Skip to main content

基于宽残差网络(Wide Residual Networks)介绍Quickvision

使用Quickvision和CIFAR-10数据集对题为“Wide Residual Networks”的论文进行端到端复制
Created on December 22|Last edited on December 22

在Sergey Zagoruyko和Nikos Komodakis引入宽残差网络(Wide Residual Networks\WRN) 之前,深度残差网络已显示出部分性能提升,但代价是层数增加了一倍。 这减少了特征的重复使用,并且总体上使模型的训练速度变慢。 WRN显示,拥有更宽的residual network可带来更好的性能,能提高CIFAR,SVHN和COCO的SOTA结果。



👁 Quickvision介绍

Quickvision 是建立在Torchvision,PyTorch和Lightning之上的计算机视觉库(Computer Vision Library)

它提供:

  1. 易于使用的PyTorch本机API,适用于模型的fit()train_step()``和val_step()
  2. 具有各种主干的易于自定义和配置的模型。
  3. 完整的PyTorch本机界面。所有模型都是nn.Module,所有训练API是可选的,并且未绑定到模型。
  4. 一个闪电API,有助于加快对多个GPU,TPU的训练。
  5. 数据集API,可轻松,快速地将常见数据格式转换为PyTorch格式。
  6. 一个很小的安装包,具有非常低的依赖性。

Quickvision只是PyTorch!

  • Quickvision不会使您学习新的库。如果您知道PyTorch,那就足够了!!!
  • Quickvision不会从PyTorch提取任何代码,也不会在其上实施任何自定义类。
  • 它保持Tensor数据格式,因此您无需转换它。

您是否只需要具有某些主干配置的模型?

  • 使用我们做的模型。它只是一个nn.Module,具有仅适用张量输入和输出格式。
  • Quickvision也提供参考脚本进行训练!

您是否要训练模型但不想编写冗长的循环?

  • 只需使用我们的训练方法,例如fit()train_step()val_step()

您是否需要多GPU训练,但担心模型配置?

  • 只是继承PyTorch Lightning模型!
  • 执行train_step()val_step()


我们将向您展示:

  1. Quickvision可让您使用自己的数据集,模型或代码

  2. 您可以使用quickvision中的模型,训练函数或数据集加载工具。

  3. API可以与Lightning无缝连接。

  4. 使用和PyTorch或Lightning相同的操作,但进行更快的实验。

  5. 使用wandb.log API记录指标。

    通过GitHub访问我们

我们很欢迎对我们的package新的贡献或改进。

Quickivison是一个为更快的速度而构建的库,但不影响PyTorch训练!



‼️ 传统残差网络(Residual Networks)的问题

减少特征重用

具有恒等映射的残差块(Residual block)使我们可以训练非常深的网络,但这是一个缺点。当梯度流过网络时,没有什么可以强迫它通过残差块的权重,因此会避开训练期间的学习。这导致只有几个block可以运行有价值的表示,或者许多block共享很少的信息,而对最终目标的贡献很小。为了解决此问题,对于残差块可以使用一种特别的dropout,将标识标量权重添加到用了dropout的每个残差块。

随着我们扩宽残差块,这导致参数数量增加,并且作者决定研究dropout的影响,以规范训练并防止过度拟合。他们认为,应该将dropout插入卷积层之间,而不是插入到block的同一部分中,并表明这样做可以带来稳定的提高,从而产生新的SOTA结果。

论文Wide Residual Networks试图回答深度残差网络应该有多宽这个问题,并解决训练问题。




📚 重要要点

该论文重点介绍了一种方法,与ResNet-1001相比,总体改进了4.4%,并显示:

  • 拓宽’可稳定提高不同深度的残差网络的性能

  • 增加深度和宽度有助于提高性能,直到参数数量变得太高需要更强的regularization

*从很深的残差网络变成很宽的残差网络(两者参数数量相同)似乎没有出现regularization作用。此外,与“瘦”残差网络相比,宽残差网络可以成功学习两倍或更多数量的参数——“瘦”网络的深度加倍才能实现,使其训练成本过高。



💪🏻训练

在本教程中,我们将使用WideResnet 模型(包含在0.2.0rc1版本中)。 您可以使用以下方法下载该库的稳定版本

pip install quickvision

当前的稳定版本0.1需要PyTorch 1.7和Torchvision 0.8.1




Quickvision提供简单的函数来创建具有预训练权重的模型。

from quickvision.models.classification import cnn

# To create model with imagenet pretrained weights 
model = cnn.create_vision_cnn("wide_resnet101_2", num_classes=10, pretrained="imagenet")

# Alternatively if you don't need pretrained weights
model_bare = cnn.create_vision_cnn("resnet50", num_classes=10, pretrained=None)

# It also supports other weights, do check a list which are supported !
model_ssl = cnn.create_vision_cnn("resnet50", num_classes=10, pretrained="ssl")


就像在torch中一样,我们定义标准和优化器

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)


而不是像这样

model = model.to(device)
for epoch in range(2):
    for batch_idx, (inputs, target) in enumerate(train_loader):
        optimizer.zero_grad()
        inputs = inputs.to(device)
        target = target.to(device)
        out = model(inputs)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()

Quickvision已经为您执行了这些无聊的过程,以加快训练速度!

您可以使用所示的.fit()方法通过单行代码来训练模型!

history = cnn.fit(model=model, epochs=2, train_loader=train_loader,
        val_loader=valid_loader, criterion=criterion, device=device, optimizer=optimizer)

如果您希望使用更精细的控制,则可以使用我们的train_step()val_step()方法。 我们为您计算常用的指标,例如准确性。

wandb.init(project="intro-to-quickvision")

for epoch in tqdm(range(5)):
    print()
    print(f"Training Epoch = {epoch}")
    train_metrics = cnn.train_step(model, train_loader, criterion, device, optimizer)
    print()
    wandb.log({"Training Top1 acc": train_metrics["top1"], "Training Top5 acc": train_metrics["top5"], "Training loss": train_metrics["loss"]})

    print(f"Validating Epoch = {epoch}")
    valid_metrics = cnn.val_step(model, valid_loader, criterion, device)
    print()
    wandb.log({"Validation Top1 acc": valid_metrics["top1"], "Validation Top5 acc": valid_metrics["top5"], "Validation loss": valid_metrics["loss"]})



This set of panels contains runs from a private project, which cannot be shown in this report


您也可以使用Lightning进行训练!

  • 我们也为PyTorch Lightning使用了相同的逻辑。
  • 这直接允许您使用所有Lighning功能,例如Multi-GPU训练,TPU训练,日志记录等。

用Torch快速制作原型,将其传输到Lightning!

model_imagenet = cnn.lit_cnn("resnet18", num_classes=10, pretrained="imagenet")

gpus = 1 if torch.cuda.is_available() else 0

# Again use all possible Trainer Params from Lightning here !!
trainer = pl.Trainer(gpus=gpus, max_epochs=2)
trainer.fit(model_imagenet, train_loader, valid_loader)



This set of panels contains runs from a private project, which cannot be shown in this report