전공관련/Deep Learning
[Pytorch] model load시 key가 있는 레이어만 불러오자
매직블럭
2021. 2. 17. 17:00
이전에 pretrained model을 불러올 때 key set이 일치하지 않아 발생하는 문제를 dict를 수정하여 불러온 적이 있다.
이러한 일련의 과정을 자동으로 처리해주는 파라미터가 load_state_dict 함수에 있어서 기록.
pretrained model 의 구조 : ['A', 'B', 'C', 'D']
new model의 구조 : ['A', 'B', 'C', 'E'] 라고 가정하고..
단순히 load_state_dict를 이용하여 pretrained model의 값을 읽어오면
Missing key(s) in state_dict: "E.weighjt", "E.bias". 와 같은 에러를 발생시킨다.
이때 load_state_dict 함수에 strict=False 파라미터를 추가해주면 key set 이 일치하는 레이어의 값만 읽어온다.
즉, 위의 가정과 같은 상황에서는 A, B, C 는 pretrained model의 weight가 load 되고
key가 없는 E 에는 random으로 초기값이 할당된다.
굳이 없는 key를 찾고 새로운 dict를 만들 필요없이 pytorch에서 제공하는 방법을 쓰면 간단하다..