전공관련/Deep Learning

[Pytorch] model load시 key가 있는 레이어만 불러오자

매직블럭 2021. 2. 17. 17:00


이전에 pretrained model을 불러올 때 key set이 일치하지 않아 발생하는 문제를 dict를 수정하여 불러온 적이 있다.

 

[Pytorch] pretrained-model 의 일부만을 불러와 보자

pretrained model 의 구조 : ['A', 'B', 'C', 'D'] new model의 구조 : ['A', 'B', 'C', 'E'] 라고 가정하고.. 단순히 load_state_dict를 이용하여 pretrained model의 값을 읽어오면 Missing key(s) in state_dic..

jangjy.tistory.com

 

이러한 일련의 과정을 자동으로 처리해주는 파라미터가 load_state_dict 함수에 있어서 기록.

 


 

pretrained model 의 구조 : ['A', 'B', 'C', 'D']

new model의 구조 : ['A', 'B', 'C', 'E']  라고 가정하고..

 

단순히 load_state_dict를 이용하여 pretrained model의 값을 읽어오면

Missing key(s) in state_dict: "E.weighjt", "E.bias". 와 같은 에러를 발생시킨다.


이때 load_state_dict 함수에 strict=False 파라미터를 추가해주면 key set 이 일치하는 레이어의 값만 읽어온다.

즉, 위의 가정과 같은 상황에서는 A, B, C 는 pretrained model의 weight가 load 되고

key가 없는 E 에는 random으로 초기값이 할당된다. 

 

굳이 없는 key를 찾고 새로운 dict를 만들 필요없이 pytorch에서 제공하는 방법을 쓰면 간단하다..