#ifndef SlmNet_H
#define SlmNet_H
/**
* @file slm_defs.h
* @brief Optimized Transformer neural architecture for language processing
*
* Implements a Transformer architecture with multi-head attention and RMS
* normalization, designed for efficient learning and inference. The architecture
* leverages cognitive principles of parallel information processing and
* selective attention.
*
* Key features:
* - RMS normalization for enhanced stability
* - Optimized residual connections
* - Causal masking for autoregressive attention
*/
#include <dlib/dnn.h>
namespace transformer
{
using namespace dlib;
// Scale Weights Layer
template <long d_k_>
class scale_weights_ : public multiply_ {
public:
explicit scale_weights_() : multiply_(1.0f / std::sqrt(static_cast<float>(d_k_))) {}
};
template <long d_k, typename SUBNET>
using scale_weights = add_layer<scale_weights_<d_k>, SUBNET>;
namespace def {
template <long num_heads, long d_model, typename SUBNET>
using query = extract<0, num_heads, d_model / num_heads, 1, SUBNET>;
template <long num_heads, long d_model, typename SUBNET>
using key = extract<d_model, num_heads, 1, d_model / num_heads, SUBNET>;
template <long num_heads, long d_model, typename SUBNET>
using value = extract<(d_model * 2), num_heads, d_model / num_heads, 1, SUBNET>;
/**
* Multi-Head Attention Layer
*
* Structure:
* 1. Input processing
* - RMS normalization
* - Single linear projection (d_model -> 3*d_model) for Q,K,V
* 2. Parallel head processing (num_heads)
* - Split into Q, K, V tensors
* - Key transposition for attention computation
* 3. Attention mechanism
* - Scaled dot-product (Q*K^T / sqrt(d_k))
* - Causal masking (tril_mask)
* - Softmax normalization
* - Value weighting
* 4. Output
* - Head concatenation
* - Residual connection
*
* Template parameters:
* @param ACT: Activation function type
* @param DO: Dropout layer type
* @param d_model: Model dimension
* @param num_heads: Number of attention heads
* @param SUBNET: Input subnet type
*/
template <template <typename> class ACT, template <typename> class DO,
long d_model, long num_heads, typename SUBNET>
using multihead_attention = add_prev1<DO<extract<0, 1, 1, d_model, multm_prev3<
DO<softmaxm<tril_mask<
scale_weights<d_model / num_heads,
multm_prev4<query<num_heads, d_model, skip2<
tag4<key<num_heads, d_model, skip2<
tag3<value<num_heads, d_model,
tag2<fc_no_bias<d_model * 3, rms_norm<
tag1<SUBNET>>>>>>>>>>>>>>>>>>>>;
/**
* Feed-Forward Network Layer
*
* Structure:
* 1. Input processing
* - RMS normalization
* - Input tagged for residual connection
* 2. Transformation
* - Expansion layer (d_model -> 4*d_model)
* - Activation function
* - Projection layer (4*d_model -> d_model)
* 3. Output
* - Dropout
* - Residual connection
*
* Template parameters:
* @param ACT: Activation function type
* @param DO: Dropout layer type
* @param d_model: Model dimension
* @param SUBNET: Input subnet type
*/
template <template <typename> class ACT, template <typename> class DO, long d_model, typename SUBNET>
using feed_forward =
add_prev5<
DO<extract<0, 1, 1, d_model,
fc<d_model, ACT<fc<d_model * 4, rms_norm<
tag5<SUBNET>>>>>>>>;
/**
* Transformer Block
*
* Combines sequentially:
* 1. Multi-head attention layer
* 2. Feed-forward network
*
* Template parameters:
* @param ACT: Activation function type
* @param DO: Dropout layer type
* @param d_model: Model dimension
* @param num_heads: Number of attention heads
* @param SUBNET: Input subnet type
*/
template <template <typename> class ACT, template <typename> class DO, long seq_len, long d_model, long num_heads, typename SUBNET>
using transformer_block =
feed_forward<ACT, DO, d_model,
multihead_attention<ACT, DO, d_model, num_heads, SUBNET>>;
}
// Positional Embeddings
template <long num_embeddings, long embedding_length, typename SUBNET>
using positional_embeddings = positional_encodings<embeddings<num_embeddings, embedding_length, SUBNET>>;
// Classification Head
template <template <typename> class ACT, long embedding_length, typename SUBNET>
using squeezing = fc<embedding_length / 4, ACT<fc<embedding_length / 8, SUBNET>>>;
template <bool USE_SQUEEZING, template <typename> class ACT, long num_logits, long embedding_length, typename SUBNET>
struct classification_head_impl;
template <template <typename> class ACT, long num_logits, long embedding_length, typename SUBNET>
struct classification_head_impl<true, ACT, num_logits, embedding_length, SUBNET>
{
using type = loss_multiclass_log<fc<num_logits, squeezing<ACT, embedding_length, rms_norm<SUBNET>>>>;
};
template <template <typename> class ACT, long num_logits, long embedding_length, typename SUBNET>
struct classification_head_impl<false, ACT, num_logits, embedding_length, SUBNET>
{
using type = loss_multiclass_log<fc<num_logits, rms_norm<SUBNET>>>;
};
template <bool USE_SQUEEZING, template <typename> class ACT, long num_logits, long embedding_length, typename SUBNET>
using classification_head = typename classification_head_impl<USE_SQUEEZING, ACT, num_logits, embedding_length, SUBNET>::type;
/**
* @brief Transformer Model Configuration Template
*
* Provides a flexible and type-safe configuration mechanism for Transformer models
* with compile-time parameter validation and network generation.
*
* Template parameters:
* @param vocab_size Vocabulary size for token embedding
* @param num_layers Number of Transformer layers
* @param num_heads Number of attention heads
* @param embedding_dim Dimension of token embeddings
* @param max_seq_len Maximum sequence length
* @param use_squeezing Use squeezing layer
* @param activation_func Activation function type
* @param dropout_policy Dropout regularization policy
*/
template <
long vocab_size = 5000, // Default vocabulary size
long num_layers = 6, // Default number of layers
long num_heads = 8, // Default number of attention heads
long embedding_dim = 128, // Default embedding dimension
long max_seq_len = 100, // Default maximum sequence length
bool use_squeezing = false, // Default use squeezing layer
template <typename> class activation_func = gelu, // Default activation function
template <typename> class dropout_policy = dropout_10 // Default dropout policy
>
struct transformer_config {
// Core model parameters
static constexpr long VOCAB_SIZE = vocab_size;
static constexpr long NUM_LAYERS = num_layers;
static constexpr long NUM_HEADS = num_heads;
static constexpr long EMBEDDING_DIM = embedding_dim;
static constexpr long MAX_SEQ_LEN = max_seq_len;
static constexpr bool USE_SQUEEZING = use_squeezing;
/**
* @brief Compile-time validation of model configuration
*
* Performs static assertions to ensure valid model parameters
*/
struct validation {
static_assert(VOCAB_SIZE > 0, "Vocabulary size must be positive");
static_assert(NUM_LAYERS > 0, "Number of layers must be positive");
static_assert(NUM_HEADS > 0, "Number of attention heads must be positive");
static_assert(EMBEDDING_DIM% NUM_HEADS == 0, "Embedding dimension must be divisible by number of heads");
};
/**
* @brief Network type generation based on training/inference mode
*
* Generates different network types for training and inference
* using the configured parameters
*
* Template parameters:
* @tparam is_training Determines training or inference network type
*/
template <typename SUBNET>
using t_transformer_block = def::transformer_block<activation_func, dropout_policy, MAX_SEQ_LEN, EMBEDDING_DIM, NUM_HEADS, SUBNET>;
template <typename SUBNET>
using i_transformer_block = def::transformer_block<activation_func, multiply, MAX_SEQ_LEN, EMBEDDING_DIM, NUM_HEADS, SUBNET>;
template<bool is_training>
using network_type = std::conditional_t<is_training,
classification_head<USE_SQUEEZING, activation_func, VOCAB_SIZE, EMBEDDING_DIM,
repeat<NUM_LAYERS, t_transformer_block,
positional_embeddings<VOCAB_SIZE, EMBEDDING_DIM, input<matrix<int, 0, 1>>>>>,
classification_head<USE_SQUEEZING, activation_func, VOCAB_SIZE, EMBEDDING_DIM,
repeat<NUM_LAYERS, i_transformer_block,
positional_embeddings<VOCAB_SIZE, EMBEDDING_DIM, input<matrix<int, 0, 1>>>>>
>;
/**
* @brief Model configuration information and debugging utility
*
* Provides methods to generate human-readable model configuration details
*/
struct model_info {
/**
* @brief Generate a detailed description of the model configuration
*
* @return String containing model configuration details
*/
static std::string describe() {
std::stringstream ss;
ss << "Transformer model configuration:\n"
<< "- vocabulary size: " << VOCAB_SIZE << "\n"
<< "- layers: " << NUM_LAYERS << "\n"
<< "- attention heads: " << NUM_HEADS << "\n"
<< "- embedding dimension: " << EMBEDDING_DIM << "\n"
<< "- max sequence length: " << MAX_SEQ_LEN;
return ss.str();
}
};
};
using vslm = transformer_config<>; // Very Small Language Model
/**
* @example Configuration and Usage Examples
*
* // Creating different transformer configurations
* using default_transformer = transformer_config<>;
* using large_transformer_with_squeezing = transformer_config<
* 50000, // Larger vocabulary
* 8, // More layers
* 8, // More heads
* 512, // Larger embedding dimension
* 128, // Longer sequences
* true // Use squeezing
* >;
*
* // Network type instantiations for different modes
* using train_network = default_transformer::network_type<true>;
* using inference_network = default_transformer::network_type<false>;
*
* // Utility function to print model configuration
* void print_model_info() {
* std::cout << default_transformer::model_info::describe() << std::endl;
* }
*
* @note
* - Supports compile-time configuration
* - Provides static validation of model parameters
* - Enables dynamic network type generation
* - Offers advanced hyperparameter tuning utilities
*
* @author Cydral
* @site https://github.com/Cydral/ERNIE
* @version 1.0
* @date 11/2024
*/
}
#endif // SlmNet_H