CREStereo Repository for the 'Towards accurate and robust depth estimation' project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

29 lines
877 B

3 years ago
import pickle
import numpy as np
import megengine as mge
import torch
import torch.nn.functional as F
def coords_grid(batch, ht, wd, device):
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij')
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def test_coords_grid():
# Getting back the megengine objects:
with open('test_data/coords_grid_test.pickle', 'rb') as f:
batch, ht, wd, coords = pickle.load(f)
coords = coords.numpy()
# Test Pytorch
coords_pytorch = coords_grid(batch, ht, wd, 'cpu').numpy()
error = np.mean(coords_pytorch-coords)
print(f"test_coords_grid - Avg. Error: {error}, \n \
Obtained shape: {coords_pytorch.shape}, Expected shape: {coords.shape}")
if __name__ == '__main__':
test_coords_grid()