Skip to content

Model Training Overview

The Semantic Router relies on multiple specialized classification models to make intelligent routing decisions. This section provides a comprehensive overview of the training process, datasets used, and the purpose of each model in the routing pipeline.

Training Architecture Overview

The Semantic Router employs a multi-task learning approach using ModernBERT as the foundation model for various classification tasks. Each model is trained for specific purposes in the routing pipeline:

graph TB
    subgraph "Training Pipeline"
        Dataset1[MMLU-Pro Dataset<br/>Academic Domains] --> CategoryTraining[Category Classification<br/>Training]
        Dataset2[Microsoft Presidio<br/>PII Dataset] --> PIITraining[PII Detection<br/>Training]
        Dataset3[Jailbreak Classification<br/>Dataset] --> JailbreakTraining[Jailbreak Detection<br/>Training]
        Dataset4[Glaive Function Calling<br/>Dataset] --> IntentTraining[Intent Classification<br/>Training]
    end

    subgraph "Model Architecture"
        ModernBERT[ModernBERT Base<br/>Shared Backbone]

        CategoryTraining --> CategoryHead[Category Classifier<br/>10 Classes]
        PIITraining --> PIIHead[PII Token Classifier<br/>6 Entity Types]
        JailbreakTraining --> JailbreakHead[Jailbreak Binary Classifier<br/>2 Classes]
        IntentTraining --> IntentHead[Intent Classifier<br/>8 Function Categories]

        CategoryHead --> ModernBERT
        PIIHead --> ModernBERT  
        JailbreakHead --> ModernBERT
        IntentHead --> ModernBERT
    end

    subgraph "Deployment"
        ModernBERT --> SemanticRouter[Semantic Router<br/>Production System]
    end

Why ModernBERT?

Technical Advantages

ModernBERT represents the latest evolution in BERT architecture with several key improvements over traditional BERT models:

1. Enhanced Architecture

  • Rotary Position Embedding (RoPE): Better handling of positional information
  • GeGLU Activation: Improved gradient flow and representation capacity
  • Attention Bias Removal: Cleaner attention mechanisms
  • Modern Layer Normalization: Better training stability

2. Training Improvements

  • Longer Context: Trained on sequences up to 8,192 tokens vs BERT's 512
  • Better Data: Trained on higher-quality, more recent datasets
  • Improved Tokenization: More efficient vocabulary and tokenization
  • Anti-overfitting Techniques: Built-in regularization improvements

3. Performance Benefits

# Performance comparison on classification tasks
model_performance = {
    "bert-base": {
        "accuracy": 89.2,
        "inference_speed": "100ms",
        "memory_usage": "400MB"
    },
    "modernbert-base": {
        "accuracy": 92.7,      # +3.5% improvement
        "inference_speed": "85ms",  # 15% faster
        "memory_usage": "380MB"     # 5% less memory
    }
}

Why Not GPT-based Models?

Aspect ModernBERT GPT-3.5/4
Latency ~20ms ~200-500ms
Cost $0.0001/query $0.002-0.03/query
Specialization Fine-tuned for classification General purpose
Consistency Deterministic outputs Variable outputs
Deployment Self-hosted API dependency
Context Understanding Bidirectional Left-to-right

Training Methodology

Unified Fine-tuning Framework

Our training approach uses a unified fine-tuning framework that applies consistent methodologies across all classification tasks:

Anti-Overfitting Strategy

# Adaptive training configuration based on dataset size
def get_training_config(dataset_size):
    if dataset_size < 1000:
        return TrainingConfig(
            epochs=2,
            batch_size=4,
            learning_rate=1e-5,
            weight_decay=0.15,
            warmup_ratio=0.1,
            eval_strategy="epoch",
            early_stopping_patience=1
        )
    elif dataset_size < 5000:
        return TrainingConfig(
            epochs=3,
            batch_size=8, 
            learning_rate=2e-5,
            weight_decay=0.1,
            warmup_ratio=0.06,
            eval_strategy="steps",
            eval_steps=100,
            early_stopping_patience=2
        )
    else:
        return TrainingConfig(
            epochs=4,
            batch_size=16,
            learning_rate=3e-5,
            weight_decay=0.05,
            warmup_ratio=0.03,
            eval_strategy="steps", 
            eval_steps=200,
            early_stopping_patience=3
        )

Training Pipeline Implementation

