博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
动手学PyTorch | (59) 微调(fine-tuning)
阅读量:4037 次
发布时间:2019-05-24

本文共 5728 字,大约阅读时间需要 19 分钟。

在前⾯的一些章节中,我们介绍了如何在只有6万张图像的Fashion-MNIST训练数据集上训练模型。我们还描述了学术界当下使用最⼴泛的⼤规模图像数据集ImageNet,它有超过1,000万的图像和1,000类的物体。然而,我们平常接触到数据集的规模通常在这两者之间。

假设我们想从图像中识别出不同种类的椅子,然后将购买链接推荐给用户。一种可能的⽅法是先找出 100种常见的椅子,为每种椅⼦拍摄1,000张不同⻆度的图像,然后在收集到的图像数据集上训练⼀个分类模型。这个椅子数据集虽然可能比Fashion-MNIST数据集要庞大,但样本数仍然不及ImageNet数据集中样本数的十分之一。这可能会导致适用于ImageNet数据集的复杂模型在这个椅⼦数据集上过拟合。同时,因为数据量有限,最终训练得到的模型的精度也可能达不到实用的要求。

为了应对上述问题,⼀个显而易见的解决办法是收集更多的数据。然⽽,收集和标注数据会花费⼤量的 时间和资金。例如,为了收集ImageNet数据集,研究⼈员花费了数百万美元的研究经费。虽然⽬前的数据采集成本已降低了不少,但其成本仍然不可忽略。

另外⼀种解决办法是应用迁移学习(transfer learning),将从源数据集学到的知识迁移到目标数据集上。例如,虽然ImageNet数据集的图像⼤多跟椅⼦无关,但在该数据集上训练的模型可以抽取较通⽤的图像特征,从⽽能够帮助识别边缘、纹理、形状和物体组成等。这些类似的特征对于识别椅子也可能同样有效。

本节我们介绍迁移学习中的一种常用技术:微调(fine tuning)。如下图所示,微调由以下4步构成。

1)在源数据集(如ImageNet数据集)上预训练⼀个神经⽹络模型,即源模型。

2)  创建⼀个新的神经网络模型,即⽬标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适⽤于⽬标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在⽬标模型中不予采用。

3)为⽬标模型添加⼀个输出⼤小为目标数据集类别个数的输出层,并随机初始化该层的模型参数。

4)在⽬标数据集(如椅⼦数据集)上训练⽬标模型。我们将从头训练输出层,⽽其余层的参数都基于源模型的参数微调得到。

当⽬标数据集远⼩于源数据集时,微调有助于提升模型的泛化能力。

目录


1. 热狗识别

接下来我们来实践⼀个具体的例子:热狗识别。我们将基于⼀个⼩数据集对在ImageNet数据集上训练好的ResNet模型进行微调。该⼩数据集含有数千张包含热狗和不包含热狗的图像。我们将使用微调得到的模型来识别一张图像中是否包含热狗。

⾸先,导入实验所需的包或模块。torchvision的models包提供了常用的预训练模型。如果希望获取更多的预训练模型,可以使⽤pretrained-models.pytorch库。

%matplotlib inlineimport torchfrom torch import nn, optimfrom torch.utils.data import Dataset, DataLoaderimport torchvisionfrom torchvision.datasets import ImageFolderfrom torchvision import transformsfrom torchvision import modelsimport osimport syssys.path.append(".") import d2lzh_pytorch as d2los.environ["CUDA_VISIBLE_DEVICES"] = "1"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  • 获取数据集

我们使⽤的热狗数据集()是从⽹上抓取的,它含有1400张包含热狗的正类图像,和同样多包含其他食品的负类图像。各类的1000张图像被⽤于训练,其余则⽤于测试。

我们⾸先将压缩后的数据集下载到路径data_dir下,然后在该路径将下载好的数据集解压,得到两个文件夹hotdog/train和hotdog/test。这两个⽂件夹下面均有hotdog和not-hotdog两个类别文件夹,每个类别文件夹里面是图像文件。

data_dir = './Datasets'os.listdir(os.path.join(data_dir, "hotdog"))

上面这种存储结构可以使用ImageFolder 实例来分别读取训练数据集和测试数据集中的所有图像文件。

train_imgs = ImageFolder(os.path.join(data_dir, 'hotdog/train'))test_imgs = ImageFolder(os.path.join(data_dir, 'hotdog/test'))

下⾯画出前8张正类图像和最后8张负类图像。可以看到,它们的⼤小和⾼宽⽐各不相同。

hotdogs = [train_imgs[i][0] for i in range(8)]not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4);

在训练时,我们先从图像中裁剪出随机⼤小和随机⾼宽比的⼀块随机区域,然后将该区域缩放为⾼和宽 均为224像素的输⼊入。测试时,我们将图像的高和宽均缩放为256像素,然后从中裁剪出高和宽均为224 像素的中⼼区域作为输入。此外,我们对RGB(红、绿、蓝)三个颜色通道的数值做标准化:每个数值 减去该通道所有数值的平均值,再除以该通道所有数值的标准差作为输出。

在使⽤预训练模型时,⼀定要和预训练时作同样的预处理。 如果你使用的是torchvision.models,那就要求All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]. 如果你使用的是pretrained-models.pytorch仓库,请务必阅读其README,其中说明了如何预处理。

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])train_augs = transforms.Compose([        transforms.RandomResizedCrop(size=224),        transforms.RandomHorizontalFlip(),        transforms.ToTensor(),        normalize    ])test_augs = transforms.Compose([        transforms.Resize(size=256),        transforms.CenterCrop(size=224),        transforms.ToTensor(),        normalize    ])
  • 定义和初始化模型

