Skip to main content

PyTorch LightningとWeights & Biasesを使った画像分類

この記事では、PyTorch Lightningを使用しPyTorchコードの可読性と再現性を向上させる方法についての実践的な紹介を提供します。
Created on January 4|Last edited on January 27
このレポートは、Ayush Thakurによる「Image Classification Using PyTorch Lightning and Weights & Biases」の翻訳です。


この記事では、PyTorch Lightningを使用して画像分類パイプラインを構築します。コードの可読性と再現性を高めるために、このスタイルガイドに従います。


Table of Contents



⚡️PyTorch Lightningとは?

PyTorchは深層学習研究のための非常に強力なフレームワークです。しかし、研究が複雑になり、16ビット精度、マルチGPUトレーニング、TPUト訓練のようなものが混ざると、ユーザーはバグを発生させる可能性が高いです。PyTorch Lightningを使えば、研究と工学から切り離すことができます。
PyTorch Lightningを使用して画像分類パイプラインを構築しましょう。これは、PyTorch Lightningの厄介ごとに慣れるためのスタートガイドだと考えてください。
PyTorch Lightning ⚡は別のフレームワークではなく、PyTorchのスタイルガイドです。

⏳取付けと輸入高

このチュートリアルでは、PyTorch Lightning(当たり前か!)とWeights & Biasesが必要です。
# install pytorch lighting
! pip install pytorch-lightning --quiet
# install weights and biases
!pip install wandb --quiet
通常のPyTorchのインポートに加え、以下のインポートが必要です ⚡。
import pytorch_lightning as pl
# your favorite machine learning tracking tool
from pytorch_lightning.loggers import WandbLogger
WandbLoggerを使って、実験結果を追跡し、W&Bに直接ログを記録します。

🔧 DataModule - 私たちが望むデータパイプライン

DataModuleはデータ関連のフックを LightningModule から切り離し、データセットにとらわれないモデルを開発できるようにするための方法です。
データモジュールは、データパイプラインを1つの共有可能で再利用可能なクラスとして整理します。データモジュールは、PyTorchのデータ処理に関わる5つのステップをカプセル化します。
  • ダウンロード/トークン化/処理
  • クリーンアップしてディスクに保存する。
  • データセット内に読み込む。
  • 変換(回転、トークン化、など)を適用する。
DataModuleについて詳しくはこちら。CIFAR-10データセット用のdatamoduleを作ってみましょう。

1. Init

CIFAR10DataModuleはPyTorch Lightningの LightningDataModuleをサブクラスにしています。データパイプラインに必要なハイパーパラメータを __init__ メソッドで渡します。また、データ変換パイプラインの定義もここで行います。
class CIFAR10DataModule(pl.LightningDataModule):
def __init__(self, batch_size, data_dir: str = './'):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size

self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.dims = (3, 32, 32)
self.num_classes = 10

2. Perpare_data

ここでは、データセットをダウンロードするロジックを定義します。今回はtorchvisionのCIFAR10データセットクラスを使用してダウンロードを行います。このメソッドは、ディスクに書き込むようなことや、分散環境で1つのGPUからだけ行う必要があることを行うために使用します。この関数では、状態の割り当てを一切行わないでください(例 self.something = ...)。
def prepare_data(self):
# download
CIFAR10(self.data_dir, train=True, download=True)
CIFAR10(self.data_dir, train=False, download=True)

3. Setup_data

ここでは、ファイルからデータを読み込み、各分割のためのPyTorchテンソルデータセットを準備します。したがって、データ分割は再現可能です。このメソッドでは、'fit'と'test'のロジックを分けるために `stage`ステージargが必要です。これは、データセット全体を一度に読み込まない場合に有効です。すべてのGPUで実行したいデータ操作は、ここで定義します。これには、PyTorchテンソルデータセットにtransformを適用することも含まれます。
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)


4. X_dataloader

train_dataloader()val_dataloader() test_dataloader() は、setup() で準備したそれぞれのデータセットをラップして作成したPyTorch DataLoader のインスタンスを返します。
train_dataloader(), val_dataloader(), and test_dataloader() all return PyTorch DataLoader instances that are created by wrapping their respective datasets that we prepared in setup()
def train_dataloader(self):
return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)

def val_dataloader(self):
return DataLoader(self.cifar_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size)


