Генеративные модели
В этом курсе мы в основном работали с размеченными данными. Мы научили нейронные сети решать задачи классификации, сегментации и т.д.
В этой лекции мы разберемся, как научить нейросеть создавать что-то новое.
Как подойти к такой задаче с помощью нейронных сетей?
Постановка задачи генерации
Дано: неразмеченные данные
Выход: новые данные, которые будут удовлетворять следующим условиям:
Предположим, у нас под рукой есть генератор случайных чисел, который позволяет нам легко получить случайные числа в диапазоне $[0,1]$:
import numpy as np
import matplotlib.pyplot as plt
uniform_noise = np.random.uniform(size=1000)
plt.hist(uniform_noise)
plt.title("uniform distribution U")
plt.ylabel("counts")
plt.xlabel("value")
plt.show()
Пусть $U_i$ — точка, которую даёт нам генератор равномерно распределенных случайных чисел. Если мы возьмём пару таких таких точек $U_1$, $U_2$ и произведём с ними так называемое преобразование Бокса-Мюллера, то получим пару новых точек $X$ и $Y$ на вещественной прямой:
Повторив такую процедуру для большого числа точек $U_i$, можно заметить, что распределение для объектов $X$ и $Y$ становится похожим на стандартное нормальное:
U1 = np.random.uniform(size=1000)
U2 = np.random.uniform(size=1000)
R = np.sqrt(-2 * np.log(U1))
Theta = 2 * np.pi * U2
X = R * np.cos(Theta)
Y = R * np.sin(Theta)
plt.hist(X)
plt.title("Box–Muller transform of U")
plt.ylabel("counts")
plt.xlabel("value")
plt.show()
gaussian_noise = np.random.normal(size=1000)
plt.hist(gaussian_noise)
plt.title("Gaussian distribution N")
plt.ylabel("counts")
plt.xlabel("value")
plt.show()
На самом деле мы можем в точности доказать, что преобразование Бокса-Мюллера преобразовало исходное равномерное распределение в стандартное нормальное:
Решим уравнение относительно переменных $U_i$: $$ \large \begin{array}{l} U_{1}=e^{-\left(X^{2}+Y^{2}\right) / 2} \\ U_{2}=\frac{1}{2 \pi} \arctan \left(\frac{X}{Y}\right) \end{array} $$
Матрица Якоби для такого преобразования:
Отсюда имеем:
То есть мы получили пару независимых случайных чисел $X$ и $Y$, каждое из которых лежит в нормальном распределении: $$ X\perp Y, \quad X \sim \mathcal{N}(0,1),\quad Y\sim \mathcal{N}(0,1) $$
Что делать, если нам хочется сгенерировать объекты не в стандартном нормальном распределении, а в каком-то более сложном? В общем случае нам необходимо подобрать некоторую функцию, которая будет отображать объекты из известного "простого" распределения в "целевое". Осуществляющие такое преобразование функции будем называть генеративными моделями.
Если целевое распределение $\mathcal{F}[\vec {x}]$ известно аналитически, то для подбора такой функции можно воспользоваться различными модификациями алгоритма Markov Chain Monte Carlo.
Эволюция в генерации изображений лиц:
Разберемся с элементом случайности. В нейронных сетях мы привыкли к воспроизводимости результата: в режиме валидации мы можем несколько раз подать на вход один и тот же объект и получить один и тот результат. Возникает два вопроса:
Ответ на оба вопроса: подавать в качестве входа вектор случайного шума.
Почему именно вектор? Почему не одно случайное число?
Ответ: входной вектор можно рассматривать как признаки генерируемого объекта. Каждый такой признак — независимая случайная величина. Если мы будем передавать только одно случайное число, то генерация будет однообразной. Чем больше признаков (степеней свободы) у входного вектора, тем разнообразнее будет результат генерации.
То есть случайный шум большей размерности даёт нам больше вариабельности для генерации. Это называется input latent space — входное латентное пространство.
Note: из-за неустоявшейся терминологии случайное распределение на входе генератора называется латентным пространством так же, как и скрытое пространство в автоэнкодерах. Поэтому в этой лекции будем называть его входным латентным пространством. Также в статьях встречается вариант: predefined latent space.
Каждую компоненту латентного пространства можно рассматривать как отдельную шкалу, вдоль которой изменяются определенные свойства генерируемых объектов. Например, можно выбрать четыре латентных вектора и посмотреть, как генерируемые объекты плавно изменяются при переходе от одного вектора к другому.
В выборе размерности входного латентного пространства важно соблюсти баланс.
Пример: генератор лиц с входным вектором длины 1. Результатом работы генератора будет всего одна шкала, вдоль которой будут расположены генерируемые изображения. Скорее всего, генератор выучит наиболее простую и "очевидную" шкалу: от молодой женщины блондинки к пожилому мужчине брюнету. У такой сети будет низкая вариабельность — она не сможет сгенерировать, например, рыжего ребенка в очках.
При обучении модели количество точек в этом латентном пространстве будет настолько мало, что в основном пространство будет состоять из пустот. Тогда модель будет крайне некачественно генерировать объекты в точках латентного пространства, далеких от точек обучающей выборки.
Лучший способ выбрать длину вектора — это найти публикацию с похожей задачей и взять значение из нее.
Если такой информации нет, то придется экспериментировать. Лучше начинать с низкой размерности латентного пространства, чтобы наладить работу всей сети, пусть и с низким разнообразием, а затем проводить эксперименты по поиску оптимальной размерности.
Можно использовать собственные знания в предметной области: спросите себя, сколькими вещественными числами можно описать важную информацию об объекте.
Как мы знаем из лекции про обучение сети, инициализация весов и нормализация входных данных вносят существенный вклад в работу модели. Поэтому принято использовать многомерное нормальное распределение для input latent space. Оно лучше взаимодействует с весами модели и улучшает сходимость.
(как делать на практике НЕ нужно)
Давайте попробуем создать генератор точек на параболе (cамым тривиальным решением кажется подача случайного шума на вход сети) и будем ожидать на выходе точки параболы. Проверим, как это будет работать.
import numpy as np
import matplotlib.pyplot as plt
def gen_pair(num=100):
x = np.random.uniform(low=-1, high=1, size=(num,))
y = x * x
return np.hstack(
(x.reshape(-1, 1), y.reshape(-1, 1))
) # Create num of correct dots(x,y) on parabola
pairs = gen_pair(100)
plt.scatter(pairs[:, 0], pairs[:, 1])
plt.title("Random dots on parabola,\nwhich will use like a dataset.")
plt.show()
Возьмём размерность входного латентного пространства $ls = 1$ и объединим шум с точками в датасеты.
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
# Define input parameters
n_batches = 10
batch_size = 128
ls = 1 # latent space
# Generate random noise
noise = torch.randn(size=(n_batches * batch_size, ls), dtype=torch.float)
print(f"NN Input: noise.shape: {noise.shape}")
# Generates dots on parabola
xy_pair = gen_pair(num=(n_batches * batch_size))
xy_pair = torch.tensor(xy_pair, dtype=torch.float)
print(f"NN Output: xy_pair.shape: {xy_pair.shape}")
dataset = TensorDataset(noise, xy_pair) # model inputs, model outputs
trainset, testset = train_test_split(
dataset, train_size=0.8
) # split dataset for train and test
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True)
NN Input: noise.shape: torch.Size([1280, 1]) NN Output: xy_pair.shape: torch.Size([1280, 2])
Определим простую модель, которая будет ожидать шум на вход и генерировать точки на выходе. (Обратите внимание, что функция активации на последнем слое отсутствует, поскольку мы не ограничиваем наш генератор в каком-то диапазоне.)
import torch.nn as nn
class GenModel(nn.Module):
def __init__(self, latent_space):
super().__init__()
self.model = nn.Sequential(
nn.Linear(latent_space, 50),
nn.ReLU(),
nn.Linear(50, 50),
nn.ReLU(),
nn.Linear(50, 2),
) # x,y
def forward(self, x):
return self.model(x)
Напишем функцию для оценки loss.
Так как мы не знаем, в каком месте параболы генератор создаст новую точку, то непонятно, с каким элементом из датасета ее сравнивать.
Вариант №1
Для сгенерированного $x$ аналитически вычислять $y_{target}=x*x$ и считать разницу между $y$, сгенерированным моделью, и $y_{target}$, вычисленным аналитически:
def custom_loss(pair, label):
# All inputs are batches
x_fake = pair[:, 0]
y_fake = pair[:, 1]
return torch.abs(x_fake * x_fake - y_fake).mean() # average by batch
Это будет работать.
Однако если мы знаем способ точно предсказать выход по входу, то задача уже решена и нейронная сеть не нужна.
Вариант №2
Найти в датасете точку $ target = (x_{target},y_{target})$, наиболее близкую к созданной генератором $ generated = (x,y)$, и использовать расстояние между этими точками в качестве loss.
$$\large Loss = min(dist(target_{i},generated))$$В пространстве высокой размерности такой поиск будет весьма ресурсозатратным, но в нашем учебном примере работать будет.
class Loss(nn.Module):
def __init__(self, targets):
super().__init__()
self.targets = targets # Remember all real samples, impossible in real world
def forward(self, input, dummy_target=None):
dist = torch.cdist(input, self.targets) # claculate pairwise distances (euc.)
min_dist, index = torch.min(dist, dim=1) # take the best
return min_dist.mean()
Вспомогательный код для вывода loss:
@torch.inference_mode()
def get_test_loss(model, loader):
test_data = next(iter(loader))
test_loss = Loss(test_data[1])
outputs = model(samples.to(device))
return test_loss(outputs.cpu())
Основной код обучения.
Целевые точки из датасета запоминаются в loss, затем идет обычный цикл обучения.
device = "cuda" if torch.cuda.is_available() else "cpu"
num_epochs = 600
model = GenModel(latent_space=ls)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
all_train_targets = next(iter(train_loader))[1]
criterion = Loss(all_train_targets.to(device))
for epoch in range(num_epochs):
loss_epoch = 0
for samples, labels in train_loader:
optimizer.zero_grad()
outputs = model(samples.to(device))
loss = criterion(outputs.to(device), labels.to(device))
loss.backward()
optimizer.step()
loss_epoch += loss.item()
loss_test = get_test_loss(model, test_loader)
if epoch % 100 == 0:
print(
f"Epoch={epoch} train_loss={loss_epoch/len(train_loader):.4} test_loss={loss_test:.4}"
)
Epoch=0 train_loss=0.05103 test_loss=0.03449 Epoch=100 train_loss=0.004604 test_loss=0.008559 Epoch=200 train_loss=0.004065 test_loss=0.008372 Epoch=300 train_loss=0.002949 test_loss=0.006444 Epoch=400 train_loss=0.004212 test_loss=0.008335 Epoch=500 train_loss=0.002926 test_loss=0.01005
Посмотрим результаты генерации на шуме:
def test_image(model, pairs, ls=1):
model.eval().to("cpu")
noise = torch.tensor(np.random.normal(size=(1000, ls)), dtype=torch.float)
with torch.no_grad():
xy_pair_gen = model(noise)
xy_pair_gen = xy_pair_gen.detach().numpy()
plt.scatter(pairs[:, 0], pairs[:, 1], color="red", label="real")
plt.scatter(xy_pair_gen[:, 0], xy_pair_gen[:, 1], color="blue", label="generated")
plt.axis([-1, 1, 0, 1])
plt.legend()
plt.show()
model.to(device)
test_image(model, pairs)
Видно, что модель генерирует точки, лежащие на параболе, при этом все они лежат в довольно узком интервале по оси х.
Это неудивительно: в loss function мы прописали, что сгенерированная точка должна лежать на параболе, и модель обучилась. А информацию о том, в каких частях кривой должны оказаться точки, мы в loss никак не кодировали.
Более того, модель может научиться хорошо генерировать одну единственную точку, и при этом loss может стать нулевым.
Итак, надо решить две проблемы:
Будем наказывать нейронную сеть не напрямую стандартной loss function, а второй сетью, которая определяет, лежит ли сгенерированная точка на параболе.
Создадим сеть-классификатор точек (лежит/не лежит на параболе), которую назовём дискриминатор или критик.
class DisModel(nn.Module):
def __init__(self, n_points):
super().__init__()
self.model = nn.Sequential(
nn.Linear(2, 25),
nn.ReLU(),
nn.Linear(25, 15),
nn.ReLU(),
nn.Linear(15, 1), # real/fake
nn.Sigmoid(),
)
def forward(self, x):
x = x.view(-1, 2 * n_points)
return self.model(x)
В общем случае задача дискриминатора — определять, принадлежит ли объект к распределению обучающей выборки.
Итого мы имеем:
Мы будем подавать в дискриминатор правильные точки (чтобы он знал, как это должно выглядеть) и точки, которые выдаёт генератор, считая их подделкой.
Таким образом, генератор будет учиться подражать реальным данным, а дискриминатор будет учиться отличать реальные точки, от подделок.
Мы пришли к идее генеративно-состязательных нейронных сетей.
2014 Generative Adversarial Networks (Goodfellow et al., 2014) (Cited by 33430!!!)
Видео разбор оригинальной статьи
Генеративно-состязательную сеть описал Иан Гудфеллоу из компании Google (на тот момент) в 2014 году. Сейчас он возглавляет подразделение машинного обучения в Apple. Принцип состязательности в сети GAN нередко описывается через метафоры.
Еще со времен AlexNet мы знаем, что если мы что-то и умеем делать с нейросетями, то это классификаторы. В классическом GAN дискриминатор выполняет простейшую из задач классификации — бинарную классификацию (либо real, либо fake). А вот задача генерации каким-то прямым образом на тот момент решена не была.
Как использовать всю мощь классификатора для создания генератора?
Представим, что есть фальшивомонетчик $G$ (generator) и банкир с прибором для проверки подлинности купюр $D$ (discriminator).
Фальшивомонетчик черпает вдохновение из генератора случайных чисел в виде случайного шума $z$ и создает подделки $G(z)$.
Банкир $D$ получает на вход пачку купюр $x$, проверяет их подлинность и сообщает вектор $D(x)$, состоящий из чисел от нуля до единицы — свою уверенность (вероятность) по каждой купюре в том, что она настоящая. Его цель — выдавать нули для подделок $D(G(z))$ и единицы для настоящих денег $D(x)$. Задачу можно записать как максимизацию произведения $D(x)(1-D(G(z)))$, а произведение, в свою очередь, можно представить как сумму через логарифм.
Таким образом, задача банкира — максимизировать $log(D(x))+log(1-D(G(z)))$.
Цель фальшивомонетчика прямо противоположна — максимизировать $D(G(z))$, то есть убедить банкира в том, что подделки настоящие.
Продолжая аналогию, обучение генератора можно представить так: фальшивомонетчик не просто генерирует подделки наудачу. Он добывает прибор для распознавания подделок, разбирает его, смотрит, как тот работает, и затем создает подделки, которые смогут обмануть этот прибор.
Математически это игра двух игроков:
$$\large \min\limits_{\theta_g} \max\limits_{\theta_d} [\mathbb{E}_{x _\sim p(data)} log(D_{\theta_d}(x)]+\mathbb{E}_{z _\sim p(z)} [log(1-D_{\theta_d}(G_{\theta_g}(z))]$$Дискриминатор
Генератор
Посредством чередования градиентного подъема и спуска сеть можно обучить.
Градиентный подъем на дискриминаторе:
$$\large \max\limits_{\theta_d} [\mathbb{E}_{x _\tilde{}p(data)} log(D_{\theta_d}(x)+\mathbb{E}_{z _\tilde{}p(z)} log(1-D_{\theta_d}(G_{\theta_g}(z)))]$$Градиентный спуск на генераторе:
$$\large \min\limits_{\theta_g} \mathbb{E}_{z _\tilde{}p(z)} log(1-D_{\theta_d}(G_{\theta_d}(z)))$$Градиентный спуск на генераторе эквивалентен градиентному подъему
$$\large \max\limits_{\theta_g} \mathbb{E}_{z _\tilde{}p(z)} log(D_{\theta_d}(G_{\theta_d}(z)))$$В процессе совместного конкурентного обучения, если система достаточно сбалансирована, достигается минимаксное состояние равновесия, в котором обе сети эффективно учатся.
Сгенерированные удачно обученной нейросетью изображения практически неотличимы от настоящих.
Если хорошенько подумать, то можно прийти к выводу, что loss function в GAN — это не какая-то функция, заданная людьми, а еще одна нейросеть.
Преимущества GAN
Недостатки GAN
Определим наши генератор и дискриминатор
class Generator(nn.Module):
def __init__(self, latent_space, hidden_dim):
super().__init__()
self.model = nn.Sequential(
nn.Linear(latent_space, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2),
) # x,y
def forward(self, x):
return self.model(x)
class Discriminator(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.model = nn.Sequential(
nn.Linear(2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1), # real/fake
nn.Sigmoid(),
)
def forward(self, x):
return self.model(x)
Определим входные параметры
latent_dim = 10 # latent space
num_epochs = 10000
batch_size = 32
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
Обратите внимание, что у нас так же, как и в первом примере, есть переменная latent space. Это тот шум, из которого мы будем генерировать наши точки. Закон сохранения масс в действии — нельзя создать что-то из ничего!
Определим все необходимое для обучения:
torch.manual_seed(42)
np.random.seed(42)
criterion = nn.BCELoss() # Loss
gen = Generator(latent_space=latent_dim, hidden_dim=50).to(device)
disc = Discriminator(hidden_dim=50).to(device)
# 2 optimizers for Discriminator and Generator
optimizerD = torch.optim.Adam(disc.parameters(), lr=3e-4)
optimizerG = torch.optim.Adam(gen.parameters(), lr=3e-4)
# Fix noise to compare
fixed_noise = torch.randn(128, latent_dim, device=device)
Обратите внимание, что мы используем BCELoss
(Binary Cross Entropy). Давайте разберемся почему:
Определим функцию, которая создает точки на параболе:
def gen_pair(num=100):
x = np.random.uniform(low=-1, high=1, size=(num,))
y = x * x
return torch.tensor(
np.hstack((x.reshape(-1, 1), y.reshape(-1, 1))), dtype=torch.float
) # Create num of correct dots(x,y) on parabola
Что сейчас будет происходить? $$$$
from IPython import display
from IPython.display import clear_output
real_label = 1
fake_label = 0
# Main Training Loop
print("Training...")
print(device)
x = []
y_D = []
y_G = []
for epoch in range(num_epochs):
# max log(D(x)) + log(1 - D(G(z)))
# train on real points
disc.zero_grad()
# Define real points
real_points = gen_pair(num=batch_size).to(device)
label = torch.full(
(batch_size,), real_label, dtype=torch.float, device=device
).view(-1)
# Train disc on real_points
output = disc(real_points).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
# Define fake points
# This dots generated by generator transform from latent space
noise = torch.randn(batch_size, latent_dim, device=device)
fake_points = gen(noise)
label.fill_(fake_label)
# Train disc on fake_points
output = disc(fake_points.detach()).view(-1)
errD_fake = criterion(output, label)
errD_fake.backward()
# Discriminator loss(real+fake)
errD = errD_real + errD_fake
optimizerD.step()
# max log(D(G(z)))
# Now, train generator
gen.zero_grad()
# Let's tell the discriminator that our generator creates real points
label.fill_(real_label)
output = disc(fake_points).view(-1)
errG = criterion(output, label)
errG.backward()
optimizerG.step()
# Plotting every N epoch
x.append(epoch)
y_D.append(errD.item() / 2)
y_G.append(errG.item())
if epoch % 250 == 0:
fig, ax = plt.subplots(nrows=2, figsize=(9, 6))
ax[0].plot(x, y_D, color="red", lw=1, label="D")
ax[0].plot(x, y_G, color="green", lw=1, label="G")
# Generates dots from fixed_noise
fake_points = gen(fixed_noise)
ax[1].scatter(
fake_points.detach().to("cpu")[:, 0],
fake_points.detach().to("cpu")[:, 1],
color="green",
alpha=0.5,
)
ax[1].scatter(
real_points.detach().to("cpu")[:, 0],
real_points.detach().to("cpu")[:, 1],
color="red",
alpha=0.5,
)
ax[1].set_xlim(-1, 1)
ax[1].set_ylim(0, 1)
clear_output(wait=True)
display.display(plt.gcf())
plt.close()
Класс! У нас получилось (если вдруг не сошлось за 10000 эпох, перезапустите заново, к сожалению, фиксация seed еще не гарантирует стабильность GAN). Особенно круто смотреть, как красиво loss дискриминатора и генератора сходятся.
С помощью GAN можно, разумеется, генерировать не только точки на параболе. Можно генерировать, например, изображения. Но появляются закономерные вопросы.
Самым простым ответом будет: взять шум, пропустить его через полносвязные слои и сделать reshape до нужного разрешения. В целом, это будет работать.
Однако DCGAN (Deep Convolutional GAN) использует сверточные и сверточно-транспонированные (convolutional и convolutional-transpose) слои в дискриминаторе и генераторе соответственно. Впервые метод DCGAN был описан в статье Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks (Radford et al., 2015).
Ниже видна разница в генерации при помощи исключительно полносвязных слоёв и при помощи обратных свёрток. Очевидно, результат DCGAN лучше, чем GAN.
На вход генератора подают шум для создания разнообразных объектов. Этот шум представляет собой вектор в многомерном пространстве. Один вектор — один сгенерированный объект. Задача дискриминатора — преобразовать вектор в изображение.
Такое преобразование возможно при помощи транспонированных сверточных (convolution-transpose, иногда называют fractionally strided convolution) слоев. Как и обычные сверточные слои, эти слои используют сверточные ядра, но перед вычислением сверток они увеличивают размер исходного изображения, "раздвигая" пиксели и заполняя образующиеся промежутки между пикселями нулями.
Давайте кратко вспомним, что делают TC слои.
Transposed convolution проходит по всем пикселям входа и умножает их на обучаемое ядро свертки. При этом каждый одиночный пиксель превращается в фрагмент. Там, где фрагменты накладываются друг на друга, значения попиксельно суммируются.
Используя Transposed convolution с параметром stride = 2
, можно повышать размер карты признаков приблизительно в два раза, добавляя на нее мелкие детали.
x = torch.rand((1, 3, 10, 10)) * 255 # one 3-channel image with 10x10 size
print(x.shape)
torch.Size([1, 3, 10, 10])
convT = nn.ConvTranspose2d(in_channels=3, out_channels=3, kernel_size=3)
y = convT(x)
print(y.shape) # One 3-chanells image with 12x12 size
torch.Size([1, 3, 12, 12])
Полученное изображение не похоже на входное, потому что были применены свёрточные ядра со случайными коэффициентами.
fig, ax = plt.subplots(ncols=2, sharex=True, sharey=True)
ax[0].imshow(x[0].permute(1, 2, 0).detach().numpy().astype(np.uint8))
ax[1].imshow(y[0].permute(1, 2, 0).detach().numpy().astype(np.uint8))
ax[0].set_title("Input")
ax[1].set_title("After ConvTranspose")
plt.show()
Помимо обратных свёрток существуют другие методы повышения разрешения из низкой размерности.
Самый простой способ — выполнить повышение разрешения с помощью интерполяции. Давайте вспомним, что в PyTorch это осуществляется слоем Upsample
x = torch.rand((1, 3, 10, 10)) # one 3-channal image with 10x10 size
print("Input shape:", x.shape)
upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
y = upsample(x)
print("Output shape", y.shape)
Input shape: torch.Size([1, 3, 10, 10]) Output shape torch.Size([1, 3, 20, 20])
fig, ax = plt.subplots(ncols=2, sharex=True, sharey=True)
ax[0].imshow((x[0].permute(1, 2, 0) * 256).detach().numpy().astype(np.uint8))
ax[1].imshow((y[0].permute(1, 2, 0).detach().numpy() * 256).astype(np.uint8))
ax[0].set_title("Input")
ax[1].set_title("After Upsample")
plt.show()
Давайте посмотрим на пример обученного DCGAN
use_gpu = True if torch.cuda.is_available() else False
model = torch.hub.load(
"facebookresearch/pytorch_GAN_zoo:hub", "DCGAN", pretrained=True, useGPU=use_gpu
)
clear_output()
import torchvision
num_images = 16
noise, _ = model.buildNoiseData(num_images)
with torch.no_grad():
generated_images = model.test(noise)
generated_images = (
generated_images.clamp(-1, 1) + 1
) / 2.0 # normalization to 0..1 range
fig, ax = plt.subplots(figsize=(16 * 3, 2 * 3))
ax.imshow(
torchvision.utils.make_grid(generated_images).permute(1, 2, 0).cpu().numpy(),
interpolation="nearest",
aspect="equal",
)
ax.axis("off")
plt.show()
Теперь давайте попробуем сами написать свой DCGAN и обучить его на датасете FashionMNIST
num_epochs = 2 # Num of epochs
batch_size = 64 # batch size
lr = 2e-4 # Learning rate
b1 = 0.5 # Adam: decay of first order momentum of gradient
b2 = 0.999 # Adam: decay of first order momentum of gradient
num_cpu = 8 # Num of cpu threads to generate batch
latent_dim = 100 # latent space
img_size = 32 # images size
channels = 1 # Num of channels
sample_interval = 450 # interval between image sampling
Обычно мы инициализируем веса случайным образом, но ничто не мешает нам инициализировать их так, как мы хотим. В оригинальной статье про DCGAN предложено инициализировать веса нормальным распределением с центром в нуле и стандартным отклонением 0,02:
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
Обратите внимание, как преобразуется шум в генераторе:
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.init_size = img_size // 4
self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size**2))
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, channels, 3, stride=1, padding=1),
nn.Tanh(),
)
def forward(self, z):
out = self.l1(z)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, bn=True):
block = [
nn.Conv2d(in_filters, out_filters, 3, 2, 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout2d(0.25),
]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
return block
self.model = nn.Sequential(
*discriminator_block(channels, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)
# The height and width of downsampled image
ds_size = img_size // 2**4
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size**2, 1), nn.Sigmoid())
def forward(self, img):
out = self.model(img)
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
return validity
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
# Initialize Generator and Discriminator
generator = Generator()
discriminator = Discriminator()
# Initialize weight
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)
# Define loss
criterion = nn.BCELoss()
generator.to(device)
discriminator.to(device)
criterion.to(device)
BCELoss()
Напишем функцию для отображения изображений
from torchvision.utils import make_grid
def show_gen_img(model, latent_dim=100):
z = Tensor(np.random.normal(0, 1, (9, latent_dim))) # define latent dim
# Generate noise from latent dim
sample_images = generator(z)
sample_images = sample_images.cpu().detach()
# Plotting images
grid = (
make_grid(sample_images, nrow=3, ncols=3, normalize=True)
.permute(1, 2, 0)
.numpy()
)
fig, ax = plt.subplots(figsize=(5, 5))
ax.imshow(grid)
plt.axis("off")
return fig
Подгрузим данные и загрузим и их в Data Loader
import os
from torchvision import datasets, transforms
os.makedirs("../../data/mnist", exist_ok=True)
data_loader = torch.utils.data.DataLoader(
datasets.FashionMNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
),
),
batch_size=batch_size,
shuffle=True,
)
clear_output()
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
Tensor = torch.cuda.FloatTensor
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(data_loader):
# Adversarial ground truths
valid = Tensor(imgs.shape[0], 1).fill_(1.0)
fake = Tensor(imgs.shape[0], 1).fill_(0.0)
# Configure input
real_imgs = imgs.type(Tensor)
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Sample noise as generator input
z = Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim)))
# Generate a batch of images
gen_imgs = generator(z)
# Loss measures generator's ability to fool the discriminator
g_loss = criterion(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_loss = criterion(discriminator(real_imgs), valid)
fake_loss = criterion(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
batches_done = epoch * len(data_loader) + i
if batches_done % sample_interval == 0:
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, num_epochs, i, len(data_loader), d_loss.item(), g_loss.item())
)
fig = show_gen_img(generator)
plt.show()
[Epoch 0/2] [Batch 0/938] [D loss: 0.693183] [G loss: 0.697421]
[Epoch 0/2] [Batch 450/938] [D loss: 0.689533] [G loss: 0.741189]
[Epoch 0/2] [Batch 900/938] [D loss: 0.640673] [G loss: 0.859734]
[Epoch 1/2] [Batch 412/938] [D loss: 0.633906] [G loss: 0.818824]
[Epoch 1/2] [Batch 862/938] [D loss: 0.672962] [G loss: 0.855835]
Чтобы картинки обрели приличный вид, хватает 2 эпох. Чтобы стали выглядеть хорошо — 5 эпох.
cGAN расшифровывается как Conditional Generative Adversarial Net — это GAN с условием. Условие может быть любым, например, генерация конкретной цифры. В этом случае нам нужен уже размеченный датасет для того, чтобы обучить дискриминатор.
Обучение в данном случае будет аналогичным обучению GAN: мы будем обучать сети, чередуя реальные данные и сгенерированные, добавив label
.
Поскольку подавать в сеть числа от 0 до 9 (в случае MNSIT) нет смысла, то нужно придумать, как подавать их в нейронную сеть. На помощь приходят Embeddings. Мы можем представить каждую метку в виде вектора с десятью элементами.
samples, labels = next(iter(data_loader))
label_emb = nn.Embedding(10, 10)
e = label_emb(labels)
print(f"Label: {labels[0]}")
print(f"Embedding for this label: {e[0]}")
Label: 6 Embedding for this label: tensor([-0.7364, 0.5600, 0.0687, 0.0613, -1.9721, 0.8849, -1.2226, 0.3140, -1.0129, 0.9434], grad_fn=<SelectBackward0>)
После этого эмбеддинги меток обычно склеиваются с входами сетей.
Если мы будем подавать просто число, например 0.1 для единицы и 0.5 для пяти, то вход у нас будет непрерывным, что довольно нелогично: тогда при небольшом изменении входа мы будем генерировать другую цифру. А также сети будет сложнее выучить небольшие расхождения в этом небольшом интервале. В случае с векторным представлением мы избегаем этих проблем.
Метки классов можно подавать не только способом, описанным выше. Можно вместо подачи их в дискриминатор сделать так, чтобы он их предсказывал — Semi-Supervised GAN.
Или же не подавать label в дискриминатор, но ждать от него классификации в соответствии с классом, который мы хотим получить от генератора — это InfoGAN
Ещё одна модификация cGAN — это AC-GAN (auxiliary classifier), в которой единственное различие заключается в том, что дискриминатор должен помимо распознавания реальных и фейковых изображений ещё и классифицировать их. Он имеет эффект стабилизации процесса обучения и позволяет генерировать большие высококачественные изображения, изучая представление в скрытом пространстве, которое не зависит от метки класса.
[-1,1]
, при этом генерируются данные [0,1]
. Это нехорошо, так как это подсказка для дискриминатора.(28,28)
, а генератор выдает (32,32)
. В таком случае нужно либо изменить архитектуру генератора, чтобы получать на выходе размер (28,28)
, либо сделать ресайз реальных данных до (32,32)
.BatchNorm
. Проблема BN
в том, что во время обучения его внутренняя статистика считается по минибатчу, а во время инференса она вычисляется как moving average, что в свою очередь может повлечь непредсказуемые результаты. Если архитектура GAN предполагает нормализацию, то лучше использовать InstanceNorm
.Если ваша версия GAN работает не так хорошо, как вам хотелось бы, попробуйте дать своему дискриминатору преимущество, обучив его на относительно большее число итераций, чем генератор. Другими словами, чем лучше дискриминатор различает настоящие и фальшивые данные, тем лучше сигнал, который генератор может извлечь из него. Обратите внимание, что эта логика не имела смысла во времена "до WGAN", поскольку слишком хорошая работа дискриминатора вредила обучению.
Например:
def train_on_batch(x, iter_, n_gen=5):
Generator:
...
...
if iter_ % n_gen == 0:
g_loss.backward()
opt_g.step()
Discriminator:
...
...
d_loss.backward()
d_loss.step()
Где iter_
— текущая итерация шага градиента, а n_gen
определяет интервал между обновлениями генератора. В данном случае, поскольку он равен 5, мы можем считать, что это означает, что дискриминатор обновляется в 5 раз чаще, чем генератор.
Естественно, работает не всегда и не везде. Но попробовать стоит
Можно обратить внимание, что почти во всех статьях по GAN используется ADAM. Сложно сказать, почему так получается, но он работает, и работает очень хорошо. Если качество вашего GAN оставляет желать лучшего, скорее всего оптимизатор тут не при чем. Ищите ошибку где-то еще.
Параметр epsilon
ADAM по умолчанию в PyTorch равен 1e-8
, что может вызвать проблемы после длительного периода обучения, например, значения функции потерь периодически взрываются или увеличиваются. Подробнее об этом на StackOverflow и в комментарии на Reddit.
Базовая публикация по теме.
Пусть мы имеем некоторый объект $\mathbf{x}_0$ из заданного имеющимся датасетом распределения $\mathbf{x}_0 \sim q(\mathbf{x})$. Определим так называемый прямой диффузный процесс, в ходе которого мы будем последовательно добавлять небольшое количество Гауссового шума последовательно $T$ раз, создавая из нашего объекта последовательность $\mathbf{x}_1, \dots, \mathbf{x}_T$ постепенно всё более зашумленных версий нашего исходного объекта. Параметры добавляемого шума зависят от номера шага зашумления и их набор $\{\beta_t \in (0, 1)\}_{t=1}^T$ задан заранее.
Исходное распределение данных из датасета будет последовательно преобразовано в Гауссов шум:
$$\large q(\mathbf{x}_t \vert \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}_{t-1}, \beta_t\mathbf{I}) \quad q(\mathbf{x}_{1:T} \vert \mathbf{x}_0) = \prod^T_{t=1} q(\mathbf{x}_t \vert \mathbf{x}_{t-1})$$Важным свойством описанного выше процесса является то, что он допускает аналитическое вычисление шума на любом из шагов процесса в явном виде. Это связано с тем, что сумма нескольких нормально распределенных случайных величин также является случайной величиной с нормальным распределением. Введём обозначения из оригинальной статьи $\bar{\alpha}_t = \prod_{i=1}^t \alpha_i$ и $\alpha_t = 1 - \beta_t$. Тогда:
$\large \text{где } \boldsymbol{\epsilon}_{t-1}, \boldsymbol{\epsilon}_{t-2}, \dots \sim \mathcal{N}(\mathbf{0}, \mathbf{I}),$
$\large \qquad \bar{\boldsymbol{\epsilon}}_{t-2} \text{ – новая нормально распределенная величина (*)}.$
(*) Напомним, что когда мы складываем две нормально распределенные величины с разной дисперсией и нулевым средним $\mathcal{N}(\mathbf{0}, \sigma_1^2\mathbf{I})$ и $\mathcal{N}(\mathbf{0}, \sigma_2^2\mathbf{I})$, то получаем новую нормально распределенную случайную величину $\mathcal{N}(\mathbf{0}, (\sigma_1^2 + \sigma_2^2)\mathbf{I})$, что в нашей параметризации означает $\sqrt{(1 - \alpha_t) + \alpha_t (1-\alpha_{t-1})} = \sqrt{1 - \alpha_t\alpha_{t-1}}$
import torch.nn as nn
import torch.nn.functional as F
class ResNetBlock(nn.Module):
def __init__(self, in_features, out_features, mid_features=None, residual=False):
super().__init__()
self.residual = residual
if not mid_features:
mid_features = out_features
self.conv_stack = nn.Sequential(
nn.Conv2d(in_features, mid_features, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(1, mid_features),
nn.GELU(),
nn.Conv2d(mid_features, out_features, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(1, out_features),
)
def forward(self, x):
if self.residual:
return F.gelu(x + self.conv_stack(x))
else:
return self.conv_stack(x)
Блок уменьшения размера:
Блок увеличения размера:
import torch
class ResizeBlock(nn.Module):
def __init__(self, in_features, out_features, emb_dim):
super().__init__()
# defines non-linear map from time embedding features to conv features
self.emb_projection = nn.Sequential(
nn.SiLU(),
nn.Linear(emb_dim, out_features),
)
def add_emb(self, x, t_vector):
# [batch_size, time_embedding_dim] -> [batch_size, out_features]
emb = self.emb_projection(t_vector)
# [batch_size, out_features] - > [batch_size, out_features, H, W]
emb = emb[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb
def forward(self, x):
pass
class Down(ResizeBlock):
def __init__(self, in_features, out_features, emb_dim=256):
super().__init__(in_features, out_features, emb_dim)
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
ResNetBlock(in_features, in_features, residual=True),
ResNetBlock(in_features, out_features),
)
def forward(self, x, t):
x = self.maxpool_conv(x)
x = self.add_emb(x, t)
return x
class Up(ResizeBlock):
def __init__(self, in_features, out_features, emb_dim=256):
super().__init__(in_features, out_features, emb_dim)
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.conv = nn.Sequential(
ResNetBlock(in_features, in_features, residual=True),
ResNetBlock(in_features, out_features, in_features // 2),
)
def forward(self, x, skip_x, t):
x = self.up(x)
x = torch.cat([skip_x, x], dim=1)
x = self.conv(x)
x = self.add_emb(x, t)
return x
class SelfAttention(nn.Module):
def __init__(self, channels, size):
super().__init__()
self.channels = channels
self.size = size
self.mha = nn.MultiheadAttention(channels, num_heads=4, batch_first=True)
self.ln = nn.LayerNorm([channels])
self.mlp = nn.Sequential(
nn.LayerNorm([channels]),
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels),
)
def forward(self, x):
# [batch_size, C, H, W] -> [batch_size, H*W, C]
x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
x_ln = self.ln(x)
# K, Q, V in Self attention are equal
attention_value, _ = self.mha(x_ln, x_ln, x_ln)
# skip connection
attention_value = attention_value + x
# simple multilayer perceptron and second skip connection
attention_value = self.mlp(attention_value) + attention_value
# back to [batch_size, C, H, W]
return attention_value.swapaxes(2, 1).view(
-1, self.channels, self.size, self.size
)
class UNet(nn.Module):
r"""
Denoising U-Net model implementation based on arXiv:2006.11239 [cs.LG]
"""
def __init__(self, num_channels=3, img_size=64, time_enbed_dim=256):
super().__init__()
self.num_channels = num_channels
self.time_dim = time_enbed_dim
self.img_size = img_size
# Downsample and enlarge feature dim
self.inc = ResNetBlock(num_channels, 64)
self.down1 = Down(64, 128)
self.sa1 = SelfAttention(128, img_size // 2)
self.down2 = Down(128, 256)
self.sa2 = SelfAttention(256, img_size // 4)
self.down3 = Down(256, 256)
self.sa3 = SelfAttention(256, img_size // 8)
# Keep spatial dim constant
self.conv_bottleneck = nn.Sequential(
ResNetBlock(256, 512),
SelfAttention(512, img_size // 8),
ResNetBlock(512, 256),
)
# Upsample and reduce feature dim
# 512=256+256 from conv_bottleneck and sa3
self.up1 = Up(512, 128)
self.sa4 = SelfAttention(128, img_size // 4)
# 256=128+128 from sa4 and sa2
self.up2 = Up(256, 64)
self.sa5 = SelfAttention(64, img_size // 2)
# 128=64+64 from sa5 and sa1
self.up3 = Up(128, 64)
self.sa6 = SelfAttention(64, img_size)
self.outc = nn.Conv2d(64, num_channels, kernel_size=1)
def pos_encoding(self, t):
r"""
Returns embedding vector for given integer time index.
We adopt 1d Positional Encoding form arXiv:1706.03762 [cs.CL]
see 3.5 for more details.
PE(x,2i) = sin(x/10000^(2i/D))
PE(x,2i+1) = cos(x/10000^(2i/D))
Where:
x is a point in 1d space
i is an integer in [0, D/2), where D is the size of the feature dimension
Args:
t: Tensor, shape ``[batch_size, 1]``
Returns:
pe: Tensor, shape ``[batch_size, time_embedding_dim]``
"""
# placeholder for diffusion time encoding vector
pe = torch.zeros(t.shape[0], self.time_dim).to(t)
# factor 1/10000^(2i/D)
div_factors = torch.exp(
torch.arange(0, self.time_dim, 2)
* (-torch.log(torch.as_tensor(10000.0)) / self.time_dim)
).to(t)
# repeat t index for each feature
x = t.repeat(1, self.time_dim // 2)
# sin(x/10000^(2i/D)) for even features
pe[:, 0::2] = torch.sin(x * div_factors)
# cos(x/10000^(2i/D)) for odd features
pe[:, 1::2] = torch.cos(x * div_factors)
return pe
def forward(self, x, t):
t = t.unsqueeze(-1).type(torch.float).to(x)
t = self.pos_encoding(t)
x1 = self.inc(x)
x2 = self.down1(x1, t)
x2 = self.sa1(x2)
x3 = self.down2(x2, t)
x3 = self.sa2(x3)
x4 = self.down3(x3, t)
x4 = self.sa3(x4)
x4 = self.conv_bottleneck(x4)
x = self.up1(x4, x3, t)
x = self.sa4(x)
x = self.up2(x, x2, t)
x = self.sa5(x)
x = self.up3(x, x1, t)
x = self.sa6(x)
output = self.outc(x)
return output
from PIL import Image
from torch.optim import AdamW
from tqdm.notebook import tqdm
from torchvision.utils import make_grid
class DiffusionGenerativeModel:
def __init__(
self,
dataset,
model_name="unconditioned_baseline",
num_noise_steps=1000,
beta_start=1e-4,
beta_end=0.02,
img_size=64,
denoising_model_class=UNet,
batch_size=2,
lr=0.001,
device="cpu",
):
self.model_name = model_name
self.num_noise_steps = num_noise_steps
self.beta_start = beta_start
self.beta_end = beta_end
self.img_size = img_size
self.batch_size = batch_size
self.lr = lr
self.device = device
# diffusion process linear noise schedule
self.beta = self._get_noise_schedule().to(self.device)
self.alpha = 1.0 - self.beta
# \hat{\alpha}_{i-1} = \prod_{j=0}^{i-1} \alpha_j
self.alpha_hat = torch.cumprod(self.alpha, dim=0)
# init denoising model
self.denoising_model = denoising_model_class(img_size=img_size).to(self.device)
# init dataset and dataloader
self.dataset = dataset
self.dl = torch.utils.data.DataLoader(
self.dataset, batch_size=self.batch_size, shuffle=True, drop_last=True
)
# init optimizer and loss for training
self.criterion = nn.MSELoss()
self.optimizer = AdamW(self.denoising_model.parameters(), lr=self.lr)
def _load_pretrained_model(self, path=None):
if path is not None:
self.denoising_model.load_state_dict(
torch.load(path, map_location=self.device)
)
else:
self.denoising_model.load_state_dict(
torch.load(f"{self.model_name}.pt", map_location=self.device)
)
def _get_noise_schedule(self):
return torch.linspace(self.beta_start, self.beta_end, self.num_noise_steps)
def _noise_images_batch(self, x, t):
# \mu_i = \sqrt{\hat{\alpha}_i}
mu = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
# \sigma_i = \sqrt{1 - \hat{\alpha}_i}
sigma = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
standard_normal_noise = torch.randn_like(x)
noised_image_batch = mu * x + sigma * standard_normal_noise
return noised_image_batch, standard_normal_noise
def _get_timestams_batch(self):
return torch.randint(low=1, high=self.num_noise_steps, size=(self.batch_size,))
def generate_images_batch(self, use_best_model=False):
if use_best_model:
self.denoising_model.load_state_dict(
torch.load(f"{self.model_name}.pt", map_location=self.device)
)
# stop tracking normalisation stats
self.denoising_model.eval()
# avoid training on generated samples
with torch.no_grad():
# start from pure noise batch
x = torch.randn((self.batch_size, 3, self.img_size, self.img_size)).to(
self.device
)
# and apply self.num_noise_steps denoising steps with model
for t_i in tqdm(reversed(range(1, self.num_noise_steps)), position=0):
# Build tensor with timestamp index. Same for each element in batch
t = torch.full((self.batch_size,), t_i).long().to(self.device)
# predict noise on current timestamp
pred_noise = self.denoising_model(x, t)
# restore noise parametrs on current timestamp
alpha = self.alpha[t][:, None, None, None]
alpha_hat = self.alpha_hat[t][:, None, None, None]
beta = self.beta[t][:, None, None, None]
# partialy denoise batch of images
x = x - (1.0 - alpha) / (1 - alpha_hat).sqrt() * pred_noise
x = (1 / alpha.sqrt()) * x
# add appropriate amount of noise for next step if any
if t_i > 0:
z = torch.randn_like(x).to(self.device)
x = x + beta.sqrt() * z
# clip x to valid 0..255 image range
x = x.clamp(0, 1)
x = (x * 255).type(torch.uint8)
# reenable tracking normalisation stats
self.denoising_model.train()
return x
def _train_step(self, batch):
# unpack data
images, labels = batch
t = self._get_timestams_batch()
# send tensors to device
images = images.to(self.device)
t = t.to(self.device)
# prep batch of noised images
noised_images, target_noise = self._noise_images_batch(images, t)
# estimate noise with U-Net
predicted_noise = self.denoising_model(noised_images, t)
# optimize model to fit target noise
loss = self.criterion(predicted_noise, target_noise)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def fit(self, n_epochs=20, save_images=True):
best_loss = float("inf")
for epoch in tqdm(range(n_epochs), position=1):
ep_loss = 0.0
with tqdm(total=len(self.dl), position=2) as progress_bar:
for batch in self.dl:
loss_step = self._train_step(batch)
# loss per single object
loss_step = loss_step / len(self.dl) * self.batch_size
ep_loss += loss_step
progress_bar.update(1)
log_string = f"Loss at epoch {epoch + 1}: {ep_loss:.3f}"
if best_loss > ep_loss:
best_loss = ep_loss
torch.save(self.denoising_model.state_dict(), f"{self.model_name}.pt")
log_string += " || Best model updated"
# Save images generated at this epoch
if save_images:
generated_images = self.generate_images_batch()
image_grid = make_grid(generated_images)
# convert to numpy
ndarr = image_grid.permute(1, 2, 0).to("cpu").numpy()
# and save
im = Image.fromarray(ndarr)
im.save(f"{self.model_name}_{epoch}.jpg")
print(log_string)
!pip install astronn -q
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 9.3/9.3 MB 15.5 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.5/4.5 MB 49.2 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 887.9/887.9 kB 43.2 MB/s eta 0:00:00
import ssl
ssl._create_default_https_context = ssl._create_unverified_context # ignore url error
Рассмотрим пример использования опеределенной выше диффузной модели генерации на примере датасета Galaxy10
from astroNN.datasets import load_galaxy10
images, labels = load_galaxy10()
Galaxy10_DECals.h5: 100%|█████████▉| 2.73G/2.74G [00:28<00:00, 114MB/s]
Downloaded Galaxy10 successfully to /root/.astroNN/datasets/Galaxy10_DECals.h5
Galaxy10_DECals.h5: 2.74GB [00:36, 75.9MB/s]
from torch.utils.data.dataset import Dataset
class GalaxyDataset(Dataset):
def __init__(self, images, labels, transform):
super().__init__()
self.images = images
self.labels = labels
self.transform = transform
def __getitem__(self, indx):
image = self.images[indx]
label = self.labels[indx]
if self.transform:
image = self.transform(image)
return image, label
def __len__(self):
return len(self.images)
from torchvision import transforms
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize(64, antialias=True),
transforms.RandomHorizontalFlip(0.5),
transforms.RandomVerticalFlip(0.5),
]
)
dataset = GalaxyDataset(images, labels, transform)
ddpm = DiffusionGenerativeModel(dataset, device="cuda", batch_size=32)
Обучение сети осуществляется следующим образом. Пропустим этот этап, так как процесс занимает продолжительное время.
# ddpm.fit(n_epochs=1)
Посмотрим на изображения, сгенерированные необученной моделью:
generated_images = ddpm.generate_images_batch()
image_grid = make_grid(generated_images)
# convert to numpy
ndarr = image_grid.permute(1, 2, 0).to("cpu").numpy()
# and show
Image.fromarray(ndarr)
0it [00:00, ?it/s]
Видим, что сгенерированные изображения представляют собой шум.
Загрузим веса предобученной модели:
!wget -q https://edunet.kea.su/repo/EduNet-content/L14/weights/unconditioned_baseline.pt
ddpm._load_pretrained_model("unconditioned_baseline.pt")
generated_images = ddpm.generate_images_batch()
image_grid = make_grid(generated_images)
# convert to numpy
ndarr = image_grid.permute(1, 2, 0).to("cpu").numpy()
# and show
Image.fromarray(ndarr)
0it [00:00, ?it/s]
Использованная литература
GAN
Generative Adversarial Networks (Goodfellow et al., 2014)
Видео разбор оригинальной статьи GAN
Generative adversarial networks
Самые современные генеративные модели
exactly how the NVIDIA GauGAN neural network works
DCGAN
Wassserstein GAN
Wasserstein GAN (Arjovsky et. al., 2017)
Improved Training of Wasserstein GANs (Gulrajani et al., 2017)
Spectral Normalization for Generative Adversarial Networks (Miyato et al., 2018).
ProGAN -> StyleGAN -> StyleGAN2 -> Alias-Free GAN
Analyzing and Improving the Image Quality of StyleGAN (StyleGAN2) [Karras et al., 2019]
Alias-Free Generative Adversarial Networks (Alias-Free GAN) [Karras et al., 2021]
Тонкости обучения GAN
Статья - детальный разбор тонкостей и советов
Top-k Training of GANs: Improving GAN Performance by Throwing Away Bad Samples (Sinha et al., 2020)
GAN Zoo:
BigGAN
Large Scale GAN Training for High Fidelity Natural Image Synthesis (Brock et al., 2018)
StackGAN
Дополнительно: репозиторий к статье с воспроизведением результатов.
Deep Learning Generative Models for Image Synthesis and Image Translation
youtube [StackGAN++] Realistic Image Synthesis with Stacked Generative Adversarial Networks | AISC
youtube Text to Photo-realistic Image Synthesis with Stacked Generative Adversarial Networks
ControlGAN
Controllable Generative Adversarial Network
Controllable Text-to-Image Generation
Image Generation and Recognition (Emotions)
Natural Language & Text-to-Image 2019
AC-GAN
How to Develop an Auxiliary Classifier GAN (AC-GAN) From Scratch with Keras
Conditional Image Synthesis with Auxiliary Classifier GANs: статья, репозиторий
An Auxiliary Classifier Generative Adversarial Framework for Relation Extraction
A Multi-Class Hinge Loss for Conditional GANs
Domain Transfer Network
Unsupervised Cross-Domain Image Generation (Taigma et al., 2016)
Pix2Pix
Image-to-Image Translation with Conditional Adversarial Networks (Isola et al., 2016)
Семантическая генерация
Learning to Generate Chairs, Tables and Cars with Convolutional Networks (Dosovitskiy et al., 2017)
Text to image
Text-to-Image Generation with Attention Based Recurrent Neural Networks (Zia et al., 2020)
Image-to-Image
GANs N' Roses: Stable, Controllable, Diverse Image to Image Translation (Chong et al., 2021)
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks (Zhu et al., 2017)