tf.extract_image_patches函数的理解

示例:

input_patch = tf.extract_image_patches(input, ksizes=[1, patch_sz, patch_sz, 1],
                                           strides=[1, 1, 1, 1],
                                           rates=[1, rates, rates, 1],
                                           padding="SAME")

理解:

将输入的4-D张量在每一个位置切取 patch_sz x patch_sz 大小领域的patch,获得的领域数据保存在depth维的位置(最后一维)上。
ksizes和strides、padding的设置和卷积操作保持一致。
rates表示的是取领域的patch网格在行列上的采样间隔,rate越大,patch直接感知的领域区域越大,但采样的点的间隔变大,采样点数依旧是 patch_sz*patch_sz

举例:

输入:(1,20,20,3)的图
参数:ksizes=[1, 3, 3, 1],strides=[1, 1, 1, 1],rates=[1, rates, rates, 1],padding="SAME"
输出:(1,20,20,27),patches数据在最后一维,也就是每个位置点采样了27个的数据,27对应着 depth*patch_sz*patch_sz,顺序是按照patch块的顺序,依次输出每个位置各个深度维度的数据。

原文地址:https://www.cnblogs.com/wioponsen/p/13589492.html