finish lightningification\n\nTraining still seems borked
This commit is contained in:
+9
-8
@@ -22,9 +22,10 @@ except:
|
||||
|
||||
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/raft.py
|
||||
class CREStereo(nn.Module):
|
||||
def __init__(self, max_disp=192, mixed_precision=False, test_mode=False):
|
||||
def __init__(self, max_disp=192, mixed_precision=False, test_mode=False, batch_size=4):
|
||||
super(CREStereo, self).__init__()
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.max_flow = max_disp
|
||||
self.mixed_precision = mixed_precision
|
||||
self.test_mode = test_mode
|
||||
@@ -37,10 +38,10 @@ class CREStereo(nn.Module):
|
||||
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
|
||||
|
||||
# # NOTE Position_encoding as workaround for TensorRt
|
||||
# image1_shape = [1, 2, 480, 640]
|
||||
# self.pos_encoding_fn_small = PositionEncodingSine(
|
||||
# d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
|
||||
# )
|
||||
image1_shape = [1, 2, 480, 640]
|
||||
self.pos_encoding_fn_small = PositionEncodingSine(
|
||||
d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
|
||||
)
|
||||
|
||||
# loftr
|
||||
self.self_att_fn = LocalFeatureTransformer(
|
||||
@@ -136,9 +137,9 @@ class CREStereo(nn.Module):
|
||||
inp_dw16 = F.avg_pool2d(inp, 4, stride=4)
|
||||
|
||||
# positional encoding and self-attention
|
||||
# pos_encoding_fn_small = PositionEncodingSine(
|
||||
# d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
|
||||
# )
|
||||
pos_encoding_fn_small = PositionEncodingSine(
|
||||
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
|
||||
)
|
||||
# 'n c h w -> n (h w) c'
|
||||
x_tmp = self.pos_encoding_fn_small(fmap1_dw16)
|
||||
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
||||
|
||||
@@ -111,12 +111,6 @@ class BasicEncoder(nn.Module):
|
||||
else:
|
||||
x_tensor = x
|
||||
|
||||
print()
|
||||
print()
|
||||
print(x_tensor.shape)
|
||||
print()
|
||||
print()
|
||||
|
||||
x_tensor = self.conv1(x_tensor)
|
||||
x_tensor = self.norm1(x_tensor)
|
||||
x_tensor = self.relu1(x_tensor)
|
||||
|
||||
Reference in New Issue
Block a user