tf.gather
笔者遇到一个问题:
a = tf.Variable([[1,2,3,4,5,6,7,8,9,10]])
y = tf.Variable([2])
p = tf.gather(a,y)
》》输出:
a.shape (1, 10)
a [[ 1 2 3 4 5 6 7 8 9 10]]
y.shape (1,)
y [2]
p.shape (1, 10)
p输出出错
笔者本来以为用tf.gather可以获取出获取出相应位置的参数,但是获取的并不正确。笔者想要弄懂是为什么,探讨如下:
tf.gather
gather(
params,
indices,
validate_indices=None,
name=None
)
功能是从params中根据索引获取参数,其中indices必须是一个interger参数,产生的一个输出的tensor的shape是indices.shape+params.shpe[1:]
看完官网的解释,笔者还是不太明白,故从网上找了几个tf.gether的列子,列子如下,具体参考[1],笔者做适量修改.
例子一:
params = [0,1,2]
indices = [2]
gather = tf.gather(params,indices)
with tf.Session() as sess:
print 'gather',gather
print sess.run(gather)
笔者发现想这样打印的时候,就可以打印粗来了。笔者原来的目标里是二维的。
所以现在有两个解决方式,要么迭代的方法讲二维变成一维。要么通过tf.gather从二维中获取出相应值。笔者准备采用第二种方式,因为第一种方法会给代码添加多余的操作,不利于代码的直观阅读。
tf.gather_nd
笔者在搜索过程中,查阅到tf.gather_nd函数,觉得更加适用,故改选该方法。
tf.gather_nd
gather_nd(
params,
indices,
name=None
)
由于例子非常简单,笔者不在这里详述,详情参考:[2]
笔者只解决最开始遇到的问题,将其进行改正。
改正如下:
a = tf.Variable([[1,2,3,4,5,6,7,8,9,10]])
y = tf.Variable([0,2])
p = tf.gather_nd(a,y)
笔者发现这样打印出来的东西就是一个值
a = tf.Variable([[1,2,3,4,5,6,7,8,9,10]])
y = tf.Variable([[0,2]])
p = tf.gather_nd(a,y)
如果是这样改,打印出来的东西就是一个列表。
所以到底怎么取值,完全取决于你需要什么样的数据格式。
补充:笔者在这里补充一下,笔者在一个函数中用到tf.gather_nd函数,却讲获取行号填成了[[1,*]],结果犯了错找了两个小时才找到,用的时候切忌越界。笔者心里痛痛的。
参考目录:
[1]:http://programtalk.com/python-examples/tensorflow.gather/
[2]: https://www.tensorflow.org/api_docs/python/tf/gather_nd
没有评论:
发表评论