commit
fc524edd76
@ -0,0 +1,2 @@ |
|||||||
|
# Auto detect text files and perform LF normalization |
||||||
|
* text=auto |
@ -0,0 +1,152 @@ |
|||||||
|
# Byte-compiled / optimized / DLL files |
||||||
|
__pycache__/ |
||||||
|
*.py[cod] |
||||||
|
*$py.class |
||||||
|
|
||||||
|
# C extensions |
||||||
|
*.so |
||||||
|
|
||||||
|
# Distribution / packaging |
||||||
|
.Python |
||||||
|
build/ |
||||||
|
develop-eggs/ |
||||||
|
dist/ |
||||||
|
downloads/ |
||||||
|
eggs/ |
||||||
|
.eggs/ |
||||||
|
lib/ |
||||||
|
lib64/ |
||||||
|
parts/ |
||||||
|
sdist/ |
||||||
|
var/ |
||||||
|
wheels/ |
||||||
|
share/python-wheels/ |
||||||
|
*.egg-info/ |
||||||
|
.installed.cfg |
||||||
|
*.egg |
||||||
|
MANIFEST |
||||||
|
|
||||||
|
# PyInstaller |
||||||
|
# Usually these files are written by a python script from a template |
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it. |
||||||
|
*.manifest |
||||||
|
*.spec |
||||||
|
|
||||||
|
# Installer logs |
||||||
|
pip-log.txt |
||||||
|
pip-delete-this-directory.txt |
||||||
|
|
||||||
|
# Unit test / coverage reports |
||||||
|
htmlcov/ |
||||||
|
.tox/ |
||||||
|
.nox/ |
||||||
|
.coverage |
||||||
|
.coverage.* |
||||||
|
.cache |
||||||
|
nosetests.xml |
||||||
|
coverage.xml |
||||||
|
*.cover |
||||||
|
*.py,cover |
||||||
|
.hypothesis/ |
||||||
|
.pytest_cache/ |
||||||
|
cover/ |
||||||
|
|
||||||
|
# Translations |
||||||
|
*.mo |
||||||
|
*.pot |
||||||
|
|
||||||
|
# Django stuff: |
||||||
|
*.log |
||||||
|
local_settings.py |
||||||
|
db.sqlite3 |
||||||
|
db.sqlite3-journal |
||||||
|
|
||||||
|
# Flask stuff: |
||||||
|
instance/ |
||||||
|
.webassets-cache |
||||||
|
|
||||||
|
# Scrapy stuff: |
||||||
|
.scrapy |
||||||
|
|
||||||
|
# Sphinx documentation |
||||||
|
docs/_build/ |
||||||
|
|
||||||
|
# PyBuilder |
||||||
|
.pybuilder/ |
||||||
|
target/ |
||||||
|
|
||||||
|
# Jupyter Notebook |
||||||
|
.ipynb_checkpoints |
||||||
|
|
||||||
|
# IPython |
||||||
|
profile_default/ |
||||||
|
ipython_config.py |
||||||
|
|
||||||
|
# pyenv |
||||||
|
# For a library or package, you might want to ignore these files since the code is |
||||||
|
# intended to run in multiple environments; otherwise, check them in: |
||||||
|
# .python-version |
||||||
|
|
||||||
|
# pipenv |
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. |
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies |
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not |
||||||
|
# install all needed dependencies. |
||||||
|
#Pipfile.lock |
||||||
|
|
||||||
|
# poetry |
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. |
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more |
||||||
|
# commonly ignored for libraries. |
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control |
||||||
|
#poetry.lock |
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow |
||||||
|
__pypackages__/ |
||||||
|
|
||||||
|
# Celery stuff |
||||||
|
celerybeat-schedule |
||||||
|
celerybeat.pid |
||||||
|
|
||||||
|
# SageMath parsed files |
||||||
|
*.sage.py |
||||||
|
|
||||||
|
# Environments |
||||||
|
.env |
||||||
|
.venv |
||||||
|
env/ |
||||||
|
venv/ |
||||||
|
ENV/ |
||||||
|
env.bak/ |
||||||
|
venv.bak/ |
||||||
|
|
||||||
|
# Spyder project settings |
||||||
|
.spyderproject |
||||||
|
.spyproject |
||||||
|
|
||||||
|
# Rope project settings |
||||||
|
.ropeproject |
||||||
|
|
||||||
|
# mkdocs documentation |
||||||
|
/site |
||||||
|
|
||||||
|
# mypy |
||||||
|
.mypy_cache/ |
||||||
|
.dmypy.json |
||||||
|
dmypy.json |
||||||
|
|
||||||
|
# Pyre type checker |
||||||
|
.pyre/ |
||||||
|
|
||||||
|
# pytype static type analyzer |
||||||
|
.pytype/ |
||||||
|
|
||||||
|
# Cython debug symbols |
||||||
|
cython_debug/ |
||||||
|
|
||||||
|
# PyCharm |
||||||
|
# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can |
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore |
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear |
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder. |
||||||
|
#.idea/ |
@ -0,0 +1,2 @@ |
|||||||
|
# CREStereo-Pytorch |
||||||
|
Non-official Pytorch implementation of the CREStereo(CVPR 2022 Oral). |
@ -0,0 +1,44 @@ |
|||||||
|
import pickle |
||||||
|
import numpy as np |
||||||
|
import megengine as mge |
||||||
|
|
||||||
|
import torch |
||||||
|
import torch.nn.functional as F |
||||||
|
|
||||||
|
def bilinear_sampler(img, coords, mode='bilinear', mask=False): |
||||||
|
|
||||||
|
""" Wrapper for grid_sample, uses pixel coordinates """ |
||||||
|
H, W = img.shape[-2:] |
||||||
|
xgrid, ygrid = coords.split([1,1], dim=-1) |
||||||
|
xgrid = 2*xgrid/(W-1) - 1 |
||||||
|
ygrid = 2*ygrid/(H-1) - 1 |
||||||
|
|
||||||
|
grid = torch.cat([xgrid, ygrid], dim=-1) |
||||||
|
img = F.grid_sample(img, grid, align_corners=True) |
||||||
|
|
||||||
|
if mask: |
||||||
|
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) |
||||||
|
return img, mask.float() |
||||||
|
|
||||||
|
return img |
||||||
|
|
||||||
|
def test_bilinear_sampler(): |
||||||
|
# Getting back the megengine objects: |
||||||
|
with open('test_data/bilinear_sampler_test.pickle', 'rb') as f: |
||||||
|
right_feature_prev, coords, right_feature = pickle.load(f) |
||||||
|
|
||||||
|
right_feature_prev = torch.tensor(right_feature_prev.numpy()) |
||||||
|
coords = torch.tensor(coords.numpy()) |
||||||
|
right_feature = right_feature.numpy() |
||||||
|
|
||||||
|
# Test Pytorch |
||||||
|
right_feature_pytorch = bilinear_sampler(right_feature_prev, coords).numpy() |
||||||
|
|
||||||
|
error = np.mean(right_feature_pytorch-right_feature) |
||||||
|
print(f"test_coords_grid - Avg. Error: {error}, \n \ |
||||||
|
Original shape: {coords.numpy().shape},\n \ |
||||||
|
Obtained shape: {right_feature_pytorch.shape}, Expected shape: {right_feature.shape}") |
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
|
||||||
|
test_bilinear_sampler() |
@ -0,0 +1,29 @@ |
|||||||
|
import pickle |
||||||
|
import numpy as np |
||||||
|
import megengine as mge |
||||||
|
|
||||||
|
import torch |
||||||
|
import torch.nn.functional as F |
||||||
|
|
||||||
|
def coords_grid(batch, ht, wd, device): |
||||||
|
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij') |
||||||
|
coords = torch.stack(coords[::-1], dim=0).float() |
||||||
|
return coords[None].repeat(batch, 1, 1, 1) |
||||||
|
|
||||||
|
def test_coords_grid(): |
||||||
|
# Getting back the megengine objects: |
||||||
|
with open('test_data/coords_grid_test.pickle', 'rb') as f: |
||||||
|
batch, ht, wd, coords = pickle.load(f) |
||||||
|
|
||||||
|
coords = coords.numpy() |
||||||
|
|
||||||
|
# Test Pytorch |
||||||
|
coords_pytorch = coords_grid(batch, ht, wd, 'cpu').numpy() |
||||||
|
|
||||||
|
error = np.mean(coords_pytorch-coords) |
||||||
|
print(f"test_coords_grid - Avg. Error: {error}, \n \ |
||||||
|
Obtained shape: {coords_pytorch.shape}, Expected shape: {coords.shape}") |
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
|
||||||
|
test_coords_grid() |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,51 @@ |
|||||||
|
import pickle |
||||||
|
import numpy as np |
||||||
|
import megengine as mge |
||||||
|
|
||||||
|
import torch |
||||||
|
import torch.nn.functional as F |
||||||
|
|
||||||
|
def manual_pad(x, pady, padx): |
||||||
|
|
||||||
|
pad = (padx, padx, pady, pady) |
||||||
|
return F.pad(torch.tensor(x), pad, "replicate") |
||||||
|
|
||||||
|
|
||||||
|
def test_pad_1_1(): |
||||||
|
# Getting back the megengine objects: |
||||||
|
with open('test_data/manual_pad_test1_1.pickle', 'rb') as f: |
||||||
|
right_feature, pady, padx, right_pad = pickle.load(f) |
||||||
|
|
||||||
|
right_feature = right_feature.numpy() |
||||||
|
right_pad = right_pad.numpy() |
||||||
|
|
||||||
|
# Test Pytorch |
||||||
|
right_pad_pytorch = manual_pad(right_feature, pady, padx).numpy() |
||||||
|
|
||||||
|
error = np.mean(right_pad_pytorch-right_pad) |
||||||
|
print(f"test_pad_1_1 - Avg. Error: {error}, \n \ |
||||||
|
Orig. shape: {right_feature.shape}, \n \ |
||||||
|
Padded shape: {right_pad_pytorch.shape}, Expected shape: {right_pad.shape}") |
||||||
|
|
||||||
|
def test_pad_0_4(): |
||||||
|
# Getting back the megengine objects: |
||||||
|
with open('test_data/manual_pad_test0_4.pickle', 'rb') as f: |
||||||
|
right_feature, pady, padx, right_pad = pickle.load(f) |
||||||
|
|
||||||
|
right_feature = right_feature.numpy() |
||||||
|
right_pad = right_pad.numpy() |
||||||
|
|
||||||
|
# Test Pytorch |
||||||
|
right_pad_pytorch = manual_pad(right_feature, pady, padx).numpy() |
||||||
|
|
||||||
|
error = np.mean(right_pad_pytorch-right_pad) |
||||||
|
print(f"test_pad_0_4 - Avg. Error: {error}, \n \ |
||||||
|
Orig. shape: {right_feature.shape}, \n \ |
||||||
|
Padded shape: {right_pad_pytorch.shape}, Expected shape: {right_pad.shape}") |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
|
||||||
|
test_pad_1_1() |
||||||
|
|
||||||
|
test_pad_0_4() |
@ -0,0 +1,30 @@ |
|||||||
|
import pickle |
||||||
|
import numpy as np |
||||||
|
import megengine as mge |
||||||
|
|
||||||
|
import torch |
||||||
|
import torch.nn.functional as F |
||||||
|
|
||||||
|
def test_meshgrid(): |
||||||
|
# Getting back the megengine objects: |
||||||
|
with open('test_data/meshgrid_np_test.pkl', 'rb') as f: |
||||||
|
rx, dilatex, ry, dilatey, x_grid, y_grid = pickle.load(f) |
||||||
|
|
||||||
|
x_grid = x_grid.numpy() |
||||||
|
y_grid = y_grid.numpy() |
||||||
|
|
||||||
|
# Test Pytorch |
||||||
|
x_grid_pytorch, y_grid_pytorch = torch.meshgrid(torch.arange(-rx, rx + 1, dilatex, device='cpu'), |
||||||
|
torch.arange(-ry, ry + 1, dilatey, device='cpu'), indexing='xy') |
||||||
|
|
||||||
|
|
||||||
|
error_x = np.mean(x_grid_pytorch.numpy()-x_grid) |
||||||
|
error_y = np.mean(y_grid_pytorch.numpy()-y_grid) |
||||||
|
print(f"test_meshgrid (X) - Avg. Error: {error_x}, \n \ |
||||||
|
Obtained shape: {x_grid_pytorch.numpy().shape}, Expected shape: {x_grid.shape}") |
||||||
|
print(f"test_meshgrid (Y) - Avg. Error: {error_y}, \n \ |
||||||
|
Obtained shape: {y_grid_pytorch.numpy().shape}, Expected shape: {y_grid.shape}") |
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
|
||||||
|
test_meshgrid() |
@ -0,0 +1,31 @@ |
|||||||
|
import pickle |
||||||
|
import numpy as np |
||||||
|
import megengine as mge |
||||||
|
|
||||||
|
import torch |
||||||
|
import torch.nn.functional as F |
||||||
|
|
||||||
|
def test_offset(): |
||||||
|
# Getting back the megengine objects: |
||||||
|
with open('test_data/offset_test.pkl', 'rb') as f: |
||||||
|
x_grid, y_grid, reshape_shape, transpose_order, expand_size, repeat_size, repeat_axis, offsets = pickle.load(f) |
||||||
|
|
||||||
|
x_grid = torch.tensor(x_grid.numpy()) |
||||||
|
y_grid = torch.tensor(y_grid.numpy()) |
||||||
|
offsets_mge = offsets.numpy() |
||||||
|
N = repeat_size |
||||||
|
|
||||||
|
# Test Pytorch |
||||||
|
offsets = torch.stack((x_grid, y_grid)) |
||||||
|
offsets = offsets.reshape(2, -1).permute(1, 0) |
||||||
|
for d in sorted((0, 2, 3)): |
||||||
|
offsets = offsets.unsqueeze(d) |
||||||
|
offsets = offsets.repeat_interleave(N, dim=0) |
||||||
|
|
||||||
|
error = np.mean(offsets.numpy()-offsets_mge) |
||||||
|
print(f"test_offset - Avg. Error: {error}, \n \ |
||||||
|
Obtained shape: {offsets.numpy().shape}, Expected shape: {offsets_mge.shape}") |
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
|
||||||
|
test_offset() |
@ -0,0 +1,47 @@ |
|||||||
|
import pickle |
||||||
|
import numpy as np |
||||||
|
import megengine as mge |
||||||
|
|
||||||
|
import torch |
||||||
|
import torch.nn.functional as F |
||||||
|
|
||||||
|
def test_split(): |
||||||
|
# Getting back the megengine objects: |
||||||
|
with open('test_data/split_test.pkl', 'rb') as f: |
||||||
|
left_feature, size, axis, lefts = pickle.load(f) |
||||||
|
|
||||||
|
left_feature = torch.tensor(left_feature.numpy()) |
||||||
|
|
||||||
|
# Test Pytorch |
||||||
|
lefts_pytorch = torch.split(left_feature, left_feature.shape[axis]//size, dim=axis) |
||||||
|
|
||||||
|
for i, (left_pytorch, left) in enumerate(zip(lefts_pytorch, lefts)): |
||||||
|
|
||||||
|
error = np.mean(left_pytorch.numpy()-left.numpy()) |
||||||
|
print(f"test_split {i} - Avg. Error: {error}, \n \ |
||||||
|
Obtained shape: {left_pytorch.numpy().shape}, Expected shape: {left.numpy().shape}\n") |
||||||
|
|
||||||
|
def test_split_list(): |
||||||
|
# Getting back the megengine objects: |
||||||
|
with open('test_data/split_test_list.pkl', 'rb') as f: |
||||||
|
fmap1, size, axis, net, inp = pickle.load(f) |
||||||
|
|
||||||
|
fmap1 = torch.tensor(fmap1.numpy()) |
||||||
|
net = net.numpy() |
||||||
|
inp = inp.numpy() |
||||||
|
|
||||||
|
# Test Pytorch |
||||||
|
net_pytorch, inp_pytorch = torch.split(fmap1, [size[0],size[0]], dim=axis) |
||||||
|
|
||||||
|
error_net = np.mean(net_pytorch.numpy()-net) |
||||||
|
error_inp = np.mean(inp_pytorch.numpy()-inp) |
||||||
|
print(f"test_split_list (net) - Avg. Error: {error_net}, \n \ |
||||||
|
Obtained shape: {net_pytorch.numpy().shape}, Expected shape: {net.shape}\n") |
||||||
|
print(f"test_split_list (inp) - Avg. Error: {error_inp}, \n \ |
||||||
|
Obtained shape: {inp_pytorch.numpy().shape}, Expected shape: {inp.shape}\n") |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
|
||||||
|
test_split() |
||||||
|
test_split_list() |
@ -0,0 +1 @@ |
|||||||
|
from .crestereo import CREStereo as Model |
@ -0,0 +1,2 @@ |
|||||||
|
from .transformer import LocalFeatureTransformer |
||||||
|
from .position_encoding import PositionEncodingSine |
@ -0,0 +1,81 @@ |
|||||||
|
""" |
||||||
|
Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" |
||||||
|
Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py |
||||||
|
""" |
||||||
|
|
||||||
|
import torch |
||||||
|
from torch.nn import Module, Dropout |
||||||
|
|
||||||
|
|
||||||
|
def elu_feature_map(x): |
||||||
|
return torch.nn.functional.elu(x) + 1 |
||||||
|
|
||||||
|
|
||||||
|
class LinearAttention(Module): |
||||||
|
def __init__(self, eps=1e-6): |
||||||
|
super().__init__() |
||||||
|
self.feature_map = elu_feature_map |
||||||
|
self.eps = eps |
||||||
|
|
||||||
|
def forward(self, queries, keys, values, q_mask=None, kv_mask=None): |
||||||
|
""" Multi-Head linear attention proposed in "Transformers are RNNs" |
||||||
|
Args: |
||||||
|
queries: [N, L, H, D] |
||||||
|
keys: [N, S, H, D] |
||||||
|
values: [N, S, H, D] |
||||||
|
q_mask: [N, L] |
||||||
|
kv_mask: [N, S] |
||||||
|
Returns: |
||||||
|
queried_values: (N, L, H, D) |
||||||
|
""" |
||||||
|
Q = self.feature_map(queries) |
||||||
|
K = self.feature_map(keys) |
||||||
|
|
||||||
|
# set padded position to zero |
||||||
|
if q_mask is not None: |
||||||
|
Q = Q * q_mask[:, :, None, None] |
||||||
|
if kv_mask is not None: |
||||||
|
K = K * kv_mask[:, :, None, None] |
||||||
|
values = values * kv_mask[:, :, None, None] |
||||||
|
|
||||||
|
v_length = values.size(1) |
||||||
|
values = values / v_length # prevent fp16 overflow |
||||||
|
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V |
||||||
|
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) |
||||||
|
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length |
||||||
|
|
||||||
|
return queried_values.contiguous() |
||||||
|
|
||||||
|
|
||||||
|
class FullAttention(Module): |
||||||
|
def __init__(self, use_dropout=False, attention_dropout=0.1): |
||||||
|
super().__init__() |
||||||
|
self.use_dropout = use_dropout |
||||||
|
self.dropout = Dropout(attention_dropout) |
||||||
|
|
||||||
|
def forward(self, queries, keys, values, q_mask=None, kv_mask=None): |
||||||
|
""" Multi-head scaled dot-product attention, a.k.a full attention. |
||||||
|
Args: |
||||||
|
queries: [N, L, H, D] |
||||||
|
keys: [N, S, H, D] |
||||||
|
values: [N, S, H, D] |
||||||
|
q_mask: [N, L] |
||||||
|
kv_mask: [N, S] |
||||||
|
Returns: |
||||||
|
queried_values: (N, L, H, D) |
||||||
|
""" |
||||||
|
|
||||||
|
# Compute the unnormalized attention and apply the masks |
||||||
|
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) |
||||||
|
if kv_mask is not None: |
||||||
|
QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) |
||||||
|
|
||||||
|
# Compute the attention and the weighted average |
||||||
|
softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) |
||||||
|
A = torch.softmax(softmax_temp * QK, dim=2) |
||||||
|
if self.use_dropout: |
||||||
|
A = self.dropout(A) |
||||||
|
|
||||||
|
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) |
||||||
|
|
||||||
|
return queried_values.contiguous() |
@ -0,0 +1,42 @@ |
|||||||
|
import math |
||||||
|
import torch |
||||||
|
from torch import nn |
||||||
|
|
||||||
|
|
||||||
|
class PositionEncodingSine(nn.Module): |
||||||
|
""" |
||||||
|
This is a sinusoidal position encoding that generalized to 2-dimensional images |
||||||
|
""" |
||||||
|
|
||||||
|
def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True): |
||||||
|
""" |
||||||
|
Args: |
||||||
|
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels |
||||||
|
temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41), |
||||||
|
the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact |
||||||
|
on the final performance. For now, we keep both impls for backward compatability. |
||||||
|
We will remove the buggy impl after re-training all variants of our released models. |
||||||
|
""" |
||||||
|
super().__init__() |
||||||
|
|
||||||
|
pe = torch.zeros((d_model, *max_shape)) |
||||||
|
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) |
||||||
|
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) |
||||||
|
if temp_bug_fix: |
||||||
|
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2))) |
||||||
|
else: # a buggy implementation (for backward compatability only) |
||||||
|
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2)) |
||||||
|
div_term = div_term[:, None, None] # [C//4, 1, 1] |
||||||
|
pe[0::4, :, :] = torch.sin(x_position * div_term) |
||||||
|
pe[1::4, :, :] = torch.cos(x_position * div_term) |
||||||
|
pe[2::4, :, :] = torch.sin(y_position * div_term) |
||||||
|
pe[3::4, :, :] = torch.cos(y_position * div_term) |
||||||
|
|
||||||
|
self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] |
||||||
|
|
||||||
|
def forward(self, x): |
||||||
|
""" |
||||||
|
Args: |
||||||
|
x: [N, C, H, W] |
||||||
|
""" |
||||||
|
return x + self.pe[:, :, :x.size(2), :x.size(3)] |
@ -0,0 +1,100 @@ |
|||||||
|
import copy |
||||||
|
import torch |
||||||
|
import torch.nn as nn |
||||||
|
from .linear_attention import LinearAttention, FullAttention |
||||||
|
|
||||||
|
#Ref: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py |
||||||
|
class LoFTREncoderLayer(nn.Module): |
||||||
|
def __init__(self, |
||||||
|
d_model, |
||||||
|
nhead, |
||||||
|
attention='linear'): |
||||||
|
super(LoFTREncoderLayer, self).__init__() |
||||||
|
|
||||||
|
self.dim = d_model // nhead |
||||||
|
self.nhead = nhead |
||||||
|
|
||||||
|
# multi-head attention |
||||||
|
self.q_proj = nn.Linear(d_model, d_model, bias=False) |
||||||
|
self.k_proj = nn.Linear(d_model, d_model, bias=False) |
||||||
|
self.v_proj = nn.Linear(d_model, d_model, bias=False) |
||||||
|
self.attention = LinearAttention() if attention == 'linear' else FullAttention() |
||||||
|
self.merge = nn.Linear(d_model, d_model, bias=False) |
||||||
|
|
||||||
|
# feed-forward network |
||||||
|
self.mlp = nn.Sequential( |
||||||
|
nn.Linear(d_model*2, d_model*2, bias=False), |
||||||
|
nn.ReLU(True), |
||||||
|
nn.Linear(d_model*2, d_model, bias=False), |
||||||
|
) |
||||||
|
|
||||||
|
# norm and dropout |
||||||
|
self.norm1 = nn.LayerNorm(d_model) |
||||||
|
self.norm2 = nn.LayerNorm(d_model) |
||||||
|
|
||||||
|
def forward(self, x, source, x_mask=None, source_mask=None): |
||||||
|
""" |
||||||
|
Args: |
||||||
|
x (torch.Tensor): [N, L, C] |
||||||
|
source (torch.Tensor): [N, S, C] |
||||||
|
x_mask (torch.Tensor): [N, L] (optional) |
||||||
|
source_mask (torch.Tensor): [N, S] (optional) |
||||||
|
""" |
||||||
|
bs = x.size(0) |
||||||
|
query, key, value = x, source, source |
||||||
|
|
||||||
|
# multi-head attention |
||||||
|
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] |
||||||
|
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] |
||||||
|
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) |
||||||
|
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] |
||||||
|
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] |
||||||
|
message = self.norm1(message) |
||||||
|
|
||||||
|
# feed-forward network |
||||||
|
message = self.mlp(torch.cat([x, message], dim=2)) |
||||||
|
message = self.norm2(message) |
||||||
|
|
||||||
|
return x + message |
||||||
|
|
||||||
|
|
||||||
|
class LocalFeatureTransformer(nn.Module): |
||||||
|
"""A Local Feature Transformer (LoFTR) module.""" |
||||||
|
|
||||||
|
def __init__(self, d_model, nhead, layer_names, attention): |
||||||
|
super(LocalFeatureTransformer, self).__init__() |
||||||
|
|
||||||
|
self.d_model = d_model |
||||||
|
self.nhead = nhead |
||||||
|
self.layer_names = layer_names |
||||||
|
encoder_layer = LoFTREncoderLayer(d_model, nhead, attention) |
||||||
|
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) |
||||||
|
self._reset_parameters() |
||||||
|
|
||||||
|
def _reset_parameters(self): |
||||||
|
for p in self.parameters(): |
||||||
|
if p.dim() > 1: |
||||||
|
nn.init.xavier_uniform_(p) |
||||||
|
|
||||||
|
def forward(self, feat0, feat1, mask0=None, mask1=None): |
||||||
|
""" |
||||||
|
Args: |
||||||
|
feat0 (torch.Tensor): [N, L, C] |
||||||
|
feat1 (torch.Tensor): [N, S, C] |
||||||
|
mask0 (torch.Tensor): [N, L] (optional) |
||||||
|
mask1 (torch.Tensor): [N, S] (optional) |
||||||
|
""" |
||||||
|
|
||||||
|
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" |
||||||
|
|
||||||
|
for layer, name in zip(self.layers, self.layer_names): |
||||||
|
if name == 'self': |
||||||
|
feat0 = layer(feat0, feat0, mask0, mask0) |
||||||
|
feat1 = layer(feat1, feat1, mask1, mask1) |
||||||
|
elif name == 'cross': |
||||||
|
feat0 = layer(feat0, feat1, mask0, mask1) |
||||||
|
feat1 = layer(feat1, feat0, mask1, mask0) |
||||||
|
else: |
||||||
|
raise KeyError |
||||||
|
|
||||||
|
return feat0, feat1 |
@ -0,0 +1,146 @@ |
|||||||
|
import numpy as np |
||||||
|
import torch |
||||||
|
import torch.nn as nn |
||||||
|
import torch.nn.functional as F |
||||||
|
|
||||||
|
from .utils import bilinear_sampler, coords_grid, manual_pad |
||||||
|
|
||||||
|
class AGCL: |
||||||
|
""" |
||||||
|
Implementation of Adaptive Group Correlation Layer (AGCL). |
||||||
|
""" |
||||||
|
|
||||||
|
def __init__(self, fmap1, fmap2, att=None): |
||||||
|
self.fmap1 = fmap1 |
||||||
|
self.fmap2 = fmap2 |
||||||
|
|
||||||
|
self.att = att |
||||||
|
|
||||||
|
self.coords = coords_grid(fmap1.shape[0], fmap1.shape[2], fmap1.shape[3], fmap1.device) |
||||||
|
|
||||||
|
def __call__(self, flow, extra_offset, small_patch=False, iter_mode=False): |
||||||
|
if iter_mode: |
||||||
|
corr = self.corr_iter(self.fmap1, self.fmap2, flow, small_patch) |
||||||
|
else: |
||||||
|
corr = self.corr_att_offset( |
||||||
|
self.fmap1, self.fmap2, flow, extra_offset, small_patch |
||||||
|
) |
||||||
|
return corr |
||||||
|
|
||||||
|
def get_correlation(self, left_feature, right_feature, psize=(3, 3), dilate=(1, 1)): |
||||||
|
|
||||||
|
N, C, H, W = left_feature.shape |
||||||
|
|
||||||
|
di_y, di_x = dilate[0], dilate[1] |
||||||
|
pady, padx = psize[0] // 2 * di_y, psize[1] // 2 * di_x |
||||||
|
|
||||||
|
right_pad = manual_pad(right_feature, pady, padx) |
||||||
|
|
||||||
|
corr_list = [] |
||||||
|
for h in range(0, pady * 2 + 1, di_y): |
||||||
|
for w in range(0, padx * 2 + 1, di_x): |
||||||
|
right_crop = right_pad[:, :, h : h + H, w : w + W] |
||||||
|
assert right_crop.shape == left_feature.shape |
||||||
|
corr = torch.mean(left_feature * right_crop, dim=1, keepdims=True) |
||||||
|
corr_list.append(corr) |
||||||
|
|
||||||
|
corr_final = torch.cat(corr_list, dim=1) |
||||||
|
|
||||||
|
return corr_final |
||||||
|
|
||||||
|
def corr_iter(self, left_feature, right_feature, flow, small_patch): |
||||||
|
|
||||||
|
coords = self.coords + flow |
||||||
|
coords = coords.permute(0, 2, 3, 1) |
||||||
|
right_feature = bilinear_sampler(right_feature, coords) |
||||||
|
|
||||||
|
if small_patch: |
||||||
|
psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)] |
||||||
|
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)] |
||||||
|
else: |
||||||
|
psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)] |
||||||
|
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)] |
||||||
|
|
||||||
|
N, C, H, W = left_feature.shape |
||||||
|
lefts = torch.split(left_feature, left_feature.shape[1]//4, dim=1) |
||||||
|
rights = torch.split(right_feature, right_feature.shape[1]//4, dim=1) |
||||||
|
|
||||||
|
corrs = [] |
||||||
|
for i in range(len(psize_list)): |
||||||
|
corr = self.get_correlation( |
||||||
|
lefts[i], rights[i], psize_list[i], dilate_list[i] |
||||||
|
) |
||||||
|
corrs.append(corr) |
||||||
|
|
||||||
|
final_corr = torch.cat(corrs, dim=1) |
||||||
|
|
||||||
|
return final_corr |
||||||
|
|
||||||
|
def corr_att_offset( |
||||||
|
self, left_feature, right_feature, flow, extra_offset, small_patch |
||||||
|
): |
||||||
|
|
||||||
|
N, C, H, W = left_feature.shape |
||||||
|
|
||||||
|
if self.att is not None: |
||||||
|
left_feature = left_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c' |
||||||
|
right_feature = right_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c' |
||||||
|
# 'n (h w) c -> n c h w' |
||||||
|
left_feature, right_feature = [ |
||||||
|
x.reshape(N, H, W, C).permute(0, 3, 1, 2) |
||||||
|
for x in [left_feature, right_feature] |
||||||
|
] |
||||||
|
|
||||||
|
lefts = torch.split(left_feature, left_feature.shape[1]//4, dim=1) |
||||||
|
rights = torch.split(right_feature, right_feature.shape[1]//4, dim=1) |
||||||
|
|
||||||
|
C = C // 4 |
||||||
|
|
||||||
|
if small_patch: |
||||||
|
psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)] |
||||||
|
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)] |
||||||
|
else: |
||||||
|
psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)] |
||||||
|
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)] |
||||||
|
|
||||||
|
search_num = 9 |
||||||
|
extra_offset = extra_offset.reshape(N, search_num, 2, H, W).permute(0, 1, 3, 4, 2) # [N, search_num, 1, 1, 2] |
||||||
|
|
||||||
|
corrs = [] |
||||||
|
for i in range(len(psize_list)): |
||||||
|
left_feature, right_feature = lefts[i], rights[i] |
||||||
|
psize, dilate = psize_list[i], dilate_list[i] |
||||||
|
|
||||||
|
psizey, psizex = psize[0], psize[1] |
||||||
|
dilatey, dilatex = dilate[0], dilate[1] |
||||||
|
|
||||||
|
ry = psizey // 2 * dilatey |
||||||
|
rx = psizex // 2 * dilatex |
||||||
|
x_grid, y_grid = torch.meshgrid(torch.arange(-rx, rx + 1, dilatex, device=self.fmap1.device), |
||||||
|
torch.arange(-ry, ry + 1, dilatey, device=self.fmap1.device), indexing='xy') |
||||||
|
|
||||||
|
offsets = torch.stack((x_grid, y_grid)) |
||||||
|
offsets = offsets.reshape(2, -1).permute(1, 0) |
||||||
|
for d in sorted((0, 2, 3)): |
||||||
|
offsets = offsets.unsqueeze(d) |
||||||
|
offsets = offsets.repeat_interleave(N, dim=0) |
||||||
|
offsets = offsets + extra_offset |
||||||
|
|
||||||
|
coords = self.coords + flow # [N, 2, H, W] |
||||||
|
coords = coords.permute(0, 2, 3, 1) # [N, H, W, 2] |
||||||
|
coords = torch.unsqueeze(coords, 1) + offsets |
||||||
|
coords = coords.reshape(N, -1, W, 2) # [N, search_num*H, W, 2] |
||||||
|
|
||||||
|
right_feature = bilinear_sampler( |
||||||
|
right_feature, coords |
||||||
|
) # [N, C, search_num*H, W] |
||||||
|
right_feature = right_feature.reshape(N, C, -1, H, W) # [N, C, search_num, H, W] |
||||||
|
left_feature = left_feature.unsqueeze(2).repeat_interleave(right_feature.shape[2], dim=2) |
||||||
|
|
||||||
|
corr = torch.mean(left_feature * right_feature, dim=1) |
||||||
|
|
||||||
|
corrs.append(corr) |
||||||
|
|
||||||
|
final_corr = torch.cat(corrs, dim=1) |
||||||
|
|
||||||
|
return final_corr |
@ -0,0 +1,258 @@ |
|||||||
|
import torch |
||||||
|
import torch.nn as nn |
||||||
|
import torch.nn.functional as F |
||||||
|
|
||||||
|
from .update import BasicUpdateBlock |
||||||
|
from .extractor import BasicEncoder |
||||||
|
from .corr import AGCL |
||||||
|
|
||||||
|
from .attention import PositionEncodingSine, LocalFeatureTransformer |
||||||
|
|
||||||
|
try: |
||||||
|
autocast = torch.cuda.amp.autocast |
||||||
|
except: |
||||||
|
# dummy autocast for PyTorch < 1.6 |
||||||
|
class autocast: |
||||||
|
def __init__(self, enabled): |
||||||
|
pass |
||||||
|
def __enter__(self): |
||||||
|
pass |
||||||
|
def __exit__(self, *args): |
||||||
|
pass |
||||||
|
|
||||||
|
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/raft.py |
||||||
|
class CREStereo(nn.Module): |
||||||
|
def __init__(self, max_disp=192, mixed_precision=False, test_mode=False): |
||||||
|
super(CREStereo, self).__init__() |
||||||
|
|
||||||
|
self.max_flow = max_disp |
||||||
|
self.mixed_precision = mixed_precision |
||||||
|
self.test_mode = test_mode |
||||||
|
|
||||||
|
self.hidden_dim = 128 |
||||||
|
self.context_dim = 128 |
||||||
|
self.dropout = 0 |
||||||
|
|
||||||
|
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=self.dropout) |
||||||
|
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4) |
||||||
|
|
||||||
|
# loftr |
||||||
|
self.self_att_fn = LocalFeatureTransformer( |
||||||
|
d_model=256, nhead=8, layer_names=["self"] * 1, attention="linear" |
||||||
|
) |
||||||
|
self.cross_att_fn = LocalFeatureTransformer( |
||||||
|
d_model=256, nhead=8, layer_names=["cross"] * 1, attention="linear" |
||||||
|
) |
||||||
|
|
||||||
|
# adaptive search |
||||||
|
self.search_num = 9 |
||||||
|
self.conv_offset_16 = nn.Conv2d( |
||||||
|
256, self.search_num * 2, kernel_size=3, stride=1, padding=1 |
||||||
|
) |
||||||
|
self.conv_offset_8 = nn.Conv2d( |
||||||
|
256, self.search_num * 2, kernel_size=3, stride=1, padding=1 |
||||||
|
) |
||||||
|
self.range_16 = 1 |
||||||
|
self.range_8 = 1 |
||||||
|
|
||||||
|
def freeze_bn(self): |
||||||
|
for m in self.modules(): |
||||||
|
if isinstance(m, nn.BatchNorm2d): |
||||||
|
m.eval() |
||||||
|
|
||||||
|
def convex_upsample(self, flow, mask, rate=4): |
||||||
|
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ |
||||||
|
N, _, H, W = flow.shape |
||||||
|
# print(flow.shape, mask.shape, rate) |
||||||
|
mask = mask.view(N, 1, 9, rate, rate, H, W) |
||||||
|
mask = torch.softmax(mask, dim=2) |
||||||
|
|
||||||
|
up_flow = F.unfold(rate * flow, [3,3], padding=1) |
||||||
|
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) |
||||||
|
|
||||||
|
up_flow = torch.sum(mask * up_flow, dim=2) |
||||||
|
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) |
||||||
|
return up_flow.reshape(N, 2, rate*H, rate*W) |
||||||
|
|
||||||
|
def zero_init(self, fmap): |
||||||
|
N, C, H, W = fmap.shape |
||||||
|
_x = torch.zeros([N, 1, H, W], dtype=torch.float32) |
||||||
|
_y = torch.zeros([N, 1, H, W], dtype=torch.float32) |
||||||
|
zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device) |
||||||
|
return zero_flow |
||||||
|
|
||||||
|
def forward(self, image1, image2, iters=10, flow_init=None, upsample=True, test_mode=False): |
||||||
|
""" Estimate optical flow between pair of frames """ |
||||||
|
|
||||||
|
image1 = 2 * (image1 / 255.0) - 1.0 |
||||||
|
image2 = 2 * (image2 / 255.0) - 1.0 |
||||||
|
|
||||||
|
image1 = image1.contiguous() |
||||||
|
image2 = image2.contiguous() |
||||||
|
|
||||||
|
hdim = self.hidden_dim |
||||||
|
cdim = self.context_dim |
||||||
|
|
||||||
|
# run the feature network |
||||||
|
with autocast(enabled=self.mixed_precision): |
||||||
|
fmap1, fmap2 = self.fnet([image1, image2]) |
||||||
|
|
||||||
|
fmap1 = fmap1.float() |
||||||
|
fmap2 = fmap2.float() |
||||||
|
|
||||||
|
with autocast(enabled=self.mixed_precision): |
||||||
|
|
||||||
|
# 1/4 -> 1/8 |
||||||
|
# feature |
||||||
|
fmap1_dw8 = F.avg_pool2d(fmap1, 2, stride=2) |
||||||
|
fmap2_dw8 = F.avg_pool2d(fmap2, 2, stride=2) |
||||||
|
|
||||||
|
# offset |
||||||
|
offset_dw8 = self.conv_offset_8(fmap1_dw8) |
||||||
|
offset_dw8 = self.range_8 * (torch.sigmoid(offset_dw8) - 0.5) * 2.0 |
||||||
|
|
||||||
|
# context |
||||||
|
net, inp = torch.split(fmap1, [hdim,hdim], dim=1) |
||||||
|
net = torch.tanh(net) |
||||||
|
inp = F.relu(inp) |
||||||
|
net_dw8 = F.avg_pool2d(net, 2, stride=2) |
||||||
|
inp_dw8 = F.avg_pool2d(inp, 2, stride=2) |
||||||
|
|
||||||
|
# 1/4 -> 1/16 |
||||||
|
# feature |
||||||
|
fmap1_dw16 = F.avg_pool2d(fmap1, 4, stride=4) |
||||||
|
fmap2_dw16 = F.avg_pool2d(fmap2, 4, stride=4) |
||||||
|
offset_dw16 = self.conv_offset_16(fmap1_dw16) |
||||||
|
offset_dw16 = self.range_16 * (torch.sigmoid(offset_dw16) - 0.5) * 2.0 |
||||||
|
|
||||||
|
# context |
||||||
|
net_dw16 = F.avg_pool2d(net, 4, stride=4) |
||||||
|
inp_dw16 = F.avg_pool2d(inp, 4, stride=4) |
||||||
|
|
||||||
|
# positional encoding and self-attention |
||||||
|
pos_encoding_fn_small = PositionEncodingSine( |
||||||
|
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16) |
||||||
|
) |
||||||
|
# 'n c h w -> n (h w) c' |
||||||
|
x_tmp = pos_encoding_fn_small(fmap1_dw16) |
||||||
|
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1]) |
||||||
|
# 'n c h w -> n (h w) c' |
||||||
|
x_tmp = pos_encoding_fn_small(fmap2_dw16) |
||||||
|
fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1]) |
||||||
|
|
||||||
|
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16) |
||||||
|
fmap1_dw16, fmap2_dw16 = [ |
||||||
|
x.reshape(x.shape[0], image1.shape[2] // 16, -1, x.shape[2]).permute(0, 3, 1, 2) |
||||||
|
for x in [fmap1_dw16, fmap2_dw16] |
||||||
|
] |
||||||
|
|
||||||
|
corr_fn = AGCL(fmap1, fmap2) |
||||||
|
corr_fn_dw8 = AGCL(fmap1_dw8, fmap2_dw8) |
||||||
|
corr_fn_att_dw16 = AGCL(fmap1_dw16, fmap2_dw16, att=self.cross_att_fn) |
||||||
|
|
||||||
|
# Cascaded refinement (1/16 + 1/8 + 1/4) |
||||||
|
predictions = [] |
||||||
|
flow = None |
||||||
|
flow_up = None |
||||||
|
if flow_init is not None: |
||||||
|
scale = fmap1.shape[2] / flow_init.shape[2] |
||||||
|
flow = -scale * F.interpolate( |
||||||
|
flow_init, |
||||||
|
size=(fmap1.shape[2], fmap1.shape[3]), |
||||||
|
mode="bilinear", |
||||||
|
align_corners=True, |
||||||
|
) |
||||||
|
else: |
||||||
|
# zero initialization |
||||||
|
flow_dw16 = self.zero_init(fmap1_dw16) |
||||||
|
|
||||||
|
# Recurrent Update Module |
||||||
|
# RUM: 1/16 |
||||||
|
for itr in range(iters // 2): |
||||||
|
if itr % 2 == 0: |
||||||
|
small_patch = False |
||||||
|
else: |
||||||
|
small_patch = True |
||||||
|
|
||||||
|
flow_dw16 = flow_dw16.detach() |
||||||
|
out_corrs = corr_fn_att_dw16( |
||||||
|
flow_dw16, offset_dw16, small_patch=small_patch |
||||||
|
) |
||||||
|
|
||||||
|
with autocast(enabled=self.mixed_precision): |
||||||
|
net_dw16, up_mask, delta_flow = self.update_block( |
||||||
|
net_dw16, inp_dw16, out_corrs, flow_dw16 |
||||||
|
) |
||||||
|
|
||||||
|
flow_dw16 = flow_dw16 + delta_flow |
||||||
|
flow = self.convex_upsample(flow_dw16, up_mask, rate=4) |
||||||
|
flow_up = -4 * F.interpolate( |
||||||
|
flow, |
||||||
|
size=(4 * flow.shape[2], 4 * flow.shape[3]), |
||||||
|
mode="bilinear", |
||||||
|
align_corners=True, |
||||||
|
) |
||||||
|
predictions.append(flow_up) |
||||||
|
|
||||||
|
scale = fmap1_dw8.shape[2] / flow.shape[2] |
||||||
|
flow_dw8 = -scale * F.interpolate( |
||||||
|
flow, |
||||||
|
size=(fmap1_dw8.shape[2], fmap1_dw8.shape[3]), |
||||||
|
mode="bilinear", |
||||||
|
align_corners=True, |
||||||
|
) |
||||||
|
|
||||||
|
# RUM: 1/8 |
||||||
|
for itr in range(iters // 2): |
||||||
|
if itr % 2 == 0: |
||||||
|
small_patch = False |
||||||
|
else: |
||||||
|
small_patch = True |
||||||
|
|
||||||
|
flow_dw8 = flow_dw8.detach() |
||||||
|
out_corrs = corr_fn_dw8(flow_dw8, offset_dw8, small_patch=small_patch) |
||||||
|
|
||||||
|
with autocast(enabled=self.mixed_precision): |
||||||
|
net_dw8, up_mask, delta_flow = self.update_block( |
||||||
|
net_dw8, inp_dw8, out_corrs, flow_dw8 |
||||||
|
) |
||||||
|
|
||||||
|
flow_dw8 = flow_dw8 + delta_flow |
||||||
|
flow = self.convex_upsample(flow_dw8, up_mask, rate=4) |
||||||
|
flow_up = -2 * F.interpolate( |
||||||
|
flow, |
||||||
|
size=(2 * flow.shape[2], 2 * flow.shape[3]), |
||||||
|
mode="bilinear", |
||||||
|
align_corners=True, |
||||||
|
) |
||||||
|
predictions.append(flow_up) |
||||||
|
|
||||||
|
scale = fmap1.shape[2] / flow.shape[2] |
||||||
|
flow = -scale * F.interpolate( |
||||||
|
flow, |
||||||
|
size=(fmap1.shape[2], fmap1.shape[3]), |
||||||
|
mode="bilinear", |
||||||
|
align_corners=True, |
||||||
|
) |
||||||
|
|
||||||
|
# RUM: 1/4 |
||||||
|
for itr in range(iters): |
||||||
|
if itr % 2 == 0: |
||||||
|
small_patch = False |
||||||
|
else: |
||||||
|
small_patch = True |
||||||
|
|
||||||
|
flow = flow.detach() |
||||||
|
out_corrs = corr_fn(flow, None, small_patch=small_patch, iter_mode=True) |
||||||
|
|
||||||
|
with autocast(enabled=self.mixed_precision): |
||||||
|
net, up_mask, delta_flow = self.update_block(net, inp, out_corrs, flow) |
||||||
|
|
||||||
|
flow = flow + delta_flow |
||||||
|
flow_up = -self.convex_upsample(flow, up_mask, rate=4) |
||||||
|
predictions.append(flow_up) |
||||||
|
|
||||||
|
if self.test_mode: |
||||||
|
return flow_up |
||||||
|
|
||||||
|
return predictions |
@ -0,0 +1,123 @@ |
|||||||
|
import torch |
||||||
|
import torch.nn as nn |
||||||
|
import torch.nn.functional as F |
||||||
|
|
||||||
|
# Ref: https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py |
||||||
|
class ResidualBlock(nn.Module): |
||||||
|
def __init__(self, in_planes, planes, norm_fn='group', stride=1): |
||||||
|
super(ResidualBlock, self).__init__() |
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) |
||||||
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) |
||||||
|
self.relu = nn.ReLU(inplace=True) |
||||||
|
|
||||||
|
num_groups = planes // 8 |
||||||
|
|
||||||
|
if norm_fn == 'group': |
||||||
|
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
||||||
|
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
||||||
|
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
||||||
|
|
||||||
|
elif norm_fn == 'batch': |
||||||
|
self.norm1 = nn.BatchNorm2d(planes) |
||||||
|
self.norm2 = nn.BatchNorm2d(planes) |
||||||
|
self.norm3 = nn.BatchNorm2d(planes) |
||||||
|
|
||||||
|
elif norm_fn == 'instance': |
||||||
|
self.norm1 = nn.InstanceNorm2d(planes) |
||||||
|
self.norm2 = nn.InstanceNorm2d(planes) |
||||||
|
self.norm3 = nn.InstanceNorm2d(planes) |
||||||
|
|
||||||
|
elif norm_fn == 'none': |
||||||
|
self.norm1 = nn.Sequential() |
||||||
|
self.norm2 = nn.Sequential() |
||||||
|
self.norm3 = nn.Sequential() |
||||||
|
|
||||||
|
self.downsample = nn.Sequential( |
||||||
|
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) |
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x): |
||||||
|
y = x |
||||||
|
y = self.relu(self.norm1(self.conv1(y))) |
||||||
|
y = self.relu(self.norm2(self.conv2(y))) |
||||||
|
|
||||||
|
x = self.downsample(x) |
||||||
|
|
||||||
|
return self.relu(x+y) |
||||||
|
|
||||||
|
|
||||||
|
class BasicEncoder(nn.Module): |
||||||
|
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): |
||||||
|
super(BasicEncoder, self).__init__() |
||||||
|
self.norm_fn = norm_fn |
||||||
|
|
||||||
|
if self.norm_fn == 'group': |
||||||
|
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) |
||||||
|
|
||||||
|
elif self.norm_fn == 'batch': |
||||||
|
self.norm1 = nn.BatchNorm2d(64) |
||||||
|
|
||||||
|
elif self.norm_fn == 'instance': |
||||||
|
self.norm1 = nn.InstanceNorm2d(64) |
||||||
|
|
||||||
|
elif self.norm_fn == 'none': |
||||||
|
self.norm1 = nn.Sequential() |
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) |
||||||
|
self.relu1 = nn.ReLU(inplace=True) |
||||||
|
|
||||||
|
self.in_planes = 64 |
||||||
|
self.layer1 = self._make_layer(64, stride=1) |
||||||
|
self.layer2 = self._make_layer(96, stride=2) |
||||||
|
self.layer3 = self._make_layer(128, stride=1) |
||||||
|
|
||||||
|
# output convolution |
||||||
|
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) |
||||||
|
|
||||||
|
self.dropout = None |
||||||
|
if dropout > 0: |
||||||
|
self.dropout = nn.Dropout2d(p=dropout) |
||||||
|
|
||||||
|
for m in self.modules(): |
||||||
|
if isinstance(m, nn.Conv2d): |
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): |
||||||
|
if m.weight is not None: |
||||||
|
nn.init.constant_(m.weight, 1) |
||||||
|
if m.bias is not None: |
||||||
|
nn.init.constant_(m.bias, 0) |
||||||
|
|
||||||
|
def _make_layer(self, dim, stride=1): |
||||||
|
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) |
||||||
|
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) |
||||||
|
layers = (layer1, layer2) |
||||||
|
|
||||||
|
self.in_planes = dim |
||||||
|
return nn.Sequential(*layers) |
||||||
|
|
||||||
|
def forward(self, x): |
||||||
|
|
||||||
|
# if input is list, combine batch dimension |
||||||
|
is_list = isinstance(x, tuple) or isinstance(x, list) |
||||||
|
if is_list: |
||||||
|
batch_dim = x[0].shape[0] |
||||||
|
x = torch.cat(x, dim=0) |
||||||
|
|
||||||
|
x = self.conv1(x) |
||||||
|
x = self.norm1(x) |
||||||
|
x = self.relu1(x) |
||||||
|
|
||||||
|
x = self.layer1(x) |
||||||
|
x = self.layer2(x) |
||||||
|
x = self.layer3(x) |
||||||
|
|
||||||
|
x = self.conv2(x) |
||||||
|
|
||||||
|
if self.dropout is not None: |
||||||
|
x = self.dropout(x) |
||||||
|
|
||||||
|
if is_list: |
||||||
|
x = torch.split(x, x.shape[0]//2, dim=0) |
||||||
|
|
||||||
|
return x |
@ -0,0 +1,91 @@ |
|||||||
|
import torch |
||||||
|
import torch.nn as nn |
||||||
|
import torch.nn.functional as F |
||||||
|
|
||||||
|
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/update.py |
||||||
|
class FlowHead(nn.Module): |
||||||
|
def __init__(self, input_dim=128, hidden_dim=256): |
||||||
|
super(FlowHead, self).__init__() |
||||||
|
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) |
||||||
|
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) |
||||||
|
self.relu = nn.ReLU(inplace=True) |
||||||
|
|
||||||
|
def forward(self, x): |
||||||
|
return self.conv2(self.relu(self.conv1(x))) |
||||||
|
|
||||||
|
|
||||||
|
class SepConvGRU(nn.Module): |
||||||
|
def __init__(self, hidden_dim=128, input_dim=192+128): |
||||||
|
super(SepConvGRU, self).__init__() |
||||||
|
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) |
||||||
|
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) |
||||||
|
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) |
||||||
|
|
||||||
|
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) |
||||||
|
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) |
||||||
|
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) |
||||||
|
|
||||||
|
def forward(self, h, x): |
||||||
|
# horizontal |
||||||
|
hx = torch.cat([h, x], dim=1) |
||||||
|
z = torch.sigmoid(self.convz1(hx)) |
||||||
|
r = torch.sigmoid(self.convr1(hx)) |
||||||
|
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) |
||||||
|
h = (1-z) * h + z * q |
||||||
|
|
||||||
|
# vertical |
||||||
|
hx = torch.cat([h, x], dim=1) |
||||||
|
z = torch.sigmoid(self.convz2(hx)) |
||||||
|
r = torch.sigmoid(self.convr2(hx)) |
||||||
|
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) |
||||||
|
h = (1-z) * h + z * q |
||||||
|
|
||||||
|
return h |
||||||
|
|
||||||
|
|
||||||
|
class BasicMotionEncoder(nn.Module): |
||||||
|
def __init__(self, cor_planes): |
||||||
|
super(BasicMotionEncoder, self).__init__() |
||||||
|
|
||||||
|
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) |
||||||
|
self.convc2 = nn.Conv2d(256, 192, 3, padding=1) |
||||||
|
self.convf1 = nn.Conv2d(2, 128, 7, padding=3) |
||||||
|
self.convf2 = nn.Conv2d(128, 64, 3, padding=1) |
||||||
|
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) |
||||||
|
|
||||||
|
def forward(self, flow, corr): |
||||||
|
cor = F.relu(self.convc1(corr)) |
||||||
|
cor = F.relu(self.convc2(cor)) |
||||||
|
flo = F.relu(self.convf1(flow)) |
||||||
|
flo = F.relu(self.convf2(flo)) |
||||||
|
|
||||||
|
cor_flo = torch.cat([cor, flo], dim=1) |
||||||
|
out = F.relu(self.conv(cor_flo)) |
||||||
|
return torch.cat([out, flow], dim=1) |
||||||
|
|
||||||
|
|
||||||
|
class BasicUpdateBlock(nn.Module): |
||||||
|
def __init__(self, hidden_dim, cor_planes, mask_size=8): |
||||||
|
super(BasicUpdateBlock, self).__init__() |
||||||
|
|
||||||
|
self.encoder = BasicMotionEncoder(cor_planes) |
||||||
|
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) |
||||||
|
self.flow_head = FlowHead(hidden_dim, hidden_dim=256) |
||||||
|
|
||||||
|
self.mask = nn.Sequential( |
||||||
|
nn.Conv2d(128, 256, 3, padding=1), |
||||||
|
nn.ReLU(inplace=True), |
||||||
|
nn.Conv2d(256, mask_size**2 *9, 1, padding=0)) |
||||||
|
|
||||||
|
def forward(self, net, inp, corr, flow, upsample=True): |
||||||
|
# print(inp.shape, corr.shape, flow.shape) |
||||||
|
motion_features = self.encoder(flow, corr) |
||||||
|
# print(motion_features.shape, inp.shape) |
||||||
|
inp = torch.cat((inp, motion_features), dim=1) |
||||||
|
|
||||||
|
net = self.gru(net, inp) |
||||||
|
delta_flow = self.flow_head(net) |
||||||
|
|
||||||
|
# scale mask to balence gradients |
||||||
|
mask = .25 * self.mask(net) |
||||||
|
return net, mask, delta_flow |
@ -0,0 +1 @@ |
|||||||
|
from .utils import bilinear_sampler, coords_grid, manual_pad |
@ -0,0 +1,31 @@ |
|||||||
|
import torch |
||||||
|
import torch.nn.functional as F |
||||||
|
import numpy as np |
||||||
|
|
||||||
|
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py |
||||||
|
|
||||||
|
def bilinear_sampler(img, coords, mode='bilinear', mask=False): |
||||||
|
""" Wrapper for grid_sample, uses pixel coordinates """ |
||||||
|
H, W = img.shape[-2:] |
||||||
|
xgrid, ygrid = coords.split([1,1], dim=-1) |
||||||
|
xgrid = 2*xgrid/(W-1) - 1 |
||||||
|
ygrid = 2*ygrid/(H-1) - 1 |
||||||
|
|
||||||
|
grid = torch.cat([xgrid, ygrid], dim=-1) |
||||||
|
img = F.grid_sample(img, grid, align_corners=True) |
||||||
|
|
||||||
|
if mask: |
||||||
|
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) |
||||||
|
return img, mask.float() |
||||||
|
|
||||||
|
return img |
||||||
|
|
||||||
|
def coords_grid(batch, ht, wd, device): |
||||||
|
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij') |
||||||
|
coords = torch.stack(coords[::-1], dim=0).float() |
||||||
|
return coords[None].repeat(batch, 1, 1, 1) |
||||||
|
|
||||||
|
def manual_pad(x, pady, padx): |
||||||
|
|
||||||
|
pad = (padx, padx, pady, pady) |
||||||
|
return F.pad(x.clone().detach(), pad, "replicate") |
@ -0,0 +1,15 @@ |
|||||||
|
import torch |
||||||
|
from torchsummary import summary |
||||||
|
import numpy as np |
||||||
|
|
||||||
|
from nets import Model |
||||||
|
|
||||||
|
model = Model(max_disp=256, mixed_precision=False, test_mode=True) |
||||||
|
model.eval() |
||||||
|
|
||||||
|
t1 = torch.rand(1, 3, 480, 640) |
||||||
|
t2 = torch.rand(1, 3, 480, 640) |
||||||
|
|
||||||
|
output = model(t1,t2) |
||||||
|
print(output.shape) |
||||||
|
|
Loading…
Reference in new issue