Pytorch 保存模型
Pytorch 保存、加载模型的方法
Pytorch 保存、加载模型的官方推荐方法?
- PATH 是保存文件的路径,并且需要指定保存文件的文件名,在 pytorch1.6 版本中,torch.save 存储的文件格式采用了新的基于压缩文件的格式 .pth.tar,torch.load 依然保留了加载了旧格式.pth 的能力
1
2
3
4
5
6
7
8
9
10#第一种:只存储模型中的参数,该方法速度快,占用空间少(官方推荐使用)
model = VGGNet()
torch.save(model.state_dict(), PATH) # 存储model中的参数
new_model = VGGNet() #建立新模型
new_model.load_state_dict(torch.load(PATH)) #将model中的参数加载到new_model中
#第二种:存储整个模型
model = VGGNet()
torch.save(model, PATH) #存储整个模型
new_model = torch.load(PATH) #将整个model加载到new_model中
#new_model 不再需要第一种方法中的建立新模型的步骤
Pytorch 如何保存 checkpoint(检查点)?
- 每隔一段时间就将训练模型信息保存一次, 这些信息不光包含模型的参数信息,还包含其他信息,如当前的迭代次数,优化器的参数
1
2
3
4
5
6
7
8
9
10
11
12state = {
'epoch' : epoch + 1, #保存当前的迭代次数
'state_dict' : model.state_dict(), #保存模型参数
'optimizer' : optimizer.state_dict(), #保存优化器参数
..., #其余一些想保持的参数都可以添加进来
}
torch.save(state, 'checkpoint.pth.tar') # 将state中的信息保存到checkpoint.pth.tar
#Pytorch 约定使用.tar格式来保存这些检查点,当想恢复训练时
checkpoint = torch.load('checkpoint.pth.tar')
epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict']) #加载模型的参数
optimizer.load_state_dict(checkpoint['optimizer']) #加载优化器的参数
Pytorch 在不同设备上的存储与加载?
- 单 GPU 与 CPU
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23#1、在CPU上存储模型,在GPU上加载模型
#CPU存储
torch.save(model.state_dict(), PATH)
#GPU加载
device = torch.device('cuda')
model = Model()
model.load_state_dict(torch.load(PATH, map_location='cuda:0')) #可以选择任意GPU设备
model.to(device)
#2、在GPU上存储,CPU上加载
#GPU存储
torch.save(model.state_dict(), PATH)
#CPU加载
device = torch.device('cpu')
model = Model()
model.load_state_dict(torch.load(PATH, map_location=device))
#3、在GPU上存储,在GPU上加载
#GPU存储
torch.save(model.state_dict(), PATH)
#GPU加载
device = torch.device('cuda')
model = Model()
model.load_state_dict(torch.load(PATH))
model.to(device) - 多 GPU
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27#(1)多卡训练,单卡加载部署
# 多卡训练保存的参数名称是module.conv1.weight,而单卡的参数名称是conv1.weight,直接加载会报错,提示找不到相应的字典的错误。此时可以通过手动的方式删减掉模型中前几位的名称,然后重新加载
model = torch.nn.DataParallel(model)
#存储
torch.save(model.module.state_dict(), PATH)
#加载
kwargs={'map_location':lambda storage, loc: storage.cuda(gpu_id)}
def load_GPUS(model,model_path,kwargs):
state_dict = torch.load(PATH, **kwargs)
# create new OrderedDict that does not contain 'module.'
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove 'module.'
new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)
return model
#(2)单卡训练,多卡加载部署
# 单卡训练参数是没有module的,而多卡加载的参数是有module的,因此需要保证参数加载在模型分发之前
#存储
torch.save(model.state_dict(), PATH)
#加载
model.load_state_dict(torch.load(PATH))
model = torch.nn.DataParallel(model) #模型分发
#(3)多卡训练,多卡加载部署
# 环境如果没有变化,则可以直接加载,如果环境有变化,则可以拆解成第1种情况,然后再分发模型
Pytorch 加载特定参数?
- 使用 resnet50、resnet101 等网络时,仅使用部分或修改了结构,如果仅仅想加载未被修改部分的权重,应该遍历 model.state_dict () 逐层进行权重加载
1
2
3
4
5
6
7model = OurModel()
model_checkpoint = torch.load('checkpoint.pth.tar')
pretrain_model_dict = model_checkpoint['state_dict']
model_dict = model.state_dict()
same_model_dict = {k : v for k, v in pretrain_model_dict if k in model_dict}
model_dict.update(same_model_dict)
model.load_state_dict(model_dict)
model.state_dict () 存储什么内容?
- model.state_dict () 返回的是一个 OrderedDict 对象。OrderedDict 是 dict 的子类, 这个字典中的 key-value 对是有顺序的。这个特性正好可以跟网络结构的层次性对应起来
1
2OrderedDict([('conv1.weight', TensorValue), ('conv1.bias', TensorValue),
('conv2.weight', TensorValue), ('conv2.bias', TensorValue)]) - 可以通过遍历 model.state_dict () 的方式,获得各层参数值,但值得注意的是,只有那些参数可以训练的层,这些层的参数会被保存到 model.state_dict () 中,如卷积层,线性层,没有参数的层,不会被保存
1
2
3
4
5
6for k, v in model.state_dict().items():
print('k = ', k, '; ', 'v.size = ', v.size()) #为了直观,输出参数的尺寸大小
#k = conv1.weight ; v.size = torch.Size([3, 3, 3, 3])
#k = conv1.bias ; v.size = torch.Size([3])
#k = conv2.weight ; v.size = torch.Size([3, 3, 3, 3])
#k = conv2.bias ; v.size = torch.Size([3])
Optimizer.state_dict () 存储什么内容?
- 保存了优化器的状态以及被使用的超参数
1
2
3
4
5optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, )
print(optimizer.state_dict())
#输出
{'state': {}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0,
'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3]}]}