[wpimath] Add Sleipnir Java bindings (#8236)

The wrapper includes reverse mode autodiff, the Problem DSL, and the
optimal control problem API. I wrote it by directly translating the
upstream
[API](https://github.com/SleipnirGroup/Sleipnir/tree/main/include/sleipnir)
and [tests](https://github.com/SleipnirGroup/Sleipnir/tree/main/test) to
Java (i.e., copy-paste-modify).

I replaced the ArmFeedforward and Ellipse2d JNIs with implementations
using the Sleipnir Java bindings. Switching dev binary JNIs to release
by default sped up wpimath test runs from several minutes to 7 seconds.
This commit is contained in:
Tyler Veness
2026-03-29 22:34:21 -07:00
committed by GitHub
parent 3e821b9448
commit d248c040bf
84 changed files with 13405 additions and 170 deletions

View File

@@ -0,0 +1,47 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
/**
* Expression type.
*
* <p>Used for autodiff caching.
*/
public enum ExpressionType {
/** There is no expression. */
NONE(0),
/** The expression is a constant. */
CONSTANT(1),
/** The expression is composed of linear and lower-order operators. */
LINEAR(2),
/** The expression is composed of quadratic and lower-order operators. */
QUADRATIC(3),
/** The expression is composed of nonlinear and lower-order operators. */
NONLINEAR(4);
/** ExpressionType value. */
public final int value;
ExpressionType(int value) {
this.value = value;
}
/**
* Converts integer to its corresponding enum value.
*
* @param x The integer.
* @return The enum value.
*/
public static ExpressionType fromInt(int x) {
return switch (x) {
case 0 -> ExpressionType.NONE;
case 1 -> ExpressionType.CONSTANT;
case 2 -> ExpressionType.LINEAR;
case 3 -> ExpressionType.QUADRATIC;
case 4 -> ExpressionType.NONLINEAR;
default -> null;
};
}
}

View File

@@ -0,0 +1,78 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import org.ejml.simple.SimpleMatrix;
/**
* This class calculates the gradient of a variable with respect to a vector of variables.
*
* <p>The gradient is only recomputed if the variable expression is quadratic or higher order.
*/
public class Gradient implements AutoCloseable {
private long m_handle;
private int m_rows;
/**
* Constructs a Gradient object.
*
* @param variable Variable of which to compute the gradient.
* @param wrt Variable with respect to which to compute the gradient.
*/
public Gradient(Variable variable, Variable wrt) {
this(variable, new VariableMatrix(wrt));
}
/**
* Constructs a Gradient object.
*
* @param variable Variable of which to compute the gradient.
* @param wrt Vector of variables with respect to which to compute the gradient.
*/
public Gradient(Variable variable, VariableMatrix wrt) {
assert wrt.cols() == 1;
m_handle = GradientJNI.create(variable.getHandle(), wrt.getHandles());
m_rows = wrt.rows();
}
/**
* Constructs a Gradient object.
*
* @param variable Variable of which to compute the gradient.
* @param wrt Vector of variables with respect to which to compute the gradient.
*/
public Gradient(Variable variable, VariableBlock wrt) {
this(variable, new VariableMatrix(wrt));
}
@Override
public void close() {
if (m_handle != 0) {
GradientJNI.destroy(m_handle);
m_handle = 0;
}
}
/**
* Returns the gradient as a VariableMatrix.
*
* <p>This is useful when constructing optimization problems with derivatives in them.
*
* @return The gradient as a VariableMatrix.
*/
public VariableMatrix get() {
return new VariableMatrix(m_rows, 1, GradientJNI.get(m_handle));
}
/**
* Evaluates the gradient at wrt's value.
*
* @return The gradient at wrt's value.
*/
public SimpleMatrix value() {
return GradientJNI.value(m_handle).toSimpleMatrix(m_rows, 1);
}
}

View File

@@ -0,0 +1,48 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import org.wpilib.math.jni.WPIMathJNI;
/** Gradient JNI functions. */
final class GradientJNI extends WPIMathJNI {
private GradientJNI() {
// Utility class.
}
/**
* Constructs a Gradient object.
*
* @param variable Variable of which to compute the Gradient.
* @param wrt Vector of variables with respect to which to compute the Gradient.
*/
static native long create(long variable, long[] wrt);
/**
* Destructs a Gradient.
*
* @param handle Gradient handle.
*/
static native void destroy(long handle);
/**
* Returns the Gradient as an array of Variable handles.
*
* <p>This is useful when constructing optimization problems with derivatives in them.
*
* @param handle Gradient handle.
* @return The Gradient as an array of Variable handles.
*/
static native long[] get(long handle);
/**
* Evaluates the Gradient at wrt's value.
*
* @param handle Gradient handle.
* @return A record containing the triplet row, column and value arrays (int[], int[], and
* double[] respectively).
*/
static native NativeSparseTriplets value(long handle);
}

View File

@@ -0,0 +1,79 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import org.ejml.simple.SimpleMatrix;
/**
* This class calculates the Hessian of a variable with respect to a vector of variables.
*
* <p>The gradient tree is cached so subsequent Hessian calculations are faster, and the Hessian is
* only recomputed if the variable expression is nonlinear.
*/
public class Hessian implements AutoCloseable {
private long m_handle;
private int m_rows;
/**
* Constructs a Hessian object.
*
* @param variable Variable of which to compute the Hessian.
* @param wrt Variable with respect to which to compute the Hessian.
*/
public Hessian(Variable variable, Variable wrt) {
this(variable, new VariableMatrix(wrt));
}
/**
* Constructs a Hessian object.
*
* @param variable Variable of which to compute the Hessian.
* @param wrt Vector of variables with respect to which to compute the Hessian.
*/
public Hessian(Variable variable, VariableMatrix wrt) {
assert wrt.cols() == 1;
m_handle = HessianJNI.create(variable.getHandle(), wrt.getHandles());
m_rows = wrt.rows();
}
/**
* Constructs a Hessian object.
*
* @param variable Variable of which to compute the Hessian.
* @param wrt Vector of variables with respect to which to compute the Hessian.
*/
public Hessian(Variable variable, VariableBlock wrt) {
this(variable, new VariableMatrix(wrt));
}
@Override
public void close() {
if (m_handle != 0) {
HessianJNI.destroy(m_handle);
m_handle = 0;
}
}
/**
* Returns the Hessian as a VariableMatrix.
*
* <p>This is useful when constructing optimization problems with derivatives in them.
*
* @return The Hessian as a VariableMatrix.
*/
public VariableMatrix get() {
return new VariableMatrix(m_rows, m_rows, HessianJNI.get(m_handle));
}
/**
* Evaluates the Hessian at wrt's value.
*
* @return The Hessian at wrt's value.
*/
public SimpleMatrix value() {
return HessianJNI.value(m_handle).toSimpleMatrix(m_rows, m_rows);
}
}

View File

@@ -0,0 +1,48 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import org.wpilib.math.jni.WPIMathJNI;
/** Hessian JNI functions. */
final class HessianJNI extends WPIMathJNI {
private HessianJNI() {
// Utility class.
}
/**
* Constructs a Hessian object.
*
* @param variable Variable of which to compute the Hessian.
* @param wrt Vector of variables with respect to which to compute the Hessian.
*/
static native long create(long variable, long[] wrt);
/**
* Destructs a Hessian.
*
* @param handle Hessian handle.
*/
static native void destroy(long handle);
/**
* Returns the Hessian as an array of Variable handles.
*
* <p>This is useful when constructing optimization problems with derivatives in them.
*
* @param handle Hessian handle.
* @return The Hessian as an array of Variable handles.
*/
static native long[] get(long handle);
/**
* Evaluates the Hessian at wrt's value.
*
* @param handle Hessian handle.
* @return A record containing the triplet row, column and value arrays (int[], int[], and
* double[] respectively).
*/
static native NativeSparseTriplets value(long handle);
}

View File

@@ -0,0 +1,102 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import org.ejml.simple.SimpleMatrix;
/**
* This class calculates the Jacobian of a vector of variables with respect to a vector of
* variables.
*
* <p>The Jacobian is only recomputed if the variable expression is quadratic or higher order.
*/
public class Jacobian implements AutoCloseable {
private long m_handle;
private int m_rows;
private int m_cols;
/**
* Constructs a Jacobian object.
*
* @param variable Variable of which to compute the Jacobian.
* @param wrt Variable with respect to which to compute the Jacobian.
*/
public Jacobian(Variable variable, Variable wrt) {
this(new VariableMatrix(variable), new VariableMatrix(wrt));
}
/**
* Constructs a Jacobian object.
*
* @param variable Variable of which to compute the Jacobian.
* @param wrt Vector of variables with respect to which to compute the Jacobian.
*/
public Jacobian(Variable variable, VariableMatrix wrt) {
this(new VariableMatrix(variable), wrt);
}
/**
* Constructs a Jacobian object.
*
* @param variable Variable of which to compute the Jacobian.
* @param wrt Vector of variables with respect to which to compute the Jacobian.
*/
public Jacobian(Variable variable, VariableBlock wrt) {
this(new VariableMatrix(variable), new VariableMatrix(wrt));
}
/**
* Constructs a Jacobian object.
*
* @param variables Vector of variables of which to compute the Jacobian.
* @param wrt Vector of variables with respect to which to compute the Jacobian.
*/
public Jacobian(VariableMatrix variables, VariableMatrix wrt) {
assert variables.cols() == 1;
assert wrt.cols() == 1;
m_handle = JacobianJNI.create(variables.getHandles(), wrt.getHandles());
m_rows = variables.rows();
m_cols = wrt.rows();
}
/**
* Constructs a Jacobian object.
*
* @param variables Vector of variables of which to compute the Jacobian.
* @param wrt Vector of variables with respect to which to compute the Jacobian.
*/
public Jacobian(VariableMatrix variables, VariableBlock wrt) {
this(variables, new VariableMatrix(wrt));
}
@Override
public void close() {
if (m_handle != 0) {
JacobianJNI.destroy(m_handle);
m_handle = 0;
}
}
/**
* Returns the Jacobian as a VariableMatrix.
*
* <p>This is useful when constructing optimization problems with derivatives in them.
*
* @return The Jacobian as a VariableMatrix.
*/
public VariableMatrix get() {
return new VariableMatrix(m_rows, m_cols, JacobianJNI.get(m_handle));
}
/**
* Evaluates the Jacobian at wrt's value.
*
* @return The Jacobian at wrt's value.
*/
public SimpleMatrix value() {
return JacobianJNI.value(m_handle).toSimpleMatrix(m_rows, m_cols);
}
}

View File

@@ -0,0 +1,48 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import org.wpilib.math.jni.WPIMathJNI;
/** Jacobian JNI functions. */
final class JacobianJNI extends WPIMathJNI {
private JacobianJNI() {
// Utility class.
}
/**
* Constructs a Jacobian object.
*
* @param variables Vector of variables of which to compute the Jacobian.
* @param wrt Vector of variables with respect to which to compute the Jacobian.
*/
static native long create(long[] variables, long[] wrt);
/**
* Destructs a Jacobian.
*
* @param handle Jacobian handle.
*/
static native void destroy(long handle);
/**
* Returns the Jacobian as an array of Variable handles.
*
* <p>This is useful when constructing optimization problems with derivatives in them.
*
* @param handle Jacobian handle.
* @return The Jacobian as an array of Variable handles.
*/
static native long[] get(long handle);
/**
* Evaluates the Jacobian at wrt's value.
*
* @param handle Jacobian handle.
* @return A record containing the triplet row, column and value arrays (int[], int[], and
* double[] respectively).
*/
static native NativeSparseTriplets value(long handle);
}

View File

@@ -0,0 +1,37 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import org.ejml.data.DMatrixSparseCSC;
import org.ejml.data.DMatrixSparseTriplet;
import org.ejml.ops.DConvertMatrixStruct;
import org.ejml.simple.SimpleMatrix;
/**
* Wrapper for sparse matrix triplets from JNI.
*
* <p>We can't use DMatrixSparseTriplet because it doesn't have a method for bulk-initialization
* from triplet arrays.
*
* @param rows Triplet rows.
* @param cols Triplet columns.
* @param values Triplet values.
*/
public record NativeSparseTriplets(int[] rows, int[] cols, double[] values) {
/**
* Returns a SimpleMatrix wrapper for this set of triplets.
*
* @param rows Number of rows in sparse SimpleMatrix.
* @param cols Number of columns in sparse SimpleMatrix.
* @return A SimpleMatrix wrapper for this set of triplets.
*/
public SimpleMatrix toSimpleMatrix(int rows, int cols) {
var ejmlTriplets = new DMatrixSparseTriplet(rows, cols, values().length);
for (int i = 0; i < values().length; ++i) {
ejmlTriplets.addItem(rows()[i], cols()[i], values()[i]);
}
return new SimpleMatrix(DConvertMatrixStruct.convert(ejmlTriplets, (DMatrixSparseCSC) null));
}
}

View File

@@ -0,0 +1,89 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import java.util.function.BiFunction;
/** Numerical integration utilities. */
public final class NumericalIntegration {
private NumericalIntegration() {
// Utility class.
}
/**
* Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
*
* @param f The function to integrate. It must take two arguments x and u.
* @param x The initial value of x.
* @param u The value u held constant over the integration period.
* @param dt The time over which to integrate.
* @return the integration of dx/dt = f(x, u) for dt.
*/
public static VariableMatrix rk4(
BiFunction<VariableMatrix, VariableMatrix, VariableMatrix> f,
VariableBlock x,
VariableBlock u,
double dt) {
return rk4(f, new VariableMatrix(x), new VariableMatrix(u), dt);
}
/**
* Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
*
* @param f The function to integrate. It must take two arguments x and u.
* @param x The initial value of x.
* @param u The value u held constant over the integration period.
* @param dt The time over which to integrate.
* @return the integration of dx/dt = f(x, u) for dt.
*/
public static VariableMatrix rk4(
BiFunction<VariableMatrix, VariableMatrix, VariableMatrix> f,
VariableBlock x,
VariableMatrix u,
double dt) {
return rk4(f, new VariableMatrix(x), u, dt);
}
/**
* Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
*
* @param f The function to integrate. It must take two arguments x and u.
* @param x The initial value of x.
* @param u The value u held constant over the integration period.
* @param dt The time over which to integrate.
* @return the integration of dx/dt = f(x, u) for dt.
*/
public static VariableMatrix rk4(
BiFunction<VariableMatrix, VariableMatrix, VariableMatrix> f,
VariableMatrix x,
VariableBlock u,
double dt) {
return rk4(f, x, new VariableMatrix(u), dt);
}
/**
* Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
*
* @param f The function to integrate. It must take two arguments x and u.
* @param x The initial value of x.
* @param u The value u held constant over the integration period.
* @param dt The time over which to integrate.
* @return the integration of dx/dt = f(x, u) for dt.
*/
public static VariableMatrix rk4(
BiFunction<VariableMatrix, VariableMatrix, VariableMatrix> f,
VariableMatrix x,
VariableMatrix u,
double dt) {
var h = dt;
var k1 = f.apply(x, u);
var k2 = f.apply(x.plus(k1.times(h * 0.5)), u);
var k3 = f.apply(x.plus(k2.times(h * 0.5)), u);
var k4 = f.apply(x.plus(k3.times(h)), u);
return x.plus(k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4).times(h / 6.0));
}
}

View File

@@ -0,0 +1,267 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import java.util.OptionalInt;
/** Represents a sequence of elements in an iterable object. */
@SuppressWarnings("PMD.UnusedFormalParameter")
public class Slice {
/** Type tag used to designate an omitted argument of the slice. */
public static class None {
/** Default constructor. */
public None() {}
}
/** Designates an omitted argument of the slice. */
public static final None __ = null;
/** Start index (inclusive). */
public int start = 0;
/** Stop index (exclusive). */
public int stop = 0;
/** Step. */
public int step = 1;
/** Constructs a Slice. */
public Slice() {}
/**
* Constructs a slice.
*
* @param start Slice start index (inclusive).
*/
public Slice(None start) {
this(OptionalInt.of(0), OptionalInt.of(Integer.MAX_VALUE), OptionalInt.of(1));
}
/**
* Constructs a slice.
*
* @param start Slice start index (inclusive).
*/
public Slice(int start) {
this.start = start;
this.stop = (start == -1) ? Integer.MAX_VALUE : start + 1;
this.step = 1;
}
/**
* Constructs a slice.
*
* @param start Slice start index (inclusive).
* @param stop Slice stop index (exclusive).
*/
public Slice(None start, None stop) {
this(OptionalInt.empty(), OptionalInt.empty(), OptionalInt.of(1));
}
/**
* Constructs a slice.
*
* @param start Slice start index (inclusive).
* @param stop Slice stop index (exclusive).
*/
public Slice(None start, int stop) {
this(OptionalInt.empty(), OptionalInt.of(stop), OptionalInt.of(1));
}
/**
* Constructs a slice.
*
* @param start Slice start index (inclusive).
* @param stop Slice stop index (exclusive).
*/
public Slice(int start, None stop) {
this(OptionalInt.of(start), OptionalInt.empty(), OptionalInt.of(1));
}
/**
* Constructs a slice.
*
* @param start Slice start index (inclusive).
* @param stop Slice stop index (exclusive).
*/
public Slice(int start, int stop) {
this(OptionalInt.of(start), OptionalInt.of(stop), OptionalInt.of(1));
}
/**
* Constructs a slice.
*
* @param start Slice start index (inclusive).
* @param stop Slice stop index (exclusive).
* @param step Slice step.
*/
public Slice(None start, None stop, None step) {
this(OptionalInt.empty(), OptionalInt.empty(), OptionalInt.empty());
}
/**
* Constructs a slice.
*
* @param start Slice start index (inclusive).
* @param stop Slice stop index (exclusive).
* @param step Slice step.
*/
public Slice(None start, None stop, int step) {
this(OptionalInt.empty(), OptionalInt.empty(), OptionalInt.of(step));
}
/**
* Constructs a slice.
*
* @param start Slice start index (inclusive).
* @param stop Slice stop index (exclusive).
* @param step Slice step.
*/
public Slice(None start, int stop, None step) {
this(OptionalInt.empty(), OptionalInt.of(stop), OptionalInt.empty());
}
/**
* Constructs a slice.
*
* @param start Slice start index (inclusive).
* @param stop Slice stop index (exclusive).
* @param step Slice step.
*/
public Slice(None start, int stop, int step) {
this(OptionalInt.empty(), OptionalInt.of(stop), OptionalInt.of(step));
}
/**
* Constructs a slice.
*
* @param start Slice start index (inclusive).
* @param stop Slice stop index (exclusive).
* @param step Slice step.
*/
public Slice(int start, None stop, None step) {
this(OptionalInt.of(start), OptionalInt.empty(), OptionalInt.empty());
}
/**
* Constructs a slice.
*
* @param start Slice start index (inclusive).
* @param stop Slice stop index (exclusive).
* @param step Slice step.
*/
public Slice(int start, None stop, int step) {
this(OptionalInt.of(start), OptionalInt.empty(), OptionalInt.of(step));
}
/**
* Constructs a slice.
*
* @param start Slice start index (inclusive).
* @param stop Slice stop index (exclusive).
* @param step Slice step.
*/
public Slice(int start, int stop, None step) {
this(OptionalInt.of(start), OptionalInt.of(stop), OptionalInt.empty());
}
/**
* Constructs a slice.
*
* @param start Slice start index (inclusive).
* @param stop Slice stop index (exclusive).
* @param step Slice step.
*/
public Slice(int start, int stop, int step) {
this(OptionalInt.of(start), OptionalInt.of(stop), OptionalInt.of(step));
}
/**
* Constructs a slice.
*
* @param start Slice start index (inclusive).
* @param stop Slice stop index (exclusive).
* @param step Slice step.
*/
public Slice(OptionalInt start, OptionalInt stop, OptionalInt step) {
if (!step.isPresent()) {
this.step = 1;
} else {
assert step.getAsInt() != 0;
this.step = step.getAsInt();
}
// Avoid UB for step = -step if step is INT_MIN
if (this.step == Integer.MIN_VALUE) {
this.step = -Integer.MAX_VALUE;
}
if (!start.isPresent()) {
if (this.step < 0) {
this.start = Integer.MAX_VALUE;
} else {
this.start = 0;
}
} else {
this.start = start.getAsInt();
}
if (!stop.isPresent()) {
if (this.step < 0) {
this.stop = Integer.MIN_VALUE;
} else {
this.stop = Integer.MAX_VALUE;
}
} else {
this.stop = stop.getAsInt();
}
}
/**
* Adjusts start and end slice indices assuming a sequence of the specified length.
*
* @param length The sequence length.
* @return The slice length.
*/
public int adjust(int length) {
assert step != 0;
assert step >= -Integer.MAX_VALUE;
if (start < 0) {
start += length;
if (start < 0) {
start = (step < 0) ? -1 : 0;
}
} else if (start >= length) {
start = (step < 0) ? length - 1 : length;
}
if (stop < 0) {
stop += length;
if (stop < 0) {
stop = (step < 0) ? -1 : 0;
}
} else if (stop >= length) {
stop = (step < 0) ? length - 1 : length;
}
if (step < 0) {
if (stop < start) {
return (start - stop - 1) / -step + 1;
} else {
return 0;
}
} else {
if (start < stop) {
return (stop - start - 1) / step + 1;
} else {
return 0;
}
}
}
}

View File

@@ -0,0 +1,634 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
/** An autodiff variable pointing to an expression node. */
public class Variable implements AutoCloseable {
/** Handle type tag. */
public static final class Handle {
/** Constructor for Handle. */
public Handle() {}
}
/** Instance of handle type tag. */
public static final Handle HANDLE = new Handle();
private long m_handle;
/** Constructs a linear Variable with a value of zero. */
@SuppressWarnings("this-escape")
public Variable() {
m_handle = VariableJNI.createDefault();
VariablePool.register(this);
}
/**
* Constructs a Variable from a floating point type.
*
* @param value The value of the Variable.
*/
@SuppressWarnings("this-escape")
public Variable(double value) {
m_handle = VariableJNI.createDouble(value);
VariablePool.register(this);
}
/**
* Constructs a Variable from an integral type.
*
* @param value The value of the Variable.
*/
@SuppressWarnings("this-escape")
public Variable(int value) {
m_handle = VariableJNI.createInt(value);
VariablePool.register(this);
}
/**
* Constructs a Variable from the given handle.
*
* <p>This constructor is for internal use only.
*
* @param handleTypeTag Handle type tag.
* @param handle Variable handle.
*/
@SuppressWarnings({"PMD.UnusedFormalParameter", "this-escape"})
public Variable(Handle handleTypeTag, long handle) {
m_handle = handle;
VariablePool.register(this);
}
@Override
public void close() {
if (m_handle != 0) {
VariableJNI.destroy(m_handle);
m_handle = 0;
}
}
/**
* Sets Variable's internal value.
*
* @param value The value of the Variable.
*/
public void setValue(double value) {
VariableJNI.setValue(m_handle, value);
}
/**
* Returns the value of this variable.
*
* @return The value of this variable.
*/
public double value() {
return VariableJNI.value(m_handle);
}
/**
* Returns the type of this expression (constant, linear, quadratic, or nonlinear).
*
* @return The type of this expression.
*/
public ExpressionType type() {
return ExpressionType.fromInt(VariableJNI.type(m_handle));
}
/**
* Returns internal handle.
*
* @return Internal handle.
*/
public long getHandle() {
return m_handle;
}
/**
* Variable-Variable multiplication operator.
*
* @param rhs Operator right-hand side.
* @return Result of multiplication.
*/
public Variable times(Variable rhs) {
return new Variable(HANDLE, VariableJNI.times(m_handle, rhs.getHandle()));
}
/**
* Variable-Variable multiplication operator.
*
* @param rhs Operator right-hand side.
* @return Result of multiplication.
*/
public Variable times(double rhs) {
return times(new Variable(rhs));
}
/**
* Variable-Variable division operator.
*
* @param rhs Operator right-hand side.
* @return Result of division.
*/
public Variable div(Variable rhs) {
return new Variable(HANDLE, VariableJNI.div(m_handle, rhs.getHandle()));
}
/**
* Variable-Variable division operator.
*
* @param rhs Operator right-hand side.
* @return Result of division.
*/
public Variable div(double rhs) {
return div(new Variable(rhs));
}
/**
* Variable-Variable addition operator.
*
* @param rhs Operator right-hand side.
* @return Result of addition.
*/
public Variable plus(Variable rhs) {
return new Variable(HANDLE, VariableJNI.plus(m_handle, rhs.getHandle()));
}
/**
* Variable-Variable addition operator.
*
* @param rhs Operator right-hand side.
* @return Result of addition.
*/
public Variable plus(double rhs) {
return plus(new Variable(rhs));
}
/**
* Variable-Variable subtraction operator.
*
* @param rhs Operator right-hand side.
* @return Result of subtraction.
*/
public Variable minus(Variable rhs) {
return new Variable(HANDLE, VariableJNI.minus(m_handle, rhs.getHandle()));
}
/**
* Variable-Variable subtraction operator.
*
* @param rhs Operator right-hand side.
* @return Result of subtraction.
*/
public Variable minus(double rhs) {
return minus(new Variable(rhs));
}
/**
* Unary minus operator.
*
* @return Result of unary minus.
*/
public Variable unaryMinus() {
return new Variable(HANDLE, VariableJNI.unaryMinus(m_handle));
}
/**
* Unary plus operator.
*
* @return Result of unary plus.
*/
public Variable unaryPlus() {
return this;
}
/**
* Math.abs() for Variables.
*
* @param x The argument.
* @return Result of abs().
*/
public static Variable abs(Variable x) {
return new Variable(HANDLE, VariableJNI.abs(x.getHandle()));
}
/**
* Math.acos() for Variables.
*
* @param x The argument.
* @return Result of acos().
*/
public static Variable acos(Variable x) {
return new Variable(HANDLE, VariableJNI.acos(x.getHandle()));
}
/**
* Math.asin() for Variables.
*
* @param x The argument.
* @return Result of asin().
*/
public static Variable asin(Variable x) {
return new Variable(HANDLE, VariableJNI.asin(x.getHandle()));
}
/**
* Math.atan() for Variables.
*
* @param x The argument.
* @return Result of atan().
*/
public static Variable atan(Variable x) {
return new Variable(HANDLE, VariableJNI.atan(x.getHandle()));
}
/**
* Math.atan2() for Variables.
*
* @param y The y argument.
* @param x The x argument.
* @return Result of atan2().
*/
public static Variable atan2(double y, Variable x) {
return atan2(new Variable(y), x);
}
/**
* Math.atan2() for Variables.
*
* @param y The y argument.
* @param x The x argument.
* @return Result of atan2().
*/
public static Variable atan2(Variable y, double x) {
return atan2(y, new Variable(x));
}
/**
* Math.atan2() for Variables.
*
* @param y The y argument.
* @param x The x argument.
* @return Result of atan2().
*/
public static Variable atan2(Variable y, Variable x) {
return new Variable(HANDLE, VariableJNI.atan2(y.getHandle(), x.getHandle()));
}
/**
* Math.cbrt() for Variables.
*
* @param x The argument.
* @return Result of cbrt().
*/
public static Variable cbrt(Variable x) {
return new Variable(HANDLE, VariableJNI.cbrt(x.getHandle()));
}
/**
* Math.cos() for Variables.
*
* @param x The argument.
* @return Result of cos().
*/
public static Variable cos(Variable x) {
return new Variable(HANDLE, VariableJNI.cos(x.getHandle()));
}
/**
* Math.cosh() for Variables.
*
* @param x The argument.
* @return Result of cosh().
*/
public static Variable cosh(Variable x) {
return new Variable(HANDLE, VariableJNI.cosh(x.getHandle()));
}
/**
* Math.exp() for Variables.
*
* @param x The argument.
* @return Result of exp().
*/
public static Variable exp(Variable x) {
return new Variable(HANDLE, VariableJNI.exp(x.getHandle()));
}
/**
* Math.hypot() for Variables.
*
* @param x The x argument.
* @param y The y argument.
* @return Result of hypot().
*/
public static Variable hypot(double x, Variable y) {
return hypot(new Variable(x), y);
}
/**
* Math.hypot() for Variables.
*
* @param x The x argument.
* @param y The y argument.
* @return Result of hypot().
*/
public static Variable hypot(Variable x, double y) {
return hypot(x, new Variable(y));
}
/**
* Math.hypot() for Variables.
*
* @param x The x argument.
* @param y The y argument.
* @return Result of hypot().
*/
public static Variable hypot(Variable x, Variable y) {
return new Variable(HANDLE, VariableJNI.hypot(x.getHandle(), y.getHandle()));
}
/**
* Math.hypot() for Variables.
*
* @param x The x argument.
* @param y The y argument.
* @param z The z argument.
* @return Result of hypot().
*/
public static Variable hypot(double x, double y, Variable z) {
return hypot(new Variable(x), new Variable(y), z);
}
/**
* Math.hypot() for Variables.
*
* @param x The x argument.
* @param y The y argument.
* @param z The z argument.
* @return Result of hypot().
*/
public static Variable hypot(double x, Variable y, double z) {
return hypot(new Variable(x), y, new Variable(z));
}
/**
* Math.hypot() for Variables.
*
* @param x The x argument.
* @param y The y argument.
* @param z The z argument.
* @return Result of hypot().
*/
public static Variable hypot(double x, Variable y, Variable z) {
return hypot(new Variable(x), y, z);
}
/**
* Math.hypot() for Variables.
*
* @param x The x argument.
* @param y The y argument.
* @param z The z argument.
* @return Result of hypot().
*/
public static Variable hypot(Variable x, double y, double z) {
return hypot(x, new Variable(y), new Variable(z));
}
/**
* Math.hypot() for Variables.
*
* @param x The x argument.
* @param y The y argument.
* @param z The z argument.
* @return Result of hypot().
*/
public static Variable hypot(Variable x, double y, Variable z) {
return hypot(x, new Variable(y), z);
}
/**
* Math.hypot() for Variables.
*
* @param x The x argument.
* @param y The y argument.
* @param z The z argument.
* @return Result of hypot().
*/
public static Variable hypot(Variable x, Variable y, double z) {
return hypot(x, y, new Variable(z));
}
/**
* Math.hypot() for Variables.
*
* @param x The x argument.
* @param y The y argument.
* @param z The z argument.
* @return Result of hypot().
*/
public static Variable hypot(Variable x, Variable y, Variable z) {
return sqrt(pow(x, 2).plus(pow(y, 2)).plus(pow(z, 2)));
}
/**
* Math.log() for Variables.
*
* @param x The argument.
* @return Result of log().
*/
public static Variable log(Variable x) {
return new Variable(HANDLE, VariableJNI.log(x.getHandle()));
}
/**
* Math.log10() for Variables.
*
* @param x The argument.
* @return Result of log10().
*/
public static Variable log10(Variable x) {
return new Variable(HANDLE, VariableJNI.log10(x.getHandle()));
}
/**
* Math.max() for Variables.
*
* <p>Returns the greater of a and b. If the values are equivalent, returns a.
*
* @param a The a argument.
* @param b The b argument.
* @return Result of max().
*/
public static Variable max(double a, Variable b) {
return max(new Variable(a), b);
}
/**
* Math.max() for Variables.
*
* <p>Returns the greater of a and b. If the values are equivalent, returns a.
*
* @param a The a argument.
* @param b The b argument.
* @return Result of max().
*/
public static Variable max(Variable a, double b) {
return max(a, new Variable(b));
}
/**
* Math.max() for Variables.
*
* <p>Returns the greater of a and b. If the values are equivalent, returns a.
*
* @param a The a argument.
* @param b The b argument.
* @return Result of max().
*/
public static Variable max(Variable a, Variable b) {
return new Variable(HANDLE, VariableJNI.max(a.getHandle(), b.getHandle()));
}
/**
* min() for Variables.
*
* <p>Returns the lesser of a and b. If the values are equivalent, returns a.
*
* @param a The a argument.
* @param b The b argument.
* @return Result of min().
*/
public static Variable min(double a, Variable b) {
return min(new Variable(a), b);
}
/**
* min() for Variables.
*
* <p>Returns the lesser of a and b. If the values are equivalent, returns a.
*
* @param a The a argument.
* @param b The b argument.
* @return Result of min().
*/
public static Variable min(Variable a, double b) {
return min(a, new Variable(b));
}
/**
* min() for Variables.
*
* <p>Returns the lesser of a and b. If the values are equivalent, returns a.
*
* @param a The a argument.
* @param b The b argument.
* @return Result of min().
*/
public static Variable min(Variable a, Variable b) {
return new Variable(HANDLE, VariableJNI.min(a.getHandle(), b.getHandle()));
}
/**
* Math.pow() for Variables.
*
* @param base The base.
* @param power The power.
* @return Result of pow().
*/
public static Variable pow(double base, Variable power) {
return pow(new Variable(base), power);
}
/**
* Math.pow() for Variables.
*
* @param base The base.
* @param power The power.
* @return Result of pow().
*/
public static Variable pow(Variable base, double power) {
return pow(base, new Variable(power));
}
/**
* Math.pow() for Variables.
*
* @param base The base.
* @param power The power.
* @return Result of pow().
*/
public static Variable pow(Variable base, Variable power) {
return new Variable(HANDLE, VariableJNI.pow(base.getHandle(), power.getHandle()));
}
/**
* Math.signum() for Variables.
*
* @param x The argument.
* @return Result of signum().
*/
public static Variable signum(Variable x) {
return new Variable(HANDLE, VariableJNI.signum(x.getHandle()));
}
/**
* Math.sin() for Variables.
*
* @param x The argument.
* @return Result of sin().
*/
public static Variable sin(Variable x) {
return new Variable(HANDLE, VariableJNI.sin(x.getHandle()));
}
/**
* Math.sinh() for Variables.
*
* @param x The argument.
* @return Result of sinh().
*/
public static Variable sinh(Variable x) {
return new Variable(HANDLE, VariableJNI.sinh(x.getHandle()));
}
/**
* Math.sqrt() for Variables.
*
* @param x The argument.
* @return Result of sqrt().
*/
public static Variable sqrt(Variable x) {
return new Variable(HANDLE, VariableJNI.sqrt(x.getHandle()));
}
/**
* Math.tan() for Variables.
*
* @param x The argument.
* @return Result of tan().
*/
public static Variable tan(Variable x) {
return new Variable(HANDLE, VariableJNI.tan(x.getHandle()));
}
/**
* Math.tanh() for Variables.
*
* @param x The argument.
* @return Result of tanh().
*/
public static Variable tanh(Variable x) {
return new Variable(HANDLE, VariableJNI.tanh(x.getHandle()));
}
/**
* Returns the total native memory usage of all Variables in bytes.
*
* @return The total native memory usage of all Variables in bytes.
*/
public static long totalNativeMemoryUsage() {
return VariableJNI.totalNativeMemoryUsage();
}
}

View File

@@ -0,0 +1,826 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.function.UnaryOperator;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.ejml.simple.SimpleMatrix;
/** A submatrix of autodiff variables with reference semantics. */
public class VariableBlock implements Iterable<Variable> {
private final VariableMatrix m_mat;
private final Slice m_rowSlice;
private final int m_rowSliceLength;
private final Slice m_colSlice;
private final int m_colSliceLength;
/**
* Constructs a Variable block pointing to all of the given matrix.
*
* @param mat The matrix to which to point.
*/
public VariableBlock(VariableMatrix mat) {
this(mat, 0, 0, mat.rows(), mat.cols());
}
/**
* Constructs a Variable block pointing to a subset of the given matrix.
*
* @param mat The matrix to which to point.
* @param rowOffset The block's row offset.
* @param colOffset The block's column offset.
* @param blockRows The number of rows in the block.
* @param blockCols The number of columns in the block.
*/
public VariableBlock(
VariableMatrix mat, int rowOffset, int colOffset, int blockRows, int blockCols) {
m_mat = mat;
m_rowSlice = new Slice(rowOffset, rowOffset + blockRows, 1);
m_rowSliceLength = m_rowSlice.adjust(mat.rows());
m_colSlice = new Slice(colOffset, colOffset + blockCols, 1);
m_colSliceLength = m_colSlice.adjust(mat.cols());
}
/**
* Constructs a Variable block pointing to a subset of the given matrix.
*
* <p>Note that the slices are taken as is rather than adjusted.
*
* @param mat The matrix to which to point.
* @param rowSlice The block's row slice.
* @param rowSliceLength The block's row length.
* @param colSlice The block's column slice.
* @param colSliceLength The block's column length.
*/
public VariableBlock(
VariableMatrix mat, Slice rowSlice, int rowSliceLength, Slice colSlice, int colSliceLength) {
m_mat = mat;
m_rowSlice = rowSlice;
m_rowSliceLength = rowSliceLength;
m_colSlice = colSlice;
m_colSliceLength = colSliceLength;
}
/**
* Assigns a double to the block.
*
* <p>This only works for blocks with one row and one column.
*
* @param value Value to assign.
* @return This VariableBlock.
*/
public VariableBlock set(double value) {
assert rows() == 1 && cols() == 1;
set(0, 0, new Variable(value));
return this;
}
/**
* Assigns a Variable to the block.
*
* <p>This only works for blocks with one row and one column.
*
* @param value Value to assign.
* @return This VariableBlock.
*/
public VariableBlock set(Variable value) {
assert rows() == 1 && cols() == 1;
set(0, 0, value);
return this;
}
/**
* Assigns a double array to the block.
*
* @param values Double array of values to assign.
* @return This VariableBlock.
*/
public VariableBlock set(double[][] values) {
assert rows() == values.length;
// Assert all column counts are the same
for (var row : values) {
assert row.length == cols();
}
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
set(row, col, values[row][col]);
}
}
return this;
}
/**
* Assigns an EJML matrix to the block.
*
* @param values EJML matrix of values to assign.
* @return This VariableBlock.
*/
public VariableBlock set(SimpleMatrix values) {
assert rows() == values.getNumRows() && cols() == values.getNumCols();
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
set(row, col, values.get(row, col));
}
}
return this;
}
/**
* Assigns a VariableMatrix to the block.
*
* @param values VariableMatrix of values.
* @return This VariableBlock.
*/
public VariableBlock set(VariableMatrix values) {
assert rows() == values.rows() && cols() == values.cols();
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
set(row, col, values.get(row, col));
}
}
return this;
}
/**
* Assigns a VariableBlock to the block.
*
* @param values VariableBlock of values.
* @return This VariableBlock.
*/
public VariableBlock set(VariableBlock values) {
assert rows() == values.rows() && cols() == values.cols();
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
set(row, col, values.get(row, col));
}
}
return this;
}
/**
* Sets a scalar subblock at the given row and column.
*
* @param row The scalar subblock's row.
* @param col The scalar subblock's column.
* @param value The value.
*/
public void set(int row, int col, Variable value) {
assert row >= 0 && row < rows();
assert col >= 0 && col < cols();
m_mat.set(
m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step, value);
}
/**
* Sets a scalar subblock at the given row and column.
*
* @param row The scalar subblock's row.
* @param col The scalar subblock's column.
* @param value The value.
*/
public void set(int row, int col, double value) {
assert row >= 0 && row < rows();
assert col >= 0 && col < cols();
m_mat.set(
m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step, value);
}
/**
* Sets a scalar subblock at the given index.
*
* @param index The scalar subblock's index.
* @param value The value.
*/
public void set(int index, double value) {
set(index, new Variable(value));
}
/**
* Sets a scalar subblock at the given index.
*
* @param index The scalar subblock's index.
* @param value The value.
*/
public void set(int index, Variable value) {
assert index >= 0 && index < rows() * cols();
set(index / cols(), index % cols(), value);
}
/**
* Assigns a double to the block.
*
* <p>This only works for blocks with one row and one column.
*
* @param value Value to assign.
*/
public void setValue(double value) {
assert rows() == 1 && cols() == 1;
get(0, 0).setValue(value);
}
/**
* Sets block's internal values.
*
* @param values Double array of values.
*/
public void setValue(double[][] values) {
assert rows() == values.length;
// Assert all column counts are the same
for (var row : values) {
assert row.length == cols();
}
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
get(row, col).setValue(values[row][col]);
}
}
}
/**
* Sets block's internal values.
*
* @param values EJML matrix of values.
*/
public void setValue(SimpleMatrix values) {
assert rows() == values.getNumRows() && cols() == values.getNumCols();
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
get(row, col).setValue(values.get(row, col));
}
}
}
/**
* Returns a scalar subblock at the given row and column.
*
* @param row The scalar subblock's row.
* @param col The scalar subblock's column.
* @return A scalar subblock at the given row and column.
*/
public Variable get(int row, int col) {
assert row >= 0 && row < rows();
assert col >= 0 && col < cols();
return m_mat.get(
m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step);
}
/**
* Returns a scalar subblock at the given index.
*
* @param index The scalar subblock's index.
* @return A scalar subblock at the given index.
*/
public Variable get(int index) {
assert index >= 0 && index < rows() * cols();
return get(index / cols(), index % cols());
}
/**
* Returns a slice of the variable matrix.
*
* @param row The row.
* @param colSlice The column slice.
* @return A slice of the variable matrix.
*/
public VariableBlock get(int row, Slice.None colSlice) {
return get(new Slice(row), new Slice(colSlice));
}
/**
* Returns a slice of the variable matrix.
*
* @param row The row.
* @param colSlice The column slice.
* @return A slice of the variable matrix.
*/
public VariableBlock get(int row, Slice colSlice) {
return get(new Slice(row), colSlice);
}
/**
* Returns a slice of the variable matrix.
*
* @param rowSlice The row slice.
* @param col The column.
* @return A slice of the variable matrix.
*/
public VariableBlock get(Slice.None rowSlice, int col) {
return get(new Slice(rowSlice), new Slice(col));
}
/**
* Returns a slice of the variable matrix.
*
* @param rowSlice The row slice.
* @param col The column.
* @return A slice of the variable matrix.
*/
public VariableBlock get(Slice rowSlice, int col) {
return get(rowSlice, new Slice(col));
}
/**
* Returns a slice of the variable matrix.
*
* @param rowSlice The row slice.
* @param colSlice The column slice.
* @return A slice of the variable matrix.
*/
public VariableBlock get(Slice.None rowSlice, Slice.None colSlice) {
return get(new Slice(rowSlice), new Slice(colSlice));
}
/**
* Returns a slice of the variable matrix.
*
* @param rowSlice The row slice.
* @param colSlice The column slice.
* @return A slice of the variable matrix.
*/
public VariableBlock get(Slice.None rowSlice, Slice colSlice) {
return get(new Slice(rowSlice), colSlice);
}
/**
* Returns a slice of the variable matrix.
*
* @param rowSlice The row slice.
* @param colSlice The column slice.
* @return A slice of the variable matrix.
*/
public VariableBlock get(Slice rowSlice, Slice.None colSlice) {
return get(rowSlice, new Slice(colSlice));
}
/**
* Returns a slice of the variable matrix.
*
* @param rowSlice The row slice.
* @param colSlice The column slice.
* @return A slice of the variable matrix.
*/
public VariableBlock get(Slice rowSlice, Slice colSlice) {
int rowSliceLength = rowSlice.adjust(m_rowSliceLength);
int colSliceLength = colSlice.adjust(m_colSliceLength);
return new VariableBlock(
m_mat,
new Slice(
m_rowSlice.start + rowSlice.start * m_rowSlice.step,
m_rowSlice.start + rowSlice.stop * m_rowSlice.step,
rowSlice.step * m_rowSlice.step),
rowSliceLength,
new Slice(
m_colSlice.start + colSlice.start * m_colSlice.step,
m_colSlice.start + colSlice.stop * m_colSlice.step,
colSlice.step * m_colSlice.step),
colSliceLength);
}
/**
* Returns a block of the variable matrix.
*
* @param rowOffset The row offset of the block selection.
* @param colOffset The column offset of the block selection.
* @param blockRows The number of rows in the block selection.
* @param blockCols The number of columns in the block selection.
* @return A block of the variable matrix.
*/
public VariableBlock block(int rowOffset, int colOffset, int blockRows, int blockCols) {
assert rowOffset >= 0 && rowOffset <= rows();
assert colOffset >= 0 && colOffset <= cols();
assert blockRows >= 0 && blockRows <= rows() - rowOffset;
assert blockCols >= 0 && blockCols <= cols() - colOffset;
return get(
new Slice(rowOffset, rowOffset + blockRows, 1),
new Slice(colOffset, colOffset + blockCols, 1));
}
/**
* Returns a segment of the variable vector.
*
* @param offset The offset of the segment.
* @param length The length of the segment.
* @return A segment of the variable vector.
*/
public VariableBlock segment(int offset, int length) {
assert cols() == 1;
assert offset >= 0 && offset < rows();
assert length >= 0 && length <= rows() - offset;
return block(offset, 0, length, 1);
}
/**
* Returns a row slice of the variable matrix.
*
* @param row The row to slice.
* @return A row slice of the variable matrix.
*/
public VariableBlock row(int row) {
assert row >= 0 && row < rows();
return block(row, 0, 1, cols());
}
/**
* Returns a column slice of the variable matrix.
*
* @param col The column to slice.
* @return A column slice of the variable matrix.
*/
public VariableBlock col(int col) {
assert col >= 0 && col < cols();
return block(0, col, rows(), 1);
}
/**
* Matrix multiplication operator.
*
* @param rhs Operator right-hand side.
* @return Result of matrix multiplication.
*/
public VariableMatrix times(VariableMatrix rhs) {
assert cols() == rhs.rows();
var result = new VariableMatrix(rows(), rhs.cols());
for (int i = 0; i < rows(); ++i) {
for (int j = 0; j < rhs.cols(); ++j) {
var sum = new Variable(0.0);
for (int k = 0; k < cols(); ++k) {
sum = sum.plus(get(i, k).times(rhs.get(k, j)));
}
result.set(i, j, sum);
}
}
return result;
}
/**
* Matrix multiplication operator.
*
* @param rhs Operator right-hand side.
* @return Result of matrix multiplication.
*/
public VariableMatrix times(VariableBlock rhs) {
assert cols() == rhs.rows();
var result = new VariableMatrix(rows(), rhs.cols());
for (int i = 0; i < rows(); ++i) {
for (int j = 0; j < rhs.cols(); ++j) {
var sum = new Variable(0.0);
for (int k = 0; k < cols(); ++k) {
sum = sum.plus(get(i, k).times(rhs.get(k, j)));
}
result.set(i, j, sum);
}
}
return result;
}
/**
* Matrix-scalar multiplication operator.
*
* @param rhs Operator right-hand side.
* @return Result of matrix-scalar multiplication.
*/
public VariableMatrix times(double rhs) {
return times(new Variable(rhs));
}
/**
* Matrix-scalar multiplication operator.
*
* @param rhs Operator right-hand side.
* @return Result of matrix-scalar multiplication.
*/
public VariableMatrix times(Variable rhs) {
var result = new VariableMatrix(rows(), cols());
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
result.set(row, col, get(row, col).times(rhs));
}
}
return result;
}
/**
* Binary division operator.
*
* @param rhs Operator right-hand side.
* @return Result of division.
*/
public VariableMatrix div(double rhs) {
return div(new Variable(rhs));
}
/**
* Binary division operator.
*
* @param rhs Operator right-hand side.
* @return Result of division.
*/
public VariableMatrix div(Variable rhs) {
var result = new VariableMatrix(rows(), cols());
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
result.set(row, col, get(row, col).div(rhs));
}
}
return result;
}
/**
* Binary addition operator.
*
* @param rhs Operator right-hand side.
* @return Result of addition.
*/
public VariableMatrix plus(VariableMatrix rhs) {
assert rows() == rhs.rows() && cols() == rhs.cols();
var result = new VariableMatrix(rows(), cols());
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
result.set(row, col, get(row, col).plus(rhs.get(row, col)));
}
}
return result;
}
/**
* Binary addition operator.
*
* @param rhs Operator right-hand side.
* @return Result of addition.
*/
public VariableMatrix plus(VariableBlock rhs) {
assert rows() == rhs.rows() && cols() == rhs.cols();
var result = new VariableMatrix(rows(), cols());
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
result.set(row, col, get(row, col).plus(rhs.get(row, col)));
}
}
return result;
}
/**
* Binary addition operator.
*
* @param rhs Operator right-hand side.
* @return Result of addition.
*/
public VariableMatrix plus(SimpleMatrix rhs) {
assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols();
var result = new VariableMatrix(rows(), cols());
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
result.set(row, col, get(row, col).plus(rhs.get(row, col)));
}
}
return result;
}
/**
* Binary subtraction operator.
*
* @param rhs Operator right-hand side.
* @return Result of subtraction.
*/
public VariableMatrix minus(VariableMatrix rhs) {
assert rows() == rhs.rows() && cols() == rhs.cols();
var result = new VariableMatrix(rows(), cols());
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
result.set(row, col, get(row, col).minus(rhs.get(row, col)));
}
}
return result;
}
/**
* Binary subtraction operator.
*
* @param rhs Operator right-hand side.
* @return Result of subtraction.
*/
public VariableMatrix minus(VariableBlock rhs) {
assert rows() == rhs.rows() && cols() == rhs.cols();
var result = new VariableMatrix(rows(), cols());
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
result.set(row, col, get(row, col).minus(rhs.get(row, col)));
}
}
return result;
}
/**
* Binary subtraction operator.
*
* @param rhs Operator right-hand side.
* @return Result of subtraction.
*/
public VariableMatrix minus(SimpleMatrix rhs) {
assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols();
var result = new VariableMatrix(rows(), cols());
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
result.set(row, col, get(row, col).minus(rhs.get(row, col)));
}
}
return result;
}
/**
* Unary minus operator.
*
* @return Result of unary minus.
*/
public VariableMatrix unaryMinus() {
var result = new VariableMatrix(rows(), cols());
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
result.set(row, col, get(row, col).unaryMinus());
}
}
return result;
}
/**
* Returns the transpose of the variable matrix.
*
* @return The transpose of the variable matrix.
*/
public VariableMatrix T() {
var result = new VariableMatrix(cols(), rows());
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
result.set(col, row, get(row, col));
}
}
return result;
}
/**
* Returns the number of rows in the matrix.
*
* @return The number of rows in the matrix.
*/
public int rows() {
return m_rowSliceLength;
}
/**
* Returns the number of columns in the matrix.
*
* @return The number of columns in the matrix.
*/
public int cols() {
return m_colSliceLength;
}
/**
* Returns an element of the variable matrix.
*
* @param row The row of the element to return.
* @param col The column of the element to return.
* @return An element of the variable matrix.
*/
public double value(int row, int col) {
return get(row, col).value();
}
/**
* Returns an element of the variable block.
*
* @param index The index of the element to return.
* @return An element of the variable block.
*/
public double value(int index) {
return get(index).value();
}
/**
* Returns the contents of the variable matrix.
*
* @return The contents of the variable matrix.
*/
public SimpleMatrix value() {
var result = new SimpleMatrix(rows(), cols());
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
result.set(row, col, value(row, col));
}
}
return result;
}
/**
* Maps the matrix coefficient-wise with an unary operator.
*
* @param unaryOp The unary operator to use for the map operation.
* @return Result of the unary operator.
*/
public VariableMatrix cwiseMap(UnaryOperator<Variable> unaryOp) {
var result = new VariableMatrix(rows(), cols());
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
result.set(row, col, unaryOp.apply(get(row, col)));
}
}
return result;
}
/**
* Returns number of elements in matrix.
*
* @return Number of elements in matrix.
*/
public int size() {
return rows() * cols();
}
@Override
public Iterator<Variable> iterator() {
return new Iterator<>() {
private int m_index = 0;
@Override
public boolean hasNext() {
return m_index < VariableBlock.this.size();
}
@Override
public Variable next() {
if (!hasNext()) {
throw new NoSuchElementException();
}
return VariableBlock.this.get(m_index++);
}
};
}
/**
* Creates a Stream of VariableBlock elements.
*
* @return A Stream of VariableBlock elements.
*/
public Stream<Variable> stream() {
return StreamSupport.stream(spliterator(), false);
}
}

