πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Retraining: Triggered, Scheduled, Online, Active Learning

MLOpsModel Retraining⭐ Premium

Advertisement

Interview Question (Hard) β€” Asked at: Google, Netflix, Uber, Amazon, Spotify

"Design an automated model retraining system that balances model freshness with computational cost. How do you determine when to retrain, what data to use, and how to validate the new model?"

Retraining Strategy Overview

Model retraining is critical for maintaining performance as data distributions change over time. The choice of retraining strategy depends on data velocity, computational resources, and latency requirements.

Retraining Strategy Comparison

StrategyTriggerLatencyCostBest For
ScheduledTime-basedHoursMediumStable distributions
TriggeredPerformance dropHoursMediumDrift detection
OnlinePer-exampleReal-timeHighStreaming data
Active LearningUncertaintyVariableLowLabel scarcity
IncrementalBatch arrivalMinutesLowHigh-frequency data

Retraining Architecture Diagram

Architecture Diagram
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                 Model Retraining System                          β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
β”‚  β”‚  Drift   │───▢│ Retrain  │───▢│ Training │───▢│ Validate β”‚ β”‚
β”‚  β”‚  Monitor β”‚    β”‚ Trigger  β”‚    β”‚ Pipeline β”‚    β”‚ & Gate   β”‚ β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
β”‚       β”‚              β”‚              β”‚                β”‚         β”‚
β”‚       β–Ό              β–Ό              β–Ό                β–Ό         β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
β”‚  β”‚  Data    β”‚    β”‚ Schedule β”‚    β”‚ Resource β”‚    β”‚ Deploy/  β”‚ β”‚
β”‚  β”‚  Version β”‚    β”‚ Manager  β”‚    β”‚ Manager  β”‚    β”‚ Rollback β”‚ β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Scheduled Retraining

Airflow Scheduled Pipeline

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.docker.operators.docker import DockerOperator
from airflow.sensors.external_task import ExternalTaskSensor
from datetime import datetime, timedelta
import json

default_args = {
    'owner': 'ml-team',
    'depends_on_past': False,
    'email_on_failure': True,
    'retries': 2,
    'retry_delay': timedelta(minutes=10),
}

