faif
KNearestNeighbor.hpp
1 // The k-Nearest Neighbor classifier
2 
3 #ifndef FAIF_K_NEAREST_NEIGHBOR_CLASSIFIER_HPP
4 #define FAIF_K_NEAREST_NEIGHBOR_CLASSIFIER_HPP
5 
6 #include <vector>
7 #include <functional>
8 
9 #include <boost/bind.hpp>
10 
11 #include "Classifier.hpp"
12 
13 #include <boost/serialization/split_member.hpp>
14 #include <boost/serialization/base_object.hpp>
15 #include <boost/serialization/nvp.hpp>
16 #include <boost/serialization/vector.hpp>
17 
18 namespace faif {
19  namespace ml {
20 
21 
22  /** \brief Distance metrics for Nomianal Values Collection, used in K Nearest Neighbors classifier
23 
24  distance between values is 0.0 (values equal), 0.5 (unknown value and other value) or 1.0 (different values)
25  distance between points are sum of distances of coordinates
26  */
27  template<typename Val> class DistanceNominalValue {
28  BOOST_CONCEPT_ASSERT((ValueConcept<Val>));
29  public:
30  typedef typename Classifier<Val>::AttrValue AttrValue;
31  typedef typename Classifier<Val>::AttrDomain AttrDomain;
32  typedef typename Classifier<Val>::AttrIdd AttrIdd;
33  typedef typename Classifier<Val>::Domains Domains;
35 
36  static double distance(const ExampleTest& a, const ExampleTest& b) {
37  double distance = 0.0;
38  typename ExampleTest::const_iterator i = a.begin(), j = b.begin();
39  for(; i != a.end() && j != b.end(); ++i, ++j) {
40  if(*i != *j) {
41  if(*i == AttrDomain::getUnknownId() || *j == AttrDomain::getUnknownId() )
42  distance += 0.5;
43  else
44  distance += 1.0;
45  }
46  }
47  for(; i != a.end(); ++i) { distance += 1.0; } //not matched values from 'a' example
48  for(; j != b.end(); ++j) { distance += 1.0; } //not matched values from 'b' example
49  return distance;
50  }
51  };
52 
53  /** \brief k Nearest Neighbor classifier
54 
55  Contains the attributes, attribute values and categories,
56  train examples, test examples and classifier methods.
57  */
58  template<typename Val,
59  template <typename> class Distance = DistanceNominalValue
60  >
61  class KNearestNeighbor : public Classifier<Val> {
62  public:
63  typedef typename Classifier<Val>::AttrValue AttrValue;
64  typedef typename Classifier<Val>::AttrDomain AttrDomain;
65  typedef typename Classifier<Val>::AttrIdd AttrIdd;
67  typedef typename Classifier<Val>::Domains Domains;
68  typedef typename Classifier<Val>::Beliefs Beliefs;
72  public:
74  KNearestNeighbor(const Domains& attr_domains, const AttrDomain& category_domain);
75 
76  virtual ~KNearestNeighbor() { }
77 
78  /** clear the classifier */
79  virtual void reset();
80 
81  /** \brief learn classifier (on the collection of training examples), here store all train examples.
82  */
83  virtual void train(const ExamplesTrain& e);
84 
85  /** classify - find the category in neighbors, use default K (number of neighbors) */
86  virtual AttrIdd getCategory(const ExampleTest& e) const { return getCategoryK(e, defaultK_); }
87 
88  /** classify - find the category in neighbors, use given K (number of neighbors) */
89  AttrIdd getCategoryK(const ExampleTest& e, int K) const;
90 
91 
92  /** \brief classify and return all classes with belief that the example is from given class
93 
94  use default K (number of neighbors)
95  */
96  virtual Beliefs getCategories(const ExampleTest& e) const { return getCategoriesK(e, defaultK_); }
97 
98  /** \brief classify and return all classes with belief that the example is from given class
99 
100  use given K (number of neighbors)
101  */
102  Beliefs getCategoriesK(const ExampleTest& e, int K) const;
103 
104  /** the ostream method */
105  virtual void write(std::ostream& os) const;
106 
107  /** accessor - get the default number of neighbors used in calculation */
108  int getDefaultK() const { return defaultK_; }
109 
110  /** mutator - set the default number of neighbors used in calculation */
111  void setDefaultK(int k) { defaultK_ = k; }
112  private:
113  /** copy c-tor not allowed */
115  /** assignment not allowed */
116  KNearestNeighbor& operator=(const KNearestNeighbor&);
117 
118  ExamplesTrain memory_; //store training examples
119  int defaultK_; //default number of neighbors used in calculation
120  private:
121  /** \brief serialization using boost::serialization */
122  friend class boost::serialization::access;
123 
124  template<class Archive>
125  void serialize( Archive &ar, const unsigned int ) {
126  ar & boost::serialization::make_nvp("KNNBase", boost::serialization::base_object<Classifier<Val> >(*this) );
127  ar & boost::serialization::make_nvp("memory", memory_ );
128  ar & boost::serialization::make_nvp("defaultK", defaultK_ );
129  }
130 
131  };
132 
133  //////////////////////////////////////////////////////////////////////////////////////////////////
134  // class KNearestNeighbor implementation
135  //////////////////////////////////////////////////////////////////////////////////////////////////
136 
137  template<typename Val, template <typename> class Distance>
138  KNearestNeighbor<Val, Distance>::KNearestNeighbor() : Classifier<Val>(), memory_(), defaultK_(3)
139  { }
140 
141  template<typename Val, template <typename> class Distance>
142  KNearestNeighbor<Val, Distance>::KNearestNeighbor(const Domains& attr_domains, const AttrDomain& category_domain)
143  : Classifier<Val>(attr_domains, category_domain), memory_(), defaultK_(3)
144  { }
145 
146  /** \brief reset - clear the memory */
147  template<typename Val, template <typename> class Distance>
149  memory_.clear();
150  }
151 
152  /** \brief learn classifier (on the collection of training examples) - remember training examples */
153  template<typename Val, template <typename> class Distance>
154  void KNearestNeighbor<Val, Distance>::train(const ExamplesTrain& e) {
155  memory_ = e;
156  }
157 
158  /** classify - return the major category for best node from decision tree */
159  template<typename Val, template <typename> class Distance>
160  typename KNearestNeighbor<Val, Distance>::AttrIdd
161  KNearestNeighbor<Val, Distance>::getCategoryK(const ExampleTest& e, int K) const {
162  Beliefs bel = getCategoriesK(e, K);
163  if( bel.empty() )
164  return AttrDomain::getUnknownId();
165  else
166  return bel.front().getValue(); //histogram is sorted
167  }
168 
169  /** \brief classify and return all classes with belief that the example is from given class */
170  template<typename Val, template <typename> class Distance>
171  typename KNearestNeighbor<Val, Distance>::Beliefs
172  KNearestNeighbor<Val, Distance>::getCategoriesK(const ExampleTest& e, int K) const {
173 
174  typedef std::pair<typename ExamplesTrain::const_iterator, double> DistanceDescr;
175  typedef std::vector<DistanceDescr> DistanceDescrVec;
176 
177  //std::cout << "Neighbors for: " << e << std::endl;
178  DistanceDescrVec distances;
179  distances.reserve( memory_.size() );
180  for(typename ExamplesTrain::const_iterator ii = memory_.begin(); ii != memory_.end(); ++ii) {
181  //std::cout << "Memory: " << *ii << " Distance: " << Distance<Val>::distance( *ii, e ) << std::endl;
182  distances.push_back( DistanceDescr(ii, Distance<Val>::distance( *ii, e ) ) );
183  }
184  typename DistanceDescrVec::iterator middle = distances.end(); //for too small collection
185  if( distances.end() - distances.begin() > K) {
186  middle = distances.begin() + K; //middle between begin and end
187  }
188  std::partial_sort( distances.begin(), middle, distances.end(),
189  boost::bind( std::less<double>(), boost::bind(&DistanceDescr::second, _1), boost::bind(&DistanceDescr::second, _2) ) );
190 
192  for(typename DistanceDescrVec::const_iterator jj = distances.begin(); jj != middle; ++jj) {
193  const ExampleTrain& ex = *(jj->first);
194  counters.inc(ex);
195  }
196  return counters.getHistogram();
197  }
198 
199  /** ostream method */
200  template<typename Val, template <typename> class Distance>
201  void KNearestNeighbor<Val, Distance>::write(std::ostream& os) const {
202  os << "KNN classifier, defaultK=" << defaultK_ << ", memSize=" << memory_.size() << ":" << std::endl;
203  std::copy(memory_.begin(), memory_.end(), std::ostream_iterator<ExampleTrain>(os,";") );
204  os << std::endl;
205  }
206 
207  }//namespace ml
208 } //namespace faif
209 
210 #endif //FAIF_K_NEAREST_NEIGHBOR_CLASSIFIER_HPP
Val::DomainType::ValueId AttrIdd
attribute id representation in learning
Definition: Classifier.hpp:55
k Nearest Neighbor classifier
Definition: KNearestNeighbor.hpp:61
virtual Beliefs getCategories(const ExampleTest &e) const
classify and return all classes with belief that the example is from given class
Definition: KNearestNeighbor.hpp:96
Definition: Chain.h:17
Val::Value AttrValue
attribute value representation in learning
Definition: Classifier.hpp:49
virtual void write(std::ostream &os) const
Definition: KNearestNeighbor.hpp:201
int getDefaultK() const
Definition: KNearestNeighbor.hpp:108
void setDefaultK(int k)
Definition: KNearestNeighbor.hpp:111
AttrIdd getCategoryK(const ExampleTest &e, int K) const
Definition: KNearestNeighbor.hpp:161
Definition: Classifier.hpp:233
point and some feature
Definition: Point.hpp:58
Belief< Val >::Beliefs Beliefs
collection of pair (AttrIdd, Probability)
Definition: Classifier.hpp:64
inner class - examples train collection
Definition: Classifier.hpp:82
Beliefs getCategoriesK(const ExampleTest &e, int K) const
classify and return all classes with belief that the example is from given class
Definition: KNearestNeighbor.hpp:172
virtual void reset()
reset - clear the memory
Definition: KNearestNeighbor.hpp:148
Point in n-space, each component of the same type.
Definition: Point.hpp:22
the value concept
Definition: Value.hpp:41
Val::DomainType::ValueIdSerialize AttrIddSerialize
for serialization the const interferes
Definition: Classifier.hpp:58
virtual void train(const ExamplesTrain &e)
learn classifier (on the collection of training examples), here store all train examples.
Definition: KNearestNeighbor.hpp:154
virtual AttrIdd getCategory(const ExampleTest &e) const
Definition: KNearestNeighbor.hpp:86
Val::DomainType AttrDomain
the attribute domain for learning
Definition: Classifier.hpp:52
Distance metrics for Nomianal Values Collection, used in K Nearest Neighbors classifier.
Definition: KNearestNeighbor.hpp:27
the clasiffier interface
Definition: Classifier.hpp:43