Автор: Фрэнк Одом
Компиляция: McGL
Хуки PyTorch — это простой и мощный способ взлома нейронных сетей, повышающий производительность.
Что такое крючок?
Хуки на самом деле довольно распространены в разработке программного обеспечения и не уникальны для PyTorch. Вообще говоря, «хук» — это функция, которая автоматически выполняется после определенного события. Некоторые примеры хуков, с которыми вы могли столкнуться в реальном мире:
-
Веб-сайт показывает рекламу после того, как вы посещаете N различных страниц.
-
Банковское приложение отправляет уведомление, когда средства зачисляются на ваш счет.
-
Когда окружающее освещение уменьшается, яркость экрана телефона уменьшается.
Этих вещей можно добиться и без хуков, но во многих случаях хуки облегчают жизнь программисту.
PyTorch регистрирует хуки для каждого тензора или объекта nn.Module. Хуки срабатывают при прямом или обратном распространении объектов. Они имеют следующие сигнатуры функций:
from torch import nn, Tensor
def module_hook(module: nn.Module, input: Tensor, output: Tensor):
# For nn.Module objects only.
def tensor_hook(grad: Tensor):
# For Tensor objects only.
# Only executed during the *backward* pass!
Каждый хук может изменять входные, выходные или внутренние параметры модуля. Чаще всего используется для отладки. Но мы увидим, что есть много других применений для них.
Пример №1: Детали выполнения модели
Вы сами вставляли операторы печати в модель, чтобы попытаться выяснить причину сообщения об ошибке? (Конечно, я виноват в этом.) Это отвратительная практика отладки, и во многих случаях мы забываем удалить оператор печати, когда закончим. Из-за этого наш код выглядит непрофессионально, а пользователи получают странные сообщения каждый раз, когда используют ваш код.
Никогда больше! Давайте использовать хуки для отладки моделей без какого-либо изменения их реализации. Например, предположим, что вы хотите знать форму вывода каждого слоя. Мы можем создать простую оболочку, которая выводит форму с помощью хука.
class VerboseExecution(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
# Register a hook for each layer
for name, layer in self.model.named_children():
layer.__name__ = name
layer.register_forward_hook(
lambda layer, _, output: print(f"{layer.__name__}: {output.shape}")
)
def forward(self, x: Tensor) -> Tensor:
return self.model(x)
Самое приятное: он работает даже с модулями PyTorch, которые мы не создавали! Давайте продемонстрируем это с помощью ResNet50 и некоторых фиктивных входных данных.
import torch
from torchvision.models import resnet50
verbose_resnet = VerboseExecution(resnet50())
dummy_input = torch.ones(10, 3, 224, 224)
_ = verbose_resnet(dummy_input)
# conv1: torch.Size([10, 64, 112, 112])
# bn1: torch.Size([10, 64, 112, 112])
# relu: torch.Size([10, 64, 112, 112])
# maxpool: torch.Size([10, 64, 56, 56])
# layer1: torch.Size([10, 256, 56, 56])
# layer2: torch.Size([10, 512, 28, 28])
# layer3: torch.Size([10, 1024, 14, 14])
# layer4: torch.Size([10, 2048, 7, 7])
# avgpool: torch.Size([10, 2048, 1, 1])
# fc: torch.Size([10, 1000])
Пример № 2: Извлечение признаков
Часто мы хотим создать функции из предварительно обученной сети, а затем использовать их для другой задачи (например, классификации, поиска сходства и т. д.). Используя хуки, мы можем извлекать функции, не воссоздавая существующую модель и не изменяя ее каким-либо образом.
from typing import Dict, Iterable, Callable
class FeatureExtractor(nn.Module):
def __init__(self, model: nn.Module, layers: Iterable[str]):
super().__init__()
self.model = model
self.layers = layers
self._features = {layer: torch.empty(0) for layer in layers}
for layer_id in layers:
layer = dict([*self.model.named_modules()])[layer_id]
layer.register_forward_hook(self.save_outputs_hook(layer_id))
def save_outputs_hook(self, layer_id: str) -> Callable:
def fn(_, __, output):
self._features[layer_id] = output
return fn
def forward(self, x: Tensor) -> Dict[str, Tensor]:
_ = self.model(x)
return self._features
Мы можем использовать экстрактор функций, как и любой другой модуль PyTorch. Запустив тот же фиктивный ввод, что и раньше, мы получим:
resnet_features = FeatureExtractor(resnet50(), layers=["layer4", "avgpool"])
features = resnet_features(dummy_input)
print({name: output.shape for name, output in features.items()})
# {'layer4': torch.Size([10, 2048, 7, 7]), 'avgpool': torch.Size([10, 2048, 1, 1])}
Пример № 3: Отсечение градиента
Отсечение градиента — это хорошо известный метод работы со взрывающимися градиентами. PyTorch уже предоставляет служебные методы для обрезки градиента, но мы также можем легко реализовать их с помощью хуков. Любой другой метод отсечения/нормализации/модификации градиента может быть реализован таким же образом.
def gradient_clipper(model: nn.Module, val: float) -> nn.Module:
for parameter in model.parameters():
parameter.register_hook(lambda grad: grad.clamp_(-val, val))
return model
Этот хук запускается во время обратного распространения, поэтому на этот раз мы также вычисляем фиктивную метрику потерь. После выполнения loss.backward() мы можем вручную проверить градиенты параметров, чтобы убедиться, что он работает правильно.
clipped_resnet = gradient_clipper(resnet50(), 0.01)
pred = clipped_resnet(dummy_input)
loss = pred.log().mean()
loss.backward()
print(clipped_resnet.fc.bias.grad[:25])
# tensor([-0.0010, -0.0047, -0.0010, -0.0009, -0.0015, 0.0027, 0.0017, -0.0023,
# 0.0051, -0.0007, -0.0057, -0.0010, -0.0039, -0.0100, -0.0018, 0.0062,
# 0.0034, -0.0010, 0.0052, 0.0021, 0.0010, 0.0017, -0.0100, 0.0021,
# 0.0020])
"источник:" к data science.com/how-to-use-…