// Copyright (c) FIRST and other WPILib contributors. // Open Source Software; you can modify and/or share it under the terms of // the WPILib BSD license file in the root directory of this project. #pragma once #include #include #include #include #include "pb.h" #include "wpi/util/SmallVector.hpp" #include "wpi/util/array.hpp" #include "wpi/util/protobuf/Protobuf.hpp" namespace wpi::util { /** * The behavior to use when more elements are in the message then expected when * decoding. */ enum class DecodeLimits { // Ignore any extra elements Ignore, // Add any extra elements to the backing vector Add, // Cause decoding to fail if extra elements exist Fail, }; template concept StringLike = std::is_convertible_v; template concept ConstVectorLike = std::is_convertible_v>; template concept MutableVectorLike = std::is_convertible_v>; template concept PackBytes = StringLike || ConstVectorLike; template concept UnpackBytes = requires(T& t) { { t.resize(size_t()) }; // NOLINT { t.size() } -> std::same_as; { t.data() } -> std::convertible_to; } && (PackBytes || MutableVectorLike); template concept ProtoEnumeration = std::is_enum_v; template concept ProtoPackable = ProtoEnumeration || std::integral || std::floating_point; template concept ProtoCallbackPackable = ProtobufSerializable || PackBytes || ProtoPackable; template concept ProtoCallbackUnpackable = ProtobufSerializable || UnpackBytes || ProtoPackable; namespace detail { template concept Validatable = ProtoCallbackPackable || ProtoCallbackUnpackable; template constexpr bool ValidateType(pb_type_t type) { switch (type) { case PB_LTYPE_BOOL: return std::integral; case PB_LTYPE_VARINT: return std::signed_integral || ProtoEnumeration; case PB_LTYPE_UVARINT: return std::unsigned_integral; case PB_LTYPE_SVARINT: return std::signed_integral; case PB_LTYPE_FIXED32: return std::integral || std::floating_point; case PB_LTYPE_FIXED64: return std::integral || std::floating_point; case PB_LTYPE_BYTES: case PB_LTYPE_STRING: return PackBytes || UnpackBytes; case PB_LTYPE_SUBMESSAGE: return ProtobufSerializable; default: return false; } } } // namespace detail /** * A callback method that will directly unpack elements into * the specified vector like data structure. The size passed * is the expected number of elements. * * By default, any elements in the packed buffer past N will * still be added to the vector. * * @tparam T object type * @tparam U vector type to pack into * @tparam N number of elements */ template class DirectUnpackCallback { public: /** * Constructs a callback from a vector like type. * * @param storage the vector to store into */ explicit DirectUnpackCallback(U& storage) : m_storage{storage} { m_callback.funcs.decode = CallbackFunc; m_callback.arg = this; } DirectUnpackCallback(const DirectUnpackCallback&) = delete; DirectUnpackCallback(DirectUnpackCallback&&) = delete; DirectUnpackCallback& operator=(const DirectUnpackCallback&) = delete; DirectUnpackCallback& operator=(DirectUnpackCallback&&) = delete; /** * Set the limits on what happens if more elements exist in the buffer then * expected. * * @param limit the limit to set */ void SetLimits(DecodeLimits limit) noexcept { m_limits = limit; } /** * Gets the nanopb callback pointing to this object. * * @return nanopb callback */ pb_callback_t Callback() const { return m_callback; } private: bool SizeCheck(bool* retVal) const { if (m_storage.size() >= N) { switch (m_limits) { case DecodeLimits::Ignore: *retVal = true; return false; case DecodeLimits::Add: break; default: *retVal = false; return false; } } return true; } bool Decode(pb_istream_t* stream, pb_type_t fieldType) { if constexpr (ProtoPackable) { switch (fieldType) { case PB_LTYPE_BOOL: if constexpr (std::integral) { bool val = false; if (!pb_decode_bool(stream, &val)) { return false; } m_storage.emplace_back(static_cast(val)); return true; } else { return false; } case PB_LTYPE_VARINT: if constexpr (std::signed_integral || ProtoEnumeration) { int64_t val = 0; if (!pb_decode_varint(stream, reinterpret_cast(&val))) { return false; } m_storage.emplace_back(static_cast(val)); return true; } else { return false; } case PB_LTYPE_UVARINT: if constexpr (std::unsigned_integral) { uint64_t val = 0; if (!pb_decode_varint(stream, &val)) { return false; } m_storage.emplace_back(static_cast(val)); return true; } else { return false; } case PB_LTYPE_SVARINT: if constexpr (std::signed_integral) { int64_t val = 0; if (!pb_decode_svarint(stream, &val)) { return false; } m_storage.emplace_back(static_cast(val)); return true; } else { return false; } case PB_LTYPE_FIXED32: if constexpr (std::signed_integral) { int32_t val = 0; if (!pb_decode_fixed32(stream, &val)) { return false; } m_storage.emplace_back(static_cast(val)); return true; } else if constexpr (std::unsigned_integral) { uint32_t val = 0; if (!pb_decode_fixed32(stream, &val)) { return false; } m_storage.emplace_back(static_cast(val)); return true; } if constexpr (std::floating_point) { float val = 0; if (!pb_decode_fixed32(stream, &val)) { return false; } m_storage.emplace_back(static_cast(val)); return true; } else { return false; } case PB_LTYPE_FIXED64: if constexpr (std::signed_integral) { int64_t val = 0; if (!pb_decode_fixed64(stream, &val)) { return false; } m_storage.emplace_back(static_cast(val)); return true; } else if constexpr (std::unsigned_integral) { uint64_t val = 0; if (!pb_decode_fixed64(stream, &val)) { return false; } m_storage.emplace_back(static_cast(val)); return true; } if constexpr (std::floating_point) { double val = 0; if (!pb_decode_fixed64(stream, &val)) { return false; } m_storage.emplace_back(static_cast(val)); return true; } else { return false; } default: return false; } } else if constexpr (UnpackBytes) { T& space = m_storage.emplace_back(T{}); space.resize(stream->bytes_left); return pb_read(stream, reinterpret_cast(space.data()), space.size()); } else if constexpr (ProtobufSerializable) { ProtoInputStream istream{stream}; auto decoded = wpi::util::Protobuf::Unpack(istream); if (decoded.has_value()) { m_storage.emplace_back(std::move(decoded.value())); return true; } return false; } } bool CallbackFunc(pb_istream_t* stream, const pb_field_t* field) { pb_type_t fieldType = PB_LTYPE(field->type); if (!detail::ValidateType(fieldType)) { return false; } // Validate our types if constexpr (ProtoPackable) { // Handle decode loop while (stream->bytes_left > 0) { bool sizeRetVal = 0; if (!SizeCheck(&sizeRetVal)) { return sizeRetVal; } if (!Decode(stream, fieldType)) { return false; } } return true; } else { // At this point, do the size check bool sizeRetVal = 0; if (!SizeCheck(&sizeRetVal)) { return sizeRetVal; } // At this point, we're good to decode return Decode(stream, fieldType); } } static bool CallbackFunc(pb_istream_t* stream, const pb_field_t* field, void** arg) { return reinterpret_cast(*arg)->CallbackFunc(stream, field); } U& m_storage; pb_callback_t m_callback; DecodeLimits m_limits{DecodeLimits::Add}; }; /** * A DirectUnpackCallback backed by a SmallVector. * * By default, any elements in the packed buffer past N will * be ignored, but decoding will still succeed * * @tparam T object type * @tparam N small vector small size/number of expected elements */ template class UnpackCallback : public DirectUnpackCallback, N> { public: /** * Constructs an UnpackCallback. */ UnpackCallback() : DirectUnpackCallback, N>{m_storedBuffer} { this->SetLimits(DecodeLimits::Ignore); } /** * Gets a span pointing to the storage buffer. * * @return storage buffer span */ std::span Items() noexcept { return m_storedBuffer; } /** * Gets a const span pointing to the storage buffer. * * @return storage buffer span */ std::span Items() const noexcept { return m_storedBuffer; } /** * Gets a reference to the backing small vector. * * @return small vector reference */ wpi::util::SmallVector& Vec() noexcept { return m_storedBuffer; } private: wpi::util::SmallVector m_storedBuffer; }; /** * A DirectUnpackCallback backed by a std::vector. * * By default, any elements in the packed buffer past N will * be ignored, but decoding will still succeed * * @tparam T object type * @tparam N number of expected elements */ template class StdVectorUnpackCallback : public DirectUnpackCallback, N> { public: /** * Constructs a StdVectorUnpackCallback. */ StdVectorUnpackCallback() : DirectUnpackCallback, N>{m_storedBuffer} { this->SetLimits(DecodeLimits::Ignore); } /** * Gets a span pointing to the storage buffer. * * @return storage buffer span */ std::span Items() noexcept { return m_storedBuffer; } /** * Gets a const span pointing to the storage buffer. * * @return storage buffer span */ std::span Items() const noexcept { return m_storedBuffer; } /** * Gets a reference to the backing vector. * * @return vector reference */ std::vector& Vec() noexcept { return m_storedBuffer; } private: std::vector m_storedBuffer; }; /** * A wrapper around a wpi::util::array that lets us * treat it as a limited sized vector. */ template struct WpiArrayEmplaceWrapper { wpi::util::array m_array{wpi::util::empty_array_t{}}; size_t m_currentIndex = 0; size_t size() const { return m_currentIndex; } template T& emplace_back(ArgTypes&&... Args) { m_array[m_currentIndex] = T(std::forward(Args)...); m_currentIndex++; return m_array[m_currentIndex - 1]; } }; /** * A DirectUnpackCallback backed by a wpi::util::array. * * Any elements in the packed buffer past N will * be cause decoding to fail. * * @tparam T object type * @tparam N small vector small size/number of expected elements */ template struct WpiArrayUnpackCallback : public DirectUnpackCallback, N> { /** * Constructs a WpiArrayUnpackCallback. */ WpiArrayUnpackCallback() : DirectUnpackCallback, N>{m_array} { this->SetLimits(DecodeLimits::Fail); } /** * Returns if the buffer is completely filled up. * * @return true if buffer is full */ bool IsFull() const noexcept { return m_array.m_currentIndex == N; } /** * Returns the number of elements in the buffer. * * @return number of elements */ size_t Size() const noexcept { return m_array.m_currentIndex; } /** * Returns a reference to the backing array. * * @return array reference */ wpi::util::array& Array() noexcept { return m_array.m_array; } private: WpiArrayEmplaceWrapper m_array; }; /** * A callback method that will pack elements when called. * * @tparam T object type */ template class PackCallback { public: /** * Constructs a pack callback from a span of elements. The elements in the * buffer _MUST_ stay alive throughout the entire encode call. */ explicit PackCallback(std::span buffer) : m_buffer{buffer} { m_callback.funcs.encode = CallbackFunc; m_callback.arg = this; } /** * Constructs a pack callback from a pointer to a single element. * This element _MUST_ stay alive throughout the entire encode call. * Do not pass a temporary here (This is why its a pointer and not a * reference) */ explicit PackCallback(const T* element) : m_buffer{std::span{element, 1}} { m_callback.funcs.encode = CallbackFunc; m_callback.arg = this; } PackCallback(const PackCallback&) = delete; PackCallback(PackCallback&&) = delete; PackCallback& operator=(const PackCallback&) = delete; PackCallback& operator=(PackCallback&&) = delete; /** * Gets the nanopb callback pointing to this object. * * @return nanopb callback */ pb_callback_t Callback() const { return m_callback; } /** * Gets a span pointing to the items * * @return span */ std::span Bufs() const { return m_buffer; } private: static auto EncodeStreamTypeFinder() { if constexpr (ProtobufSerializable) { return ProtoOutputStream(nullptr); } else { return pb_ostream_t{}; } } using EncodeStreamType = decltype(EncodeStreamTypeFinder()); bool EncodeItem(EncodeStreamType& stream, const pb_field_t* field, const T& value) const { if constexpr (std::floating_point) { pb_type_t fieldType = PB_LTYPE(field->type); switch (fieldType) { case PB_LTYPE_FIXED32: { float flt = static_cast(value); return pb_encode_fixed32(&stream, &flt); } case PB_LTYPE_FIXED64: { double dbl = static_cast(value); return pb_encode_fixed64(&stream, &dbl); } default: return false; } } else if constexpr (std::integral || ProtoEnumeration) { pb_type_t fieldType = PB_LTYPE(field->type); switch (fieldType) { case PB_LTYPE_BOOL: case PB_LTYPE_VARINT: case PB_LTYPE_UVARINT: return pb_encode_varint(&stream, value); case PB_LTYPE_SVARINT: return pb_encode_svarint(&stream, value); case PB_LTYPE_FIXED32: { uint32_t f = value; return pb_encode_fixed32(&stream, &f); } case PB_LTYPE_FIXED64: { uint64_t f = value; return pb_encode_fixed64(&stream, &f); } default: return false; } } else if constexpr (StringLike) { std::string_view view{value}; return pb_encode_string(&stream, reinterpret_cast(view.data()), view.size()); } else if constexpr (ConstVectorLike) { std::span view{value}; return pb_encode_string(&stream, reinterpret_cast(view.data()), view.size()); } else if constexpr (ProtobufSerializable) { return wpi::util::Protobuf::Pack(stream, value); } } bool EncodeLoop(pb_ostream_t* stream, const pb_field_t* field, bool writeTag) const { if constexpr (ProtobufSerializable) { ProtoOutputStream ostream{stream}; for (auto&& i : m_buffer) { if (writeTag) { if (!pb_encode_tag_for_field(stream, field)) { return false; } } if (!EncodeItem(ostream, field, i)) { return false; } } } else { for (auto&& i : m_buffer) { if (writeTag) { if (!pb_encode_tag_for_field(stream, field)) { return false; } } if (!EncodeItem(*stream, field, i)) { return false; } } } return true; } bool PackedEncode(pb_ostream_t* stream, const pb_field_t* field) const { // We're always going to used packed encoding. // So first we need to get the packed size. pb_ostream_t substream = PB_OSTREAM_SIZING; if (!EncodeLoop(&substream, field, false)) { return false; } // Encode as a string tag if (!pb_encode_tag(stream, PB_WT_STRING, field->tag)) { return false; } // Write length as varint size_t size = substream.bytes_written; if (!pb_encode_varint(stream, static_cast(size))) { return false; } return EncodeLoop(stream, field, false); } bool CallbackFunc(pb_ostream_t* stream, const pb_field_t* field) const { // First off, if we're empty, do nothing, but say we were successful if (m_buffer.empty()) { return true; } pb_type_t fieldType = PB_LTYPE(field->type); if (!detail::ValidateType(fieldType)) { return false; } if constexpr (ProtoPackable) { return PackedEncode(stream, field); } else { return EncodeLoop(stream, field, true); } } static bool CallbackFunc(pb_ostream_t* stream, const pb_field_t* field, void* const* arg) { return reinterpret_cast(*arg)->CallbackFunc(stream, field); } std::span m_buffer; pb_callback_t m_callback; }; } // namespace wpi::util