class UnifiedBERTFinetuning:
    def __init__(self, model_name="modernbert-base", task_type="classification"):
        self.model_name = model_name
        self.task_type = task_type
        self.model = None
        self.tokenizer = None

    def train_model(self, dataset, config):
        # 1. Load pre-trained model
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_name,
            num_labels=len(dataset.label_names),
            problem_type="single_label_classification"
        )

        # 2. Setup training arguments with anti-overfitting measures
        training_args = TrainingArguments(
            output_dir=f"./models/{self.task_type}_classifier_{self.model_name}_model",
            num_train_epochs=config.epochs,
            per_device_train_batch_size=config.batch_size,
            per_device_eval_batch_size=config.batch_size,
            learning_rate=config.learning_rate,
            weight_decay=config.weight_decay,
            warmup_ratio=config.warmup_ratio,

            # Evaluation and early stopping
            evaluation_strategy=config.eval_strategy,
            eval_steps=config.eval_steps if hasattr(config, 'eval_steps') else None,
            save_strategy="steps",
            save_steps=200,
            load_best_model_at_end=True,
            metric_for_best_model="f1",
            greater_is_better=True,

            # Regularization
            fp16=True,  # Mixed precision training
            gradient_checkpointing=True,
            dataloader_drop_last=True,

            # Logging
            logging_dir=f"./logs/{self.task_type}_{self.model_name}",
            logging_steps=50,
            report_to="tensorboard"
        )

        # 3. Setup trainer with custom metrics
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=dataset.train_dataset,
            eval_dataset=dataset.eval_dataset,
            tokenizer=self.tokenizer,
            data_collator=DataCollatorWithPadding(self.tokenizer),
            compute_metrics=self.compute_metrics,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=config.early_stopping_patience)]
        )

        # 4. Train the model
        trainer.train()

        # 5. Save model and evaluation results
        self.save_trained_model(trainer)

        return trainer

    def compute_metrics(self, eval_pred):
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)

        return {
            'accuracy': accuracy_score(labels, predictions),
            'f1': f1_score(labels, predictions, average='weighted'),
            'precision': precision_score(labels, predictions, average='weighted'),
            'recall': recall_score(labels, predictions, average='weighted')
        }

Model Specifications

1. Category Classification Model

Purpose: Route queries to specialized models based on academic/professional domains.

Dataset: MMLU-Pro Academic Domains

# Dataset composition
mmlu_categories = {
    "mathematics": {
        "samples": 1547,
        "subcategories": ["algebra", "calculus", "geometry", "statistics"],
        "example": "Solve the integral of x^2 from 0 to 1"
    },
    "physics": {
        "samples": 1231, 
        "subcategories": ["mechanics", "thermodynamics", "electromagnetism"],
        "example": "Calculate the force needed to accelerate a 10kg mass at 5m/s^2"
    },
    "computer_science": {
        "samples": 1156,
        "subcategories": ["algorithms", "data_structures", "programming"],
        "example": "Implement a binary search algorithm in Python"
    },
    "biology": {
        "samples": 1089,
        "subcategories": ["genetics", "ecology", "anatomy"],
        "example": "Explain the process of photosynthesis in plants"
    },
    "chemistry": {
        "samples": 1034,
        "subcategories": ["organic", "inorganic", "physical"],
        "example": "Balance the chemical equation: H2 + O2 → H2O"
    },
    # ... additional categories
}

Training Configuration

model_config:
  base_model: "modernbert-base"
  task_type: "sequence_classification" 
  num_labels: 10

training_config:
  epochs: 3
  batch_size: 8
  learning_rate: 2e-5
  weight_decay: 0.1

evaluation_metrics:
  - accuracy: 94.2%
  - f1_weighted: 93.8%
  - per_category_precision: ">90% for all categories"

Model Performance

category_performance = {
    "overall_accuracy": 0.942,
    "per_category_results": {
        "mathematics": {"precision": 0.956, "recall": 0.943, "f1": 0.949},
        "physics": {"precision": 0.934, "recall": 0.928, "f1": 0.931},
        "computer_science": {"precision": 0.948, "recall": 0.952, "f1": 0.950},
        "biology": {"precision": 0.925, "recall": 0.918, "f1": 0.921},
        "chemistry": {"precision": 0.941, "recall": 0.935, "f1": 0.938}
    },
    "confusion_matrix_insights": {
        "most_confused": "physics <-> mathematics (12% cross-classification)",
        "best_separated": "biology <-> computer_science (2% cross-classification)"
    }
}