with DAG(
    'scheduled_model_retraining',
    default_args=default_args,
    description='Scheduled model retraining pipeline',
    schedule_interval='0 2 * * 0',  # Weekly on Sunday at 2 AM
    start_date=datetime(2024, 1, 1),
    catchup=False,
    tags=['retraining', 'scheduled'],
) as dag:
    
    def prepare_training_data(**context):
        """Prepare data for retraining."""
        import pandas as pd
        from datetime import datetime, timedelta
        
        # Get date range for training
        execution_date = context['execution_date']
        training_window = 90  # days
        
        start_date = execution_date - timedelta(days=training_window)
        end_date = execution_date
        
        # Load data
        df = pd.read_parquet(
            f"s3://ml-data/training/{start_date:%Y%m%d}_{end_date:%Y%m%d}/"
        )
        
        # Save prepared data
        output_path = f"/tmp/training_data_{execution_date:%Y%m%d}.parquet"
        df.to_parquet(output_path)
        
        return output_path
    
    def train_model(**context):
        """Train model with new data."""
        import xgboost as xgb
        import mlflow
        import pandas as pd
        from sklearn.model_selection import train_test_split
        
        ti = context['ti']
        data_path = ti.xcom_pull(task_ids='prepare_data')
        
        df = pd.read_parquet(data_path)
        
        mlflow.set_experiment("scheduled_retraining")
        
        with mlflow.start_run(run_name=f"retrain_{context['ds']}"):
            # Prepare features
            X = df.drop(columns=['label'])
            y = df['label']
            
            X_train, X_val, y_train, y_val = train_test_split(
                X, y, test_size=0.2, random_state=42
            )
            
            dtrain = xgb.DMatrix(X_train, label=y_train)
            dval = xgb.DMatrix(X_val, label=y_val)
            
            params = {
                'objective': 'binary:logistic',
                'eval_metric': 'auc',
                'max_depth': 6,
                'learning_rate': 0.1,
            }
            
            model = xgb.train(
                params,
                dtrain,
                num_boost_round=1000,
                evals=[(dval, 'val')],
                early_stopping_rounds=50
            )
            
            # Log metrics
            val_pred = model.predict(dval)
            from sklearn.metrics import roc_auc_score
            auc = roc_auc_score(y_val, val_pred)
            
            mlflow.log_metric("auc_roc", auc)
            mlflow.log_param("training_samples", len(X_train))
            mlflow.log_param("validation_samples", len(X_val))
            
            # Save model
            model_path = f"/tmp/model_{context['ds']}.json"
            model.save_model(model_path)
            
            return {'model_path': model_path, 'auc': auc}
    
    def evaluate_and_compare(**context):
        """Compare new model with current production model."""
        import mlflow
        import json
        
        ti = context['ti']
        training_results = ti.xcom_pull(task_ids='train_model')
        
        # Load current production model metrics
        current_metrics = json.loads(
            open('/models/production/metrics.json').read()
        )
        
        new_auc = training_results['auc']
        current_auc = current_metrics['auc_roc']
        
        # Decision logic
        improvement_threshold = 0.01  # 1% improvement required
        
        if new_auc > current_auc + improvement_threshold:
            action = "promote"
            reason = f"New model improved AUC by {new_auc - current_auc:.4f}"
        elif new_auc < current_auc - 0.02:
            action = "keep"
            reason = f"New model degraded AUC by {current_auc - new_auc:.4f}"
        else:
            action = "keep"
            reason = "Improvement below threshold"
        
        return {
            'action': action,
            'reason': reason,
            'new_auc': new_auc,
            'current_auc': current_auc
        }
    
    # Task definitions
    prepare_data = PythonOperator(
        task_id='prepare_data',
        python_callable=prepare_training_data,
    )
    
    train = DockerOperator(
        task_id='train_model',
        image='registry.example.com/ml/training:latest',
        command='python train.py',
        auto_remove=True,
    )
    
    evaluate = PythonOperator(
        task_id='evaluate_model',
        python_callable=evaluate_and_compare,
    )
    
    promote = DockerOperator(
        task_id='promote_model',
        image='registry.example.com/ml/deployment:latest',
        command='python promote.py --model-path {{ ti.xcom_pull(task_ids="train_model")["model_path"] }}',
    )
    
    prepare_data >> train >> evaluate >> promote

Triggered Retraining

Event-Driven Retraining System

from enum import Enum
from dataclasses import dataclass
from typing import Dict, List, Optional
import json
import logging
from datetime import datetime, timedelta
from kafka import KafkaConsumer, KafkaProducer
import redis

logger = logging.getLogger(__name__)

class RetrainingTrigger(Enum):
    DATA_DRIFT = "data_drift"
    PERFORMANCE_DECAY = "performance_decay"
    SCHEDULE = "schedule"
    MANUAL = "manual"
    DATA_VOLUME = "data_volume"

@dataclass
class RetrainingEvent:
    trigger: RetrainingTrigger
    timestamp: datetime
    metadata: Dict
    priority: int  # 1 (low) to 5 (high)
    
