[Глубокое обучение-Pytorch] Resnet18 - Распознавание кошек и собак

глубокое обучение

Это 25-й день моего участия в августовском испытании обновлений. Узнайте подробности события:Испытание августовского обновления

Введение

Распознавание кошек и собак является задачей начального уровня в сети CNN.Осуществляя распознавание кошек и собак, вы можете лучше понять структуру и эффект работы сети CNN.Что более ценно, так это то, что распознавание кошек и собак просто реализовать и имеет замечательный эффект, который может быть очень захватывающим Мотивация обучения.

Собаки против кошек | kaggle связь:woohoo.cardreform.com/from/dogs-vs-from…

2. Подготовьте набор данных

woohoo.cardreform.com/from/dogs-vs-from…

С kaggle мы можем загрузить 25 000 изображений кошек и собак, из которых 12 500 изображений кошек и собак.

Есть небольшой трюк, чтобы загрузить набор данных kaggle:

  • Сначала используйте Google Chrome для загрузки, Google Chrome перейдет на зеркальный сайт Google (должен быть), затем скопируйте ссылку для загрузки, откройте ее в инструменте загрузки, таком как Thunder, и скорость загрузки удвоится, сократив время ожидания загрузки.

Разархивируйте загруженные данные в каталог train файла проекта.

3. Разделить данные

import os
import shutil
def get_address():
    """获取所有图片路径"""
    data_file = os.listdir('./train/')
    
    dog_file = list(filter(lambda x: x[:3] == 'dog', data_file))
    cat_file = list(filter(lambda x: x[:3] == 'cat', data_file))

    root = os.getcwd()

    return dog_file, cat_file, root
    
def arrange():
    """整理数据,移动图片位置"""
    dog_file, cat_file, root = get_address()

    print('开始数据整理')
    # 新建文件夹
    for i in ['dog', 'cat']:
        for j in ['train', 'val']:
            try:
                os.makedirs(os.path.join(root,j,i))
            except FileExistsError as e:
                pass

    # 移动10%(1250)的狗图到验证集
    for i, file in enumerate(dog_file):
        ori_path = os.path.join(root, 'train', file)
        if i < 0.9*len(dog_file):
            des_path = os.path.join(root, 'train', 'dog')
        else:
            des_path = os.path.join(root, 'val', 'dog')
        shutil.move(ori_path, des_path)

    # 移动10%(1250)的猫图到验证集
    for i, file in enumerate(cat_file):
        ori_path = os.path.join(root, 'train', file)
        if i < 0.9*len(cat_file):
            des_path = os.path.join(root, 'train', 'cat')
        else:
            des_path = os.path.join(root, 'val', 'cat')
        shutil.move(ori_path, des_path)

    print('数据整理完成')

Поскольку kaggle не предоставляет проверочный набор, мы можем выделить часть обучающего набора в качестве проверочного набора. Обучение с учителем может следовать принципу 8: 1: 1. Мы делим 10% данных на проверочный набор, то есть по 1250 изображений кошек и собак.

Следует отметить, что вынутые здесь 2500 листов являютсяБольше нельзя вернуться к тренировочному набору для тренировкиДа, если тренировочный набор совпадет с проверочным, это приведет к переобучению (результат очень хороший, но в реальном бою использовать не получится).

4. Преобразование в читаемые данные

