import numpy as np import torch import random import logging import datetime from pathlib import Path import argparse import subprocess import socket import sys import os import gc import json import matplotlib.pyplot as plt import time from collections import OrderedDict class StopWatch(object): def __init__(self): self.timings = OrderedDict() self.starts = {} def start(self, name): self.starts[name] = time.time() def stop(self, name): if name not in self.timings: self.timings[name] = [] self.timings[name].append(time.time() - self.starts[name]) def get(self, name=None, reduce=np.sum): if name is not None: return reduce(self.timings[name]) else: ret = {} for k in self.timings: ret[k] = reduce(self.timings[k]) return ret def __repr__(self): return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()]) def __str__(self): return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()]) class ETA(object): def __init__(self, length): self.length = length self.start_time = time.time() self.current_idx = 0 self.current_time = time.time() def update(self, idx): self.current_idx = idx self.current_time = time.time() def get_elapsed_time(self): return self.current_time - self.start_time def get_item_time(self): return self.get_elapsed_time() / (self.current_idx + 1) def get_remaining_time(self): return self.get_item_time() * (self.length - self.current_idx + 1) def format_time(self, seconds): minutes, seconds = divmod(seconds, 60) hours, minutes = divmod(minutes, 60) hours = int(hours) minutes = int(minutes) return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}' def get_elapsed_time_str(self): return self.format_time(self.get_elapsed_time()) def get_remaining_time_str(self): return self.format_time(self.get_remaining_time()) class Worker(object): def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1): self.out_root = Path(out_root) self.experiment_name = experiment_name self.epochs = epochs self.seed = seed self.train_batch_size = train_batch_size self.test_batch_size = test_batch_size self.num_workers = num_workers self.save_frequency = save_frequency self.train_device = train_device self.test_device = test_device self.max_train_iter = max_train_iter self.errs_list=[] self.setup_experiment() def setup_experiment(self): self.exp_out_root = self.out_root / self.experiment_name self.exp_out_root.mkdir(parents=True, exist_ok=True) if logging.root: del logging.root.handlers[:] logging.basicConfig( level=logging.INFO, handlers=[ logging.FileHandler( str(self.exp_out_root / 'train.log') ), logging.StreamHandler() ], format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s' ) logging.info('='*80) logging.info(f'Start of experiment: {self.experiment_name}') logging.info(socket.gethostname()) self.log_datetime() logging.info('='*80) self.metric_path = self.exp_out_root / 'metrics.json' if self.metric_path.exists(): with open(str(self.metric_path), 'r') as fp: self.metric_data = json.load(fp) else: self.metric_data = {} self.init_seed() def metric_add_train(self, epoch, key, val): epoch = str(epoch) key = str(key) if epoch not in self.metric_data: self.metric_data[epoch] = {} if 'train' not in self.metric_data[epoch]: self.metric_data[epoch]['train'] = {} self.metric_data[epoch]['train'][key] = val def metric_add_test(self, epoch, set_idx, key, val): epoch = str(epoch) set_idx = str(set_idx) key = str(key) if epoch not in self.metric_data: self.metric_data[epoch] = {} if 'test' not in self.metric_data[epoch]: self.metric_data[epoch]['test'] = {} if set_idx not in self.metric_data[epoch]['test']: self.metric_data[epoch]['test'][set_idx] = {} self.metric_data[epoch]['test'][set_idx][key] = val def metric_save(self): with open(str(self.metric_path), 'w') as fp: json.dump(self.metric_data, fp, indent=2) def init_seed(self, seed=None): if seed is not None: self.seed = seed logging.info(f'Set seed to {self.seed}') np.random.seed(self.seed) random.seed(self.seed) torch.manual_seed(self.seed) torch.cuda.manual_seed(self.seed) def log_datetime(self): logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')) def mem_report(self): for obj in gc.get_objects(): if torch.is_tensor(obj): print(type(obj), obj.shape) def get_net_path(self, epoch, root=None): if root is None: root = self.exp_out_root return root / f'net_{epoch:04d}.params' def get_do_parser_cmds(self): return ['retrain', 'resume', 'retest', 'test_init'] def get_do_parser(self): parser = argparse.ArgumentParser() parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds()) parser.add_argument('--epoch', type=int, default=-1) return parser def do_cmd(self, args, net, optimizer, scheduler=None): if args.cmd == 'retrain': self.train(net, optimizer, resume=False, scheduler=scheduler) elif args.cmd == 'resume': self.train(net, optimizer, resume=True, scheduler=scheduler) elif args.cmd == 'retest': self.retest(net, epoch=args.epoch) elif args.cmd == 'test_init': test_sets = self.get_test_sets() self.test(-1, net, test_sets) else: raise Exception('invalid cmd') def do(self, net, optimizer, load_net_optimizer=None, scheduler=None): parser = self.get_do_parser() args, _ = parser.parse_known_args() if load_net_optimizer is not None and args.cmd not in ['schedule']: net, optimizer = load_net_optimizer() self.do_cmd(args, net, optimizer, scheduler=scheduler) def retest(self, net, epoch=-1): if epoch < 0: epochs = range(self.epochs) else: epochs = [epoch] test_sets = self.get_test_sets() for epoch in epochs: net_path = self.get_net_path(epoch) if net_path.exists(): state_dict = torch.load(str(net_path)) net.load_state_dict(state_dict) self.test(epoch, net, test_sets) def format_err_str(self, errs, div=1): err = sum(errs) if len(errs) > 1: err_str = f'{err/div:0.4f}=' + '+'.join([f'{e/div:0.4f}' for e in errs]) else: err_str = f'{err/div:0.4f}' return err_str def write_err_img(self): err_img_path = self.exp_out_root / 'errs.png' fig = plt.figure(figsize=(16,16)) lines=[] for idx,errs in enumerate(self.errs_list): line,=plt.plot(range(len(errs)), errs, label=f'error{idx}') lines.append(line) plt.tight_layout() plt.legend(handles=lines) plt.savefig(str(err_img_path)) plt.close(fig) def callback_train_new_epoch(self, epoch, net, optimizer): pass def train(self, net, optimizer, resume=False, scheduler=None): logging.info('='*80) logging.info('Start training') self.log_datetime() logging.info('='*80) train_set = self.get_train_set() test_sets = self.get_test_sets() net = net.to(self.train_device) epoch = 0 min_err = {ts.name: 1e9 for ts in test_sets} state_path = self.exp_out_root / 'state.dict' if resume and state_path.exists(): logging.info('='*80) logging.info(f'Loading state from {state_path}') logging.info('='*80) state = torch.load(str(state_path)) epoch = state['epoch'] + 1 if 'min_err' in state: min_err = state['min_err'] curr_state = net.state_dict() curr_state.update(state['state_dict']) net.load_state_dict(curr_state) try: optimizer.load_state_dict(state['optimizer']) except: logging.info('Warning: cannot load optimizer from state_dict') pass if 'cpu_rng_state' in state: torch.set_rng_state(state['cpu_rng_state']) if 'gpu_rng_state' in state: torch.cuda.set_rng_state(state['gpu_rng_state']) for epoch in range(epoch, self.epochs): self.callback_train_new_epoch(epoch, net, optimizer) # train epoch self.train_epoch(epoch, net, optimizer, train_set) # test epoch errs = self.test(epoch, net, test_sets) if (epoch + 1) % self.save_frequency == 0: net = net.to(self.train_device) # store state state_dict = { 'epoch': epoch, 'min_err': min_err, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), 'cpu_rng_state': torch.get_rng_state(), 'gpu_rng_state': torch.cuda.get_rng_state(), } logging.info(f'save state to {state_path}') state_path = self.exp_out_root / 'state.dict' torch.save(state_dict, str(state_path)) for test_set_name in errs: err = sum(errs[test_set_name]) if err < min_err[test_set_name]: min_err[test_set_name] = err state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict' logging.info(f'save state to {state_path}') torch.save(state_dict, str(state_path)) # store network net_path = self.get_net_path(epoch) logging.info(f'save network to {net_path}') torch.save(net.state_dict(), str(net_path)) if scheduler is not None: scheduler.step() logging.info('='*80) logging.info('Finished training') self.log_datetime() logging.info('='*80) def get_train_set(self): # returns train_set raise NotImplementedError() def get_test_sets(self): # returns test_sets raise NotImplementedError() def copy_data(self, data, device, requires_grad, train): raise NotImplementedError() def net_forward(self, net, train): raise NotImplementedError() def loss_forward(self, output, train): raise NotImplementedError() def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks): # err = False # for name, param in net.named_parameters(): # if not torch.isfinite(param.grad).all(): # print(name) # err = True # if err: # import ipdb; ipdb.set_trace() pass def callback_train_start(self, epoch): pass def callback_train_stop(self, epoch, loss): pass def train_epoch(self, epoch, net, optimizer, dset): self.callback_train_start(epoch) stopwatch = StopWatch() logging.info('='*80) logging.info('Train epoch %d' % epoch) dset.current_epoch = epoch train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=False) net = net.to(self.train_device) net.train() mean_loss = None n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader) bar = ETA(length=n_batches) stopwatch.start('total') stopwatch.start('data') for batch_idx, data in enumerate(train_loader): if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break self.copy_data(data, device=self.train_device, requires_grad=True, train=True) stopwatch.stop('data') optimizer.zero_grad() stopwatch.start('forward') output = self.net_forward(net, train=True) if 'cuda' in self.train_device: torch.cuda.synchronize() stopwatch.stop('forward') stopwatch.start('loss') errs = self.loss_forward(output, train=True) if isinstance(errs, dict): masks = errs['masks'] errs = errs['errs'] else: masks = [] if not isinstance(errs, list) and not isinstance(errs, tuple): errs = [errs] err = sum(errs) if 'cuda' in self.train_device: torch.cuda.synchronize() stopwatch.stop('loss') stopwatch.start('backward') err.backward() self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks) if 'cuda' in self.train_device: torch.cuda.synchronize() stopwatch.stop('backward') stopwatch.start('optimizer') optimizer.step() if 'cuda' in self.train_device: torch.cuda.synchronize() stopwatch.stop('optimizer') bar.update(batch_idx) if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0: err_str = self.format_err_str(errs) logging.info(f'train e{epoch}: {batch_idx+1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}') #self.write_err_img() if mean_loss is None: mean_loss = [0 for e in errs] for erridx, err in enumerate(errs): mean_loss[erridx] += err.item() stopwatch.start('data') stopwatch.stop('total') logging.info('timings: %s' % stopwatch) mean_loss = [l / len(train_loader) for l in mean_loss] self.callback_train_stop(epoch, mean_loss) self.metric_add_train(epoch, 'loss', mean_loss) # save metrics self.metric_save() err_str = self.format_err_str(mean_loss) logging.info(f'avg train_loss={err_str}') return mean_loss def callback_test_start(self, epoch, set_idx): pass def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks): pass def callback_test_stop(self, epoch, set_idx, loss): pass def test(self, epoch, net, test_sets): errs = {} for test_set_idx, test_set in enumerate(test_sets): if (epoch + 1) % test_set.test_frequency == 0: logging.info('='*80) logging.info(f'testing set {test_set.name}') err = self.test_epoch(epoch, test_set_idx, net, test_set.dset) errs[test_set.name] = err return errs def test_epoch(self, epoch, set_idx, net, dset): logging.info('-'*80) logging.info('Test epoch %d' % epoch) dset.current_epoch = epoch test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False, pin_memory=False) net = net.to(self.test_device) net.eval() with torch.no_grad(): mean_loss = None self.callback_test_start(epoch, set_idx) bar = ETA(length=len(test_loader)) stopwatch = StopWatch() stopwatch.start('total') stopwatch.start('data') for batch_idx, data in enumerate(test_loader): # if batch_idx == 10: break self.copy_data(data, device=self.test_device, requires_grad=False, train=False) stopwatch.stop('data') stopwatch.start('forward') output = self.net_forward(net, train=False) if 'cuda' in self.test_device: torch.cuda.synchronize() stopwatch.stop('forward') stopwatch.start('loss') errs = self.loss_forward(output, train=False) if isinstance(errs, dict): masks = errs['masks'] errs = errs['errs'] else: masks = [] if not isinstance(errs, list) and not isinstance(errs, tuple): errs = [errs] bar.update(batch_idx) if batch_idx % 25 == 0: err_str = self.format_err_str(errs) logging.info(f'test e{epoch}: {batch_idx+1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}') if mean_loss is None: mean_loss = [0 for e in errs] for erridx, err in enumerate(errs): mean_loss[erridx] += err.item() stopwatch.stop('loss') self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks) stopwatch.start('data') stopwatch.stop('total') logging.info('timings: %s' % stopwatch) mean_loss = [l / len(test_loader) for l in mean_loss] self.callback_test_stop(epoch, set_idx, mean_loss) self.metric_add_test(epoch, set_idx, 'loss', mean_loss) # save metrics self.metric_save() err_str = self.format_err_str(mean_loss) logging.info(f'test epoch {epoch}: avg test_loss={err_str}') return mean_loss