|
|
|
@ -2,27 +2,41 @@ import os |
|
|
|
|
import torch |
|
|
|
|
from model import exp_synph |
|
|
|
|
from model import exp_synphge |
|
|
|
|
from model import exp_synph_real |
|
|
|
|
from model import networks |
|
|
|
|
from co.args import parse_args |
|
|
|
|
import wandb |
|
|
|
|
|
|
|
|
|
wandb.init(project="connecting_the_dots", entity="cpt-captain") |
|
|
|
|
wandb.config.epochs = 100 |
|
|
|
|
wandb.config.batch_size = 3 |
|
|
|
|
|
|
|
|
|
# parse args |
|
|
|
|
args = parse_args() |
|
|
|
|
double_head = args.no_double_heads |
|
|
|
|
|
|
|
|
|
wandb.config.update(args, allow_val_change=True) |
|
|
|
|
|
|
|
|
|
# loss types |
|
|
|
|
if args.loss == 'ph': |
|
|
|
|
worker = exp_synph.Worker(args) |
|
|
|
|
elif args.loss == 'phge': |
|
|
|
|
worker = exp_synphge.Worker(args) |
|
|
|
|
elif args.loss == 'phirl': |
|
|
|
|
worker = exp_synph_real.Worker(args) |
|
|
|
|
# double_head = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# concatenation of original image and lcn image |
|
|
|
|
channels_in = 2 |
|
|
|
|
|
|
|
|
|
# set up network |
|
|
|
|
net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes, |
|
|
|
|
output_ms=worker.ms) |
|
|
|
|
output_ms=worker.ms, double_head=double_head) |
|
|
|
|
|
|
|
|
|
# optimizer |
|
|
|
|
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) |
|
|
|
|
# optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) |
|
|
|
|
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4) |
|
|
|
|
|
|
|
|
|
# start the work |
|
|
|
|
worker.do(net, optimizer) |
|
|
|
|