Серия PyTorch | Быстрое обучение переносу знаний

PyTorch

оригинальное название | TRANSFER LEARNING TUTORIAL

автор | Sasank Chilamkurthy

оригинальный | py torch.org/tutorials/ нет…

переводчик| kbsc13 (автор публичного аккаунта «The Growth of Algorithm Apes»)

утверждение| Перевод предназначен для общения и обучения, добро пожаловать на перепечатку, но, пожалуйста, сохраните эту статью для коммерческих или незаконных целей.

Введение

В этом руководстве в основном рассказывается, как реализовать трансферное обучение с помощью глубокого обучения. Для получения более подробных знаний о переносе обучения вы можете просмотреть курс cs231n--На данный момент 231 you.GitHub.IO/transfer-...

На практике мало кто начнет с нуля обучать сверточную нейронную сеть со случайной инициализацией, потому что она слишком мала, чтобы иметь достаточное количество наборов данных. Обычно все выбирают предварительно обученную модель, обученную на относительно большом наборе данных (например, набор данных ImageNet, всего 1,2 миллиона изображений в 1000 категориях), а затем инициализируют сверточную нейронную сеть или используют ее для извлечения признаков.

Существует два основных сценария применения трансферного обучения:

  • Точная настройка сети: Инициализируйте сеть с помощью предварительно обученной модели вместо случайной инициализации, которая может быстрее сходиться и достигать лучших результатов.
  • как средство извлечения признаков: это приложение будет фиксировать параметры веса сетевых слоев, кроме последнего полностью подключенного слоя, а затем последний полностью подключенный слой изменит выходные данные в соответствии с количеством категорий набора данных и случайным образом инициализирует свои веса, а затем обучит слой. .

В учебнике этой статьи модели, которые необходимо импортировать в код, следующие:

# License: BSD
# Author: Sasank Chilamkurthy

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

plt.ion()   # interactive mode

Скачать данные

Эта часть загрузки данных будет использоватьtorchvisionиtorch.utils.dataдва модуля.

Цель этого руководства — обучить модель бинарной классификации, категории — муравьи и пчелы, поэтому набор данных содержит 120 обучающих изображений муравьев и пчел соответственно, а затем 75 изображений для каждой категории в качестве проверочного набора. То есть в этом наборе данных всего 390 изображений и менее тысячи изображений.Это очень маленький набор данных.Если модель обучается с нуля, трудно получить хорошую способность к обобщению. Поэтому в этой статье будет использоваться метод трансферного обучения для этого набора данных, чтобы улучшить способность к обобщению.

Получите набор данных и код этой статьи, вы можете ответить в фоновом режиме официального аккаунта"передача обучения"Получать.

Код для загрузки данных выглядит так:

# 数据增强方法,训练集会实现随机裁剪和水平翻转,然后进行归一化
# 验证集仅仅是裁剪和归一化,并不会做数据增强
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}
# 数据集所在文件夹
data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Визуализируйте картинки

Сначала визуализируйте некоторые обучающие изображения, чтобы лучше понять увеличение данных. Код выглядит следующим образом:

# 图片展示的函数
def imshow(inp, title=None):
    """Imshow for Tensor."""
    # 逆转操作,从 tensor 变回 numpy 数组需要转换通道位置
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    # 从归一化后变回原始图片
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# 获取一个 batch 的训练数据
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])

Отображаемые изображения следующие:

Обучите модель

После загрузки данных пришло время начать обучение модели, Здесь будут представлены следующие два содержания:

  • Стратегии повышения скорости обучения
  • Сохраните лучшую модель

В приведенном ниже коде параметрschedulerсостоит в том, чтобы принятьtorch.optim.lr_schedulerИнициализированный объект политики LR:

