Shapeworks Studio  2.1
Shape analysis software suite
itkPSMProcrustesFunction.cxx
1 /*=========================================================================
2  *
3  * Copyright Insight Software Consortium
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  *=========================================================================*/
18 
19 #include "itkPSMProcrustesFunction.h"
20 #include <iostream>
21 #include <vnl/algo/vnl_svd.h>
22 #include <fstream>
23 
24 namespace itk
25 {
26 template<unsigned int VDimension>
28 ::RunGeneralizedProcrustes(SimilarityTransformListType & transforms,
29  ShapeListType & shapes)
30 {
31  ShapeListIteratorType leaveOutIt;
32  SimilarityTransformListIteratorType transformIt;
33  ShapeIteratorType shapeIt, meanIt;
34  ShapeType shape, mean;
35  SimilarityTransform3D transform;
36  PointType ssqShape, ssqMean;
37 
38  const RealType SOS_EPSILON = 0.1; //1.0e-8;
39 
40  int numOfShapes = shapes.size();
41  std::string errstring = "";
42  // Initialize transform structure
43  transform.rotation.set_identity();
44  transform.scale = 1.0;
45  transform.translation.fill(0.0);
46 
47  // Initialize transforms vector
48  transforms.clear();
49  transforms.reserve(shapes.size());
50  for(int i = 0; i<numOfShapes; i++)
51  {
52  transforms.push_back(transform);
53  }
54 
55  RealType sumOfSquares = ComputeSumOfSquares(shapes);
56  RealType newSumOfSquares, diff = 1e5; // 1e10;
57 
58  int counter = 0;
59  try
60  {
61  while(diff > SOS_EPSILON)
62  {
63  // Initialize ssqShape vector
64  ssqShape.fill(0.0);
65  // Initialize ssqMean vector
66  ssqMean.fill(0.0);
67  transformIt = transforms.begin();
68  //int count = 0;
69  for(leaveOutIt = shapes.begin(); leaveOutIt != shapes.end(); leaveOutIt++)
70  {
71  // Calculate mean of all shapes but one
72  LeaveOneOutMean(mean, shapes, leaveOutIt);
73  (*leaveOutIt) = RunProcrustes((*transformIt), mean, leaveOutIt);
74  transformIt++;
75  } // End shape list iteration
76 
77  // Fix scalings so geometric average = 1
78  RealType scaleAve = 0.0;
79  for(transformIt = transforms.begin(); transformIt != transforms.end(); transformIt++)
80  scaleAve += log((*transformIt).scale);
81 
82  scaleAve = exp(scaleAve / static_cast<RealType>(transforms.size()));
83 
84  SimilarityTransform3D scaleSim;
85  scaleSim.rotation.set_identity();
86  scaleSim.translation.fill(0.0);
87  scaleSim.scale = 1.0 / scaleAve;
88 
89  ShapeListIteratorType shapeListIt = shapes.begin();
90  transformIt = transforms.begin();
91  while(shapeListIt != shapes.end())
92  {
93  TransformShape((*shapeListIt), scaleSim);
94  (*transformIt).scale /= scaleAve;
95 
96  shapeListIt++;
97  transformIt++;
98  }
99  // Calculate the sum of squares of discrepancies between
100  // the shapes
101  newSumOfSquares = ComputeSumOfSquares(shapes);
102  diff = sumOfSquares - newSumOfSquares;
103 
104  sumOfSquares = newSumOfSquares;
105  counter++;
106  std::cout << "DIFF VALUE : " << diff << std::endl;
107  std::cout << "******PROCRUSTES FUNCTION COUNTER******: " << counter << std::endl;
108  if(counter >= 1000)
109  {
110  errstring = "Number of iterations on shapes is too high.";
111  ExceptionObject e( __FILE__, __LINE__ );
112  e.SetDescription( errstring.c_str() );
113  throw e;
114  }
115  } // End while loop
116  } // End try
117  catch(itk::ExceptionObject &e)
118  {
119  errstring = "ITK exception with description: " + std::string(e.GetDescription())
120  + std::string("\n at location:") + std::string(e.GetLocation())
121  + std::string("\n in file:") + std::string(e.GetFile());
122  }
123  catch(...)
124  {
125  errstring = "Unknown exception thrown";
126  }
127 }
128 // Explicitly instantiate the function for 3D and 2D
129 template void PSMProcrustesFunction<3>::RunGeneralizedProcrustes(SimilarityTransformListType & transforms,
130  ShapeListType & shapes);
131 template void PSMProcrustesFunction<2>::RunGeneralizedProcrustes(SimilarityTransformListType & transforms,
132  ShapeListType & shapes);
133 
134 template<unsigned int VDimension>
135 typename PSMProcrustesFunction<VDimension>::ShapeType
137 ::RunProcrustes(SimilarityTransform3D & transform, ShapeType mean,
138  ShapeListIteratorType & leaveOutIt)
139 {
140  ShapeIteratorType shapeIt1, shapeIt2;
141  SimilarityTransform3D newTransform;
142  ShapeType shapeScaled, meanScaled;
143  PointType colMeanShape, colMeanMean, ssqShape, ssqMean;
144  double normMean, normShape;
145  vnl_matrix_fixed<RealType, VDimension, VDimension> shapeMat;
146  shapeMat.fill(0.0);
147 
148  int numPoints = (*leaveOutIt).size();
149 
150  vnl_matrix<double> meanScaledTranspose(VDimension,numPoints);
151 
152  // Initialize variables
153  colMeanShape.fill(0.0);
154  colMeanMean.fill(0.0);
155  normMean = 0;
156  normShape = 0;
157  ssqShape.fill(0.0);
158  ssqMean.fill(0.0);
159 
160  // Centering the shapes at the origin
161  // First calculate mean along columns
162  for(int j = 0; j<VDimension; j++)
163  {
164  for(int i = 0; i<numPoints; i++)
165  {
166  colMeanShape[j] += (*leaveOutIt)[i][j];
167  colMeanMean[j] += mean[i][j];
168  }
169  colMeanShape[j] = colMeanShape[j]/numPoints;
170  colMeanMean[j] = colMeanMean[j]/numPoints;
171  }
172 
173  // Repeat rows to create new vector
174  ShapeIteratorType it1 = (*leaveOutIt).begin();
175  ShapeIteratorType it2 = mean.begin();
176 
177  while(it1 != (*leaveOutIt).end())
178  {
179  shapeScaled.push_back((*it1) - colMeanShape);
180  meanScaled.push_back((*it2) - colMeanMean);
181  it1++;
182  it2++;
183  }
184 
185  // Calculate sum of squared elements of shapeScaled and meanScaled vectors along columns
186  for(int j = 0; j<VDimension; j++)
187  {
188  for(int i = 0; i<numPoints; i++)
189  {
190  ssqShape[j] += (shapeScaled[i][j] * shapeScaled[i][j]);
191  ssqMean[j] += (meanScaled[i][j] * meanScaled[i][j]);
192  }
193  }
194  // TODO: Check if dimensions match?
195  // Check if shapes are the same
196  bool constShape = CheckDegenerateCase(ssqShape, ssqMean, colMeanShape, colMeanMean, numPoints);
197 
198  // Continue iterations
199  if(constShape)
200  {
201  // Calculate scale normalizing value
202  for(int j = 0; j<VDimension; j++)
203  {
204  normShape += ssqShape[j];
205  normMean += ssqMean[j];
206  }
207  normShape = sqrt(normShape);
208  normMean = sqrt(normMean);
209  // Scale shapes to equal (unit) norm
210  ShapeIteratorType shapeScaledIt = shapeScaled.begin();
211  ShapeIteratorType meanScaledIt = meanScaled.begin();
212  while(shapeScaledIt != shapeScaled.end())
213  {
214  (*shapeScaledIt) = (*shapeScaledIt)/normShape;
215  (*meanScaledIt) = (*meanScaledIt)/normMean;
216  shapeScaledIt++;
217  meanScaledIt++;
218  }
219  }
220  // The degenerate cases: both shapes are the same
221  else
222  {
223  ShapeType output;
224  vnl_vector_fixed<RealType, VDimension> vec;
225  vec.fill(0.0);
226  for(int i = 0; i < numPoints; i++)
227  output.push_back(vec);
228  ShapeIteratorType outputIt = output.begin();
229  while(outputIt != output.end())
230  {
231  (*outputIt) = colMeanShape;
232  outputIt++;
233  }
234  transform.scale = 1.0;
235  transform.rotation.set_identity();
236  outputIt = output.begin();
237  transform.translation = (*outputIt);
238  return output;
239  }
240 
241  for(int j = 0; j<VDimension; j++)
242  {
243  for(int i = 0; i<numPoints; i++)
244  {
245  meanScaledTranspose[j][i] = meanScaled[i][j];
246  }
247  }
248 
249  // Build shapeMat = meanScaledTranspose * shapeScaled
250  for(int i = 0; i<VDimension; i++)
251  {
252  for(int j = 0; j<VDimension; j++)
253  {
254  for(int k = 0; k<numPoints; k++)
255  {
256  shapeMat(i, j) += meanScaledTranspose[i][k] * shapeScaled[k][j];
257  }
258  }
259  }
260 
261  // Calculate SVD
262  vnl_svd<RealType> svd(shapeMat);
263 
264  newTransform.rotation = svd.V() * svd.U().transpose();
265  // Cumulatively multiply rotation values
266  transform.rotation = newTransform.rotation * transform.rotation;
267  // TODO: Calculate standardized distance between mean of shapes and registered shape?
268  // Calculate scale: Sum up elements of diagonal matrix
269  double trsqrt = 0;
270  for(int j = 0; j<VDimension; j++)
271  {
272  trsqrt += svd.W()(j);
273  }
274  newTransform.scale = trsqrt * (normMean / normShape);
275 
276  if(newTransform.scale == 0)
277  newTransform.scale = 1.0;
278 
279  // Cumulatively multiply scale values
280  transform.scale *= newTransform.scale;
281 
282  // Calculate translation
283  PointType mult1 = newTransform.scale * colMeanShape;
284 
285  PointType mult2;
286  mult2.fill(0.0);
287 
288  for(int i = 0; i<VDimension; i++)
289  {
290  for(int j = 0; j<VDimension; j++)
291  {
292  mult2[i] += mult1[j] * newTransform.rotation(j,i);
293  }
294  }
295 
296  PointType sub = colMeanMean - mult2;
297  newTransform.translation = sub;
298  // Cumulatively add translation values
299  transform.translation += newTransform.translation;
300  // Transform the shape
301  ShapeType outputShape = TransformShape((*leaveOutIt), newTransform);
302  // Re-initialize variables
303  colMeanShape.fill(0.0);
304  colMeanMean.fill(0.0);
305 
306  shapeScaled.clear();
307  meanScaled.clear();
308 
309  ssqShape.fill(0.0);
310  ssqMean.fill(0.0);
311 
312  normMean = 0;
313  normShape = 0;
314 
315  return outputShape;
316 }
317 
318 template<unsigned int VDimension>
319 typename PSMProcrustesFunction<VDimension>::ShapeType
321 ::TransformShape(ShapeType shape, SimilarityTransform3D & transform)
322 {
323  int numPoints = shape.size();
324  ShapeIteratorType shapeIt;
325  shapeIt = shape.begin();
326 
327  // Multiply by scale
328  while(shapeIt != shape.end())
329  {
330  PointType & point = *shapeIt;
331  (*shapeIt) = transform.scale * point;
332  shapeIt++;
333  }
334 
335  ShapeType transformedShape;
336  vnl_vector_fixed<RealType, VDimension> vec;
337  vec.fill(0.0);
338  for(int i = 0; i < numPoints; i++)
339  transformedShape.push_back(vec);
340 
341  // Multiply by rotation
342  for(int i = 0; i<numPoints; i++)
343  {
344  for(int j = 0; j<VDimension; j++)
345  {
346  for(int k = 0; k<VDimension; k++)
347  {
348  transformedShape[i][j] += shape[i][k] * transform.rotation[k][j];
349  }
350  }
351  }
352 
353  shapeIt = transformedShape.begin();
354 
355  // Add translation
356  while(shapeIt != transformedShape.end())
357  {
358  PointType & point = (*shapeIt);
359  point += transform.translation;
360  shapeIt++;
361  }
362  return transformedShape;
363 }
364 
365 template<unsigned int VDimension>
366 typename PSMProcrustesFunction<VDimension>::RealType
368 ::ComputeSumOfSquares(ShapeListType & shapes)
369 {
370  ShapeListIteratorType shapeIt1, shapeIt2;
371  ShapeIteratorType pointIt1, pointIt2;
372 
373  RealType sum = 0.0;
374 
375  for(shapeIt1 = shapes.begin(); shapeIt1 != shapes.end(); shapeIt1++)
376  {
377  for(shapeIt2 = shapes.begin(); shapeIt2 != shapes.end(); shapeIt2++)
378  {
379  ShapeType & shape1 = (*shapeIt1);
380  ShapeType & shape2 = (*shapeIt2);
381 
382  pointIt1 = shape1.begin();
383  pointIt2 = shape2.begin();
384  while(pointIt1 != shape1.end() && pointIt2 != shape2.end())
385  {
386  sum += ((*pointIt1) - (*pointIt2)).squared_magnitude();
387  pointIt1++;
388  pointIt2++;
389  }
390  }
391  }
392  return sum / static_cast<RealType>(shapes.size() * shapes[0].size());
393 }
394 
395 template<unsigned int VDimension>
397 ::CheckDegenerateCase(PointType ssqShape, PointType ssqMean,
398  PointType colMeanShape, PointType colMeanMean, int rows)
399 {
400  // TODO: Calculate standardized distance between mean of shapes and
401  // registered shape?
402  PointType valueShape, valueMean;
403  for(int i = 0; i<VDimension; i++)
404  {
405  valueShape[i] = 2.22e-16 * rows * colMeanShape[i];
406  valueShape[i] = valueShape[i] * valueShape[i];
407 
408  valueMean[i] = 2.22e-16 * rows * colMeanMean[i];
409  valueMean[i] = valueMean[i] * valueMean[i];
410  }
411 
412  // Check if any element in ssqShape and ssqMean is less than any element in
413  // valueShape and valueMean resp.
414  for(int j = 0; j<VDimension; j++)
415  {
416  if(ssqShape[j] <= valueShape.min_value() && ssqMean[j] <= valueMean.min_value())
417  return false;
418  }
419 
420  return true;
421 }
422 
423 template<unsigned int VDimension>
425 ::LeaveOneOutMean(ShapeType & mean, ShapeListType & shapeList, ShapeListIteratorType & leaveOutIt)
426 {
427  ShapeListIteratorType shapeListIt;
428  ShapeIteratorType shapeIt, meanIt;
429 
430  int i, numPoints = shapeList[0].size();
431 
432  mean.clear();
433  mean.reserve(numPoints);
434  vnl_vector_fixed<RealType, VDimension> vec;
435  vec.fill(0.0);
436  for(i = 0; i < numPoints; i++)
437  {
438  mean.push_back(vec);
439  }
440 
441  for(shapeListIt = shapeList.begin(); shapeListIt != shapeList.end(); shapeListIt++)
442  {
443  if(shapeListIt != leaveOutIt)
444  {
445  ShapeType & shape = (*shapeListIt);
446  shapeIt = shape.begin();
447  meanIt = mean.begin();
448  while(shapeIt != shape.end())
449  {
450  (*meanIt) += (*shapeIt);
451 
452  shapeIt++;
453  meanIt++;
454  }
455  }
456  }
457 
458  for(meanIt = mean.begin(); meanIt != mean.end(); meanIt++)
459  {
460  (*meanIt) /= static_cast<RealType>(shapeList.size() - 1);
461  }
462 }
463 
464 } // end namespace
Generalized Procrustes Analysis is the rigid registration between different input shapes represented ...
RealType ComputeSumOfSquares(ShapeListType &shapes)
void RunGeneralizedProcrustes(SimilarityTransformListType &transform, ShapeListType &shapes)
ShapeType TransformShape(ShapeType shape, SimilarityTransform3D &transform)