[Анализ исходного кода] PyTorch, распространяемый Autograd (2) ---- Фонд RPC

глубокое обучение PyTorch

0x00 сводка

В предыдущей статье мы дали идею дизайна распределенного автограда, в начале этой статьи разберем конкретный исходный код. Потому что независимо от того, идет ли речь о прямом или обратном распространении, для завершения необходимо полагаться на RPC, поэтому давайте сначала рассмотрим некоторые базовые функции, инкапсулированные в RPC, такие как инициализация, прокси (функции, связанные с RPC, выполняются на основе прокси), прием сообщений. , отправить и т.д.

В этой статье вы узнаете, как инициализировать серверную часть RPC, как сгенерировать прокси-сервер RPC, как использовать прокси-сервер RPC для отправки и получения и как подключиться к удаленному механизму автоматической дифференциации dist.autograd.

Другие статьи о распространении PyTorch:

[Анализ исходного кода] Распространение PyTorch (1) ------ история и обзор

[Анализ исходного кода] Как PyTorch использует GPU

[Анализ исходного кода] Распределенный PyTorch (2) ----- DataParallel (включен)

[Анализ исходного кода] Распределенный PyTorch (3) ----- DataParallel (ниже)

[Анализ исходного кода] Распределенный PyTorch (4) ------ Основная концепция распределенного приложения

[Анализ исходного кода] Распределенный PyTorch (5) ------ Обзор DistributedDataParallel и способы его использования

[Анализ исходного кода] Распределенный PyTorch (6) ---DistributedDataParallel -- инициализация и хранение

[Анализ исходного кода] Распределенный PyTorch (7) ----- Группа процессов DistributedDataParallel

[Анализ исходного кода] Распределенный PyTorch (8) -------- Бумага DistributedDataParallel

[Анализ исходного кода] Распределенный PyTorch (9) ----- Инициализация DistributedDataParallel

[Анализ исходного кода] Распределенный PyTorch (10) ------ Редуктор статической архитектуры DistributedDataParallel

[Анализ исходного кода] Распределенный PyTorch (11) ----- DistributedDataParallel для создания операций Reducer и Join

[Анализ исходного кода] Распределенный PyTorch (12) ----- Прямое распространение до DistributedDataParallel

[Анализ исходного кода] Распределенный PyTorch (13) ----- Обратное распространение DistributedDataParallel

[Анализ исходного кода] PyTorch, распространяемый Autograd (1) ---- дизайн

Для лучшего объяснения код в этой статье будет соответственно упрощен в соответствии с конкретной ситуацией.

0x01 Пример

Мы взяли код из раздела примеров PyTorch и немного изменили его, чтобы позволить двум рабочим процессам взаимодействовать через RPC. Рабочий пример разделен на две части:

  • Операции RPC, построение базы зависимостей.
  • Выполните обратное распространение.
def my_add(t1, t2):
  return torch.add(t1, t2)

def worker0():
    # On worker 0:

    # Setup the autograd context. Computations that take
    # part in the distributed backward pass must be within
    # the distributed autograd context manager.
    with dist_autograd.context() as context_id:
      t1 = torch.rand((3, 3), requires_grad=True)
      t2 = torch.rand((3, 3), requires_grad=True)

      # 第一阶段:RPC操作,构建依赖基础
      
      # Perform some computation remotely.
      t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))

      # Perform some computation locally based on remote result.
      t4 = torch.rand((3, 3), requires_grad=True)
      t5 = torch.mul(t3, t4)

      # Compute some loss.
      loss = t5.sum()

      # 第二阶段,执行后向传播
      
      # Run the backward pass.
      dist_autograd.backward(context_id, [loss])

      # Retrieve the gradients from the context.
      dist_autograd.get_gradients(context_id)

      print(loss)  

Два воркера можно запустить следующим образом, используя rpc.init_rpc для инициализации rpc. Worker0 запустится, а затем использует RPC для выполнения некоторых операций над worker 1.

def run_worker(rank, world_size):
    r"""
    A wrapper function that initializes RPC, calls the function, and shuts down
    RPC.
    """

    # We need to use different port numbers in TCP init_method for init_rpc and
    # init_process_group to avoid port conflicts.
    rpc_backend_options = TensorPipeRpcBackendOptions()
    rpc_backend_options.init_method = "tcp://localhost:29501"

    # Rank 0 and 1 are trainers.
    if rank == 0:
        rpc.init_rpc(
            "worker0",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )
        worker0()

    elif rank == 1:
        rpc.init_rpc(
            "worker1",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )

    # block until all rpcs finish
    rpc.shutdown()

0x02 Основы RPC

2.1 Инициализация

Давайте посмотрим на пример кода с нуля.Когда скрипт запускается, он вызывает rpc.init_rpc для инициализации rpc. Из комментариев RPC можно увидеть две концепции: общий ранг и размер мира.

rank (int): a globally unique id/rank of this node.
world_size (int): The number of workers in the group.

Конкретный код инициализации:

def init_rpc(
    name,
    backend=None,
    rank=-1,
    world_size=None,
    rpc_backend_options=None,
):
        dist_autograd._init(rank) # 我们后续会讨论分布式自动微分引擎
        _set_profiler_node_id(rank)
        # Initialize RPC.
        _init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)  

Нас интересует следующее: _init_rpc_backend установит бэкенд.

2.1.1 Инициализация серверной части

