Skip to main content

PyTorch Lightning 및 Weights & Biases을 이용한 이미지 분류

이 문서에서는 PyTorch Lightning을 사용하여 PyTorch 코드의 가독성과 재현성을 향상하는 실질적인 방법을 소개합니다.
Created on January 6|Last edited on January 27
이는 여기에서 볼 수 있는 영어 기사를 번역한 것이다.


이 문서에서는 PyTorch Lightning을 사용하여 이미지 분류 파이프라인을 구축해 보겠습니다. 코드의 가독성과 재현성을 높이기 위해 이 스타일 가이드를 따르도록 하겠습니다.


목차



⚡️PyTorch Lightning이란?

PyTorch는 딥 러닝 연구를 위한 매우 강력한 프레임워크입니다. 하지만 연구가 복잡해지고 16비트 정밀도 및 다중 GPU 학습 및 TPU 학습 등이 혼재하게 되면, 버그가 발생할 가능성도 높아집니다. PyTorch Lightning은 연구와 엔지니어링을 분리할 수 있게 해줍니다.
PyTorch Lightning을 사용하여 이미지 분류 파이프라인을 구축해 보죠. PyTorch Lightning 사용에 익숙해지기 위한 시작 가이드로 생각하시기 바랍니다.
PyTorch Lightning ⚡은 또 다른 프레임워크가 아닌 PyTorch를 위한 스타일 가이드입니다.

⏳ 설치 및 가져오기

본 튜토리얼을 위해서는 PyTorch Lightning(당연하죠!)과 Weights & Biases가 필요합니다.
# install pytorch lighting
! pip install pytorch-lightning --quiet
# install weights and biases
!pip install wandb --quiet
보통의 PyTorch 가져오기 외에, 아래 ⚡가져오기도 필요합니다.
import pytorch_lightning as pl
# your favorite machine learning tracking tool
from pytorch_lightning.loggers import WandbLogger
WandbLogger는 실험 결과를 추적하고 W&B에 직접 기록하는 데에 사용하겠습니다.

🔧 DataModule - 우리에게 꼭 맞는 데이터 파이프라인

DataModules은 LightningModule 에서 데이터 관련 후크를 분리하는 하나의 방법입니다. 따라서 데이터세트 애그노스틱 모델을 개발할 수 있죠.
DataModules은 데이터 파이프라인을 공유 및 재사용 가능한 하나의 클래스로 구성합니다. PyTorch에서 하나의 데이터 모듈은 데이터 처리와 관련된 다섯 단계를 캡슐화합니다:
  • 다운로드 / 토큰화 / 프로세스.
  • 정리 및 (아마도) 디스크에 저장.
  • 내부 데이터세트 로드.
  • 변환 적용 (회전, 토큰화 등).
  • DataLoader 내부 래핑.
데이터 모듈에 대한 자세한 내용은 여기를 참조하세요. 이제 CIFAR-10 데이터세트를 위한 데이터 모듈을 구축해 보겠습니다.

1. Init

CIFAR10DataModulePyTorch Lightning LightningDataModule 의 하위 클래스로 분류됩니다. __init__ 메서드를 사용하여 데이터 파이프라인에 필요한 하이퍼파라미터를 전달합니다. 데이터 변환 파이프라인도 여기에서 정의합니다.
class CIFAR10DataModule(pl.LightningDataModule):
def __init__(self, batch_size, data_dir: str = './'):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size

self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.dims = (3, 32, 32)
self.num_classes = 10

2. Perpare_data

여기에서는 데이터세트를 다운로드하기 위한 논리를 정의합니다. torchvision의 CIFAR10 데이터세트 클래스를 사용하여 다운로드합니다. 디스크에 저장하거나 분산 설정에서 단일 GPU에서만 수행해야 하는 작업을 수행하려면 이 메서드를 사용하시기 바랍니다. 이 함수 안에서는 상태 할당은 하지 않도록 하십시오 (예: self.something = ... ).
def prepare_data(self):
# download
CIFAR10(self.data_dir, train=True, download=True)
CIFAR10(self.data_dir, train=False, download=True)

