19 #include "itkPSMProcrustesFunction.h" 21 #include <vnl/algo/vnl_svd.h> 26 template<
unsigned int VDimension>
29 ShapeListType & shapes)
31 ShapeListIteratorType leaveOutIt;
32 SimilarityTransformListIteratorType transformIt;
33 ShapeIteratorType shapeIt, meanIt;
34 ShapeType shape, mean;
35 SimilarityTransform3D transform;
36 PointType ssqShape, ssqMean;
38 const RealType SOS_EPSILON = 0.1;
40 int numOfShapes = shapes.size();
41 std::string errstring =
"";
43 transform.rotation.set_identity();
44 transform.scale = 1.0;
45 transform.translation.fill(0.0);
49 transforms.reserve(shapes.size());
50 for(
int i = 0; i<numOfShapes; i++)
52 transforms.push_back(transform);
55 RealType sumOfSquares = ComputeSumOfSquares(shapes);
56 RealType newSumOfSquares, diff = 1e5;
61 while(diff > SOS_EPSILON)
67 transformIt = transforms.begin();
69 for(leaveOutIt = shapes.begin(); leaveOutIt != shapes.end(); leaveOutIt++)
72 LeaveOneOutMean(mean, shapes, leaveOutIt);
73 (*leaveOutIt) = RunProcrustes((*transformIt), mean, leaveOutIt);
78 RealType scaleAve = 0.0;
79 for(transformIt = transforms.begin(); transformIt != transforms.end(); transformIt++)
80 scaleAve += log((*transformIt).scale);
82 scaleAve = exp(scaleAve / static_cast<RealType>(transforms.size()));
84 SimilarityTransform3D scaleSim;
85 scaleSim.rotation.set_identity();
86 scaleSim.translation.fill(0.0);
87 scaleSim.scale = 1.0 / scaleAve;
89 ShapeListIteratorType shapeListIt = shapes.begin();
90 transformIt = transforms.begin();
91 while(shapeListIt != shapes.end())
93 TransformShape((*shapeListIt), scaleSim);
94 (*transformIt).scale /= scaleAve;
101 newSumOfSquares = ComputeSumOfSquares(shapes);
102 diff = sumOfSquares - newSumOfSquares;
104 sumOfSquares = newSumOfSquares;
106 std::cout <<
"DIFF VALUE : " << diff << std::endl;
107 std::cout <<
"******PROCRUSTES FUNCTION COUNTER******: " << counter << std::endl;
110 errstring =
"Number of iterations on shapes is too high.";
111 ExceptionObject e( __FILE__, __LINE__ );
112 e.SetDescription( errstring.c_str() );
117 catch(itk::ExceptionObject &e)
119 errstring =
"ITK exception with description: " + std::string(e.GetDescription())
120 + std::string(
"\n at location:") + std::string(e.GetLocation())
121 + std::string(
"\n in file:") + std::string(e.GetFile());
125 errstring =
"Unknown exception thrown";
130 ShapeListType & shapes);
132 ShapeListType & shapes);
134 template<
unsigned int VDimension>
135 typename PSMProcrustesFunction<VDimension>::ShapeType
138 ShapeListIteratorType & leaveOutIt)
140 ShapeIteratorType shapeIt1, shapeIt2;
141 SimilarityTransform3D newTransform;
142 ShapeType shapeScaled, meanScaled;
143 PointType colMeanShape, colMeanMean, ssqShape, ssqMean;
144 double normMean, normShape;
145 vnl_matrix_fixed<RealType, VDimension, VDimension> shapeMat;
148 int numPoints = (*leaveOutIt).size();
150 vnl_matrix<double> meanScaledTranspose(VDimension,numPoints);
153 colMeanShape.fill(0.0);
154 colMeanMean.fill(0.0);
162 for(
int j = 0; j<VDimension; j++)
164 for(
int i = 0; i<numPoints; i++)
166 colMeanShape[j] += (*leaveOutIt)[i][j];
167 colMeanMean[j] += mean[i][j];
169 colMeanShape[j] = colMeanShape[j]/numPoints;
170 colMeanMean[j] = colMeanMean[j]/numPoints;
174 ShapeIteratorType it1 = (*leaveOutIt).begin();
175 ShapeIteratorType it2 = mean.begin();
177 while(it1 != (*leaveOutIt).end())
179 shapeScaled.push_back((*it1) - colMeanShape);
180 meanScaled.push_back((*it2) - colMeanMean);
186 for(
int j = 0; j<VDimension; j++)
188 for(
int i = 0; i<numPoints; i++)
190 ssqShape[j] += (shapeScaled[i][j] * shapeScaled[i][j]);
191 ssqMean[j] += (meanScaled[i][j] * meanScaled[i][j]);
196 bool constShape = CheckDegenerateCase(ssqShape, ssqMean, colMeanShape, colMeanMean, numPoints);
202 for(
int j = 0; j<VDimension; j++)
204 normShape += ssqShape[j];
205 normMean += ssqMean[j];
207 normShape = sqrt(normShape);
208 normMean = sqrt(normMean);
210 ShapeIteratorType shapeScaledIt = shapeScaled.begin();
211 ShapeIteratorType meanScaledIt = meanScaled.begin();
212 while(shapeScaledIt != shapeScaled.end())
214 (*shapeScaledIt) = (*shapeScaledIt)/normShape;
215 (*meanScaledIt) = (*meanScaledIt)/normMean;
224 vnl_vector_fixed<RealType, VDimension> vec;
226 for(
int i = 0; i < numPoints; i++)
227 output.push_back(vec);
228 ShapeIteratorType outputIt = output.begin();
229 while(outputIt != output.end())
231 (*outputIt) = colMeanShape;
234 transform.scale = 1.0;
235 transform.rotation.set_identity();
236 outputIt = output.begin();
237 transform.translation = (*outputIt);
241 for(
int j = 0; j<VDimension; j++)
243 for(
int i = 0; i<numPoints; i++)
245 meanScaledTranspose[j][i] = meanScaled[i][j];
250 for(
int i = 0; i<VDimension; i++)
252 for(
int j = 0; j<VDimension; j++)
254 for(
int k = 0; k<numPoints; k++)
256 shapeMat(i, j) += meanScaledTranspose[i][k] * shapeScaled[k][j];
262 vnl_svd<RealType> svd(shapeMat);
264 newTransform.rotation = svd.V() * svd.U().transpose();
266 transform.rotation = newTransform.rotation * transform.rotation;
270 for(
int j = 0; j<VDimension; j++)
272 trsqrt += svd.W()(j);
274 newTransform.scale = trsqrt * (normMean / normShape);
276 if(newTransform.scale == 0)
277 newTransform.scale = 1.0;
280 transform.scale *= newTransform.scale;
283 PointType mult1 = newTransform.scale * colMeanShape;
288 for(
int i = 0; i<VDimension; i++)
290 for(
int j = 0; j<VDimension; j++)
292 mult2[i] += mult1[j] * newTransform.rotation(j,i);
296 PointType sub = colMeanMean - mult2;
297 newTransform.translation = sub;
299 transform.translation += newTransform.translation;
301 ShapeType outputShape = TransformShape((*leaveOutIt), newTransform);
303 colMeanShape.fill(0.0);
304 colMeanMean.fill(0.0);
318 template<
unsigned int VDimension>
319 typename PSMProcrustesFunction<VDimension>::ShapeType
323 int numPoints = shape.size();
324 ShapeIteratorType shapeIt;
325 shapeIt = shape.begin();
328 while(shapeIt != shape.end())
330 PointType & point = *shapeIt;
331 (*shapeIt) = transform.scale * point;
335 ShapeType transformedShape;
336 vnl_vector_fixed<RealType, VDimension> vec;
338 for(
int i = 0; i < numPoints; i++)
339 transformedShape.push_back(vec);
342 for(
int i = 0; i<numPoints; i++)
344 for(
int j = 0; j<VDimension; j++)
346 for(
int k = 0; k<VDimension; k++)
348 transformedShape[i][j] += shape[i][k] * transform.rotation[k][j];
353 shapeIt = transformedShape.begin();
356 while(shapeIt != transformedShape.end())
358 PointType & point = (*shapeIt);
359 point += transform.translation;
362 return transformedShape;
365 template<
unsigned int VDimension>
366 typename PSMProcrustesFunction<VDimension>::RealType
370 ShapeListIteratorType shapeIt1, shapeIt2;
371 ShapeIteratorType pointIt1, pointIt2;
375 for(shapeIt1 = shapes.begin(); shapeIt1 != shapes.end(); shapeIt1++)
377 for(shapeIt2 = shapes.begin(); shapeIt2 != shapes.end(); shapeIt2++)
379 ShapeType & shape1 = (*shapeIt1);
380 ShapeType & shape2 = (*shapeIt2);
382 pointIt1 = shape1.begin();
383 pointIt2 = shape2.begin();
384 while(pointIt1 != shape1.end() && pointIt2 != shape2.end())
386 sum += ((*pointIt1) - (*pointIt2)).squared_magnitude();
392 return sum /
static_cast<RealType
>(shapes.size() * shapes[0].size());
395 template<
unsigned int VDimension>
398 PointType colMeanShape, PointType colMeanMean,
int rows)
402 PointType valueShape, valueMean;
403 for(
int i = 0; i<VDimension; i++)
405 valueShape[i] = 2.22e-16 * rows * colMeanShape[i];
406 valueShape[i] = valueShape[i] * valueShape[i];
408 valueMean[i] = 2.22e-16 * rows * colMeanMean[i];
409 valueMean[i] = valueMean[i] * valueMean[i];
414 for(
int j = 0; j<VDimension; j++)
416 if(ssqShape[j] <= valueShape.min_value() && ssqMean[j] <= valueMean.min_value())
423 template<
unsigned int VDimension>
425 ::LeaveOneOutMean(ShapeType & mean, ShapeListType & shapeList, ShapeListIteratorType & leaveOutIt)
427 ShapeListIteratorType shapeListIt;
428 ShapeIteratorType shapeIt, meanIt;
430 int i, numPoints = shapeList[0].size();
433 mean.reserve(numPoints);
434 vnl_vector_fixed<RealType, VDimension> vec;
436 for(i = 0; i < numPoints; i++)
441 for(shapeListIt = shapeList.begin(); shapeListIt != shapeList.end(); shapeListIt++)
443 if(shapeListIt != leaveOutIt)
445 ShapeType & shape = (*shapeListIt);
446 shapeIt = shape.begin();
447 meanIt = mean.begin();
448 while(shapeIt != shape.end())
450 (*meanIt) += (*shapeIt);
458 for(meanIt = mean.begin(); meanIt != mean.end(); meanIt++)
460 (*meanIt) /=
static_cast<RealType
>(shapeList.size() - 1);
Generalized Procrustes Analysis is the rigid registration between different input shapes represented ...
RealType ComputeSumOfSquares(ShapeListType &shapes)
void RunGeneralizedProcrustes(SimilarityTransformListType &transform, ShapeListType &shapes)
ShapeType TransformShape(ShapeType shape, SimilarityTransform3D &transform)