|
|
|
@ -31,8 +31,22 @@ import numpy as np |
|
|
|
|
import cv2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def normalize_and_colormap(img): |
|
|
|
|
def normalize_and_colormap(img, reduce_dynamic_range=False): |
|
|
|
|
# print(img.min()) |
|
|
|
|
# print(img.max()) |
|
|
|
|
# print(img.mean()) |
|
|
|
|
ret = (img - img.min()) / (img.max() - img.min()) * 255.0 |
|
|
|
|
# print(ret.min()) |
|
|
|
|
# print(ret.max()) |
|
|
|
|
# print(ret.mean()) |
|
|
|
|
|
|
|
|
|
# FIXME do I need to compress dynamic range somehow or something? |
|
|
|
|
if reduce_dynamic_range and img.max() > 5*img.mean(): |
|
|
|
|
ret = (img - img.min()) / (5*img.mean() - img.min()) * 255.0 |
|
|
|
|
# print(ret.min()) |
|
|
|
|
# print(ret.max()) |
|
|
|
|
# print(ret.mean()) |
|
|
|
|
|
|
|
|
|
if isinstance(ret, torch.Tensor): |
|
|
|
|
ret = ret.cpu().detach().numpy() |
|
|
|
|
ret = ret.astype("uint8") |
|
|
|
@ -47,34 +61,71 @@ def log_images(left, right, pred_disp, gt_disp): |
|
|
|
|
if isinstance(pred_disp, list): |
|
|
|
|
pred_disp = pred_disp[-1] |
|
|
|
|
|
|
|
|
|
pred_disp = torch.squeeze(pred_disp[:, 0, :, :]) |
|
|
|
|
gt_disp = torch.squeeze(gt_disp[:, 0, :, :]) |
|
|
|
|
left = torch.squeeze(left[:, 0, :, :]) |
|
|
|
|
right = torch.squeeze(right[:, 0, :, :]) |
|
|
|
|
|
|
|
|
|
pred_disp = torch.squeeze(pred_disp[:, 0, :, :]) |
|
|
|
|
gt_disp = torch.squeeze(gt_disp[:, 0, :, :]) |
|
|
|
|
|
|
|
|
|
# print('gt_disp debug') |
|
|
|
|
# print(gt_disp.shape) |
|
|
|
|
|
|
|
|
|
singular_batch = False |
|
|
|
|
if len(left.shape) == 2: |
|
|
|
|
singular_batch = True |
|
|
|
|
print('batch_size seems to be 1') |
|
|
|
|
input_left = left.cpu().detach().numpy() |
|
|
|
|
input_right = right.cpu().detach().numpy() |
|
|
|
|
else: |
|
|
|
|
input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) |
|
|
|
|
input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) |
|
|
|
|
|
|
|
|
|
disp = pred_disp |
|
|
|
|
disp_error = gt_disp - disp |
|
|
|
|
|
|
|
|
|
input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) |
|
|
|
|
input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) |
|
|
|
|
|
|
|
|
|
wandb_log = dict( |
|
|
|
|
key='samples', |
|
|
|
|
images=[ |
|
|
|
|
normalize_and_colormap(pred_disp[batch_idx]), |
|
|
|
|
normalize_and_colormap(abs(disp_error[batch_idx])), |
|
|
|
|
normalize_and_colormap(gt_disp[batch_idx]), |
|
|
|
|
input_left, |
|
|
|
|
input_right, |
|
|
|
|
], |
|
|
|
|
caption=[ |
|
|
|
|
f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}", |
|
|
|
|
f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}", |
|
|
|
|
f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}", |
|
|
|
|
"Input Left", |
|
|
|
|
"Input Right" |
|
|
|
|
], |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# print('gt_disp debug normalize') |
|
|
|
|
# print(gt_disp.max(), gt_disp.min()) |
|
|
|
|
# print(gt_disp.dtype) |
|
|
|
|
|
|
|
|
|
if singular_batch: |
|
|
|
|
wandb_log = dict( |
|
|
|
|
key='samples', |
|
|
|
|
images=[ |
|
|
|
|
pred_disp, |
|
|
|
|
normalize_and_colormap(pred_disp), |
|
|
|
|
normalize_and_colormap(abs(disp_error), reduce_dynamic_range=True), |
|
|
|
|
normalize_and_colormap(gt_disp, reduce_dynamic_range=True), |
|
|
|
|
input_left, |
|
|
|
|
input_right, |
|
|
|
|
], |
|
|
|
|
caption=[ |
|
|
|
|
f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", |
|
|
|
|
f"Disparity (vis) \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", |
|
|
|
|
f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}", |
|
|
|
|
f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}", |
|
|
|
|
"Input Left", |
|
|
|
|
"Input Right" |
|
|
|
|
], |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
wandb_log = dict( |
|
|
|
|
key='samples', |
|
|
|
|
images=[ |
|
|
|
|
# pred_disp.cpu().detach().numpy().transpose(1,2,0), |
|
|
|
|
normalize_and_colormap(pred_disp[batch_idx]), |
|
|
|
|
normalize_and_colormap(abs(disp_error[batch_idx])), |
|
|
|
|
normalize_and_colormap(gt_disp[batch_idx]), |
|
|
|
|
input_left, |
|
|
|
|
input_right, |
|
|
|
|
], |
|
|
|
|
caption=[ |
|
|
|
|
# f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}", |
|
|
|
|
f"Disparity (vis)\n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}", |
|
|
|
|
f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}", |
|
|
|
|
f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}", |
|
|
|
|
"Input Left", |
|
|
|
|
"Input Right" |
|
|
|
|
], |
|
|
|
|
) |
|
|
|
|
return wandb_log |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -104,7 +155,10 @@ def outlier_fraction(estimate, target, mask=None, threshold=0): |
|
|
|
|
else: |
|
|
|
|
mask = mask != 0 |
|
|
|
|
if estimate.shape != mask.shape: |
|
|
|
|
raise Exception(f'estimate and mask have to be same shape (expected {estimate.shape} == {mask.shape})') |
|
|
|
|
if len(mask.shape) == 3: |
|
|
|
|
mask = mask[0] |
|
|
|
|
if estimate.shape != mask.shape: |
|
|
|
|
raise Exception(f'estimate and mask have to be same shape (expected {estimate.shape} == {mask.shape})') |
|
|
|
|
return estimate, target, mask |
|
|
|
|
estimate = torch.squeeze(estimate[:, 0, :, :]) |
|
|
|
|
target = torch.squeeze(target[:, 0, :, :]) |
|
|
|
@ -131,27 +185,9 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): |
|
|
|
|
flow_preds[0]: (B, 2, H, W) |
|
|
|
|
flow_gt: (B, 2, H, W) |
|
|
|
|
''' |
|
|
|
|
""" |
|
|
|
|
if test: |
|
|
|
|
# print('sequence loss') |
|
|
|
|
if valid.shape != (2, 480, 640): |
|
|
|
|
valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2) |
|
|
|
|
# print(valid.shape) |
|
|
|
|
#valid = torch.stack([valid, valid]) |
|
|
|
|
# print(valid.shape) |
|
|
|
|
if valid.shape != (2, 480, 640): |
|
|
|
|
valid = valid.transpose(0,1) |
|
|
|
|
# print(valid.shape) |
|
|
|
|
""" |
|
|
|
|
# print(valid.shape) |
|
|
|
|
# print(flow_preds[0].shape) |
|
|
|
|
# print(flow_gt.shape) |
|
|
|
|
n_predictions = len(flow_preds) |
|
|
|
|
flow_loss = 0.0 |
|
|
|
|
|
|
|
|
|
# TEST |
|
|
|
|
# flow_gt = torch.squeeze(flow_gt, dim=-1) |
|
|
|
|
|
|
|
|
|
for i in range(n_predictions): |
|
|
|
|
i_weight = gamma ** (n_predictions - i - 1) |
|
|
|
|
i_loss = torch.abs(flow_preds[i] - flow_gt) |
|
|
|
@ -161,38 +197,50 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CREStereoLightning(LightningModule): |
|
|
|
|
def __init__(self, args, logger=None, pattern_path='', data_path=''): |
|
|
|
|
def __init__(self, args, logger=None, pattern_path=''): |
|
|
|
|
super().__init__() |
|
|
|
|
self.batch_size = args.batch_size |
|
|
|
|
self.wandb_logger = logger |
|
|
|
|
self.data_type = 'blender' if 'blender' in data_path else 'ctd' |
|
|
|
|
self.imwidth = args.image_width |
|
|
|
|
self.imheight = args.image_height |
|
|
|
|
self.data_type = 'blender' if 'blender' in args.training_data_path else 'ctd' |
|
|
|
|
self.eval_type = 'kinect' if 'kinect' in args.test_data_path else args.training_data_path |
|
|
|
|
self.lr = args.base_lr |
|
|
|
|
print(f'lr = {self.lr}') |
|
|
|
|
self.T_max = args.t_max if args.t_max else None |
|
|
|
|
self.pattern_attention = args.pattern_attention |
|
|
|
|
self.pattern_path = pattern_path |
|
|
|
|
self.data_path = data_path |
|
|
|
|
self.data_path = args.training_data_path |
|
|
|
|
self.test_data_path = args.test_data_path |
|
|
|
|
self.data_limit = args.data_limit # between 0 and 1. |
|
|
|
|
self.model = Model( |
|
|
|
|
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False |
|
|
|
|
) |
|
|
|
|
# so I can access it in adjust learn rate more easily |
|
|
|
|
|
|
|
|
|
if args.scheduler == 'default': |
|
|
|
|
self.automatic_optimization = False |
|
|
|
|
# so I can access it in adjust learn rate more easily |
|
|
|
|
|
|
|
|
|
self.n_total_epoch = args.n_total_epoch |
|
|
|
|
self.base_lr = args.base_lr |
|
|
|
|
|
|
|
|
|
self.automatic_optimization = False |
|
|
|
|
|
|
|
|
|
def train_dataloader(self): |
|
|
|
|
# we never train on kinect |
|
|
|
|
is_kinect = False |
|
|
|
|
if self.data_type == 'blender': |
|
|
|
|
dataset = BlenderDataset( |
|
|
|
|
root=self.data_path, |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
use_lightning=True, |
|
|
|
|
data_type='kinect' if is_kinect else 'blender', |
|
|
|
|
disp_avail=not is_kinect, |
|
|
|
|
data_limit = self.data_limit, |
|
|
|
|
) |
|
|
|
|
elif self.data_type == 'ctd': |
|
|
|
|
dataset = CTDDataset( |
|
|
|
|
root=self.data_path, |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
use_lightning=True, |
|
|
|
|
data_limit = self.data_limit, |
|
|
|
|
) |
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
dataset, |
|
|
|
@ -203,16 +251,20 @@ class CREStereoLightning(LightningModule): |
|
|
|
|
persistent_workers=True, |
|
|
|
|
pin_memory=True, |
|
|
|
|
) |
|
|
|
|
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) |
|
|
|
|
return dataloader |
|
|
|
|
|
|
|
|
|
def val_dataloader(self): |
|
|
|
|
# we also don't want to validate on kinect data |
|
|
|
|
is_kinect = False |
|
|
|
|
if self.data_type == 'blender': |
|
|
|
|
test_dataset = BlenderDataset( |
|
|
|
|
root=self.data_path, |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
test_set=True, |
|
|
|
|
use_lightning=True, |
|
|
|
|
data_type='kinect' if is_kinect else 'blender', |
|
|
|
|
disp_avail=not is_kinect, |
|
|
|
|
data_limit = self.data_limit, |
|
|
|
|
) |
|
|
|
|
elif self.data_type == 'ctd': |
|
|
|
|
test_dataset = CTDDataset( |
|
|
|
@ -220,6 +272,7 @@ class CREStereoLightning(LightningModule): |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
test_set=True, |
|
|
|
|
use_lightning=True, |
|
|
|
|
data_limit = self.data_limit, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
test_dataloader = DataLoader( |
|
|
|
@ -231,29 +284,35 @@ class CREStereoLightning(LightningModule): |
|
|
|
|
persistent_workers=True, |
|
|
|
|
pin_memory=True |
|
|
|
|
) |
|
|
|
|
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) |
|
|
|
|
return test_dataloader |
|
|
|
|
|
|
|
|
|
def test_dataloader(self): |
|
|
|
|
# TODO change this to use IRL data? |
|
|
|
|
is_kinect = self.eval_type == 'kinect' |
|
|
|
|
if self.data_type == 'blender': |
|
|
|
|
test_dataset = CTDDataset( |
|
|
|
|
root=self.data_path, |
|
|
|
|
test_dataset = BlenderDataset( |
|
|
|
|
root=self.test_data_path, |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
test_set=True, |
|
|
|
|
split=0. if is_kinect else 0.9, # if we test on kinect data, use all available samples for test set |
|
|
|
|
use_lightning=True, |
|
|
|
|
augment=False, |
|
|
|
|
disp_avail=not is_kinect, |
|
|
|
|
data_type='kinect' if is_kinect else 'blender', |
|
|
|
|
data_limit = self.data_limit, |
|
|
|
|
) |
|
|
|
|
elif self.data_type == 'ctd': |
|
|
|
|
test_dataset = BlenderDataset( |
|
|
|
|
root=self.data_path, |
|
|
|
|
test_dataset = CTDDataset( |
|
|
|
|
root=self.test_data_path, |
|
|
|
|
pattern_path=self.pattern_path, |
|
|
|
|
test_set=True, |
|
|
|
|
use_lightning=True, |
|
|
|
|
augment=False, |
|
|
|
|
data_limit = self.data_limit, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
test_dataloader = DataLoader( |
|
|
|
|
test_dataset, |
|
|
|
|
self.batch_size, |
|
|
|
|
1 if is_kinect else self.batch_size, |
|
|
|
|
shuffle=False, |
|
|
|
|
num_workers=4, |
|
|
|
|
drop_last=False, |
|
|
|
@ -307,7 +366,8 @@ class CREStereoLightning(LightningModule): |
|
|
|
|
self.log("outlier_fraction", of) |
|
|
|
|
# print(', '.join(f'of{thr}={val}' for thr, val in of.items())) |
|
|
|
|
if batch_idx % 8 == 0: |
|
|
|
|
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) |
|
|
|
|
images = log_images(left, right, flow_predictions, gt_disp) |
|
|
|
|
self.wandb_logger.log_image(**images) |
|
|
|
|
|
|
|
|
|
def test_step(self, batch, batch_idx): |
|
|
|
|
left, right, gt_disp, valid_mask = batch |
|
|
|
@ -318,20 +378,28 @@ class CREStereoLightning(LightningModule): |
|
|
|
|
flow_predictions, gt_flow, valid_mask, gamma=0.8 |
|
|
|
|
) |
|
|
|
|
self.log("test_loss", test_loss) |
|
|
|
|
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) |
|
|
|
|
of = {} |
|
|
|
|
for threshold in [0.1, 0.5, 1, 2, 5]: |
|
|
|
|
of[str(threshold)] = outlier_fraction(flow_predictions[0], gt_flow, valid_mask, threshold) |
|
|
|
|
self.log("outlier_fraction", of) |
|
|
|
|
images = log_images(left, right, flow_predictions, gt_disp) |
|
|
|
|
images['images'].append(gt_disp) |
|
|
|
|
images['caption'].append('GT Disp') |
|
|
|
|
self.wandb_logger.log_image(**images) |
|
|
|
|
|
|
|
|
|
def predict_step(self, batch, batch_idx, dataloader_idx=0): |
|
|
|
|
return self(batch) |
|
|
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
|
|
|
optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999)) |
|
|
|
|
print('len(self.train_dataloader)', len(self.train_dataloader())) |
|
|
|
|
if not self.automatic_optimization: |
|
|
|
|
return optimizer |
|
|
|
|
lr_scheduler = { |
|
|
|
|
'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR( |
|
|
|
|
optimizer, |
|
|
|
|
T_max=self.T_max if self.T_max else len(self.train_dataloader())/self.batch_size, |
|
|
|
|
), |
|
|
|
|
'name': 'CosineAnnealingLRScheduler', |
|
|
|
|
'name': 'LR Scheduler', |
|
|
|
|
} |
|
|
|
|
return [optimizer], [lr_scheduler] |
|
|
|
|
|
|
|
|
@ -356,18 +424,21 @@ class CREStereoLightning(LightningModule): |
|
|
|
|
|
|
|
|
|
for param_group in optimizer.param_groups: |
|
|
|
|
param_group['lr'] = lr |
|
|
|
|
self.log('train/lr', lr) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
# wandb.init(project='crestereo-lightning') |
|
|
|
|
wandb_logger = WandbLogger(project="crestereo-lightning", log_model=True) |
|
|
|
|
# train configuration |
|
|
|
|
args = parse_yaml("cfgs/train.yaml") |
|
|
|
|
wandb_logger.experiment.config.update(args._asdict()) |
|
|
|
|
config = wandb.config |
|
|
|
|
data_limit = config.data_limit |
|
|
|
|
if 'blender' in config.training_data_path: |
|
|
|
|
# this was used for our blender renders |
|
|
|
|
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' |
|
|
|
|
if 'ctd' in config.training_data_path: |
|
|
|
|
elif 'ctd' in config.training_data_path: |
|
|
|
|
# this one is used (i hope) for ctd |
|
|
|
|
pattern_path = '/home/nils/kinect_from_settings.png' |
|
|
|
|
|
|
|
|
@ -381,7 +452,6 @@ if __name__ == "__main__": |
|
|
|
|
config, |
|
|
|
|
wandb_logger, |
|
|
|
|
pattern_path, |
|
|
|
|
config.training_data_path, |
|
|
|
|
# lr=0.00017378008287493763, # found with auto_lr_find=True |
|
|
|
|
) |
|
|
|
|
# NOTE turn this down once it's working, this might use too much space |
|
|
|
@ -394,31 +464,59 @@ if __name__ == "__main__": |
|
|
|
|
save_last=True, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
|
|
accelerator='gpu', |
|
|
|
|
devices=devices, |
|
|
|
|
max_epochs=config.n_total_epoch, |
|
|
|
|
callbacks=[ |
|
|
|
|
EarlyStopping( |
|
|
|
|
monitor="val_loss", |
|
|
|
|
mode="min", |
|
|
|
|
patience=16, |
|
|
|
|
), |
|
|
|
|
LearningRateMonitor(), |
|
|
|
|
model_checkpoint, |
|
|
|
|
], |
|
|
|
|
strategy=DDPSpawnStrategy(find_unused_parameters=False), |
|
|
|
|
# auto_scale_batch_size='binsearch', |
|
|
|
|
# auto_lr_find=True, |
|
|
|
|
# accumulate_grad_batches=4, # needed to disable for manual optimization |
|
|
|
|
deterministic=True, |
|
|
|
|
check_val_every_n_epoch=1, |
|
|
|
|
limit_val_batches=64, |
|
|
|
|
limit_test_batches=256, |
|
|
|
|
logger=wandb_logger, |
|
|
|
|
default_root_dir=config.log_dir_lightning, |
|
|
|
|
) |
|
|
|
|
if config.scheduler == 'default': |
|
|
|
|
trainer = Trainer( |
|
|
|
|
accelerator='gpu', |
|
|
|
|
devices=devices, |
|
|
|
|
max_epochs=config.n_total_epoch, |
|
|
|
|
callbacks=[ |
|
|
|
|
EarlyStopping( |
|
|
|
|
monitor="val_loss", |
|
|
|
|
mode="min", |
|
|
|
|
patience=8, |
|
|
|
|
), |
|
|
|
|
LearningRateMonitor(), |
|
|
|
|
model_checkpoint, |
|
|
|
|
], |
|
|
|
|
strategy=DDPSpawnStrategy(find_unused_parameters=False), |
|
|
|
|
# auto_scale_batch_size='binsearch', |
|
|
|
|
# auto_lr_find=True, |
|
|
|
|
# accumulate_grad_batches=4, # needed to disable for manual optimization |
|
|
|
|
deterministic=True, |
|
|
|
|
check_val_every_n_epoch=1, |
|
|
|
|
limit_val_batches=64, |
|
|
|
|
limit_test_batches=256, |
|
|
|
|
logger=wandb_logger, |
|
|
|
|
default_root_dir=config.log_dir_lightning, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
trainer = Trainer( |
|
|
|
|
accelerator='gpu', |
|
|
|
|
devices=devices, |
|
|
|
|
max_epochs=config.n_total_epoch, |
|
|
|
|
callbacks=[ |
|
|
|
|
EarlyStopping( |
|
|
|
|
monitor="val_loss", |
|
|
|
|
mode="min", |
|
|
|
|
patience=8, |
|
|
|
|
), |
|
|
|
|
LearningRateMonitor(), |
|
|
|
|
model_checkpoint, |
|
|
|
|
], |
|
|
|
|
strategy=DDPSpawnStrategy(find_unused_parameters=False), |
|
|
|
|
# auto_scale_batch_size='binsearch', |
|
|
|
|
# auto_lr_find=True, |
|
|
|
|
accumulate_grad_batches=4, # needed to disable for manual optimization |
|
|
|
|
deterministic=True, |
|
|
|
|
check_val_every_n_epoch=1, |
|
|
|
|
limit_val_batches=64, |
|
|
|
|
limit_test_batches=256, |
|
|
|
|
logger=wandb_logger, |
|
|
|
|
default_root_dir=config.log_dir_lightning, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# trainer.tune(model) |
|
|
|
|
trainer.fit(model) |
|
|
|
|
# trainer.validate(chkpt_path=model_checkpoint.best_model_path) |
|
|
|
|
trainer.test(model) |
|
|
|
|