[Анализ исходного кода] PyTorch, распространяемый Autograd (3) ---- контекстно-зависимый

машинное обучение PyTorch

0x00 сводка

Мы уже знаем, как dist.autograd отправляет и получает сообщения.В этой статье мы рассмотрим другие вспомогательные части, то есть, как координировать два действия отправки и получения, как определить каждый отправляющий/получающий узел и как определить каждое сообщение. Сессия взаимодействия.

Из этой статьи вы можете понять: AutogradMetadata используется для передачи метаданных автоградации между различными узлами, DistAutogradContext представляет распределенную информацию, связанную с автоградацией, а DistAutogradContainer отвечает за хранение DistAutogradContext на рабочем месте.

Другие статьи о распространении PyTorch:

[Анализ исходного кода] Распространение PyTorch (1) ------ история и обзор

[Анализ исходного кода] Как PyTorch использует GPU

[Анализ исходного кода] Распределенный PyTorch (2) ----- DataParallel (включен)

[Анализ исходного кода] Распределенный PyTorch (3) ----- DataParallel (ниже)

[Анализ исходного кода] Распределенный PyTorch (4) ------ Основная концепция распределенного приложения

[Анализ исходного кода] Распределенный PyTorch (5) ------ Обзор DistributedDataParallel и способы его использования

[Анализ исходного кода] Распределенный PyTorch (6) ---DistributedDataParallel -- инициализация и хранение

[Анализ исходного кода] Распределенный PyTorch (7) ----- Группа процессов DistributedDataParallel

[Анализ исходного кода] Распределенный PyTorch (8) -------- Бумага DistributedDataParallel

[Анализ исходного кода] Распределенный PyTorch (9) ----- Инициализация DistributedDataParallel

[Анализ исходного кода] Распределенный PyTorch (10) ------ Редуктор статической архитектуры DistributedDataParallel

[Анализ исходного кода] Распределенный PyTorch (11) ----- DistributedDataParallel для создания операций Reducer и Join

[Анализ исходного кода] Распределенный PyTorch (12) ----- Прямое распространение до DistributedDataParallel

[Анализ исходного кода] Распределенный PyTorch (13) ----- Обратное распространение DistributedDataParallel

[Анализ исходного кода] PyTorch, распространяемый Autograd (1) ---- дизайн

[Анализ исходного кода] PyTorch, распространяемый Autograd (2) ---- Фонд RPC

Для лучшего объяснения код в этой статье будет соответственно упрощен в соответствии с конкретной ситуацией.

0x01 Контекст дизайна

1.1 Предыдущий обзор

В предыдущей статье при отправке сообщения мы получили сообщение типа FORWARD_AUTOGRAD_REQ через getMessageWithAutograd в sendMessageWithAutograd.

c10::intrusive_ptr<JitFuture> sendMessageWithAutograd(
    RpcAgent& agent,
    const WorkerInfo& dst,
    torch::distributed::rpc::Message&& wrappedRpcMsg,
    bool forceGradRecording,
    const float rpcTimeoutSeconds,
    bool forceDisableProfiling) {
    
  auto msg = getMessageWithAutograd( // 这里会与上下文交互,构建了 FORWARD_AUTOGRAD_REQ
      dst.id_,
      std::move(wrappedRpcMsg),
      MessageType::FORWARD_AUTOGRAD_REQ,
      forceGradRecording,
      agent.getDeviceMap(dst));

  c10::intrusive_ptr<JitFuture> fut;
  if (!forceDisableProfiling && torch::autograd::profiler::profilerEnabled()) {
    auto profilerConfig = torch::autograd::profiler::getProfilerConfig();
    auto msgWithProfiling = getMessageWithProfiling(
        std::move(msg),
        rpc::MessageType::RUN_WITH_PROFILING_REQ, //构建消息
        std::move(profilerConfig));
    // 发送消息
    fut = agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds);
  } else {
    // 发送消息
    fut = agent.send(dst, std::move(msg), rpcTimeoutSeconds);
  }

  return fut;
}

А getMessageWithAutograd взаимодействует с контекстом, и его код находится в torch/csrc/distributed/autograd/utils.cpp.

Message getMessageWithAutograd(
    const rpc::worker_id_t dstId,
    torch::distributed::rpc::Message&& wrappedRpcMsg,
    MessageType msgType,
    bool forceGradRecording,
    const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
  
  // 获取到 DistAutogradContainer
  auto& autogradContainer = DistAutogradContainer::getInstance();

  // If there is no valid context and no tensor requires grads, send original
  // rpc message. otherwise, attach grad info and grad functions and send
  // rpcWithAutograd message.
  auto tensorsRequireGrad =
      torch::autograd::compute_requires_grad(wrappedRpcMsg.tensors());
  if (!autogradContainer.hasValidContext() ||
      (!forceGradRecording && !tensorsRequireGrad)) {
    return std::move(wrappedRpcMsg);
  }

  // Retrieve the appropriate context to modify.
  auto autogradContext = autogradContainer.currentContext(); // 获取到上下文,每个worker都有自己的上下文

  // Wrap the original rpc with autograd information.
  // newAutogradMessageId 会生成一个messageID
  AutogradMetadata autogradMetadata( // 构建了 AutogradMetadata
      autogradContext->contextId(), autogradContainer.newAutogradMessageId());
  auto rpcWithAutograd = std::make_unique<RpcWithAutograd>(
      RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_,
      msgType,
      autogradMetadata,
      std::move(wrappedRpcMsg),
      deviceMap);

  if (tensorsRequireGrad) {
    // Record autograd information for 'send'.
    addSendRpcBackward( // 这里把本地上下文,autograd 的元信息等一起打包
        autogradContext, autogradMetadata, rpcWithAutograd->tensors());
  }
  // Record the workerID
  autogradContext->addKnownWorkerId(dstId);

  return std::move(*rpcWithAutograd).toMessage(); // 最终构建了一个message
}

Поэтому вводится ряд базовых классов, таких как AutogradMetadata, DistAutogradContainer и DistAutogradContext, и далее мы тщательно проанализируем их.

1.2 Общая идея

Обрисуем общую идею.

Давайте сначала рассмотрим проблему: если система состоит из трех узлов a, b и c, и на каждом узле работает рабочий процесс, то при выполнении операции распространения мы задействуем распространение среди этих трех узлов. Поэтому нам нужен механизм для уникальной маркировки процесса распространения среди этих трех узлов.В этом процессе распространения каждая отправка/получение должна быть отмечена на каждом узле, чтобы узел мог поддерживать несколько операций параллельно.

