From 08784dc2d1f2eaf3f80ef4a29b2675f23ed0254e Mon Sep 17 00:00:00 2001 From: Tyler Veness Date: Mon, 1 Dec 2025 12:51:28 -0800 Subject: [PATCH] [upstream_utils] Upgrade to Sleipnir 0.3.2 (#8323) Also includes a C++ benchmark, which has a Java counterpart in #8236. --- .../src/main/native/cpp/CartPoleBenchmark.hpp | 27 +- docs/build.gradle | 5 + upstream_utils/sleipnir.py | 6 +- .../sleipnir_patches/0001-Use-fmtlib.patch | 85 +- .../0002-Use-wpi-SmallVector.patch | 53 +- .../0003-Use-wpi-byteswap.patch | 19 +- .../0004-Replace-std-to_underlying.patch | 78 +- .../0005-Replace-std-views-zip.patch | 54 +- ...-Suppress-clang-tidy-false-positives.patch | 4 +- ...ppress-GCC-12-warning-false-positive.patch | 10 +- ...ead-of-multidimensional-array-subsc.patch} | 724 +++++---- .../native/cpp/controller/ArmFeedforward.cpp | 4 +- .../main/native/cpp/geometry/Ellipse2d.cpp | 2 +- .../autodiff/adjoint_expression_graph.hpp | 45 +- .../include/sleipnir/autodiff/expression.hpp | 1389 ++++++++++------- .../sleipnir/autodiff/expression_graph.hpp | 20 +- .../include/sleipnir/autodiff/gradient.hpp | 20 +- .../include/sleipnir/autodiff/hessian.hpp | 47 +- .../include/sleipnir/autodiff/jacobian.hpp | 52 +- .../sleipnir/autodiff/sleipnir_base.hpp | 13 + .../include/sleipnir/autodiff/slice.hpp | 6 +- .../include/sleipnir/autodiff/variable.hpp | 730 ++++++--- .../sleipnir/autodiff/variable_block.hpp | 223 ++- .../sleipnir/autodiff/variable_matrix.hpp | 999 +++++++++--- .../sleipnir/optimization/multistart.hpp | 17 +- .../include/sleipnir/optimization/ocp.hpp | 134 +- .../include/sleipnir/optimization/problem.hpp | 456 +++++- .../optimization/solver/interior_point.hpp | 857 +++++++--- .../interior_point_matrix_callbacks.hpp | 207 +++ .../optimization/solver/iteration_info.hpp | 13 +- .../sleipnir/optimization/solver/newton.hpp | 329 +++- .../solver/newton_matrix_callbacks.hpp | 89 ++ .../sleipnir/optimization/solver/sqp.hpp | 605 +++++-- .../solver/sqp_matrix_callbacks.hpp | 148 ++ .../optimization/solver/util}/bounds.hpp | 86 +- .../solver/util/error_estimate.hpp | 65 +- .../optimization/solver/util/filter.hpp | 64 +- .../util/fraction_to_the_boundary_rule.hpp | 11 +- .../optimization/solver/util}/inertia.hpp | 38 +- .../solver/util/is_locally_infeasible.hpp | 23 +- .../optimization/solver/util/kkt_error.hpp | 43 +- .../solver/util}/regularized_ldlt.hpp | 51 +- .../sleipnir/include/sleipnir/util/assert.hpp | 10 +- .../include/sleipnir/util/concepts.hpp | 40 +- .../sleipnir/include/sleipnir/util/empty.hpp | 17 + .../include/sleipnir/util/function_ref.hpp | 3 +- .../sleipnir/util/intrusive_shared_ptr.hpp | 17 +- .../sleipnir/include/sleipnir/util/print.hpp | 12 +- .../sleipnir}/util/print_diagnostics.hpp | 72 +- .../sleipnir}/util/scope_exit.hpp | 19 + .../sleipnir}/util/scoped_profiler.hpp | 4 +- .../sleipnir}/util/setup_profiler.hpp | 0 .../sleipnir}/util/solve_profiler.hpp | 0 .../sleipnir/include/sleipnir/util/spy.hpp | 17 +- .../sleipnir/src/autodiff/gradient.cpp | 5 + .../sleipnir/src/autodiff/hessian.cpp | 6 + .../sleipnir/src/autodiff/jacobian.cpp | 5 + .../sleipnir/src/autodiff/variable_matrix.cpp | 257 +-- .../sleipnir/src/optimization/ocp.cpp | 5 + .../sleipnir/src/optimization/problem.cpp | 392 +---- .../optimization/solver/interior_point.cpp | 658 +------- .../src/optimization/solver/newton.cpp | 259 +-- .../sleipnir/src/optimization/solver/sqp.cpp | 478 +----- .../thirdparty/sleipnir/src/util/pool.cpp | 6 +- .../cpp/optimization/NonlinearProblemTest.cpp | 12 +- 65 files changed, 5699 insertions(+), 4446 deletions(-) rename upstream_utils/sleipnir_patches/{0008-Revert-Use-multidimensional-array-subscript-operator.patch => 0008-Use-operator-instead-of-multidimensional-array-subsc.patch} (61%) create mode 100644 wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/sleipnir_base.hpp create mode 100644 wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/optimization/solver/interior_point_matrix_callbacks.hpp create mode 100644 wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/optimization/solver/newton_matrix_callbacks.hpp create mode 100644 wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/optimization/solver/sqp_matrix_callbacks.hpp rename wpimath/src/main/native/thirdparty/sleipnir/{src/optimization => include/sleipnir/optimization/solver/util}/bounds.hpp (77%) rename wpimath/src/main/native/thirdparty/sleipnir/{src => include/sleipnir}/optimization/solver/util/error_estimate.hpp (57%) rename wpimath/src/main/native/thirdparty/sleipnir/{src => include/sleipnir}/optimization/solver/util/filter.hpp (67%) rename wpimath/src/main/native/thirdparty/sleipnir/{src => include/sleipnir}/optimization/solver/util/fraction_to_the_boundary_rule.hpp (78%) rename wpimath/src/main/native/thirdparty/sleipnir/{src/optimization => include/sleipnir/optimization/solver/util}/inertia.hpp (52%) rename wpimath/src/main/native/thirdparty/sleipnir/{src => include/sleipnir}/optimization/solver/util/is_locally_infeasible.hpp (66%) rename wpimath/src/main/native/thirdparty/sleipnir/{src => include/sleipnir}/optimization/solver/util/kkt_error.hpp (57%) rename wpimath/src/main/native/thirdparty/sleipnir/{src/optimization => include/sleipnir/optimization/solver/util}/regularized_ldlt.hpp (82%) create mode 100644 wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/util/empty.hpp rename wpimath/src/main/native/thirdparty/sleipnir/{src => include/sleipnir}/util/print_diagnostics.hpp (82%) rename wpimath/src/main/native/thirdparty/sleipnir/{src => include/sleipnir}/util/scope_exit.hpp (61%) rename wpimath/src/main/native/thirdparty/sleipnir/{src => include/sleipnir}/util/scoped_profiler.hpp (94%) rename wpimath/src/main/native/thirdparty/sleipnir/{src => include/sleipnir}/util/setup_profiler.hpp (100%) rename wpimath/src/main/native/thirdparty/sleipnir/{src => include/sleipnir}/util/solve_profiler.hpp (100%) create mode 100644 wpimath/src/main/native/thirdparty/sleipnir/src/autodiff/gradient.cpp create mode 100644 wpimath/src/main/native/thirdparty/sleipnir/src/autodiff/hessian.cpp create mode 100644 wpimath/src/main/native/thirdparty/sleipnir/src/autodiff/jacobian.cpp create mode 100644 wpimath/src/main/native/thirdparty/sleipnir/src/optimization/ocp.cpp diff --git a/benchmark/src/main/native/cpp/CartPoleBenchmark.hpp b/benchmark/src/main/native/cpp/CartPoleBenchmark.hpp index ab07064f8b..36f3996a6d 100644 --- a/benchmark/src/main/native/cpp/CartPoleBenchmark.hpp +++ b/benchmark/src/main/native/cpp/CartPoleBenchmark.hpp @@ -11,8 +11,9 @@ #include "wpi/math/system/NumericalIntegration.hpp" -inline slp::VariableMatrix CartPoleDynamics(const slp::VariableMatrix& x, - const slp::VariableMatrix& u) { +inline slp::VariableMatrix CartPoleDynamics( + const slp::VariableMatrix& x, + const slp::VariableMatrix& u) { constexpr double m_c = 5.0; // Cart mass (kg) constexpr double m_p = 0.5; // Pole mass (kg) constexpr double l = 0.5; // Pole length (m) @@ -25,23 +26,23 @@ inline slp::VariableMatrix CartPoleDynamics(const slp::VariableMatrix& x, // [ m_c + m_p m_p l cosθ] // M(q) = [m_p l cosθ m_p l² ] - slp::VariableMatrix M{{m_c + m_p, m_p * l * cos(theta)}, - {m_p * l * cos(theta), m_p * std::pow(l, 2)}}; + slp::VariableMatrix M{{m_c + m_p, m_p * l * cos(theta)}, + {m_p * l * cos(theta), m_p * std::pow(l, 2)}}; // [0 −m_p lθ̇ sinθ] // C(q, q̇) = [0 0 ] - slp::VariableMatrix C{{0, -m_p * l * thetadot * sin(theta)}, {0, 0}}; + slp::VariableMatrix C{{0, -m_p * l * thetadot * sin(theta)}, {0, 0}}; // [ 0 ] // τ_g(q) = [-m_p gl sinθ] - slp::VariableMatrix tau_g{{0}, {-m_p * g * l * sin(theta)}}; + slp::VariableMatrix tau_g{{0}, {-m_p * g * l * sin(theta)}}; // [1] // B = [0] constexpr Eigen::Matrix B{{1}, {0}}; // q̈ = M⁻¹(q)(τ_g(q) − C(q, q̇)q̇ + Bu) - slp::VariableMatrix qddot{4}; + slp::VariableMatrix qddot{4}; qddot.segment(0, 2) = qdot; qddot.segment(2, 2) = solve(M, tau_g - C * qdot + B * u); return qddot; @@ -63,7 +64,7 @@ inline void BM_CartPole(benchmark::State& state) { constexpr Eigen::Vector x_final{ {1.0, std::numbers::pi, 0.0, 0.0}}; - slp::Problem problem; + slp::Problem problem; // x = [q, q̇]ᵀ = [x, θ, ẋ, θ̇]ᵀ auto X = problem.decision_variable(4, N + 1); @@ -95,11 +96,11 @@ inline void BM_CartPole(benchmark::State& state) { // Dynamics constraints - RK4 integration for (int k = 0; k < N; ++k) { - problem.subject_to( - X.col(k + 1) == - wpi::math::RK4(CartPoleDynamics, X.col(k), - U.col(k), dt)); + problem.subject_to(X.col(k + 1) == + wpi::math::RK4, + slp::VariableMatrix>( + CartPoleDynamics, X.col(k), U.col(k), dt)); } // Minimize sum squared inputs diff --git a/docs/build.gradle b/docs/build.gradle index cdd25115ea..00bd4eaa33 100644 --- a/docs/build.gradle +++ b/docs/build.gradle @@ -97,6 +97,11 @@ doxygen.sourceSets.main { exclude 'wpi/util/bit.hpp' exclude 'wpi/util/raw_ostream.hpp' + // Sleipnir + exclude 'sleipnir/optimization/solver/interior_point.hpp' + exclude 'sleipnir/optimization/solver/newton.hpp' + exclude 'sleipnir/optimization/solver/sqp.hpp' + // apriltag exclude 'apriltag_pose.h' diff --git a/upstream_utils/sleipnir.py b/upstream_utils/sleipnir.py index 96ddd0c628..a445d0172e 100755 --- a/upstream_utils/sleipnir.py +++ b/upstream_utils/sleipnir.py @@ -18,8 +18,7 @@ def copy_upstream_src(wpilib_root: Path): # Copy Sleipnir files into allwpilib walk_cwd_and_copy_if( - lambda dp, f: (has_prefix(dp, Path("include")) or has_prefix(dp, Path("src"))) - and f not in [".styleguide", ".styleguide-license"], + lambda dp, f: (has_prefix(dp, Path("include")) or has_prefix(dp, Path("src"))), wpimath / "src/main/native/thirdparty/sleipnir", ) @@ -49,8 +48,7 @@ using small_vector = wpi::util::SmallVector; def main(): name = "sleipnir" url = "https://github.com/SleipnirGroup/Sleipnir" - # main on 2025-09-19 - tag = "7f89d5547702a09e3617bc31fe5bafe6add04fab" + tag = "v0.3.2" sleipnir = Lib(name, url, tag, copy_upstream_src) sleipnir.main() diff --git a/upstream_utils/sleipnir_patches/0001-Use-fmtlib.patch b/upstream_utils/sleipnir_patches/0001-Use-fmtlib.patch index 3677503197..8873cebf15 100644 --- a/upstream_utils/sleipnir_patches/0001-Use-fmtlib.patch +++ b/upstream_utils/sleipnir_patches/0001-Use-fmtlib.patch @@ -4,41 +4,41 @@ Date: Wed, 29 May 2024 16:29:55 -0700 Subject: [PATCH 1/8] Use fmtlib --- - include/.styleguide | 1 + - include/sleipnir/util/assert.hpp | 5 +++-- - include/sleipnir/util/print.hpp | 31 ++++++++++++++++++------------- - src/.styleguide | 1 + - src/optimization/problem.cpp | 1 + - 5 files changed, 24 insertions(+), 15 deletions(-) + include/sleipnir/optimization/problem.hpp | 1 + + include/sleipnir/util/assert.hpp | 5 ++-- + include/sleipnir/util/print.hpp | 31 +++++++++++++---------- + 3 files changed, 22 insertions(+), 15 deletions(-) -diff --git a/include/.styleguide b/include/.styleguide -index 1b6652d3d5886cf8c9eca0d855c21031775bad7c..4f4c76204071f90bf49eddb8c2aceb583b5e09ba 100644 ---- a/include/.styleguide -+++ b/include/.styleguide -@@ -8,5 +8,6 @@ cppSrcFileInclude { +diff --git a/include/sleipnir/optimization/problem.hpp b/include/sleipnir/optimization/problem.hpp +index 3185466605b6604068e2807e461d07d8c856c505..95a33952a5a368c7c81491dbe849a8096357dc38 100644 +--- a/include/sleipnir/optimization/problem.hpp ++++ b/include/sleipnir/optimization/problem.hpp +@@ -15,6 +15,7 @@ - includeOtherLibs { - ^Eigen/ -+ ^fmt/ - ^gch/ - } + #include + #include ++#include + #include + + #include "sleipnir/autodiff/expression_type.hpp" diff --git a/include/sleipnir/util/assert.hpp b/include/sleipnir/util/assert.hpp -index 75d8ffca32accbf66ffce30f073de1db2f42469b..53de01928b929793fa77885ec4a6d1a928bdc5a9 100644 +index 0846928c3da7a6047a3c271dd2d377a3b755eeab..5d432608def05b6dee6b7cbdb9a0b91a6ab5e1c2 100644 --- a/include/sleipnir/util/assert.hpp +++ b/include/sleipnir/util/assert.hpp -@@ -3,9 +3,10 @@ - #pragma once +@@ -4,10 +4,11 @@ + + #ifdef SLEIPNIR_PYTHON - #ifdef JORMUNGANDR -#include #include #include -+ + +#include ++ /** * Throw an exception in Python. */ -@@ -13,7 +14,7 @@ +@@ -15,7 +16,7 @@ do { \ if (!(condition)) { \ auto location = std::source_location::current(); \ @@ -48,7 +48,7 @@ index 75d8ffca32accbf66ffce30f073de1db2f42469b..53de01928b929793fa77885ec4a6d1a9 location.line(), location.function_name(), #condition)); \ } \ diff --git a/include/sleipnir/util/print.hpp b/include/sleipnir/util/print.hpp -index fe430352dabf4cd6a890dc8007237c7a261dfd4b..055d5c9fa246201f1d8ae7ddca00b1159aeb2a57 100644 +index 797df849f63d960cf10eaf847415595961868ab0..a89b7d4f9864965443405a8e79cddd8dbfc54ad3 100644 --- a/include/sleipnir/util/print.hpp +++ b/include/sleipnir/util/print.hpp @@ -4,10 +4,15 @@ @@ -76,8 +76,8 @@ index fe430352dabf4cd6a890dc8007237c7a261dfd4b..055d5c9fa246201f1d8ae7ddca00b115 + * Wrapper around fmt::print() that squelches write failure exceptions. */ template --inline void print(std::format_string fmt, T&&... args) { -+inline void print(fmt::format_string fmt, T&&... args) { +-void print(std::format_string fmt, T&&... args) { ++void print(fmt::format_string fmt, T&&... args) { try { - std::print(fmt, std::forward(args)...); + fmt::print(fmt, std::forward(args)...); @@ -90,8 +90,8 @@ index fe430352dabf4cd6a890dc8007237c7a261dfd4b..055d5c9fa246201f1d8ae7ddca00b115 + * Wrapper around fmt::print() that squelches write failure exceptions. */ template --inline void print(std::FILE* f, std::format_string fmt, T&&... args) { -+inline void print(std::FILE* f, fmt::format_string fmt, T&&... args) { +-void print(std::FILE* f, std::format_string fmt, T&&... args) { ++void print(std::FILE* f, fmt::format_string fmt, T&&... args) { try { - std::print(f, fmt, std::forward(args)...); + fmt::print(f, fmt, std::forward(args)...); @@ -104,8 +104,8 @@ index fe430352dabf4cd6a890dc8007237c7a261dfd4b..055d5c9fa246201f1d8ae7ddca00b115 + * Wrapper around fmt::println() that squelches write failure exceptions. */ template --inline void println(std::format_string fmt, T&&... args) { -+inline void println(fmt::format_string fmt, T&&... args) { +-void println(std::format_string fmt, T&&... args) { ++void println(fmt::format_string fmt, T&&... args) { try { - std::println(fmt, std::forward(args)...); + fmt::println(fmt, std::forward(args)...); @@ -118,34 +118,11 @@ index fe430352dabf4cd6a890dc8007237c7a261dfd4b..055d5c9fa246201f1d8ae7ddca00b115 + * Wrapper around fmt::println() that squelches write failure exceptions. */ template --inline void println(std::FILE* f, std::format_string fmt, T&&... args) { -+inline void println(std::FILE* f, fmt::format_string fmt, T&&... args) { +-void println(std::FILE* f, std::format_string fmt, T&&... args) { ++void println(std::FILE* f, fmt::format_string fmt, T&&... args) { try { - std::println(f, fmt, std::forward(args)...); + fmt::println(f, fmt, std::forward(args)...); } catch (const std::system_error&) { } } -diff --git a/src/.styleguide b/src/.styleguide -index 1b6652d3d5886cf8c9eca0d855c21031775bad7c..4f4c76204071f90bf49eddb8c2aceb583b5e09ba 100644 ---- a/src/.styleguide -+++ b/src/.styleguide -@@ -8,5 +8,6 @@ cppSrcFileInclude { - - includeOtherLibs { - ^Eigen/ -+ ^fmt/ - ^gch/ - } -diff --git a/src/optimization/problem.cpp b/src/optimization/problem.cpp -index c3331197e2365934273f57422b79fa18c2b78a5b..09828cdb6d7cddff692b9d17603dc0c11cd5a3ec 100644 ---- a/src/optimization/problem.cpp -+++ b/src/optimization/problem.cpp -@@ -11,6 +11,7 @@ - - #include - #include -+#include - #include - - #include "optimization/bounds.hpp" diff --git a/upstream_utils/sleipnir_patches/0002-Use-wpi-SmallVector.patch b/upstream_utils/sleipnir_patches/0002-Use-wpi-SmallVector.patch index 63dde99b18..8f317b0f00 100644 --- a/upstream_utils/sleipnir_patches/0002-Use-wpi-SmallVector.patch +++ b/upstream_utils/sleipnir_patches/0002-Use-wpi-SmallVector.patch @@ -5,37 +5,37 @@ Subject: [PATCH 2/8] Use wpi::SmallVector --- include/sleipnir/autodiff/expression.hpp | 4 ++-- - include/sleipnir/autodiff/variable.hpp | 5 ++--- + include/sleipnir/autodiff/variable.hpp | 4 ++-- include/sleipnir/autodiff/variable_matrix.hpp | 4 ++-- - 3 files changed, 6 insertions(+), 7 deletions(-) + 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/sleipnir/autodiff/expression.hpp b/include/sleipnir/autodiff/expression.hpp -index bb4d8c5641a5b3d633d372674e0a35f857889cd4..53a5f6d68d3153537840c4ff45fe5e5d8b0076b7 100644 +index f5919de6c9c0be044335ce7764ded545215f0486..46814576a3db9f472329b880b94b1ab98d218867 100644 --- a/include/sleipnir/autodiff/expression.hpp +++ b/include/sleipnir/autodiff/expression.hpp -@@ -30,7 +30,7 @@ inline constexpr bool USE_POOL_ALLOCATOR = true; - struct Expression; - - inline constexpr void inc_ref_count(Expression* expr); --inline constexpr void dec_ref_count(Expression* expr); -+inline void dec_ref_count(Expression* expr); +@@ -33,7 +33,7 @@ struct Expression; + template + constexpr void inc_ref_count(Expression* expr); + template +-constexpr void dec_ref_count(Expression* expr); ++void dec_ref_count(Expression* expr); /** * Typedef for intrusive shared pointer to Expression. -@@ -733,7 +733,7 @@ inline constexpr void inc_ref_count(Expression* expr) { - * +@@ -801,7 +801,7 @@ constexpr void inc_ref_count(Expression* expr) { * @param expr The shared pointer's managed object. */ --inline constexpr void dec_ref_count(Expression* expr) { -+inline void dec_ref_count(Expression* expr) { + template +-constexpr void dec_ref_count(Expression* expr) { ++void dec_ref_count(Expression* expr) { // If a deeply nested tree is being deallocated all at once, calling the // Expression destructor when expr's refcount reaches zero can cause a stack // overflow. Instead, we iterate over its children to decrement their diff --git a/include/sleipnir/autodiff/variable.hpp b/include/sleipnir/autodiff/variable.hpp -index f60236811eba45c67a9638e90d5101d877ecc2d0..264f0950f293c67d6e6c7e729887090c050e40e2 100644 +index c78af7224b2ef93ad50b238117583e01940c53ce..0a55b906130d7506c80eb150644ac44c222d1368 100644 --- a/include/sleipnir/autodiff/variable.hpp +++ b/include/sleipnir/autodiff/variable.hpp -@@ -47,7 +47,7 @@ class SLEIPNIR_DLLEXPORT Variable { +@@ -61,7 +61,7 @@ class Variable : public SleipnirBase { /** * Constructs an empty Variable. */ @@ -43,35 +43,34 @@ index f60236811eba45c67a9638e90d5101d877ecc2d0..264f0950f293c67d6e6c7e729887090c + explicit Variable(std::nullptr_t) : expr{nullptr} {} /** - * Constructs a Variable from a floating point type. -@@ -77,8 +77,7 @@ class SLEIPNIR_DLLEXPORT Variable { + * Constructs a Variable from a scalar type. +@@ -116,7 +116,7 @@ class Variable : public SleipnirBase { * * @param expr The autodiff variable. */ -- explicit constexpr Variable(detail::ExpressionPtr&& expr) -- : expr{std::move(expr)} {} -+ explicit Variable(detail::ExpressionPtr&& expr) : expr{std::move(expr)} {} +- explicit constexpr Variable(detail::ExpressionPtr&& expr) ++ explicit Variable(detail::ExpressionPtr&& expr) + : expr{std::move(expr)} {} /** - * Assignment operator for double. diff --git a/include/sleipnir/autodiff/variable_matrix.hpp b/include/sleipnir/autodiff/variable_matrix.hpp -index e1a419ca5356660b3c1c27230d1cb2a86977fb65..349a1550235516f9853609b61feded834ef2894b 100644 +index bb66bebc01413a291242886366ce329bb5f4b70a..7ddf02c0e2f66aff8da422b874cbe9772f9fd00d 100644 --- a/include/sleipnir/autodiff/variable_matrix.hpp +++ b/include/sleipnir/autodiff/variable_matrix.hpp -@@ -1120,14 +1120,14 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -1281,14 +1281,14 @@ class VariableMatrix : public SleipnirBase { * - * @return Begin iterator. + * @return Const begin iterator. */ - const_iterator cbegin() const { return const_iterator{m_storage.cbegin()}; } + const_iterator cbegin() const { return const_iterator{m_storage.begin()}; } /** - * Returns end iterator. + * Returns const end iterator. * - * @return End iterator. + * @return Const end iterator. */ - const_iterator cend() const { return const_iterator{m_storage.cend()}; } + const_iterator cend() const { return const_iterator{m_storage.end()}; } /** - * Returns number of elements in matrix. + * Returns reverse begin iterator. diff --git a/upstream_utils/sleipnir_patches/0003-Use-wpi-byteswap.patch b/upstream_utils/sleipnir_patches/0003-Use-wpi-byteswap.patch index 37141f257e..7423819d9b 100644 --- a/upstream_utils/sleipnir_patches/0003-Use-wpi-byteswap.patch +++ b/upstream_utils/sleipnir_patches/0003-Use-wpi-byteswap.patch @@ -4,22 +4,11 @@ Date: Tue, 28 Jan 2025 22:19:14 -0800 Subject: [PATCH 3/8] Use wpi::byteswap() --- - include/.styleguide | 1 + include/sleipnir/util/spy.hpp | 3 ++- - 2 files changed, 3 insertions(+), 1 deletion(-) + 1 file changed, 2 insertions(+), 1 deletion(-) -diff --git a/include/.styleguide b/include/.styleguide -index 4f4c76204071f90bf49eddb8c2aceb583b5e09ba..03938557c2600a7a1f72c6b93c935602f5acb2b2 100644 ---- a/include/.styleguide -+++ b/include/.styleguide -@@ -10,4 +10,5 @@ includeOtherLibs { - ^Eigen/ - ^fmt/ - ^gch/ -+ ^wpi/ - } diff --git a/include/sleipnir/util/spy.hpp b/include/sleipnir/util/spy.hpp -index a2f94803e3744cee771669210d1af883160e9896..74dd7990b03783ce805a186920d5142caeb178c6 100644 +index f9143f2b925064e9df5c763823dcf3d435e7aa28..4b810e54a8038162e03cf08fc8eab52b67b2cdd5 100644 --- a/include/sleipnir/util/spy.hpp +++ b/include/sleipnir/util/spy.hpp @@ -12,6 +12,7 @@ @@ -28,9 +17,9 @@ index a2f94803e3744cee771669210d1af883160e9896..74dd7990b03783ce805a186920d5142c #include +#include - #include "sleipnir/util/symbol_exports.hpp" + namespace slp { -@@ -115,7 +116,7 @@ class SLEIPNIR_DLLEXPORT Spy { +@@ -114,7 +115,7 @@ class Spy { */ void write32le(int32_t num) { if constexpr (std::endian::native != std::endian::little) { diff --git a/upstream_utils/sleipnir_patches/0004-Replace-std-to_underlying.patch b/upstream_utils/sleipnir_patches/0004-Replace-std-to_underlying.patch index f4d84bae89..8dc94c1923 100644 --- a/upstream_utils/sleipnir_patches/0004-Replace-std-to_underlying.patch +++ b/upstream_utils/sleipnir_patches/0004-Replace-std-to_underlying.patch @@ -4,51 +4,43 @@ Date: Tue, 28 Jan 2025 22:19:31 -0800 Subject: [PATCH 4/8] Replace std::to_underlying() --- - src/optimization/problem.cpp | 9 ++++----- - src/util/print_diagnostics.hpp | 6 +++--- - 2 files changed, 7 insertions(+), 8 deletions(-) + include/sleipnir/optimization/problem.hpp | 8 ++++---- + include/sleipnir/util/print_diagnostics.hpp | 6 +++--- + 2 files changed, 7 insertions(+), 7 deletions(-) -diff --git a/src/optimization/problem.cpp b/src/optimization/problem.cpp -index 09828cdb6d7cddff692b9d17603dc0c11cd5a3ec..886de24cc0532d31f1e186150da79e925f212556 100644 ---- a/src/optimization/problem.cpp -+++ b/src/optimization/problem.cpp -@@ -7,7 +7,6 @@ - #include - #include - #include --#include +diff --git a/include/sleipnir/optimization/problem.hpp b/include/sleipnir/optimization/problem.hpp +index 95a33952a5a368c7c81491dbe849a8096357dc38..d20777a5b1912754dda5504313549197e867d34b 100644 +--- a/include/sleipnir/optimization/problem.hpp ++++ b/include/sleipnir/optimization/problem.hpp +@@ -708,11 +708,11 @@ class Problem { + // Print problem structure + slp::println("\nProblem structure:"); + slp::println(" ↳ {} cost function", +- types[std::to_underlying(cost_function_type())]); ++ types[static_cast(cost_function_type())]); + slp::println(" ↳ {} equality constraints", +- types[std::to_underlying(equality_constraint_type())]); ++ types[static_cast(equality_constraint_type())]); + slp::println(" ↳ {} inequality constraints", +- types[std::to_underlying(inequality_constraint_type())]); ++ types[static_cast(inequality_constraint_type())]); - #include - #include -@@ -350,11 +349,11 @@ void Problem::print_problem_analysis() { - // Print problem structure - slp::println("\nProblem structure:"); - slp::println(" ↳ {} cost function", -- types[std::to_underlying(cost_function_type())]); -+ types[static_cast(cost_function_type())]); - slp::println(" ↳ {} equality constraints", -- types[std::to_underlying(equality_constraint_type())]); -+ types[static_cast(equality_constraint_type())]); - slp::println(" ↳ {} inequality constraints", -- types[std::to_underlying(inequality_constraint_type())]); -+ types[static_cast(inequality_constraint_type())]); - - if (m_decision_variables.size() == 1) { - slp::print("\n1 decision variable\n"); -@@ -366,7 +365,7 @@ void Problem::print_problem_analysis() { - [](const gch::small_vector& constraints) { - std::array counts{}; - for (const auto& constraint : constraints) { -- ++counts[std::to_underlying(constraint.type())]; -+ ++counts[static_cast(constraint.type())]; - } - for (const auto& [count, name] : - std::views::zip(counts, std::array{"empty", "constant", "linear", -diff --git a/src/util/print_diagnostics.hpp b/src/util/print_diagnostics.hpp -index fde36957c0258f6e3cd435ef6224d60407012ff7..82e0e082b0e40153dcb2fcd2c655a412a8a9540a 100644 ---- a/src/util/print_diagnostics.hpp -+++ b/src/util/print_diagnostics.hpp -@@ -238,9 +238,9 @@ void print_iteration_diagnostics(int iterations, IterationType type, + if (m_decision_variables.size() == 1) { + slp::print("\n1 decision variable\n"); +@@ -724,7 +724,7 @@ class Problem { + [](const gch::small_vector>& constraints) { + std::array counts{}; + for (const auto& constraint : constraints) { +- ++counts[std::to_underlying(constraint.type())]; ++ ++counts[static_cast(constraint.type())]; + } + for (const auto& [count, name] : + std::views::zip(counts, std::array{"empty", "constant", "linear", +diff --git a/include/sleipnir/util/print_diagnostics.hpp b/include/sleipnir/util/print_diagnostics.hpp +index 9c1f9eb71b9417e138b95fd4d2d678cfb54595d1..032be8fb7b5e4196ff401c77ae9e91f1c966cde6 100644 +--- a/include/sleipnir/util/print_diagnostics.hpp ++++ b/include/sleipnir/util/print_diagnostics.hpp +@@ -252,9 +252,9 @@ void print_iteration_diagnostics(int iterations, IterationType type, slp::println( "│{:4} {:4} {:9.3f} {:12e} {:13e} {:12e} {:12e} {:.2e} {:<5} {:.2e} " "{:.2e} {:2d}│", diff --git a/upstream_utils/sleipnir_patches/0005-Replace-std-views-zip.patch b/upstream_utils/sleipnir_patches/0005-Replace-std-views-zip.patch index f1fd7a7c59..2a20e218c6 100644 --- a/upstream_utils/sleipnir_patches/0005-Replace-std-views-zip.patch +++ b/upstream_utils/sleipnir_patches/0005-Replace-std-views-zip.patch @@ -5,14 +5,14 @@ Subject: [PATCH 5/8] Replace std::views::zip() --- include/sleipnir/autodiff/adjoint_expression_graph.hpp | 5 ++++- - src/optimization/problem.cpp | 9 +++++---- - 2 files changed, 9 insertions(+), 5 deletions(-) + include/sleipnir/optimization/problem.hpp | 8 +++++--- + 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/include/sleipnir/autodiff/adjoint_expression_graph.hpp b/include/sleipnir/autodiff/adjoint_expression_graph.hpp -index 33b6eee615141a1d6472f116842d62052ef54dd9..b333aebd3e59fa23eed6046c13d736c3d2eccac7 100644 +index 16bea7efeeca78d25b34b0b1242ca19cbd05a482..a77323eee9277fc3c77a11ab57ab5003d9ed4543 100644 --- a/include/sleipnir/autodiff/adjoint_expression_graph.hpp +++ b/include/sleipnir/autodiff/adjoint_expression_graph.hpp -@@ -158,7 +158,10 @@ class AdjointExpressionGraph { +@@ -171,7 +171,10 @@ class AdjointExpressionGraph { } } } else { @@ -22,32 +22,24 @@ index 33b6eee615141a1d6472f116842d62052ef54dd9..b333aebd3e59fa23eed6046c13d736c3 + const auto& node = m_top_list[i]; + // Append adjoints of wrt to sparse matrix triplets - if (col != -1 && node->adjoint != 0.0) { + if (col != -1 && node->adjoint != Scalar(0)) { triplets.emplace_back(row, col, node->adjoint); -diff --git a/src/optimization/problem.cpp b/src/optimization/problem.cpp -index 886de24cc0532d31f1e186150da79e925f212556..e32481e9314c9ef472843adb5bedbd993627d5d9 100644 ---- a/src/optimization/problem.cpp -+++ b/src/optimization/problem.cpp -@@ -6,7 +6,6 @@ - #include - #include - #include --#include - - #include - #include -@@ -367,9 +366,11 @@ void Problem::print_problem_analysis() { - for (const auto& constraint : constraints) { - ++counts[static_cast(constraint.type())]; - } -- for (const auto& [count, name] : -- std::views::zip(counts, std::array{"empty", "constant", "linear", -- "quadratic", "nonlinear"})) { -+ for (size_t i = 0; i < counts.size(); ++i) { -+ constexpr std::array names{"empty", "constant", "linear", "quadratic", -+ "nonlinear"}; -+ const auto& count = counts[i]; -+ const auto& name = names[i]; - if (count > 0) { - slp::println(" ↳ {} {}", count, name); +diff --git a/include/sleipnir/optimization/problem.hpp b/include/sleipnir/optimization/problem.hpp +index d20777a5b1912754dda5504313549197e867d34b..5256d08e5f9d8642049d8bb8323d76c7b3bbbef7 100644 +--- a/include/sleipnir/optimization/problem.hpp ++++ b/include/sleipnir/optimization/problem.hpp +@@ -726,9 +726,11 @@ class Problem { + for (const auto& constraint : constraints) { + ++counts[static_cast(constraint.type())]; } +- for (const auto& [count, name] : +- std::views::zip(counts, std::array{"empty", "constant", "linear", +- "quadratic", "nonlinear"})) { ++ for (size_t i = 0; i < counts.size(); ++i) { ++ constexpr std::array names{"empty", "constant", "linear", ++ "quadratic", "nonlinear"}; ++ const auto& count = counts[i]; ++ const auto& name = names[i]; + if (count > 0) { + slp::println(" ↳ {} {}", count, name); + } diff --git a/upstream_utils/sleipnir_patches/0006-Suppress-clang-tidy-false-positives.patch b/upstream_utils/sleipnir_patches/0006-Suppress-clang-tidy-false-positives.patch index 109126b697..8eab50ee95 100644 --- a/upstream_utils/sleipnir_patches/0006-Suppress-clang-tidy-false-positives.patch +++ b/upstream_utils/sleipnir_patches/0006-Suppress-clang-tidy-false-positives.patch @@ -8,10 +8,10 @@ Subject: [PATCH 6/8] Suppress clang-tidy false positives 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/sleipnir/autodiff/variable.hpp b/include/sleipnir/autodiff/variable.hpp -index 264f0950f293c67d6e6c7e729887090c050e40e2..62135a5539308ae69f6b45a64d9337c4c3e96d7b 100644 +index 0a55b906130d7506c80eb150644ac44c222d1368..30ec62161df75c6948bbf3d65432c852a0d926c2 100644 --- a/include/sleipnir/autodiff/variable.hpp +++ b/include/sleipnir/autodiff/variable.hpp -@@ -633,7 +633,7 @@ struct SLEIPNIR_DLLEXPORT InequalityConstraints { +@@ -862,7 +862,7 @@ struct InequalityConstraints { * @param inequality_constraints The list of InequalityConstraints to * concatenate. */ diff --git a/upstream_utils/sleipnir_patches/0007-Suppress-GCC-12-warning-false-positive.patch b/upstream_utils/sleipnir_patches/0007-Suppress-GCC-12-warning-false-positive.patch index aed8ccc605..b4a01648a7 100644 --- a/upstream_utils/sleipnir_patches/0007-Suppress-GCC-12-warning-false-positive.patch +++ b/upstream_utils/sleipnir_patches/0007-Suppress-GCC-12-warning-false-positive.patch @@ -8,12 +8,12 @@ Subject: [PATCH 7/8] Suppress GCC 12 warning false positive 1 file changed, 7 insertions(+) diff --git a/include/sleipnir/autodiff/variable_matrix.hpp b/include/sleipnir/autodiff/variable_matrix.hpp -index 349a1550235516f9853609b61feded834ef2894b..70bccf4fc078a49e22b6699db1228c765430a121 100644 +index 7ddf02c0e2f66aff8da422b874cbe9772f9fd00d..351030b4041027ba63a2e6ec08f2077b3c35b5db 100644 --- a/include/sleipnir/autodiff/variable_matrix.hpp +++ b/include/sleipnir/autodiff/variable_matrix.hpp -@@ -573,6 +573,10 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -578,6 +578,10 @@ class VariableMatrix : public SleipnirBase { - VariableMatrix result(VariableMatrix::empty, lhs.rows(), rhs.cols()); + VariableMatrix result(detail::empty, lhs.rows(), rhs.cols()); +#if __GNUC__ >= 12 +#pragma GCC diagnostic push @@ -21,8 +21,8 @@ index 349a1550235516f9853609b61feded834ef2894b..70bccf4fc078a49e22b6699db1228c76 +#endif for (int i = 0; i < lhs.rows(); ++i) { for (int j = 0; j < rhs.cols(); ++j) { - Variable sum; -@@ -590,6 +594,9 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { + Variable sum{Scalar(0)}; +@@ -637,6 +641,9 @@ class VariableMatrix : public SleipnirBase { result[i, j] = sum; } } diff --git a/upstream_utils/sleipnir_patches/0008-Revert-Use-multidimensional-array-subscript-operator.patch b/upstream_utils/sleipnir_patches/0008-Use-operator-instead-of-multidimensional-array-subsc.patch similarity index 61% rename from upstream_utils/sleipnir_patches/0008-Revert-Use-multidimensional-array-subscript-operator.patch rename to upstream_utils/sleipnir_patches/0008-Use-operator-instead-of-multidimensional-array-subsc.patch index b04d79f5bc..c080886432 100644 --- a/upstream_utils/sleipnir_patches/0008-Revert-Use-multidimensional-array-subscript-operator.patch +++ b/upstream_utils/sleipnir_patches/0008-Use-operator-instead-of-multidimensional-array-subsc.patch @@ -1,109 +1,96 @@ From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 From: Tyler Veness Date: Sat, 12 Apr 2025 16:28:47 -0700 -Subject: [PATCH 8/8] Revert "Use multidimensional array subscript operator - (#843)" +Subject: [PATCH 8/8] Use operator() instead of multidimensional array + subscript operator -This reverts commit f9b2c450bbbf6f14b194b8b81708d032a6431ee0. --- include/sleipnir/autodiff/hessian.hpp | 4 +- include/sleipnir/autodiff/jacobian.hpp | 4 +- - include/sleipnir/autodiff/variable.hpp | 26 +---- - include/sleipnir/autodiff/variable_block.hpp | 70 +++++------ - include/sleipnir/autodiff/variable_matrix.hpp | 110 ++++++------------ - include/sleipnir/optimization/ocp.hpp | 14 +-- + include/sleipnir/autodiff/variable.hpp | 8 +- + include/sleipnir/autodiff/variable_block.hpp | 74 ++++---- + include/sleipnir/autodiff/variable_matrix.hpp | 158 +++++++++--------- + include/sleipnir/optimization/ocp.hpp | 14 +- include/sleipnir/optimization/problem.hpp | 6 +- - src/autodiff/variable_matrix.cpp | 66 +++++------ - 8 files changed, 118 insertions(+), 182 deletions(-) + 7 files changed, 134 insertions(+), 134 deletions(-) diff --git a/include/sleipnir/autodiff/hessian.hpp b/include/sleipnir/autodiff/hessian.hpp -index fa6d8af0843eca8b674744f02551584dd8d79c21..4f093b7b39ea84e56c4a12ae1b6f645c4f84a1f0 100644 +index 629b6b88274f3d0e6126fd68ccbc219618386518..10ee142ff8f02a9b9f2dc73a6b9c9efad7341ad2 100644 --- a/include/sleipnir/autodiff/hessian.hpp +++ b/include/sleipnir/autodiff/hessian.hpp -@@ -106,9 +106,9 @@ class SLEIPNIR_DLLEXPORT Hessian { +@@ -106,9 +106,9 @@ class Hessian { auto grad = m_graphs[row].generate_gradient_tree(m_wrt); for (int col = 0; col < m_wrt.rows(); ++col) { if (grad[col].expr != nullptr) { - result[row, col] = std::move(grad[col]); + result(row, col) = std::move(grad[col]); } else { -- result[row, col] = Variable{0.0}; -+ result(row, col) = Variable{0.0}; +- result[row, col] = Variable{Scalar(0)}; ++ result(row, col) = Variable{Scalar(0)}; } } } diff --git a/include/sleipnir/autodiff/jacobian.hpp b/include/sleipnir/autodiff/jacobian.hpp -index 4515076cde12a2112e1b5711acc3092bd807e250..3662b5e49b93f63b5ccac0e732149bd9178f1aae 100644 +index b7cedd63d554d6ccfa42c6d8deb62da27950cd53..c8e28a826f619bee201d3383a4dda23f148fa0b1 100644 --- a/include/sleipnir/autodiff/jacobian.hpp +++ b/include/sleipnir/autodiff/jacobian.hpp -@@ -99,9 +99,9 @@ class SLEIPNIR_DLLEXPORT Jacobian { +@@ -114,9 +114,9 @@ class Jacobian { auto grad = m_graphs[row].generate_gradient_tree(m_wrt); for (int col = 0; col < m_wrt.rows(); ++col) { if (grad[col].expr != nullptr) { - result[row, col] = std::move(grad[col]); + result(row, col) = std::move(grad[col]); } else { -- result[row, col] = Variable{0.0}; -+ result(row, col) = Variable{0.0}; +- result[row, col] = Variable{Scalar(0)}; ++ result(row, col) = Variable{Scalar(0)}; } } } diff --git a/include/sleipnir/autodiff/variable.hpp b/include/sleipnir/autodiff/variable.hpp -index 62135a5539308ae69f6b45a64d9337c4c3e96d7b..2fc2119d2dedaa5b4c941ce449b7fb113c641635 100644 +index 30ec62161df75c6948bbf3d65432c852a0d926c2..cb4c1a56ecd16ee2cd27cdd3a866fea3226ce388 100644 --- a/include/sleipnir/autodiff/variable.hpp +++ b/include/sleipnir/autodiff/variable.hpp -@@ -512,11 +512,7 @@ gch::small_vector make_constraints(LHS&& lhs, RHS&& rhs) { - for (int row = 0; row < rhs.rows(); ++row) { - for (int col = 0; col < rhs.cols(); ++col) { - // Make right-hand side zero -- if constexpr (EigenMatrixLike>) { -- constraints.emplace_back(lhs - rhs(row, col)); -- } else { -- constraints.emplace_back(lhs - rhs[row, col]); -- } -+ constraints.emplace_back(lhs - rhs(row, col)); - } - } - } else if constexpr (MatrixLike && ScalarLike) { -@@ -525,11 +521,7 @@ gch::small_vector make_constraints(LHS&& lhs, RHS&& rhs) { - for (int row = 0; row < lhs.rows(); ++row) { - for (int col = 0; col < lhs.cols(); ++col) { - // Make right-hand side zero -- if constexpr (EigenMatrixLike>) { -- constraints.emplace_back(lhs(row, col) - rhs); -- } else { -- constraints.emplace_back(lhs[row, col] - rhs); -- } -+ constraints.emplace_back(lhs(row, col) - rhs); - } - } - } else if constexpr (MatrixLike && MatrixLike) { -@@ -539,19 +531,7 @@ gch::small_vector make_constraints(LHS&& lhs, RHS&& rhs) { - for (int row = 0; row < lhs.rows(); ++row) { - for (int col = 0; col < lhs.cols(); ++col) { - // Make right-hand side zero -- if constexpr (EigenMatrixLike> && -- EigenMatrixLike>) { -- constraints.emplace_back(lhs(row, col) - rhs(row, col)); -- } else if constexpr (EigenMatrixLike> && -- SleipnirMatrixLike>) { -- constraints.emplace_back(lhs(row, col) - rhs[row, col]); -- } else if constexpr (SleipnirMatrixLike> && -- EigenMatrixLike>) { -- constraints.emplace_back(lhs[row, col] - rhs(row, col)); -- } else if constexpr (SleipnirMatrixLike> && -- SleipnirMatrixLike>) { -- constraints.emplace_back(lhs[row, col] - rhs[row, col]); -- } -+ constraints.emplace_back(lhs(row, col) - rhs(row, col)); - } +@@ -80,7 +80,7 @@ class Variable : public SleipnirBase { + * @param value The value of the Variable. + */ + // NOLINTNEXTLINE (google-explicit-constructor) +- Variable(SleipnirMatrixLike auto value) : expr{value[0, 0].expr} { ++ Variable(SleipnirMatrixLike auto value) : expr{value(0, 0).expr} { + slp_assert(value.rows() == 1 && value.cols() == 1); + } + +@@ -740,7 +740,7 @@ auto make_constraints(LHS&& lhs, RHS&& rhs) { + for (int row = 0; row < rhs.rows(); ++row) { + for (int col = 0; col < rhs.cols(); ++col) { + // Make right-hand side zero +- constraints.emplace_back(lhs - rhs[row, col]); ++ constraints.emplace_back(lhs - rhs(row, col)); } } + +@@ -756,7 +756,7 @@ auto make_constraints(LHS&& lhs, RHS&& rhs) { + for (int row = 0; row < lhs.rows(); ++row) { + for (int col = 0; col < lhs.cols(); ++col) { + // Make right-hand side zero +- constraints.emplace_back(lhs[row, col] - rhs); ++ constraints.emplace_back(lhs(row, col) - rhs); + } + } + +@@ -774,7 +774,7 @@ auto make_constraints(LHS&& lhs, RHS&& rhs) { + for (int row = 0; row < lhs.rows(); ++row) { + for (int col = 0; col < lhs.cols(); ++col) { + // Make right-hand side zero +- constraints.emplace_back(lhs[row, col] - rhs[row, col]); ++ constraints.emplace_back(lhs(row, col) - rhs(row, col)); + } + } + diff --git a/include/sleipnir/autodiff/variable_block.hpp b/include/sleipnir/autodiff/variable_block.hpp -index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fee39fe537 100644 +index d1b5ac928890dba3052918fc828371dedf26158d..c5351fec9f18f47e2fdfd724699036165c5b8506 100644 --- a/include/sleipnir/autodiff/variable_block.hpp +++ b/include/sleipnir/autodiff/variable_block.hpp -@@ -50,7 +50,7 @@ class VariableBlock { +@@ -57,7 +57,7 @@ class VariableBlock : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -112,7 +99,7 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } } } -@@ -85,7 +85,7 @@ class VariableBlock { +@@ -92,7 +92,7 @@ class VariableBlock : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -121,7 +108,7 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } } } -@@ -152,7 +152,7 @@ class VariableBlock { +@@ -155,7 +155,7 @@ class VariableBlock : public SleipnirBase { VariableBlock& operator=(ScalarLike auto value) { slp_assert(rows() == 1 && cols() == 1); @@ -130,8 +117,8 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe return *this; } -@@ -167,7 +167,7 @@ class VariableBlock { - void set_value(double value) { +@@ -170,7 +170,7 @@ class VariableBlock : public SleipnirBase { + void set_value(Scalar value) { slp_assert(rows() == 1 && cols() == 1); - (*this)[0, 0].set_value(value); @@ -139,25 +126,25 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } /** -@@ -182,7 +182,7 @@ class VariableBlock { +@@ -185,7 +185,7 @@ class VariableBlock : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { -- (*this)[row, col] = values(row, col); +- (*this)[row, col] = values[row, col]; + (*this)(row, col) = values(row, col); } } -@@ -201,7 +201,7 @@ class VariableBlock { +@@ -204,7 +204,7 @@ class VariableBlock : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { -- (*this)[row, col].set_value(values(row, col)); +- (*this)[row, col].set_value(values[row, col]); + (*this)(row, col).set_value(values(row, col)); } } } -@@ -217,7 +217,7 @@ class VariableBlock { +@@ -220,7 +220,7 @@ class VariableBlock : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -166,7 +153,7 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } } return *this; -@@ -234,7 +234,7 @@ class VariableBlock { +@@ -237,7 +237,7 @@ class VariableBlock : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -175,12 +162,12 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } } return *this; -@@ -247,13 +247,13 @@ class VariableBlock { +@@ -250,13 +250,13 @@ class VariableBlock : public SleipnirBase { * @param col The scalar subblock's column. * @return A scalar subblock at the given row and column. */ -- Variable& operator[](int row, int col) -+ Variable& operator()(int row, int col) +- Variable& operator[](int row, int col) ++ Variable& operator()(int row, int col) requires(!std::is_const_v) { slp_assert(row >= 0 && row < rows()); @@ -192,12 +179,12 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } /** -@@ -263,11 +263,11 @@ class VariableBlock { +@@ -266,11 +266,11 @@ class VariableBlock : public SleipnirBase { * @param col The scalar subblock's column. * @return A scalar subblock at the given row and column. */ -- const Variable& operator[](int row, int col) const { -+ const Variable& operator()(int row, int col) const { +- const Variable& operator[](int row, int col) const { ++ const Variable& operator()(int row, int col) const { slp_assert(row >= 0 && row < rows()); slp_assert(col >= 0 && col < cols()); - return (*m_mat)[m_row_slice.start + row * m_row_slice.step, @@ -207,25 +194,25 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } /** -@@ -280,7 +280,7 @@ class VariableBlock { +@@ -283,7 +283,7 @@ class VariableBlock : public SleipnirBase { requires(!std::is_const_v) { - slp_assert(row >= 0 && row < rows() * cols()); -- return (*this)[row / cols(), row % cols()]; -+ return (*this)(row / cols(), row % cols()); + slp_assert(index >= 0 && index < rows() * cols()); +- return (*this)[index / cols(), index % cols()]; ++ return (*this)(index / cols(), index % cols()); } /** -@@ -291,7 +291,7 @@ class VariableBlock { +@@ -294,7 +294,7 @@ class VariableBlock : public SleipnirBase { */ - const Variable& operator[](int row) const { - slp_assert(row >= 0 && row < rows() * cols()); -- return (*this)[row / cols(), row % cols()]; -+ return (*this)(row / cols(), row % cols()); + const Variable& operator[](int index) const { + slp_assert(index >= 0 && index < rows() * cols()); +- return (*this)[index / cols(), index % cols()]; ++ return (*this)(index / cols(), index % cols()); } /** -@@ -309,8 +309,8 @@ class VariableBlock { +@@ -312,8 +312,8 @@ class VariableBlock : public SleipnirBase { slp_assert(col_offset >= 0 && col_offset <= cols()); slp_assert(block_rows >= 0 && block_rows <= rows() - row_offset); slp_assert(block_cols >= 0 && block_cols <= cols() - col_offset); @@ -236,7 +223,7 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } /** -@@ -328,8 +328,8 @@ class VariableBlock { +@@ -331,8 +331,8 @@ class VariableBlock : public SleipnirBase { slp_assert(col_offset >= 0 && col_offset <= cols()); slp_assert(block_rows >= 0 && block_rows <= rows() - row_offset); slp_assert(block_cols >= 0 && block_cols <= cols() - col_offset); @@ -247,7 +234,7 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } /** -@@ -339,7 +339,7 @@ class VariableBlock { +@@ -342,10 +342,10 @@ class VariableBlock : public SleipnirBase { * @param col_slice The column slice. * @return A slice of the variable matrix. */ @@ -255,8 +242,12 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe + VariableBlock operator()(Slice row_slice, Slice col_slice) { int row_slice_length = row_slice.adjust(m_row_slice_length); int col_slice_length = col_slice.adjust(m_col_slice_length); - return VariableBlock{ -@@ -359,7 +359,7 @@ class VariableBlock { +- return (*this)[row_slice, row_slice_length, col_slice, col_slice_length]; ++ return (*this)(row_slice, row_slice_length, col_slice, col_slice_length); + } + + /** +@@ -355,11 +355,11 @@ class VariableBlock : public SleipnirBase { * @param col_slice The column slice. * @return A slice of the variable matrix. */ @@ -265,7 +256,12 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe Slice col_slice) const { int row_slice_length = row_slice.adjust(m_row_slice_length); int col_slice_length = col_slice.adjust(m_col_slice_length); -@@ -385,7 +385,7 @@ class VariableBlock { +- return (*this)[row_slice, row_slice_length, col_slice, col_slice_length]; ++ return (*this)(row_slice, row_slice_length, col_slice, col_slice_length); + } + + /** +@@ -374,7 +374,7 @@ class VariableBlock : public SleipnirBase { * @param col_slice_length The column slice length. * @return A slice of the variable matrix. */ @@ -274,7 +270,7 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe Slice col_slice, int col_slice_length) { return VariableBlock{ *m_mat, -@@ -409,7 +409,7 @@ class VariableBlock { +@@ -400,7 +400,7 @@ class VariableBlock : public SleipnirBase { * @param col_slice_length The column slice length. * @return A slice of the variable matrix. */ @@ -283,7 +279,7 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe int row_slice_length, Slice col_slice, int col_slice_length) const { -@@ -524,7 +524,7 @@ class VariableBlock { +@@ -519,7 +519,7 @@ class VariableBlock : public SleipnirBase { VariableBlock& operator*=(const ScalarLike auto& rhs) { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -292,7 +288,7 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } } -@@ -542,7 +542,7 @@ class VariableBlock { +@@ -537,7 +537,7 @@ class VariableBlock : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -301,7 +297,7 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } } -@@ -558,7 +558,7 @@ class VariableBlock { +@@ -553,7 +553,7 @@ class VariableBlock : public SleipnirBase { VariableBlock& operator/=(const ScalarLike auto& rhs) { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -310,7 +306,7 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } } -@@ -576,7 +576,7 @@ class VariableBlock { +@@ -571,7 +571,7 @@ class VariableBlock : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -319,7 +315,7 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } } -@@ -594,7 +594,7 @@ class VariableBlock { +@@ -589,7 +589,7 @@ class VariableBlock : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -328,7 +324,7 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } } -@@ -612,7 +612,7 @@ class VariableBlock { +@@ -607,7 +607,7 @@ class VariableBlock : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -337,7 +333,7 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } } -@@ -630,7 +630,7 @@ class VariableBlock { +@@ -625,7 +625,7 @@ class VariableBlock : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -346,7 +342,7 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } } -@@ -655,7 +655,7 @@ class VariableBlock { +@@ -651,7 +651,7 @@ class VariableBlock : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -355,18 +351,25 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } } -@@ -686,8 +686,8 @@ class VariableBlock { - double value(int row, int col) { - slp_assert(row >= 0 && row < rows()); - slp_assert(col >= 0 && col < cols()); -- return (*m_mat)[m_row_slice.start + row * m_row_slice.step, -- m_col_slice.start + col * m_col_slice.step] -+ return (*m_mat)(m_row_slice.start + row * m_row_slice.step, -+ m_col_slice.start + col * m_col_slice.step) - .value(); - } +@@ -679,7 +679,7 @@ class VariableBlock : public SleipnirBase { + * @param col The column of the element to return. + * @return An element of the variable matrix. + */ +- Scalar value(int row, int col) { return (*this)[row, col].value(); } ++ Scalar value(int row, int col) { return (*this)(row, col).value(); } -@@ -731,7 +731,7 @@ class VariableBlock { + /** + * Returns an element of the variable block. +@@ -703,7 +703,7 @@ class VariableBlock : public SleipnirBase { + + for (int row = 0; row < rows(); ++row) { + for (int col = 0; col < cols(); ++col) { +- result[row, col] = value(row, col); ++ result(row, col) = value(row, col); + } + } + +@@ -723,7 +723,7 @@ class VariableBlock : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -376,19 +379,46 @@ index f1c1ca0dc3fde663c3e74f6fca4b89b119cf377d..632d44beb5b3dae29b9829c52a6168fe } diff --git a/include/sleipnir/autodiff/variable_matrix.hpp b/include/sleipnir/autodiff/variable_matrix.hpp -index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da506e382b 100644 +index 351030b4041027ba63a2e6ec08f2077b3c35b5db..55788ce18fcfaa8631ea46b021ee867024ecddb2 100644 --- a/include/sleipnir/autodiff/variable_matrix.hpp +++ b/include/sleipnir/autodiff/variable_matrix.hpp -@@ -211,7 +211,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -174,7 +174,7 @@ class VariableMatrix : public SleipnirBase { + m_storage.reserve(values.rows() * values.cols()); + for (int row = 0; row < values.rows(); ++row) { + for (int col = 0; col < values.cols(); ++col) { +- m_storage.emplace_back(values[row, col]); ++ m_storage.emplace_back(values(row, col)); + } + } + } +@@ -232,7 +232,7 @@ class VariableMatrix : public SleipnirBase { + m_storage.reserve(rows() * cols()); + for (int row = 0; row < rows(); ++row) { + for (int col = 0; col < cols(); ++col) { +- m_storage.emplace_back(values[row, col]); ++ m_storage.emplace_back(values(row, col)); + } + } + } +@@ -248,7 +248,7 @@ class VariableMatrix : public SleipnirBase { + m_storage.reserve(rows() * cols()); + for (int row = 0; row < rows(); ++row) { + for (int col = 0; col < cols(); ++col) { +- m_storage.emplace_back(values[row, col]); ++ m_storage.emplace_back(values(row, col)); + } + } + } +@@ -298,7 +298,7 @@ class VariableMatrix : public SleipnirBase { for (int row = 0; row < values.rows(); ++row) { for (int col = 0; col < values.cols(); ++col) { -- (*this)[row, col] = values(row, col); +- (*this)[row, col] = values[row, col]; + (*this)(row, col) = values(row, col); } } -@@ -229,7 +229,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -316,7 +316,7 @@ class VariableMatrix : public SleipnirBase { VariableMatrix& operator=(ScalarLike auto value) { slp_assert(rows() == 1 && cols() == 1); @@ -397,52 +427,34 @@ index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da return *this; } -@@ -246,7 +246,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -333,7 +333,7 @@ class VariableMatrix : public SleipnirBase { for (int row = 0; row < values.rows(); ++row) { for (int col = 0; col < values.cols(); ++col) { -- (*this)[row, col].set_value(values(row, col)); +- (*this)[row, col].set_value(values[row, col]); + (*this)(row, col).set_value(values(row, col)); } } } -@@ -280,7 +280,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { - m_storage.reserve(rows() * cols()); - for (int row = 0; row < rows(); ++row) { - for (int col = 0; col < cols(); ++col) { -- m_storage.emplace_back(values[row, col]); -+ m_storage.emplace_back(values(row, col)); - } - } - } -@@ -295,7 +295,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { - m_storage.reserve(rows() * cols()); - for (int row = 0; row < rows(); ++row) { - for (int col = 0; col < cols(); ++col) { -- m_storage.emplace_back(values[row, col]); -+ m_storage.emplace_back(values(row, col)); - } - } - } -@@ -340,7 +340,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { - * @param col The block column. - * @return A block pointing to the given row and column. +@@ -345,7 +345,7 @@ class VariableMatrix : public SleipnirBase { + * @param col The column. + * @return The element at the given row and column. */ -- Variable& operator[](int row, int col) { -+ Variable& operator()(int row, int col) { +- Variable& operator[](int row, int col) { ++ Variable& operator()(int row, int col) { slp_assert(row >= 0 && row < rows()); slp_assert(col >= 0 && col < cols()); return m_storage[row * cols() + col]; -@@ -353,7 +353,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { - * @param col The block column. - * @return A block pointing to the given row and column. +@@ -358,7 +358,7 @@ class VariableMatrix : public SleipnirBase { + * @param col The column. + * @return The element at the given row and column. */ -- const Variable& operator[](int row, int col) const { -+ const Variable& operator()(int row, int col) const { +- const Variable& operator[](int row, int col) const { ++ const Variable& operator()(int row, int col) const { slp_assert(row >= 0 && row < rows()); slp_assert(col >= 0 && col < cols()); return m_storage[row * cols() + col]; -@@ -426,7 +426,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -431,7 +431,7 @@ class VariableMatrix : public SleipnirBase { * @param col_slice The column slice. * @return A slice of the variable matrix. */ @@ -451,7 +463,7 @@ index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da int row_slice_length = row_slice.adjust(rows()); int col_slice_length = col_slice.adjust(cols()); return VariableBlock{*this, std::move(row_slice), row_slice_length, -@@ -440,7 +440,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -445,7 +445,7 @@ class VariableMatrix : public SleipnirBase { * @param col_slice The column slice. * @return A slice of the variable matrix. */ @@ -460,7 +472,7 @@ index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da Slice col_slice) const { int row_slice_length = row_slice.adjust(rows()); int col_slice_length = col_slice.adjust(cols()); -@@ -461,7 +461,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -466,7 +466,7 @@ class VariableMatrix : public SleipnirBase { * @return A slice of the variable matrix. * */ @@ -469,7 +481,7 @@ index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da int row_slice_length, Slice col_slice, int col_slice_length) { -@@ -481,7 +481,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -486,7 +486,7 @@ class VariableMatrix : public SleipnirBase { * @param col_slice_length The column slice length. * @return A slice of the variable matrix. */ @@ -478,19 +490,35 @@ index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da Slice row_slice, int row_slice_length, Slice col_slice, int col_slice_length) const { return VariableBlock{*this, std::move(row_slice), row_slice_length, -@@ -581,17 +581,9 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -586,9 +586,9 @@ class VariableMatrix : public SleipnirBase { for (int j = 0; j < rhs.cols(); ++j) { - Variable sum; + Variable sum{Scalar(0)}; for (int k = 0; k < lhs.cols(); ++k) { -- if constexpr (SleipnirMatrixLike && SleipnirMatrixLike) { -- sum += lhs[i, k] * rhs[k, j]; -- } else if constexpr (SleipnirMatrixLike && -- EigenMatrixLike) { -- sum += lhs[i, k] * rhs(k, j); -- } else if constexpr (EigenMatrixLike && -- SleipnirMatrixLike) { -- sum += lhs(i, k) * rhs[k, j]; -- } +- sum += lhs(i, k) * rhs[k, j]; ++ sum += lhs(i, k) * rhs(k, j); + } +- result[i, j] = sum; ++ result(i, j) = sum; + } + } + +@@ -611,9 +611,9 @@ class VariableMatrix : public SleipnirBase { + for (int j = 0; j < rhs.cols(); ++j) { + Variable sum{Scalar(0)}; + for (int k = 0; k < lhs.cols(); ++k) { +- sum += lhs[i, k] * rhs(k, j); ++ sum += lhs(i, k) * rhs(k, j); + } +- result[i, j] = sum; ++ result(i, j) = sum; + } + } + +@@ -636,9 +636,9 @@ class VariableMatrix : public SleipnirBase { + for (int j = 0; j < rhs.cols(); ++j) { + Variable sum{Scalar(0)}; + for (int k = 0; k < lhs.cols(); ++k) { +- sum += lhs[i, k] * rhs[k, j]; + sum += lhs(i, k) * rhs(k, j); } - result[i, j] = sum; @@ -498,7 +526,7 @@ index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da } } #if __GNUC__ >= 12 -@@ -613,7 +605,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -661,7 +661,7 @@ class VariableMatrix : public SleipnirBase { for (int row = 0; row < result.rows(); ++row) { for (int col = 0; col < result.cols(); ++col) { @@ -507,20 +535,16 @@ index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da } } -@@ -632,11 +624,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -680,7 +680,7 @@ class VariableMatrix : public SleipnirBase { for (int row = 0; row < result.rows(); ++row) { for (int col = 0; col < result.cols(); ++col) { -- if constexpr (SleipnirMatrixLike) { -- result[row, col] = lhs[row, col] * rhs; -- } else { -- result[row, col] = lhs(row, col) * rhs; -- } +- result[row, col] = lhs[row, col] * rhs; + result(row, col) = lhs(row, col) * rhs; } } -@@ -655,7 +643,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -700,7 +700,7 @@ class VariableMatrix : public SleipnirBase { for (int row = 0; row < result.rows(); ++row) { for (int col = 0; col < result.cols(); ++col) { @@ -529,28 +553,20 @@ index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da } } -@@ -674,11 +662,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -719,7 +719,7 @@ class VariableMatrix : public SleipnirBase { for (int row = 0; row < result.rows(); ++row) { for (int col = 0; col < result.cols(); ++col) { -- if constexpr (SleipnirMatrixLike) { -- result[row, col] = rhs[row, col] * lhs; -- } else { -- result[row, col] = rhs(row, col) * lhs; -- } +- result[row, col] = rhs[row, col] * lhs; + result(row, col) = rhs(row, col) * lhs; } } -@@ -698,13 +682,9 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -739,9 +739,9 @@ class VariableMatrix : public SleipnirBase { for (int j = 0; j < rhs.cols(); ++j) { - Variable sum; + Variable sum{Scalar(0)}; for (int k = 0; k < cols(); ++k) { -- if constexpr (SleipnirMatrixLike) { -- sum += (*this)[i, k] * rhs[k, j]; -- } else { -- sum += (*this)[i, k] * rhs(k, j); -- } +- sum += (*this)[i, k] * rhs[k, j]; + sum += (*this)(i, k) * rhs(k, j); } - (*this)[i, j] = sum; @@ -558,7 +574,7 @@ index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da } } -@@ -720,7 +700,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -757,7 +757,7 @@ class VariableMatrix : public SleipnirBase { VariableMatrix& operator*=(const ScalarLike auto& rhs) { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < rhs.cols(); ++col) { @@ -567,20 +583,34 @@ index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da } } -@@ -740,11 +720,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -778,7 +778,7 @@ class VariableMatrix : public SleipnirBase { for (int row = 0; row < result.rows(); ++row) { for (int col = 0; col < result.cols(); ++col) { -- if constexpr (SleipnirMatrixLike) { -- result[row, col] = lhs[row, col] / rhs; -- } else { -- result[row, col] = lhs(row, col) / rhs; -- } +- result[row, col] = lhs[row, col] / rhs; + result(row, col) = lhs(row, col) / rhs; } } -@@ -760,7 +736,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -799,7 +799,7 @@ class VariableMatrix : public SleipnirBase { + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { +- result[row, col] = lhs[row, col] / rhs; ++ result(row, col) = lhs(row, col) / rhs; + } + } + +@@ -820,7 +820,7 @@ class VariableMatrix : public SleipnirBase { + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { +- result[row, col] = lhs[row, col] / rhs; ++ result(row, col) = lhs(row, col) / rhs; + } + } + +@@ -836,7 +836,7 @@ class VariableMatrix : public SleipnirBase { VariableMatrix& operator/=(const ScalarLike auto& rhs) { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -589,35 +619,43 @@ index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da } } -@@ -784,13 +760,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -858,7 +858,7 @@ class VariableMatrix : public SleipnirBase { for (int row = 0; row < result.rows(); ++row) { for (int col = 0; col < result.cols(); ++col) { -- if constexpr (SleipnirMatrixLike && SleipnirMatrixLike) { -- result[row, col] = lhs[row, col] + rhs[row, col]; -- } else if constexpr (SleipnirMatrixLike && EigenMatrixLike) { -- result[row, col] = lhs[row, col] + rhs(row, col); -- } else if constexpr (EigenMatrixLike && SleipnirMatrixLike) { -- result[row, col] = lhs(row, col) + rhs[row, col]; -- } +- result[row, col] = lhs[row, col] + rhs[row, col]; + result(row, col) = lhs(row, col) + rhs(row, col); } } -@@ -808,11 +778,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -880,7 +880,7 @@ class VariableMatrix : public SleipnirBase { + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { +- result[row, col] = lhs[row, col] + rhs[row, col]; ++ result(row, col) = lhs(row, col) + rhs(row, col); + } + } + +@@ -902,7 +902,7 @@ class VariableMatrix : public SleipnirBase { + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { +- result[row, col] = lhs[row, col] + rhs[row, col]; ++ result(row, col) = lhs(row, col) + rhs(row, col); + } + } + +@@ -920,7 +920,7 @@ class VariableMatrix : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { -- if constexpr (SleipnirMatrixLike) { -- (*this)[row, col] += rhs[row, col]; -- } else { -- (*this)[row, col] += rhs(row, col); -- } +- (*this)[row, col] += rhs[row, col]; + (*this)(row, col) += rhs(row, col); } } -@@ -830,7 +796,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -938,7 +938,7 @@ class VariableMatrix : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -626,35 +664,43 @@ index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da } } -@@ -854,13 +820,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -960,7 +960,7 @@ class VariableMatrix : public SleipnirBase { for (int row = 0; row < result.rows(); ++row) { for (int col = 0; col < result.cols(); ++col) { -- if constexpr (SleipnirMatrixLike && SleipnirMatrixLike) { -- result[row, col] = lhs[row, col] - rhs[row, col]; -- } else if constexpr (SleipnirMatrixLike && EigenMatrixLike) { -- result[row, col] = lhs[row, col] - rhs(row, col); -- } else if constexpr (EigenMatrixLike && SleipnirMatrixLike) { -- result[row, col] = lhs(row, col) - rhs[row, col]; -- } +- result[row, col] = lhs[row, col] - rhs[row, col]; + result(row, col) = lhs(row, col) - rhs(row, col); } } -@@ -878,11 +838,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -982,7 +982,7 @@ class VariableMatrix : public SleipnirBase { + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { +- result[row, col] = lhs[row, col] - rhs[row, col]; ++ result(row, col) = lhs(row, col) - rhs(row, col); + } + } + +@@ -1004,7 +1004,7 @@ class VariableMatrix : public SleipnirBase { + + for (int row = 0; row < result.rows(); ++row) { + for (int col = 0; col < result.cols(); ++col) { +- result[row, col] = lhs[row, col] - rhs[row, col]; ++ result(row, col) = lhs(row, col) - rhs(row, col); + } + } + +@@ -1022,7 +1022,7 @@ class VariableMatrix : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { -- if constexpr (SleipnirMatrixLike) { -- (*this)[row, col] -= rhs[row, col]; -- } else { -- (*this)[row, col] -= rhs(row, col); -- } +- (*this)[row, col] -= rhs[row, col]; + (*this)(row, col) -= rhs(row, col); } } -@@ -900,7 +856,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -1040,7 +1040,7 @@ class VariableMatrix : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -663,7 +709,7 @@ index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da } } -@@ -918,7 +874,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -1058,7 +1058,7 @@ class VariableMatrix : public SleipnirBase { for (int row = 0; row < result.rows(); ++row) { for (int col = 0; col < result.cols(); ++col) { @@ -672,16 +718,16 @@ index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da } } -@@ -930,7 +886,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { - */ - operator Variable() const { // NOLINT +@@ -1071,7 +1071,7 @@ class VariableMatrix : public SleipnirBase { + // NOLINTNEXTLINE (google-explicit-constructor) + operator Variable() const { slp_assert(rows() == 1 && cols() == 1); - return (*this)[0, 0]; + return (*this)(0, 0); } /** -@@ -943,7 +899,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -1084,7 +1084,7 @@ class VariableMatrix : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -690,7 +736,25 @@ index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da } } -@@ -1017,7 +973,7 @@ class SLEIPNIR_DLLEXPORT VariableMatrix { +@@ -1112,7 +1112,7 @@ class VariableMatrix : public SleipnirBase { + * @param col The column of the element to return. + * @return An element of the variable matrix. + */ +- Scalar value(int row, int col) { return (*this)[row, col].value(); } ++ Scalar value(int row, int col) { return (*this)(row, col).value(); } + + /** + * Returns an element of the variable matrix. +@@ -1133,7 +1133,7 @@ class VariableMatrix : public SleipnirBase { + + for (int row = 0; row < rows(); ++row) { + for (int col = 0; col < cols(); ++col) { +- result[row, col] = value(row, col); ++ result(row, col) = value(row, col); + } + } + +@@ -1153,7 +1153,7 @@ class VariableMatrix : public SleipnirBase { for (int row = 0; row < rows(); ++row) { for (int col = 0; col < cols(); ++col) { @@ -699,7 +763,7 @@ index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da } } -@@ -1199,7 +1155,7 @@ SLEIPNIR_DLLEXPORT inline VariableMatrix cwise_reduce( +@@ -1422,7 +1422,7 @@ VariableMatrix cwise_reduce( for (int row = 0; row < lhs.rows(); ++row) { for (int col = 0; col < lhs.cols(); ++col) { @@ -708,99 +772,7 @@ index 70bccf4fc078a49e22b6699db1228c765430a121..2ed997819e70c584ce413f639826b6da } } -diff --git a/include/sleipnir/optimization/ocp.hpp b/include/sleipnir/optimization/ocp.hpp -index 124224cf5ba6e54c141086e3a21389530198449f..74492a0d756a9d587df6158c7e2ef8548ae22be4 100644 ---- a/include/sleipnir/optimization/ocp.hpp -+++ b/include/sleipnir/optimization/ocp.hpp -@@ -122,7 +122,7 @@ class SLEIPNIR_DLLEXPORT OCP : public Problem { - if (timestep_method == TimestepMethod::FIXED) { - m_DT = VariableMatrix{1, m_num_steps + 1}; - for (int i = 0; i < num_steps + 1; ++i) { -- m_DT[0, i] = dt.count(); -+ m_DT(0, i) = dt.count(); - } - } else if (timestep_method == TimestepMethod::VARIABLE_SINGLE) { - Variable single_dt = decision_variable(); -@@ -131,12 +131,12 @@ class SLEIPNIR_DLLEXPORT OCP : public Problem { - // Set the member variable matrix to track the decision variable - m_DT = VariableMatrix{1, m_num_steps + 1}; - for (int i = 0; i < num_steps + 1; ++i) { -- m_DT[0, i] = single_dt; -+ m_DT(0, i) = single_dt; - } - } else if (timestep_method == TimestepMethod::VARIABLE) { - m_DT = decision_variable(1, m_num_steps + 1); - for (int i = 0; i < num_steps + 1; ++i) { -- m_DT[0, i].set_value(dt.count()); -+ m_DT(0, i).set_value(dt.count()); - } - } - -@@ -212,7 +212,7 @@ class SLEIPNIR_DLLEXPORT OCP : public Problem { - for (int i = 0; i < m_num_steps + 1; ++i) { - auto x = X().col(i); - auto u = U().col(i); -- auto dt = this->dt()[0, i]; -+ auto dt = this->dt()(0, i); - callback(time, x, u, dt); - - time += dt; -@@ -353,7 +353,7 @@ class SLEIPNIR_DLLEXPORT OCP : public Problem { - - // Derivation at https://mec560sbu.github.io/2016/09/30/direct_collocation/ - for (int i = 0; i < m_num_steps; ++i) { -- Variable h = dt()[0, i]; -+ Variable h = dt()(0, i); - - auto& f = m_dynamics; - -@@ -391,7 +391,7 @@ class SLEIPNIR_DLLEXPORT OCP : public Problem { - auto x_begin = X().col(i); - auto x_end = X().col(i + 1); - auto u = U().col(i); -- Variable dt = this->dt()[0, i]; -+ Variable dt = this->dt()(0, i); - - if (m_dynamics_type == DynamicsType::EXPLICIT_ODE) { - subject_to(x_end == rk4dt()[0, i]; -+ Variable dt = this->dt()(0, i); - - if (m_dynamics_type == DynamicsType::EXPLICIT_ODE) { - x_end = rk4 solve(const VariableMatrix& A, if (A.rows() == 1 && A.cols() == 1) { // Compute optimal inverse instead of using Eigen's general solver @@ -823,7 +795,7 @@ index 6c3a040e08bdc5009885e762402a8b44434024c3..d9619a39d583e1a29c46602ba61e8815 VariableMatrix adj_A{{d, -b}, {-c, a}}; auto det_A = a * d - b * c; -@@ -39,15 +39,15 @@ VariableMatrix solve(const VariableMatrix& A, const VariableMatrix& B) { +@@ -1588,15 +1588,15 @@ VariableMatrix solve(const VariableMatrix& A, // // https://www.wolframalpha.com/input?i=inverse+%7B%7Ba%2C+b%2C+c%7D%2C+%7Bd%2C+e%2C+f%7D%2C+%7Bg%2C+h%2C+i%7D%7D @@ -848,7 +820,7 @@ index 6c3a040e08bdc5009885e762402a8b44434024c3..d9619a39d583e1a29c46602ba61e8815 auto ae = a * e; auto af = a * f; -@@ -87,22 +87,22 @@ VariableMatrix solve(const VariableMatrix& A, const VariableMatrix& B) { +@@ -1636,22 +1636,22 @@ VariableMatrix solve(const VariableMatrix& A, // // https://www.wolframalpha.com/input?i=inverse+%7B%7Ba%2C+b%2C+c%2C+d%7D%2C+%7Be%2C+f%2C+g%2C+h%7D%2C+%7Bi%2C+j%2C+k%2C+l%7D%2C+%7Bm%2C+n%2C+o%2C+p%7D%7D @@ -887,11 +859,11 @@ index 6c3a040e08bdc5009885e762402a8b44434024c3..d9619a39d583e1a29c46602ba61e8815 auto afk = a * f * k; auto afl = a * f * l; -@@ -232,14 +232,14 @@ VariableMatrix solve(const VariableMatrix& A, const VariableMatrix& B) { +@@ -1782,14 +1782,14 @@ VariableMatrix solve(const VariableMatrix& A, MatrixXv eigen_A{A.rows(), A.cols()}; for (int row = 0; row < A.rows(); ++row) { for (int col = 0; col < A.cols(); ++col) { -- eigen_A(row, col) = A[row, col]; +- eigen_A[row, col] = A[row, col]; + eigen_A(row, col) = A(row, col); } } @@ -899,17 +871,105 @@ index 6c3a040e08bdc5009885e762402a8b44434024c3..d9619a39d583e1a29c46602ba61e8815 MatrixXv eigen_B{B.rows(), B.cols()}; for (int row = 0; row < B.rows(); ++row) { for (int col = 0; col < B.cols(); ++col) { -- eigen_B(row, col) = B[row, col]; +- eigen_B[row, col] = B[row, col]; + eigen_B(row, col) = B(row, col); } } -@@ -248,7 +248,7 @@ VariableMatrix solve(const VariableMatrix& A, const VariableMatrix& B) { - VariableMatrix X{VariableMatrix::empty, A.cols(), B.cols()}; +@@ -1798,7 +1798,7 @@ VariableMatrix solve(const VariableMatrix& A, + VariableMatrix X{detail::empty, A.cols(), B.cols()}; for (int row = 0; row < X.rows(); ++row) { for (int col = 0; col < X.cols(); ++col) { -- X[row, col] = eigen_X(row, col); +- X[row, col] = eigen_X[row, col]; + X(row, col) = eigen_X(row, col); } } +diff --git a/include/sleipnir/optimization/ocp.hpp b/include/sleipnir/optimization/ocp.hpp +index 88316894362ff3004627308c81c8f251291eae97..d62432a67af1c75b5cc0bbab54df1d785aec2846 100644 +--- a/include/sleipnir/optimization/ocp.hpp ++++ b/include/sleipnir/optimization/ocp.hpp +@@ -125,7 +125,7 @@ class OCP : public Problem { + if (timestep_method == TimestepMethod::FIXED) { + m_DT = VariableMatrix{1, m_num_steps + 1}; + for (int i = 0; i < num_steps + 1; ++i) { +- m_DT[0, i] = dt.count(); ++ m_DT(0, i) = dt.count(); + } + } else if (timestep_method == TimestepMethod::VARIABLE_SINGLE) { + Variable single_dt = this->decision_variable(); +@@ -134,12 +134,12 @@ class OCP : public Problem { + // Set the member variable matrix to track the decision variable + m_DT = VariableMatrix{1, m_num_steps + 1}; + for (int i = 0; i < num_steps + 1; ++i) { +- m_DT[0, i] = single_dt; ++ m_DT(0, i) = single_dt; + } + } else if (timestep_method == TimestepMethod::VARIABLE) { + m_DT = this->decision_variable(1, m_num_steps + 1); + for (int i = 0; i < num_steps + 1; ++i) { +- m_DT[0, i].set_value(dt.count()); ++ m_DT(0, i).set_value(dt.count()); + } + } + +@@ -216,7 +216,7 @@ class OCP : public Problem { + for (int i = 0; i < m_num_steps + 1; ++i) { + auto x = X().col(i); + auto u = U().col(i); +- auto dt = this->dt()[0, i]; ++ auto dt = this->dt()(0, i); + callback(time, x, u, dt); + + time += dt; +@@ -358,7 +358,7 @@ class OCP : public Problem { + + // Derivation at https://mec560sbu.github.io/2016/09/30/direct_collocation/ + for (int i = 0; i < m_num_steps; ++i) { +- Variable h = dt()[0, i]; ++ Variable h = dt()(0, i); + + auto& f = m_dynamics; + +@@ -397,7 +397,7 @@ class OCP : public Problem { + auto x_begin = X().col(i); + auto x_end = X().col(i + 1); + auto u = U().col(i); +- Variable dt = this->dt()[0, i]; ++ Variable dt = this->dt()(0, i); + + if (m_dynamics_type == DynamicsType::EXPLICIT_ODE) { + this->subject_to( +@@ -422,7 +422,7 @@ class OCP : public Problem { + auto x_begin = X().col(i); + auto x_end = X().col(i + 1); + auto u = U().col(i); +- Variable dt = this->dt()[0, i]; ++ Variable dt = this->dt()(0, i); + + if (m_dynamics_type == DynamicsType::EXPLICIT_ODE) { + x_end = rk4, +diff --git a/include/sleipnir/optimization/problem.hpp b/include/sleipnir/optimization/problem.hpp +index 5256d08e5f9d8642049d8bb8323d76c7b3bbbef7..a5db8e5902e440afd9f9ee1cc44c60872db2e4c1 100644 +--- a/include/sleipnir/optimization/problem.hpp ++++ b/include/sleipnir/optimization/problem.hpp +@@ -98,7 +98,7 @@ class Problem { + for (int row = 0; row < rows; ++row) { + for (int col = 0; col < cols; ++col) { + m_decision_variables.emplace_back(); +- vars[row, col] = m_decision_variables.back(); ++ vars(row, col) = m_decision_variables.back(); + } + } + +@@ -133,8 +133,8 @@ class Problem { + for (int row = 0; row < rows; ++row) { + for (int col = 0; col <= row; ++col) { + m_decision_variables.emplace_back(); +- vars[row, col] = m_decision_variables.back(); +- vars[col, row] = m_decision_variables.back(); ++ vars(row, col) = m_decision_variables.back(); ++ vars(col, row) = m_decision_variables.back(); + } + } + diff --git a/wpimath/src/main/native/cpp/controller/ArmFeedforward.cpp b/wpimath/src/main/native/cpp/controller/ArmFeedforward.cpp index 8525923727..227b940cb8 100644 --- a/wpimath/src/main/native/cpp/controller/ArmFeedforward.cpp +++ b/wpimath/src/main/native/cpp/controller/ArmFeedforward.cpp @@ -19,7 +19,7 @@ wpi::units::volt_t ArmFeedforward::Calculate( wpi::units::unit_t currentAngle, wpi::units::unit_t currentVelocity, wpi::units::unit_t nextVelocity) const { - using VarMat = slp::VariableMatrix; + using VarMat = slp::VariableMatrix; // Small kₐ values make the solver ill-conditioned if (kA < wpi::units::unit_t{1e-1}) { @@ -40,7 +40,7 @@ wpi::units::volt_t ArmFeedforward::Calculate( Vectord<2> r_k{currentAngle.value(), currentVelocity.value()}; - slp::Variable u_k; + slp::Variable u_k; // Initial guess auto acceleration = (nextVelocity - currentVelocity) / m_dt; diff --git a/wpimath/src/main/native/cpp/geometry/Ellipse2d.cpp b/wpimath/src/main/native/cpp/geometry/Ellipse2d.cpp index fc35f501c3..cd4dd80c60 100644 --- a/wpimath/src/main/native/cpp/geometry/Ellipse2d.cpp +++ b/wpimath/src/main/native/cpp/geometry/Ellipse2d.cpp @@ -20,7 +20,7 @@ Translation2d Ellipse2d::Nearest(const Translation2d& point) const { // Find nearest point { - slp::Problem problem; + slp::Problem problem; // Point on ellipse auto x = problem.decision_variable(); diff --git a/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/adjoint_expression_graph.hpp b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/adjoint_expression_graph.hpp index b333aebd3e..a77323eee9 100644 --- a/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/adjoint_expression_graph.hpp +++ b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/adjoint_expression_graph.hpp @@ -12,13 +12,17 @@ #include "sleipnir/autodiff/variable.hpp" #include "sleipnir/autodiff/variable_matrix.hpp" #include "sleipnir/util/assert.hpp" +#include "sleipnir/util/empty.hpp" namespace slp::detail { /** * This class is an adaptor type that performs value updates of an expression's * adjoint graph. + * + * @tparam Scalar Scalar type. */ +template class AdjointExpressionGraph { public: /** @@ -26,7 +30,7 @@ class AdjointExpressionGraph { * * @param root The root node of the expression. */ - explicit AdjointExpressionGraph(const Variable& root) + explicit AdjointExpressionGraph(const Variable& root) : m_top_list{topological_sort(root.expr)} { for (const auto& node : m_top_list) { m_col_list.emplace_back(node->col); @@ -50,18 +54,19 @@ class AdjointExpressionGraph { * @param wrt Variables with respect to which to compute the gradient. * @return The variable's gradient tree. */ - VariableMatrix generate_gradient_tree(const VariableMatrix& wrt) const { + VariableMatrix generate_gradient_tree( + const VariableMatrix& wrt) const { slp_assert(wrt.cols() == 1); // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation // for background on reverse accumulation automatic differentiation. if (m_top_list.empty()) { - return VariableMatrix{VariableMatrix::empty, wrt.rows(), 1}; + return VariableMatrix{detail::empty, wrt.rows(), 1}; } // Set root node's adjoint to 1 since df/df is 1 - m_top_list[0]->adjoint_expr = make_expression_ptr(1.0); + m_top_list[0]->adjoint_expr = constant_ptr(Scalar(1)); // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y // multiplied by dy/dx. If there are multiple "paths" from the root node to @@ -72,15 +77,19 @@ class AdjointExpressionGraph { auto& rhs = node->args[1]; if (lhs != nullptr) { - lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr); if (rhs != nullptr) { + // Binary operator + lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr); rhs->adjoint_expr += node->grad_expr_r(lhs, rhs, node->adjoint_expr); + } else { + // Unary operator + lhs->adjoint_expr += node->grad_expr_l(lhs, rhs, node->adjoint_expr); } } } // Move gradient tree to return value - VariableMatrix grad{VariableMatrix::empty, wrt.rows(), 1}; + VariableMatrix grad{detail::empty, wrt.rows(), 1}; for (int row = 0; row < grad.rows(); ++row) { grad[row] = Variable{std::move(wrt[row].expr->adjoint_expr)}; } @@ -104,16 +113,18 @@ class AdjointExpressionGraph { * @param wrt Vector of variables with respect to which to compute the * Jacobian. */ - void append_adjoint_triplets( - gch::small_vector>& triplets, int row, - const VariableMatrix& wrt) const { + void append_gradient_triplets( + gch::small_vector>& triplets, int row, + const VariableMatrix& wrt) const { + slp_assert(wrt.cols() == 1); + // Read docs/algorithms.md#Reverse_accumulation_automatic_differentiation // for background on reverse accumulation automatic differentiation. // If wrt has fewer nodes than graph, zero wrt's adjoints if (static_cast(wrt.rows()) < m_top_list.size()) { for (const auto& elem : wrt) { - elem.expr->adjoint = 0.0; + elem.expr->adjoint = Scalar(0); } } @@ -122,11 +133,11 @@ class AdjointExpressionGraph { } // Set root node's adjoint to 1 since df/df is 1 - m_top_list[0]->adjoint = 1.0; + m_top_list[0]->adjoint = Scalar(1); // Zero the rest of the adjoints for (auto& node : m_top_list | std::views::drop(1)) { - node->adjoint = 0.0; + node->adjoint = Scalar(0); } // df/dx = (df/dy)(dy/dx). The adjoint of x is equal to the adjoint of y @@ -139,10 +150,12 @@ class AdjointExpressionGraph { if (lhs != nullptr) { if (rhs != nullptr) { + // Binary operator lhs->adjoint += node->grad_l(lhs->val, rhs->val, node->adjoint); rhs->adjoint += node->grad_r(lhs->val, rhs->val, node->adjoint); } else { - lhs->adjoint += node->grad_l(lhs->val, 0.0, node->adjoint); + // Unary operator + lhs->adjoint += node->grad_l(lhs->val, Scalar(0), node->adjoint); } } } @@ -153,7 +166,7 @@ class AdjointExpressionGraph { const auto& node = wrt[col].expr; // Append adjoints of wrt to sparse matrix triplets - if (node->adjoint != 0.0) { + if (node->adjoint != Scalar(0)) { triplets.emplace_back(row, col, node->adjoint); } } @@ -163,7 +176,7 @@ class AdjointExpressionGraph { const auto& node = m_top_list[i]; // Append adjoints of wrt to sparse matrix triplets - if (col != -1 && node->adjoint != 0.0) { + if (col != -1 && node->adjoint != Scalar(0)) { triplets.emplace_back(row, col, node->adjoint); } } @@ -172,7 +185,7 @@ class AdjointExpressionGraph { private: // Topological sort of graph from parent to child - gch::small_vector m_top_list; + gch::small_vector*> m_top_list; // List that maps nodes to their respective column gch::small_vector m_col_list; diff --git a/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/expression.hpp b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/expression.hpp index 53a5f6d68d..46814576a3 100644 --- a/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/expression.hpp +++ b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/expression.hpp @@ -27,15 +27,21 @@ inline constexpr bool USE_POOL_ALLOCATOR = false; inline constexpr bool USE_POOL_ALLOCATOR = true; #endif +template struct Expression; -inline constexpr void inc_ref_count(Expression* expr); -inline void dec_ref_count(Expression* expr); +template +constexpr void inc_ref_count(Expression* expr); +template +void dec_ref_count(Expression* expr); /** * Typedef for intrusive shared pointer to Expression. + * + * @tparam Scalar Scalar type. */ -using ExpressionPtr = IntrusiveSharedPtr; +template +using ExpressionPtr = IntrusiveSharedPtr>; /** * Creates an intrusive shared pointer to an expression from the global pool @@ -45,7 +51,7 @@ using ExpressionPtr = IntrusiveSharedPtr; * @param args Constructor arguments for Expression. */ template -static ExpressionPtr make_expression_ptr(Args&&... args) { +static ExpressionPtr make_expression_ptr(Args&&... args) { if constexpr (USE_POOL_ALLOCATOR) { return allocate_intrusive_shared(global_pool_allocator(), std::forward(args)...); @@ -54,32 +60,50 @@ static ExpressionPtr make_expression_ptr(Args&&... args) { } } -template +template struct BinaryMinusExpression; -template +template struct BinaryPlusExpression; +template struct ConstExpression; -template +template struct DivExpression; -template +template struct MultExpression; -template +template struct UnaryMinusExpression; /** - * An autodiff expression node. + * Creates an intrusive shared pointer to a constant expression. + * + * @tparam Scalar Scalar type. + * @param value The expression value. */ -struct Expression { - /// The value of the expression node. - double val = 0.0; +template +ExpressionPtr constant_ptr(Scalar value); - /// The adjoint of the expression node used during autodiff. - double adjoint = 0.0; +/** + * An autodiff expression node. + * + * @tparam Scalar Scalar type. + */ +template +struct Expression { + /** + * Scalar type alias. + */ + using Scalar = Scalar_; + + /// The value of the expression node. + Scalar val{0}; + + /// The adjoint of the expression node, used during autodiff. + Scalar adjoint{0}; /// Counts incoming edges for this node. uint32_t incoming_edges = 0; @@ -87,15 +111,15 @@ struct Expression { /// This expression's column in a Jacobian, or -1 otherwise. int32_t col = -1; - /// The adjoint of the expression node used during gradient expression tree + /// The adjoint of the expression node, used during gradient expression tree /// generation. - ExpressionPtr adjoint_expr; + ExpressionPtr adjoint_expr; /// Reference count for intrusive shared pointer. uint32_t ref_count = 0; /// Expression arguments. - std::array args{nullptr, nullptr}; + std::array, 2> args{nullptr, nullptr}; /** * Constructs a constant expression with a value of zero. @@ -107,14 +131,14 @@ struct Expression { * * @param value The expression value. */ - explicit constexpr Expression(double value) : val{value} {} + explicit constexpr Expression(Scalar value) : val{value} {} /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr Expression(ExpressionPtr lhs) + explicit constexpr Expression(ExpressionPtr lhs) : args{std::move(lhs), nullptr} {} /** @@ -123,7 +147,7 @@ struct Expression { * @param lhs Binary operator's left operand. * @param rhs Binary operator's right operand. */ - constexpr Expression(ExpressionPtr lhs, ExpressionPtr rhs) + constexpr Expression(ExpressionPtr lhs, ExpressionPtr rhs) : args{std::move(lhs), std::move(rhs)} {} virtual ~Expression() = default; @@ -135,7 +159,7 @@ struct Expression { * * @return True if the expression is the given constant. */ - constexpr bool is_constant(double constant) const { + constexpr bool is_constant(Scalar constant) const { return type() == ExpressionType::CONSTANT && val == constant; } @@ -145,49 +169,49 @@ struct Expression { * @param lhs Operator left-hand side. * @param rhs Operator right-hand side. */ - friend ExpressionPtr operator*(const ExpressionPtr& lhs, - const ExpressionPtr& rhs) { + friend ExpressionPtr operator*(const ExpressionPtr& lhs, + const ExpressionPtr& rhs) { using enum ExpressionType; // Prune expression - if (lhs->is_constant(0.0)) { + if (lhs->is_constant(Scalar(0))) { // Return zero return lhs; - } else if (rhs->is_constant(0.0)) { + } else if (rhs->is_constant(Scalar(0))) { // Return zero return rhs; - } else if (lhs->is_constant(1.0)) { + } else if (lhs->is_constant(Scalar(1))) { return rhs; - } else if (rhs->is_constant(1.0)) { + } else if (rhs->is_constant(Scalar(1))) { return lhs; } // Evaluate constant if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) { - return make_expression_ptr(lhs->val * rhs->val); + return constant_ptr(lhs->val * rhs->val); } // Evaluate expression type if (lhs->type() == CONSTANT) { if (rhs->type() == LINEAR) { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, rhs); } else if (rhs->type() == QUADRATIC) { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, rhs); } else { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, rhs); } } else if (rhs->type() == CONSTANT) { if (lhs->type() == LINEAR) { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, rhs); } else if (lhs->type() == QUADRATIC) { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, rhs); } else { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, rhs); } } else if (lhs->type() == LINEAR && rhs->type() == LINEAR) { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, rhs); } else { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, rhs); } } @@ -197,34 +221,34 @@ struct Expression { * @param lhs Operator left-hand side. * @param rhs Operator right-hand side. */ - friend ExpressionPtr operator/(const ExpressionPtr& lhs, - const ExpressionPtr& rhs) { + friend ExpressionPtr operator/(const ExpressionPtr& lhs, + const ExpressionPtr& rhs) { using enum ExpressionType; // Prune expression - if (lhs->is_constant(0.0)) { + if (lhs->is_constant(Scalar(0))) { // Return zero return lhs; - } else if (rhs->is_constant(1.0)) { + } else if (rhs->is_constant(Scalar(1))) { return lhs; } // Evaluate constant if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) { - return make_expression_ptr(lhs->val / rhs->val); + return constant_ptr(lhs->val / rhs->val); } // Evaluate expression type if (rhs->type() == CONSTANT) { if (lhs->type() == LINEAR) { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, rhs); } else if (lhs->type() == QUADRATIC) { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, rhs); } else { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, rhs); } } else { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, rhs); } } @@ -234,29 +258,32 @@ struct Expression { * @param lhs Operator left-hand side. * @param rhs Operator right-hand side. */ - friend ExpressionPtr operator+(const ExpressionPtr& lhs, - const ExpressionPtr& rhs) { + friend ExpressionPtr operator+(const ExpressionPtr& lhs, + const ExpressionPtr& rhs) { using enum ExpressionType; // Prune expression - if (lhs == nullptr || lhs->is_constant(0.0)) { + if (lhs == nullptr || lhs->is_constant(Scalar(0))) { return rhs; - } else if (rhs == nullptr || rhs->is_constant(0.0)) { + } else if (rhs == nullptr || rhs->is_constant(Scalar(0))) { return lhs; } // Evaluate constant if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) { - return make_expression_ptr(lhs->val + rhs->val); + return constant_ptr(lhs->val + rhs->val); } auto type = std::max(lhs->type(), rhs->type()); if (type == LINEAR) { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, + rhs); } else if (type == QUADRATIC) { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, + rhs); } else { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, + rhs); } } @@ -266,8 +293,8 @@ struct Expression { * @param lhs Operator left-hand side. * @param rhs Operator right-hand side. */ - friend ExpressionPtr operator+=(ExpressionPtr& lhs, - const ExpressionPtr& rhs) { + friend ExpressionPtr operator+=(ExpressionPtr& lhs, + const ExpressionPtr& rhs) { return lhs = lhs + rhs; } @@ -277,34 +304,37 @@ struct Expression { * @param lhs Operator left-hand side. * @param rhs Operator right-hand side. */ - friend ExpressionPtr operator-(const ExpressionPtr& lhs, - const ExpressionPtr& rhs) { + friend ExpressionPtr operator-(const ExpressionPtr& lhs, + const ExpressionPtr& rhs) { using enum ExpressionType; // Prune expression - if (lhs->is_constant(0.0)) { - if (rhs->is_constant(0.0)) { + if (lhs->is_constant(Scalar(0))) { + if (rhs->is_constant(Scalar(0))) { // Return zero return rhs; } else { return -rhs; } - } else if (rhs->is_constant(0.0)) { + } else if (rhs->is_constant(Scalar(0))) { return lhs; } // Evaluate constant if (lhs->type() == CONSTANT && rhs->type() == CONSTANT) { - return make_expression_ptr(lhs->val - rhs->val); + return constant_ptr(lhs->val - rhs->val); } auto type = std::max(lhs->type(), rhs->type()); if (type == LINEAR) { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, + rhs); } else if (type == QUADRATIC) { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, + rhs); } else { - return make_expression_ptr>(lhs, rhs); + return make_expression_ptr>(lhs, + rhs); } } @@ -313,26 +343,26 @@ struct Expression { * * @param lhs Operand of unary minus. */ - friend ExpressionPtr operator-(const ExpressionPtr& lhs) { + friend ExpressionPtr operator-(const ExpressionPtr& lhs) { using enum ExpressionType; // Prune expression - if (lhs->is_constant(0.0)) { + if (lhs->is_constant(Scalar(0))) { // Return zero return lhs; } // Evaluate constant if (lhs->type() == CONSTANT) { - return make_expression_ptr(-lhs->val); + return constant_ptr(-lhs->val); } if (lhs->type() == LINEAR) { - return make_expression_ptr>(lhs); + return make_expression_ptr>(lhs); } else if (lhs->type() == QUADRATIC) { - return make_expression_ptr>(lhs); + return make_expression_ptr>(lhs); } else { - return make_expression_ptr>(lhs); + return make_expression_ptr>(lhs); } } @@ -341,7 +371,9 @@ struct Expression { * * @param lhs Operand of unary plus. */ - friend ExpressionPtr operator+(const ExpressionPtr& lhs) { return lhs; } + friend ExpressionPtr operator+(const ExpressionPtr& lhs) { + return lhs; + } /** * Either nullary operator with no arguments, unary operator with one @@ -352,8 +384,8 @@ struct Expression { * @param rhs Right argument to binary operator. * @return The node's value. */ - virtual double value([[maybe_unused]] double lhs, - [[maybe_unused]] double rhs) const = 0; + virtual Scalar value([[maybe_unused]] Scalar lhs, + [[maybe_unused]] Scalar rhs) const = 0; /** * Returns the type of this expression (constant, linear, quadratic, or @@ -364,107 +396,119 @@ struct Expression { virtual ExpressionType type() const = 0; /** - * Returns double adjoint of the left child expression. + * Returns ∂/∂l as a Scalar. * * @param lhs Left argument to binary operator. * @param rhs Right argument to binary operator. * @param parent_adjoint Adjoint of parent expression. - * @return The double adjoint of the left child expression. + * @return ∂/∂l as a Scalar. */ - virtual double grad_l([[maybe_unused]] double lhs, - [[maybe_unused]] double rhs, - [[maybe_unused]] double parent_adjoint) const { - return 0.0; + virtual Scalar grad_l([[maybe_unused]] Scalar lhs, + [[maybe_unused]] Scalar rhs, + [[maybe_unused]] Scalar parent_adjoint) const { + return Scalar(0); } /** - * Returns double adjoint of the right child expression. + * Returns ∂/∂r as a Scalar. * * @param lhs Left argument to binary operator. * @param rhs Right argument to binary operator. * @param parent_adjoint Adjoint of parent expression. - * @return The double adjoint of the right child expression. + * @return ∂/∂r as a Scalar. */ - virtual double grad_r([[maybe_unused]] double lhs, - [[maybe_unused]] double rhs, - [[maybe_unused]] double parent_adjoint) const { - return 0.0; + virtual Scalar grad_r([[maybe_unused]] Scalar lhs, + [[maybe_unused]] Scalar rhs, + [[maybe_unused]] Scalar parent_adjoint) const { + return Scalar(0); } /** - * Returns Expression adjoint of the left child expression. + * Returns ∂/∂l as an Expression. * * @param lhs Left argument to binary operator. * @param rhs Right argument to binary operator. * @param parent_adjoint Adjoint of parent expression. - * @return The Expression adjoint of the left child expression. + * @return ∂/∂l as an Expression. */ - virtual ExpressionPtr grad_expr_l( - [[maybe_unused]] const ExpressionPtr& lhs, - [[maybe_unused]] const ExpressionPtr& rhs, - [[maybe_unused]] const ExpressionPtr& parent_adjoint) const { - return make_expression_ptr(); + virtual ExpressionPtr grad_expr_l( + [[maybe_unused]] const ExpressionPtr& lhs, + [[maybe_unused]] const ExpressionPtr& rhs, + [[maybe_unused]] const ExpressionPtr& parent_adjoint) const { + return constant_ptr(Scalar(0)); } /** - * Returns Expression adjoint of the right child expression. + * Returns ∂/∂r as an Expression. * * @param lhs Left argument to binary operator. * @param rhs Right argument to binary operator. * @param parent_adjoint Adjoint of parent expression. - * @return The Expression adjoint of the right child expression. + * @return ∂/∂r as an Expression. */ - virtual ExpressionPtr grad_expr_r( - [[maybe_unused]] const ExpressionPtr& lhs, - [[maybe_unused]] const ExpressionPtr& rhs, - [[maybe_unused]] const ExpressionPtr& parent_adjoint) const { - return make_expression_ptr(); + virtual ExpressionPtr grad_expr_r( + [[maybe_unused]] const ExpressionPtr& lhs, + [[maybe_unused]] const ExpressionPtr& rhs, + [[maybe_unused]] const ExpressionPtr& parent_adjoint) const { + return constant_ptr(Scalar(0)); } }; -inline ExpressionPtr cbrt(const ExpressionPtr& x); -inline ExpressionPtr exp(const ExpressionPtr& x); -inline ExpressionPtr sin(const ExpressionPtr& x); -inline ExpressionPtr sinh(const ExpressionPtr& x); -inline ExpressionPtr sqrt(const ExpressionPtr& x); +template +ExpressionPtr constant_ptr(Scalar value) { + return make_expression_ptr>(value); +} + +template +ExpressionPtr cbrt(const ExpressionPtr& x); +template +ExpressionPtr exp(const ExpressionPtr& x); +template +ExpressionPtr sin(const ExpressionPtr& x); +template +ExpressionPtr sinh(const ExpressionPtr& x); +template +ExpressionPtr sqrt(const ExpressionPtr& x); /** * Derived expression type for binary minus operator. * + * @tparam Scalar Scalar type. * @tparam T Expression type. */ -template -struct BinaryMinusExpression final : Expression { +template +struct BinaryMinusExpression final : Expression { /** * Constructs a binary expression (an operator with two arguments). * * @param lhs Binary operator's left operand. * @param rhs Binary operator's right operand. */ - constexpr BinaryMinusExpression(ExpressionPtr lhs, ExpressionPtr rhs) - : Expression{std::move(lhs), std::move(rhs)} {} + constexpr BinaryMinusExpression(ExpressionPtr lhs, + ExpressionPtr rhs) + : Expression{std::move(lhs), std::move(rhs)} {} - double value(double lhs, double rhs) const override { return lhs - rhs; } + Scalar value(Scalar lhs, Scalar rhs) const override { return lhs - rhs; } ExpressionType type() const override { return T; } - double grad_l(double, double, double parent_adjoint) const override { + Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override { return parent_adjoint; } - double grad_r(double, double, double parent_adjoint) const override { + Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override { return -parent_adjoint; } - ExpressionPtr grad_expr_l( - const ExpressionPtr&, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { + ExpressionPtr grad_expr_l( + const ExpressionPtr&, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { return parent_adjoint; } - ExpressionPtr grad_expr_r( - const ExpressionPtr&, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { + ExpressionPtr grad_expr_r( + const ExpressionPtr&, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { return -parent_adjoint; } }; @@ -472,121 +516,136 @@ struct BinaryMinusExpression final : Expression { /** * Derived expression type for binary plus operator. * + * @tparam Scalar Scalar type. * @tparam T Expression type. */ -template -struct BinaryPlusExpression final : Expression { +template +struct BinaryPlusExpression final : Expression { /** * Constructs a binary expression (an operator with two arguments). * * @param lhs Binary operator's left operand. * @param rhs Binary operator's right operand. */ - constexpr BinaryPlusExpression(ExpressionPtr lhs, ExpressionPtr rhs) - : Expression{std::move(lhs), std::move(rhs)} {} + constexpr BinaryPlusExpression(ExpressionPtr lhs, + ExpressionPtr rhs) + : Expression{std::move(lhs), std::move(rhs)} {} - double value(double lhs, double rhs) const override { return lhs + rhs; } + Scalar value(Scalar lhs, Scalar rhs) const override { return lhs + rhs; } ExpressionType type() const override { return T; } - double grad_l(double, double, double parent_adjoint) const override { + Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override { return parent_adjoint; } - double grad_r(double, double, double parent_adjoint) const override { + Scalar grad_r(Scalar, Scalar, Scalar parent_adjoint) const override { return parent_adjoint; } - ExpressionPtr grad_expr_l( - const ExpressionPtr&, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { + ExpressionPtr grad_expr_l( + const ExpressionPtr&, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { return parent_adjoint; } - ExpressionPtr grad_expr_r( - const ExpressionPtr&, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { + ExpressionPtr grad_expr_r( + const ExpressionPtr&, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { return parent_adjoint; } }; /** - * Derived expression type for std::cbrt(). + * Derived expression type for cbrt(). + * + * @tparam Scalar Scalar type. */ -struct CbrtExpression final : Expression { +template +struct CbrtExpression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr CbrtExpression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr CbrtExpression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double x, double) const override { return std::cbrt(x); } + Scalar value(Scalar x, Scalar) const override { + using std::cbrt; + return cbrt(x); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double x, double, double parent_adjoint) const override { - double c = std::cbrt(x); - return parent_adjoint / (3.0 * c * c); + Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override { + using std::cbrt; + + Scalar c = cbrt(x); + return parent_adjoint / (Scalar(3) * c * c); } - ExpressionPtr grad_expr_l( - const ExpressionPtr& x, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { - auto c = slp::detail::cbrt(x); - return parent_adjoint / (make_expression_ptr(3.0) * c * c); + ExpressionPtr grad_expr_l( + const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { + auto c = cbrt(x); + return parent_adjoint / (constant_ptr(Scalar(3)) * c * c); } }; /** - * std::cbrt() for Expressions. + * cbrt() for Expressions. * + * @tparam Scalar Scalar type. * @param x The argument. */ -inline ExpressionPtr cbrt(const ExpressionPtr& x) { +template +ExpressionPtr cbrt(const ExpressionPtr& x) { using enum ExpressionType; + using std::cbrt; // Evaluate constant if (x->type() == CONSTANT) { - if (x->val == 0.0) { + if (x->val == Scalar(0)) { // Return zero return x; - } else if (x->val == -1.0 || x->val == 1.0) { + } else if (x->val == Scalar(-1) || x->val == Scalar(1)) { return x; } else { - return make_expression_ptr(std::cbrt(x->val)); + return constant_ptr(cbrt(x->val)); } } - return make_expression_ptr(x); + return make_expression_ptr>(x); } /** * Derived expression type for constant. + * + * @tparam Scalar Scalar type. */ -struct ConstExpression final : Expression { - /** - * Constructs a constant expression with a value of zero. - */ - constexpr ConstExpression() = default; - +template +struct ConstExpression final : Expression { /** * Constructs a nullary expression (an operator with no arguments). * * @param value The expression value. */ - explicit constexpr ConstExpression(double value) : Expression{value} {} + explicit constexpr ConstExpression(Scalar value) + : Expression{value} {} - double value(double, double) const override { return val; } + Scalar value(Scalar, Scalar) const override { return this->val; } ExpressionType type() const override { return ExpressionType::CONSTANT; } }; /** * Derived expression type for decision variable. + * + * @tparam Scalar Scalar type. */ -struct DecisionVariableExpression final : Expression { +template +struct DecisionVariableExpression final : Expression { /** * Constructs a decision variable expression with a value of zero. */ @@ -597,10 +656,10 @@ struct DecisionVariableExpression final : Expression { * * @param value The expression value. */ - explicit constexpr DecisionVariableExpression(double value) - : Expression{value} {} + explicit constexpr DecisionVariableExpression(Scalar value) + : Expression{value} {} - double value(double, double) const override { return val; } + Scalar value(Scalar, Scalar) const override { return this->val; } ExpressionType type() const override { return ExpressionType::LINEAR; } }; @@ -608,40 +667,41 @@ struct DecisionVariableExpression final : Expression { /** * Derived expression type for binary division operator. * + * @tparam Scalar Scalar type. * @tparam T Expression type. */ -template -struct DivExpression final : Expression { +template +struct DivExpression final : Expression { /** * Constructs a binary expression (an operator with two arguments). * * @param lhs Binary operator's left operand. * @param rhs Binary operator's right operand. */ - constexpr DivExpression(ExpressionPtr lhs, ExpressionPtr rhs) - : Expression{std::move(lhs), std::move(rhs)} {} + constexpr DivExpression(ExpressionPtr lhs, ExpressionPtr rhs) + : Expression{std::move(lhs), std::move(rhs)} {} - double value(double lhs, double rhs) const override { return lhs / rhs; } + Scalar value(Scalar lhs, Scalar rhs) const override { return lhs / rhs; } ExpressionType type() const override { return T; } - double grad_l(double, double rhs, double parent_adjoint) const override { + Scalar grad_l(Scalar, Scalar rhs, Scalar parent_adjoint) const override { return parent_adjoint / rhs; }; - double grad_r(double lhs, double rhs, double parent_adjoint) const override { + Scalar grad_r(Scalar lhs, Scalar rhs, Scalar parent_adjoint) const override { return parent_adjoint * -lhs / (rhs * rhs); } - ExpressionPtr grad_expr_l( - const ExpressionPtr&, const ExpressionPtr& rhs, - const ExpressionPtr& parent_adjoint) const override { + ExpressionPtr grad_expr_l( + const ExpressionPtr&, const ExpressionPtr& rhs, + const ExpressionPtr& parent_adjoint) const override { return parent_adjoint / rhs; } - ExpressionPtr grad_expr_r( - const ExpressionPtr& lhs, const ExpressionPtr& rhs, - const ExpressionPtr& parent_adjoint) const override { + ExpressionPtr grad_expr_r( + const ExpressionPtr& lhs, const ExpressionPtr& rhs, + const ExpressionPtr& parent_adjoint) const override { return parent_adjoint * -lhs / (rhs * rhs); } }; @@ -649,42 +709,45 @@ struct DivExpression final : Expression { /** * Derived expression type for binary multiplication operator. * + * @tparam Scalar Scalar type. * @tparam T Expression type. */ -template -struct MultExpression final : Expression { +template +struct MultExpression final : Expression { /** * Constructs a binary expression (an operator with two arguments). * * @param lhs Binary operator's left operand. * @param rhs Binary operator's right operand. */ - constexpr MultExpression(ExpressionPtr lhs, ExpressionPtr rhs) - : Expression{std::move(lhs), std::move(rhs)} {} + constexpr MultExpression(ExpressionPtr lhs, ExpressionPtr rhs) + : Expression{std::move(lhs), std::move(rhs)} {} - double value(double lhs, double rhs) const override { return lhs * rhs; } + Scalar value(Scalar lhs, Scalar rhs) const override { return lhs * rhs; } ExpressionType type() const override { return T; } - double grad_l([[maybe_unused]] double lhs, double rhs, - double parent_adjoint) const override { + Scalar grad_l([[maybe_unused]] Scalar lhs, Scalar rhs, + Scalar parent_adjoint) const override { return parent_adjoint * rhs; } - double grad_r(double lhs, [[maybe_unused]] double rhs, - double parent_adjoint) const override { + Scalar grad_r(Scalar lhs, [[maybe_unused]] Scalar rhs, + Scalar parent_adjoint) const override { return parent_adjoint * lhs; } - ExpressionPtr grad_expr_l( - [[maybe_unused]] const ExpressionPtr& lhs, const ExpressionPtr& rhs, - const ExpressionPtr& parent_adjoint) const override { + ExpressionPtr grad_expr_l( + [[maybe_unused]] const ExpressionPtr& lhs, + const ExpressionPtr& rhs, + const ExpressionPtr& parent_adjoint) const override { return parent_adjoint * rhs; } - ExpressionPtr grad_expr_r( - const ExpressionPtr& lhs, [[maybe_unused]] const ExpressionPtr& rhs, - const ExpressionPtr& parent_adjoint) const override { + ExpressionPtr grad_expr_r( + const ExpressionPtr& lhs, + [[maybe_unused]] const ExpressionPtr& rhs, + const ExpressionPtr& parent_adjoint) const override { return parent_adjoint * lhs; } }; @@ -692,29 +755,30 @@ struct MultExpression final : Expression { /** * Derived expression type for unary minus operator. * + * @tparam Scalar Scalar type. * @tparam T Expression type. */ -template -struct UnaryMinusExpression final : Expression { +template +struct UnaryMinusExpression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr UnaryMinusExpression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr UnaryMinusExpression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double lhs, double) const override { return -lhs; } + Scalar value(Scalar lhs, Scalar) const override { return -lhs; } ExpressionType type() const override { return T; } - double grad_l(double, double, double parent_adjoint) const override { + Scalar grad_l(Scalar, Scalar, Scalar parent_adjoint) const override { return -parent_adjoint; } - ExpressionPtr grad_expr_l( - const ExpressionPtr&, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { + ExpressionPtr grad_expr_l( + const ExpressionPtr&, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { return -parent_adjoint; } }; @@ -722,23 +786,27 @@ struct UnaryMinusExpression final : Expression { /** * Refcount increment for intrusive shared pointer. * + * @tparam Scalar Scalar type. * @param expr The shared pointer's managed object. */ -inline constexpr void inc_ref_count(Expression* expr) { +template +constexpr void inc_ref_count(Expression* expr) { ++expr->ref_count; } /** * Refcount decrement for intrusive shared pointer. * + * @tparam Scalar Scalar type. * @param expr The shared pointer's managed object. */ -inline void dec_ref_count(Expression* expr) { +template +void dec_ref_count(Expression* expr) { // If a deeply nested tree is being deallocated all at once, calling the // Expression destructor when expr's refcount reaches zero can cause a stack // overflow. Instead, we iterate over its children to decrement their // refcounts and deallocate them. - gch::small_vector stack; + gch::small_vector*> stack; stack.emplace_back(expr); while (!stack.empty()) { @@ -760,1052 +828,1225 @@ inline void dec_ref_count(Expression* expr) { // Not calling the destructor here is safe because it only decrements // refcounts, which was already done above. if constexpr (USE_POOL_ALLOCATOR) { - auto alloc = global_pool_allocator(); - std::allocator_traits::deallocate(alloc, elem, - sizeof(Expression)); + auto alloc = global_pool_allocator>(); + std::allocator_traits::deallocate( + alloc, elem, sizeof(Expression)); } } } } /** - * Derived expression type for std::abs(). + * Derived expression type for abs(). + * + * @tparam Scalar Scalar type. */ -struct AbsExpression final : Expression { +template +struct AbsExpression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr AbsExpression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr AbsExpression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double x, double) const override { return std::abs(x); } + Scalar value(Scalar x, Scalar) const override { + using std::abs; + return abs(x); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double x, double, double parent_adjoint) const override { - if (x < 0.0) { + Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override { + if (x < Scalar(0)) { return -parent_adjoint; - } else if (x > 0.0) { + } else if (x > Scalar(0)) { return parent_adjoint; } else { - return 0.0; + return Scalar(0); } } - ExpressionPtr grad_expr_l( - const ExpressionPtr& x, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { - if (x->val < 0.0) { + ExpressionPtr grad_expr_l( + const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { + if (x->val < Scalar(0)) { return -parent_adjoint; - } else if (x->val > 0.0) { + } else if (x->val > Scalar(0)) { return parent_adjoint; } else { - // Return zero - return make_expression_ptr(); + return constant_ptr(Scalar(0)); } } }; /** - * std::abs() for Expressions. + * abs() for Expressions. * + * @tparam Scalar Scalar type. * @param x The argument. */ -inline ExpressionPtr abs(const ExpressionPtr& x) { +template +ExpressionPtr abs(const ExpressionPtr& x) { using enum ExpressionType; + using std::abs; // Prune expression - if (x->is_constant(0.0)) { + if (x->is_constant(Scalar(0))) { // Return zero return x; } // Evaluate constant if (x->type() == CONSTANT) { - return make_expression_ptr(std::abs(x->val)); + return constant_ptr(abs(x->val)); } - return make_expression_ptr(x); + return make_expression_ptr>(x); } /** - * Derived expression type for std::acos(). + * Derived expression type for acos(). + * + * @tparam Scalar Scalar type. */ -struct AcosExpression final : Expression { +template +struct AcosExpression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr AcosExpression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr AcosExpression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double x, double) const override { return std::acos(x); } + Scalar value(Scalar x, Scalar) const override { + using std::acos; + return acos(x); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double x, double, double parent_adjoint) const override { - return -parent_adjoint / std::sqrt(1.0 - x * x); + Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override { + using std::sqrt; + return -parent_adjoint / sqrt(Scalar(1) - x * x); } - ExpressionPtr grad_expr_l( - const ExpressionPtr& x, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { - return -parent_adjoint / - slp::detail::sqrt(make_expression_ptr(1.0) - x * x); + ExpressionPtr grad_expr_l( + const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { + return -parent_adjoint / sqrt(constant_ptr(Scalar(1)) - x * x); } }; /** - * std::acos() for Expressions. + * acos() for Expressions. * + * @tparam Scalar Scalar type. * @param x The argument. */ -inline ExpressionPtr acos(const ExpressionPtr& x) { +template +ExpressionPtr acos(const ExpressionPtr& x) { using enum ExpressionType; + using std::acos; // Prune expression - if (x->is_constant(0.0)) { - return make_expression_ptr(std::numbers::pi / 2.0); + if (x->is_constant(Scalar(0))) { + return constant_ptr(Scalar(std::numbers::pi) / Scalar(2)); } // Evaluate constant if (x->type() == CONSTANT) { - return make_expression_ptr(std::acos(x->val)); + return constant_ptr(acos(x->val)); } - return make_expression_ptr(x); + return make_expression_ptr>(x); } /** - * Derived expression type for std::asin(). + * Derived expression type for asin(). + * + * @tparam Scalar Scalar type. */ -struct AsinExpression final : Expression { +template +struct AsinExpression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr AsinExpression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr AsinExpression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double x, double) const override { return std::asin(x); } + Scalar value(Scalar x, Scalar) const override { + using std::asin; + return asin(x); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double x, double, double parent_adjoint) const override { - return parent_adjoint / std::sqrt(1.0 - x * x); + Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override { + using std::sqrt; + return parent_adjoint / sqrt(Scalar(1) - x * x); } - ExpressionPtr grad_expr_l( - const ExpressionPtr& x, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { - return parent_adjoint / - slp::detail::sqrt(make_expression_ptr(1.0) - x * x); + ExpressionPtr grad_expr_l( + const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { + return parent_adjoint / sqrt(constant_ptr(Scalar(1)) - x * x); } }; /** - * std::asin() for Expressions. + * asin() for Expressions. * + * @tparam Scalar Scalar type. * @param x The argument. */ -inline ExpressionPtr asin(const ExpressionPtr& x) { +template +ExpressionPtr asin(const ExpressionPtr& x) { using enum ExpressionType; + using std::asin; // Prune expression - if (x->is_constant(0.0)) { + if (x->is_constant(Scalar(0))) { // Return zero return x; } // Evaluate constant if (x->type() == CONSTANT) { - return make_expression_ptr(std::asin(x->val)); + return constant_ptr(asin(x->val)); } - return make_expression_ptr(x); + return make_expression_ptr>(x); } /** - * Derived expression type for std::atan(). + * Derived expression type for atan(). + * + * @tparam Scalar Scalar type. */ -struct AtanExpression final : Expression { +template +struct AtanExpression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr AtanExpression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr AtanExpression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double x, double) const override { return std::atan(x); } + Scalar value(Scalar x, Scalar) const override { + using std::atan; + return atan(x); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double x, double, double parent_adjoint) const override { - return parent_adjoint / (1.0 + x * x); + Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override { + return parent_adjoint / (Scalar(1) + x * x); } - ExpressionPtr grad_expr_l( - const ExpressionPtr& x, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { - return parent_adjoint / (make_expression_ptr(1.0) + x * x); + ExpressionPtr grad_expr_l( + const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { + return parent_adjoint / (constant_ptr(Scalar(1)) + x * x); } }; /** - * std::atan() for Expressions. + * atan() for Expressions. * + * @tparam Scalar Scalar type. * @param x The argument. */ -inline ExpressionPtr atan(const ExpressionPtr& x) { +template +ExpressionPtr atan(const ExpressionPtr& x) { using enum ExpressionType; + using std::atan; // Prune expression - if (x->is_constant(0.0)) { + if (x->is_constant(Scalar(0))) { // Return zero return x; } // Evaluate constant if (x->type() == CONSTANT) { - return make_expression_ptr(std::atan(x->val)); + return constant_ptr(atan(x->val)); } - return make_expression_ptr(x); + return make_expression_ptr>(x); } /** - * Derived expression type for std::atan2(). + * Derived expression type for atan2(). + * + * @tparam Scalar Scalar type. */ -struct Atan2Expression final : Expression { +template +struct Atan2Expression final : Expression { /** * Constructs a binary expression (an operator with two arguments). * * @param lhs Binary operator's left operand. * @param rhs Binary operator's right operand. */ - constexpr Atan2Expression(ExpressionPtr lhs, ExpressionPtr rhs) - : Expression{std::move(lhs), std::move(rhs)} {} + constexpr Atan2Expression(ExpressionPtr lhs, + ExpressionPtr rhs) + : Expression{std::move(lhs), std::move(rhs)} {} - double value(double y, double x) const override { return std::atan2(y, x); } + Scalar value(Scalar y, Scalar x) const override { + using std::atan2; + return atan2(y, x); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double y, double x, double parent_adjoint) const override { + Scalar grad_l(Scalar y, Scalar x, Scalar parent_adjoint) const override { return parent_adjoint * x / (y * y + x * x); } - double grad_r(double y, double x, double parent_adjoint) const override { + Scalar grad_r(Scalar y, Scalar x, Scalar parent_adjoint) const override { return parent_adjoint * -y / (y * y + x * x); } - ExpressionPtr grad_expr_l( - const ExpressionPtr& y, const ExpressionPtr& x, - const ExpressionPtr& parent_adjoint) const override { + ExpressionPtr grad_expr_l( + const ExpressionPtr& y, const ExpressionPtr& x, + const ExpressionPtr& parent_adjoint) const override { return parent_adjoint * x / (y * y + x * x); } - ExpressionPtr grad_expr_r( - const ExpressionPtr& y, const ExpressionPtr& x, - const ExpressionPtr& parent_adjoint) const override { + ExpressionPtr grad_expr_r( + const ExpressionPtr& y, const ExpressionPtr& x, + const ExpressionPtr& parent_adjoint) const override { return parent_adjoint * -y / (y * y + x * x); } }; /** - * std::atan2() for Expressions. + * atan2() for Expressions. * + * @tparam Scalar Scalar type. * @param y The y argument. * @param x The x argument. */ -inline ExpressionPtr atan2(const ExpressionPtr& y, const ExpressionPtr& x) { +template +ExpressionPtr atan2(const ExpressionPtr& y, + const ExpressionPtr& x) { using enum ExpressionType; + using std::atan2; // Prune expression - if (y->is_constant(0.0)) { + if (y->is_constant(Scalar(0))) { // Return zero return y; - } else if (x->is_constant(0.0)) { - return make_expression_ptr(std::numbers::pi / 2.0); + } else if (x->is_constant(Scalar(0))) { + return constant_ptr(Scalar(std::numbers::pi) / Scalar(2)); } // Evaluate constant if (y->type() == CONSTANT && x->type() == CONSTANT) { - return make_expression_ptr(std::atan2(y->val, x->val)); + return constant_ptr(atan2(y->val, x->val)); } - return make_expression_ptr(y, x); + return make_expression_ptr>(y, x); } /** - * Derived expression type for std::cos(). + * Derived expression type for cos(). + * + * @tparam Scalar Scalar type. */ -struct CosExpression final : Expression { +template +struct CosExpression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr CosExpression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr CosExpression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double x, double) const override { return std::cos(x); } + Scalar value(Scalar x, Scalar) const override { + using std::cos; + return cos(x); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double x, double, double parent_adjoint) const override { - return -parent_adjoint * std::sin(x); + Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override { + using std::sin; + return parent_adjoint * -sin(x); } - ExpressionPtr grad_expr_l( - const ExpressionPtr& x, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { - return parent_adjoint * -slp::detail::sin(x); + ExpressionPtr grad_expr_l( + const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { + return parent_adjoint * -sin(x); } }; /** - * std::cos() for Expressions. + * cos() for Expressions. * + * @tparam Scalar Scalar type. * @param x The argument. */ -inline ExpressionPtr cos(const ExpressionPtr& x) { +template +ExpressionPtr cos(const ExpressionPtr& x) { using enum ExpressionType; + using std::cos; // Prune expression - if (x->is_constant(0.0)) { - return make_expression_ptr(1.0); + if (x->is_constant(Scalar(0))) { + return constant_ptr(Scalar(1)); } // Evaluate constant if (x->type() == CONSTANT) { - return make_expression_ptr(std::cos(x->val)); + return constant_ptr(cos(x->val)); } - return make_expression_ptr(x); + return make_expression_ptr>(x); } /** - * Derived expression type for std::cosh(). + * Derived expression type for cosh(). + * + * @tparam Scalar Scalar type. */ -struct CoshExpression final : Expression { +template +struct CoshExpression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr CoshExpression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr CoshExpression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double x, double) const override { return std::cosh(x); } + Scalar value(Scalar x, Scalar) const override { + using std::cosh; + return cosh(x); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double x, double, double parent_adjoint) const override { - return parent_adjoint * std::sinh(x); + Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override { + using std::sinh; + return parent_adjoint * sinh(x); } - ExpressionPtr grad_expr_l( - const ExpressionPtr& x, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { - return parent_adjoint * slp::detail::sinh(x); + ExpressionPtr grad_expr_l( + const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { + return parent_adjoint * sinh(x); } }; /** - * std::cosh() for Expressions. + * cosh() for Expressions. * + * @tparam Scalar Scalar type. * @param x The argument. */ -inline ExpressionPtr cosh(const ExpressionPtr& x) { +template +ExpressionPtr cosh(const ExpressionPtr& x) { using enum ExpressionType; + using std::cosh; // Prune expression - if (x->is_constant(0.0)) { - return make_expression_ptr(1.0); + if (x->is_constant(Scalar(0))) { + return constant_ptr(Scalar(1)); } // Evaluate constant if (x->type() == CONSTANT) { - return make_expression_ptr(std::cosh(x->val)); + return constant_ptr(cosh(x->val)); } - return make_expression_ptr(x); + return make_expression_ptr>(x); } /** - * Derived expression type for std::erf(). + * Derived expression type for erf(). + * + * @tparam Scalar Scalar type. */ -struct ErfExpression final : Expression { +template +struct ErfExpression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr ErfExpression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr ErfExpression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double x, double) const override { return std::erf(x); } + Scalar value(Scalar x, Scalar) const override { + using std::erf; + return erf(x); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double x, double, double parent_adjoint) const override { - return parent_adjoint * 2.0 * std::numbers::inv_sqrtpi * std::exp(-x * x); + Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override { + using std::exp; + return parent_adjoint * Scalar(2.0 * std::numbers::inv_sqrtpi) * + exp(-x * x); } - ExpressionPtr grad_expr_l( - const ExpressionPtr& x, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { + ExpressionPtr grad_expr_l( + const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { return parent_adjoint * - make_expression_ptr(2.0 * - std::numbers::inv_sqrtpi) * - slp::detail::exp(-x * x); + constant_ptr(Scalar(2.0 * std::numbers::inv_sqrtpi)) * exp(-x * x); } }; /** - * std::erf() for Expressions. + * erf() for Expressions. * + * @tparam Scalar Scalar type. * @param x The argument. */ -inline ExpressionPtr erf(const ExpressionPtr& x) { +template +ExpressionPtr erf(const ExpressionPtr& x) { using enum ExpressionType; + using std::erf; // Prune expression - if (x->is_constant(0.0)) { + if (x->is_constant(Scalar(0))) { // Return zero return x; } // Evaluate constant if (x->type() == CONSTANT) { - return make_expression_ptr(std::erf(x->val)); + return constant_ptr(erf(x->val)); } - return make_expression_ptr(x); + return make_expression_ptr>(x); } /** - * Derived expression type for std::exp(). + * Derived expression type for exp(). + * + * @tparam Scalar Scalar type. */ -struct ExpExpression final : Expression { +template +struct ExpExpression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr ExpExpression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr ExpExpression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double x, double) const override { return std::exp(x); } + Scalar value(Scalar x, Scalar) const override { + using std::exp; + return exp(x); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double x, double, double parent_adjoint) const override { - return parent_adjoint * std::exp(x); + Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override { + using std::exp; + return parent_adjoint * exp(x); } - ExpressionPtr grad_expr_l( - const ExpressionPtr& x, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { - return parent_adjoint * slp::detail::exp(x); + ExpressionPtr grad_expr_l( + const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { + return parent_adjoint * exp(x); } }; /** - * std::exp() for Expressions. + * exp() for Expressions. * + * @tparam Scalar Scalar type. * @param x The argument. */ -inline ExpressionPtr exp(const ExpressionPtr& x) { +template +ExpressionPtr exp(const ExpressionPtr& x) { using enum ExpressionType; + using std::exp; // Prune expression - if (x->is_constant(0.0)) { - return make_expression_ptr(1.0); + if (x->is_constant(Scalar(0))) { + return constant_ptr(Scalar(1)); } // Evaluate constant if (x->type() == CONSTANT) { - return make_expression_ptr(std::exp(x->val)); + return constant_ptr(exp(x->val)); } - return make_expression_ptr(x); + return make_expression_ptr>(x); } -inline ExpressionPtr hypot(const ExpressionPtr& x, const ExpressionPtr& y); +template +ExpressionPtr hypot(const ExpressionPtr& x, + const ExpressionPtr& y); /** - * Derived expression type for std::hypot(). + * Derived expression type for hypot(). + * + * @tparam Scalar Scalar type. */ -struct HypotExpression final : Expression { +template +struct HypotExpression final : Expression { /** * Constructs a binary expression (an operator with two arguments). * * @param lhs Binary operator's left operand. * @param rhs Binary operator's right operand. */ - constexpr HypotExpression(ExpressionPtr lhs, ExpressionPtr rhs) - : Expression{std::move(lhs), std::move(rhs)} {} + constexpr HypotExpression(ExpressionPtr lhs, + ExpressionPtr rhs) + : Expression{std::move(lhs), std::move(rhs)} {} - double value(double x, double y) const override { return std::hypot(x, y); } + Scalar value(Scalar x, Scalar y) const override { + using std::hypot; + return hypot(x, y); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double x, double y, double parent_adjoint) const override { - return parent_adjoint * x / std::hypot(x, y); + Scalar grad_l(Scalar x, Scalar y, Scalar parent_adjoint) const override { + using std::hypot; + return parent_adjoint * x / hypot(x, y); } - double grad_r(double x, double y, double parent_adjoint) const override { - return parent_adjoint * y / std::hypot(x, y); + Scalar grad_r(Scalar x, Scalar y, Scalar parent_adjoint) const override { + using std::hypot; + return parent_adjoint * y / hypot(x, y); } - ExpressionPtr grad_expr_l( - const ExpressionPtr& x, const ExpressionPtr& y, - const ExpressionPtr& parent_adjoint) const override { - return parent_adjoint * x / slp::detail::hypot(x, y); + ExpressionPtr grad_expr_l( + const ExpressionPtr& x, const ExpressionPtr& y, + const ExpressionPtr& parent_adjoint) const override { + return parent_adjoint * x / hypot(x, y); } - ExpressionPtr grad_expr_r( - const ExpressionPtr& x, const ExpressionPtr& y, - const ExpressionPtr& parent_adjoint) const override { - return parent_adjoint * y / slp::detail::hypot(x, y); + ExpressionPtr grad_expr_r( + const ExpressionPtr& x, const ExpressionPtr& y, + const ExpressionPtr& parent_adjoint) const override { + return parent_adjoint * y / hypot(x, y); } }; /** - * std::hypot() for Expressions. + * hypot() for Expressions. * + * @tparam Scalar Scalar type. * @param x The x argument. * @param y The y argument. */ -inline ExpressionPtr hypot(const ExpressionPtr& x, const ExpressionPtr& y) { +template +ExpressionPtr hypot(const ExpressionPtr& x, + const ExpressionPtr& y) { using enum ExpressionType; + using std::hypot; // Prune expression - if (x->is_constant(0.0)) { + if (x->is_constant(Scalar(0))) { return y; - } else if (y->is_constant(0.0)) { + } else if (y->is_constant(Scalar(0))) { return x; } // Evaluate constant if (x->type() == CONSTANT && y->type() == CONSTANT) { - return make_expression_ptr(std::hypot(x->val, y->val)); + return constant_ptr(hypot(x->val, y->val)); } - return make_expression_ptr(x, y); + return make_expression_ptr>(x, y); } /** - * Derived expression type for std::log(). + * Derived expression type for log(). + * + * @tparam Scalar Scalar type. */ -struct LogExpression final : Expression { +template +struct LogExpression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr LogExpression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr LogExpression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double x, double) const override { return std::log(x); } + Scalar value(Scalar x, Scalar) const override { + using std::log; + return log(x); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double x, double, double parent_adjoint) const override { + Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override { return parent_adjoint / x; } - ExpressionPtr grad_expr_l( - const ExpressionPtr& x, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { + ExpressionPtr grad_expr_l( + const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { return parent_adjoint / x; } }; /** - * std::log() for Expressions. + * log() for Expressions. * + * @tparam Scalar Scalar type. * @param x The argument. */ -inline ExpressionPtr log(const ExpressionPtr& x) { +template +ExpressionPtr log(const ExpressionPtr& x) { using enum ExpressionType; + using std::log; // Prune expression - if (x->is_constant(0.0)) { + if (x->is_constant(Scalar(0))) { // Return zero return x; } // Evaluate constant if (x->type() == CONSTANT) { - return make_expression_ptr(std::log(x->val)); + return constant_ptr(log(x->val)); } - return make_expression_ptr(x); + return make_expression_ptr>(x); } /** - * Derived expression type for std::log10(). + * Derived expression type for log10(). + * + * @tparam Scalar Scalar type. */ -struct Log10Expression final : Expression { +template +struct Log10Expression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr Log10Expression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr Log10Expression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double x, double) const override { return std::log10(x); } + Scalar value(Scalar x, Scalar) const override { + using std::log10; + return log10(x); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double x, double, double parent_adjoint) const override { - return parent_adjoint / (std::numbers::ln10 * x); + Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override { + return parent_adjoint / (Scalar(std::numbers::ln10) * x); } - ExpressionPtr grad_expr_l( - const ExpressionPtr& x, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { - return parent_adjoint / - (make_expression_ptr(std::numbers::ln10) * x); + ExpressionPtr grad_expr_l( + const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { + return parent_adjoint / (constant_ptr(Scalar(std::numbers::ln10)) * x); } }; /** - * std::log10() for Expressions. + * log10() for Expressions. * + * @tparam Scalar Scalar type. * @param x The argument. */ -inline ExpressionPtr log10(const ExpressionPtr& x) { +template +ExpressionPtr log10(const ExpressionPtr& x) { using enum ExpressionType; + using std::log10; // Prune expression - if (x->is_constant(0.0)) { + if (x->is_constant(Scalar(0))) { // Return zero return x; } // Evaluate constant if (x->type() == CONSTANT) { - return make_expression_ptr(std::log10(x->val)); + return constant_ptr(log10(x->val)); } - return make_expression_ptr(x); + return make_expression_ptr>(x); } -inline ExpressionPtr pow(const ExpressionPtr& base, const ExpressionPtr& power); +template +ExpressionPtr pow(const ExpressionPtr& base, + const ExpressionPtr& power); /** - * Derived expression type for std::pow(). + * Derived expression type for pow(). * - * @tparam Expression type. + * @tparam Scalar Scalar type. + * @tparam T Expression type. */ -template -struct PowExpression final : Expression { +template +struct PowExpression final : Expression { /** * Constructs a binary expression (an operator with two arguments). * * @param lhs Binary operator's left operand. * @param rhs Binary operator's right operand. */ - constexpr PowExpression(ExpressionPtr lhs, ExpressionPtr rhs) - : Expression{std::move(lhs), std::move(rhs)} {} + constexpr PowExpression(ExpressionPtr lhs, ExpressionPtr rhs) + : Expression{std::move(lhs), std::move(rhs)} {} - double value(double base, double power) const override { - return std::pow(base, power); + Scalar value(Scalar base, Scalar power) const override { + using std::pow; + return pow(base, power); } ExpressionType type() const override { return T; } - double grad_l(double base, double power, - double parent_adjoint) const override { - return parent_adjoint * std::pow(base, power - 1) * power; + Scalar grad_l(Scalar base, Scalar power, + Scalar parent_adjoint) const override { + using std::pow; + return parent_adjoint * pow(base, power - Scalar(1)) * power; } - double grad_r(double base, double power, - double parent_adjoint) const override { - // Since x * std::log(x) -> 0 as x -> 0 - if (base == 0.0) { - return 0.0; + Scalar grad_r(Scalar base, Scalar power, + Scalar parent_adjoint) const override { + using std::log; + using std::pow; + + // Since x log(x) -> 0 as x -> 0 + if (base == Scalar(0)) { + return Scalar(0); } else { - return parent_adjoint * std::pow(base, power - 1) * base * std::log(base); + return parent_adjoint * pow(base, power) * log(base); } } - ExpressionPtr grad_expr_l( - const ExpressionPtr& base, const ExpressionPtr& power, - const ExpressionPtr& parent_adjoint) const override { - return parent_adjoint * - slp::detail::pow(base, - power - make_expression_ptr(1.0)) * - power; + ExpressionPtr grad_expr_l( + const ExpressionPtr& base, const ExpressionPtr& power, + const ExpressionPtr& parent_adjoint) const override { + return parent_adjoint * pow(base, power - constant_ptr(Scalar(1))) * power; } - ExpressionPtr grad_expr_r( - const ExpressionPtr& base, const ExpressionPtr& power, - const ExpressionPtr& parent_adjoint) const override { - // Since x * std::log(x) -> 0 as x -> 0 - if (base->val == 0.0) { + ExpressionPtr grad_expr_r( + const ExpressionPtr& base, const ExpressionPtr& power, + const ExpressionPtr& parent_adjoint) const override { + // Since x log(x) -> 0 as x -> 0 + if (base->val == Scalar(0)) { // Return zero return base; } else { - return parent_adjoint * - slp::detail::pow( - base, power - make_expression_ptr(1.0)) * - base * slp::detail::log(base); + return parent_adjoint * pow(base, power) * log(base); } } }; /** - * std::pow() for Expressions. + * pow() for Expressions. * + * @tparam Scalar Scalar type. * @param base The base. * @param power The power. */ -inline ExpressionPtr pow(const ExpressionPtr& base, - const ExpressionPtr& power) { +template +ExpressionPtr pow(const ExpressionPtr& base, + const ExpressionPtr& power) { using enum ExpressionType; + using std::pow; // Prune expression - if (base->is_constant(0.0)) { + if (base->is_constant(Scalar(0))) { // Return zero return base; - } else if (base->is_constant(1.0)) { + } else if (base->is_constant(Scalar(1))) { // Return one return base; } - if (power->is_constant(0.0)) { - return make_expression_ptr(1.0); - } else if (power->is_constant(1.0)) { + if (power->is_constant(Scalar(0))) { + return constant_ptr(Scalar(1)); + } else if (power->is_constant(Scalar(1))) { return base; } // Evaluate constant if (base->type() == CONSTANT && power->type() == CONSTANT) { - return make_expression_ptr( - std::pow(base->val, power->val)); + return constant_ptr(pow(base->val, power->val)); } - if (power->is_constant(2.0)) { + if (power->is_constant(Scalar(2))) { if (base->type() == LINEAR) { - return make_expression_ptr>(base, base); + return make_expression_ptr>(base, base); } else { - return make_expression_ptr>(base, base); + return make_expression_ptr>(base, base); } } - return make_expression_ptr>(base, power); + return make_expression_ptr>(base, power); } /** * Derived expression type for sign(). + * + * @tparam Scalar Scalar type. */ -struct SignExpression final : Expression { +template +struct SignExpression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr SignExpression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr SignExpression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double x, double) const override { - if (x < 0.0) { - return -1.0; - } else if (x == 0.0) { - return 0.0; + Scalar value(Scalar x, Scalar) const override { + if (x < Scalar(0)) { + return Scalar(-1); + } else if (x == Scalar(0)) { + return Scalar(0); } else { - return 1.0; + return Scalar(1); } } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - - double grad_l(double, double, double) const override { return 0.0; } - - ExpressionPtr grad_expr_l(const ExpressionPtr&, const ExpressionPtr&, - const ExpressionPtr&) const override { - // Return zero - return make_expression_ptr(); - } }; /** * sign() for Expressions. * + * @tparam Scalar Scalar type. * @param x The argument. */ -inline ExpressionPtr sign(const ExpressionPtr& x) { +template +ExpressionPtr sign(const ExpressionPtr& x) { using enum ExpressionType; // Evaluate constant if (x->type() == CONSTANT) { - if (x->val < 0.0) { - return make_expression_ptr(-1.0); - } else if (x->val == 0.0) { + if (x->val < Scalar(0)) { + return constant_ptr(Scalar(-1)); + } else if (x->val == Scalar(0)) { // Return zero return x; } else { - return make_expression_ptr(1.0); + return constant_ptr(Scalar(1)); } } - return make_expression_ptr(x); + return make_expression_ptr>(x); } /** - * Derived expression type for std::sin(). + * Derived expression type for sin(). + * + * @tparam Scalar Scalar type. */ -struct SinExpression final : Expression { +template +struct SinExpression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr SinExpression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr SinExpression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double x, double) const override { return std::sin(x); } + Scalar value(Scalar x, Scalar) const override { + using std::sin; + return sin(x); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double x, double, double parent_adjoint) const override { - return parent_adjoint * std::cos(x); + Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override { + using std::cos; + return parent_adjoint * cos(x); } - ExpressionPtr grad_expr_l( - const ExpressionPtr& x, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { - return parent_adjoint * slp::detail::cos(x); + ExpressionPtr grad_expr_l( + const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { + return parent_adjoint * cos(x); } }; /** - * std::sin() for Expressions. + * sin() for Expressions. * + * @tparam Scalar Scalar type. * @param x The argument. */ -inline ExpressionPtr sin(const ExpressionPtr& x) { +template +ExpressionPtr sin(const ExpressionPtr& x) { using enum ExpressionType; + using std::sin; // Prune expression - if (x->is_constant(0.0)) { + if (x->is_constant(Scalar(0))) { // Return zero return x; } // Evaluate constant if (x->type() == CONSTANT) { - return make_expression_ptr(std::sin(x->val)); + return constant_ptr(sin(x->val)); } - return make_expression_ptr(x); + return make_expression_ptr>(x); } /** - * Derived expression type for std::sinh(). + * Derived expression type for sinh(). + * + * @tparam Scalar Scalar type. */ -struct SinhExpression final : Expression { +template +struct SinhExpression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr SinhExpression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr SinhExpression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double x, double) const override { return std::sinh(x); } + Scalar value(Scalar x, Scalar) const override { + using std::sinh; + return sinh(x); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double x, double, double parent_adjoint) const override { - return parent_adjoint * std::cosh(x); + Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override { + using std::cosh; + return parent_adjoint * cosh(x); } - ExpressionPtr grad_expr_l( - const ExpressionPtr& x, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { - return parent_adjoint * slp::detail::cosh(x); + ExpressionPtr grad_expr_l( + const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { + return parent_adjoint * cosh(x); } }; /** - * std::sinh() for Expressions. + * sinh() for Expressions. * + * @tparam Scalar Scalar type. * @param x The argument. */ -inline ExpressionPtr sinh(const ExpressionPtr& x) { +template +ExpressionPtr sinh(const ExpressionPtr& x) { using enum ExpressionType; + using std::sinh; // Prune expression - if (x->is_constant(0.0)) { + if (x->is_constant(Scalar(0))) { // Return zero return x; } // Evaluate constant if (x->type() == CONSTANT) { - return make_expression_ptr(std::sinh(x->val)); + return constant_ptr(sinh(x->val)); } - return make_expression_ptr(x); + return make_expression_ptr>(x); } /** - * Derived expression type for std::sqrt(). + * Derived expression type for sqrt(). + * + * @tparam Scalar Scalar type. */ -struct SqrtExpression final : Expression { +template +struct SqrtExpression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr SqrtExpression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr SqrtExpression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double x, double) const override { return std::sqrt(x); } + Scalar value(Scalar x, Scalar) const override { + using std::sqrt; + return sqrt(x); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double x, double, double parent_adjoint) const override { - return parent_adjoint / (2.0 * std::sqrt(x)); + Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override { + using std::sqrt; + return parent_adjoint / (Scalar(2) * sqrt(x)); } - ExpressionPtr grad_expr_l( - const ExpressionPtr& x, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { - return parent_adjoint / - (make_expression_ptr(2.0) * slp::detail::sqrt(x)); + ExpressionPtr grad_expr_l( + const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { + return parent_adjoint / (constant_ptr(Scalar(2)) * sqrt(x)); } }; /** - * std::sqrt() for Expressions. + * sqrt() for Expressions. * + * @tparam Scalar Scalar type. * @param x The argument. */ -inline ExpressionPtr sqrt(const ExpressionPtr& x) { +template +ExpressionPtr sqrt(const ExpressionPtr& x) { using enum ExpressionType; + using std::sqrt; // Evaluate constant if (x->type() == CONSTANT) { - if (x->val == 0.0) { + if (x->val == Scalar(0)) { // Return zero return x; - } else if (x->val == 1.0) { + } else if (x->val == Scalar(1)) { return x; } else { - return make_expression_ptr(std::sqrt(x->val)); + return constant_ptr(sqrt(x->val)); } } - return make_expression_ptr(x); + return make_expression_ptr>(x); } /** - * Derived expression type for std::tan(). + * Derived expression type for tan(). + * + * @tparam Scalar Scalar type. */ -struct TanExpression final : Expression { +template +struct TanExpression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr TanExpression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr TanExpression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double x, double) const override { return std::tan(x); } + Scalar value(Scalar x, Scalar) const override { + using std::tan; + return tan(x); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double x, double, double parent_adjoint) const override { - return parent_adjoint / (std::cos(x) * std::cos(x)); + Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override { + using std::cos; + + auto c = cos(x); + return parent_adjoint / (c * c); } - ExpressionPtr grad_expr_l( - const ExpressionPtr& x, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { - return parent_adjoint / (slp::detail::cos(x) * slp::detail::cos(x)); + ExpressionPtr grad_expr_l( + const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { + auto c = cos(x); + return parent_adjoint / (c * c); } }; /** - * std::tan() for Expressions. + * tan() for Expressions. * + * @tparam Scalar Scalar type. * @param x The argument. */ -inline ExpressionPtr tan(const ExpressionPtr& x) { +template +ExpressionPtr tan(const ExpressionPtr& x) { using enum ExpressionType; + using std::tan; // Prune expression - if (x->is_constant(0.0)) { + if (x->is_constant(Scalar(0))) { // Return zero return x; } // Evaluate constant if (x->type() == CONSTANT) { - return make_expression_ptr(std::tan(x->val)); + return constant_ptr(tan(x->val)); } - return make_expression_ptr(x); + return make_expression_ptr>(x); } /** - * Derived expression type for std::tanh(). + * Derived expression type for tanh(). + * + * @tparam Scalar Scalar type. */ -struct TanhExpression final : Expression { +template +struct TanhExpression final : Expression { /** * Constructs an unary expression (an operator with one argument). * * @param lhs Unary operator's operand. */ - explicit constexpr TanhExpression(ExpressionPtr lhs) - : Expression{std::move(lhs)} {} + explicit constexpr TanhExpression(ExpressionPtr lhs) + : Expression{std::move(lhs)} {} - double value(double x, double) const override { return std::tanh(x); } + Scalar value(Scalar x, Scalar) const override { + using std::tanh; + return tanh(x); + } ExpressionType type() const override { return ExpressionType::NONLINEAR; } - double grad_l(double x, double, double parent_adjoint) const override { - return parent_adjoint / (std::cosh(x) * std::cosh(x)); + Scalar grad_l(Scalar x, Scalar, Scalar parent_adjoint) const override { + using std::cosh; + + auto c = cosh(x); + return parent_adjoint / (c * c); } - ExpressionPtr grad_expr_l( - const ExpressionPtr& x, const ExpressionPtr&, - const ExpressionPtr& parent_adjoint) const override { - return parent_adjoint / (slp::detail::cosh(x) * slp::detail::cosh(x)); + ExpressionPtr grad_expr_l( + const ExpressionPtr& x, const ExpressionPtr&, + const ExpressionPtr& parent_adjoint) const override { + auto c = cosh(x); + return parent_adjoint / (c * c); } }; /** - * std::tanh() for Expressions. + * tanh() for Expressions. * + * @tparam Scalar Scalar type. * @param x The argument. */ -inline ExpressionPtr tanh(const ExpressionPtr& x) { +template +ExpressionPtr tanh(const ExpressionPtr& x) { using enum ExpressionType; + using std::tanh; // Prune expression - if (x->is_constant(0.0)) { + if (x->is_constant(Scalar(0))) { // Return zero return x; } // Evaluate constant if (x->type() == CONSTANT) { - return make_expression_ptr(std::tanh(x->val)); + return constant_ptr(tanh(x->val)); } - return make_expression_ptr(x); + return make_expression_ptr>(x); } } // namespace slp::detail diff --git a/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/expression_graph.hpp b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/expression_graph.hpp index 1e32cb07c2..df3d93e9ad 100644 --- a/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/expression_graph.hpp +++ b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/expression_graph.hpp @@ -15,11 +15,13 @@ namespace slp::detail { * * https://en.wikipedia.org/wiki/Topological_sorting * + * @tparam Scalar Scalar type. * @param root The root node of the expression. */ -inline gch::small_vector topological_sort( - const ExpressionPtr& root) { - gch::small_vector list; +template +gch::small_vector*> topological_sort( + const ExpressionPtr& root) { + gch::small_vector*> list; // If the root type is constant, updates are a no-op, so return an empty list if (root == nullptr || root->type() == ExpressionType::CONSTANT) { @@ -27,7 +29,7 @@ inline gch::small_vector topological_sort( } // Stack of nodes to explore - gch::small_vector stack; + gch::small_vector*> stack; // Enumerate incoming edges for each node via depth-first search stack.emplace_back(root.get()); @@ -72,20 +74,18 @@ inline gch::small_vector topological_sort( * Update the values of all nodes in this graph based on the values of * their dependent nodes. * + * @tparam Scalar Scalar type. * @param list Topological sort of graph from parent to child. */ -inline void update_values(const gch::small_vector& list) { +template +void update_values(const gch::small_vector*>& list) { // Traverse graph from child to parent and update values for (auto& node : list | std::views::reverse) { auto& lhs = node->args[0]; auto& rhs = node->args[1]; if (lhs != nullptr) { - if (rhs != nullptr) { - node->val = node->value(lhs->val, rhs->val); - } else { - node->val = node->value(lhs->val, 0.0); - } + node->val = node->value(lhs->val, rhs ? rhs->val : Scalar(0)); } } } diff --git a/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/gradient.hpp b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/gradient.hpp index e9944652d8..80720accf8 100644 --- a/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/gradient.hpp +++ b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/gradient.hpp @@ -20,8 +20,11 @@ namespace slp { * * The gradient is only recomputed if the variable expression is quadratic or * higher order. + * + * @tparam Scalar Scalar type. */ -class SLEIPNIR_DLLEXPORT Gradient { +template +class Gradient { public: /** * Constructs a Gradient object. @@ -29,7 +32,7 @@ class SLEIPNIR_DLLEXPORT Gradient { * @param variable Variable of which to compute the gradient. * @param wrt Variable with respect to which to compute the gradient. */ - Gradient(Variable variable, Variable wrt) + Gradient(Variable variable, Variable wrt) : m_jacobian{std::move(variable), std::move(wrt)} {} /** @@ -39,7 +42,7 @@ class SLEIPNIR_DLLEXPORT Gradient { * @param wrt Vector of variables with respect to which to compute the * gradient. */ - Gradient(Variable variable, SleipnirMatrixLike auto wrt) + Gradient(Variable variable, SleipnirMatrixLike auto wrt) : m_jacobian{VariableMatrix{std::move(variable)}, std::move(wrt)} {} /** @@ -50,23 +53,26 @@ class SLEIPNIR_DLLEXPORT Gradient { * * @return The gradient as a VariableMatrix. */ - VariableMatrix get() const { return m_jacobian.get().T(); } + VariableMatrix get() const { return m_jacobian.get().T(); } /** * Evaluates the gradient at wrt's value. * * @return The gradient at wrt's value. */ - const Eigen::SparseVector& value() { + const Eigen::SparseVector& value() { m_g = m_jacobian.value().transpose(); return m_g; } private: - Eigen::SparseVector m_g; + Eigen::SparseVector m_g; - Jacobian m_jacobian; + Jacobian m_jacobian; }; +extern template class EXPORT_TEMPLATE_DECLARE(SLEIPNIR_DLLEXPORT) +Gradient; + } // namespace slp diff --git a/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/hessian.hpp b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/hessian.hpp index 4f093b7b39..10ee142ff8 100644 --- a/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/hessian.hpp +++ b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/hessian.hpp @@ -12,7 +12,6 @@ #include "sleipnir/autodiff/variable_matrix.hpp" #include "sleipnir/util/assert.hpp" #include "sleipnir/util/concepts.hpp" -#include "sleipnir/util/symbol_exports.hpp" namespace slp { @@ -23,11 +22,12 @@ namespace slp { * The gradient tree is cached so subsequent Hessian calculations are faster, * and the Hessian is only recomputed if the variable expression is nonlinear. * + * @tparam Scalar Scalar type. * @tparam UpLo Which part of the Hessian to compute (Lower or Lower | Upper). */ -template +template requires(UpLo == Eigen::Lower) || (UpLo == (Eigen::Lower | Eigen::Upper)) -class SLEIPNIR_DLLEXPORT Hessian { +class Hessian { public: /** * Constructs a Hessian object. @@ -35,8 +35,8 @@ class SLEIPNIR_DLLEXPORT Hessian { * @param variable Variable of which to compute the Hessian. * @param wrt Variable with respect to which to compute the Hessian. */ - Hessian(Variable variable, Variable wrt) - : Hessian{std::move(variable), VariableMatrix{std::move(wrt)}} {} + Hessian(Variable variable, Variable wrt) + : Hessian{std::move(variable), VariableMatrix{std::move(wrt)}} {} /** * Constructs a Hessian object. @@ -45,8 +45,8 @@ class SLEIPNIR_DLLEXPORT Hessian { * @param wrt Vector of variables with respect to which to compute the * Hessian. */ - Hessian(Variable variable, SleipnirMatrixLike auto wrt) - : m_variables{detail::AdjointExpressionGraph{variable} + Hessian(Variable variable, SleipnirMatrixLike auto wrt) + : m_variables{detail::AdjointExpressionGraph{variable} .generate_gradient_tree(wrt)}, m_wrt{wrt} { slp_assert(m_wrt.cols() == 1); @@ -74,7 +74,7 @@ class SLEIPNIR_DLLEXPORT Hessian { // If the row is linear, compute its gradient once here and cache its // triplets. Constant rows are ignored because their gradients have no // nonzero triplets. - m_graphs[row].append_adjoint_triplets(m_cached_triplets, row, m_wrt); + m_graphs[row].append_gradient_triplets(m_cached_triplets, row, m_wrt); } else if (m_variables[row].type() > ExpressionType::LINEAR) { // If the row is quadratic or nonlinear, add it to the list of nonlinear // rows to be recomputed in Value(). @@ -85,7 +85,7 @@ class SLEIPNIR_DLLEXPORT Hessian { if (m_nonlinear_rows.empty()) { m_H.setFromTriplets(m_cached_triplets.begin(), m_cached_triplets.end()); if constexpr (UpLo == Eigen::Lower) { - m_H = m_H.triangularView(); + m_H = m_H.template triangularView(); } } } @@ -98,9 +98,9 @@ class SLEIPNIR_DLLEXPORT Hessian { * * @return The Hessian as a VariableMatrix. */ - VariableMatrix get() const { - VariableMatrix result{VariableMatrix::empty, m_variables.rows(), - m_wrt.rows()}; + VariableMatrix get() const { + VariableMatrix result{detail::empty, m_variables.rows(), + m_wrt.rows()}; for (int row = 0; row < m_variables.rows(); ++row) { auto grad = m_graphs[row].generate_gradient_tree(m_wrt); @@ -108,7 +108,7 @@ class SLEIPNIR_DLLEXPORT Hessian { if (grad[col].expr != nullptr) { result(row, col) = std::move(grad[col]); } else { - result(row, col) = Variable{0.0}; + result(row, col) = Variable{Scalar(0)}; } } } @@ -121,7 +121,7 @@ class SLEIPNIR_DLLEXPORT Hessian { * * @return The Hessian at wrt's value. */ - const Eigen::SparseMatrix& value() { + const Eigen::SparseMatrix& value() { if (m_nonlinear_rows.empty()) { return m_H; } @@ -136,31 +136,36 @@ class SLEIPNIR_DLLEXPORT Hessian { // Compute each nonlinear row of the Hessian for (int row : m_nonlinear_rows) { - m_graphs[row].append_adjoint_triplets(triplets, row, m_wrt); + m_graphs[row].append_gradient_triplets(triplets, row, m_wrt); } m_H.setFromTriplets(triplets.begin(), triplets.end()); if constexpr (UpLo == Eigen::Lower) { - m_H = m_H.triangularView(); + m_H = m_H.template triangularView(); } return m_H; } private: - VariableMatrix m_variables; - VariableMatrix m_wrt; + VariableMatrix m_variables; + VariableMatrix m_wrt; - gch::small_vector m_graphs; + gch::small_vector> m_graphs; - Eigen::SparseMatrix m_H{m_variables.rows(), m_wrt.rows()}; + Eigen::SparseMatrix m_H{m_variables.rows(), m_wrt.rows()}; // Cached triplets for gradients of linear rows - gch::small_vector> m_cached_triplets; + gch::small_vector> m_cached_triplets; // List of row indices for nonlinear rows whose graients will be computed in // Value() gch::small_vector m_nonlinear_rows; }; +// @cond Suppress Doxygen +extern template class EXPORT_TEMPLATE_DECLARE(SLEIPNIR_DLLEXPORT) +Hessian; +// @endcond + } // namespace slp diff --git a/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/jacobian.hpp b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/jacobian.hpp index 3662b5e49b..c8e28a826f 100644 --- a/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/jacobian.hpp +++ b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/jacobian.hpp @@ -12,6 +12,7 @@ #include "sleipnir/autodiff/variable_matrix.hpp" #include "sleipnir/util/assert.hpp" #include "sleipnir/util/concepts.hpp" +#include "sleipnir/util/empty.hpp" #include "sleipnir/util/symbol_exports.hpp" namespace slp { @@ -22,8 +23,11 @@ namespace slp { * * The Jacobian is only recomputed if the variable expression is quadratic or * higher order. + * + * @tparam Scalar Scalar type. */ -class SLEIPNIR_DLLEXPORT Jacobian { +template +class Jacobian { public: /** * Constructs a Jacobian object. @@ -31,9 +35,19 @@ class SLEIPNIR_DLLEXPORT Jacobian { * @param variable Variable of which to compute the Jacobian. * @param wrt Variable with respect to which to compute the Jacobian. */ - Jacobian(Variable variable, Variable wrt) - : Jacobian{VariableMatrix{std::move(variable)}, - VariableMatrix{std::move(wrt)}} {} + Jacobian(Variable variable, Variable wrt) + : Jacobian{VariableMatrix{std::move(variable)}, + VariableMatrix{std::move(wrt)}} {} + + /** + * Constructs a Jacobian object. + * + * @param variable Variable of which to compute the Jacobian. + * @param wrt Vector of variables with respect to which to compute the + * Jacobian. + */ + Jacobian(Variable variable, SleipnirMatrixLike auto wrt) + : Jacobian{VariableMatrix{std::move(variable)}, std::move(wrt)} {} /** * Constructs a Jacobian object. @@ -42,7 +56,8 @@ class SLEIPNIR_DLLEXPORT Jacobian { * @param wrt Vector of variables with respect to which to compute the * Jacobian. */ - Jacobian(VariableMatrix variables, SleipnirMatrixLike auto wrt) + Jacobian(VariableMatrix variables, + SleipnirMatrixLike auto wrt) : m_variables{std::move(variables)}, m_wrt{std::move(wrt)} { slp_assert(m_variables.cols() == 1); slp_assert(m_wrt.cols() == 1); @@ -70,7 +85,7 @@ class SLEIPNIR_DLLEXPORT Jacobian { // If the row is linear, compute its gradient once here and cache its // triplets. Constant rows are ignored because their gradients have no // nonzero triplets. - m_graphs[row].append_adjoint_triplets(m_cached_triplets, row, m_wrt); + m_graphs[row].append_gradient_triplets(m_cached_triplets, row, m_wrt); } else if (m_variables[row].type() > ExpressionType::LINEAR) { // If the row is quadratic or nonlinear, add it to the list of nonlinear // rows to be recomputed in Value(). @@ -91,9 +106,9 @@ class SLEIPNIR_DLLEXPORT Jacobian { * * @return The Jacobian as a VariableMatrix. */ - VariableMatrix get() const { - VariableMatrix result{VariableMatrix::empty, m_variables.rows(), - m_wrt.rows()}; + VariableMatrix get() const { + VariableMatrix result{detail::empty, m_variables.rows(), + m_wrt.rows()}; for (int row = 0; row < m_variables.rows(); ++row) { auto grad = m_graphs[row].generate_gradient_tree(m_wrt); @@ -101,7 +116,7 @@ class SLEIPNIR_DLLEXPORT Jacobian { if (grad[col].expr != nullptr) { result(row, col) = std::move(grad[col]); } else { - result(row, col) = Variable{0.0}; + result(row, col) = Variable{Scalar(0)}; } } } @@ -114,7 +129,7 @@ class SLEIPNIR_DLLEXPORT Jacobian { * * @return The Jacobian at wrt's value. */ - const Eigen::SparseMatrix& value() { + const Eigen::SparseMatrix& value() { if (m_nonlinear_rows.empty()) { return m_J; } @@ -129,7 +144,7 @@ class SLEIPNIR_DLLEXPORT Jacobian { // Compute each nonlinear row of the Jacobian for (int row : m_nonlinear_rows) { - m_graphs[row].append_adjoint_triplets(triplets, row, m_wrt); + m_graphs[row].append_gradient_triplets(triplets, row, m_wrt); } m_J.setFromTriplets(triplets.begin(), triplets.end()); @@ -138,19 +153,22 @@ class SLEIPNIR_DLLEXPORT Jacobian { } private: - VariableMatrix m_variables; - VariableMatrix m_wrt; + VariableMatrix m_variables; + VariableMatrix m_wrt; - gch::small_vector m_graphs; + gch::small_vector> m_graphs; - Eigen::SparseMatrix m_J{m_variables.rows(), m_wrt.rows()}; + Eigen::SparseMatrix m_J{m_variables.rows(), m_wrt.rows()}; // Cached triplets for gradients of linear rows - gch::small_vector> m_cached_triplets; + gch::small_vector> m_cached_triplets; // List of row indices for nonlinear rows whose graients will be computed in // Value() gch::small_vector m_nonlinear_rows; }; +extern template class EXPORT_TEMPLATE_DECLARE(SLEIPNIR_DLLEXPORT) +Jacobian; + } // namespace slp diff --git a/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/sleipnir_base.hpp b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/sleipnir_base.hpp new file mode 100644 index 0000000000..874c0a4422 --- /dev/null +++ b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/sleipnir_base.hpp @@ -0,0 +1,13 @@ +// Copyright (c) Sleipnir contributors + +#pragma once + +namespace slp { + +/** + * Marker interface for concepts to determine whether a given scalar or matrix + * type belongs to Sleipnir. + */ +class SleipnirBase {}; + +} // namespace slp diff --git a/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/slice.hpp b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/slice.hpp index abb11f22f6..7df1ebc11d 100644 --- a/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/slice.hpp +++ b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/slice.hpp @@ -47,7 +47,8 @@ class SLEIPNIR_DLLEXPORT Slice { /** * Constructs a slice. */ - constexpr Slice(slicing::none_t) // NOLINT + // NOLINTNEXTLINE (google-explicit-constructor) + constexpr Slice(slicing::none_t) : Slice(0, std::numeric_limits::max(), 1) {} /** @@ -55,7 +56,8 @@ class SLEIPNIR_DLLEXPORT Slice { * * @param start Slice start index (inclusive). */ - constexpr Slice(int start) { // NOLINT + // NOLINTNEXTLINE (google-explicit-constructor) + constexpr Slice(int start) { this->start = start; this->stop = (start == -1) ? std::numeric_limits::max() : start + 1; this->step = 1; diff --git a/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/variable.hpp b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/variable.hpp index 2fc2119d2d..cb4c1a56ec 100644 --- a/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/variable.hpp +++ b/wpimath/src/main/native/thirdparty/sleipnir/include/sleipnir/autodiff/variable.hpp @@ -6,7 +6,6 @@ #include #include #include -#include #include #include @@ -15,9 +14,9 @@ #include "sleipnir/autodiff/expression.hpp" #include "sleipnir/autodiff/expression_graph.hpp" +#include "sleipnir/autodiff/sleipnir_base.hpp" #include "sleipnir/util/assert.hpp" #include "sleipnir/util/concepts.hpp" -#include "sleipnir/util/symbol_exports.hpp" #ifndef SLEIPNIR_DISABLE_DIAGNOSTICS #include "sleipnir/util/print.hpp" @@ -26,19 +25,34 @@ namespace slp { // Forward declarations for friend declarations in Variable + namespace detail { + +template class AdjointExpressionGraph; + } // namespace detail -template + +template requires(UpLo == Eigen::Lower) || (UpLo == (Eigen::Lower | Eigen::Upper)) -class SLEIPNIR_DLLEXPORT Hessian; -class SLEIPNIR_DLLEXPORT Jacobian; +class Hessian; + +template +class Jacobian; /** * An autodiff variable pointing to an expression node. + * + * @tparam Scalar_ Scalar type. */ -class SLEIPNIR_DLLEXPORT Variable { +template +class Variable : public SleipnirBase { public: + /** + * Scalar type alias. + */ + using Scalar = Scalar_; + /** * Constructs a linear Variable with a value of zero. */ @@ -50,43 +64,69 @@ class SLEIPNIR_DLLEXPORT Variable { explicit Variable(std::nullptr_t) : expr{nullptr} {} /** - * Constructs a Variable from a floating point type. + * Constructs a Variable from a scalar type. * * @param value The value of the Variable. */ - Variable(std::floating_point auto value) // NOLINT - : expr{detail::make_expression_ptr(value)} {} + // NOLINTNEXTLINE (google-explicit-constructor) + Variable(Scalar value) + requires(!MatrixLike) + : expr{detail::make_expression_ptr>( + value)} {} + + /** + * Constructs a Variable from a scalar type. + * + * @param value The value of the Variable. + */ + // NOLINTNEXTLINE (google-explicit-constructor) + Variable(SleipnirMatrixLike auto value) : expr{value(0, 0).expr} { + slp_assert(value.rows() == 1 && value.cols() == 1); + } + + /** + * Constructs a Variable from a floating-point type. + * + * @param value The value of the Variable. + */ + // NOLINTNEXTLINE (google-explicit-constructor) + Variable(std::floating_point auto value) + : expr{detail::make_expression_ptr>( + Scalar(value))} {} /** * Constructs a Variable from an integral type. * * @param value The value of the Variable. */ - Variable(std::integral auto value) // NOLINT - : expr{detail::make_expression_ptr(value)} {} + // NOLINTNEXTLINE (google-explicit-constructor) + Variable(std::integral auto value) + : expr{detail::make_expression_ptr>( + Scalar(value))} {} /** * Constructs a Variable pointing to the specified expression. * * @param expr The autodiff variable. */ - explicit Variable(const detail::ExpressionPtr& expr) : expr{expr} {} + explicit Variable(const detail::ExpressionPtr& expr) : expr{expr} {} /** * Constructs a Variable pointing to the specified expression. * * @param expr The autodiff variable. */ - explicit Variable(detail::ExpressionPtr&& expr) : expr{std::move(expr)} {} + explicit Variable(detail::ExpressionPtr&& expr) + : expr{std::move(expr)} {} /** - * Assignment operator for double. + * Assignment operator for scalar. * * @param value The value of the Variable. * @return This variable. */ - Variable& operator=(double value) { - expr = detail::make_expression_ptr(value); + Variable& operator=(ScalarLike auto value) { + expr = detail::make_expression_ptr>(value); m_graph_initialized = false; return *this; @@ -97,7 +137,7 @@ class SLEIPNIR_DLLEXPORT Variable { * * @param value The value of the Variable. */ - void set_value(double value) { + void set_value(Scalar value) { #ifndef SLEIPNIR_DISABLE_DIAGNOSTICS // We only need to check the first argument since unary and binary operators // both use it @@ -109,18 +149,65 @@ class SLEIPNIR_DLLEXPORT Variable { location.file_name(), location.line(), location.function_name()); } #endif - expr->val = value; + expr->val = Scalar(value); } /** - * Variable-Variable multiplication operator. + * Returns the value of this variable. + * + * @return The value of this variable. + */ + Scalar value() { + if (!m_graph_initialized) { + m_graph = detail::topological_sort(expr); + m_graph_initialized = true; + } + detail::update_values(m_graph); + + return Scalar(expr->val); + } + + /** + * Returns the type of this expression (constant, linear, quadratic, or + * nonlinear). + * + * @return The type of this expression. + */ + ExpressionType type() const { return expr->type(); } + + /** + * Variable-scalar multiplication operator. * * @param lhs Operator left-hand side. * @param rhs Operator right-hand side. * @return Result of multiplication. */ - friend SLEIPNIR_DLLEXPORT Variable operator*(const Variable& lhs, - const Variable& rhs) { + template RHS> + friend Variable operator*(const LHS& lhs, const RHS& rhs) { + return Variable{Variable{lhs}.expr * rhs.expr}; + } + + /** + * Variable-scalar multiplication operator. + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + * @return Result of multiplication. + */ + template LHS, ScalarLike RHS> + friend Variable operator*(const LHS& lhs, const RHS& rhs) { + return Variable{lhs.expr * Variable{rhs}.expr}; + } + + /** + * Variable-scalar multiplication operator. + * + * @param lhs Operator left-hand side. + * @param rhs Operator right-hand side. + * @return Result of multiplication. + */ + friend Variable operator*(const Variable& lhs, + const Variable& rhs) { return Variable{lhs.expr * rhs.expr}; } @@ -130,7 +217,7 @@ class SLEIPNIR_DLLEXPORT Variable { * @param rhs Operator right-hand side. * @return Result of multiplication. */ - Variable& operator*=(const Variable& rhs) { + Variable& operator*=(const Variable& rhs) { *this = *this * rhs; return *this; } @@ -142,8 +229,8 @@ class SLEIPNIR_DLLEXPORT Variable { * @param rhs Operator right-hand side. * @return Result of division. */ - friend SLEIPNIR_DLLEXPORT Variable operator/(const Variable& lhs, - const Variable& rhs) { + friend Variable operator/(const Variable& lhs, + const Variable& rhs) { return Variable{lhs.expr / rhs.expr}; } @@ -153,7 +240,7 @@ class SLEIPNIR_DLLEXPORT Variable { * @param rhs Operator right-hand side. * @return Result of division. */ - Variable& operator/=(const Variable& rhs) { + Variable& operator/=(const Variable& rhs) { *this = *this / rhs; return *this; } @@ -165,8 +252,8 @@ class SLEIPNIR_DLLEXPORT Variable { * @param rhs Operator right-hand side. * @return Result of addition. */ - friend SLEIPNIR_DLLEXPORT Variable operator+(const Variable& lhs, - const Variable& rhs) { + friend Variable operator+(const Variable& lhs, + const Variable& rhs) { return Variable{lhs.expr + rhs.expr}; } @@ -176,7 +263,7 @@ class SLEIPNIR_DLLEXPORT Variable { * @param rhs Operator right-hand side. * @return Result of addition. */ - Variable& operator+=(const Variable& rhs) { + Variable& operator+=(const Variable& rhs) { *this = *this + rhs; return *this; } @@ -188,8 +275,8 @@ class SLEIPNIR_DLLEXPORT Variable { * @param rhs Operator right-hand side. * @return Result of subtraction. */ - friend SLEIPNIR_DLLEXPORT Variable operator-(const Variable& lhs, - const Variable& rhs) { + friend Variable operator-(const Variable& lhs, + const Variable& rhs) { return Variable{lhs.expr - rhs.expr}; } @@ -199,7 +286,7 @@ class SLEIPNIR_DLLEXPORT Variable { * @param rhs Operator right-hand side. * @return Result of subtraction. */ - Variable& operator-=(const Variable& rhs) { + Variable& operator-=(const Variable& rhs) { *this = *this - rhs; return *this; } @@ -209,7 +296,7 @@ class SLEIPNIR_DLLEXPORT Variable { * * @param lhs Operand for unary minus. */ - friend SLEIPNIR_DLLEXPORT Variable operator-(const Variable& lhs) { + friend Variable operator-(const Variable& lhs) { return Variable{-lhs.expr}; } @@ -218,321 +305,476 @@ class SLEIPNIR_DLLEXPORT Variable { * * @param lhs Operand for unary plus. */ - friend SLEIPNIR_DLLEXPORT Variable operator+(const Variable& lhs) { + friend Variable operator+(const Variable& lhs) { return Variable{+lhs.expr}; } - /** - * Returns the value of this variable. - * - * @return The value of this variable. - */ - double value() { - if (!m_graph_initialized) { - m_graph = detail::topological_sort(expr); - m_graph_initialized = true; - } - detail::update_values(m_graph); - - return expr->val; - } - - /** - * Returns the type of this expression (constant, linear, quadratic, or - * nonlinear). - * - * @return The type of this expression. - */ - ExpressionType type() const { return expr->type(); } - private: /// The expression node - detail::ExpressionPtr expr = - detail::make_expression_ptr(); + detail::ExpressionPtr expr = + detail::make_expression_ptr>(); /// Used to update the value of this variable based on the values of its /// dependent variables - gch::small_vector m_graph; + gch::small_vector*> m_graph; /// Used for lazy initialization of m_graph bool m_graph_initialized = false; - friend SLEIPNIR_DLLEXPORT Variable abs(const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable acos(const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable asin(const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable atan(const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable atan2(const Variable& y, - const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable cbrt(const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable cos(const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable cosh(const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable erf(const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable exp(const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable hypot(const Variable& x, - const Variable& y); - friend SLEIPNIR_DLLEXPORT Variable log(const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable log10(const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable pow(const Variable& base, - const Variable& power); - friend SLEIPNIR_DLLEXPORT Variable sign(const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable sin(const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable sinh(const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable sqrt(const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable tan(const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable tanh(const Variable& x); - friend SLEIPNIR_DLLEXPORT Variable hypot(const Variable& x, const Variable& y, - const Variable& z); + template + friend Variable abs(const Variable& x); + template + friend Variable acos(const Variable& x); + template + friend Variable asin(const Variable& x); + template + friend Variable atan(const Variable& x); + template + friend Variable atan2(const ScalarLike auto& y, + const Variable& x); + template + friend Variable atan2(const Variable& y, + const ScalarLike auto& x); + template + friend Variable atan2(const Variable& y, + const Variable& x); + template + friend Variable cbrt(const Variable& x); + template + friend Variable cos(const Variable& x); + template + friend Variable cosh(const Variable& x); + template + friend Variable erf(const Variable& x); + template + friend Variable exp(const Variable& x); + template + friend Variable hypot(const ScalarLike auto& x, + const Variable& y); + template + friend Variable hypot(const Variable& x, + const ScalarLike auto& y); + template + friend Variable hypot(const Variable& x, + const Variable& y); + template + friend Variable log(const Variable& x); + template + friend Variable log10(const Variable& x); + template + friend Variable pow(const ScalarLike auto& base, + const Variable& power); + template + friend Variable pow(const Variable& base, + const ScalarLike auto& power); + template + friend Variable pow(const Variable& base, + const Variable& power); + template + friend Variable sign(const Variable& x); + template + friend Variable sin(const Variable& x); + template + friend Variable sinh(const Variable& x); + template + friend Variable sqrt(const Variable& x); + template + friend Variable tan(const Variable& x); + template + friend Variable tanh(const Variable& x); + template + friend Variable hypot(const Variable& x, + const Variable& y, + const Variable& z); - friend class detail::AdjointExpressionGraph; - template + friend class detail::AdjointExpressionGraph; + template requires(UpLo == Eigen::Lower) || (UpLo == (Eigen::Lower | Eigen::Upper)) - friend class SLEIPNIR_DLLEXPORT Hessian; - friend class SLEIPNIR_DLLEXPORT Jacobian; + friend class Hessian; + template + friend class Jacobian; }; +template