Libs/Optimize/Optimize.h
Namespaces
Name |
---|
shapeworks User usage reporting (telemetry) |
Classes
Name | |
---|---|
class | shapeworks::MatrixContainer |
class | shapeworks::Optimize |
Source code
#pragma once
#ifdef _WIN32
#pragma warning(disable : 4996)
#endif
// std
#include <random>
#include <string>
#include <vector>
// itk
#include <itkCommand.h>
#include <itkImage.h>
#include <Eigen/Eigen>
// shapeworks
#include <Project/Project.h>
#include "Libs/Optimize/Domain/DomainType.h"
#include "Libs/Optimize/Function/VectorFunction.h"
#include "Libs/Optimize/Utils/OptimizationVisualizer.h"
#include "ProcrustesRegistration.h"
#include "Sampler.h"
namespace shapeworks {
class Project;
class ParticleGoodBadAssessment;
class MatrixContainer {
public:
Eigen::MatrixXd matrix_;
};
class Optimize {
public:
using ImageType = itk::Image<float, 3>;
using VectorType = VectorFunction::VectorType;
using MatrixType = Eigen::MatrixXd;
Optimize();
virtual ~Optimize();
bool Run();
bool LoadParameterFile(std::string filename);
bool SetUpOptimize(ProjectHandle project);
void SetProject(std::shared_ptr<Project> project);
void SetIterationCallbackFunction(const std::function<void(void)>& f) { this->iteration_callback_ = f; }
void AbortOptimization();
bool GetAborted();
virtual std::vector<std::vector<itk::Point<double>>> GetLocalPoints();
virtual std::vector<std::vector<itk::Point<double>>> GetGlobalPoints();
void SetCutPlanes(std::vector<std::array<itk::Point<double>, 3>> cut_planes);
void SetVerbosity(int verbosity_level);
void SetDomainsPerShape(int domains_per_shape);
int GetDomainsPerShape();
void SetDomainType(shapeworks::DomainType type);
shapeworks::DomainType GetDomainType();
void SetNumberOfParticles(std::vector<int> number_of_particles);
std::vector<int> GetNumberOfParticles();
void SetTransformFile(std::string filename);
std::string GetTransformFile();
void SetPrefixTransformFile(std::string prefix_transform_file);
std::string GetPrefixTransformFile();
void SetOutputDir(std::string output_dir);
void SetOutputTransformFile(std::string output_transform_file);
void SetOutputIndividualTransformFiles(bool value);
void SetUseMeshBasedAttributes(bool use_mesh_based_attributes);
bool GetUseMeshBasedAttributes();
void SetUseXYZ(std::vector<bool> use_xyz);
void SetUseNormals(std::vector<bool> use_normals);
void SetAttributesPerDomain(std::vector<int> attributes_per_domain);
std::vector<int> GetAttributesPerDomain();
void SetDistributionDomainID(int distribution_domain_id);
int GetDistributionDomainID();
void SetOutputCuttingPlaneFile(std::string output_cutting_plane_file);
void SetUseCuttingPlanes(bool use_cutting_planes);
void SetCuttingPlane(unsigned int i, const vnl_vector_fixed<double, 3>& va, const vnl_vector_fixed<double, 3>& vb,
const vnl_vector_fixed<double, 3>& vc);
void SetMeshFFCMode(int mesh_ffc_mode) {
m_mesh_ffc_mode = mesh_ffc_mode;
m_sampler->SetMeshFFCMode(mesh_ffc_mode);
}
void SetTimePtsPerSubject(int time_pts_per_subject);
int GetTimePtsPerSubject();
void SetOptimizationIterations(int optimization_iterations);
void SetOptimizationIterationsCompleted(int optimization_iterations_completed);
void SetIterationsPerSplit(int iterations_per_split);
void SetUseShapeStatisticsInInit(bool use_shape_statistics_in_init);
void SetProcrustesInterval(int procrustes_interval);
void SetProcrustesScaling(bool procrustes_scaling);
void SetProcrustesRotationTranslation(bool procrustes_rotation_translation);
void SetRelativeWeighting(double relative_weighting);
void SetInitialRelativeWeighting(double initial_relative_weighting);
void SetStartingRegularization(double starting_regularization);
void SetEndingRegularization(double ending_regularization);
void SetRecomputeRegularizationInterval(int recompute_regularization_interval);
void SetSaveInitSplits(bool save_init_splits);
void SetCheckpointingInterval(int checkpointing_interval);
void SetKeepCheckpoints(int keep_checkpoints);
void SetUseRegression(bool use_regression);
void SetUseMixedEffects(bool use_mixed_effects);
void SetUseDisentangledSpatiotemporalSSM(bool use_disentangled_ssm_4d);
bool GetUseDisentangledSpatiotemporalSSM();
void SetNormalAngle(double normal_angle);
void SetPerformGoodBad(bool perform_good_bad);
void SetLogEnergy(bool log_energy);
void AddImage(ImageType::Pointer image, std::string name = "");
void AddMesh(vtkSmartPointer<vtkPolyData> poly_data);
void AddContour(vtkSmartPointer<vtkPolyData> poly_data);
void SetFilenames(const std::vector<std::string>& filenames);
void SetPointFiles(const std::vector<std::string>& point_files);
void SetInitialPoints(std::vector<std::vector<itk::Point<double>>> initial_points);
int GetNumShapes();
void SetAttributeScales(const std::vector<double>& scales);
void SetFieldAttributes(const std::vector<std::string>& field_attributes);
void SetParticleFlags(std::vector<int> flags);
void SetFixedDomains(std::vector<int> flags);
void SetSharedBoundaryEnabled(bool enabled);
void SetSharedBoundaryWeight(double weight);
const std::vector<int>& GetDomainFlags();
void SetFileOutputEnabled(bool enabled);
std::vector<bool> GetUseXYZ();
std::vector<bool> GetUseNormals();
void SetNarrowBand(double v);
double GetNarrowBand();
void SetUseShapeStatisticsAfter(int num_particles);
int GetUseShapeStatisticsAfter();
void PrintParamInfo();
std::shared_ptr<Sampler> GetSampler() { return m_sampler; }
MatrixContainer GetParticleSystem();
void SetPythonFile(std::string filename);
void SetGeodesicsEnabled(bool is_enabled);
bool GetGeodesicsEnabled() const { return m_geodesics_enabled; }
void SetGeodesicsCacheSizeMultiplier(size_t n);
void SetGeodesicsRemeshPercent(double percent);
OptimizationVisualizer& GetVisualizer();
void SetShowVisualizer(bool show);
bool GetShowVisualizer();
bool GetMeshFFCMode() { return m_mesh_ffc_mode; }
vnl_vector_fixed<double, 3> TransformPoint(int domain, vnl_vector_fixed<double, 3> input);
void UpdateProgress();
void set_particle_format(std::string format) { particle_format_ = format; }
protected:
virtual void SetIterationCallback();
void ComputeTotalIterations();
void OptimizerStop();
void ReadTransformFile();
void ReadPrefixTransformFile(const std::string& s);
void InitializeSampler();
double GetMinNeighborhoodRadius();
void AddSinglePoint();
void Initialize();
void RunOptimize();
virtual void IterateCallback(itk::Object*, const itk::EventObject&);
void ComputeEnergyAfterIteration();
void SetCotanSigma();
void WriteTransformFile(int iter = -1) const;
void WriteTransformFile(std::string iter_prefix) const;
void WriteTransformFiles(int iter = -1) const;
void WriteTransformFiles(std::string iter_prefix) const;
void WritePointFiles(int iter = -1);
void WritePointFiles(std::string iter_prefix);
void WritePointFilesWithFeatures(int iter = -1);
void WritePointFilesWithFeatures(std::string iter_prefix);
void WriteEnergyFiles();
void WriteSplitFiles(std::string name);
void WriteCuttingPlanePoints(int iter = -1);
void WriteParameters(std::string output_dir = "");
void ReportBadParticles();
int SetParameters();
void WriteModes();
void PrintStartMessage(std::string str, unsigned int vlevel = 0) const;
void PrintDoneMessage(unsigned int vlevel = 0) const;
virtual void UpdateExportablePoints();
virtual std::vector<std::vector<std::vector<double>>> GetProcrustesTransforms();
void UpdateProject();
// return a checkpoint dir for the current iteration
std::string GetCheckpointDir();
std::shared_ptr<Sampler> m_sampler;
ProcrustesRegistration::Pointer m_procrustes;
std::shared_ptr<ParticleGoodBadAssessment> m_good_bad;
unsigned int m_verbosity_level = 0;
std::vector<std::vector<itk::Point<double>>> m_local_points, m_global_points;
int m_checkpoint_counter = 0;
int m_procrustes_counter = 0;
bool m_use_cutting_planes = false;
bool m_optimizing = false;
bool m_use_regression = false;
bool m_use_mixed_effects = false;
bool m_use_disentangled_ssm = false;
// IO Parameters
unsigned int m_domains_per_shape = 1;
shapeworks::DomainType m_domain_type = shapeworks::DomainType::Image;
std::vector<int> m_number_of_particles;
std::string m_transform_file;
std::string m_prefix_transform_file;
std::string m_output_dir;
std::string m_output_transform_file;
bool m_output_transform_files = false;
bool m_mesh_based_attributes = false;
std::vector<bool> m_use_xyz;
std::vector<bool> m_use_normals;
std::vector<int> m_attributes_per_domain;
int m_distribution_domain_id = -1;
std::string m_output_cutting_plane_file;
// Optimization Parameters
bool m_mesh_ffc_mode = 0;
unsigned int m_timepts_per_subject = 1;
int m_optimization_iterations = 2000;
int m_optimization_iterations_completed = 0;
int m_iterations_per_split = 1000;
bool m_use_shape_statistics_in_init = false;
unsigned int m_procrustes_interval = 3;
bool m_procrustes_scaling = true;
bool m_procrustes_rotation_translation = true;
double m_relative_weighting = 1.0;
double m_initial_relative_weighting = 0.05;
double m_starting_regularization = 1000;
double m_ending_regularization = 1.0;
int m_recompute_regularization_interval = 1;
bool m_save_init_splits = false;
unsigned int m_checkpointing_interval = 50;
int m_keep_checkpoints = 0;
double m_cotan_sigma_factor = 5.0;
std::vector<int> m_particle_flags;
std::vector<int> m_domain_flags;
double m_narrow_band = 0.0;
bool m_narrow_band_set = false;
bool m_fixed_domains_present = false;
int m_use_shape_statistics_after = -1;
std::string m_python_filename;
bool m_geodesics_enabled = false; // geodesics disabled by default
size_t m_geodesic_cache_size_multiplier = 0; // 0 => MeshWrapper will use a heuristic to determine cache size
double m_geodesic_remesh_percent = 100.0; // 100% by default (e.g. no remeshing)
// m_spacing is used to scale the random update vector for particle splitting.
double m_spacing = 0;
std::vector<std::string> m_filenames;
int m_num_shapes = 0;
std::vector<double> m_energy_a;
std::vector<double> m_energy_b;
std::vector<double> m_total_energy;
bool m_log_energy = false;
std::string m_str_energy;
// GoodBadAssessment
std::vector<std::vector<int>> m_bad_ids;
double m_normal_angle = itk::Math::pi / 2.0;
bool m_perform_good_bad = false;
std::vector<int> m_cutting_planes_per_input;
std::vector<int> m_spheres_per_input;
bool m_file_output_enabled = true;
bool m_aborted = false;
std::vector<std::array<itk::Point<double>, 3>> m_cut_planes;
int m_total_iterations = 0;
int m_iteration_count = 0;
int m_split_number = 0;
int current_particle_iterations_ = 0;
int total_particle_iterations_ = 0;
std::function<void(void)> iteration_callback_;
bool show_visualizer_ = false;
shapeworks::OptimizationVisualizer visualizer_;
std::string particle_format_ = "particles";
std::shared_ptr<Project> project_;
std::chrono::system_clock::time_point m_start_time;
std::chrono::system_clock::time_point m_last_update_time;
std::chrono::system_clock::time_point m_last_remaining_update_time;
std::string m_remaining_time_message;
};
} // namespace shapeworks
Updated on 2024-11-11 at 19:51:46 +0000