[Анализ исходного кода] Конвейер глубокого обучения, параллельный GPipe(3) -- перерасчет

машинное обучение глубокое обучение

0x00 сводка

GPipe — это параллельная библиотека для обучения нейронных сетей, разработанная командой Google Brain и поддерживающая сверхкрупномасштабные модели. В этой статье представлена ​​ее функция пересчета, которую можно проверить с помощью других реализаций.

Другие статьи из этой серии:

[Анализ исходного кода] Конвейер глубокого обучения, параллельный Gpipe(1) --- Базовая реализация конвейера

[Анализ исходного кода] Конвейер глубокого обучения, параллельный GPipe (2) ----- накопление градиента

0x01 Обзор

1.1 Предыдущий обзор

Как упоминалось выше, существует несколько необходимых параллельных технологий для обучения распределенной модели:

  • Параллельный конвейер, особенно как автоматически установить конвейер;
  • накопление градиента;
  • обратный пересчет;
  • стратегия 1F1B (будем использовать анализ PipeDream);

В предыдущей статье мы рассказали, как Gpipe реализует конвейерный параллелизм и накопление градиентов.

Есть проблема с конвейерным параллелизмом: использование памяти слишком велико. Если промежуточный результат (активация) каждого микропакетного прямого расчета потребляется обратным расчетом, необходимо кэшировать n копий (количество раз накопления градиента) полной прямой активации в видеопамяти. В это время необходимо использовать еще один важный метод: пересчет (контрольная точка).

На основе статьи «Обучение глубоких сетей с сублинейной стоимостью памяти» в этой статье анализируется исходный код pytorch и Gpipe, в надежде получить конкретное представление о технологии «градиентных контрольных точек».

1.2 Gradient checkpointing

В 2016 году команда Чена Тяньци предложила такие технологии, как «контрольные точки градиента / активации (обратный пересчет)», связанные с сублинейной оптимизацией памяти, с целью уменьшить использование памяти, вызванное промежуточной активацией во время обучения глубокому обучению. Технология контрольных точек — это своего рода сублинейная оптимизация памяти, в дополнение к другим технологиям, таким как разгрузка ЦП (разгрузка ЦП широко используется в фреймворке Microsoft Deepspeed).

Градиентная контрольная точка — это систематический подход к снижению потребления памяти во время глубокого обучения нейронной сети путем повторного запуска сегментов прямого распространения для каждого сегмента, установленного в качестве контрольной точки при обратном распространении:

  • Методы градиентных контрольных точек сосредоточены на уменьшении накладных расходов памяти на хранение промежуточных результатов (карт объектов) и градиентов, поскольку во многих распространенных глубоких сетях промежуточные результаты намного больше по сравнению с параметрами модели.
  • Градиентная контрольная точка — это метод, который обменивает время (вычислительную мощность) на пространство (видеопамять), сжимая след модели за счет уменьшения сохраненного значения активации, но значение активации, которое не сохраняется, должно быть пересчитано при расчете градиента, т. е. занимает в два раза больше времени расчета стоимости прямого распространения.
  • В частности, это установка некоторых контрольных точек градиента и выпуск промежуточных результатов за пределы контрольных точек.В дальнейшем, если прямой результат будет найден не в видеопамяти в процессе обратного распространения, будет найдена ближайшая контрольная точка градиента, а затем Вычислить, восстанавливая освобожденные тензоры.

0x02 Базовые знания

2.1 Как работает вывод

заимствовано отсюдаТехнология оптимизации видеопамяти во время обучения - слияние ОП и контрольная точка градиентаидеи.

Модели DNN состоят из ряда слоев разных типов (например, сверточные слои, полносвязные слои, объединяющие слои).

Ключом к обратному распространению является «автоматический вывод цепочки», но на самом деле BP также добавляет к этой основе небольшой механизм динамического программирования. Общий BP состоит из следующих двух шагов:

  • прямое проведение. Взяв в качестве примера классификацию изображений, текущие модели сначала делают прогнозы на небольшом наборе обучающих выборок (также известных как мини-пакеты). Этот процесс называется прямой проводимостью.

    • Для прогнозирования входные данные из мини-пакета подаются на первый слой модели.
    • Затем каждый слой вычисляет функцию на своем входе, чтобы сгенерировать выход для следующего слоя. Прямая проводимость записывает следующие два значения: выходное значение промежуточного узла и градиент выходного значения по отношению к входному значению.
    • Результатом последнего слоя является предсказание класса. На основе прогнозируемых меток модели и фактических меток для каждого изображения выходной слой вычисляет потери (или ошибки).
  • Расчет градиента обратного распространения. Обратное распространение — это процесс вычисления градиента конечного выходного значения сети по отношению к выходу этого слоя. То есть, начиная с вывода, распространяя обратно значения градиента,Рассчитайте градиент выходного значения для каждой промежуточной переменной и сохраните его.. Каждый слой вычисляет ошибку предыдущего слоя и обновляет веса (градиенты потерь) для всех связанных слоев, что сдвинет прогнозы модели к желаемому результату.

В процессе возврата градиента необходимо использовать выходное значение узла, но когда вычисление градиента выполняется в обратном распространении, БП не будет выполнять повторные вычисления, причина в том, что промежуточные переменные сохраняются при прямой передаче , то естьВыходное значение каждого промежуточного узла. BP непрерывно выполняет обратное распространение градиентов и сохраняет промежуточные градиенты до тех пор, пока не будут решены все промежуточные значения вычислительного графа и градиенты начальных значений.

Давайте посмотрим, как работает обратная обработка.

Так называемая система автоматического вывода на самом деле является «полуавтоматической»: она не находит напрямую аналитическую форму производной сложной функции, а строит расчетный график и заранее написанные правила вывода для базовых функций, объединенных с цепочкой правило вывода Реализован автоматический вывод.

Возьмем функцию в качестве примера для иллюстрации, ее выражение выглядит следующим образом:

f(x) = x * (x + 1)

Аналитическая формула его градиента, полученная простым математическим выводом:f'(x) = x + 1 + x; Сначала отложите этот результат и посмотрите, как схема автоматического вывода шаг за шагом находит этот результат, и нарисуйте диаграмму расчета следующим образом:

                       +---------+
                       |         |
               +------>+  x + 1  +----+
               |       |         |    | 3
             2 |       +---------+    |
               |                      |
               |                      v
         +-----+--+                  ++------+
         |        |                  |       |
+------> |    x   +----------------> |   +   +---------->
         |        |         1        |       |
         +--------+                  +-------+
​

