diff --git a/benchmark/src/main/java/wpilib/robot/CartPoleBenchmark.java b/benchmark/src/main/java/wpilib/robot/CartPoleBenchmark.java new file mode 100644 index 0000000000..cd08ecccea --- /dev/null +++ b/benchmark/src/main/java/wpilib/robot/CartPoleBenchmark.java @@ -0,0 +1,128 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package wpilib.robot; + +import static org.wpilib.math.autodiff.NumericalIntegration.rk4; +import static org.wpilib.math.autodiff.Variable.cos; +import static org.wpilib.math.autodiff.Variable.sin; +import static org.wpilib.math.autodiff.VariableMatrix.solve; +import static org.wpilib.math.optimization.Constraints.eq; +import static org.wpilib.math.optimization.Constraints.ge; +import static org.wpilib.math.optimization.Constraints.le; + +import org.ejml.simple.SimpleMatrix; +import org.wpilib.math.autodiff.Variable; +import org.wpilib.math.autodiff.VariableMatrix; +import org.wpilib.math.optimization.Problem; +import org.wpilib.math.optimization.solver.Options; +import org.wpilib.math.util.MathUtil; + +public final class CartPoleBenchmark { + private CartPoleBenchmark() { + // Utility class. + } + + @SuppressWarnings("LocalVariableName") + private static VariableMatrix cartPoleDynamics(VariableMatrix x, VariableMatrix u) { + final double m_c = 5.0; // Cart mass (kg) + final double m_p = 0.5; // Pole mass (kg) + final double l = 0.5; // Pole length (m) + final double g = 9.806; // Acceleration due to gravity (m/s²) + + var q = x.segment(0, 2); + var qdot = x.segment(2, 2); + var theta = q.get(1); + var thetadot = qdot.get(1); + + // [ m_c + m_p m_p l cosθ] + // M(q) = [m_p l cosθ m_p l² ] + var M = + new VariableMatrix( + new Variable[][] { + {new Variable(m_c + m_p), cos(theta).times(m_p * l)}, + {cos(theta).times(m_p * l), new Variable(m_p * Math.pow(l, 2))} + }); + + // [0 −m_p lθ̇ sinθ] + // C(q, q̇) = [0 0 ] + var C = + new VariableMatrix( + new Variable[][] { + {new Variable(0), thetadot.times(-m_p * l).times(sin(theta))}, + {new Variable(0), new Variable(0)} + }); + + // [ 0 ] + // τ_g(q) = [-m_p gl sinθ] + var tau_g = + new VariableMatrix(new Variable[][] {{new Variable(0)}, {sin(theta).times(-m_p * g * l)}}); + + // [1] + // B = [0] + var B = new VariableMatrix(new double[][] {{1}, {0}}); + + // q̈ = M⁻¹(q)(τ_g(q) − C(q, q̇)q̇ + Bu) + var qddot = new VariableMatrix(4); + qddot.segment(0, 2).set(qdot); + qddot.segment(2, 2).set(solve(M, tau_g.minus(C.times(qdot)).plus(B.times(u)))); + return qddot; + } + + /** Cart-pole benchmark. */ + public static void cartPole() { + final double T = 5.0; // s + final double dt = 0.05; // s + final int N = (int) (T / dt); + + final double u_max = 20.0; // N + final double d_max = 2.0; // m + + final var x_initial = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}}); + final var x_final = new SimpleMatrix(new double[][] {{1.0}, {Math.PI}, {0.0}, {0.0}}); + + var problem = new Problem(); + + // x = [q, q̇]ᵀ = [x, θ, ẋ, θ̇]ᵀ + var X = problem.decisionVariable(4, N + 1); + + // Initial guess + for (int k = 0; k < N + 1; ++k) { + X.get(0, k).setValue(MathUtil.lerp(x_initial.get(0), x_final.get(0), (double) k / N)); + X.get(1, k).setValue(MathUtil.lerp(x_initial.get(1), x_final.get(1), (double) k / N)); + } + + // u = f_x + var U = problem.decisionVariable(1, N); + + // Initial conditions + problem.subjectTo(eq(X.col(0), x_initial)); + + // Final conditions + problem.subjectTo(eq(X.col(N), x_final)); + + // Cart position constraints + problem.subjectTo(ge(X.row(0), 0.0)); + problem.subjectTo(le(X.row(0), d_max)); + + // Input constraints + problem.subjectTo(ge(U, -u_max)); + problem.subjectTo(le(U, u_max)); + + // Dynamics constraints - RK4 integration + for (int k = 0; k < N; ++k) { + problem.subjectTo( + eq(X.col(k + 1), rk4(CartPoleBenchmark::cartPoleDynamics, X.col(k), U.col(k), dt))); + } + + // Minimize sum squared inputs + var J = new Variable(0.0); + for (int k = 0; k < N; ++k) { + J = J.plus(U.col(k).T().times(U.col(k)).get(0)); + } + problem.minimize(J); + + problem.solve(new Options().withDiagnostics(true)); + } +} diff --git a/benchmark/src/main/java/wpilib/robot/Main.java b/benchmark/src/main/java/wpilib/robot/Main.java index 975e3401aa..3db4d43472 100644 --- a/benchmark/src/main/java/wpilib/robot/Main.java +++ b/benchmark/src/main/java/wpilib/robot/Main.java @@ -37,6 +37,13 @@ public class Main { new Runner(opt).run(); } + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.MICROSECONDS) + public void cartPole() { + CartPoleBenchmark.cartPole(); + } + @Benchmark @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.MICROSECONDS) diff --git a/wpimath/BUILD.bazel b/wpimath/BUILD.bazel index 80f99590fc..bd2e92a1c7 100644 --- a/wpimath/BUILD.bazel +++ b/wpimath/BUILD.bazel @@ -236,6 +236,9 @@ wpilib_cc_shared_library( wpilib_jni_java_library( name = "wpimath-java", srcs = [":generated_java"] + glob(["src/main/java/**/*.java"]), + javacopts = [ + "-Xep:UnicodeInCode:OFF", + ], maven_artifact_name = "wpimath-java", maven_group_id = "org.wpilib.wpimath", native_libs = [":wpimathjni"], @@ -288,6 +291,7 @@ wpilib_java_junit5_test( "//wpiunits:wpiunits-java", "//wpiutil:wpiutil-java", "@maven//:org_ejml_ejml_core", + "@maven//:org_ejml_ejml_ddense", "@maven//:org_ejml_ejml_simple", "@maven//:us_hebi_quickbuf_quickbuf_runtime", ], diff --git a/wpimath/CMakeLists.txt b/wpimath/CMakeLists.txt index 273698b4d4..73d5210913 100644 --- a/wpimath/CMakeLists.txt +++ b/wpimath/CMakeLists.txt @@ -7,14 +7,18 @@ include(DownloadAndCheck) file( GLOB wpimath_jni_src - src/main/native/cpp/jni/ArmFeedforwardJNI.cpp src/main/native/cpp/jni/DAREJNI.cpp src/main/native/cpp/jni/EigenJNI.cpp - src/main/native/cpp/jni/Ellipse2dJNI.cpp src/main/native/cpp/jni/Exceptions.cpp src/main/native/cpp/jni/LinearSystemUtilJNI.cpp src/main/native/cpp/jni/Transform3dJNI.cpp src/main/native/cpp/jni/Twist3dJNI.cpp + src/main/native/cpp/jni/autodiff/GradientJNI.cpp + src/main/native/cpp/jni/autodiff/HessianJNI.cpp + src/main/native/cpp/jni/autodiff/JacobianJNI.cpp + src/main/native/cpp/jni/autodiff/VariableJNI.cpp + src/main/native/cpp/jni/autodiff/VariableMatrixJNI.cpp + src/main/native/cpp/jni/optimization/ProblemJNI.cpp ) # Java bindings diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/ExpressionType.java b/wpimath/src/main/java/org/wpilib/math/autodiff/ExpressionType.java new file mode 100644 index 0000000000..dc9518e8fe --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/ExpressionType.java @@ -0,0 +1,47 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +/** + * Expression type. + * + *
Used for autodiff caching. + */ +public enum ExpressionType { + /** There is no expression. */ + NONE(0), + /** The expression is a constant. */ + CONSTANT(1), + /** The expression is composed of linear and lower-order operators. */ + LINEAR(2), + /** The expression is composed of quadratic and lower-order operators. */ + QUADRATIC(3), + /** The expression is composed of nonlinear and lower-order operators. */ + NONLINEAR(4); + + /** ExpressionType value. */ + public final int value; + + ExpressionType(int value) { + this.value = value; + } + + /** + * Converts integer to its corresponding enum value. + * + * @param x The integer. + * @return The enum value. + */ + public static ExpressionType fromInt(int x) { + return switch (x) { + case 0 -> ExpressionType.NONE; + case 1 -> ExpressionType.CONSTANT; + case 2 -> ExpressionType.LINEAR; + case 3 -> ExpressionType.QUADRATIC; + case 4 -> ExpressionType.NONLINEAR; + default -> null; + }; + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/Gradient.java b/wpimath/src/main/java/org/wpilib/math/autodiff/Gradient.java new file mode 100644 index 0000000000..136a7b59f7 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/Gradient.java @@ -0,0 +1,78 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import org.ejml.simple.SimpleMatrix; + +/** + * This class calculates the gradient of a variable with respect to a vector of variables. + * + *
The gradient is only recomputed if the variable expression is quadratic or higher order. + */ +public class Gradient implements AutoCloseable { + private long m_handle; + private int m_rows; + + /** + * Constructs a Gradient object. + * + * @param variable Variable of which to compute the gradient. + * @param wrt Variable with respect to which to compute the gradient. + */ + public Gradient(Variable variable, Variable wrt) { + this(variable, new VariableMatrix(wrt)); + } + + /** + * Constructs a Gradient object. + * + * @param variable Variable of which to compute the gradient. + * @param wrt Vector of variables with respect to which to compute the gradient. + */ + public Gradient(Variable variable, VariableMatrix wrt) { + assert wrt.cols() == 1; + + m_handle = GradientJNI.create(variable.getHandle(), wrt.getHandles()); + m_rows = wrt.rows(); + } + + /** + * Constructs a Gradient object. + * + * @param variable Variable of which to compute the gradient. + * @param wrt Vector of variables with respect to which to compute the gradient. + */ + public Gradient(Variable variable, VariableBlock wrt) { + this(variable, new VariableMatrix(wrt)); + } + + @Override + public void close() { + if (m_handle != 0) { + GradientJNI.destroy(m_handle); + m_handle = 0; + } + } + + /** + * Returns the gradient as a VariableMatrix. + * + *
This is useful when constructing optimization problems with derivatives in them. + * + * @return The gradient as a VariableMatrix. + */ + public VariableMatrix get() { + return new VariableMatrix(m_rows, 1, GradientJNI.get(m_handle)); + } + + /** + * Evaluates the gradient at wrt's value. + * + * @return The gradient at wrt's value. + */ + public SimpleMatrix value() { + return GradientJNI.value(m_handle).toSimpleMatrix(m_rows, 1); + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/GradientJNI.java b/wpimath/src/main/java/org/wpilib/math/autodiff/GradientJNI.java new file mode 100644 index 0000000000..e5f7568bb2 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/GradientJNI.java @@ -0,0 +1,48 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import org.wpilib.math.jni.WPIMathJNI; + +/** Gradient JNI functions. */ +final class GradientJNI extends WPIMathJNI { + private GradientJNI() { + // Utility class. + } + + /** + * Constructs a Gradient object. + * + * @param variable Variable of which to compute the Gradient. + * @param wrt Vector of variables with respect to which to compute the Gradient. + */ + static native long create(long variable, long[] wrt); + + /** + * Destructs a Gradient. + * + * @param handle Gradient handle. + */ + static native void destroy(long handle); + + /** + * Returns the Gradient as an array of Variable handles. + * + *
This is useful when constructing optimization problems with derivatives in them. + * + * @param handle Gradient handle. + * @return The Gradient as an array of Variable handles. + */ + static native long[] get(long handle); + + /** + * Evaluates the Gradient at wrt's value. + * + * @param handle Gradient handle. + * @return A record containing the triplet row, column and value arrays (int[], int[], and + * double[] respectively). + */ + static native NativeSparseTriplets value(long handle); +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/Hessian.java b/wpimath/src/main/java/org/wpilib/math/autodiff/Hessian.java new file mode 100644 index 0000000000..ba94ed2622 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/Hessian.java @@ -0,0 +1,79 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import org.ejml.simple.SimpleMatrix; + +/** + * This class calculates the Hessian of a variable with respect to a vector of variables. + * + *
The gradient tree is cached so subsequent Hessian calculations are faster, and the Hessian is + * only recomputed if the variable expression is nonlinear. + */ +public class Hessian implements AutoCloseable { + private long m_handle; + private int m_rows; + + /** + * Constructs a Hessian object. + * + * @param variable Variable of which to compute the Hessian. + * @param wrt Variable with respect to which to compute the Hessian. + */ + public Hessian(Variable variable, Variable wrt) { + this(variable, new VariableMatrix(wrt)); + } + + /** + * Constructs a Hessian object. + * + * @param variable Variable of which to compute the Hessian. + * @param wrt Vector of variables with respect to which to compute the Hessian. + */ + public Hessian(Variable variable, VariableMatrix wrt) { + assert wrt.cols() == 1; + + m_handle = HessianJNI.create(variable.getHandle(), wrt.getHandles()); + m_rows = wrt.rows(); + } + + /** + * Constructs a Hessian object. + * + * @param variable Variable of which to compute the Hessian. + * @param wrt Vector of variables with respect to which to compute the Hessian. + */ + public Hessian(Variable variable, VariableBlock wrt) { + this(variable, new VariableMatrix(wrt)); + } + + @Override + public void close() { + if (m_handle != 0) { + HessianJNI.destroy(m_handle); + m_handle = 0; + } + } + + /** + * Returns the Hessian as a VariableMatrix. + * + *
This is useful when constructing optimization problems with derivatives in them. + * + * @return The Hessian as a VariableMatrix. + */ + public VariableMatrix get() { + return new VariableMatrix(m_rows, m_rows, HessianJNI.get(m_handle)); + } + + /** + * Evaluates the Hessian at wrt's value. + * + * @return The Hessian at wrt's value. + */ + public SimpleMatrix value() { + return HessianJNI.value(m_handle).toSimpleMatrix(m_rows, m_rows); + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/HessianJNI.java b/wpimath/src/main/java/org/wpilib/math/autodiff/HessianJNI.java new file mode 100644 index 0000000000..4ae7735c9b --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/HessianJNI.java @@ -0,0 +1,48 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import org.wpilib.math.jni.WPIMathJNI; + +/** Hessian JNI functions. */ +final class HessianJNI extends WPIMathJNI { + private HessianJNI() { + // Utility class. + } + + /** + * Constructs a Hessian object. + * + * @param variable Variable of which to compute the Hessian. + * @param wrt Vector of variables with respect to which to compute the Hessian. + */ + static native long create(long variable, long[] wrt); + + /** + * Destructs a Hessian. + * + * @param handle Hessian handle. + */ + static native void destroy(long handle); + + /** + * Returns the Hessian as an array of Variable handles. + * + *
This is useful when constructing optimization problems with derivatives in them. + * + * @param handle Hessian handle. + * @return The Hessian as an array of Variable handles. + */ + static native long[] get(long handle); + + /** + * Evaluates the Hessian at wrt's value. + * + * @param handle Hessian handle. + * @return A record containing the triplet row, column and value arrays (int[], int[], and + * double[] respectively). + */ + static native NativeSparseTriplets value(long handle); +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/Jacobian.java b/wpimath/src/main/java/org/wpilib/math/autodiff/Jacobian.java new file mode 100644 index 0000000000..6dc5f99922 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/Jacobian.java @@ -0,0 +1,102 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import org.ejml.simple.SimpleMatrix; + +/** + * This class calculates the Jacobian of a vector of variables with respect to a vector of + * variables. + * + *
The Jacobian is only recomputed if the variable expression is quadratic or higher order. + */ +public class Jacobian implements AutoCloseable { + private long m_handle; + private int m_rows; + private int m_cols; + + /** + * Constructs a Jacobian object. + * + * @param variable Variable of which to compute the Jacobian. + * @param wrt Variable with respect to which to compute the Jacobian. + */ + public Jacobian(Variable variable, Variable wrt) { + this(new VariableMatrix(variable), new VariableMatrix(wrt)); + } + + /** + * Constructs a Jacobian object. + * + * @param variable Variable of which to compute the Jacobian. + * @param wrt Vector of variables with respect to which to compute the Jacobian. + */ + public Jacobian(Variable variable, VariableMatrix wrt) { + this(new VariableMatrix(variable), wrt); + } + + /** + * Constructs a Jacobian object. + * + * @param variable Variable of which to compute the Jacobian. + * @param wrt Vector of variables with respect to which to compute the Jacobian. + */ + public Jacobian(Variable variable, VariableBlock wrt) { + this(new VariableMatrix(variable), new VariableMatrix(wrt)); + } + + /** + * Constructs a Jacobian object. + * + * @param variables Vector of variables of which to compute the Jacobian. + * @param wrt Vector of variables with respect to which to compute the Jacobian. + */ + public Jacobian(VariableMatrix variables, VariableMatrix wrt) { + assert variables.cols() == 1; + assert wrt.cols() == 1; + + m_handle = JacobianJNI.create(variables.getHandles(), wrt.getHandles()); + m_rows = variables.rows(); + m_cols = wrt.rows(); + } + + /** + * Constructs a Jacobian object. + * + * @param variables Vector of variables of which to compute the Jacobian. + * @param wrt Vector of variables with respect to which to compute the Jacobian. + */ + public Jacobian(VariableMatrix variables, VariableBlock wrt) { + this(variables, new VariableMatrix(wrt)); + } + + @Override + public void close() { + if (m_handle != 0) { + JacobianJNI.destroy(m_handle); + m_handle = 0; + } + } + + /** + * Returns the Jacobian as a VariableMatrix. + * + *
This is useful when constructing optimization problems with derivatives in them. + * + * @return The Jacobian as a VariableMatrix. + */ + public VariableMatrix get() { + return new VariableMatrix(m_rows, m_cols, JacobianJNI.get(m_handle)); + } + + /** + * Evaluates the Jacobian at wrt's value. + * + * @return The Jacobian at wrt's value. + */ + public SimpleMatrix value() { + return JacobianJNI.value(m_handle).toSimpleMatrix(m_rows, m_cols); + } +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/JacobianJNI.java b/wpimath/src/main/java/org/wpilib/math/autodiff/JacobianJNI.java new file mode 100644 index 0000000000..f969d22e39 --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/JacobianJNI.java @@ -0,0 +1,48 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import org.wpilib.math.jni.WPIMathJNI; + +/** Jacobian JNI functions. */ +final class JacobianJNI extends WPIMathJNI { + private JacobianJNI() { + // Utility class. + } + + /** + * Constructs a Jacobian object. + * + * @param variables Vector of variables of which to compute the Jacobian. + * @param wrt Vector of variables with respect to which to compute the Jacobian. + */ + static native long create(long[] variables, long[] wrt); + + /** + * Destructs a Jacobian. + * + * @param handle Jacobian handle. + */ + static native void destroy(long handle); + + /** + * Returns the Jacobian as an array of Variable handles. + * + *
This is useful when constructing optimization problems with derivatives in them. + * + * @param handle Jacobian handle. + * @return The Jacobian as an array of Variable handles. + */ + static native long[] get(long handle); + + /** + * Evaluates the Jacobian at wrt's value. + * + * @param handle Jacobian handle. + * @return A record containing the triplet row, column and value arrays (int[], int[], and + * double[] respectively). + */ + static native NativeSparseTriplets value(long handle); +} diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/NativeSparseTriplets.java b/wpimath/src/main/java/org/wpilib/math/autodiff/NativeSparseTriplets.java new file mode 100644 index 0000000000..fd1deb138f --- /dev/null +++ b/wpimath/src/main/java/org/wpilib/math/autodiff/NativeSparseTriplets.java @@ -0,0 +1,37 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.math.autodiff; + +import org.ejml.data.DMatrixSparseCSC; +import org.ejml.data.DMatrixSparseTriplet; +import org.ejml.ops.DConvertMatrixStruct; +import org.ejml.simple.SimpleMatrix; + +/** + * Wrapper for sparse matrix triplets from JNI. + * + *
We can't use DMatrixSparseTriplet because it doesn't have a method for bulk-initialization
+ * from triplet arrays.
+ *
+ * @param rows Triplet rows.
+ * @param cols Triplet columns.
+ * @param values Triplet values.
+ */
+public record NativeSparseTriplets(int[] rows, int[] cols, double[] values) {
+ /**
+ * Returns a SimpleMatrix wrapper for this set of triplets.
+ *
+ * @param rows Number of rows in sparse SimpleMatrix.
+ * @param cols Number of columns in sparse SimpleMatrix.
+ * @return A SimpleMatrix wrapper for this set of triplets.
+ */
+ public SimpleMatrix toSimpleMatrix(int rows, int cols) {
+ var ejmlTriplets = new DMatrixSparseTriplet(rows, cols, values().length);
+ for (int i = 0; i < values().length; ++i) {
+ ejmlTriplets.addItem(rows()[i], cols()[i], values()[i]);
+ }
+ return new SimpleMatrix(DConvertMatrixStruct.convert(ejmlTriplets, (DMatrixSparseCSC) null));
+ }
+}
diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/NumericalIntegration.java b/wpimath/src/main/java/org/wpilib/math/autodiff/NumericalIntegration.java
new file mode 100644
index 0000000000..9d4ee80c82
--- /dev/null
+++ b/wpimath/src/main/java/org/wpilib/math/autodiff/NumericalIntegration.java
@@ -0,0 +1,89 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package org.wpilib.math.autodiff;
+
+import java.util.function.BiFunction;
+
+/** Numerical integration utilities. */
+public final class NumericalIntegration {
+ private NumericalIntegration() {
+ // Utility class.
+ }
+
+ /**
+ * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
+ *
+ * @param f The function to integrate. It must take two arguments x and u.
+ * @param x The initial value of x.
+ * @param u The value u held constant over the integration period.
+ * @param dt The time over which to integrate.
+ * @return the integration of dx/dt = f(x, u) for dt.
+ */
+ public static VariableMatrix rk4(
+ BiFunction This constructor is for internal use only.
+ *
+ * @param handleTypeTag Handle type tag.
+ * @param handle Variable handle.
+ */
+ @SuppressWarnings({"PMD.UnusedFormalParameter", "this-escape"})
+ public Variable(Handle handleTypeTag, long handle) {
+ m_handle = handle;
+ VariablePool.register(this);
+ }
+
+ @Override
+ public void close() {
+ if (m_handle != 0) {
+ VariableJNI.destroy(m_handle);
+ m_handle = 0;
+ }
+ }
+
+ /**
+ * Sets Variable's internal value.
+ *
+ * @param value The value of the Variable.
+ */
+ public void setValue(double value) {
+ VariableJNI.setValue(m_handle, value);
+ }
+
+ /**
+ * Returns the value of this variable.
+ *
+ * @return The value of this variable.
+ */
+ public double value() {
+ return VariableJNI.value(m_handle);
+ }
+
+ /**
+ * Returns the type of this expression (constant, linear, quadratic, or nonlinear).
+ *
+ * @return The type of this expression.
+ */
+ public ExpressionType type() {
+ return ExpressionType.fromInt(VariableJNI.type(m_handle));
+ }
+
+ /**
+ * Returns internal handle.
+ *
+ * @return Internal handle.
+ */
+ public long getHandle() {
+ return m_handle;
+ }
+
+ /**
+ * Variable-Variable multiplication operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of multiplication.
+ */
+ public Variable times(Variable rhs) {
+ return new Variable(HANDLE, VariableJNI.times(m_handle, rhs.getHandle()));
+ }
+
+ /**
+ * Variable-Variable multiplication operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of multiplication.
+ */
+ public Variable times(double rhs) {
+ return times(new Variable(rhs));
+ }
+
+ /**
+ * Variable-Variable division operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of division.
+ */
+ public Variable div(Variable rhs) {
+ return new Variable(HANDLE, VariableJNI.div(m_handle, rhs.getHandle()));
+ }
+
+ /**
+ * Variable-Variable division operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of division.
+ */
+ public Variable div(double rhs) {
+ return div(new Variable(rhs));
+ }
+
+ /**
+ * Variable-Variable addition operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of addition.
+ */
+ public Variable plus(Variable rhs) {
+ return new Variable(HANDLE, VariableJNI.plus(m_handle, rhs.getHandle()));
+ }
+
+ /**
+ * Variable-Variable addition operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of addition.
+ */
+ public Variable plus(double rhs) {
+ return plus(new Variable(rhs));
+ }
+
+ /**
+ * Variable-Variable subtraction operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of subtraction.
+ */
+ public Variable minus(Variable rhs) {
+ return new Variable(HANDLE, VariableJNI.minus(m_handle, rhs.getHandle()));
+ }
+
+ /**
+ * Variable-Variable subtraction operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of subtraction.
+ */
+ public Variable minus(double rhs) {
+ return minus(new Variable(rhs));
+ }
+
+ /**
+ * Unary minus operator.
+ *
+ * @return Result of unary minus.
+ */
+ public Variable unaryMinus() {
+ return new Variable(HANDLE, VariableJNI.unaryMinus(m_handle));
+ }
+
+ /**
+ * Unary plus operator.
+ *
+ * @return Result of unary plus.
+ */
+ public Variable unaryPlus() {
+ return this;
+ }
+
+ /**
+ * Math.abs() for Variables.
+ *
+ * @param x The argument.
+ * @return Result of abs().
+ */
+ public static Variable abs(Variable x) {
+ return new Variable(HANDLE, VariableJNI.abs(x.getHandle()));
+ }
+
+ /**
+ * Math.acos() for Variables.
+ *
+ * @param x The argument.
+ * @return Result of acos().
+ */
+ public static Variable acos(Variable x) {
+ return new Variable(HANDLE, VariableJNI.acos(x.getHandle()));
+ }
+
+ /**
+ * Math.asin() for Variables.
+ *
+ * @param x The argument.
+ * @return Result of asin().
+ */
+ public static Variable asin(Variable x) {
+ return new Variable(HANDLE, VariableJNI.asin(x.getHandle()));
+ }
+
+ /**
+ * Math.atan() for Variables.
+ *
+ * @param x The argument.
+ * @return Result of atan().
+ */
+ public static Variable atan(Variable x) {
+ return new Variable(HANDLE, VariableJNI.atan(x.getHandle()));
+ }
+
+ /**
+ * Math.atan2() for Variables.
+ *
+ * @param y The y argument.
+ * @param x The x argument.
+ * @return Result of atan2().
+ */
+ public static Variable atan2(double y, Variable x) {
+ return atan2(new Variable(y), x);
+ }
+
+ /**
+ * Math.atan2() for Variables.
+ *
+ * @param y The y argument.
+ * @param x The x argument.
+ * @return Result of atan2().
+ */
+ public static Variable atan2(Variable y, double x) {
+ return atan2(y, new Variable(x));
+ }
+
+ /**
+ * Math.atan2() for Variables.
+ *
+ * @param y The y argument.
+ * @param x The x argument.
+ * @return Result of atan2().
+ */
+ public static Variable atan2(Variable y, Variable x) {
+ return new Variable(HANDLE, VariableJNI.atan2(y.getHandle(), x.getHandle()));
+ }
+
+ /**
+ * Math.cbrt() for Variables.
+ *
+ * @param x The argument.
+ * @return Result of cbrt().
+ */
+ public static Variable cbrt(Variable x) {
+ return new Variable(HANDLE, VariableJNI.cbrt(x.getHandle()));
+ }
+
+ /**
+ * Math.cos() for Variables.
+ *
+ * @param x The argument.
+ * @return Result of cos().
+ */
+ public static Variable cos(Variable x) {
+ return new Variable(HANDLE, VariableJNI.cos(x.getHandle()));
+ }
+
+ /**
+ * Math.cosh() for Variables.
+ *
+ * @param x The argument.
+ * @return Result of cosh().
+ */
+ public static Variable cosh(Variable x) {
+ return new Variable(HANDLE, VariableJNI.cosh(x.getHandle()));
+ }
+
+ /**
+ * Math.exp() for Variables.
+ *
+ * @param x The argument.
+ * @return Result of exp().
+ */
+ public static Variable exp(Variable x) {
+ return new Variable(HANDLE, VariableJNI.exp(x.getHandle()));
+ }
+
+ /**
+ * Math.hypot() for Variables.
+ *
+ * @param x The x argument.
+ * @param y The y argument.
+ * @return Result of hypot().
+ */
+ public static Variable hypot(double x, Variable y) {
+ return hypot(new Variable(x), y);
+ }
+
+ /**
+ * Math.hypot() for Variables.
+ *
+ * @param x The x argument.
+ * @param y The y argument.
+ * @return Result of hypot().
+ */
+ public static Variable hypot(Variable x, double y) {
+ return hypot(x, new Variable(y));
+ }
+
+ /**
+ * Math.hypot() for Variables.
+ *
+ * @param x The x argument.
+ * @param y The y argument.
+ * @return Result of hypot().
+ */
+ public static Variable hypot(Variable x, Variable y) {
+ return new Variable(HANDLE, VariableJNI.hypot(x.getHandle(), y.getHandle()));
+ }
+
+ /**
+ * Math.hypot() for Variables.
+ *
+ * @param x The x argument.
+ * @param y The y argument.
+ * @param z The z argument.
+ * @return Result of hypot().
+ */
+ public static Variable hypot(double x, double y, Variable z) {
+ return hypot(new Variable(x), new Variable(y), z);
+ }
+
+ /**
+ * Math.hypot() for Variables.
+ *
+ * @param x The x argument.
+ * @param y The y argument.
+ * @param z The z argument.
+ * @return Result of hypot().
+ */
+ public static Variable hypot(double x, Variable y, double z) {
+ return hypot(new Variable(x), y, new Variable(z));
+ }
+
+ /**
+ * Math.hypot() for Variables.
+ *
+ * @param x The x argument.
+ * @param y The y argument.
+ * @param z The z argument.
+ * @return Result of hypot().
+ */
+ public static Variable hypot(double x, Variable y, Variable z) {
+ return hypot(new Variable(x), y, z);
+ }
+
+ /**
+ * Math.hypot() for Variables.
+ *
+ * @param x The x argument.
+ * @param y The y argument.
+ * @param z The z argument.
+ * @return Result of hypot().
+ */
+ public static Variable hypot(Variable x, double y, double z) {
+ return hypot(x, new Variable(y), new Variable(z));
+ }
+
+ /**
+ * Math.hypot() for Variables.
+ *
+ * @param x The x argument.
+ * @param y The y argument.
+ * @param z The z argument.
+ * @return Result of hypot().
+ */
+ public static Variable hypot(Variable x, double y, Variable z) {
+ return hypot(x, new Variable(y), z);
+ }
+
+ /**
+ * Math.hypot() for Variables.
+ *
+ * @param x The x argument.
+ * @param y The y argument.
+ * @param z The z argument.
+ * @return Result of hypot().
+ */
+ public static Variable hypot(Variable x, Variable y, double z) {
+ return hypot(x, y, new Variable(z));
+ }
+
+ /**
+ * Math.hypot() for Variables.
+ *
+ * @param x The x argument.
+ * @param y The y argument.
+ * @param z The z argument.
+ * @return Result of hypot().
+ */
+ public static Variable hypot(Variable x, Variable y, Variable z) {
+ return sqrt(pow(x, 2).plus(pow(y, 2)).plus(pow(z, 2)));
+ }
+
+ /**
+ * Math.log() for Variables.
+ *
+ * @param x The argument.
+ * @return Result of log().
+ */
+ public static Variable log(Variable x) {
+ return new Variable(HANDLE, VariableJNI.log(x.getHandle()));
+ }
+
+ /**
+ * Math.log10() for Variables.
+ *
+ * @param x The argument.
+ * @return Result of log10().
+ */
+ public static Variable log10(Variable x) {
+ return new Variable(HANDLE, VariableJNI.log10(x.getHandle()));
+ }
+
+ /**
+ * Math.max() for Variables.
+ *
+ * Returns the greater of a and b. If the values are equivalent, returns a.
+ *
+ * @param a The a argument.
+ * @param b The b argument.
+ * @return Result of max().
+ */
+ public static Variable max(double a, Variable b) {
+ return max(new Variable(a), b);
+ }
+
+ /**
+ * Math.max() for Variables.
+ *
+ * Returns the greater of a and b. If the values are equivalent, returns a.
+ *
+ * @param a The a argument.
+ * @param b The b argument.
+ * @return Result of max().
+ */
+ public static Variable max(Variable a, double b) {
+ return max(a, new Variable(b));
+ }
+
+ /**
+ * Math.max() for Variables.
+ *
+ * Returns the greater of a and b. If the values are equivalent, returns a.
+ *
+ * @param a The a argument.
+ * @param b The b argument.
+ * @return Result of max().
+ */
+ public static Variable max(Variable a, Variable b) {
+ return new Variable(HANDLE, VariableJNI.max(a.getHandle(), b.getHandle()));
+ }
+
+ /**
+ * min() for Variables.
+ *
+ * Returns the lesser of a and b. If the values are equivalent, returns a.
+ *
+ * @param a The a argument.
+ * @param b The b argument.
+ * @return Result of min().
+ */
+ public static Variable min(double a, Variable b) {
+ return min(new Variable(a), b);
+ }
+
+ /**
+ * min() for Variables.
+ *
+ * Returns the lesser of a and b. If the values are equivalent, returns a.
+ *
+ * @param a The a argument.
+ * @param b The b argument.
+ * @return Result of min().
+ */
+ public static Variable min(Variable a, double b) {
+ return min(a, new Variable(b));
+ }
+
+ /**
+ * min() for Variables.
+ *
+ * Returns the lesser of a and b. If the values are equivalent, returns a.
+ *
+ * @param a The a argument.
+ * @param b The b argument.
+ * @return Result of min().
+ */
+ public static Variable min(Variable a, Variable b) {
+ return new Variable(HANDLE, VariableJNI.min(a.getHandle(), b.getHandle()));
+ }
+
+ /**
+ * Math.pow() for Variables.
+ *
+ * @param base The base.
+ * @param power The power.
+ * @return Result of pow().
+ */
+ public static Variable pow(double base, Variable power) {
+ return pow(new Variable(base), power);
+ }
+
+ /**
+ * Math.pow() for Variables.
+ *
+ * @param base The base.
+ * @param power The power.
+ * @return Result of pow().
+ */
+ public static Variable pow(Variable base, double power) {
+ return pow(base, new Variable(power));
+ }
+
+ /**
+ * Math.pow() for Variables.
+ *
+ * @param base The base.
+ * @param power The power.
+ * @return Result of pow().
+ */
+ public static Variable pow(Variable base, Variable power) {
+ return new Variable(HANDLE, VariableJNI.pow(base.getHandle(), power.getHandle()));
+ }
+
+ /**
+ * Math.signum() for Variables.
+ *
+ * @param x The argument.
+ * @return Result of signum().
+ */
+ public static Variable signum(Variable x) {
+ return new Variable(HANDLE, VariableJNI.signum(x.getHandle()));
+ }
+
+ /**
+ * Math.sin() for Variables.
+ *
+ * @param x The argument.
+ * @return Result of sin().
+ */
+ public static Variable sin(Variable x) {
+ return new Variable(HANDLE, VariableJNI.sin(x.getHandle()));
+ }
+
+ /**
+ * Math.sinh() for Variables.
+ *
+ * @param x The argument.
+ * @return Result of sinh().
+ */
+ public static Variable sinh(Variable x) {
+ return new Variable(HANDLE, VariableJNI.sinh(x.getHandle()));
+ }
+
+ /**
+ * Math.sqrt() for Variables.
+ *
+ * @param x The argument.
+ * @return Result of sqrt().
+ */
+ public static Variable sqrt(Variable x) {
+ return new Variable(HANDLE, VariableJNI.sqrt(x.getHandle()));
+ }
+
+ /**
+ * Math.tan() for Variables.
+ *
+ * @param x The argument.
+ * @return Result of tan().
+ */
+ public static Variable tan(Variable x) {
+ return new Variable(HANDLE, VariableJNI.tan(x.getHandle()));
+ }
+
+ /**
+ * Math.tanh() for Variables.
+ *
+ * @param x The argument.
+ * @return Result of tanh().
+ */
+ public static Variable tanh(Variable x) {
+ return new Variable(HANDLE, VariableJNI.tanh(x.getHandle()));
+ }
+
+ /**
+ * Returns the total native memory usage of all Variables in bytes.
+ *
+ * @return The total native memory usage of all Variables in bytes.
+ */
+ public static long totalNativeMemoryUsage() {
+ return VariableJNI.totalNativeMemoryUsage();
+ }
+}
diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/VariableBlock.java b/wpimath/src/main/java/org/wpilib/math/autodiff/VariableBlock.java
new file mode 100644
index 0000000000..b92c6133ad
--- /dev/null
+++ b/wpimath/src/main/java/org/wpilib/math/autodiff/VariableBlock.java
@@ -0,0 +1,826 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package org.wpilib.math.autodiff;
+
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+import java.util.function.UnaryOperator;
+import java.util.stream.Stream;
+import java.util.stream.StreamSupport;
+import org.ejml.simple.SimpleMatrix;
+
+/** A submatrix of autodiff variables with reference semantics. */
+public class VariableBlock implements Iterable Note that the slices are taken as is rather than adjusted.
+ *
+ * @param mat The matrix to which to point.
+ * @param rowSlice The block's row slice.
+ * @param rowSliceLength The block's row length.
+ * @param colSlice The block's column slice.
+ * @param colSliceLength The block's column length.
+ */
+ public VariableBlock(
+ VariableMatrix mat, Slice rowSlice, int rowSliceLength, Slice colSlice, int colSliceLength) {
+ m_mat = mat;
+ m_rowSlice = rowSlice;
+ m_rowSliceLength = rowSliceLength;
+ m_colSlice = colSlice;
+ m_colSliceLength = colSliceLength;
+ }
+
+ /**
+ * Assigns a double to the block.
+ *
+ * This only works for blocks with one row and one column.
+ *
+ * @param value Value to assign.
+ * @return This VariableBlock.
+ */
+ public VariableBlock set(double value) {
+ assert rows() == 1 && cols() == 1;
+
+ set(0, 0, new Variable(value));
+
+ return this;
+ }
+
+ /**
+ * Assigns a Variable to the block.
+ *
+ * This only works for blocks with one row and one column.
+ *
+ * @param value Value to assign.
+ * @return This VariableBlock.
+ */
+ public VariableBlock set(Variable value) {
+ assert rows() == 1 && cols() == 1;
+
+ set(0, 0, value);
+
+ return this;
+ }
+
+ /**
+ * Assigns a double array to the block.
+ *
+ * @param values Double array of values to assign.
+ * @return This VariableBlock.
+ */
+ public VariableBlock set(double[][] values) {
+ assert rows() == values.length;
+
+ // Assert all column counts are the same
+ for (var row : values) {
+ assert row.length == cols();
+ }
+
+ for (int row = 0; row < rows(); ++row) {
+ for (int col = 0; col < cols(); ++col) {
+ set(row, col, values[row][col]);
+ }
+ }
+
+ return this;
+ }
+
+ /**
+ * Assigns an EJML matrix to the block.
+ *
+ * @param values EJML matrix of values to assign.
+ * @return This VariableBlock.
+ */
+ public VariableBlock set(SimpleMatrix values) {
+ assert rows() == values.getNumRows() && cols() == values.getNumCols();
+
+ for (int row = 0; row < rows(); ++row) {
+ for (int col = 0; col < cols(); ++col) {
+ set(row, col, values.get(row, col));
+ }
+ }
+
+ return this;
+ }
+
+ /**
+ * Assigns a VariableMatrix to the block.
+ *
+ * @param values VariableMatrix of values.
+ * @return This VariableBlock.
+ */
+ public VariableBlock set(VariableMatrix values) {
+ assert rows() == values.rows() && cols() == values.cols();
+
+ for (int row = 0; row < rows(); ++row) {
+ for (int col = 0; col < cols(); ++col) {
+ set(row, col, values.get(row, col));
+ }
+ }
+ return this;
+ }
+
+ /**
+ * Assigns a VariableBlock to the block.
+ *
+ * @param values VariableBlock of values.
+ * @return This VariableBlock.
+ */
+ public VariableBlock set(VariableBlock values) {
+ assert rows() == values.rows() && cols() == values.cols();
+
+ for (int row = 0; row < rows(); ++row) {
+ for (int col = 0; col < cols(); ++col) {
+ set(row, col, values.get(row, col));
+ }
+ }
+ return this;
+ }
+
+ /**
+ * Sets a scalar subblock at the given row and column.
+ *
+ * @param row The scalar subblock's row.
+ * @param col The scalar subblock's column.
+ * @param value The value.
+ */
+ public void set(int row, int col, Variable value) {
+ assert row >= 0 && row < rows();
+ assert col >= 0 && col < cols();
+ m_mat.set(
+ m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step, value);
+ }
+
+ /**
+ * Sets a scalar subblock at the given row and column.
+ *
+ * @param row The scalar subblock's row.
+ * @param col The scalar subblock's column.
+ * @param value The value.
+ */
+ public void set(int row, int col, double value) {
+ assert row >= 0 && row < rows();
+ assert col >= 0 && col < cols();
+ m_mat.set(
+ m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step, value);
+ }
+
+ /**
+ * Sets a scalar subblock at the given index.
+ *
+ * @param index The scalar subblock's index.
+ * @param value The value.
+ */
+ public void set(int index, double value) {
+ set(index, new Variable(value));
+ }
+
+ /**
+ * Sets a scalar subblock at the given index.
+ *
+ * @param index The scalar subblock's index.
+ * @param value The value.
+ */
+ public void set(int index, Variable value) {
+ assert index >= 0 && index < rows() * cols();
+ set(index / cols(), index % cols(), value);
+ }
+
+ /**
+ * Assigns a double to the block.
+ *
+ * This only works for blocks with one row and one column.
+ *
+ * @param value Value to assign.
+ */
+ public void setValue(double value) {
+ assert rows() == 1 && cols() == 1;
+
+ get(0, 0).setValue(value);
+ }
+
+ /**
+ * Sets block's internal values.
+ *
+ * @param values Double array of values.
+ */
+ public void setValue(double[][] values) {
+ assert rows() == values.length;
+
+ // Assert all column counts are the same
+ for (var row : values) {
+ assert row.length == cols();
+ }
+
+ for (int row = 0; row < rows(); ++row) {
+ for (int col = 0; col < cols(); ++col) {
+ get(row, col).setValue(values[row][col]);
+ }
+ }
+ }
+
+ /**
+ * Sets block's internal values.
+ *
+ * @param values EJML matrix of values.
+ */
+ public void setValue(SimpleMatrix values) {
+ assert rows() == values.getNumRows() && cols() == values.getNumCols();
+
+ for (int row = 0; row < rows(); ++row) {
+ for (int col = 0; col < cols(); ++col) {
+ get(row, col).setValue(values.get(row, col));
+ }
+ }
+ }
+
+ /**
+ * Returns a scalar subblock at the given row and column.
+ *
+ * @param row The scalar subblock's row.
+ * @param col The scalar subblock's column.
+ * @return A scalar subblock at the given row and column.
+ */
+ public Variable get(int row, int col) {
+ assert row >= 0 && row < rows();
+ assert col >= 0 && col < cols();
+ return m_mat.get(
+ m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step);
+ }
+
+ /**
+ * Returns a scalar subblock at the given index.
+ *
+ * @param index The scalar subblock's index.
+ * @return A scalar subblock at the given index.
+ */
+ public Variable get(int index) {
+ assert index >= 0 && index < rows() * cols();
+ return get(index / cols(), index % cols());
+ }
+
+ /**
+ * Returns a slice of the variable matrix.
+ *
+ * @param row The row.
+ * @param colSlice The column slice.
+ * @return A slice of the variable matrix.
+ */
+ public VariableBlock get(int row, Slice.None colSlice) {
+ return get(new Slice(row), new Slice(colSlice));
+ }
+
+ /**
+ * Returns a slice of the variable matrix.
+ *
+ * @param row The row.
+ * @param colSlice The column slice.
+ * @return A slice of the variable matrix.
+ */
+ public VariableBlock get(int row, Slice colSlice) {
+ return get(new Slice(row), colSlice);
+ }
+
+ /**
+ * Returns a slice of the variable matrix.
+ *
+ * @param rowSlice The row slice.
+ * @param col The column.
+ * @return A slice of the variable matrix.
+ */
+ public VariableBlock get(Slice.None rowSlice, int col) {
+ return get(new Slice(rowSlice), new Slice(col));
+ }
+
+ /**
+ * Returns a slice of the variable matrix.
+ *
+ * @param rowSlice The row slice.
+ * @param col The column.
+ * @return A slice of the variable matrix.
+ */
+ public VariableBlock get(Slice rowSlice, int col) {
+ return get(rowSlice, new Slice(col));
+ }
+
+ /**
+ * Returns a slice of the variable matrix.
+ *
+ * @param rowSlice The row slice.
+ * @param colSlice The column slice.
+ * @return A slice of the variable matrix.
+ */
+ public VariableBlock get(Slice.None rowSlice, Slice.None colSlice) {
+ return get(new Slice(rowSlice), new Slice(colSlice));
+ }
+
+ /**
+ * Returns a slice of the variable matrix.
+ *
+ * @param rowSlice The row slice.
+ * @param colSlice The column slice.
+ * @return A slice of the variable matrix.
+ */
+ public VariableBlock get(Slice.None rowSlice, Slice colSlice) {
+ return get(new Slice(rowSlice), colSlice);
+ }
+
+ /**
+ * Returns a slice of the variable matrix.
+ *
+ * @param rowSlice The row slice.
+ * @param colSlice The column slice.
+ * @return A slice of the variable matrix.
+ */
+ public VariableBlock get(Slice rowSlice, Slice.None colSlice) {
+ return get(rowSlice, new Slice(colSlice));
+ }
+
+ /**
+ * Returns a slice of the variable matrix.
+ *
+ * @param rowSlice The row slice.
+ * @param colSlice The column slice.
+ * @return A slice of the variable matrix.
+ */
+ public VariableBlock get(Slice rowSlice, Slice colSlice) {
+ int rowSliceLength = rowSlice.adjust(m_rowSliceLength);
+ int colSliceLength = colSlice.adjust(m_colSliceLength);
+ return new VariableBlock(
+ m_mat,
+ new Slice(
+ m_rowSlice.start + rowSlice.start * m_rowSlice.step,
+ m_rowSlice.start + rowSlice.stop * m_rowSlice.step,
+ rowSlice.step * m_rowSlice.step),
+ rowSliceLength,
+ new Slice(
+ m_colSlice.start + colSlice.start * m_colSlice.step,
+ m_colSlice.start + colSlice.stop * m_colSlice.step,
+ colSlice.step * m_colSlice.step),
+ colSliceLength);
+ }
+
+ /**
+ * Returns a block of the variable matrix.
+ *
+ * @param rowOffset The row offset of the block selection.
+ * @param colOffset The column offset of the block selection.
+ * @param blockRows The number of rows in the block selection.
+ * @param blockCols The number of columns in the block selection.
+ * @return A block of the variable matrix.
+ */
+ public VariableBlock block(int rowOffset, int colOffset, int blockRows, int blockCols) {
+ assert rowOffset >= 0 && rowOffset <= rows();
+ assert colOffset >= 0 && colOffset <= cols();
+ assert blockRows >= 0 && blockRows <= rows() - rowOffset;
+ assert blockCols >= 0 && blockCols <= cols() - colOffset;
+ return get(
+ new Slice(rowOffset, rowOffset + blockRows, 1),
+ new Slice(colOffset, colOffset + blockCols, 1));
+ }
+
+ /**
+ * Returns a segment of the variable vector.
+ *
+ * @param offset The offset of the segment.
+ * @param length The length of the segment.
+ * @return A segment of the variable vector.
+ */
+ public VariableBlock segment(int offset, int length) {
+ assert cols() == 1;
+ assert offset >= 0 && offset < rows();
+ assert length >= 0 && length <= rows() - offset;
+ return block(offset, 0, length, 1);
+ }
+
+ /**
+ * Returns a row slice of the variable matrix.
+ *
+ * @param row The row to slice.
+ * @return A row slice of the variable matrix.
+ */
+ public VariableBlock row(int row) {
+ assert row >= 0 && row < rows();
+ return block(row, 0, 1, cols());
+ }
+
+ /**
+ * Returns a column slice of the variable matrix.
+ *
+ * @param col The column to slice.
+ * @return A column slice of the variable matrix.
+ */
+ public VariableBlock col(int col) {
+ assert col >= 0 && col < cols();
+ return block(0, col, rows(), 1);
+ }
+
+ /**
+ * Matrix multiplication operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of matrix multiplication.
+ */
+ public VariableMatrix times(VariableMatrix rhs) {
+ assert cols() == rhs.rows();
+
+ var result = new VariableMatrix(rows(), rhs.cols());
+
+ for (int i = 0; i < rows(); ++i) {
+ for (int j = 0; j < rhs.cols(); ++j) {
+ var sum = new Variable(0.0);
+ for (int k = 0; k < cols(); ++k) {
+ sum = sum.plus(get(i, k).times(rhs.get(k, j)));
+ }
+ result.set(i, j, sum);
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Matrix multiplication operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of matrix multiplication.
+ */
+ public VariableMatrix times(VariableBlock rhs) {
+ assert cols() == rhs.rows();
+
+ var result = new VariableMatrix(rows(), rhs.cols());
+
+ for (int i = 0; i < rows(); ++i) {
+ for (int j = 0; j < rhs.cols(); ++j) {
+ var sum = new Variable(0.0);
+ for (int k = 0; k < cols(); ++k) {
+ sum = sum.plus(get(i, k).times(rhs.get(k, j)));
+ }
+ result.set(i, j, sum);
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Matrix-scalar multiplication operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of matrix-scalar multiplication.
+ */
+ public VariableMatrix times(double rhs) {
+ return times(new Variable(rhs));
+ }
+
+ /**
+ * Matrix-scalar multiplication operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of matrix-scalar multiplication.
+ */
+ public VariableMatrix times(Variable rhs) {
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).times(rhs));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Binary division operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of division.
+ */
+ public VariableMatrix div(double rhs) {
+ return div(new Variable(rhs));
+ }
+
+ /**
+ * Binary division operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of division.
+ */
+ public VariableMatrix div(Variable rhs) {
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).div(rhs));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Binary addition operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of addition.
+ */
+ public VariableMatrix plus(VariableMatrix rhs) {
+ assert rows() == rhs.rows() && cols() == rhs.cols();
+
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).plus(rhs.get(row, col)));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Binary addition operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of addition.
+ */
+ public VariableMatrix plus(VariableBlock rhs) {
+ assert rows() == rhs.rows() && cols() == rhs.cols();
+
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).plus(rhs.get(row, col)));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Binary addition operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of addition.
+ */
+ public VariableMatrix plus(SimpleMatrix rhs) {
+ assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols();
+
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).plus(rhs.get(row, col)));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Binary subtraction operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of subtraction.
+ */
+ public VariableMatrix minus(VariableMatrix rhs) {
+ assert rows() == rhs.rows() && cols() == rhs.cols();
+
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).minus(rhs.get(row, col)));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Binary subtraction operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of subtraction.
+ */
+ public VariableMatrix minus(VariableBlock rhs) {
+ assert rows() == rhs.rows() && cols() == rhs.cols();
+
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).minus(rhs.get(row, col)));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Binary subtraction operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of subtraction.
+ */
+ public VariableMatrix minus(SimpleMatrix rhs) {
+ assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols();
+
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).minus(rhs.get(row, col)));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Unary minus operator.
+ *
+ * @return Result of unary minus.
+ */
+ public VariableMatrix unaryMinus() {
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).unaryMinus());
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Returns the transpose of the variable matrix.
+ *
+ * @return The transpose of the variable matrix.
+ */
+ public VariableMatrix T() {
+ var result = new VariableMatrix(cols(), rows());
+
+ for (int row = 0; row < rows(); ++row) {
+ for (int col = 0; col < cols(); ++col) {
+ result.set(col, row, get(row, col));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Returns the number of rows in the matrix.
+ *
+ * @return The number of rows in the matrix.
+ */
+ public int rows() {
+ return m_rowSliceLength;
+ }
+
+ /**
+ * Returns the number of columns in the matrix.
+ *
+ * @return The number of columns in the matrix.
+ */
+ public int cols() {
+ return m_colSliceLength;
+ }
+
+ /**
+ * Returns an element of the variable matrix.
+ *
+ * @param row The row of the element to return.
+ * @param col The column of the element to return.
+ * @return An element of the variable matrix.
+ */
+ public double value(int row, int col) {
+ return get(row, col).value();
+ }
+
+ /**
+ * Returns an element of the variable block.
+ *
+ * @param index The index of the element to return.
+ * @return An element of the variable block.
+ */
+ public double value(int index) {
+ return get(index).value();
+ }
+
+ /**
+ * Returns the contents of the variable matrix.
+ *
+ * @return The contents of the variable matrix.
+ */
+ public SimpleMatrix value() {
+ var result = new SimpleMatrix(rows(), cols());
+
+ for (int row = 0; row < rows(); ++row) {
+ for (int col = 0; col < cols(); ++col) {
+ result.set(row, col, value(row, col));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Maps the matrix coefficient-wise with an unary operator.
+ *
+ * @param unaryOp The unary operator to use for the map operation.
+ * @return Result of the unary operator.
+ */
+ public VariableMatrix cwiseMap(UnaryOperator Returns the greater of a and b. If the values are equivalent, returns a.
+ *
+ * @param a The a argument.
+ * @param b The b argument.
+ */
+ static native long max(long a, long b);
+
+ /**
+ * Math.min() for Variables.
+ *
+ * Returns the lesser of a and b. If the values are equivalent, returns a.
+ *
+ * @param a The a argument.
+ * @param b The b argument.
+ */
+ static native long min(long a, long b);
+
+ /**
+ * Math.pow() for Variables.
+ *
+ * @param base The base.
+ * @param power The power.
+ */
+ static native long pow(long base, long power);
+
+ /**
+ * Math.signum() for Variables.
+ *
+ * @param x The argument.
+ */
+ static native long signum(long x);
+
+ /**
+ * Math.sin() for Variables.
+ *
+ * @param x The argument.
+ */
+ static native long sin(long x);
+
+ /**
+ * Math.sinh() for Variables.
+ *
+ * @param x The argument.
+ */
+ static native long sinh(long x);
+
+ /**
+ * Math.sqrt() for Variables.
+ *
+ * @param x The argument.
+ */
+ static native long sqrt(long x);
+
+ /**
+ * Math.tan() for Variables.
+ *
+ * @param x The argument.
+ */
+ static native long tan(long x);
+
+ /**
+ * Math.tanh() for Variables.
+ *
+ * @param x The argument.
+ */
+ static native long tanh(long x);
+
+ /**
+ * Returns the total native memory usage of Variables in bytes.
+ *
+ * @return The total native memory usage of Variables in bytes.
+ */
+ static native long totalNativeMemoryUsage();
+}
diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/VariableMatrix.java b/wpimath/src/main/java/org/wpilib/math/autodiff/VariableMatrix.java
new file mode 100644
index 0000000000..1c7a7ffb4d
--- /dev/null
+++ b/wpimath/src/main/java/org/wpilib/math/autodiff/VariableMatrix.java
@@ -0,0 +1,1047 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package org.wpilib.math.autodiff;
+
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+import java.util.function.BinaryOperator;
+import java.util.function.UnaryOperator;
+import java.util.stream.Stream;
+import java.util.stream.StreamSupport;
+import org.ejml.simple.SimpleMatrix;
+
+/** A matrix of autodiff variables. */
+public class VariableMatrix implements AutoCloseable, Iterable This constructor is for internal use only.
+ *
+ * @param rows The number of matrix rows.
+ * @param cols The number of matrix columns.
+ * @param handles Variable handles.
+ */
+ public VariableMatrix(int rows, int cols, long[] handles) {
+ assert handles.length == rows * cols;
+
+ m_rows = rows;
+ m_cols = cols;
+ m_storage = new Variable[rows * cols];
+ for (int index = 0; index < m_storage.length; ++index) {
+ m_storage[index] = new Variable(Variable.HANDLE, handles[index]);
+ }
+ }
+
+ /**
+ * Constructs a zero-initialized VariableMatrix column vector with the given rows.
+ *
+ * @param rows The number of matrix rows.
+ */
+ public VariableMatrix(int rows) {
+ this(rows, 1);
+ }
+
+ /**
+ * Constructs a zero-initialized VariableMatrix with the given dimensions.
+ *
+ * @param rows The number of matrix rows.
+ * @param cols The number of matrix columns.
+ */
+ public VariableMatrix(int rows, int cols) {
+ m_rows = rows;
+ m_cols = cols;
+ m_storage = new Variable[rows * cols];
+ for (int index = 0; index < m_storage.length; ++index) {
+ m_storage[index] = new Variable();
+ }
+ }
+
+ /**
+ * Constructs a scalar VariableMatrix from a nested list of doubles.
+ *
+ * @param list The nested list of Variables.
+ */
+ public VariableMatrix(double[][] list) {
+ // Get row and column counts for destination matrix
+ m_rows = list.length;
+ m_cols = 0;
+ if (list.length > 0) {
+ m_cols = list[0].length;
+ }
+
+ // Assert all column counts are the same
+ for (var row : list) {
+ assert row.length == m_cols;
+ }
+
+ m_storage = new Variable[m_rows * m_cols];
+ int index = 0;
+ for (var row : list) {
+ for (var elem : row) {
+ m_storage[index] = new Variable(elem);
+ ++index;
+ }
+ }
+ }
+
+ /**
+ * Constructs a scalar VariableMatrix from a nested list of Variables.
+ *
+ * @param list The nested list of Variables.
+ */
+ public VariableMatrix(Variable[][] list) {
+ // Get row and column counts for destination matrix
+ m_rows = list.length;
+ m_cols = 0;
+ if (list.length > 0) {
+ m_cols = list[0].length;
+ }
+
+ // Assert all column counts are the same
+ for (var row : list) {
+ assert row.length == m_cols;
+ }
+
+ m_storage = new Variable[m_rows * m_cols];
+ int index = 0;
+ for (var row : list) {
+ for (var elem : row) {
+ m_storage[index] = elem;
+ ++index;
+ }
+ }
+ }
+
+ /**
+ * Constructs a VariableMatrix from an EJML matrix.
+ *
+ * @param values EJML matrix of values.
+ */
+ public VariableMatrix(SimpleMatrix values) {
+ m_rows = values.getNumRows();
+ m_cols = values.getNumCols();
+ m_storage = new Variable[m_rows * m_cols];
+ for (int row = 0; row < values.getNumRows(); ++row) {
+ for (int col = 0; col < values.getNumCols(); ++col) {
+ m_storage[row * m_cols + col] = new Variable(values.get(row, col));
+ }
+ }
+ }
+
+ /**
+ * Constructs a scalar VariableMatrix from a Variable.
+ *
+ * @param variable Variable.
+ */
+ public VariableMatrix(Variable variable) {
+ m_rows = 1;
+ m_cols = 1;
+ m_storage = new Variable[] {variable};
+ }
+
+ /**
+ * Constructs a VariableMatrix from a VariableBlock.
+ *
+ * @param values VariableBlock of values.
+ */
+ public VariableMatrix(VariableBlock values) {
+ m_rows = values.rows();
+ m_cols = values.cols();
+ m_storage = new Variable[m_rows * m_cols];
+ for (int row = 0; row < m_rows; ++row) {
+ for (int col = 0; col < m_cols; ++col) {
+ m_storage[row * m_cols + col] = values.get(row, col);
+ }
+ }
+ }
+
+ @Override
+ public void close() {
+ for (int index = 0; index < rows() * cols(); ++index) {
+ m_storage[index].close();
+ }
+ }
+
+ /**
+ * Assigns a double array to a VariableMatrix.
+ *
+ * @param values Double array of values.
+ * @return This VariableMatrix.
+ */
+ public VariableMatrix set(double[][] values) {
+ assert rows() == values.length;
+
+ // Assert all column counts are the same
+ for (var row : values) {
+ assert row.length == cols();
+ }
+
+ for (int row = 0; row < values.length; ++row) {
+ for (int col = 0; col < values[0].length; ++col) {
+ set(row, col, values[row][col]);
+ }
+ }
+
+ return this;
+ }
+
+ /**
+ * Assigns an EJML matrix to a VariableMatrix.
+ *
+ * @param values EJML matrix of values.
+ * @return This VariableMatrix.
+ */
+ public VariableMatrix set(SimpleMatrix values) {
+ assert rows() == values.getNumRows() && cols() == values.getNumCols();
+
+ for (int row = 0; row < values.getNumRows(); ++row) {
+ for (int col = 0; col < values.getNumCols(); ++col) {
+ set(row, col, values.get(row, col));
+ }
+ }
+
+ return this;
+ }
+
+ /**
+ * Assigns a VariableMatrix to a VariableMatrix.
+ *
+ * @param values VariableMatrix of values.
+ * @return This VariableMatrix.
+ */
+ public VariableMatrix set(VariableMatrix values) {
+ assert rows() == values.rows() && cols() == values.cols();
+
+ for (int row = 0; row < values.rows(); ++row) {
+ for (int col = 0; col < values.cols(); ++col) {
+ set(row, col, values.get(row, col));
+ }
+ }
+
+ return this;
+ }
+
+ /**
+ * Assigns a VariableBlock to a VariableMatrix.
+ *
+ * @param values VariableBlock of values.
+ * @return This VariableMatrix.
+ */
+ public VariableMatrix set(VariableBlock values) {
+ assert rows() == values.rows() && cols() == values.cols();
+
+ for (int row = 0; row < values.rows(); ++row) {
+ for (int col = 0; col < values.cols(); ++col) {
+ set(row, col, values.get(row, col));
+ }
+ }
+
+ return this;
+ }
+
+ /**
+ * Assigns a double to the matrix.
+ *
+ * This only works for matrices with one row and one column.
+ *
+ * @param value Value to assign.
+ * @return This VariableMatrix.
+ */
+ public VariableMatrix set(double value) {
+ return set(new Variable(value));
+ }
+
+ /**
+ * Assigns a Variable to the matrix.
+ *
+ * This only works for matrices with one row and one column.
+ *
+ * @param value Value to assign.
+ * @return This VariableMatrix.
+ */
+ public VariableMatrix set(Variable value) {
+ assert rows() == 1 && cols() == 1;
+
+ m_storage[0] = value;
+
+ return this;
+ }
+
+ /**
+ * Sets an element to the given value.
+ *
+ * @param row The row.
+ * @param col The column.
+ * @param value The value.
+ */
+ public void set(int row, int col, Variable value) {
+ assert row >= 0 && row < rows();
+ assert col >= 0 && col < cols();
+ m_storage[row * cols() + col] = value;
+ }
+
+ /**
+ * Sets an element to the given value.
+ *
+ * @param row The row.
+ * @param col The column.
+ * @param value The value.
+ */
+ public void set(int row, int col, double value) {
+ assert row >= 0 && row < rows();
+ assert col >= 0 && col < cols();
+ m_storage[row * cols() + col] = new Variable(value);
+ }
+
+ /**
+ * Sets an element to the given value.
+ *
+ * @param index The index of the element.
+ * @param value The value.
+ */
+ public void set(int index, double value) {
+ set(index, new Variable(value));
+ }
+
+ /**
+ * Sets an element to the given value.
+ *
+ * @param index The index of the element.
+ * @param value The value.
+ */
+ public void set(int index, Variable value) {
+ assert index >= 0 && index < rows() * cols();
+ m_storage[index] = value;
+ }
+
+ /**
+ * Sets the VariableMatrix's internal values.
+ *
+ * @param values Double array of values.
+ */
+ public void setValue(double[][] values) {
+ assert rows() == values.length;
+
+ // Assert all column counts are the same
+ for (var row : values) {
+ assert row.length == cols();
+ }
+
+ for (int row = 0; row < rows(); ++row) {
+ for (int col = 0; col < cols(); ++col) {
+ get(row, col).setValue(values[row][col]);
+ }
+ }
+ }
+
+ /**
+ * Sets the VariableMatrix's internal values.
+ *
+ * @param values EJML matrix of values.
+ */
+ public void setValue(SimpleMatrix values) {
+ assert rows() == values.getNumRows() && cols() == values.getNumCols();
+
+ for (int row = 0; row < values.getNumRows(); ++row) {
+ for (int col = 0; col < values.getNumCols(); ++col) {
+ get(row, col).setValue(values.get(row, col));
+ }
+ }
+ }
+
+ /**
+ * Returns the element at the given row and column.
+ *
+ * @param row The row.
+ * @param col The column.
+ * @return The element at the given row and column.
+ */
+ public Variable get(int row, int col) {
+ assert row >= 0 && row < rows();
+ assert col >= 0 && col < cols();
+ return m_storage[row * cols() + col];
+ }
+
+ /**
+ * Returns the element at the given index.
+ *
+ * @param index The index.
+ * @return The element at the given index.
+ */
+ public Variable get(int index) {
+ assert index >= 0 && index < rows() * cols();
+ return m_storage[index];
+ }
+
+ /**
+ * Returns a slice of the variable matrix.
+ *
+ * @param row The row.
+ * @param colSlice The column slice.
+ * @return A slice of the variable matrix.
+ */
+ public VariableBlock get(int row, Slice.None colSlice) {
+ return get(new Slice(row), new Slice(colSlice));
+ }
+
+ /**
+ * Returns a slice of the variable matrix.
+ *
+ * @param row The row.
+ * @param colSlice The column slice.
+ * @return A slice of the variable matrix.
+ */
+ public VariableBlock get(int row, Slice colSlice) {
+ return get(new Slice(row), colSlice);
+ }
+
+ /**
+ * Returns a slice of the variable matrix.
+ *
+ * @param rowSlice The row slice.
+ * @param col The column.
+ * @return A slice of the variable matrix.
+ */
+ public VariableBlock get(Slice.None rowSlice, int col) {
+ return get(new Slice(rowSlice), new Slice(col));
+ }
+
+ /**
+ * Returns a slice of the variable matrix.
+ *
+ * @param rowSlice The row slice.
+ * @param col The column.
+ * @return A slice of the variable matrix.
+ */
+ public VariableBlock get(Slice rowSlice, int col) {
+ return get(rowSlice, new Slice(col));
+ }
+
+ /**
+ * Returns a slice of the variable matrix.
+ *
+ * @param rowSlice The row slice.
+ * @param colSlice The column slice.
+ * @return A slice of the variable matrix.
+ */
+ public VariableBlock get(Slice.None rowSlice, Slice.None colSlice) {
+ return get(new Slice(rowSlice), new Slice(colSlice));
+ }
+
+ /**
+ * Returns a slice of the variable matrix.
+ *
+ * @param rowSlice The row slice.
+ * @param colSlice The column slice.
+ * @return A slice of the variable matrix.
+ */
+ public VariableBlock get(Slice.None rowSlice, Slice colSlice) {
+ return get(new Slice(rowSlice), colSlice);
+ }
+
+ /**
+ * Returns a slice of the variable matrix.
+ *
+ * @param rowSlice The row slice.
+ * @param colSlice The column slice.
+ * @return A slice of the variable matrix.
+ */
+ public VariableBlock get(Slice rowSlice, Slice.None colSlice) {
+ return get(rowSlice, new Slice(colSlice));
+ }
+
+ /**
+ * Returns a slice of the variable matrix.
+ *
+ * @param rowSlice The row slice.
+ * @param colSlice The column slice.
+ * @return A slice of the variable matrix.
+ */
+ public VariableBlock get(Slice rowSlice, Slice colSlice) {
+ int rowSliceLength = rowSlice.adjust(rows());
+ int colSliceLength = colSlice.adjust(cols());
+ return new VariableBlock(this, rowSlice, rowSliceLength, colSlice, colSliceLength);
+ }
+
+ /**
+ * Returns a block of the variable matrix.
+ *
+ * @param rowOffset The row offset of the block selection.
+ * @param colOffset The column offset of the block selection.
+ * @param blockRows The number of rows in the block selection.
+ * @param blockCols The number of columns in the block selection.
+ * @return A block of the variable matrix.
+ */
+ public VariableBlock block(int rowOffset, int colOffset, int blockRows, int blockCols) {
+ assert rowOffset >= 0 && rowOffset <= rows();
+ assert colOffset >= 0 && colOffset <= cols();
+ assert blockRows >= 0 && blockRows <= rows() - rowOffset;
+ assert blockCols >= 0 && blockCols <= cols() - colOffset;
+ return new VariableBlock(this, rowOffset, colOffset, blockRows, blockCols);
+ }
+
+ /**
+ * Returns a segment of the variable vector.
+ *
+ * @param offset The offset of the segment.
+ * @param length The length of the segment.
+ * @return A segment of the variable vector.
+ */
+ public VariableBlock segment(int offset, int length) {
+ assert cols() == 1;
+ assert offset >= 0 && offset < rows();
+ assert length >= 0 && length <= rows() - offset;
+ return block(offset, 0, length, 1);
+ }
+
+ /**
+ * Returns a row slice of the variable matrix.
+ *
+ * @param row The row to slice.
+ * @return A row slice of the variable matrix.
+ */
+ public VariableBlock row(int row) {
+ assert row >= 0 && row < rows();
+ return block(row, 0, 1, cols());
+ }
+
+ /**
+ * Returns a column slice of the variable matrix.
+ *
+ * @param col The column to slice.
+ * @return A column slice of the variable matrix.
+ */
+ public VariableBlock col(int col) {
+ assert col >= 0 && col < cols();
+ return block(0, col, rows(), 1);
+ }
+
+ /**
+ * Matrix multiplication operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of matrix multiplication.
+ */
+ public VariableMatrix times(VariableMatrix rhs) {
+ assert cols() == rhs.rows();
+
+ var result = new VariableMatrix(rows(), rhs.cols());
+
+ for (int i = 0; i < rows(); ++i) {
+ for (int j = 0; j < rhs.cols(); ++j) {
+ var sum = new Variable(0.0);
+ for (int k = 0; k < cols(); ++k) {
+ sum = sum.plus(get(i, k).times(rhs.get(k, j)));
+ }
+ result.set(i, j, sum);
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Matrix multiplication operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of matrix multiplication.
+ */
+ public VariableMatrix times(VariableBlock rhs) {
+ assert cols() == rhs.rows();
+
+ var result = new VariableMatrix(rows(), rhs.cols());
+
+ for (int i = 0; i < rows(); ++i) {
+ for (int j = 0; j < rhs.cols(); ++j) {
+ var sum = new Variable(0.0);
+ for (int k = 0; k < cols(); ++k) {
+ sum = sum.plus(get(i, k).times(rhs.get(k, j)));
+ }
+ result.set(i, j, sum);
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Matrix multiplication operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of matrix multiplication.
+ */
+ public VariableMatrix times(SimpleMatrix rhs) {
+ return times(new VariableMatrix(rhs));
+ }
+
+ /**
+ * Matrix-scalar multiplication operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of matrix-scalar multiplication.
+ */
+ public VariableMatrix times(double rhs) {
+ return times(new Variable(rhs));
+ }
+
+ /**
+ * Matrix-scalar multiplication operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of matrix-scalar multiplication.
+ */
+ public VariableMatrix times(Variable rhs) {
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).times(rhs));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Binary division operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of division.
+ */
+ public VariableMatrix div(double rhs) {
+ return div(new Variable(rhs));
+ }
+
+ /**
+ * Binary division operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of division.
+ */
+ public VariableMatrix div(Variable rhs) {
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).div(rhs));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Binary addition operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of addition.
+ */
+ public VariableMatrix plus(VariableMatrix rhs) {
+ assert rows() == rhs.rows() && cols() == rhs.cols();
+
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).plus(rhs.get(row, col)));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Binary addition operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of addition.
+ */
+ public VariableMatrix plus(VariableBlock rhs) {
+ assert rows() == rhs.rows() && cols() == rhs.cols();
+
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).plus(rhs.get(row, col)));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Binary addition operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of addition.
+ */
+ public VariableMatrix plus(SimpleMatrix rhs) {
+ assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols();
+
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).plus(rhs.get(row, col)));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Binary subtraction operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of subtraction.
+ */
+ public VariableMatrix minus(VariableMatrix rhs) {
+ assert rows() == rhs.rows() && cols() == rhs.cols();
+
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).minus(rhs.get(row, col)));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Binary subtraction operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of subtraction.
+ */
+ public VariableMatrix minus(VariableBlock rhs) {
+ assert rows() == rhs.rows() && cols() == rhs.cols();
+
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).minus(rhs.get(row, col)));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Binary subtraction operator.
+ *
+ * @param rhs Operator right-hand side.
+ * @return Result of subtraction.
+ */
+ public VariableMatrix minus(SimpleMatrix rhs) {
+ assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols();
+
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).minus(rhs.get(row, col)));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Unary minus operator.
+ *
+ * @return Result of unary minus.
+ */
+ public VariableMatrix unaryMinus() {
+ var result = new VariableMatrix(rows(), cols());
+
+ for (int row = 0; row < result.rows(); ++row) {
+ for (int col = 0; col < result.cols(); ++col) {
+ result.set(row, col, get(row, col).unaryMinus());
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Returns the transpose of the variable matrix.
+ *
+ * @return The transpose of the variable matrix.
+ */
+ public VariableMatrix T() {
+ var result = new VariableMatrix(cols(), rows());
+
+ for (int row = 0; row < rows(); ++row) {
+ for (int col = 0; col < cols(); ++col) {
+ result.set(col, row, get(row, col));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Returns the number of rows in the matrix.
+ *
+ * @return The number of rows in the matrix.
+ */
+ public int rows() {
+ return m_rows;
+ }
+
+ /**
+ * Returns the number of columns in the matrix.
+ *
+ * @return The number of columns in the matrix.
+ */
+ public int cols() {
+ return m_cols;
+ }
+
+ /**
+ * Returns an element of the variable matrix.
+ *
+ * @param row The row of the element to return.
+ * @param col The column of the element to return.
+ * @return An element of the variable matrix.
+ */
+ public double value(int row, int col) {
+ return get(row, col).value();
+ }
+
+ /**
+ * Returns an element of the variable matrix.
+ *
+ * @param index The index of the element to return.
+ * @return An element of the variable matrix.
+ */
+ public double value(int index) {
+ return get(index).value();
+ }
+
+ /**
+ * Returns the contents of the variable matrix.
+ *
+ * @return The contents of the variable matrix.
+ */
+ public SimpleMatrix value() {
+ var result = new SimpleMatrix(rows(), cols());
+
+ for (int row = 0; row < rows(); ++row) {
+ for (int col = 0; col < cols(); ++col) {
+ result.set(row, col, value(row, col));
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Maps the matrix coefficient-wise with an unary operator.
+ *
+ * @param unaryOp The unary operator to use for the map operation.
+ * @return Result of the unary operator.
+ */
+ public VariableMatrix cwiseMap(UnaryOperator Each row's blocks must have the same height, and the assembled block rows must have the same
+ * width. For example, for the block matrix [[A, B], [C]] to be constructible, the number of rows
+ * in A and B must match, and the number of columns in [A, B] and [C] must match.
+ *
+ * @param list The nested list of blocks.
+ * @return Block matrix.
+ */
+ @SuppressWarnings("OverloadMethodsDeclarationOrder")
+ public static VariableMatrix block(VariableMatrix[][] list) {
+ // Get row and column counts for destination matrix
+ int rows = 0;
+ int cols = -1;
+ for (var row : list) {
+ if (row.length > 0) {
+ rows += row[0].rows();
+ }
+
+ // Get number of columns in this row
+ int latestCols = 0;
+ for (var elem : row) {
+ // Assert the first and latest row have the same height
+ assert row[0].rows() == elem.rows();
+
+ latestCols += elem.cols();
+ }
+
+ // If this is the first row, record the column count. Otherwise, assert the
+ // first and latest column counts are the same.
+ if (cols == -1) {
+ cols = latestCols;
+ } else {
+ assert cols == latestCols;
+ }
+ }
+
+ var result = new VariableMatrix(rows, cols);
+
+ int rowOffset = 0;
+ for (var row : list) {
+ int colOffset = 0;
+ for (var elem : row) {
+ result.block(rowOffset, colOffset, elem.rows(), elem.cols()).set(elem);
+ colOffset += elem.cols();
+ }
+ if (row.length > 0) {
+ rowOffset += row[0].rows();
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Solves the VariableMatrix equation AX = B for X.
+ *
+ * @param A The left-hand side.
+ * @param B The right-hand side.
+ * @return The solution X.
+ */
+ public static VariableMatrix solve(VariableMatrix A, VariableMatrix B) {
+ // m x n * n x p = m x p
+ assert A.rows() == B.rows();
+
+ return new VariableMatrix(
+ A.cols(),
+ B.cols(),
+ VariableMatrixJNI.solve(A.getHandles(), A.cols(), B.getHandles(), B.cols()));
+ }
+
+ /**
+ * Returns an array of VariableMatrix internal handles in row-major order.
+ *
+ * @return Array of VariableMatrix internal handles in row-major order.
+ */
+ long[] getHandles() {
+ var handles = new long[size()];
+ for (int index = 0; index < size(); ++index) {
+ handles[index] = m_storage[index].getHandle();
+ }
+ return handles;
+ }
+}
diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/VariableMatrixJNI.java b/wpimath/src/main/java/org/wpilib/math/autodiff/VariableMatrixJNI.java
new file mode 100644
index 0000000000..f7267c06fe
--- /dev/null
+++ b/wpimath/src/main/java/org/wpilib/math/autodiff/VariableMatrixJNI.java
@@ -0,0 +1,25 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package org.wpilib.math.autodiff;
+
+import org.wpilib.math.jni.WPIMathJNI;
+
+/** Variable JNI functions. */
+final class VariableMatrixJNI extends WPIMathJNI {
+ private VariableMatrixJNI() {
+ // Utility class.
+ }
+
+ /**
+ * Solves the VariableMatrix equation AX = B for X.
+ *
+ * @param A The left-hand side as a flattened row-major matrix.
+ * @param Acols The number of columns in A.
+ * @param B The right-hand side as a flattened row-major matrix.
+ * @param Bcols The number of columns in B.
+ * @return The solution X as a flattened row-major matrix.
+ */
+ static native long[] solve(long[] A, int Acols, long[] B, int Bcols);
+}
diff --git a/wpimath/src/main/java/org/wpilib/math/autodiff/VariablePool.java b/wpimath/src/main/java/org/wpilib/math/autodiff/VariablePool.java
new file mode 100644
index 0000000000..419cd2cfc0
--- /dev/null
+++ b/wpimath/src/main/java/org/wpilib/math/autodiff/VariablePool.java
@@ -0,0 +1,55 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package org.wpilib.math.autodiff;
+
+import java.util.ArrayDeque;
+import java.util.Deque;
+import org.wpilib.util.ErrorMessages;
+import org.wpilib.util.cleanup.CleanupPool;
+
+/**
+ * Cleans up implicitly allocated Variables via try-with-resources.
+ *
+ * This implements a stack of Variable pools containing a default global pool. The user can
+ * create additional pools via try-with-resources. Variable and VariableMatrix instances will
+ * register themselves with the latest pool.
+ *
+ * It's strongly recommended to only instantiate this class via try-with-resources so the close()
+ * methods are always called in the correct order (i.e., nested scopes).
+ */
+public class VariablePool implements AutoCloseable {
+ private static Deque Constructs an ArmFeedforward object and runs its currentVelocity and nextVelocity overload
- *
- * @param ks The ArmFeedforward's static gain in volts.
- * @param kv The ArmFeedforward's velocity gain in volt seconds per radian.
- * @param ka The ArmFeedforward's acceleration gain in volt seconds² per radian.
- * @param kg The ArmFeedforward's gravity gain in volts.
- * @param currentAngle The current angle in the calculation in radians.
- * @param currentVelocity The current velocity in the calculation in radians per second.
- * @param nextVelocity The next velocity in the calculation in radians per second.
- * @param dt The time between velocity setpoints in seconds.
- * @return The calculated feedforward in volts.
- */
- public static native double calculate(
- double ks,
- double kv,
- double ka,
- double kg,
- double currentAngle,
- double currentVelocity,
- double nextVelocity,
- double dt);
-
- /** Utility class. */
- private ArmFeedforwardJNI() {}
-}
diff --git a/wpimath/src/main/java/org/wpilib/math/jni/Ellipse2dJNI.java b/wpimath/src/main/java/org/wpilib/math/jni/Ellipse2dJNI.java
deleted file mode 100644
index d766ddbf63..0000000000
--- a/wpimath/src/main/java/org/wpilib/math/jni/Ellipse2dJNI.java
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright (c) FIRST and other WPILib contributors.
-// Open Source Software; you can modify and/or share it under the terms of
-// the WPILib BSD license file in the root directory of this project.
-
-package org.wpilib.math.jni;
-
-/** Ellipse2d JNI. */
-public final class Ellipse2dJNI extends WPIMathJNI {
- /**
- * Returns the nearest point that is contained within the ellipse.
- *
- * Constructs an Ellipse2d object and runs its nearest() method.
- *
- * @param centerX The x coordinate of the center of the ellipse in meters.
- * @param centerY The y coordinate of the center of the ellipse in meters.
- * @param centerHeading The ellipse's rotation in radians.
- * @param xSemiAxis The x semi-axis in meters.
- * @param ySemiAxis The y semi-axis in meters.
- * @param pointX The x coordinate of the point that this will find the nearest point to.
- * @param pointY The y coordinate of the point that this will find the nearest point to.
- * @param nearestPoint Array to store nearest point into.
- */
- public static native void nearest(
- double centerX,
- double centerY,
- double centerHeading,
- double xSemiAxis,
- double ySemiAxis,
- double pointX,
- double pointY,
- double[] nearestPoint);
-
- /** Utility class. */
- private Ellipse2dJNI() {}
-}
diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/Constraints.java b/wpimath/src/main/java/org/wpilib/math/optimization/Constraints.java
new file mode 100644
index 0000000000..188ac44bb8
--- /dev/null
+++ b/wpimath/src/main/java/org/wpilib/math/optimization/Constraints.java
@@ -0,0 +1,1472 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package org.wpilib.math.optimization;
+
+import java.util.ArrayList;
+import org.ejml.simple.SimpleMatrix;
+import org.wpilib.math.autodiff.Variable;
+import org.wpilib.math.autodiff.VariableBlock;
+import org.wpilib.math.autodiff.VariableMatrix;
+
+/** Constraint creation helper functions. */
+public final class Constraints {
+ /** Utility class. */
+ private Constraints() {}
+
+ // ==
+
+ /**
+ * Equality operator that returns an equality constraint for a double and a Variable.
+ *
+ * @param lhs Left-hand side.
+ * @param rhs Right-hand side.
+ * @return Equality constraints.
+ */
+ public static EqualityConstraints eq(double lhs, Variable rhs) {
+ return eq(new Variable(lhs), rhs);
+ }
+
+ /**
+ * Equality operator that returns an equality constraint for a Variable and a double.
+ *
+ * @param lhs Left-hand side.
+ * @param rhs Right-hand side.
+ * @return Equality constraints.
+ */
+ public static EqualityConstraints eq(Variable lhs, double rhs) {
+ return eq(lhs, new Variable(rhs));
+ }
+
+ /**
+ * Equality operator that returns an equality constraint for two Variables.
+ *
+ * @param lhs Left-hand side.
+ * @param rhs Right-hand side.
+ * @return Equality constraints.
+ */
+ public static EqualityConstraints eq(Variable lhs, Variable rhs) {
+ return new EqualityConstraints(new Variable[] {lhs.minus(rhs)});
+ }
+
+ /**
+ * Equality operator that returns an equality constraint for a double and a VariableBlock.
+ *
+ * @param lhs Left-hand side.
+ * @param rhs Right-hand side.
+ * @return Equality constraints.
+ */
+ public static EqualityConstraints eq(double lhs, VariableBlock rhs) {
+ return eq(new Variable(lhs), new VariableMatrix(rhs));
+ }
+
+ /**
+ * Equality operator that returns an equality constraint for a Variable and a VariableBlock.
+ *
+ * @param lhs Left-hand side.
+ * @param rhs Right-hand side.
+ * @return Equality constraints.
+ */
+ public static EqualityConstraints eq(Variable lhs, VariableBlock rhs) {
+ return eq(lhs, new VariableMatrix(rhs));
+ }
+
+ /**
+ * Equality operator that returns an equality constraint for a double and a VariableMatrix.
+ *
+ * @param lhs Left-hand side.
+ * @param rhs Right-hand side.
+ * @return Equality constraints.
+ */
+ public static EqualityConstraints eq(double lhs, VariableMatrix rhs) {
+ return eq(new Variable(lhs), rhs);
+ }
+
+ /**
+ * Equality operator that returns an equality constraint for a Variable and a VariableMatrix.
+ *
+ * @param lhs Left-hand side.
+ * @param rhs Right-hand side.
+ * @return Equality constraints.
+ */
+ public static EqualityConstraints eq(Variable lhs, VariableMatrix rhs) {
+ var constraints = new ArrayList The system is transcripted by one of three methods (direct transcription, direct collocation,
+ * or single-shooting) and additional constraints can be added.
+ *
+ * In direct transcription, each state is a decision variable constrained to the integrated
+ * dynamics of the previous state. In direct collocation, the trajectory is modeled as a series of
+ * cubic polynomials where the centerpoint slope is constrained. In single-shooting, states depend
+ * explicitly as a function of all previous states and all previous inputs.
+ *
+ * Explicit ODEs are integrated using RK4.
+ *
+ * For explicit ODEs, the function must be in the form dx/dt = f(t, x, u). For discrete state
+ * transition functions, the function must be in the form xₖ₊₁ = f(t, xₖ, uₖ).
+ *
+ * Direct collocation requires an explicit ODE. Direct transcription and single-shooting can use
+ * either an ODE or state transition function.
+ *
+ * https://underactuated.mit.edu/trajopt.html goes into more detail on each transcription method.
+ */
+public class OCP extends Problem {
+ private int m_numSteps;
+
+ private DynamicsFunction m_dynamics;
+ private DynamicsType m_dynamicsType;
+
+ private VariableMatrix m_X;
+ private VariableMatrix m_U;
+ private VariableMatrix m_DT;
+
+ /**
+ * Builds an optimization problem using a system evolution function (explicit ODE or discrete
+ * state transition function).
+ *
+ * @param numStates The number of system states.
+ * @param numInputs The number of system inputs.
+ * @param dt The timestep for fixed-step integration.
+ * @param numSteps The number of control points.
+ * @param dynamics Function representing an explicit or implicit ODE, or a discrete state
+ * transition function.
+ * Shaped (numStates)x(numSteps+1).
+ *
+ * @return The state variable matrix.
+ */
+ public VariableMatrix X() {
+ return m_X;
+ }
+
+ /**
+ * Gets the input variables. After the problem is solved, this will contain the inputs
+ * corresponding to the optimized trajectory.
+ *
+ * Shaped (numInputs)x(numSteps+1), although the last input step is unused in the trajectory.
+ *
+ * @return The input variable matrix.
+ */
+ public VariableMatrix U() {
+ return m_U;
+ }
+
+ /**
+ * Gets the timestep variables. After the problem is solved, this will contain the timesteps
+ * corresponding to the optimized trajectory.
+ *
+ * Shaped 1x(numSteps+1), although the last timestep is unused in the trajectory.
+ *
+ * @return The timestep variable matrix.
+ */
+ public VariableMatrix dt() {
+ return m_DT;
+ }
+
+ /**
+ * Gets the initial state in the trajectory.
+ *
+ * @return The initial state of the trajectory.
+ */
+ public VariableMatrix initialState() {
+ return new VariableMatrix(m_X.col(0));
+ }
+
+ /**
+ * Gets the final state in the trajectory.
+ *
+ * @return The final state of the trajectory.
+ */
+ public VariableMatrix finalState() {
+ return new VariableMatrix(m_X.col(m_numSteps));
+ }
+
+ /**
+ * Performs 4th order Runge-Kutta integration of dx/dt = f(t, x, u) for dt.
+ *
+ * @param f The function to integrate. It must take two arguments x and u.
+ * @param x The initial value of x.
+ * @param u The value u held constant over the integration period.
+ * @param t0 The initial time.
+ * @param dt The time over which to integrate.
+ */
+ private static VariableMatrix rk4(
+ DynamicsFunction f, VariableMatrix x, VariableMatrix u, Variable t0, Variable dt) {
+ var halfdt = dt.times(0.5);
+ VariableMatrix k1 = f.apply(t0, x, u, dt);
+ VariableMatrix k2 = f.apply(t0.plus(halfdt), x.plus(k1.times(halfdt)), u, dt);
+ VariableMatrix k3 = f.apply(t0.plus(halfdt), x.plus(k2.times(halfdt)), u, dt);
+ VariableMatrix k4 = f.apply(t0.plus(dt), x.plus(k3.times(dt)), u, dt);
+
+ return x.plus(k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4).times(dt.div(6.0)));
+ }
+
+ /** Applies direct collocation dynamics constraints. */
+ private void constrainDirectCollocation() {
+ assert m_dynamicsType == DynamicsType.EXPLICIT_ODE;
+
+ var time = new Variable(0.0);
+
+ // Derivation at https://mec560sbu.github.io/2016/09/30/direct_collocation/
+ for (int i = 0; i < m_numSteps; ++i) {
+ Variable h = dt().get(0, i);
+
+ var f = m_dynamics;
+
+ var t_begin = time;
+ var t_end = t_begin.plus(h);
+
+ var x_begin = X().col(i);
+ var x_end = X().col(i + 1);
+
+ var u_begin = U().col(i);
+ var u_end = U().col(i + 1);
+
+ var xdot_begin =
+ f.apply(t_begin, new VariableMatrix(x_begin), new VariableMatrix(u_begin), h);
+ var xdot_end = f.apply(t_end, new VariableMatrix(x_end), new VariableMatrix(u_end), h);
+ var xdot_c =
+ x_begin
+ .minus(x_end)
+ .times(new Variable(-3).div(h.times(2)))
+ .minus(xdot_begin.plus(xdot_end).times(0.25));
+
+ var t_c = t_begin.plus(h.times(0.5));
+ var x_c = x_begin.plus(x_end).times(0.5).plus(xdot_begin.minus(xdot_end).times(h.div(8)));
+ var u_c = u_begin.plus(u_end).times(0.5);
+
+ subjectTo(eq(xdot_c, f.apply(t_c, x_c, u_c, h)));
+
+ time = time.plus(h);
+ }
+ }
+
+ /** Applies direct transcription dynamics constraints. */
+ private void constrainDirectTranscription() {
+ var time = new Variable(0.0);
+
+ for (int i = 0; i < m_numSteps; ++i) {
+ var x_begin = X().col(i);
+ var x_end = X().col(i + 1);
+ var u = U().col(i);
+ Variable dt = this.dt().get(0, i);
+
+ if (m_dynamicsType == DynamicsType.EXPLICIT_ODE) {
+ subjectTo(
+ eq(
+ x_end,
+ rk4(m_dynamics, new VariableMatrix(x_begin), new VariableMatrix(u), time, dt)));
+ } else if (m_dynamicsType == DynamicsType.DISCRETE) {
+ subjectTo(
+ eq(
+ x_end,
+ m_dynamics.apply(time, new VariableMatrix(x_begin), new VariableMatrix(u), dt)));
+ }
+
+ time = time.plus(dt);
+ }
+ }
+
+ /** Applies single shooting dynamics constraints. */
+ private void constrainSingleShooting() {
+ var time = new Variable(0.0);
+
+ for (int i = 0; i < m_numSteps; ++i) {
+ var x_begin = X().col(i);
+ var x_end = X().col(i + 1);
+ var u = U().col(i);
+ Variable dt = this.dt().get(0, i);
+
+ if (m_dynamicsType == DynamicsType.EXPLICIT_ODE) {
+ x_end.set(rk4(m_dynamics, new VariableMatrix(x_begin), new VariableMatrix(u), time, dt));
+ } else if (m_dynamicsType == DynamicsType.DISCRETE) {
+ x_end.set(m_dynamics.apply(time, new VariableMatrix(x_begin), new VariableMatrix(u), dt));
+ }
+
+ time = time.plus(dt);
+ }
+ }
+}
diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/Problem.java b/wpimath/src/main/java/org/wpilib/math/optimization/Problem.java
new file mode 100644
index 0000000000..7d5502e52a
--- /dev/null
+++ b/wpimath/src/main/java/org/wpilib/math/optimization/Problem.java
@@ -0,0 +1,312 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package org.wpilib.math.optimization;
+
+import java.util.ArrayList;
+import java.util.function.Predicate;
+import org.ejml.data.DMatrixRMaj;
+import org.ejml.simple.SimpleMatrix;
+import org.wpilib.math.autodiff.ExpressionType;
+import org.wpilib.math.autodiff.NativeSparseTriplets;
+import org.wpilib.math.autodiff.Variable;
+import org.wpilib.math.autodiff.VariableMatrix;
+import org.wpilib.math.autodiff.VariablePool;
+import org.wpilib.math.optimization.solver.ExitStatus;
+import org.wpilib.math.optimization.solver.IterationInfo;
+import org.wpilib.math.optimization.solver.Options;
+
+/**
+ * This class allows the user to pose a constrained nonlinear optimization problem in natural
+ * mathematical notation and solve it.
+ *
+ * This class supports problems of the form:
+ *
+ * where f(x) is the scalar cost function, x is the vector of decision variables (variables the
+ * solver can tweak to minimize the cost function), cᵢ(x) are the inequality constraints, and cₑ(x)
+ * are the equality constraints. Constraints are equations or inequalities of the decision variables
+ * that constrain what values the solver is allowed to use when searching for an optimal solution.
+ *
+ * The nice thing about this class is users don't have to put their system in the form shown
+ * above manually; they can write it in natural mathematical form and it'll be converted for them.
+ */
+public class Problem implements AutoCloseable {
+ private long m_handle;
+
+ // The iteration callbacks
+ private final ArrayList Decision variables have an initial value of zero.
+ *
+ * @return A decision variable in the optimization problem.
+ */
+ public Variable decisionVariable() {
+ var handles = ProblemJNI.decisionVariable(m_handle, 1, 1);
+ return new Variable(Variable.HANDLE, handles[0]);
+ }
+
+ /**
+ * Creates a column vector of decision variables in the optimization problem.
+ *
+ * Decision variables have an initial value of zero.
+ *
+ * @param rows Number of column vector rows.
+ * @return A column vector of decision variables in the optimization problem.
+ */
+ public VariableMatrix decisionVariable(int rows) {
+ return decisionVariable(rows, 1);
+ }
+
+ /**
+ * Creates a matrix of decision variables in the optimization problem.
+ *
+ * Decision variables have an initial value of zero.
+ *
+ * @param rows Number of matrix rows.
+ * @param cols Number of matrix columns.
+ * @return A matrix of decision variables in the optimization problem.
+ */
+ public VariableMatrix decisionVariable(int rows, int cols) {
+ return new VariableMatrix(rows, cols, ProblemJNI.decisionVariable(m_handle, rows, cols));
+ }
+
+ /**
+ * Creates a symmetric matrix of decision variables in the optimization problem.
+ *
+ * Variable instances are reused across the diagonal, which helps reduce problem
+ * dimensionality.
+ *
+ * Decision variables have an initial value of zero.
+ *
+ * @param rows Number of matrix rows.
+ * @return A symmetric matrix of decision varaibles in the optimization problem.
+ */
+ public VariableMatrix symmetricDecisionVariable(int rows) {
+ return new VariableMatrix(rows, rows, ProblemJNI.symmetricDecisionVariable(m_handle, rows));
+ }
+
+ /**
+ * Tells the solver to minimize the output of the given cost function.
+ *
+ * Note that this is optional. If only constraints are specified, the solver will find the
+ * closest solution to the initial conditions that's in the feasible set.
+ *
+ * @param cost The cost function to minimize.
+ */
+ public void minimize(Variable cost) {
+ ProblemJNI.minimize(m_handle, cost.getHandle());
+ }
+
+ /**
+ * Tells the solver to minimize the output of the given cost function.
+ *
+ * Note that this is optional. If only constraints are specified, the solver will find the
+ * closest solution to the initial conditions that's in the feasible set.
+ *
+ * @param cost The cost function to minimize. An assertion is raised if the VariableMatrix isn't
+ * 1x1.
+ */
+ public void minimize(VariableMatrix cost) {
+ assert cost.rows() == 1 && cost.cols() == 1;
+ minimize(cost.get(0, 0));
+ }
+
+ /**
+ * Tells the solver to maximize the output of the given objective function.
+ *
+ * Note that this is optional. If only constraints are specified, the solver will find the
+ * closest solution to the initial conditions that's in the feasible set.
+ *
+ * @param objective The objective function to maximize.
+ */
+ public void maximize(Variable objective) {
+ ProblemJNI.maximize(m_handle, objective.getHandle());
+ }
+
+ /**
+ * Tells the solver to maximize the output of the given objective function.
+ *
+ * Note that this is optional. If only constraints are specified, the solver will find the
+ * closest solution to the initial conditions that's in the feasible set.
+ *
+ * @param objective The objective function to maximize. An assertion is raised if the
+ * VariableMatrix isn't 1x1.
+ */
+ public void maximize(VariableMatrix objective) {
+ assert objective.rows() == 1 && objective.cols() == 1;
+ maximize(objective.get(0, 0));
+ }
+
+ /**
+ * Tells the solver to solve the problem while satisfying the given equality constraint.
+ *
+ * @param constraint The constraint to satisfy.
+ */
+ public void subjectTo(EqualityConstraints constraint) {
+ var constraintHandles = new long[constraint.constraints.length];
+ for (int i = 0; i < constraintHandles.length; ++i) {
+ constraintHandles[i] = constraint.constraints[i].getHandle();
+ }
+ ProblemJNI.subjectToEq(m_handle, constraintHandles);
+ }
+
+ /**
+ * Tells the solver to solve the problem while satisfying the given inequality constraint.
+ *
+ * @param constraint The constraint to satisfy.
+ */
+ public void subjectTo(InequalityConstraints constraint) {
+ var constraintHandles = new long[constraint.constraints.length];
+ for (int i = 0; i < constraintHandles.length; ++i) {
+ constraintHandles[i] = constraint.constraints[i].getHandle();
+ }
+ ProblemJNI.subjectToIneq(m_handle, constraintHandles);
+ }
+
+ /**
+ * Returns the cost function's type.
+ *
+ * @return The cost function's type.
+ */
+ public ExpressionType costFunctionType() {
+ return ExpressionType.fromInt(ProblemJNI.costFunctionType(m_handle));
+ }
+
+ /**
+ * Returns the type of the highest order equality constraint.
+ *
+ * @return The type of the highest order equality constraint.
+ */
+ public ExpressionType equalityConstraintType() {
+ return ExpressionType.fromInt(ProblemJNI.equalityConstraintType(m_handle));
+ }
+
+ /**
+ * Returns the type of the highest order inequality constraint.
+ *
+ * @return The type of the highest order inequality constraint.
+ */
+ public ExpressionType inequalityConstraintType() {
+ return ExpressionType.fromInt(ProblemJNI.inequalityConstraintType(m_handle));
+ }
+
+ /**
+ * Solves the optimization problem. The solution will be stored in the original variables used to
+ * construct the problem.
+ *
+ * @return The solver status.
+ */
+ public ExitStatus solve() {
+ return solve(new Options());
+ }
+
+ /**
+ * Solves the optimization problem. The solution will be stored in the original variables used to
+ * construct the problem.
+ *
+ * @param options Solver options.
+ * @return The solver status.
+ */
+ public ExitStatus solve(Options options) {
+ return ExitStatus.fromInt(
+ ProblemJNI.solve(
+ this,
+ m_handle,
+ options.tolerance,
+ options.maxIterations,
+ options.timeout,
+ options.feasibleIPM,
+ options.diagnostics));
+ }
+
+ /**
+ * Adds a callback to be called at the beginning of each solver iteration.
+ *
+ * The callback for this overload should return bool.
+ *
+ * @param callback The callback. Returning true from the callback causes the solver to exit early
+ * with the solution it has so far.
+ */
+ public void addCallback(Predicate This function is called by native code in ProblemJNI.
+ *
+ * @param numEqualityConstraints The number of equality constraints.
+ * @param numInequalityConstraints The number of inequality constraints.
+ * @param iteration The solver iteration.
+ * @param x The decision variable values.
+ * @param gTriplets Gradient triplets.
+ * @param HTriplets Hessian triplets.
+ * @param A_eTriplets Equality constraint Jacobian triplets.
+ * @param A_iTriplets Inequality constraint Jacobian triplets.
+ * @return True if the solver shold exit early.
+ */
+ boolean runCallbacks(
+ int numEqualityConstraints,
+ int numInequalityConstraints,
+ int iteration,
+ double[] x,
+ NativeSparseTriplets gTriplets,
+ NativeSparseTriplets HTriplets,
+ NativeSparseTriplets A_eTriplets,
+ NativeSparseTriplets A_iTriplets) {
+ if (m_iterationCallbacks.isEmpty()) {
+ return false;
+ }
+
+ var info =
+ new IterationInfo(
+ iteration,
+ new SimpleMatrix(DMatrixRMaj.wrap(x.length, 1, x)),
+ gTriplets.toSimpleMatrix(x.length, 1),
+ HTriplets.toSimpleMatrix(x.length, x.length),
+ A_eTriplets.toSimpleMatrix(numEqualityConstraints, x.length),
+ A_iTriplets.toSimpleMatrix(numInequalityConstraints, x.length));
+
+ for (var callback : m_iterationCallbacks) {
+ if (callback.test(info)) {
+ return true;
+ }
+ }
+ return false;
+ }
+}
diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/ProblemJNI.java b/wpimath/src/main/java/org/wpilib/math/optimization/ProblemJNI.java
new file mode 100644
index 0000000000..c3f22ca8c7
--- /dev/null
+++ b/wpimath/src/main/java/org/wpilib/math/optimization/ProblemJNI.java
@@ -0,0 +1,134 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package org.wpilib.math.optimization;
+
+import org.wpilib.math.jni.WPIMathJNI;
+
+/** Problem JNI functions. */
+final class ProblemJNI extends WPIMathJNI {
+ private ProblemJNI() {
+ // Utility class.
+ }
+
+ /** Construct the optimization problem. */
+ static native long create();
+
+ /**
+ * Destruct the optimization problem.
+ *
+ * @param handle Problem handle.
+ */
+ static native void destroy(long handle);
+
+ /**
+ * Create a matrix of decision variables in the optimization problem.
+ *
+ * @param handle Problem handle.
+ * @param rows Number of matrix rows.
+ * @param cols Number of matrix columns.
+ * @return A matrix of decision variables in the optimization problem.
+ */
+ static native long[] decisionVariable(long handle, int rows, int cols);
+
+ /**
+ * Create a symmetric matrix of decision variables in the optimization problem.
+ *
+ * Variable instances are reused across the diagonal, which helps reduce problem
+ * dimensionality.
+ *
+ * @param handle Problem handle.
+ * @param rows Number of matrix rows.
+ * @return A symmetric matrix of decision varaibles in the optimization problem.
+ */
+ static native long[] symmetricDecisionVariable(long handle, int rows);
+
+ /**
+ * Tells the solver to minimize the output of the given cost function.
+ *
+ * Note that this is optional. If only constraints are specified, the solver will find the
+ * closest solution to the initial conditions that's in the feasible set.
+ *
+ * @param handle Problem handle.
+ * @param costHandle Variable handle of the cost function to minimize.
+ */
+ static native void minimize(long handle, long costHandle);
+
+ /**
+ * Tells the solver to maximize the output of the given objective function.
+ *
+ * Note that this is optional. If only constraints are specified, the solver will find the
+ * closest solution to the initial conditions that's in the feasible set.
+ *
+ * @param handle Problem handle.
+ * @param objectiveHandle Variable handle of the objective function to maximize.
+ */
+ static native void maximize(long handle, long objectiveHandle);
+
+ /**
+ * Tells the solver to solve the problem while satisfying the given equality constraint.
+ *
+ * @param handle Problem handle.
+ * @param constraintHandles Constraint handles.
+ */
+ static native void subjectToEq(long handle, long[] constraintHandles);
+
+ /**
+ * Tells the solver to solve the problem while satisfying the given inequality constraint.
+ *
+ * @param handle Problem handle.
+ * @param constraintHandles Constraint handles.
+ */
+ static native void subjectToIneq(long handle, long[] constraintHandles);
+
+ /**
+ * Returns the cost function's type.
+ *
+ * @param handle Problem handle.
+ * @return The cost function's type.
+ */
+ static native int costFunctionType(long handle);
+
+ /**
+ * Returns the type of the highest order equality constraint.
+ *
+ * @param handle Problem handle.
+ * @return The type of the highest order equality constraint.
+ */
+ static native int equalityConstraintType(long handle);
+
+ /**
+ * Returns the type of the highest order inequality constraint.
+ *
+ * @param handle Problem handle.
+ * @return The type of the highest order inequality constraint.
+ */
+ static native int inequalityConstraintType(long handle);
+
+ /**
+ * Solve the optimization problem. The solution will be stored in the original variables used to
+ * construct the problem.
+ *
+ * @param obj Java Problem object.
+ * @param handle Problem handle.
+ * @param tolerance The solver will stop once the error is below this tolerance.
+ * @param maxIterations The maximum number of solver iterations before returning a solution.
+ * @param timeout The maximum elapsed wall clock time in seconds before returning a solution.
+ * @param feasibleIPM Enables the feasible interior-point method. When the inequality constraints
+ * are all feasible, step sizes are reduced when necessary to prevent them becoming infeasible
+ * again. This is useful when parts of the problem are ill-conditioned in infeasible regions
+ * (e.g., square root of a negative value). This can slow or prevent progress toward a
+ * solution though, so only enable it if necessary.
+ * @param diagnnostics Enables diagnostic prints.
+ * @return The solver status.
+ */
+ static native int solve(
+ Problem obj,
+ long handle,
+ double tolerance,
+ int maxIterations,
+ double timeout,
+ boolean feasibleIPM,
+ boolean diagnostics);
+}
diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/ocp/ConstraintEvaluationFunction.java b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/ConstraintEvaluationFunction.java
new file mode 100644
index 0000000000..8c44700f49
--- /dev/null
+++ b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/ConstraintEvaluationFunction.java
@@ -0,0 +1,25 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package org.wpilib.math.optimization.ocp;
+
+import org.wpilib.math.autodiff.Variable;
+import org.wpilib.math.autodiff.VariableMatrix;
+
+/**
+ * A callback f(t, x, u, dt) where t is time, x is the state vector, u is the input vector, and dt
+ * is the timestep duration.
+ */
+@FunctionalInterface
+public interface ConstraintEvaluationFunction {
+ /**
+ * Applies this function with the arguments.
+ *
+ * @param t Time in seconds.
+ * @param x State vector.
+ * @param u Input vector.
+ * @param dt Timestep duration in seconds.
+ */
+ void accept(Variable t, VariableMatrix x, VariableMatrix u, Variable dt);
+}
diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/ocp/DynamicsFunction.java b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/DynamicsFunction.java
new file mode 100644
index 0000000000..d60fae255a
--- /dev/null
+++ b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/DynamicsFunction.java
@@ -0,0 +1,31 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package org.wpilib.math.optimization.ocp;
+
+import org.wpilib.math.autodiff.Variable;
+import org.wpilib.math.autodiff.VariableMatrix;
+
+/**
+ * Function representing an explicit or implicit ODE, or a discrete state transition function.
+ *
+ * When the inequality constraints are all feasible, step sizes are reduced when necessary to
+ * prevent them becoming infeasible again. This is useful when parts of the problem are
+ * ill-conditioned in infeasible regions (e.g., square root of a negative value). This can slow or
+ * prevent progress toward a solution though, so only enable it if necessary.
+ */
+ public boolean feasibleIPM = false;
+
+ /**
+ * Enables diagnostic output.
+ *
+ * See https://sleipnirgroup.github.io/Sleipnir/md_usage.html#output
+ * for more information.
+ */
+ public boolean diagnostics = false;
+
+ /** Default options. */
+ public Options() {}
+
+ /**
+ * Set tolerance.
+ *
+ * @param tolerance The solver will stop once the error is below this tolerance.
+ * @return This Options object.
+ */
+ public Options withTolerance(double tolerance) {
+ this.tolerance = tolerance;
+ return this;
+ }
+
+ /**
+ * Set max iterations.
+ *
+ * @param maxIterations The maximum number of solver iterations before returning a solution.
+ * @return This Options object.
+ */
+ public Options withMaxIterations(int maxIterations) {
+ this.maxIterations = maxIterations;
+ return this;
+ }
+
+ /**
+ * Set timeout.
+ *
+ * @param timeout The maximum elapsed wall clock time in seconds before returning a solution.
+ * @return This Options object.
+ */
+ public Options withTimeout(double timeout) {
+ this.timeout = timeout;
+ return this;
+ }
+
+ /**
+ * Enable or disable feasible IPM.
+ *
+ * @param feasibleIPM Enables the feasible interior-point method. When the inequality constraints
+ * are all feasible, step sizes are reduced when necessary to prevent them becoming infeasible
+ * again. This is useful when parts of the problem are ill-conditioned in infeasible regions
+ * (e.g., square root of a negative value). This can slow or prevent progress toward a
+ * solution though, so only enable it if necessary.
+ * @return This Options object.
+ */
+ public Options withFeasibleIPM(boolean feasibleIPM) {
+ this.feasibleIPM = feasibleIPM;
+ return this;
+ }
+
+ /**
+ * Enable or disable diagnostics.
+ *
+ * @param diagnostics Enables diagnostic prints.
+ * @return This Options object.
+ */
+ public Options withDiagnostics(boolean diagnostics) {
+ this.diagnostics = diagnostics;
+ return this;
+ }
+}
diff --git a/wpimath/src/main/java/org/wpilib/math/system/NumericalIntegration.java b/wpimath/src/main/java/org/wpilib/math/system/NumericalIntegration.java
index ac760488b7..88ad7c6d41 100644
--- a/wpimath/src/main/java/org/wpilib/math/system/NumericalIntegration.java
+++ b/wpimath/src/main/java/org/wpilib/math/system/NumericalIntegration.java
@@ -8,6 +8,7 @@ import java.util.function.BiFunction;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.UnaryOperator;
+import org.ejml.simple.SimpleMatrix;
import org.wpilib.math.linalg.Matrix;
import org.wpilib.math.numbers.N1;
import org.wpilib.math.util.Num;
@@ -56,6 +57,30 @@ public final class NumericalIntegration {
return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
}
+ /**
+ * Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
+ *
+ * @param f The function to integrate. It must take two arguments x and u.
+ * @param x The initial value of x.
+ * @param u The value u held constant over the integration period.
+ * @param dt The time over which to integrate.
+ * @return the integration of dx/dt = f(x, u) for dt.
+ */
+ public static SimpleMatrix rk4(
+ BiFunction
+ *
+ *
+ * @param dynamicsType The type of system evolution function.
+ * @param timestepMethod The timestep method.
+ * @param transcriptionMethod The transcription method.
+ */
+ public OCP(
+ int numStates,
+ int numInputs,
+ double dt,
+ int numSteps,
+ BiFunction
+ *
+ *
+ * @param dynamicsType The type of system evolution function.
+ * @param timestepMethod The timestep method.
+ * @param transcriptionMethod The transcription method.
+ */
+ @SuppressWarnings("this-escape")
+ public OCP(
+ int numStates,
+ int numInputs,
+ double dt,
+ int numSteps,
+ DynamicsFunction dynamics,
+ DynamicsType dynamicsType,
+ TimestepMethod timestepMethod,
+ TranscriptionMethod transcriptionMethod) {
+ m_numSteps = numSteps;
+ m_dynamics = dynamics;
+ m_dynamicsType = dynamicsType;
+
+ // u is numSteps + 1 so that the final constraint function evaluation works
+ m_U = decisionVariable(numInputs, m_numSteps + 1);
+
+ if (timestepMethod == TimestepMethod.FIXED) {
+ m_DT = new VariableMatrix(1, m_numSteps + 1);
+ for (int i = 0; i < numSteps + 1; ++i) {
+ m_DT.set(0, i, dt);
+ }
+ } else if (timestepMethod == TimestepMethod.VARIABLE_SINGLE) {
+ Variable single_dt = decisionVariable();
+ single_dt.setValue(dt);
+
+ // Set the member variable matrix to track the decision variable
+ m_DT = new VariableMatrix(1, m_numSteps + 1);
+ for (int i = 0; i < numSteps + 1; ++i) {
+ m_DT.set(0, i, single_dt);
+ }
+ } else if (timestepMethod == TimestepMethod.VARIABLE) {
+ m_DT = decisionVariable(1, m_numSteps + 1);
+ for (int i = 0; i < numSteps + 1; ++i) {
+ m_DT.get(0, i).setValue(dt);
+ }
+ }
+
+ if (transcriptionMethod == TranscriptionMethod.DIRECT_TRANSCRIPTION) {
+ m_X = decisionVariable(numStates, m_numSteps + 1);
+ constrainDirectTranscription();
+ } else if (transcriptionMethod == TranscriptionMethod.DIRECT_COLLOCATION) {
+ m_X = decisionVariable(numStates, m_numSteps + 1);
+ constrainDirectCollocation();
+ } else if (transcriptionMethod == TranscriptionMethod.SINGLE_SHOOTING) {
+ // In single-shooting the states aren't decision variables, but instead
+ // depend on the input and previous states
+ m_X = new VariableMatrix(numStates, m_numSteps + 1);
+ constrainSingleShooting();
+ }
+ }
+
+ /**
+ * Constrains the initial state.
+ *
+ * @param initialState the initial state to constrain to.
+ */
+ public void constrainInitialState(double initialState) {
+ subjectTo(eq(this.initialState(), initialState));
+ }
+
+ /**
+ * Constrains the initial state.
+ *
+ * @param initialState the initial state to constrain to.
+ */
+ public void constrainInitialState(Variable initialState) {
+ subjectTo(eq(this.initialState(), initialState));
+ }
+
+ /**
+ * Constrains the initial state.
+ *
+ * @param initialState the initial state to constrain to.
+ */
+ public void constrainInitialState(SimpleMatrix initialState) {
+ subjectTo(eq(this.initialState(), initialState));
+ }
+
+ /**
+ * Constrains the initial state.
+ *
+ * @param initialState the initial state to constrain to.
+ */
+ public void constrainInitialState(VariableMatrix initialState) {
+ subjectTo(eq(this.initialState(), initialState));
+ }
+
+ /**
+ * Constrains the initial state.
+ *
+ * @param initialState the initial state to constrain to.
+ */
+ public void constrainInitialState(VariableBlock initialState) {
+ subjectTo(eq(this.initialState(), initialState));
+ }
+
+ /**
+ * Constrains the final state.
+ *
+ * @param finalState the final state to constrain to.
+ */
+ public void constrainFinalState(double finalState) {
+ subjectTo(eq(this.finalState(), finalState));
+ }
+
+ /**
+ * Constrains the final state.
+ *
+ * @param finalState the final state to constrain to.
+ */
+ public void constrainFinalState(Variable finalState) {
+ subjectTo(eq(this.finalState(), finalState));
+ }
+
+ /**
+ * Constrains the final state.
+ *
+ * @param finalState the final state to constrain to.
+ */
+ public void constrainFinalState(SimpleMatrix finalState) {
+ subjectTo(eq(this.finalState(), finalState));
+ }
+
+ /**
+ * Constrains the final state.
+ *
+ * @param finalState the final state to constrain to.
+ */
+ public void constrainFinalState(VariableMatrix finalState) {
+ subjectTo(eq(this.finalState(), finalState));
+ }
+
+ /**
+ * Constrains the final state.
+ *
+ * @param finalState the final state to constrain to.
+ */
+ public void constrainFinalState(VariableBlock finalState) {
+ subjectTo(eq(this.finalState(), finalState));
+ }
+
+ /**
+ * Sets the constraint evaluation function. This function is called `numSteps+1` times, with the
+ * corresponding state and input VariableMatrices.
+ *
+ * @param callback The callback f(x, u) where x is the state and u is the input vector.
+ */
+ public void forEachStep(BiConsumer
+ * minₓ f(x)
+ * subject to cₑ(x) = 0
+ * cᵢ(x) ≥ 0
+ *
+ *
+ *
+ *
+ */
+@FunctionalInterface
+public interface DynamicsFunction {
+ /**
+ * Applies this function with the arguments and returns the result.
+ *
+ * @param t Time in seconds.
+ * @param x State vector.
+ * @param u Input vector.
+ * @param dt Timestep duration in seconds.
+ * @return The state derivative dx/dt.
+ */
+ VariableMatrix apply(Variable t, VariableMatrix x, VariableMatrix u, Variable dt);
+}
diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/ocp/DynamicsType.java b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/DynamicsType.java
new file mode 100644
index 0000000000..e5ccea4763
--- /dev/null
+++ b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/DynamicsType.java
@@ -0,0 +1,20 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package org.wpilib.math.optimization.ocp;
+
+/** Enum describing a type of system dynamics constraints. */
+public enum DynamicsType {
+ /** The dynamics are a function in the form dx/dt = f(t, x, u). */
+ EXPLICIT_ODE(0),
+ /** The dynamics are a function in the form xₖ₊₁ = f(t, xₖ, uₖ). */
+ DISCRETE(1);
+
+ /** DynamicsType value. */
+ public final int value;
+
+ DynamicsType(int value) {
+ this.value = value;
+ }
+}
diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/ocp/TimestepMethod.java b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/TimestepMethod.java
new file mode 100644
index 0000000000..949082e5fd
--- /dev/null
+++ b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/TimestepMethod.java
@@ -0,0 +1,22 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package org.wpilib.math.optimization.ocp;
+
+/** Enum describing the type of system timestep. */
+public enum TimestepMethod {
+ /** The timestep is a fixed constant. */
+ FIXED(0),
+ /** The timesteps are allowed to vary as independent decision variables. */
+ VARIABLE(1),
+ /** The timesteps are equal length but allowed to vary as a single decision variable. */
+ VARIABLE_SINGLE(2);
+
+ /** TimestepMethod value. */
+ public final int value;
+
+ TimestepMethod(int value) {
+ this.value = value;
+ }
+}
diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/ocp/TranscriptionMethod.java b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/TranscriptionMethod.java
new file mode 100644
index 0000000000..93bbffb93e
--- /dev/null
+++ b/wpimath/src/main/java/org/wpilib/math/optimization/ocp/TranscriptionMethod.java
@@ -0,0 +1,27 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package org.wpilib.math.optimization.ocp;
+
+/** Enum describing an OCP transcription method. */
+public enum TranscriptionMethod {
+ /**
+ * Each state is a decision variable constrained to the integrated dynamics of the previous state.
+ */
+ DIRECT_TRANSCRIPTION(0),
+ /**
+ * The trajectory is modeled as a series of cubic polynomials where the centerpoint slope is
+ * constrained.
+ */
+ DIRECT_COLLOCATION(1),
+ /** States depend explicitly as a function of all previous states and all previous inputs. */
+ SINGLE_SHOOTING(2);
+
+ /** TranscriptionMethod value. */
+ public final int value;
+
+ TranscriptionMethod(int value) {
+ this.value = value;
+ }
+}
diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/solver/ExitStatus.java b/wpimath/src/main/java/org/wpilib/math/optimization/solver/ExitStatus.java
new file mode 100644
index 0000000000..705bcfd319
--- /dev/null
+++ b/wpimath/src/main/java/org/wpilib/math/optimization/solver/ExitStatus.java
@@ -0,0 +1,66 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package org.wpilib.math.optimization.solver;
+
+/** Solver exit status. Negative values indicate failure. */
+public enum ExitStatus {
+ /** Solved the problem to the desired tolerance. */
+ SUCCESS(0),
+ /** The solver returned its solution so far after the user requested a stop. */
+ CALLBACK_REQUESTED_STOP(1),
+ /** The solver determined the problem to be overconstrained and gave up. */
+ TOO_FEW_DOFS(-1),
+ /** The solver determined the problem to be locally infeasible and gave up. */
+ LOCALLY_INFEASIBLE(-2),
+ /** The problem setup frontend determined the problem to have an empty feasible region. */
+ GLOBALLY_INFEASIBLE(-3),
+ /** The linear system factorization failed. */
+ FACTORIZATION_FAILED(-4),
+ /**
+ * The solver failed to reach the desired tolerance, and feasibility restoration failed to
+ * converge.
+ */
+ FEASIBILITY_RESTORATION_FAILED(-5),
+ /** The solver encountered nonfinite initial cost, constraints, or derivatives and gave up. */
+ NONFINITE_INITIAL_GUESS(-6),
+ /** The solver encountered diverging primal iterates xₖ and/or sₖ and gave up. */
+ DIVERGING_ITERATES(-7),
+ /** The solver returned its solution so far after exceeding the maximum number of iterations. */
+ MAX_ITERATIONS_EXCEEDED(-8),
+ /**
+ * The solver returned its solution so far after exceeding the maximum elapsed wall clock time.
+ */
+ TIMEOUT(-9);
+
+ /** ExitStatus value. */
+ public final int value;
+
+ ExitStatus(int value) {
+ this.value = value;
+ }
+
+ /**
+ * Converts integer to its corresponding enum value.
+ *
+ * @param x The integer.
+ * @return The enum value.
+ */
+ public static ExitStatus fromInt(int x) {
+ return switch (x) {
+ case 0 -> ExitStatus.SUCCESS;
+ case 1 -> ExitStatus.CALLBACK_REQUESTED_STOP;
+ case -1 -> ExitStatus.TOO_FEW_DOFS;
+ case -2 -> ExitStatus.LOCALLY_INFEASIBLE;
+ case -3 -> ExitStatus.GLOBALLY_INFEASIBLE;
+ case -4 -> ExitStatus.FACTORIZATION_FAILED;
+ case -5 -> ExitStatus.FEASIBILITY_RESTORATION_FAILED;
+ case -6 -> ExitStatus.NONFINITE_INITIAL_GUESS;
+ case -7 -> ExitStatus.DIVERGING_ITERATES;
+ case -8 -> ExitStatus.MAX_ITERATIONS_EXCEEDED;
+ case -9 -> ExitStatus.TIMEOUT;
+ default -> null;
+ };
+ }
+}
diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/solver/IterationInfo.java b/wpimath/src/main/java/org/wpilib/math/optimization/solver/IterationInfo.java
new file mode 100644
index 0000000000..08f8866174
--- /dev/null
+++ b/wpimath/src/main/java/org/wpilib/math/optimization/solver/IterationInfo.java
@@ -0,0 +1,53 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package org.wpilib.math.optimization.solver;
+
+import org.ejml.simple.SimpleMatrix;
+
+/** Solver iteration information exposed to an iteration callback. */
+public class IterationInfo {
+ /** The solver iteration. */
+ public final int iteration;
+
+ /** The decision variables (dense internal storage). */
+ public final SimpleMatrix x;
+
+ /** The gradient of the cost function (sparse internal storage). */
+ public final SimpleMatrix g;
+
+ /** The Hessian of the Lagrangian (sparse internal storage). */
+ public final SimpleMatrix H;
+
+ /** The equality constraint Jacobian (sparse internal storage). */
+ public final SimpleMatrix A_e;
+
+ /** The inequality constraint Jacobian (sparse internal storage). */
+ public final SimpleMatrix A_i;
+
+ /**
+ * Constructs iteration info.
+ *
+ * @param iteration The solver iteration.
+ * @param x The decision variables (dense internal storage).
+ * @param g The gradient of the cost function (sparse internal storage).
+ * @param H The Hessian of the Lagrangian (sparse internal storage).
+ * @param A_e The equality constraint Jacobian (sparse internal storage).
+ * @param A_i The inequality constraint Jacobian (sparse internal storage).
+ */
+ public IterationInfo(
+ int iteration,
+ SimpleMatrix x,
+ SimpleMatrix g,
+ SimpleMatrix H,
+ SimpleMatrix A_e,
+ SimpleMatrix A_i) {
+ this.iteration = iteration;
+ this.x = x;
+ this.g = g;
+ this.H = H;
+ this.A_e = A_e;
+ this.A_i = A_i;
+ }
+}
diff --git a/wpimath/src/main/java/org/wpilib/math/optimization/solver/Options.java b/wpimath/src/main/java/org/wpilib/math/optimization/solver/Options.java
new file mode 100644
index 0000000000..a83f099871
--- /dev/null
+++ b/wpimath/src/main/java/org/wpilib/math/optimization/solver/Options.java
@@ -0,0 +1,98 @@
+// Copyright (c) FIRST and other WPILib contributors.
+// Open Source Software; you can modify and/or share it under the terms of
+// the WPILib BSD license file in the root directory of this project.
+
+package org.wpilib.math.optimization.solver;
+
+/** Solver options. */
+public class Options {
+ /** The solver will stop once the error is below this tolerance. */
+ public double tolerance = 1e-8;
+
+ /** The maximum number of solver iterations before returning a solution. */
+ public int maxIterations = 5000;
+
+ /** The maximum elapsed wall clock time in seconds before returning a solution. */
+ public double timeout = Double.POSITIVE_INFINITY;
+
+ /**
+ * Enables the feasible interior-point method.
+ *
+ *