mirror of
https://github.com/wpilibsuite/allwpilib
synced 2026-06-19 00:41:43 +00:00
[wpimath] Add Sleipnir Java bindings (#8236)
The wrapper includes reverse mode autodiff, the Problem DSL, and the optimal control problem API. I wrote it by directly translating the upstream [API](https://github.com/SleipnirGroup/Sleipnir/tree/main/include/sleipnir) and [tests](https://github.com/SleipnirGroup/Sleipnir/tree/main/test) to Java (i.e., copy-paste-modify). I replaced the ArmFeedforward and Ellipse2d JNIs with implementations using the Sleipnir Java bindings. Switching dev binary JNIs to release by default sped up wpimath test runs from several minutes to 7 seconds.
This commit is contained in:
128
benchmark/src/main/java/wpilib/robot/CartPoleBenchmark.java
Normal file
128
benchmark/src/main/java/wpilib/robot/CartPoleBenchmark.java
Normal file
@@ -0,0 +1,128 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package wpilib.robot;
|
||||
|
||||
import static org.wpilib.math.autodiff.NumericalIntegration.rk4;
|
||||
import static org.wpilib.math.autodiff.Variable.cos;
|
||||
import static org.wpilib.math.autodiff.Variable.sin;
|
||||
import static org.wpilib.math.autodiff.VariableMatrix.solve;
|
||||
import static org.wpilib.math.optimization.Constraints.eq;
|
||||
import static org.wpilib.math.optimization.Constraints.ge;
|
||||
import static org.wpilib.math.optimization.Constraints.le;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.autodiff.VariableMatrix;
|
||||
import org.wpilib.math.optimization.Problem;
|
||||
import org.wpilib.math.optimization.solver.Options;
|
||||
import org.wpilib.math.util.MathUtil;
|
||||
|
||||
public final class CartPoleBenchmark {
|
||||
private CartPoleBenchmark() {
|
||||
// Utility class.
|
||||
}
|
||||
|
||||
@SuppressWarnings("LocalVariableName")
|
||||
private static VariableMatrix cartPoleDynamics(VariableMatrix x, VariableMatrix u) {
|
||||
final double m_c = 5.0; // Cart mass (kg)
|
||||
final double m_p = 0.5; // Pole mass (kg)
|
||||
final double l = 0.5; // Pole length (m)
|
||||
final double g = 9.806; // Acceleration due to gravity (m/s²)
|
||||
|
||||
var q = x.segment(0, 2);
|
||||
var qdot = x.segment(2, 2);
|
||||
var theta = q.get(1);
|
||||
var thetadot = qdot.get(1);
|
||||
|
||||
// [ m_c + m_p m_p l cosθ]
|
||||
// M(q) = [m_p l cosθ m_p l² ]
|
||||
var M =
|
||||
new VariableMatrix(
|
||||
new Variable[][] {
|
||||
{new Variable(m_c + m_p), cos(theta).times(m_p * l)},
|
||||
{cos(theta).times(m_p * l), new Variable(m_p * Math.pow(l, 2))}
|
||||
});
|
||||
|
||||
// [0 −m_p lθ̇ sinθ]
|
||||
// C(q, q̇) = [0 0 ]
|
||||
var C =
|
||||
new VariableMatrix(
|
||||
new Variable[][] {
|
||||
{new Variable(0), thetadot.times(-m_p * l).times(sin(theta))},
|
||||
{new Variable(0), new Variable(0)}
|
||||
});
|
||||
|
||||
// [ 0 ]
|
||||
// τ_g(q) = [-m_p gl sinθ]
|
||||
var tau_g =
|
||||
new VariableMatrix(new Variable[][] {{new Variable(0)}, {sin(theta).times(-m_p * g * l)}});
|
||||
|
||||
// [1]
|
||||
// B = [0]
|
||||
var B = new VariableMatrix(new double[][] {{1}, {0}});
|
||||
|
||||
// q̈ = M⁻¹(q)(τ_g(q) − C(q, q̇)q̇ + Bu)
|
||||
var qddot = new VariableMatrix(4);
|
||||
qddot.segment(0, 2).set(qdot);
|
||||
qddot.segment(2, 2).set(solve(M, tau_g.minus(C.times(qdot)).plus(B.times(u))));
|
||||
return qddot;
|
||||
}
|
||||
|
||||
/** Cart-pole benchmark. */
|
||||
public static void cartPole() {
|
||||
final double T = 5.0; // s
|
||||
final double dt = 0.05; // s
|
||||
final int N = (int) (T / dt);
|
||||
|
||||
final double u_max = 20.0; // N
|
||||
final double d_max = 2.0; // m
|
||||
|
||||
final var x_initial = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}});
|
||||
final var x_final = new SimpleMatrix(new double[][] {{1.0}, {Math.PI}, {0.0}, {0.0}});
|
||||
|
||||
var problem = new Problem();
|
||||
|
||||
// x = [q, q̇]ᵀ = [x, θ, ẋ, θ̇]ᵀ
|
||||
var X = problem.decisionVariable(4, N + 1);
|
||||
|
||||
// Initial guess
|
||||
for (int k = 0; k < N + 1; ++k) {
|
||||
X.get(0, k).setValue(MathUtil.lerp(x_initial.get(0), x_final.get(0), (double) k / N));
|
||||
X.get(1, k).setValue(MathUtil.lerp(x_initial.get(1), x_final.get(1), (double) k / N));
|
||||
}
|
||||
|
||||
// u = f_x
|
||||
var U = problem.decisionVariable(1, N);
|
||||
|
||||
// Initial conditions
|
||||
problem.subjectTo(eq(X.col(0), x_initial));
|
||||
|
||||
// Final conditions
|
||||
problem.subjectTo(eq(X.col(N), x_final));
|
||||
|
||||
// Cart position constraints
|
||||
problem.subjectTo(ge(X.row(0), 0.0));
|
||||
problem.subjectTo(le(X.row(0), d_max));
|
||||
|
||||
// Input constraints
|
||||
problem.subjectTo(ge(U, -u_max));
|
||||
problem.subjectTo(le(U, u_max));
|
||||
|
||||
// Dynamics constraints - RK4 integration
|
||||
for (int k = 0; k < N; ++k) {
|
||||
problem.subjectTo(
|
||||
eq(X.col(k + 1), rk4(CartPoleBenchmark::cartPoleDynamics, X.col(k), U.col(k), dt)));
|
||||
}
|
||||
|
||||
// Minimize sum squared inputs
|
||||
var J = new Variable(0.0);
|
||||
for (int k = 0; k < N; ++k) {
|
||||
J = J.plus(U.col(k).T().times(U.col(k)).get(0));
|
||||
}
|
||||
problem.minimize(J);
|
||||
|
||||
problem.solve(new Options().withDiagnostics(true));
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,13 @@ public class Main {
|
||||
new Runner(opt).run();
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
@BenchmarkMode(Mode.AverageTime)
|
||||
@OutputTimeUnit(TimeUnit.MICROSECONDS)
|
||||
public void cartPole() {
|
||||
CartPoleBenchmark.cartPole();
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
@BenchmarkMode(Mode.AverageTime)
|
||||
@OutputTimeUnit(TimeUnit.MICROSECONDS)
|
||||
|
||||
@@ -236,6 +236,9 @@ wpilib_cc_shared_library(
|
||||
wpilib_jni_java_library(
|
||||
name = "wpimath-java",
|
||||
srcs = [":generated_java"] + glob(["src/main/java/**/*.java"]),
|
||||
javacopts = [
|
||||
"-Xep:UnicodeInCode:OFF",
|
||||
],
|
||||
maven_artifact_name = "wpimath-java",
|
||||
maven_group_id = "org.wpilib.wpimath",
|
||||
native_libs = [":wpimathjni"],
|
||||
@@ -288,6 +291,7 @@ wpilib_java_junit5_test(
|
||||
"//wpiunits:wpiunits-java",
|
||||
"//wpiutil:wpiutil-java",
|
||||
"@maven//:org_ejml_ejml_core",
|
||||
"@maven//:org_ejml_ejml_ddense",
|
||||
"@maven//:org_ejml_ejml_simple",
|
||||
"@maven//:us_hebi_quickbuf_quickbuf_runtime",
|
||||
],
|
||||
|
||||
@@ -7,14 +7,18 @@ include(DownloadAndCheck)
|
||||
|
||||
file(
|
||||
GLOB wpimath_jni_src
|
||||
src/main/native/cpp/jni/ArmFeedforwardJNI.cpp
|
||||
src/main/native/cpp/jni/DAREJNI.cpp
|
||||
src/main/native/cpp/jni/EigenJNI.cpp
|
||||
src/main/native/cpp/jni/Ellipse2dJNI.cpp
|
||||
src/main/native/cpp/jni/Exceptions.cpp
|
||||
src/main/native/cpp/jni/LinearSystemUtilJNI.cpp
|
||||
src/main/native/cpp/jni/Transform3dJNI.cpp
|
||||
src/main/native/cpp/jni/Twist3dJNI.cpp
|
||||
src/main/native/cpp/jni/autodiff/GradientJNI.cpp
|
||||
src/main/native/cpp/jni/autodiff/HessianJNI.cpp
|
||||
src/main/native/cpp/jni/autodiff/JacobianJNI.cpp
|
||||
src/main/native/cpp/jni/autodiff/VariableJNI.cpp
|
||||
src/main/native/cpp/jni/autodiff/VariableMatrixJNI.cpp
|
||||
src/main/native/cpp/jni/optimization/ProblemJNI.cpp
|
||||
)
|
||||
|
||||
# Java bindings
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
/**
|
||||
* Expression type.
|
||||
*
|
||||
* <p>Used for autodiff caching.
|
||||
*/
|
||||
public enum ExpressionType {
|
||||
/** There is no expression. */
|
||||
NONE(0),
|
||||
/** The expression is a constant. */
|
||||
CONSTANT(1),
|
||||
/** The expression is composed of linear and lower-order operators. */
|
||||
LINEAR(2),
|
||||
/** The expression is composed of quadratic and lower-order operators. */
|
||||
QUADRATIC(3),
|
||||
/** The expression is composed of nonlinear and lower-order operators. */
|
||||
NONLINEAR(4);
|
||||
|
||||
/** ExpressionType value. */
|
||||
public final int value;
|
||||
|
||||
ExpressionType(int value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts integer to its corresponding enum value.
|
||||
*
|
||||
* @param x The integer.
|
||||
* @return The enum value.
|
||||
*/
|
||||
public static ExpressionType fromInt(int x) {
|
||||
return switch (x) {
|
||||
case 0 -> ExpressionType.NONE;
|
||||
case 1 -> ExpressionType.CONSTANT;
|
||||
case 2 -> ExpressionType.LINEAR;
|
||||
case 3 -> ExpressionType.QUADRATIC;
|
||||
case 4 -> ExpressionType.NONLINEAR;
|
||||
default -> null;
|
||||
};
|
||||
}
|
||||
}
|
||||
78
wpimath/src/main/java/org/wpilib/math/autodiff/Gradient.java
Normal file
78
wpimath/src/main/java/org/wpilib/math/autodiff/Gradient.java
Normal file
@@ -0,0 +1,78 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
/**
|
||||
* This class calculates the gradient of a variable with respect to a vector of variables.
|
||||
*
|
||||
* <p>The gradient is only recomputed if the variable expression is quadratic or higher order.
|
||||
*/
|
||||
public class Gradient implements AutoCloseable {
|
||||
private long m_handle;
|
||||
private int m_rows;
|
||||
|
||||
/**
|
||||
* Constructs a Gradient object.
|
||||
*
|
||||
* @param variable Variable of which to compute the gradient.
|
||||
* @param wrt Variable with respect to which to compute the gradient.
|
||||
*/
|
||||
public Gradient(Variable variable, Variable wrt) {
|
||||
this(variable, new VariableMatrix(wrt));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a Gradient object.
|
||||
*
|
||||
* @param variable Variable of which to compute the gradient.
|
||||
* @param wrt Vector of variables with respect to which to compute the gradient.
|
||||
*/
|
||||
public Gradient(Variable variable, VariableMatrix wrt) {
|
||||
assert wrt.cols() == 1;
|
||||
|
||||
m_handle = GradientJNI.create(variable.getHandle(), wrt.getHandles());
|
||||
m_rows = wrt.rows();
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a Gradient object.
|
||||
*
|
||||
* @param variable Variable of which to compute the gradient.
|
||||
* @param wrt Vector of variables with respect to which to compute the gradient.
|
||||
*/
|
||||
public Gradient(Variable variable, VariableBlock wrt) {
|
||||
this(variable, new VariableMatrix(wrt));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (m_handle != 0) {
|
||||
GradientJNI.destroy(m_handle);
|
||||
m_handle = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the gradient as a VariableMatrix.
|
||||
*
|
||||
* <p>This is useful when constructing optimization problems with derivatives in them.
|
||||
*
|
||||
* @return The gradient as a VariableMatrix.
|
||||
*/
|
||||
public VariableMatrix get() {
|
||||
return new VariableMatrix(m_rows, 1, GradientJNI.get(m_handle));
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluates the gradient at wrt's value.
|
||||
*
|
||||
* @return The gradient at wrt's value.
|
||||
*/
|
||||
public SimpleMatrix value() {
|
||||
return GradientJNI.value(m_handle).toSimpleMatrix(m_rows, 1);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import org.wpilib.math.jni.WPIMathJNI;
|
||||
|
||||
/** Gradient JNI functions. */
|
||||
final class GradientJNI extends WPIMathJNI {
|
||||
private GradientJNI() {
|
||||
// Utility class.
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a Gradient object.
|
||||
*
|
||||
* @param variable Variable of which to compute the Gradient.
|
||||
* @param wrt Vector of variables with respect to which to compute the Gradient.
|
||||
*/
|
||||
static native long create(long variable, long[] wrt);
|
||||
|
||||
/**
|
||||
* Destructs a Gradient.
|
||||
*
|
||||
* @param handle Gradient handle.
|
||||
*/
|
||||
static native void destroy(long handle);
|
||||
|
||||
/**
|
||||
* Returns the Gradient as an array of Variable handles.
|
||||
*
|
||||
* <p>This is useful when constructing optimization problems with derivatives in them.
|
||||
*
|
||||
* @param handle Gradient handle.
|
||||
* @return The Gradient as an array of Variable handles.
|
||||
*/
|
||||
static native long[] get(long handle);
|
||||
|
||||
/**
|
||||
* Evaluates the Gradient at wrt's value.
|
||||
*
|
||||
* @param handle Gradient handle.
|
||||
* @return A record containing the triplet row, column and value arrays (int[], int[], and
|
||||
* double[] respectively).
|
||||
*/
|
||||
static native NativeSparseTriplets value(long handle);
|
||||
}
|
||||
79
wpimath/src/main/java/org/wpilib/math/autodiff/Hessian.java
Normal file
79
wpimath/src/main/java/org/wpilib/math/autodiff/Hessian.java
Normal file
@@ -0,0 +1,79 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
/**
|
||||
* This class calculates the Hessian of a variable with respect to a vector of variables.
|
||||
*
|
||||
* <p>The gradient tree is cached so subsequent Hessian calculations are faster, and the Hessian is
|
||||
* only recomputed if the variable expression is nonlinear.
|
||||
*/
|
||||
public class Hessian implements AutoCloseable {
|
||||
private long m_handle;
|
||||
private int m_rows;
|
||||
|
||||
/**
|
||||
* Constructs a Hessian object.
|
||||
*
|
||||
* @param variable Variable of which to compute the Hessian.
|
||||
* @param wrt Variable with respect to which to compute the Hessian.
|
||||
*/
|
||||
public Hessian(Variable variable, Variable wrt) {
|
||||
this(variable, new VariableMatrix(wrt));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a Hessian object.
|
||||
*
|
||||
* @param variable Variable of which to compute the Hessian.
|
||||
* @param wrt Vector of variables with respect to which to compute the Hessian.
|
||||
*/
|
||||
public Hessian(Variable variable, VariableMatrix wrt) {
|
||||
assert wrt.cols() == 1;
|
||||
|
||||
m_handle = HessianJNI.create(variable.getHandle(), wrt.getHandles());
|
||||
m_rows = wrt.rows();
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a Hessian object.
|
||||
*
|
||||
* @param variable Variable of which to compute the Hessian.
|
||||
* @param wrt Vector of variables with respect to which to compute the Hessian.
|
||||
*/
|
||||
public Hessian(Variable variable, VariableBlock wrt) {
|
||||
this(variable, new VariableMatrix(wrt));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (m_handle != 0) {
|
||||
HessianJNI.destroy(m_handle);
|
||||
m_handle = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the Hessian as a VariableMatrix.
|
||||
*
|
||||
* <p>This is useful when constructing optimization problems with derivatives in them.
|
||||
*
|
||||
* @return The Hessian as a VariableMatrix.
|
||||
*/
|
||||
public VariableMatrix get() {
|
||||
return new VariableMatrix(m_rows, m_rows, HessianJNI.get(m_handle));
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluates the Hessian at wrt's value.
|
||||
*
|
||||
* @return The Hessian at wrt's value.
|
||||
*/
|
||||
public SimpleMatrix value() {
|
||||
return HessianJNI.value(m_handle).toSimpleMatrix(m_rows, m_rows);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import org.wpilib.math.jni.WPIMathJNI;
|
||||
|
||||
/** Hessian JNI functions. */
|
||||
final class HessianJNI extends WPIMathJNI {
|
||||
private HessianJNI() {
|
||||
// Utility class.
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a Hessian object.
|
||||
*
|
||||
* @param variable Variable of which to compute the Hessian.
|
||||
* @param wrt Vector of variables with respect to which to compute the Hessian.
|
||||
*/
|
||||
static native long create(long variable, long[] wrt);
|
||||
|
||||
/**
|
||||
* Destructs a Hessian.
|
||||
*
|
||||
* @param handle Hessian handle.
|
||||
*/
|
||||
static native void destroy(long handle);
|
||||
|
||||
/**
|
||||
* Returns the Hessian as an array of Variable handles.
|
||||
*
|
||||
* <p>This is useful when constructing optimization problems with derivatives in them.
|
||||
*
|
||||
* @param handle Hessian handle.
|
||||
* @return The Hessian as an array of Variable handles.
|
||||
*/
|
||||
static native long[] get(long handle);
|
||||
|
||||
/**
|
||||
* Evaluates the Hessian at wrt's value.
|
||||
*
|
||||
* @param handle Hessian handle.
|
||||
* @return A record containing the triplet row, column and value arrays (int[], int[], and
|
||||
* double[] respectively).
|
||||
*/
|
||||
static native NativeSparseTriplets value(long handle);
|
||||
}
|
||||
102
wpimath/src/main/java/org/wpilib/math/autodiff/Jacobian.java
Normal file
102
wpimath/src/main/java/org/wpilib/math/autodiff/Jacobian.java
Normal file
@@ -0,0 +1,102 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
/**
|
||||
* This class calculates the Jacobian of a vector of variables with respect to a vector of
|
||||
* variables.
|
||||
*
|
||||
* <p>The Jacobian is only recomputed if the variable expression is quadratic or higher order.
|
||||
*/
|
||||
public class Jacobian implements AutoCloseable {
|
||||
private long m_handle;
|
||||
private int m_rows;
|
||||
private int m_cols;
|
||||
|
||||
/**
|
||||
* Constructs a Jacobian object.
|
||||
*
|
||||
* @param variable Variable of which to compute the Jacobian.
|
||||
* @param wrt Variable with respect to which to compute the Jacobian.
|
||||
*/
|
||||
public Jacobian(Variable variable, Variable wrt) {
|
||||
this(new VariableMatrix(variable), new VariableMatrix(wrt));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a Jacobian object.
|
||||
*
|
||||
* @param variable Variable of which to compute the Jacobian.
|
||||
* @param wrt Vector of variables with respect to which to compute the Jacobian.
|
||||
*/
|
||||
public Jacobian(Variable variable, VariableMatrix wrt) {
|
||||
this(new VariableMatrix(variable), wrt);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a Jacobian object.
|
||||
*
|
||||
* @param variable Variable of which to compute the Jacobian.
|
||||
* @param wrt Vector of variables with respect to which to compute the Jacobian.
|
||||
*/
|
||||
public Jacobian(Variable variable, VariableBlock wrt) {
|
||||
this(new VariableMatrix(variable), new VariableMatrix(wrt));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a Jacobian object.
|
||||
*
|
||||
* @param variables Vector of variables of which to compute the Jacobian.
|
||||
* @param wrt Vector of variables with respect to which to compute the Jacobian.
|
||||
*/
|
||||
public Jacobian(VariableMatrix variables, VariableMatrix wrt) {
|
||||
assert variables.cols() == 1;
|
||||
assert wrt.cols() == 1;
|
||||
|
||||
m_handle = JacobianJNI.create(variables.getHandles(), wrt.getHandles());
|
||||
m_rows = variables.rows();
|
||||
m_cols = wrt.rows();
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a Jacobian object.
|
||||
*
|
||||
* @param variables Vector of variables of which to compute the Jacobian.
|
||||
* @param wrt Vector of variables with respect to which to compute the Jacobian.
|
||||
*/
|
||||
public Jacobian(VariableMatrix variables, VariableBlock wrt) {
|
||||
this(variables, new VariableMatrix(wrt));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (m_handle != 0) {
|
||||
JacobianJNI.destroy(m_handle);
|
||||
m_handle = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the Jacobian as a VariableMatrix.
|
||||
*
|
||||
* <p>This is useful when constructing optimization problems with derivatives in them.
|
||||
*
|
||||
* @return The Jacobian as a VariableMatrix.
|
||||
*/
|
||||
public VariableMatrix get() {
|
||||
return new VariableMatrix(m_rows, m_cols, JacobianJNI.get(m_handle));
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluates the Jacobian at wrt's value.
|
||||
*
|
||||
* @return The Jacobian at wrt's value.
|
||||
*/
|
||||
public SimpleMatrix value() {
|
||||
return JacobianJNI.value(m_handle).toSimpleMatrix(m_rows, m_cols);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import org.wpilib.math.jni.WPIMathJNI;
|
||||
|
||||
/** Jacobian JNI functions. */
|
||||
final class JacobianJNI extends WPIMathJNI {
|
||||
private JacobianJNI() {
|
||||
// Utility class.
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a Jacobian object.
|
||||
*
|
||||
* @param variables Vector of variables of which to compute the Jacobian.
|
||||
* @param wrt Vector of variables with respect to which to compute the Jacobian.
|
||||
*/
|
||||
static native long create(long[] variables, long[] wrt);
|
||||
|
||||
/**
|
||||
* Destructs a Jacobian.
|
||||
*
|
||||
* @param handle Jacobian handle.
|
||||
*/
|
||||
static native void destroy(long handle);
|
||||
|
||||
/**
|
||||
* Returns the Jacobian as an array of Variable handles.
|
||||
*
|
||||
* <p>This is useful when constructing optimization problems with derivatives in them.
|
||||
*
|
||||
* @param handle Jacobian handle.
|
||||
* @return The Jacobian as an array of Variable handles.
|
||||
*/
|
||||
static native long[] get(long handle);
|
||||
|
||||
/**
|
||||
* Evaluates the Jacobian at wrt's value.
|
||||
*
|
||||
* @param handle Jacobian handle.
|
||||
* @return A record containing the triplet row, column and value arrays (int[], int[], and
|
||||
* double[] respectively).
|
||||
*/
|
||||
static native NativeSparseTriplets value(long handle);
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import org.ejml.data.DMatrixSparseCSC;
|
||||
import org.ejml.data.DMatrixSparseTriplet;
|
||||
import org.ejml.ops.DConvertMatrixStruct;
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
/**
|
||||
* Wrapper for sparse matrix triplets from JNI.
|
||||
*
|
||||
* <p>We can't use DMatrixSparseTriplet because it doesn't have a method for bulk-initialization
|
||||
* from triplet arrays.
|
||||
*
|
||||
* @param rows Triplet rows.
|
||||
* @param cols Triplet columns.
|
||||
* @param values Triplet values.
|
||||
*/
|
||||
public record NativeSparseTriplets(int[] rows, int[] cols, double[] values) {
|
||||
/**
|
||||
* Returns a SimpleMatrix wrapper for this set of triplets.
|
||||
*
|
||||
* @param rows Number of rows in sparse SimpleMatrix.
|
||||
* @param cols Number of columns in sparse SimpleMatrix.
|
||||
* @return A SimpleMatrix wrapper for this set of triplets.
|
||||
*/
|
||||
public SimpleMatrix toSimpleMatrix(int rows, int cols) {
|
||||
var ejmlTriplets = new DMatrixSparseTriplet(rows, cols, values().length);
|
||||
for (int i = 0; i < values().length; ++i) {
|
||||
ejmlTriplets.addItem(rows()[i], cols()[i], values()[i]);
|
||||
}
|
||||
return new SimpleMatrix(DConvertMatrixStruct.convert(ejmlTriplets, (DMatrixSparseCSC) null));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import java.util.function.BiFunction;
|
||||
|
||||
/** Numerical integration utilities. */
|
||||
public final class NumericalIntegration {
|
||||
private NumericalIntegration() {
|
||||
// Utility class.
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
|
||||
*
|
||||
* @param f The function to integrate. It must take two arguments x and u.
|
||||
* @param x The initial value of x.
|
||||
* @param u The value u held constant over the integration period.
|
||||
* @param dt The time over which to integrate.
|
||||
* @return the integration of dx/dt = f(x, u) for dt.
|
||||
*/
|
||||
public static VariableMatrix rk4(
|
||||
BiFunction<VariableMatrix, VariableMatrix, VariableMatrix> f,
|
||||
VariableBlock x,
|
||||
VariableBlock u,
|
||||
double dt) {
|
||||
return rk4(f, new VariableMatrix(x), new VariableMatrix(u), dt);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
|
||||
*
|
||||
* @param f The function to integrate. It must take two arguments x and u.
|
||||
* @param x The initial value of x.
|
||||
* @param u The value u held constant over the integration period.
|
||||
* @param dt The time over which to integrate.
|
||||
* @return the integration of dx/dt = f(x, u) for dt.
|
||||
*/
|
||||
public static VariableMatrix rk4(
|
||||
BiFunction<VariableMatrix, VariableMatrix, VariableMatrix> f,
|
||||
VariableBlock x,
|
||||
VariableMatrix u,
|
||||
double dt) {
|
||||
return rk4(f, new VariableMatrix(x), u, dt);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
|
||||
*
|
||||
* @param f The function to integrate. It must take two arguments x and u.
|
||||
* @param x The initial value of x.
|
||||
* @param u The value u held constant over the integration period.
|
||||
* @param dt The time over which to integrate.
|
||||
* @return the integration of dx/dt = f(x, u) for dt.
|
||||
*/
|
||||
public static VariableMatrix rk4(
|
||||
BiFunction<VariableMatrix, VariableMatrix, VariableMatrix> f,
|
||||
VariableMatrix x,
|
||||
VariableBlock u,
|
||||
double dt) {
|
||||
return rk4(f, x, new VariableMatrix(u), dt);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
|
||||
*
|
||||
* @param f The function to integrate. It must take two arguments x and u.
|
||||
* @param x The initial value of x.
|
||||
* @param u The value u held constant over the integration period.
|
||||
* @param dt The time over which to integrate.
|
||||
* @return the integration of dx/dt = f(x, u) for dt.
|
||||
*/
|
||||
public static VariableMatrix rk4(
|
||||
BiFunction<VariableMatrix, VariableMatrix, VariableMatrix> f,
|
||||
VariableMatrix x,
|
||||
VariableMatrix u,
|
||||
double dt) {
|
||||
var h = dt;
|
||||
|
||||
var k1 = f.apply(x, u);
|
||||
var k2 = f.apply(x.plus(k1.times(h * 0.5)), u);
|
||||
var k3 = f.apply(x.plus(k2.times(h * 0.5)), u);
|
||||
var k4 = f.apply(x.plus(k3.times(h)), u);
|
||||
|
||||
return x.plus(k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4).times(h / 6.0));
|
||||
}
|
||||
}
|
||||
267
wpimath/src/main/java/org/wpilib/math/autodiff/Slice.java
Normal file
267
wpimath/src/main/java/org/wpilib/math/autodiff/Slice.java
Normal file
@@ -0,0 +1,267 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import java.util.OptionalInt;
|
||||
|
||||
/** Represents a sequence of elements in an iterable object. */
|
||||
@SuppressWarnings("PMD.UnusedFormalParameter")
|
||||
public class Slice {
|
||||
/** Type tag used to designate an omitted argument of the slice. */
|
||||
public static class None {
|
||||
/** Default constructor. */
|
||||
public None() {}
|
||||
}
|
||||
|
||||
/** Designates an omitted argument of the slice. */
|
||||
public static final None __ = null;
|
||||
|
||||
/** Start index (inclusive). */
|
||||
public int start = 0;
|
||||
|
||||
/** Stop index (exclusive). */
|
||||
public int stop = 0;
|
||||
|
||||
/** Step. */
|
||||
public int step = 1;
|
||||
|
||||
/** Constructs a Slice. */
|
||||
public Slice() {}
|
||||
|
||||
/**
|
||||
* Constructs a slice.
|
||||
*
|
||||
* @param start Slice start index (inclusive).
|
||||
*/
|
||||
public Slice(None start) {
|
||||
this(OptionalInt.of(0), OptionalInt.of(Integer.MAX_VALUE), OptionalInt.of(1));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a slice.
|
||||
*
|
||||
* @param start Slice start index (inclusive).
|
||||
*/
|
||||
public Slice(int start) {
|
||||
this.start = start;
|
||||
this.stop = (start == -1) ? Integer.MAX_VALUE : start + 1;
|
||||
this.step = 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a slice.
|
||||
*
|
||||
* @param start Slice start index (inclusive).
|
||||
* @param stop Slice stop index (exclusive).
|
||||
*/
|
||||
public Slice(None start, None stop) {
|
||||
this(OptionalInt.empty(), OptionalInt.empty(), OptionalInt.of(1));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a slice.
|
||||
*
|
||||
* @param start Slice start index (inclusive).
|
||||
* @param stop Slice stop index (exclusive).
|
||||
*/
|
||||
public Slice(None start, int stop) {
|
||||
this(OptionalInt.empty(), OptionalInt.of(stop), OptionalInt.of(1));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a slice.
|
||||
*
|
||||
* @param start Slice start index (inclusive).
|
||||
* @param stop Slice stop index (exclusive).
|
||||
*/
|
||||
public Slice(int start, None stop) {
|
||||
this(OptionalInt.of(start), OptionalInt.empty(), OptionalInt.of(1));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a slice.
|
||||
*
|
||||
* @param start Slice start index (inclusive).
|
||||
* @param stop Slice stop index (exclusive).
|
||||
*/
|
||||
public Slice(int start, int stop) {
|
||||
this(OptionalInt.of(start), OptionalInt.of(stop), OptionalInt.of(1));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a slice.
|
||||
*
|
||||
* @param start Slice start index (inclusive).
|
||||
* @param stop Slice stop index (exclusive).
|
||||
* @param step Slice step.
|
||||
*/
|
||||
public Slice(None start, None stop, None step) {
|
||||
this(OptionalInt.empty(), OptionalInt.empty(), OptionalInt.empty());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a slice.
|
||||
*
|
||||
* @param start Slice start index (inclusive).
|
||||
* @param stop Slice stop index (exclusive).
|
||||
* @param step Slice step.
|
||||
*/
|
||||
public Slice(None start, None stop, int step) {
|
||||
this(OptionalInt.empty(), OptionalInt.empty(), OptionalInt.of(step));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a slice.
|
||||
*
|
||||
* @param start Slice start index (inclusive).
|
||||
* @param stop Slice stop index (exclusive).
|
||||
* @param step Slice step.
|
||||
*/
|
||||
public Slice(None start, int stop, None step) {
|
||||
this(OptionalInt.empty(), OptionalInt.of(stop), OptionalInt.empty());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a slice.
|
||||
*
|
||||
* @param start Slice start index (inclusive).
|
||||
* @param stop Slice stop index (exclusive).
|
||||
* @param step Slice step.
|
||||
*/
|
||||
public Slice(None start, int stop, int step) {
|
||||
this(OptionalInt.empty(), OptionalInt.of(stop), OptionalInt.of(step));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a slice.
|
||||
*
|
||||
* @param start Slice start index (inclusive).
|
||||
* @param stop Slice stop index (exclusive).
|
||||
* @param step Slice step.
|
||||
*/
|
||||
public Slice(int start, None stop, None step) {
|
||||
this(OptionalInt.of(start), OptionalInt.empty(), OptionalInt.empty());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a slice.
|
||||
*
|
||||
* @param start Slice start index (inclusive).
|
||||
* @param stop Slice stop index (exclusive).
|
||||
* @param step Slice step.
|
||||
*/
|
||||
public Slice(int start, None stop, int step) {
|
||||
this(OptionalInt.of(start), OptionalInt.empty(), OptionalInt.of(step));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a slice.
|
||||
*
|
||||
* @param start Slice start index (inclusive).
|
||||
* @param stop Slice stop index (exclusive).
|
||||
* @param step Slice step.
|
||||
*/
|
||||
public Slice(int start, int stop, None step) {
|
||||
this(OptionalInt.of(start), OptionalInt.of(stop), OptionalInt.empty());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a slice.
|
||||
*
|
||||
* @param start Slice start index (inclusive).
|
||||
* @param stop Slice stop index (exclusive).
|
||||
* @param step Slice step.
|
||||
*/
|
||||
public Slice(int start, int stop, int step) {
|
||||
this(OptionalInt.of(start), OptionalInt.of(stop), OptionalInt.of(step));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a slice.
|
||||
*
|
||||
* @param start Slice start index (inclusive).
|
||||
* @param stop Slice stop index (exclusive).
|
||||
* @param step Slice step.
|
||||
*/
|
||||
public Slice(OptionalInt start, OptionalInt stop, OptionalInt step) {
|
||||
if (!step.isPresent()) {
|
||||
this.step = 1;
|
||||
} else {
|
||||
assert step.getAsInt() != 0;
|
||||
|
||||
this.step = step.getAsInt();
|
||||
}
|
||||
|
||||
// Avoid UB for step = -step if step is INT_MIN
|
||||
if (this.step == Integer.MIN_VALUE) {
|
||||
this.step = -Integer.MAX_VALUE;
|
||||
}
|
||||
|
||||
if (!start.isPresent()) {
|
||||
if (this.step < 0) {
|
||||
this.start = Integer.MAX_VALUE;
|
||||
} else {
|
||||
this.start = 0;
|
||||
}
|
||||
} else {
|
||||
this.start = start.getAsInt();
|
||||
}
|
||||
|
||||
if (!stop.isPresent()) {
|
||||
if (this.step < 0) {
|
||||
this.stop = Integer.MIN_VALUE;
|
||||
} else {
|
||||
this.stop = Integer.MAX_VALUE;
|
||||
}
|
||||
} else {
|
||||
this.stop = stop.getAsInt();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adjusts start and end slice indices assuming a sequence of the specified length.
|
||||
*
|
||||
* @param length The sequence length.
|
||||
* @return The slice length.
|
||||
*/
|
||||
public int adjust(int length) {
|
||||
assert step != 0;
|
||||
assert step >= -Integer.MAX_VALUE;
|
||||
|
||||
if (start < 0) {
|
||||
start += length;
|
||||
|
||||
if (start < 0) {
|
||||
start = (step < 0) ? -1 : 0;
|
||||
}
|
||||
} else if (start >= length) {
|
||||
start = (step < 0) ? length - 1 : length;
|
||||
}
|
||||
|
||||
if (stop < 0) {
|
||||
stop += length;
|
||||
|
||||
if (stop < 0) {
|
||||
stop = (step < 0) ? -1 : 0;
|
||||
}
|
||||
} else if (stop >= length) {
|
||||
stop = (step < 0) ? length - 1 : length;
|
||||
}
|
||||
|
||||
if (step < 0) {
|
||||
if (stop < start) {
|
||||
return (start - stop - 1) / -step + 1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
if (start < stop) {
|
||||
return (stop - start - 1) / step + 1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
634
wpimath/src/main/java/org/wpilib/math/autodiff/Variable.java
Normal file
634
wpimath/src/main/java/org/wpilib/math/autodiff/Variable.java
Normal file
@@ -0,0 +1,634 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
/** An autodiff variable pointing to an expression node. */
|
||||
public class Variable implements AutoCloseable {
|
||||
/** Handle type tag. */
|
||||
public static final class Handle {
|
||||
/** Constructor for Handle. */
|
||||
public Handle() {}
|
||||
}
|
||||
|
||||
/** Instance of handle type tag. */
|
||||
public static final Handle HANDLE = new Handle();
|
||||
|
||||
private long m_handle;
|
||||
|
||||
/** Constructs a linear Variable with a value of zero. */
|
||||
@SuppressWarnings("this-escape")
|
||||
public Variable() {
|
||||
m_handle = VariableJNI.createDefault();
|
||||
VariablePool.register(this);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a Variable from a floating point type.
|
||||
*
|
||||
* @param value The value of the Variable.
|
||||
*/
|
||||
@SuppressWarnings("this-escape")
|
||||
public Variable(double value) {
|
||||
m_handle = VariableJNI.createDouble(value);
|
||||
VariablePool.register(this);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a Variable from an integral type.
|
||||
*
|
||||
* @param value The value of the Variable.
|
||||
*/
|
||||
@SuppressWarnings("this-escape")
|
||||
public Variable(int value) {
|
||||
m_handle = VariableJNI.createInt(value);
|
||||
VariablePool.register(this);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a Variable from the given handle.
|
||||
*
|
||||
* <p>This constructor is for internal use only.
|
||||
*
|
||||
* @param handleTypeTag Handle type tag.
|
||||
* @param handle Variable handle.
|
||||
*/
|
||||
@SuppressWarnings({"PMD.UnusedFormalParameter", "this-escape"})
|
||||
public Variable(Handle handleTypeTag, long handle) {
|
||||
m_handle = handle;
|
||||
VariablePool.register(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (m_handle != 0) {
|
||||
VariableJNI.destroy(m_handle);
|
||||
m_handle = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets Variable's internal value.
|
||||
*
|
||||
* @param value The value of the Variable.
|
||||
*/
|
||||
public void setValue(double value) {
|
||||
VariableJNI.setValue(m_handle, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the value of this variable.
|
||||
*
|
||||
* @return The value of this variable.
|
||||
*/
|
||||
public double value() {
|
||||
return VariableJNI.value(m_handle);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the type of this expression (constant, linear, quadratic, or nonlinear).
|
||||
*
|
||||
* @return The type of this expression.
|
||||
*/
|
||||
public ExpressionType type() {
|
||||
return ExpressionType.fromInt(VariableJNI.type(m_handle));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns internal handle.
|
||||
*
|
||||
* @return Internal handle.
|
||||
*/
|
||||
public long getHandle() {
|
||||
return m_handle;
|
||||
}
|
||||
|
||||
/**
|
||||
* Variable-Variable multiplication operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of multiplication.
|
||||
*/
|
||||
public Variable times(Variable rhs) {
|
||||
return new Variable(HANDLE, VariableJNI.times(m_handle, rhs.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Variable-Variable multiplication operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of multiplication.
|
||||
*/
|
||||
public Variable times(double rhs) {
|
||||
return times(new Variable(rhs));
|
||||
}
|
||||
|
||||
/**
|
||||
* Variable-Variable division operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of division.
|
||||
*/
|
||||
public Variable div(Variable rhs) {
|
||||
return new Variable(HANDLE, VariableJNI.div(m_handle, rhs.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Variable-Variable division operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of division.
|
||||
*/
|
||||
public Variable div(double rhs) {
|
||||
return div(new Variable(rhs));
|
||||
}
|
||||
|
||||
/**
|
||||
* Variable-Variable addition operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of addition.
|
||||
*/
|
||||
public Variable plus(Variable rhs) {
|
||||
return new Variable(HANDLE, VariableJNI.plus(m_handle, rhs.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Variable-Variable addition operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of addition.
|
||||
*/
|
||||
public Variable plus(double rhs) {
|
||||
return plus(new Variable(rhs));
|
||||
}
|
||||
|
||||
/**
|
||||
* Variable-Variable subtraction operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of subtraction.
|
||||
*/
|
||||
public Variable minus(Variable rhs) {
|
||||
return new Variable(HANDLE, VariableJNI.minus(m_handle, rhs.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Variable-Variable subtraction operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of subtraction.
|
||||
*/
|
||||
public Variable minus(double rhs) {
|
||||
return minus(new Variable(rhs));
|
||||
}
|
||||
|
||||
/**
|
||||
* Unary minus operator.
|
||||
*
|
||||
* @return Result of unary minus.
|
||||
*/
|
||||
public Variable unaryMinus() {
|
||||
return new Variable(HANDLE, VariableJNI.unaryMinus(m_handle));
|
||||
}
|
||||
|
||||
/**
|
||||
* Unary plus operator.
|
||||
*
|
||||
* @return Result of unary plus.
|
||||
*/
|
||||
public Variable unaryPlus() {
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.abs() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
* @return Result of abs().
|
||||
*/
|
||||
public static Variable abs(Variable x) {
|
||||
return new Variable(HANDLE, VariableJNI.abs(x.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.acos() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
* @return Result of acos().
|
||||
*/
|
||||
public static Variable acos(Variable x) {
|
||||
return new Variable(HANDLE, VariableJNI.acos(x.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.asin() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
* @return Result of asin().
|
||||
*/
|
||||
public static Variable asin(Variable x) {
|
||||
return new Variable(HANDLE, VariableJNI.asin(x.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.atan() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
* @return Result of atan().
|
||||
*/
|
||||
public static Variable atan(Variable x) {
|
||||
return new Variable(HANDLE, VariableJNI.atan(x.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.atan2() for Variables.
|
||||
*
|
||||
* @param y The y argument.
|
||||
* @param x The x argument.
|
||||
* @return Result of atan2().
|
||||
*/
|
||||
public static Variable atan2(double y, Variable x) {
|
||||
return atan2(new Variable(y), x);
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.atan2() for Variables.
|
||||
*
|
||||
* @param y The y argument.
|
||||
* @param x The x argument.
|
||||
* @return Result of atan2().
|
||||
*/
|
||||
public static Variable atan2(Variable y, double x) {
|
||||
return atan2(y, new Variable(x));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.atan2() for Variables.
|
||||
*
|
||||
* @param y The y argument.
|
||||
* @param x The x argument.
|
||||
* @return Result of atan2().
|
||||
*/
|
||||
public static Variable atan2(Variable y, Variable x) {
|
||||
return new Variable(HANDLE, VariableJNI.atan2(y.getHandle(), x.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.cbrt() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
* @return Result of cbrt().
|
||||
*/
|
||||
public static Variable cbrt(Variable x) {
|
||||
return new Variable(HANDLE, VariableJNI.cbrt(x.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.cos() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
* @return Result of cos().
|
||||
*/
|
||||
public static Variable cos(Variable x) {
|
||||
return new Variable(HANDLE, VariableJNI.cos(x.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.cosh() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
* @return Result of cosh().
|
||||
*/
|
||||
public static Variable cosh(Variable x) {
|
||||
return new Variable(HANDLE, VariableJNI.cosh(x.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.exp() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
* @return Result of exp().
|
||||
*/
|
||||
public static Variable exp(Variable x) {
|
||||
return new Variable(HANDLE, VariableJNI.exp(x.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.hypot() for Variables.
|
||||
*
|
||||
* @param x The x argument.
|
||||
* @param y The y argument.
|
||||
* @return Result of hypot().
|
||||
*/
|
||||
public static Variable hypot(double x, Variable y) {
|
||||
return hypot(new Variable(x), y);
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.hypot() for Variables.
|
||||
*
|
||||
* @param x The x argument.
|
||||
* @param y The y argument.
|
||||
* @return Result of hypot().
|
||||
*/
|
||||
public static Variable hypot(Variable x, double y) {
|
||||
return hypot(x, new Variable(y));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.hypot() for Variables.
|
||||
*
|
||||
* @param x The x argument.
|
||||
* @param y The y argument.
|
||||
* @return Result of hypot().
|
||||
*/
|
||||
public static Variable hypot(Variable x, Variable y) {
|
||||
return new Variable(HANDLE, VariableJNI.hypot(x.getHandle(), y.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.hypot() for Variables.
|
||||
*
|
||||
* @param x The x argument.
|
||||
* @param y The y argument.
|
||||
* @param z The z argument.
|
||||
* @return Result of hypot().
|
||||
*/
|
||||
public static Variable hypot(double x, double y, Variable z) {
|
||||
return hypot(new Variable(x), new Variable(y), z);
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.hypot() for Variables.
|
||||
*
|
||||
* @param x The x argument.
|
||||
* @param y The y argument.
|
||||
* @param z The z argument.
|
||||
* @return Result of hypot().
|
||||
*/
|
||||
public static Variable hypot(double x, Variable y, double z) {
|
||||
return hypot(new Variable(x), y, new Variable(z));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.hypot() for Variables.
|
||||
*
|
||||
* @param x The x argument.
|
||||
* @param y The y argument.
|
||||
* @param z The z argument.
|
||||
* @return Result of hypot().
|
||||
*/
|
||||
public static Variable hypot(double x, Variable y, Variable z) {
|
||||
return hypot(new Variable(x), y, z);
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.hypot() for Variables.
|
||||
*
|
||||
* @param x The x argument.
|
||||
* @param y The y argument.
|
||||
* @param z The z argument.
|
||||
* @return Result of hypot().
|
||||
*/
|
||||
public static Variable hypot(Variable x, double y, double z) {
|
||||
return hypot(x, new Variable(y), new Variable(z));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.hypot() for Variables.
|
||||
*
|
||||
* @param x The x argument.
|
||||
* @param y The y argument.
|
||||
* @param z The z argument.
|
||||
* @return Result of hypot().
|
||||
*/
|
||||
public static Variable hypot(Variable x, double y, Variable z) {
|
||||
return hypot(x, new Variable(y), z);
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.hypot() for Variables.
|
||||
*
|
||||
* @param x The x argument.
|
||||
* @param y The y argument.
|
||||
* @param z The z argument.
|
||||
* @return Result of hypot().
|
||||
*/
|
||||
public static Variable hypot(Variable x, Variable y, double z) {
|
||||
return hypot(x, y, new Variable(z));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.hypot() for Variables.
|
||||
*
|
||||
* @param x The x argument.
|
||||
* @param y The y argument.
|
||||
* @param z The z argument.
|
||||
* @return Result of hypot().
|
||||
*/
|
||||
public static Variable hypot(Variable x, Variable y, Variable z) {
|
||||
return sqrt(pow(x, 2).plus(pow(y, 2)).plus(pow(z, 2)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.log() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
* @return Result of log().
|
||||
*/
|
||||
public static Variable log(Variable x) {
|
||||
return new Variable(HANDLE, VariableJNI.log(x.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.log10() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
* @return Result of log10().
|
||||
*/
|
||||
public static Variable log10(Variable x) {
|
||||
return new Variable(HANDLE, VariableJNI.log10(x.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.max() for Variables.
|
||||
*
|
||||
* <p>Returns the greater of a and b. If the values are equivalent, returns a.
|
||||
*
|
||||
* @param a The a argument.
|
||||
* @param b The b argument.
|
||||
* @return Result of max().
|
||||
*/
|
||||
public static Variable max(double a, Variable b) {
|
||||
return max(new Variable(a), b);
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.max() for Variables.
|
||||
*
|
||||
* <p>Returns the greater of a and b. If the values are equivalent, returns a.
|
||||
*
|
||||
* @param a The a argument.
|
||||
* @param b The b argument.
|
||||
* @return Result of max().
|
||||
*/
|
||||
public static Variable max(Variable a, double b) {
|
||||
return max(a, new Variable(b));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.max() for Variables.
|
||||
*
|
||||
* <p>Returns the greater of a and b. If the values are equivalent, returns a.
|
||||
*
|
||||
* @param a The a argument.
|
||||
* @param b The b argument.
|
||||
* @return Result of max().
|
||||
*/
|
||||
public static Variable max(Variable a, Variable b) {
|
||||
return new Variable(HANDLE, VariableJNI.max(a.getHandle(), b.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* min() for Variables.
|
||||
*
|
||||
* <p>Returns the lesser of a and b. If the values are equivalent, returns a.
|
||||
*
|
||||
* @param a The a argument.
|
||||
* @param b The b argument.
|
||||
* @return Result of min().
|
||||
*/
|
||||
public static Variable min(double a, Variable b) {
|
||||
return min(new Variable(a), b);
|
||||
}
|
||||
|
||||
/**
|
||||
* min() for Variables.
|
||||
*
|
||||
* <p>Returns the lesser of a and b. If the values are equivalent, returns a.
|
||||
*
|
||||
* @param a The a argument.
|
||||
* @param b The b argument.
|
||||
* @return Result of min().
|
||||
*/
|
||||
public static Variable min(Variable a, double b) {
|
||||
return min(a, new Variable(b));
|
||||
}
|
||||
|
||||
/**
|
||||
* min() for Variables.
|
||||
*
|
||||
* <p>Returns the lesser of a and b. If the values are equivalent, returns a.
|
||||
*
|
||||
* @param a The a argument.
|
||||
* @param b The b argument.
|
||||
* @return Result of min().
|
||||
*/
|
||||
public static Variable min(Variable a, Variable b) {
|
||||
return new Variable(HANDLE, VariableJNI.min(a.getHandle(), b.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.pow() for Variables.
|
||||
*
|
||||
* @param base The base.
|
||||
* @param power The power.
|
||||
* @return Result of pow().
|
||||
*/
|
||||
public static Variable pow(double base, Variable power) {
|
||||
return pow(new Variable(base), power);
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.pow() for Variables.
|
||||
*
|
||||
* @param base The base.
|
||||
* @param power The power.
|
||||
* @return Result of pow().
|
||||
*/
|
||||
public static Variable pow(Variable base, double power) {
|
||||
return pow(base, new Variable(power));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.pow() for Variables.
|
||||
*
|
||||
* @param base The base.
|
||||
* @param power The power.
|
||||
* @return Result of pow().
|
||||
*/
|
||||
public static Variable pow(Variable base, Variable power) {
|
||||
return new Variable(HANDLE, VariableJNI.pow(base.getHandle(), power.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.signum() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
* @return Result of signum().
|
||||
*/
|
||||
public static Variable signum(Variable x) {
|
||||
return new Variable(HANDLE, VariableJNI.signum(x.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.sin() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
* @return Result of sin().
|
||||
*/
|
||||
public static Variable sin(Variable x) {
|
||||
return new Variable(HANDLE, VariableJNI.sin(x.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.sinh() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
* @return Result of sinh().
|
||||
*/
|
||||
public static Variable sinh(Variable x) {
|
||||
return new Variable(HANDLE, VariableJNI.sinh(x.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.sqrt() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
* @return Result of sqrt().
|
||||
*/
|
||||
public static Variable sqrt(Variable x) {
|
||||
return new Variable(HANDLE, VariableJNI.sqrt(x.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.tan() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
* @return Result of tan().
|
||||
*/
|
||||
public static Variable tan(Variable x) {
|
||||
return new Variable(HANDLE, VariableJNI.tan(x.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Math.tanh() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
* @return Result of tanh().
|
||||
*/
|
||||
public static Variable tanh(Variable x) {
|
||||
return new Variable(HANDLE, VariableJNI.tanh(x.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the total native memory usage of all Variables in bytes.
|
||||
*
|
||||
* @return The total native memory usage of all Variables in bytes.
|
||||
*/
|
||||
public static long totalNativeMemoryUsage() {
|
||||
return VariableJNI.totalNativeMemoryUsage();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,826 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import java.util.Iterator;
|
||||
import java.util.NoSuchElementException;
|
||||
import java.util.function.UnaryOperator;
|
||||
import java.util.stream.Stream;
|
||||
import java.util.stream.StreamSupport;
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
/** A submatrix of autodiff variables with reference semantics. */
|
||||
public class VariableBlock implements Iterable<Variable> {
|
||||
private final VariableMatrix m_mat;
|
||||
|
||||
private final Slice m_rowSlice;
|
||||
private final int m_rowSliceLength;
|
||||
|
||||
private final Slice m_colSlice;
|
||||
private final int m_colSliceLength;
|
||||
|
||||
/**
|
||||
* Constructs a Variable block pointing to all of the given matrix.
|
||||
*
|
||||
* @param mat The matrix to which to point.
|
||||
*/
|
||||
public VariableBlock(VariableMatrix mat) {
|
||||
this(mat, 0, 0, mat.rows(), mat.cols());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a Variable block pointing to a subset of the given matrix.
|
||||
*
|
||||
* @param mat The matrix to which to point.
|
||||
* @param rowOffset The block's row offset.
|
||||
* @param colOffset The block's column offset.
|
||||
* @param blockRows The number of rows in the block.
|
||||
* @param blockCols The number of columns in the block.
|
||||
*/
|
||||
public VariableBlock(
|
||||
VariableMatrix mat, int rowOffset, int colOffset, int blockRows, int blockCols) {
|
||||
m_mat = mat;
|
||||
m_rowSlice = new Slice(rowOffset, rowOffset + blockRows, 1);
|
||||
m_rowSliceLength = m_rowSlice.adjust(mat.rows());
|
||||
m_colSlice = new Slice(colOffset, colOffset + blockCols, 1);
|
||||
m_colSliceLength = m_colSlice.adjust(mat.cols());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a Variable block pointing to a subset of the given matrix.
|
||||
*
|
||||
* <p>Note that the slices are taken as is rather than adjusted.
|
||||
*
|
||||
* @param mat The matrix to which to point.
|
||||
* @param rowSlice The block's row slice.
|
||||
* @param rowSliceLength The block's row length.
|
||||
* @param colSlice The block's column slice.
|
||||
* @param colSliceLength The block's column length.
|
||||
*/
|
||||
public VariableBlock(
|
||||
VariableMatrix mat, Slice rowSlice, int rowSliceLength, Slice colSlice, int colSliceLength) {
|
||||
m_mat = mat;
|
||||
m_rowSlice = rowSlice;
|
||||
m_rowSliceLength = rowSliceLength;
|
||||
m_colSlice = colSlice;
|
||||
m_colSliceLength = colSliceLength;
|
||||
}
|
||||
|
||||
/**
|
||||
* Assigns a double to the block.
|
||||
*
|
||||
* <p>This only works for blocks with one row and one column.
|
||||
*
|
||||
* @param value Value to assign.
|
||||
* @return This VariableBlock.
|
||||
*/
|
||||
public VariableBlock set(double value) {
|
||||
assert rows() == 1 && cols() == 1;
|
||||
|
||||
set(0, 0, new Variable(value));
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Assigns a Variable to the block.
|
||||
*
|
||||
* <p>This only works for blocks with one row and one column.
|
||||
*
|
||||
* @param value Value to assign.
|
||||
* @return This VariableBlock.
|
||||
*/
|
||||
public VariableBlock set(Variable value) {
|
||||
assert rows() == 1 && cols() == 1;
|
||||
|
||||
set(0, 0, value);
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Assigns a double array to the block.
|
||||
*
|
||||
* @param values Double array of values to assign.
|
||||
* @return This VariableBlock.
|
||||
*/
|
||||
public VariableBlock set(double[][] values) {
|
||||
assert rows() == values.length;
|
||||
|
||||
// Assert all column counts are the same
|
||||
for (var row : values) {
|
||||
assert row.length == cols();
|
||||
}
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
set(row, col, values[row][col]);
|
||||
}
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Assigns an EJML matrix to the block.
|
||||
*
|
||||
* @param values EJML matrix of values to assign.
|
||||
* @return This VariableBlock.
|
||||
*/
|
||||
public VariableBlock set(SimpleMatrix values) {
|
||||
assert rows() == values.getNumRows() && cols() == values.getNumCols();
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
set(row, col, values.get(row, col));
|
||||
}
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Assigns a VariableMatrix to the block.
|
||||
*
|
||||
* @param values VariableMatrix of values.
|
||||
* @return This VariableBlock.
|
||||
*/
|
||||
public VariableBlock set(VariableMatrix values) {
|
||||
assert rows() == values.rows() && cols() == values.cols();
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
set(row, col, values.get(row, col));
|
||||
}
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Assigns a VariableBlock to the block.
|
||||
*
|
||||
* @param values VariableBlock of values.
|
||||
* @return This VariableBlock.
|
||||
*/
|
||||
public VariableBlock set(VariableBlock values) {
|
||||
assert rows() == values.rows() && cols() == values.cols();
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
set(row, col, values.get(row, col));
|
||||
}
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a scalar subblock at the given row and column.
|
||||
*
|
||||
* @param row The scalar subblock's row.
|
||||
* @param col The scalar subblock's column.
|
||||
* @param value The value.
|
||||
*/
|
||||
public void set(int row, int col, Variable value) {
|
||||
assert row >= 0 && row < rows();
|
||||
assert col >= 0 && col < cols();
|
||||
m_mat.set(
|
||||
m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a scalar subblock at the given row and column.
|
||||
*
|
||||
* @param row The scalar subblock's row.
|
||||
* @param col The scalar subblock's column.
|
||||
* @param value The value.
|
||||
*/
|
||||
public void set(int row, int col, double value) {
|
||||
assert row >= 0 && row < rows();
|
||||
assert col >= 0 && col < cols();
|
||||
m_mat.set(
|
||||
m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a scalar subblock at the given index.
|
||||
*
|
||||
* @param index The scalar subblock's index.
|
||||
* @param value The value.
|
||||
*/
|
||||
public void set(int index, double value) {
|
||||
set(index, new Variable(value));
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a scalar subblock at the given index.
|
||||
*
|
||||
* @param index The scalar subblock's index.
|
||||
* @param value The value.
|
||||
*/
|
||||
public void set(int index, Variable value) {
|
||||
assert index >= 0 && index < rows() * cols();
|
||||
set(index / cols(), index % cols(), value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Assigns a double to the block.
|
||||
*
|
||||
* <p>This only works for blocks with one row and one column.
|
||||
*
|
||||
* @param value Value to assign.
|
||||
*/
|
||||
public void setValue(double value) {
|
||||
assert rows() == 1 && cols() == 1;
|
||||
|
||||
get(0, 0).setValue(value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets block's internal values.
|
||||
*
|
||||
* @param values Double array of values.
|
||||
*/
|
||||
public void setValue(double[][] values) {
|
||||
assert rows() == values.length;
|
||||
|
||||
// Assert all column counts are the same
|
||||
for (var row : values) {
|
||||
assert row.length == cols();
|
||||
}
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
get(row, col).setValue(values[row][col]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets block's internal values.
|
||||
*
|
||||
* @param values EJML matrix of values.
|
||||
*/
|
||||
public void setValue(SimpleMatrix values) {
|
||||
assert rows() == values.getNumRows() && cols() == values.getNumCols();
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
get(row, col).setValue(values.get(row, col));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a scalar subblock at the given row and column.
|
||||
*
|
||||
* @param row The scalar subblock's row.
|
||||
* @param col The scalar subblock's column.
|
||||
* @return A scalar subblock at the given row and column.
|
||||
*/
|
||||
public Variable get(int row, int col) {
|
||||
assert row >= 0 && row < rows();
|
||||
assert col >= 0 && col < cols();
|
||||
return m_mat.get(
|
||||
m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a scalar subblock at the given index.
|
||||
*
|
||||
* @param index The scalar subblock's index.
|
||||
* @return A scalar subblock at the given index.
|
||||
*/
|
||||
public Variable get(int index) {
|
||||
assert index >= 0 && index < rows() * cols();
|
||||
return get(index / cols(), index % cols());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a slice of the variable matrix.
|
||||
*
|
||||
* @param row The row.
|
||||
* @param colSlice The column slice.
|
||||
* @return A slice of the variable matrix.
|
||||
*/
|
||||
public VariableBlock get(int row, Slice.None colSlice) {
|
||||
return get(new Slice(row), new Slice(colSlice));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a slice of the variable matrix.
|
||||
*
|
||||
* @param row The row.
|
||||
* @param colSlice The column slice.
|
||||
* @return A slice of the variable matrix.
|
||||
*/
|
||||
public VariableBlock get(int row, Slice colSlice) {
|
||||
return get(new Slice(row), colSlice);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a slice of the variable matrix.
|
||||
*
|
||||
* @param rowSlice The row slice.
|
||||
* @param col The column.
|
||||
* @return A slice of the variable matrix.
|
||||
*/
|
||||
public VariableBlock get(Slice.None rowSlice, int col) {
|
||||
return get(new Slice(rowSlice), new Slice(col));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a slice of the variable matrix.
|
||||
*
|
||||
* @param rowSlice The row slice.
|
||||
* @param col The column.
|
||||
* @return A slice of the variable matrix.
|
||||
*/
|
||||
public VariableBlock get(Slice rowSlice, int col) {
|
||||
return get(rowSlice, new Slice(col));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a slice of the variable matrix.
|
||||
*
|
||||
* @param rowSlice The row slice.
|
||||
* @param colSlice The column slice.
|
||||
* @return A slice of the variable matrix.
|
||||
*/
|
||||
public VariableBlock get(Slice.None rowSlice, Slice.None colSlice) {
|
||||
return get(new Slice(rowSlice), new Slice(colSlice));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a slice of the variable matrix.
|
||||
*
|
||||
* @param rowSlice The row slice.
|
||||
* @param colSlice The column slice.
|
||||
* @return A slice of the variable matrix.
|
||||
*/
|
||||
public VariableBlock get(Slice.None rowSlice, Slice colSlice) {
|
||||
return get(new Slice(rowSlice), colSlice);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a slice of the variable matrix.
|
||||
*
|
||||
* @param rowSlice The row slice.
|
||||
* @param colSlice The column slice.
|
||||
* @return A slice of the variable matrix.
|
||||
*/
|
||||
public VariableBlock get(Slice rowSlice, Slice.None colSlice) {
|
||||
return get(rowSlice, new Slice(colSlice));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a slice of the variable matrix.
|
||||
*
|
||||
* @param rowSlice The row slice.
|
||||
* @param colSlice The column slice.
|
||||
* @return A slice of the variable matrix.
|
||||
*/
|
||||
public VariableBlock get(Slice rowSlice, Slice colSlice) {
|
||||
int rowSliceLength = rowSlice.adjust(m_rowSliceLength);
|
||||
int colSliceLength = colSlice.adjust(m_colSliceLength);
|
||||
return new VariableBlock(
|
||||
m_mat,
|
||||
new Slice(
|
||||
m_rowSlice.start + rowSlice.start * m_rowSlice.step,
|
||||
m_rowSlice.start + rowSlice.stop * m_rowSlice.step,
|
||||
rowSlice.step * m_rowSlice.step),
|
||||
rowSliceLength,
|
||||
new Slice(
|
||||
m_colSlice.start + colSlice.start * m_colSlice.step,
|
||||
m_colSlice.start + colSlice.stop * m_colSlice.step,
|
||||
colSlice.step * m_colSlice.step),
|
||||
colSliceLength);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a block of the variable matrix.
|
||||
*
|
||||
* @param rowOffset The row offset of the block selection.
|
||||
* @param colOffset The column offset of the block selection.
|
||||
* @param blockRows The number of rows in the block selection.
|
||||
* @param blockCols The number of columns in the block selection.
|
||||
* @return A block of the variable matrix.
|
||||
*/
|
||||
public VariableBlock block(int rowOffset, int colOffset, int blockRows, int blockCols) {
|
||||
assert rowOffset >= 0 && rowOffset <= rows();
|
||||
assert colOffset >= 0 && colOffset <= cols();
|
||||
assert blockRows >= 0 && blockRows <= rows() - rowOffset;
|
||||
assert blockCols >= 0 && blockCols <= cols() - colOffset;
|
||||
return get(
|
||||
new Slice(rowOffset, rowOffset + blockRows, 1),
|
||||
new Slice(colOffset, colOffset + blockCols, 1));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a segment of the variable vector.
|
||||
*
|
||||
* @param offset The offset of the segment.
|
||||
* @param length The length of the segment.
|
||||
* @return A segment of the variable vector.
|
||||
*/
|
||||
public VariableBlock segment(int offset, int length) {
|
||||
assert cols() == 1;
|
||||
assert offset >= 0 && offset < rows();
|
||||
assert length >= 0 && length <= rows() - offset;
|
||||
return block(offset, 0, length, 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a row slice of the variable matrix.
|
||||
*
|
||||
* @param row The row to slice.
|
||||
* @return A row slice of the variable matrix.
|
||||
*/
|
||||
public VariableBlock row(int row) {
|
||||
assert row >= 0 && row < rows();
|
||||
return block(row, 0, 1, cols());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a column slice of the variable matrix.
|
||||
*
|
||||
* @param col The column to slice.
|
||||
* @return A column slice of the variable matrix.
|
||||
*/
|
||||
public VariableBlock col(int col) {
|
||||
assert col >= 0 && col < cols();
|
||||
return block(0, col, rows(), 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrix multiplication operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of matrix multiplication.
|
||||
*/
|
||||
public VariableMatrix times(VariableMatrix rhs) {
|
||||
assert cols() == rhs.rows();
|
||||
|
||||
var result = new VariableMatrix(rows(), rhs.cols());
|
||||
|
||||
for (int i = 0; i < rows(); ++i) {
|
||||
for (int j = 0; j < rhs.cols(); ++j) {
|
||||
var sum = new Variable(0.0);
|
||||
for (int k = 0; k < cols(); ++k) {
|
||||
sum = sum.plus(get(i, k).times(rhs.get(k, j)));
|
||||
}
|
||||
result.set(i, j, sum);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrix multiplication operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of matrix multiplication.
|
||||
*/
|
||||
public VariableMatrix times(VariableBlock rhs) {
|
||||
assert cols() == rhs.rows();
|
||||
|
||||
var result = new VariableMatrix(rows(), rhs.cols());
|
||||
|
||||
for (int i = 0; i < rows(); ++i) {
|
||||
for (int j = 0; j < rhs.cols(); ++j) {
|
||||
var sum = new Variable(0.0);
|
||||
for (int k = 0; k < cols(); ++k) {
|
||||
sum = sum.plus(get(i, k).times(rhs.get(k, j)));
|
||||
}
|
||||
result.set(i, j, sum);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrix-scalar multiplication operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of matrix-scalar multiplication.
|
||||
*/
|
||||
public VariableMatrix times(double rhs) {
|
||||
return times(new Variable(rhs));
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrix-scalar multiplication operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of matrix-scalar multiplication.
|
||||
*/
|
||||
public VariableMatrix times(Variable rhs) {
|
||||
var result = new VariableMatrix(rows(), cols());
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
result.set(row, col, get(row, col).times(rhs));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Binary division operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of division.
|
||||
*/
|
||||
public VariableMatrix div(double rhs) {
|
||||
return div(new Variable(rhs));
|
||||
}
|
||||
|
||||
/**
|
||||
* Binary division operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of division.
|
||||
*/
|
||||
public VariableMatrix div(Variable rhs) {
|
||||
var result = new VariableMatrix(rows(), cols());
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
result.set(row, col, get(row, col).div(rhs));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Binary addition operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of addition.
|
||||
*/
|
||||
public VariableMatrix plus(VariableMatrix rhs) {
|
||||
assert rows() == rhs.rows() && cols() == rhs.cols();
|
||||
|
||||
var result = new VariableMatrix(rows(), cols());
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
result.set(row, col, get(row, col).plus(rhs.get(row, col)));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Binary addition operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of addition.
|
||||
*/
|
||||
public VariableMatrix plus(VariableBlock rhs) {
|
||||
assert rows() == rhs.rows() && cols() == rhs.cols();
|
||||
|
||||
var result = new VariableMatrix(rows(), cols());
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
result.set(row, col, get(row, col).plus(rhs.get(row, col)));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Binary addition operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of addition.
|
||||
*/
|
||||
public VariableMatrix plus(SimpleMatrix rhs) {
|
||||
assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols();
|
||||
|
||||
var result = new VariableMatrix(rows(), cols());
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
result.set(row, col, get(row, col).plus(rhs.get(row, col)));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Binary subtraction operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of subtraction.
|
||||
*/
|
||||
public VariableMatrix minus(VariableMatrix rhs) {
|
||||
assert rows() == rhs.rows() && cols() == rhs.cols();
|
||||
|
||||
var result = new VariableMatrix(rows(), cols());
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
result.set(row, col, get(row, col).minus(rhs.get(row, col)));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Binary subtraction operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of subtraction.
|
||||
*/
|
||||
public VariableMatrix minus(VariableBlock rhs) {
|
||||
assert rows() == rhs.rows() && cols() == rhs.cols();
|
||||
|
||||
var result = new VariableMatrix(rows(), cols());
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
result.set(row, col, get(row, col).minus(rhs.get(row, col)));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Binary subtraction operator.
|
||||
*
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of subtraction.
|
||||
*/
|
||||
public VariableMatrix minus(SimpleMatrix rhs) {
|
||||
assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols();
|
||||
|
||||
var result = new VariableMatrix(rows(), cols());
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
result.set(row, col, get(row, col).minus(rhs.get(row, col)));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Unary minus operator.
|
||||
*
|
||||
* @return Result of unary minus.
|
||||
*/
|
||||
public VariableMatrix unaryMinus() {
|
||||
var result = new VariableMatrix(rows(), cols());
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
result.set(row, col, get(row, col).unaryMinus());
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the transpose of the variable matrix.
|
||||
*
|
||||
* @return The transpose of the variable matrix.
|
||||
*/
|
||||
public VariableMatrix T() {
|
||||
var result = new VariableMatrix(cols(), rows());
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
result.set(col, row, get(row, col));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the number of rows in the matrix.
|
||||
*
|
||||
* @return The number of rows in the matrix.
|
||||
*/
|
||||
public int rows() {
|
||||
return m_rowSliceLength;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the number of columns in the matrix.
|
||||
*
|
||||
* @return The number of columns in the matrix.
|
||||
*/
|
||||
public int cols() {
|
||||
return m_colSliceLength;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an element of the variable matrix.
|
||||
*
|
||||
* @param row The row of the element to return.
|
||||
* @param col The column of the element to return.
|
||||
* @return An element of the variable matrix.
|
||||
*/
|
||||
public double value(int row, int col) {
|
||||
return get(row, col).value();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an element of the variable block.
|
||||
*
|
||||
* @param index The index of the element to return.
|
||||
* @return An element of the variable block.
|
||||
*/
|
||||
public double value(int index) {
|
||||
return get(index).value();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the contents of the variable matrix.
|
||||
*
|
||||
* @return The contents of the variable matrix.
|
||||
*/
|
||||
public SimpleMatrix value() {
|
||||
var result = new SimpleMatrix(rows(), cols());
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
result.set(row, col, value(row, col));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps the matrix coefficient-wise with an unary operator.
|
||||
*
|
||||
* @param unaryOp The unary operator to use for the map operation.
|
||||
* @return Result of the unary operator.
|
||||
*/
|
||||
public VariableMatrix cwiseMap(UnaryOperator<Variable> unaryOp) {
|
||||
var result = new VariableMatrix(rows(), cols());
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
result.set(row, col, unaryOp.apply(get(row, col)));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns number of elements in matrix.
|
||||
*
|
||||
* @return Number of elements in matrix.
|
||||
*/
|
||||
public int size() {
|
||||
return rows() * cols();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<Variable> iterator() {
|
||||
return new Iterator<>() {
|
||||
private int m_index = 0;
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return m_index < VariableBlock.this.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Variable next() {
|
||||
if (!hasNext()) {
|
||||
throw new NoSuchElementException();
|
||||
}
|
||||
|
||||
return VariableBlock.this.get(m_index++);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a Stream of VariableBlock elements.
|
||||
*
|
||||
* @return A Stream of VariableBlock elements.
|
||||
*/
|
||||
public Stream<Variable> stream() {
|
||||
return StreamSupport.stream(spliterator(), false);
|
||||
}
|
||||
}
|
||||
268
wpimath/src/main/java/org/wpilib/math/autodiff/VariableJNI.java
Normal file
268
wpimath/src/main/java/org/wpilib/math/autodiff/VariableJNI.java
Normal file
@@ -0,0 +1,268 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import org.wpilib.math.jni.WPIMathJNI;
|
||||
|
||||
/** Variable JNI functions. */
|
||||
final class VariableJNI extends WPIMathJNI {
|
||||
private VariableJNI() {
|
||||
// Utility class.
|
||||
}
|
||||
|
||||
/** Constructs a default Variable. */
|
||||
static native long createDefault();
|
||||
|
||||
/**
|
||||
* Constructs a Variable from a floating point type.
|
||||
*
|
||||
* @param value The value of the Variable.
|
||||
*/
|
||||
static native long createDouble(double value);
|
||||
|
||||
/**
|
||||
* Constructs a Variable from an integral type.
|
||||
*
|
||||
* @param value The value of the Variable.
|
||||
*/
|
||||
static native long createInt(int value);
|
||||
|
||||
/**
|
||||
* Destructs a Variable.
|
||||
*
|
||||
* @param handle Variable handle.
|
||||
*/
|
||||
static native void destroy(long handle);
|
||||
|
||||
/**
|
||||
* Sets Variable's internal value.
|
||||
*
|
||||
* @param handle Variable handle.
|
||||
* @param value The value of the Variable.
|
||||
*/
|
||||
static native void setValue(long handle, double value);
|
||||
|
||||
/**
|
||||
* Variable-Variable multiplication operator.
|
||||
*
|
||||
* @param handle Variable handle.
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of multiplication.
|
||||
*/
|
||||
static native long times(long handle, long rhs);
|
||||
|
||||
/**
|
||||
* Variable-Variable division operator.
|
||||
*
|
||||
* @param handle Variable handle.
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of division.
|
||||
*/
|
||||
static native long div(long handle, long rhs);
|
||||
|
||||
/**
|
||||
* Variable-Variable addition operator.
|
||||
*
|
||||
* @param handle Variable handle.
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of addition.
|
||||
*/
|
||||
static native long plus(long handle, long rhs);
|
||||
|
||||
/**
|
||||
* Variable-Variable subtraction operator.
|
||||
*
|
||||
* @param handle Variable handle.
|
||||
* @param rhs Operator right-hand side.
|
||||
* @return Result of subtraction.
|
||||
*/
|
||||
static native long minus(long handle, long rhs);
|
||||
|
||||
/**
|
||||
* Unary minus operator.
|
||||
*
|
||||
* @param handle Variable handle.
|
||||
*/
|
||||
static native long unaryMinus(long handle);
|
||||
|
||||
/**
|
||||
* Returns the value of this variable.
|
||||
*
|
||||
* @param handle Variable handle.
|
||||
* @return The value of this variable.
|
||||
*/
|
||||
static native double value(long handle);
|
||||
|
||||
/**
|
||||
* Returns the type of this expression (constant, linear, quadratic, or nonlinear).
|
||||
*
|
||||
* @param handle Variable handle.
|
||||
* @return The type of this expression.
|
||||
*/
|
||||
static native int type(long handle);
|
||||
|
||||
/**
|
||||
* Math.abs() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
*/
|
||||
static native long abs(long x);
|
||||
|
||||
/**
|
||||
* Math.acos() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
*/
|
||||
static native long acos(long x);
|
||||
|
||||
/**
|
||||
* Math.asin() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
*/
|
||||
static native long asin(long x);
|
||||
|
||||
/**
|
||||
* Math.atan() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
*/
|
||||
static native long atan(long x);
|
||||
|
||||
/**
|
||||
* Math.atan2() for Variables.
|
||||
*
|
||||
* @param y The y argument.
|
||||
* @param x The x argument.
|
||||
*/
|
||||
static native long atan2(long y, long x);
|
||||
|
||||
/**
|
||||
* Math.cbrt() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
*/
|
||||
static native long cbrt(long x);
|
||||
|
||||
/**
|
||||
* Math.cos() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
*/
|
||||
static native long cos(long x);
|
||||
|
||||
/**
|
||||
* Math.cosh() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
*/
|
||||
static native long cosh(long x);
|
||||
|
||||
/**
|
||||
* Math.exp() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
*/
|
||||
static native long exp(long x);
|
||||
|
||||
/**
|
||||
* Math.hypot() for Variables.
|
||||
*
|
||||
* @param x The x argument.
|
||||
* @param y The y argument.
|
||||
*/
|
||||
static native long hypot(long x, long y);
|
||||
|
||||
/**
|
||||
* Math.log() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
*/
|
||||
static native long log(long x);
|
||||
|
||||
/**
|
||||
* Math.log10() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
*/
|
||||
static native long log10(long x);
|
||||
|
||||
/**
|
||||
* Math.max() for Variables.
|
||||
*
|
||||
* <p>Returns the greater of a and b. If the values are equivalent, returns a.
|
||||
*
|
||||
* @param a The a argument.
|
||||
* @param b The b argument.
|
||||
*/
|
||||
static native long max(long a, long b);
|
||||
|
||||
/**
|
||||
* Math.min() for Variables.
|
||||
*
|
||||
* <p>Returns the lesser of a and b. If the values are equivalent, returns a.
|
||||
*
|
||||
* @param a The a argument.
|
||||
* @param b The b argument.
|
||||
*/
|
||||
static native long min(long a, long b);
|
||||
|
||||
/**
|
||||
* Math.pow() for Variables.
|
||||
*
|
||||
* @param base The base.
|
||||
* @param power The power.
|
||||
*/
|
||||
static native long pow(long base, long power);
|
||||
|
||||
/**
|
||||
* Math.signum() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
*/
|
||||
static native long signum(long x);
|
||||
|
||||
/**
|
||||
* Math.sin() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
*/
|
||||
static native long sin(long x);
|
||||
|
||||
/**
|
||||
* Math.sinh() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
*/
|
||||
static native long sinh(long x);
|
||||
|
||||
/**
|
||||
* Math.sqrt() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
*/
|
||||
static native long sqrt(long x);
|
||||
|
||||
/**
|
||||
* Math.tan() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
*/
|
||||
static native long tan(long x);
|
||||
|
||||
/**
|
||||
* Math.tanh() for Variables.
|
||||
*
|
||||
* @param x The argument.
|
||||
*/
|
||||
static native long tanh(long x);
|
||||
|
||||
/**
|
||||
* Returns the total native memory usage of Variables in bytes.
|
||||
*
|
||||
* @return The total native memory usage of Variables in bytes.
|
||||
*/
|
||||
static native long totalNativeMemoryUsage();
|
||||
}
|
||||
1047
wpimath/src/main/java/org/wpilib/math/autodiff/VariableMatrix.java
Normal file
1047
wpimath/src/main/java/org/wpilib/math/autodiff/VariableMatrix.java
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,25 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import org.wpilib.math.jni.WPIMathJNI;
|
||||
|
||||
/** Variable JNI functions. */
|
||||
final class VariableMatrixJNI extends WPIMathJNI {
|
||||
private VariableMatrixJNI() {
|
||||
// Utility class.
|
||||
}
|
||||
|
||||
/**
|
||||
* Solves the VariableMatrix equation AX = B for X.
|
||||
*
|
||||
* @param A The left-hand side as a flattened row-major matrix.
|
||||
* @param Acols The number of columns in A.
|
||||
* @param B The right-hand side as a flattened row-major matrix.
|
||||
* @param Bcols The number of columns in B.
|
||||
* @return The solution X as a flattened row-major matrix.
|
||||
*/
|
||||
static native long[] solve(long[] A, int Acols, long[] B, int Bcols);
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.Deque;
|
||||
import org.wpilib.util.ErrorMessages;
|
||||
import org.wpilib.util.cleanup.CleanupPool;
|
||||
|
||||
/**
|
||||
* Cleans up implicitly allocated Variables via try-with-resources.
|
||||
*
|
||||
* <p>This implements a stack of Variable pools containing a default global pool. The user can
|
||||
* create additional pools via try-with-resources. Variable and VariableMatrix instances will
|
||||
* register themselves with the latest pool.
|
||||
*
|
||||
* <p>It's strongly recommended to only instantiate this class via try-with-resources so the close()
|
||||
* methods are always called in the correct order (i.e., nested scopes).
|
||||
*/
|
||||
public class VariablePool implements AutoCloseable {
|
||||
private static Deque<VariablePool> s_variablePoolStack = new ArrayDeque<VariablePool>();
|
||||
|
||||
// Default global pool
|
||||
@SuppressWarnings("PMD.UnusedPrivateField")
|
||||
private static VariablePool s_globalPool = new VariablePool();
|
||||
|
||||
// Cleans up Variables in the scope of this VariablePool
|
||||
private final CleanupPool m_cleanupPool = new CleanupPool();
|
||||
|
||||
/** Default constructor. */
|
||||
@SuppressWarnings("this-escape")
|
||||
public VariablePool() {
|
||||
s_variablePoolStack.addFirst(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
m_cleanupPool.close();
|
||||
s_variablePoolStack.removeFirst();
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers a Variable in the Variable stack for cleanup.
|
||||
*
|
||||
* @param variable The Variable to register
|
||||
* @return The registered Variable
|
||||
*/
|
||||
public static Variable register(Variable variable) {
|
||||
ErrorMessages.requireNonNullParam(variable, "variable", "register");
|
||||
s_variablePoolStack.getFirst().m_cleanupPool.register(variable);
|
||||
return variable;
|
||||
}
|
||||
}
|
||||
@@ -4,9 +4,17 @@
|
||||
|
||||
package org.wpilib.math.controller;
|
||||
|
||||
import static org.wpilib.math.autodiff.Variable.cos;
|
||||
import static org.wpilib.math.autodiff.Variable.signum;
|
||||
|
||||
import org.wpilib.math.autodiff.Gradient;
|
||||
import org.wpilib.math.autodiff.Hessian;
|
||||
import org.wpilib.math.autodiff.NumericalIntegration;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.autodiff.VariableMatrix;
|
||||
import org.wpilib.math.autodiff.VariablePool;
|
||||
import org.wpilib.math.controller.proto.ArmFeedforwardProto;
|
||||
import org.wpilib.math.controller.struct.ArmFeedforwardStruct;
|
||||
import org.wpilib.math.jni.ArmFeedforwardJNI;
|
||||
import org.wpilib.util.protobuf.ProtobufSerializable;
|
||||
import org.wpilib.util.struct.StructSerializable;
|
||||
|
||||
@@ -191,8 +199,111 @@ public class ArmFeedforward implements ProtobufSerializable, StructSerializable
|
||||
* @return The computed feedforward in volts.
|
||||
*/
|
||||
public double calculate(double currentAngle, double currentVelocity, double nextVelocity) {
|
||||
return ArmFeedforwardJNI.calculate(
|
||||
ks, kv, ka, kg, currentAngle, currentVelocity, nextVelocity, m_dt);
|
||||
// Small kₐ values make the solver ill-conditioned
|
||||
if (ka < 1e-1) {
|
||||
double acceleration = (nextVelocity - currentVelocity) / m_dt;
|
||||
return ks * Math.signum(currentVelocity)
|
||||
+ kv * currentVelocity
|
||||
+ ka * acceleration
|
||||
+ kg * Math.cos(currentAngle);
|
||||
}
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
// Arm dynamics
|
||||
var A = new VariableMatrix(new double[][] {{0.0, 1.0}, {0.0, -kv / ka}});
|
||||
var B = new VariableMatrix(new double[][] {{0.0}, {1.0 / ka}});
|
||||
|
||||
var r_k = new VariableMatrix(new double[][] {{currentAngle}, {currentVelocity}});
|
||||
|
||||
var u_k = new Variable();
|
||||
|
||||
// Initial guess
|
||||
double acceleration = (nextVelocity - currentVelocity) / m_dt;
|
||||
u_k.setValue(
|
||||
ks * Math.signum(currentVelocity)
|
||||
+ kv * currentVelocity
|
||||
+ ka * acceleration
|
||||
+ kg * Math.cos(currentAngle));
|
||||
|
||||
var r_k1 =
|
||||
NumericalIntegration.rk4(
|
||||
(VariableMatrix x, VariableMatrix u) -> {
|
||||
var c =
|
||||
new VariableMatrix(
|
||||
new Variable[][] {
|
||||
{new Variable(0.0)},
|
||||
{signum(x.get(1)).times(-ks / ka).plus(cos(x.get(0)).times(-kg / ka))}
|
||||
});
|
||||
return A.times(x).plus(B.times(u)).plus(c);
|
||||
},
|
||||
r_k,
|
||||
new VariableMatrix(u_k),
|
||||
m_dt);
|
||||
|
||||
// Minimize difference between desired and actual next velocity
|
||||
var cost =
|
||||
new Variable(nextVelocity)
|
||||
.minus(r_k1.get(1))
|
||||
.times(new Variable(nextVelocity).minus(r_k1.get(1)));
|
||||
|
||||
// Refine solution via Newton's method
|
||||
{
|
||||
var xAD = u_k;
|
||||
double x = xAD.value();
|
||||
|
||||
var gradientF = new Gradient(cost, xAD);
|
||||
var g = gradientF.value();
|
||||
|
||||
var hessianF = new Hessian(cost, xAD);
|
||||
var H = hessianF.value();
|
||||
|
||||
double error_k = Double.POSITIVE_INFINITY;
|
||||
double error_k1 = Math.abs(g.get(0, 0));
|
||||
|
||||
// Loop until error stops decreasing or max iterations is reached
|
||||
for (int iteration = 0; iteration < 50 && error_k1 < (1.0 - 1e-10) * error_k; ++iteration) {
|
||||
error_k = error_k1;
|
||||
|
||||
// Iterate via Newton's method.
|
||||
//
|
||||
// xₖ₊₁ = xₖ − H⁻¹g
|
||||
//
|
||||
// The Hessian is regularized to at least 1e-4.
|
||||
double p_x = -g.get(0, 0) / Math.max(H.get(0, 0), 1e-4);
|
||||
|
||||
// Shrink step until cost goes down
|
||||
{
|
||||
double oldCost = cost.value();
|
||||
|
||||
double α = 1.0;
|
||||
double trial_x = x + α * p_x;
|
||||
|
||||
xAD.setValue(trial_x);
|
||||
|
||||
while (cost.value() > oldCost) {
|
||||
α *= 0.5;
|
||||
trial_x = x + α * p_x;
|
||||
|
||||
xAD.setValue(trial_x);
|
||||
}
|
||||
|
||||
x = trial_x;
|
||||
}
|
||||
|
||||
xAD.setValue(x);
|
||||
|
||||
g = gradientF.value();
|
||||
H = hessianF.value();
|
||||
|
||||
error_k1 = Math.abs(g.get(0, 0));
|
||||
}
|
||||
|
||||
hessianF.close();
|
||||
gradientF.close();
|
||||
}
|
||||
|
||||
return u_k.value();
|
||||
}
|
||||
}
|
||||
|
||||
// Rearranging the main equation from the calculate() method yields the
|
||||
|
||||
@@ -4,12 +4,14 @@
|
||||
|
||||
package org.wpilib.math.geometry;
|
||||
|
||||
import static org.wpilib.math.autodiff.Variable.pow;
|
||||
import static org.wpilib.math.optimization.Constraints.eq;
|
||||
import static org.wpilib.units.Units.Meters;
|
||||
|
||||
import java.util.Objects;
|
||||
import org.wpilib.math.geometry.proto.Ellipse2dProto;
|
||||
import org.wpilib.math.geometry.struct.Ellipse2dStruct;
|
||||
import org.wpilib.math.jni.Ellipse2dJNI;
|
||||
import org.wpilib.math.optimization.Problem;
|
||||
import org.wpilib.math.util.Pair;
|
||||
import org.wpilib.units.measure.Distance;
|
||||
import org.wpilib.util.protobuf.ProtobufSerializable;
|
||||
@@ -224,18 +226,38 @@ public class Ellipse2d implements ProtobufSerializable, StructSerializable {
|
||||
return point;
|
||||
}
|
||||
|
||||
// Rotate the point by the inverse of the ellipse's rotation
|
||||
var rotPoint =
|
||||
point.rotateAround(m_center.getTranslation(), m_center.getRotation().unaryMinus());
|
||||
|
||||
// Find nearest point
|
||||
var nearestPoint = new double[2];
|
||||
Ellipse2dJNI.nearest(
|
||||
m_center.getX(),
|
||||
m_center.getY(),
|
||||
m_center.getRotation().getRadians(),
|
||||
m_xSemiAxis,
|
||||
m_ySemiAxis,
|
||||
point.getX(),
|
||||
point.getY(),
|
||||
nearestPoint);
|
||||
return new Translation2d(nearestPoint[0], nearestPoint[1]);
|
||||
try (var problem = new Problem()) {
|
||||
// Point on ellipse
|
||||
var x = problem.decisionVariable();
|
||||
x.setValue(rotPoint.getX());
|
||||
var y = problem.decisionVariable();
|
||||
y.setValue(rotPoint.getY());
|
||||
|
||||
problem.minimize(pow(x.minus(rotPoint.getX()), 2).plus(pow(y.minus(rotPoint.getY()), 2)));
|
||||
|
||||
// (x − x_c)²/a² + (y − y_c)²/b² = 1
|
||||
// b²(x − x_c)² + a²(y − y_c)² = a²b²
|
||||
double a2 = m_xSemiAxis * m_xSemiAxis;
|
||||
double b2 = m_ySemiAxis * m_ySemiAxis;
|
||||
problem.subjectTo(
|
||||
eq(
|
||||
pow(x.minus(m_center.getX()), 2)
|
||||
.times(b2)
|
||||
.plus(pow(y.minus(m_center.getY()), 2).times(a2)),
|
||||
a2 * b2));
|
||||
|
||||
problem.solve();
|
||||
|
||||
rotPoint = new Translation2d(x.value(), y.value());
|
||||
}
|
||||
|
||||
// Undo rotation
|
||||
return rotPoint.rotateAround(m_center.getTranslation(), m_center.getRotation());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.jni;
|
||||
|
||||
/** ArmFeedforward JNI. */
|
||||
public final class ArmFeedforwardJNI extends WPIMathJNI {
|
||||
/**
|
||||
* Obtain a feedforward voltage from a single jointed arm feedforward object.
|
||||
*
|
||||
* <p>Constructs an ArmFeedforward object and runs its currentVelocity and nextVelocity overload
|
||||
*
|
||||
* @param ks The ArmFeedforward's static gain in volts.
|
||||
* @param kv The ArmFeedforward's velocity gain in volt seconds per radian.
|
||||
* @param ka The ArmFeedforward's acceleration gain in volt seconds² per radian.
|
||||
* @param kg The ArmFeedforward's gravity gain in volts.
|
||||
* @param currentAngle The current angle in the calculation in radians.
|
||||
* @param currentVelocity The current velocity in the calculation in radians per second.
|
||||
* @param nextVelocity The next velocity in the calculation in radians per second.
|
||||
* @param dt The time between velocity setpoints in seconds.
|
||||
* @return The calculated feedforward in volts.
|
||||
*/
|
||||
public static native double calculate(
|
||||
double ks,
|
||||
double kv,
|
||||
double ka,
|
||||
double kg,
|
||||
double currentAngle,
|
||||
double currentVelocity,
|
||||
double nextVelocity,
|
||||
double dt);
|
||||
|
||||
/** Utility class. */
|
||||
private ArmFeedforwardJNI() {}
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.jni;
|
||||
|
||||
/** Ellipse2d JNI. */
|
||||
public final class Ellipse2dJNI extends WPIMathJNI {
|
||||
/**
|
||||
* Returns the nearest point that is contained within the ellipse.
|
||||
*
|
||||
* <p>Constructs an Ellipse2d object and runs its nearest() method.
|
||||
*
|
||||
* @param centerX The x coordinate of the center of the ellipse in meters.
|
||||
* @param centerY The y coordinate of the center of the ellipse in meters.
|
||||
* @param centerHeading The ellipse's rotation in radians.
|
||||
* @param xSemiAxis The x semi-axis in meters.
|
||||
* @param ySemiAxis The y semi-axis in meters.
|
||||
* @param pointX The x coordinate of the point that this will find the nearest point to.
|
||||
* @param pointY The y coordinate of the point that this will find the nearest point to.
|
||||
* @param nearestPoint Array to store nearest point into.
|
||||
*/
|
||||
public static native void nearest(
|
||||
double centerX,
|
||||
double centerY,
|
||||
double centerHeading,
|
||||
double xSemiAxis,
|
||||
double ySemiAxis,
|
||||
double pointX,
|
||||
double pointY,
|
||||
double[] nearestPoint);
|
||||
|
||||
/** Utility class. */
|
||||
private Ellipse2dJNI() {}
|
||||
}
|
||||
1472
wpimath/src/main/java/org/wpilib/math/optimization/Constraints.java
Normal file
1472
wpimath/src/main/java/org/wpilib/math/optimization/Constraints.java
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,23 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
|
||||
/** A vector of equality constraints of the form cₑ(x) = 0. */
|
||||
@SuppressWarnings("PMD.ArrayIsStoredDirectly")
|
||||
public class EqualityConstraints {
|
||||
/** List of equality constraints. */
|
||||
public Variable[] constraints;
|
||||
|
||||
/**
|
||||
* Constructs an EqualityConstraints.
|
||||
*
|
||||
* @param constraints The constraints.
|
||||
*/
|
||||
public EqualityConstraints(Variable[] constraints) {
|
||||
this.constraints = constraints;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
|
||||
/** A vector of inequality constraints of the form cᵢ(x) ≥ 0. */
|
||||
@SuppressWarnings("PMD.ArrayIsStoredDirectly")
|
||||
public class InequalityConstraints {
|
||||
/** List of inequality constraints. */
|
||||
public Variable[] constraints;
|
||||
|
||||
/**
|
||||
* Constructs an InequalityConstraints.
|
||||
*
|
||||
* @param constraints The constraints.
|
||||
*/
|
||||
public InequalityConstraints(Variable[] constraints) {
|
||||
this.constraints = constraints;
|
||||
}
|
||||
}
|
||||
588
wpimath/src/main/java/org/wpilib/math/optimization/OCP.java
Normal file
588
wpimath/src/main/java/org/wpilib/math/optimization/OCP.java
Normal file
@@ -0,0 +1,588 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.wpilib.math.optimization.Constraints.eq;
|
||||
import static org.wpilib.math.optimization.Constraints.ge;
|
||||
import static org.wpilib.math.optimization.Constraints.le;
|
||||
|
||||
import java.util.function.BiConsumer;
|
||||
import java.util.function.BiFunction;
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.autodiff.VariableBlock;
|
||||
import org.wpilib.math.autodiff.VariableMatrix;
|
||||
import org.wpilib.math.optimization.ocp.ConstraintEvaluationFunction;
|
||||
import org.wpilib.math.optimization.ocp.DynamicsFunction;
|
||||
import org.wpilib.math.optimization.ocp.DynamicsType;
|
||||
import org.wpilib.math.optimization.ocp.TimestepMethod;
|
||||
import org.wpilib.math.optimization.ocp.TranscriptionMethod;
|
||||
|
||||
/**
|
||||
* This class allows the user to pose and solve a constrained optimal control problem (OCP) in a
|
||||
* variety of ways.
|
||||
*
|
||||
* <p>The system is transcripted by one of three methods (direct transcription, direct collocation,
|
||||
* or single-shooting) and additional constraints can be added.
|
||||
*
|
||||
* <p>In direct transcription, each state is a decision variable constrained to the integrated
|
||||
* dynamics of the previous state. In direct collocation, the trajectory is modeled as a series of
|
||||
* cubic polynomials where the centerpoint slope is constrained. In single-shooting, states depend
|
||||
* explicitly as a function of all previous states and all previous inputs.
|
||||
*
|
||||
* <p>Explicit ODEs are integrated using RK4.
|
||||
*
|
||||
* <p>For explicit ODEs, the function must be in the form dx/dt = f(t, x, u). For discrete state
|
||||
* transition functions, the function must be in the form xₖ₊₁ = f(t, xₖ, uₖ).
|
||||
*
|
||||
* <p>Direct collocation requires an explicit ODE. Direct transcription and single-shooting can use
|
||||
* either an ODE or state transition function.
|
||||
*
|
||||
* <p>https://underactuated.mit.edu/trajopt.html goes into more detail on each transcription method.
|
||||
*/
|
||||
public class OCP extends Problem {
|
||||
private int m_numSteps;
|
||||
|
||||
private DynamicsFunction m_dynamics;
|
||||
private DynamicsType m_dynamicsType;
|
||||
|
||||
private VariableMatrix m_X;
|
||||
private VariableMatrix m_U;
|
||||
private VariableMatrix m_DT;
|
||||
|
||||
/**
|
||||
* Builds an optimization problem using a system evolution function (explicit ODE or discrete
|
||||
* state transition function).
|
||||
*
|
||||
* @param numStates The number of system states.
|
||||
* @param numInputs The number of system inputs.
|
||||
* @param dt The timestep for fixed-step integration.
|
||||
* @param numSteps The number of control points.
|
||||
* @param dynamics Function representing an explicit or implicit ODE, or a discrete state
|
||||
* transition function.
|
||||
* <ul>
|
||||
* <li>Explicit: dx/dt = f(x, u, *)
|
||||
* <li>Implicit: f([x dx/dt]', u, *) = 0
|
||||
* <li>State transition: xₖ₊₁ = f(xₖ, uₖ)
|
||||
* </ul>
|
||||
*
|
||||
* @param dynamicsType The type of system evolution function.
|
||||
* @param timestepMethod The timestep method.
|
||||
* @param transcriptionMethod The transcription method.
|
||||
*/
|
||||
public OCP(
|
||||
int numStates,
|
||||
int numInputs,
|
||||
double dt,
|
||||
int numSteps,
|
||||
BiFunction<VariableMatrix, VariableMatrix, VariableMatrix> dynamics,
|
||||
DynamicsType dynamicsType,
|
||||
TimestepMethod timestepMethod,
|
||||
TranscriptionMethod transcriptionMethod) {
|
||||
this(
|
||||
numStates,
|
||||
numInputs,
|
||||
dt,
|
||||
numSteps,
|
||||
(Variable t, VariableMatrix x, VariableMatrix u, Variable _dt) -> dynamics.apply(x, u),
|
||||
dynamicsType,
|
||||
timestepMethod,
|
||||
transcriptionMethod);
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds an optimization problem using a system evolution function (explicit ODE or discrete
|
||||
* state transition function).
|
||||
*
|
||||
* @param numStates The number of system states.
|
||||
* @param numInputs The number of system inputs.
|
||||
* @param dt The timestep for fixed-step integration.
|
||||
* @param numSteps The number of control points.
|
||||
* @param dynamics Function representing an explicit or implicit ODE, or a discrete state
|
||||
* transition function.
|
||||
* <ul>
|
||||
* <li>Explicit: dx/dt = f(t, x, u, *)
|
||||
* <li>Implicit: f(t, [x dx/dt]', u, *) = 0
|
||||
* <li>State transition: xₖ₊₁ = f(t, xₖ, uₖ, dt)
|
||||
* </ul>
|
||||
*
|
||||
* @param dynamicsType The type of system evolution function.
|
||||
* @param timestepMethod The timestep method.
|
||||
* @param transcriptionMethod The transcription method.
|
||||
*/
|
||||
@SuppressWarnings("this-escape")
|
||||
public OCP(
|
||||
int numStates,
|
||||
int numInputs,
|
||||
double dt,
|
||||
int numSteps,
|
||||
DynamicsFunction dynamics,
|
||||
DynamicsType dynamicsType,
|
||||
TimestepMethod timestepMethod,
|
||||
TranscriptionMethod transcriptionMethod) {
|
||||
m_numSteps = numSteps;
|
||||
m_dynamics = dynamics;
|
||||
m_dynamicsType = dynamicsType;
|
||||
|
||||
// u is numSteps + 1 so that the final constraint function evaluation works
|
||||
m_U = decisionVariable(numInputs, m_numSteps + 1);
|
||||
|
||||
if (timestepMethod == TimestepMethod.FIXED) {
|
||||
m_DT = new VariableMatrix(1, m_numSteps + 1);
|
||||
for (int i = 0; i < numSteps + 1; ++i) {
|
||||
m_DT.set(0, i, dt);
|
||||
}
|
||||
} else if (timestepMethod == TimestepMethod.VARIABLE_SINGLE) {
|
||||
Variable single_dt = decisionVariable();
|
||||
single_dt.setValue(dt);
|
||||
|
||||
// Set the member variable matrix to track the decision variable
|
||||
m_DT = new VariableMatrix(1, m_numSteps + 1);
|
||||
for (int i = 0; i < numSteps + 1; ++i) {
|
||||
m_DT.set(0, i, single_dt);
|
||||
}
|
||||
} else if (timestepMethod == TimestepMethod.VARIABLE) {
|
||||
m_DT = decisionVariable(1, m_numSteps + 1);
|
||||
for (int i = 0; i < numSteps + 1; ++i) {
|
||||
m_DT.get(0, i).setValue(dt);
|
||||
}
|
||||
}
|
||||
|
||||
if (transcriptionMethod == TranscriptionMethod.DIRECT_TRANSCRIPTION) {
|
||||
m_X = decisionVariable(numStates, m_numSteps + 1);
|
||||
constrainDirectTranscription();
|
||||
} else if (transcriptionMethod == TranscriptionMethod.DIRECT_COLLOCATION) {
|
||||
m_X = decisionVariable(numStates, m_numSteps + 1);
|
||||
constrainDirectCollocation();
|
||||
} else if (transcriptionMethod == TranscriptionMethod.SINGLE_SHOOTING) {
|
||||
// In single-shooting the states aren't decision variables, but instead
|
||||
// depend on the input and previous states
|
||||
m_X = new VariableMatrix(numStates, m_numSteps + 1);
|
||||
constrainSingleShooting();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Constrains the initial state.
|
||||
*
|
||||
* @param initialState the initial state to constrain to.
|
||||
*/
|
||||
public void constrainInitialState(double initialState) {
|
||||
subjectTo(eq(this.initialState(), initialState));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constrains the initial state.
|
||||
*
|
||||
* @param initialState the initial state to constrain to.
|
||||
*/
|
||||
public void constrainInitialState(Variable initialState) {
|
||||
subjectTo(eq(this.initialState(), initialState));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constrains the initial state.
|
||||
*
|
||||
* @param initialState the initial state to constrain to.
|
||||
*/
|
||||
public void constrainInitialState(SimpleMatrix initialState) {
|
||||
subjectTo(eq(this.initialState(), initialState));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constrains the initial state.
|
||||
*
|
||||
* @param initialState the initial state to constrain to.
|
||||
*/
|
||||
public void constrainInitialState(VariableMatrix initialState) {
|
||||
subjectTo(eq(this.initialState(), initialState));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constrains the initial state.
|
||||
*
|
||||
* @param initialState the initial state to constrain to.
|
||||
*/
|
||||
public void constrainInitialState(VariableBlock initialState) {
|
||||
subjectTo(eq(this.initialState(), initialState));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constrains the final state.
|
||||
*
|
||||
* @param finalState the final state to constrain to.
|
||||
*/
|
||||
public void constrainFinalState(double finalState) {
|
||||
subjectTo(eq(this.finalState(), finalState));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constrains the final state.
|
||||
*
|
||||
* @param finalState the final state to constrain to.
|
||||
*/
|
||||
public void constrainFinalState(Variable finalState) {
|
||||
subjectTo(eq(this.finalState(), finalState));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constrains the final state.
|
||||
*
|
||||
* @param finalState the final state to constrain to.
|
||||
*/
|
||||
public void constrainFinalState(SimpleMatrix finalState) {
|
||||
subjectTo(eq(this.finalState(), finalState));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constrains the final state.
|
||||
*
|
||||
* @param finalState the final state to constrain to.
|
||||
*/
|
||||
public void constrainFinalState(VariableMatrix finalState) {
|
||||
subjectTo(eq(this.finalState(), finalState));
|
||||
}
|
||||
|
||||
/**
|
||||
* Constrains the final state.
|
||||
*
|
||||
* @param finalState the final state to constrain to.
|
||||
*/
|
||||
public void constrainFinalState(VariableBlock finalState) {
|
||||
subjectTo(eq(this.finalState(), finalState));
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the constraint evaluation function. This function is called `numSteps+1` times, with the
|
||||
* corresponding state and input VariableMatrices.
|
||||
*
|
||||
* @param callback The callback f(x, u) where x is the state and u is the input vector.
|
||||
*/
|
||||
public void forEachStep(BiConsumer<VariableMatrix, VariableMatrix> callback) {
|
||||
for (int i = 0; i < m_numSteps + 1; ++i) {
|
||||
var x = X().col(i);
|
||||
var u = U().col(i);
|
||||
callback.accept(new VariableMatrix(x), new VariableMatrix(u));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the constraint evaluation function. This function is called `numSteps+1` times, with the
|
||||
* corresponding state and input VariableMatrices.
|
||||
*
|
||||
* @param callback The callback f(t, x, u, dt) where t is time, x is the state vector, u is the
|
||||
* input vector, and dt is the timestep duration.
|
||||
*/
|
||||
public void forEachStep(ConstraintEvaluationFunction callback) {
|
||||
var time = new Variable(0.0);
|
||||
|
||||
for (int i = 0; i < m_numSteps + 1; ++i) {
|
||||
var x = X().col(i);
|
||||
var u = U().col(i);
|
||||
var dt = this.dt().get(0, i);
|
||||
callback.accept(time, new VariableMatrix(x), new VariableMatrix(u), dt);
|
||||
|
||||
time = time.plus(dt);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a lower bound on the input.
|
||||
*
|
||||
* @param lowerBound The lower bound that inputs must always be above. Must be shaped
|
||||
* (numInputs)x1.
|
||||
*/
|
||||
public void setLowerInputBound(double lowerBound) {
|
||||
for (int i = 0; i < m_numSteps + 1; ++i) {
|
||||
subjectTo(ge(U().col(i), lowerBound));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a lower bound on the input.
|
||||
*
|
||||
* @param lowerBound The lower bound that inputs must always be above. Must be shaped
|
||||
* (numInputs)x1.
|
||||
*/
|
||||
public void setLowerInputBound(Variable lowerBound) {
|
||||
for (int i = 0; i < m_numSteps + 1; ++i) {
|
||||
subjectTo(ge(U().col(i), lowerBound));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a lower bound on the input.
|
||||
*
|
||||
* @param lowerBound The lower bound that inputs must always be above. Must be shaped
|
||||
* (numInputs)x1.
|
||||
*/
|
||||
public void setLowerInputBound(SimpleMatrix lowerBound) {
|
||||
for (int i = 0; i < m_numSteps + 1; ++i) {
|
||||
subjectTo(ge(U().col(i), lowerBound));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a lower bound on the input.
|
||||
*
|
||||
* @param lowerBound The lower bound that inputs must always be above. Must be shaped
|
||||
* (numInputs)x1.
|
||||
*/
|
||||
public void setLowerInputBound(VariableMatrix lowerBound) {
|
||||
for (int i = 0; i < m_numSteps + 1; ++i) {
|
||||
subjectTo(ge(U().col(i), lowerBound));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a lower bound on the input.
|
||||
*
|
||||
* @param lowerBound The lower bound that inputs must always be above. Must be shaped
|
||||
* (numInputs)x1.
|
||||
*/
|
||||
public void setLowerInputBound(VariableBlock lowerBound) {
|
||||
for (int i = 0; i < m_numSteps + 1; ++i) {
|
||||
subjectTo(ge(U().col(i), lowerBound));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets an upper bound on the input.
|
||||
*
|
||||
* @param upperBound The upper bound that inputs must always be below. Must be shaped
|
||||
* (numInputs)x1.
|
||||
*/
|
||||
public void setUpperInputBound(double upperBound) {
|
||||
for (int i = 0; i < m_numSteps + 1; ++i) {
|
||||
subjectTo(le(U().col(i), upperBound));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets an upper bound on the input.
|
||||
*
|
||||
* @param upperBound The upper bound that inputs must always be below. Must be shaped
|
||||
* (numInputs)x1.
|
||||
*/
|
||||
public void setUpperInputBound(Variable upperBound) {
|
||||
for (int i = 0; i < m_numSteps + 1; ++i) {
|
||||
subjectTo(le(U().col(i), upperBound));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets an upper bound on the input.
|
||||
*
|
||||
* @param upperBound The upper bound that inputs must always be below. Must be shaped
|
||||
* (numInputs)x1.
|
||||
*/
|
||||
public void setUpperInputBound(SimpleMatrix upperBound) {
|
||||
for (int i = 0; i < m_numSteps + 1; ++i) {
|
||||
subjectTo(le(U().col(i), upperBound));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets an upper bound on the input.
|
||||
*
|
||||
* @param upperBound The upper bound that inputs must always be below. Must be shaped
|
||||
* (numInputs)x1.
|
||||
*/
|
||||
public void setUpperInputBound(VariableMatrix upperBound) {
|
||||
for (int i = 0; i < m_numSteps + 1; ++i) {
|
||||
subjectTo(le(U().col(i), upperBound));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets an upper bound on the input.
|
||||
*
|
||||
* @param upperBound The upper bound that inputs must always be below. Must be shaped
|
||||
* (numInputs)x1.
|
||||
*/
|
||||
public void setUpperInputBound(VariableBlock upperBound) {
|
||||
for (int i = 0; i < m_numSteps + 1; ++i) {
|
||||
subjectTo(le(U().col(i), upperBound));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a lower bound on the timestep.
|
||||
*
|
||||
* @param minTimestep The minimum timestep in seconds.
|
||||
*/
|
||||
public void setMinTimestep(double minTimestep) {
|
||||
subjectTo(ge(dt(), minTimestep));
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets an upper bound on the timestep.
|
||||
*
|
||||
* @param maxTimestep The maximum timestep in seconds.
|
||||
*/
|
||||
public void setMaxTimestep(double maxTimestep) {
|
||||
subjectTo(le(dt(), maxTimestep));
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the state variables. After the problem is solved, this will contain the optimized
|
||||
* trajectory.
|
||||
*
|
||||
* <p>Shaped (numStates)x(numSteps+1).
|
||||
*
|
||||
* @return The state variable matrix.
|
||||
*/
|
||||
public VariableMatrix X() {
|
||||
return m_X;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the input variables. After the problem is solved, this will contain the inputs
|
||||
* corresponding to the optimized trajectory.
|
||||
*
|
||||
* <p>Shaped (numInputs)x(numSteps+1), although the last input step is unused in the trajectory.
|
||||
*
|
||||
* @return The input variable matrix.
|
||||
*/
|
||||
public VariableMatrix U() {
|
||||
return m_U;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the timestep variables. After the problem is solved, this will contain the timesteps
|
||||
* corresponding to the optimized trajectory.
|
||||
*
|
||||
* <p>Shaped 1x(numSteps+1), although the last timestep is unused in the trajectory.
|
||||
*
|
||||
* @return The timestep variable matrix.
|
||||
*/
|
||||
public VariableMatrix dt() {
|
||||
return m_DT;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the initial state in the trajectory.
|
||||
*
|
||||
* @return The initial state of the trajectory.
|
||||
*/
|
||||
public VariableMatrix initialState() {
|
||||
return new VariableMatrix(m_X.col(0));
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the final state in the trajectory.
|
||||
*
|
||||
* @return The final state of the trajectory.
|
||||
*/
|
||||
public VariableMatrix finalState() {
|
||||
return new VariableMatrix(m_X.col(m_numSteps));
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs 4th order Runge-Kutta integration of dx/dt = f(t, x, u) for dt.
|
||||
*
|
||||
* @param f The function to integrate. It must take two arguments x and u.
|
||||
* @param x The initial value of x.
|
||||
* @param u The value u held constant over the integration period.
|
||||
* @param t0 The initial time.
|
||||
* @param dt The time over which to integrate.
|
||||
*/
|
||||
private static VariableMatrix rk4(
|
||||
DynamicsFunction f, VariableMatrix x, VariableMatrix u, Variable t0, Variable dt) {
|
||||
var halfdt = dt.times(0.5);
|
||||
VariableMatrix k1 = f.apply(t0, x, u, dt);
|
||||
VariableMatrix k2 = f.apply(t0.plus(halfdt), x.plus(k1.times(halfdt)), u, dt);
|
||||
VariableMatrix k3 = f.apply(t0.plus(halfdt), x.plus(k2.times(halfdt)), u, dt);
|
||||
VariableMatrix k4 = f.apply(t0.plus(dt), x.plus(k3.times(dt)), u, dt);
|
||||
|
||||
return x.plus(k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4).times(dt.div(6.0)));
|
||||
}
|
||||
|
||||
/** Applies direct collocation dynamics constraints. */
|
||||
private void constrainDirectCollocation() {
|
||||
assert m_dynamicsType == DynamicsType.EXPLICIT_ODE;
|
||||
|
||||
var time = new Variable(0.0);
|
||||
|
||||
// Derivation at https://mec560sbu.github.io/2016/09/30/direct_collocation/
|
||||
for (int i = 0; i < m_numSteps; ++i) {
|
||||
Variable h = dt().get(0, i);
|
||||
|
||||
var f = m_dynamics;
|
||||
|
||||
var t_begin = time;
|
||||
var t_end = t_begin.plus(h);
|
||||
|
||||
var x_begin = X().col(i);
|
||||
var x_end = X().col(i + 1);
|
||||
|
||||
var u_begin = U().col(i);
|
||||
var u_end = U().col(i + 1);
|
||||
|
||||
var xdot_begin =
|
||||
f.apply(t_begin, new VariableMatrix(x_begin), new VariableMatrix(u_begin), h);
|
||||
var xdot_end = f.apply(t_end, new VariableMatrix(x_end), new VariableMatrix(u_end), h);
|
||||
var xdot_c =
|
||||
x_begin
|
||||
.minus(x_end)
|
||||
.times(new Variable(-3).div(h.times(2)))
|
||||
.minus(xdot_begin.plus(xdot_end).times(0.25));
|
||||
|
||||
var t_c = t_begin.plus(h.times(0.5));
|
||||
var x_c = x_begin.plus(x_end).times(0.5).plus(xdot_begin.minus(xdot_end).times(h.div(8)));
|
||||
var u_c = u_begin.plus(u_end).times(0.5);
|
||||
|
||||
subjectTo(eq(xdot_c, f.apply(t_c, x_c, u_c, h)));
|
||||
|
||||
time = time.plus(h);
|
||||
}
|
||||
}
|
||||
|
||||
/** Applies direct transcription dynamics constraints. */
|
||||
private void constrainDirectTranscription() {
|
||||
var time = new Variable(0.0);
|
||||
|
||||
for (int i = 0; i < m_numSteps; ++i) {
|
||||
var x_begin = X().col(i);
|
||||
var x_end = X().col(i + 1);
|
||||
var u = U().col(i);
|
||||
Variable dt = this.dt().get(0, i);
|
||||
|
||||
if (m_dynamicsType == DynamicsType.EXPLICIT_ODE) {
|
||||
subjectTo(
|
||||
eq(
|
||||
x_end,
|
||||
rk4(m_dynamics, new VariableMatrix(x_begin), new VariableMatrix(u), time, dt)));
|
||||
} else if (m_dynamicsType == DynamicsType.DISCRETE) {
|
||||
subjectTo(
|
||||
eq(
|
||||
x_end,
|
||||
m_dynamics.apply(time, new VariableMatrix(x_begin), new VariableMatrix(u), dt)));
|
||||
}
|
||||
|
||||
time = time.plus(dt);
|
||||
}
|
||||
}
|
||||
|
||||
/** Applies single shooting dynamics constraints. */
|
||||
private void constrainSingleShooting() {
|
||||
var time = new Variable(0.0);
|
||||
|
||||
for (int i = 0; i < m_numSteps; ++i) {
|
||||
var x_begin = X().col(i);
|
||||
var x_end = X().col(i + 1);
|
||||
var u = U().col(i);
|
||||
Variable dt = this.dt().get(0, i);
|
||||
|
||||
if (m_dynamicsType == DynamicsType.EXPLICIT_ODE) {
|
||||
x_end.set(rk4(m_dynamics, new VariableMatrix(x_begin), new VariableMatrix(u), time, dt));
|
||||
} else if (m_dynamicsType == DynamicsType.DISCRETE) {
|
||||
x_end.set(m_dynamics.apply(time, new VariableMatrix(x_begin), new VariableMatrix(u), dt));
|
||||
}
|
||||
|
||||
time = time.plus(dt);
|
||||
}
|
||||
}
|
||||
}
|
||||
312
wpimath/src/main/java/org/wpilib/math/optimization/Problem.java
Normal file
312
wpimath/src/main/java/org/wpilib/math/optimization/Problem.java
Normal file
@@ -0,0 +1,312 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.function.Predicate;
|
||||
import org.ejml.data.DMatrixRMaj;
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.wpilib.math.autodiff.ExpressionType;
|
||||
import org.wpilib.math.autodiff.NativeSparseTriplets;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.autodiff.VariableMatrix;
|
||||
import org.wpilib.math.autodiff.VariablePool;
|
||||
import org.wpilib.math.optimization.solver.ExitStatus;
|
||||
import org.wpilib.math.optimization.solver.IterationInfo;
|
||||
import org.wpilib.math.optimization.solver.Options;
|
||||
|
||||
/**
|
||||
* This class allows the user to pose a constrained nonlinear optimization problem in natural
|
||||
* mathematical notation and solve it.
|
||||
*
|
||||
* <p>This class supports problems of the form:
|
||||
*
|
||||
* <pre>
|
||||
* minₓ f(x)
|
||||
* subject to cₑ(x) = 0
|
||||
* cᵢ(x) ≥ 0
|
||||
* </pre>
|
||||
*
|
||||
* <p>where f(x) is the scalar cost function, x is the vector of decision variables (variables the
|
||||
* solver can tweak to minimize the cost function), cᵢ(x) are the inequality constraints, and cₑ(x)
|
||||
* are the equality constraints. Constraints are equations or inequalities of the decision variables
|
||||
* that constrain what values the solver is allowed to use when searching for an optimal solution.
|
||||
*
|
||||
* <p>The nice thing about this class is users don't have to put their system in the form shown
|
||||
* above manually; they can write it in natural mathematical form and it'll be converted for them.
|
||||
*/
|
||||
public class Problem implements AutoCloseable {
|
||||
private long m_handle;
|
||||
|
||||
// The iteration callbacks
|
||||
private final ArrayList<Predicate<IterationInfo>> m_iterationCallbacks = new ArrayList<>();
|
||||
|
||||
// Cleans up Variables allocated within Problem's scope
|
||||
private final VariablePool m_pool = new VariablePool();
|
||||
|
||||
/** Construct the optimization problem. */
|
||||
@SuppressWarnings("this-escape")
|
||||
public Problem() {
|
||||
m_handle = ProblemJNI.create();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (m_handle != 0) {
|
||||
ProblemJNI.destroy(m_handle);
|
||||
m_handle = 0;
|
||||
|
||||
m_pool.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a decision variable in the optimization problem.
|
||||
*
|
||||
* <p>Decision variables have an initial value of zero.
|
||||
*
|
||||
* @return A decision variable in the optimization problem.
|
||||
*/
|
||||
public Variable decisionVariable() {
|
||||
var handles = ProblemJNI.decisionVariable(m_handle, 1, 1);
|
||||
return new Variable(Variable.HANDLE, handles[0]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a column vector of decision variables in the optimization problem.
|
||||
*
|
||||
* <p>Decision variables have an initial value of zero.
|
||||
*
|
||||
* @param rows Number of column vector rows.
|
||||
* @return A column vector of decision variables in the optimization problem.
|
||||
*/
|
||||
public VariableMatrix decisionVariable(int rows) {
|
||||
return decisionVariable(rows, 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a matrix of decision variables in the optimization problem.
|
||||
*
|
||||
* <p>Decision variables have an initial value of zero.
|
||||
*
|
||||
* @param rows Number of matrix rows.
|
||||
* @param cols Number of matrix columns.
|
||||
* @return A matrix of decision variables in the optimization problem.
|
||||
*/
|
||||
public VariableMatrix decisionVariable(int rows, int cols) {
|
||||
return new VariableMatrix(rows, cols, ProblemJNI.decisionVariable(m_handle, rows, cols));
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a symmetric matrix of decision variables in the optimization problem.
|
||||
*
|
||||
* <p>Variable instances are reused across the diagonal, which helps reduce problem
|
||||
* dimensionality.
|
||||
*
|
||||
* <p>Decision variables have an initial value of zero.
|
||||
*
|
||||
* @param rows Number of matrix rows.
|
||||
* @return A symmetric matrix of decision varaibles in the optimization problem.
|
||||
*/
|
||||
public VariableMatrix symmetricDecisionVariable(int rows) {
|
||||
return new VariableMatrix(rows, rows, ProblemJNI.symmetricDecisionVariable(m_handle, rows));
|
||||
}
|
||||
|
||||
/**
|
||||
* Tells the solver to minimize the output of the given cost function.
|
||||
*
|
||||
* <p>Note that this is optional. If only constraints are specified, the solver will find the
|
||||
* closest solution to the initial conditions that's in the feasible set.
|
||||
*
|
||||
* @param cost The cost function to minimize.
|
||||
*/
|
||||
public void minimize(Variable cost) {
|
||||
ProblemJNI.minimize(m_handle, cost.getHandle());
|
||||
}
|
||||
|
||||
/**
|
||||
* Tells the solver to minimize the output of the given cost function.
|
||||
*
|
||||
* <p>Note that this is optional. If only constraints are specified, the solver will find the
|
||||
* closest solution to the initial conditions that's in the feasible set.
|
||||
*
|
||||
* @param cost The cost function to minimize. An assertion is raised if the VariableMatrix isn't
|
||||
* 1x1.
|
||||
*/
|
||||
public void minimize(VariableMatrix cost) {
|
||||
assert cost.rows() == 1 && cost.cols() == 1;
|
||||
minimize(cost.get(0, 0));
|
||||
}
|
||||
|
||||
/**
|
||||
* Tells the solver to maximize the output of the given objective function.
|
||||
*
|
||||
* <p>Note that this is optional. If only constraints are specified, the solver will find the
|
||||
* closest solution to the initial conditions that's in the feasible set.
|
||||
*
|
||||
* @param objective The objective function to maximize.
|
||||
*/
|
||||
public void maximize(Variable objective) {
|
||||
ProblemJNI.maximize(m_handle, objective.getHandle());
|
||||
}
|
||||
|
||||
/**
|
||||
* Tells the solver to maximize the output of the given objective function.
|
||||
*
|
||||
* <p>Note that this is optional. If only constraints are specified, the solver will find the
|
||||
* closest solution to the initial conditions that's in the feasible set.
|
||||
*
|
||||
* @param objective The objective function to maximize. An assertion is raised if the
|
||||
* VariableMatrix isn't 1x1.
|
||||
*/
|
||||
public void maximize(VariableMatrix objective) {
|
||||
assert objective.rows() == 1 && objective.cols() == 1;
|
||||
maximize(objective.get(0, 0));
|
||||
}
|
||||
|
||||
/**
|
||||
* Tells the solver to solve the problem while satisfying the given equality constraint.
|
||||
*
|
||||
* @param constraint The constraint to satisfy.
|
||||
*/
|
||||
public void subjectTo(EqualityConstraints constraint) {
|
||||
var constraintHandles = new long[constraint.constraints.length];
|
||||
for (int i = 0; i < constraintHandles.length; ++i) {
|
||||
constraintHandles[i] = constraint.constraints[i].getHandle();
|
||||
}
|
||||
ProblemJNI.subjectToEq(m_handle, constraintHandles);
|
||||
}
|
||||
|
||||
/**
|
||||
* Tells the solver to solve the problem while satisfying the given inequality constraint.
|
||||
*
|
||||
* @param constraint The constraint to satisfy.
|
||||
*/
|
||||
public void subjectTo(InequalityConstraints constraint) {
|
||||
var constraintHandles = new long[constraint.constraints.length];
|
||||
for (int i = 0; i < constraintHandles.length; ++i) {
|
||||
constraintHandles[i] = constraint.constraints[i].getHandle();
|
||||
}
|
||||
ProblemJNI.subjectToIneq(m_handle, constraintHandles);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the cost function's type.
|
||||
*
|
||||
* @return The cost function's type.
|
||||
*/
|
||||
public ExpressionType costFunctionType() {
|
||||
return ExpressionType.fromInt(ProblemJNI.costFunctionType(m_handle));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the type of the highest order equality constraint.
|
||||
*
|
||||
* @return The type of the highest order equality constraint.
|
||||
*/
|
||||
public ExpressionType equalityConstraintType() {
|
||||
return ExpressionType.fromInt(ProblemJNI.equalityConstraintType(m_handle));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the type of the highest order inequality constraint.
|
||||
*
|
||||
* @return The type of the highest order inequality constraint.
|
||||
*/
|
||||
public ExpressionType inequalityConstraintType() {
|
||||
return ExpressionType.fromInt(ProblemJNI.inequalityConstraintType(m_handle));
|
||||
}
|
||||
|
||||
/**
|
||||
* Solves the optimization problem. The solution will be stored in the original variables used to
|
||||
* construct the problem.
|
||||
*
|
||||
* @return The solver status.
|
||||
*/
|
||||
public ExitStatus solve() {
|
||||
return solve(new Options());
|
||||
}
|
||||
|
||||
/**
|
||||
* Solves the optimization problem. The solution will be stored in the original variables used to
|
||||
* construct the problem.
|
||||
*
|
||||
* @param options Solver options.
|
||||
* @return The solver status.
|
||||
*/
|
||||
public ExitStatus solve(Options options) {
|
||||
return ExitStatus.fromInt(
|
||||
ProblemJNI.solve(
|
||||
this,
|
||||
m_handle,
|
||||
options.tolerance,
|
||||
options.maxIterations,
|
||||
options.timeout,
|
||||
options.feasibleIPM,
|
||||
options.diagnostics));
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a callback to be called at the beginning of each solver iteration.
|
||||
*
|
||||
* <p>The callback for this overload should return bool.
|
||||
*
|
||||
* @param callback The callback. Returning true from the callback causes the solver to exit early
|
||||
* with the solution it has so far.
|
||||
*/
|
||||
public void addCallback(Predicate<IterationInfo> callback) {
|
||||
m_iterationCallbacks.add(callback);
|
||||
}
|
||||
|
||||
/** Clears the registered callbacks. */
|
||||
public void clearCallbacks() {
|
||||
m_iterationCallbacks.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the registered callbacks.
|
||||
*
|
||||
* <p>This function is called by native code in ProblemJNI.
|
||||
*
|
||||
* @param numEqualityConstraints The number of equality constraints.
|
||||
* @param numInequalityConstraints The number of inequality constraints.
|
||||
* @param iteration The solver iteration.
|
||||
* @param x The decision variable values.
|
||||
* @param gTriplets Gradient triplets.
|
||||
* @param HTriplets Hessian triplets.
|
||||
* @param A_eTriplets Equality constraint Jacobian triplets.
|
||||
* @param A_iTriplets Inequality constraint Jacobian triplets.
|
||||
* @return True if the solver shold exit early.
|
||||
*/
|
||||
boolean runCallbacks(
|
||||
int numEqualityConstraints,
|
||||
int numInequalityConstraints,
|
||||
int iteration,
|
||||
double[] x,
|
||||
NativeSparseTriplets gTriplets,
|
||||
NativeSparseTriplets HTriplets,
|
||||
NativeSparseTriplets A_eTriplets,
|
||||
NativeSparseTriplets A_iTriplets) {
|
||||
if (m_iterationCallbacks.isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
var info =
|
||||
new IterationInfo(
|
||||
iteration,
|
||||
new SimpleMatrix(DMatrixRMaj.wrap(x.length, 1, x)),
|
||||
gTriplets.toSimpleMatrix(x.length, 1),
|
||||
HTriplets.toSimpleMatrix(x.length, x.length),
|
||||
A_eTriplets.toSimpleMatrix(numEqualityConstraints, x.length),
|
||||
A_iTriplets.toSimpleMatrix(numInequalityConstraints, x.length));
|
||||
|
||||
for (var callback : m_iterationCallbacks) {
|
||||
if (callback.test(info)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import org.wpilib.math.jni.WPIMathJNI;
|
||||
|
||||
/** Problem JNI functions. */
|
||||
final class ProblemJNI extends WPIMathJNI {
|
||||
private ProblemJNI() {
|
||||
// Utility class.
|
||||
}
|
||||
|
||||
/** Construct the optimization problem. */
|
||||
static native long create();
|
||||
|
||||
/**
|
||||
* Destruct the optimization problem.
|
||||
*
|
||||
* @param handle Problem handle.
|
||||
*/
|
||||
static native void destroy(long handle);
|
||||
|
||||
/**
|
||||
* Create a matrix of decision variables in the optimization problem.
|
||||
*
|
||||
* @param handle Problem handle.
|
||||
* @param rows Number of matrix rows.
|
||||
* @param cols Number of matrix columns.
|
||||
* @return A matrix of decision variables in the optimization problem.
|
||||
*/
|
||||
static native long[] decisionVariable(long handle, int rows, int cols);
|
||||
|
||||
/**
|
||||
* Create a symmetric matrix of decision variables in the optimization problem.
|
||||
*
|
||||
* <p>Variable instances are reused across the diagonal, which helps reduce problem
|
||||
* dimensionality.
|
||||
*
|
||||
* @param handle Problem handle.
|
||||
* @param rows Number of matrix rows.
|
||||
* @return A symmetric matrix of decision varaibles in the optimization problem.
|
||||
*/
|
||||
static native long[] symmetricDecisionVariable(long handle, int rows);
|
||||
|
||||
/**
|
||||
* Tells the solver to minimize the output of the given cost function.
|
||||
*
|
||||
* <p>Note that this is optional. If only constraints are specified, the solver will find the
|
||||
* closest solution to the initial conditions that's in the feasible set.
|
||||
*
|
||||
* @param handle Problem handle.
|
||||
* @param costHandle Variable handle of the cost function to minimize.
|
||||
*/
|
||||
static native void minimize(long handle, long costHandle);
|
||||
|
||||
/**
|
||||
* Tells the solver to maximize the output of the given objective function.
|
||||
*
|
||||
* <p>Note that this is optional. If only constraints are specified, the solver will find the
|
||||
* closest solution to the initial conditions that's in the feasible set.
|
||||
*
|
||||
* @param handle Problem handle.
|
||||
* @param objectiveHandle Variable handle of the objective function to maximize.
|
||||
*/
|
||||
static native void maximize(long handle, long objectiveHandle);
|
||||
|
||||
/**
|
||||
* Tells the solver to solve the problem while satisfying the given equality constraint.
|
||||
*
|
||||
* @param handle Problem handle.
|
||||
* @param constraintHandles Constraint handles.
|
||||
*/
|
||||
static native void subjectToEq(long handle, long[] constraintHandles);
|
||||
|
||||
/**
|
||||
* Tells the solver to solve the problem while satisfying the given inequality constraint.
|
||||
*
|
||||
* @param handle Problem handle.
|
||||
* @param constraintHandles Constraint handles.
|
||||
*/
|
||||
static native void subjectToIneq(long handle, long[] constraintHandles);
|
||||
|
||||
/**
|
||||
* Returns the cost function's type.
|
||||
*
|
||||
* @param handle Problem handle.
|
||||
* @return The cost function's type.
|
||||
*/
|
||||
static native int costFunctionType(long handle);
|
||||
|
||||
/**
|
||||
* Returns the type of the highest order equality constraint.
|
||||
*
|
||||
* @param handle Problem handle.
|
||||
* @return The type of the highest order equality constraint.
|
||||
*/
|
||||
static native int equalityConstraintType(long handle);
|
||||
|
||||
/**
|
||||
* Returns the type of the highest order inequality constraint.
|
||||
*
|
||||
* @param handle Problem handle.
|
||||
* @return The type of the highest order inequality constraint.
|
||||
*/
|
||||
static native int inequalityConstraintType(long handle);
|
||||
|
||||
/**
|
||||
* Solve the optimization problem. The solution will be stored in the original variables used to
|
||||
* construct the problem.
|
||||
*
|
||||
* @param obj Java Problem object.
|
||||
* @param handle Problem handle.
|
||||
* @param tolerance The solver will stop once the error is below this tolerance.
|
||||
* @param maxIterations The maximum number of solver iterations before returning a solution.
|
||||
* @param timeout The maximum elapsed wall clock time in seconds before returning a solution.
|
||||
* @param feasibleIPM Enables the feasible interior-point method. When the inequality constraints
|
||||
* are all feasible, step sizes are reduced when necessary to prevent them becoming infeasible
|
||||
* again. This is useful when parts of the problem are ill-conditioned in infeasible regions
|
||||
* (e.g., square root of a negative value). This can slow or prevent progress toward a
|
||||
* solution though, so only enable it if necessary.
|
||||
* @param diagnnostics Enables diagnostic prints.
|
||||
* @return The solver status.
|
||||
*/
|
||||
static native int solve(
|
||||
Problem obj,
|
||||
long handle,
|
||||
double tolerance,
|
||||
int maxIterations,
|
||||
double timeout,
|
||||
boolean feasibleIPM,
|
||||
boolean diagnostics);
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization.ocp;
|
||||
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.autodiff.VariableMatrix;
|
||||
|
||||
/**
|
||||
* A callback f(t, x, u, dt) where t is time, x is the state vector, u is the input vector, and dt
|
||||
* is the timestep duration.
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface ConstraintEvaluationFunction {
|
||||
/**
|
||||
* Applies this function with the arguments.
|
||||
*
|
||||
* @param t Time in seconds.
|
||||
* @param x State vector.
|
||||
* @param u Input vector.
|
||||
* @param dt Timestep duration in seconds.
|
||||
*/
|
||||
void accept(Variable t, VariableMatrix x, VariableMatrix u, Variable dt);
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization.ocp;
|
||||
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.autodiff.VariableMatrix;
|
||||
|
||||
/**
|
||||
* Function representing an explicit or implicit ODE, or a discrete state transition function.
|
||||
*
|
||||
* <ul>
|
||||
* <li>Explicit: dx/dt = f(t, x, u, *)
|
||||
* <li>Implicit: f(t, [x dx/dt]', u, *) = 0
|
||||
* <li>State transition: xₖ₊₁ = f(t, xₖ, uₖ, dt)
|
||||
* </ul>
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface DynamicsFunction {
|
||||
/**
|
||||
* Applies this function with the arguments and returns the result.
|
||||
*
|
||||
* @param t Time in seconds.
|
||||
* @param x State vector.
|
||||
* @param u Input vector.
|
||||
* @param dt Timestep duration in seconds.
|
||||
* @return The state derivative dx/dt.
|
||||
*/
|
||||
VariableMatrix apply(Variable t, VariableMatrix x, VariableMatrix u, Variable dt);
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization.ocp;
|
||||
|
||||
/** Enum describing a type of system dynamics constraints. */
|
||||
public enum DynamicsType {
|
||||
/** The dynamics are a function in the form dx/dt = f(t, x, u). */
|
||||
EXPLICIT_ODE(0),
|
||||
/** The dynamics are a function in the form xₖ₊₁ = f(t, xₖ, uₖ). */
|
||||
DISCRETE(1);
|
||||
|
||||
/** DynamicsType value. */
|
||||
public final int value;
|
||||
|
||||
DynamicsType(int value) {
|
||||
this.value = value;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization.ocp;
|
||||
|
||||
/** Enum describing the type of system timestep. */
|
||||
public enum TimestepMethod {
|
||||
/** The timestep is a fixed constant. */
|
||||
FIXED(0),
|
||||
/** The timesteps are allowed to vary as independent decision variables. */
|
||||
VARIABLE(1),
|
||||
/** The timesteps are equal length but allowed to vary as a single decision variable. */
|
||||
VARIABLE_SINGLE(2);
|
||||
|
||||
/** TimestepMethod value. */
|
||||
public final int value;
|
||||
|
||||
TimestepMethod(int value) {
|
||||
this.value = value;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization.ocp;
|
||||
|
||||
/** Enum describing an OCP transcription method. */
|
||||
public enum TranscriptionMethod {
|
||||
/**
|
||||
* Each state is a decision variable constrained to the integrated dynamics of the previous state.
|
||||
*/
|
||||
DIRECT_TRANSCRIPTION(0),
|
||||
/**
|
||||
* The trajectory is modeled as a series of cubic polynomials where the centerpoint slope is
|
||||
* constrained.
|
||||
*/
|
||||
DIRECT_COLLOCATION(1),
|
||||
/** States depend explicitly as a function of all previous states and all previous inputs. */
|
||||
SINGLE_SHOOTING(2);
|
||||
|
||||
/** TranscriptionMethod value. */
|
||||
public final int value;
|
||||
|
||||
TranscriptionMethod(int value) {
|
||||
this.value = value;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization.solver;
|
||||
|
||||
/** Solver exit status. Negative values indicate failure. */
|
||||
public enum ExitStatus {
|
||||
/** Solved the problem to the desired tolerance. */
|
||||
SUCCESS(0),
|
||||
/** The solver returned its solution so far after the user requested a stop. */
|
||||
CALLBACK_REQUESTED_STOP(1),
|
||||
/** The solver determined the problem to be overconstrained and gave up. */
|
||||
TOO_FEW_DOFS(-1),
|
||||
/** The solver determined the problem to be locally infeasible and gave up. */
|
||||
LOCALLY_INFEASIBLE(-2),
|
||||
/** The problem setup frontend determined the problem to have an empty feasible region. */
|
||||
GLOBALLY_INFEASIBLE(-3),
|
||||
/** The linear system factorization failed. */
|
||||
FACTORIZATION_FAILED(-4),
|
||||
/**
|
||||
* The solver failed to reach the desired tolerance, and feasibility restoration failed to
|
||||
* converge.
|
||||
*/
|
||||
FEASIBILITY_RESTORATION_FAILED(-5),
|
||||
/** The solver encountered nonfinite initial cost, constraints, or derivatives and gave up. */
|
||||
NONFINITE_INITIAL_GUESS(-6),
|
||||
/** The solver encountered diverging primal iterates xₖ and/or sₖ and gave up. */
|
||||
DIVERGING_ITERATES(-7),
|
||||
/** The solver returned its solution so far after exceeding the maximum number of iterations. */
|
||||
MAX_ITERATIONS_EXCEEDED(-8),
|
||||
/**
|
||||
* The solver returned its solution so far after exceeding the maximum elapsed wall clock time.
|
||||
*/
|
||||
TIMEOUT(-9);
|
||||
|
||||
/** ExitStatus value. */
|
||||
public final int value;
|
||||
|
||||
ExitStatus(int value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts integer to its corresponding enum value.
|
||||
*
|
||||
* @param x The integer.
|
||||
* @return The enum value.
|
||||
*/
|
||||
public static ExitStatus fromInt(int x) {
|
||||
return switch (x) {
|
||||
case 0 -> ExitStatus.SUCCESS;
|
||||
case 1 -> ExitStatus.CALLBACK_REQUESTED_STOP;
|
||||
case -1 -> ExitStatus.TOO_FEW_DOFS;
|
||||
case -2 -> ExitStatus.LOCALLY_INFEASIBLE;
|
||||
case -3 -> ExitStatus.GLOBALLY_INFEASIBLE;
|
||||
case -4 -> ExitStatus.FACTORIZATION_FAILED;
|
||||
case -5 -> ExitStatus.FEASIBILITY_RESTORATION_FAILED;
|
||||
case -6 -> ExitStatus.NONFINITE_INITIAL_GUESS;
|
||||
case -7 -> ExitStatus.DIVERGING_ITERATES;
|
||||
case -8 -> ExitStatus.MAX_ITERATIONS_EXCEEDED;
|
||||
case -9 -> ExitStatus.TIMEOUT;
|
||||
default -> null;
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization.solver;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
/** Solver iteration information exposed to an iteration callback. */
|
||||
public class IterationInfo {
|
||||
/** The solver iteration. */
|
||||
public final int iteration;
|
||||
|
||||
/** The decision variables (dense internal storage). */
|
||||
public final SimpleMatrix x;
|
||||
|
||||
/** The gradient of the cost function (sparse internal storage). */
|
||||
public final SimpleMatrix g;
|
||||
|
||||
/** The Hessian of the Lagrangian (sparse internal storage). */
|
||||
public final SimpleMatrix H;
|
||||
|
||||
/** The equality constraint Jacobian (sparse internal storage). */
|
||||
public final SimpleMatrix A_e;
|
||||
|
||||
/** The inequality constraint Jacobian (sparse internal storage). */
|
||||
public final SimpleMatrix A_i;
|
||||
|
||||
/**
|
||||
* Constructs iteration info.
|
||||
*
|
||||
* @param iteration The solver iteration.
|
||||
* @param x The decision variables (dense internal storage).
|
||||
* @param g The gradient of the cost function (sparse internal storage).
|
||||
* @param H The Hessian of the Lagrangian (sparse internal storage).
|
||||
* @param A_e The equality constraint Jacobian (sparse internal storage).
|
||||
* @param A_i The inequality constraint Jacobian (sparse internal storage).
|
||||
*/
|
||||
public IterationInfo(
|
||||
int iteration,
|
||||
SimpleMatrix x,
|
||||
SimpleMatrix g,
|
||||
SimpleMatrix H,
|
||||
SimpleMatrix A_e,
|
||||
SimpleMatrix A_i) {
|
||||
this.iteration = iteration;
|
||||
this.x = x;
|
||||
this.g = g;
|
||||
this.H = H;
|
||||
this.A_e = A_e;
|
||||
this.A_i = A_i;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization.solver;
|
||||
|
||||
/** Solver options. */
|
||||
public class Options {
|
||||
/** The solver will stop once the error is below this tolerance. */
|
||||
public double tolerance = 1e-8;
|
||||
|
||||
/** The maximum number of solver iterations before returning a solution. */
|
||||
public int maxIterations = 5000;
|
||||
|
||||
/** The maximum elapsed wall clock time in seconds before returning a solution. */
|
||||
public double timeout = Double.POSITIVE_INFINITY;
|
||||
|
||||
/**
|
||||
* Enables the feasible interior-point method.
|
||||
*
|
||||
* <p>When the inequality constraints are all feasible, step sizes are reduced when necessary to
|
||||
* prevent them becoming infeasible again. This is useful when parts of the problem are
|
||||
* ill-conditioned in infeasible regions (e.g., square root of a negative value). This can slow or
|
||||
* prevent progress toward a solution though, so only enable it if necessary.
|
||||
*/
|
||||
public boolean feasibleIPM = false;
|
||||
|
||||
/**
|
||||
* Enables diagnostic output.
|
||||
*
|
||||
* <p>See <a
|
||||
* href="https://sleipnirgroup.github.io/Sleipnir/md_usage.html#output">https://sleipnirgroup.github.io/Sleipnir/md_usage.html#output</a>
|
||||
* for more information.
|
||||
*/
|
||||
public boolean diagnostics = false;
|
||||
|
||||
/** Default options. */
|
||||
public Options() {}
|
||||
|
||||
/**
|
||||
* Set tolerance.
|
||||
*
|
||||
* @param tolerance The solver will stop once the error is below this tolerance.
|
||||
* @return This Options object.
|
||||
*/
|
||||
public Options withTolerance(double tolerance) {
|
||||
this.tolerance = tolerance;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set max iterations.
|
||||
*
|
||||
* @param maxIterations The maximum number of solver iterations before returning a solution.
|
||||
* @return This Options object.
|
||||
*/
|
||||
public Options withMaxIterations(int maxIterations) {
|
||||
this.maxIterations = maxIterations;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set timeout.
|
||||
*
|
||||
* @param timeout The maximum elapsed wall clock time in seconds before returning a solution.
|
||||
* @return This Options object.
|
||||
*/
|
||||
public Options withTimeout(double timeout) {
|
||||
this.timeout = timeout;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable or disable feasible IPM.
|
||||
*
|
||||
* @param feasibleIPM Enables the feasible interior-point method. When the inequality constraints
|
||||
* are all feasible, step sizes are reduced when necessary to prevent them becoming infeasible
|
||||
* again. This is useful when parts of the problem are ill-conditioned in infeasible regions
|
||||
* (e.g., square root of a negative value). This can slow or prevent progress toward a
|
||||
* solution though, so only enable it if necessary.
|
||||
* @return This Options object.
|
||||
*/
|
||||
public Options withFeasibleIPM(boolean feasibleIPM) {
|
||||
this.feasibleIPM = feasibleIPM;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable or disable diagnostics.
|
||||
*
|
||||
* @param diagnostics Enables diagnostic prints.
|
||||
* @return This Options object.
|
||||
*/
|
||||
public Options withDiagnostics(boolean diagnostics) {
|
||||
this.diagnostics = diagnostics;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import java.util.function.BiFunction;
|
||||
import java.util.function.DoubleBinaryOperator;
|
||||
import java.util.function.DoubleUnaryOperator;
|
||||
import java.util.function.UnaryOperator;
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.wpilib.math.linalg.Matrix;
|
||||
import org.wpilib.math.numbers.N1;
|
||||
import org.wpilib.math.util.Num;
|
||||
@@ -56,6 +57,30 @@ public final class NumericalIntegration {
|
||||
return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
|
||||
*
|
||||
* @param f The function to integrate. It must take two arguments x and u.
|
||||
* @param x The initial value of x.
|
||||
* @param u The value u held constant over the integration period.
|
||||
* @param dt The time over which to integrate.
|
||||
* @return the integration of dx/dt = f(x, u) for dt.
|
||||
*/
|
||||
public static SimpleMatrix rk4(
|
||||
BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> f,
|
||||
SimpleMatrix x,
|
||||
SimpleMatrix u,
|
||||
double dt) {
|
||||
var h = dt;
|
||||
|
||||
var k1 = f.apply(x, u);
|
||||
var k2 = f.apply(x.plus(k1.scale(h * 0.5)), u);
|
||||
var k3 = f.apply(x.plus(k2.scale(h * 0.5)), u);
|
||||
var k4 = f.apply(x.plus(k3.scale(h)), u);
|
||||
|
||||
return x.plus(k1.plus(k2.scale(2.0)).plus(k3.scale(2.0)).plus(k4).scale(h / 6.0));
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
|
||||
*
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#include "org_wpilib_math_jni_ArmFeedforwardJNI.h"
|
||||
#include "wpi/math/controller/ArmFeedforward.hpp"
|
||||
#include "wpi/util/jni_util.hpp"
|
||||
|
||||
using namespace wpi::util::java;
|
||||
|
||||
extern "C" {
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_jni_ArmFeedforwardJNI
|
||||
* Method: calculate
|
||||
* Signature: (DDDDDDDD)D
|
||||
*/
|
||||
JNIEXPORT jdouble JNICALL
|
||||
Java_org_wpilib_math_jni_ArmFeedforwardJNI_calculate
|
||||
(JNIEnv* env, jclass, jdouble ks, jdouble kv, jdouble ka, jdouble kg,
|
||||
jdouble currentAngle, jdouble currentVelocity, jdouble nextVelocity,
|
||||
jdouble dt)
|
||||
{
|
||||
return wpi::math::ArmFeedforward{
|
||||
wpi::units::volt_t{ks}, wpi::units::volt_t{kg},
|
||||
wpi::units::unit_t<wpi::math::ArmFeedforward::kv_unit>{kv},
|
||||
wpi::units::unit_t<wpi::math::ArmFeedforward::ka_unit>{ka},
|
||||
wpi::units::second_t{dt}}
|
||||
.Calculate(wpi::units::radian_t{currentAngle},
|
||||
wpi::units::radians_per_second_t{currentVelocity},
|
||||
wpi::units::radians_per_second_t{nextVelocity})
|
||||
.value();
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
@@ -1,39 +0,0 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#include "org_wpilib_math_jni_Ellipse2dJNI.h"
|
||||
#include "wpi/math/geometry/Ellipse2d.hpp"
|
||||
#include "wpi/util/array.hpp"
|
||||
#include "wpi/util/jni_util.hpp"
|
||||
|
||||
using namespace wpi::util::java;
|
||||
|
||||
extern "C" {
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_jni_Ellipse2dJNI
|
||||
* Method: nearest
|
||||
* Signature: (DDDDDDD[D)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_wpilib_math_jni_Ellipse2dJNI_nearest
|
||||
(JNIEnv* env, jclass, jdouble centerX, jdouble centerY, jdouble centerHeading,
|
||||
jdouble xSemiAxis, jdouble ySemiAxis, jdouble pointX, jdouble pointY,
|
||||
jdoubleArray nearestPoint)
|
||||
{
|
||||
auto point =
|
||||
wpi::math::Ellipse2d{
|
||||
wpi::math::Pose2d{wpi::units::meter_t{centerX},
|
||||
wpi::units::meter_t{centerY},
|
||||
wpi::units::radian_t{centerHeading}},
|
||||
wpi::units::meter_t{xSemiAxis}, wpi::units::meter_t{ySemiAxis}}
|
||||
.Nearest({wpi::units::meter_t{pointX}, wpi::units::meter_t{pointY}});
|
||||
|
||||
wpi::util::array buf{point.X().value(), point.Y().value()};
|
||||
env->SetDoubleArrayRegion(nearestPoint, 0, 2, buf.data());
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
65
wpimath/src/main/native/cpp/jni/SleipnirJNIUtil.hpp
Normal file
65
wpimath/src/main/native/cpp/jni/SleipnirJNIUtil.hpp
Normal file
@@ -0,0 +1,65 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#include <concepts>
|
||||
#include <vector>
|
||||
|
||||
#include <Eigen/SparseCore>
|
||||
|
||||
#include "wpi/util/jni_util.hpp"
|
||||
|
||||
namespace wpi::math::detail {
|
||||
|
||||
/**
|
||||
* Converts Eigen sparse matrix to triplets.
|
||||
*
|
||||
* @param env JNI environment.
|
||||
* @param mat Eigen sparse matrix to convert.
|
||||
* @return NativeSparseTriplets instance.
|
||||
*/
|
||||
template <typename Derived>
|
||||
requires std::derived_from<Derived, Eigen::SparseCompressedBase<Derived>>
|
||||
jobject GetTriplets(JNIEnv* env, const Derived& mat) {
|
||||
const int nonZeros = mat.nonZeros();
|
||||
|
||||
std::vector<int> rows;
|
||||
rows.reserve(nonZeros);
|
||||
|
||||
std::vector<int> cols;
|
||||
cols.reserve(nonZeros);
|
||||
|
||||
std::vector<double> values;
|
||||
values.reserve(nonZeros);
|
||||
|
||||
for (int k = 0; k < mat.outerSize(); ++k) {
|
||||
for (typename Derived::InnerIterator it{mat, k}; it; ++it) {
|
||||
rows.emplace_back(it.row());
|
||||
cols.emplace_back(it.col());
|
||||
values.emplace_back(it.value());
|
||||
}
|
||||
}
|
||||
|
||||
// Find NativeSparseTriplets class
|
||||
static wpi::util::java::JClass cls{
|
||||
env, "org/wpilib/math/autodiff/NativeSparseTriplets"};
|
||||
if (!cls) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Find NativeSparseTriplets constructor
|
||||
static jmethodID ctor = env->GetMethodID(cls, "<init>", "([I[I[D)V");
|
||||
if (!ctor) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return env->NewObject(cls, ctor, wpi::util::java::MakeJIntArray(env, rows),
|
||||
wpi::util::java::MakeJIntArray(env, cols),
|
||||
wpi::util::java::MakeJDoubleArray(env, values));
|
||||
}
|
||||
|
||||
} // namespace wpi::math::detail
|
||||
90
wpimath/src/main/native/cpp/jni/autodiff/GradientJNI.cpp
Normal file
90
wpimath/src/main/native/cpp/jni/autodiff/GradientJNI.cpp
Normal file
@@ -0,0 +1,90 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <sleipnir/autodiff/gradient.hpp>
|
||||
#include <sleipnir/autodiff/variable.hpp>
|
||||
#include <sleipnir/autodiff/variable_matrix.hpp>
|
||||
|
||||
#include "../SleipnirJNIUtil.hpp"
|
||||
#include "org_wpilib_math_autodiff_GradientJNI.h"
|
||||
#include "wpi/util/jni_util.hpp"
|
||||
|
||||
using namespace wpi::util::java;
|
||||
|
||||
extern "C" {
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_GradientJNI
|
||||
* Method: create
|
||||
* Signature: (J[J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_GradientJNI_create
|
||||
(JNIEnv* env, jclass, jlong variable, jlongArray wrt)
|
||||
{
|
||||
auto& variableObj = *reinterpret_cast<slp::Variable<double>*>(variable);
|
||||
|
||||
JSpan<const jlong> wrtSpan{env, wrt};
|
||||
slp::VariableMatrix<double> wrtObj(slp::detail::empty, wrtSpan.size(), 1);
|
||||
for (size_t i = 0; i < wrtSpan.size(); ++i) {
|
||||
wrtObj[i] = *reinterpret_cast<slp::Variable<double>*>(wrtSpan[i]);
|
||||
}
|
||||
|
||||
return reinterpret_cast<jlong>(
|
||||
new slp::Gradient{variableObj, std::move(wrtObj)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_GradientJNI
|
||||
* Method: destroy
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_wpilib_math_autodiff_GradientJNI_destroy
|
||||
(JNIEnv* env, jclass, jlong handle)
|
||||
{
|
||||
delete reinterpret_cast<slp::Gradient<double>*>(handle);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_GradientJNI
|
||||
* Method: get
|
||||
* Signature: (J)[J
|
||||
*/
|
||||
JNIEXPORT jlongArray JNICALL
|
||||
Java_org_wpilib_math_autodiff_GradientJNI_get
|
||||
(JNIEnv* env, jclass, jlong handle)
|
||||
{
|
||||
auto& gradient = *reinterpret_cast<slp::Gradient<double>*>(handle);
|
||||
auto g = gradient.get();
|
||||
|
||||
std::vector<jlong> varHandles;
|
||||
varHandles.reserve(g.size());
|
||||
for (auto& var : g) {
|
||||
varHandles.emplace_back(
|
||||
reinterpret_cast<jlong>(new slp::Variable<double>{var}));
|
||||
}
|
||||
|
||||
return MakeJLongArray(env, varHandles);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_GradientJNI
|
||||
* Method: value
|
||||
* Signature: (J)Ljava/lang/Object;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL
|
||||
Java_org_wpilib_math_autodiff_GradientJNI_value
|
||||
(JNIEnv* env, jclass, jlong handle)
|
||||
{
|
||||
auto& gradient = *reinterpret_cast<slp::Gradient<double>*>(handle);
|
||||
return wpi::math::detail::GetTriplets(env, gradient.value());
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
96
wpimath/src/main/native/cpp/jni/autodiff/HessianJNI.cpp
Normal file
96
wpimath/src/main/native/cpp/jni/autodiff/HessianJNI.cpp
Normal file
@@ -0,0 +1,96 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <sleipnir/autodiff/hessian.hpp>
|
||||
#include <sleipnir/autodiff/variable.hpp>
|
||||
#include <sleipnir/autodiff/variable_matrix.hpp>
|
||||
|
||||
#include "../SleipnirJNIUtil.hpp"
|
||||
#include "org_wpilib_math_autodiff_HessianJNI.h"
|
||||
#include "wpi/util/jni_util.hpp"
|
||||
|
||||
using namespace wpi::util::java;
|
||||
|
||||
extern "C" {
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_HessianJNI
|
||||
* Method: create
|
||||
* Signature: (J[J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_HessianJNI_create
|
||||
(JNIEnv* env, jclass, jlong variable, jlongArray wrt)
|
||||
{
|
||||
auto& variableObj = *reinterpret_cast<slp::Variable<double>*>(variable);
|
||||
|
||||
JSpan<const jlong> wrtSpan{env, wrt};
|
||||
slp::VariableMatrix<double> wrtObj(slp::detail::empty, wrtSpan.size(), 1);
|
||||
for (size_t i = 0; i < wrtSpan.size(); ++i) {
|
||||
wrtObj[i] = *reinterpret_cast<slp::Variable<double>*>(wrtSpan[i]);
|
||||
}
|
||||
|
||||
return reinterpret_cast<jlong>(
|
||||
new slp::Hessian<double, Eigen::Lower | Eigen::Upper>{variableObj,
|
||||
std::move(wrtObj)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_HessianJNI
|
||||
* Method: destroy
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_wpilib_math_autodiff_HessianJNI_destroy
|
||||
(JNIEnv* env, jclass, jlong handle)
|
||||
{
|
||||
delete reinterpret_cast<slp::Hessian<double, Eigen::Lower | Eigen::Upper>*>(
|
||||
handle);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_HessianJNI
|
||||
* Method: get
|
||||
* Signature: (J)[J
|
||||
*/
|
||||
JNIEXPORT jlongArray JNICALL
|
||||
Java_org_wpilib_math_autodiff_HessianJNI_get
|
||||
(JNIEnv* env, jclass, jlong handle)
|
||||
{
|
||||
auto& hessian =
|
||||
*reinterpret_cast<slp::Hessian<double, Eigen::Lower | Eigen::Upper>*>(
|
||||
handle);
|
||||
auto H = hessian.get();
|
||||
|
||||
std::vector<jlong> varHandles;
|
||||
varHandles.reserve(H.size());
|
||||
for (auto& var : H) {
|
||||
varHandles.emplace_back(
|
||||
reinterpret_cast<jlong>(new slp::Variable<double>{var}));
|
||||
}
|
||||
|
||||
return MakeJLongArray(env, varHandles);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_HessianJNI
|
||||
* Method: value
|
||||
* Signature: (J)Ljava/lang/Object;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL
|
||||
Java_org_wpilib_math_autodiff_HessianJNI_value
|
||||
(JNIEnv* env, jclass, jlong handle)
|
||||
{
|
||||
auto& hessian =
|
||||
*reinterpret_cast<slp::Hessian<double, Eigen::Lower | Eigen::Upper>*>(
|
||||
handle);
|
||||
return wpi::math::detail::GetTriplets(env, hessian.value());
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
96
wpimath/src/main/native/cpp/jni/autodiff/JacobianJNI.cpp
Normal file
96
wpimath/src/main/native/cpp/jni/autodiff/JacobianJNI.cpp
Normal file
@@ -0,0 +1,96 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <sleipnir/autodiff/jacobian.hpp>
|
||||
#include <sleipnir/autodiff/variable.hpp>
|
||||
#include <sleipnir/autodiff/variable_matrix.hpp>
|
||||
|
||||
#include "../SleipnirJNIUtil.hpp"
|
||||
#include "org_wpilib_math_autodiff_JacobianJNI.h"
|
||||
#include "wpi/util/jni_util.hpp"
|
||||
|
||||
using namespace wpi::util::java;
|
||||
|
||||
extern "C" {
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_JacobianJNI
|
||||
* Method: create
|
||||
* Signature: ([J[J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_JacobianJNI_create
|
||||
(JNIEnv* env, jclass, jlongArray variables, jlongArray wrt)
|
||||
{
|
||||
JSpan<const jlong> variablesSpan{env, variables};
|
||||
slp::VariableMatrix<double> variablesObj(slp::detail::empty,
|
||||
variablesSpan.size(), 1);
|
||||
for (size_t i = 0; i < variablesSpan.size(); ++i) {
|
||||
variablesObj[i] =
|
||||
*reinterpret_cast<slp::Variable<double>*>(variablesSpan[i]);
|
||||
}
|
||||
|
||||
JSpan<const jlong> wrtSpan{env, wrt};
|
||||
slp::VariableMatrix<double> wrtObj(slp::detail::empty, wrtSpan.size(), 1);
|
||||
for (size_t i = 0; i < wrtSpan.size(); ++i) {
|
||||
wrtObj[i] = *reinterpret_cast<slp::Variable<double>*>(wrtSpan[i]);
|
||||
}
|
||||
|
||||
return reinterpret_cast<jlong>(
|
||||
new slp::Jacobian{std::move(variablesObj), std::move(wrtObj)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_JacobianJNI
|
||||
* Method: destroy
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_wpilib_math_autodiff_JacobianJNI_destroy
|
||||
(JNIEnv* env, jclass, jlong handle)
|
||||
{
|
||||
delete reinterpret_cast<slp::Jacobian<double>*>(handle);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_JacobianJNI
|
||||
* Method: get
|
||||
* Signature: (J)[J
|
||||
*/
|
||||
JNIEXPORT jlongArray JNICALL
|
||||
Java_org_wpilib_math_autodiff_JacobianJNI_get
|
||||
(JNIEnv* env, jclass, jlong handle)
|
||||
{
|
||||
auto& jacobian = *reinterpret_cast<slp::Jacobian<double>*>(handle);
|
||||
auto J = jacobian.get();
|
||||
|
||||
std::vector<jlong> varHandles;
|
||||
varHandles.reserve(J.size());
|
||||
for (auto& var : J) {
|
||||
varHandles.emplace_back(
|
||||
reinterpret_cast<jlong>(new slp::Variable<double>{var}));
|
||||
}
|
||||
|
||||
return MakeJLongArray(env, varHandles);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_JacobianJNI
|
||||
* Method: value
|
||||
* Signature: (J)Ljava/lang/Object;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL
|
||||
Java_org_wpilib_math_autodiff_JacobianJNI_value
|
||||
(JNIEnv* env, jclass, jlong handle)
|
||||
{
|
||||
auto& jacobian = *reinterpret_cast<slp::Jacobian<double>*>(handle);
|
||||
return wpi::math::detail::GetTriplets(env, jacobian.value());
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
463
wpimath/src/main/native/cpp/jni/autodiff/VariableJNI.cpp
Normal file
463
wpimath/src/main/native/cpp/jni/autodiff/VariableJNI.cpp
Normal file
@@ -0,0 +1,463 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#include <sleipnir/autodiff/variable.hpp>
|
||||
|
||||
#include "org_wpilib_math_autodiff_VariableJNI.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: createDefault
|
||||
* Signature: ()J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_createDefault
|
||||
(JNIEnv* env, jclass)
|
||||
{
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: createDouble
|
||||
* Signature: (D)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_createDouble
|
||||
(JNIEnv* env, jclass, jdouble value)
|
||||
{
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{value});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: createInt
|
||||
* Signature: (I)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_createInt
|
||||
(JNIEnv* env, jclass, jint value)
|
||||
{
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{value});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: destroy
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_destroy
|
||||
(JNIEnv* env, jclass, jlong handle)
|
||||
{
|
||||
delete reinterpret_cast<slp::Variable<double>*>(handle);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: setValue
|
||||
* Signature: (JD)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_setValue
|
||||
(JNIEnv* env, jclass, jlong handle, jdouble value)
|
||||
{
|
||||
auto& lhsVar = *reinterpret_cast<slp::Variable<double>*>(handle);
|
||||
lhsVar.set_value(value);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: times
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_times
|
||||
(JNIEnv* env, jclass, jlong handle, jlong rhs)
|
||||
{
|
||||
auto& lhsVar = *reinterpret_cast<slp::Variable<double>*>(handle);
|
||||
auto& rhsVar = *reinterpret_cast<slp::Variable<double>*>(rhs);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{lhsVar * rhsVar});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: div
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_div
|
||||
(JNIEnv* env, jclass, jlong handle, jlong rhs)
|
||||
{
|
||||
auto& lhsVar = *reinterpret_cast<slp::Variable<double>*>(handle);
|
||||
auto& rhsVar = *reinterpret_cast<slp::Variable<double>*>(rhs);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{lhsVar / rhsVar});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: plus
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_plus
|
||||
(JNIEnv* env, jclass, jlong handle, jlong rhs)
|
||||
{
|
||||
auto& lhsVar = *reinterpret_cast<slp::Variable<double>*>(handle);
|
||||
auto& rhsVar = *reinterpret_cast<slp::Variable<double>*>(rhs);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{lhsVar + rhsVar});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: minus
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_minus
|
||||
(JNIEnv* env, jclass, jlong handle, jlong rhs)
|
||||
{
|
||||
auto& lhsVar = *reinterpret_cast<slp::Variable<double>*>(handle);
|
||||
auto& rhsVar = *reinterpret_cast<slp::Variable<double>*>(rhs);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{lhsVar - rhsVar});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: unaryMinus
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_unaryMinus
|
||||
(JNIEnv* env, jclass, jlong handle)
|
||||
{
|
||||
auto& lhsVar = *reinterpret_cast<slp::Variable<double>*>(handle);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{-lhsVar});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: value
|
||||
* Signature: (J)D
|
||||
*/
|
||||
JNIEXPORT jdouble JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_value
|
||||
(JNIEnv* env, jclass, jlong handle)
|
||||
{
|
||||
auto& lhsVar = *reinterpret_cast<slp::Variable<double>*>(handle);
|
||||
return lhsVar.value();
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: type
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_type
|
||||
(JNIEnv* env, jclass, jlong handle)
|
||||
{
|
||||
auto& lhsVar = *reinterpret_cast<slp::Variable<double>*>(handle);
|
||||
return static_cast<jint>(lhsVar.type());
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: abs
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_abs
|
||||
(JNIEnv* env, jclass, jlong x)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{abs(xVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: acos
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_acos
|
||||
(JNIEnv* env, jclass, jlong x)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{acos(xVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: asin
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_asin
|
||||
(JNIEnv* env, jclass, jlong x)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{asin(xVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: atan
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_atan
|
||||
(JNIEnv* env, jclass, jlong x)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{atan(xVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: atan2
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_atan2
|
||||
(JNIEnv* env, jclass, jlong y, jlong x)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
auto& yVar = *reinterpret_cast<slp::Variable<double>*>(y);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{atan2(yVar, xVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: cbrt
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_cbrt
|
||||
(JNIEnv* env, jclass, jlong x)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{cbrt(xVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: cos
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_cos
|
||||
(JNIEnv* env, jclass, jlong x)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{cos(xVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: cosh
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_cosh
|
||||
(JNIEnv* env, jclass, jlong x)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{cosh(xVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: exp
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_exp
|
||||
(JNIEnv* env, jclass, jlong x)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{exp(xVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: hypot
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_hypot
|
||||
(JNIEnv* env, jclass, jlong x, jlong y)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
auto& yVar = *reinterpret_cast<slp::Variable<double>*>(y);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{hypot(xVar, yVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: log
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_log
|
||||
(JNIEnv* env, jclass, jlong x)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{log(xVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: log10
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_log10
|
||||
(JNIEnv* env, jclass, jlong x)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{log10(xVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: max
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_max
|
||||
(JNIEnv* env, jclass, jlong a, jlong b)
|
||||
{
|
||||
auto& aVar = *reinterpret_cast<slp::Variable<double>*>(a);
|
||||
auto& bVar = *reinterpret_cast<slp::Variable<double>*>(b);
|
||||
return reinterpret_cast<jlong>(
|
||||
new slp::Variable<double>{(slp::max)(aVar, bVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: min
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_min
|
||||
(JNIEnv* env, jclass, jlong a, jlong b)
|
||||
{
|
||||
auto& aVar = *reinterpret_cast<slp::Variable<double>*>(a);
|
||||
auto& bVar = *reinterpret_cast<slp::Variable<double>*>(b);
|
||||
return reinterpret_cast<jlong>(
|
||||
new slp::Variable<double>{(slp::min)(aVar, bVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: pow
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_pow
|
||||
(JNIEnv* env, jclass, jlong base, jlong power)
|
||||
{
|
||||
auto& baseVar = *reinterpret_cast<slp::Variable<double>*>(base);
|
||||
auto& powerVar = *reinterpret_cast<slp::Variable<double>*>(power);
|
||||
return reinterpret_cast<jlong>(
|
||||
new slp::Variable<double>{pow(baseVar, powerVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: signum
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_signum
|
||||
(JNIEnv* env, jclass, jlong x)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{sign(xVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: sin
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_sin
|
||||
(JNIEnv* env, jclass, jlong x)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{sin(xVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: sinh
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_sinh
|
||||
(JNIEnv* env, jclass, jlong x)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{sinh(xVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: sqrt
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_sqrt
|
||||
(JNIEnv* env, jclass, jlong x)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{sqrt(xVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: tan
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_tan
|
||||
(JNIEnv* env, jclass, jlong x)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{tan(xVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: tanh
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_tanh
|
||||
(JNIEnv* env, jclass, jlong x)
|
||||
{
|
||||
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
|
||||
return reinterpret_cast<jlong>(new slp::Variable<double>{tanh(xVar)});
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableJNI
|
||||
* Method: totalNativeMemoryUsage
|
||||
* Signature: ()J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableJNI_totalNativeMemoryUsage
|
||||
(JNIEnv* env, jclass)
|
||||
{
|
||||
return slp::global_pool_resource().blocks_in_use() *
|
||||
sizeof(slp::detail::Expression<double>);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <sleipnir/autodiff/variable.hpp>
|
||||
#include <sleipnir/autodiff/variable_matrix.hpp>
|
||||
|
||||
#include "org_wpilib_math_autodiff_VariableMatrixJNI.h"
|
||||
#include "wpi/util/jni_util.hpp"
|
||||
|
||||
using namespace wpi::util::java;
|
||||
|
||||
extern "C" {
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_autodiff_VariableMatrixJNI
|
||||
* Method: solve
|
||||
* Signature: ([JI[JI)[J
|
||||
*/
|
||||
JNIEXPORT jlongArray JNICALL
|
||||
Java_org_wpilib_math_autodiff_VariableMatrixJNI_solve
|
||||
(JNIEnv* env, jclass, jlongArray A, jint Acols, jlongArray B, jint Bcols)
|
||||
{
|
||||
JSpan<const jlong> ASpan{env, A};
|
||||
slp::VariableMatrix<double> AObj(slp::detail::empty, ASpan.size() / Acols,
|
||||
Acols);
|
||||
for (size_t i = 0; i < ASpan.size(); ++i) {
|
||||
AObj[i] = *reinterpret_cast<slp::Variable<double>*>(ASpan[i]);
|
||||
}
|
||||
|
||||
JSpan<const jlong> BSpan{env, B};
|
||||
slp::VariableMatrix<double> BObj(slp::detail::empty, BSpan.size() / Bcols,
|
||||
Bcols);
|
||||
for (size_t i = 0; i < BSpan.size(); ++i) {
|
||||
BObj[i] = *reinterpret_cast<slp::Variable<double>*>(BSpan[i]);
|
||||
}
|
||||
|
||||
auto X = slp::solve(AObj, BObj);
|
||||
|
||||
std::vector<jlong> varHandles;
|
||||
varHandles.reserve(X.size());
|
||||
for (auto& var : X) {
|
||||
varHandles.emplace_back(
|
||||
reinterpret_cast<jlong>(new slp::Variable<double>{var}));
|
||||
}
|
||||
return MakeJLongArray(env, varHandles);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
255
wpimath/src/main/native/cpp/jni/optimization/ProblemJNI.cpp
Normal file
255
wpimath/src/main/native/cpp/jni/optimization/ProblemJNI.cpp
Normal file
@@ -0,0 +1,255 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <sleipnir/optimization/problem.hpp>
|
||||
|
||||
#include "../SleipnirJNIUtil.hpp"
|
||||
#include "org_wpilib_math_optimization_ProblemJNI.h"
|
||||
#include "wpi/util/jni_util.hpp"
|
||||
|
||||
using namespace wpi::util::java;
|
||||
|
||||
extern "C" {
|
||||
|
||||
namespace {
|
||||
|
||||
// ProblemJNI_solve() sets these before calling Problem::solve() so the Java
|
||||
// callback has a valid JNIEnv and object on which to call
|
||||
// Problem.runCallbacks()
|
||||
thread_local JNIEnv* callbackEnv;
|
||||
thread_local jobject callbackObj;
|
||||
|
||||
} // namespace
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_optimization_ProblemJNI
|
||||
* Method: create
|
||||
* Signature: ()J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_wpilib_math_optimization_ProblemJNI_create
|
||||
(JNIEnv* env, jclass)
|
||||
{
|
||||
auto problem = new slp::Problem<double>;
|
||||
|
||||
// Configure Java iteration callbacks
|
||||
problem->add_persistent_callback(
|
||||
[](const slp::IterationInfo<double>& info) -> bool {
|
||||
// Find Problem class
|
||||
static JClass cls{callbackEnv, "org/wpilib/math/optimization/Problem"};
|
||||
if (!cls) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Find Problem.runCallbacks()
|
||||
static jmethodID runCallbacks = callbackEnv->GetMethodID(
|
||||
cls, "runCallbacks",
|
||||
"(III[DLorg/wpilib/math/autodiff/NativeSparseTriplets;"
|
||||
"Lorg/wpilib/math/autodiff/NativeSparseTriplets;"
|
||||
"Lorg/wpilib/math/autodiff/NativeSparseTriplets;"
|
||||
"Lorg/wpilib/math/autodiff/NativeSparseTriplets;)Z");
|
||||
if (!runCallbacks) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Run Java callbacks
|
||||
return callbackEnv->CallBooleanMethod(
|
||||
callbackObj, runCallbacks, info.A_e.rows(), info.A_i.rows(),
|
||||
info.iteration, MakeJDoubleArray(callbackEnv, info.x),
|
||||
wpi::math::detail::GetTriplets(callbackEnv, info.g),
|
||||
wpi::math::detail::GetTriplets(callbackEnv, info.H),
|
||||
wpi::math::detail::GetTriplets(callbackEnv, info.A_e),
|
||||
wpi::math::detail::GetTriplets(callbackEnv, info.A_i));
|
||||
});
|
||||
|
||||
return reinterpret_cast<jlong>(problem);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_optimization_ProblemJNI
|
||||
* Method: destroy
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_wpilib_math_optimization_ProblemJNI_destroy
|
||||
(JNIEnv* env, jclass, jlong handle)
|
||||
{
|
||||
delete reinterpret_cast<slp::Problem<double>*>(handle);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_optimization_ProblemJNI
|
||||
* Method: decisionVariable
|
||||
* Signature: (JII)[J
|
||||
*/
|
||||
JNIEXPORT jlongArray JNICALL
|
||||
Java_org_wpilib_math_optimization_ProblemJNI_decisionVariable
|
||||
(JNIEnv* env, jclass, jlong handle, jint rows, jint cols)
|
||||
{
|
||||
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
|
||||
auto vars = problem.decision_variable(rows, cols);
|
||||
|
||||
std::vector<jlong> varHandles;
|
||||
varHandles.reserve(vars.size());
|
||||
for (auto& var : vars) {
|
||||
varHandles.emplace_back(
|
||||
reinterpret_cast<jlong>(new slp::Variable<double>{var}));
|
||||
}
|
||||
return MakeJLongArray(env, varHandles);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_optimization_ProblemJNI
|
||||
* Method: symmetricDecisionVariable
|
||||
* Signature: (JI)[J
|
||||
*/
|
||||
JNIEXPORT jlongArray JNICALL
|
||||
Java_org_wpilib_math_optimization_ProblemJNI_symmetricDecisionVariable
|
||||
(JNIEnv* env, jclass, jlong handle, jint rows)
|
||||
{
|
||||
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
|
||||
auto vars = problem.symmetric_decision_variable(rows);
|
||||
|
||||
std::vector<jlong> varHandles;
|
||||
varHandles.reserve(vars.size());
|
||||
for (auto& var : vars) {
|
||||
varHandles.emplace_back(
|
||||
reinterpret_cast<jlong>(new slp::Variable<double>{var}));
|
||||
}
|
||||
return MakeJLongArray(env, varHandles);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_optimization_ProblemJNI
|
||||
* Method: minimize
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_wpilib_math_optimization_ProblemJNI_minimize
|
||||
(JNIEnv* env, jclass, jlong handle, jlong costHandle)
|
||||
{
|
||||
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
|
||||
auto& costVar = *reinterpret_cast<slp::Variable<double>*>(costHandle);
|
||||
problem.minimize(costVar);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_optimization_ProblemJNI
|
||||
* Method: maximize
|
||||
* Signature: (JJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_wpilib_math_optimization_ProblemJNI_maximize
|
||||
(JNIEnv* env, jclass, jlong handle, jlong objectiveHandle)
|
||||
{
|
||||
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
|
||||
auto& objectiveVar =
|
||||
*reinterpret_cast<slp::Variable<double>*>(objectiveHandle);
|
||||
problem.maximize(objectiveVar);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_optimization_ProblemJNI
|
||||
* Method: subjectToEq
|
||||
* Signature: (J[J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_wpilib_math_optimization_ProblemJNI_subjectToEq
|
||||
(JNIEnv* env, jclass, jlong handle, jlongArray constraintHandles)
|
||||
{
|
||||
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
|
||||
JSpan<const jlong> constraintHandlesSpan{env, constraintHandles};
|
||||
|
||||
for (const auto& constraintHandle : constraintHandlesSpan) {
|
||||
const auto& constraint =
|
||||
*reinterpret_cast<slp::Variable<double>*>(constraintHandle);
|
||||
problem.subject_to(constraint == 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_optimization_ProblemJNI
|
||||
* Method: subjectToIneq
|
||||
* Signature: (J[J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_wpilib_math_optimization_ProblemJNI_subjectToIneq
|
||||
(JNIEnv* env, jclass, jlong handle, jlongArray constraintHandles)
|
||||
{
|
||||
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
|
||||
JSpan<const jlong> constraintHandlesSpan{env, constraintHandles};
|
||||
|
||||
for (const auto& constraintHandle : constraintHandlesSpan) {
|
||||
const auto& constraint =
|
||||
*reinterpret_cast<slp::Variable<double>*>(constraintHandle);
|
||||
problem.subject_to(constraint >= 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_optimization_ProblemJNI
|
||||
* Method: costFunctionType
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_wpilib_math_optimization_ProblemJNI_costFunctionType
|
||||
(JNIEnv* env, jclass, jlong handle)
|
||||
{
|
||||
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
|
||||
return static_cast<jint>(problem.cost_function_type());
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_optimization_ProblemJNI
|
||||
* Method: equalityConstraintType
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_wpilib_math_optimization_ProblemJNI_equalityConstraintType
|
||||
(JNIEnv* env, jclass, jlong handle)
|
||||
{
|
||||
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
|
||||
return static_cast<jint>(problem.equality_constraint_type());
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_optimization_ProblemJNI
|
||||
* Method: inequalityConstraintType
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_wpilib_math_optimization_ProblemJNI_inequalityConstraintType
|
||||
(JNIEnv* env, jclass, jlong handle)
|
||||
{
|
||||
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
|
||||
return static_cast<jint>(problem.inequality_constraint_type());
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_wpilib_math_optimization_ProblemJNI
|
||||
* Method: solve
|
||||
* Signature: (Ljava/lang/Object;JDIDZZ)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_wpilib_math_optimization_ProblemJNI_solve
|
||||
(JNIEnv* env, jclass, jobject obj, jlong handle, jdouble tolerance,
|
||||
jint maxIterations, jdouble timeout, jboolean feasibleIPM,
|
||||
jboolean diagnostics)
|
||||
{
|
||||
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
|
||||
|
||||
callbackEnv = env;
|
||||
callbackObj = obj;
|
||||
|
||||
slp::Options options{
|
||||
tolerance, maxIterations, std::chrono::duration<double>{timeout},
|
||||
static_cast<bool>(feasibleIPM), static_cast<bool>(diagnostics)};
|
||||
return static_cast<int>(problem.solve(options));
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
24
wpimath/src/test/java/org/wpilib/math/DoubleRange.java
Normal file
24
wpimath/src/test/java/org/wpilib/math/DoubleRange.java
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
||||
public final class DoubleRange {
|
||||
private DoubleRange() {
|
||||
// Utility class.
|
||||
}
|
||||
|
||||
public static ArrayList<Double> range(double start, double end, double step) {
|
||||
var ret = new ArrayList<Double>();
|
||||
|
||||
int steps = (int) ((end - start) / step);
|
||||
for (int i = 0; i < steps; ++i) {
|
||||
ret.add(start + i * step);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
41
wpimath/src/test/java/org/wpilib/math/MatrixAssertions.java
Normal file
41
wpimath/src/test/java/org/wpilib/math/MatrixAssertions.java
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import org.ejml.dense.row.MatrixFeatures_DDRM;
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
||||
public final class MatrixAssertions {
|
||||
private MatrixAssertions() {
|
||||
// Utility class.
|
||||
}
|
||||
|
||||
/**
|
||||
* Asserts that two SimpleMatrices are equal.
|
||||
*
|
||||
* @param expected Expected value.
|
||||
* @param actual The value to check against expected.
|
||||
*/
|
||||
public static void assertEquals(SimpleMatrix expected, SimpleMatrix actual) {
|
||||
assertFalse(MatrixFeatures_DDRM.hasUncountable(expected.getDDRM()));
|
||||
assertTrue(MatrixFeatures_DDRM.isEquals(expected.getDDRM(), actual.getDDRM()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Asserts that two SimpleMatrices are equal to within a positive delta.
|
||||
*
|
||||
* @param expected Expected value.
|
||||
* @param actual The value to check against expected.
|
||||
* @param delta The maximum delta between expected and actual for which both values are still
|
||||
* considered equal.
|
||||
*/
|
||||
public static void assertEquals(SimpleMatrix expected, SimpleMatrix actual, double delta) {
|
||||
assertFalse(MatrixFeatures_DDRM.hasUncountable(expected.getDDRM()));
|
||||
assertTrue(MatrixFeatures_DDRM.isEquals(expected.getDDRM(), actual.getDDRM(), delta));
|
||||
}
|
||||
}
|
||||
@@ -2,15 +2,15 @@
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.jni;
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class ArmFeedforwardJNITest {
|
||||
public class GradientJNITest {
|
||||
@Test
|
||||
public void testLink() {
|
||||
assertDoesNotThrow(ArmFeedforwardJNI::forceLoad);
|
||||
assertDoesNotThrow(GradientJNI::forceLoad);
|
||||
}
|
||||
}
|
||||
964
wpimath/src/test/java/org/wpilib/math/autodiff/GradientTest.java
Normal file
964
wpimath/src/test/java/org/wpilib/math/autodiff/GradientTest.java
Normal file
@@ -0,0 +1,964 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.wpilib.math.MatrixAssertions.assertEquals;
|
||||
import static org.wpilib.math.autodiff.Variable.abs;
|
||||
import static org.wpilib.math.autodiff.Variable.acos;
|
||||
import static org.wpilib.math.autodiff.Variable.asin;
|
||||
import static org.wpilib.math.autodiff.Variable.atan;
|
||||
import static org.wpilib.math.autodiff.Variable.atan2;
|
||||
import static org.wpilib.math.autodiff.Variable.cbrt;
|
||||
import static org.wpilib.math.autodiff.Variable.cos;
|
||||
import static org.wpilib.math.autodiff.Variable.cosh;
|
||||
import static org.wpilib.math.autodiff.Variable.exp;
|
||||
import static org.wpilib.math.autodiff.Variable.hypot;
|
||||
import static org.wpilib.math.autodiff.Variable.log;
|
||||
import static org.wpilib.math.autodiff.Variable.log10;
|
||||
import static org.wpilib.math.autodiff.Variable.max;
|
||||
import static org.wpilib.math.autodiff.Variable.min;
|
||||
import static org.wpilib.math.autodiff.Variable.pow;
|
||||
import static org.wpilib.math.autodiff.Variable.signum;
|
||||
import static org.wpilib.math.autodiff.Variable.sin;
|
||||
import static org.wpilib.math.autodiff.Variable.sinh;
|
||||
import static org.wpilib.math.autodiff.Variable.sqrt;
|
||||
import static org.wpilib.math.autodiff.Variable.tan;
|
||||
import static org.wpilib.math.autodiff.Variable.tanh;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class GradientTest {
|
||||
@Test
|
||||
void testTrivialCase() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var a = new Variable();
|
||||
a.setValue(10);
|
||||
var b = new Variable();
|
||||
b.setValue(20);
|
||||
var c = a;
|
||||
|
||||
try (var g_a_a = new Gradient(a, a)) {
|
||||
assertEquals(1.0, g_a_a.value().get(0, 0));
|
||||
}
|
||||
try (var g_a_b = new Gradient(a, b)) {
|
||||
assertEquals(0.0, g_a_b.value().get(0, 0));
|
||||
}
|
||||
try (var g_c_a = new Gradient(c, a)) {
|
||||
assertEquals(1.0, g_c_a.value().get(0, 0));
|
||||
}
|
||||
try (var g_c_b = new Gradient(c, b)) {
|
||||
assertEquals(0.0, g_c_b.value().get(0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testUnaryPlus() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var a = new Variable();
|
||||
a.setValue(10);
|
||||
var c = a.unaryPlus();
|
||||
|
||||
assertEquals(a.value(), c.value());
|
||||
try (var g_c_a = new Gradient(c, a)) {
|
||||
assertEquals(1.0, g_c_a.value().get(0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testUnaryMinus() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var a = new Variable();
|
||||
a.setValue(10);
|
||||
var c = a.unaryMinus();
|
||||
|
||||
assertEquals(a.unaryMinus().value(), c.value());
|
||||
try (var g_c_a = new Gradient(c, a)) {
|
||||
assertEquals(-1.0, g_c_a.value().get(0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testIdenticalVariables() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var a = new Variable();
|
||||
a.setValue(10);
|
||||
var x = a;
|
||||
var c = a.times(a).plus(x);
|
||||
|
||||
assertEquals(a.value() * a.value() + x.value(), c.value());
|
||||
try (var g_x_a = new Gradient(x, a);
|
||||
var g_c_a = new Gradient(c, a)) {
|
||||
assertEquals(2 * a.value() + g_x_a.value().get(0, 0), g_c_a.value().get(0, 0));
|
||||
}
|
||||
try (var g_a_x = new Gradient(a, x);
|
||||
var g_c_x = new Gradient(c, x)) {
|
||||
assertEquals(2 * a.value() * g_a_x.value().get(0, 0) + 1, g_c_x.value().get(0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testElementary() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var a = new Variable();
|
||||
a.setValue(1.0);
|
||||
var b = new Variable();
|
||||
b.setValue(2.0);
|
||||
|
||||
var c = a.times(-2);
|
||||
try (var g_c_a = new Gradient(c, a)) {
|
||||
assertEquals(-2.0, g_c_a.value().get(0, 0));
|
||||
}
|
||||
|
||||
c = a.div(3.0);
|
||||
try (var g_c_a = new Gradient(c, a)) {
|
||||
assertEquals(1.0 / 3.0, g_c_a.value().get(0, 0));
|
||||
}
|
||||
|
||||
a.setValue(100.0);
|
||||
b.setValue(200.0);
|
||||
|
||||
c = a.plus(b);
|
||||
try (var g_c_a = new Gradient(c, a)) {
|
||||
assertEquals(1.0, g_c_a.value().get(0, 0));
|
||||
}
|
||||
try (var g_c_b = new Gradient(c, b)) {
|
||||
assertEquals(1.0, g_c_b.value().get(0, 0));
|
||||
}
|
||||
|
||||
c = a.minus(b);
|
||||
try (var g_c_a = new Gradient(c, a)) {
|
||||
assertEquals(1.0, g_c_a.value().get(0, 0));
|
||||
}
|
||||
try (var g_c_b = new Gradient(c, b)) {
|
||||
assertEquals(-1.0, g_c_b.value().get(0, 0));
|
||||
}
|
||||
|
||||
c = a.unaryMinus().plus(b);
|
||||
try (var g_c_a = new Gradient(c, a)) {
|
||||
assertEquals(-1.0, g_c_a.value().get(0, 0));
|
||||
}
|
||||
try (var g_c_b = new Gradient(c, b)) {
|
||||
assertEquals(1.0, g_c_b.value().get(0, 0));
|
||||
}
|
||||
|
||||
c = a.plus(1);
|
||||
try (var g_c_a = new Gradient(c, a)) {
|
||||
assertEquals(1.0, g_c_a.value().get(0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testTrigonometry() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
x.setValue(0.5);
|
||||
|
||||
// Math.sin(x)
|
||||
assertEquals(Math.sin(x.value()), sin(x).value());
|
||||
|
||||
var g = new Gradient(sin(x), x);
|
||||
assertEquals(Math.cos(x.value()), g.get().value().get(0, 0));
|
||||
assertEquals(Math.cos(x.value()), g.value().get(0, 0));
|
||||
|
||||
// Math.cos(x)
|
||||
assertEquals(Math.cos(x.value()), cos(x).value());
|
||||
|
||||
g.close();
|
||||
g = new Gradient(cos(x), x);
|
||||
assertEquals(-Math.sin(x.value()), g.get().value().get(0, 0));
|
||||
assertEquals(-Math.sin(x.value()), g.value().get(0, 0));
|
||||
|
||||
// Math.tan(x)
|
||||
assertEquals(Math.tan(x.value()), tan(x).value());
|
||||
|
||||
g.close();
|
||||
g = new Gradient(tan(x), x);
|
||||
assertEquals(1.0 / (Math.cos(x.value()) * Math.cos(x.value())), g.get().value().get(0, 0));
|
||||
assertEquals(1.0 / (Math.cos(x.value()) * Math.cos(x.value())), g.value().get(0, 0));
|
||||
|
||||
// Math.asin(x)
|
||||
assertEquals(Math.asin(x.value()), asin(x).value(), 1e-15);
|
||||
|
||||
g.close();
|
||||
g = new Gradient(asin(x), x);
|
||||
assertEquals(1.0 / Math.sqrt(1 - x.value() * x.value()), g.get().value().get(0, 0));
|
||||
assertEquals(1.0 / Math.sqrt(1 - x.value() * x.value()), g.value().get(0, 0));
|
||||
|
||||
// Math.acos(x)
|
||||
assertEquals(Math.acos(x.value()), acos(x).value(), 1e-15);
|
||||
|
||||
g.close();
|
||||
g = new Gradient(acos(x), x);
|
||||
assertEquals(-1.0 / Math.sqrt(1 - x.value() * x.value()), g.get().value().get(0, 0));
|
||||
assertEquals(-1.0 / Math.sqrt(1 - x.value() * x.value()), g.value().get(0, 0));
|
||||
|
||||
// Math.atan(x)
|
||||
assertEquals(Math.atan(x.value()), atan(x).value(), 1e-15);
|
||||
|
||||
g.close();
|
||||
g = new Gradient(atan(x), x);
|
||||
assertEquals(1.0 / (1 + x.value() * x.value()), g.get().value().get(0, 0));
|
||||
assertEquals(1.0 / (1 + x.value() * x.value()), g.value().get(0, 0));
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testHyperbolic() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
x.setValue(1.0);
|
||||
|
||||
// sinh(x)
|
||||
assertEquals(Math.sinh(x.value()), sinh(x).value());
|
||||
|
||||
var g = new Gradient(sinh(x), x);
|
||||
assertEquals(Math.cosh(x.value()), g.get().value().get(0, 0), 1e-15);
|
||||
assertEquals(Math.cosh(x.value()), g.value().get(0, 0), 1e-15);
|
||||
|
||||
// Math.cosh(x)
|
||||
assertEquals(Math.cosh(x.value()), cosh(x).value(), 1e-15);
|
||||
|
||||
g.close();
|
||||
g = new Gradient(cosh(x), x);
|
||||
assertEquals(Math.sinh(x.value()), g.get().value().get(0, 0));
|
||||
assertEquals(Math.sinh(x.value()), g.value().get(0, 0));
|
||||
|
||||
// tanh(x)
|
||||
assertEquals(Math.tanh(x.value()), tanh(x).value());
|
||||
|
||||
g.close();
|
||||
g = new Gradient(tanh(x), x);
|
||||
assertEquals(
|
||||
1.0 / (Math.cosh(x.value()) * Math.cosh(x.value())), g.get().value().get(0, 0), 1e-15);
|
||||
assertEquals(1.0 / (Math.cosh(x.value()) * Math.cosh(x.value())), g.value().get(0, 0), 1e-15);
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testExponential() {
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
x.setValue(1.0);
|
||||
|
||||
// Math.log(x)
|
||||
assertEquals(Math.log(x.value()), log(x).value());
|
||||
|
||||
var g = new Gradient(log(x), x);
|
||||
assertEquals(1.0 / x.value(), g.get().value().get(0, 0));
|
||||
assertEquals(1.0 / x.value(), g.value().get(0, 0));
|
||||
|
||||
// Math.log10(x)
|
||||
assertEquals(Math.log10(x.value()), log10(x).value());
|
||||
|
||||
g.close();
|
||||
g = new Gradient(log10(x), x);
|
||||
assertEquals(1.0 / (Math.log(10.0) * x.value()), g.get().value().get(0, 0));
|
||||
assertEquals(1.0 / (Math.log(10.0) * x.value()), g.value().get(0, 0));
|
||||
|
||||
// Math.exp(x)
|
||||
assertEquals(Math.exp(x.value()), exp(x).value(), 1e-15);
|
||||
|
||||
g.close();
|
||||
g = new Gradient(exp(x), x);
|
||||
assertEquals(Math.exp(x.value()), g.get().value().get(0, 0), 1e-15);
|
||||
assertEquals(Math.exp(x.value()), g.value().get(0, 0), 1e-15);
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testPower() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
x.setValue(1.0);
|
||||
var a = new Variable();
|
||||
a.setValue(2.0);
|
||||
|
||||
// Math.sqrt(x)
|
||||
assertEquals(Math.sqrt(x.value()), sqrt(x).value());
|
||||
|
||||
var g = new Gradient(sqrt(x), x);
|
||||
assertEquals(0.5 / Math.sqrt(x.value()), g.get().value().get(0, 0));
|
||||
assertEquals(0.5 / Math.sqrt(x.value()), g.value().get(0, 0));
|
||||
|
||||
// Math.sqrt(a)
|
||||
assertEquals(Math.sqrt(a.value()), sqrt(a).value());
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
x.setValue(1.0);
|
||||
var a = new Variable();
|
||||
a.setValue(2.0);
|
||||
|
||||
var g = new Gradient(sqrt(a), a);
|
||||
assertEquals(0.5 / Math.sqrt(a.value()), g.get().value().get(0, 0));
|
||||
assertEquals(0.5 / Math.sqrt(a.value()), g.value().get(0, 0));
|
||||
|
||||
// Math.cbrt(x)
|
||||
assertEquals(Math.cbrt(x.value()), cbrt(x).value());
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
x.setValue(1.0);
|
||||
var a = new Variable();
|
||||
a.setValue(2.0);
|
||||
|
||||
var g = new Gradient(cbrt(x), x);
|
||||
assertEquals(
|
||||
1.0 / (3.0 * Math.cbrt(x.value()) * Math.cbrt(x.value())), g.get().value().get(0, 0));
|
||||
assertEquals(1.0 / (3.0 * Math.cbrt(x.value()) * Math.cbrt(x.value())), g.value().get(0, 0));
|
||||
|
||||
// Math.cbrt(a)
|
||||
assertEquals(Math.cbrt(a.value()), cbrt(a).value(), 1e-15);
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
x.setValue(1.0);
|
||||
var a = new Variable();
|
||||
a.setValue(2.0);
|
||||
|
||||
var g = new Gradient(cbrt(a), a);
|
||||
assertEquals(
|
||||
1.0 / (3.0 * Math.cbrt(a.value()) * Math.cbrt(a.value())),
|
||||
g.get().value().get(0, 0),
|
||||
1e-15);
|
||||
assertEquals(
|
||||
1.0 / (3.0 * Math.cbrt(a.value()) * Math.cbrt(a.value())), g.value().get(0, 0), 1e-15);
|
||||
|
||||
// x²
|
||||
assertEquals(Math.pow(x.value(), 2.0), pow(x, 2.0).value());
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
x.setValue(1.0);
|
||||
|
||||
var g = new Gradient(pow(x, 2.0), x);
|
||||
assertEquals(2.0 * x.value(), g.get().value().get(0, 0));
|
||||
assertEquals(2.0 * x.value(), g.value().get(0, 0));
|
||||
|
||||
// 2ˣ
|
||||
assertEquals(Math.pow(2.0, x.value()), pow(2.0, x).value());
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
x.setValue(1.0);
|
||||
|
||||
var g = new Gradient(pow(2.0, x), x);
|
||||
assertEquals(Math.log(2.0) * Math.pow(2.0, x.value()), g.get().value().get(0, 0));
|
||||
assertEquals(Math.log(2.0) * Math.pow(2.0, x.value()), g.value().get(0, 0));
|
||||
|
||||
// xˣ
|
||||
assertEquals(Math.pow(x.value(), x.value()), pow(x, x).value());
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
x.setValue(1.0);
|
||||
var a = new Variable();
|
||||
a.setValue(2.0);
|
||||
|
||||
var g = new Gradient(pow(x, x), x);
|
||||
assertEquals(
|
||||
(Math.log(x.value()) + 1) * Math.pow(x.value(), x.value()), g.get().value().get(0, 0));
|
||||
assertEquals((Math.log(x.value()) + 1) * Math.pow(x.value(), x.value()), g.value().get(0, 0));
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
x.setValue(1.0);
|
||||
var a = new Variable();
|
||||
a.setValue(2.0);
|
||||
|
||||
// y(a)
|
||||
var y = a.times(2);
|
||||
assertEquals(2 * a.value(), y.value());
|
||||
|
||||
var g = new Gradient(y, a);
|
||||
assertEquals(2.0, g.get().value().get(0, 0));
|
||||
assertEquals(2.0, g.value().get(0, 0));
|
||||
|
||||
// xʸ(x)
|
||||
assertEquals(Math.pow(x.value(), y.value()), pow(x, y).value());
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
x.setValue(1.0);
|
||||
var a = new Variable();
|
||||
a.setValue(2.0);
|
||||
|
||||
// y(a)
|
||||
var y = a.times(2);
|
||||
assertEquals(2 * a.value(), y.value());
|
||||
|
||||
var g = new Gradient(pow(x, y), x);
|
||||
assertEquals(
|
||||
y.value() / x.value() * Math.pow(x.value(), y.value()), g.get().value().get(0, 0));
|
||||
assertEquals(y.value() / x.value() * Math.pow(x.value(), y.value()), g.value().get(0, 0));
|
||||
|
||||
// xʸ(a)
|
||||
assertEquals(Math.pow(x.value(), y.value()), pow(x, y).value());
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
x.setValue(1.0);
|
||||
var a = new Variable();
|
||||
a.setValue(2.0);
|
||||
|
||||
// y(a)
|
||||
var y = a.times(2);
|
||||
assertEquals(2 * a.value(), y.value());
|
||||
|
||||
try (var g = new Gradient(pow(x, y), a);
|
||||
var g_x_a = new Gradient(x, a);
|
||||
var g_y_a = new Gradient(y, a)) {
|
||||
assertEquals(
|
||||
Math.pow(x.value(), y.value())
|
||||
* (y.value() / x.value() * g_x_a.value().get(0, 0)
|
||||
+ Math.log(x.value()) * g_y_a.value().get(0, 0)),
|
||||
g.get().value().get(0, 0));
|
||||
assertEquals(
|
||||
Math.pow(x.value(), y.value())
|
||||
* (y.value() / x.value() * g_x_a.value().get(0, 0)
|
||||
+ Math.log(x.value()) * g_y_a.value().get(0, 0)),
|
||||
g.value().get(0, 0));
|
||||
}
|
||||
|
||||
// xʸ(y)
|
||||
assertEquals(Math.pow(x.value(), y.value()), pow(x, y).value());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
x.setValue(1.0);
|
||||
var a = new Variable();
|
||||
a.setValue(2.0);
|
||||
|
||||
// y(a)
|
||||
var y = a.times(2);
|
||||
assertEquals(2 * a.value(), y.value());
|
||||
|
||||
var g = new Gradient(pow(x, y), y);
|
||||
assertEquals(Math.log(x.value()) * Math.pow(x.value(), y.value()), g.get().value().get(0, 0));
|
||||
assertEquals(Math.log(x.value()) * Math.pow(x.value(), y.value()), g.value().get(0, 0));
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAbs() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
var g = new Gradient(abs(x), x);
|
||||
|
||||
x.setValue(1.0);
|
||||
assertEquals(Math.abs(x.value()), abs(x).value());
|
||||
assertEquals(1.0, g.get().value().get(0, 0));
|
||||
assertEquals(1.0, g.value().get(0, 0));
|
||||
|
||||
x.setValue(-1.0);
|
||||
assertEquals(Math.abs(x.value()), abs(x).value());
|
||||
assertEquals(-1.0, g.get().value().get(0, 0));
|
||||
assertEquals(-1.0, g.value().get(0, 0));
|
||||
|
||||
x.setValue(0.0);
|
||||
assertEquals(Math.abs(x.value()), abs(x).value());
|
||||
assertEquals(0.0, g.get().value().get(0, 0));
|
||||
assertEquals(0.0, g.value().get(0, 0));
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAtan2() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
var y = new Variable();
|
||||
|
||||
// Testing atan2 function on (double, var)
|
||||
x.setValue(1.0);
|
||||
y.setValue(0.9);
|
||||
assertEquals(Math.atan2(2.0, x.value()), atan2(2.0, x).value());
|
||||
|
||||
var g = new Gradient(atan2(2.0, x), x);
|
||||
assertEquals(-2.0 / (2 * 2 + x.value() * x.value()), g.get().value().get(0, 0), 1e-15);
|
||||
assertEquals(-2.0 / (2 * 2 + x.value() * x.value()), g.value().get(0, 0), 1e-15);
|
||||
|
||||
// Testing atan2 function on (var, double)
|
||||
x.setValue(1.0);
|
||||
y.setValue(0.9);
|
||||
assertEquals(Math.atan2(x.value(), 2.0), atan2(x, 2.0).value());
|
||||
|
||||
g.close();
|
||||
g = new Gradient(atan2(x, 2.0), x);
|
||||
assertEquals(2.0 / (2 * 2 + x.value() * x.value()), g.get().value().get(0, 0), 1e-15);
|
||||
assertEquals(2.0 / (2 * 2 + x.value() * x.value()), g.value().get(0, 0), 1e-15);
|
||||
|
||||
// Testing atan2 function on (var, var)
|
||||
x.setValue(1.1);
|
||||
y.setValue(0.9);
|
||||
assertEquals(Math.atan2(y.value(), x.value()), atan2(y, x).value(), 1e-15);
|
||||
|
||||
g.close();
|
||||
g = new Gradient(atan2(y, x), y);
|
||||
assertEquals(
|
||||
x.value() / (x.value() * x.value() + y.value() * y.value()),
|
||||
g.get().value().get(0, 0),
|
||||
1e-15);
|
||||
assertEquals(
|
||||
x.value() / (x.value() * x.value() + y.value() * y.value()), g.value().get(0, 0), 1e-15);
|
||||
|
||||
g.close();
|
||||
g = new Gradient(atan2(y, x), x);
|
||||
assertEquals(
|
||||
-y.value() / (x.value() * x.value() + y.value() * y.value()),
|
||||
g.get().value().get(0, 0),
|
||||
1e-15);
|
||||
assertEquals(
|
||||
-y.value() / (x.value() * x.value() + y.value() * y.value()), g.value().get(0, 0), 1e-15);
|
||||
|
||||
// Testing atan2 function on (expr, expr)
|
||||
assertEquals(
|
||||
3 * Math.atan2(Math.sin(y.value()), 2 * x.value() + 1),
|
||||
3 * atan2(sin(y), x.times(2).plus(1)).value(),
|
||||
1e-15);
|
||||
|
||||
g.close();
|
||||
g = new Gradient(atan2(sin(y), x.times(2).plus(1)).times(3), y);
|
||||
assertEquals(
|
||||
3
|
||||
* (2 * x.value() + 1)
|
||||
* Math.cos(y.value())
|
||||
/ ((2 * x.value() + 1) * (2 * x.value() + 1)
|
||||
+ Math.sin(y.value()) * Math.sin(y.value())),
|
||||
g.get().value().get(0, 0),
|
||||
1e-15);
|
||||
assertEquals(
|
||||
3
|
||||
* (2 * x.value() + 1)
|
||||
* Math.cos(y.value())
|
||||
/ ((2 * x.value() + 1) * (2 * x.value() + 1)
|
||||
+ Math.sin(y.value()) * Math.sin(y.value())),
|
||||
g.value().get(0, 0),
|
||||
1e-15);
|
||||
|
||||
g.close();
|
||||
g = new Gradient(atan2(sin(y), x.times(2).plus(1)).times(3), x);
|
||||
assertEquals(
|
||||
3
|
||||
* -2
|
||||
* Math.sin(y.value())
|
||||
/ ((2 * x.value() + 1) * (2 * x.value() + 1)
|
||||
+ Math.sin(y.value()) * Math.sin(y.value())),
|
||||
g.get().value().get(0, 0),
|
||||
1e-15);
|
||||
assertEquals(
|
||||
3
|
||||
* -2
|
||||
* Math.sin(y.value())
|
||||
/ ((2 * x.value() + 1) * (2 * x.value() + 1)
|
||||
+ Math.sin(y.value()) * Math.sin(y.value())),
|
||||
g.value().get(0, 0),
|
||||
1e-15);
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
private double hypot(double x, double y, double z) {
|
||||
return Math.sqrt(x * x + y * y + z * z);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testHypot() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
var y = new Variable();
|
||||
|
||||
// Testing hypot function on (var, double)
|
||||
x.setValue(1.8);
|
||||
y.setValue(1.5);
|
||||
assertEquals(Math.hypot(x.value(), 2.0), Variable.hypot(x, 2.0).value());
|
||||
|
||||
var g = new Gradient(Variable.hypot(x, 2.0), x);
|
||||
assertEquals(x.value() / Math.hypot(x.value(), 2.0), g.get().value().get(0, 0));
|
||||
assertEquals(x.value() / Math.hypot(x.value(), 2.0), g.value().get(0, 0));
|
||||
|
||||
// Testing hypot function on (double, var)
|
||||
assertEquals(Math.hypot(2.0, y.value()), Variable.hypot(2.0, y).value());
|
||||
|
||||
g.close();
|
||||
g = new Gradient(Variable.hypot(2.0, y), y);
|
||||
assertEquals(y.value() / Math.hypot(2.0, y.value()), g.get().value().get(0, 0));
|
||||
assertEquals(y.value() / Math.hypot(2.0, y.value()), g.value().get(0, 0));
|
||||
|
||||
// Testing hypot function on (var, var)
|
||||
x.setValue(1.3);
|
||||
y.setValue(2.3);
|
||||
assertEquals(Math.hypot(x.value(), y.value()), Variable.hypot(x, y).value());
|
||||
|
||||
g.close();
|
||||
g = new Gradient(Variable.hypot(x, y), x);
|
||||
assertEquals(x.value() / Math.hypot(x.value(), y.value()), g.get().value().get(0, 0));
|
||||
assertEquals(x.value() / Math.hypot(x.value(), y.value()), g.value().get(0, 0));
|
||||
|
||||
g.close();
|
||||
g = new Gradient(Variable.hypot(x, y), y);
|
||||
assertEquals(y.value() / Math.hypot(x.value(), y.value()), g.get().value().get(0, 0));
|
||||
assertEquals(y.value() / Math.hypot(x.value(), y.value()), g.value().get(0, 0));
|
||||
|
||||
// Testing hypot function on (expr, expr)
|
||||
x.setValue(1.3);
|
||||
y.setValue(2.3);
|
||||
assertEquals(
|
||||
Math.hypot(2.0 * x.value(), 3.0 * y.value()),
|
||||
Variable.hypot(x.times(2.0), y.times(3.0)).value());
|
||||
|
||||
g.close();
|
||||
g = new Gradient(Variable.hypot(x.times(2.0), y.times(3.0)), x);
|
||||
assertEquals(
|
||||
4.0 * x.value() / Math.hypot(2.0 * x.value(), 3.0 * y.value()),
|
||||
g.get().value().get(0, 0));
|
||||
assertEquals(
|
||||
4.0 * x.value() / Math.hypot(2.0 * x.value(), 3.0 * y.value()), g.value().get(0, 0));
|
||||
|
||||
g.close();
|
||||
g = new Gradient(Variable.hypot(x.times(2.0), y.times(3.0)), y);
|
||||
assertEquals(
|
||||
9.0 * y.value() / Math.hypot(2.0 * x.value(), 3.0 * y.value()),
|
||||
g.get().value().get(0, 0));
|
||||
assertEquals(
|
||||
9.0 * y.value() / Math.hypot(2.0 * x.value(), 3.0 * y.value()), g.value().get(0, 0));
|
||||
|
||||
// Testing hypot function on (var, var, var)
|
||||
var z = new Variable();
|
||||
x.setValue(1.3);
|
||||
y.setValue(2.3);
|
||||
z.setValue(3.3);
|
||||
assertEquals(Variable.hypot(x, y, z).value(), hypot(x.value(), y.value(), z.value()));
|
||||
|
||||
g.close();
|
||||
g = new Gradient(Variable.hypot(x, y, z), x);
|
||||
assertEquals(x.value() / hypot(x.value(), y.value(), z.value()), g.get().value().get(0, 0));
|
||||
assertEquals(x.value() / hypot(x.value(), y.value(), z.value()), g.value().get(0, 0));
|
||||
|
||||
g.close();
|
||||
g = new Gradient(Variable.hypot(x, y, z), y);
|
||||
assertEquals(y.value() / hypot(x.value(), y.value(), z.value()), g.get().value().get(0, 0));
|
||||
assertEquals(y.value() / hypot(x.value(), y.value(), z.value()), g.value().get(0, 0));
|
||||
|
||||
g.close();
|
||||
g = new Gradient(Variable.hypot(x, y, z), z);
|
||||
assertEquals(z.value() / hypot(x.value(), y.value(), z.value()), g.get().value().get(0, 0));
|
||||
assertEquals(z.value() / hypot(x.value(), y.value(), z.value()), g.value().get(0, 0));
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMax() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
x.setValue(2.0);
|
||||
|
||||
var x2 = x.times(x);
|
||||
var x3 = x.times(x).times(x);
|
||||
|
||||
try (var g_x3 = new Gradient(x3, x)) {
|
||||
// Testing lhs < rhs
|
||||
var g = new Gradient(max(x2, x3), x);
|
||||
assertEquals(x3.value(), max(x2, x3).value());
|
||||
assertEquals(g_x3.value().get(0, 0), g.get().value().get(0, 0));
|
||||
assertEquals(g_x3.value().get(0, 0), g.value().get(0, 0));
|
||||
|
||||
// Testing lhs > rhs
|
||||
g.close();
|
||||
g = new Gradient(max(x3, x2), x);
|
||||
assertEquals(x3.value(), max(x3, x2).value());
|
||||
assertEquals(g_x3.value().get(0, 0), g.get().value().get(0, 0));
|
||||
assertEquals(g_x3.value().get(0, 0), g.value().get(0, 0));
|
||||
|
||||
// Testing lhs == rhs
|
||||
g.close();
|
||||
g = new Gradient(max(x, x), x);
|
||||
assertEquals(x.value(), max(x, x).value());
|
||||
assertEquals(1.0, g.get().value().get(0, 0));
|
||||
assertEquals(1.0, g.value().get(0, 0));
|
||||
|
||||
g.close();
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMin() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
x.setValue(2.0);
|
||||
|
||||
var x2 = x.times(x);
|
||||
var x3 = x.times(x).times(x);
|
||||
|
||||
try (var g_x2 = new Gradient(x2, x)) {
|
||||
// Testing lhs < rhs
|
||||
var g = new Gradient(min(x2, x3), x);
|
||||
assertEquals(x2.value(), min(x2, x3).value());
|
||||
assertEquals(g_x2.value().get(0, 0), g.get().value().get(0, 0));
|
||||
assertEquals(g_x2.value().get(0, 0), g.value().get(0, 0));
|
||||
|
||||
// Testing lhs > rhs
|
||||
g.close();
|
||||
g = new Gradient(min(x3, x2), x);
|
||||
assertEquals(x2.value(), min(x3, x2).value());
|
||||
assertEquals(g_x2.value().get(0, 0), g.get().value().get(0, 0));
|
||||
assertEquals(g_x2.value().get(0, 0), g.value().get(0, 0));
|
||||
|
||||
// Testing lhs == rhs
|
||||
g.close();
|
||||
g = new Gradient(min(x, x), x);
|
||||
assertEquals(x.value(), min(x, x).value());
|
||||
assertEquals(1.0, g.get().value().get(0, 0));
|
||||
assertEquals(1.0, g.value().get(0, 0));
|
||||
|
||||
g.close();
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMiscellaneous() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
|
||||
// dx/dx
|
||||
x.setValue(3.0);
|
||||
assertEquals(Math.abs(x.value()), abs(x).value());
|
||||
|
||||
var g = new Gradient(x, x);
|
||||
assertEquals(1.0, g.get().value().get(0, 0));
|
||||
assertEquals(1.0, g.value().get(0, 0));
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testVariableReuse() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var a = new Variable();
|
||||
a.setValue(10);
|
||||
|
||||
var b = new Variable();
|
||||
b.setValue(20);
|
||||
|
||||
var x = a.times(b);
|
||||
|
||||
var g = new Gradient(x, a);
|
||||
|
||||
assertEquals(20.0, g.get().value().get(0, 0));
|
||||
assertEquals(20.0, g.value().get(0, 0));
|
||||
|
||||
b.setValue(10);
|
||||
assertEquals(10.0, g.get().value().get(0, 0));
|
||||
assertEquals(10.0, g.value().get(0, 0));
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSignum() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new Variable();
|
||||
|
||||
// signum(1.0)
|
||||
x.setValue(1.0);
|
||||
assertEquals(Math.signum(x.value()), signum(x).value());
|
||||
|
||||
var g = new Gradient(signum(x), x);
|
||||
assertEquals(0.0, g.get().value().get(0, 0));
|
||||
assertEquals(0.0, g.value().get(0, 0));
|
||||
|
||||
// signum(-1.0)
|
||||
x.setValue(-1.0);
|
||||
assertEquals(Math.signum(x.value()), signum(x).value());
|
||||
|
||||
g.close();
|
||||
g = new Gradient(signum(x), x);
|
||||
assertEquals(0.0, g.get().value().get(0, 0));
|
||||
assertEquals(0.0, g.value().get(0, 0));
|
||||
|
||||
// signum(0.0)
|
||||
x.setValue(0.0);
|
||||
assertEquals(Math.signum(x.value()), signum(x).value());
|
||||
|
||||
g.close();
|
||||
g = new Gradient(signum(x), x);
|
||||
assertEquals(0.0, g.get().value().get(0, 0));
|
||||
assertEquals(0.0, g.value().get(0, 0));
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testNonScalar() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new VariableMatrix(3);
|
||||
x.get(0).setValue(1);
|
||||
x.get(1).setValue(2);
|
||||
x.get(2).setValue(3);
|
||||
|
||||
// y = [x₁ + 3x₂ − 5x₃]
|
||||
//
|
||||
// dy/dx = [1 3 −5]
|
||||
var y = x.get(0).plus(x.get(1).times(3)).minus(x.get(2).times(5));
|
||||
var g = new Gradient(y, x);
|
||||
|
||||
var expected_g = new SimpleMatrix(new double[][] {{1.0}, {3.0}, {-5.0}});
|
||||
|
||||
var g_get_value = g.get().value();
|
||||
assertEquals(3, g_get_value.getNumRows());
|
||||
assertEquals(1, g_get_value.getNumCols());
|
||||
assertEquals(expected_g, g_get_value);
|
||||
|
||||
var g_value = g.value();
|
||||
assertEquals(3, g_value.getNumRows());
|
||||
assertEquals(1, g_value.getNumCols());
|
||||
assertEquals(expected_g, g_value);
|
||||
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -2,15 +2,15 @@
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.jni;
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class Ellipse2dJNITest {
|
||||
public class HessianJNITest {
|
||||
@Test
|
||||
public void testLink() {
|
||||
assertDoesNotThrow(Ellipse2dJNI::forceLoad);
|
||||
assertDoesNotThrow(HessianJNI::forceLoad);
|
||||
}
|
||||
}
|
||||
499
wpimath/src/test/java/org/wpilib/math/autodiff/HessianTest.java
Normal file
499
wpimath/src/test/java/org/wpilib/math/autodiff/HessianTest.java
Normal file
@@ -0,0 +1,499 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.wpilib.math.DoubleRange.range;
|
||||
import static org.wpilib.math.MatrixAssertions.assertEquals;
|
||||
import static org.wpilib.math.autodiff.Variable.log;
|
||||
import static org.wpilib.math.autodiff.Variable.pow;
|
||||
import static org.wpilib.math.autodiff.Variable.sin;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class HessianTest {
|
||||
@Test
|
||||
void testLinear() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
// y = x
|
||||
var x = new VariableMatrix(1);
|
||||
x.get(0).setValue(3);
|
||||
var y = x.get(0);
|
||||
|
||||
// dy/dx = 1
|
||||
var gradient = new Gradient(y, x.get(0));
|
||||
double g = gradient.value().get(0, 0);
|
||||
assertEquals(1.0, g);
|
||||
|
||||
// d²y/dx² = 0
|
||||
var H = new Hessian(y, x);
|
||||
assertEquals(0.0, H.get().value(0, 0));
|
||||
assertEquals(0.0, H.value().get(0, 0));
|
||||
|
||||
H.close();
|
||||
gradient.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testQuadratic() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
// y = x²
|
||||
var x = new VariableMatrix(1);
|
||||
x.get(0).setValue(3);
|
||||
var y = x.get(0).times(x.get(0));
|
||||
|
||||
// dy/dx = 2x = 6
|
||||
var gradient = new Gradient(y, x.get(0));
|
||||
double g = gradient.value().get(0, 0);
|
||||
assertEquals(6.0, g);
|
||||
|
||||
// d²y/dx² = 2
|
||||
var H = new Hessian(y, x);
|
||||
assertEquals(2.0, H.get().value(0, 0));
|
||||
assertEquals(2.0, H.value().get(0, 0));
|
||||
|
||||
H.close();
|
||||
gradient.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCubic() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
// y = x³
|
||||
var x = new VariableMatrix(1);
|
||||
x.get(0).setValue(3);
|
||||
var y = x.get(0).times(x.get(0)).times(x.get(0));
|
||||
|
||||
// dy/dx = 3x² = 27
|
||||
var gradient = new Gradient(y, x.get(0));
|
||||
double g = gradient.value().get(0, 0);
|
||||
assertEquals(27.0, g);
|
||||
|
||||
// d²y/dx² = 6x = 18
|
||||
var H = new Hessian(y, x);
|
||||
assertEquals(18.0, H.get().value(0, 0));
|
||||
assertEquals(18.0, H.value().get(0, 0));
|
||||
|
||||
H.close();
|
||||
gradient.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testQuartic() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
// y = x⁴
|
||||
var x = new VariableMatrix(1);
|
||||
x.get(0).setValue(3);
|
||||
var y = x.get(0).times(x.get(0)).times(x.get(0)).times(x.get(0));
|
||||
|
||||
// dy/dx = 4x³ = 108
|
||||
var gradient = new Gradient(y, x.get(0));
|
||||
double g = gradient.value().get(0, 0);
|
||||
assertEquals(108.0, g);
|
||||
|
||||
// d²y/dx² = 12x² = 108
|
||||
var H = new Hessian(y, x);
|
||||
assertEquals(108.0, H.get().value(0, 0));
|
||||
assertEquals(108.0, H.value().get(0, 0));
|
||||
|
||||
H.close();
|
||||
gradient.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSum() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new VariableMatrix(5);
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
x.get(i).setValue(i + 1);
|
||||
}
|
||||
|
||||
// y = sum(x)
|
||||
var y = x.get(0).plus(x.get(1)).plus(x.get(2)).plus(x.get(3)).plus(x.get(4));
|
||||
assertEquals(15.0, y.value());
|
||||
|
||||
var g = new Gradient(y, x);
|
||||
assertEquals(SimpleMatrix.filled(5, 1, 1.0), g.get().value());
|
||||
assertEquals(SimpleMatrix.filled(5, 1, 1.0), g.value());
|
||||
|
||||
var H = new Hessian(y, x);
|
||||
assertEquals(SimpleMatrix.filled(5, 5, 0.0), H.get().value());
|
||||
assertEquals(SimpleMatrix.filled(5, 5, 0.0), H.value());
|
||||
|
||||
H.close();
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSumOfProducts() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new VariableMatrix(5);
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
x.get(i).setValue(i + 1);
|
||||
}
|
||||
|
||||
// y = ||x||²
|
||||
var y = x.T().times(x).get(0);
|
||||
assertEquals(1 * 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, y.value());
|
||||
|
||||
var g = new Gradient(y, x);
|
||||
assertEquals(x.value().scale(2), g.get().value());
|
||||
assertEquals(x.value().scale(2), g.value());
|
||||
|
||||
var H = new Hessian(y, x);
|
||||
|
||||
var expected_H = SimpleMatrix.diag(2.0, 2.0, 2.0, 2.0, 2.0);
|
||||
assertEquals(expected_H, H.get().value());
|
||||
assertEquals(expected_H, H.value());
|
||||
|
||||
H.close();
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testProductOfSines() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new VariableMatrix(5);
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
x.get(i).setValue(i + 1);
|
||||
}
|
||||
|
||||
// y = prod(sin(x))
|
||||
var y = x.cwiseMap(Variable::sin).stream().reduce(new Variable(1.0), (a, b) -> a.times(b));
|
||||
assertEquals(
|
||||
Math.sin(1) * Math.sin(2) * Math.sin(3) * Math.sin(4) * Math.sin(5), y.value(), 1e-15);
|
||||
|
||||
var g = new Gradient(y, x);
|
||||
for (int i = 0; i < x.rows(); ++i) {
|
||||
assertEquals(y.value() / Math.tan(x.get(i).value()), g.get().value(i), 1e-15);
|
||||
assertEquals(y.value() / Math.tan(x.get(i).value()), g.value().get(i, 0), 1e-15);
|
||||
}
|
||||
|
||||
var H = new Hessian(y, x);
|
||||
|
||||
var expected_H = new SimpleMatrix(5, 5);
|
||||
for (int i = 0; i < x.rows(); ++i) {
|
||||
for (int j = 0; j < x.rows(); ++j) {
|
||||
if (i == j) {
|
||||
expected_H.set(i, j, -y.value());
|
||||
} else {
|
||||
expected_H.set(
|
||||
i, j, y.value() / (Math.tan(x.get(i).value()) * Math.tan(x.get(j).value())));
|
||||
}
|
||||
}
|
||||
}
|
||||
assertEquals(expected_H, H.get().value(), 1e-15);
|
||||
assertEquals(expected_H, H.value(), 1e-15);
|
||||
|
||||
H.close();
|
||||
g.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSumOfSquaredResiduals() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new VariableMatrix(5);
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
x.get(i).setValue(1);
|
||||
}
|
||||
|
||||
// y = sum(diff(x).^2)
|
||||
var temp = x.block(0, 0, 4, 1).minus(x.block(1, 0, 4, 1)).cwiseMap(a -> pow(a, 2));
|
||||
var y = temp.stream().reduce(new Variable(0.0), (a, b) -> a.plus(b));
|
||||
var gradient = new Gradient(y, x);
|
||||
var g = gradient.value();
|
||||
|
||||
assertEquals(0.0, y.value());
|
||||
assertEquals(g.get(0, 0), 2 * x.get(0).value() - 2 * x.get(1).value());
|
||||
assertEquals(
|
||||
g.get(1, 0), -2 * x.get(0).value() + 4 * x.get(1).value() - 2 * x.get(2).value());
|
||||
assertEquals(
|
||||
g.get(2, 0), -2 * x.get(1).value() + 4 * x.get(2).value() - 2 * x.get(3).value());
|
||||
assertEquals(
|
||||
g.get(3, 0), -2 * x.get(2).value() + 4 * x.get(3).value() - 2 * x.get(4).value());
|
||||
assertEquals(g.get(4, 0), -2 * x.get(3).value() + 2 * x.get(4).value());
|
||||
|
||||
var H = new Hessian(y, x);
|
||||
|
||||
var expected_H =
|
||||
new SimpleMatrix(
|
||||
new double[][] {
|
||||
{2.0, -2.0, 0.0, 0.0, 0.0},
|
||||
{-2.0, 4.0, -2.0, 0.0, 0.0},
|
||||
{0.0, -2.0, 4.0, -2.0, 0.0},
|
||||
{0.0, 0.0, -2.0, 4.0, -2.0},
|
||||
{0.0, 0.0, 0.0, -2.0, 2.0}
|
||||
});
|
||||
assertEquals(expected_H, H.get().value());
|
||||
assertEquals(expected_H, H.value());
|
||||
|
||||
H.close();
|
||||
gradient.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSumOfSquares() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var r = new VariableMatrix(4);
|
||||
r.setValue(new double[][] {{25.0}, {10.0}, {5.0}, {0.0}});
|
||||
|
||||
var x = new VariableMatrix(4);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
x.get(i).setValue(0.0);
|
||||
}
|
||||
|
||||
var J = new Variable(0.0);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
J = J.plus(r.get(i).minus(x.get(i)).times(r.get(i).minus(x.get(i))));
|
||||
}
|
||||
|
||||
var H = new Hessian(J, x);
|
||||
|
||||
var expected_H = SimpleMatrix.diag(2.0, 2.0, 2.0, 2.0);
|
||||
assertEquals(expected_H, H.get().value());
|
||||
assertEquals(expected_H, H.value());
|
||||
|
||||
H.close();
|
||||
J.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testNestedPowers() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
final var x0 = 3.0;
|
||||
|
||||
var x = new Variable();
|
||||
x.setValue(x0);
|
||||
|
||||
var y = pow(pow(x, 2), 2);
|
||||
|
||||
var jacobian = new Jacobian(y, x);
|
||||
var J = jacobian.value();
|
||||
assertEquals(4 * x0 * x0 * x0, J.get(0, 0), 1e-12);
|
||||
|
||||
var hessian = new Hessian(y, x);
|
||||
var H = hessian.value();
|
||||
assertEquals(12 * x0 * x0, H.get(0, 0), 1e-12);
|
||||
|
||||
hessian.close();
|
||||
jacobian.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testRosenbrock() {
|
||||
// z = (1 − x)² + 100(y − x²)²
|
||||
// = 100(−x² + y)² + (−x + 1)²
|
||||
//
|
||||
// ∂z/∂x = 200(−x² + y)⋅−2x + 2(−x + 1)⋅−1
|
||||
// = −400x(−x² + y) − 2(−x + 1)
|
||||
// = 400x³ − 400xy + 2x − 2
|
||||
//
|
||||
// ∂z/∂y = 200(−x² + y)
|
||||
//
|
||||
// ∂²z/∂x² = 1200x² − 400y + 2
|
||||
// ∂²z/∂xy = −400x
|
||||
// ∂²z/∂y² = 200
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var input = new VariableMatrix(2);
|
||||
var x = input.get(0);
|
||||
var y = input.get(1);
|
||||
var hessian =
|
||||
new Hessian(
|
||||
pow(new Variable(1).minus(x), 2).plus(pow(y.minus(pow(x, 2)), 2).times(100)), input);
|
||||
|
||||
for (var x0 : range(-2.5, 2.5, 0.1)) {
|
||||
for (var y0 : range(-2.5, 2.5, 0.1)) {
|
||||
x.setValue(x0);
|
||||
y.setValue(y0);
|
||||
|
||||
var H = hessian.value();
|
||||
assertEquals(1200 * x0 * x0 - 400 * y0 + 2, H.get(0, 0), 1e-11);
|
||||
assertEquals(-400 * x0, H.get(0, 1), 1e-15);
|
||||
assertEquals(-400 * x0, H.get(1, 0), 1e-15);
|
||||
assertEquals(200, H.get(1, 1));
|
||||
}
|
||||
}
|
||||
|
||||
hessian.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testEdgePushingWangExample1() {
|
||||
// See example 1 of [1]
|
||||
//
|
||||
// [1] Wang, M., et al. "Capitalizing on live variables: new algorithms for
|
||||
// efficient Hessian computation via automatic differentiation", 2016.
|
||||
// https://sci-hub.st/10.1007/s12532-016-0100-3
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new VariableMatrix(2);
|
||||
x.get(0).setValue(3);
|
||||
x.get(1).setValue(4);
|
||||
|
||||
// y = (x₀sin(x₁)) x₀
|
||||
var y = (x.get(0).times(sin(x.get(1)))).times(x.get(0));
|
||||
|
||||
// dy/dx = [2x₀sin(x₁) x₀²cos(x₁)]
|
||||
// dy/dx = [ 6sin(4) 9cos(4) ]
|
||||
var J = new Jacobian(y, x);
|
||||
var expected_J =
|
||||
new SimpleMatrix(new double[][] {{6.0 * Math.sin(4.0), 9.0 * Math.cos(4.0)}});
|
||||
assertEquals(expected_J, J.get().value(), 1e-15);
|
||||
assertEquals(expected_J, J.value(), 1e-15);
|
||||
|
||||
// [ 2sin(x₁) 2x₀cos(x₁)]
|
||||
// d²y/dx² = [2x₀cos(x₁) −x₀²sin(x₁)]
|
||||
//
|
||||
// [2sin(4) 6cos(4)]
|
||||
// d²y/dx² = [6cos(4) −9sin(4)]
|
||||
var H = new Hessian(y, x);
|
||||
var expected_H =
|
||||
new SimpleMatrix(
|
||||
new double[][] {
|
||||
{2.0 * Math.sin(4.0), 6.0 * Math.cos(4.0)},
|
||||
{6.0 * Math.cos(4.0), -9.0 * Math.sin(4.0)}
|
||||
});
|
||||
assertEquals(expected_H, H.get().value(), 1e-15);
|
||||
assertEquals(expected_H, H.value(), 1e-15);
|
||||
|
||||
H.close();
|
||||
J.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testEdgePushingPetroFigure1() {
|
||||
// See figure 1 of [1]
|
||||
//
|
||||
// [1] Petro, C. G., et al. "On efficient Hessian computation using the edge
|
||||
// pushing algorithm in Julia", 2017.
|
||||
// https://mlubin.github.io/pdf/edge_pushing_julia.pdf
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
// y = p₁ log(x₁x₂)
|
||||
var p_1 = new Variable(2.0);
|
||||
var x = new VariableMatrix(2);
|
||||
x.get(0).setValue(2.0);
|
||||
x.get(1).setValue(3.0);
|
||||
var y = p_1.times(log(x.get(0).times(x.get(1))));
|
||||
|
||||
// d²y/dx² = [−p₁/x₁² 0 ]
|
||||
// [ 0 −p₁/x₂²]
|
||||
var H = new Hessian(y, x);
|
||||
var expected_H =
|
||||
new SimpleMatrix(
|
||||
new double[][] {
|
||||
{-p_1.value() / (x.get(0).value() * x.get(0).value()), 0.0},
|
||||
{0.0, -p_1.value() / (x.get(1).value() * x.get(1).value())}
|
||||
});
|
||||
assertEquals(expected_H, H.get().value(), 1e-15);
|
||||
assertEquals(expected_H, H.value(), 1e-15);
|
||||
|
||||
H.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testVariableReuse() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
Variable y;
|
||||
var x = new VariableMatrix(1);
|
||||
|
||||
// y = x³
|
||||
x.get(0).setValue(1);
|
||||
y = x.get(0).times(x.get(0)).times(x.get(0));
|
||||
|
||||
var hessian = new Hessian(y, x);
|
||||
|
||||
// d²y/dx² = 6x
|
||||
// H = 6
|
||||
var H = hessian.value();
|
||||
|
||||
assertEquals(1, H.getNumRows());
|
||||
assertEquals(1, H.getNumCols());
|
||||
assertEquals(6.0, H.get(0, 0));
|
||||
|
||||
x.get(0).setValue(2);
|
||||
// d²y/dx² = 6x
|
||||
// H = 12
|
||||
H = hessian.value();
|
||||
|
||||
assertEquals(1, H.getNumRows());
|
||||
assertEquals(1, H.getNumCols());
|
||||
assertEquals(12.0, H.get(0, 0));
|
||||
|
||||
hessian.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class JacobianJNITest {
|
||||
@Test
|
||||
public void testLink() {
|
||||
assertDoesNotThrow(JacobianJNI::forceLoad);
|
||||
}
|
||||
}
|
||||
266
wpimath/src/test/java/org/wpilib/math/autodiff/JacobianTest.java
Normal file
266
wpimath/src/test/java/org/wpilib/math/autodiff/JacobianTest.java
Normal file
@@ -0,0 +1,266 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.wpilib.math.MatrixAssertions.assertEquals;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class JacobianTest {
|
||||
@Test
|
||||
void testYEqualsX() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new VariableMatrix(3);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
x.get(i).setValue(i + 1);
|
||||
}
|
||||
|
||||
// y = x
|
||||
//
|
||||
// [1 0 0]
|
||||
// dy/dx = [0 1 0]
|
||||
// [0 0 1]
|
||||
var y = x;
|
||||
var J = new Jacobian(y, x);
|
||||
|
||||
var expected_J =
|
||||
new SimpleMatrix(new double[][] {{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0}});
|
||||
assertEquals(expected_J, J.get().value());
|
||||
assertEquals(expected_J, J.value());
|
||||
|
||||
J.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testYEquals3X() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new VariableMatrix(3);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
x.get(i).setValue(i + 1);
|
||||
}
|
||||
|
||||
// y = 3x
|
||||
//
|
||||
// [3 0 0]
|
||||
// dy/dx = [0 3 0]
|
||||
// [0 0 3]
|
||||
var y = x.times(3);
|
||||
var J = new Jacobian(y, x);
|
||||
|
||||
var expected_J =
|
||||
new SimpleMatrix(new double[][] {{3.0, 0.0, 0.0}, {0.0, 3.0, 0.0}, {0.0, 0.0, 3.0}});
|
||||
assertEquals(expected_J, J.get().value());
|
||||
assertEquals(expected_J, J.value());
|
||||
|
||||
J.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testProducts() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new VariableMatrix(3);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
x.get(i).setValue(i + 1);
|
||||
}
|
||||
|
||||
// [x₁x₂]
|
||||
// y = [x₂x₃]
|
||||
// [x₁x₃]
|
||||
//
|
||||
// [x₂ x₁ 0 ]
|
||||
// dy/dx = [0 x₃ x₂]
|
||||
// [x₃ 0 x₁]
|
||||
//
|
||||
// [2 1 0]
|
||||
// dy/dx = [0 3 2]
|
||||
// [3 0 1]
|
||||
var y = new VariableMatrix(3);
|
||||
y.set(0, x.get(0).times(x.get(1)));
|
||||
y.set(1, x.get(1).times(x.get(2)));
|
||||
y.set(2, x.get(0).times(x.get(2)));
|
||||
var J = new Jacobian(y, x);
|
||||
|
||||
var expected_J =
|
||||
new SimpleMatrix(new double[][] {{2.0, 1.0, 0.0}, {0.0, 3.0, 2.0}, {3.0, 0.0, 1.0}});
|
||||
assertEquals(expected_J, J.get().value());
|
||||
assertEquals(expected_J, J.value());
|
||||
|
||||
J.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testNestedProducts() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new VariableMatrix(1);
|
||||
x.get(0).setValue(3);
|
||||
assertEquals(3.0, x.value(0));
|
||||
|
||||
// [ 5x] [15]
|
||||
// y = [ 7x] = [21]
|
||||
// [11x] [33]
|
||||
var y = new VariableMatrix(3);
|
||||
y.set(0, x.get(0).times(5));
|
||||
y.set(1, x.get(0).times(7));
|
||||
y.set(2, x.get(0).times(11));
|
||||
assertEquals(15.0, y.value(0));
|
||||
assertEquals(21.0, y.value(1));
|
||||
assertEquals(33.0, y.value(2));
|
||||
|
||||
// [y₁y₂] [15⋅21] [315]
|
||||
// z = [y₂y₃] = [21⋅33] = [693]
|
||||
// [y₁y₃] [15⋅33] [495]
|
||||
var z = new VariableMatrix(3);
|
||||
z.set(0, y.get(0).times(y.get(1)));
|
||||
z.set(1, y.get(1).times(y.get(2)));
|
||||
z.set(2, y.get(0).times(y.get(2)));
|
||||
assertEquals(315.0, z.value(0));
|
||||
assertEquals(693.0, z.value(1));
|
||||
assertEquals(495.0, z.value(2));
|
||||
|
||||
// [ 5x]
|
||||
// y = [ 7x]
|
||||
// [11x]
|
||||
//
|
||||
// [ 5]
|
||||
// dy/dx = [ 7]
|
||||
// [11]
|
||||
var J = new Jacobian(y, x);
|
||||
assertEquals(5.0, J.get().value(0, 0));
|
||||
assertEquals(7.0, J.get().value(1, 0));
|
||||
assertEquals(11.0, J.get().value(2, 0));
|
||||
assertEquals(5.0, J.value().get(0, 0));
|
||||
assertEquals(7.0, J.value().get(1, 0));
|
||||
assertEquals(11.0, J.value().get(2, 0));
|
||||
|
||||
// [y₁y₂]
|
||||
// z = [y₂y₃]
|
||||
// [y₁y₃]
|
||||
//
|
||||
// [y₂ y₁ 0 ] [21 15 0]
|
||||
// dz/dy = [0 y₃ y₂] = [ 0 33 21]
|
||||
// [y₃ 0 y₁] [33 0 15]
|
||||
J.close();
|
||||
J = new Jacobian(z, y);
|
||||
var expected_J =
|
||||
new SimpleMatrix(
|
||||
new double[][] {{21.0, 15.0, 0.0}, {0.0, 33.0, 21.0}, {33.0, 0.0, 15.0}});
|
||||
assertEquals(expected_J, J.get().value());
|
||||
assertEquals(expected_J, J.value());
|
||||
|
||||
// [y₁y₂] [5x⋅ 7x] [35x²]
|
||||
// z = [y₂y₃] = [7x⋅11x] = [77x²]
|
||||
// [y₁y₃] [5x⋅11x] [55x²]
|
||||
//
|
||||
// [ 70x] [210]
|
||||
// dz/dx = [154x] = [462]
|
||||
// [110x] = [330]
|
||||
J.close();
|
||||
J = new Jacobian(z, x);
|
||||
expected_J = new SimpleMatrix(new double[][] {{210.0}, {462.0}, {330.0}});
|
||||
assertEquals(expected_J, J.get().value());
|
||||
assertEquals(expected_J, J.value());
|
||||
|
||||
J.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testNonSquare() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new VariableMatrix(3);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
x.get(i).setValue(i + 1);
|
||||
}
|
||||
|
||||
// y = [x₁ + 3x₂ − 5x₃]
|
||||
//
|
||||
// dy/dx = [1 3 −5]
|
||||
var y = new VariableMatrix(1);
|
||||
y.set(0, x.get(0).plus(x.get(1).times(3)).minus(x.get(2).times(5)));
|
||||
var J = new Jacobian(y, x);
|
||||
|
||||
var expected_J = new SimpleMatrix(new double[][] {{1.0, 3.0, -5.0}});
|
||||
|
||||
var J_get_value = J.get().value();
|
||||
assertEquals(1, J_get_value.getNumRows());
|
||||
assertEquals(3, J_get_value.getNumCols());
|
||||
assertEquals(expected_J, J_get_value);
|
||||
|
||||
var J_value = J.value();
|
||||
assertEquals(1, J_value.getNumRows());
|
||||
assertEquals(3, J_value.getNumCols());
|
||||
assertEquals(expected_J, J_value);
|
||||
|
||||
J.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testVariableReuse() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var x = new VariableMatrix(2);
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
x.get(i).setValue(i + 1);
|
||||
}
|
||||
|
||||
// y = [x₁x₂]
|
||||
var y = new VariableMatrix(1);
|
||||
y.set(0, x.get(0).times(x.get(1)));
|
||||
|
||||
var jacobian = new Jacobian(y, x);
|
||||
|
||||
// dy/dx = [x₂ x₁]
|
||||
// dy/dx = [2 1]
|
||||
var J = jacobian.value();
|
||||
|
||||
assertEquals(1, J.getNumRows());
|
||||
assertEquals(2, J.getNumCols());
|
||||
assertEquals(2.0, J.get(0, 0));
|
||||
assertEquals(1.0, J.get(0, 1));
|
||||
|
||||
x.get(0).setValue(2);
|
||||
x.get(1).setValue(1);
|
||||
// dy/dx = [x₂ x₁]
|
||||
// dy/dx = [1 2]
|
||||
J = jacobian.value();
|
||||
|
||||
assertEquals(1, J.getNumRows());
|
||||
assertEquals(2, J.getNumCols());
|
||||
assertEquals(1.0, J.get(0, 0));
|
||||
assertEquals(2.0, J.get(0, 1));
|
||||
|
||||
jacobian.close();
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
481
wpimath/src/test/java/org/wpilib/math/autodiff/SliceTest.java
Normal file
481
wpimath/src/test/java/org/wpilib/math/autodiff/SliceTest.java
Normal file
@@ -0,0 +1,481 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class SliceTest {
|
||||
@Test
|
||||
void testDefaultConstructor() {
|
||||
var slice = new Slice();
|
||||
assertEquals(0, slice.start);
|
||||
assertEquals(0, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
|
||||
assertEquals(0, slice.adjust(3));
|
||||
assertEquals(0, slice.start);
|
||||
assertEquals(0, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testOneArgConstructor() {
|
||||
// none
|
||||
{
|
||||
var slice = new Slice(Slice.__);
|
||||
assertEquals(0, slice.start);
|
||||
assertEquals(Integer.MAX_VALUE, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
|
||||
assertEquals(3, slice.adjust(3));
|
||||
assertEquals(0, slice.start);
|
||||
assertEquals(3, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
}
|
||||
|
||||
// +
|
||||
{
|
||||
var slice = new Slice(1);
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(2, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
|
||||
assertEquals(1, slice.adjust(3));
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(2, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
}
|
||||
|
||||
// -1
|
||||
{
|
||||
var slice = new Slice(-1);
|
||||
assertEquals(-1, slice.start);
|
||||
assertEquals(Integer.MAX_VALUE, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
|
||||
assertEquals(1, slice.adjust(3));
|
||||
assertEquals(2, slice.start);
|
||||
assertEquals(3, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
}
|
||||
|
||||
// -2
|
||||
{
|
||||
var slice = new Slice(-2);
|
||||
assertEquals(-2, slice.start);
|
||||
assertEquals(-1, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
|
||||
assertEquals(1, slice.adjust(3));
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(2, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testTwoArgConstructor() {
|
||||
// none, none
|
||||
{
|
||||
var slice = new Slice(Slice.__, Slice.__);
|
||||
assertEquals(0, slice.start);
|
||||
assertEquals(Integer.MAX_VALUE, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
|
||||
assertEquals(3, slice.adjust(3));
|
||||
assertEquals(0, slice.start);
|
||||
assertEquals(3, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
}
|
||||
|
||||
// none, +
|
||||
{
|
||||
var slice = new Slice(Slice.__, 1);
|
||||
assertEquals(0, slice.start);
|
||||
assertEquals(1, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
|
||||
assertEquals(1, slice.adjust(3));
|
||||
assertEquals(0, slice.start);
|
||||
assertEquals(1, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
}
|
||||
|
||||
// none, -
|
||||
{
|
||||
var slice = new Slice(Slice.__, -1);
|
||||
assertEquals(0, slice.start);
|
||||
assertEquals(-1, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
|
||||
assertEquals(2, slice.adjust(3));
|
||||
assertEquals(0, slice.start);
|
||||
assertEquals(2, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
}
|
||||
|
||||
// +, none
|
||||
{
|
||||
var slice = new Slice(1, Slice.__);
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(Integer.MAX_VALUE, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
|
||||
assertEquals(2, slice.adjust(3));
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(3, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
}
|
||||
|
||||
// -, none
|
||||
{
|
||||
var slice = new Slice(-1, Slice.__);
|
||||
assertEquals(-1, slice.start);
|
||||
assertEquals(Integer.MAX_VALUE, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
|
||||
assertEquals(1, slice.adjust(3));
|
||||
assertEquals(2, slice.start);
|
||||
assertEquals(3, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
}
|
||||
|
||||
// +, +
|
||||
{
|
||||
var slice = new Slice(1, 2);
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(2, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
|
||||
assertEquals(1, slice.adjust(3));
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(2, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
}
|
||||
|
||||
// +, -
|
||||
{
|
||||
var slice = new Slice(1, -1);
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(-1, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
|
||||
assertEquals(1, slice.adjust(3));
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(2, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
}
|
||||
|
||||
// -, -
|
||||
{
|
||||
var slice = new Slice(-2, -1);
|
||||
assertEquals(-2, slice.start);
|
||||
assertEquals(-1, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
|
||||
assertEquals(1, slice.adjust(3));
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(2, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testThreeArgConstructor() {
|
||||
// none, none, none
|
||||
{
|
||||
var slice = new Slice(Slice.__, Slice.__, Slice.__);
|
||||
assertEquals(0, slice.start);
|
||||
assertEquals(Integer.MAX_VALUE, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
|
||||
assertEquals(3, slice.adjust(3));
|
||||
assertEquals(0, slice.start);
|
||||
assertEquals(3, slice.stop);
|
||||
assertEquals(1, slice.step);
|
||||
}
|
||||
|
||||
// none, none, +
|
||||
{
|
||||
var slice = new Slice(Slice.__, Slice.__, 2);
|
||||
assertEquals(0, slice.start);
|
||||
assertEquals(Integer.MAX_VALUE, slice.stop);
|
||||
assertEquals(2, slice.step);
|
||||
|
||||
assertEquals(2, slice.adjust(3));
|
||||
assertEquals(0, slice.start);
|
||||
assertEquals(3, slice.stop);
|
||||
assertEquals(2, slice.step);
|
||||
}
|
||||
|
||||
// none, none, -
|
||||
{
|
||||
var slice = new Slice(Slice.__, Slice.__, -2);
|
||||
assertEquals(Integer.MAX_VALUE, slice.start);
|
||||
assertEquals(Integer.MIN_VALUE, slice.stop);
|
||||
assertEquals(-2, slice.step);
|
||||
|
||||
assertEquals(2, slice.adjust(3));
|
||||
assertEquals(2, slice.start);
|
||||
assertEquals(-1, slice.stop);
|
||||
assertEquals(-2, slice.step);
|
||||
}
|
||||
|
||||
// none, +, +
|
||||
{
|
||||
var slice = new Slice(Slice.__, 1, 2);
|
||||
assertEquals(0, slice.start);
|
||||
assertEquals(1, slice.stop);
|
||||
assertEquals(2, slice.step);
|
||||
|
||||
assertEquals(1, slice.adjust(3));
|
||||
assertEquals(0, slice.start);
|
||||
assertEquals(1, slice.stop);
|
||||
assertEquals(2, slice.step);
|
||||
}
|
||||
|
||||
// none, +, -
|
||||
{
|
||||
var slice = new Slice(Slice.__, 1, -2);
|
||||
assertEquals(Integer.MAX_VALUE, slice.start);
|
||||
assertEquals(1, slice.stop);
|
||||
assertEquals(-2, slice.step);
|
||||
|
||||
assertEquals(1, slice.adjust(3));
|
||||
assertEquals(2, slice.start);
|
||||
assertEquals(1, slice.stop);
|
||||
assertEquals(-2, slice.step);
|
||||
}
|
||||
|
||||
// none, -, -
|
||||
{
|
||||
var slice = new Slice(Slice.__, -2, -1);
|
||||
assertEquals(Integer.MAX_VALUE, slice.start);
|
||||
assertEquals(-2, slice.stop);
|
||||
assertEquals(-1, slice.step);
|
||||
|
||||
assertEquals(1, slice.adjust(3));
|
||||
assertEquals(2, slice.start);
|
||||
assertEquals(1, slice.stop);
|
||||
assertEquals(-1, slice.step);
|
||||
}
|
||||
|
||||
// +, none, +
|
||||
{
|
||||
var slice = new Slice(1, Slice.__, 2);
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(Integer.MAX_VALUE, slice.stop);
|
||||
assertEquals(2, slice.step);
|
||||
|
||||
assertEquals(1, slice.adjust(3));
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(3, slice.stop);
|
||||
assertEquals(2, slice.step);
|
||||
}
|
||||
|
||||
// +, none, -
|
||||
{
|
||||
var slice = new Slice(1, Slice.__, -2);
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(Integer.MIN_VALUE, slice.stop);
|
||||
assertEquals(-2, slice.step);
|
||||
|
||||
assertEquals(1, slice.adjust(3));
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(-1, slice.stop);
|
||||
assertEquals(-2, slice.step);
|
||||
}
|
||||
|
||||
// +, +, +
|
||||
{
|
||||
var slice = new Slice(1, 2, 2);
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(2, slice.stop);
|
||||
assertEquals(2, slice.step);
|
||||
|
||||
assertEquals(1, slice.adjust(3));
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(2, slice.stop);
|
||||
assertEquals(2, slice.step);
|
||||
}
|
||||
|
||||
// +, +, -
|
||||
{
|
||||
var slice = new Slice(2, 1, -2);
|
||||
assertEquals(2, slice.start);
|
||||
assertEquals(1, slice.stop);
|
||||
assertEquals(-2, slice.step);
|
||||
|
||||
assertEquals(1, slice.adjust(3));
|
||||
assertEquals(2, slice.start);
|
||||
assertEquals(1, slice.stop);
|
||||
assertEquals(-2, slice.step);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testEmptySlices() {
|
||||
// +, +, +
|
||||
{
|
||||
var slice = new Slice(2, 1, 2);
|
||||
assertEquals(2, slice.start);
|
||||
assertEquals(1, slice.stop);
|
||||
assertEquals(2, slice.step);
|
||||
|
||||
assertEquals(0, slice.adjust(3));
|
||||
assertEquals(2, slice.start);
|
||||
assertEquals(1, slice.stop);
|
||||
assertEquals(2, slice.step);
|
||||
}
|
||||
|
||||
// +, +, -
|
||||
{
|
||||
var slice = new Slice(1, 2, -2);
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(2, slice.stop);
|
||||
assertEquals(-2, slice.step);
|
||||
|
||||
assertEquals(0, slice.adjust(3));
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(2, slice.stop);
|
||||
assertEquals(-2, slice.step);
|
||||
}
|
||||
|
||||
// +, -, -
|
||||
{
|
||||
var slice = new Slice(3, -1, -2);
|
||||
assertEquals(3, slice.start);
|
||||
assertEquals(-1, slice.stop);
|
||||
assertEquals(-2, slice.step);
|
||||
|
||||
assertEquals(0, slice.adjust(3));
|
||||
assertEquals(2, slice.start);
|
||||
assertEquals(2, slice.stop);
|
||||
assertEquals(-2, slice.step);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testStepUBGuard() {
|
||||
{
|
||||
// none, none, INT_MIN
|
||||
var slice = new Slice(Slice.__, Slice.__, Integer.MIN_VALUE);
|
||||
assertEquals(Integer.MAX_VALUE, slice.start);
|
||||
assertEquals(Integer.MIN_VALUE, slice.stop);
|
||||
assertEquals(Integer.MIN_VALUE + 1, slice.step);
|
||||
|
||||
slice.step = -slice.step;
|
||||
assertEquals(Integer.MAX_VALUE, slice.start);
|
||||
assertEquals(Integer.MIN_VALUE, slice.stop);
|
||||
assertEquals(Integer.MAX_VALUE, slice.step);
|
||||
}
|
||||
|
||||
{
|
||||
// none, +, INT_MIN
|
||||
var slice = new Slice(Slice.__, 2, Integer.MIN_VALUE);
|
||||
assertEquals(Integer.MAX_VALUE, slice.start);
|
||||
assertEquals(2, slice.stop);
|
||||
assertEquals(Integer.MIN_VALUE + 1, slice.step);
|
||||
|
||||
slice.step = -slice.step;
|
||||
assertEquals(Integer.MAX_VALUE, slice.start);
|
||||
assertEquals(2, slice.stop);
|
||||
assertEquals(Integer.MAX_VALUE, slice.step);
|
||||
}
|
||||
|
||||
{
|
||||
// none, -, INT_MIN
|
||||
var slice = new Slice(Slice.__, -2, Integer.MIN_VALUE);
|
||||
assertEquals(Integer.MAX_VALUE, slice.start);
|
||||
assertEquals(-2, slice.stop);
|
||||
assertEquals(Integer.MIN_VALUE + 1, slice.step);
|
||||
|
||||
slice.step = -slice.step;
|
||||
assertEquals(Integer.MAX_VALUE, slice.start);
|
||||
assertEquals(-2, slice.stop);
|
||||
assertEquals(Integer.MAX_VALUE, slice.step);
|
||||
}
|
||||
|
||||
{
|
||||
// +, none, INT_MIN
|
||||
var slice = new Slice(1, Slice.__, Integer.MIN_VALUE);
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(Integer.MIN_VALUE, slice.stop);
|
||||
assertEquals(Integer.MIN_VALUE + 1, slice.step);
|
||||
|
||||
slice.step = -slice.step;
|
||||
assertEquals(1, slice.start);
|
||||
assertEquals(Integer.MIN_VALUE, slice.stop);
|
||||
assertEquals(Integer.MAX_VALUE, slice.step);
|
||||
}
|
||||
|
||||
{
|
||||
// -, none, INT_MIN
|
||||
var slice = new Slice(-2, Slice.__, Integer.MIN_VALUE);
|
||||
assertEquals(-2, slice.start);
|
||||
assertEquals(Integer.MIN_VALUE, slice.stop);
|
||||
assertEquals(Integer.MIN_VALUE + 1, slice.step);
|
||||
|
||||
slice.step = -slice.step;
|
||||
assertEquals(-2, slice.start);
|
||||
assertEquals(Integer.MIN_VALUE, slice.stop);
|
||||
assertEquals(Integer.MAX_VALUE, slice.step);
|
||||
}
|
||||
|
||||
{
|
||||
// +, +, INT_MIN
|
||||
var slice = new Slice(1000, 0, Integer.MIN_VALUE);
|
||||
assertEquals(1000, slice.start);
|
||||
assertEquals(0, slice.stop);
|
||||
assertEquals(Integer.MIN_VALUE + 1, slice.step);
|
||||
|
||||
slice.step = -slice.step;
|
||||
assertEquals(1000, slice.start);
|
||||
assertEquals(0, slice.stop);
|
||||
assertEquals(Integer.MAX_VALUE, slice.step);
|
||||
}
|
||||
|
||||
{
|
||||
// +, -, INT_MIN
|
||||
var slice = new Slice(1000, -2, Integer.MIN_VALUE);
|
||||
assertEquals(1000, slice.start);
|
||||
assertEquals(-2, slice.stop);
|
||||
assertEquals(Integer.MIN_VALUE + 1, slice.step);
|
||||
|
||||
slice.step = -slice.step;
|
||||
assertEquals(1000, slice.start);
|
||||
assertEquals(-2, slice.stop);
|
||||
assertEquals(Integer.MAX_VALUE, slice.step);
|
||||
}
|
||||
|
||||
{
|
||||
// -, +, INT_MIN
|
||||
var slice = new Slice(-1, 2, Integer.MIN_VALUE);
|
||||
assertEquals(-1, slice.start);
|
||||
assertEquals(2, slice.stop);
|
||||
assertEquals(Integer.MIN_VALUE + 1, slice.step);
|
||||
|
||||
slice.step = -slice.step;
|
||||
assertEquals(-1, slice.start);
|
||||
assertEquals(2, slice.stop);
|
||||
assertEquals(Integer.MAX_VALUE, slice.step);
|
||||
}
|
||||
|
||||
{
|
||||
// -, -, INT_MIN
|
||||
var slice = new Slice(-1, -2, Integer.MIN_VALUE);
|
||||
assertEquals(-1, slice.start);
|
||||
assertEquals(-2, slice.stop);
|
||||
assertEquals(Integer.MIN_VALUE + 1, slice.step);
|
||||
|
||||
slice.step = -slice.step;
|
||||
assertEquals(-1, slice.start);
|
||||
assertEquals(-2, slice.stop);
|
||||
assertEquals(Integer.MAX_VALUE, slice.step);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class VariableJNITest {
|
||||
@Test
|
||||
public void testLink() {
|
||||
assertDoesNotThrow(VariableJNI::forceLoad);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class VariableMatrixJNITest {
|
||||
@Test
|
||||
public void testLink() {
|
||||
assertDoesNotThrow(VariableMatrixJNI::forceLoad);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,600 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.wpilib.math.MatrixAssertions.assertEquals;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class VariableMatrixTest {
|
||||
@Test
|
||||
void testConstructFromDoubleArray() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var mat = new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})) {
|
||||
var expected = new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
|
||||
assertEquals(expected, mat.value());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testConstructFromSimpleMatrix() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var mat =
|
||||
new VariableMatrix(new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}))) {
|
||||
var expected = new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
|
||||
assertEquals(expected, mat.value());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAssignmentToDefault() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var mat = new VariableMatrix(2, 2);
|
||||
|
||||
assertEquals(2, mat.rows());
|
||||
assertEquals(2, mat.cols());
|
||||
assertEquals(0.0, mat.get(0, 0).value());
|
||||
assertEquals(0.0, mat.get(0, 1).value());
|
||||
assertEquals(0.0, mat.get(1, 0).value());
|
||||
assertEquals(0.0, mat.get(1, 1).value());
|
||||
|
||||
mat.set(0, 0, 1.0);
|
||||
mat.set(0, 1, 2.0);
|
||||
mat.set(1, 0, 3.0);
|
||||
mat.set(1, 1, 4.0);
|
||||
|
||||
assertEquals(1.0, mat.get(0, 0).value());
|
||||
assertEquals(2.0, mat.get(0, 1).value());
|
||||
assertEquals(3.0, mat.get(1, 0).value());
|
||||
assertEquals(4.0, mat.get(1, 1).value());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAssignmentAliasing() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var A = new VariableMatrix(new double[][] {{1.0, 2.0}, {3.0, 4.0}});
|
||||
var B = new VariableMatrix(new double[][] {{5.0, 6.0}, {7.0, 8.0}});
|
||||
|
||||
// A and B initially contain different values
|
||||
var expected_A = new SimpleMatrix(new double[][] {{1.0, 2.0}, {3.0, 4.0}});
|
||||
var expected_B = new SimpleMatrix(new double[][] {{5.0, 6.0}, {7.0, 8.0}});
|
||||
assertEquals(expected_A, A.value());
|
||||
assertEquals(expected_B, B.value());
|
||||
|
||||
// Make A point to B's storage
|
||||
A.set(B);
|
||||
assertEquals(expected_B, A.value());
|
||||
assertEquals(expected_B, B.value());
|
||||
|
||||
// Changes to B should be reflected in A
|
||||
B.get(0, 0).setValue(2.0);
|
||||
expected_B.set(0, 0, 2.0);
|
||||
assertEquals(expected_B, A.value());
|
||||
assertEquals(expected_B, B.value());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBlockMemberFunction() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var A =
|
||||
new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}});
|
||||
|
||||
// Block assignment
|
||||
A.block(1, 1, 2, 2).set(new double[][] {{10.0, 11.0}, {12.0, 13.0}});
|
||||
|
||||
var expected1 =
|
||||
new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 10.0, 11.0}, {7.0, 12.0, 13.0}});
|
||||
assertEquals(expected1, A.value());
|
||||
|
||||
// Block-of-block assignment
|
||||
A.block(1, 1, 2, 2).block(1, 1, 1, 1).set(14.0);
|
||||
|
||||
var expected2 =
|
||||
new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 10.0, 11.0}, {7.0, 12.0, 14.0}});
|
||||
assertEquals(expected2, A.value());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSlicing() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var mat =
|
||||
new VariableMatrix(
|
||||
new double[][] {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}});
|
||||
assertEquals(4, mat.rows());
|
||||
assertEquals(4, mat.cols());
|
||||
|
||||
// Single-arg index operator on full matrix
|
||||
for (int i = 0; i < mat.rows() * mat.cols(); ++i) {
|
||||
assertEquals(i + 1, mat.get(i).value());
|
||||
}
|
||||
|
||||
// Slice from start
|
||||
{
|
||||
var s = mat.get(new Slice(1, Slice.__), new Slice(2, Slice.__));
|
||||
assertEquals(3, s.rows());
|
||||
assertEquals(2, s.cols());
|
||||
// Single-arg index operator on forward slice
|
||||
assertEquals(7.0, s.get(0).value());
|
||||
assertEquals(8.0, s.get(1).value());
|
||||
assertEquals(11.0, s.get(2).value());
|
||||
assertEquals(12.0, s.get(3).value());
|
||||
assertEquals(15.0, s.get(4).value());
|
||||
assertEquals(16.0, s.get(5).value());
|
||||
// Double-arg index operator on forward slice
|
||||
assertEquals(7.0, s.get(0, 0).value());
|
||||
assertEquals(8.0, s.get(0, 1).value());
|
||||
assertEquals(11.0, s.get(1, 0).value());
|
||||
assertEquals(12.0, s.get(1, 1).value());
|
||||
assertEquals(15.0, s.get(2, 0).value());
|
||||
assertEquals(16.0, s.get(2, 1).value());
|
||||
}
|
||||
|
||||
// Slice from end
|
||||
{
|
||||
var s = mat.get(new Slice(-1, Slice.__), new Slice(-2, Slice.__));
|
||||
assertEquals(1, s.rows());
|
||||
assertEquals(2, s.cols());
|
||||
// Single-arg index operator on reverse slice
|
||||
assertEquals(15.0, s.get(0).value());
|
||||
assertEquals(16.0, s.get(1).value());
|
||||
// Double-arg index operator on reverse slice
|
||||
assertEquals(15.0, s.get(0, 0).value());
|
||||
assertEquals(16.0, s.get(0, 1).value());
|
||||
}
|
||||
|
||||
// Slice from start with step of 2
|
||||
{
|
||||
var s = mat.get(Slice.__, new Slice(Slice.__, Slice.__, 2));
|
||||
assertEquals(4, s.rows());
|
||||
assertEquals(2, s.cols());
|
||||
assertEquals(
|
||||
new SimpleMatrix(new double[][] {{1.0, 3.0}, {5.0, 7.0}, {9.0, 11.0}, {13.0, 15.0}}),
|
||||
s.value());
|
||||
}
|
||||
|
||||
// Slice from end with negative step for row and column
|
||||
{
|
||||
var s = mat.get(new Slice(Slice.__, Slice.__, -1), new Slice(Slice.__, Slice.__, -2));
|
||||
assertEquals(4, s.rows());
|
||||
assertEquals(2, s.cols());
|
||||
assertEquals(
|
||||
new SimpleMatrix(new double[][] {{16.0, 14.0}, {12.0, 10.0}, {8.0, 6.0}, {4.0, 2.0}}),
|
||||
s.value());
|
||||
}
|
||||
|
||||
// Slice from start and column -1
|
||||
{
|
||||
var s = mat.get(new Slice(1, Slice.__), -1);
|
||||
assertEquals(3, s.rows());
|
||||
assertEquals(1, s.cols());
|
||||
assertEquals(new SimpleMatrix(new double[][] {{8.0}, {12.0}, {16.0}}), s.value());
|
||||
}
|
||||
|
||||
// Slice from start and column -2
|
||||
{
|
||||
var s = mat.get(new Slice(1, Slice.__), -2);
|
||||
assertEquals(3, s.rows());
|
||||
assertEquals(1, s.cols());
|
||||
assertEquals(new SimpleMatrix(new double[][] {{7.0}, {11.0}, {15.0}}), s.value());
|
||||
}
|
||||
|
||||
// Block assignment
|
||||
{
|
||||
var s = mat.get(new Slice(Slice.__, Slice.__, 2), new Slice(Slice.__, Slice.__, 2));
|
||||
assertEquals(2, s.rows());
|
||||
assertEquals(2, s.cols());
|
||||
s.setValue(new double[][] {{17.0, 18.0}, {19.0, 20.0}});
|
||||
assertEquals(
|
||||
new SimpleMatrix(
|
||||
new double[][] {
|
||||
{17.0, 2.0, 18.0, 4.0},
|
||||
{5.0, 6.0, 7.0, 8.0},
|
||||
{19.0, 10.0, 20.0, 12.0},
|
||||
{13.0, 14.0, 15.0, 16.0}
|
||||
}),
|
||||
mat.value());
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSubslicing() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
// Block-of-block assignment (row skip forward)
|
||||
{
|
||||
var mat = new VariableMatrix(5, 5);
|
||||
var s =
|
||||
mat.get(new Slice(Slice.__, Slice.__, 2), new Slice(Slice.__, Slice.__, 1))
|
||||
.get(new Slice(1, 3), new Slice(1, 4));
|
||||
assertEquals(2, s.rows());
|
||||
assertEquals(3, s.cols());
|
||||
s.setValue(new double[][] {{1, 2, 3}, {4, 5, 6}});
|
||||
|
||||
assertEquals(
|
||||
new SimpleMatrix(
|
||||
new double[][] {
|
||||
{0, 0, 0, 0, 0},
|
||||
{0, 0, 0, 0, 0},
|
||||
{0, 1, 2, 3, 0},
|
||||
{0, 0, 0, 0, 0},
|
||||
{0, 4, 5, 6, 0}
|
||||
}),
|
||||
mat.value());
|
||||
}
|
||||
|
||||
// Block-of-block assignment (row skip backward)
|
||||
{
|
||||
var mat = new VariableMatrix(5, 5);
|
||||
var s =
|
||||
mat.get(new Slice(Slice.__, Slice.__, -2), new Slice(Slice.__, Slice.__, -1))
|
||||
.get(new Slice(1, 3), new Slice(1, 4));
|
||||
assertEquals(2, s.rows());
|
||||
assertEquals(3, s.cols());
|
||||
s.setValue(new double[][] {{1, 2, 3}, {4, 5, 6}});
|
||||
|
||||
assertEquals(
|
||||
new SimpleMatrix(
|
||||
new double[][] {
|
||||
{0, 6, 5, 4, 0},
|
||||
{0, 0, 0, 0, 0},
|
||||
{0, 3, 2, 1, 0},
|
||||
{0, 0, 0, 0, 0},
|
||||
{0, 0, 0, 0, 0}
|
||||
}),
|
||||
mat.value());
|
||||
}
|
||||
|
||||
// Block-of-block assignment (column skip forward)
|
||||
{
|
||||
var mat = new VariableMatrix(5, 5);
|
||||
var s =
|
||||
mat.get(new Slice(Slice.__, Slice.__, 1), new Slice(Slice.__, Slice.__, 2))
|
||||
.get(new Slice(1, 4), new Slice(1, 3));
|
||||
assertEquals(3, s.rows());
|
||||
assertEquals(2, s.cols());
|
||||
s.setValue(new double[][] {{1, 2}, {3, 4}, {5, 6}});
|
||||
|
||||
assertEquals(
|
||||
new SimpleMatrix(
|
||||
new double[][] {
|
||||
{0, 0, 0, 0, 0},
|
||||
{0, 0, 1, 0, 2},
|
||||
{0, 0, 3, 0, 4},
|
||||
{0, 0, 5, 0, 6},
|
||||
{0, 0, 0, 0, 0}
|
||||
}),
|
||||
mat.value());
|
||||
}
|
||||
|
||||
// Block-of-block assignment (column skip backward)
|
||||
{
|
||||
var mat = new VariableMatrix(5, 5);
|
||||
var s =
|
||||
mat.get(new Slice(Slice.__, Slice.__, -1), new Slice(Slice.__, Slice.__, -2))
|
||||
.get(new Slice(1, 4), new Slice(1, 3));
|
||||
assertEquals(3, s.rows());
|
||||
assertEquals(2, s.cols());
|
||||
s.setValue(new double[][] {{1, 2}, {3, 4}, {5, 6}});
|
||||
|
||||
assertEquals(
|
||||
new SimpleMatrix(
|
||||
new double[][] {
|
||||
{0, 0, 0, 0, 0},
|
||||
{6, 0, 5, 0, 0},
|
||||
{4, 0, 3, 0, 0},
|
||||
{2, 0, 1, 0, 0},
|
||||
{0, 0, 0, 0, 0}
|
||||
}),
|
||||
mat.value());
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@SuppressWarnings("PMD.UnusedLocalVariable")
|
||||
@Test
|
||||
void testIterators() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
final var A =
|
||||
new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}});
|
||||
final var sub_A = A.block(2, 1, 1, 2);
|
||||
|
||||
int distance = 0;
|
||||
for (var elem : A) {
|
||||
++distance;
|
||||
}
|
||||
assertEquals(9, distance);
|
||||
|
||||
distance = 0;
|
||||
for (var elem : sub_A) {
|
||||
++distance;
|
||||
}
|
||||
assertEquals(2, distance);
|
||||
|
||||
int i = 1;
|
||||
for (var elem : A) {
|
||||
assertEquals(i, elem.value());
|
||||
++i;
|
||||
}
|
||||
|
||||
i = 8;
|
||||
for (var elem : sub_A) {
|
||||
assertEquals(i, elem.value());
|
||||
++i;
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testValue() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var A =
|
||||
new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}});
|
||||
var expected =
|
||||
new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}});
|
||||
|
||||
// Full matrix
|
||||
assertEquals(expected, A.value());
|
||||
assertEquals(4.0, A.value(3));
|
||||
assertEquals(2.0, A.T().value(3));
|
||||
|
||||
// Block
|
||||
assertEquals(expected.extractMatrix(1, 3, 1, 3), A.block(1, 1, 2, 2).value());
|
||||
assertEquals(8.0, A.block(1, 1, 2, 2).value(2));
|
||||
assertEquals(6.0, A.T().block(1, 1, 2, 2).value(2));
|
||||
|
||||
// Slice
|
||||
assertEquals(
|
||||
expected.extractMatrix(1, 3, 1, 3), A.get(new Slice(1, 3), new Slice(1, 3)).value());
|
||||
assertEquals(8.0, A.get(new Slice(1, 3), new Slice(1, 3)).value(2));
|
||||
assertEquals(6.0, A.get(new Slice(1, 3), new Slice(1, 3)).T().value(2));
|
||||
|
||||
// Block-of-block
|
||||
assertEquals(
|
||||
expected.extractMatrix(1, 3, 1, 3).extractMatrix(0, 2, 1, 2),
|
||||
A.block(1, 1, 2, 2).block(0, 1, 2, 1).value());
|
||||
assertEquals(9.0, A.block(1, 1, 2, 2).block(0, 1, 2, 1).value(1));
|
||||
assertEquals(9.0, A.block(1, 1, 2, 2).T().block(0, 1, 2, 1).value(1));
|
||||
|
||||
// Slice-of-slice
|
||||
assertEquals(
|
||||
expected.extractMatrix(1, 3, 1, 3).extractMatrix(0, 2, 1, 2),
|
||||
A.get(new Slice(1, 3), new Slice(1, 3)).get(Slice.__, new Slice(1, Slice.__)).value());
|
||||
assertEquals(
|
||||
9.0,
|
||||
A.get(new Slice(1, 3), new Slice(1, 3)).get(Slice.__, new Slice(1, Slice.__)).value(1));
|
||||
assertEquals(
|
||||
9.0,
|
||||
A.get(new Slice(1, 3), new Slice(1, 3))
|
||||
.T()
|
||||
.get(Slice.__, new Slice(1, Slice.__))
|
||||
.value(1));
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCwiseMap() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
// VariableMatrix cwiseMap
|
||||
var A = new VariableMatrix(new double[][] {{-2.0, -3.0, -4.0}, {-5.0, -6.0, -7.0}});
|
||||
|
||||
var result1 = A.cwiseMap(Variable::abs);
|
||||
var expected1 = new SimpleMatrix(new double[][] {{2.0, 3.0, 4.0}, {5.0, 6.0, 7.0}});
|
||||
|
||||
// Don't modify original matrix
|
||||
assertEquals(expected1.scale(-1.0), A.value());
|
||||
|
||||
assertEquals(expected1, result1.value());
|
||||
|
||||
// VariableBlock cwiseMap
|
||||
var sub_A = A.block(0, 0, 2, 2);
|
||||
|
||||
var result2 = sub_A.cwiseMap(Variable::abs);
|
||||
var expected2 = new SimpleMatrix(new double[][] {{2.0, 3.0}, {5.0, 6.0}});
|
||||
|
||||
// Don't modify original matrix
|
||||
assertEquals(expected1.scale(-1.0), A.value());
|
||||
assertEquals(expected2.scale(-1.0), sub_A.value());
|
||||
|
||||
assertEquals(expected2, result2.value());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testZeroStaticFunction() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var A = VariableMatrix.zero(2, 3)) {
|
||||
for (var elem : A) {
|
||||
assertEquals(0.0, elem.value());
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testOneStaticFunction() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var A = VariableMatrix.one(2, 3)) {
|
||||
for (var elem : A) {
|
||||
assertEquals(1.0, elem.value());
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testConstantStaticFunction() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var A = VariableMatrix.constant(2, 3, 2.0)) {
|
||||
for (var elem : A) {
|
||||
assertEquals(2.0, elem.value());
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCwiseReduce() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var A = new VariableMatrix(new double[][] {{2.0, 3.0, 4.0}, {5.0, 6.0, 7.0}});
|
||||
var B = new VariableMatrix(new double[][] {{8.0, 9.0, 10.0}, {11.0, 12.0, 13.0}});
|
||||
var result = VariableMatrix.cwiseReduce(A, B, (Variable x, Variable y) -> x.times(y));
|
||||
|
||||
var expected = new SimpleMatrix(new double[][] {{16.0, 27.0, 40.0}, {55.0, 72.0, 91.0}});
|
||||
assertEquals(expected, result.value());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBlockFreeFunction() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
var A = new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
|
||||
var B = new VariableMatrix(new double[][] {{7.0}, {8.0}});
|
||||
|
||||
var mat1 = VariableMatrix.block(new VariableMatrix[][] {{A, B}});
|
||||
var expected1 = new SimpleMatrix(new double[][] {{1.0, 2.0, 3.0, 7.0}, {4.0, 5.0, 6.0, 8.0}});
|
||||
assertEquals(2, mat1.rows());
|
||||
assertEquals(4, mat1.cols());
|
||||
assertEquals(expected1, mat1.value());
|
||||
|
||||
var C = new VariableMatrix(new double[][] {{9.0, 10.0, 11.0, 12.0}});
|
||||
|
||||
var mat2 = VariableMatrix.block(new VariableMatrix[][] {{A, B}, {C}});
|
||||
var expected2 =
|
||||
new SimpleMatrix(
|
||||
new double[][] {{1.0, 2.0, 3.0, 7.0}, {4.0, 5.0, 6.0, 8.0}, {9.0, 10.0, 11.0, 12.0}});
|
||||
assertEquals(3, mat2.rows());
|
||||
assertEquals(4, mat2.cols());
|
||||
assertEquals(expected2, mat2.value());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
private void checkSolve(VariableMatrix A, VariableMatrix B) {
|
||||
try (var X = VariableMatrix.solve(A, B)) {
|
||||
assertEquals(A.cols(), X.rows());
|
||||
assertEquals(B.cols(), X.cols());
|
||||
assertTrue(A.value().mult(X.value()).minus(B.value()).normF() < 1e-12);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSolveFreeFunction() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
// 1x1 special case
|
||||
try (var pool = new VariablePool()) {
|
||||
checkSolve(
|
||||
new VariableMatrix(new double[][] {{2.0}}), new VariableMatrix(new double[][] {{5.0}}));
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
// 2x2 special case
|
||||
try (var pool = new VariablePool()) {
|
||||
checkSolve(
|
||||
new VariableMatrix(new double[][] {{1.0, 2.0}, {3.0, 4.0}}),
|
||||
new VariableMatrix(new double[][] {{5.0}, {6.0}}));
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
// 3x3 special case
|
||||
try (var pool = new VariablePool()) {
|
||||
checkSolve(
|
||||
new VariableMatrix(new double[][] {{1.0, 2.0, 3.0}, {-4.0, -5.0, 6.0}, {7.0, 8.0, 9.0}}),
|
||||
new VariableMatrix(new double[][] {{10.0}, {11.0}, {12.0}}));
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
// 4x4 special case
|
||||
try (var pool = new VariablePool()) {
|
||||
checkSolve(
|
||||
new VariableMatrix(
|
||||
new double[][] {
|
||||
{1.0, 2.0, 3.0, -4.0},
|
||||
{-5.0, 6.0, 7.0, 8.0},
|
||||
{9.0, 10.0, 11.0, 12.0},
|
||||
{13.0, 14.0, 15.0, 16.0}
|
||||
}),
|
||||
new VariableMatrix(new double[][] {{17.0}, {18.0}, {19.0}, {20.0}}));
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
// 5x5 general case
|
||||
try (var pool = new VariablePool()) {
|
||||
checkSolve(
|
||||
new VariableMatrix(
|
||||
new double[][] {
|
||||
{1.0, 2.0, 3.0, -4.0, 5.0},
|
||||
{-5.0, 6.0, 7.0, 8.0, 9.0},
|
||||
{9.0, 10.0, 11.0, 12.0, 13.0},
|
||||
{13.0, 14.0, 15.0, 16.0, 17.0},
|
||||
{17.0, 18.0, 19.0, 20.0, 21.0}
|
||||
}),
|
||||
new VariableMatrix(new double[][] {{21.0}, {22.0}, {23.0}, {24.0}, {25.0}}));
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.autodiff;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class VariableTest {
|
||||
@Test
|
||||
void testDefaultConstructor() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var a = new Variable()) {
|
||||
assertEquals(0.0, a.value());
|
||||
assertEquals(ExpressionType.LINEAR, a.type());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testConstantConstructor() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var pool = new VariablePool()) {
|
||||
// float
|
||||
var a = new Variable(1.0);
|
||||
assertEquals(1, a.value());
|
||||
assertEquals(ExpressionType.CONSTANT, a.type());
|
||||
|
||||
// int
|
||||
var b = new Variable(2);
|
||||
assertEquals(2, b.value());
|
||||
assertEquals(ExpressionType.CONSTANT, b.type());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSetValue() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var a = new Variable()) {
|
||||
a.setValue(1.0);
|
||||
assertEquals(1.0, a.value());
|
||||
|
||||
a.setValue(2.0);
|
||||
assertEquals(2.0, a.value());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
import java.util.function.BiFunction;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.linalg.MatBuilder;
|
||||
import org.wpilib.math.linalg.Matrix;
|
||||
import org.wpilib.math.numbers.N1;
|
||||
@@ -95,6 +96,8 @@ class ArmFeedforwardTest {
|
||||
|
||||
@Test
|
||||
void testCalculate() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
final double ks = 0.5;
|
||||
final double kv = 1.5;
|
||||
final double ka = 2;
|
||||
@@ -110,10 +113,14 @@ class ArmFeedforwardTest {
|
||||
calculateAndSimulate(armFF, ks, kv, ka, kg, Math.PI / 3, 1.0, 0.95, 0.020);
|
||||
calculateAndSimulate(armFF, ks, kv, ka, kg, -Math.PI / 3, 1.0, 1.05, 0.020);
|
||||
calculateAndSimulate(armFF, ks, kv, ka, kg, -Math.PI / 3, 1.0, 0.95, 0.020);
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCalculateIllConditionedModel() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
final double ks = 0.39671;
|
||||
final double kv = 2.7167;
|
||||
final double ka = 1e-2;
|
||||
@@ -129,10 +136,14 @@ class ArmFeedforwardTest {
|
||||
assertEquals(
|
||||
armFF.calculate(currentAngle, currentVelocity, nextVelocity),
|
||||
ks + kv * currentVelocity + ka * averageAccel + kg * Math.cos(currentAngle));
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCalculateIllConditionedGradient() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
final double ks = 0.39671;
|
||||
final double kv = 2.7167;
|
||||
final double ka = 0.50799;
|
||||
@@ -140,6 +151,8 @@ class ArmFeedforwardTest {
|
||||
final ArmFeedforward armFF = new ArmFeedforward(ks, kg, kv, ka);
|
||||
|
||||
calculateAndSimulate(armFF, ks, kv, ka, kg, 1.0, 0.02, 0.0, 0.02);
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -11,6 +11,7 @@ import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
|
||||
class Ellipse2dTest {
|
||||
private static final double kEpsilon = 1E-9;
|
||||
@@ -56,6 +57,8 @@ class Ellipse2dTest {
|
||||
|
||||
@Test
|
||||
void testDistance() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
var center = new Pose2d(1.0, 2.0, Rotation2d.fromDegrees(270.0));
|
||||
var ellipse = new Ellipse2d(center, 1.0, 2.0);
|
||||
|
||||
@@ -70,10 +73,14 @@ class Ellipse2dTest {
|
||||
|
||||
var point4 = new Translation2d(-1.0, 2.5);
|
||||
assertEquals(0.19210128384806818, ellipse.getDistance(point4), kEpsilon);
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testNearest() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
var center = new Pose2d(1.0, 2.0, Rotation2d.fromDegrees(270.0));
|
||||
var ellipse = new Ellipse2d(center, 1.0, 2.0);
|
||||
|
||||
@@ -100,6 +107,8 @@ class Ellipse2dTest {
|
||||
assertAll(
|
||||
() -> assertEquals(-0.8512799937611617, nearestPoint4.getX(), kEpsilon),
|
||||
() -> assertEquals(2.378405333174535, nearestPoint4.getY(), kEpsilon));
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.wpilib.math.autodiff.Variable.pow;
|
||||
import static org.wpilib.math.optimization.Constraints.bounds;
|
||||
import static org.wpilib.math.optimization.Constraints.eq;
|
||||
import static org.wpilib.math.optimization.Constraints.le;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.wpilib.math.autodiff.ExpressionType;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.autodiff.VariableMatrix;
|
||||
import org.wpilib.math.optimization.solver.ExitStatus;
|
||||
|
||||
class ArmOnElevatorProblemTest {
|
||||
@Test
|
||||
void testArmOnElevatorProblem() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
final int N = 800;
|
||||
|
||||
final double ELEVATOR_START_HEIGHT = 1.0; // m
|
||||
final double ELEVATOR_END_HEIGHT = 1.25; // m
|
||||
final double ELEVATOR_MAX_VELOCITY = 1.0; // m/s
|
||||
final double ELEVATOR_MAX_ACCELERATION = 2.0; // m/s²
|
||||
|
||||
final double ARM_LENGTH = 1.0; // m
|
||||
final double ARM_START_ANGLE = 0.0; // rad
|
||||
final double ARM_END_ANGLE = Math.PI; // rad
|
||||
final double ARM_MAX_VELOCITY = 2.0 * Math.PI; // rad/s
|
||||
final double ARM_MAX_ACCELERATION = 4.0 * Math.PI; // rad/s²
|
||||
|
||||
final double END_EFFECTOR_MAX_HEIGHT = 1.8; // m
|
||||
|
||||
final double TOTAL_TIME = 4.0;
|
||||
final double dt = TOTAL_TIME / N;
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var elevator = problem.decisionVariable(2, N + 1);
|
||||
var elevator_accel = problem.decisionVariable(1, N);
|
||||
|
||||
var arm = problem.decisionVariable(2, N + 1);
|
||||
var arm_accel = problem.decisionVariable(1, N);
|
||||
|
||||
for (int k = 0; k < N; ++k) {
|
||||
// Elevator dynamics constraints
|
||||
problem.subjectTo(
|
||||
eq(
|
||||
elevator.get(0, k + 1),
|
||||
elevator
|
||||
.get(0, k)
|
||||
.plus(elevator.get(1, k).times(dt))
|
||||
.plus(elevator_accel.get(0, k).times(0.5 * dt * dt))));
|
||||
problem.subjectTo(
|
||||
eq(
|
||||
elevator.get(1, k + 1),
|
||||
elevator.get(1, k).plus(elevator_accel.get(0, k).times(dt))));
|
||||
|
||||
// Arm dynamics constraints
|
||||
problem.subjectTo(
|
||||
eq(
|
||||
arm.get(0, k + 1),
|
||||
arm.get(0, k)
|
||||
.plus(arm.get(1, k).times(dt))
|
||||
.plus(arm_accel.get(0, k).times(0.5 * dt * dt))));
|
||||
problem.subjectTo(eq(arm.get(1, k + 1), arm.get(1, k).plus(arm_accel.get(0, k).times(dt))));
|
||||
}
|
||||
|
||||
// Elevator start and end conditions
|
||||
problem.subjectTo(
|
||||
eq(elevator.col(0), new VariableMatrix(new double[][] {{ELEVATOR_START_HEIGHT}, {0.0}})));
|
||||
problem.subjectTo(
|
||||
eq(elevator.col(N), new VariableMatrix(new double[][] {{ELEVATOR_END_HEIGHT}, {0.0}})));
|
||||
|
||||
// Arm start and end conditions
|
||||
problem.subjectTo(
|
||||
eq(arm.col(0), new VariableMatrix(new double[][] {{ARM_START_ANGLE}, {0.0}})));
|
||||
problem.subjectTo(
|
||||
eq(arm.col(N), new VariableMatrix(new double[][] {{ARM_END_ANGLE}, {0.0}})));
|
||||
|
||||
// Elevator velocity limits
|
||||
problem.subjectTo(bounds(-ELEVATOR_MAX_VELOCITY, elevator.row(1), ELEVATOR_MAX_VELOCITY));
|
||||
|
||||
// Elevator acceleration limits
|
||||
problem.subjectTo(
|
||||
bounds(-ELEVATOR_MAX_ACCELERATION, elevator_accel, ELEVATOR_MAX_ACCELERATION));
|
||||
|
||||
// Arm velocity limits
|
||||
problem.subjectTo(bounds(-ARM_MAX_VELOCITY, arm.row(1), ARM_MAX_VELOCITY));
|
||||
|
||||
// Arm acceleration limits
|
||||
problem.subjectTo(bounds(-ARM_MAX_ACCELERATION, arm_accel, ARM_MAX_ACCELERATION));
|
||||
|
||||
// Height limit
|
||||
var heights = elevator.row(0).plus(arm.row(0).cwiseMap(Variable::sin).times(ARM_LENGTH));
|
||||
problem.subjectTo(le(heights, END_EFFECTOR_MAX_HEIGHT));
|
||||
|
||||
// Cost function
|
||||
var J = new Variable(0.0);
|
||||
for (int k = 0; k < N + 1; ++k) {
|
||||
J =
|
||||
J.plus(
|
||||
pow(new Variable(ELEVATOR_END_HEIGHT).minus(elevator.get(0, k)), 2)
|
||||
.plus(pow(new Variable(ARM_END_ANGLE).minus(arm.get(0, k)), 2)));
|
||||
}
|
||||
problem.minimize(J);
|
||||
|
||||
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONLINEAR, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.wpilib.math.optimization.Constraints.bounds;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.wpilib.math.autodiff.ExpressionType;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.optimization.ocp.DynamicsType;
|
||||
import org.wpilib.math.optimization.ocp.TimestepMethod;
|
||||
import org.wpilib.math.optimization.ocp.TranscriptionMethod;
|
||||
import org.wpilib.math.optimization.solver.ExitStatus;
|
||||
import org.wpilib.math.util.MathUtil;
|
||||
|
||||
class CartPoleOCPTest {
|
||||
@Test
|
||||
void testCartPole() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
final double TOTAL_TIME = 5.0; // s
|
||||
final double dt = 0.05; // s
|
||||
final int N = (int) (TOTAL_TIME / dt);
|
||||
|
||||
final double u_max = 20.0; // N
|
||||
final double d_max = 2.0; // m
|
||||
|
||||
final var x_initial = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}});
|
||||
final var x_final = new SimpleMatrix(new double[][] {{1.0}, {Math.PI}, {0.0}, {0.0}});
|
||||
|
||||
try (var problem =
|
||||
new OCP(
|
||||
4,
|
||||
1,
|
||||
dt,
|
||||
N,
|
||||
CartPoleUtil::cartPoleDynamics,
|
||||
DynamicsType.EXPLICIT_ODE,
|
||||
TimestepMethod.VARIABLE_SINGLE,
|
||||
TranscriptionMethod.DIRECT_COLLOCATION)) {
|
||||
// x = [q, q̇]ᵀ = [x, θ, ẋ, θ̇]ᵀ
|
||||
var X = problem.X();
|
||||
|
||||
// Initial guess
|
||||
for (int k = 0; k < N + 1; ++k) {
|
||||
X.get(0, k).setValue(MathUtil.lerp(x_initial.get(0), x_final.get(0), (double) k / N));
|
||||
X.get(1, k).setValue(MathUtil.lerp(x_initial.get(1), x_final.get(1), (double) k / N));
|
||||
}
|
||||
|
||||
// Initial conditions
|
||||
problem.constrainInitialState(x_initial);
|
||||
|
||||
// Final conditions
|
||||
problem.constrainFinalState(x_final);
|
||||
|
||||
// Cart position constraints
|
||||
problem.forEachStep(
|
||||
(x, u) -> {
|
||||
problem.subjectTo(bounds(0.0, x.get(0), d_max));
|
||||
});
|
||||
|
||||
// Input constraints
|
||||
problem.setLowerInputBound(-u_max);
|
||||
problem.setUpperInputBound(u_max);
|
||||
|
||||
// u = f_x
|
||||
var U = problem.U();
|
||||
|
||||
// Minimize sum squared inputs
|
||||
var J = new Variable(0.0);
|
||||
for (int k = 0; k < N; ++k) {
|
||||
J = J.plus(U.col(k).T().times(U.col(k)).get(0));
|
||||
}
|
||||
problem.minimize(J);
|
||||
|
||||
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONLINEAR, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
// Verify initial state
|
||||
assertEquals(x_initial.get(0), X.value(0, 0), 1e-8);
|
||||
assertEquals(x_initial.get(1), X.value(1, 0), 1e-8);
|
||||
assertEquals(x_initial.get(2), X.value(2, 0), 1e-8);
|
||||
assertEquals(x_initial.get(3), X.value(3, 0), 1e-8);
|
||||
|
||||
// Verify final state
|
||||
assertEquals(x_final.get(0), X.value(0, N), 1e-8);
|
||||
assertEquals(x_final.get(1), X.value(1, N), 1e-8);
|
||||
assertEquals(x_final.get(2), X.value(2, N), 1e-8);
|
||||
assertEquals(x_final.get(3), X.value(3, N), 1e-8);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.wpilib.math.autodiff.NumericalIntegration.rk4;
|
||||
import static org.wpilib.math.optimization.Constraints.bounds;
|
||||
import static org.wpilib.math.optimization.Constraints.eq;
|
||||
import static org.wpilib.math.system.NumericalIntegration.rk4;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.wpilib.math.autodiff.ExpressionType;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.optimization.solver.ExitStatus;
|
||||
import org.wpilib.math.util.MathUtil;
|
||||
|
||||
class CartPoleProblemTest {
|
||||
@Test
|
||||
void testCartPoleProblem() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
final double TOTAL_TIME = 5.0; // s
|
||||
final double dt = 0.05; // s
|
||||
final int N = (int) (TOTAL_TIME / dt);
|
||||
|
||||
final double u_max = 20.0; // N
|
||||
final double d_max = 2.0; // m
|
||||
|
||||
final var x_initial = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}});
|
||||
final var x_final = new SimpleMatrix(new double[][] {{1.0}, {Math.PI}, {0.0}, {0.0}});
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
// x = [q, q̇]ᵀ = [x, θ, ẋ, θ̇]ᵀ
|
||||
var X = problem.decisionVariable(4, N + 1);
|
||||
|
||||
// Initial guess
|
||||
for (int k = 0; k < N + 1; ++k) {
|
||||
X.get(0, k).setValue(MathUtil.lerp(x_initial.get(0), x_final.get(0), (double) k / N));
|
||||
X.get(1, k).setValue(MathUtil.lerp(x_initial.get(1), x_final.get(1), (double) k / N));
|
||||
}
|
||||
|
||||
// u = f_x
|
||||
var U = problem.decisionVariable(1, N);
|
||||
|
||||
// Initial conditions
|
||||
problem.subjectTo(eq(X.col(0), x_initial));
|
||||
|
||||
// Final conditions
|
||||
problem.subjectTo(eq(X.col(N), x_final));
|
||||
|
||||
// Cart position constraints
|
||||
problem.subjectTo(bounds(0.0, X.row(0), d_max));
|
||||
|
||||
// Input constraints
|
||||
problem.subjectTo(bounds(-u_max, U, u_max));
|
||||
|
||||
// Dynamics constraints - RK4 integration
|
||||
for (int k = 0; k < N; ++k) {
|
||||
problem.subjectTo(
|
||||
eq(X.col(k + 1), rk4(CartPoleUtil::cartPoleDynamics, X.col(k), U.col(k), dt)));
|
||||
}
|
||||
|
||||
// Minimize sum squared inputs
|
||||
var J = new Variable(0.0);
|
||||
for (int k = 0; k < N; ++k) {
|
||||
J = J.plus(U.col(k).T().times(U.col(k)).get(0));
|
||||
}
|
||||
problem.minimize(J);
|
||||
|
||||
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONLINEAR, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
// Verify initial state
|
||||
assertEquals(x_initial.get(0), X.value(0, 0), 1e-8);
|
||||
assertEquals(x_initial.get(1), X.value(1, 0), 1e-8);
|
||||
assertEquals(x_initial.get(2), X.value(2, 0), 1e-8);
|
||||
assertEquals(x_initial.get(3), X.value(3, 0), 1e-8);
|
||||
|
||||
// Verify solution
|
||||
for (int k = 0; k < N; ++k) {
|
||||
// Cart position constraints
|
||||
assertTrue(X.get(0, k).value() >= 0.0);
|
||||
assertTrue(X.get(0, k).value() <= d_max);
|
||||
|
||||
// Input constraints
|
||||
assertTrue(U.get(0, k).value() >= -u_max);
|
||||
assertTrue(U.get(0, k).value() <= u_max);
|
||||
|
||||
// Dynamics constraints
|
||||
var expected_x_k1 =
|
||||
rk4(CartPoleUtil::cartPoleDynamics, X.col(k).value(), U.col(k).value(), dt);
|
||||
var actual_x_k1 = X.col(k + 1).value();
|
||||
for (int row = 0; row < actual_x_k1.getNumRows(); ++row) {
|
||||
assertEquals(expected_x_k1.get(row), actual_x_k1.get(row), 1e-8);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify final state
|
||||
assertEquals(x_final.get(0), X.value(0, N), 1e-8);
|
||||
assertEquals(x_final.get(1), X.value(1, N), 1e-8);
|
||||
assertEquals(x_final.get(2), X.value(2, N), 1e-8);
|
||||
assertEquals(x_final.get(3), X.value(3, N), 1e-8);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.wpilib.math.autodiff.Variable.cos;
|
||||
import static org.wpilib.math.autodiff.Variable.sin;
|
||||
import static org.wpilib.math.autodiff.VariableMatrix.solve;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.autodiff.VariableMatrix;
|
||||
|
||||
// https://underactuated.mit.edu/acrobot.html#cart_pole
|
||||
//
|
||||
// θ is CCW+ measured from negative y-axis.
|
||||
//
|
||||
// q = [x, θ]ᵀ
|
||||
// q̇ = [ẋ, θ̇]ᵀ
|
||||
// u = f_x
|
||||
//
|
||||
// M(q)q̈ + C(q, q̇)q̇ = τ_g(q) + Bu
|
||||
// M(q)q̈ = τ_g(q) − C(q, q̇)q̇ + Bu
|
||||
// q̈ = M⁻¹(q)(τ_g(q) − C(q, q̇)q̇ + Bu)
|
||||
//
|
||||
// [ m_c + m_p m_p l cosθ]
|
||||
// M(q) = [m_p l cosθ m_p l² ]
|
||||
//
|
||||
// [0 −m_p lθ̇ sinθ]
|
||||
// C(q, q̇) = [0 0 ]
|
||||
//
|
||||
// [ 0 ]
|
||||
// τ_g(q) = [-m_p gl sinθ]
|
||||
//
|
||||
// [1]
|
||||
// B = [0]
|
||||
|
||||
public final class CartPoleUtil {
|
||||
private CartPoleUtil() {
|
||||
// Utility class.
|
||||
}
|
||||
|
||||
private static final double m_c = 5.0; // Cart mass (kg)
|
||||
private static final double m_p = 0.5; // Pole mass (kg)
|
||||
private static final double l = 0.5; // Pole length (m)
|
||||
private static final double g = 9.806; // Acceleration due to gravity (m/s²)
|
||||
|
||||
public static SimpleMatrix cartPoleDynamics(SimpleMatrix x, SimpleMatrix u) {
|
||||
var q = x.extractMatrix(0, 2, 0, 1);
|
||||
var qdot = x.extractMatrix(2, 4, 0, 1);
|
||||
var theta = q.get(1, 0);
|
||||
var thetadot = qdot.get(1, 0);
|
||||
|
||||
// [ m_c + m_p m_p l cosθ]
|
||||
// M(q) = [m_p l cosθ m_p l² ]
|
||||
var M =
|
||||
new SimpleMatrix(
|
||||
new double[][] {
|
||||
{m_c + m_p, m_p * l * Math.cos(theta)},
|
||||
{m_p * l * Math.cos(theta), m_p * Math.pow(l, 2)}
|
||||
});
|
||||
|
||||
// [0 −m_p lθ̇ sinθ]
|
||||
// C(q, q̇) = [0 0 ]
|
||||
var C = new SimpleMatrix(new double[][] {{0, -m_p * l * thetadot * Math.sin(theta)}, {0, 0}});
|
||||
|
||||
// [ 0 ]
|
||||
// τ_g(q) = [-m_p gl sinθ]
|
||||
var tau_g = new SimpleMatrix(new double[][] {{0}, {-m_p * g * l * Math.sin(theta)}});
|
||||
|
||||
// [1]
|
||||
// B = [0]
|
||||
final var B = new SimpleMatrix(new double[][] {{1}, {0}});
|
||||
|
||||
// q̈ = M⁻¹(q)(τ_g(q) − C(q, q̇)q̇ + Bu)
|
||||
var qddot = new SimpleMatrix(4, 1);
|
||||
qddot.insertIntoThis(0, 0, qdot);
|
||||
qddot.insertIntoThis(2, 0, M.solve(tau_g.minus(C.mult(qdot)).plus(B.mult(u))));
|
||||
return qddot;
|
||||
}
|
||||
|
||||
public static VariableMatrix cartPoleDynamics(VariableMatrix x, VariableMatrix u) {
|
||||
var q = x.segment(0, 2);
|
||||
var qdot = x.segment(2, 2);
|
||||
var theta = q.get(1);
|
||||
var thetadot = qdot.get(1);
|
||||
|
||||
// [ m_c + m_p m_p l cosθ]
|
||||
// M(q) = [m_p l cosθ m_p l² ]
|
||||
var M =
|
||||
new VariableMatrix(
|
||||
new Variable[][] {
|
||||
{new Variable(m_c + m_p), cos(theta).times(m_p * l)},
|
||||
{cos(theta).times(m_p * l), new Variable(m_p * Math.pow(l, 2))}
|
||||
});
|
||||
|
||||
// [0 −m_p lθ̇ sinθ]
|
||||
// C(q, q̇) = [0 0 ]
|
||||
var C =
|
||||
new VariableMatrix(
|
||||
new Variable[][] {
|
||||
{new Variable(0), thetadot.times(-m_p * l).times(sin(theta))},
|
||||
{new Variable(0), new Variable(0)}
|
||||
});
|
||||
|
||||
// [ 0 ]
|
||||
// τ_g(q) = [-m_p gl sinθ]
|
||||
var tau_g =
|
||||
new VariableMatrix(new Variable[][] {{new Variable(0)}, {sin(theta).times(-m_p * g * l)}});
|
||||
|
||||
// [1]
|
||||
// B = [0]
|
||||
var B = new VariableMatrix(new double[][] {{1}, {0}});
|
||||
|
||||
// q̈ = M⁻¹(q)(τ_g(q) − C(q, q̇)q̇ + Bu)
|
||||
var qddot = new VariableMatrix(4);
|
||||
qddot.segment(0, 2).set(qdot);
|
||||
qddot.segment(2, 2).set(solve(M, tau_g.minus(C.times(qdot)).plus(B.times(u))));
|
||||
return qddot;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.wpilib.math.optimization.Constraints.ge;
|
||||
import static org.wpilib.math.optimization.Constraints.le;
|
||||
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.autodiff.VariableMatrix;
|
||||
|
||||
/**
|
||||
* This class computes the optimal current allocation for a list of subsystems given a list of their
|
||||
* desired currents and current tolerances that determine which subsystem gets less current if the
|
||||
* current budget is exceeded. Subsystems with a smaller tolerance are given higher priority.
|
||||
*/
|
||||
public class CurrentManager implements AutoCloseable {
|
||||
private final Problem m_problem = new Problem();
|
||||
private final VariableMatrix m_desiredCurrents;
|
||||
private final VariableMatrix m_allocatedCurrents;
|
||||
|
||||
/**
|
||||
* Constructs a CurrentManager.
|
||||
*
|
||||
* @param currentTolerances The relative current tolerance of each subsystem.
|
||||
* @param maxCurrent The current budget to allocate between subsystems.
|
||||
*/
|
||||
public CurrentManager(double[] currentTolerances, double maxCurrent) {
|
||||
this.m_desiredCurrents = new VariableMatrix(currentTolerances.length, 1);
|
||||
this.m_allocatedCurrents = m_problem.decisionVariable(currentTolerances.length);
|
||||
|
||||
// Ensure m_desired_currents contains initialized Variables
|
||||
for (int row = 0; row < m_desiredCurrents.rows(); ++row) {
|
||||
// Don't initialize to 0 or 1, because those will get folded by Sleipnir
|
||||
m_desiredCurrents.get(row).setValue(Double.POSITIVE_INFINITY);
|
||||
}
|
||||
|
||||
var J = new Variable(0.0);
|
||||
var currentSum = new Variable(0.0);
|
||||
for (int i = 0; i < currentTolerances.length; ++i) {
|
||||
// The weight is 1/tolᵢ² where tolᵢ is the tolerance between the desired
|
||||
// and allocated current for subsystem i
|
||||
var error = m_desiredCurrents.get(i).minus(m_allocatedCurrents.get(i));
|
||||
J = J.plus(error.times(error).div(currentTolerances[i] * currentTolerances[i]));
|
||||
|
||||
currentSum = currentSum.plus(m_allocatedCurrents.get(i));
|
||||
|
||||
// Currents must be nonnegative
|
||||
m_problem.subjectTo(ge(m_allocatedCurrents.get(i), 0.0));
|
||||
}
|
||||
m_problem.minimize(J);
|
||||
|
||||
// Keep total current below maximum
|
||||
m_problem.subjectTo(le(currentSum, maxCurrent));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
m_problem.close();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the optimal current allocation for a list of subsystems given a list of their desired
|
||||
* currents and current tolerances that determine which subsystem gets less current if the current
|
||||
* budget is exceeded. Subsystems with a smaller tolerance are given higher priority.
|
||||
*
|
||||
* @param desiredCurrents The desired current for each subsystem.
|
||||
* @throws RuntimeException if the number of desired currents doesn't equal the number of
|
||||
* tolerances passed in the constructor.
|
||||
*/
|
||||
public double[] calculate(double[] desiredCurrents) {
|
||||
if (m_desiredCurrents.rows() != desiredCurrents.length) {
|
||||
throw new RuntimeException(
|
||||
"Number of desired currents must equal the number of tolerances passed in the "
|
||||
+ "constructor.");
|
||||
}
|
||||
|
||||
for (int i = 0; i < desiredCurrents.length; ++i) {
|
||||
m_desiredCurrents.get(i).setValue(desiredCurrents[i]);
|
||||
}
|
||||
|
||||
m_problem.solve();
|
||||
|
||||
var result = new double[desiredCurrents.length];
|
||||
for (int i = 0; i < desiredCurrents.length; ++i) {
|
||||
result[i] = Math.max(m_allocatedCurrents.value(i), 0.0);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
|
||||
class CurrentManagerTest {
|
||||
@Test
|
||||
void testEnoughCurrent() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var manager = new CurrentManager(new double[] {1.0, 5.0, 10.0, 5.0}, 40.0)) {
|
||||
var currents = manager.calculate(new double[] {25.0, 10.0, 5.0, 0.0});
|
||||
|
||||
assertEquals(25.0, currents[0], 1e-3);
|
||||
assertEquals(10.0, currents[1], 1e-3);
|
||||
assertEquals(5.0, currents[2], 1e-3);
|
||||
assertEquals(0.0, currents[3], 1e-3);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testNotEnoughCurrent() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var manager = new CurrentManager(new double[] {1.0, 5.0, 10.0, 5.0}, 40.0)) {
|
||||
var currents = manager.calculate(new double[] {30.0, 10.0, 5.0, 0.0});
|
||||
|
||||
// Expected values are from the following program:
|
||||
//
|
||||
// #!/usr/bin/env python3
|
||||
//
|
||||
// from scipy.optimize import minimize
|
||||
//
|
||||
// r = [30.0, 10.0, 5.0, 0.0]
|
||||
// q = [1.0, 5.0, 10.0, 5.0]
|
||||
//
|
||||
// result = minimize(
|
||||
// lambda x: sum((r[i] - x[i]) ** 2 / q[i] ** 2 for i in range(4)),
|
||||
// [0.0, 0.0, 0.0, 0.0],
|
||||
// constraints=[
|
||||
// {"type": "ineq", "fun": lambda x: x},
|
||||
// {"type": "ineq", "fun": lambda x: 40.0 - sum(x)},
|
||||
// ],
|
||||
// )
|
||||
// print(result.x)
|
||||
assertEquals(29.960, currents[0], 1e-3);
|
||||
assertEquals(9.008, currents[1], 1e-3);
|
||||
assertEquals(1.032, currents[2], 1e-3);
|
||||
assertEquals(0.0, currents[3], 1e-3);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.wpilib.math.MatrixAssertions.assertEquals;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
|
||||
class DecisionVariableTest {
|
||||
@Test
|
||||
void testScalarInitAssign() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
// Scalar zero init
|
||||
var x = problem.decisionVariable();
|
||||
assertEquals(0.0, x.value());
|
||||
|
||||
// Scalar assignment
|
||||
x.setValue(1.0);
|
||||
assertEquals(1.0, x.value());
|
||||
x.setValue(2.0);
|
||||
assertEquals(2.0, x.value());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testVectorInitAssign() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
// Vector zero init
|
||||
var y = problem.decisionVariable(2);
|
||||
assertEquals(0.0, y.value(0));
|
||||
assertEquals(0.0, y.value(1));
|
||||
|
||||
// Vector assignment
|
||||
y.get(0).setValue(1.0);
|
||||
y.get(1).setValue(2.0);
|
||||
assertEquals(1.0, y.value(0));
|
||||
assertEquals(2.0, y.value(1));
|
||||
y.get(0).setValue(3.0);
|
||||
y.get(1).setValue(4.0);
|
||||
assertEquals(3.0, y.value(0));
|
||||
assertEquals(4.0, y.value(1));
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testDynamicMatrixInitAssign() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
// Matrix zero init
|
||||
var z = problem.decisionVariable(3, 2);
|
||||
assertEquals(0.0, z.value(0, 0));
|
||||
assertEquals(0.0, z.value(0, 1));
|
||||
assertEquals(0.0, z.value(1, 0));
|
||||
assertEquals(0.0, z.value(1, 1));
|
||||
assertEquals(0.0, z.value(2, 0));
|
||||
assertEquals(0.0, z.value(2, 1));
|
||||
|
||||
// Matrix assignment; element comparison
|
||||
z.setValue(new double[][] {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
|
||||
assertEquals(1.0, z.value(0, 0));
|
||||
assertEquals(2.0, z.value(0, 1));
|
||||
assertEquals(3.0, z.value(1, 0));
|
||||
assertEquals(4.0, z.value(1, 1));
|
||||
assertEquals(5.0, z.value(2, 0));
|
||||
assertEquals(6.0, z.value(2, 1));
|
||||
|
||||
// Matrix assignment; matrix comparison
|
||||
{
|
||||
var expected = new SimpleMatrix(new double[][] {{7.0, 8.0}, {9.0, 10.0}, {11.0, 12.0}});
|
||||
z.setValue(expected);
|
||||
assertEquals(expected, z.value());
|
||||
}
|
||||
|
||||
// Block assignment
|
||||
{
|
||||
var expected_block = new double[][] {{1.0}, {1.0}};
|
||||
z.block(0, 0, 2, 1).setValue(expected_block);
|
||||
|
||||
var expected_result =
|
||||
new SimpleMatrix(new double[][] {{1.0, 8.0}, {1.0, 10.0}, {11.0, 12.0}});
|
||||
assertEquals(expected_result, z.value());
|
||||
}
|
||||
|
||||
// Segment assignment
|
||||
{
|
||||
var expected_block = new double[][] {{1.0}, {1.0}};
|
||||
z.block(0, 0, 3, 1).segment(0, 2).setValue(expected_block);
|
||||
|
||||
var expected_result =
|
||||
new SimpleMatrix(new double[][] {{1.0, 8.0}, {1.0, 10.0}, {11.0, 12.0}});
|
||||
assertEquals(expected_result, z.value());
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSymmetricMatrix() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
// Matrix zero init
|
||||
var A = problem.symmetricDecisionVariable(2);
|
||||
assertEquals(0.0, A.value(0, 0));
|
||||
assertEquals(0.0, A.value(0, 1));
|
||||
assertEquals(0.0, A.value(1, 0));
|
||||
assertEquals(0.0, A.value(1, 1));
|
||||
|
||||
// Assign to lower triangle
|
||||
A.get(0, 0).setValue(1.0);
|
||||
A.get(1, 0).setValue(2.0);
|
||||
A.get(1, 1).setValue(3.0);
|
||||
|
||||
// Confirm whole matrix changed
|
||||
assertEquals(1.0, A.value(0, 0));
|
||||
assertEquals(2.0, A.value(0, 1));
|
||||
assertEquals(2.0, A.value(1, 0));
|
||||
assertEquals(3.0, A.value(1, 1));
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.wpilib.math.autodiff.ExpressionType;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.optimization.ocp.DynamicsType;
|
||||
import org.wpilib.math.optimization.ocp.TimestepMethod;
|
||||
import org.wpilib.math.optimization.ocp.TranscriptionMethod;
|
||||
import org.wpilib.math.optimization.solver.ExitStatus;
|
||||
import org.wpilib.math.optimization.solver.Options;
|
||||
|
||||
class DifferentialDriveOCPTest {
|
||||
@Test
|
||||
void testDifferentialDrive() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
final int N = 50;
|
||||
|
||||
final double minTimestep = 0.05; // s
|
||||
final var x_initial = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}, {0.0}});
|
||||
final var x_final = new SimpleMatrix(new double[][] {{1.0}, {1.0}, {0.0}, {0.0}, {0.0}});
|
||||
final var u_min = new SimpleMatrix(new double[][] {{-12.0}, {-12.0}});
|
||||
final var u_max = new SimpleMatrix(new double[][] {{12.0}, {12.0}});
|
||||
|
||||
try (var problem =
|
||||
new OCP(
|
||||
5,
|
||||
2,
|
||||
minTimestep,
|
||||
N,
|
||||
DifferentialDriveUtil::differentialDriveDynamics,
|
||||
DynamicsType.EXPLICIT_ODE,
|
||||
TimestepMethod.VARIABLE_SINGLE,
|
||||
TranscriptionMethod.DIRECT_TRANSCRIPTION)) {
|
||||
// Seed the min time formulation with lerp between waypoints
|
||||
for (int i = 0; i < N + 1; ++i) {
|
||||
problem.X().get(0, i).setValue((double) i / (N + 1));
|
||||
problem.X().get(1, i).setValue((double) i / (N + 1));
|
||||
}
|
||||
|
||||
problem.constrainInitialState(x_initial);
|
||||
problem.constrainFinalState(x_final);
|
||||
|
||||
problem.setLowerInputBound(u_min);
|
||||
problem.setUpperInputBound(u_max);
|
||||
|
||||
problem.setMinTimestep(minTimestep);
|
||||
problem.setMaxTimestep(3.0);
|
||||
|
||||
// Set up cost
|
||||
problem.minimize(problem.dt().times(SimpleMatrix.ones(N + 1, 1)));
|
||||
|
||||
assertEquals(ExpressionType.LINEAR, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONLINEAR, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve(new Options().withMaxIterations(1000)));
|
||||
|
||||
var X = problem.X();
|
||||
|
||||
// Verify initial state
|
||||
assertEquals(x_initial.get(0), X.value(0, 0), 1e-8);
|
||||
assertEquals(x_initial.get(1), X.value(1, 0), 1e-8);
|
||||
assertEquals(x_initial.get(2), X.value(2, 0), 1e-8);
|
||||
assertEquals(x_initial.get(3), X.value(3, 0), 1e-8);
|
||||
assertEquals(x_initial.get(4), X.value(4, 0), 1e-8);
|
||||
|
||||
// Verify final state
|
||||
assertEquals(x_final.get(0), X.value(0, N), 1e-8);
|
||||
assertEquals(x_final.get(1), X.value(1, N), 1e-8);
|
||||
assertEquals(x_final.get(2), X.value(2, N), 1e-8);
|
||||
assertEquals(x_final.get(3), X.value(3, N), 1e-8);
|
||||
assertEquals(x_final.get(4), X.value(4, N), 1e-8);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.wpilib.math.autodiff.NumericalIntegration.rk4;
|
||||
import static org.wpilib.math.optimization.Constraints.bounds;
|
||||
import static org.wpilib.math.optimization.Constraints.eq;
|
||||
import static org.wpilib.math.system.NumericalIntegration.rk4;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.wpilib.math.autodiff.ExpressionType;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.optimization.solver.ExitStatus;
|
||||
import org.wpilib.math.util.MathUtil;
|
||||
|
||||
class DifferentialDriveProblemTest {
|
||||
@Test
|
||||
void testDifferentialDriveProblem() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
final double TOTAL_TIME = 5.0; // s
|
||||
final double dt = 0.05; // s
|
||||
final int N = (int) (TOTAL_TIME / dt);
|
||||
|
||||
final double u_max = 12.0; // V
|
||||
|
||||
final var x_initial = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}, {0.0}});
|
||||
final var x_final = new SimpleMatrix(new double[][] {{1.0}, {1.0}, {0.0}, {0.0}, {0.0}});
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
// x = [x, y, heading, left velocity, right velocity]ᵀ
|
||||
var X = problem.decisionVariable(5, N + 1);
|
||||
|
||||
// Initial guess
|
||||
for (int k = 0; k < N; ++k) {
|
||||
X.get(0, k).setValue(MathUtil.lerp(x_initial.get(0), x_final.get(0), (double) k / N));
|
||||
X.get(1, k).setValue(MathUtil.lerp(x_initial.get(1), x_final.get(1), (double) k / N));
|
||||
}
|
||||
|
||||
// u = [left voltage, right voltage]ᵀ
|
||||
var U = problem.decisionVariable(2, N);
|
||||
|
||||
// Initial conditions
|
||||
problem.subjectTo(eq(X.col(0), x_initial));
|
||||
|
||||
// Final conditions
|
||||
problem.subjectTo(eq(X.col(N), x_final));
|
||||
|
||||
// Input constraints
|
||||
problem.subjectTo(bounds(-u_max, U, u_max));
|
||||
|
||||
// Dynamics constraints - RK4 integration
|
||||
for (int k = 0; k < N; ++k) {
|
||||
problem.subjectTo(
|
||||
eq(
|
||||
X.col(k + 1),
|
||||
rk4(DifferentialDriveUtil::differentialDriveDynamics, X.col(k), U.col(k), dt)));
|
||||
}
|
||||
|
||||
// Minimize sum squared states and inputs
|
||||
var J = new Variable(0.0);
|
||||
for (int k = 0; k < N; ++k) {
|
||||
J = J.plus(X.col(k).T().times(X.col(k)).plus(U.col(k).T().times(U.col(k))).get(0));
|
||||
}
|
||||
problem.minimize(J);
|
||||
|
||||
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONLINEAR, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
// Verify initial state
|
||||
assertEquals(x_initial.get(0), X.value(0, 0), 1e-8);
|
||||
assertEquals(x_initial.get(1), X.value(1, 0), 1e-8);
|
||||
assertEquals(x_initial.get(2), X.value(2, 0), 1e-8);
|
||||
assertEquals(x_initial.get(3), X.value(3, 0), 1e-8);
|
||||
assertEquals(x_initial.get(4), X.value(4, 0), 1e-8);
|
||||
|
||||
// Verify solution
|
||||
var x = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}, {0.0}});
|
||||
for (int k = 0; k < N; ++k) {
|
||||
// Input constraints
|
||||
assertTrue(U.get(0, k).value() >= -u_max);
|
||||
assertTrue(U.get(0, k).value() <= u_max);
|
||||
assertTrue(U.get(1, k).value() >= -u_max);
|
||||
assertTrue(U.get(1, k).value() <= u_max);
|
||||
|
||||
// Verify state
|
||||
assertEquals(x.get(0), X.value(0, k), 1e-8);
|
||||
assertEquals(x.get(1), X.value(1, k), 1e-8);
|
||||
assertEquals(x.get(2), X.value(2, k), 1e-8);
|
||||
assertEquals(x.get(3), X.value(3, k), 1e-8);
|
||||
assertEquals(x.get(4), X.value(4, k), 1e-8);
|
||||
|
||||
// Project state forward
|
||||
var u = U.col(k).value();
|
||||
x = rk4(DifferentialDriveUtil::differentialDriveDynamics, x, u, dt);
|
||||
}
|
||||
|
||||
// Verify final state
|
||||
assertEquals(x_final.get(0), X.value(0, N), 1e-8);
|
||||
assertEquals(x_final.get(1), X.value(1, N), 1e-8);
|
||||
assertEquals(x_final.get(2), X.value(2, N), 1e-8);
|
||||
assertEquals(x_final.get(3), X.value(3, N), 1e-8);
|
||||
assertEquals(x_final.get(4), X.value(4, N), 1e-8);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.wpilib.math.autodiff.Variable.cos;
|
||||
import static org.wpilib.math.autodiff.Variable.sin;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.wpilib.math.autodiff.VariableMatrix;
|
||||
|
||||
// x = [x, y, heading, left velocity, right velocity]ᵀ
|
||||
// u = [left voltage, right voltage]ᵀ
|
||||
|
||||
public final class DifferentialDriveUtil {
|
||||
private DifferentialDriveUtil() {
|
||||
// Utility class.
|
||||
}
|
||||
|
||||
private static final double trackwidth = 0.699; // m
|
||||
private static final double Kv_linear = 3.02; // V/(m/s)
|
||||
private static final double Ka_linear = 0.642; // V/(m/s²)
|
||||
private static final double Kv_angular = 1.382; // V/(m/s)
|
||||
private static final double Ka_angular = 0.08495; // V/(m/s²)
|
||||
|
||||
private static final double A1 = -(Kv_linear / Ka_linear + Kv_angular / Ka_angular) / 2.0;
|
||||
private static final double A2 = -(Kv_linear / Ka_linear - Kv_angular / Ka_angular) / 2.0;
|
||||
private static final double B1 = 0.5 / Ka_linear + 0.5 / Ka_angular;
|
||||
private static final double B2 = 0.5 / Ka_linear - 0.5 / Ka_angular;
|
||||
private static final SimpleMatrix A = new SimpleMatrix(new double[][] {{A1, A2}, {A2, A1}});
|
||||
private static final SimpleMatrix B = new SimpleMatrix(new double[][] {{B1, B2}, {B2, B1}});
|
||||
|
||||
public static SimpleMatrix differentialDriveDynamics(SimpleMatrix x, SimpleMatrix u) {
|
||||
var xdot = new SimpleMatrix(5, 1);
|
||||
|
||||
var v = (x.get(3, 0) + x.get(4, 0)) / 2.0;
|
||||
xdot.set(0, 0, v * Math.cos(x.get(2, 0)));
|
||||
xdot.set(1, 0, v * Math.sin(x.get(2, 0)));
|
||||
xdot.set(2, 0, (x.get(4, 0) - x.get(3, 0)) / trackwidth);
|
||||
xdot.insertIntoThis(3, 0, A.mult(x.extractMatrix(3, 5, 0, 1)).plus(B.mult(u)));
|
||||
|
||||
return xdot;
|
||||
}
|
||||
|
||||
public static VariableMatrix differentialDriveDynamics(VariableMatrix x, VariableMatrix u) {
|
||||
var xdot = new VariableMatrix(5);
|
||||
|
||||
var v = x.get(3).plus(x.get(4)).div(2.0);
|
||||
xdot.set(0, 0, v.times(cos(x.get(2))));
|
||||
xdot.set(1, 0, v.times(sin(x.get(2))));
|
||||
xdot.set(2, 0, x.get(4).minus(x.get(3)).div(trackwidth));
|
||||
xdot.segment(3, 2)
|
||||
.set(new VariableMatrix(A).times(x.segment(3, 2)).plus(new VariableMatrix(B).times(u)));
|
||||
|
||||
return xdot;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.wpilib.math.autodiff.Variable.pow;
|
||||
import static org.wpilib.math.optimization.Constraints.bounds;
|
||||
import static org.wpilib.math.optimization.Constraints.eq;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.wpilib.math.autodiff.ExpressionType;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.autodiff.VariableMatrix;
|
||||
import org.wpilib.math.optimization.solver.ExitStatus;
|
||||
|
||||
class DoubleIntegratorProblemTest {
|
||||
@Test
|
||||
void testDoubleIntegratorProblem() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
final double TOTAL_TIME = 3.5; // s
|
||||
final double dt = 0.005; // s
|
||||
final int N = (int) (TOTAL_TIME / dt);
|
||||
|
||||
final double r = 2.0; // m
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
// 2x1 state vector with N + 1 timesteps (includes last state)
|
||||
var X = problem.decisionVariable(2, N + 1);
|
||||
|
||||
// 1x1 input vector with N timesteps (input at last state doesn't matter)
|
||||
var U = problem.decisionVariable(1, N);
|
||||
|
||||
// Kinematics constraint assuming constant acceleration between timesteps
|
||||
for (int k = 0; k < N; ++k) {
|
||||
final double t = dt;
|
||||
var p_k1 = X.get(0, k + 1);
|
||||
var v_k1 = X.get(1, k + 1);
|
||||
var p_k = X.get(0, k);
|
||||
var v_k = X.get(1, k);
|
||||
var a_k = U.get(0, k);
|
||||
|
||||
// pₖ₊₁ = pₖ + vₖt + 1/2aₖt²
|
||||
problem.subjectTo(eq(p_k1, p_k.plus(v_k.times(t)).plus(a_k.times(0.5 * t * t))));
|
||||
|
||||
// vₖ₊₁ = vₖ + aₖt
|
||||
problem.subjectTo(eq(v_k1, v_k.plus(a_k.times(t))));
|
||||
}
|
||||
|
||||
// Start and end at rest
|
||||
problem.subjectTo(eq(X.col(0), new VariableMatrix(new double[][] {{0.0}, {0.0}})));
|
||||
problem.subjectTo(eq(X.col(N), new VariableMatrix(new double[][] {{r}, {0.0}})));
|
||||
|
||||
// Limit velocity
|
||||
problem.subjectTo(bounds(-1, X.row(1), 1));
|
||||
|
||||
// Limit acceleration
|
||||
problem.subjectTo(bounds(-1, U, 1));
|
||||
|
||||
// Cost function - minimize position error
|
||||
var J = new Variable(0.0);
|
||||
for (int k = 0; k < N + 1; ++k) {
|
||||
J = J.plus(pow(new Variable(r).minus(X.get(0, k)), 2));
|
||||
}
|
||||
problem.minimize(J);
|
||||
|
||||
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
var A = new SimpleMatrix(new double[][] {{1.0, dt}, {0.0, 1.0}});
|
||||
var B = new SimpleMatrix(new double[][] {{0.5 * dt * dt}, {dt}});
|
||||
|
||||
// Verify initial state
|
||||
assertEquals(0.0, X.value(0, 0), 1e-8);
|
||||
assertEquals(0.0, X.value(1, 0), 1e-8);
|
||||
|
||||
// Verify solution
|
||||
var x = new SimpleMatrix(new double[][] {{0.0}, {0.0}});
|
||||
var u = new SimpleMatrix(new double[][] {{0.0}});
|
||||
for (int k = 0; k < N; ++k) {
|
||||
// Verify state
|
||||
assertEquals(x.get(0), X.value(0, k), 1e-2);
|
||||
assertEquals(x.get(1), X.value(1, k), 1e-2);
|
||||
|
||||
// Determine expected input for this timestep
|
||||
if (k * dt < 1.0) {
|
||||
// Accelerate
|
||||
u.set(0, 0, 1.0);
|
||||
} else if (k * dt < 2.05) {
|
||||
// Maintain speed
|
||||
u.set(0, 0, 0.0);
|
||||
} else if (k * dt < 3.275) {
|
||||
// Decelerate
|
||||
u.set(0, 0, -1.0);
|
||||
} else {
|
||||
// Accelerate
|
||||
u.set(0, 0, 1.0);
|
||||
}
|
||||
|
||||
// Verify input
|
||||
if (k > 0 && k < N - 1 && Math.abs(U.value(0, k - 1) - U.value(0, k + 1)) >= 1.0 - 1e-2) {
|
||||
// If control input is transitioning between -1, 0, or 1, ensure it's within (-1, 1)
|
||||
assertTrue(U.value(0, k) >= -1.0);
|
||||
assertTrue(U.value(0, k) <= 1.0);
|
||||
} else {
|
||||
assertEquals(u.get(0), U.value(0, k), 1e-4);
|
||||
}
|
||||
|
||||
// Project state forward
|
||||
x = A.mult(x).plus(B.mult(u));
|
||||
}
|
||||
|
||||
// Verify final state
|
||||
assertEquals(r, X.value(0, N), 1e-8);
|
||||
assertEquals(0.0, X.value(1, N), 1e-8);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,178 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import java.util.function.BiFunction;
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.wpilib.math.autodiff.ExpressionType;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.autodiff.VariableMatrix;
|
||||
import org.wpilib.math.optimization.ocp.DynamicsType;
|
||||
import org.wpilib.math.optimization.ocp.TimestepMethod;
|
||||
import org.wpilib.math.optimization.ocp.TranscriptionMethod;
|
||||
import org.wpilib.math.optimization.solver.ExitStatus;
|
||||
|
||||
class FlywheelOCPTest {
|
||||
private boolean near(double expected, double actual, double tolerance) {
|
||||
return Math.abs(expected - actual) < tolerance;
|
||||
}
|
||||
|
||||
void flywheelTest(
|
||||
double A,
|
||||
double B,
|
||||
BiFunction<VariableMatrix, VariableMatrix, VariableMatrix> f,
|
||||
DynamicsType dynamicsType,
|
||||
TranscriptionMethod transcriptionMethod) {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
final double TOTAL_TIME = 5.0; // s
|
||||
final double dt = 0.005; // s
|
||||
final int N = (int) (TOTAL_TIME / dt);
|
||||
|
||||
// Flywheel model:
|
||||
// States: [velocity]
|
||||
// Inputs: [voltage]
|
||||
final double A_discrete = Math.exp(A * dt);
|
||||
final double B_discrete = (1.0 - A_discrete) * B;
|
||||
|
||||
final double r = 10.0;
|
||||
|
||||
try (var problem =
|
||||
new OCP(1, 1, dt, N, f, dynamicsType, TimestepMethod.FIXED, transcriptionMethod)) {
|
||||
problem.constrainInitialState(0.0);
|
||||
problem.setUpperInputBound(12);
|
||||
problem.setLowerInputBound(-12);
|
||||
|
||||
// Set up cost
|
||||
var r_mat = new VariableMatrix(SimpleMatrix.filled(1, N + 1, r));
|
||||
problem.minimize(r_mat.minus(problem.X()).times(r_mat.minus(problem.X()).T()));
|
||||
|
||||
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
// Voltage for steady-state velocity:
|
||||
//
|
||||
// rₖ₊₁ = Arₖ + Buₖ
|
||||
// uₖ = B⁺(rₖ₊₁ − Arₖ)
|
||||
// uₖ = B⁺(rₖ − Arₖ)
|
||||
// uₖ = B⁺(I − A)rₖ
|
||||
double u_ss = 1.0 / B_discrete * (1.0 - A_discrete) * r;
|
||||
|
||||
// Verify initial state
|
||||
assertEquals(0.0, problem.X().value(0, 0), 1e-8);
|
||||
|
||||
// Verify solution
|
||||
double x = 0.0;
|
||||
double u;
|
||||
for (int k = 0; k < N; ++k) {
|
||||
// Verify state
|
||||
assertEquals(x, problem.X().value(0, k), 1e-2);
|
||||
|
||||
// Determine expected input for this timestep
|
||||
double error = r - x;
|
||||
if (error > 1e-2) {
|
||||
// Max control input until the reference is reached
|
||||
u = 12.0;
|
||||
} else {
|
||||
// Maintain speed
|
||||
u = u_ss;
|
||||
}
|
||||
|
||||
// Verify input
|
||||
if (k > 0
|
||||
&& k < N - 1
|
||||
&& near(12.0, problem.U().value(0, k - 1), 1e-2)
|
||||
&& near(u_ss, problem.U().value(0, k + 1), 1e-2)) {
|
||||
// If control input is transitioning between 12 and u_ss, ensure it's
|
||||
// within (u_ss, 12)
|
||||
assertTrue(problem.U().value(0, k) >= u_ss);
|
||||
assertTrue(problem.U().value(0, k) <= 12.0);
|
||||
} else {
|
||||
if (transcriptionMethod == TranscriptionMethod.DIRECT_COLLOCATION) {
|
||||
// The tolerance is large because the trajectory is represented by a
|
||||
// spline, and splines chatter when transitioning quickly between
|
||||
// steady-states.
|
||||
assertEquals(u, problem.U().value(0, k), 2.0);
|
||||
} else {
|
||||
assertEquals(u, problem.U().value(0, k), 1e-4);
|
||||
}
|
||||
}
|
||||
|
||||
// Project state forward
|
||||
x = A_discrete * x + B_discrete * u;
|
||||
}
|
||||
|
||||
// Verify final state
|
||||
assertEquals(r, problem.X().value(0, N), 2e-6);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
static final double A = -1.0;
|
||||
static final double B = 1.0;
|
||||
|
||||
static final double dt = 0.005; // s
|
||||
|
||||
static final double A_discrete = Math.exp(A * dt);
|
||||
static final double B_discrete = (1.0 - A_discrete) * B;
|
||||
|
||||
private static VariableMatrix f_ode(VariableMatrix x, VariableMatrix u) {
|
||||
return new VariableMatrix(new double[][] {{A}})
|
||||
.times(x)
|
||||
.plus(new VariableMatrix(new double[][] {{B}}).times(u));
|
||||
}
|
||||
|
||||
private static VariableMatrix f_discrete(VariableMatrix x, VariableMatrix u) {
|
||||
return new VariableMatrix(new double[][] {{A_discrete}})
|
||||
.times(x)
|
||||
.plus(new VariableMatrix(new double[][] {{B_discrete}}).times(u));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFlywheelExplicit() {
|
||||
flywheelTest(
|
||||
A,
|
||||
B,
|
||||
FlywheelOCPTest::f_ode,
|
||||
DynamicsType.EXPLICIT_ODE,
|
||||
TranscriptionMethod.DIRECT_COLLOCATION);
|
||||
flywheelTest(
|
||||
A,
|
||||
B,
|
||||
FlywheelOCPTest::f_ode,
|
||||
DynamicsType.EXPLICIT_ODE,
|
||||
TranscriptionMethod.DIRECT_TRANSCRIPTION);
|
||||
flywheelTest(
|
||||
A,
|
||||
B,
|
||||
FlywheelOCPTest::f_ode,
|
||||
DynamicsType.EXPLICIT_ODE,
|
||||
TranscriptionMethod.SINGLE_SHOOTING);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFlywheelDiscrete() {
|
||||
flywheelTest(
|
||||
A,
|
||||
B,
|
||||
FlywheelOCPTest::f_discrete,
|
||||
DynamicsType.DISCRETE,
|
||||
TranscriptionMethod.DIRECT_TRANSCRIPTION);
|
||||
flywheelTest(
|
||||
A,
|
||||
B,
|
||||
FlywheelOCPTest::f_discrete,
|
||||
DynamicsType.DISCRETE,
|
||||
TranscriptionMethod.SINGLE_SHOOTING);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.wpilib.math.optimization.Constraints.bounds;
|
||||
import static org.wpilib.math.optimization.Constraints.eq;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.wpilib.math.autodiff.ExpressionType;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.autodiff.VariableMatrix;
|
||||
import org.wpilib.math.optimization.solver.ExitStatus;
|
||||
|
||||
class FlywheelProblemTest {
|
||||
private boolean near(double expected, double actual, double tolerance) {
|
||||
return Math.abs(expected - actual) < tolerance;
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFlywheelProblem() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
final double TOTAL_TIME = 5.0; // s
|
||||
final double dt = 0.005; // s
|
||||
final int N = (int) (TOTAL_TIME / dt);
|
||||
|
||||
// Flywheel model:
|
||||
// States: [velocity]
|
||||
// Inputs: [voltage]
|
||||
double A = Math.exp(-dt);
|
||||
double B = 1.0 - Math.exp(-dt);
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var X = problem.decisionVariable(1, N + 1);
|
||||
var U = problem.decisionVariable(1, N);
|
||||
|
||||
// Dynamics constraint
|
||||
for (int k = 0; k < N; ++k) {
|
||||
problem.subjectTo(
|
||||
eq(
|
||||
X.col(k + 1),
|
||||
new Variable(A)
|
||||
.times(X.col(k).get(0))
|
||||
.plus(new Variable(B).times(U.col(k).get(0)))));
|
||||
}
|
||||
|
||||
// State and input constraints
|
||||
problem.subjectTo(eq(X.col(0), 0.0));
|
||||
problem.subjectTo(bounds(-12, U, 12));
|
||||
|
||||
// Cost function - minimize error
|
||||
final var r = new VariableMatrix(new double[][] {{10.0}});
|
||||
var J = new Variable(0.0);
|
||||
for (int k = 0; k < N + 1; ++k) {
|
||||
J = J.plus(r.minus(X.col(k)).T().times(r.minus(X.col(k))).get(0));
|
||||
}
|
||||
problem.minimize(J);
|
||||
|
||||
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
// Voltage for steady-state velocity:
|
||||
//
|
||||
// rₖ₊₁ = Arₖ + Buₖ
|
||||
// uₖ = B⁺(rₖ₊₁ − Arₖ)
|
||||
// uₖ = B⁺(rₖ − Arₖ)
|
||||
// uₖ = B⁺(I − A)rₖ
|
||||
double u_ss = 1.0 / B * (1.0 - A) * r.value(0);
|
||||
|
||||
// Verify initial state
|
||||
assertEquals(0.0, X.value(0, 0), 1e-8);
|
||||
|
||||
// Verify solution
|
||||
double x = 0.0;
|
||||
double u;
|
||||
for (int k = 0; k < N; ++k) {
|
||||
// Verify state
|
||||
assertEquals(x, X.value(0, k), 1e-2);
|
||||
|
||||
// Determine expected input for this timestep
|
||||
double error = r.value(0) - x;
|
||||
if (error > 1e-2) {
|
||||
// Max control input until the reference is reached
|
||||
u = 12.0;
|
||||
} else {
|
||||
// Maintain speed
|
||||
u = u_ss;
|
||||
}
|
||||
|
||||
// Verify input
|
||||
if (k > 0
|
||||
&& k < N - 1
|
||||
&& near(12.0, U.value(0, k - 1), 1e-2)
|
||||
&& near(u_ss, U.value(0, k + 1), 1e-2)) {
|
||||
// If control input is transitioning between 12 and u_ss, ensure it's
|
||||
// within (u_ss, 12)
|
||||
assertTrue(U.value(0, k) >= u_ss);
|
||||
assertTrue(U.value(0, k) <= 12.0);
|
||||
} else {
|
||||
assertEquals(u, U.value(0, k), 1e-4);
|
||||
}
|
||||
|
||||
// Project state forward
|
||||
x = A * x + B * u;
|
||||
}
|
||||
|
||||
// Verify final state
|
||||
assertEquals(r.value(0), X.value(0, N), 2e-7);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.wpilib.math.optimization.Constraints.eq;
|
||||
import static org.wpilib.math.optimization.Constraints.ge;
|
||||
import static org.wpilib.math.optimization.Constraints.le;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.wpilib.math.autodiff.ExpressionType;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.optimization.solver.ExitStatus;
|
||||
|
||||
class LinearProblemTest {
|
||||
@Test
|
||||
void testMaximize() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
var y = problem.decisionVariable();
|
||||
x.setValue(1.0);
|
||||
y.setValue(1.0);
|
||||
|
||||
problem.maximize(x.times(50).plus(y.times(40)));
|
||||
|
||||
problem.subjectTo(le(x.plus(y.times(1.5)), 750));
|
||||
problem.subjectTo(le(x.times(2).plus(y.times(3)), 1500));
|
||||
problem.subjectTo(le(x.times(2).plus(y), 1000));
|
||||
problem.subjectTo(ge(x, 0));
|
||||
problem.subjectTo(ge(y, 0));
|
||||
|
||||
assertEquals(ExpressionType.LINEAR, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
assertEquals(375.0, x.value(), 1e-6);
|
||||
assertEquals(250.0, y.value(), 1e-6);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFreeVariable() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable(2);
|
||||
x.get(0).setValue(1.0);
|
||||
x.get(1).setValue(2.0);
|
||||
|
||||
problem.subjectTo(eq(x.get(0), 0));
|
||||
|
||||
assertEquals(ExpressionType.NONE, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
assertEquals(0.0, x.get(0).value(), 1e-6);
|
||||
assertEquals(2.0, x.get(1).value(), 1e-6);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,212 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.wpilib.math.DoubleRange.range;
|
||||
import static org.wpilib.math.autodiff.Variable.hypot;
|
||||
import static org.wpilib.math.autodiff.Variable.pow;
|
||||
import static org.wpilib.math.autodiff.Variable.sqrt;
|
||||
import static org.wpilib.math.optimization.Constraints.bounds;
|
||||
import static org.wpilib.math.optimization.Constraints.eq;
|
||||
import static org.wpilib.math.optimization.Constraints.ge;
|
||||
import static org.wpilib.math.optimization.Constraints.le;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.wpilib.math.autodiff.ExpressionType;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.optimization.solver.ExitStatus;
|
||||
|
||||
class NonlinearProblemTest {
|
||||
@Test
|
||||
void testQuartic() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
x.setValue(20.0);
|
||||
|
||||
problem.minimize(pow(x, 4));
|
||||
|
||||
problem.subjectTo(ge(x, 1));
|
||||
|
||||
assertEquals(ExpressionType.NONLINEAR, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
assertEquals(1.0, x.value(), 1e-6);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
private boolean near(double expected, double actual, double tolerance) {
|
||||
return Math.abs(expected - actual) < tolerance;
|
||||
}
|
||||
|
||||
@Test
|
||||
void testRosenbrockWithCubicAndLineConstraint() {
|
||||
// https://en.wikipedia.org/wiki/Test_functions_for_optimization#Test_functions_for_constrained_optimization
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
var y = problem.decisionVariable();
|
||||
|
||||
problem.minimize(
|
||||
pow(y.minus(pow(x, 2)), 2).times(100).plus(pow(new Variable(1).minus(x), 2)));
|
||||
|
||||
problem.subjectTo(ge(y, pow(x.minus(1), 3).plus(1)));
|
||||
problem.subjectTo(le(y, x.unaryMinus().plus(2)));
|
||||
|
||||
assertEquals(ExpressionType.NONLINEAR, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONLINEAR, problem.inequalityConstraintType());
|
||||
|
||||
for (var x0 : range(-1.5, 1.5, 0.1)) {
|
||||
for (var y0 : range(-0.5, 2.5, 0.1)) {
|
||||
x.setValue(x0);
|
||||
y.setValue(y0);
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
// Local minimum at (0.0, 0.0)
|
||||
// Global minimum at (1.0, 1.0)
|
||||
assertTrue(near(0.0, x.value(), 1e-2) || near(1.0, x.value(), 1e-2));
|
||||
assertTrue(near(0.0, y.value(), 1e-2) || near(1.0, y.value(), 1e-2));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testRosenbrockWithDiskConstraint() {
|
||||
// https://en.wikipedia.org/wiki/Test_functions_for_optimization#Test_functions_for_constrained_optimization
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
var y = problem.decisionVariable();
|
||||
|
||||
problem.minimize(
|
||||
pow(new Variable(1).minus(x), 2).plus(pow(y.minus(pow(x, 2)), 2).times(100)));
|
||||
|
||||
problem.subjectTo(le(pow(x, 2).plus(pow(y, 2)), 2));
|
||||
|
||||
assertEquals(ExpressionType.NONLINEAR, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.QUADRATIC, problem.inequalityConstraintType());
|
||||
|
||||
for (var x0 : range(-1.5, 1.5, 0.1)) {
|
||||
for (var y0 : range(-1.5, 1.5, 0.1)) {
|
||||
x.setValue(x0);
|
||||
y.setValue(y0);
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
assertEquals(1.0, x.value(), 1e-3);
|
||||
assertEquals(1.0, y.value(), 1e-3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMinimum2DDistanceWithLinearConstraint() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
var y = problem.decisionVariable();
|
||||
x.setValue(20.0);
|
||||
y.setValue(50.0);
|
||||
|
||||
problem.minimize(sqrt(x.times(x).plus(y.times(y))));
|
||||
|
||||
problem.subjectTo(eq(y, x.unaryMinus().plus(5.0)));
|
||||
|
||||
assertEquals(ExpressionType.NONLINEAR, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
assertEquals(2.5, x.value(), 1e-2);
|
||||
assertEquals(2.5, y.value(), 1e-2);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testConflictingBounds() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
var y = problem.decisionVariable();
|
||||
|
||||
problem.minimize(hypot(x, y));
|
||||
|
||||
problem.subjectTo(le(hypot(x, y), 1));
|
||||
problem.subjectTo(bounds(0.5, x, -0.5));
|
||||
|
||||
assertEquals(ExpressionType.NONLINEAR, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONLINEAR, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.GLOBALLY_INFEASIBLE, problem.solve());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testWachterAndBieglerLineSearchFailure() {
|
||||
// See example 19.2 of [1]
|
||||
//
|
||||
// [1] Nocedal, J. and Wright, S. "Numerical Optimization", 2nd. ed., Ch. 19. Springer, 2006.
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
var s1 = problem.decisionVariable();
|
||||
var s2 = problem.decisionVariable();
|
||||
|
||||
x.setValue(-2);
|
||||
s1.setValue(3);
|
||||
s2.setValue(1);
|
||||
|
||||
problem.minimize(x);
|
||||
|
||||
problem.subjectTo(eq(pow(x, 2).minus(s1).minus(1), 0));
|
||||
problem.subjectTo(eq(x.minus(s2).minus(0.5), 0));
|
||||
problem.subjectTo(ge(s1, 0));
|
||||
problem.subjectTo(ge(s2, 0));
|
||||
|
||||
assertEquals(ExpressionType.LINEAR, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.QUADRATIC, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
assertEquals(1.0, x.value(), 1e-6);
|
||||
assertEquals(0.0, s1.value(), 1e-6);
|
||||
assertEquals(0.5, s2.value(), 1e-6);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class ProblemJNITest {
|
||||
@Test
|
||||
public void testLink() {
|
||||
assertDoesNotThrow(ProblemJNI::forceLoad);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,194 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.wpilib.math.optimization.Constraints.eq;
|
||||
import static org.wpilib.math.optimization.Constraints.ge;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.wpilib.math.autodiff.ExpressionType;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.optimization.solver.ExitStatus;
|
||||
|
||||
class QuadraticProblemTest {
|
||||
@Test
|
||||
void testUnconstrained1D() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
x.setValue(2.0);
|
||||
|
||||
problem.minimize(x.times(x).minus(x.times(6.0)));
|
||||
|
||||
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
assertEquals(3.0, x.value(), 1e-6);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testUnconstrained2D() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
var y = problem.decisionVariable();
|
||||
x.setValue(1.0);
|
||||
y.setValue(2.0);
|
||||
|
||||
problem.minimize(x.times(x).plus(y.times(y)));
|
||||
|
||||
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
assertEquals(0.0, x.value(), 1e-6);
|
||||
assertEquals(0.0, y.value(), 1e-6);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable(2);
|
||||
x.get(0).setValue(1.0);
|
||||
x.get(1).setValue(2.0);
|
||||
|
||||
problem.minimize(x.T().times(x).get(0));
|
||||
|
||||
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
assertEquals(0.0, x.value(0), 1e-6);
|
||||
assertEquals(0.0, x.value(1), 1e-6);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testEqualityConstrained() {
|
||||
// Maximize xy subject to x + 3y = 36.
|
||||
//
|
||||
// Maximize f(x,y) = xy
|
||||
// subject to g(x,y) = x + 3y - 36 = 0
|
||||
//
|
||||
// value func constraint
|
||||
// | |
|
||||
// v v
|
||||
// L(x,y,λ) = f(x,y) - λg(x,y)
|
||||
// L(x,y,λ) = xy - λ(x + 3y - 36)
|
||||
// L(x,y,λ) = xy - xλ - 3yλ + 36λ
|
||||
//
|
||||
// ∇_x,y,λ L(x,y,λ) = 0
|
||||
//
|
||||
// ∂L/∂x = y - λ
|
||||
// ∂L/∂y = x - 3λ
|
||||
// ∂L/∂λ = -x - 3y + 36
|
||||
//
|
||||
// 0x + 1y - 1λ = 0
|
||||
// 1x + 0y - 3λ = 0
|
||||
// -1x - 3y + 0λ + 36 = 0
|
||||
//
|
||||
// [ 0 1 -1][x] [ 0]
|
||||
// [ 1 0 -3][y] = [ 0]
|
||||
// [-1 -3 0][λ] [-36]
|
||||
//
|
||||
// Solve with:
|
||||
//
|
||||
// ```python
|
||||
// np.linalg.solve(
|
||||
// np.array([[0,1,-1],
|
||||
// [1,0,-3],
|
||||
// [-1,-3,0]]),
|
||||
// np.array([[0], [0], [-36]]))
|
||||
// ```
|
||||
//
|
||||
// [x] [18]
|
||||
// [y] = [ 6]
|
||||
// [λ] [ 6]
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
var y = problem.decisionVariable();
|
||||
|
||||
problem.maximize(x.times(y));
|
||||
|
||||
problem.subjectTo(eq(x.plus(y.times(3)), 36));
|
||||
|
||||
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
assertEquals(18.0, x.value(), 1e-5);
|
||||
assertEquals(6.0, y.value(), 1e-5);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable(2);
|
||||
x.get(0).setValue(1.0);
|
||||
x.get(1).setValue(2.0);
|
||||
|
||||
problem.minimize(x.T().times(x).get(0));
|
||||
|
||||
problem.subjectTo(eq(x, new double[][] {{3.0}, {3.0}}));
|
||||
|
||||
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
assertEquals(3.0, x.value(0), 1e-5);
|
||||
assertEquals(3.0, x.value(1), 1e-5);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testInequalityConstrained2D() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
var y = problem.decisionVariable();
|
||||
x.setValue(5.0);
|
||||
y.setValue(5.0);
|
||||
|
||||
problem.minimize(x.times(x).plus(y.times(2).times(y)));
|
||||
problem.subjectTo(ge(y, x.unaryMinus().plus(5)));
|
||||
|
||||
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
assertEquals(3.0 + 1.0 / 3.0, x.value(), 1e-6);
|
||||
assertEquals(1.0 + 2.0 / 3.0, y.value(), 1e-6);
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.wpilib.math.autodiff.ExpressionType;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.optimization.solver.ExitStatus;
|
||||
|
||||
class TrivialProblemTest {
|
||||
@Test
|
||||
void testEmpty() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
assertEquals(ExpressionType.NONE, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testNoCostUnconstrained() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
@SuppressWarnings("VariableDeclarationUsageDistance")
|
||||
var X = problem.decisionVariable(2, 3);
|
||||
|
||||
assertEquals(ExpressionType.NONE, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
for (int row = 0; row < X.rows(); ++row) {
|
||||
for (int col = 0; col < X.cols(); ++col) {
|
||||
assertEquals(0.0, X.value(row, col));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var X = problem.decisionVariable(2, 3);
|
||||
X.setValue(SimpleMatrix.ones(2, 3));
|
||||
|
||||
assertEquals(ExpressionType.NONE, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
for (int row = 0; row < X.rows(); ++row) {
|
||||
for (int col = 0; col < X.cols(); ++col) {
|
||||
assertEquals(1.0, X.value(row, col));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
package org.wpilib.math.optimization.solver;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.wpilib.math.autodiff.Variable.sqrt;
|
||||
import static org.wpilib.math.optimization.Constraints.eq;
|
||||
import static org.wpilib.math.optimization.Constraints.ge;
|
||||
import static org.wpilib.math.optimization.Constraints.gt;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.wpilib.math.autodiff.ExpressionType;
|
||||
import org.wpilib.math.autodiff.Variable;
|
||||
import org.wpilib.math.optimization.Problem;
|
||||
|
||||
// These tests ensure coverage of the off-nominal exit statuses
|
||||
|
||||
class ExitStatusTest {
|
||||
@Test
|
||||
void testCallbackRequestedStop() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
problem.minimize(x.times(x));
|
||||
|
||||
problem.addCallback(info -> false);
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
|
||||
problem.addCallback(info -> true);
|
||||
assertEquals(ExitStatus.CALLBACK_REQUESTED_STOP, problem.solve());
|
||||
|
||||
problem.clearCallbacks();
|
||||
problem.addCallback(info -> false);
|
||||
assertEquals(ExitStatus.SUCCESS, problem.solve());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testTooFewDOFs() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
var y = problem.decisionVariable();
|
||||
var z = problem.decisionVariable();
|
||||
|
||||
problem.subjectTo(eq(x, 1));
|
||||
problem.subjectTo(eq(x, 2));
|
||||
problem.subjectTo(eq(y, 1));
|
||||
problem.subjectTo(eq(z, 1));
|
||||
|
||||
assertEquals(ExpressionType.NONE, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.TOO_FEW_DOFS, problem.solve());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testLocallyInfeasible() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
// Equality constraints
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
var y = problem.decisionVariable();
|
||||
var z = problem.decisionVariable();
|
||||
|
||||
problem.subjectTo(eq(x, y.plus(1)));
|
||||
problem.subjectTo(eq(y, z.plus(1)));
|
||||
problem.subjectTo(eq(z, x.plus(1)));
|
||||
|
||||
assertEquals(ExpressionType.NONE, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.LOCALLY_INFEASIBLE, problem.solve());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
// Inequality constraints
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
var y = problem.decisionVariable();
|
||||
var z = problem.decisionVariable();
|
||||
|
||||
problem.subjectTo(ge(x, y.plus(1)));
|
||||
problem.subjectTo(ge(y, z.plus(1)));
|
||||
problem.subjectTo(ge(z, x.plus(1)));
|
||||
|
||||
assertEquals(ExpressionType.NONE, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.LINEAR, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.LOCALLY_INFEASIBLE, problem.solve());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testNonfiniteInitialGuess() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
// Nonfinite cost
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
problem.minimize(new Variable(1).div(x));
|
||||
|
||||
assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve());
|
||||
}
|
||||
|
||||
// Nonfinite gradient
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
problem.minimize(sqrt(x));
|
||||
|
||||
assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve());
|
||||
}
|
||||
|
||||
// Nonfinite equality constraint
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
problem.subjectTo(eq(new Variable(1).div(x), 1));
|
||||
|
||||
assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve());
|
||||
}
|
||||
|
||||
// Nonfinite equality constraint Jacobian
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
problem.subjectTo(eq(sqrt(x), 1));
|
||||
|
||||
assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve());
|
||||
}
|
||||
|
||||
// Nonfinite inequality constraint
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
problem.subjectTo(gt(new Variable(1).div(x), 1));
|
||||
|
||||
assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve());
|
||||
}
|
||||
|
||||
// Nonfinite inequality constraint Jacobian
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
problem.subjectTo(gt(sqrt(x), 1));
|
||||
|
||||
assertEquals(ExitStatus.NONFINITE_INITIAL_GUESS, problem.solve());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testDivergingIterates() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
|
||||
problem.minimize(x);
|
||||
|
||||
assertEquals(ExpressionType.LINEAR, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.DIVERGING_ITERATES, problem.solve());
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMaxIterationsExceeded() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
|
||||
problem.minimize(x.times(x));
|
||||
|
||||
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(
|
||||
ExitStatus.MAX_ITERATIONS_EXCEEDED, problem.solve(new Options().withMaxIterations(0)));
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testTimeout() {
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
|
||||
try (var problem = new Problem()) {
|
||||
var x = problem.decisionVariable();
|
||||
|
||||
problem.minimize(x.times(x));
|
||||
|
||||
assertEquals(ExpressionType.QUADRATIC, problem.costFunctionType());
|
||||
assertEquals(ExpressionType.NONE, problem.equalityConstraintType());
|
||||
assertEquals(ExpressionType.NONE, problem.inequalityConstraintType());
|
||||
|
||||
assertEquals(ExitStatus.TIMEOUT, problem.solve(new Options().withTimeout(0.0)));
|
||||
}
|
||||
|
||||
assertEquals(0, Variable.totalNativeMemoryUsage());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
#include "wpi/math/optimization/CurrentManager.hpp"
|
||||
|
||||
#include <array>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
TEST(CurrentManagerTest, EnoughCurrent) {
|
||||
wpi::math::CurrentManager manager{std::array{1.0, 5.0, 10.0, 5.0}, 40.0};
|
||||
|
||||
auto currents = manager.calculate(std::array{25.0, 10.0, 5.0, 0.0});
|
||||
|
||||
EXPECT_NEAR(currents[0], 25.0, 1e-3);
|
||||
EXPECT_NEAR(currents[1], 10.0, 1e-3);
|
||||
EXPECT_NEAR(currents[2], 5.0, 1e-3);
|
||||
EXPECT_NEAR(currents[3], 0.0, 1e-3);
|
||||
}
|
||||
|
||||
TEST(CurrentManagerTest, NotEnoughCurrent) {
|
||||
wpi::math::CurrentManager manager{std::array{1.0, 5.0, 10.0, 5.0}, 40.0};
|
||||
|
||||
auto currents = manager.calculate(std::array{30.0, 10.0, 5.0, 0.0});
|
||||
|
||||
// Expected values are from the following program:
|
||||
//
|
||||
// #!/usr/bin/env python3
|
||||
//
|
||||
// from scipy.optimize import minimize
|
||||
//
|
||||
// r = [30.0, 10.0, 5.0, 0.0]
|
||||
// q = [1.0, 5.0, 10.0, 5.0]
|
||||
//
|
||||
// result = minimize(
|
||||
// lambda x: sum((r[i] - x[i]) ** 2 / q[i] ** 2 for i in range(4)),
|
||||
// [0.0, 0.0, 0.0, 0.0],
|
||||
// constraints=[
|
||||
// {"type": "ineq", "fun": lambda x: x},
|
||||
// {"type": "ineq", "fun": lambda x: 40.0 - sum(x)},
|
||||
// ],
|
||||
// )
|
||||
// print(result.x)
|
||||
EXPECT_NEAR(currents[0], 29.960, 1e-3);
|
||||
EXPECT_NEAR(currents[1], 9.008, 1e-3);
|
||||
EXPECT_NEAR(currents[2], 1.032, 1e-3);
|
||||
EXPECT_NEAR(currents[3], 0.0, 1e-3);
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
// Copyright (c) FIRST and other WPILib contributors.
|
||||
// Open Source Software; you can modify and/or share it under the terms of
|
||||
// the WPILib BSD license file in the root directory of this project.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <span>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
#include <sleipnir/optimization/problem.hpp>
|
||||
|
||||
namespace wpi::math {
|
||||
|
||||
/**
|
||||
* This class computes the optimal current allocation for a list of subsystems
|
||||
* given a list of their desired currents and current tolerances that determine
|
||||
* which subsystem gets less current if the current budget is exceeded.
|
||||
* Subsystems with a smaller tolerance are given higher priority.
|
||||
*/
|
||||
class CurrentManager {
|
||||
public:
|
||||
/**
|
||||
* Constructs a CurrentManager.
|
||||
*
|
||||
* @param currentTolerances The relative current tolerance of each subsystem.
|
||||
* @param maxCurrent The current budget to allocate between subsystems.
|
||||
*/
|
||||
CurrentManager(std::span<const double> currentTolerances, double maxCurrent)
|
||||
: m_desiredCurrents{static_cast<int>(currentTolerances.size()), 1},
|
||||
m_allocatedCurrents{
|
||||
m_problem.decision_variable(currentTolerances.size())} {
|
||||
// Ensure m_desiredCurrents contains initialized Variables
|
||||
for (int row = 0; row < m_desiredCurrents.rows(); ++row) {
|
||||
// Don't initialize to 0 or 1, because those will get folded by Sleipnir
|
||||
m_desiredCurrents[row] = std::numeric_limits<double>::infinity();
|
||||
}
|
||||
|
||||
slp::Variable J = 0.0;
|
||||
slp::Variable current_sum = 0.0;
|
||||
for (size_t i = 0; i < currentTolerances.size(); ++i) {
|
||||
// The weight is 1/tolᵢ² where tolᵢ is the tolerance between the desired
|
||||
// and allocated current for subsystem i
|
||||
auto error = m_desiredCurrents[i] - m_allocatedCurrents[i];
|
||||
J += error * error / (currentTolerances[i] * currentTolerances[i]);
|
||||
|
||||
current_sum += m_allocatedCurrents[i];
|
||||
|
||||
// Currents must be nonnegative
|
||||
m_problem.subject_to(m_allocatedCurrents[i] >= 0.0);
|
||||
}
|
||||
m_problem.minimize(J);
|
||||
|
||||
// Keep total current below maximum
|
||||
m_problem.subject_to(current_sum <= maxCurrent);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the optimal current allocation for a list of subsystems given a
|
||||
* list of their desired currents and current tolerances that determine which
|
||||
* subsystem gets less current if the current budget is exceeded. Subsystems
|
||||
* with a smaller tolerance are given higher priority.
|
||||
*
|
||||
* @param desiredCurrents The desired current for each subsystem.
|
||||
* @throws std::runtime_error if the number of desired currents doesn't equal
|
||||
* the number of tolerances passed in the constructor.
|
||||
*/
|
||||
std::vector<double> calculate(std::span<const double> desiredCurrents) {
|
||||
if (m_desiredCurrents.rows() != static_cast<int>(desiredCurrents.size())) {
|
||||
throw std::runtime_error(
|
||||
"Number of desired currents must equal the number of tolerances "
|
||||
"passed in the constructor.");
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < desiredCurrents.size(); ++i) {
|
||||
m_desiredCurrents[i].set_value(desiredCurrents[i]);
|
||||
}
|
||||
|
||||
m_problem.solve();
|
||||
|
||||
std::vector<double> result;
|
||||
for (size_t i = 0; i < desiredCurrents.size(); ++i) {
|
||||
result.emplace_back(std::max(m_allocatedCurrents.value(i), 0.0));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
slp::Problem<double> m_problem;
|
||||
slp::VariableMatrix<double> m_desiredCurrents;
|
||||
slp::VariableMatrix<double> m_allocatedCurrents;
|
||||
};
|
||||
|
||||
} // namespace wpi::math
|
||||
Reference in New Issue
Block a user