전공관련/Deep Learning

[Onnx] pytorch model을 onnx로 변환하여 사용하자

매직블럭 2020. 2. 26. 12:52

onnx는 open neural network exchange의 약자로 신경망 모델을 framework 간 변환하도록 만들어진 것이다.

 

기존 framework만을 사용할 수 있는 환경이라면 그냥 사용해도 문제가 없지만
여러 이유로 인하여 onnx로의 변환이 필요할 수 있다.

 

그래서 우선 torch model을 onnx 모델로 변환하는 방법 정리.

 

변환 방법은 torchScript를 이용하여 pt 파일을 생성하는 것과 거의 동일한 형태로 변환이 가능하다.

 


# model load
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = pytorch_model()
model.to(device)
model.eval()

checkpoint = torch.load("./model/pytorch_model_weight.pth.tar", map_location=device)
checkpoint_dict = get_state_dict(checkpoint["state_dict"])
model.load_state_dict(checkpoint_dict)

# make dummy data
batch_size = 1
# model input size에 맞게 b c h w 순으로 파라미터 설정
x = torch.rand(batch_size, 1, 128, 128, requires_grad=True).to(device)
# feed-forward test
output = model(x)

# convert
torch.onnx.export(model, x, "./test_onnx.onnx", export_params=True, opset_version=10, do_constant_folding=True
                  , input_names = ['input'], output_names=['output']
                  # , dynamic_axes={'input' : {0 : 'batch_size'}, 'output' : {0 : 'batch_size'}}
                  # dynamic axes 는 pytorch 1.2 부터 지원하는듯??
                  )