전공관련/Deep Learning

[TensorFlow] Saver를 이용하여 기존 model의 weight를 읽어오자.

매직블럭 2018. 3. 5. 11:26

Tensorflow에서 saver를 이용하여 기존 학습된 모델의 값을 읽어오는 방법.


1. 모델 전체를 읽어오는 경우

1
2
saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=10)
saver.restore(sess, pretrained_model)
cs


2. architecture 중 일부 레이어를 제외할 경우

1
2
3
4
5
6
7
variables = tf.trainable_variables()
variables_to_restore = [v for v in variables if v.name != '제외할레이어']
saver = tf.train.Saver(variables_to_restore, max_to_keep=10)
saver.restore(sess, pretrained_model)
 
# scope을 이용한 경우 아래와 같이 해당 scope에 대한 부분을 제할수 있음
variables_to_restore = [v for v in variables if v.name.split('/')[0!= '제외할 scope']
cs