Skip to main content

Conditional Diffusion Model(条件付き拡散モデル)を最初からトレーニングする方法

これらのモデルの1つをトレーニングすることで学べることとは?
Created on February 14|Last edited on February 14
このレポートは、Thomas Capelleによる「How To Train a Conditional Diffusion Model From Scratch」の翻訳です。


DALL-EからStable Diffusionまで、画像生成はおそらく現在ディープラーニングで最もエキサイティングなトピックであると思います。しかし、Twitterに投稿する面白い画像を生成する以外に、どのようなユースケースがあるでしょうか?
最近、PyTorchで条件付き拡散モデルのプログラミングに関するこちらの素晴らしい動画をYouTubeで見つけました。このコードを詳しく見ながら独自のモデルをトレーニングする前に、この動画をご覧ください!


この動画の要点は、管理データセットから任意のデータを生成するようにモデルをトレーニングできるということです。
ラベル付きデータセットをお持ちの場合は、合成データを生成できます!これは、複雑で高価なラベル付けデータに非常に役に立つ可能性があります。
💡

最初のチェック:CIFAR-10

元のコードでは、CIFAR-10を使用して条件付き拡散モデルをトレーニングしています。最初にこれを行いましょう!始める前に、私たちがコードベースに加えたいくつかの改善点は次のとおりです:
  • 検証メトリックを追加 (テストデータで計算)
  • 混合精度トレーニングとマルチスレッドデータローダーを実現
  • OneCycleScheduler
  • Weights and Biases (W&B) ロギング
この実験コードは、こちらでご覧いただけます : https://github.com/tcapelle/Diffusion-Models-pytorch
💡
このテレメトリにより、モデルのトレーニング方法を追跡できましたが、これらのモデルでは十分ではありません:


ご覧のとおり、「train_mse」はあまり滑らかではないため、モデルは何も学習していないと考えることができます。しかし、サンプリングされた画像をプロットすると(10エポックごとに拡散推論を実行し、画像をW&Bに記録しました)、モデルがどのように改善し続けるかを確認できます。下のスライダーを動かすと、モデルが時間の経過とともにどのように改善されるかを確認できます。私たちは、通常モデルとEMAコピーモデルの両方をサンプリングしました。
スライダー番号はモデルで表示されるバッチですが、10エポックごとにログに記録されています。


注意:著者は動画でEMAモデルがより優れた出力を生成すると主張していますが、確かではありません。
💡

技術的な詳細

このトレーニングを可能にする細かな点がたくさんあります。まず初めに、私たちはGCPマシンですべてのトレーニングを行い、最新のPyTorchを使用しました。V100(40GB)に移行する前に、CIFARの初期トレーニングにA100 16GBを使用しました。なお、追加ポイントは次のとおりです:
  • 画像のノイズを除去するUNetモデルは大きい。自己注意層(self-attention layers)が多く、コンピュータによる処理は非常に重い。 したがって、画像の解像度が小さい場合でも、自己注意はすべてのピクセル間で計算されるため、画像サイズで正方形になる。
  • より鮮明な結果を得るために、(32x32ではなく)64x64の画像を持つCIFAR-10のアップスケールバージョンを使用した。 Kaggleでこのデータセットをチェック可。
  • 混合精度が無い状態でV100のバッチサイズを10から4にすることしかできなかった。
  • また、ネットワーク上のより深いボトルネック(512x512)畳み込み層の1つを抑制して、少し高速にした。

コード

この記事の完全なコードはこちらでご覧いただけます。この記事では、コードベースの関連部分に限定します。ご不明な点が場合はお問い合わせください。(または以下のコメント欄にコメントをお寄せください!)

モデル

デフォルトの無条件拡散モデルは、自己注意層を持つUNetで構成されています。ダウンサンプリングパスとアップサンプリングパスを備えた従来のU構造が見られます。 従来のUNetとの主な違いは、アップブロックとダウンブロックがフォワードパスで追加のtimestep引数をサポートしていることです。これは、時間ステップを畳み込みに線形に埋め込むことによって行われています。詳細については、modules.py fileをご確認ください。
class UNet(nn.Module):
def __init__(self, c_in=3, c_out=3, time_dim=256):
super().__init__()
self.time_dim = time_dim
self.inc = DoubleConv(c_in, 64)
self.down1 = Down(64, 128)
self.sa1 = SelfAttention(128)
self.down2 = Down(128, 256)
self.sa2 = SelfAttention(256)
self.down3 = Down(256, 256)
self.sa3 = SelfAttention(256)

