Подробное объяснение основных параметров BN PyTorch

искусственный интеллект
Подробное объяснение основных параметров BN PyTorch

Подробное объяснение основных параметров BN PyTorch

Оригинальный документ:Yuque.com/Pull the head/UG kV9 hair…

BN — это обычная операция и модуль в CNN. В конкретной реализации он содержит несколько параметров. Это также приводит к разным эффектам при различных комбинациях параметров.

affine

Изменено при инициализации

Если для параметра affine установлено значение True, слой BatchNorm изучит параметры gamma и beta, в противном случае эти две переменные не будут включены, а именами переменных будут вес и смещение.

.train()

  • еслиaffine==True, затем выполните аффинное преобразование на нормализованном пакете, то есть умножьте вес внутри модуля (начальное значение равно [1., 1., 1., 1.]), а затем добавьте смещение внутри модуля (начальное значение равно [0., 0., 0., 0.]), эти две переменные обновляются во время обратного распространения.
  • еслиaffine==False, то BatchNorm не содержит двух переменных, веса и смещения, и ничего не делает.

.eval()

  • еслиaffine==True, затем нормированная партия преобразуется радиационно, то есть умножается на вес внутри модуля, а затем добавляется со смещением внутри модуля Эти две переменные изучаются во время обучения сети.
  • еслиaffine==False, то BatchNorm не содержит двух переменных, веса и смещения, и ничего не делает.

Изменить свойства экземпляра

Никакого эффекта, все еще в соответствии с первоначальными настройками.

track_running_stats

Поскольку этот атрибут участвует в прямом распространении BN, модификация атрибута экземпляра повлияет на окончательный процесс расчета.

class _NormBase(Module):
    """Common base of _InstanceNorm and _BatchNorm"""
    _version = 2
    __constants__ = ['track_running_stats', 'momentum', 'eps',
                     'num_features', 'affine']
    num_features: int
    eps: float
    momentum: float
    affine: bool
    track_running_stats: bool
    # WARNING: weight and bias purposely not defined here.
    # See https://github.com/pytorch/pytorch/issues/39670

    def __init__(
        self,
        num_features: int,
        eps: float = 1e-5,
        momentum: float = 0.1,
        affine: bool = True,
        track_running_stats: bool = True
    ) -> None:
        super(_NormBase, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)
        self.reset_parameters()
    ...

class _BatchNorm(_NormBase):
    ...

    def forward(self, input: Tensor) -> Tensor:
        self._check_input_dim(input)
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:  # type: ignore
                self.num_batches_tracked = self.num_batches_tracked + 1  # type: ignore
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        r"""
        Decide whether the mini-batch stats should be used for normalization rather than the buffers.
        Mini-batch stats are used in training mode, and in eval mode when buffers are None.

        可以看到这里的bn_training控制的是,数据运算使用当前batch计算得到的统计量(True)
        """
        if self.training:
            bn_training = True
        else:
            bn_training = (self.running_mean is None) and (self.running_var is None)

        r"""
        Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
        passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
        used for normalization (i.e. in eval mode when buffers are not None).

        这里强调的是统计量buffer的使用条件(self.running_mean, self.running_var)
        - training==True and track_running_stats==False, 这些属性被传入F.batch_norm中时,均替换为None
        - training==True and track_running_stats==True, 会使用这些属性中存放的内容
        - training==False and track_running_stats==True, 会使用这些属性中存放的内容
        - training==False and track_running_stats==False, 会使用这些属性中存放的内容
        """
        assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
        assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
        return F.batch_norm(
            input,
            # If buffers are not to be tracked, ensure that they won't be updated
            self.running_mean if not self.training or self.track_running_stats else None,
            self.running_var if not self.training or self.track_running_stats else None,
            self.weight, self.bias, bn_training, exponential_average_factor, self.eps)

.train()

Обратите внимание на комментарии в коде: буферы обновляются только в том случае, если их нужно отслеживать, и мы находимся в режиме обучения.track_running_stats==TrueЭти буферы статистики обновляются каждый раз.

Кроме того, в это времяself.training==True.bn_training=True.

track_running_stats==True

Слой BatchNorm будет подсчитывать глобальное среднее значение running_mean и дисперсию running_var, а при нормализации партии используются только статистические данные текущей партии.

            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))

