전공관련/Deep Learning
[Pytorch] torch.load 에서 학습시와 환경이 달라서 못읽을 경우
매직블럭
2019. 3. 13. 13:53
학습은 cuda:2 환경에서 진행한 모델을 다른 pc에서 불러와 cuda:0 환경에서 inference를 시도
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = Net()
net.to(device)
net.load_state_dict(torch.load("./output/trained_model.pth"))
이러한 경우 torch.load 에서 에러가 발생.
학습 당시의 환경은 cuda:2 인데 그 환경과 다르다는것!
이럴 경우 map_location 파라미터를 설정 해 주면 정상적으로 읽기 가능
net.load_state_dict(torch.load("./output/0074.pth", map_location='cuda:0'))