Pytorch---使用Pytorch的预训练模型实现四种天气分类问题
使用Pytorch的预训练模型实现四种天气分类问题
·
一、代码中的数据集可以通过以下链接获取
二、代码运行环境
Pytorch-gpu==1.7.1
Python==3.7
三、数据集处理代码如下所示
import torchvision
from torchvision import transforms
import os
from torch.utils.data import DataLoader
def loader_data():
BATCH_SIZE = 64
train_transform = transforms.Compose([
transforms.Resize(224),
transforms.RandomCrop(192),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(0.2),
transforms.ColorJitter(brightness=0.5),
transforms.ColorJitter(contrast=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
])
test_transform = transforms.Compose([
transforms.Resize((192, 192)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
])
train_ds = torchvision.datasets.ImageFolder(root=os.path.join('dataset', 'train_weather'),
transform=train_transform)
test_ds = torchvision.datasets.ImageFolder(root=os.path.join('dataset', 'test_weather'), transform=test_transform)
train_dl = DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl, test_ds.class_to_idx
四、模型的构建代码如下所示
import torch
import torchvision
def load_model():
model = torchvision.models.vgg16(pretrained=True)
for p in model.features.parameters():
p.requires_grad = False
model.classifier[-1].out_features = 4
return model
def load_resnet18():
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
in_f = model.fc.in_features
model.fc = torch.nn.Linear(in_features=in_f, out_features=4)
return model
五、模型的训练代码如下所示
import torch
from data_loader import loader_data
from model_loader import load_model, load_resnet18
import numpy as np
import tqdm
import os
from sklearn.metrics import accuracy_score
from torch.optim import lr_scheduler
# 数据的加载
train_dl, test_dl, class_to_idx = loader_data()
# 模型的加载
model = load_resnet18()
# 训练的相关配置
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.0001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.9)
loss_fn = torch.nn.CrossEntropyLoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCHS = 100
# 进行训练
model = model.to(device)
for epoch in range(EPOCHS):
# 训练部分
model.train()
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch {:2d}'.format(epoch))
train_accuracy_sum = []
train_loss_sum = []
for images, labels in train_tqdm:
images, labels = images.to(device), labels.to(device)
pred = model(images)
loss = loss_fn(pred, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 进行训练部分的展示
train_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
train_accuracy_sum.append(accuracy_score(y_true=labels.cpu().numpy(), y_pred=pred.cpu().numpy()))
train_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(train_loss_sum), np.mean(train_accuracy_sum)))
train_tqdm.close()
# 学习速率
exp_lr_scheduler.step()
# 验证部分
with torch.no_grad():
model.eval()
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Val epoch {:2d}'.format(epoch))
test_accuracy_sum = []
test_loss_sum = []
for images, labels in test_tqdm:
images, labels = images.to(device), labels.to(device)
pred = model(images)
loss = loss_fn(pred, labels)
# 进行验证结果的展示
test_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
test_accuracy_sum.append(accuracy_score(y_true=labels.cpu().numpy(), y_pred=pred.cpu().numpy()))
test_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(test_loss_sum), np.mean(test_accuracy_sum)))
test_tqdm.close()
# 模型的保存
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
六、模型的预测代码如下所示
import os
import torch
from data_loader import loader_data
from model_loader import load_model, load_resnet18
import matplotlib.pyplot as plt
import matplotlib
# 数据的加载
train_dl, test_dl, class_index = loader_data()
image, label = next(iter(test_dl))
new_class = dict((v, k) for k, v in class_index.items())
# 模型的加载
model = load_resnet18()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'))
model.load_state_dict(model_state_dict)
model.eval()
# 进行模型的预测
index = 23
with torch.no_grad():
pred = model(image)
pred = torch.argmax(input=pred, dim=-1)
# matplotlib.rc("font", family='Microsoft YaHei')
plt.axis('off')
plt.title('predict result: ' + new_class.get(pred[index].item()) + ', label result: ' + new_class.get(
label[index].item()),
)
plt.imshow(image[index].permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
七、代码的运行结果如下所示
更多推荐
已为社区贡献1条内容
所有评论(0)