import os import torch from model import exp_synph from model import exp_synphge from model import exp_synph_real from model import networks from co.args import parse_args import wandb wandb.init(project="connecting_the_dots", entity="cpt-captain") wandb.config.epochs = 100 wandb.config.batch_size = 3 # parse args args = parse_args() double_head = args.no_double_heads wandb.config.update(args, allow_val_change=True) # loss types if args.loss == 'ph': worker = exp_synph.Worker(args) elif args.loss == 'phge': worker = exp_synphge.Worker(args) elif args.loss == 'phirl': worker = exp_synph_real.Worker(args) # double_head = False # concatenation of original image and lcn image channels_in = 2 # set up network net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes, output_ms=worker.ms, double_head=double_head) # optimizer # optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) optimizer = torch.optim.Adam(net.parameters(), lr=5e-4) # start the work worker.do(net, optimizer)