Это пятый день моего участия в ноябрьском испытании обновлений, подробности о мероприятии:Вызов последнего обновления 2021 г.
import torch
from IPython import display
from d2l import torch as d2l
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
Установите размер мини-пакета на 256 и прочитайте итератор набора данных.
num_inputs = 784
num_outputs = 10
#初始化为均值为0,方差为0.1的张量
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)
-
num_inputs
Картинка в наборе данных 28*28, количество каналов 1, преобразование во входной вектор 784 -
num_outputs
На выходе десять классификаций, а выходной вектор равен 10.
def softmax(X):
X_exp = torch.exp(X)
partition = X_exp.sum(1, keepdim=True) # 按照每行求和,保持维度不变
return X_exp / partition # 这里用到了广播机制
определить софтмакс
def net(X):
return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)
Определите модель.
def cross_entropy(y_hat, y):
return - torch.log(y_hat[range(len(y_hat)), y])
Определение потери перекрестной энтропии
def accuracy(y_hat, y):
# y_hat维度>1并且多余一行
if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
y_hat = y_hat.argmax(axis=1)
cmp = y_hat.type(y.dtype) == y
return float(cmp.type(y.dtype).sum())
Эта функция используется для вычисления и возврата количества правильных прогнозов.
-
argmax()
Выньте значение индекса, соответствующее самому большому элементу.В это время y_hat был преобразован в вектор значения индекса самого большого элемента - cmp — логическое значение, сравните y_hat и y в это время, чтобы увидеть точность прогноза.
- Наконец, преобразуйте тип данных cmp в тип данных y, то есть преобразуйте true false of bool в 1 0, а затем просуммируйте, чтобы вычислить общее количество правильных прогнозов и вернуть результат.
-
accuracy(y_hat, y) / len(y)
Разделите на общее число y, чтобы найти точность
class Accumulator:
"""在`n`个变量上累加。"""
def __init__(self, n):
self.data = [0.0] * n # 初始化列表,长度为n
def add(self, *args):
self.data = [a + float(b) for a, b in zip(self.data, args)]
def reset(self):
self.data = [0.0] * len(self.data)
def __getitem__(self, idx):
return self.data[idx]
Класс Accumulator предназначен для реализации аккумулятора. Наложение входящих данных каждый раз.
-
*args
Использование: Когда количество входящих параметров неизвестно, и вам не нужно знать имя параметра- Поскольку длина инициализации неопределенна, поэтому
add
Длина параметра также неопределенна - но выполнить
add
Если вы хотите убедиться, что количество входящих параметров иn
Такой же
- Поскольку длина инициализации неопределенна, поэтому
def evaluate_accuracy(net, data_iter):
if isinstance(net, torch.nn.Module):# 判断类型
net.eval()
metric = Accumulator(2) # 正确预测数、预测总数 共两个
for X, y in data_iter:
metric.add(accuracy(net(X), y), y.numel())
return metric[0] / metric[1]
Вычисляет точность модели для указанного набора данных.
-
net.eval()
Установите модель в режим оценки -
accuracy(net(X), y)
Рассчитать правильное количество образцов -
y.numel()
общее количество образцов
Оцените и сложите каждую партию, чтобы найти всю сумму.
def updater(batch_size):
return d2l.sgd([W, b], lr, batch_size)
Этот sgd представляет собой функцию обновления весов и параметров вручную, реализованную в версии 3.2.
def train_epoch_ch3(net, train_iter, loss, updater): #@save
"""训练模型一个迭代周期(定义见第3章)。"""
# 将模型设置为训练模式
if isinstance(net, torch.nn.Module):
net.train()
# 训练损失总和、训练准确度总和、样本数
metric = Accumulator(3)
for X, y in train_iter:
# 计算梯度并更新参数
y_hat = net(X)
l = loss(y_hat, y)
if isinstance(updater, torch.optim.Optimizer):
# 使用PyTorch内置的优化器和损失函数
updater.zero_grad()
l.backward()
updater.step()
metric.add(float(l) * len(y), accuracy(y_hat, y),
y.size().numel())
else:
# 使用定制的优化器和损失函数
l.sum().backward()
updater(X.shape[0])
metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
# 返回训练损失和训练准确率
return metric[0] / metric[2], metric[1] / metric[2]
# ch3即第三章的训练函数
def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater): #@save
"""训练模型。"""
for epoch in range(num_epochs):
train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
test_acc = evaluate_accuracy(net, test_iter)
train_loss, train_acc = train_metrics
assert train_loss < 0.5, train_loss
assert train_acc <= 1 and train_acc > 0.7, train_acc
assert test_acc <= 1 and test_acc > 0.7, test_acc
- assert : используется для оценки выражения и запуска исключения, когда условие выражения ложно.
lr = 0.1
num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)
В нем были визуальные результаты, но я их удалил.