首页 > 其他 > 详细

tensorflow笔记6:tf.nn.dynamic_rnn:的输出,output和state

时间:2018-07-11 16:52:52      阅读:726      评论:0      收藏:0      [点我收藏+]

 技术分享图片

然后经过LSTM产生两个输出,output和state

所以如果想用dynamic_rnn得到输出后,只需要最后一次的状态输出,直接调用s.h即可。

#注意,执行两次run的时候相当于执行了两次图,所以得到的结果不同;

 

import tensorflow as tf
batch_size = 4 
input = tf.random_normal(shape=[3, batch_size, 6], dtype=tf.float32)
cell = tf.nn.rnn_cell.BasicLSTMCell(10, forget_bias=1.0, state_is_tuple=True)
init_state = cell.zero_state(batch_size, dtype=tf.float32)
output, final_state = tf.nn.dynamic_rnn(cell, input, initial_state=init_state, time_major=True) #time_major如果是True,就表示RNN的steps用第一个维度表示,建议用这个,运行速度快一点。
#如果是False,那么输入的第二个维度就是steps。
#如果是True,output的维度是[steps, batch_size, depth],反之就是[batch_size, max_time, depth]。就是和输入是一样的
#final_state就是整个LSTM输出的最终的状态,包含c和h。c和h的维度都是[batch_size, n_hidden]
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #print(sess.run(output))                                                                                                                                                                                                                                                   
    #print(sess.run(final_state))
    print(sess.run([output,final_state]))

输出比较:

由输出可以看出来,这个output 包含了隐含层所有时刻的输出,下面output的最后一个时间点就是state的输出;

state的输出包含了C和H,两者代表的都是当前层的最后一时刻的输出;

[array([[[ 0.11201711,  0.05266698,  0.12750182,  0.03627545,
          0.02706259, -0.11562401,  0.08477378,  0.06157489,
          0.07341921,  0.15011263],
        [-0.09552816, -0.17314027, -0.0895469 , -0.26399866,
         -0.36308575,  0.10537394, -0.09443868, -0.31130335,
          0.0132737 , -0.12810872],
        [-0.00719012,  0.04438379, -0.03804718, -0.06637346,
         -0.02082551,  0.132549  , -0.05982352,  0.11778613,
         -0.09206182,  0.02547247],
        [ 0.14723007,  0.05410767,  0.06571447,  0.06775881,
         -0.03286515,  0.31600857,  0.03567648,  0.10357846,
         -0.0679171 , -0.00785992]],

       [[ 0.06683166, -0.05883167,  0.10910213,  0.05030679,
          0.17738451,  0.00631482, -0.00457612, -0.03694798,
          0.17743434,  0.06658468],
        [-0.03385706, -0.20001511, -0.05247132, -0.14611273,
         -0.17433529,  0.14970839, -0.07725038, -0.32652032,
          0.09670977, -0.17828827],
        [ 0.03988864, -0.03118243, -0.09246919,  0.1831698 ,
         -0.01006366,  0.01672944,  0.01009638,  0.10943947,
         -0.00420897, -0.0054652 ],
        [ 0.16778645,  0.08699884,  0.12571299,  0.12276714,
          0.04879797,  0.10517071,  0.10341848,  0.15480027,
         -0.04619027,  0.11167715]],

       [[ 0.14293307, -0.10649989,  0.09144076, -0.03020415,
          0.18182378,  0.22111537, -0.02275194, -0.14586878,
          0.19310513, -0.02283864],
        [-0.0553881 , -0.16710383, -0.09584018, -0.06020959,
         -0.11862611,  0.05812657, -0.05461238, -0.21729217,
          0.08961426, -0.1420837 ],
        [ 0.03053934,  0.02213254, -0.11577073,  0.08933022,
         -0.08349261,  0.044699  ,  0.01332499,  0.14753158,
         -0.12446564,  0.00095996],
        [ 0.21244884,  0.11677884,  0.15352076,  0.04703464,
          0.07084017,  0.04610508,  0.09713535,  0.12495688,
          0.00218641,  0.17711937]]], dtype=float32), 
LSTMStateTuple(
c=array([[ 0.264239 , -0.16139928, 0.25842854, -0.05938458, 0.38918033, 0.37621742, -0.06394874, -0.263255 , 0.32704324, -0.04286532], [-0.11041687, -0.3316248 , -0.21551779, -0.12425021, -0.2452825 , 0.12507899, -0.11451716, -0.40844095, 0.20570038, -0.28551656], [ 0.0634905 , 0.05425977, -0.19805768, 0.15730162, -0.14432296, 0.09046975, 0.02406704, 0.34546444, -0.22364679, 0.00243504], [ 0.40725306, 0.25660557, 0.3873769 , 0.11941462, 0.16212168, 0.10613891, 0.1803763 , 0.26139545, 0.00540481, 0.31761324]], dtype=float32),
h=array([[ 0.14293307, -0.10649989, 0.09144076, -0.03020415, 0.18182378, 0.22111537, -0.02275194, -0.14586878, 0.19310513, -0.02283864], [-0.0553881 , -0.16710383, -0.09584018, -0.06020959, -0.11862611, 0.05812657, -0.05461238, -0.21729217, 0.08961426, -0.1420837 ], [ 0.03053934, 0.02213254, -0.11577073, 0.08933022, -0.08349261, 0.044699 , 0.01332499, 0.14753158, -0.12446564, 0.00095996], [ 0.21244884, 0.11677884, 0.15352076, 0.04703464, 0.07084017, 0.04610508, 0.09713535, 0.12495688, 0.00218641, 0.17711937]], dtype=float32))]

 

tensorflow笔记6:tf.nn.dynamic_rnn:的输出,output和state

原文:https://www.cnblogs.com/lovychen/p/9294624.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!