View File

@@ -0,0 +1,268 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import org.wpilib.math.jni.WPIMathJNI;
/** Variable JNI functions. */
final class VariableJNI extends WPIMathJNI {
private VariableJNI() {
// Utility class.
}
/** Constructs a default Variable. */
static native long createDefault();
/**
* Constructs a Variable from a floating point type.
*
* @param value The value of the Variable.
*/
static native long createDouble(double value);
/**
* Constructs a Variable from an integral type.
*
* @param value The value of the Variable.
*/
static native long createInt(int value);
/**
* Destructs a Variable.
*
* @param handle Variable handle.
*/
static native void destroy(long handle);
/**
* Sets Variable's internal value.
*
* @param handle Variable handle.
* @param value The value of the Variable.
*/
static native void setValue(long handle, double value);
/**
* Variable-Variable multiplication operator.
*
* @param handle Variable handle.
* @param rhs Operator right-hand side.
* @return Result of multiplication.
*/
static native long times(long handle, long rhs);
/**
* Variable-Variable division operator.
*
* @param handle Variable handle.
* @param rhs Operator right-hand side.
* @return Result of division.
*/
static native long div(long handle, long rhs);
/**
* Variable-Variable addition operator.
*
* @param handle Variable handle.
* @param rhs Operator right-hand side.
* @return Result of addition.
*/
static native long plus(long handle, long rhs);
/**
* Variable-Variable subtraction operator.
*
* @param handle Variable handle.
* @param rhs Operator right-hand side.
* @return Result of subtraction.
*/
static native long minus(long handle, long rhs);
/**
* Unary minus operator.
*
* @param handle Variable handle.
*/
static native long unaryMinus(long handle);
/**
* Returns the value of this variable.
*
* @param handle Variable handle.
* @return The value of this variable.
*/
static native double value(long handle);
/**
* Returns the type of this expression (constant, linear, quadratic, or nonlinear).
*
* @param handle Variable handle.
* @return The type of this expression.
*/
static native int type(long handle);
/**
* Math.abs() for Variables.
*
* @param x The argument.
*/
static native long abs(long x);
/**
* Math.acos() for Variables.
*
* @param x The argument.
*/
static native long acos(long x);
/**
* Math.asin() for Variables.
*
* @param x The argument.
*/
static native long asin(long x);
/**
* Math.atan() for Variables.
*
* @param x The argument.
*/
static native long atan(long x);
/**
* Math.atan2() for Variables.
*
* @param y The y argument.
* @param x The x argument.
*/
static native long atan2(long y, long x);
/**
* Math.cbrt() for Variables.
*
* @param x The argument.
*/
static native long cbrt(long x);
/**
* Math.cos() for Variables.
*
* @param x The argument.
*/
static native long cos(long x);
/**
* Math.cosh() for Variables.
*
* @param x The argument.
*/
static native long cosh(long x);
/**
* Math.exp() for Variables.
*
* @param x The argument.
*/
static native long exp(long x);
/**
* Math.hypot() for Variables.
*
* @param x The x argument.
* @param y The y argument.
*/
static native long hypot(long x, long y);
/**
* Math.log() for Variables.
*
* @param x The argument.
*/
static native long log(long x);
/**
* Math.log10() for Variables.
*
* @param x The argument.
*/
static native long log10(long x);
/**
* Math.max() for Variables.
*
* <p>Returns the greater of a and b. If the values are equivalent, returns a.
*
* @param a The a argument.
* @param b The b argument.
*/
static native long max(long a, long b);
/**
* Math.min() for Variables.
*
* <p>Returns the lesser of a and b. If the values are equivalent, returns a.
*
* @param a The a argument.
* @param b The b argument.
*/
static native long min(long a, long b);
/**
* Math.pow() for Variables.
*
* @param base The base.
* @param power The power.
*/
static native long pow(long base, long power);
/**
* Math.signum() for Variables.
*
* @param x The argument.
*/
static native long signum(long x);
/**
* Math.sin() for Variables.
*
* @param x The argument.
*/
static native long sin(long x);
/**
* Math.sinh() for Variables.
*
* @param x The argument.
*/
static native long sinh(long x);
/**
* Math.sqrt() for Variables.
*
* @param x The argument.
*/
static native long sqrt(long x);
/**
* Math.tan() for Variables.
*
* @param x The argument.
*/
static native long tan(long x);
/**
* Math.tanh() for Variables.
*
* @param x The argument.
*/
static native long tanh(long x);
/**
* Returns the total native memory usage of Variables in bytes.
*
* @return The total native memory usage of Variables in bytes.
*/
static native long totalNativeMemoryUsage();
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,25 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import org.wpilib.math.jni.WPIMathJNI;
/** Variable JNI functions. */
final class VariableMatrixJNI extends WPIMathJNI {
private VariableMatrixJNI() {
// Utility class.
}
/**
* Solves the VariableMatrix equation AX = B for X.
*
* @param A The left-hand side as a flattened row-major matrix.
* @param Acols The number of columns in A.
* @param B The right-hand side as a flattened row-major matrix.
* @param Bcols The number of columns in B.
* @return The solution X as a flattened row-major matrix.
*/
static native long[] solve(long[] A, int Acols, long[] B, int Bcols);
}

View File

@@ -0,0 +1,55 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.autodiff;
import java.util.ArrayDeque;
import java.util.Deque;
import org.wpilib.util.ErrorMessages;
import org.wpilib.util.cleanup.CleanupPool;
/**
* Cleans up implicitly allocated Variables via try-with-resources.
*
* <p>This implements a stack of Variable pools containing a default global pool. The user can
* create additional pools via try-with-resources. Variable and VariableMatrix instances will
* register themselves with the latest pool.
*
* <p>It's strongly recommended to only instantiate this class via try-with-resources so the close()
* methods are always called in the correct order (i.e., nested scopes).
*/
public class VariablePool implements AutoCloseable {
private static Deque<VariablePool> s_variablePoolStack = new ArrayDeque<VariablePool>();
// Default global pool
@SuppressWarnings("PMD.UnusedPrivateField")
private static VariablePool s_globalPool = new VariablePool();
// Cleans up Variables in the scope of this VariablePool
private final CleanupPool m_cleanupPool = new CleanupPool();
/** Default constructor. */
@SuppressWarnings("this-escape")
public VariablePool() {
s_variablePoolStack.addFirst(this);
}
@Override
public void close() {
m_cleanupPool.close();
s_variablePoolStack.removeFirst();
}
/**
* Registers a Variable in the Variable stack for cleanup.
*
* @param variable The Variable to register
* @return The registered Variable
*/
public static Variable register(Variable variable) {
ErrorMessages.requireNonNullParam(variable, "variable", "register");
s_variablePoolStack.getFirst().m_cleanupPool.register(variable);
return variable;
}
}

View File

@@ -4,9 +4,17 @@
package org.wpilib.math.controller;
import static org.wpilib.math.autodiff.Variable.cos;
import static org.wpilib.math.autodiff.Variable.signum;
import org.wpilib.math.autodiff.Gradient;
import org.wpilib.math.autodiff.Hessian;
import org.wpilib.math.autodiff.NumericalIntegration;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.autodiff.VariableMatrix;
import org.wpilib.math.autodiff.VariablePool;
import org.wpilib.math.controller.proto.ArmFeedforwardProto;
import org.wpilib.math.controller.struct.ArmFeedforwardStruct;
import org.wpilib.math.jni.ArmFeedforwardJNI;
import org.wpilib.util.protobuf.ProtobufSerializable;
import org.wpilib.util.struct.StructSerializable;
@@ -191,8 +199,111 @@ public class ArmFeedforward implements ProtobufSerializable, StructSerializable
* @return The computed feedforward in volts.
*/
public double calculate(double currentAngle, double currentVelocity, double nextVelocity) {
return ArmFeedforwardJNI.calculate(
ks, kv, ka, kg, currentAngle, currentVelocity, nextVelocity, m_dt);
// Small kₐ values make the solver ill-conditioned
if (ka < 1e-1) {
double acceleration = (nextVelocity - currentVelocity) / m_dt;
return ks * Math.signum(currentVelocity)
+ kv * currentVelocity
+ ka * acceleration
+ kg * Math.cos(currentAngle);
}
try (var pool = new VariablePool()) {
// Arm dynamics
var A = new VariableMatrix(new double[][] {{0.0, 1.0}, {0.0, -kv / ka}});
var B = new VariableMatrix(new double[][] {{0.0}, {1.0 / ka}});
var r_k = new VariableMatrix(new double[][] {{currentAngle}, {currentVelocity}});
var u_k = new Variable();
// Initial guess
double acceleration = (nextVelocity - currentVelocity) / m_dt;
u_k.setValue(
ks * Math.signum(currentVelocity)
+ kv * currentVelocity
+ ka * acceleration
+ kg * Math.cos(currentAngle));
var r_k1 =
NumericalIntegration.rk4(
(VariableMatrix x, VariableMatrix u) -> {
var c =
new VariableMatrix(
new Variable[][] {
{new Variable(0.0)},
{signum(x.get(1)).times(-ks / ka).plus(cos(x.get(0)).times(-kg / ka))}
});
return A.times(x).plus(B.times(u)).plus(c);
},
r_k,
new VariableMatrix(u_k),
m_dt);
// Minimize difference between desired and actual next velocity
var cost =
new Variable(nextVelocity)
.minus(r_k1.get(1))
.times(new Variable(nextVelocity).minus(r_k1.get(1)));
// Refine solution via Newton's method
{
var xAD = u_k;
double x = xAD.value();
var gradientF = new Gradient(cost, xAD);
var g = gradientF.value();
var hessianF = new Hessian(cost, xAD);
var H = hessianF.value();
double error_k = Double.POSITIVE_INFINITY;
double error_k1 = Math.abs(g.get(0, 0));
// Loop until error stops decreasing or max iterations is reached
for (int iteration = 0; iteration < 50 && error_k1 < (1.0 - 1e-10) * error_k; ++iteration) {
error_k = error_k1;
// Iterate via Newton's method.
//
// xₖ₊₁ = xₖ H⁻¹g
//
// The Hessian is regularized to at least 1e-4.
double p_x = -g.get(0, 0) / Math.max(H.get(0, 0), 1e-4);
// Shrink step until cost goes down
{
double oldCost = cost.value();
double α = 1.0;
double trial_x = x + α * p_x;
xAD.setValue(trial_x);
while (cost.value() > oldCost) {
α *= 0.5;
trial_x = x + α * p_x;
xAD.setValue(trial_x);
}
x = trial_x;
}
xAD.setValue(x);
g = gradientF.value();
H = hessianF.value();
error_k1 = Math.abs(g.get(0, 0));
}
hessianF.close();
gradientF.close();
}
return u_k.value();
}
}
// Rearranging the main equation from the calculate() method yields the

View File

@@ -4,12 +4,14 @@
package org.wpilib.math.geometry;
import static org.wpilib.math.autodiff.Variable.pow;
import static org.wpilib.math.optimization.Constraints.eq;
import static org.wpilib.units.Units.Meters;
import java.util.Objects;
import org.wpilib.math.geometry.proto.Ellipse2dProto;
import org.wpilib.math.geometry.struct.Ellipse2dStruct;
import org.wpilib.math.jni.Ellipse2dJNI;
import org.wpilib.math.optimization.Problem;
import org.wpilib.math.util.Pair;
import org.wpilib.units.measure.Distance;
import org.wpilib.util.protobuf.ProtobufSerializable;
@@ -224,18 +226,38 @@ public class Ellipse2d implements ProtobufSerializable, StructSerializable {
return point;
}
// Rotate the point by the inverse of the ellipse's rotation
var rotPoint =
point.rotateAround(m_center.getTranslation(), m_center.getRotation().unaryMinus());
// Find nearest point
var nearestPoint = new double[2];
Ellipse2dJNI.nearest(
m_center.getX(),
m_center.getY(),
m_center.getRotation().getRadians(),
m_xSemiAxis,
m_ySemiAxis,
point.getX(),
point.getY(),
nearestPoint);
return new Translation2d(nearestPoint[0], nearestPoint[1]);
try (var problem = new Problem()) {
// Point on ellipse
var x = problem.decisionVariable();
x.setValue(rotPoint.getX());
var y = problem.decisionVariable();
y.setValue(rotPoint.getY());
problem.minimize(pow(x.minus(rotPoint.getX()), 2).plus(pow(y.minus(rotPoint.getY()), 2)));
// (x x_c)²/a² + (y y_c)²/b² = 1
// b²(x x_c)² + a²(y y_c)² = a²b²
double a2 = m_xSemiAxis * m_xSemiAxis;
double b2 = m_ySemiAxis * m_ySemiAxis;
problem.subjectTo(
eq(
pow(x.minus(m_center.getX()), 2)
.times(b2)
.plus(pow(y.minus(m_center.getY()), 2).times(a2)),
a2 * b2));
problem.solve();
rotPoint = new Translation2d(x.value(), y.value());
}
// Undo rotation
return rotPoint.rotateAround(m_center.getTranslation(), m_center.getRotation());
}
@Override