2. PII Detection Model

Purpose: Identify personally identifiable information to protect user privacy.

Dataset: Microsoft Presidio + Custom Synthetic Data

# PII entity types and examples
pii_entities = {
    "PERSON": {
        "count": 15420,
        "examples": ["John Smith", "Dr. Sarah Johnson", "Ms. Emily Chen"],
        "patterns": ["First Last", "Title First Last", "First Middle Last"]
    },
    "EMAIL_ADDRESS": {
        "count": 8934,
        "examples": ["user@domain.com", "john.doe@company.org"],
        "patterns": ["Local@Domain", "FirstLast@Company"]
    },
    "PHONE_NUMBER": {
        "count": 7234,
        "examples": ["(555) 123-4567", "+1-800-555-0123", "555.123.4567"],
        "patterns": ["US format", "International", "Dotted"]
    },
    "US_SSN": {
        "count": 5123,
        "examples": ["123-45-6789", "123456789"],
        "patterns": ["XXX-XX-XXXX", "XXXXXXXXX"]
    },
    "LOCATION": {
        "count": 6789,
        "examples": ["123 Main St, New York, NY", "San Francisco, CA"],
        "patterns": ["Street Address", "City, State", "Geographic locations"]
    },
    "NO_PII": {
        "count": 45678,
        "examples": ["The weather is nice today", "Please help me with coding"],
        "description": "Text containing no personal information"
    }
}

Training Approach: Token Classification

class PIITokenClassifier:
    def __init__(self):
        self.model = AutoModelForTokenClassification.from_pretrained(
            "modernbert-base",
            num_labels=len(pii_entities),  # 6 entity types
            id2label={i: label for i, label in enumerate(pii_entities.keys())},
            label2id={label: i for i, label in enumerate(pii_entities.keys())}
        )

    def preprocess_data(self, examples):
        # Convert PII annotations to BIO tags
        tokenized_inputs = self.tokenizer(
            examples["tokens"], 
            truncation=True, 
            is_split_into_words=True
        )

        # Align labels with tokenized inputs
        labels = []
        for i, label in enumerate(examples["ner_tags"]):
            word_ids = tokenized_inputs.word_ids(batch_index=i)
            label_ids = self.align_labels_with_tokens(label, word_ids)
            labels.append(label_ids)

        tokenized_inputs["labels"] = labels
        return tokenized_inputs

Performance Metrics

pii_performance = {
    "overall_f1": 0.957,
    "entity_level_performance": {
        "PERSON": {"precision": 0.961, "recall": 0.954, "f1": 0.957},
        "EMAIL_ADDRESS": {"precision": 0.989, "recall": 0.985, "f1": 0.987},
        "PHONE_NUMBER": {"precision": 0.978, "recall": 0.972, "f1": 0.975},
        "US_SSN": {"precision": 0.995, "recall": 0.991, "f1": 0.993},
        "LOCATION": {"precision": 0.943, "recall": 0.938, "f1": 0.940},
        "NO_PII": {"precision": 0.967, "recall": 0.971, "f1": 0.969}
    },
    "false_positive_analysis": {
        "common_errors": "Business names confused with person names",
        "mitigation": "Post-processing with business entity recognition"
    }
}

3. Jailbreak Detection Model

Purpose: Identify and block attempts to circumvent AI safety measures.

Dataset: Jailbreak Classification Dataset

jailbreak_dataset = {
    "benign": {
        "count": 25000,
        "examples": [
            "Please help me write a professional email",
            "Can you explain quantum computing?",
            "I need help with my math homework"
        ],
        "characteristics": "Normal, helpful requests"
    },
    "jailbreak": {
        "count": 8000,
        "examples": [
            # Actual examples would be sanitized for documentation
            "DAN (Do Anything Now) style prompts",
            "Role-playing to bypass restrictions", 
            "Hypothetical scenario circumvention"
        ],
        "characteristics": "Attempts to bypass AI safety measures",
        "categories": ["role_playing", "hypothetical", "character_injection", "system_override"]
    }
}

Training Strategy

class JailbreakDetector:
    def __init__(self):
        # Binary classification with class imbalance handling
        self.model = AutoModelForSequenceClassification.from_pretrained(
            "modernbert-base",
            num_labels=2,
            id2label={0: "benign", 1: "jailbreak"},
            label2id={"benign": 0, "jailbreak": 1}
        )

        # Handle class imbalance with weighted loss
        self.class_weights = torch.tensor([1.0, 3.125])  # 25000/8000 ratio

    def compute_loss(self, outputs, labels):
        logits = outputs.logits
        loss_fct = torch.nn.CrossEntropyLoss(weight=self.class_weights)
        return loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