"""get_data.py"""def get_data(input_size, batch_size):
    """获取文件数据并转换"""
    from torchvision import transforms
    from torchvision.datasets import ImageFolder
    from torch.utils.data import DataLoader

    # 串联多个图片变换的操作(训练集)
    # transforms.RandomResizedCrop(input_size) 先随机采集,然后对裁剪得到的图像缩放为同一大小
    # RandomHorizontalFlip()  以给定的概率随机水平旋转给定的PIL的图像
    # transforms.ToTensor()  将图片转换为Tensor,归一化至[0,1]
    # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  归一化处理(平均数,标准偏差)
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    # 获取训练集(通过上面的方面操作)
    train_set = ImageFolder('train', transform=transform_train)
    # 封装训练集
    train_loader = DataLoader(dataset=train_set,
                              batch_size=batch_size,
                              shuffle=True)

    # 串联多个图片变换的操作(验证集)
    transform_val = transforms.Compose([
        transforms.Resize([input_size, input_size]),  # 注意 Resize 参数是 2 维,和 RandomResizedCrop 不同
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    # 获取验证集(通过上面的方面操作)
    val_set = ImageFolder('val', transform=transform_val)
    # 封装验证集
    val_loader = DataLoader(dataset=val_set,
                            batch_size=batch_size,
                            shuffle=False)
    # 输出
    return transform_train, train_set, train_loader, transform_val, val_set, val_loader

Для чтения данных я использую функцию чтения, которая поставляется с pytorch. Помимо чтения данных, он также может выполнять унифицированную обработку данных при чтении.

Здесь можно использовать ImageFolder в pytorch для прямого чтения данных набора изображений (первый параметр определяет адрес папки), но каждое изображение имеет разный размер и должно быть преобразовано в идентифицируемые данные. Считанное изображение необходимо преобразовать (т. е. параметр преобразования).В дополнение к масштабированию изображения также требуется обработка нормализации для уменьшения сложности данных и облегчения обработки данных. Через функцию transforms.Compose эти операции по смене изображения можно соединить последовательно, а необходимые данные быстро получить через вызов ImageFolder. Например, выше я использовал его инкапсулированную случайную обрезку до одинакового размера, случайное вращение, нормализацию и другие операции. То есть данные в сеть удобно кидать для обучения, и это может расширить возможности картинки (повернутая собака все равно собака).

5. Создайте сеть

Resnet-18: Остаточная сеть (18 указывает 18 слоев с весами, включая сверточные слои и полносвязные слои, исключая слои пула и слои BN) (После подробного ознакомления с сетью Resnet может быть представлена ​​отдельная статья. Я не буду здесь вдаваться в подробности. Короче говоря, это улучшенная сеть CNN)

Загрузите сетевую модель resnet18 и ее предварительно обученную модель.

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")     # 选择训练模式
# pretrained=True   使用预训练模型
# 使用resnet18模型
transfer_model = models.resnet18(pretrained=True)
for param in transfer_model.parameters():
    # 屏蔽预训练模型的权重,只训练最后一层的全连接的权重
    param.requires_grad = False
# 修改最后一层维数,即把原来的全连接层替换成输出维数为2的全连接层
# 提取fc层中固定的参数
dim = transfer_model.fc.in_features
# 设置网络中的全连接层为2
transfer_model.fc = nn.Linear(dim, 2)
# 构建神经网络
net = transfer_model.to(device)

Поскольку мы решаем задачу классификации, и это проблема бинарной классификации, нам нужно установить выход полносвязного слоя равным 2. Мы можем оставить другие сетевые структуры разными.

6. Установите параметры тренировки

input_size = 224
batch_size = 128    # 一次训练所选取的样本数(直接影响到GPU内存的使用情况)
save_path = './weights.pt'  # 训练参数储存地址
lr = 1e-3             # 学习率(后面用)
n_epoch = 10          # 训练次数(后面用)

Установите параметры, необходимые для обучения: input_size: размер входного изображения (обрезано в квадрат, как это) batch_size: количество образцов, выбранных для одного обучения (напрямую влияет на использование памяти графического процессора). save_path: адрес хранения параметров обучения л: скорость обучения n_epoch = 10: количество тренировок

7. Начать обучение

def train(net, optimizer, device, criterion, train_loader):
   """训练"""
    net.train()
    batch_num = len(train_loader)
    running_loss = 0.0
    for i, data in enumerate(train_loader, start=1):
        # 将输入传入GPU(CPU)
        inputs, labels = data  
        inputs, labels = inputs.to(device), labels.to(device)
     # 参数梯度置零、向前、反向、优化
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 计算误差并显示
        running_loss += loss.item()
        if i % 10 == 0:
            print('batch:{}/{} loss:{:.3f}'.format(i, batch_num, running_loss / 20))
            running_loss = 0.0
  • optimizer.zero_grad(): градиент обнуляется (поскольку расчет градиента является кумулятивным).
  • outputs = net(inputs): распространение вперед, чтобы найти прогнозируемое значение.
  • потеря = критерий (выходные данные, метки): Рассчитайте потери.
  • loss.backward(): обратное распространение, вычисление текущего градиента.
  • optimizer.step() : Обновите параметры сети на основе градиентов.

По сути сказать нечего, это процесс ввода данных пачками, выполнение операций, вычисление функции потерь и передача ее обратно, обновление параметров сети и, наконец, постепенное сближение признаков изображения кошек и собак.

Восемь, функция проверки

def validate(net, device, val_loader):
    """验证函数"""
    net.eval()  # 测试,需关闭dropout
    correct = 0
    total = 0
    with torch.no_grad():
        for data in val_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('测试图像的网络精度: %d %%' %
          (100 * correct / total))

При проверке следует отметить, что необходимо использовать net.eval() для запуска режима проверки и отключения отсева, иначе обученная нами сеть будет изменена и сеть будет уничтожена.

9. Начать обучение

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.fc.parameters(), lr=lr)
# optimizer = torch.optim.Adam(net.parameters(), lr=lr)
for epoch in range(n_epoch):
    print('第{}次训练'.format(epoch+1))
    f.train(net, optimizer, device, criterion, train_loader)
    f.validate(net, device, val_loader)

# 保存模型参数
torch.save(net.state_dict(), save_path)

Оптимизатор, который я выбираю, — это метод стохастического градиентного спуска (Адам использовал оба, и SGD немного лучше, чем другие).

Поскольку это проблема классификации, здесь используется функция кросс-энтропийных потерь (она будет представлена ​​в отдельной главе, когда будет время).

10. Результаты обучения

Точность сети может достигать 95% после однократного обучения, а после десятикратного обучения точность может достигать 97%.

Ниже приведена простая инкапсуляция сети с использованием tk, и в результате получается следующее:

image.png

адрес проекта:GitHub.com/1224667889/…