timm从本地加载预训练模型
想要从timm加载本地预训练模型,首先是参考。
·
想要从timm加载本地预训练模型,首先是参考timm.create_model()从本地加载pretrained模型
将
model = timm.create_model('modelxxx', pretrained=True, xxx)
改为
pretrained_cfg = timm.models.create_model("modelxxx").default_cfg
pretrained_cfg['file'] = 'path/to/checkpoint'
model = timm.models.create_model("modelxxx", pretrained=True, xxx, pretrained_cfg=pretrained_cfg))
但是遇到错误
AssertionError: pretrained_cfg should not be set when sourcing model from Hugging Face Hub.
后来参考LocalEntryNotFoundError when loading downloaded pretrained model using timm.create_model (side load offline weights, e.g. on Kaggle) #1826成功将从本地加载预训练模型
timm.create_model(
'modelxxx',
pretrained=True,
pretrained_cfg_overlay=dict(file='path/to/checkpoint'),
)
---------------------xxx-----------------------
从本地加载
model = timm.create_model('modelxxx', pretrained=False, xxx) # pretrained=True —> False
model.load_state_dict(torch.load(pth_local_path), strict=True)
更多推荐
所有评论(0)