[wpimath] Simplify pose estimator (#6705)

This commit is contained in:
Joseph Eng
2024-06-28 20:12:12 -07:00
committed by GitHub
parent 5e745bc5ef
commit 512a4bfc12
4 changed files with 276 additions and 205 deletions

View File

@@ -4,7 +4,10 @@
#pragma once
#include <map>
#include <optional>
#include <utility>
#include <vector>
#include <Eigen/Core>
#include <wpi/SymbolExports.h>
@@ -185,57 +188,46 @@ class WPILIB_DLLEXPORT PoseEstimator {
const WheelPositions& wheelPositions);
private:
struct InterpolationRecord {
// The pose observed given the current sensor inputs and the previous pose.
Pose2d pose;
/**
* Removes stale vision updates that won't affect sampling.
*/
void CleanUpVisionUpdates();
// The current gyroscope angle.
Rotation2d gyroAngle;
struct VisionUpdate {
// The vision-compensated pose estimate
Pose2d visionPose;
// The distances traveled by the wheels.
WheelPositions wheelPositions;
// The pose estimated based solely on odometry
Pose2d odometryPose;
/**
* Checks equality between this InterpolationRecord and another object.
* Returns the vision-compensated version of the pose. Specifically, changes
* the pose from being relative to this record's odometry pose to being
* relative to this record's vision pose.
*
* @param other The other object.
* @return Whether the two objects are equal.
* @param pose The pose to compensate.
* @return The compensated pose.
*/
bool operator==(const InterpolationRecord& other) const = default;
/**
* Checks inequality between this InterpolationRecord and another object.
*
* @param other The other object.
* @return Whether the two objects are not equal.
*/
bool operator!=(const InterpolationRecord& other) const = default;
/**
* Interpolates between two InterpolationRecords.
*
* @param endValue The end value for the interpolation.
* @param i The interpolant (fraction).
*
* @return The interpolated state.
*/
InterpolationRecord Interpolate(
Kinematics<WheelSpeeds, WheelPositions>& kinematics,
InterpolationRecord endValue, double i) const;
Pose2d Compensate(const Pose2d& pose) const {
auto delta = pose - odometryPose;
return visionPose + delta;
}
};
static constexpr units::second_t kBufferDuration = 1.5_s;
Kinematics<WheelSpeeds, WheelPositions>& m_kinematics;
Odometry<WheelSpeeds, WheelPositions>& m_odometry;
wpi::array<double, 3> m_q{wpi::empty_array};
Eigen::Matrix3d m_visionK = Eigen::Matrix3d::Zero();
TimeInterpolatableBuffer<InterpolationRecord> m_poseBuffer{
kBufferDuration, [this](const InterpolationRecord& start,
const InterpolationRecord& end, double t) {
return start.Interpolate(this->m_kinematics, end, t);
}};
// Maps timestamps to odometry-only pose estimates
TimeInterpolatableBuffer<Pose2d> m_odometryPoseBuffer{kBufferDuration};
// Maps timestamps to vision updates
// Always contains one entry before the oldest entry in m_odometryPoseBuffer,
// unless there have been no vision measurements after the last reset
std::map<units::second_t, VisionUpdate> m_visionUpdates;
Pose2d m_poseEstimate;
};
} // namespace frc

View File

