如何从零开始训练条件扩散模型
通过训练一个这样的模型,大家能学到什么?
Created on December 1|Last edited on December 7
Comment
最近,我在YouTube发现了一段关于在PyTorch编程条件扩散模型(conditional diffusion model)的精彩视频。建议在继续之前,请大家先观看该视频,因为我们将深入探讨这段代码,并训练一些我们自己的模型!
从本视频中,大家可以训练一个模型,以从受监控的数据集生成任何数据。
若有一个标记的数据集,就可以生成合成数据 !这对于复杂/昂贵的标签数据来说,是非常有用的。
💡
初步检查:CIFAR-10
- 添加验证指标(在测试数据上计算)
- 启用混合精度训练和多线程数据加载器
- OneCycleScheduler
- W&B,即权重和偏差)记录
💡
这种遥测技术让我们能够跟进模型的训练情况,但对于以下模型来说是不够的:
Run set
3
正如所见,损失(`train_mse`)不是很平稳,所以可以认为模型没有学到任何东西。但是,如果绘制取样图像(每10个历时运行一次扩散推理,并将图像记录到W&B),我们可以看到模型是如何不断改进的。移动下方的滑块,可以看到模型如何随着时间的推移而改进。我们对常规模型和EMA复制模型都进行了采样。
滑块数字是模型看到的批量(batches),但为每10个历时(epochs)记录一次。
注:在视频中,作者声称EMA模型产生更好的结果,但我不确定。
💡
技术细节
有很多小细节使这项训练成为可能。首先,我们用最新的PyTorch完成了GCP机器的所有训练。在进入A100(40GB)之前,我们使用V100 16GB进行CIFAR初始训练。还有以下几点:
- 图像去噪的UNet模型很大;模型有很多自我关注层,计算量很大。因此,即使图像的分辨率很小,自我关注也会在每个像素之间计算,所以对图像大小是二次计算。
- 只能将V100上的批量(batch)大小从4增加到10,而不会模糊精度。
- 还抑制了网络上的一个较深的瓶颈(512 x 512)卷积层(convolutional layers),以使其速度更快。
代码
模型
默认的非条件性扩散模型,是由一个带有自我关注层的UNet组成的。我们有经典的U型结构,具有下采样(downsampling)和上采样(upsampling)的路径。与传统UNet的主要区别在于,上行和下行块在其前向传递中,支持额外的时间步长参数(timestep argument)。这是通过将时间步长线性(timestep linearly)嵌入卷积(convolutions)来实现的,有关详情,请查看modules.py文件。
class UNet(nn.Module):def __init__(self, c_in=3, c_out=3, time_dim=256):super().__init__()self.time_dim = time_dimself.inc = DoubleConv(c_in, 64)self.down1 = Down(64, 128)self.sa1 = SelfAttention(128)self.down2 = Down(128, 256)self.sa2 = SelfAttention(256)self.down3 = Down(256, 256)self.sa3 = SelfAttention(256)self.bot1 = DoubleConv(256, 256)self.bot2 = DoubleConv(256, 256)self.up1 = Up(512, 128)self.sa4 = SelfAttention(128)self.up2 = Up(256, 64)self.sa5 = SelfAttention(64)self.up3 = Up(128, 64)self.sa6 = SelfAttention(64)self.outc = nn.Conv2d(64, c_out, kernel_size=1)def unet_forwad(self, x, t):"Classic UNet structure with down and up branches, self attention in between convs"x1 = self.inc(x)x2 = self.down1(x1, t)x2 = self.sa1(x2)x3 = self.down2(x2, t)x3 = self.sa2(x3)x4 = self.down3(x3, t)x4 = self.sa3(x4)x4 = self.bot1(x4)x4 = self.bot2(x4)x = self.up1(x4, x3, t)x = self.sa4(x)x = self.up2(x, x2, t)x = self.sa5(x)x = self.up3(x, x1, t)x = self.sa6(x)output = self.outc(x)return outputdef forward(self, x, t):"Positional encoding of the timestep before the blocks"t = t.unsqueeze(-1)t = self.pos_encoding(t, self.time_dim)return self.unet_forwad(x, t)
条件模型几乎相同,但通过将标签传递到嵌入层(Embedding layer),将类标签的编码添加到时间步长中。这是一个非常简单而巧妙的解决方案。
class UNet_conditional(UNet):def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None):super().__init__(c_in, c_out, time_dim)if num_classes is not None:self.label_emb = nn.Embedding(num_classes, time_dim)def forward(self, x, t, y=None):t = t.unsqueeze(-1)t = self.pos_encoding(t, self.time_dim)if y is not None:t += self.label_emb(y)return self.unet_forwad(x, t)
EMA代码
指数移动平均法(Exponential Moving Average)是一种让训练结果更好、更稳定的技术,其工作原理是保留上一次迭代的模型权重的副本,并以(1-beta)的系数更新当前迭代的权重。
class EMA:def __init__(self, beta):super().__init__()self.beta = betaself.step = 0def update_model_average(self, ma_model, current_model):for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):old_weight, up_weight = ma_params.data, current_params.datama_params.data = self.update_average(old_weight, up_weight)def update_average(self, old, new):if old is None:return newreturn old * self.beta + (1 - self.beta) * newdef step_ema(self, ema_model, model, step_start_ema=2000):if self.step < step_start_ema:self.reset_parameters(ema_model, model)self.step += 1returnself.update_model_average(ema_model, model)self.step += 1def reset_parameters(self, ema_model, model):ema_model.load_state_dict(model.state_dict())
训练
我们已经对代码进行了重构,使其能够正常运行。训练步骤发生在one_epoch函数上:
def train_step(self):self.optimizer.zero_grad()self.scaler.scale(loss).backward()self.scaler.step(self.optimizer)self.scaler.update()self.ema.step_ema(self.ema_model, self.model)self.scheduler.step()def one_epoch(self, train=True, use_wandb=False):avg_loss = 0.if train: self.model.train()else: self.model.eval()pbar = progress_bar(self.train_dataloader, leave=False)for i, (images, labels) in enumerate(pbar):with torch.autocast("cuda") and (torch.inference_mode() if not train else torch.enable_grad()):images = images.to(self.device)labels = labels.to(self.device)t = self.sample_timesteps(images.shape[0]).to(self.device)x_t, noise = self.noise_images(images, t)if np.random.random() < 0.1:labels = Nonepredicted_noise = self.model(x_t, t, labels)loss = self.mse(noise, predicted_noise)avg_loss += lossif train:self.train_step()if use_wandb:wandb.log({"train_mse": loss.item(),"learning_rate": self.scheduler.get_last_lr()[0]})pbar.comment = f"MSE={loss.item():2.3f}"return avg_loss.mean().item()
下方可以看到W&B工具的第一部分,我们记录训练损失(training loss)和学习率值(learning rate value),这样就可以跟踪所使用的调度器。为了实际记录样本,我们定义一个自定义函数来进行模型推理:
@torch.inference_mode()def log_images(self):"Log images to wandb and save them to disk"labels = torch.arange(self.num_classes).long().to(self.device)sampled_images = self.sample(use_ema=False, n=len(labels), labels=labels)ema_sampled_images = self.sample(use_ema=True, n=len(labels), labels=labels)plot_images(sampled_images) #to display on jupyter if available# log images to wandbwandb.log({"sampled_images": [wandb.Image(img.permute(1,2,0).squeeze().cpu().numpy()) for img in sampled_images]})wandb.log({"ema_sampled_images": [wandb.Image(img.permute(1,2,0).squeeze().cpu().numpy()) for img in ema_sampled_images]})
还有一个保存模型检查点的函数:
def save_model(self, run_name, epoch=-1):"Save model locally and to wandb"torch.save(self.model.state_dict(), os.path.join("models", run_name, f"ckpt.pt"))torch.save(self.ema_model.state_dict(), os.path.join("models", run_name, f"ema_ckpt.pt"))torch.save(self.optimizer.state_dict(), os.path.join("models", run_name, f"optim.pt"))at = wandb.Artifact("model", type="model", description="Model weights for DDPM conditional", metadata={"epoch": epoch})at.add_dir(os.path.join("models", run_name))wandb.log_artifact(at)
一切都很好地融入了fit函数。
def prepare(self, args):"Prepare the model for training"setup_logging(args.run_name)device = args.deviceself.train_dataloader, self.val_dataloader = get_data(args)self.optimizer = optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=0.001)self.scheduler = optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=args.lr,steps_per_epoch=len(self.train_dataloader), epochs=args.epochs)self.mse = nn.MSELoss()self.ema = EMA(0.995)self.scaler = torch.cuda.amp.GradScaler()def fit(self, args):self.prepare(args)for epoch in range(args.epochs):logging.info(f"Starting epoch {epoch}:")self.one_epoch(train=True)## validationif args.do_validation:self.one_epoch(train=False)# log predicitonsif epoch % args.log_every_epoch == 0:self.log_images(use_wandb=args.use_wandb)# save modelself.save_model(run_name=args.run_name, use_wandb=args.use_wandb, epoch=epoch)
图像采样
为了对图像进行采样,需要从随机噪声开始,迭代去噪以获得最终的图像。这个过程在“稳定扩散图解”(The Illustrated Stable Diffusion)中有很好的说明。我们的情况要简单得多,但采样是一样的。对于那些有兴趣了解稳定扩散架构的字节和比特的人来说,这是一本好书。
在该案例中,无需解码器,因为图像在UNet的输出后已经是全分辨率。

