digits 에서 학습한 텐서플로우 모델 파이썬에서 사용하기
추후 참고용으로 기록
Because my question is mostly due to my bad understanding of TensorFlow, I do a trip on the official documentation, and I've found some answers.
Firstly, I combine the use of contrib/slim and contrib/tflearn and even if it is possible, it is not really relevant. So I rewrite the network using only slim :
import tensorflow as tf
import tensorflow.contrib.slim as slim
class LeNetModel():
def gray28(self, nclasses):
# x = input_data(shape=[None, 28, 28, 1])
x = tf.placeholder(tf.float32, shape=[1, 28, 28], name="x")
rs = tf.reshape(x, shape=[-1, 28, 28, 1])
# scale (divide by MNIST std)
# x = x * 0.0125
with slim.arg_scope([slim.conv2d, slim.fully_connected],
weights_initializer=tf.contrib.layers.xavier_initializer(),
weights_regularizer=slim.l2_regularizer(0.0005)):
model = slim.conv2d(rs, 20, [5, 5], padding='VALID', scope='conv1')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='pool1')
model = slim.conv2d(model, 50, [5, 5], padding='VALID', scope='conv2')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='pool2')
model = slim.flatten(model)
model = slim.fully_connected(model, 500, scope='fc1')
model = slim.dropout(model, 0.5, is_training=True, scope='do1')
model = slim.fully_connected(model, nclasses, activation_fn=None, scope='fc2')
return x, model
I return the x placeholder and the model, and I use it to load the DIGITS pre-trained model (checkpoint) :
import tensorflow as tf
import tensorflow.contrib.slim as slim
import cv2
from models.lenet import LeNetModel
# Helper function to load/resize images
def image(path):
img = cv2.imread(path, 0)
return cv2.resize(img, dsize=(28,28))
# Define a function that adds the model/ prefix to all variables :
def name_in_checkpoint(var):
return 'model/' + var.op.name
#Instantiate the model
x, model = LeNetModel().gray28(2)
# Define the variables to restore :
# Exclude the "is_training" that I don't care about
variables_to_restore = slim.get_variables_to_restore(exclude=["is_training"])
# Rename the other variables with the function name_in_checkpoint
variables_to_restore = {name_in_checkpoint(var):var for var in variables_to_restore}
# Create a Saver to restore the checkpoint, given the variables
restorer = tf.train.Saver(variables_to_restore)
#Launch a session to restore the checkpoint and try to infer some images :
with tf.Session() as sess:
# Restore variables from disk.
restorer.restore(sess, "src/prototype/models/snapshot_5.ckpt")
print("Model restored.")
print(sess.run(model, feed_dict={x:[image("/home/damien/Vidéos/1/positives/img/1-img143.jpg")]}))
print(sess.run(model, feed_dict={x:[image("/home/damien/Vidéos/0/positives/img/1-img1.jpg")]}))
출처 : https://stackoverflow.com/questions/47506607/define-tensorflow-network-key-names-according-to-an-existing-checkpoint/47525747#47525747
'전공관련 > Deep Learning' 카테고리의 다른 글
[TensorFlow] meta file로부터 graph를 읽어오고 사용하는 방법 (2) | 2018.03.22 |
---|---|
[TensorFlow] Saver를 이용하여 기존 model의 weight를 읽어오자. (0) | 2018.03.05 |
[TensorFlow] Slim을 써보자 - custom preprocessing을 설정하자 (0) | 2017.12.06 |
[TensorFlow] Slim을 써보자 - custom architecture를 만들자 (0) | 2017.12.06 |
[TensorFlow] Slim을 써보자 - custom dataset을 만들자 (0) | 2017.12.06 |