49 template <
int kd,
typename po
int_t,
typename data_t =
double>
64 inline bool operator() (
const point_t & a,
const point_t & b)
const
65 {
return a[i_] < b[i_]; }
115 const unsigned int num_pts):
127 data_t mean[kd] = { data_t(0) };
129 for (
unsigned int i = 0; i < num_pts; i++)
131 for (
unsigned int j = 0; j < kd; j++)
133 mean[j] += data_t(points[i][j]);
137 for (
unsigned int j = 0; j < kd; j++)
139 mean[j] /= data_t(num_pts);
143 data_t variance[kd] = { data_t(0) };
145 for (
unsigned int i = 0; i < num_pts; i++)
147 for (
unsigned int j = 0; j < kd; j++)
149 data_t d = data_t(points[i][j]) - mean[j];
150 variance[j] += d * d;
155 for (
unsigned int j = 1; j < kd; j++)
157 if (variance[j] > variance[i_])
163 if (variance[i_] != data_t(0))
166 point_sort_predicate predicate(i_);
167 std::sort(&points[0], &points[num_pts], predicate);
170 std::vector<data_t> unique(num_pts);
171 unsigned int num_unique = 1;
172 unique[0] = points[0][i_];
173 for (
unsigned int i = 0; i < num_pts; i++)
175 const data_t & prev = unique[num_unique - 1];
176 const data_t next = points[i][i_];
179 unique[num_unique] = next;
186 m_ = unique[num_unique / 2];
189 min_ = points[0][i_];
190 max_ = points[num_pts - 1][i_];
193 unsigned int num_a = 1;
194 for (; num_a < num_pts && data_t(points[num_a][i_]) < m_; num_a++);
195 unsigned int num_b = num_pts - num_a;
198 assert(points[num_a - 1][i_] != points[num_a][i_]);
214 variance[i_] = data_t(0);
218 if (variance[i_] == data_t(0))
237 points_ = node.points_;
238 num_pts_ = node.num_pts_;
239 parent_ = node.parent_;
269 const data_t & distance_to_median):
271 dist_(distance_to_median)
276 {
return dist_ < b.dist_; }
287 leaf(std::list<unexplored_branch_t> & unexplored,
288 const point_t & query,
289 const data_t & best_distance)
const
291 const data_t p = data_t(query[i_]);
292 const data_t da = std::min(std::max(min_ - p, 0.0), std::max(p - m_, 0.0));
293 const data_t db = std::min(std::max(m_ - p, 0.0), std::max(p - max_, 0.0));
295 if (da > best_distance && db > best_distance)
return NULL;
296 if (points_ != NULL)
return this;
302 return a_->leaf(unexplored, query, best_distance);
307 unexplored.push_back(unexplored_branch_t(a_, da));
308 return b_->leaf(unexplored, query, best_distance);
315 euclidian_distance(
const point_t & query)
const
317 assert(points_ != NULL);
319 const point_t & point = points_[0];
320 data_t distance = 0.0;
321 for (
unsigned int j = 0; j < kd; j++)
323 data_t d = data_t(query[j]) - data_t(point[j]);
326 distance = data_t(sqrt(
double(distance)));
331 void dump(std::ostream & so)
const
335 for (
unsigned int i = 0; i < num_pts_; i++)
344 so << points_[i] << std::endl;
350 if (a_ != NULL) a_->dump(so);
351 if (b_ != NULL) b_->dump(so);
367 unsigned int num_pts_;
382 template <
int kd,
typename po
int_t,
typename data_t>
383 inline std::ostream &
384 operator << (std::ostream & so, const node_t<kd, point_t, data_t> & node)
393 template <
int kd,
typename po
int_t,
typename data_t =
double>
428 if (tree.root_ != NULL)
437 void setup(point_t * points,
const unsigned int num_pts)
459 inline bool operator < (
const nn_t & nn)
const
460 {
return dist_ < nn.dist_; }
471 nn(
const point_t & query,
472 data_t & best_distance,
473 std::list<nn_t> & nn_sorted,
474 const unsigned int max_traversals = 200,
475 const unsigned int max_nn = 3)
const
477 if (root_ == NULL)
return 0;
481 best_distance = std::numeric_limits<data_t>::max();
484 std::list<unexplored_branch_t> unexplored;
485 unexplored.push_back(unexplored_branch_t(root_, best_distance));
487 for (
unsigned int i = 0; i < max_traversals && !unexplored.empty(); i++)
491 unexplored.front().node_;
492 unexplored.pop_front();
496 start_here->leaf(unexplored, query, best_distance);
497 if (match == NULL)
continue;
500 data_t distance = match->euclidian_distance(query);
503 if (distance < best_distance)
506 best_distance = distance;
511 nn_sorted.push_front(
nn_t(best_match, best_distance));
514 if (nn_sorted.size() > max_nn)
517 nn_sorted.pop_back();
522 typename std::list<unexplored_branch_t>::iterator it =
525 while (it != unexplored.end())
527 const unexplored_branch_t & ub = *it;
529 if (ub.dist_ > best_distance)
531 it = unexplored.erase(it);
543 return nn_sorted.size();
548 nn(
const point_t & query,
549 data_t & best_distance,
550 unsigned int max_traversals = 200)
const
552 std::list<nn_t> nn_sorted;
553 if (nn(query, best_distance, nn_sorted, max_traversals, 1) == 0)
558 return nn_sorted.front().node_;
563 neighbors(
const point_t & query,
564 const data_t & radius,
565 std::list<nn_t> & nn_sorted,
566 const unsigned int max_traversals = 200)
const
568 if (root_ == NULL)
return 0;
571 std::list<unexplored_branch_t> unexplored;
572 unexplored.push_back(unexplored_branch_t(root_, 0.0));
574 for (
unsigned int i = 0; i < max_traversals && !unexplored.empty(); i++)
578 unexplored.front().node_;
579 unexplored.pop_front();
583 start_here->leaf(unexplored, query, radius);
584 if (match == NULL)
continue;
587 data_t distance = match->euclidian_distance(query);
588 if (distance > radius)
continue;
589 nn_sorted.push_back(nn_t(match, distance));
598 return nn_sorted.size();
602 void dump(std::ostream & so)
const
604 if (root_ != NULL) root_->dump(so);
614 template <
int kd,
typename po
int_t,
typename data_t>
615 inline std::ostream &
616 operator << (std::ostream & so, const tree_t<kd, point_t, data_t> & tree)