[Анализ исходного кода] Как Facebook обучает очень большие модели --- (2)

машинное обучение глубокое обучение

0x00 сводка

Как мы упоминали ранее, Microsoft ZeRO может масштабировать модель с триллионом параметров на 4096 графических процессорах NVIDIA A100, используя 8-сторонний параллелизм моделей, 64-сторонний конвейерный параллелизм и 8-сторонний параллелизм данных.

FSDP (Fully Sharded Data Parallel) — это обновленная версия PyTorch DDP, предложенная Facebook после подробного ознакомления с Microsoft ZeRO, которую можно рассматривать как эталон по сравнению с Microsoft ZeRO, и ее суть — разделение параметров. Разделение параметров заключается в разделении параметров модели на каждый графический процессор. Мы используем документы, блоги и код Google, Microsoft и Facebook для анализа обучения.

В предыдущей статье мы рассказали, как использовать FSDP, а в этой статье показано, как FSDP реализует разделение параметров с точки зрения исходного кода.

Другие статьи из этой серии:

[Анализ исходного кода] Распределенный PyTorch ZeroRedundancyOptimizer

[Перевод статьи] ZeroRO сегментирования параметров распределенного обучения

[Перевод статьи] Распределенное обучение Разделение параметров Google Weight Sharding

[Анализ исходного кода] Как Facebook обучает очень большую модель --- (1)

0x01 Обзор

1.1 ZeRO

Начнем с обзора ZeRO.

В глубоком обучении моделей видеопамять в основном используетсяModel StatesиActivationЗанят двумя частями. Модельные состояния включают:

  • Состояния оптимизатора: данные, используемые оптимизатором при обновлении градиента, например, в SGD.Momentum.
  • Градиент: Градиент, созданный обратным распространением.
  • Параметр модели: параметры модели, то есть информация, «полученная» из данных во время обучения.

ZeRO решает проблему занятости модельных состояний, а ZeRO делится на три уровня, как показано на рисунке ниже, соответствующих сегментации модельных состояний в разной степени.

img

1.1.1 ZeRO-1

Этот уровень будет разделенOptimizer States.

Оптимизатор не будет работать ни в прямой, ни в обратной фазе,Optimizer StatesТолько после создания градиента он будет рассчитываться вместе с параметрами модели при использовании градиента для обновления для создания новых параметров. Следовательно, пары ZeRO-1Optimizer StatesВыполните сегментацию, предполагая, что есть N рабочих, затем пусть у каждого рабочего будет только 1/NOptimizer States, используя эту 1/N изOptimizer StatesПосле обновления соответствующих параметров 1/N все параметры объединяются вместе для формирования полной модели (в частности, с помощью широковещательных операций или операций сбора всех рангов, чтобы убедиться, что все ранги получают самые последние обновленные значения параметров).

1.1.2 ZeRO-2

ZeRO-2 разделитсяOptimizer StatesиGradients.

ZeRO-2 основан на ZeRO-1, потому что ZeRO-1 имеетOptimizer StatesСегменты хранятся в нескольких воркерах, поэтому, естественно, вам нужно получить только свой собственный воркер.Optimizer StatesСоответствующий градиент также завершает фрагментацию градиента.

Градиенты всех воркеров агрегируются через AllReduce, каждый воркер может выбрать только часть нужных ему градиентов, а остальные градиенты можно отбросить.

1.1.3 ZeRO-3

ZeRO-3 разделитсяOptimizer States,GradientsиParameters. На основе ZeRO-1 и ZeRO-2 каждый рабочий процесс сохраняет только часть осколков модели, и все рабочие процессы работают вместе, чтобы обеспечить полное состояние модели, а также собирать и выпускать параметры в соответствии с конкретными вычислительными требованиями. Здесь следует подчеркнуть, что сбор и выдача параметров, которые делает ZeRO-3, заключается в том, чтобы детально обрабатывать каждый параметр, и мы разберем его в сочетании с кодом позже.

1.2 DDP VS FSDP

Давайте сначала найдем изображение из более ранней версии исходного кода, чтобы увидеть разницу между DDP и FSDP, вы можете просмотреть его.

img

0x02 Общая логика

2.1 FSDP

Сначала напомним общую логику FSDP следующим образом:

  • Model shard: существует только на каждом графическом процессореШардинг модели.
  • All-gather: каждый GPU собирает все данные с других GPU через all-gatherВеса, чтобы локально вычислить прямое распространение. Это подчеркнутая часть идеи диссертации, стр.
  • Вперед (местный): выполнить операцию переадресации локально. И прямое вычисление, и обратное вычисление используют полную модель.
  • All-gather: затем сделайте это снова перед обратным распространениемВесасобирать. Это подчеркнутая часть идеи диссертации, стр.
  • Назад (местный): локальное выполнение операции в обратном направлении. И прямые, и обратные вычисления используют полную модель, и каждый GPU такжевсе градиенты.
  • Reduce-scatter: после обратного распространения локальныйградиентОн агрегируется и сегментируется на каждом графическом процессоре с помощью редукции-разброса Градиент на каждом сегменте — это часть, соответствующая этому разделу после агрегации, что является подчеркнутой частью идеи статьи Pg.
  • Обновить вес (локальный): каждый GPU обновляет свой локальныйВесаФрагментация.

2.2 Оригинальный ЗЕРО

Во-вторых, давайте посмотрим, как обрабатывается исходный код Microsoft ZeRO.Вы можете сравнить его с приведенными выше идеями FSD, а также увидеть конкретную разницу в реализации между ними в последующем анализе кода FSDP.

1.2.1 Инициализация

Когда ZeRO инициализируется, параметры поровну распределяются между каждым процессом, и он будет:

  • Свести исходные параметры в одно измерение.
  • Каждый работник находит начальную и конечную позиции одномерных параметров в соответствии со своим рангом, а затем копирует соответствующие данные.
  • Чтобы предотвратить потерю исходных характеристик данных, вызванную последующим заполнением и разбиением, информация исходного тензора, такая как shape, numel и т. д., будет записана в _convert_to_deepspeed_param.
  • Он освободит исходный параметр и превратит его в тензор скалярного типа.

Поскольку для прямого/обратного распространения требуются полные параметры, вам нужно знать, как получить все параметры.Zero создаст управляющую информацию во время инициализации.Конкретная операция заключается в создании 4 ловушек для каждого подмодуля.

  • _pre_forward_module_hook,существуетsubmoduleсобранные до прямого распространенияmodule parameters.
  • _post_forward_module_hook,существуетsubmoduleвысвобождается после прямого распространенияmodule parameters.
  • _pre_backward_module_hook,существуетsubmoduleсобирается до начала обратного распространенияmodule parameters.
  • _post_backward_module_hook,существуетsubmoduleвыпущен после обратного распространенияmodule parameters.

Конкретный код:

# Pre forward hook
module.register_forward_pre_hook(_pre_forward_module_hook)
# Post forward hook
module.register_forward_hook(_post_forward_module_hook)
# Pre backward hook
module.register_forward_hook(_pre_backward_module_hook)
# post backward hook
module.register_forward_pre_hook(_post_backward_module_hook)

Затем будут построены два класса: PartitionedParameterCoordinator и PrefetchCoordinator, отвечающие за конкретный сбор и выпуск и вызываемые каждым хуком.

1.2.2 Прямое распространение

Прежде чем начнется прямое распространение,_pre_forward_module_hookВеса на каждом разделе собираются для построения исходных параметров. Вот несколько отличных советов.

Так как обучение проводится послойно, ZeRO выполнит операцию предварительной выборки, то есть, когда будут собраны параметры этого слоя, будут также собраны параметры следующего слоя, что может сэкономить время связи. конкретно это:

  • При выполнении первой итерации записывается полная работающая запись модели, то есть порядок выполнения каждого nn.module.
  • Параметры этого слоя и следующего слоя будут собираться в соответствии с записями операций, а исходные большие параметры будут реконструироваться параллельно в соответствии с исходной тензорной информацией, записанной в вышеупомянутом параметре _convert_to_deepspeed_param.
  • Выполните пересылку подмодуля этого слоя и запишите, какой шаг был предпринят на этот раз, чтобы знать, параметры какого слоя предварительно выбрать в следующий раз.

