[wpimath] Replace UKF implementation with square root form (#4168)

Co-authored-by: Tyler Veness <calcmogul@gmail.com>
This commit is contained in:
Connor Worley
2022-06-08 22:19:01 -07:00
committed by GitHub
parent 45b7fc445b
commit a99c11c14c
22 changed files with 494 additions and 297 deletions

View File

@@ -323,6 +323,10 @@ public class Matrix<R extends Num, C extends Num> {
* <p>The matrix equation could also be written as x = A<sup>-1</sup>b. Where the pseudo inverse
* is used if A is not square.
*
* <p>Note that this method does not support solving using a QR decomposition with full-pivoting,
* as only column-pivoting is supported. For full-pivoting, use {@link
* #solveFullPivHouseholderQr}.
*
* @param <C2> Columns in b.
* @param b The right-hand side of the equation to solve.
* @return The solution to the linear system.
@@ -332,6 +336,29 @@ public class Matrix<R extends Num, C extends Num> {
return new Matrix<>(this.m_storage.solve(Objects.requireNonNull(b).m_storage));
}
/**
* Solves the least-squares problem Ax=B using a QR decomposition with full pivoting, where this
* matrix is A.
*
* @param <R2> Number of rows in B.
* @param <C2> Number of columns in B.
* @param other The B matrix.
* @return The solution matrix.
*/
public final <R2 extends Num, C2 extends Num> Matrix<C, C2> solveFullPivHouseholderQr(
Matrix<R2, C2> other) {
Matrix<C, C2> solution = new Matrix<>(new SimpleMatrix(this.getNumCols(), other.getNumCols()));
WPIMathJNI.solveFullPivHouseholderQr(
this.getData(),
this.getNumRows(),
this.getNumCols(),
other.getData(),
other.getNumRows(),
other.getNumCols(),
solution.getData());
return solution;
}
/**
* Computes the matrix exponential using Eigen's solver. This method only works for square
* matrices, and will otherwise throw an {@link MatrixDimensionException}.
@@ -677,6 +704,20 @@ public class Matrix<R extends Num, C extends Num> {
this.m_storage.getDDRM(), other.m_storage.getDDRM(), tolerance);
}
/**
* Performs an inplace Cholesky rank update (or downdate).
*
* <p>If this matrix contains L where A = LL<sup>&top;</sup> before the update, it will contain L
* where LL<sup>&top;</sup> = A + &sigma;vv<sup>&top;</sup> after the update.
*
* @param v Vector to use for the update.
* @param sigma Sigma to use for the update.
* @param lowerTriangular Whether or not this matrix is lower triangular.
*/
public void rankUpdate(Matrix<R, N1> v, double sigma, boolean lowerTriangular) {
WPIMathJNI.rankUpdate(this.getData(), this.getNumRows(), v.getData(), sigma, lowerTriangular);
}
@Override
public String toString() {
return m_storage.toString();

View File

@@ -125,6 +125,19 @@ public final class WPIMathJNI {
*/
public static native String serializeTrajectory(double[] elements);
/**
* Performs an inplace rank one update (or downdate) of an upper triangular Cholesky decomposition
* matrix.
*
* @param mat Array of elements of the matrix to be updated.
* @param lowerTriangular Whether or not mat is lower triangular.
* @param rows How many rows there are.
* @param vec Vector to use for the rank update.
* @param sigma Sigma value to use for the rank update.
*/
public static native void rankUpdate(
double[] mat, int rows, double[] vec, double sigma, boolean lowerTriangular);
public static class Helper {
private static AtomicBoolean extractOnStaticLoad = new AtomicBoolean(true);
@@ -136,4 +149,18 @@ public final class WPIMathJNI {
extractOnStaticLoad.set(load);
}
}
/**
* Solves the least-squares problem Ax=B using a QR decomposition with full pivoting.
*
* @param A Array of elements of the A matrix.
* @param Arows Number of rows of the A matrix.
* @param Acols Number of rows of the A matrix.
* @param B Array of elements of the B matrix.
* @param Brows Number of rows of the B matrix.
* @param Bcols Number of rows of the B matrix.
* @param dst Array to store solution in. If A is m-n and B is m-p, dst is n-p.
*/
public static native void solveFullPivHouseholderQr(
double[] A, int Arows, int Acols, double[] B, int Brows, int Bcols, double[] dst);
}

View File

@@ -71,16 +71,16 @@ public class MerweScaledSigmaPoints<S extends Num> {
* of the filter.
*
* @param x An array of the means.
* @param P Covariance of the filter.
* @param s Square-root covariance of the filter.
* @return Two dimensional array of sigma points. Each column contains all of the sigmas for one
* dimension in the problem space. Ordered by Xi_0, Xi_{1..n}, Xi_{n+1..2n}.
*/
@SuppressWarnings({"ParameterName", "LocalVariableName"})
public Matrix<S, ?> sigmaPoints(Matrix<S, N1> x, Matrix<S, S> P) {
public Matrix<S, ?> squareRootSigmaPoints(Matrix<S, N1> x, Matrix<S, S> s) {
double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum();
double eta = Math.sqrt(lambda + m_states.getNum());
var intermediate = P.times(lambda + m_states.getNum());
var U = intermediate.lltDecompose(true); // Lower triangular
Matrix<S, S> U = s.times(eta);
// 2 * states + 1 by states
Matrix<S, ?> sigmas =

View File

@@ -14,6 +14,7 @@ import edu.wpi.first.math.system.Discretization;
import edu.wpi.first.math.system.NumericalIntegration;
import edu.wpi.first.math.system.NumericalJacobian;
import java.util.function.BiFunction;
import org.ejml.dense.row.decomposition.qr.QRDecompositionHouseholder_DDRM;
import org.ejml.simple.SimpleMatrix;
/**
@@ -33,6 +34,9 @@ import org.ejml.simple.SimpleMatrix;
* <p>For more on the underlying math, read
* https://file.tavsys.net/control/controls-engineering-in-frc.pdf chapter 9 "Stochastic control
* theory".
*
* <p>This class implements a square-root-form unscented Kalman filter (SR-UKF). For more
* information about the SR-UKF, see https://www.researchgate.net/publication/3908304.
*/
@SuppressWarnings({"MemberName", "ClassTypeParameterName"})
public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num>
@@ -50,7 +54,7 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX;
private Matrix<States, N1> m_xHat;
private Matrix<States, States> m_P;
private Matrix<States, States> m_S;
private final Matrix<States, States> m_contQ;
private final Matrix<Outputs, Outputs> m_contR;
private Matrix<States, ?> m_sigmasF;
@@ -152,14 +156,16 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
}
@SuppressWarnings({"ParameterName", "LocalVariableName"})
static <S extends Num, C extends Num> Pair<Matrix<C, N1>, Matrix<C, C>> unscentedTransform(
Nat<S> s,
Nat<C> dim,
Matrix<C, ?> sigmas,
Matrix<?, N1> Wm,
Matrix<?, N1> Wc,
BiFunction<Matrix<C, ?>, Matrix<?, N1>, Matrix<C, N1>> meanFunc,
BiFunction<Matrix<C, N1>, Matrix<C, N1>, Matrix<C, N1>> residualFunc) {
static <S extends Num, C extends Num>
Pair<Matrix<C, N1>, Matrix<C, C>> squareRootUnscentedTransform(
Nat<S> s,
Nat<C> dim,
Matrix<C, ?> sigmas,
Matrix<?, N1> Wm,
Matrix<?, N1> Wc,
BiFunction<Matrix<C, ?>, Matrix<?, N1>, Matrix<C, N1>> meanFunc,
BiFunction<Matrix<C, N1>, Matrix<C, N1>, Matrix<C, N1>> residualFunc,
Matrix<C, C> squareRootR) {
if (sigmas.getNumRows() != dim.getNum() || sigmas.getNumCols() != 2 * s.getNum() + 1) {
throw new IllegalArgumentException(
"Sigmas must be covDim by 2 * states + 1! Got "
@@ -184,28 +190,64 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
// k=1
Matrix<C, N1> x = meanFunc.apply(sigmas, Wm);
// New covariance is the sum of the outer product of the residuals times the
// weights
Matrix<C, ?> y = new Matrix<>(new SimpleMatrix(dim.getNum(), 2 * s.getNum() + 1));
for (int i = 0; i < 2 * s.getNum() + 1; i++) {
// y[:, i] = sigmas[:, i] - x
y.setColumn(i, residualFunc.apply(sigmas.extractColumnVector(i), x));
Matrix<C, ?> Sbar = new Matrix<>(new SimpleMatrix(dim.getNum(), 2 * s.getNum() + dim.getNum()));
for (int i = 0; i < 2 * s.getNum(); i++) {
Sbar.setColumn(
i,
residualFunc.apply(sigmas.extractColumnVector(1 + i), x).times(Math.sqrt(Wc.get(1, 0))));
}
Matrix<C, C> P =
y.times(Matrix.changeBoundsUnchecked(Wc.diag()))
.times(Matrix.changeBoundsUnchecked(y.transpose()));
Sbar.assignBlock(0, 2 * s.getNum(), squareRootR);
return new Pair<>(x, P);
QRDecompositionHouseholder_DDRM qr = new QRDecompositionHouseholder_DDRM();
var qrStorage = Sbar.transpose().getStorage();
if (!qr.decompose(qrStorage.getDDRM())) {
throw new RuntimeException("QR decomposition failed! Input matrix:\n" + qrStorage.toString());
}
Matrix<C, C> newS = new Matrix<>(new SimpleMatrix(qr.getR(null, true)));
newS.rankUpdate(residualFunc.apply(sigmas.extractColumnVector(0), x), Wc.get(0, 0), false);
return new Pair<>(x, newS);
}
/**
* Returns the error covariance matrix P.
* Returns the square-root error covariance matrix S.
*
* @return the square-root error covariance matrix S.
*/
public Matrix<States, States> getS() {
return m_S;
}
/**
* Returns an element of the square-root error covariance matrix S.
*
* @param row Row of S.
* @param col Column of S.
* @return the value of the square-root error covariance matrix S at (i, j).
*/
public double getS(int row, int col) {
return m_S.get(row, col);
}
/**
* Sets the entire square-root error covariance matrix S.
*
* @param newS The new value of S to use.
*/
public void setS(Matrix<States, States> newS) {
m_S = newS;
}
/**
* Returns the reconstructed error covariance matrix P.
*
* @return the error covariance matrix P.
*/
@Override
public Matrix<States, States> getP() {
return m_P;
return m_S.transpose().times(m_S);
}
/**
@@ -214,10 +256,12 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
* @param row Row of P.
* @param col Column of P.
* @return the value of the error covariance matrix P at (i, j).
* @throws UnsupportedOperationException indexing into the reconstructed P matrix is not supported
*/
@Override
public double getP(int row, int col) {
return m_P.get(row, col);
throw new UnsupportedOperationException(
"indexing into the reconstructed P matrix is not supported");
}
/**
@@ -227,7 +271,7 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
*/
@Override
public void setP(Matrix<States, States> newP) {
m_P = newP;
m_S = newP.lltDecompose(false);
}
/**
@@ -277,7 +321,7 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
@Override
public void reset() {
m_xHat = new Matrix<>(m_states, Nat.N1());
m_P = new Matrix<>(m_states, m_states);
m_S = new Matrix<>(m_states, m_states);
m_sigmasF = new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1));
}
@@ -294,8 +338,9 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
Matrix<States, States> contA =
NumericalJacobian.numericalJacobianX(m_states, m_states, m_f, m_xHat, u);
var discQ = Discretization.discretizeAQTaylor(contA, m_contQ, dtSeconds).getSecond();
var squareRootDiscQ = discQ.lltDecompose(true);
var sigmas = m_pts.sigmaPoints(m_xHat, m_P);
var sigmas = m_pts.squareRootSigmaPoints(m_xHat, m_S);
for (int i = 0; i < m_pts.getNumSigmas(); ++i) {
Matrix<States, N1> x = sigmas.extractColumnVector(i);
@@ -304,17 +349,18 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
}
var ret =
unscentedTransform(
squareRootUnscentedTransform(
m_states,
m_states,
m_sigmasF,
m_pts.getWm(),
m_pts.getWc(),
m_meanFuncX,
m_residualFuncX);
m_residualFuncX,
squareRootDiscQ);
m_xHat = ret.getFirst();
m_P = ret.getSecond().plus(discQ);
m_S = ret.getSecond();
m_dtSeconds = dtSeconds;
}
@@ -394,10 +440,11 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX,
BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) {
final var discR = Discretization.discretizeR(R, m_dtSeconds);
final var squareRootDiscR = discR.lltDecompose(true);
// Transform sigma points into measurement space
Matrix<R, ?> sigmasH = new Matrix<>(new SimpleMatrix(rows.getNum(), 2 * m_states.getNum() + 1));
var sigmas = m_pts.sigmaPoints(m_xHat, m_P);
var sigmas = m_pts.squareRootSigmaPoints(m_xHat, m_S);
for (int i = 0; i < m_pts.getNumSigmas(); i++) {
Matrix<R, N1> hRet = h.apply(sigmas.extractColumnVector(i), u);
sigmasH.setColumn(i, hRet);
@@ -405,10 +452,17 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
// Mean and covariance of prediction passed through unscented transform
var transRet =
unscentedTransform(
m_states, rows, sigmasH, m_pts.getWm(), m_pts.getWc(), meanFuncY, residualFuncY);
squareRootUnscentedTransform(
m_states,
rows,
sigmasH,
m_pts.getWm(),
m_pts.getWc(),
meanFuncY,
residualFuncY,
squareRootDiscR);
var yHat = transRet.getFirst();
var Py = transRet.getSecond().plus(discR);
var Sy = transRet.getSecond();
// Compute cross covariance of the state and the measurements
Matrix<States, R> Pxy = new Matrix<>(m_states, rows);
@@ -420,17 +474,20 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
Pxy = Pxy.plus(dx.times(dy).times(m_pts.getWc(i)));
}
// K = P_{xy} P_y⁻¹
// K = P_yᵀ⁻¹ P_{xy}ᵀ
// P_yᵀKᵀ = P_{xy}ᵀ
// Kᵀ = P_yᵀ.solve(P_{xy}ᵀ)
// K = (P_yᵀ.solve(P_{xy}ᵀ)ᵀ
Matrix<States, R> K = new Matrix<>(Py.transpose().solve(Pxy.transpose()).transpose());
// K = (P_{xy} / S_yᵀ) / S_y
// K = (S_y \ P_{xy}ᵀ)ᵀ / S_y
// K = (S_yᵀ \ (S_y \ P_{xy}ᵀ))ᵀ
Matrix<States, R> K =
Sy.transpose()
.solveFullPivHouseholderQr(Sy.solveFullPivHouseholderQr(Pxy.transpose()))
.transpose();
// x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y ŷ)
m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, yHat)));
// Pₖ₊₁⁺ = Pₖ₊₁⁻ KP_yKᵀ
m_P = m_P.minus(K.times(Py).times(K.transpose()));
Matrix<States, R> U = K.times(Sy);
for (int i = 0; i < rows.getNum(); i++) {
m_S.rankUpdate(U.extractColumnVector(i), -1, false);
}
}
}