Взгляните еще раз на решение:

  • Используйте контекст, чтобы однозначно идентифицировать процесс распространения. DistAutogradContext хранится в рабочемКаждыйИнформация о распределенном автоградации, которая инкапсулирует прямое и обратное распространение в распределенном автоградусе, накапливая градиенты, что позволяет избежать влияния нескольких рабочих процессов на градиенты друг друга. Каждому процессу автоматической дифференциации присваивается уникальный идентификатор autograd_context_id. В контейнере контекст (DistAutogradContext) этого процесса дифференциации однозначно идентифицируется в соответствии с этим autograd_context_id.
  • Используйте autogradMessageId для представления пары функций autograd отправки/получения. Каждыйsend-recvприсваивается глобально уникальныйautograd_message_idчтобы однозначно идентифицироватьsend-recvправильно. Это полезно для поиска соответствующих функций на удаленных узлах во время обратного распространения.
  • Наконец, каждому рабочему процессу нужно место для хранения контекста и идентификатора сообщения, поэтому существует класс DistAutogradContainer. У каждого воркера есть уникальный синглтон DistAutogradContainer, который отвечает за:
    • Для каждого процесса автоматической дифференциации сохраняется его распределенный контекст.
    • Как только этот процесс автоматической дифференциации завершен, его данные очищаются.

Таким образом, во время прямого прохода Pytorch сохраняет в контексте значение каждого прохода autograd.sendиrecvфункция. Это гарантирует, что мы сохраним ссылку на соответствующий узел в графе autograd, чтобы поддерживать его в рабочем состоянии. Среди прочего, это также упрощает поиск соответствующихsendиrecvфункция.

0x02 AutogradMetadata

2.1 Определения

AutogradMetadata Этот класс используется для передачи метаинформации автограда между разными узлами, то есть для инкапсуляции контекста и другой информации. То есть отправитель уведомляет получателя о своей собственной контекстной информации, и получатель выполняет соответствующую обработку в соответствии с полученной контекстной информацией.

Спойлерим заранее, получатель будет использовать autogradContextId и autogradMessageId как уникальный идентификатор контекста и сообщения соответственно. Это видно из комментариев.

  • autogradContextId — глобально уникальное целое число, используемое для представления уникального распределенного процесса распространения autograd (включая прямое и обратное распространение). Процесс распространения будет включать в себя несколько пар функций автоградации отправки/получения в цепочке обратного распространения.
  • autogradMessageId — глобально уникальное целое число, используемое для представления пары функций автоградации отправки/получения. Каждыйsend-recvприсваивается глобально уникальныйautograd_message_idчтобы однозначно идентифицироватьsend-recvправильно. Это полезно для поиска соответствующих функций на удаленных узлах во время обратного распространения.
// This structure represents autograd metadata that we need to pass across
// different nodes when we call an RPC which needs autograd computation.
struct TORCH_API AutogradMetadata {
  AutogradMetadata(int64_t autogradContextId, int64_t autogradMessageId);

  // autogradContextId_ is a globally unique integer that identifies a
  // particular distributed autograd pass.
  int64_t autogradContextId;
  // autogradMessageId_ is a globally unique integer that identifies a pair
  // of send/recv autograd functions.
  int64_t autogradMessageId;
};

Итак, вопрос в том, как autogradContextId и autogradMessageId могут быть глобально уникальными (включая несколько узлов)?

2.2 autogradMessageId

Давайте сначала подведем итоги: autogradMessageId косвенно генерируется рангом, а затем увеличивается внутри, поэтому можно гарантировать, что он будет глобально уникальным.

Переходим от задней части к передней.

  • Давайте сначала посмотрим, как newAutogradMessageId генерирует идентификатор сообщения.Оказалось, что переменная-член next_autograd_message_id_ в DistAutogradContainer увеличивается.
int64_t DistAutogradContainer::newAutogradMessageId() {
  // Check for overflow into workerId_ section.
  TORCH_INTERNAL_ASSERT(next_autograd_message_id_ < max_id_);
  return next_autograd_message_id_++;
}
  • Затем посмотрите, как инициализироватьnext_autograd_message_id_? Из функции инициализации DistAutogradContainer можно узнать, что next_autograd_message_id_ генерируется на основе worker_id. work_id — это параметр, полученный функцией инициализации.
DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) {
  std::lock_guard<std::mutex> guard(dist_container_init_lock_);

  auto& container = getInstanceInternal();
  container.worker_id_ = worker_id;
  container.next_context_id_ = static_cast<int64_t>(worker_id)
      << kAutoIncrementBits;
  container.next_autograd_message_id_ = static_cast<int64_t>(worker_id)
      << kAutoIncrementBits;
  container.max_id_ =
      (kAutoIncrementMask |
       (static_cast<int64_t>(worker_id) << kAutoIncrementBits));
  container.initialized_ = true;
  return container;
}
  • Давайте выведем и посмотрим, как установить рабочий идентификатор.Мы обнаружили следующее.Похоже, нам нужно посмотреть на метод _init в мире python.
module.def(
    "_init",
    [](int64_t worker_id) { DistAutogradContainer::init(worker_id); },
    py::call_guard<py::gil_scoped_release>());

Переходя к миру python, вы можете видеть, что ранг используется в качестве параметра, а ранг уникален для каждого воркера, что гарантирует, что идентификатор воркера уникален, и, следовательно, идентификатор сообщения уникален.

    def init_rpc(
        name,
        backend=None,
        rank=-1,
        world_size=None,
        rpc_backend_options=None,
    ):
			dist_autograd._init(rank) # rank是全局唯一

Резюмируем эти логические отношения:

worker_id = rank;

container.worker_id_ = worker_id;

container.next_autograd_message_id_ = static_cast<int64_t>(worker_id) << kAutoIncrementBits

Затем внутренне увеличивается значение next_autograd_message_id_.

int64_t DistAutogradContainer::newAutogradMessageId() {
  // Check for overflow into workerId_ section.
  TORCH_INTERNAL_ASSERT(next_autograd_message_id_ < max_id_);
  return next_autograd_message_id_++;
}

Таким образом, AutogradMessageId глобально уникален. Давайте вспомним легенду:

