You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
539 lines
18 KiB
539 lines
18 KiB
import numpy as np
|
|
import torch
|
|
import random
|
|
import logging
|
|
import datetime
|
|
from pathlib import Path
|
|
import argparse
|
|
import subprocess
|
|
import socket
|
|
import sys
|
|
import os
|
|
import gc
|
|
import json
|
|
import matplotlib.pyplot as plt
|
|
import time
|
|
from collections import OrderedDict
|
|
import wandb
|
|
|
|
|
|
class StopWatch(object):
|
|
def __init__(self):
|
|
self.timings = OrderedDict()
|
|
self.starts = {}
|
|
|
|
def start(self, name):
|
|
self.starts[name] = time.time()
|
|
|
|
def stop(self, name):
|
|
if name not in self.timings:
|
|
self.timings[name] = []
|
|
self.timings[name].append(time.time() - self.starts[name])
|
|
|
|
def get(self, name=None, reduce=np.sum):
|
|
if name is not None:
|
|
return reduce(self.timings[name])
|
|
else:
|
|
ret = {}
|
|
for k in self.timings:
|
|
ret[k] = reduce(self.timings[k])
|
|
return ret
|
|
|
|
def __repr__(self):
|
|
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
|
|
|
|
def __str__(self):
|
|
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
|
|
|
|
|
|
class ETA(object):
|
|
def __init__(self, length):
|
|
self.length = length
|
|
self.start_time = time.time()
|
|
self.current_idx = 0
|
|
self.current_time = time.time()
|
|
|
|
def update(self, idx):
|
|
self.current_idx = idx
|
|
self.current_time = time.time()
|
|
|
|
def get_elapsed_time(self):
|
|
return self.current_time - self.start_time
|
|
|
|
def get_item_time(self):
|
|
return self.get_elapsed_time() / (self.current_idx + 1)
|
|
|
|
def get_remaining_time(self):
|
|
return self.get_item_time() * (self.length - self.current_idx + 1)
|
|
|
|
def format_time(self, seconds):
|
|
minutes, seconds = divmod(seconds, 60)
|
|
hours, minutes = divmod(minutes, 60)
|
|
hours = int(hours)
|
|
minutes = int(minutes)
|
|
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
|
|
|
|
def get_elapsed_time_str(self):
|
|
return self.format_time(self.get_elapsed_time())
|
|
|
|
def get_remaining_time_str(self):
|
|
return self.format_time(self.get_remaining_time())
|
|
|
|
|
|
class Worker(object):
|
|
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16,
|
|
num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1, no_double_heads=True):
|
|
self.out_root = Path(out_root)
|
|
self.experiment_name = experiment_name
|
|
self.epochs = epochs
|
|
self.seed = seed
|
|
self.train_batch_size = train_batch_size
|
|
self.test_batch_size = test_batch_size
|
|
self.num_workers = num_workers
|
|
self.save_frequency = save_frequency
|
|
self.train_device = train_device
|
|
self.test_device = test_device
|
|
self.max_train_iter = max_train_iter
|
|
self.double_heads = no_double_heads
|
|
|
|
self.errs_list = []
|
|
|
|
self.setup_experiment()
|
|
|
|
def setup_experiment(self):
|
|
self.exp_out_root = self.out_root / self.experiment_name
|
|
self.exp_out_root.mkdir(parents=True, exist_ok=True)
|
|
|
|
if logging.root: del logging.root.handlers[:]
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
handlers=[
|
|
logging.FileHandler(str(self.exp_out_root / 'train.log')),
|
|
logging.StreamHandler()
|
|
],
|
|
format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s'
|
|
)
|
|
|
|
logging.info('=' * 80)
|
|
logging.info(f'Start of experiment: {self.experiment_name}')
|
|
logging.info(socket.gethostname())
|
|
self.log_datetime()
|
|
logging.info('=' * 80)
|
|
|
|
self.metric_path = self.exp_out_root / 'metrics.json'
|
|
if self.metric_path.exists():
|
|
with open(str(self.metric_path), 'r') as fp:
|
|
self.metric_data = json.load(fp)
|
|
else:
|
|
self.metric_data = {}
|
|
|
|
self.init_seed()
|
|
|
|
def metric_add_train(self, epoch, key, val):
|
|
epoch = str(epoch)
|
|
key = str(key)
|
|
if epoch not in self.metric_data:
|
|
self.metric_data[epoch] = {}
|
|
if 'train' not in self.metric_data[epoch]:
|
|
self.metric_data[epoch]['train'] = {}
|
|
self.metric_data[epoch]['train'][key] = val
|
|
|
|
def metric_add_test(self, epoch, set_idx, key, val):
|
|
epoch = str(epoch)
|
|
set_idx = str(set_idx)
|
|
key = str(key)
|
|
if epoch not in self.metric_data:
|
|
self.metric_data[epoch] = {}
|
|
if 'test' not in self.metric_data[epoch]:
|
|
self.metric_data[epoch]['test'] = {}
|
|
if set_idx not in self.metric_data[epoch]['test']:
|
|
self.metric_data[epoch]['test'][set_idx] = {}
|
|
self.metric_data[epoch]['test'][set_idx][key] = val
|
|
|
|
def metric_save(self):
|
|
with open(str(self.metric_path), 'w') as fp:
|
|
json.dump(self.metric_data, fp, indent=2)
|
|
|
|
def init_seed(self, seed=None):
|
|
if seed is not None:
|
|
self.seed = seed
|
|
logging.info(f'Set seed to {self.seed}')
|
|
np.random.seed(self.seed)
|
|
random.seed(self.seed)
|
|
torch.manual_seed(self.seed)
|
|
torch.cuda.manual_seed(self.seed)
|
|
|
|
def log_datetime(self):
|
|
logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
|
|
|
def mem_report(self):
|
|
for obj in gc.get_objects():
|
|
if torch.is_tensor(obj):
|
|
print(type(obj), obj.shape)
|
|
|
|
def get_net_path(self, epoch, root=None):
|
|
if root is None:
|
|
root = self.exp_out_root
|
|
return root / f'net_{epoch:04d}.params'
|
|
|
|
def get_do_parser_cmds(self):
|
|
return ['retrain', 'resume', 'retest', 'test_init']
|
|
|
|
def get_do_parser(self):
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds())
|
|
parser.add_argument('--epoch', type=int, default=-1)
|
|
return parser
|
|
|
|
def do_cmd(self, args, net, optimizer, scheduler=None):
|
|
if args.cmd == 'retrain':
|
|
self.train(net, optimizer, resume=False, scheduler=scheduler)
|
|
elif args.cmd == 'resume':
|
|
self.train(net, optimizer, resume=True, scheduler=scheduler)
|
|
elif args.cmd == 'retest':
|
|
self.retest(net, epoch=args.epoch)
|
|
elif args.cmd == 'test_init':
|
|
test_sets = self.get_test_sets()
|
|
self.test(-1, net, test_sets)
|
|
else:
|
|
raise Exception('invalid cmd')
|
|
|
|
def do(self, net, optimizer, load_net_optimizer=None, scheduler=None):
|
|
parser = self.get_do_parser()
|
|
args, _ = parser.parse_known_args()
|
|
|
|
if load_net_optimizer is not None and args.cmd not in ['schedule']:
|
|
net, optimizer = load_net_optimizer()
|
|
|
|
self.do_cmd(args, net, optimizer, scheduler=scheduler)
|
|
|
|
def retest(self, net, epoch=-1):
|
|
if epoch < 0:
|
|
epochs = range(self.epochs)
|
|
else:
|
|
epochs = [epoch]
|
|
|
|
test_sets = self.get_test_sets()
|
|
|
|
for epoch in epochs:
|
|
net_path = self.get_net_path(epoch)
|
|
if net_path.exists():
|
|
state_dict = torch.load(str(net_path))
|
|
net.load_state_dict(state_dict)
|
|
self.test(epoch, net, test_sets)
|
|
|
|
def format_err_str(self, errs, div=1):
|
|
err = sum(errs)
|
|
if len(errs) > 1:
|
|
err_str = f'{err / div:0.4f}=' + '+'.join([f'{e / div:0.4f}' for e in errs])
|
|
else:
|
|
err_str = f'{err / div:0.4f}'
|
|
return err_str
|
|
|
|
def write_err_img(self):
|
|
err_img_path = self.exp_out_root / 'errs.png'
|
|
fig = plt.figure(figsize=(16, 16))
|
|
lines = []
|
|
for idx, errs in enumerate(self.errs_list):
|
|
line, = plt.plot(range(len(errs)), errs, label=f'error{idx}')
|
|
lines.append(line)
|
|
plt.tight_layout()
|
|
plt.legend(handles=lines)
|
|
plt.savefig(str(err_img_path))
|
|
plt.close(fig)
|
|
|
|
def callback_train_new_epoch(self, epoch, net, optimizer):
|
|
pass
|
|
|
|
def train(self, net, optimizer, resume=False, scheduler=None):
|
|
logging.info('=' * 80)
|
|
logging.info('Start training')
|
|
self.log_datetime()
|
|
logging.info('=' * 80)
|
|
|
|
train_set = self.get_train_set()
|
|
test_sets = self.get_test_sets()
|
|
|
|
net = net.to(self.train_device)
|
|
|
|
epoch = 0
|
|
min_err = {ts.name: 1e9 for ts in test_sets}
|
|
|
|
state_path = self.exp_out_root / 'state.dict'
|
|
if resume and state_path.exists():
|
|
logging.info('=' * 80)
|
|
logging.info(f'Loading state from {state_path}')
|
|
logging.info('=' * 80)
|
|
state = torch.load(str(state_path))
|
|
epoch = state['epoch'] + 1
|
|
if 'min_err' in state:
|
|
min_err = state['min_err']
|
|
|
|
curr_state = net.state_dict()
|
|
curr_state.update(state['state_dict'])
|
|
net.load_state_dict(curr_state)
|
|
|
|
try:
|
|
optimizer.load_state_dict(state['optimizer'])
|
|
except:
|
|
logging.info('Warning: cannot load optimizer from state_dict')
|
|
pass
|
|
if 'cpu_rng_state' in state:
|
|
torch.set_rng_state(state['cpu_rng_state'])
|
|
if 'gpu_rng_state' in state:
|
|
torch.cuda.set_rng_state(state['gpu_rng_state'])
|
|
|
|
for epoch in range(epoch, self.epochs):
|
|
self.callback_train_new_epoch(epoch, net, optimizer)
|
|
|
|
# train epoch
|
|
self.train_epoch(epoch, net, optimizer, train_set)
|
|
|
|
# test epoch
|
|
errs = self.test(epoch, net, test_sets)
|
|
|
|
if (epoch + 1) % self.save_frequency == 0:
|
|
net = net.to(self.train_device)
|
|
|
|
# store state
|
|
state_dict = {
|
|
'epoch': epoch,
|
|
'min_err': min_err,
|
|
'state_dict': net.state_dict(),
|
|
'optimizer': optimizer.state_dict(),
|
|
'cpu_rng_state': torch.get_rng_state(),
|
|
'gpu_rng_state': torch.cuda.get_rng_state(),
|
|
}
|
|
logging.info(f'save state to {state_path}')
|
|
state_path = self.exp_out_root / 'state.dict'
|
|
torch.save(state_dict, str(state_path))
|
|
|
|
for test_set_name in errs:
|
|
err = sum(errs[test_set_name])
|
|
if err < min_err[test_set_name]:
|
|
min_err[test_set_name] = err
|
|
state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict'
|
|
logging.info(f'save state to {state_path}')
|
|
torch.save(state_dict, str(state_path))
|
|
|
|
# store network
|
|
net_path = self.get_net_path(epoch)
|
|
logging.info(f'save network to {net_path}')
|
|
torch.save(net.state_dict(), str(net_path))
|
|
|
|
if scheduler is not None:
|
|
scheduler.step()
|
|
|
|
logging.info('=' * 80)
|
|
logging.info('Finished training')
|
|
self.log_datetime()
|
|
logging.info('=' * 80)
|
|
|
|
def get_train_set(self):
|
|
# returns train_set
|
|
raise NotImplementedError()
|
|
|
|
def get_test_sets(self):
|
|
# returns test_sets
|
|
raise NotImplementedError()
|
|
|
|
def copy_data(self, data, device, requires_grad, train):
|
|
raise NotImplementedError()
|
|
|
|
def net_forward(self, net, train):
|
|
raise NotImplementedError()
|
|
|
|
def loss_forward(self, output, train):
|
|
raise NotImplementedError()
|
|
|
|
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
|
|
# err = False
|
|
# for name, param in net.named_parameters():
|
|
# if not torch.isfinite(param.grad).all():
|
|
# print(name)
|
|
# err = True
|
|
# if err:
|
|
# import ipdb; ipdb.set_trace()
|
|
pass
|
|
|
|
def callback_train_start(self, epoch):
|
|
pass
|
|
|
|
def callback_train_stop(self, epoch, loss):
|
|
pass
|
|
|
|
def train_epoch(self, epoch, net, optimizer, dset):
|
|
self.callback_train_start(epoch)
|
|
stopwatch = StopWatch()
|
|
|
|
logging.info('=' * 80)
|
|
logging.info('Train epoch %d' % epoch)
|
|
|
|
dset.current_epoch = epoch
|
|
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True,
|
|
num_workers=self.num_workers, drop_last=True, pin_memory=False)
|
|
|
|
net = net.to(self.train_device)
|
|
wandb.watch(net)
|
|
net.train()
|
|
|
|
mean_loss = None
|
|
|
|
n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader)
|
|
bar = ETA(length=n_batches)
|
|
|
|
stopwatch.start('total')
|
|
stopwatch.start('data')
|
|
for batch_idx, data in enumerate(train_loader):
|
|
if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break
|
|
self.copy_data(data, device=self.train_device, requires_grad=True, train=True)
|
|
stopwatch.stop('data')
|
|
|
|
optimizer.zero_grad()
|
|
|
|
stopwatch.start('forward')
|
|
output = self.net_forward(net, train=True)
|
|
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
|
stopwatch.stop('forward')
|
|
|
|
stopwatch.start('loss')
|
|
errs = self.loss_forward(output, train=True)
|
|
if isinstance(errs, dict):
|
|
wandb.log(errs)
|
|
masks = errs['masks']
|
|
errs = errs['errs']
|
|
else:
|
|
masks = []
|
|
if not isinstance(errs, list) and not isinstance(errs, tuple):
|
|
errs = [errs]
|
|
err = sum(errs)
|
|
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
|
stopwatch.stop('loss')
|
|
|
|
stopwatch.start('backward')
|
|
err.backward()
|
|
self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks)
|
|
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
|
stopwatch.stop('backward')
|
|
|
|
stopwatch.start('optimizer')
|
|
optimizer.step()
|
|
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
|
stopwatch.stop('optimizer')
|
|
|
|
bar.update(batch_idx)
|
|
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
|
|
err_str = self.format_err_str(errs)
|
|
logging.info(
|
|
f'train e{epoch}: {batch_idx + 1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
|
# self.write_err_img()
|
|
|
|
if mean_loss is None:
|
|
mean_loss = [0 for e in errs]
|
|
for erridx, err in enumerate(errs):
|
|
mean_loss[erridx] += err.item()
|
|
|
|
stopwatch.start('data')
|
|
stopwatch.stop('total')
|
|
logging.info('timings: %s' % stopwatch)
|
|
|
|
mean_loss = [l / len(train_loader) for l in mean_loss]
|
|
self.callback_train_stop(epoch, mean_loss)
|
|
self.metric_add_train(epoch, 'loss', mean_loss)
|
|
|
|
# save metrics
|
|
self.metric_save()
|
|
|
|
err_str = self.format_err_str(mean_loss)
|
|
logging.info(f'avg train_loss={err_str}')
|
|
wandb.log({'mean_loss': mean_loss})
|
|
return mean_loss
|
|
|
|
def callback_test_start(self, epoch, set_idx):
|
|
pass
|
|
|
|
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
|
|
pass
|
|
|
|
def callback_test_stop(self, epoch, set_idx, loss):
|
|
pass
|
|
|
|
def test(self, epoch, net, test_sets):
|
|
errs = {}
|
|
for test_set_idx, test_set in enumerate(test_sets):
|
|
if (epoch + 1) % test_set.test_frequency == 0:
|
|
logging.info('=' * 80)
|
|
logging.info(f'testing set {test_set.name}')
|
|
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
|
|
errs[test_set.name] = err
|
|
return errs
|
|
|
|
def test_epoch(self, epoch, set_idx, net, dset):
|
|
logging.info('-' * 80)
|
|
logging.info('Test epoch %d' % epoch)
|
|
dset.current_epoch = epoch
|
|
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False,
|
|
num_workers=self.num_workers, drop_last=False, pin_memory=False)
|
|
|
|
net = net.to(self.test_device)
|
|
net.eval()
|
|
|
|
with torch.no_grad():
|
|
mean_loss = None
|
|
|
|
self.callback_test_start(epoch, set_idx)
|
|
|
|
bar = ETA(length=len(test_loader))
|
|
stopwatch = StopWatch()
|
|
stopwatch.start('total')
|
|
stopwatch.start('data')
|
|
for batch_idx, data in enumerate(test_loader):
|
|
# if batch_idx == 10: break
|
|
self.copy_data(data, device=self.test_device, requires_grad=False, train=False)
|
|
stopwatch.stop('data')
|
|
|
|
stopwatch.start('forward')
|
|
output = self.net_forward(net, train=False)
|
|
if 'cuda' in self.test_device: torch.cuda.synchronize()
|
|
stopwatch.stop('forward')
|
|
|
|
stopwatch.start('loss')
|
|
errs = self.loss_forward(output, train=False)
|
|
if isinstance(errs, dict):
|
|
wandb.log(errs)
|
|
masks = errs['masks']
|
|
errs = errs['errs']
|
|
else:
|
|
masks = []
|
|
if not isinstance(errs, list) and not isinstance(errs, tuple):
|
|
errs = [errs]
|
|
|
|
bar.update(batch_idx)
|
|
if batch_idx % 25 == 0:
|
|
err_str = self.format_err_str(errs)
|
|
logging.info(
|
|
f'test e{epoch}: {batch_idx + 1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
|
|
|
if mean_loss is None:
|
|
mean_loss = [0 for e in errs]
|
|
for erridx, err in enumerate(errs):
|
|
mean_loss[erridx] += err.item()
|
|
stopwatch.stop('loss')
|
|
|
|
self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks)
|
|
|
|
stopwatch.start('data')
|
|
stopwatch.stop('total')
|
|
logging.info('timings: %s' % stopwatch)
|
|
|
|
mean_loss = [l / len(test_loader) for l in mean_loss]
|
|
self.callback_test_stop(epoch, set_idx, mean_loss)
|
|
self.metric_add_test(epoch, set_idx, 'loss', mean_loss)
|
|
|
|
# save metrics
|
|
self.metric_save()
|
|
|
|
err_str = self.format_err_str(mean_loss)
|
|
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
|
|
wandb.log({'Test loss': mean_loss})
|
|
return mean_loss
|
|
|