diff --git a/wpimath/src/main/java/edu/wpi/first/math/filter/LinearFilter.java b/wpimath/src/main/java/edu/wpi/first/math/filter/LinearFilter.java index 93d6bea9e4..5eac0b3ba1 100644 --- a/wpimath/src/main/java/edu/wpi/first/math/filter/LinearFilter.java +++ b/wpimath/src/main/java/edu/wpi/first/math/filter/LinearFilter.java @@ -138,42 +138,41 @@ public class LinearFilter { } /** - * Creates a backward finite difference filter that computes the nth derivative of the input given - * the specified number of samples. + * Creates a finite difference filter that computes the nth derivative of the input given the + * specified stencil points. * - *

For example, a first derivative filter that uses two samples and a sample period of 20 ms - * would be - * - *


-   * LinearFilter.backwardFiniteDifference(1, 2, 0.02);
-   * 
+ *

Stencil points are the indices of the samples to use in the finite difference. 0 is the + * current sample, -1 is the previous sample, -2 is the sample before that, etc. Don't use + * positive stencil points (samples from the future) if the LinearFilter will be used for + * stream-based online filtering. * * @param derivative The order of the derivative to compute. * @param samples The number of samples to use to compute the given derivative. This must be one * more than the order of derivative or higher. + * @param stencil List of stencil points. * @param period The period in seconds between samples taken by the user. * @return Linear filter. + * @throws IllegalArgumentException if derivative < 1, samples <= 0, or derivative >= + * samples. */ @SuppressWarnings("LocalVariableName") - public static LinearFilter backwardFiniteDifference(int derivative, int samples, double period) { + public static LinearFilter finiteDifference( + int derivative, int samples, int[] stencil, double period) { // See // https://en.wikipedia.org/wiki/Finite_difference_coefficient#Arbitrary_stencil_points // - //

For a given list of stencil points s of length n and the order of + // For a given list of stencil points s of length n and the order of // derivative d < n, the finite difference coefficients can be obtained by // solving the following linear system for the vector a. // - //

     // [s₁⁰   ⋯  sₙ⁰ ][a₁]      [ δ₀,d ]
     // [ ⋮    ⋱  ⋮   ][⋮ ] = d! [  ⋮   ]
     // [s₁ⁿ⁻¹ ⋯ sₙⁿ⁻¹][aₙ]      [δₙ₋₁,d]
-    // 
// - //

where δᵢ,ⱼ are the Kronecker delta. For backward finite difference, - // the stencil points are the range [-n + 1, 0]. The FIR gains are the - // elements of the vector a in reverse order divided by hᵈ. + // where δᵢ,ⱼ are the Kronecker delta. The FIR gains are the elements of the + // vector a in reverse order divided by hᵈ. // - //

The order of accuracy of the approximation is of the form O(hⁿ⁻ᵈ). + // The order of accuracy of the approximation is of the form O(hⁿ⁻ᵈ). if (derivative < 1) { throw new IllegalArgumentException( @@ -192,8 +191,7 @@ public class LinearFilter { var S = new SimpleMatrix(samples, samples); for (int row = 0; row < samples; ++row) { for (int col = 0; col < samples; ++col) { - double s = 1 - samples + col; - S.set(row, col, Math.pow(s, row)); + S.set(row, col, Math.pow(stencil[col], row)); } } @@ -211,9 +209,34 @@ public class LinearFilter { ffGains[i] = a.get(samples - i - 1, 0); } - double[] fbGains = new double[0]; + return new LinearFilter(ffGains, new double[0]); + } - return new LinearFilter(ffGains, fbGains); + /** + * Creates a backward finite difference filter that computes the nth derivative of the input given + * the specified number of samples. + * + *

For example, a first derivative filter that uses two samples and a sample period of 20 ms + * would be + * + *


+   * LinearFilter.backwardFiniteDifference(1, 2, 0.02);
+   * 
+ * + * @param derivative The order of the derivative to compute. + * @param samples The number of samples to use to compute the given derivative. This must be one + * more than the order of derivative or higher. + * @param period The period in seconds between samples taken by the user. + * @return Linear filter. + */ + public static LinearFilter backwardFiniteDifference(int derivative, int samples, double period) { + // Generate stencil points from -(samples - 1) to 0 + int[] stencil = new int[samples]; + for (int i = 0; i < samples; ++i) { + stencil[i] = -(samples - 1) + i; + } + + return finiteDifference(derivative, samples, stencil, period); } /** Reset the filter state. */ diff --git a/wpimath/src/main/native/include/frc/filter/LinearFilter.h b/wpimath/src/main/native/include/frc/filter/LinearFilter.h index 92d8bdcf53..62953271dd 100644 --- a/wpimath/src/main/native/include/frc/filter/LinearFilter.h +++ b/wpimath/src/main/native/include/frc/filter/LinearFilter.h @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -166,6 +167,73 @@ class LinearFilter { return LinearFilter(gains, {}); } + /** + * Creates a finite difference filter that computes the nth derivative of the + * input given the specified stencil points. + * + * Stencil points are the indices of the samples to use in the finite + * difference. 0 is the current sample, -1 is the previous sample, -2 is the + * sample before that, etc. Don't use positive stencil points (samples from + * the future) if the LinearFilter will be used for stream-based online + * filtering. + * + * @tparam Derivative The order of the derivative to compute. + * @tparam Samples The number of samples to use to compute the given + * derivative. This must be one more than the order of + * derivative or higher. + * @param stencil List of stencil points. + * @param period The period in seconds between samples taken by the user. + */ + template + static LinearFilter FiniteDifference( + const wpi::array stencil, units::second_t period) { + // See + // https://en.wikipedia.org/wiki/Finite_difference_coefficient#Arbitrary_stencil_points + // + // For a given list of stencil points s of length n and the order of + // derivative d < n, the finite difference coefficients can be obtained by + // solving the following linear system for the vector a. + // + // [s₁⁰ ⋯ sₙ⁰ ][a₁] [ δ₀,d ] + // [ ⋮ ⋱ ⋮ ][⋮ ] = d! [ ⋮ ] + // [s₁ⁿ⁻¹ ⋯ sₙⁿ⁻¹][aₙ] [δₙ₋₁,d] + // + // where δᵢ,ⱼ are the Kronecker delta. The FIR gains are the elements of the + // vector a in reverse order divided by hᵈ. + // + // The order of accuracy of the approximation is of the form O(hⁿ⁻ᵈ). + + static_assert(Derivative >= 1, + "Order of derivative must be greater than or equal to one."); + static_assert(Samples > 0, "Number of samples must be greater than zero."); + static_assert(Derivative < Samples, + "Order of derivative must be less than number of samples."); + + Eigen::Matrix S; + for (int row = 0; row < Samples; ++row) { + for (int col = 0; col < Samples; ++col) { + S(row, col) = std::pow(stencil[col], row); + } + } + + // Fill in Kronecker deltas: https://en.wikipedia.org/wiki/Kronecker_delta + Eigen::Vector d; + for (int i = 0; i < Samples; ++i) { + d(i) = (i == Derivative) ? Factorial(Derivative) : 0.0; + } + + Eigen::Vector a = + S.householderQr().solve(d) / std::pow(period.value(), Derivative); + + // Reverse gains list + std::vector ffGains; + for (int i = Samples - 1; i >= 0; --i) { + ffGains.push_back(a(i)); + } + + return LinearFilter(ffGains, {}); + } + /** * Creates a backward finite difference filter that computes the nth * derivative of the input given the specified number of samples. @@ -184,56 +252,14 @@ class LinearFilter { * @param period The period in seconds between samples taken by the user. */ template - static auto BackwardFiniteDifference(units::second_t period) { - // See - // https://en.wikipedia.org/wiki/Finite_difference_coefficient#Arbitrary_stencil_points - // - // For a given list of stencil points s of length n and the order of - // derivative d < n, the finite difference coefficients can be obtained by - // solving the following linear system for the vector a. - // - // @verbatim - // [s₁⁰ ⋯ sₙ⁰ ][a₁] [ δ₀,d ] - // [ ⋮ ⋱ ⋮ ][⋮ ] = d! [ ⋮ ] - // [s₁ⁿ⁻¹ ⋯ sₙⁿ⁻¹][aₙ] [δₙ₋₁,d] - // @endverbatim - // - // where δᵢ,ⱼ are the Kronecker delta. For backward finite difference, the - // stencil points are the range [-n + 1, 0]. The FIR gains are the elements - // of the vector a in reverse order divided by hᵈ. - // - // The order of accuracy of the approximation is of the form O(hⁿ⁻ᵈ). - - static_assert(Derivative >= 1, - "Order of derivative must be greater than or equal to one."); - static_assert(Samples > 0, "Number of samples must be greater than zero."); - static_assert(Derivative < Samples, - "Order of derivative must be less than number of samples."); - - Eigen::Matrix S; - for (int row = 0; row < Samples; ++row) { - for (int col = 0; col < Samples; ++col) { - double s = 1 - Samples + col; - S(row, col) = std::pow(s, row); - } - } - - // Fill in Kronecker deltas: https://en.wikipedia.org/wiki/Kronecker_delta - Eigen::Vector d; + static LinearFilter BackwardFiniteDifference(units::second_t period) { + // Generate stencil points from -(samples - 1) to 0 + wpi::array stencil{wpi::empty_array}; for (int i = 0; i < Samples; ++i) { - d(i) = (i == Derivative) ? Factorial(Derivative) : 0.0; + stencil[i] = -(Samples - 1) + i; } - Eigen::Vector a = - S.householderQr().solve(d) / std::pow(period.value(), Derivative); - - // Reverse gains list - std::vector gains; - for (int i = Samples - 1; i >= 0; --i) { - gains.push_back(a(i)); - } - - return LinearFilter(gains, {}); + return FiniteDifference(stencil, period); } /** diff --git a/wpimath/src/test/java/edu/wpi/first/math/filter/LinearFilterTest.java b/wpimath/src/test/java/edu/wpi/first/math/filter/LinearFilterTest.java index 805129f95b..f0b39c6b91 100644 --- a/wpimath/src/test/java/edu/wpi/first/math/filter/LinearFilterTest.java +++ b/wpimath/src/test/java/edu/wpi/first/math/filter/LinearFilterTest.java @@ -109,12 +109,84 @@ class LinearFilterTest { 0.0)); } + /** Test central finite difference. */ + @Test + void centralFiniteDifferenceTest() { + double h = 0.005; + + assertCentralResults( + 1, + 3, + // f(x) = x² + (double x) -> x * x, + // df/dx = 2x + (double x) -> 2.0 * x, + h, + -20.0, + 20.0); + + assertCentralResults( + 1, + 3, + // f(x) = sin(x) + (double x) -> Math.sin(x), + // df/dx = cos(x) + (double x) -> Math.cos(x), + h, + -20.0, + 20.0); + + assertCentralResults( + 1, + 3, + // f(x) = ln(x) + (double x) -> Math.log(x), + // df/dx = 1 / x + (double x) -> 1.0 / x, + h, + 1.0, + 20.0); + + assertCentralResults( + 2, + 5, + // f(x) = x² + (double x) -> x * x, + // d²f/dx² = 2 + (double x) -> 2.0, + h, + -20.0, + 20.0); + + assertCentralResults( + 2, + 5, + // f(x) = sin(x) + (double x) -> Math.sin(x), + // d²f/dx² = -sin(x) + (double x) -> -Math.sin(x), + h, + -20.0, + 20.0); + + assertCentralResults( + 2, + 5, + // f(x) = ln(x) + (double x) -> Math.log(x), + // d²f/dx² = -1 / x² + (double x) -> -1.0 / (x * x), + h, + 1.0, + 20.0); + } + /** Test backward finite difference. */ @Test void backwardFiniteDifferenceTest() { double h = 0.005; - assertResults( + assertBackwardResults( 1, 2, // f(x) = x² @@ -125,7 +197,7 @@ class LinearFilterTest { -20.0, 20.0); - assertResults( + assertBackwardResults( 1, 2, // f(x) = sin(x) @@ -136,7 +208,7 @@ class LinearFilterTest { -20.0, 20.0); - assertResults( + assertBackwardResults( 1, 2, // f(x) = ln(x) @@ -147,7 +219,7 @@ class LinearFilterTest { 1.0, 20.0); - assertResults( + assertBackwardResults( 2, 4, // f(x) = x² @@ -158,7 +230,7 @@ class LinearFilterTest { -20.0, 20.0); - assertResults( + assertBackwardResults( 2, 4, // f(x) = sin(x) @@ -169,7 +241,7 @@ class LinearFilterTest { -20.0, 20.0); - assertResults( + assertBackwardResults( 2, 4, // f(x) = ln(x) @@ -181,6 +253,53 @@ class LinearFilterTest { 20.0); } + /** + * Helper for checking results of central finite difference. + * + * @param derivative The order of the derivative. + * @param samples The number of sample points. + * @param f Function of which to take derivative. + * @param dfdx Derivative of f. + * @param h Sample period in seconds. + * @param min Minimum of f's domain to test. + * @param max Maximum of f's domain to test. + */ + void assertCentralResults( + int derivative, + int samples, + DoubleFunction f, + DoubleFunction dfdx, + double h, + double min, + double max) { + if (samples % 2 == 0) { + throw new IllegalArgumentException("Number of samples must be odd."); + } + + // Generate stencil points from -(samples - 1)/2 to (samples - 1)/2 + int[] stencil = new int[samples]; + for (int i = 0; i < samples; ++i) { + stencil[i] = -(samples - 1) / 2 + i; + } + + var filter = LinearFilter.finiteDifference(derivative, samples, stencil, h); + + for (int i = (int) (min / h); i < (int) (max / h); ++i) { + // Let filter initialize + if (i < (int) (min / h) + samples) { + filter.calculate(f.apply(i * h)); + continue; + } + + // The order of accuracy is O(h^(N - d)) where N is number of stencil + // points and d is order of derivative + assertEquals( + dfdx.apply((i - samples / 2) * h), + filter.calculate(f.apply(i * h)), + Math.pow(h, samples - derivative)); + } + } + /** * Helper for checking results of backward finite difference. * @@ -192,7 +311,7 @@ class LinearFilterTest { * @param min Minimum of f's domain to test. * @param max Maximum of f's domain to test. */ - void assertResults( + void assertBackwardResults( int derivative, int samples, DoubleFunction f, @@ -209,6 +328,8 @@ class LinearFilterTest { continue; } + // For central finite difference, the derivative computed at this point is + // half the window size in the past. // The order of accuracy is O(h^(N - d)) where N is number of stencil // points and d is order of derivative assertEquals( diff --git a/wpimath/src/test/native/cpp/filter/LinearFilterOutputTest.cpp b/wpimath/src/test/native/cpp/filter/LinearFilterOutputTest.cpp index bca3f9d9ee..e9f228b95b 100644 --- a/wpimath/src/test/native/cpp/filter/LinearFilterOutputTest.cpp +++ b/wpimath/src/test/native/cpp/filter/LinearFilterOutputTest.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include "gtest/gtest.h" @@ -120,8 +121,40 @@ INSTANTIATE_TEST_SUITE_P(Tests, LinearFilterOutputTest, kTestMovAvg, kTestPulse)); template -void AssertResults(F&& f, DfDx&& dfdx, units::second_t h, double min, - double max) { +void AssertCentralResults(F&& f, DfDx&& dfdx, units::second_t h, double min, + double max) { + static_assert(Samples % 2 != 0, "Number of samples must be odd."); + + // Generate stencil points from -(samples - 1)/2 to (samples - 1)/2 + wpi::array stencil{wpi::empty_array}; + for (int i = 0; i < Samples; ++i) { + stencil[i] = -(Samples - 1) / 2 + i; + } + + auto filter = + frc::LinearFilter::FiniteDifference(stencil, + h); + + for (int i = min / h.value(); i < max / h.value(); ++i) { + // Let filter initialize + if (i < static_cast(min / h.value()) + Samples) { + filter.Calculate(f(i * h.value())); + continue; + } + + // For central finite difference, the derivative computed at this point is + // half the window size in the past. + // The order of accuracy is O(h^(N - d)) where N is number of stencil + // points and d is order of derivative + EXPECT_NEAR(dfdx((i - (Samples - 1) / 2) * h.value()), + filter.Calculate(f(i * h.value())), + std::pow(h.value(), Samples - Derivative)); + } +} + +template +void AssertBackwardResults(F&& f, DfDx&& dfdx, units::second_t h, double min, + double max) { auto filter = frc::LinearFilter::BackwardFiniteDifference( h); @@ -141,12 +174,12 @@ void AssertResults(F&& f, DfDx&& dfdx, units::second_t h, double min, } /** - * Test backward finite difference. + * Test central finite difference. */ -TEST(LinearFilterOutputTest, BackwardFiniteDifference) { +TEST(LinearFilterOutputTest, CentralFiniteDifference) { constexpr auto h = 5_ms; - AssertResults<1, 2>( + AssertCentralResults<1, 3>( [](double x) { // f(x) = x² return x * x; @@ -157,7 +190,7 @@ TEST(LinearFilterOutputTest, BackwardFiniteDifference) { }, h, -20.0, 20.0); - AssertResults<1, 2>( + AssertCentralResults<1, 3>( [](double x) { // f(x) = std::sin(x) return std::sin(x); @@ -168,7 +201,7 @@ TEST(LinearFilterOutputTest, BackwardFiniteDifference) { }, h, -20.0, 20.0); - AssertResults<1, 2>( + AssertCentralResults<1, 3>( [](double x) { // f(x) = ln(x) return std::log(x); @@ -179,7 +212,7 @@ TEST(LinearFilterOutputTest, BackwardFiniteDifference) { }, h, 1.0, 20.0); - AssertResults<2, 4>( + AssertCentralResults<2, 5>( [](double x) { // f(x) = x^2 return x * x; @@ -190,7 +223,7 @@ TEST(LinearFilterOutputTest, BackwardFiniteDifference) { }, h, -20.0, 20.0); - AssertResults<2, 4>( + AssertCentralResults<2, 5>( [](double x) { // f(x) = std::sin(x) return std::sin(x); @@ -201,7 +234,80 @@ TEST(LinearFilterOutputTest, BackwardFiniteDifference) { }, h, -20.0, 20.0); - AssertResults<2, 4>( + AssertCentralResults<2, 5>( + [](double x) { + // f(x) = ln(x) + return std::log(x); + }, + [](double x) { + // d²f/dx² = -1 / x² + return -1.0 / (x * x); + }, + h, 1.0, 20.0); +} + +/** + * Test backward finite difference. + */ +TEST(LinearFilterOutputTest, BackwardFiniteDifference) { + constexpr auto h = 5_ms; + + AssertBackwardResults<1, 2>( + [](double x) { + // f(x) = x² + return x * x; + }, + [](double x) { + // df/dx = 2x + return 2.0 * x; + }, + h, -20.0, 20.0); + + AssertBackwardResults<1, 2>( + [](double x) { + // f(x) = std::sin(x) + return std::sin(x); + }, + [](double x) { + // df/dx = std::cos(x) + return std::cos(x); + }, + h, -20.0, 20.0); + + AssertBackwardResults<1, 2>( + [](double x) { + // f(x) = ln(x) + return std::log(x); + }, + [](double x) { + // df/dx = 1 / x + return 1.0 / x; + }, + h, 1.0, 20.0); + + AssertBackwardResults<2, 4>( + [](double x) { + // f(x) = x^2 + return x * x; + }, + [](double x) { + // d²f/dx² = 2 + return 2.0; + }, + h, -20.0, 20.0); + + AssertBackwardResults<2, 4>( + [](double x) { + // f(x) = std::sin(x) + return std::sin(x); + }, + [](double x) { + // d²f/dx² = -std::sin(x) + return -std::sin(x); + }, + h, -20.0, 20.0); + + AssertBackwardResults<2, 4>( [](double x) { // f(x) = ln(x) return std::log(x);