mirror of
https://github.com/wpilibsuite/allwpilib
synced 2026-06-21 01:01:43 +00:00
[wpimath] Replace UKF implementation with square root form (#4168)
Co-authored-by: Tyler Veness <calcmogul@gmail.com>
This commit is contained in:
@@ -323,6 +323,10 @@ public class Matrix<R extends Num, C extends Num> {
|
||||
* <p>The matrix equation could also be written as x = A<sup>-1</sup>b. Where the pseudo inverse
|
||||
* is used if A is not square.
|
||||
*
|
||||
* <p>Note that this method does not support solving using a QR decomposition with full-pivoting,
|
||||
* as only column-pivoting is supported. For full-pivoting, use {@link
|
||||
* #solveFullPivHouseholderQr}.
|
||||
*
|
||||
* @param <C2> Columns in b.
|
||||
* @param b The right-hand side of the equation to solve.
|
||||
* @return The solution to the linear system.
|
||||
@@ -332,6 +336,29 @@ public class Matrix<R extends Num, C extends Num> {
|
||||
return new Matrix<>(this.m_storage.solve(Objects.requireNonNull(b).m_storage));
|
||||
}
|
||||
|
||||
/**
|
||||
* Solves the least-squares problem Ax=B using a QR decomposition with full pivoting, where this
|
||||
* matrix is A.
|
||||
*
|
||||
* @param <R2> Number of rows in B.
|
||||
* @param <C2> Number of columns in B.
|
||||
* @param other The B matrix.
|
||||
* @return The solution matrix.
|
||||
*/
|
||||
public final <R2 extends Num, C2 extends Num> Matrix<C, C2> solveFullPivHouseholderQr(
|
||||
Matrix<R2, C2> other) {
|
||||
Matrix<C, C2> solution = new Matrix<>(new SimpleMatrix(this.getNumCols(), other.getNumCols()));
|
||||
WPIMathJNI.solveFullPivHouseholderQr(
|
||||
this.getData(),
|
||||
this.getNumRows(),
|
||||
this.getNumCols(),
|
||||
other.getData(),
|
||||
other.getNumRows(),
|
||||
other.getNumCols(),
|
||||
solution.getData());
|
||||
return solution;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the matrix exponential using Eigen's solver. This method only works for square
|
||||
* matrices, and will otherwise throw an {@link MatrixDimensionException}.
|
||||
@@ -677,6 +704,20 @@ public class Matrix<R extends Num, C extends Num> {
|
||||
this.m_storage.getDDRM(), other.m_storage.getDDRM(), tolerance);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs an inplace Cholesky rank update (or downdate).
|
||||
*
|
||||
* <p>If this matrix contains L where A = LL<sup>⊤</sup> before the update, it will contain L
|
||||
* where LL<sup>⊤</sup> = A + σvv<sup>⊤</sup> after the update.
|
||||
*
|
||||
* @param v Vector to use for the update.
|
||||
* @param sigma Sigma to use for the update.
|
||||
* @param lowerTriangular Whether or not this matrix is lower triangular.
|
||||
*/
|
||||
public void rankUpdate(Matrix<R, N1> v, double sigma, boolean lowerTriangular) {
|
||||
WPIMathJNI.rankUpdate(this.getData(), this.getNumRows(), v.getData(), sigma, lowerTriangular);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return m_storage.toString();
|
||||
|
||||
@@ -125,6 +125,19 @@ public final class WPIMathJNI {
|
||||
*/
|
||||
public static native String serializeTrajectory(double[] elements);
|
||||
|
||||
/**
|
||||
* Performs an inplace rank one update (or downdate) of an upper triangular Cholesky decomposition
|
||||
* matrix.
|
||||
*
|
||||
* @param mat Array of elements of the matrix to be updated.
|
||||
* @param lowerTriangular Whether or not mat is lower triangular.
|
||||
* @param rows How many rows there are.
|
||||
* @param vec Vector to use for the rank update.
|
||||
* @param sigma Sigma value to use for the rank update.
|
||||
*/
|
||||
public static native void rankUpdate(
|
||||
double[] mat, int rows, double[] vec, double sigma, boolean lowerTriangular);
|
||||
|
||||
public static class Helper {
|
||||
private static AtomicBoolean extractOnStaticLoad = new AtomicBoolean(true);
|
||||
|
||||
@@ -136,4 +149,18 @@ public final class WPIMathJNI {
|
||||
extractOnStaticLoad.set(load);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Solves the least-squares problem Ax=B using a QR decomposition with full pivoting.
|
||||
*
|
||||
* @param A Array of elements of the A matrix.
|
||||
* @param Arows Number of rows of the A matrix.
|
||||
* @param Acols Number of rows of the A matrix.
|
||||
* @param B Array of elements of the B matrix.
|
||||
* @param Brows Number of rows of the B matrix.
|
||||
* @param Bcols Number of rows of the B matrix.
|
||||
* @param dst Array to store solution in. If A is m-n and B is m-p, dst is n-p.
|
||||
*/
|
||||
public static native void solveFullPivHouseholderQr(
|
||||
double[] A, int Arows, int Acols, double[] B, int Brows, int Bcols, double[] dst);
|
||||
}
|
||||
|
||||
@@ -71,16 +71,16 @@ public class MerweScaledSigmaPoints<S extends Num> {
|
||||
* of the filter.
|
||||
*
|
||||
* @param x An array of the means.
|
||||
* @param P Covariance of the filter.
|
||||
* @param s Square-root covariance of the filter.
|
||||
* @return Two dimensional array of sigma points. Each column contains all of the sigmas for one
|
||||
* dimension in the problem space. Ordered by Xi_0, Xi_{1..n}, Xi_{n+1..2n}.
|
||||
*/
|
||||
@SuppressWarnings({"ParameterName", "LocalVariableName"})
|
||||
public Matrix<S, ?> sigmaPoints(Matrix<S, N1> x, Matrix<S, S> P) {
|
||||
public Matrix<S, ?> squareRootSigmaPoints(Matrix<S, N1> x, Matrix<S, S> s) {
|
||||
double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum();
|
||||
double eta = Math.sqrt(lambda + m_states.getNum());
|
||||
|
||||
var intermediate = P.times(lambda + m_states.getNum());
|
||||
var U = intermediate.lltDecompose(true); // Lower triangular
|
||||
Matrix<S, S> U = s.times(eta);
|
||||
|
||||
// 2 * states + 1 by states
|
||||
Matrix<S, ?> sigmas =
|
||||
|
||||
@@ -14,6 +14,7 @@ import edu.wpi.first.math.system.Discretization;
|
||||
import edu.wpi.first.math.system.NumericalIntegration;
|
||||
import edu.wpi.first.math.system.NumericalJacobian;
|
||||
import java.util.function.BiFunction;
|
||||
import org.ejml.dense.row.decomposition.qr.QRDecompositionHouseholder_DDRM;
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
/**
|
||||
@@ -33,6 +34,9 @@ import org.ejml.simple.SimpleMatrix;
|
||||
* <p>For more on the underlying math, read
|
||||
* https://file.tavsys.net/control/controls-engineering-in-frc.pdf chapter 9 "Stochastic control
|
||||
* theory".
|
||||
*
|
||||
* <p>This class implements a square-root-form unscented Kalman filter (SR-UKF). For more
|
||||
* information about the SR-UKF, see https://www.researchgate.net/publication/3908304.
|
||||
*/
|
||||
@SuppressWarnings({"MemberName", "ClassTypeParameterName"})
|
||||
public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num>
|
||||
@@ -50,7 +54,7 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
|
||||
private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX;
|
||||
|
||||
private Matrix<States, N1> m_xHat;
|
||||
private Matrix<States, States> m_P;
|
||||
private Matrix<States, States> m_S;
|
||||
private final Matrix<States, States> m_contQ;
|
||||
private final Matrix<Outputs, Outputs> m_contR;
|
||||
private Matrix<States, ?> m_sigmasF;
|
||||
@@ -152,14 +156,16 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
|
||||
}
|
||||
|
||||
@SuppressWarnings({"ParameterName", "LocalVariableName"})
|
||||
static <S extends Num, C extends Num> Pair<Matrix<C, N1>, Matrix<C, C>> unscentedTransform(
|
||||
Nat<S> s,
|
||||
Nat<C> dim,
|
||||
Matrix<C, ?> sigmas,
|
||||
Matrix<?, N1> Wm,
|
||||
Matrix<?, N1> Wc,
|
||||
BiFunction<Matrix<C, ?>, Matrix<?, N1>, Matrix<C, N1>> meanFunc,
|
||||
BiFunction<Matrix<C, N1>, Matrix<C, N1>, Matrix<C, N1>> residualFunc) {
|
||||
static <S extends Num, C extends Num>
|
||||
Pair<Matrix<C, N1>, Matrix<C, C>> squareRootUnscentedTransform(
|
||||
Nat<S> s,
|
||||
Nat<C> dim,
|
||||
Matrix<C, ?> sigmas,
|
||||
Matrix<?, N1> Wm,
|
||||
Matrix<?, N1> Wc,
|
||||
BiFunction<Matrix<C, ?>, Matrix<?, N1>, Matrix<C, N1>> meanFunc,
|
||||
BiFunction<Matrix<C, N1>, Matrix<C, N1>, Matrix<C, N1>> residualFunc,
|
||||
Matrix<C, C> squareRootR) {
|
||||
if (sigmas.getNumRows() != dim.getNum() || sigmas.getNumCols() != 2 * s.getNum() + 1) {
|
||||
throw new IllegalArgumentException(
|
||||
"Sigmas must be covDim by 2 * states + 1! Got "
|
||||
@@ -184,28 +190,64 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
|
||||
// k=1
|
||||
Matrix<C, N1> x = meanFunc.apply(sigmas, Wm);
|
||||
|
||||
// New covariance is the sum of the outer product of the residuals times the
|
||||
// weights
|
||||
Matrix<C, ?> y = new Matrix<>(new SimpleMatrix(dim.getNum(), 2 * s.getNum() + 1));
|
||||
for (int i = 0; i < 2 * s.getNum() + 1; i++) {
|
||||
// y[:, i] = sigmas[:, i] - x
|
||||
y.setColumn(i, residualFunc.apply(sigmas.extractColumnVector(i), x));
|
||||
Matrix<C, ?> Sbar = new Matrix<>(new SimpleMatrix(dim.getNum(), 2 * s.getNum() + dim.getNum()));
|
||||
for (int i = 0; i < 2 * s.getNum(); i++) {
|
||||
Sbar.setColumn(
|
||||
i,
|
||||
residualFunc.apply(sigmas.extractColumnVector(1 + i), x).times(Math.sqrt(Wc.get(1, 0))));
|
||||
}
|
||||
Matrix<C, C> P =
|
||||
y.times(Matrix.changeBoundsUnchecked(Wc.diag()))
|
||||
.times(Matrix.changeBoundsUnchecked(y.transpose()));
|
||||
Sbar.assignBlock(0, 2 * s.getNum(), squareRootR);
|
||||
|
||||
return new Pair<>(x, P);
|
||||
QRDecompositionHouseholder_DDRM qr = new QRDecompositionHouseholder_DDRM();
|
||||
var qrStorage = Sbar.transpose().getStorage();
|
||||
|
||||
if (!qr.decompose(qrStorage.getDDRM())) {
|
||||
throw new RuntimeException("QR decomposition failed! Input matrix:\n" + qrStorage.toString());
|
||||
}
|
||||
|
||||
Matrix<C, C> newS = new Matrix<>(new SimpleMatrix(qr.getR(null, true)));
|
||||
newS.rankUpdate(residualFunc.apply(sigmas.extractColumnVector(0), x), Wc.get(0, 0), false);
|
||||
|
||||
return new Pair<>(x, newS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the error covariance matrix P.
|
||||
* Returns the square-root error covariance matrix S.
|
||||
*
|
||||
* @return the square-root error covariance matrix S.
|
||||
*/
|
||||
public Matrix<States, States> getS() {
|
||||
return m_S;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an element of the square-root error covariance matrix S.
|
||||
*
|
||||
* @param row Row of S.
|
||||
* @param col Column of S.
|
||||
* @return the value of the square-root error covariance matrix S at (i, j).
|
||||
*/
|
||||
public double getS(int row, int col) {
|
||||
return m_S.get(row, col);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the entire square-root error covariance matrix S.
|
||||
*
|
||||
* @param newS The new value of S to use.
|
||||
*/
|
||||
public void setS(Matrix<States, States> newS) {
|
||||
m_S = newS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the reconstructed error covariance matrix P.
|
||||
*
|
||||
* @return the error covariance matrix P.
|
||||
*/
|
||||
@Override
|
||||
public Matrix<States, States> getP() {
|
||||
return m_P;
|
||||
return m_S.transpose().times(m_S);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -214,10 +256,12 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
|
||||
* @param row Row of P.
|
||||
* @param col Column of P.
|
||||
* @return the value of the error covariance matrix P at (i, j).
|
||||
* @throws UnsupportedOperationException indexing into the reconstructed P matrix is not supported
|
||||
*/
|
||||
@Override
|
||||
public double getP(int row, int col) {
|
||||
return m_P.get(row, col);
|
||||
throw new UnsupportedOperationException(
|
||||
"indexing into the reconstructed P matrix is not supported");
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -227,7 +271,7 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
|
||||
*/
|
||||
@Override
|
||||
public void setP(Matrix<States, States> newP) {
|
||||
m_P = newP;
|
||||
m_S = newP.lltDecompose(false);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -277,7 +321,7 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
|
||||
@Override
|
||||
public void reset() {
|
||||
m_xHat = new Matrix<>(m_states, Nat.N1());
|
||||
m_P = new Matrix<>(m_states, m_states);
|
||||
m_S = new Matrix<>(m_states, m_states);
|
||||
m_sigmasF = new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1));
|
||||
}
|
||||
|
||||
@@ -294,8 +338,9 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
|
||||
Matrix<States, States> contA =
|
||||
NumericalJacobian.numericalJacobianX(m_states, m_states, m_f, m_xHat, u);
|
||||
var discQ = Discretization.discretizeAQTaylor(contA, m_contQ, dtSeconds).getSecond();
|
||||
var squareRootDiscQ = discQ.lltDecompose(true);
|
||||
|
||||
var sigmas = m_pts.sigmaPoints(m_xHat, m_P);
|
||||
var sigmas = m_pts.squareRootSigmaPoints(m_xHat, m_S);
|
||||
|
||||
for (int i = 0; i < m_pts.getNumSigmas(); ++i) {
|
||||
Matrix<States, N1> x = sigmas.extractColumnVector(i);
|
||||
@@ -304,17 +349,18 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
|
||||
}
|
||||
|
||||
var ret =
|
||||
unscentedTransform(
|
||||
squareRootUnscentedTransform(
|
||||
m_states,
|
||||
m_states,
|
||||
m_sigmasF,
|
||||
m_pts.getWm(),
|
||||
m_pts.getWc(),
|
||||
m_meanFuncX,
|
||||
m_residualFuncX);
|
||||
m_residualFuncX,
|
||||
squareRootDiscQ);
|
||||
|
||||
m_xHat = ret.getFirst();
|
||||
m_P = ret.getSecond().plus(discQ);
|
||||
m_S = ret.getSecond();
|
||||
m_dtSeconds = dtSeconds;
|
||||
}
|
||||
|
||||
@@ -394,10 +440,11 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
|
||||
BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX,
|
||||
BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) {
|
||||
final var discR = Discretization.discretizeR(R, m_dtSeconds);
|
||||
final var squareRootDiscR = discR.lltDecompose(true);
|
||||
|
||||
// Transform sigma points into measurement space
|
||||
Matrix<R, ?> sigmasH = new Matrix<>(new SimpleMatrix(rows.getNum(), 2 * m_states.getNum() + 1));
|
||||
var sigmas = m_pts.sigmaPoints(m_xHat, m_P);
|
||||
var sigmas = m_pts.squareRootSigmaPoints(m_xHat, m_S);
|
||||
for (int i = 0; i < m_pts.getNumSigmas(); i++) {
|
||||
Matrix<R, N1> hRet = h.apply(sigmas.extractColumnVector(i), u);
|
||||
sigmasH.setColumn(i, hRet);
|
||||
@@ -405,10 +452,17 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
|
||||
|
||||
// Mean and covariance of prediction passed through unscented transform
|
||||
var transRet =
|
||||
unscentedTransform(
|
||||
m_states, rows, sigmasH, m_pts.getWm(), m_pts.getWc(), meanFuncY, residualFuncY);
|
||||
squareRootUnscentedTransform(
|
||||
m_states,
|
||||
rows,
|
||||
sigmasH,
|
||||
m_pts.getWm(),
|
||||
m_pts.getWc(),
|
||||
meanFuncY,
|
||||
residualFuncY,
|
||||
squareRootDiscR);
|
||||
var yHat = transRet.getFirst();
|
||||
var Py = transRet.getSecond().plus(discR);
|
||||
var Sy = transRet.getSecond();
|
||||
|
||||
// Compute cross covariance of the state and the measurements
|
||||
Matrix<States, R> Pxy = new Matrix<>(m_states, rows);
|
||||
@@ -420,17 +474,20 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
|
||||
Pxy = Pxy.plus(dx.times(dy).times(m_pts.getWc(i)));
|
||||
}
|
||||
|
||||
// K = P_{xy} P_y⁻¹
|
||||
// Kᵀ = P_yᵀ⁻¹ P_{xy}ᵀ
|
||||
// P_yᵀKᵀ = P_{xy}ᵀ
|
||||
// Kᵀ = P_yᵀ.solve(P_{xy}ᵀ)
|
||||
// K = (P_yᵀ.solve(P_{xy}ᵀ)ᵀ
|
||||
Matrix<States, R> K = new Matrix<>(Py.transpose().solve(Pxy.transpose()).transpose());
|
||||
// K = (P_{xy} / S_yᵀ) / S_y
|
||||
// K = (S_y \ P_{xy}ᵀ)ᵀ / S_y
|
||||
// K = (S_yᵀ \ (S_y \ P_{xy}ᵀ))ᵀ
|
||||
Matrix<States, R> K =
|
||||
Sy.transpose()
|
||||
.solveFullPivHouseholderQr(Sy.solveFullPivHouseholderQr(Pxy.transpose()))
|
||||
.transpose();
|
||||
|
||||
// x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − ŷ)
|
||||
m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, yHat)));
|
||||
|
||||
// Pₖ₊₁⁺ = Pₖ₊₁⁻ − KP_yKᵀ
|
||||
m_P = m_P.minus(K.times(Py).times(K.transpose()));
|
||||
Matrix<States, R> U = K.times(Sy);
|
||||
for (int i = 0; i < rows.getNum(); i++) {
|
||||
m_S.rankUpdate(U.extractColumnVector(i), -1, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user