Pytorch——Linear(一)

1 torch.nn.Linear(in_features, out_features, bias=True)
  • 用于设置网络中的全连接层。
  • 输入与输出均是二维张量。
  • 输入与输出形状是 [batch_size, size]。
  • 之前一般使用view或nn.Flatten()将4维张量变为2维张量。

Parameters

1 in_features – size of each input sample
2 out_features – size of each output sample
3 bias – If set to False, the layer will not learn an additive bias. Default: True

Shape

Applies a linear transformation to the incoming data: y = xA^T + b

Examples

[32,  512] ——> [32, 128]

1 import torch
2 import torch.nn as nn
3 
4 input = torch.randn(32, 512)
5 linear = nn.Linear(512, 128)
6 print(linear(input).size())

[32, 128] ——> [32, 32] 

1 import torch
2 import torch.nn as nn
3 
4 input = torch.randn(32, 128)
5 linear = nn.Linear(128, 32)
6 print(linear(input).size())
原文地址:https://www.cnblogs.com/timelesszxl/p/14549359.html