+----------------------------------------------------------------------------------------+
| worker                                                                                 |
|                       +-------------------------------------+                          |
|                       | DistAutogradContainer               |                          |
|                       |                                     |                          |
|                       |                                     |                          |
|              init()   |                                     |                          |
|      rank +--------------+----> worker_id_                  |                          |
|                1      |  |                                  |   newAutogradMessageId() |
|                       |  +----> next_autograd_message_id_+------------------+          |
|                       |                                     |          2    |          |
|                       +-------------------------------------+               |          |
|                                                                             |          |
|                                                                             |          |
|                                                                             |          |
|                                                                             |          |
|                     +---------------------------------------------------------------+  |
|                     | getMessageWithAutograd                                |       |  |
|                     |                                                       |       |  |
|                     |                                                       v       |  |
|                     |                                                               |  |
|                     |   AutogradMetadata autogradMetadata(contextId(), MessageId()) |  |
|                     |                           4                           3       |  |
|                     |                                                               |  |
|                     +---------------------------------------------------------------+  |
|                                                                                        |
+----------------------------------------------------------------------------------------+

Чтобы понять, почему autogradContextId гарантированно уникален, нам нужно сначала проанализировать DistAutogradContainer и DistAutogradContext.

0x03 DistAutogradContainer

У каждого воркера есть уникальный синглтон DistAutogradContainer., который отвечает за:

  • Для каждого процесса автоматической дифференциации сохраняется его распределенный контекст.
  • Как только этот процесс автоматической дифференциации завершен, его данные очищаются.

Каждому процессу автоматической дифференциации назначается уникальный autograd_context_id. В каждом контейнере контекст процесса дифференциации (DistAutogradContext) однозначно идентифицируется этим autograd_context_id. autograd_context_id — это 64-битный глобальный уникальный идентификатор, первые 16 бит — это worker_id, а последние 48 бит — это автоматически увеличивающийся идентификатор внутри каждого работника. Следовательно, видно, что в контейнере есть несколько контекстов.

Этот контейнер также отвечает за поддержание глобальных уникальных идентификаторов сообщений, которые используются для связывания пар функций автоматической дифференциации отправки и получения. Формат аналогичен autograd_context_id, который представляет собой 64-битное целое число, первые 16 бит — это идентификатор рабочего процесса, а последние 48 бит автоматически увеличиваются внутри рабочего процесса.

Поскольку первые 16 битов идентификатора сообщения и идентификатора контекста являются worker_id, который является идентификатором ранга, а последние 48 битов увеличиваются внутри, можно гарантировать, что идентификатор сообщения и идентификатор контекста будут глобально уникальными.

3.1 Определения

DistAutogradContainer определяется следующим образом, где:

  • worker_id_ : ID этого работника на самом деле является рангом этого работника.
  • next_context_id_ : Идентификатор контекста с автоинкрементом, используемый для назначения уникального autograd_context_id каждому процессу автоматической дифференциации. Фактически, в цепочке распространения только DistAutogradContainer первого узла использует next_context_id_ для создания контекста, а DistAutogradContainer последующих узлов основан на информации об идентификаторе контекста первого DistAutogradContainer для локального создания контекста, соответствующего идентификатору контекста. .
  • next_autograd_message_id_ : поддерживать глобально уникальный идентификатор сообщения, используемый для связывания пар функций автоматической дифференциации отправки/получения. Эта переменная используется, когда узел отправляет.
// Singleton class per worker which is responsible for storing the distributed
// autograd context for each autograd pass and also cleans up data for an
// autograd pass once its done.
//
// Each autograd pass is assigned a unique autograd_context_id and all data for
// that pass (DistAutogradContext) is stored in this container indexed by the
// autograd_context_id. The autograd_context_id itself is a 64 bit globally
// unique id. The first 16 bits is the worker_id and the next 48 bits is an
// auto-incrementing id for each worker.
//
// This container is also responsible for maintaining a globally unique message
// id, which is used to associate send/recv autograd function pairs. The format
// is similar to the autograd_context_id where we have a 64 bit integer with
// first 16 bits being the worker id and next 48 bits are auto-incrementing.
class TORCH_API DistAutogradContainer {

 private:
  // Number of shards for the map storing autograd contexts. We'd like this
  // to be a power of 2 and we don't expect a value much higher than the
  // number of cores would provide much benefit.
  static constexpr uint32_t kNumDefaultShards = 128;

  // Use cache line size for alignment.
  static constexpr int kCacheLineSize = 64;

  // Structure holding one shard of the sharded autograd context map with its
  // associated lock. Align to cache line size to avoid contention between
  // adjacent entries.
  struct alignas(kCacheLineSize) ContextsShard {
    // Lock for this shard.
    mutable std::mutex lock;

    // Map storing autograd contexts for this shard.
    std::unordered_map<int64_t, ContextPtr> contexts; // 这里存储了上下文指针
  };

  // Auto incrementing context id used to identify unique autograd passes.
  // Initialized with the first 16 bits being the worker_id.
  std::atomic<int64_t> next_context_id_; // 新增上下文id

  // Unique id to identify a worker in the distributed setting.
  int16_t worker_id_;

  // Whether or not the container has been initialized appropriately.
  bool initialized_;

  // Sharded autograd context map.
  std::vector<ContextsShard> autograd_contexts_; // 存储上下文列表

  // Number of shards for the sharded autograd_contexts_ map.
  uint32_t num_shards_;

  // Autograd message id to identify unique send/recv autograd function pairs.
  std::atomic<int64_t> next_autograd_message_id_;

  // Maximum allowed value for autograd_context_id or autograd_message_id.
  int64_t max_id_;
};

3.2 Сборка

Метод Init создает DistAutogradContainer, который в основном использует worker_id для выполнения соответствующих назначений локальным переменным-членам.

DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) {
  std::lock_guard<std::mutex> guard(dist_container_init_lock_);

  TORCH_CHECK(
      worker_id >= 0 && worker_id <= kMaxWorkerId,
      "worker_id needs to be in the range [0, 65535]")

  auto& container = getInstanceInternal();
  TORCH_CHECK(
      !container.initialized_ || (worker_id == container.worker_id_),
      "Container is already initialized with worker_id: ",
      container.worker_id_,
      ", cannot initialize with different worker_id: ",
      worker_id);

  if (container.initialized_) {
    return container;
  }

  container.worker_id_ = worker_id;
  container.next_context_id_ = static_cast<int64_t>(worker_id)
      << kAutoIncrementBits;
  container.next_autograd_message_id_ = static_cast<int64_t>(worker_id)
      << kAutoIncrementBits;
  container.max_id_ =
      (kAutoIncrementMask |
       (static_cast<int64_t>(worker_id) << kAutoIncrementBits));
  container.initialized_ = true;
  return container;
}

0x04 DistAutogradContext

DistAutogradContext хранится в рабочемКаждыйИнформация о распределенном автоградации, которая инкапсулирует прямое и обратное распространение в распределенном автоградусе, накапливая градиенты, что позволяет избежать влияния нескольких рабочих процессов на градиенты друг друга.

