Pytorch создает SearchTransfer

глубокое обучение

Pytorch создает SearchTransfer

SearchTransfer взят из статьи Learning Texture Transformer Network for Image Super-Resolution.

[paper] [code]

Его основная идея похожа на само-внимание, но само-внимание просто вычисляет пакетное матричное умножение (B, HW, C) и (B, C, HW), эта статья не рассматривает входные данные (B) напрямую. , HW), C), но расширить до (B, количество пикселей в каждом блоке, num_blocks) с помощью скользящего окна свертки, а затем сделать упор.

В этой статье описаны некоторые варианты использования, возникающие при воспроизведении модуля трансформатора.

ключевая функция

  • torch.nn.functional.unfold
  • torch.nn.functional.fold
  • torch.expand
  • torch.gather

Развернуть удобно, чтобы сделать внимание между блоками, а затем использовать полученный график подобия для расчета индекса, чтобы извлечь информацию в ref_unfold, и, наконец, восстановить его с помощью fold

1. unfold

unfoldс сnn.Conv2dТо же скользящее окно делит ввод на блоки

import torch
import torch.nn.functional as F

x = torch.rand((1, 3, 5, 5))
x_unfold = F.unfold(x, kernel_size=3, padding=1, stride=1)
print(x.shape)	# torch.Size([1, 3, 5, 5])
print(x_unfold.shape)	# torch.Size([1, 27, 25])

Форма x (пакет, канал, H, W), вы можете видеть, что форма x_unfold (пакет, k x k x канал, number_blocks)

k — это размер ядра, k x k x канал представляет собой количество пикселей в блоке.

number_blocks — сколько блоков может выскользнуть с учетом размера ядра, заполнения и шага.

2. fold

Использование fold противоположно unfold, то есть восстанавливает блоки один за другим (пакет, канал, H, W).

k = 6
s = 2
p = (k - s) // 2
H, W = 100, 100

