目录

前言

我们知道一个良好的权重初始化,可以使收敛速度加快,甚至可以获得更好的精度。而在实际应用中,我们通常采用一个已经训练好模型的权重参数作为我们模型的初始化参数,也称为 Finetune,更宽泛的称之为迁移学习。迁移学习中的 Finetune 技术,本质上就是让我们新构建的模型,拥有一个较好的权重初始值。

  1. 为什么要 Model Finetune?

一般来说需要模型微调的任务都有如下特点:
在新任务中数据量较小,不足以训练一个较大的 Model
。可以用 Model Finetune 的方式辅助我们在新任务中训练一个较好的模型,让训练过程更快。

  1. 模型微调的步骤

  • 第一步:保存模型,拥有一个预训练模型;
  • 第二步:加载模型,把预训练模型中的权值取出来;
  • 第三步:初始化,将权值对应的放到新的模型中。
  1. 模型微调训练方法

因为需要保留 Features Extractor 的结构和参数,提出了两种训练方法:

  • 固定预训练的参数:
    requires_grad = False
    或者
    lr = 0
    ,即不更新参数;
  • 将 Features Extractor 部分设置很小的学习率,这里用到参数组(params_group)的概念,分组设置优化器的参数。
  1. 示例(finetune_resnet18)

4.1 不使用trick:所有的参数使用同一个学习率

# -*- coding: utf-8 -*-
""" 模型finetune方法,方法一:使用同一个学习率 """
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt

import sys
libo_DIR = os.path.abspath(os.path.dirname(__file__) + os.path.sep + ".." + os.path.sep + "..")
sys.path.append(libo_DIR)
 
from tools.my_dataset import PubuDataset
from tools.common_tools import set_seed
import torchvision.models as models
import torchvision
BASEDIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("use device :{}".format(device))
 
set_seed(1)  # 设置随机种子
label_name = {"ants": 0, "bees": 1}
 
# 参数设置
MAX_EPOCH = 25
BATCH_SIZE = 16
LR = 0.001
log_interval = 10
val_interval = 1
classes = 2
start_epoch = -1
lr_decay_step = 7

# ============================ step 1/5 数据 ============================
data_dir = os.path.abspath(os.path.join(BASEDIR, "..", "data", "pubu"))
if not os.path.exists(data_dir):
    raise Exception("\n{} 不存在,请下载 07-02-数据-模型finetune.zip  放到\n{} 下,并解压即可".format(
        data_dir, os.path.dirname(data_dir)))
 
train_dir = os.path.join(data_dir, "train")
valid_dir = os.path.join(data_dir, "val")
 
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
 
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])
 
valid_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])
 
# 构建 MyDataset 实例
train_data = PubuDataset(data_dir=train_dir, transform=train_transform)
valid_data = PubuDataset(data_dir=valid_dir, transform=valid_transform)
 
# 构建 DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
 
# ============================ step 2/5 模型 ============================
# 1/3 构建模型
resnet18_ft = models.resnet18()

# 2/3 加载参数
path_pretrained_model = os.path.join(BASEDIR, "..", "data", "resnet18-5c106cde.pth")
state_dict_load = torch.load(path_pretrained_model)  # 加载字典state_dict
resnet18_ft.load_state_dict(state_dict_load)         # 把state_dict放到模型中,这样就改变了原来的参数
 
# 3/3 替换fc层
num_ftrs = resnet18_ft.fc.in_features   # 从原始的fc层获取输入有多少个神经元,给下面用。
resnet18_ft.fc = nn.Linear(num_ftrs, classes)  # 构建一个新的Linear, 输出神经元个数为分类数classes,输入为多少个神经元根据上一句得到。然后用这个 Linear 覆盖 fc 层。

resnet18_ft.to(device)

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()
 
# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9)  # 选择优化器。使用相同的学习率。
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)  # 设置学习率下降策略
 
# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()
for epoch in range(start_epoch + 1, MAX_EPOCH):
    loss_mean = 0.
    correct = 0.
    total = 0.
    resnet18_ft.train()
    for i, data in enumerate(train_loader):
        # forward
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = resnet18_ft(inputs)
        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()
        # update weights
        optimizer.step()
        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().cpu().sum().numpy()
        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                   epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.
            print("epoch:{} conv1.weights[0, 0, ...] :\n {}".format(epoch, resnet18_ft.conv1.weight[0, 0, ...]))
 
    scheduler.step()  # 更新学习率
 
    # validate the model
    if (epoch+1) % val_interval == 0:
        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        resnet18_ft.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = resnet18_ft(inputs)
                loss = criterion(outputs, labels)
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().cpu().sum().numpy()
                loss_val += loss.item()
            loss_val_mean = loss_val/len(valid_loader)
            valid_curve.append(loss_val_mean)
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val))
        resnet18_ft.train()
 
train_x = range(len(train_curve))
train_y = train_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve
 
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

输出结果为:

use device :cpu
Training:Epoch[000/025] Iteration[010/016] Loss: 0.6572 Acc:60.62%
epoch:0 conv1.weights[0, 0, ...] :
 tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],
        [ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],
        [-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],
        [ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],
        [-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],
        [ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],
        [-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],
       grad_fn=<SelectBackward>)
Valid:	 Epoch[000/025] Iteration[010/010] Loss: 0.4565 Acc:84.97%
Training:Epoch[001/025] Iteration[010/016] Loss: 0.4074 Acc:85.00%
epoch:1 conv1.weights[0, 0, ...] :
 tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],
        [ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],
        [-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],
        [ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],
        [-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],
        [ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],
        [-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],
       grad_fn=<SelectBackward>)
Valid:	 Epoch[001/025] Iteration[010/010] Loss: 0.2846 Acc:93.46%
Training:Epoch[002/025] Iteration[010/016] Loss: 0.3542 Acc:83.12%
epoch:2 conv1.weights[0, 0, ...] :
 tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],
        [ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],
        [-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],
        [ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],
        [-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],
        [ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],
        [-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],
       grad_fn=<SelectBackward>)
Valid:	 Epoch[002/025] Iteration[010/010] Loss: 0.2904 Acc:89.54%
Training:Epoch[003/025] Iteration[010/016] Loss: 0.2266 Acc:93.12%
epoch:3 conv1.weights[0, 0, ...] :
 tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],
        [ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],
        [-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],
        [ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],
        [-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],
        [ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],
        [-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],
       grad_fn=<SelectBackward>)
Valid:	 Epoch[003/025] Iteration[010/010] Loss: 0.2252 Acc:94.12%
Training:Epoch[004/025] Iteration[010/016] Loss: 0.2805 Acc:87.50%
epoch:4 conv1.weights[0, 0, ...] :
 tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],
        [ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],
        [-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],
        [ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],
        [-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],
        [ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],
        [-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],
       grad_fn=<SelectBackward>)
Valid:	 Epoch[004/025] Iteration[010/010] Loss: 0.1953 Acc:95.42%
Training:Epoch[005/025] Iteration[010/016] Loss: 0.2423 Acc:91.88%
epoch:5 conv1.weights[0, 0, ...] :
 tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],
        [ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],
        [-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],
        [ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],
        [-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],
        [ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],
        [-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],
       grad_fn=<SelectBackward>)
Valid:	 Epoch[005/025] Iteration[010/010] Loss: 0.2399 Acc:92.16%
Training:Epoch[006/025] Iteration[010/016] Loss: 0.2455 Acc:90.00%
epoch:6 conv1.weights[0, 0, ...] :
 tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],
        [ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],
        [-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],
        [ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],
        [-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],
        [ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],
        [-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],
       grad_fn=<SelectBackward>)

4.2 使用trick1:冻结卷积层的学习率