Как видно спереди, contextId_ глобально уникален.

4.1 Определения

Здесь указаны только переменные-члены DistAutogradContext, а его функции-члены игнорируются. Среди них есть три основных переменных-члена:

  • contextId_ — идентификатор контекста.
  • sendAutogradFunctions_ — это переменная типа карты, которая собирает оператор обратного распространения SendRpcBackward, соответствующий всем запросам на отправку.
  • recvAutogradFunctions_ — это переменная типа карты, которая собирает все операторы обратного распространения RecvRpcBackward, соответствующие входящим и исходящим запросам.

Что касается SendRpcBackward и RecvRpcBackward, мы проанализируем их в сочетании с движком позже.

// DistAutogradContext which stores information for a single distributed
// autograd pass on a worker.
class TORCH_API DistAutogradContext {
 private:
  friend class BackwardPassCleanupGuard;
  friend class DistEngine;
  friend class RecvRpcBackward;
  friend class DistAccumulateGradCaptureHook;

  const int64_t contextId_;

  // Set containing known worker IDs, used in cleaning up autograd context.
  // Whenever a sendRpcBackward is attached to the autograd graph for this
  // context, the destination is added here.
  std::unordered_set<rpc::worker_id_t> knownWorkerIds_;

  // Map from autograd_message_id to appropriate 'send' autograd function.
  std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>>
      sendAutogradFunctions_;

  // Map from autograd_message_id to appropriate 'recv' autograd function.
  std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>>
      recvAutogradFunctions_;

  // Gradients accumulated in this context so far. The key is the variable on
  // which the gradient needs to be accumulated and the value is the gradient
  // that needs to be accumulated on that variable..
  c10::Dict<torch::Tensor, torch::Tensor> accumulatedGrads_;

  // See comments for recordGradEvent(c10::Device device);
  std::unordered_map<c10::Device, c10::Event> gradReadyEvents_;
  const c10::impl::VirtualGuardImpl impl_;

  // The autograd GraphTask for the backward pass on this node for this context.
  std::shared_ptr<torch::autograd::GraphTask> graphTask_;

  // List of futures for RPCs initiated by this node to propagate gradients to
  // other nodes. The distributed autograd engine on this node can return
  // successfully only if all these futures are done and are successful.
  std::vector<c10::intrusive_ptr<rpc::JitFuture>> outStandingRpcs_;

  // Lock to protect concurrent modification of the context.
  mutable std::mutex lock_;
};

4.2 Сообщения

Контекст в основном включает в себя несколько типов сообщений, таких как:

// Messages with autograd info
FORWARD_AUTOGRAD_REQ = 0x0f | MessageTypeFlags::REQUEST_TYPE,
FORWARD_AUTOGRAD_RESP = 0x10 | MessageTypeFlags::RESPONSE_TYPE,

// Messages to propagate gradients on the backward pass.
BACKWARD_AUTOGRAD_REQ = 0x11 | MessageTypeFlags::REQUEST_TYPE,
BACKWARD_AUTOGRAD_RESP = 0x12 | MessageTypeFlags::RESPONSE_TYPE,

4.3 Сборка

Давайте сначала посмотрим, как конструируется контекст.

4.3.1 getOrCreateContext

Функция getOrCreateContext используется для получения контекста, если он уже существует, получить его напрямую, если нет, создать новый. Это пассивный вызов, его будет использовать принимающая сторона.

ContextPtr DistAutogradContainer::getOrCreateContext(int64_t context_id) {
  auto& shard = getShard(context_id);
  std::lock_guard<std::mutex> guard(shard.lock);
  auto it = shard.contexts.find(context_id); // 根据这个context id来查找
  if (it != shard.contexts.end()) {
    return it->second; // 找到就返回
  }

  auto& context = // 如果没有,就构建一个 context
      shard.contexts
          .emplace(
              std::piecewise_construct,
              std::forward_as_tuple(context_id),
              std::forward_as_tuple(
                  std::make_shared<DistAutogradContext>(context_id)))
          .first->second;
  return context;
}

4.3.2 newContext

Это активный вызов, и отправляющая сторона вызовет этот метод.

4.3.2.1 Python

При вызове распределенного мир Python генерирует контекст.

            with dist_autograd.context() as context_id:
                output = model(indices, offsets)
                loss = criterion(output, target)

                # Run distributed backward pass
                dist_autograd.backward(context_id, [loss])

                # Run distributed optimizer. Gradients propagated all the way to the parameter servers
                opt.step(context_id)

При генерации__enter___new_context() будет вызываться для создания контекста в C++.

class context(object):
    '''
    Context object to wrap forward and backward passes when using
    distributed autograd. The ``context_id`` generated in the ``with``
    statement  is required to uniquely identify a distributed backward pass
    on all workers. Each worker stores metadata associated with this
    ``context_id``, which is required to correctly execute a distributed
    autograd pass.

    Example::
        >>> import torch.distributed.autograd as dist_autograd
        >>> with dist_autograd.context() as context_id:
        >>>   t1 = torch.rand((3, 3), requires_grad=True)
        >>>   t2 = torch.rand((3, 3), requires_grad=True)
        >>>   loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum()
        >>>   dist_autograd.backward(context_id, [loss])
    '''
    def __enter__(self):
        self.autograd_context = _new_context() # 这里生成一个上下文
        return self.autograd_context._context_id()

    def __exit__(self, type, value, traceback):
        _release_context(self.autograd_context._context_id())

В частности, с помощью следующего сопоставления мы видим, что соответствующий метод в мире C++ вызывает DistAutogradContainer::getInstance().newContext().

  module.def(
      "_new_context",
      []() -> const ContextPtr {
        return DistAutogradContainer::getInstance().newContext();
      },
      py::return_value_policy::reference);
4.3.2.2 C++

Мы пришли в мир C++. Каждый поток имеет autograd_context_id.

constexpr int64_t kInvalidContextId = -1;

// Each thread has a single autograd_context_id valid at any point in time.
static thread_local int64_t current_context_id_ = kInvalidContextId;

newContext предназначен для создания DistAutogradContext, в котором идентификатор следующего контекста указывается путем увеличения переменной-члена next_context_id_ контейнера.

const ContextPtr DistAutogradContainer::newContext() {

  auto context_id = next_context_id_++; // 递增
  current_context_id_ = context_id;  // 在这里设置了本地线程的 current_context_id_

  // Check for overflow into workerId_ section.
  TORCH_INTERNAL_ASSERT(context_id < max_id_);

  auto& shard = getShard(context_id);
  std::lock_guard<std::mutex> guard(shard.lock);
  auto& context =
      shard.contexts
          .emplace(
              std::piecewise_construct,
              std::forward_as_tuple(context_id),
              std::forward_as_tuple(
                  std::make_shared<DistAutogradContext>(context_id)))
          .first->second;

  return context;
}

