|
|
|
@ -5,10 +5,8 @@ import logging |
|
|
|
|
from collections import namedtuple |
|
|
|
|
|
|
|
|
|
import yaml |
|
|
|
|
# from tensorboardX import SummaryWriter |
|
|
|
|
|
|
|
|
|
from nets import Model |
|
|
|
|
# from dataset import CREStereoDataset |
|
|
|
|
from dataset import BlenderDataset, CREStereoDataset, CTDDataset |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
@ -18,8 +16,11 @@ import torch.optim as optim |
|
|
|
|
from torch.utils.data import DataLoader |
|
|
|
|
from pytorch_lightning import LightningDataModule, LightningModule, Trainer |
|
|
|
|
from pytorch_lightning import Trainer, seed_everything |
|
|
|
|
from pytorch_lightning.loggers import WandbLogger |
|
|
|
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping |
|
|
|
|
from pytorch_lightning.callbacks import LearningRateMonitor |
|
|
|
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
|
|
|
from pytorch_lightning.loggers import WandbLogger |
|
|
|
|
from pytorch_lightning.strategies import DDPSpawnStrategy |
|
|
|
|
|
|
|
|
|
seed_everything(42, workers=True) |
|
|
|
|
|
|
|
|
@ -39,11 +40,9 @@ def normalize_and_colormap(img): |
|
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_images(left, right, pred_disp, gt_disp, wandb_logger=None): |
|
|
|
|
# wandb_logger.log_text('test') |
|
|
|
|
# return |
|
|
|
|
def log_images(left, right, pred_disp, gt_disp): |
|
|
|
|
log = {} |
|
|
|
|
batch_idx = 1 |
|
|
|
|
batch_idx = 0 |
|
|
|
|
|
|
|
|
|
if isinstance(pred_disp, list): |
|
|
|
|
pred_disp = pred_disp[-1] |
|
|
|
@ -100,32 +99,13 @@ def ensure_dir(path): |
|
|
|
|
os.makedirs(path, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def adjust_learning_rate(optimizer, epoch): |
|
|
|
|
|
|
|
|
|
warm_up = 0.02 |
|
|
|
|
const_range = 0.6 |
|
|
|
|
min_lr_rate = 0.05 |
|
|
|
|
|
|
|
|
|
if epoch <= args.n_total_epoch * warm_up: |
|
|
|
|
lr = (1 - min_lr_rate) * args.base_lr / ( |
|
|
|
|
args.n_total_epoch * warm_up |
|
|
|
|
) * epoch + min_lr_rate * args.base_lr |
|
|
|
|
elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range: |
|
|
|
|
lr = args.base_lr |
|
|
|
|
else: |
|
|
|
|
lr = (min_lr_rate - 1) * args.base_lr / ( |
|
|
|
|
(1 - const_range) * args.n_total_epoch |
|
|
|
|
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr |
|
|
|
|
|
|
|
|
|
for param_group in optimizer.param_groups: |
|
|
|
|
param_group['lr'] = lr |
|
|
|
|
|
|
|
|
|
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): |
|
|
|
|
''' |
|
|
|
|
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W) |
|
|
|
|
flow_preds[0]: (B, 2, H, W) |
|
|
|
|
flow_gt: (B, 2, H, W) |
|
|
|
|
''' |
|
|
|
|
""" |
|
|
|
|
if test: |
|
|
|
|
# print('sequence loss') |
|
|
|
|
if valid.shape != (2, 480, 640): |
|
|
|
@ -136,6 +116,7 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): |
|
|
|
|
if valid.shape != (2, 480, 640): |
|
|
|
|
valid = valid.transpose(0,1) |
|
|
|
|
# print(valid.shape) |
|
|
|
|
""" |
|
|
|
|
# print(valid.shape) |
|
|
|
|
# print(flow_preds[0].shape) |
|
|
|
|
# print(flow_gt.shape) |
|
|
|
@ -143,7 +124,7 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): |
|
|
|
|
flow_loss = 0.0 |
|
|
|
|
|
|
|
|
|
# TEST |
|
|
|
|
flow_gt = torch.squeeze(flow_gt, dim=-1) |
|
|
|
|
# flow_gt = torch.squeeze(flow_gt, dim=-1) |
|
|
|
|
|
|
|
|
|
for i in range(n_predictions): |
|
|
|
|
i_weight = gamma ** (n_predictions - i - 1) |
|
|
|
@ -155,16 +136,88 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CREStereoLightning(LightningModule): |
|
|
|
|
def __init__(self, args, logger): |
|
|
|
|
def __init__(self, args, logger, pattern_path, data_path): |
|
|
|
|
super().__init__() |
|
|
|
|
self.batch_size = args.batch_size |
|
|
|
|
self.wandb_logger = logger |
|
|
|
|
self.lr = args.base_lr |
|
|
|
|
print(f'lr = {self.lr}') |
|
|
|
|
self.T_max = args.t_max if args.t_max else None |
|
|
|
|
self.pattern_attention = args.pattern_attention |
|
|
|
|
self.pattern_path = pattern_path |
|
|
|
|
self.data_path = data_path |
|
|
|
|
self.model = Model( |
|
|
|
|
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True): |
|
|
|
|
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right) |
|
|
|
|
def train_dataloader(self): |
|
|
|
|
dataset = BlenderDataset( |
|
|
|
|
root=self.data_path, |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
use_lightning=True, |
|
|
|
|
) |
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
dataset, |
|
|
|
|
self.batch_size, |
|
|
|
|
shuffle=True, |
|
|
|
|
num_workers=4, |
|
|
|
|
drop_last=True, |
|
|
|
|
persistent_workers=True, |
|
|
|
|
pin_memory=True, |
|
|
|
|
) |
|
|
|
|
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) |
|
|
|
|
return dataloader |
|
|
|
|
|
|
|
|
|
def val_dataloader(self): |
|
|
|
|
test_dataset = BlenderDataset( |
|
|
|
|
root=self.data_path, |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
test_set=True, |
|
|
|
|
use_lightning=True, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
test_dataloader = DataLoader( |
|
|
|
|
test_dataset, |
|
|
|
|
self.batch_size, |
|
|
|
|
shuffle=False, |
|
|
|
|
num_workers=4, |
|
|
|
|
drop_last=False, |
|
|
|
|
persistent_workers=True, |
|
|
|
|
pin_memory=True |
|
|
|
|
) |
|
|
|
|
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) |
|
|
|
|
return test_dataloader |
|
|
|
|
|
|
|
|
|
def test_dataloader(self): |
|
|
|
|
# TODO change this to use IRL data? |
|
|
|
|
test_dataset = BlenderDataset( |
|
|
|
|
root=self.data_path, |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
test_set=True, |
|
|
|
|
use_lightning=True, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
test_dataloader = DataLoader( |
|
|
|
|
test_dataset, |
|
|
|
|
self.batch_size, |
|
|
|
|
shuffle=False, |
|
|
|
|
num_workers=4, |
|
|
|
|
drop_last=False, |
|
|
|
|
persistent_workers=True, |
|
|
|
|
pin_memory=True |
|
|
|
|
) |
|
|
|
|
return test_dataloader |
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
|
self, |
|
|
|
|
image1, |
|
|
|
|
image2, |
|
|
|
|
flow_init=None, |
|
|
|
|
iters=10, |
|
|
|
|
upsample=True, |
|
|
|
|
test_mode=False, |
|
|
|
|
): |
|
|
|
|
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self.pattern_attention) |
|
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
|
left, right, gt_disp, valid_mask = batch |
|
|
|
@ -174,6 +227,10 @@ class CREStereoLightning(LightningModule): |
|
|
|
|
loss = sequence_loss( |
|
|
|
|
flow_predictions, gt_flow, valid_mask, gamma=0.8 |
|
|
|
|
) |
|
|
|
|
if batch_idx % 128 == 0: |
|
|
|
|
image_log = log_images(left, right, flow_predictions, gt_disp) |
|
|
|
|
image_log['key'] = 'debug_train' |
|
|
|
|
self.wandb_logger.log_image(**image_log) |
|
|
|
|
self.log("train_loss", loss) |
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
@ -186,22 +243,31 @@ class CREStereoLightning(LightningModule): |
|
|
|
|
flow_predictions, gt_flow, valid_mask, gamma=0.8 |
|
|
|
|
) |
|
|
|
|
self.log("val_loss", val_loss) |
|
|
|
|
if batch_idx % 4 == 0: |
|
|
|
|
if batch_idx % 8 == 0: |
|
|
|
|
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) |
|
|
|
|
|
|
|
|
|
def test_step(self, batch, batch_idx): |
|
|
|
|
left, right, gt_disp, valid_mask = batch |
|
|
|
|
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] |
|
|
|
|
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] |
|
|
|
|
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] |
|
|
|
|
flow_predictions = self.forward(left, right, test_mode=True) |
|
|
|
|
test_loss = sequence_loss( |
|
|
|
|
flow_predictions, gt_flow, valid_mask, gamma=0.8 |
|
|
|
|
) |
|
|
|
|
self.log("test_loss", test_loss) |
|
|
|
|
print('test_batch_idx:', batch_idx) |
|
|
|
|
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) |
|
|
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
|
|
|
return optim.Adam(self.model.parameters(), lr=0.1, betas=(0.9, 0.999)) |
|
|
|
|
optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999)) |
|
|
|
|
print('len(self.train_dataloader)', len(self.train_dataloader())) |
|
|
|
|
lr_scheduler = { |
|
|
|
|
'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR( |
|
|
|
|
optimizer, |
|
|
|
|
T_max=self.T_max if self.T_max else len(self.train_dataloader())/self.batch_size, |
|
|
|
|
), |
|
|
|
|
'name': 'CosineAnnealingLRScheduler', |
|
|
|
|
} |
|
|
|
|
return [optimizer], [lr_scheduler] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
@ -209,61 +275,54 @@ if __name__ == "__main__": |
|
|
|
|
args = parse_yaml("cfgs/train.yaml") |
|
|
|
|
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' |
|
|
|
|
|
|
|
|
|
wandb_logger = WandbLogger(project="crestereo-lightning") |
|
|
|
|
wandb.config.update(args._asdict()) |
|
|
|
|
|
|
|
|
|
model = CREStereoLightning(args, wandb_logger) |
|
|
|
|
|
|
|
|
|
dataset = BlenderDataset( |
|
|
|
|
root=args.training_data_path, |
|
|
|
|
pattern_path=pattern_path, |
|
|
|
|
use_lightning=True, |
|
|
|
|
) |
|
|
|
|
test_dataset = BlenderDataset( |
|
|
|
|
root=args.training_data_path, |
|
|
|
|
pattern_path=pattern_path, |
|
|
|
|
test_set=True, |
|
|
|
|
use_lightning=True, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
dataset, |
|
|
|
|
args.batch_size, |
|
|
|
|
shuffle=True, |
|
|
|
|
num_workers=16, |
|
|
|
|
drop_last=True, |
|
|
|
|
persistent_workers=True, |
|
|
|
|
pin_memory=True, |
|
|
|
|
) |
|
|
|
|
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) |
|
|
|
|
test_dataloader = DataLoader( |
|
|
|
|
test_dataset, |
|
|
|
|
args.batch_size, |
|
|
|
|
shuffle=False, |
|
|
|
|
num_workers=16, |
|
|
|
|
drop_last=False, |
|
|
|
|
persistent_workers=True, |
|
|
|
|
pin_memory=True |
|
|
|
|
) |
|
|
|
|
run = wandb.init(project="crestereo-lightning", config=args._asdict(), tags=['new_scheduler', 'default_lr', f'{"" if args.pattern_attention else "no-"}pattern-attention'], notes='') |
|
|
|
|
run.config.update(args._asdict()) |
|
|
|
|
config = wandb.config |
|
|
|
|
wandb_logger = WandbLogger(project="crestereo-lightning", id=run.id, log_model=True) |
|
|
|
|
# wandb_logger = WandbLogger(project="crestereo-lightning", log_model='all') |
|
|
|
|
# wandb_logger.experiment.config.update(args._asdict()) |
|
|
|
|
|
|
|
|
|
model = CREStereoLightning( |
|
|
|
|
# args, |
|
|
|
|
config, |
|
|
|
|
wandb_logger, |
|
|
|
|
pattern_path, |
|
|
|
|
args.training_data_path, |
|
|
|
|
# lr=0.00017378008287493763, # found with auto_lr_find=True |
|
|
|
|
) |
|
|
|
|
# NOTE turn this down once it's working, this might use too much space |
|
|
|
|
# wandb_logger.watch(model, log_graph=False) #, log='all') |
|
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
|
|
accelerator='gpu', |
|
|
|
|
devices=2, |
|
|
|
|
devices=args.nr_gpus, |
|
|
|
|
max_epochs=args.n_total_epoch, |
|
|
|
|
callbacks=[ |
|
|
|
|
EarlyStopping( |
|
|
|
|
monitor="val_loss", |
|
|
|
|
mode="min", |
|
|
|
|
patience=4, |
|
|
|
|
patience=16, |
|
|
|
|
), |
|
|
|
|
LearningRateMonitor(), |
|
|
|
|
ModelCheckpoint( |
|
|
|
|
monitor="val_loss", |
|
|
|
|
mode="min", |
|
|
|
|
save_top_k=2, |
|
|
|
|
save_last=True, |
|
|
|
|
) |
|
|
|
|
], |
|
|
|
|
accumulate_grad_batches=8, |
|
|
|
|
strategy=DDPSpawnStrategy(find_unused_parameters=False), |
|
|
|
|
# auto_scale_batch_size='binsearch', |
|
|
|
|
# auto_lr_find=True, |
|
|
|
|
accumulate_grad_batches=4, |
|
|
|
|
deterministic=True, |
|
|
|
|
check_val_every_n_epoch=1, |
|
|
|
|
limit_val_batches=24, |
|
|
|
|
limit_test_batches=24, |
|
|
|
|
limit_val_batches=64, |
|
|
|
|
limit_test_batches=256, |
|
|
|
|
logger=wandb_logger, |
|
|
|
|
default_root_dir=args.log_dir_lightning, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
trainer.fit(model, dataloader, test_dataloader) |
|
|
|
|
# trainer.tune(model) |
|
|
|
|
trainer.fit(model) |
|
|
|
|
trainer.validate() |
|
|
|
|