diff --git a/doc/img/output.jpg b/doc/img/output.jpg index 67ae482..c7dc0ec 100644 Binary files a/doc/img/output.jpg and b/doc/img/output.jpg differ diff --git a/nets/attention/position_encoding.py b/nets/attention/position_encoding.py index 6ead877..c78307d 100644 --- a/nets/attention/position_encoding.py +++ b/nets/attention/position_encoding.py @@ -8,7 +8,7 @@ class PositionEncodingSine(nn.Module): This is a sinusoidal position encoding that generalized to 2-dimensional images """ - def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True): + def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=False): """ Args: max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels diff --git a/nets/attention/transformer.py b/nets/attention/transformer.py index f47c36a..de55ffc 100644 --- a/nets/attention/transformer.py +++ b/nets/attention/transformer.py @@ -24,7 +24,7 @@ class LoFTREncoderLayer(nn.Module): # feed-forward network self.mlp = nn.Sequential( nn.Linear(d_model*2, d_model*2, bias=False), - nn.ReLU(True), + nn.ReLU(), nn.Linear(d_model*2, d_model, bias=False), ) @@ -84,10 +84,10 @@ class LocalFeatureTransformer(nn.Module): mask0 (torch.Tensor): [N, L] (optional) mask1 (torch.Tensor): [N, S] (optional) """ - assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" for layer, name in zip(self.layers, self.layer_names): + if name == 'self': feat0 = layer(feat0, feat0, mask0, mask0) feat1 = layer(feat1, feat1, mask1, mask1) diff --git a/nets/corr.py b/nets/corr.py index aaa2af9..bf0e286 100644 --- a/nets/corr.py +++ b/nets/corr.py @@ -86,6 +86,8 @@ class AGCL: left_feature = left_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c' right_feature = right_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c' # 'n (h w) c -> n c h w' + left_feature, right_feature = self.att(left_feature, right_feature) + # 'n (h w) c -> n c h w' left_feature, right_feature = [ x.reshape(N, H, W, C).permute(0, 3, 1, 2) for x in [left_feature, right_feature] diff --git a/nets/extractor.py b/nets/extractor.py index 0faf510..993cd3a 100644 --- a/nets/extractor.py +++ b/nets/extractor.py @@ -24,9 +24,9 @@ class ResidualBlock(nn.Module): self.norm3 = nn.BatchNorm2d(planes) elif norm_fn == 'instance': - self.norm1 = nn.InstanceNorm2d(planes) - self.norm2 = nn.InstanceNorm2d(planes) - self.norm3 = nn.InstanceNorm2d(planes) + self.norm1 = nn.InstanceNorm2d(planes, affine=False) + self.norm2 = nn.InstanceNorm2d(planes, affine=False) + self.norm3 = nn.InstanceNorm2d(planes, affine=False) elif norm_fn == 'none': self.norm1 = nn.Sequential() @@ -59,7 +59,7 @@ class BasicEncoder(nn.Module): self.norm1 = nn.BatchNorm2d(64) elif self.norm_fn == 'instance': - self.norm1 = nn.InstanceNorm2d(64) + self.norm1 = nn.InstanceNorm2d(64, affine=False) elif self.norm_fn == 'none': self.norm1 = nn.Sequential() diff --git a/test_model.py b/test_model.py index 3d9cf12..ddf98e0 100644 --- a/test_model.py +++ b/test_model.py @@ -33,10 +33,10 @@ def inference(left, right, model, n_iter=20): align_corners=True, ) # print(imgR_dw2.shape) + with torch.inference_mode(): + pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None) - pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None) - - pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) + pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy() return pred_disp @@ -49,7 +49,7 @@ if __name__ == '__main__': in_h, in_w = left_img.shape[:2] # Resize image in case the GPU memory overflows - eval_h, eval_w = (1024//4,1536//4) + eval_h, eval_w = (in_h,in_w) imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) imgR = cv2.resize(right_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) @@ -65,11 +65,12 @@ if __name__ == '__main__': t = float(in_w) / float(eval_w) disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t - disp_vis = (disp - disp.min()) / (256 - disp.min()) * 255.0 + disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 disp_vis = disp_vis.astype("uint8") disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) combined_img = np.hstack((left_img, disp_vis)) + cv2.namedWindow("output", cv2.WINDOW_NORMAL) cv2.imshow("output", combined_img) cv2.imwrite("output.jpg", disp_vis) cv2.waitKey(0)