Add wandb
This commit is contained in:
parent
3da66347e7
commit
9ed0c264f5
@ -14,6 +14,7 @@ import json
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
|
||||||
class StopWatch(object):
|
class StopWatch(object):
|
||||||
@ -81,7 +82,7 @@ class ETA(object):
|
|||||||
|
|
||||||
class Worker(object):
|
class Worker(object):
|
||||||
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16,
|
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16,
|
||||||
num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
|
num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1, no_double_heads=True):
|
||||||
self.out_root = Path(out_root)
|
self.out_root = Path(out_root)
|
||||||
self.experiment_name = experiment_name
|
self.experiment_name = experiment_name
|
||||||
self.epochs = epochs
|
self.epochs = epochs
|
||||||
@ -93,6 +94,7 @@ class Worker(object):
|
|||||||
self.train_device = train_device
|
self.train_device = train_device
|
||||||
self.test_device = test_device
|
self.test_device = test_device
|
||||||
self.max_train_iter = max_train_iter
|
self.max_train_iter = max_train_iter
|
||||||
|
self.double_heads = no_double_heads
|
||||||
|
|
||||||
self.errs_list = []
|
self.errs_list = []
|
||||||
|
|
||||||
@ -372,6 +374,7 @@ class Worker(object):
|
|||||||
num_workers=self.num_workers, drop_last=True, pin_memory=False)
|
num_workers=self.num_workers, drop_last=True, pin_memory=False)
|
||||||
|
|
||||||
net = net.to(self.train_device)
|
net = net.to(self.train_device)
|
||||||
|
wandb.watch(net)
|
||||||
net.train()
|
net.train()
|
||||||
|
|
||||||
mean_loss = None
|
mean_loss = None
|
||||||
@ -396,6 +399,7 @@ class Worker(object):
|
|||||||
stopwatch.start('loss')
|
stopwatch.start('loss')
|
||||||
errs = self.loss_forward(output, train=True)
|
errs = self.loss_forward(output, train=True)
|
||||||
if isinstance(errs, dict):
|
if isinstance(errs, dict):
|
||||||
|
wandb.log(errs)
|
||||||
masks = errs['masks']
|
masks = errs['masks']
|
||||||
errs = errs['errs']
|
errs = errs['errs']
|
||||||
else:
|
else:
|
||||||
@ -442,6 +446,7 @@ class Worker(object):
|
|||||||
|
|
||||||
err_str = self.format_err_str(mean_loss)
|
err_str = self.format_err_str(mean_loss)
|
||||||
logging.info(f'avg train_loss={err_str}')
|
logging.info(f'avg train_loss={err_str}')
|
||||||
|
wandb.log({'mean_loss': mean_loss})
|
||||||
return mean_loss
|
return mean_loss
|
||||||
|
|
||||||
def callback_test_start(self, epoch, set_idx):
|
def callback_test_start(self, epoch, set_idx):
|
||||||
@ -495,6 +500,7 @@ class Worker(object):
|
|||||||
stopwatch.start('loss')
|
stopwatch.start('loss')
|
||||||
errs = self.loss_forward(output, train=False)
|
errs = self.loss_forward(output, train=False)
|
||||||
if isinstance(errs, dict):
|
if isinstance(errs, dict):
|
||||||
|
wandb.log(errs)
|
||||||
masks = errs['masks']
|
masks = errs['masks']
|
||||||
errs = errs['errs']
|
errs = errs['errs']
|
||||||
else:
|
else:
|
||||||
@ -529,4 +535,5 @@ class Worker(object):
|
|||||||
|
|
||||||
err_str = self.format_err_str(mean_loss)
|
err_str = self.format_err_str(mean_loss)
|
||||||
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
|
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
|
||||||
|
wandb.log({'Test loss': mean_loss})
|
||||||
return mean_loss
|
return mean_loss
|
||||||
|
Loading…
Reference in New Issue
Block a user