tensorflow에서 데이터를 저장할때 생성되는 파일중에 meta파일에는 graph에 관한 정보가 저장되어 있다.
이 정보를 읽어오고 사용하는 방법에 대한 정리.
====================================================================
1. saver를 이용하여 graph 와 variable 저장하기
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | # 샘플1. graph 생성하고 saver로 저장 import tensorflow as tf g = tf.Graph() with g.as_default(): v0 = tf.placeholder(tf.int32, name="V0") v1 = tf.Variable(10, name="v1") v2 = tf.Variable(20, name="v2") v3 = tf.add(v0, v2, name="add") with tf.Session(graph=g) as sess: saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) save_path = saver.save(sess, "./saved/test1") feed_dict = {v0:7} output = sess.run([v3], feed_dict=feed_dict) | cs |
2. meta데이터와 ckpt 파일로부터 그래프와 값을 읽어오고 사용하기
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | # 샘플2. meta파일로부터 graph 복구하고 variable restore import tensorflow as tf with tf.Session() as sess: new_saver = tf.train.import_meta_graph("saved/test1.meta") new_saver.restore(sess. tf.train.latest_checkpoint("./saved")) g = tf.get_default_graph() g_collection_key_list = g.get_all_collection_keys() g_collection_list = g.get_collection(g_collection_key_list[0]) oper_list = g.get_operations() # 기존 그래프를 import, 새로운 데이터를입력으로 사용하여 연산 new_v0 = g.get_tensor_by_name("v0:0") feed_dict = {new_v0:9} output = sess.run(["v3:0"], feed_dict=feed_dict) print(output) # 기존 그래프에 새로운 그래프를 연결하여 연산 new_v3 = g.get_tensor_by_name("v3:0") add_on_op = tf.multiply(new_v3, 3, name="multiply") print(sess.run(add_on_op, feed_dict)) | cs |
참고 : http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
'전공관련 > Deep Learning' 카테고리의 다른 글
[Pytorch] 학습 한 모델을 저장하고 불러오자 (1) | 2019.03.12 |
---|---|
Deconvolution 파라미터에 따른 출력 크기 계산하기 (0) | 2018.10.19 |
[TensorFlow] Saver를 이용하여 기존 model의 weight를 읽어오자. (0) | 2018.03.05 |
digits 에서 학습한 텐서플로우 모델 파이썬에서 사용하기 (0) | 2018.01.19 |
[TensorFlow] Slim을 써보자 - custom preprocessing을 설정하자 (0) | 2017.12.06 |