0x00 сводка
В предыдущей статье мы дали идею дизайна распределенного автограда, в начале этой статьи разберем конкретный исходный код. Потому что независимо от того, идет ли речь о прямом или обратном распространении, для завершения необходимо полагаться на RPC, поэтому давайте сначала рассмотрим некоторые базовые функции, инкапсулированные в RPC, такие как инициализация, прокси (функции, связанные с RPC, выполняются на основе прокси), прием сообщений. , отправить и т.д.
В этой статье вы узнаете, как инициализировать серверную часть RPC, как сгенерировать прокси-сервер RPC, как использовать прокси-сервер RPC для отправки и получения и как подключиться к удаленному механизму автоматической дифференциации dist.autograd.
Другие статьи о распространении PyTorch:
[Анализ исходного кода] Распространение PyTorch (1) ------ история и обзор
[Анализ исходного кода] Как PyTorch использует GPU
[Анализ исходного кода] Распределенный PyTorch (2) ----- DataParallel (включен)
[Анализ исходного кода] Распределенный PyTorch (3) ----- DataParallel (ниже)
[Анализ исходного кода] Распределенный PyTorch (7) ----- Группа процессов DistributedDataParallel
[Анализ исходного кода] Распределенный PyTorch (8) -------- Бумага DistributedDataParallel
[Анализ исходного кода] Распределенный PyTorch (9) ----- Инициализация DistributedDataParallel
[Анализ исходного кода] PyTorch, распространяемый Autograd (1) ---- дизайн
Для лучшего объяснения код в этой статье будет соответственно упрощен в соответствии с конкретной ситуацией.
0x01 Пример
Мы взяли код из раздела примеров PyTorch и немного изменили его, чтобы позволить двум рабочим процессам взаимодействовать через RPC. Рабочий пример разделен на две части:
- Операции RPC, построение базы зависимостей.
- Выполните обратное распространение.
def my_add(t1, t2):
return torch.add(t1, t2)
def worker0():
# On worker 0:
# Setup the autograd context. Computations that take
# part in the distributed backward pass must be within
# the distributed autograd context manager.
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# 第一阶段:RPC操作,构建依赖基础
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
# 第二阶段,执行后向传播
# Run the backward pass.
dist_autograd.backward(context_id, [loss])
# Retrieve the gradients from the context.
dist_autograd.get_gradients(context_id)
print(loss)
Два воркера можно запустить следующим образом, используя rpc.init_rpc для инициализации rpc. Worker0 запустится, а затем использует RPC для выполнения некоторых операций над worker 1.
def run_worker(rank, world_size):
r"""
A wrapper function that initializes RPC, calls the function, and shuts down
RPC.
"""
# We need to use different port numbers in TCP init_method for init_rpc and
# init_process_group to avoid port conflicts.
rpc_backend_options = TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = "tcp://localhost:29501"
# Rank 0 and 1 are trainers.
if rank == 0:
rpc.init_rpc(
"worker0",
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
worker0()
elif rank == 1:
rpc.init_rpc(
"worker1",
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
# block until all rpcs finish
rpc.shutdown()
0x02 Основы RPC
2.1 Инициализация
Давайте посмотрим на пример кода с нуля.Когда скрипт запускается, он вызывает rpc.init_rpc для инициализации rpc. Из комментариев RPC можно увидеть две концепции: общий ранг и размер мира.
rank (int): a globally unique id/rank of this node.
world_size (int): The number of workers in the group.
Конкретный код инициализации:
def init_rpc(
name,
backend=None,
rank=-1,
world_size=None,
rpc_backend_options=None,
):
dist_autograd._init(rank) # 我们后续会讨论分布式自动微分引擎
_set_profiler_node_id(rank)
# Initialize RPC.
_init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)
Нас интересует следующее: _init_rpc_backend установит бэкенд.
2.1.1 Инициализация серверной части
_init_rpc_backend Здесь будет видно, какой Агент окончательно сгенерирован в соответствии с конфигурацией, а затем установить агент в текущем контексте. У RPC есть два бэкенда, TENSORPIPE и PROCESS_GROUP, из которых PROCESS_GROUP был заброшен и будет постепенно перенесен на TENSORPIPE.
def _init_rpc_backend(
backend=BackendType.TENSORPIPE, # 默认后端是TENSORPIPE
store=None,
name=None,
rank=-1,
world_size=-1,
rpc_backend_options=None,
):
_validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options)
if _is_current_rpc_agent_set():
raise RuntimeError("RPC is already initialized")
# Initialize RPC.
rpc_agent = backend_registry.init_backend( # 生成一个agent
backend,
store=store,
name=name,
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
api._init_rpc_states(rpc_agent) # 设定代理到当前上下文
Как видите, TensorPipeAgent генерируется по умолчанию.
2.1.2 Создание прокси
Далее давайте посмотрим, как сгенерировать TensorPipeAgent, в частности, в torch/csrc/distributed/rpc/init.cpp. Когда здесь создается TensorPipeAgent, настройте RequestCallbackImpl как функцию обратного вызова. Эта функция обратного вызова используется внутри прокси-сервера для обработки полученного запроса.
shared_ptr_class_<TensorPipeAgent>(module, "TensorPipeAgent", rpcAgent)
.def(
py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
std::string selfName,
worker_id_t selfId,
int worldSize,
c10::intrusive_ptr<::c10d::ProcessGroup> processGroup,
TensorPipeRpcBackendOptions opts) {
return std::shared_ptr<TensorPipeAgent>(
new TensorPipeAgent(
store,
std::move(selfName),
selfId,
worldSize,
std::move(processGroup),
std::move(opts),
std::make_unique<RequestCallbackImpl>()), // RequestCallbackImpl 被配置到 Agent 之上
impl::destroy_without_gil<TensorPipeAgent>);
})
детали следующим образом:
+-----------------+ +-----------------------+
| TensorPipeAgent | | RequestCallbackImpl |
| | | |
| cb_ +----------> | |
| | | |
+-----------------+ +-----------------------+
2.1.3 Настройка прокси
_init_rpc_states установит прокси в среде PyTorch, которая определена в torch/distributed/rpc/api.py.
def _init_rpc_states(agent):
worker_infos = agent.get_worker_infos()
global _ALL_WORKER_NAMES
_ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos}
# NB: backend implementation might have already set the rpc_agent.
if not _is_current_rpc_agent_set():
_set_and_start_rpc_agent(agent)
Следующий шаг — войти в мир C++. В torch/csrc/distributed/rpc/init.cpp есть _set_and_start_rpc_agent, его роль:
- RpcAgent::setCurrentRpcAgent устанавливает агент.
- Вызовите rpcAgent->start(), чтобы запустить агент.
module.def(
"_set_and_start_rpc_agent",
[](const std::shared_ptr<RpcAgent>& rpcAgent) {
RpcAgent::setCurrentRpcAgent(rpcAgent); // 这里设定了 Agent
// Initializing typeResolver inside RpcAgent constructor will make
// RpcAgent have python dependency. To avoid RpcAgent to have python
// dependency, setTypeResolver() here.
std::shared_ptr<TypeResolver> typeResolver =
std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
auto typePtr = PythonRpcHandler::getInstance().parseTypeFromStr(
qn.qualifiedName());
return c10::StrongTypePtr(
PythonRpcHandler::getInstance().jitCompilationUnit(),
std::move(typePtr));
});
rpcAgent->setTypeResolver(typeResolver);
rpcAgent->start(); // 启动代理
},
py::call_guard<py::gil_scoped_release>());
setCurrentRpcAgent определен в torch/csrc/distributed/rpc/rpc_agent.cpp.
2.1.4 Статические переменные класса
В RpcAgent есть статическая переменная-член currentRpcAgent_.
class TORCH_API RpcAgent {
// 我们省略了其他成员变量和函数
private:
static std::shared_ptr<RpcAgent> currentRpcAgent_;
}
В C++ статические переменные-члены имеют следующие характеристики:
- Он принадлежит всему классу.
- Его жизненный цикл не зависит ни от какого объекта, это жизненный цикл программы.
- Вы можете напрямую обращаться к общедоступным статическим переменным-членам через имя класса.
- Доступ к общедоступным статическим переменным-членам класса можно получить по имени объекта.
- Все производные объекты класса совместно используют статические переменные-члены этого класса.
- Статические переменные-члены должны быть выделены отдельно вне класса.
- Статические переменные-члены расположены в области глобальных данных внутри программы.
Итак, мы знаемRpcAgent::currentRpcAgent_
Его можно рассматривать как глобальную переменную, и rpc единообразно использует эту переменную для согласования. В частности, эти функции выполняются через некоторые общедоступные функции-члены RpcAgent.
std::shared_ptr<RpcAgent> RpcAgent::currentRpcAgent_ = nullptr;
bool RpcAgent::isCurrentRpcAgentSet() {
return std::atomic_load(¤tRpcAgent_) != nullptr;
}
std::shared_ptr<RpcAgent> RpcAgent::getCurrentRpcAgent() {
std::shared_ptr<RpcAgent> agent = std::atomic_load(¤tRpcAgent_);
return agent;
}
void RpcAgent::setCurrentRpcAgent(std::shared_ptr<RpcAgent> rpcAgent) {
if (rpcAgent) {
std::shared_ptr<RpcAgent> previousAgent;
// Use compare_exchange so that we don't actually perform the exchange if
// that would trigger the assert just below. See:
// https://en.cppreference.com/w/cpp/atomic/atomic_compare_exchange
std::atomic_compare_exchange_strong(
¤tRpcAgent_, &previousAgent, std::move(rpcAgent));
} else {
// We can't use compare_exchange (we don't know what value to expect) but we
// don't need to, as the only case that would trigger the assert is if we
// replaced nullptr with nullptr, which we can just do as it has no effect.
std::shared_ptr<RpcAgent> previousAgent =
std::atomic_exchange(¤tRpcAgent_, std::move(rpcAgent));
}
}
Поэтому текущее расширение выглядит следующим образом: в дальнейшем операции RPC будут выполняться через глобальную переменную RpcAgent::currentRpcAgent_.
RpcAgent::currentRpcAgent_
+
|
|
|
v
+-----+-----------+ +-----------------------+
| TensorPipeAgent | | RequestCallbackImpl |
| | | |
| cb_ +----------> | |
| | | |
+-----------------+ +-----------------------+
2.2 RPC-прокси
Все связанные функции dist.autograd выполняются на основе прокси-сервера RPC, поэтому нам нужно более внимательно изучить прокси-сервер.
2.2.1 RpcAgent
Это прокси-сервер, используемый для передачи RPC, и базовый класс прокси-сервера для отправки и получения сообщений RPC, который:
- при условии
send
API используется для обработки запросов и ответов. - cb_ также настроен для обработки входящих запросов.
WorkerInfo
глобальный уникальный идентификатор рабочего процесса, в котором находится экземпляр прокси, включаяname_
иid_
Эти две переменные-члены.name_
глобально уникальное имя,id_
является глобально уникальным идентификатором.
class TORCH_API RpcAgent {
public:
RpcAgent(
WorkerInfo id,
std::unique_ptr<RequestCallback> cb,
std::chrono::milliseconds rpcTimeout);
// 给 to.id 代表的其他 RpcAgengt 发送一个消息,返回一个JitFuture,这个实现是异步的。
virtual c10::intrusive_ptr<JitFuture> send(
const WorkerInfo& to.id,
Message&& message,
const float rpcTimeoutSeconds = kUnsetRpcTimeout,
const std::unordered_map<c10::Device, c10::Device>& deviceMap = {}) = 0;
protected:
const WorkerInfo workerInfo_; // 代理实例的全局唯一标示
const std::unique_ptr<RequestCallback> cb_; // 回调函数
std::atomic<std::chrono::milliseconds> rpcTimeout_;
std::atomic<bool> profilingEnabled_;
std::shared_ptr<TypeResolver> typeResolver_;
std::atomic<bool> rpcAgentRunning_;
private:
static std::shared_ptr<RpcAgent> currentRpcAgent_; // 全局代理
// Add GIL wait time data point to metrics
virtual void addGilWaitTime(const std::chrono::microseconds gilWaitTime) = 0;
friend class PythonRpcHandler;
// Condition Variable to signal when the rpcRetryMap_ has been populated.
std::condition_variable rpcRetryMapCV_;
// Mutex to protect RpcRetryMap_.
std::mutex rpcRetryMutex_;
};
2.2.2 ProcessGroupAgent
ProcessGroupAgent — это производный класс от RpcAgent. Это использовалось раньше, но PyTorch предоставляет лучший TensorAgent. Мы выбрали только некоторые переменные-члены.
class TORCH_API ProcessGroupAgent : public RpcAgent {
public:
c10::intrusive_ptr<::c10d::ProcessGroup> pg_;
// worker name -> rank
std::unordered_map<std::string, worker_id_t> nameMap_;
std::vector<WorkerInfo> allWorkerInfo_;
MessageCounter sendCounts_;
MessageCounter recvCounts_;
std::atomic<int64_t> nextId_;
std::thread listenerThread_;
std::thread futureTimeoutThread_;
c10::intrusive_ptr<c10d::ProcessGroup::Work> recvWork_;
std::unordered_map<
worker_id_t,
std::set<c10::intrusive_ptr<c10d::ProcessGroup::Work>>>
currentPendingSends_;
ThreadPool threadPool_;
// Mapping of request id to FutureInfo struct.
std::unordered_map<int64_t, FutureInfo> futures_;
};
2.2.3 TensorPipeAgent
TensorPipeAgent определен в torch/csrc/distributed/rpc/tensorpipe_agent.h, который есть сейчас и будет в будущем. TensorPipeAgent использует TensorPipe для прозрачного перемещения тензоров и данных между доступными транспортами или каналами. Это как гибридный транспорт RPC с поддержкой разделяемой памяти (linux) и TCP (linux и mac). PyTorch разрабатывает свою версию CUDA.
Мы выбрали только некоторые переменные-члены.
// TensorPipeAgent leverages TensorPipe (https://github.com/pytorch/tensorpipe)
// to transparently move tensors and payloads through the fastest available
// transport or channel. It acts like a hybrid RPC transport, providing shared
// memory (linux) and TCP (linux & mac) support. CUDA support is in progress.
class TensorPipeAgent : public RpcAgent {
public:
TensorPipeAgent(
const c10::intrusive_ptr<::c10d::Store>& store,
std::string selfName,
worker_id_t selfId,
int worldSize,
c10::intrusive_ptr<::c10d::ProcessGroup> processGroup,
TensorPipeRpcBackendOptions opts,
std::unique_ptr<RequestCallback> cb);
const TensorPipeRpcBackendOptions opts_;
std::unordered_map<std::string, DeviceMap> reverseDeviceMaps_;
std::vector<c10::Device> devices_;
ThreadPool threadPool_;
std::shared_ptr<tensorpipe::Context> context_;
std::shared_ptr<tensorpipe::Listener> listener_;
mutable std::mutex connectedPipesMutex_;
std::unordered_map<worker_id_t, ClientPipe> connectedPipes_;
// Maps keyed on name and id for easy WorkerInfo lookup.
std::unordered_map<worker_id_t, WorkerInfo> workerIdToInfo_;
std::unordered_map<std::string, WorkerInfo> workerNameToInfo_;
std::unordered_map<std::string, std::string> workerNameToURL_;
::c10d::PrefixStore rankToNameStore_;
::c10d::PrefixStore nameToAddressStore_;
const int worldSize_;
// The join method is required to behave like a barrier and perform collective
// operations. For simplicity and reliability, we offload this to a process
// group, but probably one day we might want to re-implement them using RPCs.
const c10::intrusive_ptr<::c10d::ProcessGroup> processGroup_;
std::atomic<uint64_t> nextMessageID_{0};
// Thread that will poll the timeoutMap_ for timed out messages and mark them
// with an error accordingly
std::thread timeoutThread_;
// Function run by the timeoutThread_ to check for timed out RPCs
void pollTimeoutRpcs();
};
2.2.4 Функция обратного вызова
Когда агент получает сообщение, он вызывает функцию обратного вызова. А RequestCallbackImpl реализует логику обратного вызова. RequestCallbackImpl является производным классом Давайте посмотрим на базовый класс RequestCallbackNoPython и найдем интерфейс RequestCallback, поэтому RequestCallback является основой этой системы производных.
class TORCH_API RequestCallbackNoPython : public RequestCallback
class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython
2.2.4.1 RequestCallback
RequestCallback — это интерфейс для обработки сообщений RPC и абстрактный класс.
// Functor which is invoked to process an RPC message. This is an abstract class
// with some common functionality across all request handlers. Users need to
// implement this interface to perform the actual business logic.
class TORCH_API RequestCallback {
public:
// Invoke the callback.
c10::intrusive_ptr<JitFuture> operator()(
Message& request,
std::shared_ptr<LazyStreamContext> ctx) const;
// NOLINTNEXTLINE(modernize-use-equals-default)
virtual ~RequestCallback() {}
protected:
// RpcAgent implementation should invoke ``RequestCallback`` to process
// received requests. There is no restriction on the implementation's
// threading model. This function takes an rvalue reference of the Message
// object. It is expected to return the future to a response message or
// message containing an exception. Different rpc agent implementations are
// expected to ensure delivery of the response/exception based on their
// implementation specific mechanisms.
virtual c10::intrusive_ptr<JitFuture> processMessage(
Message& request,
std::shared_ptr<LazyStreamContext> ctx) const = 0;
};
2.2.4.2 RequestCallbackNoPython
RequestCallbackNoPython определен в файле torch/csrc/distributed/rpc/request_callback_no_python.h, который реализует некоторые механизмы обработки. Поскольку он содержит слишком много методов, мы можем извлечь только некоторые части. Если вам интересно, изучите его подробно.
// RequestCallback implementation with no Python dependencies.
class TORCH_API RequestCallbackNoPython : public RequestCallback {
public:
c10::intrusive_ptr<JitFuture> processMessage(
Message& request,
std::shared_ptr<LazyStreamContext> ctx) const override;
protected:
void processForwardAutogradReq(
RpcCommandBase& rpc,
const int64_t messageId,
const c10::intrusive_ptr<JitFuture>& responseFuture,
std::shared_ptr<LazyStreamContext> ctx) const;
void processBackwardAutogradReq(
RpcCommandBase& rpc,
const int64_t messageId,
const c10::intrusive_ptr<JitFuture>& responseFuture) const;
void processRpc(
RpcCommandBase& rpc,
const MessageType& messageType,
const int64_t messageId,
const c10::intrusive_ptr<JitFuture>& responseFuture,
std::shared_ptr<LazyStreamContext> ctx) const;
virtual void processRpcWithErrors(
RpcCommandBase& rpc,
const MessageType& messageType,
const int64_t messageId,
const c10::intrusive_ptr<JitFuture>& responseFuture,
std::shared_ptr<LazyStreamContext> ctx) const;
virtual void processRRefBackward(
RpcCommandBase& rpc,
const int64_t messageId,
const c10::intrusive_ptr<JitFuture>& responseFuture) const;
};
Мы увидим, как вызывать функцию обратного вызова в последующем анализе логики принятия.
0x03 Логика отправки
Сначала рассмотрим логику отправки. То есть роль rpc.rpc_sync: установить root, добавить send и т.д.
3.1 Python
Начнем с части Python.
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
Сначала зашел в rpc_sync и обнаружил, что он вызывает _invoke_rpc.
@_require_initialized
def rpc_sync(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout)
return fut.wait()
Далее пришел_invoke_rpc
, видно, что эта функция выбирает разные пути в зависимости от типа вызова (встроенная операция, скрипт, udf).
def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RPC_TIMEOUT):
qualified_name = torch.jit._builtins._find_builtin(func)
dst_worker_info = _to_worker_info(to)
should_profile = torch.autograd._profiler_enabled()
ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info)
with ctx_manager as rf:
args = args if args else ()
kwargs = kwargs if kwargs else {}
is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
if is_async_exec:
wrapped = func._wrapped_async_rpc_function
if isinstance(wrapped, torch.jit.ScriptFunction):
func = wrapped
if qualified_name is not None:
fut = _invoke_rpc_builtin( # 内置rpc
dst_worker_info,
qualified_name,
rpc_timeout,
*args,
**kwargs
)
elif isinstance(func, torch.jit.ScriptFunction): # 脚本
fut = _invoke_rpc_torchscript(
dst_worker_info.name,
torch._jit_internal._qualified_name(func),
args,
kwargs,
rpc_timeout,
is_async_exec
)
else:
(pickled_python_udf, tensors) = _default_pickler.serialize(
PythonUDF(func, args, kwargs)
)
fut = _invoke_rpc_python_udf( # 用户udf
dst_worker_info,
pickled_python_udf,
tensors,
rpc_timeout,
is_async_exec
)
if should_profile:
fut = rf._call_end_callbacks_on_future(fut)
return fut
Отсюда я вошел в мир C++, torch/csrc/distributed/rpc/init.cpp.
3.2 C++
можно увидеть здесь _invoke_rpc_builtin
Соответствует pyRpcBuiltin,_invoke_rpc_python_udf
Соответствует pyRpcPythonUdf.
PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
module.def(
"_invoke_rpc_builtin",
[](const WorkerInfo& dst,
const std::string& opName,
const float rpcTimeoutSeconds,
const py::args& args,
const py::kwargs& kwargs) {
return std::make_shared<jit::PythonFutureWrapper>(
pyRpcBuiltin(dst, opName, args, kwargs, rpcTimeoutSeconds)); # 内置函数
},
py::call_guard<py::gil_scoped_acquire>());
module.def(
"_invoke_rpc_python_udf",
[](const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors,
const float rpcTimeoutSeconds,
const bool isAsyncExecution) {
return std::make_shared<jit::PythonFutureWrapper>(pyRpcPythonUdf(
dst,
pickledPythonUDF, # 对应了udf
tensors,
rpcTimeoutSeconds,
isAsyncExecution));
},
py::call_guard<py::gil_scoped_release>());
# 省略其他
}
мы выбираем_invoke_rpc_builtin
Взгляните на соответствующий файл pyRpcBuiltin.
3.2.1 pyRpcBuiltin
Как вы можете видеть в torch/csrc/distributed/rpc/python_functions.cpp, pyRpcBuiltin вызовет sendMessageWithAutograd.
c10::intrusive_ptr<JitFuture> pyRpcBuiltin(
const WorkerInfo& dst,
const std::string& opName,
const py::args& args,
const py::kwargs& kwargs,
const float rpcTimeoutSeconds) {
DCHECK(PyGILState_Check());
Stack stack;
auto op = matchBuiltinOp(opName, args, kwargs, stack);
// Release GIL since args and kwargs processing is done.
py::gil_scoped_release release;
auto scriptCall = std::make_unique<ScriptCall>(op, std::move(stack));
auto agent = RpcAgent::getCurrentRpcAgent(); // 获取当前agent
return toPyJitFuture(sendMessageWithAutograd( // 发送请求
*agent,
dst,
std::move(*scriptCall).toMessage(),
false,
rpcTimeoutSeconds));
}
3.2.2 sendMessageWithAutograd
Используйте агент здесь, в torch/csrc/distributed/autograd/utils.cpp, чтобы отправить FORWARD_AUTOGRAD_REQ.
Позже на получателе мы увидим обработку сообщений FORWARD_AUTOGRAD_REQ, так что отправка и получение могут быть примерно связаны.
c10::intrusive_ptr<JitFuture> sendMessageWithAutograd(
RpcAgent& agent,
const WorkerInfo& dst,
torch::distributed::rpc::Message&& wrappedRpcMsg,
bool forceGradRecording,
const float rpcTimeoutSeconds,
bool forceDisableProfiling) {
auto msg = getMessageWithAutograd( // 这里会与上下文交互,构建了 FORWARD_AUTOGRAD_REQ
dst.id_,
std::move(wrappedRpcMsg),
MessageType::FORWARD_AUTOGRAD_REQ,
forceGradRecording,
agent.getDeviceMap(dst));
c10::intrusive_ptr<JitFuture> fut;
// If profiler is enabled, wrap this message with profiling metadata that will
// tell the remote end to process this request with the profiler enabled.
if (!forceDisableProfiling && torch::autograd::profiler::profilerEnabled()) {
auto profilerConfig = torch::autograd::profiler::getProfilerConfig();
auto msgWithProfiling = getMessageWithProfiling(
std::move(msg),
rpc::MessageType::RUN_WITH_PROFILING_REQ, //构建消息
std::move(profilerConfig));
// 发送消息
fut = agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds);
} else {
fut = agent.send(dst, std::move(msg), rpcTimeoutSeconds);
}
return fut;
}
Процесс отправки выглядит следующим образом, в котором sendMessageWithAutograd будет использовать RpcAgent::getCurrentRpcAgent() для получения RpcAgent::currentRpcAgent_, который должен получить глобально установленный агент, а затем отправить его через агент.
rpc.rpc_sync
+
|
|
v
_invoke_rpc_builtin
+
| Python
+---------------------------------------------------------------+
| C++
|
v
pyRpcBuiltin
+
|
|
v
sendMessageWithAutograd(RpcAgent::getCurrentRpcAgent())
+
|
|
| RpcAgent::currentRpcAgent_
| +
| |
| |
| v
| +-----+-----------+
| | TensorPipeAgent | +-----------------------+
| | | | RequestCallbackImpl |
| | cb_ +------------> | |
| | | +-----------------------+
| | |
| | |
+-----------> send +-----------> Will send message to other worker
| |
| |
+-----------------+
Логика принятия 0x04
4.1 Обратный вызов
Когда агент получает сообщение, он вызывает RequestCallback::operator(). Это функция обратного вызова, о которой мы упоминали ранее. Код находится в torch/csrc/distributed/rpc/tensorpipe_agent.cpp.
void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
pipeRead(
pipe,
[this, pipe](
const tensorpipe::Error& error,
Message&& requestMessage,
std::shared_ptr<LazyStreamContext> ctx) mutable {
// Arm for next read
respond(pipe);
uint64_t messageId = requestMessage.id();
increaseCallCount(serverActiveCalls_);
// Defer user RPC UDF run to thread pool
threadPool_.run([this,
pipe,
messageId,
requestMessage{std::move(requestMessage)},
ctx{std::move(ctx)}]() mutable {
c10::intrusive_ptr<JitFuture> futureResponseMessage;
try {
// 这里会调用 RequestCallback 来进行回调逻辑处理
futureResponseMessage = cb_->operator()(requestMessage, ctx);
} catch (const std::exception& /* unused */) {
futureResponseMessage =
c10::make_intrusive<JitFuture>(at::AnyClassType::get());
futureResponseMessage->setError(std::current_exception());
}
// Shortcut if immediately done
if (futureResponseMessage->completed()) {
decreaseCallCount(serverActiveCalls_);
sendCompletedResponseMessage(
pipe, *futureResponseMessage, messageId, std::move(ctx));
} else {
// Not complete yet
increaseCallCount(serverActiveAsyncCalls_);
futureResponseMessage->addCallback(
[this, pipe, messageId, ctx{std::move(ctx)}](
JitFuture& futureResponseMessage) mutable {
decreaseCallCount(serverActiveCalls_);
decreaseCallCount(serverActiveAsyncCalls_);
sendCompletedResponseMessage(
pipe, futureResponseMessage, messageId, std::move(ctx));
});
}
});
});
}
4.2 operator()
В operator() будет вызываться processMessage для обработки сообщения.
c10::intrusive_ptr<JitFuture> RequestCallback::operator()(
Message& request,
std::shared_ptr<LazyStreamContext> ctx) const {
// NB: cannot clear autograd context id here because the processMessage method
// might pause waiting for all RRefs in the arguments to be confirmed by their
// owners and resumne processing in a different thread. Hence, the
// thread_local context id needs to be set and cleared in the thread that
// indeed carries out the processing logic.
return processMessage(request, std::move(ctx));
}
Затем он будет вызываться RequestCallbackNoPython::processMessage.
- Сначала вызовите команду deserializePythonRpcCommand, реализованную в RequestCallbackImpl, чтобы десериализовать PythonUDF.
- Затем вызовите processRpcWithErrors для обработки сообщения.
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processMessage(
Message& request,
std::shared_ptr<LazyStreamContext> ctx) const {
// We need two futures here because it could pause twice when processing a
// RPC message:
// 1) waiting for all RRefs in the arguments to become confirmed;
// 2) waiting for processRpc to finish.
auto retFuture = c10::make_intrusive<JitFuture>(at::AnyClassType::get());
auto& rrefContext = RRefContext::getInstance();
try {
rrefContext.recordThreadLocalPendingRRefs();
// Deserialize PythonUDF here to trigger RRef unpickling
// 调用 RequestCallbackImpl 中实现的 deserializePythonRpcCommand 来对 PythonUDF 反序列化
std::unique_ptr<RpcCommandBase> rpc = deserializePythonRpcCommand(
deserializeRequest(request), request.type()); // 解析请求
auto rrefsReadyFuture = rrefContext.waitForThreadLocalPendingRRefs();
rrefsReadyFuture->addCallback(
[this,
retFuture,
// std::function must be copyable, hence hae to cast the unique_ptr to
// a shared_ptr here.
rpc = (std::shared_ptr<RpcCommandBase>)std::move(rpc),
messageType = request.type(),
id = request.id(),
ctx = std::move(ctx)](JitFuture& /* unused */) mutable {
c10::MultiStreamGuard guard(
ctx ? ctx->getReservedStreams() : ArrayRef<Stream>({}));
// The cost of pre-request check is minimal thanks to
// std::shared_lock. The cost is in magnitude
// of 10us.
auto serverProcessGlobalProfilerStateStackEntryPtr =
profiler::processglobal::StateStackEntry::current();
// If server global profiler is enabled, we futher pay the
// cost of thread local profiler state initialization.
if (serverProcessGlobalProfilerStateStackEntryPtr) {
// Initialize thread-local profiler state from process-global
// profiler state.
::torch::autograd::profiler::enableProfilerLegacy(
serverProcessGlobalProfilerStateStackEntryPtr->statePtr()
->config());
}
// 在这里
processRpcWithErrors(
*rpc, messageType, id, retFuture, std::move(ctx));
// Response message has been sent at this moment, this post-response
// work doesn't affect RPC trip time.
if (serverProcessGlobalProfilerStateStackEntryPtr) {
// Restore thread-local profiler state.
::torch::autograd::profiler::thread_event_lists event_lists =
::torch::autograd::profiler::disableProfilerLegacy();
// Put thread_local event_lists into the process-global profiler
// state.
profiler::processglobal::pushResultRecursive(
serverProcessGlobalProfilerStateStackEntryPtr, event_lists);
}
});
} catch (std::exception& e) {
retFuture->markCompleted(handleError(e, request.type(), request.id()));
rrefContext.clearRecordedPendingRRefsOnError();
}
return retFuture;
}
Затем вызовите processRpcWithErrors.
void RequestCallbackNoPython::processRpcWithErrors(
RpcCommandBase& rpc,
const MessageType& messageType,
const int64_t messageId,
const c10::intrusive_ptr<JitFuture>& responseFuture,
std::shared_ptr<LazyStreamContext> ctx) const {
try {
processRpc(rpc, messageType, messageId, responseFuture, std::move(ctx));
} catch (std::exception& e) {
responseFuture->markCompleted(handleError(e, messageType, messageId));
}
}
Далее идет процессRpc. Здесь вы можете увидеть обработку FORWARD_AUTOGRAD_REQ.
void RequestCallbackNoPython::processRpc(
RpcCommandBase& rpc,
const MessageType& messageType,
const int64_t messageId,
const c10::intrusive_ptr<JitFuture>& responseFuture,
std::shared_ptr<LazyStreamContext> ctx) const {
case MessageType::FORWARD_AUTOGRAD_REQ: { // 这里就和之前发送的对应上了
processForwardAutogradReq(rpc, messageId, responseFuture, std::move(ctx));
return;
}
case MessageType::BACKWARD_AUTOGRAD_REQ: {
processBackwardAutogradReq(rpc, messageId, responseFuture);
return;
};
}
детали следующим образом:
TensorPipeAgent RequestCallback RequestCallbackNoPython RequestCallbackImpl
+ + + +
| | | |
| | | |
v | | |
respond | | |
+ | | |
| | | |
| | | |
v v v |
cb_->operator() +--> operator() +--> processMessage |
+ |
| |
| v
+---------------> deserializePythonRpcCommand
|
|
|
v
processRpcWithErrors
+
|
|
v
processRpc
+
|
|
v
processForwardAutogradReq
4.3 RequestCallbackImpl
В это время у читателей возникнут вопросы. До того, как TensorPipeAgent четко установил RequestCallbackImpl в качестве функции обратного вызова, зачем вызывать только его deserializePythonRpcCommand? DeserialXXX, похоже, связан с сериализацией. Говорят, что некоторые функции бизнес-обработки должны быть вызваны, такие как processXXXX и как. Далее рассмотрим RequestCallbackImpl.
RequestCallbackImpl определен в torch/csrc/distributed/rpc/request_callback_impl.h.
class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython {
public:
std::unique_ptr<RpcCommandBase> deserializePythonRpcCommand(
std::unique_ptr<RpcCommandBase> rpc,
const MessageType& messageType) const override;
void processPythonCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const int64_t messageId,
const c10::intrusive_ptr<JitFuture>& responseFuture) const override;
void processScriptCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const int64_t messageId,
const c10::intrusive_ptr<JitFuture>& responseFuture) const override;
void processScriptRemoteCall(
ScriptRemoteCall& scriptRemoteCall,
const std::function<void(void)>& postProcessing,
std::vector<at::IValue>& stack,
const c10::intrusive_ptr<OwnerRRef>& ownerRRef) const override;
void processPythonRemoteCall(
RpcCommandBase& rpc,
const std::function<void(Message)>& markComplete,
const int64_t messageId,
const c10::intrusive_ptr<JitFuture>& responseFuture,
std::shared_ptr<LazyStreamContext> ctx) const override;
void processRpcWithErrors(
RpcCommandBase& rpc,
const MessageType& messageType,
const int64_t messageId,
const c10::intrusive_ptr<JitFuture>& responseFuture,
std::shared_ptr<LazyStreamContext> ctx) const override;
void processRRefBackward(
RpcCommandBase& rpc,
const int64_t messageId,
const c10::intrusive_ptr<JitFuture>& responseFuture) const override;
};
Поскольку RequestCallbackImpl окончательно сгенерирован, на самом деле в середине приведенного выше рисунка есть шаг: ProcessRpcWithErrors фактически вызывает функцию processRpcWithErrors в RequestCallbackImpl, что добавляет некоторую логику обработки исключений.
void RequestCallbackImpl::processRpcWithErrors(
RpcCommandBase& rpc,
const MessageType& messageType,
const int64_t messageId,
const c10::intrusive_ptr<JitFuture>& responseFuture,
std::shared_ptr<LazyStreamContext> ctx) const {
try {
processRpc(rpc, messageType, messageId, responseFuture, std::move(ctx));
} catch (py::error_already_set& e) {
responseFuture->markCompleted(handleError(e, messageType, messageId));
py::gil_scoped_acquire acquire;
e.restore(); // Release ownership on py::objects and also restore
// Python Error Indicator.
PyErr_Clear(); // Clear the Python Error Indicator as we has
// recorded the exception in the response message.
} catch (std::exception& e) {
responseFuture->markCompleted(handleError(e, messageType, messageId));
}
}
Логическая схема изменена следующим образом:
TensorPipeAgent RequestCallback RequestCallbackNoPython RequestCallbackImpl
+ + + +
| | | |
| | | |
v | | |
respond | | |
+ | | |
| | | |
| | | |
v v v |
cb_->operator() +--> operator() +--> processMessage |
+ |
| |
| v
+----------------> deserializePythonRpcCommand
| +
| |
| |
| v
|
+----------------> processRpcWithErrors
| +
| |
| |
| <------------------------+
|
|
v
processRpc
+
|
|
v
processForwardAutogradReq
Если совместить с предыдущей передачей, то расширяем схему следующим образом:
- rpc.rpc_sync вызывается, когда отправителю необходимо удаленно запустить автоматические вычисления градиента.
- Эта функция вызывается из Python в мир C++ и называется pyRpcBuiltin.
- Вызовите sendMessageWithAutograd, чтобы уведомить Receiver.
- RpcAgent::getCurrentRpcAgent() будет вызываться для получения локального агента.
- Вызов функции отправки текущего агента.
- Функция отправки отправляет FORWARD_AUTOGRAD_REQ рабочему процессу Receiver.
- Функция ответа вызовет функцию обратного вызова cb_ агента в приемнике.
- Вызов processRpcWithErrors для RequestCallbackImpl.
- Затем вызовите processRpc.
- Наконец, вызывается processForwardAutogradReq для завершения процесса запуска распределенного автограда на основе RPC.
+
rpc.rpc_sync Sender | Receiver
+ |
| |
| 1 |
v |
_invoke_rpc_builtin |
+ |
| Python |
+----------------------------------------------------------+ |
| C++ | +----------------------------+
| 2 | | RequestCallbackImpl |
v | | |
| +----> processRpcWithErrors |
pyRpcBuiltin | | | + |
+ | | | | 9 |
| 3 | | | | |
| | | | v |
v | | | processRpc |
4 | | | + |
sendMessageWithAutograd(RpcAgent::getCurrentRpcAgent()) | | | | 10 |
+ | | | | |
| | | | v |
| | | | processForwardAutogradReq |
| RpcAgent::currentRpcAgent_ | | | |
| + | | +----------------------------+
| | | |
| 5 | | |8 +-----------------+
| v | | | TensorPipeAgent |
| +------+--------+ | | | |
| |TensorPipeAgent| +-------------------+ | +------------+ cb_ |
| | | |RequestCallbackImpl| | | ^ |
| | cb_ +------->+ | | | 7 | |
| | | +-------------------+ | | | |
| | | 6 | | + |
+--------> send +----------------------------------+--------------> respond |
| | FORWARD_AUTOGRAD_REQ | |
| | + | |
+---------------+ | +-----------------+
+
Телефон такой:
На этом знакомство с RPC закончено, в следующей статье мы представим контекстно-зависимые классы управления, так что следите за обновлениями.
0xEE Личная информация
★★★★★★Думая о жизни и технологиях★★★★★★
Публичный аккаунт WeChat:мысли Росси