Pytorch создает SearchTransfer
SearchTransfer взят из статьи Learning Texture Transformer Network for Image Super-Resolution.
Его основная идея похожа на само-внимание, но само-внимание просто вычисляет пакетное матричное умножение (B, HW, C) и (B, C, HW), эта статья не рассматривает входные данные (B) напрямую. , HW), C), но расширить до (B, количество пикселей в каждом блоке, num_blocks) с помощью скользящего окна свертки, а затем сделать упор.
В этой статье описаны некоторые варианты использования, возникающие при воспроизведении модуля трансформатора.
ключевая функция
torch.nn.functional.unfold
torch.nn.functional.fold
torch.expand
torch.gather
Развернуть удобно, чтобы сделать внимание между блоками, а затем использовать полученный график подобия для расчета индекса, чтобы извлечь информацию в ref_unfold, и, наконец, восстановить его с помощью fold
1. unfold
unfold
с сnn.Conv2d
То же скользящее окно делит ввод на блоки
import torch
import torch.nn.functional as F
x = torch.rand((1, 3, 5, 5))
x_unfold = F.unfold(x, kernel_size=3, padding=1, stride=1)
print(x.shape) # torch.Size([1, 3, 5, 5])
print(x_unfold.shape) # torch.Size([1, 27, 25])
Форма x (пакет, канал, H, W), вы можете видеть, что форма x_unfold (пакет, k x k x канал, number_blocks)
k — это размер ядра, k x k x канал представляет собой количество пикселей в блоке.
number_blocks — сколько блоков может выскользнуть с учетом размера ядра, заполнения и шага.
2. fold
Использование fold противоположно unfold, то есть восстанавливает блоки один за другим (пакет, канал, H, W).
k = 6
s = 2
p = (k - s) // 2
H, W = 100, 100
x = torch.rand((1, 3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)
print(x_unfold.shape) # torch.Size([1, 108, 2500])
print(x_fold.shape) # torch.Size([1, 3, 10, 10])
print(x.mean()) # tensor(0.5012)
print(x_fold.mean()) # tensor(4.3924)
Видно, что хотя форма восстановлена, диапазон значений x и x_fold изменился, поскольку одна позиция (1x1xchannel) может появиться в нескольких блоках при развертывании, поэтому эти перекрывающиеся позиции будут суммироваться при складывании, что приведет к несогласованности данные. Следовательно, после получения x_fold необходимо разделить на количество перекрытий, чтобы получить исходный диапазон данных. Когда k=6, s=2, позиция появится в блоках 3*3=9 (окно скользит вверх и вниз, влево и вправо).
x = torch.rand((1, 3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p) / (3.*3.)
print(x_unfold.shape)
print(x_fold.shape)
print(x.mean()) # tensor(0.4998)
print(x_fold.mean()) # tensor(0.4866)
print((x[:, :, 30:40, 30:40] == x_fold[:, :, 30:40, 30:40]).sum()) # tensor(189)
Из функции sum() видно, что толькочастьДанные восстановлены.Другой способ точного вычисления делителей (например, 3. x 3.) состоит в том, чтобы использовать torch.ones в качестве входных данных.
k = 5
s = 3
p = (k - s) // 2
H, W = 100, 100
x = torch.rand((1, 3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)
ones = torch.ones((1, 3, H, W))
ones_unfold = F.unfold(ones, kernel_size=k, stride=s, padding=p)
ones_fold = F.fold(ones_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)
x_fold = x_fold / ones_fold
print(x.mean()) # tensor(0.5001)
print(x_fold.mean()) # tensor(0.5001)
print((x == x_fold).sum()) # tensor(30000) 每个点都被还原了
3. expand
использованиеTensor.expand(*size)
, вы можете использовать -1 для размера, чтобы представить размер, который остается неизменным
x = torch.rand((1, 4)) # x = torch.rand(4) 也可以得到同样的结果
x_expand1 = x.expand((3, 4))
x_expand2 = x.expand((3, -1))
print(x)
# tensor([[0.1745, 0.2331, 0.5449, 0.1914]])
print(x_expand1)
#tensor([[0.1745, 0.2331, 0.5449, 0.1914],
# [0.1745, 0.2331, 0.5449, 0.1914],
# [0.1745, 0.2331, 0.5449, 0.1914]])
print(x_expand2)
#tensor([[0.1745, 0.2331, 0.5449, 0.1914],
# [0.1745, 0.2331, 0.5449, 0.1914],
# [0.1745, 0.2331, 0.5449, 0.1914]])
4. gather
использованиеtorch.gather(input, dim, index, *, sparse_grad=False, out=None)
, эффект следующий
for i in range(dim0):
for j in range(dim1):
for k in range(dim2):
out[i, j, k] = input[index[i][j][k], j, k] # if dim == 0
out[i, j, k] = input[i, index[i][j][k], k] # if dim == 1
out[i, j, k] = input[i, j, index[i][j][k]] # if dim == 2
При использовании сбора сначала используйте расширение, чтобы сделать размер индекса равным вводу.
какindex.shape == [B, blocks]
, используйте расширение, чтобы изменить index.shape на [B, c x c x k, блоки], чтобыindex[i, :, k]
является одномерным тензором, и значение каждого элемента равно значению до расширенияindex[i, j]
Таким образом, при изменении jindex[i][j][k]
не изменится, поэтому в циклеout[i, j, k] = input[i, j, index[i][j][k]]
Просто поместите k-й блок на выход и k-й блок на входindex[i][j][k]
Каждая точка каждого блока соответствует однозначно (ход j).
5. Передача функций сборки
import torch
import torch.nn as nn
import torch.nn.functional as F
class Transfer(nn.Module):
def __init__(self):
super(Transfer, self).__init__()
def bis(self, unfold, dim, index):
"""
block index select
args:
unfold: [B, k*k*C, Hr*Wr]
dim: 哪个维度是blocks
index: [B, H*W], value range is [0, Hr*Wr-1]
return: [B, k*k*C, H*W]
"""
views = [unfold.size(0)] + [-1 if i == dim else 1 for i in range(1, len(unfold.size()))] # [B, 1, -1(H*W)]
expanse = list(unfold.size())
expanse[0] = -1
expanse[dim] = -1 # [-1, k*k*C, -1]
index = index.view(views).expand(expanse) # [B, H*W] -> [B, 1, H*W] -> [B, k*k*C, H*W]
return torch.gather(unfold, dim, index) # return[i][j][k] = unfold[i][j][index[i][j][k]]
def forward(self, lrsr_lv3, refsr_lv3, ref_lv1, ref_lv2, ref_lv3):
"""
args:
lrsr_lv3: [B, C, H, W]
refsr_lv3: [B, C, Hr, Wr]
ref_lv1: [B, C, Hr*4, Wr*4]
ref_lv2: [B, C, Hr*2, Wr*2]
ref_lv3: [B, C, Hr, Wr]
"""
H, W = lrsr_lv3.size()[-2:]
lrsr_lv3_unfold = F.unfold(lrsr_lv3, kernel_size=3, padding=1, stride=1) # [B, k*k*C, H*W]
refsr_lv3_unfold = F.unfold(refsr_lv3, kernel_size=3, padding=1, stride=1).transpose(1, 2) # [B, Hr*Wr, k*k*C]
lrsr_lv3_unfold = F.normalize(lrsr_lv3_unfold, dim=1)
refsr_lv3_unfold = F.normalize(refsr_lv3_unfold, dim=2)
R = torch.bmm(refsr_lv3_unfold, lrsr_lv3_unfold) # [B, Hr*Wr, H*W]
score, index = torch.max(R, dim=1) # [B, H*W]
ref_lv3_unfold = F.unfold(ref_lv3, kernel_size=3, padding=1, stride=1) # vgg19
ref_lv2_unfold = F.unfold(ref_lv2, kernel_size=6, padding=2, stride=2) # lv1->lv2, lv2->lv3有一次max pooling
ref_lv1_unfold = F.unfold(ref_lv1, kernel_size=12, padding=4, stride=4) # kernel_size没有按照真实的感受野计算
# 被除数,记录fold(unfold)时的overlap
divisor_lv3 = F.unfold(torch.ones_like(ref_lv3), kernel_size=3, padding=1, stride=1)
divisor_lv2 = F.unfold(torch.ones_like(ref_lv2), kernel_size=6, padding=2, stride=2)
divisor_lv1 = F.unfold(torch.ones_like(ref_lv1), kernel_size=12, padding=4, stride=4)
T_lv3_unfold = self.bis(ref_lv3_unfold, 2, index) # [B, k*k*C, H*W]
T_lv2_unfold = self.bis(ref_lv2_unfold, 2, index)
T_lv1_unfold = self.bis(ref_lv1_unfold, 2, index)
divisor_lv3 = self.bis(divisor_lv3, 2, index) # [B, k*k*C, H*W]
divisor_lv2 = self.bis(divisor_lv2, 2, index)
divisor_lv1 = self.bis(divisor_lv1, 2, index)
divisor_lv3 = F.fold(divisor_lv3, (H, W), kernel_size=3, padding=1, stride=1)
divisor_lv2 = F.fold(divisor_lv2, (2*H, 2*W), kernel_size=6, padding=2, stride=2)
divisor_lv1 = F.fold(divisor_lv1, (4*H, 4*W), kernel_size=12, padding=4, stride=4)
T_lv3 = F.fold(T_lv3_unfold, (H, W), kernel_size=3, padding=1, stride=1) / divisor_lv3
T_lv2 = F.fold(T_lv2_unfold, (2*H, 2*W), kernel_size=6, padding=2, stride=2) / divisor_lv2
T_lv1 = F.fold(T_lv1_unfold, (4*H, 4*W), kernel_size=12, padding=4, stride=4) / divisor_lv1
score = score.view(lrsr_lv3.size(0), 1, H, W) # [B, 1, H, W]
return score, T_lv1, T_lv2, T_lv3
** Пояснение к сбору в bis: ** При использовании сбора сначала используйте расширение, чтобы размер индекса был равен входному.
какindex.shape == [B, blocks]
, используйте расширить доindex.shape
становится [B, c x c x k, блоки], так чтоindex[i, :, k]
является одномерным тензором, и значение каждого элемента равно значению до расширенияindex[i, j]
Таким образом, при изменении jindex[i][j][k]
не изменится, поэтому в циклеout[i, j, k] = input[i, j, index[i][j][k]]
Просто поместите k-й блок на выход и k-й блок на входindex[i][j][k]
Каждая точка каждого блока соответствует однозначно (ход j).