Skip to main content

Road to 0.81

CheXpert X-Ray image classification
Created on October 18|Last edited on December 19

Abstract

Суммарно на все нижеописанное было потрачено примерно 30 часов моего времени, и примерно 2 недели компьюта хорошей карточки. Однако, как это обычно и бывает, надо было work smart, not hard.

Proposed method

Финальное решение выглядит относительно адекватно, хоть и использует некоторые интересные трюки.
Предобработка данных была выдумана после пристального смотрения на \approx 100 картинок. Во-первых, отказ от ресайза - это здорово. Тут у нас большие картинки (320х390 обычно) и мы не хотим терять информацию. Все полезное содержится в центре, так что было принято волевое решение отрезать края центркропом. Во-вторых, там людей на фотках колбасит туда-сюда во всех направлениях. Чтобы это учесть будем делать скейл и случайную перспективу с не очень агрессивными параметрами. В-третьих, авторы статьи по датасету используют RandomHorizontalFlip. С точки зрения здравого смысла я это осудил, но на практике помогло, так что почему нет.
transforms_dict = dict(
train=transforms.Compose([
transforms.CenterCrop(320),
transforms.RandomAffine(degrees=0, scale=(0.8, 1.1)),
transforms.RandomPerspective(0.3),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.54], [0.26]),
]),
val=transforms.Compose([
transforms.CenterCrop(320),
transforms.ToTensor(),
transforms.Normalize([0.54], [0.26]),
]),
)
Ко всему прочему конвертирую картинки в RGB, что позволяет использовать сетки из коробки, не меняя слои.
Пройдемся по основным гиперпараметрам обучения:
from torchvision.models import efficientnet_b0
import torch

lr = 0.001
n_epochs = 24
batch_size = 64