我们使⽤在ImageNet数据集上预训练的ResNet-18作为源模型。这⾥指定pretrained=True来⾃动下载并加载预训练的模型参数(False只加载结构,不加载参数)。在第一次使用时需要联⽹下载模型参数。

pretrained_net = models.resnet18(pretrained=True)

不管你是使用的torchvision的models还是pretrained-models.pytorch仓库,默认都会将预训练好的模型参数下载到你的home⽬录下 .torch文件夹。你可以通过修改环境变量$TORCH_MODEL_ZOO来更改下载⽬录 :export TORCH_MODEL_ZOO="/local/pretrainedmodels 另外我⽐较常使用的方法是,在其源码中找到 下载地址直接浏览器输入地址下载,下载好后将其放到环境变量$TORCH_MODEL_ZOO所指文件夹即可,这样⽐较快。

下⾯打印源模型的成员变量 fc。作为一个全连接层,它将ResNet最终的全局平均池化层输出变换成 ImageNet数据集上1000类的输出。

print(pretrained_net.fc)

如果你使⽤的是其他模型,那可能没有成员变量fc (⽐如models中的VGG预训练模型), 所以正确做法是查看对应模型源码中其定义部分,这样既不会出错也能加深我们对模型的理解。 pretrained-models.pytorch仓库貌似统一了接口,但是我还是建议使⽤时查看一下对应模型的源码。

可⻅此时pretrained_net最后的输出个数等于⽬标数据集的类别数1000。所以我们应该将最后的fc修改成我们需要的输出类别数:

pretrained_net.fc = nn.Linear(512, 2)print(pretrained_net.fc)

此时, pretrained_net的fc层就被随机初始化了,但是其他层依然保存着预训练得到的参数。由于是在很大的ImageNet数据集上预训练的,所以参数已经足够好,因此一般只需使用较小的学习率来微调这些参数(或者也可以冻结),⽽fc中的随机初始化参数⼀般需要更⼤的学习率从头训练。PyTorch可以⽅便的对模型的不同部分设置不同的学习参数,我们在下面代码中将fc的学习率设为已经预训练过的部分的10倍。

output_params = list(map(id, pretrained_net.fc.parameters()))feature_params = filter(lambda p: id(p) not in output_params, pretrained_net.parameters())lr = 0.01optimizer = optim.SGD([{'params': feature_params},                       {'params': pretrained_net.fc.parameters(), 'lr': lr * 10}],                       lr=lr, weight_decay=0.001)
  • 微调模型

我们先定义⼀个使⽤微调的训练函数train_fine_tuning以便多次调用。

def train_fine_tuning(net, optimizer, batch_size=128, num_epochs=5):    train_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/train'), transform=train_augs),                            batch_size, shuffle=True)    test_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/test'), transform=test_augs),                           batch_size)    loss = torch.nn.CrossEntropyLoss()    d2l.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)

根据前⾯的设置,我们将以10倍的学习率从头训练⽬标模型的输出层参数。

train_fine_tuning(pretrained_net, optimizer)

作为对比,我们定义⼀个相同的模型,但将它的所有模型参数都初始化为随机值。由于整个模型都需要 从头训练,我们可以使⽤较⼤的学习率。

scratch_net = models.resnet18(pretrained=False, num_classes=2) #只加载结构 不加载参数 同时修改输出层lr = 0.1optimizer = optim.SGD(scratch_net.parameters(), lr=lr, weight_decay=0.001)train_fine_tuning(scratch_net, optimizer)

 

可以看到,微调的模型因为参数初始值更好,往往在相同迭代周期下取得更⾼的精度。

 

2. 小结

1)迁移学习将从源数据集学到的知识迁移到⽬标数据集上。微调是迁移学习的⼀种常用技术。

2)⽬标模型复制了源模型上除了输出层外的所有模型设计及其参数,并基于目标数据集微调这些参数。⽽⽬标模型的输出层需要从头训练。

3)⼀般来说,微调参数会使用较小的学习率,而从头训练输出层可以使用较大的学习率。

 

转载地址:http://ewsdi.baihongyu.com/

你可能感兴趣的文章
idea讲web项目部署到tomcat,热部署
查看>>
IDEA Properties中文unicode转码问题
查看>>
Idea下安装Lombok插件
查看>>
zookeeper
查看>>
Idea导入的工程看不到src等代码
查看>>
技术栈
查看>>
Jenkins中shell-script执行报错sh: line 2: npm: command not found
查看>>
8.X版本的node打包时,gulp命令报错 require.extensions.hasownproperty
查看>>
Jenkins 启动命令
查看>>
Maven项目版本继承 – 我必须指定父版本?
查看>>
Maven跳过单元测试的两种方式
查看>>
通过C++反射实现C++与任意脚本(lua、js等)的交互(二)
查看>>
利用清华镜像站解决pip超时问题
查看>>
[leetcode BY python]1两数之和
查看>>
微信小程序开发全线记录
查看>>
PTA:一元多项式的加乘运算
查看>>
CCF 分蛋糕
查看>>
解决python2.7中UnicodeEncodeError
查看>>
小谈python 输出
查看>>
Django objects.all()、objects.get()与objects.filter()之间的区别介绍
查看>>