/*
 * Copyright 2022 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// #define LOG_NDEBUG 0
#define LOG_TAG "audio_utils_MelAggregator"

#include <audio_utils/MelAggregator.h>
#include <audio_utils/power.h>
#include <cinttypes>
#include <iterator>
#include <utils/Log.h>

namespace android::audio_utils {
namespace {

/** Min value after which the MEL values are aggregated to CSD. */
constexpr float kMinCsdRecordToStore = 0.01f;

/** Threshold for 100% CSD expressed in Pa^2s. */
constexpr float kCsdThreshold = 5760.0f; // 1.6f(Pa^2h) * 3600.0f(s);

/** Reference energy used for dB calculation in Pa^2. */
constexpr float kReferenceEnergyPa = 4e-10;

/**
 * Checking the intersection of the time intervals of v1 and v2. Each MelRecord v
 * spawns an interval [t1, t2) if and only if:
 *    v.timestamp == t1 && v.mels.size() == t2 - t1
 **/
std::pair<int64_t, int64_t> intersectRegion(const MelRecord& v1, const MelRecord& v2)
{
    const int64_t maxStart = std::max(v1.timestamp, v2.timestamp);
    const int64_t v1End = v1.timestamp + v1.mels.size();
    const int64_t v2End = v2.timestamp + v2.mels.size();
    const int64_t minEnd = std::min(v1End, v2End);
    return {maxStart, minEnd};
}

float aggregateMels(const float mel1, const float mel2) {
    return audio_utils_power_from_energy(powf(10.f, mel1 / 10.f) + powf(10.f, mel2 / 10.f));
}

float averageMelEnergy(const float mel1,
                       const int64_t duration1,
                       const float mel2,
                       const int64_t duration2) {
    return audio_utils_power_from_energy((powf(10.f, mel1 / 10.f) * duration1
        + powf(10.f, mel2 / 10.f) * duration2) / (duration1 + duration2));
}

float melToCsd(float mel) {
    float energy = powf(10.f, mel / 10.0f);
    return kReferenceEnergyPa * energy / kCsdThreshold;
}

CsdRecord createRevertedRecord(const CsdRecord& record) {
    return {record.timestamp, record.duration, -record.value, record.averageMel};
}

}  // namespace

int64_t MelAggregator::csdTimeIntervalStored_l()
{
    return mCsdRecords.rbegin()->second.timestamp + mCsdRecords.rbegin()->second.duration
        - mCsdRecords.begin()->second.timestamp;
}

std::map<int64_t, CsdRecord>::iterator MelAggregator::addNewestCsdRecord_l(int64_t timestamp,
                                                                           int64_t duration,
                                                                           float csdRecord,
                                                                           float averageMel)
{
    ALOGV("%s: add new csd[%" PRId64 ", %" PRId64 "]=%f for MEL avg %f",
                      __func__,
                      timestamp,
                      duration,
                      csdRecord,
                      averageMel);

    mCurrentCsd += csdRecord;
    return mCsdRecords.emplace_hint(mCsdRecords.end(),
                                    timestamp,
                                    CsdRecord(timestamp,
                                              duration,
                                              csdRecord,
                                              averageMel));
}

void MelAggregator::removeOldCsdRecords_l(std::vector<CsdRecord>& removeRecords) {
    // Remove older CSD values
    while (!mCsdRecords.empty() && csdTimeIntervalStored_l() > mCsdWindowSeconds) {
        mCurrentCsd -= mCsdRecords.begin()->second.value;
        removeRecords.emplace_back(createRevertedRecord(mCsdRecords.begin()->second));
        mCsdRecords.erase(mCsdRecords.begin());
    }
}

std::vector<CsdRecord> MelAggregator::updateCsdRecords_l()
{
    std::vector<CsdRecord> newRecords;

    // only update if we are above threshold
    if (mCurrentMelRecordsCsd < kMinCsdRecordToStore) {
        removeOldCsdRecords_l(newRecords);
        return newRecords;
    }

    float converted = 0.f;
    float averageMel = 0.f;
    float csdValue = 0.f;
    int64_t duration = 0;
    int64_t timestamp = mMelRecords.begin()->first;
    for (const auto& storedMel: mMelRecords) {
        int melsIdx = 0;
        for (const auto& mel: storedMel.second.mels) {
            averageMel = averageMelEnergy(averageMel, duration, mel, 1.f);
            csdValue += melToCsd(mel);
            ++duration;
            if (csdValue >= kMinCsdRecordToStore
                && mCurrentMelRecordsCsd - converted - csdValue >= kMinCsdRecordToStore) {
                auto it = addNewestCsdRecord_l(timestamp,
                                               duration,
                                               csdValue,
                                               averageMel);
                newRecords.emplace_back(it->second);

                duration = 0;
                averageMel = 0.f;
                converted += csdValue;
                csdValue = 0.f;
                timestamp = storedMel.first + melsIdx;
            }
            ++ melsIdx;
        }
    }

    if(csdValue > 0) {
        auto it = addNewestCsdRecord_l(timestamp,
                                       duration,
                                       csdValue,
                                       averageMel);
        newRecords.emplace_back(it->second);
    }

    removeOldCsdRecords_l(newRecords);

    // reset mel values
    mCurrentMelRecordsCsd = 0.0f;
    mMelRecords.clear();

    return newRecords;
}