📱 コールバック

コールバックは、プロジェクト間で再利用できる自己完結型のプログラムです。PyTorch Lightningには、定期的に使用されるいくつかの組み込みコールバックが付属しています。
PyTorch Lightningのコールバックについてはこちらをご覧ください。

組み込みコールバック

このチュートリアルでは、Early StoppingModel Checkpointの組み込みコールバックを使用します。トレーナー に渡すことができます。

カスタムコールバック

Custom Keras callbackに慣れているなら、PyTorchパイプラインで同じことができるのは、まさにケーキの上のサクランボと言えるでしょう。
私たちは画像分類を行っているので、画像のいくつかのサンプルでモデルの予測を視覚化する機能は有用です。コールバックの形でこれを使えば、早い段階でモデルをデバッグするのに役立ちます。

1.__Init__

ImagePredictionLogger はPyTorch Lightningの Callback クラスをサブクラス化しています。ここでは、画像とラベルのタプルである val_samples を渡します。 num_samplesは、W&Bダッシュボードにログを残す画像の数です。
class ImagePredictionLogger(Callback):
def __init__(self, val_samples, num_samples=32):
super().__init__()
self.num_samples = num_samples
self.val_imgs, self.val_labels = val_samples

2. コールバックフック

利用可能なコールバック・フックはすべてこちらにあります。
on_validation_epoch_end メソッドは、検証エポックが終了したときに呼び出されます。このメソッドは2つの引数を取ります - trainerpl_module これらは Trainer から自動的に渡されます。
trainer.logger.experimental を使用することで、Weights & Biasesのすべての機能を使用することができます。
def on_validation_epoch_end(self, trainer, pl_module):
# Bring the tensors to CPU
val_imgs = self.val_imgs.to(device=pl_module.device)
val_labels = self.val_labels.to(device=pl_module.device)
# Get model prediction
logits = pl_module(val_imgs)
preds = torch.argmax(logits, -1)
# Log the images as wandb Image
trainer.logger.experiment.log({
"examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")
for x, pred, y in zip(val_imgs[:self.num_samples],
preds[:self.num_samples],
val_labels[:self.num_samples])]
})

このコールバックの結果を見ます。


🎺 LightningModule - システムを定義する

LightningModule は、モデルではなくシステムを定義します。ここでは、システムはすべての研究コードを1つのクラスにグループ化し、自己完結するようにします。 LightningModuleはPyTorchのコードを5つのセクションに整理します。
  • 計算 (__init__).
  • Trainループ (training_step)
  • 検証ループ (validation_step)
  • テストループ (test_step)
  • オプティマイザー (configure_optimizers)
このように、データセットに依存しないモデルを構築することができ、簡単に共有することができます。Cifar-10分類のシステムを構築してみましょう。

1. 計算

LightningModule のこのコンポーネントは、モデルアーキテクチャとフォワードパスを包含しています。このコードスニペットは、通常のPyTorchのコードと見覚えがあるかもしれません。
モデルで必要なハイパーパラメータはすべて __init__ を通して渡すことができます。多くの場合、異なるハイパーパラメータでモデルの多くのバージョンをトレーニングします。 save_hyperparameters を呼び出すことで、Lightning に __init__ の値をチェックポイントに保存するように依頼することができます。これは便利な機能です。
_get_conv_output_forward_features という二つのメソッドがあることにお気づきでしょう。これらは、畳み込みブロックの出力のテンソルサイズを自動的に計算するために使用されます。それについてはこちらで学んでください。
f forward 方法は、通常のPyTorchのコードでは見慣れたものに見えるかもしれません。しかし、Lightningでは forward は推論アクションを定義するためにのみ使われます。 training_stepは訓練ループを定義しています。
class LitModel(pl.LightningModule):
def __init__(self, input_shape, num_classes, learning_rate=2e-4):
super().__init__()
# log hyperparameters
self.save_hyperparameters()
self.learning_rate = learning_rate
self.conv1 = nn.Conv2d(3, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 32, 3, 1)
self.conv3 = nn.Conv2d(32, 64, 3, 1)
self.conv4 = nn.Conv2d(64, 64, 3, 1)

self.pool1 = torch.nn.MaxPool2d(2)
self.pool2 = torch.nn.MaxPool2d(2)
n_sizes = self._get_conv_output(input_shape)

