Сравнение потери контраста под контролем самоконтроля и потери контраста под наблюдением

машинное обучение глубокое обучение компьютерное зрение NLP

Автор | Самрат Саха Компилировать|ВКонтакте Источник | К науке о данных

В статье «Контролируемое контрастное обучение» много обсуждается между контролируемым обучением, кросс-энтропийной потерей и контролируемой контрастной потерей, чтобы лучше реализовать задачи представления и классификации изображений. Давайте немного углубимся в то, о чем эта статья.

В документе указывается, что набор данных сети изображений может быть улучшен на 1%.

С точки зрения архитектуры это очень простая сеть resnet 50 со 128-мерной головкой. Вы также можете добавить больше слоев, если хотите.

Code

self.encoder = resnet50()

self.head = nn.Linear(2048, 128)

def forward(self, x):
 feat = self.encoder(x)
 #需要对128向量进行标准化
 feat = F.normalize(self.head(feat), dim=1)
 return feat

Как видно, обучение проходит в два этапа.

  • Тренировочный набор с использованием контрастных потерь (два варианта)

  • Заморозьте параметры, а затем изучите классификатор на линейном слое, используя потери softmax. (практика из бумаги)

Вышеизложенное говорит само за себя.

Основное содержание этой статьи состоит в том, чтобы понять контрастивную потерю с самоконтролем и контролируемую контрастивную потерю.

Как вы можете видеть из приведенного выше графика SCL (контролируемая контрастная потеря), кошка сравнивается с любым человеком, не являющимся кошкой. Это означает, что все кошки относятся к одной и той же метке, являются положительными парами, а любые не-кошки отрицательными. Это очень похоже на то, как работают триплетные данные и потери триплетов.

Каждое изображение кота увеличено, поэтому даже из одного изображения кота у нас будет много кошек.

Функция потерь для контролируемых контрастных потерь, хотя и выглядит устрашающе, на самом деле довольно проста.

Позже мы увидим некоторый код, но сначала очень простое объяснение. Каждый z представляет собой нормализованный 128-мерный вектор.

То есть ||z||=1

Чтобы повторить факт линейной алгебры, если два вектора u и v нормализованы, это означает, что u.v=cos(угол между u и v)

Это означает, что если два нормализованных вектора одинаковы, скалярное произведение между ними = 1.

#尝试理解下面的代码

import numpy as np
v = np.random.randn(128)
v = v/np.linalg.norm(v)
print(np.dot(v,v))
print(np.linalg.norm(v))

Функция потерь предполагает, что у каждого изображения есть дополненная версия, в каждой партии есть N изображений, а результирующий размер партии = 2*N.

Когда i!=j,yi=yj, числительexp(zi.zj)/tauПредставляет всех кошек в партии. Скалярное произведение i 128-го размерного вектора zi со всеми j 128-ми размерными векторами.

Знаменатель — это i изображений кошек, умноженных на другие изображения, не являющиеся кошками. Возьмите точки zi и zk так, что i!=k означает, что он умножает все изображения, кроме самого себя.

Наконец, мы берем логарифмическую вероятность и добавляем ее ко всем изображениям кошек в пакете, кроме самого себя, а затем делим на 2*N-1.

Сумма общих потерь для всех изображений

Мы можем понять вышеизложенное, используя некоторый код факела.

Давайте посмотрим, как рассчитать потери для одной партии, предполагая, что размер нашей партии равен 4.

Если размер пакета равен 4, ваш ввод в сети будет 8x3x224x224, где изображение имеет ширину и высоту 224.

Причина использования 8=4x2 заключается в том, что у нас всегда есть контраст для каждого изображения, поэтому загрузчик данных должен быть написан соответствующим образом.

Renet контрастных потерь выведет матрицу размеров 8x128, и вы можете разделить эти размеры, чтобы вычислить потери партии.

#batch大小
bs = 4

Эта часть может вычислить числитель

temperature = 0.07

anchor_feature = contrast_feature

anchor_dot_contrast = torch.div(
    torch.matmul(anchor_feature, contrast_feature.T),
    temperature)

Форма нашего объекта 8х128, возьмем матрицу 3х128 и транспонируем, ниже картинка после визуализации.

anchor_feature=3x128 и convert_feature=128x3, результат 3x3, как показано ниже.

Если вы заметили, что все диагональные элементы сами по себе являются точками, которые на самом деле нам не нужны, мы их удалим.

У линейной алгебры есть свойство: если u и v — два вектора, то u.v — наибольший при u=v. Таким образом, в каждой строке, если мы возьмем максимальное значение контраста привязки и возьмем одно и то же значение, все диагонали станут равными 0.

