You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
100 lines
3.6 KiB
100 lines
3.6 KiB
import copy
|
|
import torch
|
|
import torch.nn as nn
|
|
from .linear_attention import LinearAttention, FullAttention
|
|
|
|
#Ref: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py
|
|
class LoFTREncoderLayer(nn.Module):
|
|
def __init__(self,
|
|
d_model,
|
|
nhead,
|
|
attention='linear'):
|
|
super(LoFTREncoderLayer, self).__init__()
|
|
|
|
self.dim = d_model // nhead
|
|
self.nhead = nhead
|
|
|
|
# multi-head attention
|
|
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
|
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
|
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
|
self.attention = LinearAttention() if attention == 'linear' else FullAttention()
|
|
self.merge = nn.Linear(d_model, d_model, bias=False)
|
|
|
|
# feed-forward network
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(d_model*2, d_model*2, bias=False),
|
|
nn.ReLU(True),
|
|
nn.Linear(d_model*2, d_model, bias=False),
|
|
)
|
|
|
|
# norm and dropout
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
|
|
def forward(self, x, source, x_mask=None, source_mask=None):
|
|
"""
|
|
Args:
|
|
x (torch.Tensor): [N, L, C]
|
|
source (torch.Tensor): [N, S, C]
|
|
x_mask (torch.Tensor): [N, L] (optional)
|
|
source_mask (torch.Tensor): [N, S] (optional)
|
|
"""
|
|
bs = x.size(0)
|
|
query, key, value = x, source, source
|
|
|
|
# multi-head attention
|
|
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
|
|
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
|
|
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
|
|
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
|
|
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
|
message = self.norm1(message)
|
|
|
|
# feed-forward network
|
|
message = self.mlp(torch.cat([x, message], dim=2))
|
|
message = self.norm2(message)
|
|
|
|
return x + message
|
|
|
|
|
|
class LocalFeatureTransformer(nn.Module):
|
|
"""A Local Feature Transformer (LoFTR) module."""
|
|
|
|
def __init__(self, d_model, nhead, layer_names, attention):
|
|
super(LocalFeatureTransformer, self).__init__()
|
|
|
|
self.d_model = d_model
|
|
self.nhead = nhead
|
|
self.layer_names = layer_names
|
|
encoder_layer = LoFTREncoderLayer(d_model, nhead, attention)
|
|
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
|
|
self._reset_parameters()
|
|
|
|
def _reset_parameters(self):
|
|
for p in self.parameters():
|
|
if p.dim() > 1:
|
|
nn.init.xavier_uniform_(p)
|
|
|
|
def forward(self, feat0, feat1, mask0=None, mask1=None):
|
|
"""
|
|
Args:
|
|
feat0 (torch.Tensor): [N, L, C]
|
|
feat1 (torch.Tensor): [N, S, C]
|
|
mask0 (torch.Tensor): [N, L] (optional)
|
|
mask1 (torch.Tensor): [N, S] (optional)
|
|
"""
|
|
|
|
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
|
|
|
|
for layer, name in zip(self.layers, self.layer_names):
|
|
if name == 'self':
|
|
feat0 = layer(feat0, feat0, mask0, mask0)
|
|
feat1 = layer(feat1, feat1, mask1, mask1)
|
|
elif name == 'cross':
|
|
feat0 = layer(feat0, feat1, mask0, mask1)
|
|
feat1 = layer(feat1, feat0, mask1, mask0)
|
|
else:
|
|
raise KeyError
|
|
|
|
return feat0, feat1 |