[wpimath] Add core State-space classes (#2614)

Co-authored-by: Tyler Veness <calcmogul@gmail.com>
Co-authored-by: Claudius Tewari <cttewari@gmail.com>
Co-authored-by: Declan Freeman-Gleason <declanfreemangleason@gmail.com>
This commit is contained in:
Matt
2020-08-14 23:40:33 -07:00
committed by GitHub
parent e5b84e2f87
commit 3b283ab9aa
84 changed files with 11747 additions and 174 deletions

View File

@@ -10,6 +10,7 @@ package edu.wpi.first.math;
import org.ejml.simple.SimpleMatrix;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Num;
public final class Drake {
private Drake() {
@@ -46,13 +47,13 @@ public final class Drake {
* @param R Input cost matrix.
* @return Solution of DARE.
*/
@SuppressWarnings("ParameterName")
public static SimpleMatrix discreteAlgebraicRiccatiEquation(
Matrix A,
Matrix B,
Matrix Q,
Matrix R) {
return discreteAlgebraicRiccatiEquation(A.getStorage(), B.getStorage(), Q.getStorage(),
R.getStorage());
@SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
public static <States extends Num, Inputs extends Num> Matrix<States, States>
discreteAlgebraicRiccatiEquation(Matrix<States, States> A,
Matrix<States, Inputs> B,
Matrix<States, States> Q,
Matrix<Inputs, Inputs> R) {
return new Matrix<>(discreteAlgebraicRiccatiEquation(A.getStorage(), B.getStorage(),
Q.getStorage(), R.getStorage()));
}
}

View File

@@ -0,0 +1,215 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.controller;
import java.util.function.BiFunction;
import java.util.function.Function;
import edu.wpi.first.wpilibj.system.NumericalJacobian;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Nat;
import edu.wpi.first.wpiutil.math.Num;
import edu.wpi.first.wpiutil.math.numbers.N1;
/**
* Constructs a control-affine plant inversion model-based feedforward from
* given model dynamics.
*
* <p>If given the vector valued function as f(x, u) where x is the state
* vector and u is the input vector, the B matrix(continuous input matrix)
* is calculated through a {@link edu.wpi.first.wpilibj.system.NumericalJacobian}.
* In this case f has to be control-affine (of the form f(x) + Bu).
*
* <p>The feedforward is calculated as
* <strong> u_ff = B<sup>+</sup> (rDot - f(x))</strong>, where
* <strong> B<sup>+</sup> </strong> is the pseudoinverse of B.
*
* <p>This feedforward does not account for a dynamic B matrix, B is either
* determined or supplied when the feedforward is created and remains constant.
*
* <p>For more on the underlying math, read
* https://file.tavsys.net/control/controls-engineering-in-frc.pdf.
*/
@SuppressWarnings({"ParameterName", "LocalVariableName", "MemberName", "ClassTypeParameterName"})
public class ControlAffinePlantInversionFeedforward<States extends Num, Inputs extends Num> {
/**
* The current reference state.
*/
@SuppressWarnings("MemberName")
private Matrix<States, N1> m_r;
/**
* The computed feedforward.
*/
private Matrix<Inputs, N1> m_uff;
@SuppressWarnings("MemberName")
private final Matrix<States, Inputs> m_B;
private final Nat<Inputs> m_inputs;
private final double m_dt;
/**
* The model dynamics.
*/
private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f;
/**
* Constructs a feedforward with given model dynamics as a function
* of state and input.
*
* @param states A {@link Nat} representing the number of states.
* @param inputs A {@link Nat} representing the number of inputs.
* @param f A vector-valued function of x, the state, and
* u, the input, that returns the derivative of
* the state vector. HAS to be control-affine
* (of the form f(x) + Bu).
* @param dtSeconds The timestep between calls of calculate().
*/
public ControlAffinePlantInversionFeedforward(
Nat<States> states,
Nat<Inputs> inputs,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
double dtSeconds) {
this.m_dt = dtSeconds;
this.m_f = f;
this.m_inputs = inputs;
this.m_B = NumericalJacobian.numericalJacobianU(states, inputs,
m_f, new Matrix<>(states, Nat.N1()), new Matrix<>(inputs, Nat.N1()));
m_r = new Matrix<>(states, Nat.N1());
m_uff = new Matrix<>(inputs, Nat.N1());
reset(m_r);
}
/**
* Constructs a feedforward with given model dynamics as a function of state,
* and the plant's B(continuous input matrix) matrix.
*
* @param states A {@link Nat} representing the number of states.
* @param inputs A {@link Nat} representing the number of inputs.
* @param f A vector-valued function of x, the state,
* that returns the derivative of the state vector.
* @param B Continuous input matrix of the plant being controlled.
* @param dtSeconds The timestep between calls of calculate().
*/
public ControlAffinePlantInversionFeedforward(
Nat<States> states,
Nat<Inputs> inputs,
Function<Matrix<States, N1>, Matrix<States, N1>> f,
Matrix<States, Inputs> B,
double dtSeconds) {
this.m_dt = dtSeconds;
this.m_inputs = inputs;
this.m_f = (x, u) -> f.apply(x);
this.m_B = B;
m_r = new Matrix<>(states, Nat.N1());
m_uff = new Matrix<>(inputs, Nat.N1());
reset(m_r);
}
/**
* Returns the previously calculated feedforward as an input vector.
*
* @return The calculated feedforward.
*/
public Matrix<Inputs, N1> getUff() {
return m_uff;
}
/**
* Returns an element of the previously calculated feedforward.
*
* @param row Row of uff.
*
* @return The row of the calculated feedforward.
*/
public double getUff(int row) {
return m_uff.get(row, 0);
}
/**
* Returns the current reference vector r.
*
* @return The current reference vector.
*/
public Matrix<States, N1> getR() {
return m_r;
}
/**
* Returns an element of the current reference vector r.
*
* @param row Row of r.
*
* @return The row of the current reference vector.
*/
public double getR(int row) {
return m_r.get(row, 0);
}
/**
* Resets the feedforward with a specified initial state vector.
*
* @param initialState The initial state vector.
*/
public void reset(Matrix<States, N1> initialState) {
m_r = initialState;
m_uff.fill(0.0);
}
/**
* Resets the feedforward with a zero initial state vector.
*/
public void reset() {
m_r.fill(0.0);
m_uff.fill(0.0);
}
/**
* Calculate the feedforward with only the desired
* future reference. This uses the internally stored "current"
* reference.
*
* <p>If this method is used the initial state of the system is the one
* set using {@link LinearPlantInversionFeedforward#reset(Matrix)}.
* If the initial state is not set it defaults to a zero vector.
*
* @param nextR The reference state of the future timestep (k + dt).
*
* @return The calculated feedforward.
*/
public Matrix<Inputs, N1> calculate(Matrix<States, N1> nextR) {
return calculate(m_r, nextR);
}
/**
* Calculate the feedforward with current and future reference vectors.
*
* @param r The reference state of the current timestep (k).
* @param nextR The reference state of the future timestep (k + dt).
*
* @return The calculated feedforward.
*/
@SuppressWarnings({"ParameterName", "LocalVariableName"})
public Matrix<Inputs, N1> calculate(Matrix<States, N1> r, Matrix<States, N1> nextR) {
var rDot = (nextR.minus(r)).div(m_dt);
m_uff = m_B.solve(rDot.minus(m_f.apply(r, new Matrix<>(m_inputs, Nat.N1()))));
m_r = nextR;
return m_uff;
}
}

View File

@@ -0,0 +1,170 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.controller;
import org.ejml.simple.SimpleMatrix;
import edu.wpi.first.wpilibj.math.Discretization;
import edu.wpi.first.wpilibj.system.LinearSystem;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Num;
import edu.wpi.first.wpiutil.math.numbers.N1;
/**
* Constructs a plant inversion model-based feedforward from a {@link LinearSystem}.
*
* <p>The feedforward is calculated as <strong> u_ff = B<sup>+</sup> (r_k+1 - A r_k) </strong>,
* where <strong> B<sup>+</sup> </strong> is the pseudoinverse of B.
*
* <p>For more on the underlying math, read
* https://file.tavsys.net/control/controls-engineering-in-frc.pdf.
*/
@SuppressWarnings({"ParameterName", "LocalVariableName", "MemberName", "ClassTypeParameterName"})
public class LinearPlantInversionFeedforward<States extends Num, Inputs extends Num,
Outputs extends Num> {
/**
* The current reference state.
*/
@SuppressWarnings("MemberName")
private Matrix<States, N1> m_r;
/**
* The computed feedforward.
*/
private Matrix<Inputs, N1> m_uff;
@SuppressWarnings("MemberName")
private Matrix<States, Inputs> m_B;
@SuppressWarnings("MemberName")
private Matrix<States, States> m_A;
/**
* Constructs a feedforward with the given plant.
*
* @param plant The plant being controlled.
* @param dtSeconds Discretization timestep.
*/
public LinearPlantInversionFeedforward(
LinearSystem<States, Inputs, Outputs> plant,
double dtSeconds
) {
this(plant.getA(), plant.getB(), dtSeconds);
}
/**
* Constructs a feedforward with the given coefficients.
*
* @param A Continuous system matrix of the plant being controlled.
* @param B Continuous input matrix of the plant being controlled.
* @param dtSeconds Discretization timestep.
*/
@SuppressWarnings({"ParameterName", "LocalVariableName"})
public LinearPlantInversionFeedforward(Matrix<States, States> A, Matrix<States, Inputs> B,
double dtSeconds) {
var discABPair = Discretization.discretizeAB(A, B, dtSeconds);
this.m_A = discABPair.getFirst();
this.m_B = discABPair.getSecond();
m_r = new Matrix<States, N1>(new SimpleMatrix(B.getNumRows(), 1));
m_uff = new Matrix<Inputs, N1>(new SimpleMatrix(B.getNumCols(), 1));
reset(m_r);
}
/**
* Returns the previously calculated feedforward as an input vector.
*
* @return The calculated feedforward.
*/
public Matrix<Inputs, N1> getUff() {
return m_uff;
}
/**
* Returns an element of the previously calculated feedforward.
*
* @param row Row of uff.
*
* @return The row of the calculated feedforward.
*/
public double getUff(int row) {
return m_uff.get(row, 0);
}
/**
* Returns the current reference vector r.
*
* @return The current reference vector.
*/
public Matrix<States, N1> getR() {
return m_r;
}
/**
* Returns an element of the current reference vector r.
*
* @param row Row of r.
*
* @return The row of the current reference vector.
*/
public double getR(int row) {
return m_r.get(row, 0);
}
/**
* Resets the feedforward with a specified initial state vector.
*
* @param initialState The initial state vector.
*/
public void reset(Matrix<States, N1> initialState) {
m_r = initialState;
m_uff.fill(0.0);
}
/**
* Resets the feedforward with a zero initial state vector.
*/
public void reset() {
m_r.fill(0.0);
m_uff.fill(0.0);
}
/**
* Calculate the feedforward with only the desired
* future reference. This uses the internally stored "current"
* reference.
*
* <p>If this method is used the initial state of the system is the one
* set using {@link LinearPlantInversionFeedforward#reset(Matrix)}.
* If the initial state is not set it defaults to a zero vector.
*
* @param nextR The reference state of the future timestep (k + dt).
*
* @return The calculated feedforward.
*/
public Matrix<Inputs, N1> calculate(Matrix<States, N1> nextR) {
return calculate(m_r, nextR);
}
/**
* Calculate the feedforward with current and future reference vectors.
*
* @param r The reference state of the current timestep (k).
* @param nextR The reference state of the future timestep (k + dt).
*
* @return The calculated feedforward.
*/
@SuppressWarnings({"ParameterName", "LocalVariableName"})
public Matrix<Inputs, N1> calculate(Matrix<States, N1> r, Matrix<States, N1> nextR) {
m_uff = new Matrix<>(m_B.solve(nextR.minus(m_A.times(r))));
m_r = nextR;
return m_uff;
}
}

View File

@@ -0,0 +1,254 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.controller;
import org.ejml.simple.SimpleMatrix;
import edu.wpi.first.math.Drake;
import edu.wpi.first.wpilibj.math.Discretization;
import edu.wpi.first.wpilibj.math.StateSpaceUtil;
import edu.wpi.first.wpilibj.system.LinearSystem;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Nat;
import edu.wpi.first.wpiutil.math.Num;
import edu.wpi.first.wpiutil.math.Vector;
import edu.wpi.first.wpiutil.math.numbers.N1;
/**
* Contains the controller coefficients and logic for a linear-quadratic
* regulator (LQR).
* LQRs use the control law u = K(r - x).
*
* <p>For more on the underlying math, read
* https://file.tavsys.net/control/controls-engineering-in-frc.pdf.
*/
@SuppressWarnings("ClassTypeParameterName")
public class LinearQuadraticRegulator<States extends Num, Inputs extends Num,
Outputs extends Num> {
/**
* The current reference state.
*/
@SuppressWarnings("MemberName")
private Matrix<States, N1> m_r;
/**
* The computed and capped controller output.
*/
@SuppressWarnings("MemberName")
private Matrix<Inputs, N1> m_u;
// Controller gain.
@SuppressWarnings("MemberName")
private Matrix<Inputs, States> m_K;
/**
* Constructs a controller with the given coefficients and plant. Rho is defaulted to 1.
*
* @param plant The plant being controlled.
* @param qelms The maximum desired error tolerance for each state.
* @param relms The maximum desired control effort for each input.
* @param dtSeconds Discretization timestep.
*/
public LinearQuadraticRegulator(
LinearSystem<States, Inputs, Outputs> plant,
Vector<States> qelms,
Vector<Inputs> relms,
double dtSeconds
) {
this(plant.getA(), plant.getB(), qelms, 1.0, relms, dtSeconds);
}
/**
* Constructs a controller with the given coefficients and plant.
*
* @param plant The plant being controlled.
* @param qelms The maximum desired error tolerance for each state.
* @param rho A weighting factor that balances control effort and state excursion.
* Greater values penalize state excursion more heavily. 1 is a good starting
* value.
* @param relms The maximum desired control effort for each input.
* @param dtSeconds Discretization timestep.
*/
public LinearQuadraticRegulator(
LinearSystem<States, Inputs, Outputs> plant,
Vector<States> qelms,
double rho,
Vector<Inputs> relms,
double dtSeconds
) {
this(plant.getA(), plant.getB(), qelms, rho, relms, dtSeconds);
}
/**
* Constructs a controller with the given coefficients and plant.
*
* @param A Continuous system matrix of the plant being controlled.
* @param B Continuous input matrix of the plant being controlled.
* @param qelms The maximum desired error tolerance for each state.
* @param relms The maximum desired control effort for each input.
* @param dtSeconds Discretization timestep.
*/
@SuppressWarnings({"ParameterName", "LocalVariableName"})
public LinearQuadraticRegulator(Matrix<States, States> A, Matrix<States, Inputs> B,
Vector<States> qelms, Vector<Inputs> relms,
double dtSeconds
) {
this(A, B, qelms, 1.0, relms, dtSeconds);
}
/**
* Constructs a controller with the given coefficients and plant.
*
* @param A Continuous system matrix of the plant being controlled.
* @param B Continuous input matrix of the plant being controlled.
* @param qelms The maximum desired error tolerance for each state.
* @param rho A weighting factor that balances control effort and state excursion.
* Greater
* values penalize state excursion more heavily. 1 is a good starting value.
* @param relms The maximum desired control effort for each input.
* @param dtSeconds Discretization timestep.
*/
@SuppressWarnings({"ParameterName", "LocalVariableName"})
public LinearQuadraticRegulator(Matrix<States, States> A, Matrix<States, Inputs> B,
Vector<States> qelms, double rho, Vector<Inputs> relms,
double dtSeconds
) {
this(A, B, StateSpaceUtil.makeCostMatrix(qelms).times(rho),
StateSpaceUtil.makeCostMatrix(relms), dtSeconds);
}
/**
* Constructs a controller with the given coefficients and plant.
* @param A Continuous system matrix of the plant being controlled.
* @param B Continuous input matrix of the plant being controlled.
* @param Q The state cost matrix.
* @param R The input cost matrix.
* @param dtSeconds Discretization timestep.
*/
@SuppressWarnings({"ParameterName", "LocalVariableName"})
public LinearQuadraticRegulator(Matrix<States, States> A, Matrix<States, Inputs> B,
Matrix<States, States> Q, Matrix<Inputs, Inputs> R,
double dtSeconds
) {
var discABPair = Discretization.discretizeAB(A, B, dtSeconds);
var discA = discABPair.getFirst();
var discB = discABPair.getSecond();
var S = Drake.discreteAlgebraicRiccatiEquation(discA, discB, Q, R);
var temp = discB.transpose().times(S).times(discB).plus(R);
m_K = temp.solve(discB.transpose().times(S).times(discA));
m_r = new Matrix<>(new SimpleMatrix(B.getNumRows(), 1));
m_u = new Matrix<>(new SimpleMatrix(B.getNumCols(), 1));
reset();
}
/**
* Constructs a controller with the given coefficients and plant.
*
* @param states The number of states.
* @param inputs The number of inputs.
* @param k The gain matrix.
*/
@SuppressWarnings("ParameterName")
public LinearQuadraticRegulator(
Nat<States> states, Nat<Inputs> inputs,
Matrix<Inputs, States> k
) {
m_K = k;
m_r = new Matrix<>(states, Nat.N1());
m_u = new Matrix<>(inputs, Nat.N1());
reset();
}
/**
* Returns the control input vector u.
*
* @return The control input.
*/
public Matrix<Inputs, N1> getU() {
return m_u;
}
/**
* Returns an element of the control input vector u.
*
* @param row Row of u.
*
* @return The row of the control input vector.
*/
public double getU(int row) {
return m_u.get(row, 0);
}
/**
* Returns the reference vector r.
*
* @return The reference vector.
*/
public Matrix<States, N1> getR() {
return m_r;
}
/**
* Returns an element of the reference vector r.
*
* @param row Row of r.
*
* @return The row of the reference vector.
*/
public double getR(int row) {
return m_r.get(row, 0);
}
/**
* Returns the controller matrix K.
*
* @return the controller matrix K.
*/
public Matrix<Inputs, States> getK() {
return m_K;
}
/**
* Resets the controller.
*/
public void reset() {
m_r.fill(0.0);
m_u.fill(0.0);
}
/**
* Returns the next output of the controller.
*
* @param x The current state x.
*/
@SuppressWarnings("ParameterName")
public Matrix<Inputs, N1> calculate(Matrix<States, N1> x) {
m_u = m_K.times(m_r.minus(x));
return m_u;
}
/**
* Returns the next output of the controller.
*
* @param x The current state x.
* @param nextR the next reference vector r.
*/
@SuppressWarnings("ParameterName")
public Matrix<Inputs, N1> calculate(Matrix<States, N1> x, Matrix<States, N1> nextR) {
m_r = nextR;
return calculate(x);
}
}

View File

@@ -0,0 +1,286 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.estimator;
import java.util.function.BiFunction;
import edu.wpi.first.math.Drake;
import edu.wpi.first.wpilibj.math.Discretization;
import edu.wpi.first.wpilibj.math.StateSpaceUtil;
import edu.wpi.first.wpilibj.system.NumericalJacobian;
import edu.wpi.first.wpilibj.system.RungeKutta;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Nat;
import edu.wpi.first.wpiutil.math.Num;
import edu.wpi.first.wpiutil.math.numbers.N1;
/**
* Kalman filters combine predictions from a model and measurements to give an estimate of the true
* system state. This is useful because many states cannot be measured directly as a result of
* sensor noise, or because the state is "hidden".
*
* <p>The Extended Kalman filter is just like the {@link KalmanFilter Kalman filter}, but we make a
* linear approximation of nonlinear dynamics and/or nonlinear measurement models. This means that
* the EKF works with nonlinear systems.
*/
@SuppressWarnings("ClassTypeParameterName")
public class ExtendedKalmanFilter<States extends Num, Inputs extends Num, Outputs extends Num>
implements KalmanTypeFilter<States, Inputs, Outputs> {
private final Nat<States> m_states;
private final Nat<Outputs> m_outputs;
@SuppressWarnings("MemberName")
private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f;
@SuppressWarnings("MemberName")
private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h;
private final Matrix<States, States> m_contQ;
private Matrix<Outputs, Outputs> m_discR;
private final Matrix<States, States> m_initP;
private final Matrix<Outputs, Outputs> m_contR;
@SuppressWarnings("MemberName")
private Matrix<States, N1> m_xHat;
@SuppressWarnings("MemberName")
private Matrix<States, States> m_P;
/**
* Constructs an extended Kalman filter.
*
* @param states a Nat representing the number of states.
* @param inputs a Nat representing the number of inputs.
* @param outputs a Nat representing the number of outputs.
* @param f A vector-valued function of x and u that returns
* the derivative of the state vector.
* @param h A vector-valued function of x and u that returns
* the measurement vector.
* @param stateStdDevs Standard deviations of model states.
* @param measurementStdDevs Standard deviations of measurements.
* @param dtSeconds Nominal discretization timestep.
*/
@SuppressWarnings("ParameterName")
public ExtendedKalmanFilter(
Nat<States> states,
Nat<Inputs> inputs,
Nat<Outputs> outputs,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> h,
Matrix<States, N1> stateStdDevs,
Matrix<Outputs, N1> measurementStdDevs,
double dtSeconds
) {
m_states = states;
m_outputs = outputs;
m_f = f;
m_h = h;
reset();
m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
this.m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
final var contA = NumericalJacobian
.numericalJacobianX(states, states, f, m_xHat, new Matrix<>(inputs, Nat.N1()));
final var C = NumericalJacobian
.numericalJacobianX(outputs, states, h, m_xHat, new Matrix<>(inputs, Nat.N1()));
final var discPair = Discretization.discretizeAQTaylor(contA, m_contQ, dtSeconds);
final var discA = discPair.getFirst();
final var discQ = discPair.getSecond();
m_discR = Discretization.discretizeR(m_contR, dtSeconds);
// IsStabilizable(A^T, C^T) will tell us if the system is observable.
boolean isObservable = StateSpaceUtil.isStabilizable(discA.transpose(), C.transpose());
if (isObservable && outputs.getNum() <= states.getNum()) {
m_initP = Drake.discreteAlgebraicRiccatiEquation(
discA.transpose(), C.transpose(), discQ, m_discR) ;
} else {
m_initP = new Matrix<>(states, states);
}
m_P = m_initP;
}
/**
* Returns the error covariance matrix P.
*
* @return the error covariance matrix P.
*/
@Override
public Matrix<States, States> getP() {
return m_P;
}
/**
* Returns an element of the error covariance matrix P.
*
* @param row Row of P.
* @param col Column of P.
* @return the value of the error covariance matrix P at (i, j).
*/
@Override
public double getP(int row, int col) {
return m_P.get(row, col);
}
/**
* Sets the entire error covariance matrix P.
*
* @param newP The new value of P to use.
*/
@Override
public void setP(Matrix<States, States> newP) {
m_P = newP;
}
/**
* Returns the state estimate x-hat.
*
* @return the state estimate x-hat.
*/
@Override
public Matrix<States, N1> getXhat() {
return m_xHat;
}
/**
* Returns an element of the state estimate x-hat.
*
* @param row Row of x-hat.
* @return the value of the state estimate x-hat at i.
*/
@Override
public double getXhat(int row) {
return m_xHat.get(row, 0);
}
/**
* Set initial state estimate x-hat.
*
* @param xHat The state estimate x-hat.
*/
@SuppressWarnings("ParameterName")
@Override
public void setXhat(Matrix<States, N1> xHat) {
m_xHat = xHat;
}
/**
* Set an element of the initial state estimate x-hat.
*
* @param row Row of x-hat.
* @param value Value for element of x-hat.
*/
@Override
public void setXhat(int row, double value) {
m_xHat.set(row, 0, value);
}
@Override
public void reset() {
m_xHat = new Matrix<>(m_states, Nat.N1());
m_P = m_initP;
}
/**
* Project the model into the future with a new control input u.
*
* @param u New control input from controller.
* @param dtSeconds Timestep for prediction.
*/
@SuppressWarnings("ParameterName")
@Override
public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
predict(u, m_f, dtSeconds);
}
/**
* Project the model into the future with a new control input u.
*
* @param u New control input from controller.
* @param f The function used to linearlize the model.
* @param dtSeconds Timestep for prediction.
*/
@SuppressWarnings("ParameterName")
public void predict(
Matrix<Inputs, N1> u, BiFunction<Matrix<States, N1>,
Matrix<Inputs, N1>, Matrix<States, N1>> f,
double dtSeconds
) {
// Find continuous A
final var contA = NumericalJacobian.numericalJacobianX(m_states, m_states, f, m_xHat, u);
// Find discrete A and Q
final var discPair = Discretization.discretizeAQTaylor(contA, m_contQ, dtSeconds);
final var discA = discPair.getFirst();
final var discQ = discPair.getSecond();
m_xHat = RungeKutta.rungeKutta(f, m_xHat, u, dtSeconds);
m_P = discA.times(m_P).times(discA.transpose()).plus(discQ);
m_discR = Discretization.discretizeR(m_contR, dtSeconds);
}
/**
* Correct the state estimate x-hat using the measurements in y.
*
* @param u Same control input used in the predict step.
* @param y Measurement vector.
*/
@SuppressWarnings("ParameterName")
@Override
public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
correct(m_outputs, u, y, m_h, m_discR);
}
/**
* Correct the state estimate x-hat using the measurements in y.
*
* <p>This is useful for when the measurements available during a timestep's
* Correct() call vary. The h(x, u) passed to the constructor is used if one is
* not provided (the two-argument version of this function).
*
* @param <Rows> Number of rows in the result of f(x, u).
* @param rows Number of rows in the result of f(x, u).
* @param u Same control input used in the predict step.
* @param y Measurement vector.
* @param h A vector-valued function of x and u that returns the measurement
* vector.
* @param R Discrete measurement noise covariance matrix.
*/
@SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
public <Rows extends Num> void correct(
Nat<Rows> rows, Matrix<Inputs, N1> u,
Matrix<Rows, N1> y,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Rows, N1>> h,
Matrix<Rows, Rows> R
) {
final var C = NumericalJacobian.numericalJacobianX(rows, m_states, h, m_xHat, u);
final var S = C.times(m_P).times(C.transpose()).plus(R);
// We want to put K = PC^T S^-1 into Ax = b form so we can solve it more
// efficiently.
//
// K = PC^T S^-1
// KS = PC^T
// (KS)^T = (PC^T)^T
// S^T K^T = CP^T
//
// The solution of Ax = b can be found via x = A.solve(b).
//
// K^T = S^T.solve(CP^T)
// K = (S^T.solve(CP^T))^T
//
// Now we have the Kalman gain
final Matrix<States, Rows> K = S.transpose().solve(C.times(m_P.transpose())).transpose();
m_xHat = m_xHat.plus(K.times(y.minus(h.apply(m_xHat, u))));
m_P = Matrix.eye(m_states).minus(K.times(C)).times(m_P);
}
}

View File

@@ -0,0 +1,290 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.estimator;
import org.ejml.simple.SimpleMatrix;
import edu.wpi.first.math.Drake;
import edu.wpi.first.math.MathSharedStore;
import edu.wpi.first.wpilibj.math.Discretization;
import edu.wpi.first.wpilibj.math.StateSpaceUtil;
import edu.wpi.first.wpilibj.system.LinearSystem;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Nat;
import edu.wpi.first.wpiutil.math.Num;
import edu.wpi.first.wpiutil.math.numbers.N1;
/**
* A Kalman filter combines predictions from a model and measurements to give an estimate of the
* true system state. This is useful because many states cannot be measured directly as a result of
* sensor noise, or because the state is "hidden".
*
* <p>Kalman filters use a K gain matrix to determine whether to trust the model or measurements
* more. Kalman filter theory uses statistics to compute an optimal K gain which minimizes the sum
* of squares error in the state estimate. This K gain is used to correct the state estimate by
* some amount of the difference between the actual measurements and the measurements predicted by
* the model.
*
* <p>For more on the underlying math, read
* https://file.tavsys.net/control/controls-engineering-in-frc.pdf chapter 9.
*/
@SuppressWarnings("ClassTypeParameterName")
public class KalmanFilter<States extends Num, Inputs extends Num,
Outputs extends Num> implements KalmanTypeFilter<States, Inputs, Outputs> {
private final Nat<States> m_states;
private final LinearSystem<States, Inputs, Outputs> m_plant;
private final Matrix<States, N1> m_stateStdDevs;
private final Matrix<Outputs, N1> m_measurementStdDevs;
/**
* Error covariance matrix.
*/
@SuppressWarnings("MemberName")
private Matrix<States, States> m_P;
/**
* Continuous process noise covariance matrix.
*/
private final Matrix<States, States> m_contQ;
/**
* Continuous measurement noise covariance matrix.
*/
private final Matrix<Outputs, Outputs> m_contR;
/**
* Discrete measurement noise covariance matrix.
*/
private Matrix<Outputs, Outputs> m_discR;
/**
* The current state estimate x-hat.
*/
@SuppressWarnings("MemberName")
private Matrix<States, N1> m_xHat;
/**
* Constructs a state-space observer with the given plant.
*
* @param states A Nat representing the states of the system.
* @param outputs A Nat representing the outputs of the system.
* @param plant The plant used for the prediction step.
* @param stateStdDevs Standard deviations of model states.
* @param measurementStdDevs Standard deviations of measurements.
* @param dtSeconds Nominal discretization timestep.
*/
public KalmanFilter(
Nat<States> states, Nat<Outputs> outputs,
LinearSystem<States, Inputs, Outputs> plant,
Matrix<States, N1> stateStdDevs,
Matrix<Outputs, N1> measurementStdDevs,
double dtSeconds
) {
this.m_states = states;
this.m_plant = plant;
this.m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
this.m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
this.m_stateStdDevs = stateStdDevs;
this.m_measurementStdDevs = measurementStdDevs;
var pair = Discretization.discretizeAQTaylor(plant.getA(), m_contQ, dtSeconds);
var discA = pair.getFirst();
var discQ = pair.getSecond();
m_discR = Discretization.discretizeR(m_contR, dtSeconds);
// IsStabilizable(A^T, C^T) will tell us if the system is observable.
var isObservable = StateSpaceUtil.isStabilizable(discA.transpose(), plant.getC().transpose());
if (isObservable) {
if (outputs.getNum() <= states.getNum()) {
m_P = Drake.discreteAlgebraicRiccatiEquation(
discA.transpose(), plant.getC().transpose(), discQ, m_discR);
} else {
m_P = new Matrix<>(new SimpleMatrix(states.getNum(), states.getNum()));
}
} else {
MathSharedStore.reportError("The system passed to the Kalman Filter is not observable!",
Thread.currentThread().getStackTrace());
throw new IllegalArgumentException(
"The system passed to the Kalman Filter is not observable!");
}
reset();
}
@Override
public void reset() {
m_xHat = new Matrix<>(m_states, Nat.N1());
}
/**
* Returns the error covariance matrix P.
*
* @return the error covariance matrix P.
*/
@Override
public Matrix<States, States> getP() {
return m_P;
}
/**
* Returns an element of the error covariance matrix P.
*
* @param row Row of P.
* @param col Column of P.
* @return the element (i, j) of the error covariance matrix P.
*/
@Override
public double getP(int row, int col) {
return m_P.get(row, col);
}
/**
* Sets the entire error covariance matrix P.
*
* @param newP The new value of P to use.
*/
@Override
public void setP(Matrix<States, States> newP) {
m_P = newP;
}
/**
* Set initial state estimate x-hat.
*
* @param xhat The state estimate x-hat.
*/
@Override
public void setXhat(Matrix<States, N1> xhat) {
this.m_xHat = xhat;
}
/**
* Set an element of the initial state estimate x-hat.
*
* @param row Row of x-hat.
* @param value Value for element of x-hat.
*/
@Override
public void setXhat(int row, double value) {
m_xHat.set(row, 0, value);
}
/**
* Returns the state estimate x-hat.
*
* @return The state estimate x-hat.
*/
@Override
public Matrix<States, N1> getXhat() {
return m_xHat;
}
/**
* Returns an element of the state estimate x-hat.
*
* @param row Row of x-hat.
* @return the state estimate x-hat at i.
*/
@Override
public double getXhat(int row) {
return m_xHat.get(row, 0);
}
/**
* Returns the state standard deviations used to make Q.
*/
public Matrix<States, N1> getStateStdDevs() {
return m_stateStdDevs;
}
/**
* Returns the measurement standard deviations used to make R.
*/
public Matrix<Outputs, N1> getMeasurementStdDevs() {
return m_measurementStdDevs;
}
/**
* Project the model into the future with a new control input u.
*
* @param u New control input from controller.
* @param dtSeconds Timestep for prediction.
*/
@SuppressWarnings("ParameterName")
@Override
public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
this.m_xHat = m_plant.calculateX(m_xHat, u, dtSeconds);
var pair = Discretization.discretizeAQTaylor(m_plant.getA(), m_contQ, dtSeconds);
var discA = pair.getFirst();
var discQ = pair.getSecond();
m_P = discA.times(m_P).times(discA.transpose()).plus(discQ);
m_discR = Discretization.discretizeR(m_contR, dtSeconds);
}
/**
* Correct the state estimate x-hat using the measurements in y.
*
* @param u Same control input used in the last predict step.
* @param y Measurement vector.
*/
@SuppressWarnings("ParameterName")
@Override
public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
correct(u, y, m_plant.getC(), m_plant.getD(), m_discR);
}
/**
* Correct the state estimate x-hat using the measurements in y.
*
* <p>This is useful for when the measurements available during a timestep's
* Correct() call vary. The C matrix passed to the constructor is used if one
* is not provided (the two-argument version of this function).
*
* @param <R> Number of rows in the result of f(x, u).
* @param u Same control input used in the predict step.
* @param y Measurement vector.
* @param C Output matrix.
* @param r Discrete measurement noise covariance matrix.
*/
@SuppressWarnings({"ParameterName", "LocalVariableName"})
public <R extends Num> void correct(
Matrix<Inputs, N1> u,
Matrix<R, N1> y,
Matrix<R, States> C,
Matrix<R, Inputs> D,
Matrix<R, R> r) {
var S = C.times(m_P).times(C.transpose()).plus(r);
// We want to put K = PC^T S^-1 into Ax = b form so we can solve it more
// efficiently.
//
// K = PC^T S^-1
// KS = PC^T
// (KS)^T = (PC^T)^T
// S^T K^T = CP^T
//
// The solution of Ax = b can be found via x = A.solve(b).
//
// K^T = S^T.solve(CP^T)
// K = (S^T.solve(CP^T))^T
Matrix<States, R> K = S.transpose().solve(C.times(m_P.transpose())).transpose();
m_xHat = m_xHat.plus(K.times(y.minus(C.times(m_xHat).plus(D.times(u)))));
m_P = Matrix.eye(m_states).minus(K.times(C)).times(m_P);
}
}