Уменьшим размерность со 128 до 2

#bs 1 和 dim 2 意味着 2*1x2 
features = torch.randn(2, 2)

temperature = 0.07 
contrast_feature  = features
anchor_feature = contrast_feature
anchor_dot_contrast = torch.div(
    torch.matmul(anchor_feature, contrast_feature.T),
    temperature)
print('anchor_dot_contrast=\n{}'.format(anchor_dot_contrast))

logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
print('logits_max = {}'.format(logits_max))
logits = anchor_dot_contrast - logits_max.detach()
print(' logits = {}'.format(logits))

#输出看看对角线发生了什么

anchor_dot_contrast=
tensor([[128.8697, -12.0467],
        [-12.0467,  50.5816]])
 logits_max = tensor([[128.8697],
        [ 50.5816]])
 logits = tensor([[   0.0000, -140.9164],
        [ -62.6283,    0.0000]])

Создавайте искусственные метки и соответствующие маски для сравнительных расчетов. Этот код немного сложен, поэтому дважды проверьте вывод.

bs = 4
print('batch size', bs)
temperature = 0.07
labels = torch.randint(4, (1,4))
print('labels', labels)
mask = torch.eq(labels, labels.T).float()
print('mask = \n{}'.format(logits_mask))

#对它进行硬编码,以使其更容易理解
contrast_count = 2
anchor_count = contrast_count

mask = mask.repeat(anchor_count, contrast_count)

#屏蔽self-contrast的情况
logits_mask = torch.scatter(
    torch.ones_like(mask),
    1,
    torch.arange(bs * anchor_count).view(-1, 1),
    0
)
mask = mask * logits_mask
print('mask * logits_mask = \n{}'.format(mask))

Давайте разберемся с выводом.

batch size 4
labels tensor([[3, 0, 2, 3]])

#以上的意思是在这批4个品种的葡萄中,我们有3,0,2,3个标签。以防你们忘了我们在这里只做了一次对比所以我们会有3_c 0_c 2_c 3_c作为输入批处理中的对比。

mask = 
tensor([[0., 1., 1., 1., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1., 1., 1.],
        [1., 1., 0., 1., 1., 1., 1., 1.],
        [1., 1., 1., 0., 1., 1., 1., 1.],
        [1., 1., 1., 1., 0., 1., 1., 1.],
        [1., 1., 1., 1., 1., 0., 1., 1.],
        [1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0.]])
        
#这是非常重要的,所以我们创建了mask = mask * logits_mask,它告诉我们在第0个图像表示中,它应该与哪个图像进行对比。

# 所以我们的标签就是标签张量([[3,0,2,3]])
# 我重新命名它们是为了更好地理解张量([[3_1,0_1,2_1,3_2]])

mask * logits_mask = 
tensor([[0., 0., 0., 1., 1., 0., 0., 1.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 0., 1.],
        [1., 0., 0., 1., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 1., 1., 0., 0., 0.]])

Код сравнения точек привязки

logits = anchor_dot_contrast — logits_max.detach()

функция потерь

Математический обзор

У нас уже есть скалярный продукт первой части, разделенный на тау как логиты.

#上述等式的第二部分等于torch.log(exp_logits.sum(1, keepdim=True))

exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

# 计算对数似然的均值
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

# 损失
loss = - mean_log_prob_pos

loss = loss.view(anchor_count, 4).mean()
print('19. loss {}'.format(loss))

Я думаю, что это контролируемая контрастная потеря. Я думаю, что самоконтролируемую контрастивную потерю теперь легко понять, потому что она проще.

Согласно результатам исследования в этой статье, чем больше значение convert_count, тем четче модель. Нужно изменить convert_count на 2 или больше, надеюсь, вы сможете попробовать это с помощью приведенных выше инструкций.

ссылка

  • [1] : Supervised Contrastive Learning
  • [2]: Флориан Шрофф, Дмитрий Калениченко и Джеймс Филбин, Facenet: унифицированное встраивание для распознавания лиц и кластеризации, Материалы конференции IEEE по компьютерному зрению и распознаванию образов, стр. 815–823, 2015 г.
  • [3] : A Simple Framework for Contrastive Learning of Visual Representations, Ting Chen, Simon Kornblith Mohammad Norouzi, Geoffrey Hinton
  • [4] : GitHub.com/Google – Горячие цвета…

Оригинальная ссылка:к data science.com/ah-detailed-…

Добро пожаловать на сайт блога Panchuang AI:panchuang.net/

sklearn машинное обучение китайские официальные документы:sklearn123.com/

Добро пожаловать на станцию ​​сводки ресурсов блога Panchuang:docs.panchuang.net/