faif
DecisionTree.hpp
Go to the documentation of this file.
1 /**
2  * \file DecisionTree.hpp
3  * \brief The Decision Tree Classifier, inspired ID3 algorithm (Iterate Dichotomizer)
4  */
5 
6 #ifndef FAIF_DECISION_TREE_HPP
7 #define FAIF_DECISION_TREE_HPP
8 
9 #if defined(_MSC_VER) && (_MSC_VER >= 1400)
10 //msvc14.0 warnings for Boost.Serialization
11 #pragma warning(disable:4100)
12 #pragma warning(disable:4512)
13 #endif
14 
15 
16 
17 #include "Classifier.hpp"
18 
19 #include <list>
20 #include <set>
21 #include <algorithm>
22 #include <iterator>
23 #include <limits>
24 
25 #include <boost/ref.hpp>
26 #include <boost/bind.hpp>
27 
28 #include <boost/lambda/bind.hpp>
29 #include <boost/lambda/construct.hpp>
30 #include <boost/lambda/core.hpp>
31 #include <boost/lambda/lambda.hpp>
32 
33 #include <boost/serialization/split_member.hpp>
34 #include <boost/serialization/base_object.hpp>
35 #include <boost/serialization/nvp.hpp>
36 #include <boost/serialization/singleton.hpp>
37 #include <boost/serialization/extended_type_info.hpp>
38 #include <boost/serialization/shared_ptr.hpp>
39 #include <boost/serialization/vector.hpp>
40 
41 namespace faif {
42  namespace ml {
43 
44  /** \brief param for training decision tree */
46  DecisionTreeTrainParams() : allowedNbrMiscEx(1), minInfGain(0.000001) {}
47  int allowedNbrMiscEx; //allowed number of badly classified examples for each category
48  double minInfGain; // minimal information gain to accept test
49 
50  template<class Archive>
51  void serialize( Archive &ar, const unsigned int /* file_version*/ ){
52  ar & boost::serialization::make_nvp("allowedNbrMiscEx", allowedNbrMiscEx );
53  ar & boost::serialization::make_nvp("minInfGain", minInfGain );
54  }
55  };
56 
57  /** \brief Decision Tree Classifier.
58 
59  Contains the attributes, attribute values and categories,
60  train examples, test examples and classifier methods.
61  */
62  template<typename Val>
63  class DecisionTree : public Classifier<Val> {
64  public:
65  typedef typename Classifier<Val>::AttrValue AttrValue;
66  typedef typename Classifier<Val>::AttrDomain AttrDomain;
67  typedef typename Classifier<Val>::AttrIdd AttrIdd;
69  typedef typename Classifier<Val>::Domains Domains;
70  typedef typename Classifier<Val>::Beliefs Beliefs;
74  public:
75  DecisionTree();
76  DecisionTree(const Domains& attr_domains, const AttrDomain& category_domain);
77 
78  virtual ~DecisionTree() { }
79 
80  /** clear the tree */
81  virtual void reset();
82 
83  /** \brief learn classifier (on the collection of training examples).
84  */
85  virtual void train(const ExamplesTrain& e);
86 
87  /** classify */
88  virtual AttrIdd getCategory(const ExampleTest&) const;
89 
90  /** \brief classify and return all classes with belief that the example is from given class */
91  virtual Beliefs getCategories(const ExampleTest&) const;
92 
93  /** the ostream method */
94  virtual void write(std::ostream& os) const;
95 
96  /** accessor - get training parameters */
97  const DecisionTreeTrainParams& getTrainParams() const { return params_; }
98 
99  /** mutator - set training parameters */
100  void setTrainParams(const DecisionTreeTrainParams& p) { params_ = p; }
101 
102  /**
103  \brief prune tree - plase not use the example set used for training
104 
105  Return the (smart)pointer to node which
106  replace the old one. If no prunning is performed the input pointer and the output are the same.
107 
108  bottom-up method, the uneven distribution of categories is not considered
109  */
110  void prune(const ExamplesTrain& e);
111  private:
112  //forward declaration
113  class DTNode;
114  typedef boost::shared_ptr<DTNode> PDTNode;
115 
116  PDTNode root_; //main node for decision tree
117  DecisionTreeTrainParams params_; //params for training decision tree
118 
119  /** copy c-tor not allowed */
120  DecisionTree(const DecisionTree&);
121  /** assignment not allowed */
122  DecisionTree& operator=(const DecisionTree&);
123 
124  private:
125  /**
126  internal class - binary test (stored in each node), currently (for nomial values) equality test
127  */
128  class DTTest
129  {
130  public:
131  DTTest( ): idd_(AttrDomain::getUnknownId()) {} //for de-serialization
132  explicit DTTest( AttrIdd idd ): idd_(idd) {}
133  ~DTTest() {}
134 
135  AttrIdd get() const { return idd_; }
136 
137  /** \brief perform the test for given example */
138  bool test( const ExampleTest& e ) const;
139 
140  /** \brief calculate entropy gain for given test. The return value is normalized. */
141  double entropyGain(typename ExamplesTrain::const_iterator eBeg, typename ExamplesTrain::const_iterator eEnd, double minInfGain) const;
142 
143  /** \brief ostream method */
144  void write(std::ostream& os) const;
145  private:
146 
147  /** \brief serialization using boost::serialization */
148  friend class boost::serialization::access;
149 
150  template<class Archive>
151  void save(Archive & ar, const unsigned int /* file_version */) const {
152  ar & boost::serialization::make_nvp("Idd", idd_ );
153  }
154 
155  template<class Archive>
156  void load(Archive & ar, const unsigned int /* file_version */) {
157  AttrIddSerialize i;
158  ar >> boost::serialization::make_nvp("Idd", i);
159  idd_ = const_cast<AttrIdd>(i);
160  }
161 
162  template<class Archive>
163  void serialize( Archive &ar, const unsigned int file_version ){
164  boost::serialization::split_member(ar, *this, file_version);
165  }
166 
167  private:
168  AttrIdd idd_;
169  };
170 
171  /** collection of tests */
172  typedef std::list<DTTest> DTTests;
173 
174  /**
175  \brief internal class - Node in decision tree classifier (leaf)
176  */
177  class DTNode
178  {
179  public:
180  DTNode() {} //for de-serialization
181  DTNode(Beliefs catBel) : catBel_(catBel) {}
182  virtual ~DTNode() {}
183 
184  //major category for given node
185  AttrIdd getMajorCategory() const {
186  if( catBel_.empty() )
187  return AttrDomain::getUnknownId();
188  else
189  return catBel_.front().getValue();
190  }
191 
192  //categories with belief for given node
193  const Beliefs& getBeliefs() const { return catBel_; }
194 
195  //test if the node is leaf node
196  bool isLeaf() const { return getTest() == 0L; }
197 
198  //factory method
199  static PDTNode createLeaf(const Beliefs& catBel);
200  //factory method
201  static PDTNode createInternal(const Beliefs& catBel, const DTTest& test, PDTNode nTrue, PDTNode nFalse);
202 
203  //empty node - node when test return true
204  virtual PDTNode getNodeTrue() const { return PDTNode(); }
205  //empty node - node when test return false
206  virtual PDTNode getNodeFalse() const { return PDTNode(); }
207  //test stored in node (null)
208  virtual const DTTest* getTest() const { return 0L; }
209  /** mutator - set node when test return true. Empty operation for LeafNode. */
210  virtual void setNodeTrue(PDTNode) { }
211  /*** mutator - set node when test return false. Empty operation for LeafNode. */
212  virtual void setNodeFalse(PDTNode) { }
213 
214  /** \brief classify, return category for given testing example */
215  virtual AttrIdd getCategory(const ExampleTest& e) const {
216  return getMajorCategory();
217  }
218 
219  /** \brief classify, return categories and belief for given testing example */
220  virtual const Beliefs& getCategories(const ExampleTest& e) const {
221  return getBeliefs();
222  }
223 
224  //for debugging
225  virtual void write(std::ostream& os) const {
226  os << "Leaf (Major:" << getMajorCategory()->get() << ", Beliefs:" << getBeliefs() << ");";
227  }
228  private:
229  /** \brief serialization using boost::serialization */
230  friend class boost::serialization::access;
231 
232  template<class Archive>
233  void serialize( Archive &ar, const unsigned int /* file_version*/ ){
234  ar & boost::serialization::make_nvp("CatBel", catBel_ );
235  }
236 
237  private:
238  Beliefs catBel_;
239  };
240 
241 
242 
243  /**
244  \brief interanal class - internal node (with test and left and right children)
245  */
246  class DTNodeInternal : public DTNode {
247  public:
248  DTNodeInternal() {} //for de-serialization
249  DTNodeInternal(const Beliefs& catBel, const DTTest& test, PDTNode nTrue, PDTNode nFalse)
250  : DTNode(catBel), test_(test), nodeTrue_(nTrue), nodeFalse_(nFalse)
251  {}
252 
253  //node when test return true
254  virtual PDTNode getNodeTrue() const { return nodeTrue_; }
255  //node when test return false
256  virtual PDTNode getNodeFalse() const { return nodeFalse_; }
257  //test stored in node
258  virtual const DTTest* getTest() const { return &test_; }
259  /** mutator - set node when test return true. Empty operation for LeafNode. */
260  virtual void setNodeTrue(PDTNode n) { nodeTrue_ = n; }
261  /*** mutator - set node when test return false. Empty operation for LeafNode. */
262  virtual void setNodeFalse(PDTNode n) { nodeFalse_ = n; }
263 
264  /** \brief classify, return category for given testing example */
265  virtual AttrIdd getCategory(const ExampleTest& e) const {
266  if( test_.test(e) )
267  return nodeTrue_->getCategory(e);
268  else
269  return nodeFalse_->getCategory(e);
270  }
271 
272  /** \brief classify, return categories and belief for given testing example */
273  virtual const Beliefs& getCategories(const ExampleTest& e) const {
274  if( test_.test(e) )
275  return nodeTrue_->getCategories(e);
276  else
277  return nodeFalse_->getCategories(e);
278  }
279 
280  //for debugging
281  virtual void write(std::ostream& os) const {
282  os << "Internal (Major:" << this->getMajorCategory()->get() << ", Beliefs:" << this->getBeliefs() << ", test:";
283  test_.write(os);
284  os << ");";
285  }
286  private:
287  /** \brief serialization using boost::serialization */
288  friend class boost::serialization::access;
289 
290  template<class Archive>
291  void serialize( Archive &ar, const unsigned int /* file_version*/ ){
292  ar & boost::serialization::make_nvp("NodeBase", boost::serialization::base_object<DTNode>(*this) );
293  ar & boost::serialization::make_nvp("Test", test_ );
294  ar & boost::serialization::make_nvp("NodeTrue", nodeTrue_ );
295  ar & boost::serialization::make_nvp("NodeFalse", nodeFalse_ );
296  }
297  private:
298  DTTest test_; //binary test
299  PDTNode nodeTrue_; //node when test return true
300  PDTNode nodeFalse_; //node when test return false
301  };
302 
303  /**
304  \brief recurent function to build decision tree
305  \param eBeg training examples collection (iterator). The examples are re-order (partitioned) by tests (split)
306  \param eEnd training examples collection (iterator).
307  \param inTest the initial collection of tests
308  \param ALLOWED_NBR_MISC_EX allowed number of badly classified examples for each category
309  */
310  PDTNode buildTreeRecur(typename ExamplesTrain::iterator eBeg, typename ExamplesTrain::iterator eEnd,
311  const DTTests& inTests, const int ALLOWED_NBR_MISC_EX);
312 
313  /**
314  \brief recurent function to prune decision tree
315  \param eBeg pruning examples collection (iterator). The examples are re-order (partitioned) by tests (split)
316  \param eEnd pruning examples collection (iterator).
317  \param node the considered node. It is returned or changed into other node (internal node into leaf node).
318  */
319  static PDTNode pruneTreeRecur(typename ExamplesTrain::iterator eBeg, typename ExamplesTrain::iterator eEnd, PDTNode node);
320 
321  /**
322  \brief helping function for ostream operator
323  */
324  static void writeDecTreeNodes(std::ostream& os, typename DecisionTree<Val>::PDTNode node, int level = 0);
325 
326  /** \brief serialization using boost::serialization */
327  friend class boost::serialization::access;
328 
329  template<class Archive>
330  void serialize( Archive &ar, const unsigned int /* file_version */ ) {
331  ar.template register_type<DTNode>();
332  ar.template register_type<DTNodeInternal>();
333  ar.template register_type<DecisionTreeTrainParams>();
334 
335  ar & boost::serialization::make_nvp("DTCBase", boost::serialization::base_object<Classifier<Val> >(*this) );
336  ar & boost::serialization::make_nvp("Node", root_ );
337  ar & boost::serialization::make_nvp("params", params_);
338  }
339 
340  }; //class DecisionTree
341 
342  //////////////////////////////////////////////////////////////////////////////////////////////////
343  // class DecisionTree implementation
344  //////////////////////////////////////////////////////////////////////////////////////////////////
345 
346  template<typename Val>
348  {
349  }
350 
351  template<typename Val>
352  DecisionTree<Val>::DecisionTree(const Domains& attr_domains, const AttrDomain& category_domain)
353  : Classifier<Val>(attr_domains, category_domain)
354  {
355  }
356 
357  /** clear the tree */
358  template<typename Val>
360  root_ = PDTNode();
361  }
362 
363  /**
364  \brief learn classifier (on the collection of training examples), the decision tree using given train examples
365  \param e training examples collection
366  \param ALLOWED_NBR_MISC_EX allowed number of badly classified examples for each category
367  */
368  template<typename Val>
369  void DecisionTree<Val>::train(const ExamplesTrain& e) {
370 
371  ExamplesTrain ex(e); //make a copy, because the container will be changed (split re-order examples in sets)
372  std::set<AttrIdd> attrib; // structure for available tests for train examples collection
373 
374  // Look through all examples and remove redundant attributes, if exists
375  for (typename ExamplesTrain::iterator it = ex.begin(); it != ex.end(); ++it) {
376  for(typename ExampleTrain::iterator at = it->begin(); at != it->end();) {
377  AttrIdd& value = *at;
378  if( std::find(Classifier<Val>::getAttrDomains().begin(),
380  value->getDomain()->getId()) != Classifier<Val>::getAttrDomains().end())
381  ++at;
382  else
383  at = it->erase(at);
384 
385  }
386  // generate available tests for train examples collection
387  const ExampleTrain& ee = *it;
388  std::copy(ee.begin(), ee.end(), std::inserter(attrib, attrib.begin() ) );
389  }
390 
391  DTTests tests;
392  std::transform(attrib.begin(), attrib.end(), std::back_inserter(tests),
393  boost::lambda::bind(boost::lambda::constructor<DTTest>(), boost::lambda::_1 ) );
394 
395  root_ = buildTreeRecur(ex.begin(), ex.end(), tests, params_.allowedNbrMiscEx);
396 
397  }
398 
399  /** classify - return the major category for best node from decision tree */
400  template<typename Val>
401  typename DecisionTree<Val>::AttrIdd DecisionTree<Val>::getCategory(const ExampleTest& e) const {
402  if(!root_) { //empty tree
403  return AttrIdd(AttrDomain::getUnknownId());
404  } else if(e.empty() ) { //no common attrib (domains) between example and classifier
405  return root_->getMajorCategory();
406  } else { //classify using decision tree
407  return root_->getCategory(e);
408  }
409  }
410 
411  /** \brief classify and return all classes with belief that the example is from given class */
412  template<typename Val>
413  typename DecisionTree<Val>::Beliefs DecisionTree<Val>::getCategories(const ExampleTest& e) const {
414  if(!root_) { //empty tree
415  return Beliefs();
416  } else if(e.empty() ) { //no common attrib (domains) between example and classifier
417  return root_->getBeliefs();
418  } else { //classify using decision tree
419  return root_->getCategories(e);
420  }
421  }
422 
423  /** ostream method */
424  template<typename Val>
425  void DecisionTree<Val>::write(std::ostream& os) const {
426  if(!root_)
427  os << "Empty DTC" << std::endl;
428  else
429  writeDecTreeNodes(os, root_, 0 );
430  }
431 
432  /**
433  \brief prune tree - plase not use the example set used for training
434 
435  Return the (smart)pointer to node which
436  replace the old one. If no prunning is performed the input pointer and the output are the same.
437 
438  bottom-up method, the uneven distribution of categories is not considered
439  */
440  template<typename Val>
441  void DecisionTree<Val>::prune(const ExamplesTrain& e) {
442  if(root_) {
443  ExamplesTrain ex(e); //make a copy, because the container will be changed (split re-order examples in sets)
444  pruneTreeRecur(ex.begin(), ex.end(), root_ );
445  }
446  }
447 
448  /** \brief recurent function to build decision tree
449  \param eBeg training examples collection (iterator). The examples are re-order (partitioned) by tests (split)
450  \param eEnd training examples collection (iterator).
451  \param inTest the initial collection of tests
452  \param ALLOWED_NBR_MISC_EX allowed number of badly classified examples for each category
453  */
454  template<typename Val>
455  typename DecisionTree<Val>::PDTNode
456  DecisionTree<Val>::buildTreeRecur(typename ExamplesTrain::iterator eBeg, typename ExamplesTrain::iterator eEnd,
457  const DTTests& inTests, const int ALLOWED_NBR_MISC_EX) {
458 
459  //calculate histogram of categories for train example collection
460 
461  TrainExampleCategoryCounters<Val> counters(eBeg, eEnd);
462  Beliefs histogram = counters.getHistogram();
463 
464  int numCatWithManyExamples = 0;
465  const std::map<AttrIdd,int>& c = counters.get();
466  for(typename std::map<AttrIdd,int>::const_iterator i = c.begin(); i != c.end(); ++i) {
467  if( i->second > ALLOWED_NBR_MISC_EX) {
468  ++numCatWithManyExamples;
469  }
470  }
471  if(static_cast<int>(std::distance(eBeg, eEnd)) <= ALLOWED_NBR_MISC_EX || //no enough training examples - split is not sensible
472  numCatWithManyExamples < 2) { //only few examples in not-major category
473  return DTNode::createLeaf(histogram);
474  }
475  //find the best test
476  DTTests tests(inTests);
477  typename DTTests::iterator best = tests.end();
478  double bestEntropy = std::numeric_limits<double>::min();
479 
480  for(typename DTTests::iterator i = tests.begin(); i != tests.end(); ++i ) {
481  double entr = i->entropyGain(eBeg, eEnd, params_.minInfGain);
482  if(entr > bestEntropy) {
483  bestEntropy = entr;
484  best = i;
485  }
486  }
487  if( best == tests.end() || bestEntropy < params_.minInfGain ) { //no tests in tests set or no goot tests
488  return DTNode::createLeaf(histogram);
489  }
490  // std::cout << "best test:" << *best << " entropy gain:" << bestEntropy << std::endl;
491 
492  //split the examples using best test
493  typename ExamplesTrain::iterator middle = std::stable_partition(eBeg, eEnd, boost::bind(&DTTest::test, boost::ref(*best), _1) );
494  DTTest bestTest(*best);
495  tests.erase(best);
496  PDTNode nTrue = buildTreeRecur(eBeg, middle, tests, ALLOWED_NBR_MISC_EX);
497  PDTNode nFalse = buildTreeRecur(middle, eEnd, tests, ALLOWED_NBR_MISC_EX);
498  return DTNode::createInternal(histogram, bestTest, nTrue, nFalse);
499  }
500 
501  /** \brief recurent function to prune decision tree
502  \param eBeg pruning examples collection (iterator). The examples are re-order (partitioned) by tests (split)
503  \param eEnd pruning examples collection (iterator).
504  \param node the considered node. It is returned or changed into other node (internal node into leaf node).
505  */
506  template<typename Val>
507  typename DecisionTree<Val>::PDTNode
508  DecisionTree<Val>::pruneTreeRecur(typename ExamplesTrain::iterator eBeg, typename ExamplesTrain::iterator eEnd, PDTNode node) {
509 
510  if(node->isLeaf() || std::distance(eBeg, eEnd) < 1)
511  return node;
512 
513  //here assertion that for !node->isLeaf() node->getTest() return valid address
514  const DTTest& t = *(node->getTest());
515  //split the examples using the test from node
516  typename ExamplesTrain::iterator
517  middle = std::stable_partition(eBeg, eEnd, boost::bind(&DTTest::test, boost::ref(t), _1) );
518 
519  node->setNodeTrue( pruneTreeRecur( eBeg, middle, node->getNodeTrue() ) );
520  node->setNodeFalse( pruneTreeRecur( middle, eEnd, node->getNodeFalse() ) );
521 
522  //count the number of correctly classified pruning exampes
523  int leafCount = 0, treeCount = 0;
524  for(typename ExamplesTrain::const_iterator i = eBeg; i != eEnd; ++i) {
525  const ExampleTrain& e = *i;
526  if(node->getMajorCategory() == e.getFeature() )
527  ++leafCount;
528  if(node->getCategory(e) == e.getFeature())
529  ++treeCount;
530  }
531  // std::cout << "pruning:" << *node << " examples:" << std::distance(eBeg, eEnd)
532  // << " leafCount: " << leafCount << " treeCount: " << treeCount << std::endl;
533  if( leafCount >= treeCount ) { //major category for node is good enough
534  return DTNode::createLeaf(node->getBeliefs()); //switch to leaf
535  }
536  return node;
537  }
538 
539  /**
540  \brief helping function for ostream operator
541  */
542  template<typename Val>
543  void DecisionTree<Val>::writeDecTreeNodes(std::ostream& os, typename DecisionTree<Val>::PDTNode node, int level) {
544  if( node ) {
545  os << std::string(level,' ');
546  node->write(os);
547  os << std::endl;
548  writeDecTreeNodes(os, node->getNodeTrue(), level+1);
549  writeDecTreeNodes(os, node->getNodeFalse(), level+1);
550  }
551  }
552 
553  //////////////////////////////////////////////////////////////////////////////////////////////////
554  // class DecisionTree::DTTest implementation
555  //////////////////////////////////////////////////////////////////////////////////////////////////
556 
557  /** \brief calculate entropy gain for given test. The return value is normalized. */
558  template<typename Val>
559  double DecisionTree<Val>::DTTest::entropyGain(typename ExamplesTrain::const_iterator eBeg, typename ExamplesTrain::const_iterator eEnd, double minInfGain) const {
560 
561  if( eBeg == eEnd ) //not start calculation for empty set
562  return 0.0;
563 
566 
567  for(typename ExamplesTrain::const_iterator i = eBeg; i != eEnd; ++ i) {
568  const ExampleTrain& ex = *i;
569  if( this->test(ex) ) {
570  acc.inc(ex);
571  }
572  else {
573  nacc.inc(ex);
574  }
575  }
576  double sum = static_cast<double>( std::distance(eBeg, eEnd) );
577  double nrAcc = static_cast<double>(acc.getSum() );
578  double nrNAcc = static_cast<double>(nacc.getSum() );
579  double testIc = calcEntropy( nrAcc / sum ) + calcEntropy( nrNAcc/sum );
580  if( testIc < minInfGain ) {
581  return 0.0;
582  }
583  else {
584  double entropy = acc.entropy() * nrAcc / sum + nacc.entropy() * nrNAcc / sum;
585  TrainExampleCategoryCounters<Val> befSplit(eBeg, eEnd);
586  double gain = befSplit.entropy() - entropy;
587  return gain / testIc;
588  }
589  }
590 
591  /** \brief perform the test for given example */
592  template<typename Val>
593  bool DecisionTree<Val>::DTTest::test( const ExampleTest& e ) const {
594  return std::find(e.begin(), e.end(), idd_) != e.end();
595  }
596 
597  /** \brief ostream method */
598  template<typename Val>
599  void DecisionTree<Val>::DTTest::write(std::ostream& os) const {
600  os << "Domain: " << idd_->getDomain()->getId() << ", Value:" << idd_->get();
601  }
602 
603  //////////////////////////////////////////////////////////////////////////////////////////////////
604  // class DecisionTree::DTNode implementation
605  //////////////////////////////////////////////////////////////////////////////////////////////////
606 
607  //factory method
608  template<typename Val>
609  typename DecisionTree<Val>::PDTNode
610  DecisionTree<Val>::DTNode::createLeaf(const Beliefs& catBel) {
611  return PDTNode(new DTNode(catBel) );
612  }
613 
614  //factory method
615  template<typename Val>
616  typename DecisionTree<Val>::PDTNode
617  DecisionTree<Val>::DTNode::createInternal(const Beliefs& catBel, const DTTest& test, PDTNode nTrue, PDTNode nFalse) {
618  return PDTNode(new DTNodeInternal(catBel, test, nTrue, nFalse) );
619  }
620 
621  }//namespace ml
622 } //namespace faif
623 
624 
625 
626 #endif //FAIF_DECISION_TREE_HPP
Val::DomainType::ValueId AttrIdd
attribute id representation in learning
Definition: Classifier.hpp:55
Definition: Chain.h:17
virtual void write(std::ostream &os) const
Definition: DecisionTree.hpp:425
void setTrainParams(const DecisionTreeTrainParams &p)
Definition: DecisionTree.hpp:100
Val::Value AttrValue
attribute value representation in learning
Definition: Classifier.hpp:49
const Counters & get() const
access to counters
Definition: Classifier.hpp:277
Decision Tree Classifier.
Definition: DecisionTree.hpp:63
Definition: Classifier.hpp:233
point and some feature
Definition: Point.hpp:58
Belief< Val >::Beliefs Beliefs
collection of pair (AttrIdd, Probability)
Definition: Classifier.hpp:64
inner class - examples train collection
Definition: Classifier.hpp:82
virtual AttrIdd getCategory(const ExampleTest &) const
Definition: DecisionTree.hpp:401
virtual void train(const ExamplesTrain &e)
learn classifier (on the collection of training examples).
Definition: DecisionTree.hpp:369
double entropy() const
entropy of counters
Definition: Classifier.hpp:283
Point in n-space, each component of the same type.
Definition: Point.hpp:22
Beliefs getHistogram() const
histogram from counters - Beliefs class where each position is counter divided by counters sum...
Definition: Classifier.hpp:297
virtual Beliefs getCategories(const ExampleTest &) const
classify and return all classes with belief that the example is from given class
Definition: DecisionTree.hpp:413
Val::DomainType::ValueIdSerialize AttrIddSerialize
for serialization the const interferes
Definition: Classifier.hpp:58
const DecisionTreeTrainParams & getTrainParams() const
Definition: DecisionTree.hpp:97
double calcEntropy(double freq)
calculate x * log(x) value. If x == 0 return 0.
Definition: Classifier.hpp:28
param for training decision tree
Definition: DecisionTree.hpp:45
int getSum() const
optimization: instead of accumulate all values from counters container keep the integer member ...
Definition: Classifier.hpp:280
Val::DomainType AttrDomain
the attribute domain for learning
Definition: Classifier.hpp:52
virtual void reset()
Definition: DecisionTree.hpp:359
void prune(const ExamplesTrain &e)
prune tree - plase not use the example set used for training
Definition: DecisionTree.hpp:441
the clasiffier interface
Definition: Classifier.hpp:43