전공관련/Deep Learning

[ONNX] Onnx convert 모델을 검증하자

매직블럭 2020. 4. 22. 11:29

pytorch 등의 프레임워크에서 onnx로 convert 한 모델이 잘 변환됐는지 늘 확인이 필요하다.

 

이럴 때 확인을 위한 방법 정리

 


import torch
import numpy as np
import onnxruntime as rt

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
    
def test():
    model_pytorch = Net() # 네트워크 선언 및 가중치 로드 했다 치고..
    x = torch.rand(b, c, h, w)
    
    out_torch = model_pytorch(x)
    
    sess = rt.InferenceSession("onnx_model.onnx")
    input_name = sess.get_inputs()[0].name
    label_name = sess.get_outputs()[0].name
    
    out_onnx = sess.run(None, {input_name: x})
    
    np.testing.assert_allclose(to_numpy(out_torch), out_onnx[0], rtol=1e-03, atol=1e-05)
    print("Exported model has been tested with ONNXRuntime, and the result looks good!")

 

결과가 오차범위 (rtol=1e-03, atol=1e-05) 안에 있다면 마지막 프린트문이 출력 될 것이고

오차범위 밖이라면 에러가 발생한다. 이는 변환이 잘못됐다는 얘기지..