class TriggeredRetrainingManager:
    def __init__(self, config: dict):
        self.config = config
        self.redis = redis.Redis(
            host=config['redis_host'],
            port=config['redis_port']
        )
        
        self.consumer = KafkaConsumer(
            'retraining-triggers',
            bootstrap_servers=config['kafka_servers'],
            value_deserializer=lambda m: json.loads(m.decode('utf-8'))
        )
        
        self.producer = KafkaProducer(
            bootstrap_servers=config['kafka_servers'],
            value_serializer=lambda v: json.dumps(v, default=str).encode('utf-8')
        )
        
        # Retraining cooldown
        self.cooldown_hours = config.get('cooldown_hours', 24)
    
    def should_retrain(self, event: RetrainingEvent) -> tuple:
        """Determine if retraining should be triggered."""
        
        # Check cooldown
        last_retrain = self.redis.get("last_retrain_time")
        if last_retrain:
            last_time = datetime.fromisoformat(last_retrain.decode())
            if datetime.now() - last_time < timedelta(hours=self.cooldown_hours):
                return False, "Cooldown period active"
        
        # Check priority
        if event.priority < 3:
            return False, "Priority too low"
        
        # Check trigger-specific conditions
        if event.trigger == RetrainingTrigger.DATA_DRIFT:
            drift_score = event.metadata.get('drift_score', 0)
            if drift_score < 0.3:
                return False, "Drift score below threshold"
        
        elif event.trigger == RetrainingTrigger.PERFORMANCE_DECAY:
            current_accuracy = event.metadata.get('current_accuracy', 1)
            baseline_accuracy = event.metadata.get('baseline_accuracy', 1)
            decay = baseline_accuracy - current_accuracy
            
            if decay < 0.05:  # 5% decay threshold
                return False, "Performance decay below threshold"
        
        return True, "Retraining recommended"
    
    def process_event(self, event: RetrainingEvent):
        """Process a retraining trigger event."""
        
        should_retrain, reason = self.should_retrain(event)
        
        logger.info(
            f"Processing trigger: {event.trigger.value}, "
            f"Decision: {'RETRAIN' if should_retrain else 'SKIP'}, "
            f"Reason: {reason}"
        )
        
        if should_retrain:
            # Create retraining job
            job = {
                'trigger': event.trigger.value,
                'timestamp': datetime.now().isoformat(),
                'metadata': event.metadata,
                'priority': event.priority,
                'status': 'queued'
            }
            
            # Queue retraining job
            self.producer.send(
                'retraining-jobs',
                value=job
            )
            
            # Update cooldown
            self.redis.setex(
                "last_retrain_time",
                timedelta(hours=self.cooldown_hours),
                datetime.now().isoformat()
            )
            
            # Log event
            self._log_retraining_event(event, "triggered", reason)
    
    def listen_for_events(self):
        """Listen for retraining trigger events."""
        
        logger.info("Listening for retraining events...")
        
        for message in self.consumer:
            try:
                event_data = message.value
                
                event = RetrainingEvent(
                    trigger=RetrainingTrigger(event_data['trigger']),
                    timestamp=datetime.fromisoformat(event_data['timestamp']),
                    metadata=event_data['metadata'],
                    priority=event_data.get('priority', 3)
                )
                
                self.process_event(event)
                
            except Exception as e:
                logger.error(f"Error processing event: {e}")
    
    def _log_retraining_event(self, event: RetrainingEvent, 
                               action: str, reason: str):
        """Log retraining event for audit."""
        
        log_entry = {
            'timestamp': datetime.now().isoformat(),
            'trigger': event.trigger.value,
            'action': action,
            'reason': reason,
            'metadata': event.metadata
        }
        
        self.redis.lpush(
            "retraining_log",
            json.dumps(log_entry, default=str)
        )

Performance Decay Detection

import numpy as np
from typing import Dict, List, Optional
from dataclasses import dataclass
from datetime import datetime, timedelta
import pandas as pd

@dataclass
class PerformanceMetric:
    timestamp: datetime
    value: float
    window: str  # '1h', '24h', '7d'

