/**
 * Copyright (C) 2021 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <dlfcn.h>
#include <string.h>
#include <openssl/ssl.h>
#include <openssl/crypto.h>
#include <openssl/bn.h>
#include <memory>
#include "../includes/common.h"

/** NOTE: These values are for the BIGNUM declared in kBN2DecTests and  */
/** must be updated if kBN2DecTests is changed.                         */
#if _32_BIT
#define ALLOCATION_SIZE      52
static const int sMallocSkipCount[] = {1,0};
#else
#define ALLOCATION_SIZE      56
static const int sMallocSkipCount[] = {0,0};
#endif

static const char *kTest =
    "123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890";

static int sCount = 0;
static bool sOverloadMalloc = false;
int loopIndex = 0;

template<typename T>
struct OpenSSLFree {
    void operator()(T *buf) {
        OPENSSL_free(buf);
    }
};

using ScopedOpenSSLString = std::unique_ptr<char, OpenSSLFree<char>>;

namespace crypto {
template<typename T, void (*func)(T*)>
struct OpenSSLDeleter {
    void operator()(T *obj) {
        func(obj);
    }
};

template<typename Type, void (*Destroyer)(Type*)>
struct OpenSSLDestroyer {
    void operator()(Type* ptr) const {
        Destroyer(ptr);
    }
};

template<typename T, void (*func)(T*)>
using ScopedOpenSSLType = std::unique_ptr<T, OpenSSLDeleter<T, func>>;

template<typename PointerType, void (*Destroyer)(PointerType*)>
using ScopedOpenSSL =
std::unique_ptr<PointerType, OpenSSLDestroyer<PointerType, Destroyer>>;

struct OpenSSLFree {
    void operator()(uint8_t* ptr) const {
        OPENSSL_free(ptr);
    }
};

using ScopedBIGNUM = ScopedOpenSSL<BIGNUM, BN_free>;
using ScopedBN_CTX = ScopedOpenSSLType<BN_CTX, BN_CTX_free>;
}  // namespace crypto

static int DecimalToBIGNUM(crypto::ScopedBIGNUM *out, const char *in) {
    BIGNUM *raw = nullptr;
    int ret = BN_dec2bn(&raw, in);
    out->reset(raw);
    return ret;
}

void* (*realMalloc)(size_t) = nullptr;

void mtraceInit(void) {
    realMalloc = (void *(*)(size_t))dlsym(RTLD_NEXT, "malloc");
    return;
}

void *malloc(size_t size) {
    if (realMalloc == nullptr) {
        mtraceInit();
    }
    if (!sOverloadMalloc) {
        return realMalloc(size);
    }
    if (size == ALLOCATION_SIZE) {
        if (sCount >= sMallocSkipCount[loopIndex]) {
            return nullptr;
        }
        ++sCount;
    }
    return realMalloc(size);
}

using namespace crypto;

int main() {
    CRYPTO_library_init();
    ScopedBN_CTX ctx(BN_CTX_new());
    if (!ctx) {
        return EXIT_FAILURE;
    }
    for(loopIndex = 0; loopIndex < 2; ++loopIndex) {
        ScopedBIGNUM bn;
        int ret = DecimalToBIGNUM(&bn, kTest);
        if (!ret) {
            return EXIT_FAILURE;
        }
        sOverloadMalloc = true;
        ScopedOpenSSLString dec(BN_bn2dec(bn.get()));
        sOverloadMalloc = false;
        if (!dec) {
            return EXIT_FAILURE;
        }
        if (strcmp(dec.get(), kTest)) {
            return EXIT_FAILURE;
        }
    }
    return EXIT_SUCCESS;
}