import copy import torch import torch.nn as nn from .linear_attention import LinearAttention, FullAttention #Ref: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py class LoFTREncoderLayer(nn.Module): def __init__(self, d_model, nhead, attention='linear'): super(LoFTREncoderLayer, self).__init__() self.dim = d_model // nhead self.nhead = nhead # multi-head attention self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, d_model, bias=False) self.v_proj = nn.Linear(d_model, d_model, bias=False) self.attention = LinearAttention() if attention == 'linear' else FullAttention() self.merge = nn.Linear(d_model, d_model, bias=False) # feed-forward network self.mlp = nn.Sequential( nn.Linear(d_model*2, d_model*2, bias=False), nn.ReLU(), nn.Linear(d_model*2, d_model, bias=False), ) # norm and dropout self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) def forward(self, x, source, x_mask=None, source_mask=None): """ Args: x (torch.Tensor): [N, L, C] source (torch.Tensor): [N, S, C] x_mask (torch.Tensor): [N, L] (optional) source_mask (torch.Tensor): [N, S] (optional) """ bs = x.size(0) query, key, value = x, source, source # multi-head attention query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] message = self.norm1(message) # feed-forward network message = self.mlp(torch.cat([x, message], dim=2)) message = self.norm2(message) return x + message class LocalFeatureTransformer(nn.Module): """A Local Feature Transformer (LoFTR) module.""" def __init__(self, d_model, nhead, layer_names, attention): super(LocalFeatureTransformer, self).__init__() self.d_model = d_model self.nhead = nhead self.layer_names = layer_names encoder_layer = LoFTREncoderLayer(d_model, nhead, attention) self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, feat0, feat1, mask0=None, mask1=None, ignore_second_feat=False): """ Args: feat0 (torch.Tensor): [N, L, C] feat1 (torch.Tensor): [N, S, C] mask0 (torch.Tensor): [N, L] (optional) mask1 (torch.Tensor): [N, S] (optional) """ assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" # NOTE Workaround for non statically determinable zip # for layer, name in zip(self.layers, self.layer_names): # layer_zip = ((layer, self.layer_names[i]) for i, layer in enumerate(self.layers)) # layer_zip = [] # for i, layer in enumerate(self.layers): # layer_zip.append((layer, self.layer_names[i])) for i, layer in enumerate(self.layers): name = self.layer_names[i] if name == 'self': feat0 = layer(feat0, feat0, mask0, mask0) if ignore_second_feat: # save some compute continue feat1 = layer(feat1, feat1, mask1, mask1) elif name == 'cross': feat0 = layer(feat0, feat1, mask0, mask1) feat1 = layer(feat1, feat0, mask1, mask0) else: raise KeyError if ignore_second_feat: return feat0 return feat0, feat1