[Анализ исходного кода] Распределенный PyTorch (2) --- DataLoader для загрузки данных

машинное обучение глубокое обучение

0x00 сводка

Чтобы лучше представить загрузку данных сервера параметров Paracel, мы временно вставляем две загрузки данных PyTorch, в основном с распределенной точки зрения. Эта статья — просто закуска, и в будущем будет специальная серия для анализа распространения PyTorch.

Другие статьи из серии серверов параметров:

[Анализ исходного кода] Сервер параметров машинного обучения ps-lite (1) ----- PostOffice

[Анализ исходного кода] Сервер параметров машинного обучения ps-lite(2) ----- Коммуникационный модуль Van

[Анализ исходного кода] (3) сервера параметров машинного обучения ps-lite ----- Агент Заказчик

[анализ исходного кода] сервер параметров машинного обучения ps-lite(4) ----- реализация узла приложения

[Анализ исходного кода] Сервер параметров машинного обучения Paracel (1) ----- общая архитектура

[Анализ исходного кода] Сервер параметров машинного обучения Paracel (2) ----- Реализация SSP

[Анализ исходного кода] Распределенный PyTorch (1) --- DistributedSampler загрузки данных

0x01 Обзор предыдущей ситуации

Что касается загрузки данных, мы говорили о DistributedSampler в прошлой книге, а в этой статье далее будет анализироваться DataLoader.

Для лучшей иллюстрации сначала приведем приведенную выше схему конвейера, которая будет уточнена в этой статье.

                    +------------+
+--------+          |            |
|        |          | Process 1  |
| Data 1 +--------> |            +------+
|        |          | Load Data  |      |
+--------+          |            |      |
                    +------------+      |
                                        |
                                        |
                                        |
                    +------------+      |        +-----------------------------------+
+--------+          |            |      |        |                                   |
|        |          | Process 2  |      +------> | Pin-memory process                |
| Data 2 +--------> |            |               |                                   |
|        |          | Load Data  +-------------> |                                   |
+--------+          |            |               |        Transfer to Pinned Memory  |
                    +------------+       +-----> |                                   |
                                         |       |                                   |
                                         |       +-----------------------------------+
                                         |
+--------+          +------------+       |
|        |          |            |       |
| Data 3 +--------> | Process 3  +-------+
|        |          |            |
+--------+          | Load Data  |
                    |            |
                    +------------+
​

Во-вторых, кратко рассмотрим общую логику загрузки данных, как показано на рисунке ниже:

  1. DataSet отправляет количество наборов данных в DistributedSampler.
  2. Sampler формирует индексы данных по определенным правилам и отправляет их в DataLoader.
  3. DataLoader загружает данные из DataSet в соответствии с индексами (его внутренний объект DataLoaderIter отвечает за координацию однопроцессной/многопроцессной загрузки Dataset).
  4. DataLoader отправляет данные в модель для обучения.
+------------------------+                     +-----------+
|DistributedSampler      |                     |DataLoader |
|                        |     2 indices       |           |
|    Some strategy       +-------------------> |           |
|                        |                     |           |
|-------------+----------|                     |           |
              ^                                |           |  4 data  +-------+
              |                                |       -------------->+ train |
            1 | length                         |           |          +-------+
              |                                |           |
+-------------+----------+                     |           |
|DataSet                 |                     |           |
|        +---------+     |      3 Load         |           |
|        |  Data   +-------------------------> |           |
|        +---------+     |                     |           |
|                        |                     |           |
+------------------------+                     +-----------+

Далее мы официально входим в DataLoader.

0x02 DataLoader

Роль DataLoader: после объединения набора данных и семплера он предоставляет итератор для набора данных.

Это можно понять так:

DataSet — это исходные данные, Sampler предоставляет стратегию разделения данных (или предоставляет размерность разделения данных), DataLoader основан на стратегии конкретной работы, в которой загрузка одного процесса означает, что работает один человек, загрузка нескольких процессов заключается в привлечении нескольких человек к совместной работе.

2.1 Инициализация

Основные параметры инициализации следующие:

  • набор данных (Dataset): набор данных для загрузки.
  • batch_size (целое число, необязательный): сколько сэмплов загружается в пакет.
  • shuffle (bool, необязательный): если True, данные будут перемешиваться каждую эпоху.
  • Sampler (Sampler или Iterable, необязательный): определяет стратегию отбора проб из выборки. может быть любая реализация__len__итератор .
  • batch_sampler (Sampler или Iterable, необязательно): сsamplerАналогично, но возвращает пакет индексов данных за раз.
  • num_workers (int, необязательный): количество дочерних процессов для загрузки данных. Если он равен 0, это означает загрузку данных из основного процесса.
  • collate_fn (вызываемый, необязательный): объединить список выборок из мини-пакетного тензора. Используется при массовой загрузке из наборов данных в стиле карты.
  • pin_memory (bool, необязательный): если true, скопируйте тензор в закрепленную память CUDA перед возвратом тензора.
  • drop_last (bool, необязательный): если набор данных не может быть разделен равномерно, если это правда, удалить последний неполный пакет. Если False, последняя партия имеет меньший объем данных.
  • тайм-аут (числовой, необязательный): если целое число, значение тайм-аута для рабочего процесса для сбора пакетных данных.
  • worker_init_fn (вызываемый, необязательный): если он не равен нулю, он будет вызываться каждым дочерним процессом перед заполнением и загрузкой данных с идентификатором Iworker ([0, num_workers - 1]) в качестве входного параметра.
  • генератор (torch.Generator, необязательный): если он не равен нулю, он используется RandomSampler для генерации случайных индексов, а также используется несколькими процессами для генерацииbase_seed.
  • prefetch_factor (целое число, необязательный аргумент, состоящий только из ключевых слов): количество выборок для предварительной выборки на одного исполнителя.
  • персистентные_воркеры (логический, необязательный): еслиTrue, загрузчик данных не остановит рабочий процесс после его однократного использования. Это позволяет работникамDatasetЭкземпляр остается активным.

Конкретный код инициализации выглядит следующим образом, в основном для различных настроек.Для лучшего понимания код обработки исключений удален:

class DataLoader(Generic[T_co]):
​
    dataset: Dataset[T_co]
    batch_size: Optional[int]
    num_workers: int
    pin_memory: bool
    drop_last: bool
    timeout: float
    sampler: Sampler
    prefetch_factor: int
    _iterator : Optional['_BaseDataLoaderIter']
    __initialized = False