На вычислительном графе обратное распространение сначала проходит через операцию умножения в соответствии с приведенными выше правилами вывода:

  • Градиент на пути 1 равенx + 1;
  • Градиент на пути 3x;
  • Путь 3 распространяется обратно через путь 2, за исключением того, что его градиентx + 1Кроме того, умножьте градиент пути 21,
  • Путь 2 и Путь 1 сходятся вместе, поэтому окончательный градиентx + 1(路径1)+ 1 * x(路径2)= x + 1 + x, что в точности равно результату, рассчитанному по математической формуле;

Платформа автоматической деривации опирается на эти основные правила и правила деривации цепочки, чтобы работать эффективно и точно.

В процессе обучения подавляющего большинства нейронных сетей некоторые промежуточные переменные, полученные при прямом проходе, очень полезны (для удобства вывода) при расчете обратного прохода. На практике лучше всего запрограммировать кэширование этих промежуточных переменных, чтобы их можно было использовать во время обратного распространения. Поэтому основную часть занимаемой памяти занимает промежуточный результат, представляющий собой так называемую «карту признаков». В этой статье x — это промежуточный результат (карта объектов), выдаваемый предыдущим слоем.

При применении правила вывода умножения мы должны заранее сохранить промежуточные результаты x и x+1.. Обратите внимание, что правила умножения и его вывода, определенные фреймворком, являются общими.Левая и правая части умножения могут быть двумя несвязанными значениями, поэтому они должны сохраняться одновременно. То есть x + 1 может быть ( x + y ) + z .... в других функциях и может содержать другие входные переменные, поэтому его нельзя вычислить из входного x по простой формуле, например + 1 .

Без учета оптимизации самого фреймворка использование памяти включает в себя x и x + 1. Обратите внимание, что x — это не отдельное значение, а что-то вроде32x32x128Карты объектов такого размера.

2.2 Контрольная точка градиента

Как упоминалось в предыдущем разделе, в оригинальном виде нейронных сетей:

  • В прямой функции значения функции активации каждого слоя необходимо сохранять после расчета, потому что они должны быть использованы при расчете обратного распространения.
  • При движении назад градиент рассчитывается на основе значения функции потерь и соответствующего значения функции активации слоя.
  • Следовательно, нам нужно кэшировать n копий (количество накоплений градиента) полных прямых активаций в видеопамяти. То есть в этом случае использование памяти пропорционально количеству слоев.

Поэтому в настоящее время существует проблема с конвейерным параллелизмом: использование памяти слишком велико.

Можно ли не хранить значение активации? Например, в обратном направлении вы можете перемотать вперед, когда вам нужно активировать значение функции.

Если мы не сохраним ни одного из них, мы все пересчитаем через forward? Тогда это слишком трудоемко в больших моделях. Следовательно, мы можем выбрать компромиссный метод, например, сохранить только значения функции активации некоторых слоев. Когда в обратном направлении требуется значение функции активации, просто возьмите самое последнее значение активации. Так была внедрена важная технология: Checkpointing.

2.3 Содержание статьи

2.3.1 Основные документы

Основная идея Gpipe's Checkpointing исходит из следующих двух статей:

  • Andreas Griewank and Andrea Walther. Algorithm 799: revolve: an implementation of check- pointing for the reverse or adjoint mode of computational differentiation. ACM Transactions on Mathematical Software (TOMS), 26(1):19–45, 2000.
  • Tianqi Chen, Bing Xu, Chiyuan Zhang, and Carlos Guestrin. Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174, 2016.

Основная идея заключается в обмене вычислительной мощности на память (вычисления на видеопамять, а промежуточные результаты, необходимые для обратного вывода, пересчитываются из контрольной точки), а пропускную способность на видеопамять.

2.3.2 Бумажные обучающие глубокие сети с сублинейной стоимостью памяти

2.3.2.1 Основная идея

Мы в основном смотрим на эту бумагу.

Контрольные точки упоминаются в статье Чена Тяньци «Обучение глубоких сетей с сублинейной стоимостью памяти», опубликованной в 2016 году, также известной как сублинейная оптимизация памяти. Есть две идеи для сублинейной оптимизации памяти: контрольные точки и разгрузка ЦП:

  • Основная идея контрольной точки состоит в том, чтобы пометить небольшое количество тензоров (тензоров, которые являются контрольными точками) в прямой сети, и только эти отмеченные тензоры будут сохранены в прямом расчете.Тензор временно пересчитывается один раз вперед, чтобы получить его. Таким образом, большое количество активаций не нужно сохранять до обратного вычисления, что эффективно сокращает жизненный цикл большого количества тензоров и значительно повышает эффективность повторного использования памяти.
  • Идея разгрузки ЦП аналогична технологии «виртуальной памяти» в компьютерных операционных системах (временная замена памяти, которая обычно не используется, на диск, тем самым увеличивая общий объем памяти).При глубоком обучении память графического процессора (Память устройства) Характеристики дорогие, высокая скорость и небольшая емкость, в то время как основная память ЦП (Host Memory) характеризуется дешевизной, относительно низкой скоростью и большой емкостью; затем временно замените некоторые временно неиспользуемые активации в прямом расчете на Основная память ЦП, а затем подкачать ее в видеопамять ГП, когда требуется обратный расчет, что также может сэкономить видеопамять.

Две сублинейные оптимизации памяти достигают оптимизации видеопамяти разными способами: контрольные точки обменивают видеопамять с дополнительными вычислительными затратами, а разгрузка ЦП обменивает видеопамять с дополнительными затратами на передачу.

2.3.2.2 Оптимизация контрольных точек

На рисунке выше показано сравнение графика расчета до и после Checkpointing.

Серый слева — конфигурация сети.

Средний график нормального градиента представляет собой процесс распространения в прямом и обратном направлении нормальной сети.

Градиентный график, оптимизированный для памяти, справа является результатом применения градиента-контрольной точки. Для дальнейшего сокращения памяти некоторые промежуточные результаты удаляются и при необходимости восстанавливаются из дополнительных вычислений вперед.

  • Во-первых, нейронная сеть разделена на несколько частей (три сегмента на рисунке справа), алгоритм запоминает только вывод каждого сегмента, и в каждом сегменте удаляет все промежуточные результаты.
  • Во-вторых, на этапе обратного распространения мы можем повторно вычислить отброшенные промежуточные результаты, пробежав вперед от самого последнего записанного результата.
  • Таким образом, мы оплачиваем только стоимость памяти для хранения выходных данных каждого сегмента плюс максимальную стоимость памяти для обратного распространения по каждому сегменту.

такgradient-checkpointэтоДело не в том, что промежуточные результаты не нужны, но есть способ вычислить промежуточные результаты, которые были отброшены ранее, в режиме реального времени в процессе вывода..

