Skip to content

Router Implementation Details

This document provides detailed insights into the core routing algorithms, classification logic, and implementation specifics of the Semantic Router.

Classification Pipeline

Multi-Stage Classification Architecture

The Semantic Router employs a multi-stage classification pipeline that combines several specialized models:

graph TB
    Query[User Query] --> Preprocessor[Text Preprocessing<br/>Tokenization & Cleaning]

    Preprocessor --> ParallelClassification{Parallel Classification}

    ParallelClassification --> CategoryClassifier[Category Classifier<br/>ModernBERT - 10 categories]
    ParallelClassification --> PIIDetector[PII Detector<br/>ModernBERT - Token Classification]
    ParallelClassification --> JailbreakGuard[Jailbreak Guard<br/>ModernBERT - Binary Classification]

    CategoryClassifier --> RoutingEngine[Routing Decision Engine]
    PIIDetector --> SecurityGate[Security Gate]
    JailbreakGuard --> SecurityGate

    SecurityGate -->|Pass| RoutingEngine
    SecurityGate -->|Block| SecurityResponse[Security Block Response]

    RoutingEngine --> ModelSelection[Model Selection<br/>Based on Category + Confidence]
    ModelSelection --> ToolsSelector[Tools Auto-Selection]
    ToolsSelector --> FinalDecision[Final Routing Decision]

Implementation Details

Category Classification Logic

type CategoryClassifier struct {
    model           *ModernBERTModel
    tokenizer       *ModernBERTTokenizer
    labelMapping    map[int]string
    confidenceThreshold float64
}

func (cc *CategoryClassifier) ClassifyIntent(query string) (*Classification, error) {
    // Tokenize input
    tokens := cc.tokenizer.Tokenize(query)

    // Run inference
    logits, err := cc.model.Forward(tokens)
    if err != nil {
        return nil, err
    }

    // Apply softmax to get probabilities
    probabilities := softmax(logits)

    // Find best classification
    maxIdx, maxProb := argmax(probabilities)
    category := cc.labelMapping[maxIdx]

    return &Classification{
        Category:      category,
        Confidence:    maxProb,
        Probabilities: probabilities,
        ProcessingTime: time.Since(start),
    }, nil
}

// Routing decision logic
func (r *OpenAIRouter) makeRoutingDecision(classification *Classification) *RoutingDecision {
    // High confidence - use specialized model
    if classification.Confidence > 0.85 {
        return &RoutingDecision{
            SelectedModel: r.getSpecializedModel(classification.Category),
            Reason:        "High confidence specialized routing",
            Confidence:    classification.Confidence,
        }
    }

    // Medium confidence - use category-appropriate model with fallback
    if classification.Confidence > 0.6 {
        return &RoutingDecision{
            SelectedModel: r.getCategoryModel(classification.Category),
            FallbackModel: r.Config.DefaultModel,
            Reason:        "Medium confidence routing with fallback",
            Confidence:    classification.Confidence,
        }
    }

    // Low confidence - use general model
    return &RoutingDecision{
        SelectedModel: r.Config.DefaultModel,
        Reason:        "Low confidence, using general model",
        Confidence:    classification.Confidence,
    }
}

Semantic Caching Implementation

Cache Architecture

type SemanticCache struct {
    embeddings  map[string][]float32    // Query embeddings
    responses   map[string]CachedResponse
    similarity  SimilarityCalculator
    ttl         time.Duration
    maxEntries  int
    mutex       sync.RWMutex
}

type CachedResponse struct {
    Response    interface{}
    Timestamp   time.Time
    Model       string
    Embeddings  []float32
    HitCount    int
}

