From ed71c16912b9d9ad41719124330a7c39b24ee370 Mon Sep 17 00:00:00 2001 From: Nils Koch Date: Tue, 31 May 2022 11:56:54 +0200 Subject: [PATCH] api_server.py: add missing device --- api_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api_server.py b/api_server.py index ef2ac76..18434f4 100644 --- a/api_server.py +++ b/api_server.py @@ -37,7 +37,7 @@ model_path = "../train_log/models/latest.pth" device = torch.device('cuda:0') model = Model(max_disp=256, mixed_precision=False, test_mode=True) -model = nn.DataParallel(model, device_ids=[]) +model = nn.DataParallel(model, device_ids=[device]) # model.load_state_dict(torch.load(model_path), strict=False) state_dict = torch.load(model_path)['state_dict'] model.load_state_dict(state_dict, strict=True)