Пересчет не предназначен только для конвейерного параллелизма, и раньше он в основном использовался в сценариях с одной картой или параллельными данными. Но эта оптимизация очень критична при конвейерном параллелизме, потому что она делает ненужным кэширование всех активаций в прямом направлении, а требует кэширования только очень небольшого числа (например, только одного слоя Transformer Layer), конкретного тензора, который проверено, это значительно экономит накладные расходы памяти при конвейерном параллелизме.

0x03 OpenAI

Представлено на OpenAIgradient-checkpointэто тезисTraining Deep Nets with Sublinear Memory CostРеализация идеи, т.к. ее документация относительно полная (GitHub.com/openlove/grad…), мы можем извлечь из этого уроки.

Общая идея состоит в том, чтобы установить несколько контрольных точек в середине нейронной сети и сохранять контрольную точку каждый sqrt (n) для карты промежуточных результатов. Все промежуточные результаты, кроме контрольных, отбрасываются, а когда для обратного распространения требуется промежуточный результат, расчет начинается с ближайшей контрольной точки, что не только экономит видеопамять, но и позволяет избежать утомительного процесса расчета с нуля.

3.1 Расчетный граф

Для простой n-слойной нейронной сети с прямой связью вычислительный граф для получения градиента выглядит следующим образом:

детали следующим образом:

  • Иерархические активации нейронной сети соответствуют узлам, помеченным f, и при прямом распространении все эти узлы необходимо вычислять последовательно.
  • Функция потерь использует метки b-узлов для активаций и градиентов этих параметров иерархии, и во время обратного распространения все эти узлы необходимо вычислять в обратном порядке.
  • Вычисление активаций f-узлов является необходимым условием для дальнейшего вычисления градиентов b-узлов, поэтому f-узлы остаются в памяти после прямого распространения.
  • Эти активации удаляются из памяти только тогда, когда обратное распространение выполняется достаточно далеко, чтобы вычисление соответствующих градиентов больше не требовало использования активаций на более поздних уровнях или дочерних элементах f . Это означает, что простое обратное распространение требует памяти, которая линейно растет с количеством слоев в нейронной сети.

3.2 Пересчет

Простое обратное распространение уже оптимально с вычислительной точки зрения, потому что каждый узел нужно вычислить только один раз. Однако если мы захотим пересчитать узлы, то сможем сэкономить много памяти. Когда нам нужно значение активации узла, мы можем просто повторно вычислить значение активации узла с прямым распространением. Мы можем выполнять вычисления последовательно, пока не вычислим узлы, которые необходимо распространить обратно со значениями активации.

Использование этой стратегии требует, чтобы память для вычисления градиентов была стабильной на количестве слоев нейронной сети n, а n было оптимальным с точки зрения памяти. Обратите внимание, однако, что количество вычислений для узлов теперь увеличено на n^2 по сравнению с предыдущим n. Каждый из n узлов пересчитывается n раз. Следовательно, вычислительный граф становится очень медленным для вычисления глубоких сетей, что делает этот подход непригодным для глубокого обучения.

3.3 Стратегия

Чтобы соблюсти баланс между памятью и вычислением, нам нужна политика, позволяющая пересчитывать узлы, но это происходит не очень часто. Наша стратегия здесь состоит в том, чтобы использовать нейронную сеть для активации подмножества, помеченного как узел. Фиолетовый узел представляет собой заданный период времени, который необходимо хранить в памяти.

Эти узлы контрольных точек остаются в памяти после прямого распространения, в то время как оставшиеся узлы пересчитываются не более одного раза. После пересчета узлы без контрольных точек остаются в памяти до тех пор, пока они больше не понадобятся для выполнения обратного распространения. Для простой нейронной сети с прямой связью узлы активации всех нейронов являются либо точками соединения, либо точками разделения графа, определяемого прямым распространением. Это означает, что нам нужно пересчитать только узел между узлом b и последней контрольной точкой в ​​процессе обратного распространения.Когда обратное распространение достигает узла контрольной точки, который мы сохранили, тогда все узлы, пересчитанные из этого узла, находятся в памяти, и их можно удалить.

3.4 Процесс

Сначала мы устанавливаем две контрольные точки, две фиолетовые слева от первой строки на рисунке, обратите внимание, что первая фиолетовая справа — это вход.

Во-вторых, прямое распространение завершено, а обратное распространение запущено, то есть обратное распространение начинается с фиолетовой цифры 1 в нижнем ряду.

В-третьих, пришел к фиолетовому числу 2 в следующей строке, которое зависит от фиолетового числа 3 выше для расчета (напомним, что вычисление обратного распространения требует вывода прямого вычисления), это фиолетовое число 3 является контрольной точкой, которая существует в память, поэтому выполняйте обратное распространение нормально

В-четвертых, я пришел к белой цифре 4 в следующей строке.Это зависит от фиолетовой цифры 5 выше для расчета.Цифра 5 не является контрольной точкой и не находится в памяти.Нужно перезапустить расчет с контрольной точки перед ней, что есть, начните с фиолетового числа 7. . Вычисляется новая контрольная точка, а изначальную фиолетовую цифру 5 в строке выше можно удалить, потому что она больше не нужна.

В-пятых, вычислите новое фиолетовое число 4 ниже, тем самым продолжив вычисление в обратном направлении.

Из-за того, что он включает автоматическую генерацию контрольных точек, код OpenAI относительно неясен, поэтому здесь он не будет анализироваться, если вам интересно, вы можете изучить его самостоятельно.

0x04 Реализация Pytorch

Давайте теперь посмотрим на Пёрха.

4.1 Базовые знания

4.1.1 Variable & Function

В PyTorch autograd является основным содержимым всех нейронных сетей, предоставляя методы автоматического вывода для всех операций Tensor. Это среда определения-за-запуском, что означает, что обратное распространение определяется тем, как выполняется код.

autograd.Variable — это autograd в основных классах. Обертка тензора и поддерживает почти все операции, определенные на нем. После того, как вы завершили свою операцию, вы можете вызвать .backward() для автоматического расчета всех градиентов.

Еще один класс, который очень важен для реализации autograd, — функция, функция — это просто операция над переменной, такая как сложение, вычитание, умножение и деление, relu, pool и т. д. Но это не простая арифметика. В отличие от обычных операций Python или numpy, функция предназначена для вычислительных графиков и должна вычислять градиент обратного распространения. Следовательно, ему нужно не только выполнять эту операцию (прямой процесс), но также нужно использовать кеш, чтобы сохранять входные данные прямого распространения (для вычисления градиента) и поддерживать обратное распространение для вычисления градиента.

