From 9ed0c264f5e2b28763bbd1efffdf9363a4da9f7b Mon Sep 17 00:00:00 2001 From: "Cpt.Captain" Date: Tue, 22 Feb 2022 13:36:13 +0100 Subject: [PATCH] Add wandb --- torchext/worker.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torchext/worker.py b/torchext/worker.py index 4617a0a..270e455 100644 --- a/torchext/worker.py +++ b/torchext/worker.py @@ -14,6 +14,7 @@ import json import matplotlib.pyplot as plt import time from collections import OrderedDict +import wandb class StopWatch(object): @@ -81,7 +82,7 @@ class ETA(object): class Worker(object): def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, - num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1): + num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1, no_double_heads=True): self.out_root = Path(out_root) self.experiment_name = experiment_name self.epochs = epochs @@ -93,6 +94,7 @@ class Worker(object): self.train_device = train_device self.test_device = test_device self.max_train_iter = max_train_iter + self.double_heads = no_double_heads self.errs_list = [] @@ -372,6 +374,7 @@ class Worker(object): num_workers=self.num_workers, drop_last=True, pin_memory=False) net = net.to(self.train_device) + wandb.watch(net) net.train() mean_loss = None @@ -396,6 +399,7 @@ class Worker(object): stopwatch.start('loss') errs = self.loss_forward(output, train=True) if isinstance(errs, dict): + wandb.log(errs) masks = errs['masks'] errs = errs['errs'] else: @@ -442,6 +446,7 @@ class Worker(object): err_str = self.format_err_str(mean_loss) logging.info(f'avg train_loss={err_str}') + wandb.log({'mean_loss': mean_loss}) return mean_loss def callback_test_start(self, epoch, set_idx): @@ -495,6 +500,7 @@ class Worker(object): stopwatch.start('loss') errs = self.loss_forward(output, train=False) if isinstance(errs, dict): + wandb.log(errs) masks = errs['masks'] errs = errs['errs'] else: @@ -529,4 +535,5 @@ class Worker(object): err_str = self.format_err_str(mean_loss) logging.info(f'test epoch {epoch}: avg test_loss={err_str}') + wandb.log({'Test loss': mean_loss}) return mean_loss