Skip to main content

如何从零开始训练条件扩散模型

通过训练一个这样的模型,大家能学到什么?
Created on December 1|Last edited on December 7
本报告是作者Thomas Capelle所写的"How to Train a Conditional Diffusion Model from Scratch"的翻译


从DALL-E到稳定扩散(Stable Diffusion),图像生成可能是目前深度学习中最令人兴奋的事情。但除了生成有趣的图片发布在Twitter上,大家还能想到哪些用例呢?
最近,我在YouTube发现了一段关于在PyTorch编程条件扩散模型(conditional diffusion model)的精彩视频。建议在继续之前,请大家先观看该视频,因为我们将深入探讨这段代码,并训练一些我们自己的模型!


从本视频中,大家可以训练一个模型,以从受监控的数据集生成任何数据。
若有一个标记的数据集,就可以生成合成数据 !这对于复杂/昂贵的标签数据来说,是非常有用的。
💡

初步检查:CIFAR-10

原始代码使用CIFAR-10来训练条件扩散模型,所以我们先这样操作!开始之前,我们对代码库做了如下改进:
  • 添加验证指标(在测试数据上计算)
  • 启用混合精度训练和多线程数据加载器
  • OneCycleScheduler
  • W&B,即权重和偏差)记录
请点击链接查看该实验的代码:https://github.com/tcapelle/Diffusion-Models-pytorch 
💡
这种遥测技术让我们能够跟进模型的训练情况,但对于以下模型来说是不够的:

Run set
3

正如所见,损失(`train_mse`)不是很平稳,所以可以认为模型没有学到任何东西。但是,如果绘制取样图像(每10个历时运行一次扩散推理,并将图像记录到W&B),我们可以看到模型是如何不断改进的。移动下方的滑块,可以看到模型如何随着时间的推移而改进。我们对常规模型和EMA复制模型都进行了采样。
滑块数字是模型看到的批量(batches),但为每10个历时(epochs)记录一次。


注:在视频中,作者声称EMA模型产生更好的结果,但我不确定。
💡

技术细节

有很多小细节使这项训练成为可能。首先,我们用最新的PyTorch完成了GCP机器的所有训练。在进入A100(40GB)之前,我们使用V100 16GB进行CIFAR初始训练。还有以下几点:
  • 图像去噪的UNet模型很大;模型有很多自我关注层,计算量很大。因此,即使图像的分辨率很小,自我关注也会在每个像素之间计算,所以对图像大小是二次计算。
  • 为了获得更清晰的结果,我们使用了具有64x64图像(而不是32x32)的CIFAR-10升级版。可在Kaggle找到此数据集
  • 只能将V100上的批量(batch)大小从4增加到10,而不会模糊精度。
  • 还抑制了网络上的一个较深的瓶颈(512 x 512)卷积层(convolutional layers),以使其速度更快。

代码

本文的完整代码可于此处找到。在本文中,我们只讨论代码库的相关部分,若有任何问题,请与我联系(或在下方评论中提出)。

模型

默认的非条件性扩散模型,是由一个带有自我关注层的UNet组成的。我们有经典的U型结构,具有下采样(downsampling)和上采样(upsampling)的路径。与传统UNet的主要区别在于,上行和下行块在其前向传递中,支持额外的时间步长参数(timestep argument)。这是通过将时间步长线性(timestep linearly)嵌入卷积(convolutions)来实现的,有关详情,请查看modules.py文件
class UNet(nn.Module):
def __init__(self, c_in=3, c_out=3, time_dim=256):
super().__init__()
self.time_dim = time_dim
self.inc = DoubleConv(c_in, 64)
self.down1 = Down(64, 128)
self.sa1 = SelfAttention(128)
self.down2 = Down(128, 256)
self.sa2 = SelfAttention(256)
self.down3 = Down(256, 256)
self.sa3 = SelfAttention(256)

