首页 > 其他 > 详细

TensorFlow使用记录 (一): 基本概念

时间:2019-10-03 16:53:13      阅读:92      评论:0      收藏:0      [点我收藏+]

基本使用

  • 使用graph来表示计算任务
  • 在被称之为Session的上下文中执行graph
  • 使用tensor表示数据
  • 通过Variable维护状态
  • 使用feed和fetch可以为任意的操作(op)赋值或者取数据

综述

TensorFlow 是一个编程系统, 使用图来表示计算任务. 图中的节点被称之为 op (operation 的缩写). 一个 op 获得 0 个或多个 Tensor, 执行计算, 产生 0 个或多个 Tensor. 每个 Tensor 是一个类型化的多维数组. 例如, 你可以将一小组图像集表示为一个四维浮点数数组, 这四个维度分别是 [batch, height, width, channels].

一个 TensorFlow 图描述了计算的过程. 为了进行计算, 图必须在 会话 里被启动. 会话 将图的 op 分发到诸如 CPU 或 GPU 之类的 设备 上, 同时提供执行 op 的方法. 这些方法执行后, 将产生的 tensor 返回. 在 Python 语言中, 返回的 tensor 是 numpy ndarray 对象; 在 C 和 C++ 语言中, 返回的 tensor 是tensorflow::Tensor 实例.

MNIST

技术分享图片
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

print(tf.__version__)

# 1. create data
mnist = input_data.read_data_sets(../MNIST_data, one_hot=True)
with tf.variable_scope(Input):
    tf_x = tf.placeholder(tf.float32, [None, 28 * 28], name=x)
    image = tf.reshape(tf_x, [-1, 28, 28, 1], name=image)
    tf_y = tf.placeholder(tf.float32, [None, 10], name=y)
    is_training = tf.placeholder(tf.bool, None)

# 2. define Network
with tf.variable_scope(Net):
    """
    "SAME" 类型的padding:
    out_height = ceil(in_height / strides[1]); ceil向上取整
    out_width = ceil(in_width / strides[2])

    "VALID"类型的padding:
    out_height = ceil((in_height - filter_height + 1) / striders[1])
    out_width = ceil((in_width - filter_width + 1) / striders[2]
    """
    conv1 = tf.layers.conv2d(inputs=image, filters=32, kernel_size=5,
                             strides=1, padding=same, activation=tf.nn.relu)  # 32x28x28
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=2, strides=2)  # 32x14x14
    conv2 = tf.layers.conv2d(pool1, 64, 3, 1, same, activation=tf.nn.relu)  # 64x14x14
    pool2 = tf.layers.max_pooling2d(conv2, 2, 2)  # 64x7x7
    pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])
    fc1 = tf.layers.dense(pool2_flat, 1024, tf.nn.relu)
    fc1 = tf.layers.dropout(fc1, rate=0.5, training=is_training)
    predict = tf.layers.dense(fc1, 10, tf.nn.softmax)

# 3. define loss
with tf.name_scope(loss):
    loss = tf.losses.softmax_cross_entropy(onehot_labels=tf_y, logits=predict)
    tf.summary.scalar(loss, loss)
with tf.name_scope(accuracy):
    # tf.metrics.accuracy() 返回 [上次的accuracy, 这次的accuracy]
    accuracy = tf.metrics.accuracy(labels=tf.argmax(tf_y, axis=1), predictions=tf.argmax(predict, axis=1))[1]
    tf.summary.scalar(accuracy, accuracy)

# 4. define optimizer
with tf.name_scope(train):
    optimizer = tf.train.AdamOptimizer(1e-4).minimize(loss)

# 5. initialize
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

# 6. train
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    # =================
    merge_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(logs/train, sess.graph)
    test_writer = tf.summary.FileWriter(logs/test, sess.graph)
    # tensorboard --logdir=logs
    # =================
    for step in range(10001):
        batch_x, batch_y = mnist.train.next_batch(100)
        _, ls, train_output = sess.run([optimizer, loss, merge_op],
                                       feed_dict={tf_x: batch_x, tf_y: batch_y, is_training: True})
        if step % 100 == 0:
            acc_test, test_output = sess.run([accuracy, merge_op],
                                             feed_dict={tf_x: mnist.test.images, tf_y: mnist.test.labels,
                                                        is_training: False})
            print(Step:, step, | train loss: %.4f % ls, | test accuracy: %.2f % acc_test)
            train_writer.add_summary(train_output, step)
            test_writer.add_summary(test_output, step)
    save_path = saver.save(sess, ./cnn_mnist.ckpt)

with tf.Session() as sess:
    sess.run(init_op)
    saver.restore(sess, ./cnn_mnist.ckpt)
    acc_test = sess.run(accuracy, feed_dict={tf_x: mnist.test.images,
                                             tf_y: mnist.test.labels,
                                             is_training: False})
    print(acc_test)
View Code

 

TensorFlow使用记录 (一): 基本概念

原文:https://www.cnblogs.com/xuanyuyt/p/9690537.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!