• Tistory
    • 태그
    • 위치로그
    • 방명록
    • 관리자
    • 글쓰기
Carousel 01
Carousel 02
Previous Next

digits 에서 학습한 텐서플로우 모델 파이썬에서 사용하기

전공관련/Deep Learning 2018. 1. 19. 10:27




digits 에서 학습한 텐서플로우 모델 파이썬에서 사용하기

추후 참고용으로 기록


Because my question is mostly due to my bad understanding of TensorFlow, I do a trip on the official documentation, and I've found some answers.

Firstly, I combine the use of contrib/slim and contrib/tflearn and even if it is possible, it is not really relevant. So I rewrite the network using only slim :

import tensorflow as tf
import tensorflow.contrib.slim as slim


class LeNetModel():

    def gray28(self, nclasses):
        # x = input_data(shape=[None, 28, 28, 1])
        x = tf.placeholder(tf.float32, shape=[1, 28, 28], name="x")
        rs = tf.reshape(x, shape=[-1, 28, 28, 1])
        # scale (divide by MNIST std)
        # x = x * 0.0125
        with slim.arg_scope([slim.conv2d, slim.fully_connected],
                            weights_initializer=tf.contrib.layers.xavier_initializer(),
                            weights_regularizer=slim.l2_regularizer(0.0005)):
            model = slim.conv2d(rs, 20, [5, 5], padding='VALID', scope='conv1')
            model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='pool1')
            model = slim.conv2d(model, 50, [5, 5], padding='VALID', scope='conv2')
            model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='pool2')
            model = slim.flatten(model)
            model = slim.fully_connected(model, 500, scope='fc1')
            model = slim.dropout(model, 0.5, is_training=True, scope='do1')
            model = slim.fully_connected(model, nclasses, activation_fn=None, scope='fc2')

            return x, model

I return the x placeholder and the model, and I use it to load the DIGITS pre-trained model (checkpoint) :

import tensorflow as tf
import tensorflow.contrib.slim as slim
import cv2
from models.lenet import LeNetModel

# Helper function to load/resize images
def image(path):
    img = cv2.imread(path, 0)
    return cv2.resize(img, dsize=(28,28))

# Define a function that adds the model/ prefix to all variables :
def name_in_checkpoint(var):
  return 'model/' + var.op.name

#Instantiate the model
x, model = LeNetModel().gray28(2)

# Define the variables to restore :
# Exclude the "is_training" that I don't care about
variables_to_restore = slim.get_variables_to_restore(exclude=["is_training"])
# Rename the other variables with the function name_in_checkpoint
variables_to_restore = {name_in_checkpoint(var):var for var in variables_to_restore}

# Create a Saver to restore the checkpoint, given the variables
restorer = tf.train.Saver(variables_to_restore)

#Launch a session to restore the checkpoint and try to infer some images :
with tf.Session() as sess:
    # Restore variables from disk.
    restorer.restore(sess, "src/prototype/models/snapshot_5.ckpt")
    print("Model restored.")
    print(sess.run(model, feed_dict={x:[image("/home/damien/Vidéos/1/positives/img/1-img143.jpg")]}))
    print(sess.run(model, feed_dict={x:[image("/home/damien/Vidéos/0/positives/img/1-img1.jpg")]}))


출처 : https://stackoverflow.com/questions/47506607/define-tensorflow-network-key-names-according-to-an-existing-checkpoint/47525747#47525747

저작자표시비영리

'전공관련 > Deep Learning' 카테고리의 다른 글

[TensorFlow] meta file로부터 graph를 읽어오고 사용하는 방법  (2) 2018.03.22
[TensorFlow] Saver를 이용하여 기존 model의 weight를 읽어오자.  (0) 2018.03.05
digits 에서 학습한 텐서플로우 모델 파이썬에서 사용하기  (0) 2018.01.19
[TensorFlow] Slim을 써보자 - custom preprocessing을 설정하자  (0) 2017.12.06
[TensorFlow] Slim을 써보자 - custom architecture를 만들자  (0) 2017.12.06
[TensorFlow] Slim을 써보자 - custom dataset을 만들자  (0) 2017.12.06
블로그 이미지

매직블럭

작은 지식들 그리고 기억 한조각

트랙백 0개, 댓글 0개가 달렸습니다

댓글을 달아 주세요

  • «
  • 1
  • ···
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • ···
  • 414
  • »

카테고리

  • 살다보니.. (414)
    • 주절거림 (3)
    • 취미생활 (36)
      • 지식과 지혜 (3)
      • 풍경이 되어 (4)
      • Memories (17)
      • 엥겔지수를 높여라 (2)
    • mathematics (6)
      • Matrix Computation (2)
      • RandomProcesses (3)
    • English.. (8)
    • Programming (134)
      • C, C++, MFC (51)
      • C# (1)
      • OpenCV (17)
      • Python (47)
      • Git, Docker (3)
      • Matlab (4)
      • Windows (3)
      • Kinect V2 (2)
      • 기타 etc. (6)
    • 전공관련 (73)
      • Algorithm (6)
      • Deep Learning (49)
      • 실습 프로그램 (4)
      • 주워들은 용어정리 (8)
      • 기타 etc. (6)
    • Computer (104)
      • Utility (21)
      • Windows (24)
      • Ubuntu, Linux (55)
      • NAS (2)
      • Embedded, Mobile (2)
    • IT, Device (41)
      • 제품 사용기, 개봉기 (14)
      • 스마트 체험단 신청 (27)
    • Wish List (3)
    • TISTORY TIP (5)
    • 미분류. 수정중 (1)

태그목록

  • 매트랩 함수
  • Deep Learning
  • 칼로리 대폭발
  • utility
  • 스마트체험단
  • ReadString
  • 갤럭시노트3
  • random variable
  • matlab function
  • portugal
  • CStdioFile
  • 딥러닝
  • ColorMeRad
  • 일본
  • function
  • SVM
  • 오봉자싸롱
  • matlab
  • Computer Tip
  • 크롬
  • 에누리닷컴
  • 매트랩
  • review
  • LIBSVM
  • 후쿠오카
  • 포르투갈
  • 큐슈
  • DeepLearning
  • Convolutional Neural Networks
  • DSLR

달력

«   2022/05   »
일 월 화 수 목 금 토
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30 31        
05-21 17:58

LATEST FROM OUR BLOG

  • 골뱅이 연산자의 의미 (행렬곱)..
  • 프린터 용지 부족 문제를 해⋯.
  • [MXNet] 데이터 리스트를 만⋯.
  • 예쁘게 출력하자 pprint - pr⋯.
  • 작업표시줄 미리보기를 리스⋯.
  • 이미지 실제 파일 포맷 확인하기.
  • 알리 등 해외배송 배송상태를⋯.
  • 티스토리 코드블럭 내용을 복⋯.
  • warning 을 on/off 하자.
  • windows 10 파일 선택, 파일⋯.
RSS 구독하기

BLOG VISITORS

  • Total : 1,115,104
  • Today : 136
  • Yesterday : 530

Copyright © 2015 Socialdev. All Rights Reserved.

티스토리툴바