This commit is contained in:
Yiyi Liao
2019-06-13 16:25:11 +02:00
parent 26157cbb80
commit f5e5c4bd3f
84 changed files with 31343 additions and 2 deletions
+4
View File
@@ -0,0 +1,4 @@
from .dataset import *
from .worker import *
from .functions import *
from .modules import *
+66
View File
@@ -0,0 +1,66 @@
import torch
import torch.utils.data
import numpy as np
class TestSet(object):
def __init__(self, name, dset, test_frequency=1):
self.name = name
self.dset = dset
self.test_frequency = test_frequency
class TestSets(list):
def append(self, name, dset, test_frequency=1):
super().append(TestSet(name, dset, test_frequency))
class MultiDataset(torch.utils.data.Dataset):
def __init__(self, *datasets):
self.current_epoch = 0
self.datasets = []
self.cum_n_samples = [0]
for dataset in datasets:
self.append(dataset)
def append(self, dataset):
self.datasets.append(dataset)
self.__update_cum_n_samples(dataset)
def __update_cum_n_samples(self, dataset):
n_samples = self.cum_n_samples[-1] + len(dataset)
self.cum_n_samples.append(n_samples)
def dataset_updated(self):
self.cum_n_samples = [0]
for dset in self.datasets:
self.__update_cum_n_samples(dset)
def __len__(self):
return self.cum_n_samples[-1]
def __getitem__(self, idx):
didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1
sidx = idx - self.cum_n_samples[didx]
return self.datasets[didx][sidx]
class BaseDataset(torch.utils.data.Dataset):
def __init__(self, train=True, fix_seed_per_epoch=False):
self.current_epoch = 0
self.train = train
self.fix_seed_per_epoch = fix_seed_per_epoch
def get_rng(self, idx):
rng = np.random.RandomState()
if self.train:
if self.fix_seed_per_epoch:
seed = 1 * len(self) + idx
else:
seed = (self.current_epoch + 1) * len(self) + idx
rng.seed(seed)
else:
rng.seed(idx)
return rng
+10
View File
@@ -0,0 +1,10 @@
#ifndef TYPES_H
#define TYPES_H
#ifdef __CUDA_ARCH__
#define CPU_GPU_FUNCTION __host__ __device__
#else
#define CPU_GPU_FUNCTION
#endif
#endif
+135
View File
@@ -0,0 +1,135 @@
#ifndef COMMON_H
#define COMMON_H
#include "co_types.h"
#include <cmath>
#include <algorithm>
#if defined(_OPENMP)
#include <omp.h>
#endif
#define DISABLE_COPY_AND_ASSIGN(classname) \
private:\
classname(const classname&) = delete;\
classname& operator=(const classname&) = delete;
template <typename T>
CPU_GPU_FUNCTION
void fill(T* arr, int N, T val) {
for(int idx = 0; idx < N; ++idx) {
arr[idx] = val;
}
}
template <typename T>
CPU_GPU_FUNCTION
void fill_zero(T* arr, int N) {
for(int idx = 0; idx < N; ++idx) {
arr[idx] = 0;
}
}
template <typename T>
CPU_GPU_FUNCTION
inline T distance_euclidean(const T* q, const T* t, int N) {
T out = 0;
for(int idx = 0; idx < N; idx++) {
T diff = q[idx] - t[idx];
out += diff * diff;
}
return out;
}
template <typename T>
CPU_GPU_FUNCTION
inline T distance_l2(const T* q, const T* t, int N) {
T out = distance_euclidean(q, t, N);
out = std::sqrt(out);
return out;
}
template <typename T>
struct FillFunctor {
T* arr;
const T val;
FillFunctor(T* arr, const T val) : arr(arr), val(val) {}
CPU_GPU_FUNCTION void operator()(const int idx) {
arr[idx] = val;
}
};
template <typename T>
CPU_GPU_FUNCTION
T mmin(const T& a, const T& b) {
#ifdef __CUDA_ARCH__
return min(a, b);
#else
return std::min(a, b);
#endif
}
template <typename T>
CPU_GPU_FUNCTION
T mmax(const T& a, const T& b) {
#ifdef __CUDA_ARCH__
return max(a, b);
#else
return std::max(a, b);
#endif
}
template <typename T>
CPU_GPU_FUNCTION
T mround(const T& a) {
#ifdef __CUDA_ARCH__
return round(a);
#else
return round(a);
#endif
}
#ifdef __CUDA_ARCH__
#if __CUDA_ARCH__ < 600
__device__ double atomicAdd(double* address, double val)
{
unsigned long long int* address_as_ull =
(unsigned long long int*)address;
unsigned long long int old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val +
__longlong_as_double(assumed)));
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
} while (assumed != old);
return __longlong_as_double(old);
}
#endif
#endif
template <typename T>
CPU_GPU_FUNCTION
void matomic_add(T* addr, T val) {
#ifdef __CUDA_ARCH__
atomicAdd(addr, val);
#else
#if defined(_OPENMP)
#pragma omp atomic
#endif
*addr += val;
#endif
}
#endif
+173
View File
@@ -0,0 +1,173 @@
#ifndef COMMON_CUDA
#define COMMON_CUDA
#include <cublas_v2.h>
#include <stdio.h>
#define DEBUG 0
#define CUDA_DEBUG_DEVICE_SYNC 0
// cuda check for cudaMalloc and so on
#define CUDA_CHECK(condition) \
/* Code block avoids redefinition of cudaError_t error */ \
do { \
if(CUDA_DEBUG_DEVICE_SYNC) { cudaDeviceSynchronize(); } \
cudaError_t error = condition; \
if(error != cudaSuccess) { \
printf("%s in %s at %d\n", cudaGetErrorString(error), __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
/// Get error string for error code.
/// @param error
inline const char* cublasGetErrorString(cublasStatus_t error) {
switch (error) {
case CUBLAS_STATUS_SUCCESS:
return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED:
return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED:
return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE:
return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH:
return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR:
return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED:
return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR:
return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED:
return "CUBLAS_STATUS_NOT_SUPPORTED";
case CUBLAS_STATUS_LICENSE_ERROR:
return "CUBLAS_STATUS_LICENSE_ERROR";
}
return "Unknown cublas status";
}
#define CUBLAS_CHECK(condition) \
do { \
if(CUDA_DEBUG_DEVICE_SYNC) { cudaDeviceSynchronize(); } \
cublasStatus_t status = condition; \
if(status != CUBLAS_STATUS_SUCCESS) { \
printf("%s in %s at %d\n", cublasGetErrorString(status), __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
// check if there is a error after kernel execution
#define CUDA_POST_KERNEL_CHECK \
CUDA_CHECK(cudaPeekAtLastError()); \
CUDA_CHECK(cudaGetLastError());
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N, const int N_THREADS=CUDA_NUM_THREADS) {
return (N + N_THREADS - 1) / N_THREADS;
}
template<typename T>
T* device_malloc(long N) {
T* dptr;
CUDA_CHECK(cudaMalloc(&dptr, N * sizeof(T)));
if(DEBUG) { printf("[DEBUG] device_malloc %p, %ld\n", dptr, N); }
return dptr;
}
template<typename T>
void device_free(T* dptr) {
if(DEBUG) { printf("[DEBUG] device_free %p\n", dptr); }
CUDA_CHECK(cudaFree(dptr));
}
template<typename T>
void host_to_device(const T* hptr, T* dptr, long N) {
if(DEBUG) { printf("[DEBUG] host_to_device %p => %p, %ld\n", hptr, dptr, N); }
CUDA_CHECK(cudaMemcpy(dptr, hptr, N * sizeof(T), cudaMemcpyHostToDevice));
}
template<typename T>
T* host_to_device_malloc(const T* hptr, long N) {
T* dptr = device_malloc<T>(N);
host_to_device(hptr, dptr, N);
return dptr;
}
template<typename T>
void device_to_host(const T* dptr, T* hptr, long N) {
if(DEBUG) { printf("[DEBUG] device_to_host %p => %p, %ld\n", dptr, hptr, N); }
CUDA_CHECK(cudaMemcpy(hptr, dptr, N * sizeof(T), cudaMemcpyDeviceToHost));
}
template<typename T>
T* device_to_host_malloc(const T* dptr, long N) {
T* hptr = new T[N];
device_to_host(dptr, hptr, N);
return hptr;
}
template<typename T>
void device_to_device(const T* dptr, T* hptr, long N) {
if(DEBUG) { printf("[DEBUG] device_to_device %p => %p, %ld\n", dptr, hptr, N); }
CUDA_CHECK(cudaMemcpy(hptr, dptr, N * sizeof(T), cudaMemcpyDeviceToDevice));
}
// https://github.com/parallel-forall/code-samples/blob/master/posts/cuda-aware-mpi-example/src/Device.cu
// https://github.com/treecode/Bonsai/blob/master/runtime/profiling/derived_atomic_functions.h
__device__ __forceinline__ void atomicMaxF(float * const address, const float value) {
if (*address >= value) {
return;
}
int * const address_as_i = (int *)address;
int old = * address_as_i, assumed;
do {
assumed = old;
if (__int_as_float(assumed) >= value) {
break;
}
old = atomicCAS(address_as_i, assumed, __float_as_int(value));
} while (assumed != old);
}
__device__ __forceinline__ void atomicMinF(float * const address, const float value) {
if (*address <= value) {
return;
}
int * const address_as_i = (int *)address;
int old = * address_as_i, assumed;
do {
assumed = old;
if (__int_as_float(assumed) <= value) {
break;
}
old = atomicCAS(address_as_i, assumed, __float_as_int(value));
} while (assumed != old);
}
template <typename FunctorT>
__global__ void iterate_kernel(FunctorT functor, int N) {
CUDA_KERNEL_LOOP(idx, N) {
functor(idx);
}
}
template <typename FunctorT>
void iterate_cuda(FunctorT functor, int N, int N_THREADS=CUDA_NUM_THREADS) {
iterate_kernel<<<GET_BLOCKS(N, N_THREADS), N_THREADS>>>(functor, N);
CUDA_POST_KERNEL_CHECK;
}
#endif
+347
View File
@@ -0,0 +1,347 @@
#pragma once
#include "common.h"
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT_CPU(x) CHECK_CONTIGUOUS(x)
#define CHECK_INPUT_CUDA(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
template <typename T, int dim=3>
struct NNFunctor {
const T* in0; // nelem0 x dim
const T* in1; // nelem1 x dim
const long nelem0;
const long nelem1;
long* out; // nelem0
NNFunctor(const T* in0, const T* in1, long nelem0, long nelem1, long* out) : in0(in0), in1(in1), nelem0(nelem0), nelem1(nelem1), out(out) {}
CPU_GPU_FUNCTION void operator()(long idx0) {
// idx0 \in [nelem0]
const T* vec0 = in0 + idx0 * dim;
T min_dist = 1e9;
long min_arg = -1;
for(long idx1 = 0; idx1 < nelem1; ++idx1) {
const T* vec1 = in1 + idx1 * dim;
T dist = 0;
for(long didx = 0; didx < dim; ++didx) {
T diff = vec0[didx] - vec1[didx];
dist += diff * diff;
}
if(dist < min_dist) {
min_dist = dist;
min_arg = idx1;
}
}
out[idx0] = min_arg;
}
};
struct CrossCheckFunctor {
const long* in0; // nelem0
const long* in1; // nelem1
const long nelem0;
const long nelem1;
uint8_t* out; // nelem0
CrossCheckFunctor(const long* in0, const long* in1, long nelem0, long nelem1, uint8_t* out) : in0(in0), in1(in1), nelem0(nelem0), nelem1(nelem1), out(out) {}
CPU_GPU_FUNCTION void operator()(long idx0) {
// idx0 \in [nelem0]
int idx1 = in0[idx0];
out[idx0] = idx1 >=0 && in1[idx1] >= 0 && idx0 == in1[idx1];
// out[idx0] = idx0 == in1[in0[idx0]];
}
};
template <typename T, int dim=3>
struct ProjNNFunctor {
// xyz0, xyz1 in coord sys of 1
const T* xyz0; // bs x height x width x 3
const T* xyz1; // bs x height x width x 3
const T* K; // 3 x 3
const long batch_size;
const long height;
const long width;
const long patch_size;
long* out; // bs x height x width
ProjNNFunctor(const T* xyz0, const T* xyz1, const T* K, long batch_size, long height, long width, long patch_size, long* out)
: xyz0(xyz0), xyz1(xyz1), K(K), batch_size(batch_size), height(height), width(width), patch_size(patch_size), out(out) {}
CPU_GPU_FUNCTION void operator()(long idx0) {
// idx0 \in [0, bs x height x width]
const long bs = idx0 / (height * width);
const T x = xyz0[idx0 * 3 + 0];
const T y = xyz0[idx0 * 3 + 1];
const T z = xyz0[idx0 * 3 + 2];
const T d = K[6] * x + K[7] * y + K[8] * z;
const T u = (K[0] * x + K[1] * y + K[2] * z) / d;
const T v = (K[3] * x + K[4] * y + K[5] * z) / d;
int u0 = u + 0.5;
int v0 = v + 0.5;
long min_idx1 = -1;
T min_dist = 1e9;
for(int pidx = 0; pidx < patch_size*patch_size; ++pidx) {
int pu = pidx % patch_size;
int pv = pidx / patch_size;
int u1 = u0 + pu - patch_size/2;
int v1 = v0 + pv - patch_size/2;
if(u1 >= 0 && v1 >= 0 && u1 < width && v1 < height) {
const long idx1 = (bs * height + v1) * width + u1;
const T* xyz1n = xyz1 + idx1 * 3;
const T d = (x-xyz1n[0]) * (x-xyz1n[0]) + (y-xyz1n[1]) * (y-xyz1n[1]) + (z-xyz1n[2]) * (z-xyz1n[2]);
if(d < min_dist) {
min_dist = d;
min_idx1 = idx1;
}
}
}
out[idx0] = min_idx1;
}
};
template <typename T, int dim=3>
struct XCorrVolFunctor {
const T* in0; // channels x height x width
const T* in1; // channels x height x width
const long channels;
const long height;
const long width;
const long n_disps;
const long block_size;
T* out; // nelem0
XCorrVolFunctor(const T* in0, const T* in1, long channels, long height, long width, long n_disps, long block_size, T* out) : in0(in0), in1(in1), channels(channels), height(height), width(width), n_disps(n_disps), block_size(block_size), out(out) {}
CPU_GPU_FUNCTION void operator()(long oidx) {
// idx0 \in [n_disps x height x width]
auto d = oidx / (height * width);
auto h = (oidx / width) % height;
auto w = oidx % width;
long block_size2 = block_size * block_size;
T val = 0;
for(int c = 0; c < channels; ++c) {
// compute means
T mu0 = 0;
T mu1 = 0;
for(int bh = 0; bh < block_size; ++bh) {
long h0 = h + bh - block_size / 2;
h0 = mmax(long(0), mmin(height-1, h0));
for(int bw = 0; bw < block_size; ++bw) {
long w0 = w + bw - block_size / 2;
long w1 = w0 - d;
w0 = mmax(long(0), mmin(width-1, w0));
w1 = mmax(long(0), mmin(width-1, w1));
long idx0 = (c * height + h0) * width + w0;
long idx1 = (c * height + h0) * width + w1;
mu0 += in0[idx0] / block_size2;
mu1 += in1[idx1] / block_size2;
}
}
// compute stds and dot product
T sigma0 = 0;
T sigma1 = 0;
T dot = 0;
for(int bh = 0; bh < block_size; ++bh) {
long h0 = h + bh - block_size / 2;
h0 = mmax(long(0), mmin(height-1, h0));
for(int bw = 0; bw < block_size; ++bw) {
long w0 = w + bw - block_size / 2;
long w1 = w0 - d;
w0 = mmax(long(0), mmin(width-1, w0));
w1 = mmax(long(0), mmin(width-1, w1));
long idx0 = (c * height + h0) * width + w0;
long idx1 = (c * height + h0) * width + w1;
T v0 = in0[idx0] - mu0;
T v1 = in1[idx1] - mu1;
dot += v0 * v1;
sigma0 += v0 * v0;
sigma1 += v1 * v1;
}
}
T norm = sqrt(sigma0 * sigma1) + 1e-8;
val += dot / norm;
}
out[oidx] = val;
}
};
const int PHOTOMETRIC_LOSS_MSE = 0;
const int PHOTOMETRIC_LOSS_SAD = 1;
const int PHOTOMETRIC_LOSS_CENSUS_MSE = 2;
const int PHOTOMETRIC_LOSS_CENSUS_SAD = 3;
template <typename T, int type>
struct PhotometricLossForward {
const T* es; // batch_size x channels x height x width;
const T* ta;
const int block_size;
const int block_size2;
const T eps;
const int batch_size;
const int channels;
const int height;
const int width;
T* out; // batch_size x channels x height x width;
PhotometricLossForward(const T* es, const T* ta, int block_size, T eps, int batch_size, int channels, int height, int width, T* out) :
es(es), ta(ta), block_size(block_size), block_size2(block_size*block_size), eps(eps), batch_size(batch_size), channels(channels), height(height), width(width), out(out) {}
CPU_GPU_FUNCTION void operator()(int outidx) {
// outidx \in [0, batch_size x height x width]
int w = outidx % width;
int h = (outidx / width) % height;
int n = outidx / (height * width);
T loss = 0;
for(int bidx = 0; bidx < block_size2; ++bidx) {
int bh = bidx / block_size;
int bw = bidx % block_size;
int h0 = h + bh - block_size / 2;
int w0 = w + bw - block_size / 2;
h0 = mmin(height-1, mmax(0, h0));
w0 = mmin(width-1, mmax(0, w0));
for(int c = 0; c < channels; ++c) {
int inidx = ((n * channels + c) * height + h0) * width + w0;
if(type == PHOTOMETRIC_LOSS_SAD || type == PHOTOMETRIC_LOSS_MSE) {
T diff = es[inidx] - ta[inidx];
if(type == PHOTOMETRIC_LOSS_MSE) {
loss += diff * diff / block_size2;
}
else if(type == PHOTOMETRIC_LOSS_SAD) {
loss += fabs(diff) / block_size2;
}
}
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD || type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
int inidxc = ((n * channels + c) * height + h) * width + w;
T des = es[inidx] - es[inidxc];
T dta = ta[inidx] - ta[inidxc];
T h_des = 0.5 * (1 + des / sqrt(des * des + eps));
T h_dta = 0.5 * (1 + dta / sqrt(dta * dta + eps));
T diff = h_des - h_dta;
// printf("%d,%d %d,%d: des=%f, dta=%f, h_des=%f, h_dta=%f, diff=%f\n", h,w, h0,w0, des,dta, h_des,h_dta, diff);
// printf("%d,%d %d,%d: h_des=%f = 0.5 * (1 + %f / %f); %f, %f, %f\n", h,w, h0,w0, h_des, des, sqrt(des * des + eps), des*des, des*des+eps, eps);
if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
loss += diff * diff / block_size2;
}
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
loss += fabs(diff) / block_size2;
}
}
}
}
out[outidx] = loss;
}
};
template <typename T, int type>
struct PhotometricLossBackward {
const T* es; // batch_size x channels x height x width;
const T* ta;
const T* grad_out;
const int block_size;
const int block_size2;
const T eps;
const int batch_size;
const int channels;
const int height;
const int width;
T* grad_in; // batch_size x channels x height x width;
PhotometricLossBackward(const T* es, const T* ta, const T* grad_out, int block_size, T eps, int batch_size, int channels, int height, int width, T* grad_in) :
es(es), ta(ta), grad_out(grad_out), block_size(block_size), block_size2(block_size*block_size), eps(eps), batch_size(batch_size), channels(channels), height(height), width(width), grad_in(grad_in) {}
CPU_GPU_FUNCTION void operator()(int outidx) {
// outidx \in [0, batch_size x height x width]
int w = outidx % width;
int h = (outidx / width) % height;
int n = outidx / (height * width);
for(int bidx = 0; bidx < block_size2; ++bidx) {
int bh = bidx / block_size;
int bw = bidx % block_size;
int h0 = h + bh - block_size / 2;
int w0 = w + bw - block_size / 2;
h0 = mmin(height-1, mmax(0, h0));
w0 = mmin(width-1, mmax(0, w0));
const T go = grad_out[outidx];
for(int c = 0; c < channels; ++c) {
int inidx = ((n * channels + c) * height + h0) * width + w0;
if(type == PHOTOMETRIC_LOSS_SAD || type == PHOTOMETRIC_LOSS_MSE) {
T diff = es[inidx] - ta[inidx];
T grad = 0;
if(type == PHOTOMETRIC_LOSS_MSE) {
grad = 2 * diff;
}
else if(type == PHOTOMETRIC_LOSS_SAD) {
grad = diff < 0 ? -1 : (diff > 0 ? 1 : 0);
}
grad = grad / block_size2 * go;
matomic_add(grad_in + inidx, grad);
}
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD || type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
int inidxc = ((n * channels + c) * height + h) * width + w;
T des = es[inidx] - es[inidxc];
T dta = ta[inidx] - ta[inidxc];
T h_des = 0.5 * (1 + des / sqrt(des * des + eps));
T h_dta = 0.5 * (1 + dta / sqrt(dta * dta + eps));
T diff = h_des - h_dta;
T grad_loss = 0;
if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
grad_loss = 2 * diff;
}
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
grad_loss = diff < 0 ? -1 : (diff > 0 ? 1 : 0);
}
grad_loss = grad_loss / block_size2;
T tmp = des * des + eps;
T grad_heaviside = 0.5 * eps / sqrt(tmp * tmp * tmp);
T grad = go * grad_loss * grad_heaviside;
matomic_add(grad_in + inidx, grad);
matomic_add(grad_in + inidxc, -grad);
}
}
}
}
};
+198
View File
@@ -0,0 +1,198 @@
#include <torch/extension.h>
#include <iostream>
#include "ext.h"
template <typename FunctorT>
void iterate_cpu(FunctorT functor, int N) {
for(int idx = 0; idx < N; ++idx) {
functor(idx);
}
}
at::Tensor nn_cpu(at::Tensor in0, at::Tensor in1) {
CHECK_INPUT_CPU(in0)
CHECK_INPUT_CPU(in1)
auto nelem0 = in0.size(0);
auto nelem1 = in1.size(0);
auto dim = in0.size(1);
AT_ASSERTM(dim == in1.size(1), "in0 and in1 have to be the same shape")
AT_ASSERTM(dim == 3, "dim hast to be 3")
AT_ASSERTM(in0.dim() == 2, "in0 has to be N0 x 3")
AT_ASSERTM(in1.dim() == 2, "in1 has to be N1 x 3")
auto out = at::empty({nelem0}, torch::CPU(at::kLong));
AT_DISPATCH_FLOATING_TYPES(in0.scalar_type(), "nn", ([&] {
iterate_cpu(
NNFunctor<scalar_t>(in0.data<scalar_t>(), in1.data<scalar_t>(), nelem0, nelem1, out.data<long>()),
nelem0);
}));
return out;
}
at::Tensor crosscheck_cpu(at::Tensor in0, at::Tensor in1) {
CHECK_INPUT_CPU(in0)
CHECK_INPUT_CPU(in1)
AT_ASSERTM(in0.dim() == 1, "")
AT_ASSERTM(in1.dim() == 1, "")
auto nelem0 = in0.size(0);
auto nelem1 = in1.size(0);
auto out = at::empty({nelem0}, torch::CPU(at::kByte));
iterate_cpu(
CrossCheckFunctor(in0.data<long>(), in1.data<long>(), nelem0, nelem1, out.data<uint8_t>()),
nelem0);
return out;
}
at::Tensor proj_nn_cpu(at::Tensor xyz0, at::Tensor xyz1, at::Tensor K, int patch_size) {
CHECK_INPUT_CPU(xyz0)
CHECK_INPUT_CPU(xyz1)
CHECK_INPUT_CPU(K)
auto batch_size = xyz0.size(0);
auto height = xyz0.size(1);
auto width = xyz0.size(2);
AT_ASSERTM(xyz0.size(0) == xyz1.size(0), "")
AT_ASSERTM(xyz0.size(1) == xyz1.size(1), "")
AT_ASSERTM(xyz0.size(2) == xyz1.size(2), "")
AT_ASSERTM(xyz0.size(3) == xyz1.size(3), "")
AT_ASSERTM(xyz0.size(3) == 3, "")
AT_ASSERTM(xyz0.dim() == 4, "")
AT_ASSERTM(xyz1.dim() == 4, "")
auto out = at::empty({batch_size, height, width}, torch::CPU(at::kLong));
AT_DISPATCH_FLOATING_TYPES(xyz0.scalar_type(), "proj_nn", ([&] {
iterate_cpu(
ProjNNFunctor<scalar_t>(xyz0.data<scalar_t>(), xyz1.data<scalar_t>(), K.data<scalar_t>(), batch_size, height, width, patch_size, out.data<long>()),
batch_size * height * width);
}));
return out;
}
at::Tensor xcorrvol_cpu(at::Tensor in0, at::Tensor in1, int n_disps, int block_size) {
CHECK_INPUT_CPU(in0)
CHECK_INPUT_CPU(in1)
auto channels = in0.size(0);
auto height = in0.size(1);
auto width = in0.size(2);
auto out = at::empty({n_disps, height, width}, in0.options());
AT_DISPATCH_FLOATING_TYPES(in0.scalar_type(), "xcorrvol", ([&] {
iterate_cpu(
XCorrVolFunctor<scalar_t>(in0.data<scalar_t>(), in1.data<scalar_t>(), channels, height, width, n_disps, block_size, out.data<scalar_t>()),
n_disps * height * width);
}));
return out;
}
at::Tensor photometric_loss_forward(at::Tensor es, at::Tensor ta, int block_size, int type, float eps) {
CHECK_INPUT_CPU(es)
CHECK_INPUT_CPU(ta)
auto batch_size = es.size(0);
auto channels = es.size(1);
auto height = es.size(2);
auto width = es.size(3);
auto out = at::empty({batch_size, 1, height, width}, es.options());
AT_DISPATCH_FLOATING_TYPES(es.scalar_type(), "photometric_loss_forward_cpu", ([&] {
if(type == PHOTOMETRIC_LOSS_MSE) {
iterate_cpu(
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
out.numel());
}
else if(type == PHOTOMETRIC_LOSS_SAD) {
iterate_cpu(
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
out.numel());
}
else if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
iterate_cpu(
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
out.numel());
}
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
iterate_cpu(
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
out.numel());
}
}));
return out;
}
at::Tensor photometric_loss_backward(at::Tensor es, at::Tensor ta, at::Tensor grad_out, int block_size, int type, float eps) {
CHECK_INPUT_CPU(es)
CHECK_INPUT_CPU(ta)
CHECK_INPUT_CPU(grad_out)
auto batch_size = es.size(0);
auto channels = es.size(1);
auto height = es.size(2);
auto width = es.size(3);
CHECK_INPUT_CPU(ta)
auto grad_in = at::zeros({batch_size, channels, height, width}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(es.scalar_type(), "photometric_loss_backward_cpu", ([&] {
if(type == PHOTOMETRIC_LOSS_MSE) {
iterate_cpu(
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
grad_out.numel());
}
else if(type == PHOTOMETRIC_LOSS_SAD) {
iterate_cpu(
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
grad_out.numel());
}
else if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
iterate_cpu(
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
grad_out.numel());
}
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
iterate_cpu(
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
grad_out.numel());
}
}));
return grad_in;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nn_cpu", &nn_cpu, "nn_cpu");
m.def("crosscheck_cpu", &crosscheck_cpu, "crosscheck_cpu");
m.def("proj_nn_cpu", &proj_nn_cpu, "proj_nn_cpu");
m.def("xcorrvol_cpu", &xcorrvol_cpu, "xcorrvol_cpu");
m.def("photometric_loss_forward", &photometric_loss_forward);
m.def("photometric_loss_backward", &photometric_loss_backward);
}
+135
View File
@@ -0,0 +1,135 @@
#include <torch/extension.h>
#include <iostream>
#include "ext.h"
void nn_kernel(at::Tensor in0, at::Tensor in1, at::Tensor out);
at::Tensor nn_cuda(at::Tensor in0, at::Tensor in1) {
CHECK_INPUT_CUDA(in0)
CHECK_INPUT_CUDA(in1)
auto nelem0 = in0.size(0);
auto dim = in0.size(1);
AT_ASSERTM(dim == in1.size(1), "in0 and in1 have to be the same shape")
AT_ASSERTM(dim == 3, "dim hast to be 3")
AT_ASSERTM(in0.dim() == 2, "in0 has to be N0 x 3")
AT_ASSERTM(in1.dim() == 2, "in1 has to be N1 x 3")
auto out = at::empty({nelem0}, torch::CUDA(at::kLong));
nn_kernel(in0, in1, out);
return out;
}
void crosscheck_kernel(at::Tensor in0, at::Tensor in1, at::Tensor out);
at::Tensor crosscheck_cuda(at::Tensor in0, at::Tensor in1) {
CHECK_INPUT_CUDA(in0)
CHECK_INPUT_CUDA(in1)
AT_ASSERTM(in0.dim() == 1, "")
AT_ASSERTM(in1.dim() == 1, "")
auto nelem0 = in0.size(0);
auto out = at::empty({nelem0}, torch::CUDA(at::kByte));
crosscheck_kernel(in0, in1, out);
return out;
}
void proj_nn_kernel(at::Tensor xyz0, at::Tensor xyz1, at::Tensor K, int patch_size, at::Tensor out);
at::Tensor proj_nn_cuda(at::Tensor xyz0, at::Tensor xyz1, at::Tensor K, int patch_size) {
CHECK_INPUT_CUDA(xyz0)
CHECK_INPUT_CUDA(xyz1)
CHECK_INPUT_CUDA(K)
auto batch_size = xyz0.size(0);
auto height = xyz0.size(1);
auto width = xyz0.size(2);
AT_ASSERTM(xyz0.size(0) == xyz1.size(0), "")
AT_ASSERTM(xyz0.size(1) == xyz1.size(1), "")
AT_ASSERTM(xyz0.size(2) == xyz1.size(2), "")
AT_ASSERTM(xyz0.size(3) == xyz1.size(3), "")
AT_ASSERTM(xyz0.size(3) == 3, "")
AT_ASSERTM(xyz0.dim() == 4, "")
AT_ASSERTM(xyz1.dim() == 4, "")
auto out = at::empty({batch_size, height, width}, torch::CUDA(at::kLong));
proj_nn_kernel(xyz0, xyz1, K, patch_size, out);
return out;
}
void xcorrvol_kernel(at::Tensor in0, at::Tensor in1, int n_disps, int block_size, at::Tensor out);
at::Tensor xcorrvol_cuda(at::Tensor in0, at::Tensor in1, int n_disps, int block_size) {
CHECK_INPUT_CUDA(in0)
CHECK_INPUT_CUDA(in1)
// auto channels = in0.size(0);
auto height = in0.size(1);
auto width = in0.size(2);
auto out = at::empty({n_disps, height, width}, in0.options());
xcorrvol_kernel(in0, in1, n_disps, block_size, out);
return out;
}
void photometric_loss_forward_kernel(at::Tensor es, at::Tensor ta, int block_size, int type, float eps, at::Tensor out);
at::Tensor photometric_loss_forward(at::Tensor es, at::Tensor ta, int block_size, int type, float eps) {
CHECK_INPUT_CUDA(es)
CHECK_INPUT_CUDA(ta)
auto batch_size = es.size(0);
auto height = es.size(2);
auto width = es.size(3);
auto out = at::empty({batch_size, 1, height, width}, es.options());
photometric_loss_forward_kernel(es, ta, block_size, type, eps, out);
return out;
}
void photometric_loss_backward_kernel(at::Tensor es, at::Tensor ta, at::Tensor grad_out, int block_size, int type, float eps, at::Tensor grad_in);
at::Tensor photometric_loss_backward(at::Tensor es, at::Tensor ta, at::Tensor grad_out, int block_size, int type, float eps) {
CHECK_INPUT_CUDA(es)
CHECK_INPUT_CUDA(ta)
CHECK_INPUT_CUDA(grad_out)
auto batch_size = es.size(0);
auto channels = es.size(1);
auto height = es.size(2);
auto width = es.size(3);
auto grad_in = at::zeros({batch_size, channels, height, width}, grad_out.options());
photometric_loss_backward_kernel(es, ta, grad_out, block_size, type, eps, grad_in);
return grad_in;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nn_cuda", &nn_cuda, "nn_cuda");
m.def("crosscheck_cuda", &crosscheck_cuda, "crosscheck_cuda");
m.def("proj_nn_cuda", &proj_nn_cuda, "proj_nn_cuda");
m.def("xcorrvol_cuda", &xcorrvol_cuda, "xcorrvol_cuda");
m.def("photometric_loss_forward", &photometric_loss_forward);
m.def("photometric_loss_backward", &photometric_loss_backward);
}
+112
View File
@@ -0,0 +1,112 @@
#include <ATen/ATen.h>
#include "ext.h"
#include "common_cuda.h"
void nn_kernel(at::Tensor in0, at::Tensor in1, at::Tensor out) {
auto nelem0 = in0.size(0);
auto nelem1 = in1.size(0);
auto dim = in0.size(1);
AT_DISPATCH_FLOATING_TYPES(in0.scalar_type(), "nn", ([&] {
iterate_cuda(
NNFunctor<scalar_t>(in0.data<scalar_t>(), in1.data<scalar_t>(), nelem0, nelem1, out.data<long>()),
nelem0);
}));
}
void crosscheck_kernel(at::Tensor in0, at::Tensor in1, at::Tensor out) {
auto nelem0 = in0.size(0);
auto nelem1 = in1.size(0);
iterate_cuda(
CrossCheckFunctor(in0.data<long>(), in1.data<long>(), nelem0, nelem1, out.data<uint8_t>()),
nelem0);
}
void proj_nn_kernel(at::Tensor xyz0, at::Tensor xyz1, at::Tensor K, int patch_size, at::Tensor out) {
auto batch_size = xyz0.size(0);
auto height = xyz0.size(1);
auto width = xyz0.size(2);
AT_DISPATCH_FLOATING_TYPES(xyz0.scalar_type(), "proj_nn", ([&] {
iterate_cuda(
ProjNNFunctor<scalar_t>(xyz0.data<scalar_t>(), xyz1.data<scalar_t>(), K.data<scalar_t>(), batch_size, height, width, patch_size, out.data<long>()),
batch_size * height * width);
}));
}
void xcorrvol_kernel(at::Tensor in0, at::Tensor in1, int n_disps, int block_size, at::Tensor out) {
auto channels = in0.size(0);
auto height = in0.size(1);
auto width = in0.size(2);
AT_DISPATCH_FLOATING_TYPES(in0.scalar_type(), "xcorrvol", ([&] {
iterate_cuda(
XCorrVolFunctor<scalar_t>(in0.data<scalar_t>(), in1.data<scalar_t>(), channels, height, width, n_disps, block_size, out.data<scalar_t>()),
n_disps * height * width, 512);
}));
}
void photometric_loss_forward_kernel(at::Tensor es, at::Tensor ta, int block_size, int type, float eps, at::Tensor out) {
auto batch_size = es.size(0);
auto channels = es.size(1);
auto height = es.size(2);
auto width = es.size(3);
AT_DISPATCH_FLOATING_TYPES(es.scalar_type(), "photometric_loss_forward_cuda", ([&] {
if(type == PHOTOMETRIC_LOSS_MSE) {
iterate_cuda(
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
out.numel());
}
else if(type == PHOTOMETRIC_LOSS_SAD) {
iterate_cuda(
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
out.numel());
}
else if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
iterate_cuda(
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
out.numel());
}
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
iterate_cuda(
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
out.numel());
}
}));
}
void photometric_loss_backward_kernel(at::Tensor es, at::Tensor ta, at::Tensor grad_out, int block_size, int type, float eps, at::Tensor grad_in) {
auto batch_size = es.size(0);
auto channels = es.size(1);
auto height = es.size(2);
auto width = es.size(3);
AT_DISPATCH_FLOATING_TYPES(es.scalar_type(), "photometric_loss_backward_cuda", ([&] {
if(type == PHOTOMETRIC_LOSS_MSE) {
iterate_cuda(
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
grad_out.numel());
}
else if(type == PHOTOMETRIC_LOSS_SAD) {
iterate_cuda(
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
grad_out.numel());
}
else if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
iterate_cuda(
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
grad_out.numel());
}
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
iterate_cuda(
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
grad_out.numel());
}
}));
}
+147
View File
@@ -0,0 +1,147 @@
import torch
from . import ext_cpu
from . import ext_cuda
class NNFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, in0, in1):
args = (in0, in1)
if in0.is_cuda:
out = ext_cuda.nn_cuda(*args)
else:
out = ext_cpu.nn_cpu(*args)
return out
@staticmethod
def backward(ctx, grad_out):
return None, None
def nn(in0, in1):
return NNFunction.apply(in0, in1)
class CrossCheckFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, in0, in1):
args = (in0, in1)
if in0.is_cuda:
out = ext_cuda.crosscheck_cuda(*args)
else:
out = ext_cpu.crosscheck_cpu(*args)
return out
@staticmethod
def backward(ctx, grad_out):
return None, None
def crosscheck(in0, in1):
return CrossCheckFunction.apply(in0, in1)
class ProjNNFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz0, xyz1, K, patch_size):
args = (xyz0, xyz1, K, patch_size)
if xyz0.is_cuda:
out = ext_cuda.proj_nn_cuda(*args)
else:
out = ext_cpu.proj_nn_cpu(*args)
return out
@staticmethod
def backward(ctx, grad_out):
return None, None, None, None
def proj_nn(xyz0, xyz1, K, patch_size):
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
class XCorrVolFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, in0, in1, n_disps, block_size):
args = (in0, in1, n_disps, block_size)
if in0.is_cuda:
out = ext_cuda.xcorrvol_cuda(*args)
else:
out = ext_cpu.xcorrvol_cpu(*args)
return out
@staticmethod
def backward(ctx, grad_out):
return None, None, None, None
def xcorrvol(in0, in1, n_disps, block_size):
return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
class PhotometricLossFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, es, ta, block_size, type, eps):
args = (es, ta, block_size, type, eps)
ctx.save_for_backward(es, ta)
ctx.block_size = block_size
ctx.type = type
ctx.eps = eps
if es.is_cuda:
out = ext_cuda.photometric_loss_forward(*args)
else:
out = ext_cpu.photometric_loss_forward(*args)
return out
@staticmethod
def backward(ctx, grad_out):
es, ta = ctx.saved_tensors
block_size = ctx.block_size
type = ctx.type
eps = ctx.eps
args = (es, ta, grad_out.contiguous(), block_size, type, eps)
if grad_out.is_cuda:
grad_es = ext_cuda.photometric_loss_backward(*args)
else:
grad_es = ext_cpu.photometric_loss_backward(*args)
return grad_es, None, None, None, None
def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
type = type.lower()
if type == 'mse':
type = 0
elif type == 'sad':
type = 1
elif type == 'census_mse':
type = 2
elif type == 'census_sad':
type = 3
else:
raise Exception('invalid loss type')
return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
type = type.lower()
p = block_size // 2
es_pad = torch.nn.functional.pad(es, (p,p,p,p), mode='replicate')
ta_pad = torch.nn.functional.pad(ta, (p,p,p,p), mode='replicate')
es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
if type == 'mse':
ref = (es_uf - ta_uf)**2
elif type == 'sad':
ref = torch.abs(es_uf - ta_uf)
elif type == 'census_mse' or type == 'census_sad':
des = es_uf - es.unsqueeze(2)
dta = ta_uf - ta.unsqueeze(2)
h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps))
h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps))
diff = h_des - h_dta
if type == 'census_mse':
ref = diff * diff
elif type == 'census_sad':
ref = torch.abs(diff)
else:
raise Exception('invalid loss type')
ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
ref = torch.sum(ref, dim=1, keepdim=True) / block_size**2
return ref
+27
View File
@@ -0,0 +1,27 @@
import torch
import math
import numpy as np
from .functions import *
class CoordConv2d(torch.nn.Module):
def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
super().__init__()
self.conv = torch.nn.Conv2d(channels_in+2, channels_out, kernel_size=kernel_size, padding=padding, stride=stride)
self.uv = None
def forward(self, x):
if self.uv is None:
height, width = x.shape[2], x.shape[3]
u, v = np.meshgrid(range(width), range(height))
u = 2 * u / (width - 1) - 1
v = 2 * v / (height - 1) - 1
uv = np.stack((u, v)).reshape(1, 2, height, width)
self.uv = torch.from_numpy( uv.astype(np.float32) )
self.uv = self.uv.to(x.device)
uv = self.uv.expand(x.shape[0], *self.uv.shape[1:])
xuv = torch.cat((x, uv), dim=1)
y = self.conv(xuv)
return y
+24
View File
@@ -0,0 +1,24 @@
from setuptools import setup
from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension
import os
this_dir = os.path.dirname(os.path.realpath(__file__))
include_dirs = [
]
nvcc_args = [
'-arch=sm_30',
'-gencode=arch=compute_30,code=sm_30',
'-gencode=arch=compute_35,code=sm_35',
]
setup(
name='ext',
ext_modules=[
CppExtension('ext_cpu', ['ext/ext_cpu.cpp']),
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'], extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
],
cmdclass={'build_ext': BuildExtension},
include_dirs=include_dirs
)
+528
View File
@@ -0,0 +1,528 @@
import numpy as np
import torch
import random
import logging
import datetime
from pathlib import Path
import argparse
import subprocess
import socket
import sys
import os
import gc
import json
import matplotlib.pyplot as plt
import time
from collections import OrderedDict
class StopWatch(object):
def __init__(self):
self.timings = OrderedDict()
self.starts = {}
def start(self, name):
self.starts[name] = time.time()
def stop(self, name):
if name not in self.timings:
self.timings[name] = []
self.timings[name].append(time.time() - self.starts[name])
def get(self, name=None, reduce=np.sum):
if name is not None:
return reduce(self.timings[name])
else:
ret = {}
for k in self.timings:
ret[k] = reduce(self.timings[k])
return ret
def __repr__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
def __str__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
class ETA(object):
def __init__(self, length):
self.length = length
self.start_time = time.time()
self.current_idx = 0
self.current_time = time.time()
def update(self, idx):
self.current_idx = idx
self.current_time = time.time()
def get_elapsed_time(self):
return self.current_time - self.start_time
def get_item_time(self):
return self.get_elapsed_time() / (self.current_idx + 1)
def get_remaining_time(self):
return self.get_item_time() * (self.length - self.current_idx + 1)
def format_time(self, seconds):
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
hours = int(hours)
minutes = int(minutes)
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
def get_elapsed_time_str(self):
return self.format_time(self.get_elapsed_time())
def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time())
class Worker(object):
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
self.out_root = Path(out_root)
self.experiment_name = experiment_name
self.epochs = epochs
self.seed = seed
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.num_workers = num_workers
self.save_frequency = save_frequency
self.train_device = train_device
self.test_device = test_device
self.max_train_iter = max_train_iter
self.errs_list=[]
self.setup_experiment()
def setup_experiment(self):
self.exp_out_root = self.out_root / self.experiment_name
self.exp_out_root.mkdir(parents=True, exist_ok=True)
if logging.root: del logging.root.handlers[:]
logging.basicConfig(
level=logging.INFO,
handlers=[
logging.FileHandler( str(self.exp_out_root / 'train.log') ),
logging.StreamHandler()
],
format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s'
)
logging.info('='*80)
logging.info(f'Start of experiment: {self.experiment_name}')
logging.info(socket.gethostname())
self.log_datetime()
logging.info('='*80)
self.metric_path = self.exp_out_root / 'metrics.json'
if self.metric_path.exists():
with open(str(self.metric_path), 'r') as fp:
self.metric_data = json.load(fp)
else:
self.metric_data = {}
self.init_seed()
def metric_add_train(self, epoch, key, val):
epoch = str(epoch)
key = str(key)
if epoch not in self.metric_data:
self.metric_data[epoch] = {}
if 'train' not in self.metric_data[epoch]:
self.metric_data[epoch]['train'] = {}
self.metric_data[epoch]['train'][key] = val
def metric_add_test(self, epoch, set_idx, key, val):
epoch = str(epoch)
set_idx = str(set_idx)
key = str(key)
if epoch not in self.metric_data:
self.metric_data[epoch] = {}
if 'test' not in self.metric_data[epoch]:
self.metric_data[epoch]['test'] = {}
if set_idx not in self.metric_data[epoch]['test']:
self.metric_data[epoch]['test'][set_idx] = {}
self.metric_data[epoch]['test'][set_idx][key] = val
def metric_save(self):
with open(str(self.metric_path), 'w') as fp:
json.dump(self.metric_data, fp, indent=2)
def init_seed(self, seed=None):
if seed is not None:
self.seed = seed
logging.info(f'Set seed to {self.seed}')
np.random.seed(self.seed)
random.seed(self.seed)
torch.manual_seed(self.seed)
torch.cuda.manual_seed(self.seed)
def log_datetime(self):
logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
def mem_report(self):
for obj in gc.get_objects():
if torch.is_tensor(obj):
print(type(obj), obj.shape)
def get_net_path(self, epoch, root=None):
if root is None:
root = self.exp_out_root
return root / f'net_{epoch:04d}.params'
def get_do_parser_cmds(self):
return ['retrain', 'resume', 'retest', 'test_init']
def get_do_parser(self):
parser = argparse.ArgumentParser()
parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds())
parser.add_argument('--epoch', type=int, default=-1)
return parser
def do_cmd(self, args, net, optimizer, scheduler=None):
if args.cmd == 'retrain':
self.train(net, optimizer, resume=False, scheduler=scheduler)
elif args.cmd == 'resume':
self.train(net, optimizer, resume=True, scheduler=scheduler)
elif args.cmd == 'retest':
self.retest(net, epoch=args.epoch)
elif args.cmd == 'test_init':
test_sets = self.get_test_sets()
self.test(-1, net, test_sets)
else:
raise Exception('invalid cmd')
def do(self, net, optimizer, load_net_optimizer=None, scheduler=None):
parser = self.get_do_parser()
args, _ = parser.parse_known_args()
if load_net_optimizer is not None and args.cmd not in ['schedule']:
net, optimizer = load_net_optimizer()
self.do_cmd(args, net, optimizer, scheduler=scheduler)
def retest(self, net, epoch=-1):
if epoch < 0:
epochs = range(self.epochs)
else:
epochs = [epoch]
test_sets = self.get_test_sets()
for epoch in epochs:
net_path = self.get_net_path(epoch)
if net_path.exists():
state_dict = torch.load(str(net_path))
net.load_state_dict(state_dict)
self.test(epoch, net, test_sets)
def format_err_str(self, errs, div=1):
err = sum(errs)
if len(errs) > 1:
err_str = f'{err/div:0.4f}=' + '+'.join([f'{e/div:0.4f}' for e in errs])
else:
err_str = f'{err/div:0.4f}'
return err_str
def write_err_img(self):
err_img_path = self.exp_out_root / 'errs.png'
fig = plt.figure(figsize=(16,16))
lines=[]
for idx,errs in enumerate(self.errs_list):
line,=plt.plot(range(len(errs)), errs, label=f'error{idx}')
lines.append(line)
plt.tight_layout()
plt.legend(handles=lines)
plt.savefig(str(err_img_path))
plt.close(fig)
def callback_train_new_epoch(self, epoch, net, optimizer):
pass
def train(self, net, optimizer, resume=False, scheduler=None):
logging.info('='*80)
logging.info('Start training')
self.log_datetime()
logging.info('='*80)
train_set = self.get_train_set()
test_sets = self.get_test_sets()
net = net.to(self.train_device)
epoch = 0
min_err = {ts.name: 1e9 for ts in test_sets}
state_path = self.exp_out_root / 'state.dict'
if resume and state_path.exists():
logging.info('='*80)
logging.info(f'Loading state from {state_path}')
logging.info('='*80)
state = torch.load(str(state_path))
epoch = state['epoch'] + 1
if 'min_err' in state:
min_err = state['min_err']
curr_state = net.state_dict()
curr_state.update(state['state_dict'])
net.load_state_dict(curr_state)
try:
optimizer.load_state_dict(state['optimizer'])
except:
logging.info('Warning: cannot load optimizer from state_dict')
pass
if 'cpu_rng_state' in state:
torch.set_rng_state(state['cpu_rng_state'])
if 'gpu_rng_state' in state:
torch.cuda.set_rng_state(state['gpu_rng_state'])
for epoch in range(epoch, self.epochs):
self.callback_train_new_epoch(epoch, net, optimizer)
# train epoch
self.train_epoch(epoch, net, optimizer, train_set)
# test epoch
errs = self.test(epoch, net, test_sets)
if (epoch + 1) % self.save_frequency == 0:
net = net.to(self.train_device)
# store state
state_dict = {
'epoch': epoch,
'min_err': min_err,
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict(),
'cpu_rng_state': torch.get_rng_state(),
'gpu_rng_state': torch.cuda.get_rng_state(),
}
logging.info(f'save state to {state_path}')
state_path = self.exp_out_root / 'state.dict'
torch.save(state_dict, str(state_path))
for test_set_name in errs:
err = sum(errs[test_set_name])
if err < min_err[test_set_name]:
min_err[test_set_name] = err
state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict'
logging.info(f'save state to {state_path}')
torch.save(state_dict, str(state_path))
# store network
net_path = self.get_net_path(epoch)
logging.info(f'save network to {net_path}')
torch.save(net.state_dict(), str(net_path))
if scheduler is not None:
scheduler.step()
logging.info('='*80)
logging.info('Finished training')
self.log_datetime()
logging.info('='*80)
def get_train_set(self):
# returns train_set
raise NotImplementedError()
def get_test_sets(self):
# returns test_sets
raise NotImplementedError()
def copy_data(self, data, device, requires_grad, train):
raise NotImplementedError()
def net_forward(self, net, train):
raise NotImplementedError()
def loss_forward(self, output, train):
raise NotImplementedError()
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
# err = False
# for name, param in net.named_parameters():
# if not torch.isfinite(param.grad).all():
# print(name)
# err = True
# if err:
# import ipdb; ipdb.set_trace()
pass
def callback_train_start(self, epoch):
pass
def callback_train_stop(self, epoch, loss):
pass
def train_epoch(self, epoch, net, optimizer, dset):
self.callback_train_start(epoch)
stopwatch = StopWatch()
logging.info('='*80)
logging.info('Train epoch %d' % epoch)
dset.current_epoch = epoch
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=False)
net = net.to(self.train_device)
net.train()
mean_loss = None
n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader)
bar = ETA(length=n_batches)
stopwatch.start('total')
stopwatch.start('data')
for batch_idx, data in enumerate(train_loader):
if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break
self.copy_data(data, device=self.train_device, requires_grad=True, train=True)
stopwatch.stop('data')
optimizer.zero_grad()
stopwatch.start('forward')
output = self.net_forward(net, train=True)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('forward')
stopwatch.start('loss')
errs = self.loss_forward(output, train=True)
if isinstance(errs, dict):
masks = errs['masks']
errs = errs['errs']
else:
masks = []
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs]
err = sum(errs)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('loss')
stopwatch.start('backward')
err.backward()
self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('backward')
stopwatch.start('optimizer')
optimizer.step()
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('optimizer')
bar.update(batch_idx)
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
err_str = self.format_err_str(errs)
logging.info(f'train e{epoch}: {batch_idx+1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
#self.write_err_img()
if mean_loss is None:
mean_loss = [0 for e in errs]
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.start('data')
stopwatch.stop('total')
logging.info('timings: %s' % stopwatch)
mean_loss = [l / len(train_loader) for l in mean_loss]
self.callback_train_stop(epoch, mean_loss)
self.metric_add_train(epoch, 'loss', mean_loss)
# save metrics
self.metric_save()
err_str = self.format_err_str(mean_loss)
logging.info(f'avg train_loss={err_str}')
return mean_loss
def callback_test_start(self, epoch, set_idx):
pass
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
pass
def callback_test_stop(self, epoch, set_idx, loss):
pass
def test(self, epoch, net, test_sets):
errs = {}
for test_set_idx, test_set in enumerate(test_sets):
if (epoch + 1) % test_set.test_frequency == 0:
logging.info('='*80)
logging.info(f'testing set {test_set.name}')
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
errs[test_set.name] = err
return errs
def test_epoch(self, epoch, set_idx, net, dset):
logging.info('-'*80)
logging.info('Test epoch %d' % epoch)
dset.current_epoch = epoch
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False, pin_memory=False)
net = net.to(self.test_device)
net.eval()
with torch.no_grad():
mean_loss = None
self.callback_test_start(epoch, set_idx)
bar = ETA(length=len(test_loader))
stopwatch = StopWatch()
stopwatch.start('total')
stopwatch.start('data')
for batch_idx, data in enumerate(test_loader):
# if batch_idx == 10: break
self.copy_data(data, device=self.test_device, requires_grad=False, train=False)
stopwatch.stop('data')
stopwatch.start('forward')
output = self.net_forward(net, train=False)
if 'cuda' in self.test_device: torch.cuda.synchronize()
stopwatch.stop('forward')
stopwatch.start('loss')
errs = self.loss_forward(output, train=False)
if isinstance(errs, dict):
masks = errs['masks']
errs = errs['errs']
else:
masks = []
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs]
bar.update(batch_idx)
if batch_idx % 25 == 0:
err_str = self.format_err_str(errs)
logging.info(f'test e{epoch}: {batch_idx+1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
if mean_loss is None:
mean_loss = [0 for e in errs]
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.stop('loss')
self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks)
stopwatch.start('data')
stopwatch.stop('total')
logging.info('timings: %s' % stopwatch)
mean_loss = [l / len(test_loader) for l in mean_loss]
self.callback_test_stop(epoch, set_idx, mean_loss)
self.metric_add_test(epoch, set_idx, 'loss', mean_loss)
# save metrics
self.metric_save()
err_str = self.format_err_str(mean_loss)
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
return mean_loss