_init_rpc_backend Здесь будет видно, какой Агент окончательно сгенерирован в соответствии с конфигурацией, а затем установить агент в текущем контексте. У RPC есть два бэкенда, TENSORPIPE и PROCESS_GROUP, из которых PROCESS_GROUP был заброшен и будет постепенно перенесен на TENSORPIPE.

def _init_rpc_backend(
    backend=BackendType.TENSORPIPE,  # 默认后端是TENSORPIPE
    store=None,
    name=None,
    rank=-1,
    world_size=-1,
    rpc_backend_options=None,
):

    _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options)

    if _is_current_rpc_agent_set():
        raise RuntimeError("RPC is already initialized")

    # Initialize RPC.
    rpc_agent = backend_registry.init_backend( # 生成一个agent
        backend,
        store=store,
        name=name,
        rank=rank,
        world_size=world_size,
        rpc_backend_options=rpc_backend_options,
    )

    api._init_rpc_states(rpc_agent) # 设定代理到当前上下文

Как видите, TensorPipeAgent генерируется по умолчанию.

2.1.2 Создание прокси

Далее давайте посмотрим, как сгенерировать TensorPipeAgent, в частности, в torch/csrc/distributed/rpc/init.cpp. Когда здесь создается TensorPipeAgent, настройте RequestCallbackImpl как функцию обратного вызова. Эта функция обратного вызова используется внутри прокси-сервера для обработки полученного запроса.

shared_ptr_class_<TensorPipeAgent>(module, "TensorPipeAgent", rpcAgent)
    .def(
        py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
                    std::string selfName,
                    worker_id_t selfId,
                    int worldSize,
                    c10::intrusive_ptr<::c10d::ProcessGroup> processGroup,
                    TensorPipeRpcBackendOptions opts) {
          return std::shared_ptr<TensorPipeAgent>(
              new TensorPipeAgent(
                  store,
                  std::move(selfName),
                  selfId,
                  worldSize,
                  std::move(processGroup),
                  std::move(opts),
                  std::make_unique<RequestCallbackImpl>()), // RequestCallbackImpl 被配置到 Agent 之上
              impl::destroy_without_gil<TensorPipeAgent>);
        })

детали следующим образом:

+-----------------+        +-----------------------+
| TensorPipeAgent |        | RequestCallbackImpl   |
|                 |        |                       |
|         cb_ +----------> |                       |
|                 |        |                       |
+-----------------+        +-----------------------+

2.1.3 Настройка прокси

_init_rpc_states установит прокси в среде PyTorch, которая определена в torch/distributed/rpc/api.py.

def _init_rpc_states(agent):
    worker_infos = agent.get_worker_infos()
    global _ALL_WORKER_NAMES
    _ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos}

    # NB: backend implementation might have already set the rpc_agent.
    if not _is_current_rpc_agent_set():
        _set_and_start_rpc_agent(agent)

Следующий шаг — войти в мир C++. В torch/csrc/distributed/rpc/init.cpp есть _set_and_start_rpc_agent, его роль:

  • RpcAgent::setCurrentRpcAgent устанавливает агент.
  • Вызовите rpcAgent->start(), чтобы запустить агент.
module.def(
    "_set_and_start_rpc_agent",
    [](const std::shared_ptr<RpcAgent>& rpcAgent) {
        
      RpcAgent::setCurrentRpcAgent(rpcAgent); // 这里设定了 Agent
        
      // Initializing typeResolver inside RpcAgent constructor will make
      // RpcAgent have python dependency. To avoid RpcAgent to have python
      // dependency, setTypeResolver() here.
        
      std::shared_ptr<TypeResolver> typeResolver =
          std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
            auto typePtr = PythonRpcHandler::getInstance().parseTypeFromStr(
                qn.qualifiedName());
            return c10::StrongTypePtr(
                PythonRpcHandler::getInstance().jitCompilationUnit(),
                std::move(typePtr));
          });
      rpcAgent->setTypeResolver(typeResolver);
      rpcAgent->start(); // 启动代理
    },
    py::call_guard<py::gil_scoped_release>());

setCurrentRpcAgent определен в torch/csrc/distributed/rpc/rpc_agent.cpp.

2.1.4 Статические переменные класса

В RpcAgent есть статическая переменная-член currentRpcAgent_.

class TORCH_API RpcAgent {
     // 我们省略了其他成员变量和函数
     private:
      static std::shared_ptr<RpcAgent> currentRpcAgent_;
}

В C++ статические переменные-члены имеют следующие характеристики:

  • Он принадлежит всему классу.
  • Его жизненный цикл не зависит ни от какого объекта, это жизненный цикл программы.
  • Вы можете напрямую обращаться к общедоступным статическим переменным-членам через имя класса.
  • Доступ к общедоступным статическим переменным-членам класса можно получить по имени объекта.
  • Все производные объекты класса совместно используют статические переменные-члены этого класса.
  • Статические переменные-члены должны быть выделены отдельно вне класса.
  • Статические переменные-члены расположены в области глобальных данных внутри программы.

Итак, мы знаемRpcAgent::currentRpcAgent_Его можно рассматривать как глобальную переменную, и rpc единообразно использует эту переменную для согласования. В частности, эти функции выполняются через некоторые общедоступные функции-члены RpcAgent.

std::shared_ptr<RpcAgent> RpcAgent::currentRpcAgent_ = nullptr;

bool RpcAgent::isCurrentRpcAgentSet() {
  return std::atomic_load(&currentRpcAgent_) != nullptr;
}