​
    def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
                 shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
                 batch_sampler: Optional[Sampler[Sequence[int]]] = None,
                 num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
                 pin_memory: bool = False, drop_last: bool = False,
                 timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
                 multiprocessing_context=None, generator=None,
                 *, prefetch_factor: int = 2,
                 persistent_workers: bool = False):
        torch._C._log_api_usage_once("python.data_loader")
​
        self.dataset = dataset
        self.num_workers = num_workers
        self.prefetch_factor = prefetch_factor
        self.pin_memory = pin_memory
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn
        self.multiprocessing_context = multiprocessing_context
​
        if isinstance(dataset, IterableDataset):
            self._dataset_kind = _DatasetKind.Iterable
            # 省略异常处理
        else:
            self._dataset_kind = _DatasetKind.Map
​
        if batch_sampler is not None:
            # auto_collation with custom batch_sampler
            # 省略异常处理
            batch_size = None
            drop_last = False
        elif batch_size is None:
            # no auto_collation
            if drop_last:
                raise ValueError('batch_size=None option disables auto-batching '
                                 'and is mutually exclusive with drop_last')
​
        if sampler is None:  # give default samplers
            if self._dataset_kind == _DatasetKind.Iterable:
                # See NOTE [ Custom Samplers and IterableDataset ]
                sampler = _InfiniteConstantSampler()
            else:  # map-style
                if shuffle:
                    sampler = RandomSampler(dataset, generator=generator)  
                else:
                    sampler = SequentialSampler(dataset) 
​
        if batch_size is not None and batch_sampler is None:
            # auto_collation without custom batch_sampler
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
​
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.sampler = sampler
        self.batch_sampler = batch_sampler
        self.generator = generator
​
        if collate_fn is None:
            if self._auto_collation:
                collate_fn = _utils.collate.default_collate
            else:
                collate_fn = _utils.collate.default_convert
​
        self.collate_fn = collate_fn
        self.persistent_workers = persistent_workers
        self.__initialized = True
        self._IterableDataset_len_called = None 
        self._iterator = None
        self.check_worker_number_rationality()
​

2.2 Ключевые функции

Одной из ключевых функций здесь является _index_sampler, которая используется для создания сэмплера вызова итератора, о котором мы поговорим далее.

    @property
    def _index_sampler(self):
        # The actual sampler used for generating indices for `_DatasetFetcher`
        # (see _utils/fetch.py) to read data at each time. This would be
        # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
        # We can't change `.sampler` and `.batch_sampler` attributes for BC
        # reasons.
        if self._auto_collation:
            return self.batch_sampler
        else:
            return self.sampler
​

2.3 Загрузка одного процесса

В однопроцессном режиме Data Loader будет загружать данные в процессе расчета, поэтому расчет может быть заблокирован в процессе загрузки.

Оператор for вызывает перечисление и возвращает итератор для обхода набора данных. В eumerate загрузчик данных__next__(self)Метод будет вызываться для выборки следующих объектов один за другим, таким образом проходя по набору данных.

    cuda0 = torch.device('cuda:0')  # CUDA GPU 0
    for i, x in enumerate(train_loader):
        x = x.to(cuda0)

2.3.1 Дифференциальная генерация

При многопроцессной загрузке итератор создается только один раз в течение цикла объявления DataLoader, чтобы рабочие могли повторно использовать итератор.

При загрузке одного процесса он должен создаваться каждый раз, чтобы избежать сброса состояния.

    def __iter__(self) -> '_BaseDataLoaderIter':
        if self.persistent_workers and self.num_workers > 0: # 如果是多进程或者设置了持久化
            if self._iterator is None: # 如果没有,才会新生成
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else: # 单进程
            return self._get_iterator() # 每次都直接生成新的
​

Конкретное поколение будет различаться в зависимости от того, является ли оно многопроцессорным.

    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)

2.3.2 Базовый класс итератора

_BaseDataLoaderIter — это базовый класс итератора, давайте выберем ключевые функции для просмотра.

Ключевыми переменными-членами здесь являются:

  • _index_sampler: здесь устанавливается сэмплер загрузчика, поэтому итератор может получить соответствующую стратегию выборки.
  • _sampler_iter: получить итератор сэмплера.
class _BaseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        # 初始化参数
        self._dataset = loader.dataset
        self._dataset_kind = loader._dataset_kind
        self._IterableDataset_len_called = loader._IterableDataset_len_called
        self._auto_collation = loader._auto_collation
        self._drop_last = loader.drop_last
        self._index_sampler = loader._index_sampler # 得到采样策略
        self._num_workers = loader.num_workers
        self._prefetch_factor = loader.prefetch_factor
        self._pin_memory = loader.pin_memory and torch.cuda.is_available()
        self._timeout = loader.timeout
        self._collate_fn = loader.collate_fn
        self._sampler_iter = iter(self._index_sampler) # 得到sampler的迭代器
        self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
        self._persistent_workers = loader.persistent_workers
        self._num_yielded = 0
        self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__)
​
​
    def __next__(self) -> Any:
        with torch.autograd.profiler.record_function(self._profile_name):
            if self._sampler_iter is None:
                self._reset()
            data = self._next_data() # 获取数据
            self._num_yielded += 1
            if self._dataset_kind == _DatasetKind.Iterable and \
                    self._IterableDataset_len_called is not None and \
                    self._num_yielded > self._IterableDataset_len_called:
                    # 忽略错误提示处理
                warnings.warn(warn_msg)
            return data

2.3.3 Итератор одного процесса

_SingleProcessDataLoaderIterнаследовать_BaseDataLoaderIter, видно, что увеличивается_dataset_fetcher, переданный во время строительства_collate_fnи другие параметры.

Помните,__next__позвонюself._next_data()чтобы получить данные, и здесь,_next_dataбудет:

  • использоватьself._next_index(), который, в свою очередь, использует_sampler_iter(Итератор сэмплеров) для получения индексов.
  • использоватьself._dataset_fetcher.fetch(index)получить данные на основе индексов.
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0
​
        # 获取样本方法
        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
​
    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        # 获取样本
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data
    
    def _next_index(self): # 得到indices
        return next(self._sampler_iter)  # may raise StopIteration    

2.3.4 Получение образцов

Давайте посмотрим, как получить образец дальше. Это передача индекса в сборщик, чтобы получить желаемый образец.

Сборщик создается следующим образом, когда инициализируется _SingleProcessDataLoaderIter:

