Skip to main content

使用 PyTorch Lightning、权重和偏差进行图像分类

本文介绍了如何使用 PyTorch Lightning 来提高 PyTorch 代码的可读性和可再现性。
Created on January 6|Last edited on January 6
本报告是作者Ayush Thakur所写的"Image Classification Using PyTorch Lightning and Weights & Biases"的翻译


在本文中,我们将使用 PyTorch Lightning 构建一个图像分类管道。我们将遵循该风格指南来提高代码的可读性和可再现性。


目录



⚡️什么是 PyTorch Lightning?

PyTorch 是一种非常强大的深度学习研究框架。但是,进行深入研究后,像 16 位精度、多 GPU 训练和 TPU 训练混合在一起,用户很可能会引入 bug。PyTorch Lightning 可支持工程分离式研究。
让我们使用 PyTorch Lightning 构建一个图像分类管道。本文可以作为熟悉 PyTorch Lightning 的入门指南。
PyTorch Lightning ⚡ 并非另外一种框架,而是 PyTorch 风格指南。

⏳ 安装和导入

在本教程中,我们需要 PyTorch Lightning(这不很明显吗!)、权重和偏差。
# install pytorch lighting
! pip install pytorch-lightning --quiet
# install weights and biases
!pip install wandb --quiet
除了常规 PyTorch 导入,还需要这些 ⚡导入。
import pytorch_lightning as pl
# your favorite machine learning tracking tool
from pytorch_lightning.loggers import WandbLogger
我们将使用 WandbLogger 来跟踪我们的实验结果,并将其直接记录在 W&B 中。

🔧 数据模块——我们应有的数据管道

数据模块是一种将相关数据钩子与 Lightning 模块分离的方法,因此可以开发与数据集无关的模型。
可将数据管道组织为可共享可重用类别。数据模块对 PyTorch 中所涉及的数据处理五个步骤进行了封装:
  • 下载/标记化/处理。
  • 清理并(可能)保存到磁盘。
  • 在数据集中加载。
  • 应用变换(旋转、标记等)。
  • 包裹在数据加载器中。
点击 此处 了解更多关于数据模块的信息。让我们为 CIFAR-10 数据集构建一个数据模块。

1. Init

CIFAR10 数据模块 为 PyTorch Lightning 的 Lightning 数据模块 子类。我们将使用 __init__ 方法传入数据管道所需的超参数。我们还将在此处定义数据转换管道。
class CIFAR10DataModule(pl.LightningDataModule):
def __init__(self, batch_size, data_dir: str = './'):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size

self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.dims = (3, 32, 32)
self.num_classes = 10

2. Perpare_data

在这里我们将对数据集下载逻辑进行定义。我们正在使用 torchvision 的 CIFAR10 数据集类进行下载。使用该方法可执行写入磁盘操作,或在分布式设置中执行由单个 GPU 完成的操作。禁止在该函数中进行任何状态分配(即, self.something = ...)。
def prepare_data(self):
# download
CIFAR10(self.data_dir, train=True, download=True)
CIFAR10(self.data_dir, train=False, download=True)

3. Setup_data

在这里,我们将从文件中加载数据,并为每次分割准备 PyTorch 张量数据集。因此,数据拆分是可重复的。该方法需要一个 stage 参数,用于分离“训练周期”和“测试周期”逻辑。如果不想一次性加载整个数据集,这个方法很有帮助。我们在此处对每个 GPU 上执行的数据操作进行了定义。包括 PyTorch 张量数据集应用转换。
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)


4. X_dataloader

train_dataloader(), val_dataloader(), 和 test_dataloader() 均会返回 PyTorch 数据加载器实例,这些实例通过包裹在 setup() 中所准备的各自数据集进行创建。
def train_dataloader(self):
return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)

def val_dataloader(self):
return DataLoader(self.cifar_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size)


📱 回调

