diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/ExtendedKalmanFilter.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/ExtendedKalmanFilter.java index 5f89e80f8f..2f3343c982 100644 --- a/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/ExtendedKalmanFilter.java +++ b/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/ExtendedKalmanFilter.java @@ -36,6 +36,9 @@ public class ExtendedKalmanFilter, Matrix, Matrix> m_h; + private BiFunction, Matrix, Matrix> m_residualFuncY; + private BiFunction, Matrix, Matrix> m_addFuncX; + private final Matrix m_contQ; private final Matrix m_initP; private final Matrix m_contR; @@ -70,12 +73,55 @@ public class ExtendedKalmanFilter stateStdDevs, Matrix measurementStdDevs, double dtSeconds) { + this( + states, + inputs, + outputs, + f, + h, + stateStdDevs, + measurementStdDevs, + Matrix::minus, + Matrix::plus, + dtSeconds); + } + + /** + * Constructs an extended Kalman filter. + * + * @param states a Nat representing the number of states. + * @param inputs a Nat representing the number of inputs. + * @param outputs a Nat representing the number of outputs. + * @param f A vector-valued function of x and u that returns the derivative of the state vector. + * @param h A vector-valued function of x and u that returns the measurement vector. + * @param stateStdDevs Standard deviations of model states. + * @param measurementStdDevs Standard deviations of measurements. + * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it + * subtracts them.) + * @param addFuncX A function that adds two state vectors. + * @param dtSeconds Nominal discretization timestep. + */ + @SuppressWarnings({"ParameterName", "PMD.ExcessiveParameterList"}) + public ExtendedKalmanFilter( + Nat states, + Nat inputs, + Nat outputs, + BiFunction, Matrix, Matrix> f, + BiFunction, Matrix, Matrix> h, + Matrix stateStdDevs, + Matrix measurementStdDevs, + BiFunction, Matrix, Matrix> residualFuncY, + BiFunction, Matrix, Matrix> addFuncX, + double dtSeconds) { m_states = states; m_outputs = outputs; m_f = f; m_h = h; + m_residualFuncY = residualFuncY; + m_addFuncX = addFuncX; + m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs); this.m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs); m_dtSeconds = dtSeconds; @@ -234,7 +280,7 @@ public class ExtendedKalmanFilter u, Matrix y) { - correct(m_outputs, u, y, m_h, m_contR); + correct(m_outputs, u, y, m_h, m_contR, m_residualFuncY, m_addFuncX); } /** @@ -258,6 +304,35 @@ public class ExtendedKalmanFilter y, BiFunction, Matrix, Matrix> h, Matrix R) { + correct(rows, u, y, h, R, Matrix::minus, Matrix::plus); + } + + /** + * Correct the state estimate x-hat using the measurements in y. + * + *

This is useful for when the measurements available during a timestep's Correct() call vary. + * The h(x, u) passed to the constructor is used if one is not provided (the two-argument version + * of this function). + * + * @param Number of rows in the result of f(x, u). + * @param rows Number of rows in the result of f(x, u). + * @param u Same control input used in the predict step. + * @param y Measurement vector. + * @param h A vector-valued function of x and u that returns the measurement vector. + * @param R Discrete measurement noise covariance matrix. + * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it + * subtracts them.) + * @param addFuncX A function that adds two state vectors. + */ + @SuppressWarnings({"ParameterName", "MethodTypeParameterName"}) + public void correct( + Nat rows, + Matrix u, + Matrix y, + BiFunction, Matrix, Matrix> h, + Matrix R, + BiFunction, Matrix, Matrix> residualFuncY, + BiFunction, Matrix, Matrix> addFuncX) { final var C = NumericalJacobian.numericalJacobianX(rows, m_states, h, m_xHat, u); final var discR = Discretization.discretizeR(R, m_dtSeconds); @@ -279,7 +354,7 @@ public class ExtendedKalmanFilter K = S.transpose().solve(C.times(m_P.transpose())).transpose(); - m_xHat = m_xHat.plus(K.times(y.minus(h.apply(m_xHat, u)))); + m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, h.apply(m_xHat, u)))); m_P = Matrix.eye(m_states).minus(K.times(C)).times(m_P); } } diff --git a/wpimath/src/main/native/include/frc/estimator/ExtendedKalmanFilter.h b/wpimath/src/main/native/include/frc/estimator/ExtendedKalmanFilter.h index ac8ddc86ed..460cbfe81d 100644 --- a/wpimath/src/main/native/include/frc/estimator/ExtendedKalmanFilter.h +++ b/wpimath/src/main/native/include/frc/estimator/ExtendedKalmanFilter.h @@ -47,6 +47,78 @@ class ExtendedKalmanFilter { : m_f(f), m_h(h) { m_contQ = MakeCovMatrix(stateStdDevs); m_contR = MakeCovMatrix(measurementStdDevs); + m_residualFuncY = [](auto a, auto b) -> Eigen::Matrix { + return a - b; + }; + m_addFuncX = [](auto a, auto b) -> Eigen::Matrix { + return a + b; + }; + m_dt = dt; + + Reset(); + + Eigen::Matrix contA = + NumericalJacobianX( + m_f, m_xHat, Eigen::Matrix::Zero()); + Eigen::Matrix C = + NumericalJacobianX( + m_h, m_xHat, Eigen::Matrix::Zero()); + + Eigen::Matrix discA; + Eigen::Matrix discQ; + DiscretizeAQTaylor(contA, m_contQ, dt, &discA, &discQ); + + Eigen::Matrix discR = + DiscretizeR(m_contR, dt); + + // IsStabilizable(A^T, C^T) will tell us if the system is observable. + bool isObservable = + IsStabilizable(discA.transpose(), C.transpose()); + if (isObservable && Outputs <= States) { + m_initP = drake::math::DiscreteAlgebraicRiccatiEquation( + discA.transpose(), C.transpose(), discQ, discR); + } else { + m_initP = Eigen::Matrix::Zero(); + } + m_P = m_initP; + } + + /** + * Constructs an Extended Kalman filter. + * + * @param f A vector-valued function of x and u that returns + * the derivative of the state vector. + * @param h A vector-valued function of x and u that returns + * the measurement vector. + * @param stateStdDevs Standard deviations of model states. + * @param measurementStdDevs Standard deviations of measurements. + * @param residualFuncY A function that computes the residual of two + * measurement vectors (i.e. it subtracts them.) + * @param addFuncX A function that adds two state vectors. + * @param dt Nominal discretization timestep. + */ + ExtendedKalmanFilter(std::function( + const Eigen::Matrix&, + const Eigen::Matrix&)> + f, + std::function( + const Eigen::Matrix&, + const Eigen::Matrix&)> + h, + const wpi::array& stateStdDevs, + const wpi::array& measurementStdDevs, + std::function( + const Eigen::Matrix&, + const Eigen::Matrix&)> + residualFuncY, + std::function( + const Eigen::Matrix&, + const Eigen::Matrix&)> + addFuncX, + units::second_t dt) + : m_f(f), m_h(h), m_residualFuncY(residualFuncY), m_addFuncX(addFuncX) { + m_contQ = MakeCovMatrix(stateStdDevs); + m_contR = MakeCovMatrix(measurementStdDevs); m_dt = dt; Reset(); @@ -162,7 +234,24 @@ class ExtendedKalmanFilter { */ void Correct(const Eigen::Matrix& u, const Eigen::Matrix& y) { - Correct(u, y, m_h, m_contR); + Correct(u, y, m_h, m_contR, m_residualFuncY, m_addFuncX); + } + + template + void Correct(const Eigen::Matrix& u, + const Eigen::Matrix& y, + std::function( + const Eigen::Matrix&, + const Eigen::Matrix&)> + h, + const Eigen::Matrix& R) { + auto residualFuncY = [](auto a, auto b) -> Eigen::Matrix { + return a - b; + }; + auto addFuncX = [](auto a, auto b) -> Eigen::Matrix { + return a + b; + }; + Correct(u, y, h, R, residualFuncY, addFuncX); } /** @@ -172,11 +261,14 @@ class ExtendedKalmanFilter { * Correct() call vary. The h(x, u) passed to the constructor is used if one * is not provided (the two-argument version of this function). * - * @param u Same control input used in the predict step. - * @param y Measurement vector. - * @param h A vector-valued function of x and u that returns - * the measurement vector. - * @param R Discrete measurement noise covariance matrix. + * @param u Same control input used in the predict step. + * @param y Measurement vector. + * @param h A vector-valued function of x and u that returns + * the measurement vector. + * @param R Discrete measurement noise covariance matrix. + * @param residualFuncY A function that computes the residual of two + * measurement vectors (i.e. it subtracts them.) + * @param addFuncX A function that adds two state vectors. */ template void Correct(const Eigen::Matrix& u, @@ -185,7 +277,15 @@ class ExtendedKalmanFilter { const Eigen::Matrix&, const Eigen::Matrix&)> h, - const Eigen::Matrix& R) { + const Eigen::Matrix& R, + std::function( + const Eigen::Matrix&, + const Eigen::Matrix&)> + residualFuncY, + std::function( + const Eigen::Matrix&, + const Eigen::Matrix)> + addFuncX) { const Eigen::Matrix C = NumericalJacobianX(h, m_xHat, u); const Eigen::Matrix discR = DiscretizeR(R, m_dt); @@ -207,7 +307,7 @@ class ExtendedKalmanFilter { Eigen::Matrix K = S.transpose().ldlt().solve(C * m_P.transpose()).transpose(); - m_xHat += K * (y - h(m_xHat, u)); + m_xHat = addFuncX(m_xHat, K * residualFuncY(y, h(m_xHat, u))); m_P = (Eigen::Matrix::Identity() - K * C) * m_P; } @@ -220,6 +320,14 @@ class ExtendedKalmanFilter { const Eigen::Matrix&, const Eigen::Matrix&)> m_h; + std::function( + const Eigen::Matrix&, + const Eigen::Matrix)> + m_residualFuncY; + std::function( + const Eigen::Matrix&, + const Eigen::Matrix)> + m_addFuncX; Eigen::Matrix m_xHat; Eigen::Matrix m_P; Eigen::Matrix m_contQ;