x = torch.rand((1, 3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)
print(x_unfold.shape)	# torch.Size([1, 108, 2500])
print(x_fold.shape)		# torch.Size([1, 3, 10, 10])
print(x.mean())			# tensor(0.5012)
print(x_fold.mean())	# tensor(4.3924)

Видно, что хотя форма восстановлена, диапазон значений x и x_fold изменился, поскольку одна позиция (1x1xchannel) может появиться в нескольких блоках при развертывании, поэтому эти перекрывающиеся позиции будут суммироваться при складывании, что приведет к несогласованности данные. Следовательно, после получения x_fold необходимо разделить на количество перекрытий, чтобы получить исходный диапазон данных. Когда k=6, s=2, позиция появится в блоках 3*3=9 (окно скользит вверх и вниз, влево и вправо).

x = torch.rand((1, 3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p) / (3.*3.)
print(x_unfold.shape)
print(x_fold.shape)
print(x.mean())			# tensor(0.4998)
print(x_fold.mean())	# tensor(0.4866)
print((x[:, :, 30:40, 30:40] == x_fold[:, :, 30:40, 30:40]).sum()) # tensor(189)

Из функции sum() видно, что толькочастьДанные восстановлены.Другой способ точного вычисления делителей (например, 3. x 3.) состоит в том, чтобы использовать torch.ones в качестве входных данных.

k = 5
s = 3
p = (k - s) // 2
H, W = 100, 100

x = torch.rand((1, 3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)

ones = torch.ones((1, 3, H, W))
ones_unfold = F.unfold(ones, kernel_size=k, stride=s, padding=p)
ones_fold = F.fold(ones_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)

x_fold = x_fold / ones_fold
print(x.mean())			# tensor(0.5001)
print(x_fold.mean())	# tensor(0.5001)
print((x == x_fold).sum())	# tensor(30000) 每个点都被还原了

3. expand

использованиеTensor.expand(*size), вы можете использовать -1 для размера, чтобы представить размер, который остается неизменным

x = torch.rand((1, 4))	# x = torch.rand(4) 也可以得到同样的结果
x_expand1 = x.expand((3, 4))
x_expand2 = x.expand((3, -1))

print(x)
# tensor([[0.1745, 0.2331, 0.5449, 0.1914]])

print(x_expand1)
#tensor([[0.1745, 0.2331, 0.5449, 0.1914],
#        [0.1745, 0.2331, 0.5449, 0.1914],
#        [0.1745, 0.2331, 0.5449, 0.1914]])

print(x_expand2)
#tensor([[0.1745, 0.2331, 0.5449, 0.1914],
#        [0.1745, 0.2331, 0.5449, 0.1914],
#        [0.1745, 0.2331, 0.5449, 0.1914]])

4. gather

использованиеtorch.gather(input, dim, index, *, sparse_grad=False, out=None), эффект следующий

for i in range(dim0):
    for j in range(dim1):
        for k in range(dim2):
            out[i, j, k] = input[index[i][j][k], j, k]  # if dim == 0
			out[i, j, k] = input[i, index[i][j][k], k]  # if dim == 1
			out[i, j, k] = input[i, j, index[i][j][k]]  # if dim == 2

При использовании сбора сначала используйте расширение, чтобы сделать размер индекса равным вводу.

какindex.shape == [B, blocks], используйте расширение, чтобы изменить index.shape на [B, c x c x k, блоки], чтобыindex[i, :, k]является одномерным тензором, и значение каждого элемента равно значению до расширенияindex[i, j]

Таким образом, при изменении jindex[i][j][k]не изменится, поэтому в циклеout[i, j, k] = input[i, j, index[i][j][k]]Просто поместите k-й блок на выход и k-й блок на входindex[i][j][k]Каждая точка каждого блока соответствует однозначно (ход j).

5. Передача функций сборки

import torch
import torch.nn as nn
import torch.nn.functional as F


class Transfer(nn.Module):
    def __init__(self):
        super(Transfer, self).__init__()

    def bis(self, unfold, dim, index):
        """
        block index select
        args:
            unfold: [B, k*k*C, Hr*Wr]
            dim: 哪个维度是blocks
            index: [B, H*W],  value range is [0, Hr*Wr-1]
            return: [B, k*k*C, H*W]
        """
        views = [unfold.size(0)] + [-1 if i == dim else 1 for i in range(1, len(unfold.size()))]  # [B, 1, -1(H*W)]
        expanse = list(unfold.size())
        expanse[0] = -1
        expanse[dim] = -1   # [-1, k*k*C, -1]
        index = index.view(views).expand(expanse)   # [B, H*W] -> [B, 1, H*W] -> [B, k*k*C, H*W]
        return torch.gather(unfold, dim, index)    # return[i][j][k] = unfold[i][j][index[i][j][k]]

    def forward(self, lrsr_lv3, refsr_lv3, ref_lv1, ref_lv2, ref_lv3):
        """
            args:
                lrsr_lv3: [B, C, H, W]
                refsr_lv3: [B, C, Hr, Wr]
                ref_lv1: [B, C, Hr*4, Wr*4]
                ref_lv2: [B, C, Hr*2, Wr*2]
                ref_lv3: [B, C, Hr, Wr]
        """
        H, W = lrsr_lv3.size()[-2:]

        lrsr_lv3_unfold = F.unfold(lrsr_lv3, kernel_size=3, padding=1, stride=1)    # [B, k*k*C, H*W]
        refsr_lv3_unfold = F.unfold(refsr_lv3, kernel_size=3, padding=1, stride=1).transpose(1, 2)  # [B, Hr*Wr, k*k*C]

        lrsr_lv3_unfold = F.normalize(lrsr_lv3_unfold, dim=1)
        refsr_lv3_unfold = F.normalize(refsr_lv3_unfold, dim=2)

        R = torch.bmm(refsr_lv3_unfold, lrsr_lv3_unfold)  # [B, Hr*Wr, H*W]
        score, index = torch.max(R, dim=1)  # [B, H*W]

        ref_lv3_unfold = F.unfold(ref_lv3, kernel_size=3, padding=1, stride=1)      # vgg19
        ref_lv2_unfold = F.unfold(ref_lv2, kernel_size=6, padding=2, stride=2)      # lv1->lv2, lv2->lv3有一次max pooling
        ref_lv1_unfold = F.unfold(ref_lv1, kernel_size=12, padding=4, stride=4)     # kernel_size没有按照真实的感受野计算

        # 被除数,记录fold(unfold)时的overlap
        divisor_lv3 = F.unfold(torch.ones_like(ref_lv3), kernel_size=3, padding=1, stride=1)
        divisor_lv2 = F.unfold(torch.ones_like(ref_lv2), kernel_size=6, padding=2, stride=2)
        divisor_lv1 = F.unfold(torch.ones_like(ref_lv1), kernel_size=12, padding=4, stride=4)

        T_lv3_unfold = self.bis(ref_lv3_unfold, 2, index)   # [B, k*k*C, H*W]
        T_lv2_unfold = self.bis(ref_lv2_unfold, 2, index)
        T_lv1_unfold = self.bis(ref_lv1_unfold, 2, index)

        divisor_lv3 = self.bis(divisor_lv3, 2, index)  # [B, k*k*C, H*W]
        divisor_lv2 = self.bis(divisor_lv2, 2, index)
        divisor_lv1 = self.bis(divisor_lv1, 2, index)

        divisor_lv3 = F.fold(divisor_lv3, (H, W), kernel_size=3, padding=1, stride=1)
        divisor_lv2 = F.fold(divisor_lv2, (2*H, 2*W), kernel_size=6, padding=2, stride=2)
        divisor_lv1 = F.fold(divisor_lv1, (4*H, 4*W), kernel_size=12, padding=4, stride=4)

        T_lv3 = F.fold(T_lv3_unfold, (H, W), kernel_size=3, padding=1, stride=1) / divisor_lv3
        T_lv2 = F.fold(T_lv2_unfold, (2*H, 2*W), kernel_size=6, padding=2, stride=2) / divisor_lv2
        T_lv1 = F.fold(T_lv1_unfold, (4*H, 4*W), kernel_size=12, padding=4, stride=4) / divisor_lv1

        score = score.view(lrsr_lv3.size(0), 1, H, W)   # [B, 1, H, W]

        return score, T_lv1, T_lv2, T_lv3

** Пояснение к сбору в bis: ** При использовании сбора сначала используйте расширение, чтобы размер индекса был равен входному.

какindex.shape == [B, blocks], используйте расширить доindex.shapeстановится [B, c x c x k, блоки], так чтоindex[i, :, k]является одномерным тензором, и значение каждого элемента равно значению до расширенияindex[i, j]

Таким образом, при изменении jindex[i][j][k]не изменится, поэтому в циклеout[i, j, k] = input[i, j, index[i][j][k]]Просто поместите k-й блок на выход и k-й блок на входindex[i][j][k]Каждая точка каждого блока соответствует однозначно (ход j).

Ссылаться на

py torch.org/docs/stable…

GitHub.com/researchMM/…