diff --git a/src/headers/tomcrypt_misc.h b/src/headers/tomcrypt_misc.h index 2f670cc..4f5e8fa 100644 --- a/src/headers/tomcrypt_misc.h +++ b/src/headers/tomcrypt_misc.h @@ -5,14 +5,20 @@ int base64_encode(const unsigned char *in, unsigned long len, int base64_decode(const unsigned char *in, unsigned long len, unsigned char *out, unsigned long *outlen); +int base64_strict_decode(const unsigned char *in, unsigned long len, + unsigned char *out, unsigned long *outlen); #endif #ifdef LTC_BASE64_URL int base64url_encode(const unsigned char *in, unsigned long len, unsigned char *out, unsigned long *outlen); +int base64url_strict_encode(const unsigned char *in, unsigned long inlen, + unsigned char *out, unsigned long *outlen); int base64url_decode(const unsigned char *in, unsigned long len, unsigned char *out, unsigned long *outlen); +int base64url_strict_decode(const unsigned char *in, unsigned long len, + unsigned char *out, unsigned long *outlen); #endif /* ===> LTC_HKDF -- RFC5869 HMAC-based Key Derivation Function <=== */ diff --git a/src/misc/base64/base64_decode.c b/src/misc/base64/base64_decode.c index 423dc43..1babfbc 100644 --- a/src/misc/base64/base64_decode.c +++ b/src/misc/base64/base64_decode.c @@ -71,9 +71,14 @@ static const unsigned char map_base64url[256] = { 255, 255, 255, 255 }; #endif /* LTC_BASE64_URL */ +enum { + relaxed = 0, + strict = 1 +}; + static int _base64_decode_internal(const unsigned char *in, unsigned long inlen, unsigned char *out, unsigned long *outlen, - const unsigned char *map) + const unsigned char *map, int is_strict) { unsigned long t, x, y, z; unsigned char c; @@ -83,36 +88,42 @@ static int _base64_decode_internal(const unsigned char *in, unsigned long inlen LTC_ARGCHK(out != NULL); LTC_ARGCHK(outlen != NULL); - g = 3; + g = 0; /* '=' counter */ for (x = y = z = t = 0; x < inlen; x++) { c = map[in[x]&0xFF]; - if (c == 255) continue; - /* the final = symbols are read and used to trim the remaining bytes */ if (c == 254) { - c = 0; - /* prevent g < 0 which would potentially allow an overflow later */ - if (--g < 0) { - return CRYPT_INVALID_PACKET; - } - } else if (g != 3) { - /* we only allow = to be at the end */ + g++; + continue; + } + else if (is_strict && g > 0) { + /* we only allow '=' to be at the end */ return CRYPT_INVALID_PACKET; } + if (c == 255) { + if (is_strict) + return CRYPT_INVALID_PACKET; + else + continue; + } t = (t<<6)|c; if (++y == 4) { - if (z + g > *outlen) { - return CRYPT_BUFFER_OVERFLOW; - } + if (z + 3 > *outlen) return CRYPT_BUFFER_OVERFLOW; out[z++] = (unsigned char)((t>>16)&255); - if (g > 1) out[z++] = (unsigned char)((t>>8)&255); - if (g > 2) out[z++] = (unsigned char)(t&255); + out[z++] = (unsigned char)((t>>8)&255); + out[z++] = (unsigned char)(t&255); y = t = 0; } } + if (y != 0) { - return CRYPT_INVALID_PACKET; + if (y == 1) return CRYPT_INVALID_PACKET; + if ((y + g) != 4 && is_strict && map != map_base64url) return CRYPT_INVALID_PACKET; + t = t << (6 * (4 - y)); + if (z + y - 1 > *outlen) return CRYPT_BUFFER_OVERFLOW; + if (y >= 2) out[z++] = (unsigned char) ((t >> 16) & 255); + if (y == 3) out[z++] = (unsigned char) ((t >> 8) & 255); } *outlen = z; return CRYPT_OK; @@ -120,7 +131,7 @@ static int _base64_decode_internal(const unsigned char *in, unsigned long inlen #if defined(LTC_BASE64) /** - base64 decode a block of memory + Relaxed base64 decode a block of memory @param in The base64 data to decode @param inlen The length of the base64 data @param out [out] The destination of the binary decoded data @@ -130,13 +141,27 @@ static int _base64_decode_internal(const unsigned char *in, unsigned long inlen int base64_decode(const unsigned char *in, unsigned long inlen, unsigned char *out, unsigned long *outlen) { - return _base64_decode_internal(in, inlen, out, outlen, map_base64); + return _base64_decode_internal(in, inlen, out, outlen, map_base64, relaxed); +} + +/** + Strict base64 decode a block of memory + @param in The base64 data to decode + @param inlen The length of the base64 data + @param out [out] The destination of the binary decoded data + @param outlen [in/out] The max size and resulting size of the decoded data + @return CRYPT_OK if successful +*/ +int base64_strict_decode(const unsigned char *in, unsigned long inlen, + unsigned char *out, unsigned long *outlen) +{ + return _base64_decode_internal(in, inlen, out, outlen, map_base64, strict); } #endif /* LTC_BASE64 */ #if defined(LTC_BASE64_URL) /** - base64 (URL Safe, RFC 4648 section 5) decode a block of memory + Relaxed base64 (URL Safe, RFC 4648 section 5) decode a block of memory @param in The base64 data to decode @param inlen The length of the base64 data @param out [out] The destination of the binary decoded data @@ -146,7 +171,21 @@ int base64_decode(const unsigned char *in, unsigned long inlen, int base64url_decode(const unsigned char *in, unsigned long inlen, unsigned char *out, unsigned long *outlen) { - return _base64_decode_internal(in, inlen, out, outlen, map_base64url); + return _base64_decode_internal(in, inlen, out, outlen, map_base64url, relaxed); +} + +/** + Strict base64 (URL Safe, RFC 4648 section 5) decode a block of memory + @param in The base64 data to decode + @param inlen The length of the base64 data + @param out [out] The destination of the binary decoded data + @param outlen [in/out] The max size and resulting size of the decoded data + @return CRYPT_OK if successful +*/ +int base64url_strict_decode(const unsigned char *in, unsigned long inlen, + unsigned char *out, unsigned long *outlen) +{ + return _base64_decode_internal(in, inlen, out, outlen, map_base64url, strict); } #endif /* LTC_BASE64_URL */ diff --git a/src/misc/base64/base64_encode.c b/src/misc/base64/base64_encode.c index 0ed0aa3..c87f302 100644 --- a/src/misc/base64/base64_encode.c +++ b/src/misc/base64/base64_encode.c @@ -110,6 +110,12 @@ int base64url_encode(const unsigned char *in, unsigned long inlen, { return _base64_encode_internal(in, inlen, out, outlen, codes_base64url, 0); } + +int base64url_strict_encode(const unsigned char *in, unsigned long inlen, + unsigned char *out, unsigned long *outlen) +{ + return _base64_encode_internal(in, inlen, out, outlen, codes_base64url, 1); +} #endif /* LTC_BASE64_URL */ #endif diff --git a/testprof/base64_test.c b/testprof/base64_test.c index 8c15d3c..9868a0f 100644 --- a/testprof/base64_test.c +++ b/testprof/base64_test.c @@ -5,6 +5,10 @@ int base64_test(void) { unsigned char in[64], out[256], tmp[64]; unsigned long x, l1, l2, slen1; + const char special_case[] = { + 0xbe, 0xe8, 0x92, 0x3c, 0xa2, 0x25, 0xf0, 0xf8, + 0x91, 0xe4, 0xef, 0xab, 0x0b, 0x8c, 0xfd, 0xff, + 0x14, 0xd0, 0x29, 0x9d, 0x00 }; /* TEST CASES SOURCE: @@ -24,38 +28,87 @@ int base64_test(void) {"foo", "Zm9v" }, {"foob", "Zm9vYg==" }, {"fooba", "Zm9vYmE=" }, - {"foobar", "Zm9vYmFy"} + {"foobar", "Zm9vYmFy"}, + {special_case,"vuiSPKIl8PiR5O+rC4z9/xTQKZ0="} + }; + + const struct { + const char* s; + int is_strict; + } url_cases[] = { + {"vuiSPKIl8PiR5O-rC4z9_xTQKZ0", 0}, + {"vuiSPKIl8PiR5O-rC4z9_xTQKZ0=", 1}, + {"vuiS*PKIl8P*iR5O-rC4*z9_xTQKZ0", 0}, + {"vuiS*PKIl8P*iR5O-rC4*z9_xTQKZ0=", 0}, + {"vuiS*PKIl8P*iR5O-rC4*z9_xTQKZ0==", 0}, + {"vuiS*PKIl8P*iR5O-rC4*z9_xTQKZ0===", 0}, + {"vuiS*PKIl8P*iR5O-rC4*z9_xTQKZ0====", 0}, + {"vuiS*=PKIl8P*iR5O-rC4*z9_xTQKZ0=", 0}, + {"vuiS*==PKIl8P*iR5O-rC4*z9_xTQKZ0=", 0}, + {"vuiS*===PKIl8P*iR5O-rC4*z9_xTQKZ0=", 0}, }; for (x = 0; x < sizeof(cases)/sizeof(cases[0]); ++x) { + memset(out, 0, sizeof(out)); + memset(tmp, 0, sizeof(tmp)); slen1 = strlen(cases[x].s); l1 = sizeof(out); DO(base64_encode((unsigned char*)cases[x].s, slen1, out, &l1)); l2 = sizeof(tmp); - DO(base64_decode(out, l1, tmp, &l2)); - if (l2 != slen1 || l1 != strlen(cases[x].b64) || memcmp(tmp, cases[x].s, l2) || memcmp(out, cases[x].b64, l1)) { - fprintf(stderr, "\nbase64 failed case %lu", x); - fprintf(stderr, "\nbase64 should: %s", cases[x].b64); - out[sizeof(out)-1] = '\0'; - fprintf(stderr, "\nbase64 is: %s", out); - fprintf(stderr, "\nplain should: %s", cases[x].s); - tmp[sizeof(tmp)-1] = '\0'; - fprintf(stderr, "\nplain is: %s\n", tmp); + DO(base64_strict_decode(out, l1, tmp, &l2)); + if (compare_testvector(out, l1, cases[x].b64, strlen(cases[x].b64), "base64 encode", x) || + compare_testvector(tmp, l2, cases[x].s, slen1, "base64 decode", x)) { return 1; } } + for (x = 0; x < sizeof(url_cases)/sizeof(url_cases[0]); ++x) { + slen1 = strlen(url_cases[x].s); + l1 = sizeof(out); + if(url_cases[x].is_strict) + DO(base64url_strict_decode((unsigned char*)url_cases[x].s, slen1, out, &l1)); + else + DO(base64url_decode((unsigned char*)url_cases[x].s, slen1, out, &l1)); + if (compare_testvector(out, l1, special_case, strlen(special_case), "base64url decode", x)) { + return 1; + } + if(x < 2) { + l2 = sizeof(tmp); + if(x == 0) + DO(base64url_encode(out, l1, tmp, &l2)); + else + DO(base64url_strict_encode(out, l1, tmp, &l2)); + if (compare_testvector(tmp, l2, url_cases[x].s, strlen(url_cases[x].s), "base64url encode", x)) { + return 1; + } + } + } + + DO(base64url_strict_decode((unsigned char*)url_cases[4].s, slen1, out, &l1) == CRYPT_INVALID_PACKET ? CRYPT_OK : CRYPT_INVALID_PACKET); + for (x = 0; x < 64; x++) { yarrow_read(in, x, &yarrow_prng); l1 = sizeof(out); DO(base64_encode(in, x, out, &l1)); l2 = sizeof(tmp); DO(base64_decode(out, l1, tmp, &l2)); - if (l2 != x || memcmp(tmp, in, x)) { - fprintf(stderr, "base64 failed %lu %lu %lu", x, l1, l2); + if (compare_testvector(tmp, x, in, x, "random base64", x)) { return 1; } } + + x--; + memmove(&out[11], &out[10], l1 - 10); + out[10] = '='; + l1++; + l2 = sizeof(tmp); + DO(base64_decode(out, l1, tmp, &l2)); + if (compare_testvector(tmp, l2, in, l2, "relaxed base64 decoding", -1)) { + print_hex("input ", out, l1); + return 1; + } + l2 = sizeof(tmp); + DO(base64_strict_decode(out, l1, tmp, &l2) == CRYPT_INVALID_PACKET ? CRYPT_OK : CRYPT_INVALID_PACKET); return 0; } #endif