4.4 Как поделиться контекстом

В конкретном использовании, вwithгенерируется в заявленииcontext_idМожет использоваться для уникальной идентификации распределенного обратного распространения (как прямого, так и обратного) для всех воркеров. Каждый рабочий хранится с этимcontext_idСвязанные метаданные, необходимые для правильного выполнения распределенного процесса автозагрузки.

Потому что это нужно хранить в нескольких рабочихcontext_idСвязанные метаданные, поэтому для передачи этих метаданных между рабочими процессами необходим механизм инкапсуляции/отправки/получения Механизм инкапсуляции — это AutogradMetadata, о котором мы упоминали ранее. Далее давайте посмотрим, как отправлять/получать контекстную метаинформацию.

4.4.1 Отправитель

При отправке сообщения getMessageWithAutograd будет использовать autogradContainer.currentContext(), чтобы получить текущий контекст и отправить его.

Message getMessageWithAutograd(
    const rpc::worker_id_t dstId,
    torch::distributed::rpc::Message&& wrappedRpcMsg,
    MessageType msgType,
    bool forceGradRecording,
    const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
  auto& autogradContainer = DistAutogradContainer::getInstance();

  // If there is no valid context and no tensor requires grads, send original
  // rpc message. otherwise, attach grad info and grad functions and send
  // rpcWithAutograd message.
  auto tensorsRequireGrad =
      torch::autograd::compute_requires_grad(wrappedRpcMsg.tensors());
  if (!autogradContainer.hasValidContext() ||
      (!forceGradRecording && !tensorsRequireGrad)) {
    return std::move(wrappedRpcMsg);
  }

  // Retrieve the appropriate context to modify.
  auto autogradContext = autogradContainer.currentContext(); // 获取当前上下文

  // Wrap the original rpc with autograd information.
  AutogradMetadata autogradMetadata( // 使用上下文id和消息id来构建元数据
      autogradContext->contextId(), autogradContainer.newAutogradMessageId());
  auto rpcWithAutograd = std::make_unique<RpcWithAutograd>(
      RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_,
      msgType,
      autogradMetadata,
      std::move(wrappedRpcMsg),
      deviceMap);

  if (tensorsRequireGrad) {
    // Record autograd information for 'send'.
    addSendRpcBackward(
        autogradContext, autogradMetadata, rpcWithAutograd->tensors());
  }
  // Record the workerID
  autogradContext->addKnownWorkerId(dstId);

  return std::move(*rpcWithAutograd).toMessage();
}

Теперь нашу предыдущую диаграмму можно расширить, включив идентификаторы контекста.

+----------------------------------------------------------------------------------------+
| worker                                                                                 |
|                  +------------------------------------------+                          |
|                  |DistAutogradContainer                     |                          |
|          init()  |                                          |                          |
|  rank +-------------+----> worker_id_                       |                          |
|                  |  |                                       |                          |
|                  |  +----> next_context_id_+-------------+  |                          |
|                  |  |                                    |  |                          |
|                  |  +----> next_autograd_message_id_ +----------------------+          |
|                  |                                       |  |               |          |
|                  |                                       |  |               |          |
|                  +------------------------------------------+               |          |
|                                                          |                  |          |
|                                                          |                  |          |
|                                                          |                  |          |
|                  +------------------------------------------------------------------+  |
|                  |getMessageWithAutograd                 |                  |       |  |
|                  |                                       |                  |       |  |
|                  |                                       v                  v       |  |
|                  |                                                                  |  |
|                  |    AutogradMetadata autogradMetadata(contextId(), MessageId())   |  |
|                  |                                                                  |  |
|                  |                                                                  |  |
|                  +------------------------------------------------------------------+  |
|                                                                                        |
+----------------------------------------------------------------------------------------+

addSendRpcBackward передается в текущий контекст, а addSendRpcBackward будет удален во время последующего обратного распространения.

void addSendRpcBackward(
    const ContextPtr& autogradContext,
    const AutogradMetadata& autogradMetadata,
    std::vector<torch::Tensor>& tensors) {
  // Attach autograd information only for tensors requiring grad.
  std::vector<torch::Tensor> tensors_with_grad;
  std::copy_if(
      tensors.begin(),
      tensors.end(),
      std::back_inserter(tensors_with_grad),
      [](const torch::Tensor& t) { return t.requires_grad(); });

  // Attach the appropriate autograd edges.
  auto grad_fn = std::make_shared<SendRpcBackward>();
  grad_fn->set_next_edges(
      torch::autograd::collect_next_edges(tensors_with_grad));

  // Add the appropriate input metadata for the grad_fn.
  for (const auto& tensor : tensors_with_grad) {
    grad_fn->add_input_metadata(tensor);
  }

  // Record the send autograd function in our current context.
  autogradContext->addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
}

4.4.2 Получатель

В addRecvRpcBackward контекст создается на основе переданного autogradMetadata.autogradContextId.

ContextPtr addRecvRpcBackward(
    const AutogradMetadata& autogradMetadata,
    std::vector<torch::Tensor>& tensors,
    rpc::worker_id_t fromWorkerId,
    const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
  // Initialize autograd context if necessary.
  auto& autogradContainer = DistAutogradContainer::getInstance();
  // 生成或者得到一个上下文,把发送方的 autogradContextId 传入,即利用 autogradContextId 作为key后续可以查找到这个上下文
  auto autogradContext = 
      autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);

  if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
    // Attach the tensors as inputs to the autograd function.
    auto grad_fn = std::make_shared<RecvRpcBackward>(
        autogradMetadata, autogradContext, fromWorkerId, deviceMap);
    for (auto& tensor : tensors) {
      if (tensor.requires_grad()) {
        torch::autograd::set_history(tensor, grad_fn);
      }
    }

    // Now update the autograd context with the necessary information.
    autogradContext->addRecvFunction(
        grad_fn, autogradMetadata.autogradMessageId);
  }

  return autogradContext;
}

Таким образом, отправитель и получатель имеют общий контекст, и идентификатор этого контекста является глобально уникальным.