После того, как прямое распространение закончится, оно будет вызвано_post_forward_module_hookосвободить исходные большие параметры, реконструированные этим слоем.

Вот ключ.Конкретная сборка/выпуск осуществляется послойно, то есть в каждой итерации постепенно строятся/выпускаются исходные большие параметры каждого слоя, и на GPU никогда не бывает полных всех слоев. Модель, но последовательно имеет исходные параметры каждого слоя. Например, если есть 6 слоев, при расчете четвертого слоя некий GPU выглядит следующим образом, только третий слой имеет полные параметры, первые три слоя освобождены, а последние два слоя не собраны.

+---------------------------------------------+
| GPU n                                       |
|                        +  gather            |
|                        |  forward           |
|                        |  release           |
|                        v                    |
|                    +---+----+               |
|  Layer 0           |        |               |
|                    +---+----+               |
|                        |  gather            |
|                        |  forward           |
|                        |  release           |
|                        v                    |
|                    +---+----+               |
|  Layer 1           |        |               |
|                    +---+----+               |
|                        |  gather            |
|                        |  forward           |
|                        |  release           |
|                        v                    |
|                    +---+----+               |
|  Layer 2           |        |               |
|                    +---+----+               |
|                        |                    |
|                        |  gather            |
|                        |  forward           |
|                        v                    |
|           +------------+------------------+ |
|  Layer 3  |                               | |
|           +-------------------------------+ |
|                                             |
|                    +--------+               |
|  Layer 4           |        |               |
|                    +--------+               |
|                                             |
|                    +--------+               |
|  Layer 5           |        |               |
|                    +--------+               |
+---------------------------------------------+

Таким образом, подмодулям в каждом рабочем потоке нужно только собирать/создавать параметры перед расчетом прямого распространения и освобождать параметры после расчета, что может уменьшить избыточное пространство памяти. Если мы возьмем отдельный слой, чтобы посмотреть на прямое распространение и обратное распространение, рабочий механизм будет следующим, и разделение параметров будет выполнено с помощью нескольких хуков.

+-------------------------------------------------------------+
| submodule                                                   |
|                                                             |
|        _pre_forward_module_hook()      gather & rebuild     |
|                                                             |
|                    +                                        |
|                    |                                        |
|                    |                                        |
|                    v                                        |
|                                                             |
|                 forward                                     |
|                                                             |
|                    +                                        |
|                    |                                        |
|                    |                                        |
|                    v                                        |
|                                                             |
|        _post_forward_module_hook()      release             |
|                                                             |
|                    +                                        |
|                    |                                        |
|                    |                                        |
|                    v                                        |
|                                                             |
|        _pre_backward_module_hook()      gather & rebuild    |
|                                                             |
|                    +                                        |
|                    |                                        |
|                    |                                        |
|                    v                                        |
|                                                             |
|                 backward                                    |
|                                                             |
|                    +                                        |
|                    |                                        |
|                    |                                        |
|                    v                                        |
|                                                             |
|        _post_backward_module_hook()     release             |
|                                                             |
+-------------------------------------------------------------+

1.2.3 Обратное распространение

_pre_backward_module_hookЭто также похоже на прямое распространение для сбора, предварительной выборки параметров и записи шагов выполнения.

_post_backward_module_hookЭто также похоже на прямое распространение, чтобы высвободить избыточные параметры, которые больше не нужны для расчета.

Просто потому, что PyTorch не поддерживает Pre Backward Hook, поэтому вregister_forward_hookнастроитьautograd.Function, целью которого является выполнение пользовательских операций до того, как модуль выполнит обратную операцию, поэтому операции all-gather и scatter-reduce присоединяются к каждому подмодулю.

2.3 Код ФСДП

Затем мы объединяем код, чтобы сделать обзор.

2.3.1 Инициализация

Основная функция инициализации состоит в том, чтобы разделить параметры модели, и каждый рабочий процесс будет использовать некоторые параметры модели.

Конкретная операция сегментирования реализуется путем обработки каждого параметра как одномерного тензора и сохранения только одного среза, где размер среза определяется количеством параллельных рабочих процессов. Следует отметить, что параметры модели должны быть разделены перед загрузкой в ​​GPU, а затем загружены в GPU каждого воркера.

Из-за более поздних операций заполнения и разделения, чтобы предотвратить потерю исходных характеристик данных, FSDP использует метод data.size() PyTorch для записи исходных характеристик данных в p._orig_size.

VS ZeRO: FSDP в настоящее время не выполняет управление ловушками.

2.3.2 Прямое распространение

Суть этой части: прямое распространение на каждом графическом процессоре и установление отношения управления для обратного распространения, чтобы обратное распространение знало, как собирать параметры и как освобождать параметры. Конкретные операции заключаются в следующем:

  • Прежде всего, поскольку прямой проход использует полную модель, сначала используйте All-gather для сбора всех весов от других графических процессоров, в частности, вызвав _rebuild_full_params() для завершения реконструкции всех параметров модели, которая будет использовать оригинал, хранящийся в p. _orig_size для восстановления исходных параметров.
  • Вызовите _register_post_backward_hooks, чтобы настроить уменьшение разброса для обратного распространения.
  • Выполнить операцию вперед.
  • Вызовите _register_pre_backward_hooks(outputs) для регистрации all-gather для обратного распространения.

Конкретный соответствующий упрощенный код:

self._rebuild_full_params() # 做前向操作之前的 all-gather
self._register_post_backward_hooks() # 为后向传播注册 reduce-scatter
outputs = self.module(*args, **kwargs) # 模型前向传播
outputs = self._register_pre_backward_hooks(outputs) # 为后向传播注册 all-gather

VS ZeRO: FSDP контролирует работу хука в это время, но вместо того, чтобы использовать различные хуки модуля, он единообразно использует register_hook тензора.

2.3.3 Иерархическая оптимизация

У вас могут быть сомнения. Это отличается от исходного кода ZeRO. Исходный код ZeRO должен выполнять сбор/отбрасывание на каждом уровне. FSDP здесь, похоже, выполняет переадресацию на общую модель, и нет многоуровневого выполнения.

На самом деле приведенный выше кодПросто стандартная реализация или просто обработка всей системы как одного уровня, многоуровневая реализация сбора/отбрасывания не задействована.. FSDP принял во внимание стратификацию следующим образом:

Чтобы максимизировать эффективность использования памяти, мы можем отбросить все веса после прямого прохода каждого слоя, сохраняя память для последующих слоев. Это можно сделать, добавив FSDPУпаковкаПрименяется к каждому слою в сети для реализации (с использованием auto_wrap для переноса каждого слоя и настройки reshard_after_forward=True). Ниже приведено представление псевдокода:

FSDP forward pass:
    for layer_i in layers:
        all-gather full weights for layer_i # 权重
        forward pass for layer_i
        discard full weights for layer_i # 权重

FSDP backward pass:
    for layer_i in layers:
        all-gather full weights for layer_i # 权重
        backward pass for layer_i
        discard full weights for layer_i # 权重
        reduce-scatter gradients for layer_i # 梯度

2.3.4 Резюме

Мы видим, что если параметры модели фрагментированы, локальный оптимизатор будет оптимизировать эти локально выделенные параметры, состояние оптимизатора будет автоматически фрагментировано, и, следовательно, градиент будет автоматически фрагментирован, что является графиком в нижней части графика.Pos+g+pP_{os+g+p}.

img

0x03 Инициализировать

Основная функция инициализации состоит в том, чтобы разделить параметры модели, и каждый рабочий процесс будет использовать некоторые параметры модели. Если их 3, каждый рабочий делит 1/3, поэтому они освобождают остальные 2/3, которые им не принадлежат (потому что они уже избыточны). Однако соответствующие параметры модели трех рабочих объединяются, что является полным параметром модели.

Сначала мы рассмотрим общий метод инициализации, у всех есть общее впечатление, а затем мы будем тщательно и постепенно его анализировать.

