0x00 сводка
GPipe — это параллельная библиотека для обучения нейронных сетей, разработанная командой Google Brain и поддерживающая сверхкрупномасштабные модели. В этой статье представлена ее функция пересчета, которую можно проверить с помощью других реализаций.
Другие статьи из этой серии:
0x01 Обзор
1.1 Предыдущий обзор
Как упоминалось выше, существует несколько необходимых параллельных технологий для обучения распределенной модели:
- Параллельный конвейер, особенно как автоматически установить конвейер;
- накопление градиента;
- обратный пересчет;
- стратегия 1F1B (будем использовать анализ PipeDream);
В предыдущей статье мы рассказали, как Gpipe реализует конвейерный параллелизм и накопление градиентов.
Есть проблема с конвейерным параллелизмом: использование памяти слишком велико. Если промежуточный результат (активация) каждого микропакетного прямого расчета потребляется обратным расчетом, необходимо кэшировать n копий (количество раз накопления градиента) полной прямой активации в видеопамяти. В это время необходимо использовать еще один важный метод: пересчет (контрольная точка).
На основе статьи «Обучение глубоких сетей с сублинейной стоимостью памяти» в этой статье анализируется исходный код pytorch и Gpipe, в надежде получить конкретное представление о технологии «градиентных контрольных точек».
1.2 Gradient checkpointing
В 2016 году команда Чена Тяньци предложила такие технологии, как «контрольные точки градиента / активации (обратный пересчет)», связанные с сублинейной оптимизацией памяти, с целью уменьшить использование памяти, вызванное промежуточной активацией во время обучения глубокому обучению. Технология контрольных точек — это своего рода сублинейная оптимизация памяти, в дополнение к другим технологиям, таким как разгрузка ЦП (разгрузка ЦП широко используется в фреймворке Microsoft Deepspeed).
Градиентная контрольная точка — это систематический подход к снижению потребления памяти во время глубокого обучения нейронной сети путем повторного запуска сегментов прямого распространения для каждого сегмента, установленного в качестве контрольной точки при обратном распространении:
- Методы градиентных контрольных точек сосредоточены на уменьшении накладных расходов памяти на хранение промежуточных результатов (карт объектов) и градиентов, поскольку во многих распространенных глубоких сетях промежуточные результаты намного больше по сравнению с параметрами модели.
- Градиентная контрольная точка — это метод, который обменивает время (вычислительную мощность) на пространство (видеопамять), сжимая след модели за счет уменьшения сохраненного значения активации, но значение активации, которое не сохраняется, должно быть пересчитано при расчете градиента, т. е. занимает в два раза больше времени расчета стоимости прямого распространения.
- В частности, это установка некоторых контрольных точек градиента и выпуск промежуточных результатов за пределы контрольных точек.В дальнейшем, если прямой результат будет найден не в видеопамяти в процессе обратного распространения, будет найдена ближайшая контрольная точка градиента, а затем Вычислить, восстанавливая освобожденные тензоры.
0x02 Базовые знания
2.1 Как работает вывод
заимствовано отсюдаТехнология оптимизации видеопамяти во время обучения - слияние ОП и контрольная точка градиентаидеи.
Модели DNN состоят из ряда слоев разных типов (например, сверточные слои, полносвязные слои, объединяющие слои).
Ключом к обратному распространению является «автоматический вывод цепочки», но на самом деле BP также добавляет к этой основе небольшой механизм динамического программирования. Общий BP состоит из следующих двух шагов:
-
прямое проведение. Взяв в качестве примера классификацию изображений, текущие модели сначала делают прогнозы на небольшом наборе обучающих выборок (также известных как мини-пакеты). Этот процесс называется прямой проводимостью.
- Для прогнозирования входные данные из мини-пакета подаются на первый слой модели.
- Затем каждый слой вычисляет функцию на своем входе, чтобы сгенерировать выход для следующего слоя. Прямая проводимость записывает следующие два значения: выходное значение промежуточного узла и градиент выходного значения по отношению к входному значению.
- Результатом последнего слоя является предсказание класса. На основе прогнозируемых меток модели и фактических меток для каждого изображения выходной слой вычисляет потери (или ошибки).
-
Расчет градиента обратного распространения. Обратное распространение — это процесс вычисления градиента конечного выходного значения сети по отношению к выходу этого слоя. То есть, начиная с вывода, распространяя обратно значения градиента,Рассчитайте градиент выходного значения для каждой промежуточной переменной и сохраните его.. Каждый слой вычисляет ошибку предыдущего слоя и обновляет веса (градиенты потерь) для всех связанных слоев, что сдвинет прогнозы модели к желаемому результату.
В процессе возврата градиента необходимо использовать выходное значение узла, но когда вычисление градиента выполняется в обратном распространении, БП не будет выполнять повторные вычисления, причина в том, что промежуточные переменные сохраняются при прямой передаче , то естьВыходное значение каждого промежуточного узла. BP непрерывно выполняет обратное распространение градиентов и сохраняет промежуточные градиенты до тех пор, пока не будут решены все промежуточные значения вычислительного графа и градиенты начальных значений.
Давайте посмотрим, как работает обратная обработка.
Так называемая система автоматического вывода на самом деле является «полуавтоматической»: она не находит напрямую аналитическую форму производной сложной функции, а строит расчетный график и заранее написанные правила вывода для базовых функций, объединенных с цепочкой правило вывода Реализован автоматический вывод.
Возьмем функцию в качестве примера для иллюстрации, ее выражение выглядит следующим образом:
f(x) = x * (x + 1)
Аналитическая формула его градиента, полученная простым математическим выводом:f'(x) = x + 1 + x
; Сначала отложите этот результат и посмотрите, как схема автоматического вывода шаг за шагом находит этот результат, и нарисуйте диаграмму расчета следующим образом:
+---------+
| |
+------>+ x + 1 +----+
| | | | 3
2 | +---------+ |
| |
| v
+-----+--+ ++------+
| | | |
+------> | x +----------------> | + +---------->
| | 1 | |
+--------+ +-------+
На вычислительном графе обратное распространение сначала проходит через операцию умножения в соответствии с приведенными выше правилами вывода:
- Градиент на пути 1 равен
x + 1
; - Градиент на пути 3
x
; - Путь 3 распространяется обратно через путь 2, за исключением того, что его градиент
x + 1
Кроме того, умножьте градиент пути 21
, - Путь 2 и Путь 1 сходятся вместе, поэтому окончательный градиент
x + 1(路径1)+ 1 * x(路径2)= x + 1 + x
, что в точности равно результату, рассчитанному по математической формуле;
Платформа автоматической деривации опирается на эти основные правила и правила деривации цепочки, чтобы работать эффективно и точно.
В процессе обучения подавляющего большинства нейронных сетей некоторые промежуточные переменные, полученные при прямом проходе, очень полезны (для удобства вывода) при расчете обратного прохода. На практике лучше всего запрограммировать кэширование этих промежуточных переменных, чтобы их можно было использовать во время обратного распространения. Поэтому основную часть занимаемой памяти занимает промежуточный результат, представляющий собой так называемую «карту признаков». В этой статье x — это промежуточный результат (карта объектов), выдаваемый предыдущим слоем.
При применении правила вывода умножения мы должны заранее сохранить промежуточные результаты x и x+1.. Обратите внимание, что правила умножения и его вывода, определенные фреймворком, являются общими.Левая и правая части умножения могут быть двумя несвязанными значениями, поэтому они должны сохраняться одновременно. То есть x + 1 может быть ( x + y ) + z .... в других функциях и может содержать другие входные переменные, поэтому его нельзя вычислить из входного x по простой формуле, например + 1 .
Без учета оптимизации самого фреймворка использование памяти включает в себя x и x + 1. Обратите внимание, что x — это не отдельное значение, а что-то вроде32x32x128
Карты объектов такого размера.
2.2 Контрольная точка градиента
Как упоминалось в предыдущем разделе, в оригинальном виде нейронных сетей:
- В прямой функции значения функции активации каждого слоя необходимо сохранять после расчета, потому что они должны быть использованы при расчете обратного распространения.
- При движении назад градиент рассчитывается на основе значения функции потерь и соответствующего значения функции активации слоя.
- Следовательно, нам нужно кэшировать n копий (количество накоплений градиента) полных прямых активаций в видеопамяти. То есть в этом случае использование памяти пропорционально количеству слоев.
Поэтому в настоящее время существует проблема с конвейерным параллелизмом: использование памяти слишком велико.
Можно ли не хранить значение активации? Например, в обратном направлении вы можете перемотать вперед, когда вам нужно активировать значение функции.
Если мы не сохраним ни одного из них, мы все пересчитаем через forward? Тогда это слишком трудоемко в больших моделях. Следовательно, мы можем выбрать компромиссный метод, например, сохранить только значения функции активации некоторых слоев. Когда в обратном направлении требуется значение функции активации, просто возьмите самое последнее значение активации. Так была внедрена важная технология: Checkpointing.
2.3 Содержание статьи
2.3.1 Основные документы
Основная идея Gpipe's Checkpointing исходит из следующих двух статей:
- Andreas Griewank and Andrea Walther. Algorithm 799: revolve: an implementation of check- pointing for the reverse or adjoint mode of computational differentiation. ACM Transactions on Mathematical Software (TOMS), 26(1):19–45, 2000.
- Tianqi Chen, Bing Xu, Chiyuan Zhang, and Carlos Guestrin. Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174, 2016.
Основная идея заключается в обмене вычислительной мощности на память (вычисления на видеопамять, а промежуточные результаты, необходимые для обратного вывода, пересчитываются из контрольной точки), а пропускную способность на видеопамять.
2.3.2 Бумажные обучающие глубокие сети с сублинейной стоимостью памяти
2.3.2.1 Основная идея
Мы в основном смотрим на эту бумагу.
Контрольные точки упоминаются в статье Чена Тяньци «Обучение глубоких сетей с сублинейной стоимостью памяти», опубликованной в 2016 году, также известной как сублинейная оптимизация памяти. Есть две идеи для сублинейной оптимизации памяти: контрольные точки и разгрузка ЦП:
- Основная идея контрольной точки состоит в том, чтобы пометить небольшое количество тензоров (тензоров, которые являются контрольными точками) в прямой сети, и только эти отмеченные тензоры будут сохранены в прямом расчете.Тензор временно пересчитывается один раз вперед, чтобы получить его. Таким образом, большое количество активаций не нужно сохранять до обратного вычисления, что эффективно сокращает жизненный цикл большого количества тензоров и значительно повышает эффективность повторного использования памяти.
- Идея разгрузки ЦП аналогична технологии «виртуальной памяти» в компьютерных операционных системах (временная замена памяти, которая обычно не используется, на диск, тем самым увеличивая общий объем памяти).При глубоком обучении память графического процессора (Память устройства) Характеристики дорогие, высокая скорость и небольшая емкость, в то время как основная память ЦП (Host Memory) характеризуется дешевизной, относительно низкой скоростью и большой емкостью; затем временно замените некоторые временно неиспользуемые активации в прямом расчете на Основная память ЦП, а затем подкачать ее в видеопамять ГП, когда требуется обратный расчет, что также может сэкономить видеопамять.
Две сублинейные оптимизации памяти достигают оптимизации видеопамяти разными способами: контрольные точки обменивают видеопамять с дополнительными вычислительными затратами, а разгрузка ЦП обменивает видеопамять с дополнительными затратами на передачу.
2.3.2.2 Оптимизация контрольных точек
На рисунке выше показано сравнение графика расчета до и после Checkpointing.
Серый слева — конфигурация сети.
Средний график нормального градиента представляет собой процесс распространения в прямом и обратном направлении нормальной сети.
Градиентный график, оптимизированный для памяти, справа является результатом применения градиента-контрольной точки. Для дальнейшего сокращения памяти некоторые промежуточные результаты удаляются и при необходимости восстанавливаются из дополнительных вычислений вперед.
- Во-первых, нейронная сеть разделена на несколько частей (три сегмента на рисунке справа), алгоритм запоминает только вывод каждого сегмента, и в каждом сегменте удаляет все промежуточные результаты.
- Во-вторых, на этапе обратного распространения мы можем повторно вычислить отброшенные промежуточные результаты, пробежав вперед от самого последнего записанного результата.
- Таким образом, мы оплачиваем только стоимость памяти для хранения выходных данных каждого сегмента плюс максимальную стоимость памяти для обратного распространения по каждому сегменту.
такgradient-checkpoint
этоДело не в том, что промежуточные результаты не нужны, но есть способ вычислить промежуточные результаты, которые были отброшены ранее, в режиме реального времени в процессе вывода..
Пересчет не предназначен только для конвейерного параллелизма, и раньше он в основном использовался в сценариях с одной картой или параллельными данными. Но эта оптимизация очень критична при конвейерном параллелизме, потому что она делает ненужным кэширование всех активаций в прямом направлении, а требует кэширования только очень небольшого числа (например, только одного слоя Transformer Layer), конкретного тензора, который проверено, это значительно экономит накладные расходы памяти при конвейерном параллелизме.
0x03 OpenAI
Представлено на OpenAIgradient-checkpoint
это тезисTraining Deep Nets with Sublinear Memory CostРеализация идеи, т.к. ее документация относительно полная (GitHub.com/openlove/grad…), мы можем извлечь из этого уроки.
Общая идея состоит в том, чтобы установить несколько контрольных точек в середине нейронной сети и сохранять контрольную точку каждый sqrt (n) для карты промежуточных результатов. Все промежуточные результаты, кроме контрольных, отбрасываются, а когда для обратного распространения требуется промежуточный результат, расчет начинается с ближайшей контрольной точки, что не только экономит видеопамять, но и позволяет избежать утомительного процесса расчета с нуля.
3.1 Расчетный граф
Для простой n-слойной нейронной сети с прямой связью вычислительный граф для получения градиента выглядит следующим образом:
детали следующим образом:
- Иерархические активации нейронной сети соответствуют узлам, помеченным f, и при прямом распространении все эти узлы необходимо вычислять последовательно.
- Функция потерь использует метки b-узлов для активаций и градиентов этих параметров иерархии, и во время обратного распространения все эти узлы необходимо вычислять в обратном порядке.
- Вычисление активаций f-узлов является необходимым условием для дальнейшего вычисления градиентов b-узлов, поэтому f-узлы остаются в памяти после прямого распространения.
- Эти активации удаляются из памяти только тогда, когда обратное распространение выполняется достаточно далеко, чтобы вычисление соответствующих градиентов больше не требовало использования активаций на более поздних уровнях или дочерних элементах f . Это означает, что простое обратное распространение требует памяти, которая линейно растет с количеством слоев в нейронной сети.
3.2 Пересчет
Простое обратное распространение уже оптимально с вычислительной точки зрения, потому что каждый узел нужно вычислить только один раз. Однако если мы захотим пересчитать узлы, то сможем сэкономить много памяти. Когда нам нужно значение активации узла, мы можем просто повторно вычислить значение активации узла с прямым распространением. Мы можем выполнять вычисления последовательно, пока не вычислим узлы, которые необходимо распространить обратно со значениями активации.
Использование этой стратегии требует, чтобы память для вычисления градиентов была стабильной на количестве слоев нейронной сети n, а n было оптимальным с точки зрения памяти. Обратите внимание, однако, что количество вычислений для узлов теперь увеличено на n^2 по сравнению с предыдущим n. Каждый из n узлов пересчитывается n раз. Следовательно, вычислительный граф становится очень медленным для вычисления глубоких сетей, что делает этот подход непригодным для глубокого обучения.
3.3 Стратегия
Чтобы соблюсти баланс между памятью и вычислением, нам нужна политика, позволяющая пересчитывать узлы, но это происходит не очень часто. Наша стратегия здесь состоит в том, чтобы использовать нейронную сеть для активации подмножества, помеченного как узел. Фиолетовый узел представляет собой заданный период времени, который необходимо хранить в памяти.
Эти узлы контрольных точек остаются в памяти после прямого распространения, в то время как оставшиеся узлы пересчитываются не более одного раза. После пересчета узлы без контрольных точек остаются в памяти до тех пор, пока они больше не понадобятся для выполнения обратного распространения. Для простой нейронной сети с прямой связью узлы активации всех нейронов являются либо точками соединения, либо точками разделения графа, определяемого прямым распространением. Это означает, что нам нужно пересчитать только узел между узлом b и последней контрольной точкой в процессе обратного распространения.Когда обратное распространение достигает узла контрольной точки, который мы сохранили, тогда все узлы, пересчитанные из этого узла, находятся в памяти, и их можно удалить.
3.4 Процесс
Сначала мы устанавливаем две контрольные точки, две фиолетовые слева от первой строки на рисунке, обратите внимание, что первая фиолетовая справа — это вход.
Во-вторых, прямое распространение завершено, а обратное распространение запущено, то есть обратное распространение начинается с фиолетовой цифры 1 в нижнем ряду.
В-третьих, пришел к фиолетовому числу 2 в следующей строке, которое зависит от фиолетового числа 3 выше для расчета (напомним, что вычисление обратного распространения требует вывода прямого вычисления), это фиолетовое число 3 является контрольной точкой, которая существует в память, поэтому выполняйте обратное распространение нормально
В-четвертых, я пришел к белой цифре 4 в следующей строке.Это зависит от фиолетовой цифры 5 выше для расчета.Цифра 5 не является контрольной точкой и не находится в памяти.Нужно перезапустить расчет с контрольной точки перед ней, что есть, начните с фиолетового числа 7. . Вычисляется новая контрольная точка, а изначальную фиолетовую цифру 5 в строке выше можно удалить, потому что она больше не нужна.
В-пятых, вычислите новое фиолетовое число 4 ниже, тем самым продолжив вычисление в обратном направлении.
Из-за того, что он включает автоматическую генерацию контрольных точек, код OpenAI относительно неясен, поэтому здесь он не будет анализироваться, если вам интересно, вы можете изучить его самостоятельно.
0x04 Реализация Pytorch
Давайте теперь посмотрим на Пёрха.
4.1 Базовые знания
4.1.1 Variable & Function
В PyTorch autograd является основным содержимым всех нейронных сетей, предоставляя методы автоматического вывода для всех операций Tensor. Это среда определения-за-запуском, что означает, что обратное распространение определяется тем, как выполняется код.
autograd.Variable — это autograd в основных классах. Обертка тензора и поддерживает почти все операции, определенные на нем. После того, как вы завершили свою операцию, вы можете вызвать .backward() для автоматического расчета всех градиентов.
Еще один класс, который очень важен для реализации autograd, — функция, функция — это просто операция над переменной, такая как сложение, вычитание, умножение и деление, relu, pool и т. д. Но это не простая арифметика. В отличие от обычных операций Python или numpy, функция предназначена для вычислительных графиков и должна вычислять градиент обратного распространения. Следовательно, ему нужно не только выполнять эту операцию (прямой процесс), но также нужно использовать кеш, чтобы сохранять входные данные прямого распространения (для вычисления градиента) и поддерживать обратное распространение для вычисления градиента.
Pytorch использует переменные и функции для построения вычислительных графиков. Оглядываясь назад на Переменную, Переменная похожа на узел в графе вычислений, сохраняя результаты вычислений (включая значение активации прямого распространения, градиент обратного распространения), а Функция — как ребро в графе вычислений, реализующее вычисление Переменная и вывод новой переменной.
Таким образом, функция и переменная представляют собой автоматический механизм вывода pytorch, который определяет взаимосвязь расчета между каждой переменной.
Примечания. В последнем коде PyTorch Function был изменен на класс Node, который должен лучше представлять концепцию узлов в графе вычислений.
4.1.2 Дальнейшее понимание функции
Мы можем использовать класс autograd.Function для настройки модели, слоя, функции активации и функции потерь, которые легче понять, Фактически, это по сути функция, и это только простая или сложная функция.
4.2 Нормальный режим
Эта часть кода находится в torch/utils/checkpoint.py. Pytorch требует, чтобы пользователь указал контрольную точку, поэтому реализация относительно проста.
4.2.1 Упаковка
В torch/utils/checkpoint.py есть пакет для чекпоинта Этот комментарий очень стоит прочесть Давайте изучим его подробнее.
-
Суть Checkpointing заключается в обмене вычислений на память.
-
Checkpointing НетСохраните все промежуточные активации для всего вычислительного графа, необходимого для обратного вычисления, но вместо этого пересчитайте их в обратном распространении.
-
Во время прямого прохода функция параметра Checkpointing запускается в режиме torch.no_grad, поэтому промежуточные значения активации не рассчитываются. Вместо этого прямой проход сохраняет входной кортеж и
function
параметр. -
При обратном проходе сохраненный ввод и
function
изъятые,function
будет вычисляться снова, на этот раз с учетом промежуточных значений активации, затемИспользуйте эти значения активации для вычисления градиентов.
def checkpoint(function, *args, **kwargs):
r"""Checkpoint a model or part of the model
Checkpointing works by trading compute for memory. Rather than storing all
intermediate activations of the entire computation graph for computing
backward, the checkpointed part does **not** save intermediate activations,
and instead recomputes them in backward pass. It can be applied on any part
of a model.
Specifically, in the forward pass, :attr:`function` will run in
:func:`torch.no_grad` manner, i.e., not storing the intermediate
activations. Instead, the forward pass saves the inputs tuple and the
:attr:`function` parameter. In the backwards pass, the saved inputs and
:attr:`function` is retrieved, and the forward pass is computed on
:attr:`function` again, now tracking the intermediate activations, and then
the gradients are calculated using these activation values.
The output of :attr:`function` can contain non-Tensor values and gradient
recording is only performed for the Tensor values. Note that if the output
consists of nested structures (ex: custom objects, lists, dicts etc.)
consisting of Tensors, these Tensors nested in custom structures will not
be considered as part of autograd.
Args:
function: describes what to run in the forward pass of the model or
part of the model. It should also know how to handle the inputs
passed as the tuple. For example, in LSTM, if user passes
``(activation, hidden)``, :attr:`function` should correctly use the
first input as ``activation`` and the second input as ``hidden``
preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
the RNG state during each checkpoint.
args: tuple containing inputs to the :attr:`function`
Returns:
Output of running :attr:`function` on :attr:`*args`
"""
# Hack to mix *args with **kwargs in a python 2.7-compliant way
preserve = kwargs.pop('preserve_rng_state', True)
if kwargs:
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
return CheckpointFunction.apply(function, preserve, *args)
4.2.2 Обращение с оборудованием
Поскольку у pytorch нет возможности узнать, будет ли функция прямого прохода перемещать некоторые параметры на разные устройства, для этого требуется некоторая логика для сохранения состояния RNG для этих устройств. Хотя можно сохранить/восстановить все состояния ГСЧ для всех видимых устройств, в большинстве случаев это расточительно, поэтому в качестве компромисса pytorch сохраняет состояния ГСЧ только для всех устройств с тензорными параметрами.
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
# This will not error out if "arg" is a CPU tensor or a non-tensor type because
# the conditionals short-circuit.
fwd_gpu_devices = list(set(arg.get_device() for arg in args
if isinstance(arg, torch.Tensor) and arg.is_cuda))
fwd_gpu_states = []
for device in fwd_gpu_devices:
with torch.cuda.device(device):
fwd_gpu_states.append(torch.cuda.get_rng_state())
return fwd_gpu_devices, fwd_gpu_states
def set_device_states(devices, states) -> None:
for device, state in zip(devices, states):
with torch.cuda.device(device):
torch.cuda.set_rng_state(state)
4.2.3 Основная логика
CheckpointFunction расширяет torch.autograd.Function.
Мы можем расширить функцию для удовлетворения наших собственных потребностей, и расширение должно настроить прямую операцию функции и соответствующую обратную операцию.В то же время в прямом направлении нам нужно сохранить входное значение для обратного.
-
Прямая функция вводит тензор и вычисляет выходной тензор.
- Во время прямого прохода функция параметра Checkpointing запускается в режиме torch.no_grad, поэтому промежуточные значения активации не рассчитываются.
- Прямой проход содержит входной кортеж и
function
параметр. - Для CheckpointFunction по-прежнему необходимо хранить некоторую дополнительную информацию в прямом направлении (то есть информацию о rng, упомянутую выше) для расчета в обратном распространении.
- Выполните прямой проход, чтобы вернуть значение активации.
-
Обратная функция принимает градиент выходного тензора относительно некоторого скалярного значения и вычисляет градиент относительно входного тензора того же скалярного значения.
- При обратном проходе сохраненный ввод и
function
был вывезен. -
function
будет вычисляться снова, на этот раз отслеживая промежуточные активации, а затем используя эти активации для вычисления градиентов.
- При обратном проходе сохраненный ввод и
"""
我们可以通过建立torch.autograd的子类来实现我们自定义的autograd函数,
并完成张量的正向和反向传播。
"""
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
"""
在forward函数中,接收包含输入的Tensor并返回包含输出的Tensor。
ctx是环境变量,用于提供反向传播是需要的信息。我们可以使用上下文对象来缓存对象,以便在反向传播中使用。可通过ctx.save_for_backward方法缓存数据,save_for_backward只能传入Variable或是Tensor的变量。
"""
check_backward_validity(args)
# 保存前向传播函数
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
# (If the user intends that the context is initialized later, within their
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
# we have no way to anticipate this will happen before we run the function.)
# 存储前向传播时候的设备状态
ctx.had_cuda_in_fwd = False
if torch.cuda._initialized:
ctx.had_cuda_in_fwd = True
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
# to be filled out during the backward.
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
for i, arg in enumerate(args): # 存储输入数值
if torch.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
# `saved_for_backward`是会保留此input的全部信息, 并避免in-place操作导致的input在backward被修改的情况. 它是将函数的输入参数保存起来以便后面在求导时候再使用,起前向反向传播中协调作用。
ctx.save_for_backward(*tensor_inputs)
with torch.no_grad():
outputs = run_function(*args) # 进行前向传播
return outputs
"""
在反向传播中,我们接收到上下文对象和一个张量,
其包含了相对于正向传播过程中产生的输出的损失的梯度。
我们可以从上下文对象中检索缓存的数据,
并且必须计算并返回与正向传播的输入相关的损失的梯度。
"""
# 自动求导是根据每个op的backward创建的graph来进行的
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
" argument.")
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors # 获取前面保存的参数,也可以使用self.saved_variables
# Fill in inputs with appropriate saved tensors.
for i, idx in enumerate(tensor_indices): # 利用存储的张量重新设置input
inputs[idx] = tensors[i]
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrounding state
# when we're done.
# 存储目前rng状态,模拟前向传播状态,最后恢复目前状态
rng_devices = []
if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
rng_devices = ctx.fwd_gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state) # 恢复前向传播时候的设备状态
if ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(tuple(inputs))
with torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd):
# 利用前向传播函数再次计算
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
# run backward() with only tensor that requires grad
outputs_with_grad = [] # 激活值
args_with_grad = [] # 梯度
# 从前向传播计算的结果中筛选需要传播的张量
for i in range(len(outputs)):
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
outputs_with_grad.append(outputs[i])
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True,"
" this checkpoint() is not necessary")
# 开始后向传播
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs)
return (None, None) + grads
4.3 Конвейерный режим
Давайте посмотрим, как конвейерный режим выполняет Checkpoint.
Конвейерный параллельный режим Pytorch вдохновлен GPipe и упоминается в его комментариях.
пройти черезCheckpointFunction
, pytorch может объединить пересчет и рекурсивное обратное распространение в функцию автоматического вывода, поэтому, когда прибудет градиент, начнется пересчет. Но в конвейерном режиме, чтобы сократить время простоя графического процессора, пересчет должен произойти до поступления градиента (поскольку пересчет фактически не зависит от градиента, пересчет можно выполнить до поступления градиента, чтобы получить значение активации, и подождать для получения градиента обратного распространения.После этого значения активации агрегируются для выполнения собственного расчета градиента).
Чтобы решить эту проблему, pytorch вводит две функции автоматического вывода: class:Recompute
and class:Checkpoint
, соответственно, представляют собой пересчет и рекурсивное обратное распространение, которое должно разделить CheckpointFunction в обычном режиме на два этапа, чтобы эти две функции можно было использовать для управления механизмом автоматического деривации и CUDA. В частности, в классе:Recompute
and class:Checkpoint
Вставьте синхронизацию CUDA между классами следующим образом:Checkpoint
Отложить до конца полную копию градиента.
Разделяя сегменты, несколько этапов конвейера могут работать параллельно.
4.3.1 Пример
Сначала мы можем взглянуть на код test/distributed/pipeline/sync/test_checkpoint.py.
Благодаря умной печати журнала мы можем увидеть использование контрольной точки в прямом и обратном распространении во время выполнения.
Конечным результатом временной шкалы является ["a:вперед", "b:вперед", "b:вперед", "b:назад", "a:вперед", "a:назад"],
Два из них соответствуют проходу вперед, Checkpoint(Log[b]), Checkpoint(Log[a]).
@pytest.mark.parametrize("device", devices)
def test_serial_checkpoints(device):
# Copied from https://github.com/pytorch/pytorch/pull/18568.
timeline = []
class Log(torch.autograd.Function):
@staticmethod
def forward(ctx, name, x):
ctx.name = name
timeline.append(f"{name}:forward")
return x.detach()
@staticmethod
def backward(ctx, grad_output):
name = ctx.name
timeline.append(f"{name}:backward")
return None, grad_output
a = torch.rand(1, device=device, requires_grad=True)
b = torch.rand(1, device=device, requires_grad=True)
# Increase the next function sequence number.
_ = a + 1 + 2 + 3 + 4 + 5
# 这里意味着最后 backward 实际会运行"a:forward", "a:backward"
a = checkpoint(partial(Log.apply, "a"), a)
a, phony = fork(a)
b = join(b, phony)
# 这里意味着最后 backward 实际会运行"b:forward", "b:backward"
b = checkpoint(partial(Log.apply, "b"), b)
c = torch.cat((a, b))
out = c.sum()
# +--> {a} --Checkpoint(Log)--> {a}
# {out} --Sum--> {c} --Cat ^-----------------------------+
# +--> {b} --Checkpoint(Log)--> {b} --First--> {b}
out.backward()
assert timeline == ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"]
# |----------------------| |-----------------------| |-----------------------|
# forward pass Checkpoint(Log[b]) Checkpoint(Log[a])
4.3.2 Общие переменные
class:Recompute
and class:Checkpoint
В частности, общие переменные сохраняются через контекст Context.
# Types for shared memory between Checkpoint and Recompute.
Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf)
RNGStates = Tuple[Tensor, Optional[Tensor]] # (cpu_rng_state, gpu_rng_state)
class Context:
"""The common interface between the :class:`Checkpoint` and
:class:`Recompute` context.
"""
recomputed: Deque[Recomputed]
rng_states: Deque[RNGStates]
function: Function
input_atomic: bool
saved_tensors: Tuple[Tensor, ...]
def save_for_backward(self, *tensors: Tensor) -> None: # pragma: no cover
pass
4.3.3 rng state
В зависимости от среды выполнения состояние ГСЧ может по-разному влиять на производительность, поэтому необходимо сохранять текущее состояние ГСЧ устройства во время каждой контрольной точки и восстанавливать текущее состояние ГСЧ устройства перед повторным вычислением.
Методы save_rng_states и restore_rng_states используются для доступа к состояниям RNG соответственно.
def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None:
""":meth:`Checkpoint.forward` captures the current PyTorch's random number
generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.
.. seealso:: :ref:`Referential Transparency`
"""
cpu_rng_state = torch.get_rng_state()
gpu_rng_state: Optional[Tensor]
if device.type == "cuda":
gpu_rng_state = torch.cuda.get_rng_state(device)
else:
gpu_rng_state = None
rng_states.append((cpu_rng_state, gpu_rng_state))
@contextmanager
def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]:
""":meth:`Recompute.backward` restores the random number generator states
captured by :func:`save_rng_states` within its context.
.. seealso:: :ref:`Referential Transparency`
"""
cpu_rng_state, gpu_rng_state = rng_states.pop()
gpu_devices: List[torch.device] = []
if device.type == "cuda":
gpu_devices.append(device)
with torch.random.fork_rng(gpu_devices):
torch.set_rng_state(cpu_rng_state)
if gpu_rng_state is not None:
torch.cuda.set_rng_state(gpu_rng_state, device)
yield
4.3.4 Checkpoint
Checkpoint и последующий Recompute должны разделить код контрольной точки в обычном режиме на два этапа (функция forward разделена на два этапа, а back функция также разделена на два этапа), чтобы можно было лучше использовать конвейер.
class Checkpoint(torch.autograd.Function):
@staticmethod
# type: ignore[override]
def forward(
ctx: Context,
phony: Tensor,
recomputed: Deque[Recomputed],
rng_states: Deque[RNGStates],
function: Function,
input_atomic: bool,
*input: Tensor,
) -> TensorOrTensors:
ctx.recomputed = recomputed
ctx.rng_states = rng_states
# 存RNG状态
save_rng_states(input[0].device, ctx.rng_states)
ctx.function = function
ctx.input_atomic = input_atomic
# 为BP做准备,其实目前没有实现
ctx.save_for_backward(*input)
# 进行前向计算
with torch.no_grad(), enable_checkpointing():
output = function(input[0] if input_atomic else input)
return output
@staticmethod
def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover
# 从保存的重计算变量中弹出所需变量
output, input_leaf = ctx.recomputed.pop()
if isinstance(output, tuple):
tensors = output
else:
tensors = (output,)
if any(y.requires_grad for y in tensors):
tensors = tuple([x for x in tensors if x.requires_grad])
# 进行自动微分
torch.autograd.backward(tensors, grad_output)
grad_input: List[Optional[Tensor]] = [None, None, None, None, None]
grad_input.extend(x.grad for x in input_leaf)
return tuple(grad_input)
4.3.5 Recompute
Пересчет — это пересчет промежуточных переменных на основе сохраненной информации.
class Recompute(torch.autograd.Function):
@staticmethod
# type: ignore[override]
def forward(
ctx: Context,
phony: Tensor,
recomputed: Deque[Recomputed],
rng_states: Deque[RNGStates],
function: Function,
input_atomic: bool,
*input: Tensor,
) -> Tensor:
ctx.recomputed = recomputed
ctx.rng_states = rng_states
ctx.function = function
ctx.input_atomic = input_atomic
ctx.save_for_backward(*input)
return phony
@staticmethod
def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]:
input = ctx.saved_tensors
input_leaf = tuple(x.detach().requires_grad_(x.requires_grad) for x in input)
# 取出保存的RNG状态,进行前向计算,得到中间变量
with restore_rng_states(input[0].device, ctx.rng_states):
with torch.enable_grad(), enable_recomputing():
output = ctx.function(input_leaf[0] if ctx.input_atomic else input_leaf)
# 保存变量,为Checkpoint使用
ctx.recomputed.append((output, input_leaf))
grad_input: List[None] = [None, None, None, None, None]
grad_input.extend(None for _ in ctx.saved_tensors)
return tuple(grad_input)
4.3.6 Pipeline
4.3.6.1 Task
Мы начнем с рассмотрения класса Task. Код находится по адресу: torch/distributed/pipeline/sync/worker.py.
Как видно из комментариев, Task используется для расчета микробатча на партиции.
compute
Может выполняться параллельно в рабочих потоках.
finalize
должен быть вcompute
Выполняется после завершения.
class Task:
"""A task represents how to compute a micro-batch on a partition.
It consists of two parts: :meth:`compute` and :meth:`finalize`.
:meth:`compute` should be executed in worker threads concurrently.
:meth:`finalize` should be executed after when worker threads complete to
execute :meth:`compute`.
:meth:`compute` might be boosted by worker threads. Because it produces
several CUDA API calls by user code. In PyTorch, parallel CUDA API calls
are not serialized through GIL. So more than one CUDA API call can be
produced at the same time.
"""
def __init__(
self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]],
) -> None:
self.stream = stream
self._compute = compute
self._finalize = finalize
self._grad_enabled = torch.is_grad_enabled()
def compute(self) -> Batch:
with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
return self._compute()
def finalize(self, batch: Batch) -> None:
if self._finalize is None:
return
with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
self._finalize(batch)
4.3.6.2 compute
Это вычислительная функция класса Pipeline.
Логика Pipeline показана в его комментариях (комментарии PyTorch действительно информативны). Дело в томTask(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
Вот как пройти контрольно-пропускной пункт.
Как видите, метод перерасчета будет установлен как метод финализации Задачи, а затем будет запланирован перерасчет.
class Pipeline:
"""The pipeline parallelism for Pipe."""
def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
) -> None:
"""Runs tasks with synchronization to copy streams."""
partitions = self.partitions
devices = self.devices
copy_streams = self.copy_streams
checkpoint_stop = self.checkpoint_stop
# Disable checkpointing if in eval mode.
if not self.partitions[0].training:
checkpoint_stop = 0
n = len(partitions)
streams = [current_stream(d) for d in devices]
exc_info: Optional[ExcInfo] = None
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for i, j in schedule:
batch = batches[i]
partition = partitions[j]
# Synchronize with the copied input. ([1] in the diagram)
if j != 0:
_wait(batch, copy_streams[j][i], streams[j])
# Determine whether checkpointing or not.
checkpoint = i < checkpoint_stop
if checkpoint:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return partition(input)
# 这里进行处理
chk = Checkpointing(function, batch)
# 分别设置了chk.checkpoint 和 chk.recompute
task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
del function, chk
else:
def compute(
batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
task = Task(streams[j], compute=compute, finalize=None)
del compute
# Compute tasks in parallel. ([2] in the diagram)
self.in_queues[j].put(task) # 将task插入到 pipeline的queue,这样可以并行。
for i, j in schedule:
ok, payload = self.out_queues[j].get()
# Hold the first exception.
if exc_info is not None:
continue
elif not ok:
exc_info = cast(ExcInfo, payload)
continue
# 取出 task
task, batch = cast(Tuple[Task, Batch], payload)
# The copy stream synchronizes to copy the output. ([3] in the
# diagram)
if j != n - 1:
_wait(batch, streams[j], copy_streams[j][i])
# Finalize tasks. If checkpointing is enabled, here the
# recomputation is scheduled at backpropagation. ([4] in the
# diagram)
with use_device(devices[j]):
task.finalize(batch) # 计划进行重计算
batches[i] = batch
# Fail at the first exception.
if exc_info is not None:
raise exc_info[0].with_traceback(exc_info[1], exc_info[2])
0x05 Реализация Gpipe
Когда Gpipe выполняет обратное распространение, он может пересчитать функцию прямого распространения F_k на k-м ускорителе.
5.1 Функция API _Rematerialize
Во-первых, давайте посмотрим на методы API.
В builder.py есть функция _Rematerialize, которую можно использовать для переноса слоя, который необходимо пересчитать.
def _Rematerialize(self, name, body):
"""Forces rematerialization on FProp of the body layer."""
return builder_layers.RematerializationLayer.Params().Set(
name=name, body=body)
5.2 Уровень упаковки RematerializationLayer
RematerializationLayer — это слой-оболочка, который имеет:
FProp заключается в том, чтобы обернуть инкапсулированный слой в функцию Fn, а затем вызвать py_utils.RematerializeFn для передачи Fn вместе с входной переменной.
class RematerializationLayer(base_layer.BaseLayer):
"""A wrapper layer with rematerialization."""
@classmethod
def Params(cls):
p = super().Params()
p.Define('body', None,
'The main layer whose FProp will be wrapped by RematerializeFn.')
return p
def __init__(self, params):
super().__init__(params)
self.CreateChild('body', self.params.body)
def FProp(self, theta, *xs):
input_list = theta.body.Flatten() # 得到theta
theta_len = len(input_list)
input_list += list(xs) # 得到输入参数
input_len = len(input_list)
def Fn(*args): # 包装函数,会调用被封装层的 FProp
body_theta = theta.body.Pack(args[:theta_len])
return self.body.FProp(body_theta, *args[theta_len:input_len])
return py_utils.RematerializeFn(Fn, *input_list) # 调用,执行FProp,并且做Gradient checking
@classmethod
def FPropMeta(cls, p, *args): # 就是传播被封装层的信息
py_utils.CheckShapes(args)
return p.body.cls.FPropMeta(p.body, *args)
3.2.3 функция градиентов тензорного потока
RematerializeFn вызывает функцию градиентов тензорного потока для вычисления градиентов, поэтому нам нужно объяснить.
В тензорном потоке функция градиентов может автоматически вычислять градиент функции. Нам просто нужно спроектировать нашу функцию и вызвать ееtf.gradients
функция подойдет.
Параметры tf.gradients() следующие, где
-
tf.gradients()
выполнитьys
правильноxs
искать вывод -
grad_ys
также является списком, длина которого равнаlen(ys)
. Смысл этого параметра в том, чтобыxs
Производный вес каждого элемента в .
tf.gradients(ys, xs,
grad_ys=None,
name='gradients',
colocate_gradients_with_ops=False,
gate_gradients=False,
aggregation_method=None,
stop_gradients=None)
5.4 Функциональная функция RematerializeFn
RematerializeFn — это последняя функция, которая вызывает fn и повторно материализует fn во время обратного распространения.
def RematerializeFn(fn, *xs):
"""Calls fn and rematerializes fn in the backward pass.
`fn(*xs) -> ys`, where xs and ys can be a single tensor or a tuple of tensors.
Args:
fn: A python function to be rematerialized in the backprop pass.
*xs: A single tensor or a list/tuple of tensors. `xs` are input args to the
fn function.
Returns:
`fn(*xs)`
"""
initial_step_seed = GetStepSeed()
final_step_seed = MaybeGenerateSeedFromScope()
def Backward(fwd_xs, fwd_ys, d_fwd_ys):
"""The backward function that rematerializes forward outputs."""
del fwd_ys # 去掉传入的参数,因为在内部需要用备份的Checkpoint来处理
always_true = tf.random.uniform([]) < 2.0
# Alternatively, can do this:
# tf.where(tf.math.is_nan(x),
# tf.constant(float('nan'), dtype=x.dtype) * tf.ones_like(x),
# x)
bak_xs = [tf.where(always_true, x, tf.zeros_like(x)) for x in fwd_xs.xs] # 依据Checkpoint来生成 bak_xs
for dst, src in zip(bak_xs, xs):
dst.set_shape(src.shape)
ResetStepSeed(initial_step_seed)
ys = fn(*bak_xs) # 依据Checkpoint来重新生成ys
MaybeResetStepSeed(final_step_seed)
dxs = tf.gradients(ys, bak_xs, grad_ys=d_fwd_ys) # ys 对 bak_xs 求导
dxs_final = [] # 聚合
for dx, x in zip(dxs, bak_xs):
if dx is None:
dxs_final.append(tf.zeros_like(x))
else:
dxs_final.append(dx)
assert len(dxs_final) == len(bak_xs)
return NestedMap(
initial_step_seed=tf.zeros_like(initial_step_seed), xs=dxs_final)
ys_shapes = []
# TODO(huangyp, yonghui): Check Forward doesn't use any stateful random ops.
def Forward(fwd_xs):
"""Forward function plus sanity checks."""
for dst, src in zip(fwd_xs.xs, xs):
dst.set_shape(src.shape)
ResetStepSeed(fwd_xs.initial_step_seed)
ys = fn(*fwd_xs.xs) # 正常计算
# Some sanity check.
assert not GetExtraInputs()
assert not GetExtraArgs()
assert not GetExtraVars()
if isinstance(ys, tuple):
for y in ys:
assert isinstance(y, tf.Tensor)
ys_shapes.append(y.shape)
else:
assert isinstance(ys, tf.Tensor)
ys_shapes.append(ys.shape)
return ys
ys = CallDefun(
Forward,
NestedMap(initial_step_seed=initial_step_seed, xs=xs),
bak=Backward)
if isinstance(ys, tuple):
for y, s in zip(ys, ys_shapes):
y.set_shape(s)
else:
ys.set_shape(ys_shapes[0])
# TODO(b/129159299): The ResetStepSeed below is needed to work around this
# bug, which is a problem with global tensors being shared by different
# inference graphs. It should be replaced with the new step seed value
# returned from the Forward function when the bug is fixed.
MaybeResetStepSeed(final_step_seed)
return ys
CallDefun определяется следующим образом: он должен инкапсулировать прямой и обратный вызовы. Среди них роль Function заключается в построении графовой функции TensorFlow на основе вызываемого
def CallDefun(fwd, args=None, bak=None, bak_as_function=False, device=None):
"""Wraps fwd in a defun with custom gradient bak and calls it with args.
Args:
fwd: A callable xs: Nested Structure -> ys: Nested Structure.
args: A Nested Structure of tf.Tensor or None.
bak: A callable xs, ys, dys: Nested Structure -> dxs[, dcapture]: Nested
Structure. The custom backprop function for fwd. bak needs to return
dcapture if fwd uses any implicitly captured tensors, whose gradients are
dcapture.
bak_as_function: Whether to create a TF graph function for bak.
device: the device on which to run fwd and bak.
Returns:
A Nested Structure equivalent to what fwd(args) computes.
"""
if args is not None:
args = Transform(tf.convert_to_tensor, args)
sigs = Function(
fwd_sig=TensorSpecs(args),
bak=bak,
bak_as_function=bak_as_function,
device=device)(
fwd=fwd)
if args is None:
return sigs()
else:
return sigs(args)
На этом анализ GPipe завершен, и в следующей статье начнется анализ PipeDream, так что следите за обновлениями.
Кроме того, в отношении PyTorch Pipeline будет специальная серия для последующего анализа.
0xEE Личная информация
★★★★★★Думая о жизни и технологиях★★★★★★
Публичный аккаунт WeChat:мысли Росси
ссылка 0xFF
lingvo framework день чтение заметок
Накопление градиента с помощью tensorflow2
PipeDream: Fast and Efficient Pipeline Parallel DNN Training
На данный момент 231 you.GitHub.IO/neural-net…
Блог Woohoo.cn на.com/geek found/afraid/14…
Технология оптимизации видеопамяти во время обучения - слияние ОП и контрольная точка градиента
Pytorch Notes 04 - Пользовательский torch.autograd.Function
Учебное пособие по Autograd для PyTorch
Пользовательское расширение pytorch (3) - простое определение и случай torch.autograd.Function
Пользовательское расширение pytorch (2) - torch.autograd.Function завершает пользовательский слой
torch.autograd интерпретации исходного кода PyTorch: подробное объяснение расчета градиента
Перевод примечаний к курсу CS231n: примечания по обратному распространению