model = TestModel()
''' training '''
# save
savePath = "./output/test_model.pth"
torch.save(model.state_dict(), savePath)
# load
new_model = TestModel()
new_model.load_state_dict(torch.load("./output/test_model.pth"))
torch.save 나 torch.load 를 이용할 경우 환경에 따라서 정상적으로 읽고 쓰기가 불가능 할 수도 있음
추가1. 위 함수의 경우 파라미터를 serialization 하여 저장하는데 pytorch 버전에 따라 구조가 바뀔수 있음.
추가2. torch.save의 경우 기본으로 pickle 을 이용하여 serialization을 수행. 하지만 이 경우 일부 제대로 selialization 못하는 경우가 발생. 이럴 경우 dill 패키지 설치한 후 이용하자. (pickle_module=dill 을 추가 파라미터로 줄것)
그래서 state_dict 값만 따로 저장하고 읽고 하는 편이 추천된다고 한다..
단, 이경우 architecture에 대한 define은 있어야겠지..
추가3. model load 후 이어서 학습하기 위해서는 model.eval() 한번 호출이 필요한듯? batch norm이나 drop out 같은 놈이 기본값으로 설정되어 있기 때문에 이전 학습에서 사용된 파라미터를 가져 올 필요가 있는것 같다. (재확인 필요)
'전공관련 > Deep Learning' 카테고리의 다른 글
[Pytorch] torch.load 에서 학습시와 환경이 달라서 못읽을 경우 (3) | 2019.03.13 |
---|---|
[Pytorch] pretrained-model 의 일부만을 불러와 보자 (0) | 2019.03.12 |
Deconvolution 파라미터에 따른 출력 크기 계산하기 (0) | 2018.10.19 |
[TensorFlow] meta file로부터 graph를 읽어오고 사용하는 방법 (2) | 2018.03.22 |
[TensorFlow] Saver를 이용하여 기존 model의 weight를 읽어오자. (0) | 2018.03.05 |