class FullyShardedDataParallel(nn.Module):

    def __init__(
        self,
        module: nn.Module,
        process_group: Optional[ProcessGroup] = None,
        reshard_after_forward: bool = True,
        mixed_precision: bool = False,
        fp32_reduce_scatter: bool = False,
        flatten_parameters: bool = True,
        move_params_to_cpu: bool = False,
        compute_dtype: Optional[torch.dtype] = None,
        buffer_dtype: Optional[torch.dtype] = None,
        move_grads_to_cpu: Optional[bool] = None,
        bucket_cap_mb: int = 25,
        compute_device: Optional[torch.device] = None,
        no_broadcast_optim_state: Optional[bool] = False,
        state_dict_device: Optional[torch.device] = None,
        clear_autocast_cache: bool = False,
        force_input_to_fp32: bool = False,
        verbose: bool = False,
        cpu_offload: bool = False,
    ):
        init_start = time.time()
        super().__init__()
        self.process_group = process_group or get_process_group_cached()
        self.rank = self.process_group.rank()
        self.world_size = self.process_group.size()
        self.reshard_after_forward = reshard_after_forward
        self.mixed_precision = mixed_precision
        self.fp32_reduce_scatter = fp32_reduce_scatter
        self.flatten_parameters = flatten_parameters
        self.move_params_to_cpu = move_params_to_cpu or cpu_offload
        self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32)
        self.buffer_dtype = buffer_dtype or self.compute_dtype
        self.move_grads_to_cpu = self.move_params_to_cpu if move_grads_to_cpu is None else move_grads_to_cpu
        self.bucket_cap_mb = bucket_cap_mb
        self.compute_device = compute_device or _get_default_cuda_device(module)
        self.uncollected_opt_state: Dict[int, Dict] = {}
        self.no_broadcast_optim_state = no_broadcast_optim_state
        self.state_dict_device = state_dict_device or self.compute_device
        self.clear_autocast_cache = clear_autocast_cache
        self.force_input_to_fp32 = force_input_to_fp32
        self.verbose = verbose

        self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size)
        self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor

        self.numel_padded_per_param: List[int] = []
        self._tstart = time.time()

        # skip validation if the process group was created above
        if process_group:
            validate_process_group(self.compute_device, self.process_group)

        # enable pytorch sync_bn just in case model contains sync_bn layers.
        enable_pytorch_sync_bn(module)

        # 1. 打平参数
        
        # Only handle params which are not already sharded. This enables
        # sharding individual layers of a Module, with an outer wrapper to
        # shard any leftover parameters.
        param_names = []
        params = []
        
        # 1.1 遍历模型参数,收集到params之中
        for param_name, param in module.named_parameters():
            if not hasattr(param, "_is_sharded"):
                param_names.append(param_name)
                params.append(param)

        self._has_params = len(params) > 0

        # 1.2 把需要打平的参数收集到 to_be_flatten_params 之中
        to_be_flatten_params: List[List[Parameter]] = [[]]
        non_flatten_params = params
        param_name_groups = [[n] for n in param_names]
        if self.flatten_parameters:
            to_be_flatten_params = [params]
            non_flatten_params = []
            param_name_groups = [param_names]
        del param_names

        # 1.3 使用 FlattenParamsWrapper 来打平参数
        self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=to_be_flatten_params)
        del module  # free original module in case it helps garbage collection

        # Now, in this FSDP wrapper class, we keep a list of to-be-flatten and not-to-be-flatten
        # params for doing sharding, gradient hooks, etc. Note, the ordering of the
        # list matters: flatten params are always in the front.
        #
        # The self._num_flatten_params and self._param_name_groups are computed
        # and kept here to support summon_full_params and shard-to-full weight
        # consolidation.
        
        # 1.4 把打平的参数和其他参数拼接到 self.params 之中
        self.params = cast(List[Parameter], self._fsdp_wrapped_module.flat_params) + non_flatten_params
        self._num_flatten_params = len(self._fsdp_wrapped_module.flat_params)
        self._param_name_groups = param_name_groups

        # 2. 进行参数分区
        
        # Shard module parameters in place
        self._shard_parameters_() # 

        # 3. 惰性初始化
        self._reset_lazy_init()

        # Flag to indicate if we require gradient reduction in the backward
        # pass. This will be False when inside the no_sync context manager.
        self._require_backward_grad_sync: bool = True

        # Enum to indicate if we're in the forward/backward pass, idle, etc.
        self.training_state = TrainingState.IDLE

        # Flag to indicate if the full params are gathered.
        self.has_full_params: bool = False

        # Register hook after state_dict() to remove the "_fsdp_wrapped_module."
        # prefix and before load_state_dict() to add it back.
        self._register_state_dict_hook(_post_state_dict_hook)
        self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)

        # Flag to indicate whether state_dict() should automatically summon the
        # full params. This defaults to True, but may be set to False if the
        # user explicitly requests the local state dict via local_state_dict().
        self._return_full_state_dict = True
        init_end = time.time()

        # Flag to guard multiple pre-backward hook being executed per iteration.
        # This is reset at the end of the backward pass.
        self._pre_backward_hook_has_run = False

3.1 Параметры обработки

Первым шагом метода инициализации является обработка параметров, а конкретные числа ниже соответствуют комментариям в коде.

  • 1.1 Он будет проходить по параметрам модели и собирать их в параметры.
  • 1.2 Соберите параметры, которые необходимо свести, в to_be_flatten_params.
  • 1.3 Используйте FlattenParamsWrapper для выравнивания параметров.
    • Теперь у нас есть список self.params, который содержит параметры для сглаживания и без сглаживания для сегментирования, перехватчиков градиента и т. д. Порядок списка следующий: параметр flatten всегда стоит первым.
    • Также вычисляются группы self._num_flatten_params и self._param_name_groups для поддержки функцииsummon_full_params и слияния осколков с полными весами.
  • 1.4 Объедините сведенные параметры и другие параметры в self.params.

3.1.2 Шардинг

инициализировать, а затем вызыватьshard_parametersОсколок. Как видно из комментариев, при инициализации будет упакован модуль с полными параметрами, а параметры будут разделены на месте. Конкретная операция сегментирования реализуется путем обработки каждого параметра как одномерного тензора и сохранения только одного среза, где размер среза определяется количеством параллельных рабочих процессов.

Следует отметить, что параметры модели должны быть разделены перед загрузкой в ​​GPU, а затем загружены в GPU каждого воркера.

@torch.no_grad()
def _shard_parameters_(self) -> None:
    """
    At initialization we wrap a module with full parameters and shard the
    parameters in-place. Sharding is implemented by viewing each parameter
    as a 1D Tensor and retaining only a single slice, where the slice size
    is determined by the number of data parallel workers.

    Wrapping modules with many small parameters (or with a very large data
    parallel world size) will result in many small parameter shards and slow
    performance. In this case it's better to set *``flatten_parameters``* to
    ``True``, so that all of the small parameters in the module are combined
    into a single contiguous Tensor and sharded once.

    After this initial sharding is complete, the user can initialize a
    ``torch.optim.Optimizer`` in the usual way, i.e.::

    The optimizer will see only a single slice of parameters and will thus
    allocate less memory for optimizer state, avoiding redundancy across
    data parallel workers.
    """
    self.numel_padded_per_param = []
    for p in self.params: # 遍历模型参数列表

        # If world_size is 1, then we all-reduce grads instead of sharding.
        p._is_sharded = self.world_size > 1
        p._orig_size = p.data.size() # 记录张量原始信息(shape, numel, etc)

        if not p._is_sharded:
            self.numel_padded_per_param.append(0)
            continue
        p._is_sharded = True

        # Replace p.data with the relevant shard.
        orig_data = p.data # 拿到原始数据
        p.data, num_padded = self._get_shard(p.data) # 获取这个模型参数的分区
        self.numel_padded_per_param.append(num_padded)
        free_storage_(orig_data) # 释放冗余数据

_get_shard — это конкретная операция раздела, но возвращает только раздел, соответствующий этому рангу.

В исходном коде ZeRO один набор тензоров параметров модели_convert_to_deepspeed_paramVest, чтобы можно было записать исходные характеристики тензора (форма, число и т. д.), чтобы предотвратить потерю исходных характеристик данных из-за заполнения и разбиения на более позднем этапе.FSDP не использует этот метод, а записывает его в p._orig_size, в частности, используя метод data.size() PyTorch..