# -*- coding: utf-8 -*-
""" 模型finetune方法, trick 1: 冻结卷积层的学习率 """
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
 
import sys
libo_DIR = os.path.abspath(os.path.dirname(__file__) + os.path.sep + ".." + os.path.sep + "..")
sys.path.append(libo_DIR)
 
from tools.my_dataset import PubuDataset
from tools.common_tools import set_seed
import torchvision.models as models
import torchvision
BASEDIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("use device :{}".format(device))
 
set_seed(1)  # 设置随机种子
label_name = {"ants": 0, "bees": 1}
 
# 参数设置
MAX_EPOCH = 25
BATCH_SIZE = 16
LR = 0.001
log_interval = 10
val_interval = 1
classes = 2
start_epoch = -1
lr_decay_step = 7
 
 
# ============================ step 1/5 数据 ============================
data_dir = os.path.abspath(os.path.join(BASEDIR, "..", "data", "pubu"))
if not os.path.exists(data_dir):
    raise Exception("\n{} 不存在,请下载 07-02-数据-模型finetune.zip  放到\n{} 下,并解压即可".format(
        data_dir, os.path.dirname(data_dir)))
 
train_dir = os.path.join(data_dir, "train")
valid_dir = os.path.join(data_dir, "val")
 
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
 
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])
 
valid_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])
 
# 构建 MyDataset 实例
train_data = PubuDataset(data_dir=train_dir, transform=train_transform)
valid_data = PubuDataset(data_dir=valid_dir, transform=valid_transform)
 
# 构建 DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
 
# ============================ step 2/5 模型 ============================
# 1/3 构建模型
resnet18_ft = models.resnet18()
 
# 2/3 加载参数
path_pretrained_model = os.path.join(BASEDIR, "..", "data", "resnet18-5c106cde.pth")
state_dict_load = torch.load(path_pretrained_model)  #加载字典state_dict
resnet18_ft.load_state_dict(state_dict_load)         #把state_dict放到模型中,这样就改变了原来的参数
 
# 法1 : 冻结卷积层
for param in resnet18_ft.parameters():
    param.requires_grad = False
print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.conv1.weight[0, 0, ...]))
 
# 3/3 替换fc层
num_ftrs = resnet18_ft.fc.in_features   #从原始的fc层获取输入有多少个神经元,给下面用。
resnet18_ft.fc = nn.Linear(num_ftrs, classes)  #构建一个新的Linear,输出神经元个数为分类数classes,输入为多少个神经元根据上一句得到。然后用这个Linear覆盖fc层。
 
resnet18_ft.to(device)

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()    # 选择损失函数
 
# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9)   # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)     # 设置学习率下降策略
 
# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()
for epoch in range(start_epoch + 1, MAX_EPOCH):
    loss_mean = 0.
    correct = 0.
    total = 0.
    resnet18_ft.train()
    for i, data in enumerate(train_loader):
        # forward
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = resnet18_ft(inputs)
        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()
        # update weights
        optimizer.step()
        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().cpu().sum().numpy()
        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                  epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.
            print("epoch:{} conv1.weights[0, 0, ...] :\n {}".format(epoch, resnet18_ft.conv1.weight[0, 0, ...]))
 
    scheduler.step()  # 更新学习率
 
    # validate the model
    if (epoch+1) % val_interval == 0:
        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        resnet18_ft.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = resnet18_ft(inputs)
                loss = criterion(outputs, labels)
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().cpu().sum().numpy()
                loss_val += loss.item()
            loss_val_mean = loss_val/len(valid_loader)
            valid_curve.append(loss_val_mean)
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                   epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val))
        resnet18_ft.train()
 
train_x = range(len(train_curve))
train_y = train_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve
 
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

4.3 使用trick2:不同参数不同学习率

# -*- coding: utf-8 -*-
""" 模型finetune, trick2 方法:不同参数不同的学习率"""
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
 
import sys
libo_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(libo_DIR)
 
