Shapeworks Studio  2.1
Shape analysis software suite
itkPSMProcrustesRegistrationTest.cxx
1 /*=========================================================================
2  *
3  * Copyright Insight Software Consortium
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  *=========================================================================*/
18 #include <iostream>
19 #include <vector>
20 #include <string>
21 #include <sstream>
22 #include "itkPSMProcrustesRegistration.h"
23 #include "itkImage.h"
24 #include "itkImageFileReader.h"
25 #include "itkPSMEntropyModelFilter.h"
26 #include "itkPSMProjectReader.h"
27 #include "itkPSMParticleSystem.h"
28 #include "itkPSMRegionDomain.h"
29 #include "vnl/vnl_matrix_fixed.h"
30 #include "vnl/vnl_vector_fixed.h"
31 #include "itkCommand.h"
32 
33 namespace itk{
34 
35 class MyPSMProcrustesIterationCommand : public itk::Command
36 {
37 public:
40  typedef Command Superclass;
41  typedef SmartPointer< Self > Pointer;
42  typedef SmartPointer< const Self > ConstPointer;
43 
44  typedef Image<float, 3> ImageType;
45 
46  PSMProcrustesRegistration<3> *procrustesRegistration;
49 
51  itkNewMacro(Self);
52 
54  virtual void Execute(Object *caller, const EventObject &)
55  {
57  = static_cast<PSMEntropyModelFilter<ImageType> *>(caller);
58 
59  if (this->procrustesRegistration->GetProcrustesInterval() != 0)
60  {
61  this->m_ProcrustesCounter++;
62 
63  if (this->m_ProcrustesCounter >= (int)this->procrustesRegistration->GetProcrustesInterval())
64  {
65  // Reset the counter
66  this->m_ProcrustesCounter = 0;
67  this->procrustesRegistration->RunRegistration();
68  std::cout << "Run Procrustes Registration" << std::endl;
69  }
70  }
71 
72  // Print every 10 iterations
73  if (o->GetNumberOfElapsedIterations() % 10 != 0) return;
74 
75  std::cout << "Iteration # " << o->GetNumberOfElapsedIterations() << std::endl;
76  std::cout << " Eigenmode variances: ";
77  for (unsigned int i = 0; i < o->GetShapePCAVariances().size(); i++)
78  {
79  std::cout << o->GetShapePCAVariances()[i] << " ";
80  }
81  std::cout << std::endl;
82  std::cout << " Regularization = " << o->GetRegularizationConstant() << std::endl;
83  }
84  virtual void Execute(const Object *, const EventObject &)
85  {
86  std::cout << "SHOULDN'T BE HERE" << std::endl;
87  }
88  void SetPSMProcrustesRegistration(PSMProcrustesRegistration<3> *p)
89  { procrustesRegistration = p; }
90 
91 protected:
93  {
94  m_ProcrustesCounter = 0;
95  }
97 private:
98  int m_ProcrustesCounter;
99  MyPSMProcrustesIterationCommand(const Self &); //purposely not implemented
100  void operator=(const Self &); //purposely not implemented
101 };
102 
103 } // end namespace itk
104 
110 template <class T>
111 class object_reader
112 {
113 public:
116  typedef T ObjectType;
117 
119  const std::vector<ObjectType> &GetOutput() const
120  {
121  return m_Output;
122  }
123  std::vector<ObjectType> &GetOutput()
124  {
125  return m_Output;
126  }
127 
128  void SetFileName(const char *fn)
129  { m_FileName = fn; }
130  void SetFileName(const std::string &fn)
131  { m_FileName = fn; }
132  const std::string& GetFileName() const
133  { return m_FileName; }
134 
136  inline void Read()
137  { this->Update(); }
138 
139  void Update()
140  {
141  // Open the input file.
142  std::ifstream in( m_FileName.c_str(), std::ios::binary );
143 
144  if (!in)
145  {
146  std::cerr << "Could not open filename " << m_FileName << std::endl;
147  throw 1;
148  }
149  // Read the number of transforms
150  int N;
151  in.read(reinterpret_cast<char *>(&N), sizeof(int));
152 
153  int sz = sizeof(ObjectType);
154  // Read the transforms
155  for (unsigned int i = 0; i < (unsigned int)N; i++)
156  {
157  ObjectType q; // maybe not the most efficient, but safe
158  in.read(reinterpret_cast<char *>(&q), sz);
159  m_Output.push_back(q);
160  }
161  in.close();
162  }
163 
164  object_reader() { }
165  virtual ~object_reader() {};
166 
167 private:
168  object_reader(const Self&); //purposely not implemented
169  void operator=(const Self&); //purposely not implemented
170 
171  std::vector<ObjectType> m_Output;
172  std::string m_FileName;
173 };
174 
175 
177 int itkPSMProcrustesRegistrationTest(int argc, char* argv[] )
178 {
179  bool passed = true;
180  std::string errstring = "";
181  std::string output_path = "";
182  std::string input_path_prefix = "";
183 
184  // Check for proper arguments
185  if (argc < 3)
186  {
187  std::cout << "Wrong number of arguments. \nUse: "
188  << "itkPSMProcrustesRegistrationTest parameter_file transforms_file [output_path] [input_path]\n"
189  << "See itk::PSMParameterFileReader for documentation on the parameter file format.\n"
190  <<" Note that input_path will be prefixed to any file names and paths in the xml parameter file.\n"
191  << std::endl;
192  return EXIT_FAILURE;
193  }
194 
195  if (argc >3)
196  {
197  output_path = std::string(argv[3]);
198  }
199 
200  if (argc >4)
201  {
202  input_path_prefix = std::string(argv[4]);
203  }
204 
205  typedef itk::Image<float, 3> ImageType;
206 
207  try
208  {
209  // Read the project parameters
210  itk::PSMProjectReader::Pointer xmlreader =
211  itk::PSMProjectReader::New();
212  xmlreader->SetFileName(argv[1]);
213  xmlreader->Update();
214 
215  itk::PSMProject::Pointer project = xmlreader->GetOutput();
216 
217  // Create the modeling filter and set up the optimization.
218  itk::PSMEntropyModelFilter<ImageType>::Pointer P
220 
221  // Setup the Callback function that is executed after each
222  // iteration of the solver.
223  itk::MyPSMProcrustesIterationCommand::Pointer mycommand
224  = itk::MyPSMProcrustesIterationCommand::New();
225  P->AddObserver(itk::IterationEvent(), mycommand);
226 
227  // Create the ProcrustesRegistration pointer
228  itk::PSMProcrustesRegistration<3>::Pointer procrustesRegistration
230 
231  mycommand->SetPSMProcrustesRegistration( procrustesRegistration );
232  // Load the distance transforms
233  const std::vector<std::string> &dt_files = project->GetDistanceTransforms();
234  itk::ImageFileReader<ImageType>::Pointer reader =
235  itk::ImageFileReader<ImageType>::New();
236 
237  std::cout << "Reading distance transforms ..." << std::endl;
238  for (unsigned int i = 0; i < dt_files.size(); i++)
239  {
240  reader->SetFileName(input_path_prefix + dt_files[i]);
241  reader->Update();
242 
243  std::cout << " " << dt_files[i] << std::endl;
244  }
245  int number_of_inputs = 100;
246  //TODO: Why does number of inputs need to be set to greater than 100?
247  for(unsigned int i = 0; i < 103; i++)
248  {
249  P->SetInput(i,reader->GetOutput());
250  }
251  std::cout << "Done!" << std::endl;
252  std::cout << "Number of inputs: " << P->GetNumberOfInputs() << std::endl;
253 
254  // Load the model initialization. It should be specified as a model with a name.
255  const std::vector<std::string> &pt_files = project->GetModel(std::string("initialization"));
256  std::vector<itk::PSMEntropyModelFilter<ImageType>::PointType> c;
257  std::cout << "Reading the initial model correspondences ..." << std::endl;
258  unsigned int numOfPoints;
259  for (unsigned int i = 0; i < pt_files.size(); i++)
260  {
261  // Read the points for this file and add as a list
262  numOfPoints = 0;
263  // Open the ascii file.
264  std::ifstream in( (input_path_prefix + pt_files[0]).c_str() );
265  if ( !in )
266  {
267  errstring += "Could not open point file for input.";
268  passed = false;
269  break;
270  }
271 
272  // Read all of the points, one point per line.
273  while (in)
274  {
276 
277  for (unsigned int d = 0; d < 3; d++)
278  {
279  in >> pt[d];
280  }
281  c.push_back(pt);
282  numOfPoints++;
283  }
284  // this algorithm pushes the last point twice
285  c.pop_back();
286  std::cout << "Read " << numOfPoints-1 << " points. " << std::endl;
287  in.close();
288  }
289 
290  for(unsigned int i = 0; i < number_of_inputs; i++)
291  {
293  }
294 
295  std::cout << "Done!" << std::endl;
296 
297  // Read the input transforms
299  transform_reader.SetFileName(argv[2]);
300  transform_reader.Update();
301 
302  std::cout << "Reading transforms." << std::endl;
303  // Read transforms and apply to the Particle System
304  for (unsigned int i = 0; i < P->GetParticleSystem()->GetNumberOfDomains(); i++)
305  {
306  for(unsigned int j = 0; j < numOfPoints; j++)
307  {
310  itk::PSMParticleSystem<3>::Pointer PS = P->GetParticleSystem();
311 
312  point[0] = PS->GetPosition(j,i)[0];
313  point[1] = PS->GetPosition(j,i)[1];
314  point[2] = PS->GetPosition(j,i)[2];
315  // Transform the points and set them in the Particle System
316  trPoint = PS->TransformPoint( point, transform_reader.GetOutput()[i] );
317  PS->SetPosition( trPoint, j, i);
318  }
319  }
320 
321  // Read some parameters from the file or provide defaults
322  double regularization_initial = 10.0f;
323  double regularization_final = 2.0f;
324  double regularization_decayspan = 5000.0f;
325  double tolerance = 0.01;
326  unsigned int maximum_iterations = 1000;
327  unsigned int procrustes_interval = 1;
328  if ( project->HasOptimizationAttribute("regularization_initial") )
329  {
330  regularization_initial = project->GetOptimizationAttribute("regularization_initial");
331  }
332  if ( project->HasOptimizationAttribute("regularization_final") )
333  {
334  regularization_final = project->GetOptimizationAttribute("regularization_final");
335  }
336  if ( project->HasOptimizationAttribute("regularization_decayspan") )
337  {
338  regularization_decayspan = project->GetOptimizationAttribute("regularization_decayspan");
339  }
340  if ( project->HasOptimizationAttribute("tolerance") )
341  {
342  tolerance = project->GetOptimizationAttribute("tolerance");
343  }
344  if ( project->HasOptimizationAttribute("maximum_iterations") )
345  {
346  maximum_iterations
347  = static_cast<unsigned int>(project->GetOptimizationAttribute("maximum_iterations"));
348  }
349  if ( project->HasOptimizationAttribute("procrustes_interval") )
350  {
351  procrustes_interval
352  = static_cast<unsigned int>(project->GetOptimizationAttribute("procrustes_interval"));
353  }
354 
355  // Set variables for PSMProcrustesRegistration
356  procrustesRegistration->SetProcrustesInterval(procrustes_interval);
357  procrustesRegistration->SetPSMParticleSystem(P->GetParticleSystem());
358 
359  std::cout << "Optimization parameters: " << std::endl;
360  std::cout << " regularization_initial = " << regularization_initial << std::endl;
361  std::cout << " regularization_final = " << regularization_final << std::endl;
362  std::cout << " regularization_decayspan = " << regularization_decayspan << std::endl;
363  std::cout << " tolerance = " << tolerance << std::endl;
364  std::cout << " maximum_iterations = " << maximum_iterations << std::endl;
365  std::cout << " procrustes_interval = " << procrustes_interval << std::endl;
366 
367  // Set the parameters and run the optimization.
368  P->SetMaximumNumberOfIterations(maximum_iterations);
369  P->SetRegularizationInitial(regularization_initial);
370  P->SetRegularizationFinal(regularization_final);
371  P->SetRegularizationDecaySpan(regularization_decayspan);
372  P->SetTolerance(tolerance);
373  P->Update();
374  // TODO: Should this be a comparison of tolerance instead of iterations?
375  if (P->GetNumberOfElapsedIterations() >= maximum_iterations)
376  {
377  errstring += "Optimization did not converge based on tolerance criteria.\n";
378  passed = false;
379  }
380 
381  // Write out the transforms
382  std::string output_transform_file = "output_transforms_PSMProcrustesRegistrationTest.txt";
383  std::string out_file = output_path + output_transform_file;
384  std::ofstream out(out_file.c_str());
385  for (unsigned int d = 0; d < P->GetParticleSystem()->GetNumberOfDomains(); d++)
386  {
387  if(!out)
388  {
389  errstring += "Could not open file for output: ";
390  }
391  else
392  {
393  out << P->GetParticleSystem()->GetTransform(d);
394  out << std::endl;
395  }
396  }
397 
398  // Print out points for domain d
399  // Load the model initialization. It should be specified as a model with a name.
400  const std::vector<std::string> &opt_files = project->GetModel(std::string("optimized"));
401 
402  for (unsigned int d = 0; d < P->GetParticleSystem()->GetNumberOfDomains(); d++)
403  {
404  // Open the output file and append the number
405  std::ostringstream ss;
406  ss << d;
407  std::string fname = output_path + opt_files[0] + "_" + ss.str() + ".lpts";
408  std::ofstream out_file( fname.c_str() );
409  if ( !out_file )
410  {
411  errstring += "Could not open point file for output: ";
412  }
413  else
414  {
415  for (unsigned int j = 0; j < P->GetParticleSystem()->GetNumberOfParticles(d); j++)
416  {
417  for (unsigned int i = 0; i < 3; i++)
418  {
419  out_file << P->GetParticleSystem()->GetPosition(j,d)[i] << " ";
420  }
421  out_file << std::endl;
422  }
423  }
424  }
425  }
426  catch(itk::ExceptionObject &e)
427  {
428  errstring = "ITK exception with description: " + std::string(e.GetDescription())
429  + std::string("\n at location:") + std::string(e.GetLocation())
430  + std::string("\n in file:") + std::string(e.GetFile());
431  passed = false;
432  }
433  catch(...)
434  {
435  errstring = "Unknown exception thrown";
436  passed = false;
437  }
438 
439  if (passed)
440  {
441  std::cout << "All tests passed" << std::endl;
442  return EXIT_SUCCESS;
443  }
444  else
445  {
446  std::cout << "Test failed with the following error:" << std::endl;
447  std::cout << errstring << std::endl;
448  return EXIT_FAILURE;
449  }
450 }
void SetMaximumNumberOfIterations(const std::vector< unsigned int > &n)
unsigned int GetNumberOfElapsedIterations() const
PointType & GetPosition(unsigned long int k, unsigned int d=0)
void SetInputCorrespondencePoints(unsigned int index, const std::vector< PointType > &corr)
const std::vector< double > & GetShapePCAVariances() const
void SetRegularizationInitial(const std::vector< double > &v)
const std::vector< ObjectType > & GetOutput() const
void SetInput(const std::string &s, itk::DataObject *o)
itkTypeMacro(MyPSMProcrustesIterationCommand, Command)
PointType TransformPoint(const PointType &, const TransformType &) const
void SetPSMParticleSystem(PSMParticleSystemType *p)
virtual void Execute(Object *caller, const EventObject &)
void SetTolerance(const std::vector< double > &v)
vnl_matrix_fixed< double, VDimension+1, VDimension+1 > TransformType