[译]保存和加载模型¶
文章Saving and Loading Models详细的说明了几种保存和加载模型的使用方法。通篇看下去发现干货满满,没啥可以精简的地方,不如翻译之
引言¶
本文档提供了关于保存和加载PyTorch
模型的各种用例的解决方案。请随意阅读整个文档,或者直接跳到所需用例所需的代码。谈到保存和加载模型,有三个核心功能需要熟悉:
- torch.save: 将序列化对象保存到磁盘。该函数使用
Python
的pickle
程序进行序列化。使用此功能可以保存各种对象的模型、张量和字典 - torch.load: 使用
pickle
工具将目标文件反序列化到内存中。该功能还便于设备将数据加载到其中(参考章节跨设备保存和加载模型
) - torch.nn.Module.load_state_dict: 使用反序列化的
state_dict
加载模型的参数字典。参考章节什么是state_dict?
内容列表¶
- 什么是
state_dict?
- 保存和加载用于推断的模型
- 保存和加载常规检查点
- 在一个文件中保存多个模型
- 使用不同模型参数的热启动模型
- 跨设备保存和加载模型
什么是state_dict?¶
在PyTorch
中,torch.nn.Module
模型的可学习参数(即权重和偏差)包含在模型的参数中(通过model.parameters()
访问)。state_dict
(状态字典)只是一个Python
字典对象,保存了每一层映射的参数张量。请注意,只有具有可学习参数的层(卷积层、线性层等)和注册缓冲区(batchnorm
的running_mean
)在模型的state_dict
中有条目。优化器对象(torch.optim
)也有一个state_dict
,它包含关于优化器状态的信息,以及所使用的超参数
因为state_dict
对象是Python
字典,所以它们可以很容易地保存、更新、更改和恢复,为PyTorch
模型和优化器增加了大量的模块性
示例¶
查看Training a classifier教程中的模型的state_dict
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# Define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
输出如下:
Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [139742797940768, 139742797940688, 139742797940848, 139742797940928, 139742797941008, 139742797941088, 139742797941168, 139742797941248, 139742797941328, 139742797941408]}]
保存和加载用于推断的模型¶
保存/加载state_dict(推荐)¶
保存操作:
torch.save(model.state_dict(), PATH)
加载操作:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
保存模型进行推理时,只需保存训练好的模型的学习参数。用torch.save()
函数保存模型的state_dict
将为以后恢复模型提供最大的灵活性,这就是为什么它是保存模型的推荐方法。一个常见的PyTorch
约定是使用. pt
或.pth
文件扩展名。请记住,在运行推理之前,您必须调用model.eval()
将dropout
层和batch normalization
层设置为评估模式。不这样做将产生不一致的推理结果
注意:load_state_dict()
函数接受字典对象,而不是保存对象的路径。这意味着在将保存的state_dict
传递给load_state_dict()
函数之前,必须对其进行反序列化。例如,您不能使用model.load_state_dict(PATH)
进行加载
保存/加载完整模型¶
保存操作:
torch.save(model, PATH)
加载操作:
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
这个保存/加载过程使用最直观的语法,涉及的代码量最少。以这种方式保存模型将使用Python
的pickle
模块保存整个模块。这种方法的缺点是序列化数据绑定到特定的类和保存模型时使用的确切目录结构。这是因为pickle
没有保存模型类本身。相反,它保存了包含类的文件的路径,该路径在加载时使用。因此,当在其他项目中使用或重构后,您的代码可能会以各种方式中断
注意 1:同样的,常用的文件规范是.pt
或.pth
注意 2:如果调用了dropout
层或者batch normalization
层,需要调用函数model.eval()
切换成评估模式
保存和加载用于推理或恢复训练的常规检查点¶
保存操作:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
加载操作:
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
当保存用于推理或恢复训练的常规检查点时,必须保存的不仅仅是模型的state_dict
,还包括保存优化器的state_dict
,因为它包含随着模型训练而更新的缓冲区和参数。可能想要保存的其他项目包括中断的时期、最新记录的训练损失、外部的torch.nn.Embedding
层等。要保存多个组件,请在字典中组织它们,并使用torch.save()
序列化字典。一个常见的PyTorch
约定是使用.tar
文件扩展名
要加载项目,首先初始化模型和优化器,然后使用torch.load()
在本地加载字典。从这里,您可以像预期的那样通过简单地查询字典来轻松地访问保存的项目
注意:如果要用于推理,需要调用model.eval()
函数设置模型为评估模式;如果要用于恢复训练,需要调用model.train()
以确保模型处于训练模式
在一个文件中保存多个模型¶
保存操作:
torch.save({
'modelA_state_dict': modelA.state_dict(),
'modelB_state_dict': modelB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
...
}, PATH)
加载操作:
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()
保存由多个torch.nn.Module
组成的模型时,如GAN
、序列到序列模型或模型集合,您可以遵循与保存常规检查点时相同的方法。换句话说,保存每个模型的state_dict
和相应的优化器。如前所述,您可以通过简单地将其他任何有助于恢复训练的项目添加到字典中来保存它们。文件命名规范同样是使用.tar
文件扩展
要加载模型,首先初始化模型和优化器,然后使用torch.load()
在本地加载字典。同样的,如果要用于推理,需要调用model.eval()
函数设置模型为评估模式;如果要用于恢复训练,需要调用model.train()
以确保模型处于训练模式
使用不同模型参数的热启动模型¶
保存操作:
torch.save(modelA.state_dict(), PATH)
加载操作:
modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
当迁移学习或训练新的复杂模型时,部分加载模型或加载部分模型是常见的场景。利用经过训练的参数,即使只有少数参数可用,也将有助于启动训练过程,并有望帮助您的模型比从头开始的训练收敛得快得多。无论您是从缺少某些键的部分状态字典加载,还是加载的状态字典的键比您要加载的模型多,您都可以在load_state_dict()
函数中将strict
参数设置为False
以忽略不匹配的键。如果想将参数从一个层加载到另一个层,但有些键不匹配,只需在加载的state_dict
中更改参数键的名称,以匹配待加载到的模型中的键即可
跨设备保存和加载模型¶
在GPU中保存,在CPU中加载¶
保存操作:
torch.save(model.state_dict(), PATH)
加载操作:
device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
当在CPU
中加载模型而在GPU
中训练时,将torch.device('cpu')
传入到torch.load()
函数的参数map_location
中。在这种情况下,张量下面的存储使用map_location
参数动态地重新映射到CPU
设备
在GPU中保存,在GPU中加载¶
保存操作:
torch.save(model.state_dict(), PATH)
加载操作:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
当模型在GPU
中进行训练和保存,然后重新加载到GPU
中时,只需将初始model
使用函数model.to(torch.device('cuda'))
转换成CUDA
优化模型即可。注意,确保所有模型输入数据均调用了.to(torch.device('cuda'))
注意,调用my_tensor.to(device)
返回的是my_tensor
在GPU
上的副本,并不会重写my_tensor
。所以需要手动重写张量:my_tensor = my_tensor.to(torch.device('cuda'))
在CPU中保存,在GPU中加载¶
保存操作:
torch.save(model.state_dict(), PATH)
加载操作:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
模型在CPU
中进行训练和保存,在GPU
中进行加载时,设置torch.load()
函数参数map_location
为cuda:device_id
,这样能在指定GPU
设备上加载模型。接下来,确保调用model.to(torch.device('cuda'))
转换模型的参数张量到CUDA
张量。确保对所有模型输入数据调用.to(torch.device('cuda'))
函数。
同样注意,调用my_tensor.to(device)
返回的是my_tensor
在GPU
上的副本,并不会重写my_tensor
。所以需要手动重写张量:my_tensor = my_tensor.to(torch.device('cuda'))
保存torch.nn.DataParallel模型¶
保存操作:
torch.save(model.module.state_dict(), PATH)
加载操作:
# Load to whatever device you want
torch.nn.DataParallel
是一个模型包装器,支持并行GPU
的使用。为了更通用地使用DataParallel
模型,保存模型的model.module.state_dict()
。这样就可以灵活地将模型加载到任何想要的设备上