3 #ifndef FAIF_RANDOM_FOREST_HPP 4 #define FAIF_RANDOM_FOREST_HPP 6 #if defined(_MSC_VER) && (_MSC_VER >= 1400) 8 #pragma warning(disable:4100) 9 #pragma warning(disable:4512) 13 #include "Classifier.hpp" 15 #include "../utils/Random.hpp" 25 #include <boost/serialization/list.hpp> 26 #include <boost/serialization/base_object.hpp> 27 #include <boost/serialization/nvp.hpp> 28 #include <boost/serialization/vector.hpp> 29 #include <boost/serialization/shared_ptr.hpp> 37 allowedNbrMiscEx(1), minInfGain(0.000001) {}
39 int numFeaturesPerTree;
47 template<
class Archive>
48 void serialize( Archive &ar,
const unsigned int ){
49 ar & boost::serialization::make_nvp(
"numFeaturesPerTree", numFeaturesPerTree );
50 ar & boost::serialization::make_nvp(
"numTrees", numTrees );
51 ar & boost::serialization::make_nvp(
"allowedNbrMiscEx", allowedNbrMiscEx );
52 ar & boost::serialization::make_nvp(
"minInfGain", minInfGain );
61 template<
typename Val>
75 RandomForest(
const Domains& attr_domains,
const AttrDomain& category_domain);
84 return static_cast<int>(std::max(ceil(sqrt(2*dataSize)),\
85 ceil(2*numFeatures/ceil(sqrt(numFeatures)))));
100 virtual void reset();
104 virtual void train(
const ExamplesTrain& e);
107 virtual AttrIdd getCategory(
const ExampleTest&)
const;
111 virtual Beliefs getCategories(
const ExampleTest&)
const;
121 Beliefs prepareResults(
const std::list<
typename RandomForest<Val>::AttrIdd>&)
const;
126 ExamplesTrain exampleBootstrap(
const ExamplesTrain&);
134 std::vector<int> uniformRandomGenerator(std::size_t size, std::size_t range,
bool isReplacementON);
149 RandomTree(
const Domains& attr_domains,
const AttrDomain& category_domain)
161 void train(
const ExamplesTrain& e)
167 AttrIdd getCategory(
const ExampleTest& e)
const 174 typename RandomForest<Val>::Beliefs getCategories(
const ExampleTest& e)
const 179 template <
class Tcontainer >
180 static Tcontainer CoverDomains(Tcontainer domains_, std::vector<int> attribs_allowed)
182 Tcontainer newDomains_;
183 for (std::vector<int>::const_iterator it=attribs_allowed.begin(); it != attribs_allowed.end(); ++it)
185 typename Tcontainer::iterator it_d=domains_.begin();
186 std::advance(it_d,*it);
187 newDomains_.push_back(*it_d);
194 friend class boost::serialization::access;
196 template<
class Archive>
197 void serialize( Archive &ar,
const unsigned int ){
198 ar & boost::serialization::make_nvp(
"BaseTree",boost::serialization::base_object<
DecisionTree<Val> >(*
this));
206 friend class boost::serialization::access;
208 template<
class Archive>
209 void serialize( Archive &ar,
const unsigned int ){
210 ar.template register_type<RandomTree>();
211 ar.template register_type<RandomForestParams>();
213 ar & boost::serialization::make_nvp(
"RFCBase", boost::serialization::base_object<
Classifier<Val> >(*
this) );
214 ar & boost::serialization::make_nvp(
"RTrees", trees_ );
215 ar & boost::serialization::make_nvp(
"params", params_);
220 typedef std::list<boost::shared_ptr<RandomTree>> RTrees;
230 template<
typename Val>
233 template<
typename Val>
235 :
Classifier<Val>(attr_domains, category_domain){}
238 template<
typename Val>
244 template<
typename Val>
248 for(
typename RTrees::iterator it=trees_.begin(); it != trees_.end(); it++ )
250 RandomTree & obj = *(*it);
251 obj.train(exampleBootstrap(e));
256 template<
typename Val>
260 std::list< RandomForest<Val>::AttrIdd> results_;
261 for(
typename RTrees::const_iterator it=trees_.begin(); it != trees_.end(); it++ )
263 RandomTree & obj = *(*it);
264 results_.push_back(obj.getCategory(e));
266 Beliefs bel_ = prepareResults(results_);
268 return AttrDomain::getUnknownId();
270 return bel_.at(0).getValue();
273 template<
typename Val>
276 std::list< RandomForest<Val>::AttrIdd> results_;
277 for(
typename RTrees::const_iterator it=trees_.begin(); it != trees_.end(); it++ )
279 RandomTree & obj = *(*it);
280 results_.push_back(obj.getCategory(e));
283 Beliefs bel_ = prepareResults(results_);
290 template<
typename Val>
296 p.allowedNbrMiscEx = params_.allowedNbrMiscEx;
297 p.minInfGain = params_.minInfGain;
298 for(
int i=0; i < params_.numTrees; ++i) {
299 auto tree = boost::shared_ptr<RandomTree>(
new RandomTree(
300 RandomTree::CoverDomains(
302 uniformRandomGenerator(params_.numFeaturesPerTree, size_-1,
false)),
304 tree->setTrainParams(p);
305 trees_.push_back(tree);
309 template<
typename Val>
312 std::vector<std::pair<typename RandomForest<Val>::AttrIdd,
double>> resultsList_;
313 std::map<typename RandomForest<Val>::AttrIdd,
double> count_;
316 for(
typename std::list<
typename RandomForest<Val>::AttrIdd>::const_iterator it=list_.begin(); it != list_.end(); it++ )
319 typename Val::Value class_ = ((
typename RandomForest<Val>::AttrIdd) *it)->get();
321 auto itClass_ = find_if(count_.begin(), count_.end(), [&class_](
const std::pair<typename RandomForest<Val>::AttrIdd,
double>& obj) {
return obj.first->get()==class_;});
323 if (itClass_ != std::end(count_))
330 for(
typename std::map<
typename RandomForest<Val>::AttrIdd,
double>::const_iterator it = count_.begin(); it != count_.end(); ++it )
331 resultsList_.push_back(std::pair<
typename RandomForest<Val>::AttrIdd,
double>(it->first,it->second/list_.size()));
333 std::sort(resultsList_.begin(), resultsList_.end(),
334 boost::bind(&std::pair<
typename RandomForest<Val>::AttrIdd,
double>::second, _1) >
335 boost::bind(&std::pair<
typename RandomForest<Val>::AttrIdd,
double>::second, _2));
337 for(
typename std::vector<std::pair<
typename RandomForest<Val>::AttrIdd,
double>>::const_iterator it = resultsList_.begin(); it != resultsList_.end(); ++it )
338 toRet_.push_back(
typename Beliefs::value_type(it->first,it->second));
342 template<
typename Val>
345 size_t S_ = example_.size();
346 ExamplesTrain subset_;
347 std::vector<int> sample_ = uniformRandomGenerator(S_, S_-1,
true);
348 for( std::vector<int>::iterator it = sample_.begin(); it != sample_.end(); ++it )
349 subset_.push_back(*std::next(example_.begin(),*it));
354 template<
typename Val>
357 if (size > range + 1 && !isReplacementON) {
358 throw std::invalid_argument(
"Size can not be bigger than range if replacement isn't on!");
360 std::vector<int> toRet_(size, -1);
363 RandomInt uniform_generator(0, static_cast<int>(range) );
364 for( std::vector<int>::iterator it = toRet_.begin(); it != toRet_.end(); ++it )
366 int sample_ = uniform_generator();
367 while( !isReplacementON && std::find(toRet_.begin(), toRet_.end(), sample_) != toRet_.end() )
369 sample_ = uniform_generator();
381 #endif //FAIF_RANDOM_FOREST_HPP Val::DomainType::ValueId AttrIdd
attribute id representation in learning
Definition: Classifier.hpp:55
const Domains & getAttrDomains() const
accessor
Definition: Classifier.hpp:109
virtual void train(const ExamplesTrain &e)
learn classifier (on the collection of training examples).
Definition: RandomForest.hpp:245
virtual AttrIdd getCategory(const ExampleTest &) const
Definition: RandomForest.hpp:257
Val::Value AttrValue
attribute value representation in learning
Definition: Classifier.hpp:49
Decision Tree Classifier.
Definition: DecisionTree.hpp:63
The Decision Tree Classifier, inspired ID3 algorithm (Iterate Dichotomizer)
point and some feature
Definition: Point.hpp:58
static int getBreimanNumTrees(size_t dataSize, size_t numFeatures)
get number of trees as per Breiman's recommendation
Definition: RandomForest.hpp:83
Belief< Val >::Beliefs Beliefs
collection of pair (AttrIdd, Probability)
Definition: Classifier.hpp:64
inner class - examples train collection
Definition: Classifier.hpp:82
virtual AttrIdd getCategory(const ExampleTest &) const
Definition: DecisionTree.hpp:401
const RandomForestParams & getTrainParams() const
Definition: RandomForest.hpp:94
virtual void train(const ExamplesTrain &e)
learn classifier (on the collection of training examples).
Definition: DecisionTree.hpp:369
Point in n-space, each component of the same type.
Definition: Point.hpp:22
virtual void reset()
Definition: RandomForest.hpp:239
virtual Beliefs getCategories(const ExampleTest &) const
classify and return all classes with belief that the example is from given class
Definition: DecisionTree.hpp:413
Val::DomainType::ValueIdSerialize AttrIddSerialize
for serialization the const interferes
Definition: Classifier.hpp:58
the uniform distribution for int, in range <min,max>, uses RandomSingleton
Definition: Random.hpp:107
virtual Beliefs getCategories(const ExampleTest &) const
classify and return all classes with belief that the example is from given class
Definition: RandomForest.hpp:274
param for training decision tree
Definition: DecisionTree.hpp:45
random forest's parameters
Definition: RandomForest.hpp:35
Val::DomainType AttrDomain
the attribute domain for learning
Definition: Classifier.hpp:52
virtual void reset()
Definition: DecisionTree.hpp:359
void setTrainParams(const RandomForestParams &p)
Definition: RandomForest.hpp:97
static int getBreimanNumFeatures(size_t numFeatures)
get number of features per tree as recommended by Breiman
Definition: RandomForest.hpp:91
the clasiffier interface
Definition: Classifier.hpp:43
Random Forest Classifier.
Definition: RandomForest.hpp:62