Конкретная логика такова, верх — отправитель, низ — получатель.

  • отправитель
    • AutogradMetadata построен с локальным context_id, AutogradMetadata содержит ctx_id, msg_id.
    • Сообщение создается с использованием AutogradMetadata.
    • Сообщение отправлено через agent.send.
  • Приемный конец:
    • Сообщение доставлено.
    • Анализ AutogradMetadata из сообщения.
    • Извлеките context_id из AutogradMetadata.
    • Локальный DistAutogradContext создается с помощью context_id.
  • Отправитель и получатель имеют общий контекст (идентификатор этого контекста глобально уникален).
+----------------------------------------------------------------------------------+
| sendMessageWithAutograd                                                          |
|                                                                                  |
|  +----------------------------------------------------------------------------+  |
|  | addSendRpcBackward                                                         |  |
|  |                                                                            |  |
|  |                                                                            |  |
|  |               autogradMetadata = AutogradMetadata(context_id, message_id)  |  |
|  |                          +                                                 |  |
|  |                          |                                                 |  |
|  +----------------------------------------------------------------------------+  |
|                             |                                                    |
|                             v                                                    |
|        agent.send(message(autogradMetadata)                                      |
|                             +                                                    |
|                             |                                                    |
+----------------------------------------------------------------------------------+
                              |
                              |
                              |
                              |                                             Sender
+-----------------------------------------------------------------------------------+
                              |                                             Receiver
                              | message
                              v
                              |
+----------------------------------------------------------------------------------+
| processForwardAutogradReq   |                                                    |
|                             |                                                    |
|                             | message.autogradMetadata                           |
|                             v                                                    |
|  +----------------------------------------------------------------------------+  |
|  | addSendRpcBackward       |                                                 |  |
|  |                          |                                                 |  |
|  |                          +--------------------+                            |  |
|  |                                               |                            |  |
|  |                                               v                            |  |
|  |   autogradContext = getOrCreateContext(autogradMetadata.autogradContextId) |  |
|  |                                                                            |  |
|  |                                                                            |  |
|  +----------------------------------------------------------------------------+  |
|                                                                                  |
+----------------------------------------------------------------------------------+

0x05 Интерактивный процесс прямого распространения

Предыдущий процесс обмена по-прежнему краток, и мы подробно проанализируем весь процесс отправки/получения.

5.1 Отправить

Это соответствует следующему тексту в дизайне:

Во время прямого прохода мы сохраняем в контексте значение каждого прохода автоградаsendиrecvфункция. Это гарантирует, что мы сохраним ссылку на соответствующий узел в графе autograd, чтобы поддерживать его в рабочем состоянии. Кроме того, это также упрощает поиск соответствующихsendиrecvфункция.

5.1.1 Логика отправки

Логика кода следующая:

  • Создайте grad_fn типа SendRpcBackward.
  • Вызовите collect_next_edges и set_next_edges, чтобы добавить последующие ребра в SendRpcBackward, эти функции мы проанализировали в предыдущей серии.
  • Вызовите add_input_metadata, чтобы добавить входные метаданные.
  • Вызовите addSendFunction, чтобы добавить grad_fn в контекст.
void addSendRpcBackward(
    const ContextPtr& autogradContext,
    const AutogradMetadata& autogradMetadata,
    std::vector<torch::Tensor>& tensors) {
  // Attach autograd information only for tensors requiring grad.
  std::vector<torch::Tensor> tensors_with_grad;
  std::copy_if(
      tensors.begin(),
      tensors.end(),
      std::back_inserter(tensors_with_grad),
      [](const torch::Tensor& t) { return t.requires_grad(); });

  // Attach the appropriate autograd edges.
  auto grad_fn = std::make_shared<SendRpcBackward>();
  grad_fn->set_next_edges( // 这里会设置其输出边
      torch::autograd::collect_next_edges(tensors_with_grad));

  // Add the appropriate input metadata for the grad_fn.
  for (const auto& tensor : tensors_with_grad) {
    grad_fn->add_input_metadata(tensor);
  }

  // Record the send autograd function in our current context.
  autogradContext->addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
}

5.1.2 Настройка контекста

Давайте еще раз вспомним определение DistAutogradContext, здесь даны только некоторые его переменные-члены.

  • contextId_ — идентификатор контекста.
  • sendAutogradFunctions_ — это переменная типа карты, которая собирает оператор обратного распространения SendRpcBackward, соответствующий всем запросам на отправку.
  • recvAutogradFunctions_ — это переменная типа карты, которая собирает все операторы обратного распространения RecvRpcBackward, соответствующие входящим и исходящим запросам.
// DistAutogradContext which stores information for a single distributed
// autograd pass on a worker.
class TORCH_API DistAutogradContext {

  const int64_t contextId_;

  // Map from autograd_message_id to appropriate 'send' autograd function.
  std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>>
      sendAutogradFunctions_;

  // Map from autograd_message_id to appropriate 'recv' autograd function.
  std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>>
      recvAutogradFunctions_;
};

addSendFunction заключается в том, чтобы добавить SendRpcBackward в sendAutogradFunctions_, а затем вы можете получить этот SendRpcBackward в соответствии с идентификатором сообщения.

void DistAutogradContext::addSendFunction(
    const std::shared_ptr<SendRpcBackward>& func,
    int64_t autograd_message_id) {

  std::lock_guard<std::mutex> guard(lock_);
  TORCH_INTERNAL_ASSERT(
      sendAutogradFunctions_.find(autograd_message_id) ==
      sendAutogradFunctions_.end());
  sendAutogradFunctions_.emplace(autograd_message_id, func);
}

Предыдущее — с точки зрения построения контекста, на этот раз — с точки зрения содержания контекста.

На данный момент логика отправителя такова:

+--------------------------------------------------------------+    +-------------------+
| worker                                                       |    |SendRpcBackward    |
| +---------------------------------------------------------+  |    |                   |
| | DistAutogradContext                                     |  |    |   input_metadata_ |
| |                                                 +-------------> |                   |
| |  contextId_ = context_id_1                      |       |  |    |   next_edges_     |
| |                                                 +       |  |    |                   |
| |  sendAutogradFunctions_ = [msg_id_1, SendRpcBackward_1] |  |    +-------------------+
| |                                                         |  |
| |                                                         |  |
| |  recvAutogradFunctions_                                 |  |
| |                                                         |  |
| +---------------------------------------------------------+  |
|                                                              |
+--------------------------------------------------------------+

                                                                                  sender
+---------------------------------------------------------------------------------------+

5.2 Принятие

Давайте пропустим внутреннюю обработку отправки агента и вместо этого рассмотрим бизнес-процесс FORWARD_AUTOGRAD_REQ.

5.2.1 Получение сообщения ---> Получатель

При создании TensorPipeAgent настройте RequestCallbackImpl как функцию обратного вызова. Это унифицированная функция ответа агента.

Когда мы упоминали логику получения агентом ранее, мы введем следующую функцию, в которой мы можем увидеть логику обработки для processForwardAutogradReq.

void RequestCallbackNoPython::processRpc(
    RpcCommandBase& rpc,
    const MessageType& messageType,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture,
    std::shared_ptr<LazyStreamContext> ctx) const {

    case MessageType::FORWARD_AUTOGRAD_REQ: {
      // 会来到这里
      processForwardAutogradReq(rpc, messageId, responseFuture, std::move(ctx));
      return;
    }
    case MessageType::BACKWARD_AUTOGRAD_REQ: {
      processBackwardAutogradReq(rpc, messageId, responseFuture);
      return;
    };  
  
}  

5.2.2 Обработка сообщений

processForwardAutogradReq отвечает за конкретную обработку сообщений, и логика его обработки следующая:

  • Несмотря на то, что запрос на прямое распространение получен, поскольку это принимающая сторона, необходимо выполнить последующее обратное распространение, поэтому deviceMap транспонируется.
  • Используйте addRecvRpcBackward для контекстуализации сообщения rpc.
  • Возможны вложенные команды, поэтому вам нужно снова вызвать processRpc.
  • Установите наиболее оригинальное сообщение для обработки и выполните соответствующие операции.
void RequestCallbackNoPython::processForwardAutogradReq(
    RpcCommandBase& rpc,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture,
    std::shared_ptr<LazyStreamContext> ctx) const {
  
  auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);

  // Need to reverse the device map for the backward pass of distributed
  // autograd.
  std::unordered_map<c10::Device, c10::Device> reverseDeviceMap;
  // 对deviceMap进行转置
  for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
    reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
  }

  // Attach 'recv' autograd function.
  auto autogradContext = addRecvRpcBackward( // 调用了 addRecvRpcBackward 加入上下文
      rpcWithAutograd.autogradMetadata(),
      rpcWithAutograd.tensors(),
      rpcWithAutograd.fromWorkerId(),
      reverseDeviceMap);
  // For this recv thread on server side, before processRpc(),
  // set current_context_id_ to be context_id passed from client.
  // In this way, if there is nested rpc call in python rpc call, original
  // context_id from client can be passed in the chain calls.
  DistAutogradContextGuard ctxGuard(autogradContext->contextId());

  // Process the original RPC.
  auto wrappedMessageType = rpcWithAutograd.wrappedMessageType();
  // Make an overall future for the wrapped response.
  auto wrappedRpcResponseFuture =
      c10::make_intrusive<JitFuture>(at::AnyClassType::get());
  // Kick off processing for the nested RPC command.
  // wrappedRpcResponseFuture will be a Future<T> to the result.
  processRpc( // 可能会有nested命令的可能,所以需要再处理一次
      rpcWithAutograd.wrappedRpc(),
      wrappedMessageType,
      messageId,
      wrappedRpcResponseFuture,
      std::move(ctx));

  auto fromWorkerId = rpcWithAutograd.fromWorkerId();
  // The original future needs to be marked as completed when the wrapped
  // one completes, with the autograd context information wrapped.
  wrappedRpcResponseFuture->addCallback(
      [responseFuture,
       messageId,
       fromWorkerId,
       ctxId =
           autogradContext->contextId()](JitFuture& wrappedRpcResponseFuture) {
        // As this callback can be invoked by a different thread, we have to
        // make sure that the thread_local states in the previous thread is
        // correctly propagated.
        // NB: The execution of TorchScript functions can also run on a
        // different thread, which is addressed by
        // https://github.com/pytorch/pytorch/pull/36395
        // NB: when adding async UDF support, we should also propagate
        // thread_local states there.
        // TODO: Land on a general solution for RPC ThreadLocalState. See
        // https://github.com/pytorch/pytorch/issues/38510
        DistAutogradContextGuard cbCtxGuard(ctxId);

        if (wrappedRpcResponseFuture.hasError()) {
          // Propagate error to responseFuture if we had one.
          responseFuture->setError(wrappedRpcResponseFuture.exception_ptr());
        } else {
          auto msg = getMessageWithAutograd(
              fromWorkerId,
              std::move(
                  *wrappedRpcResponseFuture.value().toCustomClass<Message>()),
              MessageType::FORWARD_AUTOGRAD_RESP);
          msg.setId(messageId);
          responseFuture->markCompleted(
              IValue(c10::make_intrusive<Message>(std::move(msg))));
        }
      });
}

