api_server.py: add missing device
This commit is contained in:
parent
8f0c3a32b8
commit
ed71c16912
@ -37,7 +37,7 @@ model_path = "../train_log/models/latest.pth"
|
|||||||
device = torch.device('cuda:0')
|
device = torch.device('cuda:0')
|
||||||
|
|
||||||
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
||||||
model = nn.DataParallel(model, device_ids=[])
|
model = nn.DataParallel(model, device_ids=[device])
|
||||||
# model.load_state_dict(torch.load(model_path), strict=False)
|
# model.load_state_dict(torch.load(model_path), strict=False)
|
||||||
state_dict = torch.load(model_path)['state_dict']
|
state_dict = torch.load(model_path)['state_dict']
|
||||||
model.load_state_dict(state_dict, strict=True)
|
model.load_state_dict(state_dict, strict=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user