@@ -14,7 +14,7 @@ PoseEstimator<WheelSpeeds, WheelPositions>::PoseEstimator(
Odometry<WheelSpeeds, WheelPositions>& odometry,
const wpi::array<double, 3>& stateStdDevs,
const wpi::array<double, 3>& visionMeasurementStdDevs)
: m_kinematics(kinematics), m_odometry(odometry) {
: m_odometry(odometry), m_poseEstimate(m_odometry.GetPose()) {
for (size_t i = 0; i < 3; ++i) {
m_q[i] = stateStdDevs[i] * stateStdDevs[i];
}
@@ -48,26 +48,93 @@ void PoseEstimator<WheelSpeeds, WheelPositions>::ResetPosition(
const Pose2d& pose) {
// Reset state estimate and error covariance
m_odometry.ResetPosition(gyroAngle, wheelPositions, pose);
m_poseBuffer.Clear();
m_odometryPoseBuffer.Clear();
m_visionUpdates.clear();
m_poseEstimate = m_odometry.GetPose();
}
template <typename WheelSpeeds, typename WheelPositions>
Pose2d PoseEstimator<WheelSpeeds, WheelPositions>::GetEstimatedPosition()
const {
return m_odometry.GetPose();
return m_poseEstimate;
if (m_visionUpdates.empty()) {
return m_odometry.GetPose();
}
auto visionUpdate = m_visionUpdates.rbegin()->second;
return visionUpdate.Compensate(m_odometry.GetPose());
}
template <typename WheelSpeeds, typename WheelPositions>
std::optional<Pose2d> PoseEstimator<WheelSpeeds, WheelPositions>::SampleAt(
units::second_t timestamp) const {
// TODO Replace with std::optional::transform() in C++23
std::optional<PoseEstimator<WheelSpeeds, WheelPositions>::InterpolationRecord>
sample = m_poseBuffer.Sample(timestamp);
if (sample) {
return sample->pose;
} else {
// Step 0: If there are no odometry updates to sample, skip.
if (m_odometryPoseBuffer.GetInternalBuffer().empty()) {
return std::nullopt;
}
// Step 1: Make sure timestamp matches the sample from the odometry pose
// buffer. (When sampling, the buffer will always use a timestamp
// between the first and last timestamps)
units::second_t oldestOdometryTimestamp =
m_odometryPoseBuffer.GetInternalBuffer().front().first;
units::second_t newestOdometryTimestamp =
m_odometryPoseBuffer.GetInternalBuffer().back().first;
timestamp =
std::clamp(timestamp, oldestOdometryTimestamp, newestOdometryTimestamp);
// Step 2: If there are no applicable vision updates, use the odometry-only
// information.
if (m_visionUpdates.empty() || timestamp < m_visionUpdates.begin()->first) {
return m_odometryPoseBuffer.Sample(timestamp);
}
// Step 3: Get the latest vision update from before or at the timestamp to
// sample at.
// First, find the iterator past the sample timestamp, then go back one. Note
// that upper_bound() won't return begin() because we check begin() earlier.
auto floorIter = m_visionUpdates.upper_bound(timestamp);
--floorIter;
auto visionUpdate = floorIter->second;
// Step 4: Get the pose measured by odometry at the time of the sample.
auto odometryEstimate = m_odometryPoseBuffer.Sample(timestamp);
// Step 5: Apply the vision compensation to the odometry pose.
// TODO Replace with std::optional::transform() in C++23
if (odometryEstimate) {
return visionUpdate.Compensate(*odometryEstimate);
}
return std::nullopt;
}
template <typename WheelSpeeds, typename WheelPositions>
void PoseEstimator<WheelSpeeds, WheelPositions>::CleanUpVisionUpdates() {
// Step 0: If there are no odometry samples, skip.
if (m_odometryPoseBuffer.GetInternalBuffer().empty()) {
return;
}
// Step 1: Find the oldest timestamp that needs a vision update.
units::second_t oldestOdometryTimestamp =
m_odometryPoseBuffer.GetInternalBuffer().front().first;
// Step 2: If there are no vision updates before that timestamp, skip.
if (m_visionUpdates.empty() ||
oldestOdometryTimestamp < m_visionUpdates.begin()->first) {
return;
}
// Step 3: Find the newest vision update timestamp before or at the oldest
// timestamp.
// First, find the iterator past the oldest odometry timestamp, then go
// back one. Note that upper_bound() won't return begin() because we check
// begin() earlier.
auto newestNeededVisionUpdate =
m_visionUpdates.upper_bound(oldestOdometryTimestamp);
--newestNeededVisionUpdate;
// Step 4: Remove all entries strictly before the newest timestamp we need.
m_visionUpdates.erase(m_visionUpdates.begin(), newestNeededVisionUpdate);
}
template <typename WheelSpeeds, typename WheelPositions>
@@ -75,57 +142,58 @@ void PoseEstimator<WheelSpeeds, WheelPositions>::AddVisionMeasurement(
const Pose2d& visionRobotPose, units::second_t timestamp) {
// Step 0: If this measurement is old enough to be outside the pose buffer's
// timespan, skip.
if (!m_poseBuffer.GetInternalBuffer().empty() &&
m_poseBuffer.GetInternalBuffer().front().first - kBufferDuration >
if (m_odometryPoseBuffer.GetInternalBuffer().empty() ||
m_odometryPoseBuffer.GetInternalBuffer().front().first - kBufferDuration >
timestamp) {
return;
}
// Step 1: Get the estimated pose from when the vision measurement was made.
auto sample = m_poseBuffer.Sample(timestamp);
// Step 1: Clean up any old entries
CleanUpVisionUpdates();
if (!sample.has_value()) {
// Step 2: Get the pose measured by odometry at the moment the vision
// measurement was made.
auto odometrySample = m_odometryPoseBuffer.Sample(timestamp);
if (!odometrySample) {
return;
}
// Step 2: Measure the twist between the odometry pose and the vision pose
auto twist = sample.value().pose.Log(visionRobotPose);
// Step 3: Get the vision-compensated pose estimate at the moment the vision
// measurement was made.
auto visionSample = SampleAt(timestamp);
// Step 3: We should not trust the twist entirely, so instead we scale this
if (!visionSample) {
return;
}
// Step 4: Measure the twist between the old pose estimate and the vision
// pose.
auto twist = visionSample.value().Log(visionRobotPose);
// Step 5: We should not trust the twist entirely, so instead we scale this
// twist by a Kalman gain matrix representing how much we trust vision
// measurements compared to our current pose.
Eigen::Vector3d k_times_twist =
m_visionK *
Eigen::Vector3d{twist.dx.value(), twist.dy.value(), twist.dtheta.value()};
// Step 4: Convert back to Twist2d
// Step 6: Convert back to Twist2d.
Twist2d scaledTwist{units::meter_t{k_times_twist(0)},
units::meter_t{k_times_twist(1)},
units::radian_t{k_times_twist(2)}};
// Step 5: Reset Odometry to state at sample with vision adjustment.
m_odometry.ResetPosition(sample.value().gyroAngle,
sample.value().wheelPositions,
sample.value().pose.Exp(scaledTwist));
// Step 7: Calculate and record the vision update.
VisionUpdate visionUpdate{visionSample->Exp(scaledTwist), *odometrySample};
m_visionUpdates[timestamp] = visionUpdate;
// Step 6: Record the current pose to allow multiple measurements from the
// same timestamp
m_poseBuffer.AddSample(timestamp,
{GetEstimatedPosition(), sample.value().gyroAngle,
sample.value().wheelPositions});
// Step 8: Remove later vision measurements. (Matches previous behavior)
auto firstAfter = m_visionUpdates.upper_bound(timestamp);
m_visionUpdates.erase(firstAfter, m_visionUpdates.end());
// Step 7: Replay odometry inputs between sample time and latest recorded
// sample to update the pose buffer and correct odometry.
auto internal_buf = m_poseBuffer.GetInternalBuffer();
auto upper_bound =
std::lower_bound(internal_buf.begin(), internal_buf.end(), timestamp,
[](const auto& pair, auto t) { return t > pair.first; });
for (auto entry = upper_bound; entry != internal_buf.end(); entry++) {
UpdateWithTime(entry->first, entry->second.gyroAngle,
entry->second.wheelPositions);
}
// Step 9: Update latest pose estimate. Since we cleared all updates after
// this vision update, it's guaranteed to be the latest vision update.
m_poseEstimate = visionUpdate.Compensate(m_odometry.GetPose());
}
template <typename WheelSpeeds, typename WheelPositions>
@@ -139,40 +207,18 @@ template <typename WheelSpeeds, typename WheelPositions>
Pose2d PoseEstimator<WheelSpeeds, WheelPositions>::UpdateWithTime(
units::second_t currentTime, const Rotation2d& gyroAngle,
const WheelPositions& wheelPositions) {
m_odometry.Update(gyroAngle, wheelPositions);
auto odometryEstimate = m_odometry.Update(gyroAngle, wheelPositions);
WheelPositions internalWheelPositions = wheelPositions;
m_odometryPoseBuffer.AddSample(currentTime, odometryEstimate);
m_poseBuffer.AddSample(
currentTime, {GetEstimatedPosition(), gyroAngle, internalWheelPositions});
if (m_visionUpdates.empty()) {
m_poseEstimate = odometryEstimate;
} else {
auto visionUpdate = m_visionUpdates.rbegin()->second;
m_poseEstimate = visionUpdate.Compensate(odometryEstimate);
}
return GetEstimatedPosition();
}
template <typename WheelSpeeds, typename WheelPositions>
typename PoseEstimator<WheelSpeeds, WheelPositions>::InterpolationRecord
PoseEstimator<WheelSpeeds, WheelPositions>::InterpolationRecord::Interpolate(
Kinematics<WheelSpeeds, WheelPositions>& kinematics,
InterpolationRecord endValue, double i) const {
if (i < 0) {
return *this;
} else if (i > 1) {
return endValue;
} else {
// Find the new wheel distance measurements.
WheelPositions wheels_lerp = kinematics.Interpolate(
this->wheelPositions, endValue.wheelPositions, i);
// Find the new gyro angle.
auto gyro = wpi::Lerp(this->gyroAngle, endValue.gyroAngle, i);
// Create a twist to represent the change based on the interpolated
// sensor inputs.
auto twist = kinematics.ToTwist2d(this->wheelPositions, wheels_lerp);
twist.dtheta = (gyro - gyroAngle).Radians();
return {pose.Exp(twist), gyro, wheels_lerp};
}
}
} // namespace frc

View File

@@ -144,6 +144,13 @@ class TimeInterpolatableBuffer {
return m_pastSnapshots;
}
/**
* Grant access to the internal sample buffer.
*/
const std::vector<std::pair<units::second_t, T>>& GetInternalBuffer() const {
return m_pastSnapshots;
}
private:
units::second_t m_historySize;
std::vector<std::pair<units::second_t, T>> m_pastSnapshots;