Skip to content

Commit

Permalink
Drop cppoptlib from Problem and Solver
Browse files Browse the repository at this point in the history
  • Loading branch information
zfergus committed Jan 5, 2024
1 parent 73d9800 commit 00474f4
Show file tree
Hide file tree
Showing 5 changed files with 310 additions and 93 deletions.
130 changes: 130 additions & 0 deletions src/polysolve/nonlinear/Criteria.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#pragma once

#include <iostream>
#include <Eigen/Core>

namespace polysolve::nonlinear
{
enum class Status
{
NotStarted = -1,
Continue = 0,
IterationLimit,
XDeltaTolerance,
FDeltaTolerance,
GradNormTolerance,
Condition,
UserDefined
};

template <typename T>
class Criteria
{
public:
size_t iterations; //!< Maximum number of iterations
T xDelta; //!< Minimum change in parameter vector
T fDelta; //!< Minimum change in cost function
T gradNorm; //!< Minimum norm of gradient vector
T condition; //!< Maximum condition number of Hessian

Criteria()
{
reset();
}

static Criteria defaults()
{
Criteria d;
d.iterations = 10000;
d.xDelta = 0;
d.fDelta = 0;
d.gradNorm = 1e-4;
d.condition = 0;
return d;
}

void reset()
{
iterations = 0;
xDelta = 0;
fDelta = 0;
gradNorm = 0;
condition = 0;
}

void print(std::ostream &os) const
{
os << "Iterations: " << iterations << std::endl;
os << "xDelta: " << xDelta << std::endl;
os << "fDelta: " << fDelta << std::endl;
os << "GradNorm: " << gradNorm << std::endl;
os << "Condition: " << condition << std::endl;
}
};

template <typename T>
Status checkConvergence(const Criteria<T> &stop, const Criteria<T> &current)
{

if ((stop.iterations > 0) && (current.iterations > stop.iterations))
{
return Status::IterationLimit;
}
if ((stop.xDelta > 0) && (current.xDelta < stop.xDelta))
{
return Status::XDeltaTolerance;
}
if ((stop.fDelta > 0) && (current.fDelta < stop.fDelta))
{
return Status::FDeltaTolerance;
}
if ((stop.gradNorm > 0) && (current.gradNorm < stop.gradNorm))
{
return Status::GradNormTolerance;
}
if ((stop.condition > 0) && (current.condition > stop.condition))
{
return Status::Condition;

Check warning on line 87 in src/polysolve/nonlinear/Criteria.hpp

View check run for this annotation

Codecov / codecov/patch

src/polysolve/nonlinear/Criteria.hpp#L87

Added line #L87 was not covered by tests
}
return Status::Continue;
}

inline std::ostream &operator<<(std::ostream &os, const Status &s)
{
switch (s)
{
case Status::NotStarted:
os << "Solver not started.";
break;
case Status::Continue:
os << "Convergence criteria not reached.";
break;
case Status::IterationLimit:
os << "Iteration limit reached.";
break;
case Status::XDeltaTolerance:
os << "Change in parameter vector too small.";
break;
case Status::FDeltaTolerance:
os << "Change in cost function value too small.";
break;

Check warning on line 110 in src/polysolve/nonlinear/Criteria.hpp

View check run for this annotation

Codecov / codecov/patch

src/polysolve/nonlinear/Criteria.hpp#L96-L110

Added lines #L96 - L110 were not covered by tests
case Status::GradNormTolerance:
os << "Gradient vector norm too small.";
break;
case Status::Condition:
os << "Condition of Hessian/Covariance matrix too large.";
break;
case Status::UserDefined:
os << "Stop condition defined in the callback.";
break;

Check warning on line 119 in src/polysolve/nonlinear/Criteria.hpp

View check run for this annotation

Codecov / codecov/patch

src/polysolve/nonlinear/Criteria.hpp#L114-L119

Added lines #L114 - L119 were not covered by tests
}
return os;
}

template <typename T>
std::ostream &operator<<(std::ostream &os, const Criteria<T> &c)
{
c.print(os);
return os;
}
} // namespace polysolve::nonlinear
100 changes: 86 additions & 14 deletions src/polysolve/nonlinear/Problem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,119 @@

#include <polysolve/Types.hpp>

#include "Criteria.hpp"
#include "PostStepData.hpp"

#include <cppoptlib/problem.h>

#include <memory>
#include <vector>