Performance Analysis

jailbreak_performance = {
    "overall_metrics": {
        "accuracy": 0.967,
        "precision": 0.923,  # Lower due to conservative approach
        "recall": 0.891,     # Prioritize catching jailbreaks
        "f1": 0.907,
        "auc_roc": 0.984
    },
    "confusion_matrix": {
        "true_negatives": 4750,  # Correctly identified benign
        "false_positives": 250,  # Benign flagged as jailbreak (acceptable)
        "false_negatives": 87,   # Missed jailbreaks (concerning)
        "true_positives": 713    # Correctly caught jailbreaks
    },
    "business_impact": {
        "false_positive_rate": "5% - Users may experience occasional blocking",
        "false_negative_rate": "10.9% - Some jailbreaks may pass through",
        "tuning_strategy": "Bias toward false positives for safety"
    }
}

4. Intent Classification Model

Purpose: Classify queries for tool selection and function calling optimization.

Dataset: Glaive Function Calling v2

intent_categories = {
    "information_retrieval": {
        "count": 18250,
        "examples": ["What's the weather like?", "Search for recent news about AI"],
        "tools": ["web_search", "weather_api", "knowledge_base"]
    },
    "data_transformation": {
        "count": 8340,
        "examples": ["Convert this JSON to CSV", "Format this text"],
        "tools": ["format_converter", "data_processor"]
    },
    "calculation": {
        "count": 12150,
        "examples": ["Calculate compound interest", "Solve this equation"],
        "tools": ["calculator", "math_solver", "statistics"]
    },
    "communication": {
        "count": 6420,
        "examples": ["Send an email to John", "Post this to Slack"],
        "tools": ["email_client", "messaging_apis"]
    },
    "scheduling": {
        "count": 4680,
        "examples": ["Book a meeting for tomorrow", "Set a reminder"],
        "tools": ["calendar_api", "reminder_system"]
    },
    "file_operations": {
        "count": 7890,
        "examples": ["Read this document", "Save data to file"],
        "tools": ["file_reader", "file_writer", "cloud_storage"]
    },
    "analysis": {
        "count": 5420,
        "examples": ["Analyze this dataset", "Summarize the document"],
        "tools": ["data_analyzer", "text_summarizer"]
    },
    "no_function_needed": {
        "count": 15230,
        "examples": ["Tell me a joke", "Explain quantum physics"],
        "tools": []  # No external tools needed
    }
}

Training Infrastructure

Hardware Requirements

training_infrastructure:
  gpu_requirements:
    minimum: "NVIDIA RTX 3080 (10GB VRAM)"
    recommended: "NVIDIA A100 (40GB VRAM)"

  memory_requirements:
    system_ram: "32GB minimum, 64GB recommended"
    storage: "500GB SSD for datasets and models"

  training_time_estimates:
    category_classifier: "2-4 hours on RTX 3080"
    pii_detector: "4-6 hours on RTX 3080"
    jailbreak_guard: "1-2 hours on RTX 3080" 
    intent_classifier: "3-5 hours on RTX 3080"

Training Pipeline Automation

class TrainingPipeline:
    def __init__(self, config_path):
        self.config = self.load_config(config_path)
        self.models_to_train = ["category", "pii", "jailbreak", "intent"]

    def run_full_pipeline(self):
        results = {}

        for model_type in self.models_to_train:
            print(f"Training {model_type} classifier...")

            # 1. Load and preprocess data
            dataset = self.load_dataset(model_type)

            # 2. Initialize trainer
            trainer = UnifiedBERTFinetuning(
                model_name="modernbert-base",
                task_type=model_type
            )

            # 3. Train model
            result = trainer.train_model(dataset, self.config[model_type])

            # 4. Evaluate performance
            evaluation = trainer.evaluate_model(dataset.test_dataset)

            # 5. Save results
            results[model_type] = {
                "training_result": result,
                "evaluation_metrics": evaluation
            }

            print(f"{model_type} training completed. F1: {evaluation['f1']:.3f}")

        return results

This comprehensive training approach ensures that each model in the Semantic Router is optimized for its specific purpose while maintaining consistency in methodology and performance. The next sections detail the Classification Models and Datasets in greater depth.