You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
29 lines
877 B
29 lines
877 B
3 years ago
|
import pickle
|
||
|
import numpy as np
|
||
|
import megengine as mge
|
||
|
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
def coords_grid(batch, ht, wd, device):
|
||
|
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij')
|
||
|
coords = torch.stack(coords[::-1], dim=0).float()
|
||
|
return coords[None].repeat(batch, 1, 1, 1)
|
||
|
|
||
|
def test_coords_grid():
|
||
|
# Getting back the megengine objects:
|
||
|
with open('test_data/coords_grid_test.pickle', 'rb') as f:
|
||
|
batch, ht, wd, coords = pickle.load(f)
|
||
|
|
||
|
coords = coords.numpy()
|
||
|
|
||
|
# Test Pytorch
|
||
|
coords_pytorch = coords_grid(batch, ht, wd, 'cpu').numpy()
|
||
|
|
||
|
error = np.mean(coords_pytorch-coords)
|
||
|
print(f"test_coords_grid - Avg. Error: {error}, \n \
|
||
|
Obtained shape: {coords_pytorch.shape}, Expected shape: {coords.shape}")
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
|
||
|
test_coords_grid()
|