PyTorch——Flatten(一)

1 torch.nn.Flatten(start_dim=1, end_dim=-1)

Parameters

1 start_dim – first dim to flatten (default = 1).
2 end_dim – last dim to flatten (default = -1).

Shape

Examples

[32, 64, 56, 56] ——> [32, 200704]

64*56*56 = 200704

1 import torch
2 import torch.nn as nn
3 
4 input = torch.randn(32, 64, 56, 56)
5 flatten = nn.Flatten()
6 print(flatten(input).size())
原文地址:https://www.cnblogs.com/timelesszxl/p/14549676.html