You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
31 lines
951 B
31 lines
951 B
3 years ago
|
import pickle
|
||
|
import numpy as np
|
||
|
import megengine as mge
|
||
|
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
def test_offset():
|
||
|
# Getting back the megengine objects:
|
||
|
with open('test_data/offset_test.pkl', 'rb') as f:
|
||
|
x_grid, y_grid, reshape_shape, transpose_order, expand_size, repeat_size, repeat_axis, offsets = pickle.load(f)
|
||
|
|
||
|
x_grid = torch.tensor(x_grid.numpy())
|
||
|
y_grid = torch.tensor(y_grid.numpy())
|
||
|
offsets_mge = offsets.numpy()
|
||
|
N = repeat_size
|
||
|
|
||
|
# Test Pytorch
|
||
|
offsets = torch.stack((x_grid, y_grid))
|
||
|
offsets = offsets.reshape(2, -1).permute(1, 0)
|
||
|
for d in sorted((0, 2, 3)):
|
||
|
offsets = offsets.unsqueeze(d)
|
||
|
offsets = offsets.repeat_interleave(N, dim=0)
|
||
|
|
||
|
error = np.mean(offsets.numpy()-offsets_mge)
|
||
|
print(f"test_offset - Avg. Error: {error}, \n \
|
||
|
Obtained shape: {offsets.numpy().shape}, Expected shape: {offsets_mge.shape}")
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
|
||
|
test_offset()
|