使用meshgrid生成热图和单位向量场

需求:
生成

中heatmap
unit vector field

目前的数据:
图像的shape, 关键点的x,y , heatmap的半径R

思路:
如果使用for循环来判断距离,会很慢,如果预先准备数据又会很大,所以找其他的方案。
使用矩阵运算会很简洁而且方便计算

uu = x
vv = y
out_h = 512
out_w = 512
max_dist_2d = 20.0 # R


xx, yy = tf.meshgrid(tf.range(out_h), tf.range(out_w))
xx, yy = tf.cast(xx, tf.float32), tf.cast(yy, tf.float32)
dis = tf.sqrt(tf.square(xx-uu)+tf.square(yy-vv))
hm = tf.maximum(max_dist_2d-dis, tf.zeros_like(xx))/max_dist_2d
vecx = tf.where(tf.less(dis, max_dist_2d), xx-uu, 0) / dis
vecy = tf.where(tf.less(dis, max_dist_2d), yy-vv, 0) / dis

主要用到的就是tf.meshgrid(numpy也有同功能函数)
meshgrid用于从数组a和b产生网格。生成的网格矩阵A和B大小是相同的

tf.where 定义如下: where(condition, x=None, y=None,name=None)
condition:一个Tensor,数据类型为tf.bool类型
返回值:如果x、y不为空的话,返回值和x、y有相同的形状,如果condition对应位置值为True那么返回Tensor对应位置为x的值,否则为y的值.

tf.less(类似的还有greater等)
tf.less返回两个张量各元素比较(x<y)得到的真假值组成的张量

参考资料:https://arxiv.org/abs/1711.08996

原文地址:https://www.cnblogs.com/flyuz/p/11891465.html