Skip to main content

PyTorch에서 모델을 저장하고 불러오는 방법

이 글은 PyTorch에서 모델을 저장하고 불러오는 방법을 다루는 머신 러닝 튜토리얼이며, 버전 관리를 위해 Weights & Biases를 사용합니다. 이 글은 AI 번역본입니다. 오역이 의심되는 부분이 있다면 댓글로 알려 주세요.
Created on September 15|Last edited on September 15
모델 학습은 비용이 많이 들고 실제 활용 사례에서는 시간이 오래 걸립니다. 대부분의 ML 워크플로에서는 학습된 모델을 저장하는 것이 마지막 단계이며, 이후 추론에 재사용합니다.
학습된 모델을 저장하고 불러오는 방법은 여러 가지가 있습니다 PyTorch 이 튜토리얼에서는 PyTorch에서 학습된 모델을 저장하고 불러오는 여러 방법을 살펴봅니다. 자세한 안내는 다음을 확인하세요 공식 PyTorch 문서. 

목차




state_dict로 PyTorch 모델 저장 및 불러오기 state_dict (권장)

A state_dict 은 각 레이어를 해당 파라미터 텐서에 매핑하는 단순한 파이썬 딕셔너리입니다. 모델의 학습 가능한 파라미터(컨볼루션 레이어, 선형 레이어 등)와 등록된 버퍼(BatchNorm의 running_mean) 항목이 있습니다 state_dict.
사용하기 state_dict 여러 가지 이유로 다음과 같은 점에서 권장되는 접근 방식입니다.
  • 먼저, state_dict 모델의 핵심 파라미터(가중치와 바이어스 등)만 저장하므로 파일 크기가 작고 다루기 쉽습니다. 둘째, 유연성이 높습니다. `state_dict`는 파이썬 딕셔너리이기 때문에 모델 파라미터뿐 아니라 옵티마이저 상태와 기타 메타데이터도 함께 저장할 수 있어, 학습 재개나 미세 조정이 수월해집니다. 또한 `state_dict`는 파이썬의 기본 `pickle` 모듈과 호환되어 저장과 로딩이 간단하므로, 서로 다른 환경에서 모델을 다시 불러오는 작업이 한층 쉬워집니다. 이 가이드는 PyTorch에서 `state_dict`를 사용해 모델을 저장하고 로드하는 방법을 단계별로 설명하고, 효과적인 모델 관리를 위한 모범 사례를 제시합니다.
  • 둘째, `state_dict`는 파이썬 딕셔너리이므로 모델 파라미터뿐 아니라 옵티마이저 상태와 기타 메타데이터도 함께 저장할 수 있어 학습을 더 쉽게 재개하거나 모델을 미세 조정하다.
  • 또한 `state_dict`는 파이썬 기본 `pickle` 모듈과 호환되므로 저장과 로딩이 간단해, 서로 다른 환경에서 모델을 다시 불러오는 과정을 한층 단순화합니다. 이 가이드는 PyTorch에서 `state_dict`를 사용해 모델을 저장하고 로드하는 절차를 단계별로 안내하고, 효과적인 모델 관리를 위한 모범 사례를 설명합니다.
자세히 알아보기 state_dict 여기.
다음은 PyTorch에서 모델을 저장하고 로드하는 코드입니다 using state_dict:

다음 방법으로 PyTorch에서 모델을 저장하세요 state_dict

torch.save(model.state_dict(), 'save/to/path/model.pth')

다음 방법으로 PyTorch에서 모델을 로드하세요 state_dict

model = MyModelDefinition(args)
model.load_state_dict(torch.load('load/from/path/model.pth'))

장점

  • PyTorch는 Python의 pickle 모듈로, Python 딕셔너리를 손쉽게 피클로 직렬화하고 역직렬화하며 업데이트하고 복원할 수 있습니다. state_dict를 사용하면 모델 파라미터를 유연하게 관리할 수 있습니다.
  • 모델 파라미터와 함께 옵티마이저 상태, 하이퍼파라미터 등의 추가 요소를 키-값 쌍으로 저장할 수 있습니다. state_dict모델을 다시 로드할 때 쉽게 접근할 수 있습니다.

단점:

  • 로드하려면 state_dict정확히 같은 모델 정의가 필요합니다. 그렇지 않으면 저장된 파라미터를 올바르게 로드할 수 없습니다.

주의할 점:

  • 반드시 호출하세요 model.eval() 추론을 수행할 때는 드롭아웃과 배치 정규화 같은 레이어를 평가 모드로 설정하세요.
  • 다음을 사용해 모델을 저장하세요 .pt 또는 .pth 일관성과 호환성을 위한 확장자.

전체 PyTorch 모델 저장 및 로드

PyTorch에서 모델을 저장하고 로드할 때, 단지 모델의 일부만이 아니라 전체 모델을 저장하는 옵션도 있습니다. state_dict이 방법은 모델의 전체 아키텍처와 파라미터를 한 번에 담아, 최소한의 코드로 빠르고 쉽게 복원할 수 있게 해줍니다.
편리하긴 하지만, 전체 모델을 저장하는 방식은 일반적으로 권장되지 않습니다. 이는 Python의 pickle 모듈과 함께 저장된 모델 파일은 저장 당시 사용된 특정 클래스 정의와 디렉터리 구조에 강하게 결합됩니다. 이 의존성 때문에 코드가 리팩터링되거나 다른 프로젝트에서 모델을 사용하려고 할 때 호환성 문제가 발생할 수 있습니다.
이 방식으로 PyTorch 모델을 저장하고 로드하는 방법을 살펴보겠습니다.

전체 PyTorch 모델 저장하기

torch.save(model, 'save/to/path/model.pt')

전체 PyTorch 모델 로드하기