namespace polysolve::nonlinear
{
class Problem : public cppoptlib::Problem<double>
class Problem
{
public:
using typename cppoptlib::Problem<double>::Scalar;
using typename cppoptlib::Problem<double>::TVector;
typedef polysolve::StiffnessMatrix THessian;

// disable warning for dense hessian
using cppoptlib::Problem<double>::hessian;
static const int Dim = Eigen::Dynamic;
using Scalar = double;
using TVector = Eigen::Matrix<Scalar, Dim, 1>;
using TMatrix = Eigen::Matrix<Scalar, Dim, Dim>;
using THessian = StiffnessMatrix;

public:
Problem() {}
~Problem() = default;
virtual ~Problem() = default;

/// @brief Initialize the problem.
/// @param x0 Initial guess.
virtual void init(const TVector &x0) {}

virtual double value(const TVector &x) override = 0;
virtual void gradient(const TVector &x, TVector &gradv) override = 0;
/// @brief Compute the value of the function at x.
/// @param x Degrees of freedom.
/// @return The value of the function at x.
Scalar operator()(const TVector &x) { return value(x); }

/// @brief Compute the value of the function at x.
/// @param x Degrees of freedom.
/// @return The value of the function at x.
virtual Scalar value(const TVector &x) = 0;

/// @brief Compute the gradient of the function at x.
/// @param[in] x Degrees of freedom.
/// @param[out] grad Gradient of the function at x.
virtual void gradient(const TVector &x, TVector &grad) = 0;

// TODO: Add dense Hessian

/// @brief Compute the Hessian of the function at x.
/// @param[in] x Degrees of freedom.
/// @param[out] hessian Hessian of the function at x.
virtual void hessian(const TVector &x, TMatrix &hessian)

Check warning on line 50 in src/polysolve/nonlinear/Problem.hpp

View check run for this annotation

Codecov / codecov/patch

src/polysolve/nonlinear/Problem.hpp#L50

Added line #L50 was not covered by tests
{
StiffnessMatrix sparse_hessian;
hessian(x, sparse_hessian);
hessian = sparse_hessian;
}

Check warning on line 55 in src/polysolve/nonlinear/Problem.hpp

View check run for this annotation

Codecov / codecov/patch

src/polysolve/nonlinear/Problem.hpp#L52-L55

Added lines #L52 - L55 were not covered by tests

/// @brief Compute the Hessian of the function at x.
/// @param[in] x Degrees of freedom.
/// @param[out] hessian Hessian of the function at x.
virtual void hessian(const TVector &x, THessian &hessian) = 0;

virtual bool is_step_valid(const TVector &x0, const TVector &x1) { return true; }
virtual double max_step_size(const TVector &x0, const TVector &x1) { return 1; }
/// @brief Determine if the step from x0 to x1 is valid.
/// @param x0 Starting point.
/// @param x1 Ending point.
/// @return True if the step is valid, false otherwise.
virtual bool is_step_valid(const TVector &x0, const TVector &x1) const { return true; }

/// @brief Determine a maximum step size from x0 to x1.
/// @param x0 Starting point.
/// @param x1 Ending point.
/// @return Maximum step size.
virtual double max_step_size(const TVector &x0, const TVector &x1) const { return 1; }

// --- Callbacks ------------------------------------------------------

/// @brief Callback function for the start of a line search.
/// @param x0 Starting point.
/// @param x1 Ending point.
virtual void line_search_begin(const TVector &x0, const TVector &x1) {}

/// @brief Callback function for the end of a line search.
virtual void line_search_end() {}

/// @brief Callback function for the end of a step.
/// @param data Post step data.
virtual void post_step(const PostStepData &data) {}

/// @brief Set the project to PSD flag.
/// @param val True if the problem should be projected to PSD, false otherwise.
virtual void set_project_to_psd(bool val) {}

/// @brief Callback function for when the solution changes.
/// @param new_x New solution.
virtual void solution_changed(const TVector &new_x) {}

/// @brief Callback function used to determine if the solver should stop.
/// @param state Current state of the solver.
/// @param x Current solution.
/// @return True if the solver should stop, false otherwise.
virtual bool callback(const Criteria<Scalar> &state, const TVector &x) { return true; }

/// @brief Callback function used Determine if the solver should stop.
/// @param x Current solution.
/// @return True if the solver should stop, false otherwise.
virtual bool stop(const TVector &x) { return false; }

/// --- Misc ----------------------------------------------------------

/// @brief Sample the function along a direction.
/// @param[in] x Starting point.
/// @param[in] direction Direction to sample along.
/// @param[in] start Starting step size.
/// @param[in] end Ending step size.
/// @param[in] num_samples Number of samples to take.
/// @param[out] alphas Sampled step sizes.
/// @param[out] fs Sampled function values.
/// @param[out] valid If each sample is valid.
void sample_along_direction(
const Problem::TVector &x,
const Problem::TVector &direction,
Expand Down
Loading

0 comments on commit 00474f4

Please sign in to comment.