3. Setup_data

여기서는 파일의 데이터를 로드하고 각 분할에 대한 PyTorch 텐서 데이터세트를 준비할 수 있습니다. 따라서 데이터 분할을 재현할 수 있죠. 이 메서드에는 '적합' 및 '테스트' 논리를 분리하는 데 사용되는 stage 인수가 필요합니다. 이는 전체 데이터세트를 한 번에 불러오지 않고자 하는 경우에 유용합니다. 모든 GPU에서 수행하고자 하는 데이터 작업은 여기에서 정의되는데, PyTorch 텐서 데이터세트에 변환을 적용하는 작업도 포함됩니다.
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)


4. X_dataloader

train_dataloader(), val_dataloader(), 및 test_dataloader() 는 모두 setup()에서 준비한 해당 데이터세트를 래핑하여 생성된 PyTorchDataLoader 인스턴스를 반환합니다
def train_dataloader(self):
return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)

def val_dataloader(self):
return DataLoader(self.cifar_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size)


📱콜백

콜백은 프로젝트 전반에 걸쳐 재사용할 수 있는 자체 포함 프로그램입니다. PyTorch Lightning에는 자주 사용되는 몇 가지 기본 제공 콜백이 있습니다.
PyTorch Lightning의 콜백에 대한 자세한 내용은 여기에서 알아보세요.

기본 제공 콜백

이 튜토리얼에서는 기본 제공된 Early StoppingModel Checkpoint 콜백을 사용하겠습니다. 해당 콜백은 Trainer 에 전달됩니다.

콜백 사용자 지정

Keras 콜백 사용자 지정에 익숙한 분에게는 PyTorch 파이프라인에서도 동일하게 할 수 있다면 더할 나위 없이 유용한 기능이 되겠죠.
지금 수행하는 작업은 이미지 분류이므로, 일부 이미지 샘플에 대한 모델의 예측을 시각화하는 게 도움이 될 겁니다. 이를 콜백 형태로 처리하면, 모델을 초기 단계에서 디버깅하는 데 유용합니다.

1.__Init__

ImagePredictionLogger 는 PyTorch Lightning Callback 클래스의 하위 클래스입니다. 여기서는 이미지와 레이블의 튜플인 val_samples 를 전달합니다. num_samples 은 W&B 대시보드에 기록할 이미지 개수입니다.
class ImagePredictionLogger(Callback):
def __init__(self, val_samples, num_samples=32):
super().__init__()
self.num_samples = num_samples
self.val_imgs, self.val_labels = val_samples

2. 콜백 후크 (Callback Hooks)

사용 가능한 콜백 후크는 모두 여기에서 확인하실 수 있습니다.
on_validation_epoch_end 메서드는 검증 에포크(epoch)가 종료될 때 호출됩니다. 두 개의 인수, trainerpl_module 이 필요한데, 이 인수들은 Trainer 가 자동으로 전달합니다.
trainer.logger.experimental 을 사용하면 Weights & Biases에서 제공하는 모든 기능을 사용할 수 있습니다.
def on_validation_epoch_end(self, trainer, pl_module):
# Bring the tensors to CPU
val_imgs = self.val_imgs.to(device=pl_module.device)
val_labels = self.val_labels.to(device=pl_module.device)
# Get model prediction
logits = pl_module(val_imgs)
preds = torch.argmax(logits, -1)
# Log the images as wandb Image
trainer.logger.experiment.log({
"examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")
for x, pred, y in zip(val_imgs[:self.num_samples],
preds[:self.num_samples],
val_labels[:self.num_samples])]
})

이 콜백의 결과를 확인할 수 있습니다.

🎺 LightningModule - 시스템 정의

LightningModule 은 모델이 아닌 시스템을 정의합니다. 여기서 시스템은 모든 연구 코드를 하나의 클래스로 묶어 자체 포함하도록 합니다. LightningModule 은 PyTorch 코드를 5개 부분으로 나눕니다:
  • 계산 (__init__).
  • 학습 루프 (training_step)
  • 검증 루프 (validation_step)
  • 테스트 루프 (test_step)
  • 옵티마이저 (configure_optimizers)
