JA EN ZH

基于有限数据的图像分类迁移学习 - 微调实战指南

· 9 分钟阅读

什么是迁移学习 - 用有限数据构建高精度模型

迁移学习利用在大规模数据集上预训练的模型知识来完成数据有限的新任务。将在 ImageNet 1400 万张图像上训练的特征提取器用于仅有数百张图像的自定义分类任务,可在更短时间内达到远超从零训练的精度。

迁移学习为何有效:

CNN 浅层学习通用特征(边缘、纹理、颜色模式),深层学习任务特定特征。浅层特征在不同视觉任务间具有高度可迁移性,因此预训练模型的浅层可直接复用。这意味着新任务只需学习深层的任务特定表示,大幅减少所需数据量和训练时间。

迁移学习的三种策略:

  • 特征提取:冻结预训练模型,仅训练新的分类头。最快速,适合数据极少的场景。
  • 微调:以预训练权重为初始化,对全部或部分层进行额外训练。精度最高但需要更多数据。
  • 领域自适应:通过特殊技术弥合源域与目标域之间的差距。适用于图像特征差异大的场景。

数据量与策略选择指南:

  • 50-200 张图像:特征提取(冻结全部层,仅训练分类头)
  • 200-1000 张图像:微调最后几层(冻结浅层,训练深层+分类头)
  • 1000 张以上:全模型微调(所有层都可训练,使用较小学习率)

微调基础 - PyTorch 实现

微调以预训练模型权重为初始化,在新数据集上进行额外训练。替换最终分类头以匹配新任务的类别数,然后用任务特定数据训练全部或部分层进行适配。

实现步骤:

1. 加载预训练模型:model = timm.create_model("efficientnet_b0", pretrained=True, num_classes=10) 自动替换最终层。2. 准备数据加载器,包含适当的数据增强。3. 设置差异化学习率:预训练层使用较小学习率(1e-5),新分类头使用较大学习率(1e-3)。4. 训练并监控验证损失以防过拟合。

差异化学习率设置:

optimizer = torch.optim.AdamW([{"params": model.features.parameters(), "lr": 1e-5}, {"params": model.classifier.parameters(), "lr": 1e-3}])

预训练层使用 1/100 的学习率,避免破坏已学习的有用特征。新层使用标准学习率快速收敛。

数据增强策略:

  • 基础增强:随机水平翻转、随机旋转(±15°)、颜色抖动
  • 高级增强:RandAugment、CutMix、MixUp(数据量少于 500 张时特别有效)
  • 验证集不做增强,仅做 Resize 和 Normalize

训练技巧:

  • 使用余弦退火学习率调度器平滑降低学习率
  • 早停法(patience=5-10)防止过拟合
  • 梯度裁剪(max_norm=1.0)稳定训练
  • 使用混合精度训练(AMP)加速 2 倍且不损失精度

特征提取 - 仅训练最终层

特征提取完全冻结预训练模型权重,仅训练最终分类层来完成新任务。计算成本低于微调,在数据极度有限(50-200 张图像)且过拟合风险最高时特别有效。

实现方法:

冻结所有参数后替换最终层。使用 requires_grad=False 冻结所有层,然后用匹配类别数的新 Linear 层替换 model.fc。只有新层的参数会被更新,大幅减少训练时间和内存需求。

代码示例:

model = timm.create_model("resnet50", pretrained=True)

for param in model.parameters(): param.requires_grad = False

model.fc = nn.Linear(model.fc.in_features, num_classes)

特征提取的优势:

  • 训练速度快:仅更新最终层,GPU 上几分钟即可完成
  • 过拟合风险低:可训练参数极少,即使数据很少也不易过拟合
  • 无需 GPU:特征提取后的分类器训练在 CPU 上也可快速完成

局限性:

  • 当目标域与 ImageNet 差异大时(医学图像、卫星图像等)精度有限
  • 无法适配任务特定的特征表示
  • 分类精度上限低于微调方法

混合策略 - 渐进式解冻:

先以特征提取模式训练分类头至收敛,然后逐步解冻深层进行微调。这种渐进式方法比一次性微调所有层更稳定,特别适合中等数据量(200-500 张)的场景。

模型架构选择 - ResNet、EfficientNet、ViT

迁移学习的基础模型选择需要权衡精度、推理速度和模型大小。以下比较 2025 年可用的主要选项,涵盖不同硬件约束下的实际部署场景。

ResNet:

2015 年提出的残差连接架构,是迁移学习的标准模型。ResNet-50 拥有 2560 万参数,使用最为广泛。结构简单,在各框架中优化良好。ImageNet Top-1 精度根据变体在 76-80% 之间。推理速度快,适合生产环境部署。

EfficientNet:

通过复合缩放(同时缩放深度、宽度、分辨率)实现最优精度-效率平衡。EfficientNet-B0(540 万参数)在移动端部署中表现优异。B3-B4 在精度和速度间取得良好平衡。比同精度的 ResNet 小 4-8 倍。

