pytorch pth model to onnx model и проверьте правильность результатов

искусственный интеллект

Это 10-й день моего участия в ноябрьском испытании обновлений. Узнайте подробности события:Вызов последнего обновления 2021 г.

Очень важным шагом в развертывании модели pytorch является сброс модели pth в ONNX, В этой статье описан метод.

дамп onnx

  • Создайте свою собственную модель pytorch и загрузите веса
model = create_model(num_classes=2)
model.load_state_dict(load(model_path, map_location='cpu')["model"])
  • Дамп файла onnx
dummy_input = torch.randn(1, 3, 256, 256, device='cpu')
torch.onnx._export(model, dummy_input, "faster_rcnn.onnx", verbose=True, opset_version=11)

Сохраните модель в текущем каталогеfaster_rcnn.onnxвнутри файла

Проверить валидность onnx

  • Установитьonnxruntimeбиблиотека
pip install onnxruntime
  • Загрузите модель onnx и протестируйте
import onnxruntime
from onnxruntime.datasets import get_example

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# 测试数据
dummy_input = torch.randn(1, 3, 256, 256, device='cpu')

example_model = get_example(<absolute_root_to_your_onnx_model_file>)
# netron.start(example_model) 使用 netron python 包可视化网络
sess = onnxruntime.InferenceSession(example_model)

# onnx 网络输出
onnx_out = sess.run(None, {<input_layer_name_of_your_network>: to_numpy(dummy_input)})
print(onnx_out)

model.eval()
with torch.no_grad():
    # pytorch model 网络输出
    torch_out = model(dummy_input)
	print(torch_out)
  • вывод:
onnx_out
[array([[  0.       ,  93.246    , 228.95842  , 256.       ],
       [  0.       ,   2.6370468, 209.39705  , 148.17822  ]],
      dtype=float32), array([1, 1], dtype=int64), array([0.1501071 , 0.07568519], dtype=float32)]

torch_out
[{'boxes': tensor([[  0.0000,  93.2459, 228.9584, 256.0000],
        [  0.0000,   2.6370, 209.3971, 148.1782]]), 'labels': tensor([1, 1]), 'scores': tensor([0.1501, 0.0757])}]

Получите собственное имя сетевого входного слоя

  • Иногда имя входного слоя модели непонятно, когда вы не знакомы с сетью.Вы можете использовать Netron для визуализации собственной сети, получить имя входного слоя и передать его в сессию onnx.

Уведомление ! ! !

  • В процессе преобразования модели pytorch в модель ONNX используемый экспортер является экспортером на основе траектории, что означает, что при его выполнении ему необходимо запустить модель один раз, а затем экспортировать операторы, которые фактически участвуют в операции. Это также означает, что если ваша модель является динамической, например, изменяя некоторые операции, которые зависят от входных данных, полученные результаты будут неточными.Кроме того, траектория может быть действительна только для определенного размера ввода (вот почему нам нужно имеют явную траекторию Одна из причин вводадокументация по pytorch

  • То есть, если в сетевом модуле есть ветви, подобные if.. else.., одна из ветвей будет выбрана в соответствии с исходными данными, используемыми при создании модели ONNX, так что сгенерированная модель ONNX будет только сохранить эту ветвь Структура исходной модели pytorch больше не существует в этой модели.

использованная литература