// Copyright (C) 2025 Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_BPE_TOKENIZER_H
#define DLIB_BPE_TOKENIZER_H

#include <iostream>
#include <iomanip>
#include <string>
#include <vector>
#include <future>
#include <mutex>
#include <thread>
#include <map>
#include <unordered_map>
#include <queue>

#include "../base64.h"
#include "../serialize.h"
#include "bpe_tokenizer_abstract.h"

namespace dlib
{
    constexpr size_t BPE_TOKENIZER_MAX_TOKEN_LENGTH = 8;
    constexpr int BPE_TOKENIZER_BASE_VOCAB_SIZE = 256;

    class bpe_tokenizer
    {
    public:
        bpe_tokenizer() : vocab_size(BPE_TOKENIZER_BASE_VOCAB_SIZE)
        {
            // Initialize the base vocabulary with single bytes
            for (int i = 0; i < BPE_TOKENIZER_BASE_VOCAB_SIZE; ++i)
                vocab[i] = std::vector<uint8_t>{ static_cast<uint8_t>(i) };
            
            // Initialize special tokens with sequential IDs
            special_tokens =
            {
                {"<text>",      BPE_TOKENIZER_BASE_VOCAB_SIZE},
                {"</text>",     BPE_TOKENIZER_BASE_VOCAB_SIZE + 1},
                {"<url>",       BPE_TOKENIZER_BASE_VOCAB_SIZE + 2},
                {"</url>",      BPE_TOKENIZER_BASE_VOCAB_SIZE + 3},
                {"<image>",     BPE_TOKENIZER_BASE_VOCAB_SIZE + 4},
                {"</image>",    BPE_TOKENIZER_BASE_VOCAB_SIZE + 5},
                {"<video>",     BPE_TOKENIZER_BASE_VOCAB_SIZE + 6},
                {"</video>",    BPE_TOKENIZER_BASE_VOCAB_SIZE + 7},
                {"<audio>",     BPE_TOKENIZER_BASE_VOCAB_SIZE + 8},
                {"</audio>",    BPE_TOKENIZER_BASE_VOCAB_SIZE + 9},
                {"<file>",      BPE_TOKENIZER_BASE_VOCAB_SIZE + 10},
                {"</file>",     BPE_TOKENIZER_BASE_VOCAB_SIZE + 11},
                {"<code>",      BPE_TOKENIZER_BASE_VOCAB_SIZE + 12},
                {"</code>",     BPE_TOKENIZER_BASE_VOCAB_SIZE + 13},
                {"<summary>",   BPE_TOKENIZER_BASE_VOCAB_SIZE + 14},
                {"</summary>",  BPE_TOKENIZER_BASE_VOCAB_SIZE + 15},
                {"<think>",     BPE_TOKENIZER_BASE_VOCAB_SIZE + 16},
                {"</think>",    BPE_TOKENIZER_BASE_VOCAB_SIZE + 17},
                {"<start>",     BPE_TOKENIZER_BASE_VOCAB_SIZE + 18},
                {"<end>",       BPE_TOKENIZER_BASE_VOCAB_SIZE + 19},
                {"<user>",      BPE_TOKENIZER_BASE_VOCAB_SIZE + 20},
                {"<bot>",       BPE_TOKENIZER_BASE_VOCAB_SIZE + 21},
                {"<system>",    BPE_TOKENIZER_BASE_VOCAB_SIZE + 22},
                {"<question>",  BPE_TOKENIZER_BASE_VOCAB_SIZE + 23},
                {"<answer>",    BPE_TOKENIZER_BASE_VOCAB_SIZE + 24},
                {"<search>",    BPE_TOKENIZER_BASE_VOCAB_SIZE + 25},
                {"<unk>",       BPE_TOKENIZER_BASE_VOCAB_SIZE + 26},
                {"<pad>",       BPE_TOKENIZER_BASE_VOCAB_SIZE + 27}
            };

            // Initialize the vector of special token IDs
            for (const auto& token : special_tokens)
                special_token_map[token.second] = token.first;
        }

