tf.segment_sum和tf.unsorted_segment_sum理解实例

本文来自 guotong1988 的CSDN 博客 ,全文地址请点击:https://blog.csdn.net/guotong1988/article/details/77622790

 1 import tensorflow as tf
 2 c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
 3 result = tf.segment_sum(c, tf.constant([0, 0, 1]))#第二个参数长度必须为3
 4 result_ = tf.segment_sum(c, tf.constant([0, 1, 1]))
 5 result__ = tf.segment_sum(c, tf.constant([0, 1, 2]))
 6 result2 = tf.unsorted_segment_sum(c, tf.constant([2, 1, 1]),3)#第二个参数长度必须为3
 7 result3 = tf.unsorted_segment_sum(c, tf.constant([1, 0, 1]),2)
 8 #result4 = tf.unsorted_segment_sum(c, tf.constant([2, 0, 1]),2) #错误,segment_ids[0] = 2 is out of range [0, 2)
 9 result4 = tf.unsorted_segment_sum(c, tf.constant([2, 0, 1]),3)
10 result5 = tf.unsorted_segment_sum(c, tf.constant([3, 1, 0]),5)
11 sess = tf.Session()
12 print("result")
13 print(sess.run(result))
14 print("result_")
15 print(sess.run(result_))
16 print("result__")
17 print(sess.run(result__))
18 print("result2")
19 print(sess.run(result2))
20 print("result3")
21 print(sess.run(result3))
22 print("result4")
23 print(sess.run(result4))
24 print("result5")
25 print(sess.run(result5))

运行结果:

result 
[[0 0 0 0] 
 [5 6 7 8]] 
result_ 
[[1 2 3 4] 
 [4 4 4 4]] 
result__ 
[[ 1  2  3  4] 
 [-1 -2 -3 -4] 
 [ 5  6  7  8]] 
result2 
[[0 0 0 0] 
 [4 4 4 4] 
 [1 2 3 4]] 
result3 
[[-1 -2 -3 -4] 
 [ 6  8 10 12]] 
result4 
[[-1 -2 -3 -4] 
 [ 5  6  7  8] 
 [ 1  2  3  4]] 
result5 
[[ 5  6  7  8] 
 [-1 -2 -3 -4] 
 [ 0  0  0  0] 
 [ 1  2  3  4] 
 [ 0  0  0  0]]

将索引值相同的进行求和,其余的按顺序计算。

原文地址:https://www.cnblogs.com/gaofighting/p/9706081.html