View File

@@ -0,0 +1,35 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.estimator;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Num;
import edu.wpi.first.wpiutil.math.numbers.N1;
@SuppressWarnings({"ParameterName", "InterfaceTypeParameterName"})
interface KalmanTypeFilter<States extends Num, Inputs extends Num, Outputs extends Num> {
Matrix<States, States> getP();
double getP(int i, int j);
void setP(Matrix<States, States> newP);
Matrix<States, N1> getXhat();
double getXhat(int i);
void setXhat(Matrix<States, N1> xHat);
void setXhat(int i, double value);
void reset();
void predict(Matrix<Inputs, N1> u, double dtSeconds);
void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y);
}

View File

@@ -0,0 +1,168 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.estimator;
import org.ejml.simple.SimpleMatrix;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Nat;
import edu.wpi.first.wpiutil.math.Num;
import edu.wpi.first.wpiutil.math.numbers.N1;
/**
* Generates sigma points and weights according to Van der Merwe's 2004
* dissertation[1] for the UnscentedKalmanFilter class.
*
* <p>It parametrizes the sigma points using alpha, beta, kappa terms, and is the
* version seen in most publications. Unless you know better, this should be
* your default choice.
*
* <p>States is the dimensionality of the state. 2*States+1 weights will be
* generated.
*
* <p>[1] R. Van der Merwe "Sigma-Point Kalman Filters for Probabilitic
* Inference in Dynamic State-Space Models" (Doctoral dissertation)
*/
public class MerweScaledSigmaPoints<S extends Num> {
private final double m_alpha;
private final int m_kappa;
private final Nat<S> m_states;
private Matrix<?, N1> m_wm;
private Matrix<?, N1> m_wc;
/**
* Constructs a generator for Van der Merwe scaled sigma points.
*
* @param states an instance of Num that represents the number of states.
* @param alpha Determines the spread of the sigma points around the mean.
* Usually a small positive value (1e-3).
* @param beta Incorporates prior knowledge of the distribution of the mean.
* For Gaussian distributions, beta = 2 is optimal.
* @param kappa Secondary scaling parameter usually set to 0 or 3 - States.
*/
public MerweScaledSigmaPoints(Nat<S> states, double alpha, double beta, int kappa) {
this.m_states = states;
this.m_alpha = alpha;
this.m_kappa = kappa;
computeWeights(beta);
}
/**
* Constructs a generator for Van der Merwe scaled sigma points with default values for alpha,
* beta, and kappa.
*
* @param states an instance of Num that represents the number of states.
*/
public MerweScaledSigmaPoints(Nat<S> states) {
this(states, 1e-3, 2, 3 - states.getNum());
}
/**
* Returns number of sigma points for each variable in the state x.
*
* @return The number of sigma points for each variable in the state x.
*/
public int getNumSigmas() {
return 2 * m_states.getNum() + 1;
}
/**
* Computes the sigma points for an unscented Kalman filter given the mean
* (x) and covariance(P) of the filter.
*
* @param x An array of the means.
* @param P Covariance of the filter.
* @return Two dimensional array of sigma points. Each column contains all of
* the sigmas for one dimension in the problem space. Ordered by
* Xi_0, Xi_{1..n}, Xi_{n+1..2n}.
*/
@SuppressWarnings({"ParameterName", "LocalVariableName"})
public Matrix<S, ?> sigmaPoints(
Matrix<S, N1> x,
Matrix<S, S> P) {
double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum();
var intermediate = P.times(lambda + m_states.getNum());
var U = intermediate.lltDecompose(true); // Lower triangular
// 2 * states + 1 by states
Matrix<S, ?> sigmas = new Matrix<>(
new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1));
sigmas.setColumn(0, x);
for (int k = 0; k < m_states.getNum(); k++) {
var xPlusU = x.plus(U.extractColumnVector(k));
var xMinusU = x.minus(U.extractColumnVector(k));
sigmas.setColumn(k + 1, xPlusU);
sigmas.setColumn(m_states.getNum() + k + 1, xMinusU);
}
return new Matrix<>(sigmas);
}
/**
* Computes the weights for the scaled unscented Kalman filter.
*
* @param beta Incorporates prior knowledge of the distribution of the mean.
*/
@SuppressWarnings("LocalVariableName")
private void computeWeights(double beta) {
double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum();
double c = 0.5 / (m_states.getNum() + lambda);
Matrix<?, N1> wM = new Matrix<>(new SimpleMatrix(2 * m_states.getNum() + 1, 1));
Matrix<?, N1> wC = new Matrix<>(new SimpleMatrix(2 * m_states.getNum() + 1, 1));
wM.fill(c);
wC.fill(c);
wM.set(0, 0, lambda / (m_states.getNum() + lambda));
wC.set(0, 0, lambda / (m_states.getNum() + lambda) + (1 - Math.pow(m_alpha, 2) + beta));
this.m_wm = wM;
this.m_wc = wC;
}
/**
* Returns the weight for each sigma point for the mean.
*
* @return the weight for each sigma point for the mean.
*/
public Matrix<?, N1> getWm() {
return m_wm;
}
/**
* Returns an element of the weight for each sigma point for the mean.
*
* @param element Element of vector to return.
* @return the element i's weight for the mean.
*/
public double getWm(int element) {
return m_wm.get(element, 0);
}
/**
* Returns the weight for each sigma point for the covariance.
*
* @return the weight for each sigma point for the covariance.
*/
public Matrix<?, N1> getWc() {
return m_wc;
}
/**
* Returns an element of the weight for each sigma point for the covariance.
*
* @param element Element of vector to return.
* @return The element I's weight for the covariance.
*/
public double getWc(int element) {
return m_wc.get(element, 0);
}
}

