pytorch реализует слой обратного градиента GRL

алгоритм

 

В GRL цель, которую необходимо достичь, такова: при прямой передаче результат операции не меняется, а при градиентной передаче градиент, переданный предыдущему конечному узлу, становится направлением, противоположным исходному. Лучше всего иллюстрирует пример:

import torch
from torch.autograd  import  Function

x = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)

z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + y
s =6* f.sum()

print(s)
s.backward()
print(x)
print(x.grad)

Результат работы этой программы:

tensor(672., grad_fn=<MulBackward0>)
tensor([1., 2., 3.], requires_grad=True)
tensor([18., 30., 42.])

Процесс работы для каждого измерения в тензоре:

f(x)=(x^{2}+x)*6

Тогда производная по x равна:

\frac{\mathrm{d} f}{\mathrm{d} x} = 12x+6

Итак, когда ввод x=[1,2,3], соответствующий градиент: [18,30,42]

Итак, это обычный процесс получения градиента, но как перевернуть градиент? Очень просто, посмотрите на код ниже:

import torch
from torch.autograd  import  Function

x = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)

z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + y

class GRL(Function):
    def forward(self,input):
        return input
    def backward(self,grad_output):
        grad_input = grad_output.neg()
        return grad_input


Grl = GRL()

s =6* f.sum()
s = Grl(s)

print(s)
s.backward()
print(x)
print(x.grad)

Текущий результат:

tensor(672., grad_fn=<GRL>)
tensor([1., 2., 3.], requires_grad=True)
tensor([-18., -30., -42.])

По сравнению с предыдущей программой, эта программа добавляет только градиентный флип-слой:

class GRL(Function):
    def forward(self,input):
        return input
    def backward(self,grad_output):
        grad_input = grad_output.neg()
        return grad_input

Форвард в этой части не выполняет никаких операций, а в реверсе выполняется операция .neg(), что эквивалентно отражению градиента. В обратной части FUnction в torch.autograd значение grad_output по умолчанию равно 1 без каких-либо операций.