Siamese Network & Triplet NetWork

алгоритм

Это 3-й день моего участия в ноябрьском испытании обновлений, узнайте подробности события:Вызов последнего обновления 2021 г.

Сиамская сеть

Проще говоря, сиамская сетьОбщие параметрыДве нейронные сети

В сиамской сети выкладываем образX1X_1На вход получаем кодировку картинкиGW(X1)G_W(X_1). Затем мы вводим другое изображение без каких-либо обновлений параметров сети.X2X_2, и получить кодировку измененной картинкиGW(X2)G_W(X_2). Так как похожие картинки должны иметь схожие признаки (кодировки), по этому можно сравнивать и судить о сходстве двух картинок

Функция потерь сиамской сети

Использование традиционной сиамской сетиContrastive Loss(функция потери контраста)

L=(1Y)12(DW)2+(Y)12{max(0,mDW)}2\mathcal{L} = (1-Y)\frac{1}{2}(D_W)^2+(Y)\frac{1}{2}\{max(0, m-D_W)\}^2

вDWD_Wопределяется как евклидово расстояние между двумя входами сиамской сети, т.е.

DW={GW(X1)GW(X2)}2D_W = \sqrt{\{G_W(X_1)-G_W(X_2)\}^2}
  • YYзначение равно 0 или 1, еслиX1,X2X_1,X_2Пара образцов принадлежит к одному классу, тогдаY=0Y=0,НапротивY=1Y=1
  • mmявляется предельным значением, т. е. когдаY=1Y=1,еслиX1X_1иX2X_2расстояние больше, чемmm, то не оптимизировать (экономя время и силы); еслиX1X_1иX2X_2дистанция междуmm, затем настройте параметры, чтобы увеличить расстояние доmm
Контрастный код потери
import torch
import numpy as np
import torch.nn.functional as F

class ContrastiveLoss(torch.nn.Module):
    "Contrastive loss function"
    def __init__(self, m=2.0):
        super(ContrastiveLoss, self).__init__()
        self.m = m
            
    def forward(self, output1, output2, label):
        d_w = F.pairwise_distance(output1, output2)
        contrastive_loss = torch.mean((1-label) * 0.5 * torch.pow(d_w, 2) +
                                      (label) * 0.5 * torch.pow(torch.clamp(self.m - d_w, min=0.0), 2))

        return contrastive_loss

в,F.pairwise_distance(x1, x2, p=2)Формула функции выглядит следующим образом

(i=1n(x1x2p))1px1,x2еRb×n(\sum_{i=1}^n(|x_1-x_2|^p))^{\frac{1}{p}}\\ x_1,x_2 \in \mathbb{R}^{b\times n}

pairwise_distance(x1, x2, p) Computes the batchwise pairwise distance between vectors x1x_1, x2x_2 using the p-norm

Использование сиамских сетей

Проще говоря, прямое использование сиамской сети заключается в измерении того, насколько разные (или похожие) два входа. Отправьте два входа в две нейронные сети соответственно, получите их представление в новом пространстве, а затем используйте функцию потерь для вычисления степени их различия (или степени сходства).

  • Анализ лексико-семантического подобия, сопоставление вопроса и ответа в QA
  • Распознавание рукописного ввода также может использовать сиамскую сеть.
  • Конкурс Quora's Question Pair на Kaggle, то есть на определение того, являются ли два вопроса одним и тем же вопросом.
Псевдосиамская сеть (псевдосиамская сеть)

Для псевдосиамской сети обе стороны могут бытьразличные нейронные сети(например, один lstm, другой cnn), и если это одна и та же нейронная сеть,не делитесь параметрамииз

Сценарии, в которых применимы соответственно сиамская сеть и псевдосиамская сеть
  • Сиамская сеть подходит для обработки двух входовпохожийСлучай
  • Псевдосиамская сеть, подходящая для обработки двух входовЕсть определенная разницаСлучай

Например, для вычисления смыслового сходства двух предложений или слов целесообразнее использовать Siamese Network; для проверки соответствия описания заголовка описанию текста (длина заголовка и текста очень отличается), или если текст описывает изображение (один - изображение, один - текст), следует использовать псевдо-сиамскую сеть

Тройная сеть

Если сиамские сети — это близнецы, то триплетные сети — это тройки. Его входов три: один положительный + два отрицательных или один отрицательный + два положительных. Цель обучения по-прежнему состоит в том, чтобы расстояние между одними и теми же классами было как можно меньше, а между разными классами — как можно больше. Triplet Network превосходит сиамскую сеть по наборам данных CIFAR и MNIST

Функция потерь определяется следующим образом:

L=max(d(a,p)d(a,n)+margin,0)\mathcal{L}=max(d(a,p)-d(a,n)+margin, 0)
  • aaпредставляет якорное изображение
  • ppПредставляет положительный образ
  • nnпредставляет собой негативный образ

мы надеемсяaaиppрасстояние должно быть меньшеaaиnnрасстояние.marginmarginэто гиперпараметр, который выражаетd(a,p)d(a,p)иd(a,n)d(a,n)какая разница должна быть между, например, предположениемmargin=0.2margin=0.2d(a,p)=0.5d(a,p)=0.5,Такd(a,n)d(a,n)должно быть больше или равно0.70.7

Reference