Подробное объяснение основных параметров BN PyTorch
Оригинальный документ:Yuque.com/Pull the head/UG kV9 hair…
BN — это обычная операция и модуль в CNN. В конкретной реализации он содержит несколько параметров. Это также приводит к разным эффектам при различных комбинациях параметров.
affine
Изменено при инициализации
Если для параметра affine установлено значение True, слой BatchNorm изучит параметры gamma и beta, в противном случае эти две переменные не будут включены, а именами переменных будут вес и смещение.
.train()
- если
affine==True, затем выполните аффинное преобразование на нормализованном пакете, то есть умножьте вес внутри модуля (начальное значение равно [1., 1., 1., 1.]), а затем добавьте смещение внутри модуля (начальное значение равно [0., 0., 0., 0.]), эти две переменные обновляются во время обратного распространения. - если
affine==False, то BatchNorm не содержит двух переменных, веса и смещения, и ничего не делает.
.eval()
- если
affine==True, затем нормированная партия преобразуется радиационно, то есть умножается на вес внутри модуля, а затем добавляется со смещением внутри модуля Эти две переменные изучаются во время обучения сети. - если
affine==False, то BatchNorm не содержит двух переменных, веса и смещения, и ничего не делает.
Изменить свойства экземпляра
Никакого эффекта, все еще в соответствии с первоначальными настройками.
track_running_stats
Поскольку этот атрибут участвует в прямом распространении BN, модификация атрибута экземпляра повлияет на окончательный процесс расчета.
class _NormBase(Module):
"""Common base of _InstanceNorm and _BatchNorm"""
_version = 2
__constants__ = ['track_running_stats', 'momentum', 'eps',
'num_features', 'affine']
num_features: int
eps: float
momentum: float
affine: bool
track_running_stats: bool
# WARNING: weight and bias purposely not defined here.
# See https://github.com/pytorch/pytorch/issues/39670
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True
) -> None:
super(_NormBase, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
else:
self.register_parameter('running_mean', None)
self.register_parameter('running_var', None)
self.register_parameter('num_batches_tracked', None)
self.reset_parameters()
...
class _BatchNorm(_NormBase):
...
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None: # type: ignore
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
r"""
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
可以看到这里的bn_training控制的是,数据运算使用当前batch计算得到的统计量(True)
"""
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
r"""
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
这里强调的是统计量buffer的使用条件(self.running_mean, self.running_var)
- training==True and track_running_stats==False, 这些属性被传入F.batch_norm中时,均替换为None
- training==True and track_running_stats==True, 会使用这些属性中存放的内容
- training==False and track_running_stats==True, 会使用这些属性中存放的内容
- training==False and track_running_stats==False, 会使用这些属性中存放的内容
"""
assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor)
assert self.running_var is None or isinstance(self.running_var, torch.Tensor)
return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean if not self.training or self.track_running_stats else None,
self.running_var if not self.training or self.track_running_stats else None,
self.weight, self.bias, bn_training, exponential_average_factor, self.eps)
.train()
Обратите внимание на комментарии в коде: буферы обновляются только в том случае, если их нужно отслеживать, и мы находимся в режиме обучения.track_running_stats==TrueЭти буферы статистики обновляются каждый раз.
Кроме того, в это времяself.training==True.bn_training=True.
track_running_stats==True
Слой BatchNorm будет подсчитывать глобальное среднее значение running_mean и дисперсию running_var, а при нормализации партии используются только статистические данные текущей партии.
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
Используйте импульс для обновления running_mean внутри модуля.
- Если импульс равен None, то используется кумулятивная скользящая средняя (здесь используются атрибуты
self.num_batches_trackedдля подсчета количества прошедших пакетов), в противном случае используйте экспоненциальное скользящее среднее (с использованием импульса в качестве коэффициента). Базовая структура формулы обновления для них одинакова:
, только конкретныеразные.
- представляет обновленное значение running_mean и running_var;
- Представляет running_mean и running_var до обновления;
- x_{batch}\Представляет среднюю и несмещенную выборочную дисперсию текущей партии.
- Обновление кумулятивной скользящей средней.
- Формула обновления для экспоненциальной скользящей средней:.
Изменить свойства экземпляра
Если установлено.track_running_stats==False,В настоящее времяself.num_batches_trackedне будет обновляться иexponential_average_factorОн также не будет масштабироваться.
И из-за:
self.running_mean if not self.training or self.track_running_stats else None,
self.running_var if not self.training or self.track_running_stats else None,
и в это времяself.training==True,иself.track_running_stats==False, так что присылайтеF.batch_normизself.running_mean&self.running_varОба параметра равны None.То есть установить в этот момент и непосредственно в инициализации**track_running_stats==False**тот же эффект.
Но будьте осторожны здесь~~exponential_average_factor~~Перемена. Однако, поскольку обычно при инициализации BN мы отправляем только~~num_features~~, поэтому по умолчанию он будет использовать~~exponential_average_factor = self.momentum~~для построения экспоненциальной статистики времени выполнения обновления скользящего среднего.(В настоящее времяexponential_average_factorне будет работать)
track_running_stats==False
Тогда BatchNorm не содержит переменных running_mean и running_var, то есть для нормализации батча используется только статистика текущего батча.
self.register_parameter('running_mean', None)
self.register_parameter('running_var', None)
self.register_parameter('num_batches_tracked', None)
Изменить свойства экземпляра
Если установлено.track_running_stats==True,В настоящее времяself.num_batches_trackedВсе еще не будет обновляться, потому что его начальное значение равно None.
В целом, такие изменения не имеют реального влияния.
.eval()
В настоящее времяself.training==False.
self.running_mean if not self.training or self.track_running_stats else None,
self.running_var if not self.training or self.track_running_stats else None,
отправить в это времяF.batch_normДве статистики буфера и инициализации согласованы.
track_running_stats==True
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
В настоящее времяbn_training = (self.running_mean is None) and (self.running_var is None) == False. Так что пользуйтесь глобальной статистикой.
Нормируйте партию, формула, обратите внимание, что среднее значение и дисперсия здесь равныrunning_mean и running_var, рассчитанный при обучении сетиГлобальное среднее и несмещенная выборочная дисперсия.
Изменить свойства экземпляра
Если установлено.track_running_stats==False,В настоящее времяbn_trainingбез изменений, все еще не False, поэтому по-прежнему используйте глобальную статистику. этоself.running_mean, self.running_varсодержимое, хранящееся в .
В целом, изменение свойств в настоящее время не имеет никакого эффекта.
track_running_stats==False
self.register_parameter('running_mean', None)
self.register_parameter('running_var', None)
self.register_parameter('num_batches_tracked', None)
В настоящее времяbn_training = (self.running_mean is None) and (self.running_var is None) == True. Поэтому используйте статистику текущей партии.
Нормируйте партию, формула, обратите внимание, что среднее значение и дисперсия здесь равныпакетное собственное среднее значение и переменная, running_mean и running_var в настоящее время не включены в BatchNorm.
Обратите внимание, что в это время используется несмещенная выборочная дисперсия (отличная от обучения), поэтому, если batch_size=1, знаменатель будет равен 0, и будет сообщено об ошибке.
Изменить свойства экземпляра
Если установлено.track_running_stats==True,В настоящее времяbn_trainingБез изменений, по-прежнему True, поэтому по-прежнему используется статистика текущего пакета. то есть игнорироватьself.running_mean, self.running_varсодержимое, хранящееся в .
Поведение в этот момент такое же, как если бы оно не было изменено.
Резюме
Скриншот изображения из оригинального документа.