self.bot1 = DoubleConv(256, 256)
self.bot2 = DoubleConv(256, 256)

self.up1 = Up(512, 128)
self.sa4 = SelfAttention(128)
self.up2 = Up(256, 64)
self.sa5 = SelfAttention(64)
self.up3 = Up(128, 64)
self.sa6 = SelfAttention(64)
self.outc = nn.Conv2d(64, c_out, kernel_size=1)
def unet_forwad(self, x, t):
"Classic UNet structure with down and up branches, self attention in between convs"
x1 = self.inc(x)
x2 = self.down1(x1, t)
x2 = self.sa1(x2)
x3 = self.down2(x2, t)
x3 = self.sa2(x3)
x4 = self.down3(x3, t)
x4 = self.sa3(x4)

x4 = self.bot1(x4)
x4 = self.bot2(x4)

x = self.up1(x4, x3, t)
x = self.sa4(x)
x = self.up2(x, x2, t)
x = self.sa5(x)
x = self.up3(x, x1, t)
x = self.sa6(x)
output = self.outc(x)
return output
def forward(self, x, t):
"Positional encoding of the timestep before the blocks"
t = t.unsqueeze(-1)
t = self.pos_encoding(t, self.time_dim)
return self.unet_forwad(x, t)
条件付きモデルはほぼ同じですが、ラベルを埋め込みレイヤーに渡すことで、クラスラベルのエンコーディングをタイムステップに追加します。これは非常にシンプルでエレガントなソリューションです。
class UNet_conditional(UNet):
def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None):
super().__init__(c_in, c_out, time_dim)
if num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_dim)

def forward(self, x, t, y=None):
t = t.unsqueeze(-1)
t = self.pos_encoding(t, self.time_dim)

if y is not None:
t += self.label_emb(y)

return self.unet_forwad(x, t)

EMA コード

EMA(指数平滑移動平均線)は、結果を改善し、より安定したトレーニングにするために使用されるテクニックです。 これは、前の反復モデルウェイトのコピーを保持し、現在の反復ウェイトで係数(1-beta)だけを更新することで機能します。
class EMA:
def __init__(self, beta):
super().__init__()
self.beta = beta
self.step = 0

def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)

def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new

def step_ema(self, ema_model, model, step_start_ema=2000):
if self.step < step_start_ema:
self.reset_parameters(ema_model, model)
self.step += 1
return
self.update_model_average(ema_model, model)
self.step += 1

def reset_parameters(self, ema_model, model):
ema_model.load_state_dict(model.state_dict())

トレーニング

私たちは、洗練するためにコードをリファクタリングしました。トレーニングステップは、one_epoch関数で行われます:
def train_step(self):
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.ema.step_ema(self.ema_model, self.model)
self.scheduler.step()

