CREStereo Repository for the 'Towards accurate and robust depth estimation' project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

419 lines
14 KiB

import os
import sys
import time
import logging
from collections import namedtuple
import yaml
# from tensorboardX import SummaryWriter
from nets import Model
# from dataset import CREStereoDataset
from dataset import CREStereoDataset, CTDDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb
import numpy as np
import cv2
def normalize_and_colormap(img):
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
if isinstance(ret, torch.Tensor):
ret = ret.cpu().detach().numpy()
ret = ret.astype("uint8")
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
return ret
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True):
print("Model Forwarding...")
left = left.cpu().detach().numpy()
imgL = left
imgR = right.cpu().detach().numpy()
imgL = np.ascontiguousarray(imgL[None, :, :, :])
imgR = np.ascontiguousarray(imgR[None, :, :, :])
# chosen for convenience
device = torch.device('cuda:0')
imgL = torch.tensor(imgL.astype("float32")).to(device)
imgR = torch.tensor(imgR.astype("float32")).to(device)
imgL = imgL.transpose(2,3).transpose(1,2)
if imgL.shape != imgR.shape:
imgR = imgR.transpose(2,3).transpose(1,2)
imgL_dw2 = F.interpolate(
imgL,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
imgR_dw2 = F.interpolate(
imgR,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
with torch.inference_mode():
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
if not wandb_log:
return torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
log = {}
in_h, in_w = left.shape[:2]
# Resize image in case the GPU memory overflows
eval_h, eval_w = (in_h,in_w)
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy()
pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :]).cpu().detach().numpy()
pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
if i == n_iter-1:
t = float(in_w) / float(eval_w)
disp = cv2.resize(pred_disp, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
log[f'disp_vis'] = wandb.Image(
normalize_and_colormap(disp),
caption=f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
)
log[f'pred_{i}'] = wandb.Image(
np.array([pred_disp.reshape(480, 640)]),
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
)
log[f'pred_norm_{i}'] = wandb.Image(
np.array([pred_disp_norm.reshape(480, 640)]),
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
)
log[f'pred_dw2_{i}'] = wandb.Image(
np.array([pred_disp_dw2.reshape(240, 320)]),
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
)
log[f'pred_dw2_norm_{i}'] = wandb.Image(
np.array([pred_disp_dw2_norm.reshape(240, 320)]),
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
)
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1,2,0).astype('uint8'), caption="Input Right")
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
disp_error = gt_disp - disp
log['disp_error'] = wandb.Image(
normalize_and_colormap(disp_error.abs()),
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.abs().mean():.{2}f}",
)
log[f'gt_disp_vis'] = wandb.Image(
normalize_and_colormap(gt_disp),
caption=f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
)
wandb.log(log)
def parse_yaml(file_path: str) -> namedtuple:
"""Parse yaml configuration file and return the object in `namedtuple`."""
with open(file_path, "rb") as f:
cfg: dict = yaml.safe_load(f)
args = namedtuple("train_args", cfg.keys())(*cfg.values())
return args
def format_time(elapse):
elapse = int(elapse)
hour = elapse // 3600
minute = elapse % 3600 // 60
seconds = elapse % 60
return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds)
def ensure_dir(path):
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
def adjust_learning_rate(optimizer, epoch):
warm_up = 0.02
const_range = 0.6
min_lr_rate = 0.05
if epoch <= args.n_total_epoch * warm_up:
lr = (1 - min_lr_rate) * args.base_lr / (
args.n_total_epoch * warm_up
) * epoch + min_lr_rate * args.base_lr
elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range:
lr = args.base_lr
else:
lr = (min_lr_rate - 1) * args.base_lr / (
(1 - const_range) * args.n_total_epoch
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8):
'''
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
flow_preds[0]: (B, 2, H, W)
flow_gt: (B, 2, H, W)
'''
n_predictions = len(flow_preds)
flow_loss = 0.0
# TEST
flow_gt = torch.squeeze(flow_gt, dim=-1)
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
i_loss = torch.abs(flow_preds[i] - flow_gt)
flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean()
return flow_loss
def main(args):
# initial info
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
# rank, world_size = dist.get_rank(), dist.get_world_size()
world_size = torch.cuda.device_count() # number of GPU(s)
# directory check
log_model_dir = os.path.join(args.log_dir, "models")
ensure_dir(log_model_dir)
# model / optimizer
model = Model(
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
)
model = nn.DataParallel(model,device_ids=[i for i in range(world_size)])
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999))
# model = nn.DataParallel(model,device_ids=[0])
# tb_log = SummaryWriter(os.path.join(args.log_dir, "train.events"))
wandb.watch(model)
metrics = {}
# worklog
logging.basicConfig(level=eval(args.log_level))
worklog = logging.getLogger("train_logger")
worklog.propagate = False
fileHandler = logging.FileHandler(
os.path.join(args.log_dir, "worklog.txt"), mode="a", encoding="utf8"
)
formatter = logging.Formatter(
fmt="%(asctime)s %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
)
fileHandler.setFormatter(formatter)
consoleHandler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
fmt="\x1b[32m%(asctime)s\x1b[0m %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
)
consoleHandler.setFormatter(formatter)
worklog.handlers = [fileHandler, consoleHandler]
# params stat
worklog.info(f"Use {world_size} GPU(s)")
worklog.info("Params: %s" % sum([p.numel() for p in model.parameters()]))
# load pretrained model if exist
chk_path = os.path.join(log_model_dir, "latest.pth")
if args.loadmodel is not None:
chk_path = args.loadmodel
elif not os.path.exists(chk_path):
chk_path = None
if chk_path is not None:
# if rank == 0:
worklog.info(f"loading model: {chk_path}")
state_dict = torch.load(chk_path)
model.load_state_dict(state_dict['state_dict'])
optimizer.load_state_dict(state_dict['optim_state_dict'])
resume_epoch_idx = state_dict["epoch"]
resume_iters = state_dict["iters"]
start_epoch_idx = resume_epoch_idx + 1
start_iters = resume_iters
else:
start_epoch_idx = 1
start_iters = 0
# pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
pattern_path = '/home/nils/kinect_reference_cropped.png'
# pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
# datasets
# dataset = CREStereoDataset(args.training_data_path)
dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path)
# if rank == 0:
worklog.info(f"Dataset size: {len(dataset)}")
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
# counter
cur_iters = start_iters
total_iters = args.minibatch_per_epoch * args.n_total_epoch
t0 = time.perf_counter()
for epoch_idx in range(start_epoch_idx, args.n_total_epoch + 1):
# adjust learning rate
epoch_total_train_loss = 0
adjust_learning_rate(optimizer, epoch_idx)
model.train()
t1 = time.perf_counter()
# batch_idx = 0
# for mini_batch_data in dataloader:
for batch_idx, mini_batch_data in enumerate(dataloader):
if batch_idx % args.minibatch_per_epoch == 0 and batch_idx != 0:
break
# batch_idx += 1
cur_iters += 1
# parse data
left, right, gt_disp, valid_mask = (
mini_batch_data["left"],
mini_batch_data["right"],
mini_batch_data["disparity"].cuda(),
mini_batch_data["mask"].cuda(),
)
t2 = time.perf_counter()
optimizer.zero_grad()
# pre-process
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
# forward
# left = left.transpose(1, 2).transpose(2, 3)
left = left.transpose(1, 3).transpose(2, 3)
right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2)
flow_predictions = model(left.cuda(), right.cuda())
# loss & backword
loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
if batch_idx % 128 == 0:
inference(
mini_batch_data['left'][0],
mini_batch_data['right'][0],
mini_batch_data['disparity'][0],
mini_batch_data['mask'][0],
model,
batch_idx,
)
# loss stats
loss_item = loss.data.item()
epoch_total_train_loss += loss_item
loss.backward()
optimizer.step()
t3 = time.perf_counter()
if cur_iters % 10 == 0:
tdata = t2 - t1
time_train_passed = t3 - t0
time_iter_passed = t3 - t1
step_passed = cur_iters - start_iters
eta = (
(total_iters - cur_iters)
/ max(step_passed, 1e-7)
* time_train_passed
)
meta_info = list()
meta_info.append("{:.2g} b/s".format(1.0 / time_iter_passed))
meta_info.append("passed:{}".format(format_time(time_train_passed)))
meta_info.append("eta:{}".format(format_time(eta)))
meta_info.append(
"data_time:{:.2g}".format(tdata / time_iter_passed)
)
meta_info.append(
"lr:{:.5g}".format(optimizer.param_groups[0]["lr"])
)
meta_info.append(
"[{}/{}:{}/{}]".format(
epoch_idx,
args.n_total_epoch,
batch_idx,
args.minibatch_per_epoch,
)
)
loss_info = [" ==> {}:{:.4g}".format("loss", loss_item)]
# exp_name = ['\n' + os.path.basename(os.getcwd())]
info = [",".join(meta_info)] + loss_info
worklog.info("".join(info))
# minibatch loss
# tb_log.add_scalar("train/loss_batch", loss_item, cur_iters)
metrics['train/loss_batch'] = loss_item
# tb_log.add_scalar(
# "train/lr", optimizer.param_groups[0]["lr"], cur_iters
# )
metrics['train/lr'] = optimizer.param_groups[0]["lr"]
# tb_log.flush()
wandb.log(metrics)
t1 = time.perf_counter()
# tb_log.add_scalar(
# "train/loss",
# epoch_total_train_loss / args.minibatch_per_epoch,
# epoch_idx,
# )
metrics['train/loss'] = epoch_total_train_loss / args.minibatch_per_epoch
# tb_log.flush()
wandb.log(metrics)
# save model params
ckp_data = {
"epoch": epoch_idx,
"iters": cur_iters,
"batch_size": args.batch_size,
"epoch_size": args.minibatch_per_epoch,
"train_loss": epoch_total_train_loss / args.minibatch_per_epoch,
"state_dict": model.state_dict(),
"optim_state_dict": optimizer.state_dict(),
}
torch.save(ckp_data, os.path.join(log_model_dir, "latest.pth"))
if epoch_idx % args.model_save_freq_epoch == 0:
save_path = os.path.join(log_model_dir, "epoch-%d.pth" % epoch_idx)
worklog.info(f"Model params saved: {save_path}")
torch.save(ckp_data, save_path)
worklog.info("Training is done, exit.")
if __name__ == "__main__":
# train configuration
args = parse_yaml("cfgs/train.yaml")
wandb.init(project="crestereo", entity="cpt-captain")
wandb.config.update(args._asdict())
main(args)