from tools.my_dataset import PubuDataset
from tools.common_tools import set_seed
import torchvision.models as models
import torchvision
BASEDIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("use device :{}".format(device))
 
set_seed(1)  # 设置随机种子
label_name = {"ants": 0, "bees": 1}
 
# 参数设置
MAX_EPOCH = 25
BATCH_SIZE = 16
LR = 0.001
log_interval = 10
val_interval = 1
classes = 2
start_epoch = -1
lr_decay_step = 7

# ============================ step 1/5 数据 ============================
data_dir = os.path.abspath(os.path.join(BASEDIR, "..", "data", "pubu"))
if not os.path.exists(data_dir):
    raise Exception("\n{} 不存在,请下载 07-02-数据-模型finetune.zip  放到\n{} 下,并解压即可".format(
        data_dir, os.path.dirname(data_dir)))
 
train_dir = os.path.join(data_dir, "train")
valid_dir = os.path.join(data_dir, "val")
 
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
 
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])
 
valid_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])
 
# 构建 MyDataset 实例
train_data = PubuDataset(data_dir=train_dir, transform=train_transform)
valid_data = PubuDataset(data_dir=valid_dir, transform=valid_transform)
 
# 构建 DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
 
# ============================ step 2/5 模型 ============================
# 1/3 构建模型
resnet18_ft = models.resnet18()
 
# 2/3 加载参数
path_pretrained_model = os.path.join(BASEDIR, "..", "data", "resnet18-5c106cde.pth")
state_dict_load = torch.load(path_pretrained_model)  #加载字典state_dict
resnet18_ft.load_state_dict(state_dict_load)         #把state_dict放到模型中,这样就改变了原来的参数
 
# 3/3 替换fc层
num_ftrs = resnet18_ft.fc.in_features   #从原始的fc层获取输入有多少个神经元,给下面用。
resnet18_ft.fc = nn.Linear(num_ftrs, classes)  #构建一个新的Linear,输出神经元个数为分类数classes,输入为多少个神经元根据上一句得到。然后用这个Linear覆盖fc层。
 
resnet18_ft.to(device)

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()    # 选择损失函数
 
# ============================ step 4/5 优化器 ============================
# 法2 : conv 小学习率
fc_params_id = list(map(id, resnet18_ft.fc.parameters()))   # 返回的是parameters的内存地址。对fc层获取地址,形成一个list.
base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters())  # 过滤掉fc层。也就是前面卷积层的参数。
optimizer = optim.SGD([
    {'params': base_params, 'lr': LR*0.1},   # 前面卷积层的参数。设置卷积层的学习率,为LR*0.1,比后面的小十倍。如果设为0,表示冻结卷积层。
    {'params': resnet18_ft.fc.parameters(), 'lr': LR}
    ], momentum=0.9)  # fc层的学习率。
 
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)     # 设置学习率下降策略
 
# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()
for epoch in range(start_epoch + 1, MAX_EPOCH):
    loss_mean = 0.
    correct = 0.
    total = 0.
    resnet18_ft.train()
    for i, data in enumerate(train_loader):
        # forward
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = resnet18_ft(inputs)
        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()
        # update weights
        optimizer.step()
        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().cpu().sum().numpy()
        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                  epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.
 
    scheduler.step()  # 更新学习率
 
    # validate the model
    if (epoch+1) % val_interval == 0:
        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        resnet18_ft.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = resnet18_ft(inputs)
                loss = criterion(outputs, labels)
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().cpu().sum().numpy()
                loss_val += loss.item()
            loss_val_mean = loss_val/len(valid_loader)
            valid_curve.append(loss_val_mean)
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                   epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val))
        resnet18_ft.train()
 
train_x = range(len(train_curve))
train_y = train_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve
 
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid') 
plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

参考文献

  1. https://blog.csdn.net/pengchengliu/article/details/108968158
Logo

更多推荐