You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
connecting_the_dots/train_val.py

42 lines
1.1 KiB

import os
import torch
from model import exp_synph
from model import exp_synphge
from model import exp_synph_real
from model import networks
from co.args import parse_args
import wandb
wandb.init(project="connecting_the_dots", entity="cpt-captain")
wandb.config.epochs = 100
wandb.config.batch_size = 3
# parse args
args = parse_args()
double_head = args.no_double_heads
wandb.config.update(args, allow_val_change=True)
# loss types
if args.loss == 'ph':
worker = exp_synph.Worker(args)
elif args.loss == 'phge':
worker = exp_synphge.Worker(args)
elif args.loss == 'phirl':
worker = exp_synph_real.Worker(args)
# double_head = False
# concatenation of original image and lcn image
channels_in = 2
# set up network
net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes,
output_ms=worker.ms, double_head=double_head)
# optimizer
# optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)
# start the work
worker.do(net, optimizer)