[Блог технологий] Адаптация состязательного домена

искусственный интеллект

Введение в адаптацию домена

Адаптация предметной области — одна из наиболее распространенных проблем в трансферном обучении.Предметы разные, но задачи одинаковые, а данные исходного домена имеют метки, а данные целевого домена не имеют меток или мало данных имеют метки. Адаптация домена работает путем проецирования функций исходного и целевого доменов в аналогичное пространство признаков, так что классификатор исходного домена можно использовать для классификации целевого домена.

В качестве иллюстрации возьмем две категории, как показано ниже:域.PNGНа рисунке красный кружок — исходный домен, синий кружок — целевой домен, а кружок и крестик — данные разных характеристик.Классификатор исходного домена делит данные исходного домена на две категории, как показано пунктирной линией. В это время, если классификатор исходного домена используется для классификации целевого домена, из рисунка видно, что эффект очень плохой.   Итак, что делать? Один из способов — выровнять распределение исходного домена и целевого домена. Как показано в правой части рисунка, распределение исходного домена и целевого домена похоже (то есть данные аналогичные функции распределены в аналогичных позициях), так что вы можете напрямую взять Классификатор исходного домена классифицировал целевой домен.

Тренировочный процесс доменной состязательной генеративной сети GAN аналогичен Две модели обучаются одновременно: одна используется для извлечения функции целевого домена MT, а дискриминатор домена D используется для определения того, относится ли функция к исходному или целевому домену. максимизации D и создания ошибок, то есть извлеченного МТ. Эта особенность делает невозможным для D определение того, из исходного домена он или из целевого домена.

Экстрактор признаков целевого домена MT и дискриминатор домена D являются противниками друг друга: D учится различать, принадлежат ли признаки исходному домену или целевому домену, а MT учится приближать извлеченные им признаки к признакам, извлеченным из исходный домен. Средство извлечения признаков целевого домена MT можно рассматривать как группу фальшивомонетчиков, пытающуюся создать подделку и использовать ее без обнаружения, в то время как дискриминатор домена D подобен полицейскому, пытающемуся обнаружить поддельные деньги. Соревнование в этой игре заставляет обе команды совершенствовать свои методы, пока трудно сказать.

Адаптация враждебного домена

выбор данных

Для хорошего эффекта и простоты обучения я выбираю данные 0 и 1 в наборе данных mnist в качестве исходного домена, а данные 2 и 3 — в качестве целевого домена. В исходном домене и целевом домене содержится 10 000 фрагментов данных. Во время обучения исходный домен может получать данные и метки, в то время как целевой домен может получать только данные без меток для имитации контекста адаптации домена. Метки целевого домена используются только при проверке точности.

Интернет