def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
    """Return the local shard of a full tensor."""
    
    # Shard using torch.chunk to match all-gather/reduce-scatter.
    # 把传入的张量打平,按照world size分成一个list
    chunks = list(torch.flatten(tensor).chunk(self.world_size))
    # 把list之中都初始化
    while len(chunks) < self.world_size:
        chunks.append(chunks[0].new_empty(0)) # 插入空白张量

    # Determine number of padding elements.
    # 看看需要pad多少元素
    num_to_pad = chunks[0].numel() - chunks[self.rank].numel()

    # 获得本rank对应的分区
    shard = chunks[self.rank].clone()
    if num_to_pad > 0:
        shard = F.pad(shard, [0, num_to_pad]) # pad
    # 返回    
    return shard, num_to_pad

3.3 Ленивая инициализация

В прямых или других методах выполняется ленивая инициализация, а именно:

  • Вызовите _init_param_attributes для инициализации параметров.Для параметров, которые необходимо переместить, подготовьтесь к последующему перемещению в CPU и поместите их в pin_memory.
  • Вызовите _set_is_root, чтобы установить корень, в основном, чтобы сделать некоторые настройки для группы процессов.
  • Вызовите _setup_streams, чтобы настроить потоки CUDA. Создайте ReduceScatterBucketer с отдельными потоками CUDA для «fp32_to_fp16», «all_gather» и «post_backward».
  • Вызовите _wait_for_previous_optim_step, чтобы дождаться завершения потока.
def _lazy_init(self) -> None:
    """Initialization steps that should happen lazily, typically right
    before the first forward pass.
    """
    # Initialize param attributes lazily, in case the param's dtype or
    # device changes after __init__.
    for p in self.params:
        self._init_param_attributes(p) # 1. 初始化参数

    # Initialize _is_root and setup streams. These steps would ideally
    # happen in __init__, but _is_root can only be determined after the
    # entire model hierarchy is setup, thus we run it lazily.
    if self._is_root is None:
        self._set_is_root()
        self._setup_streams()

    if self._is_root:
        # Buffers stay on GPU, and don't get sharded. Since _cast_buffers
        # applies recursively, we only call this from the root instance.
        self._cast_buffers()

        # Don't free the full params for the outer-most (root) instance,
        # since those params will be needed immediately after for the
        # backward pass.
        self.reshard_after_forward = False

        # Due to the use of streams, we need to make sure the previous
        # ``optim.step()`` is done before we all-gather parameters.
        self._wait_for_previous_optim_step()

3.3.1 Параметр инициализации

Здесь будут установлены следующие параметры, и здесь можно увидеть переключение смешанной точности:

  • _fp32_shard: срез с одним параметром полной точности (обычно fp32, но это зависит от типа данных модели, переданного пользователем). Может выполняться на CPU или GPU, в зависимости от*cpu_offload* ценность.
  • _fp16_shard:еслиmixed_precisionзаTrue, который будет единым шардом параметров в fp16 для all-gather.
  • _full_param_padded: все веса, используемые для вычисления в прямом и обратном проходах (заполнены дляworld_sizeпоровну). Это изменит размер на месте и материализуется только при необходимости (через all-gather).

Основная логика такова: при подготовке к последующему переносу на ЦП некоторые параметры будут помещены в pin_memory для создания _full_param_padded, вмещающего все веса.

@torch.no_grad()
def _init_param_attributes(self, p: Parameter) -> None:
    """
    We manage several attributes on each Parameter instance. The first two
    are set by :func:`_shard_parameters_`:

        ``_is_sharded``: ``True`` if the Parameter is sharded or ``False``
            if the Parameter is intentionally not sharded (in which case we
            will all-reduce grads for this param).
        ``_orig_size``: the size of the original Parameter (before sharding)

    The remaining attributes are set here:
        ``_fp32_shard``: a single shard of the parameters in full precision
            (typically FP32, but this is dependent on the dtype of the model
            as it's passed in by the user). This can be on CPU or GPU
            depending on the value of *``cpu_offload``*.
        ``_fp16_shard``: if *``mixed_precision``* is ``True``, this will be
            a single shard of the parameters in FP16, used for all-gather.
        ``_full_param_padded``: the full weight (padded to be evenly
            divisible by ``world_size``), used for computation in the
            forward and backward pass. This will be resized in place and
            only materialized (via all-gather) as needed.
    """
    if hasattr(p, "_fp32_shard"):
        return

    # A single shard of the parameters in full precision.
    p._fp32_shard = p.data

    if self.mixed_precision:
        if self.move_params_to_cpu: 
            # 为后续移动到CPU做准备,放到pin_memory之中
            # If we plan to keep the FP32 parameters on CPU, then pinning
            # memory allows us to later use non-blocking transfers when moving
            # the FP32 param shard to compute_device.
            p._fp32_shard = p._fp32_shard.pin_memory() 
            p.data = p._fp32_shard

        # In mixed precision mode, we maintain a reduced precision
        # (typically FP16) parameter shard on compute_device for performing
        # the computation in the forward/backward pass. We resize the
        # storage to size 0 at init (here) and re-materialize (by copying
        # from _fp32_shard) as needed.
        p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
        free_storage_(p._fp16_shard)
    else:
        p._fp16_shard = None  # use _fp32_shard

    # We also maintain a full-sized parameter of type self.compute_dtype
    # (FP16 for mixed_precision or FP32 otherwise). We resize the
    # storage to size 0 at init (here) and only materialize as needed. The
    # storage may contain padding elements so that it is evenly divisible by
    # world_size, although these padding elements will be removed before the
    # relevant computation.
    if p._is_sharded:
        p._full_param_padded = torch.zeros( # _full_param_padded 是所有权重
            p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype
        )
        free_storage_(p._full_param_padded)

    if self.move_grads_to_cpu: 
        # 为后续移动到CPU做准备,放到pin_memory之中
        # We can optionally move the grad shard to CPU during the backward
        # pass. In this case, it's important to pre-allocate the CPU grad
        # shard in pinned memory so that we can do a non-blocking transfer.
        p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()

3.3.2 корневые настройки

В основном это делается для некоторых настроек группы процессов.

def _set_is_root(self) -> None:
    """If ``True``, implies that no other :class:`FullyShardedDataParallel`
    instance wraps this one. Called once by :func:`_lazy_init`.
    Also sets self.children_share_process_group = True if all child
    instances share the same process group. If some child instances use a
    different process group, self.clip_grad_norm_ will raise an error.
    """
    if self._is_root is not None:
        return
    # No FSDP instance wraps this, else _is_root would be set to False.
    self._is_root = True
    # As the root, we now set all children instances to False and
    # give them a closure to try to queue a wait_for_post_backward.
    self.children_share_process_group = True
    for n, m in self.named_modules():
        # `n != ""` excludes self.
        if n != "" and isinstance(m, FullyShardedDataParallel):
            # We relax the assert for non-root instance, when the nested inialized module is wrapped
            # again in FSDP later, for example after training to run inference.
            assert m._is_root is None or not m._is_root
            if m._is_root is None:
                m._is_root = False
            if m.process_group != self.process_group:
                self.children_share_process_group = False

            # if child instance in its own (smaller) world, that was probably an attempt to avoid OOM.
            # Therefore gathering this child's optim state will probably cause OOM, so we won't do it.
            m.no_broadcast_optim_state = m.no_broadcast_optim_state or (
                (m.world_size == 1) and (m.world_size < self.world_size) and (m.process_group != self.process_group)
            )

3.3.3 Установка потока CUDA

Создайте ReduceScatterBucketer с отдельными потоками CUDA для «fp32_to_fp16», «all_gather» и «post_backward».

def _setup_streams(self) -> None:
    """Create streams to overlap data transfer and computation."""
    if len(self._streams) > 0 or not self._is_root:
        return

    if torch.cuda.is_available():
        # Stream to move main FP32 params (may be on CPU) to FP16 for forward.
        self._streams["fp32_to_fp16"] = torch.cuda.Stream()
        # Stream for all-gathering parameters.
        self._streams["all_gather"] = torch.cuda.Stream()
        # Stream for overlapping grad reduction with the backward pass.
        self._streams["post_backward"] = torch.cuda.Stream()

    # Helper for bucketing reduce-scatter ops. This is also shared with
    # children instances to improve bucket utilization.
    self._reducer = ReduceScatterBucketer(self.bucket_cap_mb)
    # We share streams with all children instances, which allows them to
    # overlap transfers across the forward pass without synchronizing with
    # the default stream.
    for n, m in self.named_modules():
        if n != "" and isinstance(m, FullyShardedDataParallel):
            m._streams = self._streams
            m._reducer = self._reducer

