diff --git a/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/KalmanFilter.java b/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/KalmanFilter.java index a03dae5c37..26af2e5873 100644 --- a/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/KalmanFilter.java +++ b/wpimath/src/main/java/edu/wpi/first/wpilibj/estimator/KalmanFilter.java @@ -7,8 +7,6 @@ package edu.wpi.first.wpilibj.estimator; -import org.ejml.simple.SimpleMatrix; - import edu.wpi.first.math.Drake; import edu.wpi.first.math.MathSharedStore; import edu.wpi.first.wpilibj.math.Discretization; @@ -35,38 +33,19 @@ import edu.wpi.first.wpiutil.math.numbers.N1; */ @SuppressWarnings("ClassTypeParameterName") public class KalmanFilter implements KalmanTypeFilter { - + Outputs extends Num> { private final Nat m_states; private final LinearSystem m_plant; - private final Matrix m_stateStdDevs; - private final Matrix m_measurementStdDevs; - /** - * Error covariance matrix. + * The steady-state Kalman gain matrix. */ @SuppressWarnings("MemberName") - private Matrix m_P; + private final Matrix m_K; /** - * Continuous process noise covariance matrix. - */ - private final Matrix m_contQ; - - /** - * Continuous measurement noise covariance matrix. - */ - private final Matrix m_contR; - - /** - * Discrete measurement noise covariance matrix. - */ - private Matrix m_discR; - - /** - * The current state estimate x-hat. + * The state estimate. */ @SuppressWarnings("MemberName") private Matrix m_xHat; @@ -81,6 +60,7 @@ public class KalmanFilter states, Nat outputs, LinearSystem plant, @@ -92,182 +72,30 @@ public class KalmanFilter(new SimpleMatrix(states.getNum(), states.getNum())); - } - } else { - MathSharedStore.reportError("The system passed to the Kalman Filter is not observable!", + var C = plant.getC(); + + // isStabilizable(A^T, C^T) will tell us if the system is observable. + var isObservable = StateSpaceUtil.isStabilizable(discA.transpose(), C.transpose()); + if (!isObservable) { + MathSharedStore.reportError("The system passed to the Kalman filter is not observable!", Thread.currentThread().getStackTrace()); throw new IllegalArgumentException( - "The system passed to the Kalman Filter is not observable!"); + "The system passed to the Kalman filter is not observable!"); } - reset(); - } + var P = new Matrix<>(Drake.discreteAlgebraicRiccatiEquation( + discA.transpose(), C.transpose(), discQ, discR)); - @Override - public void reset() { - m_xHat = new Matrix<>(m_states, Nat.N1()); - } - - /** - * Returns the error covariance matrix P. - * - * @return the error covariance matrix P. - */ - @Override - public Matrix getP() { - return m_P; - } - - /** - * Returns an element of the error covariance matrix P. - * - * @param row Row of P. - * @param col Column of P. - * @return the element (i, j) of the error covariance matrix P. - */ - @Override - public double getP(int row, int col) { - return m_P.get(row, col); - } - - /** - * Sets the entire error covariance matrix P. - * - * @param newP The new value of P to use. - */ - @Override - public void setP(Matrix newP) { - m_P = newP; - } - - /** - * Set initial state estimate x-hat. - * - * @param xhat The state estimate x-hat. - */ - @Override - public void setXhat(Matrix xhat) { - this.m_xHat = xhat; - } - - /** - * Set an element of the initial state estimate x-hat. - * - * @param row Row of x-hat. - * @param value Value for element of x-hat. - */ - @Override - public void setXhat(int row, double value) { - m_xHat.set(row, 0, value); - } - - /** - * Returns the state estimate x-hat. - * - * @return The state estimate x-hat. - */ - @Override - public Matrix getXhat() { - return m_xHat; - } - - /** - * Returns an element of the state estimate x-hat. - * - * @param row Row of x-hat. - * @return the state estimate x-hat at i. - */ - @Override - public double getXhat(int row) { - return m_xHat.get(row, 0); - } - - /** - * Returns the state standard deviations used to make Q. - */ - public Matrix getStateStdDevs() { - return m_stateStdDevs; - } - - /** - * Returns the measurement standard deviations used to make R. - */ - public Matrix getMeasurementStdDevs() { - return m_measurementStdDevs; - } - - /** - * Project the model into the future with a new control input u. - * - * @param u New control input from controller. - * @param dtSeconds Timestep for prediction. - */ - @SuppressWarnings("ParameterName") - @Override - public void predict(Matrix u, double dtSeconds) { - this.m_xHat = m_plant.calculateX(m_xHat, u, dtSeconds); - - var pair = Discretization.discretizeAQTaylor(m_plant.getA(), m_contQ, dtSeconds); - var discA = pair.getFirst(); - var discQ = pair.getSecond(); - - m_P = discA.times(m_P).times(discA.transpose()).plus(discQ); - m_discR = Discretization.discretizeR(m_contR, dtSeconds); - } - - /** - * Correct the state estimate x-hat using the measurements in y. - * - * @param u Same control input used in the last predict step. - * @param y Measurement vector. - */ - @SuppressWarnings("ParameterName") - @Override - public void correct(Matrix u, Matrix y) { - correct(u, y, m_plant.getC(), m_plant.getD(), m_discR); - } - - /** - * Correct the state estimate x-hat using the measurements in y. - * - *

This is useful for when the measurements available during a timestep's - * Correct() call vary. The C matrix passed to the constructor is used if one - * is not provided (the two-argument version of this function). - * - * @param Number of rows in the result of f(x, u). - * @param u Same control input used in the predict step. - * @param y Measurement vector. - * @param C Output matrix. - * @param r Discrete measurement noise covariance matrix. - */ - @SuppressWarnings({"ParameterName", "LocalVariableName"}) - public void correct( - Matrix u, - Matrix y, - Matrix C, - Matrix D, - Matrix r) { - var S = C.times(m_P).times(C.transpose()).plus(r); + var S = C.times(P).times(C.transpose()).plus(discR); // We want to put K = PC^T S^-1 into Ax = b form so we can solve it more // efficiently. @@ -281,10 +109,95 @@ public class KalmanFilter K = S.transpose().solve(C.times(m_P.transpose())).transpose(); + m_K = new Matrix<>(S.transpose().getStorage() + .solve((C.times(P.transpose())).getStorage()).transpose()); - m_xHat = m_xHat.plus(K.times(y.minus(C.times(m_xHat).plus(D.times(u))))); - m_P = Matrix.eye(m_states).minus(K.times(C)).times(m_P); + reset(); } + public void reset() { + m_xHat = new Matrix<>(m_states, Nat.N1()); + } + + /** + * Returns the steady-state Kalman gain matrix K. + * + * @return The steady-state Kalman gain matrix K. + */ + public Matrix getK() { + return m_K; + } + + /** + * Returns an element of the steady-state Kalman gain matrix K. + * + * @param row Row of K. + * @param col Column of K. + * @return the element (i, j) of the steady-state Kalman gain matrix K. + */ + public double getK(int row, int col) { + return m_K.get(row, col); + } + + /** + * Set initial state estimate x-hat. + * + * @param xhat The state estimate x-hat. + */ + public void setXhat(Matrix xhat) { + this.m_xHat = xhat; + } + + /** + * Set an element of the initial state estimate x-hat. + * + * @param row Row of x-hat. + * @param value Value for element of x-hat. + */ + public void setXhat(int row, double value) { + m_xHat.set(row, 0, value); + } + + /** + * Returns the state estimate x-hat. + * + * @return The state estimate x-hat. + */ + public Matrix getXhat() { + return m_xHat; + } + + /** + * Returns an element of the state estimate x-hat. + * + * @param row Row of x-hat. + * @return the state estimate x-hat at i. + */ + public double getXhat(int row) { + return m_xHat.get(row, 0); + } + + /** + * Project the model into the future with a new control input u. + * + * @param u New control input from controller. + * @param dtSeconds Timestep for prediction. + */ + @SuppressWarnings("ParameterName") + public void predict(Matrix u, double dtSeconds) { + this.m_xHat = m_plant.calculateX(m_xHat, u, dtSeconds); + } + + /** + * Correct the state estimate x-hat using the measurements in y. + * + * @param u Same control input used in the last predict step. + * @param y Measurement vector. + */ + @SuppressWarnings("ParameterName") + public void correct(Matrix u, Matrix y) { + final var C = m_plant.getC(); + final var D = m_plant.getD(); + m_xHat = m_xHat.plus(m_K.times(y.minus(C.times(m_xHat).plus(D.times(u))))); + } } diff --git a/wpimath/src/main/native/include/frc/estimator/KalmanFilter.h b/wpimath/src/main/native/include/frc/estimator/KalmanFilter.h index 4bff9d5ac5..dc522e60d7 100644 --- a/wpimath/src/main/native/include/frc/estimator/KalmanFilter.h +++ b/wpimath/src/main/native/include/frc/estimator/KalmanFilter.h @@ -56,55 +56,65 @@ class KalmanFilterImpl { units::second_t dt) { m_plant = &plant; - m_contQ = MakeCovMatrix(stateStdDevs); - m_contR = MakeCovMatrix(measurementStdDevs); + auto contQ = MakeCovMatrix(stateStdDevs); + auto contR = MakeCovMatrix(measurementStdDevs); Eigen::Matrix discA; Eigen::Matrix discQ; - DiscretizeAQTaylor(plant.A(), m_contQ, dt, &discA, &discQ); + DiscretizeAQTaylor(plant.A(), contQ, dt, &discA, &discQ); - m_discR = DiscretizeR(m_contR, dt); + auto discR = DiscretizeR(contR, dt); + + const auto& C = plant.C(); // IsStabilizable(A^T, C^T) will tell us if the system is observable. - bool isObservable = IsStabilizable(discA.transpose(), - plant.C().transpose()); - if (isObservable) { - if (Outputs <= States) { - m_P = drake::math::DiscreteAlgebraicRiccatiEquation( - discA.transpose(), plant.C().transpose(), discQ, m_discR); - } else { - m_P.setZero(); - } - } else { + bool isObservable = + IsStabilizable(discA.transpose(), C.transpose()); + if (!isObservable) { wpi::math::MathSharedStore::ReportError( - "The system passed to the Kalman Filter is not observable!"); + "The system passed to the Kalman filter is not observable!"); throw std::invalid_argument( - "The system passed to the Kalman Filter is not observable!"); + "The system passed to the Kalman filter is not observable!"); } + + Eigen::Matrix P = + drake::math::DiscreteAlgebraicRiccatiEquation( + discA.transpose(), C.transpose(), discQ, discR); + + Eigen::Matrix S = C * P * C.transpose() + discR; + + // We want to put K = PC^T S^-1 into Ax = b form so we can solve it more + // efficiently. + // + // K = PC^T S^-1 + // KS = PC^T + // (KS)^T = (PC^T)^T + // S^T K^T = CP^T + // + // The solution of Ax = b can be found via x = A.solve(b). + // + // K^T = S^T.solve(CP^T) + // K = (S^T.solve(CP^T))^T + m_K = S.transpose().ldlt().solve(C * P.transpose()).transpose(); + + Reset(); } KalmanFilterImpl(KalmanFilterImpl&&) = default; KalmanFilterImpl& operator=(KalmanFilterImpl&&) = default; /** - * Returns the error covariance matrix P. + * Returns the steady-state Kalman gain matrix K. */ - const Eigen::Matrix& P() const { return m_P; } + const Eigen::Matrix& K() const { return m_K; } /** - * Returns an element of the error covariance matrix P. + * Returns an element of the steady-state Kalman gain matrix K. * - * @param i Row of P. - * @param j Column of P. + * @param i Row of K. + * @param j Column of K. */ - double P(int i, int j) const { return m_P(i, j); } - - /** - * Set the current error covariance matrix P. - * - * @param P The error covariance matrix P. - */ - void SetP(const Eigen::Matrix& P) { m_P = P; } + double K(int i, int j) const { return m_K(i, j); } /** * Returns the state estimate x-hat. @@ -146,13 +156,6 @@ class KalmanFilterImpl { */ void Predict(const Eigen::Matrix& u, units::second_t dt) { m_xHat = m_plant->CalculateX(m_xHat, u, dt); - - Eigen::Matrix discA; - Eigen::Matrix discQ; - DiscretizeAQTaylor(m_plant->A(), m_contQ, dt, &discA, &discQ); - - m_P = discA * m_P * discA.transpose() + discQ; - m_discR = DiscretizeR(m_contR, dt); } /** @@ -163,75 +166,19 @@ class KalmanFilterImpl { */ void Correct(const Eigen::Matrix& u, const Eigen::Matrix& y) { - Correct(u, y, m_plant->C(), m_plant->D(), m_discR); - } - - /** - * Correct the state estimate x-hat using the measurements in y. - * - * This is useful for when the measurements available during a timestep's - * Correct() call vary. The C matrix passed to the constructor is used if one - * is not provided (the two-argument version of this function). - * - * @param u Same control input used in the predict step. - * @param y Measurement vector. - * @param C Output matrix. - * @param D Feedthrough matrix. - * @param R Measurement noise covariance matrix. - */ - template - void Correct(const Eigen::Matrix& u, - const Eigen::Matrix& y, - const Eigen::Matrix& C, - const Eigen::Matrix& D, - const Eigen::Matrix& R) { - const auto& x = m_xHat; - Eigen::Matrix S = C * m_P * C.transpose() + R; - - // We want to put K = PC^T S^-1 into Ax = b form so we can solve it more - // efficiently. - // - // K = PC^T S^-1 - // KS = PC^T - // (KS)^T = (PC^T)^T - // S^T K^T = CP^T - // - // The solution of Ax = b can be found via x = A.solve(b). - // - // K^T = S^T.solve(CP^T) - // K = (S^T.solve(CP^T))^T - Eigen::Matrix K = - S.transpose().ldlt().solve(C * m_P.transpose()).transpose(); - - m_xHat = x + K * (y - (C * x + D * u)); - m_P = (Eigen::Matrix::Identity() - K * C) * m_P; + m_xHat += m_K * (y - (m_plant->C() * m_xHat + m_plant->D() * u)); } private: LinearSystem* m_plant; /** - * Error covariance matrix. + * The steady-state Kalman gain matrix. */ - Eigen::Matrix m_P; + Eigen::Matrix m_K; /** - * Continuous process noise covariance matrix. - */ - Eigen::Matrix m_contQ; - - /** - * Continuous measurement noise covariance matrix. - */ - Eigen::Matrix m_contR; - - /** - * Discrete measurement noise covariance matrix. - */ - Eigen::Matrix m_discR; - - /** - * State estimate x-hat. + * The state estimate. */ Eigen::Matrix m_xHat; };