[wpimath] Add Sleipnir Java bindings (#8236)

The wrapper includes reverse mode autodiff, the Problem DSL, and the
optimal control problem API. I wrote it by directly translating the
upstream
[API](https://github.com/SleipnirGroup/Sleipnir/tree/main/include/sleipnir)
and [tests](https://github.com/SleipnirGroup/Sleipnir/tree/main/test) to
Java (i.e., copy-paste-modify).

I replaced the ArmFeedforward and Ellipse2d JNIs with implementations
using the Sleipnir Java bindings. Switching dev binary JNIs to release
by default sped up wpimath test runs from several minutes to 7 seconds.
This commit is contained in:
Tyler Veness
2026-03-29 22:34:21 -07:00
committed by GitHub
parent 3e821b9448
commit d248c040bf
84 changed files with 13405 additions and 170 deletions

View File

@@ -0,0 +1,24 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math;
import java.util.ArrayList;
public final class DoubleRange {
private DoubleRange() {
// Utility class.
}
public static ArrayList<Double> range(double start, double end, double step) {
var ret = new ArrayList<Double>();
int steps = (int) ((end - start) / step);
for (int i = 0; i < steps; ++i) {
ret.add(start + i * step);
}
return ret;
}
}

View File

@@ -0,0 +1,41 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.ejml.dense.row.MatrixFeatures_DDRM;
import org.ejml.simple.SimpleMatrix;
public final class MatrixAssertions {
private MatrixAssertions() {
// Utility class.
}
/**
* Asserts that two SimpleMatrices are equal.
*
* @param expected Expected value.
* @param actual The value to check against expected.
*/
public static void assertEquals(SimpleMatrix expected, SimpleMatrix actual) {
assertFalse(MatrixFeatures_DDRM.hasUncountable(expected.getDDRM()));
assertTrue(MatrixFeatures_DDRM.isEquals(expected.getDDRM(), actual.getDDRM()));
}
/**
* Asserts that two SimpleMatrices are equal to within a positive delta.
*
* @param expected Expected value.
* @param actual The value to check against expected.
* @param delta The maximum delta between expected and actual for which both values are still
* considered equal.
*/
public static void assertEquals(SimpleMatrix expected, SimpleMatrix actual, double delta) {
assertFalse(MatrixFeatures_DDRM.hasUncountable(expected.getDDRM()));
assertTrue(MatrixFeatures_DDRM.isEquals(expected.getDDRM(), actual.getDDRM(), delta));
}
}

View File

@@ -2,15 +2,15 @@
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.jni;
package org.wpilib.math.autodiff;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import org.junit.jupiter.api.Test;
public class ArmFeedforwardJNITest {
public class GradientJNITest {
@Test
public void testLink() {
assertDoesNotThrow(ArmFeedforwardJNI::forceLoad);
assertDoesNotThrow(GradientJNI::forceLoad);
}
}

View File

@@ -0,0 +1,964 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.wpilib.math.MatrixAssertions.assertEquals;
import static org.wpilib.math.autodiff.Variable.abs;
import static org.wpilib.math.autodiff.Variable.acos;
import static org.wpilib.math.autodiff.Variable.asin;
import static org.wpilib.math.autodiff.Variable.atan;
import static org.wpilib.math.autodiff.Variable.atan2;
import static org.wpilib.math.autodiff.Variable.cbrt;
import static org.wpilib.math.autodiff.Variable.cos;
import static org.wpilib.math.autodiff.Variable.cosh;
import static org.wpilib.math.autodiff.Variable.exp;
import static org.wpilib.math.autodiff.Variable.hypot;
import static org.wpilib.math.autodiff.Variable.log;
import static org.wpilib.math.autodiff.Variable.log10;
import static org.wpilib.math.autodiff.Variable.max;
import static org.wpilib.math.autodiff.Variable.min;
import static org.wpilib.math.autodiff.Variable.pow;
import static org.wpilib.math.autodiff.Variable.signum;
import static org.wpilib.math.autodiff.Variable.sin;
import static org.wpilib.math.autodiff.Variable.sinh;
import static org.wpilib.math.autodiff.Variable.sqrt;
import static org.wpilib.math.autodiff.Variable.tan;
import static org.wpilib.math.autodiff.Variable.tanh;
import org.ejml.simple.SimpleMatrix;
import org.junit.jupiter.api.Test;
class GradientTest {
@Test
void testTrivialCase() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var a = new Variable();
a.setValue(10);
var b = new Variable();
b.setValue(20);
var c = a;
try (var g_a_a = new Gradient(a, a)) {
assertEquals(1.0, g_a_a.value().get(0, 0));
}
try (var g_a_b = new Gradient(a, b)) {
assertEquals(0.0, g_a_b.value().get(0, 0));
}
try (var g_c_a = new Gradient(c, a)) {
assertEquals(1.0, g_c_a.value().get(0, 0));
}
try (var g_c_b = new Gradient(c, b)) {
assertEquals(0.0, g_c_b.value().get(0, 0));
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testUnaryPlus() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var a = new Variable();
a.setValue(10);
var c = a.unaryPlus();
assertEquals(a.value(), c.value());
try (var g_c_a = new Gradient(c, a)) {
assertEquals(1.0, g_c_a.value().get(0, 0));
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testUnaryMinus() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var a = new Variable();
a.setValue(10);
var c = a.unaryMinus();
assertEquals(a.unaryMinus().value(), c.value());
try (var g_c_a = new Gradient(c, a)) {
assertEquals(-1.0, g_c_a.value().get(0, 0));
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testIdenticalVariables() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var a = new Variable();
a.setValue(10);
var x = a;
var c = a.times(a).plus(x);
assertEquals(a.value() * a.value() + x.value(), c.value());
try (var g_x_a = new Gradient(x, a);
var g_c_a = new Gradient(c, a)) {
assertEquals(2 * a.value() + g_x_a.value().get(0, 0), g_c_a.value().get(0, 0));
}
try (var g_a_x = new Gradient(a, x);
var g_c_x = new Gradient(c, x)) {
assertEquals(2 * a.value() * g_a_x.value().get(0, 0) + 1, g_c_x.value().get(0, 0));
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testElementary() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var a = new Variable();
a.setValue(1.0);
var b = new Variable();
b.setValue(2.0);
var c = a.times(-2);
try (var g_c_a = new Gradient(c, a)) {
assertEquals(-2.0, g_c_a.value().get(0, 0));
}
c = a.div(3.0);
try (var g_c_a = new Gradient(c, a)) {
assertEquals(1.0 / 3.0, g_c_a.value().get(0, 0));
}
a.setValue(100.0);
b.setValue(200.0);
c = a.plus(b);
try (var g_c_a = new Gradient(c, a)) {
assertEquals(1.0, g_c_a.value().get(0, 0));
}
try (var g_c_b = new Gradient(c, b)) {
assertEquals(1.0, g_c_b.value().get(0, 0));
}
c = a.minus(b);
try (var g_c_a = new Gradient(c, a)) {
assertEquals(1.0, g_c_a.value().get(0, 0));
}
try (var g_c_b = new Gradient(c, b)) {
assertEquals(-1.0, g_c_b.value().get(0, 0));
}
c = a.unaryMinus().plus(b);
try (var g_c_a = new Gradient(c, a)) {
assertEquals(-1.0, g_c_a.value().get(0, 0));
}
try (var g_c_b = new Gradient(c, b)) {
assertEquals(1.0, g_c_b.value().get(0, 0));
}
c = a.plus(1);
try (var g_c_a = new Gradient(c, a)) {
assertEquals(1.0, g_c_a.value().get(0, 0));
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testTrigonometry() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
x.setValue(0.5);
// Math.sin(x)
assertEquals(Math.sin(x.value()), sin(x).value());
var g = new Gradient(sin(x), x);
assertEquals(Math.cos(x.value()), g.get().value().get(0, 0));
assertEquals(Math.cos(x.value()), g.value().get(0, 0));
// Math.cos(x)
assertEquals(Math.cos(x.value()), cos(x).value());
g.close();
g = new Gradient(cos(x), x);
assertEquals(-Math.sin(x.value()), g.get().value().get(0, 0));
assertEquals(-Math.sin(x.value()), g.value().get(0, 0));
// Math.tan(x)
assertEquals(Math.tan(x.value()), tan(x).value());
g.close();
g = new Gradient(tan(x), x);
assertEquals(1.0 / (Math.cos(x.value()) * Math.cos(x.value())), g.get().value().get(0, 0));
assertEquals(1.0 / (Math.cos(x.value()) * Math.cos(x.value())), g.value().get(0, 0));
// Math.asin(x)
assertEquals(Math.asin(x.value()), asin(x).value(), 1e-15);
g.close();
g = new Gradient(asin(x), x);
assertEquals(1.0 / Math.sqrt(1 - x.value() * x.value()), g.get().value().get(0, 0));
assertEquals(1.0 / Math.sqrt(1 - x.value() * x.value()), g.value().get(0, 0));
// Math.acos(x)
assertEquals(Math.acos(x.value()), acos(x).value(), 1e-15);
g.close();
g = new Gradient(acos(x), x);
assertEquals(-1.0 / Math.sqrt(1 - x.value() * x.value()), g.get().value().get(0, 0));
assertEquals(-1.0 / Math.sqrt(1 - x.value() * x.value()), g.value().get(0, 0));
// Math.atan(x)
assertEquals(Math.atan(x.value()), atan(x).value(), 1e-15);
g.close();
g = new Gradient(atan(x), x);
assertEquals(1.0 / (1 + x.value() * x.value()), g.get().value().get(0, 0));
assertEquals(1.0 / (1 + x.value() * x.value()), g.value().get(0, 0));
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testHyperbolic() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
x.setValue(1.0);
// sinh(x)
assertEquals(Math.sinh(x.value()), sinh(x).value());
var g = new Gradient(sinh(x), x);
assertEquals(Math.cosh(x.value()), g.get().value().get(0, 0), 1e-15);
assertEquals(Math.cosh(x.value()), g.value().get(0, 0), 1e-15);
// Math.cosh(x)
assertEquals(Math.cosh(x.value()), cosh(x).value(), 1e-15);
g.close();
g = new Gradient(cosh(x), x);
assertEquals(Math.sinh(x.value()), g.get().value().get(0, 0));
assertEquals(Math.sinh(x.value()), g.value().get(0, 0));
// tanh(x)
assertEquals(Math.tanh(x.value()), tanh(x).value());
g.close();
g = new Gradient(tanh(x), x);
assertEquals(
1.0 / (Math.cosh(x.value()) * Math.cosh(x.value())), g.get().value().get(0, 0), 1e-15);
assertEquals(1.0 / (Math.cosh(x.value()) * Math.cosh(x.value())), g.value().get(0, 0), 1e-15);
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testExponential() {
try (var pool = new VariablePool()) {
var x = new Variable();
x.setValue(1.0);
// Math.log(x)
assertEquals(Math.log(x.value()), log(x).value());
var g = new Gradient(log(x), x);
assertEquals(1.0 / x.value(), g.get().value().get(0, 0));
assertEquals(1.0 / x.value(), g.value().get(0, 0));
// Math.log10(x)
assertEquals(Math.log10(x.value()), log10(x).value());
g.close();
g = new Gradient(log10(x), x);
assertEquals(1.0 / (Math.log(10.0) * x.value()), g.get().value().get(0, 0));
assertEquals(1.0 / (Math.log(10.0) * x.value()), g.value().get(0, 0));
// Math.exp(x)
assertEquals(Math.exp(x.value()), exp(x).value(), 1e-15);
g.close();
g = new Gradient(exp(x), x);
assertEquals(Math.exp(x.value()), g.get().value().get(0, 0), 1e-15);
assertEquals(Math.exp(x.value()), g.value().get(0, 0), 1e-15);
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testPower() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
x.setValue(1.0);
var a = new Variable();
a.setValue(2.0);
// Math.sqrt(x)
assertEquals(Math.sqrt(x.value()), sqrt(x).value());
var g = new Gradient(sqrt(x), x);
assertEquals(0.5 / Math.sqrt(x.value()), g.get().value().get(0, 0));
assertEquals(0.5 / Math.sqrt(x.value()), g.value().get(0, 0));
// Math.sqrt(a)
assertEquals(Math.sqrt(a.value()), sqrt(a).value());
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
x.setValue(1.0);
var a = new Variable();
a.setValue(2.0);
var g = new Gradient(sqrt(a), a);
assertEquals(0.5 / Math.sqrt(a.value()), g.get().value().get(0, 0));
assertEquals(0.5 / Math.sqrt(a.value()), g.value().get(0, 0));
// Math.cbrt(x)
assertEquals(Math.cbrt(x.value()), cbrt(x).value());
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
x.setValue(1.0);
var a = new Variable();
a.setValue(2.0);
var g = new Gradient(cbrt(x), x);
assertEquals(
1.0 / (3.0 * Math.cbrt(x.value()) * Math.cbrt(x.value())), g.get().value().get(0, 0));
assertEquals(1.0 / (3.0 * Math.cbrt(x.value()) * Math.cbrt(x.value())), g.value().get(0, 0));
// Math.cbrt(a)
assertEquals(Math.cbrt(a.value()), cbrt(a).value(), 1e-15);
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
x.setValue(1.0);
var a = new Variable();
a.setValue(2.0);
var g = new Gradient(cbrt(a), a);
assertEquals(
1.0 / (3.0 * Math.cbrt(a.value()) * Math.cbrt(a.value())),
g.get().value().get(0, 0),
1e-15);
assertEquals(
1.0 / (3.0 * Math.cbrt(a.value()) * Math.cbrt(a.value())), g.value().get(0, 0), 1e-15);
// x²
assertEquals(Math.pow(x.value(), 2.0), pow(x, 2.0).value());
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
x.setValue(1.0);
var g = new Gradient(pow(x, 2.0), x);
assertEquals(2.0 * x.value(), g.get().value().get(0, 0));
assertEquals(2.0 * x.value(), g.value().get(0, 0));
// 2ˣ
assertEquals(Math.pow(2.0, x.value()), pow(2.0, x).value());
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
x.setValue(1.0);
var g = new Gradient(pow(2.0, x), x);
assertEquals(Math.log(2.0) * Math.pow(2.0, x.value()), g.get().value().get(0, 0));
assertEquals(Math.log(2.0) * Math.pow(2.0, x.value()), g.value().get(0, 0));
// xˣ
assertEquals(Math.pow(x.value(), x.value()), pow(x, x).value());
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
x.setValue(1.0);
var a = new Variable();
a.setValue(2.0);
var g = new Gradient(pow(x, x), x);
assertEquals(
(Math.log(x.value()) + 1) * Math.pow(x.value(), x.value()), g.get().value().get(0, 0));
assertEquals((Math.log(x.value()) + 1) * Math.pow(x.value(), x.value()), g.value().get(0, 0));
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
x.setValue(1.0);
var a = new Variable();
a.setValue(2.0);
// y(a)
var y = a.times(2);
assertEquals(2 * a.value(), y.value());
var g = new Gradient(y, a);
assertEquals(2.0, g.get().value().get(0, 0));
assertEquals(2.0, g.value().get(0, 0));
// xʸ(x)
assertEquals(Math.pow(x.value(), y.value()), pow(x, y).value());
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
x.setValue(1.0);
var a = new Variable();
a.setValue(2.0);
// y(a)
var y = a.times(2);
assertEquals(2 * a.value(), y.value());
var g = new Gradient(pow(x, y), x);
assertEquals(
y.value() / x.value() * Math.pow(x.value(), y.value()), g.get().value().get(0, 0));
assertEquals(y.value() / x.value() * Math.pow(x.value(), y.value()), g.value().get(0, 0));
// xʸ(a)
assertEquals(Math.pow(x.value(), y.value()), pow(x, y).value());
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
x.setValue(1.0);
var a = new Variable();
a.setValue(2.0);
// y(a)
var y = a.times(2);
assertEquals(2 * a.value(), y.value());
try (var g = new Gradient(pow(x, y), a);
var g_x_a = new Gradient(x, a);
var g_y_a = new Gradient(y, a)) {
assertEquals(
Math.pow(x.value(), y.value())
* (y.value() / x.value() * g_x_a.value().get(0, 0)
+ Math.log(x.value()) * g_y_a.value().get(0, 0)),
g.get().value().get(0, 0));
assertEquals(
Math.pow(x.value(), y.value())
* (y.value() / x.value() * g_x_a.value().get(0, 0)
+ Math.log(x.value()) * g_y_a.value().get(0, 0)),
g.value().get(0, 0));
}
// xʸ(y)
assertEquals(Math.pow(x.value(), y.value()), pow(x, y).value());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
x.setValue(1.0);
var a = new Variable();
a.setValue(2.0);
// y(a)
var y = a.times(2);
assertEquals(2 * a.value(), y.value());
var g = new Gradient(pow(x, y), y);
assertEquals(Math.log(x.value()) * Math.pow(x.value(), y.value()), g.get().value().get(0, 0));
assertEquals(Math.log(x.value()) * Math.pow(x.value(), y.value()), g.value().get(0, 0));
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testAbs() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
var g = new Gradient(abs(x), x);
x.setValue(1.0);
assertEquals(Math.abs(x.value()), abs(x).value());
assertEquals(1.0, g.get().value().get(0, 0));
assertEquals(1.0, g.value().get(0, 0));
x.setValue(-1.0);
assertEquals(Math.abs(x.value()), abs(x).value());
assertEquals(-1.0, g.get().value().get(0, 0));
assertEquals(-1.0, g.value().get(0, 0));
x.setValue(0.0);
assertEquals(Math.abs(x.value()), abs(x).value());
assertEquals(0.0, g.get().value().get(0, 0));
assertEquals(0.0, g.value().get(0, 0));
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testAtan2() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
var y = new Variable();
// Testing atan2 function on (double, var)
x.setValue(1.0);
y.setValue(0.9);
assertEquals(Math.atan2(2.0, x.value()), atan2(2.0, x).value());
var g = new Gradient(atan2(2.0, x), x);
assertEquals(-2.0 / (2 * 2 + x.value() * x.value()), g.get().value().get(0, 0), 1e-15);
assertEquals(-2.0 / (2 * 2 + x.value() * x.value()), g.value().get(0, 0), 1e-15);
// Testing atan2 function on (var, double)
x.setValue(1.0);
y.setValue(0.9);
assertEquals(Math.atan2(x.value(), 2.0), atan2(x, 2.0).value());
g.close();
g = new Gradient(atan2(x, 2.0), x);
assertEquals(2.0 / (2 * 2 + x.value() * x.value()), g.get().value().get(0, 0), 1e-15);
assertEquals(2.0 / (2 * 2 + x.value() * x.value()), g.value().get(0, 0), 1e-15);
// Testing atan2 function on (var, var)
x.setValue(1.1);
y.setValue(0.9);
assertEquals(Math.atan2(y.value(), x.value()), atan2(y, x).value(), 1e-15);
g.close();
g = new Gradient(atan2(y, x), y);
assertEquals(
x.value() / (x.value() * x.value() + y.value() * y.value()),
g.get().value().get(0, 0),
1e-15);
assertEquals(
x.value() / (x.value() * x.value() + y.value() * y.value()), g.value().get(0, 0), 1e-15);
g.close();
g = new Gradient(atan2(y, x), x);
assertEquals(
-y.value() / (x.value() * x.value() + y.value() * y.value()),
g.get().value().get(0, 0),
1e-15);
assertEquals(
-y.value() / (x.value() * x.value() + y.value() * y.value()), g.value().get(0, 0), 1e-15);
// Testing atan2 function on (expr, expr)
assertEquals(
3 * Math.atan2(Math.sin(y.value()), 2 * x.value() + 1),
3 * atan2(sin(y), x.times(2).plus(1)).value(),
1e-15);
g.close();
g = new Gradient(atan2(sin(y), x.times(2).plus(1)).times(3), y);
assertEquals(
3
* (2 * x.value() + 1)
* Math.cos(y.value())
/ ((2 * x.value() + 1) * (2 * x.value() + 1)
+ Math.sin(y.value()) * Math.sin(y.value())),
g.get().value().get(0, 0),
1e-15);
assertEquals(
3
* (2 * x.value() + 1)
* Math.cos(y.value())
/ ((2 * x.value() + 1) * (2 * x.value() + 1)
+ Math.sin(y.value()) * Math.sin(y.value())),
g.value().get(0, 0),
1e-15);
g.close();
g = new Gradient(atan2(sin(y), x.times(2).plus(1)).times(3), x);
assertEquals(
3
* -2
* Math.sin(y.value())
/ ((2 * x.value() + 1) * (2 * x.value() + 1)
+ Math.sin(y.value()) * Math.sin(y.value())),
g.get().value().get(0, 0),
1e-15);
assertEquals(
3
* -2
* Math.sin(y.value())
/ ((2 * x.value() + 1) * (2 * x.value() + 1)
+ Math.sin(y.value()) * Math.sin(y.value())),
g.value().get(0, 0),
1e-15);
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
private double hypot(double x, double y, double z) {
return Math.sqrt(x * x + y * y + z * z);
}
@Test
void testHypot() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
var y = new Variable();
// Testing hypot function on (var, double)
x.setValue(1.8);
y.setValue(1.5);
assertEquals(Math.hypot(x.value(), 2.0), Variable.hypot(x, 2.0).value());
var g = new Gradient(Variable.hypot(x, 2.0), x);
assertEquals(x.value() / Math.hypot(x.value(), 2.0), g.get().value().get(0, 0));
assertEquals(x.value() / Math.hypot(x.value(), 2.0), g.value().get(0, 0));
// Testing hypot function on (double, var)
assertEquals(Math.hypot(2.0, y.value()), Variable.hypot(2.0, y).value());
g.close();
g = new Gradient(Variable.hypot(2.0, y), y);
assertEquals(y.value() / Math.hypot(2.0, y.value()), g.get().value().get(0, 0));
assertEquals(y.value() / Math.hypot(2.0, y.value()), g.value().get(0, 0));
// Testing hypot function on (var, var)
x.setValue(1.3);
y.setValue(2.3);
assertEquals(Math.hypot(x.value(), y.value()), Variable.hypot(x, y).value());
g.close();
g = new Gradient(Variable.hypot(x, y), x);
assertEquals(x.value() / Math.hypot(x.value(), y.value()), g.get().value().get(0, 0));
assertEquals(x.value() / Math.hypot(x.value(), y.value()), g.value().get(0, 0));
g.close();
g = new Gradient(Variable.hypot(x, y), y);
assertEquals(y.value() / Math.hypot(x.value(), y.value()), g.get().value().get(0, 0));
assertEquals(y.value() / Math.hypot(x.value(), y.value()), g.value().get(0, 0));
// Testing hypot function on (expr, expr)
x.setValue(1.3);
y.setValue(2.3);
assertEquals(
Math.hypot(2.0 * x.value(), 3.0 * y.value()),
Variable.hypot(x.times(2.0), y.times(3.0)).value());
g.close();
g = new Gradient(Variable.hypot(x.times(2.0), y.times(3.0)), x);
assertEquals(
4.0 * x.value() / Math.hypot(2.0 * x.value(), 3.0 * y.value()),
g.get().value().get(0, 0));
assertEquals(
4.0 * x.value() / Math.hypot(2.0 * x.value(), 3.0 * y.value()), g.value().get(0, 0));
g.close();
g = new Gradient(Variable.hypot(x.times(2.0), y.times(3.0)), y);
assertEquals(
9.0 * y.value() / Math.hypot(2.0 * x.value(), 3.0 * y.value()),
g.get().value().get(0, 0));
assertEquals(
9.0 * y.value() / Math.hypot(2.0 * x.value(), 3.0 * y.value()), g.value().get(0, 0));
// Testing hypot function on (var, var, var)
var z = new Variable();
x.setValue(1.3);
y.setValue(2.3);
z.setValue(3.3);
assertEquals(Variable.hypot(x, y, z).value(), hypot(x.value(), y.value(), z.value()));
g.close();
g = new Gradient(Variable.hypot(x, y, z), x);
assertEquals(x.value() / hypot(x.value(), y.value(), z.value()), g.get().value().get(0, 0));
assertEquals(x.value() / hypot(x.value(), y.value(), z.value()), g.value().get(0, 0));
g.close();
g = new Gradient(Variable.hypot(x, y, z), y);
assertEquals(y.value() / hypot(x.value(), y.value(), z.value()), g.get().value().get(0, 0));
assertEquals(y.value() / hypot(x.value(), y.value(), z.value()), g.value().get(0, 0));
g.close();
g = new Gradient(Variable.hypot(x, y, z), z);
assertEquals(z.value() / hypot(x.value(), y.value(), z.value()), g.get().value().get(0, 0));
assertEquals(z.value() / hypot(x.value(), y.value(), z.value()), g.value().get(0, 0));
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testMax() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
x.setValue(2.0);
var x2 = x.times(x);
var x3 = x.times(x).times(x);
try (var g_x3 = new Gradient(x3, x)) {
// Testing lhs < rhs
var g = new Gradient(max(x2, x3), x);
assertEquals(x3.value(), max(x2, x3).value());
assertEquals(g_x3.value().get(0, 0), g.get().value().get(0, 0));
assertEquals(g_x3.value().get(0, 0), g.value().get(0, 0));
// Testing lhs > rhs
g.close();
g = new Gradient(max(x3, x2), x);
assertEquals(x3.value(), max(x3, x2).value());
assertEquals(g_x3.value().get(0, 0), g.get().value().get(0, 0));
assertEquals(g_x3.value().get(0, 0), g.value().get(0, 0));
// Testing lhs == rhs
g.close();
g = new Gradient(max(x, x), x);
assertEquals(x.value(), max(x, x).value());
assertEquals(1.0, g.get().value().get(0, 0));
assertEquals(1.0, g.value().get(0, 0));
g.close();
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testMin() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
x.setValue(2.0);
var x2 = x.times(x);
var x3 = x.times(x).times(x);
try (var g_x2 = new Gradient(x2, x)) {
// Testing lhs < rhs
var g = new Gradient(min(x2, x3), x);
assertEquals(x2.value(), min(x2, x3).value());
assertEquals(g_x2.value().get(0, 0), g.get().value().get(0, 0));
assertEquals(g_x2.value().get(0, 0), g.value().get(0, 0));
// Testing lhs > rhs
g.close();
g = new Gradient(min(x3, x2), x);
assertEquals(x2.value(), min(x3, x2).value());
assertEquals(g_x2.value().get(0, 0), g.get().value().get(0, 0));
assertEquals(g_x2.value().get(0, 0), g.value().get(0, 0));
// Testing lhs == rhs
g.close();
g = new Gradient(min(x, x), x);
assertEquals(x.value(), min(x, x).value());
assertEquals(1.0, g.get().value().get(0, 0));
assertEquals(1.0, g.value().get(0, 0));
g.close();
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testMiscellaneous() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
// dx/dx
x.setValue(3.0);
assertEquals(Math.abs(x.value()), abs(x).value());
var g = new Gradient(x, x);
assertEquals(1.0, g.get().value().get(0, 0));
assertEquals(1.0, g.value().get(0, 0));
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testVariableReuse() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var a = new Variable();
a.setValue(10);
var b = new Variable();
b.setValue(20);
var x = a.times(b);
var g = new Gradient(x, a);
assertEquals(20.0, g.get().value().get(0, 0));
assertEquals(20.0, g.value().get(0, 0));
b.setValue(10);
assertEquals(10.0, g.get().value().get(0, 0));
assertEquals(10.0, g.value().get(0, 0));
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testSignum() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new Variable();
// signum(1.0)
x.setValue(1.0);
assertEquals(Math.signum(x.value()), signum(x).value());
var g = new Gradient(signum(x), x);
assertEquals(0.0, g.get().value().get(0, 0));
assertEquals(0.0, g.value().get(0, 0));
// signum(-1.0)
x.setValue(-1.0);
assertEquals(Math.signum(x.value()), signum(x).value());
g.close();
g = new Gradient(signum(x), x);
assertEquals(0.0, g.get().value().get(0, 0));
assertEquals(0.0, g.value().get(0, 0));
// signum(0.0)
x.setValue(0.0);
assertEquals(Math.signum(x.value()), signum(x).value());
g.close();
g = new Gradient(signum(x), x);
assertEquals(0.0, g.get().value().get(0, 0));
assertEquals(0.0, g.value().get(0, 0));
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testNonScalar() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new VariableMatrix(3);
x.get(0).setValue(1);
x.get(1).setValue(2);
x.get(2).setValue(3);
// y = [x₁ + 3x₂ 5x₃]
//
// dy/dx = [1 3 5]
var y = x.get(0).plus(x.get(1).times(3)).minus(x.get(2).times(5));
var g = new Gradient(y, x);
var expected_g = new SimpleMatrix(new double[][] {{1.0}, {3.0}, {-5.0}});
var g_get_value = g.get().value();
assertEquals(3, g_get_value.getNumRows());
assertEquals(1, g_get_value.getNumCols());
assertEquals(expected_g, g_get_value);
var g_value = g.value();
assertEquals(3, g_value.getNumRows());
assertEquals(1, g_value.getNumCols());
assertEquals(expected_g, g_value);
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -2,15 +2,15 @@
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.jni;
package org.wpilib.math.autodiff;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import org.junit.jupiter.api.Test;
public class Ellipse2dJNITest {
public class HessianJNITest {
@Test
public void testLink() {
assertDoesNotThrow(Ellipse2dJNI::forceLoad);
assertDoesNotThrow(HessianJNI::forceLoad);
}
}

View File

@@ -0,0 +1,499 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.wpilib.math.DoubleRange.range;
import static org.wpilib.math.MatrixAssertions.assertEquals;
import static org.wpilib.math.autodiff.Variable.log;
import static org.wpilib.math.autodiff.Variable.pow;
import static org.wpilib.math.autodiff.Variable.sin;
import org.ejml.simple.SimpleMatrix;
import org.junit.jupiter.api.Test;
class HessianTest {
@Test
void testLinear() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
// y = x
var x = new VariableMatrix(1);
x.get(0).setValue(3);
var y = x.get(0);
// dy/dx = 1
var gradient = new Gradient(y, x.get(0));
double g = gradient.value().get(0, 0);
assertEquals(1.0, g);
// d²y/dx² = 0
var H = new Hessian(y, x);
assertEquals(0.0, H.get().value(0, 0));
assertEquals(0.0, H.value().get(0, 0));
H.close();
gradient.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testQuadratic() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
// y = x²
var x = new VariableMatrix(1);
x.get(0).setValue(3);
var y = x.get(0).times(x.get(0));
// dy/dx = 2x = 6
var gradient = new Gradient(y, x.get(0));
double g = gradient.value().get(0, 0);
assertEquals(6.0, g);
// d²y/dx² = 2
var H = new Hessian(y, x);
assertEquals(2.0, H.get().value(0, 0));
assertEquals(2.0, H.value().get(0, 0));
H.close();
gradient.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testCubic() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
// y = x³
var x = new VariableMatrix(1);
x.get(0).setValue(3);
var y = x.get(0).times(x.get(0)).times(x.get(0));
// dy/dx = 3x² = 27
var gradient = new Gradient(y, x.get(0));
double g = gradient.value().get(0, 0);
assertEquals(27.0, g);
// d²y/dx² = 6x = 18
var H = new Hessian(y, x);
assertEquals(18.0, H.get().value(0, 0));
assertEquals(18.0, H.value().get(0, 0));
H.close();
gradient.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testQuartic() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
// y = x⁴
var x = new VariableMatrix(1);
x.get(0).setValue(3);
var y = x.get(0).times(x.get(0)).times(x.get(0)).times(x.get(0));
// dy/dx = 4x³ = 108
var gradient = new Gradient(y, x.get(0));
double g = gradient.value().get(0, 0);
assertEquals(108.0, g);
// d²y/dx² = 12x² = 108
var H = new Hessian(y, x);
assertEquals(108.0, H.get().value(0, 0));
assertEquals(108.0, H.value().get(0, 0));
H.close();
gradient.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testSum() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new VariableMatrix(5);
for (int i = 0; i < 5; ++i) {
x.get(i).setValue(i + 1);
}
// y = sum(x)
var y = x.get(0).plus(x.get(1)).plus(x.get(2)).plus(x.get(3)).plus(x.get(4));
assertEquals(15.0, y.value());
var g = new Gradient(y, x);
assertEquals(SimpleMatrix.filled(5, 1, 1.0), g.get().value());
assertEquals(SimpleMatrix.filled(5, 1, 1.0), g.value());
var H = new Hessian(y, x);
assertEquals(SimpleMatrix.filled(5, 5, 0.0), H.get().value());
assertEquals(SimpleMatrix.filled(5, 5, 0.0), H.value());
H.close();
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testSumOfProducts() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new VariableMatrix(5);
for (int i = 0; i < 5; ++i) {
x.get(i).setValue(i + 1);
}
// y = ||x||²
var y = x.T().times(x).get(0);
assertEquals(1 * 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, y.value());
var g = new Gradient(y, x);
assertEquals(x.value().scale(2), g.get().value());
assertEquals(x.value().scale(2), g.value());
var H = new Hessian(y, x);
var expected_H = SimpleMatrix.diag(2.0, 2.0, 2.0, 2.0, 2.0);
assertEquals(expected_H, H.get().value());
assertEquals(expected_H, H.value());
H.close();
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testProductOfSines() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new VariableMatrix(5);
for (int i = 0; i < 5; ++i) {
x.get(i).setValue(i + 1);
}
// y = prod(sin(x))
var y = x.cwiseMap(Variable::sin).stream().reduce(new Variable(1.0), (a, b) -> a.times(b));
assertEquals(
Math.sin(1) * Math.sin(2) * Math.sin(3) * Math.sin(4) * Math.sin(5), y.value(), 1e-15);
var g = new Gradient(y, x);
for (int i = 0; i < x.rows(); ++i) {
assertEquals(y.value() / Math.tan(x.get(i).value()), g.get().value(i), 1e-15);
assertEquals(y.value() / Math.tan(x.get(i).value()), g.value().get(i, 0), 1e-15);
}
var H = new Hessian(y, x);
var expected_H = new SimpleMatrix(5, 5);
for (int i = 0; i < x.rows(); ++i) {
for (int j = 0; j < x.rows(); ++j) {
if (i == j) {
expected_H.set(i, j, -y.value());
} else {
expected_H.set(
i, j, y.value() / (Math.tan(x.get(i).value()) * Math.tan(x.get(j).value())));
}
}
}
assertEquals(expected_H, H.get().value(), 1e-15);
assertEquals(expected_H, H.value(), 1e-15);
H.close();
g.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testSumOfSquaredResiduals() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new VariableMatrix(5);
for (int i = 0; i < 5; ++i) {
x.get(i).setValue(1);
}
// y = sum(diff(x).^2)
var temp = x.block(0, 0, 4, 1).minus(x.block(1, 0, 4, 1)).cwiseMap(a -> pow(a, 2));
var y = temp.stream().reduce(new Variable(0.0), (a, b) -> a.plus(b));
var gradient = new Gradient(y, x);
var g = gradient.value();
assertEquals(0.0, y.value());
assertEquals(g.get(0, 0), 2 * x.get(0).value() - 2 * x.get(1).value());
assertEquals(
g.get(1, 0), -2 * x.get(0).value() + 4 * x.get(1).value() - 2 * x.get(2).value());
assertEquals(
g.get(2, 0), -2 * x.get(1).value() + 4 * x.get(2).value() - 2 * x.get(3).value());
assertEquals(
g.get(3, 0), -2 * x.get(2).value() + 4 * x.get(3).value() - 2 * x.get(4).value());
assertEquals(g.get(4, 0), -2 * x.get(3).value() + 2 * x.get(4).value());
var H = new Hessian(y, x);
var expected_H =
new SimpleMatrix(
new double[][] {
{2.0, -2.0, 0.0, 0.0, 0.0},
{-2.0, 4.0, -2.0, 0.0, 0.0},
{0.0, -2.0, 4.0, -2.0, 0.0},
{0.0, 0.0, -2.0, 4.0, -2.0},
{0.0, 0.0, 0.0, -2.0, 2.0}
});
assertEquals(expected_H, H.get().value());
assertEquals(expected_H, H.value());
H.close();
gradient.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testSumOfSquares() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var r = new VariableMatrix(4);
r.setValue(new double[][] {{25.0}, {10.0}, {5.0}, {0.0}});
var x = new VariableMatrix(4);
for (int i = 0; i < 4; ++i) {
x.get(i).setValue(0.0);
}
var J = new Variable(0.0);
for (int i = 0; i < 4; ++i) {
J = J.plus(r.get(i).minus(x.get(i)).times(r.get(i).minus(x.get(i))));
}
var H = new Hessian(J, x);
var expected_H = SimpleMatrix.diag(2.0, 2.0, 2.0, 2.0);
assertEquals(expected_H, H.get().value());
assertEquals(expected_H, H.value());
H.close();
J.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testNestedPowers() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
final var x0 = 3.0;
var x = new Variable();
x.setValue(x0);
var y = pow(pow(x, 2), 2);
var jacobian = new Jacobian(y, x);
var J = jacobian.value();
assertEquals(4 * x0 * x0 * x0, J.get(0, 0), 1e-12);
var hessian = new Hessian(y, x);
var H = hessian.value();
assertEquals(12 * x0 * x0, H.get(0, 0), 1e-12);
hessian.close();
jacobian.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testRosenbrock() {
// z = (1 x)² + 100(y x²)²
// = 100(x² + y)² + (x + 1)²
//
// ∂z/∂x = 200(x² + y)⋅2x + 2(x + 1)⋅1
// = 400x(x² + y) 2(x + 1)
// = 400x³ 400xy + 2x 2
//
// ∂z/∂y = 200(x² + y)
//
// ∂²z/∂x² = 1200x² 400y + 2
// ∂²z/∂xy = 400x
// ∂²z/∂y² = 200
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var input = new VariableMatrix(2);
var x = input.get(0);
var y = input.get(1);
var hessian =
new Hessian(
pow(new Variable(1).minus(x), 2).plus(pow(y.minus(pow(x, 2)), 2).times(100)), input);
for (var x0 : range(-2.5, 2.5, 0.1)) {
for (var y0 : range(-2.5, 2.5, 0.1)) {
x.setValue(x0);
y.setValue(y0);
var H = hessian.value();
assertEquals(1200 * x0 * x0 - 400 * y0 + 2, H.get(0, 0), 1e-11);
assertEquals(-400 * x0, H.get(0, 1), 1e-15);
assertEquals(-400 * x0, H.get(1, 0), 1e-15);
assertEquals(200, H.get(1, 1));
}
}
hessian.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testEdgePushingWangExample1() {
// See example 1 of [1]
//
// [1] Wang, M., et al. "Capitalizing on live variables: new algorithms for
// efficient Hessian computation via automatic differentiation", 2016.
// https://sci-hub.st/10.1007/s12532-016-0100-3
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new VariableMatrix(2);
x.get(0).setValue(3);
x.get(1).setValue(4);
// y = (x₀sin(x₁)) x₀
var y = (x.get(0).times(sin(x.get(1)))).times(x.get(0));
// dy/dx = [2x₀sin(x₁) x₀²cos(x₁)]
// dy/dx = [ 6sin(4) 9cos(4) ]
var J = new Jacobian(y, x);
var expected_J =
new SimpleMatrix(new double[][] {{6.0 * Math.sin(4.0), 9.0 * Math.cos(4.0)}});
assertEquals(expected_J, J.get().value(), 1e-15);
assertEquals(expected_J, J.value(), 1e-15);
// [ 2sin(x₁) 2x₀cos(x₁)]
// d²y/dx² = [2x₀cos(x₁) x₀²sin(x₁)]
//
// [2sin(4) 6cos(4)]
// d²y/dx² = [6cos(4) 9sin(4)]
var H = new Hessian(y, x);
var expected_H =
new SimpleMatrix(
new double[][] {
{2.0 * Math.sin(4.0), 6.0 * Math.cos(4.0)},
{6.0 * Math.cos(4.0), -9.0 * Math.sin(4.0)}
});
assertEquals(expected_H, H.get().value(), 1e-15);
assertEquals(expected_H, H.value(), 1e-15);
H.close();
J.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testEdgePushingPetroFigure1() {
// See figure 1 of [1]
//
// [1] Petro, C. G., et al. "On efficient Hessian computation using the edge
// pushing algorithm in Julia", 2017.
// https://mlubin.github.io/pdf/edge_pushing_julia.pdf
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
// y = p₁ log(x₁x₂)
var p_1 = new Variable(2.0);
var x = new VariableMatrix(2);
x.get(0).setValue(2.0);
x.get(1).setValue(3.0);
var y = p_1.times(log(x.get(0).times(x.get(1))));
// d²y/dx² = [p₁/x₁² 0 ]
// [ 0 p₁/x₂²]
var H = new Hessian(y, x);
var expected_H =
new SimpleMatrix(
new double[][] {
{-p_1.value() / (x.get(0).value() * x.get(0).value()), 0.0},
{0.0, -p_1.value() / (x.get(1).value() * x.get(1).value())}
});
assertEquals(expected_H, H.get().value(), 1e-15);
assertEquals(expected_H, H.value(), 1e-15);
H.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testVariableReuse() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
Variable y;
var x = new VariableMatrix(1);
// y = x³
x.get(0).setValue(1);
y = x.get(0).times(x.get(0)).times(x.get(0));
var hessian = new Hessian(y, x);
// d²y/dx² = 6x
// H = 6
var H = hessian.value();
assertEquals(1, H.getNumRows());
assertEquals(1, H.getNumCols());
assertEquals(6.0, H.get(0, 0));
x.get(0).setValue(2);
// d²y/dx² = 6x
// H = 12
H = hessian.value();
assertEquals(1, H.getNumRows());
assertEquals(1, H.getNumCols());
assertEquals(12.0, H.get(0, 0));
hessian.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -0,0 +1,16 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import org.junit.jupiter.api.Test;
public class JacobianJNITest {
@Test
public void testLink() {
assertDoesNotThrow(JacobianJNI::forceLoad);
}
}

View File

@@ -0,0 +1,266 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.wpilib.math.MatrixAssertions.assertEquals;
import org.ejml.simple.SimpleMatrix;
import org.junit.jupiter.api.Test;
class JacobianTest {
@Test
void testYEqualsX() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new VariableMatrix(3);
for (int i = 0; i < 3; ++i) {
x.get(i).setValue(i + 1);
}
// y = x
//
// [1 0 0]
// dy/dx = [0 1 0]
// [0 0 1]
var y = x;
var J = new Jacobian(y, x);
var expected_J =
new SimpleMatrix(new double[][] {{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0}});
assertEquals(expected_J, J.get().value());
assertEquals(expected_J, J.value());
J.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testYEquals3X() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new VariableMatrix(3);
for (int i = 0; i < 3; ++i) {
x.get(i).setValue(i + 1);
}
// y = 3x
//
// [3 0 0]
// dy/dx = [0 3 0]
// [0 0 3]
var y = x.times(3);
var J = new Jacobian(y, x);
var expected_J =
new SimpleMatrix(new double[][] {{3.0, 0.0, 0.0}, {0.0, 3.0, 0.0}, {0.0, 0.0, 3.0}});
assertEquals(expected_J, J.get().value());
assertEquals(expected_J, J.value());
J.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testProducts() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new VariableMatrix(3);
for (int i = 0; i < 3; ++i) {
x.get(i).setValue(i + 1);
}
// [x₁x₂]
// y = [x₂x₃]
// [x₁x₃]
//
// [x₂ x₁ 0 ]
// dy/dx = [0 x₃ x₂]
// [x₃ 0 x₁]
//
// [2 1 0]
// dy/dx = [0 3 2]
// [3 0 1]
var y = new VariableMatrix(3);
y.set(0, x.get(0).times(x.get(1)));
y.set(1, x.get(1).times(x.get(2)));
y.set(2, x.get(0).times(x.get(2)));
var J = new Jacobian(y, x);
var expected_J =
new SimpleMatrix(new double[][] {{2.0, 1.0, 0.0}, {0.0, 3.0, 2.0}, {3.0, 0.0, 1.0}});
assertEquals(expected_J, J.get().value());
assertEquals(expected_J, J.value());
J.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testNestedProducts() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new VariableMatrix(1);
x.get(0).setValue(3);
assertEquals(3.0, x.value(0));
// [ 5x] [15]
// y = [ 7x] = [21]
// [11x] [33]
var y = new VariableMatrix(3);
y.set(0, x.get(0).times(5));
y.set(1, x.get(0).times(7));
y.set(2, x.get(0).times(11));
assertEquals(15.0, y.value(0));
assertEquals(21.0, y.value(1));
assertEquals(33.0, y.value(2));
// [y₁y₂] [15⋅21] [315]
// z = [y₂y₃] = [21⋅33] = [693]
// [y₁y₃] [15⋅33] [495]
var z = new VariableMatrix(3);
z.set(0, y.get(0).times(y.get(1)));
z.set(1, y.get(1).times(y.get(2)));
z.set(2, y.get(0).times(y.get(2)));
assertEquals(315.0, z.value(0));
assertEquals(693.0, z.value(1));
assertEquals(495.0, z.value(2));
// [ 5x]
// y = [ 7x]
// [11x]
//
// [ 5]
// dy/dx = [ 7]
// [11]
var J = new Jacobian(y, x);
assertEquals(5.0, J.get().value(0, 0));
assertEquals(7.0, J.get().value(1, 0));
assertEquals(11.0, J.get().value(2, 0));
assertEquals(5.0, J.value().get(0, 0));
assertEquals(7.0, J.value().get(1, 0));
assertEquals(11.0, J.value().get(2, 0));
// [y₁y₂]
// z = [y₂y₃]
// [y₁y₃]
//
// [y₂ y₁ 0 ] [21 15 0]
// dz/dy = [0 y₃ y₂] = [ 0 33 21]
// [y₃ 0 y₁] [33 0 15]
J.close();
J = new Jacobian(z, y);
var expected_J =
new SimpleMatrix(
new double[][] {{21.0, 15.0, 0.0}, {0.0, 33.0, 21.0}, {33.0, 0.0, 15.0}});
assertEquals(expected_J, J.get().value());
assertEquals(expected_J, J.value());
// [y₁y₂] [5x⋅ 7x] [35x²]
// z = [y₂y₃] = [7x⋅11x] = [77x²]
// [y₁y₃] [5x⋅11x] [55x²]
//
// [ 70x] [210]
// dz/dx = [154x] = [462]
// [110x] = [330]
J.close();
J = new Jacobian(z, x);
expected_J = new SimpleMatrix(new double[][] {{210.0}, {462.0}, {330.0}});
assertEquals(expected_J, J.get().value());
assertEquals(expected_J, J.value());
J.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testNonSquare() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new VariableMatrix(3);
for (int i = 0; i < 3; ++i) {
x.get(i).setValue(i + 1);
}
// y = [x₁ + 3x₂ 5x₃]
//
// dy/dx = [1 3 5]
var y = new VariableMatrix(1);
y.set(0, x.get(0).plus(x.get(1).times(3)).minus(x.get(2).times(5)));
var J = new Jacobian(y, x);
var expected_J = new SimpleMatrix(new double[][] {{1.0, 3.0, -5.0}});
var J_get_value = J.get().value();
assertEquals(1, J_get_value.getNumRows());
assertEquals(3, J_get_value.getNumCols());
assertEquals(expected_J, J_get_value);
var J_value = J.value();
assertEquals(1, J_value.getNumRows());
assertEquals(3, J_value.getNumCols());
assertEquals(expected_J, J_value);
J.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testVariableReuse() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var x = new VariableMatrix(2);
for (int i = 0; i < 2; ++i) {
x.get(i).setValue(i + 1);
}
// y = [x₁x₂]
var y = new VariableMatrix(1);
y.set(0, x.get(0).times(x.get(1)));
var jacobian = new Jacobian(y, x);
// dy/dx = [x₂ x₁]
// dy/dx = [2 1]
var J = jacobian.value();
assertEquals(1, J.getNumRows());
assertEquals(2, J.getNumCols());
assertEquals(2.0, J.get(0, 0));
assertEquals(1.0, J.get(0, 1));
x.get(0).setValue(2);
x.get(1).setValue(1);
// dy/dx = [x₂ x₁]
// dy/dx = [1 2]
J = jacobian.value();
assertEquals(1, J.getNumRows());
assertEquals(2, J.getNumCols());
assertEquals(1.0, J.get(0, 0));
assertEquals(2.0, J.get(0, 1));
jacobian.close();
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -0,0 +1,481 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.Test;
class SliceTest {
@Test
void testDefaultConstructor() {
var slice = new Slice();
assertEquals(0, slice.start);
assertEquals(0, slice.stop);
assertEquals(1, slice.step);
assertEquals(0, slice.adjust(3));
assertEquals(0, slice.start);
assertEquals(0, slice.stop);
assertEquals(1, slice.step);
}
@Test
void testOneArgConstructor() {
// none
{
var slice = new Slice(Slice.__);
assertEquals(0, slice.start);
assertEquals(Integer.MAX_VALUE, slice.stop);
assertEquals(1, slice.step);
assertEquals(3, slice.adjust(3));
assertEquals(0, slice.start);
assertEquals(3, slice.stop);
assertEquals(1, slice.step);
}
// +
{
var slice = new Slice(1);
assertEquals(1, slice.start);
assertEquals(2, slice.stop);
assertEquals(1, slice.step);
assertEquals(1, slice.adjust(3));
assertEquals(1, slice.start);
assertEquals(2, slice.stop);
assertEquals(1, slice.step);
}
// -1
{
var slice = new Slice(-1);
assertEquals(-1, slice.start);
assertEquals(Integer.MAX_VALUE, slice.stop);
assertEquals(1, slice.step);
assertEquals(1, slice.adjust(3));
assertEquals(2, slice.start);
assertEquals(3, slice.stop);
assertEquals(1, slice.step);
}
// -2
{
var slice = new Slice(-2);
assertEquals(-2, slice.start);
assertEquals(-1, slice.stop);
assertEquals(1, slice.step);
assertEquals(1, slice.adjust(3));
assertEquals(1, slice.start);
assertEquals(2, slice.stop);
assertEquals(1, slice.step);
}
}
@Test
void testTwoArgConstructor() {
// none, none
{
var slice = new Slice(Slice.__, Slice.__);
assertEquals(0, slice.start);
assertEquals(Integer.MAX_VALUE, slice.stop);
assertEquals(1, slice.step);
assertEquals(3, slice.adjust(3));
assertEquals(0, slice.start);
assertEquals(3, slice.stop);
assertEquals(1, slice.step);
}
// none, +
{
var slice = new Slice(Slice.__, 1);
assertEquals(0, slice.start);
assertEquals(1, slice.stop);
assertEquals(1, slice.step);
assertEquals(1, slice.adjust(3));
assertEquals(0, slice.start);
assertEquals(1, slice.stop);
assertEquals(1, slice.step);
}
// none, -
{
var slice = new Slice(Slice.__, -1);
assertEquals(0, slice.start);
assertEquals(-1, slice.stop);
assertEquals(1, slice.step);
assertEquals(2, slice.adjust(3));
assertEquals(0, slice.start);
assertEquals(2, slice.stop);
assertEquals(1, slice.step);
}
// +, none
{
var slice = new Slice(1, Slice.__);
assertEquals(1, slice.start);
assertEquals(Integer.MAX_VALUE, slice.stop);
assertEquals(1, slice.step);
assertEquals(2, slice.adjust(3));
assertEquals(1, slice.start);
assertEquals(3, slice.stop);
assertEquals(1, slice.step);
}
// -, none
{
var slice = new Slice(-1, Slice.__);
assertEquals(-1, slice.start);
assertEquals(Integer.MAX_VALUE, slice.stop);
assertEquals(1, slice.step);
assertEquals(1, slice.adjust(3));
assertEquals(2, slice.start);
assertEquals(3, slice.stop);
assertEquals(1, slice.step);
}
// +, +
{
var slice = new Slice(1, 2);
assertEquals(1, slice.start);
assertEquals(2, slice.stop);
assertEquals(1, slice.step);
assertEquals(1, slice.adjust(3));
assertEquals(1, slice.start);
assertEquals(2, slice.stop);
assertEquals(1, slice.step);
}
// +, -
{
var slice = new Slice(1, -1);
assertEquals(1, slice.start);
assertEquals(-1, slice.stop);
assertEquals(1, slice.step);
assertEquals(1, slice.adjust(3));
assertEquals(1, slice.start);
assertEquals(2, slice.stop);
assertEquals(1, slice.step);
}
// -, -
{
var slice = new Slice(-2, -1);
assertEquals(-2, slice.start);
assertEquals(-1, slice.stop);
assertEquals(1, slice.step);
assertEquals(1, slice.adjust(3));
assertEquals(1, slice.start);
assertEquals(2, slice.stop);
assertEquals(1, slice.step);
}
}
@Test
void testThreeArgConstructor() {
// none, none, none
{
var slice = new Slice(Slice.__, Slice.__, Slice.__);
assertEquals(0, slice.start);
assertEquals(Integer.MAX_VALUE, slice.stop);
assertEquals(1, slice.step);
assertEquals(3, slice.adjust(3));
assertEquals(0, slice.start);
assertEquals(3, slice.stop);
assertEquals(1, slice.step);
}
// none, none, +
{
var slice = new Slice(Slice.__, Slice.__, 2);
assertEquals(0, slice.start);
assertEquals(Integer.MAX_VALUE, slice.stop);
assertEquals(2, slice.step);
assertEquals(2, slice.adjust(3));
assertEquals(0, slice.start);
assertEquals(3, slice.stop);
assertEquals(2, slice.step);
}
// none, none, -
{
var slice = new Slice(Slice.__, Slice.__, -2);
assertEquals(Integer.MAX_VALUE, slice.start);
assertEquals(Integer.MIN_VALUE, slice.stop);
assertEquals(-2, slice.step);
assertEquals(2, slice.adjust(3));
assertEquals(2, slice.start);
assertEquals(-1, slice.stop);
assertEquals(-2, slice.step);
}
// none, +, +
{
var slice = new Slice(Slice.__, 1, 2);
assertEquals(0, slice.start);
assertEquals(1, slice.stop);
assertEquals(2, slice.step);
assertEquals(1, slice.adjust(3));
assertEquals(0, slice.start);
assertEquals(1, slice.stop);
assertEquals(2, slice.step);
}
// none, +, -
{
var slice = new Slice(Slice.__, 1, -2);
assertEquals(Integer.MAX_VALUE, slice.start);
assertEquals(1, slice.stop);
assertEquals(-2, slice.step);
assertEquals(1, slice.adjust(3));
assertEquals(2, slice.start);
assertEquals(1, slice.stop);
assertEquals(-2, slice.step);
}
// none, -, -
{
var slice = new Slice(Slice.__, -2, -1);
assertEquals(Integer.MAX_VALUE, slice.start);
assertEquals(-2, slice.stop);
assertEquals(-1, slice.step);
assertEquals(1, slice.adjust(3));
assertEquals(2, slice.start);
assertEquals(1, slice.stop);
assertEquals(-1, slice.step);
}
// +, none, +
{
var slice = new Slice(1, Slice.__, 2);
assertEquals(1, slice.start);
assertEquals(Integer.MAX_VALUE, slice.stop);
assertEquals(2, slice.step);
assertEquals(1, slice.adjust(3));
assertEquals(1, slice.start);
assertEquals(3, slice.stop);
assertEquals(2, slice.step);
}
// +, none, -
{
var slice = new Slice(1, Slice.__, -2);
assertEquals(1, slice.start);
assertEquals(Integer.MIN_VALUE, slice.stop);
assertEquals(-2, slice.step);
assertEquals(1, slice.adjust(3));
assertEquals(1, slice.start);
assertEquals(-1, slice.stop);
assertEquals(-2, slice.step);
}
// +, +, +
{
var slice = new Slice(1, 2, 2);
assertEquals(1, slice.start);
assertEquals(2, slice.stop);
assertEquals(2, slice.step);
assertEquals(1, slice.adjust(3));
assertEquals(1, slice.start);
assertEquals(2, slice.stop);
assertEquals(2, slice.step);
}
// +, +, -
{
var slice = new Slice(2, 1, -2);
assertEquals(2, slice.start);
assertEquals(1, slice.stop);
assertEquals(-2, slice.step);
assertEquals(1, slice.adjust(3));
assertEquals(2, slice.start);
assertEquals(1, slice.stop);
assertEquals(-2, slice.step);
}
}
@Test
void testEmptySlices() {
// +, +, +
{
var slice = new Slice(2, 1, 2);
assertEquals(2, slice.start);
assertEquals(1, slice.stop);
assertEquals(2, slice.step);
assertEquals(0, slice.adjust(3));
assertEquals(2, slice.start);
assertEquals(1, slice.stop);
assertEquals(2, slice.step);
}
// +, +, -
{
var slice = new Slice(1, 2, -2);
assertEquals(1, slice.start);
assertEquals(2, slice.stop);
assertEquals(-2, slice.step);
assertEquals(0, slice.adjust(3));
assertEquals(1, slice.start);
assertEquals(2, slice.stop);
assertEquals(-2, slice.step);
}
// +, -, -
{
var slice = new Slice(3, -1, -2);
assertEquals(3, slice.start);
assertEquals(-1, slice.stop);
assertEquals(-2, slice.step);
assertEquals(0, slice.adjust(3));
assertEquals(2, slice.start);
assertEquals(2, slice.stop);
assertEquals(-2, slice.step);
}
}
@Test
void testStepUBGuard() {
{
// none, none, INT_MIN
var slice = new Slice(Slice.__, Slice.__, Integer.MIN_VALUE);
assertEquals(Integer.MAX_VALUE, slice.start);
assertEquals(Integer.MIN_VALUE, slice.stop);
assertEquals(Integer.MIN_VALUE + 1, slice.step);
slice.step = -slice.step;
assertEquals(Integer.MAX_VALUE, slice.start);
assertEquals(Integer.MIN_VALUE, slice.stop);
assertEquals(Integer.MAX_VALUE, slice.step);
}
{
// none, +, INT_MIN
var slice = new Slice(Slice.__, 2, Integer.MIN_VALUE);
assertEquals(Integer.MAX_VALUE, slice.start);
assertEquals(2, slice.stop);
assertEquals(Integer.MIN_VALUE + 1, slice.step);
slice.step = -slice.step;
assertEquals(Integer.MAX_VALUE, slice.start);
assertEquals(2, slice.stop);
assertEquals(Integer.MAX_VALUE, slice.step);
}
{
// none, -, INT_MIN
var slice = new Slice(Slice.__, -2, Integer.MIN_VALUE);
assertEquals(Integer.MAX_VALUE, slice.start);
assertEquals(-2, slice.stop);
assertEquals(Integer.MIN_VALUE + 1, slice.step);
slice.step = -slice.step;
assertEquals(Integer.MAX_VALUE, slice.start);
assertEquals(-2, slice.stop);
assertEquals(Integer.MAX_VALUE, slice.step);
}
{
// +, none, INT_MIN
var slice = new Slice(1, Slice.__, Integer.MIN_VALUE);
assertEquals(1, slice.start);
assertEquals(Integer.MIN_VALUE, slice.stop);
assertEquals(Integer.MIN_VALUE + 1, slice.step);
slice.step = -slice.step;
assertEquals(1, slice.start);
assertEquals(Integer.MIN_VALUE, slice.stop);
assertEquals(Integer.MAX_VALUE, slice.step);
}
{
// -, none, INT_MIN
var slice = new Slice(-2, Slice.__, Integer.MIN_VALUE);
assertEquals(-2, slice.start);
assertEquals(Integer.MIN_VALUE, slice.stop);
assertEquals(Integer.MIN_VALUE + 1, slice.step);
slice.step = -slice.step;
assertEquals(-2, slice.start);
assertEquals(Integer.MIN_VALUE, slice.stop);
assertEquals(Integer.MAX_VALUE, slice.step);
}
{
// +, +, INT_MIN
var slice = new Slice(1000, 0, Integer.MIN_VALUE);
assertEquals(1000, slice.start);
assertEquals(0, slice.stop);
assertEquals(Integer.MIN_VALUE + 1, slice.step);
slice.step = -slice.step;
assertEquals(1000, slice.start);
assertEquals(0, slice.stop);
assertEquals(Integer.MAX_VALUE, slice.step);
}
{
// +, -, INT_MIN
var slice = new Slice(1000, -2, Integer.MIN_VALUE);
assertEquals(1000, slice.start);
assertEquals(-2, slice.stop);
assertEquals(Integer.MIN_VALUE + 1, slice.step);
slice.step = -slice.step;
assertEquals(1000, slice.start);
assertEquals(-2, slice.stop);
assertEquals(Integer.MAX_VALUE, slice.step);
}
{
// -, +, INT_MIN
var slice = new Slice(-1, 2, Integer.MIN_VALUE);
assertEquals(-1, slice.start);
assertEquals(2, slice.stop);
assertEquals(Integer.MIN_VALUE + 1, slice.step);
slice.step = -slice.step;
assertEquals(-1, slice.start);
assertEquals(2, slice.stop);
assertEquals(Integer.MAX_VALUE, slice.step);
}
{
// -, -, INT_MIN
var slice = new Slice(-1, -2, Integer.MIN_VALUE);
assertEquals(-1, slice.start);
assertEquals(-2, slice.stop);
assertEquals(Integer.MIN_VALUE + 1, slice.step);
slice.step = -slice.step;
assertEquals(-1, slice.start);
assertEquals(-2, slice.stop);
assertEquals(Integer.MAX_VALUE, slice.step);
}
}
}

View File

@@ -0,0 +1,16 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import org.junit.jupiter.api.Test;
public class VariableJNITest {
@Test
public void testLink() {
assertDoesNotThrow(VariableJNI::forceLoad);
}
}

View File

@@ -0,0 +1,16 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import org.junit.jupiter.api.Test;
public class VariableMatrixJNITest {
@Test
public void testLink() {
assertDoesNotThrow(VariableMatrixJNI::forceLoad);
}
}

View File

@@ -0,0 +1,600 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.wpilib.math.MatrixAssertions.assertEquals;
import org.ejml.simple.SimpleMatrix;
import org.junit.jupiter.api.Test;
class VariableMatrixTest {
@Test
void testConstructFromDoubleArray() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var mat = new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})) {
var expected = new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
assertEquals(expected, mat.value());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testConstructFromSimpleMatrix() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var mat =
new VariableMatrix(new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}))) {
var expected = new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
assertEquals(expected, mat.value());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testAssignmentToDefault() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var mat = new VariableMatrix(2, 2);
assertEquals(2, mat.rows());
assertEquals(2, mat.cols());
assertEquals(0.0, mat.get(0, 0).value());
assertEquals(0.0, mat.get(0, 1).value());
assertEquals(0.0, mat.get(1, 0).value());
assertEquals(0.0, mat.get(1, 1).value());
mat.set(0, 0, 1.0);
mat.set(0, 1, 2.0);
mat.set(1, 0, 3.0);
mat.set(1, 1, 4.0);
assertEquals(1.0, mat.get(0, 0).value());
assertEquals(2.0, mat.get(0, 1).value());
assertEquals(3.0, mat.get(1, 0).value());
assertEquals(4.0, mat.get(1, 1).value());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testAssignmentAliasing() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var A = new VariableMatrix(new double[][] {{1.0, 2.0}, {3.0, 4.0}});
var B = new VariableMatrix(new double[][] {{5.0, 6.0}, {7.0, 8.0}});
// A and B initially contain different values
var expected_A = new SimpleMatrix(new double[][] {{1.0, 2.0}, {3.0, 4.0}});
var expected_B = new SimpleMatrix(new double[][] {{5.0, 6.0}, {7.0, 8.0}});
assertEquals(expected_A, A.value());
assertEquals(expected_B, B.value());
// Make A point to B's storage
A.set(B);
assertEquals(expected_B, A.value());
assertEquals(expected_B, B.value());
// Changes to B should be reflected in A
B.get(0, 0).setValue(2.0);
expected_B.set(0, 0, 2.0);
assertEquals(expected_B, A.value());
assertEquals(expected_B, B.value());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testBlockMemberFunction() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var A =
new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}});
// Block assignment
A.block(1, 1, 2, 2).set(new double[][] {{10.0, 11.0}, {12.0, 13.0}});
var expected1 =
new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 10.0, 11.0}, {7.0, 12.0, 13.0}});
assertEquals(expected1, A.value());
// Block-of-block assignment
A.block(1, 1, 2, 2).block(1, 1, 1, 1).set(14.0);
var expected2 =
new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 10.0, 11.0}, {7.0, 12.0, 14.0}});
assertEquals(expected2, A.value());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testSlicing() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var mat =
new VariableMatrix(
new double[][] {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}});
assertEquals(4, mat.rows());
assertEquals(4, mat.cols());
// Single-arg index operator on full matrix
for (int i = 0; i < mat.rows() * mat.cols(); ++i) {
assertEquals(i + 1, mat.get(i).value());
}
// Slice from start
{
var s = mat.get(new Slice(1, Slice.__), new Slice(2, Slice.__));
assertEquals(3, s.rows());
assertEquals(2, s.cols());
// Single-arg index operator on forward slice
assertEquals(7.0, s.get(0).value());
assertEquals(8.0, s.get(1).value());
assertEquals(11.0, s.get(2).value());
assertEquals(12.0, s.get(3).value());
assertEquals(15.0, s.get(4).value());
assertEquals(16.0, s.get(5).value());
// Double-arg index operator on forward slice
assertEquals(7.0, s.get(0, 0).value());
assertEquals(8.0, s.get(0, 1).value());
assertEquals(11.0, s.get(1, 0).value());
assertEquals(12.0, s.get(1, 1).value());
assertEquals(15.0, s.get(2, 0).value());
assertEquals(16.0, s.get(2, 1).value());
}
// Slice from end
{
var s = mat.get(new Slice(-1, Slice.__), new Slice(-2, Slice.__));
assertEquals(1, s.rows());
assertEquals(2, s.cols());
// Single-arg index operator on reverse slice
assertEquals(15.0, s.get(0).value());
assertEquals(16.0, s.get(1).value());
// Double-arg index operator on reverse slice
assertEquals(15.0, s.get(0, 0).value());
assertEquals(16.0, s.get(0, 1).value());
}
// Slice from start with step of 2
{
var s = mat.get(Slice.__, new Slice(Slice.__, Slice.__, 2));
assertEquals(4, s.rows());
assertEquals(2, s.cols());
assertEquals(
new SimpleMatrix(new double[][] {{1.0, 3.0}, {5.0, 7.0}, {9.0, 11.0}, {13.0, 15.0}}),
s.value());
}
// Slice from end with negative step for row and column
{
var s = mat.get(new Slice(Slice.__, Slice.__, -1), new Slice(Slice.__, Slice.__, -2));
assertEquals(4, s.rows());
assertEquals(2, s.cols());
assertEquals(
new SimpleMatrix(new double[][] {{16.0, 14.0}, {12.0, 10.0}, {8.0, 6.0}, {4.0, 2.0}}),
s.value());
}
// Slice from start and column -1
{
var s = mat.get(new Slice(1, Slice.__), -1);
assertEquals(3, s.rows());
assertEquals(1, s.cols());
assertEquals(new SimpleMatrix(new double[][] {{8.0}, {12.0}, {16.0}}), s.value());
}
// Slice from start and column -2
{
var s = mat.get(new Slice(1, Slice.__), -2);
assertEquals(3, s.rows());
assertEquals(1, s.cols());
assertEquals(new SimpleMatrix(new double[][] {{7.0}, {11.0}, {15.0}}), s.value());
}
// Block assignment
{
var s = mat.get(new Slice(Slice.__, Slice.__, 2), new Slice(Slice.__, Slice.__, 2));
assertEquals(2, s.rows());
assertEquals(2, s.cols());
s.setValue(new double[][] {{17.0, 18.0}, {19.0, 20.0}});
assertEquals(
new SimpleMatrix(
new double[][] {
{17.0, 2.0, 18.0, 4.0},
{5.0, 6.0, 7.0, 8.0},
{19.0, 10.0, 20.0, 12.0},
{13.0, 14.0, 15.0, 16.0}
}),
mat.value());
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testSubslicing() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
// Block-of-block assignment (row skip forward)
{
var mat = new VariableMatrix(5, 5);
var s =
mat.get(new Slice(Slice.__, Slice.__, 2), new Slice(Slice.__, Slice.__, 1))
.get(new Slice(1, 3), new Slice(1, 4));
assertEquals(2, s.rows());
assertEquals(3, s.cols());
s.setValue(new double[][] {{1, 2, 3}, {4, 5, 6}});
assertEquals(
new SimpleMatrix(
new double[][] {
{0, 0, 0, 0, 0},
{0, 0, 0, 0, 0},
{0, 1, 2, 3, 0},
{0, 0, 0, 0, 0},
{0, 4, 5, 6, 0}
}),
mat.value());
}
// Block-of-block assignment (row skip backward)
{
var mat = new VariableMatrix(5, 5);
var s =
mat.get(new Slice(Slice.__, Slice.__, -2), new Slice(Slice.__, Slice.__, -1))
.get(new Slice(1, 3), new Slice(1, 4));
assertEquals(2, s.rows());
assertEquals(3, s.cols());
s.setValue(new double[][] {{1, 2, 3}, {4, 5, 6}});
assertEquals(
new SimpleMatrix(
new double[][] {
{0, 6, 5, 4, 0},
{0, 0, 0, 0, 0},
{0, 3, 2, 1, 0},
{0, 0, 0, 0, 0},
{0, 0, 0, 0, 0}
}),
mat.value());
}
// Block-of-block assignment (column skip forward)
{
var mat = new VariableMatrix(5, 5);
var s =
mat.get(new Slice(Slice.__, Slice.__, 1), new Slice(Slice.__, Slice.__, 2))
.get(new Slice(1, 4), new Slice(1, 3));
assertEquals(3, s.rows());
assertEquals(2, s.cols());
s.setValue(new double[][] {{1, 2}, {3, 4}, {5, 6}});
assertEquals(
new SimpleMatrix(
new double[][] {
{0, 0, 0, 0, 0},
{0, 0, 1, 0, 2},
{0, 0, 3, 0, 4},
{0, 0, 5, 0, 6},
{0, 0, 0, 0, 0}
}),
mat.value());
}
// Block-of-block assignment (column skip backward)
{
var mat = new VariableMatrix(5, 5);
var s =
mat.get(new Slice(Slice.__, Slice.__, -1), new Slice(Slice.__, Slice.__, -2))
.get(new Slice(1, 4), new Slice(1, 3));
assertEquals(3, s.rows());
assertEquals(2, s.cols());
s.setValue(new double[][] {{1, 2}, {3, 4}, {5, 6}});
assertEquals(
new SimpleMatrix(
new double[][] {
{0, 0, 0, 0, 0},
{6, 0, 5, 0, 0},
{4, 0, 3, 0, 0},
{2, 0, 1, 0, 0},
{0, 0, 0, 0, 0}
}),
mat.value());
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@SuppressWarnings("PMD.UnusedLocalVariable")
@Test
void testIterators() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
final var A =
new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}});
final var sub_A = A.block(2, 1, 1, 2);
int distance = 0;
for (var elem : A) {
++distance;
}
assertEquals(9, distance);
distance = 0;
for (var elem : sub_A) {
++distance;
}
assertEquals(2, distance);
int i = 1;
for (var elem : A) {
assertEquals(i, elem.value());
++i;
}
i = 8;
for (var elem : sub_A) {
assertEquals(i, elem.value());
++i;
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testValue() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var A =
new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}});
var expected =
new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}});
// Full matrix
assertEquals(expected, A.value());
assertEquals(4.0, A.value(3));
assertEquals(2.0, A.T().value(3));
// Block
assertEquals(expected.extractMatrix(1, 3, 1, 3), A.block(1, 1, 2, 2).value());
assertEquals(8.0, A.block(1, 1, 2, 2).value(2));
assertEquals(6.0, A.T().block(1, 1, 2, 2).value(2));
// Slice
assertEquals(
expected.extractMatrix(1, 3, 1, 3), A.get(new Slice(1, 3), new Slice(1, 3)).value());
assertEquals(8.0, A.get(new Slice(1, 3), new Slice(1, 3)).value(2));
assertEquals(6.0, A.get(new Slice(1, 3), new Slice(1, 3)).T().value(2));
// Block-of-block
assertEquals(
expected.extractMatrix(1, 3, 1, 3).extractMatrix(0, 2, 1, 2),
A.block(1, 1, 2, 2).block(0, 1, 2, 1).value());
assertEquals(9.0, A.block(1, 1, 2, 2).block(0, 1, 2, 1).value(1));
assertEquals(9.0, A.block(1, 1, 2, 2).T().block(0, 1, 2, 1).value(1));
// Slice-of-slice
assertEquals(
expected.extractMatrix(1, 3, 1, 3).extractMatrix(0, 2, 1, 2),
A.get(new Slice(1, 3), new Slice(1, 3)).get(Slice.__, new Slice(1, Slice.__)).value());
assertEquals(
9.0,
A.get(new Slice(1, 3), new Slice(1, 3)).get(Slice.__, new Slice(1, Slice.__)).value(1));
assertEquals(
9.0,
A.get(new Slice(1, 3), new Slice(1, 3))
.T()
.get(Slice.__, new Slice(1, Slice.__))
.value(1));
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testCwiseMap() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
// VariableMatrix cwiseMap
var A = new VariableMatrix(new double[][] {{-2.0, -3.0, -4.0}, {-5.0, -6.0, -7.0}});
var result1 = A.cwiseMap(Variable::abs);
var expected1 = new SimpleMatrix(new double[][] {{2.0, 3.0, 4.0}, {5.0, 6.0, 7.0}});
// Don't modify original matrix
assertEquals(expected1.scale(-1.0), A.value());
assertEquals(expected1, result1.value());
// VariableBlock cwiseMap
var sub_A = A.block(0, 0, 2, 2);
var result2 = sub_A.cwiseMap(Variable::abs);
var expected2 = new SimpleMatrix(new double[][] {{2.0, 3.0}, {5.0, 6.0}});
// Don't modify original matrix
assertEquals(expected1.scale(-1.0), A.value());
assertEquals(expected2.scale(-1.0), sub_A.value());
assertEquals(expected2, result2.value());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testZeroStaticFunction() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var A = VariableMatrix.zero(2, 3)) {
for (var elem : A) {
assertEquals(0.0, elem.value());
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testOneStaticFunction() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var A = VariableMatrix.one(2, 3)) {
for (var elem : A) {
assertEquals(1.0, elem.value());
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testConstantStaticFunction() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var A = VariableMatrix.constant(2, 3, 2.0)) {
for (var elem : A) {
assertEquals(2.0, elem.value());
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testCwiseReduce() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var A = new VariableMatrix(new double[][] {{2.0, 3.0, 4.0}, {5.0, 6.0, 7.0}});
var B = new VariableMatrix(new double[][] {{8.0, 9.0, 10.0}, {11.0, 12.0, 13.0}});
var result = VariableMatrix.cwiseReduce(A, B, (Variable x, Variable y) -> x.times(y));
var expected = new SimpleMatrix(new double[][] {{16.0, 27.0, 40.0}, {55.0, 72.0, 91.0}});
assertEquals(expected, result.value());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testBlockFreeFunction() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
var A = new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
var B = new VariableMatrix(new double[][] {{7.0}, {8.0}});
var mat1 = VariableMatrix.block(new VariableMatrix[][] {{A, B}});
var expected1 = new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0, 7.0}, {4.0, 5.0, 6.0, 8.0}});
assertEquals(2, mat1.rows());
assertEquals(4, mat1.cols());
assertEquals(expected1, mat1.value());
var C = new VariableMatrix(new double[][] {{9.0, 10.0, 11.0, 12.0}});
var mat2 = VariableMatrix.block(new VariableMatrix[][] {{A, B}, {C}});
var expected2 =
new SimpleMatrix(
new double[][] {{1.0, 2.0, 3.0, 7.0}, {4.0, 5.0, 6.0, 8.0}, {9.0, 10.0, 11.0, 12.0}});
assertEquals(3, mat2.rows());
assertEquals(4, mat2.cols());
assertEquals(expected2, mat2.value());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
private void checkSolve(VariableMatrix A, VariableMatrix B) {
try (var X = VariableMatrix.solve(A, B)) {
assertEquals(A.cols(), X.rows());
assertEquals(B.cols(), X.cols());
assertTrue(A.value().mult(X.value()).minus(B.value()).normF() < 1e-12);
}
}
@Test
void testSolveFreeFunction() {
assertEquals(0, Variable.totalNativeMemoryUsage());
// 1x1 special case
try (var pool = new VariablePool()) {
checkSolve(
new VariableMatrix(new double[][] {{2.0}}), new VariableMatrix(new double[][] {{5.0}}));
}
assertEquals(0, Variable.totalNativeMemoryUsage());
// 2x2 special case
try (var pool = new VariablePool()) {
checkSolve(
new VariableMatrix(new double[][] {{1.0, 2.0}, {3.0, 4.0}}),
new VariableMatrix(new double[][] {{5.0}, {6.0}}));
}
assertEquals(0, Variable.totalNativeMemoryUsage());
// 3x3 special case
try (var pool = new VariablePool()) {
checkSolve(
new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {-4.0, -5.0, 6.0}, {7.0, 8.0, 9.0}}),
new VariableMatrix(new double[][] {{10.0}, {11.0}, {12.0}}));
}
assertEquals(0, Variable.totalNativeMemoryUsage());
// 4x4 special case
try (var pool = new VariablePool()) {
checkSolve(
new VariableMatrix(
new double[][] {
{1.0, 2.0, 3.0, -4.0},
{-5.0, 6.0, 7.0, 8.0},
{9.0, 10.0, 11.0, 12.0},
{13.0, 14.0, 15.0, 16.0}
}),
new VariableMatrix(new double[][] {{17.0}, {18.0}, {19.0}, {20.0}}));
}
assertEquals(0, Variable.totalNativeMemoryUsage());
// 5x5 general case
try (var pool = new VariablePool()) {
checkSolve(
new VariableMatrix(
new double[][] {
{1.0, 2.0, 3.0, -4.0, 5.0},
{-5.0, 6.0, 7.0, 8.0, 9.0},
{9.0, 10.0, 11.0, 12.0, 13.0},
{13.0, 14.0, 15.0, 16.0, 17.0},
{17.0, 18.0, 19.0, 20.0, 21.0}
}),
new VariableMatrix(new double[][] {{21.0}, {22.0}, {23.0}, {24.0}, {25.0}}));
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -0,0 +1,57 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.Test;
class VariableTest {
@Test
void testDefaultConstructor() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var a = new Variable()) {
assertEquals(0.0, a.value());
assertEquals(ExpressionType.LINEAR, a.type());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testConstantConstructor() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var pool = new VariablePool()) {
// float
var a = new Variable(1.0);
assertEquals(1, a.value());
assertEquals(ExpressionType.CONSTANT, a.type());
// int
var b = new Variable(2);
assertEquals(2, b.value());
assertEquals(ExpressionType.CONSTANT, b.type());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testSetValue() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var a = new Variable()) {
a.setValue(1.0);
assertEquals(1.0, a.value());
a.setValue(2.0);
assertEquals(2.0, a.value());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
import java.util.function.BiFunction;
import org.junit.jupiter.api.Test;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.linalg.MatBuilder;
import org.wpilib.math.linalg.Matrix;
import org.wpilib.math.numbers.N1;
@@ -95,6 +96,8 @@ class ArmFeedforwardTest {
@Test
void testCalculate() {
assertEquals(0, Variable.totalNativeMemoryUsage());
final double ks = 0.5;
final double kv = 1.5;
final double ka = 2;
@@ -110,10 +113,14 @@ class ArmFeedforwardTest {
calculateAndSimulate(armFF, ks, kv, ka, kg, Math.PI / 3, 1.0, 0.95, 0.020);
calculateAndSimulate(armFF, ks, kv, ka, kg, -Math.PI / 3, 1.0, 1.05, 0.020);
calculateAndSimulate(armFF, ks, kv, ka, kg, -Math.PI / 3, 1.0, 0.95, 0.020);
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testCalculateIllConditionedModel() {
assertEquals(0, Variable.totalNativeMemoryUsage());
final double ks = 0.39671;
final double kv = 2.7167;
final double ka = 1e-2;
@@ -129,10 +136,14 @@ class ArmFeedforwardTest {
assertEquals(
armFF.calculate(currentAngle, currentVelocity, nextVelocity),
ks + kv * currentVelocity + ka * averageAccel + kg * Math.cos(currentAngle));
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testCalculateIllConditionedGradient() {
assertEquals(0, Variable.totalNativeMemoryUsage());
final double ks = 0.39671;
final double kv = 2.7167;
final double ka = 0.50799;
@@ -140,6 +151,8 @@ class ArmFeedforwardTest {
final ArmFeedforward armFF = new ArmFeedforward(ks, kg, kv, ka);
calculateAndSimulate(armFF, ks, kv, ka, kg, 1.0, 0.02, 0.0, 0.02);
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test

View File

@@ -11,6 +11,7 @@ import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.Test;
import org.wpilib.math.autodiff.Variable;
class Ellipse2dTest {
private static final double kEpsilon = 1E-9;
@@ -56,6 +57,8 @@ class Ellipse2dTest {
@Test
void testDistance() {
assertEquals(0, Variable.totalNativeMemoryUsage());
var center = new Pose2d(1.0, 2.0, Rotation2d.fromDegrees(270.0));
var ellipse = new Ellipse2d(center, 1.0, 2.0);
@@ -70,10 +73,14 @@ class Ellipse2dTest {
var point4 = new Translation2d(-1.0, 2.5);
assertEquals(0.19210128384806818, ellipse.getDistance(point4), kEpsilon);
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testNearest() {
assertEquals(0, Variable.totalNativeMemoryUsage());
var center = new Pose2d(1.0, 2.0, Rotation2d.fromDegrees(270.0));
var ellipse = new Ellipse2d(center, 1.0, 2.0);
@@ -100,6 +107,8 @@ class Ellipse2dTest {
assertAll(
() -> assertEquals(-0.8512799937611617, nearestPoint4.getX(), kEpsilon),
() -> assertEquals(2.378405333174535, nearestPoint4.getY(), kEpsilon));
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test

View File

@@ -0,0 +1,121 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.wpilib.math.autodiff.Variable.pow;
import static org.wpilib.math.optimization.Constraints.bounds;
import static org.wpilib.math.optimization.Constraints.eq;
import static org.wpilib.math.optimization.Constraints.le;
import org.junit.jupiter.api.Test;
import org.wpilib.math.autodiff.ExpressionType;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.autodiff.VariableMatrix;
import org.wpilib.math.optimization.solver.ExitStatus;
class ArmOnElevatorProblemTest {
@Test
void testArmOnElevatorProblem() {
assertEquals(0, Variable.totalNativeMemoryUsage());
final int N = 800;
final double ELEVATOR_START_HEIGHT = 1.0; // m
final double ELEVATOR_END_HEIGHT = 1.25; // m
final double ELEVATOR_MAX_VELOCITY = 1.0; // m/s
final double ELEVATOR_MAX_ACCELERATION = 2.0; // m/s²
final double ARM_LENGTH = 1.0; // m
final double ARM_START_ANGLE = 0.0; // rad
final double ARM_END_ANGLE = Math.PI; // rad
final double ARM_MAX_VELOCITY = 2.0 * Math.PI; // rad/s
final double ARM_MAX_ACCELERATION = 4.0 * Math.PI; // rad/s²
final double END_EFFECTOR_MAX_HEIGHT = 1.8; // m
final double TOTAL_TIME = 4.0;
final double dt = TOTAL_TIME / N;
try (var problem = new Problem()) {
var elevator = problem.decisionVariable(2, N + 1);
var elevator_accel = problem.decisionVariable(1, N);
var arm = problem.decisionVariable(2, N + 1);
var arm_accel = problem.decisionVariable(1, N);
for (int k = 0; k < N; ++k) {
// Elevator dynamics constraints
problem.subjectTo(
eq(
elevator.get(0, k + 1),
elevator
.get(0, k)
.plus(elevator.get(1, k).times(dt))
.plus(elevator_accel.get(0, k).times(0.5 * dt * dt))));
problem.subjectTo(
eq(
elevator.get(1, k + 1),
elevator.get(1, k).plus(elevator_accel.get(0, k).times(dt))));
// Arm dynamics constraints
problem.subjectTo(
eq(
arm.get(0, k + 1),
arm.get(0, k)
.plus(arm.get(1, k).times(dt))
.plus(arm_accel.get(0, k).times(0.5 * dt * dt))));
problem.subjectTo(eq(arm.get(1, k + 1), arm.get(1, k).plus(arm_accel.get(0, k).times(dt))));
}
// Elevator start and end conditions
problem.subjectTo(
eq(elevator.col(0), new VariableMatrix(new double[][] {{ELEVATOR_START_HEIGHT}, {0.0}})));
problem.subjectTo(
eq(elevator.col(N), new VariableMatrix(new double[][] {{ELEVATOR_END_HEIGHT}, {0.0}})));
// Arm start and end conditions
problem.subjectTo(
eq(arm.col(0), new VariableMatrix(new double[][] {{ARM_START_ANGLE}, {0.0}})));
problem.subjectTo(
eq(arm.col(N), new VariableMatrix(new double[][] {{ARM_END_ANGLE}, {0.0}})));
// Elevator velocity limits
problem.subjectTo(bounds(-ELEVATOR_MAX_VELOCITY, elevator.row(1), ELEVATOR_MAX_VELOCITY));
// Elevator acceleration limits
problem.subjectTo(
bounds(-ELEVATOR_MAX_ACCELERATION, elevator_accel, ELEVATOR_MAX_ACCELERATION));
// Arm velocity limits
problem.subjectTo(bounds(-ARM_MAX_VELOCITY, arm.row(1), ARM_MAX_VELOCITY));
// Arm acceleration limits
problem.subjectTo(bounds(-ARM_MAX_ACCELERATION, arm_accel, ARM_MAX_ACCELERATION));
// Height limit
var heights = elevator.row(0).plus(arm.row(0).cwiseMap(Variable::sin).times(ARM_LENGTH));
problem.subjectTo(le(heights, END_EFFECTOR_MAX_HEIGHT));
// Cost function
var J = new Variable(0.0);
for (int k = 0; k < N + 1; ++k) {
J =
J.plus(
pow(new Variable(ELEVATOR_END_HEIGHT).minus(elevator.get(0, k)), 2)
.plus(pow(new Variable(ARM_END_ANGLE).minus(arm.get(0, k)), 2)));
}
problem.minimize(J);
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
assertEquals(ExpressionType.NONLINEAR, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -0,0 +1,101 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.wpilib.math.optimization.Constraints.bounds;
import org.ejml.simple.SimpleMatrix;
import org.junit.jupiter.api.Test;
import org.wpilib.math.autodiff.ExpressionType;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.optimization.ocp.DynamicsType;
import org.wpilib.math.optimization.ocp.TimestepMethod;
import org.wpilib.math.optimization.ocp.TranscriptionMethod;
import org.wpilib.math.optimization.solver.ExitStatus;
import org.wpilib.math.util.MathUtil;
class CartPoleOCPTest {
@Test
void testCartPole() {
assertEquals(0, Variable.totalNativeMemoryUsage());
final double TOTAL_TIME = 5.0; // s
final double dt = 0.05; // s
final int N = (int) (TOTAL_TIME / dt);
final double u_max = 20.0; // N
final double d_max = 2.0; // m
final var x_initial = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}});
final var x_final = new SimpleMatrix(new double[][] {{1.0}, {Math.PI}, {0.0}, {0.0}});
try (var problem =
new OCP(
4,
1,
dt,
N,
CartPoleUtil::cartPoleDynamics,
DynamicsType.EXPLICIT_ODE,
TimestepMethod.VARIABLE_SINGLE,
TranscriptionMethod.DIRECT_COLLOCATION)) {
// x = [q, q̇]ᵀ = [x, θ, ẋ, θ̇]ᵀ
var X = problem.X();
// Initial guess
for (int k = 0; k < N + 1; ++k) {
X.get(0, k).setValue(MathUtil.lerp(x_initial.get(0), x_final.get(0), (double) k / N));
X.get(1, k).setValue(MathUtil.lerp(x_initial.get(1), x_final.get(1), (double) k / N));
}
// Initial conditions
problem.constrainInitialState(x_initial);
// Final conditions
problem.constrainFinalState(x_final);
// Cart position constraints
problem.forEachStep(
(x, u) -> {
problem.subjectTo(bounds(0.0, x.get(0), d_max));
});
// Input constraints
problem.setLowerInputBound(-u_max);
problem.setUpperInputBound(u_max);
// u = f_x
var U = problem.U();
// Minimize sum squared inputs
var J = new Variable(0.0);
for (int k = 0; k < N; ++k) {
J = J.plus(U.col(k).T().times(U.col(k)).get(0));
}
problem.minimize(J);
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
assertEquals(ExpressionType.NONLINEAR, problem.equalityConstraintType());
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
// Verify initial state
assertEquals(x_initial.get(0), X.value(0, 0), 1e-8);
assertEquals(x_initial.get(1), X.value(1, 0), 1e-8);
assertEquals(x_initial.get(2), X.value(2, 0), 1e-8);
assertEquals(x_initial.get(3), X.value(3, 0), 1e-8);
// Verify final state
assertEquals(x_final.get(0), X.value(0, N), 1e-8);
assertEquals(x_final.get(1), X.value(1, N), 1e-8);
assertEquals(x_final.get(2), X.value(2, N), 1e-8);
assertEquals(x_final.get(3), X.value(3, N), 1e-8);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -0,0 +1,114 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.wpilib.math.autodiff.NumericalIntegration.rk4;
import static org.wpilib.math.optimization.Constraints.bounds;
import static org.wpilib.math.optimization.Constraints.eq;
import static org.wpilib.math.system.NumericalIntegration.rk4;
import org.ejml.simple.SimpleMatrix;
import org.junit.jupiter.api.Test;
import org.wpilib.math.autodiff.ExpressionType;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.optimization.solver.ExitStatus;
import org.wpilib.math.util.MathUtil;
class CartPoleProblemTest {
@Test
void testCartPoleProblem() {
assertEquals(0, Variable.totalNativeMemoryUsage());
final double TOTAL_TIME = 5.0; // s
final double dt = 0.05; // s
final int N = (int) (TOTAL_TIME / dt);
final double u_max = 20.0; // N
final double d_max = 2.0; // m
final var x_initial = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}});
final var x_final = new SimpleMatrix(new double[][] {{1.0}, {Math.PI}, {0.0}, {0.0}});
try (var problem = new Problem()) {
// x = [q, q̇]ᵀ = [x, θ, ẋ, θ̇]ᵀ
var X = problem.decisionVariable(4, N + 1);
// Initial guess
for (int k = 0; k < N + 1; ++k) {
X.get(0, k).setValue(MathUtil.lerp(x_initial.get(0), x_final.get(0), (double) k / N));
X.get(1, k).setValue(MathUtil.lerp(x_initial.get(1), x_final.get(1), (double) k / N));
}
// u = f_x
var U = problem.decisionVariable(1, N);
// Initial conditions
problem.subjectTo(eq(X.col(0), x_initial));
// Final conditions
problem.subjectTo(eq(X.col(N), x_final));
// Cart position constraints
problem.subjectTo(bounds(0.0, X.row(0), d_max));
// Input constraints
problem.subjectTo(bounds(-u_max, U, u_max));
// Dynamics constraints - RK4 integration
for (int k = 0; k < N; ++k) {
problem.subjectTo(
eq(X.col(k + 1), rk4(CartPoleUtil::cartPoleDynamics, X.col(k), U.col(k), dt)));
}
// Minimize sum squared inputs
var J = new Variable(0.0);
for (int k = 0; k < N; ++k) {
J = J.plus(U.col(k).T().times(U.col(k)).get(0));
}
problem.minimize(J);
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
assertEquals(ExpressionType.NONLINEAR, problem.equalityConstraintType());
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
// Verify initial state
assertEquals(x_initial.get(0), X.value(0, 0), 1e-8);
assertEquals(x_initial.get(1), X.value(1, 0), 1e-8);
assertEquals(x_initial.get(2), X.value(2, 0), 1e-8);
assertEquals(x_initial.get(3), X.value(3, 0), 1e-8);
// Verify solution
for (int k = 0; k < N; ++k) {
// Cart position constraints
assertTrue(X.get(0, k).value() >= 0.0);
assertTrue(X.get(0, k).value() <= d_max);
// Input constraints
assertTrue(U.get(0, k).value() >= -u_max);
assertTrue(U.get(0, k).value() <= u_max);
// Dynamics constraints
var expected_x_k1 =
rk4(CartPoleUtil::cartPoleDynamics, X.col(k).value(), U.col(k).value(), dt);
var actual_x_k1 = X.col(k + 1).value();
for (int row = 0; row < actual_x_k1.getNumRows(); ++row) {
assertEquals(expected_x_k1.get(row), actual_x_k1.get(row), 1e-8);
}
}
// Verify final state
assertEquals(x_final.get(0), X.value(0, N), 1e-8);
assertEquals(x_final.get(1), X.value(1, N), 1e-8);
assertEquals(x_final.get(2), X.value(2, N), 1e-8);
assertEquals(x_final.get(3), X.value(3, N), 1e-8);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -0,0 +1,122 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.wpilib.math.autodiff.Variable.cos;
import static org.wpilib.math.autodiff.Variable.sin;
import static org.wpilib.math.autodiff.VariableMatrix.solve;
import org.ejml.simple.SimpleMatrix;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.autodiff.VariableMatrix;
// https://underactuated.mit.edu/acrobot.html#cart_pole
//
// θ is CCW+ measured from negative y-axis.
//
// q = [x, θ]ᵀ
// q̇ = [ẋ, θ̇]ᵀ
// u = f_x
//
// M(q)q̈ + C(q, q̇)q̇ = τ_g(q) + Bu
// M(q)q̈ = τ_g(q) C(q, q̇)q̇ + Bu
// q̈ = M⁻¹(q)(τ_g(q) C(q, q̇)q̇ + Bu)
//
// [ m_c + m_p m_p l cosθ]
// M(q) = [m_p l cosθ m_p l² ]
//
// [0 m_p lθ̇ sinθ]
// C(q, q̇) = [0 0 ]
//
// [ 0 ]
// τ_g(q) = [-m_p gl sinθ]
//
// [1]
// B = [0]
public final class CartPoleUtil {
private CartPoleUtil() {
// Utility class.
}
private static final double m_c = 5.0; // Cart mass (kg)
private static final double m_p = 0.5; // Pole mass (kg)
private static final double l = 0.5; // Pole length (m)
private static final double g = 9.806; // Acceleration due to gravity (m/s²)
public static SimpleMatrix cartPoleDynamics(SimpleMatrix x, SimpleMatrix u) {
var q = x.extractMatrix(0, 2, 0, 1);
var qdot = x.extractMatrix(2, 4, 0, 1);
var theta = q.get(1, 0);
var thetadot = qdot.get(1, 0);
// [ m_c + m_p m_p l cosθ]
// M(q) = [m_p l cosθ m_p l² ]
var M =
new SimpleMatrix(
new double[][] {
{m_c + m_p, m_p * l * Math.cos(theta)},
{m_p * l * Math.cos(theta), m_p * Math.pow(l, 2)}
});
// [0 m_p lθ̇ sinθ]
// C(q, q̇) = [0 0 ]
var C = new SimpleMatrix(new double[][] {{0, -m_p * l * thetadot * Math.sin(theta)}, {0, 0}});
// [ 0 ]
// τ_g(q) = [-m_p gl sinθ]
var tau_g = new SimpleMatrix(new double[][] {{0}, {-m_p * g * l * Math.sin(theta)}});
// [1]
// B = [0]
final var B = new SimpleMatrix(new double[][] {{1}, {0}});
// q̈ = M⁻¹(q)(τ_g(q) C(q, q̇)q̇ + Bu)
var qddot = new SimpleMatrix(4, 1);
qddot.insertIntoThis(0, 0, qdot);
qddot.insertIntoThis(2, 0, M.solve(tau_g.minus(C.mult(qdot)).plus(B.mult(u))));
return qddot;
}
public static VariableMatrix cartPoleDynamics(VariableMatrix x, VariableMatrix u) {
var q = x.segment(0, 2);
var qdot = x.segment(2, 2);
var theta = q.get(1);
var thetadot = qdot.get(1);
// [ m_c + m_p m_p l cosθ]
// M(q) = [m_p l cosθ m_p l² ]
var M =
new VariableMatrix(
new Variable[][] {
{new Variable(m_c + m_p), cos(theta).times(m_p * l)},
{cos(theta).times(m_p * l), new Variable(m_p * Math.pow(l, 2))}
});
// [0 m_p lθ̇ sinθ]
// C(q, q̇) = [0 0 ]
var C =
new VariableMatrix(
new Variable[][] {
{new Variable(0), thetadot.times(-m_p * l).times(sin(theta))},
{new Variable(0), new Variable(0)}
});
// [ 0 ]
// τ_g(q) = [-m_p gl sinθ]
var tau_g =
new VariableMatrix(new Variable[][] {{new Variable(0)}, {sin(theta).times(-m_p * g * l)}});
// [1]
// B = [0]
var B = new VariableMatrix(new double[][] {{1}, {0}});
// q̈ = M⁻¹(q)(τ_g(q) C(q, q̇)q̇ + Bu)
var qddot = new VariableMatrix(4);
qddot.segment(0, 2).set(qdot);
qddot.segment(2, 2).set(solve(M, tau_g.minus(C.times(qdot)).plus(B.times(u))));
return qddot;
}
}

View File

@@ -0,0 +1,92 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.wpilib.math.optimization.Constraints.ge;
import static org.wpilib.math.optimization.Constraints.le;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.autodiff.VariableMatrix;
/**
* This class computes the optimal current allocation for a list of subsystems given a list of their
* desired currents and current tolerances that determine which subsystem gets less current if the
* current budget is exceeded. Subsystems with a smaller tolerance are given higher priority.
*/
public class CurrentManager implements AutoCloseable {
private final Problem m_problem = new Problem();
private final VariableMatrix m_desiredCurrents;
private final VariableMatrix m_allocatedCurrents;
/**
* Constructs a CurrentManager.
*
* @param currentTolerances The relative current tolerance of each subsystem.
* @param maxCurrent The current budget to allocate between subsystems.
*/
public CurrentManager(double[] currentTolerances, double maxCurrent) {
this.m_desiredCurrents = new VariableMatrix(currentTolerances.length, 1);
this.m_allocatedCurrents = m_problem.decisionVariable(currentTolerances.length);
// Ensure m_desired_currents contains initialized Variables
for (int row = 0; row < m_desiredCurrents.rows(); ++row) {
// Don't initialize to 0 or 1, because those will get folded by Sleipnir
m_desiredCurrents.get(row).setValue(Double.POSITIVE_INFINITY);
}
var J = new Variable(0.0);
var currentSum = new Variable(0.0);
for (int i = 0; i < currentTolerances.length; ++i) {
// The weight is 1/tolᵢ² where tolᵢ is the tolerance between the desired
// and allocated current for subsystem i
var error = m_desiredCurrents.get(i).minus(m_allocatedCurrents.get(i));
J = J.plus(error.times(error).div(currentTolerances[i] * currentTolerances[i]));
currentSum = currentSum.plus(m_allocatedCurrents.get(i));
// Currents must be nonnegative
m_problem.subjectTo(ge(m_allocatedCurrents.get(i), 0.0));
}
m_problem.minimize(J);
// Keep total current below maximum
m_problem.subjectTo(le(currentSum, maxCurrent));
}
@Override
public void close() {
m_problem.close();
}
/**
* Returns the optimal current allocation for a list of subsystems given a list of their desired
* currents and current tolerances that determine which subsystem gets less current if the current
* budget is exceeded. Subsystems with a smaller tolerance are given higher priority.
*
* @param desiredCurrents The desired current for each subsystem.
* @throws RuntimeException if the number of desired currents doesn't equal the number of
* tolerances passed in the constructor.
*/
public double[] calculate(double[] desiredCurrents) {
if (m_desiredCurrents.rows() != desiredCurrents.length) {
throw new RuntimeException(
"Number of desired currents must equal the number of tolerances passed in the "
+ "constructor.");
}
for (int i = 0; i < desiredCurrents.length; ++i) {
m_desiredCurrents.get(i).setValue(desiredCurrents[i]);
}
m_problem.solve();
var result = new double[desiredCurrents.length];
for (int i = 0; i < desiredCurrents.length; ++i) {
result[i] = Math.max(m_allocatedCurrents.value(i), 0.0);
}
return result;
}
}

View File

@@ -0,0 +1,62 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.Test;
import org.wpilib.math.autodiff.Variable;
class CurrentManagerTest {
@Test
void testEnoughCurrent() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var manager = new CurrentManager(new double[] {1.0, 5.0, 10.0, 5.0}, 40.0)) {
var currents = manager.calculate(new double[] {25.0, 10.0, 5.0, 0.0});
assertEquals(25.0, currents[0], 1e-3);
assertEquals(10.0, currents[1], 1e-3);
assertEquals(5.0, currents[2], 1e-3);
assertEquals(0.0, currents[3], 1e-3);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testNotEnoughCurrent() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var manager = new CurrentManager(new double[] {1.0, 5.0, 10.0, 5.0}, 40.0)) {
var currents = manager.calculate(new double[] {30.0, 10.0, 5.0, 0.0});
// Expected values are from the following program:
//
// #!/usr/bin/env python3
//
// from scipy.optimize import minimize
//
// r = [30.0, 10.0, 5.0, 0.0]
// q = [1.0, 5.0, 10.0, 5.0]
//
// result = minimize(
// lambda x: sum((r[i] - x[i]) ** 2 / q[i] ** 2 for i in range(4)),
// [0.0, 0.0, 0.0, 0.0],
// constraints=[
// {"type": "ineq", "fun": lambda x: x},
// {"type": "ineq", "fun": lambda x: 40.0 - sum(x)},
// ],
// )
// print(result.x)
assertEquals(29.960, currents[0], 1e-3);
assertEquals(9.008, currents[1], 1e-3);
assertEquals(1.032, currents[2], 1e-3);
assertEquals(0.0, currents[3], 1e-3);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -0,0 +1,138 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.wpilib.math.MatrixAssertions.assertEquals;
import org.ejml.simple.SimpleMatrix;
import org.junit.jupiter.api.Test;
import org.wpilib.math.autodiff.Variable;
class DecisionVariableTest {
@Test
void testScalarInitAssign() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
// Scalar zero init
var x = problem.decisionVariable();
assertEquals(0.0, x.value());
// Scalar assignment
x.setValue(1.0);
assertEquals(1.0, x.value());
x.setValue(2.0);
assertEquals(2.0, x.value());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testVectorInitAssign() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
// Vector zero init
var y = problem.decisionVariable(2);
assertEquals(0.0, y.value(0));
assertEquals(0.0, y.value(1));
// Vector assignment
y.get(0).setValue(1.0);
y.get(1).setValue(2.0);
assertEquals(1.0, y.value(0));
assertEquals(2.0, y.value(1));
y.get(0).setValue(3.0);
y.get(1).setValue(4.0);
assertEquals(3.0, y.value(0));
assertEquals(4.0, y.value(1));
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testDynamicMatrixInitAssign() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
// Matrix zero init
var z = problem.decisionVariable(3, 2);
assertEquals(0.0, z.value(0, 0));
assertEquals(0.0, z.value(0, 1));
assertEquals(0.0, z.value(1, 0));
assertEquals(0.0, z.value(1, 1));
assertEquals(0.0, z.value(2, 0));
assertEquals(0.0, z.value(2, 1));
// Matrix assignment; element comparison
z.setValue(new double[][] {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
assertEquals(1.0, z.value(0, 0));
assertEquals(2.0, z.value(0, 1));
assertEquals(3.0, z.value(1, 0));
assertEquals(4.0, z.value(1, 1));
assertEquals(5.0, z.value(2, 0));
assertEquals(6.0, z.value(2, 1));
// Matrix assignment; matrix comparison
{
var expected = new SimpleMatrix(new double[][] {{7.0, 8.0}, {9.0, 10.0}, {11.0, 12.0}});
z.setValue(expected);
assertEquals(expected, z.value());
}
// Block assignment
{
var expected_block = new double[][] {{1.0}, {1.0}};
z.block(0, 0, 2, 1).setValue(expected_block);
var expected_result =
new SimpleMatrix(new double[][] {{1.0, 8.0}, {1.0, 10.0}, {11.0, 12.0}});
assertEquals(expected_result, z.value());
}
// Segment assignment
{
var expected_block = new double[][] {{1.0}, {1.0}};
z.block(0, 0, 3, 1).segment(0, 2).setValue(expected_block);
var expected_result =
new SimpleMatrix(new double[][] {{1.0, 8.0}, {1.0, 10.0}, {11.0, 12.0}});
assertEquals(expected_result, z.value());
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testSymmetricMatrix() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
// Matrix zero init
var A = problem.symmetricDecisionVariable(2);
assertEquals(0.0, A.value(0, 0));
assertEquals(0.0, A.value(0, 1));
assertEquals(0.0, A.value(1, 0));
assertEquals(0.0, A.value(1, 1));
// Assign to lower triangle
A.get(0, 0).setValue(1.0);
A.get(1, 0).setValue(2.0);
A.get(1, 1).setValue(3.0);
// Confirm whole matrix changed
assertEquals(1.0, A.value(0, 0));
assertEquals(2.0, A.value(0, 1));
assertEquals(2.0, A.value(1, 0));
assertEquals(3.0, A.value(1, 1));
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -0,0 +1,85 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.ejml.simple.SimpleMatrix;
import org.junit.jupiter.api.Test;
import org.wpilib.math.autodiff.ExpressionType;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.optimization.ocp.DynamicsType;
import org.wpilib.math.optimization.ocp.TimestepMethod;
import org.wpilib.math.optimization.ocp.TranscriptionMethod;
import org.wpilib.math.optimization.solver.ExitStatus;
import org.wpilib.math.optimization.solver.Options;
class DifferentialDriveOCPTest {
@Test
void testDifferentialDrive() {
assertEquals(0, Variable.totalNativeMemoryUsage());
final int N = 50;
final double minTimestep = 0.05; // s
final var x_initial = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}, {0.0}});
final var x_final = new SimpleMatrix(new double[][] {{1.0}, {1.0}, {0.0}, {0.0}, {0.0}});
final var u_min = new SimpleMatrix(new double[][] {{-12.0}, {-12.0}});
final var u_max = new SimpleMatrix(new double[][] {{12.0}, {12.0}});
try (var problem =
new OCP(
5,
2,
minTimestep,
N,
DifferentialDriveUtil::differentialDriveDynamics,
DynamicsType.EXPLICIT_ODE,
TimestepMethod.VARIABLE_SINGLE,
TranscriptionMethod.DIRECT_TRANSCRIPTION)) {
// Seed the min time formulation with lerp between waypoints
for (int i = 0; i < N + 1; ++i) {
problem.X().get(0, i).setValue((double) i / (N + 1));
problem.X().get(1, i).setValue((double) i / (N + 1));
}
problem.constrainInitialState(x_initial);
problem.constrainFinalState(x_final);
problem.setLowerInputBound(u_min);
problem.setUpperInputBound(u_max);
problem.setMinTimestep(minTimestep);
problem.setMaxTimestep(3.0);
// Set up cost
problem.minimize(problem.dt().times(SimpleMatrix.ones(N + 1, 1)));
assertEquals(ExpressionType.LINEAR, problem.costFunctionType());
assertEquals(ExpressionType.NONLINEAR, problem.equalityConstraintType());
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve(new Options().withMaxIterations(1000)));
var X = problem.X();
// Verify initial state
assertEquals(x_initial.get(0), X.value(0, 0), 1e-8);
assertEquals(x_initial.get(1), X.value(1, 0), 1e-8);
assertEquals(x_initial.get(2), X.value(2, 0), 1e-8);
assertEquals(x_initial.get(3), X.value(3, 0), 1e-8);
assertEquals(x_initial.get(4), X.value(4, 0), 1e-8);
// Verify final state
assertEquals(x_final.get(0), X.value(0, N), 1e-8);
assertEquals(x_final.get(1), X.value(1, N), 1e-8);
assertEquals(x_final.get(2), X.value(2, N), 1e-8);
assertEquals(x_final.get(3), X.value(3, N), 1e-8);
assertEquals(x_final.get(4), X.value(4, N), 1e-8);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -0,0 +1,116 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.wpilib.math.autodiff.NumericalIntegration.rk4;
import static org.wpilib.math.optimization.Constraints.bounds;
import static org.wpilib.math.optimization.Constraints.eq;
import static org.wpilib.math.system.NumericalIntegration.rk4;
import org.ejml.simple.SimpleMatrix;
import org.junit.jupiter.api.Test;
import org.wpilib.math.autodiff.ExpressionType;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.optimization.solver.ExitStatus;
import org.wpilib.math.util.MathUtil;
class DifferentialDriveProblemTest {
@Test
void testDifferentialDriveProblem() {
assertEquals(0, Variable.totalNativeMemoryUsage());
final double TOTAL_TIME = 5.0; // s
final double dt = 0.05; // s
final int N = (int) (TOTAL_TIME / dt);
final double u_max = 12.0; // V
final var x_initial = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}, {0.0}});
final var x_final = new SimpleMatrix(new double[][] {{1.0}, {1.0}, {0.0}, {0.0}, {0.0}});
try (var problem = new Problem()) {
// x = [x, y, heading, left velocity, right velocity]ᵀ
var X = problem.decisionVariable(5, N + 1);
// Initial guess
for (int k = 0; k < N; ++k) {
X.get(0, k).setValue(MathUtil.lerp(x_initial.get(0), x_final.get(0), (double) k / N));
X.get(1, k).setValue(MathUtil.lerp(x_initial.get(1), x_final.get(1), (double) k / N));
}
// u = [left voltage, right voltage]ᵀ
var U = problem.decisionVariable(2, N);
// Initial conditions
problem.subjectTo(eq(X.col(0), x_initial));
// Final conditions
problem.subjectTo(eq(X.col(N), x_final));
// Input constraints
problem.subjectTo(bounds(-u_max, U, u_max));
// Dynamics constraints - RK4 integration
for (int k = 0; k < N; ++k) {
problem.subjectTo(
eq(
X.col(k + 1),
rk4(DifferentialDriveUtil::differentialDriveDynamics, X.col(k), U.col(k), dt)));
}
// Minimize sum squared states and inputs
var J = new Variable(0.0);
for (int k = 0; k < N; ++k) {
J = J.plus(X.col(k).T().times(X.col(k)).plus(U.col(k).T().times(U.col(k))).get(0));
}
problem.minimize(J);
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
assertEquals(ExpressionType.NONLINEAR, problem.equalityConstraintType());
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
// Verify initial state
assertEquals(x_initial.get(0), X.value(0, 0), 1e-8);
assertEquals(x_initial.get(1), X.value(1, 0), 1e-8);
assertEquals(x_initial.get(2), X.value(2, 0), 1e-8);
assertEquals(x_initial.get(3), X.value(3, 0), 1e-8);
assertEquals(x_initial.get(4), X.value(4, 0), 1e-8);
// Verify solution
var x = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}, {0.0}});
for (int k = 0; k < N; ++k) {
// Input constraints
assertTrue(U.get(0, k).value() >= -u_max);
assertTrue(U.get(0, k).value() <= u_max);
assertTrue(U.get(1, k).value() >= -u_max);
assertTrue(U.get(1, k).value() <= u_max);
// Verify state
assertEquals(x.get(0), X.value(0, k), 1e-8);
assertEquals(x.get(1), X.value(1, k), 1e-8);
assertEquals(x.get(2), X.value(2, k), 1e-8);
assertEquals(x.get(3), X.value(3, k), 1e-8);
assertEquals(x.get(4), X.value(4, k), 1e-8);
// Project state forward
var u = U.col(k).value();
x = rk4(DifferentialDriveUtil::differentialDriveDynamics, x, u, dt);
}
// Verify final state
assertEquals(x_final.get(0), X.value(0, N), 1e-8);
assertEquals(x_final.get(1), X.value(1, N), 1e-8);
assertEquals(x_final.get(2), X.value(2, N), 1e-8);
assertEquals(x_final.get(3), X.value(3, N), 1e-8);
assertEquals(x_final.get(4), X.value(4, N), 1e-8);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -0,0 +1,58 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.wpilib.math.autodiff.Variable.cos;
import static org.wpilib.math.autodiff.Variable.sin;
import org.ejml.simple.SimpleMatrix;
import org.wpilib.math.autodiff.VariableMatrix;
// x = [x, y, heading, left velocity, right velocity]ᵀ
// u = [left voltage, right voltage]ᵀ
public final class DifferentialDriveUtil {
private DifferentialDriveUtil() {
// Utility class.
}
private static final double trackwidth = 0.699; // m
private static final double Kv_linear = 3.02; // V/(m/s)
private static final double Ka_linear = 0.642; // V/(m/s²)
private static final double Kv_angular = 1.382; // V/(m/s)
private static final double Ka_angular = 0.08495; // V/(m/s²)
private static final double A1 = -(Kv_linear / Ka_linear + Kv_angular / Ka_angular) / 2.0;
private static final double A2 = -(Kv_linear / Ka_linear - Kv_angular / Ka_angular) / 2.0;
private static final double B1 = 0.5 / Ka_linear + 0.5 / Ka_angular;
private static final double B2 = 0.5 / Ka_linear - 0.5 / Ka_angular;
private static final SimpleMatrix A = new SimpleMatrix(new double[][] {{A1, A2}, {A2, A1}});
private static final SimpleMatrix B = new SimpleMatrix(new double[][] {{B1, B2}, {B2, B1}});
public static SimpleMatrix differentialDriveDynamics(SimpleMatrix x, SimpleMatrix u) {
var xdot = new SimpleMatrix(5, 1);
var v = (x.get(3, 0) + x.get(4, 0)) / 2.0;
xdot.set(0, 0, v * Math.cos(x.get(2, 0)));
xdot.set(1, 0, v * Math.sin(x.get(2, 0)));
xdot.set(2, 0, (x.get(4, 0) - x.get(3, 0)) / trackwidth);
xdot.insertIntoThis(3, 0, A.mult(x.extractMatrix(3, 5, 0, 1)).plus(B.mult(u)));
return xdot;
}
public static VariableMatrix differentialDriveDynamics(VariableMatrix x, VariableMatrix u) {
var xdot = new VariableMatrix(5);
var v = x.get(3).plus(x.get(4)).div(2.0);
xdot.set(0, 0, v.times(cos(x.get(2))));
xdot.set(1, 0, v.times(sin(x.get(2))));
xdot.set(2, 0, x.get(4).minus(x.get(3)).div(trackwidth));
xdot.segment(3, 2)
.set(new VariableMatrix(A).times(x.segment(3, 2)).plus(new VariableMatrix(B).times(u)));
return xdot;
}
}

View File

@@ -0,0 +1,127 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.wpilib.math.autodiff.Variable.pow;
import static org.wpilib.math.optimization.Constraints.bounds;
import static org.wpilib.math.optimization.Constraints.eq;
import org.ejml.simple.SimpleMatrix;
import org.junit.jupiter.api.Test;
import org.wpilib.math.autodiff.ExpressionType;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.autodiff.VariableMatrix;
import org.wpilib.math.optimization.solver.ExitStatus;
class DoubleIntegratorProblemTest {
@Test
void testDoubleIntegratorProblem() {
assertEquals(0, Variable.totalNativeMemoryUsage());
final double TOTAL_TIME = 3.5; // s
final double dt = 0.005; // s
final int N = (int) (TOTAL_TIME / dt);
final double r = 2.0; // m
try (var problem = new Problem()) {
// 2x1 state vector with N + 1 timesteps (includes last state)
var X = problem.decisionVariable(2, N + 1);
// 1x1 input vector with N timesteps (input at last state doesn't matter)
var U = problem.decisionVariable(1, N);
// Kinematics constraint assuming constant acceleration between timesteps
for (int k = 0; k < N; ++k) {
final double t = dt;
var p_k1 = X.get(0, k + 1);
var v_k1 = X.get(1, k + 1);
var p_k = X.get(0, k);
var v_k = X.get(1, k);
var a_k = U.get(0, k);
// pₖ₊₁ = pₖ + vₖt + 1/2aₖt²
problem.subjectTo(eq(p_k1, p_k.plus(v_k.times(t)).plus(a_k.times(0.5 * t * t))));
// vₖ₊₁ = vₖ + aₖt
problem.subjectTo(eq(v_k1, v_k.plus(a_k.times(t))));
}
// Start and end at rest
problem.subjectTo(eq(X.col(0), new VariableMatrix(new double[][] {{0.0}, {0.0}})));
problem.subjectTo(eq(X.col(N), new VariableMatrix(new double[][] {{r}, {0.0}})));
// Limit velocity
problem.subjectTo(bounds(-1, X.row(1), 1));
// Limit acceleration
problem.subjectTo(bounds(-1, U, 1));
// Cost function - minimize position error
var J = new Variable(0.0);
for (int k = 0; k < N + 1; ++k) {
J = J.plus(pow(new Variable(r).minus(X.get(0, k)), 2));
}
problem.minimize(J);
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
var A = new SimpleMatrix(new double[][] {{1.0, dt}, {0.0, 1.0}});
var B = new SimpleMatrix(new double[][] {{0.5 * dt * dt}, {dt}});
// Verify initial state
assertEquals(0.0, X.value(0, 0), 1e-8);
assertEquals(0.0, X.value(1, 0), 1e-8);
// Verify solution
var x = new SimpleMatrix(new double[][] {{0.0}, {0.0}});
var u = new SimpleMatrix(new double[][] {{0.0}});
for (int k = 0; k < N; ++k) {
// Verify state
assertEquals(x.get(0), X.value(0, k), 1e-2);
assertEquals(x.get(1), X.value(1, k), 1e-2);
// Determine expected input for this timestep
if (k * dt < 1.0) {
// Accelerate
u.set(0, 0, 1.0);
} else if (k * dt < 2.05) {
// Maintain speed
u.set(0, 0, 0.0);
} else if (k * dt < 3.275) {
// Decelerate
u.set(0, 0, -1.0);
} else {
// Accelerate
u.set(0, 0, 1.0);
}
// Verify input
if (k > 0 && k < N - 1 && Math.abs(U.value(0, k - 1) - U.value(0, k + 1)) >= 1.0 - 1e-2) {
// If control input is transitioning between -1, 0, or 1, ensure it's within (-1, 1)
assertTrue(U.value(0, k) >= -1.0);
assertTrue(U.value(0, k) <= 1.0);
} else {
assertEquals(u.get(0), U.value(0, k), 1e-4);
}
// Project state forward
x = A.mult(x).plus(B.mult(u));
}
// Verify final state
assertEquals(r, X.value(0, N), 1e-8);
assertEquals(0.0, X.value(1, N), 1e-8);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -0,0 +1,178 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.function.BiFunction;
import org.ejml.simple.SimpleMatrix;
import org.junit.jupiter.api.Test;
import org.wpilib.math.autodiff.ExpressionType;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.autodiff.VariableMatrix;
import org.wpilib.math.optimization.ocp.DynamicsType;
import org.wpilib.math.optimization.ocp.TimestepMethod;
import org.wpilib.math.optimization.ocp.TranscriptionMethod;
import org.wpilib.math.optimization.solver.ExitStatus;
class FlywheelOCPTest {
private boolean near(double expected, double actual, double tolerance) {
return Math.abs(expected - actual) < tolerance;
}
void flywheelTest(
double A,
double B,
BiFunction<VariableMatrix, VariableMatrix, VariableMatrix> f,
DynamicsType dynamicsType,
TranscriptionMethod transcriptionMethod) {
assertEquals(0, Variable.totalNativeMemoryUsage());
final double TOTAL_TIME = 5.0; // s
final double dt = 0.005; // s
final int N = (int) (TOTAL_TIME / dt);
// Flywheel model:
// States: [velocity]
// Inputs: [voltage]
final double A_discrete = Math.exp(A * dt);
final double B_discrete = (1.0 - A_discrete) * B;
final double r = 10.0;
try (var problem =
new OCP(1, 1, dt, N, f, dynamicsType, TimestepMethod.FIXED, transcriptionMethod)) {
problem.constrainInitialState(0.0);
problem.setUpperInputBound(12);
problem.setLowerInputBound(-12);
// Set up cost
var r_mat = new VariableMatrix(SimpleMatrix.filled(1, N + 1, r));
problem.minimize(r_mat.minus(problem.X()).times(r_mat.minus(problem.X()).T()));
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
// Voltage for steady-state velocity:
//
// rₖ₊₁ = Arₖ + Buₖ
// uₖ = B⁺(rₖ₊₁ Arₖ)
// uₖ = B⁺(rₖ Arₖ)
// uₖ = B⁺(I A)rₖ
double u_ss = 1.0 / B_discrete * (1.0 - A_discrete) * r;
// Verify initial state
assertEquals(0.0, problem.X().value(0, 0), 1e-8);
// Verify solution
double x = 0.0;
double u;
for (int k = 0; k < N; ++k) {
// Verify state
assertEquals(x, problem.X().value(0, k), 1e-2);
// Determine expected input for this timestep
double error = r - x;
if (error > 1e-2) {
// Max control input until the reference is reached
u = 12.0;
} else {
// Maintain speed
u = u_ss;
}
// Verify input
if (k > 0
&& k < N - 1
&& near(12.0, problem.U().value(0, k - 1), 1e-2)
&& near(u_ss, problem.U().value(0, k + 1), 1e-2)) {
// If control input is transitioning between 12 and u_ss, ensure it's
// within (u_ss, 12)
assertTrue(problem.U().value(0, k) >= u_ss);
assertTrue(problem.U().value(0, k) <= 12.0);
} else {
if (transcriptionMethod == TranscriptionMethod.DIRECT_COLLOCATION) {
// The tolerance is large because the trajectory is represented by a
// spline, and splines chatter when transitioning quickly between
// steady-states.
assertEquals(u, problem.U().value(0, k), 2.0);
} else {
assertEquals(u, problem.U().value(0, k), 1e-4);
}
}
// Project state forward
x = A_discrete * x + B_discrete * u;
}
// Verify final state
assertEquals(r, problem.X().value(0, N), 2e-6);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
static final double A = -1.0;
static final double B = 1.0;
static final double dt = 0.005; // s
static final double A_discrete = Math.exp(A * dt);
static final double B_discrete = (1.0 - A_discrete) * B;
private static VariableMatrix f_ode(VariableMatrix x, VariableMatrix u) {
return new VariableMatrix(new double[][] {{A}})
.times(x)
.plus(new VariableMatrix(new double[][] {{B}}).times(u));
}
private static VariableMatrix f_discrete(VariableMatrix x, VariableMatrix u) {
return new VariableMatrix(new double[][] {{A_discrete}})
.times(x)
.plus(new VariableMatrix(new double[][] {{B_discrete}}).times(u));
}
@Test
void testFlywheelExplicit() {
flywheelTest(
A,
B,
FlywheelOCPTest::f_ode,
DynamicsType.EXPLICIT_ODE,
TranscriptionMethod.DIRECT_COLLOCATION);
flywheelTest(
A,
B,
FlywheelOCPTest::f_ode,
DynamicsType.EXPLICIT_ODE,
TranscriptionMethod.DIRECT_TRANSCRIPTION);
flywheelTest(
A,
B,
FlywheelOCPTest::f_ode,
DynamicsType.EXPLICIT_ODE,
TranscriptionMethod.SINGLE_SHOOTING);
}
@Test
void testFlywheelDiscrete() {
flywheelTest(
A,
B,
FlywheelOCPTest::f_discrete,
DynamicsType.DISCRETE,
TranscriptionMethod.DIRECT_TRANSCRIPTION);
flywheelTest(
A,
B,
FlywheelOCPTest::f_discrete,
DynamicsType.DISCRETE,
TranscriptionMethod.SINGLE_SHOOTING);
}
}

View File

@@ -0,0 +1,120 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.wpilib.math.optimization.Constraints.bounds;
import static org.wpilib.math.optimization.Constraints.eq;
import org.junit.jupiter.api.Test;
import org.wpilib.math.autodiff.ExpressionType;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.autodiff.VariableMatrix;
import org.wpilib.math.optimization.solver.ExitStatus;
class FlywheelProblemTest {
private boolean near(double expected, double actual, double tolerance) {
return Math.abs(expected - actual) < tolerance;
}
@Test
void testFlywheelProblem() {
assertEquals(0, Variable.totalNativeMemoryUsage());
final double TOTAL_TIME = 5.0; // s
final double dt = 0.005; // s
final int N = (int) (TOTAL_TIME / dt);
// Flywheel model:
// States: [velocity]
// Inputs: [voltage]
double A = Math.exp(-dt);
double B = 1.0 - Math.exp(-dt);
try (var problem = new Problem()) {
var X = problem.decisionVariable(1, N + 1);
var U = problem.decisionVariable(1, N);
// Dynamics constraint
for (int k = 0; k < N; ++k) {
problem.subjectTo(
eq(
X.col(k + 1),
new Variable(A)
.times(X.col(k).get(0))
.plus(new Variable(B).times(U.col(k).get(0)))));
}
// State and input constraints
problem.subjectTo(eq(X.col(0), 0.0));
problem.subjectTo(bounds(-12, U, 12));
// Cost function - minimize error
final var r = new VariableMatrix(new double[][] {{10.0}});
var J = new Variable(0.0);
for (int k = 0; k < N + 1; ++k) {
J = J.plus(r.minus(X.col(k)).T().times(r.minus(X.col(k))).get(0));
}
problem.minimize(J);
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
// Voltage for steady-state velocity:
//
// rₖ₊₁ = Arₖ + Buₖ
// uₖ = B⁺(rₖ₊₁ Arₖ)
// uₖ = B⁺(rₖ Arₖ)
// uₖ = B⁺(I A)rₖ
double u_ss = 1.0 / B * (1.0 - A) * r.value(0);
// Verify initial state
assertEquals(0.0, X.value(0, 0), 1e-8);
// Verify solution
double x = 0.0;
double u;
for (int k = 0; k < N; ++k) {
// Verify state
assertEquals(x, X.value(0, k), 1e-2);
// Determine expected input for this timestep
double error = r.value(0) - x;
if (error > 1e-2) {
// Max control input until the reference is reached
u = 12.0;
} else {
// Maintain speed
u = u_ss;
}
// Verify input
if (k > 0
&& k < N - 1
&& near(12.0, U.value(0, k - 1), 1e-2)
&& near(u_ss, U.value(0, k + 1), 1e-2)) {
// If control input is transitioning between 12 and u_ss, ensure it's
// within (u_ss, 12)
assertTrue(U.value(0, k) >= u_ss);
assertTrue(U.value(0, k) <= 12.0);
} else {
assertEquals(u, U.value(0, k), 1e-4);
}
// Project state forward
x = A * x + B * u;
}
// Verify final state
assertEquals(r.value(0), X.value(0, N), 2e-7);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -0,0 +1,72 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.wpilib.math.optimization.Constraints.eq;
import static org.wpilib.math.optimization.Constraints.ge;
import static org.wpilib.math.optimization.Constraints.le;
import org.junit.jupiter.api.Test;
import org.wpilib.math.autodiff.ExpressionType;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.optimization.solver.ExitStatus;
class LinearProblemTest {
@Test
void testMaximize() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable();
var y = problem.decisionVariable();
x.setValue(1.0);
y.setValue(1.0);
problem.maximize(x.times(50).plus(y.times(40)));
problem.subjectTo(le(x.plus(y.times(1.5)), 750));
problem.subjectTo(le(x.times(2).plus(y.times(3)), 1500));
problem.subjectTo(le(x.times(2).plus(y), 1000));
problem.subjectTo(ge(x, 0));
problem.subjectTo(ge(y, 0));
assertEquals(ExpressionType.LINEAR, problem.costFunctionType());
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
assertEquals(375.0, x.value(), 1e-6);
assertEquals(250.0, y.value(), 1e-6);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testFreeVariable() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable(2);
x.get(0).setValue(1.0);
x.get(1).setValue(2.0);
problem.subjectTo(eq(x.get(0), 0));
assertEquals(ExpressionType.NONE, problem.costFunctionType());
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
assertEquals(0.0, x.get(0).value(), 1e-6);
assertEquals(2.0, x.get(1).value(), 1e-6);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -0,0 +1,212 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.wpilib.math.DoubleRange.range;
import static org.wpilib.math.autodiff.Variable.hypot;
import static org.wpilib.math.autodiff.Variable.pow;
import static org.wpilib.math.autodiff.Variable.sqrt;
import static org.wpilib.math.optimization.Constraints.bounds;
import static org.wpilib.math.optimization.Constraints.eq;
import static org.wpilib.math.optimization.Constraints.ge;
import static org.wpilib.math.optimization.Constraints.le;
import org.junit.jupiter.api.Test;
import org.wpilib.math.autodiff.ExpressionType;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.optimization.solver.ExitStatus;
class NonlinearProblemTest {
@Test
void testQuartic() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable();
x.setValue(20.0);
problem.minimize(pow(x, 4));
problem.subjectTo(ge(x, 1));
assertEquals(ExpressionType.NONLINEAR, problem.costFunctionType());
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
assertEquals(1.0, x.value(), 1e-6);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
private boolean near(double expected, double actual, double tolerance) {
return Math.abs(expected - actual) < tolerance;
}
@Test
void testRosenbrockWithCubicAndLineConstraint() {
// https://en.wikipedia.org/wiki/Test_functions_for_optimization#Test_functions_for_constrained_optimization
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable();
var y = problem.decisionVariable();
problem.minimize(
pow(y.minus(pow(x, 2)), 2).times(100).plus(pow(new Variable(1).minus(x), 2)));
problem.subjectTo(ge(y, pow(x.minus(1), 3).plus(1)));
problem.subjectTo(le(y, x.unaryMinus().plus(2)));
assertEquals(ExpressionType.NONLINEAR, problem.costFunctionType());
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
assertEquals(ExpressionType.NONLINEAR, problem.inequalityConstraintType());
for (var x0 : range(-1.5, 1.5, 0.1)) {
for (var y0 : range(-0.5, 2.5, 0.1)) {
x.setValue(x0);
y.setValue(y0);
assertEquals(ExitStatus.SUCCESS, problem.solve());
// Local minimum at (0.0, 0.0)
// Global minimum at (1.0, 1.0)
assertTrue(near(0.0, x.value(), 1e-2) || near(1.0, x.value(), 1e-2));
assertTrue(near(0.0, y.value(), 1e-2) || near(1.0, y.value(), 1e-2));
}
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testRosenbrockWithDiskConstraint() {
// https://en.wikipedia.org/wiki/Test_functions_for_optimization#Test_functions_for_constrained_optimization
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable();
var y = problem.decisionVariable();
problem.minimize(
pow(new Variable(1).minus(x), 2).plus(pow(y.minus(pow(x, 2)), 2).times(100)));
problem.subjectTo(le(pow(x, 2).plus(pow(y, 2)), 2));
assertEquals(ExpressionType.NONLINEAR, problem.costFunctionType());
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
assertEquals(ExpressionType.QUADRATIC, problem.inequalityConstraintType());
for (var x0 : range(-1.5, 1.5, 0.1)) {
for (var y0 : range(-1.5, 1.5, 0.1)) {
x.setValue(x0);
y.setValue(y0);
assertEquals(ExitStatus.SUCCESS, problem.solve());
assertEquals(1.0, x.value(), 1e-3);
assertEquals(1.0, y.value(), 1e-3);
}
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testMinimum2DDistanceWithLinearConstraint() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable();
var y = problem.decisionVariable();
x.setValue(20.0);
y.setValue(50.0);
problem.minimize(sqrt(x.times(x).plus(y.times(y))));
problem.subjectTo(eq(y, x.unaryMinus().plus(5.0)));
assertEquals(ExpressionType.NONLINEAR, problem.costFunctionType());
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
assertEquals(2.5, x.value(), 1e-2);
assertEquals(2.5, y.value(), 1e-2);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testConflictingBounds() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable();
var y = problem.decisionVariable();
problem.minimize(hypot(x, y));
problem.subjectTo(le(hypot(x, y), 1));
problem.subjectTo(bounds(0.5, x, -0.5));
assertEquals(ExpressionType.NONLINEAR, problem.costFunctionType());
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
assertEquals(ExpressionType.NONLINEAR, problem.inequalityConstraintType());
assertEquals(ExitStatus.GLOBALLY_INFEASIBLE, problem.solve());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testWachterAndBieglerLineSearchFailure() {
// See example 19.2 of [1]
//
// [1] Nocedal, J. and Wright, S. "Numerical Optimization", 2nd. ed., Ch. 19. Springer, 2006.
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable();
var s1 = problem.decisionVariable();
var s2 = problem.decisionVariable();
x.setValue(-2);
s1.setValue(3);
s2.setValue(1);
problem.minimize(x);
problem.subjectTo(eq(pow(x, 2).minus(s1).minus(1), 0));
problem.subjectTo(eq(x.minus(s2).minus(0.5), 0));
problem.subjectTo(ge(s1, 0));
problem.subjectTo(ge(s2, 0));
assertEquals(ExpressionType.LINEAR, problem.costFunctionType());
assertEquals(ExpressionType.QUADRATIC, problem.equalityConstraintType());
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
assertEquals(1.0, x.value(), 1e-6);
assertEquals(0.0, s1.value(), 1e-6);
assertEquals(0.5, s2.value(), 1e-6);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -0,0 +1,16 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import org.junit.jupiter.api.Test;
public class ProblemJNITest {
@Test
public void testLink() {
assertDoesNotThrow(ProblemJNI::forceLoad);
}
}

View File

@@ -0,0 +1,194 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.wpilib.math.optimization.Constraints.eq;
import static org.wpilib.math.optimization.Constraints.ge;
import org.junit.jupiter.api.Test;
import org.wpilib.math.autodiff.ExpressionType;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.optimization.solver.ExitStatus;
class QuadraticProblemTest {
@Test
void testUnconstrained1D() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable();
x.setValue(2.0);
problem.minimize(x.times(x).minus(x.times(6.0)));
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
assertEquals(3.0, x.value(), 1e-6);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testUnconstrained2D() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable();
var y = problem.decisionVariable();
x.setValue(1.0);
y.setValue(2.0);
problem.minimize(x.times(x).plus(y.times(y)));
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
assertEquals(0.0, x.value(), 1e-6);
assertEquals(0.0, y.value(), 1e-6);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable(2);
x.get(0).setValue(1.0);
x.get(1).setValue(2.0);
problem.minimize(x.T().times(x).get(0));
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
assertEquals(0.0, x.value(0), 1e-6);
assertEquals(0.0, x.value(1), 1e-6);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testEqualityConstrained() {
// Maximize xy subject to x + 3y = 36.
//
// Maximize f(x,y) = xy
// subject to g(x,y) = x + 3y - 36 = 0
//
// value func constraint
// | |
// v v
// L(x,y,λ) = f(x,y) - λg(x,y)
// L(x,y,λ) = xy - λ(x + 3y - 36)
// L(x,y,λ) = xy - xλ - 3yλ + 36λ
//
// ∇_x,y,λ L(x,y,λ) = 0
//
// ∂L/∂x = y - λ
// ∂L/∂y = x - 3λ
// ∂L/∂λ = -x - 3y + 36
//
// 0x + 1y - 1λ = 0
// 1x + 0y - 3λ = 0
// -1x - 3y + 0λ + 36 = 0
//
// [ 0 1 -1][x] [ 0]
// [ 1 0 -3][y] = [ 0]
// [-1 -3 0][λ] [-36]
//
// Solve with:
//
// ```python
// np.linalg.solve(
// np.array([[0,1,-1],
// [1,0,-3],
// [-1,-3,0]]),
// np.array([[0], [0], [-36]]))
// ```
//
// [x] [18]
// [y] = [ 6]
// [λ] [ 6]
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable();
var y = problem.decisionVariable();
problem.maximize(x.times(y));
problem.subjectTo(eq(x.plus(y.times(3)), 36));
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
assertEquals(18.0, x.value(), 1e-5);
assertEquals(6.0, y.value(), 1e-5);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable(2);
x.get(0).setValue(1.0);
x.get(1).setValue(2.0);
problem.minimize(x.T().times(x).get(0));
problem.subjectTo(eq(x, new double[][] {{3.0}, {3.0}}));
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
assertEquals(3.0, x.value(0), 1e-5);
assertEquals(3.0, x.value(1), 1e-5);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testInequalityConstrained2D() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable();
var y = problem.decisionVariable();
x.setValue(5.0);
y.setValue(5.0);
problem.minimize(x.times(x).plus(y.times(2).times(y)));
problem.subjectTo(ge(y, x.unaryMinus().plus(5)));
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
assertEquals(3.0 + 1.0 / 3.0, x.value(), 1e-6);
assertEquals(1.0 + 2.0 / 3.0, y.value(), 1e-6);
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -0,0 +1,73 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.ejml.simple.SimpleMatrix;
import org.junit.jupiter.api.Test;
import org.wpilib.math.autodiff.ExpressionType;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.optimization.solver.ExitStatus;
class TrivialProblemTest {
@Test
void testEmpty() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
assertEquals(ExpressionType.NONE, problem.costFunctionType());
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testNoCostUnconstrained() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
@SuppressWarnings("VariableDeclarationUsageDistance")
var X = problem.decisionVariable(2, 3);
assertEquals(ExpressionType.NONE, problem.costFunctionType());
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
for (int row = 0; row < X.rows(); ++row) {
for (int col = 0; col < X.cols(); ++col) {
assertEquals(0.0, X.value(row, col));
}
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var X = problem.decisionVariable(2, 3);
X.setValue(SimpleMatrix.ones(2, 3));
assertEquals(ExpressionType.NONE, problem.costFunctionType());
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
assertEquals(ExitStatus.SUCCESS, problem.solve());
for (int row = 0; row < X.rows(); ++row) {
for (int col = 0; col < X.cols(); ++col) {
assertEquals(1.0, X.value(row, col));
}
}
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}

View File

@@ -0,0 +1,222 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization.solver;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.wpilib.math.autodiff.Variable.sqrt;
import static org.wpilib.math.optimization.Constraints.eq;
import static org.wpilib.math.optimization.Constraints.ge;
import static org.wpilib.math.optimization.Constraints.gt;
import org.junit.jupiter.api.Test;
import org.wpilib.math.autodiff.ExpressionType;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.optimization.Problem;
// These tests ensure coverage of the off-nominal exit statuses
class ExitStatusTest {
@Test
void testCallbackRequestedStop() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable();
problem.minimize(x.times(x));
problem.addCallback(info -> false);
assertEquals(ExitStatus.SUCCESS, problem.solve());
problem.addCallback(info -> true);
assertEquals(ExitStatus.CALLBACK_REQUESTED_STOP, problem.solve());
problem.clearCallbacks();
problem.addCallback(info -> false);
assertEquals(ExitStatus.SUCCESS, problem.solve());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testTooFewDOFs() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable();
var y = problem.decisionVariable();
var z = problem.decisionVariable();
problem.subjectTo(eq(x, 1));
problem.subjectTo(eq(x, 2));
problem.subjectTo(eq(y, 1));
problem.subjectTo(eq(z, 1));
assertEquals(ExpressionType.NONE, problem.costFunctionType());
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
assertEquals(ExitStatus.TOO_FEW_DOFS, problem.solve());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testLocallyInfeasible() {
assertEquals(0, Variable.totalNativeMemoryUsage());
// Equality constraints
try (var problem = new Problem()) {
var x = problem.decisionVariable();
var y = problem.decisionVariable();
var z = problem.decisionVariable();
problem.subjectTo(eq(x, y.plus(1)));
problem.subjectTo(eq(y, z.plus(1)));
problem.subjectTo(eq(z, x.plus(1)));
assertEquals(ExpressionType.NONE, problem.costFunctionType());
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
assertEquals(ExitStatus.LOCALLY_INFEASIBLE, problem.solve());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
// Inequality constraints
try (var problem = new Problem()) {
var x = problem.decisionVariable();
var y = problem.decisionVariable();
var z = problem.decisionVariable();
problem.subjectTo(ge(x, y.plus(1)));
problem.subjectTo(ge(y, z.plus(1)));
problem.subjectTo(ge(z, x.plus(1)));
assertEquals(ExpressionType.NONE, problem.costFunctionType());
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
assertEquals(ExitStatus.LOCALLY_INFEASIBLE, problem.solve());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testNonfiniteInitialGuess() {
assertEquals(0, Variable.totalNativeMemoryUsage());
// Nonfinite cost
try (var problem = new Problem()) {
var x = problem.decisionVariable();
problem.minimize(new Variable(1).div(x));
assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve());
}
// Nonfinite gradient
try (var problem = new Problem()) {
var x = problem.decisionVariable();
problem.minimize(sqrt(x));
assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve());
}
// Nonfinite equality constraint
try (var problem = new Problem()) {
var x = problem.decisionVariable();
problem.subjectTo(eq(new Variable(1).div(x), 1));
assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve());
}
// Nonfinite equality constraint Jacobian
try (var problem = new Problem()) {
var x = problem.decisionVariable();
problem.subjectTo(eq(sqrt(x), 1));
assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve());
}
// Nonfinite inequality constraint
try (var problem = new Problem()) {
var x = problem.decisionVariable();
problem.subjectTo(gt(new Variable(1).div(x), 1));
assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve());
}
// Nonfinite inequality constraint Jacobian
try (var problem = new Problem()) {
var x = problem.decisionVariable();
problem.subjectTo(gt(sqrt(x), 1));
assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testDivergingIterates() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable();
problem.minimize(x);
assertEquals(ExpressionType.LINEAR, problem.costFunctionType());
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
assertEquals(ExitStatus.DIVERGING_ITERATES, problem.solve());
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testMaxIterationsExceeded() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable();
problem.minimize(x.times(x));
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
assertEquals(
ExitStatus.MAX_ITERATIONS_EXCEEDED, problem.solve(new Options().withMaxIterations(0)));
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
@Test
void testTimeout() {
assertEquals(0, Variable.totalNativeMemoryUsage());
try (var problem = new Problem()) {
var x = problem.decisionVariable();
problem.minimize(x.times(x));
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
assertEquals(ExitStatus.TIMEOUT, problem.solve(new Options().withTimeout(0.0)));
}
assertEquals(0, Variable.totalNativeMemoryUsage());
}
}