|
|
@ -14,6 +14,7 @@ import json |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
import time |
|
|
|
import time |
|
|
|
from collections import OrderedDict |
|
|
|
from collections import OrderedDict |
|
|
|
|
|
|
|
import wandb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StopWatch(object): |
|
|
|
class StopWatch(object): |
|
|
@ -81,7 +82,7 @@ class ETA(object): |
|
|
|
|
|
|
|
|
|
|
|
class Worker(object): |
|
|
|
class Worker(object): |
|
|
|
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, |
|
|
|
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, |
|
|
|
num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1): |
|
|
|
num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1, no_double_heads=True): |
|
|
|
self.out_root = Path(out_root) |
|
|
|
self.out_root = Path(out_root) |
|
|
|
self.experiment_name = experiment_name |
|
|
|
self.experiment_name = experiment_name |
|
|
|
self.epochs = epochs |
|
|
|
self.epochs = epochs |
|
|
@ -93,6 +94,7 @@ class Worker(object): |
|
|
|
self.train_device = train_device |
|
|
|
self.train_device = train_device |
|
|
|
self.test_device = test_device |
|
|
|
self.test_device = test_device |
|
|
|
self.max_train_iter = max_train_iter |
|
|
|
self.max_train_iter = max_train_iter |
|
|
|
|
|
|
|
self.double_heads = no_double_heads |
|
|
|
|
|
|
|
|
|
|
|
self.errs_list = [] |
|
|
|
self.errs_list = [] |
|
|
|
|
|
|
|
|
|
|
@ -372,6 +374,7 @@ class Worker(object): |
|
|
|
num_workers=self.num_workers, drop_last=True, pin_memory=False) |
|
|
|
num_workers=self.num_workers, drop_last=True, pin_memory=False) |
|
|
|
|
|
|
|
|
|
|
|
net = net.to(self.train_device) |
|
|
|
net = net.to(self.train_device) |
|
|
|
|
|
|
|
wandb.watch(net) |
|
|
|
net.train() |
|
|
|
net.train() |
|
|
|
|
|
|
|
|
|
|
|
mean_loss = None |
|
|
|
mean_loss = None |
|
|
@ -396,6 +399,7 @@ class Worker(object): |
|
|
|
stopwatch.start('loss') |
|
|
|
stopwatch.start('loss') |
|
|
|
errs = self.loss_forward(output, train=True) |
|
|
|
errs = self.loss_forward(output, train=True) |
|
|
|
if isinstance(errs, dict): |
|
|
|
if isinstance(errs, dict): |
|
|
|
|
|
|
|
wandb.log(errs) |
|
|
|
masks = errs['masks'] |
|
|
|
masks = errs['masks'] |
|
|
|
errs = errs['errs'] |
|
|
|
errs = errs['errs'] |
|
|
|
else: |
|
|
|
else: |
|
|
@ -442,6 +446,7 @@ class Worker(object): |
|
|
|
|
|
|
|
|
|
|
|
err_str = self.format_err_str(mean_loss) |
|
|
|
err_str = self.format_err_str(mean_loss) |
|
|
|
logging.info(f'avg train_loss={err_str}') |
|
|
|
logging.info(f'avg train_loss={err_str}') |
|
|
|
|
|
|
|
wandb.log({'mean_loss': mean_loss}) |
|
|
|
return mean_loss |
|
|
|
return mean_loss |
|
|
|
|
|
|
|
|
|
|
|
def callback_test_start(self, epoch, set_idx): |
|
|
|
def callback_test_start(self, epoch, set_idx): |
|
|
@ -495,6 +500,7 @@ class Worker(object): |
|
|
|
stopwatch.start('loss') |
|
|
|
stopwatch.start('loss') |
|
|
|
errs = self.loss_forward(output, train=False) |
|
|
|
errs = self.loss_forward(output, train=False) |
|
|
|
if isinstance(errs, dict): |
|
|
|
if isinstance(errs, dict): |
|
|
|
|
|
|
|
wandb.log(errs) |
|
|
|
masks = errs['masks'] |
|
|
|
masks = errs['masks'] |
|
|
|
errs = errs['errs'] |
|
|
|
errs = errs['errs'] |
|
|
|
else: |
|
|
|
else: |
|
|
@ -529,4 +535,5 @@ class Worker(object): |
|
|
|
|
|
|
|
|
|
|
|
err_str = self.format_err_str(mean_loss) |
|
|
|
err_str = self.format_err_str(mean_loss) |
|
|
|
logging.info(f'test epoch {epoch}: avg test_loss={err_str}') |
|
|
|
logging.info(f'test epoch {epoch}: avg test_loss={err_str}') |
|
|
|
|
|
|
|
wandb.log({'Test loss': mean_loss}) |
|
|
|
return mean_loss |
|
|
|
return mean_loss |
|
|
|