|
|
|
import torch
|
|
|
|
|
|
|
|
def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
|
|
|
|
type = type.lower()
|
|
|
|
p = block_size // 2
|
|
|
|
es_pad = torch.nn.functional.pad(es, (p, p, p, p), mode='replicate')
|
|
|
|
ta_pad = torch.nn.functional.pad(ta, (p, p, p, p), mode='replicate')
|
|
|
|
es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
|
|
|
|
ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
|
|
|
|
es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
|
|
|
|
ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
|
|
|
|
if type == 'mse':
|
|
|
|
ref = (es_uf - ta_uf) ** 2
|
|
|
|
elif type == 'sad':
|
|
|
|
ref = torch.abs(es_uf - ta_uf)
|
|
|
|
elif type == 'census_mse' or type == 'census_sad':
|
|
|
|
des = es_uf - es.unsqueeze(2)
|
|
|
|
dta = ta_uf - ta.unsqueeze(2)
|
|
|
|
h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps))
|
|
|
|
h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps))
|
|
|
|
diff = h_des - h_dta
|
|
|
|
if type == 'census_mse':
|
|
|
|
ref = diff * diff
|
|
|
|
elif type == 'census_sad':
|
|
|
|
ref = torch.abs(diff)
|
|
|
|
else:
|
|
|
|
raise Exception('invalid loss type')
|
|
|
|
ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
|
|
|
|
ref = torch.sum(ref, dim=1, keepdim=True) / block_size ** 2
|
|
|
|
return ref
|