이전에 pretrained model을 불러올 때 key set이 일치하지 않아 발생하는 문제를 dict를 수정하여 불러온 적이 있다.
이러한 일련의 과정을 자동으로 처리해주는 파라미터가 load_state_dict 함수에 있어서 기록.
pretrained model 의 구조 : ['A', 'B', 'C', 'D']
new model의 구조 : ['A', 'B', 'C', 'E'] 라고 가정하고..
단순히 load_state_dict를 이용하여 pretrained model의 값을 읽어오면
Missing key(s) in state_dict: "E.weighjt", "E.bias". 와 같은 에러를 발생시킨다.
이때 load_state_dict 함수에 strict=False 파라미터를 추가해주면 key set 이 일치하는 레이어의 값만 읽어온다.
즉, 위의 가정과 같은 상황에서는 A, B, C 는 pretrained model의 weight가 load 되고
key가 없는 E 에는 random으로 초기값이 할당된다.
굳이 없는 key를 찾고 새로운 dict를 만들 필요없이 pytorch에서 제공하는 방법을 쓰면 간단하다..
'전공관련 > Deep Learning' 카테고리의 다른 글
[용어] Ablation Study (0) | 2021.02.25 |
---|---|
[Caffe] caffe 환경 설정없이 caffemodel 값을 확인하자 (0) | 2021.02.24 |
[ONNX] Onnx convert 모델을 검증하자 (2) | 2020.04.22 |
[Onnx] Onnxruntime - GPU를 사용하자 (8) | 2020.03.09 |
[Onnx] onnx 모듈을 사용하기 위한 class를 만들어보자 (0) | 2020.02.26 |