diff --git a/model/networks.py b/model/networks.py index feb4c69..9542903 100644 --- a/model/networks.py +++ b/model/networks.py @@ -224,7 +224,7 @@ class DispNetS(TimedModule): ) def crop_like(self, input, ref): - assert (input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3)) + assert (input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3)), f'Assertion ({input.size(2)} >= {ref.size(2)} and {input.size(3)} >= {ref.size(3)}) failed' return input[:, :, :ref.size(2), :ref.size(3)] def tforward(self, x): @@ -291,7 +291,8 @@ class DispNetS(TimedModule): if self.output_ms: if self.double_head: - return (disp1, disp1_d), (disp2, disp2_d), disp3, disp4 + # NOTE return all tuples for easier handling + return (disp1, disp1_d), (disp2, disp2_d), (disp3, disp3), (disp4, disp4) return disp1, disp2, disp3, disp4 else: if self.double_head: @@ -304,8 +305,8 @@ class DispNetShallow(DispNetS): Edge Decoder based on DispNetS with fewer layers ''' - def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False): - super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init) + def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False, double_head=False): + super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init, double_head=False) self.mod_name = 'DispNetShallow' conv_planes = [32, 64, 128, 256, 512, 512, 512] upconv_planes = [512, 512, 256, 128, 64, 32, 16] @@ -335,6 +336,21 @@ class DispNetShallow(DispNetS): out_iconv1 = self.iconv1(concat1) disp1 = self.predict_disp1(out_iconv1) + if self.double_head: + out_upconv2_d = self.crop_like(self.upconv2(out_iconv3), out_conv1) + disp3_up_d = self.crop_like( + torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1) + concat2_d = torch.cat((out_upconv2_d, out_conv1, disp3_up_d), 1) + out_iconv2_d = self.iconv2(concat2_d) + disp2_d = self.predict_disp2_double(out_iconv2_d) + + out_upconv1_d = self.crop_like(self.upconv1(out_iconv2), x) + disp2_up_d = self.crop_like( + torch.nn.functional.interpolate(disp2_d, scale_factor=2, mode='bilinear', align_corners=False), x) + concat1_d = torch.cat((out_upconv1_d, disp2_up_d), 1) + out_iconv1_d = self.iconv1(concat1_d) + disp1_d = self.predict_disp1_double(out_iconv1_d) + if self.output_ms: return disp1, disp2, disp3 else: @@ -411,7 +427,6 @@ class RectifiedPatternSimilarityLoss(TimedModule): def tforward(self, disp0, im, std=None): self.pattern = self.pattern.to(disp0.device) self.uv0 = self.uv0.to(disp0.device) - uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:]) uv1 = torch.empty_like(uv0) uv1[..., 0] = uv0[..., 0] - disp0.contiguous().view(disp0.shape[0], -1) @@ -512,7 +527,7 @@ class ProjectionBaseLoss(TimedModule): xyz = xyz + t.reshape(bs, 1, 3) Kt = self.K.transpose(1, 2).expand(bs, -1, -1) - uv = torch.bmm(xyz, Kt) + uv = torch.bmm(xyz, Kt.float()) d = uv[:, :, 2:3]