        // Train the tokenizer on the given text
        void train(const std::string& text, int vocab_size, bool verbose = false)
        {
            DLIB_CASSERT(vocab_size >= BPE_TOKENIZER_BASE_VOCAB_SIZE);
            this->vocab_size = vocab_size;
            int num_merges = vocab_size - BPE_TOKENIZER_BASE_VOCAB_SIZE;

            // Convert text to byte IDs
            std::vector<int> ids;
            for (char c : text) ids.push_back(static_cast<uint8_t>(c));

            // Perform BPE merges
            for (int i = 0; i < num_merges; ++i) {
                auto stats = get_stats(ids);
                if (stats.empty()) break;

                // Find the most frequent pair that does not exceed BPE_TOKENIZER_MAX_TOKEN_LENGTH
                auto pair = get_most_frequent_pair(stats);

                // Check if the resulting token would exceed BPE_TOKENIZER_MAX_TOKEN_LENGTH
                size_t new_token_length = vocab[pair.first].size() + vocab[pair.second].size();
                if (new_token_length > BPE_TOKENIZER_MAX_TOKEN_LENGTH) {
                    if (verbose)
                    {
                        std::cout << "\r"
                            << std::setw(100) << std::flush
                            << "\rskipping merge " << std::to_string(i + 1) << "/" << std::to_string(num_merges) << ": ("
                            << std::to_string(pair.first) << "," << std::to_string(pair.second) << ") -> new token length "
                            << std::to_string(new_token_length) << " exceeds limit of " << std::to_string(BPE_TOKENIZER_MAX_TOKEN_LENGTH)
                            << std::flush;
                    }
                    continue; // Skip this merge
                }

                int idx = (BPE_TOKENIZER_BASE_VOCAB_SIZE + (int)special_tokens.size()) + i;
                ids = merge(ids, pair, idx);
                merges[pair] = idx;
                vocab[idx].insert(vocab[idx].end(), vocab[pair.first].begin(), vocab[pair.first].end());
                vocab[idx].insert(vocab[idx].end(), vocab[pair.second].begin(), vocab[pair.second].end());

                if (verbose)
                {
                    std::cout << "\r"
                        << std::setw(100) << std::flush
                        << "\rmerge " << std::to_string(i + 1) << "/" << std::to_string(num_merges) << ": ("
                        << std::to_string(pair.first) << "," << std::to_string(pair.second) << ") -> " << std::to_string(idx)
                        << " (" << bytes_to_string(vocab[idx]) << ") had "
                        << std::to_string(stats[pair]) << " occurrences"
                        << std::endl;
                }
            }
        }

        // Encode the given text into subword tokens
        std::vector<int> encode(const std::string& text) const
        {
            std::vector<int> result_ids;
            std::mutex result_mutex;

            // Split the text into paragraphs based on newline characters
            std::vector<std::string> paragraphs;
            size_t start = 0, end = text.find('\n');
            while (end != std::string::npos) {
                std::string paragraph = text.substr(start, end - start);
                if (!paragraph.empty()) paragraphs.push_back(paragraph);
                start = end + 1;
                end = text.find('\n', start);
            }
            // Add the last paragraph (if any) and only if it's not empty
            if (start < text.size()) {
                std::string paragraph = text.substr(start);
                if (!paragraph.empty()) paragraphs.push_back(paragraph);
            }

            // Function to encode a single paragraph
            auto encode_paragraph = [this](const std::string& paragraph) -> std::vector<int> {
                std::vector<int> ids;
                ids.reserve(paragraph.size());
                for (char c : paragraph) ids.push_back(static_cast<uint8_t>(c));

                auto stats = get_stats(ids);
                std::priority_queue<std::pair<int, std::pair<int, int>>> pq;
                for (const auto& stat : stats) {
                    const std::pair<int, int>& pair = stat.first;
                    if (merges.count(pair)) pq.push({ merges.at(pair), pair });
                }

                while (!pq.empty()) {
                    const auto& top_element = pq.top();
                    const std::pair<int, int>& pair = top_element.second;
                    pq.pop();

                    bool pair_found = false;
                    for (size_t i = 0; i < ids.size() - 1; ++i) {
                        if (ids[i] == pair.first && ids[i + 1] == pair.second) {
                            pair_found = true;
                            break;
                        }
                    }
                    if (!pair_found) continue;

                    int idx = merges.at(pair);
                    ids = merge(ids, pair, idx);

                    stats = get_stats(ids);
                    for (const auto& stat : stats) {
                        const std::pair<int, int>& new_pair = stat.first;
                        if (merges.count(new_pair)) pq.push({ merges.at(new_pair), new_pair });
                    }
                }

                return ids;
            };

            // Special case: if there's only one paragraph, no need for threads
            int sot_tok = get_special_token_id("<text>");
            int eot_tok = get_special_token_id("</text>");
            if (paragraphs.size() == 1) {
                std::vector<int> paragraph_ids = encode_paragraph(paragraphs[0]);
                result_ids.push_back(sot_tok);
                result_ids.insert(result_ids.end(), paragraph_ids.begin(), paragraph_ids.end());
                result_ids.push_back(eot_tok);
                return result_ids;
            }

            // Launch encoding tasks in parallel for multiple paragraphs
            std::vector<std::future<std::vector<int>>> futures;
            for (const auto& paragraph : paragraphs)
                futures.push_back(std::async(std::launch::async, encode_paragraph, paragraph));

            // Collect results in order
            for (auto& future : futures) {
                std::vector<int> paragraph_ids = future.get();
                std::lock_guard<std::mutex> lock(result_mutex);
                result_ids.push_back(sot_tok);
                result_ids.insert(result_ids.end(), paragraph_ids.begin(), paragraph_ids.end());
                result_ids.push_back(eot_tok);
            }
            return result_ids;
        }

