#include "RkAudio.h"
#include "BoostLog.h"
#include <cstring>
#include <rkmedia/rkmedia_api.h>

namespace RkAudio {

constexpr RK_U32 VqeFrameSample = 256; // 16ms;
constexpr auto AudioNode = "hw:0,0";
constexpr auto ParamFilePath = "/system/etc/RKAP_3A_Para.bin";

static SAMPLE_FORMAT_E rkAiFormat(Format::SampleType sampleType) {
    SAMPLE_FORMAT_E ret = RK_SAMPLE_FMT_NONE;
    switch (sampleType) {
    case Format::SampleType::Unknown:
        ret = RK_SAMPLE_FMT_NONE;
        break;
    case Format::SampleType::SignedInt16:
        ret = RK_SAMPLE_FMT_S16;
        break;
    case Format::SampleType::SignedInt:
        ret = RK_SAMPLE_FMT_S32;
        break;
    case Format::SampleType::Float:
        ret = RK_SAMPLE_FMT_FLT;
        break;
    default:
        LOG(error) << "unkonwn sample type: " << static_cast<int>(sampleType);
        ret = RK_SAMPLE_FMT_NONE;
        break;
    }
    return ret;
}

Input::Input() {
}

Input::~Input() {
    if (m_channel >= 0) {
        stop();
    }
}
bool Input::open(const Format &format, bool enableVqe) {
    bool ret = false;
    m_channel = 0;

    AI_CHN_ATTR_S parameter = {0};
    parameter.pcAudioNode = (RK_CHAR *)AudioNode;
    parameter.enAiLayout = AI_LAYOUT_MIC_REF; // remove ref channel, and output mic mono
    parameter.enSampleFormat = rkAiFormat(format.sampleType);
    parameter.u32Channels = format.channels;
    parameter.u32SampleRate = format.sampleRate;
    parameter.u32NbSamples = format.sampleRate / 1000 * format.period;
    int status = RK_MPI_AI_SetChnAttr(m_channel, &parameter);
    if (status) {
        LOG(error) << "RK_MPI_AI_SetChnAttr() failed, status: " << status;
        return ret;
    }

    status = RK_MPI_AI_EnableChn(m_channel);
    if (status) {
        LOG(error) << "RK_MPI_AI_EnableChn() failed, status: " << status;
        return ret;
    }

    if (enableVqe) {
        AI_TALKVQE_CONFIG_S config = {0};
        status = RK_MPI_AI_GetTalkVqeAttr(m_channel, &config);
        if (status) {
            LOG(error) << "RK_MPI_AI_GetTalkVqeAttr() failed, status: " << status;
            return ret;
        }
        LOG(info) << "param file: " << config.aParamFilePath;
        config.s32WorkSampleRate = format.sampleRate;
        config.s32FrameSample = VqeFrameSample;
        config.u32OpenMask = AI_TALKVQE_MASK_AEC | AI_TALKVQE_MASK_ANR | AI_TALKVQE_MASK_AGC;
        strncpy(config.aParamFilePath, ParamFilePath, sizeof(config.aParamFilePath));
        RK_MPI_AI_SetTalkVqeAttr(m_channel, &config);
        if (status) {
            LOG(error) << "RK_MPI_AI_SetTalkVqeAttr() failed, status: " << status;
            return ret;
        }
        status = RK_MPI_AI_EnableVqe(m_channel);
        if (status) {
            LOG(error) << "RK_MPI_AI_EnableVqe() failed, status: " << status;
            return ret;
        }
    }

    status = RK_MPI_AI_StartStream(0);
    if (status) {
        LOG(info) << "start AI failed, status: " << status;
        return ret;
    }

    m_exit = false;
    m_thread = std::thread(&Input::run, this);
    ret = true;
    return ret;
}

void Input::stop() {
    m_exit = true;
    if (m_thread.joinable()) m_thread.join();

    if (m_channel >= 0) {
        RK_MPI_AI_DisableVqe(m_channel);
        RK_MPI_AI_DisableChn(m_channel);
        m_channel = -1;
    }
}

void Input::setDataCallback(const ReadCallback &callback) {
    m_callback = callback;
}

void Input::run() {
    while (!m_exit) {
        auto mediaBuffer = RK_MPI_SYS_GetMediaBuffer(RK_ID_AI, 0, -1);
        if (!mediaBuffer) {
            LOG(error) << "RK_MPI_SYS_GetMediaBuffer() failed.";
            continue;
        }
        if (m_callback) {
            Frame frame;
            frame.data = reinterpret_cast<uint8_t *>(RK_MPI_MB_GetPtr(mediaBuffer));
            frame.byteSize = RK_MPI_MB_GetSize(mediaBuffer);
            frame.frameSize = frame.byteSize / m_format.channels / sizeof(uint16_t);
            frame.timestamp = std::chrono::system_clock::now();
            m_callback(frame);
        }
        RK_MPI_MB_ReleaseBuffer(mediaBuffer);
    }
}

Output::Output() {
}

Output::~Output() {
    close();
}

bool Output::open(uint32_t sampleSize, uint32_t sampleRate, uint32_t channels, uint32_t period, bool enableVqe) {

    m_channel = 0;
    AO_CHN_ATTR_S parameter = {0};
    parameter.pcAudioNode = (RK_CHAR *)AudioNode;
    parameter.enSampleFormat = RK_SAMPLE_FMT_S16;
    parameter.u32NbSamples = sampleRate / 1000 * period;
    parameter.u32SampleRate = sampleRate;
    parameter.u32Channels = channels;

    RK_MPI_AO_SetChnAttr(m_channel, &parameter);
    auto status = RK_MPI_AO_EnableChn(m_channel);
    if (status != 0) {
        LOG(error) << "RK_MPI_AO_EnableChn() failed, status: " << status;
        return false;
    }

    if (enableVqe) {
        AO_VQE_CONFIG_S config = {0};
        config.s32WorkSampleRate = sampleRate;
        config.s32FrameSample = VqeFrameSample;
        config.u32OpenMask = AO_VQE_MASK_ANR | AO_VQE_MASK_AGC;
        strncpy(config.aParamFilePath, ParamFilePath, sizeof(config.aParamFilePath));

        RK_MPI_AO_SetVqeAttr(m_channel, &config);
        RK_MPI_AO_EnableVqe(m_channel);
    }

    return true;
}

void Output::close() {
    if (m_channel >= 0) {
        RK_MPI_AO_DisableVqe(m_channel);
        RK_MPI_AO_DisableChn(m_channel);
        m_channel = -1;
    }
}

void Output::write(const uint8_t *data, uint32_t byteSize) {
    if (m_channel < 0) return;
    auto buffer = RK_MPI_MB_CreateAudioBuffer(byteSize, RK_FALSE);
    if (buffer != nullptr) {
        memcpy(RK_MPI_MB_GetPtr(buffer), data, byteSize);
        RK_MPI_MB_SetSize(buffer, byteSize);
        RK_MPI_SYS_SendMediaBuffer(RK_ID_AO, m_channel, buffer);
        RK_MPI_MB_ReleaseBuffer(buffer);
    } else {
        LOG(error) << "RK_MPI_MB_CreateAudioBuffer() failed.";
    }
}

} // namespace RkAudio