std::shared_ptr<RpcAgent> RpcAgent::getCurrentRpcAgent() {
  std::shared_ptr<RpcAgent> agent = std::atomic_load(&currentRpcAgent_);
  return agent;
}

void RpcAgent::setCurrentRpcAgent(std::shared_ptr<RpcAgent> rpcAgent) {
  if (rpcAgent) {
    std::shared_ptr<RpcAgent> previousAgent;
    // Use compare_exchange so that we don't actually perform the exchange if
    // that would trigger the assert just below. See:
    // https://en.cppreference.com/w/cpp/atomic/atomic_compare_exchange
    std::atomic_compare_exchange_strong(
        &currentRpcAgent_, &previousAgent, std::move(rpcAgent));
  } else {
    // We can't use compare_exchange (we don't know what value to expect) but we
    // don't need to, as the only case that would trigger the assert is if we
    // replaced nullptr with nullptr, which we can just do as it has no effect.
    std::shared_ptr<RpcAgent> previousAgent =
        std::atomic_exchange(&currentRpcAgent_, std::move(rpcAgent));
  }
}

Поэтому текущее расширение выглядит следующим образом: в дальнейшем операции RPC будут выполняться через глобальную переменную RpcAgent::currentRpcAgent_.

RpcAgent::currentRpcAgent_
      +
      |
      |
      |
      v
+-----+-----------+        +-----------------------+
| TensorPipeAgent |        | RequestCallbackImpl   |
|                 |        |                       |
|         cb_ +----------> |                       |
|                 |        |                       |
+-----------------+        +-----------------------+

2.2 RPC-прокси

Все связанные функции dist.autograd выполняются на основе прокси-сервера RPC, поэтому нам нужно более внимательно изучить прокси-сервер.

2.2.1 RpcAgent

Это прокси-сервер, используемый для передачи RPC, и базовый класс прокси-сервера для отправки и получения сообщений RPC, который:

  • при условииsendAPI используется для обработки запросов и ответов.
  • cb_ также настроен для обработки входящих запросов.

WorkerInfoглобальный уникальный идентификатор рабочего процесса, в котором находится экземпляр прокси, включаяname_иid_Эти две переменные-члены.name_глобально уникальное имя,id_является глобально уникальным идентификатором.

class TORCH_API RpcAgent {
 public:
  RpcAgent(
      WorkerInfo id,
      std::unique_ptr<RequestCallback> cb,
      std::chrono::milliseconds rpcTimeout);
  
  // 给 to.id 代表的其他 RpcAgengt 发送一个消息,返回一个JitFuture,这个实现是异步的。
  virtual c10::intrusive_ptr<JitFuture> send(
      const WorkerInfo& to.id,
      Message&& message,
      const float rpcTimeoutSeconds = kUnsetRpcTimeout,
      const std::unordered_map<c10::Device, c10::Device>& deviceMap = {}) = 0;

 protected:
  const WorkerInfo workerInfo_; // 代理实例的全局唯一标示
  const std::unique_ptr<RequestCallback> cb_; // 回调函数
  std::atomic<std::chrono::milliseconds> rpcTimeout_;
  std::atomic<bool> profilingEnabled_;
  std::shared_ptr<TypeResolver> typeResolver_;
  std::atomic<bool> rpcAgentRunning_;

 private:
  static std::shared_ptr<RpcAgent> currentRpcAgent_; // 全局代理
  // Add GIL wait time data point to metrics
  virtual void addGilWaitTime(const std::chrono::microseconds gilWaitTime) = 0;
  friend class PythonRpcHandler;
  // Condition Variable to signal when the rpcRetryMap_ has been populated.
  std::condition_variable rpcRetryMapCV_;
  // Mutex to protect RpcRetryMap_.
  std::mutex rpcRetryMutex_;
};

2.2.2 ProcessGroupAgent

ProcessGroupAgent — это производный класс от RpcAgent. Это использовалось раньше, но PyTorch предоставляет лучший TensorAgent. Мы выбрали только некоторые переменные-члены.

class TORCH_API ProcessGroupAgent : public RpcAgent {
 public:

  c10::intrusive_ptr<::c10d::ProcessGroup> pg_;
  // worker name -> rank
  std::unordered_map<std::string, worker_id_t> nameMap_;
  std::vector<WorkerInfo> allWorkerInfo_;

  MessageCounter sendCounts_;
  MessageCounter recvCounts_;

  std::atomic<int64_t> nextId_;

  std::thread listenerThread_;
  std::thread futureTimeoutThread_;
  c10::intrusive_ptr<c10d::ProcessGroup::Work> recvWork_;

  std::unordered_map<
      worker_id_t,
      std::set<c10::intrusive_ptr<c10d::ProcessGroup::Work>>>
      currentPendingSends_;

  ThreadPool threadPool_;

  // Mapping of request id to FutureInfo struct.
  std::unordered_map<int64_t, FutureInfo> futures_;
};

2.2.3 TensorPipeAgent

TensorPipeAgent определен в torch/csrc/distributed/rpc/tensorpipe_agent.h, который есть сейчас и будет в будущем. TensorPipeAgent использует TensorPipe для прозрачного перемещения тензоров и данных между доступными транспортами или каналами. Это как гибридный транспорт RPC с поддержкой разделяемой памяти (linux) и TCP (linux и mac). PyTorch разрабатывает свою версию CUDA.

Мы выбрали только некоторые переменные-члены.