self.bot1 = DoubleConv(256, 256)
self.bot2 = DoubleConv(256, 256)

self.up1 = Up(512, 128)
self.sa4 = SelfAttention(128)
self.up2 = Up(256, 64)
self.sa5 = SelfAttention(64)
self.up3 = Up(128, 64)
self.sa6 = SelfAttention(64)
self.outc = nn.Conv2d(64, c_out, kernel_size=1)
def unet_forwad(self, x, t):
"Classic UNet structure with down and up branches, self attention in between convs"
x1 = self.inc(x)
x2 = self.down1(x1, t)
x2 = self.sa1(x2)
x3 = self.down2(x2, t)
x3 = self.sa2(x3)
x4 = self.down3(x3, t)
x4 = self.sa3(x4)

x4 = self.bot1(x4)
x4 = self.bot2(x4)

x = self.up1(x4, x3, t)
x = self.sa4(x)
x = self.up2(x, x2, t)
x = self.sa5(x)
x = self.up3(x, x1, t)
x = self.sa6(x)
output = self.outc(x)
return output
def forward(self, x, t):
"Positional encoding of the timestep before the blocks"
t = t.unsqueeze(-1)
t = self.pos_encoding(t, self.time_dim)
return self.unet_forwad(x, t)
条件模型几乎相同,但通过将标签传递到嵌入层(Embedding layer),将类标签的编码添加到时间步长中。这是一个非常简单而巧妙的解决方案。
class UNet_conditional(UNet):
def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None):
super().__init__(c_in, c_out, time_dim)
if num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_dim)

def forward(self, x, t, y=None):
t = t.unsqueeze(-1)
t = self.pos_encoding(t, self.time_dim)

if y is not None:
t += self.label_emb(y)

return self.unet_forwad(x, t)

EMA代码

指数移动平均法(Exponential Moving Average)是一种让训练结果更好、更稳定的技术,其工作原理是保留上一次迭代的模型权重的副本,并以(1-beta)的系数更新当前迭代的权重。
class EMA:
def __init__(self, beta):
super().__init__()
self.beta = beta
self.step = 0

def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)

def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new

def step_ema(self, ema_model, model, step_start_ema=2000):
if self.step < step_start_ema:
self.reset_parameters(ema_model, model)
self.step += 1
return
self.update_model_average(ema_model, model)
self.step += 1

def reset_parameters(self, ema_model, model):
ema_model.load_state_dict(model.state_dict())

训练

我们已经对代码进行了重构,使其能够正常运行。训练步骤发生在one_epoch函数上:
def train_step(self):
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.ema.step_ema(self.ema_model, self.model)
self.scheduler.step()