View File

@@ -0,0 +1,314 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.estimator;
import java.util.function.BiFunction;
import org.ejml.simple.SimpleMatrix;
import edu.wpi.first.wpilibj.math.Discretization;
import edu.wpi.first.wpilibj.math.StateSpaceUtil;
import edu.wpi.first.wpilibj.system.NumericalJacobian;
import edu.wpi.first.wpilibj.system.RungeKutta;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Nat;
import edu.wpi.first.wpiutil.math.Num;
import edu.wpi.first.wpiutil.math.Pair;
import edu.wpi.first.wpiutil.math.numbers.N1;
/**
* A Kalman filter combines predictions from a model and measurements to give an estimate of the
* true ystem state. This is useful because many states cannot be measured directly as a result of
* sensor noise, or because the state is "hidden".
*
* <p>The Unscented Kalman filter is similar to the {@link KalmanFilter Kalman filter}, except that
* it propagates carefully chosen points called sigma points through the non-linear model to obtain
* an estimate of the true covariance (as opposed to a linearized version of it). This means that
* the UKF works with nonlinear systems.
*/
@SuppressWarnings({"MemberName", "ClassTypeParameterName"})
public class UnscentedKalmanFilter<States extends Num, Inputs extends Num,
Outputs extends Num> implements KalmanTypeFilter<States, Inputs, Outputs> {
private final Nat<States> m_states;
private final Nat<Outputs> m_outputs;
private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> m_f;
private final BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> m_h;
private Matrix<States, N1> m_xHat;
private Matrix<States, States> m_P;
private final Matrix<States, States> m_contQ;
private final Matrix<Outputs, Outputs> m_contR;
private Matrix<Outputs, Outputs> m_discR;
private Matrix<States, ?> m_sigmasF;
private final MerweScaledSigmaPoints<States> m_pts;
/**
* Constructs an Unscented Kalman Filter.
*
* @param states A Nat representing the number of states.
* @param outputs A Nat representing the number of outputs.
* @param f A vector-valued function of x and u that returns
* the derivative of the state vector.
* @param h A vector-valued function of x and u that returns
* the measurement vector.
* @param stateStdDevs Standard deviations of model states.
* @param measurementStdDevs Standard deviations of measurements.
* @param nominalDtSeconds Nominal discretization timestep.
*/
@SuppressWarnings("ParameterName")
public UnscentedKalmanFilter(Nat<States> states, Nat<Outputs> outputs,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>,
Matrix<States, N1>> f,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>,
Matrix<Outputs, N1>> h,
Matrix<States, N1> stateStdDevs,
Matrix<Outputs, N1> measurementStdDevs,
double nominalDtSeconds) {
this.m_states = states;
this.m_outputs = outputs;
m_f = f;
m_h = h;
m_contQ = StateSpaceUtil.makeCovarianceMatrix(states, stateStdDevs);
m_contR = StateSpaceUtil.makeCovarianceMatrix(outputs, measurementStdDevs);
m_discR = Discretization.discretizeR(m_contR, nominalDtSeconds);
m_pts = new MerweScaledSigmaPoints<>(states);
reset();
}
@SuppressWarnings({"ParameterName", "LocalVariableName", "PMD.CyclomaticComplexity"})
static <S extends Num, C extends Num>
Pair<Matrix<C, N1>, Matrix<C, C>> unscentedTransform(
Nat<S> s, Nat<C> dim, Matrix<C, ?> sigmas, Matrix<?, N1> Wm, Matrix<?, N1> Wc
) {
if (sigmas.getNumRows() != dim.getNum() || sigmas.getNumCols() != 2 * s.getNum() + 1) {
throw new IllegalArgumentException("Sigmas must be covDim by 2 * states + 1! Got "
+ sigmas.getNumRows() + " by " + sigmas.getNumCols());
}
if (Wm.getNumRows() != 2 * s.getNum() + 1 || Wm.getNumCols() != 1 ) {
throw new IllegalArgumentException("Wm must be 2 * states + 1 by 1! Got "
+ Wm.getNumRows() + " by " + Wm.getNumCols());
}
if (Wc.getNumRows() != 2 * s.getNum() + 1 || Wc.getNumCols() != 1) {
throw new IllegalArgumentException("Wc must be 2 * states + 1 by 1! Got "
+ Wc.getNumRows() + " by " + Wc.getNumCols());
}
// New mean is just the sum of the sigmas * weight
// dot = \Sigma^n_1 (W[k]*Xi[k])
Matrix<C, N1> x = sigmas.times(Matrix.changeBoundsUnchecked(Wm));
// New covariance is the sum of the outer product of the residuals times the
// weights
Matrix<C, ?> y = new Matrix<>(new SimpleMatrix(dim.getNum(), 2 * s.getNum() + 1));
for (int i = 0; i < 2 * s.getNum() + 1; i++) {
y.setColumn(i, sigmas.extractColumnVector(i).minus(x));
}
Matrix<C, C> P = y.times(Matrix.changeBoundsUnchecked(Wc.diag()))
.times(Matrix.changeBoundsUnchecked(y.transpose()));
return new Pair<>(x, P);
}
/**
* Returns the error covariance matrix P.
*
* @return the error covariance matrix P.
*/
@Override
public Matrix<States, States> getP() {
return m_P;
}
/**
* Returns an element of the error covariance matrix P.
*
* @param row Row of P.
* @param col Column of P.
* @return the value of the error covariance matrix P at (i, j).
*/
@Override
public double getP(int row, int col) {
return m_P.get(row, col);
}
/**
* Sets the entire error covariance matrix P.
*
* @param newP The new value of P to use.
*/
@Override
public void setP(Matrix<States, States> newP) {
m_P = newP;
}
/**
* Returns the state estimate x-hat.
*
* @return the state estimate x-hat.
*/
@Override
public Matrix<States, N1> getXhat() {
return m_xHat;
}
/**
* Returns an element of the state estimate x-hat.
*
* @param row Row of x-hat.
* @return the value of the state estimate x-hat at i.
*/
@Override
public double getXhat(int row) {
return m_xHat.get(row, 0);
}
/**
* Set initial state estimate x-hat.
*
* @param xHat The state estimate x-hat.
*/
@SuppressWarnings("ParameterName")
@Override
public void setXhat(Matrix<States, N1> xHat) {
m_xHat = xHat;
}
/**
* Set an element of the initial state estimate x-hat.
*
* @param row Row of x-hat.
* @param value Value for element of x-hat.
*/
@Override
public void setXhat(int row, double value) {
m_xHat.set(row, 0, value);
}
/**
* Resets the observer.
*/
@Override
public void reset() {
m_xHat = new Matrix<>(m_states, Nat.N1());
m_P = new Matrix<>(m_states, m_states);
m_sigmasF = new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1));
}
/**
* Project the model into the future with a new control input u.
*
* @param u New control input from controller.
* @param dtSeconds Timestep for prediction.
*/
@SuppressWarnings({"LocalVariableName", "ParameterName"})
@Override
public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
// Discretize Q before projecting mean and covariance forward
Matrix<States, States> contA =
NumericalJacobian.numericalJacobianX(m_states, m_states, m_f, m_xHat, u);
var discQ =
Discretization.discretizeAQTaylor(contA, m_contQ, dtSeconds).getSecond();
var sigmas = m_pts.sigmaPoints(m_xHat, m_P);
for (int i = 0; i < m_pts.getNumSigmas(); ++i) {
Matrix<States, N1> x = sigmas.extractColumnVector(i);
m_sigmasF.setColumn(i, RungeKutta.rungeKutta(m_f, x, u, dtSeconds));
}
var ret = unscentedTransform(m_states, m_states,
m_sigmasF, m_pts.getWm(), m_pts.getWc());
m_xHat = ret.getFirst();
m_P = ret.getSecond().plus(discQ);
m_discR = Discretization.discretizeR(m_contR, dtSeconds);
}
/**
* Correct the state estimate x-hat using the measurements in y.
*
* @param u Same control input used in the predict step.
* @param y Measurement vector.
*/
@SuppressWarnings("ParameterName")
@Override
public void correct(Matrix<Inputs, N1> u, Matrix<Outputs, N1> y) {
correct(m_outputs, u, y, m_h, m_discR);
}
/**
* Correct the state estimate x-hat using the measurements in y.
*
* <p>This is useful for when the measurements available during a timestep's
* Correct() call vary. The h(x, u) passed to the constructor is used if one
* is not provided (the two-argument version of this function).
*
* @param u Same control input used in the predict step.
* @param y Measurement vector.
* @param h A vector-valued function of x and u that returns
* the measurement vector.
* @param R Measurement noise covariance matrix.
*/
@SuppressWarnings({"ParameterName", "LocalVariableName"})
public <R extends Num> void correct(
Nat<R> rows, Matrix<Inputs, N1> u,
Matrix<R, N1> y,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<R, N1>> h,
Matrix<R, R> R) {
// Transform sigma points into measurement space
Matrix<R, ?> sigmasH = new Matrix<>(new SimpleMatrix(
rows.getNum(), 2 * m_states.getNum() + 1));
var sigmas = m_pts.sigmaPoints(m_xHat, m_P);
for (int i = 0; i < m_pts.getNumSigmas(); i++) {
Matrix<R, N1> hRet = h.apply(
sigmas.extractColumnVector(i),
u
);
sigmasH.setColumn(i, hRet);
}
// Mean and covariance of prediction passed through unscented transform
var transRet = unscentedTransform(m_states, rows, sigmasH, m_pts.getWm(), m_pts.getWc());
var yHat = transRet.getFirst();
var Py = transRet.getSecond().plus(R);
// Compute cross covariance of the state and the measurements
Matrix<States, R> Pxy = new Matrix<>(m_states, rows);
for (int i = 0; i < m_pts.getNumSigmas(); i++) {
var temp =
m_sigmasF.extractColumnVector(i).minus(m_xHat)
.times(sigmasH.extractColumnVector(i).minus(yHat).transpose());
Pxy = Pxy.plus(temp.times(m_pts.getWc(i)));
}
// K = P_{xy} Py^-1
// K^T = P_y^T^-1 P_{xy}^T
// P_y^T K^T = P_{xy}^T
// K^T = P_y^T.solve(P_{xy}^T)
// K = (P_y^T.solve(P_{xy}^T)^T
Matrix<States, R> K = new Matrix<>(
Py.transpose().solve(Pxy.transpose()).transpose()
);
m_xHat = m_xHat.plus(K.times(y.minus(yHat)));
m_P = m_P.minus(K.times(Py).times(K.transpose()));
}
}

