Pytorch torch.cat(inputs, dimension=0)

1. torch.cat(inputs, dimension=0)说明

 torch.cat用于对tensor的拼接,dim默认为0,即从第一维度拼接。表示为4维的图像tensor中,第一维默认为batchSize,第二维为channel(通道),第三维为height(图片的高),第四维为width(图片的宽),一般需要基于通道进行拼接。

2. 例子

2.1 定义输入

2.1.1 code

    # ====================================
    # 定义两个4维tensor数据:
    # (batchSize, channel, height, width),
    # 这里定义的一个是一个4维数据,可以定义其
    # 他维度的数据。
    # ====================================
    data1 = torch.rand([1, 1, 3, 3])
    data2 = torch.rand([1, 1, 3, 3])
    print("data1_shape: ", data1.shape)
    print("data1: ", data1)
    print("data2_shape: ", data2.shape)
    print("data2: ", data2)

2.1.2 输出显示

 data1_shape和data2_shape是tensor的维度信息,代表2个4维tensor。

2.2 拼接数据

2.2.1 code

   # ====================================
    # 拼接数据,可以根据dim进行调整,此处的
    # dim = 0: 代表基于batchSize拼接
    # dim = 1: 代表基于通道拼接
    # dim = 2: 代表基于高拼接
    # dim = 3: 代表基于宽拼接
    # ====================================
    data3 = torch.cat([data1, data2], dim=0)
    data4 = torch.cat([data1, data2], dim=1)
    data5 = torch.cat([data1, data2], dim=2)
    data6 = torch.cat([data1, data2], dim=3)

    print("data3_shape: ", data3.shape)
    print("data3: ", data3)

    print("data4_shape: ", data4.shape)
    print("data4: ", data4)

    print("data5_shape: ", data5.shape)
    print("data5: ", data5)

    print("data6_shape: ", data6.shape)
    print("data6: ", data6)

2.2.2 输出显示

分别从batchSize,channel,height,width进行拼接。

 

原文地址:https://www.cnblogs.com/haifwu/p/12790416.html