[wpilib] Add pose estimators (#2867)

Pose and state estimators can filter latency-compensated global measurements and fuse them with state-space drivetrain model information to estimate robot position. They are drop-in replacements for the existing odometry classes.

Co-authored-by: Declan Freeman-Gleason <declanfreemangleason@gmail.com>
Co-authored-by: Prateek Machiraju <prateek.machiraju@gmail.com>
Co-authored-by: Claudius Tewari <cttewari@gmail.com>
Co-authored-by: Matt <matthew.morley.ca@gmail.com>
This commit is contained in:
Declan Freeman-Gleason
2020-11-28 17:35:35 -05:00
committed by GitHub
parent 3413bfc06a
commit bc8f338771
58 changed files with 4958 additions and 39 deletions

View File

@@ -0,0 +1,130 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.estimator;
import java.util.function.BiFunction;
import org.ejml.simple.SimpleMatrix;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Num;
import edu.wpi.first.wpiutil.math.numbers.N1;
public final class AngleStatistics {
private AngleStatistics() {
// Utility class
}
/**
* Subtracts a and b while normalizing the resulting value in the selected row as if it were an
* angle.
*
* @param a A vector to subtract from.
* @param b A vector to subtract with.
* @param angleStateIdx The row containing angles to be normalized.
*/
@SuppressWarnings("checkstyle:ParameterName")
public static <S extends Num> Matrix<S, N1> angleResidual(Matrix<S, N1> a, Matrix<S, N1> b,
int angleStateIdx) {
Matrix<S, N1> ret = a.minus(b);
ret.set(angleStateIdx, 0, normalizeAngle(ret.get(angleStateIdx, 0)));
return ret;
}
/**
* Returns a function that subtracts two vectors while normalizing the resulting value in the
* selected row as if it were an angle.
*
* @param angleStateIdx The row containing angles to be normalized.
*/
@SuppressWarnings("checkstyle:ParameterName")
public static <S extends Num> BiFunction<Matrix<S, N1>, Matrix<S, N1>, Matrix<S, N1>>
angleResidual(int angleStateIdx) {
return (a, b) -> angleResidual(a, b, angleStateIdx);
}
/**
* Adds a and b while normalizing the resulting value in the selected row as an angle.
*
* @param a A vector to add with.
* @param b A vector to add with.
* @param angleStateIdx The row containing angles to be normalized.
*/
@SuppressWarnings("checkstyle:ParameterName")
public static <S extends Num> Matrix<S, N1> angleAdd(Matrix<S, N1> a, Matrix<S, N1> b,
int angleStateIdx) {
Matrix<S, N1> ret = a.plus(b);
ret.set(angleStateIdx, 0, normalizeAngle(ret.get(angleStateIdx, 0)));
return ret;
}
/**
* Returns a function that adds two vectors while normalizing the resulting value in the selected
* row as an angle.
*
* @param angleStateIdx The row containing angles to be normalized.
*/
@SuppressWarnings("checkstyle:ParameterName")
public static <S extends Num> BiFunction<Matrix<S, N1>, Matrix<S, N1>, Matrix<S, N1>>
angleAdd(int angleStateIdx) {
return (a, b) -> angleAdd(a, b, angleStateIdx);
}
static double normalizeAngle(double angle) {
final double tau = 2 * Math.PI;
angle -= Math.floor(angle / tau) * tau;
if (angle > Math.PI) {
angle -= tau;
}
return angle;
}
/**
* Computes the mean of sigmas with the weights Wm while computing a special angle mean for a
* select row.
*
* @param sigmas Sigma points.
* @param Wm Weights for the mean.
* @param angleStateIdx The row containing the angles.
*/
@SuppressWarnings("checkstyle:ParameterName")
public static <S extends Num> Matrix<S, N1> angleMean(Matrix<S, ?> sigmas, Matrix<?, N1> Wm,
int angleStateIdx) {
double[] angleSigmas = sigmas.extractRowVector(angleStateIdx).getData();
Matrix<N1, ?> sinAngleSigmas = new Matrix<>(new SimpleMatrix(1, sigmas.getNumCols()));
Matrix<N1, ?> cosAngleSigmas = new Matrix<>(new SimpleMatrix(1, sigmas.getNumCols()));
for (int i = 0; i < angleSigmas.length; i++) {
sinAngleSigmas.set(0, i, Math.sin(angleSigmas[i]));
cosAngleSigmas.set(0, i, Math.cos(angleSigmas[i]));
}
double sumSin = sinAngleSigmas.times(Matrix.changeBoundsUnchecked(Wm)).elementSum();
double sumCos = cosAngleSigmas.times(Matrix.changeBoundsUnchecked(Wm)).elementSum();
Matrix<S, N1> ret = sigmas.times(Matrix.changeBoundsUnchecked(Wm));
ret.set(angleStateIdx, 0, Math.atan2(sumSin, sumCos));
return ret;
}
/**
* Returns a function that computes the mean of sigmas with the weights Wm while computing a
* special angle mean for a select row.
*
* @param angleStateIdx The row containing the angles.
*/
@SuppressWarnings("checkstyle:ParameterName")
public static <S extends Num> BiFunction<Matrix<S, ?>, Matrix<?, N1>, Matrix<S, N1>>
angleMean(int angleStateIdx) {
return (sigmas, Wm) -> angleMean(sigmas, Wm, angleStateIdx);
}
}

View File

@@ -0,0 +1,297 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.estimator;
import java.util.function.BiConsumer;
import edu.wpi.first.wpilibj.geometry.Pose2d;
import edu.wpi.first.wpilibj.geometry.Rotation2d;
import edu.wpi.first.wpilibj.kinematics.DifferentialDriveWheelSpeeds;
import edu.wpi.first.wpilibj.math.Discretization;
import edu.wpi.first.wpilibj.math.StateSpaceUtil;
import edu.wpi.first.wpiutil.WPIUtilJNI;
import edu.wpi.first.wpiutil.math.MatBuilder;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Nat;
import edu.wpi.first.wpiutil.math.VecBuilder;
import edu.wpi.first.wpiutil.math.numbers.N1;
import edu.wpi.first.wpiutil.math.numbers.N3;
import edu.wpi.first.wpiutil.math.numbers.N5;
/**
* This class wraps an
* {@link edu.wpi.first.wpilibj.estimator.UnscentedKalmanFilter Unscented Kalman Filter}
* to fuse latency-compensated vision
* measurements with differential drive encoder measurements. It will correct
* for noisy vision measurements and encoder drift. It is intended to be an easy
* drop-in for
* {@link edu.wpi.first.wpilibj.kinematics.DifferentialDriveOdometry}; in fact,
* if you never call {@link DifferentialDrivePoseEstimator#addVisionMeasurement}
* and only call {@link DifferentialDrivePoseEstimator#update} then this will
* behave exactly the same as DifferentialDriveOdometry.
*
* <p>{@link DifferentialDrivePoseEstimator#update} should be called every robot
* loop (if your robot loops are faster than the default then you should change
* the {@link DifferentialDrivePoseEstimator#DifferentialDrivePoseEstimator(Rotation2d, Pose2d,
* Matrix, Matrix, Matrix, double) nominal delta time}.)
* {@link DifferentialDrivePoseEstimator#addVisionMeasurement} can be called as
* infrequently as you want; if you never call it then this class will behave
* exactly like regular encoder odometry.
*
* <p>Our state-space system is:
*
* <p><strong> x = [[x, y, theta, dist_l, dist_r]]^T </strong>
* in the field coordinate system (dist_* are wheel distances.)
*
* <p><strong> u = [[vx, vy, omega]]^T </strong> (robot-relative velocities)
* -- NB: using velocities make things considerably easier, because it means that
* teams don't have to worry about getting an accurate model.
* Basically, we suspect that it's easier for teams to get good encoder data than it is for
* them to perform system identification well enough to get a good model.
*
* <p><strong>y = [[x, y, theta]]^T </strong> from vision,
* or <strong>y = [[dist_l, dist_r, theta]] </strong> from encoders and gyro.
*/
public class DifferentialDrivePoseEstimator {
final UnscentedKalmanFilter<N5, N3, N3> m_observer; // Package-private to allow for unit testing
private final BiConsumer<Matrix<N3, N1>, Matrix<N3, N1>> m_visionCorrect;
private final KalmanFilterLatencyCompensator<N5, N3, N3> m_latencyCompensator;
private final double m_nominalDt; // Seconds
private double m_prevTimeSeconds = -1.0;
private Rotation2d m_gyroOffset;
private Rotation2d m_previousAngle;
/**
* Constructs a DifferentialDrivePoseEstimator.
*
* @param gyroAngle The current gyro angle.
* @param initialPoseMeters The starting pose estimate.
* @param stateStdDevs Standard deviations of model states. Increase these numbers to
* trust your wheel and gyro velocities less.
* @param localMeasurementStdDevs Standard deviations of the encoder and gyro measurements.
* Increase these numbers to trust encoder distances and gyro
* angle less.
* @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase
* these numbers to trust vision less.
*/
public DifferentialDrivePoseEstimator(
Rotation2d gyroAngle, Pose2d initialPoseMeters,
Matrix<N5, N1> stateStdDevs,
Matrix<N3, N1> localMeasurementStdDevs, Matrix<N3, N1> visionMeasurementStdDevs
) {
this(gyroAngle, initialPoseMeters,
stateStdDevs, localMeasurementStdDevs, visionMeasurementStdDevs, 0.02);
}
/**
* Constructs a DifferentialDrivePoseEstimator.
*
* @param gyroAngle The current gyro angle.
* @param initialPoseMeters The starting pose estimate.
* @param stateStdDevs Standard deviations of model states. Increase these numbers to
* trust your wheel and gyro velocities less.
* @param localMeasurementStdDevs Standard deviations of the encoder and gyro measurements.
* Increase these numbers to trust encoder distances and gyro
* angle less.
* @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase
* these numbers to trust vision less.
* @param nominalDtSeconds The time in seconds between each robot loop.
*/
@SuppressWarnings("ParameterName")
public DifferentialDrivePoseEstimator(
Rotation2d gyroAngle, Pose2d initialPoseMeters,
Matrix<N5, N1> stateStdDevs,
Matrix<N3, N1> localMeasurementStdDevs, Matrix<N3, N1> visionMeasurementStdDevs,
double nominalDtSeconds
) {
m_nominalDt = nominalDtSeconds;
m_observer = new UnscentedKalmanFilter<>(
Nat.N5(), Nat.N3(),
this::f,
(x, u) -> VecBuilder.fill(x.get(3, 0), x.get(4, 0), x.get(2, 0)),
stateStdDevs, localMeasurementStdDevs,
AngleStatistics.angleMean(2),
AngleStatistics.angleMean(2),
AngleStatistics.angleResidual(2),
AngleStatistics.angleResidual(2),
AngleStatistics.angleAdd(2),
m_nominalDt
);
m_latencyCompensator = new KalmanFilterLatencyCompensator<>();
var visionContR = StateSpaceUtil.makeCovarianceMatrix(Nat.N3(), visionMeasurementStdDevs);
var visionDiscR = Discretization.discretizeR(visionContR, m_nominalDt);
m_visionCorrect = (u, y) -> m_observer.correct(
Nat.N3(), u, y,
(x, u_) -> new Matrix<>(x.getStorage().extractMatrix(0, 3, 0, 1)),
visionDiscR,
AngleStatistics.angleMean(2),
AngleStatistics.angleResidual(2),
AngleStatistics.angleResidual(2),
AngleStatistics.angleAdd(2)
);
m_gyroOffset = initialPoseMeters.getRotation().minus(gyroAngle);
m_previousAngle = initialPoseMeters.getRotation();
m_observer.setXhat(fillStateVector(initialPoseMeters, 0.0, 0.0));
}
@SuppressWarnings({"ParameterName", "MethodName"})
private Matrix<N5, N1> f(Matrix<N5, N1> x, Matrix<N3, N1> u) {
// Apply a rotation matrix. Note that we do *not* add x--Runge-Kutta does that for us.
var theta = x.get(2, 0);
var toFieldRotation = new MatBuilder<>(Nat.N5(), Nat.N5()).fill(
Math.cos(theta), -Math.sin(theta), 0, 0, 0,
Math.sin(theta), Math.cos(theta), 0, 0, 0,
0, 0, 1, 0, 0,
0, 0, 0, 1, 0,
0, 0, 0, 0, 1
);
return toFieldRotation.times(VecBuilder.fill(
u.get(0, 0), u.get(1, 0), u.get(2, 0), u.get(0, 0), u.get(1, 0)
));
}
/**
* Resets the robot's position on the field.
*
* <p>You NEED to reset your encoders (to zero) when calling this method.
*
* <p>The gyroscope angle does not need to be reset here on the user's robot code.
* The library automatically takes care of offsetting the gyro angle.
*
* @param poseMeters The position on the field that your robot is at.
* @param gyroAngle The angle reported by the gyroscope.
*/
public void resetPosition(Pose2d poseMeters, Rotation2d gyroAngle) {
m_previousAngle = poseMeters.getRotation();
m_gyroOffset = getEstimatedPosition().getRotation().minus(gyroAngle);
m_observer.setXhat(fillStateVector(poseMeters, 0.0, 0.0));
}
/**
* Gets the pose of the robot at the current time as estimated by the Unscented Kalman Filter.
*
* @return The estimated robot pose in meters.
*/
public Pose2d getEstimatedPosition() {
return new Pose2d(
m_observer.getXhat(0),
m_observer.getXhat(1),
new Rotation2d(m_observer.getXhat(2))
);
}
/**
* Add a vision measurement to the Unscented Kalman Filter. This will correct the
* odometry pose estimate while still accounting for measurement noise.
*
* <p>This method can be called as infrequently as you want, as long as you are
* calling {@link DifferentialDrivePoseEstimator#update} every loop.
*
* @param visionRobotPoseMeters The pose of the robot as measured by the vision
* camera.
* @param timestampSeconds The timestamp of the vision measurement in seconds. Note that if
* you don't use your own time source by calling
* {@link DifferentialDrivePoseEstimator#updateWithTime} then you
* must use a timestamp with an epoch since FPGA startup
* (i.e. the epoch of this timestamp is the same epoch as
* Timer.getFPGATimestamp.) This means that you should
* use Timer.getFPGATimestamp as your time source in
* this case.
*/
public void addVisionMeasurement(Pose2d visionRobotPoseMeters, double timestampSeconds) {
m_latencyCompensator.applyPastGlobalMeasurement(
Nat.N3(),
m_observer, m_nominalDt,
StateSpaceUtil.poseTo3dVector(visionRobotPoseMeters),
m_visionCorrect,
timestampSeconds
);
}
/**
* Updates the the Unscented Kalman Filter using only wheel encoder information.
* Note that this should be called every loop.
*
* @param gyroAngle The current gyro angle.
* @param wheelVelocitiesMetersPerSecond The velocities of the wheels in meters per second.
* @param distanceLeftMeters The total distance travelled by the left wheel in meters
* since the last time you called
* {@link DifferentialDrivePoseEstimator#resetPosition}.
* @param distanceRightMeters The total distance travelled by the right wheel in meters
* since the last time you called
* {@link DifferentialDrivePoseEstimator#resetPosition}.
* @return The estimated pose of the robot in meters.
*/
public Pose2d update(
Rotation2d gyroAngle,
DifferentialDriveWheelSpeeds wheelVelocitiesMetersPerSecond,
double distanceLeftMeters, double distanceRightMeters
) {
return updateWithTime(
WPIUtilJNI.now() * 1.0e-6, gyroAngle, wheelVelocitiesMetersPerSecond,
distanceLeftMeters, distanceRightMeters
);
}
/**
* Updates the the Unscented Kalman Filter using only wheel encoder information.
* Note that this should be called every loop.
*
* @param currentTimeSeconds Time at which this method was called, in seconds.
* @param gyroAngle The current gyro angle.
* @param wheelVelocitiesMetersPerSecond The velocities of the wheels in meters per second.
* @param distanceLeftMeters The total distance travelled by the left wheel in meters
* since the last time you called
* {@link DifferentialDrivePoseEstimator#resetPosition}.
* @param distanceRightMeters The total distance travelled by the right wheel in meters
* since the last time you called
* {@link DifferentialDrivePoseEstimator#resetPosition}.
* @return The estimated pose of the robot in meters.
*/
@SuppressWarnings({"LocalVariableName", "ParameterName"})
public Pose2d updateWithTime(
double currentTimeSeconds, Rotation2d gyroAngle,
DifferentialDriveWheelSpeeds wheelVelocitiesMetersPerSecond,
double distanceLeftMeters, double distanceRightMeters
) {
double dt = m_prevTimeSeconds >= 0 ? currentTimeSeconds - m_prevTimeSeconds : m_nominalDt;
m_prevTimeSeconds = currentTimeSeconds;
var angle = gyroAngle.plus(m_gyroOffset);
// Diff drive forward kinematics:
// v_c = (v_l + v_r) / 2
var wheelVels = wheelVelocitiesMetersPerSecond;
var u = VecBuilder.fill(
(wheelVels.leftMetersPerSecond + wheelVels.rightMetersPerSecond) / 2, 0,
angle.minus(m_previousAngle).getRadians() / dt
);
m_previousAngle = angle;
var localY = VecBuilder.fill(distanceLeftMeters, distanceRightMeters, angle.getRadians());
m_latencyCompensator.addObserverState(m_observer, u, localY, currentTimeSeconds);
m_observer.predict(u, dt);
m_observer.correct(u, localY);
return getEstimatedPosition();
}
private static Matrix<N5, N1> fillStateVector(Pose2d pose, double leftDist, double rightDist) {
return VecBuilder.fill(
pose.getTranslation().getX(),
pose.getTranslation().getY(),
pose.getRotation().getRadians(),
leftDist,
rightDist
);
}
}

View File

@@ -0,0 +1,161 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.estimator;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Nat;
import edu.wpi.first.wpiutil.math.Num;
import edu.wpi.first.wpiutil.math.numbers.N1;
public class KalmanFilterLatencyCompensator<S extends Num, I extends Num, O extends Num> {
private static final int kMaxPastObserverStates = 300;
private final List<Map.Entry<Double, ObserverSnapshot>> m_pastObserverSnapshots;
KalmanFilterLatencyCompensator() {
m_pastObserverSnapshots = new ArrayList<>();
}
/**
* Add past observer states to the observer snapshots list.
*
* @param observer The observer.
* @param u The input at the timestamp.
* @param localY The local output at the timestamp
* @param timestampSeconds The timesnap of the state.
*/
@SuppressWarnings("ParameterName")
public void addObserverState(
KalmanTypeFilter<S, I, O> observer, Matrix<I, N1> u, Matrix<O, N1> localY,
double timestampSeconds
) {
m_pastObserverSnapshots.add(Map.entry(
timestampSeconds, new ObserverSnapshot(observer, u, localY)
));
if (m_pastObserverSnapshots.size() > kMaxPastObserverStates) {
m_pastObserverSnapshots.remove(0);
}
}
/**
* Add past global measurements (such as from vision)to the estimator.
*
* @param <R> The rows in the global measurement vector.
* @param rows The rows in the global measurement vector.
* @param observer The observer to apply the past global measurement.
* @param nominalDtSeconds The nominal timestep.
* @param globalMeasurement The measurement.
* @param globalMeasurementCorrect The function take calls correct() on the observer.
* @param globalMeasurementTimestampSeconds The timestamp of the measurement.
*/
@SuppressWarnings({"ParameterName", "PMD.AvoidInstantiatingObjectsInLoops"})
public <R extends Num> void applyPastGlobalMeasurement(
Nat<R> rows,
KalmanTypeFilter<S, I, O> observer,
double nominalDtSeconds,
Matrix<R, N1> globalMeasurement,
BiConsumer<Matrix<I, N1>, Matrix<R, N1>> globalMeasurementCorrect,
double globalMeasurementTimestampSeconds
) {
if (m_pastObserverSnapshots.isEmpty()) {
// State map was empty, which means that we got a past measurement right at startup. The only
// thing we can really do is ignore the measurement.
return;
}
// This index starts at one because we use the previous state later on, and we always want to
// have a "previous state".
int maxIdx = m_pastObserverSnapshots.size() - 1;
int low = 1;
int high = Math.max(maxIdx, 1);
while (low != high) {
int mid = (low + high) / 2;
if (m_pastObserverSnapshots.get(mid).getKey() < globalMeasurementTimestampSeconds) {
// This index and everything under it are less than the requested timestamp. Therefore, we
// can discard them.
low = mid + 1;
} else {
// t is at least as large as the element at this index. This means that anything after it
// cannot be what we are looking for.
high = mid;
}
}
// We are simply assigning this index to a new variable to avoid confusion
// with variable names.
int index = low;
double timestamp = globalMeasurementTimestampSeconds;
int indexOfClosestEntry =
Math.abs(timestamp - m_pastObserverSnapshots.get(index - 1).getKey())
<= Math.abs(timestamp - m_pastObserverSnapshots.get(Math.min(index, maxIdx)).getKey())
? index - 1
: index;
double lastTimestamp =
m_pastObserverSnapshots.get(indexOfClosestEntry).getKey() - nominalDtSeconds;
// We will now go back in time to the state of the system at the time when
// the measurement was captured. We will reset the observer to that state,
// and apply correction based on the measurement. Then, we will go back
// through all observer states until the present and apply past inputs to
// get the present estimated state.
for (int i = indexOfClosestEntry; i < m_pastObserverSnapshots.size(); i++) {
var key = m_pastObserverSnapshots.get(i).getKey();
var snapshot = m_pastObserverSnapshots.get(i).getValue();
if (i == indexOfClosestEntry) {
observer.setP(snapshot.errorCovariances);
observer.setXhat(snapshot.xHat);
}
observer.predict(snapshot.inputs, key - lastTimestamp);
observer.correct(snapshot.inputs, snapshot.localMeasurements);
if (i == indexOfClosestEntry) {
// Note that the measurement is at a timestep close but probably not exactly equal to the
// timestep for which we called predict.
// This makes the assumption that the dt is small enough that the difference between the
// measurement time and the time that the inputs were captured at is very small.
globalMeasurementCorrect.accept(snapshot.inputs, globalMeasurement);
}
lastTimestamp = key;
m_pastObserverSnapshots.set(i, Map.entry(key,
new ObserverSnapshot(observer, snapshot.inputs, snapshot.localMeasurements)));
}
}
/**
* This class contains all the information about our observer at a given time.
*/
@SuppressWarnings("MemberName")
public class ObserverSnapshot {
public final Matrix<S, N1> xHat;
public final Matrix<S, S> errorCovariances;
public final Matrix<I, N1> inputs;
public final Matrix<O, N1> localMeasurements;
@SuppressWarnings("ParameterName")
private ObserverSnapshot(
KalmanTypeFilter<S, I, O> observer, Matrix<I, N1> u, Matrix<O, N1> localY
) {
this.xHat = observer.getXhat();
this.errorCovariances = observer.getP();
inputs = u;
localMeasurements = localY;
}
}
}

View File

@@ -0,0 +1,250 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.estimator;
import java.util.function.BiConsumer;
import edu.wpi.first.wpilibj.geometry.Pose2d;
import edu.wpi.first.wpilibj.geometry.Rotation2d;
import edu.wpi.first.wpilibj.geometry.Translation2d;
import edu.wpi.first.wpilibj.kinematics.MecanumDriveKinematics;
import edu.wpi.first.wpilibj.kinematics.MecanumDriveWheelSpeeds;
import edu.wpi.first.wpilibj.math.Discretization;
import edu.wpi.first.wpilibj.math.StateSpaceUtil;
import edu.wpi.first.wpiutil.WPIUtilJNI;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Nat;
import edu.wpi.first.wpiutil.math.VecBuilder;
import edu.wpi.first.wpiutil.math.numbers.N1;
import edu.wpi.first.wpiutil.math.numbers.N3;
/**
* This class wraps an {@link UnscentedKalmanFilter Unscented Kalman Filter} to fuse
* latency-compensated vision measurements with mecanum drive encoder velocity measurements.
* It will correct for noisy measurements and encoder drift. It is intended to be an easy
* but more accurate drop-in for {@link edu.wpi.first.wpilibj.kinematics.MecanumDriveOdometry}.
*
* <p>{@link MecanumDrivePoseEstimator#update} should be called every robot loop. If
* your loops are faster or slower than the default of 0.02s, then you should change
* the nominal delta time using the secondary constructor:
* {@link MecanumDrivePoseEstimator#MecanumDrivePoseEstimator(Rotation2d, Pose2d,
* MecanumDriveKinematics, Matrix, Matrix, Matrix, double)}.
*
* <p>{@link MecanumDrivePoseEstimator#addVisionMeasurement} can be called as
* infrequently as you want; if you never call it, then this class will behave mostly like regular
* encoder odometry.
*
* <p>Our state-space system is:
*
* <p><strong> x = [[x, y, theta]]^T </strong> in the field-coordinate system.
*
* <p><strong> u = [[vx, vy, theta]]^T </strong> in the field-coordinate system.
*
* <p><strong> y = [[x, y, theta]]^T </strong> in field coords from vision,
* or <strong> y = [[theta]]^T </strong> from the gyro.
*/
public class MecanumDrivePoseEstimator {
private final UnscentedKalmanFilter<N3, N3, N1> m_observer;
private final MecanumDriveKinematics m_kinematics;
private final BiConsumer<Matrix<N3, N1>, Matrix<N3, N1>> m_visionCorrect;
private final KalmanFilterLatencyCompensator<N3, N3, N1> m_latencyCompensator;
private final double m_nominalDt; // Seconds
private double m_prevTimeSeconds = -1.0;
private Rotation2d m_gyroOffset;
private Rotation2d m_previousAngle;
/**
* Constructs a MecanumDrivePoseEstimator.
*
* @param gyroAngle The current gyro angle.
* @param initialPoseMeters The starting pose estimate.
* @param kinematics A correctly-configured kinematics object for your drivetrain.
* @param stateStdDevs Standard deviations of model states. Increase these numbers to
* trust your wheel and gyro velocities less.
* @param localMeasurementStdDevs Standard deviations of the gyro measurement. Increase this
* number to trust gyro angle measurements less.
* @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase
* these numbers to trust vision less.
*/
public MecanumDrivePoseEstimator(
Rotation2d gyroAngle, Pose2d initialPoseMeters, MecanumDriveKinematics kinematics,
Matrix<N3, N1> stateStdDevs, Matrix<N1, N1> localMeasurementStdDevs,
Matrix<N3, N1> visionMeasurementStdDevs
) {
this(gyroAngle, initialPoseMeters, kinematics, stateStdDevs, localMeasurementStdDevs,
visionMeasurementStdDevs, 0.02);
}
/**
* Constructs a MecanumDrivePoseEstimator.
*
* @param gyroAngle The current gyro angle.
* @param initialPoseMeters The starting pose estimate.
* @param kinematics A correctly-configured kinematics object for your drivetrain.
* @param stateStdDevs Standard deviations of model states. Increase these numbers to
* trust your wheel and gyro velocities less.
* @param localMeasurementStdDevs Standard deviations of the gyro measurement. Increase this
* number to trust gyro angle measurements less.
* @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase
* these numbers to trust vision less.
* @param nominalDtSeconds The time in seconds between each robot loop.
*/
@SuppressWarnings("ParameterName")
public MecanumDrivePoseEstimator(
Rotation2d gyroAngle, Pose2d initialPoseMeters, MecanumDriveKinematics kinematics,
Matrix<N3, N1> stateStdDevs, Matrix<N1, N1> localMeasurementStdDevs,
Matrix<N3, N1> visionMeasurementStdDevs, double nominalDtSeconds
) {
m_nominalDt = nominalDtSeconds;
m_observer = new UnscentedKalmanFilter<>(
Nat.N3(), Nat.N1(),
(x_, u) -> u,
(x, u_) -> x.extractRowVector(2),
stateStdDevs,
localMeasurementStdDevs,
AngleStatistics.angleMean(2),
AngleStatistics.angleMean(0),
AngleStatistics.angleResidual(2),
AngleStatistics.angleResidual(0),
AngleStatistics.angleAdd(2),
m_nominalDt
);
m_kinematics = kinematics;
m_latencyCompensator = new KalmanFilterLatencyCompensator<>();
var visionContR = StateSpaceUtil.makeCovarianceMatrix(Nat.N3(), visionMeasurementStdDevs);
var visionDiscR = Discretization.discretizeR(visionContR, m_nominalDt);
m_visionCorrect = (u, y) -> m_observer.correct(
Nat.N3(), u, y,
(x, u_) -> x,
visionDiscR,
AngleStatistics.angleMean(2),
AngleStatistics.angleResidual(2),
AngleStatistics.angleResidual(2),
AngleStatistics.angleAdd(2)
);
m_gyroOffset = initialPoseMeters.getRotation().minus(gyroAngle);
m_previousAngle = initialPoseMeters.getRotation();
m_observer.setXhat(StateSpaceUtil.poseTo3dVector(initialPoseMeters));
}
/**
* Resets the robot's position on the field.
*
* <p>You NEED to reset your encoders (to zero) when calling this method.
*
* <p>The gyroscope angle does not need to be reset in the user's robot code.
* The library automatically takes care of offsetting the gyro angle.
*
* @param poseMeters The position on the field that your robot is at.
* @param gyroAngle The angle reported by the gyroscope.
*/
public void resetPosition(Pose2d poseMeters, Rotation2d gyroAngle) {
m_previousAngle = poseMeters.getRotation();
m_gyroOffset = getEstimatedPosition().getRotation().minus(gyroAngle);
m_observer.setXhat(StateSpaceUtil.poseTo3dVector(poseMeters));
}
/**
* Gets the pose of the robot at the current time as estimated by the Unscented Kalman Filter.
*
* @return The estimated robot pose in meters.
*/
public Pose2d getEstimatedPosition() {
return new Pose2d(
m_observer.getXhat(0),
m_observer.getXhat(1),
new Rotation2d(m_observer.getXhat(2))
);
}
/**
* Add a vision measurement to the Unscented Kalman Filter. This will correct the
* odometry pose estimate while still accounting for measurement noise.
*
* <p>This method can be called as infrequently as you want, as long as you are
* calling {@link MecanumDrivePoseEstimator#update} every loop.
*
* @param visionRobotPoseMeters The pose of the robot as measured by the vision
* camera.
* @param timestampSeconds The timestamp of the vision measurement in seconds. Note that if
* you don't use your own time source by calling
* {@link MecanumDrivePoseEstimator#updateWithTime} then you
* must use a timestamp with an epoch since FPGA startup
* (i.e. the epoch of this timestamp is the same epoch as
* Timer.getFPGATimestamp.) This means that you should
* use Timer.getFPGATimestamp as your time source
* or sync the epochs.
*/
public void addVisionMeasurement(Pose2d visionRobotPoseMeters, double timestampSeconds) {
m_latencyCompensator.applyPastGlobalMeasurement(
Nat.N3(),
m_observer, m_nominalDt,
StateSpaceUtil.poseTo3dVector(visionRobotPoseMeters),
m_visionCorrect,
timestampSeconds
);
}
/**
* Updates the the Unscented Kalman Filter using only wheel encoder information.
* This should be called every loop, and the correct loop period must be passed
* into the constructor of this class.
*
* @param gyroAngle The current gyro angle.
* @param wheelSpeeds The current speeds of the mecanum drive wheels.
* @return The estimated pose of the robot in meters.
*/
public Pose2d update(Rotation2d gyroAngle, MecanumDriveWheelSpeeds wheelSpeeds) {
return updateWithTime(WPIUtilJNI.now() * 1.0e-6, gyroAngle, wheelSpeeds);
}
/**
* Updates the the Unscented Kalman Filter using only wheel encoder information.
* This should be called every loop, and the correct loop period must be passed
* into the constructor of this class.
*
* @param currentTimeSeconds Time at which this method was called, in seconds.
* @param gyroAngle The current gyroscope angle.
* @param wheelSpeeds The current speeds of the mecanum drive wheels.
* @return The estimated pose of the robot in meters.
*/
@SuppressWarnings("LocalVariableName")
public Pose2d updateWithTime(double currentTimeSeconds, Rotation2d gyroAngle,
MecanumDriveWheelSpeeds wheelSpeeds) {
double dt = m_prevTimeSeconds >= 0 ? currentTimeSeconds - m_prevTimeSeconds : m_nominalDt;
m_prevTimeSeconds = currentTimeSeconds;
var angle = gyroAngle.plus(m_gyroOffset);
var omega = angle.minus(m_previousAngle).getRadians() / dt;
var chassisSpeeds = m_kinematics.toChassisSpeeds(wheelSpeeds);
var fieldRelativeVelocities =
new Translation2d(chassisSpeeds.vxMetersPerSecond, chassisSpeeds.vyMetersPerSecond)
.rotateBy(angle);
var u = VecBuilder.fill(
fieldRelativeVelocities.getX(),
fieldRelativeVelocities.getY(),
omega
);
m_previousAngle = angle;
var localY = VecBuilder.fill(angle.getRadians());
m_latencyCompensator.addObserverState(m_observer, u, localY, currentTimeSeconds);
m_observer.predict(u, dt);
m_observer.correct(u, localY);
return getEstimatedPosition();
}
}

View File

@@ -0,0 +1,255 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.estimator;
import java.util.function.BiConsumer;
import edu.wpi.first.wpilibj.geometry.Pose2d;
import edu.wpi.first.wpilibj.geometry.Rotation2d;
import edu.wpi.first.wpilibj.geometry.Translation2d;
import edu.wpi.first.wpilibj.kinematics.SwerveDriveKinematics;
import edu.wpi.first.wpilibj.kinematics.SwerveModuleState;
import edu.wpi.first.wpilibj.math.Discretization;
import edu.wpi.first.wpilibj.math.StateSpaceUtil;
import edu.wpi.first.wpiutil.WPIUtilJNI;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Nat;
import edu.wpi.first.wpiutil.math.VecBuilder;
import edu.wpi.first.wpiutil.math.numbers.N1;
import edu.wpi.first.wpiutil.math.numbers.N3;
/**
* This class wraps an {@link UnscentedKalmanFilter Unscented Kalman Filter} to fuse
* latency-compensated vision measurements with swerve drive encoder velocity measurements.
* It will correct for noisy measurements and encoder drift. It is intended to be an easy
* but more accurate drop-in for {@link edu.wpi.first.wpilibj.kinematics.SwerveDriveOdometry}.
*
* <p>{@link SwerveDrivePoseEstimator#update} should be called every robot loop. If
* your loops are faster or slower than the default of 0.02s, then you should change
* the nominal delta time using the secondary constructor:
* {@link SwerveDrivePoseEstimator#SwerveDrivePoseEstimator(Rotation2d, Pose2d,
* SwerveDriveKinematics, Matrix, Matrix, Matrix, double)}.
*
* <p>{@link SwerveDrivePoseEstimator#addVisionMeasurement} can be called as
* infrequently as you want; if you never call it, then this class will behave mostly like regular
* encoder odometry.
*
* <p>Our state-space system is:
*
* <p><strong> x = [[x, y, theta]]^T </strong> in the field-coordinate system.
*
* <p><strong> u = [[vx, vy, omega]]^T </strong> in the field-coordinate system.
*
* <p><strong> y = [[x, y, theta]]^T </strong> in field coords from vision,
* or <strong> y = [[theta]]^T </strong> from the gyro.
*/
public class SwerveDrivePoseEstimator {
private final UnscentedKalmanFilter<N3, N3, N1> m_observer;
private final SwerveDriveKinematics m_kinematics;
private final BiConsumer<Matrix<N3, N1>, Matrix<N3, N1>> m_visionCorrect;
private final KalmanFilterLatencyCompensator<N3, N3, N1> m_latencyCompensator;
private final double m_nominalDt; // Seconds
private double m_prevTimeSeconds = -1.0;
private Rotation2d m_gyroOffset;
private Rotation2d m_previousAngle;
/**
* Constructs a SwerveDrivePoseEstimator.
*
* @param gyroAngle The current gyro angle.
* @param initialPoseMeters The starting pose estimate.
* @param kinematics A correctly-configured kinematics object for your drivetrain.
* @param stateStdDevs Standard deviations of model states. Increase these numbers to
* trust your wheel and gyro velocities less.
* @param localMeasurementStdDevs Standard deviations of the gyro measurement. Increase this
* number to trust gyro angle measurements less.
* @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase
* these numbers to trust vision less.
*/
public SwerveDrivePoseEstimator(
Rotation2d gyroAngle, Pose2d initialPoseMeters, SwerveDriveKinematics kinematics,
Matrix<N3, N1> stateStdDevs, Matrix<N1, N1> localMeasurementStdDevs,
Matrix<N3, N1> visionMeasurementStdDevs
) {
this(gyroAngle, initialPoseMeters, kinematics, stateStdDevs, localMeasurementStdDevs,
visionMeasurementStdDevs, 0.02);
}
/**
* Constructs a SwerveDrivePoseEstimator.
*
* @param gyroAngle The current gyro angle.
* @param initialPoseMeters The starting pose estimate.
* @param kinematics A correctly-configured kinematics object for your drivetrain.
* @param stateStdDevs Standard deviations of model states. Increase these numbers to
* trust your wheel and gyro velocities less.
* @param localMeasurementStdDevs Standard deviations of the gyro measurement. Increase this
* number to trust gyro angle measurements less.
* @param visionMeasurementStdDevs Standard deviations of the vision measurements. Increase
* these numbers to trust vision less.
* @param nominalDtSeconds The time in seconds between each robot loop.
*/
@SuppressWarnings("ParameterName")
public SwerveDrivePoseEstimator(
Rotation2d gyroAngle, Pose2d initialPoseMeters, SwerveDriveKinematics kinematics,
Matrix<N3, N1> stateStdDevs, Matrix<N1, N1> localMeasurementStdDevs,
Matrix<N3, N1> visionMeasurementStdDevs, double nominalDtSeconds
) {
m_nominalDt = nominalDtSeconds;
m_observer = new UnscentedKalmanFilter<>(
Nat.N3(), Nat.N1(),
(x_, u) -> u,
(x, u_) -> x.extractRowVector(2),
stateStdDevs,
localMeasurementStdDevs,
AngleStatistics.angleMean(2),
AngleStatistics.angleMean(0),
AngleStatistics.angleResidual(2),
AngleStatistics.angleResidual(0),
AngleStatistics.angleAdd(2),
m_nominalDt
);
m_kinematics = kinematics;
m_latencyCompensator = new KalmanFilterLatencyCompensator<>();
var visionContR = StateSpaceUtil.makeCovarianceMatrix(Nat.N3(), visionMeasurementStdDevs);
var visionDiscR = Discretization.discretizeR(visionContR, m_nominalDt);
m_visionCorrect = (u, y) -> m_observer.correct(
Nat.N3(), u, y,
(x, u_) -> x,
visionDiscR,
AngleStatistics.angleMean(2),
AngleStatistics.angleResidual(2),
AngleStatistics.angleResidual(2),
AngleStatistics.angleAdd(2)
);
m_gyroOffset = initialPoseMeters.getRotation().minus(gyroAngle);
m_previousAngle = initialPoseMeters.getRotation();
m_observer.setXhat(StateSpaceUtil.poseTo3dVector(initialPoseMeters));
}
/**
* Resets the robot's position on the field.
*
* <p>You NEED to reset your encoders (to zero) when calling this method.
*
* <p>The gyroscope angle does not need to be reset in the user's robot code.
* The library automatically takes care of offsetting the gyro angle.
*
* @param poseMeters The position on the field that your robot is at.
* @param gyroAngle The angle reported by the gyroscope.
*/
public void resetPosition(Pose2d poseMeters, Rotation2d gyroAngle) {
m_previousAngle = poseMeters.getRotation();
m_gyroOffset = getEstimatedPosition().getRotation().minus(gyroAngle);
m_observer.setXhat(StateSpaceUtil.poseTo3dVector(poseMeters));
}
/**
* Gets the pose of the robot at the current time as estimated by the Unscented Kalman Filter.
*
* @return The estimated robot pose in meters.
*/
public Pose2d getEstimatedPosition() {
return new Pose2d(
m_observer.getXhat(0),
m_observer.getXhat(1),
new Rotation2d(m_observer.getXhat(2))
);
}
/**
* Add a vision measurement to the Unscented Kalman Filter. This will correct the
* odometry pose estimate while still accounting for measurement noise.
*
* <p>This method can be called as infrequently as you want, as long as you are
* calling {@link SwerveDrivePoseEstimator#update} every loop.
*
* @param visionRobotPoseMeters The pose of the robot as measured by the vision
* camera.
* @param timestampSeconds The timestamp of the vision measurement in seconds. Note that if
* you don't use your own time source by calling
* {@link SwerveDrivePoseEstimator#updateWithTime} then you
* must use a timestamp with an epoch since FPGA startup
* (i.e. the epoch of this timestamp is the same epoch as
* Timer.getFPGATimestamp.) This means that you should
* use Timer.getFPGATimestamp as your time source or
* sync the epochs.
*/
public void addVisionMeasurement(Pose2d visionRobotPoseMeters, double timestampSeconds) {
m_latencyCompensator.applyPastGlobalMeasurement(
Nat.N3(),
m_observer, m_nominalDt,
StateSpaceUtil.poseTo3dVector(visionRobotPoseMeters),
m_visionCorrect,
timestampSeconds
);
}
/**
* Updates the the Unscented Kalman Filter using only wheel encoder information.
* This should be called every loop, and the correct loop period must be passed
* into the constructor of this class.
*
* @param gyroAngle The current gyro angle.
* @param moduleStates The current velocities and rotations of the swerve modules.
* @return The estimated pose of the robot in meters.
*/
public Pose2d update(
Rotation2d gyroAngle,
SwerveModuleState... moduleStates
) {
return updateWithTime(WPIUtilJNI.now() * 1.0e-6, gyroAngle, moduleStates);
}
/**
* Updates the the Unscented Kalman Filter using only wheel encoder information.
* This should be called every loop, and the correct loop period must be passed
* into the constructor of this class.
*
* @param currentTimeSeconds Time at which this method was called, in seconds.
* @param gyroAngle The current gyroscope angle.
* @param moduleStates The current velocities and rotations of the swerve modules.
* @return The estimated pose of the robot in meters.
*/
@SuppressWarnings("LocalVariableName")
public Pose2d updateWithTime(
double currentTimeSeconds,
Rotation2d gyroAngle, SwerveModuleState... moduleStates
) {
double dt = m_prevTimeSeconds >= 0 ? currentTimeSeconds - m_prevTimeSeconds : m_nominalDt;
m_prevTimeSeconds = currentTimeSeconds;
var angle = gyroAngle.plus(m_gyroOffset);
var omega = angle.minus(m_previousAngle).getRadians() / dt;
var chassisSpeeds = m_kinematics.toChassisSpeeds(moduleStates);
var fieldRelativeVelocities = new Translation2d(
chassisSpeeds.vxMetersPerSecond, chassisSpeeds.vyMetersPerSecond
).rotateBy(angle);
var u = VecBuilder.fill(
fieldRelativeVelocities.getX(),
fieldRelativeVelocities.getY(),
omega
);
m_previousAngle = angle;
var localY = VecBuilder.fill(angle.getRadians());
m_latencyCompensator.addObserverState(m_observer, u, localY, currentTimeSeconds);
m_observer.predict(u, dt);
m_observer.correct(u, localY);
return getEstimatedPosition();
}
}

View File

@@ -31,7 +31,7 @@ import edu.wpi.first.wpiutil.math.numbers.N1;
* an estimate of the true covariance (as opposed to a linearized version of it). This means that
* the UKF works with nonlinear systems.
*/
@SuppressWarnings({"MemberName", "ClassTypeParameterName"})
@SuppressWarnings({"MemberName", "ClassTypeParameterName", "PMD.TooManyFields"})
public class UnscentedKalmanFilter<States extends Num, Inputs extends Num,
Outputs extends Num> implements KalmanTypeFilter<States, Inputs, Outputs> {
@@ -41,6 +41,12 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num,
private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f;
private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h;
private BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> m_meanFuncX;
private BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> m_meanFuncY;
private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_residualFuncX;
private BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> m_residualFuncY;
private BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> m_addFuncX;
private Matrix<States, N1> m_xHat;
private Matrix<States, States> m_P;
private final Matrix<States, States> m_contQ;
@@ -61,7 +67,7 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num,
* the measurement vector.
* @param stateStdDevs Standard deviations of model states.
* @param measurementStdDevs Standard deviations of measurements.
* @param dtSeconds Nominal discretization timestep.
* @param nominalDtSeconds Nominal discretization timestep.
*/
@SuppressWarnings("ParameterName")
public UnscentedKalmanFilter(Nat<States> states, Nat<Outputs> outputs,
@@ -71,18 +77,75 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num,
Matrix<Outputs, N1>> h,
Matrix<States, N1> stateStdDevs,
Matrix<Outputs, N1> measurementStdDevs,
double dtSeconds) {
double nominalDtSeconds) {
this(
states, outputs,
f, h,
stateStdDevs, measurementStdDevs,
(sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)),
(sigmas, Wm) -> sigmas.times(Matrix.changeBoundsUnchecked(Wm)),
Matrix::minus,
Matrix::minus,
Matrix::plus,
nominalDtSeconds
);
}
/**
* Constructs an unscented Kalman filter with custom mean, residual, and addition functions. Using
* custom functions for arithmetic can be useful if you have angles in the state or measurements,
* because they allow you to correctly account for the modular nature of angle arithmetic.
*
* @param states A Nat representing the number of states.
* @param outputs A Nat representing the number of outputs.
* @param f A vector-valued function of x and u that returns
* the derivative of the state vector.
* @param h A vector-valued function of x and u that returns
* the measurement vector.
* @param stateStdDevs Standard deviations of model states.
* @param measurementStdDevs Standard deviations of measurements.
* @param meanFuncX A function that computes the mean of 2 * States + 1 state vectors
* using a given set of weights.
* @param meanFuncY A function that computes the mean of 2 * States + 1 measurement
* vectors using a given set of weights.
* @param residualFuncX A function that computes the residual of two state vectors (i.e. it
* subtracts them.)
* @param residualFuncY A function that computes the residual of two measurement vectors
* (i.e. it subtracts them.)
* @param addFuncX A function that adds two state vectors.
* @param nominalDtSeconds Nominal discretization timestep.
*/
@SuppressWarnings({"ParameterName", "PMD.ExcessiveParameterList"})
public UnscentedKalmanFilter(
Nat<States> states, Nat<Outputs> outputs,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
Matrix<States, N1> stateStdDevs,
Matrix<Outputs, N1> measurementStdDevs,
BiFunction<Matrix<States, ?>, Matrix<?, N1>, Matrix<States, N1>> meanFuncX,
BiFunction<Matrix<Outputs, ?>, Matrix<?, N1>, Matrix<Outputs, N1>> meanFuncY,
BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX,
BiFunction<Matrix<Outputs, N1>, Matrix<Outputs, N1>, Matrix<Outputs, N1>> residualFuncY,
BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX,
double nominalDtSeconds
) {
this.m_states = states;
this.m_outputs = outputs;
m_f = f;
m_h = h;
m_meanFuncX = meanFuncX;
m_meanFuncY = meanFuncY;
m_residualFuncX = residualFuncX;
m_residualFuncY = residualFuncY;
m_addFuncX = addFuncX;
m_dtSeconds = nominalDtSeconds;
m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
m_dtSeconds = dtSeconds;
m_pts = new MerweScaledSigmaPoints<>(states);
reset();
@@ -91,7 +154,9 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num,
@SuppressWarnings({"ParameterName", "LocalVariableName", "PMD.CyclomaticComplexity"})
static <S extends Num, C extends Num>
Pair<Matrix<C, N1>, Matrix<C, C>> unscentedTransform(
Nat<S> s, Nat<C> dim, Matrix<C, ?> sigmas, Matrix<?, N1> Wm, Matrix<?, N1> Wc
Nat<S> s, Nat<C> dim, Matrix<C, ?> sigmas, Matrix<?, N1> Wm, Matrix<?, N1> Wc,
BiFunction<Matrix<C, ?>, Matrix<?, N1>, Matrix<C, N1>> meanFunc,
BiFunction<Matrix<C, N1>, Matrix<C, N1>, Matrix<C, N1>> residualFunc
) {
if (sigmas.getNumRows() != dim.getNum() || sigmas.getNumCols() != 2 * s.getNum() + 1) {
throw new IllegalArgumentException("Sigmas must be covDim by 2 * states + 1! Got "
@@ -108,15 +173,16 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num,
+ Wc.getNumRows() + " by " + Wc.getNumCols());
}
// New mean is just the sum of the sigmas * weight
// New mean is usually just the sum of the sigmas * weight:
// dot = \Sigma^n_1 (W[k]*Xi[k])
Matrix<C, N1> x = sigmas.times(Matrix.changeBoundsUnchecked(Wm));
Matrix<C, N1> x = meanFunc.apply(sigmas, Wm);
// New covariance is the sum of the outer product of the residuals times the
// weights
Matrix<C, ?> y = new Matrix<>(new SimpleMatrix(dim.getNum(), 2 * s.getNum() + 1));
for (int i = 0; i < 2 * s.getNum() + 1; i++) {
y.setColumn(i, sigmas.extractColumnVector(i).minus(x));
// y[:, i] = sigmas[:, i] - x
y.setColumn(i, residualFunc.apply(sigmas.extractColumnVector(i), x));
}
Matrix<C, C> P = y.times(Matrix.changeBoundsUnchecked(Wc.diag()))
.times(Matrix.changeBoundsUnchecked(y.transpose()));
@@ -234,7 +300,7 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num,
}
var ret = unscentedTransform(m_states, m_states,
m_sigmasF, m_pts.getWm(), m_pts.getWc());
m_sigmasF, m_pts.getWm(), m_pts.getWc(), m_meanFuncX, m_residualFuncX);
m_xHat = ret.getFirst();
m_P = ret.getSecond().plus(discQ);
@@ -250,7 +316,8 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num,
@SuppressWarnings("ParameterName")
@Override
public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
correct(m_outputs, u, y, m_h, m_contR);
correct(m_outputs, u, y, m_h, m_contR,
m_meanFuncY, m_residualFuncY, m_residualFuncX, m_addFuncX);
}
/**
@@ -271,7 +338,11 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num,
Nat<R> rows, Matrix<Inputs, N1> u,
Matrix<R, N1> y,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<R, N1>> h,
Matrix<R, R> R) {
Matrix<R, R> R,
BiFunction<Matrix<R, ?>, Matrix<?, N1>, Matrix<R, N1>> meanFuncY,
BiFunction<Matrix<R, N1>, Matrix<R, N1>, Matrix<R, N1>> residualFuncY,
BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> residualFuncX,
BiFunction<Matrix<States, N1>, Matrix<States, N1>, Matrix<States, N1>> addFuncX) {
final var discR = Discretization.discretizeR(R, m_dtSeconds);
// Transform sigma points into measurement space
@@ -287,18 +358,19 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num,
}
// Mean and covariance of prediction passed through unscented transform
var transRet = unscentedTransform(m_states, rows, sigmasH, m_pts.getWm(), m_pts.getWc());
var transRet = unscentedTransform(m_states, rows,
sigmasH, m_pts.getWm(), m_pts.getWc(), meanFuncY, residualFuncY);
var yHat = transRet.getFirst();
var Py = transRet.getSecond().plus(discR);
// Compute cross covariance of the state and the measurements
Matrix<States, R> Pxy = new Matrix<>(m_states, rows);
for (int i = 0; i < m_pts.getNumSigmas(); i++) {
var temp =
m_sigmasF.extractColumnVector(i).minus(m_xHat)
.times(sigmasH.extractColumnVector(i).minus(yHat).transpose());
// Pxy += (sigmas_f[:, i] - xHat) * (sigmas_h[:, i] - yHat)^T * W_c[i]
var dx = residualFuncX.apply(m_sigmasF.extractColumnVector(i), m_xHat);
var dy = residualFuncY.apply(sigmasH.extractColumnVector(i), yHat).transpose();
Pxy = Pxy.plus(temp.times(m_pts.getWc(i)));
Pxy = Pxy.plus(dx.times(dy).times(m_pts.getWc(i)));
}
// K = P_{xy} Py^-1
@@ -310,7 +382,8 @@ public class UnscentedKalmanFilter<States extends Num, Inputs extends Num,
Py.transpose().solve(Pxy.transpose()).transpose()
);
m_xHat = m_xHat.plus(K.times(y.minus(yHat)));
// xHat + K * (y - yHat)
m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, yHat)));
m_P = m_P.minus(K.times(Py).times(K.transpose()));
}
}

View File

@@ -20,6 +20,7 @@ import edu.wpi.first.wpiutil.math.Num;
import edu.wpi.first.wpiutil.math.VecBuilder;
import edu.wpi.first.wpiutil.math.numbers.N1;
import edu.wpi.first.wpiutil.math.numbers.N3;
import edu.wpi.first.wpiutil.math.numbers.N4;
@SuppressWarnings("ParameterName")
public final class StateSpaceUtil {
@@ -177,4 +178,32 @@ public final class StateSpaceUtil {
return u;
}
/**
* Convert a {@link Pose2d} to a vector of [x, y, cos(theta), sin(theta)],
* where theta is in radians.
*
* @param pose A pose to convert to a vector.
*/
public static Matrix<N4, N1> poseTo4dVector(Pose2d pose) {
return VecBuilder.fill(
pose.getTranslation().getX(),
pose.getTranslation().getY(),
pose.getRotation().getCos(),
pose.getRotation().getSin()
);
}
/**
* Convert a {@link Pose2d} to a vector of [x, y, theta], where theta is in radians.
*
* @param pose A pose to convert to a vector.
* @return The given pose in vector form, with the third element, theta, in radians.
*/
public static Matrix<N3, N1> poseTo3dVector(Pose2d pose) {
return VecBuilder.fill(
pose.getTranslation().getX(),
pose.getTranslation().getY(),
pose.getRotation().getRadians()
);
}
}