• Tistory
    • 태그
    • 위치로그
    • 방명록
    • 관리자
    • 글쓰기
Carousel 01
Carousel 02
Previous Next

[Onnx] onnx 모듈을 사용하기 위한 class를 만들어보자

전공관련/Deep Learning 2020. 2. 26. 13:08




onnxruntime 예제코드에는 struct를 생성하여 사용하는 방법이 나와있다. 

코드 중 주요 부분만 떼서 보면 아래와 같다.

struct onnx_struct {
	onnx_struct() {
		auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
		input_tensor_ = Ort::Value::CreateTensor<float>(memory_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size());
		output_tensor_ = Ort::Value::CreateTensor<float>(memory_info, results_.data(), results_.size(), output_shape_.data(), output_shape_.size());
	}

	std::array<float, 1000> Run() {
		const char* input_names[] = { "input" };
		const char* output_names[] = { "output" };

		session_.Run(Ort::RunOptions{ nullptr }, input_names, &input_tensor_, 1, output_names, &output_tensor_, 1);

		// result_ = std::distance(results_.begin(), std::max_element(results_.begin(), results_.end()));
		return results_;
	}

	static constexpr const int width_ = 128;
	static constexpr const int height_ = 128;

	std::array<float, width_ * height_> input_image_{};
	std::array<float, 1000> results_{};
	// int64_t result_{ 0 };

private:
	Ort::Env env;
	Ort::Session session_{ env, L"./data/test_onnx.onnx", Ort::SessionOptions{ nullptr } };

	Ort::Value input_tensor_{ nullptr };
	std::array<int64_t, 4> input_shape_{ 1, 1, width_, height_ };

	Ort::Value output_tensor_{ nullptr };
	std::array<int64_t, 2> output_shape_{ 1, 1000 };
};

입력 사이즈와 출력 크기가 다 하드코딩 되어있을 뿐 아니라 모델 로드 부분도 구조체 생성단계에서 수행하니

활용 측면에서는 영 꽝이다. 

 


 

그래서 동적으로 사용 가능하도록 클래스 수정 중..

필요한 기능에 따라 지속적으로 수정 예정. ( 혹 비효율 적이거나 다른 방법이 있다면 알려주세요)

class onnx_module
{
public:
	onnx_module(std::string sModelPath, int nInputC, int nInputWidth, int nInputHeight, int nOutputDims);
	onnx_module(std::string sModelPath, int nInputC, int nInputWidth, int nInputHeight, int nOutputC, int nOutputWidth, int nOutputHeight);
	void Run(std::vector<float>& vResults);

	std::vector<float> results_;
	std::vector<float> input_image_;

private:
	Ort::Env env;
	Ort::Session* session_;

	Ort::Value input_tensor_{ nullptr };
	std::vector<int64_t> input_shape_;

	Ort::Value output_tensor_{ nullptr };
	std::vector<int64_t> output_shape_;
};

onnx_module::onnx_module(std::string sModelPath, int nInputC, int nInputWidth, int nInputHeight, int nOutputDims)
{
	std::string sPath = sModelPath;
	wchar_t* wPath = new wchar_t[sPath.length() + 1];
	std::copy(sPath.begin(), sPath.end(), wPath);
	wPath[sPath.length()] = 0;

	session_ = new Ort::Session(env, wPath, Ort::SessionOptions{ nullptr });
	delete[] wPath;

	const int batch_ = 1;
	const int channel_ = nInputC;
	const int width_ = nInputWidth;
	const int height_ = nInputHeight;

	input_image_.assign(width_*height_*channel_, 0.0);
	results_.assign(nOutputDims, 0.0);

	input_shape_.clear();
	input_shape_.push_back(batch_);
	input_shape_.push_back(channel_);
	input_shape_.push_back(width_);
	input_shape_.push_back(height_);

	output_shape_.clear();
	output_shape_.push_back(batch_);
	output_shape_.push_back(nOutputDims);



	auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
	input_tensor_ = Ort::Value::CreateTensor<float>(memory_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size());
	output_tensor_ = Ort::Value::CreateTensor<float>(memory_info, results_.data(), results_.size(), output_shape_.data(), output_shape_.size());
}