class _DatasetKind(object):
    Map = 0
    Iterable = 1
​
    @staticmethod
    def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
        if kind == _DatasetKind.Map:
            return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
        else:
            return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
​

Для Map-стиля используйте обработку _MapDatasetFetcher, то есть используйте возможно_batched_index для извлечения данных из набора данных, возможно_batched_index является ключом.

Если доступен пакетный пробоотборник, используйте пакетный пробоотборник.

Если вам нужно объединить список выборок из мини-пакетного тензора. Просто используйте постобработку collate_fn.

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
​
    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            # 如果配置了batch_sampler,_auto_collation就为True,
            # 那么就优先使用batch_sampler,此时fetcher中传入的就是一个batch的索引
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

Для Iterable-стиля, потому что__init__Начальный итератор набора данных задается в методе, поэтому при выборке элементов в методе выборки, если это обычный сэмплер, индекс не имеет значения и получается непосредственно из итератора набора данных. Если это пакетный сэмплер, индекс имеет значение.

class _IterableDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
        self.dataset_iter = iter(dataset)
​
    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            # 即auto_collation为True,表示使用batch_sampler。
            # 则使用possibly_batched_index,获取1个batch大小的样本       
            data = []
            for _ in possibly_batched_index:
                try:
                    data.append(next(self.dataset_iter))
                except StopIteration:
                    break
            if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
                raise StopIteration
        else:
            # sampler则直接往后遍历,提取1个样本
            data = next(self.dataset_iter)
        return self.collate_fn(data)

Общая логика на данном этапе следующая:

     +--------------------------+            +-------------------------------+
     | DataLoader               |            | _SingleProcessDataLoaderIter  |
     |                          |            |                               |
     |                          |            |               __next__        |
+---------------+ Sampler       |            |                               |
|    |                          |            |              _next_data +-----------+
|    |            Dataset       |            |                               |     |
|    |                          |            |              _next_index      |     |
|    |           __iter__       |            |                               |     |
|    |                          |            |             _index_sampler    |     |
|    |       _get_iterator  +--------------> |                    +          |     |
|    |                          |            |                    |          |     |
|    +--------------------------+            +-------------------------------+     |
|                                                                 |                |
|                                                                 |                |
|                                                                 |                |
|                                                                 |                |
|                                                                 |                |
|                           +----------------------------+        |                |
|                           |Sampler                     |        |                |
+------------------------>  |                            | <------+                |
                            |                            |                         |
                            |                            |                         |
                            |                            |                         |
                            +----------------------------+                         |
                                                                                   |
                                                                                   |
                            +----------------------------+                         |
                            |_BaseDatasetFetcher         |                         |
                            |                            |                         |
                            |                            |                         |
                            |          dataset           |                         |
                            |                            |  <----------------------+
                            |          collate_fn        |
                            |                            |
                            +----------------------------+
​

Динамический процесс выглядит следующим образом:

  User              DataLoader    _SingleProcessDataLoaderIter _DatasetKind   Sampler
​
    +                   +                    +                        +           +
    |                   |                    |                        |           |
    |         1         |                    |                        |           |
 enumerate-------->  __iter__                |                        |           |
    |                   +                    |                        |           |
    |                   |                    |                        |           |
    |                   |                    |                        |           |
    |                   |          2         v            3           v           |
    |              _get_iterator--------> __init__  +----------> create_fetcher   |
    |         4         |                    +                        +           |
    | <-----------------+                    |                        |           |
    |      iterator     |                    |                        |           |
    |                   |          5         |                        |           |
for loop +------------------------------> __next__                    |           |
    |                   |                    |                        |           |
    |                   |                    |                        |           |
    |                   |                    |                        |           |
    |                   |                _next_data                   |           |
    |                   |                    |                        |           |
    |                   |                    |                        |           |
    |                   |                    |           6  next      |           |
    |                   |                _next_index  +-------------------------> |
    |                   |                    |                        |           |
    |                   |                    |  <---------------------------------+
    |                   |                    |           7  index     |           |
    |                   |                    |                        |           |
    |                   |                    |                        |           |
    |                   |                    |        8 fetch(index)  |           |
    |                   |                    | +--------------------> |           |
    |                   |                    |                        |           |
    |                   |                    |  <---------------------+           |
    |                   |                    |         9  data        |           |
    |  <-------------------------------------+                        |           |
    |   10  data        |                    |                        |           |
    |                   |                    |                        |           |
    v                   v                    v                        v           v
​
​

2.4 Многопроцессная загрузка

Для ускорения PyTorch обеспечивает многопроцессорную загрузку, если параметрnum_workersЕсли установлено положительное целое число, система будет генерировать многопроцессную обработку соответственно, в этом режиме каждый рабочий процесс является независимым процессом.

Из предыдущего раздела мы можем знать, что _SingleProcessDataLoaderIter — это ядро ​​однопроцессной загрузки данных, посредством которого загрузчик взаимодействует с семплером и набором данных. В многопроцессорном режиме это ядро ​​соответствует _MultiProcessingDataLoaderIter.

    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)
​

Далее мы начнем анализ с _MultiProcessingDataLoaderIter.

2.4.1 Общая логика

Комментарии в _MultiProcessingDataLoaderIter очень подробны и их стоит прочитать, а логическая блок-схема представлена ​​следующим образом: Основной процесс выполняется вокруг трех очередей:

  • Основной процесс помещает индекс данных, которые необходимо получить, в index_queue, которая представляет собой очередь, указывающую, какие данные необходимо получить дочернему процессу. При этом очередь результатов также передается дочернему процессу.Относительно очереди результатов есть две ветки:

    • Если установлена ​​память контактов, передается worker_result_queue.
    • В противном случае перейдите в data_queue.
  • Дочерний процесс считывает индекс из index_queue, считывает данные, а затем помещает индекс прочитанных данных в worker_result_queue, очередь, которая возвращает результат основному процессу.

  • Обрабатывается основной процесс, здесь есть две ветки:

    • Если установлена ​​память пинов, то pin_memory_thread основного процесса будет считывать индекс данных из worker_result_queue, читать данные по индексу, обрабатывать их и помещать результат в data_queue, которая является очередью для обработки результата.
    • Если память контактов не требуется, результат уже хранится в data_queue, и новая операция не выполняется.