回调是一个自包含程序,可跨项目重复使用。PyTorch Lightning 配备了一些常用的内置回调
点击 此处 了解更多有关 PyTorch Lightning 中的回调信息。

内置回调

在本教程中,我们将使用提前终止模型检查点内置回调。它们可以传递给 Trainer

自定义回调

如果你熟悉自定义 Keras 回调,那么在 PyTorch 管道中进行类似操作就非常简单了。
由于我们正在进行图像分类,因此在一些图像样本上可直观看到模型预测结果是很有帮助的。该回调形式有助于在早期阶段对模型进行调试。

1.__Init__

PyTorch Lightning 回调图像预测记录器子类。这里我们将传递 val_samples,其为一个图像和标签元组。 num_samples 为记录到 W&B 仪表板的图像数量。
class ImagePredictionLogger(Callback):
def __init__(self, val_samples, num_samples=32):
super().__init__()
self.num_samples = num_samples
self.val_imgs, self.val_labels = val_samples

2. 回调钩子

点击此处可获得所有可用的回调钩子。
在验证 epoch 结束时调用 On_validation_epoch_end 方法。需要两个参数—— trainerpl_module,这两个参数由 Trainer 自动传递。
通过使用 trainer.logger.experimental,我们就可以使用权重和偏差提供的所有功能。
def on_validation_epoch_end(self, trainer, pl_module):
# Bring the tensors to CPU
val_imgs = self.val_imgs.to(device=pl_module.device)
val_labels = self.val_labels.to(device=pl_module.device)
# Get model prediction
logits = pl_module(val_imgs)
preds = torch.argmax(logits, -1)
# Log the images as wandb Image
trainer.logger.experiment.log({
"examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")
for x, pred, y in zip(val_imgs[:self.num_samples],
preds[:self.num_samples],
val_labels[:self.num_samples])]
})

我们将看到该回调结果。

🎺 Lightning 模块——系统定义

Lightning 模块定义了一个系统,而非模型。此处,系统将所有研究代码归入单一类中,使其自包含。 Lightning 模块 可将 PyTorch 代码分为 5 部分:
The LightningModule defines a system and not a model. Here a system groups all the research code into a single class to make it self-contained. LightningModule organizes your PyTorch code into 5 sections:
  • 计算 (__init__).
  • 训练循环 (training_step)
  • 验证循环 (validation_step)
  • 测试循环 (test_step)
  • 优化器 (configure_optimizers)
因此,人们可构建一个可轻松共享,与数据集无关的模型。让我们建立一个 Cifar-10 分类系统。

1. 初始化相关计算

Lightning 模块组件中包含了模型架构和前向传递。该代码片断看起来类似于正常的 PyTorch 代码。
你可以通过 __init__ 传递模型所需的所有超参数。通常情况下,我们用不同超参数对一个模型的多个版本进行训练。通过调用 save_hyperparameters,我们可通过 Lightning 将 __init__ 中的任何值保存至检查点。这是一个非常实用的功能。
主要有 _get_conv_output_forward_features 两种方法。可以用于自动计算卷积块输出的张量大小。点击此处以了解更多信息。
对于普通的 PyTorch 代码来说,前向方法可能看起来比较熟悉。然而,在 Lightning 中, 前向 方法仅用于定义推断动作。 Training_step 定义了训练循环。
class LitModel(pl.LightningModule):
def __init__(self, input_shape, num_classes, learning_rate=2e-4):
super().__init__()
# log hyperparameters
self.save_hyperparameters()
self.learning_rate = learning_rate
self.conv1 = nn.Conv2d(3, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 32, 3, 1)
self.conv3 = nn.Conv2d(32, 64, 3, 1)
self.conv4 = nn.Conv2d(64, 64, 3, 1)

self.pool1 = torch.nn.MaxPool2d(2)
self.pool2 = torch.nn.MaxPool2d(2)
n_sizes = self._get_conv_output(input_shape)