View File

@@ -1,36 +0,0 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.jni;
/** ArmFeedforward JNI. */
public final class ArmFeedforwardJNI extends WPIMathJNI {
/**
* Obtain a feedforward voltage from a single jointed arm feedforward object.
*
* <p>Constructs an ArmFeedforward object and runs its currentVelocity and nextVelocity overload
*
* @param ks The ArmFeedforward's static gain in volts.
* @param kv The ArmFeedforward's velocity gain in volt seconds per radian.
* @param ka The ArmFeedforward's acceleration gain in volt seconds² per radian.
* @param kg The ArmFeedforward's gravity gain in volts.
* @param currentAngle The current angle in the calculation in radians.
* @param currentVelocity The current velocity in the calculation in radians per second.
* @param nextVelocity The next velocity in the calculation in radians per second.
* @param dt The time between velocity setpoints in seconds.
* @return The calculated feedforward in volts.
*/
public static native double calculate(
double ks,
double kv,
double ka,
double kg,
double currentAngle,
double currentVelocity,
double nextVelocity,
double dt);
/** Utility class. */
private ArmFeedforwardJNI() {}
}

View File

@@ -1,35 +0,0 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.jni;
/** Ellipse2d JNI. */
public final class Ellipse2dJNI extends WPIMathJNI {
/**
* Returns the nearest point that is contained within the ellipse.
*
* <p>Constructs an Ellipse2d object and runs its nearest() method.
*
* @param centerX The x coordinate of the center of the ellipse in meters.
* @param centerY The y coordinate of the center of the ellipse in meters.
* @param centerHeading The ellipse's rotation in radians.
* @param xSemiAxis The x semi-axis in meters.
* @param ySemiAxis The y semi-axis in meters.
* @param pointX The x coordinate of the point that this will find the nearest point to.
* @param pointY The y coordinate of the point that this will find the nearest point to.
* @param nearestPoint Array to store nearest point into.
*/
public static native void nearest(
double centerX,
double centerY,
double centerHeading,
double xSemiAxis,
double ySemiAxis,
double pointX,
double pointY,
double[] nearestPoint);
/** Utility class. */
private Ellipse2dJNI() {}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,23 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import org.wpilib.math.autodiff.Variable;
/** A vector of equality constraints of the form cₑ(x) = 0. */
@SuppressWarnings("PMD.ArrayIsStoredDirectly")
public class EqualityConstraints {
/** List of equality constraints. */
public Variable[] constraints;
/**
* Constructs an EqualityConstraints.
*
* @param constraints The constraints.
*/
public EqualityConstraints(Variable[] constraints) {
this.constraints = constraints;
}
}

View File

@@ -0,0 +1,23 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import org.wpilib.math.autodiff.Variable;
/** A vector of inequality constraints of the form cᵢ(x) ≥ 0. */
@SuppressWarnings("PMD.ArrayIsStoredDirectly")
public class InequalityConstraints {
/** List of inequality constraints. */
public Variable[] constraints;
/**
* Constructs an InequalityConstraints.
*
* @param constraints The constraints.
*/
public InequalityConstraints(Variable[] constraints) {
this.constraints = constraints;
}
}

View File

@@ -0,0 +1,588 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import static org.wpilib.math.optimization.Constraints.eq;
import static org.wpilib.math.optimization.Constraints.ge;
import static org.wpilib.math.optimization.Constraints.le;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import org.ejml.simple.SimpleMatrix;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.autodiff.VariableBlock;
import org.wpilib.math.autodiff.VariableMatrix;
import org.wpilib.math.optimization.ocp.ConstraintEvaluationFunction;
import org.wpilib.math.optimization.ocp.DynamicsFunction;
import org.wpilib.math.optimization.ocp.DynamicsType;
import org.wpilib.math.optimization.ocp.TimestepMethod;
import org.wpilib.math.optimization.ocp.TranscriptionMethod;
/**
* This class allows the user to pose and solve a constrained optimal control problem (OCP) in a
* variety of ways.
*
* <p>The system is transcripted by one of three methods (direct transcription, direct collocation,
* or single-shooting) and additional constraints can be added.
*
* <p>In direct transcription, each state is a decision variable constrained to the integrated
* dynamics of the previous state. In direct collocation, the trajectory is modeled as a series of
* cubic polynomials where the centerpoint slope is constrained. In single-shooting, states depend
* explicitly as a function of all previous states and all previous inputs.
*
* <p>Explicit ODEs are integrated using RK4.
*
* <p>For explicit ODEs, the function must be in the form dx/dt = f(t, x, u). For discrete state
* transition functions, the function must be in the form xₖ₊₁ = f(t, xₖ, uₖ).
*
* <p>Direct collocation requires an explicit ODE. Direct transcription and single-shooting can use
* either an ODE or state transition function.
*
* <p>https://underactuated.mit.edu/trajopt.html goes into more detail on each transcription method.
*/
public class OCP extends Problem {
private int m_numSteps;
private DynamicsFunction m_dynamics;
private DynamicsType m_dynamicsType;
private VariableMatrix m_X;
private VariableMatrix m_U;
private VariableMatrix m_DT;
/**
* Builds an optimization problem using a system evolution function (explicit ODE or discrete
* state transition function).
*
* @param numStates The number of system states.
* @param numInputs The number of system inputs.
* @param dt The timestep for fixed-step integration.
* @param numSteps The number of control points.
* @param dynamics Function representing an explicit or implicit ODE, or a discrete state
* transition function.
* <ul>
* <li>Explicit: dx/dt = f(x, u, *)
* <li>Implicit: f([x dx/dt]', u, *) = 0
* <li>State transition: xₖ₊₁ = f(xₖ, uₖ)
* </ul>
*
* @param dynamicsType The type of system evolution function.
* @param timestepMethod The timestep method.
* @param transcriptionMethod The transcription method.
*/
public OCP(
int numStates,
int numInputs,
double dt,
int numSteps,
BiFunction<VariableMatrix, VariableMatrix, VariableMatrix> dynamics,
DynamicsType dynamicsType,
TimestepMethod timestepMethod,
TranscriptionMethod transcriptionMethod) {
this(
numStates,
numInputs,
dt,
numSteps,
(Variable t, VariableMatrix x, VariableMatrix u, Variable _dt) -> dynamics.apply(x, u),
dynamicsType,
timestepMethod,
transcriptionMethod);
}
/**
* Builds an optimization problem using a system evolution function (explicit ODE or discrete
* state transition function).
*
* @param numStates The number of system states.
* @param numInputs The number of system inputs.
* @param dt The timestep for fixed-step integration.
* @param numSteps The number of control points.
* @param dynamics Function representing an explicit or implicit ODE, or a discrete state
* transition function.
* <ul>
* <li>Explicit: dx/dt = f(t, x, u, *)
* <li>Implicit: f(t, [x dx/dt]', u, *) = 0
* <li>State transition: xₖ₊₁ = f(t, xₖ, uₖ, dt)
* </ul>
*
* @param dynamicsType The type of system evolution function.
* @param timestepMethod The timestep method.
* @param transcriptionMethod The transcription method.
*/
@SuppressWarnings("this-escape")
public OCP(
int numStates,
int numInputs,
double dt,
int numSteps,
DynamicsFunction dynamics,
DynamicsType dynamicsType,
TimestepMethod timestepMethod,
TranscriptionMethod transcriptionMethod) {
m_numSteps = numSteps;
m_dynamics = dynamics;
m_dynamicsType = dynamicsType;
// u is numSteps + 1 so that the final constraint function evaluation works
m_U = decisionVariable(numInputs, m_numSteps + 1);
if (timestepMethod == TimestepMethod.FIXED) {
m_DT = new VariableMatrix(1, m_numSteps + 1);
for (int i = 0; i < numSteps + 1; ++i) {
m_DT.set(0, i, dt);
}
} else if (timestepMethod == TimestepMethod.VARIABLE_SINGLE) {
Variable single_dt = decisionVariable();
single_dt.setValue(dt);
// Set the member variable matrix to track the decision variable
m_DT = new VariableMatrix(1, m_numSteps + 1);
for (int i = 0; i < numSteps + 1; ++i) {
m_DT.set(0, i, single_dt);
}
} else if (timestepMethod == TimestepMethod.VARIABLE) {
m_DT = decisionVariable(1, m_numSteps + 1);
for (int i = 0; i < numSteps + 1; ++i) {
m_DT.get(0, i).setValue(dt);
}
}
if (transcriptionMethod == TranscriptionMethod.DIRECT_TRANSCRIPTION) {
m_X = decisionVariable(numStates, m_numSteps + 1);
constrainDirectTranscription();
} else if (transcriptionMethod == TranscriptionMethod.DIRECT_COLLOCATION) {
m_X = decisionVariable(numStates, m_numSteps + 1);
constrainDirectCollocation();
} else if (transcriptionMethod == TranscriptionMethod.SINGLE_SHOOTING) {
// In single-shooting the states aren't decision variables, but instead
// depend on the input and previous states
m_X = new VariableMatrix(numStates, m_numSteps + 1);
constrainSingleShooting();
}
}
/**
* Constrains the initial state.
*
* @param initialState the initial state to constrain to.
*/
public void constrainInitialState(double initialState) {
subjectTo(eq(this.initialState(), initialState));
}
/**
* Constrains the initial state.
*
* @param initialState the initial state to constrain to.
*/
public void constrainInitialState(Variable initialState) {
subjectTo(eq(this.initialState(), initialState));
}
/**
* Constrains the initial state.
*
* @param initialState the initial state to constrain to.
*/
public void constrainInitialState(SimpleMatrix initialState) {
subjectTo(eq(this.initialState(), initialState));
}
/**
* Constrains the initial state.
*
* @param initialState the initial state to constrain to.
*/
public void constrainInitialState(VariableMatrix initialState) {
subjectTo(eq(this.initialState(), initialState));
}
/**
* Constrains the initial state.
*
* @param initialState the initial state to constrain to.
*/
public void constrainInitialState(VariableBlock initialState) {
subjectTo(eq(this.initialState(), initialState));
}
/**
* Constrains the final state.
*
* @param finalState the final state to constrain to.
*/
public void constrainFinalState(double finalState) {
subjectTo(eq(this.finalState(), finalState));
}
/**
* Constrains the final state.
*
* @param finalState the final state to constrain to.
*/
public void constrainFinalState(Variable finalState) {
subjectTo(eq(this.finalState(), finalState));
}
/**
* Constrains the final state.
*
* @param finalState the final state to constrain to.
*/
public void constrainFinalState(SimpleMatrix finalState) {
subjectTo(eq(this.finalState(), finalState));
}
/**
* Constrains the final state.
*
* @param finalState the final state to constrain to.
*/
public void constrainFinalState(VariableMatrix finalState) {
subjectTo(eq(this.finalState(), finalState));
}
/**
* Constrains the final state.
*
* @param finalState the final state to constrain to.
*/
public void constrainFinalState(VariableBlock finalState) {
subjectTo(eq(this.finalState(), finalState));
}
/**
* Sets the constraint evaluation function. This function is called `numSteps+1` times, with the
* corresponding state and input VariableMatrices.
*
* @param callback The callback f(x, u) where x is the state and u is the input vector.
*/
public void forEachStep(BiConsumer<VariableMatrix, VariableMatrix> callback) {
for (int i = 0; i < m_numSteps + 1; ++i) {
var x = X().col(i);
var u = U().col(i);
callback.accept(new VariableMatrix(x), new VariableMatrix(u));
}
}
/**
* Sets the constraint evaluation function. This function is called `numSteps+1` times, with the
* corresponding state and input VariableMatrices.
*
* @param callback The callback f(t, x, u, dt) where t is time, x is the state vector, u is the
* input vector, and dt is the timestep duration.
*/
public void forEachStep(ConstraintEvaluationFunction callback) {
var time = new Variable(0.0);
for (int i = 0; i < m_numSteps + 1; ++i) {
var x = X().col(i);
var u = U().col(i);
var dt = this.dt().get(0, i);
callback.accept(time, new VariableMatrix(x), new VariableMatrix(u), dt);
time = time.plus(dt);
}
}
/**
* Sets a lower bound on the input.
*
* @param lowerBound The lower bound that inputs must always be above. Must be shaped
* (numInputs)x1.
*/
public void setLowerInputBound(double lowerBound) {
for (int i = 0; i < m_numSteps + 1; ++i) {
subjectTo(ge(U().col(i), lowerBound));
}
}
/**
* Sets a lower bound on the input.
*
* @param lowerBound The lower bound that inputs must always be above. Must be shaped
* (numInputs)x1.
*/
public void setLowerInputBound(Variable lowerBound) {
for (int i = 0; i < m_numSteps + 1; ++i) {
subjectTo(ge(U().col(i), lowerBound));
}
}
/**
* Sets a lower bound on the input.
*
* @param lowerBound The lower bound that inputs must always be above. Must be shaped
* (numInputs)x1.
*/
public void setLowerInputBound(SimpleMatrix lowerBound) {
for (int i = 0; i < m_numSteps + 1; ++i) {
subjectTo(ge(U().col(i), lowerBound));
}
}
/**
* Sets a lower bound on the input.
*
* @param lowerBound The lower bound that inputs must always be above. Must be shaped
* (numInputs)x1.
*/
public void setLowerInputBound(VariableMatrix lowerBound) {
for (int i = 0; i < m_numSteps + 1; ++i) {
subjectTo(ge(U().col(i), lowerBound));
}
}
/**
* Sets a lower bound on the input.
*
* @param lowerBound The lower bound that inputs must always be above. Must be shaped
* (numInputs)x1.
*/
public void setLowerInputBound(VariableBlock lowerBound) {
for (int i = 0; i < m_numSteps + 1; ++i) {
subjectTo(ge(U().col(i), lowerBound));
}
}
/**
* Sets an upper bound on the input.
*
* @param upperBound The upper bound that inputs must always be below. Must be shaped
* (numInputs)x1.
*/
public void setUpperInputBound(double upperBound) {
for (int i = 0; i < m_numSteps + 1; ++i) {
subjectTo(le(U().col(i), upperBound));
}
}
/**
* Sets an upper bound on the input.
*
* @param upperBound The upper bound that inputs must always be below. Must be shaped
* (numInputs)x1.
*/
public void setUpperInputBound(Variable upperBound) {
for (int i = 0; i < m_numSteps + 1; ++i) {
subjectTo(le(U().col(i), upperBound));
}
}
/**
* Sets an upper bound on the input.
*
* @param upperBound The upper bound that inputs must always be below. Must be shaped
* (numInputs)x1.
*/
public void setUpperInputBound(SimpleMatrix upperBound) {
for (int i = 0; i < m_numSteps + 1; ++i) {
subjectTo(le(U().col(i), upperBound));
}
}
/**
* Sets an upper bound on the input.
*
* @param upperBound The upper bound that inputs must always be below. Must be shaped
* (numInputs)x1.
*/
public void setUpperInputBound(VariableMatrix upperBound) {
for (int i = 0; i < m_numSteps + 1; ++i) {
subjectTo(le(U().col(i), upperBound));
}
}
/**
* Sets an upper bound on the input.
*
* @param upperBound The upper bound that inputs must always be below. Must be shaped
* (numInputs)x1.
*/
public void setUpperInputBound(VariableBlock upperBound) {
for (int i = 0; i < m_numSteps + 1; ++i) {
subjectTo(le(U().col(i), upperBound));
}
}
/**
* Sets a lower bound on the timestep.
*
* @param minTimestep The minimum timestep in seconds.
*/
public void setMinTimestep(double minTimestep) {
subjectTo(ge(dt(), minTimestep));
}
/**
* Sets an upper bound on the timestep.
*
* @param maxTimestep The maximum timestep in seconds.
*/
public void setMaxTimestep(double maxTimestep) {
subjectTo(le(dt(), maxTimestep));
}
/**
* Gets the state variables. After the problem is solved, this will contain the optimized
* trajectory.
*
* <p>Shaped (numStates)x(numSteps+1).
*
* @return The state variable matrix.
*/
public VariableMatrix X() {
return m_X;
}
/**
* Gets the input variables. After the problem is solved, this will contain the inputs
* corresponding to the optimized trajectory.
*
* <p>Shaped (numInputs)x(numSteps+1), although the last input step is unused in the trajectory.
*
* @return The input variable matrix.
*/
public VariableMatrix U() {
return m_U;
}
/**
* Gets the timestep variables. After the problem is solved, this will contain the timesteps
* corresponding to the optimized trajectory.
*
* <p>Shaped 1x(numSteps+1), although the last timestep is unused in the trajectory.
*
* @return The timestep variable matrix.
*/
public VariableMatrix dt() {
return m_DT;
}
/**
* Gets the initial state in the trajectory.
*
* @return The initial state of the trajectory.
*/
public VariableMatrix initialState() {
return new VariableMatrix(m_X.col(0));
}
/**
* Gets the final state in the trajectory.
*
* @return The final state of the trajectory.
*/
public VariableMatrix finalState() {
return new VariableMatrix(m_X.col(m_numSteps));
}
/**
* Performs 4th order Runge-Kutta integration of dx/dt = f(t, x, u) for dt.
*
* @param f The function to integrate. It must take two arguments x and u.
* @param x The initial value of x.
* @param u The value u held constant over the integration period.
* @param t0 The initial time.
* @param dt The time over which to integrate.
*/
private static VariableMatrix rk4(
DynamicsFunction f, VariableMatrix x, VariableMatrix u, Variable t0, Variable dt) {
var halfdt = dt.times(0.5);
VariableMatrix k1 = f.apply(t0, x, u, dt);
VariableMatrix k2 = f.apply(t0.plus(halfdt), x.plus(k1.times(halfdt)), u, dt);
VariableMatrix k3 = f.apply(t0.plus(halfdt), x.plus(k2.times(halfdt)), u, dt);
VariableMatrix k4 = f.apply(t0.plus(dt), x.plus(k3.times(dt)), u, dt);
return x.plus(k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4).times(dt.div(6.0)));
}
/** Applies direct collocation dynamics constraints. */
private void constrainDirectCollocation() {
assert m_dynamicsType == DynamicsType.EXPLICIT_ODE;
var time = new Variable(0.0);
// Derivation at https://mec560sbu.github.io/2016/09/30/direct_collocation/
for (int i = 0; i < m_numSteps; ++i) {
Variable h = dt().get(0, i);
var f = m_dynamics;
var t_begin = time;
var t_end = t_begin.plus(h);
var x_begin = X().col(i);
var x_end = X().col(i + 1);
var u_begin = U().col(i);
var u_end = U().col(i + 1);
var xdot_begin =
f.apply(t_begin, new VariableMatrix(x_begin), new VariableMatrix(u_begin), h);
var xdot_end = f.apply(t_end, new VariableMatrix(x_end), new VariableMatrix(u_end), h);
var xdot_c =
x_begin
.minus(x_end)
.times(new Variable(-3).div(h.times(2)))
.minus(xdot_begin.plus(xdot_end).times(0.25));
var t_c = t_begin.plus(h.times(0.5));
var x_c = x_begin.plus(x_end).times(0.5).plus(xdot_begin.minus(xdot_end).times(h.div(8)));
var u_c = u_begin.plus(u_end).times(0.5);
subjectTo(eq(xdot_c, f.apply(t_c, x_c, u_c, h)));
time = time.plus(h);
}
}
/** Applies direct transcription dynamics constraints. */
private void constrainDirectTranscription() {
var time = new Variable(0.0);
for (int i = 0; i < m_numSteps; ++i) {
var x_begin = X().col(i);
var x_end = X().col(i + 1);
var u = U().col(i);
Variable dt = this.dt().get(0, i);
if (m_dynamicsType == DynamicsType.EXPLICIT_ODE) {
subjectTo(
eq(
x_end,
rk4(m_dynamics, new VariableMatrix(x_begin), new VariableMatrix(u), time, dt)));
} else if (m_dynamicsType == DynamicsType.DISCRETE) {
subjectTo(
eq(
x_end,
m_dynamics.apply(time, new VariableMatrix(x_begin), new VariableMatrix(u), dt)));
}
time = time.plus(dt);
}
}
/** Applies single shooting dynamics constraints. */
private void constrainSingleShooting() {
var time = new Variable(0.0);
for (int i = 0; i < m_numSteps; ++i) {
var x_begin = X().col(i);
var x_end = X().col(i + 1);
var u = U().col(i);
Variable dt = this.dt().get(0, i);
if (m_dynamicsType == DynamicsType.EXPLICIT_ODE) {
x_end.set(rk4(m_dynamics, new VariableMatrix(x_begin), new VariableMatrix(u), time, dt));
} else if (m_dynamicsType == DynamicsType.DISCRETE) {
x_end.set(m_dynamics.apply(time, new VariableMatrix(x_begin), new VariableMatrix(u), dt));
}
time = time.plus(dt);
}
}
}

