diff --git a/wpimath/src/main/java/edu/wpi/first/math/filter/LinearFilter.java b/wpimath/src/main/java/edu/wpi/first/math/filter/LinearFilter.java
index 93d6bea9e4..5eac0b3ba1 100644
--- a/wpimath/src/main/java/edu/wpi/first/math/filter/LinearFilter.java
+++ b/wpimath/src/main/java/edu/wpi/first/math/filter/LinearFilter.java
@@ -138,42 +138,41 @@ public class LinearFilter {
}
/**
- * Creates a backward finite difference filter that computes the nth derivative of the input given
- * the specified number of samples.
+ * Creates a finite difference filter that computes the nth derivative of the input given the
+ * specified stencil points.
*
- *
For example, a first derivative filter that uses two samples and a sample period of 20 ms
- * would be
- *
- *
- * LinearFilter.backwardFiniteDifference(1, 2, 0.02);
- *
+ * Stencil points are the indices of the samples to use in the finite difference. 0 is the
+ * current sample, -1 is the previous sample, -2 is the sample before that, etc. Don't use
+ * positive stencil points (samples from the future) if the LinearFilter will be used for
+ * stream-based online filtering.
*
* @param derivative The order of the derivative to compute.
* @param samples The number of samples to use to compute the given derivative. This must be one
* more than the order of derivative or higher.
+ * @param stencil List of stencil points.
* @param period The period in seconds between samples taken by the user.
* @return Linear filter.
+ * @throws IllegalArgumentException if derivative < 1, samples <= 0, or derivative >=
+ * samples.
*/
@SuppressWarnings("LocalVariableName")
- public static LinearFilter backwardFiniteDifference(int derivative, int samples, double period) {
+ public static LinearFilter finiteDifference(
+ int derivative, int samples, int[] stencil, double period) {
// See
// https://en.wikipedia.org/wiki/Finite_difference_coefficient#Arbitrary_stencil_points
//
- //
For a given list of stencil points s of length n and the order of
+ // For a given list of stencil points s of length n and the order of
// derivative d < n, the finite difference coefficients can be obtained by
// solving the following linear system for the vector a.
//
- //
// [s₁⁰ ⋯ sₙ⁰ ][a₁] [ δ₀,d ]
// [ ⋮ ⋱ ⋮ ][⋮ ] = d! [ ⋮ ]
// [s₁ⁿ⁻¹ ⋯ sₙⁿ⁻¹][aₙ] [δₙ₋₁,d]
- //
//
- // where δᵢ,ⱼ are the Kronecker delta. For backward finite difference,
- // the stencil points are the range [-n + 1, 0]. The FIR gains are the
- // elements of the vector a in reverse order divided by hᵈ.
+ // where δᵢ,ⱼ are the Kronecker delta. The FIR gains are the elements of the
+ // vector a in reverse order divided by hᵈ.
//
- //
The order of accuracy of the approximation is of the form O(hⁿ⁻ᵈ).
+ // The order of accuracy of the approximation is of the form O(hⁿ⁻ᵈ).
if (derivative < 1) {
throw new IllegalArgumentException(
@@ -192,8 +191,7 @@ public class LinearFilter {
var S = new SimpleMatrix(samples, samples);
for (int row = 0; row < samples; ++row) {
for (int col = 0; col < samples; ++col) {
- double s = 1 - samples + col;
- S.set(row, col, Math.pow(s, row));
+ S.set(row, col, Math.pow(stencil[col], row));
}
}
@@ -211,9 +209,34 @@ public class LinearFilter {
ffGains[i] = a.get(samples - i - 1, 0);
}
- double[] fbGains = new double[0];
+ return new LinearFilter(ffGains, new double[0]);
+ }
- return new LinearFilter(ffGains, fbGains);
+ /**
+ * Creates a backward finite difference filter that computes the nth derivative of the input given
+ * the specified number of samples.
+ *
+ *
For example, a first derivative filter that uses two samples and a sample period of 20 ms
+ * would be
+ *
+ *
+ * LinearFilter.backwardFiniteDifference(1, 2, 0.02);
+ *
+ *
+ * @param derivative The order of the derivative to compute.
+ * @param samples The number of samples to use to compute the given derivative. This must be one
+ * more than the order of derivative or higher.
+ * @param period The period in seconds between samples taken by the user.
+ * @return Linear filter.
+ */
+ public static LinearFilter backwardFiniteDifference(int derivative, int samples, double period) {
+ // Generate stencil points from -(samples - 1) to 0
+ int[] stencil = new int[samples];
+ for (int i = 0; i < samples; ++i) {
+ stencil[i] = -(samples - 1) + i;
+ }
+
+ return finiteDifference(derivative, samples, stencil, period);
}
/** Reset the filter state. */
diff --git a/wpimath/src/main/native/include/frc/filter/LinearFilter.h b/wpimath/src/main/native/include/frc/filter/LinearFilter.h
index 92d8bdcf53..62953271dd 100644
--- a/wpimath/src/main/native/include/frc/filter/LinearFilter.h
+++ b/wpimath/src/main/native/include/frc/filter/LinearFilter.h
@@ -10,6 +10,7 @@
#include
#include
+#include
#include
#include
@@ -166,6 +167,73 @@ class LinearFilter {
return LinearFilter(gains, {});
}
+ /**
+ * Creates a finite difference filter that computes the nth derivative of the
+ * input given the specified stencil points.
+ *
+ * Stencil points are the indices of the samples to use in the finite
+ * difference. 0 is the current sample, -1 is the previous sample, -2 is the
+ * sample before that, etc. Don't use positive stencil points (samples from
+ * the future) if the LinearFilter will be used for stream-based online
+ * filtering.
+ *
+ * @tparam Derivative The order of the derivative to compute.
+ * @tparam Samples The number of samples to use to compute the given
+ * derivative. This must be one more than the order of
+ * derivative or higher.
+ * @param stencil List of stencil points.
+ * @param period The period in seconds between samples taken by the user.
+ */
+ template
+ static LinearFilter FiniteDifference(
+ const wpi::array stencil, units::second_t period) {
+ // See
+ // https://en.wikipedia.org/wiki/Finite_difference_coefficient#Arbitrary_stencil_points
+ //
+ // For a given list of stencil points s of length n and the order of
+ // derivative d < n, the finite difference coefficients can be obtained by
+ // solving the following linear system for the vector a.
+ //
+ // [s₁⁰ ⋯ sₙ⁰ ][a₁] [ δ₀,d ]
+ // [ ⋮ ⋱ ⋮ ][⋮ ] = d! [ ⋮ ]
+ // [s₁ⁿ⁻¹ ⋯ sₙⁿ⁻¹][aₙ] [δₙ₋₁,d]
+ //
+ // where δᵢ,ⱼ are the Kronecker delta. The FIR gains are the elements of the
+ // vector a in reverse order divided by hᵈ.
+ //
+ // The order of accuracy of the approximation is of the form O(hⁿ⁻ᵈ).
+
+ static_assert(Derivative >= 1,
+ "Order of derivative must be greater than or equal to one.");
+ static_assert(Samples > 0, "Number of samples must be greater than zero.");
+ static_assert(Derivative < Samples,
+ "Order of derivative must be less than number of samples.");
+
+ Eigen::Matrix S;
+ for (int row = 0; row < Samples; ++row) {
+ for (int col = 0; col < Samples; ++col) {
+ S(row, col) = std::pow(stencil[col], row);
+ }
+ }
+
+ // Fill in Kronecker deltas: https://en.wikipedia.org/wiki/Kronecker_delta
+ Eigen::Vector d;
+ for (int i = 0; i < Samples; ++i) {
+ d(i) = (i == Derivative) ? Factorial(Derivative) : 0.0;
+ }
+
+ Eigen::Vector a =
+ S.householderQr().solve(d) / std::pow(period.value(), Derivative);
+
+ // Reverse gains list
+ std::vector ffGains;
+ for (int i = Samples - 1; i >= 0; --i) {
+ ffGains.push_back(a(i));
+ }
+
+ return LinearFilter(ffGains, {});
+ }
+
/**
* Creates a backward finite difference filter that computes the nth
* derivative of the input given the specified number of samples.
@@ -184,56 +252,14 @@ class LinearFilter {
* @param period The period in seconds between samples taken by the user.
*/
template
- static auto BackwardFiniteDifference(units::second_t period) {
- // See
- // https://en.wikipedia.org/wiki/Finite_difference_coefficient#Arbitrary_stencil_points
- //
- // For a given list of stencil points s of length n and the order of
- // derivative d < n, the finite difference coefficients can be obtained by
- // solving the following linear system for the vector a.
- //
- // @verbatim
- // [s₁⁰ ⋯ sₙ⁰ ][a₁] [ δ₀,d ]
- // [ ⋮ ⋱ ⋮ ][⋮ ] = d! [ ⋮ ]
- // [s₁ⁿ⁻¹ ⋯ sₙⁿ⁻¹][aₙ] [δₙ₋₁,d]
- // @endverbatim
- //
- // where δᵢ,ⱼ are the Kronecker delta. For backward finite difference, the
- // stencil points are the range [-n + 1, 0]. The FIR gains are the elements
- // of the vector a in reverse order divided by hᵈ.
- //
- // The order of accuracy of the approximation is of the form O(hⁿ⁻ᵈ).
-
- static_assert(Derivative >= 1,
- "Order of derivative must be greater than or equal to one.");
- static_assert(Samples > 0, "Number of samples must be greater than zero.");
- static_assert(Derivative < Samples,
- "Order of derivative must be less than number of samples.");
-
- Eigen::Matrix S;
- for (int row = 0; row < Samples; ++row) {
- for (int col = 0; col < Samples; ++col) {
- double s = 1 - Samples + col;
- S(row, col) = std::pow(s, row);
- }
- }
-
- // Fill in Kronecker deltas: https://en.wikipedia.org/wiki/Kronecker_delta
- Eigen::Vector d;
+ static LinearFilter BackwardFiniteDifference(units::second_t period) {
+ // Generate stencil points from -(samples - 1) to 0
+ wpi::array stencil{wpi::empty_array};
for (int i = 0; i < Samples; ++i) {
- d(i) = (i == Derivative) ? Factorial(Derivative) : 0.0;
+ stencil[i] = -(Samples - 1) + i;
}
- Eigen::Vector a =
- S.householderQr().solve(d) / std::pow(period.value(), Derivative);
-
- // Reverse gains list
- std::vector gains;
- for (int i = Samples - 1; i >= 0; --i) {
- gains.push_back(a(i));
- }
-
- return LinearFilter(gains, {});
+ return FiniteDifference(stencil, period);
}
/**
diff --git a/wpimath/src/test/java/edu/wpi/first/math/filter/LinearFilterTest.java b/wpimath/src/test/java/edu/wpi/first/math/filter/LinearFilterTest.java
index 805129f95b..f0b39c6b91 100644
--- a/wpimath/src/test/java/edu/wpi/first/math/filter/LinearFilterTest.java
+++ b/wpimath/src/test/java/edu/wpi/first/math/filter/LinearFilterTest.java
@@ -109,12 +109,84 @@ class LinearFilterTest {
0.0));
}
+ /** Test central finite difference. */
+ @Test
+ void centralFiniteDifferenceTest() {
+ double h = 0.005;
+
+ assertCentralResults(
+ 1,
+ 3,
+ // f(x) = x²
+ (double x) -> x * x,
+ // df/dx = 2x
+ (double x) -> 2.0 * x,
+ h,
+ -20.0,
+ 20.0);
+
+ assertCentralResults(
+ 1,
+ 3,
+ // f(x) = sin(x)
+ (double x) -> Math.sin(x),
+ // df/dx = cos(x)
+ (double x) -> Math.cos(x),
+ h,
+ -20.0,
+ 20.0);
+
+ assertCentralResults(
+ 1,
+ 3,
+ // f(x) = ln(x)
+ (double x) -> Math.log(x),
+ // df/dx = 1 / x
+ (double x) -> 1.0 / x,
+ h,
+ 1.0,
+ 20.0);
+
+ assertCentralResults(
+ 2,
+ 5,
+ // f(x) = x²
+ (double x) -> x * x,
+ // d²f/dx² = 2
+ (double x) -> 2.0,
+ h,
+ -20.0,
+ 20.0);
+
+ assertCentralResults(
+ 2,
+ 5,
+ // f(x) = sin(x)
+ (double x) -> Math.sin(x),
+ // d²f/dx² = -sin(x)
+ (double x) -> -Math.sin(x),
+ h,
+ -20.0,
+ 20.0);
+
+ assertCentralResults(
+ 2,
+ 5,
+ // f(x) = ln(x)
+ (double x) -> Math.log(x),
+ // d²f/dx² = -1 / x²
+ (double x) -> -1.0 / (x * x),
+ h,
+ 1.0,
+ 20.0);
+ }
+
/** Test backward finite difference. */
@Test
void backwardFiniteDifferenceTest() {
double h = 0.005;
- assertResults(
+ assertBackwardResults(
1,
2,
// f(x) = x²
@@ -125,7 +197,7 @@ class LinearFilterTest {
-20.0,
20.0);
- assertResults(
+ assertBackwardResults(
1,
2,
// f(x) = sin(x)
@@ -136,7 +208,7 @@ class LinearFilterTest {
-20.0,
20.0);
- assertResults(
+ assertBackwardResults(
1,
2,
// f(x) = ln(x)
@@ -147,7 +219,7 @@ class LinearFilterTest {
1.0,
20.0);
- assertResults(
+ assertBackwardResults(
2,
4,
// f(x) = x²
@@ -158,7 +230,7 @@ class LinearFilterTest {
-20.0,
20.0);
- assertResults(
+ assertBackwardResults(
2,
4,
// f(x) = sin(x)
@@ -169,7 +241,7 @@ class LinearFilterTest {
-20.0,
20.0);
- assertResults(
+ assertBackwardResults(
2,
4,
// f(x) = ln(x)
@@ -181,6 +253,53 @@ class LinearFilterTest {
20.0);
}
+ /**
+ * Helper for checking results of central finite difference.
+ *
+ * @param derivative The order of the derivative.
+ * @param samples The number of sample points.
+ * @param f Function of which to take derivative.
+ * @param dfdx Derivative of f.
+ * @param h Sample period in seconds.
+ * @param min Minimum of f's domain to test.
+ * @param max Maximum of f's domain to test.
+ */
+ void assertCentralResults(
+ int derivative,
+ int samples,
+ DoubleFunction f,
+ DoubleFunction dfdx,
+ double h,
+ double min,
+ double max) {
+ if (samples % 2 == 0) {
+ throw new IllegalArgumentException("Number of samples must be odd.");
+ }
+
+ // Generate stencil points from -(samples - 1)/2 to (samples - 1)/2
+ int[] stencil = new int[samples];
+ for (int i = 0; i < samples; ++i) {
+ stencil[i] = -(samples - 1) / 2 + i;
+ }
+
+ var filter = LinearFilter.finiteDifference(derivative, samples, stencil, h);
+
+ for (int i = (int) (min / h); i < (int) (max / h); ++i) {
+ // Let filter initialize
+ if (i < (int) (min / h) + samples) {
+ filter.calculate(f.apply(i * h));
+ continue;
+ }
+
+ // The order of accuracy is O(h^(N - d)) where N is number of stencil
+ // points and d is order of derivative
+ assertEquals(
+ dfdx.apply((i - samples / 2) * h),
+ filter.calculate(f.apply(i * h)),
+ Math.pow(h, samples - derivative));
+ }
+ }
+
/**
* Helper for checking results of backward finite difference.
*
@@ -192,7 +311,7 @@ class LinearFilterTest {
* @param min Minimum of f's domain to test.
* @param max Maximum of f's domain to test.
*/
- void assertResults(
+ void assertBackwardResults(
int derivative,
int samples,
DoubleFunction f,
@@ -209,6 +328,8 @@ class LinearFilterTest {
continue;
}
+ // For central finite difference, the derivative computed at this point is
+ // half the window size in the past.
// The order of accuracy is O(h^(N - d)) where N is number of stencil
// points and d is order of derivative
assertEquals(
diff --git a/wpimath/src/test/native/cpp/filter/LinearFilterOutputTest.cpp b/wpimath/src/test/native/cpp/filter/LinearFilterOutputTest.cpp
index bca3f9d9ee..e9f228b95b 100644
--- a/wpimath/src/test/native/cpp/filter/LinearFilterOutputTest.cpp
+++ b/wpimath/src/test/native/cpp/filter/LinearFilterOutputTest.cpp
@@ -9,6 +9,7 @@
#include
#include
+#include
#include
#include "gtest/gtest.h"
@@ -120,8 +121,40 @@ INSTANTIATE_TEST_SUITE_P(Tests, LinearFilterOutputTest,
kTestMovAvg, kTestPulse));
template
-void AssertResults(F&& f, DfDx&& dfdx, units::second_t h, double min,
- double max) {
+void AssertCentralResults(F&& f, DfDx&& dfdx, units::second_t h, double min,
+ double max) {
+ static_assert(Samples % 2 != 0, "Number of samples must be odd.");
+
+ // Generate stencil points from -(samples - 1)/2 to (samples - 1)/2
+ wpi::array stencil{wpi::empty_array};
+ for (int i = 0; i < Samples; ++i) {
+ stencil[i] = -(Samples - 1) / 2 + i;
+ }
+
+ auto filter =
+ frc::LinearFilter::FiniteDifference(stencil,
+ h);
+
+ for (int i = min / h.value(); i < max / h.value(); ++i) {
+ // Let filter initialize
+ if (i < static_cast(min / h.value()) + Samples) {
+ filter.Calculate(f(i * h.value()));
+ continue;
+ }
+
+ // For central finite difference, the derivative computed at this point is
+ // half the window size in the past.
+ // The order of accuracy is O(h^(N - d)) where N is number of stencil
+ // points and d is order of derivative
+ EXPECT_NEAR(dfdx((i - (Samples - 1) / 2) * h.value()),
+ filter.Calculate(f(i * h.value())),
+ std::pow(h.value(), Samples - Derivative));
+ }
+}
+
+template
+void AssertBackwardResults(F&& f, DfDx&& dfdx, units::second_t h, double min,
+ double max) {
auto filter =
frc::LinearFilter::BackwardFiniteDifference(
h);
@@ -141,12 +174,12 @@ void AssertResults(F&& f, DfDx&& dfdx, units::second_t h, double min,
}
/**
- * Test backward finite difference.
+ * Test central finite difference.
*/
-TEST(LinearFilterOutputTest, BackwardFiniteDifference) {
+TEST(LinearFilterOutputTest, CentralFiniteDifference) {
constexpr auto h = 5_ms;
- AssertResults<1, 2>(
+ AssertCentralResults<1, 3>(
[](double x) {
// f(x) = x²
return x * x;
@@ -157,7 +190,7 @@ TEST(LinearFilterOutputTest, BackwardFiniteDifference) {
},
h, -20.0, 20.0);
- AssertResults<1, 2>(
+ AssertCentralResults<1, 3>(
[](double x) {
// f(x) = std::sin(x)
return std::sin(x);
@@ -168,7 +201,7 @@ TEST(LinearFilterOutputTest, BackwardFiniteDifference) {
},
h, -20.0, 20.0);
- AssertResults<1, 2>(
+ AssertCentralResults<1, 3>(
[](double x) {
// f(x) = ln(x)
return std::log(x);
@@ -179,7 +212,7 @@ TEST(LinearFilterOutputTest, BackwardFiniteDifference) {
},
h, 1.0, 20.0);
- AssertResults<2, 4>(
+ AssertCentralResults<2, 5>(
[](double x) {
// f(x) = x^2
return x * x;
@@ -190,7 +223,7 @@ TEST(LinearFilterOutputTest, BackwardFiniteDifference) {
},
h, -20.0, 20.0);
- AssertResults<2, 4>(
+ AssertCentralResults<2, 5>(
[](double x) {
// f(x) = std::sin(x)
return std::sin(x);
@@ -201,7 +234,80 @@ TEST(LinearFilterOutputTest, BackwardFiniteDifference) {
},
h, -20.0, 20.0);
- AssertResults<2, 4>(
+ AssertCentralResults<2, 5>(
+ [](double x) {
+ // f(x) = ln(x)
+ return std::log(x);
+ },
+ [](double x) {
+ // d²f/dx² = -1 / x²
+ return -1.0 / (x * x);
+ },
+ h, 1.0, 20.0);
+}
+
+/**
+ * Test backward finite difference.
+ */
+TEST(LinearFilterOutputTest, BackwardFiniteDifference) {
+ constexpr auto h = 5_ms;
+
+ AssertBackwardResults<1, 2>(
+ [](double x) {
+ // f(x) = x²
+ return x * x;
+ },
+ [](double x) {
+ // df/dx = 2x
+ return 2.0 * x;
+ },
+ h, -20.0, 20.0);
+
+ AssertBackwardResults<1, 2>(
+ [](double x) {
+ // f(x) = std::sin(x)
+ return std::sin(x);
+ },
+ [](double x) {
+ // df/dx = std::cos(x)
+ return std::cos(x);
+ },
+ h, -20.0, 20.0);
+
+ AssertBackwardResults<1, 2>(
+ [](double x) {
+ // f(x) = ln(x)
+ return std::log(x);
+ },
+ [](double x) {
+ // df/dx = 1 / x
+ return 1.0 / x;
+ },
+ h, 1.0, 20.0);
+
+ AssertBackwardResults<2, 4>(
+ [](double x) {
+ // f(x) = x^2
+ return x * x;
+ },
+ [](double x) {
+ // d²f/dx² = 2
+ return 2.0;
+ },
+ h, -20.0, 20.0);
+
+ AssertBackwardResults<2, 4>(
+ [](double x) {
+ // f(x) = std::sin(x)
+ return std::sin(x);
+ },
+ [](double x) {
+ // d²f/dx² = -std::sin(x)
+ return -std::sin(x);
+ },
+ h, -20.0, 20.0);
+
+ AssertBackwardResults<2, 4>(
[](double x) {
// f(x) = ln(x)
return std::log(x);