self.fc1 = nn.Linear(n_sizes, 512)
self.fc2 = nn.Linear(512, 128)
self.fc3 = nn.Linear(128, num_classes)

self.accuracy = torchmetrics.Accuracy()

# returns the size of the output tensor going into Linear layer from the conv block.
def _get_conv_output(self, shape):
batch_size = 1
input = torch.autograd.Variable(torch.rand(batch_size, *shape))

output_feat = self._forward_features(input)
n_size = output_feat.data.view(batch_size, -1).size(1)
return n_size
# returns the feature tensor from the conv block
def _forward_features(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(F.relu(self.conv2(x)))
x = F.relu(self.conv3(x))
x = self.pool2(F.relu(self.conv4(x)))
return x
# will be used during inference
def forward(self, x):
x = self._forward_features(x)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.log_softmax(self.fc3(x), dim=1)
return x


2. 训练循环

Lightning 为我们自动完成了大部分训练,包括 epoch 和 batch 迭代,我们仅需要保留训练步骤逻辑。 training_step 方法需要 batchbatch_idx 参数,这些参数由 Trainer 自动传递。点击此处了解更多有关训练循环的信息
为了计算 Epoch Wise 指标,将 on_epoch=True 传递给 .log 方法。逐步 (Step-wise) 指标会被自动记录下来。想要关闭,需传递 on_step=False
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
# training metrics
preds = torch.argmax(logits, dim=1)
acc = self.accuracy(preds, y)
self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
return loss

3. 验证循环

与训练循环类似,验证循环可以通过覆盖 Lightning 模块validation_step 方法实现。点击此处了解更多有关验证循环的信息。
这些指标会自动按 epoch-wise 记录下来。
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)