따라서 쉽게 공유할 수 있는 데이터세트-애그노스틱 모델을 구축할 수 있죠. Cifar-10 분류를 위한 시스템을 구축해 보겠습니다.

1. 계산

LightningModule 의 이 구성 요소는 모델 아키텍처와 순방향 전달을 포함합니다. 이 코드 조각은 일반 PyTorch 코드와 비슷해 보이기도 합니다.
__init__ 을 통해 모델에 필요한 모든 필수 하이퍼파라미터를 전달할 수 있습니다. 보통 여러 버전의 모델은 서로 다른 하이퍼파라미터로 학습하죠. save_hyperparameters 를 호출하여 lightning__init__ 에 있는 모든 값을 체크포인트에 저장하도록 요청할 수 있습니다. 이는 유용한 기능입니다.
두 개의 메서드, _get_conv_output_forward_features 가 보이실 겁니다. 이 메서드는 컨볼루션 블록 출력의 텐서 크기를 자동으로 계산하는 데 사용됩니다. 자세한 내용은 여기를 참조하세요.
forward 메서드는 일반적인 PyTorch 코드와 비슷해 보일 수 있는데, Lightning에서 forward 는 추론 작업을 정의하는 데에만 사용됩니다. training_step 은 학습 루프를 정의합니다.
class LitModel(pl.LightningModule):
def __init__(self, input_shape, num_classes, learning_rate=2e-4):
super().__init__()
# log hyperparameters
self.save_hyperparameters()
self.learning_rate = learning_rate
self.conv1 = nn.Conv2d(3, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 32, 3, 1)
self.conv3 = nn.Conv2d(32, 64, 3, 1)
self.conv4 = nn.Conv2d(64, 64, 3, 1)

self.pool1 = torch.nn.MaxPool2d(2)
self.pool2 = torch.nn.MaxPool2d(2)
n_sizes = self._get_conv_output(input_shape)

self.fc1 = nn.Linear(n_sizes, 512)
self.fc2 = nn.Linear(512, 128)
self.fc3 = nn.Linear(128, num_classes)

self.accuracy = torchmetrics.Accuracy()

# returns the size of the output tensor going into Linear layer from the conv block.
def _get_conv_output(self, shape):
batch_size = 1
input = torch.autograd.Variable(torch.rand(batch_size, *shape))

output_feat = self._forward_features(input)
n_size = output_feat.data.view(batch_size, -1).size(1)
return n_size
# returns the feature tensor from the conv block
def _forward_features(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(F.relu(self.conv2(x)))
x = F.relu(self.conv3(x))
x = self.pool2(F.relu(self.conv4(x)))
return x
# will be used during inference
def forward(self, x):
x = self._forward_features(x)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.log_softmax(self.fc3(x), dim=1)
return x


2. 학습 루프

Lightning 은 대부분의 학습, 에포크 및 배치(batch) 반복을 자동화하므로, 사용자는 학습 단계 논리를 유지하기만 하면 됩니다. training_step 메서드에는 batchbatch_idx 인수가 필요한데 이는 Trainer 가 자동으로 전달합니다. 학습 루프에 대한 자세한 내용은 여기를 확인하세요
에포크 메트릭을 계산하려면 on_epoch=True.log 메서드에 전달합니다. 단계별 메트릭은 자동으로 기록되는데, 이 기능을 끄고자 하면 on_step=False 를 전달합니다.
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
# training metrics
preds = torch.argmax(logits, dim=1)
acc = self.accuracy(preds, y)
self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
return loss

3. 검증 루프

학습 루프와 유사하게, LightningModulevalidation_step 메서드를 덮어쓰면 검증 루프를 구현할 수 있습니다. 검증 루프에 대한 자세한 내용은 여기를 참조하세요.
메트릭은 에포크별로 자동 기록됩니다.
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)

