[wpimath] Add LinearFilter::FiniteDifference() (#3900)

This allows making more general finite difference filters, like central
finite difference. SysId uses this for acceleration filtering.
This commit is contained in:
Tyler Veness
2022-01-15 20:18:11 -08:00
committed by GitHub
parent 63d1fb3bed
commit 7f4265facc
4 changed files with 360 additions and 84 deletions

View File

@@ -109,12 +109,84 @@ class LinearFilterTest {
0.0));
}
/** Test central finite difference. */
@Test
void centralFiniteDifferenceTest() {
double h = 0.005;
assertCentralResults(
1,
3,
// f(x) = x²
(double x) -> x * x,
// df/dx = 2x
(double x) -> 2.0 * x,
h,
-20.0,
20.0);
assertCentralResults(
1,
3,
// f(x) = sin(x)
(double x) -> Math.sin(x),
// df/dx = cos(x)
(double x) -> Math.cos(x),
h,
-20.0,
20.0);
assertCentralResults(
1,
3,
// f(x) = ln(x)
(double x) -> Math.log(x),
// df/dx = 1 / x
(double x) -> 1.0 / x,
h,
1.0,
20.0);
assertCentralResults(
2,
5,
// f(x) = x²
(double x) -> x * x,
// d²f/dx² = 2
(double x) -> 2.0,
h,
-20.0,
20.0);
assertCentralResults(
2,
5,
// f(x) = sin(x)
(double x) -> Math.sin(x),
// d²f/dx² = -sin(x)
(double x) -> -Math.sin(x),
h,
-20.0,
20.0);
assertCentralResults(
2,
5,
// f(x) = ln(x)
(double x) -> Math.log(x),
// d²f/dx² = -1 / x²
(double x) -> -1.0 / (x * x),
h,
1.0,
20.0);
}
/** Test backward finite difference. */
@Test
void backwardFiniteDifferenceTest() {
double h = 0.005;
assertResults(
assertBackwardResults(
1,
2,
// f(x) = x²
@@ -125,7 +197,7 @@ class LinearFilterTest {
-20.0,
20.0);
assertResults(
assertBackwardResults(
1,
2,
// f(x) = sin(x)
@@ -136,7 +208,7 @@ class LinearFilterTest {
-20.0,
20.0);
assertResults(
assertBackwardResults(
1,
2,
// f(x) = ln(x)
@@ -147,7 +219,7 @@ class LinearFilterTest {
1.0,
20.0);
assertResults(
assertBackwardResults(
2,
4,
// f(x) = x²
@@ -158,7 +230,7 @@ class LinearFilterTest {
-20.0,
20.0);
assertResults(
assertBackwardResults(
2,
4,
// f(x) = sin(x)
@@ -169,7 +241,7 @@ class LinearFilterTest {
-20.0,
20.0);
assertResults(
assertBackwardResults(
2,
4,
// f(x) = ln(x)
@@ -181,6 +253,53 @@ class LinearFilterTest {
20.0);
}
/**
* Helper for checking results of central finite difference.
*
* @param derivative The order of the derivative.
* @param samples The number of sample points.
* @param f Function of which to take derivative.
* @param dfdx Derivative of f.
* @param h Sample period in seconds.
* @param min Minimum of f's domain to test.
* @param max Maximum of f's domain to test.
*/
void assertCentralResults(
int derivative,
int samples,
DoubleFunction<Double> f,
DoubleFunction<Double> dfdx,
double h,
double min,
double max) {
if (samples % 2 == 0) {
throw new IllegalArgumentException("Number of samples must be odd.");
}
// Generate stencil points from -(samples - 1)/2 to (samples - 1)/2
int[] stencil = new int[samples];
for (int i = 0; i < samples; ++i) {
stencil[i] = -(samples - 1) / 2 + i;
}
var filter = LinearFilter.finiteDifference(derivative, samples, stencil, h);
for (int i = (int) (min / h); i < (int) (max / h); ++i) {
// Let filter initialize
if (i < (int) (min / h) + samples) {
filter.calculate(f.apply(i * h));
continue;
}
// The order of accuracy is O(h^(N - d)) where N is number of stencil
// points and d is order of derivative
assertEquals(
dfdx.apply((i - samples / 2) * h),
filter.calculate(f.apply(i * h)),
Math.pow(h, samples - derivative));
}
}
/**
* Helper for checking results of backward finite difference.
*
@@ -192,7 +311,7 @@ class LinearFilterTest {
* @param min Minimum of f's domain to test.
* @param max Maximum of f's domain to test.
*/
void assertResults(
void assertBackwardResults(
int derivative,
int samples,
DoubleFunction<Double> f,
@@ -209,6 +328,8 @@ class LinearFilterTest {
continue;
}
// For central finite difference, the derivative computed at this point is
// half the window size in the past.
// The order of accuracy is O(h^(N - d)) where N is number of stencil
// points and d is order of derivative
assertEquals(

View File

@@ -9,6 +9,7 @@
#include <memory>
#include <random>
#include <wpi/array.h>
#include <wpi/numbers>
#include "gtest/gtest.h"
@@ -120,8 +121,40 @@ INSTANTIATE_TEST_SUITE_P(Tests, LinearFilterOutputTest,
kTestMovAvg, kTestPulse));
template <int Derivative, int Samples, typename F, typename DfDx>
void AssertResults(F&& f, DfDx&& dfdx, units::second_t h, double min,
double max) {
void AssertCentralResults(F&& f, DfDx&& dfdx, units::second_t h, double min,
double max) {
static_assert(Samples % 2 != 0, "Number of samples must be odd.");
// Generate stencil points from -(samples - 1)/2 to (samples - 1)/2
wpi::array<int, Samples> stencil{wpi::empty_array};
for (int i = 0; i < Samples; ++i) {
stencil[i] = -(Samples - 1) / 2 + i;
}
auto filter =
frc::LinearFilter<double>::FiniteDifference<Derivative, Samples>(stencil,
h);
for (int i = min / h.value(); i < max / h.value(); ++i) {
// Let filter initialize
if (i < static_cast<int>(min / h.value()) + Samples) {
filter.Calculate(f(i * h.value()));
continue;
}
// For central finite difference, the derivative computed at this point is
// half the window size in the past.
// The order of accuracy is O(h^(N - d)) where N is number of stencil
// points and d is order of derivative
EXPECT_NEAR(dfdx((i - (Samples - 1) / 2) * h.value()),
filter.Calculate(f(i * h.value())),
std::pow(h.value(), Samples - Derivative));
}
}
template <int Derivative, int Samples, typename F, typename DfDx>
void AssertBackwardResults(F&& f, DfDx&& dfdx, units::second_t h, double min,
double max) {
auto filter =
frc::LinearFilter<double>::BackwardFiniteDifference<Derivative, Samples>(
h);
@@ -141,12 +174,12 @@ void AssertResults(F&& f, DfDx&& dfdx, units::second_t h, double min,
}
/**
* Test backward finite difference.
* Test central finite difference.
*/
TEST(LinearFilterOutputTest, BackwardFiniteDifference) {
TEST(LinearFilterOutputTest, CentralFiniteDifference) {
constexpr auto h = 5_ms;
AssertResults<1, 2>(
AssertCentralResults<1, 3>(
[](double x) {
// f(x) = x²
return x * x;
@@ -157,7 +190,7 @@ TEST(LinearFilterOutputTest, BackwardFiniteDifference) {
},
h, -20.0, 20.0);
AssertResults<1, 2>(
AssertCentralResults<1, 3>(
[](double x) {
// f(x) = std::sin(x)
return std::sin(x);
@@ -168,7 +201,7 @@ TEST(LinearFilterOutputTest, BackwardFiniteDifference) {
},
h, -20.0, 20.0);
AssertResults<1, 2>(
AssertCentralResults<1, 3>(
[](double x) {
// f(x) = ln(x)
return std::log(x);
@@ -179,7 +212,7 @@ TEST(LinearFilterOutputTest, BackwardFiniteDifference) {
},
h, 1.0, 20.0);
AssertResults<2, 4>(
AssertCentralResults<2, 5>(
[](double x) {
// f(x) = x^2
return x * x;
@@ -190,7 +223,7 @@ TEST(LinearFilterOutputTest, BackwardFiniteDifference) {
},
h, -20.0, 20.0);
AssertResults<2, 4>(
AssertCentralResults<2, 5>(
[](double x) {
// f(x) = std::sin(x)
return std::sin(x);
@@ -201,7 +234,80 @@ TEST(LinearFilterOutputTest, BackwardFiniteDifference) {
},
h, -20.0, 20.0);
AssertResults<2, 4>(
AssertCentralResults<2, 5>(
[](double x) {
// f(x) = ln(x)
return std::log(x);
},
[](double x) {
// d²f/dx² = -1 / x²
return -1.0 / (x * x);
},
h, 1.0, 20.0);
}
/**
* Test backward finite difference.
*/
TEST(LinearFilterOutputTest, BackwardFiniteDifference) {
constexpr auto h = 5_ms;
AssertBackwardResults<1, 2>(
[](double x) {
// f(x) = x²
return x * x;
},
[](double x) {
// df/dx = 2x
return 2.0 * x;
},
h, -20.0, 20.0);
AssertBackwardResults<1, 2>(
[](double x) {
// f(x) = std::sin(x)
return std::sin(x);
},
[](double x) {
// df/dx = std::cos(x)
return std::cos(x);
},
h, -20.0, 20.0);
AssertBackwardResults<1, 2>(
[](double x) {
// f(x) = ln(x)
return std::log(x);
},
[](double x) {
// df/dx = 1 / x
return 1.0 / x;
},
h, 1.0, 20.0);
AssertBackwardResults<2, 4>(
[](double x) {
// f(x) = x^2
return x * x;
},
[](double x) {
// d²f/dx² = 2
return 2.0;
},
h, -20.0, 20.0);
AssertBackwardResults<2, 4>(
[](double x) {
// f(x) = std::sin(x)
return std::sin(x);
},
[](double x) {
// d²f/dx² = -std::sin(x)
return -std::sin(x);
},
h, -20.0, 20.0);
AssertBackwardResults<2, 4>(
[](double x) {
// f(x) = ln(x)
return std::log(x);