View File

@@ -0,0 +1,312 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import java.util.ArrayList;
import java.util.function.Predicate;
import org.ejml.data.DMatrixRMaj;
import org.ejml.simple.SimpleMatrix;
import org.wpilib.math.autodiff.ExpressionType;
import org.wpilib.math.autodiff.NativeSparseTriplets;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.autodiff.VariableMatrix;
import org.wpilib.math.autodiff.VariablePool;
import org.wpilib.math.optimization.solver.ExitStatus;
import org.wpilib.math.optimization.solver.IterationInfo;
import org.wpilib.math.optimization.solver.Options;
/**
* This class allows the user to pose a constrained nonlinear optimization problem in natural
* mathematical notation and solve it.
*
* <p>This class supports problems of the form:
*
* <pre>
* minₓ f(x)
* subject to cₑ(x) = 0
* cᵢ(x) ≥ 0
* </pre>
*
* <p>where f(x) is the scalar cost function, x is the vector of decision variables (variables the
* solver can tweak to minimize the cost function), cᵢ(x) are the inequality constraints, and cₑ(x)
* are the equality constraints. Constraints are equations or inequalities of the decision variables
* that constrain what values the solver is allowed to use when searching for an optimal solution.
*
* <p>The nice thing about this class is users don't have to put their system in the form shown
* above manually; they can write it in natural mathematical form and it'll be converted for them.
*/
public class Problem implements AutoCloseable {
private long m_handle;
// The iteration callbacks
private final ArrayList<Predicate<IterationInfo>> m_iterationCallbacks = new ArrayList<>();
// Cleans up Variables allocated within Problem's scope
private final VariablePool m_pool = new VariablePool();
/** Construct the optimization problem. */
@SuppressWarnings("this-escape")
public Problem() {
m_handle = ProblemJNI.create();
}
@Override
public void close() {
if (m_handle != 0) {
ProblemJNI.destroy(m_handle);
m_handle = 0;
m_pool.close();
}
}
/**
* Creates a decision variable in the optimization problem.
*
* <p>Decision variables have an initial value of zero.
*
* @return A decision variable in the optimization problem.
*/
public Variable decisionVariable() {
var handles = ProblemJNI.decisionVariable(m_handle, 1, 1);
return new Variable(Variable.HANDLE, handles[0]);
}
/**
* Creates a column vector of decision variables in the optimization problem.
*
* <p>Decision variables have an initial value of zero.
*
* @param rows Number of column vector rows.
* @return A column vector of decision variables in the optimization problem.
*/
public VariableMatrix decisionVariable(int rows) {
return decisionVariable(rows, 1);
}
/**
* Creates a matrix of decision variables in the optimization problem.
*
* <p>Decision variables have an initial value of zero.
*
* @param rows Number of matrix rows.
* @param cols Number of matrix columns.
* @return A matrix of decision variables in the optimization problem.
*/
public VariableMatrix decisionVariable(int rows, int cols) {
return new VariableMatrix(rows, cols, ProblemJNI.decisionVariable(m_handle, rows, cols));
}
/**
* Creates a symmetric matrix of decision variables in the optimization problem.
*
* <p>Variable instances are reused across the diagonal, which helps reduce problem
* dimensionality.
*
* <p>Decision variables have an initial value of zero.
*
* @param rows Number of matrix rows.
* @return A symmetric matrix of decision varaibles in the optimization problem.
*/
public VariableMatrix symmetricDecisionVariable(int rows) {
return new VariableMatrix(rows, rows, ProblemJNI.symmetricDecisionVariable(m_handle, rows));
}
/**
* Tells the solver to minimize the output of the given cost function.
*
* <p>Note that this is optional. If only constraints are specified, the solver will find the
* closest solution to the initial conditions that's in the feasible set.
*
* @param cost The cost function to minimize.
*/
public void minimize(Variable cost) {
ProblemJNI.minimize(m_handle, cost.getHandle());
}
/**
* Tells the solver to minimize the output of the given cost function.
*
* <p>Note that this is optional. If only constraints are specified, the solver will find the
* closest solution to the initial conditions that's in the feasible set.
*
* @param cost The cost function to minimize. An assertion is raised if the VariableMatrix isn't
* 1x1.
*/
public void minimize(VariableMatrix cost) {
assert cost.rows() == 1 && cost.cols() == 1;
minimize(cost.get(0, 0));
}
/**
* Tells the solver to maximize the output of the given objective function.
*
* <p>Note that this is optional. If only constraints are specified, the solver will find the
* closest solution to the initial conditions that's in the feasible set.
*
* @param objective The objective function to maximize.
*/
public void maximize(Variable objective) {
ProblemJNI.maximize(m_handle, objective.getHandle());
}
/**
* Tells the solver to maximize the output of the given objective function.
*
* <p>Note that this is optional. If only constraints are specified, the solver will find the
* closest solution to the initial conditions that's in the feasible set.
*
* @param objective The objective function to maximize. An assertion is raised if the
* VariableMatrix isn't 1x1.
*/
public void maximize(VariableMatrix objective) {
assert objective.rows() == 1 && objective.cols() == 1;
maximize(objective.get(0, 0));
}
/**
* Tells the solver to solve the problem while satisfying the given equality constraint.
*
* @param constraint The constraint to satisfy.
*/
public void subjectTo(EqualityConstraints constraint) {
var constraintHandles = new long[constraint.constraints.length];
for (int i = 0; i < constraintHandles.length; ++i) {
constraintHandles[i] = constraint.constraints[i].getHandle();
}
ProblemJNI.subjectToEq(m_handle, constraintHandles);
}
/**
* Tells the solver to solve the problem while satisfying the given inequality constraint.
*
* @param constraint The constraint to satisfy.
*/
public void subjectTo(InequalityConstraints constraint) {
var constraintHandles = new long[constraint.constraints.length];
for (int i = 0; i < constraintHandles.length; ++i) {
constraintHandles[i] = constraint.constraints[i].getHandle();
}
ProblemJNI.subjectToIneq(m_handle, constraintHandles);
}
/**
* Returns the cost function's type.
*
* @return The cost function's type.
*/
public ExpressionType costFunctionType() {
return ExpressionType.fromInt(ProblemJNI.costFunctionType(m_handle));
}
/**
* Returns the type of the highest order equality constraint.
*
* @return The type of the highest order equality constraint.
*/
public ExpressionType equalityConstraintType() {
return ExpressionType.fromInt(ProblemJNI.equalityConstraintType(m_handle));
}
/**
* Returns the type of the highest order inequality constraint.
*
* @return The type of the highest order inequality constraint.
*/
public ExpressionType inequalityConstraintType() {
return ExpressionType.fromInt(ProblemJNI.inequalityConstraintType(m_handle));
}
/**
* Solves the optimization problem. The solution will be stored in the original variables used to
* construct the problem.
*
* @return The solver status.
*/
public ExitStatus solve() {
return solve(new Options());
}
/**
* Solves the optimization problem. The solution will be stored in the original variables used to
* construct the problem.
*
* @param options Solver options.
* @return The solver status.
*/
public ExitStatus solve(Options options) {
return ExitStatus.fromInt(
ProblemJNI.solve(
this,
m_handle,
options.tolerance,
options.maxIterations,
options.timeout,
options.feasibleIPM,
options.diagnostics));
}
/**
* Adds a callback to be called at the beginning of each solver iteration.
*
* <p>The callback for this overload should return bool.
*
* @param callback The callback. Returning true from the callback causes the solver to exit early
* with the solution it has so far.
*/
public void addCallback(Predicate<IterationInfo> callback) {
m_iterationCallbacks.add(callback);
}
/** Clears the registered callbacks. */
public void clearCallbacks() {
m_iterationCallbacks.clear();
}
/**
* Runs the registered callbacks.
*
* <p>This function is called by native code in ProblemJNI.
*
* @param numEqualityConstraints The number of equality constraints.
* @param numInequalityConstraints The number of inequality constraints.
* @param iteration The solver iteration.
* @param x The decision variable values.
* @param gTriplets Gradient triplets.
* @param HTriplets Hessian triplets.
* @param A_eTriplets Equality constraint Jacobian triplets.
* @param A_iTriplets Inequality constraint Jacobian triplets.
* @return True if the solver shold exit early.
*/
boolean runCallbacks(
int numEqualityConstraints,
int numInequalityConstraints,
int iteration,
double[] x,
NativeSparseTriplets gTriplets,
NativeSparseTriplets HTriplets,
NativeSparseTriplets A_eTriplets,
NativeSparseTriplets A_iTriplets) {
if (m_iterationCallbacks.isEmpty()) {
return false;
}
var info =
new IterationInfo(
iteration,
new SimpleMatrix(DMatrixRMaj.wrap(x.length, 1, x)),
gTriplets.toSimpleMatrix(x.length, 1),
HTriplets.toSimpleMatrix(x.length, x.length),
A_eTriplets.toSimpleMatrix(numEqualityConstraints, x.length),
A_iTriplets.toSimpleMatrix(numInequalityConstraints, x.length));
for (var callback : m_iterationCallbacks) {
if (callback.test(info)) {
return true;
}
}
return false;
}
}

