전공관련/Deep Learning
[TensorFlow] Saver를 이용하여 기존 model의 weight를 읽어오자.
매직블럭
2018. 3. 5. 11:26
Tensorflow에서 saver를 이용하여 기존 학습된 모델의 값을 읽어오는 방법.
1. 모델 전체를 읽어오는 경우
1 2 | saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=10) saver.restore(sess, pretrained_model) | cs |
2. architecture 중 일부 레이어를 제외할 경우
1 2 3 4 5 6 7 | variables = tf.trainable_variables() variables_to_restore = [v for v in variables if v.name != '제외할레이어'] saver = tf.train.Saver(variables_to_restore, max_to_keep=10) saver.restore(sess, pretrained_model) # scope을 이용한 경우 아래와 같이 해당 scope에 대한 부분을 제할수 있음 variables_to_restore = [v for v in variables if v.name.split('/')[0] != '제외할 scope'] | cs |