6 #ifndef FAIF_DECISION_TREE_HPP 7 #define FAIF_DECISION_TREE_HPP 9 #if defined(_MSC_VER) && (_MSC_VER >= 1400) 11 #pragma warning(disable:4100) 12 #pragma warning(disable:4512) 17 #include "Classifier.hpp" 25 #include <boost/ref.hpp> 26 #include <boost/bind.hpp> 28 #include <boost/lambda/bind.hpp> 29 #include <boost/lambda/construct.hpp> 30 #include <boost/lambda/core.hpp> 31 #include <boost/lambda/lambda.hpp> 33 #include <boost/serialization/split_member.hpp> 34 #include <boost/serialization/base_object.hpp> 35 #include <boost/serialization/nvp.hpp> 36 #include <boost/serialization/singleton.hpp> 37 #include <boost/serialization/extended_type_info.hpp> 38 #include <boost/serialization/shared_ptr.hpp> 39 #include <boost/serialization/vector.hpp> 50 template<
class Archive>
51 void serialize( Archive &ar,
const unsigned int ){
52 ar & boost::serialization::make_nvp(
"allowedNbrMiscEx", allowedNbrMiscEx );
53 ar & boost::serialization::make_nvp(
"minInfGain", minInfGain );
62 template<
typename Val>
76 DecisionTree(
const Domains& attr_domains,
const AttrDomain& category_domain);
85 virtual void train(
const ExamplesTrain& e);
88 virtual AttrIdd getCategory(
const ExampleTest&)
const;
91 virtual Beliefs getCategories(
const ExampleTest&)
const;
94 virtual void write(std::ostream& os)
const;
110 void prune(
const ExamplesTrain& e);
114 typedef boost::shared_ptr<DTNode> PDTNode;
131 DTTest( ): idd_(AttrDomain::getUnknownId()) {}
132 explicit DTTest( AttrIdd idd ): idd_(idd) {}
135 AttrIdd
get()
const {
return idd_; }
138 bool test(
const ExampleTest& e )
const;
141 double entropyGain(
typename ExamplesTrain::const_iterator eBeg,
typename ExamplesTrain::const_iterator eEnd,
double minInfGain)
const;
144 void write(std::ostream& os)
const;
148 friend class boost::serialization::access;
150 template<
class Archive>
151 void save(Archive & ar,
const unsigned int )
const {
152 ar & boost::serialization::make_nvp(
"Idd", idd_ );
155 template<
class Archive>
156 void load(Archive & ar,
const unsigned int ) {
158 ar >> boost::serialization::make_nvp(
"Idd", i);
159 idd_ =
const_cast<AttrIdd
>(i);
162 template<
class Archive>
163 void serialize( Archive &ar,
const unsigned int file_version ){
164 boost::serialization::split_member(ar, *
this, file_version);
172 typedef std::list<DTTest> DTTests;
181 DTNode(Beliefs catBel) : catBel_(catBel) {}
185 AttrIdd getMajorCategory()
const {
186 if( catBel_.empty() )
187 return AttrDomain::getUnknownId();
189 return catBel_.front().getValue();
193 const Beliefs& getBeliefs()
const {
return catBel_; }
196 bool isLeaf()
const {
return getTest() == 0L; }
199 static PDTNode createLeaf(
const Beliefs& catBel);
201 static PDTNode createInternal(
const Beliefs& catBel,
const DTTest& test, PDTNode nTrue, PDTNode nFalse);
204 virtual PDTNode getNodeTrue()
const {
return PDTNode(); }
206 virtual PDTNode getNodeFalse()
const {
return PDTNode(); }
208 virtual const DTTest* getTest()
const {
return 0L; }
210 virtual void setNodeTrue(PDTNode) { }
212 virtual void setNodeFalse(PDTNode) { }
215 virtual AttrIdd getCategory(
const ExampleTest& e)
const {
216 return getMajorCategory();
220 virtual const Beliefs& getCategories(
const ExampleTest& e)
const {
225 virtual void write(std::ostream& os)
const {
226 os <<
"Leaf (Major:" << getMajorCategory()->get() <<
", Beliefs:" << getBeliefs() <<
");";
230 friend class boost::serialization::access;
232 template<
class Archive>
233 void serialize( Archive &ar,
const unsigned int ){
234 ar & boost::serialization::make_nvp(
"CatBel", catBel_ );
246 class DTNodeInternal :
public DTNode {
249 DTNodeInternal(
const Beliefs& catBel,
const DTTest& test, PDTNode nTrue, PDTNode nFalse)
250 : DTNode(catBel), test_(test), nodeTrue_(nTrue), nodeFalse_(nFalse)
254 virtual PDTNode getNodeTrue()
const {
return nodeTrue_; }
256 virtual PDTNode getNodeFalse()
const {
return nodeFalse_; }
258 virtual const DTTest* getTest()
const {
return &test_; }
260 virtual void setNodeTrue(PDTNode n) { nodeTrue_ = n; }
262 virtual void setNodeFalse(PDTNode n) { nodeFalse_ = n; }
265 virtual AttrIdd getCategory(
const ExampleTest& e)
const {
267 return nodeTrue_->getCategory(e);
269 return nodeFalse_->getCategory(e);
273 virtual const Beliefs& getCategories(
const ExampleTest& e)
const {
275 return nodeTrue_->getCategories(e);
277 return nodeFalse_->getCategories(e);
281 virtual void write(std::ostream& os)
const {
282 os <<
"Internal (Major:" << this->getMajorCategory()->get() <<
", Beliefs:" << this->getBeliefs() <<
", test:";
288 friend class boost::serialization::access;
290 template<
class Archive>
291 void serialize( Archive &ar,
const unsigned int ){
292 ar & boost::serialization::make_nvp(
"NodeBase", boost::serialization::base_object<DTNode>(*
this) );
293 ar & boost::serialization::make_nvp(
"Test", test_ );
294 ar & boost::serialization::make_nvp(
"NodeTrue", nodeTrue_ );
295 ar & boost::serialization::make_nvp(
"NodeFalse", nodeFalse_ );
310 PDTNode buildTreeRecur(
typename ExamplesTrain::iterator eBeg,
typename ExamplesTrain::iterator eEnd,
311 const DTTests& inTests,
const int ALLOWED_NBR_MISC_EX);
319 static PDTNode pruneTreeRecur(
typename ExamplesTrain::iterator eBeg,
typename ExamplesTrain::iterator eEnd, PDTNode node);
324 static void writeDecTreeNodes(std::ostream& os,
typename DecisionTree<Val>::PDTNode node,
int level = 0);
327 friend class boost::serialization::access;
329 template<
class Archive>
330 void serialize( Archive &ar,
const unsigned int ) {
331 ar.template register_type<DTNode>();
332 ar.template register_type<DTNodeInternal>();
333 ar.template register_type<DecisionTreeTrainParams>();
335 ar & boost::serialization::make_nvp(
"DTCBase", boost::serialization::base_object<
Classifier<Val> >(*
this) );
336 ar & boost::serialization::make_nvp(
"Node", root_ );
337 ar & boost::serialization::make_nvp(
"params", params_);
346 template<
typename Val>
351 template<
typename Val>
353 :
Classifier<Val>(attr_domains, category_domain)
358 template<
typename Val>
368 template<
typename Val>
372 std::set<AttrIdd> attrib;
375 for (
typename ExamplesTrain::iterator it = ex.begin(); it != ex.end(); ++it) {
376 for(
typename ExampleTrain::iterator at = it->begin(); at != it->end();) {
377 AttrIdd& value = *at;
387 const ExampleTrain& ee = *it;
388 std::copy(ee.begin(), ee.end(), std::inserter(attrib, attrib.begin() ) );
392 std::transform(attrib.begin(), attrib.end(), std::back_inserter(tests),
393 boost::lambda::bind(boost::lambda::constructor<DTTest>(), boost::lambda::_1 ) );
395 root_ = buildTreeRecur(ex.begin(), ex.end(), tests, params_.allowedNbrMiscEx);
400 template<
typename Val>
403 return AttrIdd(AttrDomain::getUnknownId());
404 }
else if(e.empty() ) {
405 return root_->getMajorCategory();
407 return root_->getCategory(e);
412 template<
typename Val>
416 }
else if(e.empty() ) {
417 return root_->getBeliefs();
419 return root_->getCategories(e);
424 template<
typename Val>
427 os <<
"Empty DTC" << std::endl;
429 writeDecTreeNodes(os, root_, 0 );
440 template<
typename Val>
444 pruneTreeRecur(ex.begin(), ex.end(), root_ );
454 template<
typename Val>
455 typename DecisionTree<Val>::PDTNode
457 const DTTests& inTests,
const int ALLOWED_NBR_MISC_EX) {
464 int numCatWithManyExamples = 0;
465 const std::map<AttrIdd,int>& c = counters.
get();
466 for(
typename std::map<AttrIdd,int>::const_iterator i = c.begin(); i != c.end(); ++i) {
467 if( i->second > ALLOWED_NBR_MISC_EX) {
468 ++numCatWithManyExamples;
471 if(static_cast<int>(std::distance(eBeg, eEnd)) <= ALLOWED_NBR_MISC_EX ||
472 numCatWithManyExamples < 2) {
473 return DTNode::createLeaf(histogram);
476 DTTests tests(inTests);
477 typename DTTests::iterator best = tests.end();
478 double bestEntropy = std::numeric_limits<double>::min();
480 for(
typename DTTests::iterator i = tests.begin(); i != tests.end(); ++i ) {
481 double entr = i->entropyGain(eBeg, eEnd, params_.minInfGain);
482 if(entr > bestEntropy) {
487 if( best == tests.end() || bestEntropy < params_.minInfGain ) {
488 return DTNode::createLeaf(histogram);
493 typename ExamplesTrain::iterator middle = std::stable_partition(eBeg, eEnd, boost::bind(&DTTest::test, boost::ref(*best), _1) );
494 DTTest bestTest(*best);
496 PDTNode nTrue = buildTreeRecur(eBeg, middle, tests, ALLOWED_NBR_MISC_EX);
497 PDTNode nFalse = buildTreeRecur(middle, eEnd, tests, ALLOWED_NBR_MISC_EX);
498 return DTNode::createInternal(histogram, bestTest, nTrue, nFalse);
506 template<
typename Val>
507 typename DecisionTree<Val>::PDTNode
510 if(node->isLeaf() || std::distance(eBeg, eEnd) < 1)
514 const DTTest& t = *(node->getTest());
516 typename ExamplesTrain::iterator
517 middle = std::stable_partition(eBeg, eEnd, boost::bind(&DTTest::test, boost::ref(t), _1) );
519 node->setNodeTrue( pruneTreeRecur( eBeg, middle, node->getNodeTrue() ) );
520 node->setNodeFalse( pruneTreeRecur( middle, eEnd, node->getNodeFalse() ) );
523 int leafCount = 0, treeCount = 0;
524 for(
typename ExamplesTrain::const_iterator i = eBeg; i != eEnd; ++i) {
525 const ExampleTrain& e = *i;
526 if(node->getMajorCategory() == e.getFeature() )
528 if(node->getCategory(e) == e.getFeature())
533 if( leafCount >= treeCount ) {
534 return DTNode::createLeaf(node->getBeliefs());
542 template<
typename Val>
545 os << std::string(level,
' ');
548 writeDecTreeNodes(os, node->getNodeTrue(), level+1);
549 writeDecTreeNodes(os, node->getNodeFalse(), level+1);
558 template<
typename Val>
567 for(
typename ExamplesTrain::const_iterator i = eBeg; i != eEnd; ++ i) {
568 const ExampleTrain& ex = *i;
569 if( this->test(ex) ) {
576 double sum =
static_cast<double>( std::distance(eBeg, eEnd) );
577 double nrAcc =
static_cast<double>(acc.
getSum() );
578 double nrNAcc =
static_cast<double>(nacc.
getSum() );
580 if( testIc < minInfGain ) {
584 double entropy = acc.
entropy() * nrAcc / sum + nacc.
entropy() * nrNAcc / sum;
586 double gain = befSplit.
entropy() - entropy;
587 return gain / testIc;
592 template<
typename Val>
594 return std::find(e.begin(), e.end(), idd_) != e.end();
598 template<
typename Val>
600 os <<
"Domain: " << idd_->getDomain()->getId() <<
", Value:" << idd_->get();
608 template<
typename Val>
609 typename DecisionTree<Val>::PDTNode
611 return PDTNode(
new DTNode(catBel) );
615 template<
typename Val>
616 typename DecisionTree<Val>::PDTNode
618 return PDTNode(
new DTNodeInternal(catBel, test, nTrue, nFalse) );
626 #endif //FAIF_DECISION_TREE_HPP Val::DomainType::ValueId AttrIdd
attribute id representation in learning
Definition: Classifier.hpp:55
virtual void write(std::ostream &os) const
Definition: DecisionTree.hpp:425
void setTrainParams(const DecisionTreeTrainParams &p)
Definition: DecisionTree.hpp:100
Val::Value AttrValue
attribute value representation in learning
Definition: Classifier.hpp:49
const Counters & get() const
access to counters
Definition: Classifier.hpp:277
Decision Tree Classifier.
Definition: DecisionTree.hpp:63
Definition: Classifier.hpp:233
point and some feature
Definition: Point.hpp:58
Belief< Val >::Beliefs Beliefs
collection of pair (AttrIdd, Probability)
Definition: Classifier.hpp:64
inner class - examples train collection
Definition: Classifier.hpp:82
virtual AttrIdd getCategory(const ExampleTest &) const
Definition: DecisionTree.hpp:401
virtual void train(const ExamplesTrain &e)
learn classifier (on the collection of training examples).
Definition: DecisionTree.hpp:369
double entropy() const
entropy of counters
Definition: Classifier.hpp:283
Point in n-space, each component of the same type.
Definition: Point.hpp:22
Beliefs getHistogram() const
histogram from counters - Beliefs class where each position is counter divided by counters sum...
Definition: Classifier.hpp:297
virtual Beliefs getCategories(const ExampleTest &) const
classify and return all classes with belief that the example is from given class
Definition: DecisionTree.hpp:413
Val::DomainType::ValueIdSerialize AttrIddSerialize
for serialization the const interferes
Definition: Classifier.hpp:58
const DecisionTreeTrainParams & getTrainParams() const
Definition: DecisionTree.hpp:97
double calcEntropy(double freq)
calculate x * log(x) value. If x == 0 return 0.
Definition: Classifier.hpp:28
param for training decision tree
Definition: DecisionTree.hpp:45
int getSum() const
optimization: instead of accumulate all values from counters container keep the integer member ...
Definition: Classifier.hpp:280
Val::DomainType AttrDomain
the attribute domain for learning
Definition: Classifier.hpp:52
virtual void reset()
Definition: DecisionTree.hpp:359
void prune(const ExamplesTrain &e)
prune tree - plase not use the example set used for training
Definition: DecisionTree.hpp:441
the clasiffier interface
Definition: Classifier.hpp:43