Add wandb, more args, increase learning rate
This commit is contained in:
parent
168516924e
commit
3da66347e7
18
train_val.py
18
train_val.py
@ -2,27 +2,41 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
from model import exp_synph
|
from model import exp_synph
|
||||||
from model import exp_synphge
|
from model import exp_synphge
|
||||||
|
from model import exp_synph_real
|
||||||
from model import networks
|
from model import networks
|
||||||
from co.args import parse_args
|
from co.args import parse_args
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
wandb.init(project="connecting_the_dots", entity="cpt-captain")
|
||||||
|
wandb.config.epochs = 100
|
||||||
|
wandb.config.batch_size = 3
|
||||||
|
|
||||||
# parse args
|
# parse args
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
double_head = args.no_double_heads
|
||||||
|
|
||||||
|
wandb.config.update(args, allow_val_change=True)
|
||||||
|
|
||||||
# loss types
|
# loss types
|
||||||
if args.loss == 'ph':
|
if args.loss == 'ph':
|
||||||
worker = exp_synph.Worker(args)
|
worker = exp_synph.Worker(args)
|
||||||
elif args.loss == 'phge':
|
elif args.loss == 'phge':
|
||||||
worker = exp_synphge.Worker(args)
|
worker = exp_synphge.Worker(args)
|
||||||
|
elif args.loss == 'phirl':
|
||||||
|
worker = exp_synph_real.Worker(args)
|
||||||
|
# double_head = False
|
||||||
|
|
||||||
|
|
||||||
# concatenation of original image and lcn image
|
# concatenation of original image and lcn image
|
||||||
channels_in = 2
|
channels_in = 2
|
||||||
|
|
||||||
# set up network
|
# set up network
|
||||||
net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes,
|
net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes,
|
||||||
output_ms=worker.ms)
|
output_ms=worker.ms, double_head=double_head)
|
||||||
|
|
||||||
# optimizer
|
# optimizer
|
||||||
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
|
# optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
|
||||||
|
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)
|
||||||
|
|
||||||
# start the work
|
# start the work
|
||||||
worker.do(net, optimizer)
|
worker.do(net, optimizer)
|
||||||
|
Loading…
Reference in New Issue
Block a user