这节是关于tensorflow的Freezing,字面意思是冷冻,可理解为整合合并;整合什么呢,就是将模型文件和权重文件整合合并为一个文件,主要用途是便于发布。
tensorflow在训练过程中,通常不会将权重数据保存的格式文件里(这里我理解是模型文件),反而是分开保存在一个叫checkpoint的检查点文件里,当初始化时,再通过模型文件里的变量Op节点来从checkoupoint文件读取数据并初始化变量。这种模型和权重数据分开保存的情况,使得发布产品时不是那么方便,我们可以将tf的图和参数文件整合进一个后缀为pb的二进制文件中,由于整合过程回将变量转化为常量,所以我们在日后读取模型文件时不能够进行训练,仅能向前传播,而且我们在保存时需要指定节点名称。
将图变量转换为常量的API:tf.graph_util.convert_variables_to_constants
转换后的graph_def对象转换为二进制数据(graph_def.SerializeToString())后,写入pb即可。
|
1
2
3
4
5
6
7
8
9
10
11
12
13
|
import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[1]), name=‘v1‘)v2 = tf.Variable(tf.constant(2.0, shape=[1]), name=‘v2‘)result = v1 + v2saver = tf.train.Saver()with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.save(sess, ‘./tmodel/test_model.ckpt‘) gd = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), [‘add‘])with tf.gfile.GFile(‘./tmodel/model.pb‘, ‘wb‘) as f: f.write(gd.SerializeToString()) |
我们可以直接查看gd:
node {
name: "v1"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 1
}
}
float_val: 1.0
}
}
}
}
……
node {
name: "add"
op: "Add"
input: "v1/read"
input: "v2/read"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
library {
}
四、从图上读取张量
上面的代码实际上已经包含了本小节的内容,但是由于从图上读取特定的张量是如此的重要,所以我仍然单独的补充上这部分的内容。
无论如何,想要获取特定的张量我们必须要有张量的名称和图的句柄,比如 ‘import/pool_3/_reshape:0‘ 这种,有了张量名和图,索引就很简单了。
从二进制模型加载张量
第二小节的代码很好的展示了这种情况
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
BOTTLENECK_TENSOR_NAME = ‘pool_3/_reshape:0‘ # 瓶颈层输出张量名称JPEG_DATA_TENSOR_NAME = ‘DecodeJpeg/contents:0‘ # 输入层张量名称MODEL_DIR = ‘./inception_dec_2015‘ # 模型存放文件夹MODEL_FILE = ‘tensorflow_inception_graph.pb‘ # 模型名# 加载模型# with gfile.FastGFile(os.path.join(MODEL_DIR,MODEL_FILE),‘rb‘) as f: # 阅读器上下文with open(os.path.join(MODEL_DIR, MODEL_FILE), ‘rb‘) as f: # 阅读器上下文 graph_def = tf.GraphDef() # 生成图 graph_def.ParseFromString(f.read()) # 图加载模型# 加载图上节点张量(按照句柄理解)bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def( # 从图上读取张量,同时导入默认图 graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME]) |
从当前图中获取对应张量
这个就是很普通的情况,从我们当前操作的图中获取某个张量,用于feed啦或者用于输出等操作,API也很简单,用法如下:
g.get_tensor_by_name(‘import/pool_3/_reshape:0‘)
g表示当前图句柄,可以简单的使用 g = tf.get_default_graph() 获取。
从图中获取节点信息
有的时候我们对于模型中的节点并不够了解,此时我们可以通过图句柄来查询图的构造:
|
1
2
|
g = tf.get_default_graph()print(g.as_graph_def().node) |
这个操作将返回图的构造结构。从这里,对比前面的代码,我们也可以了解到:graph_def 实际就是图的结构信息存储形式,我们可以将之还原为图(二进制模型加载代码中展示了),也可以从图中将之提取出来(本部分代码)。

