PyTorch模型数据统计分析(模型每层形状、参数量、FLOPs)

相关工具:

1、torchsummary:打印torch模型每层形状

sksq96/pytorch-summary: Model summary in PyTorch similar to `model.summary()` in Keras (github.com)

How to install

pip install torchsummary

How to Use

from torchsummary import summary

summary(model, (1, 28, 28))

2、THOP: 统计 PyTorch 模型的 FLOPs 和参数量

Lyken17/pytorch-OpCounter: Count the MACs / FLOPs of your PyTorch model. (github.com)

How to install

pip install thop (now continously intergrated on Github actions)

OR

pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git

How to use

  • Basic usage

    from torchvision.models import resnet50
    from thop import profile
    model = resnet50()
    input = torch.randn(1, 3, 224, 224)
    macs, params = profile(model, inputs=(input, ))
  • Define the rule for 3rd party module.

    class YourModule(nn.Module):
        # your definition
    def count_your_model(model, x, y):
        # your rule here
    
    input = torch.randn(1, 3, 224, 224)
    macs, params = profile(model, inputs=(input, ), 
                            custom_ops={YourModule: count_your_model})
  • Improve the output readability

    Call thop.clever_format to give a better format of the output.

    from thop import clever_format
    macs, params = clever_format([macs, params], "%.3f")

3、Flops counter for convolutional networks in pytorch framework

sovrasov/flops-counter.pytorch: Flops counter for convolutional networks in pytorch framework (github.com)

How to install

From PyPI:

pip install ptflops

From this repository:

pip install --upgrade git+https://github.com/sovrasov/flops-counter.pytorch.git

How to use

import torchvision.models as models
import torch
from ptflops import get_model_complexity_info

with torch.cuda.device(0):
    net = models.densenet161()
    macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
                                           print_per_layer_stat=True, verbose=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))

 

原文地址:https://www.cnblogs.com/lucifer1997/p/14070453.html