// Cache lookup with semantic similarity
func (sc *SemanticCache) Get(query string) (interface{}, bool) {
    sc.mutex.RLock()
    defer sc.mutex.RUnlock()

    // Generate query embedding
    queryEmbedding := sc.generateEmbedding(query)

    // Find most similar cached query
    bestSimilarity := 0.0
    var bestMatch *CachedResponse

    for cachedQuery, embedding := range sc.embeddings {
        similarity := sc.similarity.CosineSimilarity(queryEmbedding, embedding)

        if similarity > bestSimilarity && similarity > sc.similarityThreshold {
            bestSimilarity = similarity
            if response, exists := sc.responses[cachedQuery]; exists {
                bestMatch = &response
            }
        }
    }

    if bestMatch != nil && time.Since(bestMatch.Timestamp) < sc.ttl {
        bestMatch.HitCount++
        return bestMatch.Response, true
    }

    return nil, false
}

Tools Auto-Selection

Tool Relevance Algorithm

type ToolsSelector struct {
    toolsDB           *tools.ToolsDatabase
    relevanceModel    *RelevanceModel
    maxTools          int
    confidenceThreshold float64
}

func (ts *ToolsSelector) SelectRelevantTools(
    query string, 
    availableTools []Tool,
) []Tool {
    var selectedTools []Tool

    // Score each tool for relevance
    for _, tool := range availableTools {
        relevanceScore := ts.calculateRelevance(query, tool)

        if relevanceScore > ts.confidenceThreshold {
            tool.RelevanceScore = relevanceScore
            selectedTools = append(selectedTools, tool)
        }
    }

    // Sort by relevance score
    sort.Slice(selectedTools, func(i, j int) bool {
        return selectedTools[i].RelevanceScore > selectedTools[j].RelevanceScore
    })

    // Limit number of tools
    if len(selectedTools) > ts.maxTools {
        selectedTools = selectedTools[:ts.maxTools]
    }

    return selectedTools
}

func (ts *ToolsSelector) calculateRelevance(query string, tool Tool) float64 {
    // Combine multiple relevance signals
    keywordScore := ts.calculateKeywordRelevance(query, tool)
    semanticScore := ts.calculateSemanticRelevance(query, tool)
    categoryScore := ts.calculateCategoryRelevance(query, tool)

    // Weighted combination
    return 0.4*keywordScore + 0.4*semanticScore + 0.2*categoryScore
}

Security Implementation

PII Detection Pipeline

type PIIDetector struct {
    tokenClassifier  *ModernBERTTokenClassifier
    piiPatterns     map[string]*regexp.Regexp
    confidence      float64
}

func (pd *PIIDetector) DetectPII(text string) (*PIIDetectionResult, error) {
    result := &PIIDetectionResult{
        HasPII:   false,
        Entities: []PIIEntity{},
    }

    // Token-level classification with ModernBERT
    tokens := pd.tokenClassifier.Tokenize(text)
    predictions, err := pd.tokenClassifier.Predict(tokens)
    if err != nil {
        return nil, err
    }

    // Extract PII entities
    entities := pd.extractEntities(tokens, predictions)

    // Additional pattern-based detection for high-precision
    patternEntities := pd.detectWithPatterns(text)

    // Combine results
    allEntities := append(entities, patternEntities...)

    if len(allEntities) > 0 {
        result.HasPII = true
        result.Entities = allEntities
    }

    return result, nil
}

Jailbreak Detection

type JailbreakGuard struct {
    classifier     *ModernBERTBinaryClassifier
    patterns       []JailbreakPattern
    riskThreshold  float64
}

func (jg *JailbreakGuard) AssessRisk(query string) (*SecurityAssessment, error) {
    // ML-based detection
    mlScore, err := jg.classifier.PredictRisk(query)
    if err != nil {
        return nil, err
    }

    // Pattern-based detection
    patternScore := jg.calculatePatternScore(query)

    // Combined risk score
    overallRisk := 0.7*mlScore + 0.3*patternScore

    return &SecurityAssessment{
        RiskScore:    overallRisk,
        IsJailbreak:  overallRisk > jg.riskThreshold,
        MLScore:      mlScore,
        PatternScore: patternScore,
        Reasoning:    jg.explainDecision(overallRisk, mlScore, patternScore),
    }, nil
}

Performance Optimizations

Model Loading and Caching