Как видите, вход каждого процесса — это очередь index_queue, а выход — очередь worker_result_queue. Основной процесс и дочерний процесс связаны через эти 2~3 очереди, чтобы добиться эффекта развязки и ускорения.

    # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
    #
    # Preliminary:
    #
    # Our data model looks like this (queues are indicated with curly brackets):
    #
    #                main process                              ||
    #                     |                                    ||
    #               {index_queue}                              ||
    #                     |                                    ||
    #              worker processes                            ||     DATA
    #                     |                                    ||
    #            {worker_result_queue}                         ||     FLOW
    #                     |                                    ||
    #      pin_memory_thread of main process                   ||   DIRECTION
    #                     |                                    ||
    #               {data_queue}                               ||
    #                     |                                    ||
    #                data output                               /
    #
    # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
    #      `pin_memory=False`.
​

Как показано на рисунке ниже, если память контактов не требуется, это:

                                               +-----------+
               indices  -------------+ indices | Worker    | Data
             +--------->+index queue +-------->+ Process   +------+
             |          |            |         |           |      |
             |          -------------+         +-----------+      |
             |                                                    |   +------------+
             |                                                    |   |            |
+---------+  |                                                    +--->            |
| Main    |  | indices  -------------+ indices +-----------+          |            |
| Process +------------>+index queue +-------->+ Worker    | Data     | Data Queue |
|         |  |          |            |         | Process   +---------->            |
+---------+  |          -------------+         |           |          |            |
             |                                 +-----------+      +--->            |
             |                                                    |   +------------+
             |                                                    |
             | indices  -------------+ indices +-----------+      |
             +--------->+index queue +-------->+ Worker    | Data |
                        |            |         | Process   +------+
                        -------------+         |           |
                                               +-----------+
​

При наличии памяти пинов она сначала попадет в очередь результатов, а затем после обработки в очередь данных будет передана цепочка pin_memory_thread:

                                               +-----------+
               indices  -------------+ indices | Worker    | Data
             +--------->+index queue +-------->+ Process   +------+
             |          |            |         |           |      |
             |          -------------+         +-----------+      |
             |                                                    |   --------------+
             |                                                    |   |             |
+---------+  |                                                    +--->             |
| Main    |  | indices  -------------+ indices +-----------+          |             |
| Process +------------>+index queue +-------->+ Worker    | Data     | result_queue|
|         |  |          |            |         | Process   +---------->             |
+---------+  |          -------------+         |           |          |             |
             |                                 +-----------+      +--->             |
             |                                                    |   ---------+----+
             |                                                    |            |
             | indices  -------------+ indices +-----------+      |            |
             +--------->+index queue +-------->+ Worker    | Data |  +---------+--------+
                        |            |         | Process   +------+  | pin_memory_thread|
                        -------------+         |           |         |         |        |
                                               +-----------+         |         |        |
                                                                     |         |        |
                                                                     +------------------+
                                                                               |
                                                                               |
                                                                               |
                                                                               v
                                                                         +-----+------+
                                                                         | Data Queue |
                                                                         |            |
                                                                         +------------+
​

2.4.2 Инициализация

Функция инициализации в основном выглядит следующим образом:

  • Настройте, создайте различные переменные-члены и настройте различные очереди.
  • Запустите каждый дочерний процесс.
  • Поток, запускающий pin_memory в основном процессе.

Основные переменные-члены:

  • _index_queues: это список очереди, каждый элемент списка представляет собой очередь, которая является индексом данных, которые должны быть обработаны очередью каждого дочернего процесса.
  • _worker_result_queue: (idx, data) обрабатывается дочерним процессом.
  • data_queue: Очередь данных обрабатывается потоком основного процесса pin_memory, если пин не нужен, он будет использоваться напрямую_worker_result_queue.
  • _worker_queue_idx_cycleРабочий находил следующую работу.

Конкретный код выглядит следующим образом:

class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
    r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
​
    def __init__(self, loader):
        super(_MultiProcessingDataLoaderIter, self).__init__(loader)
​
        assert self._num_workers > 0
        assert self._prefetch_factor > 0
​
        if loader.multiprocessing_context is None:
            multiprocessing_context = multiprocessing
        else:
            multiprocessing_context = loader.multiprocessing_context
​
        self._worker_init_fn = loader.worker_init_fn
        self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
        # No certainty which module multiprocessing_context is
        self._worker_result_queue = multiprocessing_context.Queue()  # 子进程输出,读取完数据的index
        self._worker_pids_set = False
        self._shutdown = False
        self._workers_done_event = multiprocessing_context.Event()
​
        self._index_queues = [] # 子进程输入,需读取数据的index
        self._workers = []
        for i in range(self._num_workers):
            # No certainty which module multiprocessing_context is
            index_queue = multiprocessing_context.Queue()  # type: ignore[var-annotated]
            # Need to `cancel_join_thread` here!
            # See sections (2) and (3b) above.
            index_queue.cancel_join_thread()
            w = multiprocessing_context.Process(
                target=_utils.worker._worker_loop, # worker进程主函数,把各种queue和函数传进去
                args=(self._dataset_kind, self._dataset, index_queue,
                      self._worker_result_queue, self._workers_done_event,
                      self._auto_collation, self._collate_fn, self._drop_last,
                      self._base_seed, self._worker_init_fn, i, self._num_workers,
                      self._persistent_workers))
            w.daemon = True
            w.start()
            self._index_queues.append(index_queue) # 把这个worker对应的index_queue放到主进程这里存起来,以后就可以交互了
            self._workers.append(w)
​
        if self._pin_memory:
            self._pin_memory_thread_done_event = threading.Event()
​
            # Queue is not type-annotated
            self._data_queue = queue.Queue()  # pin 处理之后的数据结果
            pin_memory_thread = threading.Thread(
                target=_utils.pin_memory._pin_memory_loop,
                args=(self._worker_result_queue, self._data_queue,
                      torch.cuda.current_device(),
                      self._pin_memory_thread_done_event))
            pin_memory_thread.daemon = True
            pin_memory_thread.start()
            # Similar to workers (see comment above), we only register
            # pin_memory_thread once it is started.
            self._pin_memory_thread = pin_memory_thread
        else:
            self._data_queue = self._worker_result_queue # 如果不需要pin,则直接使用_worker_result_queue
​
        # .pid can be None only before process is spawned (not the case, so ignore)
        _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))  # type: ignore[misc]
        _utils.signal_handling._set_SIGCHLD_handler()
        self._worker_pids_set = True
        
        self._reset(loader, first_iter=True) # 继续完善业务

2.4.3 Сброс бизнеса