3.3.4 Синхронизация

Подождите, пока поток завершит операцию.

def _wait_for_previous_optim_step(self) -> None:
    """
    The outer-most :class:`FullyShardedDataParallel` instance (i.e., the root
    instance) needs to synchronize with the default stream to ensure the
    previous optimizer step is done.
    """
    if not torch.cuda.is_available():
        return
    if self.mixed_precision:
        self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream())
    else:
        self._streams["all_gather"].wait_stream(torch.cuda.current_stream())

Таким образом, текущее состояние выглядит следующим образом: в предположении, что имеется 2 графических процессора, параметры модели делятся на два графических процессора. Предположим, что модель имеет два параметра, Параметр 0 и Параметр 1. Каждый параметр разделен на два сегмента, которые хранятся на двух графических процессорах соответственно, Параметр 0 разделен на Параметр 0_0 и Параметр 0_1, а Параметр 1 разделен на Параметр 1_0 и Параметр 1_1 .

                  Model Parameter
                  +----------------------------+
                  |       Parameter 0          |
                  |                            |
                  |       Parameter 1          |
                  |                            |
                  +------------+---------------+
                               |
                               | split
                               v
                         +-----+-----+
                         |           |
                         |           |
 GPU 0                   v           v                       GPU 1
+------------------------+----+   +--+---------------------------+
|  Model Parameter Shard 0    |   |  Model Parameter Shard 1     |
| +-------------------------+ |   | +--------------------------+ |
| |    Parameter 0_0        | |   | |      Parameter 0_1       | |
| |                         | |   | |                          | |
| |    Parameter 1_0        | |   | |      Parameter 1_1       | |
| |                         | |   | |                          | |
| +-------------------------+ |   | +--------------------------+ |
+-----------------------------+   +------------------------------+

0x04 Прямое распространение

Суть этой части заключается в точном сборе/использовании/выпуске параметров в соответствии с требованиями сегментирования параметров. Сбор — это All-gather на рисунке ниже, а выпуск — Reduce-Scatter.

img

4.1 forward

Согласно предыдущему анализу мы знаем, что форвардная операция состоит из двух частей:

  • All-gather: каждый GPU собирает все данные с других GPU через all-gatherВеса, вычислить переднее распространение локально.
  • Вперед (местный): выполнить операцию переадресации локально. И прямое вычисление, и обратное вычисление используют полную модель.

В соответствии с кодом конкретная логика такова:

  1. Если используется смешанная точность, преобразуйте входные данные в FP16.
  2. Если вы не используете смешанную точность и выполняете приведение к FP32, выполните преобразование.
  3. Перед вызовом _rebuild_full_params() для выполнения прямой операцииall-gather, который восстанавливает все параметры модели.
  4. Поскольку сбор/выпуск параметров происходит во время прямого и обратного распространения, необходимо настроить обратное распространение во время прямого распространения. В частности, вызовите _register_post_backward_hooks, чтобы установить уменьшение разброса для обратного распространения.
  5. Выполнить операцию вперед.
  6. Переключитесь на основной слайс параметров FP32. Мы сохраняем этот инвариант на протяжении всего кода, т.е. после каждой функции,p.data == p._fp32_shard. Поскольку состояние оптимизатора обычно находится вoptim.step()Средняя ленивая инициализация, которая также гарантирует, что после первой пересылки состояние оптимизатора будет инициализировано с правильным типом данных и размером (осколка).
  7. Вызовите _register_pre_backward_hooks(outputs) для регистрации all-gather для обратного распространения. здесь вокончательныйХук регистрируется в выходном тензоре, поэтому при обратном распространении он будетПервыйВызвав этот хук, можно сделать все-собрать по логике. Поскольку это должно быть зарегистрировано поверх окончательного вывода, _register_pre_backward_hooks вызывается только в конце прямого прохода.
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
    self._lazy_init()

    # Start of a forward pass.
    self.training_state = TrainingState.FORWARD

		# 1. 如果使用混合精度,则把输入转换为FP16
    
    # For root and mixed precision, we convert the input to FP16 (no_grad is needed for
    # the conversion).
    if self._is_root and self.mixed_precision:
        args, kwargs = cast_floats_to_right_precision(True, True, *args, **kwargs)

    # 2. 如果不使用混合精度,切强制转换FP32,则进行转换  
        
    # If enabled, convert the input to FP32 if we are in full precision.
    # no_grad is not used because the input might be for a non-root instance,
    # which mean autograd needs to go through the conversion.
    if self.force_input_to_fp32 and not self.mixed_precision:
        args, kwargs = cast_floats_to_right_precision(False, False, *args, **kwargs)

    # 3. 调用 _rebuild_full_params() 做前向操作之前的 all-gather
        
    # All-gather full parameters. This will also transfer FP32 parameters to
    # ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
    self._rebuild_full_params() # 做前向操作之前的 all-gather

    # 4. 调用_register_post_backward_hooks为后向传播建立 Reduce-scatter
    
    # Register backward hooks to reshard params and reduce-scatter grads.
    # These need to be re-registered every forward pass.
    self._register_post_backward_hooks() # 为后向传播建立 Reduce-scatter

    # 5. 进行前向操作
    
    outputs = self.module(*args, **kwargs)

    # 6. 丢弃多余模型参数
    if self.reshard_after_forward:
        self._free_full_params()
        if self.mixed_precision:
            self._free_fp16_param_shard()
            
    # 7. 切换到主FP32参数分片。我们在整个代码中都保持这个不变量,即在每个函数之后,``p.data == p._fp32_shard``。因为优化器状态通常在``optim.step()``中延迟初始化,这还确保在第一次forward之后,优化器状态将使用正确的数据类型和(分片)大小来初始化,   

    # Switch to main FP32 param shard. We maintain this invariant throughout
    # the code, i.e., ``p.data == p._fp32_shard`` after each function. This
    # also ensures that after the first forward, the optimizer state will be
    # initialized with the correct dtype and (sharded) size, since optimizer
    # state is typically initialized lazily in ``optim.step()``.
    self._use_fp32_param_shard()

    # 8. 调用 _register_pre_backward_hooks(outputs) 为后向传播注册 all-gather
    
    # Register pre-backward hooks to all-gather the params for the backward
    # pass (if output's grad was needed). This won't register anything if
    # we are in eval mode.
    #
    # Some model does forward pass multiple times, we need to register the
    # pre-backward hook on every output since the last output's hook has to
    # fire first to setup for backward. However, we use ``self._pre_backward_hook_has_run``
    # to prevent repeated overhead from multiple hook callbacks.
    outputs = self._register_pre_backward_hooks(outputs) # 为后向传播注册 all-gather

    # Done with a forward pass.
    self.training_state = TrainingState.IDLE

    # Only need to clear cache during forward. During backward, the cache is not used.
    # TODO (Min): Future PyTorch versions may provide a way to completely disable this
    #     cache. Update this when that's available.
    if self.clear_autocast_cache:
        torch.clear_autocast_cache()

    return outputs

Давайте посмотрим, как каждая часть реализована дальше.

4.1.1 All-gather

self._rebuild_full_params() выполнит операцию сбора перед операцией пересылки.

# All-gather full parameters. This will also transfer FP32 parameters to
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
self._rebuild_full_params()
4.1.1.1 _rebuild_full_params

Здесь, как и в OSS, синхронизируются все параметры модели, конкретная логика такова:

  1. Если у нас уже есть полные аргументы и полная точность не нужна, то выходим раньше.

  2. Настройте последующие операции на использование потока, соответствующего «all_gather».

  3. Сделайте прецизионное преобразование.

  4. Перебрать все параметры модели:

4.1 Если world_size==1, обновить напрямую, потому что ранг только один.

4.2 Переместить данные из CPU в CUDA.

4.3 Выполнить сборные операции.

4.4. Обновите локальный тензор с помощью общего результата, который будет реконструирован с использованием исходной информации, хранящейся в p._orig_size.

