From d248c040bf9ff36b7b1d2d91a130fa6337fc6340 Mon Sep 17 00:00:00 2001 From: Tyler Veness Date: Sun, 29 Mar 2026 22:34:21 -0700 Subject: [PATCH] [wpimath] Add Sleipnir Java bindings (#8236) The wrapper includes reverse mode autodiff, the Problem DSL, and the optimal control problem API. I wrote it by directly translating the upstream [API](https://github.com/SleipnirGroup/Sleipnir/tree/main/include/sleipnir) and [tests](https://github.com/SleipnirGroup/Sleipnir/tree/main/test) to Java (i.e., copy-paste-modify). I replaced the ArmFeedforward and Ellipse2d JNIs with implementations using the Sleipnir Java bindings. Switching dev binary JNIs to release by default sped up wpimath test runs from several minutes to 7 seconds. --- .../java/wpilib/robot/CartPoleBenchmark.java | 128 ++ .../src/main/java/wpilib/robot/Main.java | 7 + wpimath/BUILD.bazel | 4 + wpimath/CMakeLists.txt | 8 +- .../wpilib/math/autodiff/ExpressionType.java | 47 + .../org/wpilib/math/autodiff/Gradient.java | 78 + .../org/wpilib/math/autodiff/GradientJNI.java | 48 + .../org/wpilib/math/autodiff/Hessian.java | 79 + .../org/wpilib/math/autodiff/HessianJNI.java | 48 + .../org/wpilib/math/autodiff/Jacobian.java | 102 ++ .../org/wpilib/math/autodiff/JacobianJNI.java | 48 + .../math/autodiff/NativeSparseTriplets.java | 37 + .../math/autodiff/NumericalIntegration.java | 89 + .../java/org/wpilib/math/autodiff/Slice.java | 267 +++ .../org/wpilib/math/autodiff/Variable.java | 634 +++++++ .../wpilib/math/autodiff/VariableBlock.java | 826 +++++++++ .../org/wpilib/math/autodiff/VariableJNI.java | 268 +++ .../wpilib/math/autodiff/VariableMatrix.java | 1047 ++++++++++++ .../math/autodiff/VariableMatrixJNI.java | 25 + .../wpilib/math/autodiff/VariablePool.java | 55 + .../math/controller/ArmFeedforward.java | 117 +- .../org/wpilib/math/geometry/Ellipse2d.java | 46 +- .../wpilib/math/jni/ArmFeedforwardJNI.java | 36 - .../org/wpilib/math/jni/Ellipse2dJNI.java | 35 - .../wpilib/math/optimization/Constraints.java | 1472 +++++++++++++++++ .../optimization/EqualityConstraints.java | 23 + .../optimization/InequalityConstraints.java | 23 + .../org/wpilib/math/optimization/OCP.java | 588 +++++++ .../org/wpilib/math/optimization/Problem.java | 312 ++++ .../wpilib/math/optimization/ProblemJNI.java | 134 ++ .../ocp/ConstraintEvaluationFunction.java | 25 + .../optimization/ocp/DynamicsFunction.java | 31 + .../math/optimization/ocp/DynamicsType.java | 20 + .../math/optimization/ocp/TimestepMethod.java | 22 + .../optimization/ocp/TranscriptionMethod.java | 27 + .../math/optimization/solver/ExitStatus.java | 66 + .../optimization/solver/IterationInfo.java | 53 + .../math/optimization/solver/Options.java | 98 ++ .../math/system/NumericalIntegration.java | 25 + .../main/native/cpp/jni/ArmFeedforwardJNI.cpp | 37 - .../src/main/native/cpp/jni/Ellipse2dJNI.cpp | 39 - .../main/native/cpp/jni/SleipnirJNIUtil.hpp | 65 + .../native/cpp/jni/autodiff/GradientJNI.cpp | 90 + .../native/cpp/jni/autodiff/HessianJNI.cpp | 96 ++ .../native/cpp/jni/autodiff/JacobianJNI.cpp | 96 ++ .../native/cpp/jni/autodiff/VariableJNI.cpp | 463 ++++++ .../cpp/jni/autodiff/VariableMatrixJNI.cpp | 53 + .../cpp/jni/optimization/ProblemJNI.cpp | 255 +++ .../java/org/wpilib/math/DoubleRange.java | 24 + .../org/wpilib/math/MatrixAssertions.java | 41 + .../GradientJNITest.java} | 6 +- .../wpilib/math/autodiff/GradientTest.java | 964 +++++++++++ .../HessianJNITest.java} | 6 +- .../org/wpilib/math/autodiff/HessianTest.java | 499 ++++++ .../wpilib/math/autodiff/JacobianJNITest.java | 16 + .../wpilib/math/autodiff/JacobianTest.java | 266 +++ .../org/wpilib/math/autodiff/SliceTest.java | 481 ++++++ .../wpilib/math/autodiff/VariableJNITest.java | 16 + .../math/autodiff/VariableMatrixJNITest.java | 16 + .../math/autodiff/VariableMatrixTest.java | 600 +++++++ .../wpilib/math/autodiff/VariableTest.java | 57 + .../math/controller/ArmFeedforwardTest.java | 13 + .../wpilib/math/geometry/Ellipse2dTest.java | 9 + .../ArmOnElevatorProblemTest.java | 121 ++ .../math/optimization/CartPoleOCPTest.java | 101 ++ .../optimization/CartPoleProblemTest.java | 114 ++ .../math/optimization/CartPoleUtil.java | 122 ++ .../math/optimization/CurrentManager.java | 92 ++ .../math/optimization/CurrentManagerTest.java | 62 + .../optimization/DecisionVariableTest.java | 138 ++ .../DifferentialDriveOCPTest.java | 85 + .../DifferentialDriveProblemTest.java | 116 ++ .../optimization/DifferentialDriveUtil.java | 58 + .../DoubleIntegratorProblemTest.java | 127 ++ .../math/optimization/FlywheelOCPTest.java | 178 ++ .../optimization/FlywheelProblemTest.java | 120 ++ .../math/optimization/LinearProblemTest.java | 72 + .../optimization/NonlinearProblemTest.java | 212 +++ .../math/optimization/ProblemJNITest.java | 16 + .../optimization/QuadraticProblemTest.java | 194 +++ .../math/optimization/TrivialProblemTest.java | 73 + .../optimization/solver/ExitStatusTest.java | 222 +++ .../cpp/optimization/CurrentManagerTest.cpp | 49 + .../wpi/math/optimization/CurrentManager.hpp | 97 ++ 84 files changed, 13405 insertions(+), 170 deletions(-) create mode 100644 benchmark/src/main/java/wpilib/robot/CartPoleBenchmark.java create mode 100644 wpimath/src/main/java/org/wpilib/math/autodiff/ExpressionType.java create mode 100644 wpimath/src/main/java/org/wpilib/math/autodiff/Gradient.java create mode 100644 wpimath/src/main/java/org/wpilib/math/autodiff/GradientJNI.java create mode 100644 wpimath/src/main/java/org/wpilib/math/autodiff/Hessian.java create mode 100644 wpimath/src/main/java/org/wpilib/math/autodiff/HessianJNI.java create mode 100644 wpimath/src/main/java/org/wpilib/math/autodiff/Jacobian.java create mode 100644 wpimath/src/main/java/org/wpilib/math/autodiff/JacobianJNI.java create mode 100644 wpimath/src/main/java/org/wpilib/math/autodiff/NativeSparseTriplets.java create mode 100644 wpimath/src/main/java/org/wpilib/math/autodiff/NumericalIntegration.java create mode 100644 wpimath/src/main/java/org/wpilib/math/autodiff/Slice.java create mode 100644 wpimath/src/main/java/org/wpilib/math/autodiff/Variable.java create mode 100644 wpimath/src/main/java/org/wpilib/math/autodiff/VariableBlock.java create mode 100644 wpimath/src/main/java/org/wpilib/math/autodiff/VariableJNI.java create mode 100644 wpimath/src/main/java/org/wpilib/math/autodiff/VariableMatrix.java create mode 100644 wpimath/src/main/java/org/wpilib/math/autodiff/VariableMatrixJNI.java create mode 100644 wpimath/src/main/java/org/wpilib/math/autodiff/VariablePool.java delete mode 100644 wpimath/src/main/java/org/wpilib/math/jni/ArmFeedforwardJNI.java delete mode 100644 wpimath/src/main/java/org/wpilib/math/jni/Ellipse2dJNI.java create mode 100644 wpimath/src/main/java/org/wpilib/math/optimization/Constraints.java create mode 100644 wpimath/src/main/java/org/wpilib/math/optimization/EqualityConstraints.java create mode 100644 wpimath/src/main/java/org/wpilib/math/optimization/InequalityConstraints.java create mode 100644 wpimath/src/main/java/org/wpilib/math/optimization/OCP.java create mode 100644 wpimath/src/main/java/org/wpilib/math/optimization/Problem.java create mode 100644 wpimath/src/main/java/org/wpilib/math/optimization/ProblemJNI.java create mode 100644 wpimath/src/main/java/org/wpilib/math/optimization/ocp/ConstraintEvaluationFunction.java create mode 100644 wpimath/src/main/java/org/wpilib/math/optimization/ocp/DynamicsFunction.java create mode 100644 wpimath/src/main/java/org/wpilib/math/optimization/ocp/DynamicsType.java create mode 100644 wpimath/src/main/java/org/wpilib/math/optimization/ocp/TimestepMethod.java create mode 100644 wpimath/src/main/java/org/wpilib/math/optimization/ocp/TranscriptionMethod.java create mode 100644 wpimath/src/main/java/org/wpilib/math/optimization/solver/ExitStatus.java create mode 100644 wpimath/src/main/java/org/wpilib/math/optimization/solver/IterationInfo.java create mode 100644 wpimath/src/main/java/org/wpilib/math/optimization/solver/Options.java delete mode 100644 wpimath/src/main/native/cpp/jni/ArmFeedforwardJNI.cpp delete mode 100644 wpimath/src/main/native/cpp/jni/Ellipse2dJNI.cpp create mode 100644 wpimath/src/main/native/cpp/jni/SleipnirJNIUtil.hpp create mode 100644 wpimath/src/main/native/cpp/jni/autodiff/GradientJNI.cpp create mode 100644 wpimath/src/main/native/cpp/jni/autodiff/HessianJNI.cpp create mode 100644 wpimath/src/main/native/cpp/jni/autodiff/JacobianJNI.cpp create mode 100644 wpimath/src/main/native/cpp/jni/autodiff/VariableJNI.cpp create mode 100644 wpimath/src/main/native/cpp/jni/autodiff/VariableMatrixJNI.cpp create mode 100644 wpimath/src/main/native/cpp/jni/optimization/ProblemJNI.cpp create mode 100644 wpimath/src/test/java/org/wpilib/math/DoubleRange.java create mode 100644 wpimath/src/test/java/org/wpilib/math/MatrixAssertions.java rename wpimath/src/test/java/org/wpilib/math/{jni/ArmFeedforwardJNITest.java => autodiff/GradientJNITest.java} (74%) create mode 100644 wpimath/src/test/java/org/wpilib/math/autodiff/GradientTest.java rename wpimath/src/test/java/org/wpilib/math/{jni/Ellipse2dJNITest.java => autodiff/HessianJNITest.java} (75%) create mode 100644 wpimath/src/test/java/org/wpilib/math/autodiff/HessianTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/autodiff/JacobianJNITest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/autodiff/JacobianTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/autodiff/SliceTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/autodiff/VariableJNITest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/autodiff/VariableMatrixJNITest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/autodiff/VariableMatrixTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/autodiff/VariableTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/ArmOnElevatorProblemTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/CartPoleOCPTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/CartPoleProblemTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/CartPoleUtil.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/CurrentManager.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/CurrentManagerTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/DecisionVariableTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/DifferentialDriveOCPTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/DifferentialDriveProblemTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/DifferentialDriveUtil.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/DoubleIntegratorProblemTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/FlywheelOCPTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/FlywheelProblemTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/LinearProblemTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/NonlinearProblemTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/ProblemJNITest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/QuadraticProblemTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/TrivialProblemTest.java create mode 100644 wpimath/src/test/java/org/wpilib/math/optimization/solver/ExitStatusTest.java create mode 100644 wpimath/src/test/native/cpp/optimization/CurrentManagerTest.cpp create mode 100644 wpimath/src/test/native/include/wpi/math/optimization/CurrentManager.hpp diff --git a/benchmark/src/main/java/wpilib/robot/CartPoleBenchmark.java b/benchmark/src/main/java/wpilib/robot/CartPoleBenchmark.java new file mode 100644 index 0000000000..cd08ecccea --- /dev/null +++ b/benchmark/src/main/java/wpilib/robot/CartPoleBenchmark.java @@ -0,0 +1,128 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package wpilib.robot; + +import static org.wpilib.math.autodiff.NumericalIntegration.rk4; +import static org.wpilib.math.autodiff.Variable.cos; +import static org.wpilib.math.autodiff.Variable.sin; +import static org.wpilib.math.autodiff.VariableMatrix.solve; +import static org.wpilib.math.optimization.Constraints.eq; +import static org.wpilib.math.optimization.Constraints.ge; +import static org.wpilib.math.optimization.Constraints.le; + +import org.ejml.simple.SimpleMatrix; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.autodiff.VariableMatrix; +import org.wpilib.math.optimization.Problem; +import org.wpilib.math.optimization.solver.Options; +import org.wpilib.math.util.MathUtil; + +public final class CartPoleBenchmark { + private CartPoleBenchmark() { + // Utility class. + } + + @SuppressWarnings("LocalVariableName") + private static VariableMatrix cartPoleDynamics(VariableMatrix x, VariableMatrix u) { + final double m_c = 5.0; // Cart mass (kg) + final double m_p = 0.5; // Pole mass (kg) + final double l = 0.5; // Pole length (m) + final double g = 9.806; // Acceleration due to gravity (m/s²) + + var q = x.segment(0, 2); + var qdot = x.segment(2, 2); + var theta = q.get(1); + var thetadot = qdot.get(1); + + // [ m_c + m_p m_p l cosθ] + // M(q) = [m_p l cosθ m_p l² ] + var M = + new VariableMatrix( + new Variable[][] { + {new Variable(m_c + m_p), cos(theta).times(m_p * l)}, + {cos(theta).times(m_p * l), new Variable(m_p * Math.pow(l, 2))} + }); + + // [0 −m_p lθ̇ sinθ] + // C(q, q̇) = [0 0 ] + var C = + new VariableMatrix( + new Variable[][] { + {new Variable(0), thetadot.times(-m_p * l).times(sin(theta))}, + {new Variable(0), new Variable(0)} + }); + + // [ 0 ] + // τ_g(q) = [-m_p gl sinθ] + var tau_g = + new VariableMatrix(new Variable[][] {{new Variable(0)}, {sin(theta).times(-m_p * g * l)}}); + + // [1] + // B = [0] + var B = new VariableMatrix(new double[][] {{1}, {0}}); + + // q̈ = M⁻¹(q)(τ_g(q) − C(q, q̇)q̇ + Bu) + var qddot = new VariableMatrix(4); + qddot.segment(0, 2).set(qdot); + qddot.segment(2, 2).set(solve(M, tau_g.minus(C.times(qdot)).plus(B.times(u)))); + return qddot; + } + + /** Cart-pole benchmark. */ + public static void cartPole() { + final double T = 5.0; // s + final double dt = 0.05; // s + final int N = (int) (T / dt); + + final double u_max = 20.0; // N + final double d_max = 2.0; // m + + final var x_initial = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}}); + final var x_final = new SimpleMatrix(new double[][] {{1.0}, {Math.PI}, {0.0}, {0.0}}); + + var problem = new Problem(); + + // x = [q, q̇]ᵀ = [x, θ, ẋ, θ̇]ᵀ + var X = problem.decisionVariable(4, N + 1); + + // Initial guess + for (int k = 0; k < N + 1; ++k) { + X.get(0, k).setValue(MathUtil.lerp(x_initial.get(0), x_final.get(0), (double) k / N)); + X.get(1, k).setValue(MathUtil.lerp(x_initial.get(1), x_final.get(1), (double) k / N)); + } + + // u = f_x + var U = problem.decisionVariable(1, N); + + // Initial conditions + problem.subjectTo(eq(X.col(0), x_initial)); + + // Final conditions + problem.subjectTo(eq(X.col(N), x_final)); + + // Cart position constraints + problem.subjectTo(ge(X.row(0), 0.0)); + problem.subjectTo(le(X.row(0), d_max)); + + // Input constraints + problem.subjectTo(ge(U, -u_max)); + problem.subjectTo(le(U, u_max)); + + // Dynamics constraints - RK4 integration + for (int k = 0; k < N; ++k) { + problem.subjectTo( + eq(X.col(k + 1), rk4(CartPoleBenchmark::cartPoleDynamics, X.col(k), U.col(k), dt))); + } + + // Minimize sum squared inputs + var J = new Variable(0.0); + for (int k = 0; k < N; ++k) { + J = J.plus(U.col(k).T().times(U.col(k)).get(0)); + } + problem.minimize(J); + + problem.solve(new Options().withDiagnostics(true)); + } +} diff --git a/benchmark/src/main/java/wpilib/robot/Main.java b/benchmark/src/main/java/wpilib/robot/Main.java index 975e3401aa..3db4d43472 100644 --- a/benchmark/src/main/java/wpilib/robot/Main.java +++ b/benchmark/src/main/java/wpilib/robot/Main.java @@ -37,6 +37,13 @@ public class Main { new Runner(opt).run(); } + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.MICROSECONDS) + public void cartPole() { + CartPoleBenchmark.cartPole(); + } + @Benchmark @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.MICROSECONDS) diff --git a/wpimath/BUILD.bazel b/wpimath/BUILD.bazel index 80f99590fc..bd2e92a1c7 100644 --- a/wpimath/BUILD.bazel +++ b/wpimath/BUILD.bazel @@ -236,6 +236,9 @@ wpilib_cc_shared_library( wpilib_jni_java_library( name = "wpimath-java", srcs = [":generated_java"] + glob(["src/main/java/**/*.java"]), + javacopts = [ + "-Xep:UnicodeInCode:OFF", + ], maven_artifact_name = "wpimath-java", maven_group_id = "org.wpilib.wpimath", native_libs = [":wpimathjni"], @@ -288,6 +291,7 @@ wpilib_java_junit5_test( "//wpiunits:wpiunits-java", "//wpiutil:wpiutil-java", "@maven//:org_ejml_ejml_core", + "@maven//:org_ejml_ejml_ddense", "@maven//:org_ejml_ejml_simple", "@maven//:us_hebi_quickbuf_quickbuf_runtime", ], diff --git a/wpimath/CMakeLists.txt b/wpimath/CMakeLists.txt index 273698b4d4..73d5210913 100644 --- a/wpimath/CMakeLists.txt +++ b/wpimath/CMakeLists.txt @@ -7,14 +7,18 @@ include(DownloadAndCheck) file( GLOB wpimath_jni_src - src/main/native/cpp/jni/ArmFeedforwardJNI.cpp src/main/native/cpp/jni/DAREJNI.cpp src/main/native/cpp/jni/EigenJNI.cpp - src/main/native/cpp/jni/Ellipse2dJNI.cpp src/main/native/cpp/jni/Exceptions.cpp src/main/native/cpp/jni/LinearSystemUtilJNI.cpp src/main/native/cpp/jni/Transform3dJNI.cpp src/main/native/cpp/jni/Twist3dJNI.cpp + src/main/native/cpp/jni/autodiff/GradientJNI.cpp + src/main/native/cpp/jni/autodiff/HessianJNI.cpp + src/main/native/cpp/jni/autodiff/JacobianJNI.cpp + src/main/native/cpp/jni/autodiff/VariableJNI.cpp + src/main/native/cpp/jni/autodiff/VariableMatrixJNI.cpp + src/main/native/cpp/jni/optimization/ProblemJNI.cpp ) # Java bindings diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/ExpressionType.java b/wpimath/src/main/java/org/wpilib/math/autodiff/ExpressionType.java new file mode 100644 index 0000000000..dc9518e8fe --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/ExpressionType.java @@ -0,0 +1,47 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +/** + * Expression type. + * + *

Used for autodiff caching. + */ +public enum ExpressionType { + /** There is no expression. */ + NONE(0), + /** The expression is a constant. */ + CONSTANT(1), + /** The expression is composed of linear and lower-order operators. */ + LINEAR(2), + /** The expression is composed of quadratic and lower-order operators. */ + QUADRATIC(3), + /** The expression is composed of nonlinear and lower-order operators. */ + NONLINEAR(4); + + /** ExpressionType value. */ + public final int value; + + ExpressionType(int value) { + this.value = value; + } + + /** + * Converts integer to its corresponding enum value. + * + * @param x The integer. + * @return The enum value. + */ + public static ExpressionType fromInt(int x) { + return switch (x) { + case 0 -> ExpressionType.NONE; + case 1 -> ExpressionType.CONSTANT; + case 2 -> ExpressionType.LINEAR; + case 3 -> ExpressionType.QUADRATIC; + case 4 -> ExpressionType.NONLINEAR; + default -> null; + }; + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/Gradient.java b/wpimath/src/main/java/org/wpilib/math/autodiff/Gradient.java new file mode 100644 index 0000000000..136a7b59f7 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/Gradient.java @@ -0,0 +1,78 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import org.ejml.simple.SimpleMatrix; + +/** + * This class calculates the gradient of a variable with respect to a vector of variables. + * + *

The gradient is only recomputed if the variable expression is quadratic or higher order. + */ +public class Gradient implements AutoCloseable { + private long m_handle; + private int m_rows; + + /** + * Constructs a Gradient object. + * + * @param variable Variable of which to compute the gradient. + * @param wrt Variable with respect to which to compute the gradient. + */ + public Gradient(Variable variable, Variable wrt) { + this(variable, new VariableMatrix(wrt)); + } + + /** + * Constructs a Gradient object. + * + * @param variable Variable of which to compute the gradient. + * @param wrt Vector of variables with respect to which to compute the gradient. + */ + public Gradient(Variable variable, VariableMatrix wrt) { + assert wrt.cols() == 1; + + m_handle = GradientJNI.create(variable.getHandle(), wrt.getHandles()); + m_rows = wrt.rows(); + } + + /** + * Constructs a Gradient object. + * + * @param variable Variable of which to compute the gradient. + * @param wrt Vector of variables with respect to which to compute the gradient. + */ + public Gradient(Variable variable, VariableBlock wrt) { + this(variable, new VariableMatrix(wrt)); + } + + @Override + public void close() { + if (m_handle != 0) { + GradientJNI.destroy(m_handle); + m_handle = 0; + } + } + + /** + * Returns the gradient as a VariableMatrix. + * + *

This is useful when constructing optimization problems with derivatives in them. + * + * @return The gradient as a VariableMatrix. + */ + public VariableMatrix get() { + return new VariableMatrix(m_rows, 1, GradientJNI.get(m_handle)); + } + + /** + * Evaluates the gradient at wrt's value. + * + * @return The gradient at wrt's value. + */ + public SimpleMatrix value() { + return GradientJNI.value(m_handle).toSimpleMatrix(m_rows, 1); + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/GradientJNI.java b/wpimath/src/main/java/org/wpilib/math/autodiff/GradientJNI.java new file mode 100644 index 0000000000..e5f7568bb2 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/GradientJNI.java @@ -0,0 +1,48 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import org.wpilib.math.jni.WPIMathJNI; + +/** Gradient JNI functions. */ +final class GradientJNI extends WPIMathJNI { + private GradientJNI() { + // Utility class. + } + + /** + * Constructs a Gradient object. + * + * @param variable Variable of which to compute the Gradient. + * @param wrt Vector of variables with respect to which to compute the Gradient. + */ + static native long create(long variable, long[] wrt); + + /** + * Destructs a Gradient. + * + * @param handle Gradient handle. + */ + static native void destroy(long handle); + + /** + * Returns the Gradient as an array of Variable handles. + * + *

This is useful when constructing optimization problems with derivatives in them. + * + * @param handle Gradient handle. + * @return The Gradient as an array of Variable handles. + */ + static native long[] get(long handle); + + /** + * Evaluates the Gradient at wrt's value. + * + * @param handle Gradient handle. + * @return A record containing the triplet row, column and value arrays (int[], int[], and + * double[] respectively). + */ + static native NativeSparseTriplets value(long handle); +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/Hessian.java b/wpimath/src/main/java/org/wpilib/math/autodiff/Hessian.java new file mode 100644 index 0000000000..ba94ed2622 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/Hessian.java @@ -0,0 +1,79 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import org.ejml.simple.SimpleMatrix; + +/** + * This class calculates the Hessian of a variable with respect to a vector of variables. + * + *

The gradient tree is cached so subsequent Hessian calculations are faster, and the Hessian is + * only recomputed if the variable expression is nonlinear. + */ +public class Hessian implements AutoCloseable { + private long m_handle; + private int m_rows; + + /** + * Constructs a Hessian object. + * + * @param variable Variable of which to compute the Hessian. + * @param wrt Variable with respect to which to compute the Hessian. + */ + public Hessian(Variable variable, Variable wrt) { + this(variable, new VariableMatrix(wrt)); + } + + /** + * Constructs a Hessian object. + * + * @param variable Variable of which to compute the Hessian. + * @param wrt Vector of variables with respect to which to compute the Hessian. + */ + public Hessian(Variable variable, VariableMatrix wrt) { + assert wrt.cols() == 1; + + m_handle = HessianJNI.create(variable.getHandle(), wrt.getHandles()); + m_rows = wrt.rows(); + } + + /** + * Constructs a Hessian object. + * + * @param variable Variable of which to compute the Hessian. + * @param wrt Vector of variables with respect to which to compute the Hessian. + */ + public Hessian(Variable variable, VariableBlock wrt) { + this(variable, new VariableMatrix(wrt)); + } + + @Override + public void close() { + if (m_handle != 0) { + HessianJNI.destroy(m_handle); + m_handle = 0; + } + } + + /** + * Returns the Hessian as a VariableMatrix. + * + *

This is useful when constructing optimization problems with derivatives in them. + * + * @return The Hessian as a VariableMatrix. + */ + public VariableMatrix get() { + return new VariableMatrix(m_rows, m_rows, HessianJNI.get(m_handle)); + } + + /** + * Evaluates the Hessian at wrt's value. + * + * @return The Hessian at wrt's value. + */ + public SimpleMatrix value() { + return HessianJNI.value(m_handle).toSimpleMatrix(m_rows, m_rows); + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/HessianJNI.java b/wpimath/src/main/java/org/wpilib/math/autodiff/HessianJNI.java new file mode 100644 index 0000000000..4ae7735c9b --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/HessianJNI.java @@ -0,0 +1,48 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import org.wpilib.math.jni.WPIMathJNI; + +/** Hessian JNI functions. */ +final class HessianJNI extends WPIMathJNI { + private HessianJNI() { + // Utility class. + } + + /** + * Constructs a Hessian object. + * + * @param variable Variable of which to compute the Hessian. + * @param wrt Vector of variables with respect to which to compute the Hessian. + */ + static native long create(long variable, long[] wrt); + + /** + * Destructs a Hessian. + * + * @param handle Hessian handle. + */ + static native void destroy(long handle); + + /** + * Returns the Hessian as an array of Variable handles. + * + *

This is useful when constructing optimization problems with derivatives in them. + * + * @param handle Hessian handle. + * @return The Hessian as an array of Variable handles. + */ + static native long[] get(long handle); + + /** + * Evaluates the Hessian at wrt's value. + * + * @param handle Hessian handle. + * @return A record containing the triplet row, column and value arrays (int[], int[], and + * double[] respectively). + */ + static native NativeSparseTriplets value(long handle); +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/Jacobian.java b/wpimath/src/main/java/org/wpilib/math/autodiff/Jacobian.java new file mode 100644 index 0000000000..6dc5f99922 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/Jacobian.java @@ -0,0 +1,102 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import org.ejml.simple.SimpleMatrix; + +/** + * This class calculates the Jacobian of a vector of variables with respect to a vector of + * variables. + * + *

The Jacobian is only recomputed if the variable expression is quadratic or higher order. + */ +public class Jacobian implements AutoCloseable { + private long m_handle; + private int m_rows; + private int m_cols; + + /** + * Constructs a Jacobian object. + * + * @param variable Variable of which to compute the Jacobian. + * @param wrt Variable with respect to which to compute the Jacobian. + */ + public Jacobian(Variable variable, Variable wrt) { + this(new VariableMatrix(variable), new VariableMatrix(wrt)); + } + + /** + * Constructs a Jacobian object. + * + * @param variable Variable of which to compute the Jacobian. + * @param wrt Vector of variables with respect to which to compute the Jacobian. + */ + public Jacobian(Variable variable, VariableMatrix wrt) { + this(new VariableMatrix(variable), wrt); + } + + /** + * Constructs a Jacobian object. + * + * @param variable Variable of which to compute the Jacobian. + * @param wrt Vector of variables with respect to which to compute the Jacobian. + */ + public Jacobian(Variable variable, VariableBlock wrt) { + this(new VariableMatrix(variable), new VariableMatrix(wrt)); + } + + /** + * Constructs a Jacobian object. + * + * @param variables Vector of variables of which to compute the Jacobian. + * @param wrt Vector of variables with respect to which to compute the Jacobian. + */ + public Jacobian(VariableMatrix variables, VariableMatrix wrt) { + assert variables.cols() == 1; + assert wrt.cols() == 1; + + m_handle = JacobianJNI.create(variables.getHandles(), wrt.getHandles()); + m_rows = variables.rows(); + m_cols = wrt.rows(); + } + + /** + * Constructs a Jacobian object. + * + * @param variables Vector of variables of which to compute the Jacobian. + * @param wrt Vector of variables with respect to which to compute the Jacobian. + */ + public Jacobian(VariableMatrix variables, VariableBlock wrt) { + this(variables, new VariableMatrix(wrt)); + } + + @Override + public void close() { + if (m_handle != 0) { + JacobianJNI.destroy(m_handle); + m_handle = 0; + } + } + + /** + * Returns the Jacobian as a VariableMatrix. + * + *

This is useful when constructing optimization problems with derivatives in them. + * + * @return The Jacobian as a VariableMatrix. + */ + public VariableMatrix get() { + return new VariableMatrix(m_rows, m_cols, JacobianJNI.get(m_handle)); + } + + /** + * Evaluates the Jacobian at wrt's value. + * + * @return The Jacobian at wrt's value. + */ + public SimpleMatrix value() { + return JacobianJNI.value(m_handle).toSimpleMatrix(m_rows, m_cols); + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/JacobianJNI.java b/wpimath/src/main/java/org/wpilib/math/autodiff/JacobianJNI.java new file mode 100644 index 0000000000..f969d22e39 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/JacobianJNI.java @@ -0,0 +1,48 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import org.wpilib.math.jni.WPIMathJNI; + +/** Jacobian JNI functions. */ +final class JacobianJNI extends WPIMathJNI { + private JacobianJNI() { + // Utility class. + } + + /** + * Constructs a Jacobian object. + * + * @param variables Vector of variables of which to compute the Jacobian. + * @param wrt Vector of variables with respect to which to compute the Jacobian. + */ + static native long create(long[] variables, long[] wrt); + + /** + * Destructs a Jacobian. + * + * @param handle Jacobian handle. + */ + static native void destroy(long handle); + + /** + * Returns the Jacobian as an array of Variable handles. + * + *

This is useful when constructing optimization problems with derivatives in them. + * + * @param handle Jacobian handle. + * @return The Jacobian as an array of Variable handles. + */ + static native long[] get(long handle); + + /** + * Evaluates the Jacobian at wrt's value. + * + * @param handle Jacobian handle. + * @return A record containing the triplet row, column and value arrays (int[], int[], and + * double[] respectively). + */ + static native NativeSparseTriplets value(long handle); +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/NativeSparseTriplets.java b/wpimath/src/main/java/org/wpilib/math/autodiff/NativeSparseTriplets.java new file mode 100644 index 0000000000..fd1deb138f --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/NativeSparseTriplets.java @@ -0,0 +1,37 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import org.ejml.data.DMatrixSparseCSC; +import org.ejml.data.DMatrixSparseTriplet; +import org.ejml.ops.DConvertMatrixStruct; +import org.ejml.simple.SimpleMatrix; + +/** + * Wrapper for sparse matrix triplets from JNI. + * + *

We can't use DMatrixSparseTriplet because it doesn't have a method for bulk-initialization + * from triplet arrays. + * + * @param rows Triplet rows. + * @param cols Triplet columns. + * @param values Triplet values. + */ +public record NativeSparseTriplets(int[] rows, int[] cols, double[] values) { + /** + * Returns a SimpleMatrix wrapper for this set of triplets. + * + * @param rows Number of rows in sparse SimpleMatrix. + * @param cols Number of columns in sparse SimpleMatrix. + * @return A SimpleMatrix wrapper for this set of triplets. + */ + public SimpleMatrix toSimpleMatrix(int rows, int cols) { + var ejmlTriplets = new DMatrixSparseTriplet(rows, cols, values().length); + for (int i = 0; i < values().length; ++i) { + ejmlTriplets.addItem(rows()[i], cols()[i], values()[i]); + } + return new SimpleMatrix(DConvertMatrixStruct.convert(ejmlTriplets, (DMatrixSparseCSC) null)); + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/NumericalIntegration.java b/wpimath/src/main/java/org/wpilib/math/autodiff/NumericalIntegration.java new file mode 100644 index 0000000000..9d4ee80c82 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/NumericalIntegration.java @@ -0,0 +1,89 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import java.util.function.BiFunction; + +/** Numerical integration utilities. */ +public final class NumericalIntegration { + private NumericalIntegration() { + // Utility class. + } + + /** + * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt. + * + * @param f The function to integrate. It must take two arguments x and u. + * @param x The initial value of x. + * @param u The value u held constant over the integration period. + * @param dt The time over which to integrate. + * @return the integration of dx/dt = f(x, u) for dt. + */ + public static VariableMatrix rk4( + BiFunction f, + VariableBlock x, + VariableBlock u, + double dt) { + return rk4(f, new VariableMatrix(x), new VariableMatrix(u), dt); + } + + /** + * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt. + * + * @param f The function to integrate. It must take two arguments x and u. + * @param x The initial value of x. + * @param u The value u held constant over the integration period. + * @param dt The time over which to integrate. + * @return the integration of dx/dt = f(x, u) for dt. + */ + public static VariableMatrix rk4( + BiFunction f, + VariableBlock x, + VariableMatrix u, + double dt) { + return rk4(f, new VariableMatrix(x), u, dt); + } + + /** + * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt. + * + * @param f The function to integrate. It must take two arguments x and u. + * @param x The initial value of x. + * @param u The value u held constant over the integration period. + * @param dt The time over which to integrate. + * @return the integration of dx/dt = f(x, u) for dt. + */ + public static VariableMatrix rk4( + BiFunction f, + VariableMatrix x, + VariableBlock u, + double dt) { + return rk4(f, x, new VariableMatrix(u), dt); + } + + /** + * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt. + * + * @param f The function to integrate. It must take two arguments x and u. + * @param x The initial value of x. + * @param u The value u held constant over the integration period. + * @param dt The time over which to integrate. + * @return the integration of dx/dt = f(x, u) for dt. + */ + public static VariableMatrix rk4( + BiFunction f, + VariableMatrix x, + VariableMatrix u, + double dt) { + var h = dt; + + var k1 = f.apply(x, u); + var k2 = f.apply(x.plus(k1.times(h * 0.5)), u); + var k3 = f.apply(x.plus(k2.times(h * 0.5)), u); + var k4 = f.apply(x.plus(k3.times(h)), u); + + return x.plus(k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4).times(h / 6.0)); + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/Slice.java b/wpimath/src/main/java/org/wpilib/math/autodiff/Slice.java new file mode 100644 index 0000000000..f1f31327ed --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/Slice.java @@ -0,0 +1,267 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import java.util.OptionalInt; + +/** Represents a sequence of elements in an iterable object. */ +@SuppressWarnings("PMD.UnusedFormalParameter") +public class Slice { + /** Type tag used to designate an omitted argument of the slice. */ + public static class None { + /** Default constructor. */ + public None() {} + } + + /** Designates an omitted argument of the slice. */ + public static final None __ = null; + + /** Start index (inclusive). */ + public int start = 0; + + /** Stop index (exclusive). */ + public int stop = 0; + + /** Step. */ + public int step = 1; + + /** Constructs a Slice. */ + public Slice() {} + + /** + * Constructs a slice. + * + * @param start Slice start index (inclusive). + */ + public Slice(None start) { + this(OptionalInt.of(0), OptionalInt.of(Integer.MAX_VALUE), OptionalInt.of(1)); + } + + /** + * Constructs a slice. + * + * @param start Slice start index (inclusive). + */ + public Slice(int start) { + this.start = start; + this.stop = (start == -1) ? Integer.MAX_VALUE : start + 1; + this.step = 1; + } + + /** + * Constructs a slice. + * + * @param start Slice start index (inclusive). + * @param stop Slice stop index (exclusive). + */ + public Slice(None start, None stop) { + this(OptionalInt.empty(), OptionalInt.empty(), OptionalInt.of(1)); + } + + /** + * Constructs a slice. + * + * @param start Slice start index (inclusive). + * @param stop Slice stop index (exclusive). + */ + public Slice(None start, int stop) { + this(OptionalInt.empty(), OptionalInt.of(stop), OptionalInt.of(1)); + } + + /** + * Constructs a slice. + * + * @param start Slice start index (inclusive). + * @param stop Slice stop index (exclusive). + */ + public Slice(int start, None stop) { + this(OptionalInt.of(start), OptionalInt.empty(), OptionalInt.of(1)); + } + + /** + * Constructs a slice. + * + * @param start Slice start index (inclusive). + * @param stop Slice stop index (exclusive). + */ + public Slice(int start, int stop) { + this(OptionalInt.of(start), OptionalInt.of(stop), OptionalInt.of(1)); + } + + /** + * Constructs a slice. + * + * @param start Slice start index (inclusive). + * @param stop Slice stop index (exclusive). + * @param step Slice step. + */ + public Slice(None start, None stop, None step) { + this(OptionalInt.empty(), OptionalInt.empty(), OptionalInt.empty()); + } + + /** + * Constructs a slice. + * + * @param start Slice start index (inclusive). + * @param stop Slice stop index (exclusive). + * @param step Slice step. + */ + public Slice(None start, None stop, int step) { + this(OptionalInt.empty(), OptionalInt.empty(), OptionalInt.of(step)); + } + + /** + * Constructs a slice. + * + * @param start Slice start index (inclusive). + * @param stop Slice stop index (exclusive). + * @param step Slice step. + */ + public Slice(None start, int stop, None step) { + this(OptionalInt.empty(), OptionalInt.of(stop), OptionalInt.empty()); + } + + /** + * Constructs a slice. + * + * @param start Slice start index (inclusive). + * @param stop Slice stop index (exclusive). + * @param step Slice step. + */ + public Slice(None start, int stop, int step) { + this(OptionalInt.empty(), OptionalInt.of(stop), OptionalInt.of(step)); + } + + /** + * Constructs a slice. + * + * @param start Slice start index (inclusive). + * @param stop Slice stop index (exclusive). + * @param step Slice step. + */ + public Slice(int start, None stop, None step) { + this(OptionalInt.of(start), OptionalInt.empty(), OptionalInt.empty()); + } + + /** + * Constructs a slice. + * + * @param start Slice start index (inclusive). + * @param stop Slice stop index (exclusive). + * @param step Slice step. + */ + public Slice(int start, None stop, int step) { + this(OptionalInt.of(start), OptionalInt.empty(), OptionalInt.of(step)); + } + + /** + * Constructs a slice. + * + * @param start Slice start index (inclusive). + * @param stop Slice stop index (exclusive). + * @param step Slice step. + */ + public Slice(int start, int stop, None step) { + this(OptionalInt.of(start), OptionalInt.of(stop), OptionalInt.empty()); + } + + /** + * Constructs a slice. + * + * @param start Slice start index (inclusive). + * @param stop Slice stop index (exclusive). + * @param step Slice step. + */ + public Slice(int start, int stop, int step) { + this(OptionalInt.of(start), OptionalInt.of(stop), OptionalInt.of(step)); + } + + /** + * Constructs a slice. + * + * @param start Slice start index (inclusive). + * @param stop Slice stop index (exclusive). + * @param step Slice step. + */ + public Slice(OptionalInt start, OptionalInt stop, OptionalInt step) { + if (!step.isPresent()) { + this.step = 1; + } else { + assert step.getAsInt() != 0; + + this.step = step.getAsInt(); + } + + // Avoid UB for step = -step if step is INT_MIN + if (this.step == Integer.MIN_VALUE) { + this.step = -Integer.MAX_VALUE; + } + + if (!start.isPresent()) { + if (this.step < 0) { + this.start = Integer.MAX_VALUE; + } else { + this.start = 0; + } + } else { + this.start = start.getAsInt(); + } + + if (!stop.isPresent()) { + if (this.step < 0) { + this.stop = Integer.MIN_VALUE; + } else { + this.stop = Integer.MAX_VALUE; + } + } else { + this.stop = stop.getAsInt(); + } + } + + /** + * Adjusts start and end slice indices assuming a sequence of the specified length. + * + * @param length The sequence length. + * @return The slice length. + */ + public int adjust(int length) { + assert step != 0; + assert step >= -Integer.MAX_VALUE; + + if (start < 0) { + start += length; + + if (start < 0) { + start = (step < 0) ? -1 : 0; + } + } else if (start >= length) { + start = (step < 0) ? length - 1 : length; + } + + if (stop < 0) { + stop += length; + + if (stop < 0) { + stop = (step < 0) ? -1 : 0; + } + } else if (stop >= length) { + stop = (step < 0) ? length - 1 : length; + } + + if (step < 0) { + if (stop < start) { + return (start - stop - 1) / -step + 1; + } else { + return 0; + } + } else { + if (start < stop) { + return (stop - start - 1) / step + 1; + } else { + return 0; + } + } + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/Variable.java b/wpimath/src/main/java/org/wpilib/math/autodiff/Variable.java new file mode 100644 index 0000000000..1750c7405e --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/Variable.java @@ -0,0 +1,634 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +/** An autodiff variable pointing to an expression node. */ +public class Variable implements AutoCloseable { + /** Handle type tag. */ + public static final class Handle { + /** Constructor for Handle. */ + public Handle() {} + } + + /** Instance of handle type tag. */ + public static final Handle HANDLE = new Handle(); + + private long m_handle; + + /** Constructs a linear Variable with a value of zero. */ + @SuppressWarnings("this-escape") + public Variable() { + m_handle = VariableJNI.createDefault(); + VariablePool.register(this); + } + + /** + * Constructs a Variable from a floating point type. + * + * @param value The value of the Variable. + */ + @SuppressWarnings("this-escape") + public Variable(double value) { + m_handle = VariableJNI.createDouble(value); + VariablePool.register(this); + } + + /** + * Constructs a Variable from an integral type. + * + * @param value The value of the Variable. + */ + @SuppressWarnings("this-escape") + public Variable(int value) { + m_handle = VariableJNI.createInt(value); + VariablePool.register(this); + } + + /** + * Constructs a Variable from the given handle. + * + *

This constructor is for internal use only. + * + * @param handleTypeTag Handle type tag. + * @param handle Variable handle. + */ + @SuppressWarnings({"PMD.UnusedFormalParameter", "this-escape"}) + public Variable(Handle handleTypeTag, long handle) { + m_handle = handle; + VariablePool.register(this); + } + + @Override + public void close() { + if (m_handle != 0) { + VariableJNI.destroy(m_handle); + m_handle = 0; + } + } + + /** + * Sets Variable's internal value. + * + * @param value The value of the Variable. + */ + public void setValue(double value) { + VariableJNI.setValue(m_handle, value); + } + + /** + * Returns the value of this variable. + * + * @return The value of this variable. + */ + public double value() { + return VariableJNI.value(m_handle); + } + + /** + * Returns the type of this expression (constant, linear, quadratic, or nonlinear). + * + * @return The type of this expression. + */ + public ExpressionType type() { + return ExpressionType.fromInt(VariableJNI.type(m_handle)); + } + + /** + * Returns internal handle. + * + * @return Internal handle. + */ + public long getHandle() { + return m_handle; + } + + /** + * Variable-Variable multiplication operator. + * + * @param rhs Operator right-hand side. + * @return Result of multiplication. + */ + public Variable times(Variable rhs) { + return new Variable(HANDLE, VariableJNI.times(m_handle, rhs.getHandle())); + } + + /** + * Variable-Variable multiplication operator. + * + * @param rhs Operator right-hand side. + * @return Result of multiplication. + */ + public Variable times(double rhs) { + return times(new Variable(rhs)); + } + + /** + * Variable-Variable division operator. + * + * @param rhs Operator right-hand side. + * @return Result of division. + */ + public Variable div(Variable rhs) { + return new Variable(HANDLE, VariableJNI.div(m_handle, rhs.getHandle())); + } + + /** + * Variable-Variable division operator. + * + * @param rhs Operator right-hand side. + * @return Result of division. + */ + public Variable div(double rhs) { + return div(new Variable(rhs)); + } + + /** + * Variable-Variable addition operator. + * + * @param rhs Operator right-hand side. + * @return Result of addition. + */ + public Variable plus(Variable rhs) { + return new Variable(HANDLE, VariableJNI.plus(m_handle, rhs.getHandle())); + } + + /** + * Variable-Variable addition operator. + * + * @param rhs Operator right-hand side. + * @return Result of addition. + */ + public Variable plus(double rhs) { + return plus(new Variable(rhs)); + } + + /** + * Variable-Variable subtraction operator. + * + * @param rhs Operator right-hand side. + * @return Result of subtraction. + */ + public Variable minus(Variable rhs) { + return new Variable(HANDLE, VariableJNI.minus(m_handle, rhs.getHandle())); + } + + /** + * Variable-Variable subtraction operator. + * + * @param rhs Operator right-hand side. + * @return Result of subtraction. + */ + public Variable minus(double rhs) { + return minus(new Variable(rhs)); + } + + /** + * Unary minus operator. + * + * @return Result of unary minus. + */ + public Variable unaryMinus() { + return new Variable(HANDLE, VariableJNI.unaryMinus(m_handle)); + } + + /** + * Unary plus operator. + * + * @return Result of unary plus. + */ + public Variable unaryPlus() { + return this; + } + + /** + * Math.abs() for Variables. + * + * @param x The argument. + * @return Result of abs(). + */ + public static Variable abs(Variable x) { + return new Variable(HANDLE, VariableJNI.abs(x.getHandle())); + } + + /** + * Math.acos() for Variables. + * + * @param x The argument. + * @return Result of acos(). + */ + public static Variable acos(Variable x) { + return new Variable(HANDLE, VariableJNI.acos(x.getHandle())); + } + + /** + * Math.asin() for Variables. + * + * @param x The argument. + * @return Result of asin(). + */ + public static Variable asin(Variable x) { + return new Variable(HANDLE, VariableJNI.asin(x.getHandle())); + } + + /** + * Math.atan() for Variables. + * + * @param x The argument. + * @return Result of atan(). + */ + public static Variable atan(Variable x) { + return new Variable(HANDLE, VariableJNI.atan(x.getHandle())); + } + + /** + * Math.atan2() for Variables. + * + * @param y The y argument. + * @param x The x argument. + * @return Result of atan2(). + */ + public static Variable atan2(double y, Variable x) { + return atan2(new Variable(y), x); + } + + /** + * Math.atan2() for Variables. + * + * @param y The y argument. + * @param x The x argument. + * @return Result of atan2(). + */ + public static Variable atan2(Variable y, double x) { + return atan2(y, new Variable(x)); + } + + /** + * Math.atan2() for Variables. + * + * @param y The y argument. + * @param x The x argument. + * @return Result of atan2(). + */ + public static Variable atan2(Variable y, Variable x) { + return new Variable(HANDLE, VariableJNI.atan2(y.getHandle(), x.getHandle())); + } + + /** + * Math.cbrt() for Variables. + * + * @param x The argument. + * @return Result of cbrt(). + */ + public static Variable cbrt(Variable x) { + return new Variable(HANDLE, VariableJNI.cbrt(x.getHandle())); + } + + /** + * Math.cos() for Variables. + * + * @param x The argument. + * @return Result of cos(). + */ + public static Variable cos(Variable x) { + return new Variable(HANDLE, VariableJNI.cos(x.getHandle())); + } + + /** + * Math.cosh() for Variables. + * + * @param x The argument. + * @return Result of cosh(). + */ + public static Variable cosh(Variable x) { + return new Variable(HANDLE, VariableJNI.cosh(x.getHandle())); + } + + /** + * Math.exp() for Variables. + * + * @param x The argument. + * @return Result of exp(). + */ + public static Variable exp(Variable x) { + return new Variable(HANDLE, VariableJNI.exp(x.getHandle())); + } + + /** + * Math.hypot() for Variables. + * + * @param x The x argument. + * @param y The y argument. + * @return Result of hypot(). + */ + public static Variable hypot(double x, Variable y) { + return hypot(new Variable(x), y); + } + + /** + * Math.hypot() for Variables. + * + * @param x The x argument. + * @param y The y argument. + * @return Result of hypot(). + */ + public static Variable hypot(Variable x, double y) { + return hypot(x, new Variable(y)); + } + + /** + * Math.hypot() for Variables. + * + * @param x The x argument. + * @param y The y argument. + * @return Result of hypot(). + */ + public static Variable hypot(Variable x, Variable y) { + return new Variable(HANDLE, VariableJNI.hypot(x.getHandle(), y.getHandle())); + } + + /** + * Math.hypot() for Variables. + * + * @param x The x argument. + * @param y The y argument. + * @param z The z argument. + * @return Result of hypot(). + */ + public static Variable hypot(double x, double y, Variable z) { + return hypot(new Variable(x), new Variable(y), z); + } + + /** + * Math.hypot() for Variables. + * + * @param x The x argument. + * @param y The y argument. + * @param z The z argument. + * @return Result of hypot(). + */ + public static Variable hypot(double x, Variable y, double z) { + return hypot(new Variable(x), y, new Variable(z)); + } + + /** + * Math.hypot() for Variables. + * + * @param x The x argument. + * @param y The y argument. + * @param z The z argument. + * @return Result of hypot(). + */ + public static Variable hypot(double x, Variable y, Variable z) { + return hypot(new Variable(x), y, z); + } + + /** + * Math.hypot() for Variables. + * + * @param x The x argument. + * @param y The y argument. + * @param z The z argument. + * @return Result of hypot(). + */ + public static Variable hypot(Variable x, double y, double z) { + return hypot(x, new Variable(y), new Variable(z)); + } + + /** + * Math.hypot() for Variables. + * + * @param x The x argument. + * @param y The y argument. + * @param z The z argument. + * @return Result of hypot(). + */ + public static Variable hypot(Variable x, double y, Variable z) { + return hypot(x, new Variable(y), z); + } + + /** + * Math.hypot() for Variables. + * + * @param x The x argument. + * @param y The y argument. + * @param z The z argument. + * @return Result of hypot(). + */ + public static Variable hypot(Variable x, Variable y, double z) { + return hypot(x, y, new Variable(z)); + } + + /** + * Math.hypot() for Variables. + * + * @param x The x argument. + * @param y The y argument. + * @param z The z argument. + * @return Result of hypot(). + */ + public static Variable hypot(Variable x, Variable y, Variable z) { + return sqrt(pow(x, 2).plus(pow(y, 2)).plus(pow(z, 2))); + } + + /** + * Math.log() for Variables. + * + * @param x The argument. + * @return Result of log(). + */ + public static Variable log(Variable x) { + return new Variable(HANDLE, VariableJNI.log(x.getHandle())); + } + + /** + * Math.log10() for Variables. + * + * @param x The argument. + * @return Result of log10(). + */ + public static Variable log10(Variable x) { + return new Variable(HANDLE, VariableJNI.log10(x.getHandle())); + } + + /** + * Math.max() for Variables. + * + *

Returns the greater of a and b. If the values are equivalent, returns a. + * + * @param a The a argument. + * @param b The b argument. + * @return Result of max(). + */ + public static Variable max(double a, Variable b) { + return max(new Variable(a), b); + } + + /** + * Math.max() for Variables. + * + *

Returns the greater of a and b. If the values are equivalent, returns a. + * + * @param a The a argument. + * @param b The b argument. + * @return Result of max(). + */ + public static Variable max(Variable a, double b) { + return max(a, new Variable(b)); + } + + /** + * Math.max() for Variables. + * + *

Returns the greater of a and b. If the values are equivalent, returns a. + * + * @param a The a argument. + * @param b The b argument. + * @return Result of max(). + */ + public static Variable max(Variable a, Variable b) { + return new Variable(HANDLE, VariableJNI.max(a.getHandle(), b.getHandle())); + } + + /** + * min() for Variables. + * + *

Returns the lesser of a and b. If the values are equivalent, returns a. + * + * @param a The a argument. + * @param b The b argument. + * @return Result of min(). + */ + public static Variable min(double a, Variable b) { + return min(new Variable(a), b); + } + + /** + * min() for Variables. + * + *

Returns the lesser of a and b. If the values are equivalent, returns a. + * + * @param a The a argument. + * @param b The b argument. + * @return Result of min(). + */ + public static Variable min(Variable a, double b) { + return min(a, new Variable(b)); + } + + /** + * min() for Variables. + * + *

Returns the lesser of a and b. If the values are equivalent, returns a. + * + * @param a The a argument. + * @param b The b argument. + * @return Result of min(). + */ + public static Variable min(Variable a, Variable b) { + return new Variable(HANDLE, VariableJNI.min(a.getHandle(), b.getHandle())); + } + + /** + * Math.pow() for Variables. + * + * @param base The base. + * @param power The power. + * @return Result of pow(). + */ + public static Variable pow(double base, Variable power) { + return pow(new Variable(base), power); + } + + /** + * Math.pow() for Variables. + * + * @param base The base. + * @param power The power. + * @return Result of pow(). + */ + public static Variable pow(Variable base, double power) { + return pow(base, new Variable(power)); + } + + /** + * Math.pow() for Variables. + * + * @param base The base. + * @param power The power. + * @return Result of pow(). + */ + public static Variable pow(Variable base, Variable power) { + return new Variable(HANDLE, VariableJNI.pow(base.getHandle(), power.getHandle())); + } + + /** + * Math.signum() for Variables. + * + * @param x The argument. + * @return Result of signum(). + */ + public static Variable signum(Variable x) { + return new Variable(HANDLE, VariableJNI.signum(x.getHandle())); + } + + /** + * Math.sin() for Variables. + * + * @param x The argument. + * @return Result of sin(). + */ + public static Variable sin(Variable x) { + return new Variable(HANDLE, VariableJNI.sin(x.getHandle())); + } + + /** + * Math.sinh() for Variables. + * + * @param x The argument. + * @return Result of sinh(). + */ + public static Variable sinh(Variable x) { + return new Variable(HANDLE, VariableJNI.sinh(x.getHandle())); + } + + /** + * Math.sqrt() for Variables. + * + * @param x The argument. + * @return Result of sqrt(). + */ + public static Variable sqrt(Variable x) { + return new Variable(HANDLE, VariableJNI.sqrt(x.getHandle())); + } + + /** + * Math.tan() for Variables. + * + * @param x The argument. + * @return Result of tan(). + */ + public static Variable tan(Variable x) { + return new Variable(HANDLE, VariableJNI.tan(x.getHandle())); + } + + /** + * Math.tanh() for Variables. + * + * @param x The argument. + * @return Result of tanh(). + */ + public static Variable tanh(Variable x) { + return new Variable(HANDLE, VariableJNI.tanh(x.getHandle())); + } + + /** + * Returns the total native memory usage of all Variables in bytes. + * + * @return The total native memory usage of all Variables in bytes. + */ + public static long totalNativeMemoryUsage() { + return VariableJNI.totalNativeMemoryUsage(); + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/VariableBlock.java b/wpimath/src/main/java/org/wpilib/math/autodiff/VariableBlock.java new file mode 100644 index 0000000000..b92c6133ad --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/VariableBlock.java @@ -0,0 +1,826 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.function.UnaryOperator; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import org.ejml.simple.SimpleMatrix; + +/** A submatrix of autodiff variables with reference semantics. */ +public class VariableBlock implements Iterable { + private final VariableMatrix m_mat; + + private final Slice m_rowSlice; + private final int m_rowSliceLength; + + private final Slice m_colSlice; + private final int m_colSliceLength; + + /** + * Constructs a Variable block pointing to all of the given matrix. + * + * @param mat The matrix to which to point. + */ + public VariableBlock(VariableMatrix mat) { + this(mat, 0, 0, mat.rows(), mat.cols()); + } + + /** + * Constructs a Variable block pointing to a subset of the given matrix. + * + * @param mat The matrix to which to point. + * @param rowOffset The block's row offset. + * @param colOffset The block's column offset. + * @param blockRows The number of rows in the block. + * @param blockCols The number of columns in the block. + */ + public VariableBlock( + VariableMatrix mat, int rowOffset, int colOffset, int blockRows, int blockCols) { + m_mat = mat; + m_rowSlice = new Slice(rowOffset, rowOffset + blockRows, 1); + m_rowSliceLength = m_rowSlice.adjust(mat.rows()); + m_colSlice = new Slice(colOffset, colOffset + blockCols, 1); + m_colSliceLength = m_colSlice.adjust(mat.cols()); + } + + /** + * Constructs a Variable block pointing to a subset of the given matrix. + * + *

Note that the slices are taken as is rather than adjusted. + * + * @param mat The matrix to which to point. + * @param rowSlice The block's row slice. + * @param rowSliceLength The block's row length. + * @param colSlice The block's column slice. + * @param colSliceLength The block's column length. + */ + public VariableBlock( + VariableMatrix mat, Slice rowSlice, int rowSliceLength, Slice colSlice, int colSliceLength) { + m_mat = mat; + m_rowSlice = rowSlice; + m_rowSliceLength = rowSliceLength; + m_colSlice = colSlice; + m_colSliceLength = colSliceLength; + } + + /** + * Assigns a double to the block. + * + *

This only works for blocks with one row and one column. + * + * @param value Value to assign. + * @return This VariableBlock. + */ + public VariableBlock set(double value) { + assert rows() == 1 && cols() == 1; + + set(0, 0, new Variable(value)); + + return this; + } + + /** + * Assigns a Variable to the block. + * + *

This only works for blocks with one row and one column. + * + * @param value Value to assign. + * @return This VariableBlock. + */ + public VariableBlock set(Variable value) { + assert rows() == 1 && cols() == 1; + + set(0, 0, value); + + return this; + } + + /** + * Assigns a double array to the block. + * + * @param values Double array of values to assign. + * @return This VariableBlock. + */ + public VariableBlock set(double[][] values) { + assert rows() == values.length; + + // Assert all column counts are the same + for (var row : values) { + assert row.length == cols(); + } + + for (int row = 0; row < rows(); ++row) { + for (int col = 0; col < cols(); ++col) { + set(row, col, values[row][col]); + } + } + + return this; + } + + /** + * Assigns an EJML matrix to the block. + * + * @param values EJML matrix of values to assign. + * @return This VariableBlock. + */ + public VariableBlock set(SimpleMatrix values) { + assert rows() == values.getNumRows() && cols() == values.getNumCols(); + + for (int row = 0; row < rows(); ++row) { + for (int col = 0; col < cols(); ++col) { + set(row, col, values.get(row, col)); + } + } + + return this; + } + + /** + * Assigns a VariableMatrix to the block. + * + * @param values VariableMatrix of values. + * @return This VariableBlock. + */ + public VariableBlock set(VariableMatrix values) { + assert rows() == values.rows() && cols() == values.cols(); + + for (int row = 0; row < rows(); ++row) { + for (int col = 0; col < cols(); ++col) { + set(row, col, values.get(row, col)); + } + } + return this; + } + + /** + * Assigns a VariableBlock to the block. + * + * @param values VariableBlock of values. + * @return This VariableBlock. + */ + public VariableBlock set(VariableBlock values) { + assert rows() == values.rows() && cols() == values.cols(); + + for (int row = 0; row < rows(); ++row) { + for (int col = 0; col < cols(); ++col) { + set(row, col, values.get(row, col)); + } + } + return this; + } + + /** + * Sets a scalar subblock at the given row and column. + * + * @param row The scalar subblock's row. + * @param col The scalar subblock's column. + * @param value The value. + */ + public void set(int row, int col, Variable value) { + assert row >= 0 && row < rows(); + assert col >= 0 && col < cols(); + m_mat.set( + m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step, value); + } + + /** + * Sets a scalar subblock at the given row and column. + * + * @param row The scalar subblock's row. + * @param col The scalar subblock's column. + * @param value The value. + */ + public void set(int row, int col, double value) { + assert row >= 0 && row < rows(); + assert col >= 0 && col < cols(); + m_mat.set( + m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step, value); + } + + /** + * Sets a scalar subblock at the given index. + * + * @param index The scalar subblock's index. + * @param value The value. + */ + public void set(int index, double value) { + set(index, new Variable(value)); + } + + /** + * Sets a scalar subblock at the given index. + * + * @param index The scalar subblock's index. + * @param value The value. + */ + public void set(int index, Variable value) { + assert index >= 0 && index < rows() * cols(); + set(index / cols(), index % cols(), value); + } + + /** + * Assigns a double to the block. + * + *

This only works for blocks with one row and one column. + * + * @param value Value to assign. + */ + public void setValue(double value) { + assert rows() == 1 && cols() == 1; + + get(0, 0).setValue(value); + } + + /** + * Sets block's internal values. + * + * @param values Double array of values. + */ + public void setValue(double[][] values) { + assert rows() == values.length; + + // Assert all column counts are the same + for (var row : values) { + assert row.length == cols(); + } + + for (int row = 0; row < rows(); ++row) { + for (int col = 0; col < cols(); ++col) { + get(row, col).setValue(values[row][col]); + } + } + } + + /** + * Sets block's internal values. + * + * @param values EJML matrix of values. + */ + public void setValue(SimpleMatrix values) { + assert rows() == values.getNumRows() && cols() == values.getNumCols(); + + for (int row = 0; row < rows(); ++row) { + for (int col = 0; col < cols(); ++col) { + get(row, col).setValue(values.get(row, col)); + } + } + } + + /** + * Returns a scalar subblock at the given row and column. + * + * @param row The scalar subblock's row. + * @param col The scalar subblock's column. + * @return A scalar subblock at the given row and column. + */ + public Variable get(int row, int col) { + assert row >= 0 && row < rows(); + assert col >= 0 && col < cols(); + return m_mat.get( + m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step); + } + + /** + * Returns a scalar subblock at the given index. + * + * @param index The scalar subblock's index. + * @return A scalar subblock at the given index. + */ + public Variable get(int index) { + assert index >= 0 && index < rows() * cols(); + return get(index / cols(), index % cols()); + } + + /** + * Returns a slice of the variable matrix. + * + * @param row The row. + * @param colSlice The column slice. + * @return A slice of the variable matrix. + */ + public VariableBlock get(int row, Slice.None colSlice) { + return get(new Slice(row), new Slice(colSlice)); + } + + /** + * Returns a slice of the variable matrix. + * + * @param row The row. + * @param colSlice The column slice. + * @return A slice of the variable matrix. + */ + public VariableBlock get(int row, Slice colSlice) { + return get(new Slice(row), colSlice); + } + + /** + * Returns a slice of the variable matrix. + * + * @param rowSlice The row slice. + * @param col The column. + * @return A slice of the variable matrix. + */ + public VariableBlock get(Slice.None rowSlice, int col) { + return get(new Slice(rowSlice), new Slice(col)); + } + + /** + * Returns a slice of the variable matrix. + * + * @param rowSlice The row slice. + * @param col The column. + * @return A slice of the variable matrix. + */ + public VariableBlock get(Slice rowSlice, int col) { + return get(rowSlice, new Slice(col)); + } + + /** + * Returns a slice of the variable matrix. + * + * @param rowSlice The row slice. + * @param colSlice The column slice. + * @return A slice of the variable matrix. + */ + public VariableBlock get(Slice.None rowSlice, Slice.None colSlice) { + return get(new Slice(rowSlice), new Slice(colSlice)); + } + + /** + * Returns a slice of the variable matrix. + * + * @param rowSlice The row slice. + * @param colSlice The column slice. + * @return A slice of the variable matrix. + */ + public VariableBlock get(Slice.None rowSlice, Slice colSlice) { + return get(new Slice(rowSlice), colSlice); + } + + /** + * Returns a slice of the variable matrix. + * + * @param rowSlice The row slice. + * @param colSlice The column slice. + * @return A slice of the variable matrix. + */ + public VariableBlock get(Slice rowSlice, Slice.None colSlice) { + return get(rowSlice, new Slice(colSlice)); + } + + /** + * Returns a slice of the variable matrix. + * + * @param rowSlice The row slice. + * @param colSlice The column slice. + * @return A slice of the variable matrix. + */ + public VariableBlock get(Slice rowSlice, Slice colSlice) { + int rowSliceLength = rowSlice.adjust(m_rowSliceLength); + int colSliceLength = colSlice.adjust(m_colSliceLength); + return new VariableBlock( + m_mat, + new Slice( + m_rowSlice.start + rowSlice.start * m_rowSlice.step, + m_rowSlice.start + rowSlice.stop * m_rowSlice.step, + rowSlice.step * m_rowSlice.step), + rowSliceLength, + new Slice( + m_colSlice.start + colSlice.start * m_colSlice.step, + m_colSlice.start + colSlice.stop * m_colSlice.step, + colSlice.step * m_colSlice.step), + colSliceLength); + } + + /** + * Returns a block of the variable matrix. + * + * @param rowOffset The row offset of the block selection. + * @param colOffset The column offset of the block selection. + * @param blockRows The number of rows in the block selection. + * @param blockCols The number of columns in the block selection. + * @return A block of the variable matrix. + */ + public VariableBlock block(int rowOffset, int colOffset, int blockRows, int blockCols) { + assert rowOffset >= 0 && rowOffset <= rows(); + assert colOffset >= 0 && colOffset <= cols(); + assert blockRows >= 0 && blockRows <= rows() - rowOffset; + assert blockCols >= 0 && blockCols <= cols() - colOffset; + return get( + new Slice(rowOffset, rowOffset + blockRows, 1), + new Slice(colOffset, colOffset + blockCols, 1)); + } + + /** + * Returns a segment of the variable vector. + * + * @param offset The offset of the segment. + * @param length The length of the segment. + * @return A segment of the variable vector. + */ + public VariableBlock segment(int offset, int length) { + assert cols() == 1; + assert offset >= 0 && offset < rows(); + assert length >= 0 && length <= rows() - offset; + return block(offset, 0, length, 1); + } + + /** + * Returns a row slice of the variable matrix. + * + * @param row The row to slice. + * @return A row slice of the variable matrix. + */ + public VariableBlock row(int row) { + assert row >= 0 && row < rows(); + return block(row, 0, 1, cols()); + } + + /** + * Returns a column slice of the variable matrix. + * + * @param col The column to slice. + * @return A column slice of the variable matrix. + */ + public VariableBlock col(int col) { + assert col >= 0 && col < cols(); + return block(0, col, rows(), 1); + } + + /** + * Matrix multiplication operator. + * + * @param rhs Operator right-hand side. + * @return Result of matrix multiplication. + */ + public VariableMatrix times(VariableMatrix rhs) { + assert cols() == rhs.rows(); + + var result = new VariableMatrix(rows(), rhs.cols()); + + for (int i = 0; i < rows(); ++i) { + for (int j = 0; j < rhs.cols(); ++j) { + var sum = new Variable(0.0); + for (int k = 0; k < cols(); ++k) { + sum = sum.plus(get(i, k).times(rhs.get(k, j))); + } + result.set(i, j, sum); + } + } + + return result; + } + + /** + * Matrix multiplication operator. + * + * @param rhs Operator right-hand side. + * @return Result of matrix multiplication. + */ + public VariableMatrix times(VariableBlock rhs) { + assert cols() == rhs.rows(); + + var result = new VariableMatrix(rows(), rhs.cols()); + + for (int i = 0; i < rows(); ++i) { + for (int j = 0; j < rhs.cols(); ++j) { + var sum = new Variable(0.0); + for (int k = 0; k < cols(); ++k) { + sum = sum.plus(get(i, k).times(rhs.get(k, j))); + } + result.set(i, j, sum); + } + } + + return result; + } + + /** + * Matrix-scalar multiplication operator. + * + * @param rhs Operator right-hand side. + * @return Result of matrix-scalar multiplication. + */ + public VariableMatrix times(double rhs) { + return times(new Variable(rhs)); + } + + /** + * Matrix-scalar multiplication operator. + * + * @param rhs Operator right-hand side. + * @return Result of matrix-scalar multiplication. + */ + public VariableMatrix times(Variable rhs) { + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).times(rhs)); + } + } + + return result; + } + + /** + * Binary division operator. + * + * @param rhs Operator right-hand side. + * @return Result of division. + */ + public VariableMatrix div(double rhs) { + return div(new Variable(rhs)); + } + + /** + * Binary division operator. + * + * @param rhs Operator right-hand side. + * @return Result of division. + */ + public VariableMatrix div(Variable rhs) { + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).div(rhs)); + } + } + + return result; + } + + /** + * Binary addition operator. + * + * @param rhs Operator right-hand side. + * @return Result of addition. + */ + public VariableMatrix plus(VariableMatrix rhs) { + assert rows() == rhs.rows() && cols() == rhs.cols(); + + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).plus(rhs.get(row, col))); + } + } + + return result; + } + + /** + * Binary addition operator. + * + * @param rhs Operator right-hand side. + * @return Result of addition. + */ + public VariableMatrix plus(VariableBlock rhs) { + assert rows() == rhs.rows() && cols() == rhs.cols(); + + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).plus(rhs.get(row, col))); + } + } + + return result; + } + + /** + * Binary addition operator. + * + * @param rhs Operator right-hand side. + * @return Result of addition. + */ + public VariableMatrix plus(SimpleMatrix rhs) { + assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols(); + + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).plus(rhs.get(row, col))); + } + } + + return result; + } + + /** + * Binary subtraction operator. + * + * @param rhs Operator right-hand side. + * @return Result of subtraction. + */ + public VariableMatrix minus(VariableMatrix rhs) { + assert rows() == rhs.rows() && cols() == rhs.cols(); + + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).minus(rhs.get(row, col))); + } + } + + return result; + } + + /** + * Binary subtraction operator. + * + * @param rhs Operator right-hand side. + * @return Result of subtraction. + */ + public VariableMatrix minus(VariableBlock rhs) { + assert rows() == rhs.rows() && cols() == rhs.cols(); + + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).minus(rhs.get(row, col))); + } + } + + return result; + } + + /** + * Binary subtraction operator. + * + * @param rhs Operator right-hand side. + * @return Result of subtraction. + */ + public VariableMatrix minus(SimpleMatrix rhs) { + assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols(); + + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).minus(rhs.get(row, col))); + } + } + + return result; + } + + /** + * Unary minus operator. + * + * @return Result of unary minus. + */ + public VariableMatrix unaryMinus() { + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).unaryMinus()); + } + } + + return result; + } + + /** + * Returns the transpose of the variable matrix. + * + * @return The transpose of the variable matrix. + */ + public VariableMatrix T() { + var result = new VariableMatrix(cols(), rows()); + + for (int row = 0; row < rows(); ++row) { + for (int col = 0; col < cols(); ++col) { + result.set(col, row, get(row, col)); + } + } + + return result; + } + + /** + * Returns the number of rows in the matrix. + * + * @return The number of rows in the matrix. + */ + public int rows() { + return m_rowSliceLength; + } + + /** + * Returns the number of columns in the matrix. + * + * @return The number of columns in the matrix. + */ + public int cols() { + return m_colSliceLength; + } + + /** + * Returns an element of the variable matrix. + * + * @param row The row of the element to return. + * @param col The column of the element to return. + * @return An element of the variable matrix. + */ + public double value(int row, int col) { + return get(row, col).value(); + } + + /** + * Returns an element of the variable block. + * + * @param index The index of the element to return. + * @return An element of the variable block. + */ + public double value(int index) { + return get(index).value(); + } + + /** + * Returns the contents of the variable matrix. + * + * @return The contents of the variable matrix. + */ + public SimpleMatrix value() { + var result = new SimpleMatrix(rows(), cols()); + + for (int row = 0; row < rows(); ++row) { + for (int col = 0; col < cols(); ++col) { + result.set(row, col, value(row, col)); + } + } + + return result; + } + + /** + * Maps the matrix coefficient-wise with an unary operator. + * + * @param unaryOp The unary operator to use for the map operation. + * @return Result of the unary operator. + */ + public VariableMatrix cwiseMap(UnaryOperator unaryOp) { + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < rows(); ++row) { + for (int col = 0; col < cols(); ++col) { + result.set(row, col, unaryOp.apply(get(row, col))); + } + } + + return result; + } + + /** + * Returns number of elements in matrix. + * + * @return Number of elements in matrix. + */ + public int size() { + return rows() * cols(); + } + + @Override + public Iterator iterator() { + return new Iterator<>() { + private int m_index = 0; + + @Override + public boolean hasNext() { + return m_index < VariableBlock.this.size(); + } + + @Override + public Variable next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + return VariableBlock.this.get(m_index++); + } + }; + } + + /** + * Creates a Stream of VariableBlock elements. + * + * @return A Stream of VariableBlock elements. + */ + public Stream stream() { + return StreamSupport.stream(spliterator(), false); + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/VariableJNI.java b/wpimath/src/main/java/org/wpilib/math/autodiff/VariableJNI.java new file mode 100644 index 0000000000..d0de949ded --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/VariableJNI.java @@ -0,0 +1,268 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import org.wpilib.math.jni.WPIMathJNI; + +/** Variable JNI functions. */ +final class VariableJNI extends WPIMathJNI { + private VariableJNI() { + // Utility class. + } + + /** Constructs a default Variable. */ + static native long createDefault(); + + /** + * Constructs a Variable from a floating point type. + * + * @param value The value of the Variable. + */ + static native long createDouble(double value); + + /** + * Constructs a Variable from an integral type. + * + * @param value The value of the Variable. + */ + static native long createInt(int value); + + /** + * Destructs a Variable. + * + * @param handle Variable handle. + */ + static native void destroy(long handle); + + /** + * Sets Variable's internal value. + * + * @param handle Variable handle. + * @param value The value of the Variable. + */ + static native void setValue(long handle, double value); + + /** + * Variable-Variable multiplication operator. + * + * @param handle Variable handle. + * @param rhs Operator right-hand side. + * @return Result of multiplication. + */ + static native long times(long handle, long rhs); + + /** + * Variable-Variable division operator. + * + * @param handle Variable handle. + * @param rhs Operator right-hand side. + * @return Result of division. + */ + static native long div(long handle, long rhs); + + /** + * Variable-Variable addition operator. + * + * @param handle Variable handle. + * @param rhs Operator right-hand side. + * @return Result of addition. + */ + static native long plus(long handle, long rhs); + + /** + * Variable-Variable subtraction operator. + * + * @param handle Variable handle. + * @param rhs Operator right-hand side. + * @return Result of subtraction. + */ + static native long minus(long handle, long rhs); + + /** + * Unary minus operator. + * + * @param handle Variable handle. + */ + static native long unaryMinus(long handle); + + /** + * Returns the value of this variable. + * + * @param handle Variable handle. + * @return The value of this variable. + */ + static native double value(long handle); + + /** + * Returns the type of this expression (constant, linear, quadratic, or nonlinear). + * + * @param handle Variable handle. + * @return The type of this expression. + */ + static native int type(long handle); + + /** + * Math.abs() for Variables. + * + * @param x The argument. + */ + static native long abs(long x); + + /** + * Math.acos() for Variables. + * + * @param x The argument. + */ + static native long acos(long x); + + /** + * Math.asin() for Variables. + * + * @param x The argument. + */ + static native long asin(long x); + + /** + * Math.atan() for Variables. + * + * @param x The argument. + */ + static native long atan(long x); + + /** + * Math.atan2() for Variables. + * + * @param y The y argument. + * @param x The x argument. + */ + static native long atan2(long y, long x); + + /** + * Math.cbrt() for Variables. + * + * @param x The argument. + */ + static native long cbrt(long x); + + /** + * Math.cos() for Variables. + * + * @param x The argument. + */ + static native long cos(long x); + + /** + * Math.cosh() for Variables. + * + * @param x The argument. + */ + static native long cosh(long x); + + /** + * Math.exp() for Variables. + * + * @param x The argument. + */ + static native long exp(long x); + + /** + * Math.hypot() for Variables. + * + * @param x The x argument. + * @param y The y argument. + */ + static native long hypot(long x, long y); + + /** + * Math.log() for Variables. + * + * @param x The argument. + */ + static native long log(long x); + + /** + * Math.log10() for Variables. + * + * @param x The argument. + */ + static native long log10(long x); + + /** + * Math.max() for Variables. + * + *

Returns the greater of a and b. If the values are equivalent, returns a. + * + * @param a The a argument. + * @param b The b argument. + */ + static native long max(long a, long b); + + /** + * Math.min() for Variables. + * + *

Returns the lesser of a and b. If the values are equivalent, returns a. + * + * @param a The a argument. + * @param b The b argument. + */ + static native long min(long a, long b); + + /** + * Math.pow() for Variables. + * + * @param base The base. + * @param power The power. + */ + static native long pow(long base, long power); + + /** + * Math.signum() for Variables. + * + * @param x The argument. + */ + static native long signum(long x); + + /** + * Math.sin() for Variables. + * + * @param x The argument. + */ + static native long sin(long x); + + /** + * Math.sinh() for Variables. + * + * @param x The argument. + */ + static native long sinh(long x); + + /** + * Math.sqrt() for Variables. + * + * @param x The argument. + */ + static native long sqrt(long x); + + /** + * Math.tan() for Variables. + * + * @param x The argument. + */ + static native long tan(long x); + + /** + * Math.tanh() for Variables. + * + * @param x The argument. + */ + static native long tanh(long x); + + /** + * Returns the total native memory usage of Variables in bytes. + * + * @return The total native memory usage of Variables in bytes. + */ + static native long totalNativeMemoryUsage(); +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/VariableMatrix.java b/wpimath/src/main/java/org/wpilib/math/autodiff/VariableMatrix.java new file mode 100644 index 0000000000..1c7a7ffb4d --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/VariableMatrix.java @@ -0,0 +1,1047 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.function.BinaryOperator; +import java.util.function.UnaryOperator; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import org.ejml.simple.SimpleMatrix; + +/** A matrix of autodiff variables. */ +public class VariableMatrix implements AutoCloseable, Iterable { + private final Variable[] m_storage; + private int m_rows; + private int m_cols; + + /** + * Constructs a VariableMatrix from Variable internal handles. + * + *

This constructor is for internal use only. + * + * @param rows The number of matrix rows. + * @param cols The number of matrix columns. + * @param handles Variable handles. + */ + public VariableMatrix(int rows, int cols, long[] handles) { + assert handles.length == rows * cols; + + m_rows = rows; + m_cols = cols; + m_storage = new Variable[rows * cols]; + for (int index = 0; index < m_storage.length; ++index) { + m_storage[index] = new Variable(Variable.HANDLE, handles[index]); + } + } + + /** + * Constructs a zero-initialized VariableMatrix column vector with the given rows. + * + * @param rows The number of matrix rows. + */ + public VariableMatrix(int rows) { + this(rows, 1); + } + + /** + * Constructs a zero-initialized VariableMatrix with the given dimensions. + * + * @param rows The number of matrix rows. + * @param cols The number of matrix columns. + */ + public VariableMatrix(int rows, int cols) { + m_rows = rows; + m_cols = cols; + m_storage = new Variable[rows * cols]; + for (int index = 0; index < m_storage.length; ++index) { + m_storage[index] = new Variable(); + } + } + + /** + * Constructs a scalar VariableMatrix from a nested list of doubles. + * + * @param list The nested list of Variables. + */ + public VariableMatrix(double[][] list) { + // Get row and column counts for destination matrix + m_rows = list.length; + m_cols = 0; + if (list.length > 0) { + m_cols = list[0].length; + } + + // Assert all column counts are the same + for (var row : list) { + assert row.length == m_cols; + } + + m_storage = new Variable[m_rows * m_cols]; + int index = 0; + for (var row : list) { + for (var elem : row) { + m_storage[index] = new Variable(elem); + ++index; + } + } + } + + /** + * Constructs a scalar VariableMatrix from a nested list of Variables. + * + * @param list The nested list of Variables. + */ + public VariableMatrix(Variable[][] list) { + // Get row and column counts for destination matrix + m_rows = list.length; + m_cols = 0; + if (list.length > 0) { + m_cols = list[0].length; + } + + // Assert all column counts are the same + for (var row : list) { + assert row.length == m_cols; + } + + m_storage = new Variable[m_rows * m_cols]; + int index = 0; + for (var row : list) { + for (var elem : row) { + m_storage[index] = elem; + ++index; + } + } + } + + /** + * Constructs a VariableMatrix from an EJML matrix. + * + * @param values EJML matrix of values. + */ + public VariableMatrix(SimpleMatrix values) { + m_rows = values.getNumRows(); + m_cols = values.getNumCols(); + m_storage = new Variable[m_rows * m_cols]; + for (int row = 0; row < values.getNumRows(); ++row) { + for (int col = 0; col < values.getNumCols(); ++col) { + m_storage[row * m_cols + col] = new Variable(values.get(row, col)); + } + } + } + + /** + * Constructs a scalar VariableMatrix from a Variable. + * + * @param variable Variable. + */ + public VariableMatrix(Variable variable) { + m_rows = 1; + m_cols = 1; + m_storage = new Variable[] {variable}; + } + + /** + * Constructs a VariableMatrix from a VariableBlock. + * + * @param values VariableBlock of values. + */ + public VariableMatrix(VariableBlock values) { + m_rows = values.rows(); + m_cols = values.cols(); + m_storage = new Variable[m_rows * m_cols]; + for (int row = 0; row < m_rows; ++row) { + for (int col = 0; col < m_cols; ++col) { + m_storage[row * m_cols + col] = values.get(row, col); + } + } + } + + @Override + public void close() { + for (int index = 0; index < rows() * cols(); ++index) { + m_storage[index].close(); + } + } + + /** + * Assigns a double array to a VariableMatrix. + * + * @param values Double array of values. + * @return This VariableMatrix. + */ + public VariableMatrix set(double[][] values) { + assert rows() == values.length; + + // Assert all column counts are the same + for (var row : values) { + assert row.length == cols(); + } + + for (int row = 0; row < values.length; ++row) { + for (int col = 0; col < values[0].length; ++col) { + set(row, col, values[row][col]); + } + } + + return this; + } + + /** + * Assigns an EJML matrix to a VariableMatrix. + * + * @param values EJML matrix of values. + * @return This VariableMatrix. + */ + public VariableMatrix set(SimpleMatrix values) { + assert rows() == values.getNumRows() && cols() == values.getNumCols(); + + for (int row = 0; row < values.getNumRows(); ++row) { + for (int col = 0; col < values.getNumCols(); ++col) { + set(row, col, values.get(row, col)); + } + } + + return this; + } + + /** + * Assigns a VariableMatrix to a VariableMatrix. + * + * @param values VariableMatrix of values. + * @return This VariableMatrix. + */ + public VariableMatrix set(VariableMatrix values) { + assert rows() == values.rows() && cols() == values.cols(); + + for (int row = 0; row < values.rows(); ++row) { + for (int col = 0; col < values.cols(); ++col) { + set(row, col, values.get(row, col)); + } + } + + return this; + } + + /** + * Assigns a VariableBlock to a VariableMatrix. + * + * @param values VariableBlock of values. + * @return This VariableMatrix. + */ + public VariableMatrix set(VariableBlock values) { + assert rows() == values.rows() && cols() == values.cols(); + + for (int row = 0; row < values.rows(); ++row) { + for (int col = 0; col < values.cols(); ++col) { + set(row, col, values.get(row, col)); + } + } + + return this; + } + + /** + * Assigns a double to the matrix. + * + *

This only works for matrices with one row and one column. + * + * @param value Value to assign. + * @return This VariableMatrix. + */ + public VariableMatrix set(double value) { + return set(new Variable(value)); + } + + /** + * Assigns a Variable to the matrix. + * + *

This only works for matrices with one row and one column. + * + * @param value Value to assign. + * @return This VariableMatrix. + */ + public VariableMatrix set(Variable value) { + assert rows() == 1 && cols() == 1; + + m_storage[0] = value; + + return this; + } + + /** + * Sets an element to the given value. + * + * @param row The row. + * @param col The column. + * @param value The value. + */ + public void set(int row, int col, Variable value) { + assert row >= 0 && row < rows(); + assert col >= 0 && col < cols(); + m_storage[row * cols() + col] = value; + } + + /** + * Sets an element to the given value. + * + * @param row The row. + * @param col The column. + * @param value The value. + */ + public void set(int row, int col, double value) { + assert row >= 0 && row < rows(); + assert col >= 0 && col < cols(); + m_storage[row * cols() + col] = new Variable(value); + } + + /** + * Sets an element to the given value. + * + * @param index The index of the element. + * @param value The value. + */ + public void set(int index, double value) { + set(index, new Variable(value)); + } + + /** + * Sets an element to the given value. + * + * @param index The index of the element. + * @param value The value. + */ + public void set(int index, Variable value) { + assert index >= 0 && index < rows() * cols(); + m_storage[index] = value; + } + + /** + * Sets the VariableMatrix's internal values. + * + * @param values Double array of values. + */ + public void setValue(double[][] values) { + assert rows() == values.length; + + // Assert all column counts are the same + for (var row : values) { + assert row.length == cols(); + } + + for (int row = 0; row < rows(); ++row) { + for (int col = 0; col < cols(); ++col) { + get(row, col).setValue(values[row][col]); + } + } + } + + /** + * Sets the VariableMatrix's internal values. + * + * @param values EJML matrix of values. + */ + public void setValue(SimpleMatrix values) { + assert rows() == values.getNumRows() && cols() == values.getNumCols(); + + for (int row = 0; row < values.getNumRows(); ++row) { + for (int col = 0; col < values.getNumCols(); ++col) { + get(row, col).setValue(values.get(row, col)); + } + } + } + + /** + * Returns the element at the given row and column. + * + * @param row The row. + * @param col The column. + * @return The element at the given row and column. + */ + public Variable get(int row, int col) { + assert row >= 0 && row < rows(); + assert col >= 0 && col < cols(); + return m_storage[row * cols() + col]; + } + + /** + * Returns the element at the given index. + * + * @param index The index. + * @return The element at the given index. + */ + public Variable get(int index) { + assert index >= 0 && index < rows() * cols(); + return m_storage[index]; + } + + /** + * Returns a slice of the variable matrix. + * + * @param row The row. + * @param colSlice The column slice. + * @return A slice of the variable matrix. + */ + public VariableBlock get(int row, Slice.None colSlice) { + return get(new Slice(row), new Slice(colSlice)); + } + + /** + * Returns a slice of the variable matrix. + * + * @param row The row. + * @param colSlice The column slice. + * @return A slice of the variable matrix. + */ + public VariableBlock get(int row, Slice colSlice) { + return get(new Slice(row), colSlice); + } + + /** + * Returns a slice of the variable matrix. + * + * @param rowSlice The row slice. + * @param col The column. + * @return A slice of the variable matrix. + */ + public VariableBlock get(Slice.None rowSlice, int col) { + return get(new Slice(rowSlice), new Slice(col)); + } + + /** + * Returns a slice of the variable matrix. + * + * @param rowSlice The row slice. + * @param col The column. + * @return A slice of the variable matrix. + */ + public VariableBlock get(Slice rowSlice, int col) { + return get(rowSlice, new Slice(col)); + } + + /** + * Returns a slice of the variable matrix. + * + * @param rowSlice The row slice. + * @param colSlice The column slice. + * @return A slice of the variable matrix. + */ + public VariableBlock get(Slice.None rowSlice, Slice.None colSlice) { + return get(new Slice(rowSlice), new Slice(colSlice)); + } + + /** + * Returns a slice of the variable matrix. + * + * @param rowSlice The row slice. + * @param colSlice The column slice. + * @return A slice of the variable matrix. + */ + public VariableBlock get(Slice.None rowSlice, Slice colSlice) { + return get(new Slice(rowSlice), colSlice); + } + + /** + * Returns a slice of the variable matrix. + * + * @param rowSlice The row slice. + * @param colSlice The column slice. + * @return A slice of the variable matrix. + */ + public VariableBlock get(Slice rowSlice, Slice.None colSlice) { + return get(rowSlice, new Slice(colSlice)); + } + + /** + * Returns a slice of the variable matrix. + * + * @param rowSlice The row slice. + * @param colSlice The column slice. + * @return A slice of the variable matrix. + */ + public VariableBlock get(Slice rowSlice, Slice colSlice) { + int rowSliceLength = rowSlice.adjust(rows()); + int colSliceLength = colSlice.adjust(cols()); + return new VariableBlock(this, rowSlice, rowSliceLength, colSlice, colSliceLength); + } + + /** + * Returns a block of the variable matrix. + * + * @param rowOffset The row offset of the block selection. + * @param colOffset The column offset of the block selection. + * @param blockRows The number of rows in the block selection. + * @param blockCols The number of columns in the block selection. + * @return A block of the variable matrix. + */ + public VariableBlock block(int rowOffset, int colOffset, int blockRows, int blockCols) { + assert rowOffset >= 0 && rowOffset <= rows(); + assert colOffset >= 0 && colOffset <= cols(); + assert blockRows >= 0 && blockRows <= rows() - rowOffset; + assert blockCols >= 0 && blockCols <= cols() - colOffset; + return new VariableBlock(this, rowOffset, colOffset, blockRows, blockCols); + } + + /** + * Returns a segment of the variable vector. + * + * @param offset The offset of the segment. + * @param length The length of the segment. + * @return A segment of the variable vector. + */ + public VariableBlock segment(int offset, int length) { + assert cols() == 1; + assert offset >= 0 && offset < rows(); + assert length >= 0 && length <= rows() - offset; + return block(offset, 0, length, 1); + } + + /** + * Returns a row slice of the variable matrix. + * + * @param row The row to slice. + * @return A row slice of the variable matrix. + */ + public VariableBlock row(int row) { + assert row >= 0 && row < rows(); + return block(row, 0, 1, cols()); + } + + /** + * Returns a column slice of the variable matrix. + * + * @param col The column to slice. + * @return A column slice of the variable matrix. + */ + public VariableBlock col(int col) { + assert col >= 0 && col < cols(); + return block(0, col, rows(), 1); + } + + /** + * Matrix multiplication operator. + * + * @param rhs Operator right-hand side. + * @return Result of matrix multiplication. + */ + public VariableMatrix times(VariableMatrix rhs) { + assert cols() == rhs.rows(); + + var result = new VariableMatrix(rows(), rhs.cols()); + + for (int i = 0; i < rows(); ++i) { + for (int j = 0; j < rhs.cols(); ++j) { + var sum = new Variable(0.0); + for (int k = 0; k < cols(); ++k) { + sum = sum.plus(get(i, k).times(rhs.get(k, j))); + } + result.set(i, j, sum); + } + } + + return result; + } + + /** + * Matrix multiplication operator. + * + * @param rhs Operator right-hand side. + * @return Result of matrix multiplication. + */ + public VariableMatrix times(VariableBlock rhs) { + assert cols() == rhs.rows(); + + var result = new VariableMatrix(rows(), rhs.cols()); + + for (int i = 0; i < rows(); ++i) { + for (int j = 0; j < rhs.cols(); ++j) { + var sum = new Variable(0.0); + for (int k = 0; k < cols(); ++k) { + sum = sum.plus(get(i, k).times(rhs.get(k, j))); + } + result.set(i, j, sum); + } + } + + return result; + } + + /** + * Matrix multiplication operator. + * + * @param rhs Operator right-hand side. + * @return Result of matrix multiplication. + */ + public VariableMatrix times(SimpleMatrix rhs) { + return times(new VariableMatrix(rhs)); + } + + /** + * Matrix-scalar multiplication operator. + * + * @param rhs Operator right-hand side. + * @return Result of matrix-scalar multiplication. + */ + public VariableMatrix times(double rhs) { + return times(new Variable(rhs)); + } + + /** + * Matrix-scalar multiplication operator. + * + * @param rhs Operator right-hand side. + * @return Result of matrix-scalar multiplication. + */ + public VariableMatrix times(Variable rhs) { + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).times(rhs)); + } + } + + return result; + } + + /** + * Binary division operator. + * + * @param rhs Operator right-hand side. + * @return Result of division. + */ + public VariableMatrix div(double rhs) { + return div(new Variable(rhs)); + } + + /** + * Binary division operator. + * + * @param rhs Operator right-hand side. + * @return Result of division. + */ + public VariableMatrix div(Variable rhs) { + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).div(rhs)); + } + } + + return result; + } + + /** + * Binary addition operator. + * + * @param rhs Operator right-hand side. + * @return Result of addition. + */ + public VariableMatrix plus(VariableMatrix rhs) { + assert rows() == rhs.rows() && cols() == rhs.cols(); + + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).plus(rhs.get(row, col))); + } + } + + return result; + } + + /** + * Binary addition operator. + * + * @param rhs Operator right-hand side. + * @return Result of addition. + */ + public VariableMatrix plus(VariableBlock rhs) { + assert rows() == rhs.rows() && cols() == rhs.cols(); + + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).plus(rhs.get(row, col))); + } + } + + return result; + } + + /** + * Binary addition operator. + * + * @param rhs Operator right-hand side. + * @return Result of addition. + */ + public VariableMatrix plus(SimpleMatrix rhs) { + assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols(); + + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).plus(rhs.get(row, col))); + } + } + + return result; + } + + /** + * Binary subtraction operator. + * + * @param rhs Operator right-hand side. + * @return Result of subtraction. + */ + public VariableMatrix minus(VariableMatrix rhs) { + assert rows() == rhs.rows() && cols() == rhs.cols(); + + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).minus(rhs.get(row, col))); + } + } + + return result; + } + + /** + * Binary subtraction operator. + * + * @param rhs Operator right-hand side. + * @return Result of subtraction. + */ + public VariableMatrix minus(VariableBlock rhs) { + assert rows() == rhs.rows() && cols() == rhs.cols(); + + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).minus(rhs.get(row, col))); + } + } + + return result; + } + + /** + * Binary subtraction operator. + * + * @param rhs Operator right-hand side. + * @return Result of subtraction. + */ + public VariableMatrix minus(SimpleMatrix rhs) { + assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols(); + + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).minus(rhs.get(row, col))); + } + } + + return result; + } + + /** + * Unary minus operator. + * + * @return Result of unary minus. + */ + public VariableMatrix unaryMinus() { + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { + result.set(row, col, get(row, col).unaryMinus()); + } + } + + return result; + } + + /** + * Returns the transpose of the variable matrix. + * + * @return The transpose of the variable matrix. + */ + public VariableMatrix T() { + var result = new VariableMatrix(cols(), rows()); + + for (int row = 0; row < rows(); ++row) { + for (int col = 0; col < cols(); ++col) { + result.set(col, row, get(row, col)); + } + } + + return result; + } + + /** + * Returns the number of rows in the matrix. + * + * @return The number of rows in the matrix. + */ + public int rows() { + return m_rows; + } + + /** + * Returns the number of columns in the matrix. + * + * @return The number of columns in the matrix. + */ + public int cols() { + return m_cols; + } + + /** + * Returns an element of the variable matrix. + * + * @param row The row of the element to return. + * @param col The column of the element to return. + * @return An element of the variable matrix. + */ + public double value(int row, int col) { + return get(row, col).value(); + } + + /** + * Returns an element of the variable matrix. + * + * @param index The index of the element to return. + * @return An element of the variable matrix. + */ + public double value(int index) { + return get(index).value(); + } + + /** + * Returns the contents of the variable matrix. + * + * @return The contents of the variable matrix. + */ + public SimpleMatrix value() { + var result = new SimpleMatrix(rows(), cols()); + + for (int row = 0; row < rows(); ++row) { + for (int col = 0; col < cols(); ++col) { + result.set(row, col, value(row, col)); + } + } + + return result; + } + + /** + * Maps the matrix coefficient-wise with an unary operator. + * + * @param unaryOp The unary operator to use for the map operation. + * @return Result of the unary operator. + */ + public VariableMatrix cwiseMap(UnaryOperator unaryOp) { + var result = new VariableMatrix(rows(), cols()); + + for (int row = 0; row < rows(); ++row) { + for (int col = 0; col < cols(); ++col) { + result.set(row, col, unaryOp.apply(get(row, col))); + } + } + + return result; + } + + /** + * Returns number of elements in matrix. + * + * @return Number of elements in matrix. + */ + public int size() { + return m_storage.length; + } + + @Override + public Iterator iterator() { + return new Iterator<>() { + private int m_index = 0; + + @Override + public boolean hasNext() { + return m_index < VariableMatrix.this.size(); + } + + @Override + public Variable next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + return VariableMatrix.this.get(m_index++); + } + }; + } + + /** + * Creates a Stream of VariableMatrix elements. + * + * @return A Stream of VariableMatrix elements. + */ + public Stream stream() { + return StreamSupport.stream(spliterator(), false); + } + + /** + * Returns a variable matrix filled with zeroes. + * + * @param rows The number of matrix rows. + * @param cols The number of matrix columns. + * @return A variable matrix filled with zeroes. + */ + public static VariableMatrix zero(int rows, int cols) { + return new VariableMatrix(new SimpleMatrix(rows, cols)); + } + + /** + * Returns a variable matrix filled with ones. + * + * @param rows The number of matrix rows. + * @param cols The number of matrix columns. + * @return A variable matrix filled with ones. + */ + public static VariableMatrix one(int rows, int cols) { + return new VariableMatrix(SimpleMatrix.ones(rows, cols)); + } + + /** + * Returns a variable matrix filled with a constant. + * + * @param rows The number of matrix rows. + * @param cols The number of matrix columns. + * @param constant The constant. + * @return A variable matrix filled with a constant. + */ + public static VariableMatrix constant(int rows, int cols, double constant) { + return new VariableMatrix(SimpleMatrix.filled(rows, cols, constant)); + } + + /** + * Applies a coefficient-wise reduce operation to two matrices. + * + * @param lhs The left-hand side of the binary operator. + * @param rhs The right-hand side of the binary operator. + * @param binaryOp The binary operator to use for the reduce operation. + * @return Result of binary operator. + */ + public static VariableMatrix cwiseReduce( + VariableMatrix lhs, VariableMatrix rhs, BinaryOperator binaryOp) { + assert lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols(); + + var result = new VariableMatrix(lhs.rows(), lhs.cols()); + + for (int row = 0; row < lhs.rows(); ++row) { + for (int col = 0; col < lhs.cols(); ++col) { + result.set(row, col, binaryOp.apply(lhs.get(row, col), rhs.get(row, col))); + } + } + + return result; + } + + /** + * Assembles a VariableMatrix from a nested list of blocks. + * + *

Each row's blocks must have the same height, and the assembled block rows must have the same + * width. For example, for the block matrix [[A, B], [C]] to be constructible, the number of rows + * in A and B must match, and the number of columns in [A, B] and [C] must match. + * + * @param list The nested list of blocks. + * @return Block matrix. + */ + @SuppressWarnings("OverloadMethodsDeclarationOrder") + public static VariableMatrix block(VariableMatrix[][] list) { + // Get row and column counts for destination matrix + int rows = 0; + int cols = -1; + for (var row : list) { + if (row.length > 0) { + rows += row[0].rows(); + } + + // Get number of columns in this row + int latestCols = 0; + for (var elem : row) { + // Assert the first and latest row have the same height + assert row[0].rows() == elem.rows(); + + latestCols += elem.cols(); + } + + // If this is the first row, record the column count. Otherwise, assert the + // first and latest column counts are the same. + if (cols == -1) { + cols = latestCols; + } else { + assert cols == latestCols; + } + } + + var result = new VariableMatrix(rows, cols); + + int rowOffset = 0; + for (var row : list) { + int colOffset = 0; + for (var elem : row) { + result.block(rowOffset, colOffset, elem.rows(), elem.cols()).set(elem); + colOffset += elem.cols(); + } + if (row.length > 0) { + rowOffset += row[0].rows(); + } + } + + return result; + } + + /** + * Solves the VariableMatrix equation AX = B for X. + * + * @param A The left-hand side. + * @param B The right-hand side. + * @return The solution X. + */ + public static VariableMatrix solve(VariableMatrix A, VariableMatrix B) { + // m x n * n x p = m x p + assert A.rows() == B.rows(); + + return new VariableMatrix( + A.cols(), + B.cols(), + VariableMatrixJNI.solve(A.getHandles(), A.cols(), B.getHandles(), B.cols())); + } + + /** + * Returns an array of VariableMatrix internal handles in row-major order. + * + * @return Array of VariableMatrix internal handles in row-major order. + */ + long[] getHandles() { + var handles = new long[size()]; + for (int index = 0; index < size(); ++index) { + handles[index] = m_storage[index].getHandle(); + } + return handles; + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/VariableMatrixJNI.java b/wpimath/src/main/java/org/wpilib/math/autodiff/VariableMatrixJNI.java new file mode 100644 index 0000000000..f7267c06fe --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/VariableMatrixJNI.java @@ -0,0 +1,25 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import org.wpilib.math.jni.WPIMathJNI; + +/** Variable JNI functions. */ +final class VariableMatrixJNI extends WPIMathJNI { + private VariableMatrixJNI() { + // Utility class. + } + + /** + * Solves the VariableMatrix equation AX = B for X. + * + * @param A The left-hand side as a flattened row-major matrix. + * @param Acols The number of columns in A. + * @param B The right-hand side as a flattened row-major matrix. + * @param Bcols The number of columns in B. + * @return The solution X as a flattened row-major matrix. + */ + static native long[] solve(long[] A, int Acols, long[] B, int Bcols); +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/VariablePool.java b/wpimath/src/main/java/org/wpilib/math/autodiff/VariablePool.java new file mode 100644 index 0000000000..419cd2cfc0 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/VariablePool.java @@ -0,0 +1,55 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import java.util.ArrayDeque; +import java.util.Deque; +import org.wpilib.util.ErrorMessages; +import org.wpilib.util.cleanup.CleanupPool; + +/** + * Cleans up implicitly allocated Variables via try-with-resources. + * + *

This implements a stack of Variable pools containing a default global pool. The user can + * create additional pools via try-with-resources. Variable and VariableMatrix instances will + * register themselves with the latest pool. + * + *

It's strongly recommended to only instantiate this class via try-with-resources so the close() + * methods are always called in the correct order (i.e., nested scopes). + */ +public class VariablePool implements AutoCloseable { + private static Deque s_variablePoolStack = new ArrayDeque(); + + // Default global pool + @SuppressWarnings("PMD.UnusedPrivateField") + private static VariablePool s_globalPool = new VariablePool(); + + // Cleans up Variables in the scope of this VariablePool + private final CleanupPool m_cleanupPool = new CleanupPool(); + + /** Default constructor. */ + @SuppressWarnings("this-escape") + public VariablePool() { + s_variablePoolStack.addFirst(this); + } + + @Override + public void close() { + m_cleanupPool.close(); + s_variablePoolStack.removeFirst(); + } + + /** + * Registers a Variable in the Variable stack for cleanup. + * + * @param variable The Variable to register + * @return The registered Variable + */ + public static Variable register(Variable variable) { + ErrorMessages.requireNonNullParam(variable, "variable", "register"); + s_variablePoolStack.getFirst().m_cleanupPool.register(variable); + return variable; + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/controller/ArmFeedforward.java b/wpimath/src/main/java/org/wpilib/math/controller/ArmFeedforward.java index 637af8f1cc..21c742a6f4 100644 --- a/wpimath/src/main/java/org/wpilib/math/controller/ArmFeedforward.java +++ b/wpimath/src/main/java/org/wpilib/math/controller/ArmFeedforward.java @@ -4,9 +4,17 @@ package org.wpilib.math.controller; +import static org.wpilib.math.autodiff.Variable.cos; +import static org.wpilib.math.autodiff.Variable.signum; + +import org.wpilib.math.autodiff.Gradient; +import org.wpilib.math.autodiff.Hessian; +import org.wpilib.math.autodiff.NumericalIntegration; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.autodiff.VariableMatrix; +import org.wpilib.math.autodiff.VariablePool; import org.wpilib.math.controller.proto.ArmFeedforwardProto; import org.wpilib.math.controller.struct.ArmFeedforwardStruct; -import org.wpilib.math.jni.ArmFeedforwardJNI; import org.wpilib.util.protobuf.ProtobufSerializable; import org.wpilib.util.struct.StructSerializable; @@ -191,8 +199,111 @@ public class ArmFeedforward implements ProtobufSerializable, StructSerializable * @return The computed feedforward in volts. */ public double calculate(double currentAngle, double currentVelocity, double nextVelocity) { - return ArmFeedforwardJNI.calculate( - ks, kv, ka, kg, currentAngle, currentVelocity, nextVelocity, m_dt); + // Small kₐ values make the solver ill-conditioned + if (ka < 1e-1) { + double acceleration = (nextVelocity - currentVelocity) / m_dt; + return ks * Math.signum(currentVelocity) + + kv * currentVelocity + + ka * acceleration + + kg * Math.cos(currentAngle); + } + + try (var pool = new VariablePool()) { + // Arm dynamics + var A = new VariableMatrix(new double[][] {{0.0, 1.0}, {0.0, -kv / ka}}); + var B = new VariableMatrix(new double[][] {{0.0}, {1.0 / ka}}); + + var r_k = new VariableMatrix(new double[][] {{currentAngle}, {currentVelocity}}); + + var u_k = new Variable(); + + // Initial guess + double acceleration = (nextVelocity - currentVelocity) / m_dt; + u_k.setValue( + ks * Math.signum(currentVelocity) + + kv * currentVelocity + + ka * acceleration + + kg * Math.cos(currentAngle)); + + var r_k1 = + NumericalIntegration.rk4( + (VariableMatrix x, VariableMatrix u) -> { + var c = + new VariableMatrix( + new Variable[][] { + {new Variable(0.0)}, + {signum(x.get(1)).times(-ks / ka).plus(cos(x.get(0)).times(-kg / ka))} + }); + return A.times(x).plus(B.times(u)).plus(c); + }, + r_k, + new VariableMatrix(u_k), + m_dt); + + // Minimize difference between desired and actual next velocity + var cost = + new Variable(nextVelocity) + .minus(r_k1.get(1)) + .times(new Variable(nextVelocity).minus(r_k1.get(1))); + + // Refine solution via Newton's method + { + var xAD = u_k; + double x = xAD.value(); + + var gradientF = new Gradient(cost, xAD); + var g = gradientF.value(); + + var hessianF = new Hessian(cost, xAD); + var H = hessianF.value(); + + double error_k = Double.POSITIVE_INFINITY; + double error_k1 = Math.abs(g.get(0, 0)); + + // Loop until error stops decreasing or max iterations is reached + for (int iteration = 0; iteration < 50 && error_k1 < (1.0 - 1e-10) * error_k; ++iteration) { + error_k = error_k1; + + // Iterate via Newton's method. + // + // xₖ₊₁ = xₖ − H⁻¹g + // + // The Hessian is regularized to at least 1e-4. + double p_x = -g.get(0, 0) / Math.max(H.get(0, 0), 1e-4); + + // Shrink step until cost goes down + { + double oldCost = cost.value(); + + double α = 1.0; + double trial_x = x + α * p_x; + + xAD.setValue(trial_x); + + while (cost.value() > oldCost) { + α *= 0.5; + trial_x = x + α * p_x; + + xAD.setValue(trial_x); + } + + x = trial_x; + } + + xAD.setValue(x); + + g = gradientF.value(); + H = hessianF.value(); + + error_k1 = Math.abs(g.get(0, 0)); + } + + hessianF.close(); + gradientF.close(); + } + + return u_k.value(); + } } // Rearranging the main equation from the calculate() method yields the diff --git a/wpimath/src/main/java/org/wpilib/math/geometry/Ellipse2d.java b/wpimath/src/main/java/org/wpilib/math/geometry/Ellipse2d.java index f17dc536df..03b66fee0b 100644 --- a/wpimath/src/main/java/org/wpilib/math/geometry/Ellipse2d.java +++ b/wpimath/src/main/java/org/wpilib/math/geometry/Ellipse2d.java @@ -4,12 +4,14 @@ package org.wpilib.math.geometry; +import static org.wpilib.math.autodiff.Variable.pow; +import static org.wpilib.math.optimization.Constraints.eq; import static org.wpilib.units.Units.Meters; import java.util.Objects; import org.wpilib.math.geometry.proto.Ellipse2dProto; import org.wpilib.math.geometry.struct.Ellipse2dStruct; -import org.wpilib.math.jni.Ellipse2dJNI; +import org.wpilib.math.optimization.Problem; import org.wpilib.math.util.Pair; import org.wpilib.units.measure.Distance; import org.wpilib.util.protobuf.ProtobufSerializable; @@ -224,18 +226,38 @@ public class Ellipse2d implements ProtobufSerializable, StructSerializable { return point; } + // Rotate the point by the inverse of the ellipse's rotation + var rotPoint = + point.rotateAround(m_center.getTranslation(), m_center.getRotation().unaryMinus()); + // Find nearest point - var nearestPoint = new double[2]; - Ellipse2dJNI.nearest( - m_center.getX(), - m_center.getY(), - m_center.getRotation().getRadians(), - m_xSemiAxis, - m_ySemiAxis, - point.getX(), - point.getY(), - nearestPoint); - return new Translation2d(nearestPoint[0], nearestPoint[1]); + try (var problem = new Problem()) { + // Point on ellipse + var x = problem.decisionVariable(); + x.setValue(rotPoint.getX()); + var y = problem.decisionVariable(); + y.setValue(rotPoint.getY()); + + problem.minimize(pow(x.minus(rotPoint.getX()), 2).plus(pow(y.minus(rotPoint.getY()), 2))); + + // (x − x_c)²/a² + (y − y_c)²/b² = 1 + // b²(x − x_c)² + a²(y − y_c)² = a²b² + double a2 = m_xSemiAxis * m_xSemiAxis; + double b2 = m_ySemiAxis * m_ySemiAxis; + problem.subjectTo( + eq( + pow(x.minus(m_center.getX()), 2) + .times(b2) + .plus(pow(y.minus(m_center.getY()), 2).times(a2)), + a2 * b2)); + + problem.solve(); + + rotPoint = new Translation2d(x.value(), y.value()); + } + + // Undo rotation + return rotPoint.rotateAround(m_center.getTranslation(), m_center.getRotation()); } @Override diff --git a/wpimath/src/main/java/org/wpilib/math/jni/ArmFeedforwardJNI.java b/wpimath/src/main/java/org/wpilib/math/jni/ArmFeedforwardJNI.java deleted file mode 100644 index 9b315de857..0000000000 --- a/wpimath/src/main/java/org/wpilib/math/jni/ArmFeedforwardJNI.java +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -package org.wpilib.math.jni; - -/** ArmFeedforward JNI. */ -public final class ArmFeedforwardJNI extends WPIMathJNI { - /** - * Obtain a feedforward voltage from a single jointed arm feedforward object. - * - *

Constructs an ArmFeedforward object and runs its currentVelocity and nextVelocity overload - * - * @param ks The ArmFeedforward's static gain in volts. - * @param kv The ArmFeedforward's velocity gain in volt seconds per radian. - * @param ka The ArmFeedforward's acceleration gain in volt seconds² per radian. - * @param kg The ArmFeedforward's gravity gain in volts. - * @param currentAngle The current angle in the calculation in radians. - * @param currentVelocity The current velocity in the calculation in radians per second. - * @param nextVelocity The next velocity in the calculation in radians per second. - * @param dt The time between velocity setpoints in seconds. - * @return The calculated feedforward in volts. - */ - public static native double calculate( - double ks, - double kv, - double ka, - double kg, - double currentAngle, - double currentVelocity, - double nextVelocity, - double dt); - - /** Utility class. */ - private ArmFeedforwardJNI() {} -} diff --git a/wpimath/src/main/java/org/wpilib/math/jni/Ellipse2dJNI.java b/wpimath/src/main/java/org/wpilib/math/jni/Ellipse2dJNI.java deleted file mode 100644 index d766ddbf63..0000000000 --- a/wpimath/src/main/java/org/wpilib/math/jni/Ellipse2dJNI.java +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -package org.wpilib.math.jni; - -/** Ellipse2d JNI. */ -public final class Ellipse2dJNI extends WPIMathJNI { - /** - * Returns the nearest point that is contained within the ellipse. - * - *

Constructs an Ellipse2d object and runs its nearest() method. - * - * @param centerX The x coordinate of the center of the ellipse in meters. - * @param centerY The y coordinate of the center of the ellipse in meters. - * @param centerHeading The ellipse's rotation in radians. - * @param xSemiAxis The x semi-axis in meters. - * @param ySemiAxis The y semi-axis in meters. - * @param pointX The x coordinate of the point that this will find the nearest point to. - * @param pointY The y coordinate of the point that this will find the nearest point to. - * @param nearestPoint Array to store nearest point into. - */ - public static native void nearest( - double centerX, - double centerY, - double centerHeading, - double xSemiAxis, - double ySemiAxis, - double pointX, - double pointY, - double[] nearestPoint); - - /** Utility class. */ - private Ellipse2dJNI() {} -} diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/Constraints.java b/wpimath/src/main/java/org/wpilib/math/optimization/Constraints.java new file mode 100644 index 0000000000..188ac44bb8 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/optimization/Constraints.java @@ -0,0 +1,1472 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import java.util.ArrayList; +import org.ejml.simple.SimpleMatrix; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.autodiff.VariableBlock; +import org.wpilib.math.autodiff.VariableMatrix; + +/** Constraint creation helper functions. */ +public final class Constraints { + /** Utility class. */ + private Constraints() {} + + // == + + /** + * Equality operator that returns an equality constraint for a double and a Variable. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(double lhs, Variable rhs) { + return eq(new Variable(lhs), rhs); + } + + /** + * Equality operator that returns an equality constraint for a Variable and a double. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(Variable lhs, double rhs) { + return eq(lhs, new Variable(rhs)); + } + + /** + * Equality operator that returns an equality constraint for two Variables. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(Variable lhs, Variable rhs) { + return new EqualityConstraints(new Variable[] {lhs.minus(rhs)}); + } + + /** + * Equality operator that returns an equality constraint for a double and a VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(double lhs, VariableBlock rhs) { + return eq(new Variable(lhs), new VariableMatrix(rhs)); + } + + /** + * Equality operator that returns an equality constraint for a Variable and a VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(Variable lhs, VariableBlock rhs) { + return eq(lhs, new VariableMatrix(rhs)); + } + + /** + * Equality operator that returns an equality constraint for a double and a VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(double lhs, VariableMatrix rhs) { + return eq(new Variable(lhs), rhs); + } + + /** + * Equality operator that returns an equality constraint for a Variable and a VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(Variable lhs, VariableMatrix rhs) { + var constraints = new ArrayList(rhs.rows() * rhs.cols()); + for (int row = 0; row < rhs.rows(); ++row) { + for (int col = 0; col < rhs.cols(); ++col) { + constraints.add(lhs.minus(rhs.get(row, col))); + } + } + + var array = new Variable[constraints.size()]; + return new EqualityConstraints(constraints.toArray(array)); + } + + /** + * Equality operator that returns an equality constraint for a VariableBlock and a double. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(VariableBlock lhs, double rhs) { + return eq(new VariableMatrix(lhs), new Variable(rhs)); + } + + /** + * Equality operator that returns an equality constraint for a VariableBlock and a Variable. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(VariableBlock lhs, Variable rhs) { + return eq(new VariableMatrix(lhs), rhs); + } + + /** + * Equality operator that returns an equality constraint for a VariableMatrix and a double. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(VariableMatrix lhs, double rhs) { + return eq(lhs, new Variable(rhs)); + } + + /** + * Equality operator that returns an equality constraint for a VariableMatrix and a Variable. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(VariableMatrix lhs, Variable rhs) { + var constraints = new ArrayList(lhs.rows() * lhs.cols()); + for (int row = 0; row < lhs.rows(); ++row) { + for (int col = 0; col < lhs.cols(); ++col) { + constraints.add(lhs.get(row, col).minus(rhs)); + } + } + + var array = new Variable[constraints.size()]; + return new EqualityConstraints(constraints.toArray(array)); + } + + /** + * Equality operator that returns an equality constraint for two VariableBlocks. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(VariableBlock lhs, VariableBlock rhs) { + return eq(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Equality operator that returns an equality constraint for a VariableBlock and a VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(VariableBlock lhs, VariableMatrix rhs) { + return eq(new VariableMatrix(lhs), rhs); + } + + /** + * Equality operator that returns an equality constraint for a VariableMatrix and a VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(VariableMatrix lhs, VariableBlock rhs) { + return eq(lhs, new VariableMatrix(rhs)); + } + + /** + * Equality operator that returns an equality constraint for a double[][] and a VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(double[][] lhs, VariableBlock rhs) { + return eq(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Equality operator that returns an equality constraint for a SimpleMatrix and a VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(SimpleMatrix lhs, VariableBlock rhs) { + return eq(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Equality operator that returns an equality constraint for a double array and a VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(double[][] lhs, VariableMatrix rhs) { + return eq(new VariableMatrix(lhs), rhs); + } + + /** + * Equality operator that returns an equality constraint for a SimpleMatrix and a VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(SimpleMatrix lhs, VariableMatrix rhs) { + return eq(new VariableMatrix(lhs), rhs); + } + + /** + * Equality operator that returns an equality constraint for a VariableBlock and a double array. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(VariableBlock lhs, double[][] rhs) { + return eq(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Equality operator that returns an equality constraint for a VariableBlock and a SimpleMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(VariableBlock lhs, SimpleMatrix rhs) { + return eq(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Equality operator that returns an equality constraint for a VariableMatrix and a double array. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(VariableMatrix lhs, double[][] rhs) { + return eq(lhs, new VariableMatrix(rhs)); + } + + /** + * Equality operator that returns an equality constraint for a VariableMatrix and a SimpleMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(VariableMatrix lhs, SimpleMatrix rhs) { + return eq(lhs, new VariableMatrix(rhs)); + } + + /** + * Equality operator that returns an equality constraint for two VariableMatrices. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static EqualityConstraints eq(VariableMatrix lhs, VariableMatrix rhs) { + assert lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols(); + + var constraints = new ArrayList(lhs.rows() * lhs.cols()); + for (int row = 0; row < lhs.rows(); ++row) { + for (int col = 0; col < lhs.cols(); ++col) { + constraints.add(lhs.get(row, col).minus(rhs.get(row, col))); + } + } + + var array = new Variable[constraints.size()]; + return new EqualityConstraints(constraints.toArray(array)); + } + + // < + + /** + * Less-than comparison operator that returns an inequality constraint for a double and a + * Variable. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints lt(double lhs, Variable rhs) { + return ge(rhs, new Variable(lhs)); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a Variable and a + * double. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints lt(Variable lhs, double rhs) { + return ge(new Variable(rhs), lhs); + } + + /** + * Less-than comparison operator that returns an inequality constraint for two Variables. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints lt(Variable lhs, Variable rhs) { + return ge(rhs, lhs); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a double and a + * VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints lt(double lhs, VariableBlock rhs) { + return ge(new VariableMatrix(rhs), new Variable(lhs)); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a Variable and a + * VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints lt(Variable lhs, VariableBlock rhs) { + return ge(new VariableMatrix(rhs), lhs); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a double and a + * VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints lt(double lhs, VariableMatrix rhs) { + return ge(rhs, new Variable(lhs)); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a Variable and a + * VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints lt(Variable lhs, VariableMatrix rhs) { + return ge(rhs, lhs); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a VariableBlock and a + * double. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints lt(VariableBlock lhs, double rhs) { + return ge(new Variable(rhs), new VariableMatrix(lhs)); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a VariableBlock and a + * Variable. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints lt(VariableBlock lhs, Variable rhs) { + return ge(rhs, new VariableMatrix(lhs)); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a VariableMatrix and a + * double. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints lt(VariableMatrix lhs, double rhs) { + return ge(new Variable(rhs), lhs); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a VariableMatrix and a + * Variable. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints lt(VariableMatrix lhs, Variable rhs) { + return ge(rhs, lhs); + } + + /** + * Less-than comparison operator that returns an inequality constraint for two VariableBlocks. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints lt(VariableBlock lhs, VariableBlock rhs) { + return ge(new VariableMatrix(rhs), new VariableMatrix(lhs)); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a VariableBlock and a + * VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints lt(VariableBlock lhs, VariableMatrix rhs) { + return ge(rhs, new VariableMatrix(lhs)); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a VariableMatrix and a + * VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints lt(VariableMatrix lhs, VariableBlock rhs) { + return ge(new VariableMatrix(rhs), lhs); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a double array and a + * VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints lt(double[][] lhs, VariableBlock rhs) { + return lt(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a SimpleMatrix and a + * VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints lt(SimpleMatrix lhs, VariableBlock rhs) { + return lt(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a double array and a + * VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints lt(double[][] lhs, VariableMatrix rhs) { + return lt(new VariableMatrix(lhs), rhs); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a SimpleMatrix and a + * VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints lt(SimpleMatrix lhs, VariableMatrix rhs) { + return lt(new VariableMatrix(lhs), rhs); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a VariableBlock and a + * double array. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints lt(VariableBlock lhs, double[][] rhs) { + return lt(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a VariableBlock and a + * SimpleMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints lt(VariableBlock lhs, SimpleMatrix rhs) { + return lt(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a VariableMatrix and a + * double array. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints lt(VariableMatrix lhs, double[][] rhs) { + return lt(lhs, new VariableMatrix(rhs)); + } + + /** + * Less-than comparison operator that returns an inequality constraint for a VariableMatrix and a + * SimpleMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints lt(VariableMatrix lhs, SimpleMatrix rhs) { + return lt(lhs, new VariableMatrix(rhs)); + } + + /** + * Less-than comparison operator that returns an inequality constraint for two VariableMatrices. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints lt(VariableMatrix lhs, VariableMatrix rhs) { + return ge(rhs, lhs); + } + + // <= + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a double + * and a Variable. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints le(double lhs, Variable rhs) { + return ge(rhs, new Variable(lhs)); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a Variable + * and a double. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints le(Variable lhs, double rhs) { + return ge(new Variable(rhs), lhs); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for two + * Variables. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints le(Variable lhs, Variable rhs) { + return ge(rhs, lhs); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a double + * and a VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints le(double lhs, VariableBlock rhs) { + return ge(new VariableMatrix(rhs), new Variable(lhs)); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a Variable + * and a VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints le(Variable lhs, VariableBlock rhs) { + return ge(new VariableMatrix(rhs), lhs); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a double + * and a VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints le(double lhs, VariableMatrix rhs) { + return ge(rhs, new Variable(lhs)); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a Variable + * and a VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints le(Variable lhs, VariableMatrix rhs) { + return ge(rhs, lhs); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableBlock and a double. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints le(VariableBlock lhs, double rhs) { + return ge(new Variable(rhs), new VariableMatrix(lhs)); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableBlock and a Variable. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints le(VariableBlock lhs, Variable rhs) { + return ge(rhs, new VariableMatrix(lhs)); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableMatrix and a double. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints le(VariableMatrix lhs, double rhs) { + return ge(new Variable(rhs), lhs); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableMatrix and a Variable. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints le(VariableMatrix lhs, Variable rhs) { + return ge(rhs, lhs); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for two + * VariableBlocks. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints le(VariableBlock lhs, VariableBlock rhs) { + return ge(new VariableMatrix(rhs), new VariableMatrix(lhs)); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableBlock and a VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints le(VariableBlock lhs, VariableMatrix rhs) { + return ge(rhs, new VariableMatrix(lhs)); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableMatrix and a VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints le(VariableMatrix lhs, VariableBlock rhs) { + return ge(new VariableMatrix(rhs), lhs); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a double + * array and a VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints le(double[][] lhs, VariableBlock rhs) { + return le(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a + * SimpleMatrix and a VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints le(SimpleMatrix lhs, VariableBlock rhs) { + return le(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a double + * array and a VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints le(double[][] lhs, VariableMatrix rhs) { + return le(new VariableMatrix(lhs), rhs); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a + * SimpleMatrix and a VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints le(SimpleMatrix lhs, VariableMatrix rhs) { + return le(new VariableMatrix(lhs), rhs); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableBlock and a double array. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints le(VariableBlock lhs, double[][] rhs) { + return le(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableBlock and a SimpleMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints le(VariableBlock lhs, SimpleMatrix rhs) { + return le(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableMatrix and a double array. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints le(VariableMatrix lhs, double[][] rhs) { + return le(lhs, new VariableMatrix(rhs)); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableMatrix and a SimpleMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints le(VariableMatrix lhs, SimpleMatrix rhs) { + return le(lhs, new VariableMatrix(rhs)); + } + + /** + * Less-than-or-equal-to comparison operator that returns an inequality constraint for two + * VariableMatrices. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints le(VariableMatrix lhs, VariableMatrix rhs) { + return ge(rhs, lhs); + } + + // > + + /** + * Greater-than comparison operator that returns an inequality constraint for a double and a + * Variable. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints gt(double lhs, Variable rhs) { + return ge(new Variable(lhs), rhs); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a Variable and a + * double. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints gt(Variable lhs, double rhs) { + return ge(lhs, new Variable(rhs)); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for two Variables. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints gt(Variable lhs, Variable rhs) { + return ge(lhs, rhs); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a double and a + * VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints gt(double lhs, VariableBlock rhs) { + return ge(new Variable(lhs), new VariableMatrix(rhs)); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a Variable and a + * VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints gt(Variable lhs, VariableBlock rhs) { + return ge(lhs, new VariableMatrix(rhs)); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a double and a + * VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints gt(double lhs, VariableMatrix rhs) { + return ge(new Variable(lhs), rhs); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a Variable and a + * VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints gt(Variable lhs, VariableMatrix rhs) { + return ge(lhs, rhs); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a VariableBlock and + * a double. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints gt(VariableBlock lhs, double rhs) { + return ge(new VariableMatrix(lhs), new Variable(rhs)); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a VariableBlock and + * a Variable. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints gt(VariableBlock lhs, Variable rhs) { + return ge(new VariableMatrix(lhs), rhs); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a VariableMatrix and + * a double. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints gt(VariableMatrix lhs, double rhs) { + return ge(lhs, new Variable(rhs)); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a VariableMatrix and + * a Variable. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints gt(VariableMatrix lhs, Variable rhs) { + return ge(lhs, rhs); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for two VariableBlocks. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints gt(VariableBlock lhs, VariableBlock rhs) { + return ge(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a VariableBlock and + * a VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints gt(VariableBlock lhs, VariableMatrix rhs) { + return ge(new VariableMatrix(lhs), rhs); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a VariableMatrix and + * a VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints gt(VariableMatrix lhs, VariableBlock rhs) { + return ge(lhs, new VariableMatrix(rhs)); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a double array and a + * VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints gt(double[][] lhs, VariableBlock rhs) { + return gt(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a SimpleMatrix and a + * VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints gt(SimpleMatrix lhs, VariableBlock rhs) { + return gt(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a double array and a + * VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints gt(double[][] lhs, VariableMatrix rhs) { + return gt(new VariableMatrix(lhs), rhs); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a SimpleMatrix and a + * VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints gt(SimpleMatrix lhs, VariableMatrix rhs) { + return gt(new VariableMatrix(lhs), rhs); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a VariableBlock and + * a double array. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints gt(VariableBlock lhs, double[][] rhs) { + return gt(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a VariableBlock and + * a SimpleMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints gt(VariableBlock lhs, SimpleMatrix rhs) { + return gt(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a VariableMatrix and + * a double array. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints gt(VariableMatrix lhs, double[][] rhs) { + return gt(lhs, new VariableMatrix(rhs)); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for a VariableMatrix and + * a SimpleMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints gt(VariableMatrix lhs, SimpleMatrix rhs) { + return gt(lhs, new VariableMatrix(rhs)); + } + + /** + * Greater-than comparison operator that returns an inequality constraint for two + * VariableMatrices. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints gt(VariableMatrix lhs, VariableMatrix rhs) { + return ge(lhs, rhs); + } + + // >= + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a double + * and a Variable. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints ge(double lhs, Variable rhs) { + return ge(new Variable(lhs), rhs); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a + * Variable and a double. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints ge(Variable lhs, double rhs) { + return ge(lhs, new Variable(rhs)); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for two + * Variables. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints ge(Variable lhs, Variable rhs) { + return new InequalityConstraints(new Variable[] {lhs.minus(rhs)}); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a double + * and a VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints ge(double lhs, VariableBlock rhs) { + return ge(new Variable(lhs), new VariableMatrix(rhs)); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a + * Variable and a VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints ge(Variable lhs, VariableBlock rhs) { + return ge(lhs, new VariableMatrix(rhs)); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a double + * and a VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints ge(double lhs, VariableMatrix rhs) { + return ge(new Variable(lhs), rhs); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a + * Variable and a VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints ge(Variable lhs, VariableMatrix rhs) { + var constraints = new ArrayList(rhs.rows() * rhs.cols()); + for (int row = 0; row < rhs.rows(); ++row) { + for (int col = 0; col < rhs.cols(); ++col) { + constraints.add(lhs.minus(rhs.get(row, col))); + } + } + + var array = new Variable[constraints.size()]; + return new InequalityConstraints(constraints.toArray(array)); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableBlock and a double. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints ge(VariableBlock lhs, double rhs) { + return ge(new VariableMatrix(lhs), new Variable(rhs)); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableBlock and a Variable. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints ge(VariableBlock lhs, Variable rhs) { + return ge(new VariableMatrix(lhs), rhs); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableMatrix and a double. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints ge(VariableMatrix lhs, double rhs) { + return ge(lhs, new Variable(rhs)); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableMatrix and a Variable. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints ge(VariableMatrix lhs, Variable rhs) { + var constraints = new ArrayList(lhs.rows() * lhs.cols()); + for (int row = 0; row < lhs.rows(); ++row) { + for (int col = 0; col < lhs.cols(); ++col) { + constraints.add(lhs.get(row, col).minus(rhs)); + } + } + + var array = new Variable[constraints.size()]; + return new InequalityConstraints(constraints.toArray(array)); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for two + * VariableBlocks. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints ge(VariableBlock lhs, VariableBlock rhs) { + return ge(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableBlock and a VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints ge(VariableBlock lhs, VariableMatrix rhs) { + return ge(new VariableMatrix(lhs), rhs); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableMatrix and a VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints ge(VariableMatrix lhs, VariableBlock rhs) { + return ge(lhs, new VariableMatrix(rhs)); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a double + * array and a VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints ge(double[][] lhs, VariableBlock rhs) { + return ge(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a + * SimpleMatrix and a VariableBlock. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints ge(SimpleMatrix lhs, VariableBlock rhs) { + return ge(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a double + * array and a VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints ge(double[][] lhs, VariableMatrix rhs) { + return ge(new VariableMatrix(lhs), rhs); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a + * SimpleMatrix and a VariableMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints ge(SimpleMatrix lhs, VariableMatrix rhs) { + return ge(new VariableMatrix(lhs), rhs); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableBlock and a double array. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints ge(VariableBlock lhs, double[][] rhs) { + return ge(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableBlock and a SimpleMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints ge(VariableBlock lhs, SimpleMatrix rhs) { + return ge(new VariableMatrix(lhs), new VariableMatrix(rhs)); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableMatrix and a double array. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints ge(VariableMatrix lhs, double[][] rhs) { + return ge(lhs, new VariableMatrix(rhs)); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for a + * VariableMatrix and a SimpleMatrix. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Equality constraints. + */ + public static InequalityConstraints ge(VariableMatrix lhs, SimpleMatrix rhs) { + return ge(lhs, new VariableMatrix(rhs)); + } + + /** + * Greater-than-or-equal-to comparison operator that returns an inequality constraint for two + * VariableMatrices. + * + * @param lhs Left-hand side. + * @param rhs Right-hand side. + * @return Inequality constraints. + */ + public static InequalityConstraints ge(VariableMatrix lhs, VariableMatrix rhs) { + assert lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols(); + + var constraints = new ArrayList(lhs.rows() * lhs.cols()); + for (int row = 0; row < lhs.rows(); ++row) { + for (int col = 0; col < lhs.cols(); ++col) { + constraints.add(lhs.get(row, col).minus(rhs.get(row, col))); + } + } + + var array = new Variable[constraints.size()]; + return new InequalityConstraints(constraints.toArray(array)); + } + + /** + * Helper function for creating bound constraints. + * + * @param l Lower bound. + * @param x Variable to bound. + * @param u Upper bound. + * @return Inequality constraints. + */ + public static InequalityConstraints bounds(double l, Variable x, double u) { + return bounds(l, new VariableMatrix(x), u); + } + + /** + * Helper function for creating bound constraints. + * + * @param l Lower bound. + * @param x Variable to bound. + * @param u Upper bound. + * @return Inequality constraints. + */ + public static InequalityConstraints bounds(double l, VariableMatrix x, double u) { + var ineq1 = le(l, x).constraints; + var ineq2 = le(x, u).constraints; + var result = new Variable[ineq1.length + ineq2.length]; + System.arraycopy(ineq1, 0, result, 0, ineq1.length); + System.arraycopy(ineq2, 0, result, ineq1.length, ineq2.length); + return new InequalityConstraints(result); + } + + /** + * Helper function for creating bound constraints. + * + * @param l Lower bound. + * @param x Variable to bound. + * @param u Upper bound. + * @return Inequality constraints. + */ + public static InequalityConstraints bounds(double l, VariableBlock x, double u) { + return bounds(l, new VariableMatrix(x), u); + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/EqualityConstraints.java b/wpimath/src/main/java/org/wpilib/math/optimization/EqualityConstraints.java new file mode 100644 index 0000000000..38c4f45c30 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/optimization/EqualityConstraints.java @@ -0,0 +1,23 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import org.wpilib.math.autodiff.Variable; + +/** A vector of equality constraints of the form cₑ(x) = 0. */ +@SuppressWarnings("PMD.ArrayIsStoredDirectly") +public class EqualityConstraints { + /** List of equality constraints. */ + public Variable[] constraints; + + /** + * Constructs an EqualityConstraints. + * + * @param constraints The constraints. + */ + public EqualityConstraints(Variable[] constraints) { + this.constraints = constraints; + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/InequalityConstraints.java b/wpimath/src/main/java/org/wpilib/math/optimization/InequalityConstraints.java new file mode 100644 index 0000000000..ad392b04ba --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/optimization/InequalityConstraints.java @@ -0,0 +1,23 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import org.wpilib.math.autodiff.Variable; + +/** A vector of inequality constraints of the form cᵢ(x) ≥ 0. */ +@SuppressWarnings("PMD.ArrayIsStoredDirectly") +public class InequalityConstraints { + /** List of inequality constraints. */ + public Variable[] constraints; + + /** + * Constructs an InequalityConstraints. + * + * @param constraints The constraints. + */ + public InequalityConstraints(Variable[] constraints) { + this.constraints = constraints; + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/OCP.java b/wpimath/src/main/java/org/wpilib/math/optimization/OCP.java new file mode 100644 index 0000000000..c66242c832 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/optimization/OCP.java @@ -0,0 +1,588 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.wpilib.math.optimization.Constraints.eq; +import static org.wpilib.math.optimization.Constraints.ge; +import static org.wpilib.math.optimization.Constraints.le; + +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import org.ejml.simple.SimpleMatrix; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.autodiff.VariableBlock; +import org.wpilib.math.autodiff.VariableMatrix; +import org.wpilib.math.optimization.ocp.ConstraintEvaluationFunction; +import org.wpilib.math.optimization.ocp.DynamicsFunction; +import org.wpilib.math.optimization.ocp.DynamicsType; +import org.wpilib.math.optimization.ocp.TimestepMethod; +import org.wpilib.math.optimization.ocp.TranscriptionMethod; + +/** + * This class allows the user to pose and solve a constrained optimal control problem (OCP) in a + * variety of ways. + * + *

The system is transcripted by one of three methods (direct transcription, direct collocation, + * or single-shooting) and additional constraints can be added. + * + *

In direct transcription, each state is a decision variable constrained to the integrated + * dynamics of the previous state. In direct collocation, the trajectory is modeled as a series of + * cubic polynomials where the centerpoint slope is constrained. In single-shooting, states depend + * explicitly as a function of all previous states and all previous inputs. + * + *

Explicit ODEs are integrated using RK4. + * + *

For explicit ODEs, the function must be in the form dx/dt = f(t, x, u). For discrete state + * transition functions, the function must be in the form xₖ₊₁ = f(t, xₖ, uₖ). + * + *

Direct collocation requires an explicit ODE. Direct transcription and single-shooting can use + * either an ODE or state transition function. + * + *

https://underactuated.mit.edu/trajopt.html goes into more detail on each transcription method. + */ +public class OCP extends Problem { + private int m_numSteps; + + private DynamicsFunction m_dynamics; + private DynamicsType m_dynamicsType; + + private VariableMatrix m_X; + private VariableMatrix m_U; + private VariableMatrix m_DT; + + /** + * Builds an optimization problem using a system evolution function (explicit ODE or discrete + * state transition function). + * + * @param numStates The number of system states. + * @param numInputs The number of system inputs. + * @param dt The timestep for fixed-step integration. + * @param numSteps The number of control points. + * @param dynamics Function representing an explicit or implicit ODE, or a discrete state + * transition function. + *

+ * + * @param dynamicsType The type of system evolution function. + * @param timestepMethod The timestep method. + * @param transcriptionMethod The transcription method. + */ + public OCP( + int numStates, + int numInputs, + double dt, + int numSteps, + BiFunction dynamics, + DynamicsType dynamicsType, + TimestepMethod timestepMethod, + TranscriptionMethod transcriptionMethod) { + this( + numStates, + numInputs, + dt, + numSteps, + (Variable t, VariableMatrix x, VariableMatrix u, Variable _dt) -> dynamics.apply(x, u), + dynamicsType, + timestepMethod, + transcriptionMethod); + } + + /** + * Builds an optimization problem using a system evolution function (explicit ODE or discrete + * state transition function). + * + * @param numStates The number of system states. + * @param numInputs The number of system inputs. + * @param dt The timestep for fixed-step integration. + * @param numSteps The number of control points. + * @param dynamics Function representing an explicit or implicit ODE, or a discrete state + * transition function. + * + * + * @param dynamicsType The type of system evolution function. + * @param timestepMethod The timestep method. + * @param transcriptionMethod The transcription method. + */ + @SuppressWarnings("this-escape") + public OCP( + int numStates, + int numInputs, + double dt, + int numSteps, + DynamicsFunction dynamics, + DynamicsType dynamicsType, + TimestepMethod timestepMethod, + TranscriptionMethod transcriptionMethod) { + m_numSteps = numSteps; + m_dynamics = dynamics; + m_dynamicsType = dynamicsType; + + // u is numSteps + 1 so that the final constraint function evaluation works + m_U = decisionVariable(numInputs, m_numSteps + 1); + + if (timestepMethod == TimestepMethod.FIXED) { + m_DT = new VariableMatrix(1, m_numSteps + 1); + for (int i = 0; i < numSteps + 1; ++i) { + m_DT.set(0, i, dt); + } + } else if (timestepMethod == TimestepMethod.VARIABLE_SINGLE) { + Variable single_dt = decisionVariable(); + single_dt.setValue(dt); + + // Set the member variable matrix to track the decision variable + m_DT = new VariableMatrix(1, m_numSteps + 1); + for (int i = 0; i < numSteps + 1; ++i) { + m_DT.set(0, i, single_dt); + } + } else if (timestepMethod == TimestepMethod.VARIABLE) { + m_DT = decisionVariable(1, m_numSteps + 1); + for (int i = 0; i < numSteps + 1; ++i) { + m_DT.get(0, i).setValue(dt); + } + } + + if (transcriptionMethod == TranscriptionMethod.DIRECT_TRANSCRIPTION) { + m_X = decisionVariable(numStates, m_numSteps + 1); + constrainDirectTranscription(); + } else if (transcriptionMethod == TranscriptionMethod.DIRECT_COLLOCATION) { + m_X = decisionVariable(numStates, m_numSteps + 1); + constrainDirectCollocation(); + } else if (transcriptionMethod == TranscriptionMethod.SINGLE_SHOOTING) { + // In single-shooting the states aren't decision variables, but instead + // depend on the input and previous states + m_X = new VariableMatrix(numStates, m_numSteps + 1); + constrainSingleShooting(); + } + } + + /** + * Constrains the initial state. + * + * @param initialState the initial state to constrain to. + */ + public void constrainInitialState(double initialState) { + subjectTo(eq(this.initialState(), initialState)); + } + + /** + * Constrains the initial state. + * + * @param initialState the initial state to constrain to. + */ + public void constrainInitialState(Variable initialState) { + subjectTo(eq(this.initialState(), initialState)); + } + + /** + * Constrains the initial state. + * + * @param initialState the initial state to constrain to. + */ + public void constrainInitialState(SimpleMatrix initialState) { + subjectTo(eq(this.initialState(), initialState)); + } + + /** + * Constrains the initial state. + * + * @param initialState the initial state to constrain to. + */ + public void constrainInitialState(VariableMatrix initialState) { + subjectTo(eq(this.initialState(), initialState)); + } + + /** + * Constrains the initial state. + * + * @param initialState the initial state to constrain to. + */ + public void constrainInitialState(VariableBlock initialState) { + subjectTo(eq(this.initialState(), initialState)); + } + + /** + * Constrains the final state. + * + * @param finalState the final state to constrain to. + */ + public void constrainFinalState(double finalState) { + subjectTo(eq(this.finalState(), finalState)); + } + + /** + * Constrains the final state. + * + * @param finalState the final state to constrain to. + */ + public void constrainFinalState(Variable finalState) { + subjectTo(eq(this.finalState(), finalState)); + } + + /** + * Constrains the final state. + * + * @param finalState the final state to constrain to. + */ + public void constrainFinalState(SimpleMatrix finalState) { + subjectTo(eq(this.finalState(), finalState)); + } + + /** + * Constrains the final state. + * + * @param finalState the final state to constrain to. + */ + public void constrainFinalState(VariableMatrix finalState) { + subjectTo(eq(this.finalState(), finalState)); + } + + /** + * Constrains the final state. + * + * @param finalState the final state to constrain to. + */ + public void constrainFinalState(VariableBlock finalState) { + subjectTo(eq(this.finalState(), finalState)); + } + + /** + * Sets the constraint evaluation function. This function is called `numSteps+1` times, with the + * corresponding state and input VariableMatrices. + * + * @param callback The callback f(x, u) where x is the state and u is the input vector. + */ + public void forEachStep(BiConsumer callback) { + for (int i = 0; i < m_numSteps + 1; ++i) { + var x = X().col(i); + var u = U().col(i); + callback.accept(new VariableMatrix(x), new VariableMatrix(u)); + } + } + + /** + * Sets the constraint evaluation function. This function is called `numSteps+1` times, with the + * corresponding state and input VariableMatrices. + * + * @param callback The callback f(t, x, u, dt) where t is time, x is the state vector, u is the + * input vector, and dt is the timestep duration. + */ + public void forEachStep(ConstraintEvaluationFunction callback) { + var time = new Variable(0.0); + + for (int i = 0; i < m_numSteps + 1; ++i) { + var x = X().col(i); + var u = U().col(i); + var dt = this.dt().get(0, i); + callback.accept(time, new VariableMatrix(x), new VariableMatrix(u), dt); + + time = time.plus(dt); + } + } + + /** + * Sets a lower bound on the input. + * + * @param lowerBound The lower bound that inputs must always be above. Must be shaped + * (numInputs)x1. + */ + public void setLowerInputBound(double lowerBound) { + for (int i = 0; i < m_numSteps + 1; ++i) { + subjectTo(ge(U().col(i), lowerBound)); + } + } + + /** + * Sets a lower bound on the input. + * + * @param lowerBound The lower bound that inputs must always be above. Must be shaped + * (numInputs)x1. + */ + public void setLowerInputBound(Variable lowerBound) { + for (int i = 0; i < m_numSteps + 1; ++i) { + subjectTo(ge(U().col(i), lowerBound)); + } + } + + /** + * Sets a lower bound on the input. + * + * @param lowerBound The lower bound that inputs must always be above. Must be shaped + * (numInputs)x1. + */ + public void setLowerInputBound(SimpleMatrix lowerBound) { + for (int i = 0; i < m_numSteps + 1; ++i) { + subjectTo(ge(U().col(i), lowerBound)); + } + } + + /** + * Sets a lower bound on the input. + * + * @param lowerBound The lower bound that inputs must always be above. Must be shaped + * (numInputs)x1. + */ + public void setLowerInputBound(VariableMatrix lowerBound) { + for (int i = 0; i < m_numSteps + 1; ++i) { + subjectTo(ge(U().col(i), lowerBound)); + } + } + + /** + * Sets a lower bound on the input. + * + * @param lowerBound The lower bound that inputs must always be above. Must be shaped + * (numInputs)x1. + */ + public void setLowerInputBound(VariableBlock lowerBound) { + for (int i = 0; i < m_numSteps + 1; ++i) { + subjectTo(ge(U().col(i), lowerBound)); + } + } + + /** + * Sets an upper bound on the input. + * + * @param upperBound The upper bound that inputs must always be below. Must be shaped + * (numInputs)x1. + */ + public void setUpperInputBound(double upperBound) { + for (int i = 0; i < m_numSteps + 1; ++i) { + subjectTo(le(U().col(i), upperBound)); + } + } + + /** + * Sets an upper bound on the input. + * + * @param upperBound The upper bound that inputs must always be below. Must be shaped + * (numInputs)x1. + */ + public void setUpperInputBound(Variable upperBound) { + for (int i = 0; i < m_numSteps + 1; ++i) { + subjectTo(le(U().col(i), upperBound)); + } + } + + /** + * Sets an upper bound on the input. + * + * @param upperBound The upper bound that inputs must always be below. Must be shaped + * (numInputs)x1. + */ + public void setUpperInputBound(SimpleMatrix upperBound) { + for (int i = 0; i < m_numSteps + 1; ++i) { + subjectTo(le(U().col(i), upperBound)); + } + } + + /** + * Sets an upper bound on the input. + * + * @param upperBound The upper bound that inputs must always be below. Must be shaped + * (numInputs)x1. + */ + public void setUpperInputBound(VariableMatrix upperBound) { + for (int i = 0; i < m_numSteps + 1; ++i) { + subjectTo(le(U().col(i), upperBound)); + } + } + + /** + * Sets an upper bound on the input. + * + * @param upperBound The upper bound that inputs must always be below. Must be shaped + * (numInputs)x1. + */ + public void setUpperInputBound(VariableBlock upperBound) { + for (int i = 0; i < m_numSteps + 1; ++i) { + subjectTo(le(U().col(i), upperBound)); + } + } + + /** + * Sets a lower bound on the timestep. + * + * @param minTimestep The minimum timestep in seconds. + */ + public void setMinTimestep(double minTimestep) { + subjectTo(ge(dt(), minTimestep)); + } + + /** + * Sets an upper bound on the timestep. + * + * @param maxTimestep The maximum timestep in seconds. + */ + public void setMaxTimestep(double maxTimestep) { + subjectTo(le(dt(), maxTimestep)); + } + + /** + * Gets the state variables. After the problem is solved, this will contain the optimized + * trajectory. + * + *

Shaped (numStates)x(numSteps+1). + * + * @return The state variable matrix. + */ + public VariableMatrix X() { + return m_X; + } + + /** + * Gets the input variables. After the problem is solved, this will contain the inputs + * corresponding to the optimized trajectory. + * + *

Shaped (numInputs)x(numSteps+1), although the last input step is unused in the trajectory. + * + * @return The input variable matrix. + */ + public VariableMatrix U() { + return m_U; + } + + /** + * Gets the timestep variables. After the problem is solved, this will contain the timesteps + * corresponding to the optimized trajectory. + * + *

Shaped 1x(numSteps+1), although the last timestep is unused in the trajectory. + * + * @return The timestep variable matrix. + */ + public VariableMatrix dt() { + return m_DT; + } + + /** + * Gets the initial state in the trajectory. + * + * @return The initial state of the trajectory. + */ + public VariableMatrix initialState() { + return new VariableMatrix(m_X.col(0)); + } + + /** + * Gets the final state in the trajectory. + * + * @return The final state of the trajectory. + */ + public VariableMatrix finalState() { + return new VariableMatrix(m_X.col(m_numSteps)); + } + + /** + * Performs 4th order Runge-Kutta integration of dx/dt = f(t, x, u) for dt. + * + * @param f The function to integrate. It must take two arguments x and u. + * @param x The initial value of x. + * @param u The value u held constant over the integration period. + * @param t0 The initial time. + * @param dt The time over which to integrate. + */ + private static VariableMatrix rk4( + DynamicsFunction f, VariableMatrix x, VariableMatrix u, Variable t0, Variable dt) { + var halfdt = dt.times(0.5); + VariableMatrix k1 = f.apply(t0, x, u, dt); + VariableMatrix k2 = f.apply(t0.plus(halfdt), x.plus(k1.times(halfdt)), u, dt); + VariableMatrix k3 = f.apply(t0.plus(halfdt), x.plus(k2.times(halfdt)), u, dt); + VariableMatrix k4 = f.apply(t0.plus(dt), x.plus(k3.times(dt)), u, dt); + + return x.plus(k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4).times(dt.div(6.0))); + } + + /** Applies direct collocation dynamics constraints. */ + private void constrainDirectCollocation() { + assert m_dynamicsType == DynamicsType.EXPLICIT_ODE; + + var time = new Variable(0.0); + + // Derivation at https://mec560sbu.github.io/2016/09/30/direct_collocation/ + for (int i = 0; i < m_numSteps; ++i) { + Variable h = dt().get(0, i); + + var f = m_dynamics; + + var t_begin = time; + var t_end = t_begin.plus(h); + + var x_begin = X().col(i); + var x_end = X().col(i + 1); + + var u_begin = U().col(i); + var u_end = U().col(i + 1); + + var xdot_begin = + f.apply(t_begin, new VariableMatrix(x_begin), new VariableMatrix(u_begin), h); + var xdot_end = f.apply(t_end, new VariableMatrix(x_end), new VariableMatrix(u_end), h); + var xdot_c = + x_begin + .minus(x_end) + .times(new Variable(-3).div(h.times(2))) + .minus(xdot_begin.plus(xdot_end).times(0.25)); + + var t_c = t_begin.plus(h.times(0.5)); + var x_c = x_begin.plus(x_end).times(0.5).plus(xdot_begin.minus(xdot_end).times(h.div(8))); + var u_c = u_begin.plus(u_end).times(0.5); + + subjectTo(eq(xdot_c, f.apply(t_c, x_c, u_c, h))); + + time = time.plus(h); + } + } + + /** Applies direct transcription dynamics constraints. */ + private void constrainDirectTranscription() { + var time = new Variable(0.0); + + for (int i = 0; i < m_numSteps; ++i) { + var x_begin = X().col(i); + var x_end = X().col(i + 1); + var u = U().col(i); + Variable dt = this.dt().get(0, i); + + if (m_dynamicsType == DynamicsType.EXPLICIT_ODE) { + subjectTo( + eq( + x_end, + rk4(m_dynamics, new VariableMatrix(x_begin), new VariableMatrix(u), time, dt))); + } else if (m_dynamicsType == DynamicsType.DISCRETE) { + subjectTo( + eq( + x_end, + m_dynamics.apply(time, new VariableMatrix(x_begin), new VariableMatrix(u), dt))); + } + + time = time.plus(dt); + } + } + + /** Applies single shooting dynamics constraints. */ + private void constrainSingleShooting() { + var time = new Variable(0.0); + + for (int i = 0; i < m_numSteps; ++i) { + var x_begin = X().col(i); + var x_end = X().col(i + 1); + var u = U().col(i); + Variable dt = this.dt().get(0, i); + + if (m_dynamicsType == DynamicsType.EXPLICIT_ODE) { + x_end.set(rk4(m_dynamics, new VariableMatrix(x_begin), new VariableMatrix(u), time, dt)); + } else if (m_dynamicsType == DynamicsType.DISCRETE) { + x_end.set(m_dynamics.apply(time, new VariableMatrix(x_begin), new VariableMatrix(u), dt)); + } + + time = time.plus(dt); + } + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/Problem.java b/wpimath/src/main/java/org/wpilib/math/optimization/Problem.java new file mode 100644 index 0000000000..7d5502e52a --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/optimization/Problem.java @@ -0,0 +1,312 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import java.util.ArrayList; +import java.util.function.Predicate; +import org.ejml.data.DMatrixRMaj; +import org.ejml.simple.SimpleMatrix; +import org.wpilib.math.autodiff.ExpressionType; +import org.wpilib.math.autodiff.NativeSparseTriplets; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.autodiff.VariableMatrix; +import org.wpilib.math.autodiff.VariablePool; +import org.wpilib.math.optimization.solver.ExitStatus; +import org.wpilib.math.optimization.solver.IterationInfo; +import org.wpilib.math.optimization.solver.Options; + +/** + * This class allows the user to pose a constrained nonlinear optimization problem in natural + * mathematical notation and solve it. + * + *

This class supports problems of the form: + * + *

+ *       minₓ f(x)
+ * subject to cₑ(x) = 0
+ *            cᵢ(x) ≥ 0
+ * 
+ * + *

where f(x) is the scalar cost function, x is the vector of decision variables (variables the + * solver can tweak to minimize the cost function), cᵢ(x) are the inequality constraints, and cₑ(x) + * are the equality constraints. Constraints are equations or inequalities of the decision variables + * that constrain what values the solver is allowed to use when searching for an optimal solution. + * + *

The nice thing about this class is users don't have to put their system in the form shown + * above manually; they can write it in natural mathematical form and it'll be converted for them. + */ +public class Problem implements AutoCloseable { + private long m_handle; + + // The iteration callbacks + private final ArrayList> m_iterationCallbacks = new ArrayList<>(); + + // Cleans up Variables allocated within Problem's scope + private final VariablePool m_pool = new VariablePool(); + + /** Construct the optimization problem. */ + @SuppressWarnings("this-escape") + public Problem() { + m_handle = ProblemJNI.create(); + } + + @Override + public void close() { + if (m_handle != 0) { + ProblemJNI.destroy(m_handle); + m_handle = 0; + + m_pool.close(); + } + } + + /** + * Creates a decision variable in the optimization problem. + * + *

Decision variables have an initial value of zero. + * + * @return A decision variable in the optimization problem. + */ + public Variable decisionVariable() { + var handles = ProblemJNI.decisionVariable(m_handle, 1, 1); + return new Variable(Variable.HANDLE, handles[0]); + } + + /** + * Creates a column vector of decision variables in the optimization problem. + * + *

Decision variables have an initial value of zero. + * + * @param rows Number of column vector rows. + * @return A column vector of decision variables in the optimization problem. + */ + public VariableMatrix decisionVariable(int rows) { + return decisionVariable(rows, 1); + } + + /** + * Creates a matrix of decision variables in the optimization problem. + * + *

Decision variables have an initial value of zero. + * + * @param rows Number of matrix rows. + * @param cols Number of matrix columns. + * @return A matrix of decision variables in the optimization problem. + */ + public VariableMatrix decisionVariable(int rows, int cols) { + return new VariableMatrix(rows, cols, ProblemJNI.decisionVariable(m_handle, rows, cols)); + } + + /** + * Creates a symmetric matrix of decision variables in the optimization problem. + * + *

Variable instances are reused across the diagonal, which helps reduce problem + * dimensionality. + * + *

Decision variables have an initial value of zero. + * + * @param rows Number of matrix rows. + * @return A symmetric matrix of decision varaibles in the optimization problem. + */ + public VariableMatrix symmetricDecisionVariable(int rows) { + return new VariableMatrix(rows, rows, ProblemJNI.symmetricDecisionVariable(m_handle, rows)); + } + + /** + * Tells the solver to minimize the output of the given cost function. + * + *

Note that this is optional. If only constraints are specified, the solver will find the + * closest solution to the initial conditions that's in the feasible set. + * + * @param cost The cost function to minimize. + */ + public void minimize(Variable cost) { + ProblemJNI.minimize(m_handle, cost.getHandle()); + } + + /** + * Tells the solver to minimize the output of the given cost function. + * + *

Note that this is optional. If only constraints are specified, the solver will find the + * closest solution to the initial conditions that's in the feasible set. + * + * @param cost The cost function to minimize. An assertion is raised if the VariableMatrix isn't + * 1x1. + */ + public void minimize(VariableMatrix cost) { + assert cost.rows() == 1 && cost.cols() == 1; + minimize(cost.get(0, 0)); + } + + /** + * Tells the solver to maximize the output of the given objective function. + * + *

Note that this is optional. If only constraints are specified, the solver will find the + * closest solution to the initial conditions that's in the feasible set. + * + * @param objective The objective function to maximize. + */ + public void maximize(Variable objective) { + ProblemJNI.maximize(m_handle, objective.getHandle()); + } + + /** + * Tells the solver to maximize the output of the given objective function. + * + *

Note that this is optional. If only constraints are specified, the solver will find the + * closest solution to the initial conditions that's in the feasible set. + * + * @param objective The objective function to maximize. An assertion is raised if the + * VariableMatrix isn't 1x1. + */ + public void maximize(VariableMatrix objective) { + assert objective.rows() == 1 && objective.cols() == 1; + maximize(objective.get(0, 0)); + } + + /** + * Tells the solver to solve the problem while satisfying the given equality constraint. + * + * @param constraint The constraint to satisfy. + */ + public void subjectTo(EqualityConstraints constraint) { + var constraintHandles = new long[constraint.constraints.length]; + for (int i = 0; i < constraintHandles.length; ++i) { + constraintHandles[i] = constraint.constraints[i].getHandle(); + } + ProblemJNI.subjectToEq(m_handle, constraintHandles); + } + + /** + * Tells the solver to solve the problem while satisfying the given inequality constraint. + * + * @param constraint The constraint to satisfy. + */ + public void subjectTo(InequalityConstraints constraint) { + var constraintHandles = new long[constraint.constraints.length]; + for (int i = 0; i < constraintHandles.length; ++i) { + constraintHandles[i] = constraint.constraints[i].getHandle(); + } + ProblemJNI.subjectToIneq(m_handle, constraintHandles); + } + + /** + * Returns the cost function's type. + * + * @return The cost function's type. + */ + public ExpressionType costFunctionType() { + return ExpressionType.fromInt(ProblemJNI.costFunctionType(m_handle)); + } + + /** + * Returns the type of the highest order equality constraint. + * + * @return The type of the highest order equality constraint. + */ + public ExpressionType equalityConstraintType() { + return ExpressionType.fromInt(ProblemJNI.equalityConstraintType(m_handle)); + } + + /** + * Returns the type of the highest order inequality constraint. + * + * @return The type of the highest order inequality constraint. + */ + public ExpressionType inequalityConstraintType() { + return ExpressionType.fromInt(ProblemJNI.inequalityConstraintType(m_handle)); + } + + /** + * Solves the optimization problem. The solution will be stored in the original variables used to + * construct the problem. + * + * @return The solver status. + */ + public ExitStatus solve() { + return solve(new Options()); + } + + /** + * Solves the optimization problem. The solution will be stored in the original variables used to + * construct the problem. + * + * @param options Solver options. + * @return The solver status. + */ + public ExitStatus solve(Options options) { + return ExitStatus.fromInt( + ProblemJNI.solve( + this, + m_handle, + options.tolerance, + options.maxIterations, + options.timeout, + options.feasibleIPM, + options.diagnostics)); + } + + /** + * Adds a callback to be called at the beginning of each solver iteration. + * + *

The callback for this overload should return bool. + * + * @param callback The callback. Returning true from the callback causes the solver to exit early + * with the solution it has so far. + */ + public void addCallback(Predicate callback) { + m_iterationCallbacks.add(callback); + } + + /** Clears the registered callbacks. */ + public void clearCallbacks() { + m_iterationCallbacks.clear(); + } + + /** + * Runs the registered callbacks. + * + *

This function is called by native code in ProblemJNI. + * + * @param numEqualityConstraints The number of equality constraints. + * @param numInequalityConstraints The number of inequality constraints. + * @param iteration The solver iteration. + * @param x The decision variable values. + * @param gTriplets Gradient triplets. + * @param HTriplets Hessian triplets. + * @param A_eTriplets Equality constraint Jacobian triplets. + * @param A_iTriplets Inequality constraint Jacobian triplets. + * @return True if the solver shold exit early. + */ + boolean runCallbacks( + int numEqualityConstraints, + int numInequalityConstraints, + int iteration, + double[] x, + NativeSparseTriplets gTriplets, + NativeSparseTriplets HTriplets, + NativeSparseTriplets A_eTriplets, + NativeSparseTriplets A_iTriplets) { + if (m_iterationCallbacks.isEmpty()) { + return false; + } + + var info = + new IterationInfo( + iteration, + new SimpleMatrix(DMatrixRMaj.wrap(x.length, 1, x)), + gTriplets.toSimpleMatrix(x.length, 1), + HTriplets.toSimpleMatrix(x.length, x.length), + A_eTriplets.toSimpleMatrix(numEqualityConstraints, x.length), + A_iTriplets.toSimpleMatrix(numInequalityConstraints, x.length)); + + for (var callback : m_iterationCallbacks) { + if (callback.test(info)) { + return true; + } + } + return false; + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/ProblemJNI.java b/wpimath/src/main/java/org/wpilib/math/optimization/ProblemJNI.java new file mode 100644 index 0000000000..c3f22ca8c7 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/optimization/ProblemJNI.java @@ -0,0 +1,134 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import org.wpilib.math.jni.WPIMathJNI; + +/** Problem JNI functions. */ +final class ProblemJNI extends WPIMathJNI { + private ProblemJNI() { + // Utility class. + } + + /** Construct the optimization problem. */ + static native long create(); + + /** + * Destruct the optimization problem. + * + * @param handle Problem handle. + */ + static native void destroy(long handle); + + /** + * Create a matrix of decision variables in the optimization problem. + * + * @param handle Problem handle. + * @param rows Number of matrix rows. + * @param cols Number of matrix columns. + * @return A matrix of decision variables in the optimization problem. + */ + static native long[] decisionVariable(long handle, int rows, int cols); + + /** + * Create a symmetric matrix of decision variables in the optimization problem. + * + *

Variable instances are reused across the diagonal, which helps reduce problem + * dimensionality. + * + * @param handle Problem handle. + * @param rows Number of matrix rows. + * @return A symmetric matrix of decision varaibles in the optimization problem. + */ + static native long[] symmetricDecisionVariable(long handle, int rows); + + /** + * Tells the solver to minimize the output of the given cost function. + * + *

Note that this is optional. If only constraints are specified, the solver will find the + * closest solution to the initial conditions that's in the feasible set. + * + * @param handle Problem handle. + * @param costHandle Variable handle of the cost function to minimize. + */ + static native void minimize(long handle, long costHandle); + + /** + * Tells the solver to maximize the output of the given objective function. + * + *

Note that this is optional. If only constraints are specified, the solver will find the + * closest solution to the initial conditions that's in the feasible set. + * + * @param handle Problem handle. + * @param objectiveHandle Variable handle of the objective function to maximize. + */ + static native void maximize(long handle, long objectiveHandle); + + /** + * Tells the solver to solve the problem while satisfying the given equality constraint. + * + * @param handle Problem handle. + * @param constraintHandles Constraint handles. + */ + static native void subjectToEq(long handle, long[] constraintHandles); + + /** + * Tells the solver to solve the problem while satisfying the given inequality constraint. + * + * @param handle Problem handle. + * @param constraintHandles Constraint handles. + */ + static native void subjectToIneq(long handle, long[] constraintHandles); + + /** + * Returns the cost function's type. + * + * @param handle Problem handle. + * @return The cost function's type. + */ + static native int costFunctionType(long handle); + + /** + * Returns the type of the highest order equality constraint. + * + * @param handle Problem handle. + * @return The type of the highest order equality constraint. + */ + static native int equalityConstraintType(long handle); + + /** + * Returns the type of the highest order inequality constraint. + * + * @param handle Problem handle. + * @return The type of the highest order inequality constraint. + */ + static native int inequalityConstraintType(long handle); + + /** + * Solve the optimization problem. The solution will be stored in the original variables used to + * construct the problem. + * + * @param obj Java Problem object. + * @param handle Problem handle. + * @param tolerance The solver will stop once the error is below this tolerance. + * @param maxIterations The maximum number of solver iterations before returning a solution. + * @param timeout The maximum elapsed wall clock time in seconds before returning a solution. + * @param feasibleIPM Enables the feasible interior-point method. When the inequality constraints + * are all feasible, step sizes are reduced when necessary to prevent them becoming infeasible + * again. This is useful when parts of the problem are ill-conditioned in infeasible regions + * (e.g., square root of a negative value). This can slow or prevent progress toward a + * solution though, so only enable it if necessary. + * @param diagnnostics Enables diagnostic prints. + * @return The solver status. + */ + static native int solve( + Problem obj, + long handle, + double tolerance, + int maxIterations, + double timeout, + boolean feasibleIPM, + boolean diagnostics); +} diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/ocp/ConstraintEvaluationFunction.java b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/ConstraintEvaluationFunction.java new file mode 100644 index 0000000000..8c44700f49 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/ConstraintEvaluationFunction.java @@ -0,0 +1,25 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization.ocp; + +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.autodiff.VariableMatrix; + +/** + * A callback f(t, x, u, dt) where t is time, x is the state vector, u is the input vector, and dt + * is the timestep duration. + */ +@FunctionalInterface +public interface ConstraintEvaluationFunction { + /** + * Applies this function with the arguments. + * + * @param t Time in seconds. + * @param x State vector. + * @param u Input vector. + * @param dt Timestep duration in seconds. + */ + void accept(Variable t, VariableMatrix x, VariableMatrix u, Variable dt); +} diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/ocp/DynamicsFunction.java b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/DynamicsFunction.java new file mode 100644 index 0000000000..d60fae255a --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/DynamicsFunction.java @@ -0,0 +1,31 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization.ocp; + +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.autodiff.VariableMatrix; + +/** + * Function representing an explicit or implicit ODE, or a discrete state transition function. + * + *

    + *
  • Explicit: dx/dt = f(t, x, u, *) + *
  • Implicit: f(t, [x dx/dt]', u, *) = 0 + *
  • State transition: xₖ₊₁ = f(t, xₖ, uₖ, dt) + *
+ */ +@FunctionalInterface +public interface DynamicsFunction { + /** + * Applies this function with the arguments and returns the result. + * + * @param t Time in seconds. + * @param x State vector. + * @param u Input vector. + * @param dt Timestep duration in seconds. + * @return The state derivative dx/dt. + */ + VariableMatrix apply(Variable t, VariableMatrix x, VariableMatrix u, Variable dt); +} diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/ocp/DynamicsType.java b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/DynamicsType.java new file mode 100644 index 0000000000..e5ccea4763 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/DynamicsType.java @@ -0,0 +1,20 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization.ocp; + +/** Enum describing a type of system dynamics constraints. */ +public enum DynamicsType { + /** The dynamics are a function in the form dx/dt = f(t, x, u). */ + EXPLICIT_ODE(0), + /** The dynamics are a function in the form xₖ₊₁ = f(t, xₖ, uₖ). */ + DISCRETE(1); + + /** DynamicsType value. */ + public final int value; + + DynamicsType(int value) { + this.value = value; + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/ocp/TimestepMethod.java b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/TimestepMethod.java new file mode 100644 index 0000000000..949082e5fd --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/TimestepMethod.java @@ -0,0 +1,22 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization.ocp; + +/** Enum describing the type of system timestep. */ +public enum TimestepMethod { + /** The timestep is a fixed constant. */ + FIXED(0), + /** The timesteps are allowed to vary as independent decision variables. */ + VARIABLE(1), + /** The timesteps are equal length but allowed to vary as a single decision variable. */ + VARIABLE_SINGLE(2); + + /** TimestepMethod value. */ + public final int value; + + TimestepMethod(int value) { + this.value = value; + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/ocp/TranscriptionMethod.java b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/TranscriptionMethod.java new file mode 100644 index 0000000000..93bbffb93e --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/TranscriptionMethod.java @@ -0,0 +1,27 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization.ocp; + +/** Enum describing an OCP transcription method. */ +public enum TranscriptionMethod { + /** + * Each state is a decision variable constrained to the integrated dynamics of the previous state. + */ + DIRECT_TRANSCRIPTION(0), + /** + * The trajectory is modeled as a series of cubic polynomials where the centerpoint slope is + * constrained. + */ + DIRECT_COLLOCATION(1), + /** States depend explicitly as a function of all previous states and all previous inputs. */ + SINGLE_SHOOTING(2); + + /** TranscriptionMethod value. */ + public final int value; + + TranscriptionMethod(int value) { + this.value = value; + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/solver/ExitStatus.java b/wpimath/src/main/java/org/wpilib/math/optimization/solver/ExitStatus.java new file mode 100644 index 0000000000..705bcfd319 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/optimization/solver/ExitStatus.java @@ -0,0 +1,66 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization.solver; + +/** Solver exit status. Negative values indicate failure. */ +public enum ExitStatus { + /** Solved the problem to the desired tolerance. */ + SUCCESS(0), + /** The solver returned its solution so far after the user requested a stop. */ + CALLBACK_REQUESTED_STOP(1), + /** The solver determined the problem to be overconstrained and gave up. */ + TOO_FEW_DOFS(-1), + /** The solver determined the problem to be locally infeasible and gave up. */ + LOCALLY_INFEASIBLE(-2), + /** The problem setup frontend determined the problem to have an empty feasible region. */ + GLOBALLY_INFEASIBLE(-3), + /** The linear system factorization failed. */ + FACTORIZATION_FAILED(-4), + /** + * The solver failed to reach the desired tolerance, and feasibility restoration failed to + * converge. + */ + FEASIBILITY_RESTORATION_FAILED(-5), + /** The solver encountered nonfinite initial cost, constraints, or derivatives and gave up. */ + NONFINITE_INITIAL_GUESS(-6), + /** The solver encountered diverging primal iterates xₖ and/or sₖ and gave up. */ + DIVERGING_ITERATES(-7), + /** The solver returned its solution so far after exceeding the maximum number of iterations. */ + MAX_ITERATIONS_EXCEEDED(-8), + /** + * The solver returned its solution so far after exceeding the maximum elapsed wall clock time. + */ + TIMEOUT(-9); + + /** ExitStatus value. */ + public final int value; + + ExitStatus(int value) { + this.value = value; + } + + /** + * Converts integer to its corresponding enum value. + * + * @param x The integer. + * @return The enum value. + */ + public static ExitStatus fromInt(int x) { + return switch (x) { + case 0 -> ExitStatus.SUCCESS; + case 1 -> ExitStatus.CALLBACK_REQUESTED_STOP; + case -1 -> ExitStatus.TOO_FEW_DOFS; + case -2 -> ExitStatus.LOCALLY_INFEASIBLE; + case -3 -> ExitStatus.GLOBALLY_INFEASIBLE; + case -4 -> ExitStatus.FACTORIZATION_FAILED; + case -5 -> ExitStatus.FEASIBILITY_RESTORATION_FAILED; + case -6 -> ExitStatus.NONFINITE_INITIAL_GUESS; + case -7 -> ExitStatus.DIVERGING_ITERATES; + case -8 -> ExitStatus.MAX_ITERATIONS_EXCEEDED; + case -9 -> ExitStatus.TIMEOUT; + default -> null; + }; + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/solver/IterationInfo.java b/wpimath/src/main/java/org/wpilib/math/optimization/solver/IterationInfo.java new file mode 100644 index 0000000000..08f8866174 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/optimization/solver/IterationInfo.java @@ -0,0 +1,53 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization.solver; + +import org.ejml.simple.SimpleMatrix; + +/** Solver iteration information exposed to an iteration callback. */ +public class IterationInfo { + /** The solver iteration. */ + public final int iteration; + + /** The decision variables (dense internal storage). */ + public final SimpleMatrix x; + + /** The gradient of the cost function (sparse internal storage). */ + public final SimpleMatrix g; + + /** The Hessian of the Lagrangian (sparse internal storage). */ + public final SimpleMatrix H; + + /** The equality constraint Jacobian (sparse internal storage). */ + public final SimpleMatrix A_e; + + /** The inequality constraint Jacobian (sparse internal storage). */ + public final SimpleMatrix A_i; + + /** + * Constructs iteration info. + * + * @param iteration The solver iteration. + * @param x The decision variables (dense internal storage). + * @param g The gradient of the cost function (sparse internal storage). + * @param H The Hessian of the Lagrangian (sparse internal storage). + * @param A_e The equality constraint Jacobian (sparse internal storage). + * @param A_i The inequality constraint Jacobian (sparse internal storage). + */ + public IterationInfo( + int iteration, + SimpleMatrix x, + SimpleMatrix g, + SimpleMatrix H, + SimpleMatrix A_e, + SimpleMatrix A_i) { + this.iteration = iteration; + this.x = x; + this.g = g; + this.H = H; + this.A_e = A_e; + this.A_i = A_i; + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/solver/Options.java b/wpimath/src/main/java/org/wpilib/math/optimization/solver/Options.java new file mode 100644 index 0000000000..a83f099871 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/optimization/solver/Options.java @@ -0,0 +1,98 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization.solver; + +/** Solver options. */ +public class Options { + /** The solver will stop once the error is below this tolerance. */ + public double tolerance = 1e-8; + + /** The maximum number of solver iterations before returning a solution. */ + public int maxIterations = 5000; + + /** The maximum elapsed wall clock time in seconds before returning a solution. */ + public double timeout = Double.POSITIVE_INFINITY; + + /** + * Enables the feasible interior-point method. + * + *

When the inequality constraints are all feasible, step sizes are reduced when necessary to + * prevent them becoming infeasible again. This is useful when parts of the problem are + * ill-conditioned in infeasible regions (e.g., square root of a negative value). This can slow or + * prevent progress toward a solution though, so only enable it if necessary. + */ + public boolean feasibleIPM = false; + + /** + * Enables diagnostic output. + * + *

See https://sleipnirgroup.github.io/Sleipnir/md_usage.html#output + * for more information. + */ + public boolean diagnostics = false; + + /** Default options. */ + public Options() {} + + /** + * Set tolerance. + * + * @param tolerance The solver will stop once the error is below this tolerance. + * @return This Options object. + */ + public Options withTolerance(double tolerance) { + this.tolerance = tolerance; + return this; + } + + /** + * Set max iterations. + * + * @param maxIterations The maximum number of solver iterations before returning a solution. + * @return This Options object. + */ + public Options withMaxIterations(int maxIterations) { + this.maxIterations = maxIterations; + return this; + } + + /** + * Set timeout. + * + * @param timeout The maximum elapsed wall clock time in seconds before returning a solution. + * @return This Options object. + */ + public Options withTimeout(double timeout) { + this.timeout = timeout; + return this; + } + + /** + * Enable or disable feasible IPM. + * + * @param feasibleIPM Enables the feasible interior-point method. When the inequality constraints + * are all feasible, step sizes are reduced when necessary to prevent them becoming infeasible + * again. This is useful when parts of the problem are ill-conditioned in infeasible regions + * (e.g., square root of a negative value). This can slow or prevent progress toward a + * solution though, so only enable it if necessary. + * @return This Options object. + */ + public Options withFeasibleIPM(boolean feasibleIPM) { + this.feasibleIPM = feasibleIPM; + return this; + } + + /** + * Enable or disable diagnostics. + * + * @param diagnostics Enables diagnostic prints. + * @return This Options object. + */ + public Options withDiagnostics(boolean diagnostics) { + this.diagnostics = diagnostics; + return this; + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/system/NumericalIntegration.java b/wpimath/src/main/java/org/wpilib/math/system/NumericalIntegration.java index ac760488b7..88ad7c6d41 100644 --- a/wpimath/src/main/java/org/wpilib/math/system/NumericalIntegration.java +++ b/wpimath/src/main/java/org/wpilib/math/system/NumericalIntegration.java @@ -8,6 +8,7 @@ import java.util.function.BiFunction; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; import java.util.function.UnaryOperator; +import org.ejml.simple.SimpleMatrix; import org.wpilib.math.linalg.Matrix; import org.wpilib.math.numbers.N1; import org.wpilib.math.util.Num; @@ -56,6 +57,30 @@ public final class NumericalIntegration { return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4); } + /** + * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt. + * + * @param f The function to integrate. It must take two arguments x and u. + * @param x The initial value of x. + * @param u The value u held constant over the integration period. + * @param dt The time over which to integrate. + * @return the integration of dx/dt = f(x, u) for dt. + */ + public static SimpleMatrix rk4( + BiFunction f, + SimpleMatrix x, + SimpleMatrix u, + double dt) { + var h = dt; + + var k1 = f.apply(x, u); + var k2 = f.apply(x.plus(k1.scale(h * 0.5)), u); + var k3 = f.apply(x.plus(k2.scale(h * 0.5)), u); + var k4 = f.apply(x.plus(k3.scale(h)), u); + + return x.plus(k1.plus(k2.scale(2.0)).plus(k3.scale(2.0)).plus(k4).scale(h / 6.0)); + } + /** * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt. * diff --git a/wpimath/src/main/native/cpp/jni/ArmFeedforwardJNI.cpp b/wpimath/src/main/native/cpp/jni/ArmFeedforwardJNI.cpp deleted file mode 100644 index a6a46e5b37..0000000000 --- a/wpimath/src/main/native/cpp/jni/ArmFeedforwardJNI.cpp +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#include - -#include "org_wpilib_math_jni_ArmFeedforwardJNI.h" -#include "wpi/math/controller/ArmFeedforward.hpp" -#include "wpi/util/jni_util.hpp" - -using namespace wpi::util::java; - -extern "C" { - -/* - * Class: org_wpilib_math_jni_ArmFeedforwardJNI - * Method: calculate - * Signature: (DDDDDDDD)D - */ -JNIEXPORT jdouble JNICALL -Java_org_wpilib_math_jni_ArmFeedforwardJNI_calculate - (JNIEnv* env, jclass, jdouble ks, jdouble kv, jdouble ka, jdouble kg, - jdouble currentAngle, jdouble currentVelocity, jdouble nextVelocity, - jdouble dt) -{ - return wpi::math::ArmFeedforward{ - wpi::units::volt_t{ks}, wpi::units::volt_t{kg}, - wpi::units::unit_t{kv}, - wpi::units::unit_t{ka}, - wpi::units::second_t{dt}} - .Calculate(wpi::units::radian_t{currentAngle}, - wpi::units::radians_per_second_t{currentVelocity}, - wpi::units::radians_per_second_t{nextVelocity}) - .value(); -} - -} // extern "C" diff --git a/wpimath/src/main/native/cpp/jni/Ellipse2dJNI.cpp b/wpimath/src/main/native/cpp/jni/Ellipse2dJNI.cpp deleted file mode 100644 index 14f7a0c7a4..0000000000 --- a/wpimath/src/main/native/cpp/jni/Ellipse2dJNI.cpp +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) FIRST and other WPILib contributors. -// Open Source Software; you can modify and/or share it under the terms of -// the WPILib BSD license file in the root directory of this project. - -#include - -#include "org_wpilib_math_jni_Ellipse2dJNI.h" -#include "wpi/math/geometry/Ellipse2d.hpp" -#include "wpi/util/array.hpp" -#include "wpi/util/jni_util.hpp" - -using namespace wpi::util::java; - -extern "C" { - -/* - * Class: org_wpilib_math_jni_Ellipse2dJNI - * Method: nearest - * Signature: (DDDDDDD[D)V - */ -JNIEXPORT void JNICALL -Java_org_wpilib_math_jni_Ellipse2dJNI_nearest - (JNIEnv* env, jclass, jdouble centerX, jdouble centerY, jdouble centerHeading, - jdouble xSemiAxis, jdouble ySemiAxis, jdouble pointX, jdouble pointY, - jdoubleArray nearestPoint) -{ - auto point = - wpi::math::Ellipse2d{ - wpi::math::Pose2d{wpi::units::meter_t{centerX}, - wpi::units::meter_t{centerY}, - wpi::units::radian_t{centerHeading}}, - wpi::units::meter_t{xSemiAxis}, wpi::units::meter_t{ySemiAxis}} - .Nearest({wpi::units::meter_t{pointX}, wpi::units::meter_t{pointY}}); - - wpi::util::array buf{point.X().value(), point.Y().value()}; - env->SetDoubleArrayRegion(nearestPoint, 0, 2, buf.data()); -} - -} // extern "C" diff --git a/wpimath/src/main/native/cpp/jni/SleipnirJNIUtil.hpp b/wpimath/src/main/native/cpp/jni/SleipnirJNIUtil.hpp new file mode 100644 index 0000000000..1cdfb9dff4 --- /dev/null +++ b/wpimath/src/main/native/cpp/jni/SleipnirJNIUtil.hpp @@ -0,0 +1,65 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#pragma once + +#include + +#include +#include + +#include + +#include "wpi/util/jni_util.hpp" + +namespace wpi::math::detail { + +/** + * Converts Eigen sparse matrix to triplets. + * + * @param env JNI environment. + * @param mat Eigen sparse matrix to convert. + * @return NativeSparseTriplets instance. + */ +template + requires std::derived_from> +jobject GetTriplets(JNIEnv* env, const Derived& mat) { + const int nonZeros = mat.nonZeros(); + + std::vector rows; + rows.reserve(nonZeros); + + std::vector cols; + cols.reserve(nonZeros); + + std::vector values; + values.reserve(nonZeros); + + for (int k = 0; k < mat.outerSize(); ++k) { + for (typename Derived::InnerIterator it{mat, k}; it; ++it) { + rows.emplace_back(it.row()); + cols.emplace_back(it.col()); + values.emplace_back(it.value()); + } + } + + // Find NativeSparseTriplets class + static wpi::util::java::JClass cls{ + env, "org/wpilib/math/autodiff/NativeSparseTriplets"}; + if (!cls) { + return nullptr; + } + + // Find NativeSparseTriplets constructor + static jmethodID ctor = env->GetMethodID(cls, "", "([I[I[D)V"); + if (!ctor) { + return nullptr; + } + + return env->NewObject(cls, ctor, wpi::util::java::MakeJIntArray(env, rows), + wpi::util::java::MakeJIntArray(env, cols), + wpi::util::java::MakeJDoubleArray(env, values)); +} + +} // namespace wpi::math::detail diff --git a/wpimath/src/main/native/cpp/jni/autodiff/GradientJNI.cpp b/wpimath/src/main/native/cpp/jni/autodiff/GradientJNI.cpp new file mode 100644 index 0000000000..399b8c0738 --- /dev/null +++ b/wpimath/src/main/native/cpp/jni/autodiff/GradientJNI.cpp @@ -0,0 +1,90 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#include + +#include +#include + +#include +#include +#include + +#include "../SleipnirJNIUtil.hpp" +#include "org_wpilib_math_autodiff_GradientJNI.h" +#include "wpi/util/jni_util.hpp" + +using namespace wpi::util::java; + +extern "C" { + +/* + * Class: org_wpilib_math_autodiff_GradientJNI + * Method: create + * Signature: (J[J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_GradientJNI_create + (JNIEnv* env, jclass, jlong variable, jlongArray wrt) +{ + auto& variableObj = *reinterpret_cast*>(variable); + + JSpan wrtSpan{env, wrt}; + slp::VariableMatrix wrtObj(slp::detail::empty, wrtSpan.size(), 1); + for (size_t i = 0; i < wrtSpan.size(); ++i) { + wrtObj[i] = *reinterpret_cast*>(wrtSpan[i]); + } + + return reinterpret_cast( + new slp::Gradient{variableObj, std::move(wrtObj)}); +} + +/* + * Class: org_wpilib_math_autodiff_GradientJNI + * Method: destroy + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_org_wpilib_math_autodiff_GradientJNI_destroy + (JNIEnv* env, jclass, jlong handle) +{ + delete reinterpret_cast*>(handle); +} + +/* + * Class: org_wpilib_math_autodiff_GradientJNI + * Method: get + * Signature: (J)[J + */ +JNIEXPORT jlongArray JNICALL +Java_org_wpilib_math_autodiff_GradientJNI_get + (JNIEnv* env, jclass, jlong handle) +{ + auto& gradient = *reinterpret_cast*>(handle); + auto g = gradient.get(); + + std::vector varHandles; + varHandles.reserve(g.size()); + for (auto& var : g) { + varHandles.emplace_back( + reinterpret_cast(new slp::Variable{var})); + } + + return MakeJLongArray(env, varHandles); +} + +/* + * Class: org_wpilib_math_autodiff_GradientJNI + * Method: value + * Signature: (J)Ljava/lang/Object; + */ +JNIEXPORT jobject JNICALL +Java_org_wpilib_math_autodiff_GradientJNI_value + (JNIEnv* env, jclass, jlong handle) +{ + auto& gradient = *reinterpret_cast*>(handle); + return wpi::math::detail::GetTriplets(env, gradient.value()); +} + +} // extern "C" diff --git a/wpimath/src/main/native/cpp/jni/autodiff/HessianJNI.cpp b/wpimath/src/main/native/cpp/jni/autodiff/HessianJNI.cpp new file mode 100644 index 0000000000..fd569fa8ba --- /dev/null +++ b/wpimath/src/main/native/cpp/jni/autodiff/HessianJNI.cpp @@ -0,0 +1,96 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#include + +#include +#include + +#include +#include +#include + +#include "../SleipnirJNIUtil.hpp" +#include "org_wpilib_math_autodiff_HessianJNI.h" +#include "wpi/util/jni_util.hpp" + +using namespace wpi::util::java; + +extern "C" { + +/* + * Class: org_wpilib_math_autodiff_HessianJNI + * Method: create + * Signature: (J[J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_HessianJNI_create + (JNIEnv* env, jclass, jlong variable, jlongArray wrt) +{ + auto& variableObj = *reinterpret_cast*>(variable); + + JSpan wrtSpan{env, wrt}; + slp::VariableMatrix wrtObj(slp::detail::empty, wrtSpan.size(), 1); + for (size_t i = 0; i < wrtSpan.size(); ++i) { + wrtObj[i] = *reinterpret_cast*>(wrtSpan[i]); + } + + return reinterpret_cast( + new slp::Hessian{variableObj, + std::move(wrtObj)}); +} + +/* + * Class: org_wpilib_math_autodiff_HessianJNI + * Method: destroy + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_org_wpilib_math_autodiff_HessianJNI_destroy + (JNIEnv* env, jclass, jlong handle) +{ + delete reinterpret_cast*>( + handle); +} + +/* + * Class: org_wpilib_math_autodiff_HessianJNI + * Method: get + * Signature: (J)[J + */ +JNIEXPORT jlongArray JNICALL +Java_org_wpilib_math_autodiff_HessianJNI_get + (JNIEnv* env, jclass, jlong handle) +{ + auto& hessian = + *reinterpret_cast*>( + handle); + auto H = hessian.get(); + + std::vector varHandles; + varHandles.reserve(H.size()); + for (auto& var : H) { + varHandles.emplace_back( + reinterpret_cast(new slp::Variable{var})); + } + + return MakeJLongArray(env, varHandles); +} + +/* + * Class: org_wpilib_math_autodiff_HessianJNI + * Method: value + * Signature: (J)Ljava/lang/Object; + */ +JNIEXPORT jobject JNICALL +Java_org_wpilib_math_autodiff_HessianJNI_value + (JNIEnv* env, jclass, jlong handle) +{ + auto& hessian = + *reinterpret_cast*>( + handle); + return wpi::math::detail::GetTriplets(env, hessian.value()); +} + +} // extern "C" diff --git a/wpimath/src/main/native/cpp/jni/autodiff/JacobianJNI.cpp b/wpimath/src/main/native/cpp/jni/autodiff/JacobianJNI.cpp new file mode 100644 index 0000000000..b627989a44 --- /dev/null +++ b/wpimath/src/main/native/cpp/jni/autodiff/JacobianJNI.cpp @@ -0,0 +1,96 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#include + +#include +#include + +#include +#include +#include + +#include "../SleipnirJNIUtil.hpp" +#include "org_wpilib_math_autodiff_JacobianJNI.h" +#include "wpi/util/jni_util.hpp" + +using namespace wpi::util::java; + +extern "C" { + +/* + * Class: org_wpilib_math_autodiff_JacobianJNI + * Method: create + * Signature: ([J[J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_JacobianJNI_create + (JNIEnv* env, jclass, jlongArray variables, jlongArray wrt) +{ + JSpan variablesSpan{env, variables}; + slp::VariableMatrix variablesObj(slp::detail::empty, + variablesSpan.size(), 1); + for (size_t i = 0; i < variablesSpan.size(); ++i) { + variablesObj[i] = + *reinterpret_cast*>(variablesSpan[i]); + } + + JSpan wrtSpan{env, wrt}; + slp::VariableMatrix wrtObj(slp::detail::empty, wrtSpan.size(), 1); + for (size_t i = 0; i < wrtSpan.size(); ++i) { + wrtObj[i] = *reinterpret_cast*>(wrtSpan[i]); + } + + return reinterpret_cast( + new slp::Jacobian{std::move(variablesObj), std::move(wrtObj)}); +} + +/* + * Class: org_wpilib_math_autodiff_JacobianJNI + * Method: destroy + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_org_wpilib_math_autodiff_JacobianJNI_destroy + (JNIEnv* env, jclass, jlong handle) +{ + delete reinterpret_cast*>(handle); +} + +/* + * Class: org_wpilib_math_autodiff_JacobianJNI + * Method: get + * Signature: (J)[J + */ +JNIEXPORT jlongArray JNICALL +Java_org_wpilib_math_autodiff_JacobianJNI_get + (JNIEnv* env, jclass, jlong handle) +{ + auto& jacobian = *reinterpret_cast*>(handle); + auto J = jacobian.get(); + + std::vector varHandles; + varHandles.reserve(J.size()); + for (auto& var : J) { + varHandles.emplace_back( + reinterpret_cast(new slp::Variable{var})); + } + + return MakeJLongArray(env, varHandles); +} + +/* + * Class: org_wpilib_math_autodiff_JacobianJNI + * Method: value + * Signature: (J)Ljava/lang/Object; + */ +JNIEXPORT jobject JNICALL +Java_org_wpilib_math_autodiff_JacobianJNI_value + (JNIEnv* env, jclass, jlong handle) +{ + auto& jacobian = *reinterpret_cast*>(handle); + return wpi::math::detail::GetTriplets(env, jacobian.value()); +} + +} // extern "C" diff --git a/wpimath/src/main/native/cpp/jni/autodiff/VariableJNI.cpp b/wpimath/src/main/native/cpp/jni/autodiff/VariableJNI.cpp new file mode 100644 index 0000000000..5a7642de02 --- /dev/null +++ b/wpimath/src/main/native/cpp/jni/autodiff/VariableJNI.cpp @@ -0,0 +1,463 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#include + +#include + +#include "org_wpilib_math_autodiff_VariableJNI.h" + +extern "C" { + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: createDefault + * Signature: ()J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_createDefault + (JNIEnv* env, jclass) +{ + return reinterpret_cast(new slp::Variable{}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: createDouble + * Signature: (D)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_createDouble + (JNIEnv* env, jclass, jdouble value) +{ + return reinterpret_cast(new slp::Variable{value}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: createInt + * Signature: (I)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_createInt + (JNIEnv* env, jclass, jint value) +{ + return reinterpret_cast(new slp::Variable{value}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: destroy + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_destroy + (JNIEnv* env, jclass, jlong handle) +{ + delete reinterpret_cast*>(handle); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: setValue + * Signature: (JD)V + */ +JNIEXPORT void JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_setValue + (JNIEnv* env, jclass, jlong handle, jdouble value) +{ + auto& lhsVar = *reinterpret_cast*>(handle); + lhsVar.set_value(value); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: times + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_times + (JNIEnv* env, jclass, jlong handle, jlong rhs) +{ + auto& lhsVar = *reinterpret_cast*>(handle); + auto& rhsVar = *reinterpret_cast*>(rhs); + return reinterpret_cast(new slp::Variable{lhsVar * rhsVar}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: div + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_div + (JNIEnv* env, jclass, jlong handle, jlong rhs) +{ + auto& lhsVar = *reinterpret_cast*>(handle); + auto& rhsVar = *reinterpret_cast*>(rhs); + return reinterpret_cast(new slp::Variable{lhsVar / rhsVar}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: plus + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_plus + (JNIEnv* env, jclass, jlong handle, jlong rhs) +{ + auto& lhsVar = *reinterpret_cast*>(handle); + auto& rhsVar = *reinterpret_cast*>(rhs); + return reinterpret_cast(new slp::Variable{lhsVar + rhsVar}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: minus + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_minus + (JNIEnv* env, jclass, jlong handle, jlong rhs) +{ + auto& lhsVar = *reinterpret_cast*>(handle); + auto& rhsVar = *reinterpret_cast*>(rhs); + return reinterpret_cast(new slp::Variable{lhsVar - rhsVar}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: unaryMinus + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_unaryMinus + (JNIEnv* env, jclass, jlong handle) +{ + auto& lhsVar = *reinterpret_cast*>(handle); + return reinterpret_cast(new slp::Variable{-lhsVar}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: value + * Signature: (J)D + */ +JNIEXPORT jdouble JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_value + (JNIEnv* env, jclass, jlong handle) +{ + auto& lhsVar = *reinterpret_cast*>(handle); + return lhsVar.value(); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: type + * Signature: (J)I + */ +JNIEXPORT jint JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_type + (JNIEnv* env, jclass, jlong handle) +{ + auto& lhsVar = *reinterpret_cast*>(handle); + return static_cast(lhsVar.type()); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: abs + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_abs + (JNIEnv* env, jclass, jlong x) +{ + auto& xVar = *reinterpret_cast*>(x); + return reinterpret_cast(new slp::Variable{abs(xVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: acos + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_acos + (JNIEnv* env, jclass, jlong x) +{ + auto& xVar = *reinterpret_cast*>(x); + return reinterpret_cast(new slp::Variable{acos(xVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: asin + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_asin + (JNIEnv* env, jclass, jlong x) +{ + auto& xVar = *reinterpret_cast*>(x); + return reinterpret_cast(new slp::Variable{asin(xVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: atan + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_atan + (JNIEnv* env, jclass, jlong x) +{ + auto& xVar = *reinterpret_cast*>(x); + return reinterpret_cast(new slp::Variable{atan(xVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: atan2 + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_atan2 + (JNIEnv* env, jclass, jlong y, jlong x) +{ + auto& xVar = *reinterpret_cast*>(x); + auto& yVar = *reinterpret_cast*>(y); + return reinterpret_cast(new slp::Variable{atan2(yVar, xVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: cbrt + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_cbrt + (JNIEnv* env, jclass, jlong x) +{ + auto& xVar = *reinterpret_cast*>(x); + return reinterpret_cast(new slp::Variable{cbrt(xVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: cos + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_cos + (JNIEnv* env, jclass, jlong x) +{ + auto& xVar = *reinterpret_cast*>(x); + return reinterpret_cast(new slp::Variable{cos(xVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: cosh + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_cosh + (JNIEnv* env, jclass, jlong x) +{ + auto& xVar = *reinterpret_cast*>(x); + return reinterpret_cast(new slp::Variable{cosh(xVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: exp + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_exp + (JNIEnv* env, jclass, jlong x) +{ + auto& xVar = *reinterpret_cast*>(x); + return reinterpret_cast(new slp::Variable{exp(xVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: hypot + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_hypot + (JNIEnv* env, jclass, jlong x, jlong y) +{ + auto& xVar = *reinterpret_cast*>(x); + auto& yVar = *reinterpret_cast*>(y); + return reinterpret_cast(new slp::Variable{hypot(xVar, yVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: log + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_log + (JNIEnv* env, jclass, jlong x) +{ + auto& xVar = *reinterpret_cast*>(x); + return reinterpret_cast(new slp::Variable{log(xVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: log10 + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_log10 + (JNIEnv* env, jclass, jlong x) +{ + auto& xVar = *reinterpret_cast*>(x); + return reinterpret_cast(new slp::Variable{log10(xVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: max + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_max + (JNIEnv* env, jclass, jlong a, jlong b) +{ + auto& aVar = *reinterpret_cast*>(a); + auto& bVar = *reinterpret_cast*>(b); + return reinterpret_cast( + new slp::Variable{(slp::max)(aVar, bVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: min + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_min + (JNIEnv* env, jclass, jlong a, jlong b) +{ + auto& aVar = *reinterpret_cast*>(a); + auto& bVar = *reinterpret_cast*>(b); + return reinterpret_cast( + new slp::Variable{(slp::min)(aVar, bVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: pow + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_pow + (JNIEnv* env, jclass, jlong base, jlong power) +{ + auto& baseVar = *reinterpret_cast*>(base); + auto& powerVar = *reinterpret_cast*>(power); + return reinterpret_cast( + new slp::Variable{pow(baseVar, powerVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: signum + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_signum + (JNIEnv* env, jclass, jlong x) +{ + auto& xVar = *reinterpret_cast*>(x); + return reinterpret_cast(new slp::Variable{sign(xVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: sin + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_sin + (JNIEnv* env, jclass, jlong x) +{ + auto& xVar = *reinterpret_cast*>(x); + return reinterpret_cast(new slp::Variable{sin(xVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: sinh + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_sinh + (JNIEnv* env, jclass, jlong x) +{ + auto& xVar = *reinterpret_cast*>(x); + return reinterpret_cast(new slp::Variable{sinh(xVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: sqrt + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_sqrt + (JNIEnv* env, jclass, jlong x) +{ + auto& xVar = *reinterpret_cast*>(x); + return reinterpret_cast(new slp::Variable{sqrt(xVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: tan + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_tan + (JNIEnv* env, jclass, jlong x) +{ + auto& xVar = *reinterpret_cast*>(x); + return reinterpret_cast(new slp::Variable{tan(xVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: tanh + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_tanh + (JNIEnv* env, jclass, jlong x) +{ + auto& xVar = *reinterpret_cast*>(x); + return reinterpret_cast(new slp::Variable{tanh(xVar)}); +} + +/* + * Class: org_wpilib_math_autodiff_VariableJNI + * Method: totalNativeMemoryUsage + * Signature: ()J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_autodiff_VariableJNI_totalNativeMemoryUsage + (JNIEnv* env, jclass) +{ + return slp::global_pool_resource().blocks_in_use() * + sizeof(slp::detail::Expression); +} + +} // extern "C" diff --git a/wpimath/src/main/native/cpp/jni/autodiff/VariableMatrixJNI.cpp b/wpimath/src/main/native/cpp/jni/autodiff/VariableMatrixJNI.cpp new file mode 100644 index 0000000000..4065d3066d --- /dev/null +++ b/wpimath/src/main/native/cpp/jni/autodiff/VariableMatrixJNI.cpp @@ -0,0 +1,53 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#include + +#include + +#include +#include + +#include "org_wpilib_math_autodiff_VariableMatrixJNI.h" +#include "wpi/util/jni_util.hpp" + +using namespace wpi::util::java; + +extern "C" { + +/* + * Class: org_wpilib_math_autodiff_VariableMatrixJNI + * Method: solve + * Signature: ([JI[JI)[J + */ +JNIEXPORT jlongArray JNICALL +Java_org_wpilib_math_autodiff_VariableMatrixJNI_solve + (JNIEnv* env, jclass, jlongArray A, jint Acols, jlongArray B, jint Bcols) +{ + JSpan ASpan{env, A}; + slp::VariableMatrix AObj(slp::detail::empty, ASpan.size() / Acols, + Acols); + for (size_t i = 0; i < ASpan.size(); ++i) { + AObj[i] = *reinterpret_cast*>(ASpan[i]); + } + + JSpan BSpan{env, B}; + slp::VariableMatrix BObj(slp::detail::empty, BSpan.size() / Bcols, + Bcols); + for (size_t i = 0; i < BSpan.size(); ++i) { + BObj[i] = *reinterpret_cast*>(BSpan[i]); + } + + auto X = slp::solve(AObj, BObj); + + std::vector varHandles; + varHandles.reserve(X.size()); + for (auto& var : X) { + varHandles.emplace_back( + reinterpret_cast(new slp::Variable{var})); + } + return MakeJLongArray(env, varHandles); +} + +} // extern "C" diff --git a/wpimath/src/main/native/cpp/jni/optimization/ProblemJNI.cpp b/wpimath/src/main/native/cpp/jni/optimization/ProblemJNI.cpp new file mode 100644 index 0000000000..dcf3c5d279 --- /dev/null +++ b/wpimath/src/main/native/cpp/jni/optimization/ProblemJNI.cpp @@ -0,0 +1,255 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#include + +#include + +#include + +#include "../SleipnirJNIUtil.hpp" +#include "org_wpilib_math_optimization_ProblemJNI.h" +#include "wpi/util/jni_util.hpp" + +using namespace wpi::util::java; + +extern "C" { + +namespace { + +// ProblemJNI_solve() sets these before calling Problem::solve() so the Java +// callback has a valid JNIEnv and object on which to call +// Problem.runCallbacks() +thread_local JNIEnv* callbackEnv; +thread_local jobject callbackObj; + +} // namespace + +/* + * Class: org_wpilib_math_optimization_ProblemJNI + * Method: create + * Signature: ()J + */ +JNIEXPORT jlong JNICALL +Java_org_wpilib_math_optimization_ProblemJNI_create + (JNIEnv* env, jclass) +{ + auto problem = new slp::Problem; + + // Configure Java iteration callbacks + problem->add_persistent_callback( + [](const slp::IterationInfo& info) -> bool { + // Find Problem class + static JClass cls{callbackEnv, "org/wpilib/math/optimization/Problem"}; + if (!cls) { + return true; + } + + // Find Problem.runCallbacks() + static jmethodID runCallbacks = callbackEnv->GetMethodID( + cls, "runCallbacks", + "(III[DLorg/wpilib/math/autodiff/NativeSparseTriplets;" + "Lorg/wpilib/math/autodiff/NativeSparseTriplets;" + "Lorg/wpilib/math/autodiff/NativeSparseTriplets;" + "Lorg/wpilib/math/autodiff/NativeSparseTriplets;)Z"); + if (!runCallbacks) { + return true; + } + + // Run Java callbacks + return callbackEnv->CallBooleanMethod( + callbackObj, runCallbacks, info.A_e.rows(), info.A_i.rows(), + info.iteration, MakeJDoubleArray(callbackEnv, info.x), + wpi::math::detail::GetTriplets(callbackEnv, info.g), + wpi::math::detail::GetTriplets(callbackEnv, info.H), + wpi::math::detail::GetTriplets(callbackEnv, info.A_e), + wpi::math::detail::GetTriplets(callbackEnv, info.A_i)); + }); + + return reinterpret_cast(problem); +} + +/* + * Class: org_wpilib_math_optimization_ProblemJNI + * Method: destroy + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_org_wpilib_math_optimization_ProblemJNI_destroy + (JNIEnv* env, jclass, jlong handle) +{ + delete reinterpret_cast*>(handle); +} + +/* + * Class: org_wpilib_math_optimization_ProblemJNI + * Method: decisionVariable + * Signature: (JII)[J + */ +JNIEXPORT jlongArray JNICALL +Java_org_wpilib_math_optimization_ProblemJNI_decisionVariable + (JNIEnv* env, jclass, jlong handle, jint rows, jint cols) +{ + auto& problem = *reinterpret_cast*>(handle); + auto vars = problem.decision_variable(rows, cols); + + std::vector varHandles; + varHandles.reserve(vars.size()); + for (auto& var : vars) { + varHandles.emplace_back( + reinterpret_cast(new slp::Variable{var})); + } + return MakeJLongArray(env, varHandles); +} + +/* + * Class: org_wpilib_math_optimization_ProblemJNI + * Method: symmetricDecisionVariable + * Signature: (JI)[J + */ +JNIEXPORT jlongArray JNICALL +Java_org_wpilib_math_optimization_ProblemJNI_symmetricDecisionVariable + (JNIEnv* env, jclass, jlong handle, jint rows) +{ + auto& problem = *reinterpret_cast*>(handle); + auto vars = problem.symmetric_decision_variable(rows); + + std::vector varHandles; + varHandles.reserve(vars.size()); + for (auto& var : vars) { + varHandles.emplace_back( + reinterpret_cast(new slp::Variable{var})); + } + return MakeJLongArray(env, varHandles); +} + +/* + * Class: org_wpilib_math_optimization_ProblemJNI + * Method: minimize + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL +Java_org_wpilib_math_optimization_ProblemJNI_minimize + (JNIEnv* env, jclass, jlong handle, jlong costHandle) +{ + auto& problem = *reinterpret_cast*>(handle); + auto& costVar = *reinterpret_cast*>(costHandle); + problem.minimize(costVar); +} + +/* + * Class: org_wpilib_math_optimization_ProblemJNI + * Method: maximize + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL +Java_org_wpilib_math_optimization_ProblemJNI_maximize + (JNIEnv* env, jclass, jlong handle, jlong objectiveHandle) +{ + auto& problem = *reinterpret_cast*>(handle); + auto& objectiveVar = + *reinterpret_cast*>(objectiveHandle); + problem.maximize(objectiveVar); +} + +/* + * Class: org_wpilib_math_optimization_ProblemJNI + * Method: subjectToEq + * Signature: (J[J)V + */ +JNIEXPORT void JNICALL +Java_org_wpilib_math_optimization_ProblemJNI_subjectToEq + (JNIEnv* env, jclass, jlong handle, jlongArray constraintHandles) +{ + auto& problem = *reinterpret_cast*>(handle); + JSpan constraintHandlesSpan{env, constraintHandles}; + + for (const auto& constraintHandle : constraintHandlesSpan) { + const auto& constraint = + *reinterpret_cast*>(constraintHandle); + problem.subject_to(constraint == 0.0); + } +} + +/* + * Class: org_wpilib_math_optimization_ProblemJNI + * Method: subjectToIneq + * Signature: (J[J)V + */ +JNIEXPORT void JNICALL +Java_org_wpilib_math_optimization_ProblemJNI_subjectToIneq + (JNIEnv* env, jclass, jlong handle, jlongArray constraintHandles) +{ + auto& problem = *reinterpret_cast*>(handle); + JSpan constraintHandlesSpan{env, constraintHandles}; + + for (const auto& constraintHandle : constraintHandlesSpan) { + const auto& constraint = + *reinterpret_cast*>(constraintHandle); + problem.subject_to(constraint >= 0.0); + } +} + +/* + * Class: org_wpilib_math_optimization_ProblemJNI + * Method: costFunctionType + * Signature: (J)I + */ +JNIEXPORT jint JNICALL +Java_org_wpilib_math_optimization_ProblemJNI_costFunctionType + (JNIEnv* env, jclass, jlong handle) +{ + auto& problem = *reinterpret_cast*>(handle); + return static_cast(problem.cost_function_type()); +} + +/* + * Class: org_wpilib_math_optimization_ProblemJNI + * Method: equalityConstraintType + * Signature: (J)I + */ +JNIEXPORT jint JNICALL +Java_org_wpilib_math_optimization_ProblemJNI_equalityConstraintType + (JNIEnv* env, jclass, jlong handle) +{ + auto& problem = *reinterpret_cast*>(handle); + return static_cast(problem.equality_constraint_type()); +} + +/* + * Class: org_wpilib_math_optimization_ProblemJNI + * Method: inequalityConstraintType + * Signature: (J)I + */ +JNIEXPORT jint JNICALL +Java_org_wpilib_math_optimization_ProblemJNI_inequalityConstraintType + (JNIEnv* env, jclass, jlong handle) +{ + auto& problem = *reinterpret_cast*>(handle); + return static_cast(problem.inequality_constraint_type()); +} + +/* + * Class: org_wpilib_math_optimization_ProblemJNI + * Method: solve + * Signature: (Ljava/lang/Object;JDIDZZ)I + */ +JNIEXPORT jint JNICALL +Java_org_wpilib_math_optimization_ProblemJNI_solve + (JNIEnv* env, jclass, jobject obj, jlong handle, jdouble tolerance, + jint maxIterations, jdouble timeout, jboolean feasibleIPM, + jboolean diagnostics) +{ + auto& problem = *reinterpret_cast*>(handle); + + callbackEnv = env; + callbackObj = obj; + + slp::Options options{ + tolerance, maxIterations, std::chrono::duration{timeout}, + static_cast(feasibleIPM), static_cast(diagnostics)}; + return static_cast(problem.solve(options)); +} + +} // extern "C" diff --git a/wpimath/src/test/java/org/wpilib/math/DoubleRange.java b/wpimath/src/test/java/org/wpilib/math/DoubleRange.java new file mode 100644 index 0000000000..9d94485122 --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/DoubleRange.java @@ -0,0 +1,24 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math; + +import java.util.ArrayList; + +public final class DoubleRange { + private DoubleRange() { + // Utility class. + } + + public static ArrayList range(double start, double end, double step) { + var ret = new ArrayList(); + + int steps = (int) ((end - start) / step); + for (int i = 0; i < steps; ++i) { + ret.add(start + i * step); + } + + return ret; + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/MatrixAssertions.java b/wpimath/src/test/java/org/wpilib/math/MatrixAssertions.java new file mode 100644 index 0000000000..57544601a4 --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/MatrixAssertions.java @@ -0,0 +1,41 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.ejml.dense.row.MatrixFeatures_DDRM; +import org.ejml.simple.SimpleMatrix; + +public final class MatrixAssertions { + private MatrixAssertions() { + // Utility class. + } + + /** + * Asserts that two SimpleMatrices are equal. + * + * @param expected Expected value. + * @param actual The value to check against expected. + */ + public static void assertEquals(SimpleMatrix expected, SimpleMatrix actual) { + assertFalse(MatrixFeatures_DDRM.hasUncountable(expected.getDDRM())); + assertTrue(MatrixFeatures_DDRM.isEquals(expected.getDDRM(), actual.getDDRM())); + } + + /** + * Asserts that two SimpleMatrices are equal to within a positive delta. + * + * @param expected Expected value. + * @param actual The value to check against expected. + * @param delta The maximum delta between expected and actual for which both values are still + * considered equal. + */ + public static void assertEquals(SimpleMatrix expected, SimpleMatrix actual, double delta) { + assertFalse(MatrixFeatures_DDRM.hasUncountable(expected.getDDRM())); + assertTrue(MatrixFeatures_DDRM.isEquals(expected.getDDRM(), actual.getDDRM(), delta)); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/jni/ArmFeedforwardJNITest.java b/wpimath/src/test/java/org/wpilib/math/autodiff/GradientJNITest.java similarity index 74% rename from wpimath/src/test/java/org/wpilib/math/jni/ArmFeedforwardJNITest.java rename to wpimath/src/test/java/org/wpilib/math/autodiff/GradientJNITest.java index fe227967d3..36d28594e3 100644 --- a/wpimath/src/test/java/org/wpilib/math/jni/ArmFeedforwardJNITest.java +++ b/wpimath/src/test/java/org/wpilib/math/autodiff/GradientJNITest.java @@ -2,15 +2,15 @@ // Open Source Software; you can modify and/or share it under the terms of // the WPILib BSD license file in the root directory of this project. -package org.wpilib.math.jni; +package org.wpilib.math.autodiff; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import org.junit.jupiter.api.Test; -public class ArmFeedforwardJNITest { +public class GradientJNITest { @Test public void testLink() { - assertDoesNotThrow(ArmFeedforwardJNI::forceLoad); + assertDoesNotThrow(GradientJNI::forceLoad); } } diff --git a/wpimath/src/test/java/org/wpilib/math/autodiff/GradientTest.java b/wpimath/src/test/java/org/wpilib/math/autodiff/GradientTest.java new file mode 100644 index 0000000000..e39a503400 --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/autodiff/GradientTest.java @@ -0,0 +1,964 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.wpilib.math.MatrixAssertions.assertEquals; +import static org.wpilib.math.autodiff.Variable.abs; +import static org.wpilib.math.autodiff.Variable.acos; +import static org.wpilib.math.autodiff.Variable.asin; +import static org.wpilib.math.autodiff.Variable.atan; +import static org.wpilib.math.autodiff.Variable.atan2; +import static org.wpilib.math.autodiff.Variable.cbrt; +import static org.wpilib.math.autodiff.Variable.cos; +import static org.wpilib.math.autodiff.Variable.cosh; +import static org.wpilib.math.autodiff.Variable.exp; +import static org.wpilib.math.autodiff.Variable.hypot; +import static org.wpilib.math.autodiff.Variable.log; +import static org.wpilib.math.autodiff.Variable.log10; +import static org.wpilib.math.autodiff.Variable.max; +import static org.wpilib.math.autodiff.Variable.min; +import static org.wpilib.math.autodiff.Variable.pow; +import static org.wpilib.math.autodiff.Variable.signum; +import static org.wpilib.math.autodiff.Variable.sin; +import static org.wpilib.math.autodiff.Variable.sinh; +import static org.wpilib.math.autodiff.Variable.sqrt; +import static org.wpilib.math.autodiff.Variable.tan; +import static org.wpilib.math.autodiff.Variable.tanh; + +import org.ejml.simple.SimpleMatrix; +import org.junit.jupiter.api.Test; + +class GradientTest { + @Test + void testTrivialCase() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var a = new Variable(); + a.setValue(10); + var b = new Variable(); + b.setValue(20); + var c = a; + + try (var g_a_a = new Gradient(a, a)) { + assertEquals(1.0, g_a_a.value().get(0, 0)); + } + try (var g_a_b = new Gradient(a, b)) { + assertEquals(0.0, g_a_b.value().get(0, 0)); + } + try (var g_c_a = new Gradient(c, a)) { + assertEquals(1.0, g_c_a.value().get(0, 0)); + } + try (var g_c_b = new Gradient(c, b)) { + assertEquals(0.0, g_c_b.value().get(0, 0)); + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testUnaryPlus() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var a = new Variable(); + a.setValue(10); + var c = a.unaryPlus(); + + assertEquals(a.value(), c.value()); + try (var g_c_a = new Gradient(c, a)) { + assertEquals(1.0, g_c_a.value().get(0, 0)); + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testUnaryMinus() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var a = new Variable(); + a.setValue(10); + var c = a.unaryMinus(); + + assertEquals(a.unaryMinus().value(), c.value()); + try (var g_c_a = new Gradient(c, a)) { + assertEquals(-1.0, g_c_a.value().get(0, 0)); + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testIdenticalVariables() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var a = new Variable(); + a.setValue(10); + var x = a; + var c = a.times(a).plus(x); + + assertEquals(a.value() * a.value() + x.value(), c.value()); + try (var g_x_a = new Gradient(x, a); + var g_c_a = new Gradient(c, a)) { + assertEquals(2 * a.value() + g_x_a.value().get(0, 0), g_c_a.value().get(0, 0)); + } + try (var g_a_x = new Gradient(a, x); + var g_c_x = new Gradient(c, x)) { + assertEquals(2 * a.value() * g_a_x.value().get(0, 0) + 1, g_c_x.value().get(0, 0)); + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testElementary() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var a = new Variable(); + a.setValue(1.0); + var b = new Variable(); + b.setValue(2.0); + + var c = a.times(-2); + try (var g_c_a = new Gradient(c, a)) { + assertEquals(-2.0, g_c_a.value().get(0, 0)); + } + + c = a.div(3.0); + try (var g_c_a = new Gradient(c, a)) { + assertEquals(1.0 / 3.0, g_c_a.value().get(0, 0)); + } + + a.setValue(100.0); + b.setValue(200.0); + + c = a.plus(b); + try (var g_c_a = new Gradient(c, a)) { + assertEquals(1.0, g_c_a.value().get(0, 0)); + } + try (var g_c_b = new Gradient(c, b)) { + assertEquals(1.0, g_c_b.value().get(0, 0)); + } + + c = a.minus(b); + try (var g_c_a = new Gradient(c, a)) { + assertEquals(1.0, g_c_a.value().get(0, 0)); + } + try (var g_c_b = new Gradient(c, b)) { + assertEquals(-1.0, g_c_b.value().get(0, 0)); + } + + c = a.unaryMinus().plus(b); + try (var g_c_a = new Gradient(c, a)) { + assertEquals(-1.0, g_c_a.value().get(0, 0)); + } + try (var g_c_b = new Gradient(c, b)) { + assertEquals(1.0, g_c_b.value().get(0, 0)); + } + + c = a.plus(1); + try (var g_c_a = new Gradient(c, a)) { + assertEquals(1.0, g_c_a.value().get(0, 0)); + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testTrigonometry() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + x.setValue(0.5); + + // Math.sin(x) + assertEquals(Math.sin(x.value()), sin(x).value()); + + var g = new Gradient(sin(x), x); + assertEquals(Math.cos(x.value()), g.get().value().get(0, 0)); + assertEquals(Math.cos(x.value()), g.value().get(0, 0)); + + // Math.cos(x) + assertEquals(Math.cos(x.value()), cos(x).value()); + + g.close(); + g = new Gradient(cos(x), x); + assertEquals(-Math.sin(x.value()), g.get().value().get(0, 0)); + assertEquals(-Math.sin(x.value()), g.value().get(0, 0)); + + // Math.tan(x) + assertEquals(Math.tan(x.value()), tan(x).value()); + + g.close(); + g = new Gradient(tan(x), x); + assertEquals(1.0 / (Math.cos(x.value()) * Math.cos(x.value())), g.get().value().get(0, 0)); + assertEquals(1.0 / (Math.cos(x.value()) * Math.cos(x.value())), g.value().get(0, 0)); + + // Math.asin(x) + assertEquals(Math.asin(x.value()), asin(x).value(), 1e-15); + + g.close(); + g = new Gradient(asin(x), x); + assertEquals(1.0 / Math.sqrt(1 - x.value() * x.value()), g.get().value().get(0, 0)); + assertEquals(1.0 / Math.sqrt(1 - x.value() * x.value()), g.value().get(0, 0)); + + // Math.acos(x) + assertEquals(Math.acos(x.value()), acos(x).value(), 1e-15); + + g.close(); + g = new Gradient(acos(x), x); + assertEquals(-1.0 / Math.sqrt(1 - x.value() * x.value()), g.get().value().get(0, 0)); + assertEquals(-1.0 / Math.sqrt(1 - x.value() * x.value()), g.value().get(0, 0)); + + // Math.atan(x) + assertEquals(Math.atan(x.value()), atan(x).value(), 1e-15); + + g.close(); + g = new Gradient(atan(x), x); + assertEquals(1.0 / (1 + x.value() * x.value()), g.get().value().get(0, 0)); + assertEquals(1.0 / (1 + x.value() * x.value()), g.value().get(0, 0)); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testHyperbolic() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + x.setValue(1.0); + + // sinh(x) + assertEquals(Math.sinh(x.value()), sinh(x).value()); + + var g = new Gradient(sinh(x), x); + assertEquals(Math.cosh(x.value()), g.get().value().get(0, 0), 1e-15); + assertEquals(Math.cosh(x.value()), g.value().get(0, 0), 1e-15); + + // Math.cosh(x) + assertEquals(Math.cosh(x.value()), cosh(x).value(), 1e-15); + + g.close(); + g = new Gradient(cosh(x), x); + assertEquals(Math.sinh(x.value()), g.get().value().get(0, 0)); + assertEquals(Math.sinh(x.value()), g.value().get(0, 0)); + + // tanh(x) + assertEquals(Math.tanh(x.value()), tanh(x).value()); + + g.close(); + g = new Gradient(tanh(x), x); + assertEquals( + 1.0 / (Math.cosh(x.value()) * Math.cosh(x.value())), g.get().value().get(0, 0), 1e-15); + assertEquals(1.0 / (Math.cosh(x.value()) * Math.cosh(x.value())), g.value().get(0, 0), 1e-15); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testExponential() { + try (var pool = new VariablePool()) { + var x = new Variable(); + x.setValue(1.0); + + // Math.log(x) + assertEquals(Math.log(x.value()), log(x).value()); + + var g = new Gradient(log(x), x); + assertEquals(1.0 / x.value(), g.get().value().get(0, 0)); + assertEquals(1.0 / x.value(), g.value().get(0, 0)); + + // Math.log10(x) + assertEquals(Math.log10(x.value()), log10(x).value()); + + g.close(); + g = new Gradient(log10(x), x); + assertEquals(1.0 / (Math.log(10.0) * x.value()), g.get().value().get(0, 0)); + assertEquals(1.0 / (Math.log(10.0) * x.value()), g.value().get(0, 0)); + + // Math.exp(x) + assertEquals(Math.exp(x.value()), exp(x).value(), 1e-15); + + g.close(); + g = new Gradient(exp(x), x); + assertEquals(Math.exp(x.value()), g.get().value().get(0, 0), 1e-15); + assertEquals(Math.exp(x.value()), g.value().get(0, 0), 1e-15); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testPower() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + x.setValue(1.0); + var a = new Variable(); + a.setValue(2.0); + + // Math.sqrt(x) + assertEquals(Math.sqrt(x.value()), sqrt(x).value()); + + var g = new Gradient(sqrt(x), x); + assertEquals(0.5 / Math.sqrt(x.value()), g.get().value().get(0, 0)); + assertEquals(0.5 / Math.sqrt(x.value()), g.value().get(0, 0)); + + // Math.sqrt(a) + assertEquals(Math.sqrt(a.value()), sqrt(a).value()); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + x.setValue(1.0); + var a = new Variable(); + a.setValue(2.0); + + var g = new Gradient(sqrt(a), a); + assertEquals(0.5 / Math.sqrt(a.value()), g.get().value().get(0, 0)); + assertEquals(0.5 / Math.sqrt(a.value()), g.value().get(0, 0)); + + // Math.cbrt(x) + assertEquals(Math.cbrt(x.value()), cbrt(x).value()); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + x.setValue(1.0); + var a = new Variable(); + a.setValue(2.0); + + var g = new Gradient(cbrt(x), x); + assertEquals( + 1.0 / (3.0 * Math.cbrt(x.value()) * Math.cbrt(x.value())), g.get().value().get(0, 0)); + assertEquals(1.0 / (3.0 * Math.cbrt(x.value()) * Math.cbrt(x.value())), g.value().get(0, 0)); + + // Math.cbrt(a) + assertEquals(Math.cbrt(a.value()), cbrt(a).value(), 1e-15); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + x.setValue(1.0); + var a = new Variable(); + a.setValue(2.0); + + var g = new Gradient(cbrt(a), a); + assertEquals( + 1.0 / (3.0 * Math.cbrt(a.value()) * Math.cbrt(a.value())), + g.get().value().get(0, 0), + 1e-15); + assertEquals( + 1.0 / (3.0 * Math.cbrt(a.value()) * Math.cbrt(a.value())), g.value().get(0, 0), 1e-15); + + // x² + assertEquals(Math.pow(x.value(), 2.0), pow(x, 2.0).value()); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + x.setValue(1.0); + + var g = new Gradient(pow(x, 2.0), x); + assertEquals(2.0 * x.value(), g.get().value().get(0, 0)); + assertEquals(2.0 * x.value(), g.value().get(0, 0)); + + // 2ˣ + assertEquals(Math.pow(2.0, x.value()), pow(2.0, x).value()); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + x.setValue(1.0); + + var g = new Gradient(pow(2.0, x), x); + assertEquals(Math.log(2.0) * Math.pow(2.0, x.value()), g.get().value().get(0, 0)); + assertEquals(Math.log(2.0) * Math.pow(2.0, x.value()), g.value().get(0, 0)); + + // xˣ + assertEquals(Math.pow(x.value(), x.value()), pow(x, x).value()); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + x.setValue(1.0); + var a = new Variable(); + a.setValue(2.0); + + var g = new Gradient(pow(x, x), x); + assertEquals( + (Math.log(x.value()) + 1) * Math.pow(x.value(), x.value()), g.get().value().get(0, 0)); + assertEquals((Math.log(x.value()) + 1) * Math.pow(x.value(), x.value()), g.value().get(0, 0)); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + x.setValue(1.0); + var a = new Variable(); + a.setValue(2.0); + + // y(a) + var y = a.times(2); + assertEquals(2 * a.value(), y.value()); + + var g = new Gradient(y, a); + assertEquals(2.0, g.get().value().get(0, 0)); + assertEquals(2.0, g.value().get(0, 0)); + + // xʸ(x) + assertEquals(Math.pow(x.value(), y.value()), pow(x, y).value()); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + x.setValue(1.0); + var a = new Variable(); + a.setValue(2.0); + + // y(a) + var y = a.times(2); + assertEquals(2 * a.value(), y.value()); + + var g = new Gradient(pow(x, y), x); + assertEquals( + y.value() / x.value() * Math.pow(x.value(), y.value()), g.get().value().get(0, 0)); + assertEquals(y.value() / x.value() * Math.pow(x.value(), y.value()), g.value().get(0, 0)); + + // xʸ(a) + assertEquals(Math.pow(x.value(), y.value()), pow(x, y).value()); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + x.setValue(1.0); + var a = new Variable(); + a.setValue(2.0); + + // y(a) + var y = a.times(2); + assertEquals(2 * a.value(), y.value()); + + try (var g = new Gradient(pow(x, y), a); + var g_x_a = new Gradient(x, a); + var g_y_a = new Gradient(y, a)) { + assertEquals( + Math.pow(x.value(), y.value()) + * (y.value() / x.value() * g_x_a.value().get(0, 0) + + Math.log(x.value()) * g_y_a.value().get(0, 0)), + g.get().value().get(0, 0)); + assertEquals( + Math.pow(x.value(), y.value()) + * (y.value() / x.value() * g_x_a.value().get(0, 0) + + Math.log(x.value()) * g_y_a.value().get(0, 0)), + g.value().get(0, 0)); + } + + // xʸ(y) + assertEquals(Math.pow(x.value(), y.value()), pow(x, y).value()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + x.setValue(1.0); + var a = new Variable(); + a.setValue(2.0); + + // y(a) + var y = a.times(2); + assertEquals(2 * a.value(), y.value()); + + var g = new Gradient(pow(x, y), y); + assertEquals(Math.log(x.value()) * Math.pow(x.value(), y.value()), g.get().value().get(0, 0)); + assertEquals(Math.log(x.value()) * Math.pow(x.value(), y.value()), g.value().get(0, 0)); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testAbs() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + var g = new Gradient(abs(x), x); + + x.setValue(1.0); + assertEquals(Math.abs(x.value()), abs(x).value()); + assertEquals(1.0, g.get().value().get(0, 0)); + assertEquals(1.0, g.value().get(0, 0)); + + x.setValue(-1.0); + assertEquals(Math.abs(x.value()), abs(x).value()); + assertEquals(-1.0, g.get().value().get(0, 0)); + assertEquals(-1.0, g.value().get(0, 0)); + + x.setValue(0.0); + assertEquals(Math.abs(x.value()), abs(x).value()); + assertEquals(0.0, g.get().value().get(0, 0)); + assertEquals(0.0, g.value().get(0, 0)); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testAtan2() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + var y = new Variable(); + + // Testing atan2 function on (double, var) + x.setValue(1.0); + y.setValue(0.9); + assertEquals(Math.atan2(2.0, x.value()), atan2(2.0, x).value()); + + var g = new Gradient(atan2(2.0, x), x); + assertEquals(-2.0 / (2 * 2 + x.value() * x.value()), g.get().value().get(0, 0), 1e-15); + assertEquals(-2.0 / (2 * 2 + x.value() * x.value()), g.value().get(0, 0), 1e-15); + + // Testing atan2 function on (var, double) + x.setValue(1.0); + y.setValue(0.9); + assertEquals(Math.atan2(x.value(), 2.0), atan2(x, 2.0).value()); + + g.close(); + g = new Gradient(atan2(x, 2.0), x); + assertEquals(2.0 / (2 * 2 + x.value() * x.value()), g.get().value().get(0, 0), 1e-15); + assertEquals(2.0 / (2 * 2 + x.value() * x.value()), g.value().get(0, 0), 1e-15); + + // Testing atan2 function on (var, var) + x.setValue(1.1); + y.setValue(0.9); + assertEquals(Math.atan2(y.value(), x.value()), atan2(y, x).value(), 1e-15); + + g.close(); + g = new Gradient(atan2(y, x), y); + assertEquals( + x.value() / (x.value() * x.value() + y.value() * y.value()), + g.get().value().get(0, 0), + 1e-15); + assertEquals( + x.value() / (x.value() * x.value() + y.value() * y.value()), g.value().get(0, 0), 1e-15); + + g.close(); + g = new Gradient(atan2(y, x), x); + assertEquals( + -y.value() / (x.value() * x.value() + y.value() * y.value()), + g.get().value().get(0, 0), + 1e-15); + assertEquals( + -y.value() / (x.value() * x.value() + y.value() * y.value()), g.value().get(0, 0), 1e-15); + + // Testing atan2 function on (expr, expr) + assertEquals( + 3 * Math.atan2(Math.sin(y.value()), 2 * x.value() + 1), + 3 * atan2(sin(y), x.times(2).plus(1)).value(), + 1e-15); + + g.close(); + g = new Gradient(atan2(sin(y), x.times(2).plus(1)).times(3), y); + assertEquals( + 3 + * (2 * x.value() + 1) + * Math.cos(y.value()) + / ((2 * x.value() + 1) * (2 * x.value() + 1) + + Math.sin(y.value()) * Math.sin(y.value())), + g.get().value().get(0, 0), + 1e-15); + assertEquals( + 3 + * (2 * x.value() + 1) + * Math.cos(y.value()) + / ((2 * x.value() + 1) * (2 * x.value() + 1) + + Math.sin(y.value()) * Math.sin(y.value())), + g.value().get(0, 0), + 1e-15); + + g.close(); + g = new Gradient(atan2(sin(y), x.times(2).plus(1)).times(3), x); + assertEquals( + 3 + * -2 + * Math.sin(y.value()) + / ((2 * x.value() + 1) * (2 * x.value() + 1) + + Math.sin(y.value()) * Math.sin(y.value())), + g.get().value().get(0, 0), + 1e-15); + assertEquals( + 3 + * -2 + * Math.sin(y.value()) + / ((2 * x.value() + 1) * (2 * x.value() + 1) + + Math.sin(y.value()) * Math.sin(y.value())), + g.value().get(0, 0), + 1e-15); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + private double hypot(double x, double y, double z) { + return Math.sqrt(x * x + y * y + z * z); + } + + @Test + void testHypot() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + var y = new Variable(); + + // Testing hypot function on (var, double) + x.setValue(1.8); + y.setValue(1.5); + assertEquals(Math.hypot(x.value(), 2.0), Variable.hypot(x, 2.0).value()); + + var g = new Gradient(Variable.hypot(x, 2.0), x); + assertEquals(x.value() / Math.hypot(x.value(), 2.0), g.get().value().get(0, 0)); + assertEquals(x.value() / Math.hypot(x.value(), 2.0), g.value().get(0, 0)); + + // Testing hypot function on (double, var) + assertEquals(Math.hypot(2.0, y.value()), Variable.hypot(2.0, y).value()); + + g.close(); + g = new Gradient(Variable.hypot(2.0, y), y); + assertEquals(y.value() / Math.hypot(2.0, y.value()), g.get().value().get(0, 0)); + assertEquals(y.value() / Math.hypot(2.0, y.value()), g.value().get(0, 0)); + + // Testing hypot function on (var, var) + x.setValue(1.3); + y.setValue(2.3); + assertEquals(Math.hypot(x.value(), y.value()), Variable.hypot(x, y).value()); + + g.close(); + g = new Gradient(Variable.hypot(x, y), x); + assertEquals(x.value() / Math.hypot(x.value(), y.value()), g.get().value().get(0, 0)); + assertEquals(x.value() / Math.hypot(x.value(), y.value()), g.value().get(0, 0)); + + g.close(); + g = new Gradient(Variable.hypot(x, y), y); + assertEquals(y.value() / Math.hypot(x.value(), y.value()), g.get().value().get(0, 0)); + assertEquals(y.value() / Math.hypot(x.value(), y.value()), g.value().get(0, 0)); + + // Testing hypot function on (expr, expr) + x.setValue(1.3); + y.setValue(2.3); + assertEquals( + Math.hypot(2.0 * x.value(), 3.0 * y.value()), + Variable.hypot(x.times(2.0), y.times(3.0)).value()); + + g.close(); + g = new Gradient(Variable.hypot(x.times(2.0), y.times(3.0)), x); + assertEquals( + 4.0 * x.value() / Math.hypot(2.0 * x.value(), 3.0 * y.value()), + g.get().value().get(0, 0)); + assertEquals( + 4.0 * x.value() / Math.hypot(2.0 * x.value(), 3.0 * y.value()), g.value().get(0, 0)); + + g.close(); + g = new Gradient(Variable.hypot(x.times(2.0), y.times(3.0)), y); + assertEquals( + 9.0 * y.value() / Math.hypot(2.0 * x.value(), 3.0 * y.value()), + g.get().value().get(0, 0)); + assertEquals( + 9.0 * y.value() / Math.hypot(2.0 * x.value(), 3.0 * y.value()), g.value().get(0, 0)); + + // Testing hypot function on (var, var, var) + var z = new Variable(); + x.setValue(1.3); + y.setValue(2.3); + z.setValue(3.3); + assertEquals(Variable.hypot(x, y, z).value(), hypot(x.value(), y.value(), z.value())); + + g.close(); + g = new Gradient(Variable.hypot(x, y, z), x); + assertEquals(x.value() / hypot(x.value(), y.value(), z.value()), g.get().value().get(0, 0)); + assertEquals(x.value() / hypot(x.value(), y.value(), z.value()), g.value().get(0, 0)); + + g.close(); + g = new Gradient(Variable.hypot(x, y, z), y); + assertEquals(y.value() / hypot(x.value(), y.value(), z.value()), g.get().value().get(0, 0)); + assertEquals(y.value() / hypot(x.value(), y.value(), z.value()), g.value().get(0, 0)); + + g.close(); + g = new Gradient(Variable.hypot(x, y, z), z); + assertEquals(z.value() / hypot(x.value(), y.value(), z.value()), g.get().value().get(0, 0)); + assertEquals(z.value() / hypot(x.value(), y.value(), z.value()), g.value().get(0, 0)); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testMax() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + x.setValue(2.0); + + var x2 = x.times(x); + var x3 = x.times(x).times(x); + + try (var g_x3 = new Gradient(x3, x)) { + // Testing lhs < rhs + var g = new Gradient(max(x2, x3), x); + assertEquals(x3.value(), max(x2, x3).value()); + assertEquals(g_x3.value().get(0, 0), g.get().value().get(0, 0)); + assertEquals(g_x3.value().get(0, 0), g.value().get(0, 0)); + + // Testing lhs > rhs + g.close(); + g = new Gradient(max(x3, x2), x); + assertEquals(x3.value(), max(x3, x2).value()); + assertEquals(g_x3.value().get(0, 0), g.get().value().get(0, 0)); + assertEquals(g_x3.value().get(0, 0), g.value().get(0, 0)); + + // Testing lhs == rhs + g.close(); + g = new Gradient(max(x, x), x); + assertEquals(x.value(), max(x, x).value()); + assertEquals(1.0, g.get().value().get(0, 0)); + assertEquals(1.0, g.value().get(0, 0)); + + g.close(); + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testMin() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + x.setValue(2.0); + + var x2 = x.times(x); + var x3 = x.times(x).times(x); + + try (var g_x2 = new Gradient(x2, x)) { + // Testing lhs < rhs + var g = new Gradient(min(x2, x3), x); + assertEquals(x2.value(), min(x2, x3).value()); + assertEquals(g_x2.value().get(0, 0), g.get().value().get(0, 0)); + assertEquals(g_x2.value().get(0, 0), g.value().get(0, 0)); + + // Testing lhs > rhs + g.close(); + g = new Gradient(min(x3, x2), x); + assertEquals(x2.value(), min(x3, x2).value()); + assertEquals(g_x2.value().get(0, 0), g.get().value().get(0, 0)); + assertEquals(g_x2.value().get(0, 0), g.value().get(0, 0)); + + // Testing lhs == rhs + g.close(); + g = new Gradient(min(x, x), x); + assertEquals(x.value(), min(x, x).value()); + assertEquals(1.0, g.get().value().get(0, 0)); + assertEquals(1.0, g.value().get(0, 0)); + + g.close(); + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testMiscellaneous() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + + // dx/dx + x.setValue(3.0); + assertEquals(Math.abs(x.value()), abs(x).value()); + + var g = new Gradient(x, x); + assertEquals(1.0, g.get().value().get(0, 0)); + assertEquals(1.0, g.value().get(0, 0)); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testVariableReuse() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var a = new Variable(); + a.setValue(10); + + var b = new Variable(); + b.setValue(20); + + var x = a.times(b); + + var g = new Gradient(x, a); + + assertEquals(20.0, g.get().value().get(0, 0)); + assertEquals(20.0, g.value().get(0, 0)); + + b.setValue(10); + assertEquals(10.0, g.get().value().get(0, 0)); + assertEquals(10.0, g.value().get(0, 0)); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testSignum() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new Variable(); + + // signum(1.0) + x.setValue(1.0); + assertEquals(Math.signum(x.value()), signum(x).value()); + + var g = new Gradient(signum(x), x); + assertEquals(0.0, g.get().value().get(0, 0)); + assertEquals(0.0, g.value().get(0, 0)); + + // signum(-1.0) + x.setValue(-1.0); + assertEquals(Math.signum(x.value()), signum(x).value()); + + g.close(); + g = new Gradient(signum(x), x); + assertEquals(0.0, g.get().value().get(0, 0)); + assertEquals(0.0, g.value().get(0, 0)); + + // signum(0.0) + x.setValue(0.0); + assertEquals(Math.signum(x.value()), signum(x).value()); + + g.close(); + g = new Gradient(signum(x), x); + assertEquals(0.0, g.get().value().get(0, 0)); + assertEquals(0.0, g.value().get(0, 0)); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testNonScalar() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new VariableMatrix(3); + x.get(0).setValue(1); + x.get(1).setValue(2); + x.get(2).setValue(3); + + // y = [x₁ + 3x₂ − 5x₃] + // + // dy/dx = [1 3 −5] + var y = x.get(0).plus(x.get(1).times(3)).minus(x.get(2).times(5)); + var g = new Gradient(y, x); + + var expected_g = new SimpleMatrix(new double[][] {{1.0}, {3.0}, {-5.0}}); + + var g_get_value = g.get().value(); + assertEquals(3, g_get_value.getNumRows()); + assertEquals(1, g_get_value.getNumCols()); + assertEquals(expected_g, g_get_value); + + var g_value = g.value(); + assertEquals(3, g_value.getNumRows()); + assertEquals(1, g_value.getNumCols()); + assertEquals(expected_g, g_value); + + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/jni/Ellipse2dJNITest.java b/wpimath/src/test/java/org/wpilib/math/autodiff/HessianJNITest.java similarity index 75% rename from wpimath/src/test/java/org/wpilib/math/jni/Ellipse2dJNITest.java rename to wpimath/src/test/java/org/wpilib/math/autodiff/HessianJNITest.java index 61cb068cc4..e13ed8d57b 100644 --- a/wpimath/src/test/java/org/wpilib/math/jni/Ellipse2dJNITest.java +++ b/wpimath/src/test/java/org/wpilib/math/autodiff/HessianJNITest.java @@ -2,15 +2,15 @@ // Open Source Software; you can modify and/or share it under the terms of // the WPILib BSD license file in the root directory of this project. -package org.wpilib.math.jni; +package org.wpilib.math.autodiff; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import org.junit.jupiter.api.Test; -public class Ellipse2dJNITest { +public class HessianJNITest { @Test public void testLink() { - assertDoesNotThrow(Ellipse2dJNI::forceLoad); + assertDoesNotThrow(HessianJNI::forceLoad); } } diff --git a/wpimath/src/test/java/org/wpilib/math/autodiff/HessianTest.java b/wpimath/src/test/java/org/wpilib/math/autodiff/HessianTest.java new file mode 100644 index 0000000000..dd658cc89c --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/autodiff/HessianTest.java @@ -0,0 +1,499 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.wpilib.math.DoubleRange.range; +import static org.wpilib.math.MatrixAssertions.assertEquals; +import static org.wpilib.math.autodiff.Variable.log; +import static org.wpilib.math.autodiff.Variable.pow; +import static org.wpilib.math.autodiff.Variable.sin; + +import org.ejml.simple.SimpleMatrix; +import org.junit.jupiter.api.Test; + +class HessianTest { + @Test + void testLinear() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + // y = x + var x = new VariableMatrix(1); + x.get(0).setValue(3); + var y = x.get(0); + + // dy/dx = 1 + var gradient = new Gradient(y, x.get(0)); + double g = gradient.value().get(0, 0); + assertEquals(1.0, g); + + // d²y/dx² = 0 + var H = new Hessian(y, x); + assertEquals(0.0, H.get().value(0, 0)); + assertEquals(0.0, H.value().get(0, 0)); + + H.close(); + gradient.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testQuadratic() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + // y = x² + var x = new VariableMatrix(1); + x.get(0).setValue(3); + var y = x.get(0).times(x.get(0)); + + // dy/dx = 2x = 6 + var gradient = new Gradient(y, x.get(0)); + double g = gradient.value().get(0, 0); + assertEquals(6.0, g); + + // d²y/dx² = 2 + var H = new Hessian(y, x); + assertEquals(2.0, H.get().value(0, 0)); + assertEquals(2.0, H.value().get(0, 0)); + + H.close(); + gradient.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testCubic() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + // y = x³ + var x = new VariableMatrix(1); + x.get(0).setValue(3); + var y = x.get(0).times(x.get(0)).times(x.get(0)); + + // dy/dx = 3x² = 27 + var gradient = new Gradient(y, x.get(0)); + double g = gradient.value().get(0, 0); + assertEquals(27.0, g); + + // d²y/dx² = 6x = 18 + var H = new Hessian(y, x); + assertEquals(18.0, H.get().value(0, 0)); + assertEquals(18.0, H.value().get(0, 0)); + + H.close(); + gradient.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testQuartic() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + // y = x⁴ + var x = new VariableMatrix(1); + x.get(0).setValue(3); + var y = x.get(0).times(x.get(0)).times(x.get(0)).times(x.get(0)); + + // dy/dx = 4x³ = 108 + var gradient = new Gradient(y, x.get(0)); + double g = gradient.value().get(0, 0); + assertEquals(108.0, g); + + // d²y/dx² = 12x² = 108 + var H = new Hessian(y, x); + assertEquals(108.0, H.get().value(0, 0)); + assertEquals(108.0, H.value().get(0, 0)); + + H.close(); + gradient.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testSum() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new VariableMatrix(5); + for (int i = 0; i < 5; ++i) { + x.get(i).setValue(i + 1); + } + + // y = sum(x) + var y = x.get(0).plus(x.get(1)).plus(x.get(2)).plus(x.get(3)).plus(x.get(4)); + assertEquals(15.0, y.value()); + + var g = new Gradient(y, x); + assertEquals(SimpleMatrix.filled(5, 1, 1.0), g.get().value()); + assertEquals(SimpleMatrix.filled(5, 1, 1.0), g.value()); + + var H = new Hessian(y, x); + assertEquals(SimpleMatrix.filled(5, 5, 0.0), H.get().value()); + assertEquals(SimpleMatrix.filled(5, 5, 0.0), H.value()); + + H.close(); + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testSumOfProducts() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new VariableMatrix(5); + for (int i = 0; i < 5; ++i) { + x.get(i).setValue(i + 1); + } + + // y = ||x||² + var y = x.T().times(x).get(0); + assertEquals(1 * 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, y.value()); + + var g = new Gradient(y, x); + assertEquals(x.value().scale(2), g.get().value()); + assertEquals(x.value().scale(2), g.value()); + + var H = new Hessian(y, x); + + var expected_H = SimpleMatrix.diag(2.0, 2.0, 2.0, 2.0, 2.0); + assertEquals(expected_H, H.get().value()); + assertEquals(expected_H, H.value()); + + H.close(); + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testProductOfSines() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new VariableMatrix(5); + for (int i = 0; i < 5; ++i) { + x.get(i).setValue(i + 1); + } + + // y = prod(sin(x)) + var y = x.cwiseMap(Variable::sin).stream().reduce(new Variable(1.0), (a, b) -> a.times(b)); + assertEquals( + Math.sin(1) * Math.sin(2) * Math.sin(3) * Math.sin(4) * Math.sin(5), y.value(), 1e-15); + + var g = new Gradient(y, x); + for (int i = 0; i < x.rows(); ++i) { + assertEquals(y.value() / Math.tan(x.get(i).value()), g.get().value(i), 1e-15); + assertEquals(y.value() / Math.tan(x.get(i).value()), g.value().get(i, 0), 1e-15); + } + + var H = new Hessian(y, x); + + var expected_H = new SimpleMatrix(5, 5); + for (int i = 0; i < x.rows(); ++i) { + for (int j = 0; j < x.rows(); ++j) { + if (i == j) { + expected_H.set(i, j, -y.value()); + } else { + expected_H.set( + i, j, y.value() / (Math.tan(x.get(i).value()) * Math.tan(x.get(j).value()))); + } + } + } + assertEquals(expected_H, H.get().value(), 1e-15); + assertEquals(expected_H, H.value(), 1e-15); + + H.close(); + g.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testSumOfSquaredResiduals() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new VariableMatrix(5); + for (int i = 0; i < 5; ++i) { + x.get(i).setValue(1); + } + + // y = sum(diff(x).^2) + var temp = x.block(0, 0, 4, 1).minus(x.block(1, 0, 4, 1)).cwiseMap(a -> pow(a, 2)); + var y = temp.stream().reduce(new Variable(0.0), (a, b) -> a.plus(b)); + var gradient = new Gradient(y, x); + var g = gradient.value(); + + assertEquals(0.0, y.value()); + assertEquals(g.get(0, 0), 2 * x.get(0).value() - 2 * x.get(1).value()); + assertEquals( + g.get(1, 0), -2 * x.get(0).value() + 4 * x.get(1).value() - 2 * x.get(2).value()); + assertEquals( + g.get(2, 0), -2 * x.get(1).value() + 4 * x.get(2).value() - 2 * x.get(3).value()); + assertEquals( + g.get(3, 0), -2 * x.get(2).value() + 4 * x.get(3).value() - 2 * x.get(4).value()); + assertEquals(g.get(4, 0), -2 * x.get(3).value() + 2 * x.get(4).value()); + + var H = new Hessian(y, x); + + var expected_H = + new SimpleMatrix( + new double[][] { + {2.0, -2.0, 0.0, 0.0, 0.0}, + {-2.0, 4.0, -2.0, 0.0, 0.0}, + {0.0, -2.0, 4.0, -2.0, 0.0}, + {0.0, 0.0, -2.0, 4.0, -2.0}, + {0.0, 0.0, 0.0, -2.0, 2.0} + }); + assertEquals(expected_H, H.get().value()); + assertEquals(expected_H, H.value()); + + H.close(); + gradient.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testSumOfSquares() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var r = new VariableMatrix(4); + r.setValue(new double[][] {{25.0}, {10.0}, {5.0}, {0.0}}); + + var x = new VariableMatrix(4); + for (int i = 0; i < 4; ++i) { + x.get(i).setValue(0.0); + } + + var J = new Variable(0.0); + for (int i = 0; i < 4; ++i) { + J = J.plus(r.get(i).minus(x.get(i)).times(r.get(i).minus(x.get(i)))); + } + + var H = new Hessian(J, x); + + var expected_H = SimpleMatrix.diag(2.0, 2.0, 2.0, 2.0); + assertEquals(expected_H, H.get().value()); + assertEquals(expected_H, H.value()); + + H.close(); + J.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testNestedPowers() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + final var x0 = 3.0; + + var x = new Variable(); + x.setValue(x0); + + var y = pow(pow(x, 2), 2); + + var jacobian = new Jacobian(y, x); + var J = jacobian.value(); + assertEquals(4 * x0 * x0 * x0, J.get(0, 0), 1e-12); + + var hessian = new Hessian(y, x); + var H = hessian.value(); + assertEquals(12 * x0 * x0, H.get(0, 0), 1e-12); + + hessian.close(); + jacobian.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testRosenbrock() { + // z = (1 − x)² + 100(y − x²)² + // = 100(−x² + y)² + (−x + 1)² + // + // ∂z/∂x = 200(−x² + y)⋅−2x + 2(−x + 1)⋅−1 + // = −400x(−x² + y) − 2(−x + 1) + // = 400x³ − 400xy + 2x − 2 + // + // ∂z/∂y = 200(−x² + y) + // + // ∂²z/∂x² = 1200x² − 400y + 2 + // ∂²z/∂xy = −400x + // ∂²z/∂y² = 200 + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var input = new VariableMatrix(2); + var x = input.get(0); + var y = input.get(1); + var hessian = + new Hessian( + pow(new Variable(1).minus(x), 2).plus(pow(y.minus(pow(x, 2)), 2).times(100)), input); + + for (var x0 : range(-2.5, 2.5, 0.1)) { + for (var y0 : range(-2.5, 2.5, 0.1)) { + x.setValue(x0); + y.setValue(y0); + + var H = hessian.value(); + assertEquals(1200 * x0 * x0 - 400 * y0 + 2, H.get(0, 0), 1e-11); + assertEquals(-400 * x0, H.get(0, 1), 1e-15); + assertEquals(-400 * x0, H.get(1, 0), 1e-15); + assertEquals(200, H.get(1, 1)); + } + } + + hessian.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testEdgePushingWangExample1() { + // See example 1 of [1] + // + // [1] Wang, M., et al. "Capitalizing on live variables: new algorithms for + // efficient Hessian computation via automatic differentiation", 2016. + // https://sci-hub.st/10.1007/s12532-016-0100-3 + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new VariableMatrix(2); + x.get(0).setValue(3); + x.get(1).setValue(4); + + // y = (x₀sin(x₁)) x₀ + var y = (x.get(0).times(sin(x.get(1)))).times(x.get(0)); + + // dy/dx = [2x₀sin(x₁) x₀²cos(x₁)] + // dy/dx = [ 6sin(4) 9cos(4) ] + var J = new Jacobian(y, x); + var expected_J = + new SimpleMatrix(new double[][] {{6.0 * Math.sin(4.0), 9.0 * Math.cos(4.0)}}); + assertEquals(expected_J, J.get().value(), 1e-15); + assertEquals(expected_J, J.value(), 1e-15); + + // [ 2sin(x₁) 2x₀cos(x₁)] + // d²y/dx² = [2x₀cos(x₁) −x₀²sin(x₁)] + // + // [2sin(4) 6cos(4)] + // d²y/dx² = [6cos(4) −9sin(4)] + var H = new Hessian(y, x); + var expected_H = + new SimpleMatrix( + new double[][] { + {2.0 * Math.sin(4.0), 6.0 * Math.cos(4.0)}, + {6.0 * Math.cos(4.0), -9.0 * Math.sin(4.0)} + }); + assertEquals(expected_H, H.get().value(), 1e-15); + assertEquals(expected_H, H.value(), 1e-15); + + H.close(); + J.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testEdgePushingPetroFigure1() { + // See figure 1 of [1] + // + // [1] Petro, C. G., et al. "On efficient Hessian computation using the edge + // pushing algorithm in Julia", 2017. + // https://mlubin.github.io/pdf/edge_pushing_julia.pdf + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + // y = p₁ log(x₁x₂) + var p_1 = new Variable(2.0); + var x = new VariableMatrix(2); + x.get(0).setValue(2.0); + x.get(1).setValue(3.0); + var y = p_1.times(log(x.get(0).times(x.get(1)))); + + // d²y/dx² = [−p₁/x₁² 0 ] + // [ 0 −p₁/x₂²] + var H = new Hessian(y, x); + var expected_H = + new SimpleMatrix( + new double[][] { + {-p_1.value() / (x.get(0).value() * x.get(0).value()), 0.0}, + {0.0, -p_1.value() / (x.get(1).value() * x.get(1).value())} + }); + assertEquals(expected_H, H.get().value(), 1e-15); + assertEquals(expected_H, H.value(), 1e-15); + + H.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testVariableReuse() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + Variable y; + var x = new VariableMatrix(1); + + // y = x³ + x.get(0).setValue(1); + y = x.get(0).times(x.get(0)).times(x.get(0)); + + var hessian = new Hessian(y, x); + + // d²y/dx² = 6x + // H = 6 + var H = hessian.value(); + + assertEquals(1, H.getNumRows()); + assertEquals(1, H.getNumCols()); + assertEquals(6.0, H.get(0, 0)); + + x.get(0).setValue(2); + // d²y/dx² = 6x + // H = 12 + H = hessian.value(); + + assertEquals(1, H.getNumRows()); + assertEquals(1, H.getNumCols()); + assertEquals(12.0, H.get(0, 0)); + + hessian.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/autodiff/JacobianJNITest.java b/wpimath/src/test/java/org/wpilib/math/autodiff/JacobianJNITest.java new file mode 100644 index 0000000000..0eea269354 --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/autodiff/JacobianJNITest.java @@ -0,0 +1,16 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import org.junit.jupiter.api.Test; + +public class JacobianJNITest { + @Test + public void testLink() { + assertDoesNotThrow(JacobianJNI::forceLoad); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/autodiff/JacobianTest.java b/wpimath/src/test/java/org/wpilib/math/autodiff/JacobianTest.java new file mode 100644 index 0000000000..e92465aeae --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/autodiff/JacobianTest.java @@ -0,0 +1,266 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.wpilib.math.MatrixAssertions.assertEquals; + +import org.ejml.simple.SimpleMatrix; +import org.junit.jupiter.api.Test; + +class JacobianTest { + @Test + void testYEqualsX() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new VariableMatrix(3); + for (int i = 0; i < 3; ++i) { + x.get(i).setValue(i + 1); + } + + // y = x + // + // [1 0 0] + // dy/dx = [0 1 0] + // [0 0 1] + var y = x; + var J = new Jacobian(y, x); + + var expected_J = + new SimpleMatrix(new double[][] {{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0}}); + assertEquals(expected_J, J.get().value()); + assertEquals(expected_J, J.value()); + + J.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testYEquals3X() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new VariableMatrix(3); + for (int i = 0; i < 3; ++i) { + x.get(i).setValue(i + 1); + } + + // y = 3x + // + // [3 0 0] + // dy/dx = [0 3 0] + // [0 0 3] + var y = x.times(3); + var J = new Jacobian(y, x); + + var expected_J = + new SimpleMatrix(new double[][] {{3.0, 0.0, 0.0}, {0.0, 3.0, 0.0}, {0.0, 0.0, 3.0}}); + assertEquals(expected_J, J.get().value()); + assertEquals(expected_J, J.value()); + + J.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testProducts() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new VariableMatrix(3); + for (int i = 0; i < 3; ++i) { + x.get(i).setValue(i + 1); + } + + // [x₁x₂] + // y = [x₂x₃] + // [x₁x₃] + // + // [x₂ x₁ 0 ] + // dy/dx = [0 x₃ x₂] + // [x₃ 0 x₁] + // + // [2 1 0] + // dy/dx = [0 3 2] + // [3 0 1] + var y = new VariableMatrix(3); + y.set(0, x.get(0).times(x.get(1))); + y.set(1, x.get(1).times(x.get(2))); + y.set(2, x.get(0).times(x.get(2))); + var J = new Jacobian(y, x); + + var expected_J = + new SimpleMatrix(new double[][] {{2.0, 1.0, 0.0}, {0.0, 3.0, 2.0}, {3.0, 0.0, 1.0}}); + assertEquals(expected_J, J.get().value()); + assertEquals(expected_J, J.value()); + + J.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testNestedProducts() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new VariableMatrix(1); + x.get(0).setValue(3); + assertEquals(3.0, x.value(0)); + + // [ 5x] [15] + // y = [ 7x] = [21] + // [11x] [33] + var y = new VariableMatrix(3); + y.set(0, x.get(0).times(5)); + y.set(1, x.get(0).times(7)); + y.set(2, x.get(0).times(11)); + assertEquals(15.0, y.value(0)); + assertEquals(21.0, y.value(1)); + assertEquals(33.0, y.value(2)); + + // [y₁y₂] [15⋅21] [315] + // z = [y₂y₃] = [21⋅33] = [693] + // [y₁y₃] [15⋅33] [495] + var z = new VariableMatrix(3); + z.set(0, y.get(0).times(y.get(1))); + z.set(1, y.get(1).times(y.get(2))); + z.set(2, y.get(0).times(y.get(2))); + assertEquals(315.0, z.value(0)); + assertEquals(693.0, z.value(1)); + assertEquals(495.0, z.value(2)); + + // [ 5x] + // y = [ 7x] + // [11x] + // + // [ 5] + // dy/dx = [ 7] + // [11] + var J = new Jacobian(y, x); + assertEquals(5.0, J.get().value(0, 0)); + assertEquals(7.0, J.get().value(1, 0)); + assertEquals(11.0, J.get().value(2, 0)); + assertEquals(5.0, J.value().get(0, 0)); + assertEquals(7.0, J.value().get(1, 0)); + assertEquals(11.0, J.value().get(2, 0)); + + // [y₁y₂] + // z = [y₂y₃] + // [y₁y₃] + // + // [y₂ y₁ 0 ] [21 15 0] + // dz/dy = [0 y₃ y₂] = [ 0 33 21] + // [y₃ 0 y₁] [33 0 15] + J.close(); + J = new Jacobian(z, y); + var expected_J = + new SimpleMatrix( + new double[][] {{21.0, 15.0, 0.0}, {0.0, 33.0, 21.0}, {33.0, 0.0, 15.0}}); + assertEquals(expected_J, J.get().value()); + assertEquals(expected_J, J.value()); + + // [y₁y₂] [5x⋅ 7x] [35x²] + // z = [y₂y₃] = [7x⋅11x] = [77x²] + // [y₁y₃] [5x⋅11x] [55x²] + // + // [ 70x] [210] + // dz/dx = [154x] = [462] + // [110x] = [330] + J.close(); + J = new Jacobian(z, x); + expected_J = new SimpleMatrix(new double[][] {{210.0}, {462.0}, {330.0}}); + assertEquals(expected_J, J.get().value()); + assertEquals(expected_J, J.value()); + + J.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testNonSquare() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new VariableMatrix(3); + for (int i = 0; i < 3; ++i) { + x.get(i).setValue(i + 1); + } + + // y = [x₁ + 3x₂ − 5x₃] + // + // dy/dx = [1 3 −5] + var y = new VariableMatrix(1); + y.set(0, x.get(0).plus(x.get(1).times(3)).minus(x.get(2).times(5))); + var J = new Jacobian(y, x); + + var expected_J = new SimpleMatrix(new double[][] {{1.0, 3.0, -5.0}}); + + var J_get_value = J.get().value(); + assertEquals(1, J_get_value.getNumRows()); + assertEquals(3, J_get_value.getNumCols()); + assertEquals(expected_J, J_get_value); + + var J_value = J.value(); + assertEquals(1, J_value.getNumRows()); + assertEquals(3, J_value.getNumCols()); + assertEquals(expected_J, J_value); + + J.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testVariableReuse() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var x = new VariableMatrix(2); + for (int i = 0; i < 2; ++i) { + x.get(i).setValue(i + 1); + } + + // y = [x₁x₂] + var y = new VariableMatrix(1); + y.set(0, x.get(0).times(x.get(1))); + + var jacobian = new Jacobian(y, x); + + // dy/dx = [x₂ x₁] + // dy/dx = [2 1] + var J = jacobian.value(); + + assertEquals(1, J.getNumRows()); + assertEquals(2, J.getNumCols()); + assertEquals(2.0, J.get(0, 0)); + assertEquals(1.0, J.get(0, 1)); + + x.get(0).setValue(2); + x.get(1).setValue(1); + // dy/dx = [x₂ x₁] + // dy/dx = [1 2] + J = jacobian.value(); + + assertEquals(1, J.getNumRows()); + assertEquals(2, J.getNumCols()); + assertEquals(1.0, J.get(0, 0)); + assertEquals(2.0, J.get(0, 1)); + + jacobian.close(); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/autodiff/SliceTest.java b/wpimath/src/test/java/org/wpilib/math/autodiff/SliceTest.java new file mode 100644 index 0000000000..3ca62a4d5e --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/autodiff/SliceTest.java @@ -0,0 +1,481 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +class SliceTest { + @Test + void testDefaultConstructor() { + var slice = new Slice(); + assertEquals(0, slice.start); + assertEquals(0, slice.stop); + assertEquals(1, slice.step); + + assertEquals(0, slice.adjust(3)); + assertEquals(0, slice.start); + assertEquals(0, slice.stop); + assertEquals(1, slice.step); + } + + @Test + void testOneArgConstructor() { + // none + { + var slice = new Slice(Slice.__); + assertEquals(0, slice.start); + assertEquals(Integer.MAX_VALUE, slice.stop); + assertEquals(1, slice.step); + + assertEquals(3, slice.adjust(3)); + assertEquals(0, slice.start); + assertEquals(3, slice.stop); + assertEquals(1, slice.step); + } + + // + + { + var slice = new Slice(1); + assertEquals(1, slice.start); + assertEquals(2, slice.stop); + assertEquals(1, slice.step); + + assertEquals(1, slice.adjust(3)); + assertEquals(1, slice.start); + assertEquals(2, slice.stop); + assertEquals(1, slice.step); + } + + // -1 + { + var slice = new Slice(-1); + assertEquals(-1, slice.start); + assertEquals(Integer.MAX_VALUE, slice.stop); + assertEquals(1, slice.step); + + assertEquals(1, slice.adjust(3)); + assertEquals(2, slice.start); + assertEquals(3, slice.stop); + assertEquals(1, slice.step); + } + + // -2 + { + var slice = new Slice(-2); + assertEquals(-2, slice.start); + assertEquals(-1, slice.stop); + assertEquals(1, slice.step); + + assertEquals(1, slice.adjust(3)); + assertEquals(1, slice.start); + assertEquals(2, slice.stop); + assertEquals(1, slice.step); + } + } + + @Test + void testTwoArgConstructor() { + // none, none + { + var slice = new Slice(Slice.__, Slice.__); + assertEquals(0, slice.start); + assertEquals(Integer.MAX_VALUE, slice.stop); + assertEquals(1, slice.step); + + assertEquals(3, slice.adjust(3)); + assertEquals(0, slice.start); + assertEquals(3, slice.stop); + assertEquals(1, slice.step); + } + + // none, + + { + var slice = new Slice(Slice.__, 1); + assertEquals(0, slice.start); + assertEquals(1, slice.stop); + assertEquals(1, slice.step); + + assertEquals(1, slice.adjust(3)); + assertEquals(0, slice.start); + assertEquals(1, slice.stop); + assertEquals(1, slice.step); + } + + // none, - + { + var slice = new Slice(Slice.__, -1); + assertEquals(0, slice.start); + assertEquals(-1, slice.stop); + assertEquals(1, slice.step); + + assertEquals(2, slice.adjust(3)); + assertEquals(0, slice.start); + assertEquals(2, slice.stop); + assertEquals(1, slice.step); + } + + // +, none + { + var slice = new Slice(1, Slice.__); + assertEquals(1, slice.start); + assertEquals(Integer.MAX_VALUE, slice.stop); + assertEquals(1, slice.step); + + assertEquals(2, slice.adjust(3)); + assertEquals(1, slice.start); + assertEquals(3, slice.stop); + assertEquals(1, slice.step); + } + + // -, none + { + var slice = new Slice(-1, Slice.__); + assertEquals(-1, slice.start); + assertEquals(Integer.MAX_VALUE, slice.stop); + assertEquals(1, slice.step); + + assertEquals(1, slice.adjust(3)); + assertEquals(2, slice.start); + assertEquals(3, slice.stop); + assertEquals(1, slice.step); + } + + // +, + + { + var slice = new Slice(1, 2); + assertEquals(1, slice.start); + assertEquals(2, slice.stop); + assertEquals(1, slice.step); + + assertEquals(1, slice.adjust(3)); + assertEquals(1, slice.start); + assertEquals(2, slice.stop); + assertEquals(1, slice.step); + } + + // +, - + { + var slice = new Slice(1, -1); + assertEquals(1, slice.start); + assertEquals(-1, slice.stop); + assertEquals(1, slice.step); + + assertEquals(1, slice.adjust(3)); + assertEquals(1, slice.start); + assertEquals(2, slice.stop); + assertEquals(1, slice.step); + } + + // -, - + { + var slice = new Slice(-2, -1); + assertEquals(-2, slice.start); + assertEquals(-1, slice.stop); + assertEquals(1, slice.step); + + assertEquals(1, slice.adjust(3)); + assertEquals(1, slice.start); + assertEquals(2, slice.stop); + assertEquals(1, slice.step); + } + } + + @Test + void testThreeArgConstructor() { + // none, none, none + { + var slice = new Slice(Slice.__, Slice.__, Slice.__); + assertEquals(0, slice.start); + assertEquals(Integer.MAX_VALUE, slice.stop); + assertEquals(1, slice.step); + + assertEquals(3, slice.adjust(3)); + assertEquals(0, slice.start); + assertEquals(3, slice.stop); + assertEquals(1, slice.step); + } + + // none, none, + + { + var slice = new Slice(Slice.__, Slice.__, 2); + assertEquals(0, slice.start); + assertEquals(Integer.MAX_VALUE, slice.stop); + assertEquals(2, slice.step); + + assertEquals(2, slice.adjust(3)); + assertEquals(0, slice.start); + assertEquals(3, slice.stop); + assertEquals(2, slice.step); + } + + // none, none, - + { + var slice = new Slice(Slice.__, Slice.__, -2); + assertEquals(Integer.MAX_VALUE, slice.start); + assertEquals(Integer.MIN_VALUE, slice.stop); + assertEquals(-2, slice.step); + + assertEquals(2, slice.adjust(3)); + assertEquals(2, slice.start); + assertEquals(-1, slice.stop); + assertEquals(-2, slice.step); + } + + // none, +, + + { + var slice = new Slice(Slice.__, 1, 2); + assertEquals(0, slice.start); + assertEquals(1, slice.stop); + assertEquals(2, slice.step); + + assertEquals(1, slice.adjust(3)); + assertEquals(0, slice.start); + assertEquals(1, slice.stop); + assertEquals(2, slice.step); + } + + // none, +, - + { + var slice = new Slice(Slice.__, 1, -2); + assertEquals(Integer.MAX_VALUE, slice.start); + assertEquals(1, slice.stop); + assertEquals(-2, slice.step); + + assertEquals(1, slice.adjust(3)); + assertEquals(2, slice.start); + assertEquals(1, slice.stop); + assertEquals(-2, slice.step); + } + + // none, -, - + { + var slice = new Slice(Slice.__, -2, -1); + assertEquals(Integer.MAX_VALUE, slice.start); + assertEquals(-2, slice.stop); + assertEquals(-1, slice.step); + + assertEquals(1, slice.adjust(3)); + assertEquals(2, slice.start); + assertEquals(1, slice.stop); + assertEquals(-1, slice.step); + } + + // +, none, + + { + var slice = new Slice(1, Slice.__, 2); + assertEquals(1, slice.start); + assertEquals(Integer.MAX_VALUE, slice.stop); + assertEquals(2, slice.step); + + assertEquals(1, slice.adjust(3)); + assertEquals(1, slice.start); + assertEquals(3, slice.stop); + assertEquals(2, slice.step); + } + + // +, none, - + { + var slice = new Slice(1, Slice.__, -2); + assertEquals(1, slice.start); + assertEquals(Integer.MIN_VALUE, slice.stop); + assertEquals(-2, slice.step); + + assertEquals(1, slice.adjust(3)); + assertEquals(1, slice.start); + assertEquals(-1, slice.stop); + assertEquals(-2, slice.step); + } + + // +, +, + + { + var slice = new Slice(1, 2, 2); + assertEquals(1, slice.start); + assertEquals(2, slice.stop); + assertEquals(2, slice.step); + + assertEquals(1, slice.adjust(3)); + assertEquals(1, slice.start); + assertEquals(2, slice.stop); + assertEquals(2, slice.step); + } + + // +, +, - + { + var slice = new Slice(2, 1, -2); + assertEquals(2, slice.start); + assertEquals(1, slice.stop); + assertEquals(-2, slice.step); + + assertEquals(1, slice.adjust(3)); + assertEquals(2, slice.start); + assertEquals(1, slice.stop); + assertEquals(-2, slice.step); + } + } + + @Test + void testEmptySlices() { + // +, +, + + { + var slice = new Slice(2, 1, 2); + assertEquals(2, slice.start); + assertEquals(1, slice.stop); + assertEquals(2, slice.step); + + assertEquals(0, slice.adjust(3)); + assertEquals(2, slice.start); + assertEquals(1, slice.stop); + assertEquals(2, slice.step); + } + + // +, +, - + { + var slice = new Slice(1, 2, -2); + assertEquals(1, slice.start); + assertEquals(2, slice.stop); + assertEquals(-2, slice.step); + + assertEquals(0, slice.adjust(3)); + assertEquals(1, slice.start); + assertEquals(2, slice.stop); + assertEquals(-2, slice.step); + } + + // +, -, - + { + var slice = new Slice(3, -1, -2); + assertEquals(3, slice.start); + assertEquals(-1, slice.stop); + assertEquals(-2, slice.step); + + assertEquals(0, slice.adjust(3)); + assertEquals(2, slice.start); + assertEquals(2, slice.stop); + assertEquals(-2, slice.step); + } + } + + @Test + void testStepUBGuard() { + { + // none, none, INT_MIN + var slice = new Slice(Slice.__, Slice.__, Integer.MIN_VALUE); + assertEquals(Integer.MAX_VALUE, slice.start); + assertEquals(Integer.MIN_VALUE, slice.stop); + assertEquals(Integer.MIN_VALUE + 1, slice.step); + + slice.step = -slice.step; + assertEquals(Integer.MAX_VALUE, slice.start); + assertEquals(Integer.MIN_VALUE, slice.stop); + assertEquals(Integer.MAX_VALUE, slice.step); + } + + { + // none, +, INT_MIN + var slice = new Slice(Slice.__, 2, Integer.MIN_VALUE); + assertEquals(Integer.MAX_VALUE, slice.start); + assertEquals(2, slice.stop); + assertEquals(Integer.MIN_VALUE + 1, slice.step); + + slice.step = -slice.step; + assertEquals(Integer.MAX_VALUE, slice.start); + assertEquals(2, slice.stop); + assertEquals(Integer.MAX_VALUE, slice.step); + } + + { + // none, -, INT_MIN + var slice = new Slice(Slice.__, -2, Integer.MIN_VALUE); + assertEquals(Integer.MAX_VALUE, slice.start); + assertEquals(-2, slice.stop); + assertEquals(Integer.MIN_VALUE + 1, slice.step); + + slice.step = -slice.step; + assertEquals(Integer.MAX_VALUE, slice.start); + assertEquals(-2, slice.stop); + assertEquals(Integer.MAX_VALUE, slice.step); + } + + { + // +, none, INT_MIN + var slice = new Slice(1, Slice.__, Integer.MIN_VALUE); + assertEquals(1, slice.start); + assertEquals(Integer.MIN_VALUE, slice.stop); + assertEquals(Integer.MIN_VALUE + 1, slice.step); + + slice.step = -slice.step; + assertEquals(1, slice.start); + assertEquals(Integer.MIN_VALUE, slice.stop); + assertEquals(Integer.MAX_VALUE, slice.step); + } + + { + // -, none, INT_MIN + var slice = new Slice(-2, Slice.__, Integer.MIN_VALUE); + assertEquals(-2, slice.start); + assertEquals(Integer.MIN_VALUE, slice.stop); + assertEquals(Integer.MIN_VALUE + 1, slice.step); + + slice.step = -slice.step; + assertEquals(-2, slice.start); + assertEquals(Integer.MIN_VALUE, slice.stop); + assertEquals(Integer.MAX_VALUE, slice.step); + } + + { + // +, +, INT_MIN + var slice = new Slice(1000, 0, Integer.MIN_VALUE); + assertEquals(1000, slice.start); + assertEquals(0, slice.stop); + assertEquals(Integer.MIN_VALUE + 1, slice.step); + + slice.step = -slice.step; + assertEquals(1000, slice.start); + assertEquals(0, slice.stop); + assertEquals(Integer.MAX_VALUE, slice.step); + } + + { + // +, -, INT_MIN + var slice = new Slice(1000, -2, Integer.MIN_VALUE); + assertEquals(1000, slice.start); + assertEquals(-2, slice.stop); + assertEquals(Integer.MIN_VALUE + 1, slice.step); + + slice.step = -slice.step; + assertEquals(1000, slice.start); + assertEquals(-2, slice.stop); + assertEquals(Integer.MAX_VALUE, slice.step); + } + + { + // -, +, INT_MIN + var slice = new Slice(-1, 2, Integer.MIN_VALUE); + assertEquals(-1, slice.start); + assertEquals(2, slice.stop); + assertEquals(Integer.MIN_VALUE + 1, slice.step); + + slice.step = -slice.step; + assertEquals(-1, slice.start); + assertEquals(2, slice.stop); + assertEquals(Integer.MAX_VALUE, slice.step); + } + + { + // -, -, INT_MIN + var slice = new Slice(-1, -2, Integer.MIN_VALUE); + assertEquals(-1, slice.start); + assertEquals(-2, slice.stop); + assertEquals(Integer.MIN_VALUE + 1, slice.step); + + slice.step = -slice.step; + assertEquals(-1, slice.start); + assertEquals(-2, slice.stop); + assertEquals(Integer.MAX_VALUE, slice.step); + } + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/autodiff/VariableJNITest.java b/wpimath/src/test/java/org/wpilib/math/autodiff/VariableJNITest.java new file mode 100644 index 0000000000..34776ed4ff --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/autodiff/VariableJNITest.java @@ -0,0 +1,16 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import org.junit.jupiter.api.Test; + +public class VariableJNITest { + @Test + public void testLink() { + assertDoesNotThrow(VariableJNI::forceLoad); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/autodiff/VariableMatrixJNITest.java b/wpimath/src/test/java/org/wpilib/math/autodiff/VariableMatrixJNITest.java new file mode 100644 index 0000000000..84b04a49e2 --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/autodiff/VariableMatrixJNITest.java @@ -0,0 +1,16 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import org.junit.jupiter.api.Test; + +public class VariableMatrixJNITest { + @Test + public void testLink() { + assertDoesNotThrow(VariableMatrixJNI::forceLoad); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/autodiff/VariableMatrixTest.java b/wpimath/src/test/java/org/wpilib/math/autodiff/VariableMatrixTest.java new file mode 100644 index 0000000000..1c45eb9fc4 --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/autodiff/VariableMatrixTest.java @@ -0,0 +1,600 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.wpilib.math.MatrixAssertions.assertEquals; + +import org.ejml.simple.SimpleMatrix; +import org.junit.jupiter.api.Test; + +class VariableMatrixTest { + @Test + void testConstructFromDoubleArray() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var mat = new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})) { + var expected = new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}); + assertEquals(expected, mat.value()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testConstructFromSimpleMatrix() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var mat = + new VariableMatrix(new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}))) { + var expected = new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}); + assertEquals(expected, mat.value()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testAssignmentToDefault() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var mat = new VariableMatrix(2, 2); + + assertEquals(2, mat.rows()); + assertEquals(2, mat.cols()); + assertEquals(0.0, mat.get(0, 0).value()); + assertEquals(0.0, mat.get(0, 1).value()); + assertEquals(0.0, mat.get(1, 0).value()); + assertEquals(0.0, mat.get(1, 1).value()); + + mat.set(0, 0, 1.0); + mat.set(0, 1, 2.0); + mat.set(1, 0, 3.0); + mat.set(1, 1, 4.0); + + assertEquals(1.0, mat.get(0, 0).value()); + assertEquals(2.0, mat.get(0, 1).value()); + assertEquals(3.0, mat.get(1, 0).value()); + assertEquals(4.0, mat.get(1, 1).value()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testAssignmentAliasing() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var A = new VariableMatrix(new double[][] {{1.0, 2.0}, {3.0, 4.0}}); + var B = new VariableMatrix(new double[][] {{5.0, 6.0}, {7.0, 8.0}}); + + // A and B initially contain different values + var expected_A = new SimpleMatrix(new double[][] {{1.0, 2.0}, {3.0, 4.0}}); + var expected_B = new SimpleMatrix(new double[][] {{5.0, 6.0}, {7.0, 8.0}}); + assertEquals(expected_A, A.value()); + assertEquals(expected_B, B.value()); + + // Make A point to B's storage + A.set(B); + assertEquals(expected_B, A.value()); + assertEquals(expected_B, B.value()); + + // Changes to B should be reflected in A + B.get(0, 0).setValue(2.0); + expected_B.set(0, 0, 2.0); + assertEquals(expected_B, A.value()); + assertEquals(expected_B, B.value()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testBlockMemberFunction() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var A = + new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}}); + + // Block assignment + A.block(1, 1, 2, 2).set(new double[][] {{10.0, 11.0}, {12.0, 13.0}}); + + var expected1 = + new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 10.0, 11.0}, {7.0, 12.0, 13.0}}); + assertEquals(expected1, A.value()); + + // Block-of-block assignment + A.block(1, 1, 2, 2).block(1, 1, 1, 1).set(14.0); + + var expected2 = + new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 10.0, 11.0}, {7.0, 12.0, 14.0}}); + assertEquals(expected2, A.value()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testSlicing() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var mat = + new VariableMatrix( + new double[][] {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}); + assertEquals(4, mat.rows()); + assertEquals(4, mat.cols()); + + // Single-arg index operator on full matrix + for (int i = 0; i < mat.rows() * mat.cols(); ++i) { + assertEquals(i + 1, mat.get(i).value()); + } + + // Slice from start + { + var s = mat.get(new Slice(1, Slice.__), new Slice(2, Slice.__)); + assertEquals(3, s.rows()); + assertEquals(2, s.cols()); + // Single-arg index operator on forward slice + assertEquals(7.0, s.get(0).value()); + assertEquals(8.0, s.get(1).value()); + assertEquals(11.0, s.get(2).value()); + assertEquals(12.0, s.get(3).value()); + assertEquals(15.0, s.get(4).value()); + assertEquals(16.0, s.get(5).value()); + // Double-arg index operator on forward slice + assertEquals(7.0, s.get(0, 0).value()); + assertEquals(8.0, s.get(0, 1).value()); + assertEquals(11.0, s.get(1, 0).value()); + assertEquals(12.0, s.get(1, 1).value()); + assertEquals(15.0, s.get(2, 0).value()); + assertEquals(16.0, s.get(2, 1).value()); + } + + // Slice from end + { + var s = mat.get(new Slice(-1, Slice.__), new Slice(-2, Slice.__)); + assertEquals(1, s.rows()); + assertEquals(2, s.cols()); + // Single-arg index operator on reverse slice + assertEquals(15.0, s.get(0).value()); + assertEquals(16.0, s.get(1).value()); + // Double-arg index operator on reverse slice + assertEquals(15.0, s.get(0, 0).value()); + assertEquals(16.0, s.get(0, 1).value()); + } + + // Slice from start with step of 2 + { + var s = mat.get(Slice.__, new Slice(Slice.__, Slice.__, 2)); + assertEquals(4, s.rows()); + assertEquals(2, s.cols()); + assertEquals( + new SimpleMatrix(new double[][] {{1.0, 3.0}, {5.0, 7.0}, {9.0, 11.0}, {13.0, 15.0}}), + s.value()); + } + + // Slice from end with negative step for row and column + { + var s = mat.get(new Slice(Slice.__, Slice.__, -1), new Slice(Slice.__, Slice.__, -2)); + assertEquals(4, s.rows()); + assertEquals(2, s.cols()); + assertEquals( + new SimpleMatrix(new double[][] {{16.0, 14.0}, {12.0, 10.0}, {8.0, 6.0}, {4.0, 2.0}}), + s.value()); + } + + // Slice from start and column -1 + { + var s = mat.get(new Slice(1, Slice.__), -1); + assertEquals(3, s.rows()); + assertEquals(1, s.cols()); + assertEquals(new SimpleMatrix(new double[][] {{8.0}, {12.0}, {16.0}}), s.value()); + } + + // Slice from start and column -2 + { + var s = mat.get(new Slice(1, Slice.__), -2); + assertEquals(3, s.rows()); + assertEquals(1, s.cols()); + assertEquals(new SimpleMatrix(new double[][] {{7.0}, {11.0}, {15.0}}), s.value()); + } + + // Block assignment + { + var s = mat.get(new Slice(Slice.__, Slice.__, 2), new Slice(Slice.__, Slice.__, 2)); + assertEquals(2, s.rows()); + assertEquals(2, s.cols()); + s.setValue(new double[][] {{17.0, 18.0}, {19.0, 20.0}}); + assertEquals( + new SimpleMatrix( + new double[][] { + {17.0, 2.0, 18.0, 4.0}, + {5.0, 6.0, 7.0, 8.0}, + {19.0, 10.0, 20.0, 12.0}, + {13.0, 14.0, 15.0, 16.0} + }), + mat.value()); + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testSubslicing() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + // Block-of-block assignment (row skip forward) + { + var mat = new VariableMatrix(5, 5); + var s = + mat.get(new Slice(Slice.__, Slice.__, 2), new Slice(Slice.__, Slice.__, 1)) + .get(new Slice(1, 3), new Slice(1, 4)); + assertEquals(2, s.rows()); + assertEquals(3, s.cols()); + s.setValue(new double[][] {{1, 2, 3}, {4, 5, 6}}); + + assertEquals( + new SimpleMatrix( + new double[][] { + {0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}, + {0, 1, 2, 3, 0}, + {0, 0, 0, 0, 0}, + {0, 4, 5, 6, 0} + }), + mat.value()); + } + + // Block-of-block assignment (row skip backward) + { + var mat = new VariableMatrix(5, 5); + var s = + mat.get(new Slice(Slice.__, Slice.__, -2), new Slice(Slice.__, Slice.__, -1)) + .get(new Slice(1, 3), new Slice(1, 4)); + assertEquals(2, s.rows()); + assertEquals(3, s.cols()); + s.setValue(new double[][] {{1, 2, 3}, {4, 5, 6}}); + + assertEquals( + new SimpleMatrix( + new double[][] { + {0, 6, 5, 4, 0}, + {0, 0, 0, 0, 0}, + {0, 3, 2, 1, 0}, + {0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0} + }), + mat.value()); + } + + // Block-of-block assignment (column skip forward) + { + var mat = new VariableMatrix(5, 5); + var s = + mat.get(new Slice(Slice.__, Slice.__, 1), new Slice(Slice.__, Slice.__, 2)) + .get(new Slice(1, 4), new Slice(1, 3)); + assertEquals(3, s.rows()); + assertEquals(2, s.cols()); + s.setValue(new double[][] {{1, 2}, {3, 4}, {5, 6}}); + + assertEquals( + new SimpleMatrix( + new double[][] { + {0, 0, 0, 0, 0}, + {0, 0, 1, 0, 2}, + {0, 0, 3, 0, 4}, + {0, 0, 5, 0, 6}, + {0, 0, 0, 0, 0} + }), + mat.value()); + } + + // Block-of-block assignment (column skip backward) + { + var mat = new VariableMatrix(5, 5); + var s = + mat.get(new Slice(Slice.__, Slice.__, -1), new Slice(Slice.__, Slice.__, -2)) + .get(new Slice(1, 4), new Slice(1, 3)); + assertEquals(3, s.rows()); + assertEquals(2, s.cols()); + s.setValue(new double[][] {{1, 2}, {3, 4}, {5, 6}}); + + assertEquals( + new SimpleMatrix( + new double[][] { + {0, 0, 0, 0, 0}, + {6, 0, 5, 0, 0}, + {4, 0, 3, 0, 0}, + {2, 0, 1, 0, 0}, + {0, 0, 0, 0, 0} + }), + mat.value()); + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @SuppressWarnings("PMD.UnusedLocalVariable") + @Test + void testIterators() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + final var A = + new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}}); + final var sub_A = A.block(2, 1, 1, 2); + + int distance = 0; + for (var elem : A) { + ++distance; + } + assertEquals(9, distance); + + distance = 0; + for (var elem : sub_A) { + ++distance; + } + assertEquals(2, distance); + + int i = 1; + for (var elem : A) { + assertEquals(i, elem.value()); + ++i; + } + + i = 8; + for (var elem : sub_A) { + assertEquals(i, elem.value()); + ++i; + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testValue() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var A = + new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}}); + var expected = + new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}}); + + // Full matrix + assertEquals(expected, A.value()); + assertEquals(4.0, A.value(3)); + assertEquals(2.0, A.T().value(3)); + + // Block + assertEquals(expected.extractMatrix(1, 3, 1, 3), A.block(1, 1, 2, 2).value()); + assertEquals(8.0, A.block(1, 1, 2, 2).value(2)); + assertEquals(6.0, A.T().block(1, 1, 2, 2).value(2)); + + // Slice + assertEquals( + expected.extractMatrix(1, 3, 1, 3), A.get(new Slice(1, 3), new Slice(1, 3)).value()); + assertEquals(8.0, A.get(new Slice(1, 3), new Slice(1, 3)).value(2)); + assertEquals(6.0, A.get(new Slice(1, 3), new Slice(1, 3)).T().value(2)); + + // Block-of-block + assertEquals( + expected.extractMatrix(1, 3, 1, 3).extractMatrix(0, 2, 1, 2), + A.block(1, 1, 2, 2).block(0, 1, 2, 1).value()); + assertEquals(9.0, A.block(1, 1, 2, 2).block(0, 1, 2, 1).value(1)); + assertEquals(9.0, A.block(1, 1, 2, 2).T().block(0, 1, 2, 1).value(1)); + + // Slice-of-slice + assertEquals( + expected.extractMatrix(1, 3, 1, 3).extractMatrix(0, 2, 1, 2), + A.get(new Slice(1, 3), new Slice(1, 3)).get(Slice.__, new Slice(1, Slice.__)).value()); + assertEquals( + 9.0, + A.get(new Slice(1, 3), new Slice(1, 3)).get(Slice.__, new Slice(1, Slice.__)).value(1)); + assertEquals( + 9.0, + A.get(new Slice(1, 3), new Slice(1, 3)) + .T() + .get(Slice.__, new Slice(1, Slice.__)) + .value(1)); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testCwiseMap() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + // VariableMatrix cwiseMap + var A = new VariableMatrix(new double[][] {{-2.0, -3.0, -4.0}, {-5.0, -6.0, -7.0}}); + + var result1 = A.cwiseMap(Variable::abs); + var expected1 = new SimpleMatrix(new double[][] {{2.0, 3.0, 4.0}, {5.0, 6.0, 7.0}}); + + // Don't modify original matrix + assertEquals(expected1.scale(-1.0), A.value()); + + assertEquals(expected1, result1.value()); + + // VariableBlock cwiseMap + var sub_A = A.block(0, 0, 2, 2); + + var result2 = sub_A.cwiseMap(Variable::abs); + var expected2 = new SimpleMatrix(new double[][] {{2.0, 3.0}, {5.0, 6.0}}); + + // Don't modify original matrix + assertEquals(expected1.scale(-1.0), A.value()); + assertEquals(expected2.scale(-1.0), sub_A.value()); + + assertEquals(expected2, result2.value()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testZeroStaticFunction() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var A = VariableMatrix.zero(2, 3)) { + for (var elem : A) { + assertEquals(0.0, elem.value()); + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testOneStaticFunction() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var A = VariableMatrix.one(2, 3)) { + for (var elem : A) { + assertEquals(1.0, elem.value()); + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testConstantStaticFunction() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var A = VariableMatrix.constant(2, 3, 2.0)) { + for (var elem : A) { + assertEquals(2.0, elem.value()); + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testCwiseReduce() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var A = new VariableMatrix(new double[][] {{2.0, 3.0, 4.0}, {5.0, 6.0, 7.0}}); + var B = new VariableMatrix(new double[][] {{8.0, 9.0, 10.0}, {11.0, 12.0, 13.0}}); + var result = VariableMatrix.cwiseReduce(A, B, (Variable x, Variable y) -> x.times(y)); + + var expected = new SimpleMatrix(new double[][] {{16.0, 27.0, 40.0}, {55.0, 72.0, 91.0}}); + assertEquals(expected, result.value()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testBlockFreeFunction() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + var A = new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}); + var B = new VariableMatrix(new double[][] {{7.0}, {8.0}}); + + var mat1 = VariableMatrix.block(new VariableMatrix[][] {{A, B}}); + var expected1 = new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0, 7.0}, {4.0, 5.0, 6.0, 8.0}}); + assertEquals(2, mat1.rows()); + assertEquals(4, mat1.cols()); + assertEquals(expected1, mat1.value()); + + var C = new VariableMatrix(new double[][] {{9.0, 10.0, 11.0, 12.0}}); + + var mat2 = VariableMatrix.block(new VariableMatrix[][] {{A, B}, {C}}); + var expected2 = + new SimpleMatrix( + new double[][] {{1.0, 2.0, 3.0, 7.0}, {4.0, 5.0, 6.0, 8.0}, {9.0, 10.0, 11.0, 12.0}}); + assertEquals(3, mat2.rows()); + assertEquals(4, mat2.cols()); + assertEquals(expected2, mat2.value()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + private void checkSolve(VariableMatrix A, VariableMatrix B) { + try (var X = VariableMatrix.solve(A, B)) { + assertEquals(A.cols(), X.rows()); + assertEquals(B.cols(), X.cols()); + assertTrue(A.value().mult(X.value()).minus(B.value()).normF() < 1e-12); + } + } + + @Test + void testSolveFreeFunction() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + // 1x1 special case + try (var pool = new VariablePool()) { + checkSolve( + new VariableMatrix(new double[][] {{2.0}}), new VariableMatrix(new double[][] {{5.0}})); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + // 2x2 special case + try (var pool = new VariablePool()) { + checkSolve( + new VariableMatrix(new double[][] {{1.0, 2.0}, {3.0, 4.0}}), + new VariableMatrix(new double[][] {{5.0}, {6.0}})); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + // 3x3 special case + try (var pool = new VariablePool()) { + checkSolve( + new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {-4.0, -5.0, 6.0}, {7.0, 8.0, 9.0}}), + new VariableMatrix(new double[][] {{10.0}, {11.0}, {12.0}})); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + // 4x4 special case + try (var pool = new VariablePool()) { + checkSolve( + new VariableMatrix( + new double[][] { + {1.0, 2.0, 3.0, -4.0}, + {-5.0, 6.0, 7.0, 8.0}, + {9.0, 10.0, 11.0, 12.0}, + {13.0, 14.0, 15.0, 16.0} + }), + new VariableMatrix(new double[][] {{17.0}, {18.0}, {19.0}, {20.0}})); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + // 5x5 general case + try (var pool = new VariablePool()) { + checkSolve( + new VariableMatrix( + new double[][] { + {1.0, 2.0, 3.0, -4.0, 5.0}, + {-5.0, 6.0, 7.0, 8.0, 9.0}, + {9.0, 10.0, 11.0, 12.0, 13.0}, + {13.0, 14.0, 15.0, 16.0, 17.0}, + {17.0, 18.0, 19.0, 20.0, 21.0} + }), + new VariableMatrix(new double[][] {{21.0}, {22.0}, {23.0}, {24.0}, {25.0}})); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/autodiff/VariableTest.java b/wpimath/src/test/java/org/wpilib/math/autodiff/VariableTest.java new file mode 100644 index 0000000000..ce4f51372b --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/autodiff/VariableTest.java @@ -0,0 +1,57 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +class VariableTest { + @Test + void testDefaultConstructor() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var a = new Variable()) { + assertEquals(0.0, a.value()); + assertEquals(ExpressionType.LINEAR, a.type()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testConstantConstructor() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var pool = new VariablePool()) { + // float + var a = new Variable(1.0); + assertEquals(1, a.value()); + assertEquals(ExpressionType.CONSTANT, a.type()); + + // int + var b = new Variable(2); + assertEquals(2, b.value()); + assertEquals(ExpressionType.CONSTANT, b.type()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testSetValue() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var a = new Variable()) { + a.setValue(1.0); + assertEquals(1.0, a.value()); + + a.setValue(2.0); + assertEquals(2.0, a.value()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/controller/ArmFeedforwardTest.java b/wpimath/src/test/java/org/wpilib/math/controller/ArmFeedforwardTest.java index 67b007e162..3a2df9ee5d 100644 --- a/wpimath/src/test/java/org/wpilib/math/controller/ArmFeedforwardTest.java +++ b/wpimath/src/test/java/org/wpilib/math/controller/ArmFeedforwardTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import java.util.function.BiFunction; import org.junit.jupiter.api.Test; +import org.wpilib.math.autodiff.Variable; import org.wpilib.math.linalg.MatBuilder; import org.wpilib.math.linalg.Matrix; import org.wpilib.math.numbers.N1; @@ -95,6 +96,8 @@ class ArmFeedforwardTest { @Test void testCalculate() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + final double ks = 0.5; final double kv = 1.5; final double ka = 2; @@ -110,10 +113,14 @@ class ArmFeedforwardTest { calculateAndSimulate(armFF, ks, kv, ka, kg, Math.PI / 3, 1.0, 0.95, 0.020); calculateAndSimulate(armFF, ks, kv, ka, kg, -Math.PI / 3, 1.0, 1.05, 0.020); calculateAndSimulate(armFF, ks, kv, ka, kg, -Math.PI / 3, 1.0, 0.95, 0.020); + + assertEquals(0, Variable.totalNativeMemoryUsage()); } @Test void testCalculateIllConditionedModel() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + final double ks = 0.39671; final double kv = 2.7167; final double ka = 1e-2; @@ -129,10 +136,14 @@ class ArmFeedforwardTest { assertEquals( armFF.calculate(currentAngle, currentVelocity, nextVelocity), ks + kv * currentVelocity + ka * averageAccel + kg * Math.cos(currentAngle)); + + assertEquals(0, Variable.totalNativeMemoryUsage()); } @Test void testCalculateIllConditionedGradient() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + final double ks = 0.39671; final double kv = 2.7167; final double ka = 0.50799; @@ -140,6 +151,8 @@ class ArmFeedforwardTest { final ArmFeedforward armFF = new ArmFeedforward(ks, kg, kv, ka); calculateAndSimulate(armFF, ks, kv, ka, kg, 1.0, 0.02, 0.0, 0.02); + + assertEquals(0, Variable.totalNativeMemoryUsage()); } @Test diff --git a/wpimath/src/test/java/org/wpilib/math/geometry/Ellipse2dTest.java b/wpimath/src/test/java/org/wpilib/math/geometry/Ellipse2dTest.java index dbe027ad7e..355a415eb8 100644 --- a/wpimath/src/test/java/org/wpilib/math/geometry/Ellipse2dTest.java +++ b/wpimath/src/test/java/org/wpilib/math/geometry/Ellipse2dTest.java @@ -11,6 +11,7 @@ import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import org.junit.jupiter.api.Test; +import org.wpilib.math.autodiff.Variable; class Ellipse2dTest { private static final double kEpsilon = 1E-9; @@ -56,6 +57,8 @@ class Ellipse2dTest { @Test void testDistance() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + var center = new Pose2d(1.0, 2.0, Rotation2d.fromDegrees(270.0)); var ellipse = new Ellipse2d(center, 1.0, 2.0); @@ -70,10 +73,14 @@ class Ellipse2dTest { var point4 = new Translation2d(-1.0, 2.5); assertEquals(0.19210128384806818, ellipse.getDistance(point4), kEpsilon); + + assertEquals(0, Variable.totalNativeMemoryUsage()); } @Test void testNearest() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + var center = new Pose2d(1.0, 2.0, Rotation2d.fromDegrees(270.0)); var ellipse = new Ellipse2d(center, 1.0, 2.0); @@ -100,6 +107,8 @@ class Ellipse2dTest { assertAll( () -> assertEquals(-0.8512799937611617, nearestPoint4.getX(), kEpsilon), () -> assertEquals(2.378405333174535, nearestPoint4.getY(), kEpsilon)); + + assertEquals(0, Variable.totalNativeMemoryUsage()); } @Test diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/ArmOnElevatorProblemTest.java b/wpimath/src/test/java/org/wpilib/math/optimization/ArmOnElevatorProblemTest.java new file mode 100644 index 0000000000..94bb149abf --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/ArmOnElevatorProblemTest.java @@ -0,0 +1,121 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.wpilib.math.autodiff.Variable.pow; +import static org.wpilib.math.optimization.Constraints.bounds; +import static org.wpilib.math.optimization.Constraints.eq; +import static org.wpilib.math.optimization.Constraints.le; + +import org.junit.jupiter.api.Test; +import org.wpilib.math.autodiff.ExpressionType; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.autodiff.VariableMatrix; +import org.wpilib.math.optimization.solver.ExitStatus; + +class ArmOnElevatorProblemTest { + @Test + void testArmOnElevatorProblem() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + final int N = 800; + + final double ELEVATOR_START_HEIGHT = 1.0; // m + final double ELEVATOR_END_HEIGHT = 1.25; // m + final double ELEVATOR_MAX_VELOCITY = 1.0; // m/s + final double ELEVATOR_MAX_ACCELERATION = 2.0; // m/s² + + final double ARM_LENGTH = 1.0; // m + final double ARM_START_ANGLE = 0.0; // rad + final double ARM_END_ANGLE = Math.PI; // rad + final double ARM_MAX_VELOCITY = 2.0 * Math.PI; // rad/s + final double ARM_MAX_ACCELERATION = 4.0 * Math.PI; // rad/s² + + final double END_EFFECTOR_MAX_HEIGHT = 1.8; // m + + final double TOTAL_TIME = 4.0; + final double dt = TOTAL_TIME / N; + + try (var problem = new Problem()) { + var elevator = problem.decisionVariable(2, N + 1); + var elevator_accel = problem.decisionVariable(1, N); + + var arm = problem.decisionVariable(2, N + 1); + var arm_accel = problem.decisionVariable(1, N); + + for (int k = 0; k < N; ++k) { + // Elevator dynamics constraints + problem.subjectTo( + eq( + elevator.get(0, k + 1), + elevator + .get(0, k) + .plus(elevator.get(1, k).times(dt)) + .plus(elevator_accel.get(0, k).times(0.5 * dt * dt)))); + problem.subjectTo( + eq( + elevator.get(1, k + 1), + elevator.get(1, k).plus(elevator_accel.get(0, k).times(dt)))); + + // Arm dynamics constraints + problem.subjectTo( + eq( + arm.get(0, k + 1), + arm.get(0, k) + .plus(arm.get(1, k).times(dt)) + .plus(arm_accel.get(0, k).times(0.5 * dt * dt)))); + problem.subjectTo(eq(arm.get(1, k + 1), arm.get(1, k).plus(arm_accel.get(0, k).times(dt)))); + } + + // Elevator start and end conditions + problem.subjectTo( + eq(elevator.col(0), new VariableMatrix(new double[][] {{ELEVATOR_START_HEIGHT}, {0.0}}))); + problem.subjectTo( + eq(elevator.col(N), new VariableMatrix(new double[][] {{ELEVATOR_END_HEIGHT}, {0.0}}))); + + // Arm start and end conditions + problem.subjectTo( + eq(arm.col(0), new VariableMatrix(new double[][] {{ARM_START_ANGLE}, {0.0}}))); + problem.subjectTo( + eq(arm.col(N), new VariableMatrix(new double[][] {{ARM_END_ANGLE}, {0.0}}))); + + // Elevator velocity limits + problem.subjectTo(bounds(-ELEVATOR_MAX_VELOCITY, elevator.row(1), ELEVATOR_MAX_VELOCITY)); + + // Elevator acceleration limits + problem.subjectTo( + bounds(-ELEVATOR_MAX_ACCELERATION, elevator_accel, ELEVATOR_MAX_ACCELERATION)); + + // Arm velocity limits + problem.subjectTo(bounds(-ARM_MAX_VELOCITY, arm.row(1), ARM_MAX_VELOCITY)); + + // Arm acceleration limits + problem.subjectTo(bounds(-ARM_MAX_ACCELERATION, arm_accel, ARM_MAX_ACCELERATION)); + + // Height limit + var heights = elevator.row(0).plus(arm.row(0).cwiseMap(Variable::sin).times(ARM_LENGTH)); + problem.subjectTo(le(heights, END_EFFECTOR_MAX_HEIGHT)); + + // Cost function + var J = new Variable(0.0); + for (int k = 0; k < N + 1; ++k) { + J = + J.plus( + pow(new Variable(ELEVATOR_END_HEIGHT).minus(elevator.get(0, k)), 2) + .plus(pow(new Variable(ARM_END_ANGLE).minus(arm.get(0, k)), 2))); + } + problem.minimize(J); + + assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType()); + assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONLINEAR, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/CartPoleOCPTest.java b/wpimath/src/test/java/org/wpilib/math/optimization/CartPoleOCPTest.java new file mode 100644 index 0000000000..33ea4adda9 --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/CartPoleOCPTest.java @@ -0,0 +1,101 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.wpilib.math.optimization.Constraints.bounds; + +import org.ejml.simple.SimpleMatrix; +import org.junit.jupiter.api.Test; +import org.wpilib.math.autodiff.ExpressionType; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.optimization.ocp.DynamicsType; +import org.wpilib.math.optimization.ocp.TimestepMethod; +import org.wpilib.math.optimization.ocp.TranscriptionMethod; +import org.wpilib.math.optimization.solver.ExitStatus; +import org.wpilib.math.util.MathUtil; + +class CartPoleOCPTest { + @Test + void testCartPole() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + final double TOTAL_TIME = 5.0; // s + final double dt = 0.05; // s + final int N = (int) (TOTAL_TIME / dt); + + final double u_max = 20.0; // N + final double d_max = 2.0; // m + + final var x_initial = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}}); + final var x_final = new SimpleMatrix(new double[][] {{1.0}, {Math.PI}, {0.0}, {0.0}}); + + try (var problem = + new OCP( + 4, + 1, + dt, + N, + CartPoleUtil::cartPoleDynamics, + DynamicsType.EXPLICIT_ODE, + TimestepMethod.VARIABLE_SINGLE, + TranscriptionMethod.DIRECT_COLLOCATION)) { + // x = [q, q̇]ᵀ = [x, θ, ẋ, θ̇]ᵀ + var X = problem.X(); + + // Initial guess + for (int k = 0; k < N + 1; ++k) { + X.get(0, k).setValue(MathUtil.lerp(x_initial.get(0), x_final.get(0), (double) k / N)); + X.get(1, k).setValue(MathUtil.lerp(x_initial.get(1), x_final.get(1), (double) k / N)); + } + + // Initial conditions + problem.constrainInitialState(x_initial); + + // Final conditions + problem.constrainFinalState(x_final); + + // Cart position constraints + problem.forEachStep( + (x, u) -> { + problem.subjectTo(bounds(0.0, x.get(0), d_max)); + }); + + // Input constraints + problem.setLowerInputBound(-u_max); + problem.setUpperInputBound(u_max); + + // u = f_x + var U = problem.U(); + + // Minimize sum squared inputs + var J = new Variable(0.0); + for (int k = 0; k < N; ++k) { + J = J.plus(U.col(k).T().times(U.col(k)).get(0)); + } + problem.minimize(J); + + assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType()); + assertEquals(ExpressionType.NONLINEAR, problem.equalityConstraintType()); + assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + // Verify initial state + assertEquals(x_initial.get(0), X.value(0, 0), 1e-8); + assertEquals(x_initial.get(1), X.value(1, 0), 1e-8); + assertEquals(x_initial.get(2), X.value(2, 0), 1e-8); + assertEquals(x_initial.get(3), X.value(3, 0), 1e-8); + + // Verify final state + assertEquals(x_final.get(0), X.value(0, N), 1e-8); + assertEquals(x_final.get(1), X.value(1, N), 1e-8); + assertEquals(x_final.get(2), X.value(2, N), 1e-8); + assertEquals(x_final.get(3), X.value(3, N), 1e-8); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/CartPoleProblemTest.java b/wpimath/src/test/java/org/wpilib/math/optimization/CartPoleProblemTest.java new file mode 100644 index 0000000000..bdb9b78efc --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/CartPoleProblemTest.java @@ -0,0 +1,114 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.wpilib.math.autodiff.NumericalIntegration.rk4; +import static org.wpilib.math.optimization.Constraints.bounds; +import static org.wpilib.math.optimization.Constraints.eq; +import static org.wpilib.math.system.NumericalIntegration.rk4; + +import org.ejml.simple.SimpleMatrix; +import org.junit.jupiter.api.Test; +import org.wpilib.math.autodiff.ExpressionType; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.optimization.solver.ExitStatus; +import org.wpilib.math.util.MathUtil; + +class CartPoleProblemTest { + @Test + void testCartPoleProblem() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + final double TOTAL_TIME = 5.0; // s + final double dt = 0.05; // s + final int N = (int) (TOTAL_TIME / dt); + + final double u_max = 20.0; // N + final double d_max = 2.0; // m + + final var x_initial = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}}); + final var x_final = new SimpleMatrix(new double[][] {{1.0}, {Math.PI}, {0.0}, {0.0}}); + + try (var problem = new Problem()) { + // x = [q, q̇]ᵀ = [x, θ, ẋ, θ̇]ᵀ + var X = problem.decisionVariable(4, N + 1); + + // Initial guess + for (int k = 0; k < N + 1; ++k) { + X.get(0, k).setValue(MathUtil.lerp(x_initial.get(0), x_final.get(0), (double) k / N)); + X.get(1, k).setValue(MathUtil.lerp(x_initial.get(1), x_final.get(1), (double) k / N)); + } + + // u = f_x + var U = problem.decisionVariable(1, N); + + // Initial conditions + problem.subjectTo(eq(X.col(0), x_initial)); + + // Final conditions + problem.subjectTo(eq(X.col(N), x_final)); + + // Cart position constraints + problem.subjectTo(bounds(0.0, X.row(0), d_max)); + + // Input constraints + problem.subjectTo(bounds(-u_max, U, u_max)); + + // Dynamics constraints - RK4 integration + for (int k = 0; k < N; ++k) { + problem.subjectTo( + eq(X.col(k + 1), rk4(CartPoleUtil::cartPoleDynamics, X.col(k), U.col(k), dt))); + } + + // Minimize sum squared inputs + var J = new Variable(0.0); + for (int k = 0; k < N; ++k) { + J = J.plus(U.col(k).T().times(U.col(k)).get(0)); + } + problem.minimize(J); + + assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType()); + assertEquals(ExpressionType.NONLINEAR, problem.equalityConstraintType()); + assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + // Verify initial state + assertEquals(x_initial.get(0), X.value(0, 0), 1e-8); + assertEquals(x_initial.get(1), X.value(1, 0), 1e-8); + assertEquals(x_initial.get(2), X.value(2, 0), 1e-8); + assertEquals(x_initial.get(3), X.value(3, 0), 1e-8); + + // Verify solution + for (int k = 0; k < N; ++k) { + // Cart position constraints + assertTrue(X.get(0, k).value() >= 0.0); + assertTrue(X.get(0, k).value() <= d_max); + + // Input constraints + assertTrue(U.get(0, k).value() >= -u_max); + assertTrue(U.get(0, k).value() <= u_max); + + // Dynamics constraints + var expected_x_k1 = + rk4(CartPoleUtil::cartPoleDynamics, X.col(k).value(), U.col(k).value(), dt); + var actual_x_k1 = X.col(k + 1).value(); + for (int row = 0; row < actual_x_k1.getNumRows(); ++row) { + assertEquals(expected_x_k1.get(row), actual_x_k1.get(row), 1e-8); + } + } + + // Verify final state + assertEquals(x_final.get(0), X.value(0, N), 1e-8); + assertEquals(x_final.get(1), X.value(1, N), 1e-8); + assertEquals(x_final.get(2), X.value(2, N), 1e-8); + assertEquals(x_final.get(3), X.value(3, N), 1e-8); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/CartPoleUtil.java b/wpimath/src/test/java/org/wpilib/math/optimization/CartPoleUtil.java new file mode 100644 index 0000000000..07624ee9da --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/CartPoleUtil.java @@ -0,0 +1,122 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.wpilib.math.autodiff.Variable.cos; +import static org.wpilib.math.autodiff.Variable.sin; +import static org.wpilib.math.autodiff.VariableMatrix.solve; + +import org.ejml.simple.SimpleMatrix; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.autodiff.VariableMatrix; + +// https://underactuated.mit.edu/acrobot.html#cart_pole +// +// θ is CCW+ measured from negative y-axis. +// +// q = [x, θ]ᵀ +// q̇ = [ẋ, θ̇]ᵀ +// u = f_x +// +// M(q)q̈ + C(q, q̇)q̇ = τ_g(q) + Bu +// M(q)q̈ = τ_g(q) − C(q, q̇)q̇ + Bu +// q̈ = M⁻¹(q)(τ_g(q) − C(q, q̇)q̇ + Bu) +// +// [ m_c + m_p m_p l cosθ] +// M(q) = [m_p l cosθ m_p l² ] +// +// [0 −m_p lθ̇ sinθ] +// C(q, q̇) = [0 0 ] +// +// [ 0 ] +// τ_g(q) = [-m_p gl sinθ] +// +// [1] +// B = [0] + +public final class CartPoleUtil { + private CartPoleUtil() { + // Utility class. + } + + private static final double m_c = 5.0; // Cart mass (kg) + private static final double m_p = 0.5; // Pole mass (kg) + private static final double l = 0.5; // Pole length (m) + private static final double g = 9.806; // Acceleration due to gravity (m/s²) + + public static SimpleMatrix cartPoleDynamics(SimpleMatrix x, SimpleMatrix u) { + var q = x.extractMatrix(0, 2, 0, 1); + var qdot = x.extractMatrix(2, 4, 0, 1); + var theta = q.get(1, 0); + var thetadot = qdot.get(1, 0); + + // [ m_c + m_p m_p l cosθ] + // M(q) = [m_p l cosθ m_p l² ] + var M = + new SimpleMatrix( + new double[][] { + {m_c + m_p, m_p * l * Math.cos(theta)}, + {m_p * l * Math.cos(theta), m_p * Math.pow(l, 2)} + }); + + // [0 −m_p lθ̇ sinθ] + // C(q, q̇) = [0 0 ] + var C = new SimpleMatrix(new double[][] {{0, -m_p * l * thetadot * Math.sin(theta)}, {0, 0}}); + + // [ 0 ] + // τ_g(q) = [-m_p gl sinθ] + var tau_g = new SimpleMatrix(new double[][] {{0}, {-m_p * g * l * Math.sin(theta)}}); + + // [1] + // B = [0] + final var B = new SimpleMatrix(new double[][] {{1}, {0}}); + + // q̈ = M⁻¹(q)(τ_g(q) − C(q, q̇)q̇ + Bu) + var qddot = new SimpleMatrix(4, 1); + qddot.insertIntoThis(0, 0, qdot); + qddot.insertIntoThis(2, 0, M.solve(tau_g.minus(C.mult(qdot)).plus(B.mult(u)))); + return qddot; + } + + public static VariableMatrix cartPoleDynamics(VariableMatrix x, VariableMatrix u) { + var q = x.segment(0, 2); + var qdot = x.segment(2, 2); + var theta = q.get(1); + var thetadot = qdot.get(1); + + // [ m_c + m_p m_p l cosθ] + // M(q) = [m_p l cosθ m_p l² ] + var M = + new VariableMatrix( + new Variable[][] { + {new Variable(m_c + m_p), cos(theta).times(m_p * l)}, + {cos(theta).times(m_p * l), new Variable(m_p * Math.pow(l, 2))} + }); + + // [0 −m_p lθ̇ sinθ] + // C(q, q̇) = [0 0 ] + var C = + new VariableMatrix( + new Variable[][] { + {new Variable(0), thetadot.times(-m_p * l).times(sin(theta))}, + {new Variable(0), new Variable(0)} + }); + + // [ 0 ] + // τ_g(q) = [-m_p gl sinθ] + var tau_g = + new VariableMatrix(new Variable[][] {{new Variable(0)}, {sin(theta).times(-m_p * g * l)}}); + + // [1] + // B = [0] + var B = new VariableMatrix(new double[][] {{1}, {0}}); + + // q̈ = M⁻¹(q)(τ_g(q) − C(q, q̇)q̇ + Bu) + var qddot = new VariableMatrix(4); + qddot.segment(0, 2).set(qdot); + qddot.segment(2, 2).set(solve(M, tau_g.minus(C.times(qdot)).plus(B.times(u)))); + return qddot; + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/CurrentManager.java b/wpimath/src/test/java/org/wpilib/math/optimization/CurrentManager.java new file mode 100644 index 0000000000..f3aef56c0b --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/CurrentManager.java @@ -0,0 +1,92 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.wpilib.math.optimization.Constraints.ge; +import static org.wpilib.math.optimization.Constraints.le; + +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.autodiff.VariableMatrix; + +/** + * This class computes the optimal current allocation for a list of subsystems given a list of their + * desired currents and current tolerances that determine which subsystem gets less current if the + * current budget is exceeded. Subsystems with a smaller tolerance are given higher priority. + */ +public class CurrentManager implements AutoCloseable { + private final Problem m_problem = new Problem(); + private final VariableMatrix m_desiredCurrents; + private final VariableMatrix m_allocatedCurrents; + + /** + * Constructs a CurrentManager. + * + * @param currentTolerances The relative current tolerance of each subsystem. + * @param maxCurrent The current budget to allocate between subsystems. + */ + public CurrentManager(double[] currentTolerances, double maxCurrent) { + this.m_desiredCurrents = new VariableMatrix(currentTolerances.length, 1); + this.m_allocatedCurrents = m_problem.decisionVariable(currentTolerances.length); + + // Ensure m_desired_currents contains initialized Variables + for (int row = 0; row < m_desiredCurrents.rows(); ++row) { + // Don't initialize to 0 or 1, because those will get folded by Sleipnir + m_desiredCurrents.get(row).setValue(Double.POSITIVE_INFINITY); + } + + var J = new Variable(0.0); + var currentSum = new Variable(0.0); + for (int i = 0; i < currentTolerances.length; ++i) { + // The weight is 1/tolᵢ² where tolᵢ is the tolerance between the desired + // and allocated current for subsystem i + var error = m_desiredCurrents.get(i).minus(m_allocatedCurrents.get(i)); + J = J.plus(error.times(error).div(currentTolerances[i] * currentTolerances[i])); + + currentSum = currentSum.plus(m_allocatedCurrents.get(i)); + + // Currents must be nonnegative + m_problem.subjectTo(ge(m_allocatedCurrents.get(i), 0.0)); + } + m_problem.minimize(J); + + // Keep total current below maximum + m_problem.subjectTo(le(currentSum, maxCurrent)); + } + + @Override + public void close() { + m_problem.close(); + } + + /** + * Returns the optimal current allocation for a list of subsystems given a list of their desired + * currents and current tolerances that determine which subsystem gets less current if the current + * budget is exceeded. Subsystems with a smaller tolerance are given higher priority. + * + * @param desiredCurrents The desired current for each subsystem. + * @throws RuntimeException if the number of desired currents doesn't equal the number of + * tolerances passed in the constructor. + */ + public double[] calculate(double[] desiredCurrents) { + if (m_desiredCurrents.rows() != desiredCurrents.length) { + throw new RuntimeException( + "Number of desired currents must equal the number of tolerances passed in the " + + "constructor."); + } + + for (int i = 0; i < desiredCurrents.length; ++i) { + m_desiredCurrents.get(i).setValue(desiredCurrents[i]); + } + + m_problem.solve(); + + var result = new double[desiredCurrents.length]; + for (int i = 0; i < desiredCurrents.length; ++i) { + result[i] = Math.max(m_allocatedCurrents.value(i), 0.0); + } + + return result; + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/CurrentManagerTest.java b/wpimath/src/test/java/org/wpilib/math/optimization/CurrentManagerTest.java new file mode 100644 index 0000000000..b4c7b0ff1e --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/CurrentManagerTest.java @@ -0,0 +1,62 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; +import org.wpilib.math.autodiff.Variable; + +class CurrentManagerTest { + @Test + void testEnoughCurrent() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var manager = new CurrentManager(new double[] {1.0, 5.0, 10.0, 5.0}, 40.0)) { + var currents = manager.calculate(new double[] {25.0, 10.0, 5.0, 0.0}); + + assertEquals(25.0, currents[0], 1e-3); + assertEquals(10.0, currents[1], 1e-3); + assertEquals(5.0, currents[2], 1e-3); + assertEquals(0.0, currents[3], 1e-3); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testNotEnoughCurrent() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var manager = new CurrentManager(new double[] {1.0, 5.0, 10.0, 5.0}, 40.0)) { + var currents = manager.calculate(new double[] {30.0, 10.0, 5.0, 0.0}); + + // Expected values are from the following program: + // + // #!/usr/bin/env python3 + // + // from scipy.optimize import minimize + // + // r = [30.0, 10.0, 5.0, 0.0] + // q = [1.0, 5.0, 10.0, 5.0] + // + // result = minimize( + // lambda x: sum((r[i] - x[i]) ** 2 / q[i] ** 2 for i in range(4)), + // [0.0, 0.0, 0.0, 0.0], + // constraints=[ + // {"type": "ineq", "fun": lambda x: x}, + // {"type": "ineq", "fun": lambda x: 40.0 - sum(x)}, + // ], + // ) + // print(result.x) + assertEquals(29.960, currents[0], 1e-3); + assertEquals(9.008, currents[1], 1e-3); + assertEquals(1.032, currents[2], 1e-3); + assertEquals(0.0, currents[3], 1e-3); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/DecisionVariableTest.java b/wpimath/src/test/java/org/wpilib/math/optimization/DecisionVariableTest.java new file mode 100644 index 0000000000..348183a5c9 --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/DecisionVariableTest.java @@ -0,0 +1,138 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.wpilib.math.MatrixAssertions.assertEquals; + +import org.ejml.simple.SimpleMatrix; +import org.junit.jupiter.api.Test; +import org.wpilib.math.autodiff.Variable; + +class DecisionVariableTest { + @Test + void testScalarInitAssign() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + // Scalar zero init + var x = problem.decisionVariable(); + assertEquals(0.0, x.value()); + + // Scalar assignment + x.setValue(1.0); + assertEquals(1.0, x.value()); + x.setValue(2.0); + assertEquals(2.0, x.value()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testVectorInitAssign() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + // Vector zero init + var y = problem.decisionVariable(2); + assertEquals(0.0, y.value(0)); + assertEquals(0.0, y.value(1)); + + // Vector assignment + y.get(0).setValue(1.0); + y.get(1).setValue(2.0); + assertEquals(1.0, y.value(0)); + assertEquals(2.0, y.value(1)); + y.get(0).setValue(3.0); + y.get(1).setValue(4.0); + assertEquals(3.0, y.value(0)); + assertEquals(4.0, y.value(1)); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testDynamicMatrixInitAssign() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + // Matrix zero init + var z = problem.decisionVariable(3, 2); + assertEquals(0.0, z.value(0, 0)); + assertEquals(0.0, z.value(0, 1)); + assertEquals(0.0, z.value(1, 0)); + assertEquals(0.0, z.value(1, 1)); + assertEquals(0.0, z.value(2, 0)); + assertEquals(0.0, z.value(2, 1)); + + // Matrix assignment; element comparison + z.setValue(new double[][] {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); + assertEquals(1.0, z.value(0, 0)); + assertEquals(2.0, z.value(0, 1)); + assertEquals(3.0, z.value(1, 0)); + assertEquals(4.0, z.value(1, 1)); + assertEquals(5.0, z.value(2, 0)); + assertEquals(6.0, z.value(2, 1)); + + // Matrix assignment; matrix comparison + { + var expected = new SimpleMatrix(new double[][] {{7.0, 8.0}, {9.0, 10.0}, {11.0, 12.0}}); + z.setValue(expected); + assertEquals(expected, z.value()); + } + + // Block assignment + { + var expected_block = new double[][] {{1.0}, {1.0}}; + z.block(0, 0, 2, 1).setValue(expected_block); + + var expected_result = + new SimpleMatrix(new double[][] {{1.0, 8.0}, {1.0, 10.0}, {11.0, 12.0}}); + assertEquals(expected_result, z.value()); + } + + // Segment assignment + { + var expected_block = new double[][] {{1.0}, {1.0}}; + z.block(0, 0, 3, 1).segment(0, 2).setValue(expected_block); + + var expected_result = + new SimpleMatrix(new double[][] {{1.0, 8.0}, {1.0, 10.0}, {11.0, 12.0}}); + assertEquals(expected_result, z.value()); + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testSymmetricMatrix() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + // Matrix zero init + var A = problem.symmetricDecisionVariable(2); + assertEquals(0.0, A.value(0, 0)); + assertEquals(0.0, A.value(0, 1)); + assertEquals(0.0, A.value(1, 0)); + assertEquals(0.0, A.value(1, 1)); + + // Assign to lower triangle + A.get(0, 0).setValue(1.0); + A.get(1, 0).setValue(2.0); + A.get(1, 1).setValue(3.0); + + // Confirm whole matrix changed + assertEquals(1.0, A.value(0, 0)); + assertEquals(2.0, A.value(0, 1)); + assertEquals(2.0, A.value(1, 0)); + assertEquals(3.0, A.value(1, 1)); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/DifferentialDriveOCPTest.java b/wpimath/src/test/java/org/wpilib/math/optimization/DifferentialDriveOCPTest.java new file mode 100644 index 0000000000..2d0f339811 --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/DifferentialDriveOCPTest.java @@ -0,0 +1,85 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.ejml.simple.SimpleMatrix; +import org.junit.jupiter.api.Test; +import org.wpilib.math.autodiff.ExpressionType; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.optimization.ocp.DynamicsType; +import org.wpilib.math.optimization.ocp.TimestepMethod; +import org.wpilib.math.optimization.ocp.TranscriptionMethod; +import org.wpilib.math.optimization.solver.ExitStatus; +import org.wpilib.math.optimization.solver.Options; + +class DifferentialDriveOCPTest { + @Test + void testDifferentialDrive() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + final int N = 50; + + final double minTimestep = 0.05; // s + final var x_initial = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}, {0.0}}); + final var x_final = new SimpleMatrix(new double[][] {{1.0}, {1.0}, {0.0}, {0.0}, {0.0}}); + final var u_min = new SimpleMatrix(new double[][] {{-12.0}, {-12.0}}); + final var u_max = new SimpleMatrix(new double[][] {{12.0}, {12.0}}); + + try (var problem = + new OCP( + 5, + 2, + minTimestep, + N, + DifferentialDriveUtil::differentialDriveDynamics, + DynamicsType.EXPLICIT_ODE, + TimestepMethod.VARIABLE_SINGLE, + TranscriptionMethod.DIRECT_TRANSCRIPTION)) { + // Seed the min time formulation with lerp between waypoints + for (int i = 0; i < N + 1; ++i) { + problem.X().get(0, i).setValue((double) i / (N + 1)); + problem.X().get(1, i).setValue((double) i / (N + 1)); + } + + problem.constrainInitialState(x_initial); + problem.constrainFinalState(x_final); + + problem.setLowerInputBound(u_min); + problem.setUpperInputBound(u_max); + + problem.setMinTimestep(minTimestep); + problem.setMaxTimestep(3.0); + + // Set up cost + problem.minimize(problem.dt().times(SimpleMatrix.ones(N + 1, 1))); + + assertEquals(ExpressionType.LINEAR, problem.costFunctionType()); + assertEquals(ExpressionType.NONLINEAR, problem.equalityConstraintType()); + assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve(new Options().withMaxIterations(1000))); + + var X = problem.X(); + + // Verify initial state + assertEquals(x_initial.get(0), X.value(0, 0), 1e-8); + assertEquals(x_initial.get(1), X.value(1, 0), 1e-8); + assertEquals(x_initial.get(2), X.value(2, 0), 1e-8); + assertEquals(x_initial.get(3), X.value(3, 0), 1e-8); + assertEquals(x_initial.get(4), X.value(4, 0), 1e-8); + + // Verify final state + assertEquals(x_final.get(0), X.value(0, N), 1e-8); + assertEquals(x_final.get(1), X.value(1, N), 1e-8); + assertEquals(x_final.get(2), X.value(2, N), 1e-8); + assertEquals(x_final.get(3), X.value(3, N), 1e-8); + assertEquals(x_final.get(4), X.value(4, N), 1e-8); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/DifferentialDriveProblemTest.java b/wpimath/src/test/java/org/wpilib/math/optimization/DifferentialDriveProblemTest.java new file mode 100644 index 0000000000..8cb71ef17f --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/DifferentialDriveProblemTest.java @@ -0,0 +1,116 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.wpilib.math.autodiff.NumericalIntegration.rk4; +import static org.wpilib.math.optimization.Constraints.bounds; +import static org.wpilib.math.optimization.Constraints.eq; +import static org.wpilib.math.system.NumericalIntegration.rk4; + +import org.ejml.simple.SimpleMatrix; +import org.junit.jupiter.api.Test; +import org.wpilib.math.autodiff.ExpressionType; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.optimization.solver.ExitStatus; +import org.wpilib.math.util.MathUtil; + +class DifferentialDriveProblemTest { + @Test + void testDifferentialDriveProblem() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + final double TOTAL_TIME = 5.0; // s + final double dt = 0.05; // s + final int N = (int) (TOTAL_TIME / dt); + + final double u_max = 12.0; // V + + final var x_initial = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}, {0.0}}); + final var x_final = new SimpleMatrix(new double[][] {{1.0}, {1.0}, {0.0}, {0.0}, {0.0}}); + + try (var problem = new Problem()) { + // x = [x, y, heading, left velocity, right velocity]ᵀ + var X = problem.decisionVariable(5, N + 1); + + // Initial guess + for (int k = 0; k < N; ++k) { + X.get(0, k).setValue(MathUtil.lerp(x_initial.get(0), x_final.get(0), (double) k / N)); + X.get(1, k).setValue(MathUtil.lerp(x_initial.get(1), x_final.get(1), (double) k / N)); + } + + // u = [left voltage, right voltage]ᵀ + var U = problem.decisionVariable(2, N); + + // Initial conditions + problem.subjectTo(eq(X.col(0), x_initial)); + + // Final conditions + problem.subjectTo(eq(X.col(N), x_final)); + + // Input constraints + problem.subjectTo(bounds(-u_max, U, u_max)); + + // Dynamics constraints - RK4 integration + for (int k = 0; k < N; ++k) { + problem.subjectTo( + eq( + X.col(k + 1), + rk4(DifferentialDriveUtil::differentialDriveDynamics, X.col(k), U.col(k), dt))); + } + + // Minimize sum squared states and inputs + var J = new Variable(0.0); + for (int k = 0; k < N; ++k) { + J = J.plus(X.col(k).T().times(X.col(k)).plus(U.col(k).T().times(U.col(k))).get(0)); + } + problem.minimize(J); + + assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType()); + assertEquals(ExpressionType.NONLINEAR, problem.equalityConstraintType()); + assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + // Verify initial state + assertEquals(x_initial.get(0), X.value(0, 0), 1e-8); + assertEquals(x_initial.get(1), X.value(1, 0), 1e-8); + assertEquals(x_initial.get(2), X.value(2, 0), 1e-8); + assertEquals(x_initial.get(3), X.value(3, 0), 1e-8); + assertEquals(x_initial.get(4), X.value(4, 0), 1e-8); + + // Verify solution + var x = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}, {0.0}}); + for (int k = 0; k < N; ++k) { + // Input constraints + assertTrue(U.get(0, k).value() >= -u_max); + assertTrue(U.get(0, k).value() <= u_max); + assertTrue(U.get(1, k).value() >= -u_max); + assertTrue(U.get(1, k).value() <= u_max); + + // Verify state + assertEquals(x.get(0), X.value(0, k), 1e-8); + assertEquals(x.get(1), X.value(1, k), 1e-8); + assertEquals(x.get(2), X.value(2, k), 1e-8); + assertEquals(x.get(3), X.value(3, k), 1e-8); + assertEquals(x.get(4), X.value(4, k), 1e-8); + + // Project state forward + var u = U.col(k).value(); + x = rk4(DifferentialDriveUtil::differentialDriveDynamics, x, u, dt); + } + + // Verify final state + assertEquals(x_final.get(0), X.value(0, N), 1e-8); + assertEquals(x_final.get(1), X.value(1, N), 1e-8); + assertEquals(x_final.get(2), X.value(2, N), 1e-8); + assertEquals(x_final.get(3), X.value(3, N), 1e-8); + assertEquals(x_final.get(4), X.value(4, N), 1e-8); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/DifferentialDriveUtil.java b/wpimath/src/test/java/org/wpilib/math/optimization/DifferentialDriveUtil.java new file mode 100644 index 0000000000..2233bebf71 --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/DifferentialDriveUtil.java @@ -0,0 +1,58 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.wpilib.math.autodiff.Variable.cos; +import static org.wpilib.math.autodiff.Variable.sin; + +import org.ejml.simple.SimpleMatrix; +import org.wpilib.math.autodiff.VariableMatrix; + +// x = [x, y, heading, left velocity, right velocity]ᵀ +// u = [left voltage, right voltage]ᵀ + +public final class DifferentialDriveUtil { + private DifferentialDriveUtil() { + // Utility class. + } + + private static final double trackwidth = 0.699; // m + private static final double Kv_linear = 3.02; // V/(m/s) + private static final double Ka_linear = 0.642; // V/(m/s²) + private static final double Kv_angular = 1.382; // V/(m/s) + private static final double Ka_angular = 0.08495; // V/(m/s²) + + private static final double A1 = -(Kv_linear / Ka_linear + Kv_angular / Ka_angular) / 2.0; + private static final double A2 = -(Kv_linear / Ka_linear - Kv_angular / Ka_angular) / 2.0; + private static final double B1 = 0.5 / Ka_linear + 0.5 / Ka_angular; + private static final double B2 = 0.5 / Ka_linear - 0.5 / Ka_angular; + private static final SimpleMatrix A = new SimpleMatrix(new double[][] {{A1, A2}, {A2, A1}}); + private static final SimpleMatrix B = new SimpleMatrix(new double[][] {{B1, B2}, {B2, B1}}); + + public static SimpleMatrix differentialDriveDynamics(SimpleMatrix x, SimpleMatrix u) { + var xdot = new SimpleMatrix(5, 1); + + var v = (x.get(3, 0) + x.get(4, 0)) / 2.0; + xdot.set(0, 0, v * Math.cos(x.get(2, 0))); + xdot.set(1, 0, v * Math.sin(x.get(2, 0))); + xdot.set(2, 0, (x.get(4, 0) - x.get(3, 0)) / trackwidth); + xdot.insertIntoThis(3, 0, A.mult(x.extractMatrix(3, 5, 0, 1)).plus(B.mult(u))); + + return xdot; + } + + public static VariableMatrix differentialDriveDynamics(VariableMatrix x, VariableMatrix u) { + var xdot = new VariableMatrix(5); + + var v = x.get(3).plus(x.get(4)).div(2.0); + xdot.set(0, 0, v.times(cos(x.get(2)))); + xdot.set(1, 0, v.times(sin(x.get(2)))); + xdot.set(2, 0, x.get(4).minus(x.get(3)).div(trackwidth)); + xdot.segment(3, 2) + .set(new VariableMatrix(A).times(x.segment(3, 2)).plus(new VariableMatrix(B).times(u))); + + return xdot; + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/DoubleIntegratorProblemTest.java b/wpimath/src/test/java/org/wpilib/math/optimization/DoubleIntegratorProblemTest.java new file mode 100644 index 0000000000..2f4a53aa46 --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/DoubleIntegratorProblemTest.java @@ -0,0 +1,127 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.wpilib.math.autodiff.Variable.pow; +import static org.wpilib.math.optimization.Constraints.bounds; +import static org.wpilib.math.optimization.Constraints.eq; + +import org.ejml.simple.SimpleMatrix; +import org.junit.jupiter.api.Test; +import org.wpilib.math.autodiff.ExpressionType; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.autodiff.VariableMatrix; +import org.wpilib.math.optimization.solver.ExitStatus; + +class DoubleIntegratorProblemTest { + @Test + void testDoubleIntegratorProblem() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + final double TOTAL_TIME = 3.5; // s + final double dt = 0.005; // s + final int N = (int) (TOTAL_TIME / dt); + + final double r = 2.0; // m + + try (var problem = new Problem()) { + // 2x1 state vector with N + 1 timesteps (includes last state) + var X = problem.decisionVariable(2, N + 1); + + // 1x1 input vector with N timesteps (input at last state doesn't matter) + var U = problem.decisionVariable(1, N); + + // Kinematics constraint assuming constant acceleration between timesteps + for (int k = 0; k < N; ++k) { + final double t = dt; + var p_k1 = X.get(0, k + 1); + var v_k1 = X.get(1, k + 1); + var p_k = X.get(0, k); + var v_k = X.get(1, k); + var a_k = U.get(0, k); + + // pₖ₊₁ = pₖ + vₖt + 1/2aₖt² + problem.subjectTo(eq(p_k1, p_k.plus(v_k.times(t)).plus(a_k.times(0.5 * t * t)))); + + // vₖ₊₁ = vₖ + aₖt + problem.subjectTo(eq(v_k1, v_k.plus(a_k.times(t)))); + } + + // Start and end at rest + problem.subjectTo(eq(X.col(0), new VariableMatrix(new double[][] {{0.0}, {0.0}}))); + problem.subjectTo(eq(X.col(N), new VariableMatrix(new double[][] {{r}, {0.0}}))); + + // Limit velocity + problem.subjectTo(bounds(-1, X.row(1), 1)); + + // Limit acceleration + problem.subjectTo(bounds(-1, U, 1)); + + // Cost function - minimize position error + var J = new Variable(0.0); + for (int k = 0; k < N + 1; ++k) { + J = J.plus(pow(new Variable(r).minus(X.get(0, k)), 2)); + } + problem.minimize(J); + + assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType()); + assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType()); + assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + var A = new SimpleMatrix(new double[][] {{1.0, dt}, {0.0, 1.0}}); + var B = new SimpleMatrix(new double[][] {{0.5 * dt * dt}, {dt}}); + + // Verify initial state + assertEquals(0.0, X.value(0, 0), 1e-8); + assertEquals(0.0, X.value(1, 0), 1e-8); + + // Verify solution + var x = new SimpleMatrix(new double[][] {{0.0}, {0.0}}); + var u = new SimpleMatrix(new double[][] {{0.0}}); + for (int k = 0; k < N; ++k) { + // Verify state + assertEquals(x.get(0), X.value(0, k), 1e-2); + assertEquals(x.get(1), X.value(1, k), 1e-2); + + // Determine expected input for this timestep + if (k * dt < 1.0) { + // Accelerate + u.set(0, 0, 1.0); + } else if (k * dt < 2.05) { + // Maintain speed + u.set(0, 0, 0.0); + } else if (k * dt < 3.275) { + // Decelerate + u.set(0, 0, -1.0); + } else { + // Accelerate + u.set(0, 0, 1.0); + } + + // Verify input + if (k > 0 && k < N - 1 && Math.abs(U.value(0, k - 1) - U.value(0, k + 1)) >= 1.0 - 1e-2) { + // If control input is transitioning between -1, 0, or 1, ensure it's within (-1, 1) + assertTrue(U.value(0, k) >= -1.0); + assertTrue(U.value(0, k) <= 1.0); + } else { + assertEquals(u.get(0), U.value(0, k), 1e-4); + } + + // Project state forward + x = A.mult(x).plus(B.mult(u)); + } + + // Verify final state + assertEquals(r, X.value(0, N), 1e-8); + assertEquals(0.0, X.value(1, N), 1e-8); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/FlywheelOCPTest.java b/wpimath/src/test/java/org/wpilib/math/optimization/FlywheelOCPTest.java new file mode 100644 index 0000000000..79856580c0 --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/FlywheelOCPTest.java @@ -0,0 +1,178 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.function.BiFunction; +import org.ejml.simple.SimpleMatrix; +import org.junit.jupiter.api.Test; +import org.wpilib.math.autodiff.ExpressionType; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.autodiff.VariableMatrix; +import org.wpilib.math.optimization.ocp.DynamicsType; +import org.wpilib.math.optimization.ocp.TimestepMethod; +import org.wpilib.math.optimization.ocp.TranscriptionMethod; +import org.wpilib.math.optimization.solver.ExitStatus; + +class FlywheelOCPTest { + private boolean near(double expected, double actual, double tolerance) { + return Math.abs(expected - actual) < tolerance; + } + + void flywheelTest( + double A, + double B, + BiFunction f, + DynamicsType dynamicsType, + TranscriptionMethod transcriptionMethod) { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + final double TOTAL_TIME = 5.0; // s + final double dt = 0.005; // s + final int N = (int) (TOTAL_TIME / dt); + + // Flywheel model: + // States: [velocity] + // Inputs: [voltage] + final double A_discrete = Math.exp(A * dt); + final double B_discrete = (1.0 - A_discrete) * B; + + final double r = 10.0; + + try (var problem = + new OCP(1, 1, dt, N, f, dynamicsType, TimestepMethod.FIXED, transcriptionMethod)) { + problem.constrainInitialState(0.0); + problem.setUpperInputBound(12); + problem.setLowerInputBound(-12); + + // Set up cost + var r_mat = new VariableMatrix(SimpleMatrix.filled(1, N + 1, r)); + problem.minimize(r_mat.minus(problem.X()).times(r_mat.minus(problem.X()).T())); + + assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType()); + assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType()); + assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + // Voltage for steady-state velocity: + // + // rₖ₊₁ = Arₖ + Buₖ + // uₖ = B⁺(rₖ₊₁ − Arₖ) + // uₖ = B⁺(rₖ − Arₖ) + // uₖ = B⁺(I − A)rₖ + double u_ss = 1.0 / B_discrete * (1.0 - A_discrete) * r; + + // Verify initial state + assertEquals(0.0, problem.X().value(0, 0), 1e-8); + + // Verify solution + double x = 0.0; + double u; + for (int k = 0; k < N; ++k) { + // Verify state + assertEquals(x, problem.X().value(0, k), 1e-2); + + // Determine expected input for this timestep + double error = r - x; + if (error > 1e-2) { + // Max control input until the reference is reached + u = 12.0; + } else { + // Maintain speed + u = u_ss; + } + + // Verify input + if (k > 0 + && k < N - 1 + && near(12.0, problem.U().value(0, k - 1), 1e-2) + && near(u_ss, problem.U().value(0, k + 1), 1e-2)) { + // If control input is transitioning between 12 and u_ss, ensure it's + // within (u_ss, 12) + assertTrue(problem.U().value(0, k) >= u_ss); + assertTrue(problem.U().value(0, k) <= 12.0); + } else { + if (transcriptionMethod == TranscriptionMethod.DIRECT_COLLOCATION) { + // The tolerance is large because the trajectory is represented by a + // spline, and splines chatter when transitioning quickly between + // steady-states. + assertEquals(u, problem.U().value(0, k), 2.0); + } else { + assertEquals(u, problem.U().value(0, k), 1e-4); + } + } + + // Project state forward + x = A_discrete * x + B_discrete * u; + } + + // Verify final state + assertEquals(r, problem.X().value(0, N), 2e-6); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + static final double A = -1.0; + static final double B = 1.0; + + static final double dt = 0.005; // s + + static final double A_discrete = Math.exp(A * dt); + static final double B_discrete = (1.0 - A_discrete) * B; + + private static VariableMatrix f_ode(VariableMatrix x, VariableMatrix u) { + return new VariableMatrix(new double[][] {{A}}) + .times(x) + .plus(new VariableMatrix(new double[][] {{B}}).times(u)); + } + + private static VariableMatrix f_discrete(VariableMatrix x, VariableMatrix u) { + return new VariableMatrix(new double[][] {{A_discrete}}) + .times(x) + .plus(new VariableMatrix(new double[][] {{B_discrete}}).times(u)); + } + + @Test + void testFlywheelExplicit() { + flywheelTest( + A, + B, + FlywheelOCPTest::f_ode, + DynamicsType.EXPLICIT_ODE, + TranscriptionMethod.DIRECT_COLLOCATION); + flywheelTest( + A, + B, + FlywheelOCPTest::f_ode, + DynamicsType.EXPLICIT_ODE, + TranscriptionMethod.DIRECT_TRANSCRIPTION); + flywheelTest( + A, + B, + FlywheelOCPTest::f_ode, + DynamicsType.EXPLICIT_ODE, + TranscriptionMethod.SINGLE_SHOOTING); + } + + @Test + void testFlywheelDiscrete() { + flywheelTest( + A, + B, + FlywheelOCPTest::f_discrete, + DynamicsType.DISCRETE, + TranscriptionMethod.DIRECT_TRANSCRIPTION); + flywheelTest( + A, + B, + FlywheelOCPTest::f_discrete, + DynamicsType.DISCRETE, + TranscriptionMethod.SINGLE_SHOOTING); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/FlywheelProblemTest.java b/wpimath/src/test/java/org/wpilib/math/optimization/FlywheelProblemTest.java new file mode 100644 index 0000000000..6e63be89ce --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/FlywheelProblemTest.java @@ -0,0 +1,120 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.wpilib.math.optimization.Constraints.bounds; +import static org.wpilib.math.optimization.Constraints.eq; + +import org.junit.jupiter.api.Test; +import org.wpilib.math.autodiff.ExpressionType; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.autodiff.VariableMatrix; +import org.wpilib.math.optimization.solver.ExitStatus; + +class FlywheelProblemTest { + private boolean near(double expected, double actual, double tolerance) { + return Math.abs(expected - actual) < tolerance; + } + + @Test + void testFlywheelProblem() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + final double TOTAL_TIME = 5.0; // s + final double dt = 0.005; // s + final int N = (int) (TOTAL_TIME / dt); + + // Flywheel model: + // States: [velocity] + // Inputs: [voltage] + double A = Math.exp(-dt); + double B = 1.0 - Math.exp(-dt); + + try (var problem = new Problem()) { + var X = problem.decisionVariable(1, N + 1); + var U = problem.decisionVariable(1, N); + + // Dynamics constraint + for (int k = 0; k < N; ++k) { + problem.subjectTo( + eq( + X.col(k + 1), + new Variable(A) + .times(X.col(k).get(0)) + .plus(new Variable(B).times(U.col(k).get(0))))); + } + + // State and input constraints + problem.subjectTo(eq(X.col(0), 0.0)); + problem.subjectTo(bounds(-12, U, 12)); + + // Cost function - minimize error + final var r = new VariableMatrix(new double[][] {{10.0}}); + var J = new Variable(0.0); + for (int k = 0; k < N + 1; ++k) { + J = J.plus(r.minus(X.col(k)).T().times(r.minus(X.col(k))).get(0)); + } + problem.minimize(J); + + assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType()); + assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType()); + assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + // Voltage for steady-state velocity: + // + // rₖ₊₁ = Arₖ + Buₖ + // uₖ = B⁺(rₖ₊₁ − Arₖ) + // uₖ = B⁺(rₖ − Arₖ) + // uₖ = B⁺(I − A)rₖ + double u_ss = 1.0 / B * (1.0 - A) * r.value(0); + + // Verify initial state + assertEquals(0.0, X.value(0, 0), 1e-8); + + // Verify solution + double x = 0.0; + double u; + for (int k = 0; k < N; ++k) { + // Verify state + assertEquals(x, X.value(0, k), 1e-2); + + // Determine expected input for this timestep + double error = r.value(0) - x; + if (error > 1e-2) { + // Max control input until the reference is reached + u = 12.0; + } else { + // Maintain speed + u = u_ss; + } + + // Verify input + if (k > 0 + && k < N - 1 + && near(12.0, U.value(0, k - 1), 1e-2) + && near(u_ss, U.value(0, k + 1), 1e-2)) { + // If control input is transitioning between 12 and u_ss, ensure it's + // within (u_ss, 12) + assertTrue(U.value(0, k) >= u_ss); + assertTrue(U.value(0, k) <= 12.0); + } else { + assertEquals(u, U.value(0, k), 1e-4); + } + + // Project state forward + x = A * x + B * u; + } + + // Verify final state + assertEquals(r.value(0), X.value(0, N), 2e-7); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/LinearProblemTest.java b/wpimath/src/test/java/org/wpilib/math/optimization/LinearProblemTest.java new file mode 100644 index 0000000000..3de69e5565 --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/LinearProblemTest.java @@ -0,0 +1,72 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.wpilib.math.optimization.Constraints.eq; +import static org.wpilib.math.optimization.Constraints.ge; +import static org.wpilib.math.optimization.Constraints.le; + +import org.junit.jupiter.api.Test; +import org.wpilib.math.autodiff.ExpressionType; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.optimization.solver.ExitStatus; + +class LinearProblemTest { + @Test + void testMaximize() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + var y = problem.decisionVariable(); + x.setValue(1.0); + y.setValue(1.0); + + problem.maximize(x.times(50).plus(y.times(40))); + + problem.subjectTo(le(x.plus(y.times(1.5)), 750)); + problem.subjectTo(le(x.times(2).plus(y.times(3)), 1500)); + problem.subjectTo(le(x.times(2).plus(y), 1000)); + problem.subjectTo(ge(x, 0)); + problem.subjectTo(ge(y, 0)); + + assertEquals(ExpressionType.LINEAR, problem.costFunctionType()); + assertEquals(ExpressionType.NONE, problem.equalityConstraintType()); + assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + assertEquals(375.0, x.value(), 1e-6); + assertEquals(250.0, y.value(), 1e-6); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testFreeVariable() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(2); + x.get(0).setValue(1.0); + x.get(1).setValue(2.0); + + problem.subjectTo(eq(x.get(0), 0)); + + assertEquals(ExpressionType.NONE, problem.costFunctionType()); + assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONE, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + assertEquals(0.0, x.get(0).value(), 1e-6); + assertEquals(2.0, x.get(1).value(), 1e-6); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/NonlinearProblemTest.java b/wpimath/src/test/java/org/wpilib/math/optimization/NonlinearProblemTest.java new file mode 100644 index 0000000000..13012ea5de --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/NonlinearProblemTest.java @@ -0,0 +1,212 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.wpilib.math.DoubleRange.range; +import static org.wpilib.math.autodiff.Variable.hypot; +import static org.wpilib.math.autodiff.Variable.pow; +import static org.wpilib.math.autodiff.Variable.sqrt; +import static org.wpilib.math.optimization.Constraints.bounds; +import static org.wpilib.math.optimization.Constraints.eq; +import static org.wpilib.math.optimization.Constraints.ge; +import static org.wpilib.math.optimization.Constraints.le; + +import org.junit.jupiter.api.Test; +import org.wpilib.math.autodiff.ExpressionType; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.optimization.solver.ExitStatus; + +class NonlinearProblemTest { + @Test + void testQuartic() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + x.setValue(20.0); + + problem.minimize(pow(x, 4)); + + problem.subjectTo(ge(x, 1)); + + assertEquals(ExpressionType.NONLINEAR, problem.costFunctionType()); + assertEquals(ExpressionType.NONE, problem.equalityConstraintType()); + assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + assertEquals(1.0, x.value(), 1e-6); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + private boolean near(double expected, double actual, double tolerance) { + return Math.abs(expected - actual) < tolerance; + } + + @Test + void testRosenbrockWithCubicAndLineConstraint() { + // https://en.wikipedia.org/wiki/Test_functions_for_optimization#Test_functions_for_constrained_optimization + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + var y = problem.decisionVariable(); + + problem.minimize( + pow(y.minus(pow(x, 2)), 2).times(100).plus(pow(new Variable(1).minus(x), 2))); + + problem.subjectTo(ge(y, pow(x.minus(1), 3).plus(1))); + problem.subjectTo(le(y, x.unaryMinus().plus(2))); + + assertEquals(ExpressionType.NONLINEAR, problem.costFunctionType()); + assertEquals(ExpressionType.NONE, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONLINEAR, problem.inequalityConstraintType()); + + for (var x0 : range(-1.5, 1.5, 0.1)) { + for (var y0 : range(-0.5, 2.5, 0.1)) { + x.setValue(x0); + y.setValue(y0); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + // Local minimum at (0.0, 0.0) + // Global minimum at (1.0, 1.0) + assertTrue(near(0.0, x.value(), 1e-2) || near(1.0, x.value(), 1e-2)); + assertTrue(near(0.0, y.value(), 1e-2) || near(1.0, y.value(), 1e-2)); + } + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testRosenbrockWithDiskConstraint() { + // https://en.wikipedia.org/wiki/Test_functions_for_optimization#Test_functions_for_constrained_optimization + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + var y = problem.decisionVariable(); + + problem.minimize( + pow(new Variable(1).minus(x), 2).plus(pow(y.minus(pow(x, 2)), 2).times(100))); + + problem.subjectTo(le(pow(x, 2).plus(pow(y, 2)), 2)); + + assertEquals(ExpressionType.NONLINEAR, problem.costFunctionType()); + assertEquals(ExpressionType.NONE, problem.equalityConstraintType()); + assertEquals(ExpressionType.QUADRATIC, problem.inequalityConstraintType()); + + for (var x0 : range(-1.5, 1.5, 0.1)) { + for (var y0 : range(-1.5, 1.5, 0.1)) { + x.setValue(x0); + y.setValue(y0); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + assertEquals(1.0, x.value(), 1e-3); + assertEquals(1.0, y.value(), 1e-3); + } + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testMinimum2DDistanceWithLinearConstraint() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + var y = problem.decisionVariable(); + x.setValue(20.0); + y.setValue(50.0); + + problem.minimize(sqrt(x.times(x).plus(y.times(y)))); + + problem.subjectTo(eq(y, x.unaryMinus().plus(5.0))); + + assertEquals(ExpressionType.NONLINEAR, problem.costFunctionType()); + assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONE, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + assertEquals(2.5, x.value(), 1e-2); + assertEquals(2.5, y.value(), 1e-2); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testConflictingBounds() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + var y = problem.decisionVariable(); + + problem.minimize(hypot(x, y)); + + problem.subjectTo(le(hypot(x, y), 1)); + problem.subjectTo(bounds(0.5, x, -0.5)); + + assertEquals(ExpressionType.NONLINEAR, problem.costFunctionType()); + assertEquals(ExpressionType.NONE, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONLINEAR, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.GLOBALLY_INFEASIBLE, problem.solve()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testWachterAndBieglerLineSearchFailure() { + // See example 19.2 of [1] + // + // [1] Nocedal, J. and Wright, S. "Numerical Optimization", 2nd. ed., Ch. 19. Springer, 2006. + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + var s1 = problem.decisionVariable(); + var s2 = problem.decisionVariable(); + + x.setValue(-2); + s1.setValue(3); + s2.setValue(1); + + problem.minimize(x); + + problem.subjectTo(eq(pow(x, 2).minus(s1).minus(1), 0)); + problem.subjectTo(eq(x.minus(s2).minus(0.5), 0)); + problem.subjectTo(ge(s1, 0)); + problem.subjectTo(ge(s2, 0)); + + assertEquals(ExpressionType.LINEAR, problem.costFunctionType()); + assertEquals(ExpressionType.QUADRATIC, problem.equalityConstraintType()); + assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + assertEquals(1.0, x.value(), 1e-6); + assertEquals(0.0, s1.value(), 1e-6); + assertEquals(0.5, s2.value(), 1e-6); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/ProblemJNITest.java b/wpimath/src/test/java/org/wpilib/math/optimization/ProblemJNITest.java new file mode 100644 index 0000000000..c7477540ae --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/ProblemJNITest.java @@ -0,0 +1,16 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import org.junit.jupiter.api.Test; + +public class ProblemJNITest { + @Test + public void testLink() { + assertDoesNotThrow(ProblemJNI::forceLoad); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/QuadraticProblemTest.java b/wpimath/src/test/java/org/wpilib/math/optimization/QuadraticProblemTest.java new file mode 100644 index 0000000000..4d2907c672 --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/QuadraticProblemTest.java @@ -0,0 +1,194 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.wpilib.math.optimization.Constraints.eq; +import static org.wpilib.math.optimization.Constraints.ge; + +import org.junit.jupiter.api.Test; +import org.wpilib.math.autodiff.ExpressionType; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.optimization.solver.ExitStatus; + +class QuadraticProblemTest { + @Test + void testUnconstrained1D() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + x.setValue(2.0); + + problem.minimize(x.times(x).minus(x.times(6.0))); + + assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType()); + assertEquals(ExpressionType.NONE, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONE, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + assertEquals(3.0, x.value(), 1e-6); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testUnconstrained2D() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + var y = problem.decisionVariable(); + x.setValue(1.0); + y.setValue(2.0); + + problem.minimize(x.times(x).plus(y.times(y))); + + assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType()); + assertEquals(ExpressionType.NONE, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONE, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + assertEquals(0.0, x.value(), 1e-6); + assertEquals(0.0, y.value(), 1e-6); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(2); + x.get(0).setValue(1.0); + x.get(1).setValue(2.0); + + problem.minimize(x.T().times(x).get(0)); + + assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType()); + assertEquals(ExpressionType.NONE, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONE, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + assertEquals(0.0, x.value(0), 1e-6); + assertEquals(0.0, x.value(1), 1e-6); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testEqualityConstrained() { + // Maximize xy subject to x + 3y = 36. + // + // Maximize f(x,y) = xy + // subject to g(x,y) = x + 3y - 36 = 0 + // + // value func constraint + // | | + // v v + // L(x,y,λ) = f(x,y) - λg(x,y) + // L(x,y,λ) = xy - λ(x + 3y - 36) + // L(x,y,λ) = xy - xλ - 3yλ + 36λ + // + // ∇_x,y,λ L(x,y,λ) = 0 + // + // ∂L/∂x = y - λ + // ∂L/∂y = x - 3λ + // ∂L/∂λ = -x - 3y + 36 + // + // 0x + 1y - 1λ = 0 + // 1x + 0y - 3λ = 0 + // -1x - 3y + 0λ + 36 = 0 + // + // [ 0 1 -1][x] [ 0] + // [ 1 0 -3][y] = [ 0] + // [-1 -3 0][λ] [-36] + // + // Solve with: + // + // ```python + // np.linalg.solve( + // np.array([[0,1,-1], + // [1,0,-3], + // [-1,-3,0]]), + // np.array([[0], [0], [-36]])) + // ``` + // + // [x] [18] + // [y] = [ 6] + // [λ] [ 6] + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + var y = problem.decisionVariable(); + + problem.maximize(x.times(y)); + + problem.subjectTo(eq(x.plus(y.times(3)), 36)); + + assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType()); + assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONE, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + assertEquals(18.0, x.value(), 1e-5); + assertEquals(6.0, y.value(), 1e-5); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(2); + x.get(0).setValue(1.0); + x.get(1).setValue(2.0); + + problem.minimize(x.T().times(x).get(0)); + + problem.subjectTo(eq(x, new double[][] {{3.0}, {3.0}})); + + assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType()); + assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONE, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + assertEquals(3.0, x.value(0), 1e-5); + assertEquals(3.0, x.value(1), 1e-5); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testInequalityConstrained2D() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + var y = problem.decisionVariable(); + x.setValue(5.0); + y.setValue(5.0); + + problem.minimize(x.times(x).plus(y.times(2).times(y))); + problem.subjectTo(ge(y, x.unaryMinus().plus(5))); + + assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType()); + assertEquals(ExpressionType.NONE, problem.equalityConstraintType()); + assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + assertEquals(3.0 + 1.0 / 3.0, x.value(), 1e-6); + assertEquals(1.0 + 2.0 / 3.0, y.value(), 1e-6); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/TrivialProblemTest.java b/wpimath/src/test/java/org/wpilib/math/optimization/TrivialProblemTest.java new file mode 100644 index 0000000000..cf6b6042c3 --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/TrivialProblemTest.java @@ -0,0 +1,73 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.ejml.simple.SimpleMatrix; +import org.junit.jupiter.api.Test; +import org.wpilib.math.autodiff.ExpressionType; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.optimization.solver.ExitStatus; + +class TrivialProblemTest { + @Test + void testEmpty() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + assertEquals(ExpressionType.NONE, problem.costFunctionType()); + assertEquals(ExpressionType.NONE, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONE, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testNoCostUnconstrained() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + @SuppressWarnings("VariableDeclarationUsageDistance") + var X = problem.decisionVariable(2, 3); + + assertEquals(ExpressionType.NONE, problem.costFunctionType()); + assertEquals(ExpressionType.NONE, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONE, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + for (int row = 0; row < X.rows(); ++row) { + for (int col = 0; col < X.cols(); ++col) { + assertEquals(0.0, X.value(row, col)); + } + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var X = problem.decisionVariable(2, 3); + X.setValue(SimpleMatrix.ones(2, 3)); + + assertEquals(ExpressionType.NONE, problem.costFunctionType()); + assertEquals(ExpressionType.NONE, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONE, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + for (int row = 0; row < X.rows(); ++row) { + for (int col = 0; col < X.cols(); ++col) { + assertEquals(1.0, X.value(row, col)); + } + } + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/java/org/wpilib/math/optimization/solver/ExitStatusTest.java b/wpimath/src/test/java/org/wpilib/math/optimization/solver/ExitStatusTest.java new file mode 100644 index 0000000000..38d37076b3 --- /dev/null +++ b/wpimath/src/test/java/org/wpilib/math/optimization/solver/ExitStatusTest.java @@ -0,0 +1,222 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.optimization.solver; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.wpilib.math.autodiff.Variable.sqrt; +import static org.wpilib.math.optimization.Constraints.eq; +import static org.wpilib.math.optimization.Constraints.ge; +import static org.wpilib.math.optimization.Constraints.gt; + +import org.junit.jupiter.api.Test; +import org.wpilib.math.autodiff.ExpressionType; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.optimization.Problem; + +// These tests ensure coverage of the off-nominal exit statuses + +class ExitStatusTest { + @Test + void testCallbackRequestedStop() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + problem.minimize(x.times(x)); + + problem.addCallback(info -> false); + assertEquals(ExitStatus.SUCCESS, problem.solve()); + + problem.addCallback(info -> true); + assertEquals(ExitStatus.CALLBACK_REQUESTED_STOP, problem.solve()); + + problem.clearCallbacks(); + problem.addCallback(info -> false); + assertEquals(ExitStatus.SUCCESS, problem.solve()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testTooFewDOFs() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + var y = problem.decisionVariable(); + var z = problem.decisionVariable(); + + problem.subjectTo(eq(x, 1)); + problem.subjectTo(eq(x, 2)); + problem.subjectTo(eq(y, 1)); + problem.subjectTo(eq(z, 1)); + + assertEquals(ExpressionType.NONE, problem.costFunctionType()); + assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONE, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.TOO_FEW_DOFS, problem.solve()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testLocallyInfeasible() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + // Equality constraints + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + var y = problem.decisionVariable(); + var z = problem.decisionVariable(); + + problem.subjectTo(eq(x, y.plus(1))); + problem.subjectTo(eq(y, z.plus(1))); + problem.subjectTo(eq(z, x.plus(1))); + + assertEquals(ExpressionType.NONE, problem.costFunctionType()); + assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONE, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.LOCALLY_INFEASIBLE, problem.solve()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + + // Inequality constraints + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + var y = problem.decisionVariable(); + var z = problem.decisionVariable(); + + problem.subjectTo(ge(x, y.plus(1))); + problem.subjectTo(ge(y, z.plus(1))); + problem.subjectTo(ge(z, x.plus(1))); + + assertEquals(ExpressionType.NONE, problem.costFunctionType()); + assertEquals(ExpressionType.NONE, problem.equalityConstraintType()); + assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.LOCALLY_INFEASIBLE, problem.solve()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testNonfiniteInitialGuess() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + // Nonfinite cost + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + problem.minimize(new Variable(1).div(x)); + + assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve()); + } + + // Nonfinite gradient + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + problem.minimize(sqrt(x)); + + assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve()); + } + + // Nonfinite equality constraint + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + problem.subjectTo(eq(new Variable(1).div(x), 1)); + + assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve()); + } + + // Nonfinite equality constraint Jacobian + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + problem.subjectTo(eq(sqrt(x), 1)); + + assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve()); + } + + // Nonfinite inequality constraint + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + problem.subjectTo(gt(new Variable(1).div(x), 1)); + + assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve()); + } + + // Nonfinite inequality constraint Jacobian + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + problem.subjectTo(gt(sqrt(x), 1)); + + assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testDivergingIterates() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + + problem.minimize(x); + + assertEquals(ExpressionType.LINEAR, problem.costFunctionType()); + assertEquals(ExpressionType.NONE, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONE, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.DIVERGING_ITERATES, problem.solve()); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testMaxIterationsExceeded() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + + problem.minimize(x.times(x)); + + assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType()); + assertEquals(ExpressionType.NONE, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONE, problem.inequalityConstraintType()); + + assertEquals( + ExitStatus.MAX_ITERATIONS_EXCEEDED, problem.solve(new Options().withMaxIterations(0))); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } + + @Test + void testTimeout() { + assertEquals(0, Variable.totalNativeMemoryUsage()); + + try (var problem = new Problem()) { + var x = problem.decisionVariable(); + + problem.minimize(x.times(x)); + + assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType()); + assertEquals(ExpressionType.NONE, problem.equalityConstraintType()); + assertEquals(ExpressionType.NONE, problem.inequalityConstraintType()); + + assertEquals(ExitStatus.TIMEOUT, problem.solve(new Options().withTimeout(0.0))); + } + + assertEquals(0, Variable.totalNativeMemoryUsage()); + } +} diff --git a/wpimath/src/test/native/cpp/optimization/CurrentManagerTest.cpp b/wpimath/src/test/native/cpp/optimization/CurrentManagerTest.cpp new file mode 100644 index 0000000000..10efc8a90a --- /dev/null +++ b/wpimath/src/test/native/cpp/optimization/CurrentManagerTest.cpp @@ -0,0 +1,49 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#include "wpi/math/optimization/CurrentManager.hpp" + +#include + +#include + +TEST(CurrentManagerTest, EnoughCurrent) { + wpi::math::CurrentManager manager{std::array{1.0, 5.0, 10.0, 5.0}, 40.0}; + + auto currents = manager.calculate(std::array{25.0, 10.0, 5.0, 0.0}); + + EXPECT_NEAR(currents[0], 25.0, 1e-3); + EXPECT_NEAR(currents[1], 10.0, 1e-3); + EXPECT_NEAR(currents[2], 5.0, 1e-3); + EXPECT_NEAR(currents[3], 0.0, 1e-3); +} + +TEST(CurrentManagerTest, NotEnoughCurrent) { + wpi::math::CurrentManager manager{std::array{1.0, 5.0, 10.0, 5.0}, 40.0}; + + auto currents = manager.calculate(std::array{30.0, 10.0, 5.0, 0.0}); + + // Expected values are from the following program: + // + // #!/usr/bin/env python3 + // + // from scipy.optimize import minimize + // + // r = [30.0, 10.0, 5.0, 0.0] + // q = [1.0, 5.0, 10.0, 5.0] + // + // result = minimize( + // lambda x: sum((r[i] - x[i]) ** 2 / q[i] ** 2 for i in range(4)), + // [0.0, 0.0, 0.0, 0.0], + // constraints=[ + // {"type": "ineq", "fun": lambda x: x}, + // {"type": "ineq", "fun": lambda x: 40.0 - sum(x)}, + // ], + // ) + // print(result.x) + EXPECT_NEAR(currents[0], 29.960, 1e-3); + EXPECT_NEAR(currents[1], 9.008, 1e-3); + EXPECT_NEAR(currents[2], 1.032, 1e-3); + EXPECT_NEAR(currents[3], 0.0, 1e-3); +} diff --git a/wpimath/src/test/native/include/wpi/math/optimization/CurrentManager.hpp b/wpimath/src/test/native/include/wpi/math/optimization/CurrentManager.hpp new file mode 100644 index 0000000000..ffa6cd4e2b --- /dev/null +++ b/wpimath/src/test/native/include/wpi/math/optimization/CurrentManager.hpp @@ -0,0 +1,97 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace wpi::math { + +/** + * This class computes the optimal current allocation for a list of subsystems + * given a list of their desired currents and current tolerances that determine + * which subsystem gets less current if the current budget is exceeded. + * Subsystems with a smaller tolerance are given higher priority. + */ +class CurrentManager { + public: + /** + * Constructs a CurrentManager. + * + * @param currentTolerances The relative current tolerance of each subsystem. + * @param maxCurrent The current budget to allocate between subsystems. + */ + CurrentManager(std::span currentTolerances, double maxCurrent) + : m_desiredCurrents{static_cast(currentTolerances.size()), 1}, + m_allocatedCurrents{ + m_problem.decision_variable(currentTolerances.size())} { + // Ensure m_desiredCurrents contains initialized Variables + for (int row = 0; row < m_desiredCurrents.rows(); ++row) { + // Don't initialize to 0 or 1, because those will get folded by Sleipnir + m_desiredCurrents[row] = std::numeric_limits::infinity(); + } + + slp::Variable J = 0.0; + slp::Variable current_sum = 0.0; + for (size_t i = 0; i < currentTolerances.size(); ++i) { + // The weight is 1/tolᵢ² where tolᵢ is the tolerance between the desired + // and allocated current for subsystem i + auto error = m_desiredCurrents[i] - m_allocatedCurrents[i]; + J += error * error / (currentTolerances[i] * currentTolerances[i]); + + current_sum += m_allocatedCurrents[i]; + + // Currents must be nonnegative + m_problem.subject_to(m_allocatedCurrents[i] >= 0.0); + } + m_problem.minimize(J); + + // Keep total current below maximum + m_problem.subject_to(current_sum <= maxCurrent); + } + + /** + * Returns the optimal current allocation for a list of subsystems given a + * list of their desired currents and current tolerances that determine which + * subsystem gets less current if the current budget is exceeded. Subsystems + * with a smaller tolerance are given higher priority. + * + * @param desiredCurrents The desired current for each subsystem. + * @throws std::runtime_error if the number of desired currents doesn't equal + * the number of tolerances passed in the constructor. + */ + std::vector calculate(std::span desiredCurrents) { + if (m_desiredCurrents.rows() != static_cast(desiredCurrents.size())) { + throw std::runtime_error( + "Number of desired currents must equal the number of tolerances " + "passed in the constructor."); + } + + for (size_t i = 0; i < desiredCurrents.size(); ++i) { + m_desiredCurrents[i].set_value(desiredCurrents[i]); + } + + m_problem.solve(); + + std::vector result; + for (size_t i = 0; i < desiredCurrents.size(); ++i) { + result.emplace_back(std::max(m_allocatedCurrents.value(i), 0.0)); + } + + return result; + } + + private: + slp::Problem m_problem; + slp::VariableMatrix m_desiredCurrents; + slp::VariableMatrix m_allocatedCurrents; +}; + +} // namespace wpi::math