self.fc1 = nn.Linear(n_sizes, 512)
self.fc2 = nn.Linear(512, 128)
self.fc3 = nn.Linear(128, num_classes)

self.accuracy = torchmetrics.Accuracy()

# returns the size of the output tensor going into Linear layer from the conv block.
def _get_conv_output(self, shape):
batch_size = 1
input = torch.autograd.Variable(torch.rand(batch_size, *shape))

output_feat = self._forward_features(input)
n_size = output_feat.data.view(batch_size, -1).size(1)
return n_size
# returns the feature tensor from the conv block
def _forward_features(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(F.relu(self.conv2(x)))
x = F.relu(self.conv3(x))
x = self.pool2(F.relu(self.conv4(x)))
return x
# will be used during inference
def forward(self, x):
x = self._forward_features(x)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.log_softmax(self.fc3(x), dim=1)
return x


2. トレーニングループ

Lightningはエポックやバッチの反復など、訓練のほとんどを自動化してくれるので、必要なのはトレーニングステップのロジックだけです。 training_step 方法には batchbatch_idx の引数が必要で、これは Trainer から自動的に渡されます。トレーニングループについて詳しくはこちら
エポック単位のメトリクスを計算するために、 .log 方法に on_epoch=True を渡します。ステップごとのメトリクスが自動的にログに記録されます。これをオフにするには、 on_step=False を渡します。
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
# training metrics
preds = torch.argmax(logits, dim=1)
acc = self.accuracy(preds, y)
self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
return loss

3. 検証ループ

トレーニングループと同様に、 LightningModulevalidation_step 方法を上書きすることで、検証ループを実装することができます。検証ループについてはこちらをご覧ください。
メトリクスは、自動的にエポック単位でログが記録されます。

def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)

