tensorflow2.0——where与mask的取值操作

import tensorflow as tf

a = tf.random.normal([3,3])
print('初始a:',a)

mask = a > 0
print('mask:',mask)

bool_mask = tf.boolean_mask(a,mask)
print('bool_mask:',bool_mask)

where_mask = tf.where(mask)
print('where_mask:',where_mask)

原文地址:https://www.cnblogs.com/cxhzy/p/13490993.html