change a bunch of stuff, add wip lightning implementation

This commit is contained in:
2022-08-24 16:25:12 +02:00
parent 11959eef61
commit 63da24f429
10 changed files with 916 additions and 112 deletions
+10 -3
View File
@@ -86,8 +86,15 @@ class LocalFeatureTransformer(nn.Module):
"""
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
for layer, name in zip(self.layers, self.layer_names):
# NOTE Workaround for non statically determinable zip
# for layer, name in zip(self.layers, self.layer_names):
# layer_zip = ((layer, self.layer_names[i]) for i, layer in enumerate(self.layers))
# layer_zip = []
# for i, layer in enumerate(self.layers):
# layer_zip.append((layer, self.layer_names[i]))
for i, layer in enumerate(self.layers):
name = self.layer_names[i]
if name == 'self':
feat0 = layer(feat0, feat0, mask0, mask0)
feat1 = layer(feat1, feat1, mask1, mask1)
@@ -97,4 +104,4 @@ class LocalFeatureTransformer(nn.Module):
else:
raise KeyError
return feat0, feat1
return feat0, feat1
+19 -7
View File
@@ -36,6 +36,12 @@ class CREStereo(nn.Module):
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=self.dropout)
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
# # NOTE Position_encoding as workaround for TensorRt
# image1_shape = [1, 2, 480, 640]
# self.pos_encoding_fn_small = PositionEncodingSine(
# d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
# )
# loftr
self.self_att_fn = LocalFeatureTransformer(
d_model=256, nhead=8, layer_names=["self"] * 1, attention="linear"
@@ -81,7 +87,7 @@ class CREStereo(nn.Module):
zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device)
return zero_flow
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False):
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True):
""" Estimate optical flow between pair of frames """
image1 = 2 * (image1 / 255.0) - 1.0
@@ -130,17 +136,22 @@ class CREStereo(nn.Module):
inp_dw16 = F.avg_pool2d(inp, 4, stride=4)
# positional encoding and self-attention
pos_encoding_fn_small = PositionEncodingSine(
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
)
# pos_encoding_fn_small = PositionEncodingSine(
# d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
# )
# 'n c h w -> n (h w) c'
x_tmp = pos_encoding_fn_small(fmap1_dw16)
x_tmp = self.pos_encoding_fn_small(fmap1_dw16)
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
# 'n c h w -> n (h w) c'
x_tmp = pos_encoding_fn_small(fmap2_dw16)
x_tmp = self.pos_encoding_fn_small(fmap2_dw16)
fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
# FIXME experimental ! no self-attention for pattern
if not self_attend_right:
fmap1_dw16, _ = self.self_att_fn(fmap1_dw16, fmap2_dw16)
else:
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
fmap1_dw16, fmap2_dw16 = [
x.reshape(x.shape[0], image1.shape[2] // 16, -1, x.shape[2]).permute(0, 3, 1, 2)
for x in [fmap1_dw16, fmap2_dw16]
@@ -258,3 +269,4 @@ class CREStereo(nn.Module):
return flow_up
return predictions
+29 -12
View File
@@ -1,6 +1,8 @@
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
# Ref: https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py
class ResidualBlock(nn.Module):
@@ -96,28 +98,43 @@ class BasicEncoder(nn.Module):
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
def forward(self, x: List[Tensor]):
# NOTE always assume list, otherwise TensorRT is sad
# batch_dim = x[0].shape[0]
# x_tensor = torch.cat(list(x), dim=0)
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x_tensor = torch.cat(x, dim=0)
else:
x_tensor = x
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
print()
print()
print(x_tensor.shape)
print()
print()
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x_tensor = self.conv1(x_tensor)
x_tensor = self.norm1(x_tensor)
x_tensor = self.relu1(x_tensor)
x = self.conv2(x)
x_tensor = self.layer1(x_tensor)
x_tensor = self.layer2(x_tensor)
x_tensor = self.layer3(x_tensor)
x_tensor = self.conv2(x_tensor)
if self.dropout is not None:
x = self.dropout(x)
x_tensor = self.dropout(x_tensor)
if is_list:
x = torch.split(x, x.shape[0]//2, dim=0)
x_list = torch.split(x_tensor, x_tensor.shape[0]//2, dim=0)
return x_list
return x
x_list = torch.split(x_tensor, x_tensor.shape[0]//2, dim=0)
return x_list
# return list(x)
+1 -1
View File
@@ -77,7 +77,7 @@ class BasicUpdateBlock(nn.Module):
nn.ReLU(inplace=True),
nn.Conv2d(256, mask_size**2 *9, 1, padding=0))
def forward(self, net, inp, corr, flow, upsample=True):
def forward(self, net, inp, corr, flow, upsample: bool=True):
# print(inp.shape, corr.shape, flow.shape)
motion_features = self.encoder(flow, corr)
# print(motion_features.shape, inp.shape)