last batch of live-fixes and improvments

This commit is contained in:
2022-09-23 11:27:04 +02:00
parent 2731ef1ada
commit 6f6ac23175
6 changed files with 264 additions and 137 deletions
+6 -1
View File
@@ -76,7 +76,7 @@ class LocalFeatureTransformer(nn.Module):
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, feat0, feat1, mask0=None, mask1=None):
def forward(self, feat0, feat1, mask0=None, mask1=None, ignore_second_feat=False):
"""
Args:
feat0 (torch.Tensor): [N, L, C]
@@ -97,6 +97,9 @@ class LocalFeatureTransformer(nn.Module):
name = self.layer_names[i]
if name == 'self':
feat0 = layer(feat0, feat0, mask0, mask0)
if ignore_second_feat:
# save some compute
continue
feat1 = layer(feat1, feat1, mask1, mask1)
elif name == 'cross':
feat0 = layer(feat0, feat1, mask0, mask1)
@@ -104,4 +107,6 @@ class LocalFeatureTransformer(nn.Module):
else:
raise KeyError
if ignore_second_feat:
return feat0
return feat0, feat1
+2 -1
View File
@@ -151,7 +151,8 @@ class CREStereo(nn.Module):
# FIXME experimental ! no self-attention for pattern
if not self_attend_right:
fmap1_dw16, _ = self.self_att_fn(fmap1_dw16, fmap2_dw16)
print('skipping right attention')
fmap1_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16, ignore_second_feat=True)
else:
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)