diff --git a/src/main/native/cpp/support/Base64.cpp b/src/main/native/cpp/support/Base64.cpp index 88fe28a47b..82f062dd6b 100644 --- a/src/main/native/cpp/support/Base64.cpp +++ b/src/main/native/cpp/support/Base64.cpp @@ -64,6 +64,9 @@ #include "support/Base64.h" +#include "llvm/SmallVector.h" +#include "llvm/raw_ostream.h" + namespace wpi { // aaaack but it's fast and const should make it shared text page. @@ -88,65 +91,91 @@ static const unsigned char pr2six[256] = 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 }; -std::size_t Base64Decode(llvm::StringRef encoded, std::string* plain) { +std::size_t Base64Decode(llvm::raw_ostream& os, llvm::StringRef encoded) { const unsigned char *end = encoded.bytes_begin(); while (pr2six[*end] <= 63 && end != encoded.bytes_end()) ++end; std::size_t nprbytes = end - encoded.bytes_begin(); - - plain->clear(); - if (nprbytes == 0) - return 0; - plain->reserve(((nprbytes + 3) / 4) * 3); + if (nprbytes == 0) return 0; const unsigned char *cur = encoded.bytes_begin(); while (nprbytes > 4) { - (*plain) += (pr2six[cur[0]] << 2 | pr2six[cur[1]] >> 4); - (*plain) += (pr2six[cur[1]] << 4 | pr2six[cur[2]] >> 2); - (*plain) += (pr2six[cur[2]] << 6 | pr2six[cur[3]]); + os << static_cast(pr2six[cur[0]] << 2 | pr2six[cur[1]] >> 4); + os << static_cast(pr2six[cur[1]] << 4 | pr2six[cur[2]] >> 2); + os << static_cast(pr2six[cur[2]] << 6 | pr2six[cur[3]]); cur += 4; nprbytes -= 4; } // Note: (nprbytes == 1) would be an error, so just ignore that case - if (nprbytes > 1) (*plain) += (pr2six[cur[0]] << 2 | pr2six[cur[1]] >> 4); - if (nprbytes > 2) (*plain) += (pr2six[cur[1]] << 4 | pr2six[cur[2]] >> 2); - if (nprbytes > 3) (*plain) += (pr2six[cur[2]] << 6 | pr2six[cur[3]]); + if (nprbytes > 1) + os << static_cast(pr2six[cur[0]] << 2 | pr2six[cur[1]] >> 4); + if (nprbytes > 2) + os << static_cast(pr2six[cur[1]] << 4 | pr2six[cur[2]] >> 2); + if (nprbytes > 3) + os << static_cast(pr2six[cur[2]] << 6 | pr2six[cur[3]]); return (end - encoded.bytes_begin()) + ((4 - nprbytes) & 3); } +std::size_t Base64Decode(llvm::StringRef encoded, std::string* plain) { + plain->resize(0); + llvm::raw_string_ostream os(*plain); + std::size_t rv = Base64Decode(os, encoded); + os.flush(); + return rv; +} + +llvm::StringRef Base64Decode(llvm::StringRef encoded, std::size_t* num_read, + llvm::SmallVectorImpl& buf) { + buf.clear(); + llvm::raw_svector_ostream os(buf); + *num_read = Base64Decode(os, encoded); + return os.str(); +} + static const char basis_64[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; -void Base64Encode(llvm::StringRef plain, std::string* encoded) { - encoded->clear(); - if (plain.empty()) - return; +void Base64Encode(llvm::raw_ostream& os, llvm::StringRef plain) { + if (plain.empty()) return; std::size_t len = plain.size(); - encoded->reserve(((len + 2) / 3 * 4) + 1); std::size_t i; for (i = 0; (i + 2) < len; i += 3) { - (*encoded) += basis_64[(plain[i] >> 2) & 0x3F]; - (*encoded) += - basis_64[((plain[i] & 0x3) << 4) | ((int)(plain[i + 1] & 0xF0) >> 4)]; - (*encoded) += basis_64[((plain[i + 1] & 0xF) << 2) | - ((int)(plain[i + 2] & 0xC0) >> 6)]; - (*encoded) += basis_64[plain[i + 2] & 0x3F]; + os << basis_64[(plain[i] >> 2) & 0x3F]; + os << basis_64[((plain[i] & 0x3) << 4) | ((int)(plain[i + 1] & 0xF0) >> 4)]; + os << basis_64[((plain[i + 1] & 0xF) << 2) | + ((int)(plain[i + 2] & 0xC0) >> 6)]; + os << basis_64[plain[i + 2] & 0x3F]; } if (i < len) { - (*encoded) += basis_64[(plain[i] >> 2) & 0x3F]; + os << basis_64[(plain[i] >> 2) & 0x3F]; if (i == (len - 1)) { - (*encoded) += basis_64[((plain[i] & 0x3) << 4)]; - (*encoded) += '='; + os << basis_64[((plain[i] & 0x3) << 4)]; + os << '='; } else { - (*encoded) += - basis_64[((plain[i] & 0x3) << 4) | ((int)(plain[i + 1] & 0xF0) >> 4)]; - (*encoded) += basis_64[((plain[i + 1] & 0xF) << 2)]; + os << basis_64[((plain[i] & 0x3) << 4) | + ((int)(plain[i + 1] & 0xF0) >> 4)]; + os << basis_64[((plain[i + 1] & 0xF) << 2)]; } - (*encoded) += '='; + os << '='; } } +void Base64Encode(llvm::StringRef plain, std::string* encoded) { + encoded->resize(0); + llvm::raw_string_ostream os(*encoded); + Base64Encode(os, plain); + os.flush(); +} + +llvm::StringRef Base64Encode(llvm::StringRef plain, + llvm::SmallVectorImpl& buf) { + buf.clear(); + llvm::raw_svector_ostream os(buf); + Base64Encode(os, plain); + return os.str(); +} + } // namespace wpi diff --git a/src/main/native/include/support/Base64.h b/src/main/native/include/support/Base64.h index 40cb2528f0..85bada6060 100644 --- a/src/main/native/include/support/Base64.h +++ b/src/main/native/include/support/Base64.h @@ -13,11 +13,28 @@ #include "llvm/StringRef.h" +namespace llvm { +template +class SmallVectorImpl; +class raw_ostream; +} + namespace wpi { +std::size_t Base64Decode(llvm::raw_ostream& os, llvm::StringRef encoded); + std::size_t Base64Decode(llvm::StringRef encoded, std::string* plain); + +llvm::StringRef Base64Decode(llvm::StringRef encoded, std::size_t* num_read, + llvm::SmallVectorImpl& buf); + +void Base64Encode(llvm::raw_ostream& os, llvm::StringRef plain); + void Base64Encode(llvm::StringRef plain, std::string* encoded); +llvm::StringRef Base64Encode(llvm::StringRef plain, + llvm::SmallVectorImpl& buf); + } // namespace wpi #endif // WPIUTIL_SUPPORT_BASE64_H_ diff --git a/src/test/native/cpp/Base64Test.cpp b/src/test/native/cpp/Base64Test.cpp index 972ca487ad..fd6ddd5cbc 100644 --- a/src/test/native/cpp/Base64Test.cpp +++ b/src/test/native/cpp/Base64Test.cpp @@ -8,6 +8,7 @@ #include "support/Base64.h" #include "gtest/gtest.h" +#include "llvm/SmallString.h" namespace wpi { @@ -34,17 +35,45 @@ class Base64Test : public ::testing::TestWithParam { } }; -TEST_P(Base64Test, Encode) { +TEST_P(Base64Test, EncodeStdString) { std::string s; Base64Encode(GetPlain(), &s); ASSERT_EQ(GetParam().encoded, s); + + // text already in s + Base64Encode(GetPlain(), &s); + ASSERT_EQ(GetParam().encoded, s); } -TEST_P(Base64Test, Decode) { +TEST_P(Base64Test, EncodeSmallString) { + llvm::SmallString<128> buf; + ASSERT_EQ(GetParam().encoded, Base64Encode(GetPlain(), buf)); + // reuse buf + ASSERT_EQ(GetParam().encoded, Base64Encode(GetPlain(), buf)); +} + +TEST_P(Base64Test, DecodeStdString) { std::string s; llvm::StringRef encoded = GetParam().encoded; EXPECT_EQ(encoded.size(), Base64Decode(encoded, &s)); ASSERT_EQ(GetPlain(), s); + + // text already in s + Base64Decode(encoded, &s); + ASSERT_EQ(GetPlain(), s); +} + +TEST_P(Base64Test, DecodeSmallString) { + llvm::SmallString<128> buf; + llvm::StringRef encoded = GetParam().encoded; + std::size_t len; + llvm::StringRef plain = Base64Decode(encoded, &len, buf); + EXPECT_EQ(encoded.size(), len); + ASSERT_EQ(GetPlain(), plain); + + // reuse buf + plain = Base64Decode(encoded, &len, buf); + ASSERT_EQ(GetPlain(), plain); } static Base64TestParam sample[] = {