View File

@@ -0,0 +1,179 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.math;
import org.ejml.simple.SimpleMatrix;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Nat;
import edu.wpi.first.wpiutil.math.Num;
import edu.wpi.first.wpiutil.math.Pair;
@SuppressWarnings({"PMD.TooManyMethods", "ParameterName", "MethodTypeParameterName"})
public final class Discretization {
private Discretization() {
// Utility class
}
/**
* Discretizes the given continuous A matrix.
*
* @param <States> Num representing the number of states.
* @param contA Continuous system matrix.
* @param dtSeconds Discretization timestep.
* @return the discrete matrix system.
*/
public static <States extends Num> Matrix<States, States> discretizeA(
Matrix<States, States> contA, double dtSeconds) {
return contA.times(dtSeconds).exp();
}
/**
* Discretizes the given continuous A and B matrices.
*
* <p>Rather than solving a (States + Inputs) x (States + Inputs) matrix
* exponential like in DiscretizeAB(), we take advantage of the structure of the
* block matrix of A and B.
*
* <p>1) The exponential of A*t, which is only N x N, is relatively cheap.
* 2) The upper-right quarter of the (States + Inputs) x (States + Inputs)
* matrix, which we can approximate using a taylor series to several terms
* and still be substantially cheaper than taking the big exponential.
*
* @param states Nat representing the states of the system.
* @param contA Continuous system matrix.
* @param contB Continuous input matrix.
* @param dtseconds Discretization timestep.
*/
public static <States extends Num, Inputs extends Num> Pair<Matrix<States, States>,
Matrix<States, Inputs>>
discretizeABTaylor(Nat<States> states,
Matrix<States, States> contA,
Matrix<States, Inputs> contB,
double dtseconds) {
Matrix<States, States> lastTerm = Matrix.eye(states);
double lastCoeff = dtseconds;
var phi12 = lastTerm.times(lastCoeff);
// i = 6 i.e. 5th order should be enough precision
for (int i = 2; i < 6; ++i) {
lastTerm = contA.times(lastTerm);
lastCoeff *= dtseconds / ((double) i);
phi12 = phi12.plus(lastTerm.times(lastCoeff));
}
var discB = phi12.times(contB);
var discA = discretizeA(contA, dtseconds);
return Pair.of(discA, discB);
}
/**
* Discretizes the given continuous A and Q matrices.
*
* <p>Rather than solving a 2N x 2N matrix exponential like in DiscretizeQ() (which
* is expensive), we take advantage of the structure of the block matrix of A
* and Q.
*
* <p>The exponential of A*t, which is only N x N, is relatively cheap.
* 2) The upper-right quarter of the 2N x 2N matrix, which we can approximate
* using a taylor series to several terms and still be substantially cheaper
* than taking the big exponential.
*
* @param <States> Nat representing the number of states.
* @param contA Continuous system matrix.
* @param contQ Continuous process noise covariance matrix.
* @param dtSeconds Discretization timestep.
* @return a pair representing the discrete system matrix and process noise covariance matrix.
*/
@SuppressWarnings("LocalVariableName")
public static <States extends Num> Pair<Matrix<States, States>,
Matrix<States, States>> discretizeAQTaylor(Matrix<States, States> contA,
Matrix<States, States> contQ,
double dtSeconds) {
Matrix<States, States> Q = (contQ.plus(contQ.transpose())).div(2.0);
Matrix<States, States> lastTerm = Q.copy();
double lastCoeff = dtSeconds;
// A^T^n
Matrix<States, States> Atn = contA.transpose();
Matrix<States, States> phi12 = lastTerm.times(lastCoeff);
// i = 6 i.e. 6th order should be enough precision
for (int i = 2; i < 6; ++i) {
lastTerm = contA.times(-1).times(lastTerm).plus(Q.times(Atn));
lastCoeff *= dtSeconds / ((double) i);
phi12 = phi12.plus(lastTerm.times(lastCoeff));
Atn = Atn.times(contA.transpose());
}
var discA = discretizeA(contA, dtSeconds);
Q = discA.times(phi12);
// Make Q symmetric if it isn't already
var discQ = Q.plus(Q.transpose()).div(2.0);
return new Pair<>(discA, discQ);
}
/**
* Returns a discretized version of the provided continuous measurement noise
* covariance matrix. Note that dt=0.0 divides R by zero.
*
* @param <O> Nat representing the number of outputs.
* @param R Continuous measurement noise covariance matrix.
* @param dtSeconds Discretization timestep.
* @return Discretized version of the provided continuous measurement noise covariance matrix.
*/
public static <O extends Num> Matrix<O, O> discretizeR(Matrix<O, O> R, double dtSeconds) {
return R.div(dtSeconds);
}
/**
* Discretizes the given continuous A and B matrices.
*
* @param <States> Nat representing the states of the system.
* @param <Inputs> Nat representing the inputs to the system.
* @param contA Continuous system matrix.
* @param contB Continuous input matrix.
* @param dtSeconds Discretization timestep.
* @return a Pair representing discA and diskB.
*/
@SuppressWarnings("LocalVariableName")
public static <States extends Num, Inputs extends Num> Pair<Matrix<States, States>,
Matrix<States, Inputs>> discretizeAB(
Matrix<States, States> contA,
Matrix<States, Inputs> contB,
double dtSeconds) {
var scaledA = contA.times(dtSeconds);
var scaledB = contB.times(dtSeconds);
var contSize = contB.getNumRows() + contB.getNumCols();
var Mcont = new Matrix<>(new SimpleMatrix(contSize, contSize));
Mcont.assignBlock(0, 0, scaledA);
Mcont.assignBlock(0, scaledA.getNumCols(), scaledB);
var Mdisc = Mcont.exp();
var discA = new Matrix<States, States>(new SimpleMatrix(contB.getNumRows(),
contB.getNumRows()));
var discB = new Matrix<States, Inputs>(new SimpleMatrix(contB.getNumRows(),
contB.getNumCols()));
discA.extractFrom(0, 0, Mdisc);
discB.extractFrom(0, contB.getNumRows(), Mdisc);
return new Pair<>(discA, discB);
}
}

View File

@@ -0,0 +1,180 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.math;
import java.util.Random;
import org.ejml.simple.SimpleMatrix;
import edu.wpi.first.wpilibj.geometry.Pose2d;
import edu.wpi.first.wpiutil.math.MathUtil;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Nat;
import edu.wpi.first.wpiutil.math.Num;
import edu.wpi.first.wpiutil.math.VecBuilder;
import edu.wpi.first.wpiutil.math.WPIMathJNI;
import edu.wpi.first.wpiutil.math.numbers.N1;
import edu.wpi.first.wpiutil.math.numbers.N3;
@SuppressWarnings({"PMD.TooManyMethods", "ParameterName"})
public final class StateSpaceUtil {
private StateSpaceUtil() {
// Utility class
}
/**
* Creates a covariance matrix from the given vector for use with Kalman
* filters.
*
* <p>Each element is squared and placed on the covariance matrix diagonal.
*
* @param <States> Num representing the states of the system.
* @param states A Nat representing the states of the system.
* @param stdDevs For a Q matrix, its elements are the standard deviations of
* each state from how the model behaves. For an R matrix, its
* elements are the standard deviations for each output
* measurement.
* @return Process noise or measurement noise covariance matrix.
*/
@SuppressWarnings("MethodTypeParameterName")
public static <States extends Num> Matrix<States, States> makeCovarianceMatrix(
Nat<States> states, Matrix<States, N1> stdDevs
) {
var result = new Matrix<>(states, states);
for (int i = 0; i < states.getNum(); i++) {
result.set(i, i, Math.pow(stdDevs.get(i, 0), 2));
}
return result;
}
/**
* Creates a vector of normally distributed white noise with the given noise
* intensities for each element.
*
* @param <N> Num representing the dimensionality of the noise vector to create.
* @param stdDevs A matrix whose elements are the standard deviations of each
* element of the noise vector.
* @return White noise vector.
*/
public static <N extends Num> Matrix<N, N1> makeWhiteNoiseVector(
Matrix<N, N1> stdDevs
) {
var rand = new Random();
Matrix<N, N1> result = new Matrix<>(new SimpleMatrix(stdDevs.getNumRows(), 1));
for (int i = 0; i < stdDevs.getNumRows(); i++) {
result.set(i, 0, rand.nextGaussian() * stdDevs.get(i, 0));
}
return result;
}
/**
* Creates a cost matrix from the given vector for use with LQR.
*
* <p>The cost matrix is constructed using Bryson's rule. The inverse square of
* each element in the input is taken and placed on the cost matrix diagonal.
*
* @param <States> Nat representing the states of the system.
* @param costs An array. For a Q matrix, its elements are the maximum allowed
* excursions of the states from the reference. For an R matrix,
* its elements are the maximum allowed excursions of the control
* inputs from no actuation.
* @return State excursion or control effort cost matrix.
*/
@SuppressWarnings("MethodTypeParameterName")
public static <States extends Num> Matrix<States, States>
makeCostMatrix(Matrix<States, N1> costs) {
Matrix<States, States> result =
new Matrix<>(new SimpleMatrix(costs.getNumRows(), costs.getNumRows()));
result.fill(0.0);
for (int i = 0; i < costs.getNumRows(); i++) {
result.set(i, i, 1.0 / (Math.pow(costs.get(i, 0), 2)));
}
return result;
}
/**
* Returns true if (A, B) is a stabilizable pair.
*
* <p>(A,B) is stabilizable if and only if the uncontrollable eigenvalues of A, if
* any, have absolute values less than one, where an eigenvalue is
* uncontrollable if rank(lambda * I - A, B) %3C n where n is number of states.
*
* @param <States> Num representing the size of A.
* @param <Inputs> Num representing the columns of B.
* @param A System matrix.
* @param B Input matrix.
* @return If the system is stabilizable.
*/
@SuppressWarnings("MethodTypeParameterName")
public static <States extends Num, Inputs extends Num> boolean isStabilizable(
Matrix<States, States> A, Matrix<States, Inputs> B) {
return WPIMathJNI.isStabilizable(A.getNumRows(), B.getNumCols(),
A.getData(), B.getData());
}
/**
* Convert a {@link Pose2d} to a vector of [x, y, theta], where theta is in radians.
*
* @param pose A pose to convert to a vector.
* @return The given pose in vector form, with the third element, theta, in radians.
*/
public static Matrix<N3, N1> poseToVector(Pose2d pose) {
return VecBuilder.fill(
pose.getX(),
pose.getY(),
pose.getRotation().getRadians()
);
}
/**
* Clamp the input u to the min and max.
*
* @param u The input to clamp.
* @param umin The minimum input magnitude.
* @param umax The maximum input magnitude.
* @param <I> The number of inputs.
* @return The clamped input.
*/
@SuppressWarnings({"ParameterName", "LocalVariableName"})
public static <I extends Num> Matrix<I, N1> clampInputMaxMagnitude(Matrix<I, N1> u,
Matrix<I, N1> umin,
Matrix<I, N1> umax) {
var result = new Matrix<I, N1>(new SimpleMatrix(u.getNumRows(), 1));
for (int i = 0; i < u.getNumRows(); i++) {
result.set(i, 0, MathUtil.clamp(
u.get(i, 0),
umin.get(i, 0),
umax.get(i, 0)));
}
return result;
}
/**
* Normalize all inputs if any excedes the maximum magnitude. Useful for systems such as
* differential drivetrains.
*
* @param u The input vector.
* @param maxMagnitude The maximum magnitude any input can have.
* @param <I> The number of inputs.
* @return The normalizedInput
*/
public static <I extends Num> Matrix<I, N1> normalizeInputVector(Matrix<I, N1> u,
double maxMagnitude) {
double maxValue = u.maxAbs();
boolean isCapped = maxValue > maxMagnitude;
if (isCapped) {
return u.times(maxMagnitude / maxValue);
}
return u;
}
}

View File

@@ -0,0 +1,182 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.system;
import edu.wpi.first.wpilibj.math.Discretization;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Num;
import edu.wpi.first.wpiutil.math.numbers.N1;
@SuppressWarnings({"PMD.TooManyMethods", "ClassTypeParameterName"})
public class LinearSystem<States extends Num, Inputs extends Num,
Outputs extends Num> {
/**
* Continuous system matrix.
*/
@SuppressWarnings("MemberName")
private final Matrix<States, States> m_A;
/**
* Continuous input matrix.
*/
@SuppressWarnings("MemberName")
private final Matrix<States, Inputs> m_B;
/**
* Output matrix.
*/
@SuppressWarnings("MemberName")
private final Matrix<Outputs, States> m_C;
/**
* Feedthrough matrix.
*/
@SuppressWarnings("MemberName")
private final Matrix<Outputs, Inputs> m_D;
/**
* Construct a new LinearSystem from the four system matrices.
*
* @param a The system matrix A.
* @param b The input matrix B.
* @param c The output matrix C.
* @param d The feedthrough matrix D.
*/
@SuppressWarnings("ParameterName")
public LinearSystem(Matrix<States, States> a, Matrix<States, Inputs> b,
Matrix<Outputs, States> c, Matrix<Outputs, Inputs> d) {
this.m_A = a;
this.m_B = b;
this.m_C = c;
this.m_D = d;
}
/**
* Returns the system matrix A.
*
* @return the system matrix A.
*/
public Matrix<States, States> getA() {
return m_A;
}
/**
* Returns an element of the system matrix A.
*
* @param row Row of A.
* @param col Column of A.
* @return the system matrix A at (i, j).
*/
public double getA(int row, int col) {
return m_A.get(row, col);
}
/**
* Returns the input matrix B.
*
* @return the input matrix B.
*/
public Matrix<States, Inputs> getB() {
return m_B;
}
/**
* Returns an element of the input matrix B.
*
* @param row Row of B.
* @param col Column of B.
* @return The value of the input matrix B at (i, j).
*/
public double getB(int row, int col) {
return m_B.get(row, col);
}
/**
* Returns the output matrix C.
*
* @return Output matrix C.
*/
public Matrix<Outputs, States> getC() {
return m_C;
}
/**
* Returns an element of the output matrix C.
*
* @param row Row of C.
* @param col Column of C.
* @return the double value of C at the given position.
*/
public double getC(int row, int col) {
return m_C.get(row, col);
}
/**
* Returns the feedthrough matrix D.
*
* @return the feedthrough matrix D.
*/
public Matrix<Outputs, Inputs> getD() {
return m_D;
}
/**
* Returns an element of the feedthrough matrix D.
*
* @param row Row of D.
* @param col Column of D.
* @return The feedthrough matrix D at (i, j).
*/
public double getD(int row, int col) {
return m_D.get(row, col);
}
/**
* Computes the new x given the old x and the control input.
*
* <p>This is used by state observers directly to run updates based on state
* estimate.
*
* @param x The current state.
* @param clampedU The control input.
* @param dtSeconds Timestep for model update.
* @return the updated x.
*/
@SuppressWarnings("ParameterName")
public Matrix<States, N1> calculateX(Matrix<States, N1> x, Matrix<Inputs, N1> clampedU,
double dtSeconds) {
var discABpair = Discretization.discretizeAB(m_A, m_B, dtSeconds);
return (discABpair.getFirst().times(x)).plus(discABpair.getSecond().times(clampedU));
}
/**
* Computes the new y given the control input.
*
* <p>This is used by state observers directly to run updates based on state
* estimate.
*
* @param x The current state.
* @param clampedU The control input.
* @return the updated output matrix Y.
*/
@SuppressWarnings("ParameterName")
public Matrix<Outputs, N1> calculateY(
Matrix<States, N1> x,
Matrix<Inputs, N1> clampedU) {
return m_C.times(x).plus(m_D.times(clampedU));
}
@Override
public String toString() {
return String.format("Linear System: A\n%s\n\nB:\n%s\n\nC:\n%s\n\nD:\n%s\n", m_A.toString(),
m_B.toString(), m_C.toString(), m_D.toString());
}
}

View File

