전공관련/Deep Learning
[Pytorch] multiGPU 에서 학습한 모델을 singleGPU에서 사용하자2
매직블럭
2019. 12. 17. 17:35
[Pytorch] multiGPU 에서 학습한 모델을 singleGPU에서 사용하자
지난난 글에서 사용한한 방법은 저장 단계에서 적용해야 하는 문제가 있어서
이번에는 다시 state_dict를 읽어서 key에서 module을 제거하고 다시 넣어주는 방법을 사용.
우선 state_dict의 key에서 module 을 제거한 새로운 dict를 반환하는 함수 선언
def get_state_dict(origin_dict):
old_keys = origin_dict.keys()
new_dict = {}
for ii in old_keys:
temp_key = str(ii)
if temp_key[0:7] == "module.":
new_key = temp_key[7:]
else:
new_key = temp_key
new_dict[new_key] = origin_dict[temp_key]
return new_dict
이후 네트워크 가중치를 읽을때 위 함수를 통해 다시 읽어주면 문제없이 사용 가능하다.
# 위쪽은 생략
net = CNN_Network(num_classes=1000)
net.to(device)
net.eval()
checkpoint = torch.load('model_path/model_file.pth.tar', map_location=device)
checkpoint_dict = get_state_dict(checkpoint["state_dict"])
net.load_state_dict(checkpoint_dict)
(21.01.06. 수정)
위의 get_state_dict를 람다식을 이용하여 깔끔하게 구현된 코드를 확인하여 기록.
(출처 : github.com/biubug6/Pytorch_Retinaface / RetinaFace)
def remove_prefix(state_dict, prefix):
''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
print('remove prefix \'{}\''.format(prefix))
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
return {f(key): value for key, value in state_dict.items()}
# 사용할때는
pretrained_dict = remove_prefix(pretrained_dict, 'module.')