|
|
@ -16,10 +16,10 @@ import torch.nn as nn |
|
|
|
import torch.nn.functional as F |
|
|
|
import torch.nn.functional as F |
|
|
|
import torch.optim as optim |
|
|
|
import torch.optim as optim |
|
|
|
from torch.utils.data import DataLoader |
|
|
|
from torch.utils.data import DataLoader |
|
|
|
from pytorch_lightning.lite import LightningLite |
|
|
|
|
|
|
|
from pytorch_lightning import LightningDataModule, LightningModule, Trainer |
|
|
|
from pytorch_lightning import LightningDataModule, LightningModule, Trainer |
|
|
|
from pytorch_lightning import Trainer, seed_everything |
|
|
|
from pytorch_lightning import Trainer, seed_everything |
|
|
|
from pytorch_lightning.loggers import WandbLogger |
|
|
|
from pytorch_lightning.loggers import WandbLogger |
|
|
|
|
|
|
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping |
|
|
|
|
|
|
|
|
|
|
|
seed_everything(42, workers=True) |
|
|
|
seed_everything(42, workers=True) |
|
|
|
|
|
|
|
|
|
|
@ -39,148 +39,44 @@ def normalize_and_colormap(img): |
|
|
|
return ret |
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True, last_img=None, test=False, attend_pattern=True): |
|
|
|
def log_images(left, right, pred_disp, gt_disp, wandb_logger=None): |
|
|
|
|
|
|
|
# wandb_logger.log_text('test') |
|
|
|
print("Model Forwarding...") |
|
|
|
# return |
|
|
|
if isinstance(left, torch.Tensor): |
|
|
|
|
|
|
|
left = left# .cpu().detach().numpy() |
|
|
|
|
|
|
|
imgR = right# .cpu().detach().numpy() |
|
|
|
|
|
|
|
imgL = left |
|
|
|
|
|
|
|
imgR = right |
|
|
|
|
|
|
|
imgL = np.ascontiguousarray(imgL[None, :, :, :]) |
|
|
|
|
|
|
|
imgR = np.ascontiguousarray(imgR[None, :, :, :]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flow_init = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# chosen for convenience |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
imgL = torch.tensor(imgL.astype("float32")) |
|
|
|
|
|
|
|
imgR = torch.tensor(imgR.astype("float32")) |
|
|
|
|
|
|
|
imgL = imgL.transpose(2,3).transpose(1,2) |
|
|
|
|
|
|
|
if imgL.shape != imgR.shape: |
|
|
|
|
|
|
|
imgR = imgR.transpose(2,3).transpose(1,2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
imgL_dw2 = F.interpolate( |
|
|
|
|
|
|
|
imgL, |
|
|
|
|
|
|
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2), |
|
|
|
|
|
|
|
mode="bilinear", |
|
|
|
|
|
|
|
align_corners=True, |
|
|
|
|
|
|
|
).clamp(min=0, max=255) |
|
|
|
|
|
|
|
imgR_dw2 = F.interpolate( |
|
|
|
|
|
|
|
imgR, |
|
|
|
|
|
|
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2), |
|
|
|
|
|
|
|
mode="bilinear", |
|
|
|
|
|
|
|
align_corners=True, |
|
|
|
|
|
|
|
).clamp(min=0, max=255) |
|
|
|
|
|
|
|
if last_img is not None: |
|
|
|
|
|
|
|
print('using flow_initialization') |
|
|
|
|
|
|
|
print(last_img.shape) |
|
|
|
|
|
|
|
# FIXME trying 8bit normalization because sometimes we get really high values that don't seem to help |
|
|
|
|
|
|
|
print(last_img.max(), last_img.min()) |
|
|
|
|
|
|
|
if last_img.min() < 0: |
|
|
|
|
|
|
|
# print('Negative disparity detected. shifting...') |
|
|
|
|
|
|
|
last_img = last_img - last_img.min() |
|
|
|
|
|
|
|
if last_img.max() > 255: |
|
|
|
|
|
|
|
# print('Excessive disparity detected. scaling...') |
|
|
|
|
|
|
|
last_img = last_img / (last_img.max() / 255) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
last_img = np.dstack([last_img, last_img]) |
|
|
|
|
|
|
|
# last_img = np.dstack([last_img, last_img, last_img]) |
|
|
|
|
|
|
|
last_img = np.dstack([last_img]) |
|
|
|
|
|
|
|
last_img = last_img.reshape((1, 2, 480, 640)) |
|
|
|
|
|
|
|
# print(last_img.shape) |
|
|
|
|
|
|
|
# print(last_img.dtype) |
|
|
|
|
|
|
|
# print(last_img.max(), last_img.min()) |
|
|
|
|
|
|
|
flow_init = torch.tensor(last_img.astype("float32")) |
|
|
|
|
|
|
|
# flow_init = F.interpolate( |
|
|
|
|
|
|
|
# last_img, |
|
|
|
|
|
|
|
# size=(last_img.shape[0] // 2, last_img.shape[1] // 2), |
|
|
|
|
|
|
|
# mode="bilinear", |
|
|
|
|
|
|
|
# align_corners=True, |
|
|
|
|
|
|
|
# ) |
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
|
|
|
|
|
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=flow_init, self_attend_right=attend_pattern) |
|
|
|
|
|
|
|
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2, self_attend_right=attend_pattern) |
|
|
|
|
|
|
|
pf_base = pred_flow |
|
|
|
|
|
|
|
if isinstance(pf_base, list): |
|
|
|
|
|
|
|
pf_base = pred_flow[0] |
|
|
|
|
|
|
|
pf = torch.squeeze(pf_base[:, 0, :, :])# .cpu().detach().numpy() |
|
|
|
|
|
|
|
print('pred_flow max min') |
|
|
|
|
|
|
|
print(pf.max(), pf.min()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not wandb_log: |
|
|
|
|
|
|
|
if test: |
|
|
|
|
|
|
|
return pred_flow |
|
|
|
|
|
|
|
return torch.squeeze(pred_flow[:, 0, :, :])# .cpu().detach().numpy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log = {} |
|
|
|
log = {} |
|
|
|
in_h, in_w = left.shape[:2] |
|
|
|
batch_idx = 1 |
|
|
|
|
|
|
|
|
|
|
|
# Resize image in case the GPU memory overflows |
|
|
|
|
|
|
|
eval_h, eval_w = (in_h,in_w) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)): |
|
|
|
|
|
|
|
pred_disp = torch.squeeze(pf[:, 0, :, :])# .cpu().detach().numpy() |
|
|
|
|
|
|
|
pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :])# .cpu().detach().numpy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) |
|
|
|
|
|
|
|
# pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if i == n_iter-1: |
|
|
|
|
|
|
|
t = float(in_w) / float(eval_w) |
|
|
|
|
|
|
|
disp = cv2.resize(pred_disp.cpu().detach().numpy(), (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log[f'disp_vis'] = wandb.Image( |
|
|
|
|
|
|
|
normalize_and_colormap(disp), |
|
|
|
|
|
|
|
caption=f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log[f'pred_{i}'] = wandb.Image( |
|
|
|
|
|
|
|
np.array([pred_disp.cpu().detach().numpy().reshape(480, 640)]), |
|
|
|
|
|
|
|
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
# log[f'pred_norm_{i}'] = wandb.Image( |
|
|
|
|
|
|
|
# np.array([pred_disp_norm.reshape(480, 640)]), |
|
|
|
|
|
|
|
# caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", |
|
|
|
|
|
|
|
# ) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# log[f'pred_dw2_{i}'] = wandb.Image( |
|
|
|
if isinstance(pred_disp, list): |
|
|
|
# np.array([pred_disp_dw2.reshape(240, 320)]), |
|
|
|
pred_disp = pred_disp[-1] |
|
|
|
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", |
|
|
|
|
|
|
|
# ) |
|
|
|
|
|
|
|
# log[f'pred_dw2_norm_{i}'] = wandb.Image( |
|
|
|
|
|
|
|
# np.array([pred_disp_dw2_norm.reshape(240, 320)]), |
|
|
|
|
|
|
|
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", |
|
|
|
|
|
|
|
# ) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_disp = torch.squeeze(pred_disp[:, 0, :, :]) |
|
|
|
|
|
|
|
gt_disp = torch.squeeze(gt_disp[:, 0, :, :]) |
|
|
|
|
|
|
|
left = torch.squeeze(left[:, 0, :, :]) |
|
|
|
|
|
|
|
right = torch.squeeze(right[:, 0, :, :]) |
|
|
|
|
|
|
|
|
|
|
|
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left") |
|
|
|
disp = pred_disp |
|
|
|
input_right = right# .cpu().detach().numpy() if isinstance(right, torch.Tensor) else right |
|
|
|
|
|
|
|
if input_right.shape != (480, 640, 3): |
|
|
|
|
|
|
|
input_right.transpose(1,2,0) |
|
|
|
|
|
|
|
log['input_right'] = wandb.Image(input_right.astype('uint8'), caption="Input Right") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gt_disp = gt_disp# .cpu().detach().numpy() if isinstance(gt_disp, torch.Tensor) else gt_disp |
|
|
|
|
|
|
|
disp = disp# .cpu().detach().numpy() if isinstance(disp, torch.Tensor) else disp |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
disp_error = gt_disp - disp |
|
|
|
disp_error = gt_disp - disp |
|
|
|
log['disp_error'] = wandb.Image( |
|
|
|
|
|
|
|
normalize_and_colormap(abs(disp_error)), |
|
|
|
|
|
|
|
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}", |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) |
|
|
|
log[f'gt_disp_vis'] = wandb.Image( |
|
|
|
input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) |
|
|
|
normalize_and_colormap(gt_disp), |
|
|
|
|
|
|
|
caption=f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}", |
|
|
|
wandb_log = dict( |
|
|
|
) |
|
|
|
key='samples', |
|
|
|
|
|
|
|
images=[ |
|
|
|
wandb.log(log) |
|
|
|
normalize_and_colormap(pred_disp[batch_idx]), |
|
|
|
return pred_flow |
|
|
|
normalize_and_colormap(abs(disp_error[batch_idx])), |
|
|
|
|
|
|
|
normalize_and_colormap(gt_disp[batch_idx]), |
|
|
|
|
|
|
|
input_left, |
|
|
|
|
|
|
|
input_right, |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
caption=[ |
|
|
|
|
|
|
|
f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}", |
|
|
|
|
|
|
|
f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}", |
|
|
|
|
|
|
|
f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}", |
|
|
|
|
|
|
|
"Input Left", |
|
|
|
|
|
|
|
"Input Right" |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
return wandb_log |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_yaml(file_path: str) -> namedtuple: |
|
|
|
def parse_yaml(file_path: str) -> namedtuple: |
|
|
@ -259,9 +155,10 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CREStereoLightning(LightningModule): |
|
|
|
class CREStereoLightning(LightningModule): |
|
|
|
def __init__(self, args): |
|
|
|
def __init__(self, args, logger): |
|
|
|
super().__init__() |
|
|
|
super().__init__() |
|
|
|
self.batch_size = args.batch_size |
|
|
|
self.batch_size = args.batch_size |
|
|
|
|
|
|
|
self.wandb_logger = logger |
|
|
|
self.model = Model( |
|
|
|
self.model = Model( |
|
|
|
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False |
|
|
|
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False |
|
|
|
) |
|
|
|
) |
|
|
@ -270,13 +167,10 @@ class CREStereoLightning(LightningModule): |
|
|
|
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right) |
|
|
|
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right) |
|
|
|
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
# loss = self(batch) |
|
|
|
|
|
|
|
left, right, gt_disp, valid_mask = batch |
|
|
|
left, right, gt_disp, valid_mask = batch |
|
|
|
left = torch.Tensor(left).to(self.device) |
|
|
|
|
|
|
|
right = torch.Tensor(right).to(self.device) |
|
|
|
|
|
|
|
left = left |
|
|
|
|
|
|
|
right = right |
|
|
|
|
|
|
|
flow_predictions = self.forward(left, right) |
|
|
|
flow_predictions = self.forward(left, right) |
|
|
|
|
|
|
|
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] |
|
|
|
|
|
|
|
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] |
|
|
|
loss = sequence_loss( |
|
|
|
loss = sequence_loss( |
|
|
|
flow_predictions, gt_flow, valid_mask, gamma=0.8 |
|
|
|
flow_predictions, gt_flow, valid_mask, gamma=0.8 |
|
|
|
) |
|
|
|
) |
|
|
@ -285,62 +179,91 @@ class CREStereoLightning(LightningModule): |
|
|
|
|
|
|
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
|
left, right, gt_disp, valid_mask = batch |
|
|
|
left, right, gt_disp, valid_mask = batch |
|
|
|
left = torch.Tensor(left).to(self.device) |
|
|
|
flow_predictions = self.forward(left, right, test_mode=True) |
|
|
|
right = torch.Tensor(right).to(self.device) |
|
|
|
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] |
|
|
|
print(left.shape) |
|
|
|
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] |
|
|
|
print(right.shape) |
|
|
|
|
|
|
|
flow_predictions = self.forward(left, right) |
|
|
|
|
|
|
|
val_loss = sequence_loss( |
|
|
|
val_loss = sequence_loss( |
|
|
|
flow_predictions, gt_flow, valid_mask, gamma=0.8 |
|
|
|
flow_predictions, gt_flow, valid_mask, gamma=0.8 |
|
|
|
) |
|
|
|
) |
|
|
|
self.log("val_loss", val_loss) |
|
|
|
self.log("val_loss", val_loss) |
|
|
|
|
|
|
|
if batch_idx % 4 == 0: |
|
|
|
|
|
|
|
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) |
|
|
|
|
|
|
|
|
|
|
|
def test_step(self, batch, batch_idx): |
|
|
|
def test_step(self, batch, batch_idx): |
|
|
|
left, right, gt_disp, valid_mask = batch |
|
|
|
left, right, gt_disp, valid_mask = batch |
|
|
|
# left, right, gt_disp, valid_mask = ( |
|
|
|
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] |
|
|
|
# batch["left"], |
|
|
|
flow_predictions = self.forward(left, right, test_mode=True) |
|
|
|
# batch["right"], |
|
|
|
|
|
|
|
# batch["disparity"], |
|
|
|
|
|
|
|
# batch["mask"], |
|
|
|
|
|
|
|
# ) |
|
|
|
|
|
|
|
left = torch.Tensor(left).to(self.device) |
|
|
|
|
|
|
|
right = torch.Tensor(right).to(self.device) |
|
|
|
|
|
|
|
flow_predictions = self.forward(left, right) |
|
|
|
|
|
|
|
test_loss = sequence_loss( |
|
|
|
test_loss = sequence_loss( |
|
|
|
flow_predictions, gt_flow, valid_mask, gamma=0.8 |
|
|
|
flow_predictions, gt_flow, valid_mask, gamma=0.8 |
|
|
|
) |
|
|
|
) |
|
|
|
self.log("test_loss", test_loss) |
|
|
|
self.log("test_loss", test_loss) |
|
|
|
|
|
|
|
print('test_batch_idx:', batch_idx) |
|
|
|
|
|
|
|
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) |
|
|
|
|
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
|
|
def configure_optimizers(self): |
|
|
|
return optim.Adam(self.model.parameters(), lr=0.01, betas=(0.9, 0.999)) |
|
|
|
return optim.Adam(self.model.parameters(), lr=0.1, betas=(0.9, 0.999)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
if __name__ == "__main__": |
|
|
|
# train configuration |
|
|
|
# train configuration |
|
|
|
args = parse_yaml("cfgs/train.yaml") |
|
|
|
args = parse_yaml("cfgs/train.yaml") |
|
|
|
# wandb.init(project="crestereo-lightning", entity="cpt-captain") |
|
|
|
|
|
|
|
# Lite(strategy='dp', accelerator='gpu', devices=2).run(args) |
|
|
|
|
|
|
|
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' |
|
|
|
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' |
|
|
|
model = CREStereoLightning(args) |
|
|
|
|
|
|
|
dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, use_lightning=True) |
|
|
|
|
|
|
|
test_dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, test_set=True, use_lightning=True) |
|
|
|
|
|
|
|
print(len(dataset)) |
|
|
|
|
|
|
|
print(len(test_dataset)) |
|
|
|
|
|
|
|
wandb_logger = WandbLogger(project="crestereo-lightning") |
|
|
|
wandb_logger = WandbLogger(project="crestereo-lightning") |
|
|
|
wandb.config.update(args._asdict()) |
|
|
|
wandb.config.update(args._asdict()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = CREStereoLightning(args, wandb_logger) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = BlenderDataset( |
|
|
|
|
|
|
|
root=args.training_data_path, |
|
|
|
|
|
|
|
pattern_path=pattern_path, |
|
|
|
|
|
|
|
use_lightning=True, |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
test_dataset = BlenderDataset( |
|
|
|
|
|
|
|
root=args.training_data_path, |
|
|
|
|
|
|
|
pattern_path=pattern_path, |
|
|
|
|
|
|
|
test_set=True, |
|
|
|
|
|
|
|
use_lightning=True, |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
|
|
|
dataset, |
|
|
|
|
|
|
|
args.batch_size, |
|
|
|
|
|
|
|
shuffle=True, |
|
|
|
|
|
|
|
num_workers=16, |
|
|
|
|
|
|
|
drop_last=True, |
|
|
|
|
|
|
|
persistent_workers=True, |
|
|
|
|
|
|
|
pin_memory=True, |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) |
|
|
|
|
|
|
|
test_dataloader = DataLoader( |
|
|
|
|
|
|
|
test_dataset, |
|
|
|
|
|
|
|
args.batch_size, |
|
|
|
|
|
|
|
shuffle=False, |
|
|
|
|
|
|
|
num_workers=16, |
|
|
|
|
|
|
|
drop_last=False, |
|
|
|
|
|
|
|
persistent_workers=True, |
|
|
|
|
|
|
|
pin_memory=True |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
|
trainer = Trainer( |
|
|
|
max_epochs=args.n_total_epoch, |
|
|
|
|
|
|
|
accelerator='gpu', |
|
|
|
accelerator='gpu', |
|
|
|
devices=2, |
|
|
|
devices=2, |
|
|
|
# auto_scale_batch_size='binsearch', |
|
|
|
max_epochs=args.n_total_epoch, |
|
|
|
# strategy='ddp', |
|
|
|
callbacks=[ |
|
|
|
|
|
|
|
EarlyStopping( |
|
|
|
|
|
|
|
monitor="val_loss", |
|
|
|
|
|
|
|
mode="min", |
|
|
|
|
|
|
|
patience=4, |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
accumulate_grad_batches=8, |
|
|
|
deterministic=True, |
|
|
|
deterministic=True, |
|
|
|
check_val_every_n_epoch=1, |
|
|
|
check_val_every_n_epoch=1, |
|
|
|
limit_val_batches=24, |
|
|
|
limit_val_batches=24, |
|
|
|
limit_test_batches=24, |
|
|
|
limit_test_batches=24, |
|
|
|
logger=wandb_logger, |
|
|
|
logger=wandb_logger, |
|
|
|
default_root_dir=args.log_dir_lightning, |
|
|
|
default_root_dir=args.log_dir_lightning, |
|
|
|
) |
|
|
|
) |
|
|
|
# trainer.tune(model) |
|
|
|
|
|
|
|
trainer.fit(model, dataset, test_dataset) |
|
|
|
trainer.fit(model, dataloader, test_dataloader) |
|
|
|