|
|
|
@ -37,7 +37,7 @@ model_path = "../train_log/models/latest.pth" |
|
|
|
|
device = torch.device('cuda:0') |
|
|
|
|
|
|
|
|
|
model = Model(max_disp=256, mixed_precision=False, test_mode=True) |
|
|
|
|
model = nn.DataParallel(model, device_ids=[]) |
|
|
|
|
model = nn.DataParallel(model, device_ids=[device]) |
|
|
|
|
# model.load_state_dict(torch.load(model_path), strict=False) |
|
|
|
|
state_dict = torch.load(model_path)['state_dict'] |
|
|
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
|
|