// TensorPipeAgent leverages TensorPipe (https://github.com/pytorch/tensorpipe)
// to transparently move tensors and payloads through the fastest available
// transport or channel. It acts like a hybrid RPC transport, providing shared
// memory (linux) and TCP (linux & mac) support. CUDA support is in progress.
class TensorPipeAgent : public RpcAgent {
 public:
  TensorPipeAgent(
      const c10::intrusive_ptr<::c10d::Store>& store,
      std::string selfName,
      worker_id_t selfId,
      int worldSize,
      c10::intrusive_ptr<::c10d::ProcessGroup> processGroup,
      TensorPipeRpcBackendOptions opts,
      std::unique_ptr<RequestCallback> cb);

  const TensorPipeRpcBackendOptions opts_;
  std::unordered_map<std::string, DeviceMap> reverseDeviceMaps_;
  std::vector<c10::Device> devices_;

  ThreadPool threadPool_;
  std::shared_ptr<tensorpipe::Context> context_;
  std::shared_ptr<tensorpipe::Listener> listener_;

  mutable std::mutex connectedPipesMutex_;
  std::unordered_map<worker_id_t, ClientPipe> connectedPipes_;

  // Maps keyed on name and id for easy WorkerInfo lookup.
  std::unordered_map<worker_id_t, WorkerInfo> workerIdToInfo_;
  std::unordered_map<std::string, WorkerInfo> workerNameToInfo_;
  std::unordered_map<std::string, std::string> workerNameToURL_;

  ::c10d::PrefixStore rankToNameStore_;
  ::c10d::PrefixStore nameToAddressStore_;
  const int worldSize_;

  // The join method is required to behave like a barrier and perform collective
  // operations. For simplicity and reliability, we offload this to a process
  // group, but probably one day we might want to re-implement them using RPCs.
  const c10::intrusive_ptr<::c10d::ProcessGroup> processGroup_;

  std::atomic<uint64_t> nextMessageID_{0};

  // Thread that will poll the timeoutMap_ for timed out messages and mark them
  // with an error accordingly
  std::thread timeoutThread_;

  // Function run by the timeoutThread_ to check for timed out RPCs
  void pollTimeoutRpcs();
};

2.2.4 Функция обратного вызова

Когда агент получает сообщение, он вызывает функцию обратного вызова. А RequestCallbackImpl реализует логику обратного вызова. RequestCallbackImpl является производным классом Давайте посмотрим на базовый класс RequestCallbackNoPython и найдем интерфейс RequestCallback, поэтому RequestCallback является основой этой системы производных.

class TORCH_API RequestCallbackNoPython : public RequestCallback
  
class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython   
2.2.4.1 RequestCallback

RequestCallback — это интерфейс для обработки сообщений RPC и абстрактный класс.

// Functor which is invoked to process an RPC message. This is an abstract class
// with some common functionality across all request handlers. Users need to
// implement this interface to perform the actual business logic.
class TORCH_API RequestCallback {
 public:
  // Invoke the callback.
  c10::intrusive_ptr<JitFuture> operator()(
      Message& request,
      std::shared_ptr<LazyStreamContext> ctx) const;

  // NOLINTNEXTLINE(modernize-use-equals-default)
  virtual ~RequestCallback() {}

 protected:
  // RpcAgent implementation should invoke ``RequestCallback`` to process
  // received requests. There is no restriction on the implementation's
  // threading model. This function takes an rvalue reference of the Message
  // object. It is expected to return the future to a response message or
  // message containing an exception. Different rpc agent implementations are
  // expected to ensure delivery of the response/exception based on their
  // implementation specific mechanisms.
  virtual c10::intrusive_ptr<JitFuture> processMessage(
      Message& request,
      std::shared_ptr<LazyStreamContext> ctx) const = 0;
};
2.2.4.2 RequestCallbackNoPython

RequestCallbackNoPython определен в файле torch/csrc/distributed/rpc/request_callback_no_python.h, который реализует некоторые механизмы обработки. Поскольку он содержит слишком много методов, мы можем извлечь только некоторые части. Если вам интересно, изучите его подробно.

// RequestCallback implementation with no Python dependencies.
class TORCH_API RequestCallbackNoPython : public RequestCallback {
 public:
  c10::intrusive_ptr<JitFuture> processMessage(
      Message& request,
      std::shared_ptr<LazyStreamContext> ctx) const override;

 protected:

  void processForwardAutogradReq(
      RpcCommandBase& rpc,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture,
      std::shared_ptr<LazyStreamContext> ctx) const;

  void processBackwardAutogradReq(
      RpcCommandBase& rpc,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture) const;

  void processRpc(
      RpcCommandBase& rpc,
      const MessageType& messageType,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture,
      std::shared_ptr<LazyStreamContext> ctx) const;

  virtual void processRpcWithErrors(
      RpcCommandBase& rpc,
      const MessageType& messageType,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture,
      std::shared_ptr<LazyStreamContext> ctx) const;

  virtual void processRRefBackward(
      RpcCommandBase& rpc,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture) const;
};

Мы увидим, как вызывать функцию обратного вызова в последующем анализе логики принятия.

0x03 Логика отправки

Сначала рассмотрим логику отправки. То есть роль rpc.rpc_sync: установить root, добавить send и т.д.

3.1 Python

Начнем с части Python.

# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))

Сначала зашел в rpc_sync и обнаружил, что он вызывает _invoke_rpc.

