전공관련/Deep Learning
[Pytorch] pretrained-model 의 일부만을 불러와 보자
매직블럭
2019. 3. 12. 16:12
pretrained model 의 구조 : ['A', 'B', 'C', 'D']
new model의 구조 : ['A', 'B', 'C', 'E'] 라고 가정하고..
단순히 load_state_dict를 이용하여 pretrained model의 값을 읽어오면
Missing key(s) in state_dict: "E.weighjt", "E.bias". 와 같은 에러를 발생시킨다.
이럴 경우 pretrained model로 부터 new model에 있는 값만을 골라내는 작업 후 load 하면 정상적으로 동작
pretrained_dict = pretrained_model.state_dict()
new_model_dict = new_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in new_model_dict}
new_model_dict.update(pretrained_dict)
new_model.load_state_dict(new_model_dict)
개념 이해를 위한 참고사항
pretrained_dict: ['A', 'B', 'C', 'D']
model_dict: ['A', 'B', 'C', 'E']
↓
pretrained_dict: ['A', 'B', 'C']
model_dict: ['A', 'B', 'C', 'E']
출처 : https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/16