@torch.no_grad()
def _rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
    """
    Gather all shards of params.

    Args:
        force_full_precision (bool, Optional): by default params will be gathered
            in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
            ``True``, in which case they will be gathered in full precision
            (e.g., FP32), possibly in fresh storage. The parameter that's being
            rebuilt will end up in full precision as well.

    Returns:
        A list of tuples, where the first element is the full-sized param
        and the second element is a bool indicating if it's safe for the
        caller to free the full-sized param. This will be ``None`` if
        ``force_full_precision=False`` and the full params are already gathered.
    """
    output_tensors: List[Tuple[torch.Tensor, bool]] = []

    def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
        """
        Helper function to update p.data pointer.

        Args:
            custom_output_tensor (torch.Tensor, Optional): if not None, this
            tensor contains the data we just gathered.
        """
        if custom_output_tensor is not None:
            assert p._is_sharded
            p.data = custom_output_tensor
            output_tensors.append((p.data, True))
        elif not p._is_sharded:
            if self.mixed_precision and not force_full_precision:
                assert p._fp16_shard is not None
                p.data = p._fp16_shard
                output_tensors.append((p.data, True))
            else:
                # Here p.data == p._fp32_shard, so it's not safe to free.
                output_tensors.append((p.data, False))
        else:
            p.data = p._full_param_padded
            output_tensors.append((p.data, True))
        # Trim any padding and reshape to match original size.
        p.data = p.data[: p._orig_size.numel()].view(p._orig_size)

	  # 1. 如果我们已经有完整的参数并且不需要完整的精度,那么就提前退出。
        
    # Early exit if we already have full params and don't need full precision.
    if self.has_full_params and not force_full_precision:
        for p in self.params:
            update_p_data()
        return output_tensors

    self.has_full_params = True

    # 2. 使用 all_gather"对应的流
    with torch.cuda.stream(self._streams["all_gather"对应的流]):
      
      	# 3. 进行精度转换
        
        if self.mixed_precision and not force_full_precision:
            self._cast_fp32_param_shards_to_fp16()

        # 4. 遍历所有模型参数
            
        for p in self.params: 
            if not p._is_sharded:  # e.g., when world_size == 1
                update_p_data() # 4.1 如果world_size==1,则直接更新,因为只有一个rank
            else:
              
                # 4.2 把数据从CPU移动到CUDA
                
                # If self.move_params_to_cpu and force_full_precision, we need to cast
                # the FP32 CPU param to CUDA for the all-gather.
                p_data = p.data.to(p._full_param_padded.device, non_blocking=True)

                p_size = p._full_param_padded.size()
                if self.mixed_precision and force_full_precision:
                    # Allocate fresh tensor in full precision since we are in
                    # mixed precision and full precision rebuild is asked.
                    output_tensor = p_data.new_zeros(p_size)
                else:
                    if p._full_param_padded.storage().size() != p_size.numel():
                        # Allocate based on full size from all shards.
                        alloc_storage_(p._full_param_padded, size=p_size)
                    output_tensor = p._full_param_padded

                # 4.3 进行all-gather操作     
                
                # Fill output_tensor with (p.data for each shard in self.world_size)
                if hasattr(dist, "_all_gather_base"):
                    # New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
                    dist._all_gather_base(output_tensor, p_data, group=self.process_group)  # type: ignore
                else:
                    chunks = list(output_tensor.chunk(self.world_size))
                    dist.all_gather(chunks, p_data, group=self.process_group)

                # 4.4 用all-gather结果对本地张量进行更新
                
                # Set p.data = output_tensor (with padding trimmed)
                update_p_data(output_tensor)

                if self.mixed_precision and not force_full_precision:
                    self._free_fp16_param_shard([p])
                    
    torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
    return output_tensors
4.1.1.2 Прецизионная работа

_cast_fp32_param_shards_to_fp16 преобразует осколок параметра FP32 в осколок параметра FP16.

@torch.no_grad()
def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None:
    """Cast FP32 param shard to FP16 for a list of params."""
    if params is None:
        params = self.params
    with torch.cuda.stream(self._streams["fp32_to_fp16"]):
        for p in params:
            alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
            p._fp16_shard.copy_(
                # If cpu_offload is True, this will be non-blocking because
                # _fp32_shard is pinned, otherwise it's a no-op.
                p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
            )
            p.data = p._fp16_shard
    torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])

4.1.2 Отменить избыточные параметры

Эта часть включает две возможности: или отбросить параметры FP32, такие как _free_full_params.

@torch.no_grad()
def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
    """Free up storage for full parameters."""
    if params is None:
        params = self.params
    self.has_full_params = False
    current_stream = torch.cuda.current_stream()
    for p in params:
        if not p._is_sharded:  # e.g., world_size == 1
            if self.mixed_precision:
                self._free_fp16_param_shard([p])
            continue
        # Don't let PyTorch reuse this memory until all work in the current
        # stream is complete.
        p._full_param_padded.record_stream(current_stream)
        # There may be external references to the Tensor Storage that we
        # can't modify, such as references that are created by
        # ctx.save_for_backward in the forward pass. Thus when we
        # unshard parameters, we should reuse the original Tensor
        # Storage object and unshard it in-place. For now, just resize
        # the Storage to 0 to save memory.
        free_storage_(p._full_param_padded)

Или отбросить параметр FP16.

@torch.no_grad()
def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
    """Free storage for FP16 shards for a list of params."""
    if params is None:
        params = self.params
    current_stream = torch.cuda.current_stream()
    for p in params:
        if p._fp16_shard is not None:
            # _fp16_shard is allocated in "fp32_to_fp16" stream, so we can't
            # free it until the work in the current stream completes.
            p._fp16_shard.record_stream(current_stream)
            free_storage_(p._fp16_shard)

4.3 Настроить назад

Эта часть функции заключается в настройке обратного распространения, позволяющего ему выполнять сбор в начале и уменьшение разброса в конце. Эти конфигурации входят в конкретную логику работы:

  • All-gather: затем сделайте это снова перед обратным распространениемВесасобирать. Это подчеркнутая часть идеи диссертации, стр.
  • Назад (местный): локальное выполнение операции в обратном направлении. Как прямые, так и обратные вычисления используют полную модель, и каждый графический процессор такжевсе градиенты.
  • Reduce-scatter: после обратного распространения локальныйградиентОн агрегируется и сегментируется на каждом графическом процессоре с помощью редукции-разброса Градиент на каждом сегменте — это часть, соответствующая этому разделу после агрегации, что является подчеркнутой частью идеи статьи Pg.

Соответствует предыдущему коду:

        self._register_post_backward_hooks() # 为后向传播注册 reduce-scatter
        outputs = self.module(*args, **kwargs) # 模型前向传播
        outputs = self._register_pre_backward_hooks(outputs) # 为后向传播注册 all-gather

Далее мы анализируем их один за другим.

4.3.1 _register_post_backward_hooks

_register_post_backward_hooks — хук, вызываемый после регистрации обратного распространения, здесь хук — операции переразбиения и уменьшения разброса. Давайте интерпретируем его записи:

_register_post_backward_hooks Вызывается при прямом распространении. Цель состоит в том, чтобы сгенерировать функцию в градиенте каждого параметра (нижеgrad_acc), чтобы присоединить метод ловушки, который будет вызываться после того, как будут рассчитаны все градиенты для этого параметра.

наша цель:

  1. Мы хотим, чтобы этот хук срабатывал только один раз и после того, как будут вычислены все градиенты для этого параметра.
  2. Если он запускается более одного раза, мы получаем неправильный градиент, разделенный на несколько раз. (может привести к слишком маленьким размерам).
  3. Если он запускается один раз, но слишком рано или не запускается, мы не разбиваем градиент. (может привести к слишком большим размерам).

Из-за нескольких прямых операций эту функцию можно вызывать несколько раз с одним и тем же параметром за один прямой проход. Если мы зарегистрируем хук несколько раз, хук в конечном итоге будет вызываться несколько раз. Мы можем каждый раз пытаться получить новый хук и удалять ранее зарегистрированный хук

Однако по неизвестным причинам в режиме смешанной точности мы получаем два разных вызова этой функции (в одном прямом проходе)grad_accобъект. Если мы сохраним последний, хук запустится преждевременно. В режиме полной точности нам повезло получить то же самое.grad_accобъект, поэтому удаление и повторная регистрация по-прежнему гарантируют, что хук сработает только один раз после того, как все градиенты будут сгенерированы.