seed_everything(0xbebebe)
model = efficientnet_b0(num_classes=5)
criterion = torch.nn.BCEWithLogitsLoss() # == Sigmoid + BCELoss
optimizer = torch.optim.AdamW(model.parameters(), lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=lr,
pct_start=0.2,
total_steps=n_epochs * (len(train_dataset.data) // batch_size + 1)
)
lr, число эпох и размер батча не слишком подвергались сомнению, и в процессе экспериментов почти не менялись (разве что чуть-чуть в целях небольшого ablation), как и random_seed. Модель была найдена путем достаточно мучительного перебора (до нее упорно пытался завести ResNet - безуспешно, и DenseNet121 - более успешно). Финальная - самая маленькая из семейства, увеличение не влияет на скор (мб потому что мало эпох). Оптимайзер в последних экспах был выбран AdamW просто потому что давайте попробуем, но Adam с дефолтными параметрами показывал ТОЧНО ТАКОЙ ЖЕ результат, так что забейте. Шедулер все время использовался с понижением в Х раз каждые У эпох, в последние пару экспов был заменен на более умный (который еще и красивую картинку lr'a рисует), что, впрочем, тоже не сильно помогло.

Самое интересное

Ну обучили мы значит эту нашу финальную версию, а что по скорам? Все описанное выше дает в лб где-то 0.803-0.804, а хочется-то выше. Здесь в ход идут ансамбли хитрости. Умными исследователями было предложено усреднять чекпоинты моделей, близких к оптимуму (см. SWA). Так как пишется это примерно в 10 строчек, то отказываться от такого мы, конечно, не будем. И вуаля, скор в лидерборде становится уже 0.806.
def averaging(base_model, paths):
target_state_dict = base_model(num_classes=5).state_dict()
for key in target_state_dict:
target_state_dict[key].data.fill_(0.)
for path in tqdm(paths):
model = base_model(num_classes=5)
model.load_state_dict(torch.load(path))
state_dict = model.state_dict()
for key in target_state_dict:
if target_state_dict[key].data.dtype != torch.float32:
continue
target_state_dict[key].data += state_dict[key].data.clone() / len(paths)
return target_state_dict
Но можно лучше! Идея стара как мир и проста как пареная репа - Test time augmentation. Будем предсказываться на тесте не 1 раз с детерминированными аугментациями, а 41 раз с трейновыми - случайными. А так как они достаточно сильные, то и предсказания каждый раз будут получаться немного разные (надеемся на то, что в измененной картинке модель увидит что-то, чего не увидела в обычной). И вот это дает решающий прирост в качестве 0.806->0.81.
dataloader = torch.utils.data.DataLoader( # ключевое слово тут train
CustomDataset('data/sample_submission.csv', get_transforms('train')),
batch_size=256, num_workers=8
)
tta = 40
preds = inference(model, dataloader, device)
for t in range(tta):
preds += inference(model, dataloader, device)
Графики финальной модельки, если вдруг интересно



Страдания Что еще пробовал

Вы просили, мы рассказываем! Как уже упоминалось выше - на данное дз я потратил достаточное количество времени и электричества (carbon footprint сравнимо разве что с PALM 540В). Поэтому тут точно есть о чем рассказать.
ФАКАП: во всех ранних экспериментах можно заметить график от времени с неприлично большими значениями. Это происходит потому, что первые 168 часов выполнения домашки я обучался с num_workers=0 (default value в Dataloader'e). Удивлению не было предела, когда я ускорил резнет в 40 раз командой num_workers=4.
Сначала в ход пошла маленькая самописная сетка (3 слоя or smth)



Когда выяснилось, что мой сомнительный код для обучения способен обучить нейросеть, в ход пошел всеми любимый бейзлайн - ResNet18. И (ого) она отлично обучилась на свои 0.77. Надо отметить, что на всех экспериментах (кроме, может быть, последних) валидационное качество до 3 знаков совпадало с результатами в лб.


Дальше было много не очень интересных экспериментов с резнетами (разного размера, с разными аугментациями и тп). Кроме того, были испробованы MLPMixer (ему, оказалось, нужно сильно больше компьюта для достижения тех же результатов) и WideResNet (он тоже показал себя существенно хуже обычных резнетов, чему объяснения я не нашел).
В легком бессилии я решил посмотреть что пишут про эти данные и нашел статью авторов датасета. Они утверждали, что DenseNet121 победил всех в их экспериментах. Так как поводов не верить ученым из стенфорда у меня не было, то я лихо внедрил у себя это новшество и сразу стало хорошо.


Так же были попробованы разные лоссы (например SoftMargin), но получилось так себе. Еще были эксперименты, где модель выступала в роли feature extractor'a, а поверх обучалась лог регрессия из склерна. И да, формально это был не ансамбль, потому что я инициализировал обученными весами новый clf слой в конце сетки. Но никакого профита это не дало (видимо, голова модельки отлично обучилась и в коррекции не нуждается).
Дальше были попробованы efficientnet'ы разных размеров - как уже оговаривалось, разницы никакой не было. И конечная версия решения плавно вытекла из всего попробованного.

Что не пробовал

Есть так же большой список прикольных трюков, которые могли бы помочь, но применены в силу моей лени или каких-либо других обстоятельств не были. Тут вы можете найти их неполный список.
  • SGD. Известная история - для него сложнее подбирать lr, нужно больше эпох и правильный шедулер. Но частенько он дает оптимумы лучше других на задаче классификации изображений. Все это я вспомнил после последней посылки.
  • SWA. Не усреднение чекпоинтов за 16, 17, 19, 21 эпохи, а настоящий статейный SWA.
  • Взвешенный BCELoss. Была гипотеза, что можно взвесить семплы пропорционально частоте классов и передать это в лосс. Но оказалось, хорошее качество выбивается и так.
  • MixUp. Общепризнано хорошая аугментация, которая часто сильно помогает. Руки были заняты и не дошли.
  • Много чего еще, но я забыл

Заключение

Если все вышеописанное держать в голове перед решением очередного контеста на выбивание скора (на картинках в частности), правильно расставлять приоритеты экспериментов, а так же вовремя гуглить, то хороших результатов можно достичь существенно быстрее. Когда-нибудь я научусь. Когда-нибудь...