def one_epoch(self, train=True, use_wandb=False):
avg_loss = 0.
if train: self.model.train()
else: self.model.eval()
pbar = progress_bar(self.train_dataloader, leave=False)
for i, (images, labels) in enumerate(pbar):
with torch.autocast("cuda") and (torch.inference_mode() if not train else torch.enable_grad()):
images = images.to(self.device)
labels = labels.to(self.device)
t = self.sample_timesteps(images.shape[0]).to(self.device)
x_t, noise = self.noise_images(images, t)
if np.random.random() < 0.1:
labels = None
predicted_noise = self.model(x_t, t, labels)
loss = self.mse(noise, predicted_noise)
avg_loss += loss
if train:
self.train_step()
if use_wandb:
wandb.log({"train_mse": loss.item(),
"learning_rate": self.scheduler.get_last_lr()[0]})
pbar.comment = f"MSE={loss.item():2.3f}"
return avg_loss.mean().item()
ここでは、W&Bインストルメンテーションの最初の部分で、トレーニング損失と学習率の値を記録していることがわかります。このようにして、使用しているスケジューラに従うことができます。サンプルを実際にログに記録するには、モデル推論を実行するカスタム関数を定義しましょう:
@torch.inference_mode()
def log_images(self):
"Log images to wandb and save them to disk"
labels = torch.arange(self.num_classes).long().to(self.device)
sampled_images = self.sample(use_ema=False, n=len(labels), labels=labels)
ema_sampled_images = self.sample(use_ema=True, n=len(labels), labels=labels)
plot_images(sampled_images) #to display on jupyter if available
# log images to wandb
wandb.log({"sampled_images": [wandb.Image(img.permute(1,2,0).squeeze().cpu().numpy()) for img in sampled_images]})
wandb.log({"ema_sampled_images": [wandb.Image(img.permute(1,2,0).squeeze().cpu().numpy()) for img in ema_sampled_images]})
また、モデルチェックポイントを保存する関数もあります:
def save_model(self, run_name, epoch=-1):
"Save model locally and to wandb"
torch.save(self.model.state_dict(), os.path.join("models", run_name, f"ckpt.pt"))
torch.save(self.ema_model.state_dict(), os.path.join("models", run_name, f"ema_ckpt.pt"))
torch.save(self.optimizer.state_dict(), os.path.join("models", run_name, f"optim.pt"))
at = wandb.Artifact("model", type="model", description="Model weights for DDPM conditional", metadata={"epoch": epoch})
at.add_dir(os.path.join("models", run_name))
wandb.log_artifact(at)
全てがfit関数に上手くフィットします
def prepare(self, args):
"Prepare the model for training"
setup_logging(args.run_name)
device = args.device
self.train_dataloader, self.val_dataloader = get_data(args)
self.optimizer = optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=0.001)
self.scheduler = optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=args.lr,
steps_per_epoch=len(self.train_dataloader), epochs=args.epochs)
self.mse = nn.MSELoss()
self.ema = EMA(0.995)
self.scaler = torch.cuda.amp.GradScaler()

def fit(self, args):
self.prepare(args)
for epoch in range(args.epochs):
logging.info(f"Starting epoch {epoch}:")
self.one_epoch(train=True)
## validation
if args.do_validation:
self.one_epoch(train=False)
# log predicitons
if epoch % args.log_every_epoch == 0:
self.log_images(use_wandb=args.use_wandb)

# save model
self.save_model(run_name=args.run_name, use_wandb=args.use_wandb, epoch=epoch)

画像のサンプリング

画像をサンプリングするには、ランダムノイズから始めて、最終的な画像を取得するためにノイズ除去を繰り返す必要があります。 この手順は、「The Illustrated Stable Diffusion.」で詳しく説明されています。私たちのケースははるかに単純ですが、サンプリングは同じです。とにかく、Stable Diffusion構造を理解することに興味がある方には、是非読んでいただきたい記事です。
私たちケースの場合、UNetの出力後に画像はすでにフル解像度であるため、デコーダーは必要ありません。
ステップを1度に1つずつノイズ除去。私たちはまさに同じことをしていますが、Image Decoderの部分はありません。
サンプリングコードは、ノイズのスケジューラに従って画像からノイズを徐々に除去します。ランダムな純粋なノイズから始めて、サンプリングされた画像で終わります。パラメータはDDPM論文にちなんで命名されているため、コードはやや紛らわしくなっています。
私たちは論文の2番目のアルゴリズムを使用しています:
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise

フォントを生成するためのモデルのトレーニング

別のラベル付きデータセットを試してみましょう。この単純な例では、小さな画像を持つデータセットが必要です。Kaggleで検索した結果、Alphabet Character Fonts Datasetが適切なデータセットのようでした。このデータセットがどのようなものか見てみましょう。
データセットの各行はフォントファミリーであり、A から Z までの文字が 32x32 ピクセルの BW 画像としてレンダリングされています。これをW&Bにすばやく記録して、細かく調べることができます:

Run set
10

ここでのアイデアは、拡散モデルをトレーニングして新しいフォントを生成することです。 CIFAR-10と同じ考え方と条件生成を、作成したい実際の文字に適用できます。以前と同じように、トレーニング中に定期的に出力をサンプリングして、トレーニングプロセスを見てみましょう:


ここでのモデルの欠点は、各文字を個別に取得してモデルをトレーニングしたため、フォントファミリー(またはスタイル)を生成する方法がないことです。したがって、文字を1つずつのみ生成できます。そのため、ラベルAZ(またはW&B🤣)を渡したとしても、毎回独立したラ��ダムな文字を取得することになります。


こちらのノートブックでこのwandb.Tablesの作成方法をご確認いただけます。