View File

@@ -0,0 +1,134 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization;
import org.wpilib.math.jni.WPIMathJNI;
/** Problem JNI functions. */
final class ProblemJNI extends WPIMathJNI {
private ProblemJNI() {
// Utility class.
}
/** Construct the optimization problem. */
static native long create();
/**
* Destruct the optimization problem.
*
* @param handle Problem handle.
*/
static native void destroy(long handle);
/**
* Create a matrix of decision variables in the optimization problem.
*
* @param handle Problem handle.
* @param rows Number of matrix rows.
* @param cols Number of matrix columns.
* @return A matrix of decision variables in the optimization problem.
*/
static native long[] decisionVariable(long handle, int rows, int cols);
/**
* Create a symmetric matrix of decision variables in the optimization problem.
*
* <p>Variable instances are reused across the diagonal, which helps reduce problem
* dimensionality.
*
* @param handle Problem handle.
* @param rows Number of matrix rows.
* @return A symmetric matrix of decision varaibles in the optimization problem.
*/
static native long[] symmetricDecisionVariable(long handle, int rows);
/**
* Tells the solver to minimize the output of the given cost function.
*
* <p>Note that this is optional. If only constraints are specified, the solver will find the
* closest solution to the initial conditions that's in the feasible set.
*
* @param handle Problem handle.
* @param costHandle Variable handle of the cost function to minimize.
*/
static native void minimize(long handle, long costHandle);
/**
* Tells the solver to maximize the output of the given objective function.
*
* <p>Note that this is optional. If only constraints are specified, the solver will find the
* closest solution to the initial conditions that's in the feasible set.
*
* @param handle Problem handle.
* @param objectiveHandle Variable handle of the objective function to maximize.
*/
static native void maximize(long handle, long objectiveHandle);
/**
* Tells the solver to solve the problem while satisfying the given equality constraint.
*
* @param handle Problem handle.
* @param constraintHandles Constraint handles.
*/
static native void subjectToEq(long handle, long[] constraintHandles);
/**
* Tells the solver to solve the problem while satisfying the given inequality constraint.
*
* @param handle Problem handle.
* @param constraintHandles Constraint handles.
*/
static native void subjectToIneq(long handle, long[] constraintHandles);
/**
* Returns the cost function's type.
*
* @param handle Problem handle.
* @return The cost function's type.
*/
static native int costFunctionType(long handle);
/**
* Returns the type of the highest order equality constraint.
*
* @param handle Problem handle.
* @return The type of the highest order equality constraint.
*/
static native int equalityConstraintType(long handle);
/**
* Returns the type of the highest order inequality constraint.
*
* @param handle Problem handle.
* @return The type of the highest order inequality constraint.
*/
static native int inequalityConstraintType(long handle);
/**
* Solve the optimization problem. The solution will be stored in the original variables used to
* construct the problem.
*
* @param obj Java Problem object.
* @param handle Problem handle.
* @param tolerance The solver will stop once the error is below this tolerance.
* @param maxIterations The maximum number of solver iterations before returning a solution.
* @param timeout The maximum elapsed wall clock time in seconds before returning a solution.
* @param feasibleIPM Enables the feasible interior-point method. When the inequality constraints
* are all feasible, step sizes are reduced when necessary to prevent them becoming infeasible
* again. This is useful when parts of the problem are ill-conditioned in infeasible regions
* (e.g., square root of a negative value). This can slow or prevent progress toward a
* solution though, so only enable it if necessary.
* @param diagnnostics Enables diagnostic prints.
* @return The solver status.
*/
static native int solve(
Problem obj,
long handle,
double tolerance,
int maxIterations,
double timeout,
boolean feasibleIPM,
boolean diagnostics);
}