Исходя из опыта, сохранение первого хука, зарегистрированного при каждом прямом проходе, кажется наиболее эффективным. Нам также действительно нужно удалить хук в конце обратного распространения. В противном случае следующий прямой проход не зарегистрирует новый хук, который необходим для нового прямого прохода.

Помимо аннотаций, вот несколько специальных приемов:

  • Зачем регистрировать хук в grad_fn.next_functions[0][0], а не непосредственно в тензоре p?

Здесь все сложнее, поэтому я могу рассказать о нем лишь вкратце, а заинтересованные читатели могут изучить исходный код более подробно.

Во-первых, AccumulateGrad является производным от TraceableFunction, а TraceableFunction — от Node.

Во-вторых, в Node есть два вида хуков.

std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;
std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;

Если вы запустите register_hook для тензора p, он будет зарегистрирован поверх p.grad_fn.pre_hooks_, вычисление градиента в это время не выполнялось, поэтому градиент, полученный в это время, является только входом функции градиента, которая является временной переменной и не накапливается в фактической памяти градиентов, поэтому ловушка тензора обычно используется для наблюдения за временным градиентом.

Если вы запустите register_hook на AccumulateGrad, он будет зарегистрирован поверх p.grad_fn.post_hooks_, в это время вычисление градиента завершено, и в это время можно получить градиент, а p и его градиент не будут выпущены.

Хук grad_fn по умолчанию не передает параметры.Для реализации allreduce обычно передается параметр p, а p - это grad_fnсоответствующая переменная, поэтому используйте functools.partial для создания параметров.

Итак, чтобы получить правильный градиент, следует использовать post_hook, то есть запускать register_hook непосредственно в функции градиента.

  • Использование Expand_as.

Это связано с тем, что при вызове _register_post_backward_hooks расчет вперед не производится, поэтому функция градиента grad_fn на p в это время не сгенерирована. Роль expand_as заключается в следующем: эта функция градиента grad_fn может быть сгенерирована без создания фактических градиентов.

Следующий код может продемонстрировать это более наглядно.

a = torch.tensor(1.0, requires_grad=True)
print(a.grad_fn) # None,此时没有前向计算,所以没有梯度函数

a_temp = a.expand_as(a) # 虽然没有前向计算,但是也可以生成梯度函数,而且不产生实际梯度
print(a_temp.grad_fn) # ExpandBackward
print(a_temp.grad_fn.next_functions[0][0]) # AccumulateGrad
  
# 输出是:
None
<ExpandBackward object at 0x7fef9794e898> 
<AccumulateGrad object at 0x7fef9794e7f0> # 就是在这里注册 hook 

Примечание. Приведенные выше советы принадлежат другу Хуперу (ууууууу. Calling.com/people/hoopoe…) учился там.

Конкретный код регистрационного хука выглядит следующим образом:

def _register_post_backward_hooks(self) -> None:
    """
    Register backward hooks to reshard params and reduce-scatter grads.

    This is called during forward pass. The goal is to attach a hook
    on each of the parameter's gradient generating function (``grad_acc``
    below) so that the hook is called *after* all gradients for that
    param are computed.

    Goals:

    1. We want the hook to fire once and only once *after* all gradients
    are accumulated for a param.
    2. If it fires more than once, we end up incorrectly shard the grad
    multiple times. (could lead to dimension too small)
    3. If it fires once but too early or doesn't fire, we leave gradients
    unsharded. (could lead to dimension too large)

    Due to multiple-pass forward, this function can be called on
    the same parameter multiple times in a single forward pass. If we register
    the hook multiple time, we end up getting called multiple times. We
    could try to get a new hook every time and delete the previous one
    registered. However, due to *unknown reason* (I have debugged it for
    a long time!), in mixed precision mode, we get two different ``grad_acc``
    objects below during different calls of this function (in the same
    forward pass). If we keep the last one, the hook end up firing too
    early. In full precision mode, we luckily get the *same* ``grad_acc``
    object, so deleting and re-registering still ensured the hook fire
    once after all gradients are generated.

    Empirically, keep the first hook register per forward pass seems to
    work the best. We do need to remove the hook at the end of the
    backward pass. Otherwise, the next forward pass will not register
    a new hook, which is needed for a new forward pass.
    """
    if not torch.is_grad_enabled():
        return  # don't register grad hooks if grad isn't enabled
    for p in self.params:
        if p.requires_grad:
            if hasattr(p, "_shard_bwd_hook"):
                continue
            # Register a hook on the first call, empirically, autograd
            # fires it at the end for this param, which makes sense.
            p_tmp = p.expand_as(p)  # Get a grad_fn on p_tmp.
            assert p_tmp.grad_fn is not None
            grad_acc = p_tmp.grad_fn.next_functions[0][0]  # Gets its GradAccumulation object.
            handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p))
            p._shard_bwd_hook = (grad_acc, handle)

4.3.2 _post_backward_hook

_post_backward_hookэто функция ловушки, которая будет регистрировать_post_reduction_hookи self._reducer.reduce_scatter_async.

