전공관련/Deep Learning

[Pytorch] 학습 한 모델을 저장하고 불러오자

매직블럭 2019. 3. 12. 16:05
model = TestModel()
''' training '''

# save
savePath = "./output/test_model.pth"
torch.save(model.state_dict(), savePath)

# load
new_model = TestModel()
new_model.load_state_dict(torch.load("./output/test_model.pth"))


torch.save 나 torch.load 를 이용할 경우 환경에 따라서 정상적으로 읽고 쓰기가 불가능 할 수도 있음

추가1. 위 함수의 경우 파라미터를 serialization 하여 저장하는데 pytorch 버전에 따라 구조가 바뀔수 있음.

추가2. torch.save의 경우 기본으로 pickle 을 이용하여 serialization을 수행. 하지만 이 경우 일부 제대로 selialization 못하는 경우가 발생. 이럴 경우 dill 패키지 설치한 후 이용하자. (pickle_module=dill 을 추가 파라미터로 줄것)


그래서 state_dict 값만 따로 저장하고 읽고 하는 편이 추천된다고 한다..

단, 이경우 architecture에 대한 define은 있어야겠지..


추가3. model load 후 이어서 학습하기 위해서는 model.eval() 한번 호출이 필요한듯? batch norm이나 drop out 같은 놈이 기본값으로 설정되어 있기 때문에 이전 학습에서 사용된 파라미터를 가져 올 필요가 있는것 같다. (재확인 필요)