@_require_initialized
def rpc_sync(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
    fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout)
    return fut.wait()

Далее пришел_invoke_rpc, видно, что эта функция выбирает разные пути в зависимости от типа вызова (встроенная операция, скрипт, udf).

def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RPC_TIMEOUT):
    qualified_name = torch.jit._builtins._find_builtin(func)
    dst_worker_info = _to_worker_info(to)
    should_profile = torch.autograd._profiler_enabled()
    ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info)

    with ctx_manager as rf:
        args = args if args else ()
        kwargs = kwargs if kwargs else {}

        is_async_exec = hasattr(func, "_wrapped_async_rpc_function")

        if is_async_exec:
            wrapped = func._wrapped_async_rpc_function
            if isinstance(wrapped, torch.jit.ScriptFunction):
                func = wrapped

        if qualified_name is not None:
            fut = _invoke_rpc_builtin( # 内置rpc
                dst_worker_info,
                qualified_name,
                rpc_timeout,
                *args,
                **kwargs
            )
        elif isinstance(func, torch.jit.ScriptFunction): # 脚本
            fut = _invoke_rpc_torchscript( 
                dst_worker_info.name,
                torch._jit_internal._qualified_name(func),
                args,
                kwargs,
                rpc_timeout,
                is_async_exec
            )
        else:
            (pickled_python_udf, tensors) = _default_pickler.serialize(
                PythonUDF(func, args, kwargs)
            )
            fut = _invoke_rpc_python_udf( # 用户udf
                dst_worker_info,
                pickled_python_udf,
                tensors,
                rpc_timeout,
                is_async_exec
            )
        if should_profile:
            fut = rf._call_end_callbacks_on_future(fut)
    return fut

Отсюда я вошел в мир C++, torch/csrc/distributed/rpc/init.cpp.

3.2 C++

можно увидеть здесь _invoke_rpc_builtinСоответствует pyRpcBuiltin,_invoke_rpc_python_udfСоответствует pyRpcPythonUdf.

PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
  module.def(
      "_invoke_rpc_builtin",
      [](const WorkerInfo& dst,
         const std::string& opName,
         const float rpcTimeoutSeconds,
         const py::args& args,
         const py::kwargs& kwargs) {
        return std::make_shared<jit::PythonFutureWrapper>(
            pyRpcBuiltin(dst, opName, args, kwargs, rpcTimeoutSeconds)); # 内置函数
      },
      py::call_guard<py::gil_scoped_acquire>());

  module.def(
      "_invoke_rpc_python_udf",
      [](const WorkerInfo& dst,
         std::string& pickledPythonUDF,
         std::vector<torch::Tensor>& tensors,
         const float rpcTimeoutSeconds,
         const bool isAsyncExecution) {
        return std::make_shared<jit::PythonFutureWrapper>(pyRpcPythonUdf(
            dst,
            pickledPythonUDF, # 对应了udf
            tensors,
            rpcTimeoutSeconds,
            isAsyncExecution));
      },
      py::call_guard<py::gil_scoped_release>());  
  
  # 省略其他
}

мы выбираем_invoke_rpc_builtinВзгляните на соответствующий файл pyRpcBuiltin.

3.2.1 pyRpcBuiltin

Как вы можете видеть в torch/csrc/distributed/rpc/python_functions.cpp, pyRpcBuiltin вызовет sendMessageWithAutograd.

c10::intrusive_ptr<JitFuture> pyRpcBuiltin(
    const WorkerInfo& dst,
    const std::string& opName,
    const py::args& args,
    const py::kwargs& kwargs,
    const float rpcTimeoutSeconds) {
  DCHECK(PyGILState_Check());
  Stack stack;
  auto op = matchBuiltinOp(opName, args, kwargs, stack);
  // Release GIL since args and kwargs processing is done.
  py::gil_scoped_release release;
  auto scriptCall = std::make_unique<ScriptCall>(op, std::move(stack));
  auto agent = RpcAgent::getCurrentRpcAgent(); // 获取当前agent
  return toPyJitFuture(sendMessageWithAutograd( // 发送请求
      *agent,
      dst,
      std::move(*scriptCall).toMessage(),
      false,
      rpcTimeoutSeconds));
}

3.2.2 sendMessageWithAutograd

Используйте агент здесь, в torch/csrc/distributed/autograd/utils.cpp, чтобы отправить FORWARD_AUTOGRAD_REQ.

Позже на получателе мы увидим обработку сообщений FORWARD_AUTOGRAD_REQ, так что отправка и получение могут быть примерно связаны.

c10::intrusive_ptr<JitFuture> sendMessageWithAutograd(
    RpcAgent& agent,
    const WorkerInfo& dst,
    torch::distributed::rpc::Message&& wrappedRpcMsg,
    bool forceGradRecording,
    const float rpcTimeoutSeconds,
    bool forceDisableProfiling) {
  auto msg = getMessageWithAutograd( // 这里会与上下文交互,构建了 FORWARD_AUTOGRAD_REQ
      dst.id_,
      std::move(wrappedRpcMsg),
      MessageType::FORWARD_AUTOGRAD_REQ,
      forceGradRecording,
      agent.getDeviceMap(dst));

  c10::intrusive_ptr<JitFuture> fut;
  // If profiler is enabled, wrap this message with profiling metadata that will
  // tell the remote end to process this request with the profiler enabled.
  if (!forceDisableProfiling && torch::autograd::profiler::profilerEnabled()) {
    auto profilerConfig = torch::autograd::profiler::getProfilerConfig();
    auto msgWithProfiling = getMessageWithProfiling(
        std::move(msg),
        rpc::MessageType::RUN_WITH_PROFILING_REQ, //构建消息
        std::move(profilerConfig));
    // 发送消息
    fut = agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds);
  } else {
    fut = agent.send(dst, std::move(msg), rpcTimeoutSeconds);
  }

  return fut;
}

