[wpimath] Add custom residual support to EKF (#3148)

Fixes #3145.

Co-authored-by: Declan Freeman-Gleason <declanfreemangleason@gmail.com>
This commit is contained in:
Tyler Veness
2021-02-12 22:13:36 -08:00
committed by GitHub
parent 5899f3dd28
commit 94e685e1bd
2 changed files with 193 additions and 10 deletions

View File

@@ -36,6 +36,9 @@ public class ExtendedKalmanFilter<States extends Num, Inputs extends Num, Output
@SuppressWarnings("MemberName")
private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h;
private BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> m_residualFuncY;
private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX;
private final Matrix<States, States> m_contQ;
private final Matrix<States, States> m_initP;
private final Matrix<Outputs, Outputs> m_contR;
@@ -70,12 +73,55 @@ public class ExtendedKalmanFilter<States extends Num, Inputs extends Num, Output
Matrix<States, N1> stateStdDevs,
Matrix<Outputs, N1> measurementStdDevs,
double dtSeconds) {
this(
states,
inputs,
outputs,
f,
h,
stateStdDevs,
measurementStdDevs,
Matrix::minus,
Matrix::plus,
dtSeconds);
}
/**
* Constructs an extended Kalman filter.
*
* @param states a Nat representing the number of states.
* @param inputs a Nat representing the number of inputs.
* @param outputs a Nat representing the number of outputs.
* @param f A vector-valued function of x and u that returns the derivative of the state vector.
* @param h A vector-valued function of x and u that returns the measurement vector.
* @param stateStdDevs Standard deviations of model states.
* @param measurementStdDevs Standard deviations of measurements.
* @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
* subtracts them.)
* @param addFuncX A function that adds two state vectors.
* @param dtSeconds Nominal discretization timestep.
*/
@SuppressWarnings({"ParameterName", "PMD.ExcessiveParameterList"})
public ExtendedKalmanFilter(
Nat<States> states,
Nat<Inputs> inputs,
Nat<Outputs> outputs,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
Matrix<States, N1> stateStdDevs,
Matrix<Outputs, N1> measurementStdDevs,
BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY,
BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX,
double dtSeconds) {
m_states = states;
m_outputs = outputs;
m_f = f;
m_h = h;
m_residualFuncY = residualFuncY;
m_addFuncX = addFuncX;
m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
this.m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
m_dtSeconds = dtSeconds;
@@ -234,7 +280,7 @@ public class ExtendedKalmanFilter<States extends Num, Inputs extends Num, Output
@SuppressWarnings("ParameterName")
@Override
public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
correct(m_outputs, u, y, m_h, m_contR);
correct(m_outputs, u, y, m_h, m_contR, m_residualFuncY, m_addFuncX);
}
/**
@@ -258,6 +304,35 @@ public class ExtendedKalmanFilter<States extends Num, Inputs extends Num, Output
Matrix<Rows, N1> y,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h,
Matrix<Rows, Rows> R) {
correct(rows, u, y, h, R, Matrix::minus, Matrix::plus);
}
/**
* Correct the state estimate x-hat using the measurements in y.
*
* <p>This is useful for when the measurements available during a timestep's Correct() call vary.
* The h(x, u) passed to the constructor is used if one is not provided (the two-argument version
* of this function).
*
* @param <Rows> Number of rows in the result of f(x, u).
* @param rows Number of rows in the result of f(x, u).
* @param u Same control input used in the predict step.
* @param y Measurement vector.
* @param h A vector-valued function of x and u that returns the measurement vector.
* @param R Discrete measurement noise covariance matrix.
* @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
* subtracts them.)
* @param addFuncX A function that adds two state vectors.
*/
@SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
public <Rows extends Num> void correct(
Nat<Rows> rows,
Matrix<Inputs, N1> u,
Matrix<Rows, N1> y,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h,
Matrix<Rows, Rows> R,
BiFunction<Matrix<Rows, N1>, Matrix<Rows, N1>, Matrix<Rows, N1>> residualFuncY,
BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) {
final var C = NumericalJacobian.numericalJacobianX(rows, m_states, h, m_xHat, u);
final var discR = Discretization.discretizeR(R, m_dtSeconds);
@@ -279,7 +354,7 @@ public class ExtendedKalmanFilter<States extends Num, Inputs extends Num, Output
// Now we have the Kalman gain
final Matrix<States, Rows> K = S.transpose().solve(C.times(m_P.transpose())).transpose();
m_xHat = m_xHat.plus(K.times(y.minus(h.apply(m_xHat, u))));
m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, h.apply(m_xHat, u))));
m_P = Matrix.eye(m_states).minus(K.times(C)).times(m_P);
}
}