__init__В конце функция вызовет функцию _reset, которая предназначена для дальнейшего улучшения инициализации бизнеса, а также для сброса среды.

В функции из предыдущего раздела рабочий подпроцесс запущен, но задача не выделена, поэтому функция _reset выполнит выделение задачи и предварительную выборку.

_MultiProcessingDataLoaderIter имеет следующие параметры флага для координации работы между воркерами (включая различные очереди):

  • _send_idx: Отправьте индекс для записи идентификатора пакета, который на этот раз будет помещен в очередь index_queue.
  • _rcvd_idx: принять индекс, записать idx пакета, который нужно взять из data_queue.
  • _task_info: словарь, в котором хранится информация о данных, которая должна быть сгенерирована, ключ — идентификатор задачи (целочисленный индекс, начинающийся с 0), а значение —(worker_id,)или(worker_id, data), соответствующие невыбранным и извлеченным данным соответственно
  • _tasks_outstanding: Целое число, представляющее количество подготовленных задач/пакетов (некоторые из них могут находиться в стадии подготовки).
  • _send_idx: отправьте индекс и запишите idx пакета задач в index_queue в следующий раз.
  • _rcvd_idx: принять индекс и записать идентификатор следующего пакета задач, который будет взят из очереди данных._send_idxи_rcvd_idxВ основном используется для управления потоком и для обеспечения того, чтобы принятие индексов имело смысл.
  • _task_info: словарь, в котором хранится информация о данных, которая должна быть сгенерирована, ключом является идентификатор пакета задач (целочисленный индекс, начинающийся с 0), а значение равно(worker_id,)или(worker_id, data), соответствующие невыбранным и извлеченным данным соответственно._task_infoФункция состоит в том, чтобы получить соответствующий рабочий идентификатор и временно сохранить неупорядоченные данные в соответствии с идентификатором пакета задач.
  • _tasks_outstanding: Integer, количество подготавливаемых задач/пакетов, по сути, это выполнение какой-то подтверждающей работы, что не очень практично.

Для загрузки данных каждый рабочий процесс генерирует по одному пакету данных за раз.Перед возвратом пакетных данных он помещает нижний индекс данных для обработки в следующий пакет, поэтому функция сброса помещает данные в следующий пакет._send_idx,_rcvd_idxВсе возвращаются к 0, так что следующая итерация может повторить обработку.

    def _reset(self, loader, first_iter=False):
        super()._reset(loader, first_iter)
        self._send_idx = 0  # idx of the next task to be sent to workers
        self._rcvd_idx = 0  # idx of the next task to be returned in __next__
        # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
        # map: task idx => - (worker_id,)        if data isn't fetched (outstanding)
        #                  \ (worker_id, data)   if data is already fetched (out-of-order)
        self._task_info = {}
        self._tasks_outstanding = 0  # always equal to count(v for v in task_info.values() if len(v) == 1)
        # A list of booleans representing whether each worker still has work to
        # do, i.e., not having exhausted its iterable dataset object. It always
        # contains all `True`s if not using an iterable-style dataset
        # (i.e., if kind != Iterable).
        # Not that this indicates that a worker still has work to do *for this epoch*.
        # It does not mean that a worker is dead. In case of `_persistent_workers`,
        # the worker will be reset to available in the next epoch.
        # 每个worker的状态
        self._workers_status = [True for i in range(self._num_workers)]
        # We resume the prefetching in case it was enabled
        if not first_iter:
            for idx in range(self._num_workers):
                self._index_queues[idx].put(_utils.worker._ResumeIteration())
            resume_iteration_cnt = self._num_workers
            while resume_iteration_cnt > 0:
                return_idx, return_data = self._get_data()
                if isinstance(return_idx, _utils.worker._ResumeIteration):
                    assert return_data is None
                    resume_iteration_cnt -= 1
        # prime the prefetch loop
        
        # 预取若干index
        for _ in range(self._prefetch_factor * self._num_workers):
            self._try_put_index()
​

2.4.4 Получить индекс

Функция _try_put_index должна использовать сэмплер для получения следующего пакета индекса данных. Здесь значение _prefetch_factor по умолчанию равно 2, а основная логика такова.

  • Получить индекс следующей партии из сэмплера.
  • Найдите следующего доступного воркера через _worker_queue_idx_cycle и присвойте ему индекс.
  • И настроить информацию основного процесса.
    def _next_index(self): # 定义在基类 _BaseDataLoaderIter 之中,就是获取下一批index
        return next(self._sampler_iter)  # may raise StopIteration
​
    def _try_put_index(self):
        
        assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
​
        try:
            index = self._next_index() # 获取下一批index
        except StopIteration:
            return
        for _ in range(self._num_workers):  # find the next active worker, if any
            worker_queue_idx = next(self._worker_queue_idx_cycle)
            if self._workers_status[worker_queue_idx]: # 如果已经工作,就继续找
                break
        else:
            # not found (i.e., didn't break)
            return
​
        # 以下是主进程进行相关记录
        # 给下一个工作worker放入 (任务index, 数据index), 就是给queue放入数据,所以worker loop之中就立刻会从queue中得到index,从而开始获取数据。
        self._index_queues[worker_queue_idx].put((self._send_idx, index)) 
        # 记录 将要产生的 data 信息
        self._task_info[self._send_idx] = (worker_queue_idx,)
        # 正在处理的batch个数+1
        self._tasks_outstanding += 1
        # send_idx 记录从sample_iter中发送索引到index_queue的次数
        self._send_idx += 1 # 递增下一批发送的task index
​

2.4.5 Основная функция работника

_worker_loop — это основная функция рабочего процесса, и основная логика показана в его комментариях:

    # [ worker processes ]
    #   While loader process is alive:
    #     Get from `index_queue`.
    #       If get anything else,
    #          Check `workers_done_event`.
    #            If set, continue to next iteration
    #                    i.e., keep getting until see the `None`, then exit.
    #            Otherwise, process data:
    #                If is fetching from an `IterableDataset` and the iterator
    #                    is exhausted, send an `_IterableDatasetStopIteration`
    #                    object to signal iteration end. The main process, upon
    #                    receiving such an object, will send `None` to this
    #                    worker and not use the corresponding `index_queue`
    #                    anymore.
    #       If timed out,
    #          No matter `workers_done_event` is set (still need to see `None`)
    #          or not, must continue to next iteration.
    #   (outside loop)
    #   If `workers_done_event` is set,  (this can be False with `IterableDataset`)
    #     `data_queue.cancel_join_thread()`.  (Everything is ending here:
    #                                          main process won't read from it;
    #                                          other workers will also call
    #                                          `cancel_join_thread`.)
