首页 > 编程语言 > 详细

Tensorflow 保存模型 & 在java中调用

时间:2019-10-11 20:32:01      阅读:105      评论:0      收藏:0      [点我收藏+]

本节涉及:

  1. 保存TensorFlow 的模型供其他语言使用
  2. java中调用模型并进行预测计算

一、保存TensorFlow 的模型供其他语言使用

技术分享图片

 

 如果用户选择“y” ,则执行下面的步骤:

  • 判断程序执行目录下是否有 export 目录,如果有,调用 shutil 包中的 rmtress 函数将其删除,以免冲突
  • builder = tf .saved_model . builder . SavedModelBuilder ("export")   ———— 用于生成保存神经网络模型的对象builder,并指定保存位置为程序执行目录下的 export 子目录
  • builder.add_meta_graph_and_variables (sess,["tag"]) ———— 指定保存会话对象 sess 中的默认数据流图和可变参数(即保存模型的主要内容),并起标记名 “tag”,这个标记名 在以后被其他语言调用时会被引用
  • builder.save() ———— 保存

 

完后,会在程序执行目录下生成一个 export 子目录,其中包含了需要传递给其他语言程序的神经网络模型的相关文件。

在其他语言调用时,需要把这个文件夹 整个复制到需要使用的计算机上

二、java中调用模型并进行预测计算

 

技术分享图片

 

 调用模型文件进行预测的示例:

import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.SavedModelBundle;
import java.nio.FloatBuffer;
import java.util.Arrays;

public class TestTF {

    public static void main(String[] args) {  
        SavedModelBundle smb = SavedModelBundle.load("export", "tag");
        
        Session s = smb.session();
        
        float[][] matrix = {{1.0F, 2.0F, 3.0F, 4.0F}};
        System.out.println(Arrays.deepToString(matrix));  

        Tensor xFeed = Tensor.create(matrix);
        
        Tensor result = s.runner().feed("x", xFeed).fetch("y").run().get(0);
            
        FloatBuffer buf = FloatBuffer.allocate(2);
            
        result.writeTo(buf);
            
        System.out.println(result.toString());  

        System.out.println(buf.get(0));  
        System.out.println(buf.get(1));  
    }
}

技术分享图片

 

 技术分享图片

 

 技术分享图片

 

Tensorflow 保存模型 & 在java中调用

原文:https://www.cnblogs.com/expedition/p/11656593.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!