import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from .utils import bilinear_sampler, coords_grid, manual_pad class AGCL: """ Implementation of Adaptive Group Correlation Layer (AGCL). """ def __init__(self, fmap1, fmap2, att=None): self.fmap1 = fmap1 self.fmap2 = fmap2 self.att = att self.coords = coords_grid(fmap1.shape[0], fmap1.shape[2], fmap1.shape[3], fmap1.device) def __call__(self, flow, extra_offset, small_patch=False, iter_mode=False): if iter_mode: corr = self.corr_iter(self.fmap1, self.fmap2, flow, small_patch) else: corr = self.corr_att_offset( self.fmap1, self.fmap2, flow, extra_offset, small_patch ) return corr def get_correlation(self, left_feature, right_feature, psize=(3, 3), dilate=(1, 1)): N, C, H, W = left_feature.shape di_y, di_x = dilate[0], dilate[1] pady, padx = psize[0] // 2 * di_y, psize[1] // 2 * di_x right_pad = manual_pad(right_feature, pady, padx) corr_list = [] for h in range(0, pady * 2 + 1, di_y): for w in range(0, padx * 2 + 1, di_x): right_crop = right_pad[:, :, h : h + H, w : w + W] assert right_crop.shape == left_feature.shape corr = torch.mean(left_feature * right_crop, dim=1, keepdims=True) corr_list.append(corr) corr_final = torch.cat(corr_list, dim=1) return corr_final def corr_iter(self, left_feature, right_feature, flow, small_patch): coords = self.coords + flow coords = coords.permute(0, 2, 3, 1) right_feature = bilinear_sampler(right_feature, coords) if small_patch: psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)] dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)] else: psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)] dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)] N, C, H, W = left_feature.shape lefts = torch.split(left_feature, left_feature.shape[1]//4, dim=1) rights = torch.split(right_feature, right_feature.shape[1]//4, dim=1) corrs = [] for i in range(len(psize_list)): corr = self.get_correlation( lefts[i], rights[i], psize_list[i], dilate_list[i] ) corrs.append(corr) final_corr = torch.cat(corrs, dim=1) return final_corr def corr_att_offset( self, left_feature, right_feature, flow, extra_offset, small_patch ): N, C, H, W = left_feature.shape if self.att is not None: left_feature = left_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c' right_feature = right_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c' # 'n (h w) c -> n c h w' left_feature, right_feature = self.att(left_feature, right_feature) # 'n (h w) c -> n c h w' left_feature, right_feature = [ x.reshape(N, H, W, C).permute(0, 3, 1, 2) for x in [left_feature, right_feature] ] lefts = torch.split(left_feature, left_feature.shape[1]//4, dim=1) rights = torch.split(right_feature, right_feature.shape[1]//4, dim=1) C = C // 4 if small_patch: psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)] dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)] else: psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)] dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)] search_num = 9 extra_offset = extra_offset.reshape(N, search_num, 2, H, W).permute(0, 1, 3, 4, 2) # [N, search_num, 1, 1, 2] corrs = [] for i in range(len(psize_list)): left_feature, right_feature = lefts[i], rights[i] psize, dilate = psize_list[i], dilate_list[i] psizey, psizex = psize[0], psize[1] dilatey, dilatex = dilate[0], dilate[1] ry = psizey // 2 * dilatey rx = psizex // 2 * dilatex x_grid, y_grid = torch.meshgrid(torch.arange(-rx, rx + 1, dilatex, device=self.fmap1.device), torch.arange(-ry, ry + 1, dilatey, device=self.fmap1.device), indexing='xy') offsets = torch.stack((x_grid, y_grid)) offsets = offsets.reshape(2, -1).permute(1, 0) for d in sorted((0, 2, 3)): offsets = offsets.unsqueeze(d) offsets = offsets.repeat_interleave(N, dim=0) offsets = offsets + extra_offset coords = self.coords + flow # [N, 2, H, W] coords = coords.permute(0, 2, 3, 1) # [N, H, W, 2] coords = torch.unsqueeze(coords, 1) + offsets coords = coords.reshape(N, -1, W, 2) # [N, search_num*H, W, 2] right_feature = bilinear_sampler( right_feature, coords ) # [N, C, search_num*H, W] right_feature = right_feature.reshape(N, C, -1, H, W) # [N, C, search_num, H, W] left_feature = left_feature.unsqueeze(2).repeat_interleave(right_feature.shape[2], dim=2) corr = torch.mean(left_feature * right_feature, dim=1) corrs.append(corr) final_corr = torch.cat(corrs, dim=1) return final_corr