​

Это взаимодействие с основным процессом через index_queue, data_queue.

  • Получить новый индекс данных из index_queue;
  • Если конец этого воркера не установлен, используйте fetcher для получения данных.
  • Затем поместите данные в data_queue и уведомите основной процесс, здесь нужно обратить внимание,data_queue — входящий параметр, если установлена ​​память пинов, передается worker_result_queue, в противном случае передается data_queue..
def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
                 auto_collation, collate_fn, drop_last, base_seed, init_fn, worker_id,
                 num_workers, persistent_workers):
    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
    # logic of this function.
​
    try:
        # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
        # module's handlers are executed after Python returns from C low-level
        # handlers, likely when the same fatal signal had already happened
        # again.
        # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
        signal_handling._set_worker_signal_handlers()
​
        torch.set_num_threads(1)
        seed = base_seed + worker_id
        random.seed(seed)
        torch.manual_seed(seed)
        if HAS_NUMPY:
            np_seed = _generate_state(base_seed, worker_id)
            import numpy as np
            np.random.seed(np_seed)
​
        global _worker_info
        _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
                                  seed=seed, dataset=dataset)
​
        from torch.utils.data import _DatasetKind
​
        init_exception = None
​
        try:
            if init_fn is not None:
                init_fn(worker_id)
​
            fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
        except Exception:
            init_exception = ExceptionWrapper(
                where="in DataLoader worker process {}".format(worker_id))
​
        iteration_end = False
        watchdog = ManagerWatchdog()
​
        while watchdog.is_alive(): # 等待在这里
            try:
                # _try_put_index 如果放入了数据index,这里就被激活,开始工作
                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
            except queue.Empty:
                continue
            if isinstance(r, _ResumeIteration):
                # Acknowledge the main process
                data_queue.put((r, None))
                iteration_end = False
                # Recreate the fetcher for worker-reuse policy
                fetcher = _DatasetKind.create_fetcher(
                    dataset_kind, dataset, auto_collation, collate_fn, drop_last)
                continue
            elif r is None:
                # Received the final signal
                assert done_event.is_set() or iteration_end
                break
            elif done_event.is_set() or iteration_end:
                # `done_event` is set. But I haven't received the final signal
                # (None) yet. I will keep continuing until get it, and skip the
                # processing steps.
                continue
            idx, index = r
            data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
            if init_exception is not None:
                data = init_exception
                init_exception = None
            else:
                try:
                    data = fetcher.fetch(index)
                except Exception as e:
                    # 省略处理代码
            
            data_queue.put((idx, data)) # 放入数据,通知主进程
            del data, idx, index, r  # save memory
    except KeyboardInterrupt:
        # Main process will raise KeyboardInterrupt anyways.
        pass
    if done_event.is_set():
        data_queue.cancel_join_thread()
        data_queue.close()

2.4.6 Pin memory thread

В основном процессе, если требуется контактная память, pin_memory_thread основного процесса будет считывать данные из worker_result_queue, обрабатывать их и помещать результат в data_queue.

Конкретный код выглядит следующим образом:

def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
    # This setting is thread local, and prevents the copy in pin_memory from
    # consuming all CPU cores.
    torch.set_num_threads(1)
​
    torch.cuda.set_device(device_id)
​
    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
    # logic of this function.
    while not done_event.is_set():
        try:
            r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        except queue.Empty:
            continue
        idx, data = r
        if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
            data = pin_memory(data)
            # 省略异常处理代码
            r = (idx, data)
        while not done_event.is_set():
            try:
                out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
                break
            except queue.Full:
                continue
        del r  # save memory
​
​
def pin_memory(data):
    if isinstance(data, torch.Tensor):
        return data.pin_memory()
    elif isinstance(data, string_classes):
        return data
    elif isinstance(data, collections.abc.Mapping):
        return {k: pin_memory(sample) for k, sample in data.items()}
    elif isinstance(data, tuple) and hasattr(data, '_fields'):  # namedtuple
        return type(data)(*(pin_memory(sample) for sample in data))
    elif isinstance(data, collections.abc.Sequence):
        return [pin_memory(sample) for sample in data]
    elif hasattr(data, "pin_memory"):
        return data.pin_memory()
    else:
        return data

2.4.7 Пользователь получает данные

Теперь, когда данные загружены, давайте посмотрим, как пользователь получает данные из DataLoader.

Здесь есть ключевой момент: как сохранить согласованность порядка чтения данных в разных экспериментах. Чтобы проводить сравнения между несколькими экспериментами, необходимо обеспечить, чтобы в этих экспериментах порядок считывания данных каждый раз был постоянным, чтобы не вызывать ошибок в результатах из-за данных.

Самая большая возможность нарушить последовательную согласованность — это неупорядоченные данные. Причина проблемы неупорядоченности: многопроцессорное чтение, некоторые процессы могут быть быстрыми, а некоторые — медленными. Например, на этот раз пользователю нужно прочитать 6-19, 16-26, 37-46. Но какой-то рабочий работает медленно, 6-19 не может вернуться сразу, а другой рабочий 16-26 возвращается первым, поэтому это вызовет беспорядок.

Как быть с неупорядоченными данными? Конкретный подход PyTorch таков: DataLoader возвращает данные строго в порядке Sampler. Если какие-то данные не по порядку, они будут временно сохранены, а затем перейдут к получению следующих данных, см. комментарий «сохранить неупорядоченные образцы» в приведенном ниже коде. Подождите, пока он не должен быть возвращен (эта последовательность данных не прибудет), прежде чем возвращаться.

Но есть риск, что возврат данных будет медленнее, чем текущий запрос, например, должно быть получено 6, а таких данных нет в очереди данных, только 16, 27, поэтому пользователь может ждать только 6 загружен.

Решение медлительности: предварительная выборка. То есть в конце метода сброса заранее извлеките ряд индексов и позвольте DataLoader получить их заранее.Хотя это не гарантирует, что порядок возврата данных любых двух тренировок будет точно таким же, это может быть гарантировано в наибольшей степени.

Конкретный код выглядит следующим образом: во-первых, вспомните базовый класс__next__вы можете видеть, что она вызывает _next_data для получения данных.