Процесс отправки выглядит следующим образом, в котором sendMessageWithAutograd будет использовать RpcAgent::getCurrentRpcAgent() для получения RpcAgent::currentRpcAgent_, который должен получить глобально установленный агент, а затем отправить его через агент.

  rpc.rpc_sync
         +
         |
         |
         v
  _invoke_rpc_builtin
         +
         |                                               Python
+---------------------------------------------------------------+
         |                                               C++
         |
         v

    pyRpcBuiltin
         +
         |
         |
         v

 sendMessageWithAutograd(RpcAgent::getCurrentRpcAgent())
         +
         |
         |
         |   RpcAgent::currentRpcAgent_
         |           +
         |           |
         |           |
         |           v
         |     +-----+-----------+
         |     | TensorPipeAgent |        +-----------------------+
         |     |                 |        | RequestCallbackImpl   |
         |     |       cb_ +------------> |                       |
         |     |                 |        +-----------------------+
         |     |                 |
         |     |                 |
         +-----------> send +-----------> Will send message to other worker
               |                 |
               |                 |
               +-----------------+

Логика принятия 0x04

4.1 Обратный вызов

Когда агент получает сообщение, он вызывает RequestCallback::operator(). Это функция обратного вызова, о которой мы упоминали ранее. Код находится в torch/csrc/distributed/rpc/tensorpipe_agent.cpp.

void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
  pipeRead(
      pipe,
      [this, pipe](
          const tensorpipe::Error& error,
          Message&& requestMessage,
          std::shared_ptr<LazyStreamContext> ctx) mutable {

        // Arm for next read
        respond(pipe);

        uint64_t messageId = requestMessage.id();
        increaseCallCount(serverActiveCalls_);

        // Defer user RPC UDF run to thread pool
        threadPool_.run([this,
                         pipe,
                         messageId,
                         requestMessage{std::move(requestMessage)},
                         ctx{std::move(ctx)}]() mutable {

          c10::intrusive_ptr<JitFuture> futureResponseMessage;
          try {
              
            // 这里会调用 RequestCallback 来进行回调逻辑处理
              
            futureResponseMessage = cb_->operator()(requestMessage, ctx);
            
          } catch (const std::exception& /* unused */) {
            futureResponseMessage =
                c10::make_intrusive<JitFuture>(at::AnyClassType::get());
            futureResponseMessage->setError(std::current_exception());
          }

          // Shortcut if immediately done
          if (futureResponseMessage->completed()) {
            decreaseCallCount(serverActiveCalls_);
            sendCompletedResponseMessage(
                pipe, *futureResponseMessage, messageId, std::move(ctx));
          } else {
            // Not complete yet
            increaseCallCount(serverActiveAsyncCalls_);
            futureResponseMessage->addCallback(
                [this, pipe, messageId, ctx{std::move(ctx)}](
                    JitFuture& futureResponseMessage) mutable {
                  decreaseCallCount(serverActiveCalls_);
                  decreaseCallCount(serverActiveAsyncCalls_);
                  sendCompletedResponseMessage(
                      pipe, futureResponseMessage, messageId, std::move(ctx));
                });
          }
        });
      });
}

4.2 operator()

В operator() будет вызываться processMessage для обработки сообщения.

c10::intrusive_ptr<JitFuture> RequestCallback::operator()(
    Message& request,
    std::shared_ptr<LazyStreamContext> ctx) const {
  // NB: cannot clear autograd context id here because the processMessage method
  // might pause waiting for all RRefs in the arguments to be confirmed by their
  // owners and resumne processing in a different thread. Hence, the
  // thread_local context id needs to be set and cleared in the thread that
  // indeed carries out the processing logic.
  return processMessage(request, std::move(ctx));
}