существует_post_backward_hookНачинать,param.gradСодержит полный градиент локальной партии. Операция уменьшения разброса поместитparam.gradЗаменено одним осколком суммы градиентов по всем графическим процессорам. Этот осколок соответствует текущему рангу, например:

    before reduce_scatter:
        param.grad (GPU #0): [1, 2, 3, 4]
        param.grad (GPU #1): [5, 6, 7, 8]

    after reduce_scatter:
        param.grad (GPU #0): [6, 8]    # 1+5, 2+6
        param.grad (GPU #1): [10, 12]  # 3+7, 4+8

локальный графический процессорoptim.stepОдин сегмент, отвечающий за обновление параметров, также соответствующий текущему рангу графического процессора. Это выравнивание задается_shard_parameters_Созданный, он гарантирует, что локальный оптимизатор видит только соответствующие фрагменты параметров.

Следующий код удаляет некоторые проверки.

@torch.no_grad()
def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
    """
    At the start of :func:`_post_backward_hook`, ``param.grad`` contains the
    full gradient for the local batch. The reduce-scatter op will replace
    ``param.grad`` with a single shard of the summed gradient across all
    GPUs. This shard will align with the current GPU rank. 

    The local GPU's ``optim.step`` is responsible for updating a single
    shard of params, also corresponding to the current GPU's rank. This
    alignment is created by :func:`_shard_parameters_`, which ensures that
    the local optimizer only sees the relevant parameter shard.
    """
    # First hook callback will see PRE state. If we have multiple params,
    # then subsequent hook callbacks will see POST state. When checkpoint
    # fwd counter is used, IDLE is also possible since the pre-backward hook
    # is not triggered (see ``auto_wrap_bn`` below, we have to use
    # FSDP(checkpoint(conv, FSDP(bn), ...)), with reshard_after_forward=False).

    self.training_state = TrainingState.BACKWARD_POST

    # If this is a checkpointed module, we check if the following
    # counter reaches 0. If not, it is not the final backward call
    # for this module yet. Therefore, we early return in that case.
    if hasattr(self._fsdp_wrapped_module, "_checkpoint_fwd_counter"):
        if self._fsdp_wrapped_module._checkpoint_fwd_counter != 0:
            return

    if self._require_backward_grad_sync or self.reshard_after_forward:
        # Free full params. As a special case, we don't free the full params
        # when in a ``no_sync`` context (as inversely indicated by
        # ``self._require_backward_grad_sync``), since the params will not
        # get updated before the next forward. This saves networking
        # bandwidth but uses more GPU memory.
        self._free_full_params([param])

    if self.mixed_precision:
        # This is a no-op if reshard_after_forward is True, since we already
        # free the param shard when rebuilding the full params in the
        # pre_backward_hook.
        self._free_fp16_param_shard([param])

    # Switch to FP32 shard after backward.
    self._use_fp32_param_shard([param])

    # Wait for all work in the current stream to finish, then start the
    # reductions in post_backward stream.
    self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(self._streams["post_backward"]):
        orig_grad_data = param.grad.data

        if self.mixed_precision and self.fp32_reduce_scatter:
            # Cast grad to FP32.
            param.grad.data = param.grad.data.to(param.dtype)

        if self.gradient_predivide_factor > 1:
            # Average grad by world_size for consistency with PyTorch DDP.
            param.grad.data.div_(self.gradient_predivide_factor)

        # 执行reduce-scatter操作    
            
        callback_fn = functools.partial(self._post_reduction_hook, param)
        if param._is_sharded:
            grad_chunks = chunk_and_pad(param.grad.data, self.world_size)
            self._reducer.reduce_scatter_async(grad_chunks, group=self.process_group, callback_fn=callback_fn)
        else:
            # Currently the only way for _is_sharded to be False is if
            # world_size == 1. This could be relaxed in the future, in which
            # case grads should be all-reduced here.
            callback_fn(param.grad.data)

        # After _post_backward_hook returns, orig_grad_data will eventually
        # go out of scope, at which point it could otherwise be freed for
        # further reuse by the main stream while the div/reduce_scatter/copy
        # are underway in the post_backward stream. See:
        # github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py
        orig_grad_data.record_stream(self._streams["post_backward"])

4.3.3 Reduce Scatter

ReduceScatterBucketer используется для объединения нескольких операций уменьшения разброса на небольших тензорах в более крупные операции уменьшения разброса для повышения эффективности связи. Асинхронное сокращение-рассеяние списка тензоров позволяет объединять меньшие сокращения. данный обратный вызов (callback_fn) будет вызвана позже с результатом сокращения. можно назватьflush()для принудительного выполнения всех операций в очереди и обратных вызовов. Обратите внимание, что большие входные данные будут немедленно уменьшены, эта функция также может очищать связанные корзины дляinput_listСделать комнату.

class ReduceScatterBucketer:
    """
    Helper for bucketing multiple reduce-scatter operations on small tensors
    into larger reduce-scatter ops to improve communication efficiency.
    """

    def __init__(self, bucket_cap_mb: int = 25):
        self.bucket_cap_mb = bucket_cap_mb
        self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}

    @torch.no_grad()
    def reduce_scatter_async(
        self, input_list: List[Tensor], group: ProcessGroup, callback_fn: Optional[Callable] = None,
    ) -> None:
        """
        Reduce-scatter a list of tensors asynchronously, so smaller reductions
        can be bucketed together. The given callback (``callback_fn``) will be
        called with the reduced result at some later time. Call ``flush()`` to
        force all queued ops and callbacks to be executed.

        Note that large inputs will be reduced immediately, and this function
        may also flush the relevant bucket to make room for ``input_list``.

        Args:
            input_list (List[Tensor]): list of tensors to reduce-scatter. List
                should contain ``group.size()`` tensors and each tensor should
                have identical shape, dtype and device.
            group (ProcessGroup): process group for reduction
            callback_fn (Callable, Optional): callback function to call after
                the reduction executes. Function will be called with a single
                argument corresponding to the reduced result.
        """
        world_size = group.size()
        first_input = input_list[0]
        first_input_size = first_input.numel()

        bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size)
        if first_input_size > bucket_shard_size:
            # input is too big to fit in the bucket, reduce-scatter directly
            output = torch.zeros_like(input_list[0])
            if hasattr(dist, "_reduce_scatter_base"):
                input_flattened = torch.cat(input_list)
                dist._reduce_scatter_base(output, input_flattened, group=group)  # type: ignore
            else:
                # fallback
                dist.reduce_scatter(output, input_list, group=group)
            if callback_fn is not None:
                callback_fn(output)
            return

        bucket = self._get_bucket(first_input, group)
        if first_input_size > bucket.data.size(1) - bucket.offset:
            # not enough space remaining in bucket, flush it now
            bucket.flush()

        # copy data from input_list into bucket
        stacked_input = torch.stack(input_list).view(world_size, first_input_size)
        offset = bucket.offset
        bucket.data[:, offset : offset + first_input_size].copy_(stacked_input)
        bucket.offset += first_input_size

        # callback will be given the reduced result
        if callback_fn is not None:
            result_view = bucket.output_shard[offset : offset + first_input_size].view_as(first_input)
            bucket.callbacks.append(functools.partial(callback_fn, result_view))

4.3.4 _register_pre_backward_hooks

Здесь регистрируется хук, который будет вызываться перед обратным распространением, в хуке вызывается _rebuild_full_params, а внутри вызывается all-gather. Поскольку это должно быть зарегистрировано поверх окончательного вывода, _register_pre_backward_hooks вызывается только в конце прямого прохода.

def _register_pre_backward_hooks(self, outputs: Any) -> Any:
    """Register pre-backward hook to run before the wrapped module's
    backward. Hooks should be attached to all outputs from the forward.

    Returns:
        outputs: new outputs with hooks registered if they requires gradient.
    """
    if not torch.is_grad_enabled():
        return outputs  # don't register hooks if grad isn't enabled

    if self._is_root:
        # This actually means that only root instance has
        # _post_backward_callback_queued defined. Accidentally accessing this field
        # will assert on all other instances, giving us a nice bug checker.
        self._post_backward_callback_queued = False

    def _pre_backward_hook(*unused: Any) -> None:
        # try to queue final backward callback only once for root, so
        # that final backward callback is attached to the outer most
        # backward graph task and called after all the backward
        # calls are completed.
        if self._is_root:
            self._queue_wait_for_post_backward()

        if self._pre_backward_hook_has_run:
            return  # only run once (from multiple outputs or multiple forward passes)
        self._pre_backward_hook_has_run = True

        # Start of a backward pass.
        self.training_state = TrainingState.BACKWARD_PRE

        # All-gather full parameters.
        if self.reshard_after_forward:
            self._rebuild_full_params() # 这里调用 all-gather
        else:
            self._use_full_params()

        # Prepare p.grad.
        self._prep_grads_for_backward()

    def _register_hook(t: torch.Tensor) -> torch.Tensor:
        if t.requires_grad:
            t.register_hook(_pre_backward_hook)
        return t

    # Attach hooks to Tensor outputs.
    outputs = apply_to_tensors(_register_hook, outputs)

    return outputs

Логика работы примерно такая:

+---------------------------------------------+     +--------------------------------------------------+
| forward                                     |     | backward                                         |
|                                             |     |                                                  |
| +                                           |     |                                                  |
| |       all_gather()                        |     |                ^                               ^ |
| |           +                               |     |                |                               | |
| |           |                               |     |                |                               | |
| |           |                               |     |                |                               | |
| |           |                               +     +                |                               | |
| |           v                               register               |                               | |
| | _register_post_backward_hooks()  +--------+-----+--> _post_backward_hook() +--> reduce_scatter() | |
| |                                           |     |                                                | |
| |           +                               |     |                ^                               | |
| |           |                               |     |                |                               | |
| |           |                               |     |                |                               | |
| |           v                               |     |                +                               | |
| |    outputs = self.module(*args, **kwargs) |     |         compute gradient                       | |
| |           +                               |     |                                                | |
| |           |                               |     |                ^                               | |
| |           |                               |     |                |                               | |
| |           |                               +     +                |                               | |
| |           v                               register               +                               | |
| | _register_pre_backward_hooks(outputs) +---+-----+--> _pre_backward_hook() +---> all_gather()     | |
| v                                           |     |                                                + |
|                                             |     |                                                  |
| Timeline                                    |     |                                         Timeline |
|                                             |     |                                                  |
+---------------------------------------------+     +--------------------------------------------------+

Телефон такой:

img

До сих пор мы рассказывали, как FSDP нарезает параметры модели для уменьшения накладных расходов на видеопамять.В следующей статье мы увидим, как разгрузка может еще больше сэкономить видеопамять.

0xEE Личная информация

★★★★★★Думая о жизни и технологиях★★★★★★

Публичный аккаунт WeChat:мысли Росси

ссылка 0xFF

docs.NVIDIA.com/deep учитесь в…

developer.NVIDIA.com/automatic - нет…

blogs.NVIDIA.com/blog/2019/1…

woohoo.paddle paddle.org.capable/document ATI…

on-demand.GPU tech conf.com/Steel Mills-Taiwan/…

bin dog.GitHub.IO/blog/2020/0…

docs.NVIDIA.com/deep учитесь в…

Optimizer state sharding (ZeRO)