diff --git a/src/Base64.cpp b/src/Base64.cpp index 643cb2d70b..b4bdbd962f 100644 --- a/src/Base64.cpp +++ b/src/Base64.cpp @@ -91,7 +91,7 @@ static const unsigned char pr2six[256] = std::size_t Base64Decode(llvm::StringRef encoded, std::string* plain) { const unsigned char *bufin = encoded.bytes_begin(); while (pr2six[*bufin] <= 63 && bufin != encoded.bytes_end()) ++bufin; - std::size_t nprbytes = (bufin - encoded.bytes_begin()) - 1; + std::size_t nprbytes = bufin - encoded.bytes_begin(); std::size_t nbytesdecoded = ((nprbytes + 3) / 4) * 3; plain->clear(); @@ -121,18 +121,21 @@ static const char basis_64[] = void Base64Encode(llvm::StringRef plain, std::string* encoded) { encoded->clear(); + if (plain.empty()) + return; std::size_t len = plain.size(); encoded->reserve(((len + 2) / 3 * 4) + 1); - std::size_t i; - for (i = 0; i < len - 2; i += 3) { - (*encoded) += basis_64[(plain[i] >> 2) & 0x3F]; - (*encoded) += - basis_64[((plain[i] & 0x3) << 4) | ((int)(plain[i + 1] & 0xF0) >> 4)]; - (*encoded) += basis_64[((plain[i + 1] & 0xF) << 2) | - ((int)(plain[i + 2] & 0xC0) >> 6)]; - (*encoded) += basis_64[plain[i + 2] & 0x3F]; - } + std::size_t i = 0; + if (len >= 2) + for (; i < len - 2; i += 3) { + (*encoded) += basis_64[(plain[i] >> 2) & 0x3F]; + (*encoded) += + basis_64[((plain[i] & 0x3) << 4) | ((int)(plain[i + 1] & 0xF0) >> 4)]; + (*encoded) += basis_64[((plain[i + 1] & 0xF) << 2) | + ((int)(plain[i + 2] & 0xC0) >> 6)]; + (*encoded) += basis_64[plain[i + 2] & 0x3F]; + } if (i < len) { (*encoded) += basis_64[(plain[i] >> 2) & 0x3F]; if (i == (len - 1)) { diff --git a/test/unit/test_Base64.cpp b/test/unit/test_Base64.cpp new file mode 100644 index 0000000000..5f74cdebf2 --- /dev/null +++ b/test/unit/test_Base64.cpp @@ -0,0 +1,62 @@ +#include "Base64.h" + +#include "gtest/gtest.h" + +namespace ntimpl { + +struct Base64TestParam { + int plain_len; + const char* plain; + const char* encoded; +}; + +class Base64Test : public ::testing::TestWithParam { + protected: + llvm::StringRef GetPlain() { + if (GetParam().plain_len < 0) + return llvm::StringRef(GetParam().plain); + else + return llvm::StringRef(GetParam().plain, GetParam().plain_len); + } +}; + +TEST_P(Base64Test, Encode) { + std::string s; + Base64Encode(GetPlain(), &s); + ASSERT_EQ(GetParam().encoded, s); +} + +TEST_P(Base64Test, Decode) { + std::string s; + Base64Decode(GetParam().encoded, &s); + ASSERT_EQ(GetPlain(), s); +} + +static Base64TestParam sample[] = { + {-1, "Send reinforcements", "U2VuZCByZWluZm9yY2VtZW50cw=="}, + {-1, "Now is the time for all good coders\n to learn C++", + "Tm93IGlzIHRoZSB0aW1lIGZvciBhbGwgZ29vZCBjb2RlcnMKIHRvIGxlYXJuIEMrKw=="}, + {-1, + "This is line one\nThis is line two\nThis is line three\nAnd so on...\n", + "VGhpcyBpcyBsaW5lIG9uZQpUaGlzIGlzIGxpbmUgdHdvClRoaXMgaXMgbGluZSB0aHJlZQpBb" + "mQgc28gb24uLi4K"}, +}; + +INSTANTIATE_TEST_CASE_P(Base64Sample, Base64Test, + ::testing::ValuesIn(sample)); + +static Base64TestParam standard[] = { + {0, "", ""}, + {1, "\0", "AA=="}, + {2, "\0\0", "AAA="}, + {3, "\0\0\0", "AAAA"}, + {1, "\377", "/w=="}, + {2, "\377\377", "//8="}, + {3, "\377\377\377", "////"}, + {2, "\xff\xef", "/+8="}, +}; + +INSTANTIATE_TEST_CASE_P(Base64Standard, Base64Test, + ::testing::ValuesIn(standard)); + +} // namespace ntimpl