전공관련/Deep Learning

[Pytorch] pretrained-model 의 일부만을 불러와 보자

매직블럭 2019. 3. 12. 16:12

pretrained model 의 구조 : ['A', 'B', 'C', 'D']

new model의 구조 : ['A', 'B', 'C', 'E']  라고 가정하고..


단순히 load_state_dict를 이용하여 pretrained model의 값을 읽어오면

Missing key(s) in state_dict: "E.weighjt", "E.bias". 와 같은 에러를 발생시킨다.


이럴 경우 pretrained model로 부터 new model에 있는 값만을 골라내는 작업 후 load 하면 정상적으로 동작

pretrained_dict = pretrained_model.state_dict()
new_model_dict = new_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in new_model_dict}
new_model_dict.update(pretrained_dict)
new_model.load_state_dict(new_model_dict)


개념 이해를 위한 참고사항

    pretrained_dict: ['A', 'B', 'C', 'D']

    model_dict: ['A', 'B', 'C', 'E']

                  

    pretrained_dict: ['A', 'B', 'C']

    model_dict: ['A', 'B', 'C', 'E']

출처 : https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/16