class _BaseDataLoaderIter(object):
    def __next__(self) -> Any:
        with torch.autograd.profiler.record_function(self._profile_name):
            if self._sampler_iter is None:
                self._reset()
            data = self._next_data() # 获取数据
            self._num_yielded += 1
            if self._dataset_kind == _DatasetKind.Iterable and \
                    self._IterableDataset_len_called is not None and \
                    self._num_yielded > self._IterableDataset_len_called:
                    # 忽略错误提示处理
                warnings.warn(warn_msg)
            return data
​

Итак, давайте посмотрим_MultiProcessingDataLoaderIterиз_next_data.

  • Поскольку индекс был предварительно загружен ранее, рабочий процесс уже начал получать данные, поэтому основной процесс может получить данные в это время.Если данных нет, он будет продолжать ждать, пока True.
  • Если выборка прошла успешно, используйте_process_dataУстановите следующий индекс и подготовьтесь к следующей итерации.
  • пройти через_task_infoДля записи неупорядоченных данных, если они временно не могут быть обработаны, сохраните их здесь.
    def _next_data(self):
        while True:
            
            # 找到待取idx
            while self._rcvd_idx < self._send_idx: # 如果 待取batch idx < 已取batch idx
                info = self._task_info[self._rcvd_idx]
                worker_id = info[0]
                if len(info) == 2 or self._workers_status[worker_id]:  # has data or is still active
                    break # 有数据或者正在工作,就跳出内部这个while
                del self._task_info[self._rcvd_idx]
                self._rcvd_idx += 1
            else:
                # no valid `self._rcvd_idx` is found (i.e., didn't break)
                if not self._persistent_workers:
                    self._shutdown_workers()
                raise StopIteration
​
            # Now `self._rcvd_idx` is the batch index we want to fetch
​
            # Check if the next sample has already been generated
            if len(self._task_info[self._rcvd_idx]) == 2:
                data = self._task_info.pop(self._rcvd_idx)[1]
                return self._process_data(data) # 设定下一次的indx,进行下一次迭代
​
            assert not self._shutdown and self._tasks_outstanding > 0
            idx, data = self._get_data() # 从 self._data_queue 中取数据
            self._tasks_outstanding -= 1 # 正在准备的batch个数需要减1
            
            if self._dataset_kind == _DatasetKind.Iterable:
                # Check for _IterableDatasetStopIteration
                if isinstance(data, _utils.worker._IterableDatasetStopIteration):
                    if self._persistent_workers:
                        self._workers_status[data.worker_id] = False
                    else:
                        self._mark_worker_as_unavailable(data.worker_id)
                    self._try_put_index() 
                    continue
​
            if idx != self._rcvd_idx: # 乱序数据
                # store out-of-order samples
                self._task_info[idx] += (data,)
            else:
                del self._task_info[idx] # 正常数据
                return self._process_data(data) # 设定下一次的indx,进行下一次迭代

Во-вторых, давайте посмотрим на_get_dataКак добраться изself._data_queueизвлечь данные. В частности, используйте _try_get_data для извлечения.

  • Если есть конфигурация тайм-аута, читайте в соответствии с тайм-аутом.
  • Если установлена ​​память штифта, данные считываются из резьбы штифта после обработки.
  • В противном случае данные, обрабатываемые рабочим потоком, считываются в цикле до тех пор, пока данные не будут получены.
    def _get_data(self):
        # Fetches data from `self._data_queue`.
 
        if self._timeout > 0: # 如果有超时配置,就按照超时读取
            success, data = self._try_get_data(self._timeout)
            if success:
                return data
            else:
                raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
        elif self._pin_memory: # 从pin 线程处理之后的数据读取
            while self._pin_memory_thread.is_alive():
                success, data = self._try_get_data()
                if success:
                    return data
            else:
                # while condition is false, i.e., pin_memory_thread died.
                raise RuntimeError('Pin memory thread exited unexpectedly')
            # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
            # need to call `.task_done()` because we don't use `.join()`.
        else:
            while True:
                success, data = self._try_get_data() # 读取worker处理的数据
                if success:
                    return data
​

_try_get_dataиз_data_queueчитать. Основной процесс и рабочий процесс взаимодействуют и взаимодействуют посредством размещения и получения в очереди.

    def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
        # Tries to fetch data from `self._data_queue` once for a given timeout.
        # This can also be used as inner loop of fetching without timeout, with
        # the sender status as the loop condition.
        #
        # This raises a `RuntimeError` if any worker died expectedly. This error
        # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
        # (only for non-Windows platforms), or the manual check below on errors
        # and timeouts.
        #
        # Returns a 2-tuple:
        #   (bool: whether successfully get data, any: data if successful else None)
        try:
            data = self._data_queue.get(timeout=timeout)
            return (True, data)
        except Exception as e:
            # At timeout and error, we manually check whether any worker has
            # failed. Note that this is the only mechanism for Windows to detect
            # worker failures.
            failed_workers = []
            for worker_id, w in enumerate(self._workers):
                if self._workers_status[worker_id] and not w.is_alive():
                    failed_workers.append(w)
                    self._mark_worker_as_unavailable(worker_id)
            # 省略异常处理代码
            import tempfile
            import errno
            try:
                # Raise an exception if we are this close to the FDs limit.
                # Apparently, trying to open only one file is not a sufficient
                # test.
                # See NOTE [ DataLoader on Linux and open files limit ]
                fds_limit_margin = 10
                fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
            except OSError as e:
                # 省略异常处理代码
            raise
​

Установите следующую итерацию для использования_process_data.

    def _process_data(self, data):
        self._rcvd_idx += 1
        self._try_put_index() # 设定下一次的indx,进行下一次迭代
        if isinstance(data, ExceptionWrapper):
            data.reraise()
        return data # 返回数据
​

2.4.8 Резюме

Подытожим логику многопроцессорности.

Общая логика следующая:

  • Основной процесс помещает полученный индекс данных в index_queue.
  • Дочерний процесс считывает индекс из очереди index_queue, считывает данные, а затем помещает индекс прочитанных данных в очередь worker_result_queue.
  • pin_memory_thread основного процесса будет считывать индекс данных из worker_result_queue, читать данные по индексу, обрабатывать их и помещать результат в data_queue.