View File

@@ -0,0 +1,25 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization.ocp;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.autodiff.VariableMatrix;
/**
* A callback f(t, x, u, dt) where t is time, x is the state vector, u is the input vector, and dt
* is the timestep duration.
*/
@FunctionalInterface
public interface ConstraintEvaluationFunction {
/**
* Applies this function with the arguments.
*
* @param t Time in seconds.
* @param x State vector.
* @param u Input vector.
* @param dt Timestep duration in seconds.
*/
void accept(Variable t, VariableMatrix x, VariableMatrix u, Variable dt);
}

View File

@@ -0,0 +1,31 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization.ocp;
import org.wpilib.math.autodiff.Variable;
import org.wpilib.math.autodiff.VariableMatrix;
/**
* Function representing an explicit or implicit ODE, or a discrete state transition function.
*
* <ul>
* <li>Explicit: dx/dt = f(t, x, u, *)
* <li>Implicit: f(t, [x dx/dt]', u, *) = 0
* <li>State transition: xₖ₊₁ = f(t, xₖ, uₖ, dt)
* </ul>
*/
@FunctionalInterface
public interface DynamicsFunction {
/**
* Applies this function with the arguments and returns the result.
*
* @param t Time in seconds.
* @param x State vector.
* @param u Input vector.
* @param dt Timestep duration in seconds.
* @return The state derivative dx/dt.
*/
VariableMatrix apply(Variable t, VariableMatrix x, VariableMatrix u, Variable dt);
}

View File

@@ -0,0 +1,20 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization.ocp;
/** Enum describing a type of system dynamics constraints. */
public enum DynamicsType {
/** The dynamics are a function in the form dx/dt = f(t, x, u). */
EXPLICIT_ODE(0),
/** The dynamics are a function in the form xₖ₊₁ = f(t, xₖ, uₖ). */
DISCRETE(1);
/** DynamicsType value. */
public final int value;
DynamicsType(int value) {
this.value = value;
}
}

View File

@@ -0,0 +1,22 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization.ocp;
/** Enum describing the type of system timestep. */
public enum TimestepMethod {
/** The timestep is a fixed constant. */
FIXED(0),
/** The timesteps are allowed to vary as independent decision variables. */
VARIABLE(1),
/** The timesteps are equal length but allowed to vary as a single decision variable. */
VARIABLE_SINGLE(2);
/** TimestepMethod value. */
public final int value;
TimestepMethod(int value) {
this.value = value;
}
}

View File

@@ -0,0 +1,27 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization.ocp;
/** Enum describing an OCP transcription method. */
public enum TranscriptionMethod {
/**
* Each state is a decision variable constrained to the integrated dynamics of the previous state.
*/
DIRECT_TRANSCRIPTION(0),
/**
* The trajectory is modeled as a series of cubic polynomials where the centerpoint slope is
* constrained.
*/
DIRECT_COLLOCATION(1),
/** States depend explicitly as a function of all previous states and all previous inputs. */
SINGLE_SHOOTING(2);
/** TranscriptionMethod value. */
public final int value;
TranscriptionMethod(int value) {
this.value = value;
}
}

View File

@@ -0,0 +1,66 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization.solver;
/** Solver exit status. Negative values indicate failure. */
public enum ExitStatus {
/** Solved the problem to the desired tolerance. */
SUCCESS(0),
/** The solver returned its solution so far after the user requested a stop. */
CALLBACK_REQUESTED_STOP(1),
/** The solver determined the problem to be overconstrained and gave up. */
TOO_FEW_DOFS(-1),
/** The solver determined the problem to be locally infeasible and gave up. */
LOCALLY_INFEASIBLE(-2),
/** The problem setup frontend determined the problem to have an empty feasible region. */
GLOBALLY_INFEASIBLE(-3),
/** The linear system factorization failed. */
FACTORIZATION_FAILED(-4),
/**
* The solver failed to reach the desired tolerance, and feasibility restoration failed to
* converge.
*/
FEASIBILITY_RESTORATION_FAILED(-5),
/** The solver encountered nonfinite initial cost, constraints, or derivatives and gave up. */
NONFINITE_INITIAL_GUESS(-6),
/** The solver encountered diverging primal iterates xₖ and/or sₖ and gave up. */
DIVERGING_ITERATES(-7),
/** The solver returned its solution so far after exceeding the maximum number of iterations. */
MAX_ITERATIONS_EXCEEDED(-8),
/**
* The solver returned its solution so far after exceeding the maximum elapsed wall clock time.
*/
TIMEOUT(-9);
/** ExitStatus value. */
public final int value;
ExitStatus(int value) {
this.value = value;
}
/**
* Converts integer to its corresponding enum value.
*
* @param x The integer.
* @return The enum value.
*/
public static ExitStatus fromInt(int x) {
return switch (x) {
case 0 -> ExitStatus.SUCCESS;
case 1 -> ExitStatus.CALLBACK_REQUESTED_STOP;
case -1 -> ExitStatus.TOO_FEW_DOFS;
case -2 -> ExitStatus.LOCALLY_INFEASIBLE;
case -3 -> ExitStatus.GLOBALLY_INFEASIBLE;
case -4 -> ExitStatus.FACTORIZATION_FAILED;
case -5 -> ExitStatus.FEASIBILITY_RESTORATION_FAILED;
case -6 -> ExitStatus.NONFINITE_INITIAL_GUESS;
case -7 -> ExitStatus.DIVERGING_ITERATES;
case -8 -> ExitStatus.MAX_ITERATIONS_EXCEEDED;
case -9 -> ExitStatus.TIMEOUT;
default -> null;
};
}
}

View File

