mirror of
https://github.com/wpilibsuite/allwpilib
synced 2026-07-04 03:11:43 +00:00
[upstream_utils] Upgrade to Sleipnir 0.6.2 (#8996)
This commit is contained in:
@@ -46,7 +46,7 @@ using small_vector = wpi::util::SmallVector<T>;
|
||||
def main():
|
||||
name = "sleipnir"
|
||||
url = "https://github.com/SleipnirGroup/Sleipnir"
|
||||
tag = "v0.6.1"
|
||||
tag = "v0.6.2"
|
||||
|
||||
sleipnir = Lib(name, url, tag, copy_upstream_src)
|
||||
sleipnir.main()
|
||||
|
||||
@@ -9,7 +9,7 @@ Subject: [PATCH 02/10] Use wpi::SmallVector
|
||||
2 files changed, 4 insertions(+), 4 deletions(-)
|
||||
|
||||
diff --git a/include/sleipnir/autodiff/expression.hpp b/include/sleipnir/autodiff/expression.hpp
|
||||
index f04967c6a84ea76b77cebb7ef8edcb798d03e2a5..361c74702107eadf6f1fbe1682a7ef296abd19a3 100644
|
||||
index 1e0a57bd78bdd6a41d87539c1e4536aed8bfc42d..8b1c520934580ccc94d0a7e46ab500e6bdd65207 100644
|
||||
--- a/include/sleipnir/autodiff/expression.hpp
|
||||
+++ b/include/sleipnir/autodiff/expression.hpp
|
||||
@@ -34,7 +34,7 @@ struct Expression;
|
||||
@@ -21,7 +21,7 @@ index f04967c6a84ea76b77cebb7ef8edcb798d03e2a5..361c74702107eadf6f1fbe1682a7ef29
|
||||
|
||||
/// Typedef for intrusive shared pointer to Expression.
|
||||
///
|
||||
@@ -749,7 +749,7 @@ constexpr void inc_ref_count(Expression<Scalar>* expr) {
|
||||
@@ -727,7 +727,7 @@ constexpr void inc_ref_count(Expression<Scalar>* expr) {
|
||||
/// @tparam Scalar Scalar type.
|
||||
/// @param expr The shared pointer's managed object.
|
||||
template <typename Scalar>
|
||||
|
||||
@@ -8,7 +8,7 @@ Subject: [PATCH 08/10] Suppress GCC 12 warning false positive
|
||||
1 file changed, 7 insertions(+)
|
||||
|
||||
diff --git a/include/sleipnir/autodiff/variable_matrix.hpp b/include/sleipnir/autodiff/variable_matrix.hpp
|
||||
index 2bdbf11841b9920111b9cb1426d6fd04764040e4..8b84a04780240cffb801e2c6f84e22c0e5246286 100644
|
||||
index 72559959cfc4cb38e099a2268c66088ea19551f5..284cfd3aafdee76e8623df402e9c1bdd90dab116 100644
|
||||
--- a/include/sleipnir/autodiff/variable_matrix.hpp
|
||||
+++ b/include/sleipnir/autodiff/variable_matrix.hpp
|
||||
@@ -503,6 +503,10 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
@@ -8,14 +8,14 @@ Subject: [PATCH 10/10] Use operator() instead of multidimensional array
|
||||
include/sleipnir/autodiff/hessian.hpp | 4 +-
|
||||
include/sleipnir/autodiff/jacobian.hpp | 4 +-
|
||||
include/sleipnir/autodiff/variable.hpp | 8 +-
|
||||
include/sleipnir/autodiff/variable_block.hpp | 76 ++++-----
|
||||
include/sleipnir/autodiff/variable_block.hpp | 80 ++++-----
|
||||
include/sleipnir/autodiff/variable_matrix.hpp | 158 +++++++++---------
|
||||
include/sleipnir/optimization/ocp.hpp | 14 +-
|
||||
include/sleipnir/optimization/problem.hpp | 6 +-
|
||||
7 files changed, 135 insertions(+), 135 deletions(-)
|
||||
7 files changed, 137 insertions(+), 137 deletions(-)
|
||||
|
||||
diff --git a/include/sleipnir/autodiff/hessian.hpp b/include/sleipnir/autodiff/hessian.hpp
|
||||
index d9d25a00251dd31e446bd0419f0c13e8f2456f2d..39cd798160ff272aeac4b12c92edee3f2a213a22 100644
|
||||
index 56100d882f9eafe2cdd4773e99a9166c21221ce8..ee42cd3fc2cdc9545f4dbfb1949a644620f30215 100644
|
||||
--- a/include/sleipnir/autodiff/hessian.hpp
|
||||
+++ b/include/sleipnir/autodiff/hessian.hpp
|
||||
@@ -111,9 +111,9 @@ class Hessian {
|
||||
@@ -87,7 +87,7 @@ index 5f32d7ea16864ad9a4fdcc4b067ef7dff2e9027d..5d7b56de884cdb73fc9f08fcfa1982b2
|
||||
}
|
||||
|
||||
diff --git a/include/sleipnir/autodiff/variable_block.hpp b/include/sleipnir/autodiff/variable_block.hpp
|
||||
index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d7bf629dc 100644
|
||||
index 55cafb5d48afff2dbdff07516d85759180ef74c7..9c9e9017487fb0fd19c584c58e5ba3c6c6cd3e7d 100644
|
||||
--- a/include/sleipnir/autodiff/variable_block.hpp
|
||||
+++ b/include/sleipnir/autodiff/variable_block.hpp
|
||||
@@ -49,7 +49,7 @@ class VariableBlock : public SleipnirBase {
|
||||
@@ -279,7 +279,19 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
|
||||
int row_slice_length,
|
||||
Slice col_slice,
|
||||
int col_slice_length) const {
|
||||
@@ -453,7 +453,7 @@ class VariableBlock : public SleipnirBase {
|
||||
@@ -438,9 +438,9 @@ class VariableBlock : public SleipnirBase {
|
||||
for (int j = 0; j < rhs.cols(); ++j) {
|
||||
Variable sum{Scalar(0)};
|
||||
for (int k = 0; k < cols(); ++k) {
|
||||
- sum += lhs_old_row[k] * rhs[k, j];
|
||||
+ sum += lhs_old_row[k] * rhs(k, j);
|
||||
}
|
||||
- (*this)[i, j] = sum;
|
||||
+ (*this)(i, j) = sum;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -454,7 +454,7 @@ class VariableBlock : public SleipnirBase {
|
||||
VariableBlock<Mat>& operator*=(const ScalarLike auto& rhs) {
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -288,7 +300,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
|
||||
}
|
||||
}
|
||||
|
||||
@@ -469,7 +469,7 @@ class VariableBlock : public SleipnirBase {
|
||||
@@ -470,7 +470,7 @@ class VariableBlock : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -297,7 +309,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
|
||||
}
|
||||
}
|
||||
|
||||
@@ -483,7 +483,7 @@ class VariableBlock : public SleipnirBase {
|
||||
@@ -484,7 +484,7 @@ class VariableBlock : public SleipnirBase {
|
||||
VariableBlock<Mat>& operator/=(const ScalarLike auto& rhs) {
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -306,7 +318,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
|
||||
}
|
||||
}
|
||||
|
||||
@@ -499,7 +499,7 @@ class VariableBlock : public SleipnirBase {
|
||||
@@ -500,7 +500,7 @@ class VariableBlock : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -315,7 +327,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
|
||||
}
|
||||
}
|
||||
|
||||
@@ -515,7 +515,7 @@ class VariableBlock : public SleipnirBase {
|
||||
@@ -516,7 +516,7 @@ class VariableBlock : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -324,7 +336,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
|
||||
}
|
||||
}
|
||||
|
||||
@@ -531,7 +531,7 @@ class VariableBlock : public SleipnirBase {
|
||||
@@ -532,7 +532,7 @@ class VariableBlock : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -333,7 +345,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
|
||||
}
|
||||
}
|
||||
|
||||
@@ -547,7 +547,7 @@ class VariableBlock : public SleipnirBase {
|
||||
@@ -548,7 +548,7 @@ class VariableBlock : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -342,7 +354,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
|
||||
}
|
||||
}
|
||||
|
||||
@@ -558,7 +558,7 @@ class VariableBlock : public SleipnirBase {
|
||||
@@ -559,7 +559,7 @@ class VariableBlock : public SleipnirBase {
|
||||
// NOLINTNEXTLINE (google-explicit-constructor)
|
||||
operator Variable<Scalar>() const {
|
||||
slp_assert(rows() == 1 && cols() == 1);
|
||||
@@ -351,7 +363,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
|
||||
}
|
||||
|
||||
/// Returns the transpose of the variable matrix.
|
||||
@@ -569,7 +569,7 @@ class VariableBlock : public SleipnirBase {
|
||||
@@ -570,7 +570,7 @@ class VariableBlock : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -360,7 +372,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
|
||||
}
|
||||
}
|
||||
|
||||
@@ -591,7 +591,7 @@ class VariableBlock : public SleipnirBase {
|
||||
@@ -592,7 +592,7 @@ class VariableBlock : public SleipnirBase {
|
||||
/// @param row The row of the element to return.
|
||||
/// @param col The column of the element to return.
|
||||
/// @return An element of the variable matrix.
|
||||
@@ -369,7 +381,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
|
||||
|
||||
/// Returns an element of the variable block.
|
||||
///
|
||||
@@ -611,7 +611,7 @@ class VariableBlock : public SleipnirBase {
|
||||
@@ -612,7 +612,7 @@ class VariableBlock : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -378,7 +390,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
|
||||
}
|
||||
}
|
||||
|
||||
@@ -629,7 +629,7 @@ class VariableBlock : public SleipnirBase {
|
||||
@@ -630,7 +630,7 @@ class VariableBlock : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -388,7 +400,7 @@ index 4018606df45941b578c861caf934495f8c9e368e..cf554832b82adb17b4b1d7b56842a77d
|
||||
}
|
||||
|
||||
diff --git a/include/sleipnir/autodiff/variable_matrix.hpp b/include/sleipnir/autodiff/variable_matrix.hpp
|
||||
index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c7449d109 100644
|
||||
index 284cfd3aafdee76e8623df402e9c1bdd90dab116..2f91d1714ccee8271c0186c877dd5f84720da692 100644
|
||||
--- a/include/sleipnir/autodiff/variable_matrix.hpp
|
||||
+++ b/include/sleipnir/autodiff/variable_matrix.hpp
|
||||
@@ -154,7 +154,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -571,28 +583,28 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -650,9 +650,9 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -651,9 +651,9 @@ class VariableMatrix : public SleipnirBase {
|
||||
for (int j = 0; j < rhs.cols(); ++j) {
|
||||
Variable sum{Scalar(0)};
|
||||
for (int k = 0; k < cols(); ++k) {
|
||||
- sum += (*this)[i, k] * rhs[k, j];
|
||||
+ sum += (*this)(i, k) * rhs(k, j);
|
||||
- sum += lhs_old_row[k] * rhs[k, j];
|
||||
+ sum += lhs_old_row[k] * rhs(k, j);
|
||||
}
|
||||
- (*this)[i, j] = sum;
|
||||
+ (*this)(i, j) = sum;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -666,7 +666,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -667,7 +667,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
VariableMatrix& operator*=(const ScalarLike auto& rhs) {
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < rhs.cols(); ++col) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
- (*this)[row, col] *= rhs;
|
||||
+ (*this)(row, col) *= rhs;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -685,7 +685,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -686,7 +686,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
@@ -601,7 +613,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -704,7 +704,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -705,7 +705,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
@@ -610,7 +622,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -723,7 +723,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -724,7 +724,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
@@ -619,7 +631,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -737,7 +737,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -738,7 +738,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
VariableMatrix& operator/=(const ScalarLike auto& rhs) {
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -628,7 +640,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -757,7 +757,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -758,7 +758,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
@@ -637,7 +649,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -777,7 +777,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -778,7 +778,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
@@ -646,7 +658,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -797,7 +797,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -798,7 +798,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
@@ -655,7 +667,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -813,7 +813,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -814,7 +814,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -664,7 +676,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -829,7 +829,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -830,7 +830,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -673,7 +685,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -849,7 +849,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -850,7 +850,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
@@ -682,7 +694,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -869,7 +869,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -870,7 +870,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
@@ -691,7 +703,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -889,7 +889,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -890,7 +890,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
@@ -700,7 +712,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -905,7 +905,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -906,7 +906,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -709,7 +721,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -921,7 +921,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -922,7 +922,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -718,7 +730,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -937,7 +937,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -938,7 +938,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < result.rows(); ++row) {
|
||||
for (int col = 0; col < result.cols(); ++col) {
|
||||
@@ -727,7 +739,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -948,7 +948,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -949,7 +949,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
// NOLINTNEXTLINE (google-explicit-constructor)
|
||||
operator Variable<Scalar>() const {
|
||||
slp_assert(rows() == 1 && cols() == 1);
|
||||
@@ -736,7 +748,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
|
||||
/// Returns the transpose of the variable matrix.
|
||||
@@ -959,7 +959,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -960,7 +960,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -745,7 +757,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -981,7 +981,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -982,7 +982,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
/// @param row The row of the element to return.
|
||||
/// @param col The column of the element to return.
|
||||
/// @return An element of the variable matrix.
|
||||
@@ -754,7 +766,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
|
||||
/// Returns an element of the variable matrix.
|
||||
///
|
||||
@@ -998,7 +998,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -999,7 +999,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -763,7 +775,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1016,7 +1016,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
@@ -1017,7 +1017,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
@@ -772,7 +784,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1269,7 +1269,7 @@ VariableMatrix<Scalar> cwise_reduce(
|
||||
@@ -1270,7 +1270,7 @@ VariableMatrix<Scalar> cwise_reduce(
|
||||
|
||||
for (int row = 0; row < lhs.rows(); ++row) {
|
||||
for (int col = 0; col < lhs.cols(); ++col) {
|
||||
@@ -781,7 +793,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1402,17 +1402,17 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
|
||||
@@ -1403,17 +1403,17 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
|
||||
|
||||
if (A.rows() == 1 && A.cols() == 1) {
|
||||
// Compute optimal inverse instead of using Eigen's general solver
|
||||
@@ -804,7 +816,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
|
||||
VariableMatrix adj_A{{d, -b}, {-c, a}};
|
||||
auto det_A = a * d - b * c;
|
||||
@@ -1429,15 +1429,15 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
|
||||
@@ -1430,15 +1430,15 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
|
||||
//
|
||||
// https://www.wolframalpha.com/input?i=inverse+%7B%7Ba%2C+b%2C+c%7D%2C+%7Bd%2C+e%2C+f%7D%2C+%7Bg%2C+h%2C+i%7D%7D
|
||||
|
||||
@@ -829,7 +841,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
|
||||
auto ae = a * e;
|
||||
auto af = a * f;
|
||||
@@ -1477,22 +1477,22 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
|
||||
@@ -1478,22 +1478,22 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
|
||||
//
|
||||
// https://www.wolframalpha.com/input?i=inverse+%7B%7Ba%2C+b%2C+c%2C+d%7D%2C+%7Be%2C+f%2C+g%2C+h%7D%2C+%7Bi%2C+j%2C+k%2C+l%7D%2C+%7Bm%2C+n%2C+o%2C+p%7D%7D
|
||||
|
||||
@@ -868,7 +880,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
|
||||
auto afk = a * f * k;
|
||||
auto afl = a * f * l;
|
||||
@@ -1623,14 +1623,14 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
|
||||
@@ -1624,14 +1624,14 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
|
||||
MatrixXv eigen_A{A.rows(), A.cols()};
|
||||
for (int row = 0; row < A.rows(); ++row) {
|
||||
for (int col = 0; col < A.cols(); ++col) {
|
||||
@@ -885,7 +897,7 @@ index 8b84a04780240cffb801e2c6f84e22c0e5246286..d82b90191f6fb3f30460fecb2653655c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1639,7 +1639,7 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
|
||||
@@ -1640,7 +1640,7 @@ VariableMatrix<Scalar> solve(const VariableMatrix<Scalar>& A,
|
||||
VariableMatrix<Scalar> X{detail::empty, A.cols(), B.cols()};
|
||||
for (int row = 0; row < X.rows(); ++row) {
|
||||
for (int col = 0; col < X.cols(); ++col) {
|
||||
|
||||
@@ -379,11 +379,9 @@ struct Expression {
|
||||
///
|
||||
/// @param lhs Left argument to binary operator.
|
||||
/// @param rhs Right argument to binary operator.
|
||||
/// @param parent_adjoint Adjoint of parent expression.
|
||||
/// @return ∂/∂l as a Scalar.
|
||||
virtual Scalar grad_l([[maybe_unused]] Scalar lhs,
|
||||
[[maybe_unused]] Scalar rhs,
|
||||
[[maybe_unused]] Scalar parent_adjoint) const {
|
||||
[[maybe_unused]] Scalar rhs) const {
|
||||
return Scalar(0);
|
||||
}
|
||||
|
||||
@@ -391,11 +389,9 @@ struct Expression {
|
||||
///
|
||||
/// @param lhs Left argument to binary operator.
|
||||
/// @param rhs Right argument to binary operator.
|
||||
/// @param parent_adjoint Adjoint of parent expression.
|
||||
/// @return ∂/∂r as a Scalar.
|
||||
virtual Scalar grad_r([[maybe_unused]] Scalar lhs,
|
||||
[[maybe_unused]] Scalar rhs,
|
||||
[[maybe_unused]] Scalar parent_adjoint) const {
|
||||
[[maybe_unused]] Scalar rhs) const {
|
||||
return Scalar(0);
|
||||
}
|
||||
|
||||
@@ -403,12 +399,10 @@ struct Expression {
|
||||
///
|
||||
/// @param lhs Left argument to binary operator.
|
||||
/// @param rhs Right argument to binary operator.
|
||||
/// @param parent_adjoint Adjoint of parent expression.
|
||||
/// @return ∂/∂l as an Expression.
|
||||
virtual ExpressionPtr<Scalar> grad_expr_l(
|
||||
[[maybe_unused]] const ExpressionPtr<Scalar>& lhs,
|
||||
[[maybe_unused]] const ExpressionPtr<Scalar>& rhs,
|
||||
[[maybe_unused]] const ExpressionPtr<Scalar>& parent_adjoint) const {
|
||||
[[maybe_unused]] const ExpressionPtr<Scalar>& rhs) const {
|
||||
return constant_ptr(Scalar(0));
|
||||
}
|
||||
|
||||
@@ -416,12 +410,10 @@ struct Expression {
|
||||
///
|
||||
/// @param lhs Left argument to binary operator.
|
||||
/// @param rhs Right argument to binary operator.
|
||||
/// @param parent_adjoint Adjoint of parent expression.
|
||||
/// @return ∂/∂r as an Expression.
|
||||
virtual ExpressionPtr<Scalar> grad_expr_r(
|
||||
[[maybe_unused]] const ExpressionPtr<Scalar>& lhs,
|
||||
[[maybe_unused]] const ExpressionPtr<Scalar>& rhs,
|
||||
[[maybe_unused]] const ExpressionPtr<Scalar>& parent_adjoint) const {
|
||||
[[maybe_unused]] const ExpressionPtr<Scalar>& rhs) const {
|
||||
return constant_ptr(Scalar(0));
|
||||
}
|
||||
};
|
||||
@@ -462,24 +454,20 @@ struct BinaryMinusExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "binary minus"; }
|
||||
|
||||
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override {
|
||||
return parent_adjoint;
|
||||
}
|
||||
Scalar grad_l(Scalar, Scalar) const override { return this->adjoint; }
|
||||
|
||||
Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override {
|
||||
return -parent_adjoint;
|
||||
}
|
||||
Scalar grad_r(Scalar, Scalar) const override { return -this->adjoint; }
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>&, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint;
|
||||
const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
return this->adjoint_expr;
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_r(
|
||||
const ExpressionPtr<Scalar>&, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return -parent_adjoint;
|
||||
const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
return -this->adjoint_expr;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -503,24 +491,20 @@ struct BinaryPlusExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "binary plus"; }
|
||||
|
||||
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override {
|
||||
return parent_adjoint;
|
||||
}
|
||||
Scalar grad_l(Scalar, Scalar) const override { return this->adjoint; }
|
||||
|
||||
Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override {
|
||||
return parent_adjoint;
|
||||
}
|
||||
Scalar grad_r(Scalar, Scalar) const override { return this->adjoint; }
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>&, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint;
|
||||
const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
return this->adjoint_expr;
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_r(
|
||||
const ExpressionPtr<Scalar>&, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint;
|
||||
const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
return this->adjoint_expr;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -544,18 +528,18 @@ struct CbrtExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "cbrt"; }
|
||||
|
||||
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
|
||||
Scalar grad_l(Scalar x, Scalar) const override {
|
||||
using std::cbrt;
|
||||
|
||||
Scalar c = cbrt(x);
|
||||
return parent_adjoint / (Scalar(3) * c * c);
|
||||
return this->adjoint / (Scalar(3) * c * c);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
auto c = cbrt(x);
|
||||
return parent_adjoint / (constant_ptr(Scalar(3)) * c * c);
|
||||
return this->adjoint_expr / (constant_ptr(Scalar(3)) * c * c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -641,24 +625,24 @@ struct DivExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "division"; }
|
||||
|
||||
Scalar grad_l(Scalar, Scalar rhs, Scalar parent_adjoint) const override {
|
||||
return parent_adjoint / rhs;
|
||||
Scalar grad_l(Scalar, Scalar rhs) const override {
|
||||
return this->adjoint / rhs;
|
||||
};
|
||||
|
||||
Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override {
|
||||
return parent_adjoint * -lhs / (rhs * rhs);
|
||||
Scalar grad_r(Scalar lhs, Scalar rhs) const override {
|
||||
return this->adjoint * -lhs / (rhs * rhs);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>&, const ExpressionPtr<Scalar>& rhs,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint / rhs;
|
||||
const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& rhs) const override {
|
||||
return this->adjoint_expr / rhs;
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_r(
|
||||
const ExpressionPtr<Scalar>& lhs, const ExpressionPtr<Scalar>& rhs,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint * -lhs / (rhs * rhs);
|
||||
const ExpressionPtr<Scalar>& lhs,
|
||||
const ExpressionPtr<Scalar>& rhs) const override {
|
||||
return this->adjoint_expr * -lhs / (rhs * rhs);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -681,28 +665,24 @@ struct MultExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "multiplication"; }
|
||||
|
||||
Scalar grad_l([[maybe_unused]] Scalar lhs, Scalar rhs,
|
||||
Scalar parent_adjoint) const override {
|
||||
return parent_adjoint * rhs;
|
||||
Scalar grad_l([[maybe_unused]] Scalar lhs, Scalar rhs) const override {
|
||||
return this->adjoint * rhs;
|
||||
}
|
||||
|
||||
Scalar grad_r(Scalar lhs, [[maybe_unused]] Scalar rhs,
|
||||
Scalar parent_adjoint) const override {
|
||||
return parent_adjoint * lhs;
|
||||
Scalar grad_r(Scalar lhs, [[maybe_unused]] Scalar rhs) const override {
|
||||
return this->adjoint * lhs;
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
[[maybe_unused]] const ExpressionPtr<Scalar>& lhs,
|
||||
const ExpressionPtr<Scalar>& rhs,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint * rhs;
|
||||
const ExpressionPtr<Scalar>& rhs) const override {
|
||||
return this->adjoint_expr * rhs;
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_r(
|
||||
const ExpressionPtr<Scalar>& lhs,
|
||||
[[maybe_unused]] const ExpressionPtr<Scalar>& rhs,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint * lhs;
|
||||
[[maybe_unused]] const ExpressionPtr<Scalar>& rhs) const override {
|
||||
return this->adjoint_expr * lhs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -724,14 +704,12 @@ struct UnaryMinusExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "unary minus"; }
|
||||
|
||||
Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override {
|
||||
return -parent_adjoint;
|
||||
}
|
||||
Scalar grad_l(Scalar, Scalar) const override { return -this->adjoint; }
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>&, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return -parent_adjoint;
|
||||
const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
return -this->adjoint_expr;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -806,23 +784,23 @@ struct AbsExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "abs"; }
|
||||
|
||||
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
|
||||
Scalar grad_l(Scalar x, Scalar) const override {
|
||||
if (x < Scalar(0)) {
|
||||
return -parent_adjoint;
|
||||
return -this->adjoint;
|
||||
} else if (x > Scalar(0)) {
|
||||
return parent_adjoint;
|
||||
return this->adjoint;
|
||||
} else {
|
||||
return Scalar(0);
|
||||
}
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
if (x->val < Scalar(0)) {
|
||||
return -parent_adjoint;
|
||||
return -this->adjoint_expr;
|
||||
} else if (x->val > Scalar(0)) {
|
||||
return parent_adjoint;
|
||||
return this->adjoint_expr;
|
||||
} else {
|
||||
return constant_ptr(Scalar(0));
|
||||
}
|
||||
@@ -872,15 +850,15 @@ struct AcosExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "acos"; }
|
||||
|
||||
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
|
||||
Scalar grad_l(Scalar x, Scalar) const override {
|
||||
using std::sqrt;
|
||||
return -parent_adjoint / sqrt(Scalar(1) - x * x);
|
||||
return -this->adjoint / sqrt(Scalar(1) - x * x);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return -parent_adjoint / sqrt(constant_ptr(Scalar(1)) - x * x);
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
return -this->adjoint_expr / sqrt(constant_ptr(Scalar(1)) - x * x);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -926,15 +904,15 @@ struct AsinExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "asin"; }
|
||||
|
||||
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
|
||||
Scalar grad_l(Scalar x, Scalar) const override {
|
||||
using std::sqrt;
|
||||
return parent_adjoint / sqrt(Scalar(1) - x * x);
|
||||
return this->adjoint / sqrt(Scalar(1) - x * x);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint / sqrt(constant_ptr(Scalar(1)) - x * x);
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
return this->adjoint_expr / sqrt(constant_ptr(Scalar(1)) - x * x);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -981,14 +959,14 @@ struct AtanExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "atan"; }
|
||||
|
||||
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
|
||||
return parent_adjoint / (Scalar(1) + x * x);
|
||||
Scalar grad_l(Scalar x, Scalar) const override {
|
||||
return this->adjoint / (Scalar(1) + x * x);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint / (constant_ptr(Scalar(1)) + x * x);
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
return this->adjoint_expr / (constant_ptr(Scalar(1)) + x * x);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1037,24 +1015,24 @@ struct Atan2Expression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "atan2"; }
|
||||
|
||||
Scalar grad_l(Scalar y, Scalar x, Scalar parent_adjoint) const override {
|
||||
return parent_adjoint * x / (y * y + x * x);
|
||||
Scalar grad_l(Scalar y, Scalar x) const override {
|
||||
return this->adjoint * x / (y * y + x * x);
|
||||
}
|
||||
|
||||
Scalar grad_r(Scalar y, Scalar x, Scalar parent_adjoint) const override {
|
||||
return parent_adjoint * -y / (y * y + x * x);
|
||||
Scalar grad_r(Scalar y, Scalar x) const override {
|
||||
return this->adjoint * -y / (y * y + x * x);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& y, const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint * x / (y * y + x * x);
|
||||
const ExpressionPtr<Scalar>& y,
|
||||
const ExpressionPtr<Scalar>& x) const override {
|
||||
return this->adjoint_expr * x / (y * y + x * x);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_r(
|
||||
const ExpressionPtr<Scalar>& y, const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint * -y / (y * y + x * x);
|
||||
const ExpressionPtr<Scalar>& y,
|
||||
const ExpressionPtr<Scalar>& x) const override {
|
||||
return this->adjoint_expr * -y / (y * y + x * x);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1105,15 +1083,15 @@ struct CosExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "cos"; }
|
||||
|
||||
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
|
||||
Scalar grad_l(Scalar x, Scalar) const override {
|
||||
using std::sin;
|
||||
return parent_adjoint * -sin(x);
|
||||
return this->adjoint * -sin(x);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint * -sin(x);
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
return this->adjoint_expr * -sin(x);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1159,15 +1137,15 @@ struct CoshExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "cosh"; }
|
||||
|
||||
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
|
||||
Scalar grad_l(Scalar x, Scalar) const override {
|
||||
using std::sinh;
|
||||
return parent_adjoint * sinh(x);
|
||||
return this->adjoint * sinh(x);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint * sinh(x);
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
return this->adjoint_expr * sinh(x);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1213,16 +1191,15 @@ struct ErfExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "erf"; }
|
||||
|
||||
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
|
||||
Scalar grad_l(Scalar x, Scalar) const override {
|
||||
using std::exp;
|
||||
return parent_adjoint * Scalar(2.0 * std::numbers::inv_sqrtpi) *
|
||||
exp(-x * x);
|
||||
return this->adjoint * Scalar(2.0 * std::numbers::inv_sqrtpi) * exp(-x * x);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint *
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
return this->adjoint_expr *
|
||||
constant_ptr(Scalar(2.0 * std::numbers::inv_sqrtpi)) * exp(-x * x);
|
||||
}
|
||||
};
|
||||
@@ -1270,15 +1247,15 @@ struct ExpExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "exp"; }
|
||||
|
||||
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
|
||||
Scalar grad_l(Scalar x, Scalar) const override {
|
||||
using std::exp;
|
||||
return parent_adjoint * exp(x);
|
||||
return this->adjoint * exp(x);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint * exp(x);
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
return this->adjoint_expr * exp(x);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1330,26 +1307,26 @@ struct HypotExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "hypot"; }
|
||||
|
||||
Scalar grad_l(Scalar x, Scalar y, Scalar parent_adjoint) const override {
|
||||
Scalar grad_l(Scalar x, Scalar y) const override {
|
||||
using std::hypot;
|
||||
return parent_adjoint * x / hypot(x, y);
|
||||
return this->adjoint * x / hypot(x, y);
|
||||
}
|
||||
|
||||
Scalar grad_r(Scalar x, Scalar y, Scalar parent_adjoint) const override {
|
||||
Scalar grad_r(Scalar x, Scalar y) const override {
|
||||
using std::hypot;
|
||||
return parent_adjoint * y / hypot(x, y);
|
||||
return this->adjoint * y / hypot(x, y);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>& y,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint * x / hypot(x, y);
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>& y) const override {
|
||||
return this->adjoint_expr * x / hypot(x, y);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_r(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>& y,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint * y / hypot(x, y);
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>& y) const override {
|
||||
return this->adjoint_expr * y / hypot(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1399,14 +1376,12 @@ struct LogExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "log"; }
|
||||
|
||||
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
|
||||
return parent_adjoint / x;
|
||||
}
|
||||
Scalar grad_l(Scalar x, Scalar) const override { return this->adjoint / x; }
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint / x;
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
return this->adjoint_expr / x;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1453,14 +1428,14 @@ struct Log10Expression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "log10"; }
|
||||
|
||||
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
|
||||
return parent_adjoint / (Scalar(std::numbers::ln10) * x);
|
||||
Scalar grad_l(Scalar x, Scalar) const override {
|
||||
return this->adjoint / (Scalar(std::numbers::ln10) * x);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint / (constant_ptr(Scalar(std::numbers::ln10)) * x);
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
return this->adjoint_expr / (constant_ptr(Scalar(std::numbers::ln10)) * x);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1510,37 +1485,37 @@ struct MaxExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "max"; }
|
||||
|
||||
Scalar grad_l(Scalar a, Scalar b, Scalar parent_adjoint) const override {
|
||||
Scalar grad_l(Scalar a, Scalar b) const override {
|
||||
if (a >= b) {
|
||||
return parent_adjoint;
|
||||
return this->adjoint;
|
||||
} else {
|
||||
return Scalar(0);
|
||||
}
|
||||
}
|
||||
|
||||
Scalar grad_r(Scalar a, Scalar b, Scalar parent_adjoint) const override {
|
||||
Scalar grad_r(Scalar a, Scalar b) const override {
|
||||
if (b > a) {
|
||||
return parent_adjoint;
|
||||
return this->adjoint;
|
||||
} else {
|
||||
return Scalar(0);
|
||||
}
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& a, const ExpressionPtr<Scalar>& b,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
const ExpressionPtr<Scalar>& a,
|
||||
const ExpressionPtr<Scalar>& b) const override {
|
||||
if (a->val >= b->val) {
|
||||
return parent_adjoint;
|
||||
return this->adjoint_expr;
|
||||
} else {
|
||||
return constant_ptr(Scalar(0));
|
||||
}
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_r(
|
||||
const ExpressionPtr<Scalar>& a, const ExpressionPtr<Scalar>& b,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
const ExpressionPtr<Scalar>& a,
|
||||
const ExpressionPtr<Scalar>& b) const override {
|
||||
if (b->val > a->val) {
|
||||
return parent_adjoint;
|
||||
return this->adjoint_expr;
|
||||
} else {
|
||||
return constant_ptr(Scalar(0));
|
||||
}
|
||||
@@ -1589,38 +1564,38 @@ struct MinExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "min"; }
|
||||
|
||||
Scalar grad_l(Scalar a, Scalar b, Scalar parent_adjoint) const override {
|
||||
Scalar grad_l(Scalar a, Scalar b) const override {
|
||||
if (a <= b) {
|
||||
return parent_adjoint;
|
||||
return this->adjoint;
|
||||
} else {
|
||||
return Scalar(0);
|
||||
}
|
||||
}
|
||||
|
||||
Scalar grad_r([[maybe_unused]] Scalar a, [[maybe_unused]] Scalar b,
|
||||
Scalar parent_adjoint) const override {
|
||||
Scalar grad_r([[maybe_unused]] Scalar a,
|
||||
[[maybe_unused]] Scalar b) const override {
|
||||
if (b < a) {
|
||||
return parent_adjoint;
|
||||
return this->adjoint;
|
||||
} else {
|
||||
return Scalar(0);
|
||||
}
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& a, const ExpressionPtr<Scalar>& b,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
const ExpressionPtr<Scalar>& a,
|
||||
const ExpressionPtr<Scalar>& b) const override {
|
||||
if (a->val <= b->val) {
|
||||
return parent_adjoint;
|
||||
return this->adjoint_expr;
|
||||
} else {
|
||||
return constant_ptr(Scalar(0));
|
||||
}
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_r(
|
||||
const ExpressionPtr<Scalar>& a, const ExpressionPtr<Scalar>& b,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
const ExpressionPtr<Scalar>& a,
|
||||
const ExpressionPtr<Scalar>& b) const override {
|
||||
if (b->val < a->val) {
|
||||
return parent_adjoint;
|
||||
return this->adjoint_expr;
|
||||
} else {
|
||||
return constant_ptr(Scalar(0));
|
||||
}
|
||||
@@ -1672,14 +1647,12 @@ struct PowExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "pow"; }
|
||||
|
||||
Scalar grad_l(Scalar base, Scalar power,
|
||||
Scalar parent_adjoint) const override {
|
||||
Scalar grad_l(Scalar base, Scalar power) const override {
|
||||
using std::pow;
|
||||
return parent_adjoint * pow(base, power - Scalar(1)) * power;
|
||||
return this->adjoint * pow(base, power - Scalar(1)) * power;
|
||||
}
|
||||
|
||||
Scalar grad_r(Scalar base, Scalar power,
|
||||
Scalar parent_adjoint) const override {
|
||||
Scalar grad_r(Scalar base, Scalar power) const override {
|
||||
using std::log;
|
||||
using std::pow;
|
||||
|
||||
@@ -1687,25 +1660,26 @@ struct PowExpression final : Expression<Scalar> {
|
||||
if (base == Scalar(0)) {
|
||||
return Scalar(0);
|
||||
} else {
|
||||
return parent_adjoint * pow(base, power) * log(base);
|
||||
return this->adjoint * pow(base, power) * log(base);
|
||||
}
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& base, const ExpressionPtr<Scalar>& power,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint * pow(base, power - constant_ptr(Scalar(1))) * power;
|
||||
const ExpressionPtr<Scalar>& base,
|
||||
const ExpressionPtr<Scalar>& power) const override {
|
||||
return this->adjoint_expr * pow(base, power - constant_ptr(Scalar(1))) *
|
||||
power;
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_r(
|
||||
const ExpressionPtr<Scalar>& base, const ExpressionPtr<Scalar>& power,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
const ExpressionPtr<Scalar>& base,
|
||||
const ExpressionPtr<Scalar>& power) const override {
|
||||
// Since x log(x) -> 0 as x -> 0
|
||||
if (base->val == Scalar(0)) {
|
||||
// Return zero
|
||||
return base;
|
||||
} else {
|
||||
return parent_adjoint * pow(base, power) * log(base);
|
||||
return this->adjoint_expr * pow(base, power) * log(base);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -1821,15 +1795,15 @@ struct SinExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "sin"; }
|
||||
|
||||
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
|
||||
Scalar grad_l(Scalar x, Scalar) const override {
|
||||
using std::cos;
|
||||
return parent_adjoint * cos(x);
|
||||
return this->adjoint * cos(x);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint * cos(x);
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
return this->adjoint_expr * cos(x);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1876,15 +1850,15 @@ struct SinhExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "sinh"; }
|
||||
|
||||
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
|
||||
Scalar grad_l(Scalar x, Scalar) const override {
|
||||
using std::cosh;
|
||||
return parent_adjoint * cosh(x);
|
||||
return this->adjoint * cosh(x);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint * cosh(x);
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
return this->adjoint_expr * cosh(x);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1931,15 +1905,15 @@ struct SqrtExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "sqrt"; }
|
||||
|
||||
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
|
||||
Scalar grad_l(Scalar x, Scalar) const override {
|
||||
using std::sqrt;
|
||||
return parent_adjoint / (Scalar(2) * sqrt(x));
|
||||
return this->adjoint / (Scalar(2) * sqrt(x));
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
return parent_adjoint / (constant_ptr(Scalar(2)) * sqrt(x));
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
return this->adjoint_expr / (constant_ptr(Scalar(2)) * sqrt(x));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1987,18 +1961,18 @@ struct TanExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "tan"; }
|
||||
|
||||
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
|
||||
Scalar grad_l(Scalar x, Scalar) const override {
|
||||
using std::cos;
|
||||
|
||||
auto c = cos(x);
|
||||
return parent_adjoint / (c * c);
|
||||
return this->adjoint / (c * c);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
auto c = cos(x);
|
||||
return parent_adjoint / (c * c);
|
||||
return this->adjoint_expr / (c * c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -2045,18 +2019,18 @@ struct TanhExpression final : Expression<Scalar> {
|
||||
|
||||
std::string_view name() const override { return "tanh"; }
|
||||
|
||||
Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override {
|
||||
Scalar grad_l(Scalar x, Scalar) const override {
|
||||
using std::cosh;
|
||||
|
||||
auto c = cosh(x);
|
||||
return parent_adjoint / (c * c);
|
||||
return this->adjoint / (c * c);
|
||||
}
|
||||
|
||||
ExpressionPtr<Scalar> grad_expr_l(
|
||||
const ExpressionPtr<Scalar>& x, const ExpressionPtr<Scalar>&,
|
||||
const ExpressionPtr<Scalar>& parent_adjoint) const override {
|
||||
const ExpressionPtr<Scalar>& x,
|
||||
const ExpressionPtr<Scalar>&) const override {
|
||||
auto c = cosh(x);
|
||||
return parent_adjoint / (c * c);
|
||||
return this->adjoint_expr / (c * c);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -133,11 +133,11 @@ void append_triplets(
|
||||
if (lhs != nullptr) {
|
||||
if (rhs != nullptr) {
|
||||
// Binary operator
|
||||
lhs->adjoint += node->grad_l(lhs->val, rhs->val, node->adjoint);
|
||||
rhs->adjoint += node->grad_r(lhs->val, rhs->val, node->adjoint);
|
||||
lhs->adjoint += node->grad_l(lhs->val, rhs->val);
|
||||
rhs->adjoint += node->grad_r(lhs->val, rhs->val);
|
||||
} else {
|
||||
// Unary operator
|
||||
lhs->adjoint += node->grad_l(lhs->val, Scalar(0), node->adjoint);
|
||||
lhs->adjoint += node->grad_l(lhs->val, Scalar(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ class Hessian {
|
||||
: m_variables{detail::gradient_tree(
|
||||
detail::topological_sort(variable.expr), wrt)},
|
||||
m_wrt{std::move(wrt)} {
|
||||
slp_assert(wrt.cols() == 1);
|
||||
slp_assert(m_wrt.cols() == 1);
|
||||
|
||||
for (auto& variable : m_variables) {
|
||||
m_top_lists.emplace_back(detail::topological_sort(variable.expr));
|
||||
|
||||
@@ -434,10 +434,11 @@ class VariableBlock : public SleipnirBase {
|
||||
slp_assert(cols() == rhs.rows() && cols() == rhs.cols());
|
||||
|
||||
for (int i = 0; i < rows(); ++i) {
|
||||
Mat lhs_old_row = row(i);
|
||||
for (int j = 0; j < rhs.cols(); ++j) {
|
||||
Variable sum{Scalar(0)};
|
||||
for (int k = 0; k < cols(); ++k) {
|
||||
sum += (*this)(i, k) * rhs(k, j);
|
||||
sum += lhs_old_row[k] * rhs(k, j);
|
||||
}
|
||||
(*this)(i, j) = sum;
|
||||
}
|
||||
|
||||
@@ -647,10 +647,11 @@ class VariableMatrix : public SleipnirBase {
|
||||
slp_assert(cols() == rhs.rows() && cols() == rhs.cols());
|
||||
|
||||
for (int i = 0; i < rows(); ++i) {
|
||||
VariableMatrix lhs_old_row = row(i);
|
||||
for (int j = 0; j < rhs.cols(); ++j) {
|
||||
Variable sum{Scalar(0)};
|
||||
for (int k = 0; k < cols(); ++k) {
|
||||
sum += (*this)(i, k) * rhs(k, j);
|
||||
sum += lhs_old_row[k] * rhs(k, j);
|
||||
}
|
||||
(*this)(i, j) = sum;
|
||||
}
|
||||
@@ -665,7 +666,7 @@ class VariableMatrix : public SleipnirBase {
|
||||
/// @return Result of multiplication.
|
||||
VariableMatrix& operator*=(const ScalarLike auto& rhs) {
|
||||
for (int row = 0; row < rows(); ++row) {
|
||||
for (int col = 0; col < rhs.cols(); ++col) {
|
||||
for (int col = 0; col < cols(); ++col) {
|
||||
(*this)(row, col) *= rhs;
|
||||
}
|
||||
}
|
||||
@@ -1686,11 +1687,11 @@ VariableMatrix<Scalar> gradient_tree(const ExpressionGraph<Scalar>& top_list,
|
||||
if (lhs != nullptr) {
|
||||
if (rhs != nullptr) {
|
||||
// Binary operator
|
||||
lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
|
||||
rhs->adjoint_expr += node->grad_expr_r(lhs, rhs, node->adjoint_expr);
|
||||
lhs->adjoint_expr += node->grad_expr_l(lhs, rhs);
|
||||
rhs->adjoint_expr += node->grad_expr_r(lhs, rhs);
|
||||
} else {
|
||||
// Unary operator
|
||||
lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr);
|
||||
lhs->adjoint_expr += node->grad_expr_l(lhs, rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -501,6 +501,7 @@ ExitStatus interior_point(
|
||||
// Loop until a step is accepted
|
||||
while (1) {
|
||||
trial_x = x + α * step.p_x;
|
||||
trial_c_i = matrices.c_i(trial_x);
|
||||
if (options.feasible_ipm && c_i.cwiseGreater(Scalar(0)).all()) {
|
||||
// If the inequality constraints are all feasible, prevent them from
|
||||
// becoming infeasible again.
|
||||
@@ -515,7 +516,6 @@ ExitStatus interior_point(
|
||||
|
||||
trial_f = matrices.f(trial_x);
|
||||
trial_c_e = matrices.c_e(trial_x);
|
||||
trial_c_i = matrices.c_i(trial_x);
|
||||
|
||||
// If f(xₖ + αpₖˣ), cₑ(xₖ + αpₖˣ), or cᵢ(xₖ + αpₖˣ) aren't finite, reduce
|
||||
// step size immediately
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include <Eigen/Core>
|
||||
#include <Eigen/SparseCore>
|
||||
@@ -42,8 +43,8 @@ struct ProblemScaling {
|
||||
/// @param f Cost scaling factor d_f.
|
||||
/// @param c_e Equality constraint scaling factors d_cₑ.
|
||||
/// @param c_i Inequality constraint scaling factors d_cᵢ.
|
||||
ProblemScaling(Scalar f, const DenseVector& c_e, const DenseVector& c_i)
|
||||
: f{f}, c_e{c_e}, c_i{c_i} {}
|
||||
ProblemScaling(Scalar f, DenseVector c_e, DenseVector c_i)
|
||||
: f{f}, c_e{std::move(c_e)}, c_i{std::move(c_i)} {}
|
||||
|
||||
/// Computes Newton problem scaling.
|
||||
///
|
||||
|
||||
Reference in New Issue
Block a user