@@ -0,0 +1,357 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.system;
import java.util.function.Function;
import org.ejml.MatrixDimensionException;
import org.ejml.simple.SimpleMatrix;
import edu.wpi.first.wpilibj.controller.LinearPlantInversionFeedforward;
import edu.wpi.first.wpilibj.controller.LinearQuadraticRegulator;
import edu.wpi.first.wpilibj.estimator.KalmanFilter;
import edu.wpi.first.wpilibj.math.StateSpaceUtil;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Num;
import edu.wpi.first.wpiutil.math.numbers.N1;
/**
* Combines a plant, controller, and observer for controlling a mechanism with
* full state feedback.
*
* <p>For everything in this file, "inputs" and "outputs" are defined from the
* perspective of the plant. This means U is an input and Y is an output
* (because you give the plant U (powers) and it gives you back a Y (sensor
* values). This is the opposite of what they mean from the perspective of the
* controller (U is an output because that's what goes to the motors and Y is an
* input because that's what comes back from the sensors).
*
* <p>For more on the underlying math, read
* https://file.tavsys.net/control/state-space-guide.pdf.
*/
@SuppressWarnings("ClassTypeParameterName")
public class LinearSystemLoop<States extends Num, Inputs extends Num,
Outputs extends Num> {
private final LinearSystem<States, Inputs, Outputs> m_plant;
private final LinearQuadraticRegulator<States, Inputs, Outputs> m_controller;
private final LinearPlantInversionFeedforward<States, Inputs, Outputs> m_feedforward;
private final KalmanFilter<States, Inputs, Outputs> m_observer;
private Matrix<States, N1> m_nextR;
private Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> m_clampFunction;
/**
* Constructs a state-space loop with the given plant, controller, and
* observer. By default, the initial reference is all zeros. Users should
* call reset with the initial system state before enabling the loop. This
* constructor assumes that the input(s) to this system are voltage.
*
* @param plant State-space plant.
* @param controller State-space controller.
* @param observer State-space observer.
* @param maxVoltageVolts The maximum voltage that can be applied. Commonly 12.
* @param dtSeconds The nominal timestep.
*/
public LinearSystemLoop(LinearSystem<States, Inputs, Outputs> plant,
LinearQuadraticRegulator<States, Inputs, Outputs> controller,
KalmanFilter<States, Inputs, Outputs> observer,
double maxVoltageVolts,
double dtSeconds) {
this(plant, controller,
new LinearPlantInversionFeedforward<>(plant, dtSeconds), observer,
u -> StateSpaceUtil.normalizeInputVector(u, maxVoltageVolts));
}
/**
* Constructs a state-space loop with the given plant, controller, and
* observer. By default, the initial reference is all zeros. Users should
* call reset with the initial system state before enabling the loop.
*
* @param plant State-space plant.
* @param controller State-space controller.
* @param observer State-space observer.
* @param clampFunction The function used to clamp the input U.
* @param dtSeconds The nominal timestep.
*/
public LinearSystemLoop(LinearSystem<States, Inputs, Outputs> plant,
LinearQuadraticRegulator<States, Inputs, Outputs> controller,
KalmanFilter<States, Inputs, Outputs> observer,
Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> clampFunction,
double dtSeconds) {
this(plant, controller, new LinearPlantInversionFeedforward<>(plant, dtSeconds),
observer, clampFunction);
}
/**
* Constructs a state-space loop with the given plant, controller, and
* observer. By default, the initial reference is all zeros. Users should
* call reset with the initial system state before enabling the loop.
*
* @param plant State-space plant.
* @param controller State-space controller.
* @param feedforward Plant inversion feedforward.
* @param observer State-space observer.
* @param maxVoltageVolts The maximum voltage that can be applied. Assumes that the
* inputs are voltages.
*/
public LinearSystemLoop(LinearSystem<States, Inputs, Outputs> plant,
LinearQuadraticRegulator<States, Inputs, Outputs> controller,
LinearPlantInversionFeedforward<States, Inputs, Outputs> feedforward,
KalmanFilter<States, Inputs, Outputs> observer,
double maxVoltageVolts
) {
this(plant, controller, feedforward,
observer, u -> StateSpaceUtil.normalizeInputVector(u, maxVoltageVolts));
}
/**
* Constructs a state-space loop with the given plant, controller, and
* observer. By default, the initial reference is all zeros. Users should
* call reset with the initial system state before enabling the loop.
*
* @param plant State-space plant.
* @param controller State-space controller.
* @param feedforward Plant inversion feedforward.
* @param observer State-space observer.
* @param clampFunction The function used to clamp the input U.
*/
public LinearSystemLoop(LinearSystem<States, Inputs, Outputs> plant,
LinearQuadraticRegulator<States, Inputs, Outputs> controller,
LinearPlantInversionFeedforward<States, Inputs, Outputs> feedforward,
KalmanFilter<States, Inputs, Outputs> observer,
Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> clampFunction) {
this.m_plant = plant;
this.m_controller = controller;
this.m_feedforward = feedforward;
this.m_observer = observer;
this.m_clampFunction = clampFunction;
m_nextR = new Matrix<>(new SimpleMatrix(controller.getK().getNumCols(), 0));
reset(m_nextR);
}
/**
* Returns the observer's state estimate x-hat.
*
* @return the observer's state estimate x-hat.
*/
public Matrix<States, N1> getXHat() {
return getObserver().getXhat();
}
/**
* Returns an element of the observer's state estimate x-hat.
*
* @param row Row of x-hat.
* @return the i-th element of the observer's state estimate x-hat.
*/
public double getXHat(int row) {
return getObserver().getXhat(row);
}
/**
* Set the initial state estimate x-hat.
*
* @param xhat The initial state estimate x-hat.
*/
public void setXHat(Matrix<States, N1> xhat) {
getObserver().setXhat(xhat);
}
/**
* Set an element of the initial state estimate x-hat.
*
* @param row Row of x-hat.
* @param value Value for element of x-hat.
*/
public void setXHat(int row, double value) {
getObserver().setXhat(row, value);
}
/**
* Returns an element of the controller's next reference r.
*
* @param row Row of r.
* @return the element i of the controller's next reference r.
*/
public double getNextR(int row) {
return getNextR().get(row, 0);
}
/**
* Returns the controller's next reference r.
*
* @return the controller's next reference r.
*/
public Matrix<States, N1> getNextR() {
return m_nextR;
}
/**
* Set the next reference r.
*
* @param nextR Next reference.
*/
public void setNextR(Matrix<States, N1> nextR) {
m_nextR = nextR;
}
/**
* Set the next reference r.
*
* @param nextR Next reference.
*/
public void setNextR(double... nextR) {
if (nextR.length != m_nextR.getNumRows()) {
throw new MatrixDimensionException(String.format("The next reference does not have the "
+ "correct number of entries! Expected %s, but got %s.",
m_nextR.getNumRows(),
nextR.length));
}
m_nextR = new Matrix<>(new SimpleMatrix(m_nextR.getNumRows(), 1, true, nextR));
}
/**
* Returns the controller's calculated control input u plus the calculated feedforward u_ff.
*
* @return the calculated control input u.
*/
public Matrix<Inputs, N1> getU() {
return clampInput(m_controller.getU().plus(m_feedforward.getUff()));
}
/**
* Returns an element of the controller's calculated control input u.
*
* @param row Row of u.
* @return the calculated control input u at the row i.
*/
public double getU(int row) {
return getU().get(row, 0);
}
/**
* Return the plant used internally.
*
* @return the plant used internally.
*/
public LinearSystem<States, Inputs, Outputs> getPlant() {
return m_plant;
}
/**
* Return the controller used internally.
*
* @return the controller used internally.
*/
public LinearQuadraticRegulator<States, Inputs, Outputs> getController() {
return m_controller;
}
/**
* Return the feedforward used internally.
*
* @return the feedforward used internally.
*/
public LinearPlantInversionFeedforward<States, Inputs, Outputs> getFeedforward() {
return m_feedforward;
}
/**
* Return the observer used internally.
*
* @return the observer used internally.
*/
public KalmanFilter<States, Inputs, Outputs> getObserver() {
return m_observer;
}
/**
* Zeroes reference r, controller output u and plant output y.
* The previous reference for PlantInversionFeedforward is set to the
* initial reference.
* @param initialReference The initial reference.
*/
public void reset(Matrix<States, N1> initialReference) {
m_controller.reset();
m_feedforward.reset(initialReference);
m_observer.reset();
m_nextR.fill(0.0);
}
/**
* Returns difference between reoid predict(double dtSference r and x-hat.
*
* @return the
*/
public Matrix<States, N1> getError() {
return getController().getR().minus(m_observer.getXhat());
}
/**
* Returns difference between reference r and x-hat.
*
* @param index The index of the error matrix to return.
* @return The error at that index.
*/
public double getError(int index) {
return (getController().getR().minus(m_observer.getXhat())).get(index, 0);
}
/**
* Get the function used to clamp the input u.
* @return The clamping function.
*/
public Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> getClampFunction() {
return m_clampFunction;
}
/**
* Set the clamping function used to clamp inputs.
*/
public void setClampFunction(Function<Matrix<Inputs, N1>, Matrix<Inputs, N1>> clampFunction) {
this.m_clampFunction = clampFunction;
}
/**
* Correct the state estimate x-hat using the measurements in y.
*
* @param y Measurement vector.
*/
@SuppressWarnings("ParameterName")
public void correct(Matrix<Outputs, N1> y) {
getObserver().correct(getU(), y);
}
/**
* Sets new controller output, projects model forward, and runs observer
* prediction.
*
* <p>After calling this, the user should send the elements of u to the
* actuators.
*
* @param dtSeconds Timestep for model update.
*/
@SuppressWarnings("LocalVariableName")
public void predict(double dtSeconds) {
var u = clampInput(m_controller.calculate(getObserver().getXhat(), m_nextR)
.plus(m_feedforward.calculate(m_nextR)));
getObserver().predict(u, dtSeconds);
}
/**
* Clamp the input u to the min and max.
*
* @param unclampedU The input to clamp.
* @return The clamped input.
*/
public Matrix<Inputs, N1> clampInput(Matrix<Inputs, N1> unclampedU) {
return m_clampFunction.apply(unclampedU);
}
}

View File

@@ -0,0 +1,111 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.system;
import java.util.function.BiFunction;
import java.util.function.Function;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Nat;
import edu.wpi.first.wpiutil.math.Num;
import edu.wpi.first.wpiutil.math.numbers.N1;
public final class NumericalJacobian {
private NumericalJacobian() {
// Utility Class.
}
private static final double kEpsilon = 1e-5;
/**
* Computes the numerical Jacobian with respect to x for f(x).
*
* @param <Rows> Number of rows in the result of f(x).
* @param <States> Num representing the number of rows in the output of f.
* @param <Cols> Number of columns in the result of f(x).
* @param rows Number of rows in the result of f(x).
* @param cols Number of columns in the result of f(x).
* @param f Vector-valued function from which to compute the Jacobian.
* @param x Vector argument.
* @return The numerical Jacobian with respect to x for f(x, u, ...).
*/
@SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
public static <Rows extends Num, Cols extends Num, States extends Num> Matrix<Rows, Cols>
numericalJacobian(
Nat<Rows> rows,
Nat<Cols> cols,
Function<Matrix<Cols, N1>, Matrix<States, N1>> f,
Matrix<Cols, N1> x
) {
var result = new Matrix<>(rows, cols);
for (int i = 0; i < cols.getNum(); i++) {
var dxPlus = x.copy();
var dxMinus = x.copy();
dxPlus.set(i, 0, dxPlus.get(i, 0) + kEpsilon);
dxMinus.set(i, 0, dxMinus.get(i, 0) - kEpsilon);
@SuppressWarnings("LocalVariableName")
var dF = f.apply(dxPlus).minus(f.apply(dxMinus)).div(2 * kEpsilon);
result.setColumn(i, Matrix.changeBoundsUnchecked(dF));
}
return result;
}
/**
* Returns numerical Jacobian with respect to x for f(x, u, ...).
*
* @param <Rows> Number of rows in the result of f(x, u).
* @param <States> Number of rows in x.
* @param <Inputs> Number of rows in the second input to f.
* @param <Outputs> Num representing the rows in the output of f.
* @param rows Number of rows in the result of f(x, u).
* @param states Number of rows in x.
* @param f Vector-valued function from which to compute Jacobian.
* @param x State vector.
* @param u Input vector.
* @return The numerical Jacobian with respect to x for f(x, u, ...).
*/
@SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
public static <Rows extends Num, States extends Num, Inputs extends Num, Outputs extends Num>
Matrix<Rows, States> numericalJacobianX(
Nat<Rows> rows,
Nat<States> states,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<Outputs, N1>> f,
Matrix<States, N1> x,
Matrix<Inputs, N1> u
) {
return numericalJacobian(rows, states, _x -> f.apply(_x, u), x);
}
/**
* Returns the numerical Jacobian with respect to u for f(x, u).
*
* @param <States> The states of the system.
* @param <Inputs> The inputs to the system.
* @param <Rows> Number of rows in the result of f(x, u).
* @param rows Number of rows in the result of f(x, u).
* @param inputs Number of rows in u.
* @param f Vector-valued function from which to compute the Jacobian.
* @param x State vector.
* @param u Input vector.
* @return the numerical Jacobian with respect to u for f(x, u).
*/
@SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
public static <Rows extends Num, States extends Num, Inputs extends Num> Matrix<Rows, Inputs>
numericalJacobianU(
Nat<Rows> rows,
Nat<Inputs> inputs,
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
Matrix<States, N1> x,
Matrix<Inputs, N1> u
) {
return numericalJacobian(rows, inputs, _u -> f.apply(x, _u), u);
}
}

View File

@@ -0,0 +1,113 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.system;
import java.util.function.BiFunction;
import java.util.function.DoubleFunction;
import java.util.function.Function;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Num;
import edu.wpi.first.wpiutil.math.numbers.N1;
public final class RungeKutta {
private RungeKutta() {
// utility Class
}
/**
* Performs Runge Kutta integration (4th order).
*
* @param f The function to integrate, which takes one argument x.
* @param x The initial value of x.
* @param dtSeconds The time over which to integrate.
* @return the integration of dx/dt = f(x) for dt.
*/
@SuppressWarnings("ParameterName")
public static double rungeKutta(
DoubleFunction<Double> f,
double x,
double dtSeconds
) {
final var halfDt = 0.5 * dtSeconds;
final var k1 = f.apply(x);
final var k2 = f.apply(x + k1 * halfDt);
final var k3 = f.apply(x + k2 * halfDt);
final var k4 = f.apply(x + k3 * dtSeconds);
return x + dtSeconds / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
}
/**
* Performs Runge Kutta integration (4th order).
*
* @param f The function to integrate. It must take two arguments x and u.
* @param x The initial value of x.
* @param u The value u held constant over the integration period.
* @param dtSeconds The time over which to integrate.
* @return The result of Runge Kutta integration (4th order).
*/
@SuppressWarnings("ParameterName")
public static double rungeKutta(
BiFunction<Double, Double, Double> f,
double x, Double u, double dtSeconds
) {
final var halfDt = 0.5 * dtSeconds;
final var k1 = f.apply(x, u);
final var k2 = f.apply(x + k1 * halfDt, u);
final var k3 = f.apply(x + k2 * halfDt, u);
final var k4 = f.apply(x + k3 * dtSeconds, u);
return x + dtSeconds / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
}
/**
* Performs 4th order Runge-Kutta integration of dx/dt = f(x, u) for dt.
*
* @param <States> A Num representing the states of the system to integrate.
* @param <Inputs> A Num representing the inputs of the system to integrate.
* @param f The function to integrate. It must take two arguments x and u.
* @param x The initial value of x.
* @param u The value u held constant over the integration period.
* @param dtSeconds The time over which to integrate.
* @return the integration of dx/dt = f(x, u) for dt.
*/
@SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
public static <States extends Num, Inputs extends Num> Matrix<States, N1> rungeKutta(
BiFunction<Matrix<States, N1>, Matrix<Inputs, N1>, Matrix<States, N1>> f,
Matrix<States, N1> x, Matrix<Inputs, N1> u, double dtSeconds) {
final var halfDt = 0.5 * dtSeconds;
Matrix<States, N1> k1 = f.apply(x, u);
Matrix<States, N1> k2 = f.apply(x.plus(k1.times(halfDt)), u);
Matrix<States, N1> k3 = f.apply(x.plus(k2.times(halfDt)), u);
Matrix<States, N1> k4 = f.apply(x.plus(k3.times(dtSeconds)), u);
return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(dtSeconds).div(6.0));
}
/**
* Performs 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
*
* @param <States> A Num prepresenting the states of the system.
* @param f The function to integrate. It must take one argument x.
* @param x The initial value of x.
* @param dtSeconds The time over which to integrate.
* @return 4th order Runge-Kutta integration of dx/dt = f(x) for dt.
*/
@SuppressWarnings({"ParameterName", "MethodTypeParameterName"})
public static <States extends Num> Matrix<States, N1> rungeKutta(
Function<Matrix<States, N1>, Matrix<States, N1>> f,
Matrix<States, N1> x, double dtSeconds) {
final var halfDt = 0.5 * dtSeconds;
Matrix<States, N1> k1 = f.apply(x);
Matrix<States, N1> k2 = f.apply(x.plus(k1.times(halfDt)));
Matrix<States, N1> k3 = f.apply(x.plus(k2.times(halfDt)));
Matrix<States, N1> k4 = f.apply(x.plus(k3.times(dtSeconds)));
return x.plus((k1.plus(k2.times(2.0)).plus(k3.times(2.0)).plus(k4)).times(dtSeconds).div(6.0));
}
}