class PerformanceDecayDetector:
    def __init__(self, baseline_metrics: Dict[str, float],
                 decay_thresholds: Dict[str, float]):
        """
        Args:
            baseline_metrics: Baseline performance metrics
            decay_thresholds: Threshold for each metric to trigger retraining
        """
        self.baseline_metrics = baseline_metrics
        self.decay_thresholds = decay_thresholds
        self.metric_history = {}
    
    def add_metric(self, metric_name: str, value: float, 
                   window: str = '24h'):
        """Add a performance metric measurement."""
        
        if metric_name not in self.metric_history:
            self.metric_history[metric_name] = []
        
        self.metric_history[metric_name].append(
            PerformanceMetric(
                timestamp=datetime.now(),
                value=value,
                window=window
            )
        )
    
    def detect_decay(self) -> Dict[str, Dict]:
        """Detect performance decay across all metrics."""
        
        decay_results = {}
        
        for metric_name, history in self.metric_history.items():
            if not history:
                continue
            
            # Get recent values
            recent_window = timedelta(hours=24)
            recent_values = [
                m.value for m in history
                if datetime.now() - m.timestamp < recent_window
            ]
            
            if len(recent_values) < 10:
                continue
            
            # Calculate statistics
            current_mean = np.mean(recent_values)
            current_std = np.std(recent_values)
            baseline = self.baseline_metrics.get(metric_name, current_mean)
            
            # Calculate decay
            decay = baseline - current_mean
            decay_percentage = (decay / baseline * 100) if baseline > 0 else 0
            
            # Check threshold
            threshold = self.decay_thresholds.get(metric_name, 0.05)
            decay_detected = decay > threshold
            
            # Statistical significance test
            from scipy import stats
            t_stat, p_value = stats.ttest_1samp(
                recent_values, baseline
            )
            
            decay_results[metric_name] = {
                'current_value': current_mean,
                'baseline_value': baseline,
                'decay': decay,
                'decay_percentage': decay_percentage,
                'std': current_std,
                'p_value': p_value,
                'statistically_significant': p_value < 0.05,
                'decay_detected': decay_detected,
                'threshold': threshold
            }
        
        return decay_results
    
    def calculate_retraining_priority(self, decay_results: Dict) -> int:
        """Calculate retraining priority based on decay severity."""
        
        priority = 1
        
        for metric_name, result in decay_results.items():
            if result['decay_detected']:
                if result['decay_percentage'] > 20:
                    priority = 5
                elif result['decay_percentage'] > 10:
                    priority = max(priority, 4)
                elif result['decay_percentage'] > 5:
                    priority = max(priority, 3)
                else:
                    priority = max(priority, 2)
        
        return priority

ℹ️

Triggered retraining balances computational cost with model freshness. Use performance decay detection with statistical significance tests to avoid false positives from normal metric variance.

Online Learning

Online Gradient Descent

import numpy as np
from typing import Optional
from collections import deque

class OnlineLinearModel:
    """Online learning model using gradient descent."""
    
    def __init__(self, n_features: int, learning_rate: float = 0.01,
                 regularization: float = 0.001):
        self.n_features = n_features
        self.learning_rate = learning_rate
        self.regularization = regularization
        
        # Initialize weights
        self.weights = np.zeros(n_features)
        self.bias = 0
        
        # Running statistics
        self.n_samples = 0
        self.loss_history = deque(maxlen=1000)
    
    def _sigmoid(self, z):
        """Sigmoid activation function."""
        return 1 / (1 + np.exp(-np.clip(z, -500, 500)))
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """Make predictions."""
        z = X @ self.weights + self.bias
        return self._sigmoid(z)
    
    def update(self, X: np.ndarray, y: float):
        """Update model with a single example (online learning)."""
        
        # Forward pass
        prediction = self.predict(X.reshape(1, -1))[0]
        
        # Calculate gradient
        error = prediction - y
        
        # Update weights
        gradient = error * X
        self.weights -= self.learning_rate * (
            gradient + self.regularization * self.weights
        )
        self.bias -= self.learning_rate * error
        
        # Track loss
        loss = -y * np.log(prediction + 1e-7) - (1 - y) * np.log(1 - prediction + 1e-7)
        self.loss_history.append(loss)
        
        self.n_samples += 1
        
        return loss
    
    def update_batch(self, X_batch: np.ndarray, y_batch: np.ndarray):
        """Update model with a mini-batch."""
        
        total_loss = 0
        for X, y in zip(X_batch, y_batch):
            loss = self.update(X, y)
            total_loss += loss
        
        return total_loss / len(y_batch)
    
    def get_metrics(self) -> dict:
        """Get current model metrics."""
        
        return {
            'n_samples': self.n_samples,
            'avg_loss': np.mean(self.loss_history) if self.loss_history else 0,
            'weights_norm': np.linalg.norm(self.weights),
            'bias': self.bias
        }

