From ae7b1851ec704a81dbf0dbc1dcb3a5542bc0ceb5 Mon Sep 17 00:00:00 2001 From: Peter Johnson Date: Fri, 29 Apr 2022 17:24:23 -0700 Subject: [PATCH] [wpimath] KalmanFilter: Use extern template instead of Impl class --- .../native/cpp/estimator/KalmanFilter.cpp | 13 +- .../include/frc/estimator/KalmanFilter.h | 139 ++---------------- .../include/frc/estimator/KalmanFilter.inc | 87 +++++++++++ 3 files changed, 104 insertions(+), 135 deletions(-) create mode 100644 wpimath/src/main/native/include/frc/estimator/KalmanFilter.inc diff --git a/wpimath/src/main/native/cpp/estimator/KalmanFilter.cpp b/wpimath/src/main/native/cpp/estimator/KalmanFilter.cpp index 1209eae0bb..a56efe4f8d 100644 --- a/wpimath/src/main/native/cpp/estimator/KalmanFilter.cpp +++ b/wpimath/src/main/native/cpp/estimator/KalmanFilter.cpp @@ -6,16 +6,7 @@ namespace frc { -KalmanFilter<1, 1, 1>::KalmanFilter( - LinearSystem<1, 1, 1>& plant, const wpi::array& stateStdDevs, - const wpi::array& measurementStdDevs, units::second_t dt) - : detail::KalmanFilterImpl<1, 1, 1>{plant, stateStdDevs, measurementStdDevs, - dt} {} - -KalmanFilter<2, 1, 1>::KalmanFilter( - LinearSystem<2, 1, 1>& plant, const wpi::array& stateStdDevs, - const wpi::array& measurementStdDevs, units::second_t dt) - : detail::KalmanFilterImpl<2, 1, 1>{plant, stateStdDevs, measurementStdDevs, - dt} {} +template class EXPORT_TEMPLATE_DEFINE(WPILIB_DLLEXPORT) KalmanFilter<1, 1, 1>; +template class EXPORT_TEMPLATE_DEFINE(WPILIB_DLLEXPORT) KalmanFilter<2, 1, 1>; } // namespace frc diff --git a/wpimath/src/main/native/include/frc/estimator/KalmanFilter.h b/wpimath/src/main/native/include/frc/estimator/KalmanFilter.h index 3aa4dbd212..dd0abc3edf 100644 --- a/wpimath/src/main/native/include/frc/estimator/KalmanFilter.h +++ b/wpimath/src/main/native/include/frc/estimator/KalmanFilter.h @@ -4,25 +4,14 @@ #pragma once -#include - -#include -#include - #include #include -#include "Eigen/Cholesky" #include "Eigen/Core" -#include "drake/math/discrete_algebraic_riccati_equation.h" -#include "frc/StateSpaceUtil.h" -#include "frc/system/Discretization.h" #include "frc/system/LinearSystem.h" #include "units/time.h" -#include "wpimath/MathShared.h" namespace frc { -namespace detail { /** * A Kalman filter combines predictions from a model and measurements to give an @@ -45,7 +34,7 @@ namespace detail { * @tparam Outputs The number of outputs. */ template -class KalmanFilterImpl { +class KalmanFilter { public: /** * Constructs a state-space observer with the given plant. @@ -55,59 +44,13 @@ class KalmanFilterImpl { * @param measurementStdDevs Standard deviations of measurements. * @param dt Nominal discretization timestep. */ - KalmanFilterImpl(LinearSystem& plant, - const wpi::array& stateStdDevs, - const wpi::array& measurementStdDevs, - units::second_t dt) { - m_plant = &plant; + KalmanFilter(LinearSystem& plant, + const wpi::array& stateStdDevs, + const wpi::array& measurementStdDevs, + units::second_t dt); - auto contQ = MakeCovMatrix(stateStdDevs); - auto contR = MakeCovMatrix(measurementStdDevs); - - Eigen::Matrix discA; - Eigen::Matrix discQ; - DiscretizeAQTaylor(plant.A(), contQ, dt, &discA, &discQ); - - auto discR = DiscretizeR(contR, dt); - - const auto& C = plant.C(); - - if (!IsDetectable(discA, C)) { - std::string msg = fmt::format( - "The system passed to the Kalman filter is " - "unobservable!\n\nA =\n{}\nC =\n{}\n", - discA, C); - - wpi::math::MathSharedStore::ReportError(msg); - throw std::invalid_argument(msg); - } - - Eigen::Matrix P = - drake::math::DiscreteAlgebraicRiccatiEquation( - discA.transpose(), C.transpose(), discQ, discR); - - // S = CPCᵀ + R - Eigen::Matrix S = C * P * C.transpose() + discR; - - // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more - // efficiently. - // - // K = PCᵀS⁻¹ - // KS = PCᵀ - // (KS)ᵀ = (PCᵀ)ᵀ - // SᵀKᵀ = CPᵀ - // - // The solution of Ax = b can be found via x = A.solve(b). - // - // Kᵀ = Sᵀ.solve(CPᵀ) - // K = (Sᵀ.solve(CPᵀ))ᵀ - m_K = S.transpose().ldlt().solve(C * P.transpose()).transpose(); - - Reset(); - } - - KalmanFilterImpl(KalmanFilterImpl&&) = default; - KalmanFilterImpl& operator=(KalmanFilterImpl&&) = default; + KalmanFilter(KalmanFilter&&) = default; + KalmanFilter& operator=(KalmanFilter&&) = default; /** * Returns the steady-state Kalman gain matrix K. @@ -160,9 +103,7 @@ class KalmanFilterImpl { * @param u New control input from controller. * @param dt Timestep for prediction. */ - void Predict(const Eigen::Vector& u, units::second_t dt) { - m_xHat = m_plant->CalculateX(m_xHat, u, dt); - } + void Predict(const Eigen::Vector& u, units::second_t dt); /** * Correct the state estimate x-hat using the measurements in y. @@ -171,10 +112,7 @@ class KalmanFilterImpl { * @param y Measurement vector. */ void Correct(const Eigen::Vector& u, - const Eigen::Vector& y) { - // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − (Cx̂ₖ₊₁⁻ + Duₖ₊₁)) - m_xHat += m_K * (y - (m_plant->C() * m_xHat + m_plant->D() * u)); - } + const Eigen::Vector& y); private: LinearSystem* m_plant; @@ -190,58 +128,11 @@ class KalmanFilterImpl { Eigen::Vector m_xHat; }; -} // namespace detail - -template -class KalmanFilter : public detail::KalmanFilterImpl { - public: - /** - * Constructs a state-space observer with the given plant. - * - * @param plant The plant used for the prediction step. - * @param stateStdDevs Standard deviations of model states. - * @param measurementStdDevs Standard deviations of measurements. - * @param dt Nominal discretization timestep. - */ - KalmanFilter(LinearSystem& plant, - const wpi::array& stateStdDevs, - const wpi::array& measurementStdDevs, - units::second_t dt) - : detail::KalmanFilterImpl{ - plant, stateStdDevs, measurementStdDevs, dt} {} - - KalmanFilter(KalmanFilter&&) = default; - KalmanFilter& operator=(KalmanFilter&&) = default; -}; - -// Template specializations are used here to make common state-input-output -// triplets compile faster. -template <> -class WPILIB_DLLEXPORT KalmanFilter<1, 1, 1> - : public detail::KalmanFilterImpl<1, 1, 1> { - public: - KalmanFilter(LinearSystem<1, 1, 1>& plant, - const wpi::array& stateStdDevs, - const wpi::array& measurementStdDevs, - units::second_t dt); - - KalmanFilter(KalmanFilter&&) = default; - KalmanFilter& operator=(KalmanFilter&&) = default; -}; - -// Template specializations are used here to make common state-input-output -// triplets compile faster. -template <> -class WPILIB_DLLEXPORT KalmanFilter<2, 1, 1> - : public detail::KalmanFilterImpl<2, 1, 1> { - public: - KalmanFilter(LinearSystem<2, 1, 1>& plant, - const wpi::array& stateStdDevs, - const wpi::array& measurementStdDevs, - units::second_t dt); - - KalmanFilter(KalmanFilter&&) = default; - KalmanFilter& operator=(KalmanFilter&&) = default; -}; +extern template class EXPORT_TEMPLATE_DECLARE(WPILIB_DLLEXPORT) + KalmanFilter<1, 1, 1>; +extern template class EXPORT_TEMPLATE_DECLARE(WPILIB_DLLEXPORT) + KalmanFilter<2, 1, 1>; } // namespace frc + +#include "KalmanFilter.inc" diff --git a/wpimath/src/main/native/include/frc/estimator/KalmanFilter.inc b/wpimath/src/main/native/include/frc/estimator/KalmanFilter.inc new file mode 100644 index 0000000000..df5fb91baa --- /dev/null +++ b/wpimath/src/main/native/include/frc/estimator/KalmanFilter.inc @@ -0,0 +1,87 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#pragma once + +#include + +#include +#include + +#include "Eigen/Cholesky" +#include "drake/math/discrete_algebraic_riccati_equation.h" +#include "frc/StateSpaceUtil.h" +#include "frc/estimator/KalmanFilter.h" +#include "frc/system/Discretization.h" +#include "wpimath/MathShared.h" + +namespace frc { + +template +KalmanFilter::KalmanFilter( + LinearSystem& plant, + const wpi::array& stateStdDevs, + const wpi::array& measurementStdDevs, units::second_t dt) { + m_plant = &plant; + + auto contQ = MakeCovMatrix(stateStdDevs); + auto contR = MakeCovMatrix(measurementStdDevs); + + Eigen::Matrix discA; + Eigen::Matrix discQ; + DiscretizeAQTaylor(plant.A(), contQ, dt, &discA, &discQ); + + auto discR = DiscretizeR(contR, dt); + + const auto& C = plant.C(); + + if (!IsDetectable(discA, C)) { + std::string msg = fmt::format( + "The system passed to the Kalman filter is " + "unobservable!\n\nA =\n{}\nC =\n{}\n", + discA, C); + + wpi::math::MathSharedStore::ReportError(msg); + throw std::invalid_argument(msg); + } + + Eigen::Matrix P = + drake::math::DiscreteAlgebraicRiccatiEquation( + discA.transpose(), C.transpose(), discQ, discR); + + // S = CPCᵀ + R + Eigen::Matrix S = C * P * C.transpose() + discR; + + // We want to put K = PCᵀS⁻¹ into Ax = b form so we can solve it more + // efficiently. + // + // K = PCᵀS⁻¹ + // KS = PCᵀ + // (KS)ᵀ = (PCᵀ)ᵀ + // SᵀKᵀ = CPᵀ + // + // The solution of Ax = b can be found via x = A.solve(b). + // + // Kᵀ = Sᵀ.solve(CPᵀ) + // K = (Sᵀ.solve(CPᵀ))ᵀ + m_K = S.transpose().ldlt().solve(C * P.transpose()).transpose(); + + Reset(); +} + +template +void KalmanFilter::Predict( + const Eigen::Vector& u, units::second_t dt) { + m_xHat = m_plant->CalculateX(m_xHat, u, dt); +} + +template +void KalmanFilter::Correct( + const Eigen::Vector& u, + const Eigen::Vector& y) { + // x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − (Cx̂ₖ₊₁⁻ + Duₖ₊₁)) + m_xHat += m_K * (y - (m_plant->C() * m_xHat + m_plant->D() * u)); +} + +} // namespace frc