View File

@@ -0,0 +1,171 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.system.plant;
import edu.wpi.first.wpilibj.util.Units;
/**
* Holds the constants for a DC motor.
*/
public class DCMotor {
public final double m_nominalVoltageVolts;
public final double m_stallTorqueNewtonMeters;
public final double m_stallCurrentAmps;
public final double m_freeCurrentAmps;
public final double m_freeSpeedRadPerSec;
@SuppressWarnings("MemberName")
public final double m_rOhms;
@SuppressWarnings("MemberName")
public final double m_KvRadPerSecPerVolt;
@SuppressWarnings("MemberName")
public final double m_KtNMPerAmp;
/**
* Constructs a DC motor.
*
* @param nominalVoltageVolts Voltage at which the motor constants were measured.
* @param stallTorqueNewtonMeters Current draw when stalled.
* @param stallCurrentAmps Current draw when stalled.
* @param freeCurrentAmps Current draw under no load.
* @param freeSpeedRadPerSec Angular velocity under no load.
*/
public DCMotor(double nominalVoltageVolts,
double stallTorqueNewtonMeters,
double stallCurrentAmps,
double freeCurrentAmps,
double freeSpeedRadPerSec) {
this.m_nominalVoltageVolts = nominalVoltageVolts;
this.m_stallTorqueNewtonMeters = stallTorqueNewtonMeters;
this.m_stallCurrentAmps = stallCurrentAmps;
this.m_freeCurrentAmps = freeCurrentAmps;
this.m_freeSpeedRadPerSec = freeSpeedRadPerSec;
this.m_rOhms = nominalVoltageVolts / stallCurrentAmps;
this.m_KvRadPerSecPerVolt = freeSpeedRadPerSec / (nominalVoltageVolts - m_rOhms
* freeCurrentAmps);
this.m_KtNMPerAmp = stallTorqueNewtonMeters / stallCurrentAmps;
}
/**
* Return a gearbox of CIM motors.
*
* @param numMotors Number of motors in the gearbox.
*/
public static DCMotor getCIM(int numMotors) {
return new DCMotor(12,
2.42 * numMotors, 133,
2.7, Units.rotationsPerMinuteToRadiansPerSecond(5310));
}
/**
* Return a gearbox of 775Pro motors.
*
* @param numMotors Number of motors in the gearbox.
*/
public static DCMotor getVex775Pro(int numMotors) {
return gearbox(new DCMotor(12,
0.71, 134,
0.7, Units.rotationsPerMinuteToRadiansPerSecond(18730)), numMotors);
}
/**
* Return a gearbox of NEO motors.
*
* @param numMotors Number of motors in the gearbox.
*/
public static DCMotor getNEO(int numMotors) {
return gearbox(new DCMotor(12, 2.6,
105, 1.8, Units.rotationsPerMinuteToRadiansPerSecond(5676)), numMotors);
}
/**
* Return a gearbox of MiniCIM motors.
*
* @param numMotors Number of motors in the gearbox.
*/
public static DCMotor getMiniCIM(int numMotors) {
return gearbox(new DCMotor(12, 1.41, 89, 3,
Units.rotationsPerMinuteToRadiansPerSecond(5840)), numMotors);
}
/**
* Return a gearbox of Bag motors.
*
* @param numMotors Number of motors in the gearbox.
*/
public static DCMotor getBag(int numMotors) {
return gearbox(new DCMotor(12, 0.43, 53, 1.8,
Units.rotationsPerMinuteToRadiansPerSecond(13180)), numMotors);
}
/**
* Return a gearbox of Andymark RS775-125 motors.
*
* @param numMotors Number of motors in the gearbox.
*/
public static DCMotor getAndymarkRs775_125(int numMotors) {
return gearbox(new DCMotor(12, 0.28, 18, 1.6,
Units.rotationsPerMinuteToRadiansPerSecond(5800.0)), numMotors);
}
/**
* Return a gearbox of Banebots RS775 motors.
*
* @param numMotors Number of motors in the gearbox.
*/
public static DCMotor getBanebotsRs775(int numMotors) {
return gearbox(new DCMotor(12, 0.72, 97, 2.7,
Units.rotationsPerMinuteToRadiansPerSecond(13050.0)), numMotors);
}
/**
* Return a gearbox of Andymark 9015 motors.
*
* @param numMotors Number of motors in the gearbox.
*/
public static DCMotor getAndymark9015(int numMotors) {
return gearbox(new DCMotor(12, 0.36, 71, 3.7,
Units.rotationsPerMinuteToRadiansPerSecond(14270.0)), numMotors);
}
/**
* Return a gearbox of Banebots RS 550 motors.
*
* @param numMotors Number of motors in the gearbox.
*/
public static DCMotor getBanebotsRs550(int numMotors) {
return gearbox(new DCMotor(12, 0.38, 84, 0.4,
Units.rotationsPerMinuteToRadiansPerSecond(19000.0)), numMotors);
}
/**
* Return a gearbox of Neo 550 motors.
*
* @param numMotors Number of motors in the gearbox.
*/
public static DCMotor getNeo550(int numMotors) {
return gearbox(new DCMotor(12, 0.97, 100, 1.4,
Units.rotationsPerMinuteToRadiansPerSecond(11000.0)), numMotors);
}
/**
* Return a gearbox of Falcon 500 motors.
*
* @param numMotors Number of motors in the gearbox.
*/
public static DCMotor getFalcon500(int numMotors) {
return gearbox(new DCMotor(12, 4.69, 257, 1.5,
Units.rotationsPerMinuteToRadiansPerSecond(6380.0)), numMotors);
}
private static DCMotor gearbox(DCMotor motor, double numMotors) {
return new DCMotor(motor.m_nominalVoltageVolts, motor.m_stallTorqueNewtonMeters * numMotors,
motor.m_stallCurrentAmps, motor.m_freeCurrentAmps, motor.m_freeSpeedRadPerSec);
}
}

View File

@@ -0,0 +1,199 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpilibj.system.plant;
import edu.wpi.first.wpilibj.system.LinearSystem;
import edu.wpi.first.wpiutil.math.Matrix;
import edu.wpi.first.wpiutil.math.Nat;
import edu.wpi.first.wpiutil.math.VecBuilder;
import edu.wpi.first.wpiutil.math.numbers.N1;
import edu.wpi.first.wpiutil.math.numbers.N2;
public final class LinearSystemId {
private LinearSystemId() {
// Utility class
}
/**
* Create a state-space model of an elevator system.
*
* @param motor The motor (or gearbox) attached to the arm.
* @param massKg The mass of the elevator carriage, in kilograms.
* @param radiusMeters The radius of thd driving drum of the elevator, in meters.
* @param G The reduction between motor and drum, as a ratio of output to input.
* @return A LinearSystem representing the given characterized constants.
*/
@SuppressWarnings("ParameterName")
public static LinearSystem<N2, N1, N1> createElevatorSystem(DCMotor motor, double massKg,
double radiusMeters, double G) {
return new LinearSystem<>(
Matrix.mat(Nat.N2(), Nat.N2()).fill(0, 1,
0, -Math.pow(G, 2) * motor.m_KtNMPerAmp
/ (motor.m_rOhms * radiusMeters * radiusMeters * massKg
* motor.m_KvRadPerSecPerVolt)),
VecBuilder.fill(
0, G * motor.m_KtNMPerAmp / (motor.m_rOhms * radiusMeters * massKg)),
Matrix.mat(Nat.N1(), Nat.N2()).fill(1, 0),
new Matrix<>(Nat.N1(), Nat.N1()));
}
/**
* Create a state-space model of a flywheel system.
*
* @param motor The motor (or gearbox) attached to the arm.
* @param jKgMetersSquared The moment of inertia J of the flywheel.
* @param G The reduction between motor and drum, as a ratio of output to input.
* @return A LinearSystem representing the given characterized constants.
*/
@SuppressWarnings("ParameterName")
public static LinearSystem<N1, N1, N1> createFlywheelSystem(DCMotor motor,
double jKgMetersSquared,
double G) {
return new LinearSystem<>(
VecBuilder.fill(
-G * G * motor.m_KtNMPerAmp
/ (motor.m_KvRadPerSecPerVolt * motor.m_rOhms * jKgMetersSquared)),
VecBuilder.fill(G * motor.m_KtNMPerAmp
/ (motor.m_rOhms * jKgMetersSquared)),
Matrix.eye(Nat.N1()),
new Matrix<>(Nat.N1(), Nat.N1()));
}
/**
* Create a state-space model of a differential drive drivetrain.
*
* @param motor the gearbox representing the motors driving the drivetrain.
* @param massKg the mass of the robot.
* @param rMeters the radius of the wheels in meters.
* @param rbMeters the radius of the base (half the track width) in meters.
* @param JKgMetersSquared the moment of inertia of the robot.
* @param G the gearing reduction as output over input.
* @return A LinearSystem representing a differential drivetrain.
*/
@SuppressWarnings({"LocalVariableName", "ParameterName"})
public static LinearSystem<N2, N2, N2> createDrivetrainVelocitySystem(DCMotor motor,
double massKg,
double rMeters,
double rbMeters,
double JKgMetersSquared,
double G) {
var C1 =
-(G * G) * motor.m_KtNMPerAmp
/ (motor.m_KvRadPerSecPerVolt * motor.m_rOhms * rMeters * rMeters);
var C2 = G * motor.m_KtNMPerAmp / (motor.m_rOhms * rMeters);
final double C3 = 1 / massKg + rbMeters * rbMeters / JKgMetersSquared;
final double C4 = 1 / massKg - rbMeters * rbMeters / JKgMetersSquared;
var A = Matrix.mat(Nat.N2(), Nat.N2()).fill(
C3 * C1,
C4 * C1,
C4 * C1,
C3 * C1);
var B = Matrix.mat(Nat.N2(), Nat.N2()).fill(
C3 * C2,
C4 * C2,
C4 * C2,
C3 * C2);
var C = Matrix.mat(Nat.N2(), Nat.N2()).fill(1.0, 0.0, 0.0, 1.0);
var D = Matrix.mat(Nat.N2(), Nat.N2()).fill(0.0, 0.0, 0.0, 0.0);
return new LinearSystem<>(A, B, C, D);
}
/**
* Create a state-space model of a single jointed arm system.
*
* @param motor The motor (or gearbox) attached to the arm.
* @param jKgSquaredMeters The moment of inertia J of the arm.
* @param G the gearing between the motor and arm, in output over input.
* Most of the time this will be greater than 1.
* @return A LinearSystem representing the given characterized constants.
*/
@SuppressWarnings("ParameterName")
public static LinearSystem<N2, N1, N1> createSingleJointedArmSystem(DCMotor motor,
double jKgSquaredMeters,
double G) {
return new LinearSystem<>(
Matrix.mat(Nat.N2(), Nat.N2()).fill(0, 1,
0, -Math.pow(G, 2) * motor.m_KtNMPerAmp
/ (motor.m_KvRadPerSecPerVolt * motor.m_rOhms * jKgSquaredMeters)),
VecBuilder.fill(0, G * motor.m_KtNMPerAmp
/ (motor.m_rOhms * jKgSquaredMeters)),
Matrix.mat(Nat.N1(), Nat.N2()).fill(1, 0),
new Matrix<>(Nat.N1(), Nat.N1()));
}
/**
* Identify a velocity system from it's kV (volts/(unit/sec)) and kA (volts/(unit/sec^2).
* These constants cam be found using frc-characterization.
*
* @param kV The velocity gain, in volts per (units per second)
* @param kA The acceleration gain, in volts per (units per second squared)
* @return A LinearSystem representing the given characterized constants.
* @see <a href="https://github.com/wpilibsuite/frc-characterization">
* https://github.com/wpilibsuite/frc-characterization</a>
*/
@SuppressWarnings("ParameterName")
public static LinearSystem<N1, N1, N1> identifyVelocitySystem(double kV, double kA) {
return new LinearSystem<>(
VecBuilder.fill(-kV / kA),
VecBuilder.fill(1.0 / kA),
VecBuilder.fill(1.0),
VecBuilder.fill(0.0));
}
/**
* Identify a position system from it's kV (volts/(unit/sec)) and kA (volts/(unit/sec^2).
* These constants cam be found using frc-characterization.
*
* @param kV The velocity gain, in volts per (units per second)
* @param kA The acceleration gain, in volts per (units per second squared)
* @return A LinearSystem representing the given characterized constants.
* @see <a href="https://github.com/wpilibsuite/frc-characterization">
* https://github.com/wpilibsuite/frc-characterization</a>
*/
@SuppressWarnings("ParameterName")
public static LinearSystem<N2, N1, N1> identifyPositionSystem(double kV, double kA) {
return new LinearSystem<>(
Matrix.mat(Nat.N2(), Nat.N2()).fill(0.0, 1.0, 0.0, -kV / kA),
VecBuilder.fill(0.0, 1.0 / kA),
Matrix.mat(Nat.N1(), Nat.N2()).fill(1.0, 0.0),
VecBuilder.fill(0.0));
}
/**
* Identify a standard differential drive drivetrain, given the drivetrain's
* kV and kA in both linear (volts/(meter/sec) and volts/(meter/sec^2)) and
* angular (volts/(radian/sec) and volts/(radian/sec^2)) cases. This can be
* found using frc-characterization.
*
* @param kVLinear The linear velocity gain, volts per (meter per second).
* @param kALinear The linear acceleration gain, volts per (meter per second squared).
* @param kVAngular The angular velocity gain, volts per (radians per second).
* @param kAAngular The angular acceleration gain, volts per (radians per second squared).
* @return A LinearSystem representing the given characterized constants.
* @see <a href="https://github.com/wpilibsuite/frc-characterization">
* https://github.com/wpilibsuite/frc-characterization</a>
*/
@SuppressWarnings("ParameterName")
public static LinearSystem<N2, N2, N2> identifyDrivetrainSystem(
double kVLinear, double kALinear, double kVAngular, double kAAngular) {
final double c = 0.5 / (kALinear * kAAngular);
final double A1 = c * (-kALinear * kVAngular - kVLinear * kAAngular);
final double A2 = c * (kALinear * kVAngular - kVLinear * kAAngular);
final double B1 = c * (kALinear + kAAngular);
final double B2 = c * (kAAngular - kALinear);
return new LinearSystem<>(
Matrix.mat(Nat.N2(), Nat.N2()).fill(A1, A2, A2, A1),
Matrix.mat(Nat.N2(), Nat.N2()).fill(B1, B2, B2, B1),
Matrix.mat(Nat.N2(), Nat.N2()).fill(1, 0, 0, 1),
Matrix.mat(Nat.N2(), Nat.N2()).fill(0, 0, 0, 0));
}
}

View File