Конкретный процесс выглядит следующим образом:

  1. В функции инициализации _MultiProcessingDataLoaderIter__init__будет инициализирован в:

    • Настройте, создайте различные переменные-члены и настройте различные очереди.
    • Запустите каждый дочерний процесс.
    • Поток, запускающий pin_memory в основном процессе.
    • Вызовите функцию _reset, которая предназначена для дальнейшего улучшения инициализации бизнеса, а также для сброса среды. Рабочий подпроцесс был запущен выше, но задачи не были назначены, поэтому функция сброса назначит задачи.Предварительная загрузка.
  2. Далее идет операция предварительной выборки (обязательно обратите внимание, глядя на рисунок ниже).

    • Функция _try_put_index должна использовать сэмплер для получения следующего пакета индекса данных. Здесь значение _prefetch_factor по умолчанию равно 2, а основная логика такова.

      • Используйте _next_index, чтобы получить индекс следующей партии из сэмплера.
      • Найдите следующего доступного воркера через _worker_queue_idx_cycle и присвойте ему индекс.
      • И настроить информацию основного процесса.
    • Получив индекс, вернитесь к основному потоку. Здесь происходит извлечение данных. Это взаимодействие с основным процессом через index_queue, data_queue.

      • Получить новый индекс данных из index_queue;
      • Если конец этого воркера не установлен, используйте fetcher для получения данных.
      • Затем поместите данные в data_queue и уведомите основной процесс.Здесь следует отметить, что data_queue является входным параметром.Если установлена ​​память пинов, передается worker_result_queue, иначе передается data_queue.
  3. Когда пользователь выполняет итерацию, вызывается базовый класс Loader.__next__function , которая вызывает _next_data для получения данных.

    • использовать_get_dataКак добраться изself._data_queueизвлечь данные.
    • использовать_process_dataУстановите индекс для следующей итерации, т.е. используйте_try_put_index,_next_indexчтобы перейти к следующему раунду настроек.
user        _MultiProcessingDataLoaderIter   Sampler        Queue(index_queue)    Queue(data_queue)    _worker_loop     Fetcher
 +                       +                      +                  +                     +                  +              +
 |                       |                      |                  |                     |                  |              |
 |                       |                      |                  |                     |                  |              |
 |                       v                      |                  |                     |                  |              |
 |                   __init__                   |                  |                     |                  |              |
 |               1    _reset                    |                  |                     |                  |              |
 |                       +                      |                  |                     |                  |              |
 |                       |                      |                  |                     |                  |              |
 |                       |                      |                  |                     |                  |              |
 |                       v                      |                  |                     |                  |              |
 |            2   _try_put_index     next       |                  |                     |                  |              |
 |                  _next_index  +------------> |                  |                     |                  |              |
 |                       +                      |                  |                     |                  |              |
 |                       |  <-----------------+ |                  |                     |                  |              |
 |                       |           index      |                  |                     |                  |              |
 |                       |                      |                  |                     |                  |              |
 |                       | +------------------------------------>  |                     |                  |              |
 |                       |           put        |                  |                     |       get        |              |
 |                       |                      |                  +--------------------------------------> |              |
 |                       |                      |                  |                     |                  |    index     |
 |                       |                      |                  |                     |                  +------------> |
 |         next          |                      |                  |                     |                  | <----------+ |
 +---------------------> |                      |                  |                     | <----------------+    data      |
 |                       |                      |                  |                     |      data        |              |
 |                       +                      |                  |                     |                  |              |
 |                   _next_data                 |                  |                     |                  |              |
 |              3   _get_data          get      |                  |                     |                  |              |
 |                  _try_get_data  +-------------------------------------------------->  |                  |              |
 |                       +                      |                  |                     |                  |              |
 |                       |  <----------------------------------------------------------+ |                  |              |
 |                       |             data     |                  |                     |                  |              |
 |                       +                      |                  |                     |                  |              |
 |                   _process_data              |                  |                     |                  |              |
 |                  _try_put_index     next     |                  |                     |                  |              |
 |                  _next_index +-------------> |                  |                     |                  |              |
 |                       + <--------------------+                  |                     |                  |              |
 |                       |           index      |                  |                     |                  |              |
 |                       +---------------------------------------> |                     |       get        |              |
 | <-------------------+ |             put      |                  +------------------------------------->  |     index    |
 |        data           |                      |                  |                     |                  | +----------> |
 |                       |                      |                  |                     |                  +<-----------+ |
 v                       v                      v                  v                     v                  v     data     v
​

По телефону так:

img

2.5 Pipleline

На данный момент мы усовершенствовали предыдущую диаграмму конвейера следующим образом:

                                                  +------------+
                              +--------+          |            |
                              |        |          | Process 1  |
                      +-----> | Data 1 +--------> |            +------+
                      |       |        |          | Load Data  |      |
                      |       +--------+          |            |      |
                      |                           +------------+      |
                      |                                               |
                      |                                               |
                      |                                               |
+----------------+    |                           +------------+      |                                          +-------------------------+
|Main process    |    |       +--------+          |            |      |                                          |  pin_memory_thread      |
|                |    |       |        |          | Process 2  |      +------>  +------------------------+       |                         |          +------------+
|  index_queue   +----------> | Data 2 +--------> |            |                |                        |       |                         |          |            |
|                |    |       |        |          | Load Data  +------------->  |  _worker_result_queue  +-----> |  Write to pinned memory +--------> | data_queue |
|                |    |       +--------+          |            |                |                        |       |                         |          |            |
+----------------+    |                           +------------+       +----->  |                        |       |                         |          +------------+
                      |                                                |        +------------------------+       |                         |
                      |                                                |                                         +-------------------------+
                      |                                                |
                      |       +--------+          +------------+       |
                      |       |        |          |            |       |
                      +-----> | Data 3 +--------> | Process 3  +-------+
                              |        |          |            |
                              +--------+          | Load Data  |
                                                  |            |
                                                  +------------+
​

Телефон такой:

img

На этом анализ части распределенной загрузки данных PyTorch завершен, в следующей статье мы вернемся к тому, как Paracel обрабатывает загрузку данных.

0xEE Личная информация

★★★★★★Думая о жизни и технологиях★★★★★★

Публичный аккаунт WeChat:мысли Росси

ссылка 0xFF

Модель распараллеливания сверточных нейронных сетей — один странный трюк для распараллеливания сверточных нейронных сетей

Проблемы и решения обработки данных в рамках ИИ

torch.utils.data интерпретации исходного кода PyTorch: весь процесс парсинга обработки данных

Расскажите о своем понимании и осведомленности в области крупномасштабного машинного обучения?

Nvidia-DALI от отказа до входа

pytorch (распределенные) данные параллельной личной практики - DataParallel/DistributedDataParallel