parent
11959eef61
commit
63da24f429
@ -0,0 +1,346 @@ |
|||||||
|
import os |
||||||
|
import sys |
||||||
|
import time |
||||||
|
import logging |
||||||
|
from collections import namedtuple |
||||||
|
|
||||||
|
import yaml |
||||||
|
# from tensorboardX import SummaryWriter |
||||||
|
|
||||||
|
from nets import Model |
||||||
|
# from dataset import CREStereoDataset |
||||||
|
from dataset import BlenderDataset, CREStereoDataset, CTDDataset |
||||||
|
|
||||||
|
import torch |
||||||
|
import torch.nn as nn |
||||||
|
import torch.nn.functional as F |
||||||
|
import torch.optim as optim |
||||||
|
from torch.utils.data import DataLoader |
||||||
|
from pytorch_lightning.lite import LightningLite |
||||||
|
from pytorch_lightning import LightningDataModule, LightningModule, Trainer |
||||||
|
from pytorch_lightning import Trainer, seed_everything |
||||||
|
from pytorch_lightning.loggers import WandbLogger |
||||||
|
|
||||||
|
seed_everything(42, workers=True) |
||||||
|
|
||||||
|
|
||||||
|
import wandb |
||||||
|
|
||||||
|
import numpy as np |
||||||
|
import cv2 |
||||||
|
|
||||||
|
|
||||||
|
def normalize_and_colormap(img): |
||||||
|
ret = (img - img.min()) / (img.max() - img.min()) * 255.0 |
||||||
|
if isinstance(ret, torch.Tensor): |
||||||
|
ret = ret.cpu().detach().numpy() |
||||||
|
ret = ret.astype("uint8") |
||||||
|
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO) |
||||||
|
return ret |
||||||
|
|
||||||
|
|
||||||
|
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True, last_img=None, test=False, attend_pattern=True): |
||||||
|
|
||||||
|
print("Model Forwarding...") |
||||||
|
if isinstance(left, torch.Tensor): |
||||||
|
left = left# .cpu().detach().numpy() |
||||||
|
imgR = right# .cpu().detach().numpy() |
||||||
|
imgL = left |
||||||
|
imgR = right |
||||||
|
imgL = np.ascontiguousarray(imgL[None, :, :, :]) |
||||||
|
imgR = np.ascontiguousarray(imgR[None, :, :, :]) |
||||||
|
|
||||||
|
flow_init = None |
||||||
|
|
||||||
|
# chosen for convenience |
||||||
|
|
||||||
|
imgL = torch.tensor(imgL.astype("float32")) |
||||||
|
imgR = torch.tensor(imgR.astype("float32")) |
||||||
|
imgL = imgL.transpose(2,3).transpose(1,2) |
||||||
|
if imgL.shape != imgR.shape: |
||||||
|
imgR = imgR.transpose(2,3).transpose(1,2) |
||||||
|
|
||||||
|
imgL_dw2 = F.interpolate( |
||||||
|
imgL, |
||||||
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2), |
||||||
|
mode="bilinear", |
||||||
|
align_corners=True, |
||||||
|
).clamp(min=0, max=255) |
||||||
|
imgR_dw2 = F.interpolate( |
||||||
|
imgR, |
||||||
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2), |
||||||
|
mode="bilinear", |
||||||
|
align_corners=True, |
||||||
|
).clamp(min=0, max=255) |
||||||
|
if last_img is not None: |
||||||
|
print('using flow_initialization') |
||||||
|
print(last_img.shape) |
||||||
|
# FIXME trying 8bit normalization because sometimes we get really high values that don't seem to help |
||||||
|
print(last_img.max(), last_img.min()) |
||||||
|
if last_img.min() < 0: |
||||||
|
# print('Negative disparity detected. shifting...') |
||||||
|
last_img = last_img - last_img.min() |
||||||
|
if last_img.max() > 255: |
||||||
|
# print('Excessive disparity detected. scaling...') |
||||||
|
last_img = last_img / (last_img.max() / 255) |
||||||
|
|
||||||
|
|
||||||
|
last_img = np.dstack([last_img, last_img]) |
||||||
|
# last_img = np.dstack([last_img, last_img, last_img]) |
||||||
|
last_img = np.dstack([last_img]) |
||||||
|
last_img = last_img.reshape((1, 2, 480, 640)) |
||||||
|
# print(last_img.shape) |
||||||
|
# print(last_img.dtype) |
||||||
|
# print(last_img.max(), last_img.min()) |
||||||
|
flow_init = torch.tensor(last_img.astype("float32")) |
||||||
|
# flow_init = F.interpolate( |
||||||
|
# last_img, |
||||||
|
# size=(last_img.shape[0] // 2, last_img.shape[1] // 2), |
||||||
|
# mode="bilinear", |
||||||
|
# align_corners=True, |
||||||
|
# ) |
||||||
|
with torch.inference_mode(): |
||||||
|
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=flow_init, self_attend_right=attend_pattern) |
||||||
|
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2, self_attend_right=attend_pattern) |
||||||
|
pf_base = pred_flow |
||||||
|
if isinstance(pf_base, list): |
||||||
|
pf_base = pred_flow[0] |
||||||
|
pf = torch.squeeze(pf_base[:, 0, :, :])# .cpu().detach().numpy() |
||||||
|
print('pred_flow max min') |
||||||
|
print(pf.max(), pf.min()) |
||||||
|
|
||||||
|
|
||||||
|
if not wandb_log: |
||||||
|
if test: |
||||||
|
return pred_flow |
||||||
|
return torch.squeeze(pred_flow[:, 0, :, :])# .cpu().detach().numpy() |
||||||
|
|
||||||
|
log = {} |
||||||
|
in_h, in_w = left.shape[:2] |
||||||
|
|
||||||
|
# Resize image in case the GPU memory overflows |
||||||
|
eval_h, eval_w = (in_h,in_w) |
||||||
|
|
||||||
|
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)): |
||||||
|
pred_disp = torch.squeeze(pf[:, 0, :, :])# .cpu().detach().numpy() |
||||||
|
pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :])# .cpu().detach().numpy() |
||||||
|
|
||||||
|
# pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) |
||||||
|
# pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) |
||||||
|
|
||||||
|
if i == n_iter-1: |
||||||
|
t = float(in_w) / float(eval_w) |
||||||
|
disp = cv2.resize(pred_disp.cpu().detach().numpy(), (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t |
||||||
|
|
||||||
|
log[f'disp_vis'] = wandb.Image( |
||||||
|
normalize_and_colormap(disp), |
||||||
|
caption=f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", |
||||||
|
) |
||||||
|
|
||||||
|
log[f'pred_{i}'] = wandb.Image( |
||||||
|
np.array([pred_disp.cpu().detach().numpy().reshape(480, 640)]), |
||||||
|
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", |
||||||
|
) |
||||||
|
# log[f'pred_norm_{i}'] = wandb.Image( |
||||||
|
# np.array([pred_disp_norm.reshape(480, 640)]), |
||||||
|
# caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", |
||||||
|
# ) |
||||||
|
|
||||||
|
# log[f'pred_dw2_{i}'] = wandb.Image( |
||||||
|
# np.array([pred_disp_dw2.reshape(240, 320)]), |
||||||
|
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", |
||||||
|
# ) |
||||||
|
# log[f'pred_dw2_norm_{i}'] = wandb.Image( |
||||||
|
# np.array([pred_disp_dw2_norm.reshape(240, 320)]), |
||||||
|
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", |
||||||
|
# ) |
||||||
|
|
||||||
|
|
||||||
|
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left") |
||||||
|
input_right = right# .cpu().detach().numpy() if isinstance(right, torch.Tensor) else right |
||||||
|
if input_right.shape != (480, 640, 3): |
||||||
|
input_right.transpose(1,2,0) |
||||||
|
log['input_right'] = wandb.Image(input_right.astype('uint8'), caption="Input Right") |
||||||
|
|
||||||
|
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}") |
||||||
|
|
||||||
|
gt_disp = gt_disp# .cpu().detach().numpy() if isinstance(gt_disp, torch.Tensor) else gt_disp |
||||||
|
disp = disp# .cpu().detach().numpy() if isinstance(disp, torch.Tensor) else disp |
||||||
|
|
||||||
|
disp_error = gt_disp - disp |
||||||
|
log['disp_error'] = wandb.Image( |
||||||
|
normalize_and_colormap(abs(disp_error)), |
||||||
|
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}", |
||||||
|
) |
||||||
|
|
||||||
|
|
||||||
|
log[f'gt_disp_vis'] = wandb.Image( |
||||||
|
normalize_and_colormap(gt_disp), |
||||||
|
caption=f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}", |
||||||
|
) |
||||||
|
|
||||||
|
wandb.log(log) |
||||||
|
return pred_flow |
||||||
|
|
||||||
|
|
||||||
|
def parse_yaml(file_path: str) -> namedtuple: |
||||||
|
"""Parse yaml configuration file and return the object in `namedtuple`.""" |
||||||
|
with open(file_path, "rb") as f: |
||||||
|
cfg: dict = yaml.safe_load(f) |
||||||
|
args = namedtuple("train_args", cfg.keys())(*cfg.values()) |
||||||
|
return args |
||||||
|
|
||||||
|
|
||||||
|
def format_time(elapse): |
||||||
|
elapse = int(elapse) |
||||||
|
hour = elapse // 3600 |
||||||
|
minute = elapse % 3600 // 60 |
||||||
|
seconds = elapse % 60 |
||||||
|
return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds) |
||||||
|
|
||||||
|
|
||||||
|
def ensure_dir(path): |
||||||
|
if not os.path.exists(path): |
||||||
|
os.makedirs(path, exist_ok=True) |
||||||
|
|
||||||
|
|
||||||
|
def adjust_learning_rate(optimizer, epoch): |
||||||
|
|
||||||
|
warm_up = 0.02 |
||||||
|
const_range = 0.6 |
||||||
|
min_lr_rate = 0.05 |
||||||
|
|
||||||
|
if epoch <= args.n_total_epoch * warm_up: |
||||||
|
lr = (1 - min_lr_rate) * args.base_lr / ( |
||||||
|
args.n_total_epoch * warm_up |
||||||
|
) * epoch + min_lr_rate * args.base_lr |
||||||
|
elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range: |
||||||
|
lr = args.base_lr |
||||||
|
else: |
||||||
|
lr = (min_lr_rate - 1) * args.base_lr / ( |
||||||
|
(1 - const_range) * args.n_total_epoch |
||||||
|
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr |
||||||
|
|
||||||
|
for param_group in optimizer.param_groups: |
||||||
|
param_group['lr'] = lr |
||||||
|
|
||||||
|
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): |
||||||
|
''' |
||||||
|
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W) |
||||||
|
flow_preds[0]: (B, 2, H, W) |
||||||
|
flow_gt: (B, 2, H, W) |
||||||
|
''' |
||||||
|
if test: |
||||||
|
# print('sequence loss') |
||||||
|
if valid.shape != (2, 480, 640): |
||||||
|
valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2) |
||||||
|
# print(valid.shape) |
||||||
|
#valid = torch.stack([valid, valid]) |
||||||
|
# print(valid.shape) |
||||||
|
if valid.shape != (2, 480, 640): |
||||||
|
valid = valid.transpose(0,1) |
||||||
|
# print(valid.shape) |
||||||
|
# print(valid.shape) |
||||||
|
# print(flow_preds[0].shape) |
||||||
|
# print(flow_gt.shape) |
||||||
|
n_predictions = len(flow_preds) |
||||||
|
flow_loss = 0.0 |
||||||
|
|
||||||
|
# TEST |
||||||
|
flow_gt = torch.squeeze(flow_gt, dim=-1) |
||||||
|
|
||||||
|
for i in range(n_predictions): |
||||||
|
i_weight = gamma ** (n_predictions - i - 1) |
||||||
|
i_loss = torch.abs(flow_preds[i] - flow_gt) |
||||||
|
flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean() |
||||||
|
|
||||||
|
return flow_loss |
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CREStereoLightning(LightningModule): |
||||||
|
def __init__(self, args): |
||||||
|
super().__init__() |
||||||
|
self.batch_size = args.batch_size |
||||||
|
self.model = Model( |
||||||
|
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False |
||||||
|
) |
||||||
|
|
||||||
|
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True): |
||||||
|
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right) |
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx): |
||||||
|
# loss = self(batch) |
||||||
|
left, right, gt_disp, valid_mask = batch |
||||||
|
left = torch.Tensor(left).to(self.device) |
||||||
|
right = torch.Tensor(right).to(self.device) |
||||||
|
left = left |
||||||
|
right = right |
||||||
|
flow_predictions = self.forward(left, right) |
||||||
|
loss = sequence_loss( |
||||||
|
flow_predictions, gt_flow, valid_mask, gamma=0.8 |
||||||
|
) |
||||||
|
self.log("train_loss", loss) |
||||||
|
return loss |
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx): |
||||||
|
left, right, gt_disp, valid_mask = batch |
||||||
|
left = torch.Tensor(left).to(self.device) |
||||||
|
right = torch.Tensor(right).to(self.device) |
||||||
|
print(left.shape) |
||||||
|
print(right.shape) |
||||||
|
flow_predictions = self.forward(left, right) |
||||||
|
val_loss = sequence_loss( |
||||||
|
flow_predictions, gt_flow, valid_mask, gamma=0.8 |
||||||
|
) |
||||||
|
self.log("val_loss", val_loss) |
||||||
|
|
||||||
|
def test_step(self, batch, batch_idx): |
||||||
|
left, right, gt_disp, valid_mask = batch |
||||||
|
# left, right, gt_disp, valid_mask = ( |
||||||
|
# batch["left"], |
||||||
|
# batch["right"], |
||||||
|
# batch["disparity"], |
||||||
|
# batch["mask"], |
||||||
|
# ) |
||||||
|
left = torch.Tensor(left).to(self.device) |
||||||
|
right = torch.Tensor(right).to(self.device) |
||||||
|
flow_predictions = self.forward(left, right) |
||||||
|
test_loss = sequence_loss( |
||||||
|
flow_predictions, gt_flow, valid_mask, gamma=0.8 |
||||||
|
) |
||||||
|
self.log("test_loss", test_loss) |
||||||
|
|
||||||
|
def configure_optimizers(self): |
||||||
|
return optim.Adam(self.model.parameters(), lr=0.01, betas=(0.9, 0.999)) |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
# train configuration |
||||||
|
args = parse_yaml("cfgs/train.yaml") |
||||||
|
# wandb.init(project="crestereo-lightning", entity="cpt-captain") |
||||||
|
# Lite(strategy='dp', accelerator='gpu', devices=2).run(args) |
||||||
|
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' |
||||||
|
model = CREStereoLightning(args) |
||||||
|
dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, use_lightning=True) |
||||||
|
test_dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, test_set=True, use_lightning=True) |
||||||
|
print(len(dataset)) |
||||||
|
print(len(test_dataset)) |
||||||
|
wandb_logger = WandbLogger(project="crestereo-lightning") |
||||||
|
wandb.config.update(args._asdict()) |
||||||
|
|
||||||
|
trainer = Trainer( |
||||||
|
max_epochs=args.n_total_epoch, |
||||||
|
accelerator='gpu', |
||||||
|
devices=2, |
||||||
|
# auto_scale_batch_size='binsearch', |
||||||
|
# strategy='ddp', |
||||||
|
deterministic=True, |
||||||
|
check_val_every_n_epoch=1, |
||||||
|
limit_val_batches=24, |
||||||
|
limit_test_batches=24, |
||||||
|
logger=wandb_logger, |
||||||
|
default_root_dir=args.log_dir_lightning, |
||||||
|
) |
||||||
|
# trainer.tune(model) |
||||||
|
trainer.fit(model, dataset, test_dataset) |
Loading…
Reference in new issue