Pytorch использует переменные и функции для построения вычислительных графиков. Оглядываясь назад на Переменную, Переменная похожа на узел в графе вычислений, сохраняя результаты вычислений (включая значение активации прямого распространения, градиент обратного распространения), а Функция — как ребро в графе вычислений, реализующее вычисление Переменная и вывод новой переменной.

Таким образом, функция и переменная представляют собой автоматический механизм вывода pytorch, который определяет взаимосвязь расчета между каждой переменной.

Примечания. В последнем коде PyTorch Function был изменен на класс Node, который должен лучше представлять концепцию узлов в графе вычислений.

4.1.2 Дальнейшее понимание функции

Мы можем использовать класс autograd.Function для настройки модели, слоя, функции активации и функции потерь, которые легче понять, Фактически, это по сути функция, и это только простая или сложная функция.

4.2 Нормальный режим

Эта часть кода находится в torch/utils/checkpoint.py. Pytorch требует, чтобы пользователь указал контрольную точку, поэтому реализация относительно проста.

4.2.1 Упаковка

В torch/utils/checkpoint.py есть пакет для чекпоинта Этот комментарий очень стоит прочесть Давайте изучим его подробнее.

  • Суть Checkpointing заключается в обмене вычислений на память.

  • Checkpointing НетСохраните все промежуточные активации для всего вычислительного графа, необходимого для обратного вычисления, но вместо этого пересчитайте их в обратном распространении.

  • Во время прямого прохода функция параметра Checkpointing запускается в режиме torch.no_grad, поэтому промежуточные значения активации не рассчитываются. Вместо этого прямой проход сохраняет входной кортеж иfunctionпараметр.

  • При обратном проходе сохраненный ввод иfunctionизъятые,functionбудет вычисляться снова, на этот раз с учетом промежуточных значений активации, затем

    Используйте эти значения активации для вычисления градиентов.

def checkpoint(function, *args, **kwargs):
    r"""Checkpoint a model or part of the model
​
    Checkpointing works by trading compute for memory. Rather than storing all
    intermediate activations of the entire computation graph for computing
    backward, the checkpointed part does **not** save intermediate activations,
    and instead recomputes them in backward pass. It can be applied on any part
    of a model.
​
    Specifically, in the forward pass, :attr:`function` will run in
    :func:`torch.no_grad` manner, i.e., not storing the intermediate
    activations. Instead, the forward pass saves the inputs tuple and the
    :attr:`function` parameter. In the backwards pass, the saved inputs and
    :attr:`function` is retrieved, and the forward pass is computed on
    :attr:`function` again, now tracking the intermediate activations, and then
    the gradients are calculated using these activation values.
​
    The output of :attr:`function` can contain non-Tensor values and gradient
    recording is only performed for the Tensor values. Note that if the output
    consists of nested structures (ex: custom objects, lists, dicts etc.)
    consisting of Tensors, these Tensors nested in custom structures will not
    be considered as part of autograd.
​
    Args:
        function: describes what to run in the forward pass of the model or
            part of the model. It should also know how to handle the inputs
            passed as the tuple. For example, in LSTM, if user passes
            ``(activation, hidden)``, :attr:`function` should correctly use the
            first input as ``activation`` and the second input as ``hidden``
        preserve_rng_state(bool, optional, default=True):  Omit stashing and restoring
            the RNG state during each checkpoint.
        args: tuple containing inputs to the :attr:`function`
​
    Returns:
        Output of running :attr:`function` on :attr:`*args`
    """
    # Hack to mix *args with **kwargs in a python 2.7-compliant way
    preserve = kwargs.pop('preserve_rng_state', True)
    if kwargs:
        raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
​
    return CheckpointFunction.apply(function, preserve, *args)

4.2.2 Обращение с оборудованием

Поскольку у pytorch нет возможности узнать, будет ли функция прямого прохода перемещать некоторые параметры на разные устройства, для этого требуется некоторая логика для сохранения состояния RNG для этих устройств. Хотя можно сохранить/восстановить все состояния ГСЧ для всех видимых устройств, в большинстве случаев это расточительно, поэтому в качестве компромисса pytorch сохраняет состояния ГСЧ только для всех устройств с тензорными параметрами.

def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
    # This will not error out if "arg" is a CPU tensor or a non-tensor type because
    # the conditionals short-circuit.
    fwd_gpu_devices = list(set(arg.get_device() for arg in args
                               if isinstance(arg, torch.Tensor) and arg.is_cuda))
​
    fwd_gpu_states = []
    for device in fwd_gpu_devices:
        with torch.cuda.device(device):
            fwd_gpu_states.append(torch.cuda.get_rng_state())
​
    return fwd_gpu_devices, fwd_gpu_states
​
​
def set_device_states(devices, states) -> None:
    for device, state in zip(devices, states):
        with torch.cuda.device(device):
            torch.cuda.set_rng_state(state)

4.2.3 Основная логика

CheckpointFunction расширяет torch.autograd.Function.

Мы можем расширить функцию для удовлетворения наших собственных потребностей, и расширение должно настроить прямую операцию функции и соответствующую обратную операцию.В то же время в прямом направлении нам нужно сохранить входное значение для обратного.

  • Прямая функция вводит тензор и вычисляет выходной тензор.

    • Во время прямого прохода функция параметра Checkpointing запускается в режиме torch.no_grad, поэтому промежуточные значения активации не рассчитываются.
    • Прямой проход содержит входной кортеж иfunctionпараметр.
    • Для CheckpointFunction по-прежнему необходимо хранить некоторую дополнительную информацию в прямом направлении (то есть информацию о rng, упомянутую выше) для расчета в обратном распространении.
    • Выполните прямой проход, чтобы вернуть значение активации.
  • Обратная функция принимает градиент выходного тензора относительно некоторого скалярного значения и вычисляет градиент относительно входного тензора того же скалярного значения.

    • При обратном проходе сохраненный ввод иfunctionбыл вывезен.
    • functionбудет вычисляться снова, на этот раз отслеживая промежуточные активации, а затем используя эти активации для вычисления градиентов.