Затем он будет вызываться RequestCallbackNoPython::processMessage.

  • Сначала вызовите команду deserializePythonRpcCommand, реализованную в RequestCallbackImpl, чтобы десериализовать PythonUDF.
  • Затем вызовите processRpcWithErrors для обработки сообщения.
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processMessage(
    Message& request,
    std::shared_ptr<LazyStreamContext> ctx) const {
  // We need two futures here because it could pause twice when processing a
  // RPC message:
  //  1) waiting for all RRefs in the arguments to become confirmed;
  //  2) waiting for processRpc to finish.
  auto retFuture = c10::make_intrusive<JitFuture>(at::AnyClassType::get());
  auto& rrefContext = RRefContext::getInstance();
  try {
    rrefContext.recordThreadLocalPendingRRefs();
    // Deserialize PythonUDF here to trigger RRef unpickling
    // 调用 RequestCallbackImpl 中实现的  deserializePythonRpcCommand 来对 PythonUDF 反序列化
    std::unique_ptr<RpcCommandBase> rpc = deserializePythonRpcCommand(
        deserializeRequest(request), request.type()); // 解析请求
    auto rrefsReadyFuture = rrefContext.waitForThreadLocalPendingRRefs();

    rrefsReadyFuture->addCallback(
        [this,
         retFuture,
         // std::function must be copyable, hence hae to cast the unique_ptr to
         // a shared_ptr here.
         rpc = (std::shared_ptr<RpcCommandBase>)std::move(rpc),
         messageType = request.type(),
         id = request.id(),
         ctx = std::move(ctx)](JitFuture& /* unused */) mutable {
          c10::MultiStreamGuard guard(
              ctx ? ctx->getReservedStreams() : ArrayRef<Stream>({}));
          // The cost of pre-request check is minimal thanks to
          // std::shared_lock. The cost is in magnitude
          // of 10us.
          auto serverProcessGlobalProfilerStateStackEntryPtr =
              profiler::processglobal::StateStackEntry::current();
          // If server global profiler is enabled, we futher pay the
          // cost of thread local profiler state initialization.
          if (serverProcessGlobalProfilerStateStackEntryPtr) {
            // Initialize thread-local profiler state from process-global
            // profiler state.
            ::torch::autograd::profiler::enableProfilerLegacy(
                serverProcessGlobalProfilerStateStackEntryPtr->statePtr()
                    ->config());
          }

          // 在这里
          processRpcWithErrors(
              *rpc, messageType, id, retFuture, std::move(ctx));

          // Response message has been sent at this moment, this post-response
          // work doesn't affect RPC trip time.
          if (serverProcessGlobalProfilerStateStackEntryPtr) {
            // Restore thread-local profiler state.
            ::torch::autograd::profiler::thread_event_lists event_lists =
                ::torch::autograd::profiler::disableProfilerLegacy();
            // Put thread_local event_lists into the process-global profiler
            // state.
            profiler::processglobal::pushResultRecursive(
                serverProcessGlobalProfilerStateStackEntryPtr, event_lists);
          }
        });
  } catch (std::exception& e) {
    retFuture->markCompleted(handleError(e, request.type(), request.id()));
    rrefContext.clearRecordedPendingRRefsOnError();
  }
  return retFuture;
}

Затем вызовите processRpcWithErrors.

void RequestCallbackNoPython::processRpcWithErrors(
    RpcCommandBase& rpc,
    const MessageType& messageType,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture,
    std::shared_ptr<LazyStreamContext> ctx) const {
  try {
    processRpc(rpc, messageType, messageId, responseFuture, std::move(ctx));
  } catch (std::exception& e) {
    responseFuture->markCompleted(handleError(e, messageType, messageId));
  }
}

Далее идет процессRpc. Здесь вы можете увидеть обработку FORWARD_AUTOGRAD_REQ.

void RequestCallbackNoPython::processRpc(
    RpcCommandBase& rpc,
    const MessageType& messageType,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture,
    std::shared_ptr<LazyStreamContext> ctx) const {

    case MessageType::FORWARD_AUTOGRAD_REQ: { // 这里就和之前发送的对应上了
      processForwardAutogradReq(rpc, messageId, responseFuture, std::move(ctx));
      return;
    }
    case MessageType::BACKWARD_AUTOGRAD_REQ: {
      processBackwardAutogradReq(rpc, messageId, responseFuture);
      return;
    };  
  
}  

детали следующим образом:

 TensorPipeAgent      RequestCallback  RequestCallbackNoPython     RequestCallbackImpl
        +                   +                 +                          +
        |                   |                 |                          |
        |                   |                 |                          |
        v                   |                 |                          |
    respond                 |                 |                          |
        +                   |                 |                          |
        |                   |                 |                          |
        |                   |                 |                          |
        v                   v                 v                          |
cb_->operator()  +-->   operator()  +-->  processMessage                 |
                                              +                          |
                                              |                          |
                                              |                          v
                                              +--------------->  deserializePythonRpcCommand
                                              |
                                              |
                                              |
                                              v

                                      processRpcWithErrors
                                              +
                                              |
                                              |
                                              v
                                          processRpc
                                              +
                                              |
                                              |
                                              v
                                    processForwardAutogradReq

4.3 RequestCallbackImpl

В это время у читателей возникнут вопросы. До того, как TensorPipeAgent четко установил RequestCallbackImpl в качестве функции обратного вызова, зачем вызывать только его deserializePythonRpcCommand? DeserialXXX, похоже, связан с сериализацией. Говорят, что некоторые функции бизнес-обработки должны быть вызваны, такие как processXXXX и как. Далее рассмотрим RequestCallbackImpl.

RequestCallbackImpl определен в torch/csrc/distributed/rpc/request_callback_impl.h.

class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython {
 public:
  std::unique_ptr<RpcCommandBase> deserializePythonRpcCommand(
      std::unique_ptr<RpcCommandBase> rpc,
      const MessageType& messageType) const override;

  void processPythonCall(
      RpcCommandBase& rpc,
      const std::function<void(Message)>& markComplete,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture) const override;

  void processScriptCall(
      RpcCommandBase& rpc,
      const std::function<void(Message)>& markComplete,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture) const override;

  void processScriptRemoteCall(
      ScriptRemoteCall& scriptRemoteCall,
      const std::function<void(void)>& postProcessing,
      std::vector<at::IValue>& stack,
      const c10::intrusive_ptr<OwnerRRef>& ownerRRef) const override;

  void processPythonRemoteCall(
      RpcCommandBase& rpc,
      const std::function<void(Message)>& markComplete,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture,
      std::shared_ptr<LazyStreamContext> ctx) const override;