        // Decode a single token ID back into text
        std::string decode(int id, bool display_special_tokens = true) const
        {
            return decode(std::vector<int>({ id }), display_special_tokens);
        }

        // Decode a sequence of token IDs back into text
        std::string decode(const std::vector<int>& ids, bool display_special_tokens = true) const
        {
            std::vector<uint8_t> bytes;
            int vocab_size = static_cast<int>(get_vocab_size());
            for (int id : ids)
            {
                if (id < vocab_size)
                {
                    // Check if the ID is a special token
                    auto it = special_token_map.find(id);
                    if (it != special_token_map.end())
                    {
                        // It's a special token, get the corresponding string
                        if (display_special_tokens) bytes.insert(bytes.end(), it->second.begin(), it->second.end());
                    }
                    else
                    {
                        // It's a regular token, get the bytes from the vocabulary
                        auto& token = vocab.at(id);
                        bytes.insert(bytes.end(), token.begin(), token.end());
                    }
                }
            }
            return std::string(bytes.begin(), bytes.end());
        }

        // Save the tokenizer model and vocabulary to file
        friend void serialize(const bpe_tokenizer& tok, std::ostream& out)
        {
            serialize("bpe_tokenizer2_", out);
            serialize(tok.special_tokens, out);
            serialize(tok.special_token_map, out);
            serialize(tok.merges, out);
            serialize(tok.vocab, out);
            serialize(tok.vocab_size, out);
        }

        // Load the tokenizer model and vocabulary from file
        friend void deserialize(bpe_tokenizer& tok, std::istream& in) {
            std::string version;
            dlib::deserialize(version, in);
            if (version != "bpe_tokenizer2_")
                throw dlib::serialization_error("Unexpected version '" + version + "' found while deserializing dlib::bpe_tokenizer_.");
            deserialize(tok.special_tokens, in);
            deserialize(tok.special_token_map, in);
            deserialize(tok.merges, in);
            deserialize(tok.vocab, in);
            deserialize(tok.vocab_size, in);
        }

        // Get the ID of a special token
        int get_special_token_id(const std::string& token) const
        {
            auto it = special_tokens.find(token);
            if (it != special_tokens.end()) return it->second;
            throw std::runtime_error("Special token not found: " + token);
        }

        // Get the total vocabulary size
        size_t get_vocab_size() const
        {
            return (vocab.size() + special_tokens.size());
        }

    private:
        std::map<std::string, int> special_tokens;
        std::unordered_map<int, std::string> special_token_map;
        std::map<std::pair<int, int>, int> merges;
        std::map<int, std::vector<uint8_t>> vocab;
        int vocab_size;

