From 394cfeadbd3a2218a367bb14af3c2423d6a4ddec Mon Sep 17 00:00:00 2001 From: Tyler Veness Date: Sat, 12 Aug 2023 19:45:45 -0700 Subject: [PATCH] [wpimath] Use SDA algorithm instead of SSCA for DARE solver (#5526) Both seem to work, but the SDA algorithm is specifically recommended for solving DAREs as opposed to P-DAREs. The QR decomposition was replaced with a partial pivoting LU decomposition at the recommendation of section 2.4 of the paper. More tests and a separate JNI function for each DARE solver variant were added. --- .../main/java/edu/wpi/first/math/DARE.java | 48 ++-- .../java/edu/wpi/first/math/WPIMathJNI.java | 32 ++- wpimath/src/main/native/cpp/DARE.cpp | 241 ++++++++++-------- .../src/main/native/cpp/jni/WPIMathJNI.cpp | 59 ++++- wpimath/src/main/native/include/frc/DARE.h | 91 ++++++- .../java/edu/wpi/first/math/DARETest.java | 116 +++++++++ wpimath/src/test/native/cpp/DARETest.cpp | 117 ++++++++- 7 files changed, 549 insertions(+), 155 deletions(-) diff --git a/wpimath/src/main/java/edu/wpi/first/math/DARE.java b/wpimath/src/main/java/edu/wpi/first/math/DARE.java index a6f88fb6c1..698c3f7731 100644 --- a/wpimath/src/main/java/edu/wpi/first/math/DARE.java +++ b/wpimath/src/main/java/edu/wpi/first/math/DARE.java @@ -19,10 +19,14 @@ public final class DARE { * @param Q State cost matrix. * @param R Input cost matrix. * @return Solution of DARE. + * @throws IllegalArgumentException if Q isn't symmetric positive semidefinite. + * @throws IllegalArgumentException if R isn't symmetric positive definite. + * @throws IllegalArgumentException if the (A, B) pair isn't stabilizable. + * @throws IllegalArgumentException if the (A, C) pair where Q = CᵀC isn't detectable. */ public static SimpleMatrix dare(SimpleMatrix A, SimpleMatrix B, SimpleMatrix Q, SimpleMatrix R) { var S = new SimpleMatrix(A.getNumRows(), A.getNumCols()); - WPIMathJNI.dare( + WPIMathJNI.dareABQR( A.getDDRM().getData(), B.getDDRM().getData(), Q.getDDRM().getData(), @@ -43,6 +47,10 @@ public final class DARE { * @param Q State cost matrix. * @param R Input cost matrix. * @return Solution of DARE. + * @throws IllegalArgumentException if Q isn't symmetric positive semidefinite. + * @throws IllegalArgumentException if R isn't symmetric positive definite. + * @throws IllegalArgumentException if the (A, B) pair isn't stabilizable. + * @throws IllegalArgumentException if the (A, C) pair where Q = CᵀC isn't detectable. */ public static Matrix dare( Matrix A, @@ -61,21 +69,20 @@ public final class DARE { * @param R Input cost matrix. * @param N State-input cross-term cost matrix. * @return Solution of DARE. + * @throws IllegalArgumentException if Q − NR⁻¹Nᵀ isn't symmetric positive semidefinite. + * @throws IllegalArgumentException if R isn't symmetric positive definite. + * @throws IllegalArgumentException if the (A, B) pair isn't stabilizable. + * @throws IllegalArgumentException if the (A, C) pair where Q = CᵀC isn't detectable. */ public static SimpleMatrix dare( SimpleMatrix A, SimpleMatrix B, SimpleMatrix Q, SimpleMatrix R, SimpleMatrix N) { - // See - // https://en.wikipedia.org/wiki/Linear%E2%80%93quadratic_regulator#Infinite-horizon,_discrete-time_LQR - // for the change of variables used here. - var scrA = A.minus(B.mult(R.solve(N.transpose()))); - var scrQ = Q.minus(N.mult(R.solve(N.transpose()))); - var S = new SimpleMatrix(A.getNumRows(), A.getNumCols()); - WPIMathJNI.dare( - scrA.getDDRM().getData(), + WPIMathJNI.dareABQRN( + A.getDDRM().getData(), B.getDDRM().getData(), - scrQ.getDDRM().getData(), + Q.getDDRM().getData(), R.getDDRM().getData(), + N.getDDRM().getData(), A.getNumCols(), B.getNumCols(), S.getDDRM().getData()); @@ -93,6 +100,10 @@ public final class DARE { * @param R Input cost matrix. * @param N State-input cross-term cost matrix. * @return Solution of DARE. + * @throws IllegalArgumentException if Q − NR⁻¹Nᵀ isn't symmetric positive semidefinite. + * @throws IllegalArgumentException if R isn't symmetric positive definite. + * @throws IllegalArgumentException if the (A, B) pair isn't stabilizable. + * @throws IllegalArgumentException if the (A, C) pair where Q = CᵀC isn't detectable. */ public static Matrix dare( Matrix A, @@ -100,22 +111,7 @@ public final class DARE { Matrix Q, Matrix R, Matrix N) { - // This is a change of variables to make the DARE that includes Q, R, and N - // cost matrices fit the form of the DARE that includes only Q and R cost - // matrices. - // - // This is equivalent to solving the original DARE: - // - // A₂ᵀXA₂ − X − A₂ᵀXB(BᵀXB + R)⁻¹BᵀXA₂ + Q₂ = 0 - // - // where A₂ and Q₂ are a change of variables: - // - // A₂ = A − BR⁻¹Nᵀ and Q₂ = Q − NR⁻¹Nᵀ return new Matrix<>( - dare( - A.minus(B.times(R.solve(N.transpose()))).getStorage(), - B.getStorage(), - Q.minus(N.times(R.solve(N.transpose()))).getStorage(), - R.getStorage())); + dare(A.getStorage(), B.getStorage(), Q.getStorage(), R.getStorage(), N.getStorage())); } } diff --git a/wpimath/src/main/java/edu/wpi/first/math/WPIMathJNI.java b/wpimath/src/main/java/edu/wpi/first/math/WPIMathJNI.java index f843fde7b9..37a85ffa61 100644 --- a/wpimath/src/main/java/edu/wpi/first/math/WPIMathJNI.java +++ b/wpimath/src/main/java/edu/wpi/first/math/WPIMathJNI.java @@ -53,10 +53,40 @@ public final class WPIMathJNI { * @param states Number of states in A matrix. * @param inputs Number of inputs in B matrix. * @param S Array storage for DARE solution. + * @throws IllegalArgumentException if Q isn't symmetric positive semidefinite. + * @throws IllegalArgumentException if R isn't symmetric positive definite. + * @throws IllegalArgumentException if the (A, B) pair isn't stabilizable. + * @throws IllegalArgumentException if the (A, C) pair where Q = CᵀC isn't detectable. */ - public static native void dare( + public static native void dareABQR( double[] A, double[] B, double[] Q, double[] R, int states, int inputs, double[] S); + /** + * Solves the discrete alegebraic Riccati equation. + * + * @param A Array containing elements of A in row-major order. + * @param B Array containing elements of B in row-major order. + * @param Q Array containing elements of Q in row-major order. + * @param R Array containing elements of R in row-major order. + * @param N Array containing elements of N in row-major order. + * @param states Number of states in A matrix. + * @param inputs Number of inputs in B matrix. + * @param S Array storage for DARE solution. + * @throws IllegalArgumentException if Q − NR⁻¹Nᵀ isn't symmetric positive semidefinite. + * @throws IllegalArgumentException if R isn't symmetric positive definite. + * @throws IllegalArgumentException if the (A, B) pair isn't stabilizable. + * @throws IllegalArgumentException if the (A, C) pair where Q = CᵀC isn't detectable. + */ + public static native void dareABQRN( + double[] A, + double[] B, + double[] Q, + double[] R, + double[] N, + int states, + int inputs, + double[] S); + /** * Computes the matrix exp. * diff --git a/wpimath/src/main/native/cpp/DARE.cpp b/wpimath/src/main/native/cpp/DARE.cpp index 3a91df8115..3a67602f60 100644 --- a/wpimath/src/main/native/cpp/DARE.cpp +++ b/wpimath/src/main/native/cpp/DARE.cpp @@ -11,6 +11,7 @@ #include "Eigen/Cholesky" #include "Eigen/Core" #include "Eigen/Eigenvalues" +#include "Eigen/LU" #include "Eigen/QR" #include "frc/fmt/Eigen.h" @@ -47,7 +48,7 @@ bool IsStabilizable(const Eigen::Ref& A, } Eigen::MatrixXcd E{A.rows(), A.rows() + B.cols()}; - E << es.eigenvalues()[i] * Eigen::MatrixXcd::Identity(A.rows(), A.rows()) - + E << es.eigenvalues()[i] * Eigen::MatrixXcd::Identity(A.rows(), A.cols()) - A, B; @@ -74,39 +75,6 @@ bool IsDetectable(const Eigen::Ref& A, return IsStabilizable(A.transpose(), C.transpose()); } -/** - * Returns true if all the matrix's eigenvalues are greater than or equal to - * zero. - * - * @param A The matrix. - */ -bool IsPositiveSemidefinite(const Eigen::Ref& A) { - Eigen::SelfAdjointEigenSolver es{A, Eigen::EigenvaluesOnly}; - for (int i = 0; i < A.rows(); ++i) { - if (es.eigenvalues()[i] < 0) { - return false; - } - } - - return true; -} - -/** - * Returns true if all the matrix's eigenvalues are greater than zero. - * - * @param A The matrix. - */ -bool IsPositiveDefinite(const Eigen::Ref& A) { - Eigen::SelfAdjointEigenSolver es{A, Eigen::EigenvaluesOnly}; - for (int i = 0; i < A.rows(); ++i) { - if (es.eigenvalues()[i] <= 0) { - return false; - } - } - - return true; -} - } // namespace Eigen::MatrixXd DARE(const Eigen::Ref& A, @@ -123,14 +91,6 @@ Eigen::MatrixXd DARE(const Eigen::Ref& A, assert(Q.rows() == states && Q.cols() == states); assert(R.rows() == inputs && R.cols() == inputs); - // Require the number of inputs be less than or equal to the number of states - if (inputs > states) { - std::string msg = fmt::format( - "Number of inputs ({}) is greater than number of states ({})!", inputs, - states); - throw std::invalid_argument(msg); - } - // Require Q be symmetric if ((Q - Q.transpose()).norm() > 1e-10) { std::string msg = @@ -139,7 +99,17 @@ Eigen::MatrixXd DARE(const Eigen::Ref& A, } // Require Q be positive semidefinite - if (!IsPositiveSemidefinite(Q)) { + // + // If Q is a symmetric matrix with a decomposition LDLᵀ, the number of + // positive, negative, and zero diagonal entries in D equals the number of + // positive, negative, and zero eigenvalues respectively in Q (see + // https://en.wikipedia.org/wiki/Sylvester's_law_of_inertia). + // + // Therefore, D having no negative diagonal entries is sufficient to prove Q + // is positive semidefinite. + auto Q_ldlt = Q.ldlt(); + if (Q_ldlt.info() != Eigen::Success || + (Q_ldlt.vectorD().array() < 0.0).any()) { std::string msg = fmt::format("Q isn't positive semidefinite!\n\nQ =\n{}\n", Eigen::MatrixXd{Q}); throw std::invalid_argument(msg); @@ -152,13 +122,6 @@ Eigen::MatrixXd DARE(const Eigen::Ref& A, throw std::invalid_argument(msg); } - // Require R be positive definite - if (!IsPositiveDefinite(R)) { - std::string msg = fmt::format("R isn't positive definite!\n\nR =\n{}\n", - Eigen::MatrixXd{R}); - throw std::invalid_argument(msg); - } - // Require (A, B) pair be stabilizable if (!IsStabilizable(A, B)) { std::string msg = @@ -169,9 +132,8 @@ Eigen::MatrixXd DARE(const Eigen::Ref& A, // Require (A, C) pair be detectable where Q = CᵀC { - Eigen::LDLT ldlt{Q}; - Eigen::MatrixXd C = Eigen::MatrixXd{ldlt.matrixL()} * - ldlt.vectorD().cwiseSqrt().asDiagonal(); + Eigen::MatrixXd C = Eigen::MatrixXd{Q_ldlt.matrixL()} * + Q_ldlt.vectorD().cwiseSqrt().asDiagonal(); if (!IsDetectable(A, C)) { std::string msg = fmt::format( @@ -182,61 +144,7 @@ Eigen::MatrixXd DARE(const Eigen::Ref& A, } } - // Implements the SSCA algorithm on page 12 of [1]. - - // A₀ = A - Eigen::MatrixXd A_k = A; - Eigen::MatrixXd A_k1 = A; - - // G₀ = BR⁻¹Bᵀ - // - // See equation (4) of [1]. - Eigen::MatrixXd G_k = B * R.llt().solve(B.transpose()); - - Eigen::MatrixXd I = Eigen::MatrixXd::Identity(A.rows(), A.cols()); - - // H₀ = Q - // - // See equation (4) of [1]. - Eigen::MatrixXd H_k = Q; - Eigen::MatrixXd H_k1 = Q; - - do { - A_k = A_k1; - H_k = H_k1; - - // W = I + HₖGₖ - Eigen::MatrixXd W = I + H_k * G_k; - - // W is symmetric positive definite, so the LLT decomposition would work - // here and is faster than the householder QR decomposition [2]. However, - // it's not accurate enough. Experimentation showed that so many iterations - // of iterative refinement [3] were required to fix the accuracy that the - // total system solve time was much higher than householder QR. - // - // [2] https://eigen.tuxfamily.org/dox/group__TutorialLinearAlgebra.html - // [3] https://en.wikipedia.org/wiki/Iterative_refinement - auto W_solver = W.householderQr(); - - // Solve WV₁ = Aₖᵀ for V₁ - Eigen::MatrixXd V_1 = W_solver.solve(A_k.transpose()); - - // Solve WV₂ = Hₖ for V₂ - Eigen::MatrixXd V_2 = W_solver.solve(H_k); - - // Aₖ₊₁ = V₁ᵀAₖ - A_k1 = V_1.transpose() * A_k; - - // Gₖ₊₁ = Gₖ + AₖGₖV₁ - G_k += A_k * G_k * V_1; - - // Hₖ₊₁ = Hₖ + AₖᵀV₂Aₖ - H_k1 = H_k + A_k.transpose() * V_2 * A_k; - - // while |Hₖ₊₁ − Hₖ| > ε |Hₖ₊₁| - } while ((H_k1 - H_k).norm() > 1e-10 * H_k1.norm()); - - return H_k1; + return internal::DARE(A, B, Q, R); } Eigen::MatrixXd DARE(const Eigen::Ref& A, @@ -244,8 +152,19 @@ Eigen::MatrixXd DARE(const Eigen::Ref& A, const Eigen::Ref& Q, const Eigen::Ref& R, const Eigen::Ref& N) { + // These are unused if assertions aren't compiled in + [[maybe_unused]] int states = A.rows(); + [[maybe_unused]] int inputs = B.cols(); + // Check argument dimensions - assert(N.rows() == B.rows() && N.cols() == B.cols()); + assert(N.rows() == states && N.cols() == inputs); + + auto R_llt = R.llt(); + if (R_llt.info() != Eigen::Success) { + std::string msg = fmt::format("R isn't positive definite!\n\nR =\n{}\n", + Eigen::MatrixXd{R}); + throw std::invalid_argument(msg); + } // This is a change of variables to make the DARE that includes Q, R, and N // cost matrices fit the form of the DARE that includes only Q and R cost @@ -258,8 +177,108 @@ Eigen::MatrixXd DARE(const Eigen::Ref& A, // where A₂ and Q₂ are a change of variables: // // A₂ = A − BR⁻¹Nᵀ and Q₂ = Q − NR⁻¹Nᵀ - return DARE(A - B * R.llt().solve(N.transpose()), B, - Q - N * R.llt().solve(N.transpose()), R); + return DARE(A - B * R_llt.solve(N.transpose()), B, + Q - N * R_llt.solve(N.transpose()), R); } +namespace internal { + +Eigen::MatrixXd DARE(const Eigen::Ref& A, + const Eigen::Ref& B, + const Eigen::Ref& Q, + const Eigen::Ref& R) { + // Require R be positive definite + auto R_llt = R.llt(); + if (R_llt.info() != Eigen::Success) { + std::string msg = fmt::format("R isn't positive definite!\n\nR =\n{}\n", + Eigen::MatrixXd{R}); + throw std::invalid_argument(msg); + } + + // Implements the SDA algorithm on page 5 of [1]. + + // A₀ = A + Eigen::MatrixXd A_k = A; + + // G₀ = BR⁻¹Bᵀ + // + // See equation (4) of [1]. + Eigen::MatrixXd G_k = B * R_llt.solve(B.transpose()); + + // H₀ = Q + // + // See equation (4) of [1]. + Eigen::MatrixXd H_k; + Eigen::MatrixXd H_k1 = Q; + + do { + H_k = H_k1; + + // W = I + GₖHₖ + Eigen::MatrixXd W = + Eigen::MatrixXd::Identity(H_k.rows(), H_k.cols()) + G_k * H_k; + + auto W_solver = W.lu(); + + // Solve WV₁ = Aₖ for V₁ + Eigen::MatrixXd V_1 = W_solver.solve(A_k); + + // Solve V₂Wᵀ = Gₖ for V₂ + // + // We want to put V₂Wᵀ = Gₖ into Ax = b form so we can solve it more + // efficiently. + // + // V₂Wᵀ = Gₖ + // (V₂Wᵀ)ᵀ = Gₖᵀ + // WV₂ᵀ = Gₖᵀ + // + // The solution of Ax = b can be found via x = A.solve(b). + // + // V₂ᵀ = W.solve(Gₖᵀ) + // V₂ = W.solve(Gₖᵀ)ᵀ + Eigen::MatrixXd V_2 = W_solver.solve(G_k.transpose()).transpose(); + + // Gₖ₊₁ = Gₖ + AₖV₂Aₖᵀ + G_k += A_k * V_2 * A_k.transpose(); + + // Hₖ₊₁ = Hₖ + V₁ᵀHₖAₖ + H_k1 = H_k + V_1.transpose() * H_k * A_k; + + // Aₖ₊₁ = AₖV₁ + A_k *= V_1; + + // while |Hₖ₊₁ − Hₖ| > ε |Hₖ₊₁| + } while ((H_k1 - H_k).norm() > 1e-10 * H_k1.norm()); + + return H_k1; +} + +Eigen::MatrixXd DARE(const Eigen::Ref& A, + const Eigen::Ref& B, + const Eigen::Ref& Q, + const Eigen::Ref& R, + const Eigen::Ref& N) { + auto R_llt = R.llt(); + if (R_llt.info() != Eigen::Success) { + std::string msg = fmt::format("R isn't positive definite!\n\nR =\n{}\n", + Eigen::MatrixXd{R}); + throw std::invalid_argument(msg); + } + + // This is a change of variables to make the DARE that includes Q, R, and N + // cost matrices fit the form of the DARE that includes only Q and R cost + // matrices. + // + // This is equivalent to solving the original DARE: + // + // A₂ᵀXA₂ − X − A₂ᵀXB(BᵀXB + R)⁻¹BᵀXA₂ + Q₂ = 0 + // + // where A₂ and Q₂ are a change of variables: + // + // A₂ = A − BR⁻¹Nᵀ and Q₂ = Q − NR⁻¹Nᵀ + return internal::DARE(A - B * R_llt.solve(N.transpose()), B, + Q - N * R_llt.solve(N.transpose()), R); +} + +} // namespace internal } // namespace frc diff --git a/wpimath/src/main/native/cpp/jni/WPIMathJNI.cpp b/wpimath/src/main/native/cpp/jni/WPIMathJNI.cpp index bfc07cc16f..a674ecf557 100644 --- a/wpimath/src/main/native/cpp/jni/WPIMathJNI.cpp +++ b/wpimath/src/main/native/cpp/jni/WPIMathJNI.cpp @@ -5,6 +5,7 @@ #include #include +#include #include @@ -102,11 +103,11 @@ extern "C" { /* * Class: edu_wpi_first_math_WPIMathJNI - * Method: dare + * Method: dareABQR * Signature: ([D[D[D[DII[D)V */ JNIEXPORT void JNICALL -Java_edu_wpi_first_math_WPIMathJNI_dare +Java_edu_wpi_first_math_WPIMathJNI_dareABQR (JNIEnv* env, jclass, jdoubleArray A, jdoubleArray B, jdoubleArray Q, jdoubleArray R, jint states, jint inputs, jdoubleArray S) { @@ -137,8 +138,58 @@ Java_edu_wpi_first_math_WPIMathJNI_dare env->ReleaseDoubleArrayElements(R, nativeR, 0); env->SetDoubleArrayRegion(S, 0, states * states, result.data()); - } catch (const std::runtime_error& e) { - jclass cls = env->FindClass("java/lang/RuntimeException"); + } catch (const std::invalid_argument& e) { + jclass cls = env->FindClass("java/lang/IllegalArgumentException"); + if (cls) { + env->ThrowNew(cls, e.what()); + } + } +} + +/* + * Class: edu_wpi_first_math_WPIMathJNI + * Method: dareABQRN + * Signature: ([D[D[D[D[DII[D)V + */ +JNIEXPORT void JNICALL +Java_edu_wpi_first_math_WPIMathJNI_dareABQRN + (JNIEnv* env, jclass, jdoubleArray A, jdoubleArray B, jdoubleArray Q, + jdoubleArray R, jdoubleArray N, jint states, jint inputs, jdoubleArray S) +{ + jdouble* nativeA = env->GetDoubleArrayElements(A, nullptr); + jdouble* nativeB = env->GetDoubleArrayElements(B, nullptr); + jdouble* nativeQ = env->GetDoubleArrayElements(Q, nullptr); + jdouble* nativeR = env->GetDoubleArrayElements(R, nullptr); + jdouble* nativeN = env->GetDoubleArrayElements(N, nullptr); + + Eigen::Map< + Eigen::Matrix> + Amat{nativeA, states, states}; + Eigen::Map< + Eigen::Matrix> + Bmat{nativeB, states, inputs}; + Eigen::Map< + Eigen::Matrix> + Qmat{nativeQ, states, states}; + Eigen::Map< + Eigen::Matrix> + Rmat{nativeR, inputs, inputs}; + Eigen::Map< + Eigen::Matrix> + Nmat{nativeN, states, inputs}; + + try { + Eigen::MatrixXd result = frc::DARE(Amat, Bmat, Qmat, Rmat, Nmat); + + env->ReleaseDoubleArrayElements(A, nativeA, 0); + env->ReleaseDoubleArrayElements(B, nativeB, 0); + env->ReleaseDoubleArrayElements(Q, nativeQ, 0); + env->ReleaseDoubleArrayElements(R, nativeR, 0); + env->ReleaseDoubleArrayElements(N, nativeN, 0); + + env->SetDoubleArrayRegion(S, 0, states * states, result.data()); + } catch (const std::invalid_argument& e) { + jclass cls = env->FindClass("java/lang/IllegalArgumentException"); if (cls) { env->ThrowNew(cls, e.what()); } diff --git a/wpimath/src/main/native/include/frc/DARE.h b/wpimath/src/main/native/include/frc/DARE.h index b6c3de5b06..c0220a1cd8 100644 --- a/wpimath/src/main/native/include/frc/DARE.h +++ b/wpimath/src/main/native/include/frc/DARE.h @@ -11,7 +11,6 @@ namespace frc { /** - * * Computes the unique stabilizing solution X to the discrete-time algebraic * Riccati equation: * @@ -21,8 +20,6 @@ namespace frc { * @param B The input matrix. * @param Q The state cost matrix. * @param R The input cost matrix. - * @throws std::invalid_argument if number of inputs is greater than number of - * states. * @throws std::invalid_argument if Q isn't symmetric positive semidefinite. * @throws std::invalid_argument if R isn't symmetric positive definite. * @throws std::invalid_argument if the (A, B) pair isn't stabilizable. @@ -73,8 +70,6 @@ J = Σ [uₖ] [0 R][uₖ] ΔT @param Q The state cost matrix. @param R The input cost matrix. @param N The state-input cross cost matrix. -@throws std::invalid_argument if number of inputs is greater than number of - states. @throws std::invalid_argument if Q − NR⁻¹Nᵀ isn't symmetric positive semidefinite. @throws std::invalid_argument if R isn't symmetric positive definite. @@ -88,4 +83,90 @@ Eigen::MatrixXd DARE(const Eigen::Ref& A, const Eigen::Ref& R, const Eigen::Ref& N); +namespace internal { + +/** + * Computes the unique stabilizing solution X to the discrete-time algebraic + * Riccati equation: + * + * AᵀXA − X − AᵀXB(BᵀXB + R)⁻¹BᵀXA + Q = 0 + * + * This internal function skips expensive precondition checks for increased + * performance. The solver may hang if any of the following occur: + *
    + *
  • Q isn't symmetric positive semidefinite
  • + *
  • R isn't symmetric positive definite
  • + *
  • The (A, B) pair isn't stabilizable
  • + *
  • The (A, C) pair where Q = CᵀC isn't detectable
  • + *
+ * Only use this function if you're sure the preconditions are met. + * + * @param A The system matrix. + * @param B The input matrix. + * @param Q The state cost matrix. + * @param R The input cost matrix. + */ +WPILIB_DLLEXPORT +Eigen::MatrixXd DARE(const Eigen::Ref& A, + const Eigen::Ref& B, + const Eigen::Ref& Q, + const Eigen::Ref& R); + +/** +Computes the unique stabilizing solution X to the discrete-time algebraic +Riccati equation: + + AᵀXA − X − (AᵀXB + N)(BᵀXB + R)⁻¹(BᵀXA + Nᵀ) + Q = 0 + +This overload of the DARE is useful for finding the control law uₖ that +minimizes the following cost function subject to xₖ₊₁ = Axₖ + Buₖ. + +@verbatim + ∞ [xₖ]ᵀ[Q N][xₖ] +J = Σ [uₖ] [Nᵀ R][uₖ] ΔT + k=0 +@endverbatim + +This is a more general form of the following. The linear-quadratic regulator +is the feedback control law uₖ that minimizes the following cost function +subject to xₖ₊₁ = Axₖ + Buₖ: + +@verbatim + ∞ +J = Σ (xₖᵀQxₖ + uₖᵀRuₖ) ΔT + k=0 +@endverbatim + +This can be refactored as: + +@verbatim + ∞ [xₖ]ᵀ[Q 0][xₖ] +J = Σ [uₖ] [0 R][uₖ] ΔT + k=0 +@endverbatim + +This internal function skips expensive precondition checks for increased +performance. The solver may hang if any of the following occur: +
    +
  • Q − NR⁻¹Nᵀ isn't symmetric positive semidefinite
  • +
  • R isn't symmetric positive definite
  • +
  • The (A, B) pair isn't stabilizable
  • +
  • The (A, C) pair where Q = CᵀC isn't detectable
  • +
+Only use this function if you're sure the preconditions are met. + +@param A The system matrix. +@param B The input matrix. +@param Q The state cost matrix. +@param R The input cost matrix. +@param N The state-input cross cost matrix. +*/ +WPILIB_DLLEXPORT +Eigen::MatrixXd DARE(const Eigen::Ref& A, + const Eigen::Ref& B, + const Eigen::Ref& Q, + const Eigen::Ref& R, + const Eigen::Ref& N); + +} // namespace internal } // namespace frc diff --git a/wpimath/src/test/java/edu/wpi/first/math/DARETest.java b/wpimath/src/test/java/edu/wpi/first/math/DARETest.java index f1a47ac74b..6dddaefef9 100644 --- a/wpimath/src/test/java/edu/wpi/first/math/DARETest.java +++ b/wpimath/src/test/java/edu/wpi/first/math/DARETest.java @@ -5,6 +5,7 @@ package edu.wpi.first.math; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import edu.wpi.first.wpilibj.UtilityClassTest; import org.ejml.simple.SimpleMatrix; @@ -187,4 +188,119 @@ class DARETest extends UtilityClassTest { assertMatrixEqual(X, X.transpose()); assertDARESolution(A, B, Q, R, N, X); } + + @Test + void testMoreInputsThanStates_ABQR() { + var A = SimpleMatrix.identity(2); + var B = new SimpleMatrix(2, 3, true, new double[] {1, 0, 0, 0, 0.5, 0.3}); + var Q = SimpleMatrix.identity(2); + var R = SimpleMatrix.identity(3); + + var X = DARE.dare(A, B, Q, R); + assertMatrixEqual(X, X.transpose()); + assertDARESolution(A, B, Q, R, X); + } + + @Test + void testMoreInputsThanStates_ABQRN() { + var A = SimpleMatrix.identity(2); + var B = new SimpleMatrix(2, 3, true, new double[] {1, 0, 0, 0, 0.5, 0.3}); + var Q = SimpleMatrix.identity(2); + var R = SimpleMatrix.identity(3); + var N = new SimpleMatrix(2, 3, true, new double[] {1, 0, 0, 0, 1, 0}); + + var X = DARE.dare(A, B, Q, R, N); + assertMatrixEqual(X, X.transpose()); + assertDARESolution(A, B, Q, R, N, X); + } + + @Test + void testQNotSymmetricPositiveSemidefinite_ABQR() { + var A = SimpleMatrix.identity(2); + var B = SimpleMatrix.identity(2); + var Q = SimpleMatrix.diag(-1.0, -1.0); + var R = SimpleMatrix.identity(2); + + assertThrows(IllegalArgumentException.class, () -> DARE.dare(A, B, Q, R)); + } + + @Test + void testQNotSymmetricPositiveSemidefinite_ABQRN() { + var A = SimpleMatrix.identity(2); + var B = SimpleMatrix.identity(2); + var Q = SimpleMatrix.identity(2); + var R = SimpleMatrix.diag(-1.0, -1.0); + var N = SimpleMatrix.diag(2.0, 2.0); + + assertThrows(IllegalArgumentException.class, () -> DARE.dare(A, B, Q, R, N)); + } + + @Test + void testRNotSymmetricPositiveDefinite_ABQR() { + var A = SimpleMatrix.identity(2); + var B = SimpleMatrix.identity(2); + var Q = SimpleMatrix.identity(2); + + var R1 = new SimpleMatrix(2, 2); + assertThrows(IllegalArgumentException.class, () -> DARE.dare(A, B, Q, R1)); + + var R2 = SimpleMatrix.diag(-1.0, -1.0); + assertThrows(IllegalArgumentException.class, () -> DARE.dare(A, B, Q, R2)); + } + + @Test + void testRNotSymmetricPositiveDefinite_ABQRN() { + var A = SimpleMatrix.identity(2); + var B = SimpleMatrix.identity(2); + var Q = SimpleMatrix.identity(2); + var N = SimpleMatrix.identity(2); + + var R1 = new SimpleMatrix(2, 2); + assertThrows(IllegalArgumentException.class, () -> DARE.dare(A, B, Q, R1, N)); + + var R2 = SimpleMatrix.diag(-1.0, -1.0); + assertThrows(IllegalArgumentException.class, () -> DARE.dare(A, B, Q, R2, N)); + } + + @Test + void testABNotStabilizable_ABQR() { + var A = SimpleMatrix.identity(2); + var B = new SimpleMatrix(2, 2); + var Q = SimpleMatrix.identity(2); + var R = SimpleMatrix.identity(2); + + assertThrows(IllegalArgumentException.class, () -> DARE.dare(A, B, Q, R)); + } + + @Test + void testABNotStabilizable_ABQRN() { + var A = SimpleMatrix.identity(2); + var B = new SimpleMatrix(2, 2); + var Q = SimpleMatrix.identity(2); + var R = SimpleMatrix.identity(2); + var N = SimpleMatrix.identity(2); + + assertThrows(IllegalArgumentException.class, () -> DARE.dare(A, B, Q, R, N)); + } + + @Test + void testACNotDetectable_ABQR() { + var A = SimpleMatrix.identity(2); + var B = SimpleMatrix.identity(2); + var Q = new SimpleMatrix(2, 2); + var R = SimpleMatrix.identity(2); + + assertThrows(IllegalArgumentException.class, () -> DARE.dare(A, B, Q, R)); + } + + @Test + void testACNotDetectable_ABQRN() { + var A = SimpleMatrix.identity(2); + var B = SimpleMatrix.identity(2); + var Q = new SimpleMatrix(2, 2); + var R = SimpleMatrix.identity(2); + var N = new SimpleMatrix(2, 2); + + assertThrows(IllegalArgumentException.class, () -> DARE.dare(A, B, Q, R, N)); + } } diff --git a/wpimath/src/test/native/cpp/DARETest.cpp b/wpimath/src/test/native/cpp/DARETest.cpp index 51397bfbf4..97fd412eb1 100644 --- a/wpimath/src/test/native/cpp/DARETest.cpp +++ b/wpimath/src/test/native/cpp/DARETest.cpp @@ -2,6 +2,8 @@ // Open Source Software; you can modify and/or share it under the terms of // the WPILib BSD license file in the root directory of this project. +#include + #include #include "Eigen/Core" @@ -71,7 +73,6 @@ void ExpectDARESolution(const Eigen::Ref& A, ExpectMatrixEqual(Y, Eigen::MatrixXd::Zero(X.rows(), X.cols()), 1e-10); } -// NOLINTNEXTLINE(google-readability-avoid-underscore-in-googletest-name) TEST(DARETest, NonInvertibleA_ABQR) { // Example 2 of "On the Numerical Solution of the Discrete-Time Algebraic // Riccati Equation" @@ -91,7 +92,6 @@ TEST(DARETest, NonInvertibleA_ABQR) { ExpectDARESolution(A, B, Q, R, X); } -// NOLINTNEXTLINE(google-readability-avoid-underscore-in-googletest-name) TEST(DARETest, NonInvertibleA_ABQRN) { // Example 2 of "On the Numerical Solution of the Discrete-Time Algebraic // Riccati Equation" @@ -117,7 +117,6 @@ TEST(DARETest, NonInvertibleA_ABQRN) { ExpectDARESolution(A, B, Q, R, N, X); } -// NOLINTNEXTLINE(google-readability-avoid-underscore-in-googletest-name) TEST(DARETest, InvertibleA_ABQR) { Eigen::MatrixXd A{2, 2}; A << 1, 1, 0, 1; @@ -134,7 +133,6 @@ TEST(DARETest, InvertibleA_ABQR) { ExpectDARESolution(A, B, Q, R, X); } -// NOLINTNEXTLINE(google-readability-avoid-underscore-in-googletest-name) TEST(DARETest, InvertibleA_ABQRN) { Eigen::MatrixXd A{2, 2}; A << 1, 1, 0, 1; @@ -157,7 +155,6 @@ TEST(DARETest, InvertibleA_ABQRN) { ExpectDARESolution(A, B, Q, R, N, X); } -// NOLINTNEXTLINE(google-readability-avoid-underscore-in-googletest-name) TEST(DARETest, FirstGeneralizedEigenvalueOfSTIsStable_ABQR) { // The first generalized eigenvalue of (S, T) is stable @@ -176,7 +173,6 @@ TEST(DARETest, FirstGeneralizedEigenvalueOfSTIsStable_ABQR) { ExpectDARESolution(A, B, Q, R, X); } -// NOLINTNEXTLINE(google-readability-avoid-underscore-in-googletest-name) TEST(DARETest, FirstGeneralizedEigenvalueOfSTIsStable_ABQRN) { // The first generalized eigenvalue of (S, T) is stable @@ -201,7 +197,6 @@ TEST(DARETest, FirstGeneralizedEigenvalueOfSTIsStable_ABQRN) { ExpectDARESolution(A, B, Q, R, N, X); } -// NOLINTNEXTLINE(google-readability-avoid-underscore-in-googletest-name) TEST(DARETest, IdentitySystem_ABQR) { const Eigen::MatrixXd A{Eigen::Matrix2d::Identity()}; const Eigen::MatrixXd B{Eigen::Matrix2d::Identity()}; @@ -214,7 +209,6 @@ TEST(DARETest, IdentitySystem_ABQR) { ExpectDARESolution(A, B, Q, R, X); } -// NOLINTNEXTLINE(google-readability-avoid-underscore-in-googletest-name) TEST(DARETest, IdentitySystem_ABQRN) { const Eigen::MatrixXd A{Eigen::Matrix2d::Identity()}; const Eigen::MatrixXd B{Eigen::Matrix2d::Identity()}; @@ -227,3 +221,110 @@ TEST(DARETest, IdentitySystem_ABQRN) { ExpectPositiveSemidefinite(X); ExpectDARESolution(A, B, Q, R, N, X); } + +TEST(DARETest, MoreInputsThanStates_ABQR) { + const Eigen::MatrixXd A{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd B{{1.0, 0.0, 0.0}, {0.0, 0.5, 0.3}}; + const Eigen::MatrixXd Q{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd R{Eigen::Matrix3d::Identity()}; + + Eigen::MatrixXd X = frc::DARE(A, B, Q, R); + ExpectMatrixEqual(X, X.transpose(), 1e-10); + ExpectPositiveSemidefinite(X); + ExpectDARESolution(A, B, Q, R, X); +} + +TEST(DARETest, MoreInputsThanStates_ABQRN) { + const Eigen::MatrixXd A{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd B{{1.0, 0.0, 0.0}, {0.0, 0.5, 0.3}}; + const Eigen::MatrixXd Q{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd R{Eigen::Matrix3d::Identity()}; + const Eigen::MatrixXd N{{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}}; + + Eigen::MatrixXd X = frc::DARE(A, B, Q, R, N); + ExpectMatrixEqual(X, X.transpose(), 1e-10); + ExpectPositiveSemidefinite(X); + ExpectDARESolution(A, B, Q, R, N, X); +} + +TEST(DARETest, QNotSymmetricPositiveSemidefinite_ABQR) { + const Eigen::MatrixXd A{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd B{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd Q{-Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd R{Eigen::Matrix2d::Identity()}; + + EXPECT_THROW(frc::DARE(A, B, Q, R), std::invalid_argument); +} + +TEST(DARETest, QNotSymmetricPositiveSemidefinite_ABQRN) { + const Eigen::MatrixXd A{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd B{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd Q{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd R{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd N{2.0 * Eigen::Matrix2d::Identity()}; + + EXPECT_THROW(frc::DARE(A, B, Q, R, N), std::invalid_argument); +} + +TEST(DARETest, RNotSymmetricPositiveDefinite_ABQR) { + const Eigen::MatrixXd A{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd B{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd Q{Eigen::Matrix2d::Identity()}; + + const Eigen::MatrixXd R1{Eigen::Matrix2d::Zero()}; + EXPECT_THROW(frc::DARE(A, B, Q, R1), std::invalid_argument); + + const Eigen::MatrixXd R2{-Eigen::Matrix2d::Identity()}; + EXPECT_THROW(frc::DARE(A, B, Q, R2), std::invalid_argument); +} + +TEST(DARETest, RNotSymmetricPositiveDefinite_ABQRN) { + const Eigen::MatrixXd A{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd B{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd Q{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd N{Eigen::Matrix2d::Identity()}; + + const Eigen::MatrixXd R1{Eigen::Matrix2d::Zero()}; + EXPECT_THROW(frc::DARE(A, B, Q, R1, N), std::invalid_argument); + + const Eigen::MatrixXd R2{-Eigen::Matrix2d::Identity()}; + EXPECT_THROW(frc::DARE(A, B, Q, R2, N), std::invalid_argument); +} + +TEST(DARETest, ABNotStabilizable_ABQR) { + const Eigen::MatrixXd A{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd B{Eigen::Matrix2d::Zero()}; + const Eigen::MatrixXd Q{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd R{Eigen::Matrix2d::Identity()}; + + EXPECT_THROW(frc::DARE(A, B, Q, R), std::invalid_argument); +} + +TEST(DARETest, ABNotStabilizable_ABQRN) { + const Eigen::MatrixXd A{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd B{Eigen::Matrix2d::Zero()}; + const Eigen::MatrixXd Q{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd R{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd N{Eigen::Matrix2d::Identity()}; + + EXPECT_THROW(frc::DARE(A, B, Q, R, N), std::invalid_argument); +} + +TEST(DARETest, ACNotDetectable_ABQR) { + const Eigen::MatrixXd A{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd B{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd Q{Eigen::Matrix2d::Zero()}; + const Eigen::MatrixXd R{Eigen::Matrix2d::Identity()}; + + EXPECT_THROW(frc::DARE(A, B, Q, R), std::invalid_argument); +} + +TEST(DARETest, ACNotDetectable_ABQRN) { + const Eigen::MatrixXd A{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd B{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd Q{Eigen::Matrix2d::Zero()}; + const Eigen::MatrixXd R{Eigen::Matrix2d::Identity()}; + const Eigen::MatrixXd N{Eigen::Matrix2d::Zero()}; + + EXPECT_THROW(frc::DARE(A, B, Q, R, N), std::invalid_argument); +}