# validation metrics
preds = torch.argmax(logits, dim=1)
acc = self.accuracy(preds, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return loss

4. テストループ

テストループはバリデーションループと似ています。唯一の違いは、 trainer.test() が使われたときだけテストループが呼び出されることです。テストループについてはこちらをご覧ください。
メトリクスは、自動的にエポック単位でログが記録されます。
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
# validation metrics
preds = torch.argmax(logits, dim=1)
acc = self.accuracy(preds, y)
self.log('test_loss', loss, prog_bar=True)
self.log('test_acc', acc, prog_bar=True)
return loss

5. Optimizer

configure_optimizer 方法を用いて、オプティマイザや学習率スケジューラを定義することができる。GANのように複数のオプティマイザを定義することも可能である。
このメソッドについて詳しくはこちら
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
注意:Lightningを使用してPyTorchのコードをリファクタリングする場合、 LightningModule から .cuda().to() を削除してください。

🚋 訓練と評価

DataModule でデータパイプラインを、 LightningModule でモデルアーキテクチャと学習ループを整理したので、あとはPyTorch Lightning Trainerが全てを自動化してくれます。
Trainerが自動化するのは:
  • エポックとバッチイテレーション
  • optimizer.step()backwardzero_grad() の呼び出し。
  • .eval() の呼び出し、grad の有効化/無効化
  • 重みの保存と読み込み
  • 重みとバイアスのロギング
  • マルチGPUトレーニングのサポート
  • TPUのサポート
  • 16ビットトレーニング対応
トレーナーについて詳しくはこちら。これを使っていよいよモデルをトレーニングしてみましょう。
まずはデータパイプラインを初期化します。Trainerはtrain/val/testの分割のためにPyTorch DataLoader を必要とするだけです。作成した dm オブジェクトは直接Trainerに渡すことができます。しかし、 ImagePredictionLogger のサンプルが必要なので、 prepare_datasetup 方法を手動で呼び出すことにします。
# Init our data pipeline
dm = CIFAR10DataModule(batch_size=32)
# To access the x_dataloader we need to call prepare_data and setup.
dm.prepare_data()
dm.setup()

# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape
モデルのトレーニングはこんなに簡単ではありませんでした。モデルとお気に入りのロガーを初期化するだけです。 checkpoint_callback を別に渡していることに注意してください。
# Init our model
model = LitModel(dm.size(), dm.num_classes)

# Initialize wandb logger
wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')

# Initialize a trainer
trainer = pl.Trainer(max_epochs=50,
progress_bar_refresh_rate=20,
gpus=1,
logger=wandb_logger,
callbacks=[early_stop_callback,
ImagePredictionLogger(val_samples)],
checkpoint_callback=checkpoint_callback)

# Train the model ⚡🚅⚡
trainer.fit(model, dm)

# Evaluate the model on the held-out test set ⚡⚡
trainer.test()

# Close wandb run
wandb.finish()

以下のメディアパネルは、W&Bに記録されるメトリックスを表示します。


下のメディアチャートは、 ImagePredictionLogger カスタムコールバックの結果です。各画像の予測値とグランドトゥルースラベルを見ることができます。
⚙️アイコンをクリックしてスライダーを動かすと、エポック毎のモデルの予測値を見ることができます。



📉 精度-再現率曲線

画像分類モデルは徹底的にテストされる必要があります。精度-再現曲線の使用は標準的な方法です。
Weights & Biasesはカスタムベガプロットをサポートしており、ベガでサポートされているものであれば文字通り何でもプロットすることができます。平均的な精度-再現曲線を用いて、モデルのパフォーマンスを見てみましょう。
Weights & Biasesのカスタム可視化サポートについては、こちらのレポートをご覧ください。平均精度-リコール曲線の記録方法については、このレポートをご覧ください。
テスト精度は70%程度ですが、この分類器を改善するためにできることはたくさんあります。




最終的な感想

私はTensorFlow/Kerasのエコシステムから来たので、PyTorchはエレガントなフレームワークですが、少し圧倒される感じがします。ただ、個人的な経験ですが。PyTorch Lightningを試しているうちに、私がPyTorchから遠ざかっていた理由のほとんどすべてが解決されたことに気づきました。以下は、私の感動を簡単にまとめたものです。
  • 以前:従来のPyTorchのモデル定義は、どこもかしこもバラバラでした。 model.py スクリプトにモデルがあって、 train.py ファイルに訓練ループがあるような状態です。パイプラインを理解するために、何度も見返したりしていました。
  • LightningModule は、モデルが training_stepvalidation_step などとともに定義されるシステムとして機能します。これでモジュール化され、共有できるようになりました。
  • 以前:TensorFlow/Kerasの最も優れた部分は、入力データパイプラインです。そのデータセットカタログは豊富で、増え続けています。PyTorchのデータパイプラインは、以前は最大のペインポイントでした。通常のPyTorchのコードでは、データのダウンロード/クリーニング/準備は、通常、多くのファイルに散らばっています。
  • DataModule は、データパイプラインを1つの共有可能で再利用可能なクラスとして整理しています。これは、train_dataloader, val_dataloader(s), test_dataloader(s) と、それに必要な変換やデータ処理/ダウンロードのステップを単純にまとめたものです。
  • 以前:Kerasでは、 model.fit を呼び出してモデルを学習し、 model.predict を呼び出して推論を実行します。 model.evaluate は、テストデータに対する古き良きシンプルな評価を提供しました。これはPyTorchの場合ではありません。通常、 train.pytest.py のファイルが別々に存在します。
  • LightningModule があれば、 Trainer がすべてを自動化してくれます。 trainer.fittrainer.test を呼び出すだけで、モデルの訓練と評価を行うことができます。
  • 以前TensorFlowはTPU、PyTorch...を愛用しています。
  • PyTorch Lightningを使えば、複数のGPUTPUでも同じモデルを簡単に学習させることができるんだ。すごい!
  • 以前:私はCallbacksの大ファンであり、カスタムのコールバックを書くことを好んでいます。Early Stoppingのような些細なことが、従来のPyTorchとの議論のポイントになっていました。
  • PyTorch Lightningでは、Early StoppingModel Checkpointingを使うのは簡単です。カスタムコールバックも書けるし。
興奮という名の暴言を吐き続けることができそうです。PyTorch Lightningが提供するすべてのリストはこちらです。

🎨 結論とリソース

このレポートがお役に立てれば幸いです。私は、このコードで遊び、あなたが選んだデータセットで画像分類器を訓練することをお勧めします。
PyTorch Lightningについてより詳しく知るためのリソースをいくつか紹介します。
下のコメントであなたの考えを聞かせてください。
Iterate on AI agents and models faster. Try Weights & Biases today.