diff --git a/.gitignore b/.gitignore index d9005f2..cdcc1c4 100644 --- a/.gitignore +++ b/.gitignore @@ -150,3 +150,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +vis_results/ \ No newline at end of file diff --git a/README.md b/README.md index eadf1c1..76f1a26 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ # References: - CREStereo: https://github.com/megvii-research/CREStereo +- CREStereo-Pytorch: https://github.com/ibaiGorordo/CREStereo-Pytorch - RAFT: https://github.com/princeton-vl/RAFT - LoFTR: https://github.com/zju3dv/LoFTR - Grid sample replacement: https://zenn.dev/pinto0309/scraps/7d4032067d0160 diff --git a/cfgs/train.yaml b/cfgs/train.yaml new file mode 100644 index 0000000..cfc5f7b --- /dev/null +++ b/cfgs/train.yaml @@ -0,0 +1,20 @@ +seed: 0 +mixed_precision: false +base_lr: 4.0e-4 + +nr_gpus: 8 +batch_size: 4 +n_total_epoch: 600 +minibatch_per_epoch: 500 + +loadmodel: ~ +log_dir: "./train_log" +model_save_freq_epoch: 1 + +max_disp: 256 +image_width: 512 +image_height: 384 +training_data_path: "./stereo_trainset/crestereo" + +log_level: "logging.INFO" + diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..e13e7b9 --- /dev/null +++ b/dataset.py @@ -0,0 +1,215 @@ +import os +import cv2 +import glob +import numpy as np +from PIL import Image, ImageEnhance + +from megengine.data.dataset import Dataset + + +class Augmentor: + def __init__( + self, + image_height=384, + image_width=512, + max_disp=256, + scale_min=0.6, + scale_max=1.0, + seed=0, + ): + super().__init__() + self.image_height = image_height + self.image_width = image_width + self.max_disp = max_disp + self.scale_min = scale_min + self.scale_max = scale_max + self.rng = np.random.RandomState(seed) + + def chromatic_augmentation(self, img): + random_brightness = np.random.uniform(0.8, 1.2) + random_contrast = np.random.uniform(0.8, 1.2) + random_gamma = np.random.uniform(0.8, 1.2) + + img = Image.fromarray(img) + + enhancer = ImageEnhance.Brightness(img) + img = enhancer.enhance(random_brightness) + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(random_contrast) + + gamma_map = [ + 255 * 1.0 * pow(ele / 255.0, random_gamma) for ele in range(256) + ] * 3 + img = img.point(gamma_map) # use PIL's point-function to accelerate this part + + img_ = np.array(img) + + return img_ + + def __call__(self, left_img, right_img, left_disp): + # 1. chromatic augmentation + left_img = self.chromatic_augmentation(left_img) + right_img = self.chromatic_augmentation(right_img) + + # 2. spatial augmentation + # 2.1) rotate & vertical shift for right image + if self.rng.binomial(1, 0.5): + angle, pixel = 0.1, 2 + px = self.rng.uniform(-pixel, pixel) + ag = self.rng.uniform(-angle, angle) + image_center = ( + self.rng.uniform(0, right_img.shape[0]), + self.rng.uniform(0, right_img.shape[1]), + ) + rot_mat = cv2.getRotationMatrix2D(image_center, ag, 1.0) + right_img = cv2.warpAffine( + right_img, rot_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR + ) + trans_mat = np.float32([[1, 0, 0], [0, 1, px]]) + right_img = cv2.warpAffine( + right_img, trans_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR + ) + + # 2.2) random resize + resize_scale = self.rng.uniform(self.scale_min, self.scale_max) + + left_img = cv2.resize( + left_img, + None, + fx=resize_scale, + fy=resize_scale, + interpolation=cv2.INTER_LINEAR, + ) + right_img = cv2.resize( + right_img, + None, + fx=resize_scale, + fy=resize_scale, + interpolation=cv2.INTER_LINEAR, + ) + + disp_mask = (left_disp < float(self.max_disp / resize_scale)) & (left_disp > 0) + disp_mask = disp_mask.astype("float32") + disp_mask = cv2.resize( + disp_mask, + None, + fx=resize_scale, + fy=resize_scale, + interpolation=cv2.INTER_LINEAR, + ) + + left_disp = ( + cv2.resize( + left_disp, + None, + fx=resize_scale, + fy=resize_scale, + interpolation=cv2.INTER_LINEAR, + ) + * resize_scale + ) + + # 2.3) random crop + h, w, c = left_img.shape + dx = w - self.image_width + dy = h - self.image_height + dy = self.rng.randint(min(0, dy), max(0, dy) + 1) + dx = self.rng.randint(min(0, dx), max(0, dx) + 1) + + M = np.float32([[1.0, 0.0, -dx], [0.0, 1.0, -dy]]) + left_img = cv2.warpAffine( + left_img, + M, + (self.image_width, self.image_height), + flags=cv2.INTER_LINEAR, + borderValue=0, + ) + right_img = cv2.warpAffine( + right_img, + M, + (self.image_width, self.image_height), + flags=cv2.INTER_LINEAR, + borderValue=0, + ) + left_disp = cv2.warpAffine( + left_disp, + M, + (self.image_width, self.image_height), + flags=cv2.INTER_LINEAR, + borderValue=0, + ) + disp_mask = cv2.warpAffine( + disp_mask, + M, + (self.image_width, self.image_height), + flags=cv2.INTER_LINEAR, + borderValue=0, + ) + + # 3. add random occlusion to right image + if self.rng.binomial(1, 0.5): + sx = int(self.rng.uniform(50, 100)) + sy = int(self.rng.uniform(50, 100)) + cx = int(self.rng.uniform(sx, right_img.shape[0] - sx)) + cy = int(self.rng.uniform(sy, right_img.shape[1] - sy)) + right_img[cx - sx : cx + sx, cy - sy : cy + sy] = np.mean( + np.mean(right_img, 0), 0 + )[np.newaxis, np.newaxis] + + return left_img, right_img, left_disp, disp_mask + + +class CREStereoDataset(Dataset): + def __init__(self, root): + super().__init__() + self.imgs = glob.glob(os.path.join(root, "**/*_left.jpg"), recursive=True) + self.augmentor = Augmentor( + image_height=384, + image_width=512, + max_disp=256, + scale_min=0.6, + scale_max=1.0, + seed=0, + ) + self.rng = np.random.RandomState(0) + + def get_disp(self, path): + disp = cv2.imread(path, cv2.IMREAD_UNCHANGED) + return disp.astype(np.float32) / 32 + + def __getitem__(self, index): + # find path + left_path = self.imgs[index] + prefix = left_path[: left_path.rfind("_")] + right_path = prefix + "_right.jpg" + left_disp_path = prefix + "_left.disp.png" + right_disp_path = prefix + "_right.disp.png" + + # read img, disp + left_img = cv2.imread(left_path, cv2.IMREAD_COLOR) + right_img = cv2.imread(right_path, cv2.IMREAD_COLOR) + left_disp = self.get_disp(left_disp_path) + right_disp = self.get_disp(right_disp_path) + + if self.rng.binomial(1, 0.5): + left_img, right_img = np.fliplr(right_img), np.fliplr(left_img) + left_disp, right_disp = np.fliplr(right_disp), np.fliplr(left_disp) + left_disp[left_disp == np.inf] = 0 + + # augmentaion + left_img, right_img, left_disp, disp_mask = self.augmentor( + left_img, right_img, left_disp + ) + + left_img = left_img.transpose(2, 0, 1).astype("uint8") + right_img = right_img.transpose(2, 0, 1).astype("uint8") + + return { + "left": left_img, + "right": right_img, + "disparity": left_disp, + "mask": disp_mask, + } + + def __len__(self): + return len(self.imgs) diff --git a/models/crestereo_eth3d.mge b/models/crestereo_eth3d.mge new file mode 100644 index 0000000..5be102e Binary files /dev/null and b/models/crestereo_eth3d.mge differ diff --git a/train.py b/train.py new file mode 100644 index 0000000..8f8d385 --- /dev/null +++ b/train.py @@ -0,0 +1,274 @@ +import os +import sys +import time +import logging +from collections import namedtuple + +import yaml +from tensorboardX import SummaryWriter + +from nets import Model +from dataset import CREStereoDataset + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader + + +def parse_yaml(file_path: str) -> namedtuple: + """Parse yaml configuration file and return the object in `namedtuple`.""" + with open(file_path, "rb") as f: + cfg: dict = yaml.safe_load(f) + args = namedtuple("train_args", cfg.keys())(*cfg.values()) + return args + + +def format_time(elapse): + elapse = int(elapse) + hour = elapse // 3600 + minute = elapse % 3600 // 60 + seconds = elapse % 60 + return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds) + + +def ensure_dir(path): + if not os.path.exists(path): + os.makedirs(path, exist_ok=True) + + +def adjust_learning_rate(optimizer, epoch): + + warm_up = 0.02 + const_range = 0.6 + min_lr_rate = 0.05 + + if epoch <= args.n_total_epoch * warm_up: + lr = (1 - min_lr_rate) * args.base_lr / ( + args.n_total_epoch * warm_up + ) * epoch + min_lr_rate * args.base_lr + elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range: + lr = args.base_lr + else: + lr = (min_lr_rate - 1) * args.base_lr / ( + (1 - const_range) * args.n_total_epoch + ) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr + + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8): + ''' + valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W) + flow_preds[0]: (B, 2, H, W) + flow_gt: (B, 2, H, W) + ''' + n_predictions = len(flow_preds) + flow_loss = 0.0 + for i in range(n_predictions): + i_weight = gamma ** (n_predictions - i - 1) + i_loss = torch.abs(flow_preds[i] - flow_gt) + flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean() + + return flow_loss + + +def main(args): + # initial info + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + # rank, world_size = dist.get_rank(), dist.get_world_size() + world_size = torch.cuda.device_count() # number of GPU(s) + + # directory check + log_model_dir = os.path.join(args.log_dir, "models") + ensure_dir(log_model_dir) + + # model / optimizer + model = Model( + max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False + ) + model = nn.DataParallel(model,device_ids=[i for i in range(world_size)]) + model.cuda() + optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999)) + # model = nn.DataParallel(model,device_ids=[0]) + + tb_log = SummaryWriter(os.path.join(args.log_dir, "train.events")) + + # worklog + logging.basicConfig(level=eval(args.log_level)) + worklog = logging.getLogger("train_logger") + worklog.propagate = False + fileHandler = logging.FileHandler( + os.path.join(args.log_dir, "worklog.txt"), mode="a", encoding="utf8" + ) + formatter = logging.Formatter( + fmt="%(asctime)s %(message)s", datefmt="%Y/%m/%d %H:%M:%S" + ) + fileHandler.setFormatter(formatter) + consoleHandler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter( + fmt="\x1b[32m%(asctime)s\x1b[0m %(message)s", datefmt="%Y/%m/%d %H:%M:%S" + ) + consoleHandler.setFormatter(formatter) + worklog.handlers = [fileHandler, consoleHandler] + + # params stat + worklog.info(f"Use {world_size} GPU(s)") + worklog.info("Params: %s" % sum([p.numel() for p in model.parameters()])) + + # load pretrained model if exist + chk_path = os.path.join(log_model_dir, "latest.pth") + if args.loadmodel is not None: + chk_path = args.loadmodel + elif not os.path.exists(chk_path): + chk_path = None + + if chk_path is not None: + # if rank == 0: + worklog.info(f"loading model: {chk_path}") + state_dict = torch.load(chk_path) + model.load_state_dict(state_dict['state_dict']) + optimizer.load_state_dict(state_dict['optim_state_dict']) + resume_epoch_idx = state_dict["epoch"] + resume_iters = state_dict["iters"] + start_epoch_idx = resume_epoch_idx + 1 + start_iters = resume_iters + else: + start_epoch_idx = 1 + start_iters = 0 + + # datasets + dataset = CREStereoDataset(args.training_data_path) + # if rank == 0: + worklog.info(f"Dataset size: {len(dataset)}") + dataloader = DataLoader(dataset, args.batch_size, shuffle=True, + num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) + + # counter + cur_iters = start_iters + total_iters = args.minibatch_per_epoch * args.n_total_epoch + t0 = time.perf_counter() + for epoch_idx in range(start_epoch_idx, args.n_total_epoch + 1): + + # adjust learning rate + epoch_total_train_loss = 0 + adjust_learning_rate(optimizer, epoch_idx) + model.train() + + t1 = time.perf_counter() + # batch_idx = 0 + + # for mini_batch_data in dataloader: + for batch_idx, mini_batch_data in enumerate(dataloader): + + if batch_idx % args.minibatch_per_epoch == 0 and batch_idx != 0: + break + # batch_idx += 1 + cur_iters += 1 + + # parse data + left, right, gt_disp, valid_mask = ( + mini_batch_data["left"], + mini_batch_data["right"], + mini_batch_data["disparity"].cuda(), + mini_batch_data["mask"].cuda(), + ) + + t2 = time.perf_counter() + optimizer.zero_grad() + + # pre-process + gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] + gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] + + # forward + flow_predictions = model(left.cuda(), right.cuda()) + + # loss & backword + loss = sequence_loss( + flow_predictions, gt_flow, valid_mask, gamma=0.8 + ) + + # loss stats + loss_item = loss.data.item() + epoch_total_train_loss += loss_item + loss.backward() + optimizer.step() + t3 = time.perf_counter() + + if cur_iters % 10 == 0: + tdata = t2 - t1 + time_train_passed = t3 - t0 + time_iter_passed = t3 - t1 + step_passed = cur_iters - start_iters + eta = ( + (total_iters - cur_iters) + / max(step_passed, 1e-7) + * time_train_passed + ) + + meta_info = list() + meta_info.append("{:.2g} b/s".format(1.0 / time_iter_passed)) + meta_info.append("passed:{}".format(format_time(time_train_passed))) + meta_info.append("eta:{}".format(format_time(eta))) + meta_info.append( + "data_time:{:.2g}".format(tdata / time_iter_passed) + ) + + meta_info.append( + "lr:{:.5g}".format(optimizer.param_groups[0]["lr"]) + ) + meta_info.append( + "[{}/{}:{}/{}]".format( + epoch_idx, + args.n_total_epoch, + batch_idx, + args.minibatch_per_epoch, + ) + ) + loss_info = [" ==> {}:{:.4g}".format("loss", loss_item)] + # exp_name = ['\n' + os.path.basename(os.getcwd())] + + info = [",".join(meta_info)] + loss_info + worklog.info("".join(info)) + + # minibatch loss + tb_log.add_scalar("train/loss_batch", loss_item, cur_iters) + tb_log.add_scalar( + "train/lr", optimizer.param_groups[0]["lr"], cur_iters + ) + tb_log.flush() + + t1 = time.perf_counter() + + tb_log.add_scalar( + "train/loss", + epoch_total_train_loss / args.minibatch_per_epoch, + epoch_idx, + ) + tb_log.flush() + + # save model params + ckp_data = { + "epoch": epoch_idx, + "iters": cur_iters, + "batch_size": args.batch_size, + "epoch_size": args.minibatch_per_epoch, + "train_loss": epoch_total_train_loss / args.minibatch_per_epoch, + "state_dict": model.state_dict(), + "optim_state_dict": optimizer.state_dict(), + } + torch.save(ckp_data, os.path.join(log_model_dir, "latest.pth")) + if epoch_idx % args.model_save_freq_epoch == 0: + save_path = os.path.join(log_model_dir, "epoch-%d.pth" % epoch_idx) + worklog.info(f"Model params saved: {save_path}") + torch.save(ckp_data, save_path) + + worklog.info("Training is done, exit.") + + +if __name__ == "__main__": + # train configuration + args = parse_yaml("cfgs/train.yaml") + main(args)