"""
我们可以通过建立torch.autograd的子类来实现我们自定义的autograd函数,
并完成张量的正向和反向传播。
"""
class CheckpointFunction(torch.autograd.Function):
​
    @staticmethod
    def forward(ctx, run_function, preserve_rng_state, *args):
        """
        在forward函数中,接收包含输入的Tensor并返回包含输出的Tensor。
        ctx是环境变量,用于提供反向传播是需要的信息。我们可以使用上下文对象来缓存对象,以便在反向传播中使用。可通过ctx.save_for_backward方法缓存数据,save_for_backward只能传入Variable或是Tensor的变量。
        """
        check_backward_validity(args)
        # 保存前向传播函数
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
        ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
        if preserve_rng_state:
            ctx.fwd_cpu_state = torch.get_rng_state()
            # Don't eagerly initialize the cuda context by accident.
            # (If the user intends that the context is initialized later, within their
            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
            # we have no way to anticipate this will happen before we run the function.)
            # 存储前向传播时候的设备状态
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
​
        # Save non-tensor inputs in ctx, keep a placeholder None for tensors
        # to be filled out during the backward.
        ctx.inputs = [] 
        ctx.tensor_indices = []
        tensor_inputs = []
        for i, arg in enumerate(args): # 存储输入数值
            if torch.is_tensor(arg):
                tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
            else:
                ctx.inputs.append(arg)
​
        # `saved_for_backward`是会保留此input的全部信息, 并避免in-place操作导致的input在backward被修改的情况. 它是将函数的输入参数保存起来以便后面在求导时候再使用,起前向反向传播中协调作用。      
        ctx.save_for_backward(*tensor_inputs)
​
        with torch.no_grad():
            outputs = run_function(*args) # 进行前向传播
        return outputs
​
"""
在反向传播中,我们接收到上下文对象和一个张量,
其包含了相对于正向传播过程中产生的输出的损失的梯度。
我们可以从上下文对象中检索缓存的数据,
并且必须计算并返回与正向传播的输入相关的损失的梯度。
"""      
    # 自动求导是根据每个op的backward创建的graph来进行的
    @staticmethod
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad() or when an `inputs` parameter"
                " is passed to .backward(). Please use .backward() and do not pass its `inputs`"
                " argument.")
        # Copy the list to avoid modifying original list.
        inputs = list(ctx.inputs)
        tensor_indices = ctx.tensor_indices
        tensors = ctx.saved_tensors # 获取前面保存的参数,也可以使用self.saved_variables
​
        # Fill in inputs with appropriate saved tensors.
        for i, idx in enumerate(tensor_indices): # 利用存储的张量重新设置input
            inputs[idx] = tensors[i]
​
        # Stash the surrounding rng state, and mimic the state that was
        # present at this time during forward.  Restore the surrounding state
        # when we're done.
        # 存储目前rng状态,模拟前向传播状态,最后恢复目前状态
        rng_devices = [] 
        if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
            rng_devices = ctx.fwd_gpu_devices
        with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
            if ctx.preserve_rng_state:
                torch.set_rng_state(ctx.fwd_cpu_state) # 恢复前向传播时候的设备状态
                if ctx.had_cuda_in_fwd:
                    set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
            detached_inputs = detach_variable(tuple(inputs))
            with torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd):
                # 利用前向传播函数再次计算
                outputs = ctx.run_function(*detached_inputs)
​
        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
​
        # run backward() with only tensor that requires grad
        outputs_with_grad = [] # 激活值
        args_with_grad = [] # 梯度
        # 从前向传播计算的结果中筛选需要传播的张量
        for i in range(len(outputs)): 
            if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
                outputs_with_grad.append(outputs[i])
                args_with_grad.append(args[i])
        if len(outputs_with_grad) == 0:
            raise RuntimeError(
                "none of output has requires_grad=True,"
                " this checkpoint() is not necessary")
            
        # 开始后向传播    
        torch.autograd.backward(outputs_with_grad, args_with_grad)
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
                      for inp in detached_inputs)
​
        return (None, None) + grads

4.3 Конвейерный режим

Давайте посмотрим, как конвейерный режим выполняет Checkpoint.

Конвейерный параллельный режим Pytorch вдохновлен GPipe и упоминается в его комментариях.

пройти черезCheckpointFunction, pytorch может объединить пересчет и рекурсивное обратное распространение в функцию автоматического вывода, поэтому, когда прибудет градиент, начнется пересчет. Но в конвейерном режиме, чтобы сократить время простоя графического процессора, пересчет должен произойти до поступления градиента (поскольку пересчет фактически не зависит от градиента, пересчет можно выполнить до поступления градиента, чтобы получить значение активации, и подождать для получения градиента обратного распространения.После этого значения активации агрегируются для выполнения собственного расчета градиента).

Чтобы решить эту проблему, pytorch вводит две функции автоматического вывода: class:Recompute and class:Checkpoint, соответственно, представляют собой пересчет и рекурсивное обратное распространение, которое должно разделить CheckpointFunction в обычном режиме на два этапа, чтобы эти две функции можно было использовать для управления механизмом автоматического деривации и CUDA. В частности, в классе:Recompute and class:CheckpointВставьте синхронизацию CUDA между классами следующим образом:CheckpointОтложить до конца полную копию градиента.

Разделяя сегменты, несколько этапов конвейера могут работать параллельно.

4.3.1 Пример

Сначала мы можем взглянуть на код test/distributed/pipeline/sync/test_checkpoint.py.

Благодаря умной печати журнала мы можем увидеть использование контрольной точки в прямом и обратном распространении во время выполнения.

Конечным результатом временной шкалы является ["a:вперед", "b:вперед", "b:вперед", "b:назад", "a:вперед", "a:назад"],

Два из них соответствуют проходу вперед, Checkpoint(Log[b]), Checkpoint(Log[a]).

@pytest.mark.parametrize("device", devices)
def test_serial_checkpoints(device):
    # Copied from https://github.com/pytorch/pytorch/pull/18568.
    timeline = []
​
    class Log(torch.autograd.Function):
        @staticmethod
        def forward(ctx, name, x):
            ctx.name = name
            timeline.append(f"{name}:forward")
            return x.detach()
​
        @staticmethod
        def backward(ctx, grad_output):
            name = ctx.name
            timeline.append(f"{name}:backward")
            return None, grad_output
​
    a = torch.rand(1, device=device, requires_grad=True)
    b = torch.rand(1, device=device, requires_grad=True)
​
    # Increase the next function sequence number.
    _ = a + 1 + 2 + 3 + 4 + 5
​
    # 这里意味着最后 backward 实际会运行"a:forward", "a:backward"
    a = checkpoint(partial(Log.apply, "a"), a)
​
    a, phony = fork(a)
    b = join(b, phony)
​
    # 这里意味着最后 backward 实际会运行"b:forward", "b:backward"
    b = checkpoint(partial(Log.apply, "b"), b)
​
    c = torch.cat((a, b))
​
    out = c.sum()
​
    #                        +--> {a} --Checkpoint(Log)--> {a}
    # {out} --Sum--> {c} --Cat     ^-----------------------------+
    #                        +--> {b} --Checkpoint(Log)--> {b} --First--> {b}
    out.backward()
