전공관련/Deep Learning
[ONNX] Onnx convert 모델을 검증하자
매직블럭
2020. 4. 22. 11:29
pytorch 등의 프레임워크에서 onnx로 convert 한 모델이 잘 변환됐는지 늘 확인이 필요하다.
이럴 때 확인을 위한 방법 정리
import torch
import numpy as np
import onnxruntime as rt
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def test():
model_pytorch = Net() # 네트워크 선언 및 가중치 로드 했다 치고..
x = torch.rand(b, c, h, w)
out_torch = model_pytorch(x)
sess = rt.InferenceSession("onnx_model.onnx")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
out_onnx = sess.run(None, {input_name: x})
np.testing.assert_allclose(to_numpy(out_torch), out_onnx[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
결과가 오차범위 (rtol=1e-03, atol=1e-05) 안에 있다면 마지막 프린트문이 출력 될 것이고
오차범위 밖이라면 에러가 발생한다. 이는 변환이 잘못됐다는 얘기지..