# validation metrics
preds = torch.argmax(logits, dim=1)
acc = self.accuracy(preds, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return loss

4. 테스트 루프

테스트 루프는 검증 루프와 유사합니다. 유일한 차이는 test 루프는 trainer.test() 가 사용될 때만 호출된다는 점입니다. 테스트 루프에 대한 자세한 내용은 여기를 참조하세요.
메트릭은 에포크별로 자동 기록됩니다.
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
# validation metrics
preds = torch.argmax(logits, dim=1)
acc = self.accuracy(preds, y)
self.log('test_loss', loss, prog_bar=True)
self.log('test_acc', acc, prog_bar=True)
return loss

5. 옵티마이저

configure_optimizer 메서드를 사용하여 옵티마이저 및 학습 속도 스케줄러를 정의할 수 있습니다. GAN의 경우와 같이 여러 옵티마이저를 정의할 수도 있죠.
이 메서드에 대한 자세한 내용은 여기를 참조하세요.
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
참고: Lightning을 사용하여 PyTorch 코드를 리팩터링할 때에는 Lightning Module 에서 .cuda().to() 를 제거합니다.

🚋 학습 및 평가

이제 DataModule 로 데이터 파이프라인을 구성하고 LightningModule 로 모델 아키텍처+학습 루프를 구성했으니, PyTorch Lightning Trainer는 나머지 모든 작업을 자동화합니다.
Trainer는 다음을 자동화합니다:
  • 에포크 및 배치(batch) 반복
  • optimizer.step(), backward, zero_grad() 호출
  • .eval(), enabling/disabling grads 호출
  • 가중치 저장 및 로딩
  • 가중치 및 편향 로깅
  • 다중 GPU 학습 지원
  • TPU 지원
  • 16 비트 학습 지원
Trainer에 대한 자세한 내용은 여기를 참조하세요. 이제 이를 사용하여 최종적으로 모델이 학습하도록 해보죠.
먼저 데이터 파이프라인을 초기화합니다. TrainerPyTorch DataLoader 만 있으면 학습/검증/테스트 분할을 수행하죠. 생성된 dm 객체는 Trainer에 직접 전달할 수 있습니다. 그러나 ImagePredictionLogger 에 몇 가지 샘플이 필요하므로, prepare_datasetup 메서드를 수동으로 호출하겠습니다.
# Init our data pipeline
dm = CIFAR10DataModule(batch_size=32)
# To access the x_dataloader we need to call prepare_data and setup.
dm.prepare_data()
dm.setup()

# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape
모델의 학습이 전에 없이 간편해졌습니다. 모델과 선호하는 로거를 초기화하기만 하면 되죠. checkpoint_callback 은 별도로 전달했다는 사실에 유의하시기 바랍니다.
# Init our model
model = LitModel(dm.size(), dm.num_classes)

# Initialize wandb logger
wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')

# Initialize a trainer
trainer = pl.Trainer(max_epochs=50,
progress_bar_refresh_rate=20,
gpus=1,
logger=wandb_logger,
callbacks=[early_stop_callback,
ImagePredictionLogger(val_samples)],
checkpoint_callback=checkpoint_callback)

# Train the model ⚡🚅⚡
trainer.fit(model, dm)

# Evaluate the model on the held-out test set ⚡⚡
trainer.test()

# Close wandb run
wandb.finish()

아래 미디어 패널은 W&B에 기록된 메트릭을 보여줍니다.


아래 미디어 차트는 ImagePredictionLogger 콜백을 사용자 지정한 결과입니다. 각 이미지의 예측값 및 실제값을 확인하실 수 있죠.
⚙️ 아이콘을 클릭하고 슬라이더를 이동하여 각 에포크에서의 모델 예측을 확인해 보세요.



📉 정밀도-재현율 곡선

이미지 분류 모델은 철저한 테스트가 필요하죠. 그래서 정밀도-재현율 곡선을 사용하는 것이 일반적입니다.
Weights & Biases는 말 그대로 Vega에서 지원하는 모든 것을 그릴 수 있는 사용자 지정 vega 플롯을 지원합니다. 평균 정밀도-재현율 곡선을 사용하여 모델의 성능을 살펴보시죠.
Weights & Biases가 지원하는 시각화 사용자 지정에 대한 자세한 내용은 이 보고서를 확인하세요. 평균 정밀도-재현율 곡선을 기록하는 방법은 이 보고서를 확인하세요.
테스트 정확도는 70%이하이지만, 이 분류기를 개선할 수 있는 방법은 많습니다.




맺음말

제 전문 분야는 TensorFlow/Keras 생태계이며, PyTorch는 우아한 프레임워크이지만 다소 버겁다고 생각합니다. 물론 제 개인적인 경험에 불과합니다. PyTorch Lightning을 살펴보면서, 저는 제가 PyTorch를 멀리했던 거의 모든 이유가 해결되었다는 사실을 깨달았습니다. 제가 발견한 기쁜 소식을 알려드리죠:
  • 이전: 기존의 PyTorch 모델 정의는 두서없었습니다. 일부 model.py 스크립트에 모델이 있기도 하고 train.py 파일에 학습 루프가 있기도 했죠. 파이프라인을 이해하려면 이리저리 살펴보아야 했습니다.
  • 현재: LightningModule 은 모델이 training_step, validation_step 등과 함께 정의되는 시스템 역할을 하므로 이제 모듈 형식으로 공유 가능합니다.
  • 이전: TensorFlow/Keras의 가장 좋은 점은 입력 데이터 파이프라인입니다. 데이터세트 카탈로그가 풍부하고 계속 확장하고 있죠. PyTorch의 데이터 파이프라인은 큰 문제였습니다. 일반적인 PyTorch 코드에서 데이터 다운로드/정리/준비는 보통 여러 파일에 분산되어 있습니다.
  • 현재: DataModule 은 데이터 파이프라인을 공유 및 재사용 가능한 하나의 클래스로 구성합니다. 그저 일치하는 변환 및 필요한 데이터 처리/다운로드 단계가 있는 train_dataloader, val_dataloader(s), test_dataloader(s)의 모음이죠.
  • 이전: Keras를 사용하면 모델 학습은 model.fit 을, 추론 실행은 model.predict 를 호출하여 수행할 수 있죠. model.evaluate 는 테스트 데이터에 대한 오래되고 단순한 평가를 제공했고요. PyTorch의 경우는 그렇지 않죠. 일반적으로 별도의 train.pytest.py 파일이 있습니다.
  • 현재: LightningModule 이 있어 Trainer 가 모든 것을 자동화합니다. 모델의 학습과 평가는 trainer.fittrainer.test 를 호출하기만 하면 됩니다.
  • 이전: TensorFlowTPU, PyTorch를 좋아합니다... 네!
  • 현재: PyTorch Lightning을 사용하면, 동일한 모델이 여러 GPU와 심지어 TPU에서도 쉽게 학습할 수 있습니다. 와!
  • 이전: 저는 콜백의 열렬한 팬으로 콜백을 사용자 지정하는 것을 선호합니다. 기존 PyTorch에서는 Early Stopping처럼 간단한 콜백조차도 애를 먹이곤 했죠.
  • 현재: PyTorch Lightning에서는 Early StoppingModel Checkpointing을 사용하는 일이 식은 죽 먹기입니다. 사용자 지정 콜백도 작성할 수 있죠.
저는 앞으로도 이 기쁜 소식을 계속 외치고 다니게 될 것 같습니다. PyTorch Lightning이 제공하는 모든 지원 내용은 여기에서 확인하실 수 있습니다 .

🎨 결론 및 리소스

이 보고서가 여러분께 도움이 되길 바랍니다. 코드를 이리저리 바꿔보시고, 여러분이 고른 데이터세트를 이용해 이미지 분류기 학습을 수행하셨으면 좋겠습니다.
PyTorch Lightning에 대한 자세한 정보를 제공하는 자료를 알려드리죠:
아래에 댓글로 여러분의 생각을 알려주세요.

Iterate on AI agents and models faster. Try Weights & Biases today.