pytorch 등의 프레임워크에서 onnx로 convert 한 모델이 잘 변환됐는지 늘 확인이 필요하다.
이럴 때 확인을 위한 방법 정리
import torch
import numpy as np
import onnxruntime as rt
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def test():
model_pytorch = Net() # 네트워크 선언 및 가중치 로드 했다 치고..
x = torch.rand(b, c, h, w)
out_torch = model_pytorch(x)
sess = rt.InferenceSession("onnx_model.onnx")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
out_onnx = sess.run(None, {input_name: x})
np.testing.assert_allclose(to_numpy(out_torch), out_onnx[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
결과가 오차범위 (rtol=1e-03, atol=1e-05) 안에 있다면 마지막 프린트문이 출력 될 것이고
오차범위 밖이라면 에러가 발생한다. 이는 변환이 잘못됐다는 얘기지..
'전공관련 > Deep Learning' 카테고리의 다른 글
[Caffe] caffe 환경 설정없이 caffemodel 값을 확인하자 (0) | 2021.02.24 |
---|---|
[Pytorch] model load시 key가 있는 레이어만 불러오자 (0) | 2021.02.17 |
[Onnx] Onnxruntime - GPU를 사용하자 (8) | 2020.03.09 |
[Onnx] onnx 모듈을 사용하기 위한 class를 만들어보자 (0) | 2020.02.26 |
[Onnx] visual studio에서 onnxruntime을 설치 해 보자 (0) | 2020.02.26 |