​
    assert timeline == ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"]
    #    |----------------------|  |-----------------------|  |-----------------------|
    #          forward pass            Checkpoint(Log[b])         Checkpoint(Log[a])

4.3.2 Общие переменные

class:Recompute and class:CheckpointВ частности, общие переменные сохраняются через контекст Context.

# Types for shared memory between Checkpoint and Recompute.
​
Recomputed = Tuple[TensorOrTensors, Tensors]  # (output, input_leaf)
RNGStates = Tuple[Tensor, Optional[Tensor]]  # (cpu_rng_state, gpu_rng_state)
​
class Context:
    """The common interface between the :class:`Checkpoint` and
    :class:`Recompute` context.
    """
​
    recomputed: Deque[Recomputed]
    rng_states: Deque[RNGStates]
    function: Function
    input_atomic: bool
​
    saved_tensors: Tuple[Tensor, ...]
​
    def save_for_backward(self, *tensors: Tensor) -> None:  # pragma: no cover
        pass

4.3.3 rng state

В зависимости от среды выполнения состояние ГСЧ может по-разному влиять на производительность, поэтому необходимо сохранять текущее состояние ГСЧ устройства во время каждой контрольной точки и восстанавливать текущее состояние ГСЧ устройства перед повторным вычислением.

Методы save_rng_states и restore_rng_states используются для доступа к состояниям RNG соответственно.

def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None:
    """:meth:`Checkpoint.forward` captures the current PyTorch's random number
    generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.
​
    .. seealso:: :ref:`Referential Transparency`
​
    """
    cpu_rng_state = torch.get_rng_state()
​
    gpu_rng_state: Optional[Tensor]
    if device.type == "cuda":
        gpu_rng_state = torch.cuda.get_rng_state(device)
    else:
        gpu_rng_state = None
​
    rng_states.append((cpu_rng_state, gpu_rng_state))
​
​
@contextmanager
def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]:
    """:meth:`Recompute.backward` restores the random number generator states
    captured by :func:`save_rng_states` within its context.
​
    .. seealso:: :ref:`Referential Transparency`
​
    """
    cpu_rng_state, gpu_rng_state = rng_states.pop()
​
    gpu_devices: List[torch.device] = []
    if device.type == "cuda":
        gpu_devices.append(device)
​
    with torch.random.fork_rng(gpu_devices):
        torch.set_rng_state(cpu_rng_state)
        if gpu_rng_state is not None:
            torch.cuda.set_rng_state(gpu_rng_state, device)
        yield
​

4.3.4 Checkpoint

Checkpoint и последующий Recompute должны разделить код контрольной точки в обычном режиме на два этапа (функция forward разделена на два этапа, а back функция также разделена на два этапа), чтобы можно было лучше использовать конвейер.

class Checkpoint(torch.autograd.Function):
    @staticmethod
    # type: ignore[override]
    def forward(
        ctx: Context,
        phony: Tensor,
        recomputed: Deque[Recomputed],
        rng_states: Deque[RNGStates],
        function: Function,
        input_atomic: bool,
        *input: Tensor,
    ) -> TensorOrTensors:
        ctx.recomputed = recomputed
        ctx.rng_states = rng_states
​
        # 存RNG状态
        save_rng_states(input[0].device, ctx.rng_states)
​
        ctx.function = function
        ctx.input_atomic = input_atomic
        # 为BP做准备,其实目前没有实现
        ctx.save_for_backward(*input)
​
        # 进行前向计算
        with torch.no_grad(), enable_checkpointing():
            output = function(input[0] if input_atomic else input)
​
        return output
​
    @staticmethod
    def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]:  # pragma: no cover
        # 从保存的重计算变量中弹出所需变量
        output, input_leaf = ctx.recomputed.pop() 
​
        if isinstance(output, tuple):
            tensors = output
        else:
            tensors = (output,)
            
        if any(y.requires_grad for y in tensors):
            tensors = tuple([x for x in tensors if x.requires_grad])
            # 进行自动微分
            torch.autograd.backward(tensors, grad_output)
​
        grad_input: List[Optional[Tensor]] = [None, None, None, None, None]
        grad_input.extend(x.grad for x in input_leaf)
        return tuple(grad_input)

4.3.5 Recompute

Пересчет — это пересчет промежуточных переменных на основе сохраненной информации.

class Recompute(torch.autograd.Function):
  
    @staticmethod
    # type: ignore[override]
    def forward(
        ctx: Context,
        phony: Tensor,
        recomputed: Deque[Recomputed],
        rng_states: Deque[RNGStates],
        function: Function,
        input_atomic: bool,
        *input: Tensor,
    ) -> Tensor:
        ctx.recomputed = recomputed
        ctx.rng_states = rng_states
​
        ctx.function = function
        ctx.input_atomic = input_atomic
        ctx.save_for_backward(*input)
​
        return phony
​
    @staticmethod
    def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]:  
        input = ctx.saved_tensors
        input_leaf = tuple(x.detach().requires_grad_(x.requires_grad) for x in input)
​
        # 取出保存的RNG状态,进行前向计算,得到中间变量
        with restore_rng_states(input[0].device, ctx.rng_states):
            with torch.enable_grad(), enable_recomputing():
                output = ctx.function(input_leaf[0] if ctx.input_atomic else input_leaf)
​
        # 保存变量,为Checkpoint使用
        ctx.recomputed.append((output, input_leaf))
​
        grad_input: List[None] = [None, None, None, None, None]
        grad_input.extend(None for _ in ctx.saved_tensors)
        return tuple(grad_input)

4.3.6 Pipeline

4.3.6.1 Task

Мы начнем с рассмотрения класса Task. Код находится по адресу: torch/distributed/pipeline/sync/worker.py.

Как видно из комментариев, Task используется для расчета микробатча на партиции.

computeМожет выполняться параллельно в рабочих потоках.

finalizeдолжен быть вcomputeВыполняется после завершения.

class Task:
    """A task represents how to compute a micro-batch on a partition.
​
    It consists of two parts: :meth:`compute` and :meth:`finalize`.
    :meth:`compute` should be executed in worker threads concurrently.
    :meth:`finalize` should be executed after when worker threads complete to
    execute :meth:`compute`.
​
    :meth:`compute` might be boosted by worker threads. Because it produces
    several CUDA API calls by user code. In PyTorch, parallel CUDA API calls
    are not serialized through GIL. So more than one CUDA API call can be
    produced at the same time.
​
    """
​
    def __init__(
        self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]],
    ) -> None:
        self.stream = stream
        self._compute = compute
        self._finalize = finalize
        self._grad_enabled = torch.is_grad_enabled()