model = torch.load('load/from/path/model.pt')

장점:

  • 모델 전체를 저장하는 방식은 필요한 코드가 가장 적어, 저장과 로드를 빠르게 처리할 수 있는 옵션입니다.
  • 전체 모델을 저장하고 로드하는 API는 간단하고 사용하기 쉽습니다. 저장 및 로드 API가 더 직관적입니다.

단점:

  • 파이썬의 pickle 이 모듈이 내부적으로 사용되면, 저장된 모델은 특정 클래스 정의와 디렉터리 구조에 강하게 결합됩니다. 즉, 리팩터링이나 파일 경로 변경이 발생하면 로딩 문제가 생길 수 있습니다.
  • 전체 모델 저장본을 다른 프로젝트에서 사용하는 것은 어렵습니다. 저장 당시와 동일한 디렉터리 구조와 클래스 경로를 유지해야 하기 때문입니다.

주의할 점:

  • 반드시 호출하세요 model.eval() 드롭아웃과 배치 정규화 같은 레이어가 올바르게 동작하도록, 추론을 수행하기 전에.
  • 다음을 사용해 모델을 저장하세요 .pt 또는 .pth 일관성과 가독성을 위해 모델을 저장할 때 확장자를 지정하세요.

체크포인트에서 PyTorch 모델 저장 및 불러오기

대부분의 머신러닝 파이프라인에서는 일정 주기나 특정 조건에 따라 모델 체크포인트를 저장하는 것이 필수적입니다. 이렇게 하면 중단 상황이 발생하더라도 최신 또는 최적의 체크포인트에서 학습을 재개하여 연속성을 보장할 수 있습니다. 체크포인트는 파인튜닝이나 다양한 단계에서 모델 성능을 평가하는 데도 유용합니다.
체크포인트를 저장할 때, 모델의 state_dict 충분하지 않습니다. 옵티마이저의 것도 함께 저장해야 합니다. state_dict마지막 에포크 번호, 현재 손실값, 그리고 학습을 끊김 없이 재개하는 데 필요한 기타 관련 정보.

PyTorch 모델 체크포인트 저장

torch.save({'epoch': EPOCH,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS},
'save/to/path/model.pth')

PyTorch 모델 체크포인트 불러오기

model = MyModelDefinition(args)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load('load/from/path/model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

주의할 점:

  • 학습을 재개할 때는 반드시 호출하세요 model.train() 모델이 학습 모드로 설정되었는지 확인하기 위해서입니다.
  • 체크포인트를 불러온 뒤 추론을 수행하려면 다음을 호출하세요 model.eval() 모델을 평가 모드로 설정하기 위해서입니다.


모델 버전 관리를 위해 W&B 아티팩트를 사용하세요

W&B 아티팩트 버전 관리를 통해 PyTorch 모델을 저장하고 불러오는 강력한 방법을 제공합니다. 아티팩트는 버전이 있는 디렉터리로 볼 수 있으며, 모델 버전과 기타 자산을 체계적으로 저장하고 추적할 수 있게 해줍니다.
아티팩트 더 알아보기 여기.

W&B 아티팩트로 PyTorch 모델 저장하기

# Import
import wandb
# Save your model.
torch.save(model.state_dict(), 'save/to/path/model.pth')
# Save as artifact for version control.
run = wandb.init(project='your-project-name')
artifact = wandb.Artifact('model', type='model')
artifact.add_file('save/to/path/model.pth')
run.log_artifact(artifact)
run.finish()

W&B 아티팩트로 PyTorch 모델 불러오기

저장된 모델이 다운로드됩니다. 그런 다음 다음을 사용해 모델을 불러올 수 있습니다. torch.load.
import wandb
run = wandb.init()

artifact = run.use_artifact('entity/your-project-name/model:v0', type='model')
artifact_dir = artifact.download()

run.finish()

그림 1라이브 아티팩트 대시보드를 확인하세요 여기

W&B Artifacts를 사용하면 PyTorch 모델을 손쉽게 저장하고 불러올 수 있어 버전 관리를 간편하게 하고, 프로젝트에서 모델 반복 관리를 쉽게 할 수 있습니다.

Weights & Biases를 사용해 보세요

Weights & Biases는 머신러닝 실험을 체계적으로 추적할 수 있도록 도와줍니다. 실행에서 하이퍼파라미터와 출력 지표를 기록하고, 결과를 시각화·비교한 뒤, 동료들과 신속하게 공유해 보세요.
시작하기 5분 안에 완료하거나 Replit에서 간단한 실험 2가지를 실행해 보고, 아래 지침을 따라 W&B가 작업 정리에 어떻게 도움이 되는지 확인해 보세요:
지침
  1. 아래의 초록색 “Run” 버튼을 클릭하세요 처음으로 Run을 클릭하면 Replit이 머신을 할당하는 데 약 30~45초가 걸립니다.
  2. 터미널 창에 표시되는 안내를 따라 진행하세요 (아래 오른쪽 창)
  3. 터미널 창의 크기를 조절할 수 있습니다 (오른쪽 아래)에서 더 크게 보기



추천 읽을거리



이 글은 AI 번역본입니다. 오역이 의심되는 부분이 있으면 댓글로 알려 주세요. 원문은 아래 링크에서 확인할 수 있습니다: 원문 보고서 보기
Hongbo Miao
Hongbo Miao •  
I saw ONNX file format in the video "Integrate Weights & Biases with PyTorch" https://youtu.be/G7GH0SeNBMA?t=925 It would be great to add it. Thanks!
Reply
Arnav Das
Arnav Das •  
If for some reason Optimizer state wasn't saved, then the saved model state will be useless right ? or is there any way to that we can again retrain with only model state, loss and epoch information ?
3 replies