Skip to content

Libs/Optimize/Function/EarlyStop/EarlyStopping.h

Namespaces

Name
shapeworks
User usage reporting (telemetry)

Classes

Name
class shapeworks::EarlyStopping

Source code

#pragma once
#include <Eigen/Dense>

#include "Libs/Optimize/ParticleSystem.h"
#include "Libs/Optimize/EarlyStoppingConfig.h"
#include "MorphologicalDeviationScore.h"

namespace shapeworks {

class EarlyStopping {
public:
    typedef typename ParticleSystem::PointType PointType;
    constexpr static int VDimension = 3;
    EarlyStopping();
    void SetConfigParams(int frequency,
                  int window_size,
                  double threshold,
                  EarlyStoppingStrategy strategy = EarlyStoppingStrategy::RelativeDifference,
                  double ema_alpha = 0.2,
                  bool enable_logging = false,
                  const std::string& logger_name = "",
                  int warmup_iters = 1000);

    void reset();
    void update(int iteration, const ParticleSystem* p);
    bool ShouldStop() const;
    bool SetControlShapes(const ParticleSystem* p);
    Eigen::MatrixXd GetTestShapes(const ParticleSystem* p);

private:
    std::deque<Eigen::VectorXd> score_history_;
    int frequency_, window_size_;
    double threshold_, ema_alpha_;
    int last_checked_iter_;
    int warmup_iters_;
    // bool stop_flag_;
    mutable std::atomic<bool> stop_flag_{false};
    bool enable_logging_;
    std::string logger_name_;
    Eigen::MatrixXd control_shapes_; 
    MorphologicalDeviationScore score_func_;
    EarlyStoppingStrategy strategy_;
    mutable Eigen::VectorXd ema_diff_;
    mutable bool ema_initialized_ = false;

    Eigen::VectorXd ComputeScore(const Eigen::MatrixXd& X) ;
    Eigen::VectorXd ComputeRelativeDiff(const Eigen::VectorXd& a, const Eigen::VectorXd& b) const;
    bool HasConverged() const;
    bool CheckRelativeDifference() const;
    bool CheckExponentialMovingAverage() const;
    void LogStatus(int iter,
                   const Eigen::VectorXd& current_score,
                   const Eigen::VectorXd& diff,
                   const std::vector<bool>& per_subject_convergence) const;
};

} // namespace shapeworks

Updated on 2025-10-13 at 18:47:50 +0000