// Copyright (C) 2025 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BPE_TOKENIZER_H #define DLIB_BPE_TOKENIZER_H #include <iostream> #include <iomanip> #include <string> #include <vector> #include <future> #include <mutex> #include <thread> #include <map> #include <unordered_map> #include <queue> #include "../base64.h" #include "../serialize.h" #include "bpe_tokenizer_abstract.h" namespace dlib { constexpr size_t BPE_TOKENIZER_MAX_TOKEN_LENGTH = 8; constexpr int BPE_TOKENIZER_BASE_VOCAB_SIZE = 256; class bpe_tokenizer { public: bpe_tokenizer() : vocab_size(BPE_TOKENIZER_BASE_VOCAB_SIZE) { // Initialize the base vocabulary with single bytes for (int i = 0; i < BPE_TOKENIZER_BASE_VOCAB_SIZE; ++i) vocab[i] = std::vector<uint8_t>{ static_cast<uint8_t>(i) }; // Initialize special tokens with sequential IDs special_tokens = { {"<text>", BPE_TOKENIZER_BASE_VOCAB_SIZE}, {"</text>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 1}, {"<url>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 2}, {"</url>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 3}, {"<image>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 4}, {"</image>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 5}, {"<video>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 6}, {"</video>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 7}, {"<audio>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 8}, {"</audio>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 9}, {"<file>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 10}, {"</file>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 11}, {"<code>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 12}, {"</code>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 13}, {"<summary>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 14}, {"</summary>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 15}, {"<think>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 16}, {"</think>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 17}, {"<start>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 18}, {"<end>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 19}, {"<user>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 20}, {"<bot>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 21}, {"<system>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 22}, {"<question>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 23}, {"<answer>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 24}, {"<search>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 25}, {"<unk>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 26}, {"<pad>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 27} }; // Initialize the vector of special token IDs for (const auto& token : special_tokens) special_token_map[token.second] = token.first; } // Train the tokenizer on the given text void train(const std::string& text, int vocab_size, bool verbose = false) { DLIB_CASSERT(vocab_size >= BPE_TOKENIZER_BASE_VOCAB_SIZE); this->vocab_size = vocab_size; int num_merges = vocab_size - BPE_TOKENIZER_BASE_VOCAB_SIZE; // Convert text to byte IDs std::vector<int> ids; for (char c : text) ids.push_back(static_cast<uint8_t>(c)); // Perform BPE merges for (int i = 0; i < num_merges; ++i) { auto stats = get_stats(ids); if (stats.empty()) break; // Find the most frequent pair that does not exceed BPE_TOKENIZER_MAX_TOKEN_LENGTH auto pair = get_most_frequent_pair(stats); // Check if the resulting token would exceed BPE_TOKENIZER_MAX_TOKEN_LENGTH size_t new_token_length = vocab[pair.first].size() + vocab[pair.second].size(); if (new_token_length > BPE_TOKENIZER_MAX_TOKEN_LENGTH) { if (verbose) { std::cout << "\r" << std::setw(100) << std::flush << "\rskipping merge " << std::to_string(i + 1) << "/" << std::to_string(num_merges) << ": (" << std::to_string(pair.first) << "," << std::to_string(pair.second) << ") -> new token length " << std::to_string(new_token_length) << " exceeds limit of " << std::to_string(BPE_TOKENIZER_MAX_TOKEN_LENGTH) << std::flush; } continue; // Skip this merge } int idx = (BPE_TOKENIZER_BASE_VOCAB_SIZE + (int)special_tokens.size()) + i; ids = merge(ids, pair, idx); merges[pair] = idx; vocab[idx].insert(vocab[idx].end(), vocab[pair.first].begin(), vocab[pair.first].end()); vocab[idx].insert(vocab[idx].end(), vocab[pair.second].begin(), vocab[pair.second].end()); if (verbose) { std::cout << "\r" << std::setw(100) << std::flush << "\rmerge " << std::to_string(i + 1) << "/" << std::to_string(num_merges) << ": (" << std::to_string(pair.first) << "," << std::to_string(pair.second) << ") -> " << std::to_string(idx) << " (" << bytes_to_string(vocab[idx]) << ") had " << std::to_string(stats[pair]) << " occurrences" << std::endl; } } } // Encode the given text into subword tokens std::vector<int> encode(const std::string& text) const { std::vector<int> result_ids; std::mutex result_mutex; // Split the text into paragraphs based on newline characters std::vector<std::string> paragraphs; size_t start = 0, end = text.find('\n'); while (end != std::string::npos) { std::string paragraph = text.substr(start, end - start); if (!paragraph.empty()) paragraphs.push_back(paragraph); start = end + 1; end = text.find('\n', start); } // Add the last paragraph (if any) and only if it's not empty if (start < text.size()) { std::string paragraph = text.substr(start); if (!paragraph.empty()) paragraphs.push_back(paragraph); } // Function to encode a single paragraph auto encode_paragraph = [this](const std::string& paragraph) -> std::vector<int> { std::vector<int> ids; ids.reserve(paragraph.size()); for (char c : paragraph) ids.push_back(static_cast<uint8_t>(c)); auto stats = get_stats(ids); std::priority_queue<std::pair<int, std::pair<int, int>>> pq; for (const auto& stat : stats) { const std::pair<int, int>& pair = stat.first; if (merges.count(pair)) pq.push({ merges.at(pair), pair }); } while (!pq.empty()) { const auto& top_element = pq.top(); const std::pair<int, int>& pair = top_element.second; pq.pop(); bool pair_found = false; for (size_t i = 0; i < ids.size() - 1; ++i) { if (ids[i] == pair.first && ids[i + 1] == pair.second) { pair_found = true; break; } } if (!pair_found) continue; int idx = merges.at(pair); ids = merge(ids, pair, idx); stats = get_stats(ids); for (const auto& stat : stats) { const std::pair<int, int>& new_pair = stat.first; if (merges.count(new_pair)) pq.push({ merges.at(new_pair), new_pair }); } } return ids; }; // Special case: if there's only one paragraph, no need for threads int sot_tok = get_special_token_id("<text>"); int eot_tok = get_special_token_id("</text>"); if (paragraphs.size() == 1) { std::vector<int> paragraph_ids = encode_paragraph(paragraphs[0]); result_ids.push_back(sot_tok); result_ids.insert(result_ids.end(), paragraph_ids.begin(), paragraph_ids.end()); result_ids.push_back(eot_tok); return result_ids; } // Launch encoding tasks in parallel for multiple paragraphs std::vector<std::future<std::vector<int>>> futures; for (const auto& paragraph : paragraphs) futures.push_back(std::async(std::launch::async, encode_paragraph, paragraph)); // Collect results in order for (auto& future : futures) { std::vector<int> paragraph_ids = future.get(); std::lock_guard<std::mutex> lock(result_mutex); result_ids.push_back(sot_tok); result_ids.insert(result_ids.end(), paragraph_ids.begin(), paragraph_ids.end()); result_ids.push_back(eot_tok); } return result_ids; } // Decode a single token ID back into text std::string decode(int id, bool display_special_tokens = true) const { return decode(std::vector<int>({ id }), display_special_tokens); } // Decode a sequence of token IDs back into text std::string decode(const std::vector<int>& ids, bool display_special_tokens = true) const { std::vector<uint8_t> bytes; int vocab_size = static_cast<int>(get_vocab_size()); for (int id : ids) { if (id < vocab_size) { // Check if the ID is a special token auto it = special_token_map.find(id); if (it != special_token_map.end()) { // It's a special token, get the corresponding string if (display_special_tokens) bytes.insert(bytes.end(), it->second.begin(), it->second.end()); } else { // It's a regular token, get the bytes from the vocabulary auto& token = vocab.at(id); bytes.insert(bytes.end(), token.begin(), token.end()); } } } return std::string(bytes.begin(), bytes.end()); } // Save the tokenizer model and vocabulary to file friend void serialize(const bpe_tokenizer& tok, std::ostream& out) { serialize("bpe_tokenizer2_", out); serialize(tok.special_tokens, out); serialize(tok.special_token_map, out); serialize(tok.merges, out); serialize(tok.vocab, out); serialize(tok.vocab_size, out); } // Load the tokenizer model and vocabulary from file friend void deserialize(bpe_tokenizer& tok, std::istream& in) { std::string version; dlib::deserialize(version, in); if (version != "bpe_tokenizer2_") throw dlib::serialization_error("Unexpected version '" + version + "' found while deserializing dlib::bpe_tokenizer_."); deserialize(tok.special_tokens, in); deserialize(tok.special_token_map, in); deserialize(tok.merges, in); deserialize(tok.vocab, in); deserialize(tok.vocab_size, in); } // Get the ID of a special token int get_special_token_id(const std::string& token) const { auto it = special_tokens.find(token); if (it != special_tokens.end()) return it->second; throw std::runtime_error("Special token not found: " + token); } // Get the total vocabulary size size_t get_vocab_size() const { return (vocab.size() + special_tokens.size()); } private: std::map<std::string, int> special_tokens; std::unordered_map<int, std::string> special_token_map; std::map<std::pair<int, int>, int> merges; std::map<int, std::vector<uint8_t>> vocab; int vocab_size; // Get frequency statistics of adjacent token pairs struct pair_hash { template <class T1, class T2> std::size_t operator()(const std::pair<T1, T2>& p) const { auto hash1 = std::hash<T1>{}(p.first); auto hash2 = std::hash<T2>{}(p.second); return hash1 ^ (hash2 << 1); } }; std::unordered_map<std::pair<int, int>, int, pair_hash> get_stats(const std::vector<int>& ids) const { std::unordered_map<std::pair<int, int>, int, pair_hash> global_stats; std::mutex global_stats_mutex; auto worker = [&](size_t start, size_t end) { std::unordered_map<std::pair<int, int>, int, pair_hash> local_stats; for (size_t i = start; i < end - 1 && i + 1 < ids.size(); ++i) local_stats[{ids[i], ids[i + 1]}]++; std::lock_guard<std::mutex> lock(global_stats_mutex); for (const auto& pair : local_stats) global_stats[pair.first] += pair.second; }; size_t num_threads = std::thread::hardware_concurrency(); size_t segment_size = ids.size() / num_threads; std::vector<std::thread> threads; for (size_t t = 0; t < num_threads; ++t) { size_t start = t * segment_size; size_t end = (t == num_threads - 1) ? ids.size() : start + segment_size; threads.emplace_back(worker, start, end); } for (auto& thread : threads) thread.join(); return global_stats; } // Finds the most frequent pair of tokens in the given statistics map that does not exceed the maximum token length std::pair<int, int> get_most_frequent_pair(const std::unordered_map<std::pair<int, int>, int, pair_hash>& stats) const { std::pair<int, int> best_pair = { -1, -1 }; // Initialize the best pair to an invalid value double max_score = 0; // Initialize the maximum score to 0 // Iterate over all pairs in the statistics map for (const auto& stat : stats) { const std::pair<int, int>& pair = stat.first; // Extract the token pair int count = stat.second; // Extract the frequency count // Check if the new token formed by merging the pair would exceed the maximum allowed length size_t new_token_length = vocab.at(pair.first).size() + vocab.at(pair.second).size(); if (new_token_length > BPE_TOKENIZER_MAX_TOKEN_LENGTH) continue; // Skip this pair if it exceeds the maximum token length // Calculate the score for this pair (frequency * length_penalty) double score = (size_t)count * (new_token_length > (BPE_TOKENIZER_MAX_TOKEN_LENGTH / 2) ? 1.75 : 1.0); // Update the best pair if the current pair has a higher score if (score > max_score) { best_pair = pair; max_score = score; } } return best_pair; // Return the pair with the highest score } // Merge the most frequent pair in the token sequence std::vector<int> merge(std::vector<int>& ids, const std::pair<int, int>& pair, int idx) const { std::vector<int> new_ids; new_ids.reserve(ids.size()); // Reserve space to avoid reallocations for (size_t i = 0; i < ids.size(); ++i) { if (i < ids.size() - 1 && ids[i] == pair.first && ids[i + 1] == pair.second) { new_ids.push_back(idx); // Replace the pair with the new token ID i++; // Skip the next token } else new_ids.push_back(ids[i]); // Keep the current token } return new_ids; } static std::string base64_encode(const std::string& input) { dlib::base64 encoder; std::istringstream sin(input); std::ostringstream sout; encoder.encode(sin, sout); return sout.str(); } // Convert a sequence of bytes to a readable string static std::string bytes_to_string(const std::vector<uint8_t>& bytes) { std::string data(bytes.begin(), bytes.end()); return base64_encode(data); } }; } #endif // DLIB_BPE_TOKENIZER_H