用 Transformers微调ViT图像分类
批处理以字典列表的形式出现,因此您可以将它们解压缩+堆叠到批处理张量中。由于 将返回批处理字典,因此您可以稍后将输入到模型中。return {来自的准确度指标可以轻松用于将预测与标签进行比较。下面,您可以看到如何使用datasetscompute_metricsTrainer让我们加载预训练模型。我们将添加 init,以便模型创建具有正确单位数的分类头。我们还将在 Hub 微件中包含 和 映射以具
【翻译 Nate Raw 的 Fine-Tune ViT for Image Classification with 🤗 Transformers】
正如基于 Transformers 的模型彻底改变了NLP一样,我们现在看到将它们应用于各种其他领域的论文激增。其中最具革命性的是 Vision Transformer (ViT),它由 Google Brain 的一组研究人员于 2021 年六月推出。
本文探讨了如何标记图像,就像标记句子一样,以便将它们传递给 transformer 模型进行训练。这是一个非常简单的概念,真的…
- 将映像拆分为子映像修补程序的网格
- 使用线性投影嵌入每个图面
- 每个嵌入式修补程序都将成为标记,生成的嵌入式修补程序序列就是传递给模型的序列。
事实证明,一旦你完成了上述操作,你就可以像习惯NLP任务一样对转换器进行预训练和微调。很贴心😎。
在这篇博文中,我们将介绍如何利用 🤗 下载的数据集和处理图像分类数据集,然后使用它们来微调带有 🤗 transformers 的预训练 ViT。
首先,让我们先安装这两个软件包。
pip install datasets transformers
补充建议4.28.x, 原因后面再提
加载一个数据集
让我们首先加载一个小的图像分类数据集并查看其结构。
我们将使用豆数据集,这是健康和不健康豆叶图片的集合。🍃
from datasets import load_dataset
ds = load_dataset('beans')
ds
让我们看一下从 由’train’拆分的豆荚数据集的第 400 个示例。您会注意到数据集中的每个示例都有 3 个特征:
-
image:PIL 图像
-
image_file_path:str 路径指向需要加载的图像文件
-
labels:数据集。类标签功能,它是标签的整数表示形式。(稍后您将看到如何获取字符串类名,别担心!)
ex = ds['train'][400] ex { 'image': <PIL.JpegImagePlugin ...>, 'image_file_path': '/root/.cache/.../bean_rust_train.4.jpg', 'labels': 1 }
让我们先看一下图像👀
image = ex['image']
image
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Q081sdpR-1685285046313)(null)]
那绝对是一片叶子!但是什么样的呢?😅
由于此数据集的特征是 datasets.features.ClassLabel,我们可以使用它来查找此示例的标签 ID 的相应名称。
首先,让我们访问labels的特征定义
labels = ds['train'].features['labels']
labels
ClassLabel(num_classes=3, names=['angular_leaf_spot', 'bean_rust', 'healthy'], names_file=None, id=None)
现在,让我们打印出示例的类标签。你可以通过使用 int2str 函数来做到这一点,顾名思义,它允许传递类的整数表示形式来查找字符串 label.ClassLabel
labels.int2str(ex['labels'])
'bean_rust'
事实证明,上面显示的叶子感染了豆锈病,豆锈病是豆类植物中的一种严重疾病。😢
让我们编写一个函数,该函数将显示每个类的示例网格,以便更好地了解您正在使用的内容。
import random
from PIL import ImageDraw, ImageFont, Image
def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):
w, h = size
labels = ds['train'].features['labels'].names
grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
draw = ImageDraw.Draw(grid)
font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)
for label_id, label in enumerate(labels):
# Filter the dataset by a single label, shuffle it, and grab a few samples
ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))
# Plot this label's examples along a row
for i, example in enumerate(ds_slice):
image = example['image']
idx = examples_per_class * label_id + i
box = (idx % examples_per_class * w, idx // examples_per_class * h)
grid.paste(image.resize(size), box=box)
draw.text(box, label, (255, 255, 255), font=font)
return grid
show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)
依我所见:
- 角叶斑:有不规则的棕色斑块
- 豆锈:有圆形棕色斑点,周围环绕着白色黄色的环
- 健康:。。。看起来很健康。🤷♂️
加载 ViT 特征提取器
现在我们知道了我们的图像是什么样子的,并更好地了解了我们试图解决的问题。让我们看看如何为我们的模型准备这些图像!
训练 ViT 模型时,会将特定转换应用于馈入其中的图像。在图像上使用错误的转换,模型将无法理解它所看到的内容!🖼 ➡️ 🔢
为了确保我们应用正确的转换,我们将使用一个 ViTFeatureExtractor 初始化,该配置与我们计划使用的预训练模型一起保存。在我们的例子中,我们将使用google/vit-base-patch16-224-in21k模型,所以让我们从Hugging Face Hub加载它的特征提取器。
from transformers import ViTFeatureExtractor
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
您可以通过打印来查看特征提取器配置。
ViTFeatureExtractor {
"do_normalize": true,
"do_resize": true,
"feature_extractor_type": "ViTFeatureExtractor",
"image_mean": [
0.5,
0.5,
0.5
],
"image_std": [
0.5,
0.5,
0.5
],
"resample": 2,
"size": 224
}
要处理图像,只需将其传递给特征提取器的调用函数即可。这将返回一个字典容器,这是要传递给model.pixel values的数字表示形式。
默认情况下,你会得到一个 NumPy 数组,但如果你添加参数 return_tensors=‘pt’,你会得到 torch 张量。
feature_extractor(image, return_tensors='pt')
应该给你一些类似的东西…
{
'pixel_values': tensor([[[[ 0.2706, 0.3255, 0.3804, ...]]]])
}
…其中张量的形状为 。(1, 3, 224, 224)
处理数据集
现在您知道如何读取图像并将其转换为输入,让我们编写一个函数,将这两件事放在一起以处理数据集中的单个示例。
def process_example(example):
inputs = feature_extractor(example['image'], return_tensors='pt')
inputs['labels'] = example['labels']
return inputs
process_example(ds['train'][0])
{
'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078, ..., ]]]]),
'labels': 0
}
虽然您可以一次调用并将其应用于每个示例,但这可能会非常慢,尤其是在使用较大的数据集时。相反,您可以对数据集应用转换。转换仅在为示例编制索引时应用于示例 ds.map
但是,首先,您需要更新最后一个函数以接受一批数据,因为这是预期的。ds.with_transform
ds = load_dataset('beans')
def transform(example_batch):
# Take a list of PIL images and turn them to pixel values
inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
# Don't forget to include the labels!
inputs['labels'] = example_batch['labels']
return inputs
您可以使用 ds.with_transform(transform) 将其直接应用于数据集。
prepared_ds = ds.with_transform(transform)
现在,每当您从数据集中获取示例时,转换将是 实时应用(在样品和切片上,如下所示)
prepared_ds['train'][0:2]
这一次,得到的张量将具有形状。pixel_values(2, 3, 224, 224)
{
'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078, ..., ]]]]),
'labels': [0, 0]
}
培训和评估
数据已处理完毕,即可开始设置训练管道。这篇博文使用 🤗 的 Trainer,但这需要我们先做几件事:
-
定义排序规则函数。
-
定义评估指标。在训练期间,应评估模型的预测准确性。您应该相应地定义一个函数。compute_metrics
-
加载预训练的检查点。您需要加载预训练的检查点并正确配置它以进行训练。
-
定义训练配置。
微调模型后,您将根据评估数据正确评估模型,并验证它是否确实学会了正确分类图像。
定义我们的数据整理器
批处理以字典列表的形式出现,因此您可以将它们解压缩+堆叠到批处理张量中。
由于 将返回批处理字典,因此您可以稍后将输入到模型中。✨collate_fn**unpack
import torch
def collate_fn(batch):
return {
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
'labels': torch.tensor([x['labels'] for x in batch])
}
定义评估指标
来自的准确度指标可以轻松用于将预测与标签进行比较。下面,您可以看到如何使用datasetscompute_metricsTrainer
import numpy as np
from datasets import load_metric
metric = load_metric("accuracy")
def compute_metrics(p):
return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
让我们加载预训练模型。我们将添加 init,以便模型创建具有正确单位数的分类头。我们还将在 Hub 微件中包含 和 映射以具有人类可读的标签(如果您选择 )。num_labelsid2labellabel2idpush_to_hub
from transformers import ViTForImageClassification
labels = ds['train'].features['labels'].names
model = ViTForImageClassification.from_pretrained(
model_name_or_path,
num_labels=len(labels),
id2label={str(i): c for i, c in enumerate(labels)},
label2id={c: str(i) for i, c in enumerate(labels)}
)
快准备好训练了!在此之前需要做的最后一件事是通过定义 TrainingArguments 来设置训练配置。
其中大多数都是不言自明的,但这里非常重要的一点是。这将删除模型调用函数未使用的任何功能。默认情况下,这是因为通常最好删除未使用的特征列,从而更轻松地将输入解压缩到模型的调用函数中。但是,在我们的例子中,我们需要未使用的功能(特别是“图像”)来创建“pixel_values”.remove_unused_columns=FalseTrue
我想说的是,如果你忘记设置.remove_unused_columns=False,你会过得很糟糕
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="./vit-base-beans",
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
fp16=True,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=False,
report_to='tensorboard',
load_best_model_at_end=True,
)
此处我一直遇到“NameError: name ‘PartialState’ is not defined.”,查阅
资料新的transformers用的accelerate有用到PartialState,所以需要控制它的版本,并且pip install git+https://github.com/huggingface/accelerate
安装dev版本或没有用过multi-GPUs (such as in Colab)的使用 pip install accelerate -U
另外, 根据提示安装缺失的pypi包,解决raw.githubusercontent.com无法访问的问题
现在,所有实例都可以传递给训练师,我们准备开始训练了!
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
compute_metrics=compute_metrics,
train_dataset=prepared_ds["train"],
eval_dataset=prepared_ds["validation"],
tokenizer=feature_extractor,
)
训练🚀
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
评价📊
metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
这是我的评估结果 - 酷豆!对不起,不得不说。
***** eval metrics *****
epoch = 4.0
eval_accuracy = 0.985
eval_loss = 0.0637
eval_runtime = 0:00:02.13
eval_samples_per_second = 62.356
eval_steps_per_second = 7.97
最后,如果需要,可以将模型推送到hub。在这里,如果您在训练配置中指定,我们将其向上推送。请注意,为了推送到 hub,您必须安装 git-lfs 并登录到您的 Hugging Face 帐户(可以通过 ).push_to_hub=Truehuggingface-cli 登录来完成
kwargs = {
"finetuned_from": model.config._name_or_path,
"tasks": "image-classification",
"dataset": 'beans',
"tags": ['image-classification'],
}
if training_args.push_to_hub:
trainer.push_to_hub('🍻 cheers', **kwargs)
else:
trainer.create_model_card(**kwargs)
由此产生的模型已共享给nateraw/vit-base-beans。我假设你没有豆叶的图片,所以我添加了一些例子让你试一试!🚀
更多推荐
所有评论(0)