diff --git a/model/networks.py b/model/networks.py index cd21bae..ea6d8ab 100644 --- a/model/networks.py +++ b/model/networks.py @@ -192,6 +192,8 @@ class DispNetS(TimedModule): def conv(self, in_planes, out_planes): return torch.nn.Sequential( torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1), + # TODO try this + torch.nn.LayerNorm(out_planes), torch.nn.ReLU(inplace=True) )