一、state_dict方式(推荐)

torch.save(model.state_dict(), PATH)

model = YourModel()
model.load_state_dict(torch.load(PATH))
model.eval()

记住一定要使用model.eval()来固定dropout归一化层,否则每次推理会生成不同的结果。

二、整个模型(结构+state_dict)方式

torch.save(model, PATH)

model = torch.load(PATH)
model.eval()

这种保存/加载模型的过程使用了最直观的语法,所用代码量少。这使用Python的pickle保存所有模块。这种方法的缺点是,保存模型的时候,序列化的数据被绑定到了特定的类和确切的目录。这是因为pickle不保存模型类本身,而是保存这个类的路径,并且在加载的时候会使用。因此,当在其他项目里使用或者重构的时候,加载模型的时候会出错。
记住一定要使用model.eval()来固定dropout归一化层,否则每次推理会生成不同的结果。

三、cptk方式

当我们在训练的时候,因为一些原因导致训练终止了,这个时候如果我们不想再浪费时间从头开始训练,就可以使用cptk的方式。这种方式不仅可以保存模型的state_dict,还可以保存训练中断时的训练的epoch,loss,优化器的state_dict等信息。

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)


model = yourModel()
optimizer = yourOptimizer()

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - 或者 -
model.train()

示例
【深度学习实战(9)】三种保存和加载模型的方式-LMLPHP

04-20 00:13