api_server.py: add missing device

main
Nils Koch 2 years ago
parent 8f0c3a32b8
commit ed71c16912
  1. 2
      api_server.py

@ -37,7 +37,7 @@ model_path = "../train_log/models/latest.pth"
device = torch.device('cuda:0') device = torch.device('cuda:0')
model = Model(max_disp=256, mixed_precision=False, test_mode=True) model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model = nn.DataParallel(model, device_ids=[]) model = nn.DataParallel(model, device_ids=[device])
# model.load_state_dict(torch.load(model_path), strict=False) # model.load_state_dict(torch.load(model_path), strict=False)
state_dict = torch.load(model_path)['state_dict'] state_dict = torch.load(model_path)['state_dict']
model.load_state_dict(state_dict, strict=True) model.load_state_dict(state_dict, strict=True)

Loading…
Cancel
Save