import torch def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1): type = type.lower() p = block_size // 2 es_pad = torch.nn.functional.pad(es, (p, p, p, p), mode='replicate') ta_pad = torch.nn.functional.pad(ta, (p, p, p, p), mode='replicate') es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size) ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size) es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3]) ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3]) if type == 'mse': ref = (es_uf - ta_uf) ** 2 elif type == 'sad': ref = torch.abs(es_uf - ta_uf) elif type == 'census_mse' or type == 'census_sad': des = es_uf - es.unsqueeze(2) dta = ta_uf - ta.unsqueeze(2) h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps)) h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps)) diff = h_des - h_dta if type == 'census_mse': ref = diff * diff elif type == 'census_sad': ref = torch.abs(diff) else: raise Exception('invalid loss type') ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3]) ref = torch.sum(ref, dim=1, keepdim=True) / block_size ** 2 return ref