@@ -0,0 +1,53 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization.solver;
import org.ejml.simple.SimpleMatrix;
/** Solver iteration information exposed to an iteration callback. */
public class IterationInfo {
/** The solver iteration. */
public final int iteration;
/** The decision variables (dense internal storage). */
public final SimpleMatrix x;
/** The gradient of the cost function (sparse internal storage). */
public final SimpleMatrix g;
/** The Hessian of the Lagrangian (sparse internal storage). */
public final SimpleMatrix H;
/** The equality constraint Jacobian (sparse internal storage). */
public final SimpleMatrix A_e;
/** The inequality constraint Jacobian (sparse internal storage). */
public final SimpleMatrix A_i;
/**
* Constructs iteration info.
*
* @param iteration The solver iteration.
* @param x The decision variables (dense internal storage).
* @param g The gradient of the cost function (sparse internal storage).
* @param H The Hessian of the Lagrangian (sparse internal storage).
* @param A_e The equality constraint Jacobian (sparse internal storage).
* @param A_i The inequality constraint Jacobian (sparse internal storage).
*/
public IterationInfo(
int iteration,
SimpleMatrix x,
SimpleMatrix g,
SimpleMatrix H,
SimpleMatrix A_e,
SimpleMatrix A_i) {
this.iteration = iteration;
this.x = x;
this.g = g;
this.H = H;
this.A_e = A_e;
this.A_i = A_i;
}
}

View File

@@ -0,0 +1,98 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package org.wpilib.math.optimization.solver;
/** Solver options. */
public class Options {
/** The solver will stop once the error is below this tolerance. */
public double tolerance = 1e-8;
/** The maximum number of solver iterations before returning a solution. */
public int maxIterations = 5000;
/** The maximum elapsed wall clock time in seconds before returning a solution. */
public double timeout = Double.POSITIVE_INFINITY;
/**
* Enables the feasible interior-point method.
*
* <p>When the inequality constraints are all feasible, step sizes are reduced when necessary to
* prevent them becoming infeasible again. This is useful when parts of the problem are
* ill-conditioned in infeasible regions (e.g., square root of a negative value). This can slow or
* prevent progress toward a solution though, so only enable it if necessary.
*/
public boolean feasibleIPM = false;
/**
* Enables diagnostic output.
*
* <p>See <a
* href="https://sleipnirgroup.github.io/Sleipnir/md_usage.html#output">https://sleipnirgroup.github.io/Sleipnir/md_usage.html#output</a>
* for more information.
*/
public boolean diagnostics = false;
/** Default options. */
public Options() {}
/**
* Set tolerance.
*
* @param tolerance The solver will stop once the error is below this tolerance.
* @return This Options object.
*/
public Options withTolerance(double tolerance) {
this.tolerance = tolerance;
return this;
}
/**
* Set max iterations.
*
* @param maxIterations The maximum number of solver iterations before returning a solution.
* @return This Options object.
*/
public Options withMaxIterations(int maxIterations) {
this.maxIterations = maxIterations;
return this;
}
/**
* Set timeout.
*
* @param timeout The maximum elapsed wall clock time in seconds before returning a solution.
* @return This Options object.
*/
public Options withTimeout(double timeout) {
this.timeout = timeout;
return this;
}
/**
* Enable or disable feasible IPM.
*
* @param feasibleIPM Enables the feasible interior-point method. When the inequality constraints
* are all feasible, step sizes are reduced when necessary to prevent them becoming infeasible
* again. This is useful when parts of the problem are ill-conditioned in infeasible regions
* (e.g., square root of a negative value). This can slow or prevent progress toward a
* solution though, so only enable it if necessary.
* @return This Options object.
*/
public Options withFeasibleIPM(boolean feasibleIPM) {
this.feasibleIPM = feasibleIPM;
return this;
}
/**
* Enable or disable diagnostics.
*
* @param diagnostics Enables diagnostic prints.
* @return This Options object.
*/
public Options withDiagnostics(boolean diagnostics) {
this.diagnostics = diagnostics;
return this;
}
}

View File

@@ -8,6 +8,7 @@ import java.util.function.BiFunction;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.UnaryOperator;
import org.ejml.simple.SimpleMatrix;
import org.wpilib.math.linalg.Matrix;
import org.wpilib.math.numbers.N1;
import org.wpilib.math.util.Num;
@@ -56,6 +57,30 @@ public final class NumericalIntegration {
return x + h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
}
/**
* Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
*
* @param f The function to integrate. It must take two arguments x and u.
* @param x The initial value of x.
* @param u The value u held constant over the integration period.
* @param dt The time over which to integrate.
* @return the integration of dx/dt = f(x, u) for dt.
*/
public static SimpleMatrix rk4(
BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> f,
SimpleMatrix x,
SimpleMatrix u,
double dt) {
var h = dt;
var k1 = f.apply(x, u);
var k2 = f.apply(x.plus(k1.scale(h * 0.5)), u);
var k3 = f.apply(x.plus(k2.scale(h * 0.5)), u);
var k4 = f.apply(x.plus(k3.scale(h)), u);
return x.plus(k1.plus(k2.scale(2.0)).plus(k3.scale(2.0)).plus(k4).scale(h / 6.0));
}
/**
* Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
*

View File

@@ -1,37 +0,0 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#include <jni.h>
#include "org_wpilib_math_jni_ArmFeedforwardJNI.h"
#include "wpi/math/controller/ArmFeedforward.hpp"
#include "wpi/util/jni_util.hpp"
using namespace wpi::util::java;
extern "C" {
/*
* Class: org_wpilib_math_jni_ArmFeedforwardJNI
* Method: calculate
* Signature: (DDDDDDDD)D
*/
JNIEXPORT jdouble JNICALL
Java_org_wpilib_math_jni_ArmFeedforwardJNI_calculate
(JNIEnv* env, jclass, jdouble ks, jdouble kv, jdouble ka, jdouble kg,
jdouble currentAngle, jdouble currentVelocity, jdouble nextVelocity,
jdouble dt)
{
return wpi::math::ArmFeedforward{
wpi::units::volt_t{ks}, wpi::units::volt_t{kg},
wpi::units::unit_t<wpi::math::ArmFeedforward::kv_unit>{kv},
wpi::units::unit_t<wpi::math::ArmFeedforward::ka_unit>{ka},
wpi::units::second_t{dt}}
.Calculate(wpi::units::radian_t{currentAngle},
wpi::units::radians_per_second_t{currentVelocity},
wpi::units::radians_per_second_t{nextVelocity})
.value();
}
} // extern "C"

View File

@@ -1,39 +0,0 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#include <jni.h>
#include "org_wpilib_math_jni_Ellipse2dJNI.h"
#include "wpi/math/geometry/Ellipse2d.hpp"
#include "wpi/util/array.hpp"
#include "wpi/util/jni_util.hpp"
using namespace wpi::util::java;
extern "C" {
/*
* Class: org_wpilib_math_jni_Ellipse2dJNI
* Method: nearest
* Signature: (DDDDDDD[D)V
*/
JNIEXPORT void JNICALL
Java_org_wpilib_math_jni_Ellipse2dJNI_nearest
(JNIEnv* env, jclass, jdouble centerX, jdouble centerY, jdouble centerHeading,
jdouble xSemiAxis, jdouble ySemiAxis, jdouble pointX, jdouble pointY,
jdoubleArray nearestPoint)
{
auto point =
wpi::math::Ellipse2d{
wpi::math::Pose2d{wpi::units::meter_t{centerX},
wpi::units::meter_t{centerY},
wpi::units::radian_t{centerHeading}},
wpi::units::meter_t{xSemiAxis}, wpi::units::meter_t{ySemiAxis}}
.Nearest({wpi::units::meter_t{pointX}, wpi::units::meter_t{pointY}});
wpi::util::array buf{point.X().value(), point.Y().value()};
env->SetDoubleArrayRegion(nearestPoint, 0, 2, buf.data());
}
} // extern "C"

View File

@@ -0,0 +1,65 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <jni.h>
#include <concepts>
#include <vector>
#include <Eigen/SparseCore>
#include "wpi/util/jni_util.hpp"
namespace wpi::math::detail {
/**
* Converts Eigen sparse matrix to triplets.
*
* @param env JNI environment.
* @param mat Eigen sparse matrix to convert.
* @return NativeSparseTriplets instance.
*/
template <typename Derived>
requires std::derived_from<Derived, Eigen::SparseCompressedBase<Derived>>
jobject GetTriplets(JNIEnv* env, const Derived& mat) {
const int nonZeros = mat.nonZeros();
std::vector<int> rows;
rows.reserve(nonZeros);
std::vector<int> cols;
cols.reserve(nonZeros);
std::vector<double> values;
values.reserve(nonZeros);
for (int k = 0; k < mat.outerSize(); ++k) {
for (typename Derived::InnerIterator it{mat, k}; it; ++it) {
rows.emplace_back(it.row());
cols.emplace_back(it.col());
values.emplace_back(it.value());
}
}
// Find NativeSparseTriplets class
static wpi::util::java::JClass cls{
env, "org/wpilib/math/autodiff/NativeSparseTriplets"};
if (!cls) {
return nullptr;
}
// Find NativeSparseTriplets constructor
static jmethodID ctor = env->GetMethodID(cls, "<init>", "([I[I[D)V");
if (!ctor) {
return nullptr;
}
return env->NewObject(cls, ctor, wpi::util::java::MakeJIntArray(env, rows),
wpi::util::java::MakeJIntArray(env, cols),
wpi::util::java::MakeJDoubleArray(env, values));
}
} // namespace wpi::math::detail

View File

@@ -0,0 +1,90 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#include <jni.h>
#include <utility>
#include <vector>
#include <sleipnir/autodiff/gradient.hpp>
#include <sleipnir/autodiff/variable.hpp>
#include <sleipnir/autodiff/variable_matrix.hpp>
#include "../SleipnirJNIUtil.hpp"
#include "org_wpilib_math_autodiff_GradientJNI.h"
#include "wpi/util/jni_util.hpp"
using namespace wpi::util::java;
extern "C" {
/*
* Class: org_wpilib_math_autodiff_GradientJNI
* Method: create
* Signature: (J[J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_GradientJNI_create
(JNIEnv* env, jclass, jlong variable, jlongArray wrt)
{
auto& variableObj = *reinterpret_cast<slp::Variable<double>*>(variable);
JSpan<const jlong> wrtSpan{env, wrt};
slp::VariableMatrix<double> wrtObj(slp::detail::empty, wrtSpan.size(), 1);
for (size_t i = 0; i < wrtSpan.size(); ++i) {
wrtObj[i] = *reinterpret_cast<slp::Variable<double>*>(wrtSpan[i]);
}
return reinterpret_cast<jlong>(
new slp::Gradient{variableObj, std::move(wrtObj)});
}
/*
* Class: org_wpilib_math_autodiff_GradientJNI
* Method: destroy
* Signature: (J)V
*/
JNIEXPORT void JNICALL
Java_org_wpilib_math_autodiff_GradientJNI_destroy
(JNIEnv* env, jclass, jlong handle)
{
delete reinterpret_cast<slp::Gradient<double>*>(handle);
}
/*
* Class: org_wpilib_math_autodiff_GradientJNI
* Method: get
* Signature: (J)[J
*/
JNIEXPORT jlongArray JNICALL
Java_org_wpilib_math_autodiff_GradientJNI_get
(JNIEnv* env, jclass, jlong handle)
{
auto& gradient = *reinterpret_cast<slp::Gradient<double>*>(handle);
auto g = gradient.get();
std::vector<jlong> varHandles;
varHandles.reserve(g.size());
for (auto& var : g) {
varHandles.emplace_back(
reinterpret_cast<jlong>(new slp::Variable<double>{var}));
}
return MakeJLongArray(env, varHandles);
}
/*
* Class: org_wpilib_math_autodiff_GradientJNI
* Method: value
* Signature: (J)Ljava/lang/Object;
*/
JNIEXPORT jobject JNICALL
Java_org_wpilib_math_autodiff_GradientJNI_value
(JNIEnv* env, jclass, jlong handle)
{
auto& gradient = *reinterpret_cast<slp::Gradient<double>*>(handle);
return wpi::math::detail::GetTriplets(env, gradient.value());
}
} // extern "C"

View File

@@ -0,0 +1,96 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#include <jni.h>
#include <utility>
#include <vector>
#include <sleipnir/autodiff/hessian.hpp>
#include <sleipnir/autodiff/variable.hpp>
#include <sleipnir/autodiff/variable_matrix.hpp>
#include "../SleipnirJNIUtil.hpp"
#include "org_wpilib_math_autodiff_HessianJNI.h"
#include "wpi/util/jni_util.hpp"
using namespace wpi::util::java;
extern "C" {
/*
* Class: org_wpilib_math_autodiff_HessianJNI
* Method: create
* Signature: (J[J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_HessianJNI_create
(JNIEnv* env, jclass, jlong variable, jlongArray wrt)
{
auto& variableObj = *reinterpret_cast<slp::Variable<double>*>(variable);
JSpan<const jlong> wrtSpan{env, wrt};
slp::VariableMatrix<double> wrtObj(slp::detail::empty, wrtSpan.size(), 1);
for (size_t i = 0; i < wrtSpan.size(); ++i) {
wrtObj[i] = *reinterpret_cast<slp::Variable<double>*>(wrtSpan[i]);
}
return reinterpret_cast<jlong>(
new slp::Hessian<double, Eigen::Lower | Eigen::Upper>{variableObj,
std::move(wrtObj)});
}
/*
* Class: org_wpilib_math_autodiff_HessianJNI
* Method: destroy
* Signature: (J)V
*/
JNIEXPORT void JNICALL
Java_org_wpilib_math_autodiff_HessianJNI_destroy
(JNIEnv* env, jclass, jlong handle)
{
delete reinterpret_cast<slp::Hessian<double, Eigen::Lower | Eigen::Upper>*>(
handle);
}
/*
* Class: org_wpilib_math_autodiff_HessianJNI
* Method: get
* Signature: (J)[J
*/
JNIEXPORT jlongArray JNICALL
Java_org_wpilib_math_autodiff_HessianJNI_get
(JNIEnv* env, jclass, jlong handle)
{
auto& hessian =
*reinterpret_cast<slp::Hessian<double, Eigen::Lower | Eigen::Upper>*>(
handle);
auto H = hessian.get();
std::vector<jlong> varHandles;
varHandles.reserve(H.size());
for (auto& var : H) {
varHandles.emplace_back(
reinterpret_cast<jlong>(new slp::Variable<double>{var}));
}
return MakeJLongArray(env, varHandles);
}
/*
* Class: org_wpilib_math_autodiff_HessianJNI
* Method: value
* Signature: (J)Ljava/lang/Object;
*/
JNIEXPORT jobject JNICALL
Java_org_wpilib_math_autodiff_HessianJNI_value
(JNIEnv* env, jclass, jlong handle)
{
auto& hessian =
*reinterpret_cast<slp::Hessian<double, Eigen::Lower | Eigen::Upper>*>(
handle);
return wpi::math::detail::GetTriplets(env, hessian.value());
}
} // extern "C"

View File

@@ -0,0 +1,96 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#include <jni.h>
#include <utility>
#include <vector>
#include <sleipnir/autodiff/jacobian.hpp>
#include <sleipnir/autodiff/variable.hpp>
#include <sleipnir/autodiff/variable_matrix.hpp>
#include "../SleipnirJNIUtil.hpp"
#include "org_wpilib_math_autodiff_JacobianJNI.h"
#include "wpi/util/jni_util.hpp"
using namespace wpi::util::java;
extern "C" {
/*
* Class: org_wpilib_math_autodiff_JacobianJNI
* Method: create
* Signature: ([J[J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_JacobianJNI_create
(JNIEnv* env, jclass, jlongArray variables, jlongArray wrt)
{
JSpan<const jlong> variablesSpan{env, variables};
slp::VariableMatrix<double> variablesObj(slp::detail::empty,
variablesSpan.size(), 1);
for (size_t i = 0; i < variablesSpan.size(); ++i) {
variablesObj[i] =
*reinterpret_cast<slp::Variable<double>*>(variablesSpan[i]);
}
JSpan<const jlong> wrtSpan{env, wrt};
slp::VariableMatrix<double> wrtObj(slp::detail::empty, wrtSpan.size(), 1);
for (size_t i = 0; i < wrtSpan.size(); ++i) {
wrtObj[i] = *reinterpret_cast<slp::Variable<double>*>(wrtSpan[i]);
}
return reinterpret_cast<jlong>(
new slp::Jacobian{std::move(variablesObj), std::move(wrtObj)});
}
/*
* Class: org_wpilib_math_autodiff_JacobianJNI
* Method: destroy
* Signature: (J)V
*/
JNIEXPORT void JNICALL
Java_org_wpilib_math_autodiff_JacobianJNI_destroy
(JNIEnv* env, jclass, jlong handle)
{
delete reinterpret_cast<slp::Jacobian<double>*>(handle);
}
/*
* Class: org_wpilib_math_autodiff_JacobianJNI
* Method: get
* Signature: (J)[J
*/
JNIEXPORT jlongArray JNICALL
Java_org_wpilib_math_autodiff_JacobianJNI_get
(JNIEnv* env, jclass, jlong handle)
{
auto& jacobian = *reinterpret_cast<slp::Jacobian<double>*>(handle);
auto J = jacobian.get();
std::vector<jlong> varHandles;
varHandles.reserve(J.size());
for (auto& var : J) {
varHandles.emplace_back(
reinterpret_cast<jlong>(new slp::Variable<double>{var}));
}
return MakeJLongArray(env, varHandles);
}
/*
* Class: org_wpilib_math_autodiff_JacobianJNI
* Method: value
* Signature: (J)Ljava/lang/Object;
*/
JNIEXPORT jobject JNICALL
Java_org_wpilib_math_autodiff_JacobianJNI_value
(JNIEnv* env, jclass, jlong handle)
{
auto& jacobian = *reinterpret_cast<slp::Jacobian<double>*>(handle);
return wpi::math::detail::GetTriplets(env, jacobian.value());
}
} // extern "C"