def one_epoch(self, train=True, use_wandb=False):
avg_loss = 0.
if train: self.model.train()
else: self.model.eval()
pbar = progress_bar(self.train_dataloader, leave=False)
for i, (images, labels) in enumerate(pbar):
with torch.autocast("cuda") and (torch.inference_mode() if not train else torch.enable_grad()):
images = images.to(self.device)
labels = labels.to(self.device)
t = self.sample_timesteps(images.shape[0]).to(self.device)
x_t, noise = self.noise_images(images, t)
if np.random.random() < 0.1:
labels = None
predicted_noise = self.model(x_t, t, labels)
loss = self.mse(noise, predicted_noise)
avg_loss += loss
if train:
self.train_step()
if use_wandb:
wandb.log({"train_mse": loss.item(),
"learning_rate": self.scheduler.get_last_lr()[0]})
pbar.comment = f"MSE={loss.item():2.3f}"
return avg_loss.mean().item()
下方可以看到W&B工具的第一部分,我们记录训练损失(training loss)和学习率值(learning rate value),这样就可以跟踪所使用的调度器。为了实际记录样本,我们定义一个自定义函数来进行模型推理:
@torch.inference_mode()
def log_images(self):
"Log images to wandb and save them to disk"
labels = torch.arange(self.num_classes).long().to(self.device)
sampled_images = self.sample(use_ema=False, n=len(labels), labels=labels)
ema_sampled_images = self.sample(use_ema=True, n=len(labels), labels=labels)
plot_images(sampled_images) #to display on jupyter if available
# log images to wandb
wandb.log({"sampled_images": [wandb.Image(img.permute(1,2,0).squeeze().cpu().numpy()) for img in sampled_images]})
wandb.log({"ema_sampled_images": [wandb.Image(img.permute(1,2,0).squeeze().cpu().numpy()) for img in ema_sampled_images]})
还有一个保存模型检查点的函数:
def save_model(self, run_name, epoch=-1):
"Save model locally and to wandb"
torch.save(self.model.state_dict(), os.path.join("models", run_name, f"ckpt.pt"))
torch.save(self.ema_model.state_dict(), os.path.join("models", run_name, f"ema_ckpt.pt"))
torch.save(self.optimizer.state_dict(), os.path.join("models", run_name, f"optim.pt"))
at = wandb.Artifact("model", type="model", description="Model weights for DDPM conditional", metadata={"epoch": epoch})
at.add_dir(os.path.join("models", run_name))
wandb.log_artifact(at)
一切都很好地融入了fit函数。
def prepare(self, args):
"Prepare the model for training"
setup_logging(args.run_name)
device = args.device
self.train_dataloader, self.val_dataloader = get_data(args)
self.optimizer = optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=0.001)
self.scheduler = optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=args.lr,
steps_per_epoch=len(self.train_dataloader), epochs=args.epochs)
self.mse = nn.MSELoss()
self.ema = EMA(0.995)
self.scaler = torch.cuda.amp.GradScaler()

def fit(self, args):
self.prepare(args)
for epoch in range(args.epochs):
logging.info(f"Starting epoch {epoch}:")
self.one_epoch(train=True)
## validation
if args.do_validation:
self.one_epoch(train=False)
# log predicitons
if epoch % args.log_every_epoch == 0:
self.log_images(use_wandb=args.use_wandb)

# save model
self.save_model(run_name=args.run_name, use_wandb=args.use_wandb, epoch=epoch)

图像采样

为了对图像进行采样,需要从随机噪声开始,迭代去噪以获得最终的图像。这个过程在“稳定扩散图解”(The Illustrated Stable Diffusion)中有很好的说明。我们的情况要简单得多,但采样是一样的。对于那些有兴趣了解稳定扩散架构的字节和比特的人来说,这是一本好书。
在该案例中,无需解码器,因为图像在UNet的输出后已经是全分辨率。
一次一个步骤去噪。我们正是这样做的,但没有图像解码器(Image Decoder)部分。
采样代码按照噪声调度器从图像中逐渐去除噪声。我们从随机的纯噪声开始,以采样图像结束。这段代码有点混乱,因为参数是以DDPM论文上的方程式命名的。
我们使用的是该论文中的第二种算法:
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise

训练模型来生成字体

我们尝试一下另一个标记的数据集。对于这个简单的示例,需要具有小图像的数据集。在Kaggle上搜索后,Alphabet字符字体数据集(Alphabet character fonts dataset)似乎是一个不错的选择。我们看看这个数据集是什么样子的。
数据集的每一行都是一个字体系列,从A到Z的字母呈现为32x32像素BW图像。我们可以快速将其记录到W&B并探索:

Run set
10

这里的想法是要训练一个扩散模型来生成新的字体!我们可以应用与CIFAR-10相同的想法,并将条件生成到我们想要产生的实际字母上。如前所述,看看训练过程,在训练过程中定期对输出进行采样:


这里模型的缺点是,无法生成字体系列(或样式),因为我们通过单独获取每个字母来训练模型,所以只能逐个生成字母。因此,即使通过标签A-Z(或W&B🤣) ,每次都会得到独立的随机字母。


请于该记事本查看我是如何创建这个wandb.Tables的。