class OnlineEnsemble:
    """Ensemble of online learning models."""
    
    def __init__(self, n_models: int, n_features: int,
                 learning_rate: float = 0.01):
        self.models = [
            OnlineLinearModel(n_features, learning_rate)
            for _ in range(n_models)
        ]
        self.weights = np.ones(n_models) / n_models
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """Weighted ensemble prediction."""
        
        predictions = np.array([
            model.predict(X) for model in self.models
        ])
        
        return np.average(predictions, axis=0, weights=self.weights)
    
    def update(self, X: np.ndarray, y: float):
        """Update all models and adjust ensemble weights."""
        
        # Get individual predictions
        predictions = np.array([
            model.predict(X.reshape(1, -1))[0]
            for model in self.models
        ])
        
        # Calculate losses
        losses = np.array([
            -y * np.log(pred + 1e-7) - (1 - y) * np.log(1 - pred + 1e-7)
            for pred in predictions
        ])
        
        # Update ensemble weights (exponential weighting)
        self.weights *= np.exp(-losses)
        self.weights /= self.weights.sum()
        
        # Update individual models
        for model in self.models:
            model.update(X, y)
    
    def get_model_weights(self) -> dict:
        """Get current ensemble weights."""
        return {
            f'model_{i}': float(w)
            for i, w in enumerate(self.weights)
        }

River - Online Machine Learning Library

from river import (
    linear_model, preprocessing, metrics, compose, utils
)
from river import ensemble as river_ensemble
import numpy as np

class RiverOnlineClassifier:
    """Online classifier using River library."""
    
    def __init__(self):
        # Create pipeline with preprocessing
        self.model = compose.Pipeline(
            preprocessing.StandardScaler(),
            river_ensemble.AdaptiveRandomForestClassifier(
                n_models=10,
                seed=42
            )
        )
        
        # Metrics
        self.metric = metrics Accuracy()
        self.auc_metric = metrics.ROCAUC()
        
        # Buffer for predictions
        self.prediction_buffer = []
    
    def learn_one(self, x: dict, y: int):
        """Learn from a single example."""
        
        # Get prediction before learning
        y_pred = self.model.predict_one(x)
        
        if y_pred is not None:
            self.metric.update(y, y_pred)
            self.auc_metric.update(y, self.model.predict_proba_one(x)[True])
        
        # Learn from example
        self.model.learn_one(x, y)
    
    def predict_one(self, x: dict) -> int:
        """Predict a single example."""
        return self.model.predict_one(x)
    
    def predict_proba_one(self, x: dict) -> dict:
        """Predict class probabilities."""
        return self.model.predict_proba_one(x)
    
    def get_metrics(self) -> dict:
        """Get current metrics."""
        return {
            'accuracy': self.metric.get(),
            'auc': self.auc_metric.get()
        }

class OnlineFeatureUpdater:
    """Update features in real-time for online learning."""
    
    def __init__(self, feature_store_url: str):
        self.feature_store_url = feature_store_url
        self.feature_cache = {}
    
    def update_features(self, entity_id: str, features: dict):
        """Update features in the feature store."""
        
        # Update local cache
        self.feature_cache[entity_id] = {
            'features': features,
            'updated_at': datetime.now()
        }
        
        # Update feature store (async)
        import asyncio
        asyncio.create_task(
            self._update_feature_store(entity_id, features)
        )
    
    async def _update_feature_store(self, entity_id: str, features: dict):
        """Update feature store asynchronously."""
        
        import aiohttp
        
        async with aiohttp.ClientSession() as session:
            async with session.put(
                f"{self.feature_store_url}/features/{entity_id}",
                json=features
            ) as response:
                if response.status != 200:
                    print(f"Failed to update features for {entity_id}")

⚠️

Online learning is sensitive to data quality. Implement input validation and outlier detection before updating the model. Consider using bounded buffers and gradient clipping to prevent instability.

Active Learning

Uncertainty Sampling

import numpy as np
from typing import List, Tuple
from sklearn.ensemble import RandomForestClassifier
from sklearn.calibration import CalibratedClassifierCV

