CREStereo Repository for the 'Towards accurate and robust depth estimation' project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
CREStereo-pytorch-nxt/convert_weights.py

26 lines
614 B

import copy
import torch
import numpy as np
import megengine as mge
from nets import Model
# Read Megengine parameters
pretrained_dict = mge.load("models/crestereo_eth3d.mge")
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model.eval()
state_dict = model.state_dict()
for key, value in pretrained_dict['state_dict'].items():
print(f"Converting {key}")
# Fix shape mismatch
if value.shape[0] == 1:
value = np.squeeze(value)
state_dict[key] = torch.tensor(value)
output_path = "models/crestereo_eth3d.pth"
torch.save(state_dict, output_path)
print(f"\nModel saved to: {output_path}")