PcmStreamBuffer::PcmStreamBuffer(uint32_t sampleRate, uint32_t channels, RkAudio::Format::SampleType sampleType, uint32_t popDuration,
                                 uint32_t capacity)
    : m_sampleRate(sampleRate), m_channels(channels), m_buffer(capacity), m_capacity(capacity) {
    if (sampleType == RkAudio::Format::SignedInt16) {
        m_pointByteSize = 2;
    }

    uint32_t frameSize = sampleRate * channels * m_pointByteSize * popDuration / 1000;
    m_popBuffer = std::vector<uint8_t>(frameSize);
    m_popFrame.data = m_popBuffer.data();
}

bool PcmStreamBuffer::push(const RkAudio::Frame &frame) {
    std::lock_guard<std::mutex> locker(m_mutex);
    uint32_t byteSize = availableByteSize();
    uint32_t freeSize = m_capacity - byteSize;
    if (freeSize < frame.byteSize) {
        LOG_FORMAT(warning, "buffer is full, capacity: %d, free size: %d, need size: %d", m_capacity, freeSize, frame.byteSize);
        return false;
    }
    if ((m_tail + frame.byteSize) > m_capacity) {

        uint32_t size1 = m_capacity - m_tail;
        memcpy(m_buffer.data() + m_tail, frame.data, size1);

        uint32_t size2 = frame.byteSize - size1;
        memcpy(m_buffer.data(), frame.data + size1, size2);
    } else {
        memcpy(m_buffer.data() + m_tail, frame.data, frame.byteSize);
    }
    m_tail = (m_tail + frame.byteSize) % m_capacity;
    m_full = (m_tail == m_head);
    return true;
}

std::chrono::milliseconds PcmStreamBuffer::availableDuration() const {
    auto byteSize = availableByteSize();
    return std::chrono::milliseconds(1000 * byteSize / (m_sampleRate * m_channels * m_pointByteSize));
}

const RkAudio::Frame *PcmStreamBuffer::pop() {
    // return nullptr;
    std::lock_guard<std::mutex> locker(m_mutex);
    auto byteSize = availableByteSize();

    if (byteSize < m_popBuffer.size()) return nullptr;

    if ((m_head + m_popBuffer.size()) > m_capacity) {
        uint32_t size1 = m_capacity - m_head;
        memcpy(m_popBuffer.data(), m_buffer.data() + m_head, size1);
        uint32_t size2 = m_popBuffer.size() - size1;
        memcpy(m_popBuffer.data() + size1, m_buffer.data(), size2);
    } else {
        memcpy(m_popBuffer.data(), m_buffer.data() + m_head, m_popBuffer.size());
    }
    m_head = (m_head + m_popBuffer.size()) % m_capacity;

    m_popFrame.byteSize = m_popBuffer.size();
    m_popFrame.frameSize = m_popFrame.byteSize / m_channels / m_pointByteSize;
    m_full = false;
    return &m_popFrame;
}

uint32_t PcmStreamBuffer::availableByteSize() const {
    if (m_full) {
        return m_capacity;
    } else if (m_tail >= m_head) {
        return m_tail - m_head;
    } else {
        return m_capacity + m_tail - m_head;
    }
}