# 训练模型的函数
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # 每个 epoch 都分为训练阶段和验证阶段
        for phase in ['train', 'val']:
            # 注意训练和验证阶段,需要分别对 model 的设置
            if phase == 'train':
                scheduler.step()
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # 清空参数的梯度
                optimizer.zero_grad()

                # 只有训练阶段才追踪历史
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # 训练阶段才进行反向传播和参数的更新
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # 记录 loss 和 准确率
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # 载入最好的模型参数
    model.load_state_dict(best_model_wts)
    return model

Вышеупомянутая функция реализует обучение модели, которое разделено на этапы обучения и проверки в эпоху.Фаза обучения, естественно, требует прямого расчета и обратного распространения и обновляет параметры сетевого уровня, но этап проверки требует только прямого расчета, а затем записывает потери и точность в наборе проверки.

Кроме того, должны быть установлены условия для сохранения модели.Здесь модель сохраняется, когда точность каждого набора проверки выше, чем предыдущая наилучшая точность.

Визуализируйте предсказания модели

Следующее определяет функцию для визуализации результатов прогнозирования модели, которая используется для отображения изображения и информации о прогнозируемой категории модели для изображения:

# 可视化模型预测结果,即展示图片和模型对该图片的预测类别信息,默认展示 6 张图片
def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

Точная настройка сети

Эта часть является основным содержанием этого трансферного обучения. Спереди — обычный код для загрузки данных и определения процесса обучения. Вот как можно настроить сеть. Код выглядит следующим образом:

# 加载 resnet18 网络模型,并且设置加载预训练模型
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# 修改输出层的输出数量,本次采用的数据集类别为 2
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# 对所有网络层参数进行更新
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# 学习率策略,每 7 个 epochs 乘以 0.1
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

Поскольку этот шаг предназначен для установки и загрузки предварительно обученной модели, предварительно обученная модель будет загружена после запуска.resnet18файл сетевой модели

обучение и проверка

Далее происходит официальное обучение сетевой модели.Код следующий.При использовании процессора это занимает около 15-25 минут, а при использовании графического процессора скорость очень высокая, и обучение в основном завершается примерно за одну минуту .

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)

Результат обучения:

Визуализируйте предсказания модели:

visualize_model(model_ft)

Результаты визуализации следующие:

для экстрактора признаков

Он просто используется для тонкой настройки сети, то есть модель предварительного обучения используется для инициализации параметров сетевого уровня.Далее вводится второе использование трансферного обучения в качестве экстрактора признаков, то есть вес параметры сетевого уровня в фиксированной части модели предобучения. Код реализации этой части выглядит следующим образом, в котором параметры части сверточного слоя необходимо зафиксировать, то есть установитьrequires_grad==False, чтобы их градиенты не вычислялись при обратном распространении, можно посмотреть больше памятиПишите на torch.org/docs/notes/…

model_conv = torchvision.models.resnet18(pretrained=True)
# 固定卷积层的权重参数
for param in model_conv.parameters():
    param.requires_grad = False

# 新的网络层的参数默认 requires_grad=True
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

# 只对输出层的参数进行更新
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# 学习率策略,每 7 个 epochs 乘以 0.1
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

Тренируйся снова:

model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=25)

Результат тренировки:

Визуализируйте результаты предсказания сети

visualize_model(model_conv)

plt.ioff()
plt.show()

Выходной результат:

резюме

В этом учебном пособии кратко представлено содержание трансферного обучения путем обучения модели бинарной классификации, а также кратко представлены два варианта использования трансферного обучения: тонкая настройка сети и средство извлечения признаков для фиксированных значений.

Кодовый адрес этой статьи:

GitHub.com/CCC013/глубокий…

Получите код и метод набора данных для этой статьи:

  1. Обратите внимание на общественный номер»Рост алгоритма обезьяны"
  2. Официальный интерфейс диалога учетной записи отвечает "передача обучения"

Добро пожаловать в мой общедоступный аккаунт WeChat--Рост алгоритма обезьяныили отсканируйте QR-код ниже, чтобы общаться, учиться и развиваться вместе!