change a bunch of stuff, add wip lightning implementation
This commit is contained in:
@@ -86,8 +86,15 @@ class LocalFeatureTransformer(nn.Module):
|
||||
"""
|
||||
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
|
||||
|
||||
for layer, name in zip(self.layers, self.layer_names):
|
||||
|
||||
# NOTE Workaround for non statically determinable zip
|
||||
# for layer, name in zip(self.layers, self.layer_names):
|
||||
# layer_zip = ((layer, self.layer_names[i]) for i, layer in enumerate(self.layers))
|
||||
# layer_zip = []
|
||||
# for i, layer in enumerate(self.layers):
|
||||
# layer_zip.append((layer, self.layer_names[i]))
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
name = self.layer_names[i]
|
||||
if name == 'self':
|
||||
feat0 = layer(feat0, feat0, mask0, mask0)
|
||||
feat1 = layer(feat1, feat1, mask1, mask1)
|
||||
@@ -97,4 +104,4 @@ class LocalFeatureTransformer(nn.Module):
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
return feat0, feat1
|
||||
return feat0, feat1
|
||||
|
||||
+19
-7
@@ -36,6 +36,12 @@ class CREStereo(nn.Module):
|
||||
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=self.dropout)
|
||||
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
|
||||
|
||||
# # NOTE Position_encoding as workaround for TensorRt
|
||||
# image1_shape = [1, 2, 480, 640]
|
||||
# self.pos_encoding_fn_small = PositionEncodingSine(
|
||||
# d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
|
||||
# )
|
||||
|
||||
# loftr
|
||||
self.self_att_fn = LocalFeatureTransformer(
|
||||
d_model=256, nhead=8, layer_names=["self"] * 1, attention="linear"
|
||||
@@ -81,7 +87,7 @@ class CREStereo(nn.Module):
|
||||
zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device)
|
||||
return zero_flow
|
||||
|
||||
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False):
|
||||
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True):
|
||||
""" Estimate optical flow between pair of frames """
|
||||
|
||||
image1 = 2 * (image1 / 255.0) - 1.0
|
||||
@@ -130,17 +136,22 @@ class CREStereo(nn.Module):
|
||||
inp_dw16 = F.avg_pool2d(inp, 4, stride=4)
|
||||
|
||||
# positional encoding and self-attention
|
||||
pos_encoding_fn_small = PositionEncodingSine(
|
||||
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
|
||||
)
|
||||
# pos_encoding_fn_small = PositionEncodingSine(
|
||||
# d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
|
||||
# )
|
||||
# 'n c h w -> n (h w) c'
|
||||
x_tmp = pos_encoding_fn_small(fmap1_dw16)
|
||||
x_tmp = self.pos_encoding_fn_small(fmap1_dw16)
|
||||
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
||||
# 'n c h w -> n (h w) c'
|
||||
x_tmp = pos_encoding_fn_small(fmap2_dw16)
|
||||
x_tmp = self.pos_encoding_fn_small(fmap2_dw16)
|
||||
fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
||||
|
||||
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
|
||||
# FIXME experimental ! no self-attention for pattern
|
||||
if not self_attend_right:
|
||||
fmap1_dw16, _ = self.self_att_fn(fmap1_dw16, fmap2_dw16)
|
||||
else:
|
||||
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
|
||||
|
||||
fmap1_dw16, fmap2_dw16 = [
|
||||
x.reshape(x.shape[0], image1.shape[2] // 16, -1, x.shape[2]).permute(0, 3, 1, 2)
|
||||
for x in [fmap1_dw16, fmap2_dw16]
|
||||
@@ -258,3 +269,4 @@ class CREStereo(nn.Module):
|
||||
return flow_up
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
+29
-12
@@ -1,6 +1,8 @@
|
||||
from typing import List
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
# Ref: https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py
|
||||
class ResidualBlock(nn.Module):
|
||||
@@ -96,28 +98,43 @@ class BasicEncoder(nn.Module):
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: List[Tensor]):
|
||||
# NOTE always assume list, otherwise TensorRT is sad
|
||||
# batch_dim = x[0].shape[0]
|
||||
# x_tensor = torch.cat(list(x), dim=0)
|
||||
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
x_tensor = torch.cat(x, dim=0)
|
||||
else:
|
||||
x_tensor = x
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
print()
|
||||
print()
|
||||
print(x_tensor.shape)
|
||||
print()
|
||||
print()
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x_tensor = self.conv1(x_tensor)
|
||||
x_tensor = self.norm1(x_tensor)
|
||||
x_tensor = self.relu1(x_tensor)
|
||||
|
||||
x = self.conv2(x)
|
||||
x_tensor = self.layer1(x_tensor)
|
||||
x_tensor = self.layer2(x_tensor)
|
||||
x_tensor = self.layer3(x_tensor)
|
||||
|
||||
x_tensor = self.conv2(x_tensor)
|
||||
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
x_tensor = self.dropout(x_tensor)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, x.shape[0]//2, dim=0)
|
||||
x_list = torch.split(x_tensor, x_tensor.shape[0]//2, dim=0)
|
||||
return x_list
|
||||
|
||||
return x
|
||||
x_list = torch.split(x_tensor, x_tensor.shape[0]//2, dim=0)
|
||||
return x_list
|
||||
|
||||
# return list(x)
|
||||
|
||||
+1
-1
@@ -77,7 +77,7 @@ class BasicUpdateBlock(nn.Module):
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, mask_size**2 *9, 1, padding=0))
|
||||
|
||||
def forward(self, net, inp, corr, flow, upsample=True):
|
||||
def forward(self, net, inp, corr, flow, upsample: bool=True):
|
||||
# print(inp.shape, corr.shape, flow.shape)
|
||||
motion_features = self.encoder(flow, corr)
|
||||
# print(motion_features.shape, inp.shape)
|
||||
|
||||
Reference in New Issue
Block a user