전공관련/Deep Learning

[Pytorch] multiGPU 에서 학습한 모델을 singleGPU에서 사용하자2

매직블럭 2019. 12. 17. 17:35

[Pytorch] multiGPU 에서 학습한 모델을 singleGPU에서 사용하자

지난난 글에서 사용한한 방법은 저장 단계에서 적용해야 하는 문제가 있어서 

이번에는 다시 state_dict를 읽어서 key에서 module을 제거하고 다시 넣어주는 방법을 사용.


우선 state_dict의 key에서 module 을 제거한 새로운 dict를 반환하는 함수 선언

def get_state_dict(origin_dict):
    old_keys = origin_dict.keys()
    new_dict = {}

    for ii in old_keys:
        temp_key = str(ii)
        if temp_key[0:7] == "module.":
            new_key = temp_key[7:]
        else:
            new_key = temp_key

        new_dict[new_key] = origin_dict[temp_key]
    return new_dict

 

이후 네트워크 가중치를 읽을때 위 함수를 통해 다시 읽어주면 문제없이 사용 가능하다.

# 위쪽은 생략

net = CNN_Network(num_classes=1000)
net.to(device)
net.eval()

checkpoint = torch.load('model_path/model_file.pth.tar', map_location=device)
checkpoint_dict = get_state_dict(checkpoint["state_dict"])
net.load_state_dict(checkpoint_dict)

 

 

(21.01.06. 수정)

위의 get_state_dict를 람다식을 이용하여 깔끔하게 구현된 코드를 확인하여 기록.

(출처 : github.com/biubug6/Pytorch_Retinaface / RetinaFace)

 

def remove_prefix(state_dict, prefix):
    ''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
    print('remove prefix \'{}\''.format(prefix))
    f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
    return {f(key): value for key, value in state_dict.items()}
    
# 사용할때는
pretrained_dict = remove_prefix(pretrained_dict, 'module.')