|
|
|
@ -20,18 +20,7 @@ if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
t1 = torch.rand(1, 3, in_h, in_w) |
|
|
|
|
t2 = torch.rand(1, 3, in_h, in_w) |
|
|
|
|
flow_init = torch.rand(1, 2, in_h//2, in_w) |
|
|
|
|
|
|
|
|
|
# Export the model |
|
|
|
|
# !! Needs Pytorch nightly until next release (1.12). Ref: https://github.com/pytorch/pytorch/pull/73760 |
|
|
|
|
torch.onnx.export(model, |
|
|
|
|
(t1_half, t2_half), |
|
|
|
|
"crestereo_without_flow.onnx", # where to save the model (can be a file or file-like object) |
|
|
|
|
export_params=True, # store the trained parameter weights inside the model file |
|
|
|
|
opset_version=12, # the ONNX version to export the model to |
|
|
|
|
do_constant_folding=True, # whether to execute constant folding for optimization |
|
|
|
|
input_names = ['left', 'right'], # the model's input names |
|
|
|
|
output_names = ['output']) |
|
|
|
|
flow_init = torch.rand(1, 2, in_h//2, in_w//2) |
|
|
|
|
|
|
|
|
|
# Export the model |
|
|
|
|
torch.onnx.export(model, |
|
|
|
@ -43,5 +32,17 @@ if __name__ == '__main__': |
|
|
|
|
input_names = ['left', 'right','flow_init'], # the model's input names |
|
|
|
|
output_names = ['output']) |
|
|
|
|
|
|
|
|
|
# # Export the model without init_flow (it takes a lot of time) |
|
|
|
|
# # !! Needs Pytorch nightly until next release (1.12). Ref: https://github.com/pytorch/pytorch/pull/73760 |
|
|
|
|
# torch.onnx.export(model, |
|
|
|
|
# (t1_half, t2_half), |
|
|
|
|
# "crestereo_without_flow.onnx", # where to save the model (can be a file or file-like object) |
|
|
|
|
# export_params=True, # store the trained parameter weights inside the model file |
|
|
|
|
# opset_version=12, # the ONNX version to export the model to |
|
|
|
|
# do_constant_folding=True, # whether to execute constant folding for optimization |
|
|
|
|
# input_names = ['left', 'right'], # the model's input names |
|
|
|
|
# output_names = ['output']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|