commit
d91a867d5d
@ -0,0 +1,20 @@ |
|||||||
|
seed: 0 |
||||||
|
mixed_precision: false |
||||||
|
base_lr: 4.0e-4 |
||||||
|
|
||||||
|
nr_gpus: 8 |
||||||
|
batch_size: 4 |
||||||
|
n_total_epoch: 600 |
||||||
|
minibatch_per_epoch: 500 |
||||||
|
|
||||||
|
loadmodel: ~ |
||||||
|
log_dir: "./train_log" |
||||||
|
model_save_freq_epoch: 1 |
||||||
|
|
||||||
|
max_disp: 256 |
||||||
|
image_width: 512 |
||||||
|
image_height: 384 |
||||||
|
training_data_path: "./stereo_trainset/crestereo" |
||||||
|
|
||||||
|
log_level: "logging.INFO" |
||||||
|
|
@ -0,0 +1,215 @@ |
|||||||
|
import os |
||||||
|
import cv2 |
||||||
|
import glob |
||||||
|
import numpy as np |
||||||
|
from PIL import Image, ImageEnhance |
||||||
|
|
||||||
|
from megengine.data.dataset import Dataset |
||||||
|
|
||||||
|
|
||||||
|
class Augmentor: |
||||||
|
def __init__( |
||||||
|
self, |
||||||
|
image_height=384, |
||||||
|
image_width=512, |
||||||
|
max_disp=256, |
||||||
|
scale_min=0.6, |
||||||
|
scale_max=1.0, |
||||||
|
seed=0, |
||||||
|
): |
||||||
|
super().__init__() |
||||||
|
self.image_height = image_height |
||||||
|
self.image_width = image_width |
||||||
|
self.max_disp = max_disp |
||||||
|
self.scale_min = scale_min |
||||||
|
self.scale_max = scale_max |
||||||
|
self.rng = np.random.RandomState(seed) |
||||||
|
|
||||||
|
def chromatic_augmentation(self, img): |
||||||
|
random_brightness = np.random.uniform(0.8, 1.2) |
||||||
|
random_contrast = np.random.uniform(0.8, 1.2) |
||||||
|
random_gamma = np.random.uniform(0.8, 1.2) |
||||||
|
|
||||||
|
img = Image.fromarray(img) |
||||||
|
|
||||||
|
enhancer = ImageEnhance.Brightness(img) |
||||||
|
img = enhancer.enhance(random_brightness) |
||||||
|
enhancer = ImageEnhance.Contrast(img) |
||||||
|
img = enhancer.enhance(random_contrast) |
||||||
|
|
||||||
|
gamma_map = [ |
||||||
|
255 * 1.0 * pow(ele / 255.0, random_gamma) for ele in range(256) |
||||||
|
] * 3 |
||||||
|
img = img.point(gamma_map) # use PIL's point-function to accelerate this part |
||||||
|
|
||||||
|
img_ = np.array(img) |
||||||
|
|
||||||
|
return img_ |
||||||
|
|
||||||
|
def __call__(self, left_img, right_img, left_disp): |
||||||
|
# 1. chromatic augmentation |
||||||
|
left_img = self.chromatic_augmentation(left_img) |
||||||
|
right_img = self.chromatic_augmentation(right_img) |
||||||
|
|
||||||
|
# 2. spatial augmentation |
||||||
|
# 2.1) rotate & vertical shift for right image |
||||||
|
if self.rng.binomial(1, 0.5): |
||||||
|
angle, pixel = 0.1, 2 |
||||||
|
px = self.rng.uniform(-pixel, pixel) |
||||||
|
ag = self.rng.uniform(-angle, angle) |
||||||
|
image_center = ( |
||||||
|
self.rng.uniform(0, right_img.shape[0]), |
||||||
|
self.rng.uniform(0, right_img.shape[1]), |
||||||
|
) |
||||||
|
rot_mat = cv2.getRotationMatrix2D(image_center, ag, 1.0) |
||||||
|
right_img = cv2.warpAffine( |
||||||
|
right_img, rot_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR |
||||||
|
) |
||||||
|
trans_mat = np.float32([[1, 0, 0], [0, 1, px]]) |
||||||
|
right_img = cv2.warpAffine( |
||||||
|
right_img, trans_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR |
||||||
|
) |
||||||
|
|
||||||
|
# 2.2) random resize |
||||||
|
resize_scale = self.rng.uniform(self.scale_min, self.scale_max) |
||||||
|
|
||||||
|
left_img = cv2.resize( |
||||||
|
left_img, |
||||||
|
None, |
||||||
|
fx=resize_scale, |
||||||
|
fy=resize_scale, |
||||||
|
interpolation=cv2.INTER_LINEAR, |
||||||
|
) |
||||||
|
right_img = cv2.resize( |
||||||
|
right_img, |
||||||
|
None, |
||||||
|
fx=resize_scale, |
||||||
|
fy=resize_scale, |
||||||
|
interpolation=cv2.INTER_LINEAR, |
||||||
|
) |
||||||
|
|
||||||
|
disp_mask = (left_disp < float(self.max_disp / resize_scale)) & (left_disp > 0) |
||||||
|
disp_mask = disp_mask.astype("float32") |
||||||
|
disp_mask = cv2.resize( |
||||||
|
disp_mask, |
||||||
|
None, |
||||||
|
fx=resize_scale, |
||||||
|
fy=resize_scale, |
||||||
|
interpolation=cv2.INTER_LINEAR, |
||||||
|
) |
||||||
|
|
||||||
|
left_disp = ( |
||||||
|
cv2.resize( |
||||||
|
left_disp, |
||||||
|
None, |
||||||
|
fx=resize_scale, |
||||||
|
fy=resize_scale, |
||||||
|
interpolation=cv2.INTER_LINEAR, |
||||||
|
) |
||||||
|
* resize_scale |
||||||
|
) |
||||||
|
|
||||||
|
# 2.3) random crop |
||||||
|
h, w, c = left_img.shape |
||||||
|
dx = w - self.image_width |
||||||
|
dy = h - self.image_height |
||||||
|
dy = self.rng.randint(min(0, dy), max(0, dy) + 1) |
||||||
|
dx = self.rng.randint(min(0, dx), max(0, dx) + 1) |
||||||
|
|
||||||
|
M = np.float32([[1.0, 0.0, -dx], [0.0, 1.0, -dy]]) |
||||||
|
left_img = cv2.warpAffine( |
||||||
|
left_img, |
||||||
|
M, |
||||||
|
(self.image_width, self.image_height), |
||||||
|
flags=cv2.INTER_LINEAR, |
||||||
|
borderValue=0, |
||||||
|
) |
||||||
|
right_img = cv2.warpAffine( |
||||||
|
right_img, |
||||||
|
M, |
||||||
|
(self.image_width, self.image_height), |
||||||
|
flags=cv2.INTER_LINEAR, |
||||||
|
borderValue=0, |
||||||
|
) |
||||||
|
left_disp = cv2.warpAffine( |
||||||
|
left_disp, |
||||||
|
M, |
||||||
|
(self.image_width, self.image_height), |
||||||
|
flags=cv2.INTER_LINEAR, |
||||||
|
borderValue=0, |
||||||
|
) |
||||||
|
disp_mask = cv2.warpAffine( |
||||||
|
disp_mask, |
||||||
|
M, |
||||||
|
(self.image_width, self.image_height), |
||||||
|
flags=cv2.INTER_LINEAR, |
||||||
|
borderValue=0, |
||||||
|
) |
||||||
|
|
||||||
|
# 3. add random occlusion to right image |
||||||
|
if self.rng.binomial(1, 0.5): |
||||||
|
sx = int(self.rng.uniform(50, 100)) |
||||||
|
sy = int(self.rng.uniform(50, 100)) |
||||||
|
cx = int(self.rng.uniform(sx, right_img.shape[0] - sx)) |
||||||
|
cy = int(self.rng.uniform(sy, right_img.shape[1] - sy)) |
||||||
|
right_img[cx - sx : cx + sx, cy - sy : cy + sy] = np.mean( |
||||||
|
np.mean(right_img, 0), 0 |
||||||
|
)[np.newaxis, np.newaxis] |
||||||
|
|
||||||
|
return left_img, right_img, left_disp, disp_mask |
||||||
|
|
||||||
|
|
||||||
|
class CREStereoDataset(Dataset): |
||||||
|
def __init__(self, root): |
||||||
|
super().__init__() |
||||||
|
self.imgs = glob.glob(os.path.join(root, "**/*_left.jpg"), recursive=True) |
||||||
|
self.augmentor = Augmentor( |
||||||
|
image_height=384, |
||||||
|
image_width=512, |
||||||
|
max_disp=256, |
||||||
|
scale_min=0.6, |
||||||
|
scale_max=1.0, |
||||||
|
seed=0, |
||||||
|
) |
||||||
|
self.rng = np.random.RandomState(0) |
||||||
|
|
||||||
|
def get_disp(self, path): |
||||||
|
disp = cv2.imread(path, cv2.IMREAD_UNCHANGED) |
||||||
|
return disp.astype(np.float32) / 32 |
||||||
|
|
||||||
|
def __getitem__(self, index): |
||||||
|
# find path |
||||||
|
left_path = self.imgs[index] |
||||||
|
prefix = left_path[: left_path.rfind("_")] |
||||||
|
right_path = prefix + "_right.jpg" |
||||||
|
left_disp_path = prefix + "_left.disp.png" |
||||||
|
right_disp_path = prefix + "_right.disp.png" |
||||||
|
|
||||||
|
# read img, disp |
||||||
|
left_img = cv2.imread(left_path, cv2.IMREAD_COLOR) |
||||||
|
right_img = cv2.imread(right_path, cv2.IMREAD_COLOR) |
||||||
|
left_disp = self.get_disp(left_disp_path) |
||||||
|
right_disp = self.get_disp(right_disp_path) |
||||||
|
|
||||||
|
if self.rng.binomial(1, 0.5): |
||||||
|
left_img, right_img = np.fliplr(right_img), np.fliplr(left_img) |
||||||
|
left_disp, right_disp = np.fliplr(right_disp), np.fliplr(left_disp) |
||||||
|
left_disp[left_disp == np.inf] = 0 |
||||||
|
|
||||||
|
# augmentaion |
||||||
|
left_img, right_img, left_disp, disp_mask = self.augmentor( |
||||||
|
left_img, right_img, left_disp |
||||||
|
) |
||||||
|
|
||||||
|
left_img = left_img.transpose(2, 0, 1).astype("uint8") |
||||||
|
right_img = right_img.transpose(2, 0, 1).astype("uint8") |
||||||
|
|
||||||
|
return { |
||||||
|
"left": left_img, |
||||||
|
"right": right_img, |
||||||
|
"disparity": left_disp, |
||||||
|
"mask": disp_mask, |
||||||
|
} |
||||||
|
|
||||||
|
def __len__(self): |
||||||
|
return len(self.imgs) |
Binary file not shown.
@ -0,0 +1,274 @@ |
|||||||
|
import os |
||||||
|
import sys |
||||||
|
import time |
||||||
|
import logging |
||||||
|
from collections import namedtuple |
||||||
|
|
||||||
|
import yaml |
||||||
|
from tensorboardX import SummaryWriter |
||||||
|
|
||||||
|
from nets import Model |
||||||
|
from dataset import CREStereoDataset |
||||||
|
|
||||||
|
import torch |
||||||
|
import torch.nn as nn |
||||||
|
import torch.optim as optim |
||||||
|
from torch.utils.data import DataLoader |
||||||
|
|
||||||
|
|
||||||
|
def parse_yaml(file_path: str) -> namedtuple: |
||||||
|
"""Parse yaml configuration file and return the object in `namedtuple`.""" |
||||||
|
with open(file_path, "rb") as f: |
||||||
|
cfg: dict = yaml.safe_load(f) |
||||||
|
args = namedtuple("train_args", cfg.keys())(*cfg.values()) |
||||||
|
return args |
||||||
|
|
||||||
|
|
||||||
|
def format_time(elapse): |
||||||
|
elapse = int(elapse) |
||||||
|
hour = elapse // 3600 |
||||||
|
minute = elapse % 3600 // 60 |
||||||
|
seconds = elapse % 60 |
||||||
|
return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds) |
||||||
|
|
||||||
|
|
||||||
|
def ensure_dir(path): |
||||||
|
if not os.path.exists(path): |
||||||
|
os.makedirs(path, exist_ok=True) |
||||||
|
|
||||||
|
|
||||||
|
def adjust_learning_rate(optimizer, epoch): |
||||||
|
|
||||||
|
warm_up = 0.02 |
||||||
|
const_range = 0.6 |
||||||
|
min_lr_rate = 0.05 |
||||||
|
|
||||||
|
if epoch <= args.n_total_epoch * warm_up: |
||||||
|
lr = (1 - min_lr_rate) * args.base_lr / ( |
||||||
|
args.n_total_epoch * warm_up |
||||||
|
) * epoch + min_lr_rate * args.base_lr |
||||||
|
elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range: |
||||||
|
lr = args.base_lr |
||||||
|
else: |
||||||
|
lr = (min_lr_rate - 1) * args.base_lr / ( |
||||||
|
(1 - const_range) * args.n_total_epoch |
||||||
|
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr |
||||||
|
|
||||||
|
for param_group in optimizer.param_groups: |
||||||
|
param_group['lr'] = lr |
||||||
|
|
||||||
|
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8): |
||||||
|
''' |
||||||
|
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W) |
||||||
|
flow_preds[0]: (B, 2, H, W) |
||||||
|
flow_gt: (B, 2, H, W) |
||||||
|
''' |
||||||
|
n_predictions = len(flow_preds) |
||||||
|
flow_loss = 0.0 |
||||||
|
for i in range(n_predictions): |
||||||
|
i_weight = gamma ** (n_predictions - i - 1) |
||||||
|
i_loss = torch.abs(flow_preds[i] - flow_gt) |
||||||
|
flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean() |
||||||
|
|
||||||
|
return flow_loss |
||||||
|
|
||||||
|
|
||||||
|
def main(args): |
||||||
|
# initial info |
||||||
|
torch.manual_seed(args.seed) |
||||||
|
torch.cuda.manual_seed(args.seed) |
||||||
|
# rank, world_size = dist.get_rank(), dist.get_world_size() |
||||||
|
world_size = torch.cuda.device_count() # number of GPU(s) |
||||||
|
|
||||||
|
# directory check |
||||||
|
log_model_dir = os.path.join(args.log_dir, "models") |
||||||
|
ensure_dir(log_model_dir) |
||||||
|
|
||||||
|
# model / optimizer |
||||||
|
model = Model( |
||||||
|
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False |
||||||
|
) |
||||||
|
model = nn.DataParallel(model,device_ids=[i for i in range(world_size)]) |
||||||
|
model.cuda() |
||||||
|
optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999)) |
||||||
|
# model = nn.DataParallel(model,device_ids=[0]) |
||||||
|
|
||||||
|
tb_log = SummaryWriter(os.path.join(args.log_dir, "train.events")) |
||||||
|
|
||||||
|
# worklog |
||||||
|
logging.basicConfig(level=eval(args.log_level)) |
||||||
|
worklog = logging.getLogger("train_logger") |
||||||
|
worklog.propagate = False |
||||||
|
fileHandler = logging.FileHandler( |
||||||
|
os.path.join(args.log_dir, "worklog.txt"), mode="a", encoding="utf8" |
||||||
|
) |
||||||
|
formatter = logging.Formatter( |
||||||
|
fmt="%(asctime)s %(message)s", datefmt="%Y/%m/%d %H:%M:%S" |
||||||
|
) |
||||||
|
fileHandler.setFormatter(formatter) |
||||||
|
consoleHandler = logging.StreamHandler(sys.stdout) |
||||||
|
formatter = logging.Formatter( |
||||||
|
fmt="\x1b[32m%(asctime)s\x1b[0m %(message)s", datefmt="%Y/%m/%d %H:%M:%S" |
||||||
|
) |
||||||
|
consoleHandler.setFormatter(formatter) |
||||||
|
worklog.handlers = [fileHandler, consoleHandler] |
||||||
|
|
||||||
|
# params stat |
||||||
|
worklog.info(f"Use {world_size} GPU(s)") |
||||||
|
worklog.info("Params: %s" % sum([p.numel() for p in model.parameters()])) |
||||||
|
|
||||||
|
# load pretrained model if exist |
||||||
|
chk_path = os.path.join(log_model_dir, "latest.pth") |
||||||
|
if args.loadmodel is not None: |
||||||
|
chk_path = args.loadmodel |
||||||
|
elif not os.path.exists(chk_path): |
||||||
|
chk_path = None |
||||||
|
|
||||||
|
if chk_path is not None: |
||||||
|
# if rank == 0: |
||||||
|
worklog.info(f"loading model: {chk_path}") |
||||||
|
state_dict = torch.load(chk_path) |
||||||
|
model.load_state_dict(state_dict['state_dict']) |
||||||
|
optimizer.load_state_dict(state_dict['optim_state_dict']) |
||||||
|
resume_epoch_idx = state_dict["epoch"] |
||||||
|
resume_iters = state_dict["iters"] |
||||||
|
start_epoch_idx = resume_epoch_idx + 1 |
||||||
|
start_iters = resume_iters |
||||||
|
else: |
||||||
|
start_epoch_idx = 1 |
||||||
|
start_iters = 0 |
||||||
|
|
||||||
|
# datasets |
||||||
|
dataset = CREStereoDataset(args.training_data_path) |
||||||
|
# if rank == 0: |
||||||
|
worklog.info(f"Dataset size: {len(dataset)}") |
||||||
|
dataloader = DataLoader(dataset, args.batch_size, shuffle=True, |
||||||
|
num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) |
||||||
|
|
||||||
|
# counter |
||||||
|
cur_iters = start_iters |
||||||
|
total_iters = args.minibatch_per_epoch * args.n_total_epoch |
||||||
|
t0 = time.perf_counter() |
||||||
|
for epoch_idx in range(start_epoch_idx, args.n_total_epoch + 1): |
||||||
|
|
||||||
|
# adjust learning rate |
||||||
|
epoch_total_train_loss = 0 |
||||||
|
adjust_learning_rate(optimizer, epoch_idx) |
||||||
|
model.train() |
||||||
|
|
||||||
|
t1 = time.perf_counter() |
||||||
|
# batch_idx = 0 |
||||||
|
|
||||||
|
# for mini_batch_data in dataloader: |
||||||
|
for batch_idx, mini_batch_data in enumerate(dataloader): |
||||||
|
|
||||||
|
if batch_idx % args.minibatch_per_epoch == 0 and batch_idx != 0: |
||||||
|
break |
||||||
|
# batch_idx += 1 |
||||||
|
cur_iters += 1 |
||||||
|
|
||||||
|
# parse data |
||||||
|
left, right, gt_disp, valid_mask = ( |
||||||
|
mini_batch_data["left"], |
||||||
|
mini_batch_data["right"], |
||||||
|
mini_batch_data["disparity"].cuda(), |
||||||
|
mini_batch_data["mask"].cuda(), |
||||||
|
) |
||||||
|
|
||||||
|
t2 = time.perf_counter() |
||||||
|
optimizer.zero_grad() |
||||||
|
|
||||||
|
# pre-process |
||||||
|
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] |
||||||
|
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] |
||||||
|
|
||||||
|
# forward |
||||||
|
flow_predictions = model(left.cuda(), right.cuda()) |
||||||
|
|
||||||
|
# loss & backword |
||||||
|
loss = sequence_loss( |
||||||
|
flow_predictions, gt_flow, valid_mask, gamma=0.8 |
||||||
|
) |
||||||
|
|
||||||
|
# loss stats |
||||||
|
loss_item = loss.data.item() |
||||||
|
epoch_total_train_loss += loss_item |
||||||
|
loss.backward() |
||||||
|
optimizer.step() |
||||||
|
t3 = time.perf_counter() |
||||||
|
|
||||||
|
if cur_iters % 10 == 0: |
||||||
|
tdata = t2 - t1 |
||||||
|
time_train_passed = t3 - t0 |
||||||
|
time_iter_passed = t3 - t1 |
||||||
|
step_passed = cur_iters - start_iters |
||||||
|
eta = ( |
||||||
|
(total_iters - cur_iters) |
||||||
|
/ max(step_passed, 1e-7) |
||||||
|
* time_train_passed |
||||||
|
) |
||||||
|
|
||||||
|
meta_info = list() |
||||||
|
meta_info.append("{:.2g} b/s".format(1.0 / time_iter_passed)) |
||||||
|
meta_info.append("passed:{}".format(format_time(time_train_passed))) |
||||||
|
meta_info.append("eta:{}".format(format_time(eta))) |
||||||
|
meta_info.append( |
||||||
|
"data_time:{:.2g}".format(tdata / time_iter_passed) |
||||||
|
) |
||||||
|
|
||||||
|
meta_info.append( |
||||||
|
"lr:{:.5g}".format(optimizer.param_groups[0]["lr"]) |
||||||
|
) |
||||||
|
meta_info.append( |
||||||
|
"[{}/{}:{}/{}]".format( |
||||||
|
epoch_idx, |
||||||
|
args.n_total_epoch, |
||||||
|
batch_idx, |
||||||
|
args.minibatch_per_epoch, |
||||||
|
) |
||||||
|
) |
||||||
|
loss_info = [" ==> {}:{:.4g}".format("loss", loss_item)] |
||||||
|
# exp_name = ['\n' + os.path.basename(os.getcwd())] |
||||||
|
|
||||||
|
info = [",".join(meta_info)] + loss_info |
||||||
|
worklog.info("".join(info)) |
||||||
|
|
||||||
|
# minibatch loss |
||||||
|
tb_log.add_scalar("train/loss_batch", loss_item, cur_iters) |
||||||
|
tb_log.add_scalar( |
||||||
|
"train/lr", optimizer.param_groups[0]["lr"], cur_iters |
||||||
|
) |
||||||
|
tb_log.flush() |
||||||
|
|
||||||
|
t1 = time.perf_counter() |
||||||
|
|
||||||
|
tb_log.add_scalar( |
||||||
|
"train/loss", |
||||||
|
epoch_total_train_loss / args.minibatch_per_epoch, |
||||||
|
epoch_idx, |
||||||
|
) |
||||||
|
tb_log.flush() |
||||||
|
|
||||||
|
# save model params |
||||||
|
ckp_data = { |
||||||
|
"epoch": epoch_idx, |
||||||
|
"iters": cur_iters, |
||||||
|
"batch_size": args.batch_size, |
||||||
|
"epoch_size": args.minibatch_per_epoch, |
||||||
|
"train_loss": epoch_total_train_loss / args.minibatch_per_epoch, |
||||||
|
"state_dict": model.state_dict(), |
||||||
|
"optim_state_dict": optimizer.state_dict(), |
||||||
|
} |
||||||
|
torch.save(ckp_data, os.path.join(log_model_dir, "latest.pth")) |
||||||
|
if epoch_idx % args.model_save_freq_epoch == 0: |
||||||
|
save_path = os.path.join(log_model_dir, "epoch-%d.pth" % epoch_idx) |
||||||
|
worklog.info(f"Model params saved: {save_path}") |
||||||
|
torch.save(ckp_data, save_path) |
||||||
|
|
||||||
|
worklog.info("Training is done, exit.") |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
# train configuration |
||||||
|
args = parse_yaml("cfgs/train.yaml") |
||||||
|
main(args) |
Loading…
Reference in new issue