fastai を使用して拡散モデルをトレーニングする

条件付き拡散モデルをゼロからトレーニングするために必要なコードが少ないことがわかりましたが、ライブラリにfastaiを使用すると、さらにコードを削減することができます! 現在進行中の「Deep Learning for Coders Part 2(コーダーのためのディープラーニングパート2)」をフォローして、ジェレミーと一緒に生成モデルについてより詳しく学ぶことをお勧めします。
それまでの間、私が前に紹介したのと同じCIFARコードは、fastaiの言語の単なるコールバックです。多数のfastai開発者が拡散モデルの実装を追加しているfastdiffusionリポジトリをチェックしてみてください。
CIFAR-10 をトレーニングするには、このコールバックを Image 2 Image fastai パイプラインに渡すだけです:
class ConditionalDDPMCallback(Callback):
def __init__(self, n_steps, beta_min, beta_max, tensor_type=TensorImage):
store_attr()

def before_fit(self):
self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps).to(self.dls.device) # variance schedule, linearly increased with timestep
self.alpha = 1. - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.sigma = torch.sqrt(self.beta)

def before_batch_training(self):
x0 = self.xb[0] # original images, x_0
eps = self.tensor_type(torch.randn(x0.shape, device=x0.device)) # noise, x_T
batch_size = x0.shape[0]
t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long) # select random timesteps
alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
xt = torch.sqrt(alpha_bar_t)*x0 + torch.sqrt(1-alpha_bar_t)*eps #noisify the image
self.learn.xb = (xt, t, self.yb[0]) # input to our model is noisy image and timestep
self.learn.yb = (eps,) # ground truth is the noise


def before_batch_sampling(self):
xt = self.tensor_type(self.xb[0]) # a full batch at once!
batch_size = xt.shape[0]
label = torch.arange(10, dtype=torch.long, device=xt.device).repeat(batch_size//10 + 1).flatten()[0:batch_size]
for t in progress_bar(reversed(range(self.n_steps)), total=self.n_steps, leave=False):
t_batch = torch.full((batch_size,), t, device=xt.device, dtype=torch.long)
z = torch.randn(xt.shape, device=xt.device) if t > 0 else torch.zeros(xt.shape, device=xt.device)
alpha_t = self.alpha[t] # get noise level at current timestep
alpha_bar_t = self.alpha_bar[t]
sigma_t = self.sigma[t]
xt = 1/torch.sqrt(alpha_t) * (xt - (1-alpha_t)/torch.sqrt(1-alpha_bar_t) * self.model(xt, t_batch, label=label)) + sigma_t*z # predict x_(t-1) in accordance to Algorithm 2 in paper
self.learn.pred = (xt,)
raise CancelBatchException

def before_batch(self):
if not hasattr(self, 'gather_preds'): self.before_batch_training()
else: self.before_batch_sampling()
それから、このコールバックを対応するLearnerに渡します:
model = ConditionalUnet(dim=32, channels=1, num_classes=10).cuda()
ddpm_learner = Learner(dls, model,
cbs=[ConditionalDDPMCallback(n_steps=1000, beta_min=0.0001, beta_max=0.02, tensor_type=TensorImageBW)],
loss_func=nn.MSELoss()).to_fp16()
fastaiはW&Bをサポートしているので、WandbCallbackを渡すのと同じくらい簡単です。
with wandb.init(project="sd_from_scratch"):
ddpm_learner.fit_one_cycle(10, 1e-4, cbs=WandbCallback(log_preds=False))
より包括的な例については、こちらのノートブックをご確認ください。

まとめ

私は主に拡散モデルをゼロから理解するのを助けるためにこれを行いました。GANは複雑で、トレーニングするために多くのエンジニアリングトリックが必要なため、あまり好きではありませんでした。そこで、古い信頼できるUNetを使用して画像を生成するこの斬新な手法がオープンソース界でメジャーになった今、これを試さない手はないと思ったのです!
この新しい強力なツールは、ラベル付きデータを生成する新たな方法への扉を開きました。これは、データセットを事前トレーニング/拡張し、より強力な事前トレーニング済みモデルを取得する新しい方法になる可能性があります。この仮説をテストする方法についていくつかのアイデアがありますので、発表までどうぞ楽しみにお待ちください。
Iterate on AI agents and models faster. Try Weights & Biases today.