Skip to content

Libs/Optimize/Matrix/MixedEffectsShapeMatrix.h

Namespaces

Name
shapeworks
User usage reporting (telemetry)

Classes

Name
class shapeworks::MixedEffectsShapeMatrix

Source code

/* Class for Mixed-effects regression */

#pragma once

#include "Libs/Optimize/Matrix/LegacyShapeMatrix.h"
#include "ParticleSystem.h"
#include "vnl/vnl_trace.h"
#include "vnl/vnl_vector.h"

namespace shapeworks {
class MixedEffectsShapeMatrix : public LegacyShapeMatrix {
 public:
  typedef double DataType;
  typedef MixedEffectsShapeMatrix Self;
  typedef LegacyShapeMatrix Superclass;
  typedef itk::SmartPointer<Self> Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;
  typedef itk::WeakPointer<const Self> ConstWeakPointer;

  itkNewMacro(Self);

  itkTypeMacro(MixedEffectsShapeMatrix, LegacyShapeMatrix);

  void UpdateMeanMatrix() {
    // for each sample
    vnl_vector<double> tempvect;
    tempvect.set_size(m_MeanMatrix.rows());
    tempvect.fill(0.0);
    for (unsigned int i = 0; i < m_MeanMatrix.cols(); i++) {
      int group_indx = i / m_TimeptsPerIndividual;
      tempvect = m_Intercept + m_Slope * m_Expl(i);
      tempvect = tempvect + m_InterceptRand.get_row(group_indx);
      tempvect = tempvect + m_SlopeRand.get_row(group_indx) * m_Expl(i);
      // compute the mean
      m_MeanMatrix.set_column(i, tempvect);
    }
  }

  inline vnl_vector<double> ComputeMean(double k) const { return m_Intercept + m_Slope * k; }

  void ResizeParameters(unsigned int n) {
    vnl_vector<double> tmpA = m_Intercept;  // copy existing  matrix
    vnl_vector<double> tmpB = m_Slope;      // copy existing  matrix

    // Create new
    m_Intercept.set_size(n);
    m_Slope.set_size(n);

    // Copy old data into new vector.
    for (unsigned int r = 0; r < tmpA.size(); r++) {
      m_Intercept(r) = tmpA(r);
      m_Slope(r) = tmpB(r);
    }
  }

  virtual void ResizeMeanMatrix(int rs, int cs) {
    vnl_matrix<double> tmp = m_MeanMatrix;  // copy existing  matrix

    // Create new column (shape)
    m_MeanMatrix.set_size(rs, cs);

    m_MeanMatrix.fill(0.0);

    // Copy old data into new matrix.
    for (unsigned int c = 0; c < tmp.cols(); c++) {
      for (unsigned int r = 0; r < tmp.rows(); r++) {
        m_MeanMatrix(r, c) = tmp(r, c);
      }
    }
  }

  void ResizeExplanatory(unsigned int n) {
    if (n > m_Expl.size()) {
      vnl_vector<double> tmp = m_Expl;  // copy existing  matrix

      // Create new
      m_Expl.set_size(n);
      m_Expl.fill(0.0);

      // Copy old data into new vector.
      for (unsigned int r = 0; r < tmp.size(); r++) {
        m_Expl(r) = tmp(r);
      }
    }
  }

  virtual void DomainAddEventCallback(Object*, const itk::EventObject& e) {
    const ParticleDomainAddEvent& event = dynamic_cast<const ParticleDomainAddEvent&>(e);
    unsigned int d = event.GetDomainIndex();

    if (d % this->m_DomainsPerShape == 0) {
      this->ResizeMatrix(this->rows(), this->cols() + 1);
      this->ResizeMeanMatrix(this->rows(), this->cols() + 1);
      this->ResizeExplanatory(this->cols());
    }
  }