5.2.3 Контекстное взаимодействие

В torch/csrc/distributed/autograd/utils.cpp функция addRecvRpcBackward обрабатывает контекст.

Вот соответствующий дизайн:

Во время прямого прохода мы сохраняем в контексте значение каждого прохода автоградаsendиrecvфункция. Это гарантирует, что мы сохраним ссылку на соответствующий узел в графе autograd, чтобы поддерживать его в рабочем состоянии. Среди прочего, это также упрощает поиск соответствующихsendиrecvфункция.

Его конкретная логика такова:

  • Получите локальный контекст в соответствии с autogradContextId в информации rpc.
  • Создайте RecvRpcBackward.
  • Настройте RecvRpcBackward с тензорами в информации rpc, включая torch::autograd::set_history(tensor, grad_fn).
  • Вызовите addRecvFunction, чтобы добавить RecvRpcBackward в контекст.
ContextPtr addRecvRpcBackward(
    const AutogradMetadata& autogradMetadata,
    std::vector<torch::Tensor>& tensors,
    rpc::worker_id_t fromWorkerId,
    const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
  // Initialize autograd context if necessary.
  auto& autogradContainer = DistAutogradContainer::getInstance();
  auto autogradContext =
      autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);

  if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
    // Attach the tensors as inputs to the autograd function.
    auto grad_fn = std::make_shared<RecvRpcBackward>(
        autogradMetadata, autogradContext, fromWorkerId, deviceMap);
    for (auto& tensor : tensors) {
      if (tensor.requires_grad()) {
        torch::autograd::set_history(tensor, grad_fn);
      }
    }

    // Now update the autograd context with the necessary information.
    autogradContext->addRecvFunction(
        grad_fn, autogradMetadata.autogradMessageId);
  }

  return autogradContext;
}

Операция сложения addRecvFunction заключается в следующем, то есть посмотреть, существует ли уже оператор, соответствующий id сообщения, в recvAutogradFunctions_, если нет, добавить его.

