|
|
@ -224,7 +224,7 @@ class DispNetS(TimedModule): |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def crop_like(self, input, ref): |
|
|
|
def crop_like(self, input, ref): |
|
|
|
assert (input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3)) |
|
|
|
assert (input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3)), f'Assertion ({input.size(2)} >= {ref.size(2)} and {input.size(3)} >= {ref.size(3)}) failed' |
|
|
|
return input[:, :, :ref.size(2), :ref.size(3)] |
|
|
|
return input[:, :, :ref.size(2), :ref.size(3)] |
|
|
|
|
|
|
|
|
|
|
|
def tforward(self, x): |
|
|
|
def tforward(self, x): |
|
|
@ -291,7 +291,8 @@ class DispNetS(TimedModule): |
|
|
|
|
|
|
|
|
|
|
|
if self.output_ms: |
|
|
|
if self.output_ms: |
|
|
|
if self.double_head: |
|
|
|
if self.double_head: |
|
|
|
return (disp1, disp1_d), (disp2, disp2_d), disp3, disp4 |
|
|
|
# NOTE return all tuples for easier handling |
|
|
|
|
|
|
|
return (disp1, disp1_d), (disp2, disp2_d), (disp3, disp3), (disp4, disp4) |
|
|
|
return disp1, disp2, disp3, disp4 |
|
|
|
return disp1, disp2, disp3, disp4 |
|
|
|
else: |
|
|
|
else: |
|
|
|
if self.double_head: |
|
|
|
if self.double_head: |
|
|
@ -304,8 +305,8 @@ class DispNetShallow(DispNetS): |
|
|
|
Edge Decoder based on DispNetS with fewer layers |
|
|
|
Edge Decoder based on DispNetS with fewer layers |
|
|
|
''' |
|
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False): |
|
|
|
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False, double_head=False): |
|
|
|
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init) |
|
|
|
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init, double_head=False) |
|
|
|
self.mod_name = 'DispNetShallow' |
|
|
|
self.mod_name = 'DispNetShallow' |
|
|
|
conv_planes = [32, 64, 128, 256, 512, 512, 512] |
|
|
|
conv_planes = [32, 64, 128, 256, 512, 512, 512] |
|
|
|
upconv_planes = [512, 512, 256, 128, 64, 32, 16] |
|
|
|
upconv_planes = [512, 512, 256, 128, 64, 32, 16] |
|
|
@ -335,6 +336,21 @@ class DispNetShallow(DispNetS): |
|
|
|
out_iconv1 = self.iconv1(concat1) |
|
|
|
out_iconv1 = self.iconv1(concat1) |
|
|
|
disp1 = self.predict_disp1(out_iconv1) |
|
|
|
disp1 = self.predict_disp1(out_iconv1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.double_head: |
|
|
|
|
|
|
|
out_upconv2_d = self.crop_like(self.upconv2(out_iconv3), out_conv1) |
|
|
|
|
|
|
|
disp3_up_d = self.crop_like( |
|
|
|
|
|
|
|
torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1) |
|
|
|
|
|
|
|
concat2_d = torch.cat((out_upconv2_d, out_conv1, disp3_up_d), 1) |
|
|
|
|
|
|
|
out_iconv2_d = self.iconv2(concat2_d) |
|
|
|
|
|
|
|
disp2_d = self.predict_disp2_double(out_iconv2_d) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out_upconv1_d = self.crop_like(self.upconv1(out_iconv2), x) |
|
|
|
|
|
|
|
disp2_up_d = self.crop_like( |
|
|
|
|
|
|
|
torch.nn.functional.interpolate(disp2_d, scale_factor=2, mode='bilinear', align_corners=False), x) |
|
|
|
|
|
|
|
concat1_d = torch.cat((out_upconv1_d, disp2_up_d), 1) |
|
|
|
|
|
|
|
out_iconv1_d = self.iconv1(concat1_d) |
|
|
|
|
|
|
|
disp1_d = self.predict_disp1_double(out_iconv1_d) |
|
|
|
|
|
|
|
|
|
|
|
if self.output_ms: |
|
|
|
if self.output_ms: |
|
|
|
return disp1, disp2, disp3 |
|
|
|
return disp1, disp2, disp3 |
|
|
|
else: |
|
|
|
else: |
|
|
@ -411,7 +427,6 @@ class RectifiedPatternSimilarityLoss(TimedModule): |
|
|
|
def tforward(self, disp0, im, std=None): |
|
|
|
def tforward(self, disp0, im, std=None): |
|
|
|
self.pattern = self.pattern.to(disp0.device) |
|
|
|
self.pattern = self.pattern.to(disp0.device) |
|
|
|
self.uv0 = self.uv0.to(disp0.device) |
|
|
|
self.uv0 = self.uv0.to(disp0.device) |
|
|
|
|
|
|
|
|
|
|
|
uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:]) |
|
|
|
uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:]) |
|
|
|
uv1 = torch.empty_like(uv0) |
|
|
|
uv1 = torch.empty_like(uv0) |
|
|
|
uv1[..., 0] = uv0[..., 0] - disp0.contiguous().view(disp0.shape[0], -1) |
|
|
|
uv1[..., 0] = uv0[..., 0] - disp0.contiguous().view(disp0.shape[0], -1) |
|
|
@ -512,7 +527,7 @@ class ProjectionBaseLoss(TimedModule): |
|
|
|
xyz = xyz + t.reshape(bs, 1, 3) |
|
|
|
xyz = xyz + t.reshape(bs, 1, 3) |
|
|
|
|
|
|
|
|
|
|
|
Kt = self.K.transpose(1, 2).expand(bs, -1, -1) |
|
|
|
Kt = self.K.transpose(1, 2).expand(bs, -1, -1) |
|
|
|
uv = torch.bmm(xyz, Kt) |
|
|
|
uv = torch.bmm(xyz, Kt.float()) |
|
|
|
|
|
|
|
|
|
|
|
d = uv[:, :, 2:3] |
|
|
|
d = uv[:, :, 2:3] |
|
|
|
|
|
|
|
|
|
|
|