Skip to main content

PyTorch에서 모델을 저장하고 로드하는 방법

이 튜토리얼에서는 PyTorch에서 트레이닝 된 모델을 올바르게 저장하고 로드하는 방법을 배울 수 있습니다.
Created on February 2|Last edited on February 2
이는 여기에서 볼 수 있는 영어 기사를 번역한 것이다.

도입

모델 트레이닝은 비용이 많이 들며 실제 활용 사례의 경우 많은 시간이 소요됩니다. 트레이닝 된 모델을 저장하는 것은 일반적으로 대부분의 모델 트레이닝 워크플로우의 마지막 단계에 해당합니다. 이후에는 이러한 사전 트레이닝 모델을 재사용합니다. PyTorch에서 트레이닝 된 모델을 저장하고 로드하는 방법으로 여러가지가 있습니다.
이 짧은 리포트에서는 PyTorch 생태계에서 트레이닝 된 모델을 저장하고 로드하는 방법을 살펴볼 것입니다. 자세한 지침은 PyTorch 공식 문서를 참고하세요

섹션




state_dict (권장사항) 사용

state_dict는 단순하게 각 레이어를 해당 파라미터 텐서(Tensor)에 매핑해주는 Python 사전입니다. 모델 (컨볼루션 레이어, 선형 레이어, 등) 및 등록된 버퍼(batchnorm’s running_mean)의 학습 가능한 매개 변수에는 state_dict 항목이 있습니다. state_dict에 대한 자세한 내용은 여기를 참고하세요.

저장

torch.save(model.state_dict(), 'save/to/path/model.pth')

로드

model = MyModelDefinition(args)
model.load_state_dict(torch.load('load/from/path/model.pth'))

장점:

  • PyTorch는 내부적으로 Python의 피클(Pickle) 모듈에 의존합니다. Python 사전은 쉽게 피클 및 언피클 (Pickled, Unpickled), 업데이트 및 복원 가능합니다. 따라서 state_dict를 사용하여 모델을 저장하면 유연성이 향상됩니다.
  • 옵티마이저(Optimizer)의 상태, 하이퍼 파라미터 등을 모델의 state_dict와 함께 Key-Value 패어로 저장할 수도 있습니다. 복원된 후에는 통상적인 Python 사전처럼 액세스할 수 있습니다. 어떻게 가능한지에 대해서는 뒷부분에서 좀 더 자세히 알아보도록 하겠습니다.

단점:

  • state_dict를 로드하려면 모델의 정의가 필요합니다.

Gotcha:

  • 추론(Inference)을 원하시는 경우에는 반드시 model.eval()를 콜(Call)해야 합니다.
  • .pt 또는 .pth확장자를 사용하여 모델을 저장합니다.

전체 모델 저장 및 로드

또한 state_dict 뿐만 아니라 전체 모델을 PyTorch에 저장하실 수도 있습니다. 그러나 이러한 방법은 모형을 저장하는 데 권장되는 방법이 아닙니다.

저장

torch.save(model, 'save/to/path/model.pt')

로드

model = torch.load('load/from/path/model.pt')

장점:

  • 최소한의 코드로 전체 모델을 저장할 수 있는 가장 쉬운 방법입니다.
  • API 저장과 로딩이 보다 직관적으로 가능합니다.

단점:

  • Python의 피클 모듈은 내부적으로 사용되기 때문에 직렬화된(Serialized) 데이터는 특정 클래스에 바인딩되며 모델을 저장할 때 정확한 디렉토리 구조가 사용됩니다. 피클은 단순히 특정 클래스를 포함하는 파일의 경로를 저장합니다. 로딩 시간 동안 사용됩니다.
  • 리팩토링 후 저장된 모델이 동일한 경로에 링크되지 않으면 코드가 끊어질 수 있습니다. 이러한 모델을 다른 프로젝트에서 사용하는 것 또한 경로 구조가 유지되어야 함에 따라 어려울 수 있습니다.

Gotchas:

  • 추론(Inference)을 원하시는 경우에는 반드시 model.eval()를 콜(Call)해야 합니다.
  • .pt 또는 .pth확장자를 사용하여 모델을 저장합니다.

체크포인트에서 모델 저장 및 로드

추론 (Inference)을 위해 일반적인 체크포인트 모형을 저장하고 로딩 하거나 트레이닝을 재개하면 마지막으로 중단한 부분을 픽업하는 데 도움이 될 수 있습니다. 체크포인트의 컨텍스트에서 저장 모델의 state_dict는 충분하지 않습니다. 또한 Optimizer의 state_dict와 함께 마지막 epoch 넘버, loss 등을 저장해야 합니다. 체크포인트를 사용하여 트레이닝을 재개하는 데 필요한 모든 사항을 저장하는 것이 좋습니다.

저장

torch.save({
'epoch': EPOCH,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}, 'save/to/path/model.pth')

로드

model = MyModelDefinition(args)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load('load/from/path/model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

Gotcha:

  • 이를 트레이닝에 사용하려면 model.train()를 콜(Call)해야 합니다.
  • 이를 추론(Inference)에 사용하려면 model.eval()를 콜(Call)해야 합니다.

모델 버전 제어에 W&B 아티팩트 사용


모델을 Weights and Biases Artifacts팩트로 저장하고 동일한 내용을 사용하려면 Colab Notebook을 사용해 보세요.
W&B 아티팩트는 머신 러닝 파이프라인에서 데이터 세트, 모델 및 평가 결과를 저장하고 추적하는 데 사용할 수 있습니다. 아티팩트를 데이터의 버전 폴더로 간주합니다. 아티팩트에 대한 내용은 여기를 참고하세요.

아티팩트로 저장

# Import
import wandb
# Save your model.
torch.save(model.state_dict(), 'save/to/path/model.pth')
# Save as artifact for version control.
run = wandb.init(project='your-project-name')
artifact = wandb.Artifact('model', type='model')
artifact.add_file('save/to/path/model.pth')
run.log_artifact(artifact)
run.join()

아티팩트로 로드

이를 통해 저장된 모델이 다운로드될 것입니다. 그런 다음 torch.load를 사용하여 모델을 로드할 수 있습니다.
import wandb
run = wandb.init()

artifact = run.use_artifact('entity/your-project-name/model:v0', type='model')
artifact_dir = artifact.download()

run.join()

그림 1: 여기에서 실시간 아티팩트 대시보드를 확인하세요.


Weights & Biases

저희 툴을 사용하시면 실행에서 하이퍼 파라미터 (Hyperparameter)와 출력 메트릭을 로그(기록)한 후에 결과값을 시각화 및 비교하고 결과값을 동료 분들과 신속하게 공유하실 수 있습니다.
5분 내로 시작해보세요.

PyTorch에 관심있는 분들에게 추천하는 내용


Iterate on AI agents and models faster. Try Weights & Biases today.