Reformat $EVERYTHING

This commit is contained in:
2021-11-15 16:53:30 +01:00
parent 56f2aa7d5d
commit 43df77fb9b
32 changed files with 4171 additions and 3749 deletions
+45 -45
View File
@@ -2,65 +2,65 @@ import torch
import torch.utils.data
import numpy as np
class TestSet(object):
def __init__(self, name, dset, test_frequency=1):
self.name = name
self.dset = dset
self.test_frequency = test_frequency
def __init__(self, name, dset, test_frequency=1):
self.name = name
self.dset = dset
self.test_frequency = test_frequency
class TestSets(list):
def append(self, name, dset, test_frequency=1):
super().append(TestSet(name, dset, test_frequency))
def append(self, name, dset, test_frequency=1):
super().append(TestSet(name, dset, test_frequency))
class MultiDataset(torch.utils.data.Dataset):
def __init__(self, *datasets):
self.current_epoch = 0
def __init__(self, *datasets):
self.current_epoch = 0
self.datasets = []
self.cum_n_samples = [0]
self.datasets = []
self.cum_n_samples = [0]
for dataset in datasets:
self.append(dataset)
for dataset in datasets:
self.append(dataset)
def append(self, dataset):
self.datasets.append(dataset)
self.__update_cum_n_samples(dataset)
def append(self, dataset):
self.datasets.append(dataset)
self.__update_cum_n_samples(dataset)
def __update_cum_n_samples(self, dataset):
n_samples = self.cum_n_samples[-1] + len(dataset)
self.cum_n_samples.append(n_samples)
def __update_cum_n_samples(self, dataset):
n_samples = self.cum_n_samples[-1] + len(dataset)
self.cum_n_samples.append(n_samples)
def dataset_updated(self):
self.cum_n_samples = [0]
for dset in self.datasets:
self.__update_cum_n_samples(dset)
def dataset_updated(self):
self.cum_n_samples = [0]
for dset in self.datasets:
self.__update_cum_n_samples(dset)
def __len__(self):
return self.cum_n_samples[-1]
def __getitem__(self, idx):
didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1
sidx = idx - self.cum_n_samples[didx]
return self.datasets[didx][sidx]
def __len__(self):
return self.cum_n_samples[-1]
def __getitem__(self, idx):
didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1
sidx = idx - self.cum_n_samples[didx]
return self.datasets[didx][sidx]
class BaseDataset(torch.utils.data.Dataset):
def __init__(self, train=True, fix_seed_per_epoch=False):
self.current_epoch = 0
self.train = train
self.fix_seed_per_epoch = fix_seed_per_epoch
def __init__(self, train=True, fix_seed_per_epoch=False):
self.current_epoch = 0
self.train = train
self.fix_seed_per_epoch = fix_seed_per_epoch
def get_rng(self, idx):
rng = np.random.RandomState()
if self.train:
if self.fix_seed_per_epoch:
seed = 1 * len(self) + idx
else:
seed = (self.current_epoch + 1) * len(self) + idx
rng.seed(seed)
else:
rng.seed(idx)
return rng
def get_rng(self, idx):
rng = np.random.RandomState()
if self.train:
if self.fix_seed_per_epoch:
seed = 1 * len(self) + idx
else:
seed = (self.current_epoch + 1) * len(self) + idx
rng.seed(seed)
else:
rng.seed(idx)
return rng
+120 -115
View File
@@ -2,146 +2,151 @@ import torch
from . import ext_cpu
from . import ext_cuda
class NNFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, in0, in1):
args = (in0, in1)
if in0.is_cuda:
out = ext_cuda.nn_cuda(*args)
else:
out = ext_cpu.nn_cpu(*args)
return out
@staticmethod
def backward(ctx, grad_out):
return None, None
class NNFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, in0, in1):
args = (in0, in1)
if in0.is_cuda:
out = ext_cuda.nn_cuda(*args)
else:
out = ext_cpu.nn_cpu(*args)
return out
@staticmethod
def backward(ctx, grad_out):
return None, None
def nn(in0, in1):
return NNFunction.apply(in0, in1)
return NNFunction.apply(in0, in1)
class CrossCheckFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, in0, in1):
args = (in0, in1)
if in0.is_cuda:
out = ext_cuda.crosscheck_cuda(*args)
else:
out = ext_cpu.crosscheck_cpu(*args)
return out
@staticmethod
def forward(ctx, in0, in1):
args = (in0, in1)
if in0.is_cuda:
out = ext_cuda.crosscheck_cuda(*args)
else:
out = ext_cpu.crosscheck_cpu(*args)
return out
@staticmethod
def backward(ctx, grad_out):
return None, None
@staticmethod
def backward(ctx, grad_out):
return None, None
def crosscheck(in0, in1):
return CrossCheckFunction.apply(in0, in1)
return CrossCheckFunction.apply(in0, in1)
class ProjNNFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz0, xyz1, K, patch_size):
args = (xyz0, xyz1, K, patch_size)
if xyz0.is_cuda:
out = ext_cuda.proj_nn_cuda(*args)
else:
out = ext_cpu.proj_nn_cpu(*args)
return out
@staticmethod
def forward(ctx, xyz0, xyz1, K, patch_size):
args = (xyz0, xyz1, K, patch_size)
if xyz0.is_cuda:
out = ext_cuda.proj_nn_cuda(*args)
else:
out = ext_cpu.proj_nn_cpu(*args)
return out
@staticmethod
def backward(ctx, grad_out):
return None, None, None, None
@staticmethod
def backward(ctx, grad_out):
return None, None, None, None
def proj_nn(xyz0, xyz1, K, patch_size):
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
class XCorrVolFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, in0, in1, n_disps, block_size):
args = (in0, in1, n_disps, block_size)
if in0.is_cuda:
out = ext_cuda.xcorrvol_cuda(*args)
else:
out = ext_cpu.xcorrvol_cpu(*args)
return out
@staticmethod
def forward(ctx, in0, in1, n_disps, block_size):
args = (in0, in1, n_disps, block_size)
if in0.is_cuda:
out = ext_cuda.xcorrvol_cuda(*args)
else:
out = ext_cpu.xcorrvol_cpu(*args)
return out
@staticmethod
def backward(ctx, grad_out):
return None, None, None, None
@staticmethod
def backward(ctx, grad_out):
return None, None, None, None
def xcorrvol(in0, in1, n_disps, block_size):
return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
class PhotometricLossFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, es, ta, block_size, type, eps):
args = (es, ta, block_size, type, eps)
ctx.save_for_backward(es, ta)
ctx.block_size = block_size
ctx.type = type
ctx.eps = eps
if es.is_cuda:
out = ext_cuda.photometric_loss_forward(*args)
else:
out = ext_cpu.photometric_loss_forward(*args)
return out
@staticmethod
def forward(ctx, es, ta, block_size, type, eps):
args = (es, ta, block_size, type, eps)
ctx.save_for_backward(es, ta)
ctx.block_size = block_size
ctx.type = type
ctx.eps = eps
if es.is_cuda:
out = ext_cuda.photometric_loss_forward(*args)
else:
out = ext_cpu.photometric_loss_forward(*args)
return out
@staticmethod
def backward(ctx, grad_out):
es, ta = ctx.saved_tensors
block_size = ctx.block_size
type = ctx.type
eps = ctx.eps
args = (es, ta, grad_out.contiguous(), block_size, type, eps)
if grad_out.is_cuda:
grad_es = ext_cuda.photometric_loss_backward(*args)
else:
grad_es = ext_cpu.photometric_loss_backward(*args)
return grad_es, None, None, None, None
@staticmethod
def backward(ctx, grad_out):
es, ta = ctx.saved_tensors
block_size = ctx.block_size
type = ctx.type
eps = ctx.eps
args = (es, ta, grad_out.contiguous(), block_size, type, eps)
if grad_out.is_cuda:
grad_es = ext_cuda.photometric_loss_backward(*args)
else:
grad_es = ext_cpu.photometric_loss_backward(*args)
return grad_es, None, None, None, None
def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
type = type.lower()
if type == 'mse':
type = 0
elif type == 'sad':
type = 1
elif type == 'census_mse':
type = 2
elif type == 'census_sad':
type = 3
else:
raise Exception('invalid loss type')
return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
type = type.lower()
if type == 'mse':
type = 0
elif type == 'sad':
type = 1
elif type == 'census_mse':
type = 2
elif type == 'census_sad':
type = 3
else:
raise Exception('invalid loss type')
return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
type = type.lower()
p = block_size // 2
es_pad = torch.nn.functional.pad(es, (p,p,p,p), mode='replicate')
ta_pad = torch.nn.functional.pad(ta, (p,p,p,p), mode='replicate')
es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
if type == 'mse':
ref = (es_uf - ta_uf)**2
elif type == 'sad':
ref = torch.abs(es_uf - ta_uf)
elif type == 'census_mse' or type == 'census_sad':
des = es_uf - es.unsqueeze(2)
dta = ta_uf - ta.unsqueeze(2)
h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps))
h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps))
diff = h_des - h_dta
if type == 'census_mse':
ref = diff * diff
elif type == 'census_sad':
ref = torch.abs(diff)
else:
raise Exception('invalid loss type')
ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
ref = torch.sum(ref, dim=1, keepdim=True) / block_size**2
return ref
type = type.lower()
p = block_size // 2
es_pad = torch.nn.functional.pad(es, (p, p, p, p), mode='replicate')
ta_pad = torch.nn.functional.pad(ta, (p, p, p, p), mode='replicate')
es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
if type == 'mse':
ref = (es_uf - ta_uf) ** 2
elif type == 'sad':
ref = torch.abs(es_uf - ta_uf)
elif type == 'census_mse' or type == 'census_sad':
des = es_uf - es.unsqueeze(2)
dta = ta_uf - ta.unsqueeze(2)
h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps))
h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps))
diff = h_des - h_dta
if type == 'census_mse':
ref = diff * diff
elif type == 'census_sad':
ref = torch.abs(diff)
else:
raise Exception('invalid loss type')
ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
ref = torch.sum(ref, dim=1, keepdim=True) / block_size ** 2
return ref
+19 -17
View File
@@ -4,24 +4,26 @@ import numpy as np
from .functions import *
class CoordConv2d(torch.nn.Module):
def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
super().__init__()
def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
super().__init__()
self.conv = torch.nn.Conv2d(channels_in+2, channels_out, kernel_size=kernel_size, padding=padding, stride=stride)
self.conv = torch.nn.Conv2d(channels_in + 2, channels_out, kernel_size=kernel_size, padding=padding,
stride=stride)
self.uv = None
self.uv = None
def forward(self, x):
if self.uv is None:
height, width = x.shape[2], x.shape[3]
u, v = np.meshgrid(range(width), range(height))
u = 2 * u / (width - 1) - 1
v = 2 * v / (height - 1) - 1
uv = np.stack((u, v)).reshape(1, 2, height, width)
self.uv = torch.from_numpy( uv.astype(np.float32) )
self.uv = self.uv.to(x.device)
uv = self.uv.expand(x.shape[0], *self.uv.shape[1:])
xuv = torch.cat((x, uv), dim=1)
y = self.conv(xuv)
return y
def forward(self, x):
if self.uv is None:
height, width = x.shape[2], x.shape[3]
u, v = np.meshgrid(range(width), range(height))
u = 2 * u / (width - 1) - 1
v = 2 * v / (height - 1) - 1
uv = np.stack((u, v)).reshape(1, 2, height, width)
self.uv = torch.from_numpy(uv.astype(np.float32))
self.uv = self.uv.to(x.device)
uv = self.uv.expand(x.shape[0], *self.uv.shape[1:])
xuv = torch.cat((x, uv), dim=1)
y = self.conv(xuv)
return y
+8 -7
View File
@@ -11,11 +11,12 @@ nvcc_args = [
]
setup(
name='ext',
ext_modules=[
CppExtension('ext_cpu', ['ext/ext_cpu.cpp']),
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'], extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
],
cmdclass={'build_ext': BuildExtension},
include_dirs=include_dirs
name='ext',
ext_modules=[
CppExtension('ext_cpu', ['ext/ext_cpu.cpp']),
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'],
extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
],
cmdclass={'build_ext': BuildExtension},
include_dirs=include_dirs
)
+434 -430
View File
@@ -17,512 +17,516 @@ from collections import OrderedDict
class StopWatch(object):
def __init__(self):
self.timings = OrderedDict()
self.starts = {}
def __init__(self):
self.timings = OrderedDict()
self.starts = {}
def start(self, name):
self.starts[name] = time.time()
def start(self, name):
self.starts[name] = time.time()
def stop(self, name):
if name not in self.timings:
self.timings[name] = []
self.timings[name].append(time.time() - self.starts[name])
def stop(self, name):
if name not in self.timings:
self.timings[name] = []
self.timings[name].append(time.time() - self.starts[name])
def get(self, name=None, reduce=np.sum):
if name is not None:
return reduce(self.timings[name])
else:
ret = {}
for k in self.timings:
ret[k] = reduce(self.timings[k])
return ret
def get(self, name=None, reduce=np.sum):
if name is not None:
return reduce(self.timings[name])
else:
ret = {}
for k in self.timings:
ret[k] = reduce(self.timings[k])
return ret
def __repr__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
def __str__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
def __repr__(self):
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
def __str__(self):
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
class ETA(object):
def __init__(self, length):
self.length = length
self.start_time = time.time()
self.current_idx = 0
self.current_time = time.time()
def __init__(self, length):
self.length = length
self.start_time = time.time()
self.current_idx = 0
self.current_time = time.time()
def update(self, idx):
self.current_idx = idx
self.current_time = time.time()
def update(self, idx):
self.current_idx = idx
self.current_time = time.time()
def get_elapsed_time(self):
return self.current_time - self.start_time
def get_elapsed_time(self):
return self.current_time - self.start_time
def get_item_time(self):
return self.get_elapsed_time() / (self.current_idx + 1)
def get_item_time(self):
return self.get_elapsed_time() / (self.current_idx + 1)
def get_remaining_time(self):
return self.get_item_time() * (self.length - self.current_idx + 1)
def get_remaining_time(self):
return self.get_item_time() * (self.length - self.current_idx + 1)
def format_time(self, seconds):
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
hours = int(hours)
minutes = int(minutes)
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
def format_time(self, seconds):
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
hours = int(hours)
minutes = int(minutes)
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
def get_elapsed_time_str(self):
return self.format_time(self.get_elapsed_time())
def get_elapsed_time_str(self):
return self.format_time(self.get_elapsed_time())
def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time())
def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time())
class Worker(object):
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
self.out_root = Path(out_root)
self.experiment_name = experiment_name
self.epochs = epochs
self.seed = seed
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.num_workers = num_workers
self.save_frequency = save_frequency
self.train_device = train_device
self.test_device = test_device
self.max_train_iter = max_train_iter
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16,
num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
self.out_root = Path(out_root)
self.experiment_name = experiment_name
self.epochs = epochs
self.seed = seed
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.num_workers = num_workers
self.save_frequency = save_frequency
self.train_device = train_device
self.test_device = test_device
self.max_train_iter = max_train_iter
self.errs_list=[]
self.errs_list = []
self.setup_experiment()
self.setup_experiment()
def setup_experiment(self):
self.exp_out_root = self.out_root / self.experiment_name
self.exp_out_root.mkdir(parents=True, exist_ok=True)
def setup_experiment(self):
self.exp_out_root = self.out_root / self.experiment_name
self.exp_out_root.mkdir(parents=True, exist_ok=True)
if logging.root: del logging.root.handlers[:]
logging.basicConfig(
level=logging.INFO,
handlers=[
logging.FileHandler( str(self.exp_out_root / 'train.log') ),
logging.StreamHandler()
],
format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s'
)
if logging.root: del logging.root.handlers[:]
logging.basicConfig(
level=logging.INFO,
handlers=[
logging.FileHandler(str(self.exp_out_root / 'train.log')),
logging.StreamHandler()
],
format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s'
)
logging.info('='*80)
logging.info(f'Start of experiment: {self.experiment_name}')
logging.info(socket.gethostname())
self.log_datetime()
logging.info('='*80)
logging.info('=' * 80)
logging.info(f'Start of experiment: {self.experiment_name}')
logging.info(socket.gethostname())
self.log_datetime()
logging.info('=' * 80)
self.metric_path = self.exp_out_root / 'metrics.json'
if self.metric_path.exists():
with open(str(self.metric_path), 'r') as fp:
self.metric_data = json.load(fp)
else:
self.metric_data = {}
self.metric_path = self.exp_out_root / 'metrics.json'
if self.metric_path.exists():
with open(str(self.metric_path), 'r') as fp:
self.metric_data = json.load(fp)
else:
self.metric_data = {}
self.init_seed()
self.init_seed()
def metric_add_train(self, epoch, key, val):
epoch = str(epoch)
key = str(key)
if epoch not in self.metric_data:
self.metric_data[epoch] = {}
if 'train' not in self.metric_data[epoch]:
self.metric_data[epoch]['train'] = {}
self.metric_data[epoch]['train'][key] = val
def metric_add_train(self, epoch, key, val):
epoch = str(epoch)
key = str(key)
if epoch not in self.metric_data:
self.metric_data[epoch] = {}
if 'train' not in self.metric_data[epoch]:
self.metric_data[epoch]['train'] = {}
self.metric_data[epoch]['train'][key] = val
def metric_add_test(self, epoch, set_idx, key, val):
epoch = str(epoch)
set_idx = str(set_idx)
key = str(key)
if epoch not in self.metric_data:
self.metric_data[epoch] = {}
if 'test' not in self.metric_data[epoch]:
self.metric_data[epoch]['test'] = {}
if set_idx not in self.metric_data[epoch]['test']:
self.metric_data[epoch]['test'][set_idx] = {}
self.metric_data[epoch]['test'][set_idx][key] = val
def metric_add_test(self, epoch, set_idx, key, val):
epoch = str(epoch)
set_idx = str(set_idx)
key = str(key)
if epoch not in self.metric_data:
self.metric_data[epoch] = {}
if 'test' not in self.metric_data[epoch]:
self.metric_data[epoch]['test'] = {}
if set_idx not in self.metric_data[epoch]['test']:
self.metric_data[epoch]['test'][set_idx] = {}
self.metric_data[epoch]['test'][set_idx][key] = val
def metric_save(self):
with open(str(self.metric_path), 'w') as fp:
json.dump(self.metric_data, fp, indent=2)
def metric_save(self):
with open(str(self.metric_path), 'w') as fp:
json.dump(self.metric_data, fp, indent=2)
def init_seed(self, seed=None):
if seed is not None:
self.seed = seed
logging.info(f'Set seed to {self.seed}')
np.random.seed(self.seed)
random.seed(self.seed)
torch.manual_seed(self.seed)
torch.cuda.manual_seed(self.seed)
def init_seed(self, seed=None):
if seed is not None:
self.seed = seed
logging.info(f'Set seed to {self.seed}')
np.random.seed(self.seed)
random.seed(self.seed)
torch.manual_seed(self.seed)
torch.cuda.manual_seed(self.seed)
def log_datetime(self):
logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
def log_datetime(self):
logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
def mem_report(self):
for obj in gc.get_objects():
if torch.is_tensor(obj):
print(type(obj), obj.shape)
def mem_report(self):
for obj in gc.get_objects():
if torch.is_tensor(obj):
print(type(obj), obj.shape)
def get_net_path(self, epoch, root=None):
if root is None:
root = self.exp_out_root
return root / f'net_{epoch:04d}.params'
def get_net_path(self, epoch, root=None):
if root is None:
root = self.exp_out_root
return root / f'net_{epoch:04d}.params'
def get_do_parser_cmds(self):
return ['retrain', 'resume', 'retest', 'test_init']
def get_do_parser_cmds(self):
return ['retrain', 'resume', 'retest', 'test_init']
def get_do_parser(self):
parser = argparse.ArgumentParser()
parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds())
parser.add_argument('--epoch', type=int, default=-1)
return parser
def get_do_parser(self):
parser = argparse.ArgumentParser()
parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds())
parser.add_argument('--epoch', type=int, default=-1)
return parser
def do_cmd(self, args, net, optimizer, scheduler=None):
if args.cmd == 'retrain':
self.train(net, optimizer, resume=False, scheduler=scheduler)
elif args.cmd == 'resume':
self.train(net, optimizer, resume=True, scheduler=scheduler)
elif args.cmd == 'retest':
self.retest(net, epoch=args.epoch)
elif args.cmd == 'test_init':
test_sets = self.get_test_sets()
self.test(-1, net, test_sets)
else:
raise Exception('invalid cmd')
def do_cmd(self, args, net, optimizer, scheduler=None):
if args.cmd == 'retrain':
self.train(net, optimizer, resume=False, scheduler=scheduler)
elif args.cmd == 'resume':
self.train(net, optimizer, resume=True, scheduler=scheduler)
elif args.cmd == 'retest':
self.retest(net, epoch=args.epoch)
elif args.cmd == 'test_init':
test_sets = self.get_test_sets()
self.test(-1, net, test_sets)
else:
raise Exception('invalid cmd')
def do(self, net, optimizer, load_net_optimizer=None, scheduler=None):
parser = self.get_do_parser()
args, _ = parser.parse_known_args()
def do(self, net, optimizer, load_net_optimizer=None, scheduler=None):
parser = self.get_do_parser()
args, _ = parser.parse_known_args()
if load_net_optimizer is not None and args.cmd not in ['schedule']:
net, optimizer = load_net_optimizer()
if load_net_optimizer is not None and args.cmd not in ['schedule']:
net, optimizer = load_net_optimizer()
self.do_cmd(args, net, optimizer, scheduler=scheduler)
self.do_cmd(args, net, optimizer, scheduler=scheduler)
def retest(self, net, epoch=-1):
if epoch < 0:
epochs = range(self.epochs)
else:
epochs = [epoch]
def retest(self, net, epoch=-1):
if epoch < 0:
epochs = range(self.epochs)
else:
epochs = [epoch]
test_sets = self.get_test_sets()
test_sets = self.get_test_sets()
for epoch in epochs:
net_path = self.get_net_path(epoch)
if net_path.exists():
state_dict = torch.load(str(net_path))
net.load_state_dict(state_dict)
self.test(epoch, net, test_sets)
for epoch in epochs:
net_path = self.get_net_path(epoch)
if net_path.exists():
state_dict = torch.load(str(net_path))
net.load_state_dict(state_dict)
self.test(epoch, net, test_sets)
def format_err_str(self, errs, div=1):
err = sum(errs)
if len(errs) > 1:
err_str = f'{err/div:0.4f}=' + '+'.join([f'{e/div:0.4f}' for e in errs])
else:
err_str = f'{err/div:0.4f}'
return err_str
def format_err_str(self, errs, div=1):
err = sum(errs)
if len(errs) > 1:
err_str = f'{err / div:0.4f}=' + '+'.join([f'{e / div:0.4f}' for e in errs])
else:
err_str = f'{err / div:0.4f}'
return err_str
def write_err_img(self):
err_img_path = self.exp_out_root / 'errs.png'
fig = plt.figure(figsize=(16,16))
lines=[]
for idx,errs in enumerate(self.errs_list):
line,=plt.plot(range(len(errs)), errs, label=f'error{idx}')
lines.append(line)
plt.tight_layout()
plt.legend(handles=lines)
plt.savefig(str(err_img_path))
plt.close(fig)
def write_err_img(self):
err_img_path = self.exp_out_root / 'errs.png'
fig = plt.figure(figsize=(16, 16))
lines = []
for idx, errs in enumerate(self.errs_list):
line, = plt.plot(range(len(errs)), errs, label=f'error{idx}')
lines.append(line)
plt.tight_layout()
plt.legend(handles=lines)
plt.savefig(str(err_img_path))
plt.close(fig)
def callback_train_new_epoch(self, epoch, net, optimizer):
pass
def train(self, net, optimizer, resume=False, scheduler=None):
logging.info('='*80)
logging.info('Start training')
self.log_datetime()
logging.info('='*80)
train_set = self.get_train_set()
test_sets = self.get_test_sets()
net = net.to(self.train_device)
epoch = 0
min_err = {ts.name: 1e9 for ts in test_sets}
state_path = self.exp_out_root / 'state.dict'
if resume and state_path.exists():
logging.info('='*80)
logging.info(f'Loading state from {state_path}')
logging.info('='*80)
state = torch.load(str(state_path))
epoch = state['epoch'] + 1
if 'min_err' in state:
min_err = state['min_err']
curr_state = net.state_dict()
curr_state.update(state['state_dict'])
net.load_state_dict(curr_state)
try:
optimizer.load_state_dict(state['optimizer'])
except:
logging.info('Warning: cannot load optimizer from state_dict')
def callback_train_new_epoch(self, epoch, net, optimizer):
pass
if 'cpu_rng_state' in state:
torch.set_rng_state(state['cpu_rng_state'])
if 'gpu_rng_state' in state:
torch.cuda.set_rng_state(state['gpu_rng_state'])
for epoch in range(epoch, self.epochs):
self.callback_train_new_epoch(epoch, net, optimizer)
def train(self, net, optimizer, resume=False, scheduler=None):
logging.info('=' * 80)
logging.info('Start training')
self.log_datetime()
logging.info('=' * 80)
# train epoch
self.train_epoch(epoch, net, optimizer, train_set)
train_set = self.get_train_set()
test_sets = self.get_test_sets()
# test epoch
errs = self.test(epoch, net, test_sets)
if (epoch + 1) % self.save_frequency == 0:
net = net.to(self.train_device)
# store state
state_dict = {
'epoch': epoch,
'min_err': min_err,
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict(),
'cpu_rng_state': torch.get_rng_state(),
'gpu_rng_state': torch.cuda.get_rng_state(),
}
logging.info(f'save state to {state_path}')
epoch = 0
min_err = {ts.name: 1e9 for ts in test_sets}
state_path = self.exp_out_root / 'state.dict'
torch.save(state_dict, str(state_path))
if resume and state_path.exists():
logging.info('=' * 80)
logging.info(f'Loading state from {state_path}')
logging.info('=' * 80)
state = torch.load(str(state_path))
epoch = state['epoch'] + 1
if 'min_err' in state:
min_err = state['min_err']
for test_set_name in errs:
err = sum(errs[test_set_name])
if err < min_err[test_set_name]:
min_err[test_set_name] = err
state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict'
logging.info(f'save state to {state_path}')
torch.save(state_dict, str(state_path))
curr_state = net.state_dict()
curr_state.update(state['state_dict'])
net.load_state_dict(curr_state)
# store network
net_path = self.get_net_path(epoch)
logging.info(f'save network to {net_path}')
torch.save(net.state_dict(), str(net_path))
try:
optimizer.load_state_dict(state['optimizer'])
except:
logging.info('Warning: cannot load optimizer from state_dict')
pass
if 'cpu_rng_state' in state:
torch.set_rng_state(state['cpu_rng_state'])
if 'gpu_rng_state' in state:
torch.cuda.set_rng_state(state['gpu_rng_state'])
if scheduler is not None:
scheduler.step()
for epoch in range(epoch, self.epochs):
self.callback_train_new_epoch(epoch, net, optimizer)
logging.info('='*80)
logging.info('Finished training')
self.log_datetime()
logging.info('='*80)
# train epoch
self.train_epoch(epoch, net, optimizer, train_set)
def get_train_set(self):
# returns train_set
raise NotImplementedError()
# test epoch
errs = self.test(epoch, net, test_sets)
def get_test_sets(self):
# returns test_sets
raise NotImplementedError()
if (epoch + 1) % self.save_frequency == 0:
net = net.to(self.train_device)
def copy_data(self, data, device, requires_grad, train):
raise NotImplementedError()
# store state
state_dict = {
'epoch': epoch,
'min_err': min_err,
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict(),
'cpu_rng_state': torch.get_rng_state(),
'gpu_rng_state': torch.cuda.get_rng_state(),
}
logging.info(f'save state to {state_path}')
state_path = self.exp_out_root / 'state.dict'
torch.save(state_dict, str(state_path))
def net_forward(self, net, train):
raise NotImplementedError()
for test_set_name in errs:
err = sum(errs[test_set_name])
if err < min_err[test_set_name]:
min_err[test_set_name] = err
state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict'
logging.info(f'save state to {state_path}')
torch.save(state_dict, str(state_path))
def loss_forward(self, output, train):
raise NotImplementedError()
# store network
net_path = self.get_net_path(epoch)
logging.info(f'save network to {net_path}')
torch.save(net.state_dict(), str(net_path))
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
# err = False
# for name, param in net.named_parameters():
# if not torch.isfinite(param.grad).all():
# print(name)
# err = True
# if err:
# import ipdb; ipdb.set_trace()
pass
if scheduler is not None:
scheduler.step()
def callback_train_start(self, epoch):
pass
logging.info('=' * 80)
logging.info('Finished training')
self.log_datetime()
logging.info('=' * 80)
def callback_train_stop(self, epoch, loss):
pass
def get_train_set(self):
# returns train_set
raise NotImplementedError()
def train_epoch(self, epoch, net, optimizer, dset):
self.callback_train_start(epoch)
stopwatch = StopWatch()
def get_test_sets(self):
# returns test_sets
raise NotImplementedError()
logging.info('='*80)
logging.info('Train epoch %d' % epoch)
def copy_data(self, data, device, requires_grad, train):
raise NotImplementedError()
dset.current_epoch = epoch
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=False)
def net_forward(self, net, train):
raise NotImplementedError()
net = net.to(self.train_device)
net.train()
def loss_forward(self, output, train):
raise NotImplementedError()
mean_loss = None
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
# err = False
# for name, param in net.named_parameters():
# if not torch.isfinite(param.grad).all():
# print(name)
# err = True
# if err:
# import ipdb; ipdb.set_trace()
pass
n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader)
bar = ETA(length=n_batches)
def callback_train_start(self, epoch):
pass
stopwatch.start('total')
stopwatch.start('data')
for batch_idx, data in enumerate(train_loader):
if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break
self.copy_data(data, device=self.train_device, requires_grad=True, train=True)
stopwatch.stop('data')
def callback_train_stop(self, epoch, loss):
pass
optimizer.zero_grad()
def train_epoch(self, epoch, net, optimizer, dset):
self.callback_train_start(epoch)
stopwatch = StopWatch()
stopwatch.start('forward')
output = self.net_forward(net, train=True)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('forward')
logging.info('=' * 80)
logging.info('Train epoch %d' % epoch)
stopwatch.start('loss')
errs = self.loss_forward(output, train=True)
if isinstance(errs, dict):
masks = errs['masks']
errs = errs['errs']
else:
masks = []
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs]
err = sum(errs)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('loss')
dset.current_epoch = epoch
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True,
num_workers=self.num_workers, drop_last=True, pin_memory=False)
stopwatch.start('backward')
err.backward()
self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('backward')
net = net.to(self.train_device)
net.train()
stopwatch.start('optimizer')
optimizer.step()
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('optimizer')
mean_loss = None
bar.update(batch_idx)
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
err_str = self.format_err_str(errs)
logging.info(f'train e{epoch}: {batch_idx+1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
#self.write_err_img()
if mean_loss is None:
mean_loss = [0 for e in errs]
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.start('data')
stopwatch.stop('total')
logging.info('timings: %s' % stopwatch)
mean_loss = [l / len(train_loader) for l in mean_loss]
self.callback_train_stop(epoch, mean_loss)
self.metric_add_train(epoch, 'loss', mean_loss)
# save metrics
self.metric_save()
err_str = self.format_err_str(mean_loss)
logging.info(f'avg train_loss={err_str}')
return mean_loss
def callback_test_start(self, epoch, set_idx):
pass
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
pass
def callback_test_stop(self, epoch, set_idx, loss):
pass
def test(self, epoch, net, test_sets):
errs = {}
for test_set_idx, test_set in enumerate(test_sets):
if (epoch + 1) % test_set.test_frequency == 0:
logging.info('='*80)
logging.info(f'testing set {test_set.name}')
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
errs[test_set.name] = err
return errs
def test_epoch(self, epoch, set_idx, net, dset):
logging.info('-'*80)
logging.info('Test epoch %d' % epoch)
dset.current_epoch = epoch
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False, pin_memory=False)
net = net.to(self.test_device)
net.eval()
with torch.no_grad():
mean_loss = None
self.callback_test_start(epoch, set_idx)
bar = ETA(length=len(test_loader))
stopwatch = StopWatch()
stopwatch.start('total')
stopwatch.start('data')
for batch_idx, data in enumerate(test_loader):
# if batch_idx == 10: break
self.copy_data(data, device=self.test_device, requires_grad=False, train=False)
stopwatch.stop('data')
stopwatch.start('forward')
output = self.net_forward(net, train=False)
if 'cuda' in self.test_device: torch.cuda.synchronize()
stopwatch.stop('forward')
stopwatch.start('loss')
errs = self.loss_forward(output, train=False)
if isinstance(errs, dict):
masks = errs['masks']
errs = errs['errs']
else:
masks = []
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs]
bar.update(batch_idx)
if batch_idx % 25 == 0:
err_str = self.format_err_str(errs)
logging.info(f'test e{epoch}: {batch_idx+1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
if mean_loss is None:
mean_loss = [0 for e in errs]
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.stop('loss')
self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks)
n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader)
bar = ETA(length=n_batches)
stopwatch.start('total')
stopwatch.start('data')
stopwatch.stop('total')
logging.info('timings: %s' % stopwatch)
for batch_idx, data in enumerate(train_loader):
if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break
self.copy_data(data, device=self.train_device, requires_grad=True, train=True)
stopwatch.stop('data')
mean_loss = [l / len(test_loader) for l in mean_loss]
self.callback_test_stop(epoch, set_idx, mean_loss)
self.metric_add_test(epoch, set_idx, 'loss', mean_loss)
optimizer.zero_grad()
# save metrics
self.metric_save()
stopwatch.start('forward')
output = self.net_forward(net, train=True)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('forward')
err_str = self.format_err_str(mean_loss)
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
return mean_loss
stopwatch.start('loss')
errs = self.loss_forward(output, train=True)
if isinstance(errs, dict):
masks = errs['masks']
errs = errs['errs']
else:
masks = []
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs]
err = sum(errs)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('loss')
stopwatch.start('backward')
err.backward()
self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('backward')
stopwatch.start('optimizer')
optimizer.step()
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('optimizer')
bar.update(batch_idx)
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
err_str = self.format_err_str(errs)
logging.info(
f'train e{epoch}: {batch_idx + 1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
# self.write_err_img()
if mean_loss is None:
mean_loss = [0 for e in errs]
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.start('data')
stopwatch.stop('total')
logging.info('timings: %s' % stopwatch)
mean_loss = [l / len(train_loader) for l in mean_loss]
self.callback_train_stop(epoch, mean_loss)
self.metric_add_train(epoch, 'loss', mean_loss)
# save metrics
self.metric_save()
err_str = self.format_err_str(mean_loss)
logging.info(f'avg train_loss={err_str}')
return mean_loss
def callback_test_start(self, epoch, set_idx):
pass
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
pass
def callback_test_stop(self, epoch, set_idx, loss):
pass
def test(self, epoch, net, test_sets):
errs = {}
for test_set_idx, test_set in enumerate(test_sets):
if (epoch + 1) % test_set.test_frequency == 0:
logging.info('=' * 80)
logging.info(f'testing set {test_set.name}')
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
errs[test_set.name] = err
return errs
def test_epoch(self, epoch, set_idx, net, dset):
logging.info('-' * 80)
logging.info('Test epoch %d' % epoch)
dset.current_epoch = epoch
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False,
num_workers=self.num_workers, drop_last=False, pin_memory=False)
net = net.to(self.test_device)
net.eval()
with torch.no_grad():
mean_loss = None
self.callback_test_start(epoch, set_idx)
bar = ETA(length=len(test_loader))
stopwatch = StopWatch()
stopwatch.start('total')
stopwatch.start('data')
for batch_idx, data in enumerate(test_loader):
# if batch_idx == 10: break
self.copy_data(data, device=self.test_device, requires_grad=False, train=False)
stopwatch.stop('data')
stopwatch.start('forward')
output = self.net_forward(net, train=False)
if 'cuda' in self.test_device: torch.cuda.synchronize()
stopwatch.stop('forward')
stopwatch.start('loss')
errs = self.loss_forward(output, train=False)
if isinstance(errs, dict):
masks = errs['masks']
errs = errs['errs']
else:
masks = []
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs]
bar.update(batch_idx)
if batch_idx % 25 == 0:
err_str = self.format_err_str(errs)
logging.info(
f'test e{epoch}: {batch_idx + 1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
if mean_loss is None:
mean_loss = [0 for e in errs]
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.stop('loss')
self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks)
stopwatch.start('data')
stopwatch.stop('total')
logging.info('timings: %s' % stopwatch)
mean_loss = [l / len(test_loader) for l in mean_loss]
self.callback_test_stop(epoch, set_idx, mean_loss)
self.metric_add_test(epoch, set_idx, 'loss', mean_loss)
# save metrics
self.metric_save()
err_str = self.format_err_str(mean_loss)
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
return mean_loss