View File

@@ -0,0 +1,463 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#include <jni.h>
#include <sleipnir/autodiff/variable.hpp>
#include "org_wpilib_math_autodiff_VariableJNI.h"
extern "C" {
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: createDefault
* Signature: ()J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_createDefault
(JNIEnv* env, jclass)
{
return reinterpret_cast<jlong>(new slp::Variable<double>{});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: createDouble
* Signature: (D)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_createDouble
(JNIEnv* env, jclass, jdouble value)
{
return reinterpret_cast<jlong>(new slp::Variable<double>{value});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: createInt
* Signature: (I)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_createInt
(JNIEnv* env, jclass, jint value)
{
return reinterpret_cast<jlong>(new slp::Variable<double>{value});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: destroy
* Signature: (J)V
*/
JNIEXPORT void JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_destroy
(JNIEnv* env, jclass, jlong handle)
{
delete reinterpret_cast<slp::Variable<double>*>(handle);
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: setValue
* Signature: (JD)V
*/
JNIEXPORT void JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_setValue
(JNIEnv* env, jclass, jlong handle, jdouble value)
{
auto& lhsVar = *reinterpret_cast<slp::Variable<double>*>(handle);
lhsVar.set_value(value);
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: times
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_times
(JNIEnv* env, jclass, jlong handle, jlong rhs)
{
auto& lhsVar = *reinterpret_cast<slp::Variable<double>*>(handle);
auto& rhsVar = *reinterpret_cast<slp::Variable<double>*>(rhs);
return reinterpret_cast<jlong>(new slp::Variable<double>{lhsVar * rhsVar});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: div
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_div
(JNIEnv* env, jclass, jlong handle, jlong rhs)
{
auto& lhsVar = *reinterpret_cast<slp::Variable<double>*>(handle);
auto& rhsVar = *reinterpret_cast<slp::Variable<double>*>(rhs);
return reinterpret_cast<jlong>(new slp::Variable<double>{lhsVar / rhsVar});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: plus
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_plus
(JNIEnv* env, jclass, jlong handle, jlong rhs)
{
auto& lhsVar = *reinterpret_cast<slp::Variable<double>*>(handle);
auto& rhsVar = *reinterpret_cast<slp::Variable<double>*>(rhs);
return reinterpret_cast<jlong>(new slp::Variable<double>{lhsVar + rhsVar});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: minus
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_minus
(JNIEnv* env, jclass, jlong handle, jlong rhs)
{
auto& lhsVar = *reinterpret_cast<slp::Variable<double>*>(handle);
auto& rhsVar = *reinterpret_cast<slp::Variable<double>*>(rhs);
return reinterpret_cast<jlong>(new slp::Variable<double>{lhsVar - rhsVar});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: unaryMinus
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_unaryMinus
(JNIEnv* env, jclass, jlong handle)
{
auto& lhsVar = *reinterpret_cast<slp::Variable<double>*>(handle);
return reinterpret_cast<jlong>(new slp::Variable<double>{-lhsVar});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: value
* Signature: (J)D
*/
JNIEXPORT jdouble JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_value
(JNIEnv* env, jclass, jlong handle)
{
auto& lhsVar = *reinterpret_cast<slp::Variable<double>*>(handle);
return lhsVar.value();
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: type
* Signature: (J)I
*/
JNIEXPORT jint JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_type
(JNIEnv* env, jclass, jlong handle)
{
auto& lhsVar = *reinterpret_cast<slp::Variable<double>*>(handle);
return static_cast<jint>(lhsVar.type());
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: abs
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_abs
(JNIEnv* env, jclass, jlong x)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
return reinterpret_cast<jlong>(new slp::Variable<double>{abs(xVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: acos
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_acos
(JNIEnv* env, jclass, jlong x)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
return reinterpret_cast<jlong>(new slp::Variable<double>{acos(xVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: asin
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_asin
(JNIEnv* env, jclass, jlong x)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
return reinterpret_cast<jlong>(new slp::Variable<double>{asin(xVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: atan
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_atan
(JNIEnv* env, jclass, jlong x)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
return reinterpret_cast<jlong>(new slp::Variable<double>{atan(xVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: atan2
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_atan2
(JNIEnv* env, jclass, jlong y, jlong x)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
auto& yVar = *reinterpret_cast<slp::Variable<double>*>(y);
return reinterpret_cast<jlong>(new slp::Variable<double>{atan2(yVar, xVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: cbrt
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_cbrt
(JNIEnv* env, jclass, jlong x)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
return reinterpret_cast<jlong>(new slp::Variable<double>{cbrt(xVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: cos
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_cos
(JNIEnv* env, jclass, jlong x)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
return reinterpret_cast<jlong>(new slp::Variable<double>{cos(xVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: cosh
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_cosh
(JNIEnv* env, jclass, jlong x)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
return reinterpret_cast<jlong>(new slp::Variable<double>{cosh(xVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: exp
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_exp
(JNIEnv* env, jclass, jlong x)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
return reinterpret_cast<jlong>(new slp::Variable<double>{exp(xVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: hypot
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_hypot
(JNIEnv* env, jclass, jlong x, jlong y)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
auto& yVar = *reinterpret_cast<slp::Variable<double>*>(y);
return reinterpret_cast<jlong>(new slp::Variable<double>{hypot(xVar, yVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: log
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_log
(JNIEnv* env, jclass, jlong x)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
return reinterpret_cast<jlong>(new slp::Variable<double>{log(xVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: log10
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_log10
(JNIEnv* env, jclass, jlong x)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
return reinterpret_cast<jlong>(new slp::Variable<double>{log10(xVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: max
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_max
(JNIEnv* env, jclass, jlong a, jlong b)
{
auto& aVar = *reinterpret_cast<slp::Variable<double>*>(a);
auto& bVar = *reinterpret_cast<slp::Variable<double>*>(b);
return reinterpret_cast<jlong>(
new slp::Variable<double>{(slp::max)(aVar, bVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: min
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_min
(JNIEnv* env, jclass, jlong a, jlong b)
{
auto& aVar = *reinterpret_cast<slp::Variable<double>*>(a);
auto& bVar = *reinterpret_cast<slp::Variable<double>*>(b);
return reinterpret_cast<jlong>(
new slp::Variable<double>{(slp::min)(aVar, bVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: pow
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_pow
(JNIEnv* env, jclass, jlong base, jlong power)
{
auto& baseVar = *reinterpret_cast<slp::Variable<double>*>(base);
auto& powerVar = *reinterpret_cast<slp::Variable<double>*>(power);
return reinterpret_cast<jlong>(
new slp::Variable<double>{pow(baseVar, powerVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: signum
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_signum
(JNIEnv* env, jclass, jlong x)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
return reinterpret_cast<jlong>(new slp::Variable<double>{sign(xVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: sin
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_sin
(JNIEnv* env, jclass, jlong x)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
return reinterpret_cast<jlong>(new slp::Variable<double>{sin(xVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: sinh
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_sinh
(JNIEnv* env, jclass, jlong x)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
return reinterpret_cast<jlong>(new slp::Variable<double>{sinh(xVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: sqrt
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_sqrt
(JNIEnv* env, jclass, jlong x)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
return reinterpret_cast<jlong>(new slp::Variable<double>{sqrt(xVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: tan
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_tan
(JNIEnv* env, jclass, jlong x)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
return reinterpret_cast<jlong>(new slp::Variable<double>{tan(xVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: tanh
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_tanh
(JNIEnv* env, jclass, jlong x)
{
auto& xVar = *reinterpret_cast<slp::Variable<double>*>(x);
return reinterpret_cast<jlong>(new slp::Variable<double>{tanh(xVar)});
}
/*
* Class: org_wpilib_math_autodiff_VariableJNI
* Method: totalNativeMemoryUsage
* Signature: ()J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_autodiff_VariableJNI_totalNativeMemoryUsage
(JNIEnv* env, jclass)
{
return slp::global_pool_resource().blocks_in_use() *
sizeof(slp::detail::Expression<double>);
}
} // extern "C"

View File

@@ -0,0 +1,53 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#include <jni.h>
#include <vector>
#include <sleipnir/autodiff/variable.hpp>
#include <sleipnir/autodiff/variable_matrix.hpp>
#include "org_wpilib_math_autodiff_VariableMatrixJNI.h"
#include "wpi/util/jni_util.hpp"
using namespace wpi::util::java;
extern "C" {
/*
* Class: org_wpilib_math_autodiff_VariableMatrixJNI
* Method: solve
* Signature: ([JI[JI)[J
*/
JNIEXPORT jlongArray JNICALL
Java_org_wpilib_math_autodiff_VariableMatrixJNI_solve
(JNIEnv* env, jclass, jlongArray A, jint Acols, jlongArray B, jint Bcols)
{
JSpan<const jlong> ASpan{env, A};
slp::VariableMatrix<double> AObj(slp::detail::empty, ASpan.size() / Acols,
Acols);
for (size_t i = 0; i < ASpan.size(); ++i) {
AObj[i] = *reinterpret_cast<slp::Variable<double>*>(ASpan[i]);
}
JSpan<const jlong> BSpan{env, B};
slp::VariableMatrix<double> BObj(slp::detail::empty, BSpan.size() / Bcols,
Bcols);
for (size_t i = 0; i < BSpan.size(); ++i) {
BObj[i] = *reinterpret_cast<slp::Variable<double>*>(BSpan[i]);
}
auto X = slp::solve(AObj, BObj);
std::vector<jlong> varHandles;
varHandles.reserve(X.size());
for (auto& var : X) {
varHandles.emplace_back(
reinterpret_cast<jlong>(new slp::Variable<double>{var}));
}
return MakeJLongArray(env, varHandles);
}
} // extern "C"

View File

@@ -0,0 +1,255 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#include <jni.h>
#include <vector>
#include <sleipnir/optimization/problem.hpp>
#include "../SleipnirJNIUtil.hpp"
#include "org_wpilib_math_optimization_ProblemJNI.h"
#include "wpi/util/jni_util.hpp"
using namespace wpi::util::java;
extern "C" {
namespace {
// ProblemJNI_solve() sets these before calling Problem::solve() so the Java
// callback has a valid JNIEnv and object on which to call
// Problem.runCallbacks()
thread_local JNIEnv* callbackEnv;
thread_local jobject callbackObj;
} // namespace
/*
* Class: org_wpilib_math_optimization_ProblemJNI
* Method: create
* Signature: ()J
*/
JNIEXPORT jlong JNICALL
Java_org_wpilib_math_optimization_ProblemJNI_create
(JNIEnv* env, jclass)
{
auto problem = new slp::Problem<double>;
// Configure Java iteration callbacks
problem->add_persistent_callback(
[](const slp::IterationInfo<double>& info) -> bool {
// Find Problem class
static JClass cls{callbackEnv, "org/wpilib/math/optimization/Problem"};
if (!cls) {
return true;
}
// Find Problem.runCallbacks()
static jmethodID runCallbacks = callbackEnv->GetMethodID(
cls, "runCallbacks",
"(III[DLorg/wpilib/math/autodiff/NativeSparseTriplets;"
"Lorg/wpilib/math/autodiff/NativeSparseTriplets;"
"Lorg/wpilib/math/autodiff/NativeSparseTriplets;"
"Lorg/wpilib/math/autodiff/NativeSparseTriplets;)Z");
if (!runCallbacks) {
return true;
}
// Run Java callbacks
return callbackEnv->CallBooleanMethod(
callbackObj, runCallbacks, info.A_e.rows(), info.A_i.rows(),
info.iteration, MakeJDoubleArray(callbackEnv, info.x),
wpi::math::detail::GetTriplets(callbackEnv, info.g),
wpi::math::detail::GetTriplets(callbackEnv, info.H),
wpi::math::detail::GetTriplets(callbackEnv, info.A_e),
wpi::math::detail::GetTriplets(callbackEnv, info.A_i));
});
return reinterpret_cast<jlong>(problem);
}
/*
* Class: org_wpilib_math_optimization_ProblemJNI
* Method: destroy
* Signature: (J)V
*/
JNIEXPORT void JNICALL
Java_org_wpilib_math_optimization_ProblemJNI_destroy
(JNIEnv* env, jclass, jlong handle)
{
delete reinterpret_cast<slp::Problem<double>*>(handle);
}
/*
* Class: org_wpilib_math_optimization_ProblemJNI
* Method: decisionVariable
* Signature: (JII)[J
*/
JNIEXPORT jlongArray JNICALL
Java_org_wpilib_math_optimization_ProblemJNI_decisionVariable
(JNIEnv* env, jclass, jlong handle, jint rows, jint cols)
{
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
auto vars = problem.decision_variable(rows, cols);
std::vector<jlong> varHandles;
varHandles.reserve(vars.size());
for (auto& var : vars) {
varHandles.emplace_back(
reinterpret_cast<jlong>(new slp::Variable<double>{var}));
}
return MakeJLongArray(env, varHandles);
}
/*
* Class: org_wpilib_math_optimization_ProblemJNI
* Method: symmetricDecisionVariable
* Signature: (JI)[J
*/
JNIEXPORT jlongArray JNICALL
Java_org_wpilib_math_optimization_ProblemJNI_symmetricDecisionVariable
(JNIEnv* env, jclass, jlong handle, jint rows)
{
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
auto vars = problem.symmetric_decision_variable(rows);
std::vector<jlong> varHandles;
varHandles.reserve(vars.size());
for (auto& var : vars) {
varHandles.emplace_back(
reinterpret_cast<jlong>(new slp::Variable<double>{var}));
}
return MakeJLongArray(env, varHandles);
}
/*
* Class: org_wpilib_math_optimization_ProblemJNI
* Method: minimize
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL
Java_org_wpilib_math_optimization_ProblemJNI_minimize
(JNIEnv* env, jclass, jlong handle, jlong costHandle)
{
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
auto& costVar = *reinterpret_cast<slp::Variable<double>*>(costHandle);
problem.minimize(costVar);
}
/*
* Class: org_wpilib_math_optimization_ProblemJNI
* Method: maximize
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL
Java_org_wpilib_math_optimization_ProblemJNI_maximize
(JNIEnv* env, jclass, jlong handle, jlong objectiveHandle)
{
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
auto& objectiveVar =
*reinterpret_cast<slp::Variable<double>*>(objectiveHandle);
problem.maximize(objectiveVar);
}
/*
* Class: org_wpilib_math_optimization_ProblemJNI
* Method: subjectToEq
* Signature: (J[J)V
*/
JNIEXPORT void JNICALL
Java_org_wpilib_math_optimization_ProblemJNI_subjectToEq
(JNIEnv* env, jclass, jlong handle, jlongArray constraintHandles)
{
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
JSpan<const jlong> constraintHandlesSpan{env, constraintHandles};
for (const auto& constraintHandle : constraintHandlesSpan) {
const auto& constraint =
*reinterpret_cast<slp::Variable<double>*>(constraintHandle);
problem.subject_to(constraint == 0.0);
}
}
/*
* Class: org_wpilib_math_optimization_ProblemJNI
* Method: subjectToIneq
* Signature: (J[J)V
*/
JNIEXPORT void JNICALL
Java_org_wpilib_math_optimization_ProblemJNI_subjectToIneq
(JNIEnv* env, jclass, jlong handle, jlongArray constraintHandles)
{
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
JSpan<const jlong> constraintHandlesSpan{env, constraintHandles};
for (const auto& constraintHandle : constraintHandlesSpan) {
const auto& constraint =
*reinterpret_cast<slp::Variable<double>*>(constraintHandle);
problem.subject_to(constraint >= 0.0);
}
}
/*
* Class: org_wpilib_math_optimization_ProblemJNI
* Method: costFunctionType
* Signature: (J)I
*/
JNIEXPORT jint JNICALL
Java_org_wpilib_math_optimization_ProblemJNI_costFunctionType
(JNIEnv* env, jclass, jlong handle)
{
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
return static_cast<jint>(problem.cost_function_type());
}
/*
* Class: org_wpilib_math_optimization_ProblemJNI
* Method: equalityConstraintType
* Signature: (J)I
*/
JNIEXPORT jint JNICALL
Java_org_wpilib_math_optimization_ProblemJNI_equalityConstraintType
(JNIEnv* env, jclass, jlong handle)
{
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
return static_cast<jint>(problem.equality_constraint_type());
}
/*
* Class: org_wpilib_math_optimization_ProblemJNI
* Method: inequalityConstraintType
* Signature: (J)I
*/
JNIEXPORT jint JNICALL
Java_org_wpilib_math_optimization_ProblemJNI_inequalityConstraintType
(JNIEnv* env, jclass, jlong handle)
{
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
return static_cast<jint>(problem.inequality_constraint_type());
}
/*
* Class: org_wpilib_math_optimization_ProblemJNI
* Method: solve
* Signature: (Ljava/lang/Object;JDIDZZ)I
*/
JNIEXPORT jint JNICALL
Java_org_wpilib_math_optimization_ProblemJNI_solve
(JNIEnv* env, jclass, jobject obj, jlong handle, jdouble tolerance,
jint maxIterations, jdouble timeout, jboolean feasibleIPM,
jboolean diagnostics)
{
auto& problem = *reinterpret_cast<slp::Problem<double>*>(handle);
callbackEnv = env;
callbackObj = obj;
slp::Options options{
tolerance, maxIterations, std::chrono::duration<double>{timeout},
static_cast<bool>(feasibleIPM), static_cast<bool>(diagnostics)};
return static_cast<int>(problem.solve(options));
}
} // extern "C"