使用fastai来训练扩散模型

我们看到从头开始训练一个条件扩散模型所需的代码很少,但如果使用fastai这样的库,就可以进一步减少代码!我建议大家关注“编码员用深度学习第二部分”(Deep Learning for Coders Part 2),和Jeremy一起学习更多关于生成模型的知识。
同时:我之前展示的CIFAR代码,只是fastai语言中的一个回调。查看fastdiffusion存储库,那里有一群fastai开发人员正在添加其扩散模型的实现。
要训练CIFAR-10,只需把这个回调传给一个Image 2 Image fastai管道:
class ConditionalDDPMCallback(Callback):
def __init__(self, n_steps, beta_min, beta_max, tensor_type=TensorImage):
store_attr()

def before_fit(self):
self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps).to(self.dls.device) # variance schedule, linearly increased with timestep
self.alpha = 1. - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.sigma = torch.sqrt(self.beta)

def before_batch_training(self):
x0 = self.xb[0] # original images, x_0
eps = self.tensor_type(torch.randn(x0.shape, device=x0.device)) # noise, x_T
batch_size = x0.shape[0]
t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long) # select random timesteps
alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
xt = torch.sqrt(alpha_bar_t)*x0 + torch.sqrt(1-alpha_bar_t)*eps #noisify the image
self.learn.xb = (xt, t, self.yb[0]) # input to our model is noisy image and timestep
self.learn.yb = (eps,) # ground truth is the noise


def before_batch_sampling(self):
xt = self.tensor_type(self.xb[0]) # a full batch at once!
batch_size = xt.shape[0]
label = torch.arange(10, dtype=torch.long, device=xt.device).repeat(batch_size//10 + 1).flatten()[0:batch_size]
for t in progress_bar(reversed(range(self.n_steps)), total=self.n_steps, leave=False):
t_batch = torch.full((batch_size,), t, device=xt.device, dtype=torch.long)
z = torch.randn(xt.shape, device=xt.device) if t > 0 else torch.zeros(xt.shape, device=xt.device)
alpha_t = self.alpha[t] # get noise level at current timestep
alpha_bar_t = self.alpha_bar[t]
sigma_t = self.sigma[t]
xt = 1/torch.sqrt(alpha_t) * (xt - (1-alpha_t)/torch.sqrt(1-alpha_bar_t) * self.model(xt, t_batch, label=label)) + sigma_t*z # predict x_(t-1) in accordance to Algorithm 2 in paper
self.learn.pred = (xt,)
raise CancelBatchException

def before_batch(self):
if not hasattr(self, 'gather_preds'): self.before_batch_training()
else: self.before_batch_sampling()
然后将此回调传给相应的Learner
model = ConditionalUnet(dim=32, channels=1, num_classes=10).cuda()
ddpm_learner = Learner(dls, model,
cbs=[ConditionalDDPMCallback(n_steps=1000, beta_min=0.0001, beta_max=0.02, tensor_type=TensorImageBW)],
loss_func=nn.MSELoss()).to_fp16()
由于fastai支持W&B开箱即用,因此只需传递WandbCallback即可:
with wandb.init(project="sd_from_scratch"):
ddpm_learner.fit_one_cycle(10, 1e-4, cbs=WandbCallback(log_preds=False))
如需更全面的例子,请查看此记事本

结论

我这样操作主要是为了帮助自己从头理解扩散模型。我从来都不太喜欢GAN,因为GAN很复杂,需要很多工程技巧来训练,所以当使用这种老的、可靠的UNet来生成图像的新技术征服了开源世界时,我觉得必须要尝试一下!
这一新的强大工具,为生成标记数据的新方法打开了大门。该工具可能成为一种预训练或扩充数据集的新方法,以获得更强大的预训练模型。关于如何检验这一假设,我有几个想法,请继续关注。
Iterate on AI agents and models faster. Try Weights & Biases today.