Участвуйте в 21-м дне Ноябрьского испытания обновлений, узнайте подробности события:Вызов последнего обновления 2021 г.
Почему необходимо читать и записывать файлы?
Чтение и запись файлов на самом деле не является чтением наборов данных.
Это когда вы тренируетесь периодически сохранять промежуточные результаты, чтобы гарантировать, что в случае случайного отключения питания сервера или чего-то еще вы потеряете результаты вычислений предыдущих дней.
В этом разделе рассказывается, как хранить веса и всю модель.
import torch
from torch import nn
from torch.nn import functional as F
load
иsave
Для одного тензора мы можем напрямую вызватьload
иsave
Функция читает и записывает их отдельно.
-
torch.saves
torch.save(obj, f, pickle_module=<module 'pickle' from '.../pickle.py'>, pickle_protocol=2)
параметр:
- obj — сохранить объект
- f - строка, имя файла
- pickle_module — модуль для травления метаданных и объектов
- pickle_protocol — укажите протокол pickle для переопределения параметров по умолчанию
-
torch.load
torch.load(f, map_location=None, pickle_module=<module 'pickle' from '.../pickle.py'>)
Прочитать файл с диска через
torch.save()
сохраненный объект.параметр:
- f – строка, имя файла
- map_location — функция или словарь, указывающий, как переназначить место хранения
- pickle_module — модуль для распаковки метаданных и объектов (должен совпадать с pickle_module при сериализации файлов)
x = torch.arange(4)
torch.save(x, 'x-file')
x2 = torch.load('x-file')
print(x2)
инициализировать х
Сохраните x в текущей папке и назовите егоx-file
, вы найдете файл с таким же именем в текущей папке.
Конечно, это может быть не 0 1 2 3 после открытия, потому что метод кодирования другой, поэтому не беспокойтесь о том, что вы видите после открытия.
Объявите x2 и прочитайте его из файла, вы обнаружите, что результат тензор ([0, 1, 2, 3]), и результат правильный.
y = torch.zeros(4)
torch.save([x, y],'x-file')
x2, y2 = torch.load('x-file')
print(x2, y2)
>>
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
Сохраните список тензоров и прочитайте их обратно в память.
y = torch.zeros(4)
torch.save(y[:2],'x-file')
x2, y2 = torch.load('x-file')
(x2, y2)
print(x2, y2)
Также возможна нарезка.
mydict = {'x': x, 'y': y}
torch.save(mydict, 'x-file')
mydict2 = torch.load('x-file')
mydict2
>>
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
Хранение словарей тоже работает.
Загрузить и сохранить параметры модели
Фреймворки глубокого обучения предоставляют встроенные функции для сохранения и загрузки целых сетей.
Но вместо сохранения всей модели сохраняются параметры модели.
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.hidden = nn.Linear(20, 256)
self.output = nn.Linear(256, 10)
def forward(self, x):
return self.output(F.relu(self.hidden(x)))
Помните многослойный персептрон, который мы написали вручную, используйте его, чтобы реализовать все сразу.
net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)
Теперь сгенерируйте сеть, используйте ее для вычисления X и назначьте ее Y.
torch.save(net.state_dict(), 'x-file')
Сохраните параметры сети.
net_ = MLP()
net_.load_state_dict(torch.load('x-file'))
net_.eval()
Генерация net_ также является многоуровневым персептроном, и параметры сети напрямую загружают параметры в файл.
net_.eval()
заключается в изменении режима модели на режим оценки.
Y_clone = net_(X)
print(Y_clone == Y)
>>
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
Назначьте новую сеть Y_clone, вы можете видеть, что Y_clone и Y одинаковы.
Конечно, также можно перейти на собственный слой pytorch.
MLP = nn.Sequential(nn.Linear(20,256),nn.Linear(256,10),nn.ReLU())
def init(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight)
nn.init.zeros_(m.bias)
net = MLP
X = torch.randn(size=(2, 20))
Y = net(X)
torch.save(net.state_dict(), 'x-file')
net_ = MLP
net_.load_state_dict(torch.load('x-file'))
net_.eval()
Y_clone = net_(X)
Y_clone == Y
>>
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
-
Подробнее о серии «Практическое глубокое обучение» см. здесь:Колонка «Практическое глубокое обучение» (juejin.cn)
-
Примечания Адрес Github:DeepLearningNotes/d2l(github.com)
Все еще в процессе обновления......