查询模型参数总数

numel

参考:torch.numel(input) → int

函数numel作用是返回输入张量的元素总数

>>> import torch
>>>  
>>> a = torch.randn((2,3,4))
>>> a.shape
torch.Size([2, 3, 4])
>>> torch.numel(a)
24
>>> a.numel()
24

查询模型参数总数

参考:5.查看网络总参数

net = Model()
print('# Model parameters:', sum(param.numel() for param in net.parameters()))