[wpiutil] Change C++ protobuf to nanopb (#7309)

The Google C++ protobuf implementation has issues with dynamic linkage across DLL boundaries because it uses global variables.  It also has a compile-time dependency because the protoc version must exactly match the libprotobuf version.  Using nanopb with a customized generator fixes both of these issues.

Co-authored-by: Gold856 <117957790+Gold856@users.noreply.github.com>
This commit is contained in:
Thad House
2024-11-07 22:42:50 -08:00
committed by GitHub
parent fd2e0c0427
commit 8b8b634f65
166 changed files with 17522 additions and 1571 deletions

View File

@@ -14,18 +14,12 @@
#include <utility>
#include <vector>
#include "pb.h"
#include "pb_decode.h"
#include "pb_encode.h"
#include "wpi/array.h"
#include "wpi/function_ref.h"
namespace google::protobuf {
class Arena;
class Message;
template <typename T>
class RepeatedPtrField;
template <typename T>
class RepeatedField;
} // namespace google::protobuf
namespace wpi {
template <typename T>
@@ -41,43 +35,235 @@ class SmallVectorImpl;
template <typename T>
struct Protobuf {};
namespace detail {
using SmallVectorType = wpi::SmallVectorImpl<uint8_t>;
using StdVectorType = std::vector<uint8_t>;
bool WriteFromSmallVector(pb_ostream_t* stream, const pb_byte_t* buf,
size_t count);
bool WriteFromStdVector(pb_ostream_t* stream, const pb_byte_t* buf,
size_t count);
} // namespace detail
/**
* Class for wrapping a nanopb istream.
*/
template <typename T>
class ProtoInputStream {
public:
/**
* Constructs a nanopb istream from an existing istream object.
* Generally used internally for decoding submessages
*
* @param[in] stream the nanopb istream
*/
explicit ProtoInputStream(pb_istream_t* stream)
: m_streamMsg{stream},
m_msgDesc{
Protobuf<std::remove_cvref_t<T>>::MessageStruct::msg_descriptor()} {
}
/**
* Constructs a nanopb istream from a buffer.
*
* @param[in] stream the stream buffer
*/
explicit ProtoInputStream(std::span<const uint8_t> stream)
: m_streamLocal{pb_istream_from_buffer(
reinterpret_cast<const pb_byte_t*>(stream.data()), stream.size())},
m_msgDesc{
Protobuf<std::remove_cvref_t<T>>::MessageStruct::msg_descriptor()} {
}
/**
* Gets the backing nanopb stream object.
*
* @return nanopb stream
*/
pb_istream_t* Stream() noexcept {
return m_streamMsg ? m_streamMsg : &m_streamLocal;
}
/**
* Gets the nanopb message descriptor
*
* @return the nanopb message descriptor
*/
const pb_msgdesc_t* MsgDesc() const noexcept { return m_msgDesc; }
/**
* Decodes a protobuf. Flags are the same flags passed to pb_decode_ex.
*
* @param[in] msg The message to decode into
* @param[in] flags Flags to pass
* @return true if decoding was successful, false otherwise
*/
bool Decode(typename Protobuf<std::remove_cvref_t<T>>::MessageStruct& msg,
unsigned int flags = 0) {
return pb_decode_ex(Stream(), m_msgDesc, &msg, flags);
}
private:
pb_istream_t m_streamLocal;
pb_istream_t* m_streamMsg{nullptr};
const pb_msgdesc_t* m_msgDesc;
};
/**
* Class for wrapping a nanopb ostream
*/
template <typename T>
class ProtoOutputStream {
public:
/**
* Constructs a nanopb ostream from an existing ostream object
* Generally used internally for encoding messages.
*
* This constructor will cause `Encode` to call pb_encode_submessage
* instead of `pb_encode_ex`
*
* @param[in] stream the nanopb ostream
*/
explicit ProtoOutputStream(pb_ostream_t* stream)
: m_streamMsg{stream},
m_msgDesc{
Protobuf<std::remove_cvref_t<T>>::MessageStruct::msg_descriptor()} {
}
/**
* Constructs a nanopb ostream from a buffer.
*
* This constructor will cause `Encode` to call pb_encode_ex`
*
* @param[in] out the stream buffer
*/
explicit ProtoOutputStream(detail::SmallVectorType& out)
: m_msgDesc{
Protobuf<std::remove_cvref_t<T>>::MessageStruct::msg_descriptor()} {
m_streamLocal.callback = detail::WriteFromSmallVector;
m_streamLocal.state = &out;
m_streamLocal.max_size = SIZE_MAX;
m_streamLocal.bytes_written = 0;
m_streamLocal.errmsg = nullptr;
}
/**
* Constructs a nanopb ostream from a buffer.
*
* This constructor will cause `Encode` to call pb_encode_ex`
*
* @param[in] out the stream buffer
*/
explicit ProtoOutputStream(detail::StdVectorType& out)
: m_msgDesc{
Protobuf<std::remove_cvref_t<T>>::MessageStruct::msg_descriptor()} {
m_streamLocal.callback = detail::WriteFromStdVector;
m_streamLocal.state = &out;
m_streamLocal.max_size = SIZE_MAX;
m_streamLocal.bytes_written = 0;
m_streamLocal.errmsg = nullptr;
}
/**
* Constructs a empty nanopb stream. You must fill out the stream
* returned from `Stream` before calling Encode.
*
* This constructor exists to cause `Encode` to call pb_encode_ex`,
* but allow manipulating the stream manually.
*/
ProtoOutputStream()
: m_msgDesc{Protobuf<
std::remove_cvref_t<T>>::MessageStruct::msg_descriptor()} {}
/**
* Gets the backing nanopb stream object.
*
* @return nanopb stream
*/
pb_ostream_t* Stream() noexcept {
return m_streamMsg ? m_streamMsg : &m_streamLocal;
}
/**
* Gets if this stream points to a submessage, and will call
* pb_encode_submessage instead of pb_encode
*
* @return true if submessage, false otherwise
*/
bool IsSubmessage() const noexcept { return m_streamMsg; }
/**
* Gets the nanopb message descriptor
*
* @return the nanopb message descriptor
*/
const pb_msgdesc_t* MsgDesc() const noexcept { return m_msgDesc; }
/**
* Decodes a protobuf. Flags are the same flags passed to pb_decode_ex.
*
* @param[in] msg The message to encode from
* @return true if encoding was successful, false otherwise
*/
bool Encode(
const typename Protobuf<std::remove_cvref_t<T>>::MessageStruct& msg) {
if (m_streamMsg) {
return pb_encode_submessage(m_streamMsg, m_msgDesc, &msg);
}
return pb_encode(&m_streamLocal, m_msgDesc, &msg);
}
private:
pb_ostream_t m_streamLocal;
pb_ostream_t* m_streamMsg{nullptr};
const pb_msgdesc_t* m_msgDesc;
};
/**
* Specifies that a type is capable of protobuf serialization and
* deserialization.
*
* This is designed for serializing complex flexible data structures using
* code generated from a .proto file. Serialization consists of writing
* values into a mutable protobuf Message and deserialization consists of
* reading values from an immutable protobuf Message.
* values into a nanopb Stream and deserialization consists of
* reading values from nanopb Stream.
*
* Implementations must define a template specialization for wpi::Protobuf with
* T being the type that is being serialized/deserialized, with the following
* static members (as enforced by this concept):
* - google::protobuf::Message* New(google::protobuf::Arena*): create a protobuf
* message
* - T Unpack(const google::protobuf::Message&): function for deserialization
* - void Pack(google::protobuf::Message*, T&& value): function for
* serialization
* - using MessageStruct = nanopb_message_struct_here: typedef to the wpilib
* modified nanopb message struct
* - std::optional<T> Unpack(wpi::ProtoInputStream<T>&): function
* for deserialization
* - bool Pack(wpi::ProtoOutputStream<T>&, T&& value): function
* for serialization
*
* To avoid pulling in the protobuf headers, these functions use
* google::protobuf::Message instead of a more specific type; implementations
* will need to static_cast to the correct type as created by New().
*
* Additionally: In a static block, call StructRegistry.registerClass() to
* register the class
* As a suggestion, 2 extra type usings can be added to simplify the stream
* definitions, however these are not required.
* - using InputStream = wpi::ProtoInputStream<T>;
* - using OutputStream = wpi::ProtoOutputStream<T>;
*/
template <typename T>
concept ProtobufSerializable = requires(
google::protobuf::Arena* arena, const google::protobuf::Message& inmsg,
google::protobuf::Message* outmsg, const T& value) {
wpi::ProtoOutputStream<std::remove_cvref_t<T>>& ostream,
wpi::ProtoInputStream<std::remove_cvref_t<T>>& istream, const T& value) {
typename Protobuf<typename std::remove_cvref_t<T>>;
{
Protobuf<typename std::remove_cvref_t<T>>::New(arena)
} -> std::same_as<google::protobuf::Message*>;
Protobuf<typename std::remove_cvref_t<T>>::Unpack(istream)
} -> std::same_as<std::optional<typename std::remove_cvref_t<T>>>;
{
Protobuf<typename std::remove_cvref_t<T>>::Unpack(inmsg)
} -> std::same_as<typename std::remove_cvref_t<T>>;
Protobuf<typename std::remove_cvref_t<T>>::Pack(outmsg, value);
Protobuf<typename std::remove_cvref_t<T>>::Pack(ostream, value)
} -> std::same_as<bool>;
typename Protobuf<typename std::remove_cvref_t<T>>::MessageStruct;
{
Protobuf<typename std::remove_cvref_t<T>>::MessageStruct::msg_descriptor()
} -> std::same_as<const pb_msgdesc_t*>;
{
Protobuf<typename std::remove_cvref_t<T>>::MessageStruct::msg_name()
} -> std::same_as<std::string_view>;
{
Protobuf<typename std::remove_cvref_t<T>>::MessageStruct::file_descriptor()
} -> std::same_as<pb_filedesc_t>;
};
/**
@@ -85,146 +271,22 @@ concept ProtobufSerializable = requires(
*
* In addition to meeting ProtobufSerializable, implementations must define a
* wpi::Protobuf<T> static member
* `void UnpackInto(T*, const google::protobuf::Message&)` to update the
* pointed-to T with the contents of the message.
* - bool UnpackInto(T*, wpi::ProtoInputStream<T>&)` to update the
* pointed-to T with the contents of the message.
*/
template <typename T>
concept MutableProtobufSerializable =
ProtobufSerializable<T> &&
requires(T* out, const google::protobuf::Message& msg) {
Protobuf<typename std::remove_cvref_t<T>>::UnpackInto(out, msg);
requires(T* out, wpi::ProtoInputStream<T>& istream) {
{
Protobuf<typename std::remove_cvref_t<T>>::UnpackInto(out, istream)
} -> std::same_as<bool>;
};
/**
* Unpack a serialized protobuf message.
*
* @tparam T object type
* @param msg protobuf message
* @return Deserialized object
*/
template <ProtobufSerializable T>
inline T UnpackProtobuf(const google::protobuf::Message& msg) {
return Protobuf<T>::Unpack(msg);
}
/**
* Unpack a serialized protobuf array message.
*
* @tparam Proto element type of the protobuf array
* @tparam T object type
* @tparam N number of objects
* @param msg protobuf array message
* @return Deserialized array
*/
template <std::derived_from<google::protobuf::Message> Proto,
ProtobufSerializable T, size_t N>
wpi::array<T, N> UnpackProtobufArray(
const google::protobuf::RepeatedPtrField<Proto>& msg) {
if (N != std::dynamic_extent && msg.size() != N) {
// TODO
}
wpi::array<T, N> arr(wpi::empty_array);
for (size_t i = 0; i < N; i++) {
arr[i] = wpi::UnpackProtobuf<T>(msg.Get(i));
}
return arr;
}
/**
* Unpack a serialized protobuf array message.
*
* @tparam T element type of the protobuf array
* @tparam N number of objects
* @param msg protobuf array message
* @return Deserialized array
*/
template <typename T, size_t N>
wpi::array<T, N> UnpackProtobufArray(
const google::protobuf::RepeatedField<T>& msg) {
if (N != std::dynamic_extent && msg.size() != N) {
// TODO
}
wpi::array<T, N> arr(wpi::empty_array);
for (size_t i = 0; i < N; i++) {
arr[i] = msg.Get(i);
}
return arr;
}
/**
* Pack a serialized protobuf message.
*
* @param msg protobuf message (mutable, output)
* @param value object
*/
template <ProtobufSerializable T>
inline void PackProtobuf(google::protobuf::Message* msg, const T& value) {
Protobuf<typename std::remove_cvref_t<T>>::Pack(msg, value);
}
/**
* Pack a serialized protobuf array message.
*
* @tparam Proto element type of the protobuf array
* @tparam T object type
* @tparam N number of objects
* @param msg protobuf message (mutable, output)
* @param arr array of objects
*/
template <std::derived_from<google::protobuf::Message> Proto,
ProtobufSerializable T, size_t N>
void PackProtobufArray(google::protobuf::RepeatedPtrField<Proto>* msg,
const wpi::array<T, N>& arr) {
msg->Clear();
msg->Reserve(N);
for (const auto& obj : arr) {
PackProtobuf(msg->Add(), obj);
}
}
/**
* Pack a serialized protobuf array message.
*
* @tparam T object type
* @tparam N number of objects
* @param msg protobuf message (mutable, output)
* @param arr array of objects
*/
template <typename T, size_t N>
void PackProtobufArray(google::protobuf::RepeatedField<T>* msg,
const wpi::array<T, N>& arr) {
msg->Clear();
msg->Reserve(N);
msg->Add(arr.begin(), arr.end());
}
/**
* Unpack a serialized struct into an existing object, overwriting its contents.
*
* @param out object (output)
* @param msg protobuf message
*/
template <ProtobufSerializable T>
inline void UnpackProtobufInto(T* out, const google::protobuf::Message& msg) {
if constexpr (MutableProtobufSerializable<T>) {
Protobuf<T>::UnpackInto(out, msg);
} else {
*out = UnpackProtobuf<T>(msg);
}
}
// these detail functions avoid the need to include protobuf headers
namespace detail {
void DeleteProtobuf(google::protobuf::Message* msg);
bool ParseProtobuf(google::protobuf::Message* msg,
std::span<const uint8_t> data);
bool SerializeProtobuf(wpi::SmallVectorImpl<uint8_t>& out,
const google::protobuf::Message& msg);
bool SerializeProtobuf(std::vector<uint8_t>& out,
const google::protobuf::Message& msg);
std::string GetTypeString(const google::protobuf::Message& msg);
std::string GetTypeString(const pb_msgdesc_t* msg);
void ForEachProtobufDescriptor(
const google::protobuf::Message& msg,
const pb_msgdesc_t* msg,
function_ref<bool(std::string_view filename)> wants,
function_ref<void(std::string_view filename,
std::span<const uint8_t> descriptor)>
@@ -232,48 +294,23 @@ void ForEachProtobufDescriptor(
} // namespace detail
/**
* Owning wrapper (ala std::unique_ptr) for google::protobuf::Message* that does
* not require the protobuf headers be included. Note this object is not thread
* safe; users of this object are required to provide any necessary thread
* safety.
* Ease of use wrapper to make nanopb streams more opaque to the user.
* This class is stateless and thread safe.
*
* @tparam T serialized object type
*/
template <ProtobufSerializable T>
class ProtobufMessage {
public:
explicit ProtobufMessage(google::protobuf::Arena* arena = nullptr)
: m_msg{Protobuf<T>::New(arena)} {}
~ProtobufMessage() { detail::DeleteProtobuf(m_msg); }
ProtobufMessage(const ProtobufMessage&) = delete;
ProtobufMessage& operator=(const ProtobufMessage&) = delete;
ProtobufMessage(ProtobufMessage&& rhs) : m_msg{rhs.m_msg} {
rhs.m_msg = nullptr;
}
ProtobufMessage& operator=(ProtobufMessage&& rhs) {
std::swap(m_msg, rhs.m_msg);
return *this;
}
/**
* Gets the stored message object.
*
* @return google::protobuf::Message*
*/
google::protobuf::Message* GetMessage() { return m_msg; }
const google::protobuf::Message* GetMessage() const { return m_msg; }
/**
* Unpacks from a byte array.
*
* @param data byte array
* @return Optional; empty if parsing failed
*/
std::optional<T> Unpack(std::span<const uint8_t> data) {
if (!detail::ParseProtobuf(m_msg, data)) {
return std::nullopt;
}
return Protobuf<T>::Unpack(*m_msg);
std::optional<std::remove_cvref_t<T>> Unpack(std::span<const uint8_t> data) {
ProtoInputStream<std::remove_cvref_t<T>> stream{data};
return Protobuf<std::remove_cvref_t<T>>::Unpack(stream);
}
/**
@@ -284,11 +321,17 @@ class ProtobufMessage {
* @return true if successful
*/
bool UnpackInto(T* out, std::span<const uint8_t> data) {
if (!detail::ParseProtobuf(m_msg, data)) {
return false;
if constexpr (MutableProtobufSerializable<T>) {
ProtoInputStream<std::remove_cvref_t<T>> stream{data};
return Protobuf<std::remove_cvref_t<T>>::UnpackInto(out, stream);
} else {
auto unpacked = Unpack(data);
if (!unpacked) {
return false;
}
*out = std::move(unpacked.value());
return true;
}
UnpackProtobufInto(out, *m_msg);
return true;
}
/**
@@ -299,8 +342,8 @@ class ProtobufMessage {
* @return true if successful
*/
bool Pack(wpi::SmallVectorImpl<uint8_t>& out, const T& value) {
Protobuf<T>::Pack(m_msg, value);
return detail::SerializeProtobuf(out, *m_msg);
ProtoOutputStream<std::remove_cvref_t<T>> stream{out};
return Protobuf<std::remove_cvref_t<T>>::Pack(stream, value);
}
/**
@@ -311,8 +354,8 @@ class ProtobufMessage {
* @return true if successful
*/
bool Pack(std::vector<uint8_t>& out, const T& value) {
Protobuf<T>::Pack(m_msg, value);
return detail::SerializeProtobuf(out, *m_msg);
ProtoOutputStream<std::remove_cvref_t<T>> stream{out};
return Protobuf<std::remove_cvref_t<T>>::Pack(stream, value);
}
/**
@@ -320,7 +363,10 @@ class ProtobufMessage {
*
* @return type string
*/
std::string GetTypeString() const { return detail::GetTypeString(*m_msg); }
std::string GetTypeString() const {
return detail::GetTypeString(
Protobuf<std::remove_cvref_t<T>>::MessageStruct::msg_descriptor());
}
/**
* Loops over all protobuf descriptors including nested/referenced
@@ -335,11 +381,10 @@ class ProtobufMessage {
function_ref<void(std::string_view filename,
std::span<const uint8_t> descriptor)>
fn) {
detail::ForEachProtobufDescriptor(*m_msg, exists, fn);
detail::ForEachProtobufDescriptor(
Protobuf<std::remove_cvref_t<T>>::MessageStruct::msg_descriptor(),
exists, fn);
}
private:
google::protobuf::Message* m_msg = nullptr;
};
} // namespace wpi

View File

@@ -0,0 +1,670 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
#pragma once
#include <span>
#include <utility>
#include <vector>
#include <fmt/format.h>
#include "pb.h"
#include "wpi/SmallVector.h"
#include "wpi/array.h"
#include "wpi/protobuf/Protobuf.h"
namespace wpi {
/**
* The behavior to use when more elements are in the message then expected when
* decoding.
*/
enum class DecodeLimits {
// Ignore any extra elements
Ignore,
// Add any extra elements to the backing vector
Add,
// Cause decoding to fail if extra elements exist
Fail,
};
template <class T>
concept StringLike = std::is_convertible_v<T, std::string_view>;
template <class T>
concept ConstVectorLike = std::is_convertible_v<T, std::span<const uint8_t>>;
template <class T>
concept MutableVectorLike = std::is_convertible_v<T, std::span<uint8_t>>;
template <typename T>
concept PackBytes = StringLike<T> || ConstVectorLike<T>;
template <typename T>
concept UnpackBytes = requires(T& t) {
{ t.resize(size_t()) }; // NOLINT
{ t.size() } -> std::same_as<size_t>;
{ t.data() } -> std::convertible_to<void*>;
} && (PackBytes<T> || MutableVectorLike<T>);
template <typename T>
concept ProtoEnumeration = std::is_enum_v<T>;
template <typename T>
concept ProtoPackable =
ProtoEnumeration<T> || std::integral<T> || std::floating_point<T>;
template <typename T>
concept ProtoCallbackPackable =
ProtobufSerializable<T> || PackBytes<T> || ProtoPackable<T>;
template <typename T>
concept ProtoCallbackUnpackable =
ProtobufSerializable<T> || UnpackBytes<T> || ProtoPackable<T>;
namespace detail {
template <typename T>
concept Validatable = ProtoCallbackPackable<T> || ProtoCallbackUnpackable<T>;
template <Validatable T>
constexpr bool ValidateType(pb_type_t type) {
switch (type) {
case PB_LTYPE_BOOL:
return std::integral<T>;
case PB_LTYPE_VARINT:
return std::signed_integral<T> || ProtoEnumeration<T>;
case PB_LTYPE_UVARINT:
return std::unsigned_integral<T>;
case PB_LTYPE_SVARINT:
return std::signed_integral<T>;
case PB_LTYPE_FIXED32:
return std::integral<T> || std::floating_point<T>;
case PB_LTYPE_FIXED64:
return std::integral<T> || std::floating_point<T>;
case PB_LTYPE_BYTES:
case PB_LTYPE_STRING:
return PackBytes<T> || UnpackBytes<T>;
case PB_LTYPE_SUBMESSAGE:
return ProtobufSerializable<T>;
default:
return false;
}
}
} // namespace detail
/**
* A callback method that will directly unpack elements into
* the specified vector like data structure. The size passed
* is the expected number of elements.
*
* By default, any elements in the packed buffer past N will
* still be added to the vector.
*
* @tparam T object type
* @tparam U vector type to pack into
* @tparam N number of elements
*/
template <ProtoCallbackUnpackable T, typename U, size_t N = 1>
class DirectUnpackCallback {
public:
/**
* Constructs a callback from a vector like type.
*
* @param storage the vector to store into
*/
explicit DirectUnpackCallback(U& storage) : m_storage{storage} {
m_callback.funcs.decode = CallbackFunc;
m_callback.arg = this;
}
DirectUnpackCallback(const DirectUnpackCallback&) = delete;
DirectUnpackCallback(DirectUnpackCallback&&) = delete;
DirectUnpackCallback& operator=(const DirectUnpackCallback&) = delete;
DirectUnpackCallback& operator=(DirectUnpackCallback&&) = delete;
/**
* Set the limits on what happens if more elements exist in the buffer then
* expected.
*
* @param limit the limit to set
*/
void SetLimits(DecodeLimits limit) noexcept { m_limits = limit; }
/**
* Gets the nanopb callback pointing to this object.
*
* @return nanopb callback
*/
pb_callback_t Callback() const { return m_callback; }
private:
bool SizeCheck(bool* retVal) const {
if (m_storage.size() >= N) {
switch (m_limits) {
case DecodeLimits::Ignore:
*retVal = true;
return false;
case DecodeLimits::Add:
break;
default:
*retVal = false;
return false;
}
}
return true;
}
bool Decode(pb_istream_t* stream, pb_type_t fieldType) {
if constexpr (ProtoPackable<T>) {
switch (fieldType) {
case PB_LTYPE_BOOL:
if constexpr (std::integral<T>) {
bool val = false;
if (!pb_decode_bool(stream, &val)) {
return false;
}
m_storage.emplace_back(static_cast<T>(val));
return true;
} else {
return false;
}
case PB_LTYPE_VARINT:
if constexpr (std::signed_integral<T> || ProtoEnumeration<T>) {
int64_t val = 0;
if (!pb_decode_varint(stream, reinterpret_cast<uint64_t*>(&val))) {
return false;
}
m_storage.emplace_back(static_cast<T>(val));
return true;
} else {
return false;
}
case PB_LTYPE_UVARINT:
if constexpr (std::unsigned_integral<T>) {
uint64_t val = 0;
if (!pb_decode_varint(stream, &val)) {
return false;
}
m_storage.emplace_back(static_cast<T>(val));
return true;
} else {
return false;
}
case PB_LTYPE_SVARINT:
if constexpr (std::signed_integral<T>) {
int64_t val = 0;
if (!pb_decode_svarint(stream, &val)) {
return false;
}
m_storage.emplace_back(static_cast<T>(val));
return true;
} else {
return false;
}
case PB_LTYPE_FIXED32:
if constexpr (std::signed_integral<T>) {
int32_t val = 0;
if (!pb_decode_fixed32(stream, &val)) {
return false;
}
m_storage.emplace_back(static_cast<T>(val));
return true;
} else if constexpr (std::unsigned_integral<T>) {
uint32_t val = 0;
if (!pb_decode_fixed32(stream, &val)) {
return false;
}
m_storage.emplace_back(static_cast<T>(val));
return true;
}
if constexpr (std::floating_point<T>) {
float val = 0;
if (!pb_decode_fixed32(stream, &val)) {
return false;
}
m_storage.emplace_back(static_cast<T>(val));
return true;
} else {
return false;
}
case PB_LTYPE_FIXED64:
if constexpr (std::signed_integral<T>) {
int64_t val = 0;
if (!pb_decode_fixed64(stream, &val)) {
return false;
}
m_storage.emplace_back(static_cast<T>(val));
return true;
} else if constexpr (std::unsigned_integral<T>) {
uint64_t val = 0;
if (!pb_decode_fixed64(stream, &val)) {
return false;
}
m_storage.emplace_back(static_cast<T>(val));
return true;
}
if constexpr (std::floating_point<T>) {
double val = 0;
if (!pb_decode_fixed64(stream, &val)) {
return false;
}
m_storage.emplace_back(static_cast<T>(val));
return true;
} else {
return false;
}
default:
return false;
}
} else if constexpr (UnpackBytes<T>) {
T& space = m_storage.emplace_back(T{});
space.resize(stream->bytes_left);
return pb_read(stream, reinterpret_cast<pb_byte_t*>(space.data()),
space.size());
} else if constexpr (ProtobufSerializable<T>) {
ProtoInputStream<T> istream{stream};
auto decoded = wpi::Protobuf<T>::Unpack(istream);
if (decoded.has_value()) {
m_storage.emplace_back(std::move(decoded.value()));
return true;
}
return false;
}
}
bool CallbackFunc(pb_istream_t* stream, const pb_field_t* field) {
pb_type_t fieldType = PB_LTYPE(field->type);
if (!detail::ValidateType<T>(fieldType)) {
return false;
}
// Validate our types
if constexpr (ProtoPackable<T>) {
// Handle decode loop
while (stream->bytes_left > 0) {
bool sizeRetVal = 0;
if (!SizeCheck(&sizeRetVal)) {
return sizeRetVal;
}
if (!Decode(stream, fieldType)) {
return false;
}
}
return true;
} else {
// At this point, do the size check
bool sizeRetVal = 0;
if (!SizeCheck(&sizeRetVal)) {
return sizeRetVal;
}
// At this point, we're good to decode
return Decode(stream, fieldType);
}
}
static bool CallbackFunc(pb_istream_t* stream, const pb_field_t* field,
void** arg) {
return reinterpret_cast<DirectUnpackCallback*>(*arg)->CallbackFunc(stream,
field);
}
U& m_storage;
pb_callback_t m_callback;
DecodeLimits m_limits{DecodeLimits::Add};
};
/**
* A DirectUnpackCallback backed by a SmallVector<T, N>.
*
* By default, any elements in the packed buffer past N will
* be ignored, but decoding will still succeed
*
* @tparam T object type
* @tparam N small vector small size/number of expected elements
*/
template <ProtoCallbackUnpackable T, size_t N = 1>
class UnpackCallback
: public DirectUnpackCallback<T, wpi::SmallVector<T, N>, N> {
public:
/**
* Constructs an UnpackCallback.
*/
UnpackCallback()
: DirectUnpackCallback<T, wpi::SmallVector<T, N>, N>{m_storedBuffer} {
this->SetLimits(DecodeLimits::Ignore);
}
/**
* Gets a span pointing to the storage buffer.
*
* @return storage buffer span
*/
std::span<T> Items() noexcept { return m_storedBuffer; }
/**
* Gets a const span pointing to the storage buffer.
*
* @return storage buffer span
*/
std::span<const T> Items() const noexcept { return m_storedBuffer; }
/**
* Gets a reference to the backing small vector.
*
* @return small vector reference
*/
wpi::SmallVector<T, N>& Vec() noexcept { return m_storedBuffer; }
private:
wpi::SmallVector<T, N> m_storedBuffer;
};
/**
* A DirectUnpackCallback backed by a std::vector.
*
* By default, any elements in the packed buffer past N will
* be ignored, but decoding will still succeed
*
* @tparam T object type
* @tparam N number of expected elements
*/
template <ProtoCallbackUnpackable T, size_t N = 1>
class StdVectorUnpackCallback
: public DirectUnpackCallback<T, std::vector<T>, N> {
public:
/**
* Constructs a StdVectorUnpackCallback.
*/
StdVectorUnpackCallback()
: DirectUnpackCallback<T, std::vector<T>, N>{m_storedBuffer} {
this->SetLimits(DecodeLimits::Ignore);
}
/**
* Gets a span pointing to the storage buffer.
*
* @return storage buffer span
*/
std::span<T> Items() noexcept { return m_storedBuffer; }
/**
* Gets a const span pointing to the storage buffer.
*
* @return storage buffer span
*/
std::span<const T> Items() const noexcept { return m_storedBuffer; }
/**
* Gets a reference to the backing vector.
*
* @return vector reference
*/
std::vector<T>& Vec() noexcept { return m_storedBuffer; }
private:
std::vector<T> m_storedBuffer;
};
/**
* A wrapper around a wpi::array that lets us
* treat it as a limited sized vector.
*/
template <ProtoCallbackUnpackable T, size_t N>
struct WpiArrayEmplaceWrapper {
wpi::array<T, N> m_array{wpi::empty_array_t{}};
size_t m_currentIndex = 0;
size_t size() const { return m_currentIndex; }
template <typename... ArgTypes>
T& emplace_back(ArgTypes&&... Args) {
m_array[m_currentIndex] = T(std::forward<ArgTypes>(Args)...);
m_currentIndex++;
return m_array[m_currentIndex - 1];
}
};
/**
* A DirectUnpackCallback backed by a wpi::array<T, N>.
*
* Any elements in the packed buffer past N will
* be cause decoding to fail.
*
* @tparam T object type
* @tparam N small vector small size/number of expected elements
*/
template <ProtoCallbackUnpackable T, size_t N>
struct WpiArrayUnpackCallback
: public DirectUnpackCallback<T, WpiArrayEmplaceWrapper<T, N>, N> {
/**
* Constructs a WpiArrayUnpackCallback.
*/
WpiArrayUnpackCallback()
: DirectUnpackCallback<T, WpiArrayEmplaceWrapper<T, N>, N>{m_array} {
this->SetLimits(DecodeLimits::Fail);
}
/**
* Returns if the buffer is completely filled up.
*
* @return true if buffer is full
*/
bool IsFull() const noexcept { return m_array.m_currentIndex == N; }
/**
* Returns the number of elements in the buffer.
*
* @return number of elements
*/
size_t Size() const noexcept { return m_array.m_currentIndex; }
/**
* Returns a reference to the backing array.
*
* @return array reference
*/
wpi::array<T, N>& Array() noexcept { return m_array.m_array; }
private:
WpiArrayEmplaceWrapper<T, N> m_array;
};
/**
* A callback method that will pack elements when called.
*
* @tparam T object type
*/
template <ProtoCallbackPackable T>
class PackCallback {
public:
/**
* Constructs a pack callback from a span of elements. The elements in the
* buffer _MUST_ stay alive throughout the entire encode call.
*/
explicit PackCallback(std::span<const T> buffer) : m_buffer{buffer} {
m_callback.funcs.encode = CallbackFunc;
m_callback.arg = this;
}
/**
* Constructs a pack callback from a pointer to a single element.
* This element _MUST_ stay alive throughout the entire encode call.
* Do not pass a temporary here (This is why its a pointer and not a
* reference)
*/
explicit PackCallback(const T* element)
: m_buffer{std::span<const T>{element, 1}} {
m_callback.funcs.encode = CallbackFunc;
m_callback.arg = this;
}
PackCallback(const PackCallback&) = delete;
PackCallback(PackCallback&&) = delete;
PackCallback& operator=(const PackCallback&) = delete;
PackCallback& operator=(PackCallback&&) = delete;
/**
* Gets the nanopb callback pointing to this object.
*
* @return nanopb callback
*/
pb_callback_t Callback() const { return m_callback; }
/**
* Gets a span pointing to the items
*
* @return span
*/
std::span<const T> Bufs() const { return m_buffer; }
private:
static auto EncodeStreamTypeFinder() {
if constexpr (ProtobufSerializable<T>) {
return ProtoOutputStream<T>(nullptr);
} else {
return pb_ostream_t{};
}
}
using EncodeStreamType = decltype(EncodeStreamTypeFinder());
bool EncodeItem(EncodeStreamType& stream, const pb_field_t* field,
const T& value) const {
if constexpr (std::floating_point<T>) {
pb_type_t fieldType = PB_LTYPE(field->type);
switch (fieldType) {
case PB_LTYPE_FIXED32: {
float flt = static_cast<float>(value);
return pb_encode_fixed32(&stream, &flt);
}
case PB_LTYPE_FIXED64: {
double dbl = static_cast<double>(value);
return pb_encode_fixed64(&stream, &dbl);
}
default:
return false;
}
} else if constexpr (std::integral<T> || ProtoEnumeration<T>) {
pb_type_t fieldType = PB_LTYPE(field->type);
switch (fieldType) {
case PB_LTYPE_BOOL:
case PB_LTYPE_VARINT:
case PB_LTYPE_UVARINT:
return pb_encode_varint(&stream, value);
case PB_LTYPE_SVARINT:
return pb_encode_svarint(&stream, value);
case PB_LTYPE_FIXED32: {
uint32_t f = value;
return pb_encode_fixed32(&stream, &f);
}
case PB_LTYPE_FIXED64: {
uint64_t f = value;
return pb_encode_fixed64(&stream, &f);
}
default:
return false;
}
} else if constexpr (StringLike<T>) {
std::string_view view{value};
return pb_encode_string(&stream,
reinterpret_cast<const pb_byte_t*>(view.data()),
view.size());
} else if constexpr (ConstVectorLike<T>) {
std::span<const uint8_t> view{value};
return pb_encode_string(&stream,
reinterpret_cast<const pb_byte_t*>(view.data()),
view.size());
} else if constexpr (ProtobufSerializable<T>) {
return wpi::Protobuf<T>::Pack(stream, value);
}
}
bool EncodeLoop(pb_ostream_t* stream, const pb_field_t* field,
bool writeTag) const {
if constexpr (ProtobufSerializable<T>) {
ProtoOutputStream<T> ostream{stream};
for (auto&& i : m_buffer) {
if (writeTag) {
if (!pb_encode_tag_for_field(stream, field)) {
return false;
}
}
if (!EncodeItem(ostream, field, i)) {
return false;
}
}
} else {
for (auto&& i : m_buffer) {
if (writeTag) {
if (!pb_encode_tag_for_field(stream, field)) {
return false;
}
}
if (!EncodeItem(*stream, field, i)) {
return false;
}
}
}
return true;
}
bool PackedEncode(pb_ostream_t* stream, const pb_field_t* field) const {
// We're always going to used packed encoding.
// So first we need to get the packed size.
pb_ostream_t substream = PB_OSTREAM_SIZING;
if (!EncodeLoop(&substream, field, false)) {
return false;
}
// Encode as a string tag
if (!pb_encode_tag(stream, PB_WT_STRING, field->tag)) {
return false;
}
// Write length as varint
size_t size = substream.bytes_written;
if (!pb_encode_varint(stream, static_cast<uint64_t>(size))) {
return false;
}
return EncodeLoop(stream, field, false);
}
bool CallbackFunc(pb_ostream_t* stream, const pb_field_t* field) const {
// First off, if we're empty, do nothing, but say we were successful
if (m_buffer.empty()) {
return true;
}
pb_type_t fieldType = PB_LTYPE(field->type);
if (!detail::ValidateType<T>(fieldType)) {
return false;
}
if constexpr (ProtoPackable<T>) {
return PackedEncode(stream, field);
} else {
return EncodeLoop(stream, field, true);
}
}
static bool CallbackFunc(pb_ostream_t* stream, const pb_field_t* field,
void* const* arg) {
return reinterpret_cast<const PackCallback*>(*arg)->CallbackFunc(stream,
field);
}
std::span<const T> m_buffer;
pb_callback_t m_callback;
};
} // namespace wpi