전공관련/Deep Learning

[TensorFlow] meta file로부터 graph를 읽어오고 사용하는 방법

매직블럭 2018. 3. 22. 14:25

tensorflow에서 데이터를 저장할때 생성되는 파일중에 meta파일에는 graph에 관한 정보가 저장되어 있다.


이 정보를 읽어오고 사용하는 방법에 대한 정리.


====================================================================


1. saver를 이용하여 graph 와 variable 저장하기


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 샘플1. graph 생성하고 saver로 저장
import tensorflow as tf
 
= tf.Graph()
 
with g.as_default():
    v0 = tf.placeholder(tf.int32, name="V0")
    v1 = tf.Variable(10, name="v1")
    v2 = tf.Variable(20, name="v2")
    v3 = tf.add(v0, v2, name="add")
 
with tf.Session(graph=g) as sess:
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    save_path = saver.save(sess, "./saved/test1")
    feed_dict = {v0:7}
    output = sess.run([v3], feed_dict=feed_dict)
 
 
 
cs



2. meta데이터와 ckpt 파일로부터 그래프와 값을 읽어오고 사용하기


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 샘플2. meta파일로부터 graph 복구하고 variable restore
import tensorflow as tf
 
with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph("saved/test1.meta")
    new_saver.restore(sess. tf.train.latest_checkpoint("./saved"))
 
    g = tf.get_default_graph()
    
    g_collection_key_list = g.get_all_collection_keys()
 
    g_collection_list = g.get_collection(g_collection_key_list[0])
 
    oper_list = g.get_operations()
    
    # 기존 그래프를 import, 새로운 데이터를입력으로 사용하여 연산
    new_v0 = g.get_tensor_by_name("v0:0")
    feed_dict = {new_v0:9}
    output = sess.run(["v3:0"], feed_dict=feed_dict)
    print(output)
 
    # 기존 그래프에 새로운 그래프를 연결하여 연산
    new_v3 = g.get_tensor_by_name("v3:0")
    add_on_op = tf.multiply(new_v3, 3, name="multiply")
    print(sess.run(add_on_op, feed_dict))
cs


참고 : http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/