pretrained model 의 구조 : ['A', 'B', 'C', 'D']
new model의 구조 : ['A', 'B', 'C', 'E'] 라고 가정하고..
단순히 load_state_dict를 이용하여 pretrained model의 값을 읽어오면
Missing key(s) in state_dict: "E.weighjt", "E.bias". 와 같은 에러를 발생시킨다.
이럴 경우 pretrained model로 부터 new model에 있는 값만을 골라내는 작업 후 load 하면 정상적으로 동작
pretrained_dict = pretrained_model.state_dict()
new_model_dict = new_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in new_model_dict}
new_model_dict.update(pretrained_dict)
new_model.load_state_dict(new_model_dict)
개념 이해를 위한 참고사항
pretrained_dict: ['A', 'B', 'C', 'D']
model_dict: ['A', 'B', 'C', 'E']
↓
pretrained_dict: ['A', 'B', 'C']
model_dict: ['A', 'B', 'C', 'E']
출처 : https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/16
'전공관련 > Deep Learning' 카테고리의 다른 글
[Pytorch] network 결과물 사용시 주의사항! (0) | 2019.03.13 |
---|---|
[Pytorch] torch.load 에서 학습시와 환경이 달라서 못읽을 경우 (3) | 2019.03.13 |
[Pytorch] 학습 한 모델을 저장하고 불러오자 (1) | 2019.03.12 |
Deconvolution 파라미터에 따른 출력 크기 계산하기 (0) | 2018.10.19 |
[TensorFlow] meta file로부터 graph를 읽어오고 사용하는 방법 (2) | 2018.03.22 |