Seg3D  2.4
Seg3D is a free volume segmentation and processing tool developed by the NIH Center for Integrative Biomedical Computing at the University of Utah Scientific Computing and Imaging (SCI) Institute.
tree.hxx
1 /*
2  For more information, please see: http://software.sci.utah.edu
3 
4  The MIT License
5 
6  Copyright (c) 2016 Scientific Computing and Imaging Institute,
7  University of Utah.
8 
9 
10  Permission is hereby granted, free of charge, to any person obtaining a
11  copy of this software and associated documentation files (the "Software"),
12  to deal in the Software without restriction, including without limitation
13  the rights to use, copy, modify, merge, publish, distribute, sublicense,
14  and/or sell copies of the Software, and to permit persons to whom the
15  Software is furnished to do so, subject to the following conditions:
16 
17  The above copyright notice and this permission notice shall be included
18  in all copies or substantial portions of the Software.
19 
20  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
21  OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
23  THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
25  FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
26  DEALINGS IN THE SOFTWARE.
27  */
28 
29 // File : tree.hxx
30 // Author : Pavel A. Koshevoy
31 // Created : 2006/02/09 13:14
32 // Copyright : (C) 2004-2008 University of Utah
33 // Description : A kd-tree.
34 
35 #ifndef TREE_HXX_
36 #define TREE_HXX_
37 
38 // system includes:
39 #include <vector>
40 #include <list>
41 #include <algorithm>
42 #include <iostream>
43 #include <math.h>
44 
45 
46 //----------------------------------------------------------------
47 // node_t
48 //
49 template <int kd, typename point_t, typename data_t = double>
50 class node_t
51 {
52 public:
53 
54  //----------------------------------------------------------------
55  // point_sort_predicate
56  //
58  {
59  public:
60  point_sort_predicate(const unsigned int & sort_on_dimension):
61  i_(sort_on_dimension)
62  {}
63 
64  inline bool operator() (const point_t & a, const point_t & b) const
65  { return a[i_] < b[i_]; }
66 
67  // dimension along which the points will be sorted:
68  unsigned int i_;
69  };
70 
71  // default constructor:
72  node_t():
73  i_(0),
74  m_(data_t(0)),
75  min_(data_t(0)),
76  max_(data_t(0)),
77  points_(NULL),
78  num_pts_(0),
79  parent_(NULL),
80  a_(NULL),
81  b_(NULL)
82  {}
83 
84  // copy constructor:
85  node_t(const node_t<kd, point_t, data_t> & node):
86  i_(0),
87  m_(data_t(0)),
88  min_(data_t(0)),
89  max_(data_t(0)),
90  points_(NULL),
91  num_pts_(0),
92  parent_(NULL),
93  a_(NULL),
94  b_(NULL)
95  { *this = node; }
96 
97  // destructor:
98  ~node_t()
99  {
100  delete a_;
101  a_ = NULL;
102 
103  delete b_;
104  b_ = NULL;
105  }
106 
107  //----------------------------------------------------------------
108  // node_t
109  //
110  // given a list of points, where each point is composed of
111  // kd elements of type data_t, construct a balanced kd-tree
112  //
114  point_t * points,
115  const unsigned int num_pts):
116  i_(0),
117  m_(data_t(0)),
118  min_(data_t(0)),
119  max_(data_t(0)),
120  points_(NULL),
121  num_pts_(0),
122  parent_(parent),
123  a_(NULL),
124  b_(NULL)
125  {
126  // first, find the mean point value for each dimension:
127  data_t mean[kd] = { data_t(0) };
128 
129  for (unsigned int i = 0; i < num_pts; i++)
130  {
131  for (unsigned int j = 0; j < kd; j++)
132  {
133  mean[j] += data_t(points[i][j]);
134  }
135  }
136 
137  for (unsigned int j = 0; j < kd; j++)
138  {
139  mean[j] /= data_t(num_pts);
140  }
141 
142  // next, find the point variance for each dimension:
143  data_t variance[kd] = { data_t(0) };
144 
145  for (unsigned int i = 0; i < num_pts; i++)
146  {
147  for (unsigned int j = 0; j < kd; j++)
148  {
149  data_t d = data_t(points[i][j]) - mean[j];
150  variance[j] += d * d;
151  }
152  }
153 
154  // find the dimension with the highest variance:
155  for (unsigned int j = 1; j < kd; j++)
156  {
157  if (variance[j] > variance[i_])
158  {
159  i_ = j;
160  }
161  }
162 
163  if (variance[i_] != data_t(0))
164  {
165  // sort the points in the ascending order:
166  point_sort_predicate predicate(i_);
167  std::sort(&points[0], &points[num_pts], predicate);
168 
169  // find the median:
170  std::vector<data_t> unique(num_pts);
171  unsigned int num_unique = 1;
172  unique[0] = points[0][i_];
173  for (unsigned int i = 0; i < num_pts; i++)
174  {
175  const data_t & prev = unique[num_unique - 1];
176  const data_t next = points[i][i_];
177  if (prev != next)
178  {
179  unique[num_unique] = next;
180  num_unique++;
181  }
182  }
183 
184  if (num_unique > 1)
185  {
186  m_ = unique[num_unique / 2];
187 
188  // store the range of point values bounded by this node:
189  min_ = points[0][i_];
190  max_ = points[num_pts - 1][i_];
191 
192  // find the cut:
193  unsigned int num_a = 1;
194  for (; num_a < num_pts && data_t(points[num_a][i_]) < m_; num_a++);
195  unsigned int num_b = num_pts - num_a;
196 
197  // FIXME:
198  assert(points[num_a - 1][i_] != points[num_a][i_]);
199 
200  if (num_a != 0)
201  {
202  a_ = new node_t<kd, point_t, data_t>(this, &(points[0]), num_a);
203  }
204 
205  if (num_b != 0)
206  {
207  b_ = new node_t<kd, point_t, data_t>(this, &(points[num_a]), num_b);
208  }
209  }
210  else // num_unique == 1
211  {
212  // this can happen due to numerical imprecision of
213  // calculating the mean (an therefore - the variance):
214  variance[i_] = data_t(0);
215  }
216  }
217 
218  if (variance[i_] == data_t(0))
219  {
220  // this is a leaf node:
221  m_ = points[0][i_];
222  min_ = m_;
223  max_ = m_;
224  points_ = points;
225  num_pts_ = num_pts;
226  }
227  }
228 
229  // assignment operator (deep copy):
231  operator = (const node_t<kd, point_t, data_t> & node)
232  {
233  i_ = node.i_;
234  m_ = node.m_;
235  min_ = node.min_;
236  max_ = node.max_;
237  points_ = node.points_;
238  num_pts_ = node.num_pts_;
239  parent_ = node.parent_;
240 
241  delete a_;
242  a_ = NULL;
243 
244  delete b_;
245  b_ = NULL;
246 
247  if (node.a_ != NULL)
248  {
249  a_ = new node_t<kd, point_t, data_t>(*node.a_);
250  a_->parent_ = this;
251  }
252 
253  if (node.b_ != NULL)
254  {
255  b_ = new node_t<kd, point_t, data_t>(*node.b_);
256  b_->parent_ = this;
257  }
258 
259  return *this;
260  }
261 
262  //----------------------------------------------------------------
263  // unexplored_branch_t
264  //
266  {
267  public:
269  const data_t & distance_to_median):
270  node_(node),
271  dist_(distance_to_median)
272  {}
273 
274  // this is used to sort the nodes:
275  inline bool operator < (const unexplored_branch_t & b) const
276  { return dist_ < b.dist_; }
277 
278  // the root node of the unexplored branch:
279  const node_t<kd, point_t, data_t> * node_;
280 
281  // distance from the query point to the median:
282  data_t dist_;
283  };
284 
285  // find a given leaf node:
287  leaf(std::list<unexplored_branch_t> & unexplored,
288  const point_t & query,
289  const data_t & best_distance) const
290  {
291  const data_t p = data_t(query[i_]);
292  const data_t da = std::min(std::max(min_ - p, 0.0), std::max(p - m_, 0.0));
293  const data_t db = std::min(std::max(m_ - p, 0.0), std::max(p - max_, 0.0));
294 
295  if (da > best_distance && db > best_distance) return NULL;
296  if (points_ != NULL) return this;
297 
298  if (da < db)
299  {
300  // will traverse branch a, save branch b for later:
301  unexplored.push_back(unexplored_branch_t(b_, db));
302  return a_->leaf(unexplored, query, best_distance);
303  }
304  else
305  {
306  // will traverse branch b, save branch a for later:
307  unexplored.push_back(unexplored_branch_t(a_, da));
308  return b_->leaf(unexplored, query, best_distance);
309  }
310  }
311 
312  // calculate the Euclidian distance between the points stored
313  // in the leaf node and the query point:
314  data_t
315  euclidian_distance(const point_t & query) const
316  {
317  assert(points_ != NULL);
318 
319  const point_t & point = points_[0];
320  data_t distance = 0.0;
321  for (unsigned int j = 0; j < kd; j++)
322  {
323  data_t d = data_t(query[j]) - data_t(point[j]);
324  distance += d * d;
325  }
326  distance = data_t(sqrt(double(distance)));
327  return distance;
328  }
329 
330  // FIXME: this is for debugging:
331  void dump(std::ostream & so) const
332  {
333  if (points_ != NULL)
334  {
335  for (unsigned int i = 0; i < num_pts_; i++)
336  {
337 //#if 0
338 // for (unsigned int j = 0; j < kd; j++)
339 // {
340 // so << ' ' << points_[i][j];
341 // }
342 // so << std::endl;
343 //#else
344  so << points_[i] << std::endl;
345 //#endif
346  }
347  }
348  else
349  {
350  if (a_ != NULL) a_->dump(so);
351  if (b_ != NULL) b_->dump(so);
352  }
353  }
354 
355  // dimension index in the point:
356  unsigned int i_;
357 
358  // median point value along that dimension:
359  data_t m_;
360 
361  // min/max point values along that dimension:
362  data_t min_;
363  data_t max_;
364 
365  // the payload:
366  point_t * points_;
367  unsigned int num_pts_;
368 
369  // the parent node:
370  node_t<kd, point_t, data_t> * parent_;
371 
372  // branch containing points with value lesser than the median:
374 
375  // branch containing points with value greater or equal to the median:
377 };
378 
379 //----------------------------------------------------------------
380 // operator <<
381 //
382 template <int kd, typename point_t, typename data_t>
383 inline std::ostream &
384 operator << (std::ostream & so, const node_t<kd, point_t, data_t> & node)
385 {
386  node.dump(so);
387  return so;
388 }
389 
390 //----------------------------------------------------------------
391 // tree_t
392 //
393 template <int kd, typename point_t, typename data_t = double>
394 class tree_t
395 {
396 public:
397  //----------------------------------------------------------------
398  // unexplored_branch_t
399  //
402 
403  // default constructor:
404  tree_t():
405  root_(NULL)
406  {}
407 
408  // copy constructor:
409  tree_t(const tree_t<kd, point_t, data_t> & tree):
410  root_(NULL)
411  {
412  *this = tree;
413  }
414 
415  // destructor:
416  ~tree_t()
417  {
418  delete root_;
419  root_ = NULL;
420  }
421 
422  // assignment operator:
424  operator = (const tree_t<kd, point_t, data_t> & tree)
425  {
426  delete root_;
427  root_ = NULL;
428  if (tree.root_ != NULL)
429  {
430  root_ = new node_t<kd, point_t, data_t>(*(tree.root_));
431  }
432 
433  return *this;
434  }
435 
436  // build a kd-tree from a given set of points (duplicates are allowed):
437  void setup(point_t * points, const unsigned int num_pts)
438  {
439  delete root_;
440  root_ = NULL;
441 
442  if (num_pts != 0)
443  {
444  root_ = new node_t<kd, point_t, data_t>(NULL, points, num_pts);
445  }
446  }
447 
448  //----------------------------------------------------------------
449  // nn_t
450  //
451  class nn_t
452  {
453  public:
454  nn_t(const node_t<kd, point_t, data_t> * node, const double & dist):
455  node_(node),
456  dist_(dist)
457  {}
458 
459  inline bool operator < (const nn_t & nn) const
460  { return dist_ < nn.dist_; }
461 
462  // pointer to the destination node:
463  const node_t<kd, point_t, data_t> * node_;
464 
465  // distance to the destination node:
466  double dist_;
467  };
468 
469  // find the node that contains the nearest neighbor(s) of a given point:
470  unsigned int
471  nn(const point_t & query,
472  data_t & best_distance,
473  std::list<nn_t> & nn_sorted,
474  const unsigned int max_traversals = 200,
475  const unsigned int max_nn = 3) const
476  {
477  if (root_ == NULL) return 0;
478 
479  // the results:
480  const node_t<kd, point_t, data_t> * best_match = NULL;
481  best_distance = std::numeric_limits<data_t>::max();
482 
483  // bootstrap the search by starting at the root:
484  std::list<unexplored_branch_t> unexplored;
485  unexplored.push_back(unexplored_branch_t(root_, best_distance));
486 
487  for (unsigned int i = 0; i < max_traversals && !unexplored.empty(); i++)
488  {
489  // retrieve the next search entry point:
490  const node_t<kd, point_t, data_t> * start_here =
491  unexplored.front().node_;
492  unexplored.pop_front();
493 
494  // find a matching point in the tree:
495  const node_t<kd, point_t, data_t> * match =
496  start_here->leaf(unexplored, query, best_distance);
497  if (match == NULL) continue;
498 
499  // find the distance to the matching point:
500  data_t distance = match->euclidian_distance(query);
501 
502  // update the best neighbor estimate:
503  if (distance < best_distance)
504  {
505  best_match = match;
506  best_distance = distance;
507 
508  // NOTE: this does not guarantee that the "correct" match will
509  // be added to the list, because it is entirely possible that
510  // the "correct" match was never found due to being pruned out.
511  nn_sorted.push_front(nn_t(best_match, best_distance));
512 
513  // FIXME: this should be a
514  if (nn_sorted.size() > max_nn)
515  {
516  // remove the worst neighbor:
517  nn_sorted.pop_back();
518  }
519  }
520 
521  // prune some of the branches:
522  typename std::list<unexplored_branch_t>::iterator it =
523  unexplored.begin();
524 
525  while (it != unexplored.end())
526  {
527  const unexplored_branch_t & ub = *it;
528 
529  if (ub.dist_ > best_distance)
530  {
531  it = unexplored.erase(it);
532  }
533  else
534  {
535  ++it;
536  }
537  }
538 
539  // sort the unexplored branches:
540  unexplored.sort();
541  }
542 
543  return nn_sorted.size();
544  }
545 
546  // same as above, except ignoring the full list of close neighbors found:
548  nn(const point_t & query,
549  data_t & best_distance,
550  unsigned int max_traversals = 200) const
551  {
552  std::list<nn_t> nn_sorted;
553  if (nn(query, best_distance, nn_sorted, max_traversals, 1) == 0)
554  {
555  return NULL;
556  }
557 
558  return nn_sorted.front().node_;
559  }
560 
561  // collect nodes within some radius around a given point:
562  unsigned int
563  neighbors(const point_t & query,
564  const data_t & radius,
565  std::list<nn_t> & nn_sorted,
566  const unsigned int max_traversals = 200) const
567  {
568  if (root_ == NULL) return 0;
569 
570  // bootstrap the search by starting at the root:
571  std::list<unexplored_branch_t> unexplored;
572  unexplored.push_back(unexplored_branch_t(root_, 0.0));
573 
574  for (unsigned int i = 0; i < max_traversals && !unexplored.empty(); i++)
575  {
576  // retrieve the next search entry point:
577  const node_t<kd, point_t, data_t> * start_here =
578  unexplored.front().node_;
579  unexplored.pop_front();
580 
581  // find a matching point in the tree:
582  const node_t<kd, point_t, data_t> * match =
583  start_here->leaf(unexplored, query, radius);
584  if (match == NULL) continue;
585 
586  // find the distance to the matching point:
587  data_t distance = match->euclidian_distance(query);
588  if (distance > radius) continue;
589  nn_sorted.push_back(nn_t(match, distance));
590 
591  // sort the unexplored branches:
592  unexplored.sort();
593  }
594 
595  // sort the neighbors from closest to furthest away:
596  nn_sorted.sort();
597 
598  return nn_sorted.size();
599  }
600 
601  // dump the leaf nodes into the stream:
602  void dump(std::ostream & so) const
603  {
604  if (root_ != NULL) root_->dump(so);
605  else so << "NULL";
606  }
607 
609 };
610 
611 //----------------------------------------------------------------
612 // operator <<
613 //
614 template <int kd, typename point_t, typename data_t>
615 inline std::ostream &
616 operator << (std::ostream & so, const tree_t<kd, point_t, data_t> & tree)
617 {
618  tree.dump(so);
619  return so;
620 }
621 
622 
623 #endif // TREE_HXX_
Definition: tree.hxx:394
Definition: tree.hxx:57
Definition: tree.hxx:451
Definition: tree.hxx:50
Definition: tree.hxx:265