  virtual void PositionAddEventCallback(Object* o, const itk::EventObject& e) {
    const int VDimension = 3;

    const ParticlePositionAddEvent& event = dynamic_cast<const ParticlePositionAddEvent&>(e);
    const ParticleSystem* ps = dynamic_cast<const ParticleSystem*>(o);
    const int d = event.GetDomainIndex();
    const unsigned int idx = event.GetPositionIndex();
    const typename ParticleSystem::PointType pos = ps->GetTransformedPosition(idx, d);

    const unsigned int PointsPerDomain = ps->GetNumberOfParticles(d);

    // Make sure we have enough rows.
    if ((ps->GetNumberOfParticles(d) * VDimension * this->m_DomainsPerShape) > this->rows()) {
      this->ResizeParameters(PointsPerDomain * VDimension * this->m_DomainsPerShape);
      this->ResizeMatrix(PointsPerDomain * VDimension * this->m_DomainsPerShape, this->cols());
      this->ResizeMeanMatrix(PointsPerDomain * VDimension * this->m_DomainsPerShape, this->cols());
    }

    // CANNOT ADD POSITION INFO UNTIL ALL POINTS PER DOMAIN IS KNOWN
    // Add position info to the matrix
    unsigned int k = ((d % this->m_DomainsPerShape) * PointsPerDomain * VDimension) + (idx * VDimension);
    for (unsigned int i = 0; i < VDimension; i++) {
      this->operator()(i + k, d / this->m_DomainsPerShape) = pos[i];
    }

    //   std::cout << "Row " << k << " Col " << d / this->m_DomainsPerShape << " = " << pos << std::endl;
  }

  virtual void PositionSetEventCallback(Object* o, const itk::EventObject& e) {
    const int VDimension = 3;

    const ParticlePositionSetEvent& event = dynamic_cast<const ParticlePositionSetEvent&>(e);

    const ParticleSystem* ps = dynamic_cast<const ParticleSystem*>(o);
    const int d = event.GetDomainIndex();
    const unsigned int idx = event.GetPositionIndex();
    const typename ParticleSystem::PointType pos = ps->GetTransformedPosition(idx, d);
    const unsigned int PointsPerDomain = ps->GetNumberOfParticles(d);

    // Modify matrix info
    //    unsigned int k = VDimension * idx;
    unsigned int k = ((d % this->m_DomainsPerShape) * PointsPerDomain * VDimension) + (idx * VDimension);

    for (unsigned int i = 0; i < VDimension; i++) {
      this->operator()(i + k, d / this->m_DomainsPerShape) = pos[i] - m_MeanMatrix(i + k, d / this->m_DomainsPerShape);
    }
  }

  virtual void PositionRemoveEventCallback(Object*, const itk::EventObject&) {
    // NEED TO IMPLEMENT THIS
  }

  void SetDomainsPerShape(int i) { this->m_DomainsPerShape = i; }
  int GetDomainsPerShape() const { return this->m_DomainsPerShape; }

  void SetTimeptsPerIndividual(int i) { this->m_TimeptsPerIndividual = i; }

  int GetTimeptsPerIndividual() const { return this->m_TimeptsPerIndividual; }

  void SetExplanatory(std::vector<double> v) {
    //    std::cout << "Setting expl " << std::endl;
    ResizeExplanatory(v.size());
    for (unsigned int i = 0; i < v.size(); i++) {
      //      std::cout << v[i] << std::endl;
      m_Expl[i] = v[i];
    }
  }
  void SetExplanatory(unsigned int i, double q) { m_Expl[i] = q; }
  const double& GetExplanatory(unsigned int i) const { return m_Expl[i]; }
  double& GetExplanatory(unsigned int i) { return m_Expl[i]; }

  const vnl_vector<double>& GetSlope() const { return m_Slope; }
  const vnl_vector<double>& GetIntercept() const { return m_Intercept; }
  const vnl_matrix<double>& GetSlopeRandom() const { return m_SlopeRand; }
  const vnl_matrix<double>& GetInterceptRandom() const { return m_InterceptRand; }

  void SetSlope(const std::vector<double>& v) {
    ResizeParameters(v.size());
    for (unsigned int i = 0; i < v.size(); i++) {
      m_Slope[i] = v[i];
    }
  }

  void SetIntercept(const std::vector<double>& v) {
    ResizeParameters(v.size());
    for (unsigned int i = 0; i < v.size(); i++) {
      m_Intercept[i] = v[i];
    }
  }

