Skip to main content

Реализация прореживания в PyTorch: пример

Рассматривается пример, демонстрирующий регуляризацию модели PyTorch с прореживанием, дополненный кодом и интерактивными визуализациями
Created on January 11|Last edited on January 18
Этот отчет является переводом «Implementing Dropout in PyTorch: With Example » Ayush Thakur

Знакомство с прореживанием в PyTorch

В этом отчете мы увидим пример добавления прореживания в модель PyTorch и понаблюдаем за эффектом, который производит прореживание на функционирование модели посредством отслеживания наших моделей в Weights & Biases.

Определение прореживания

Прореживание – это техника машинного обучения, в которой вы удаляете (или «прореживаете») единицы в нейронной сети для моделирования обучения большого количества архитектур одновременно. Важно отметить, что прореживание может существенно снизить возможность переобучения в ходе обучения.


Пример добавления прореживания в модель PyTorch

1. Добавим прореживание в модель PyTorch

Добавление прореживания в модели PyTorch является достаточно простым, если использовать класс torch.nn.Dropout, который в качестве параметра предусматривает использование коэффициента прореживания, определяющего вероятность деактивации нейрона.
self.dropout = nn.Dropout(0.25)
Мы можем применить прореживание после любого невыводимого слоя.

2. Понаблюдаем за влиянием прореживания на производительность модели

Для наблюдения за эффектом прореживания обучим модель осуществлять классификацию изображений. Вначале мы обучим нерегуляризированную сеть, потом – сеть, регуляризированную посредством прореживания. Модели обучаются на наборе данных Cifar-10 для 15 эпох каждая.

Выполняем пример с добавлением прореживания в модель PyTorch:

class Net(nn.Module):
def __init__(self, input_shape=(3,32,32)):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3)
self.conv2 = nn.Conv2d(32, 64, 3)
self.conv3 = nn.Conv2d(64, 128, 3)
self.pool = nn.MaxPool2d(2,2)

n_size = self._get_conv_output(input_shape)
self.fc1 = nn.Linear(n_size, 512)
self.fc2 = nn.Linear(512, 10)

# Define proportion or neurons to dropout
self.dropout = nn.Dropout(0.25)
def forward(self, x):
x = self._forward_features(x)
x = x.view(x.size(0), -1)
x = self.dropout(x)
x = F.relu(self.fc1(x))
# Apply dropout
x = self.dropout(x)
x = self.fc2(x)
return x
Используя функцию wandb.log() в качестве вашей обучающей функции, вы сможете автоматически отслеживать функционирование вашей модели. Ознакомьтесь с соответствующими документами для получения подробной информации.
def train(model, device, train_loader, optimizer, criterion, epoch, steps_per_epoch=20):

# Log gradients and model parameters
wandb.watch(model)

# loop over the data iterator, and feed the inputs to the network and adjust the weights.
for batch_idx, (data, target) in enumerate(train_loader, start=0):
# ...
acc = round((train_correct / train_total) * 100, 2)
# Log metrics to visualize performance
wandb.log({'Train Loss': train_loss/train_total, 'Train Accuracy': acc})



Run set
2


Эффект использования прореживания в PyTorch

  • Нерегуляризованная сеть быстро переобучается на основе обучающего набора данных. Заметьте, как валидация потерь для запуска without-dropout быстро расходится всего лишь за несколько эпох. Этим объясняется более высокая ошибка обобщения.
  • Обучение с двумя слоями прореживания с вероятностью прореживания в 25% предотвращает переобучение модели. Однако это снижает точность обучения и означает, что регуляризованную сеть нужно обучать дольше.
  • Прореживание улучшает обобщение модели. Даже если точность обучения ниже, чем в нерегуляризованной сети, общая точность валидации улучшилась. Этим объясняется меньшая ошибка обобщения.
На этом мы завершим данное короткое учебное пособие по использованию прореживания в ваших моделях PyTorch.

Weights & Biases

Инструмент Weights & Biases помогает отслеживать эксперименты с машинным обучением. Используйте наш инструмент для регистрации гиперпараметров и выводимых показателей из ваших прогонов, затем визуализируйте и сравнивайте результаты и быстро делитесь результатами с коллегами.
Начните через 5 минут.

Iterate on AI agents and models faster. Try Weights & Biases today.