  void processRpcWithErrors(
      RpcCommandBase& rpc,
      const MessageType& messageType,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture,
      std::shared_ptr<LazyStreamContext> ctx) const override;

  void processRRefBackward(
      RpcCommandBase& rpc,
      const int64_t messageId,
      const c10::intrusive_ptr<JitFuture>& responseFuture) const override;
};

Поскольку RequestCallbackImpl окончательно сгенерирован, на самом деле в середине приведенного выше рисунка есть шаг: ProcessRpcWithErrors фактически вызывает функцию processRpcWithErrors в RequestCallbackImpl, что добавляет некоторую логику обработки исключений.

void RequestCallbackImpl::processRpcWithErrors(
    RpcCommandBase& rpc,
    const MessageType& messageType,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture,
    std::shared_ptr<LazyStreamContext> ctx) const {
  try {
    processRpc(rpc, messageType, messageId, responseFuture, std::move(ctx));
  } catch (py::error_already_set& e) {
    responseFuture->markCompleted(handleError(e, messageType, messageId));
    py::gil_scoped_acquire acquire;
    e.restore(); // Release ownership on py::objects and also restore
                 // Python Error Indicator.
    PyErr_Clear(); // Clear the Python Error Indicator as we has
                   // recorded the exception in the response message.
  } catch (std::exception& e) {
    responseFuture->markCompleted(handleError(e, messageType, messageId));
  }
}

Логическая схема изменена следующим образом:

 TensorPipeAgent      RequestCallback  RequestCallbackNoPython     RequestCallbackImpl
        +                   +                 +                          +
        |                   |                 |                          |
        |                   |                 |                          |
        v                   |                 |                          |
    respond                 |                 |                          |
        +                   |                 |                          |
        |                   |                 |                          |
        |                   |                 |                          |
        v                   v                 v                          |
cb_->operator()  +-->   operator()  +-->  processMessage                 |
                                              +                          |
                                              |                          |
                                              |                          v
                                              +----------------> deserializePythonRpcCommand
                                              |                          +
                                              |                          |
                                              |                          |
                                              |                          v
                                              |
                                              +----------------> processRpcWithErrors
                                              |                          +
                                              |                          |
                                              |                          |
                                              | <------------------------+
                                              |
                                              |
                                              v
                                          processRpc
                                              +
                                              |
                                              |
                                              v
                                    processForwardAutogradReq

Если совместить с предыдущей передачей, то расширяем схему следующим образом:

  1. rpc.rpc_sync вызывается, когда отправителю необходимо удаленно запустить автоматические вычисления градиента.
  2. Эта функция вызывается из Python в мир C++ и называется pyRpcBuiltin.
  3. Вызовите sendMessageWithAutograd, чтобы уведомить Receiver.
  4. RpcAgent::getCurrentRpcAgent() будет вызываться для получения локального агента.
  5. Вызов функции отправки текущего агента.
  6. Функция отправки отправляет FORWARD_AUTOGRAD_REQ рабочему процессу Receiver.
  7. Функция ответа вызовет функцию обратного вызова cb_ агента в приемнике.
  8. Вызов processRpcWithErrors для RequestCallbackImpl.
  9. Затем вызовите processRpc.
  10. Наконец, вызывается processForwardAutogradReq для завершения процесса запуска распределенного автограда на основе RPC.
                                                             +
 rpc.rpc_sync                                 Sender         |     Receiver
        +                                                    |
        |                                                    |
        | 1                                                  |
        v                                                    |
 _invoke_rpc_builtin                                         |
        +                                                    |
        |                                      Python        |
+----------------------------------------------------------+ |
        |                                      C++           |      +----------------------------+
        |  2                                                 |      | RequestCallbackImpl        |
        v                                                    |      |                            |
                                                             |   +----> processRpcWithErrors     |
   pyRpcBuiltin                                              |   |  |             +              |
        +                                                    |   |  |             | 9            |
        |  3                                                 |   |  |             |              |
        |                                                    |   |  |             v              |
        v                                                    |   |  |         processRpc         |
                                     4                       |   |  |             +              |
sendMessageWithAutograd(RpcAgent::getCurrentRpcAgent())      |   |  |             | 10           |
        +                                                    |   |  |             |              |
        |                                                    |   |  |             v              |
        |                                                    |   |  |  processForwardAutogradReq |
        |   RpcAgent::currentRpcAgent_                       |   |  |                            |
        |           +                                        |   |  +----------------------------+
        |           |                                        |   |
        | 5         |                                        |   |8     +-----------------+
        |           v                                        |   |      | TensorPipeAgent |
        |    +------+--------+                               |   |      |                 |
        |    |TensorPipeAgent|   +-------------------+       |   +------------+ cb_       |
        |    |               |   |RequestCallbackImpl|       |          |        ^        |
        |    |      cb_ +------->+                   |       |          |      7 |        |
        |    |               |   +-------------------+       |          |        |        |
        |    |               |                          6    |          |        +        |
        +--------> send   +----------------------------------+--------------> respond     |
             |               |                   FORWARD_AUTOGRAD_REQ   |                 |
             |               |                               +          |                 |
             +---------------+                               |          +-----------------+
                                                             +


Телефон такой:

На этом знакомство с RPC закончено, в следующей статье мы представим контекстно-зависимые классы управления, так что следите за обновлениями.

0xEE Личная информация

★★★★★★Думая о жизни и технологиях★★★★★★

Публичный аккаунт WeChat:мысли Росси

ссылка 0xFF