Shapeworks Studio  2.1
Shape analysis software suite
itkPSMProcrustesRegistration2DTest.cxx
1 /*=========================================================================
2  *
3  * Copyright Insight Software Consortium
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  *=========================================================================*/
18 #include <iostream>
19 #include <vector>
20 #include <string>
21 #include <sstream>
22 #include "itkPSMProcrustesRegistration.h"
23 #include "itkImage.h"
24 #include "itkImageFileReader.h"
25 #include "itkPSMEntropyModelFilter.h"
26 #include "itkPSMProjectReader.h"
27 #include "itkPSMParticleSystem.h"
28 #include "itkPSMRegionDomain.h"
29 #include "vnl/vnl_matrix_fixed.h"
30 #include "vnl/vnl_vector_fixed.h"
31 #include "itkCommand.h"
32 
33 namespace itk{
34 
35 class MyPSMProcrustesIterationCommand : public itk::Command
36 {
37 public:
40  typedef Command Superclass;
41  typedef SmartPointer< Self > Pointer;
42  typedef SmartPointer< const Self > ConstPointer;
43 
44  typedef Image<float, 2> ImageType;
45 
46  PSMProcrustesRegistration<2> *procrustesRegistration;
49 
51  itkNewMacro(Self);
52 
54  virtual void Execute(Object *caller, const EventObject &)
55  {
57  = static_cast<PSMEntropyModelFilter<ImageType> *>(caller);
58 
59  if (this->procrustesRegistration->GetProcrustesInterval() != 0)
60  {
61  this->m_ProcrustesCounter++;
62 
63  if (this->m_ProcrustesCounter >= (int)this->procrustesRegistration->GetProcrustesInterval())
64  {
65  // Reset the counter
66  this->m_ProcrustesCounter = 0;
67  this->procrustesRegistration->RunRegistration();
68  std::cout << "Run Procrustes Registration" << std::endl;
69  }
70  }
71 
72  // Print every 10 iterations
73  if (o->GetNumberOfElapsedIterations() % 10 != 0) return;
74 
75  std::cout << "Iteration # " << o->GetNumberOfElapsedIterations() << std::endl;
76  std::cout << " Eigenmode variances: ";
77  for (unsigned int i = 0; i < o->GetShapePCAVariances().size(); i++)
78  {
79  std::cout << o->GetShapePCAVariances()[i] << " ";
80  }
81  std::cout << std::endl;
82  std::cout << " Regularization = " << o->GetRegularizationConstant() << std::endl;
83  }
84  virtual void Execute(const Object *, const EventObject &)
85  {
86  std::cout << "SHOULDN'T BE HERE" << std::endl;
87  }
88  void SetPSMProcrustesRegistration(PSMProcrustesRegistration<2> *p)
89  { procrustesRegistration = p; }
90 
91 protected:
93  {
94  m_ProcrustesCounter = 0;
95  }
97 private:
98  int m_ProcrustesCounter;
99  MyPSMProcrustesIterationCommand(const Self &); //purposely not implemented
100  void operator=(const Self &); //purposely not implemented
101 };
102 
103 } // end namespace itk
104 
110 template <class T>
112 {
113 public:
116  typedef T ObjectType;
117 
119  const std::vector<ObjectType> &GetOutput() const
120  {
121  return m_Output;
122  }
123  std::vector<ObjectType> &GetOutput()
124  {
125  return m_Output;
126  }
127 
128  void SetFileName(const char *fn)
129  { m_FileName = fn; }
130  void SetFileName(const std::string &fn)
131  { m_FileName = fn; }
132  const std::string& GetFileName() const
133  { return m_FileName; }
134 
136  inline void Read()
137  { this->Update(); }
138 
139  void Update()
140  {
141  // Open the input file.
142  std::ifstream in( m_FileName.c_str(), std::ios::binary );
143 
144  if (!in)
145  {
146  std::cerr << "Could not open filename " << m_FileName << std::endl;
147  throw 1;
148  }
149  // Read the number of transforms
150  int N;
151  in.read(reinterpret_cast<char *>(&N), sizeof(int));
152 
153  int sz = sizeof(ObjectType);
154  // Read the transforms
155  for (unsigned int i = 0; i < (unsigned int)N; i++)
156  {
157  ObjectType q; // maybe not the most efficient, but safe
158  in.read(reinterpret_cast<char *>(&q), sz);
159  m_Output.push_back(q);
160  }
161  in.close();
162  }
163 
164  object_reader() { }
165  virtual ~object_reader() {};
166 
167 private:
168  object_reader(const Self&); //purposely not implemented
169  void operator=(const Self&); //purposely not implemented
170 
171  std::vector<ObjectType> m_Output;
172  std::string m_FileName;
173 };
174 
175 
177 int itkPSMProcrustesRegistration2DTest(int argc, char* argv[] )
178 {
179  bool passed = true;
180  std::string errstring = "";
181  std::string output_path = "";
182  std::string input_path_prefix = "";
183 
184  // Check for proper arguments
185  if (argc < 3)
186  {
187  std::cout << "Wrong number of arguments. \nUse: "
188  << "itkPSMProcrustesRegistration2DTest parameter_file transforms_file [output_path] [input_path]\n"
189  << "See itk::PSMParameterFileReader for documentation on the parameter file format.\n"
190  <<" Note that input_path will be prefixed to any file names and paths in the xml parameter file.\n"
191  << std::endl;
192  return EXIT_FAILURE;
193  }
194 
195  if (argc >3)
196  {
197  output_path = std::string(argv[3]);
198  }
199 
200  if (argc >4)
201  {
202  input_path_prefix = std::string(argv[4]);
203  }
204 
205  typedef itk::Image<float, 2> ImageType;
206 
207  try
208  {
209  // Read the project parameters
210  itk::PSMProjectReader::Pointer xmlreader =
211  itk::PSMProjectReader::New();
212  xmlreader->SetFileName(argv[1]);
213  xmlreader->Update();
214 
215  itk::PSMProject::Pointer project = xmlreader->GetOutput();
216 
217  // Create the modeling filter and set up the optimization.
218  itk::PSMEntropyModelFilter<ImageType>::Pointer P
220 
221  // Setup the Callback function that is executed after each
222  // iteration of the solver.
223  itk::MyPSMProcrustesIterationCommand::Pointer mycommand
224  = itk::MyPSMProcrustesIterationCommand::New();
225  P->AddObserver(itk::IterationEvent(), mycommand);
226 
227  // Create the ProcrustesRegistration pointer
228  itk::PSMProcrustesRegistration<2>::Pointer procrustesRegistration
230 
231  mycommand->SetPSMProcrustesRegistration( procrustesRegistration );
232  // Load the distance transforms
233  const std::vector<std::string> &dt_files = project->GetDistanceTransforms();
234  itk::ImageFileReader<ImageType>::Pointer reader =
235  itk::ImageFileReader<ImageType>::New();
236 
237  std::cout << "Reading distance transforms ..." << std::endl;
238  for (unsigned int i = 0; i < dt_files.size(); i++)
239  {
240  reader->SetFileName(input_path_prefix + dt_files[i]);
241  reader->Update();
242 
243  std::cout << " " << dt_files[i] << std::endl;
244  }
245  //TODO: Why does number of inputs need to be set to greater than 100?
246  for(unsigned int i = 0; i < 103; i++)
247  {
248  P->SetInput(i,reader->GetOutput());
249  }
250  std::cout << "Done!" << std::endl;
251  std::cout << "Number of inputs: " << P->GetNumberOfInputs() << std::endl;
252 
253  // Load the model initialization. It should be specified as a model with a name.
254  const std::vector<std::string> &pt_files = project->GetModel(std::string("initialization"));
255  std::vector<itk::PSMEntropyModelFilter<ImageType>::PointType> c;
256  std::cout << "Reading the initial model correspondences ..." << std::endl;
257  unsigned int numOfPoints;
258  for (unsigned int i = 0; i < pt_files.size(); i++)
259  {
260  // Read the points for this file and add as a list
261  numOfPoints = 0;
262  // Open the ascii file.
263  std::ifstream in( (input_path_prefix + pt_files[0]).c_str() );
264  if ( !in )
265  {
266  errstring += "Could not open point file for input.";
267  passed = false;
268  break;
269  }
270 
271  // Read all of the points, one point per line.
272  while (in)
273  {
275 
276  for (unsigned int d = 0; d < 2; d++)
277  {
278  in >> pt[d];
279  }
280  c.push_back(pt);
281  numOfPoints++;
282  }
283  // this algorithm pushes the last point twice
284  c.pop_back();
285  std::cout << "Read " << numOfPoints-1 << " points. " << std::endl;
286  in.close();
287  }
288 
289  for(unsigned int i = 0; i < 100; i++)
290  {
292  }
293 
294  std::cout << "Done!" << std::endl;
295 
296  // Read the input transforms
298  transform_reader.SetFileName(argv[2]);
299  transform_reader.Update();
300 
301  std::cout << "Reading transforms." << std::endl;
302  // Read transforms and apply to the Particle System
303  for (unsigned int i = 0; i < P->GetParticleSystem()->GetNumberOfDomains(); i++)
304  {
305  for(unsigned int j = 0; j < numOfPoints; j++)
306  {
309  itk::PSMParticleSystem<2>::Pointer PS = P->GetParticleSystem();
310 
311  point[0] = PS->GetPosition(j,i)[0];
312  point[1] = PS->GetPosition(j,i)[1];
313 
314  // Transform the points and set them in the Particle System
315  trPoint = PS->TransformPoint( point, transform_reader.GetOutput()[i] );
316  PS->SetPosition( trPoint, j, i);
317  }
318  }
319 
320  // Read some parameters from the file or provide defaults
321  double regularization_initial = 100.0f;
322  double regularization_final = 5.0f;
323  double regularization_decayspan = 2000.0f;
324  double tolerance = 1.0e-8;
325  unsigned int maximum_iterations = 200000;
326  unsigned int procrustes_interval = 1;
327  if ( project->HasOptimizationAttribute("regularization_initial") )
328  {
329  regularization_initial = project->GetOptimizationAttribute("regularization_initial");
330  }
331  if ( project->HasOptimizationAttribute("regularization_final") )
332  {
333  regularization_final = project->GetOptimizationAttribute("regularization_final");
334  }
335  if ( project->HasOptimizationAttribute("regularization_decayspan") )
336  {
337  regularization_decayspan = project->GetOptimizationAttribute("regularization_decayspan");
338  }
339  if ( project->HasOptimizationAttribute("tolerance") )
340  {
341  tolerance = project->GetOptimizationAttribute("tolerance");
342  }
343  if ( project->HasOptimizationAttribute("maximum_iterations") )
344  {
345  maximum_iterations
346  = static_cast<unsigned int>(project->GetOptimizationAttribute("maximum_iterations"));
347  }
348  if ( project->HasOptimizationAttribute("procrustes_interval") )
349  {
350  procrustes_interval
351  = static_cast<unsigned int>(project->GetOptimizationAttribute("procrustes_interval"));
352  }
353 
354  // Set variables for PSMProcrustesRegistration
355  procrustesRegistration->SetProcrustesInterval(procrustes_interval);
356  procrustesRegistration->SetPSMParticleSystem(P->GetParticleSystem());
357 
358  std::cout << "Optimization parameters: " << std::endl;
359  std::cout << " regularization_initial = " << regularization_initial << std::endl;
360  std::cout << " regularization_final = " << regularization_final << std::endl;
361  std::cout << " regularization_decayspan = " << regularization_decayspan << std::endl;
362  std::cout << " tolerance = " << tolerance << std::endl;
363  std::cout << " maximum_iterations = " << maximum_iterations << std::endl;
364  std::cout << " procrustes_interval = " << procrustes_interval << std::endl;
365 
366  // Set the parameters and run the optimization.
367  P->SetMaximumNumberOfIterations(maximum_iterations);
368  P->SetRegularizationInitial(regularization_initial);
369  P->SetRegularizationFinal(regularization_final);
370  P->SetRegularizationDecaySpan(regularization_decayspan);
371  P->SetTolerance(tolerance);
372  P->Update();
373 
374  if (P->GetNumberOfElapsedIterations() >= maximum_iterations)
375  {
376  errstring += "Optimization did not converge based on tolerance criteria.\n";
377  passed = false;
378  }
379 
380  // Write out the transforms
381  std::string output_transform_file = "output_transforms_PSMProcrustesRegistration2DTest.txt";
382  std::string out_file = output_path + output_transform_file;
383  std::ofstream out(out_file.c_str());
384  for (unsigned int d = 0; d < P->GetParticleSystem()->GetNumberOfDomains(); d++)
385  {
386  if(!out)
387  {
388  errstring += "Could not open file for output: ";
389  }
390  else
391  {
392  out << P->GetParticleSystem()->GetTransform(d);
393  out << std::endl;
394  }
395  }
396 
397  // Print out points for domain d
398  // Load the model initialization. It should be specified as a model with a name.
399  const std::vector<std::string> &out_files = project->GetModel(std::string("optimized"));
400 
401  for (unsigned int d = 0; d < P->GetParticleSystem()->GetNumberOfDomains(); d++)
402  {
403  // Open the output file and append the number
404  std::ostringstream ss;
405  ss << d;
406  std::string fname = output_path + out_files[0] + "_" + ss.str() + ".lpts";
407  std::ofstream out( fname.c_str() );
408  if ( !out )
409  {
410  errstring += "Could not open point file for output: ";
411  }
412  else
413  {
414  for (unsigned int j = 0; j < P->GetParticleSystem()->GetNumberOfParticles(d); j++)
415  {
416  //for (unsigned int i = 0; i < 2; i++)
417  // {
418  // Print the last point as 0.0 so that SWViewer can read the point files
419  out << P->GetParticleSystem()->GetPosition(j,d)[0] << " " << P->GetParticleSystem()->GetPosition(j,d)[1] << " " << 0.0;
420  // }
421  out << std::endl;
422  }
423  }
424  }
425  }
426  catch(itk::ExceptionObject &e)
427  {
428  errstring = "ITK exception with description: " + std::string(e.GetDescription())
429  + std::string("\n at location:") + std::string(e.GetLocation())
430  + std::string("\n in file:") + std::string(e.GetFile());
431  passed = false;
432  }
433  catch(...)
434  {
435  errstring = "Unknown exception thrown";
436  passed = false;
437  }
438 
439  if (passed)
440  {
441  std::cout << "All tests passed" << std::endl;
442  return EXIT_SUCCESS;
443  }
444  else
445  {
446  std::cout << "Test failed with the following error:" << std::endl;
447  std::cout << errstring << std::endl;
448  return EXIT_FAILURE;
449  }
450 }
void SetMaximumNumberOfIterations(const std::vector< unsigned int > &n)
unsigned int GetNumberOfElapsedIterations() const
PointType & GetPosition(unsigned long int k, unsigned int d=0)
void SetInputCorrespondencePoints(unsigned int index, const std::vector< PointType > &corr)
const std::vector< double > & GetShapePCAVariances() const
void SetRegularizationInitial(const std::vector< double > &v)
const std::vector< ObjectType > & GetOutput() const
void SetInput(const std::string &s, itk::DataObject *o)
itkTypeMacro(MyPSMProcrustesIterationCommand, Command)
PointType TransformPoint(const PointType &, const TransformType &) const
void SetPSMParticleSystem(PSMParticleSystemType *p)
virtual void Execute(Object *caller, const EventObject &)
void SetTolerance(const std::vector< double > &v)
vnl_matrix_fixed< double, VDimension+1, VDimension+1 > TransformType