mirror of
https://github.com/wpilibsuite/allwpilib
synced 2026-06-20 00:51:42 +00:00
[wpimath] Use SDA algorithm instead of SSCA for DARE solver (#5526)
Both seem to work, but the SDA algorithm is specifically recommended for solving DAREs as opposed to P-DAREs. The QR decomposition was replaced with a partial pivoting LU decomposition at the recommendation of section 2.4 of the paper. More tests and a separate JNI function for each DARE solver variant were added.
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include "Eigen/Cholesky"
|
||||
#include "Eigen/Core"
|
||||
#include "Eigen/Eigenvalues"
|
||||
#include "Eigen/LU"
|
||||
#include "Eigen/QR"
|
||||
#include "frc/fmt/Eigen.h"
|
||||
|
||||
@@ -47,7 +48,7 @@ bool IsStabilizable(const Eigen::Ref<const Eigen::MatrixXd>& A,
|
||||
}
|
||||
|
||||
Eigen::MatrixXcd E{A.rows(), A.rows() + B.cols()};
|
||||
E << es.eigenvalues()[i] * Eigen::MatrixXcd::Identity(A.rows(), A.rows()) -
|
||||
E << es.eigenvalues()[i] * Eigen::MatrixXcd::Identity(A.rows(), A.cols()) -
|
||||
A,
|
||||
B;
|
||||
|
||||
@@ -74,39 +75,6 @@ bool IsDetectable(const Eigen::Ref<const Eigen::MatrixXd>& A,
|
||||
return IsStabilizable(A.transpose(), C.transpose());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if all the matrix's eigenvalues are greater than or equal to
|
||||
* zero.
|
||||
*
|
||||
* @param A The matrix.
|
||||
*/
|
||||
bool IsPositiveSemidefinite(const Eigen::Ref<const Eigen::MatrixXd>& A) {
|
||||
Eigen::SelfAdjointEigenSolver<Eigen::MatrixXd> es{A, Eigen::EigenvaluesOnly};
|
||||
for (int i = 0; i < A.rows(); ++i) {
|
||||
if (es.eigenvalues()[i] < 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if all the matrix's eigenvalues are greater than zero.
|
||||
*
|
||||
* @param A The matrix.
|
||||
*/
|
||||
bool IsPositiveDefinite(const Eigen::Ref<const Eigen::MatrixXd>& A) {
|
||||
Eigen::SelfAdjointEigenSolver<Eigen::MatrixXd> es{A, Eigen::EigenvaluesOnly};
|
||||
for (int i = 0; i < A.rows(); ++i) {
|
||||
if (es.eigenvalues()[i] <= 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Eigen::MatrixXd DARE(const Eigen::Ref<const Eigen::MatrixXd>& A,
|
||||
@@ -123,14 +91,6 @@ Eigen::MatrixXd DARE(const Eigen::Ref<const Eigen::MatrixXd>& A,
|
||||
assert(Q.rows() == states && Q.cols() == states);
|
||||
assert(R.rows() == inputs && R.cols() == inputs);
|
||||
|
||||
// Require the number of inputs be less than or equal to the number of states
|
||||
if (inputs > states) {
|
||||
std::string msg = fmt::format(
|
||||
"Number of inputs ({}) is greater than number of states ({})!", inputs,
|
||||
states);
|
||||
throw std::invalid_argument(msg);
|
||||
}
|
||||
|
||||
// Require Q be symmetric
|
||||
if ((Q - Q.transpose()).norm() > 1e-10) {
|
||||
std::string msg =
|
||||
@@ -139,7 +99,17 @@ Eigen::MatrixXd DARE(const Eigen::Ref<const Eigen::MatrixXd>& A,
|
||||
}
|
||||
|
||||
// Require Q be positive semidefinite
|
||||
if (!IsPositiveSemidefinite(Q)) {
|
||||
//
|
||||
// If Q is a symmetric matrix with a decomposition LDLᵀ, the number of
|
||||
// positive, negative, and zero diagonal entries in D equals the number of
|
||||
// positive, negative, and zero eigenvalues respectively in Q (see
|
||||
// https://en.wikipedia.org/wiki/Sylvester's_law_of_inertia).
|
||||
//
|
||||
// Therefore, D having no negative diagonal entries is sufficient to prove Q
|
||||
// is positive semidefinite.
|
||||
auto Q_ldlt = Q.ldlt();
|
||||
if (Q_ldlt.info() != Eigen::Success ||
|
||||
(Q_ldlt.vectorD().array() < 0.0).any()) {
|
||||
std::string msg = fmt::format("Q isn't positive semidefinite!\n\nQ =\n{}\n",
|
||||
Eigen::MatrixXd{Q});
|
||||
throw std::invalid_argument(msg);
|
||||
@@ -152,13 +122,6 @@ Eigen::MatrixXd DARE(const Eigen::Ref<const Eigen::MatrixXd>& A,
|
||||
throw std::invalid_argument(msg);
|
||||
}
|
||||
|
||||
// Require R be positive definite
|
||||
if (!IsPositiveDefinite(R)) {
|
||||
std::string msg = fmt::format("R isn't positive definite!\n\nR =\n{}\n",
|
||||
Eigen::MatrixXd{R});
|
||||
throw std::invalid_argument(msg);
|
||||
}
|
||||
|
||||
// Require (A, B) pair be stabilizable
|
||||
if (!IsStabilizable(A, B)) {
|
||||
std::string msg =
|
||||
@@ -169,9 +132,8 @@ Eigen::MatrixXd DARE(const Eigen::Ref<const Eigen::MatrixXd>& A,
|
||||
|
||||
// Require (A, C) pair be detectable where Q = CᵀC
|
||||
{
|
||||
Eigen::LDLT<Eigen::MatrixXd> ldlt{Q};
|
||||
Eigen::MatrixXd C = Eigen::MatrixXd{ldlt.matrixL()} *
|
||||
ldlt.vectorD().cwiseSqrt().asDiagonal();
|
||||
Eigen::MatrixXd C = Eigen::MatrixXd{Q_ldlt.matrixL()} *
|
||||
Q_ldlt.vectorD().cwiseSqrt().asDiagonal();
|
||||
|
||||
if (!IsDetectable(A, C)) {
|
||||
std::string msg = fmt::format(
|
||||
@@ -182,61 +144,7 @@ Eigen::MatrixXd DARE(const Eigen::Ref<const Eigen::MatrixXd>& A,
|
||||
}
|
||||
}
|
||||
|
||||
// Implements the SSCA algorithm on page 12 of [1].
|
||||
|
||||
// A₀ = A
|
||||
Eigen::MatrixXd A_k = A;
|
||||
Eigen::MatrixXd A_k1 = A;
|
||||
|
||||
// G₀ = BR⁻¹Bᵀ
|
||||
//
|
||||
// See equation (4) of [1].
|
||||
Eigen::MatrixXd G_k = B * R.llt().solve(B.transpose());
|
||||
|
||||
Eigen::MatrixXd I = Eigen::MatrixXd::Identity(A.rows(), A.cols());
|
||||
|
||||
// H₀ = Q
|
||||
//
|
||||
// See equation (4) of [1].
|
||||
Eigen::MatrixXd H_k = Q;
|
||||
Eigen::MatrixXd H_k1 = Q;
|
||||
|
||||
do {
|
||||
A_k = A_k1;
|
||||
H_k = H_k1;
|
||||
|
||||
// W = I + HₖGₖ
|
||||
Eigen::MatrixXd W = I + H_k * G_k;
|
||||
|
||||
// W is symmetric positive definite, so the LLT decomposition would work
|
||||
// here and is faster than the householder QR decomposition [2]. However,
|
||||
// it's not accurate enough. Experimentation showed that so many iterations
|
||||
// of iterative refinement [3] were required to fix the accuracy that the
|
||||
// total system solve time was much higher than householder QR.
|
||||
//
|
||||
// [2] https://eigen.tuxfamily.org/dox/group__TutorialLinearAlgebra.html
|
||||
// [3] https://en.wikipedia.org/wiki/Iterative_refinement
|
||||
auto W_solver = W.householderQr();
|
||||
|
||||
// Solve WV₁ = Aₖᵀ for V₁
|
||||
Eigen::MatrixXd V_1 = W_solver.solve(A_k.transpose());
|
||||
|
||||
// Solve WV₂ = Hₖ for V₂
|
||||
Eigen::MatrixXd V_2 = W_solver.solve(H_k);
|
||||
|
||||
// Aₖ₊₁ = V₁ᵀAₖ
|
||||
A_k1 = V_1.transpose() * A_k;
|
||||
|
||||
// Gₖ₊₁ = Gₖ + AₖGₖV₁
|
||||
G_k += A_k * G_k * V_1;
|
||||
|
||||
// Hₖ₊₁ = Hₖ + AₖᵀV₂Aₖ
|
||||
H_k1 = H_k + A_k.transpose() * V_2 * A_k;
|
||||
|
||||
// while |Hₖ₊₁ − Hₖ| > ε |Hₖ₊₁|
|
||||
} while ((H_k1 - H_k).norm() > 1e-10 * H_k1.norm());
|
||||
|
||||
return H_k1;
|
||||
return internal::DARE(A, B, Q, R);
|
||||
}
|
||||
|
||||
Eigen::MatrixXd DARE(const Eigen::Ref<const Eigen::MatrixXd>& A,
|
||||
@@ -244,8 +152,19 @@ Eigen::MatrixXd DARE(const Eigen::Ref<const Eigen::MatrixXd>& A,
|
||||
const Eigen::Ref<const Eigen::MatrixXd>& Q,
|
||||
const Eigen::Ref<const Eigen::MatrixXd>& R,
|
||||
const Eigen::Ref<const Eigen::MatrixXd>& N) {
|
||||
// These are unused if assertions aren't compiled in
|
||||
[[maybe_unused]] int states = A.rows();
|
||||
[[maybe_unused]] int inputs = B.cols();
|
||||
|
||||
// Check argument dimensions
|
||||
assert(N.rows() == B.rows() && N.cols() == B.cols());
|
||||
assert(N.rows() == states && N.cols() == inputs);
|
||||
|
||||
auto R_llt = R.llt();
|
||||
if (R_llt.info() != Eigen::Success) {
|
||||
std::string msg = fmt::format("R isn't positive definite!\n\nR =\n{}\n",
|
||||
Eigen::MatrixXd{R});
|
||||
throw std::invalid_argument(msg);
|
||||
}
|
||||
|
||||
// This is a change of variables to make the DARE that includes Q, R, and N
|
||||
// cost matrices fit the form of the DARE that includes only Q and R cost
|
||||
@@ -258,8 +177,108 @@ Eigen::MatrixXd DARE(const Eigen::Ref<const Eigen::MatrixXd>& A,
|
||||
// where A₂ and Q₂ are a change of variables:
|
||||
//
|
||||
// A₂ = A − BR⁻¹Nᵀ and Q₂ = Q − NR⁻¹Nᵀ
|
||||
return DARE(A - B * R.llt().solve(N.transpose()), B,
|
||||
Q - N * R.llt().solve(N.transpose()), R);
|
||||
return DARE(A - B * R_llt.solve(N.transpose()), B,
|
||||
Q - N * R_llt.solve(N.transpose()), R);
|
||||
}
|
||||
|
||||
namespace internal {
|
||||
|
||||
Eigen::MatrixXd DARE(const Eigen::Ref<const Eigen::MatrixXd>& A,
|
||||
const Eigen::Ref<const Eigen::MatrixXd>& B,
|
||||
const Eigen::Ref<const Eigen::MatrixXd>& Q,
|
||||
const Eigen::Ref<const Eigen::MatrixXd>& R) {
|
||||
// Require R be positive definite
|
||||
auto R_llt = R.llt();
|
||||
if (R_llt.info() != Eigen::Success) {
|
||||
std::string msg = fmt::format("R isn't positive definite!\n\nR =\n{}\n",
|
||||
Eigen::MatrixXd{R});
|
||||
throw std::invalid_argument(msg);
|
||||
}
|
||||
|
||||
// Implements the SDA algorithm on page 5 of [1].
|
||||
|
||||
// A₀ = A
|
||||
Eigen::MatrixXd A_k = A;
|
||||
|
||||
// G₀ = BR⁻¹Bᵀ
|
||||
//
|
||||
// See equation (4) of [1].
|
||||
Eigen::MatrixXd G_k = B * R_llt.solve(B.transpose());
|
||||
|
||||
// H₀ = Q
|
||||
//
|
||||
// See equation (4) of [1].
|
||||
Eigen::MatrixXd H_k;
|
||||
Eigen::MatrixXd H_k1 = Q;
|
||||
|
||||
do {
|
||||
H_k = H_k1;
|
||||
|
||||
// W = I + GₖHₖ
|
||||
Eigen::MatrixXd W =
|
||||
Eigen::MatrixXd::Identity(H_k.rows(), H_k.cols()) + G_k * H_k;
|
||||
|
||||
auto W_solver = W.lu();
|
||||
|
||||
// Solve WV₁ = Aₖ for V₁
|
||||
Eigen::MatrixXd V_1 = W_solver.solve(A_k);
|
||||
|
||||
// Solve V₂Wᵀ = Gₖ for V₂
|
||||
//
|
||||
// We want to put V₂Wᵀ = Gₖ into Ax = b form so we can solve it more
|
||||
// efficiently.
|
||||
//
|
||||
// V₂Wᵀ = Gₖ
|
||||
// (V₂Wᵀ)ᵀ = Gₖᵀ
|
||||
// WV₂ᵀ = Gₖᵀ
|
||||
//
|
||||
// The solution of Ax = b can be found via x = A.solve(b).
|
||||
//
|
||||
// V₂ᵀ = W.solve(Gₖᵀ)
|
||||
// V₂ = W.solve(Gₖᵀ)ᵀ
|
||||
Eigen::MatrixXd V_2 = W_solver.solve(G_k.transpose()).transpose();
|
||||
|
||||
// Gₖ₊₁ = Gₖ + AₖV₂Aₖᵀ
|
||||
G_k += A_k * V_2 * A_k.transpose();
|
||||
|
||||
// Hₖ₊₁ = Hₖ + V₁ᵀHₖAₖ
|
||||
H_k1 = H_k + V_1.transpose() * H_k * A_k;
|
||||
|
||||
// Aₖ₊₁ = AₖV₁
|
||||
A_k *= V_1;
|
||||
|
||||
// while |Hₖ₊₁ − Hₖ| > ε |Hₖ₊₁|
|
||||
} while ((H_k1 - H_k).norm() > 1e-10 * H_k1.norm());
|
||||
|
||||
return H_k1;
|
||||
}
|
||||
|
||||
Eigen::MatrixXd DARE(const Eigen::Ref<const Eigen::MatrixXd>& A,
|
||||
const Eigen::Ref<const Eigen::MatrixXd>& B,
|
||||
const Eigen::Ref<const Eigen::MatrixXd>& Q,
|
||||
const Eigen::Ref<const Eigen::MatrixXd>& R,
|
||||
const Eigen::Ref<const Eigen::MatrixXd>& N) {
|
||||
auto R_llt = R.llt();
|
||||
if (R_llt.info() != Eigen::Success) {
|
||||
std::string msg = fmt::format("R isn't positive definite!\n\nR =\n{}\n",
|
||||
Eigen::MatrixXd{R});
|
||||
throw std::invalid_argument(msg);
|
||||
}
|
||||
|
||||
// This is a change of variables to make the DARE that includes Q, R, and N
|
||||
// cost matrices fit the form of the DARE that includes only Q and R cost
|
||||
// matrices.
|
||||
//
|
||||
// This is equivalent to solving the original DARE:
|
||||
//
|
||||
// A₂ᵀXA₂ − X − A₂ᵀXB(BᵀXB + R)⁻¹BᵀXA₂ + Q₂ = 0
|
||||
//
|
||||
// where A₂ and Q₂ are a change of variables:
|
||||
//
|
||||
// A₂ = A − BR⁻¹Nᵀ and Q₂ = Q − NR⁻¹Nᵀ
|
||||
return internal::DARE(A - B * R_llt.solve(N.transpose()), B,
|
||||
Q - N * R_llt.solve(N.transpose()), R);
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace frc
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <jni.h>
|
||||
|
||||
#include <exception>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <wpi/jni_util.h>
|
||||
|
||||
@@ -102,11 +103,11 @@ extern "C" {
|
||||
|
||||
/*
|
||||
* Class: edu_wpi_first_math_WPIMathJNI
|
||||
* Method: dare
|
||||
* Method: dareABQR
|
||||
* Signature: ([D[D[D[DII[D)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_edu_wpi_first_math_WPIMathJNI_dare
|
||||
Java_edu_wpi_first_math_WPIMathJNI_dareABQR
|
||||
(JNIEnv* env, jclass, jdoubleArray A, jdoubleArray B, jdoubleArray Q,
|
||||
jdoubleArray R, jint states, jint inputs, jdoubleArray S)
|
||||
{
|
||||
@@ -137,8 +138,58 @@ Java_edu_wpi_first_math_WPIMathJNI_dare
|
||||
env->ReleaseDoubleArrayElements(R, nativeR, 0);
|
||||
|
||||
env->SetDoubleArrayRegion(S, 0, states * states, result.data());
|
||||
} catch (const std::runtime_error& e) {
|
||||
jclass cls = env->FindClass("java/lang/RuntimeException");
|
||||
} catch (const std::invalid_argument& e) {
|
||||
jclass cls = env->FindClass("java/lang/IllegalArgumentException");
|
||||
if (cls) {
|
||||
env->ThrowNew(cls, e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: edu_wpi_first_math_WPIMathJNI
|
||||
* Method: dareABQRN
|
||||
* Signature: ([D[D[D[D[DII[D)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_edu_wpi_first_math_WPIMathJNI_dareABQRN
|
||||
(JNIEnv* env, jclass, jdoubleArray A, jdoubleArray B, jdoubleArray Q,
|
||||
jdoubleArray R, jdoubleArray N, jint states, jint inputs, jdoubleArray S)
|
||||
{
|
||||
jdouble* nativeA = env->GetDoubleArrayElements(A, nullptr);
|
||||
jdouble* nativeB = env->GetDoubleArrayElements(B, nullptr);
|
||||
jdouble* nativeQ = env->GetDoubleArrayElements(Q, nullptr);
|
||||
jdouble* nativeR = env->GetDoubleArrayElements(R, nullptr);
|
||||
jdouble* nativeN = env->GetDoubleArrayElements(N, nullptr);
|
||||
|
||||
Eigen::Map<
|
||||
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
Amat{nativeA, states, states};
|
||||
Eigen::Map<
|
||||
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
Bmat{nativeB, states, inputs};
|
||||
Eigen::Map<
|
||||
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
Qmat{nativeQ, states, states};
|
||||
Eigen::Map<
|
||||
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
Rmat{nativeR, inputs, inputs};
|
||||
Eigen::Map<
|
||||
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
Nmat{nativeN, states, inputs};
|
||||
|
||||
try {
|
||||
Eigen::MatrixXd result = frc::DARE(Amat, Bmat, Qmat, Rmat, Nmat);
|
||||
|
||||
env->ReleaseDoubleArrayElements(A, nativeA, 0);
|
||||
env->ReleaseDoubleArrayElements(B, nativeB, 0);
|
||||
env->ReleaseDoubleArrayElements(Q, nativeQ, 0);
|
||||
env->ReleaseDoubleArrayElements(R, nativeR, 0);
|
||||
env->ReleaseDoubleArrayElements(N, nativeN, 0);
|
||||
|
||||
env->SetDoubleArrayRegion(S, 0, states * states, result.data());
|
||||
} catch (const std::invalid_argument& e) {
|
||||
jclass cls = env->FindClass("java/lang/IllegalArgumentException");
|
||||
if (cls) {
|
||||
env->ThrowNew(cls, e.what());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user