Примите участие в 9-м дне ноябрьского испытания обновлений и узнайте подробности события:Вызов последнего обновления 2021 г.
Одним из широко используемых наборов данных для классификации изображений является набор данных MNIST.LeCun.Bottou.Bengio.ea.1998
(Только набор данных для распознавания рукописных цифр). Хотя это хороший эталонный набор данных, даже простые модели могут достигать точности классификации более 95% по сегодняшним стандартам, что делает их непригодными для различения сильных и слабых моделей. Теперь мы используем аналогичный, но относительно сложный набор данных Fashion-MNIST, опубликованный в 2017 году.Xiao.Rasul.Vollgraf.2017
.
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
# d2l使用svg来显示图片使其清晰度更高
d2l.use_svg_display()
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=False)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=False)
- Здесь появится предупреждение, см. здесь:Подробное объяснение torchvision.transforms.ToTensor | Предупреждение пользователя при использовании transforms.ToTensor() - Самородки (juejin.cn)
-
trans = transforms.ToTensor()
Предварительная обработка, преобразование изображения в тензор.Преобразуйте данные изображения из типа PIL в 32-битный формат с плавающей запятой через экземпляр ToTensor и разделите на 255, чтобы все значения пикселей были между 0 и 1. -
torchvision.datasets.FashionMNIST
: получить набор данных из набора данных torchvision.- обучать, загружать ли обучающий набор данных
- трансформировать ли трансформировать
- скачать ли скачать
len(mnist_train), len(mnist_test)
# 看训练集和测试集的样本数量
mnist_train[0][0].shape
# 查看第一张图片的形状
Здесь вывод формыtorch.Size([1, 28, 28])
, канал, карта 28*28.
batch_size = 256
def get_dataloader_workers(): #@save
"""使用4个进程来读取数据。"""
return 4
# num_workers多线程读取,这里是4线程
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers())
Чтение небольшого пакета данных размеромbatch_size
Все выборки также случайным образом перемешиваются в итераторе обучающих данных.
Затем вы можете посмотреть, сколько времени требуется для чтения обучающего набора данных.
timer = d2l.Timer()
for X, y in train_iter:
continue
f'{timer.stop():.2f} sec'
Наконец, объедините все функции в одну функцию. Это изменение размера зависит от того, следует ли корректировать форму ввода.
def load_data_fashion_mnist(batch_size, resize=None): #@save
"""下载Fashion-MNIST数据集,然后将其加载到内存中。"""
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=get_dataloader_workers()))