  void EstimateParameters() {
    //    std::cout << "Estimating params" << std::endl;
    //    std::cout << "Explanatory: " << m_Expl << std::endl;

    vnl_matrix<double> X = *this + m_MeanMatrix;

    // Number of samples
    int num_shapes = static_cast<double>(X.cols());
    this->m_NumIndividuals = num_shapes / this->GetTimeptsPerIndividual();
    int nr = X.rows();  // number of points*3

    // set the sizes of random slope and intercept matrix
    m_SlopeRand.set_size(m_NumIndividuals, nr);      // num_groups X num_points*3
    m_InterceptRand.set_size(m_NumIndividuals, nr);  // num_groups X num_points*3

    vnl_matrix<double> fixed;   // slopes + intercepts for all points
    vnl_matrix<double> random;  // slopes + intercepts for all groups, for all points
    fixed.set_size(2, nr);
    random.set_size(2, nr * m_NumIndividuals);
    vnl_matrix<double> Ds(2, 2);  // covariance matrix of random parameters (2x2)
    Ds.set_identity();            // initialize to identity
    double sigma2s = 1;           // variance of error
    vnl_matrix<double> identity_n;
    identity_n.set_size(m_TimeptsPerIndividual, m_TimeptsPerIndividual);
    identity_n.set_identity();
    vnl_matrix<double> identity_2;
    identity_2.set_size(2, 2);
    identity_2.set_identity();
    vnl_matrix<double>*Ws = NULL, *Vs = NULL;
    Ws = new vnl_matrix<double>[m_NumIndividuals];
    Vs = new vnl_matrix<double>[m_NumIndividuals];
    for (int i = 0; i < m_NumIndividuals; i++) {
      Vs[i].set_size(m_TimeptsPerIndividual, m_TimeptsPerIndividual);
      Ws[i].set_size(m_TimeptsPerIndividual, m_TimeptsPerIndividual);
    }

    vnl_matrix<double> sum_mat1(2, 2, 0);
    vnl_vector<double> sum_mat2(2);
    sum_mat2.fill(0.0);
    vnl_vector<double> residual;
    residual.set_size(m_TimeptsPerIndividual);
    residual.fill(0.0);
    double ecorr = 0.0;
    double tracevar = 0.0;
    vnl_matrix<double> bscorr(2, 2, 0.0);
    vnl_matrix<double> bsvar(2, 2, 0.0);
    vnl_matrix<double> Xp;
    Xp.set_size(m_TimeptsPerIndividual, 2);
    vnl_vector<double> y;
    y.set_size(m_TimeptsPerIndividual);
    vnl_vector<double> tempvect;
    tempvect.set_size(2);
    for (int i = 0; i < nr; i++)  // for all points (x,y,z coordinates)
    {
      sigma2s = 1.0;
      Ds.set_identity();
      for (int j = 0; j < 50; j++)  // EM iterations
      {
        sum_mat1.fill(0.0);
        sum_mat2.fill(0.0);
        residual.fill(0.0);
        ecorr = 0.0;
        tracevar = 0.0;
        bscorr.fill(0.0);
        bsvar.fill(0.0);
        for (int k = 0; k < m_NumIndividuals; k++) {
          for (int l = 0; l < m_TimeptsPerIndividual; l++) {
            Xp(l, 0) = m_Expl(k * m_TimeptsPerIndividual + l);
            Xp(l, 1) = 1;
            y(l) = X(i, k * m_TimeptsPerIndividual + l);
          }
          Vs[k] = (identity_n * sigma2s) + Xp * Ds * vnl_transpose(Xp);
          // Ws = static_cast<vnl_matrix> (vnl_matrix_inverse<double>(Vs));
          Ws[k] = vnl_inverse(Vs[k]);
          sum_mat1 = sum_mat1 + vnl_transpose(Xp) * Ws[k] * Xp;
          sum_mat2 = sum_mat2 + vnl_transpose(Xp) * Ws[k] * y;
        }
        tempvect = vnl_inverse(sum_mat1) * sum_mat2;
        fixed.set_column(i, tempvect);
        for (int k = 0; k < m_NumIndividuals; k++) {
          for (int l = 0; l < m_TimeptsPerIndividual; l++) {
            Xp(l, 0) = m_Expl(k * m_TimeptsPerIndividual + l);
            Xp(l, 1) = 1;
            y(l) = X(i, k * m_TimeptsPerIndividual + l);
          }
          tempvect = Ds * vnl_transpose(Xp) * Ws[k] * (y - (Xp * fixed.get_column(i)));
          random.set_column(i * m_NumIndividuals + k, tempvect);
          residual = y - (Xp * fixed.get_column(i)) - (Xp * random.get_column(i * m_NumIndividuals + k));
          ecorr = ecorr + dot_product(residual, residual);
          tracevar = tracevar + (m_TimeptsPerIndividual - sigma2s * vnl_trace(Ws[k]));
          bscorr = bscorr + outer_product(random.get_column(i * m_NumIndividuals + k),
                                          random.get_column(i * m_NumIndividuals + k));
          bsvar = bsvar + (identity_2 - (vnl_transpose(Xp) * Ws[k] * Xp * Ds));
        }
        sigma2s = (ecorr + sigma2s * tracevar) / (num_shapes);
        Ds = (bscorr + Ds * bsvar) / m_NumIndividuals;
      }  // endfor EM iterations
      // printf ("point #%d\n", i);
    }  // endfor all points on shape (x,y & z)

    m_Slope = fixed.get_row(0);
    m_Intercept = fixed.get_row(1);
    for (int i = 0; i < m_NumIndividuals; i++) {
      for (int j = 0; j < nr; j++)  // for all points * 3
      {
        m_SlopeRand(i, j) = random(0, j * m_NumIndividuals + i);
        m_InterceptRand(i, j) = random(1, j * m_NumIndividuals + i);
      }
    }
    delete[] Vs;
    delete[] Ws;
    // printf ("points:\n");
    // for (int k = 0; k < m_NumIndividuals; k++)
    //  for (int l = 0; l < m_TimeptsPerIndividual; l++)
    //      printf ("%g   %g\n", X(0,k*m_TimeptsPerIndividual + l), m_Expl(k*m_TimeptsPerIndividual + l));

    // printf ("fixed: slope %g, intercept %g", m_Slope(0), m_Intercept(0));
    // printf ("random: slopes %g %g, intercepts %g %g", m_SlopeRand(0,0), m_SlopeRand(1,0), m_InterceptRand(0,0),
    // m_InterceptRand(1,0));
  }

