profile计算模型参数

from thop import profile

 class Test(nn.Module):
     def __init__(self, input_size, output_szie):
         super(Test, self).__init__()
         self.out = nn.Linear(input_size, output_szie)
     def forward(self, x):
         output = self.out(x)
         return output

t = Test(10, 2)
x = torch.randn(4, 10)
profile(t, (x,), verbose=False)   # (80.0, 22.0): 10*2 + 2 = 22.0

# total_flops += flops 
# model_params_num += params
原文地址:https://www.cnblogs.com/douzujun/p/13875078.html