전공관련/Deep Learning
[TensorFlow] meta file로부터 graph를 읽어오고 사용하는 방법
매직블럭
2018. 3. 22. 14:25
tensorflow에서 데이터를 저장할때 생성되는 파일중에 meta파일에는 graph에 관한 정보가 저장되어 있다.
이 정보를 읽어오고 사용하는 방법에 대한 정리.
====================================================================
1. saver를 이용하여 graph 와 variable 저장하기
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | # 샘플1. graph 생성하고 saver로 저장 import tensorflow as tf g = tf.Graph() with g.as_default(): v0 = tf.placeholder(tf.int32, name="V0") v1 = tf.Variable(10, name="v1") v2 = tf.Variable(20, name="v2") v3 = tf.add(v0, v2, name="add") with tf.Session(graph=g) as sess: saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) save_path = saver.save(sess, "./saved/test1") feed_dict = {v0:7} output = sess.run([v3], feed_dict=feed_dict) | cs |
2. meta데이터와 ckpt 파일로부터 그래프와 값을 읽어오고 사용하기
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | # 샘플2. meta파일로부터 graph 복구하고 variable restore import tensorflow as tf with tf.Session() as sess: new_saver = tf.train.import_meta_graph("saved/test1.meta") new_saver.restore(sess. tf.train.latest_checkpoint("./saved")) g = tf.get_default_graph() g_collection_key_list = g.get_all_collection_keys() g_collection_list = g.get_collection(g_collection_key_list[0]) oper_list = g.get_operations() # 기존 그래프를 import, 새로운 데이터를입력으로 사용하여 연산 new_v0 = g.get_tensor_by_name("v0:0") feed_dict = {new_v0:9} output = sess.run(["v3:0"], feed_dict=feed_dict) print(output) # 기존 그래프에 새로운 그래프를 연결하여 연산 new_v3 = g.get_tensor_by_name("v3:0") add_on_op = tf.multiply(new_v3, 3, name="multiply") print(sess.run(add_on_op, feed_dict)) | cs |
참고 : http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/