[upstream_utils] Upgrade to Sleipnir 0.6.2 (#8996)

This commit is contained in:
Tyler Veness
2026-06-19 20:16:03 -07:00
committed by GitHub
parent 396b553069
commit 481a586009
11 changed files with 255 additions and 266 deletions

View File

@@ -46,7 +46,7 @@ using small_vector = wpi::util::SmallVector<T>;
def main():
name = "sleipnir"
url = "https://github.com/SleipnirGroup/Sleipnir"
tag = "v0.6.1"
tag = "v0.6.2"
sleipnir = Lib(name, url, tag, copy_upstream_src)
sleipnir.main()

View File

@@ -9,7 +9,7 @@ Subject: [PATCH 02/10] Use wpi::SmallVector
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/include/sleipnir/autodiff/expression.hpp b/include/sleipnir/autodiff/expression.hpp
index f04967c6a84ea76b77cebb7ef8edcb798d03e2a5..361c74702107eadf6f1fbe1682a7ef296abd19a3 100644
index 1e0a57bd78bdd6a41d87539c1e4536aed8bfc42d..8b1c520934580ccc94d0a7e46ab500e6bdd65207 100644
--- a/include/sleipnir/autodiff/expression.hpp
+++ b/include/sleipnir/autodiff/expression.hpp
@@ -34,7 +34,7 @@ struct Expression;
@@ -21,7 +21,7 @@ index f04967c6a84ea76b77cebb7ef8edcb798d03e2a5..361c74702107eadf6f1fbe1682a7ef29
/// Typedef for intrusive shared pointer to Expression.
///
@@ -749,7 +749,7 @@ constexpr void inc_ref_count(Expression<Scalar>* expr) {
@@ -727,7 +727,7 @@ constexpr void inc_ref_count(Expression<Scalar>* expr) {
/// @tparam Scalar Scalar type.
/// @param expr The shared pointer's managed object.
template <typename Scalar>

View File

@@ -8,7 +8,7 @@ Subject: [PATCH 08/10] Suppress GCC 12 warning false positive
1 file changed, 7 insertions(+)
diff --git a/include/sleipnir/autodiff/variable_matrix.hpp b/include/sleipnir/autodiff/variable_matrix.hpp
index 2bdbf11841b9920111b9cb1426d6fd04764040e4..8b84a04780240cffb801e2c6f84e22c0e5246286 100644
index 72559959cfc4cb38e099a2268c66088ea19551f5..284cfd3aafdee76e8623df402e9c1bdd90dab116 100644
--- a/include/sleipnir/autodiff/variable_matrix.hpp
+++ b/include/sleipnir/autodiff/variable_matrix.hpp
@@ -503,6 +503,10 @@ class VariableMatrix : public SleipnirBase {

View File

@@ -8,14 +8,14 @@ Subject: [PATCH 10/10] Use operator() instead of multidimensional array
include/sleipnir/autodiff/hessian.hpp | 4 +-
include/sleipnir/autodiff/jacobian.hpp | 4 +-
include/sleipnir/autodiff/variable.hpp | 8 +-
include/sleipnir/autodiff/variable_block.hpp | 76 ++++-----
include/sleipnir/autodiff/variable_block.hpp | 80 ++++-----
include/sleipnir/autodiff/variable_matrix.hpp | 158 +++++++++---------
include/sleipnir/optimization/ocp.hpp | 14 +-
include/sleipnir/optimization/problem.hpp | 6 +-
7 files changed, 135 insertions(+), 135 deletions(-)
7 files changed, 137 insertions(+), 137 deletions(-)
diff --git a/include/sleipnir/autodiff/hessian.hpp b/include/sleipnir/autodiff/hessian.hpp
index d9d25a00251dd31e446bd0419f0c13e8f2456f2d..39cd798160ff272aeac4b12c92edee3f2a213a22 100644
index 56100d882f9eafe2cdd4773e99a9166c21221ce8..ee42cd3fc2cdc9545f4dbfb1949a644620f30215 100644
--- a/include/sleipnir/autodiff/hessian.hpp
+++ b/include/sleipnir/autodiff/hessian.hpp
@@ -111,9 +111,9 @@ class Hessian {
@@ -87,7 +87,7 @@ index 5f32d7ea16864ad9a4fdcc4b067ef7dff2e9027d..5d7b56de884cdb73fc9f08fcfa1982b2
}
diff --git a/include/sleipnir/autodiff/variable_block.hpp b/include/sleipnir/autodiff/variable_block.hpp
index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d7bf629dc 100644
index 55cafb5d48afff2dbdff07516d85759180ef74c7..9c9e9017487fb0fd19c584c58e5ba3c6c6cd3e7d 100644
--- a/include/sleipnir/autodiff/variable_block.hpp
+++ b/include/sleipnir/autodiff/variable_block.hpp
@@ -49,7 +49,7 @@ class VariableBlock : public SleipnirBase {
@@ -279,7 +279,19 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
int row_slice_length,
Slice col_slice,
int col_slice_length) const {
@@ -453,7 +453,7 @@ class VariableBlock : public SleipnirBase {
@@ -438,9 +438,9 @@ class VariableBlock : public SleipnirBase {
for (int j = 0; j < rhs.cols(); ++j) {
Variable sum{Scalar(0)};
for (int k = 0; k < cols(); ++k) {
- sum += lhs_old_row[k] * rhs[k, j];
+ sum += lhs_old_row[k] * rhs(k, j);
}
- (*this)[i, j] = sum;
+ (*this)(i, j) = sum;
}
}
@@ -454,7 +454,7 @@ class VariableBlock : public SleipnirBase {
VariableBlock<Mat>& operator*=(const ScalarLike auto& rhs) {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -288,7 +300,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
}
}
@@ -469,7 +469,7 @@ class VariableBlock : public SleipnirBase {
@@ -470,7 +470,7 @@ class VariableBlock : public SleipnirBase {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -297,7 +309,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
}
}
@@ -483,7 +483,7 @@ class VariableBlock : public SleipnirBase {
@@ -484,7 +484,7 @@ class VariableBlock : public SleipnirBase {
VariableBlock<Mat>& operator/=(const ScalarLike auto& rhs) {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -306,7 +318,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
}
}
@@ -499,7 +499,7 @@ class VariableBlock : public SleipnirBase {
@@ -500,7 +500,7 @@ class VariableBlock : public SleipnirBase {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -315,7 +327,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
}
}
@@ -515,7 +515,7 @@ class VariableBlock : public SleipnirBase {
@@ -516,7 +516,7 @@ class VariableBlock : public SleipnirBase {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -324,7 +336,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
}
}
@@ -531,7 +531,7 @@ class VariableBlock : public SleipnirBase {
@@ -532,7 +532,7 @@ class VariableBlock : public SleipnirBase {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -333,7 +345,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
}
}
@@ -547,7 +547,7 @@ class VariableBlock : public SleipnirBase {
@@ -548,7 +548,7 @@ class VariableBlock : public SleipnirBase {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -342,7 +354,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
}
}
@@ -558,7 +558,7 @@ class VariableBlock : public SleipnirBase {
@@ -559,7 +559,7 @@ class VariableBlock : public SleipnirBase {
// NOLINTNEXTLINE (google-explicit-constructor)
operator Variable<Scalar>() const {
slp_assert(rows() == 1 && cols() == 1);
@@ -351,7 +363,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
}
/// Returns the transpose of the variable matrix.
@@ -569,7 +569,7 @@ class VariableBlock : public SleipnirBase {
@@ -570,7 +570,7 @@ class VariableBlock : public SleipnirBase {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -360,7 +372,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
}
}
@@ -591,7 +591,7 @@ class VariableBlock : public SleipnirBase {
@@ -592,7 +592,7 @@ class VariableBlock : public SleipnirBase {
/// @param row The row of the element to return.
/// @param col The column of the element to return.
/// @return An element of the variable matrix.
@@ -369,7 +381,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
/// Returns an element of the variable block.
///
@@ -611,7 +611,7 @@ class VariableBlock : public SleipnirBase {
@@ -612,7 +612,7 @@ class VariableBlock : public SleipnirBase {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -378,7 +390,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
}
}
@@ -629,7 +629,7 @@ class VariableBlock : public SleipnirBase {
@@ -630,7 +630,7 @@ class VariableBlock : public SleipnirBase {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -388,7 +400,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
}
diff --git a/include/sleipnir/autodiff/variable_matrix.hpp b/include/sleipnir/autodiff/variable_matrix.hpp
index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c7449d109 100644
index 284cfd3aafdee76e8623df402e9c1bdd90dab116..2f91d1714ccee8271c0186c877dd5f84720da692 100644
--- a/include/sleipnir/autodiff/variable_matrix.hpp
+++ b/include/sleipnir/autodiff/variable_matrix.hpp
@@ -154,7 +154,7 @@ class VariableMatrix : public SleipnirBase {
@@ -571,28 +583,28 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -650,9 +650,9 @@ class VariableMatrix : public SleipnirBase {
@@ -651,9 +651,9 @@ class VariableMatrix : public SleipnirBase {
for (int j = 0; j < rhs.cols(); ++j) {
Variable sum{Scalar(0)};
for (int k = 0; k < cols(); ++k) {
- sum += (*this)[i, k] * rhs[k, j];
+ sum += (*this)(i, k) * rhs(k, j);
- sum += lhs_old_row[k] * rhs[k, j];
+ sum += lhs_old_row[k] * rhs(k, j);
}
- (*this)[i, j] = sum;
+ (*this)(i, j) = sum;
}
}
@@ -666,7 +666,7 @@ class VariableMatrix : public SleipnirBase {
@@ -667,7 +667,7 @@ class VariableMatrix : public SleipnirBase {
VariableMatrix& operator*=(const ScalarLike auto& rhs) {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < rhs.cols(); ++col) {
for (int col = 0; col < cols(); ++col) {
- (*this)[row, col] *= rhs;
+ (*this)(row, col) *= rhs;
}
}
@@ -685,7 +685,7 @@ class VariableMatrix : public SleipnirBase {
@@ -686,7 +686,7 @@ class VariableMatrix : public SleipnirBase {
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
@@ -601,7 +613,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -704,7 +704,7 @@ class VariableMatrix : public SleipnirBase {
@@ -705,7 +705,7 @@ class VariableMatrix : public SleipnirBase {
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
@@ -610,7 +622,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -723,7 +723,7 @@ class VariableMatrix : public SleipnirBase {
@@ -724,7 +724,7 @@ class VariableMatrix : public SleipnirBase {
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
@@ -619,7 +631,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -737,7 +737,7 @@ class VariableMatrix : public SleipnirBase {
@@ -738,7 +738,7 @@ class VariableMatrix : public SleipnirBase {
VariableMatrix& operator/=(const ScalarLike auto& rhs) {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -628,7 +640,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -757,7 +757,7 @@ class VariableMatrix : public SleipnirBase {
@@ -758,7 +758,7 @@ class VariableMatrix : public SleipnirBase {
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
@@ -637,7 +649,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -777,7 +777,7 @@ class VariableMatrix : public SleipnirBase {
@@ -778,7 +778,7 @@ class VariableMatrix : public SleipnirBase {
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
@@ -646,7 +658,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -797,7 +797,7 @@ class VariableMatrix : public SleipnirBase {
@@ -798,7 +798,7 @@ class VariableMatrix : public SleipnirBase {
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
@@ -655,7 +667,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -813,7 +813,7 @@ class VariableMatrix : public SleipnirBase {
@@ -814,7 +814,7 @@ class VariableMatrix : public SleipnirBase {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -664,7 +676,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -829,7 +829,7 @@ class VariableMatrix : public SleipnirBase {
@@ -830,7 +830,7 @@ class VariableMatrix : public SleipnirBase {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -673,7 +685,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -849,7 +849,7 @@ class VariableMatrix : public SleipnirBase {
@@ -850,7 +850,7 @@ class VariableMatrix : public SleipnirBase {
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
@@ -682,7 +694,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -869,7 +869,7 @@ class VariableMatrix : public SleipnirBase {
@@ -870,7 +870,7 @@ class VariableMatrix : public SleipnirBase {
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
@@ -691,7 +703,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -889,7 +889,7 @@ class VariableMatrix : public SleipnirBase {
@@ -890,7 +890,7 @@ class VariableMatrix : public SleipnirBase {
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
@@ -700,7 +712,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -905,7 +905,7 @@ class VariableMatrix : public SleipnirBase {
@@ -906,7 +906,7 @@ class VariableMatrix : public SleipnirBase {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -709,7 +721,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -921,7 +921,7 @@ class VariableMatrix : public SleipnirBase {
@@ -922,7 +922,7 @@ class VariableMatrix : public SleipnirBase {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -718,7 +730,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -937,7 +937,7 @@ class VariableMatrix : public SleipnirBase {
@@ -938,7 +938,7 @@ class VariableMatrix : public SleipnirBase {
for (int row = 0; row < result.rows(); ++row) {
for (int col = 0; col < result.cols(); ++col) {
@@ -727,7 +739,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -948,7 +948,7 @@ class VariableMatrix : public SleipnirBase {
@@ -949,7 +949,7 @@ class VariableMatrix : public SleipnirBase {
// NOLINTNEXTLINE (google-explicit-constructor)
operator Variable<Scalar>() const {
slp_assert(rows() == 1 && cols() == 1);
@@ -736,7 +748,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
/// Returns the transpose of the variable matrix.
@@ -959,7 +959,7 @@ class VariableMatrix : public SleipnirBase {
@@ -960,7 +960,7 @@ class VariableMatrix : public SleipnirBase {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -745,7 +757,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -981,7 +981,7 @@ class VariableMatrix : public SleipnirBase {
@@ -982,7 +982,7 @@ class VariableMatrix : public SleipnirBase {
/// @param row The row of the element to return.
/// @param col The column of the element to return.
/// @return An element of the variable matrix.
@@ -754,7 +766,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
/// Returns an element of the variable matrix.
///
@@ -998,7 +998,7 @@ class VariableMatrix : public SleipnirBase {
@@ -999,7 +999,7 @@ class VariableMatrix : public SleipnirBase {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -763,7 +775,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -1016,7 +1016,7 @@ class VariableMatrix : public SleipnirBase {
@@ -1017,7 +1017,7 @@ class VariableMatrix : public SleipnirBase {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < cols(); ++col) {
@@ -772,7 +784,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -1269,7 +1269,7 @@ VariableMatrix<Scalar> cwise_reduce(
@@ -1270,7 +1270,7 @@ VariableMatrix<Scalar> cwise_reduce(
for (int row = 0; row < lhs.rows(); ++row) {
for (int col = 0; col < lhs.cols(); ++col) {
@@ -781,7 +793,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -1402,17 +1402,17 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
@@ -1403,17 +1403,17 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
if (A.rows() == 1 && A.cols() == 1) {
// Compute optimal inverse instead of using Eigen's general solver
@@ -804,7 +816,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
VariableMatrix adj_A{{d, -b}, {-c, a}};
auto det_A = a * d - b * c;
@@ -1429,15 +1429,15 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
@@ -1430,15 +1430,15 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
//
// https://www.wolframalpha.com/input?i=inverse+%7B%7Ba%2C+b%2C+c%7D%2C+%7Bd%2C+e%2C+f%7D%2C+%7Bg%2C+h%2C+i%7D%7D
@@ -829,7 +841,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
auto ae = a * e;
auto af = a * f;
@@ -1477,22 +1477,22 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
@@ -1478,22 +1478,22 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
//
// https://www.wolframalpha.com/input?i=inverse+%7B%7Ba%2C+b%2C+c%2C+d%7D%2C+%7Be%2C+f%2C+g%2C+h%7D%2C+%7Bi%2C+j%2C+k%2C+l%7D%2C+%7Bm%2C+n%2C+o%2C+p%7D%7D
@@ -868,7 +880,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
auto afk = a * f * k;
auto afl = a * f * l;
@@ -1623,14 +1623,14 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
@@ -1624,14 +1624,14 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
MatrixXv eigen_A{A.rows(), A.cols()};
for (int row = 0; row < A.rows(); ++row) {
for (int col = 0; col < A.cols(); ++col) {
@@ -885,7 +897,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
}
}
@@ -1639,7 +1639,7 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
@@ -1640,7 +1640,7 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
VariableMatrix<Scalar> X{detail::empty, A.cols(), B.cols()};
for (int row = 0; row < X.rows(); ++row) {
for (int col = 0; col < X.cols(); ++col) {

View File

@@ -379,11 +379,9 @@ struct Expression {
///
/// @param lhs Left argument to binary operator.
/// @param rhs Right argument to binary operator.
/// @param parent_adjoint Adjoint of parent expression.
/// @return ∂/∂l as a Scalar.
virtual Scalar grad_l([[maybe_unused]] Scalar lhs,
[[maybe_unused]] Scalar rhs,
[[maybe_unused]] Scalar parent_adjoint) const {
[[maybe_unused]] Scalar rhs) const {
return Scalar(0);
}
@@ -391,11 +389,9 @@ struct Expression {
///
/// @param lhs Left argument to binary operator.
/// @param rhs Right argument to binary operator.
/// @param parent_adjoint Adjoint of parent expression.
/// @return ∂/∂r as a Scalar.
virtual Scalar grad_r([[maybe_unused]] Scalar lhs,
[[maybe_unused]] Scalar rhs,
[[maybe_unused]] Scalar parent_adjoint) const {
[[maybe_unused]] Scalar rhs) const {
return Scalar(0);
}
@@ -403,12 +399,10 @@ struct Expression {
///
/// @param lhs Left argument to binary operator.
/// @param rhs Right argument to binary operator.
/// @param parent_adjoint Adjoint of parent expression.
/// @return ∂/∂l as an Expression.
virtual ExpressionPtr<Scalar> grad_expr_l(
[[maybe_unused]] const ExpressionPtr<Scalar>& lhs,
[[maybe_unused]] const ExpressionPtr<Scalar>& rhs,
[[maybe_unused]] const ExpressionPtr<Scalar>& parent_adjoint) const {
[[maybe_unused]] const ExpressionPtr<Scalar>& rhs) const {
return constant_ptr(Scalar(0));
}
@@ -416,12 +410,10 @@ struct Expression {
///
/// @param lhs Left argument to binary operator.
/// @param rhs Right argument to binary operator.
/// @param parent_adjoint Adjoint of parent expression.
/// @return ∂/∂r as an Expression.
virtual ExpressionPtr<Scalar> grad_expr_r(
[[maybe_unused]] const ExpressionPtr<Scalar>& lhs,
[[maybe_unused]] const ExpressionPtr<Scalar>& rhs,
[[maybe_unused]] const ExpressionPtr<Scalar>& parent_adjoint) const {
[[maybe_unused]] const ExpressionPtr<Scalar>& rhs) const {
return constant_ptr(Scalar(0));
}
};
@@ -462,24 +454,20 @@ struct BinaryMinusExpression final : Expression<Scalar> {
std::string_view name() const override { return "binary minus"; }
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override {
return parent_adjoint;
}
Scalar grad_l(Scalar, Scalar) const override { return this->adjoint; }
Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override {
return -parent_adjoint;
}
Scalar grad_r(Scalar, Scalar) const override { return -this->adjoint; }
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>&, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint;
const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>&) const override {
return this->adjoint_expr;
}
ExpressionPtr<Scalar> grad_expr_r(
const ExpressionPtr<Scalar>&, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return -parent_adjoint;
const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>&) const override {
return -this->adjoint_expr;
}
};
@@ -503,24 +491,20 @@ struct BinaryPlusExpression final : Expression<Scalar> {
std::string_view name() const override { return "binary plus"; }
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override {
return parent_adjoint;
}
Scalar grad_l(Scalar, Scalar) const override { return this->adjoint; }
Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override {
return parent_adjoint;
}
Scalar grad_r(Scalar, Scalar) const override { return this->adjoint; }
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>&, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint;
const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>&) const override {
return this->adjoint_expr;
}
ExpressionPtr<Scalar> grad_expr_r(
const ExpressionPtr<Scalar>&, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint;
const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>&) const override {
return this->adjoint_expr;
}
};
@@ -544,18 +528,18 @@ struct CbrtExpression final : Expression<Scalar> {
std::string_view name() const override { return "cbrt"; }
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
Scalar grad_l(Scalar x, Scalar) const override {
using std::cbrt;
Scalar c = cbrt(x);
return parent_adjoint / (Scalar(3) * c * c);
return this->adjoint / (Scalar(3) * c * c);
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>&) const override {
auto c = cbrt(x);
return parent_adjoint / (constant_ptr(Scalar(3)) * c * c);
return this->adjoint_expr / (constant_ptr(Scalar(3)) * c * c);
}
};
@@ -641,24 +625,24 @@ struct DivExpression final : Expression<Scalar> {
std::string_view name() const override { return "division"; }
Scalar grad_l(Scalar, Scalar rhs, Scalar parent_adjoint) const override {
return parent_adjoint / rhs;
Scalar grad_l(Scalar, Scalar rhs) const override {
return this->adjoint / rhs;
};
Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override {
return parent_adjoint * -lhs / (rhs * rhs);
Scalar grad_r(Scalar lhs, Scalar rhs) const override {
return this->adjoint * -lhs / (rhs * rhs);
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>&, const ExpressionPtr<Scalar>& rhs,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint / rhs;
const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& rhs) const override {
return this->adjoint_expr / rhs;
}
ExpressionPtr<Scalar> grad_expr_r(
const ExpressionPtr<Scalar>& lhs, const ExpressionPtr<Scalar>& rhs,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint * -lhs / (rhs * rhs);
const ExpressionPtr<Scalar>& lhs,
const ExpressionPtr<Scalar>& rhs) const override {
return this->adjoint_expr * -lhs / (rhs * rhs);
}
};
@@ -681,28 +665,24 @@ struct MultExpression final : Expression<Scalar> {
std::string_view name() const override { return "multiplication"; }
Scalar grad_l([[maybe_unused]] Scalar lhs, Scalar rhs,
Scalar parent_adjoint) const override {
return parent_adjoint * rhs;
Scalar grad_l([[maybe_unused]] Scalar lhs, Scalar rhs) const override {
return this->adjoint * rhs;
}
Scalar grad_r(Scalar lhs, [[maybe_unused]] Scalar rhs,
Scalar parent_adjoint) const override {
return parent_adjoint * lhs;
Scalar grad_r(Scalar lhs, [[maybe_unused]] Scalar rhs) const override {
return this->adjoint * lhs;
}
ExpressionPtr<Scalar> grad_expr_l(
[[maybe_unused]] const ExpressionPtr<Scalar>& lhs,
const ExpressionPtr<Scalar>& rhs,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint * rhs;
const ExpressionPtr<Scalar>& rhs) const override {
return this->adjoint_expr * rhs;
}
ExpressionPtr<Scalar> grad_expr_r(
const ExpressionPtr<Scalar>& lhs,
[[maybe_unused]] const ExpressionPtr<Scalar>& rhs,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint * lhs;
[[maybe_unused]] const ExpressionPtr<Scalar>& rhs) const override {
return this->adjoint_expr * lhs;
}
};
@@ -724,14 +704,12 @@ struct UnaryMinusExpression final : Expression<Scalar> {
std::string_view name() const override { return "unary minus"; }
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override {
return -parent_adjoint;
}
Scalar grad_l(Scalar, Scalar) const override { return -this->adjoint; }
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>&, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return -parent_adjoint;
const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>&) const override {
return -this->adjoint_expr;
}
};
@@ -806,23 +784,23 @@ struct AbsExpression final : Expression<Scalar> {
std::string_view name() const override { return "abs"; }
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
Scalar grad_l(Scalar x, Scalar) const override {
if (x < Scalar(0)) {
return -parent_adjoint;
return -this->adjoint;
} else if (x > Scalar(0)) {
return parent_adjoint;
return this->adjoint;
} else {
return Scalar(0);
}
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>&) const override {
if (x->val < Scalar(0)) {
return -parent_adjoint;
return -this->adjoint_expr;
} else if (x->val > Scalar(0)) {
return parent_adjoint;
return this->adjoint_expr;
} else {
return constant_ptr(Scalar(0));
}
@@ -872,15 +850,15 @@ struct AcosExpression final : Expression<Scalar> {
std::string_view name() const override { return "acos"; }
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
Scalar grad_l(Scalar x, Scalar) const override {
using std::sqrt;
return -parent_adjoint / sqrt(Scalar(1) - x * x);
return -this->adjoint / sqrt(Scalar(1) - x * x);
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return -parent_adjoint / sqrt(constant_ptr(Scalar(1)) - x * x);
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>&) const override {
return -this->adjoint_expr / sqrt(constant_ptr(Scalar(1)) - x * x);
}
};
@@ -926,15 +904,15 @@ struct AsinExpression final : Expression<Scalar> {
std::string_view name() const override { return "asin"; }
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
Scalar grad_l(Scalar x, Scalar) const override {
using std::sqrt;
return parent_adjoint / sqrt(Scalar(1) - x * x);
return this->adjoint / sqrt(Scalar(1) - x * x);
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint / sqrt(constant_ptr(Scalar(1)) - x * x);
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>&) const override {
return this->adjoint_expr / sqrt(constant_ptr(Scalar(1)) - x * x);
}
};
@@ -981,14 +959,14 @@ struct AtanExpression final : Expression<Scalar> {
std::string_view name() const override { return "atan"; }
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
return parent_adjoint / (Scalar(1) + x * x);
Scalar grad_l(Scalar x, Scalar) const override {
return this->adjoint / (Scalar(1) + x * x);
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint / (constant_ptr(Scalar(1)) + x * x);
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>&) const override {
return this->adjoint_expr / (constant_ptr(Scalar(1)) + x * x);
}
};
@@ -1037,24 +1015,24 @@ struct Atan2Expression final : Expression<Scalar> {
std::string_view name() const override { return "atan2"; }
Scalar grad_l(Scalar y, Scalar x, Scalar parent_adjoint) const override {
return parent_adjoint * x / (y * y + x * x);
Scalar grad_l(Scalar y, Scalar x) const override {
return this->adjoint * x / (y * y + x * x);
}
Scalar grad_r(Scalar y, Scalar x, Scalar parent_adjoint) const override {
return parent_adjoint * -y / (y * y + x * x);
Scalar grad_r(Scalar y, Scalar x) const override {
return this->adjoint * -y / (y * y + x * x);
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& y, const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint * x / (y * y + x * x);
const ExpressionPtr<Scalar>& y,
const ExpressionPtr<Scalar>& x) const override {
return this->adjoint_expr * x / (y * y + x * x);
}
ExpressionPtr<Scalar> grad_expr_r(
const ExpressionPtr<Scalar>& y, const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint * -y / (y * y + x * x);
const ExpressionPtr<Scalar>& y,
const ExpressionPtr<Scalar>& x) const override {
return this->adjoint_expr * -y / (y * y + x * x);
}
};
@@ -1105,15 +1083,15 @@ struct CosExpression final : Expression<Scalar> {
std::string_view name() const override { return "cos"; }
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
Scalar grad_l(Scalar x, Scalar) const override {
using std::sin;
return parent_adjoint * -sin(x);
return this->adjoint * -sin(x);
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint * -sin(x);
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>&) const override {
return this->adjoint_expr * -sin(x);
}
};
@@ -1159,15 +1137,15 @@ struct CoshExpression final : Expression<Scalar> {
std::string_view name() const override { return "cosh"; }
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
Scalar grad_l(Scalar x, Scalar) const override {
using std::sinh;
return parent_adjoint * sinh(x);
return this->adjoint * sinh(x);
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint * sinh(x);
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>&) const override {
return this->adjoint_expr * sinh(x);
}
};
@@ -1213,16 +1191,15 @@ struct ErfExpression final : Expression<Scalar> {
std::string_view name() const override { return "erf"; }
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
Scalar grad_l(Scalar x, Scalar) const override {
using std::exp;
return parent_adjoint * Scalar(2.0 * std::numbers::inv_sqrtpi) *
exp(-x * x);
return this->adjoint * Scalar(2.0 * std::numbers::inv_sqrtpi) * exp(-x * x);
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint *
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>&) const override {
return this->adjoint_expr *
constant_ptr(Scalar(2.0 * std::numbers::inv_sqrtpi)) * exp(-x * x);
}
};
@@ -1270,15 +1247,15 @@ struct ExpExpression final : Expression<Scalar> {
std::string_view name() const override { return "exp"; }
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
Scalar grad_l(Scalar x, Scalar) const override {
using std::exp;
return parent_adjoint * exp(x);
return this->adjoint * exp(x);
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint * exp(x);
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>&) const override {
return this->adjoint_expr * exp(x);
}
};
@@ -1330,26 +1307,26 @@ struct HypotExpression final : Expression<Scalar> {
std::string_view name() const override { return "hypot"; }
Scalar grad_l(Scalar x, Scalar y, Scalar parent_adjoint) const override {
Scalar grad_l(Scalar x, Scalar y) const override {
using std::hypot;
return parent_adjoint * x / hypot(x, y);
return this->adjoint * x / hypot(x, y);
}
Scalar grad_r(Scalar x, Scalar y, Scalar parent_adjoint) const override {
Scalar grad_r(Scalar x, Scalar y) const override {
using std::hypot;
return parent_adjoint * y / hypot(x, y);
return this->adjoint * y / hypot(x, y);
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>& y,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint * x / hypot(x, y);
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>& y) const override {
return this->adjoint_expr * x / hypot(x, y);
}
ExpressionPtr<Scalar> grad_expr_r(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>& y,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint * y / hypot(x, y);
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>& y) const override {
return this->adjoint_expr * y / hypot(x, y);
}
};
@@ -1399,14 +1376,12 @@ struct LogExpression final : Expression<Scalar> {
std::string_view name() const override { return "log"; }
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
return parent_adjoint / x;
}
Scalar grad_l(Scalar x, Scalar) const override { return this->adjoint / x; }
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint / x;
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>&) const override {
return this->adjoint_expr / x;
}
};
@@ -1453,14 +1428,14 @@ struct Log10Expression final : Expression<Scalar> {
std::string_view name() const override { return "log10"; }
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
return parent_adjoint / (Scalar(std::numbers::ln10) * x);
Scalar grad_l(Scalar x, Scalar) const override {
return this->adjoint / (Scalar(std::numbers::ln10) * x);
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint / (constant_ptr(Scalar(std::numbers::ln10)) * x);
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>&) const override {
return this->adjoint_expr / (constant_ptr(Scalar(std::numbers::ln10)) * x);
}
};
@@ -1510,37 +1485,37 @@ struct MaxExpression final : Expression<Scalar> {
std::string_view name() const override { return "max"; }
Scalar grad_l(Scalar a, Scalar b, Scalar parent_adjoint) const override {
Scalar grad_l(Scalar a, Scalar b) const override {
if (a >= b) {
return parent_adjoint;
return this->adjoint;
} else {
return Scalar(0);
}
}
Scalar grad_r(Scalar a, Scalar b, Scalar parent_adjoint) const override {
Scalar grad_r(Scalar a, Scalar b) const override {
if (b > a) {
return parent_adjoint;
return this->adjoint;
} else {
return Scalar(0);
}
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& a, const ExpressionPtr<Scalar>& b,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
const ExpressionPtr<Scalar>& a,
const ExpressionPtr<Scalar>& b) const override {
if (a->val >= b->val) {
return parent_adjoint;
return this->adjoint_expr;
} else {
return constant_ptr(Scalar(0));
}
}
ExpressionPtr<Scalar> grad_expr_r(
const ExpressionPtr<Scalar>& a, const ExpressionPtr<Scalar>& b,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
const ExpressionPtr<Scalar>& a,
const ExpressionPtr<Scalar>& b) const override {
if (b->val > a->val) {
return parent_adjoint;
return this->adjoint_expr;
} else {
return constant_ptr(Scalar(0));
}
@@ -1589,38 +1564,38 @@ struct MinExpression final : Expression<Scalar> {
std::string_view name() const override { return "min"; }
Scalar grad_l(Scalar a, Scalar b, Scalar parent_adjoint) const override {
Scalar grad_l(Scalar a, Scalar b) const override {
if (a <= b) {
return parent_adjoint;
return this->adjoint;
} else {
return Scalar(0);
}
}
Scalar grad_r([[maybe_unused]] Scalar a, [[maybe_unused]] Scalar b,
Scalar parent_adjoint) const override {
Scalar grad_r([[maybe_unused]] Scalar a,
[[maybe_unused]] Scalar b) const override {
if (b < a) {
return parent_adjoint;
return this->adjoint;
} else {
return Scalar(0);
}
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& a, const ExpressionPtr<Scalar>& b,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
const ExpressionPtr<Scalar>& a,
const ExpressionPtr<Scalar>& b) const override {
if (a->val <= b->val) {
return parent_adjoint;
return this->adjoint_expr;
} else {
return constant_ptr(Scalar(0));
}
}
ExpressionPtr<Scalar> grad_expr_r(
const ExpressionPtr<Scalar>& a, const ExpressionPtr<Scalar>& b,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
const ExpressionPtr<Scalar>& a,
const ExpressionPtr<Scalar>& b) const override {
if (b->val < a->val) {
return parent_adjoint;
return this->adjoint_expr;
} else {
return constant_ptr(Scalar(0));
}
@@ -1672,14 +1647,12 @@ struct PowExpression final : Expression<Scalar> {
std::string_view name() const override { return "pow"; }
Scalar grad_l(Scalar base, Scalar power,
Scalar parent_adjoint) const override {
Scalar grad_l(Scalar base, Scalar power) const override {
using std::pow;
return parent_adjoint * pow(base, power - Scalar(1)) * power;
return this->adjoint * pow(base, power - Scalar(1)) * power;
}
Scalar grad_r(Scalar base, Scalar power,
Scalar parent_adjoint) const override {
Scalar grad_r(Scalar base, Scalar power) const override {
using std::log;
using std::pow;
@@ -1687,25 +1660,26 @@ struct PowExpression final : Expression<Scalar> {
if (base == Scalar(0)) {
return Scalar(0);
} else {
return parent_adjoint * pow(base, power) * log(base);
return this->adjoint * pow(base, power) * log(base);
}
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& base, const ExpressionPtr<Scalar>& power,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint * pow(base, power - constant_ptr(Scalar(1))) * power;
const ExpressionPtr<Scalar>& base,
const ExpressionPtr<Scalar>& power) const override {
return this->adjoint_expr * pow(base, power - constant_ptr(Scalar(1))) *
power;
}
ExpressionPtr<Scalar> grad_expr_r(
const ExpressionPtr<Scalar>& base, const ExpressionPtr<Scalar>& power,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
const ExpressionPtr<Scalar>& base,
const ExpressionPtr<Scalar>& power) const override {
// Since x log(x) -> 0 as x -> 0
if (base->val == Scalar(0)) {
// Return zero
return base;
} else {
return parent_adjoint * pow(base, power) * log(base);
return this->adjoint_expr * pow(base, power) * log(base);
}
}
};
@@ -1821,15 +1795,15 @@ struct SinExpression final : Expression<Scalar> {
std::string_view name() const override { return "sin"; }
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
Scalar grad_l(Scalar x, Scalar) const override {
using std::cos;
return parent_adjoint * cos(x);
return this->adjoint * cos(x);
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint * cos(x);
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>&) const override {
return this->adjoint_expr * cos(x);
}
};
@@ -1876,15 +1850,15 @@ struct SinhExpression final : Expression<Scalar> {
std::string_view name() const override { return "sinh"; }
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
Scalar grad_l(Scalar x, Scalar) const override {
using std::cosh;
return parent_adjoint * cosh(x);
return this->adjoint * cosh(x);
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint * cosh(x);
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>&) const override {
return this->adjoint_expr * cosh(x);
}
};
@@ -1931,15 +1905,15 @@ struct SqrtExpression final : Expression<Scalar> {
std::string_view name() const override { return "sqrt"; }
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
Scalar grad_l(Scalar x, Scalar) const override {
using std::sqrt;
return parent_adjoint / (Scalar(2) * sqrt(x));
return this->adjoint / (Scalar(2) * sqrt(x));
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
return parent_adjoint / (constant_ptr(Scalar(2)) * sqrt(x));
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>&) const override {
return this->adjoint_expr / (constant_ptr(Scalar(2)) * sqrt(x));
}
};
@@ -1987,18 +1961,18 @@ struct TanExpression final : Expression<Scalar> {
std::string_view name() const override { return "tan"; }
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
Scalar grad_l(Scalar x, Scalar) const override {
using std::cos;
auto c = cos(x);
return parent_adjoint / (c * c);
return this->adjoint / (c * c);
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>&) const override {
auto c = cos(x);
return parent_adjoint / (c * c);
return this->adjoint_expr / (c * c);
}
};
@@ -2045,18 +2019,18 @@ struct TanhExpression final : Expression<Scalar> {
std::string_view name() const override { return "tanh"; }
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
Scalar grad_l(Scalar x, Scalar) const override {
using std::cosh;
auto c = cosh(x);
return parent_adjoint / (c * c);
return this->adjoint / (c * c);
}
ExpressionPtr<Scalar> grad_expr_l(
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
const ExpressionPtr<Scalar>& parent_adjoint) const override {
const ExpressionPtr<Scalar>& x,
const ExpressionPtr<Scalar>&) const override {
auto c = cosh(x);
return parent_adjoint / (c * c);
return this->adjoint_expr / (c * c);
}
};

View File

@@ -133,11 +133,11 @@ void append_triplets(
if (lhs != nullptr) {
if (rhs != nullptr) {
// Binary operator
lhs->adjoint += node->grad_l(lhs->val, rhs->val, node->adjoint);
rhs->adjoint += node->grad_r(lhs->val, rhs->val, node->adjoint);
lhs->adjoint += node->grad_l(lhs->val, rhs->val);
rhs->adjoint += node->grad_r(lhs->val, rhs->val);
} else {
// Unary operator
lhs->adjoint += node->grad_l(lhs->val, Scalar(0), node->adjoint);
lhs->adjoint += node->grad_l(lhs->val, Scalar(0));
}
}
}

View File

@@ -45,7 +45,7 @@ class Hessian {
: m_variables{detail::gradient_tree(
detail::topological_sort(variable.expr), wrt)},
m_wrt{std::move(wrt)} {
slp_assert(wrt.cols() == 1);
slp_assert(m_wrt.cols() == 1);
for (auto& variable : m_variables) {
m_top_lists.emplace_back(detail::topological_sort(variable.expr));

View File

@@ -434,10 +434,11 @@ class VariableBlock : public SleipnirBase {
slp_assert(cols() == rhs.rows() && cols() == rhs.cols());
for (int i = 0; i < rows(); ++i) {
Mat lhs_old_row = row(i);
for (int j = 0; j < rhs.cols(); ++j) {
Variable sum{Scalar(0)};
for (int k = 0; k < cols(); ++k) {
sum += (*this)(i, k) * rhs(k, j);
sum += lhs_old_row[k] * rhs(k, j);
}
(*this)(i, j) = sum;
}

View File

@@ -647,10 +647,11 @@ class VariableMatrix : public SleipnirBase {
slp_assert(cols() == rhs.rows() && cols() == rhs.cols());
for (int i = 0; i < rows(); ++i) {
VariableMatrix lhs_old_row = row(i);
for (int j = 0; j < rhs.cols(); ++j) {
Variable sum{Scalar(0)};
for (int k = 0; k < cols(); ++k) {
sum += (*this)(i, k) * rhs(k, j);
sum += lhs_old_row[k] * rhs(k, j);
}
(*this)(i, j) = sum;
}
@@ -665,7 +666,7 @@ class VariableMatrix : public SleipnirBase {
/// @return Result of multiplication.
VariableMatrix& operator*=(const ScalarLike auto& rhs) {
for (int row = 0; row < rows(); ++row) {
for (int col = 0; col < rhs.cols(); ++col) {
for (int col = 0; col < cols(); ++col) {
(*this)(row, col) *= rhs;
}
}
@@ -1686,11 +1687,11 @@ VariableMatrix<Scalar> gradient_tree(const ExpressionGraph<Scalar>& top_list,
if (lhs != nullptr) {
if (rhs != nullptr) {
// Binary operator
lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
rhs->adjoint_expr += node->grad_expr_r(lhs, rhs, node->adjoint_expr);
lhs->adjoint_expr += node->grad_expr_l(lhs, rhs);
rhs->adjoint_expr += node->grad_expr_r(lhs, rhs);
} else {
// Unary operator
lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
lhs->adjoint_expr += node->grad_expr_l(lhs, rhs);
}
}
}

View File

@@ -501,6 +501,7 @@ ExitStatus interior_point(
// Loop until a step is accepted
while (1) {
trial_x = x + α * step.p_x;
trial_c_i = matrices.c_i(trial_x);
if (options.feasible_ipm && c_i.cwiseGreater(Scalar(0)).all()) {
// If the inequality constraints are all feasible, prevent them from
// becoming infeasible again.
@@ -515,7 +516,6 @@ ExitStatus interior_point(
trial_f = matrices.f(trial_x);
trial_c_e = matrices.c_e(trial_x);
trial_c_i = matrices.c_i(trial_x);
// If f(xₖ + αpₖˣ), cₑ(xₖ + αpₖˣ), or cᵢ(xₖ + αpₖˣ) aren't finite, reduce
// step size immediately

View File

@@ -3,6 +3,7 @@
#pragma once
#include <algorithm>
#include <utility>
#include <Eigen/Core>
#include <Eigen/SparseCore>
@@ -42,8 +43,8 @@ struct ProblemScaling {
/// @param f Cost scaling factor d_f.
/// @param c_e Equality constraint scaling factors d_cₑ.
/// @param c_i Inequality constraint scaling factors d_cᵢ.
ProblemScaling(Scalar f, const DenseVector& c_e, const DenseVector& c_i)
: f{f}, c_e{c_e}, c_i{c_i} {}
ProblemScaling(Scalar f, DenseVector c_e, DenseVector c_i)
: f{f}, c_e{std::move(c_e)}, c_i{std::move(c_i)} {}
/// Computes Newton problem scaling.
///