From 1530fccbd0193f33821e2a5289609bf3e9ebf194 Mon Sep 17 00:00:00 2001 From: Joseph Eng <91924258+KangarooKoala@users.noreply.github.com> Date: Tue, 15 Jul 2025 21:17:25 -0700 Subject: [PATCH] [wpimath] Implement Scaled Spherical Simplex Filter (S3F) (#8091) Adds S3SigmaPoints based on MerweScaledSigmaPoints. In addition, restructures UnscentedKalmanFilter to support different sigma point generators and provides MerweUKF and S3UKF for convenience when working with either kind of filter. S3UKFTest is copied from MerweUKFTest (which is a rename of UnscentedKalmanFilterTest). Curiously, however, in Java the original tolerance used in MerweUKFTest.testDriveConvergence() for the final rotation was too low for S3UKFTest, so the tolerance is increased from 0.000005 (5e-6) radians to 0.00015 (1.5e-4) radians. However, the C++ version still uses the original tolerance. (This difference is probably because Java uses a final rotation of 5.846 degrees while C++ uses a final rotation of 5.846 radians) Closes #8072. Breaking changes: - (C++) UnscentedKalmanFilter has a new template parameter for the sigma point generator type. - (Java) UnscentedKalmanFilter has an additional parameter to every constructor providing an instance of a sigma point generator. - (C++) int MerweScaledSigmaPoints.NumSigmas() has been replaced with constexpr int MerweScaledSigmaPoints::NumSigmas. - (C++) The second parameter of SquareRootUnscentedTransform has been changed from States to NumSigmas. --- .../estimator/MerweScaledSigmaPoints.java | 11 +- .../wpi/first/math/estimator/MerweUKF.java | 117 ++++++ .../first/math/estimator/S3SigmaPoints.java | 172 ++++++++ .../edu/wpi/first/math/estimator/S3UKF.java | 117 ++++++ .../wpi/first/math/estimator/SigmaPoints.java | 64 +++ .../math/estimator/UnscentedKalmanFilter.java | 74 ++-- .../main/native/cpp/estimator/MerweUKF.cpp | 14 + .../{UnscentedKalmanFilter.cpp => S3UKF.cpp} | 6 +- .../frc/estimator/MerweScaledSigmaPoints.h | 11 +- .../native/include/frc/estimator/MerweUKF.h | 25 ++ .../include/frc/estimator/S3SigmaPoints.h | 135 ++++++ .../main/native/include/frc/estimator/S3UKF.h | 25 ++ .../include/frc/estimator/SigmaPoints.h | 26 ++ .../frc/estimator/UnscentedKalmanFilter.h | 71 ++-- .../frc/estimator/UnscentedTransform.h | 25 +- .../estimator/ExtendedKalmanFilterTest.java | 2 +- ...almanFilterTest.java => MerweUKFTest.java} | 39 +- .../math/estimator/S3SigmaPointsTest.java | 86 ++++ .../wpi/first/math/estimator/S3UKFTest.java | 393 ++++++++++++++++++ ...dKalmanFilterTest.cpp => MerweUKFTest.cpp} | 58 +-- .../cpp/estimator/S3SigmaPointsTest.cpp | 50 +++ .../test/native/cpp/estimator/S3UKFTest.cpp | 309 ++++++++++++++ 22 files changed, 1694 insertions(+), 136 deletions(-) create mode 100644 wpimath/src/main/java/edu/wpi/first/math/estimator/MerweUKF.java create mode 100644 wpimath/src/main/java/edu/wpi/first/math/estimator/S3SigmaPoints.java create mode 100644 wpimath/src/main/java/edu/wpi/first/math/estimator/S3UKF.java create mode 100644 wpimath/src/main/java/edu/wpi/first/math/estimator/SigmaPoints.java create mode 100644 wpimath/src/main/native/cpp/estimator/MerweUKF.cpp rename wpimath/src/main/native/cpp/estimator/{UnscentedKalmanFilter.cpp => S3UKF.cpp} (71%) create mode 100644 wpimath/src/main/native/include/frc/estimator/MerweUKF.h create mode 100644 wpimath/src/main/native/include/frc/estimator/S3SigmaPoints.h create mode 100644 wpimath/src/main/native/include/frc/estimator/S3UKF.h create mode 100644 wpimath/src/main/native/include/frc/estimator/SigmaPoints.h rename wpimath/src/test/java/edu/wpi/first/math/estimator/{UnscentedKalmanFilterTest.java => MerweUKFTest.java} (91%) create mode 100644 wpimath/src/test/java/edu/wpi/first/math/estimator/S3SigmaPointsTest.java create mode 100644 wpimath/src/test/java/edu/wpi/first/math/estimator/S3UKFTest.java rename wpimath/src/test/native/cpp/estimator/{UnscentedKalmanFilterTest.cpp => MerweUKFTest.cpp} (83%) create mode 100644 wpimath/src/test/native/cpp/estimator/S3SigmaPointsTest.cpp create mode 100644 wpimath/src/test/native/cpp/estimator/S3UKFTest.cpp diff --git a/wpimath/src/main/java/edu/wpi/first/math/estimator/MerweScaledSigmaPoints.java b/wpimath/src/main/java/edu/wpi/first/math/estimator/MerweScaledSigmaPoints.java index d3131a83e9..d159af5d19 100644 --- a/wpimath/src/main/java/edu/wpi/first/math/estimator/MerweScaledSigmaPoints.java +++ b/wpimath/src/main/java/edu/wpi/first/math/estimator/MerweScaledSigmaPoints.java @@ -15,7 +15,8 @@ import org.ejml.simple.SimpleMatrix; * UnscentedKalmanFilter class. * *

It parametrizes the sigma points using alpha, beta, kappa terms, and is the version seen in - * most publications. Unless you know better, this should be your default choice. + * most publications. S3SigmaPoints is generally preferred due to its greater performance with + * nearly identical accuracy. * *

States is the dimensionality of the state. 2*States+1 weights will be generated. * @@ -24,7 +25,7 @@ import org.ejml.simple.SimpleMatrix; * * @param The dimensionality of the state. 2 * States + 1 weights will be generated. */ -public class MerweScaledSigmaPoints { +public class MerweScaledSigmaPoints implements SigmaPoints { private final double m_alpha; private final int m_kappa; private final Nat m_states; @@ -64,6 +65,7 @@ public class MerweScaledSigmaPoints { * * @return The number of sigma points for each variable in the state x. */ + @Override public int getNumSigmas() { return 2 * m_states.getNum() + 1; } @@ -77,6 +79,7 @@ public class MerweScaledSigmaPoints { * @return Two-dimensional array of sigma points. Each column contains all the sigmas for one * dimension in the problem space. Ordered by Xi_0, Xi_{1..n}, Xi_{n+1..2n}. */ + @Override public Matrix squareRootSigmaPoints(Matrix x, Matrix s) { double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum(); double eta = Math.sqrt(lambda + m_states.getNum()); @@ -125,6 +128,7 @@ public class MerweScaledSigmaPoints { * * @return the weight for each sigma point for the mean. */ + @Override public Matrix getWm() { return m_wm; } @@ -135,6 +139,7 @@ public class MerweScaledSigmaPoints { * @param element Element of vector to return. * @return the element i's weight for the mean. */ + @Override public double getWm(int element) { return m_wm.get(element, 0); } @@ -144,6 +149,7 @@ public class MerweScaledSigmaPoints { * * @return the weight for each sigma point for the covariance. */ + @Override public Matrix getWc() { return m_wc; } @@ -154,6 +160,7 @@ public class MerweScaledSigmaPoints { * @param element Element of vector to return. * @return The element I's weight for the covariance. */ + @Override public double getWc(int element) { return m_wc.get(element, 0); } diff --git a/wpimath/src/main/java/edu/wpi/first/math/estimator/MerweUKF.java b/wpimath/src/main/java/edu/wpi/first/math/estimator/MerweUKF.java new file mode 100644 index 0000000000..d34ebe1ff9 --- /dev/null +++ b/wpimath/src/main/java/edu/wpi/first/math/estimator/MerweUKF.java @@ -0,0 +1,117 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package edu.wpi.first.math.estimator; + +import edu.wpi.first.math.Matrix; +import edu.wpi.first.math.Nat; +import edu.wpi.first.math.Num; +import edu.wpi.first.math.numbers.N1; +import java.util.function.BiFunction; + +/** + * An Unscented Kalman Filter using sigma points and weights from Van der Merwe's 2004 dissertation. + * S3UKF is generally preferred due to its greater performance with nearly identical accuracy. + * + * @param Number of states. + * @param Number of inputs. + * @param Number of outputs. + */ +public class MerweUKF + extends UnscentedKalmanFilter { + /** + * Constructs a Merwe Unscented Kalman Filter. + * + *

See https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices + * for how to select the standard deviations. + * + * @param states A Nat representing the number of states. + * @param outputs A Nat representing the number of outputs. + * @param f A vector-valued function of x and u that returns the derivative of the state vector. + * @param h A vector-valued function of x and u that returns the measurement vector. + * @param stateStdDevs Standard deviations of model states. + * @param measurementStdDevs Standard deviations of measurements. + * @param nominalDt Nominal discretization timestep in seconds. + */ + public MerweUKF( + Nat states, + Nat outputs, + BiFunction, Matrix, Matrix> f, + BiFunction, Matrix, Matrix> h, + Matrix stateStdDevs, + Matrix measurementStdDevs, + double nominalDt) { + super( + new MerweScaledSigmaPoints<>(states), + states, + outputs, + f, + h, + stateStdDevs, + measurementStdDevs, + (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)), + (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)), + Matrix::minus, + Matrix::minus, + Matrix::plus, + nominalDt); + } + + /** + * Constructs a Merwe Unscented Kalman filter with custom mean, residual, and addition functions. + * Using custom functions for arithmetic can be useful if you have angles in the state or + * measurements, because they allow you to correctly account for the modular nature of angle + * arithmetic. + * + *

See https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices + * for how to select the standard deviations. + * + * @param states A Nat representing the number of states. + * @param outputs A Nat representing the number of outputs. + * @param f A vector-valued function of x and u that returns the derivative of the state vector. + * @param h A vector-valued function of x and u that returns the measurement vector. + * @param stateStdDevs Standard deviations of model states. + * @param measurementStdDevs Standard deviations of measurements. + * @param meanFuncX A function that computes the mean of NumSigmas state vectors using a given set + * of weights. + * @param meanFuncY A function that computes the mean of NumSigmas measurement vectors using a + * given set of weights. + * @param residualFuncX A function that computes the residual of two state vectors (i.e. it + * subtracts them.) + * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it + * subtracts them.) + * @param addFuncX A function that adds two state vectors. + * @param nominalDt Nominal discretization timestep in seconds. + */ + public MerweUKF( + Nat states, + Nat outputs, + BiFunction, Matrix, Matrix> f, + BiFunction, Matrix, Matrix> h, + Matrix stateStdDevs, + Matrix measurementStdDevs, + BiFunction, Matrix, Matrix> meanFuncX, + BiFunction, Matrix, Matrix> meanFuncY, + BiFunction, Matrix, Matrix> residualFuncX, + BiFunction, Matrix, Matrix> residualFuncY, + BiFunction, Matrix, Matrix> addFuncX, + double nominalDt) { + super( + new MerweScaledSigmaPoints<>(states), + states, + outputs, + f, + h, + stateStdDevs, + measurementStdDevs, + meanFuncX, + meanFuncY, + residualFuncX, + residualFuncY, + addFuncX, + nominalDt); + } +} diff --git a/wpimath/src/main/java/edu/wpi/first/math/estimator/S3SigmaPoints.java b/wpimath/src/main/java/edu/wpi/first/math/estimator/S3SigmaPoints.java new file mode 100644 index 0000000000..430296e109 --- /dev/null +++ b/wpimath/src/main/java/edu/wpi/first/math/estimator/S3SigmaPoints.java @@ -0,0 +1,172 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package edu.wpi.first.math.estimator; + +import edu.wpi.first.math.Matrix; +import edu.wpi.first.math.Nat; +import edu.wpi.first.math.Num; +import edu.wpi.first.math.numbers.N1; +import org.ejml.simple.SimpleMatrix; + +/** + * Generates sigma points and weights according to Papakonstantinou's paper[1] for the + * UnscentedKalmanFilter class. + * + *

It parameterizes the sigma points using alpha and beta terms. Unless you know better, this + * should be your default choice due to its high accuracy and performance. + * + *

[1] K. Papakonstantinou "A Scaled Spherical Simplex Filter (S3F) with a decreased n + 2 sigma + * points set size and equivalent 2n + 1 Unscented Kalman Filter (UKF) accuracy" + * + * @param The dimenstionality of the state. States + 2 weights will be generated. + */ +public class S3SigmaPoints implements SigmaPoints { + private final Nat m_states; + private final double m_alpha; + private Matrix m_wm; + private Matrix m_wc; + + /** + * Constructs a generator for Papakonstantinou sigma points. + * + * @param states an instance of Num that represents the number of states. + * @param alpha Determines the spread of the sigma points around the mean. Usually a small + * positive value (1e-3). + * @param beta Incorporates prior knowledge of the distribution of the mean. For Gaussian + * distributions, beta = 2 is optimal. + */ + @SuppressWarnings("this-escape") + public S3SigmaPoints(Nat states, double alpha, double beta) { + m_states = states; + m_alpha = alpha; + + computeWeights(beta); + } + + /** + * Constructs a generator for Papakonstantinou sigma points with default values for alpha, beta, + * and kappa. + * + * @param states an instance of Num that represents the number of states. + */ + public S3SigmaPoints(Nat states) { + this(states, 1e-3, 2); + } + + /** + * Returns number of sigma points for each variable in the state x. + * + * @return The number of sigma points for each variable in the state x. + */ + @Override + public int getNumSigmas() { + return m_states.getNum() + 2; + } + + /** + * Computes the sigma points for an unscented Kalman filter given the mean (x) and square-root + * covariance (s) of the filter. + * + * @param x An array of the means. + * @param s Square-root covariance of the filter. + * @return Two-dimensional array of sigma points. Each column contains all the sigmas for one + * dimension in the problem space. Ordered by Xi_0, Xi_{1..n}, Xi_{n+1..2n}. + */ + @Override + public Matrix squareRootSigmaPoints(Matrix x, Matrix s) { + // table (1), equation (12) + double[] q = new double[m_states.getNum()]; + for (int t = 1; t <= m_states.getNum(); ++t) { + q[t - 1] = m_alpha * Math.sqrt(t * (m_states.getNum() + 1) / (double) (t + 1)); + } + + Matrix C = new Matrix<>(new SimpleMatrix(m_states.getNum(), getNumSigmas())); + for (int row = 0; row < m_states.getNum(); ++row) { + C.set(row, 0, 0.0); + } + for (int col = 1; col < getNumSigmas(); ++col) { + for (int row = 0; row < m_states.getNum(); ++row) { + if (row < col - 2) { + C.set(row, col, 0.0); + } else if (row == col - 2) { + C.set(row, col, q[row]); + } else { + C.set(row, col, -q[row] / (row + 1)); + } + } + } + + Matrix sigmas = new Matrix<>(new SimpleMatrix(m_states.getNum(), getNumSigmas())); + for (int col = 0; col < getNumSigmas(); ++col) { + sigmas.setColumn(col, x.plus(s.times(C.extractColumnVector(col)))); + } + + return sigmas; + } + + /** + * Computes the weights for the scaled unscented Kalman filter. + * + * @param beta Incorporates prior knowledge of the distribution of the mean. + */ + private void computeWeights(double beta) { + double alpha_sq = m_alpha * m_alpha; + + double c = 1.0 / (alpha_sq * (m_states.getNum() + 1)); + + Matrix wM = new Matrix<>(new SimpleMatrix(getNumSigmas(), 1)); + Matrix wC = new Matrix<>(new SimpleMatrix(getNumSigmas(), 1)); + wM.fill(c); + wC.fill(c); + + wM.set(0, 0, 1.0 - 1.0 / alpha_sq); + wC.set(0, 0, 1.0 - 1.0 / alpha_sq + (1 - alpha_sq + beta)); + + this.m_wm = wM; + this.m_wc = wC; + } + + /** + * Returns the weight for each sigma point for the mean. + * + * @return the weight for each sigma point for the mean. + */ + @Override + public Matrix getWm() { + return m_wm; + } + + /** + * Returns an element of the weight for each sigma point for the mean. + * + * @param element Element of vector to return. + * @return the element i's weight for the mean. + */ + @Override + public double getWm(int element) { + return m_wm.get(element, 0); + } + + /** + * Returns the weight for each sigma point for the covariance. + * + * @return the weight for each sigma point for the covariance. + */ + @Override + public Matrix getWc() { + return m_wc; + } + + /** + * Returns an element of the weight for each sigma point for the covariance. + * + * @param element Element of vector to return. + * @return The element I's weight for the covariance. + */ + @Override + public double getWc(int element) { + return m_wc.get(element, 0); + } +} diff --git a/wpimath/src/main/java/edu/wpi/first/math/estimator/S3UKF.java b/wpimath/src/main/java/edu/wpi/first/math/estimator/S3UKF.java new file mode 100644 index 0000000000..4de46faee3 --- /dev/null +++ b/wpimath/src/main/java/edu/wpi/first/math/estimator/S3UKF.java @@ -0,0 +1,117 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package edu.wpi.first.math.estimator; + +import edu.wpi.first.math.Matrix; +import edu.wpi.first.math.Nat; +import edu.wpi.first.math.Num; +import edu.wpi.first.math.numbers.N1; +import java.util.function.BiFunction; + +/** + * An Unscented Kalman Filter using sigma points and weights from Papakonstantinou's paper. This is + * generally preferred over other UKF variants due to its high accuracy and performance. + * + * @param Number of states. + * @param Number of inputs. + * @param Number of outputs. + */ +public class S3UKF + extends UnscentedKalmanFilter { + /** + * Constructs a Scaled Spherical Simplex (S3) Unscented Kalman Filter. + * + *

See https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices + * for how to select the standard deviations. + * + * @param states A Nat representing the number of states. + * @param outputs A Nat representing the number of outputs. + * @param f A vector-valued function of x and u that returns the derivative of the state vector. + * @param h A vector-valued function of x and u that returns the measurement vector. + * @param stateStdDevs Standard deviations of model states. + * @param measurementStdDevs Standard deviations of measurements. + * @param nominalDt Nominal discretization timestep in seconds. + */ + public S3UKF( + Nat states, + Nat outputs, + BiFunction, Matrix, Matrix> f, + BiFunction, Matrix, Matrix> h, + Matrix stateStdDevs, + Matrix measurementStdDevs, + double nominalDt) { + super( + new S3SigmaPoints<>(states), + states, + outputs, + f, + h, + stateStdDevs, + measurementStdDevs, + (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)), + (sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)), + Matrix::minus, + Matrix::minus, + Matrix::plus, + nominalDt); + } + + /** + * Constructs a Scaled Spherical Simplex (S3) Unscented Kalman filter with custom mean, residual, + * and addition functions. Using custom functions for arithmetic can be useful if you have angles + * in the state or measurements, because they allow you to correctly account for the modular + * nature of angle arithmetic. + * + *

See https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices + * for how to select the standard deviations. + * + * @param states A Nat representing the number of states. + * @param outputs A Nat representing the number of outputs. + * @param f A vector-valued function of x and u that returns the derivative of the state vector. + * @param h A vector-valued function of x and u that returns the measurement vector. + * @param stateStdDevs Standard deviations of model states. + * @param measurementStdDevs Standard deviations of measurements. + * @param meanFuncX A function that computes the mean of NumSigmas state vectors using a given set + * of weights. + * @param meanFuncY A function that computes the mean of NumSigmas measurement vectors using a + * given set of weights. + * @param residualFuncX A function that computes the residual of two state vectors (i.e. it + * subtracts them.) + * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it + * subtracts them.) + * @param addFuncX A function that adds two state vectors. + * @param nominalDt Nominal discretization timestep in seconds. + */ + public S3UKF( + Nat states, + Nat outputs, + BiFunction, Matrix, Matrix> f, + BiFunction, Matrix, Matrix> h, + Matrix stateStdDevs, + Matrix measurementStdDevs, + BiFunction, Matrix, Matrix> meanFuncX, + BiFunction, Matrix, Matrix> meanFuncY, + BiFunction, Matrix, Matrix> residualFuncX, + BiFunction, Matrix, Matrix> residualFuncY, + BiFunction, Matrix, Matrix> addFuncX, + double nominalDt) { + super( + new S3SigmaPoints<>(states), + states, + outputs, + f, + h, + stateStdDevs, + measurementStdDevs, + meanFuncX, + meanFuncY, + residualFuncX, + residualFuncY, + addFuncX, + nominalDt); + } +} diff --git a/wpimath/src/main/java/edu/wpi/first/math/estimator/SigmaPoints.java b/wpimath/src/main/java/edu/wpi/first/math/estimator/SigmaPoints.java new file mode 100644 index 0000000000..cb0ca2334b --- /dev/null +++ b/wpimath/src/main/java/edu/wpi/first/math/estimator/SigmaPoints.java @@ -0,0 +1,64 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package edu.wpi.first.math.estimator; + +import edu.wpi.first.math.Matrix; +import edu.wpi.first.math.Num; +import edu.wpi.first.math.numbers.N1; + +/** + * A sigma points generator for the UnscentedKalmanFilter class. + * + * @param The dimensionality of the state. + */ +public interface SigmaPoints { + /** + * Returns number of sigma points for each variable in the state x. + * + * @return The number of sigma points for each variable in the state x. + */ + int getNumSigmas(); + + /** + * Computes the sigma points for an unscented Kalman filter given the mean (x) and square-root + * covariance (s) of the filter. + * + * @param x An array of the means. + * @param s Square-root covariance of the filter. + * @return Two-dimensional array of sigma points. Each column contains all the sigmas for one + * dimension in the problem space. Ordered by Xi_0, Xi_{1..n}, Xi_{n+1..2n}. + */ + Matrix squareRootSigmaPoints(Matrix x, Matrix s); + + /** + * Returns the weight for each sigma point for the mean. + * + * @return the weight for each sigma point for the mean. + */ + Matrix getWm(); + + /** + * Returns an element of the weight for each sigma point for the mean. + * + * @param element Element of vector to return. + * @return the element i's weight for the mean. + */ + double getWm(int element); + + /** + * Returns the weight for each sigma point for the covariance. + * + * @return the weight for each sigma point for the covariance. + */ + Matrix getWc(); + + /** + * Returns an element of the weight for each sigma point for the covariance. + * + * @param element Element of vector to return. + * @return The element I's weight for the covariance. + */ + double getWc(int element); +} diff --git a/wpimath/src/main/java/edu/wpi/first/math/estimator/UnscentedKalmanFilter.java b/wpimath/src/main/java/edu/wpi/first/math/estimator/UnscentedKalmanFilter.java index 0d402c7af9..455cd67d46 100644 --- a/wpimath/src/main/java/edu/wpi/first/math/estimator/UnscentedKalmanFilter.java +++ b/wpimath/src/main/java/edu/wpi/first/math/estimator/UnscentedKalmanFilter.java @@ -22,6 +22,11 @@ import org.ejml.simple.SimpleMatrix; * true system state. This is useful because many states cannot be measured directly as a result of * sensor noise, or because the state is "hidden". * + *

This class's constructors require a SigmaPoints generator. For convenience, {@link S3UKF} and + * {@link MerweUKF} subclasses are provided to create a suitable generator for you. S3UKF is + * generally preferred over MerweUKF because of its greater performance while maintaining nearly + * identical accuracy. + * *

Kalman filters use a K gain matrix to determine whether to trust the model or measurements * more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum * of squares error in the state estimate. This K gain is used to correct the state estimate by some @@ -64,7 +69,7 @@ public class UnscentedKalmanFilter m_sigmasF; private double m_dt; - private final MerweScaledSigmaPoints m_pts; + private final SigmaPoints m_pts; /** * Constructs an Unscented Kalman Filter. @@ -73,6 +78,7 @@ public class UnscentedKalmanFilterhttps://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices * for how to select the standard deviations. * + * @param pts A sigma points and weights generator. * @param states A Nat representing the number of states. * @param outputs A Nat representing the number of outputs. * @param f A vector-valued function of x and u that returns the derivative of the state vector. @@ -82,6 +88,7 @@ public class UnscentedKalmanFilter pts, Nat states, Nat outputs, BiFunction, Matrix, Matrix> f, @@ -90,6 +97,7 @@ public class UnscentedKalmanFilter measurementStdDevs, double nominalDt) { this( + pts, states, outputs, f, @@ -113,16 +121,17 @@ public class UnscentedKalmanFilterhttps://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices * for how to select the standard deviations. * + * @param pts A sigma points and weights generator. * @param states A Nat representing the number of states. * @param outputs A Nat representing the number of outputs. * @param f A vector-valued function of x and u that returns the derivative of the state vector. * @param h A vector-valued function of x and u that returns the measurement vector. * @param stateStdDevs Standard deviations of model states. * @param measurementStdDevs Standard deviations of measurements. - * @param meanFuncX A function that computes the mean of 2 * States + 1 state vectors using a + * @param meanFuncX A function that computes the mean of NumSigmas state vectors using a given set + * of weights. + * @param meanFuncY A function that computes the mean of NumSigmas measurement vectors using a * given set of weights. - * @param meanFuncY A function that computes the mean of 2 * States + 1 measurement vectors using - * a given set of weights. * @param residualFuncX A function that computes the residual of two state vectors (i.e. it * subtracts them.) * @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it @@ -131,6 +140,7 @@ public class UnscentedKalmanFilter pts, Nat states, Nat outputs, BiFunction, Matrix, Matrix> f, @@ -160,37 +170,36 @@ public class UnscentedKalmanFilter(states); + m_pts = pts; reset(); } - static - Pair, Matrix> squareRootUnscentedTransform( - Nat s, - Nat dim, - Matrix sigmas, - Matrix Wm, - Matrix Wc, - BiFunction, Matrix, Matrix> meanFunc, - BiFunction, Matrix, Matrix> residualFunc, - Matrix squareRootR) { - if (sigmas.getNumRows() != dim.getNum() || sigmas.getNumCols() != 2 * s.getNum() + 1) { + static Pair, Matrix> squareRootUnscentedTransform( + Nat covdim, + int numSigmas, + Matrix sigmas, + Matrix Wm, + Matrix Wc, + BiFunction, Matrix, Matrix> meanFunc, + BiFunction, Matrix, Matrix> residualFunc, + Matrix squareRootR) { + if (sigmas.getNumRows() != covdim.getNum() || sigmas.getNumCols() != numSigmas) { throw new IllegalArgumentException( - "Sigmas must be covDim by 2 * states + 1! Got " + "Sigmas must be covDim by numSigmas! Got " + sigmas.getNumRows() + " by " + sigmas.getNumCols()); } - if (Wm.getNumRows() != 2 * s.getNum() + 1 || Wm.getNumCols() != 1) { + if (Wm.getNumRows() != numSigmas || Wm.getNumCols() != 1) { throw new IllegalArgumentException( - "Wm must be 2 * states + 1 by 1! Got " + Wm.getNumRows() + " by " + Wm.getNumCols()); + "Wm must be numSigmas by 1! Got " + Wm.getNumRows() + " by " + Wm.getNumCols()); } - if (Wc.getNumRows() != 2 * s.getNum() + 1 || Wc.getNumCols() != 1) { + if (Wc.getNumRows() != numSigmas || Wc.getNumCols() != 1) { throw new IllegalArgumentException( - "Wc must be 2 * states + 1 by 1! Got " + Wc.getNumRows() + " by " + Wc.getNumCols()); + "Wc must be numSigmas by 1! Got " + Wc.getNumRows() + " by " + Wc.getNumCols()); } // New mean is usually just the sum of the sigmas * weights: @@ -208,13 +217,18 @@ public class UnscentedKalmanFilter Sbar = new Matrix<>(new SimpleMatrix(dim.getNum(), 2 * s.getNum() + dim.getNum())); - for (int i = 0; i < 2 * s.getNum(); i++) { + // + // Note that we allow a custom function instead of the difference to allow + // angle wrapping. Furthermore, we allow an arbitrary number of sigma points to + // support similar methods such as the Scaled Spherical Simplex Filter (S3F). + Matrix Sbar = + new Matrix<>(new SimpleMatrix(covdim.getNum(), numSigmas - 1 + covdim.getNum())); + for (int i = 0; i < numSigmas - 1; i++) { Sbar.setColumn( i, residualFunc.apply(sigmas.extractColumnVector(1 + i), x).times(Math.sqrt(Wc.get(1, 0)))); } - Sbar.assignBlock(0, 2 * s.getNum(), squareRootR); + Sbar.assignBlock(0, numSigmas - 1, squareRootR); QRDecompositionHouseholder_DDRM qr = new QRDecompositionHouseholder_DDRM(); var qrStorage = Sbar.transpose().getStorage(); @@ -355,7 +369,7 @@ public class UnscentedKalmanFilter(m_states, Nat.N1()); m_S = new Matrix<>(m_states, m_states); - m_sigmasF = new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1)); + m_sigmasF = new Matrix<>(new SimpleMatrix(m_states.getNum(), m_pts.getNumSigmas())); } /** @@ -397,7 +411,7 @@ public class UnscentedKalmanFilter sigmasH = new Matrix<>(new SimpleMatrix(rows.getNum(), 2 * m_states.getNum() + 1)); + Matrix sigmasH = new Matrix<>(new SimpleMatrix(rows.getNum(), m_pts.getNumSigmas())); var sigmas = m_pts.squareRootSigmaPoints(m_xHat, m_S); for (int i = 0; i < m_pts.getNumSigmas(); i++) { Matrix hRet = h.apply(sigmas.extractColumnVector(i), u); @@ -521,8 +535,8 @@ public class UnscentedKalmanFilter>; +template class EXPORT_TEMPLATE_DEFINE(WPILIB_DLLEXPORT) + UnscentedKalmanFilter<5, 3, 3, MerweScaledSigmaPoints<5>>; + +} // namespace frc diff --git a/wpimath/src/main/native/cpp/estimator/UnscentedKalmanFilter.cpp b/wpimath/src/main/native/cpp/estimator/S3UKF.cpp similarity index 71% rename from wpimath/src/main/native/cpp/estimator/UnscentedKalmanFilter.cpp rename to wpimath/src/main/native/cpp/estimator/S3UKF.cpp index d5e869b6cb..d0302d7ca9 100644 --- a/wpimath/src/main/native/cpp/estimator/UnscentedKalmanFilter.cpp +++ b/wpimath/src/main/native/cpp/estimator/S3UKF.cpp @@ -2,13 +2,13 @@ // Open Source Software; you can modify and/or share it under the terms of // the WPILib BSD license file in the root directory of this project. -#include "frc/estimator/UnscentedKalmanFilter.h" +#include "frc/estimator/S3UKF.h" namespace frc { template class EXPORT_TEMPLATE_DEFINE(WPILIB_DLLEXPORT) - UnscentedKalmanFilter<3, 3, 1>; + UnscentedKalmanFilter<3, 3, 1, S3SigmaPoints<3>>; template class EXPORT_TEMPLATE_DEFINE(WPILIB_DLLEXPORT) - UnscentedKalmanFilter<5, 3, 3>; + UnscentedKalmanFilter<5, 3, 3, S3SigmaPoints<5>>; } // namespace frc diff --git a/wpimath/src/main/native/include/frc/estimator/MerweScaledSigmaPoints.h b/wpimath/src/main/native/include/frc/estimator/MerweScaledSigmaPoints.h index 2605d5e03e..1b5b5cfa1f 100644 --- a/wpimath/src/main/native/include/frc/estimator/MerweScaledSigmaPoints.h +++ b/wpimath/src/main/native/include/frc/estimator/MerweScaledSigmaPoints.h @@ -15,8 +15,8 @@ namespace frc { * dissertation[1] for the UnscentedKalmanFilter class. * * It parametrizes the sigma points using alpha, beta, kappa terms, and is the - * version seen in most publications. Unless you know better, this should be - * your default choice. + * version seen in most publications. S3SigmaPoints is generally preferred due + * to its greater performance with nearly identical accuracy. * * [1] R. Van der Merwe "Sigma-Point Kalman Filters for Probabilistic * Inference in Dynamic State-Space Models" (Doctoral dissertation) @@ -27,6 +27,8 @@ namespace frc { template class MerweScaledSigmaPoints { public: + static constexpr int NumSigmas = 2 * States + 1; + /** * Constructs a generator for Van der Merwe scaled sigma points. * @@ -44,11 +46,6 @@ class MerweScaledSigmaPoints { ComputeWeights(beta); } - /** - * Returns number of sigma points for each variable in the state x. - */ - int NumSigmas() { return 2 * States + 1; } - /** * Computes the sigma points for an unscented Kalman filter given the mean * (x) and square-root covariance (S) of the filter. diff --git a/wpimath/src/main/native/include/frc/estimator/MerweUKF.h b/wpimath/src/main/native/include/frc/estimator/MerweUKF.h new file mode 100644 index 0000000000..1b5165aeb4 --- /dev/null +++ b/wpimath/src/main/native/include/frc/estimator/MerweUKF.h @@ -0,0 +1,25 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#pragma once + +#include + +#include "frc/estimator/MerweScaledSigmaPoints.h" +#include "frc/estimator/UnscentedKalmanFilter.h" + +namespace frc { + +template +using MerweUKF = UnscentedKalmanFilter>; + +// Because MerweUKF is a type alias and not a class, we have to use +// UnscentedKalmanFilter instead +extern template class EXPORT_TEMPLATE_DECLARE(WPILIB_DLLEXPORT) + UnscentedKalmanFilter<3, 3, 1, MerweScaledSigmaPoints<3>>; +extern template class EXPORT_TEMPLATE_DECLARE(WPILIB_DLLEXPORT) + UnscentedKalmanFilter<5, 3, 3, MerweScaledSigmaPoints<5>>; + +} // namespace frc diff --git a/wpimath/src/main/native/include/frc/estimator/S3SigmaPoints.h b/wpimath/src/main/native/include/frc/estimator/S3SigmaPoints.h new file mode 100644 index 0000000000..a19cd072ac --- /dev/null +++ b/wpimath/src/main/native/include/frc/estimator/S3SigmaPoints.h @@ -0,0 +1,135 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#pragma once + +#include + +#include "frc/EigenCore.h" +#include "frc/estimator/SigmaPoints.h" + +namespace frc { + +/** + * Generates sigma points and weights according to Papakonstantinou's paper[1] + * for the UnscentedKalmanFilter class. + * + * It parameterizes the sigma points using alpha and beta terms. Unless you know + * better, this should be your default choice due to its high accuracy and + * performance. + * + * [1] K. Papakonstantinou "A Scaled Spherical Simplex Filter (S3F) with a + * decreased n + 2 sigma points set size and equivalent 2n + 1 Unscented Kalman + * Filter (UKF) accuracy" + * + * @tparam States The dimenstionality of the state. States + 2 weights will be + * generated. + */ +template +class S3SigmaPoints { + public: + static constexpr int NumSigmas = States + 2; + + /** + * Constructs a generator for Papakonstantinou sigma points. + * + * @param alpha Determines the spread of the sigma points around the mean. + * Usually a small positive value (1e-3). + * @param beta Incorporates prior knowledge of the distribution of the mean. + * For Gaussian distributions, beta = 2 is optimal. + */ + explicit S3SigmaPoints(double alpha = 1e-3, double beta = 2) + : m_alpha{alpha} { + ComputeWeights(beta); + } + + /** + * Computes the sigma points for an unscented Kalman filter given the mean (x) + * and square-root covariance (S) of the filter. + * + * @param x An array of the means. + * @param S Square-root covariance of the filter. + * + * @return Two dimensional array of sigma points. Each column contains all of + * the sigmas for one dimension in the problem space. Ordered by Xi_0, + * Xi_{1..n+1}. + */ + Matrixd SquareRootSigmaPoints( + const Vectord& x, const Matrixd& S) const { + // table (1), equation (12) + wpi::array q(wpi::empty_array); + for (size_t t = 1; t <= States; ++t) { + q[t - 1] = m_alpha * std::sqrt(static_cast(t * (States + 1)) / + static_cast(t + 1)); + } + + Matrixd C; + C.template block(0, 0) = Vectord::Constant(0.0); + for (int col = 1; col < NumSigmas; ++col) { + for (int row = 0; row < States; ++row) { + if (row < col - 2) { + C(row, col) = 0.0; + } else if (row == col - 2) { + C(row, col) = q[row]; + } else { + C(row, col) = -q[row] / (row + 1); + } + } + } + + Matrixd sigmas; + for (int col = 0; col < NumSigmas; ++col) { + sigmas.col(col) = x + S * C.col(col); + } + + return sigmas; + } + + /** + * Returns the weight for each sigma point for the mean. + */ + const Vectord& Wm() const { return m_Wm; } + + /** + * Returns an element of the weight for each sigma point for the mean. + * + * @param i Element of vector to return. + */ + double Wm(int i) const { return m_Wm(i, 0); } + + /** + * Returns the weight for each sigma point for the covariance. + */ + const Vectord& Wc() const { return m_Wc; } + + /** + * Returns an element of the weight for each sigma point for the covariance. + * + * @param i Element of vector to return. + */ + double Wc(int i) const { return m_Wc(i, 0); } + + private: + Vectord m_Wm; + Vectord m_Wc; + double m_alpha; + + /** + * Computes the weights for the scaled unscented Kalman filter. + * + * @param beta Incorporates prior knowledge of the distribution of the mean. + */ + void ComputeWeights(double beta) { + double alpha_sq = m_alpha * m_alpha; + + double c = 1.0 / (alpha_sq * (States + 1)); + m_Wm = Vectord::Constant(c); + m_Wc = Vectord::Constant(c); + + m_Wm(0) = 1.0 - 1.0 / alpha_sq; + m_Wc(0) = 1.0 - 1.0 / alpha_sq + (1 - alpha_sq + beta); + } +}; + +} // namespace frc diff --git a/wpimath/src/main/native/include/frc/estimator/S3UKF.h b/wpimath/src/main/native/include/frc/estimator/S3UKF.h new file mode 100644 index 0000000000..e94f2af0c6 --- /dev/null +++ b/wpimath/src/main/native/include/frc/estimator/S3UKF.h @@ -0,0 +1,25 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#pragma once + +#include + +#include "frc/estimator/S3SigmaPoints.h" +#include "frc/estimator/UnscentedKalmanFilter.h" + +namespace frc { + +template +using S3UKF = + UnscentedKalmanFilter>; + +// Because S3UKF is a type alias and not a class, we have to use +// UnscentedKalmanFilter instead +extern template class EXPORT_TEMPLATE_DECLARE(WPILIB_DLLEXPORT) + UnscentedKalmanFilter<3, 3, 1, S3SigmaPoints<3>>; +extern template class EXPORT_TEMPLATE_DECLARE(WPILIB_DLLEXPORT) + UnscentedKalmanFilter<5, 3, 3, S3SigmaPoints<5>>; + +} // namespace frc diff --git a/wpimath/src/main/native/include/frc/estimator/SigmaPoints.h b/wpimath/src/main/native/include/frc/estimator/SigmaPoints.h new file mode 100644 index 0000000000..a459197c75 --- /dev/null +++ b/wpimath/src/main/native/include/frc/estimator/SigmaPoints.h @@ -0,0 +1,26 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#pragma once + +#include + +#include "frc/EigenCore.h" + +namespace frc { + +template +concept SigmaPoints = + requires(T t, Vectord x, Matrixd S, int i) { + { T::NumSigmas } -> std::convertible_to; + { + t.SquareRootSigmaPoints(x, S) + } -> std::same_as>; + { t.Wm() } -> std::convertible_to>; + { t.Wm(i) } -> std::same_as; + { t.Wc() } -> std::convertible_to>; + { t.Wc(i) } -> std::same_as; + } && std::default_initializable; + +} // namespace frc diff --git a/wpimath/src/main/native/include/frc/estimator/UnscentedKalmanFilter.h b/wpimath/src/main/native/include/frc/estimator/UnscentedKalmanFilter.h index 21838870a8..2d2ed82c62 100644 --- a/wpimath/src/main/native/include/frc/estimator/UnscentedKalmanFilter.h +++ b/wpimath/src/main/native/include/frc/estimator/UnscentedKalmanFilter.h @@ -13,7 +13,7 @@ #include "frc/EigenCore.h" #include "frc/StateSpaceUtil.h" -#include "frc/estimator/MerweScaledSigmaPoints.h" +#include "frc/estimator/SigmaPoints.h" #include "frc/estimator/UnscentedTransform.h" #include "frc/system/Discretization.h" #include "frc/system/NumericalIntegration.h" @@ -28,6 +28,11 @@ namespace frc { * be measured directly as a result of sensor noise, or because the state is * "hidden". * + * This class requires a SigmaPoints template parameter. For convenience, S3UKF + * and MerweUKF type aliases are provided to specify a suitable generator for + * you. S3UKF is generally preferred over MerweUKF because of its greater + * performance while maintaining nearly identical accuracy. + * * Kalman filters use a K gain matrix to determine whether to trust the model or * measurements more. Kalman filter theory uses statistics to compute an optimal * K gain which minimizes the sum of squares error in the state estimate. This K @@ -51,10 +56,13 @@ namespace frc { * @tparam States Number of states. * @tparam Inputs Number of inputs. * @tparam Outputs Number of outputs. + * @tparam SigmaPoints Type used to generate sigma points and weights. */ -template +template SigmaPoints> class UnscentedKalmanFilter { public: + static constexpr int NumSigmas = SigmaPoints::NumSigmas; + using StateVector = Vectord; using InputVector = Vectord; using OutputVector = Vectord; @@ -87,12 +95,12 @@ class UnscentedKalmanFilter { : m_f(std::move(f)), m_h(std::move(h)) { m_contQ = MakeCovMatrix(stateStdDevs); m_contR = MakeCovMatrix(measurementStdDevs); - m_meanFuncX = [](const Matrixd& sigmas, - const Vectord<2 * States + 1>& Wm) -> StateVector { + m_meanFuncX = [](const Matrixd& sigmas, + const Vectord& Wm) -> StateVector { return sigmas * Wm; }; - m_meanFuncY = [](const Matrixd& sigmas, - const Vectord<2 * States + 1>& Wc) -> OutputVector { + m_meanFuncY = [](const Matrixd& sigmas, + const Vectord& Wc) -> OutputVector { return sigmas * Wc; }; m_residualFuncX = [](const StateVector& a, @@ -141,11 +149,11 @@ class UnscentedKalmanFilter { std::function f, std::function h, const StateArray& stateStdDevs, const OutputArray& measurementStdDevs, - std::function&, - const Vectord<2 * States + 1>&)> + std::function&, + const Vectord&)> meanFuncX, - std::function&, - const Vectord<2 * States + 1>&)> + std::function&, + const Vectord&)> meanFuncY, std::function residualFuncX, @@ -257,7 +265,7 @@ class UnscentedKalmanFilter { // Generate sigma points around the state mean // // equation (17) - Matrixd sigmas = + Matrixd sigmas = m_pts.SquareRootSigmaPoints(m_xHat, m_S); // Project each sigma point forward in time according to the @@ -267,7 +275,7 @@ class UnscentedKalmanFilter { // sigmasF = 𝒳ₖ,ₖ₋₁ or just 𝒳 for readability // // equation (18) - for (int i = 0; i < m_pts.NumSigmas(); ++i) { + for (int i = 0; i < NumSigmas; ++i) { StateVector x = sigmas.template block(0, i); m_sigmasF.template block(0, i) = RK4(m_f, x, u, dt); } @@ -276,7 +284,7 @@ class UnscentedKalmanFilter { // to compute the prior state mean and covariance // // equations (18) (19) and (20) - auto [xHat, S] = SquareRootUnscentedTransform( + auto [xHat, S] = SquareRootUnscentedTransform( m_sigmasF, m_pts.Wm(), m_pts.Wc(), m_meanFuncX, m_residualFuncX, discQ.template triangularView()); m_xHat = xHat; @@ -327,8 +335,8 @@ class UnscentedKalmanFilter { const InputVector& u, const Vectord& y, std::function(const StateVector&, const InputVector&)> h, const Matrixd& R) { - auto meanFuncY = [](const Matrixd& sigmas, - const Vectord<2 * States + 1>& Wc) -> Vectord { + auto meanFuncY = [](const Matrixd& sigmas, + const Vectord& Wc) -> Vectord { return sigmas * Wc; }; auto residualFuncX = [](const StateVector& a, @@ -358,7 +366,7 @@ class UnscentedKalmanFilter { * @param h A vector-valued function of x and u that returns the * measurement vector. * @param R Continuous measurement noise covariance matrix. - * @param meanFuncY A function that computes the mean of 2 * States + 1 + * @param meanFuncY A function that computes the mean of NumSigmas * measurement vectors using a given set of weights. * @param residualFuncY A function that computes the residual of two * measurement vectors (i.e. it subtracts them.) @@ -371,8 +379,8 @@ class UnscentedKalmanFilter { const InputVector& u, const Vectord& y, std::function(const StateVector&, const InputVector&)> h, const Matrixd& R, - std::function(const Matrixd&, - const Vectord<2 * States + 1>&)> + std::function(const Matrixd&, + const Vectord&)> meanFuncY, std::function(const Vectord&, const Vectord&)> residualFuncY, @@ -392,10 +400,10 @@ class UnscentedKalmanFilter { // This differs from equation (22) which uses // the prior sigma points, regenerating them allows // multiple measurement updates per time update - Matrixd sigmasH; - Matrixd sigmas = + Matrixd sigmasH; + Matrixd sigmas = m_pts.SquareRootSigmaPoints(m_xHat, m_S); - for (int i = 0; i < m_pts.NumSigmas(); ++i) { + for (int i = 0; i < NumSigmas; ++i) { sigmasH.template block(0, i) = h(sigmas.template block(0, i), u); } @@ -405,7 +413,7 @@ class UnscentedKalmanFilter { // covariance. // // equations (23) (24) and (25) - auto [yHat, Sy] = SquareRootUnscentedTransform( + auto [yHat, Sy] = SquareRootUnscentedTransform( sigmasH, m_pts.Wm(), m_pts.Wc(), meanFuncY, residualFuncY, discR.template triangularView()); @@ -419,7 +427,7 @@ class UnscentedKalmanFilter { // equation (26) Matrixd Pxy; Pxy.setZero(); - for (int i = 0; i < m_pts.NumSigmas(); ++i) { + for (int i = 0; i < NumSigmas; ++i) { Pxy += m_pts.Wc(i) * (residualFuncX(m_sigmasF.template block(0, i), m_xHat)) * @@ -467,11 +475,11 @@ class UnscentedKalmanFilter { private: std::function m_f; std::function m_h; - std::function&, - const Vectord<2 * States + 1>&)> + std::function&, + const Vectord&)> m_meanFuncX; - std::function&, - const Vectord<2 * States + 1>&)> + std::function&, + const Vectord&)> m_meanFuncY; std::function m_residualFuncX; @@ -482,15 +490,10 @@ class UnscentedKalmanFilter { StateMatrix m_S; StateMatrix m_contQ; Matrixd m_contR; - Matrixd m_sigmasF; + Matrixd m_sigmasF; units::second_t m_dt; - MerweScaledSigmaPoints m_pts; + SigmaPoints m_pts; }; -extern template class EXPORT_TEMPLATE_DECLARE(WPILIB_DLLEXPORT) - UnscentedKalmanFilter<3, 3, 1>; -extern template class EXPORT_TEMPLATE_DECLARE(WPILIB_DLLEXPORT) - UnscentedKalmanFilter<5, 3, 3>; - } // namespace frc diff --git a/wpimath/src/main/native/include/frc/estimator/UnscentedTransform.h b/wpimath/src/main/native/include/frc/estimator/UnscentedTransform.h index 0002638d79..f8b4a469c8 100644 --- a/wpimath/src/main/native/include/frc/estimator/UnscentedTransform.h +++ b/wpimath/src/main/native/include/frc/estimator/UnscentedTransform.h @@ -22,11 +22,11 @@ namespace frc { * * @tparam CovDim Dimension of covariance of sigma points after passing * through the transform. - * @tparam States Number of states. + * @tparam NumSigmas Number of sigma points. * @param sigmas List of sigma points. * @param Wm Weights for the mean. * @param Wc Weights for the covariance. - * @param meanFunc A function that computes the mean of 2 * States + 1 state + * @param meanFunc A function that computes the mean of NumSigmas state * vectors using a given set of weights. * @param residualFunc A function that computes the residual of two state * vectors (i.e. it subtracts them.) @@ -35,13 +35,13 @@ namespace frc { * @return Tuple of x, mean of sigma points; S, square-root covariance of * sigmas. */ -template +template std::tuple, Matrixd> SquareRootUnscentedTransform( - const Matrixd& sigmas, - const Vectord<2 * States + 1>& Wm, const Vectord<2 * States + 1>& Wc, - std::function(const Matrixd&, - const Vectord<2 * States + 1>&)> + const Matrixd& sigmas, const Vectord& Wm, + const Vectord& Wc, + std::function(const Matrixd&, + const Vectord&)> meanFunc, std::function(const Vectord&, const Vectord&)> @@ -62,13 +62,18 @@ SquareRootUnscentedTransform( // [√{W₁⁽ᶜ⁾}(𝒳_{1:2L} - x̂) √{Rᵛ}] // // the part of equations (20) and (24) within the "qr{}" - Matrixd Sbar; - for (int i = 0; i < States * 2; i++) { + // + // Note that we allow a custom function instead of the difference to allow + // angle wrapping. Furthermore, we allow an arbitrary number of sigma points + // to support similar methods such as the Scaled Spherical Simplex Filter + // (S3F). + Matrixd Sbar; + for (int i = 0; i < NumSigmas - 1; i++) { Sbar.template block(0, i) = std::sqrt(Wc[1]) * residualFunc(sigmas.template block(0, 1 + i), x); } - Sbar.template block(0, States * 2) = squareRootR; + Sbar.template block(0, NumSigmas - 1) = squareRootR; // Compute the square-root covariance of the sigma points. // diff --git a/wpimath/src/test/java/edu/wpi/first/math/estimator/ExtendedKalmanFilterTest.java b/wpimath/src/test/java/edu/wpi/first/math/estimator/ExtendedKalmanFilterTest.java index 29554b176d..0fc743ff2a 100644 --- a/wpimath/src/test/java/edu/wpi/first/math/estimator/ExtendedKalmanFilterTest.java +++ b/wpimath/src/test/java/edu/wpi/first/math/estimator/ExtendedKalmanFilterTest.java @@ -114,7 +114,7 @@ class ExtendedKalmanFilterTest { List waypoints = List.of( new Pose2d(2.75, 22.521, Rotation2d.kZero), - new Pose2d(24.73, 19.68, Rotation2d.fromDegrees(5.846))); + new Pose2d(24.73, 19.68, Rotation2d.fromRadians(5.846))); var trajectory = TrajectoryGenerator.generateTrajectory(waypoints, new TrajectoryConfig(8.8, 0.1)); diff --git a/wpimath/src/test/java/edu/wpi/first/math/estimator/UnscentedKalmanFilterTest.java b/wpimath/src/test/java/edu/wpi/first/math/estimator/MerweUKFTest.java similarity index 91% rename from wpimath/src/test/java/edu/wpi/first/math/estimator/UnscentedKalmanFilterTest.java rename to wpimath/src/test/java/edu/wpi/first/math/estimator/MerweUKFTest.java index aa723fcf52..70041ebac2 100644 --- a/wpimath/src/test/java/edu/wpi/first/math/estimator/UnscentedKalmanFilterTest.java +++ b/wpimath/src/test/java/edu/wpi/first/math/estimator/MerweUKFTest.java @@ -33,7 +33,7 @@ import java.util.Collections; import java.util.List; import org.junit.jupiter.api.Test; -class UnscentedKalmanFilterTest { +class MerweUKFTest { private static Matrix driveDynamics(Matrix x, Matrix u) { var motors = DCMotor.getCIM(2); @@ -78,12 +78,12 @@ class UnscentedKalmanFilterTest { var dt = 0.005; assertDoesNotThrow( () -> { - UnscentedKalmanFilter observer = - new UnscentedKalmanFilter<>( + MerweUKF observer = + new MerweUKF<>( Nat.N5(), Nat.N3(), - UnscentedKalmanFilterTest::driveDynamics, - UnscentedKalmanFilterTest::driveLocalMeasurementModel, + MerweUKFTest::driveDynamics, + MerweUKFTest::driveLocalMeasurementModel, VecBuilder.fill(0.5, 0.5, 10.0, 1.0, 1.0), VecBuilder.fill(0.0001, 0.01, 0.01), AngleStatistics.angleMean(2), @@ -107,7 +107,7 @@ class UnscentedKalmanFilterTest { Nat.N5(), u, globalY, - UnscentedKalmanFilterTest::driveGlobalMeasurementModel, + MerweUKFTest::driveGlobalMeasurementModel, R, AngleStatistics.angleMean(2), AngleStatistics.angleResidual(2), @@ -121,12 +121,12 @@ class UnscentedKalmanFilterTest { final double dt = 0.005; final double rb = 0.8382 / 2.0; // Robot radius - UnscentedKalmanFilter observer = - new UnscentedKalmanFilter<>( + MerweUKF observer = + new MerweUKF<>( Nat.N5(), Nat.N3(), - UnscentedKalmanFilterTest::driveDynamics, - UnscentedKalmanFilterTest::driveLocalMeasurementModel, + MerweUKFTest::driveDynamics, + MerweUKFTest::driveLocalMeasurementModel, VecBuilder.fill(0.5, 0.5, 10.0, 1.0, 1.0), VecBuilder.fill(0.0001, 0.5, 0.5), AngleStatistics.angleMean(2), @@ -139,7 +139,7 @@ class UnscentedKalmanFilterTest { List waypoints = List.of( new Pose2d(2.75, 22.521, Rotation2d.kZero), - new Pose2d(24.73, 19.68, Rotation2d.fromDegrees(5.846))); + new Pose2d(24.73, 19.68, Rotation2d.fromRadians(5.846))); var trajectory = TrajectoryGenerator.generateTrajectory(waypoints, new TrajectoryConfig(8.8, 0.1)); @@ -150,7 +150,7 @@ class UnscentedKalmanFilterTest { NumericalJacobian.numericalJacobianU( Nat.N5(), Nat.N2(), - UnscentedKalmanFilterTest::driveDynamics, + MerweUKFTest::driveDynamics, new Matrix<>(Nat.N5(), Nat.N1()), new Matrix<>(Nat.N2(), Nat.N1())); @@ -190,8 +190,7 @@ class UnscentedKalmanFilterTest { observer.predict(u, dt); r = nextR; - trueXhat = - NumericalIntegration.rk4(UnscentedKalmanFilterTest::driveDynamics, trueXhat, u, dt); + trueXhat = NumericalIntegration.rk4(MerweUKFTest::driveDynamics, trueXhat, u, dt); } var localY = driveLocalMeasurementModel(trueXhat, u); @@ -205,7 +204,7 @@ class UnscentedKalmanFilterTest { Nat.N5(), u, globalY, - UnscentedKalmanFilterTest::driveGlobalMeasurementModel, + MerweUKFTest::driveGlobalMeasurementModel, R, AngleStatistics.angleMean(2), AngleStatistics.angleResidual(2), @@ -226,7 +225,7 @@ class UnscentedKalmanFilterTest { var dt = 0.020; var plant = LinearSystemId.identifyVelocitySystem(0.02, 0.006); var observer = - new UnscentedKalmanFilter<>( + new MerweUKF<>( Nat.N1(), Nat.N1(), (x, u) -> plant.getA().times(x).plus(plant.getB().times(u)), @@ -256,7 +255,7 @@ class UnscentedKalmanFilterTest { var dt = 0.005; var observer = - new UnscentedKalmanFilter<>( + new MerweUKF<>( Nat.N2(), Nat.N2(), (x, u) -> x, @@ -371,11 +370,11 @@ class UnscentedKalmanFilterTest { 0.0, 0.0, 10); var observer = - new UnscentedKalmanFilter( + new MerweUKF( Nat.N4(), Nat.N3(), - UnscentedKalmanFilterTest::motorDynamics, - UnscentedKalmanFilterTest::motorMeasurementModel, + MerweUKFTest::motorDynamics, + MerweUKFTest::motorMeasurementModel, VecBuilder.fill(0.1, 1.0, 1e-10, 1e-10), VecBuilder.fill(pos_stddev, vel_stddev, accel_stddev), dt); diff --git a/wpimath/src/test/java/edu/wpi/first/math/estimator/S3SigmaPointsTest.java b/wpimath/src/test/java/edu/wpi/first/math/estimator/S3SigmaPointsTest.java new file mode 100644 index 0000000000..b9c83cbfb8 --- /dev/null +++ b/wpimath/src/test/java/edu/wpi/first/math/estimator/S3SigmaPointsTest.java @@ -0,0 +1,86 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package edu.wpi.first.math.estimator; + +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import edu.wpi.first.math.MatBuilder; +import edu.wpi.first.math.Matrix; +import edu.wpi.first.math.Nat; +import edu.wpi.first.math.VecBuilder; +import edu.wpi.first.math.Vector; +import edu.wpi.first.math.numbers.N2; +import org.junit.jupiter.api.Test; + +class S3SigmaPointsTest { + @Test + void testSimplex() { + double alpha = 1e-3; + double beta = 2; + Nat N = Nat.N2(); + + var sigmaPoints = new S3SigmaPoints<>(N, alpha, beta); + var points = sigmaPoints.squareRootSigmaPoints(new Vector<>(N), Matrix.eye(N)); + + var v1 = new Vector<>(points.extractColumnVector(1)); + var v2 = new Vector<>(points.extractColumnVector(2)); + var v3 = new Vector<>(points.extractColumnVector(3)); + + assertAll( + () -> assertEquals(alpha * Math.sqrt(N.getNum()), v1.norm(), 1e-15), + () -> assertEquals(alpha * Math.sqrt(N.getNum()), v2.norm(), 1e-15), + () -> assertEquals(alpha * Math.sqrt(N.getNum()), v3.norm(), 1e-15), + () -> assertEquals(v1.minus(v2).norm(), v1.minus(v3).norm(), 1e-15), + () -> assertEquals(v1.minus(v2).norm(), v2.minus(v3).norm(), 1e-15)); + } + + @Test + void testZeroMeanPoints() { + var sigmaPoints = new S3SigmaPoints<>(Nat.N2()); + var points = + sigmaPoints.squareRootSigmaPoints( + VecBuilder.fill(0, 0), MatBuilder.fill(Nat.N2(), Nat.N2(), 1, 0, 0, 1)); + + assertTrue( + points.isEqual( + MatBuilder.fill( + Nat.N2(), + Nat.N4(), + 0.0, + -0.00122474, + 0.00122474, + 0.0, + 0.0, + -0.00070711, + -0.00070711, + 0.00141421), + 1E-6)); + } + + @Test + void testNonzeroMeanPoints() { + var sigmaPoints = new S3SigmaPoints<>(Nat.N2()); + var points = + sigmaPoints.squareRootSigmaPoints( + VecBuilder.fill(1, 2), MatBuilder.fill(Nat.N2(), Nat.N2(), 1, 0, 0, Math.sqrt(10))); + + assertTrue( + points.isEqual( + MatBuilder.fill( + Nat.N2(), + Nat.N4(), + 1.0, + 0.99877526, + 1.00122474, + 1.0, + 2.0, + 1.99776393, + 1.99776393, + 2.00447214), + 1E-6)); + } +} diff --git a/wpimath/src/test/java/edu/wpi/first/math/estimator/S3UKFTest.java b/wpimath/src/test/java/edu/wpi/first/math/estimator/S3UKFTest.java new file mode 100644 index 0000000000..be9740820e --- /dev/null +++ b/wpimath/src/test/java/edu/wpi/first/math/estimator/S3UKFTest.java @@ -0,0 +1,393 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package edu.wpi.first.math.estimator; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import edu.wpi.first.math.MatBuilder; +import edu.wpi.first.math.MathUtil; +import edu.wpi.first.math.Matrix; +import edu.wpi.first.math.Nat; +import edu.wpi.first.math.StateSpaceUtil; +import edu.wpi.first.math.VecBuilder; +import edu.wpi.first.math.geometry.Pose2d; +import edu.wpi.first.math.geometry.Rotation2d; +import edu.wpi.first.math.numbers.N1; +import edu.wpi.first.math.numbers.N2; +import edu.wpi.first.math.numbers.N3; +import edu.wpi.first.math.numbers.N4; +import edu.wpi.first.math.numbers.N5; +import edu.wpi.first.math.system.Discretization; +import edu.wpi.first.math.system.NumericalIntegration; +import edu.wpi.first.math.system.NumericalJacobian; +import edu.wpi.first.math.system.plant.DCMotor; +import edu.wpi.first.math.system.plant.LinearSystemId; +import edu.wpi.first.math.trajectory.TrajectoryConfig; +import edu.wpi.first.math.trajectory.TrajectoryGenerator; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.junit.jupiter.api.Test; + +class S3UKFTest { + private static Matrix driveDynamics(Matrix x, Matrix u) { + var motors = DCMotor.getCIM(2); + + // var gLow = 15.32; // Low gear ratio + var gHigh = 7.08; // High gear ratio + var rb = 0.8382 / 2.0; // Robot radius + var r = 0.0746125; // Wheel radius + var m = 63.503; // Robot mass + var J = 5.6; // Robot moment of inertia + + var C1 = -Math.pow(gHigh, 2) * motors.Kt / (motors.Kv * motors.R * r * r); + var C2 = gHigh * motors.Kt / (motors.R * r); + var k1 = 1.0 / m + Math.pow(rb, 2) / J; + var k2 = 1.0 / m - Math.pow(rb, 2) / J; + + var vl = x.get(3, 0); + var vr = x.get(4, 0); + var Vl = u.get(0, 0); + var Vr = u.get(1, 0); + + var v = 0.5 * (vl + vr); + return VecBuilder.fill( + v * Math.cos(x.get(2, 0)), + v * Math.sin(x.get(2, 0)), + (vr - vl) / (2.0 * rb), + k1 * (C1 * vl + C2 * Vl) + k2 * (C1 * vr + C2 * Vr), + k2 * (C1 * vl + C2 * Vl) + k1 * (C1 * vr + C2 * Vr)); + } + + @SuppressWarnings("PMD.UnusedFormalParameter") + private static Matrix driveLocalMeasurementModel(Matrix x, Matrix u) { + return VecBuilder.fill(x.get(2, 0), x.get(3, 0), x.get(4, 0)); + } + + @SuppressWarnings("PMD.UnusedFormalParameter") + private static Matrix driveGlobalMeasurementModel(Matrix x, Matrix u) { + return x.copy(); + } + + @Test + void testDriveInit() { + var dt = 0.005; + assertDoesNotThrow( + () -> { + S3UKF observer = + new S3UKF<>( + Nat.N5(), + Nat.N3(), + S3UKFTest::driveDynamics, + S3UKFTest::driveLocalMeasurementModel, + VecBuilder.fill(0.5, 0.5, 10.0, 1.0, 1.0), + VecBuilder.fill(0.0001, 0.01, 0.01), + AngleStatistics.angleMean(2), + AngleStatistics.angleMean(0), + AngleStatistics.angleResidual(2), + AngleStatistics.angleResidual(0), + AngleStatistics.angleAdd(2), + dt); + + var u = VecBuilder.fill(12.0, 12.0); + observer.predict(u, dt); + + var localY = driveLocalMeasurementModel(observer.getXhat(), u); + observer.correct(u, localY); + + var globalY = driveGlobalMeasurementModel(observer.getXhat(), u); + var R = + StateSpaceUtil.makeCovarianceMatrix( + Nat.N5(), VecBuilder.fill(0.01, 0.01, 0.0001, 0.01, 0.01)); + observer.correct( + Nat.N5(), + u, + globalY, + S3UKFTest::driveGlobalMeasurementModel, + R, + AngleStatistics.angleMean(2), + AngleStatistics.angleResidual(2), + AngleStatistics.angleResidual(2), + AngleStatistics.angleAdd(2)); + }); + } + + @Test + void testDriveConvergence() { + final double dt = 0.005; + final double rb = 0.8382 / 2.0; // Robot radius + + S3UKF observer = + new S3UKF<>( + Nat.N5(), + Nat.N3(), + S3UKFTest::driveDynamics, + S3UKFTest::driveLocalMeasurementModel, + VecBuilder.fill(0.5, 0.5, 10.0, 1.0, 1.0), + VecBuilder.fill(0.0001, 0.5, 0.5), + AngleStatistics.angleMean(2), + AngleStatistics.angleMean(0), + AngleStatistics.angleResidual(2), + AngleStatistics.angleResidual(0), + AngleStatistics.angleAdd(2), + dt); + + List waypoints = + List.of( + new Pose2d(2.75, 22.521, Rotation2d.kZero), + new Pose2d(24.73, 19.68, Rotation2d.fromRadians(5.846))); + var trajectory = + TrajectoryGenerator.generateTrajectory(waypoints, new TrajectoryConfig(8.8, 0.1)); + + Matrix r = new Matrix<>(Nat.N5(), Nat.N1()); + Matrix u = new Matrix<>(Nat.N2(), Nat.N1()); + + var B = + NumericalJacobian.numericalJacobianU( + Nat.N5(), + Nat.N2(), + S3UKFTest::driveDynamics, + new Matrix<>(Nat.N5(), Nat.N1()), + new Matrix<>(Nat.N2(), Nat.N1())); + + observer.setXhat( + VecBuilder.fill( + trajectory.getInitialPose().getTranslation().getX(), + trajectory.getInitialPose().getTranslation().getY(), + trajectory.getInitialPose().getRotation().getRadians(), + 0.0, + 0.0)); + + var trueXhat = observer.getXhat(); + + double totalTime = trajectory.getTotalTime(); + for (int i = 0; i < (totalTime / dt); ++i) { + var ref = trajectory.sample(dt * i); + double vl = ref.velocity * (1 - (ref.curvature * rb)); + double vr = ref.velocity * (1 + (ref.curvature * rb)); + + var nextR = + VecBuilder.fill( + ref.pose.getTranslation().getX(), + ref.pose.getTranslation().getY(), + ref.pose.getRotation().getRadians(), + vl, + vr); + + Matrix localY = + driveLocalMeasurementModel(trueXhat, new Matrix<>(Nat.N2(), Nat.N1())); + var noiseStdDev = VecBuilder.fill(0.0001, 0.5, 0.5); + + observer.correct(u, localY.plus(StateSpaceUtil.makeWhiteNoiseVector(noiseStdDev))); + + var rdot = nextR.minus(r).div(dt); + u = new Matrix<>(B.solve(rdot.minus(driveDynamics(r, new Matrix<>(Nat.N2(), Nat.N1()))))); + + observer.predict(u, dt); + + r = nextR; + trueXhat = NumericalIntegration.rk4(S3UKFTest::driveDynamics, trueXhat, u, dt); + } + + var localY = driveLocalMeasurementModel(trueXhat, u); + observer.correct(u, localY); + + var globalY = driveGlobalMeasurementModel(trueXhat, u); + var R = + StateSpaceUtil.makeCovarianceMatrix( + Nat.N5(), VecBuilder.fill(0.01, 0.01, 0.0001, 0.5, 0.5)); + observer.correct( + Nat.N5(), + u, + globalY, + S3UKFTest::driveGlobalMeasurementModel, + R, + AngleStatistics.angleMean(2), + AngleStatistics.angleResidual(2), + AngleStatistics.angleResidual(2), + AngleStatistics.angleAdd(2)); + + final var finalPosition = trajectory.sample(trajectory.getTotalTime()); + + assertEquals(finalPosition.pose.getTranslation().getX(), observer.getXhat(0), 0.055); + assertEquals(finalPosition.pose.getTranslation().getY(), observer.getXhat(1), 0.15); + assertEquals(finalPosition.pose.getRotation().getRadians(), observer.getXhat(2), 0.00015); + assertEquals(0.0, observer.getXhat(3), 0.1); + assertEquals(0.0, observer.getXhat(4), 0.1); + } + + @Test + void testLinearUKF() { + var dt = 0.020; + var plant = LinearSystemId.identifyVelocitySystem(0.02, 0.006); + var observer = + new S3UKF<>( + Nat.N1(), + Nat.N1(), + (x, u) -> plant.getA().times(x).plus(plant.getB().times(u)), + plant::calculateY, + VecBuilder.fill(0.05), + VecBuilder.fill(1.0), + dt); + + var discABPair = Discretization.discretizeAB(plant.getA(), plant.getB(), dt); + var discA = discABPair.getFirst(); + var discB = discABPair.getSecond(); + + Matrix ref = VecBuilder.fill(100); + Matrix u = VecBuilder.fill(0); + + for (int i = 0; i < (2.0 / dt); ++i) { + observer.predict(u, dt); + + u = discB.solve(ref.minus(discA.times(ref))); + } + + assertEquals(ref.get(0, 0), observer.getXhat(0), 5); + } + + @Test + void testRoundTripP() { + var dt = 0.005; + + var observer = + new S3UKF<>( + Nat.N2(), + Nat.N2(), + (x, u) -> x, + (x, u) -> x, + VecBuilder.fill(0.0, 0.0), + VecBuilder.fill(0.0, 0.0), + dt); + + var P = MatBuilder.fill(Nat.N2(), Nat.N2(), 2.0, 1.0, 1.0, 2.0); + observer.setP(P); + + assertTrue(observer.getP().isEqual(P, 1e-9)); + } + + // Second system, single motor feedforward estimator + private static Matrix motorDynamics(Matrix x, Matrix u) { + double v = x.get(1, 0); + double kV = x.get(2, 0); + double kA = x.get(3, 0); + + double V = u.get(0, 0); + + double a = -kV / kA * v + 1.0 / kA * V; + return MatBuilder.fill(Nat.N4(), Nat.N1(), v, a, 0, 0); + } + + private static Matrix motorMeasurementModel(Matrix x, Matrix u) { + double p = x.get(0, 0); + double v = x.get(1, 0); + double kV = x.get(2, 0); + double kA = x.get(3, 0); + double V = u.get(0, 0); + + double a = -kV / kA * v + 1.0 / kA * V; + return MatBuilder.fill(Nat.N3(), Nat.N1(), p, v, a); + } + + private static Matrix motorControlInput(double t) { + return MatBuilder.fill( + Nat.N1(), + Nat.N1(), + MathUtil.clamp( + 8 * Math.sin(Math.PI * Math.sqrt(2.0) * t) + + 6 * Math.sin(Math.PI * Math.sqrt(3.0) * t) + + 4 * Math.sin(Math.PI * Math.sqrt(5.0) * t), + -12.0, + 12.0)); + } + + @Test + void testMotorConvergence() { + final double dt = 0.01; + final int steps = 500; + final double true_kV = 3; + final double true_kA = 0.2; + + final double pos_stddev = 0.02; + final double vel_stddev = 0.1; + final double accel_stddev = 0.1; + + var states = + new ArrayList<>( + Collections.nCopies( + steps + 1, MatBuilder.fill(Nat.N4(), Nat.N1(), 0.0, 0.0, 0.0, 0.0))); + var inputs = + new ArrayList<>(Collections.nCopies(steps, MatBuilder.fill(Nat.N1(), Nat.N1(), 0.0))); + var measurements = + new ArrayList<>( + Collections.nCopies(steps + 1, MatBuilder.fill(Nat.N3(), Nat.N1(), 0.0, 0.0, 0.0))); + states.set(0, MatBuilder.fill(Nat.N4(), Nat.N1(), 0.0, 0.0, true_kV, true_kA)); + + var A = + MatBuilder.fill( + Nat.N4(), + Nat.N4(), + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + -true_kV / true_kA, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0); + var B = MatBuilder.fill(Nat.N4(), Nat.N1(), 0.0, 1.0 / true_kA, 0.0, 0.0); + + var discABPair = Discretization.discretizeAB(A, B, dt); + var discA = discABPair.getFirst(); + var discB = discABPair.getSecond(); + + for (int i = 0; i < steps; ++i) { + inputs.set(i, motorControlInput(i * dt)); + states.set(i + 1, discA.times(states.get(i)).plus(discB.times(inputs.get(i)))); + measurements.set( + i, + motorMeasurementModel(states.get(i + 1), inputs.get(i)) + .plus( + StateSpaceUtil.makeWhiteNoiseVector( + VecBuilder.fill(pos_stddev, vel_stddev, accel_stddev)))); + } + + var P0 = + MatBuilder.fill( + Nat.N4(), Nat.N4(), 0.001, 0.0, 0.0, 0.0, 0.0, 0.001, 0.0, 0.0, 0.0, 0.0, 10, 0.0, 0.0, + 0.0, 0.0, 10); + + var observer = + new S3UKF( + Nat.N4(), + Nat.N3(), + S3UKFTest::motorDynamics, + S3UKFTest::motorMeasurementModel, + VecBuilder.fill(0.1, 1.0, 1e-10, 1e-10), + VecBuilder.fill(pos_stddev, vel_stddev, accel_stddev), + dt); + + observer.setXhat(MatBuilder.fill(Nat.N4(), Nat.N1(), 0.0, 0.0, 2.0, 2.0)); + observer.setP(P0); + + for (int i = 0; i < steps; ++i) { + observer.predict(inputs.get(i), dt); + observer.correct(inputs.get(i), measurements.get(i)); + } + + assertEquals(true_kV, observer.getXhat(2), true_kV * 0.5); + assertEquals(true_kA, observer.getXhat(3), true_kA * 0.5); + } +} diff --git a/wpimath/src/test/native/cpp/estimator/UnscentedKalmanFilterTest.cpp b/wpimath/src/test/native/cpp/estimator/MerweUKFTest.cpp similarity index 83% rename from wpimath/src/test/native/cpp/estimator/UnscentedKalmanFilterTest.cpp rename to wpimath/src/test/native/cpp/estimator/MerweUKFTest.cpp index 21d2df98f4..649b514ba6 100644 --- a/wpimath/src/test/native/cpp/estimator/UnscentedKalmanFilterTest.cpp +++ b/wpimath/src/test/native/cpp/estimator/MerweUKFTest.cpp @@ -13,7 +13,7 @@ #include "frc/EigenCore.h" #include "frc/StateSpaceUtil.h" #include "frc/estimator/AngleStatistics.h" -#include "frc/estimator/UnscentedKalmanFilter.h" +#include "frc/estimator/MerweUKF.h" #include "frc/system/Discretization.h" #include "frc/system/NumericalIntegration.h" #include "frc/system/NumericalJacobian.h" @@ -67,19 +67,19 @@ frc::Vectord<5> DriveGlobalMeasurementModel( return frc::Vectord<5>{x(0), x(1), x(2), x(3), x(4)}; } -TEST(UnscentedKalmanFilterTest, DriveInit) { +TEST(MerweUKFTest, DriveInit) { constexpr auto dt = 5_ms; - frc::UnscentedKalmanFilter<5, 2, 3> observer{DriveDynamics, - DriveLocalMeasurementModel, - {0.5, 0.5, 10.0, 1.0, 1.0}, - {0.0001, 0.01, 0.01}, - frc::AngleMean<5, 5>(2), - frc::AngleMean<3, 5>(0), - frc::AngleResidual<5>(2), - frc::AngleResidual<3>(0), - frc::AngleAdd<5>(2), - dt}; + frc::MerweUKF<5, 2, 3> observer{DriveDynamics, + DriveLocalMeasurementModel, + {0.5, 0.5, 10.0, 1.0, 1.0}, + {0.0001, 0.01, 0.01}, + frc::AngleMean<5, 5>(2), + frc::AngleMean<3, 5>(0), + frc::AngleResidual<5>(2), + frc::AngleResidual<3>(0), + frc::AngleAdd<5>(2), + dt}; frc::Vectord<2> u{12.0, 12.0}; observer.Predict(u, dt); @@ -93,20 +93,20 @@ TEST(UnscentedKalmanFilterTest, DriveInit) { frc::AngleResidual<5>(2), frc::AngleAdd<5>(2)); } -TEST(UnscentedKalmanFilterTest, DriveConvergence) { +TEST(MerweUKFTest, DriveConvergence) { constexpr auto dt = 5_ms; constexpr auto rb = 0.8382_m / 2.0; // Robot radius - frc::UnscentedKalmanFilter<5, 2, 3> observer{DriveDynamics, - DriveLocalMeasurementModel, - {0.5, 0.5, 10.0, 1.0, 1.0}, - {0.0001, 0.5, 0.5}, - frc::AngleMean<5, 5>(2), - frc::AngleMean<3, 5>(0), - frc::AngleResidual<5>(2), - frc::AngleResidual<3>(0), - frc::AngleAdd<5>(2), - dt}; + frc::MerweUKF<5, 2, 3> observer{DriveDynamics, + DriveLocalMeasurementModel, + {0.5, 0.5, 10.0, 1.0, 1.0}, + {0.0001, 0.5, 0.5}, + frc::AngleMean<5, 5>(2), + frc::AngleMean<3, 5>(0), + frc::AngleResidual<5>(2), + frc::AngleResidual<3>(0), + frc::AngleAdd<5>(2), + dt}; auto waypoints = std::vector{frc::Pose2d{2.75_m, 22.521_m, 0_rad}, @@ -174,11 +174,11 @@ TEST(UnscentedKalmanFilterTest, DriveConvergence) { EXPECT_NEAR(0.0, observer.Xhat(4), 0.1); } -TEST(UnscentedKalmanFilterTest, LinearUKF) { +TEST(MerweUKFTest, LinearUKF) { constexpr units::second_t dt = 20_ms; auto plant = frc::LinearSystemId::IdentifyVelocitySystem( 0.02_V / 1_mps, 0.006_V / 1_mps_sq); - frc::UnscentedKalmanFilter<1, 1, 1> observer{ + frc::MerweUKF<1, 1, 1> observer{ [&](const frc::Vectord<1>& x, const frc::Vectord<1>& u) { return plant.A() * x + plant.B() * u; }, @@ -205,10 +205,10 @@ TEST(UnscentedKalmanFilterTest, LinearUKF) { EXPECT_NEAR(ref(0, 0), observer.Xhat(0), 5); } -TEST(UnscentedKalmanFilterTest, RoundTripP) { +TEST(MerweUKFTest, RoundTripP) { constexpr auto dt = 5_ms; - frc::UnscentedKalmanFilter<2, 2, 2> observer{ + frc::MerweUKF<2, 2, 2> observer{ [](const frc::Vectord<2>& x, const frc::Vectord<2>& u) { return x; }, [](const frc::Vectord<2>& x, const frc::Vectord<2>& u) { return x; }, {0.0, 0.0}, @@ -255,7 +255,7 @@ frc::Vectord<1> MotorControlInput(double t) { -12.0, 12.0)}; } -TEST(UnscentedKalmanFilterTest, MotorConvergence) { +TEST(MerweUKFTest, MotorConvergence) { constexpr units::second_t dt = 10_ms; constexpr int steps = 500; constexpr double true_kV = 3; @@ -290,7 +290,7 @@ TEST(UnscentedKalmanFilterTest, MotorConvergence) { frc::Vectord<4> P0{0.001, 0.001, 10, 10}; - frc::UnscentedKalmanFilter<4, 1, 3> observer{ + frc::MerweUKF<4, 1, 3> observer{ MotorDynamics, MotorMeasurementModel, wpi::array{0.1, 1.0, 1e-10, 1e-10}, wpi::array{pos_stddev, vel_stddev, accel_stddev}, dt}; diff --git a/wpimath/src/test/native/cpp/estimator/S3SigmaPointsTest.cpp b/wpimath/src/test/native/cpp/estimator/S3SigmaPointsTest.cpp new file mode 100644 index 0000000000..00926d1ef4 --- /dev/null +++ b/wpimath/src/test/native/cpp/estimator/S3SigmaPointsTest.cpp @@ -0,0 +1,50 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#include + +#include "frc/estimator/S3SigmaPoints.h" + +TEST(S3SigmaPointsTest, Simplex) { + constexpr double alpha = 1e-3; + constexpr double beta = 2; + constexpr size_t N = 2; + + frc::S3SigmaPoints sigmaPoints{alpha, beta}; + auto points = sigmaPoints.SquareRootSigmaPoints( + frc::Vectord::Zero(), frc::Matrixd::Identity()); + + auto v1 = points.template block<2, 1>(0, 1); + auto v2 = points.template block<2, 1>(0, 2); + auto v3 = points.template block<2, 1>(0, 3); + + EXPECT_DOUBLE_EQ(alpha * std::sqrt(N), v1.norm()); + EXPECT_DOUBLE_EQ(alpha * std::sqrt(N), v2.norm()); + EXPECT_DOUBLE_EQ(alpha * std::sqrt(N), v3.norm()); + EXPECT_DOUBLE_EQ((v1 - v2).norm(), (v1 - v3).norm()); + EXPECT_DOUBLE_EQ((v1 - v2).norm(), (v2 - v3).norm()); +} + +TEST(S3SigmaPointsTest, ZeroMean) { + frc::S3SigmaPoints<2> sigmaPoints; + auto points = sigmaPoints.SquareRootSigmaPoints( + frc::Vectord<2>{0.0, 0.0}, frc::Matrixd<2, 2>{{1.0, 0.0}, {0.0, 1.0}}); + + EXPECT_TRUE( + (points - frc::Matrixd<2, 4>{{0.0, -0.00122474, 0.00122474, 0.0}, + {0.0, -0.00070711, -0.00070711, 0.00141421}}) + .norm() < 1e-7); +} + +TEST(S3SigmaPointsTest, NonzeroMean) { + frc::S3SigmaPoints<2> sigmaPoints; + auto points = sigmaPoints.SquareRootSigmaPoints( + frc::Vectord<2>{1.0, 2.0}, + frc::Matrixd<2, 2>{{1.0, 0.0}, {0.0, std::sqrt(10.0)}}); + + EXPECT_TRUE( + (points - frc::Matrixd<2, 4>{{1.0, 0.99877526, 1.00122474, 1.0}, + {2.0, 1.99776393, 1.99776393, 2.00447214}}) + .norm() < 1e-7); +} diff --git a/wpimath/src/test/native/cpp/estimator/S3UKFTest.cpp b/wpimath/src/test/native/cpp/estimator/S3UKFTest.cpp new file mode 100644 index 0000000000..649b514ba6 --- /dev/null +++ b/wpimath/src/test/native/cpp/estimator/S3UKFTest.cpp @@ -0,0 +1,309 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#include +#include +#include +#include + +#include +#include + +#include "frc/EigenCore.h" +#include "frc/StateSpaceUtil.h" +#include "frc/estimator/AngleStatistics.h" +#include "frc/estimator/MerweUKF.h" +#include "frc/system/Discretization.h" +#include "frc/system/NumericalIntegration.h" +#include "frc/system/NumericalJacobian.h" +#include "frc/system/plant/DCMotor.h" +#include "frc/system/plant/LinearSystemId.h" +#include "frc/trajectory/TrajectoryGenerator.h" +#include "units/moment_of_inertia.h" + +namespace { + +// First test system, differential drive +frc::Vectord<5> DriveDynamics(const frc::Vectord<5>& x, + const frc::Vectord<2>& u) { + auto motors = frc::DCMotor::CIM(2); + + // constexpr double Glow = 15.32; // Low gear ratio + constexpr double Ghigh = 7.08; // High gear ratio + constexpr auto rb = 0.8382_m / 2.0; // Robot radius + constexpr auto r = 0.0746125_m; // Wheel radius + constexpr auto m = 63.503_kg; // Robot mass + constexpr auto J = 5.6_kg_sq_m; // Robot moment of inertia + + auto C1 = -std::pow(Ghigh, 2) * motors.Kt / + (motors.Kv * motors.R * units::math::pow<2>(r)); + auto C2 = Ghigh * motors.Kt / (motors.R * r); + auto k1 = (1 / m + units::math::pow<2>(rb) / J); + auto k2 = (1 / m - units::math::pow<2>(rb) / J); + + units::meters_per_second_t vl{x(3)}; + units::meters_per_second_t vr{x(4)}; + units::volt_t Vl{u(0)}; + units::volt_t Vr{u(1)}; + + auto v = 0.5 * (vl + vr); + return frc::Vectord<5>{ + v.value() * std::cos(x(2)), v.value() * std::sin(x(2)), + ((vr - vl) / (2.0 * rb)).value(), + k1.value() * ((C1 * vl).value() + (C2 * Vl).value()) + + k2.value() * ((C1 * vr).value() + (C2 * Vr).value()), + k2.value() * ((C1 * vl).value() + (C2 * Vl).value()) + + k1.value() * ((C1 * vr).value() + (C2 * Vr).value())}; +} + +frc::Vectord<3> DriveLocalMeasurementModel( + const frc::Vectord<5>& x, [[maybe_unused]] const frc::Vectord<2>& u) { + return frc::Vectord<3>{x(2), x(3), x(4)}; +} + +frc::Vectord<5> DriveGlobalMeasurementModel( + const frc::Vectord<5>& x, [[maybe_unused]] const frc::Vectord<2>& u) { + return frc::Vectord<5>{x(0), x(1), x(2), x(3), x(4)}; +} + +TEST(MerweUKFTest, DriveInit) { + constexpr auto dt = 5_ms; + + frc::MerweUKF<5, 2, 3> observer{DriveDynamics, + DriveLocalMeasurementModel, + {0.5, 0.5, 10.0, 1.0, 1.0}, + {0.0001, 0.01, 0.01}, + frc::AngleMean<5, 5>(2), + frc::AngleMean<3, 5>(0), + frc::AngleResidual<5>(2), + frc::AngleResidual<3>(0), + frc::AngleAdd<5>(2), + dt}; + frc::Vectord<2> u{12.0, 12.0}; + observer.Predict(u, dt); + + auto localY = DriveLocalMeasurementModel(observer.Xhat(), u); + observer.Correct(u, localY); + + auto globalY = DriveGlobalMeasurementModel(observer.Xhat(), u); + auto R = frc::MakeCovMatrix(0.01, 0.01, 0.0001, 0.01, 0.01); + observer.Correct<5>(u, globalY, DriveGlobalMeasurementModel, R, + frc::AngleMean<5, 5>(2), frc::AngleResidual<5>(2), + frc::AngleResidual<5>(2), frc::AngleAdd<5>(2)); +} + +TEST(MerweUKFTest, DriveConvergence) { + constexpr auto dt = 5_ms; + constexpr auto rb = 0.8382_m / 2.0; // Robot radius + + frc::MerweUKF<5, 2, 3> observer{DriveDynamics, + DriveLocalMeasurementModel, + {0.5, 0.5, 10.0, 1.0, 1.0}, + {0.0001, 0.5, 0.5}, + frc::AngleMean<5, 5>(2), + frc::AngleMean<3, 5>(0), + frc::AngleResidual<5>(2), + frc::AngleResidual<3>(0), + frc::AngleAdd<5>(2), + dt}; + + auto waypoints = + std::vector{frc::Pose2d{2.75_m, 22.521_m, 0_rad}, + frc::Pose2d{24.73_m, 19.68_m, 5.846_rad}}; + auto trajectory = frc::TrajectoryGenerator::GenerateTrajectory( + waypoints, {8.8_mps, 0.1_mps_sq}); + + frc::Vectord<5> r = frc::Vectord<5>::Zero(); + frc::Vectord<2> u = frc::Vectord<2>::Zero(); + + auto B = frc::NumericalJacobianU<5, 5, 2>( + DriveDynamics, frc::Vectord<5>::Zero(), frc::Vectord<2>::Zero()); + + observer.SetXhat(frc::Vectord<5>{ + trajectory.InitialPose().Translation().X().value(), + trajectory.InitialPose().Translation().Y().value(), + trajectory.InitialPose().Rotation().Radians().value(), 0.0, 0.0}); + + auto trueXhat = observer.Xhat(); + + auto totalTime = trajectory.TotalTime(); + for (size_t i = 0; i < (totalTime / dt).value(); ++i) { + auto ref = trajectory.Sample(dt * i); + units::meters_per_second_t vl = + ref.velocity * (1 - (ref.curvature * rb).value()); + units::meters_per_second_t vr = + ref.velocity * (1 + (ref.curvature * rb).value()); + + frc::Vectord<5> nextR{ + ref.pose.Translation().X().value(), ref.pose.Translation().Y().value(), + ref.pose.Rotation().Radians().value(), vl.value(), vr.value()}; + + auto localY = DriveLocalMeasurementModel(trueXhat, frc::Vectord<2>::Zero()); + observer.Correct(u, localY + frc::MakeWhiteNoiseVector(0.0001, 0.5, 0.5)); + + frc::Vectord<5> rdot = (nextR - r) / dt.value(); + u = B.householderQr().solve(rdot - + DriveDynamics(r, frc::Vectord<2>::Zero())); + + observer.Predict(u, dt); + + r = nextR; + trueXhat = frc::RK4(DriveDynamics, trueXhat, u, dt); + } + + auto localY = DriveLocalMeasurementModel(trueXhat, u); + observer.Correct(u, localY); + + auto globalY = DriveGlobalMeasurementModel(trueXhat, u); + auto R = frc::MakeCovMatrix(0.01, 0.01, 0.0001, 0.5, 0.5); + observer.Correct<5>(u, globalY, DriveGlobalMeasurementModel, R, + frc::AngleMean<5, 5>(2), frc::AngleResidual<5>(2), + frc::AngleResidual<5>(2), frc::AngleAdd<5>(2) + + ); + + auto finalPosition = trajectory.Sample(trajectory.TotalTime()); + EXPECT_NEAR(finalPosition.pose.Translation().X().value(), observer.Xhat(0), + 0.055); + EXPECT_NEAR(finalPosition.pose.Translation().Y().value(), observer.Xhat(1), + 0.15); + EXPECT_NEAR(finalPosition.pose.Rotation().Radians().value(), observer.Xhat(2), + 0.000005); + EXPECT_NEAR(0.0, observer.Xhat(3), 0.1); + EXPECT_NEAR(0.0, observer.Xhat(4), 0.1); +} + +TEST(MerweUKFTest, LinearUKF) { + constexpr units::second_t dt = 20_ms; + auto plant = frc::LinearSystemId::IdentifyVelocitySystem( + 0.02_V / 1_mps, 0.006_V / 1_mps_sq); + frc::MerweUKF<1, 1, 1> observer{ + [&](const frc::Vectord<1>& x, const frc::Vectord<1>& u) { + return plant.A() * x + plant.B() * u; + }, + [&](const frc::Vectord<1>& x, const frc::Vectord<1>& u) { + return plant.CalculateY(x, u); + }, + {0.05}, + {1.0}, + dt}; + + frc::Matrixd<1, 1> discA; + frc::Matrixd<1, 1> discB; + frc::DiscretizeAB<1, 1>(plant.A(), plant.B(), dt, &discA, &discB); + + frc::Vectord<1> ref{100.0}; + frc::Vectord<1> u{0.0}; + + for (int i = 0; i < 2.0 / dt.value(); ++i) { + observer.Predict(u, dt); + + u = discB.householderQr().solve(ref - discA * ref); + } + + EXPECT_NEAR(ref(0, 0), observer.Xhat(0), 5); +} + +TEST(MerweUKFTest, RoundTripP) { + constexpr auto dt = 5_ms; + + frc::MerweUKF<2, 2, 2> observer{ + [](const frc::Vectord<2>& x, const frc::Vectord<2>& u) { return x; }, + [](const frc::Vectord<2>& x, const frc::Vectord<2>& u) { return x; }, + {0.0, 0.0}, + {0.0, 0.0}, + dt}; + + frc::Matrixd<2, 2> P({{2, 1}, {1, 2}}); + observer.SetP(P); + + ASSERT_TRUE(observer.P().isApprox(P)); +} + +// Second system, single motor feedforward estimator +frc::Vectord<4> MotorDynamics(const frc::Vectord<4>& x, + const frc::Vectord<1>& u) { + double v = x(1); + double kV = x(2); + double kA = x(3); + + double V = u(0); + + double a = -kV / kA * v + 1.0 / kA * V; + return frc::Vectord<4>{v, a, 0.0, 0.0}; +} + +frc::Vectord<3> MotorMeasurementModel(const frc::Vectord<4>& x, + const frc::Vectord<1>& u) { + double p = x(0); + double v = x(1); + double kV = x(2); + double kA = x(3); + + double V = u(0); + + double a = -kV / kA * v + 1.0 / kA * V; + return frc::Vectord<3>{p, v, a}; +} + +frc::Vectord<1> MotorControlInput(double t) { + return frc::Vectord<1>{ + std::clamp(8 * std::sin(std::numbers::pi * std::sqrt(2.0) * t) + + 6 * std::sin(std::numbers::pi * std::sqrt(3.0) * t) + + 4 * std::sin(std::numbers::pi * std::sqrt(5.0) * t), + -12.0, 12.0)}; +} + +TEST(MerweUKFTest, MotorConvergence) { + constexpr units::second_t dt = 10_ms; + constexpr int steps = 500; + constexpr double true_kV = 3; + constexpr double true_kA = 0.2; + + constexpr double pos_stddev = 0.02; + constexpr double vel_stddev = 0.1; + constexpr double accel_stddev = 0.1; + + std::vector> states(steps + 1); + std::vector> inputs(steps); + std::vector> measurements(steps); + states[0] = frc::Vectord<4>{{0.0}, {0.0}, {true_kV}, {true_kA}}; + + constexpr frc::Matrixd<4, 4> A{{0.0, 1.0, 0.0, 0.0}, + {0.0, -true_kV / true_kA, 0.0, 0.0}, + {0.0, 0.0, 0.0, 0.0}, + {0.0, 0.0, 0.0, 0.0}}; + constexpr frc::Matrixd<4, 1> B{{0.0}, {1.0 / true_kA}, {0.0}, {0.0}}; + + frc::Matrixd<4, 4> discA; + frc::Matrixd<4, 1> discB; + frc::DiscretizeAB(A, B, dt, &discA, &discB); + + for (int i = 0; i < steps; ++i) { + inputs[i] = MotorControlInput(i * dt.value()); + states[i + 1] = discA * states[i] + discB * inputs[i]; + measurements[i] = + MotorMeasurementModel(states[i + 1], inputs[i]) + + frc::MakeWhiteNoiseVector(pos_stddev, vel_stddev, accel_stddev); + } + + frc::Vectord<4> P0{0.001, 0.001, 10, 10}; + + frc::MerweUKF<4, 1, 3> observer{ + MotorDynamics, MotorMeasurementModel, wpi::array{0.1, 1.0, 1e-10, 1e-10}, + wpi::array{pos_stddev, vel_stddev, accel_stddev}, dt}; + + observer.SetXhat(frc::Vectord<4>{0.0, 0.0, 2.0, 2.0}); + observer.SetP(P0.asDiagonal()); + + for (int i = 0; i < steps; ++i) { + observer.Predict(inputs[i], dt); + observer.Correct(inputs[i], measurements[i]); + } + + EXPECT_NEAR(true_kV, observer.Xhat(2), true_kV * 0.5); + EXPECT_NEAR(true_kA, observer.Xhat(3), true_kA * 0.5); +} + +} // namespace