предисловие:
В этой статье в основном описывается, как построить простую сверточную нейронную сеть с использованием Pytorch (среда глубокого обучения), которая в настоящее время популярна и популярна, а также обучить и протестировать ее на наборе данных MNIST. Набор данных MNIST представляет собой набор 28*28 изображений рукописных цифр, а тестовый набор используется для проверки точности распознавания обученной модели рукописных цифр.
Данные PyTorch:
Ссылка на официальную документацию PyTorch:PyTorch documentation, есть не только описания API, но и несколько классических примеров для справки.
Форум официального сайта PyTorch:vision, будет много информации, которой можно поделиться, и ответы на некоторые популярные вопросы.
PyTorch создает практику работы с нейронными сетями:
Вначале импортируйте два файла основной библиотеки torch и torchvision, которые необходимо импортировать в PyTorch.Эти две библиотеки в основном содержат множество методов и функций, которые будет использовать PyTorch.
import torchvision
import torch
from torchvision import datasets, transforms
Стоит отметить, что наборы данных torchvision могут легко загружать набор данных автоматически, и здесь используется набор данных MNIST. Кроме того, такие наборы данных, как COCO, ImageNet, CIFCAR, также можно легко загрузить и использовать, а команда импорта также очень проста.
data_train = datasets.MNIST(root = "./data/",
transform=transform,
train = True,
download = True)
data_test = datasets.MNIST(root="./data/",
transform = transform,
train = False)
root указывает путь, где хранится набор данных, transform указывает, какие операции преобразования необходимо выполнить при импорте набора данных, и если для train установлено значение True, это означает, что импортируется обучающий набор, в противном случае это тестовый набор.
В преобразовании есть много хороших методов, которые можно использовать для операций аргументации данных в наборах данных с меньшим количеством ресурсов изображений.Вот простое преобразование формата тензора и пакетная нормализация.
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])
После завершения загрузки данных необходимо выполнить операцию загрузки данных.
data_loader_train = torch.utils.data.DataLoader(dataset=data_train,
batch_size = 64,
shuffle = True)
data_loader_test = torch.utils.data.DataLoader(dataset=data_test,
batch_size = 64,
shuffle = True)
batch_size устанавливает количество изображений данных, загружаемых в каждом пакете, равным 64, а для случайного перемешивания установлено значение True для случайного перемешивания в процессе загрузки.
На следующем рисунке показано отображение набора пакетных данных (64 изображения), видно, что все они являются одномерными изображениями размером 28*28.
После загрузки данных можно построить базовую программу. Вот нейронная сеть, включающая сверточный слой и полностью связанный слой.Сверточный слой построен с использованием torch.nn.Conv2d, а слой активации построен с использованием torch.nn .ReLU Для сборки слой пула строится с помощью torch.nn.MaxPool2d, а полносвязный слой строится с помощью torch.nn.Linear.
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(stride=2,kernel_size=2))
self.dense = torch.nn.Sequential(torch.nn.Linear(14*14*128,1024),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(1024, 10))
def forward(self, x):
x = self.conv1(x)
x = x.view(-1, 14*14*128)
x = self.dense(x)
return x
который определяет torch.nn.Dropout (p = 0,5), чтобы предотвратить переоснащение модели.
Прямая функция определяет прямое распространение, которое на самом деле является нормальным путем свертки. Сначала он обрабатывается сверткой self.conv1(x), затем сжимается и сглаживается x.view(-1, 14*14*128) и, наконец, классифицируется полным соединением self.dense(x)
После этого вызывается объект «Модель», затем определяется расчет потерь с использованием перекрестной энтропии, а расчет оптимизации использует метод автоматизации Адама, и, наконец, можно начинать обучение.
model = Model()
cost = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
Вы можете просмотреть архитектуру нейронной сети перед обучением, вывод на печать показывает следующее
Model (
(conv1): Sequential (
(0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU ()
(2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU ()
(4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
)
(dense): Sequential (
(0): Linear (25088 -> 1024)
(1): ReLU ()
(2): Dropout (p = 0.5)
(3): Linear (1024 -> 10)
)
)
Определите количество раз обучения как 5 и запустите нейронную сеть.После завершения обучения результаты, полученные при вводе тестового набора, следующие
Epoch 0/5
----------
Loss is:0.0003, Train Accuracy is:99.4167%, Test Accuracy is:98.6600
Epoch 1/5
----------
Loss is:0.0002, Train Accuracy is:99.5967%, Test Accuracy is:98.9200
Epoch 2/5
----------
Loss is:0.0002, Train Accuracy is:99.6667%, Test Accuracy is:98.7700
Epoch 3/5
----------
Loss is:0.0002, Train Accuracy is:99.7133%, Test Accuracy is:98.9600
Epoch 4/5
----------
Loss is:0.0001, Train Accuracy is:99.7317%, Test Accuracy is:98.7300
Судя по результатам, это неплохо, точность обучения составляет до 99,73%, а точность теста - 98,96%. Результаты показывают небольшие признаки переобучения, и лучшие результаты были бы достигнуты, если бы на тестовом наборе использовалась более надежная сверточная модель.
Случайным образом предсказать несколько изображений тестового набора и отобразить их визуально
Predict Label is: [3, 4, 9, 3]
Real Label is: [3, 4, 9, 3]
После завершения обучения параметры, полученные в результате обучения, также можно сохранить, чтобы их можно было использовать сразу после следующего импорта.
torch.save(model.state_dict(), "model_parameter.pkl")
Ссылка на полный код:JaimeTang/Pytorch-and-mnist(файл model_parameter.pkl слишком велик и не загружен)