以下代码等价,但第一种慢,第二种快:

import torchvision.models as models

self.resnet = models.resnet18(pretrained=True)  #  联网下载,慢
self.model = models.resnet18(pretrained=False)
state_dict = torch.load('src/model/resnet18-5c106cde.pth')  # 自己从网上下载.pth,快
self.model.load_state_dict(state_dict)  # 再把读出来的参数放进没有参数的模型

当pretrained=True,才会联网下载模型,否则很快,仅得到一个没训练过的模型。

.pth文件或者state_dict变量:模型参数,里面是模型每一层具体的浮点数

model:模型,不含参数

model和.pth如果是对应的,就可以用model.load_state_dict加载。注意这条语句是在模型上直接修改,不应写成model = model.load_state_dict。

所以我们可以自己在浏览器下载模型,然后加载进去。那么去哪里下载呢?Ctrl+函数打开源码自己就可以找到。

Logo

更多推荐