В этой статье обсуждается реализация Pytorch сети LSTM, а также организация кода и архитектура библиотеки Pytorch.
LSTM
LSTM — это рекуррентная нейронная сеть, подходящая для моделирования сериализованных входных данных. Этот пост Криса ОластатьяПодробное объяснение того, как работает ячейка LSTM, рекомендуется к прочтению.
две идеи
Ворота: ворота информационного потока
я т знак равно с я г м о я г ( Вт Икс я Икс т + Вт час я час т − 1 + б я )" role="презентация">it=sigmoid(Wxixt+Whiht-1+bi) я т знак равно с я г м о я г ( Вт Икс я Икс т + Вт час я час т − 1 + б я )ф т знак равно с я г м о я г ( Вт Икс ф Икс т + Вт час ф час т − 1 + б ф )" role="презентация">ft=sigmoid(Wxfxt+Whfht−1+bf) ф т знак равно с я г м о я г ( Вт Икс ф Икс т + Вт час ф час т − 1 + б ф )
о т знак равно с я г м о я г ( Вт Икс о Икс т + Вт час о час т − 1 + б о )" role="презентация">ot=sigmoid(Wxoxt+Whoht-1+bo) о т знак равно с я г м о я г ( Вт Икс о Икс т + Вт час о час т − 1 + б о )
x" role="presentation">
Ячейка: пул памяти
с т знак равно ф т ⊙ с т − 1 + я т ⊙ т а н час ( Вт Икс с Икс т + Вт час с час т − 1 + б с ) час т знак равно о т ⊙ т а н час ( с т )" role="презентация">ct=ft⊙ct−1+it⊙tanh(Wxcxt+Whcht−1+bc) ht=ot⊙tanh(ct) с т знак равно ф т ⊙ с т − 1 + я т ⊙ т а н час ( Вт Икс с Икс т + Вт час с час т − 1 + б с ) час т знак равно о т ⊙ т а н час ( с т )h" role="presentation">
Сравнение с обычным RNN
Обычные RNN имеют только один самообновляющийся скрытый блок состояния.
LSTM добавляет ячейку пула памяти и обновляет информацию в пуле памяти контролируемым образом через несколько шлюзов, а также определяет скрытое состояние с помощью информации в пуле памяти.
From Scratch
Ниже приведен код для ручной реализации LSTM, наследующий базовый класс.nn.Module
.
import torch.nn as nn
import torch
from torch.autograd import Variable
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, cell_size, output_size):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.cell_size = cell_size
self.gate = nn.Linear(input_size + hidden_size, cell_size)
self.output = nn.Linear(hidden_size, output_size)
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
self.softmax = nn.LogSoftmax()
def forward(self, input, hidden, cell):
combined = torch.cat((input, hidden), 1)
f_gate = self.gate(combined)
i_gate = self.gate(combined)
o_gate = self.gate(combined)
f_gate = self.sigmoid(f_gate)
i_gate = self.sigmoid(i_gate)
o_gate = self.sigmoid(o_gate)
cell_helper = self.gate(combined)
cell_helper = self.tanh(cell_helper)
cell = torch.add(torch.mul(cell, f_gate), torch.mul(cell_helper, i_gate))
hidden = torch.mul(self.tanh(cell), o_gate)
output = self.output(hidden)
output = self.softmax(output)
return output, hidden, cell
def initHidden(self):
return Variable(torch.zeros(1, self.hidden_size))
def initCell(self):
return Variable(torch.zeros(1, self.cell_size))
Несколько ключевых моментов:
- Размер тензора
- Порядок доставки информации
Pytorch Module
Сама библиотека Pytorch инкапсулирует больше функций для реализации LSTM, и организация классов и функций также очень значима. Мое понимание его реализации основано на следующих двух моментах:
- Иерархическая развязка ячеек, слоев и сложенных слоев, каждый слой абстрагирует часть параметров (структуру)
- Передача дескриптора функции: вернуть дескриптор функции после обработки параметров
forward
Начнем следить за картинкой, смотрим исходный кодGitHub.
LSTM-класс
документ:nn/modules/rnn.py
# nn/modules/rnn.py
class RNNBase(Module):
def __init__(self, mode, input_size, output_size):
pass
def forward(self, input, hx=None):
if hx is None:
hx = torch.autograd.Variable()
if self.mode == 'LSTM':
hx = (hx, hx)
func = self._backend.RNN() #!!!
output, hidden = func(input, self.all_weights, hx) #!!!
return output, hidden
class LSTM(RNNBase):
def __init__(self, *args, **kwargs):
super(LSTM, self).__init__('LSTM', *args, **kwargs)
-
LSTM
класс простоRNNBase
Декоратор класса. - в базовом классе
nn.Module
Вход__call__()
определяется как вызовforward()
метод, поэтому реальная функциональность реализована в_backend.RNN()
середина
Функция АвтоградРНН
Ищите ниже_backend.RNN
.
документ:nn/backends/thnn.py
# nn/backends/thnn.py
def _initialize_backend():
from .._functions.rnn import RNN, LSTMCell
оригинальный,_backend
Также индекс.
наконец-то нашелRNN()
функция.
документ:nn/_functions/rnn.py
# nn/_functions/rnn.py
def RNN(*args, **kwargs):
def forward(input, *fargs, **fkwargs):
func = AutogradRNN(*args, **kwargs)
return func(input, *fargs, **fkwargs)
return forward
def AutogradRNN(mode, input_size, hidden_size):
cell = LSTMCell
rec_factory = Recurrent
layer = (rec_factory(cell),)
func = StackedRNN(layer, num_layers)
def forward(input, weight, hidden):
nexth, output = func(input, hidden, weight)
return output, nexth
return forward
-
RNN()
является декоратором, в зависимости от того, есть лиcudnn
библиотека решает позвонитьAutogradRNN()
все ещеCudnnRNN()
, тут только наблюдатьAutogradRNN()
-
AutogradRNN()
выбранLSTMCell
,использоватьRecurrent()
функция обработанаCell
составляютLayer
, а потомLayer
входящийStackedRNN()
функция -
RNN()
иAutogradRNN()
возвращаетforward()
дескриптор функции
НижеRecurrent()
функция:
def Recurrent(inner):
def forward(input, hidden, weight):
output = []
steps = range(input.size(0) - 1, -1, -1)
for i in steps:
hidden = inner(input[i], hidden, *weight)
output.append(hidden[0])
return hidden, output
return forward
-
Recurrent()
Функция реализует «рекурсивную» структуру, комбинируя по размеру входных данныхCell
, завершает итерацию скрытых состояний и параметров. -
Recurrent()
функция будетCell(inner)
в сочетании какLayer
.
Функция StackedRNN()
def StackedRNN(inners, num_layers):
num_directions = len(inners)
total_layers = num_layers * num_directions
def forward(input, hidden, weight):
next_hidden = []
hidden = list(zip(*hidden))
for i in range(num_layers):
all_output = []
for j, inner in enumerate(inners):
hy, output = inner(input, hidden[l], weight[l])
next_hidden.append(hy)
all_output.append(output)
input = torch.cat(all_output, input.dim() - 1)
next_h, next_c = zip(*next_hidden)
next_hidden = (torch.cat(next_h, 0).view(total_layers, *next_h[0].size()),
torch.cat(next_c, 0).view(total_layers, *next_c[0].size()))
return next_hidden, input
return forward
-
StackedRNN()
функция будетLayer(inner)
объединены в стопку
Наконец, вычисление в базовой ячейке LSTM состоит изLSTMCell()
реализация функции.
Функция LSTMCell()
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
if input.is_cuda:
igates = F.linear(input, w_ih)
hgates = F.linear(hidden[0], w_hh)
state = fusedBackend.LSTMFused()
return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh)
hx, cx = hidden
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)
return hy, cy
Обратите внимание на приведенный выше код, который является основной формулой передачи информации LSTM. На этом наше путешествие завершено.
резюме
Нет ничего, что нельзя было бы решить, добавив слой абстракции, если нет, то добавьте еще один слой.
Чтобы повторить мое понимание приведенного выше кода:
- Иерархическая развязка ячеек, слоев и сложенных слоев, каждый слой абстрагирует часть параметров (структуру)
- Передача дескриптора функции: вернуть дескриптор функции после обработки параметров
forward
Как луковицу, мы очистили до конца и обнаружили, что обрабатываемая информация — это именно параметры входа, скрытого состояния и нескольких управляющих вентилей блока LSTM. В послойной абстракции Pytorch обрабатывает разные параметры на разных уровнях, обеспечивая масштабируемость и разделение между уровнями абстракции.
@ddlee
Эта статья соответствуетCreative Commons Attribution-ShareAlike 4.0 International License.
Это означает, что вы можете воспроизводить эту статью с указанием авторства, сопровождаемой этим соглашением.
Если вы хотите получать регулярные обновления о моих сообщениях в блоге, пожалуйста, подпишитесьДонгдон ежемесячно.
Ссылка на эту статью: блог Дошел о. Может только /posts/7 не 453…