Используйте импульс для обновления running_mean внутри модуля.

  • Если импульс равен None, то используется кумулятивная скользящая средняя (здесь используются атрибутыself.num_batches_trackedдля подсчета количества прошедших пакетов), в противном случае используйте экспоненциальное скользящее среднее (с использованием импульса в качестве коэффициента). Базовая структура формулы обновления для них одинакова:xnew=(1factor)×xcur+factor×xbatchx_{new}=(1 - factor) \times x_{cur} + factor \times x_{batch}

, только конкретныеfactorfactorразные.

  • xnewx_{new}представляет обновленное значение running_mean и running_var;
  • xcurx_{cur}Представляет running_mean и running_var до обновления;
  • x_{batch}\Представляет среднюю и несмещенную выборочную дисперсию текущей партии.
  • Обновление кумулятивной скользящей среднейfactor=1/num_batches_trackedfactor=1/num\_batches\_tracked.
  • Формула обновления для экспоненциальной скользящей средней:factor=momentumfactor=momentum.
Изменить свойства экземпляра

Если установлено.track_running_stats==False,В настоящее времяself.num_batches_trackedне будет обновляться иexponential_average_factorОн также не будет масштабироваться. И из-за:

            self.running_mean if not self.training or self.track_running_stats else None,
            self.running_var if not self.training or self.track_running_stats else None,

и в это времяself.training==Trueself.track_running_stats==False, так что присылайтеF.batch_normизself.running_mean&self.running_varОба параметра равны None.То есть установить в этот момент и непосредственно в инициализации**track_running_stats==False**тот же эффект. Но будьте осторожны здесь~~exponential_average_factor~~Перемена. Однако, поскольку обычно при инициализации BN мы отправляем только~~num_features~~, поэтому по умолчанию он будет использовать~~exponential_average_factor = self.momentum~~для построения экспоненциальной статистики времени выполнения обновления скользящего среднего.(В настоящее времяexponential_average_factorне будет работать)

track_running_stats==False

Тогда BatchNorm не содержит переменных running_mean и running_var, то есть для нормализации батча используется только статистика текущего батча.

            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)
Изменить свойства экземпляра

Если установлено.track_running_stats==True,В настоящее времяself.num_batches_trackedВсе еще не будет обновляться, потому что его начальное значение равно None. В целом, такие изменения не имеют реального влияния.

.eval()

В настоящее времяself.training==False.

            self.running_mean if not self.training or self.track_running_stats else None,
            self.running_var if not self.training or self.track_running_stats else None,

отправить в это времяF.batch_normДве статистики буфера и инициализации согласованы.

track_running_stats==True

            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))

В настоящее времяbn_training = (self.running_mean is None) and (self.running_var is None) == False. Так что пользуйтесь глобальной статистикой. Нормируйте партию, формулаy=xE^[x]Var^[x]+ϵy=\frac{x-\hat{E}[x]}{\sqrt{\hat{Var}[x]+\epsilon}}, обратите внимание, что среднее значение и дисперсия здесь равныrunning_mean и running_var, рассчитанный при обучении сетиГлобальное среднее и несмещенная выборочная дисперсия.

Изменить свойства экземпляра

Если установлено.track_running_stats==False,В настоящее времяbn_trainingбез изменений, все еще не False, поэтому по-прежнему используйте глобальную статистику. этоself.running_mean, self.running_varсодержимое, хранящееся в . В целом, изменение свойств в настоящее время не имеет никакого эффекта.

track_running_stats==False

            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)

В настоящее времяbn_training = (self.running_mean is None) and (self.running_var is None) == True. Поэтому используйте статистику текущей партии. Нормируйте партию, формулаy=xE[x]Var[x]+ϵy=\frac{x-{E}[x]}{\sqrt{{Var}[x]+\epsilon }}, обратите внимание, что среднее значение и дисперсия здесь равныпакетное собственное среднее значение и переменная, running_mean и running_var в настоящее время не включены в BatchNorm. Обратите внимание, что в это время используется несмещенная выборочная дисперсия (отличная от обучения), поэтому, если batch_size=1, знаменатель будет равен 0, и будет сообщено об ошибке.

Изменить свойства экземпляра

Если установлено.track_running_stats==True,В настоящее времяbn_trainingБез изменений, по-прежнему True, поэтому по-прежнему используется статистика текущего пакета. то есть игнорироватьself.running_mean, self.running_varсодержимое, хранящееся в . Поведение в этот момент такое же, как если бы оно не было изменено.

Резюме

Скриншот изображения из оригинального документа.

Ссылаться на