diff --git a/nets/crestereo.py b/nets/crestereo.py index d685d4c..f6d2a37 100644 --- a/nets/crestereo.py +++ b/nets/crestereo.py @@ -81,7 +81,7 @@ class CREStereo(nn.Module): zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device) return zero_flow - def forward(self, image1, image2, iters=10, flow_init=None, upsample=True, test_mode=False): + def forward(self, image1, image2, flow_init, iters=10, upsample=True, test_mode=False): """ Estimate optical flow between pair of frames """ image1 = 2 * (image1 / 255.0) - 1.0 diff --git a/nets/utils/utils.py b/nets/utils/utils.py index 96075ca..3a28205 100644 --- a/nets/utils/utils.py +++ b/nets/utils/utils.py @@ -12,7 +12,8 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False): ygrid = 2*ygrid/(H-1) - 1 grid = torch.cat([xgrid, ygrid], dim=-1) - img = F.grid_sample(img, grid, align_corners=True) + # img = F.grid_sample(img, grid, align_corners=True) + img = bilinear_grid_sample(img, grid, align_corners=True) if mask: mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) @@ -29,3 +30,79 @@ def manual_pad(x, pady, padx): pad = (padx, padx, pady, pady) return F.pad(x.clone().detach(), pad, "replicate") + +# Ref: https://zenn.dev/pinto0309/scraps/7d4032067d0160 +def bilinear_grid_sample(im, grid, align_corners=False): + """Given an input and a flow-field grid, computes the output using input + values and pixel locations from grid. Supported only bilinear interpolation + method to sample the input pixels. + + Args: + im (torch.Tensor): Input feature map, shape (N, C, H, W) + grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2) + align_corners {bool}: If set to True, the extrema (-1 and 1) are + considered as referring to the center points of the input’s + corner pixels. If set to False, they are instead considered as + referring to the corner points of the input’s corner pixels, + making the sampling more resolution agnostic. + + Returns: + torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg) + """ + n, c, h, w = im.shape + gn, gh, gw, _ = grid.shape + assert n == gn + + x = grid[:, :, :, 0] + y = grid[:, :, :, 1] + + if align_corners: + x = ((x + 1) / 2) * (w - 1) + y = ((y + 1) / 2) * (h - 1) + else: + x = ((x + 1) * w - 1) / 2 + y = ((y + 1) * h - 1) / 2 + + x = x.view(n, -1) + y = y.view(n, -1) + + x0 = torch.floor(x).long() + y0 = torch.floor(y).long() + x1 = x0 + 1 + y1 = y0 + 1 + + wa = ((x1 - x) * (y1 - y)).unsqueeze(1) + wb = ((x1 - x) * (y - y0)).unsqueeze(1) + wc = ((x - x0) * (y1 - y)).unsqueeze(1) + wd = ((x - x0) * (y - y0)).unsqueeze(1) + + # Apply default for grid_sample function zero padding + im_padded = torch.nn.functional.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0) + padded_h = h + 2 + padded_w = w + 2 + # save points positions after padding + x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1 + + # Clip coordinates to padded image size + x0 = torch.where(x0 < 0, torch.tensor(0, device=im.device), x0) + x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x0) + x1 = torch.where(x1 < 0, torch.tensor(0, device=im.device), x1) + x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x1) + y0 = torch.where(y0 < 0, torch.tensor(0, device=im.device), y0) + y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y0) + y1 = torch.where(y1 < 0, torch.tensor(0, device=im.device), y1) + y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y1) + + im_padded = im_padded.view(n, c, -1) + + x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) + x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) + x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) + x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) + + Ia = torch.gather(im_padded, 2, x0_y0) + Ib = torch.gather(im_padded, 2, x0_y1) + Ic = torch.gather(im_padded, 2, x1_y0) + Id = torch.gather(im_padded, 2, x1_y1) + + return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)