type ModelManager struct {
    models     map[string]*LoadedModel
    modelLock  sync.RWMutex
    warmupPool sync.Pool
}

// Lazy loading with warming
func (mm *ModelManager) GetModel(modelName string) (*LoadedModel, error) {
    mm.modelLock.RLock()
    if model, exists := mm.models[modelName]; exists {
        mm.modelLock.RUnlock()
        return model, nil
    }
    mm.modelLock.RUnlock()

    // Upgrade to write lock
    mm.modelLock.Lock()
    defer mm.modelLock.Unlock()

    // Double-check pattern
    if model, exists := mm.models[modelName]; exists {
        return model, nil
    }

    // Load model
    model, err := mm.loadModel(modelName)
    if err != nil {
        return nil, err
    }

    // Warm up model
    go mm.warmupModel(model)

    mm.models[modelName] = model
    return model, nil
}

Batch Processing

type BatchProcessor struct {
    batchSize     int
    batchTimeout  time.Duration
    pendingBatch  []ProcessingRequest
    batchMutex    sync.Mutex
    flushTimer    *time.Timer
}

func (bp *BatchProcessor) ProcessRequest(req ProcessingRequest) {
    bp.batchMutex.Lock()
    defer bp.batchMutex.Unlock()

    bp.pendingBatch = append(bp.pendingBatch, req)

    // Flush if batch is full
    if len(bp.pendingBatch) >= bp.batchSize {
        bp.flushBatch()
        return
    }

    // Set timer for timeout-based flushing
    if bp.flushTimer == nil {
        bp.flushTimer = time.AfterFunc(bp.batchTimeout, bp.flushBatch)
    }
}

func (bp *BatchProcessor) flushBatch() {
    if len(bp.pendingBatch) == 0 {
        return
    }

    // Process entire batch together for better GPU utilization
    results := bp.classifier.ProcessBatch(bp.pendingBatch)

    // Distribute results back to individual requests
    for i, result := range results {
        bp.pendingBatch[i].ResultChannel <- result
    }

    // Reset batch
    bp.pendingBatch = bp.pendingBatch[:0]
    if bp.flushTimer != nil {
        bp.flushTimer.Stop()
        bp.flushTimer = nil
    }
}

Monitoring and Observability

Request Tracing

type RequestTracer struct {
    spans map[string]*Span
    mutex sync.RWMutex
}

func (rt *RequestTracer) StartSpan(requestID, operation string) *Span {
    span := &Span{
        RequestID: requestID,
        Operation: operation,
        StartTime: time.Now(),
        Tags:      make(map[string]interface{}),
    }

    rt.mutex.Lock()
    rt.spans[requestID+":"+operation] = span
    rt.mutex.Unlock()

    return span
}

func (rt *RequestTracer) FinishSpan(span *Span) {
    span.EndTime = time.Now()
    span.Duration = span.EndTime.Sub(span.StartTime)

    // Log detailed timing information
    log.WithFields(log.Fields{
        "request_id": span.RequestID,
        "operation":  span.Operation,
        "duration":   span.Duration.Milliseconds(),
        "tags":       span.Tags,
    }).Info("Operation completed")

    rt.mutex.Lock()
    delete(rt.spans, span.RequestID+":"+span.Operation)
    rt.mutex.Unlock()
}

Performance Metrics

// Detailed performance tracking
type PerformanceTracker struct {
    classificationLatency prometheus.Histogram
    cacheHitRatio        prometheus.Gauge
    securityCheckLatency prometheus.Histogram
    routingAccuracy      prometheus.Gauge
}

func (pt *PerformanceTracker) RecordClassification(
    category string, 
    confidence float64, 
    duration time.Duration,
) {
    pt.classificationLatency.Observe(duration.Seconds())

    // Track accuracy by category
    accuracyMetric := pt.routingAccuracy.WithLabelValues(category)
    accuracyMetric.Set(confidence)
}

This implementation provides the foundation for intelligent, secure, and performant LLM routing. The next section covers Model Training, detailing how the classification models are developed and optimized.