[wpimath] Rewrite DARE solver (#5328)

I timed the DARE unit tests, and the new solver is 0 to 100% faster in
all cases (that is, it's at least as fast as Drake's and up to 2x faster
in some cases).

The new solver is also much simpler, takes less time to compile, and
drops the libwpimath.so size from 325 MB to 301 MB.

I think most of the compilation time is coming from the eigenvalue
decompositions used to enforce argument preconditions.
This commit is contained in:
Tyler Veness
2023-05-14 22:23:00 -07:00
committed by GitHub
parent 3876a2523a
commit 52bd5b972d
32 changed files with 831 additions and 2024 deletions

View File

@@ -6,8 +6,8 @@ package edu.wpi.first.math;
import org.ejml.simple.SimpleMatrix;
public final class Drake {
private Drake() {}
public final class DARE {
private DARE() {}
/**
* Solves the discrete algebraic Riccati equation.
@@ -18,10 +18,9 @@ public final class Drake {
* @param R Input cost matrix.
* @return Solution of DARE.
*/
public static SimpleMatrix discreteAlgebraicRiccatiEquation(
SimpleMatrix A, SimpleMatrix B, SimpleMatrix Q, SimpleMatrix R) {
public static SimpleMatrix dare(SimpleMatrix A, SimpleMatrix B, SimpleMatrix Q, SimpleMatrix R) {
var S = new SimpleMatrix(A.numRows(), A.numCols());
WPIMathJNI.discreteAlgebraicRiccatiEquation(
WPIMathJNI.dare(
A.getDDRM().getData(),
B.getDDRM().getData(),
Q.getDDRM().getData(),
@@ -43,15 +42,12 @@ public final class Drake {
* @param R Input cost matrix.
* @return Solution of DARE.
*/
public static <States extends Num, Inputs extends Num>
Matrix<States, States> discreteAlgebraicRiccatiEquation(
Matrix<States, States> A,
Matrix<States, Inputs> B,
Matrix<States, States> Q,
Matrix<Inputs, Inputs> R) {
return new Matrix<>(
discreteAlgebraicRiccatiEquation(
A.getStorage(), B.getStorage(), Q.getStorage(), R.getStorage()));
public static <States extends Num, Inputs extends Num> Matrix<States, States> dare(
Matrix<States, States> A,
Matrix<States, Inputs> B,
Matrix<States, States> Q,
Matrix<Inputs, Inputs> R) {
return new Matrix<>(dare(A.getStorage(), B.getStorage(), Q.getStorage(), R.getStorage()));
}
/**
@@ -64,7 +60,7 @@ public final class Drake {
* @param N State-input cross-term cost matrix.
* @return Solution of DARE.
*/
public static SimpleMatrix discreteAlgebraicRiccatiEquation(
public static SimpleMatrix dare(
SimpleMatrix A, SimpleMatrix B, SimpleMatrix Q, SimpleMatrix R, SimpleMatrix N) {
// See
// https://en.wikipedia.org/wiki/Linear%E2%80%93quadratic_regulator#Infinite-horizon,_discrete-time_LQR
@@ -73,7 +69,7 @@ public final class Drake {
var scrQ = Q.minus(N.mult(R.solve(N.transpose())));
var S = new SimpleMatrix(A.numRows(), A.numCols());
WPIMathJNI.discreteAlgebraicRiccatiEquation(
WPIMathJNI.dare(
scrA.getDDRM().getData(),
B.getDDRM().getData(),
scrQ.getDDRM().getData(),
@@ -96,21 +92,28 @@ public final class Drake {
* @param N State-input cross-term cost matrix.
* @return Solution of DARE.
*/
public static <States extends Num, Inputs extends Num>
Matrix<States, States> discreteAlgebraicRiccatiEquation(
Matrix<States, States> A,
Matrix<States, Inputs> B,
Matrix<States, States> Q,
Matrix<Inputs, Inputs> R,
Matrix<States, Inputs> N) {
// See
// https://en.wikipedia.org/wiki/Linear%E2%80%93quadratic_regulator#Infinite-horizon,_discrete-time_LQR
// for the change of variables used here.
var scrA = A.minus(B.times(R.solve(N.transpose())));
var scrQ = Q.minus(N.times(R.solve(N.transpose())));
public static <States extends Num, Inputs extends Num> Matrix<States, States> dare(
Matrix<States, States> A,
Matrix<States, Inputs> B,
Matrix<States, States> Q,
Matrix<Inputs, Inputs> R,
Matrix<States, Inputs> N) {
// This is a change of variables to make the DARE that includes Q, R, and N
// cost matrices fit the form of the DARE that includes only Q and R cost
// matrices.
//
// This is equivalent to solving the original DARE:
//
// AᵀXA X AᵀXB(BᵀXB + R)¹BᵀXA + Q = 0
//
// where A and Q are a change of variables:
//
// A = A BR¹Nᵀ and Q = Q NR¹Nᵀ
return new Matrix<>(
discreteAlgebraicRiccatiEquation(
scrA.getStorage(), B.getStorage(), scrQ.getStorage(), R.getStorage()));
dare(
A.minus(B.times(R.solve(N.transpose()))).getStorage(),
B.getStorage(),
Q.minus(N.times(R.solve(N.transpose()))).getStorage(),
R.getStorage()));
}
}

View File

@@ -54,7 +54,7 @@ public final class WPIMathJNI {
* @param inputs Number of inputs in B matrix.
* @param S Array storage for DARE solution.
*/
public static native void discreteAlgebraicRiccatiEquation(
public static native void dare(
double[] A, double[] B, double[] Q, double[] R, int states, int inputs, double[] S);
/**

View File

@@ -4,7 +4,7 @@
package edu.wpi.first.math.controller;
import edu.wpi.first.math.Drake;
import edu.wpi.first.math.DARE;
import edu.wpi.first.math.MathSharedStore;
import edu.wpi.first.math.Matrix;
import edu.wpi.first.math.Num;
@@ -111,7 +111,7 @@ public class LinearQuadraticRegulator<States extends Num, Inputs extends Num, Ou
throw new IllegalArgumentException(msg);
}
var S = Drake.discreteAlgebraicRiccatiEquation(discA, discB, Q, R);
var S = DARE.dare(discA, discB, Q, R);
// K = (BᵀSB + R)⁻¹BᵀSA
m_K =
@@ -150,7 +150,7 @@ public class LinearQuadraticRegulator<States extends Num, Inputs extends Num, Ou
var discA = discABPair.getFirst();
var discB = discABPair.getSecond();
var S = Drake.discreteAlgebraicRiccatiEquation(discA, discB, Q, R, N);
var S = DARE.dare(discA, discB, Q, R, N);
// K = (BᵀSB + R)⁻¹(BᵀSA + Nᵀ)
m_K =

View File

@@ -4,7 +4,7 @@
package edu.wpi.first.math.estimator;
import edu.wpi.first.math.Drake;
import edu.wpi.first.math.DARE;
import edu.wpi.first.math.Matrix;
import edu.wpi.first.math.Nat;
import edu.wpi.first.math.Num;
@@ -145,8 +145,7 @@ public class ExtendedKalmanFilter<States extends Num, Inputs extends Num, Output
final var discR = Discretization.discretizeR(m_contR, dtSeconds);
if (StateSpaceUtil.isDetectable(discA, C) && outputs.getNum() <= states.getNum()) {
m_initP =
Drake.discreteAlgebraicRiccatiEquation(discA.transpose(), C.transpose(), discQ, discR);
m_initP = DARE.dare(discA.transpose(), C.transpose(), discQ, discR);
} else {
m_initP = new Matrix<>(states, states);
}

View File

@@ -4,7 +4,7 @@
package edu.wpi.first.math.estimator;
import edu.wpi.first.math.Drake;
import edu.wpi.first.math.DARE;
import edu.wpi.first.math.MathSharedStore;
import edu.wpi.first.math.Matrix;
import edu.wpi.first.math.Nat;
@@ -87,9 +87,7 @@ public class KalmanFilter<States extends Num, Inputs extends Num, Outputs extend
throw new IllegalArgumentException(msg);
}
var P =
new Matrix<>(
Drake.discreteAlgebraicRiccatiEquation(discA.transpose(), C.transpose(), discQ, discR));
var P = new Matrix<>(DARE.dare(discA.transpose(), C.transpose(), discQ, discR));
// S = CPCᵀ + R
var S = C.times(P).times(C.transpose()).plus(discR);