# validation metrics
preds = torch.argmax(logits, dim=1)
acc = self.accuracy(preds, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return loss

4. 测试循环

测试循环类似于验证循环。唯一的区别为仅在使用 trainer.test() 时才会调用测试循环。点击此处了解更多有关测试循环的信息。
这些指标会自动按 epoch-wise 记录下来。
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
# validation metrics
preds = torch.argmax(logits, dim=1)
acc = self.accuracy(preds, y)
self.log('test_loss', loss, prog_bar=True)
self.log('test_acc', acc, prog_bar=True)
return loss

5. 优化器

我们可以使用 configure_optimizer 方法定义优化器和学习率调度器。甚至可以像 GAN 一样定义多个优化器。
点击此处了解更多有关该方法的信息。
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
注意:如果正在使用 Lightning 重构 PyTorch 代码,请从 Lightning 模块 中删除 .cuda().to()

🚋 训练和评估

现在我们已经用数据模块组织了数据管道,用 Lightning 模块组织了模型架构 + 训练循环,PyTorch Lightning Trainer 对其他剩余部分进行了自动化。
Trainer 自动化:
  • Epoch 和 batch 迭代
  • 调用 optimizer.step()backwardzero_grad()
  • 调用 .eval(),启用/禁用 grad
  • 保存和加载权重
  • 权重和偏差记录
  • 支持多 GPU 训练
  • 支持 TPU
  • 支持 16 位训练
点击此处了解更多有关 Trainer 的信息。让我们用这个来训练模型。
我们首先会对我们的数据管道进行初始化。Trainer 仅需要一个 PyTorch 数据加载器,用于区分训练/验证/测试。我们可以直接将已创建的 dm 对象传递给Trainer。但是由于图像预测记录器需要一些示例,我们将手动调用 prepare_data设置方法。
# Init our data pipeline
dm = CIFAR10DataModule(batch_size=32)
# To access the x_dataloader we need to call prepare_data and setup.
dm.prepare_data()
dm.setup()

# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape
训练模型从未如此简单。我们仅需要初始化模型和我们最喜欢的记录器。注意,我们已单独传递了 checkpoint_callback
# Init our model
model = LitModel(dm.size(), dm.num_classes)

# Initialize wandb logger
wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')

# Initialize a trainer
trainer = pl.Trainer(max_epochs=50,
progress_bar_refresh_rate=20,
gpus=1,
logger=wandb_logger,
callbacks=[early_stop_callback,
ImagePredictionLogger(val_samples)],
checkpoint_callback=checkpoint_callback)

# Train the model ⚡🚅⚡
trainer.fit(model, dm)

# Evaluate the model on the held-out test set ⚡⚡
trainer.test()

# Close wandb run
wandb.finish()

以下媒体面板显示了所记录的 W&B 指标。

Run set
0

以下媒体图表为 图像预测记录器 自定义回调的结果。你可以看到每个图像的预测结果和 Ground Truth 标签。
点击⚙️图标,移动滑块,查看模型在每个 epoch 的预测结果。

Run set
0


📉 精确度—召回曲线

图像分类模型需要进行彻底测试。使用精确度—召回曲线为标准做法。
权重和偏差支持自定义 vega 图,使用其可绘制 vega 所支持的任何东西。让我们用精确度—召回曲线来观察该模型的性能。
查看此报告,了解更多由权重和偏差提供的自定义可视化支持相关信息。查看此报告,了解如何记录平均精确度—召回曲线
尽管我们的测试准确率为 70% 左右,但该分类器仍有很大改进空间。


Run set
0


结语

我经常使用 TensorFlow/Keras 生态系统,发现 PyTorch 框架虽然比较简练,但却有点让人难以承受。不过这仅为我的个人观点。在探索 PyTorch Lightning 的过程中,我意识到几乎所有能让我远离 PyTorch 的原因均得到了解决。下面让我兴奋地快速总结一下:
  • 当时:传统的 PyTorch 模型定义比比皆是。模型在某个 model.py 脚本中,训练循环在 train.py 文件中。为了了解管道情况,我们来回看了很多遍。
  • 现在Lightning 模块 作为一个系统,其模型与 training_stepvalidation_step 等一起定义,现在可以模块化并可共享。
  • 当时TensorFlow/Keras 最好的部分为输入数据管道。他们的数据集目录非常丰富,且还在不断增加。PyTorch 的数据管道曾经为最大的痛点。在正常 PyTorch 代码中,数据的下载/清理/准备工作通常分散在许多文件中。
  • • 现在:数据模块将数据管道组织为一个可共享和可重用的类别。现在仅为 train_dataloader、val_dataloader、test_dataloader 及所需匹配转换和数据处理/下载步骤的集合。
  • 当时:使用 Keras,可调用 model.fit 训练模型,调用 model.predict 运行推断。 model.evaluate 对测试数据进行了一个简单的旧评估。PyTorch 的运行方式与其不同。人们通常会发现单独的 train.pytest.py 文件。
  • 现在:通过 Lightning 模块Trainer 可对所有步骤进行自动化。人们仅需调用 trainer.fittrainer.test 来训练和评估模型。
  • 当时TensorFlow 擅长应用 TPU,PyTorch……也是!
  • 现在:通过 PyTorch Lightning,使用多个 GPU 甚至 TPU 来训练同一个模型就变得非常容易。真不错!
  • 当时:我热衷于使用回调,更喜欢编写自定义回调。类似于提前终止这类小问题,曾经是传统 PyTorch 的讨论焦点。
  • 现在:通过 PyTorch Lightning,使用提前终止和模型检查点就变得非常容易了。我甚至可以编写自定义回调。
我非常高兴。以下是 PyTorch Lightning 提供的所有内容清单

🎨 结论和资源

我希望这份报告对你有所帮助。我鼓励你使用该代码,并使用所选择的数据集训练图像分类器。
这里有一些资源,可了解更多有关 PyTorch Lightning 的信息:
请在下方评论中将你的想法告诉我。

Iterate on AI agents and models faster. Try Weights & Biases today.