class ActiveLearner:
    """Active learning with uncertainty sampling."""
    
    def __init__(self, base_model, n_initial: int = 100):
        self.base_model = base_model
        self.n_initial = n_initial
        
        # Labeled data
        self.X_labeled = None
        self.y_labeled = None
        
        # Unlabeled pool
        self.X_unlabeled = None
        
        # Query budget
        self.query_budget = 1000
        self.queries_made = 0
    
    def initialize(self, X_pool: np.ndarray, y_pool: np.ndarray):
        """Initialize with a small labeled set."""
        
        # Randomly select initial samples
        initial_idx = np.random.choice(
            len(X_pool), 
            size=min(self.n_initial, len(X_pool)),
            replace=False
        )
        
        self.X_labeled = X_pool[initial_idx]
        self.y_labeled = y_pool[initial_idx]
        
        # Remove from pool
        mask = np.ones(len(X_pool), dtype=bool)
        mask[initial_idx] = False
        self.X_unlabeled = X_pool[mask]
        
        # Fit initial model
        self._fit_model()
    
    def _fit_model(self):
        """Fit the model on labeled data."""
        
        # Calibrate probabilities
        self.model = CalibratedClassifierCV(
            self.base_model, 
            cv=3
        )
        self.model.fit(self.X_labeled, self.y_labeled)
    
    def query_uncertainty(self, n_samples: int = 10) -> np.ndarray:
        """Query samples with highest uncertainty."""
        
        # Get predictions
        probabilities = self.model.predict_proba(self.X_unlabeled)
        
        # Calculate uncertainty (entropy)
        entropy = -np.sum(
            probabilities * np.log(probabilities + 1e-7),
            axis=1
        )
        
        # Select most uncertain samples
        uncertain_idx = np.argsort(entropy)[-n_samples:]
        
        return uncertain_idx
    
    def query_margin(self, n_samples: int = 10) -> np.ndarray:
        """Query samples with smallest margin between top 2 classes."""
        
        probabilities = self.model.predict_proba(self.X_unlabeled)
        
        # Sort probabilities
        sorted_probs = np.sort(probabilities, axis=1)[:, ::-1]
        
        # Calculate margin
        margin = sorted_probs[:, 0] - sorted_probs[:, 1]
        
        # Select samples with smallest margin
        margin_idx = np.argsort(margin)[:n_samples]
        
        return margin_idx
    
    def query_random_forest_uncertainty(self, n_samples: int = 10) -> np.ndarray:
        """Query based on random forest disagreement."""
        
        # Use base model's trees for disagreement
        if hasattr(self.base_model, 'estimators_'):
            # Get predictions from each tree
            tree_predictions = np.array([
                tree.predict_proba(self.X_unlabeled)[:, 1]
                for tree in self.base_model.estimators_
            ])
            
            # Calculate variance (disagreement)
            variance = np.var(tree_predictions, axis=0)
            
            # Select most disagreed samples
            uncertain_idx = np.argsort(variance)[-n_samples:]
            
            return uncertain_idx
        
        return self.query_uncertainty(n_samples)
    
    def add_labels(self, X_new: np.ndarray, y_new: np.ndarray):
        """Add newly labeled data."""
        
        self.X_labeled = np.vstack([self.X_labeled, X_new])
        self.y_labeled = np.concatenate([self.y_labeled, y_new])
        
        # Remove from unlabeled pool
        # (simplified - in practice need to find and remove specific indices)
        
        # Refit model
        self._fit_model()
        
        self.queries_made += len(X_new)
    
    def should_query(self) -> bool:
        """Check if we should query more samples."""
        
        return (
            self.queries_made < self.query_budget and
            len(self.X_unlabeled) > 0
        )
    
    def get_statistics(self) -> dict:
        """Get active learning statistics."""
        
        return {
            'n_labeled': len(self.X_labeled),
            'n_unlabeled': len(self.X_unlabeled),
            'queries_made': self.queries_made,
            'query_budget': self.query_budget,
            'class_distribution': {
                'positive': int(np.sum(self.y_labeled == 1)),
                'negative': int(np.sum(self.y_labeled == 0))
            }
        }

Retraining Pipeline Orchestration

Complete Retraining Orchestrator

