Skip to main content

Satellites and Anime Girls

pix2pix GAN implementation*
Created on December 18|Last edited on December 20


Abstract

Нужно было реализовать метод из статьи и поиграться с ним.

Task 1

Базовая конфигурация

Модель: UNet128
UNet(
(down_blocks): ModuleList(
(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): ConvBlock(
(features): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(2): ConvBlock(
(features): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(3): ConvBlock(
(features): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(4): ConvBlock(
(features): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(5): ConvBlock(
(features): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(6): ConvBlock(
(features): Sequential(
(0): LeakyReLU(negative_slope=0.2, inplace=True)
(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
)
)
)
(up_blocks): ModuleList(
(0): DeconvBlock(
(features): Sequential(
(0): ReLU(inplace=True)
(1): ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): DeconvBlock(
(features): Sequential(
(0): ReLU(inplace=True)
(1): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): Dropout(p=0.5, inplace=False)
)
)
(2): DeconvBlock(
(features): Sequential(
(0): ReLU(inplace=True)
(1): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): Dropout(p=0.5, inplace=False)
)
)
(3): DeconvBlock(
(features): Sequential(
(0): ReLU(inplace=True)
(1): ConvTranspose2d(1024, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
(3): Dropout(p=0.5, inplace=False)
)
)
(4): DeconvBlock(
(features): Sequential(
(0): ReLU(inplace=True)
(1): ConvTranspose2d(512, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(5): DeconvBlock(
(features): Sequential(
(0): ReLU(inplace=True)
(1): ConvTranspose2d(256, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(6): DeconvBlock(
(features): Sequential(
(0): ReLU(inplace=True)
(1): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(2): Tanh()
)
)
)
)
За основу взяли UNet из домашки по зрению week06_vision, а архитекутуру подсмотрели тут.
Данные: задача sat2map, датасет скачали здесь же.
Аугментации: RandomHorizontalFlip(p=0.5), RandomPad(mode="reflect"). Суть второй аугментации в том, что из одной картинки составляется вот такой квадрат 3×33\times 3 картинки:
Пример RandomPad
А затем из него берется случайный кроп размером с исходную картинку:
Пример данных
Преобразования данных: масштабирование в 256×256256\times 256, нормализация [0,255]  [1,1][0, 255]\ \to\ [-1, 1].
Функция потерь: L1Loss + EdgeLoss. Про EdgeLoss не было в статье, но по опыту помогает генерировать более резкие картинки. Работает очень просто:
EdgeLoss = PreprocessWrapper(EdgeDetector(), nn.L1Loss())

class EdgeDetector(nn.Module):
sobel_x = torch.tensor(
[[1, 0, -1],
[2, 0, -2],
[1, 0, -1]],
dtype=torch.float32
)
sobel_y = sobel_x.T
gaussian = GaussianBlur(channels=3, kernel_size=3)
luma_estimator = LuminanceEstimator()

def forward(self, x: Tensor) -> Tensor:
denoised = self.gaussian(x)
luma = self.luma_estimator(denoised)
edges = (F.conv2d(luma, self.sobel_x) ** 2 +
F.conv2d(luma, self.sobel_y) ** 2) ** (1 / 2)
edges = F.pad(edges, (1, 1, 1, 1), mode="replicate")
return edges
Как видим, берется яркость картинки, немного размывается чтобы убрать шум и выделяются контуры (эджи) с помощью оператора Собеля. Оценивать яркость можно, например, так:
def estimate_luma(x: Tensor) -> Tensor:
# Power curve to linear RGB
x = x ** 2.2
# Estimate luminance
x = (0.2126 * x[:, 0] + 0.7152 * x[:, 1] + 0.0722 * x[:, 2]).unsqueeze(1)
return x
Пример извлеченных контуров. Сверху - выход генератора, снизу - целевые изображения.
Метрики: PSNR, SSIM, LPIPS. Основная валидационная метрика LPIPS, по ее значению сохраняем лучший чекпоинт.
Остальные параметры:
num_epochs = 50
batch_size = 128
lr = 0.0002 # Из статьи
betas = (0.5, 0.999) # Из статьи

Результаты

На графиках можно увидеть сравнение UNet с BatchNorm и InstanceNorm. Явно видно преисущество BatchNorm (хотя тут большой батч, не такой как в статье), его и будем дальше использовать. Градиент в одном из экспериментов взорвался, но это позже было исправлено с помощью градиент клиппинга и возврата к предыдущему чекпоинту в случае взрыва.


В целом результаты нормальные, но картинки довольно мутные. Как мы знаем, чтобы улучшить резкость можно воспользоваться функцией потерь, основанной на корреляции глубоких признаков предобученной модели.
FID = 222.5 на 50 эпохе (посчитано с помощью pytorch-fid на валидации)

Добавляем VGGLoss

Функция потерь: 10 * L1Loss + 10 * EdgeLoss + 0.3 * VGGLoss
class VGGPerceptualLoss(nn.Module):
def __init__(self):
super().__init__()
# Get pretrained VGG model
self.vgg_model = torchvision.models.vgg16(pretrained=True)
# Remove classifier part
self.vgg_model.classifier = nn.Identity()
# Remove layers with deep features
self.vgg_model.features = nn.Sequential(*self.vgg_model.features[:22])
# Freeze model
self.vgg_model.eval()
for param in self.vgg_model.parameters():
param.requires_grad = False
# L1 loss instance
self.loss = nn.L1Loss()

self.mean = (0.48235, 0.45882, 0.40784)
self.std = (0.229, 0.224, 0.225)

def forward(self, input: Tensor, target: Tensor) -> Tensor:
out = self.loss(self.vgg_model(norm(input / 2 + 0.5, self.mean, self.std)),
self.vgg_model(norm(target / 2 + 0.5, self.mean, self.std)))
return out
То есть просто прогоняем выход генератора и целевую картинку через преобученную на ImageNet VGG и считаем L1Loss от выходов с 10 сверточного слоя.
Модель: UNet256. Картинки размером 256x256, значит моделька не совсем правильная была. На самом деле это влияет очень слабо, потому что разница только в самых центральных слоях, где размер признаков 1x1. Перетащили веса из предыдущей обученной модели в новую, как именно - см. в тетрадке transfer_weights.ipynb.
Остальные параметры: как в предыдущем, но
num_epochs = 75
batch_size = 8 # На больших батчах получаются более искаженные картинки
# с извилистыми линиями

Результаты



LPIPS стал заметно лучше, да и картинки почетче.
FID = 119.2 на 75 эпохе

Добавляем skip-connections

Идея: давайте добавим в UNet residual learning. Будем прокидывать признаки между слоями с помощью билинейной интерполяции и pointwise 1x1 свертки, где требуется поменять число каналов. По примеру ConvNeXt не будем выполнять активацию на residual признаках, а только внутри сверточного блока.
Модель: UNet256Skip. Предобученные параметры из предыдущей модели опять перенесли в новую.
Остальные параметры: как в предыдущем, но
num_epochs = 150

Результаты



Новая модель быстро догнала и перегнала старую по метрикам, обратите внимание на появление зеленого парка-прямоугольника на примере генерируемых картинок во втором ряду сверху.
FID = 100.6 на 112 эпохе

Task 2

Добавляем adversarial потери

Дискриминатор: PatchDiscriminator
PatchDiscriminator(
(features): Sequential(
(0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
)
)
Подсмотрено там же, где и архитектура UNet.
Функция потерь: 10 * L1Loss + 10 * EdgeLoss + 0.3 * VGGLoss + 0.25 * LSGAN. Взяли LSGAN, потому что он более стабильный, чем Vanilla GAN. При потерях дискриминатора меньше 0.3 (значит, что он слишком уверен в своих предсказаниях) его шаг пропускался, то есть обновлялся только генератор. Иначе дискриминатор просто сойдется в 100% качество и adversarial обучения не будет.
Остальные параметры: как в предыдущем, но
num_epoch = 100
batch_size = 4 # См. картинки, 8 слишком усредняет
generator_lr = 0.0002 # Из статьи
discriminator_lr = 0.0002 # Из статьи

Результаты



По картинкам особо не понять, но LPIPS и FID на валидации стали лучше.
FID = 88.0 на 100 эпохе

Добавляем второй датасет

В качестве второй задачи возьмем генерацию картинки по ее контурам (зря мы что ли EdgeLoss писали?) Естественно, генерировать будет аниме-девочек. Чтобы заиметь нужный датасет, пришлось немало потрудиться, но оно того стоило(?)
Ясно, что если просто взять датасет Danbooru2021, то будет слишком много разных картинок. Хотелось бы выделить какое-то подмножество, желательно с постоянным набором персонажей, чтобы модель научилась правильно их раскрашивать (а не просто как-то).
Кроме того, хотелось бы иметь белый задний фон, чтобы сетка не мучилась восстанавливать то, о чем информации совсем мало (фон может быть какой угодно, по его контурам фиг поймешь, какого он цвета). Но если выделить только такие картинки, то их окажется маловато.
В итоге было принято решение скачать картинки из актуальной версии базы danbooru с помощью API. Ниже приведен набор тегов, по которым производилась селекция:
["monogatari_(series)", "white_background", "rating:sensitive,general"]
Скрипт для скачивания данных полностью оригинален.
Аугментации: RandomCrop(), RandomHorizontalFlip(p=0.5)
Преобразования данных:
Обучение: масштабирование в 256 по малой стороне, аугментации, затем выделение контуров с помощью алгоритма Canny(t1=100, t2=200), нормализация [0,255]  [1,1][0, 255]\ \to\ [-1, 1]. Контуры инвертируются, то есть получаются черные линии на белом фоне (как и в целевой картинке).
Валидация: масштабирование в 256 по малой стороне, TopCrop() (потому что на картинках чаще всего лицо сверху) , затем то же самое, что и в обучении после аугментаций.
Пример данных с цветным фоном, слева совершенно не видно, что он оранжевый
Пример данных с белым фоном

Supervised режим: L1

Функция потерь: L1Loss
Остальные параметры:
num_epochs = 50
batch_size = 8
lr = 0.0002
betas = (0.5, 0.999)
Дополнительно: был реализован метод пост-процессинга предсказанных картинок, дающий более гладкие градиенты цвета. Широкоформатные иллюстрации далее обработаны с помощью этого метода. Он незначительно влияет на FID (~0.1 - 1.0), но позволяет получить более приятные для глаза изображения.
Идея метода состоит в оптимизации ColorConsistencyLoss с помощью градиентного спуска по картинке.
class ColorConsistencyLoss(nn.Module):
def __init__(self, kernel_size, epsilon):
super().__init__()
self.gaussian_blur = GaussianBlur(channels=1, kernel_size=kernel_size)
self.epsilon = epsilon

def forward(self, input: Tensor, target: Tensor) -> Tensor:
inverted_edges = target
inverted_edges = self.gaussian_blur(inverted_edges)
inverted_edges_density = inverted_edges * (inverted_edges > self.epsilon)

x = input
padded = F.pad(x, (1, 1, 1, 1), mode="replicate")

tl = padded[:, :, :-2, :-2] # Top left
t = padded[:, :, :-2, 1:-1] # Top
tr = padded[:, :, :-2, 2:] # Top right
r = padded[:, :, 1:-1, 2:] # Right
br = padded[:, :, 2:, 2:] # Bottom right
b = padded[:, :, 2:, 1:-1] # Bottom
bl = padded[:, :, 2:, :-2] # Bottom left
l = padded[:, :, 1:-1, :-2] # Left

se = (tl - x) ** 2 + (t - x) ** 2 + (tr - x) ** 2 + \
(r - x) ** 2 + (br - x) ** 2 + (b - x) ** 2 + \
(bl - x) ** 2 + (l - x) ** 2
weighted_se = se * inverted_edges_density

return weighted_se.mean()
По сути происходит усреднение цветов внутри замкнутых областей, определяющихся маской inverted_edges_density.
Маска inverted_edges_density разбивает картинку на области, внутри которых применяется заливка
Верхний ряд - применили пост-процессинг, нижний ряд - сырой выход генеративной модели

Результаты



Цвета получились очень блеклые, выцветшие.
Пример валидации, эпоха 52. Здесь и далее: верхний ряд - вход, средний ряд - предсказание, нижний ряд - целевая картинка.
FID = 231.6 на 52 эпохе

Добавляем VGGLoss

Функция потерь: L1Loss + 0.03 * VGGLoss. EdgeLoss здесь бессмыслен, т.к. мы и так подаем контуры на вход. Кроме того, градиенты от него взрываются, слишком сильный фидбек.
Предобученные веса: из предыдущего эксперимента.
Остальные параметры: как в предыдущем, но
num_epoch = 200

Результаты



Это уже ощутимо лучше. Цвета все равно недожатые, но сетка явно научилась распознавать персонажей и примерно правильно их раскрашивать.
Пример валидации, эпоха 190
FID = 126.6 на 190 эпохе

Добавляем adversarial потери

Дискриминатор: PatchDiscriminator
Функция потерь: L1Loss + 0.03 * VGGLoss + 0.02 * LSGAN
Предобученные веса: из предыдущего эксперимента.
Остальные параметры: как в предыдущем, но
num_epoch = 100
generator_lr = 0.0002 # Из статьи
discriminator_lr = 2e-5

Результаты



Особых изменений не наблюдается.
Пример валидации, эпоха 88
FID = 120.5 на 88 эпохе

Улучшения

В экспериментах с высоким коэффициентом GAN потерь наблюдались вот такие артефакты:


Это наводит на мысль о том, что генератор и дискриминатор подстраиваются друг под друга с точностью до пикселя и обманывают друг друга с помощью вот такой градиентной атаки. Чтобы это исправить, достаточно в 50% случаев подавать в дискриминатор картинку, сдвинутую на 1 пиксель (например, вправо вниз).
class RandomShift(nn.Module):
def __init__(self, shift=(0, 1, 0, 1)):
super().__init__()
self.shift = shift

def forward(self, x: Tensor) -> Tensor:
if self.training and np.random.rand() < 0.5:
x = F.pad(x, self.shift, mode="replicate")
x = x[:, :, self.shift[3]:, self.shift[1]:] # сори за говнокод
return x
Заодно применим тот же самый трюк и для VGG потерь, благодаря глубине признаков можно даже сдвигать prediction и target независимо друг от друга, это будет действовать как дополнительная регуляризация.
Теперь мы можем увеличить коэффициент GAN потерь, не опасаясь получить артефакты.
Функция потерь: L1Loss + 0.03 * VGGLoss + 0.1 * LSGAN
Предобученные веса: из предыдущего эксперимента.
Остальные параметры: как в предыдущем, но
discriminator_lr = 2e-5 or 0.0002

Результаты



Видно, что у эксперимента с большим коэффициентом GAN потерь и большим discriminator_lr меньше SSIM и PSNR, но больше LPIPS, что и есть adversarial эффект. Отличные цвета (обратите внимание на вторую колонку, волосы стали более фиолетовыми).
Пример валидации, эпоха 18
FID = 121.3 на 23 эпохе

Back to satellite

А что будет, если наши новые знания применить снова к задаче sat2map? Ответ вот:
(розовый - предыдущий лучший GAN)


Даже дорожки на траве стали генерироваться!
FID = 79.9 на 93 эпохе

А также

Еще много чего можно написать, делал эксперименты с BlurPool из AntiAliasedCNNs, но оно не очень нужно оказалось, случайный сдвиг все решает.

Воспроизводимость

Скачать данные можно с помощью ноутбука dataset.ipynb. Там находятся ссылки на кастомные датасеты (чтобы не грузить их скриптом через тор, да и содержимое сайта может поменяться).
Чтобы воспроизвести результаты достаточно запустить (предварительно поместив данные в нужное место)
python3 train.py -c CONFIG
В папке src/config вы найдете множество файлов конфигурации для различных экспериментов. В ноутбуке train.ipynb можно посмотреть пример запуска скриптов и установку правильной конфигурации для датасферы.
Для применения уже обученной модели написан скрипт predict
python3 predict.py -c CONFIG --gen_pretrained WEIGHTS
Где WEIGHTS - путь к файлу (без расширения!) полного чекпоинта или чекпоинта весов модели.
Все обученные веса, упомянутые в данном отчете, лежат в папке на гугл диске.
Скрипт для подготовки данных для подсчета FID и другие полезности находятся в папке bin. Все скрипты снабжены мануалом, доступным по флагу -h.