MAE源代码理解 part2 : 预训练调试
拿MAE预训练模型来分类
目录
part1 :
MAE源代码理解 part1 : 调试理解法_YI_SHU_JIA的博客-CSDN博客
git官方链接: GitHub - facebookresearch/mae: PyTorch implementation of MAE https//arxiv.org/abs/2111.06377
MAE就是一个上游的预训练模型,作用肯定是给下游分类或者干嘛用的 ,那么 怎么做呢 ?我跟着大家一起来探索。
1 事先准备
微调在 FINETUNE.md下 根据指示 需要下载微调模型
然后控制台输入这句代码, 这里面都是args的设置 其中有一个resume 是你下载微调模型的存放位置 而 data_path 是数据集 因为默认是imagenet 太大了 没法整 所以我删除了这一句 直接自己整了个数据集。
这次调试是在main_finetune.py内进行。 点运行 编辑配置
参数中输入
--eval --resume model_save/mae_finetuned_vit_base.pth --model vit_base_patch16 --batch_size 16
代码中找到这一句 直接替换成你的数据集 。我们就可以开始调试了 。
二 调试 :
不管args 我们直接进入main函数
misc.init_distributed_mode(args)
第一句就看不懂。 查了之后 哦~ 是与分布式训练相关的 , 这里默认不使用。
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
打印出工作目录和args的参数 。
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
# dataset_train = build_dataset(is_train=True, args=args)
# dataset_val = build_dataset(is_train=False, args=args)
dataset_train = train_set
dataset_val = val_set
一些随机性设置 和dataset的引入
if True:
你指定有点毛病 。
parser.add_argument('--num_workers', default=0, type=int)
由于在开启docker时没使用 下面的shm指令 所以将num_workers设置为了 0
docker run --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=2,3 --shm-size 8G -it --rm dev:v1 /bin/bash
if True: # args.distributed:
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
print("Sampler_train = %s" % str(sampler_train))
if args.dist_eval:
if len(dataset_val) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias
else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
获得GPU数量 我是1 这里有一堆关于多gpu训练的东西 全部跳过不看 乱七八糟的 。
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
训练器
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
print("Mixup is activated!")
mixup_fn = Mixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes)
数据增广方式 。 我们没有数据增广 。
model = models_vit.__dict__[args.model](
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
global_pool=args.global_pool,
)
我们来看看模型 模型传入了三个参数 分类数 drop率和 全局池化
def vit_base_patch16(**kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
传入这个函数 抱在**kwargs内
然后进入VIT模型中
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
""" Vision Transformer with support for global average pooling
"""
def __init__(self, global_pool=False, **kwargs):
super(VisionTransformer, self).__init__(**kwargs)
self.global_pool = global_pool
if self.global_pool:
norm_layer = kwargs['norm_layer']
embed_dim = kwargs['embed_dim']
self.fc_norm = norm_layer(embed_dim)
del self.norm # remove the original norm
添加了一个归一化层 到这里似乎看出来 这是一个纯验证的过程 。继续看
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Model = %s" % str(model_without_ddp))
print('number of params (M): %.2f' % (n_parameters / 1.e6))
打印出模型和模型需要的参数 是VIT模型 。
# build optimizer with layer-wise lr decay (lrd)
param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
no_weight_decay_list=model_without_ddp.no_weight_decay(),
layer_decay=args.layer_decay
)
optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
loss_scaler = NativeScaler()
参数decay 和优化器及损失函数 loss——scaler 等价于 求梯度回传 并且更新参数
elif args.smoothing > 0.:
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
这是loss 采用的是标签平滑loss 这是一种标签用的是概率的方法
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
加载模型。进入函数
def load_model(args, model_without_ddp, optimizer, loss_scaler):
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
#这里报错了 因为加载的模型是1000分类头的 所以我决定把分类数改成1000 反正我们只看流程 不看结果
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
optimizer.load_state_dict(checkpoint['optimizer'])
args.start_epoch = checkpoint['epoch'] + 1
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
print("With optim & sched!")
if args.eval:
test_stats = evaluate(data_loader_val, model, device)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
exit(0)
测试验证集
@torch.no_grad()
def evaluate(data_loader, model, device):
criterion = torch.nn.CrossEntropyLoss()
#分类损失
metric_logger = misc.MetricLogger(delimiter=" ")
header = 'Test:'
#这应该是显示用的
# switch to evaluation mode
model.eval()
for batch in metric_logger.log_every(data_loader, 10, header):
#后面的东西是用来打印的
images = batch[0]
target = batch[-1]
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# compute output
with torch.cuda.amp.autocast():
output = model(images)
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
#计算top1,5准确率 这个accuracy函数 可以从torch.utils 中调用 我以前咋不知道
batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
从测试出来得到了 准确度 给人的感觉 vit最后一层就是分类层。
这好像没什么 就是一个载入模型 然后计算准确率 我也不知道他是怎么写的如此的复杂的 同样也不知道作用 现在 让我们把
args.finetune 改为mae_pretrain_vit_base.pth 把
args.eval改为False 进入微调步骤
if args.finetune and not args.eval:
checkpoint = torch.load(args.finetune, map_location='cpu')
print("Load pre-trained checkpoint from: %s" % args.finetune)
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
载入了模型
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
这一句是说 如果分类头的分类数不等于预训练模型的分类数 就去掉分类头 。
interpolate_pos_embed(model, checkpoint_model)
位置嵌入 因为mae的位置嵌入是固定的 所以直接载入预训练模型的位置
def interpolate_pos_embed(model, checkpoint_model):
if 'pos_embed' in checkpoint_model:
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
msg = model.load_state_dict(checkpoint_model, strict=False)
print(msg)
这一句很重要 因为在MAE预训练模型中 是没有head层的 也没有归一化层 需要载入
trunc_normal_(model.head.weight, std=2e-5)
接下来微调
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_accuracy = 0.0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
args.clip_grad, mixup_fn,
log_writer=log_writer,
args=args
)
if args.output_dir:
misc.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch)
简单的准备
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
mixup_fn: Optional[Mixup] = None, log_writer=None,
args=None):
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 20
accum_iter = args.accum_iter
训练函数
if data_iter_step % accum_iter == 0:
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
更改学习率的方法有了 如果迭代步数 甚至可以做到调整层学习率
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
#不可数就停止?
loss /= accum_iter
loss_scaler(loss, optimizer, clip_grad=max_norm,
parameters=model.parameters(), create_graph=False,
update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()
写的花里胡哨的 其实就是普通的算梯度 然后回传 然后梯度归0
这个微调就完啦!! 看起来写的好复杂 啊 乱七八糟的 但其实好像就是把MAE模型拿过来 去掉最后的归一化层 然后加上分类头和归一化层 得到结果 说白了 普通的微调呗 那我就不客气了 自己去搞!!!
用MAE预训练的模型用于自己下游的分类
下面放上我微调MAE用来做医学图像分类的代码 :
(之前用的是食物分类的,但是那个被我搞掉了。 这里只是涉及加载数据集的不同罢了 。)
首先args 设置分类数 , drop率 全局池化 模型选择 预训练模型的位置 把之前下的
mae_pretrain_vit_base.pth这个文件放进去
:
def get_args_parser():
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
#model
parser.add_argument('--nb_classes', default=2, type=int,
help='number of the classfication types')
parser.add_argument('--drop_path', default=0.1, type=float, metavar='PCT',
help='Drop path rate (default: 0.1)')
parser.add_argument('--global_pool', action='store_true')
parser.set_defaults(global_pool=True)
parser.add_argument('--model', default='vit_base_patch16', type=str, metavar='MODEL',
help='Name of model to train')
#path
parser.add_argument('--predModelPath', default='model_save/mae_pretrain_vit_base.pth',
help='finetune from checkpoint')
return parser
args = get_args_parser()
args = args.parse_args()
初始化模型 就是加载模型 :
def initMaeClass(args):
model = models_vit.__dict__[args.model](
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
global_pool=args.global_pool,
)
checkpoint = torch.load(args.predModelPath, map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
msg = model.load_state_dict(checkpoint_model, strict=False)
print(msg)
return model
导入数据集 设置超参数
##################################################################
savePath = 'model_save/foodFine'
class1Train = r'/home/dataset/pendi/cls1/train'
class2Train = r'/home/dataset/pendi/cls2/train'
class1Val = r'/home/dataset/pendi/cls1/val'
class2Val = r'/home/dataset/pendi/cls2/val'
class1Test = r'/home/dataset/pendi/cls1/test'
class2Test = r'/home/dataset/pendi/cls2/test'
trainloader = getDataLoader(class1Train, class2Train, batchSize=1)
valloader = getDataLoader(class1Val, class2Val, batchSize=1)
#################################################################
random.seed(1)
batch_size = 128
learning_rate = 1e-4
w = 0.00001
criterion =nn.CrossEntropyLoss()
epoch = 2000
# w = 0.00001
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
训练
train_VAL(model,trainloader, valloader, optimizer, criterion, batch_size, w, num_epoch=epoch,save_=savePath,device=device)
这就是用MAE预训练模型用来提特征然后微调分类的方法了
全部代码 :
import torch
import matplotlib.pyplot as plt
import time
import numpy as np
import torch.nn as nn
import torch.nn.init as init
from torch.utils.data import DataLoader,Dataset
#更新学习率
def train_VAL(model,train_set,val_set,optimizer,loss,batch_size,w,num_epoch,device, save_):
# train_loader = DataLoader(train_set,batch_size=batch_size,shuffle=True,num_workers=0)
# val_loader = DataLoader(val_set,batch_size=batch_size,shuffle=True,num_workers=0)
train_loader= train_set
val_loader = val_set
# 用测试集训练模型model(),用验证集作为测试集来验证
plt_train_loss = []
plt_val_loss = []
plt_train_acc = []
plt_val_acc = []
maxacc = 0
for epoch in range(num_epoch):
# update_lr(optimizer,epoch)
epoch_start_time = time.time()
train_acc = 0.0
train_loss = 0.0
val_acc = 0.0
val_loss = 0.0
model.train() # 确保 model_utils 是在 训练 model_utils (开启 Dropout 等...)
for i, data in enumerate(train_loader):
optimizer.zero_grad() # 用 optimizer 将模型参数的梯度 gradient 归零
train_pred = model(data[0].to(device)) # 利用 model_utils 得到预测的概率分布,这边实际上是调用模型的 forward 函数
# batch_loss = loss(train_pred, data[1].cuda(), w, model) # 计算 loss (注意 prediction 跟 label 必须同时在 CPU 或是 GPU 上)
batch_loss = loss(train_pred, data[1].to(device))
batch_loss.backward() # 利用 back propagation 算出每个参数的 gradient
optimizer.step() # 以 optimizer 用 gradient 更新参数
train_acc += np.sum(np.argmax(train_pred.cpu().data.numpy(), axis=1) == data[1].numpy())
train_loss += batch_loss.item()
#验证集val
model.eval()
with torch.no_grad():
for i, data in enumerate(val_loader):
val_pred = model(data[0].to(device))
# batch_loss = loss(val_pred, data[1].cuda(),w, model)
batch_loss = loss(val_pred, data[1].to(device))
val_acc += np.sum(np.argmax(val_pred.cpu().data.numpy(), axis=1) == data[1].numpy())
val_loss += batch_loss.item()
if val_acc > maxacc:
torch.save(model,save_+'max')
maxacc = val_acc
# torch.save({'epoch': epoch + 1, 'state_dict': model_utils.state_dict(), 'best_loss': val_loss,
# 'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma},
# 'cat_dog_res18')
#保存用于画图
plt_train_acc.append(train_acc/train_set.dataset.__len__())
plt_train_loss.append(train_loss/train_set.dataset.__len__())
plt_val_acc.append(val_acc/val_set.dataset.__len__())
plt_val_loss.append(val_loss/val_set.dataset.__len__())
#将结果 print 出來
print('[%03d/%03d] %2.2f sec(s) Train Acc: %3.6f Loss: %3.6f | Val Acc: %3.6f loss: %3.6f' % \
(epoch + 1, num_epoch, time.time()-epoch_start_time, \
plt_train_acc[-1], plt_train_loss[-1], plt_val_acc[-1], plt_val_loss[-1]))
if epoch == num_epoch-1:
torch.save(model,save_ + 'final')
# Loss曲线
plt.plot(plt_train_loss)
plt.plot(plt_val_loss)
plt.title('Loss')
plt.legend(['train', 'val'])
plt.savefig('loss.png')
plt.show()
# Accuracy曲线
plt.plot(plt_train_acc)
plt.plot(plt_val_acc)
plt.title('Accuracy')
plt.legend(['train', 'val'])
plt.savefig('acc.png')
plt.show()
import os
import numpy as np
import torch
import torch.nn as nn
import random
import argparse
import torch
import timm
assert timm.__version__ == "0.5.4" # version check
import models_vit
from torch import optim
from model_utils.data import getDataLoader
from model_utils.train import train_VAL
# from model_utils.foodData import trainloader, valloader
def get_args_parser():
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
#model
parser.add_argument('--nb_classes', default=11, type=int,
help='number of the classfication types')
parser.add_argument('--drop_path', default=0.1, type=float, metavar='PCT',
help='Drop path rate (default: 0.1)')
parser.add_argument('--global_pool', action='store_true')
parser.set_defaults(global_pool=True)
parser.add_argument('--model', default='vit_base_patch16', type=str, metavar='MODEL',
help='Name of model to train')
#path
parser.add_argument('--predModelPath', default='model_save/mae_pretrain_vit_base.pth',
help='finetune from checkpoint')
return parser
def initMaeClass(args):
model = models_vit.__dict__[args.model](
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
global_pool=args.global_pool,
)
checkpoint = torch.load(args.predModelPath, map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
msg = model.load_state_dict(checkpoint_model, strict=False)
print(msg)
return model
##################################################################
savePath = 'model_save/foodFine'
class1Train = r'/home/dataset/food/cls1/train'
class2Train = r'/home/dataset/pendi/cls2/train'
class1Val = r'/home/dataset/pendi/cls1/val'
class2Val = r'/home/dataset/pendi/cls2/val'
class1Test = r'/home/dataset/pendi/cls1/test'
class2Test = r'/home/dataset/pendi/cls2/test'
###
trainloader = getDataLoader(class1Train, class2Train, batchSize=1)
valloader = getDataLoader(class1Val, class2Val, batchSize=1)
#读数据这里按照自己的写法就行 。
#################################################################
random.seed(1)
batch_size = 128
learning_rate = 1e-4
w = 0.00001
criterion =nn.CrossEntropyLoss()
epoch = 2000
# w = 0.00001
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
##################################################################
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
model = initMaeClass(args).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
train_VAL(model,trainloader, valloader, optimizer, criterion, batch_size, w, num_epoch=epoch,save_=savePath,device=device)
# modelpath1 = savePath+'max'
# model1 = torch.load(modelpath1)
#
# test(model1, test_set=test_dataset)
#
# modelpath2 = savePath+'final'
#
# model2 = torch.load(modelpath2)
# test(model2, test_set=test_dataset)
import cv2
import os
import numpy as np
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms,datasets
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.model_selection import train_test_split
import torch
import random
from imblearn.over_sampling import SMOTE
from collections import Counter
HW = 224
def readjpgfile(listpath,label,rate = None):
assert rate == None or rate//1 == rate
# label 是一个布尔值,代表需不需要返回 y 值
image_dir = sorted(os.listdir(listpath))
n = len(image_dir)
if rate:
n = n*rate
# x存储图片,每张彩色图片都是128(高)*128(宽)*3(彩色三通道)
x = np.zeros((n, HW , HW , 3), dtype=np.uint8)
# y存储标签,每个y大小为1
y = np.zeros(n, dtype=np.uint8)
if not rate:
for i, file in enumerate(image_dir):
img = cv2.imread(os.path.join(listpath, file))
# xshape = img.shape
# Xmid = img.shape[1]//2
# 利用cv2.resize()函数将不同大小的图片统一为128(高)*128(宽) os.path.join作用是将两个路径拼接起来。路径+文件名
x[i, :, :] = cv2.resize(img,(HW , HW ))
y[i] = label
else:
for i, file in enumerate(image_dir):
img = cv2.imread(os.path.join(listpath, file))
# xshape = img.shape
# Xmid = img.shape[1]//2
# 利用cv2.resize()函数将不同大小的图片统一为128(高)*128(宽) os.path.join作用是将两个路径拼接起来。路径+文件名
for j in range(rate):
x[rate*i + j, :, :] = cv2.resize(img,(HW , HW ))
y[rate*i + j] = label
return x,y
#training 时,通过随机旋转、水平翻转图片来进行数据增强(data_abnor augmentation)
train_transform = transforms.Compose([
# transforms.RandomResizedCrop(150),
transforms.ToPILImage(),
transforms.ToTensor()
# transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225]
])
#testing 时,不需要进行数据增强(data_abnor augmentation)
test_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(),
])
class ImgDataset(Dataset):
def __init__(self, x, y=None, transform=None, lessTran = False):
self.x = x
# label 需要是 LongTensor 型
self.y = y
if y is not None:
self.y = torch.LongTensor(y)
self.transform = transform
self.lessTran = lessTran
# 强制水平翻转
self.trans0 = torchvision.transforms.Compose([
transforms.ToPILImage(),
torchvision.transforms.Resize(256),
torchvision.transforms.RandomCrop(224),
torchvision.transforms.RandomHorizontalFlip(p=1),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# 强制垂直翻转
self.trans1 = torchvision.transforms.Compose([
transforms.ToPILImage(),
torchvision.transforms.Resize(256),
torchvision.transforms.RandomCrop(224),
torchvision.transforms.RandomVerticalFlip(p=1),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# 旋转-90~90
self.trans2 = torchvision.transforms.Compose([
transforms.ToPILImage(),torchvision.transforms.Resize(256),
torchvision.transforms.RandomCrop(224),
torchvision.transforms.RandomRotation(90),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# 亮度在0-2之间增强,0是原图
self.trans3 = torchvision.transforms.Compose([
transforms.ToPILImage(),torchvision.transforms.Resize(256),
torchvision.transforms.RandomCrop(224),
torchvision.transforms.ColorJitter(brightness=1),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# 修改对比度,0-2之间增强,0是原图
self.trans4 = torchvision.transforms.Compose([
transforms.ToPILImage(),torchvision.transforms.Resize(256),
torchvision.transforms.RandomCrop(224),
torchvision.transforms.ColorJitter(contrast=2),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# 颜色变化
self.trans5 = torchvision.transforms.Compose([
transforms.ToPILImage(),torchvision.transforms.Resize(256),
torchvision.transforms.RandomCrop(224),
torchvision.transforms.ColorJitter(hue=0.5),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# 混合
self.trans6 = torchvision.transforms.Compose([
transforms.ToPILImage(),torchvision.transforms.Resize(256),
torchvision.transforms.RandomCrop(224),
torchvision.transforms.ColorJitter(brightness=1, contrast=2, hue=0.5),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
self.trans_list = [self.trans0, self.trans1, self.trans2, self.trans3, self.trans4, self.trans5, self.trans6]
def __len__(self):
return len(self.x)
def __getitem__(self, index):
X = self.x[index]
if self.y is not None:
if self.lessTran:
num = random.randint(0, 6)
X = self.trans_list[num](X)
else:
if self.transform is not None:
X = self.transform(X)
Y = self.y[index]
return X, Y
else:
return X
def getbatch(self,indices):
images = []
labels = []
for index in indices:
image,label = self.__getitem__(index)
images.append(image)
labels.append(label)
return torch.stack(images),torch.tensor(labels)
def getDateset(dir_class1, dir_class2, testSize=0.3,rate = None, testNum = None, lessTran = False):
'''
:param dir_class1: 这个是参数较少的那个
:param dir_class2:
:param testSize:
:param rate:
:param testNum:
:return:
'''
x1,y1 = readjpgfile(dir_class1,0,rate=rate) #类1是0
x2,y2 = readjpgfile(dir_class2,1) #类2是1
if testNum == -1:
X = np.concatenate((x1, x2))
Y = np.concatenate((y1, y2))
dataset = ImgDataset(X, Y, transform=train_transform, lessTran = lessTran)
return dataset
if not testNum :
X = np.concatenate((x1, x2))
Y = np.concatenate((y1, y2))
train_x, test_x, train_y, test_y = train_test_split(X,Y,test_size=testSize,random_state=0)
else:
train_x1, test_x1, train_y1, test_y1 = train_test_split(x1,y1,test_size=testNum/len(y1),random_state=0)
train_x2, test_x2, train_y2, test_y2 = train_test_split(x2,y2,test_size=testNum/len(y2),random_state=0)
print(len(test_y2),len(test_y1))
train_x = np.concatenate((train_x1,train_x2))
test_x = np.concatenate((test_x1, test_x2))
train_y = np.concatenate((train_y1,train_y2))
test_y = np.concatenate((test_y1, test_y2))
train_dataset = ImgDataset(train_x,train_y ,transform=train_transform,lessTran = lessTran)
test_dataset = ImgDataset(test_x ,test_y,transform=test_transform,lessTran = lessTran)
# test_x1,test_y1 = readjpgfile(r'F:\li_XIANGMU\pycharm\deeplearning\cat_dog\catsdogs\test\Cat',0) #猫是0
# test_x2,test_y2 = readjpgfile(r'F:\li_XIANGMU\pycharm\deeplearning\cat_dog\catsdogs\test\Dog',1)
# test_x = np.concatenate((test_x1,test_x2))
# test_y = np.concatenate((test_y1,test_y2))
return train_dataset, test_dataset
def smote(X_train,y_train):
oversampler = SMOTE(sampling_strategy='auto', random_state=np.random.randint(100), k_neighbors=5, n_jobs=-1)
os_X_train, os_y_train = oversampler.fit_resample(X_train,y_train)
print('Resampled dataset shape {}'.format(Counter(os_y_train)))
return os_X_train, os_y_train
def getDataLoader(class1path, class2path, batchSize,mode='train'):
assert mode in ['train','val', 'test']
if mode == 'train':
train_set = getDateset(class1path, class2path, testNum=-1)
trainloader = DataLoader(train_set,batch_size=batchSize, shuffle=True)
return trainloader
elif mode == 'test':
testset = getDateset(class1path, class2path, testNum=-1)
testLoader = DataLoader(testset, batch_size=1, shuffle=False)
return testLoader
更多推荐
所有评论(0)