finish lightningification\n\nTraining still seems borked

This commit is contained in:
2022-08-24 19:18:20 +02:00
parent 0e2a4b2340
commit d8169e01bc
4 changed files with 103 additions and 185 deletions
+9 -8
View File
@@ -22,9 +22,10 @@ except:
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/raft.py
class CREStereo(nn.Module):
def __init__(self, max_disp=192, mixed_precision=False, test_mode=False):
def __init__(self, max_disp=192, mixed_precision=False, test_mode=False, batch_size=4):
super(CREStereo, self).__init__()
self.batch_size = batch_size
self.max_flow = max_disp
self.mixed_precision = mixed_precision
self.test_mode = test_mode
@@ -37,10 +38,10 @@ class CREStereo(nn.Module):
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
# # NOTE Position_encoding as workaround for TensorRt
# image1_shape = [1, 2, 480, 640]
# self.pos_encoding_fn_small = PositionEncodingSine(
# d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
# )
image1_shape = [1, 2, 480, 640]
self.pos_encoding_fn_small = PositionEncodingSine(
d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
)
# loftr
self.self_att_fn = LocalFeatureTransformer(
@@ -136,9 +137,9 @@ class CREStereo(nn.Module):
inp_dw16 = F.avg_pool2d(inp, 4, stride=4)
# positional encoding and self-attention
# pos_encoding_fn_small = PositionEncodingSine(
# d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
# )
pos_encoding_fn_small = PositionEncodingSine(
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
)
# 'n c h w -> n (h w) c'
x_tmp = self.pos_encoding_fn_small(fmap1_dw16)
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
-6
View File
@@ -111,12 +111,6 @@ class BasicEncoder(nn.Module):
else:
x_tensor = x
print()
print()
print(x_tensor.shape)
print()
print()
x_tensor = self.conv1(x_tensor)
x_tensor = self.norm1(x_tensor)
x_tensor = self.relu1(x_tensor)