        // Get frequency statistics of adjacent token pairs
        struct pair_hash {
            template <class T1, class T2>
            std::size_t operator()(const std::pair<T1, T2>& p) const
            {
                auto hash1 = std::hash<T1>{}(p.first);
                auto hash2 = std::hash<T2>{}(p.second);
                return hash1 ^ (hash2 << 1);
            }
        };
        std::unordered_map<std::pair<int, int>, int, pair_hash> get_stats(const std::vector<int>& ids) const
        {
            std::unordered_map<std::pair<int, int>, int, pair_hash> global_stats;
            std::mutex global_stats_mutex;

            auto worker = [&](size_t start, size_t end) {
                std::unordered_map<std::pair<int, int>, int, pair_hash> local_stats;
                for (size_t i = start; i < end - 1 && i + 1 < ids.size(); ++i)
                    local_stats[{ids[i], ids[i + 1]}]++;

                std::lock_guard<std::mutex> lock(global_stats_mutex);
                for (const auto& pair : local_stats)
                    global_stats[pair.first] += pair.second;
            };

            size_t num_threads = std::thread::hardware_concurrency();
            size_t segment_size = ids.size() / num_threads;
            std::vector<std::thread> threads;

            for (size_t t = 0; t < num_threads; ++t)
            {
                size_t start = t * segment_size;
                size_t end = (t == num_threads - 1) ? ids.size() : start + segment_size;
                threads.emplace_back(worker, start, end);
            }

            for (auto& thread : threads) thread.join();

            return global_stats;
        }

        // Finds the most frequent pair of tokens in the given statistics map that does not exceed the maximum token length
        std::pair<int, int> get_most_frequent_pair(const std::unordered_map<std::pair<int, int>, int, pair_hash>& stats) const 
        {
            std::pair<int, int> best_pair = { -1, -1 }; // Initialize the best pair to an invalid value
            double max_score = 0; // Initialize the maximum score to 0

            // Iterate over all pairs in the statistics map
            for (const auto& stat : stats) {
                const std::pair<int, int>& pair = stat.first; // Extract the token pair
                int count = stat.second; // Extract the frequency count

                // Check if the new token formed by merging the pair would exceed the maximum allowed length
                size_t new_token_length = vocab.at(pair.first).size() + vocab.at(pair.second).size();
                if (new_token_length > BPE_TOKENIZER_MAX_TOKEN_LENGTH) continue; // Skip this pair if it exceeds the maximum token length

                // Calculate the score for this pair (frequency * length_penalty)
                double score = (size_t)count * (new_token_length > (BPE_TOKENIZER_MAX_TOKEN_LENGTH / 2) ? 1.75 : 1.0);

                // Update the best pair if the current pair has a higher score
                if (score > max_score)
                {
                    best_pair = pair;
                    max_score = score;
                }
            }

            return best_pair; // Return the pair with the highest score
        }

        // Merge the most frequent pair in the token sequence
        std::vector<int> merge(std::vector<int>& ids, const std::pair<int, int>& pair, int idx) const
        {
            std::vector<int> new_ids;
            new_ids.reserve(ids.size()); // Reserve space to avoid reallocations

            for (size_t i = 0; i < ids.size(); ++i)
            {
                if (i < ids.size() - 1 && ids[i] == pair.first && ids[i + 1] == pair.second)
                {
                    new_ids.push_back(idx); // Replace the pair with the new token ID
                    i++; // Skip the next token
                }
                else new_ids.push_back(ids[i]); // Keep the current token
            }

            return new_ids;
        }

        static std::string base64_encode(const std::string& input) {
            dlib::base64 encoder;
            std::istringstream sin(input);
            std::ostringstream sout;
            encoder.encode(sin, sout);
            return sout.str();
        }

        // Convert a sequence of bytes to a readable string
        static std::string bytes_to_string(const std::vector<uint8_t>& bytes)
        {
            std::string data(bytes.begin(), bytes.end());
            return base64_encode(data);
        }

    };

}


#endif // DLIB_BPE_TOKENIZER_H