Skip to content

Studio/DeepSSM/DeepSSMParameters.h

Namespaces

Name
shapeworks
User usage reporting (telemetry)

Classes

Name
class shapeworks::DeepSSMParameters

Source code

#pragma once

#include <Project/Project.h>

namespace shapeworks {

class DeepSSMParameters {
  enum class SamplerTypeOption { gaussian, gaussian_mixture, kde };

 public:
  explicit DeepSSMParameters(ProjectHandle project);
  void save_to_project();

  int get_aug_num_samples();
  void set_aug_num_samples(int num_samples);

  int get_aug_num_dims();
  void set_aug_num_dims(int num_dims);

  int get_training_num_dims();
  void set_training_num_dims(int num_dims);

  double get_aug_percent_variability();
  void set_aug_percent_variability(double percent_variability);

  std::string get_aug_sampler_type();
  void set_aug_sampler_type(std::string sampler_type);

  int get_training_epochs();
  void set_training_epochs(int epochs);

  double get_training_learning_rate();
  void set_training_learning_rate(double rate);

  bool get_training_decay_learning_rate();
  void set_training_decay_learning_rate(bool decay);

  bool get_training_fine_tuning();
  void set_training_fine_tuning(bool fine_tuning);

  int get_training_fine_tuning_epochs();
  void set_training_fine_tuning_epochs(int epochs);

  double get_training_fine_tuning_learning_rate();
  void set_training_fine_tuning_learning_rate(double rate);

  int get_training_batch_size();
  void set_training_batch_size(int batch_size);

  double get_training_split();
  void set_training_split(double value);

  double get_validation_split();
  void set_validation_split(double value);

  double get_testing_split();
  void set_testing_split(double value);

  bool get_prep_step_complete();
  void set_prep_step_complete(bool value);

  int get_prep_stage();
  void set_prep_stage(int stage);

  bool get_aug_step_complete();
  void set_aug_step_complete(bool value);

  std::string get_aug_message();
  void set_aug_message(std::string message);

  bool get_training_step_complete();
  void set_training_step_complete(bool value);

  std::string get_training_message();
  void set_training_message(std::string message);

  std::vector<double> get_spacing();
  void set_spacing(std::vector<double> spacing);

  std::string get_loss_function();
  void set_loss_function(std::string loss_function);

  bool get_tl_net_enabled();
  void set_tl_net_enabled(bool enabled);

  int get_tl_net_ae_epochs();
  void set_tl_net_ae_epochs(int num_epochs);

  int get_tl_net_tf_epochs();
  void set_tl_net_tf_epochs(int num_epochs);

  int get_tl_net_joint_epochs();
  void set_tl_net_joint_epochs(int num_epochs);

  double get_tl_net_alpha();
  void set_tl_net_alpha(double alpha);

  double get_tl_net_a_ae();
  void set_tl_net_a_ae(double a_ae);

  double get_tl_net_c_ae();
  void set_tl_net_c_ae(double c_ae);

  double get_tl_net_a_lat();
  void set_tl_net_a_lat(double a_lat);

  double get_tl_net_c_lat();
  void set_tl_net_c_lat(double c_lat);

  void restore_split_defaults();
  void restore_augmentation_defaults();
  void restore_training_defaults();

  void restore_defaults();

  // constants
  const static std::string DEEPSSM_SAMPLER_GAUSSIAN_C;
  const static std::string DEEPSSM_SAMPLER_MIXTURE_C;
  const static std::string DEEPSSM_SAMPLER_KDE_C;

 private:
  Parameters params_;
  ProjectHandle project_;
};

}  // namespace shapeworks

Updated on 2024-03-17 at 12:58:44 -0600