Практическое руководство по глубокому обучению 5.4 PyTorch Чтение и запись файлов

глубокое обучение PyTorch

Участвуйте в 21-м дне Ноябрьского испытания обновлений, узнайте подробности события:Вызов последнего обновления 2021 г.

Почему необходимо читать и записывать файлы?

Чтение и запись файлов на самом деле не является чтением наборов данных.

Это когда вы тренируетесь периодически сохранять промежуточные результаты, чтобы гарантировать, что в случае случайного отключения питания сервера или чего-то еще вы потеряете результаты вычислений предыдущих дней.

В этом разделе рассказывается, как хранить веса и всю модель.

import torch
from torch import nn
from torch.nn import functional as F

loadиsave

Для одного тензора мы можем напрямую вызватьloadиsaveФункция читает и записывает их отдельно.

  • torch.saves

    torch.save(obj, f, pickle_module=<module 'pickle' from '.../pickle.py'>, pickle_protocol=2)
    

    параметр:

    • obj — сохранить объект
    • f - строка, имя файла
    • pickle_module — модуль для травления метаданных и объектов
    • pickle_protocol — укажите протокол pickle для переопределения параметров по умолчанию
  • torch.load

    torch.load(f, map_location=None, pickle_module=<module 'pickle' from '.../pickle.py'>)
    

    Прочитать файл с диска черезtorch.save()сохраненный объект.

    параметр:

    • f – строка, имя файла
    • map_location — функция или словарь, указывающий, как переназначить место хранения
    • pickle_module — модуль для распаковки метаданных и объектов (должен совпадать с pickle_module при сериализации файлов)
x = torch.arange(4)
torch.save(x, 'x-file')
x2 = torch.load('x-file')
print(x2)

инициализировать х

Сохраните x в текущей папке и назовите егоx-file, вы найдете файл с таким же именем в текущей папке. Конечно, это может быть не 0 1 2 3 после открытия, потому что метод кодирования другой, поэтому не беспокойтесь о том, что вы видите после открытия.

image.png

Объявите x2 и прочитайте его из файла, вы обнаружите, что результат тензор ([0, 1, 2, 3]), и результат правильный.

y = torch.zeros(4)
torch.save([x, y],'x-file')
x2, y2 = torch.load('x-file')
print(x2, y2)
>>
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))

Сохраните список тензоров и прочитайте их обратно в память.

y = torch.zeros(4)
torch.save(y[:2],'x-file')
x2, y2 = torch.load('x-file')
(x2, y2)
print(x2, y2)

Также возможна нарезка.

mydict = {'x': x, 'y': y}
torch.save(mydict, 'x-file')
mydict2 = torch.load('x-file')
mydict2
>>
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

Хранение словарей тоже работает.

Загрузить и сохранить параметры модели

Фреймворки глубокого обучения предоставляют встроенные функции для сохранения и загрузки целых сетей.

Но вместо сохранения всей модели сохраняются параметры модели.

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))

Помните многослойный персептрон, который мы написали вручную, используйте его, чтобы реализовать все сразу.

net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)

Теперь сгенерируйте сеть, используйте ее для вычисления X и назначьте ее Y.

torch.save(net.state_dict(), 'x-file')

Сохраните параметры сети.

net_ = MLP()
net_.load_state_dict(torch.load('x-file'))
net_.eval()

Генерация net_ также является многоуровневым персептроном, и параметры сети напрямую загружают параметры в файл.

net_.eval()заключается в изменении режима модели на режим оценки.

Y_clone = net_(X)
print(Y_clone == Y)
>>
tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

Назначьте новую сеть Y_clone, вы можете видеть, что Y_clone и Y одинаковы.

Конечно, также можно перейти на собственный слой pytorch.

    
MLP = nn.Sequential(nn.Linear(20,256),nn.Linear(256,10),nn.ReLU())

def init(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight)
        nn.init.zeros_(m.bias)


net = MLP
X = torch.randn(size=(2, 20))
Y = net(X)

torch.save(net.state_dict(), 'x-file')

net_ = MLP
net_.load_state_dict(torch.load('x-file'))
net_.eval()

Y_clone = net_(X)
Y_clone == Y
>>
tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

  1. Подробнее о серии «Практическое глубокое обучение» см. здесь:Колонка «Практическое глубокое обучение» (juejin.cn)

  2. Примечания Адрес Github:DeepLearningNotes/d2l(github.com)

Все еще в процессе обновления......