|
|
|
@ -94,6 +94,32 @@ def format_time(elapse): |
|
|
|
|
return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def outlier_fraction(estimate, target, mask=None, threshold=0): |
|
|
|
|
def _process_inputs(estimate, target, mask): |
|
|
|
|
if estimate.shape != target.shape: |
|
|
|
|
raise Exception(f'estimate and target have to be same shape (expected {estimate.shape} == {target.shape})') |
|
|
|
|
if mask is None: |
|
|
|
|
mask = np.ones(estimate.shape, dtype=np.bool) |
|
|
|
|
else: |
|
|
|
|
mask = mask != 0 |
|
|
|
|
if estimate.shape != mask.shape: |
|
|
|
|
raise Exception(f'estimate and mask have to be same shape (expected {estimate.shape} == {mask.shape})') |
|
|
|
|
return estimate, target, mask |
|
|
|
|
estimate = torch.squeeze(estimate[:, 0, :, :]) |
|
|
|
|
target = torch.squeeze(target[:, 0, :, :]) |
|
|
|
|
|
|
|
|
|
estimate, target, mask = _process_inputs(estimate, target, mask) |
|
|
|
|
|
|
|
|
|
mask = mask.cpu().detach().numpy() |
|
|
|
|
estimate = estimate.cpu().detach().numpy() |
|
|
|
|
target = target.cpu().detach().numpy() |
|
|
|
|
|
|
|
|
|
diff = np.abs(estimate[mask] - target[mask]) |
|
|
|
|
m = (diff > threshold).sum() / mask.sum() |
|
|
|
|
return m |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ensure_dir(path): |
|
|
|
|
if not os.path.exists(path): |
|
|
|
|
os.makedirs(path, exist_ok=True) |
|
|
|
@ -134,12 +160,12 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): |
|
|
|
|
return flow_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CREStereoLightning(LightningModule): |
|
|
|
|
def __init__(self, args, logger, pattern_path, data_path): |
|
|
|
|
def __init__(self, args, logger=None, pattern_path='', data_path=''): |
|
|
|
|
super().__init__() |
|
|
|
|
self.batch_size = args.batch_size |
|
|
|
|
self.wandb_logger = logger |
|
|
|
|
self.data_type = 'blender' if 'blender' in data_path else 'ctd' |
|
|
|
|
self.lr = args.base_lr |
|
|
|
|
print(f'lr = {self.lr}') |
|
|
|
|
self.T_max = args.t_max if args.t_max else None |
|
|
|
@ -149,13 +175,25 @@ class CREStereoLightning(LightningModule): |
|
|
|
|
self.model = Model( |
|
|
|
|
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False |
|
|
|
|
) |
|
|
|
|
# so I can access it in adjust learn rate more easily |
|
|
|
|
self.n_total_epoch = args.n_total_epoch |
|
|
|
|
self.base_lr = args.base_lr |
|
|
|
|
|
|
|
|
|
self.automatic_optimization = False |
|
|
|
|
|
|
|
|
|
def train_dataloader(self): |
|
|
|
|
dataset = BlenderDataset( |
|
|
|
|
root=self.data_path, |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
use_lightning=True, |
|
|
|
|
) |
|
|
|
|
if self.data_type == 'blender': |
|
|
|
|
dataset = BlenderDataset( |
|
|
|
|
root=self.data_path, |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
use_lightning=True, |
|
|
|
|
) |
|
|
|
|
elif self.data_type == 'ctd': |
|
|
|
|
dataset = CTDDataset( |
|
|
|
|
root=self.data_path, |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
use_lightning=True, |
|
|
|
|
) |
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
dataset, |
|
|
|
|
self.batch_size, |
|
|
|
@ -169,12 +207,20 @@ class CREStereoLightning(LightningModule): |
|
|
|
|
return dataloader |
|
|
|
|
|
|
|
|
|
def val_dataloader(self): |
|
|
|
|
test_dataset = BlenderDataset( |
|
|
|
|
root=self.data_path, |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
test_set=True, |
|
|
|
|
use_lightning=True, |
|
|
|
|
) |
|
|
|
|
if self.data_type == 'blender': |
|
|
|
|
test_dataset = BlenderDataset( |
|
|
|
|
root=self.data_path, |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
test_set=True, |
|
|
|
|
use_lightning=True, |
|
|
|
|
) |
|
|
|
|
elif self.data_type == 'ctd': |
|
|
|
|
test_dataset = CTDDataset( |
|
|
|
|
root=self.data_path, |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
test_set=True, |
|
|
|
|
use_lightning=True, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
test_dataloader = DataLoader( |
|
|
|
|
test_dataset, |
|
|
|
@ -190,12 +236,20 @@ class CREStereoLightning(LightningModule): |
|
|
|
|
|
|
|
|
|
def test_dataloader(self): |
|
|
|
|
# TODO change this to use IRL data? |
|
|
|
|
test_dataset = BlenderDataset( |
|
|
|
|
root=self.data_path, |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
test_set=True, |
|
|
|
|
use_lightning=True, |
|
|
|
|
) |
|
|
|
|
if self.data_type == 'blender': |
|
|
|
|
test_dataset = CTDDataset( |
|
|
|
|
root=self.data_path, |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
test_set=True, |
|
|
|
|
use_lightning=True, |
|
|
|
|
) |
|
|
|
|
elif self.data_type == 'ctd': |
|
|
|
|
test_dataset = BlenderDataset( |
|
|
|
|
root=self.data_path, |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
test_set=True, |
|
|
|
|
use_lightning=True, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
test_dataloader = DataLoader( |
|
|
|
|
test_dataset, |
|
|
|
@ -232,6 +286,10 @@ class CREStereoLightning(LightningModule): |
|
|
|
|
image_log['key'] = 'debug_train' |
|
|
|
|
self.wandb_logger.log_image(**image_log) |
|
|
|
|
self.log("train_loss", loss) |
|
|
|
|
|
|
|
|
|
# update learn rate every N epochs |
|
|
|
|
if self.trainer.is_last_batch: |
|
|
|
|
self.adjust_learning_rate() |
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
|
@ -243,6 +301,11 @@ class CREStereoLightning(LightningModule): |
|
|
|
|
flow_predictions, gt_flow, valid_mask, gamma=0.8 |
|
|
|
|
) |
|
|
|
|
self.log("val_loss", val_loss) |
|
|
|
|
of = {} |
|
|
|
|
for threshold in [0.1, 0.5, 1, 2, 5]: |
|
|
|
|
of[threshold] = outlier_fraction(flow_predictions[0], gt_flow, valid_mask, threshold) |
|
|
|
|
self.log("outlier_fraction", of) |
|
|
|
|
# print(', '.join(f'of{thr}={val}' for thr, val in of.items())) |
|
|
|
|
if batch_idx % 8 == 0: |
|
|
|
|
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) |
|
|
|
|
|
|
|
|
@ -257,6 +320,9 @@ class CREStereoLightning(LightningModule): |
|
|
|
|
self.log("test_loss", test_loss) |
|
|
|
|
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) |
|
|
|
|
|
|
|
|
|
def predict_step(self, batch, batch_idx, dataloader_idx=0): |
|
|
|
|
return self(batch) |
|
|
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
|
|
|
optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999)) |
|
|
|
|
print('len(self.train_dataloader)', len(self.train_dataloader())) |
|
|
|
@ -269,34 +335,69 @@ class CREStereoLightning(LightningModule): |
|
|
|
|
} |
|
|
|
|
return [optimizer], [lr_scheduler] |
|
|
|
|
|
|
|
|
|
def adjust_learning_rate(self): |
|
|
|
|
optimizer = self.optimizers().optimizer |
|
|
|
|
epoch = self.trainer.current_epoch+1 |
|
|
|
|
|
|
|
|
|
warm_up = 0.02 |
|
|
|
|
const_range = 0.6 |
|
|
|
|
min_lr_rate = 0.05 |
|
|
|
|
|
|
|
|
|
if epoch <= self.n_total_epoch * warm_up: |
|
|
|
|
lr = (1 - min_lr_rate) * self.base_lr / ( |
|
|
|
|
self.n_total_epoch * warm_up |
|
|
|
|
) * epoch + min_lr_rate * self.base_lr |
|
|
|
|
elif self.n_total_epoch * warm_up < epoch <= self.n_total_epoch * const_range: |
|
|
|
|
lr = self.base_lr |
|
|
|
|
else: |
|
|
|
|
lr = (min_lr_rate - 1) * self.base_lr / ( |
|
|
|
|
(1 - const_range) * self.n_total_epoch |
|
|
|
|
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * self.base_lr |
|
|
|
|
|
|
|
|
|
for param_group in optimizer.param_groups: |
|
|
|
|
param_group['lr'] = lr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
wandb_logger = WandbLogger(project="crestereo-lightning", log_model=True) |
|
|
|
|
# train configuration |
|
|
|
|
args = parse_yaml("cfgs/train.yaml") |
|
|
|
|
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' |
|
|
|
|
|
|
|
|
|
run = wandb.init(project="crestereo-lightning", config=args._asdict(), tags=['new_scheduler', 'default_lr', f'{"" if args.pattern_attention else "no-"}pattern-attention'], notes='') |
|
|
|
|
run.config.update(args._asdict()) |
|
|
|
|
wandb_logger.experiment.config.update(args._asdict()) |
|
|
|
|
config = wandb.config |
|
|
|
|
wandb_logger = WandbLogger(project="crestereo-lightning", id=run.id, log_model=True) |
|
|
|
|
# wandb_logger = WandbLogger(project="crestereo-lightning", log_model='all') |
|
|
|
|
# wandb_logger.experiment.config.update(args._asdict()) |
|
|
|
|
if 'blender' in config.training_data_path: |
|
|
|
|
# this was used for our blender renders |
|
|
|
|
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' |
|
|
|
|
if 'ctd' in config.training_data_path: |
|
|
|
|
# this one is used (i hope) for ctd |
|
|
|
|
pattern_path = '/home/nils/kinect_from_settings.png' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
devices = min(config.nr_gpus, torch.cuda.device_count()) |
|
|
|
|
if devices != config.nr_gpus: |
|
|
|
|
print(f'Using less devices than expected! ({devices} / {config.nr_gpus})') |
|
|
|
|
|
|
|
|
|
model = CREStereoLightning( |
|
|
|
|
# args, |
|
|
|
|
config, |
|
|
|
|
wandb_logger, |
|
|
|
|
pattern_path, |
|
|
|
|
args.training_data_path, |
|
|
|
|
config.training_data_path, |
|
|
|
|
# lr=0.00017378008287493763, # found with auto_lr_find=True |
|
|
|
|
) |
|
|
|
|
# NOTE turn this down once it's working, this might use too much space |
|
|
|
|
# wandb_logger.watch(model, log_graph=False) #, log='all') |
|
|
|
|
|
|
|
|
|
model_checkpoint = ModelCheckpoint( |
|
|
|
|
monitor="val_loss", |
|
|
|
|
mode="min", |
|
|
|
|
save_top_k=2, |
|
|
|
|
save_last=True, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
|
|
accelerator='gpu', |
|
|
|
|
devices=args.nr_gpus, |
|
|
|
|
max_epochs=args.n_total_epoch, |
|
|
|
|
devices=devices, |
|
|
|
|
max_epochs=config.n_total_epoch, |
|
|
|
|
callbacks=[ |
|
|
|
|
EarlyStopping( |
|
|
|
|
monitor="val_loss", |
|
|
|
@ -304,25 +405,20 @@ if __name__ == "__main__": |
|
|
|
|
patience=16, |
|
|
|
|
), |
|
|
|
|
LearningRateMonitor(), |
|
|
|
|
ModelCheckpoint( |
|
|
|
|
monitor="val_loss", |
|
|
|
|
mode="min", |
|
|
|
|
save_top_k=2, |
|
|
|
|
save_last=True, |
|
|
|
|
) |
|
|
|
|
model_checkpoint, |
|
|
|
|
], |
|
|
|
|
strategy=DDPSpawnStrategy(find_unused_parameters=False), |
|
|
|
|
# auto_scale_batch_size='binsearch', |
|
|
|
|
# auto_lr_find=True, |
|
|
|
|
accumulate_grad_batches=4, |
|
|
|
|
# accumulate_grad_batches=4, # needed to disable for manual optimization |
|
|
|
|
deterministic=True, |
|
|
|
|
check_val_every_n_epoch=1, |
|
|
|
|
limit_val_batches=64, |
|
|
|
|
limit_test_batches=256, |
|
|
|
|
logger=wandb_logger, |
|
|
|
|
default_root_dir=args.log_dir_lightning, |
|
|
|
|
default_root_dir=config.log_dir_lightning, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# trainer.tune(model) |
|
|
|
|
trainer.fit(model) |
|
|
|
|
trainer.validate() |
|
|
|
|
# trainer.validate(chkpt_path=model_checkpoint.best_model_path) |
|
|
|
|