From 7e95010a298f132d4f3ccc105ce296c4b69203ad Mon Sep 17 00:00:00 2001 From: Redrield Date: Sun, 18 Aug 2019 18:00:40 -0400 Subject: [PATCH] Add compile-time EJML matrix wrapper to wpiutil (#1804) This adds a wrapper over EJML's SimpleMatrix that uses generated classes representing numbers to encode the dimensions of each matrix at compile time, and to check operations between matrices for validity at compile time, rather than failing with an exception at runtime. This is required for the Java implementation of state-space control. Additions to the wpiutil gradle script, and a python script at the wpiutil root are used to generate numeric types from a template at build time for both gradle and cmake. Users will be able to access types through functions on the Nat class. --- wpiutil/CMakeLists.txt | 4 +- wpiutil/build.gradle | 52 +++ wpiutil/generate_numbers.py | 37 ++ wpiutil/src/generate/GenericNumber.java.in | 34 ++ wpiutil/src/generate/Nat.java.in | 45 +++ wpiutil/src/generate/NatGetter.java.in | 3 + .../wpi/first/wpiutil/math/MatBuilder.java | 50 +++ .../edu/wpi/first/wpiutil/math/Matrix.java | 327 ++++++++++++++++++ .../wpi/first/wpiutil/math/MatrixUtils.java | 83 +++++ .../java/edu/wpi/first/wpiutil/math/Num.java | 20 ++ .../first/wpiutil/math/SimpleMatrixUtils.java | 161 +++++++++ .../wpi/first/wpiutil/math/VecBuilder.java | 21 ++ .../wpi/first/wpiutil/math/MatrixTest.java | 210 +++++++++++ 13 files changed, 1046 insertions(+), 1 deletion(-) create mode 100644 wpiutil/generate_numbers.py create mode 100644 wpiutil/src/generate/GenericNumber.java.in create mode 100644 wpiutil/src/generate/Nat.java.in create mode 100644 wpiutil/src/generate/NatGetter.java.in create mode 100644 wpiutil/src/main/java/edu/wpi/first/wpiutil/math/MatBuilder.java create mode 100644 wpiutil/src/main/java/edu/wpi/first/wpiutil/math/Matrix.java create mode 100644 wpiutil/src/main/java/edu/wpi/first/wpiutil/math/MatrixUtils.java create mode 100644 wpiutil/src/main/java/edu/wpi/first/wpiutil/math/Num.java create mode 100644 wpiutil/src/main/java/edu/wpi/first/wpiutil/math/SimpleMatrixUtils.java create mode 100644 wpiutil/src/main/java/edu/wpi/first/wpiutil/math/VecBuilder.java create mode 100644 wpiutil/src/test/java/edu/wpi/first/wpiutil/math/MatrixTest.java diff --git a/wpiutil/CMakeLists.txt b/wpiutil/CMakeLists.txt index 52d5260579..ef3c5d6b02 100644 --- a/wpiutil/CMakeLists.txt +++ b/wpiutil/CMakeLists.txt @@ -40,7 +40,9 @@ if (NOT WITHOUT_JAVA) set(CMAKE_JAVA_INCLUDE_PATH wpiutil.jar ${EJML_JARS}) - file(GLOB_RECURSE JAVA_SOURCES src/main/java/*.java) + execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/wpiutil/generate_numbers.py ${CMAKE_BINARY_DIR}/wpiutil) + + file(GLOB_RECURSE JAVA_SOURCES src/main/java/*.java ${CMAKE_BINARY_DIR}/wpiutil/generated/*.java) add_jar(wpiutil_jar ${JAVA_SOURCES} INCLUDE_JARS ${EJML_JARS} OUTPUT_NAME wpiutil) get_property(WPIUTIL_JAR_FILE TARGET wpiutil_jar PROPERTY JAR_FILE) diff --git a/wpiutil/build.gradle b/wpiutil/build.gradle index 6e1b35e08c..7a9e7ac95c 100644 --- a/wpiutil/build.gradle +++ b/wpiutil/build.gradle @@ -240,3 +240,55 @@ model { dependencies { compile "org.ejml:ejml-simple:0.38" } + +def wpilibNumberFileInput = file("src/generate/GenericNumber.java.in") +def natFileInput = file("src/generate/Nat.java.in") +def natGetterInput = file("src/generate/NatGetter.java.in") +def wpilibNumberFileOutputDir = file("$buildDir/generated/java/edu/wpi/first/wpiutil/math/numbers") +def wpilibNatFileOutput = file("$buildDir/generated/java/edu/wpi/first/wpiutil/math/Nat.java") +def maxNum = 20 + +task generateNumbers() { + description = "Generates generic number classes from template" + group = "WPILib" + + inputs.file wpilibNumberFileInput + + doLast { + if(wpilibNumberFileOutputDir.exists()) { + wpilibNumberFileOutputDir.delete() + } + wpilibNumberFileOutputDir.mkdirs() + + for(i in 0..maxNum) { + def outputFile = new File(wpilibNumberFileOutputDir, "N${i}.java") + def read = wpilibNumberFileInput.text.replace('${num}', i.toString()) + outputFile.write(read) + } + } +} + +task generateNat() { + description = "Generates Nat.java" + group = "WPILib" + inputs.files([natFileInput, natGetterInput]) + dependsOn generateNumbers + + doLast { + if(wpilibNatFileOutput.exists()) { + wpilibNatFileOutput.delete() + } + + def template = natFileInput.text + "\n" + for(i in 0..maxNum) { + template += natGetterInput.text.replace('${num}', i.toString()) + "\n" + } + template += "}\n" // Close the class body + + wpilibNatFileOutput.write(template) + } +} + +sourceSets.main.java.srcDir "${buildDir}/generated/java" +compileJava.dependsOn generateNumbers +compileJava.dependsOn generateNat diff --git a/wpiutil/generate_numbers.py b/wpiutil/generate_numbers.py new file mode 100644 index 0000000000..91f2bb69d9 --- /dev/null +++ b/wpiutil/generate_numbers.py @@ -0,0 +1,37 @@ +import os +import shutil +import sys + +MAX_NUM = 20 + +dirname, _ = os.path.split(os.path.abspath(__file__)) +cmake_binary_dir = sys.argv[1] + +with open(f"{dirname}/src/generate/GenericNumber.java.in", "r") as templateFile: + template = templateFile.read() + rootPath = f"{cmake_binary_dir}/generated/main/java/edu/wpi/first/wpiutil/math/numbers" + + if os.path.exists(rootPath): + shutil.rmtree(rootPath) + os.makedirs(rootPath) + + for i in range(MAX_NUM + 1): + with open(f"{rootPath}/N{i}.java", "w") as f: + f.write(template.replace("${num}", str(i))) + +with open(f"{dirname}/src/generate/Nat.java.in", "r") as templateFile: + template = templateFile.read() + outputPath = f"{cmake_binary_dir}/generated/main/java/edu/wpi/first/wpiutil/math/Nat.java" + with open(f"{dirname}/src/generate/NatGetter.java.in", "r") as getterFile: + getter = getterFile.read() + + if os.path.exists(outputPath): + os.remove(outputPath) + + for i in range(MAX_NUM + 1): + template += getter.replace("${num}", str(i)) + + template += "}\n" + + with open(outputPath, "w") as f: + f.write(template) diff --git a/wpiutil/src/generate/GenericNumber.java.in b/wpiutil/src/generate/GenericNumber.java.in new file mode 100644 index 0000000000..5a36582e4d --- /dev/null +++ b/wpiutil/src/generate/GenericNumber.java.in @@ -0,0 +1,34 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2019 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +package edu.wpi.first.wpiutil.math.numbers; + +import edu.wpi.first.wpiutil.math.Nat; +import edu.wpi.first.wpiutil.math.Num; + +/** + * A class representing the number ${num}. +*/ +public final class N${num} extends Num implements Nat { + private N${num}() { + } + + /** + * The integer this class represents. + * + * @return The literal number ${num}. + */ + @Override + public int getNum() { + return ${num}; + } + + /** + * The singleton instance of this class. + */ + public static final N${num} instance = new N${num}(); +} diff --git a/wpiutil/src/generate/Nat.java.in b/wpiutil/src/generate/Nat.java.in new file mode 100644 index 0000000000..5263c53a58 --- /dev/null +++ b/wpiutil/src/generate/Nat.java.in @@ -0,0 +1,45 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2019 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +package edu.wpi.first.wpiutil.math; + +import edu.wpi.first.wpiutil.math.numbers.N0; +import edu.wpi.first.wpiutil.math.numbers.N1; +import edu.wpi.first.wpiutil.math.numbers.N10; +import edu.wpi.first.wpiutil.math.numbers.N11; +import edu.wpi.first.wpiutil.math.numbers.N12; +import edu.wpi.first.wpiutil.math.numbers.N13; +import edu.wpi.first.wpiutil.math.numbers.N14; +import edu.wpi.first.wpiutil.math.numbers.N15; +import edu.wpi.first.wpiutil.math.numbers.N16; +import edu.wpi.first.wpiutil.math.numbers.N17; +import edu.wpi.first.wpiutil.math.numbers.N18; +import edu.wpi.first.wpiutil.math.numbers.N19; +import edu.wpi.first.wpiutil.math.numbers.N2; +import edu.wpi.first.wpiutil.math.numbers.N20; +import edu.wpi.first.wpiutil.math.numbers.N3; +import edu.wpi.first.wpiutil.math.numbers.N4; +import edu.wpi.first.wpiutil.math.numbers.N5; +import edu.wpi.first.wpiutil.math.numbers.N6; +import edu.wpi.first.wpiutil.math.numbers.N7; +import edu.wpi.first.wpiutil.math.numbers.N8; +import edu.wpi.first.wpiutil.math.numbers.N9; + +/** + * A natural number expressed as a java class. + * The counterpart to {@link Num} that should be used as a concrete value. + * + * @param The {@link Num} this represents. + */ +@SuppressWarnings({"MethodName", "unused", "PMD.TooManyMethods"}) +public interface Nat { + /** + * The number this interface represents. + * + * @return The number backing this value. + */ + int getNum(); diff --git a/wpiutil/src/generate/NatGetter.java.in b/wpiutil/src/generate/NatGetter.java.in new file mode 100644 index 0000000000..d268fab425 --- /dev/null +++ b/wpiutil/src/generate/NatGetter.java.in @@ -0,0 +1,3 @@ + static Nat N${num}() { + return N${num}.instance; + } diff --git a/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/MatBuilder.java b/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/MatBuilder.java new file mode 100644 index 0000000000..a5490a33d4 --- /dev/null +++ b/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/MatBuilder.java @@ -0,0 +1,50 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2019 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +package edu.wpi.first.wpiutil.math; + +import java.util.Objects; + +import org.ejml.simple.SimpleMatrix; + +/** + * A class for constructing arbitrary RxC matrices. + * + * @param The number of rows of the desired matrix. + * @param The number of columns of the desired matrix. + */ +public class MatBuilder { + private final Nat m_rows; + private final Nat m_cols; + + /** + * Fills the matrix with the given data, encoded in row major form. + * (The matrix is filled row by row, left to right with the given data). + * + * @param data The data to fill the matrix with. + * @return The constructed matrix. + */ + @SuppressWarnings("LineLength") + public final Matrix fill(double... data) { + if (Objects.requireNonNull(data).length != this.m_rows.getNum() * this.m_cols.getNum()) { + throw new IllegalArgumentException("Invalid matrix data provided. Wanted " + this.m_rows.getNum() + + " x " + this.m_cols.getNum() + " matrix, but got " + data.length + " elements"); + } else { + return new Matrix<>(new SimpleMatrix(this.m_rows.getNum(), this.m_cols.getNum(), true, data)); + } + } + + /** + * Creates a new {@link MatBuilder} with the given dimensions. + * @param rows The number of rows of the matrix. + * @param cols The number of columns of the matrix. + */ + public MatBuilder(Nat rows, Nat cols) { + this.m_rows = Objects.requireNonNull(rows); + this.m_cols = Objects.requireNonNull(cols); + } +} diff --git a/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/Matrix.java b/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/Matrix.java new file mode 100644 index 0000000000..5780498b30 --- /dev/null +++ b/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/Matrix.java @@ -0,0 +1,327 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2019 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +package edu.wpi.first.wpiutil.math; + +import java.util.Objects; + +import org.ejml.dense.row.CommonOps_DDRM; +import org.ejml.dense.row.NormOps_DDRM; +import org.ejml.simple.SimpleMatrix; + +/** + * A shape-safe wrapper over Efficient Java Matrix Library (EJML) matrices. + * + *

This class is intended to be used alongside the state space library. + * + * @param The number of rows in this matrix. + * @param The number of columns in this matrix. + */ +@SuppressWarnings("PMD.TooManyMethods") +public class Matrix { + + private final SimpleMatrix m_storage; + + /** + * Gets the number of columns in this matrix. + * + * @return The number of columns, according to the internal storage. + */ + public final int getNumCols() { + return this.m_storage.numCols(); + } + + /** + * Gets the number of rows in this matrix. + * + * @return The number of rows, according to the internal storage. + */ + public final int getNumRows() { + return this.m_storage.numRows(); + } + + /** + * Get an element of this matrix. + * + * @param row The row of the element. + * @param col The column of the element. + * @return The element in this matrix at row,col. + */ + public final double get(int row, int col) { + return this.m_storage.get(row, col); + } + + /** + * Sets the value at the given indices. + * + * @param row The row of the element. + * @param col The column of the element. + * @param value The value to insert at the given location. + */ + public final void set(int row, int col, double value) { + this.m_storage.set(row, col, value); + } + + /** + * If a vector then a square matrix is returned + * if a matrix then a vector of diagonal elements is returned. + * + * @return Diagonal elements inside a vector or a square matrix with the same diagonal elements. + */ + public final Matrix diag() { + return new Matrix<>(this.m_storage.diag()); + } + + /** + * Returns the largest element of this matrix. + * + * @return The largest element of this matrix. + */ + public final double maxInternal() { + return CommonOps_DDRM.elementMax(this.m_storage.getDDRM()); + } + + /** + * Returns the smallest element of this matrix. + * + * @return The smallest element of this matrix. + */ + public final double minInternal() { + return CommonOps_DDRM.elementMin(this.m_storage.getDDRM()); + } + + /** + * Calculates the mean of the elements in this matrix. + * + * @return The mean value of this matrix. + */ + public final double mean() { + return this.elementSum() / (double) this.m_storage.getNumElements(); + } + + /** + * Multiplies this matrix with another that has C rows. + * + *

As matrix multiplication is only defined if the number of columns + * in the first matrix matches the number of rows in the second, + * this operation will fail to compile under any other circumstances. + * + * @param other The other matrix to multiply by. + * @param The number of columns in the second matrix. + * @return The result of the matrix multiplication between this and the given matrix. + */ + public final Matrix times(Matrix other) { + return new Matrix<>(this.m_storage.mult(other.m_storage)); + } + + /** + * Multiplies all the elements of this matrix by the given scalar. + * + * @param value The scalar value to multiply by. + * @return A new matrix with all the elements multiplied by the given value. + */ + public final Matrix times(double value) { + return new Matrix<>(this.m_storage.scale(value)); + } + + /** + *

+ * Returns a matrix which is the result of an element by element multiplication of 'this' and 'b'. + * ci,j = ai,j*bi,j + *

+ * + * @param other A matrix. + * @return The element by element multiplication of 'this' and 'b'. + */ + public final Matrix elementTimes(Matrix other) { + return new Matrix<>(this.m_storage.elementMult(Objects.requireNonNull(other).m_storage)); + } + + /** + * Subtracts the given value from all the elements of this matrix. + * + * @param value The value to subtract. + * @return The resultant matrix. + */ + public final Matrix minus(double value) { + return new Matrix<>(this.m_storage.minus(value)); + } + + + /** + * Subtracts the given matrix from this matrix. + * + * @param value The matrix to subtract. + * @return The resultant matrix. + */ + public final Matrix minus(Matrix value) { + return new Matrix<>(this.m_storage.minus(Objects.requireNonNull(value).m_storage)); + } + + + /** + * Adds the given value to all the elements of this matrix. + * + * @param value The value to add. + * @return The resultant matrix. + */ + public final Matrix plus(double value) { + return new Matrix<>(this.m_storage.plus(value)); + } + + /** + * Adds the given matrix to this matrix. + * + * @param value The matrix to add. + * @return The resultant matrix. + */ + public final Matrix plus(Matrix value) { + return new Matrix<>(this.m_storage.plus(value.m_storage)); + } + + /** + * Divides all elements of this matrix by the given value. + * + * @param value The value to divide by. + * @return The resultant matrix. + */ + public final Matrix div(int value) { + return new Matrix<>(this.m_storage.divide((double) value)); + } + + /** + * Divides all elements of this matrix by the given value. + * + * @param value The value to divide by. + * @return The resultant matrix. + */ + public final Matrix div(double value) { + return new Matrix<>(this.m_storage.divide(value)); + } + + /** + * Calculates the transpose, M^T of this matrix. + * + * @return The tranpose matrix. + */ + public final Matrix transpose() { + return new Matrix<>(this.m_storage.transpose()); + } + + + /** + * Returns a copy of this matrix. + * + * @return A copy of this matrix. + */ + public final Matrix copy() { + return new Matrix<>(this.m_storage.copy()); + } + + + /** + * Returns the inverse matrix of this matrix. + * + * @return The inverse of this matrix. + * @throws org.ejml.data.SingularMatrixException If this matrix is non-invertable. + */ + public final Matrix inv() { + return new Matrix<>(this.m_storage.invert()); + } + + /** + * Returns the determinant of this matrix. + * + * @return The determinant of this matrix. + */ + public final double det() { + return this.m_storage.determinant(); + } + + /** + * Computes the Frobenius normal of the matrix.
+ *
+ * normF = Sqrt{ ∑i=1:mj=1:n { aij2} } + * + * @return The matrix's Frobenius normal. + */ + public final double normF() { + return this.m_storage.normF(); + } + + /** + * Computes the induced p = 1 matrix norm.
+ *
+ * ||A||1= max(j=1 to n; sum(i=1 to m; |aij|)) + * + * @return The norm. + */ + public final double normIndP1() { + return NormOps_DDRM.inducedP1(this.m_storage.getDDRM()); + } + + /** + * Computes the sum of all the elements in the matrix. + * + * @return Sum of all the elements. + */ + public final double elementSum() { + return this.m_storage.elementSum(); + } + + /** + * Computes the trace of the matrix. + * + * @return The trace of the matrix. + */ + public final double trace() { + return this.m_storage.trace(); + } + + /** + * Returns a matrix which is the result of an element by element power of 'this' and 'b': + * ci,j = ai,j ^ b. + * + * @param b Scalar + * @return The element by element power of 'this' and 'b'. + */ + @SuppressWarnings("ParameterName") + public final Matrix epow(double b) { + return new Matrix<>(this.m_storage.elementPower(b)); + } + + /** + * Returns a matrix which is the result of an element by element power of 'this' and 'b': + * ci,j = ai,j ^ b. + * + * @param b Scalar. + * @return The element by element power of 'this' and 'b'. + */ + @SuppressWarnings("ParameterName") + public final Matrix epow(int b) { + return new Matrix<>(this.m_storage.elementPower((double) b)); + } + + /** + * Returns the EJML {@link SimpleMatrix} backing this wrapper. + * + * @return The untyped EJML {@link SimpleMatrix}. + */ + public final SimpleMatrix getStorage() { + return this.m_storage; + } + + /** + * Constructs a new matrix with the given storage. + * Caller should make sure that the provided generic bounds match the shape of the provided matrix + * + * @param storage The {@link SimpleMatrix} to back this value + */ + public Matrix(SimpleMatrix storage) { + this.m_storage = Objects.requireNonNull(storage); + } +} diff --git a/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/MatrixUtils.java b/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/MatrixUtils.java new file mode 100644 index 0000000000..ac4deb9320 --- /dev/null +++ b/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/MatrixUtils.java @@ -0,0 +1,83 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2019 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +package edu.wpi.first.wpiutil.math; + +import java.util.Objects; + +import org.ejml.simple.SimpleMatrix; + +import edu.wpi.first.wpiutil.math.numbers.N1; + +public final class MatrixUtils { + private MatrixUtils() { + throw new AssertionError("utility class"); + } + + /** + * Creates a new matrix of zeros. + * + * @param rows The number of rows in the matrix. + * @param cols The number of columns in the matrix. + * @param The number of rows in the matrix as a generic. + * @param The number of columns in the matrix as a generic. + * @return An RxC matrix filled with zeros. + */ + @SuppressWarnings("LineLength") + public static Matrix zeros(Nat rows, Nat cols) { + return new Matrix<>( + new SimpleMatrix(Objects.requireNonNull(rows).getNum(), Objects.requireNonNull(cols).getNum())); + } + + /** + * Creates a new vector of zeros. + * + * @param nums The size of the desired vector. + * @param The size of the desired vector as a generic. + * @return A vector of size N filled with zeros. + */ + public static Matrix zeros(Nat nums) { + return new Matrix<>(new SimpleMatrix(Objects.requireNonNull(nums).getNum(), 1)); + } + + /** + * Creates the identity matrix of the given dimension. + * + * @param dim The dimension of the desired matrix. + * @param The dimension of the desired matrix as a generic. + * @return The DxD identity matrix. + */ + public static Matrix eye(Nat dim) { + return new Matrix<>(SimpleMatrix.identity(Objects.requireNonNull(dim).getNum())); + } + + /** + * Entrypoint to the MatBuilder class for creation + * of custom matrices with the given dimensions and contents. + * + * @param rows The number of rows of the desired matrix. + * @param cols The number of columns of the desired matrix. + * @param The number of rows of the desired matrix as a generic. + * @param The number of columns of the desired matrix as a generic. + * @return A builder to construct the matrix. + */ + public static MatBuilder mat(Nat rows, Nat cols) { + return new MatBuilder<>(rows, cols); + } + + /** + * Entrypoint to the VecBuilder class for creation + * of custom vectors with the given size and contents. + * + * @param dim The dimension of the vector. + * @param The dimension of the vector as a generic. + * @return A builder to construct the vector. + */ + public static VecBuilder vec(Nat dim) { + return new VecBuilder<>(dim); + } +} diff --git a/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/Num.java b/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/Num.java new file mode 100644 index 0000000000..c7385ea62c --- /dev/null +++ b/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/Num.java @@ -0,0 +1,20 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2019 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +package edu.wpi.first.wpiutil.math; + +/** + * A number expressed as a java class. + */ +public abstract class Num { + /** + * The number this is backing. + * + * @return The number represented by this class. + */ + public abstract int getNum(); +} diff --git a/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/SimpleMatrixUtils.java b/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/SimpleMatrixUtils.java new file mode 100644 index 0000000000..ea983d47ec --- /dev/null +++ b/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/SimpleMatrixUtils.java @@ -0,0 +1,161 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2019 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +package edu.wpi.first.wpiutil.math; + +import java.util.function.BiFunction; + +import org.ejml.dense.row.NormOps_DDRM; +import org.ejml.simple.SimpleBase; +import org.ejml.simple.SimpleMatrix; + +public class SimpleMatrixUtils { + private SimpleMatrixUtils() {} + + /** + * Compute the matrix exponential, e^M of the given matrix. + * + * @param matrix The matrix to compute the exponential of. + * @return The resultant matrix. + */ + @SuppressWarnings({"LocalVariableName", "LineLength"}) + public static SimpleMatrix expm(SimpleMatrix matrix) { + BiFunction solveProvider = SimpleBase::solve; + SimpleMatrix A = matrix; + double A_L1 = NormOps_DDRM.inducedP1(matrix.getDDRM()); + int n_squarings = 0; + + if (A_L1 < 1.495585217958292e-002) { + Pair pair = _pade3(A); + return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider); + } else if (A_L1 < 2.539398330063230e-001) { + Pair pair = _pade5(A); + return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider); + } else if (A_L1 < 9.504178996162932e-001) { + Pair pair = _pade7(A); + return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider); + } else if (A_L1 < 2.097847961257068e+000) { + Pair pair = _pade9(A); + return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider); + } else { + double maxNorm = 5.371920351148152; + double log = Math.log(A_L1 / maxNorm) / Math.log(2); // logb(2, arg) + n_squarings = (int) Math.max(0, Math.ceil(log)); + A = A.divide(Math.pow(2.0, n_squarings)); + Pair pair = _pade13(A); + return dispatchPade(pair.getFirst(), pair.getSecond(), n_squarings, solveProvider); + } + } + + @SuppressWarnings({"LocalVariableName", "ParameterName", "LineLength"}) + private static SimpleMatrix dispatchPade(SimpleMatrix U, SimpleMatrix V, + int nSquarings, BiFunction solveProvider) { + SimpleMatrix P = U.plus(V); + SimpleMatrix Q = U.negative().plus(V); + + SimpleMatrix R = solveProvider.apply(Q, P); + + for (int i = 0; i < nSquarings; i++) { + R = R.mult(R); + } + + return R; + } + + @SuppressWarnings({"MethodName", "LocalVariableName", "ParameterName"}) + private static Pair _pade3(SimpleMatrix A) { + double[] b = new double[]{120, 60, 12, 1}; + SimpleMatrix ident = eye(A.numRows(), A.numCols()); + + SimpleMatrix A2 = A.mult(A); + SimpleMatrix U = A.mult(A2.mult(ident.scale(b[1]).plus(b[3]))); + SimpleMatrix V = A2.scale(b[2]).plus(ident.scale(b[0])); + return new Pair<>(U, V); + } + + @SuppressWarnings({"MethodName", "LocalVariableName", "ParameterName"}) + private static Pair _pade5(SimpleMatrix A) { + double[] b = new double[]{30240, 15120, 3360, 420, 30, 1}; + SimpleMatrix ident = eye(A.numRows(), A.numCols()); + SimpleMatrix A2 = A.mult(A); + SimpleMatrix A4 = A2.mult(A2); + + SimpleMatrix U = A.mult(A4.scale(b[5]).plus(A2.scale(b[3])).plus(ident.scale(b[1]))); + SimpleMatrix V = A4.scale(b[4]).plus(A2.scale(b[2])).plus(ident.scale(b[0])); + + return new Pair<>(U, V); + } + + @SuppressWarnings({"MethodName", "LocalVariableName", "LineLength", "ParameterName"}) + private static Pair _pade7(SimpleMatrix A) { + double[] b = new double[]{17297280, 8648640, 1995840, 277200, 25200, 1512, 56, 1}; + SimpleMatrix ident = eye(A.numRows(), A.numCols()); + SimpleMatrix A2 = A.mult(A); + SimpleMatrix A4 = A2.mult(A2); + SimpleMatrix A6 = A4.mult(A2); + + SimpleMatrix U = A.mult(A6.scale(b[7]).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1]))); + SimpleMatrix V = A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0])); + + return new Pair<>(U, V); + } + + @SuppressWarnings({"MethodName", "LocalVariableName", "ParameterName", "LineLength"}) + private static Pair _pade9(SimpleMatrix A) { + double[] b = new double[]{17643225600.0, 8821612800.0, 2075673600, 302702400, 30270240, + 2162160, 110880, 3960, 90, 1}; + SimpleMatrix ident = eye(A.numRows(), A.numCols()); + SimpleMatrix A2 = A.mult(A); + SimpleMatrix A4 = A2.mult(A2); + SimpleMatrix A6 = A4.mult(A2); + SimpleMatrix A8 = A6.mult(A2); + + SimpleMatrix U = A.mult(A8.scale(b[9]).plus(A6.scale(b[7])).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1]))); + SimpleMatrix V = A8.scale(b[8]).plus(A6.scale(b[6])).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0])); + + return new Pair<>(U, V); + } + + @SuppressWarnings({"MethodName", "LocalVariableName", "LineLength", "ParameterName"}) + private static Pair _pade13(SimpleMatrix A) { + double[] b = new double[]{64764752532480000.0, 32382376266240000.0, 7771770303897600.0, + 1187353796428800.0, 129060195264000.0, 10559470521600.0, 670442572800.0, + 33522128640.0, 1323241920, 40840800, 960960, 16380, 182, 1}; + SimpleMatrix ident = eye(A.numRows(), A.numCols()); + + SimpleMatrix A2 = A.mult(A); + SimpleMatrix A4 = A2.mult(A2); + SimpleMatrix A6 = A4.mult(A2); + + SimpleMatrix U = A.mult(A6.scale(b[13]).plus(A4.scale(b[11])).plus(A2.scale(b[9])).plus(A6.scale(b[7])).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1]))); + SimpleMatrix V = A6.mult(A6.scale(b[12]).plus(A4.scale(b[10])).plus(A2.scale(b[8]))).plus(A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0]))); + + return new Pair<>(U, V); + } + + private static SimpleMatrix eye(int rows, int cols) { + return SimpleMatrix.identity(Math.min(rows, cols)); + } + + private static class Pair { + private final A m_first; + private final B m_second; + + Pair(A first, B second) { + m_first = first; + m_second = second; + } + + public A getFirst() { + return m_first; + } + + public B getSecond() { + return m_second; + } + } +} diff --git a/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/VecBuilder.java b/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/VecBuilder.java new file mode 100644 index 0000000000..deaaa41725 --- /dev/null +++ b/wpiutil/src/main/java/edu/wpi/first/wpiutil/math/VecBuilder.java @@ -0,0 +1,21 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2019 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +package edu.wpi.first.wpiutil.math; + +import edu.wpi.first.wpiutil.math.numbers.N1; + +/** + * A specialization of {@link MatBuilder} for constructing vectors (Nx1 matrices). + * + * @param The dimension of the vector to be constructed. + */ +public class VecBuilder extends MatBuilder { + public VecBuilder(Nat rows) { + super(rows, Nat.N1()); + } +} diff --git a/wpiutil/src/test/java/edu/wpi/first/wpiutil/math/MatrixTest.java b/wpiutil/src/test/java/edu/wpi/first/wpiutil/math/MatrixTest.java new file mode 100644 index 0000000000..4d6697d3b1 --- /dev/null +++ b/wpiutil/src/test/java/edu/wpi/first/wpiutil/math/MatrixTest.java @@ -0,0 +1,210 @@ +/*----------------------------------------------------------------------------*/ +/* Copyright (c) 2019 FIRST. All Rights Reserved. */ +/* Open Source Software - may be modified and shared by FRC teams. The code */ +/* must be accompanied by the FIRST BSD license file in the root directory of */ +/* the project. */ +/*----------------------------------------------------------------------------*/ + +package edu.wpi.first.wpiutil.math; + +import org.ejml.data.SingularMatrixException; +import org.ejml.dense.row.MatrixFeatures_DDRM; +import org.ejml.simple.SimpleMatrix; +import org.junit.jupiter.api.Test; + +import edu.wpi.first.wpiutil.math.numbers.N1; +import edu.wpi.first.wpiutil.math.numbers.N2; +import edu.wpi.first.wpiutil.math.numbers.N3; +import edu.wpi.first.wpiutil.math.numbers.N4; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class MatrixTest { + @Test + void testMatrixMultiplication() { + var mat1 = MatrixUtils.mat(Nat.N2(), Nat.N2()) + .fill(2.0, 1.0, + 0.0, 1.0); + var mat2 = MatrixUtils.mat(Nat.N2(), Nat.N2()) + .fill(3.0, 0.0, + 0.0, 2.5); + + Matrix result = mat1.times(mat2); + + assertTrue(MatrixFeatures_DDRM.isEquals( + MatrixUtils.mat(Nat.N2(), Nat.N2()) + .fill(6.0, 2.5, + 0.0, 2.5).getStorage().getDDRM(), + result.getStorage().getDDRM() + )); + + var mat3 = MatrixUtils.mat(Nat.N2(), Nat.N3()) + .fill(1.0, 3.0, 0.5, + 2.0, 4.3, 1.2); + var mat4 = MatrixUtils.mat(Nat.N3(), Nat.N4()) + .fill(3.0, 1.5, 2.0, 4.5, + 2.3, 1.0, 1.6, 3.1, + 5.2, 2.1, 2.0, 1.0); + + Matrix result2 = mat3.times(mat4); + + assertTrue(MatrixFeatures_DDRM.isIdentical( + MatrixUtils.mat(Nat.N2(), Nat.N4()) + .fill(12.5, 5.55, 7.8, 14.3, + 22.13, 9.82, 13.28, 23.53).getStorage().getDDRM(), + result2.getStorage().getDDRM(), + 1E-9 + )); + } + + @Test + void testMatrixVectorMultiplication() { + var mat = MatrixUtils.mat(Nat.N2(), Nat.N2()) + .fill(1.0, 1.0, + 0.0, 1.0); + + var vec = MatrixUtils.vec(Nat.N2()) + .fill(3.0, + 2.0); + + Matrix result = mat.times(vec); + assertTrue(MatrixFeatures_DDRM.isEquals( + MatrixUtils.vec(Nat.N2()) + .fill(5.0, + 2.0).getStorage().getDDRM(), + result.getStorage().getDDRM() + )); + } + + @Test + void testTranspose() { + Matrix vec = MatrixUtils.vec(Nat.N3()) + .fill(1.0, + 2.0, + 3.0); + + Matrix transpose = vec.transpose(); + + assertTrue(MatrixFeatures_DDRM.isEquals( + MatrixUtils.mat(Nat.N1(), Nat.N3()).fill(1.0, 2.0, 3.0).getStorage() + .getDDRM(), + transpose.getStorage().getDDRM() + )); + } + + @Test + void testInverse() { + var mat = MatrixUtils.mat(Nat.N3(), Nat.N3()) + .fill(1.0, 3.0, 2.0, + 5.0, 2.0, 1.5, + 0.0, 1.3, 2.5); + + var inv = mat.inv(); + + assertTrue(MatrixFeatures_DDRM.isIdentical( + MatrixUtils.eye(Nat.N3()).getStorage().getDDRM(), + mat.times(inv).getStorage().getDDRM(), + 1E-9 + )); + + assertTrue(MatrixFeatures_DDRM.isIdentical( + MatrixUtils.eye(Nat.N3()).getStorage().getDDRM(), + inv.times(mat).getStorage().getDDRM(), + 1E-9 + )); + } + + @Test + void testUninvertableMatrix() { + var singularMatrix = MatrixUtils.mat(Nat.N2(), Nat.N2()) + .fill(2.0, 1.0, + 2.0, 1.0); + + assertThrows(SingularMatrixException.class, singularMatrix::inv); + } + + @Test + void testMatrixScalarArithmetic() { + var mat = MatrixUtils.mat(Nat.N2(), Nat.N2()) + .fill(1.0, 2.0, + 3.0, 4.0); + + + assertTrue(MatrixFeatures_DDRM.isEquals( + MatrixUtils.mat(Nat.N2(), Nat.N2()) + .fill(3.0, 4.0, + 5.0, 6.0).getStorage().getDDRM(), + mat.plus(2.0).getStorage().getDDRM() + )); + + assertTrue(MatrixFeatures_DDRM.isEquals( + MatrixUtils.mat(Nat.N2(), Nat.N2()) + .fill(0.0, 1.0, + 2.0, 3.0).getStorage().getDDRM(), + mat.minus(1.0).getStorage().getDDRM() + )); + + assertTrue(MatrixFeatures_DDRM.isEquals( + MatrixUtils.mat(Nat.N2(), Nat.N2()) + .fill(2.0, 4.0, + 6.0, 8.0).getStorage().getDDRM(), + mat.times(2.0).getStorage().getDDRM() + )); + + assertTrue(MatrixFeatures_DDRM.isIdentical( + MatrixUtils.mat(Nat.N2(), Nat.N2()) + .fill(0.5, 1.0, + 1.5, 2.0).getStorage().getDDRM(), + mat.div(2.0).getStorage().getDDRM(), + 1E-3 + )); + } + + @Test + void testMatrixMatrixArithmetic() { + var mat1 = MatrixUtils.mat(Nat.N2(), Nat.N2()) + .fill(1.0, 2.0, + 3.0, 4.0); + + var mat2 = MatrixUtils.mat(Nat.N2(), Nat.N2()) + .fill(5.0, 6.0, + 7.0, 8.0); + + assertTrue(MatrixFeatures_DDRM.isEquals( + MatrixUtils.mat(Nat.N2(), Nat.N2()) + .fill(-4.0, -4.0, + -4.0, -4.0).getStorage().getDDRM(), + mat1.minus(mat2).getStorage().getDDRM() + )); + + assertTrue(MatrixFeatures_DDRM.isEquals( + MatrixUtils.mat(Nat.N2(), Nat.N2()) + .fill(6.0, 8.0, + 10.0, 12.0).getStorage().getDDRM(), + mat1.plus(mat2).getStorage().getDDRM() + )); + } + + @Test + void testMatrixExponential() { + SimpleMatrix matrix = MatrixUtils.eye(Nat.N2()).getStorage(); + var result = SimpleMatrixUtils.expm(matrix); + + assertTrue(MatrixFeatures_DDRM.isIdentical( + result.getDDRM(), + new SimpleMatrix(2, 2, true, new double[]{Math.E, 0, 0, Math.E}).getDDRM(), + 1E-9 + )); + + matrix = new SimpleMatrix(2, 2, true, new double[]{1, 2, 3, 4}); + result = SimpleMatrixUtils.expm(matrix.scale(0.01)); + + assertTrue(MatrixFeatures_DDRM.isIdentical( + result.getDDRM(), + new SimpleMatrix(2, 2, true, new double[]{1.01035625, 0.02050912, + 0.03076368, 1.04111993}).getDDRM(), + 1E-8 + )); + } +}