一次一个步骤去噪。我们正是这样做的,但没有图像解码器(Image Decoder)部分。
我们使用的是该论文中的第二种算法:
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
训练模型来生成字体
我们尝试一下另一个标记的数据集。对于这个简单的示例,需要具有小图像的数据集。在Kaggle上搜索后,Alphabet字符字体数据集(Alphabet character fonts dataset)似乎是一个不错的选择。我们看看这个数据集是什么样子的。
数据集的每一行都是一个字体系列,从A到Z的字母呈现为32x32像素BW图像。我们可以快速将其记录到W&B并探索:
Run set
10
这里的想法是要训练一个扩散模型来生成新的字体!我们可以应用与CIFAR-10相同的想法,并将条件生成到我们想要产生的实际字母上。如前所述,看看训练过程,在训练过程中定期对输出进行采样:
这里模型的缺点是,无法生成字体系列(或样式),因为我们通过单独获取每个字母来训练模型,所以只能逐个生成字母。因此,即使通过标签A-Z(或W&B🤣) ,每次都会得到独立的随机字母。
使用fastai来训练扩散模型
我们看到从头开始训练一个条件扩散模型所需的代码很少,但如果使用fastai这样的库,就可以进一步减少代码!我建议大家关注“编码员用深度学习第二部分”(Deep Learning for Coders Part 2),和Jeremy一起学习更多关于生成模型的知识。
要训练CIFAR-10,只需把这个回调传给一个Image 2 Image fastai管道:
class ConditionalDDPMCallback(Callback):def __init__(self, n_steps, beta_min, beta_max, tensor_type=TensorImage):store_attr()def before_fit(self):self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps).to(self.dls.device) # variance schedule, linearly increased with timestepself.alpha = 1. - self.betaself.alpha_bar = torch.cumprod(self.alpha, dim=0)self.sigma = torch.sqrt(self.beta)def before_batch_training(self):x0 = self.xb[0] # original images, x_0eps = self.tensor_type(torch.randn(x0.shape, device=x0.device)) # noise, x_Tbatch_size = x0.shape[0]t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long) # select random timestepsalpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)xt = torch.sqrt(alpha_bar_t)*x0 + torch.sqrt(1-alpha_bar_t)*eps #noisify the imageself.learn.xb = (xt, t, self.yb[0]) # input to our model is noisy image and timestepself.learn.yb = (eps,) # ground truth is the noisedef before_batch_sampling(self):xt = self.tensor_type(self.xb[0]) # a full batch at once!batch_size = xt.shape[0]label = torch.arange(10, dtype=torch.long, device=xt.device).repeat(batch_size//10 + 1).flatten()[0:batch_size]for t in progress_bar(reversed(range(self.n_steps)), total=self.n_steps, leave=False):t_batch = torch.full((batch_size,), t, device=xt.device, dtype=torch.long)z = torch.randn(xt.shape, device=xt.device) if t > 0 else torch.zeros(xt.shape, device=xt.device)alpha_t = self.alpha[t] # get noise level at current timestepalpha_bar_t = self.alpha_bar[t]sigma_t = self.sigma[t]xt = 1/torch.sqrt(alpha_t) * (xt - (1-alpha_t)/torch.sqrt(1-alpha_bar_t) * self.model(xt, t_batch, label=label)) + sigma_t*z # predict x_(t-1) in accordance to Algorithm 2 in paperself.learn.pred = (xt,)raise CancelBatchExceptiondef before_batch(self):if not hasattr(self, 'gather_preds'): self.before_batch_training()else: self.before_batch_sampling()
然后将此回调传给相应的Learner:
model = ConditionalUnet(dim=32, channels=1, num_classes=10).cuda()ddpm_learner = Learner(dls, model,cbs=[ConditionalDDPMCallback(n_steps=1000, beta_min=0.0001, beta_max=0.02, tensor_type=TensorImageBW)],loss_func=nn.MSELoss()).to_fp16()
由于fastai支持W&B开箱即用,因此只需传递WandbCallback即可:
with wandb.init(project="sd_from_scratch"):ddpm_learner.fit_one_cycle(10, 1e-4, cbs=WandbCallback(log_preds=False))
结论
我这样操作主要是为了帮助自己从头理解扩散模型。我从来都不太喜欢GAN,因为GAN很复杂,需要很多工程技巧来训练,所以当使用这种老的、可靠的UNet来生成图像的新技术征服了开源世界时,我觉得必须要尝试一下!
这一新的强大工具,为生成标记数据的新方法打开了大门。该工具可能成为一种预训练或扩充数据集的新方法,以获得更强大的预训练模型。关于如何检验这一假设,我有几个想法,请继续关注。
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.