|
|
@ -12,7 +12,8 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False): |
|
|
|
ygrid = 2*ygrid/(H-1) - 1 |
|
|
|
ygrid = 2*ygrid/(H-1) - 1 |
|
|
|
|
|
|
|
|
|
|
|
grid = torch.cat([xgrid, ygrid], dim=-1) |
|
|
|
grid = torch.cat([xgrid, ygrid], dim=-1) |
|
|
|
img = F.grid_sample(img, grid, align_corners=True) |
|
|
|
# img = F.grid_sample(img, grid, align_corners=True) |
|
|
|
|
|
|
|
img = bilinear_grid_sample(img, grid, align_corners=True) |
|
|
|
|
|
|
|
|
|
|
|
if mask: |
|
|
|
if mask: |
|
|
|
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) |
|
|
|
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) |
|
|
@ -29,3 +30,79 @@ def manual_pad(x, pady, padx): |
|
|
|
|
|
|
|
|
|
|
|
pad = (padx, padx, pady, pady) |
|
|
|
pad = (padx, padx, pady, pady) |
|
|
|
return F.pad(x.clone().detach(), pad, "replicate") |
|
|
|
return F.pad(x.clone().detach(), pad, "replicate") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Ref: https://zenn.dev/pinto0309/scraps/7d4032067d0160 |
|
|
|
|
|
|
|
def bilinear_grid_sample(im, grid, align_corners=False): |
|
|
|
|
|
|
|
"""Given an input and a flow-field grid, computes the output using input |
|
|
|
|
|
|
|
values and pixel locations from grid. Supported only bilinear interpolation |
|
|
|
|
|
|
|
method to sample the input pixels. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
|
|
im (torch.Tensor): Input feature map, shape (N, C, H, W) |
|
|
|
|
|
|
|
grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2) |
|
|
|
|
|
|
|
align_corners {bool}: If set to True, the extrema (-1 and 1) are |
|
|
|
|
|
|
|
considered as referring to the center points of the input’s |
|
|
|
|
|
|
|
corner pixels. If set to False, they are instead considered as |
|
|
|
|
|
|
|
referring to the corner points of the input’s corner pixels, |
|
|
|
|
|
|
|
making the sampling more resolution agnostic. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
|
|
|
torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg) |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
n, c, h, w = im.shape |
|
|
|
|
|
|
|
gn, gh, gw, _ = grid.shape |
|
|
|
|
|
|
|
assert n == gn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = grid[:, :, :, 0] |
|
|
|
|
|
|
|
y = grid[:, :, :, 1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if align_corners: |
|
|
|
|
|
|
|
x = ((x + 1) / 2) * (w - 1) |
|
|
|
|
|
|
|
y = ((y + 1) / 2) * (h - 1) |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
x = ((x + 1) * w - 1) / 2 |
|
|
|
|
|
|
|
y = ((y + 1) * h - 1) / 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x.view(n, -1) |
|
|
|
|
|
|
|
y = y.view(n, -1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x0 = torch.floor(x).long() |
|
|
|
|
|
|
|
y0 = torch.floor(y).long() |
|
|
|
|
|
|
|
x1 = x0 + 1 |
|
|
|
|
|
|
|
y1 = y0 + 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wa = ((x1 - x) * (y1 - y)).unsqueeze(1) |
|
|
|
|
|
|
|
wb = ((x1 - x) * (y - y0)).unsqueeze(1) |
|
|
|
|
|
|
|
wc = ((x - x0) * (y1 - y)).unsqueeze(1) |
|
|
|
|
|
|
|
wd = ((x - x0) * (y - y0)).unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Apply default for grid_sample function zero padding |
|
|
|
|
|
|
|
im_padded = torch.nn.functional.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0) |
|
|
|
|
|
|
|
padded_h = h + 2 |
|
|
|
|
|
|
|
padded_w = w + 2 |
|
|
|
|
|
|
|
# save points positions after padding |
|
|
|
|
|
|
|
x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Clip coordinates to padded image size |
|
|
|
|
|
|
|
x0 = torch.where(x0 < 0, torch.tensor(0, device=im.device), x0) |
|
|
|
|
|
|
|
x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x0) |
|
|
|
|
|
|
|
x1 = torch.where(x1 < 0, torch.tensor(0, device=im.device), x1) |
|
|
|
|
|
|
|
x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x1) |
|
|
|
|
|
|
|
y0 = torch.where(y0 < 0, torch.tensor(0, device=im.device), y0) |
|
|
|
|
|
|
|
y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y0) |
|
|
|
|
|
|
|
y1 = torch.where(y1 < 0, torch.tensor(0, device=im.device), y1) |
|
|
|
|
|
|
|
y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
im_padded = im_padded.view(n, c, -1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) |
|
|
|
|
|
|
|
x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) |
|
|
|
|
|
|
|
x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) |
|
|
|
|
|
|
|
x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Ia = torch.gather(im_padded, 2, x0_y0) |
|
|
|
|
|
|
|
Ib = torch.gather(im_padded, 2, x0_y1) |
|
|
|
|
|
|
|
Ic = torch.gather(im_padded, 2, x1_y0) |
|
|
|
|
|
|
|
Id = torch.gather(im_padded, 2, x1_y1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw) |
|
|
|