一、代码中的数据集可以通过以下链接获取

百度网盘提取码:lala

二、代码运行环境

Pytorch-gpu==1.7.1
Python==3.7

三、数据集处理代码如下所示

import torchvision
from torchvision import transforms
import os
from torch.utils.data import DataLoader


def loader_data():
    BATCH_SIZE = 64
    train_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.RandomCrop(192),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(0.2),
        transforms.ColorJitter(brightness=0.5),
        transforms.ColorJitter(contrast=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],
                             std=[0.5, 0.5, 0.5])
    ])
    test_transform = transforms.Compose([
        transforms.Resize((192, 192)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],
                             std=[0.5, 0.5, 0.5])
    ])
    train_ds = torchvision.datasets.ImageFolder(root=os.path.join('dataset', 'train_weather'),
                                                transform=train_transform)
    test_ds = torchvision.datasets.ImageFolder(root=os.path.join('dataset', 'test_weather'), transform=test_transform)
    train_dl = DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
    test_dl = DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)

    return train_dl, test_dl, test_ds.class_to_idx

四、模型的构建代码如下所示

import torch
import torchvision


def load_model():
    model = torchvision.models.vgg16(pretrained=True)
    for p in model.features.parameters():
        p.requires_grad = False
    model.classifier[-1].out_features = 4
    return model


def load_resnet18():
    model = torchvision.models.resnet18(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    in_f = model.fc.in_features
    model.fc = torch.nn.Linear(in_features=in_f, out_features=4)

    return model

五、模型的训练代码如下所示

import torch
from data_loader import loader_data
from model_loader import load_model, load_resnet18
import numpy as np
import tqdm
import os
from sklearn.metrics import accuracy_score
from torch.optim import lr_scheduler

# 数据的加载
train_dl, test_dl, class_to_idx = loader_data()

# 模型的加载
model = load_resnet18()

# 训练的相关配置
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.0001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.9)
loss_fn = torch.nn.CrossEntropyLoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCHS = 100

# 进行训练
model = model.to(device)
for epoch in range(EPOCHS):
    # 训练部分
    model.train()
    train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
    train_tqdm.set_description_str('Train epoch {:2d}'.format(epoch))
    train_accuracy_sum = []
    train_loss_sum = []
    for images, labels in train_tqdm:
        images, labels = images.to(device), labels.to(device)
        pred = model(images)
        loss = loss_fn(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 进行训练部分的展示
        train_loss_sum.append(loss.item())
        pred = torch.argmax(input=pred, dim=-1)
        train_accuracy_sum.append(accuracy_score(y_true=labels.cpu().numpy(), y_pred=pred.cpu().numpy()))
        train_tqdm.set_postfix_str(
            'loss is {:14f}, accuracy is {:14f}'.format(np.mean(train_loss_sum), np.mean(train_accuracy_sum)))
    train_tqdm.close()

    # 学习速率
    exp_lr_scheduler.step()

    # 验证部分
    with torch.no_grad():
        model.eval()
        test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
        test_tqdm.set_description_str('Val epoch {:2d}'.format(epoch))
        test_accuracy_sum = []
        test_loss_sum = []
        for images, labels in test_tqdm:
            images, labels = images.to(device), labels.to(device)
            pred = model(images)
            loss = loss_fn(pred, labels)
            # 进行验证结果的展示
            test_loss_sum.append(loss.item())
            pred = torch.argmax(input=pred, dim=-1)
            test_accuracy_sum.append(accuracy_score(y_true=labels.cpu().numpy(), y_pred=pred.cpu().numpy()))
            test_tqdm.set_postfix_str(
                'loss is {:14f}, accuracy is {:14f}'.format(np.mean(test_loss_sum), np.mean(test_accuracy_sum)))
        test_tqdm.close()

# 模型的保存 
if not os.path.exists(os.path.join('model_data')):
    os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))

六、模型的预测代码如下所示

import os
import torch
from data_loader import loader_data
from model_loader import load_model, load_resnet18
import matplotlib.pyplot as plt
import matplotlib

# 数据的加载
train_dl, test_dl, class_index = loader_data()
image, label = next(iter(test_dl))
new_class = dict((v, k) for k, v in class_index.items())

# 模型的加载
model = load_resnet18()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'))
model.load_state_dict(model_state_dict)
model.eval()

# 进行模型的预测
index = 23
with torch.no_grad():
    pred = model(image)
    pred = torch.argmax(input=pred, dim=-1)
    # matplotlib.rc("font", family='Microsoft YaHei')
    plt.axis('off')
    plt.title('predict result: ' + new_class.get(pred[index].item()) + ', label result: ' + new_class.get(
        label[index].item()),
              )
    plt.imshow(image[index].permute(1, 2, 0))
    plt.savefig('result.png')
    plt.show()

七、代码的运行结果如下所示

在这里插入图片描述

Logo

更多推荐