​
    def compute(self) -> Batch:
        with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
            return self._compute()
​
    def finalize(self, batch: Batch) -> None:
        if self._finalize is None:
            return
        with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
            self._finalize(batch)
4.3.6.2 compute

Это вычислительная функция класса Pipeline.

Логика Pipeline показана в его комментариях (комментарии PyTorch действительно информативны). Дело в томTask(streams[j], compute=chk.checkpoint, finalize=chk.recompute)Вот как пройти контрольно-пропускной пункт.

Как видите, метод перерасчета будет установлен как метод финализации Задачи, а затем будет запланирован перерасчет.

class Pipeline:
    """The pipeline parallelism for Pipe."""
    
    def compute(
        self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
    ) -> None:
        """Runs tasks with synchronization to copy streams."""
        partitions = self.partitions
        devices = self.devices
        copy_streams = self.copy_streams
        checkpoint_stop = self.checkpoint_stop
​
        # Disable checkpointing if in eval mode.
        if not self.partitions[0].training:
            checkpoint_stop = 0
​
        n = len(partitions)
        streams = [current_stream(d) for d in devices]
        exc_info: Optional[ExcInfo] = None
​
        # With checkpointing, the autograd graph looks like this diagram:
        # ┌─────┸──────┐
        # │    Copy    │
        # └─────┰──────┘   (fence)
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        #       ┃          (compute)
        # ┌─────┸──────┐
        # │    Wait    │ [1] Synchronize the current stream with the copy stream.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │ Checkpoint │ [2] Compute a partition within checkpointing.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │    Wait    │ [3] Synchronize the copy stream with the current stream.
        # └─────┰──────┘
        #       ┠ ─ ─ ─ ┐
        #       ┃ ┌─────┴─────┐
        #       ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
        #       ┃ └─────┬─────┘
        #       ┠ ─ ─ ─ ┘
        #       ┃
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        # ┌─────┸──────┐   (fence)
        # │    Copy    │
        # └─────┰──────┘
        for i, j in schedule:
            batch = batches[i]
            partition = partitions[j]
​
            # Synchronize with the copied input. ([1] in the diagram)
            if j != 0:
                _wait(batch, copy_streams[j][i], streams[j])
​
            # Determine whether checkpointing or not.
            checkpoint = i < checkpoint_stop
            if checkpoint:
​
                def function(
                    input: TensorOrTensors,
                    partition: nn.Sequential = partition,
                    skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                    chunk_id: int = i,
                    part_id: int = j,
                ) -> TensorOrTensors:
                    with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
                        return partition(input)
​
                # 这里进行处理
                chk = Checkpointing(function, batch)
                # 分别设置了chk.checkpoint 和 chk.recompute
                task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
                del function, chk
​
            else:
​
                def compute(
                    batch: Batch = batch,
                    partition: nn.Sequential = partition,
                    skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                    chunk_id: int = i,
                    part_id: int = j,
                ) -> Batch:
                    with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
                        return batch.call(partition)
​
                task = Task(streams[j], compute=compute, finalize=None)
                del compute
​
            # Compute tasks in parallel. ([2] in the diagram)
            self.in_queues[j].put(task) # 将task插入到 pipeline的queue,这样可以并行。
​
        for i, j in schedule: 
            ok, payload = self.out_queues[j].get()
​
            # Hold the first exception.
            if exc_info is not None:
                continue
            elif not ok:
                exc_info = cast(ExcInfo, payload)
                continue
​
            # 取出 task    
            task, batch = cast(Tuple[Task, Batch], payload)
​
            # The copy stream synchronizes to copy the output. ([3] in the
            # diagram)
            if j != n - 1:
                _wait(batch, streams[j], copy_streams[j][i])
​
            # Finalize tasks. If checkpointing is enabled, here the
            # recomputation is scheduled at backpropagation. ([4] in the
            # diagram)
            with use_device(devices[j]):
                task.finalize(batch) # 计划进行重计算
​
            batches[i] = batch
​
        # Fail at the first exception.
        if exc_info is not None:
            raise exc_info[0].with_traceback(exc_info[1], exc_info[2])
​

0x05 Реализация Gpipe

Когда Gpipe выполняет обратное распространение, он может пересчитать функцию прямого распространения F_k на k-м ускорителе.

5.1 Функция API _Rematerialize

Во-первых, давайте посмотрим на методы API.

В builder.py есть функция _Rematerialize, которую можно использовать для переноса слоя, который необходимо пересчитать.

  def _Rematerialize(self, name, body):
    """Forces rematerialization on FProp of the body layer."""
    return builder_layers.RematerializationLayer.Params().Set(
        name=name, body=body)

5.2 Уровень упаковки RematerializationLayer

RematerializationLayer — это слой-оболочка, который имеет:

FProp заключается в том, чтобы обернуть инкапсулированный слой в функцию Fn, а затем вызвать py_utils.RematerializeFn для передачи Fn вместе с входной переменной.

class RematerializationLayer(base_layer.BaseLayer):
  """A wrapper layer with rematerialization."""
​
  @classmethod
  def Params(cls):
    p = super().Params()
    p.Define('body', None,
             'The main layer whose FProp will be wrapped by RematerializeFn.')
    return p
​
  def __init__(self, params):
    super().__init__(params)
    self.CreateChild('body', self.params.body)
​
  def FProp(self, theta, *xs):
    input_list = theta.body.Flatten() # 得到theta
    theta_len = len(input_list)
    input_list += list(xs) # 得到输入参数
    input_len = len(input_list)
​
    def Fn(*args): # 包装函数,会调用被封装层的 FProp
      body_theta = theta.body.Pack(args[:theta_len])
      return self.body.FProp(body_theta, *args[theta_len:input_len])
​
    return py_utils.RematerializeFn(Fn, *input_list) # 调用,执行FProp,并且做Gradient checking
​
  @classmethod
  def FPropMeta(cls, p, *args): # 就是传播被封装层的信息
    py_utils.CheckShapes(args)
    return p.body.cls.FPropMeta(p.body, *args)
​

3.2.3 функция градиентов тензорного потока

RematerializeFn вызывает функцию градиентов тензорного потока для вычисления градиентов, поэтому нам нужно объяснить.

В тензорном потоке функция градиентов может автоматически вычислять градиент функции. Нам просто нужно спроектировать нашу функцию и вызвать ееtf.gradientsфункция подойдет.

Параметры tf.gradients() следующие, где

  • tf.gradients()выполнитьysправильноxsискать вывод
  • grad_ysтакже является списком, длина которого равнаlen(ys). Смысл этого параметра в том, чтобыxsПроизводный вес каждого элемента в .
