general improvements
This commit is contained in:
parent
7633990c81
commit
168516924e
@ -224,7 +224,7 @@ class DispNetS(TimedModule):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def crop_like(self, input, ref):
|
def crop_like(self, input, ref):
|
||||||
assert (input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3))
|
assert (input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3)), f'Assertion ({input.size(2)} >= {ref.size(2)} and {input.size(3)} >= {ref.size(3)}) failed'
|
||||||
return input[:, :, :ref.size(2), :ref.size(3)]
|
return input[:, :, :ref.size(2), :ref.size(3)]
|
||||||
|
|
||||||
def tforward(self, x):
|
def tforward(self, x):
|
||||||
@ -291,7 +291,8 @@ class DispNetS(TimedModule):
|
|||||||
|
|
||||||
if self.output_ms:
|
if self.output_ms:
|
||||||
if self.double_head:
|
if self.double_head:
|
||||||
return (disp1, disp1_d), (disp2, disp2_d), disp3, disp4
|
# NOTE return all tuples for easier handling
|
||||||
|
return (disp1, disp1_d), (disp2, disp2_d), (disp3, disp3), (disp4, disp4)
|
||||||
return disp1, disp2, disp3, disp4
|
return disp1, disp2, disp3, disp4
|
||||||
else:
|
else:
|
||||||
if self.double_head:
|
if self.double_head:
|
||||||
@ -304,8 +305,8 @@ class DispNetShallow(DispNetS):
|
|||||||
Edge Decoder based on DispNetS with fewer layers
|
Edge Decoder based on DispNetS with fewer layers
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False):
|
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False, double_head=False):
|
||||||
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init)
|
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init, double_head=False)
|
||||||
self.mod_name = 'DispNetShallow'
|
self.mod_name = 'DispNetShallow'
|
||||||
conv_planes = [32, 64, 128, 256, 512, 512, 512]
|
conv_planes = [32, 64, 128, 256, 512, 512, 512]
|
||||||
upconv_planes = [512, 512, 256, 128, 64, 32, 16]
|
upconv_planes = [512, 512, 256, 128, 64, 32, 16]
|
||||||
@ -335,6 +336,21 @@ class DispNetShallow(DispNetS):
|
|||||||
out_iconv1 = self.iconv1(concat1)
|
out_iconv1 = self.iconv1(concat1)
|
||||||
disp1 = self.predict_disp1(out_iconv1)
|
disp1 = self.predict_disp1(out_iconv1)
|
||||||
|
|
||||||
|
if self.double_head:
|
||||||
|
out_upconv2_d = self.crop_like(self.upconv2(out_iconv3), out_conv1)
|
||||||
|
disp3_up_d = self.crop_like(
|
||||||
|
torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
|
||||||
|
concat2_d = torch.cat((out_upconv2_d, out_conv1, disp3_up_d), 1)
|
||||||
|
out_iconv2_d = self.iconv2(concat2_d)
|
||||||
|
disp2_d = self.predict_disp2_double(out_iconv2_d)
|
||||||
|
|
||||||
|
out_upconv1_d = self.crop_like(self.upconv1(out_iconv2), x)
|
||||||
|
disp2_up_d = self.crop_like(
|
||||||
|
torch.nn.functional.interpolate(disp2_d, scale_factor=2, mode='bilinear', align_corners=False), x)
|
||||||
|
concat1_d = torch.cat((out_upconv1_d, disp2_up_d), 1)
|
||||||
|
out_iconv1_d = self.iconv1(concat1_d)
|
||||||
|
disp1_d = self.predict_disp1_double(out_iconv1_d)
|
||||||
|
|
||||||
if self.output_ms:
|
if self.output_ms:
|
||||||
return disp1, disp2, disp3
|
return disp1, disp2, disp3
|
||||||
else:
|
else:
|
||||||
@ -411,7 +427,6 @@ class RectifiedPatternSimilarityLoss(TimedModule):
|
|||||||
def tforward(self, disp0, im, std=None):
|
def tforward(self, disp0, im, std=None):
|
||||||
self.pattern = self.pattern.to(disp0.device)
|
self.pattern = self.pattern.to(disp0.device)
|
||||||
self.uv0 = self.uv0.to(disp0.device)
|
self.uv0 = self.uv0.to(disp0.device)
|
||||||
|
|
||||||
uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:])
|
uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:])
|
||||||
uv1 = torch.empty_like(uv0)
|
uv1 = torch.empty_like(uv0)
|
||||||
uv1[..., 0] = uv0[..., 0] - disp0.contiguous().view(disp0.shape[0], -1)
|
uv1[..., 0] = uv0[..., 0] - disp0.contiguous().view(disp0.shape[0], -1)
|
||||||
@ -512,7 +527,7 @@ class ProjectionBaseLoss(TimedModule):
|
|||||||
xyz = xyz + t.reshape(bs, 1, 3)
|
xyz = xyz + t.reshape(bs, 1, 3)
|
||||||
|
|
||||||
Kt = self.K.transpose(1, 2).expand(bs, -1, -1)
|
Kt = self.K.transpose(1, 2).expand(bs, -1, -1)
|
||||||
uv = torch.bmm(xyz, Kt)
|
uv = torch.bmm(xyz, Kt.float())
|
||||||
|
|
||||||
d = uv[:, :, 2:3]
|
d = uv[:, :, 2:3]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user