init
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
from .dataset import *
|
||||
from .worker import *
|
||||
from .functions import *
|
||||
from .modules import *
|
||||
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
|
||||
class TestSet(object):
|
||||
def __init__(self, name, dset, test_frequency=1):
|
||||
self.name = name
|
||||
self.dset = dset
|
||||
self.test_frequency = test_frequency
|
||||
|
||||
class TestSets(list):
|
||||
def append(self, name, dset, test_frequency=1):
|
||||
super().append(TestSet(name, dset, test_frequency))
|
||||
|
||||
|
||||
|
||||
class MultiDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, *datasets):
|
||||
self.current_epoch = 0
|
||||
|
||||
self.datasets = []
|
||||
self.cum_n_samples = [0]
|
||||
|
||||
for dataset in datasets:
|
||||
self.append(dataset)
|
||||
|
||||
def append(self, dataset):
|
||||
self.datasets.append(dataset)
|
||||
self.__update_cum_n_samples(dataset)
|
||||
|
||||
def __update_cum_n_samples(self, dataset):
|
||||
n_samples = self.cum_n_samples[-1] + len(dataset)
|
||||
self.cum_n_samples.append(n_samples)
|
||||
|
||||
def dataset_updated(self):
|
||||
self.cum_n_samples = [0]
|
||||
for dset in self.datasets:
|
||||
self.__update_cum_n_samples(dset)
|
||||
|
||||
def __len__(self):
|
||||
return self.cum_n_samples[-1]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1
|
||||
sidx = idx - self.cum_n_samples[didx]
|
||||
return self.datasets[didx][sidx]
|
||||
|
||||
|
||||
|
||||
class BaseDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, train=True, fix_seed_per_epoch=False):
|
||||
self.current_epoch = 0
|
||||
self.train = train
|
||||
self.fix_seed_per_epoch = fix_seed_per_epoch
|
||||
|
||||
def get_rng(self, idx):
|
||||
rng = np.random.RandomState()
|
||||
if self.train:
|
||||
if self.fix_seed_per_epoch:
|
||||
seed = 1 * len(self) + idx
|
||||
else:
|
||||
seed = (self.current_epoch + 1) * len(self) + idx
|
||||
rng.seed(seed)
|
||||
else:
|
||||
rng.seed(idx)
|
||||
return rng
|
||||
@@ -0,0 +1,10 @@
|
||||
#ifndef TYPES_H
|
||||
#define TYPES_H
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
#define CPU_GPU_FUNCTION __host__ __device__
|
||||
#else
|
||||
#define CPU_GPU_FUNCTION
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,135 @@
|
||||
#ifndef COMMON_H
|
||||
#define COMMON_H
|
||||
|
||||
#include "co_types.h"
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
|
||||
#define DISABLE_COPY_AND_ASSIGN(classname) \
|
||||
private:\
|
||||
classname(const classname&) = delete;\
|
||||
classname& operator=(const classname&) = delete;
|
||||
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
void fill(T* arr, int N, T val) {
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
arr[idx] = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
void fill_zero(T* arr, int N) {
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
arr[idx] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
inline T distance_euclidean(const T* q, const T* t, int N) {
|
||||
T out = 0;
|
||||
for(int idx = 0; idx < N; idx++) {
|
||||
T diff = q[idx] - t[idx];
|
||||
out += diff * diff;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
inline T distance_l2(const T* q, const T* t, int N) {
|
||||
T out = distance_euclidean(q, t, N);
|
||||
out = std::sqrt(out);
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct FillFunctor {
|
||||
T* arr;
|
||||
const T val;
|
||||
|
||||
FillFunctor(T* arr, const T val) : arr(arr), val(val) {}
|
||||
CPU_GPU_FUNCTION void operator()(const int idx) {
|
||||
arr[idx] = val;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
T mmin(const T& a, const T& b) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
return min(a, b);
|
||||
#else
|
||||
return std::min(a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
T mmax(const T& a, const T& b) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
return max(a, b);
|
||||
#else
|
||||
return std::max(a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
T mround(const T& a) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
return round(a);
|
||||
#else
|
||||
return round(a);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ < 600
|
||||
__device__ double atomicAdd(double* address, double val)
|
||||
{
|
||||
unsigned long long int* address_as_ull =
|
||||
(unsigned long long int*)address;
|
||||
unsigned long long int old = *address_as_ull, assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed,
|
||||
__double_as_longlong(val +
|
||||
__longlong_as_double(assumed)));
|
||||
|
||||
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
||||
} while (assumed != old);
|
||||
|
||||
return __longlong_as_double(old);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
void matomic_add(T* addr, T val) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
atomicAdd(addr, val);
|
||||
#else
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp atomic
|
||||
#endif
|
||||
*addr += val;
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,173 @@
|
||||
#ifndef COMMON_CUDA
|
||||
#define COMMON_CUDA
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#define DEBUG 0
|
||||
#define CUDA_DEBUG_DEVICE_SYNC 0
|
||||
|
||||
// cuda check for cudaMalloc and so on
|
||||
#define CUDA_CHECK(condition) \
|
||||
/* Code block avoids redefinition of cudaError_t error */ \
|
||||
do { \
|
||||
if(CUDA_DEBUG_DEVICE_SYNC) { cudaDeviceSynchronize(); } \
|
||||
cudaError_t error = condition; \
|
||||
if(error != cudaSuccess) { \
|
||||
printf("%s in %s at %d\n", cudaGetErrorString(error), __FILE__, __LINE__); \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
/// Get error string for error code.
|
||||
/// @param error
|
||||
inline const char* cublasGetErrorString(cublasStatus_t error) {
|
||||
switch (error) {
|
||||
case CUBLAS_STATUS_SUCCESS:
|
||||
return "CUBLAS_STATUS_SUCCESS";
|
||||
case CUBLAS_STATUS_NOT_INITIALIZED:
|
||||
return "CUBLAS_STATUS_NOT_INITIALIZED";
|
||||
case CUBLAS_STATUS_ALLOC_FAILED:
|
||||
return "CUBLAS_STATUS_ALLOC_FAILED";
|
||||
case CUBLAS_STATUS_INVALID_VALUE:
|
||||
return "CUBLAS_STATUS_INVALID_VALUE";
|
||||
case CUBLAS_STATUS_ARCH_MISMATCH:
|
||||
return "CUBLAS_STATUS_ARCH_MISMATCH";
|
||||
case CUBLAS_STATUS_MAPPING_ERROR:
|
||||
return "CUBLAS_STATUS_MAPPING_ERROR";
|
||||
case CUBLAS_STATUS_EXECUTION_FAILED:
|
||||
return "CUBLAS_STATUS_EXECUTION_FAILED";
|
||||
case CUBLAS_STATUS_INTERNAL_ERROR:
|
||||
return "CUBLAS_STATUS_INTERNAL_ERROR";
|
||||
case CUBLAS_STATUS_NOT_SUPPORTED:
|
||||
return "CUBLAS_STATUS_NOT_SUPPORTED";
|
||||
case CUBLAS_STATUS_LICENSE_ERROR:
|
||||
return "CUBLAS_STATUS_LICENSE_ERROR";
|
||||
}
|
||||
return "Unknown cublas status";
|
||||
}
|
||||
|
||||
#define CUBLAS_CHECK(condition) \
|
||||
do { \
|
||||
if(CUDA_DEBUG_DEVICE_SYNC) { cudaDeviceSynchronize(); } \
|
||||
cublasStatus_t status = condition; \
|
||||
if(status != CUBLAS_STATUS_SUCCESS) { \
|
||||
printf("%s in %s at %d\n", cublasGetErrorString(status), __FILE__, __LINE__); \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// check if there is a error after kernel execution
|
||||
#define CUDA_POST_KERNEL_CHECK \
|
||||
CUDA_CHECK(cudaPeekAtLastError()); \
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
#define CUDA_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
|
||||
|
||||
const int CUDA_NUM_THREADS = 1024;
|
||||
|
||||
inline int GET_BLOCKS(const int N, const int N_THREADS=CUDA_NUM_THREADS) {
|
||||
return (N + N_THREADS - 1) / N_THREADS;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T* device_malloc(long N) {
|
||||
T* dptr;
|
||||
CUDA_CHECK(cudaMalloc(&dptr, N * sizeof(T)));
|
||||
if(DEBUG) { printf("[DEBUG] device_malloc %p, %ld\n", dptr, N); }
|
||||
return dptr;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void device_free(T* dptr) {
|
||||
if(DEBUG) { printf("[DEBUG] device_free %p\n", dptr); }
|
||||
CUDA_CHECK(cudaFree(dptr));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void host_to_device(const T* hptr, T* dptr, long N) {
|
||||
if(DEBUG) { printf("[DEBUG] host_to_device %p => %p, %ld\n", hptr, dptr, N); }
|
||||
CUDA_CHECK(cudaMemcpy(dptr, hptr, N * sizeof(T), cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T* host_to_device_malloc(const T* hptr, long N) {
|
||||
T* dptr = device_malloc<T>(N);
|
||||
host_to_device(hptr, dptr, N);
|
||||
return dptr;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void device_to_host(const T* dptr, T* hptr, long N) {
|
||||
if(DEBUG) { printf("[DEBUG] device_to_host %p => %p, %ld\n", dptr, hptr, N); }
|
||||
CUDA_CHECK(cudaMemcpy(hptr, dptr, N * sizeof(T), cudaMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T* device_to_host_malloc(const T* dptr, long N) {
|
||||
T* hptr = new T[N];
|
||||
device_to_host(dptr, hptr, N);
|
||||
return hptr;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void device_to_device(const T* dptr, T* hptr, long N) {
|
||||
if(DEBUG) { printf("[DEBUG] device_to_device %p => %p, %ld\n", dptr, hptr, N); }
|
||||
CUDA_CHECK(cudaMemcpy(hptr, dptr, N * sizeof(T), cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
|
||||
// https://github.com/parallel-forall/code-samples/blob/master/posts/cuda-aware-mpi-example/src/Device.cu
|
||||
// https://github.com/treecode/Bonsai/blob/master/runtime/profiling/derived_atomic_functions.h
|
||||
__device__ __forceinline__ void atomicMaxF(float * const address, const float value) {
|
||||
if (*address >= value) {
|
||||
return;
|
||||
}
|
||||
|
||||
int * const address_as_i = (int *)address;
|
||||
int old = * address_as_i, assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
if (__int_as_float(assumed) >= value) {
|
||||
break;
|
||||
}
|
||||
|
||||
old = atomicCAS(address_as_i, assumed, __float_as_int(value));
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void atomicMinF(float * const address, const float value) {
|
||||
if (*address <= value) {
|
||||
return;
|
||||
}
|
||||
|
||||
int * const address_as_i = (int *)address;
|
||||
int old = * address_as_i, assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
if (__int_as_float(assumed) <= value) {
|
||||
break;
|
||||
}
|
||||
|
||||
old = atomicCAS(address_as_i, assumed, __float_as_int(value));
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
|
||||
template <typename FunctorT>
|
||||
__global__ void iterate_kernel(FunctorT functor, int N) {
|
||||
CUDA_KERNEL_LOOP(idx, N) {
|
||||
functor(idx);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FunctorT>
|
||||
void iterate_cuda(FunctorT functor, int N, int N_THREADS=CUDA_NUM_THREADS) {
|
||||
iterate_kernel<<<GET_BLOCKS(N, N_THREADS), N_THREADS>>>(functor, N);
|
||||
CUDA_POST_KERNEL_CHECK;
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,347 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
|
||||
#define CHECK_INPUT_CPU(x) CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_INPUT_CUDA(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
|
||||
template <typename T, int dim=3>
|
||||
struct NNFunctor {
|
||||
const T* in0; // nelem0 x dim
|
||||
const T* in1; // nelem1 x dim
|
||||
const long nelem0;
|
||||
const long nelem1;
|
||||
long* out; // nelem0
|
||||
|
||||
NNFunctor(const T* in0, const T* in1, long nelem0, long nelem1, long* out) : in0(in0), in1(in1), nelem0(nelem0), nelem1(nelem1), out(out) {}
|
||||
|
||||
CPU_GPU_FUNCTION void operator()(long idx0) {
|
||||
// idx0 \in [nelem0]
|
||||
|
||||
const T* vec0 = in0 + idx0 * dim;
|
||||
|
||||
T min_dist = 1e9;
|
||||
long min_arg = -1;
|
||||
for(long idx1 = 0; idx1 < nelem1; ++idx1) {
|
||||
const T* vec1 = in1 + idx1 * dim;
|
||||
T dist = 0;
|
||||
for(long didx = 0; didx < dim; ++didx) {
|
||||
T diff = vec0[didx] - vec1[didx];
|
||||
dist += diff * diff;
|
||||
}
|
||||
|
||||
if(dist < min_dist) {
|
||||
min_dist = dist;
|
||||
min_arg = idx1;
|
||||
}
|
||||
}
|
||||
|
||||
out[idx0] = min_arg;
|
||||
}
|
||||
};
|
||||
|
||||
struct CrossCheckFunctor {
|
||||
const long* in0; // nelem0
|
||||
const long* in1; // nelem1
|
||||
const long nelem0;
|
||||
const long nelem1;
|
||||
uint8_t* out; // nelem0
|
||||
|
||||
CrossCheckFunctor(const long* in0, const long* in1, long nelem0, long nelem1, uint8_t* out) : in0(in0), in1(in1), nelem0(nelem0), nelem1(nelem1), out(out) {}
|
||||
|
||||
CPU_GPU_FUNCTION void operator()(long idx0) {
|
||||
// idx0 \in [nelem0]
|
||||
int idx1 = in0[idx0];
|
||||
out[idx0] = idx1 >=0 && in1[idx1] >= 0 && idx0 == in1[idx1];
|
||||
// out[idx0] = idx0 == in1[in0[idx0]];
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int dim=3>
|
||||
struct ProjNNFunctor {
|
||||
// xyz0, xyz1 in coord sys of 1
|
||||
const T* xyz0; // bs x height x width x 3
|
||||
const T* xyz1; // bs x height x width x 3
|
||||
const T* K; // 3 x 3
|
||||
const long batch_size;
|
||||
const long height;
|
||||
const long width;
|
||||
const long patch_size;
|
||||
long* out; // bs x height x width
|
||||
|
||||
ProjNNFunctor(const T* xyz0, const T* xyz1, const T* K, long batch_size, long height, long width, long patch_size, long* out)
|
||||
: xyz0(xyz0), xyz1(xyz1), K(K), batch_size(batch_size), height(height), width(width), patch_size(patch_size), out(out) {}
|
||||
|
||||
CPU_GPU_FUNCTION void operator()(long idx0) {
|
||||
// idx0 \in [0, bs x height x width]
|
||||
|
||||
const long bs = idx0 / (height * width);
|
||||
|
||||
const T x = xyz0[idx0 * 3 + 0];
|
||||
const T y = xyz0[idx0 * 3 + 1];
|
||||
const T z = xyz0[idx0 * 3 + 2];
|
||||
const T d = K[6] * x + K[7] * y + K[8] * z;
|
||||
const T u = (K[0] * x + K[1] * y + K[2] * z) / d;
|
||||
const T v = (K[3] * x + K[4] * y + K[5] * z) / d;
|
||||
|
||||
int u0 = u + 0.5;
|
||||
int v0 = v + 0.5;
|
||||
|
||||
long min_idx1 = -1;
|
||||
T min_dist = 1e9;
|
||||
for(int pidx = 0; pidx < patch_size*patch_size; ++pidx) {
|
||||
int pu = pidx % patch_size;
|
||||
int pv = pidx / patch_size;
|
||||
|
||||
int u1 = u0 + pu - patch_size/2;
|
||||
int v1 = v0 + pv - patch_size/2;
|
||||
|
||||
if(u1 >= 0 && v1 >= 0 && u1 < width && v1 < height) {
|
||||
const long idx1 = (bs * height + v1) * width + u1;
|
||||
const T* xyz1n = xyz1 + idx1 * 3;
|
||||
const T d = (x-xyz1n[0]) * (x-xyz1n[0]) + (y-xyz1n[1]) * (y-xyz1n[1]) + (z-xyz1n[2]) * (z-xyz1n[2]);
|
||||
if(d < min_dist) {
|
||||
min_dist = d;
|
||||
min_idx1 = idx1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out[idx0] = min_idx1;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename T, int dim=3>
|
||||
struct XCorrVolFunctor {
|
||||
const T* in0; // channels x height x width
|
||||
const T* in1; // channels x height x width
|
||||
const long channels;
|
||||
const long height;
|
||||
const long width;
|
||||
const long n_disps;
|
||||
const long block_size;
|
||||
T* out; // nelem0
|
||||
|
||||
XCorrVolFunctor(const T* in0, const T* in1, long channels, long height, long width, long n_disps, long block_size, T* out) : in0(in0), in1(in1), channels(channels), height(height), width(width), n_disps(n_disps), block_size(block_size), out(out) {}
|
||||
|
||||
CPU_GPU_FUNCTION void operator()(long oidx) {
|
||||
// idx0 \in [n_disps x height x width]
|
||||
|
||||
auto d = oidx / (height * width);
|
||||
auto h = (oidx / width) % height;
|
||||
auto w = oidx % width;
|
||||
|
||||
long block_size2 = block_size * block_size;
|
||||
|
||||
T val = 0;
|
||||
for(int c = 0; c < channels; ++c) {
|
||||
// compute means
|
||||
T mu0 = 0;
|
||||
T mu1 = 0;
|
||||
for(int bh = 0; bh < block_size; ++bh) {
|
||||
long h0 = h + bh - block_size / 2;
|
||||
h0 = mmax(long(0), mmin(height-1, h0));
|
||||
for(int bw = 0; bw < block_size; ++bw) {
|
||||
long w0 = w + bw - block_size / 2;
|
||||
long w1 = w0 - d;
|
||||
w0 = mmax(long(0), mmin(width-1, w0));
|
||||
w1 = mmax(long(0), mmin(width-1, w1));
|
||||
long idx0 = (c * height + h0) * width + w0;
|
||||
long idx1 = (c * height + h0) * width + w1;
|
||||
mu0 += in0[idx0] / block_size2;
|
||||
mu1 += in1[idx1] / block_size2;
|
||||
}
|
||||
}
|
||||
|
||||
// compute stds and dot product
|
||||
T sigma0 = 0;
|
||||
T sigma1 = 0;
|
||||
T dot = 0;
|
||||
for(int bh = 0; bh < block_size; ++bh) {
|
||||
long h0 = h + bh - block_size / 2;
|
||||
h0 = mmax(long(0), mmin(height-1, h0));
|
||||
for(int bw = 0; bw < block_size; ++bw) {
|
||||
long w0 = w + bw - block_size / 2;
|
||||
long w1 = w0 - d;
|
||||
w0 = mmax(long(0), mmin(width-1, w0));
|
||||
w1 = mmax(long(0), mmin(width-1, w1));
|
||||
long idx0 = (c * height + h0) * width + w0;
|
||||
long idx1 = (c * height + h0) * width + w1;
|
||||
T v0 = in0[idx0] - mu0;
|
||||
T v1 = in1[idx1] - mu1;
|
||||
|
||||
dot += v0 * v1;
|
||||
sigma0 += v0 * v0;
|
||||
sigma1 += v1 * v1;
|
||||
}
|
||||
}
|
||||
|
||||
T norm = sqrt(sigma0 * sigma1) + 1e-8;
|
||||
val += dot / norm;
|
||||
}
|
||||
|
||||
out[oidx] = val;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
const int PHOTOMETRIC_LOSS_MSE = 0;
|
||||
const int PHOTOMETRIC_LOSS_SAD = 1;
|
||||
const int PHOTOMETRIC_LOSS_CENSUS_MSE = 2;
|
||||
const int PHOTOMETRIC_LOSS_CENSUS_SAD = 3;
|
||||
|
||||
template <typename T, int type>
|
||||
struct PhotometricLossForward {
|
||||
const T* es; // batch_size x channels x height x width;
|
||||
const T* ta;
|
||||
const int block_size;
|
||||
const int block_size2;
|
||||
const T eps;
|
||||
const int batch_size;
|
||||
const int channels;
|
||||
const int height;
|
||||
const int width;
|
||||
T* out; // batch_size x channels x height x width;
|
||||
|
||||
PhotometricLossForward(const T* es, const T* ta, int block_size, T eps, int batch_size, int channels, int height, int width, T* out) :
|
||||
es(es), ta(ta), block_size(block_size), block_size2(block_size*block_size), eps(eps), batch_size(batch_size), channels(channels), height(height), width(width), out(out) {}
|
||||
|
||||
CPU_GPU_FUNCTION void operator()(int outidx) {
|
||||
// outidx \in [0, batch_size x height x width]
|
||||
|
||||
int w = outidx % width;
|
||||
int h = (outidx / width) % height;
|
||||
int n = outidx / (height * width);
|
||||
|
||||
T loss = 0;
|
||||
for(int bidx = 0; bidx < block_size2; ++bidx) {
|
||||
int bh = bidx / block_size;
|
||||
int bw = bidx % block_size;
|
||||
int h0 = h + bh - block_size / 2;
|
||||
int w0 = w + bw - block_size / 2;
|
||||
|
||||
h0 = mmin(height-1, mmax(0, h0));
|
||||
w0 = mmin(width-1, mmax(0, w0));
|
||||
|
||||
for(int c = 0; c < channels; ++c) {
|
||||
int inidx = ((n * channels + c) * height + h0) * width + w0;
|
||||
if(type == PHOTOMETRIC_LOSS_SAD || type == PHOTOMETRIC_LOSS_MSE) {
|
||||
T diff = es[inidx] - ta[inidx];
|
||||
if(type == PHOTOMETRIC_LOSS_MSE) {
|
||||
loss += diff * diff / block_size2;
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_SAD) {
|
||||
loss += fabs(diff) / block_size2;
|
||||
}
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD || type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
|
||||
int inidxc = ((n * channels + c) * height + h) * width + w;
|
||||
T des = es[inidx] - es[inidxc];
|
||||
T dta = ta[inidx] - ta[inidxc];
|
||||
T h_des = 0.5 * (1 + des / sqrt(des * des + eps));
|
||||
T h_dta = 0.5 * (1 + dta / sqrt(dta * dta + eps));
|
||||
T diff = h_des - h_dta;
|
||||
// printf("%d,%d %d,%d: des=%f, dta=%f, h_des=%f, h_dta=%f, diff=%f\n", h,w, h0,w0, des,dta, h_des,h_dta, diff);
|
||||
// printf("%d,%d %d,%d: h_des=%f = 0.5 * (1 + %f / %f); %f, %f, %f\n", h,w, h0,w0, h_des, des, sqrt(des * des + eps), des*des, des*des+eps, eps);
|
||||
if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
|
||||
loss += diff * diff / block_size2;
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
|
||||
loss += fabs(diff) / block_size2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out[outidx] = loss;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int type>
|
||||
struct PhotometricLossBackward {
|
||||
const T* es; // batch_size x channels x height x width;
|
||||
const T* ta;
|
||||
const T* grad_out;
|
||||
const int block_size;
|
||||
const int block_size2;
|
||||
const T eps;
|
||||
const int batch_size;
|
||||
const int channels;
|
||||
const int height;
|
||||
const int width;
|
||||
T* grad_in; // batch_size x channels x height x width;
|
||||
|
||||
PhotometricLossBackward(const T* es, const T* ta, const T* grad_out, int block_size, T eps, int batch_size, int channels, int height, int width, T* grad_in) :
|
||||
es(es), ta(ta), grad_out(grad_out), block_size(block_size), block_size2(block_size*block_size), eps(eps), batch_size(batch_size), channels(channels), height(height), width(width), grad_in(grad_in) {}
|
||||
|
||||
CPU_GPU_FUNCTION void operator()(int outidx) {
|
||||
// outidx \in [0, batch_size x height x width]
|
||||
|
||||
int w = outidx % width;
|
||||
int h = (outidx / width) % height;
|
||||
int n = outidx / (height * width);
|
||||
|
||||
for(int bidx = 0; bidx < block_size2; ++bidx) {
|
||||
int bh = bidx / block_size;
|
||||
int bw = bidx % block_size;
|
||||
int h0 = h + bh - block_size / 2;
|
||||
int w0 = w + bw - block_size / 2;
|
||||
|
||||
h0 = mmin(height-1, mmax(0, h0));
|
||||
w0 = mmin(width-1, mmax(0, w0));
|
||||
|
||||
const T go = grad_out[outidx];
|
||||
|
||||
for(int c = 0; c < channels; ++c) {
|
||||
int inidx = ((n * channels + c) * height + h0) * width + w0;
|
||||
if(type == PHOTOMETRIC_LOSS_SAD || type == PHOTOMETRIC_LOSS_MSE) {
|
||||
T diff = es[inidx] - ta[inidx];
|
||||
T grad = 0;
|
||||
if(type == PHOTOMETRIC_LOSS_MSE) {
|
||||
grad = 2 * diff;
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_SAD) {
|
||||
grad = diff < 0 ? -1 : (diff > 0 ? 1 : 0);
|
||||
}
|
||||
grad = grad / block_size2 * go;
|
||||
matomic_add(grad_in + inidx, grad);
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD || type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
|
||||
int inidxc = ((n * channels + c) * height + h) * width + w;
|
||||
T des = es[inidx] - es[inidxc];
|
||||
T dta = ta[inidx] - ta[inidxc];
|
||||
T h_des = 0.5 * (1 + des / sqrt(des * des + eps));
|
||||
T h_dta = 0.5 * (1 + dta / sqrt(dta * dta + eps));
|
||||
T diff = h_des - h_dta;
|
||||
|
||||
T grad_loss = 0;
|
||||
if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
|
||||
grad_loss = 2 * diff;
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
|
||||
grad_loss = diff < 0 ? -1 : (diff > 0 ? 1 : 0);
|
||||
}
|
||||
grad_loss = grad_loss / block_size2;
|
||||
|
||||
T tmp = des * des + eps;
|
||||
T grad_heaviside = 0.5 * eps / sqrt(tmp * tmp * tmp);
|
||||
|
||||
T grad = go * grad_loss * grad_heaviside;
|
||||
matomic_add(grad_in + inidx, grad);
|
||||
matomic_add(grad_in + inidxc, -grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,198 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ext.h"
|
||||
|
||||
template <typename FunctorT>
|
||||
void iterate_cpu(FunctorT functor, int N) {
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
functor(idx);
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor nn_cpu(at::Tensor in0, at::Tensor in1) {
|
||||
CHECK_INPUT_CPU(in0)
|
||||
CHECK_INPUT_CPU(in1)
|
||||
|
||||
auto nelem0 = in0.size(0);
|
||||
auto nelem1 = in1.size(0);
|
||||
auto dim = in0.size(1);
|
||||
|
||||
AT_ASSERTM(dim == in1.size(1), "in0 and in1 have to be the same shape")
|
||||
AT_ASSERTM(dim == 3, "dim hast to be 3")
|
||||
AT_ASSERTM(in0.dim() == 2, "in0 has to be N0 x 3")
|
||||
AT_ASSERTM(in1.dim() == 2, "in1 has to be N1 x 3")
|
||||
|
||||
auto out = at::empty({nelem0}, torch::CPU(at::kLong));
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(in0.scalar_type(), "nn", ([&] {
|
||||
iterate_cpu(
|
||||
NNFunctor<scalar_t>(in0.data<scalar_t>(), in1.data<scalar_t>(), nelem0, nelem1, out.data<long>()),
|
||||
nelem0);
|
||||
}));
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
at::Tensor crosscheck_cpu(at::Tensor in0, at::Tensor in1) {
|
||||
CHECK_INPUT_CPU(in0)
|
||||
CHECK_INPUT_CPU(in1)
|
||||
|
||||
AT_ASSERTM(in0.dim() == 1, "")
|
||||
AT_ASSERTM(in1.dim() == 1, "")
|
||||
|
||||
auto nelem0 = in0.size(0);
|
||||
auto nelem1 = in1.size(0);
|
||||
|
||||
auto out = at::empty({nelem0}, torch::CPU(at::kByte));
|
||||
|
||||
iterate_cpu(
|
||||
CrossCheckFunctor(in0.data<long>(), in1.data<long>(), nelem0, nelem1, out.data<uint8_t>()),
|
||||
nelem0);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
at::Tensor proj_nn_cpu(at::Tensor xyz0, at::Tensor xyz1, at::Tensor K, int patch_size) {
|
||||
CHECK_INPUT_CPU(xyz0)
|
||||
CHECK_INPUT_CPU(xyz1)
|
||||
CHECK_INPUT_CPU(K)
|
||||
|
||||
auto batch_size = xyz0.size(0);
|
||||
auto height = xyz0.size(1);
|
||||
auto width = xyz0.size(2);
|
||||
|
||||
AT_ASSERTM(xyz0.size(0) == xyz1.size(0), "")
|
||||
AT_ASSERTM(xyz0.size(1) == xyz1.size(1), "")
|
||||
AT_ASSERTM(xyz0.size(2) == xyz1.size(2), "")
|
||||
AT_ASSERTM(xyz0.size(3) == xyz1.size(3), "")
|
||||
AT_ASSERTM(xyz0.size(3) == 3, "")
|
||||
AT_ASSERTM(xyz0.dim() == 4, "")
|
||||
AT_ASSERTM(xyz1.dim() == 4, "")
|
||||
|
||||
auto out = at::empty({batch_size, height, width}, torch::CPU(at::kLong));
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz0.scalar_type(), "proj_nn", ([&] {
|
||||
iterate_cpu(
|
||||
ProjNNFunctor<scalar_t>(xyz0.data<scalar_t>(), xyz1.data<scalar_t>(), K.data<scalar_t>(), batch_size, height, width, patch_size, out.data<long>()),
|
||||
batch_size * height * width);
|
||||
}));
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
at::Tensor xcorrvol_cpu(at::Tensor in0, at::Tensor in1, int n_disps, int block_size) {
|
||||
CHECK_INPUT_CPU(in0)
|
||||
CHECK_INPUT_CPU(in1)
|
||||
|
||||
auto channels = in0.size(0);
|
||||
auto height = in0.size(1);
|
||||
auto width = in0.size(2);
|
||||
|
||||
auto out = at::empty({n_disps, height, width}, in0.options());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(in0.scalar_type(), "xcorrvol", ([&] {
|
||||
iterate_cpu(
|
||||
XCorrVolFunctor<scalar_t>(in0.data<scalar_t>(), in1.data<scalar_t>(), channels, height, width, n_disps, block_size, out.data<scalar_t>()),
|
||||
n_disps * height * width);
|
||||
}));
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
at::Tensor photometric_loss_forward(at::Tensor es, at::Tensor ta, int block_size, int type, float eps) {
|
||||
CHECK_INPUT_CPU(es)
|
||||
CHECK_INPUT_CPU(ta)
|
||||
|
||||
auto batch_size = es.size(0);
|
||||
auto channels = es.size(1);
|
||||
auto height = es.size(2);
|
||||
auto width = es.size(3);
|
||||
|
||||
auto out = at::empty({batch_size, 1, height, width}, es.options());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(es.scalar_type(), "photometric_loss_forward_cpu", ([&] {
|
||||
if(type == PHOTOMETRIC_LOSS_MSE) {
|
||||
iterate_cpu(
|
||||
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
|
||||
out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_SAD) {
|
||||
iterate_cpu(
|
||||
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
|
||||
out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
|
||||
iterate_cpu(
|
||||
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
|
||||
out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
|
||||
iterate_cpu(
|
||||
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
|
||||
out.numel());
|
||||
}
|
||||
}));
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
at::Tensor photometric_loss_backward(at::Tensor es, at::Tensor ta, at::Tensor grad_out, int block_size, int type, float eps) {
|
||||
CHECK_INPUT_CPU(es)
|
||||
CHECK_INPUT_CPU(ta)
|
||||
CHECK_INPUT_CPU(grad_out)
|
||||
|
||||
auto batch_size = es.size(0);
|
||||
auto channels = es.size(1);
|
||||
auto height = es.size(2);
|
||||
auto width = es.size(3);
|
||||
|
||||
CHECK_INPUT_CPU(ta)
|
||||
auto grad_in = at::zeros({batch_size, channels, height, width}, grad_out.options());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(es.scalar_type(), "photometric_loss_backward_cpu", ([&] {
|
||||
if(type == PHOTOMETRIC_LOSS_MSE) {
|
||||
iterate_cpu(
|
||||
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
|
||||
grad_out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_SAD) {
|
||||
iterate_cpu(
|
||||
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
|
||||
grad_out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
|
||||
iterate_cpu(
|
||||
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
|
||||
grad_out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
|
||||
iterate_cpu(
|
||||
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
|
||||
grad_out.numel());
|
||||
}
|
||||
}));
|
||||
|
||||
return grad_in;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("nn_cpu", &nn_cpu, "nn_cpu");
|
||||
m.def("crosscheck_cpu", &crosscheck_cpu, "crosscheck_cpu");
|
||||
m.def("proj_nn_cpu", &proj_nn_cpu, "proj_nn_cpu");
|
||||
|
||||
m.def("xcorrvol_cpu", &xcorrvol_cpu, "xcorrvol_cpu");
|
||||
|
||||
m.def("photometric_loss_forward", &photometric_loss_forward);
|
||||
m.def("photometric_loss_backward", &photometric_loss_backward);
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ext.h"
|
||||
|
||||
void nn_kernel(at::Tensor in0, at::Tensor in1, at::Tensor out);
|
||||
|
||||
at::Tensor nn_cuda(at::Tensor in0, at::Tensor in1) {
|
||||
CHECK_INPUT_CUDA(in0)
|
||||
CHECK_INPUT_CUDA(in1)
|
||||
|
||||
auto nelem0 = in0.size(0);
|
||||
auto dim = in0.size(1);
|
||||
|
||||
AT_ASSERTM(dim == in1.size(1), "in0 and in1 have to be the same shape")
|
||||
AT_ASSERTM(dim == 3, "dim hast to be 3")
|
||||
AT_ASSERTM(in0.dim() == 2, "in0 has to be N0 x 3")
|
||||
AT_ASSERTM(in1.dim() == 2, "in1 has to be N1 x 3")
|
||||
|
||||
auto out = at::empty({nelem0}, torch::CUDA(at::kLong));
|
||||
|
||||
nn_kernel(in0, in1, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
void crosscheck_kernel(at::Tensor in0, at::Tensor in1, at::Tensor out);
|
||||
|
||||
at::Tensor crosscheck_cuda(at::Tensor in0, at::Tensor in1) {
|
||||
CHECK_INPUT_CUDA(in0)
|
||||
CHECK_INPUT_CUDA(in1)
|
||||
|
||||
AT_ASSERTM(in0.dim() == 1, "")
|
||||
AT_ASSERTM(in1.dim() == 1, "")
|
||||
|
||||
auto nelem0 = in0.size(0);
|
||||
auto out = at::empty({nelem0}, torch::CUDA(at::kByte));
|
||||
crosscheck_kernel(in0, in1, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void proj_nn_kernel(at::Tensor xyz0, at::Tensor xyz1, at::Tensor K, int patch_size, at::Tensor out);
|
||||
|
||||
at::Tensor proj_nn_cuda(at::Tensor xyz0, at::Tensor xyz1, at::Tensor K, int patch_size) {
|
||||
CHECK_INPUT_CUDA(xyz0)
|
||||
CHECK_INPUT_CUDA(xyz1)
|
||||
CHECK_INPUT_CUDA(K)
|
||||
|
||||
auto batch_size = xyz0.size(0);
|
||||
auto height = xyz0.size(1);
|
||||
auto width = xyz0.size(2);
|
||||
|
||||
AT_ASSERTM(xyz0.size(0) == xyz1.size(0), "")
|
||||
AT_ASSERTM(xyz0.size(1) == xyz1.size(1), "")
|
||||
AT_ASSERTM(xyz0.size(2) == xyz1.size(2), "")
|
||||
AT_ASSERTM(xyz0.size(3) == xyz1.size(3), "")
|
||||
AT_ASSERTM(xyz0.size(3) == 3, "")
|
||||
AT_ASSERTM(xyz0.dim() == 4, "")
|
||||
AT_ASSERTM(xyz1.dim() == 4, "")
|
||||
|
||||
auto out = at::empty({batch_size, height, width}, torch::CUDA(at::kLong));
|
||||
|
||||
proj_nn_kernel(xyz0, xyz1, K, patch_size, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void xcorrvol_kernel(at::Tensor in0, at::Tensor in1, int n_disps, int block_size, at::Tensor out);
|
||||
|
||||
at::Tensor xcorrvol_cuda(at::Tensor in0, at::Tensor in1, int n_disps, int block_size) {
|
||||
CHECK_INPUT_CUDA(in0)
|
||||
CHECK_INPUT_CUDA(in1)
|
||||
|
||||
// auto channels = in0.size(0);
|
||||
auto height = in0.size(1);
|
||||
auto width = in0.size(2);
|
||||
|
||||
auto out = at::empty({n_disps, height, width}, in0.options());
|
||||
|
||||
xcorrvol_kernel(in0, in1, n_disps, block_size, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void photometric_loss_forward_kernel(at::Tensor es, at::Tensor ta, int block_size, int type, float eps, at::Tensor out);
|
||||
|
||||
at::Tensor photometric_loss_forward(at::Tensor es, at::Tensor ta, int block_size, int type, float eps) {
|
||||
CHECK_INPUT_CUDA(es)
|
||||
CHECK_INPUT_CUDA(ta)
|
||||
|
||||
auto batch_size = es.size(0);
|
||||
auto height = es.size(2);
|
||||
auto width = es.size(3);
|
||||
|
||||
auto out = at::empty({batch_size, 1, height, width}, es.options());
|
||||
photometric_loss_forward_kernel(es, ta, block_size, type, eps, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
void photometric_loss_backward_kernel(at::Tensor es, at::Tensor ta, at::Tensor grad_out, int block_size, int type, float eps, at::Tensor grad_in);
|
||||
|
||||
at::Tensor photometric_loss_backward(at::Tensor es, at::Tensor ta, at::Tensor grad_out, int block_size, int type, float eps) {
|
||||
CHECK_INPUT_CUDA(es)
|
||||
CHECK_INPUT_CUDA(ta)
|
||||
CHECK_INPUT_CUDA(grad_out)
|
||||
|
||||
auto batch_size = es.size(0);
|
||||
auto channels = es.size(1);
|
||||
auto height = es.size(2);
|
||||
auto width = es.size(3);
|
||||
|
||||
auto grad_in = at::zeros({batch_size, channels, height, width}, grad_out.options());
|
||||
photometric_loss_backward_kernel(es, ta, grad_out, block_size, type, eps, grad_in);
|
||||
|
||||
return grad_in;
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("nn_cuda", &nn_cuda, "nn_cuda");
|
||||
m.def("crosscheck_cuda", &crosscheck_cuda, "crosscheck_cuda");
|
||||
m.def("proj_nn_cuda", &proj_nn_cuda, "proj_nn_cuda");
|
||||
|
||||
m.def("xcorrvol_cuda", &xcorrvol_cuda, "xcorrvol_cuda");
|
||||
|
||||
m.def("photometric_loss_forward", &photometric_loss_forward);
|
||||
m.def("photometric_loss_backward", &photometric_loss_backward);
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include "ext.h"
|
||||
#include "common_cuda.h"
|
||||
|
||||
void nn_kernel(at::Tensor in0, at::Tensor in1, at::Tensor out) {
|
||||
auto nelem0 = in0.size(0);
|
||||
auto nelem1 = in1.size(0);
|
||||
auto dim = in0.size(1);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(in0.scalar_type(), "nn", ([&] {
|
||||
iterate_cuda(
|
||||
NNFunctor<scalar_t>(in0.data<scalar_t>(), in1.data<scalar_t>(), nelem0, nelem1, out.data<long>()),
|
||||
nelem0);
|
||||
}));
|
||||
}
|
||||
|
||||
|
||||
void crosscheck_kernel(at::Tensor in0, at::Tensor in1, at::Tensor out) {
|
||||
auto nelem0 = in0.size(0);
|
||||
auto nelem1 = in1.size(0);
|
||||
|
||||
iterate_cuda(
|
||||
CrossCheckFunctor(in0.data<long>(), in1.data<long>(), nelem0, nelem1, out.data<uint8_t>()),
|
||||
nelem0);
|
||||
}
|
||||
|
||||
void proj_nn_kernel(at::Tensor xyz0, at::Tensor xyz1, at::Tensor K, int patch_size, at::Tensor out) {
|
||||
auto batch_size = xyz0.size(0);
|
||||
auto height = xyz0.size(1);
|
||||
auto width = xyz0.size(2);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz0.scalar_type(), "proj_nn", ([&] {
|
||||
iterate_cuda(
|
||||
ProjNNFunctor<scalar_t>(xyz0.data<scalar_t>(), xyz1.data<scalar_t>(), K.data<scalar_t>(), batch_size, height, width, patch_size, out.data<long>()),
|
||||
batch_size * height * width);
|
||||
}));
|
||||
}
|
||||
|
||||
void xcorrvol_kernel(at::Tensor in0, at::Tensor in1, int n_disps, int block_size, at::Tensor out) {
|
||||
auto channels = in0.size(0);
|
||||
auto height = in0.size(1);
|
||||
auto width = in0.size(2);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(in0.scalar_type(), "xcorrvol", ([&] {
|
||||
iterate_cuda(
|
||||
XCorrVolFunctor<scalar_t>(in0.data<scalar_t>(), in1.data<scalar_t>(), channels, height, width, n_disps, block_size, out.data<scalar_t>()),
|
||||
n_disps * height * width, 512);
|
||||
}));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void photometric_loss_forward_kernel(at::Tensor es, at::Tensor ta, int block_size, int type, float eps, at::Tensor out) {
|
||||
auto batch_size = es.size(0);
|
||||
auto channels = es.size(1);
|
||||
auto height = es.size(2);
|
||||
auto width = es.size(3);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(es.scalar_type(), "photometric_loss_forward_cuda", ([&] {
|
||||
if(type == PHOTOMETRIC_LOSS_MSE) {
|
||||
iterate_cuda(
|
||||
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
|
||||
out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_SAD) {
|
||||
iterate_cuda(
|
||||
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
|
||||
out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
|
||||
iterate_cuda(
|
||||
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
|
||||
out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
|
||||
iterate_cuda(
|
||||
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
|
||||
out.numel());
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
void photometric_loss_backward_kernel(at::Tensor es, at::Tensor ta, at::Tensor grad_out, int block_size, int type, float eps, at::Tensor grad_in) {
|
||||
auto batch_size = es.size(0);
|
||||
auto channels = es.size(1);
|
||||
auto height = es.size(2);
|
||||
auto width = es.size(3);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(es.scalar_type(), "photometric_loss_backward_cuda", ([&] {
|
||||
if(type == PHOTOMETRIC_LOSS_MSE) {
|
||||
iterate_cuda(
|
||||
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
|
||||
grad_out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_SAD) {
|
||||
iterate_cuda(
|
||||
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
|
||||
grad_out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
|
||||
iterate_cuda(
|
||||
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
|
||||
grad_out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
|
||||
iterate_cuda(
|
||||
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
|
||||
grad_out.numel());
|
||||
}
|
||||
}));
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
import torch
|
||||
from . import ext_cpu
|
||||
from . import ext_cuda
|
||||
|
||||
class NNFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, in0, in1):
|
||||
args = (in0, in1)
|
||||
if in0.is_cuda:
|
||||
out = ext_cuda.nn_cuda(*args)
|
||||
else:
|
||||
out = ext_cpu.nn_cpu(*args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return None, None
|
||||
|
||||
def nn(in0, in1):
|
||||
return NNFunction.apply(in0, in1)
|
||||
|
||||
|
||||
class CrossCheckFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, in0, in1):
|
||||
args = (in0, in1)
|
||||
if in0.is_cuda:
|
||||
out = ext_cuda.crosscheck_cuda(*args)
|
||||
else:
|
||||
out = ext_cpu.crosscheck_cpu(*args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return None, None
|
||||
|
||||
def crosscheck(in0, in1):
|
||||
return CrossCheckFunction.apply(in0, in1)
|
||||
|
||||
class ProjNNFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, xyz0, xyz1, K, patch_size):
|
||||
args = (xyz0, xyz1, K, patch_size)
|
||||
if xyz0.is_cuda:
|
||||
out = ext_cuda.proj_nn_cuda(*args)
|
||||
else:
|
||||
out = ext_cpu.proj_nn_cpu(*args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return None, None, None, None
|
||||
|
||||
def proj_nn(xyz0, xyz1, K, patch_size):
|
||||
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
|
||||
|
||||
|
||||
|
||||
class XCorrVolFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, in0, in1, n_disps, block_size):
|
||||
args = (in0, in1, n_disps, block_size)
|
||||
if in0.is_cuda:
|
||||
out = ext_cuda.xcorrvol_cuda(*args)
|
||||
else:
|
||||
out = ext_cpu.xcorrvol_cpu(*args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return None, None, None, None
|
||||
|
||||
def xcorrvol(in0, in1, n_disps, block_size):
|
||||
return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
|
||||
|
||||
|
||||
|
||||
|
||||
class PhotometricLossFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, es, ta, block_size, type, eps):
|
||||
args = (es, ta, block_size, type, eps)
|
||||
ctx.save_for_backward(es, ta)
|
||||
ctx.block_size = block_size
|
||||
ctx.type = type
|
||||
ctx.eps = eps
|
||||
if es.is_cuda:
|
||||
out = ext_cuda.photometric_loss_forward(*args)
|
||||
else:
|
||||
out = ext_cpu.photometric_loss_forward(*args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
es, ta = ctx.saved_tensors
|
||||
block_size = ctx.block_size
|
||||
type = ctx.type
|
||||
eps = ctx.eps
|
||||
args = (es, ta, grad_out.contiguous(), block_size, type, eps)
|
||||
if grad_out.is_cuda:
|
||||
grad_es = ext_cuda.photometric_loss_backward(*args)
|
||||
else:
|
||||
grad_es = ext_cpu.photometric_loss_backward(*args)
|
||||
return grad_es, None, None, None, None
|
||||
|
||||
def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
|
||||
type = type.lower()
|
||||
if type == 'mse':
|
||||
type = 0
|
||||
elif type == 'sad':
|
||||
type = 1
|
||||
elif type == 'census_mse':
|
||||
type = 2
|
||||
elif type == 'census_sad':
|
||||
type = 3
|
||||
else:
|
||||
raise Exception('invalid loss type')
|
||||
return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
|
||||
|
||||
def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
|
||||
type = type.lower()
|
||||
p = block_size // 2
|
||||
es_pad = torch.nn.functional.pad(es, (p,p,p,p), mode='replicate')
|
||||
ta_pad = torch.nn.functional.pad(ta, (p,p,p,p), mode='replicate')
|
||||
es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
|
||||
ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
|
||||
es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
|
||||
ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
|
||||
if type == 'mse':
|
||||
ref = (es_uf - ta_uf)**2
|
||||
elif type == 'sad':
|
||||
ref = torch.abs(es_uf - ta_uf)
|
||||
elif type == 'census_mse' or type == 'census_sad':
|
||||
des = es_uf - es.unsqueeze(2)
|
||||
dta = ta_uf - ta.unsqueeze(2)
|
||||
h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps))
|
||||
h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps))
|
||||
diff = h_des - h_dta
|
||||
if type == 'census_mse':
|
||||
ref = diff * diff
|
||||
elif type == 'census_sad':
|
||||
ref = torch.abs(diff)
|
||||
else:
|
||||
raise Exception('invalid loss type')
|
||||
ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
|
||||
ref = torch.sum(ref, dim=1, keepdim=True) / block_size**2
|
||||
return ref
|
||||
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from .functions import *
|
||||
|
||||
class CoordConv2d(torch.nn.Module):
|
||||
def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
|
||||
super().__init__()
|
||||
|
||||
self.conv = torch.nn.Conv2d(channels_in+2, channels_out, kernel_size=kernel_size, padding=padding, stride=stride)
|
||||
|
||||
self.uv = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.uv is None:
|
||||
height, width = x.shape[2], x.shape[3]
|
||||
u, v = np.meshgrid(range(width), range(height))
|
||||
u = 2 * u / (width - 1) - 1
|
||||
v = 2 * v / (height - 1) - 1
|
||||
uv = np.stack((u, v)).reshape(1, 2, height, width)
|
||||
self.uv = torch.from_numpy( uv.astype(np.float32) )
|
||||
self.uv = self.uv.to(x.device)
|
||||
uv = self.uv.expand(x.shape[0], *self.uv.shape[1:])
|
||||
xuv = torch.cat((x, uv), dim=1)
|
||||
y = self.conv(xuv)
|
||||
return y
|
||||
@@ -0,0 +1,24 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension
|
||||
import os
|
||||
|
||||
this_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
include_dirs = [
|
||||
]
|
||||
|
||||
nvcc_args = [
|
||||
'-arch=sm_30',
|
||||
'-gencode=arch=compute_30,code=sm_30',
|
||||
'-gencode=arch=compute_35,code=sm_35',
|
||||
]
|
||||
|
||||
setup(
|
||||
name='ext',
|
||||
ext_modules=[
|
||||
CppExtension('ext_cpu', ['ext/ext_cpu.cpp']),
|
||||
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'], extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
|
||||
],
|
||||
cmdclass={'build_ext': BuildExtension},
|
||||
include_dirs=include_dirs
|
||||
)
|
||||
@@ -0,0 +1,528 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
import logging
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import subprocess
|
||||
import socket
|
||||
import sys
|
||||
import os
|
||||
import gc
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class StopWatch(object):
|
||||
def __init__(self):
|
||||
self.timings = OrderedDict()
|
||||
self.starts = {}
|
||||
|
||||
def start(self, name):
|
||||
self.starts[name] = time.time()
|
||||
|
||||
def stop(self, name):
|
||||
if name not in self.timings:
|
||||
self.timings[name] = []
|
||||
self.timings[name].append(time.time() - self.starts[name])
|
||||
|
||||
def get(self, name=None, reduce=np.sum):
|
||||
if name is not None:
|
||||
return reduce(self.timings[name])
|
||||
else:
|
||||
ret = {}
|
||||
for k in self.timings:
|
||||
ret[k] = reduce(self.timings[k])
|
||||
return ret
|
||||
|
||||
def __repr__(self):
|
||||
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
|
||||
def __str__(self):
|
||||
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
|
||||
|
||||
|
||||
class ETA(object):
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
self.start_time = time.time()
|
||||
self.current_idx = 0
|
||||
self.current_time = time.time()
|
||||
|
||||
def update(self, idx):
|
||||
self.current_idx = idx
|
||||
self.current_time = time.time()
|
||||
|
||||
def get_elapsed_time(self):
|
||||
return self.current_time - self.start_time
|
||||
|
||||
def get_item_time(self):
|
||||
return self.get_elapsed_time() / (self.current_idx + 1)
|
||||
|
||||
def get_remaining_time(self):
|
||||
return self.get_item_time() * (self.length - self.current_idx + 1)
|
||||
|
||||
def format_time(self, seconds):
|
||||
minutes, seconds = divmod(seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
hours = int(hours)
|
||||
minutes = int(minutes)
|
||||
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
|
||||
|
||||
def get_elapsed_time_str(self):
|
||||
return self.format_time(self.get_elapsed_time())
|
||||
|
||||
def get_remaining_time_str(self):
|
||||
return self.format_time(self.get_remaining_time())
|
||||
|
||||
class Worker(object):
|
||||
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
|
||||
self.out_root = Path(out_root)
|
||||
self.experiment_name = experiment_name
|
||||
self.epochs = epochs
|
||||
self.seed = seed
|
||||
self.train_batch_size = train_batch_size
|
||||
self.test_batch_size = test_batch_size
|
||||
self.num_workers = num_workers
|
||||
self.save_frequency = save_frequency
|
||||
self.train_device = train_device
|
||||
self.test_device = test_device
|
||||
self.max_train_iter = max_train_iter
|
||||
|
||||
self.errs_list=[]
|
||||
|
||||
self.setup_experiment()
|
||||
|
||||
def setup_experiment(self):
|
||||
self.exp_out_root = self.out_root / self.experiment_name
|
||||
self.exp_out_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if logging.root: del logging.root.handlers[:]
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
handlers=[
|
||||
logging.FileHandler( str(self.exp_out_root / 'train.log') ),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s'
|
||||
)
|
||||
|
||||
logging.info('='*80)
|
||||
logging.info(f'Start of experiment: {self.experiment_name}')
|
||||
logging.info(socket.gethostname())
|
||||
self.log_datetime()
|
||||
logging.info('='*80)
|
||||
|
||||
self.metric_path = self.exp_out_root / 'metrics.json'
|
||||
if self.metric_path.exists():
|
||||
with open(str(self.metric_path), 'r') as fp:
|
||||
self.metric_data = json.load(fp)
|
||||
else:
|
||||
self.metric_data = {}
|
||||
|
||||
self.init_seed()
|
||||
|
||||
def metric_add_train(self, epoch, key, val):
|
||||
epoch = str(epoch)
|
||||
key = str(key)
|
||||
if epoch not in self.metric_data:
|
||||
self.metric_data[epoch] = {}
|
||||
if 'train' not in self.metric_data[epoch]:
|
||||
self.metric_data[epoch]['train'] = {}
|
||||
self.metric_data[epoch]['train'][key] = val
|
||||
|
||||
def metric_add_test(self, epoch, set_idx, key, val):
|
||||
epoch = str(epoch)
|
||||
set_idx = str(set_idx)
|
||||
key = str(key)
|
||||
if epoch not in self.metric_data:
|
||||
self.metric_data[epoch] = {}
|
||||
if 'test' not in self.metric_data[epoch]:
|
||||
self.metric_data[epoch]['test'] = {}
|
||||
if set_idx not in self.metric_data[epoch]['test']:
|
||||
self.metric_data[epoch]['test'][set_idx] = {}
|
||||
self.metric_data[epoch]['test'][set_idx][key] = val
|
||||
|
||||
def metric_save(self):
|
||||
with open(str(self.metric_path), 'w') as fp:
|
||||
json.dump(self.metric_data, fp, indent=2)
|
||||
|
||||
def init_seed(self, seed=None):
|
||||
if seed is not None:
|
||||
self.seed = seed
|
||||
logging.info(f'Set seed to {self.seed}')
|
||||
np.random.seed(self.seed)
|
||||
random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
torch.cuda.manual_seed(self.seed)
|
||||
|
||||
def log_datetime(self):
|
||||
logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
|
||||
def mem_report(self):
|
||||
for obj in gc.get_objects():
|
||||
if torch.is_tensor(obj):
|
||||
print(type(obj), obj.shape)
|
||||
|
||||
def get_net_path(self, epoch, root=None):
|
||||
if root is None:
|
||||
root = self.exp_out_root
|
||||
return root / f'net_{epoch:04d}.params'
|
||||
|
||||
def get_do_parser_cmds(self):
|
||||
return ['retrain', 'resume', 'retest', 'test_init']
|
||||
|
||||
def get_do_parser(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds())
|
||||
parser.add_argument('--epoch', type=int, default=-1)
|
||||
return parser
|
||||
|
||||
def do_cmd(self, args, net, optimizer, scheduler=None):
|
||||
if args.cmd == 'retrain':
|
||||
self.train(net, optimizer, resume=False, scheduler=scheduler)
|
||||
elif args.cmd == 'resume':
|
||||
self.train(net, optimizer, resume=True, scheduler=scheduler)
|
||||
elif args.cmd == 'retest':
|
||||
self.retest(net, epoch=args.epoch)
|
||||
elif args.cmd == 'test_init':
|
||||
test_sets = self.get_test_sets()
|
||||
self.test(-1, net, test_sets)
|
||||
else:
|
||||
raise Exception('invalid cmd')
|
||||
|
||||
def do(self, net, optimizer, load_net_optimizer=None, scheduler=None):
|
||||
parser = self.get_do_parser()
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if load_net_optimizer is not None and args.cmd not in ['schedule']:
|
||||
net, optimizer = load_net_optimizer()
|
||||
|
||||
self.do_cmd(args, net, optimizer, scheduler=scheduler)
|
||||
|
||||
def retest(self, net, epoch=-1):
|
||||
if epoch < 0:
|
||||
epochs = range(self.epochs)
|
||||
else:
|
||||
epochs = [epoch]
|
||||
|
||||
test_sets = self.get_test_sets()
|
||||
|
||||
for epoch in epochs:
|
||||
net_path = self.get_net_path(epoch)
|
||||
if net_path.exists():
|
||||
state_dict = torch.load(str(net_path))
|
||||
net.load_state_dict(state_dict)
|
||||
self.test(epoch, net, test_sets)
|
||||
|
||||
def format_err_str(self, errs, div=1):
|
||||
err = sum(errs)
|
||||
if len(errs) > 1:
|
||||
err_str = f'{err/div:0.4f}=' + '+'.join([f'{e/div:0.4f}' for e in errs])
|
||||
else:
|
||||
err_str = f'{err/div:0.4f}'
|
||||
return err_str
|
||||
|
||||
def write_err_img(self):
|
||||
err_img_path = self.exp_out_root / 'errs.png'
|
||||
fig = plt.figure(figsize=(16,16))
|
||||
lines=[]
|
||||
for idx,errs in enumerate(self.errs_list):
|
||||
line,=plt.plot(range(len(errs)), errs, label=f'error{idx}')
|
||||
lines.append(line)
|
||||
plt.tight_layout()
|
||||
plt.legend(handles=lines)
|
||||
plt.savefig(str(err_img_path))
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def callback_train_new_epoch(self, epoch, net, optimizer):
|
||||
pass
|
||||
|
||||
def train(self, net, optimizer, resume=False, scheduler=None):
|
||||
logging.info('='*80)
|
||||
logging.info('Start training')
|
||||
self.log_datetime()
|
||||
logging.info('='*80)
|
||||
|
||||
train_set = self.get_train_set()
|
||||
test_sets = self.get_test_sets()
|
||||
|
||||
net = net.to(self.train_device)
|
||||
|
||||
epoch = 0
|
||||
min_err = {ts.name: 1e9 for ts in test_sets}
|
||||
|
||||
state_path = self.exp_out_root / 'state.dict'
|
||||
if resume and state_path.exists():
|
||||
logging.info('='*80)
|
||||
logging.info(f'Loading state from {state_path}')
|
||||
logging.info('='*80)
|
||||
state = torch.load(str(state_path))
|
||||
epoch = state['epoch'] + 1
|
||||
if 'min_err' in state:
|
||||
min_err = state['min_err']
|
||||
|
||||
curr_state = net.state_dict()
|
||||
curr_state.update(state['state_dict'])
|
||||
net.load_state_dict(curr_state)
|
||||
|
||||
|
||||
try:
|
||||
optimizer.load_state_dict(state['optimizer'])
|
||||
except:
|
||||
logging.info('Warning: cannot load optimizer from state_dict')
|
||||
pass
|
||||
if 'cpu_rng_state' in state:
|
||||
torch.set_rng_state(state['cpu_rng_state'])
|
||||
if 'gpu_rng_state' in state:
|
||||
torch.cuda.set_rng_state(state['gpu_rng_state'])
|
||||
|
||||
for epoch in range(epoch, self.epochs):
|
||||
self.callback_train_new_epoch(epoch, net, optimizer)
|
||||
|
||||
# train epoch
|
||||
self.train_epoch(epoch, net, optimizer, train_set)
|
||||
|
||||
# test epoch
|
||||
errs = self.test(epoch, net, test_sets)
|
||||
|
||||
if (epoch + 1) % self.save_frequency == 0:
|
||||
net = net.to(self.train_device)
|
||||
|
||||
# store state
|
||||
state_dict = {
|
||||
'epoch': epoch,
|
||||
'min_err': min_err,
|
||||
'state_dict': net.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'cpu_rng_state': torch.get_rng_state(),
|
||||
'gpu_rng_state': torch.cuda.get_rng_state(),
|
||||
}
|
||||
logging.info(f'save state to {state_path}')
|
||||
state_path = self.exp_out_root / 'state.dict'
|
||||
torch.save(state_dict, str(state_path))
|
||||
|
||||
for test_set_name in errs:
|
||||
err = sum(errs[test_set_name])
|
||||
if err < min_err[test_set_name]:
|
||||
min_err[test_set_name] = err
|
||||
state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict'
|
||||
logging.info(f'save state to {state_path}')
|
||||
torch.save(state_dict, str(state_path))
|
||||
|
||||
# store network
|
||||
net_path = self.get_net_path(epoch)
|
||||
logging.info(f'save network to {net_path}')
|
||||
torch.save(net.state_dict(), str(net_path))
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
logging.info('='*80)
|
||||
logging.info('Finished training')
|
||||
self.log_datetime()
|
||||
logging.info('='*80)
|
||||
|
||||
def get_train_set(self):
|
||||
# returns train_set
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_test_sets(self):
|
||||
# returns test_sets
|
||||
raise NotImplementedError()
|
||||
|
||||
def copy_data(self, data, device, requires_grad, train):
|
||||
raise NotImplementedError()
|
||||
|
||||
def net_forward(self, net, train):
|
||||
raise NotImplementedError()
|
||||
|
||||
def loss_forward(self, output, train):
|
||||
raise NotImplementedError()
|
||||
|
||||
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
|
||||
# err = False
|
||||
# for name, param in net.named_parameters():
|
||||
# if not torch.isfinite(param.grad).all():
|
||||
# print(name)
|
||||
# err = True
|
||||
# if err:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
pass
|
||||
|
||||
def callback_train_start(self, epoch):
|
||||
pass
|
||||
|
||||
def callback_train_stop(self, epoch, loss):
|
||||
pass
|
||||
|
||||
def train_epoch(self, epoch, net, optimizer, dset):
|
||||
self.callback_train_start(epoch)
|
||||
stopwatch = StopWatch()
|
||||
|
||||
logging.info('='*80)
|
||||
logging.info('Train epoch %d' % epoch)
|
||||
|
||||
dset.current_epoch = epoch
|
||||
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=False)
|
||||
|
||||
net = net.to(self.train_device)
|
||||
net.train()
|
||||
|
||||
mean_loss = None
|
||||
|
||||
n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader)
|
||||
bar = ETA(length=n_batches)
|
||||
|
||||
stopwatch.start('total')
|
||||
stopwatch.start('data')
|
||||
for batch_idx, data in enumerate(train_loader):
|
||||
if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break
|
||||
self.copy_data(data, device=self.train_device, requires_grad=True, train=True)
|
||||
stopwatch.stop('data')
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
stopwatch.start('forward')
|
||||
output = self.net_forward(net, train=True)
|
||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('forward')
|
||||
|
||||
stopwatch.start('loss')
|
||||
errs = self.loss_forward(output, train=True)
|
||||
if isinstance(errs, dict):
|
||||
masks = errs['masks']
|
||||
errs = errs['errs']
|
||||
else:
|
||||
masks = []
|
||||
if not isinstance(errs, list) and not isinstance(errs, tuple):
|
||||
errs = [errs]
|
||||
err = sum(errs)
|
||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('loss')
|
||||
|
||||
stopwatch.start('backward')
|
||||
err.backward()
|
||||
self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks)
|
||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('backward')
|
||||
|
||||
stopwatch.start('optimizer')
|
||||
optimizer.step()
|
||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('optimizer')
|
||||
|
||||
bar.update(batch_idx)
|
||||
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
|
||||
err_str = self.format_err_str(errs)
|
||||
logging.info(f'train e{epoch}: {batch_idx+1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
||||
#self.write_err_img()
|
||||
|
||||
|
||||
if mean_loss is None:
|
||||
mean_loss = [0 for e in errs]
|
||||
for erridx, err in enumerate(errs):
|
||||
mean_loss[erridx] += err.item()
|
||||
|
||||
stopwatch.start('data')
|
||||
stopwatch.stop('total')
|
||||
logging.info('timings: %s' % stopwatch)
|
||||
|
||||
mean_loss = [l / len(train_loader) for l in mean_loss]
|
||||
self.callback_train_stop(epoch, mean_loss)
|
||||
self.metric_add_train(epoch, 'loss', mean_loss)
|
||||
|
||||
# save metrics
|
||||
self.metric_save()
|
||||
|
||||
err_str = self.format_err_str(mean_loss)
|
||||
logging.info(f'avg train_loss={err_str}')
|
||||
return mean_loss
|
||||
|
||||
def callback_test_start(self, epoch, set_idx):
|
||||
pass
|
||||
|
||||
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
|
||||
pass
|
||||
|
||||
def callback_test_stop(self, epoch, set_idx, loss):
|
||||
pass
|
||||
|
||||
def test(self, epoch, net, test_sets):
|
||||
errs = {}
|
||||
for test_set_idx, test_set in enumerate(test_sets):
|
||||
if (epoch + 1) % test_set.test_frequency == 0:
|
||||
logging.info('='*80)
|
||||
logging.info(f'testing set {test_set.name}')
|
||||
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
|
||||
errs[test_set.name] = err
|
||||
return errs
|
||||
|
||||
def test_epoch(self, epoch, set_idx, net, dset):
|
||||
logging.info('-'*80)
|
||||
logging.info('Test epoch %d' % epoch)
|
||||
dset.current_epoch = epoch
|
||||
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False, pin_memory=False)
|
||||
|
||||
net = net.to(self.test_device)
|
||||
net.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
mean_loss = None
|
||||
|
||||
self.callback_test_start(epoch, set_idx)
|
||||
|
||||
bar = ETA(length=len(test_loader))
|
||||
stopwatch = StopWatch()
|
||||
stopwatch.start('total')
|
||||
stopwatch.start('data')
|
||||
for batch_idx, data in enumerate(test_loader):
|
||||
# if batch_idx == 10: break
|
||||
self.copy_data(data, device=self.test_device, requires_grad=False, train=False)
|
||||
stopwatch.stop('data')
|
||||
|
||||
stopwatch.start('forward')
|
||||
output = self.net_forward(net, train=False)
|
||||
if 'cuda' in self.test_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('forward')
|
||||
|
||||
stopwatch.start('loss')
|
||||
errs = self.loss_forward(output, train=False)
|
||||
if isinstance(errs, dict):
|
||||
masks = errs['masks']
|
||||
errs = errs['errs']
|
||||
else:
|
||||
masks = []
|
||||
if not isinstance(errs, list) and not isinstance(errs, tuple):
|
||||
errs = [errs]
|
||||
|
||||
bar.update(batch_idx)
|
||||
if batch_idx % 25 == 0:
|
||||
err_str = self.format_err_str(errs)
|
||||
logging.info(f'test e{epoch}: {batch_idx+1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
||||
|
||||
if mean_loss is None:
|
||||
mean_loss = [0 for e in errs]
|
||||
for erridx, err in enumerate(errs):
|
||||
mean_loss[erridx] += err.item()
|
||||
stopwatch.stop('loss')
|
||||
|
||||
self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks)
|
||||
|
||||
stopwatch.start('data')
|
||||
stopwatch.stop('total')
|
||||
logging.info('timings: %s' % stopwatch)
|
||||
|
||||
mean_loss = [l / len(test_loader) for l in mean_loss]
|
||||
self.callback_test_stop(epoch, set_idx, mean_loss)
|
||||
self.metric_add_test(epoch, set_idx, 'loss', mean_loss)
|
||||
|
||||
# save metrics
|
||||
self.metric_save()
|
||||
|
||||
err_str = self.format_err_str(mean_loss)
|
||||
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
|
||||
return mean_loss
|
||||
Reference in New Issue
Block a user