fix lightning, prepare sweeps
This commit is contained in:
+8
-6
@@ -38,10 +38,10 @@ class CREStereo(nn.Module):
|
||||
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
|
||||
|
||||
# # NOTE Position_encoding as workaround for TensorRt
|
||||
image1_shape = [1, 2, 480, 640]
|
||||
self.pos_encoding_fn_small = PositionEncodingSine(
|
||||
d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
|
||||
)
|
||||
# image1_shape = [1, 2, 480, 640]
|
||||
# self.pos_encoding_fn_small = PositionEncodingSine(
|
||||
# d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
|
||||
# )
|
||||
|
||||
# loftr
|
||||
self.self_att_fn = LocalFeatureTransformer(
|
||||
@@ -141,10 +141,12 @@ class CREStereo(nn.Module):
|
||||
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
|
||||
)
|
||||
# 'n c h w -> n (h w) c'
|
||||
x_tmp = self.pos_encoding_fn_small(fmap1_dw16)
|
||||
# x_tmp = self.pos_encoding_fn_small(fmap1_dw16)
|
||||
x_tmp = pos_encoding_fn_small(fmap1_dw16)
|
||||
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
||||
# 'n c h w -> n (h w) c'
|
||||
x_tmp = self.pos_encoding_fn_small(fmap2_dw16)
|
||||
# x_tmp = self.pos_encoding_fn_small(fmap2_dw16)
|
||||
x_tmp = pos_encoding_fn_small(fmap2_dw16)
|
||||
fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
||||
|
||||
# FIXME experimental ! no self-attention for pattern
|
||||
|
||||
Reference in New Issue
Block a user