/** * Copyright (C) 2021 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include <dlfcn.h> #include <string.h> #include <openssl/ssl.h> #include <openssl/crypto.h> #include <openssl/bn.h> #include <memory> #include "../includes/common.h" /** NOTE: These values are for the BIGNUM declared in kBN2DecTests and */ /** must be updated if kBN2DecTests is changed. */ #if _32_BIT #define ALLOCATION_SIZE 52 static const int sMallocSkipCount[] = {1,0}; #else #define ALLOCATION_SIZE 56 static const int sMallocSkipCount[] = {0,0}; #endif static const char *kTest = "123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890"; static int sCount = 0; static bool sOverloadMalloc = false; int loopIndex = 0; template<typename T> struct OpenSSLFree { void operator()(T *buf) { OPENSSL_free(buf); } }; using ScopedOpenSSLString = std::unique_ptr<char, OpenSSLFree<char>>; namespace crypto { template<typename T, void (*func)(T*)> struct OpenSSLDeleter { void operator()(T *obj) { func(obj); } }; template<typename Type, void (*Destroyer)(Type*)> struct OpenSSLDestroyer { void operator()(Type* ptr) const { Destroyer(ptr); } }; template<typename T, void (*func)(T*)> using ScopedOpenSSLType = std::unique_ptr<T, OpenSSLDeleter<T, func>>; template<typename PointerType, void (*Destroyer)(PointerType*)> using ScopedOpenSSL = std::unique_ptr<PointerType, OpenSSLDestroyer<PointerType, Destroyer>>; struct OpenSSLFree { void operator()(uint8_t* ptr) const { OPENSSL_free(ptr); } }; using ScopedBIGNUM = ScopedOpenSSL<BIGNUM, BN_free>; using ScopedBN_CTX = ScopedOpenSSLType<BN_CTX, BN_CTX_free>; } // namespace crypto static int DecimalToBIGNUM(crypto::ScopedBIGNUM *out, const char *in) { BIGNUM *raw = nullptr; int ret = BN_dec2bn(&raw, in); out->reset(raw); return ret; } void* (*realMalloc)(size_t) = nullptr; void mtraceInit(void) { realMalloc = (void *(*)(size_t))dlsym(RTLD_NEXT, "malloc"); return; } void *malloc(size_t size) { if (realMalloc == nullptr) { mtraceInit(); } if (!sOverloadMalloc) { return realMalloc(size); } if (size == ALLOCATION_SIZE) { if (sCount >= sMallocSkipCount[loopIndex]) { return nullptr; } ++sCount; } return realMalloc(size); } using namespace crypto; int main() { CRYPTO_library_init(); ScopedBN_CTX ctx(BN_CTX_new()); if (!ctx) { return EXIT_FAILURE; } for(loopIndex = 0; loopIndex < 2; ++loopIndex) { ScopedBIGNUM bn; int ret = DecimalToBIGNUM(&bn, kTest); if (!ret) { return EXIT_FAILURE; } sOverloadMalloc = true; ScopedOpenSSLString dec(BN_bn2dec(bn.get())); sOverloadMalloc = false; if (!dec) { return EXIT_FAILURE; } if (strcmp(dec.get(), kTest)) { return EXIT_FAILURE; } } return EXIT_SUCCESS; }