Reformat $EVERYTHING
This commit is contained in:
+45
-45
@@ -2,65 +2,65 @@ import torch
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestSet(object):
|
||||
def __init__(self, name, dset, test_frequency=1):
|
||||
self.name = name
|
||||
self.dset = dset
|
||||
self.test_frequency = test_frequency
|
||||
def __init__(self, name, dset, test_frequency=1):
|
||||
self.name = name
|
||||
self.dset = dset
|
||||
self.test_frequency = test_frequency
|
||||
|
||||
|
||||
class TestSets(list):
|
||||
def append(self, name, dset, test_frequency=1):
|
||||
super().append(TestSet(name, dset, test_frequency))
|
||||
|
||||
def append(self, name, dset, test_frequency=1):
|
||||
super().append(TestSet(name, dset, test_frequency))
|
||||
|
||||
|
||||
class MultiDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, *datasets):
|
||||
self.current_epoch = 0
|
||||
def __init__(self, *datasets):
|
||||
self.current_epoch = 0
|
||||
|
||||
self.datasets = []
|
||||
self.cum_n_samples = [0]
|
||||
self.datasets = []
|
||||
self.cum_n_samples = [0]
|
||||
|
||||
for dataset in datasets:
|
||||
self.append(dataset)
|
||||
for dataset in datasets:
|
||||
self.append(dataset)
|
||||
|
||||
def append(self, dataset):
|
||||
self.datasets.append(dataset)
|
||||
self.__update_cum_n_samples(dataset)
|
||||
def append(self, dataset):
|
||||
self.datasets.append(dataset)
|
||||
self.__update_cum_n_samples(dataset)
|
||||
|
||||
def __update_cum_n_samples(self, dataset):
|
||||
n_samples = self.cum_n_samples[-1] + len(dataset)
|
||||
self.cum_n_samples.append(n_samples)
|
||||
def __update_cum_n_samples(self, dataset):
|
||||
n_samples = self.cum_n_samples[-1] + len(dataset)
|
||||
self.cum_n_samples.append(n_samples)
|
||||
|
||||
def dataset_updated(self):
|
||||
self.cum_n_samples = [0]
|
||||
for dset in self.datasets:
|
||||
self.__update_cum_n_samples(dset)
|
||||
def dataset_updated(self):
|
||||
self.cum_n_samples = [0]
|
||||
for dset in self.datasets:
|
||||
self.__update_cum_n_samples(dset)
|
||||
|
||||
def __len__(self):
|
||||
return self.cum_n_samples[-1]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1
|
||||
sidx = idx - self.cum_n_samples[didx]
|
||||
return self.datasets[didx][sidx]
|
||||
def __len__(self):
|
||||
return self.cum_n_samples[-1]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1
|
||||
sidx = idx - self.cum_n_samples[didx]
|
||||
return self.datasets[didx][sidx]
|
||||
|
||||
|
||||
class BaseDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, train=True, fix_seed_per_epoch=False):
|
||||
self.current_epoch = 0
|
||||
self.train = train
|
||||
self.fix_seed_per_epoch = fix_seed_per_epoch
|
||||
def __init__(self, train=True, fix_seed_per_epoch=False):
|
||||
self.current_epoch = 0
|
||||
self.train = train
|
||||
self.fix_seed_per_epoch = fix_seed_per_epoch
|
||||
|
||||
def get_rng(self, idx):
|
||||
rng = np.random.RandomState()
|
||||
if self.train:
|
||||
if self.fix_seed_per_epoch:
|
||||
seed = 1 * len(self) + idx
|
||||
else:
|
||||
seed = (self.current_epoch + 1) * len(self) + idx
|
||||
rng.seed(seed)
|
||||
else:
|
||||
rng.seed(idx)
|
||||
return rng
|
||||
def get_rng(self, idx):
|
||||
rng = np.random.RandomState()
|
||||
if self.train:
|
||||
if self.fix_seed_per_epoch:
|
||||
seed = 1 * len(self) + idx
|
||||
else:
|
||||
seed = (self.current_epoch + 1) * len(self) + idx
|
||||
rng.seed(seed)
|
||||
else:
|
||||
rng.seed(idx)
|
||||
return rng
|
||||
|
||||
+120
-115
@@ -2,146 +2,151 @@ import torch
|
||||
from . import ext_cpu
|
||||
from . import ext_cuda
|
||||
|
||||
class NNFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, in0, in1):
|
||||
args = (in0, in1)
|
||||
if in0.is_cuda:
|
||||
out = ext_cuda.nn_cuda(*args)
|
||||
else:
|
||||
out = ext_cpu.nn_cpu(*args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return None, None
|
||||
class NNFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, in0, in1):
|
||||
args = (in0, in1)
|
||||
if in0.is_cuda:
|
||||
out = ext_cuda.nn_cuda(*args)
|
||||
else:
|
||||
out = ext_cpu.nn_cpu(*args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return None, None
|
||||
|
||||
|
||||
def nn(in0, in1):
|
||||
return NNFunction.apply(in0, in1)
|
||||
return NNFunction.apply(in0, in1)
|
||||
|
||||
|
||||
class CrossCheckFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, in0, in1):
|
||||
args = (in0, in1)
|
||||
if in0.is_cuda:
|
||||
out = ext_cuda.crosscheck_cuda(*args)
|
||||
else:
|
||||
out = ext_cpu.crosscheck_cpu(*args)
|
||||
return out
|
||||
@staticmethod
|
||||
def forward(ctx, in0, in1):
|
||||
args = (in0, in1)
|
||||
if in0.is_cuda:
|
||||
out = ext_cuda.crosscheck_cuda(*args)
|
||||
else:
|
||||
out = ext_cpu.crosscheck_cpu(*args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return None, None
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return None, None
|
||||
|
||||
def crosscheck(in0, in1):
|
||||
return CrossCheckFunction.apply(in0, in1)
|
||||
return CrossCheckFunction.apply(in0, in1)
|
||||
|
||||
|
||||
class ProjNNFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, xyz0, xyz1, K, patch_size):
|
||||
args = (xyz0, xyz1, K, patch_size)
|
||||
if xyz0.is_cuda:
|
||||
out = ext_cuda.proj_nn_cuda(*args)
|
||||
else:
|
||||
out = ext_cpu.proj_nn_cpu(*args)
|
||||
return out
|
||||
@staticmethod
|
||||
def forward(ctx, xyz0, xyz1, K, patch_size):
|
||||
args = (xyz0, xyz1, K, patch_size)
|
||||
if xyz0.is_cuda:
|
||||
out = ext_cuda.proj_nn_cuda(*args)
|
||||
else:
|
||||
out = ext_cpu.proj_nn_cpu(*args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return None, None, None, None
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return None, None, None, None
|
||||
|
||||
def proj_nn(xyz0, xyz1, K, patch_size):
|
||||
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
|
||||
|
||||
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
|
||||
|
||||
|
||||
class XCorrVolFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, in0, in1, n_disps, block_size):
|
||||
args = (in0, in1, n_disps, block_size)
|
||||
if in0.is_cuda:
|
||||
out = ext_cuda.xcorrvol_cuda(*args)
|
||||
else:
|
||||
out = ext_cpu.xcorrvol_cpu(*args)
|
||||
return out
|
||||
@staticmethod
|
||||
def forward(ctx, in0, in1, n_disps, block_size):
|
||||
args = (in0, in1, n_disps, block_size)
|
||||
if in0.is_cuda:
|
||||
out = ext_cuda.xcorrvol_cuda(*args)
|
||||
else:
|
||||
out = ext_cpu.xcorrvol_cpu(*args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return None, None, None, None
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return None, None, None, None
|
||||
|
||||
def xcorrvol(in0, in1, n_disps, block_size):
|
||||
return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
|
||||
|
||||
|
||||
return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
|
||||
|
||||
|
||||
class PhotometricLossFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, es, ta, block_size, type, eps):
|
||||
args = (es, ta, block_size, type, eps)
|
||||
ctx.save_for_backward(es, ta)
|
||||
ctx.block_size = block_size
|
||||
ctx.type = type
|
||||
ctx.eps = eps
|
||||
if es.is_cuda:
|
||||
out = ext_cuda.photometric_loss_forward(*args)
|
||||
else:
|
||||
out = ext_cpu.photometric_loss_forward(*args)
|
||||
return out
|
||||
@staticmethod
|
||||
def forward(ctx, es, ta, block_size, type, eps):
|
||||
args = (es, ta, block_size, type, eps)
|
||||
ctx.save_for_backward(es, ta)
|
||||
ctx.block_size = block_size
|
||||
ctx.type = type
|
||||
ctx.eps = eps
|
||||
if es.is_cuda:
|
||||
out = ext_cuda.photometric_loss_forward(*args)
|
||||
else:
|
||||
out = ext_cpu.photometric_loss_forward(*args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
es, ta = ctx.saved_tensors
|
||||
block_size = ctx.block_size
|
||||
type = ctx.type
|
||||
eps = ctx.eps
|
||||
args = (es, ta, grad_out.contiguous(), block_size, type, eps)
|
||||
if grad_out.is_cuda:
|
||||
grad_es = ext_cuda.photometric_loss_backward(*args)
|
||||
else:
|
||||
grad_es = ext_cpu.photometric_loss_backward(*args)
|
||||
return grad_es, None, None, None, None
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
es, ta = ctx.saved_tensors
|
||||
block_size = ctx.block_size
|
||||
type = ctx.type
|
||||
eps = ctx.eps
|
||||
args = (es, ta, grad_out.contiguous(), block_size, type, eps)
|
||||
if grad_out.is_cuda:
|
||||
grad_es = ext_cuda.photometric_loss_backward(*args)
|
||||
else:
|
||||
grad_es = ext_cpu.photometric_loss_backward(*args)
|
||||
return grad_es, None, None, None, None
|
||||
|
||||
def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
|
||||
type = type.lower()
|
||||
if type == 'mse':
|
||||
type = 0
|
||||
elif type == 'sad':
|
||||
type = 1
|
||||
elif type == 'census_mse':
|
||||
type = 2
|
||||
elif type == 'census_sad':
|
||||
type = 3
|
||||
else:
|
||||
raise Exception('invalid loss type')
|
||||
return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
|
||||
type = type.lower()
|
||||
if type == 'mse':
|
||||
type = 0
|
||||
elif type == 'sad':
|
||||
type = 1
|
||||
elif type == 'census_mse':
|
||||
type = 2
|
||||
elif type == 'census_sad':
|
||||
type = 3
|
||||
else:
|
||||
raise Exception('invalid loss type')
|
||||
return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
|
||||
|
||||
|
||||
def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
|
||||
type = type.lower()
|
||||
p = block_size // 2
|
||||
es_pad = torch.nn.functional.pad(es, (p,p,p,p), mode='replicate')
|
||||
ta_pad = torch.nn.functional.pad(ta, (p,p,p,p), mode='replicate')
|
||||
es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
|
||||
ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
|
||||
es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
|
||||
ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
|
||||
if type == 'mse':
|
||||
ref = (es_uf - ta_uf)**2
|
||||
elif type == 'sad':
|
||||
ref = torch.abs(es_uf - ta_uf)
|
||||
elif type == 'census_mse' or type == 'census_sad':
|
||||
des = es_uf - es.unsqueeze(2)
|
||||
dta = ta_uf - ta.unsqueeze(2)
|
||||
h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps))
|
||||
h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps))
|
||||
diff = h_des - h_dta
|
||||
if type == 'census_mse':
|
||||
ref = diff * diff
|
||||
elif type == 'census_sad':
|
||||
ref = torch.abs(diff)
|
||||
else:
|
||||
raise Exception('invalid loss type')
|
||||
ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
|
||||
ref = torch.sum(ref, dim=1, keepdim=True) / block_size**2
|
||||
return ref
|
||||
type = type.lower()
|
||||
p = block_size // 2
|
||||
es_pad = torch.nn.functional.pad(es, (p, p, p, p), mode='replicate')
|
||||
ta_pad = torch.nn.functional.pad(ta, (p, p, p, p), mode='replicate')
|
||||
es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
|
||||
ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
|
||||
es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
|
||||
ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
|
||||
if type == 'mse':
|
||||
ref = (es_uf - ta_uf) ** 2
|
||||
elif type == 'sad':
|
||||
ref = torch.abs(es_uf - ta_uf)
|
||||
elif type == 'census_mse' or type == 'census_sad':
|
||||
des = es_uf - es.unsqueeze(2)
|
||||
dta = ta_uf - ta.unsqueeze(2)
|
||||
h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps))
|
||||
h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps))
|
||||
diff = h_des - h_dta
|
||||
if type == 'census_mse':
|
||||
ref = diff * diff
|
||||
elif type == 'census_sad':
|
||||
ref = torch.abs(diff)
|
||||
else:
|
||||
raise Exception('invalid loss type')
|
||||
ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
|
||||
ref = torch.sum(ref, dim=1, keepdim=True) / block_size ** 2
|
||||
return ref
|
||||
|
||||
+19
-17
@@ -4,24 +4,26 @@ import numpy as np
|
||||
|
||||
from .functions import *
|
||||
|
||||
|
||||
class CoordConv2d(torch.nn.Module):
|
||||
def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
|
||||
super().__init__()
|
||||
def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
|
||||
super().__init__()
|
||||
|
||||
self.conv = torch.nn.Conv2d(channels_in+2, channels_out, kernel_size=kernel_size, padding=padding, stride=stride)
|
||||
self.conv = torch.nn.Conv2d(channels_in + 2, channels_out, kernel_size=kernel_size, padding=padding,
|
||||
stride=stride)
|
||||
|
||||
self.uv = None
|
||||
self.uv = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.uv is None:
|
||||
height, width = x.shape[2], x.shape[3]
|
||||
u, v = np.meshgrid(range(width), range(height))
|
||||
u = 2 * u / (width - 1) - 1
|
||||
v = 2 * v / (height - 1) - 1
|
||||
uv = np.stack((u, v)).reshape(1, 2, height, width)
|
||||
self.uv = torch.from_numpy( uv.astype(np.float32) )
|
||||
self.uv = self.uv.to(x.device)
|
||||
uv = self.uv.expand(x.shape[0], *self.uv.shape[1:])
|
||||
xuv = torch.cat((x, uv), dim=1)
|
||||
y = self.conv(xuv)
|
||||
return y
|
||||
def forward(self, x):
|
||||
if self.uv is None:
|
||||
height, width = x.shape[2], x.shape[3]
|
||||
u, v = np.meshgrid(range(width), range(height))
|
||||
u = 2 * u / (width - 1) - 1
|
||||
v = 2 * v / (height - 1) - 1
|
||||
uv = np.stack((u, v)).reshape(1, 2, height, width)
|
||||
self.uv = torch.from_numpy(uv.astype(np.float32))
|
||||
self.uv = self.uv.to(x.device)
|
||||
uv = self.uv.expand(x.shape[0], *self.uv.shape[1:])
|
||||
xuv = torch.cat((x, uv), dim=1)
|
||||
y = self.conv(xuv)
|
||||
return y
|
||||
|
||||
+8
-7
@@ -11,11 +11,12 @@ nvcc_args = [
|
||||
]
|
||||
|
||||
setup(
|
||||
name='ext',
|
||||
ext_modules=[
|
||||
CppExtension('ext_cpu', ['ext/ext_cpu.cpp']),
|
||||
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'], extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
|
||||
],
|
||||
cmdclass={'build_ext': BuildExtension},
|
||||
include_dirs=include_dirs
|
||||
name='ext',
|
||||
ext_modules=[
|
||||
CppExtension('ext_cpu', ['ext/ext_cpu.cpp']),
|
||||
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'],
|
||||
extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
|
||||
],
|
||||
cmdclass={'build_ext': BuildExtension},
|
||||
include_dirs=include_dirs
|
||||
)
|
||||
|
||||
+434
-430
@@ -17,512 +17,516 @@ from collections import OrderedDict
|
||||
|
||||
|
||||
class StopWatch(object):
|
||||
def __init__(self):
|
||||
self.timings = OrderedDict()
|
||||
self.starts = {}
|
||||
def __init__(self):
|
||||
self.timings = OrderedDict()
|
||||
self.starts = {}
|
||||
|
||||
def start(self, name):
|
||||
self.starts[name] = time.time()
|
||||
def start(self, name):
|
||||
self.starts[name] = time.time()
|
||||
|
||||
def stop(self, name):
|
||||
if name not in self.timings:
|
||||
self.timings[name] = []
|
||||
self.timings[name].append(time.time() - self.starts[name])
|
||||
def stop(self, name):
|
||||
if name not in self.timings:
|
||||
self.timings[name] = []
|
||||
self.timings[name].append(time.time() - self.starts[name])
|
||||
|
||||
def get(self, name=None, reduce=np.sum):
|
||||
if name is not None:
|
||||
return reduce(self.timings[name])
|
||||
else:
|
||||
ret = {}
|
||||
for k in self.timings:
|
||||
ret[k] = reduce(self.timings[k])
|
||||
return ret
|
||||
def get(self, name=None, reduce=np.sum):
|
||||
if name is not None:
|
||||
return reduce(self.timings[name])
|
||||
else:
|
||||
ret = {}
|
||||
for k in self.timings:
|
||||
ret[k] = reduce(self.timings[k])
|
||||
return ret
|
||||
|
||||
def __repr__(self):
|
||||
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
|
||||
def __str__(self):
|
||||
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
|
||||
def __repr__(self):
|
||||
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
|
||||
|
||||
def __str__(self):
|
||||
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
|
||||
|
||||
|
||||
class ETA(object):
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
self.start_time = time.time()
|
||||
self.current_idx = 0
|
||||
self.current_time = time.time()
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
self.start_time = time.time()
|
||||
self.current_idx = 0
|
||||
self.current_time = time.time()
|
||||
|
||||
def update(self, idx):
|
||||
self.current_idx = idx
|
||||
self.current_time = time.time()
|
||||
def update(self, idx):
|
||||
self.current_idx = idx
|
||||
self.current_time = time.time()
|
||||
|
||||
def get_elapsed_time(self):
|
||||
return self.current_time - self.start_time
|
||||
def get_elapsed_time(self):
|
||||
return self.current_time - self.start_time
|
||||
|
||||
def get_item_time(self):
|
||||
return self.get_elapsed_time() / (self.current_idx + 1)
|
||||
def get_item_time(self):
|
||||
return self.get_elapsed_time() / (self.current_idx + 1)
|
||||
|
||||
def get_remaining_time(self):
|
||||
return self.get_item_time() * (self.length - self.current_idx + 1)
|
||||
def get_remaining_time(self):
|
||||
return self.get_item_time() * (self.length - self.current_idx + 1)
|
||||
|
||||
def format_time(self, seconds):
|
||||
minutes, seconds = divmod(seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
hours = int(hours)
|
||||
minutes = int(minutes)
|
||||
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
|
||||
def format_time(self, seconds):
|
||||
minutes, seconds = divmod(seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
hours = int(hours)
|
||||
minutes = int(minutes)
|
||||
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
|
||||
|
||||
def get_elapsed_time_str(self):
|
||||
return self.format_time(self.get_elapsed_time())
|
||||
def get_elapsed_time_str(self):
|
||||
return self.format_time(self.get_elapsed_time())
|
||||
|
||||
def get_remaining_time_str(self):
|
||||
return self.format_time(self.get_remaining_time())
|
||||
|
||||
def get_remaining_time_str(self):
|
||||
return self.format_time(self.get_remaining_time())
|
||||
|
||||
class Worker(object):
|
||||
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
|
||||
self.out_root = Path(out_root)
|
||||
self.experiment_name = experiment_name
|
||||
self.epochs = epochs
|
||||
self.seed = seed
|
||||
self.train_batch_size = train_batch_size
|
||||
self.test_batch_size = test_batch_size
|
||||
self.num_workers = num_workers
|
||||
self.save_frequency = save_frequency
|
||||
self.train_device = train_device
|
||||
self.test_device = test_device
|
||||
self.max_train_iter = max_train_iter
|
||||
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16,
|
||||
num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
|
||||
self.out_root = Path(out_root)
|
||||
self.experiment_name = experiment_name
|
||||
self.epochs = epochs
|
||||
self.seed = seed
|
||||
self.train_batch_size = train_batch_size
|
||||
self.test_batch_size = test_batch_size
|
||||
self.num_workers = num_workers
|
||||
self.save_frequency = save_frequency
|
||||
self.train_device = train_device
|
||||
self.test_device = test_device
|
||||
self.max_train_iter = max_train_iter
|
||||
|
||||
self.errs_list=[]
|
||||
self.errs_list = []
|
||||
|
||||
self.setup_experiment()
|
||||
self.setup_experiment()
|
||||
|
||||
def setup_experiment(self):
|
||||
self.exp_out_root = self.out_root / self.experiment_name
|
||||
self.exp_out_root.mkdir(parents=True, exist_ok=True)
|
||||
def setup_experiment(self):
|
||||
self.exp_out_root = self.out_root / self.experiment_name
|
||||
self.exp_out_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if logging.root: del logging.root.handlers[:]
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
handlers=[
|
||||
logging.FileHandler( str(self.exp_out_root / 'train.log') ),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s'
|
||||
)
|
||||
if logging.root: del logging.root.handlers[:]
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
handlers=[
|
||||
logging.FileHandler(str(self.exp_out_root / 'train.log')),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s'
|
||||
)
|
||||
|
||||
logging.info('='*80)
|
||||
logging.info(f'Start of experiment: {self.experiment_name}')
|
||||
logging.info(socket.gethostname())
|
||||
self.log_datetime()
|
||||
logging.info('='*80)
|
||||
logging.info('=' * 80)
|
||||
logging.info(f'Start of experiment: {self.experiment_name}')
|
||||
logging.info(socket.gethostname())
|
||||
self.log_datetime()
|
||||
logging.info('=' * 80)
|
||||
|
||||
self.metric_path = self.exp_out_root / 'metrics.json'
|
||||
if self.metric_path.exists():
|
||||
with open(str(self.metric_path), 'r') as fp:
|
||||
self.metric_data = json.load(fp)
|
||||
else:
|
||||
self.metric_data = {}
|
||||
self.metric_path = self.exp_out_root / 'metrics.json'
|
||||
if self.metric_path.exists():
|
||||
with open(str(self.metric_path), 'r') as fp:
|
||||
self.metric_data = json.load(fp)
|
||||
else:
|
||||
self.metric_data = {}
|
||||
|
||||
self.init_seed()
|
||||
self.init_seed()
|
||||
|
||||
def metric_add_train(self, epoch, key, val):
|
||||
epoch = str(epoch)
|
||||
key = str(key)
|
||||
if epoch not in self.metric_data:
|
||||
self.metric_data[epoch] = {}
|
||||
if 'train' not in self.metric_data[epoch]:
|
||||
self.metric_data[epoch]['train'] = {}
|
||||
self.metric_data[epoch]['train'][key] = val
|
||||
def metric_add_train(self, epoch, key, val):
|
||||
epoch = str(epoch)
|
||||
key = str(key)
|
||||
if epoch not in self.metric_data:
|
||||
self.metric_data[epoch] = {}
|
||||
if 'train' not in self.metric_data[epoch]:
|
||||
self.metric_data[epoch]['train'] = {}
|
||||
self.metric_data[epoch]['train'][key] = val
|
||||
|
||||
def metric_add_test(self, epoch, set_idx, key, val):
|
||||
epoch = str(epoch)
|
||||
set_idx = str(set_idx)
|
||||
key = str(key)
|
||||
if epoch not in self.metric_data:
|
||||
self.metric_data[epoch] = {}
|
||||
if 'test' not in self.metric_data[epoch]:
|
||||
self.metric_data[epoch]['test'] = {}
|
||||
if set_idx not in self.metric_data[epoch]['test']:
|
||||
self.metric_data[epoch]['test'][set_idx] = {}
|
||||
self.metric_data[epoch]['test'][set_idx][key] = val
|
||||
def metric_add_test(self, epoch, set_idx, key, val):
|
||||
epoch = str(epoch)
|
||||
set_idx = str(set_idx)
|
||||
key = str(key)
|
||||
if epoch not in self.metric_data:
|
||||
self.metric_data[epoch] = {}
|
||||
if 'test' not in self.metric_data[epoch]:
|
||||
self.metric_data[epoch]['test'] = {}
|
||||
if set_idx not in self.metric_data[epoch]['test']:
|
||||
self.metric_data[epoch]['test'][set_idx] = {}
|
||||
self.metric_data[epoch]['test'][set_idx][key] = val
|
||||
|
||||
def metric_save(self):
|
||||
with open(str(self.metric_path), 'w') as fp:
|
||||
json.dump(self.metric_data, fp, indent=2)
|
||||
def metric_save(self):
|
||||
with open(str(self.metric_path), 'w') as fp:
|
||||
json.dump(self.metric_data, fp, indent=2)
|
||||
|
||||
def init_seed(self, seed=None):
|
||||
if seed is not None:
|
||||
self.seed = seed
|
||||
logging.info(f'Set seed to {self.seed}')
|
||||
np.random.seed(self.seed)
|
||||
random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
torch.cuda.manual_seed(self.seed)
|
||||
def init_seed(self, seed=None):
|
||||
if seed is not None:
|
||||
self.seed = seed
|
||||
logging.info(f'Set seed to {self.seed}')
|
||||
np.random.seed(self.seed)
|
||||
random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
torch.cuda.manual_seed(self.seed)
|
||||
|
||||
def log_datetime(self):
|
||||
logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
def log_datetime(self):
|
||||
logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
|
||||
def mem_report(self):
|
||||
for obj in gc.get_objects():
|
||||
if torch.is_tensor(obj):
|
||||
print(type(obj), obj.shape)
|
||||
def mem_report(self):
|
||||
for obj in gc.get_objects():
|
||||
if torch.is_tensor(obj):
|
||||
print(type(obj), obj.shape)
|
||||
|
||||
def get_net_path(self, epoch, root=None):
|
||||
if root is None:
|
||||
root = self.exp_out_root
|
||||
return root / f'net_{epoch:04d}.params'
|
||||
def get_net_path(self, epoch, root=None):
|
||||
if root is None:
|
||||
root = self.exp_out_root
|
||||
return root / f'net_{epoch:04d}.params'
|
||||
|
||||
def get_do_parser_cmds(self):
|
||||
return ['retrain', 'resume', 'retest', 'test_init']
|
||||
def get_do_parser_cmds(self):
|
||||
return ['retrain', 'resume', 'retest', 'test_init']
|
||||
|
||||
def get_do_parser(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds())
|
||||
parser.add_argument('--epoch', type=int, default=-1)
|
||||
return parser
|
||||
def get_do_parser(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds())
|
||||
parser.add_argument('--epoch', type=int, default=-1)
|
||||
return parser
|
||||
|
||||
def do_cmd(self, args, net, optimizer, scheduler=None):
|
||||
if args.cmd == 'retrain':
|
||||
self.train(net, optimizer, resume=False, scheduler=scheduler)
|
||||
elif args.cmd == 'resume':
|
||||
self.train(net, optimizer, resume=True, scheduler=scheduler)
|
||||
elif args.cmd == 'retest':
|
||||
self.retest(net, epoch=args.epoch)
|
||||
elif args.cmd == 'test_init':
|
||||
test_sets = self.get_test_sets()
|
||||
self.test(-1, net, test_sets)
|
||||
else:
|
||||
raise Exception('invalid cmd')
|
||||
def do_cmd(self, args, net, optimizer, scheduler=None):
|
||||
if args.cmd == 'retrain':
|
||||
self.train(net, optimizer, resume=False, scheduler=scheduler)
|
||||
elif args.cmd == 'resume':
|
||||
self.train(net, optimizer, resume=True, scheduler=scheduler)
|
||||
elif args.cmd == 'retest':
|
||||
self.retest(net, epoch=args.epoch)
|
||||
elif args.cmd == 'test_init':
|
||||
test_sets = self.get_test_sets()
|
||||
self.test(-1, net, test_sets)
|
||||
else:
|
||||
raise Exception('invalid cmd')
|
||||
|
||||
def do(self, net, optimizer, load_net_optimizer=None, scheduler=None):
|
||||
parser = self.get_do_parser()
|
||||
args, _ = parser.parse_known_args()
|
||||
def do(self, net, optimizer, load_net_optimizer=None, scheduler=None):
|
||||
parser = self.get_do_parser()
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if load_net_optimizer is not None and args.cmd not in ['schedule']:
|
||||
net, optimizer = load_net_optimizer()
|
||||
if load_net_optimizer is not None and args.cmd not in ['schedule']:
|
||||
net, optimizer = load_net_optimizer()
|
||||
|
||||
self.do_cmd(args, net, optimizer, scheduler=scheduler)
|
||||
self.do_cmd(args, net, optimizer, scheduler=scheduler)
|
||||
|
||||
def retest(self, net, epoch=-1):
|
||||
if epoch < 0:
|
||||
epochs = range(self.epochs)
|
||||
else:
|
||||
epochs = [epoch]
|
||||
def retest(self, net, epoch=-1):
|
||||
if epoch < 0:
|
||||
epochs = range(self.epochs)
|
||||
else:
|
||||
epochs = [epoch]
|
||||
|
||||
test_sets = self.get_test_sets()
|
||||
test_sets = self.get_test_sets()
|
||||
|
||||
for epoch in epochs:
|
||||
net_path = self.get_net_path(epoch)
|
||||
if net_path.exists():
|
||||
state_dict = torch.load(str(net_path))
|
||||
net.load_state_dict(state_dict)
|
||||
self.test(epoch, net, test_sets)
|
||||
for epoch in epochs:
|
||||
net_path = self.get_net_path(epoch)
|
||||
if net_path.exists():
|
||||
state_dict = torch.load(str(net_path))
|
||||
net.load_state_dict(state_dict)
|
||||
self.test(epoch, net, test_sets)
|
||||
|
||||
def format_err_str(self, errs, div=1):
|
||||
err = sum(errs)
|
||||
if len(errs) > 1:
|
||||
err_str = f'{err/div:0.4f}=' + '+'.join([f'{e/div:0.4f}' for e in errs])
|
||||
else:
|
||||
err_str = f'{err/div:0.4f}'
|
||||
return err_str
|
||||
def format_err_str(self, errs, div=1):
|
||||
err = sum(errs)
|
||||
if len(errs) > 1:
|
||||
err_str = f'{err / div:0.4f}=' + '+'.join([f'{e / div:0.4f}' for e in errs])
|
||||
else:
|
||||
err_str = f'{err / div:0.4f}'
|
||||
return err_str
|
||||
|
||||
def write_err_img(self):
|
||||
err_img_path = self.exp_out_root / 'errs.png'
|
||||
fig = plt.figure(figsize=(16,16))
|
||||
lines=[]
|
||||
for idx,errs in enumerate(self.errs_list):
|
||||
line,=plt.plot(range(len(errs)), errs, label=f'error{idx}')
|
||||
lines.append(line)
|
||||
plt.tight_layout()
|
||||
plt.legend(handles=lines)
|
||||
plt.savefig(str(err_img_path))
|
||||
plt.close(fig)
|
||||
def write_err_img(self):
|
||||
err_img_path = self.exp_out_root / 'errs.png'
|
||||
fig = plt.figure(figsize=(16, 16))
|
||||
lines = []
|
||||
for idx, errs in enumerate(self.errs_list):
|
||||
line, = plt.plot(range(len(errs)), errs, label=f'error{idx}')
|
||||
lines.append(line)
|
||||
plt.tight_layout()
|
||||
plt.legend(handles=lines)
|
||||
plt.savefig(str(err_img_path))
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def callback_train_new_epoch(self, epoch, net, optimizer):
|
||||
pass
|
||||
|
||||
def train(self, net, optimizer, resume=False, scheduler=None):
|
||||
logging.info('='*80)
|
||||
logging.info('Start training')
|
||||
self.log_datetime()
|
||||
logging.info('='*80)
|
||||
|
||||
train_set = self.get_train_set()
|
||||
test_sets = self.get_test_sets()
|
||||
|
||||
net = net.to(self.train_device)
|
||||
|
||||
epoch = 0
|
||||
min_err = {ts.name: 1e9 for ts in test_sets}
|
||||
|
||||
state_path = self.exp_out_root / 'state.dict'
|
||||
if resume and state_path.exists():
|
||||
logging.info('='*80)
|
||||
logging.info(f'Loading state from {state_path}')
|
||||
logging.info('='*80)
|
||||
state = torch.load(str(state_path))
|
||||
epoch = state['epoch'] + 1
|
||||
if 'min_err' in state:
|
||||
min_err = state['min_err']
|
||||
|
||||
curr_state = net.state_dict()
|
||||
curr_state.update(state['state_dict'])
|
||||
net.load_state_dict(curr_state)
|
||||
|
||||
|
||||
try:
|
||||
optimizer.load_state_dict(state['optimizer'])
|
||||
except:
|
||||
logging.info('Warning: cannot load optimizer from state_dict')
|
||||
def callback_train_new_epoch(self, epoch, net, optimizer):
|
||||
pass
|
||||
if 'cpu_rng_state' in state:
|
||||
torch.set_rng_state(state['cpu_rng_state'])
|
||||
if 'gpu_rng_state' in state:
|
||||
torch.cuda.set_rng_state(state['gpu_rng_state'])
|
||||
|
||||
for epoch in range(epoch, self.epochs):
|
||||
self.callback_train_new_epoch(epoch, net, optimizer)
|
||||
def train(self, net, optimizer, resume=False, scheduler=None):
|
||||
logging.info('=' * 80)
|
||||
logging.info('Start training')
|
||||
self.log_datetime()
|
||||
logging.info('=' * 80)
|
||||
|
||||
# train epoch
|
||||
self.train_epoch(epoch, net, optimizer, train_set)
|
||||
train_set = self.get_train_set()
|
||||
test_sets = self.get_test_sets()
|
||||
|
||||
# test epoch
|
||||
errs = self.test(epoch, net, test_sets)
|
||||
|
||||
if (epoch + 1) % self.save_frequency == 0:
|
||||
net = net.to(self.train_device)
|
||||
|
||||
# store state
|
||||
state_dict = {
|
||||
'epoch': epoch,
|
||||
'min_err': min_err,
|
||||
'state_dict': net.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'cpu_rng_state': torch.get_rng_state(),
|
||||
'gpu_rng_state': torch.cuda.get_rng_state(),
|
||||
}
|
||||
logging.info(f'save state to {state_path}')
|
||||
epoch = 0
|
||||
min_err = {ts.name: 1e9 for ts in test_sets}
|
||||
|
||||
state_path = self.exp_out_root / 'state.dict'
|
||||
torch.save(state_dict, str(state_path))
|
||||
if resume and state_path.exists():
|
||||
logging.info('=' * 80)
|
||||
logging.info(f'Loading state from {state_path}')
|
||||
logging.info('=' * 80)
|
||||
state = torch.load(str(state_path))
|
||||
epoch = state['epoch'] + 1
|
||||
if 'min_err' in state:
|
||||
min_err = state['min_err']
|
||||
|
||||
for test_set_name in errs:
|
||||
err = sum(errs[test_set_name])
|
||||
if err < min_err[test_set_name]:
|
||||
min_err[test_set_name] = err
|
||||
state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict'
|
||||
logging.info(f'save state to {state_path}')
|
||||
torch.save(state_dict, str(state_path))
|
||||
curr_state = net.state_dict()
|
||||
curr_state.update(state['state_dict'])
|
||||
net.load_state_dict(curr_state)
|
||||
|
||||
# store network
|
||||
net_path = self.get_net_path(epoch)
|
||||
logging.info(f'save network to {net_path}')
|
||||
torch.save(net.state_dict(), str(net_path))
|
||||
try:
|
||||
optimizer.load_state_dict(state['optimizer'])
|
||||
except:
|
||||
logging.info('Warning: cannot load optimizer from state_dict')
|
||||
pass
|
||||
if 'cpu_rng_state' in state:
|
||||
torch.set_rng_state(state['cpu_rng_state'])
|
||||
if 'gpu_rng_state' in state:
|
||||
torch.cuda.set_rng_state(state['gpu_rng_state'])
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
for epoch in range(epoch, self.epochs):
|
||||
self.callback_train_new_epoch(epoch, net, optimizer)
|
||||
|
||||
logging.info('='*80)
|
||||
logging.info('Finished training')
|
||||
self.log_datetime()
|
||||
logging.info('='*80)
|
||||
# train epoch
|
||||
self.train_epoch(epoch, net, optimizer, train_set)
|
||||
|
||||
def get_train_set(self):
|
||||
# returns train_set
|
||||
raise NotImplementedError()
|
||||
# test epoch
|
||||
errs = self.test(epoch, net, test_sets)
|
||||
|
||||
def get_test_sets(self):
|
||||
# returns test_sets
|
||||
raise NotImplementedError()
|
||||
if (epoch + 1) % self.save_frequency == 0:
|
||||
net = net.to(self.train_device)
|
||||
|
||||
def copy_data(self, data, device, requires_grad, train):
|
||||
raise NotImplementedError()
|
||||
# store state
|
||||
state_dict = {
|
||||
'epoch': epoch,
|
||||
'min_err': min_err,
|
||||
'state_dict': net.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'cpu_rng_state': torch.get_rng_state(),
|
||||
'gpu_rng_state': torch.cuda.get_rng_state(),
|
||||
}
|
||||
logging.info(f'save state to {state_path}')
|
||||
state_path = self.exp_out_root / 'state.dict'
|
||||
torch.save(state_dict, str(state_path))
|
||||
|
||||
def net_forward(self, net, train):
|
||||
raise NotImplementedError()
|
||||
for test_set_name in errs:
|
||||
err = sum(errs[test_set_name])
|
||||
if err < min_err[test_set_name]:
|
||||
min_err[test_set_name] = err
|
||||
state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict'
|
||||
logging.info(f'save state to {state_path}')
|
||||
torch.save(state_dict, str(state_path))
|
||||
|
||||
def loss_forward(self, output, train):
|
||||
raise NotImplementedError()
|
||||
# store network
|
||||
net_path = self.get_net_path(epoch)
|
||||
logging.info(f'save network to {net_path}')
|
||||
torch.save(net.state_dict(), str(net_path))
|
||||
|
||||
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
|
||||
# err = False
|
||||
# for name, param in net.named_parameters():
|
||||
# if not torch.isfinite(param.grad).all():
|
||||
# print(name)
|
||||
# err = True
|
||||
# if err:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
pass
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
def callback_train_start(self, epoch):
|
||||
pass
|
||||
logging.info('=' * 80)
|
||||
logging.info('Finished training')
|
||||
self.log_datetime()
|
||||
logging.info('=' * 80)
|
||||
|
||||
def callback_train_stop(self, epoch, loss):
|
||||
pass
|
||||
def get_train_set(self):
|
||||
# returns train_set
|
||||
raise NotImplementedError()
|
||||
|
||||
def train_epoch(self, epoch, net, optimizer, dset):
|
||||
self.callback_train_start(epoch)
|
||||
stopwatch = StopWatch()
|
||||
def get_test_sets(self):
|
||||
# returns test_sets
|
||||
raise NotImplementedError()
|
||||
|
||||
logging.info('='*80)
|
||||
logging.info('Train epoch %d' % epoch)
|
||||
def copy_data(self, data, device, requires_grad, train):
|
||||
raise NotImplementedError()
|
||||
|
||||
dset.current_epoch = epoch
|
||||
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=False)
|
||||
def net_forward(self, net, train):
|
||||
raise NotImplementedError()
|
||||
|
||||
net = net.to(self.train_device)
|
||||
net.train()
|
||||
def loss_forward(self, output, train):
|
||||
raise NotImplementedError()
|
||||
|
||||
mean_loss = None
|
||||
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
|
||||
# err = False
|
||||
# for name, param in net.named_parameters():
|
||||
# if not torch.isfinite(param.grad).all():
|
||||
# print(name)
|
||||
# err = True
|
||||
# if err:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
pass
|
||||
|
||||
n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader)
|
||||
bar = ETA(length=n_batches)
|
||||
def callback_train_start(self, epoch):
|
||||
pass
|
||||
|
||||
stopwatch.start('total')
|
||||
stopwatch.start('data')
|
||||
for batch_idx, data in enumerate(train_loader):
|
||||
if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break
|
||||
self.copy_data(data, device=self.train_device, requires_grad=True, train=True)
|
||||
stopwatch.stop('data')
|
||||
def callback_train_stop(self, epoch, loss):
|
||||
pass
|
||||
|
||||
optimizer.zero_grad()
|
||||
def train_epoch(self, epoch, net, optimizer, dset):
|
||||
self.callback_train_start(epoch)
|
||||
stopwatch = StopWatch()
|
||||
|
||||
stopwatch.start('forward')
|
||||
output = self.net_forward(net, train=True)
|
||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('forward')
|
||||
logging.info('=' * 80)
|
||||
logging.info('Train epoch %d' % epoch)
|
||||
|
||||
stopwatch.start('loss')
|
||||
errs = self.loss_forward(output, train=True)
|
||||
if isinstance(errs, dict):
|
||||
masks = errs['masks']
|
||||
errs = errs['errs']
|
||||
else:
|
||||
masks = []
|
||||
if not isinstance(errs, list) and not isinstance(errs, tuple):
|
||||
errs = [errs]
|
||||
err = sum(errs)
|
||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('loss')
|
||||
dset.current_epoch = epoch
|
||||
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True,
|
||||
num_workers=self.num_workers, drop_last=True, pin_memory=False)
|
||||
|
||||
stopwatch.start('backward')
|
||||
err.backward()
|
||||
self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks)
|
||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('backward')
|
||||
net = net.to(self.train_device)
|
||||
net.train()
|
||||
|
||||
stopwatch.start('optimizer')
|
||||
optimizer.step()
|
||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('optimizer')
|
||||
mean_loss = None
|
||||
|
||||
bar.update(batch_idx)
|
||||
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
|
||||
err_str = self.format_err_str(errs)
|
||||
logging.info(f'train e{epoch}: {batch_idx+1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
||||
#self.write_err_img()
|
||||
|
||||
|
||||
if mean_loss is None:
|
||||
mean_loss = [0 for e in errs]
|
||||
for erridx, err in enumerate(errs):
|
||||
mean_loss[erridx] += err.item()
|
||||
|
||||
stopwatch.start('data')
|
||||
stopwatch.stop('total')
|
||||
logging.info('timings: %s' % stopwatch)
|
||||
|
||||
mean_loss = [l / len(train_loader) for l in mean_loss]
|
||||
self.callback_train_stop(epoch, mean_loss)
|
||||
self.metric_add_train(epoch, 'loss', mean_loss)
|
||||
|
||||
# save metrics
|
||||
self.metric_save()
|
||||
|
||||
err_str = self.format_err_str(mean_loss)
|
||||
logging.info(f'avg train_loss={err_str}')
|
||||
return mean_loss
|
||||
|
||||
def callback_test_start(self, epoch, set_idx):
|
||||
pass
|
||||
|
||||
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
|
||||
pass
|
||||
|
||||
def callback_test_stop(self, epoch, set_idx, loss):
|
||||
pass
|
||||
|
||||
def test(self, epoch, net, test_sets):
|
||||
errs = {}
|
||||
for test_set_idx, test_set in enumerate(test_sets):
|
||||
if (epoch + 1) % test_set.test_frequency == 0:
|
||||
logging.info('='*80)
|
||||
logging.info(f'testing set {test_set.name}')
|
||||
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
|
||||
errs[test_set.name] = err
|
||||
return errs
|
||||
|
||||
def test_epoch(self, epoch, set_idx, net, dset):
|
||||
logging.info('-'*80)
|
||||
logging.info('Test epoch %d' % epoch)
|
||||
dset.current_epoch = epoch
|
||||
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False, pin_memory=False)
|
||||
|
||||
net = net.to(self.test_device)
|
||||
net.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
mean_loss = None
|
||||
|
||||
self.callback_test_start(epoch, set_idx)
|
||||
|
||||
bar = ETA(length=len(test_loader))
|
||||
stopwatch = StopWatch()
|
||||
stopwatch.start('total')
|
||||
stopwatch.start('data')
|
||||
for batch_idx, data in enumerate(test_loader):
|
||||
# if batch_idx == 10: break
|
||||
self.copy_data(data, device=self.test_device, requires_grad=False, train=False)
|
||||
stopwatch.stop('data')
|
||||
|
||||
stopwatch.start('forward')
|
||||
output = self.net_forward(net, train=False)
|
||||
if 'cuda' in self.test_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('forward')
|
||||
|
||||
stopwatch.start('loss')
|
||||
errs = self.loss_forward(output, train=False)
|
||||
if isinstance(errs, dict):
|
||||
masks = errs['masks']
|
||||
errs = errs['errs']
|
||||
else:
|
||||
masks = []
|
||||
if not isinstance(errs, list) and not isinstance(errs, tuple):
|
||||
errs = [errs]
|
||||
|
||||
bar.update(batch_idx)
|
||||
if batch_idx % 25 == 0:
|
||||
err_str = self.format_err_str(errs)
|
||||
logging.info(f'test e{epoch}: {batch_idx+1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
||||
|
||||
if mean_loss is None:
|
||||
mean_loss = [0 for e in errs]
|
||||
for erridx, err in enumerate(errs):
|
||||
mean_loss[erridx] += err.item()
|
||||
stopwatch.stop('loss')
|
||||
|
||||
self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks)
|
||||
n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader)
|
||||
bar = ETA(length=n_batches)
|
||||
|
||||
stopwatch.start('total')
|
||||
stopwatch.start('data')
|
||||
stopwatch.stop('total')
|
||||
logging.info('timings: %s' % stopwatch)
|
||||
for batch_idx, data in enumerate(train_loader):
|
||||
if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break
|
||||
self.copy_data(data, device=self.train_device, requires_grad=True, train=True)
|
||||
stopwatch.stop('data')
|
||||
|
||||
mean_loss = [l / len(test_loader) for l in mean_loss]
|
||||
self.callback_test_stop(epoch, set_idx, mean_loss)
|
||||
self.metric_add_test(epoch, set_idx, 'loss', mean_loss)
|
||||
optimizer.zero_grad()
|
||||
|
||||
# save metrics
|
||||
self.metric_save()
|
||||
stopwatch.start('forward')
|
||||
output = self.net_forward(net, train=True)
|
||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('forward')
|
||||
|
||||
err_str = self.format_err_str(mean_loss)
|
||||
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
|
||||
return mean_loss
|
||||
stopwatch.start('loss')
|
||||
errs = self.loss_forward(output, train=True)
|
||||
if isinstance(errs, dict):
|
||||
masks = errs['masks']
|
||||
errs = errs['errs']
|
||||
else:
|
||||
masks = []
|
||||
if not isinstance(errs, list) and not isinstance(errs, tuple):
|
||||
errs = [errs]
|
||||
err = sum(errs)
|
||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('loss')
|
||||
|
||||
stopwatch.start('backward')
|
||||
err.backward()
|
||||
self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks)
|
||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('backward')
|
||||
|
||||
stopwatch.start('optimizer')
|
||||
optimizer.step()
|
||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('optimizer')
|
||||
|
||||
bar.update(batch_idx)
|
||||
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
|
||||
err_str = self.format_err_str(errs)
|
||||
logging.info(
|
||||
f'train e{epoch}: {batch_idx + 1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
||||
# self.write_err_img()
|
||||
|
||||
if mean_loss is None:
|
||||
mean_loss = [0 for e in errs]
|
||||
for erridx, err in enumerate(errs):
|
||||
mean_loss[erridx] += err.item()
|
||||
|
||||
stopwatch.start('data')
|
||||
stopwatch.stop('total')
|
||||
logging.info('timings: %s' % stopwatch)
|
||||
|
||||
mean_loss = [l / len(train_loader) for l in mean_loss]
|
||||
self.callback_train_stop(epoch, mean_loss)
|
||||
self.metric_add_train(epoch, 'loss', mean_loss)
|
||||
|
||||
# save metrics
|
||||
self.metric_save()
|
||||
|
||||
err_str = self.format_err_str(mean_loss)
|
||||
logging.info(f'avg train_loss={err_str}')
|
||||
return mean_loss
|
||||
|
||||
def callback_test_start(self, epoch, set_idx):
|
||||
pass
|
||||
|
||||
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
|
||||
pass
|
||||
|
||||
def callback_test_stop(self, epoch, set_idx, loss):
|
||||
pass
|
||||
|
||||
def test(self, epoch, net, test_sets):
|
||||
errs = {}
|
||||
for test_set_idx, test_set in enumerate(test_sets):
|
||||
if (epoch + 1) % test_set.test_frequency == 0:
|
||||
logging.info('=' * 80)
|
||||
logging.info(f'testing set {test_set.name}')
|
||||
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
|
||||
errs[test_set.name] = err
|
||||
return errs
|
||||
|
||||
def test_epoch(self, epoch, set_idx, net, dset):
|
||||
logging.info('-' * 80)
|
||||
logging.info('Test epoch %d' % epoch)
|
||||
dset.current_epoch = epoch
|
||||
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False,
|
||||
num_workers=self.num_workers, drop_last=False, pin_memory=False)
|
||||
|
||||
net = net.to(self.test_device)
|
||||
net.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
mean_loss = None
|
||||
|
||||
self.callback_test_start(epoch, set_idx)
|
||||
|
||||
bar = ETA(length=len(test_loader))
|
||||
stopwatch = StopWatch()
|
||||
stopwatch.start('total')
|
||||
stopwatch.start('data')
|
||||
for batch_idx, data in enumerate(test_loader):
|
||||
# if batch_idx == 10: break
|
||||
self.copy_data(data, device=self.test_device, requires_grad=False, train=False)
|
||||
stopwatch.stop('data')
|
||||
|
||||
stopwatch.start('forward')
|
||||
output = self.net_forward(net, train=False)
|
||||
if 'cuda' in self.test_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('forward')
|
||||
|
||||
stopwatch.start('loss')
|
||||
errs = self.loss_forward(output, train=False)
|
||||
if isinstance(errs, dict):
|
||||
masks = errs['masks']
|
||||
errs = errs['errs']
|
||||
else:
|
||||
masks = []
|
||||
if not isinstance(errs, list) and not isinstance(errs, tuple):
|
||||
errs = [errs]
|
||||
|
||||
bar.update(batch_idx)
|
||||
if batch_idx % 25 == 0:
|
||||
err_str = self.format_err_str(errs)
|
||||
logging.info(
|
||||
f'test e{epoch}: {batch_idx + 1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
||||
|
||||
if mean_loss is None:
|
||||
mean_loss = [0 for e in errs]
|
||||
for erridx, err in enumerate(errs):
|
||||
mean_loss[erridx] += err.item()
|
||||
stopwatch.stop('loss')
|
||||
|
||||
self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks)
|
||||
|
||||
stopwatch.start('data')
|
||||
stopwatch.stop('total')
|
||||
logging.info('timings: %s' % stopwatch)
|
||||
|
||||
mean_loss = [l / len(test_loader) for l in mean_loss]
|
||||
self.callback_test_stop(epoch, set_idx, mean_loss)
|
||||
self.metric_add_test(epoch, set_idx, 'loss', mean_loss)
|
||||
|
||||
# save metrics
|
||||
self.metric_save()
|
||||
|
||||
err_str = self.format_err_str(mean_loss)
|
||||
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
|
||||
return mean_loss
|
||||
|
||||
Reference in New Issue
Block a user