Vision Transformer (ViT):

基于自注意力机制的架构,在大数据集上表现卓越。ViT-Base(8600 万参数)在充足数据下超越 CNN。但在小数据集(<1000 张)上容易过拟合,不适合特征提取模式。DeiT 变体通过知识蒸馏改善了小数据性能。

选择指南:

  • 数据少于 500 张:EfficientNet-B0 或 ResNet-34(参数少,不易过拟合)
  • 数据 500-5000 张:EfficientNet-B3 或 ResNet-50(精度与泛化的平衡)
  • 数据超过 5000 张:ViT-Base 或 EfficientNet-B5(充分利用大数据)
  • 边缘部署:MobileNetV3 或 EfficientNet-Lite(推理速度优先)

领域自适应与实用技巧

迁移学习的效果取决于预训练数据与目标任务之间的领域相似度。本节介绍领域差距较大时的应对策略,以及在实际项目中最大化迁移学习性能的实用技巧。

领域差距问题:

将 ImageNet 训练的模型应用于医学图像或卫星图像时,由于图像特征差异巨大会导致性能下降。医学图像通常是灰度的,纹理模式与自然图像完全不同。卫星图像的视角、分辨率和色彩分布也与日常照片截然不同。

应对策略:

  • 中间域预训练:先在与目标域更接近的大数据集上预训练,再微调到目标任务。例如医学图像可先在 X 光数据集上预训练。
  • 自监督预训练:使用 MAE、DINO 等自监督方法在目标域无标签数据上预训练,学习域特定特征。
  • 数据增强对齐:设计模拟目标域特征的增强策略(如医学图像的灰度化、对比度调整)。

实用技巧:

  • 学习率预热:前 5-10% 的训练步骤线性增加学习率,避免初期大梯度破坏预训练特征
  • 标签平滑:使用 0.1 的标签平滑减少过拟合,提高泛化能力
  • 测试时增强(TTA):推理时对输入做多种增强并平均预测,可提升 1-3% 精度
  • 模型集成:组合多个不同架构的微调模型,进一步提升精度

实战项目 - 用 100 张图像构建分类器

分步指南:用约 100 张图像构建图像分类器,附代码示例。以二分类为实际案例,演示完整的迁移学习工作流程。

数据准备:

每类准备 50 张图像,共 100 张。按训练:验证 = 8:2 划分,得到 80 张训练和 20 张验证图像。目录结构遵循标准约定,train 和 val 文件夹下按类别建立子目录。

完整训练代码:

import timm, torch

model = timm.create_model("efficientnet_b0", pretrained=True, num_classes=2)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

训练配置:

  • Epoch 数:20-30(配合早停)
  • Batch size:16(小数据集用小 batch)
  • 优化器:AdamW(weight_decay=0.01)
  • 损失函数:CrossEntropyLoss(配合标签平滑 0.1)

预期结果:

使用 EfficientNet-B0 微调,100 张图像的二分类任务通常可达到 90-95% 的验证精度。相比从零训练(通常 60-70%),迁移学习带来 20-30% 的精度提升。训练时间在单 GPU 上约 5-10 分钟。

部署建议:

  • 使用 TorchScript 或 ONNX 导出模型用于生产部署
  • 量化(INT8)可将模型大小减少 4 倍,推理速度提升 2-3 倍
  • 考虑使用 TensorRT 或 OpenVINO 进行推理优化

Related Articles

机器学习数据增强 - 实用图像增强技术

学习机器学习中的数据增强技术,包括几何变换、颜色增强、MixUp/CutMix 以及自动增强策略的实践方法。

目标检测概述 - YOLO、SSD 和 Faster R-CNN 架构与性能对比

全面解析目标检测技术,从 Faster R-CNN 到 YOLO 系列和 SSD,比较各架构的精度、速度和适用场景。

NeRF 基础 - 从图像进行 3D 场景重建

详解 Neural Radiance Fields (NeRF) 的原理,从体积渲染到 Instant NGP 和 3D Gaussian Splatting 等加速方法,涵盖实际工作流程。

扩散模型工作原理 - Stable Diffusion 技术深度解析

从扩散模型原理到 Stable Diffusion 架构。涵盖 DDPM、潜在扩散、CFG、加速技术和实用控制方法。

GAN 图像生成的应用 - 从超分辨率到风格迁移

系统讲解 GAN 在图像处理中的实际应用。涵盖超分辨率、风格迁移、图像修复、人脸生成和实用部署方案。

深度学习超分辨率 - 从 SRCNN 到 Real-ESRGAN 的演进与实践

全面解析深度学习图像超分辨率技术,从 SRCNN 开创性工作到 Real-ESRGAN 的真实世界退化处理,涵盖实际部署指南。

Related Terms