전공관련/Deep Learning

[Pytorch] torch.load 에서 학습시와 환경이 달라서 못읽을 경우

매직블럭 2019. 3. 13. 13:53

학습은 cuda:2 환경에서 진행한 모델을 다른 pc에서 불러와 cuda:0 환경에서 inference를 시도

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net = Net()
net.to(device)
net.load_state_dict(torch.load("./output/trained_model.pth"))

이러한 경우 torch.load 에서 에러가 발생.

학습 당시의 환경은 cuda:2 인데 그 환경과 다르다는것! 


이럴 경우 map_location 파라미터를 설정 해 주면 정상적으로 읽기 가능

net.load_state_dict(torch.load("./output/0074.pth", map_location='cuda:0'))