From 3da66347e7bcf7989b08517c7ca97bfd97594ca9 Mon Sep 17 00:00:00 2001 From: "Cpt.Captain" Date: Tue, 22 Feb 2022 13:35:50 +0100 Subject: [PATCH] Add wandb, more args, increase learning rate --- train_val.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/train_val.py b/train_val.py index 17f14f2..4601864 100644 --- a/train_val.py +++ b/train_val.py @@ -2,27 +2,41 @@ import os import torch from model import exp_synph from model import exp_synphge +from model import exp_synph_real from model import networks from co.args import parse_args +import wandb + +wandb.init(project="connecting_the_dots", entity="cpt-captain") +wandb.config.epochs = 100 +wandb.config.batch_size = 3 # parse args args = parse_args() +double_head = args.no_double_heads + +wandb.config.update(args, allow_val_change=True) # loss types if args.loss == 'ph': worker = exp_synph.Worker(args) elif args.loss == 'phge': worker = exp_synphge.Worker(args) +elif args.loss == 'phirl': + worker = exp_synph_real.Worker(args) + # double_head = False + # concatenation of original image and lcn image channels_in = 2 # set up network net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes, - output_ms=worker.ms) + output_ms=worker.ms, double_head=double_head) # optimizer -optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) +# optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) +optimizer = torch.optim.Adam(net.parameters(), lr=5e-4) # start the work worker.do(net, optimizer)