onnx_module::onnx_module(std::string sModelPath, int nInputC, int nInputWidth, int nInputHeight, int nOutputC, int nOutputWidth, int nOutputHeight)
{
	std::string sPath = sModelPath;
	wchar_t* wPath = new wchar_t[sPath.length() + 1];
	std::copy(sPath.begin(), sPath.end(), wPath);
	wPath[sPath.length()] = 0;


	session_ = new Ort::Session(env, wPath, Ort::SessionOptions{ nullptr });
	delete[] wPath;

	const int batch_ = 1;

	const int channel_in = nInputC;
	const int width_in = nInputWidth;
	const int height_in = nInputHeight;

	const int channel_out = nOutputC;
	const int width_out = nOutputWidth;
	const int height_out = nOutputHeight;

	input_image_.assign(width_in * height_in * channel_in, 0.0);
	results_.assign(nOutputWidth * nOutputHeight * nOutputC, 0.0);

	input_shape_.clear();
	input_shape_.push_back(batch_);
	input_shape_.push_back(channel_in);
	input_shape_.push_back(width_in);
	input_shape_.push_back(height_in);

	output_shape_.clear();
	output_shape_.push_back(batch_);
	output_shape_.push_back(channel_out);
	output_shape_.push_back(width_out);
	output_shape_.push_back(height_out);



	auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
	input_tensor_ = Ort::Value::CreateTensor<float>(memory_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size());
	output_tensor_ = Ort::Value::CreateTensor<float>(memory_info, results_.data(), results_.size(), output_shape_.data(), output_shape_.size());
}

void onnx_module::Run(std::vector<float>& vResults)
{
	const char* input_names[] = { "input" };
	const char* output_names[] = { "output" };

	(*session_).Run(Ort::RunOptions{ nullptr }, input_names, &input_tensor_, 1, output_names, &output_tensor_, 1);

	vResults.assign(results_.begin(), results_.end());
}

 

아직 수정해야 할 부분이 많다... 어렵다.. ㅋ

저작자표시 (새창열림)

'전공관련 > Deep Learning' 카테고리의 다른 글

[ONNX] Onnx convert 모델을 검증하자  (2) 2020.04.22
[Onnx] Onnxruntime - GPU를 사용하자  (8) 2020.03.09
[Onnx] visual studio에서 onnxruntime을 설치 해 보자  (0) 2020.02.26
[Onnx] pytorch model을 onnx로 변환하여 사용하자  (1) 2020.02.26
[Pytorch] Custom Dataloader를 사용하자  (0) 2019.12.23
블로그 이미지

매직블럭

작은 지식들 그리고 기억 한조각

,

카테고리

  • 살다보니.. (449)
    • 주절거림 (3)
    • 취미생활 (36)
      • 지식과 지혜 (3)
      • 풍경이 되어 (4)
      • Memories (17)
      • 엥겔지수를 높여라 (2)
    • mathematics (6)
      • Matrix Computation (2)
      • RandomProcesses (3)
    • English.. (8)
    • Programming (147)
      • C, C++, MFC (51)
      • C# (1)
      • OpenCV (17)
      • Python (58)
      • Git, Docker (3)
      • Matlab (4)
      • Windows (3)
      • Kinect V2 (2)
      • 기타 etc. (8)
    • 전공관련 (80)
      • Algorithm (6)
      • Deep Learning (54)
      • 실습 프로그램 (4)
      • 주워들은 용어정리 (8)
      • 기타 etc. (8)
    • Computer (118)
      • Utility (21)
      • Windows (31)
      • Mac (4)
      • Ubuntu, Linux (58)
      • NAS (2)
      • Embedded, Mobile (2)
    • IT, Device (41)
      • 제품 사용기, 개봉기 (14)
      • 스마트 체험단 신청 (27)
    • Wish List (3)
    • TISTORY TIP (5)
    • 미분류. 수정중 (1)

태그목록

  • LIBSVM
  • random variable
  • Deep Learning
  • 칼로리 대폭발
  • 스마트체험단
  • Convolutional Neural Networks
  • 에누리닷컴
  • 포르투갈
  • SVM
  • matlab
  • DeepLearning
  • review
  • 후쿠오카
  • 큐슈
  • portugal
  • 크롬
  • ReadString
  • CStdioFile
  • 매트랩
  • ColorMeRad
  • DSLR
  • 오봉자싸롱
  • 매트랩 함수
  • 일본
  • 딥러닝
  • function
  • 갤럭시노트3
  • utility
  • matlab function
  • Computer Tip

달력

«   2025/06   »
일 월 화 수 목 금 토
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30
06-26 10:32

LATEST FROM OUR BLOG

RSS 구독하기

BLOG VISITORS

  • Total :
  • Today :
  • Yesterday :

Copyright © 2015 Socialdev. All Rights Reserved.

티스토리툴바