mirror of
https://github.com/wpilibsuite/allwpilib
synced 2026-06-20 00:51:42 +00:00
[wpimath] Refactor KalmanFilter to be steady-state only (#2657)
I didn't notice a performance difference between the original implementation and this one for a flywheel simulation, so this simplifies a lot of internals. This class can no longer implement KalmanTypeFilter because that class allows setting the error covariance for use in the KalmanFilterLatencyCompensator class. This won't impact the holonomic pose estimators that use KalmanFilterLatencyCompensator because they all use an EKF.
This commit is contained in:
@@ -7,8 +7,6 @@
|
||||
|
||||
package edu.wpi.first.wpilibj.estimator;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
import edu.wpi.first.math.Drake;
|
||||
import edu.wpi.first.math.MathSharedStore;
|
||||
import edu.wpi.first.wpilibj.math.Discretization;
|
||||
@@ -35,38 +33,19 @@ import edu.wpi.first.wpiutil.math.numbers.N1;
|
||||
*/
|
||||
@SuppressWarnings("ClassTypeParameterName")
|
||||
public class KalmanFilter<States extends Num, Inputs extends Num,
|
||||
Outputs extends Num> implements KalmanTypeFilter<States, Inputs, Outputs> {
|
||||
|
||||
Outputs extends Num> {
|
||||
private final Nat<States> m_states;
|
||||
|
||||
private final LinearSystem<States, Inputs, Outputs> m_plant;
|
||||
|
||||
private final Matrix<States, N1> m_stateStdDevs;
|
||||
private final Matrix<Outputs, N1> m_measurementStdDevs;
|
||||
|
||||
/**
|
||||
* Error covariance matrix.
|
||||
* The steady-state Kalman gain matrix.
|
||||
*/
|
||||
@SuppressWarnings("MemberName")
|
||||
private Matrix<States, States> m_P;
|
||||
private final Matrix<States, Outputs> m_K;
|
||||
|
||||
/**
|
||||
* Continuous process noise covariance matrix.
|
||||
*/
|
||||
private final Matrix<States, States> m_contQ;
|
||||
|
||||
/**
|
||||
* Continuous measurement noise covariance matrix.
|
||||
*/
|
||||
private final Matrix<Outputs, Outputs> m_contR;
|
||||
|
||||
/**
|
||||
* Discrete measurement noise covariance matrix.
|
||||
*/
|
||||
private Matrix<Outputs, Outputs> m_discR;
|
||||
|
||||
/**
|
||||
* The current state estimate x-hat.
|
||||
* The state estimate.
|
||||
*/
|
||||
@SuppressWarnings("MemberName")
|
||||
private Matrix<States, N1> m_xHat;
|
||||
@@ -81,6 +60,7 @@ public class KalmanFilter<States extends Num, Inputs extends Num,
|
||||
* @param measurementStdDevs Standard deviations of measurements.
|
||||
* @param dtSeconds Nominal discretization timestep.
|
||||
*/
|
||||
@SuppressWarnings("LocalVariableName")
|
||||
public KalmanFilter(
|
||||
Nat<States> states, Nat<Outputs> outputs,
|
||||
LinearSystem<States, Inputs, Outputs> plant,
|
||||
@@ -92,182 +72,30 @@ public class KalmanFilter<States extends Num, Inputs extends Num,
|
||||
|
||||
this.m_plant = plant;
|
||||
|
||||
this.m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
|
||||
this.m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
|
||||
var contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
|
||||
var contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
|
||||
|
||||
this.m_stateStdDevs = stateStdDevs;
|
||||
this.m_measurementStdDevs = measurementStdDevs;
|
||||
|
||||
var pair = Discretization.discretizeAQTaylor(plant.getA(), m_contQ, dtSeconds);
|
||||
var pair = Discretization.discretizeAQTaylor(plant.getA(), contQ, dtSeconds);
|
||||
var discA = pair.getFirst();
|
||||
var discQ = pair.getSecond();
|
||||
|
||||
m_discR = Discretization.discretizeR(m_contR, dtSeconds);
|
||||
var discR = Discretization.discretizeR(contR, dtSeconds);
|
||||
|
||||
// IsStabilizable(A^T, C^T) will tell us if the system is observable.
|
||||
var isObservable = StateSpaceUtil.isStabilizable(discA.transpose(), plant.getC().transpose());
|
||||
if (isObservable) {
|
||||
if (outputs.getNum() <= states.getNum()) {
|
||||
m_P = Drake.discreteAlgebraicRiccatiEquation(
|
||||
discA.transpose(), plant.getC().transpose(), discQ, m_discR);
|
||||
} else {
|
||||
m_P = new Matrix<>(new SimpleMatrix(states.getNum(), states.getNum()));
|
||||
}
|
||||
} else {
|
||||
MathSharedStore.reportError("The system passed to the Kalman Filter is not observable!",
|
||||
var C = plant.getC();
|
||||
|
||||
// isStabilizable(A^T, C^T) will tell us if the system is observable.
|
||||
var isObservable = StateSpaceUtil.isStabilizable(discA.transpose(), C.transpose());
|
||||
if (!isObservable) {
|
||||
MathSharedStore.reportError("The system passed to the Kalman filter is not observable!",
|
||||
Thread.currentThread().getStackTrace());
|
||||
throw new IllegalArgumentException(
|
||||
"The system passed to the Kalman Filter is not observable!");
|
||||
"The system passed to the Kalman filter is not observable!");
|
||||
}
|
||||
|
||||
reset();
|
||||
}
|
||||
var P = new Matrix<>(Drake.discreteAlgebraicRiccatiEquation(
|
||||
discA.transpose(), C.transpose(), discQ, discR));
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
m_xHat = new Matrix<>(m_states, Nat.N1());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the error covariance matrix P.
|
||||
*
|
||||
* @return the error covariance matrix P.
|
||||
*/
|
||||
@Override
|
||||
public Matrix<States, States> getP() {
|
||||
return m_P;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an element of the error covariance matrix P.
|
||||
*
|
||||
* @param row Row of P.
|
||||
* @param col Column of P.
|
||||
* @return the element (i, j) of the error covariance matrix P.
|
||||
*/
|
||||
@Override
|
||||
public double getP(int row, int col) {
|
||||
return m_P.get(row, col);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the entire error covariance matrix P.
|
||||
*
|
||||
* @param newP The new value of P to use.
|
||||
*/
|
||||
@Override
|
||||
public void setP(Matrix<States, States> newP) {
|
||||
m_P = newP;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set initial state estimate x-hat.
|
||||
*
|
||||
* @param xhat The state estimate x-hat.
|
||||
*/
|
||||
@Override
|
||||
public void setXhat(Matrix<States, N1> xhat) {
|
||||
this.m_xHat = xhat;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set an element of the initial state estimate x-hat.
|
||||
*
|
||||
* @param row Row of x-hat.
|
||||
* @param value Value for element of x-hat.
|
||||
*/
|
||||
@Override
|
||||
public void setXhat(int row, double value) {
|
||||
m_xHat.set(row, 0, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the state estimate x-hat.
|
||||
*
|
||||
* @return The state estimate x-hat.
|
||||
*/
|
||||
@Override
|
||||
public Matrix<States, N1> getXhat() {
|
||||
return m_xHat;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an element of the state estimate x-hat.
|
||||
*
|
||||
* @param row Row of x-hat.
|
||||
* @return the state estimate x-hat at i.
|
||||
*/
|
||||
@Override
|
||||
public double getXhat(int row) {
|
||||
return m_xHat.get(row, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the state standard deviations used to make Q.
|
||||
*/
|
||||
public Matrix<States, N1> getStateStdDevs() {
|
||||
return m_stateStdDevs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the measurement standard deviations used to make R.
|
||||
*/
|
||||
public Matrix<Outputs, N1> getMeasurementStdDevs() {
|
||||
return m_measurementStdDevs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Project the model into the future with a new control input u.
|
||||
*
|
||||
* @param u New control input from controller.
|
||||
* @param dtSeconds Timestep for prediction.
|
||||
*/
|
||||
@SuppressWarnings("ParameterName")
|
||||
@Override
|
||||
public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
|
||||
this.m_xHat = m_plant.calculateX(m_xHat, u, dtSeconds);
|
||||
|
||||
var pair = Discretization.discretizeAQTaylor(m_plant.getA(), m_contQ, dtSeconds);
|
||||
var discA = pair.getFirst();
|
||||
var discQ = pair.getSecond();
|
||||
|
||||
m_P = discA.times(m_P).times(discA.transpose()).plus(discQ);
|
||||
m_discR = Discretization.discretizeR(m_contR, dtSeconds);
|
||||
}
|
||||
|
||||
/**
|
||||
* Correct the state estimate x-hat using the measurements in y.
|
||||
*
|
||||
* @param u Same control input used in the last predict step.
|
||||
* @param y Measurement vector.
|
||||
*/
|
||||
@SuppressWarnings("ParameterName")
|
||||
@Override
|
||||
public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
|
||||
correct(u, y, m_plant.getC(), m_plant.getD(), m_discR);
|
||||
}
|
||||
|
||||
/**
|
||||
* Correct the state estimate x-hat using the measurements in y.
|
||||
*
|
||||
* <p>This is useful for when the measurements available during a timestep's
|
||||
* Correct() call vary. The C matrix passed to the constructor is used if one
|
||||
* is not provided (the two-argument version of this function).
|
||||
*
|
||||
* @param <R> Number of rows in the result of f(x, u).
|
||||
* @param u Same control input used in the predict step.
|
||||
* @param y Measurement vector.
|
||||
* @param C Output matrix.
|
||||
* @param r Discrete measurement noise covariance matrix.
|
||||
*/
|
||||
@SuppressWarnings({"ParameterName", "LocalVariableName"})
|
||||
public <R extends Num> void correct(
|
||||
Matrix<Inputs, N1> u,
|
||||
Matrix<R, N1> y,
|
||||
Matrix<R, States> C,
|
||||
Matrix<R, Inputs> D,
|
||||
Matrix<R, R> r) {
|
||||
var S = C.times(m_P).times(C.transpose()).plus(r);
|
||||
var S = C.times(P).times(C.transpose()).plus(discR);
|
||||
|
||||
// We want to put K = PC^T S^-1 into Ax = b form so we can solve it more
|
||||
// efficiently.
|
||||
@@ -281,10 +109,95 @@ public class KalmanFilter<States extends Num, Inputs extends Num,
|
||||
//
|
||||
// K^T = S^T.solve(CP^T)
|
||||
// K = (S^T.solve(CP^T))^T
|
||||
Matrix<States, R> K = S.transpose().solve(C.times(m_P.transpose())).transpose();
|
||||
m_K = new Matrix<>(S.transpose().getStorage()
|
||||
.solve((C.times(P.transpose())).getStorage()).transpose());
|
||||
|
||||
m_xHat = m_xHat.plus(K.times(y.minus(C.times(m_xHat).plus(D.times(u)))));
|
||||
m_P = Matrix.eye(m_states).minus(K.times(C)).times(m_P);
|
||||
reset();
|
||||
}
|
||||
|
||||
public void reset() {
|
||||
m_xHat = new Matrix<>(m_states, Nat.N1());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the steady-state Kalman gain matrix K.
|
||||
*
|
||||
* @return The steady-state Kalman gain matrix K.
|
||||
*/
|
||||
public Matrix<States, Outputs> getK() {
|
||||
return m_K;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an element of the steady-state Kalman gain matrix K.
|
||||
*
|
||||
* @param row Row of K.
|
||||
* @param col Column of K.
|
||||
* @return the element (i, j) of the steady-state Kalman gain matrix K.
|
||||
*/
|
||||
public double getK(int row, int col) {
|
||||
return m_K.get(row, col);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set initial state estimate x-hat.
|
||||
*
|
||||
* @param xhat The state estimate x-hat.
|
||||
*/
|
||||
public void setXhat(Matrix<States, N1> xhat) {
|
||||
this.m_xHat = xhat;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set an element of the initial state estimate x-hat.
|
||||
*
|
||||
* @param row Row of x-hat.
|
||||
* @param value Value for element of x-hat.
|
||||
*/
|
||||
public void setXhat(int row, double value) {
|
||||
m_xHat.set(row, 0, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the state estimate x-hat.
|
||||
*
|
||||
* @return The state estimate x-hat.
|
||||
*/
|
||||
public Matrix<States, N1> getXhat() {
|
||||
return m_xHat;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an element of the state estimate x-hat.
|
||||
*
|
||||
* @param row Row of x-hat.
|
||||
* @return the state estimate x-hat at i.
|
||||
*/
|
||||
public double getXhat(int row) {
|
||||
return m_xHat.get(row, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Project the model into the future with a new control input u.
|
||||
*
|
||||
* @param u New control input from controller.
|
||||
* @param dtSeconds Timestep for prediction.
|
||||
*/
|
||||
@SuppressWarnings("ParameterName")
|
||||
public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
|
||||
this.m_xHat = m_plant.calculateX(m_xHat, u, dtSeconds);
|
||||
}
|
||||
|
||||
/**
|
||||
* Correct the state estimate x-hat using the measurements in y.
|
||||
*
|
||||
* @param u Same control input used in the last predict step.
|
||||
* @param y Measurement vector.
|
||||
*/
|
||||
@SuppressWarnings("ParameterName")
|
||||
public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
|
||||
final var C = m_plant.getC();
|
||||
final var D = m_plant.getD();
|
||||
m_xHat = m_xHat.plus(m_K.times(y.minus(C.times(m_xHat).plus(D.times(u)))));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user