You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

72 lines
1.5 KiB

#pragma once
#include <vector>
class Sample {
public:
Sample(int channels, int height, int width)
: channels_(channels), height_(height), width_(width) {}
virtual ~Sample() {}
virtual float at(int c, int h, int w) const = 0;
virtual float operator()(int c, int h, int w) const {
return at(c,h,w);
}
virtual int channels() const { return channels_; }
virtual int height() const { return height_; }
virtual int width() const { return width_; }
protected:
int channels_;
int height_;
int width_;
};
typedef std::shared_ptr<Sample> SamplePtr;
class Target {
public:
Target() {}
virtual ~Target() {}
};
typedef std::shared_ptr<Target> TargetPtr;
typedef std::vector<TargetPtr> VecTargetPtr;
typedef std::shared_ptr<VecTargetPtr> VecPtrTargetPtr;
class ClassificationTarget : public Target {
public:
ClassificationTarget(int cl) : cl_(cl) {}
virtual ~ClassificationTarget() {}
int cl() const { return cl_; }
private:
int cl_;
};
typedef std::shared_ptr<ClassificationTarget> ClassificationTargetPtr;
struct TrainDatum {
SamplePtr sample;
TargetPtr target;
TargetPtr optimize_target;
TrainDatum() : sample(nullptr), target(nullptr), optimize_target(nullptr) {}
TrainDatum(SamplePtr sample, TargetPtr target)
: sample(sample), target(target), optimize_target(target) {}
TrainDatum(SamplePtr sample, TargetPtr target, TargetPtr optimize_target)
: sample(sample), target(target), optimize_target(optimize_target) {}
};