last batch of live-fixes and improvments
This commit is contained in:
@@ -76,7 +76,7 @@ class LocalFeatureTransformer(nn.Module):
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, feat0, feat1, mask0=None, mask1=None):
|
||||
def forward(self, feat0, feat1, mask0=None, mask1=None, ignore_second_feat=False):
|
||||
"""
|
||||
Args:
|
||||
feat0 (torch.Tensor): [N, L, C]
|
||||
@@ -97,6 +97,9 @@ class LocalFeatureTransformer(nn.Module):
|
||||
name = self.layer_names[i]
|
||||
if name == 'self':
|
||||
feat0 = layer(feat0, feat0, mask0, mask0)
|
||||
if ignore_second_feat:
|
||||
# save some compute
|
||||
continue
|
||||
feat1 = layer(feat1, feat1, mask1, mask1)
|
||||
elif name == 'cross':
|
||||
feat0 = layer(feat0, feat1, mask0, mask1)
|
||||
@@ -104,4 +107,6 @@ class LocalFeatureTransformer(nn.Module):
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
if ignore_second_feat:
|
||||
return feat0
|
||||
return feat0, feat1
|
||||
|
||||
Reference in New Issue
Block a user