Tensorflow-模型的保存、恢复以及fine-tune
最近在做的本科毕设需要对迁移过来的ConvNet进⾏在线fine-tune,由于之前在学习Tensorflow时对模型的保存恢复学习的不够深⼊,所以今天花了⼀个下午看了⼏篇⽂章,觉得有的写的很不错,就搬运过来。在最后有⾃⼰的总结。
使⽤tensorflow的过程中,我们常常会⽤到训练好的模型。我们可以直接使⽤训练好的模型进⾏测试或者对训练好的模型做进⼀步的微调。(微调是指初始化⽹络参数的时候不再是随机初始化,⽽是使⽤先前训练好的权重参数进⾏初始化,在此基础上对⽹络的全部或者局部参数进⾏重新训练的过程)。为了实现模型的复⽤或微调,我将从以下四个⽅⾯进⾏说明:
模型是指什么?
如何保存模型?
如何恢复模型?
如何进⾏微调?
⼀、模型是指什么?
tensorflow训练后需要保存的模型主要包含两部分,⼀是⽹络图,⼆是⽹络图⾥的参数值。保存的模型⽂件结构如下(假设每过1000次保存⼀次):
checkpoint
MyModel-1000.data-00000-of-00001
MyModel-1000.index
MyModel-2000.data-00000-of-00001
MyModel-2000.index
MyModel-3000.data-00000-of-00001
麻省理工大学
MyModel-3000.index
.......
avoided1 checkpoint
checkpoint是⼀个⽂本⽂件,如下所⽰。其中有model_checkpoint_path和all_model_checkpoint_paths两个属性。
model_checkpoint_path保存了最新的tensorflow模型⽂件的⽂件名,all_model_checkpoint_paths则有未被删除的所有tensorflow模型⽂件的⽂件名。
model_checkpoint_path: "MyModel-3000"
all_model_checkpoint_paths: "MyModel-1000"
all_model_checkpoint_paths: "MyModel-2000"
all_model_checkpoint_paths: "MyModel-3000"
......
emini2 .meta⽂件
.meta ⽂件⽤于保存⽹络结构,且以 protocol buffer 格式进⾏保存。protocol buffer是Google 公司内部使⽤的⼀种轻便⾼效的数据描述语⾔。类似于XML能够将结构化数据序列化,protocol buffer也可⽤于序列化结构化数据,并将其⽤于数据存储、通信协议等⽅⾯。相较于XML,protocol buffer更⼩、更快、也更简单。划重点:⽹络结构,仅仅是⽹络结构
3 .data-00000-of-00001 ⽂件和 .index ⽂件
在tensorflow 0.11之前,保存的⽂件结构如下。tensorflow 0.11之后,将ckpt⽂件拆分为了.data-00000-of-00001 和 .index 两个⽂件。.ckpt是⼆进制⽂件,保存了所有变量的值及变量的名称。拆分后的.data-00000-of-00001 保存的是变量值,.index⽂件保存的
是.data⽂件中数据和 .meta⽂件中结构图之间的对应关系(也就是变量的名称)划重点:变量!tf.Variable()
checkpoint
MyModel.ckpt
⼆、如何保存模型?
tensorflow 提供tf.train.Saver类及tf.train.Saver类下⾯的save⽅法共同保存模型。下⾯分别说明tf.train.Saver类及save⽅法:
keep_checkpoint_every_n_hours=10000.0, name=None, restore_quentially=Fal,
saver_def=None, builder=None, defer_build=Fal, allow_empty=Fal,
write_version=saver_pb2.SaverDef.V2, pad_step_number=Fal)
就常⽤的参数进⾏说明:
var_list:如果我们不对tf.train.Saver指定任何参数,默认会保存所有变量。如果你只想保存⼀部分变量,
可以通过将需要保存的变量构造list或者dictionary,赋值给var_list。
max_to_keep:tensorflow默认只会保存最近的5个模型⽂件,如果你希望保存更多,可以通过max_to_keep来指定
keep_checkpoint_every_n_hours:设置每隔⼏⼩时保存⼀次模型
save(ss,save_path,global_step=None,latest_filename=None,meta_graph_suffix="meta",
write_meta_graph=True, write_state=True)
就常⽤的参数进⾏说明:
ss:在tensorflow中,只有开启ssion时数据才会流动,因此保存模型的时候必须传⼊ssion。
save_path: 模型保存的路径及模型名称。
global_step:定义每隔多少步保存⼀次模型,每次会在保存的模型名称后⾯加上global_step的值作为后缀
write_meta_graph:布尔值,True表⽰每次都保存图,Fal表⽰不保存图(由于图是不变的,没必要每次都去保存)
注意:保存变量的时候必须在ssion中;保存的变量必须已经初始化;
1.简单⽰例
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
w3 = tf.Variable(tf.random_normal(shape=[1]), name='w3')
saver = tf.train.Saver()#未指定任何参数,默认保存所有变量。等价于saver = tf.train.ainable_variables())
中译英save_path = './checkpoint_dir/MyModel'#定义模型保存的路径./checkpoint_dir/及模型名称MyModel
with tf.Session() as ss:
ss.run(tf.global_variables_initializer())
saver.save(ss, save_path)
执⾏后,在checkpoint_dir⽬录下创建模型⽂件如下:
checkpoint
MyModel.data-00000-of-00001
MyModel.index
2.经典⽰例
import tensorflow as tf
ves import xrange
import os
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w11')#变量w1在内存中的名字是w11;恢复变量时应该与name的名字保持⼀致
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w22')
因为英语w3 = tf.Variable(tf.random_normal(shape=[5]), name='w33')
#保存⼀部分变量[w1,w2];只保存最近的5个模型⽂件;每2⼩时保存⼀次模型
saver = tf.train.Saver([w1, w2],max_to_keep=5, keep_checkpoint_every_n_hours=2)
save_path = './checkpoint_dir/MyModel'#定义模型保存的路径./checkpoint_dir/及模型名称MyModel
# Launch the graph and train, saving the model every 1,000 steps.
with tf.Session() as ss:思忖
我是歌手第三季名单
ss.run(tf.global_variables_initializer())
for step in xrange(100):
if step % 10 == 0:
# 每隔step=10步保存⼀次模型( keep_checkpoint_every_n_hours与global_step可同时使⽤,表⽰'与',通常任选⼀个就够了);
#每次会在保存的模型名称后⾯加上global_step的值作为后缀
# write_meta_graph=Fal表⽰不保存图
saver.save(ss, save_path, global_step=step, write_meta_graph=Fal)
# 如果模型⽂件中没有保存⽹络图,则使⽤如下语句保存⼀张⽹络图(由于⽹络图不变,只保存⼀次就⾏)
if not ists('./checkpoint_a'):
# port_meta_graph(filename=None, collection_list=None,as_text=Fal,export_scope=None,clear_devices=Fal)
# port_meta_graph()仅仅保存⽹络图;参数filename表⽰⽹络图保存的路径即⽹络图名称
#注意:port_meta_graph()等价于tf.port_meta_graph()
执⾏后,在checkpoint_dir⽬录下创建模型⽂件如下:
checkpoint
MyModel-50.data-00000-of-00001
MyModel-50.index
MyModel-60.data-00000-of-00001
MyModel-60.index
MyModel-70.data-00000-of-00001
MyModel-70.index
MyModel-80.data-00000-of-00001
MyModel-80.index
MyModel-90.data-00000-of-00001
MyModel-90.index
三、如何恢复模型?
tensorflow保存模型时将⽹络图和⽹络图⾥的参数值分开保存。因此,在恢复模型时,也要分为2步:构造⽹络图和加载参数。模型的恢复分为两步,第⼀步是graph的重新构建,第⼆步是模型参数的加载。模型参数的加载对应的是变量的初始化。
1 构造⽹络图
构造⽹络图可以⼿动创建(需要创建⼀个跟保存的模型⼀模⼀样的⽹络图)
也可以从meta⽂件⾥加载graph进⾏创建,如下:
#⾸先恢复graph
saver = tf.train.import_meta_graph('./checkpoint_a')
2 恢复参数有两种⽅式,如下:
with tf.Session() as ss:
#恢复最新保存的权重
#指定⼀个权重恢复
四、如何进⾏微调?(*)
上⾯叙述了如何恢复模型,那么,对于恢复出来的模型应该如何使⽤呢?这⾥以tensorflow官⽹给出的vgg为例进⾏说明。恢复出来的模型有四种⽤途:
查看模型参数
直接使⽤原始模型进⾏测试
扩展原始模型(直接使⽤扩展后的⽹络进⾏测试,扩展后需要重新训练的情况见微调部分)
微调:使⽤训练好的权重参数进⾏初始化,在此基础上对⽹络的全部或局部参数进⾏重新训练
1.查看模型参数
import tensorflow as tf
import vgg
# build graph
graph = tf.Graph
inputs = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='inputs')
net, end_points = vgg.vgg_16(inputs, num_class=1000)
saver = tf.train.Saver()
e bookwith tf.Session() as ss:
"""
查看恢复的模型参数
tf.global_variables()获得的与tf.trainable_variables()类似,只是多了⼀些⾮trainable的变量,⽐如定义时指定为trainable=Fal的变量; _operations()则可以获得⼏乎所有的operations相关的tensor
"""
tvs = [v for v ainable_variables()]
print('获得所有可训练变量的权重:')
for v in tvs:
print(v.name)
print(ss.run(v))
gv = [v for v in tf.global_variables()]
print('获得所有变量:')
for v in gv:
print(v.name, '\n')
# _operations()可以换为tf.get_default_graph().get_operations()
ops = [o for o _operations()]
print('获得所有operations相关的tensor:')
for o in ops:
print(o.name, '\n')
2.直接使⽤原始模型进⾏测试
import tensorflow as tf
import vgg
import numpy as np
import cv2
image = cv2.imread('./cat.18.jpg')
print(image.shape)
res = size(image, (224,224))
res_image = np.expand_dims(res, 0)
print(res_image.shape, type(res_image))
#build graph
graph = tf.Graph
inputs = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='inputs')
net, end_points = vgg.vgg_16(inputs, num_class=1000)
print(end_points)
saver = tf.train.Saver()
with tf.Session() as ss:
#恢复权重
惊喜英文单词
# Get input and output tensors
# 需要特别注意,get_tensor_by_name后⾯传⼊的参数,如果没有重复,需要在后⾯加上“:0” # ss.graph等价于tf.get_default_graph()
input = _tensor_by_name('inputs:0')
output = _tensor_by_name('vgg_16/fc8/squeezed:0')
# Run forward pass to calculate pred
#使⽤不同的数据运⾏相同的⽹络,只需将新数据通过feed_dict传递到⽹络即可。
pred = ss.run(output, feed_dict={input:res_image})
#得到使⽤vgg⽹络对输⼊图⽚的分类结果
photos
print(np.argmax(pred, 1))
3.扩展原始模型