#include "DataStructure.h"
#include <openssl/evp.h>
#include <openssl/kdf.h>
#include <openssl/rand.h>
#include <regex>
#include <stdexcept>

namespace Older {
int PBKDF2_ITERATIONS = 100000;
constexpr int SALT_LENGTH = 16;
constexpr int HASH_LENGTH = 32;

Account Account::hashPassword(const std::string &password) {
    Account ret;

    // 生成随机盐
    ret.salt.resize(SALT_LENGTH);
    if (RAND_bytes(ret.salt.data(), SALT_LENGTH) != 1) {
        throw std::runtime_error("Salt generation failed");
    }

    // PBKDF2 派生
    EVP_KDF *kdf = EVP_KDF_fetch(nullptr, "PBKDF2", nullptr);
    EVP_KDF_CTX *ctx = EVP_KDF_CTX_new(kdf);

    OSSL_PARAM params[] = {OSSL_PARAM_construct_utf8_string("digest", "SHA256", 0),
                           OSSL_PARAM_construct_octet_string("salt", ret.salt.data(), ret.salt.size()),
                           OSSL_PARAM_construct_octet_string("pass", (void *)password.data(), password.size()),
                           OSSL_PARAM_construct_int("iter", &PBKDF2_ITERATIONS), OSSL_PARAM_construct_end()};

    ret.passwordHash.resize(HASH_LENGTH);
    if (EVP_KDF_derive(ctx, ret.passwordHash.data(), HASH_LENGTH, params) != 1) {
        EVP_KDF_CTX_free(ctx);
        EVP_KDF_free(kdf);
        throw std::runtime_error("PBKDF2 failed");
    }

    EVP_KDF_CTX_free(ctx);
    EVP_KDF_free(kdf);
    return ret;
}

bool Account::verifyPassword(const Account &account, const std::string &password) {
    // 重新计算哈希
    EVP_KDF *kdf = EVP_KDF_fetch(nullptr, "PBKDF2", nullptr);
    EVP_KDF_CTX *ctx = EVP_KDF_CTX_new(kdf);

    OSSL_PARAM params[5] = {OSSL_PARAM_construct_utf8_string("digest", const_cast<char *>("SHA256"), 0),
                            OSSL_PARAM_construct_octet_string("salt", (void *)account.salt.data(), account.salt.size()),
                            OSSL_PARAM_construct_octet_string("pass", const_cast<char *>(password.data()), password.size()),
                            OSSL_PARAM_construct_int("iter", &PBKDF2_ITERATIONS), OSSL_PARAM_construct_end()};

    std::vector<unsigned char> new_hash(HASH_LENGTH);
    if (EVP_KDF_derive(ctx, new_hash.data(), HASH_LENGTH, params) != 1) {
        EVP_KDF_CTX_free(ctx);
        EVP_KDF_free(kdf);
        return false;
    }

    EVP_KDF_CTX_free(ctx);
    EVP_KDF_free(kdf);

    // 安全比较(防时序攻击)
    return CRYPTO_memcmp(account.passwordHash.data(), new_hash.data(), HASH_LENGTH) == 0;
}

bool Account::validateEmail(const std::string &email) {
    const std::regex pattern(R"([a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,})");
    return std::regex_match(email, pattern);
}
} // namespace Older