[wpimath] Implement Scaled Spherical Simplex Filter (S3F) (#8091)

Adds S3SigmaPoints based on MerweScaledSigmaPoints. In addition, restructures UnscentedKalmanFilter to support different sigma point generators and provides MerweUKF and S3UKF for convenience when working with either kind of filter.

S3UKFTest is copied from MerweUKFTest (which is a rename of UnscentedKalmanFilterTest). Curiously, however, in Java the original tolerance used in MerweUKFTest.testDriveConvergence() for the final rotation was too low for S3UKFTest, so the tolerance is increased from 0.000005 (5e-6) radians to 0.00015 (1.5e-4) radians. However, the C++ version still uses the original tolerance. (This difference is probably because Java uses a final rotation of 5.846 degrees while C++ uses a final rotation of 5.846 radians)

Closes #8072.

Breaking changes:

- (C++) UnscentedKalmanFilter has a new template parameter for the sigma point generator type.
- (Java) UnscentedKalmanFilter has an additional parameter to every constructor providing an instance of a sigma point generator.
- (C++) int MerweScaledSigmaPoints.NumSigmas() has been replaced with constexpr int MerweScaledSigmaPoints::NumSigmas.
- (C++) The second parameter of SquareRootUnscentedTransform has been changed from States to NumSigmas.
This commit is contained in:
Joseph Eng
2025-07-15 21:17:25 -07:00
committed by GitHub
parent f03df5388e
commit 1530fccbd0
22 changed files with 1694 additions and 136 deletions

View File

@@ -15,7 +15,8 @@ import org.ejml.simple.SimpleMatrix;
* UnscentedKalmanFilter class.
*
* <p>It parametrizes the sigma points using alpha, beta, kappa terms, and is the version seen in
* most publications. Unless you know better, this should be your default choice.
* most publications. S3SigmaPoints is generally preferred due to its greater performance with
* nearly identical accuracy.
*
* <p>States is the dimensionality of the state. 2*States+1 weights will be generated.
*
@@ -24,7 +25,7 @@ import org.ejml.simple.SimpleMatrix;
*
* @param <S> The dimensionality of the state. 2 * States + 1 weights will be generated.
*/
public class MerweScaledSigmaPoints<S extends Num> {
public class MerweScaledSigmaPoints<S extends Num> implements SigmaPoints<S> {
private final double m_alpha;
private final int m_kappa;
private final Nat<S> m_states;
@@ -64,6 +65,7 @@ public class MerweScaledSigmaPoints<S extends Num> {
*
* @return The number of sigma points for each variable in the state x.
*/
@Override
public int getNumSigmas() {
return 2 * m_states.getNum() + 1;
}
@@ -77,6 +79,7 @@ public class MerweScaledSigmaPoints<S extends Num> {
* @return Two-dimensional array of sigma points. Each column contains all the sigmas for one
* dimension in the problem space. Ordered by Xi_0, Xi_{1..n}, Xi_{n+1..2n}.
*/
@Override
public Matrix<S, ?> squareRootSigmaPoints(Matrix<S, N1> x, Matrix<S, S> s) {
double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum();
double eta = Math.sqrt(lambda + m_states.getNum());
@@ -125,6 +128,7 @@ public class MerweScaledSigmaPoints<S extends Num> {
*
* @return the weight for each sigma point for the mean.
*/
@Override
public Matrix<?, N1> getWm() {
return m_wm;
}
@@ -135,6 +139,7 @@ public class MerweScaledSigmaPoints<S extends Num> {
* @param element Element of vector to return.
* @return the element i's weight for the mean.
*/
@Override
public double getWm(int element) {
return m_wm.get(element, 0);
}
@@ -144,6 +149,7 @@ public class MerweScaledSigmaPoints<S extends Num> {
*
* @return the weight for each sigma point for the covariance.
*/
@Override
public Matrix<?, N1> getWc() {
return m_wc;
}
@@ -154,6 +160,7 @@ public class MerweScaledSigmaPoints<S extends Num> {
* @param element Element of vector to return.
* @return The element I's weight for the covariance.
*/
@Override
public double getWc(int element) {
return m_wc.get(element, 0);
}

View File

@@ -0,0 +1,117 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package edu.wpi.first.math.estimator;
import edu.wpi.first.math.Matrix;
import edu.wpi.first.math.Nat;
import edu.wpi.first.math.Num;
import edu.wpi.first.math.numbers.N1;
import java.util.function.BiFunction;
/**
* An Unscented Kalman Filter using sigma points and weights from Van der Merwe's 2004 dissertation.
* S3UKF is generally preferred due to its greater performance with nearly identical accuracy.
*
* @param <States> Number of states.
* @param <Inputs> Number of inputs.
* @param <Outputs> Number of outputs.
*/
public class MerweUKF<States extends Num, Inputs extends Num, Outputs extends Num>
extends UnscentedKalmanFilter<States, Inputs, Outputs> {
/**
* Constructs a Merwe Unscented Kalman Filter.
*
* <p>See <a
* href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
* for how to select the standard deviations.
*
* @param states A Nat representing the number of states.
* @param outputs A Nat representing the number of outputs.
* @param f A vector-valued function of x and u that returns the derivative of the state vector.
* @param h A vector-valued function of x and u that returns the measurement vector.
* @param stateStdDevs Standard deviations of model states.
* @param measurementStdDevs Standard deviations of measurements.
* @param nominalDt Nominal discretization timestep in seconds.
*/
public MerweUKF(
Nat<States> states,
Nat<Outputs> outputs,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
Matrix<States, N1> stateStdDevs,
Matrix<Outputs, N1> measurementStdDevs,
double nominalDt) {
super(
new MerweScaledSigmaPoints<>(states),
states,
outputs,
f,
h,
stateStdDevs,
measurementStdDevs,
(sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)),
(sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)),
Matrix::minus,
Matrix::minus,
Matrix::plus,
nominalDt);
}
/**
* Constructs a Merwe Unscented Kalman filter with custom mean, residual, and addition functions.
* Using custom functions for arithmetic can be useful if you have angles in the state or
* measurements, because they allow you to correctly account for the modular nature of angle
* arithmetic.
*
* <p>See <a
* href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
* for how to select the standard deviations.
*
* @param states A Nat representing the number of states.
* @param outputs A Nat representing the number of outputs.
* @param f A vector-valued function of x and u that returns the derivative of the state vector.
* @param h A vector-valued function of x and u that returns the measurement vector.
* @param stateStdDevs Standard deviations of model states.
* @param measurementStdDevs Standard deviations of measurements.
* @param meanFuncX A function that computes the mean of NumSigmas state vectors using a given set
* of weights.
* @param meanFuncY A function that computes the mean of NumSigmas measurement vectors using a
* given set of weights.
* @param residualFuncX A function that computes the residual of two state vectors (i.e. it
* subtracts them.)
* @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
* subtracts them.)
* @param addFuncX A function that adds two state vectors.
* @param nominalDt Nominal discretization timestep in seconds.
*/
public MerweUKF(
Nat<States> states,
Nat<Outputs> outputs,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
Matrix<States, N1> stateStdDevs,
Matrix<Outputs, N1> measurementStdDevs,
BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> meanFuncX,
BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> meanFuncY,
BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX,
BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY,
BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX,
double nominalDt) {
super(
new MerweScaledSigmaPoints<>(states),
states,
outputs,
f,
h,
stateStdDevs,
measurementStdDevs,
meanFuncX,
meanFuncY,
residualFuncX,
residualFuncY,
addFuncX,
nominalDt);
}
}

View File

@@ -0,0 +1,172 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package edu.wpi.first.math.estimator;
import edu.wpi.first.math.Matrix;
import edu.wpi.first.math.Nat;
import edu.wpi.first.math.Num;
import edu.wpi.first.math.numbers.N1;
import org.ejml.simple.SimpleMatrix;
/**
* Generates sigma points and weights according to Papakonstantinou's paper[1] for the
* UnscentedKalmanFilter class.
*
* <p>It parameterizes the sigma points using alpha and beta terms. Unless you know better, this
* should be your default choice due to its high accuracy and performance.
*
* <p>[1] K. Papakonstantinou "A Scaled Spherical Simplex Filter (S3F) with a decreased n + 2 sigma
* points set size and equivalent 2n + 1 Unscented Kalman Filter (UKF) accuracy"
*
* @param <States> The dimenstionality of the state. States + 2 weights will be generated.
*/
public class S3SigmaPoints<States extends Num> implements SigmaPoints<States> {
private final Nat<States> m_states;
private final double m_alpha;
private Matrix<?, N1> m_wm;
private Matrix<?, N1> m_wc;
/**
* Constructs a generator for Papakonstantinou sigma points.
*
* @param states an instance of Num that represents the number of states.
* @param alpha Determines the spread of the sigma points around the mean. Usually a small
* positive value (1e-3).
* @param beta Incorporates prior knowledge of the distribution of the mean. For Gaussian
* distributions, beta = 2 is optimal.
*/
@SuppressWarnings("this-escape")
public S3SigmaPoints(Nat<States> states, double alpha, double beta) {
m_states = states;
m_alpha = alpha;
computeWeights(beta);
}
/**
* Constructs a generator for Papakonstantinou sigma points with default values for alpha, beta,
* and kappa.
*
* @param states an instance of Num that represents the number of states.
*/
public S3SigmaPoints(Nat<States> states) {
this(states, 1e-3, 2);
}
/**
* Returns number of sigma points for each variable in the state x.
*
* @return The number of sigma points for each variable in the state x.
*/
@Override
public int getNumSigmas() {
return m_states.getNum() + 2;
}
/**
* Computes the sigma points for an unscented Kalman filter given the mean (x) and square-root
* covariance (s) of the filter.
*
* @param x An array of the means.
* @param s Square-root covariance of the filter.
* @return Two-dimensional array of sigma points. Each column contains all the sigmas for one
* dimension in the problem space. Ordered by Xi_0, Xi_{1..n}, Xi_{n+1..2n}.
*/
@Override
public Matrix<States, ?> squareRootSigmaPoints(Matrix<States, N1> x, Matrix<States, States> s) {
// table (1), equation (12)
double[] q = new double[m_states.getNum()];
for (int t = 1; t <= m_states.getNum(); ++t) {
q[t - 1] = m_alpha * Math.sqrt(t * (m_states.getNum() + 1) / (double) (t + 1));
}
Matrix<States, ?> C = new Matrix<>(new SimpleMatrix(m_states.getNum(), getNumSigmas()));
for (int row = 0; row < m_states.getNum(); ++row) {
C.set(row, 0, 0.0);
}
for (int col = 1; col < getNumSigmas(); ++col) {
for (int row = 0; row < m_states.getNum(); ++row) {
if (row < col - 2) {
C.set(row, col, 0.0);
} else if (row == col - 2) {
C.set(row, col, q[row]);
} else {
C.set(row, col, -q[row] / (row + 1));
}
}
}
Matrix<States, ?> sigmas = new Matrix<>(new SimpleMatrix(m_states.getNum(), getNumSigmas()));
for (int col = 0; col < getNumSigmas(); ++col) {
sigmas.setColumn(col, x.plus(s.times(C.extractColumnVector(col))));
}
return sigmas;
}
/**
* Computes the weights for the scaled unscented Kalman filter.
*
* @param beta Incorporates prior knowledge of the distribution of the mean.
*/
private void computeWeights(double beta) {
double alpha_sq = m_alpha * m_alpha;
double c = 1.0 / (alpha_sq * (m_states.getNum() + 1));
Matrix<?, N1> wM = new Matrix<>(new SimpleMatrix(getNumSigmas(), 1));
Matrix<?, N1> wC = new Matrix<>(new SimpleMatrix(getNumSigmas(), 1));
wM.fill(c);
wC.fill(c);
wM.set(0, 0, 1.0 - 1.0 / alpha_sq);
wC.set(0, 0, 1.0 - 1.0 / alpha_sq + (1 - alpha_sq + beta));
this.m_wm = wM;
this.m_wc = wC;
}
/**
* Returns the weight for each sigma point for the mean.
*
* @return the weight for each sigma point for the mean.
*/
@Override
public Matrix<?, N1> getWm() {
return m_wm;
}
/**
* Returns an element of the weight for each sigma point for the mean.
*
* @param element Element of vector to return.
* @return the element i's weight for the mean.
*/
@Override
public double getWm(int element) {
return m_wm.get(element, 0);
}
/**
* Returns the weight for each sigma point for the covariance.
*
* @return the weight for each sigma point for the covariance.
*/
@Override
public Matrix<?, N1> getWc() {
return m_wc;
}
/**
* Returns an element of the weight for each sigma point for the covariance.
*
* @param element Element of vector to return.
* @return The element I's weight for the covariance.
*/
@Override
public double getWc(int element) {
return m_wc.get(element, 0);
}
}

View File

@@ -0,0 +1,117 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package edu.wpi.first.math.estimator;
import edu.wpi.first.math.Matrix;
import edu.wpi.first.math.Nat;
import edu.wpi.first.math.Num;
import edu.wpi.first.math.numbers.N1;
import java.util.function.BiFunction;
/**
* An Unscented Kalman Filter using sigma points and weights from Papakonstantinou's paper. This is
* generally preferred over other UKF variants due to its high accuracy and performance.
*
* @param <States> Number of states.
* @param <Inputs> Number of inputs.
* @param <Outputs> Number of outputs.
*/
public class S3UKF<States extends Num, Inputs extends Num, Outputs extends Num>
extends UnscentedKalmanFilter<States, Inputs, Outputs> {
/**
* Constructs a Scaled Spherical Simplex (S3) Unscented Kalman Filter.
*
* <p>See <a
* href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
* for how to select the standard deviations.
*
* @param states A Nat representing the number of states.
* @param outputs A Nat representing the number of outputs.
* @param f A vector-valued function of x and u that returns the derivative of the state vector.
* @param h A vector-valued function of x and u that returns the measurement vector.
* @param stateStdDevs Standard deviations of model states.
* @param measurementStdDevs Standard deviations of measurements.
* @param nominalDt Nominal discretization timestep in seconds.
*/
public S3UKF(
Nat<States> states,
Nat<Outputs> outputs,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
Matrix<States, N1> stateStdDevs,
Matrix<Outputs, N1> measurementStdDevs,
double nominalDt) {
super(
new S3SigmaPoints<>(states),
states,
outputs,
f,
h,
stateStdDevs,
measurementStdDevs,
(sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)),
(sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)),
Matrix::minus,
Matrix::minus,
Matrix::plus,
nominalDt);
}
/**
* Constructs a Scaled Spherical Simplex (S3) Unscented Kalman filter with custom mean, residual,
* and addition functions. Using custom functions for arithmetic can be useful if you have angles
* in the state or measurements, because they allow you to correctly account for the modular
* nature of angle arithmetic.
*
* <p>See <a
* href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
* for how to select the standard deviations.
*
* @param states A Nat representing the number of states.
* @param outputs A Nat representing the number of outputs.
* @param f A vector-valued function of x and u that returns the derivative of the state vector.
* @param h A vector-valued function of x and u that returns the measurement vector.
* @param stateStdDevs Standard deviations of model states.
* @param measurementStdDevs Standard deviations of measurements.
* @param meanFuncX A function that computes the mean of NumSigmas state vectors using a given set
* of weights.
* @param meanFuncY A function that computes the mean of NumSigmas measurement vectors using a
* given set of weights.
* @param residualFuncX A function that computes the residual of two state vectors (i.e. it
* subtracts them.)
* @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
* subtracts them.)
* @param addFuncX A function that adds two state vectors.
* @param nominalDt Nominal discretization timestep in seconds.
*/
public S3UKF(
Nat<States> states,
Nat<Outputs> outputs,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
Matrix<States, N1> stateStdDevs,
Matrix<Outputs, N1> measurementStdDevs,
BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> meanFuncX,
BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> meanFuncY,
BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX,
BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY,
BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX,
double nominalDt) {
super(
new S3SigmaPoints<>(states),
states,
outputs,
f,
h,
stateStdDevs,
measurementStdDevs,
meanFuncX,
meanFuncY,
residualFuncX,
residualFuncY,
addFuncX,
nominalDt);
}
}

View File

@@ -0,0 +1,64 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package edu.wpi.first.math.estimator;
import edu.wpi.first.math.Matrix;
import edu.wpi.first.math.Num;
import edu.wpi.first.math.numbers.N1;
/**
* A sigma points generator for the UnscentedKalmanFilter class.
*
* @param <States> The dimensionality of the state.
*/
public interface SigmaPoints<States extends Num> {
/**
* Returns number of sigma points for each variable in the state x.
*
* @return The number of sigma points for each variable in the state x.
*/
int getNumSigmas();
/**
* Computes the sigma points for an unscented Kalman filter given the mean (x) and square-root
* covariance (s) of the filter.
*
* @param x An array of the means.
* @param s Square-root covariance of the filter.
* @return Two-dimensional array of sigma points. Each column contains all the sigmas for one
* dimension in the problem space. Ordered by Xi_0, Xi_{1..n}, Xi_{n+1..2n}.
*/
Matrix<States, ?> squareRootSigmaPoints(Matrix<States, N1> x, Matrix<States, States> s);
/**
* Returns the weight for each sigma point for the mean.
*
* @return the weight for each sigma point for the mean.
*/
Matrix<?, N1> getWm();
/**
* Returns an element of the weight for each sigma point for the mean.
*
* @param element Element of vector to return.
* @return the element i's weight for the mean.
*/
double getWm(int element);
/**
* Returns the weight for each sigma point for the covariance.
*
* @return the weight for each sigma point for the covariance.
*/
Matrix<?, N1> getWc();
/**
* Returns an element of the weight for each sigma point for the covariance.
*
* @param element Element of vector to return.
* @return The element I's weight for the covariance.
*/
double getWc(int element);
}

View File

@@ -22,6 +22,11 @@ import org.ejml.simple.SimpleMatrix;
* true system state. This is useful because many states cannot be measured directly as a result of
* sensor noise, or because the state is "hidden".
*
* <p>This class's constructors require a SigmaPoints generator. For convenience, {@link S3UKF} and
* {@link MerweUKF} subclasses are provided to create a suitable generator for you. S3UKF is
* generally preferred over MerweUKF because of its greater performance while maintaining nearly
* identical accuracy.
*
* <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements
* more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum
* of squares error in the state estimate. This K gain is used to correct the state estimate by some
@@ -64,7 +69,7 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
private Matrix<States, ?> m_sigmasF;
private double m_dt;
private final MerweScaledSigmaPoints<States> m_pts;
private final SigmaPoints<States> m_pts;
/**
* Constructs an Unscented Kalman Filter.
@@ -73,6 +78,7 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
* href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
* for how to select the standard deviations.
*
* @param pts A sigma points and weights generator.
* @param states A Nat representing the number of states.
* @param outputs A Nat representing the number of outputs.
* @param f A vector-valued function of x and u that returns the derivative of the state vector.
@@ -82,6 +88,7 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
* @param nominalDt Nominal discretization timestep in seconds.
*/
public UnscentedKalmanFilter(
SigmaPoints<States> pts,
Nat<States> states,
Nat<Outputs> outputs,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
@@ -90,6 +97,7 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
Matrix<Outputs, N1> measurementStdDevs,
double nominalDt) {
this(
pts,
states,
outputs,
f,
@@ -113,16 +121,17 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
* href="https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices">https://docs.wpilib.org/en/stable/docs/software/advanced-controls/state-space/state-space-observers.html#process-and-measurement-noise-covariance-matrices</a>
* for how to select the standard deviations.
*
* @param pts A sigma points and weights generator.
* @param states A Nat representing the number of states.
* @param outputs A Nat representing the number of outputs.
* @param f A vector-valued function of x and u that returns the derivative of the state vector.
* @param h A vector-valued function of x and u that returns the measurement vector.
* @param stateStdDevs Standard deviations of model states.
* @param measurementStdDevs Standard deviations of measurements.
* @param meanFuncX A function that computes the mean of 2 * States + 1 state vectors using a
* @param meanFuncX A function that computes the mean of NumSigmas state vectors using a given set
* of weights.
* @param meanFuncY A function that computes the mean of NumSigmas measurement vectors using a
* given set of weights.
* @param meanFuncY A function that computes the mean of 2 * States + 1 measurement vectors using
* a given set of weights.
* @param residualFuncX A function that computes the residual of two state vectors (i.e. it
* subtracts them.)
* @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
@@ -131,6 +140,7 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
* @param nominalDt Nominal discretization timestep in seconds.
*/
public UnscentedKalmanFilter(
SigmaPoints<States> pts,
Nat<States> states,
Nat<Outputs> outputs,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
@@ -160,37 +170,36 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
m_pts = new MerweScaledSigmaPoints<>(states);
m_pts = pts;
reset();
}
static <S extends Num, C extends Num>
Pair<Matrix<C, N1>, Matrix<C, C>> squareRootUnscentedTransform(
Nat<S> s,
Nat<C> dim,
Matrix<C, ?> sigmas,
Matrix<?, N1> Wm,
Matrix<?, N1> Wc,
BiFunction<Matrix<C, ?>, Matrix<?, N1>, Matrix<C, N1>> meanFunc,
BiFunction<Matrix<C, N1>, Matrix<C, N1>, Matrix<C, N1>> residualFunc,
Matrix<C, C> squareRootR) {
if (sigmas.getNumRows() != dim.getNum() || sigmas.getNumCols() != 2 * s.getNum() + 1) {
static <C extends Num> Pair<Matrix<C, N1>, Matrix<C, C>> squareRootUnscentedTransform(
Nat<C> covdim,
int numSigmas,
Matrix<C, ?> sigmas,
Matrix<?, N1> Wm,
Matrix<?, N1> Wc,
BiFunction<Matrix<C, ?>, Matrix<?, N1>, Matrix<C, N1>> meanFunc,
BiFunction<Matrix<C, N1>, Matrix<C, N1>, Matrix<C, N1>> residualFunc,
Matrix<C, C> squareRootR) {
if (sigmas.getNumRows() != covdim.getNum() || sigmas.getNumCols() != numSigmas) {
throw new IllegalArgumentException(
"Sigmas must be covDim by 2 * states + 1! Got "
"Sigmas must be covDim by numSigmas! Got "
+ sigmas.getNumRows()
+ " by "
+ sigmas.getNumCols());
}
if (Wm.getNumRows() != 2 * s.getNum() + 1 || Wm.getNumCols() != 1) {
if (Wm.getNumRows() != numSigmas || Wm.getNumCols() != 1) {
throw new IllegalArgumentException(
"Wm must be 2 * states + 1 by 1! Got " + Wm.getNumRows() + " by " + Wm.getNumCols());
"Wm must be numSigmas by 1! Got " + Wm.getNumRows() + " by " + Wm.getNumCols());
}
if (Wc.getNumRows() != 2 * s.getNum() + 1 || Wc.getNumCols() != 1) {
if (Wc.getNumRows() != numSigmas || Wc.getNumCols() != 1) {
throw new IllegalArgumentException(
"Wc must be 2 * states + 1 by 1! Got " + Wc.getNumRows() + " by " + Wc.getNumCols());
"Wc must be numSigmas by 1! Got " + Wc.getNumRows() + " by " + Wc.getNumCols());
}
// New mean is usually just the sum of the sigmas * weights:
@@ -208,13 +217,18 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
// [√{W₁⁽ᶜ⁾}(𝒳_{1:2L} - x̂) √{Rᵛ}]
//
// the part of equations (20) and (24) within the "qr{}"
Matrix<C, ?> Sbar = new Matrix<>(new SimpleMatrix(dim.getNum(), 2 * s.getNum() + dim.getNum()));
for (int i = 0; i < 2 * s.getNum(); i++) {
//
// Note that we allow a custom function instead of the difference to allow
// angle wrapping. Furthermore, we allow an arbitrary number of sigma points to
// support similar methods such as the Scaled Spherical Simplex Filter (S3F).
Matrix<C, ?> Sbar =
new Matrix<>(new SimpleMatrix(covdim.getNum(), numSigmas - 1 + covdim.getNum()));
for (int i = 0; i < numSigmas - 1; i++) {
Sbar.setColumn(
i,
residualFunc.apply(sigmas.extractColumnVector(1 + i), x).times(Math.sqrt(Wc.get(1, 0))));
}
Sbar.assignBlock(0, 2 * s.getNum(), squareRootR);
Sbar.assignBlock(0, numSigmas - 1, squareRootR);
QRDecompositionHouseholder_DDRM qr = new QRDecompositionHouseholder_DDRM();
var qrStorage = Sbar.transpose().getStorage();
@@ -355,7 +369,7 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
public final void reset() {
m_xHat = new Matrix<>(m_states, Nat.N1());
m_S = new Matrix<>(m_states, m_states);
m_sigmasF = new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1));
m_sigmasF = new Matrix<>(new SimpleMatrix(m_states.getNum(), m_pts.getNumSigmas()));
}
/**
@@ -397,7 +411,7 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
var ret =
squareRootUnscentedTransform(
m_states,
m_states,
m_pts.getNumSigmas(),
m_sigmasF,
m_pts.getWm(),
m_pts.getWc(),
@@ -477,8 +491,8 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
* @param y Measurement vector.
* @param h A vector-valued function of x and u that returns the measurement vector.
* @param R Continuous measurement noise covariance matrix.
* @param meanFuncY A function that computes the mean of 2 * States + 1 measurement vectors using
* a given set of weights.
* @param meanFuncY A function that computes the mean of NumSigmas measurement vectors using a
* given set of weights.
* @param residualFuncY A function that computes the residual of two measurement vectors (i.e. it
* subtracts them.)
* @param residualFuncX A function that computes the residual of two state vectors (i.e. it
@@ -507,7 +521,7 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
// This differs from equation (22) which uses
// the prior sigma points, regenerating them allows
// multiple measurement updates per time update
Matrix<R, ?> sigmasH = new Matrix<>(new SimpleMatrix(rows.getNum(), 2 * m_states.getNum() + 1));
Matrix<R, ?> sigmasH = new Matrix<>(new SimpleMatrix(rows.getNum(), m_pts.getNumSigmas()));
var sigmas = m_pts.squareRootSigmaPoints(m_xHat, m_S);
for (int i = 0; i < m_pts.getNumSigmas(); i++) {
Matrix<R, N1> hRet = h.apply(sigmas.extractColumnVector(i), u);
@@ -521,8 +535,8 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num, Outpu
// equations (23) (24) and (25)
var transRet =
squareRootUnscentedTransform(
m_states,
rows,
m_pts.getNumSigmas(),
sigmasH,
m_pts.getWm(),
m_pts.getWc(),