Add compile-time EJML matrix wrapper to wpiutil (#1804)

This adds a wrapper over EJML's SimpleMatrix that uses generated classes representing numbers to encode the dimensions of each matrix at compile time, and to check operations between matrices for validity at compile time, rather than failing with an exception at runtime. This is required for the Java implementation of state-space control.

Additions to the wpiutil gradle script, and a python script at the wpiutil root are used to generate numeric types from a template at build time for both gradle and cmake. Users will be able to access types through functions on the Nat class.
This commit is contained in:
Redrield
2019-08-18 18:00:40 -04:00
committed by Peter Johnson
parent 3ebc5a6d3a
commit 7e95010a29
13 changed files with 1046 additions and 1 deletions

View File

@@ -0,0 +1,50 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2019 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpiutil.math;
import java.util.Objects;
import org.ejml.simple.SimpleMatrix;
/**
* A class for constructing arbitrary RxC matrices.
*
* @param <R> The number of rows of the desired matrix.
* @param <C> The number of columns of the desired matrix.
*/
public class MatBuilder<R extends Num, C extends Num> {
private final Nat<R> m_rows;
private final Nat<C> m_cols;
/**
* Fills the matrix with the given data, encoded in row major form.
* (The matrix is filled row by row, left to right with the given data).
*
* @param data The data to fill the matrix with.
* @return The constructed matrix.
*/
@SuppressWarnings("LineLength")
public final Matrix<R, C> fill(double... data) {
if (Objects.requireNonNull(data).length != this.m_rows.getNum() * this.m_cols.getNum()) {
throw new IllegalArgumentException("Invalid matrix data provided. Wanted " + this.m_rows.getNum()
+ " x " + this.m_cols.getNum() + " matrix, but got " + data.length + " elements");
} else {
return new Matrix<>(new SimpleMatrix(this.m_rows.getNum(), this.m_cols.getNum(), true, data));
}
}
/**
* Creates a new {@link MatBuilder} with the given dimensions.
* @param rows The number of rows of the matrix.
* @param cols The number of columns of the matrix.
*/
public MatBuilder(Nat<R> rows, Nat<C> cols) {
this.m_rows = Objects.requireNonNull(rows);
this.m_cols = Objects.requireNonNull(cols);
}
}

View File

@@ -0,0 +1,327 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2019 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpiutil.math;
import java.util.Objects;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.dense.row.NormOps_DDRM;
import org.ejml.simple.SimpleMatrix;
/**
* A shape-safe wrapper over Efficient Java Matrix Library (EJML) matrices.
*
* <p>This class is intended to be used alongside the state space library.
*
* @param <R> The number of rows in this matrix.
* @param <C> The number of columns in this matrix.
*/
@SuppressWarnings("PMD.TooManyMethods")
public class Matrix<R extends Num, C extends Num> {
private final SimpleMatrix m_storage;
/**
* Gets the number of columns in this matrix.
*
* @return The number of columns, according to the internal storage.
*/
public final int getNumCols() {
return this.m_storage.numCols();
}
/**
* Gets the number of rows in this matrix.
*
* @return The number of rows, according to the internal storage.
*/
public final int getNumRows() {
return this.m_storage.numRows();
}
/**
* Get an element of this matrix.
*
* @param row The row of the element.
* @param col The column of the element.
* @return The element in this matrix at row,col.
*/
public final double get(int row, int col) {
return this.m_storage.get(row, col);
}
/**
* Sets the value at the given indices.
*
* @param row The row of the element.
* @param col The column of the element.
* @param value The value to insert at the given location.
*/
public final void set(int row, int col, double value) {
this.m_storage.set(row, col, value);
}
/**
* If a vector then a square matrix is returned
* if a matrix then a vector of diagonal elements is returned.
*
* @return Diagonal elements inside a vector or a square matrix with the same diagonal elements.
*/
public final Matrix<R, C> diag() {
return new Matrix<>(this.m_storage.diag());
}
/**
* Returns the largest element of this matrix.
*
* @return The largest element of this matrix.
*/
public final double maxInternal() {
return CommonOps_DDRM.elementMax(this.m_storage.getDDRM());
}
/**
* Returns the smallest element of this matrix.
*
* @return The smallest element of this matrix.
*/
public final double minInternal() {
return CommonOps_DDRM.elementMin(this.m_storage.getDDRM());
}
/**
* Calculates the mean of the elements in this matrix.
*
* @return The mean value of this matrix.
*/
public final double mean() {
return this.elementSum() / (double) this.m_storage.getNumElements();
}
/**
* Multiplies this matrix with another that has C rows.
*
* <p>As matrix multiplication is only defined if the number of columns
* in the first matrix matches the number of rows in the second,
* this operation will fail to compile under any other circumstances.
*
* @param other The other matrix to multiply by.
* @param <C2> The number of columns in the second matrix.
* @return The result of the matrix multiplication between this and the given matrix.
*/
public final <C2 extends Num> Matrix<R, C2> times(Matrix<C, C2> other) {
return new Matrix<>(this.m_storage.mult(other.m_storage));
}
/**
* Multiplies all the elements of this matrix by the given scalar.
*
* @param value The scalar value to multiply by.
* @return A new matrix with all the elements multiplied by the given value.
*/
public final Matrix<R, C> times(double value) {
return new Matrix<>(this.m_storage.scale(value));
}
/**
* <p>
* Returns a matrix which is the result of an element by element multiplication of 'this' and 'b'.
* c<sub>i,j</sub> = a<sub>i,j</sub>*b<sub>i,j</sub>
* </p>
*
* @param other A matrix.
* @return The element by element multiplication of 'this' and 'b'.
*/
public final Matrix<R, C> elementTimes(Matrix<R, C> other) {
return new Matrix<>(this.m_storage.elementMult(Objects.requireNonNull(other).m_storage));
}
/**
* Subtracts the given value from all the elements of this matrix.
*
* @param value The value to subtract.
* @return The resultant matrix.
*/
public final Matrix<R, C> minus(double value) {
return new Matrix<>(this.m_storage.minus(value));
}
/**
* Subtracts the given matrix from this matrix.
*
* @param value The matrix to subtract.
* @return The resultant matrix.
*/
public final Matrix<R, C> minus(Matrix<R, C> value) {
return new Matrix<>(this.m_storage.minus(Objects.requireNonNull(value).m_storage));
}
/**
* Adds the given value to all the elements of this matrix.
*
* @param value The value to add.
* @return The resultant matrix.
*/
public final Matrix<R, C> plus(double value) {
return new Matrix<>(this.m_storage.plus(value));
}
/**
* Adds the given matrix to this matrix.
*
* @param value The matrix to add.
* @return The resultant matrix.
*/
public final Matrix<R, C> plus(Matrix<R, C> value) {
return new Matrix<>(this.m_storage.plus(value.m_storage));
}
/**
* Divides all elements of this matrix by the given value.
*
* @param value The value to divide by.
* @return The resultant matrix.
*/
public final Matrix<R, C> div(int value) {
return new Matrix<>(this.m_storage.divide((double) value));
}
/**
* Divides all elements of this matrix by the given value.
*
* @param value The value to divide by.
* @return The resultant matrix.
*/
public final Matrix<R, C> div(double value) {
return new Matrix<>(this.m_storage.divide(value));
}
/**
* Calculates the transpose, M^T of this matrix.
*
* @return The tranpose matrix.
*/
public final Matrix<C, R> transpose() {
return new Matrix<>(this.m_storage.transpose());
}
/**
* Returns a copy of this matrix.
*
* @return A copy of this matrix.
*/
public final Matrix<R, C> copy() {
return new Matrix<>(this.m_storage.copy());
}
/**
* Returns the inverse matrix of this matrix.
*
* @return The inverse of this matrix.
* @throws org.ejml.data.SingularMatrixException If this matrix is non-invertable.
*/
public final Matrix<R, C> inv() {
return new Matrix<>(this.m_storage.invert());
}
/**
* Returns the determinant of this matrix.
*
* @return The determinant of this matrix.
*/
public final double det() {
return this.m_storage.determinant();
}
/**
* Computes the Frobenius normal of the matrix.<br>
* <br>
* normF = Sqrt{ &sum;<sub>i=1:m</sub> &sum;<sub>j=1:n</sub> { a<sub>ij</sub><sup>2</sup>} }
*
* @return The matrix's Frobenius normal.
*/
public final double normF() {
return this.m_storage.normF();
}
/**
* Computes the induced p = 1 matrix norm.<br>
* <br>
* ||A||<sub>1</sub>= max(j=1 to n; sum(i=1 to m; |a<sub>ij</sub>|))
*
* @return The norm.
*/
public final double normIndP1() {
return NormOps_DDRM.inducedP1(this.m_storage.getDDRM());
}
/**
* Computes the sum of all the elements in the matrix.
*
* @return Sum of all the elements.
*/
public final double elementSum() {
return this.m_storage.elementSum();
}
/**
* Computes the trace of the matrix.
*
* @return The trace of the matrix.
*/
public final double trace() {
return this.m_storage.trace();
}
/**
* Returns a matrix which is the result of an element by element power of 'this' and 'b':
* c<sub>i,j</sub> = a<sub>i,j</sub> ^ b.
*
* @param b Scalar
* @return The element by element power of 'this' and 'b'.
*/
@SuppressWarnings("ParameterName")
public final Matrix<R, C> epow(double b) {
return new Matrix<>(this.m_storage.elementPower(b));
}
/**
* Returns a matrix which is the result of an element by element power of 'this' and 'b':
* c<sub>i,j</sub> = a<sub>i,j</sub> ^ b.
*
* @param b Scalar.
* @return The element by element power of 'this' and 'b'.
*/
@SuppressWarnings("ParameterName")
public final Matrix<R, C> epow(int b) {
return new Matrix<>(this.m_storage.elementPower((double) b));
}
/**
* Returns the EJML {@link SimpleMatrix} backing this wrapper.
*
* @return The untyped EJML {@link SimpleMatrix}.
*/
public final SimpleMatrix getStorage() {
return this.m_storage;
}
/**
* Constructs a new matrix with the given storage.
* Caller should make sure that the provided generic bounds match the shape of the provided matrix
*
* @param storage The {@link SimpleMatrix} to back this value
*/
public Matrix(SimpleMatrix storage) {
this.m_storage = Objects.requireNonNull(storage);
}
}

View File

@@ -0,0 +1,83 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2019 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpiutil.math;
import java.util.Objects;
import org.ejml.simple.SimpleMatrix;
import edu.wpi.first.wpiutil.math.numbers.N1;
public final class MatrixUtils {
private MatrixUtils() {
throw new AssertionError("utility class");
}
/**
* Creates a new matrix of zeros.
*
* @param rows The number of rows in the matrix.
* @param cols The number of columns in the matrix.
* @param <R> The number of rows in the matrix as a generic.
* @param <C> The number of columns in the matrix as a generic.
* @return An RxC matrix filled with zeros.
*/
@SuppressWarnings("LineLength")
public static <R extends Num, C extends Num> Matrix<R, C> zeros(Nat<R> rows, Nat<C> cols) {
return new Matrix<>(
new SimpleMatrix(Objects.requireNonNull(rows).getNum(), Objects.requireNonNull(cols).getNum()));
}
/**
* Creates a new vector of zeros.
*
* @param nums The size of the desired vector.
* @param <N> The size of the desired vector as a generic.
* @return A vector of size N filled with zeros.
*/
public static <N extends Num> Matrix<N, N1> zeros(Nat<N> nums) {
return new Matrix<>(new SimpleMatrix(Objects.requireNonNull(nums).getNum(), 1));
}
/**
* Creates the identity matrix of the given dimension.
*
* @param dim The dimension of the desired matrix.
* @param <D> The dimension of the desired matrix as a generic.
* @return The DxD identity matrix.
*/
public static <D extends Num> Matrix<D, D> eye(Nat<D> dim) {
return new Matrix<>(SimpleMatrix.identity(Objects.requireNonNull(dim).getNum()));
}
/**
* Entrypoint to the MatBuilder class for creation
* of custom matrices with the given dimensions and contents.
*
* @param rows The number of rows of the desired matrix.
* @param cols The number of columns of the desired matrix.
* @param <R> The number of rows of the desired matrix as a generic.
* @param <C> The number of columns of the desired matrix as a generic.
* @return A builder to construct the matrix.
*/
public static <R extends Num, C extends Num> MatBuilder<R, C> mat(Nat<R> rows, Nat<C> cols) {
return new MatBuilder<>(rows, cols);
}
/**
* Entrypoint to the VecBuilder class for creation
* of custom vectors with the given size and contents.
*
* @param dim The dimension of the vector.
* @param <D> The dimension of the vector as a generic.
* @return A builder to construct the vector.
*/
public static <D extends Num> VecBuilder<D> vec(Nat<D> dim) {
return new VecBuilder<>(dim);
}
}

View File

@@ -0,0 +1,20 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2019 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpiutil.math;
/**
* A number expressed as a java class.
*/
public abstract class Num {
/**
* The number this is backing.
*
* @return The number represented by this class.
*/
public abstract int getNum();
}

View File

@@ -0,0 +1,161 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2019 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpiutil.math;
import java.util.function.BiFunction;
import org.ejml.dense.row.NormOps_DDRM;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;
public class SimpleMatrixUtils {
private SimpleMatrixUtils() {}
/**
* Compute the matrix exponential, e^M of the given matrix.
*
* @param matrix The matrix to compute the exponential of.
* @return The resultant matrix.
*/
@SuppressWarnings({"LocalVariableName", "LineLength"})
public static SimpleMatrix expm(SimpleMatrix matrix) {
BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> solveProvider = SimpleBase::solve;
SimpleMatrix A = matrix;
double A_L1 = NormOps_DDRM.inducedP1(matrix.getDDRM());
int n_squarings = 0;
if (A_L1 < 1.495585217958292e-002) {
Pair<SimpleMatrix, SimpleMatrix> pair = _pade3(A);
return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider);
} else if (A_L1 < 2.539398330063230e-001) {
Pair<SimpleMatrix, SimpleMatrix> pair = _pade5(A);
return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider);
} else if (A_L1 < 9.504178996162932e-001) {
Pair<SimpleMatrix, SimpleMatrix> pair = _pade7(A);
return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider);
} else if (A_L1 < 2.097847961257068e+000) {
Pair<SimpleMatrix, SimpleMatrix> pair = _pade9(A);
return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider);
} else {
double maxNorm = 5.371920351148152;
double log = Math.log(A_L1 / maxNorm) / Math.log(2); // logb(2, arg)
n_squarings = (int) Math.max(0, Math.ceil(log));
A = A.divide(Math.pow(2.0, n_squarings));
Pair<SimpleMatrix, SimpleMatrix> pair = _pade13(A);
return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider);
}
}
@SuppressWarnings({"LocalVariableName", "ParameterName", "LineLength"})
private static SimpleMatrix dispatchPade(SimpleMatrix U, SimpleMatrix V,
int nSquarings, BiFunction<SimpleMatrix, SimpleMatrix, SimpleMatrix> solveProvider) {
SimpleMatrix P = U.plus(V);
SimpleMatrix Q = U.negative().plus(V);
SimpleMatrix R = solveProvider.apply(Q, P);
for (int i = 0; i < nSquarings; i++) {
R = R.mult(R);
}
return R;
}
@SuppressWarnings({"MethodName", "LocalVariableName", "ParameterName"})
private static Pair<SimpleMatrix, SimpleMatrix> _pade3(SimpleMatrix A) {
double[] b = new double[]{120, 60, 12, 1};
SimpleMatrix ident = eye(A.numRows(), A.numCols());
SimpleMatrix A2 = A.mult(A);
SimpleMatrix U = A.mult(A2.mult(ident.scale(b[1]).plus(b[3])));
SimpleMatrix V = A2.scale(b[2]).plus(ident.scale(b[0]));
return new Pair<>(U, V);
}
@SuppressWarnings({"MethodName", "LocalVariableName", "ParameterName"})
private static Pair<SimpleMatrix, SimpleMatrix> _pade5(SimpleMatrix A) {
double[] b = new double[]{30240, 15120, 3360, 420, 30, 1};
SimpleMatrix ident = eye(A.numRows(), A.numCols());
SimpleMatrix A2 = A.mult(A);
SimpleMatrix A4 = A2.mult(A2);
SimpleMatrix U = A.mult(A4.scale(b[5]).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
SimpleMatrix V = A4.scale(b[4]).plus(A2.scale(b[2])).plus(ident.scale(b[0]));
return new Pair<>(U, V);
}
@SuppressWarnings({"MethodName", "LocalVariableName", "LineLength", "ParameterName"})
private static Pair<SimpleMatrix, SimpleMatrix> _pade7(SimpleMatrix A) {
double[] b = new double[]{17297280, 8648640, 1995840, 277200, 25200, 1512, 56, 1};
SimpleMatrix ident = eye(A.numRows(), A.numCols());
SimpleMatrix A2 = A.mult(A);
SimpleMatrix A4 = A2.mult(A2);
SimpleMatrix A6 = A4.mult(A2);
SimpleMatrix U = A.mult(A6.scale(b[7]).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
SimpleMatrix V = A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0]));
return new Pair<>(U, V);
}
@SuppressWarnings({"MethodName", "LocalVariableName", "ParameterName", "LineLength"})
private static Pair<SimpleMatrix, SimpleMatrix> _pade9(SimpleMatrix A) {
double[] b = new double[]{17643225600.0, 8821612800.0, 2075673600, 302702400, 30270240,
2162160, 110880, 3960, 90, 1};
SimpleMatrix ident = eye(A.numRows(), A.numCols());
SimpleMatrix A2 = A.mult(A);
SimpleMatrix A4 = A2.mult(A2);
SimpleMatrix A6 = A4.mult(A2);
SimpleMatrix A8 = A6.mult(A2);
SimpleMatrix U = A.mult(A8.scale(b[9]).plus(A6.scale(b[7])).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
SimpleMatrix V = A8.scale(b[8]).plus(A6.scale(b[6])).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0]));
return new Pair<>(U, V);
}
@SuppressWarnings({"MethodName", "LocalVariableName", "LineLength", "ParameterName"})
private static Pair<SimpleMatrix, SimpleMatrix> _pade13(SimpleMatrix A) {
double[] b = new double[]{64764752532480000.0, 32382376266240000.0, 7771770303897600.0,
1187353796428800.0, 129060195264000.0, 10559470521600.0, 670442572800.0,
33522128640.0, 1323241920, 40840800, 960960, 16380, 182, 1};
SimpleMatrix ident = eye(A.numRows(), A.numCols());
SimpleMatrix A2 = A.mult(A);
SimpleMatrix A4 = A2.mult(A2);
SimpleMatrix A6 = A4.mult(A2);
SimpleMatrix U = A.mult(A6.scale(b[13]).plus(A4.scale(b[11])).plus(A2.scale(b[9])).plus(A6.scale(b[7])).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
SimpleMatrix V = A6.mult(A6.scale(b[12]).plus(A4.scale(b[10])).plus(A2.scale(b[8]))).plus(A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0])));
return new Pair<>(U, V);
}
private static SimpleMatrix eye(int rows, int cols) {
return SimpleMatrix.identity(Math.min(rows, cols));
}
private static class Pair<A, B> {
private final A m_first;
private final B m_second;
Pair(A first, B second) {
m_first = first;
m_second = second;
}
public A getFirst() {
return m_first;
}
public B getSecond() {
return m_second;
}
}
}

View File

@@ -0,0 +1,21 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2019 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpiutil.math;
import edu.wpi.first.wpiutil.math.numbers.N1;
/**
* A specialization of {@link MatBuilder} for constructing vectors (Nx1 matrices).
*
* @param <N> The dimension of the vector to be constructed.
*/
public class VecBuilder<N extends Num> extends MatBuilder<N, N1> {
public VecBuilder(Nat<N> rows) {
super(rows, Nat.N1());
}
}