首页 > 其他 > 详细

Tensorflow - tf.nn.embedding_lookup使用

时间:2019-09-02 14:51:36      阅读:92      评论:0      收藏:0      [点我收藏+]
  • 原型:tf.nn.embedding_lookup(params, ids, partition_strategy=‘mod‘, name=None, validate_indices=True, max_norm=None)
  • 实际上tf.nn.embedding_lookup的作用就是找到要寻找的embedding data中的对应的行下的vector。
  • 简单地通过代码了解一下,ids为1行或者多行都囊括:
    # -*- coding= utf-8 -*-
    import tensorflow as tf
    import numpy as np
    
    a = [[0.1, 0.2, 0.3], [1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3]]
    a = np.asarray(a)
    idx1 = tf.Variable([0, 2, 3, 1], tf.int32)
    idx2 = tf.Variable([[0, 2, 3, 1], [4, 0, 2, 2]], tf.int32)
    out1 = tf.nn.embedding_lookup(a, idx1)
    out2 = tf.nn.embedding_lookup(a, idx2)
    init = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init)
        print sess.run(out1)
        print out1
        print ==================
        print sess.run(out2)
        print out2
  • 输出:
    [[ 0.1  0.2  0.3]
     [ 2.1  2.2  2.3]
     [ 3.1  3.2  3.3]
     [ 1.1  1.2  1.3]]
    Tensor("embedding_lookup:0", shape=(4, 3), dtype=float64)
    ==================
    [[[ 0.1  0.2  0.3]
      [ 2.1  2.2  2.3]
      [ 3.1  3.2  3.3]
      [ 1.1  1.2  1.3]]
    
     [[ 4.1  4.2  4.3]
      [ 0.1  0.2  0.3]
      [ 2.1  2.2  2.3]
      [ 2.1  2.2  2.3]]]
    Tensor("embedding_lookup_1:0", shape=(2, 4, 3), dtype=float64)

     

  • 维度讨论:在要寻找的embedding数据中下找对应的index下的vector进行拼接。永远是ids部分的维度+embedding部分的除了第一维后的维度拼接。很明显,我们也可以得到,ids里面值是必须要小于等于embedding最大维度减一的。

 

参考文章https://www.jianshu.com/p/ad88a0afa98f

Tensorflow - tf.nn.embedding_lookup使用

原文:https://www.cnblogs.com/Jesee/p/11445560.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!