Fix implementation issues
This commit is contained in:
@@ -8,7 +8,7 @@ class PositionEncodingSine(nn.Module):
|
||||
This is a sinusoidal position encoding that generalized to 2-dimensional images
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True):
|
||||
def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=False):
|
||||
"""
|
||||
Args:
|
||||
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
|
||||
|
||||
@@ -24,7 +24,7 @@ class LoFTREncoderLayer(nn.Module):
|
||||
# feed-forward network
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(d_model*2, d_model*2, bias=False),
|
||||
nn.ReLU(True),
|
||||
nn.ReLU(),
|
||||
nn.Linear(d_model*2, d_model, bias=False),
|
||||
)
|
||||
|
||||
@@ -84,10 +84,10 @@ class LocalFeatureTransformer(nn.Module):
|
||||
mask0 (torch.Tensor): [N, L] (optional)
|
||||
mask1 (torch.Tensor): [N, S] (optional)
|
||||
"""
|
||||
|
||||
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
|
||||
|
||||
for layer, name in zip(self.layers, self.layer_names):
|
||||
|
||||
if name == 'self':
|
||||
feat0 = layer(feat0, feat0, mask0, mask0)
|
||||
feat1 = layer(feat1, feat1, mask1, mask1)
|
||||
|
||||
@@ -86,6 +86,8 @@ class AGCL:
|
||||
left_feature = left_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c'
|
||||
right_feature = right_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c'
|
||||
# 'n (h w) c -> n c h w'
|
||||
left_feature, right_feature = self.att(left_feature, right_feature)
|
||||
# 'n (h w) c -> n c h w'
|
||||
left_feature, right_feature = [
|
||||
x.reshape(N, H, W, C).permute(0, 3, 1, 2)
|
||||
for x in [left_feature, right_feature]
|
||||
|
||||
+4
-4
@@ -24,9 +24,9 @@ class ResidualBlock(nn.Module):
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(planes)
|
||||
self.norm2 = nn.InstanceNorm2d(planes)
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
self.norm1 = nn.InstanceNorm2d(planes, affine=False)
|
||||
self.norm2 = nn.InstanceNorm2d(planes, affine=False)
|
||||
self.norm3 = nn.InstanceNorm2d(planes, affine=False)
|
||||
|
||||
elif norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
@@ -59,7 +59,7 @@ class BasicEncoder(nn.Module):
|
||||
self.norm1 = nn.BatchNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(64)
|
||||
self.norm1 = nn.InstanceNorm2d(64, affine=False)
|
||||
|
||||
elif self.norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
Reference in New Issue
Block a user