void DistAutogradContext::addRecvFunction(
    std::shared_ptr<RecvRpcBackward>& func,
    int64_t autograd_message_id) {
  TORCH_INTERNAL_ASSERT(func != nullptr);
  std::lock_guard<std::mutex> guard(lock_);
  TORCH_INTERNAL_ASSERT(
      recvAutogradFunctions_.find(autograd_message_id) ==
      recvAutogradFunctions_.end());
  recvAutogradFunctions_.emplace(autograd_message_id, func);
}

Пока что логика расширена следующим образом: DistAutogradContext есть и у отправителя, и у получателя, и его id — context_id_1.

В каждом DistAutogradContext в качестве ключа используется msg_id_1, один — SendRpcBackward, а другой — RecvRpcBackward.

Это соответствует тому, что было упомянуто в дизайне:

Каждому процессу автоматической дифференциации присваивается уникальный идентификатор autograd_context_id. В контейнере контекст (DistAutogradContext) этого процесса дифференциации однозначно идентифицируется в соответствии с этим autograd_context_id. autograd_context_id — это 64-битный глобальный уникальный идентификатор, первые 16 бит — это worker_id, а последние 48 бит — это автоматически увеличивающийся идентификатор внутри каждого работника. Следовательно, видно, что в контейнере есть несколько контекстов.

Этот контейнер также отвечает за поддержание глобальных уникальных идентификаторов сообщений, которые используются для связывания пар функций автоматической дифференциации отправки и получения. Формат аналогичен autograd_context_id, который представляет собой 64-битное целое число, первые 16 бит — это идентификатор рабочего процесса, а последние 48 бит автоматически увеличиваются внутри рабочего процесса.

+----------------------------------------------------------------+
| worker                                                         |    +-------------------+
|                                                                |    |SendRpcBackward    |
|   +---------------------------------------------------------+  |    |                   |
|   | DistAutogradContext                                     |  |    |   input_metadata_ |
|   |                                                 +-------------> |                   |
|   |  contextId_ = context_id_1                      |       |  |    |   next_edges_     |
|   |                                                 +       |  |    |                   |
|   |  sendAutogradFunctions_ = [msg_id_1, SendRpcBackward_1] |  |    +-------------------+
|   |                                                         |  |
|   |  recvAutogradFunctions_                                 |  |
|   |                                                         |  |
|   +---------------------------------------------------------+  |
|                                                                |
|                             +                                  |
|                             |                                  |
+----------------------------------------------------------------+
                              |
                              |
                              |                                                     Sender
+-----------------------------------------------------------------------------------------+
                              |                                                     Receiver
                              |
                              v
+-----------------------------+----------------------------------+
| worker                                                         |
|                                                                |    +-------------------+
|   +---------------------------------------------------------+  |    |RecvRpcBackward    |
|   | DistAutogradContext                                     |  |    |                   |
|   |                                                         |  |    |                   |
|   |   contextId_ = context_id_1                 +-----------------> |   input_metadata_ |
|   |                                             |           |  |    |                   |
|   |   sendAutogradFunctions_                    |           |  |    |   next_edges_     |
|   |                                             +           |  |    |                   |
|   |   recvAutogradFunctions_ = [msg_id_1, RecvRpcBackward_1]|  |    +-------------------+
|   |                                                         |  |
|   +---------------------------------------------------------+  |
|                                                                |
+----------------------------------------------------------------+

Давайте добавим Container и расширим текущую логику следующим образом:

  • Каждый рабочий включает DistAutogradContainer.
  • Каждый DistAutogradContainer включает в себя несколько DistAutogradContext, а DistAutogradContext извлекается в соответствии с идентификатором контекста.
  • Каждый DistAutogradContext включает sendAutogradFunctions_ и recvAutogradFunctions_ и использует идентификатор сообщения для получения SendRpcBackward или RecvRpcBackward.

Таким образом строится цепочка обратного распространения.

+------------------------------------------------------------------------------------------------------------------------------------+
| worker                                                                                                                             |
|                                                                                                                                    |
| +---------------------------------------+     +---------------------------------------------------------+    +-------------------+ |
| | DistAutogradContainer                 |     | DistAutogradContext                                     |    |SendRpcBackward    | |
| |                                       |     |                                                 +----------> |                   | |
| |   worker_id_                          |     |  contextId_ = ctx_id_1                          |       |    |   input_metadata_ | |
| |                                       |     |                                                 +       |    |                   | |
| |   next_autograd_message_id_     +---------> |  sendAutogradFunctions_ = [msg_id_1, SendRpcBackward_1] |    |   next_edges_     | |
| |                                 |     |     |                                                         |    |                   | |
| |   next_context_id_              |     |     |  recvAutogradFunctions_                                 |    +-------------------+ |
| |                                 +     |     |                                                         |                          |
| |   autograd_contexts_[ctx_id_1 : ctx]  |     +---------------------------------------------------------+                          |
| |                                       |                                                                                          |
| +----------------------------+----------+                                                                                          |
|                              |                                                                                                     |
+------------------------------------------------------------------------------------------------------------------------------------+
                               |
                               |
+-------------------------------------------------------------------------------------------------------------------------------------+
                               |
                               v
+------------------------------+-----------------------------------------------------------------------------------------------------+
| worker                                                                                                                             |
|                                                                                                                                    |
| +---------------------------------------+     +---------------------------------------------------------+    +-------------------+ |
| | DistAutogradContainer                 |     | DistAutogradContext                                     |    |RecvRpcBackward    | |
| |                                       |     |                                                 +----------> |                   | |
| |   worker_id_                          |     |  contextId_ = ctx_id_1                          |       |    |   input_metadata_ | |
| |                                       |     |                                                 |       |    |                   | |
| |   next_autograd_message_id_     +---------> |  sendAutogradFunctions_                         |       |    |   next_edges_     | |
| |                                 |     |     |                                                 +       |    |                   | |
| |   next_context_id_              |     |     |  recvAutogradFunctions_ = [msg_id_1, RecvRpcBackward_1] |    +-------------------+ |
| |                                 +     |     |                                                         |                          |
| |   autograd_contexts_[ctx_id_1 : ctx]  |     +---------------------------------------------------------+                          |
| |                                       |                                                                                          |
| +---------------------------------------+                                                                                          |
|                                                                                                                                    |
+------------------------------------------------------------------------------------------------------------------------------------+

Телефон такой:

До сих пор мы предварительно анализировали контекстно-зависимые классы, далее мы объединим уже проанализированный контент, и система будет смотреть на бизнес-логику.

0xEE Личная информация

★★★★★★Думая о жизни и технологиях★★★★★★

Публичный аккаунт WeChat:мысли Росси

ссылка 0xFF