1. Экстрактор признаков исходного домена MS, экстрактор признаков целевого домена MT. Так называемый экстрактор признаков фактически удаляет из сети последний слой классификации, который распознает mnist.

		(encoder): Sequential (
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (2): ReLU ()
    (3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
    (4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (5): ReLU ()
    )
    (fc1): Linear (64 * 4 * 4 -> 512)

Думайте о выходе этой сети как об извлеченных функциях.

2. Классификатор С. Фактически, это последний уровень классификации сети, который распознает mnist, простую полносвязную сеть.

		Classifier (
    (fc2): Linear (512 -> 2)
    )

3. Идентификатор домена D. В соответствии с выходными данными экстрактора признаков, чтобы определить, поступают ли данные из исходного домена или из целевого домена, выход 0 представляет из исходного домена, а выход 1 представляет из целевого домена.

		Discriminator (
     (layer): Sequential (
    (0): Linear (512 -> 512)
    (1): Linear (512 -> 512)
    (2): Linear (512 -> 2)
    ))

 

тренировочный процесс

Поезд МС, С

Во-первых, обучите экстрактор признаков MS и классификатор C в исходном домене.过程1.PNGПроцесс обучения аналогичен общему процессу обучения, за исключением того, что вся сеть разделена на две части для обучения и оптимизации.

def train_MS_C(loader_ms):
    # 模型
    MS = Encoder()
    C = Classifier()
    # 优化器
    o_ms = optim.SGD(MS.parameters(), lr=0.03)
    o_c = optim.SGD(C.parameters(), lr=0.03)
    criterion = nn.CrossEntropyLoss()  # 计算损失
    for j in range(1):
        print(j)
        # 训练
        for i, (images, labels) in enumerate(loader_ms):
            o_ms.zero_grad()
            o_c.zero_grad()
            outputs_mid = MS(images)
            outputs = C(outputs_mid)

            loss = criterion(outputs, labels)
            loss.backward()

            o_ms.step()  # 优化参数
            o_c.step()

            if i % 100 == 0:
                print(i)
                print('current loss : %.5f' % loss.data.item())
    # 保存模型
    np.save(params.MS_save_dir, MS.get_w())
    np.save(params.C_save_dir, C.get_w())

После обучения точность в исходном домене составляет 0,9985. Если вы напрямую используете экстрактор признаков и классификатор исходного домена для классификации целевого домена, точность составит всего 0,5840.acc1.PNG

Исправить MS и C, обучить MT и D

Затем, зафиксировав MS и C без изменений, т. е. без изменения их сетевых весов, состязательно изучите экстрактор признаков целевого домена MT и дискриминатор домена D на исходном и целевом доменах. 1. Инициализируйте MT с MS, чтобы целевой домен получил хорошую точность 0,5840 в начале, а затем тренируйтесь на этой основе, легче сходиться в хорошем направлении, и процесс сходимости также быстрее.

MT.update_w(np.load(params.MS_save_dir, encoding='bytes', allow_pickle=True).item())

过程2.PNG

def train_MT_D(loader_ms, loader_mt):
    # 模型
    MS = Encoder()
    MT = Encoder()
    D = Discriminator()
    # 加载模型
    MS.update_w(np.load(params.MS_save_dir, encoding='bytes', allow_pickle=True).item())

    if params.first_train:
        params.first_train = False
        # 第一次训练
        # MT用MS的权重初始化
        MT.update_w(np.load(params.MS_save_dir, encoding='bytes', allow_pickle=True).item())
    else:
        MT.update_w(np.load(params.MT_save_dir, encoding='bytes', allow_pickle=True).item())
        D.update_w(np.load(params.D_save_dir, encoding='bytes', allow_pickle=True).item())

    # 优化器
    o_mt = optim.SGD(MT.parameters(), lr=0.00001)
    o_d = optim.SGD(D.parameters(), lr=0.00001)
    criterion = nn.CrossEntropyLoss()  # 计算损失
    # 训练
    for j in range(1):
        print(j)
        # 训练D 域辨别器
        data_zip = zip(loader_ms, loader_mt)
        for i, ((images_s, labels_s), (images_t, labels_t)) in enumerate(data_zip):
            ################对域辨别器D的训练
            # 提取的特征
            f_s = MS(images_s)
            f_t = MT(images_t)
            f_cat = torch.cat((f_s, f_t), 0)
            # 域辨别器辨别结果
            out_D = D(f_cat.detach())

            predicts_D = torch.max(out_D.data, 1)[1]
            if i == 0:
                print('域辨别器的辨别结果')
                print(predicts_D)

            # 构造损失对比用的标签
            len_s = len(labels_s)
            len_t = len(labels_t)

            temp1 = torch.zeros(len_s)
            temp2 = torch.ones(len_t)

            lab_D = torch.cat((temp1, temp2), 0).long()

            # 梯度置0
            o_d.zero_grad()
            # 计算loss
            loss_D = criterion(out_D, lab_D)
            # 反向传播
            loss_D.backward()
            # 优化网络
            o_d.step()
            ##############################对目标域特征提取器MT的训练
            # 提取的特征
            f_t = MT(images_t)
            # 域辨别器辨别结果
            d_t = D(f_t)
            # 构造计算损失的outputs、labels
            out_MT = d_t

            predicts_MT = torch.max(out_MT.data, 1)[1]

            lab_MT = torch.zeros(len_t).long()
            # 梯度置0
            o_mt.zero_grad()
            # 计算loss
            loss_MT = criterion(out_MT, lab_MT)
            # 反向传播
            loss_MT.backward()
            # 优化网络
            o_mt.step()

            if i % 100 == 0:
                print(i)
                print('current loss_D : %.5f' % loss_D.data.item())
                print('current loss_MT : %.5f' % loss_MT.data.item())
    # 保存模型
    np.save(params.MT_save_dir, MT.get_w())
    np.save(params.D_save_dir, D.get_w())

Классифицировать по целевому домену с помощью MT и C

Наконец, используйте обученный экстрактор признаков целевого домена MT и классификатор C для классификации в целевом домене.过程3.PNG

def test_MT_C(loader_mt):
    MT = Encoder()
    C = Classifier()
    # 加载模型
    MT.update_w(np.load(params.MT_save_dir, encoding='bytes', allow_pickle=True).item())
    C.update_w(np.load(params.C_save_dir, encoding='bytes', allow_pickle=True).item())
    correct = 0
    for images, labels in loader_mt:
        outputs_mid = MT(images)
        outputs = C(outputs_mid)
        _, predicts = torch.max(outputs.data, 1)
        correct += (predicts == labels).sum()
    total = len(loader_mt.dataset)
    print('MT+C  Accuracy: %.4f' % (1.0 * correct / total))

Результаты экспериментов

Используя экстрактор признаков и классификатор исходного домена для классификации целевого домена, точность составляет всего 0,5840.acc1.PNGНа следующем рисунке показан результат дискриминатора домена D. Вход первой половины — это признак исходного домена, а ввод второй половины — признак целевого домена.Теперь большая часть D может судить правильно.捕获.PNG

После нескольких эпох обучения точность немного повышается.acc2.PNGСпособность D различать домен снижается, и большая часть входных данных целевого домена оценивается как исходный домен.捕获2.PNG

После 40 периодов обучения точность колеблется около 0,9, что является большим улучшением по сравнению с начальным значением 0,5840.acc3.PNGD больше не может различать исходный и целевой домены и распознает все входные данные как исходный домен.捕获3.PNG

кодовый адрес

Тихо потяните. Может /explore/5 отправить 1… 

Ссылаться на

Adversarial Discriminative Domain Adaptation blog.CSDN.net/День мертвых_29381… GitHub.com/core?l/Пак Ючон…