首页 > 其他 > 详细

Tensorflow 模型保存与调用

时间:2020-11-18 19:39:09      阅读:47      评论:0      收藏:0      [点我收藏+]

Tensorflow 两种保存模型的方式:pb 和  saved_model 都可以。

 

1、pb

1.1 模型保存成pb

freozen_pb.py

 1 import tensorflow as tf
 2 from tensorflow.python.framework import graph_util
 3 
 4 
 5 
 6 with tf.Session(graph=tf.Graph()) as sess:
 7     x = tf.placeholder(tf.int32, name=in_x)
 8     y = tf.placeholder(tf.int32, name=in_y)
 9     b = tf.Variable(1, name=b)
10     m = tf.multiply(x, y)
11     a = tf.add(m, b, name=out_add)
12 
13     sess.run(tf.global_variables_initializer())
14     constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, [out_add])
15 
16     feed_dict = {x: 10, y: 3}
17     print(sess.run(a, feed_dict))
18 
19     with tf.gfile.FastGFile(./model.pb, mode=wb) as f:
20         f.write(constant_graph.SerializeToString())

 

 

1.2 调用pb模型

call_pb.py

 1 import tensorflow as tf
 2 from tensorflow.python.platform import gfile
 3 
 4 
 5 sess = tf.Session()
 6 with gfile.FastGFile(./model.pb, rb) as f:
 7     graph_def = tf.GraphDef()
 8     graph_def.ParseFromString(f.read())
 9     sess.graph.as_default()
10     tf.import_graph_def(graph_def, name=‘‘)
11 
12 sess.run(tf.global_variables_initializer())
13 #print(sess.run(‘b:0‘))
14 
15 in_x = sess.graph.get_tensor_by_name(in_x:0)
16 in_y = sess.graph.get_tensor_by_name(in_y:0)
17 out_add = sess.graph.get_tensor_by_name(out_add:0)
18 
19 ret = sess.run(out_add, feed_dict={in_x: 8, in_y: 9})
20 print(ret)

 

2、 saved_model

2.1 模型保存成saved model

freozen_sm.py

 1 import os
 2 import tensorflow as tf
 3 
 4 saved_model_path = os.getcwd()
 5 
 6 with tf.Session(graph=tf.Graph()) as sess:
 7     x = tf.placeholder(tf.int32, name=in_x)
 8     y = tf.placeholder(tf.int32, name=in_y)
 9     b = tf.Variable(1, name=b)
10     m = tf.multiply(x, y)
11     a = tf.add(m, b, name=out_add)
12 
13     sess.run(tf.global_variables_initializer())
14 
15     tf.saved_model.simple_save(sess, ./sm, {in_x: x, in_y: y}, {out_add: a}, )

 

2.2 调用saved model模型

call_sm.py

 1 import tensorflow as tf
 2 
 3 sess = tf.Session()
 4 tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], ./sm)
 5 in_x = sess.graph.get_tensor_by_name(in_x:0)
 6 in_y = sess.graph.get_tensor_by_name(in_y:0)
 7 out_add = sess.graph.get_tensor_by_name(out_add:0)
 8 
 9 ret = sess.run(out_add, feed_dict={in_x: 8, in_y: 5})
10 print(ret)

 

Tensorflow 模型保存与调用

原文:https://www.cnblogs.com/vsignsoft/p/14000250.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!