fix lightning, prepare sweeps

This commit is contained in:
2022-08-27 11:21:00 +02:00
parent d8169e01bc
commit 37c537ca31
5 changed files with 151 additions and 87 deletions
+8 -6
View File
@@ -38,10 +38,10 @@ class CREStereo(nn.Module):
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
# # NOTE Position_encoding as workaround for TensorRt
image1_shape = [1, 2, 480, 640]
self.pos_encoding_fn_small = PositionEncodingSine(
d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
)
# image1_shape = [1, 2, 480, 640]
# self.pos_encoding_fn_small = PositionEncodingSine(
# d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
# )
# loftr
self.self_att_fn = LocalFeatureTransformer(
@@ -141,10 +141,12 @@ class CREStereo(nn.Module):
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
)
# 'n c h w -> n (h w) c'
x_tmp = self.pos_encoding_fn_small(fmap1_dw16)
# x_tmp = self.pos_encoding_fn_small(fmap1_dw16)
x_tmp = pos_encoding_fn_small(fmap1_dw16)
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
# 'n c h w -> n (h w) c'
x_tmp = self.pos_encoding_fn_small(fmap2_dw16)
# x_tmp = self.pos_encoding_fn_small(fmap2_dw16)
x_tmp = pos_encoding_fn_small(fmap2_dw16)
fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
# FIXME experimental ! no self-attention for pattern