diff --git a/api_server.py b/api_server.py index ef2ac76..18434f4 100644 --- a/api_server.py +++ b/api_server.py @@ -37,7 +37,7 @@ model_path = "../train_log/models/latest.pth" device = torch.device('cuda:0') model = Model(max_disp=256, mixed_precision=False, test_mode=True) -model = nn.DataParallel(model, device_ids=[]) +model = nn.DataParallel(model, device_ids=[device]) # model.load_state_dict(torch.load(model_path), strict=False) state_dict = torch.load(model_path)['state_dict'] model.load_state_dict(state_dict, strict=True)