From ae42eee8e13e47167ca054f1bb501439622fe6b1 Mon Sep 17 00:00:00 2001 From: Peter Johnson Date: Fri, 3 Jul 2015 18:01:50 -0700 Subject: [PATCH] More Base64 fixes (correct decode return value). Change-Id: Ic8f5eb7efd39e1d155a458aa41e430232bee7c7d --- src/Base64.cpp | 30 +++++++++++++++--------------- test/unit/test_Base64.cpp | 10 +++++++++- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/Base64.cpp b/src/Base64.cpp index d97b7011a7..d3b9c0b982 100644 --- a/src/Base64.cpp +++ b/src/Base64.cpp @@ -89,31 +89,31 @@ static const unsigned char pr2six[256] = }; std::size_t Base64Decode(llvm::StringRef encoded, std::string* plain) { - const unsigned char *bufin = encoded.bytes_begin(); - while (pr2six[*bufin] <= 63 && bufin != encoded.bytes_end()) ++bufin; - std::size_t nprbytes = bufin - encoded.bytes_begin(); - std::size_t nbytesdecoded = ((nprbytes + 3) / 4) * 3; + const unsigned char *end = encoded.bytes_begin(); + while (pr2six[*end] <= 63 && end != encoded.bytes_end()) ++end; + std::size_t nprbytes = end - encoded.bytes_begin(); plain->clear(); - plain->reserve(nbytesdecoded); + if (nprbytes == 0) + return 0; + plain->reserve(((nprbytes + 3) / 4) * 3); - bufin = encoded.bytes_begin(); + const unsigned char *cur = encoded.bytes_begin(); while (nprbytes > 4) { - (*plain) += (pr2six[*bufin] << 2 | pr2six[bufin[1]] >> 4); - (*plain) += (pr2six[bufin[1]] << 4 | pr2six[bufin[2]] >> 2); - (*plain) += (pr2six[bufin[2]] << 6 | pr2six[bufin[3]]); - bufin += 4; + (*plain) += (pr2six[cur[0]] << 2 | pr2six[cur[1]] >> 4); + (*plain) += (pr2six[cur[1]] << 4 | pr2six[cur[2]] >> 2); + (*plain) += (pr2six[cur[2]] << 6 | pr2six[cur[3]]); + cur += 4; nprbytes -= 4; } // Note: (nprbytes == 1) would be an error, so just ignore that case - if (nprbytes > 1) (*plain) += (pr2six[*bufin] << 2 | pr2six[bufin[1]] >> 4); - if (nprbytes > 2) (*plain) += (pr2six[bufin[1]] << 4 | pr2six[bufin[2]] >> 2); - if (nprbytes > 3) (*plain) += (pr2six[bufin[2]] << 6 | pr2six[bufin[3]]); + if (nprbytes > 1) (*plain) += (pr2six[cur[0]] << 2 | pr2six[cur[1]] >> 4); + if (nprbytes > 2) (*plain) += (pr2six[cur[1]] << 4 | pr2six[cur[2]] >> 2); + if (nprbytes > 3) (*plain) += (pr2six[cur[2]] << 6 | pr2six[cur[3]]); - nbytesdecoded -= (4 - nprbytes) & 3; - return nbytesdecoded; + return (end - encoded.bytes_begin()) + ((4 - nprbytes) & 3); } static const char basis_64[] = diff --git a/test/unit/test_Base64.cpp b/test/unit/test_Base64.cpp index d400266dc2..274e202788 100644 --- a/test/unit/test_Base64.cpp +++ b/test/unit/test_Base64.cpp @@ -10,6 +10,13 @@ struct Base64TestParam { const char* encoded; }; +std::ostream& operator<<(std::ostream& os, const Base64TestParam& param) { + os << "Base64TestParam(Len: " << param.plain_len << ", " + << "Plain: \"" << param.plain << "\", " + << "Encoded: \"" << param.encoded << "\")"; + return os; +} + class Base64Test : public ::testing::TestWithParam { protected: llvm::StringRef GetPlain() { @@ -28,7 +35,8 @@ TEST_P(Base64Test, Encode) { TEST_P(Base64Test, Decode) { std::string s; - Base64Decode(GetParam().encoded, &s); + llvm::StringRef encoded = GetParam().encoded; + EXPECT_EQ(encoded.size(), Base64Decode(encoded, &s)); ASSERT_EQ(GetPlain(), s); }