from enum import Enum
from dataclasses import dataclass
from typing import Dict, List, Optional
import json
import logging
from datetime import datetime, timedelta
import pandas as pd
import numpy as np

class RetrainingStrategy(Enum):
    SCHEDULED = "scheduled"
    TRIGGERED = "triggered"
    ONLINE = "online"
    ACTIVE_LEARNING = "active_learning"

@dataclass
class RetrainingJob:
    job_id: str
    strategy: RetrainingStrategy
    model_name: str
    trigger_reason: str
    priority: int
    created_at: datetime
    status: str = "pending"
    metadata: Dict = None

class RetrainingOrchestrator:
    def __init__(self, config: dict):
        self.config = config
        self.job_queue = []
        self.active_jobs = {}
        self.completed_jobs = []
    
    def create_retraining_job(self, strategy: RetrainingStrategy,
                               model_name: str, trigger_reason: str,
                               priority: int = 3) -> RetrainingJob:
        """Create a new retraining job."""
        
        job = RetrainingJob(
            job_id=f"job_{datetime.now():%Y%m%d%H%M%S}",
            strategy=strategy,
            model_name=model_name,
            trigger_reason=trigger_reason,
            priority=priority,
            created_at=datetime.now(),
            metadata={}
        )
        
        self.job_queue.append(job)
        
        # Sort by priority
        self.job_queue.sort(key=lambda x: x.priority, reverse=True)
        
        return job
    
    def execute_job(self, job: RetrainingJob) -> dict:
        """Execute a retraining job."""
        
        job.status = "running"
        self.active_jobs[job.job_id] = job
        
        try:
            # Step 1: Prepare data
            data = self._prepare_training_data(job)
            
            # Step 2: Train model
            training_results = self._train_model(job, data)
            
            # Step 3: Evaluate model
            evaluation_results = self._evaluate_model(job, training_results)
            
            # Step 4: Deploy if successful
            if evaluation_results['passed']:
                self._deploy_model(job, training_results)
                job.status = "completed"
            else:
                job.status = "failed"
            
            result = {
                'job_id': job.job_id,
                'status': job.status,
                'training_results': training_results,
                'evaluation_results': evaluation_results
            }
            
        except Exception as e:
            job.status = "error"
            result = {
                'job_id': job.job_id,
                'status': 'error',
                'error': str(e)
            }
        
        finally:
            # Move to completed
            self.completed_jobs.append(job)
            if job.job_id in self.active_jobs:
                del self.active_jobs[job.job_id]
        
        return result
    
    def _prepare_training_data(self, job: RetrainingJob) -> pd.DataFrame:
        """Prepare training data based on strategy."""
        
        if job.strategy == RetrainingStrategy.SCHEDULED:
            # Use last N days of data
            end_date = datetime.now()
            start_date = end_date - timedelta(days=90)
            
            data = pd.read_parquet(
                f"s3://ml-data/training/{start_date:%Y%m%d}_{end_date:%Y%m%d}/"
            )
            
        elif job.strategy == RetrainingStrategy.TRIGGERED:
            # Use data from last retraining to now
            last_retrain = self._get_last_retrain_time(job.model_name)
            
            data = pd.read_parquet(
                f"s3://ml-data/training/{last_retrain:%Y%m%d}_{datetime.now():%Y%m%d}/"
            )
            
        elif job.strategy == RetrainingStrategy.ONLINE:
            # Use recent streaming data
            data = self._get_recent_streaming_data(hours=24)
            
        else:
            # Default: use last 30 days
            data = self._get_training_data(days=30)
        
        return data
    
    def _train_model(self, job: RetrainingJob, data: pd.DataFrame) -> dict:
        """Train model using appropriate method."""
        
        import xgboost as xgb
        from sklearn.model_selection import train_test_split
        import mlflow
        
        mlflow.set_experiment(f"retraining_{job.model_name}")
        
        with mlflow.start_run(run_name=job.job_id):
            X = data.drop(columns=['label'])
            y = data['label']
            
            X_train, X_val, y_train, y_val = train_test_split(
                X, y, test_size=0.2, random_state=42
            )
            
            dtrain = xgb.DMatrix(X_train, label=y_train)
            dval = xgb.DMatrix(X_val, label=y_val)
            
            params = {
                'objective': 'binary:logistic',
                'eval_metric': 'auc',
                'max_depth': 6,
                'learning_rate': 0.1,
            }
            
            model = xgb.train(
                params,
                dtrain,
                num_boost_round=1000,
                evals=[(dval, 'val')],
                early_stopping_rounds=50
            )
            
            # Log metrics
            val_pred = model.predict(dval)
            from sklearn.metrics import roc_auc_score
            auc = roc_auc_score(y_val, val_pred)
            
            mlflow.log_metric("auc_roc", auc)
            mlflow.log_param("strategy", job.strategy.value)
            mlflow.log_param("trigger_reason", job.trigger_reason)
            
            return {
                'model': model,
                'auc': auc,
                'training_samples': len(X_train),
                'validation_samples': len(X_val)
            }
    
    def _evaluate_model(self, job: RetrainingJob, 
                        training_results: dict) -> dict:
        """Evaluate model against quality gates."""
        
        thresholds = {
            'auc_roc': 0.90,
            'min_improvement': 0.01
        }
        
        # Get current production model performance
        current_metrics = self._get_production_metrics(job.model_name)
        
        new_auc = training_results['auc']
        current_auc = current_metrics.get('auc_roc', 0)
        
        # Check absolute threshold
        if new_auc < thresholds['auc_roc']:
            return {
                'passed': False,
                'reason': f"AUC {new_auc:.4f} below threshold {thresholds['auc_roc']}"
            }
        
        # Check improvement for triggered retraining
        if job.strategy == RetrainingStrategy.TRIGGERED:
            improvement = new_auc - current_auc
            if improvement < thresholds['min_improvement']:
                return {
                    'passed': False,
                    'reason': f"Improvement {improvement:.4f} below threshold"
                }
        
        return {
            'passed': True,
            'new_auc': new_auc,
            'current_auc': current_auc,
            'improvement': new_auc - current_auc
        }
    
    def _deploy_model(self, job: RetrainingJob, 
                      training_results: dict):
        """Deploy new model to production."""
        
        # Save model
        model_path = f"s3://ml-models/{job.model_name}/{job.job_id}/model.json"
        training_results['model'].save_model(model_path)
        
        # Update model registry
        self._update_model_registry(
            job.model_name,
            job.job_id,
            training_results
        )
        
        # Trigger deployment
        self._trigger_deployment(job.model_name, job.job_id)
    
    def _get_last_retrain_time(self, model_name: str) -> datetime:
        """Get last retraining time for a model."""
        
        # Query model registry
        return datetime.now() - timedelta(days=7)
    
    def _get_production_metrics(self, model_name: str) -> dict:
        """Get current production model metrics."""
        
        # Query monitoring system
        return {'auc_roc': 0.92}
    
    def _update_model_registry(self, model_name: str, 
                                version: str, metrics: dict):
        """Update model registry with new version."""
        
        pass
    
    def _trigger_deployment(self, model_name: str, version: str):
        """Trigger deployment pipeline."""
        
        pass
    
    def get_statistics(self) -> dict:
        """Get retraining statistics."""
        
        return {
            'queued_jobs': len(self.job_queue),
            'active_jobs': len(self.active_jobs),
            'completed_jobs': len(self.completed_jobs),
            'jobs_by_strategy': {
                strategy.value: len([
                    j for j in self.completed_jobs
                    if j.strategy == strategy
                ])
                for strategy in RetrainingStrategy
            }
        }

ℹ️

A robust retraining system combines multiple strategies: scheduled retraining for baseline freshness, triggered retraining for drift response, and online learning for real-time adaptation. Monitor computational costs and model performance to optimize the retraining schedule.

Summary

Model retraining strategies include:

  1. Scheduled: Time-based retraining for predictable workloads
  2. Triggered: Event-driven retraining based on drift or performance decay
  3. Online: Continuous learning from streaming data
  4. Active Learning: Strategic sample selection for label efficiency

Choose the strategy that balances model freshness, computational cost, and labeling requirements for your use case.

Advertisement