diff --git a/nets/crestereo.py b/nets/crestereo.py index f6d2a37..3f99917 100644 --- a/nets/crestereo.py +++ b/nets/crestereo.py @@ -81,7 +81,7 @@ class CREStereo(nn.Module): zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device) return zero_flow - def forward(self, image1, image2, flow_init, iters=10, upsample=True, test_mode=False): + def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False): """ Estimate optical flow between pair of frames """ image1 = 2 * (image1 / 255.0) - 1.0