@@ -18,8 +18,8 @@ import org.ejml.simple.SimpleMatrix;
* @param <C> The number of columns of the desired matrix.
*/
public class MatBuilder<R extends Num, C extends Num> {
private final Nat<R> m_rows;
private final Nat<C> m_cols;
final Nat<R> m_rows;
final Nat<C> m_cols;
/**
* Fills the matrix with the given data, encoded in row major form.

View File

@@ -9,10 +9,17 @@ package edu.wpi.first.wpiutil.math;
import java.util.Objects;
import org.ejml.MatrixDimensionException;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.dense.row.MatrixFeatures_DDRM;
import org.ejml.dense.row.NormOps_DDRM;
import org.ejml.dense.row.factory.DecompositionFactory_DDRM;
import org.ejml.interfaces.decomposition.CholeskyDecomposition_F64;
import org.ejml.simple.SimpleMatrix;
import edu.wpi.first.wpiutil.math.numbers.N1;
/**
* A shape-safe wrapper over Efficient Java Matrix Library (EJML) matrices.
*
@@ -21,10 +28,60 @@ import org.ejml.simple.SimpleMatrix;
* @param <R> The number of rows in this matrix.
* @param <C> The number of columns in this matrix.
*/
@SuppressWarnings("PMD.TooManyMethods")
@SuppressWarnings({"PMD.TooManyMethods", "PMD.ExcessivePublicCount"})
public class Matrix<R extends Num, C extends Num> {
protected final SimpleMatrix m_storage;
private final SimpleMatrix m_storage;
/**
* Constructs an empty zero matrix of the given dimensions.
*
* @param rows The number of rows of the matrix.
* @param columns The number of columns of the matrix.
*/
public Matrix(Nat<R> rows, Nat<C> columns) {
this.m_storage = new SimpleMatrix(
Objects.requireNonNull(rows).getNum(),
Objects.requireNonNull(columns).getNum()
);
}
/**
* Constructs a new {@link Matrix} with the given storage.
* Caller should make sure that the provided generic bounds match
* the shape of the provided {@link Matrix}.
*
* <p>NOTE:It is not recommend to use this constructor unless the
* {@link SimpleMatrix} API is absolutely necessary due to the desired
* function not being accessible through the {@link Matrix} wrapper.
*
* @param storage The {@link SimpleMatrix} to back this value.
*/
public Matrix(SimpleMatrix storage) {
this.m_storage = Objects.requireNonNull(storage);
}
/**
* Constructs a new matrix with the storage of the supplied matrix.
*
* @param other The {@link Matrix} to copy the storage of.
*/
public Matrix(Matrix<R, C> other) {
this.m_storage = Objects.requireNonNull(other).getStorage().copy();
}
/**
* Gets the underlying {@link SimpleMatrix} that this {@link Matrix} wraps.
*
* <p>NOTE:The use of this method is heavily discouraged as this removes any
* guarantee of type safety. This should only be called if the {@link SimpleMatrix}
* API is absolutely necessary due to the desired function not being accessible through
* the {@link Matrix} wrapper.
*
* @return The underlying {@link SimpleMatrix} storage.
*/
public SimpleMatrix getStorage() {
return m_storage;
}
/**
* Gets the number of columns in this matrix.
@@ -58,8 +115,8 @@ public class Matrix<R extends Num, C extends Num> {
/**
* Sets the value at the given indices.
*
* @param row The row of the element.
* @param col The column of the element.
* @param row The row of the element.
* @param col The column of the element.
* @param value The value to insert at the given location.
*/
public final void set(int row, int col, double value) {
@@ -67,10 +124,44 @@ public class Matrix<R extends Num, C extends Num> {
}
/**
* If a vector then a square matrix is returned
* if a matrix then a vector of diagonal elements is returned.
* Sets a row to a given row vector.
*
* @return Diagonal elements inside a vector or a square matrix with the same diagonal elements.
* @param row The row to set.
* @param val The row vector to set the given row to.
*/
public final void setRow(int row, Matrix<N1, C> val) {
this.m_storage.setRow(row, 0,
Objects.requireNonNull(val).m_storage.getDDRM().getData());
}
/**
* Sets a column to a given column vector.
*
* @param column The column to set.
* @param val The column vector to set the given row to.
*/
public final void setColumn(int column, Matrix<R, N1> val) {
this.m_storage.setColumn(column, 0,
Objects.requireNonNull(val).m_storage.getDDRM().getData());
}
/**
* Sets all the elements in "this" matrix equal to the specified value.
*
* @param value The value each element is set to.
*/
public void fill(double value) {
this.m_storage.fill(value);
}
/**
* Returns the diagonal elements inside a vector or square matrix.
*
* <p>If "this" {@link Matrix} is a vector then a square matrix is returned. If a "this"
* {@link Matrix} is a matrix then a vector of diagonal elements is returned.
*
* @return The diagonal elements inside a vector or a square matrix.
*/
public final Matrix<R, C> diag() {
return new Matrix<>(this.m_storage.diag());
@@ -81,10 +172,20 @@ public class Matrix<R extends Num, C extends Num> {
*
* @return The largest element of this matrix.
*/
public final double maxInternal() {
public final double max() {
return CommonOps_DDRM.elementMax(this.m_storage.getDDRM());
}
/**
* Returns the absolute value of the element in this matrix with the largest absolute value.
*
* @return The absolute value of the element with the largest absolute value.
*/
public final double maxAbs() {
return CommonOps_DDRM.elementMaxAbs(this.m_storage.getDDRM());
}
/**
* Returns the smallest element of this matrix.
*
@@ -112,10 +213,10 @@ public class Matrix<R extends Num, C extends Num> {
*
* @param other The other matrix to multiply by.
* @param <C2> The number of columns in the second matrix.
* @return The result of the matrix multiplication between this and the given matrix.
* @return The result of the matrix multiplication between "this" and the given matrix.
*/
public final <C2 extends Num> Matrix<R, C2> times(Matrix<C, C2> other) {
return new Matrix<>(this.m_storage.mult(other.m_storage));
return new Matrix<>(this.m_storage.mult(Objects.requireNonNull(other).m_storage));
}
/**
@@ -129,13 +230,14 @@ public class Matrix<R extends Num, C extends Num> {
}
/**
* <p>
* Returns a matrix which is the result of an element by element multiplication of 'this' and 'b'.
* c<sub>i,j</sub> = a<sub>i,j</sub>*b<sub>i,j</sub>
* </p>
* Returns a matrix which is the result of an element by element multiplication of
* "this" and other.
*
* @param other A matrix.
* @return The element by element multiplication of 'this' and 'b'.
* <p>c<sub>i,j</sub> = a<sub>i,j</sub>*other<sub>i,j</sub>
*
*
* @param other The other {@link Matrix} to preform element multiplication on.
* @return The element by element multiplication of "this" and other.
*/
public final Matrix<R, C> elementTimes(Matrix<R, C> other) {
return new Matrix<>(this.m_storage.elementMult(Objects.requireNonNull(other).m_storage));
@@ -180,7 +282,7 @@ public class Matrix<R extends Num, C extends Num> {
* @return The resultant matrix.
*/
public final Matrix<R, C> plus(Matrix<R, C> value) {
return new Matrix<>(this.m_storage.plus(value.m_storage));
return new Matrix<>(this.m_storage.plus(Objects.requireNonNull(value).m_storage));
}
/**
@@ -206,7 +308,7 @@ public class Matrix<R extends Num, C extends Num> {
/**
* Calculates the transpose, M^T of this matrix.
*
* @return The tranpose matrix.
* @return The transpose matrix.
*/
public final Matrix<C, R> transpose() {
return new Matrix<>(this.m_storage.transpose());
@@ -224,15 +326,47 @@ public class Matrix<R extends Num, C extends Num> {
/**
* Returns the inverse matrix of this matrix.
* Returns the inverse matrix of "this" matrix.
*
* @return The inverse of this matrix.
* @throws org.ejml.data.SingularMatrixException If this matrix is non-invertable.
* @return The inverse of "this" matrix.
* @throws org.ejml.data.SingularMatrixException If "this" matrix is non-invertable.
*/
public final Matrix<R, C> inv() {
return new Matrix<>(this.m_storage.invert());
}
/**
* Returns the solution x to the equation Ax = b, where A is "this" matrix.
*
* <p>The matrix equation could also be written as x = A<sup>-1</sup>b. Where the
* pseudo inverse is used if A is not square.
*
* @param b The right-hand side of the equation to solve.
* @return The solution to the linear system.
*/
@SuppressWarnings("ParameterName")
public final <C2 extends Num> Matrix<C, C2> solve(Matrix<R, C2> b) {
return new Matrix<>(this.m_storage.solve(Objects.requireNonNull(b).m_storage));
}
/**
* Computes the matrix exponential using Eigen's solver.
* This method only works for square matrices, and will
* otherwise throw an {@link MatrixDimensionException}.
*
* @return The exponential of A.
*/
public final Matrix<R, C> exp() {
if (this.getNumRows() != this.getNumCols()) {
throw new MatrixDimensionException("Non-square matrices cannot be exponentiated! "
+ "This matrix is " + this.getNumRows() + " x " + this.getNumCols());
}
Matrix<R, C> toReturn = new Matrix<>(new SimpleMatrix(this.getNumRows(), this.getNumCols()));
WPIMathJNI.exp(this.m_storage.getDDRM().getData(), this.getNumRows(),
toReturn.m_storage.getDDRM().getData());
return toReturn;
}
/**
* Returns the determinant of this matrix.
*
@@ -243,9 +377,9 @@ public class Matrix<R extends Num, C extends Num> {
}
/**
* Computes the Frobenius normal of the matrix.<br>
* <br>
* normF = Sqrt{ &sum;<sub>i=1:m</sub> &sum;<sub>j=1:n</sub> { a<sub>ij</sub><sup>2</sup>} }
* Computes the Frobenius normal of the matrix.
*
* <p>normF = Sqrt{ &sum;<sub>i=1:m</sub> &sum;<sub>j=1:n</sub> { a<sub>ij</sub><sup>2</sup>} }
*
* @return The matrix's Frobenius normal.
*/
@@ -254,9 +388,9 @@ public class Matrix<R extends Num, C extends Num> {
}
/**
* Computes the induced p = 1 matrix norm.<br>
* <br>
* ||A||<sub>1</sub>= max(j=1 to n; sum(i=1 to m; |a<sub>ij</sub>|))
* Computes the induced p = 1 matrix norm.
*
* <p>||A||<sub>1</sub>= max(j=1 to n; sum(i=1 to m; |a<sub>ij</sub>|))
*
* @return The norm.
*/
@@ -283,45 +417,261 @@ public class Matrix<R extends Num, C extends Num> {
}
/**
* Returns a matrix which is the result of an element by element power of 'this' and 'b':
* c<sub>i,j</sub> = a<sub>i,j</sub> ^ b.
* Returns a matrix which is the result of an element by element power of "this" and b.
*
* @param b Scalar
* @return The element by element power of 'this' and 'b'.
* <p>c<sub>i,j</sub> = a<sub>i,j</sub> ^ b
*
* @param b Scalar.
* @return The element by element power of "this" and b.
*/
@SuppressWarnings("ParameterName")
public final Matrix<R, C> epow(double b) {
public final Matrix<R, C> elementPower(double b) {
return new Matrix<>(this.m_storage.elementPower(b));
}
/**
* Returns a matrix which is the result of an element by element power of 'this' and 'b':
* c<sub>i,j</sub> = a<sub>i,j</sub> ^ b.
* Returns a matrix which is the result of an element by element power of "this" and b.
*
* <p>c<sub>i,j</sub> = a<sub>i,j</sub> ^ b
*
* @param b Scalar.
* @return The element by element power of 'this' and 'b'.
* @return The element by element power of "this" and b.
*/
@SuppressWarnings("ParameterName")
public final Matrix<R, C> epow(int b) {
public final Matrix<R, C> elementPower(int b) {
return new Matrix<>(this.m_storage.elementPower((double) b));
}
/**
* Returns the EJML {@link SimpleMatrix} backing this wrapper.
* Extracts a given row into a row vector with new underlying storage.
*
* @return The untyped EJML {@link SimpleMatrix}.
* @param row The row to extract a vector from.
* @return A row vector from the given row.
*/
public final SimpleMatrix getStorage() {
return this.m_storage;
public final Matrix<N1, C> extractRowVector(int row) {
return new Matrix<>(this.m_storage.extractVector(true, row));
}
/**
* Constructs a new matrix with the given storage.
* Caller should make sure that the provided generic bounds match the shape of the provided matrix
* Extracts a given column into a column vector with new underlying storage.
*
* @param storage The {@link SimpleMatrix} to back this value
* @param column The column to extract a vector from.
* @return A column vector from the given column.
*/
public Matrix(SimpleMatrix storage) {
this.m_storage = Objects.requireNonNull(storage);
public final Matrix<R, N1> extractColumnVector(int column) {
return new Matrix<>(this.m_storage.extractVector(false, column));
}
/**
* Extracts a matrix of a given size and start position with new underlying
* storage.
*
* @param height The number of rows of the extracted matrix.
* @param width The number of columns of the extracted matrix.
* @param startingRow The starting row of the extracted matrix.
* @param startingCol The starting column of the extracted matrix.
* @return The extracted matrix.
*/
public final <R2 extends Num, C2 extends Num> Matrix<R2, C2> block(
Nat<R2> height, Nat<C2> width, int startingRow, int startingCol) {
return new Matrix<>(this.m_storage.extractMatrix(
startingRow,
Objects.requireNonNull(height).getNum() + startingRow,
startingCol,
Objects.requireNonNull(width).getNum() + startingCol));
}
/**
* Assign a matrix of a given size and start position.
*
* @param startingRow The row to start at.
* @param startingCol The column to start at.
* @param other The matrix to assign the block to.
*/
public <R2 extends Num, C2 extends Num> void assignBlock(int startingRow, int startingCol,
Matrix<R2, C2> other) {
this.m_storage.insertIntoThis(
startingRow,
startingCol,
Objects.requireNonNull(other).m_storage);
}
/**
* Extracts a submatrix from the supplied matrix and inserts it in a submatrix in "this". The
* shape of "this" is used to determine the size of the matrix extracted.
*
* @param startingRow The starting row in the supplied matrix to extract the submatrix.
* @param startingCol The starting column in the supplied matrix to extract the submatrix.
* @param other The matrix to extract the submatrix from.
*/
public <R2 extends Num, C2 extends Num> void extractFrom(int startingRow, int startingCol,
Matrix<R2, C2> other) {
CommonOps_DDRM.extract(other.m_storage.getDDRM(), startingRow, startingCol,
this.m_storage.getDDRM());
}
/**
* Decompose "this" matrix using Cholesky Decomposition. If the "this" matrix is zeros, it
* will return the zero matrix.
*
* @param lowerTriangular Whether or not we want to decompose to the lower triangular
* Cholesky matrix.
* @return The decomposed matrix.
* @throws RuntimeException if the matrix could not be decomposed(ie. is not positive
* semidefinite).
*/
@SuppressWarnings("PMD.AvoidThrowingRawExceptionTypes")
public Matrix<R, C> lltDecompose(boolean lowerTriangular) {
SimpleMatrix temp = m_storage.copy();
CholeskyDecomposition_F64<DMatrixRMaj> chol =
DecompositionFactory_DDRM.chol(temp.numRows(), lowerTriangular);
if (!chol.decompose(temp.getMatrix())) {
// check that the input is not all zeros -- if they are, we special case and return all
// zeros.
var matData = temp.getDDRM().data;
var isZeros = true;
for (double matDatum : matData) {
isZeros &= Math.abs(matDatum) < 1e-6;
}
if (isZeros) {
return new Matrix<>(new SimpleMatrix(temp.numRows(), temp.numCols()));
}
throw new RuntimeException("Cholesky decomposition failed! Input matrix:\n"
+ m_storage.toString());
}
return new Matrix<>(SimpleMatrix.wrap(chol.getT(null)));
}
/**
* Returns the row major data of this matrix as a double array.
*
* @return The row major data of this matrix as a double array.
*/
public double[] getData() {
return m_storage.getDDRM().getData();
}
/**
* Creates the identity matrix of the given dimension.
*
* @param dim The dimension of the desired matrix as a {@link Nat}.
* @param <D> The dimension of the desired matrix as a generic.
* @return The DxD identity matrix.
*/
public static <D extends Num> Matrix<D, D> eye(Nat<D> dim) {
return new Matrix<>(SimpleMatrix.identity(Objects.requireNonNull(dim).getNum()));
}
/**
* Creates the identity matrix of the given dimension.
*
* @param dim The dimension of the desired matrix as a {@link Num}.
* @param <D> The dimension of the desired matrix as a generic.
* @return The DxD identity matrix.
*/
public static <D extends Num> Matrix<D, D> eye(D dim) {
return new Matrix<>(SimpleMatrix.identity(Objects.requireNonNull(dim).getNum()));
}
/**
* Entrypoint to the {@link MatBuilder} class for creation
* of custom matrices with the given dimensions and contents.
*
* @param rows The number of rows of the desired matrix.
* @param cols The number of columns of the desired matrix.
* @param <R> The number of rows of the desired matrix as a generic.
* @param <C> The number of columns of the desired matrix as a generic.
* @return A builder to construct the matrix.
*/
public static <R extends Num, C extends Num> MatBuilder<R, C> mat(Nat<R> rows, Nat<C> cols) {
return new MatBuilder<>(Objects.requireNonNull(rows), Objects.requireNonNull(cols));
}
/**
* Reassigns dimensions of a {@link Matrix} to allow for operations with
* other matrices that have wildcard dimensions.
*
* @param mat The {@link Matrix} to remove the dimensions from.
* @return The matrix with reassigned dimensions.
*/
public static <R1 extends Num, C1 extends Num> Matrix<R1, C1> changeBoundsUnchecked(
Matrix<?, ?> mat) {
return new Matrix<>(mat.m_storage);
}
/**
* Checks if another {@link Matrix} is identical to "this" one within a specified tolerance.
*
* <p>This will check if each element is in tolerance of the corresponding element
* from the other {@link Matrix} or if the elements have the same symbolic meaning. For two
* elements to have the same symbolic meaning they both must be either Double.NaN,
* Double.POSITIVE_INFINITY, or Double.NEGATIVE_INFINITY.
*
* <p>NOTE:It is recommend to use {@link Matrix#isEqual(Matrix, double)} over this
* method when checking if two matrices are equal as {@link Matrix#isEqual(Matrix, double)}
* will return false if an element is uncountable. This method should only be used when
* uncountable elements need to compared.
*
* @param other The {@link Matrix} to check against this one.
* @param tolerance The tolerance to check equality with.
* @return true if this matrix is identical to the one supplied.
*/
public boolean isIdentical(Matrix<?, ?> other, double tolerance) {
return MatrixFeatures_DDRM.isIdentical(this.m_storage.getDDRM(),
other.m_storage.getDDRM(), tolerance);
}
/**
* Checks if another {@link Matrix} is equal to "this" within a specified tolerance.
*
* <p>This will check if each element is in tolerance of the corresponding element
* from the other {@link Matrix}.
*
* <p>tol &ge; |a<sub>ij</sub> - b<sub>ij</sub>|
*
* @param other The {@link Matrix} to check against this one.
* @param tolerance The tolerance to check equality with.
* @return true if this matrix is equal to the one supplied.
*/
public boolean isEqual(Matrix<?, ?> other, double tolerance) {
return MatrixFeatures_DDRM.isEquals(this.m_storage.getDDRM(),
other.m_storage.getDDRM(), tolerance);
}
@Override
public String toString() {
return m_storage.toString();
}
/**
* Checks if an object is equal to this {@link Matrix}.
*
* <p>a<sub>ij</sub> == b<sub>ij</sub>
*
* @param other The Object to check against this {@link Matrix}.
* @return true if the object supplied is a {@link Matrix} and is equal to this matrix.
*/
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (!(other instanceof Matrix)) {
return false;
}
Matrix<?, ?> matrix = (Matrix<?, ?>) other;
if (MatrixFeatures_DDRM.hasUncountable(matrix.m_storage.getDDRM())) {
return false;
}
return MatrixFeatures_DDRM.isEquals(this.m_storage.getDDRM(), matrix.m_storage.getDDRM());
}
@Override
public int hashCode() {
return Objects.hash(m_storage);
}
}

View File

@@ -13,6 +13,7 @@ import org.ejml.simple.SimpleMatrix;
import edu.wpi.first.wpiutil.math.numbers.N1;
@Deprecated
public final class MatrixUtils {
private MatrixUtils() {
throw new AssertionError("utility class");

View File

@@ -0,0 +1,31 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpiutil.math;
public class Pair<A, B> {
private final A m_first;
private final B m_second;
public Pair(A first, B second) {
m_first = first;
m_second = second;
}
public A getFirst() {
return m_first;
}
public B getSecond() {
return m_second;
}
@SuppressWarnings("ParameterName")
public static <A, B> Pair<A, B> of(A a, B b) {
return new Pair<>(a, b);
}
}

View File

@@ -9,12 +9,17 @@ package edu.wpi.first.wpiutil.math;
import java.util.function.BiFunction;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.NormOps_DDRM;
import org.ejml.dense.row.factory.DecompositionFactory_DDRM;
import org.ejml.interfaces.decomposition.CholeskyDecomposition_F64;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;
public class SimpleMatrixUtils {
private SimpleMatrixUtils() {}
@SuppressWarnings("PMD.TooManyMethods")
public final class SimpleMatrixUtils {
private SimpleMatrixUtils() {
}
/**
* Compute the matrix exponential, e^M of the given matrix.
@@ -98,8 +103,10 @@ public class SimpleMatrixUtils {
SimpleMatrix A4 = A2.mult(A2);
SimpleMatrix A6 = A4.mult(A2);
SimpleMatrix U = A.mult(A6.scale(b[7]).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
SimpleMatrix V = A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0]));
SimpleMatrix U =
A.mult(A6.scale(b[7]).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
SimpleMatrix V =
A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0]));
return new Pair<>(U, V);
}
@@ -114,8 +121,10 @@ public class SimpleMatrixUtils {
SimpleMatrix A6 = A4.mult(A2);
SimpleMatrix A8 = A6.mult(A2);
SimpleMatrix U = A.mult(A8.scale(b[9]).plus(A6.scale(b[7])).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
SimpleMatrix V = A8.scale(b[8]).plus(A6.scale(b[6])).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0]));
SimpleMatrix U =
A.mult(A8.scale(b[9]).plus(A6.scale(b[7])).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
SimpleMatrix V =
A8.scale(b[8]).plus(A6.scale(b[6])).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0]));
return new Pair<>(U, V);
}
@@ -131,8 +140,10 @@ public class SimpleMatrixUtils {
SimpleMatrix A4 = A2.mult(A2);
SimpleMatrix A6 = A4.mult(A2);
SimpleMatrix U = A.mult(A6.scale(b[13]).plus(A4.scale(b[11])).plus(A2.scale(b[9])).plus(A6.scale(b[7])).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
SimpleMatrix V = A6.mult(A6.scale(b[12]).plus(A4.scale(b[10])).plus(A2.scale(b[8]))).plus(A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0])));
SimpleMatrix U =
A.mult(A6.scale(b[13]).plus(A4.scale(b[11])).plus(A2.scale(b[9])).plus(A6.scale(b[7])).plus(A4.scale(b[5])).plus(A2.scale(b[3])).plus(ident.scale(b[1])));
SimpleMatrix V =
A6.mult(A6.scale(b[12]).plus(A4.scale(b[10])).plus(A2.scale(b[8]))).plus(A6.scale(b[6]).plus(A4.scale(b[4])).plus(A2.scale(b[2])).plus(ident.scale(b[0])));
return new Pair<>(U, V);
}
@@ -141,21 +152,76 @@ public class SimpleMatrixUtils {
return SimpleMatrix.identity(Math.min(rows, cols));
}
private static class Pair<A, B> {
private final A m_first;
private final B m_second;
Pair(A first, B second) {
m_first = first;
m_second = second;
}
public A getFirst() {
return m_first;
}
public B getSecond() {
return m_second;
}
/**
* The identy of a square matrix.
*
* @param rows the number of rows (and columns)
* @return the identiy matrix, rows x rows.
*/
public static SimpleMatrix eye(int rows) {
return SimpleMatrix.identity(rows);
}
/**
* Decompose the given matrix using Cholesky Decomposition and return a view of the upper
* triangular matrix (if you want lower triangular see the other overload of this method.) If the
* input matrix is zeros, this will return the zero matrix.
*
* @param src The matrix to decompose.
* @return The decomposed matrix.
* @throws RuntimeException if the matrix could not be decomposed (ie. is not positive
* semidefinite).
*/
public static SimpleMatrix lltDecompose(SimpleMatrix src) {
return lltDecompose(src, false);
}
/**
* Decompose the given matrix using Cholesky Decomposition. If the input matrix is zeros, this
* will return the zero matrix.
*
* @param src The matrix to decompose.
* @param lowerTriangular if we want to decompose to the lower triangular Cholesky matrix.
* @return The decomposed matrix.
* @throws RuntimeException if the matrix could not be decomposed (ie. is not positive
* semidefinite).
*/
@SuppressWarnings("PMD.AvoidThrowingRawExceptionTypes")
public static SimpleMatrix lltDecompose(SimpleMatrix src, boolean lowerTriangular) {
SimpleMatrix temp = src.copy();
CholeskyDecomposition_F64<DMatrixRMaj> chol =
DecompositionFactory_DDRM.chol(temp.numRows(), lowerTriangular);
if (!chol.decompose(temp.getMatrix())) {
// check that the input is not all zeros -- if they are, we special case and return all
// zeros.
var matData = temp.getDDRM().data;
var isZeros = true;
for (double matDatum : matData) {
isZeros &= Math.abs(matDatum) < 1e-6;
}
if (isZeros) {
return new SimpleMatrix(temp.numRows(), temp.numCols());
}
throw new RuntimeException("Cholesky decomposition failed! Input matrix:\n" + src.toString());
}
return SimpleMatrix.wrap(chol.getT(null));
}
/**
* Computes the matrix exponential using Eigen's solver.
*
* @param A the matrix to exponentiate.
* @return the exponential of A.
*/
@SuppressWarnings("ParameterName")
public static SimpleMatrix exp(
SimpleMatrix A) {
SimpleMatrix toReturn = new SimpleMatrix(A.numRows(), A.numRows());
WPIMathJNI.exp(A.getDDRM().getData(), A.numRows(), toReturn.getDDRM().getData());
return toReturn;
}
}

View File

@@ -8,14 +8,171 @@
package edu.wpi.first.wpiutil.math;
import edu.wpi.first.wpiutil.math.numbers.N1;
import edu.wpi.first.wpiutil.math.numbers.N10;
import edu.wpi.first.wpiutil.math.numbers.N2;
import edu.wpi.first.wpiutil.math.numbers.N3;
import edu.wpi.first.wpiutil.math.numbers.N4;
import edu.wpi.first.wpiutil.math.numbers.N5;
import edu.wpi.first.wpiutil.math.numbers.N6;
import edu.wpi.first.wpiutil.math.numbers.N7;
import edu.wpi.first.wpiutil.math.numbers.N8;
import edu.wpi.first.wpiutil.math.numbers.N9;
/**
* A specialization of {@link MatBuilder} for constructing vectors (Nx1 matrices).
*
* @param <N> The dimension of the vector to be constructed.
*/
@SuppressWarnings("PMD.TooManyMethods")
public class VecBuilder<N extends Num> extends MatBuilder<N, N1> {
public VecBuilder(Nat<N> rows) {
super(rows, Nat.N1());
}
private Vector<N> fillVec(double... data) {
return new Vector<>(fill(data));
}
/**
* Returns a 1x1 vector containing the given elements.
*
* @param n1 the first element.
*/
public static Vector<N1> fill(double n1) {
return new VecBuilder<>(Nat.N1()).fillVec(n1);
}
/**
* Returns a 2x1 vector containing the given elements.
*
* @param n1 the first element.
* @param n2 the second element.
*/
public static Vector<N2> fill(double n1, double n2) {
return new VecBuilder<>(Nat.N2()).fillVec(n1, n2);
}
/**
* Returns a 3x1 vector containing the given elements.
*
* @param n1 the first element.
* @param n2 the second element.
* @param n3 the third element.
*/
public static Vector<N3> fill(double n1, double n2, double n3) {
return new VecBuilder<>(Nat.N3()).fillVec(n1, n2, n3);
}
/**
* Returns a 4x1 vector containing the given elements.
*
* @param n1 the first element.
* @param n2 the second element.
* @param n3 the third element.
* @param n4 the fourth element.
*/
public static Vector<N4> fill(double n1, double n2, double n3, double n4) {
return new VecBuilder<>(Nat.N4()).fillVec(n1, n2, n3, n4);
}
/**
* Returns a 5x1 vector containing the given elements.
*
* @param n1 the first element.
* @param n2 the second element.
* @param n3 the third element.
* @param n4 the fourth element.
* @param n5 the fifth element.
*/
public static Vector<N5> fill(double n1, double n2, double n3, double n4, double n5) {
return new VecBuilder<>(Nat.N5()).fillVec(n1, n2, n3, n4, n5);
}
/**
* Returns a 6x1 vector containing the given elements.
*
* @param n1 the first element.
* @param n2 the second element.
* @param n3 the third element.
* @param n4 the fourth element.
* @param n5 the fifth element.
* @param n6 the sixth element.
*/
public static Vector<N6> fill(double n1, double n2, double n3, double n4, double n5,
double n6) {
return new VecBuilder<>(Nat.N6()).fillVec(n1, n2, n3, n4, n5, n6);
}
/**
* Returns a 7x1 vector containing the given elements.
*
* @param n1 the first element.
* @param n2 the second element.
* @param n3 the third element.
* @param n4 the fourth element.
* @param n5 the fifth element.
* @param n6 the sixth element.
* @param n7 the seventh element.
*/
public static Vector<N7> fill(double n1, double n2, double n3, double n4, double n5,
double n6, double n7) {
return new VecBuilder<>(Nat.N7()).fillVec(n1, n2, n3, n4, n5, n6, n7);
}
/**
* Returns a 8x1 vector containing the given elements.
*
* @param n1 the first element.
* @param n2 the second element.
* @param n3 the third element.
* @param n4 the fourth element.
* @param n5 the fifth element.
* @param n6 the sixth element.
* @param n7 the seventh element.
* @param n8 the eighth element.
*/
public static Vector<N8> fill(double n1, double n2, double n3, double n4, double n5,
double n6, double n7, double n8) {
return new VecBuilder<>(Nat.N8()).fillVec(n1, n2, n3, n4, n5, n6, n7, n8);
}
/**
* Returns a 9x1 vector containing the given elements.
*
* @param n1 the first element.
* @param n2 the second element.
* @param n3 the third element.
* @param n4 the fourth element.
* @param n5 the fifth element.
* @param n6 the sixth element.
* @param n7 the seventh element.
* @param n8 the eighth element.
* @param n9 the ninth element.
*/
public static Vector<N9> fill(double n1, double n2, double n3, double n4, double n5,
double n6, double n7, double n8, double n9) {
return new VecBuilder<>(Nat.N9()).fillVec(n1, n2, n3, n4, n5, n6, n7, n8, n9);
}
/**
* Returns a 10x1 vector containing the given elements.
*
* @param n1 the first element.
* @param n2 the second element.
* @param n3 the third element.
* @param n4 the fourth element.
* @param n5 the fifth element.
* @param n6 the sixth element.
* @param n7 the seventh element.
* @param n8 the eighth element.
* @param n9 the ninth element.
* @param n10 the tenth element.
*/
@SuppressWarnings("PMD.ExcessiveParameterList")
public static Vector<N10> fill(double n1, double n2, double n3, double n4, double n5,
double n6, double n7, double n8, double n9, double n10) {
return new VecBuilder<>(Nat.N10()).fillVec(n1, n2, n3, n4, n5, n6, n7, n8, n9, n10);
}
}

View File

@@ -0,0 +1,55 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpiutil.math;
import org.ejml.simple.SimpleMatrix;
import edu.wpi.first.wpiutil.math.numbers.N1;
/**
* A shape-safe wrapper over Efficient Java Matrix Library (EJML) matrices.
*
* <p>This class is intended to be used alongside the state space library.
*
* @param <R> The number of rows in this matrix.
*/
public class Vector<R extends Num> extends Matrix<R, N1> {
/**
* Constructs an empty zero vector of the given dimensions.
*
* @param rows The number of rows of the vector.
*/
public Vector(Nat<R> rows) {
super(rows, Nat.N1());
}
/**
* Constructs a new {@link Vector} with the given storage.
* Caller should make sure that the provided generic bounds match
* the shape of the provided {@link Vector}.
*
* <p>NOTE:It is not recommended to use this constructor unless the
* {@link SimpleMatrix} API is absolutely necessary due to the desired
* function not being accessible through the {@link Vector} wrapper.
*
* @param storage The {@link SimpleMatrix} to back this vector.
*/
public Vector(SimpleMatrix storage) {
super(storage);
}
/**
* Constructs a new vector with the storage of the supplied matrix.
*
* @param other The {@link Vector} to copy the storage of.
*/
public Vector(Matrix<R, N1> other) {
super(other);
}
}

View File

@@ -0,0 +1,79 @@
/*----------------------------------------------------------------------------*/
/* Copyright (c) 2020 FIRST. All Rights Reserved. */
/* Open Source Software - may be modified and shared by FRC teams. The code */
/* must be accompanied by the FIRST BSD license file in the root directory of */
/* the project. */
/*----------------------------------------------------------------------------*/
package edu.wpi.first.wpiutil.math;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean;
import edu.wpi.first.wpiutil.RuntimeLoader;
public final class WPIMathJNI {
static boolean libraryLoaded = false;
static RuntimeLoader<WPIMathJNI> loader = null;
public static class Helper {
private static AtomicBoolean extractOnStaticLoad = new AtomicBoolean(true);
public static boolean getExtractOnStaticLoad() {
return extractOnStaticLoad.get();
}
public static void setExtractOnStaticLoad(boolean load) {
extractOnStaticLoad.set(load);
}
}
static {
if (Helper.getExtractOnStaticLoad()) {
try {
loader = new RuntimeLoader<>("wpimathjni", RuntimeLoader.getDefaultExtractionRoot(), WPIMathJNI.class);
loader.loadLibrary();
} catch (IOException ex) {
ex.printStackTrace();
System.exit(1);
}
libraryLoaded = true;
}
}
/**
* Force load the library.
*/
public static synchronized void forceLoad() throws IOException {
if (libraryLoaded) {
return;
}
loader = new RuntimeLoader<>("wpiutiljni", RuntimeLoader.getDefaultExtractionRoot(), WPIMathJNI.class);
loader.loadLibrary();
libraryLoaded = true;
}
/**
* Computes the matrix exp.
*
* @param src Array of elements of the matrix to be exponentiated.
* @param rows how many rows there are.
* @param dst Array where the result will be stored.
*/
public static native void exp(double[] src, int rows, double[] dst);
/**
* Returns true if (A, B) is a stabilizable pair.
*
* <p>(A,B) is stabilizable if and only if the uncontrollable eigenvalues of A, if
* any, have absolute values less than one, where an eigenvalue is
* uncontrollable if rank(lambda * I - A, B) &lt; n where n is number of states.
*
* @param states the number of states of the system.
* @param inputs the number of inputs to the system.
* @param A System matrix.
* @param B Input matrix.
* @return If the system is stabilizable.
*/
public static native boolean isStabilizable(int states, int inputs, double[] A, double[] B);
}