|
|
|
@ -387,7 +387,7 @@ class RectifiedPatternSimilarityLoss(TimedModule): |
|
|
|
|
if std is not None: |
|
|
|
|
mask = mask * std |
|
|
|
|
|
|
|
|
|
diff = torchext.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps) |
|
|
|
|
diff = torchext.photometric_loss_pytorch(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps) |
|
|
|
|
val = (mask * diff).sum() / mask.sum() |
|
|
|
|
return val, pattern_proj |
|
|
|
|
|
|
|
|
|