tf.gradients(ys, xs, 
             grad_ys=None, 
             name='gradients',
             colocate_gradients_with_ops=False,
             gate_gradients=False,
             aggregation_method=None,
             stop_gradients=None)

5.4 Функциональная функция RematerializeFn

RematerializeFn — это последняя функция, которая вызывает fn и повторно материализует fn во время обратного распространения.

def RematerializeFn(fn, *xs):
  """Calls fn and rematerializes fn in the backward pass.
​
  `fn(*xs) -> ys`, where xs and ys can be a single tensor or a tuple of tensors.
​
  Args:
    fn: A python function to be rematerialized in the backprop pass.
    *xs: A single tensor or a list/tuple of tensors. `xs` are input args to the
      fn function.
​
  Returns:
    `fn(*xs)`
  """
  initial_step_seed = GetStepSeed()
  final_step_seed = MaybeGenerateSeedFromScope()
​
  def Backward(fwd_xs, fwd_ys, d_fwd_ys):
    """The backward function that rematerializes forward outputs."""
    del fwd_ys # 去掉传入的参数,因为在内部需要用备份的Checkpoint来处理
    always_true = tf.random.uniform([]) < 2.0
    # Alternatively, can do this:
    # tf.where(tf.math.is_nan(x),
    #          tf.constant(float('nan'), dtype=x.dtype) * tf.ones_like(x),
    #          x)
    bak_xs = [tf.where(always_true, x, tf.zeros_like(x)) for x in fwd_xs.xs] # 依据Checkpoint来生成 bak_xs
    for dst, src in zip(bak_xs, xs):
      dst.set_shape(src.shape)
    ResetStepSeed(initial_step_seed)
    ys = fn(*bak_xs) # 依据Checkpoint来重新生成ys
    MaybeResetStepSeed(final_step_seed)
    dxs = tf.gradients(ys, bak_xs, grad_ys=d_fwd_ys) # ys 对 bak_xs 求导
    dxs_final = [] # 聚合
    for dx, x in zip(dxs, bak_xs):
      if dx is None:
        dxs_final.append(tf.zeros_like(x))
      else:
        dxs_final.append(dx)
    assert len(dxs_final) == len(bak_xs)
    return NestedMap(
        initial_step_seed=tf.zeros_like(initial_step_seed), xs=dxs_final)
​
  ys_shapes = []
​
  # TODO(huangyp, yonghui): Check Forward doesn't use any stateful random ops.
  def Forward(fwd_xs):
    """Forward function plus sanity checks."""
    for dst, src in zip(fwd_xs.xs, xs):
      dst.set_shape(src.shape)
    ResetStepSeed(fwd_xs.initial_step_seed)
    ys = fn(*fwd_xs.xs) # 正常计算
    # Some sanity check.
    assert not GetExtraInputs()
    assert not GetExtraArgs()
    assert not GetExtraVars()
    if isinstance(ys, tuple):
      for y in ys:
        assert isinstance(y, tf.Tensor)
        ys_shapes.append(y.shape)
    else:
      assert isinstance(ys, tf.Tensor)
      ys_shapes.append(ys.shape)
    return ys
​
  ys = CallDefun(
      Forward,
      NestedMap(initial_step_seed=initial_step_seed, xs=xs),
      bak=Backward)
  if isinstance(ys, tuple):
    for y, s in zip(ys, ys_shapes):
      y.set_shape(s)
  else:
    ys.set_shape(ys_shapes[0])
  # TODO(b/129159299): The ResetStepSeed below is needed to work around this
  # bug, which is a problem with global tensors being shared by different
  # inference graphs. It should be replaced with the new step seed value
  # returned from the Forward function when the bug is fixed.
  MaybeResetStepSeed(final_step_seed)
  return ys

CallDefun определяется следующим образом: он должен инкапсулировать прямой и обратный вызовы. Среди них роль Function заключается в построении графовой функции TensorFlow на основе вызываемого

def CallDefun(fwd, args=None, bak=None, bak_as_function=False, device=None):
  """Wraps fwd in a defun with custom gradient bak and calls it with args.
​
  Args:
    fwd: A callable xs: Nested Structure -> ys: Nested Structure.
    args: A Nested Structure of tf.Tensor or None.
    bak: A callable xs, ys, dys: Nested Structure -> dxs[, dcapture]: Nested
      Structure. The custom backprop function for fwd. bak needs to return
      dcapture if fwd uses any implicitly captured tensors, whose gradients are
      dcapture.
    bak_as_function: Whether to create a TF graph function for bak.
    device: the device on which to run fwd and bak.
​
  Returns:
    A Nested Structure equivalent to what fwd(args) computes.
  """
  if args is not None:
    args = Transform(tf.convert_to_tensor, args)
  sigs = Function(
      fwd_sig=TensorSpecs(args),
      bak=bak,
      bak_as_function=bak_as_function,
      device=device)(
          fwd=fwd)
  if args is None:
    return sigs()
  else:
    return sigs(args)

На этом анализ GPipe завершен, и в следующей статье начнется анализ PipeDream, так что следите за обновлениями.

Кроме того, в отношении PyTorch Pipeline будет специальная серия для последующего анализа.

0xEE Личная информация

★★★★★★Думая о жизни и технологиях★★★★★★

Публичный аккаунт WeChat:мысли Росси

ссылка 0xFF

lingvo framework день чтение заметок

Tensorflow понимает, что градиенты нескольких мини-пакетных вычислений сначала накапливаются, а затем распространяются обратно.

Накопление градиента с помощью tensorflow2

В десять раз время расчета модели увеличилось всего на 20%: плагин для замены градиента с открытым исходным кодом OpenAI

PipeDream: Fast and Efficient Pipeline Parallel DNN Training

Paper Interpretation Series 5: Microsoft Stanford и другие PipeDream быстро обучают крупномасштабные нейронные сети

На данный момент 231 you.GitHub.IO/neural-net…

Блог Woohoo.cn на.com/geek found/afraid/14…

Технология оптимизации видеопамяти во время обучения - слияние ОП и контрольная точка градиента

Pytorch Notes 04 - Пользовательский torch.autograd.Function

Учебное пособие по Autograd для PyTorch

Пользовательское расширение pytorch (3) - простое определение и случай torch.autograd.Function

Пользовательское расширение pytorch (2) - torch.autograd.Function завершает пользовательский слой

torch.autograd интерпретации исходного кода PyTorch: подробное объяснение расчета градиента

Обратное распространение

Перевод примечаний к курсу CS231n: примечания по обратному распространению