|
|
|
@ -81,7 +81,7 @@ class CREStereo(nn.Module): |
|
|
|
|
zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device) |
|
|
|
|
return zero_flow |
|
|
|
|
|
|
|
|
|
def forward(self, image1, image2, flow_init, iters=10, upsample=True, test_mode=False): |
|
|
|
|
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False): |
|
|
|
|
""" Estimate optical flow between pair of frames """ |
|
|
|
|
|
|
|
|
|
image1 = 2 * (image1 / 255.0) - 1.0 |
|
|
|
|