  //
  void Initialize() {
    m_Intercept.fill(0.0);
    m_Slope.fill(0.0);
    m_MeanMatrix.fill(0.0);

    m_SlopeRand.fill(0.0);
    m_InterceptRand.fill(0.0);
  }

  virtual void BeforeIteration() {
    m_UpdateCounter++;
    if (m_UpdateCounter >= m_RegressionInterval) {
      m_UpdateCounter = 0;
      this->EstimateParameters();
      this->UpdateMeanMatrix();
    }
  }

  void SetRegressionInterval(int i) { m_RegressionInterval = i; }
  int GetRegressionInterval() const { return m_RegressionInterval; }

 protected:
  MixedEffectsShapeMatrix() {
    this->m_DefinedCallbacks.DomainAddEvent = true;
    this->m_DefinedCallbacks.PositionAddEvent = true;
    this->m_DefinedCallbacks.PositionSetEvent = true;
    this->m_DefinedCallbacks.PositionRemoveEvent = true;
    m_UpdateCounter = 0;
    m_RegressionInterval = 1;
    m_NumIndividuals = 13;
    m_TimeptsPerIndividual = 3;
  }
  virtual ~MixedEffectsShapeMatrix(){};

  void PrintSelf(std::ostream& os, itk::Indent indent) const { Superclass::PrintSelf(os, indent); }

 private:
  MixedEffectsShapeMatrix(const Self&);  // purposely not implemented
  void operator=(const Self&);                            // purposely not implemented

  int m_UpdateCounter;
  int m_RegressionInterval;

  // Parameters for the linear model
  vnl_vector<double> m_Intercept;
  vnl_vector<double> m_Slope;

  // The explanatory variable value for each sample (matrix column)
  vnl_vector<double> m_Expl;

  // A matrix to store the mean estimated for each explanatory variable (each sample)
  vnl_matrix<double> m_MeanMatrix;

  vnl_matrix<double> m_InterceptRand;  // added: AK , random intercepts for each group
  vnl_matrix<double> m_SlopeRand;      // added: AK , random slopes for each group
  int m_NumIndividuals;
  int m_TimeptsPerIndividual;
};

}  // namespace shapeworks

Updated on 2024-11-11 at 19:51:46 +0000