tf.gather

gather就是按行取值:

a1 = [[1,2], [3, 4], [5, 6]]
a2 = tf.gather(tf.constant(a1), [0, 1])
print(a2)

输出:

tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32)

相当于:

a1[:2]
原文地址:https://www.cnblogs.com/oaks/p/14044090.html