std::vector<CsdRecord> MelAggregator::aggregateAndAddNewMelRecord(const MelRecord& mel)
{
    std::lock_guard _l(mLock);
    return aggregateAndAddNewMelRecord_l(mel);
}

std::vector<CsdRecord> MelAggregator::aggregateAndAddNewMelRecord_l(const MelRecord& mel)
{
    for (const auto& m : mel.mels) {
        mCurrentMelRecordsCsd += melToCsd(m);
    }
    ALOGV("%s: current mel values CSD %f", __func__, mCurrentMelRecordsCsd);

    auto mergeIt = mMelRecords.lower_bound(mel.timestamp);

    if (mergeIt != mMelRecords.begin()) {
        auto prevMergeIt = std::prev(mergeIt);
        if (prevMergeIt->second.overlapsEnd(mel)) {
            mergeIt = prevMergeIt;
        }
    }

    int64_t newTimestamp = mel.timestamp;
    std::vector<float> newMels = mel.mels;
    auto mergeStart = mergeIt;
    int overlapStart = 0;
    while(mergeIt != mMelRecords.end()) {
        const auto& [melRecordStart, melRecord] = *mergeIt;
        const auto [regionStart, regionEnd] = intersectRegion(melRecord, mel);
        if (regionStart >= regionEnd) {
            // no intersection
            break;
        }

        if (melRecordStart < regionStart) {
            newTimestamp = melRecordStart;
            overlapStart = regionStart - melRecordStart;
            newMels.insert(newMels.begin(), melRecord.mels.begin(),
                           melRecord.mels.begin() + overlapStart);
        }

        for (int64_t aggregateTime = regionStart; aggregateTime < regionEnd; ++aggregateTime) {
            const int offsetStored = aggregateTime - melRecordStart;
            const int offsetNew = aggregateTime - mel.timestamp;
            newMels[overlapStart + offsetNew] =
                aggregateMels(melRecord.mels[offsetStored], mel.mels[offsetNew]);
        }

        const int64_t mergeEndTime = melRecordStart + melRecord.mels.size();
        if (mergeEndTime > regionEnd) {
            newMels.insert(newMels.end(),
                           melRecord.mels.end() - mergeEndTime + regionEnd,
                           melRecord.mels.end());
        }

        ++mergeIt;
    }

    auto hint = mergeIt;
    if (mergeStart != mergeIt) {
        hint = mMelRecords.erase(mergeStart, mergeIt);
    }

    mMelRecords.emplace_hint(hint,
                             newTimestamp,
                             MelRecord(mel.portId, newMels, newTimestamp));

    return updateCsdRecords_l();
}

void MelAggregator::reset(float newCsd, const std::vector<CsdRecord>& newRecords)
{
    std::lock_guard _l(mLock);
    mCsdRecords.clear();
    mMelRecords.clear();

    mCurrentCsd = newCsd;
    for (const auto& record : newRecords) {
        mCsdRecords.emplace_hint(mCsdRecords.end(), record.timestamp, record);
    }
}

size_t MelAggregator::getCachedMelRecordsSize() const
{
    std::lock_guard _l(mLock);
    return mMelRecords.size();
}

void MelAggregator::foreachCachedMel(const std::function<void(const MelRecord&)>& f) const
{
     std::lock_guard _l(mLock);
     for (const auto &melRecord : mMelRecords) {
         f(melRecord.second);
     }
}

float MelAggregator::getCsd() {
    std::lock_guard _l(mLock);
    return mCurrentCsd;
}

size_t MelAggregator::getCsdRecordsSize() const {
    std::lock_guard _l(mLock);
    return mCsdRecords.size();
}

void MelAggregator::foreachCsd(const std::function<void(const CsdRecord&)>& f) const
{
     std::lock_guard _l(mLock);
     for (const auto &csdRecord : mCsdRecords) {
         f(csdRecord.second);
     }
}

}  // namespace android::audio_utils