fix lightning, prepare sweeps
This commit is contained in:
parent
d8169e01bc
commit
37c537ca31
@ -1,6 +1,8 @@
|
|||||||
seed: 0
|
seed: 0
|
||||||
mixed_precision: false
|
mixed_precision: false
|
||||||
base_lr: 4.0e-4
|
# base_lr: 4.0e-4
|
||||||
|
base_lr: 0.001
|
||||||
|
t_max: 161
|
||||||
|
|
||||||
nr_gpus: 3
|
nr_gpus: 3
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
@ -16,7 +18,7 @@ max_disp: 256
|
|||||||
image_width: 640
|
image_width: 640
|
||||||
image_height: 480
|
image_height: 480
|
||||||
# training_data_path: "./stereo_trainset/crestereo"
|
# training_data_path: "./stereo_trainset/crestereo"
|
||||||
pattern_attention: true
|
pattern_attention: false
|
||||||
dataset: "blender"
|
dataset: "blender"
|
||||||
# training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/"
|
# training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/"
|
||||||
training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders/data"
|
training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders/data"
|
||||||
|
@ -384,7 +384,7 @@ class BlenderDataset(CTDDataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not self.use_lightning:
|
if not self.use_lightning:
|
||||||
right_img = right_img.transpose((2, 0, 1)).astype("uint8")
|
# right_img = right_img.transpose((2, 0, 1)).astype("uint8")
|
||||||
return {
|
return {
|
||||||
"left": left_img,
|
"left": left_img,
|
||||||
"right": right_img,
|
"right": right_img,
|
||||||
@ -408,7 +408,7 @@ class BlenderDataset(CTDDataset):
|
|||||||
# return disp.astype(np.float32) / 32
|
# return disp.astype(np.float32) / 32
|
||||||
# FIXME temporarily increase disparity until new data with better depth values is generated
|
# FIXME temporarily increase disparity until new data with better depth values is generated
|
||||||
# higher values seem to speedup convergence, but introduce much stronger artifacting
|
# higher values seem to speedup convergence, but introduce much stronger artifacting
|
||||||
# mystery_factor = 150
|
mystery_factor = 150
|
||||||
mystery_factor = 1
|
# mystery_factor = 1
|
||||||
disp = (baseline * fl * mystery_factor) / depth
|
disp = (baseline * fl * mystery_factor) / depth
|
||||||
return disp.astype(np.float32)
|
return disp.astype(np.float32)
|
||||||
|
@ -38,10 +38,10 @@ class CREStereo(nn.Module):
|
|||||||
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
|
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
|
||||||
|
|
||||||
# # NOTE Position_encoding as workaround for TensorRt
|
# # NOTE Position_encoding as workaround for TensorRt
|
||||||
image1_shape = [1, 2, 480, 640]
|
# image1_shape = [1, 2, 480, 640]
|
||||||
self.pos_encoding_fn_small = PositionEncodingSine(
|
# self.pos_encoding_fn_small = PositionEncodingSine(
|
||||||
d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
|
# d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
|
||||||
)
|
# )
|
||||||
|
|
||||||
# loftr
|
# loftr
|
||||||
self.self_att_fn = LocalFeatureTransformer(
|
self.self_att_fn = LocalFeatureTransformer(
|
||||||
@ -141,10 +141,12 @@ class CREStereo(nn.Module):
|
|||||||
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
|
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
|
||||||
)
|
)
|
||||||
# 'n c h w -> n (h w) c'
|
# 'n c h w -> n (h w) c'
|
||||||
x_tmp = self.pos_encoding_fn_small(fmap1_dw16)
|
# x_tmp = self.pos_encoding_fn_small(fmap1_dw16)
|
||||||
|
x_tmp = pos_encoding_fn_small(fmap1_dw16)
|
||||||
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
||||||
# 'n c h w -> n (h w) c'
|
# 'n c h w -> n (h w) c'
|
||||||
x_tmp = self.pos_encoding_fn_small(fmap2_dw16)
|
# x_tmp = self.pos_encoding_fn_small(fmap2_dw16)
|
||||||
|
x_tmp = pos_encoding_fn_small(fmap2_dw16)
|
||||||
fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
||||||
|
|
||||||
# FIXME experimental ! no self-attention for pattern
|
# FIXME experimental ! no self-attention for pattern
|
||||||
|
1
train.py
1
train.py
@ -419,6 +419,7 @@ def main(args):
|
|||||||
# print(f'left {left.shape}, right {right.shape}')
|
# print(f'left {left.shape}, right {right.shape}')
|
||||||
# left = left.transpose([2, 0, 1])
|
# left = left.transpose([2, 0, 1])
|
||||||
right = right.transpose([1, 2, 0])
|
right = right.transpose([1, 2, 0])
|
||||||
|
|
||||||
# right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2)
|
# right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2)
|
||||||
# print(f'left {left.shape}, right {right.shape}')
|
# print(f'left {left.shape}, right {right.shape}')
|
||||||
|
|
||||||
|
@ -5,10 +5,8 @@ import logging
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
# from tensorboardX import SummaryWriter
|
|
||||||
|
|
||||||
from nets import Model
|
from nets import Model
|
||||||
# from dataset import CREStereoDataset
|
|
||||||
from dataset import BlenderDataset, CREStereoDataset, CTDDataset
|
from dataset import BlenderDataset, CREStereoDataset, CTDDataset
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -18,8 +16,11 @@ import torch.optim as optim
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
||||||
from pytorch_lightning import Trainer, seed_everything
|
from pytorch_lightning import Trainer, seed_everything
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
|
||||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||||
|
from pytorch_lightning.callbacks import LearningRateMonitor
|
||||||
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||||
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
from pytorch_lightning.strategies import DDPSpawnStrategy
|
||||||
|
|
||||||
seed_everything(42, workers=True)
|
seed_everything(42, workers=True)
|
||||||
|
|
||||||
@ -39,11 +40,9 @@ def normalize_and_colormap(img):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def log_images(left, right, pred_disp, gt_disp, wandb_logger=None):
|
def log_images(left, right, pred_disp, gt_disp):
|
||||||
# wandb_logger.log_text('test')
|
|
||||||
# return
|
|
||||||
log = {}
|
log = {}
|
||||||
batch_idx = 1
|
batch_idx = 0
|
||||||
|
|
||||||
if isinstance(pred_disp, list):
|
if isinstance(pred_disp, list):
|
||||||
pred_disp = pred_disp[-1]
|
pred_disp = pred_disp[-1]
|
||||||
@ -100,32 +99,13 @@ def ensure_dir(path):
|
|||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
def adjust_learning_rate(optimizer, epoch):
|
|
||||||
|
|
||||||
warm_up = 0.02
|
|
||||||
const_range = 0.6
|
|
||||||
min_lr_rate = 0.05
|
|
||||||
|
|
||||||
if epoch <= args.n_total_epoch * warm_up:
|
|
||||||
lr = (1 - min_lr_rate) * args.base_lr / (
|
|
||||||
args.n_total_epoch * warm_up
|
|
||||||
) * epoch + min_lr_rate * args.base_lr
|
|
||||||
elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range:
|
|
||||||
lr = args.base_lr
|
|
||||||
else:
|
|
||||||
lr = (min_lr_rate - 1) * args.base_lr / (
|
|
||||||
(1 - const_range) * args.n_total_epoch
|
|
||||||
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr
|
|
||||||
|
|
||||||
for param_group in optimizer.param_groups:
|
|
||||||
param_group['lr'] = lr
|
|
||||||
|
|
||||||
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
|
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
|
||||||
'''
|
'''
|
||||||
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
|
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
|
||||||
flow_preds[0]: (B, 2, H, W)
|
flow_preds[0]: (B, 2, H, W)
|
||||||
flow_gt: (B, 2, H, W)
|
flow_gt: (B, 2, H, W)
|
||||||
'''
|
'''
|
||||||
|
"""
|
||||||
if test:
|
if test:
|
||||||
# print('sequence loss')
|
# print('sequence loss')
|
||||||
if valid.shape != (2, 480, 640):
|
if valid.shape != (2, 480, 640):
|
||||||
@ -136,6 +116,7 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
|
|||||||
if valid.shape != (2, 480, 640):
|
if valid.shape != (2, 480, 640):
|
||||||
valid = valid.transpose(0,1)
|
valid = valid.transpose(0,1)
|
||||||
# print(valid.shape)
|
# print(valid.shape)
|
||||||
|
"""
|
||||||
# print(valid.shape)
|
# print(valid.shape)
|
||||||
# print(flow_preds[0].shape)
|
# print(flow_preds[0].shape)
|
||||||
# print(flow_gt.shape)
|
# print(flow_gt.shape)
|
||||||
@ -143,7 +124,7 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
|
|||||||
flow_loss = 0.0
|
flow_loss = 0.0
|
||||||
|
|
||||||
# TEST
|
# TEST
|
||||||
flow_gt = torch.squeeze(flow_gt, dim=-1)
|
# flow_gt = torch.squeeze(flow_gt, dim=-1)
|
||||||
|
|
||||||
for i in range(n_predictions):
|
for i in range(n_predictions):
|
||||||
i_weight = gamma ** (n_predictions - i - 1)
|
i_weight = gamma ** (n_predictions - i - 1)
|
||||||
@ -155,16 +136,88 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
|
|||||||
|
|
||||||
|
|
||||||
class CREStereoLightning(LightningModule):
|
class CREStereoLightning(LightningModule):
|
||||||
def __init__(self, args, logger):
|
def __init__(self, args, logger, pattern_path, data_path):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch_size = args.batch_size
|
self.batch_size = args.batch_size
|
||||||
self.wandb_logger = logger
|
self.wandb_logger = logger
|
||||||
|
self.lr = args.base_lr
|
||||||
|
print(f'lr = {self.lr}')
|
||||||
|
self.T_max = args.t_max if args.t_max else None
|
||||||
|
self.pattern_attention = args.pattern_attention
|
||||||
|
self.pattern_path = pattern_path
|
||||||
|
self.data_path = data_path
|
||||||
self.model = Model(
|
self.model = Model(
|
||||||
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
|
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True):
|
def train_dataloader(self):
|
||||||
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right)
|
dataset = BlenderDataset(
|
||||||
|
root=self.data_path,
|
||||||
|
pattern_path=self.pattern_path,
|
||||||
|
use_lightning=True,
|
||||||
|
)
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
self.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=4,
|
||||||
|
drop_last=True,
|
||||||
|
persistent_workers=True,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
test_dataset = BlenderDataset(
|
||||||
|
root=self.data_path,
|
||||||
|
pattern_path=self.pattern_path,
|
||||||
|
test_set=True,
|
||||||
|
use_lightning=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_dataloader = DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
self.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=4,
|
||||||
|
drop_last=False,
|
||||||
|
persistent_workers=True,
|
||||||
|
pin_memory=True
|
||||||
|
)
|
||||||
|
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
|
||||||
|
return test_dataloader
|
||||||
|
|
||||||
|
def test_dataloader(self):
|
||||||
|
# TODO change this to use IRL data?
|
||||||
|
test_dataset = BlenderDataset(
|
||||||
|
root=self.data_path,
|
||||||
|
pattern_path=self.pattern_path,
|
||||||
|
test_set=True,
|
||||||
|
use_lightning=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_dataloader = DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
self.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=4,
|
||||||
|
drop_last=False,
|
||||||
|
persistent_workers=True,
|
||||||
|
pin_memory=True
|
||||||
|
)
|
||||||
|
return test_dataloader
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
image1,
|
||||||
|
image2,
|
||||||
|
flow_init=None,
|
||||||
|
iters=10,
|
||||||
|
upsample=True,
|
||||||
|
test_mode=False,
|
||||||
|
):
|
||||||
|
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self.pattern_attention)
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
left, right, gt_disp, valid_mask = batch
|
left, right, gt_disp, valid_mask = batch
|
||||||
@ -174,6 +227,10 @@ class CREStereoLightning(LightningModule):
|
|||||||
loss = sequence_loss(
|
loss = sequence_loss(
|
||||||
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||||
)
|
)
|
||||||
|
if batch_idx % 128 == 0:
|
||||||
|
image_log = log_images(left, right, flow_predictions, gt_disp)
|
||||||
|
image_log['key'] = 'debug_train'
|
||||||
|
self.wandb_logger.log_image(**image_log)
|
||||||
self.log("train_loss", loss)
|
self.log("train_loss", loss)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@ -186,22 +243,31 @@ class CREStereoLightning(LightningModule):
|
|||||||
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||||
)
|
)
|
||||||
self.log("val_loss", val_loss)
|
self.log("val_loss", val_loss)
|
||||||
if batch_idx % 4 == 0:
|
if batch_idx % 8 == 0:
|
||||||
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
|
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
left, right, gt_disp, valid_mask = batch
|
left, right, gt_disp, valid_mask = batch
|
||||||
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
|
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
|
||||||
|
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
|
||||||
flow_predictions = self.forward(left, right, test_mode=True)
|
flow_predictions = self.forward(left, right, test_mode=True)
|
||||||
test_loss = sequence_loss(
|
test_loss = sequence_loss(
|
||||||
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||||
)
|
)
|
||||||
self.log("test_loss", test_loss)
|
self.log("test_loss", test_loss)
|
||||||
print('test_batch_idx:', batch_idx)
|
|
||||||
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
|
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
return optim.Adam(self.model.parameters(), lr=0.1, betas=(0.9, 0.999))
|
optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999))
|
||||||
|
print('len(self.train_dataloader)', len(self.train_dataloader()))
|
||||||
|
lr_scheduler = {
|
||||||
|
'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||||
|
optimizer,
|
||||||
|
T_max=self.T_max if self.T_max else len(self.train_dataloader())/self.batch_size,
|
||||||
|
),
|
||||||
|
'name': 'CosineAnnealingLRScheduler',
|
||||||
|
}
|
||||||
|
return [optimizer], [lr_scheduler]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -209,61 +275,54 @@ if __name__ == "__main__":
|
|||||||
args = parse_yaml("cfgs/train.yaml")
|
args = parse_yaml("cfgs/train.yaml")
|
||||||
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
||||||
|
|
||||||
wandb_logger = WandbLogger(project="crestereo-lightning")
|
run = wandb.init(project="crestereo-lightning", config=args._asdict(), tags=['new_scheduler', 'default_lr', f'{"" if args.pattern_attention else "no-"}pattern-attention'], notes='')
|
||||||
wandb.config.update(args._asdict())
|
run.config.update(args._asdict())
|
||||||
|
config = wandb.config
|
||||||
|
wandb_logger = WandbLogger(project="crestereo-lightning", id=run.id, log_model=True)
|
||||||
|
# wandb_logger = WandbLogger(project="crestereo-lightning", log_model='all')
|
||||||
|
# wandb_logger.experiment.config.update(args._asdict())
|
||||||
|
|
||||||
model = CREStereoLightning(args, wandb_logger)
|
model = CREStereoLightning(
|
||||||
|
# args,
|
||||||
dataset = BlenderDataset(
|
config,
|
||||||
root=args.training_data_path,
|
wandb_logger,
|
||||||
pattern_path=pattern_path,
|
pattern_path,
|
||||||
use_lightning=True,
|
args.training_data_path,
|
||||||
)
|
# lr=0.00017378008287493763, # found with auto_lr_find=True
|
||||||
test_dataset = BlenderDataset(
|
)
|
||||||
root=args.training_data_path,
|
# NOTE turn this down once it's working, this might use too much space
|
||||||
pattern_path=pattern_path,
|
# wandb_logger.watch(model, log_graph=False) #, log='all')
|
||||||
test_set=True,
|
|
||||||
use_lightning=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataloader = DataLoader(
|
|
||||||
dataset,
|
|
||||||
args.batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=16,
|
|
||||||
drop_last=True,
|
|
||||||
persistent_workers=True,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
|
|
||||||
test_dataloader = DataLoader(
|
|
||||||
test_dataset,
|
|
||||||
args.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=16,
|
|
||||||
drop_last=False,
|
|
||||||
persistent_workers=True,
|
|
||||||
pin_memory=True
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
accelerator='gpu',
|
accelerator='gpu',
|
||||||
devices=2,
|
devices=args.nr_gpus,
|
||||||
max_epochs=args.n_total_epoch,
|
max_epochs=args.n_total_epoch,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
EarlyStopping(
|
EarlyStopping(
|
||||||
monitor="val_loss",
|
monitor="val_loss",
|
||||||
mode="min",
|
mode="min",
|
||||||
patience=4,
|
patience=16,
|
||||||
|
),
|
||||||
|
LearningRateMonitor(),
|
||||||
|
ModelCheckpoint(
|
||||||
|
monitor="val_loss",
|
||||||
|
mode="min",
|
||||||
|
save_top_k=2,
|
||||||
|
save_last=True,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
accumulate_grad_batches=8,
|
strategy=DDPSpawnStrategy(find_unused_parameters=False),
|
||||||
|
# auto_scale_batch_size='binsearch',
|
||||||
|
# auto_lr_find=True,
|
||||||
|
accumulate_grad_batches=4,
|
||||||
deterministic=True,
|
deterministic=True,
|
||||||
check_val_every_n_epoch=1,
|
check_val_every_n_epoch=1,
|
||||||
limit_val_batches=24,
|
limit_val_batches=64,
|
||||||
limit_test_batches=24,
|
limit_test_batches=256,
|
||||||
logger=wandb_logger,
|
logger=wandb_logger,
|
||||||
default_root_dir=args.log_dir_lightning,
|
default_root_dir=args.log_dir_lightning,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.fit(model, dataloader, test_dataloader)
|
# trainer.tune(model)
|
||||||
|
trainer.fit(model)
|
||||||
|
trainer.validate()
|
||||||
|
Loading…
Reference in New Issue
Block a user