Interview Question (Hard) β Asked at: Uber, Netflix, Spotify, Apple, Tesla
"Design a model serving architecture that supports batch, real-time, and edge inference. How do you implement A/B testing and shadow mode deployments while maintaining low latency?"
Model Serving Architecture Overview
Model serving is the process of deploying trained ML models to production for inference. The serving pattern depends on latency requirements, throughput, cost constraints, and deployment environment.
Serving Pattern Decision Matrix
| Pattern | Latency | Throughput | Cost | Use Case |
|---|---|---|---|---|
| Batch Inference | Hours | Very High | Low | Report generation, ETL |
| Real-Time REST | 10-100ms | Medium | High | Web APIs, mobile |
| Real-Time gRPC | 1-10ms | High | High | High-frequency trading |
| Edge Inference | 1-5ms | Low | Low | IoT, mobile offline |
| Streaming | 100ms-1s | High | Medium | Event-driven systems |
Batch Inference Patterns
PySpark Batch Inference
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, DoubleType
import mlflow
import mlflow.spark
class BatchInferencePipeline:
def __init__(self, spark: SparkSession, model_uri: str):
self.spark = spark
self.model = mlflow.spark.load_model(model_uri)
def run_batch_inference(self, input_path: str,
output_path: str,
batch_size: int = 1000000):
"""Run batch inference on large dataset."""
# Read input data
input_df = self.spark.read.parquet(input_path)
# Add batch identifier
input_df = input_df.withColumn(
"batch_id",
F.date_format(F.current_timestamp(), "yyyyMMdd_HHmmss")
)
# Repartition for parallel processing
num_partitions = max(1, input_df.count() // batch_size)
input_df = input_df.repartition(num_partitions)
# Run inference
predictions = self.model.transform(input_df)
# Add metadata
predictions = predictions \
.withColumn("inference_timestamp", F.current_timestamp()) \
.withColumn("model_version", F.lit(self.model_version))
# Write results with partitioning
predictions.write \
.mode("overwrite") \
.partitionBy("batch_id") \
.parquet(output_path)
# Log inference statistics
stats = {
"total_rows": input_df.count(),
"total_batches": num_partitions,
"output_path": output_path,
"timestamp": datetime.now().isoformat()
}
return stats
def run_incremental_inference(self, input_path: str,
output_path: str,
watermark_col: str,
interval_minutes: int = 60):
"""Run incremental batch inference."""
# Read new data since last inference
last_watermark = self._get_last_watermark(output_path)
input_df = self.spark.read.parquet(input_path) \
.filter(F.col(watermark_col) > last_watermark)
if input_df.count() == 0:
print("No new data to process")
return
# Run inference
predictions = self.model.transform(input_df)
# Append to output
predictions.write \
.mode("append") \
.parquet(output_path)
# Update watermark
new_watermark = input_df.agg(
F.max(watermark_col)
).collect()[0][0]
self._update_watermark(output_path, new_watermark)
def _get_last_watermark(self, output_path: str):
"""Get last processed watermark."""
try:
metadata = self.spark.read.parquet(
f"{output_path}/_metadata/watermark"
).collect()[0]
return metadata["watermark"]
except:
return "1970-01-01"
def _update_watermark(self, output_path: str, watermark):
"""Update watermark metadata."""
watermark_df = self.spark.createDataFrame(
[(watermark, datetime.now().isoformat())],
["watermark", "updated_at"]
)
watermark_df.write \
.mode("overwrite") \
.parquet(f"{output_path}/_metadata/watermark")
Real-Time Serving with FastAPI
REST API Model Server
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
import uvicorn
import mlflow
import numpy as np
import pandas as pd
from datetime import datetime
import asyncio
import logging
from contextlib import asynccontextmanager
import redis.asyncio as redis
import json
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelServer:
def __init__(self):
self.model = None
self.preprocessor = None
self.model_version = None
self.redis = None
self.feature_cache_ttl = 300 # 5 minutes
async def load_model(self, model_uri: str):
"""Load model and preprocessor."""
self.model = mlflow.pyfunc.load_model(model_uri)
self.model_version = model_uri.split("/")[-1]
# Connect to Redis for caching
self.redis = await redis.from_url(
"redis://localhost:6379",
encoding="utf-8",
decode_responses=True
)
logger.info(f"Model loaded: {self.model_version}")
async def get_cached_features(self, entity_id: str) -> Optional[Dict]:
"""Get features from cache."""
if self.redis:
cached = await self.redis.get(f"features:{entity_id}")
if cached:
return json.loads(cached)
return None
async def set_cached_features(self, entity_id: str,
features: Dict, ttl: int = None):
"""Cache features."""
if self.redis:
await self.redis.setex(
f"features:{entity_id}",
ttl or self.feature_cache_ttl,
json.dumps(features)
)
# Global model server
model_server = ModelServer()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model on startup."""
await model_server.load_model("models:/fraud_detection/Production")
yield
# Cleanup on shutdown
if model_server.redis:
await model_server.redis.close()
app = FastAPI(
title="ML Model Server",
version="1.0.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class PredictionRequest(BaseModel):
entity_id: str = Field(..., description="Unique entity identifier")
features: Dict[str, float] = Field(
...,
description="Feature values for prediction"
)
request_id: Optional[str] = Field(
None,
description="Optional request ID for tracking"
)
class Config:
json_schema_extra = {
"example": {
"entity_id": "user_12345",
"features": {
"transaction_amount": 150.00,
"time_since_last_transaction": 3600,
"merchant_category": 1,
"user_account_age_days": 365
},
"request_id": "req_abc123"
}
}
class PredictionResponse(BaseModel):
prediction: float
probability: float
confidence: float
model_version: str
latency_ms: float
request_id: Optional[str] = None
timestamp: datetime
class BatchPredictionRequest(BaseModel):
instances: List[PredictionRequest]
class BatchPredictionResponse(BaseModel):
predictions: List[PredictionResponse]
total_latency_ms: float
batch_size: int
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
"""Single prediction endpoint."""
start_time = datetime.now()
try:
# Try to get cached features
cached_features = await model_server.get_cached_features(
request.entity_id
)
# Use provided features or cached
features = request.features
if cached_features:
features.update(cached_features)
# Prepare input
input_df = pd.DataFrame([features])
# Run prediction
prediction = model_server.model.predict(input_df)
# Calculate confidence
probability = float(prediction[0]) if hasattr(prediction[0], '__float__') else float(prediction[0][1])
confidence = max(probability, 1 - probability)
# Calculate latency
latency_ms = (datetime.now() - start_time).total_seconds() * 1000
# Log prediction for monitoring
logger.info(f"Prediction: {request.entity_id}, "
f"latency: {latency_ms:.2f}ms")
return PredictionResponse(
prediction=float(prediction[0]),
probability=probability,
confidence=confidence,
model_version=model_server.model_version,
latency_ms=latency_ms,
request_id=request.request_id,
timestamp=datetime.now()
)
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict/batch", response_model=BatchPredictionResponse)
async def predict_batch(request: BatchPredictionRequest):
"""Batch prediction endpoint."""
start_time = datetime.now()
try:
# Prepare batch input
features_list = [inst.features for inst in request.instances]
entity_ids = [inst.entity_id for inst in request.instances]
input_df = pd.DataFrame(features_list)
# Run batch prediction
predictions = model_server.model.predict(input_df)
# Build responses
responses = []
for i, (pred, entity_id) in enumerate(zip(predictions, entity_ids)):
probability = float(pred) if hasattr(pred, '__float__') else float(pred[1])
confidence = max(probability, 1 - probability)
responses.append(PredictionResponse(
prediction=float(pred),
probability=probability,
confidence=confidence,
model_version=model_server.model_version,
latency_ms=0, # Will be updated
request_id=request.instances[i].request_id,
timestamp=datetime.now()
))
total_latency_ms = (datetime.now() - start_time).total_seconds() * 1000
# Update individual latencies proportionally
per_item_latency = total_latency_ms / len(responses)
for resp in responses:
resp.latency_ms = per_item_latency
return BatchPredictionResponse(
predictions=responses,
total_latency_ms=total_latency_ms,
batch_size=len(responses)
)
except Exception as e:
logger.error(f"Batch prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
"""Health check endpoint."""
return {
"status": "healthy",
"model_loaded": model_server.model is not None,
"model_version": model_server.model_version,
"timestamp": datetime.now()
}
@app.get("/metrics")
async def metrics():
"""Prometheus metrics endpoint."""
return {
"model_version": model_server.model_version,
"requests_total": 0, # Would be tracked
"latency_p50": 0,
"latency_p99": 0,
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
βΉοΈ
For sub-10ms latency, use gRPC instead of REST. NVIDIA Triton Inference Server provides optimized serving with dynamic batching, model ensembles, and multi-GPU support.
A/B Testing Framework
Statistical A/B Testing Implementation
import numpy as np
from scipy import stats
from typing import Dict, List, Tuple
from dataclasses import dataclass
from datetime import datetime, timedelta
import hashlib
import json
@dataclass
class ABTestConfig:
name: str
model_a: str # Control
model_b: str # Treatment
traffic_split: float # % to treatment
min_samples: int = 1000
significance_level: float = 0.05
primary_metric: str = "conversion_rate"
secondary_metrics: List[str] = None
duration_days: int = 7
class ABTestManager:
def __init__(self, config: ABTestConfig):
self.config = config
self.results_a = []
self.results_b = []
self.start_time = datetime.now()
def assign_variant(self, entity_id: str) -> str:
"""Deterministically assign entity to variant."""
hash_value = int(
hashlib.md5(
f"{self.config.name}:{entity_id}".encode()
).hexdigest(),
16
) % 100
if hash_value < self.config.traffic_split * 100:
return "B"
else:
return "A"
def log_result(self, variant: str, entity_id: str,
prediction: float, actual: float = None):
"""Log prediction result for a variant."""
result = {
'entity_id': entity_id,
'prediction': prediction,
'actual': actual,
'timestamp': datetime.now()
}
if variant == "A":
self.results_a.append(result)
else:
self.results_b.append(result)
def analyze_results(self) -> Dict:
"""Perform statistical analysis of A/B test results."""
# Filter to only completed results
completed_a = [r for r in self.results_a if r['actual'] is not None]
completed_b = [r for r in self.results_b if r['actual'] is not None]
if len(completed_a) < self.config.min_samples or \
len(completed_b) < self.config.min_samples:
return {
'status': 'insufficient_data',
'samples_a': len(completed_a),
'samples_b': len(completed_b),
'required': self.config.min_samples
}
# Extract metrics
values_a = [r['actual'] for r in completed_a]
values_b = [r['actual'] for r in completed_b]
# Calculate primary metric
if self.config.primary_metric == "conversion_rate":
metric_a = np.mean(values_a)
metric_b = np.mean(values_b)
# Chi-squared test for proportions
contingency = np.array([
[sum(values_a), len(values_a) - sum(values_a)],
[sum(values_b), len(values_b) - sum(values_b)]
])
chi2, p_value, _, _ = stats.chi2_contingency(contingency)
else:
# Continuous metric - t-test
metric_a = np.mean(values_a)
metric_b = np.mean(values_b)
t_stat, p_value = stats.ttest_ind(values_a, values_b)
# Calculate confidence interval
se_a = np.std(values_a) / np.sqrt(len(values_a))
se_b = np.std(values_b) / np.sqrt(len(values_b))
ci_lower = (metric_b - metric_a) - 1.96 * np.sqrt(se_a**2 + se_b**2)
ci_upper = (metric_b - metric_a) + 1.96 * np.sqrt(se_a**2 + se_b**2)
# Determine winner
significant = p_value < self.config.significance_level
winner = None
if significant:
if metric_b > metric_a:
winner = "B"
else:
winner = "A"
# Calculate lift
lift = (metric_b - metric_a) / metric_a * 100 if metric_a > 0 else 0
# Calculate required sample size for future tests
required_n = self._calculate_required_sample_size(
metric_a, effect_size=0.05
)
return {
'status': 'completed' if significant else 'running',
'significant': significant,
'winner': winner,
'p_value': p_value,
'metric_a': metric_a,
'metric_b': metric_b,
'lift': lift,
'confidence_interval': [ci_lower, ci_upper],
'samples_a': len(completed_a),
'samples_b': len(completed_b),
'duration_days': (datetime.now() - self.start_time).days,
'required_samples_per_variant': required_n,
}
def _calculate_required_sample_size(self, baseline_rate: float,
effect_size: float,
power: float = 0.8) -> int:
"""Calculate required sample size for statistical power."""
alpha = self.config.significance_level
beta = 1 - power
p1 = baseline_rate
p2 = baseline_rate * (1 + effect_size)
pooled_p = (p1 + p2) / 2
z_alpha = stats.norm.ppf(1 - alpha/2)
z_beta = stats.norm.ppf(power)
numerator = (z_alpha * np.sqrt(2 * pooled_p * (1 - pooled_p)) +
z_beta * np.sqrt(p1 * (1 - p1) + p2 * (1 - p2)))**2
denominator = (p2 - p1)**2
return int(np.ceil(numerator / denominator))
def should_stop_early(self) -> Tuple[bool, str]:
"""Check if test should stop early (for ethical/resource reasons)."""
# Stop if p-value is very significant (p < 0.001)
results = self.analyze_results()
if results.get('p_value', 1) < 0.001:
return True, "Highly significant result detected"
# Stop if duration exceeds max
if results.get('duration_days', 0) >= self.config.duration_days:
return True, "Maximum duration reached"
return False, "Continue test"
A/B Testing Traffic Management
# kubernetes/ab-testing.yaml
apiVersion: networking.istio.io/v1beta1
kind: VirtualService
metadata:
name: model-serving-vs
spec:
hosts:
- model-serving
http:
- route:
- destination:
host: model-serving
subset: model-a
weight: 90
- destination:
host: model-serving
subset: model-b
weight: 10
retries:
attempts: 3
perTryTimeout: 2s
timeout: 10s
---
apiVersion: networking.istio.io/v1beta1
kind: DestinationRule
metadata:
name: model-serving-dr
spec:
host: model-serving
subsets:
- name: model-a
labels:
model-version: v1
trafficPolicy:
connectionPool:
tcp:
maxConnections: 100
http:
h2UpgradePolicy: DEFAULT
http1MaxPendingRequests: 100
http2MaxRequests: 1000
- name: model-b
labels:
model-version: v2
trafficPolicy:
connectionPool:
tcp:
maxConnections: 50
http:
h2UpgradePolicy: DEFAULT
http1MaxPendingRequests: 50
http2MaxRequests: 500
Shadow Mode Deployment
Shadow Mode Implementation
import asyncio
from typing import Optional
import time
class ShadowModeServer:
def __init__(self, primary_model, shadow_model):
self.primary_model = primary_model
self.shadow_model = shadow_model
self.shadow_predictions = []
self.primary_predictions = []
self.comparison_results = []
async def predict_with_shadow(self, features: dict) -> dict:
"""Run prediction with shadow model comparison."""
# Run primary prediction (serves response)
start_primary = time.time()
primary_prediction = await self._run_prediction(
self.primary_model, features
)
primary_latency = (time.time() - start_primary) * 1000
# Run shadow prediction asynchronously (doesn't serve response)
async def run_shadow():
start_shadow = time.time()
shadow_prediction = await self._run_prediction(
self.shadow_model, features
)
shadow_latency = (time.time() - start_shadow) * 1000
# Store comparison
self.comparison_results.append({
'primary_prediction': primary_prediction,
'shadow_prediction': shadow_prediction,
'primary_latency': primary_latency,
'shadow_latency': shadow_latency,
'timestamp': time.time()
})
return shadow_prediction
# Fire and forget shadow prediction
asyncio.create_task(run_shadow())
return {
'prediction': primary_prediction,
'latency_ms': primary_latency,
'model_version': self.primary_model.version
}
async def _run_prediction(self, model, features: dict) -> float:
"""Run prediction on a model."""
import pandas as pd
input_df = pd.DataFrame([features])
prediction = model.predict(input_df)
return float(prediction[0])
def get_comparison_metrics(self) -> dict:
"""Get metrics comparing primary vs shadow."""
if not self.comparison_results:
return {'status': 'no_data'}
primary_preds = [r['primary_prediction'] for r in self.comparison_results]
shadow_preds = [r['shadow_prediction'] for r in self.comparison_results]
# Calculate agreement metrics
agreement = np.mean(
np.array(primary_preds) == np.array(shadow_preds)
)
# Calculate correlation
correlation = np.corrcoef(primary_preds, shadow_preds)[0, 1]
# Calculate latency comparison
primary_latencies = [r['primary_latency'] for r in self.comparison_results]
shadow_latencies = [r['shadow_latency'] for r in self.comparison_results]
return {
'agreement_rate': agreement,
'correlation': correlation,
'primary_avg_latency': np.mean(primary_latencies),
'shadow_avg_latency': np.mean(shadow_latencies),
'total_comparisons': len(self.comparison_results),
'mean_absolute_difference': np.mean(
np.abs(np.array(primary_preds) - np.array(shadow_preds))
)
}
# Shadow mode Kubernetes deployment
SHADOW_DEPLOYMENT = """
apiVersion: apps/v1
kind: Deployment
metadata:
name: model-shadow
spec:
replicas: 2
selector:
matchLabels:
app: ml-model
mode: shadow
template:
metadata:
labels:
app: ml-model
mode: shadow
spec:
containers:
- name: model-server
image: registry.example.com/ml-model:v2
ports:
- containerPort: 8080
env:
- name: MODEL_MODE
value: "shadow"
- name: PRIMARY_MODEL_URL
value: "http://model-primary:8080"
resources:
requests:
memory: "2Gi"
cpu: "1000m"
limits:
memory: "4Gi"
cpu: "2000m"
---
apiVersion: v1
kind: Service
metadata:
name: model-shadow
spec:
selector:
app: ml-model
mode: shadow
ports:
- port: 8080
targetPort: 8080
"""
β οΈ
Shadow mode deployments double your inference load. Monitor resource usage carefully and set appropriate resource limits. Consider sampling (10-20% of traffic) for cost optimization.
Edge Deployment Patterns
TensorFlow Lite Edge Deployment
import tensorflow as tf
import numpy as np
from typing import Dict
class EdgeModelConverter:
def __init__(self, model_path: str):
self.model = tf.keras.models.load_model(model_path)
def convert_to_tflite(self, quantize: bool = True,
optimize: bool = True) -> bytes:
"""Convert model to TensorFlow Lite format."""
converter = tf.lite.TFLiteConverter.from_keras_model(self.model)
if optimize:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
if quantize:
def representative_dataset():
for _ in range(100):
yield [np.random.randn(1, *self.model.input_shape[1:]).astype(np.float32)]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_types = [tf.float16]
tflite_model = converter.convert()
# Save model
output_path = "model_edge.tflite"
with open(output_path, 'wb') as f:
f.write(tflite_model)
print(f"Model converted: {len(tflite_model) / 1024:.1f} KB")
return tflite_model
def convert_to_onnx(self, output_path: str = "model.onnx"):
"""Convert to ONNX format for cross-platform deployment."""
import tf2onnx
import onnx
spec = (tf.TensorSpec(self.model.input_shape, tf.float32, name="input"),)
output_path = output_path
model_proto, _ = tf2onnx.convert.from_keras(
self.model, input_signature=spec, output_path=output_path
)
print(f"ONNX model saved: {output_path}")
return output_path
def create_edge_inference_code(self, tflite_path: str):
"""Generate C++ inference code for edge devices."""
cpp_code = f"""
// edge_inference.cpp
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include <vector>
#include <memory>
class EdgeModel {{
public:
EdgeModel(const char* model_path) {{
model_ = tflite::FlatBufferModel::BuildFromFile(model_path);
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::BuildInterpreter(model_, resolver, &interpreter_);
}}
std::vector<float> Predict(const std::vector<float>& input) {{
// Resize input tensor
interpreter_->ResizeInputTensor(0, {{1, static_cast<int>(input.size())}});
interpreter_->AllocateTensors();
// Copy input data
float* input_ptr = interpreter_->typed_tensor<float>(0);
std::copy(input.begin(), input.end(), input_ptr);
// Run inference
interpreter_->Invoke();
// Get output
float* output_ptr = interpreter_->typed_tensor<float>(0);
int output_size = interpreter_->tensor(0)->bytes / sizeof(float);
return std::vector<float>(output_ptr, output_ptr + output_size);
}}
private:
std::unique_ptr<tflite::FlatBufferModel> model_;
std::unique_ptr<tflite::Interpreter> interpreter_;
}};
"""
with open("edge_inference.cpp", "w") as f:
f.write(cpp_code)
Model Serving with NVIDIA Triton
Triton Configuration
# model_repository/fraud_detection/config.pbtxt
name: "fraud_detection"
platform: "ensemble"
max_batch_size: 64
input [
{
name: "INPUT"
data_type: TYPE_FP32
dims: [ 128 ]
}
]
output [
{
name: "OUTPUT"
data_type: TYPE_FP32
dims: [ 1 ]
}
]
ensemble_scheduling {
step [
{
model_name: "preprocessing"
model_version: -1
input_map {
key: "INPUT"
value: "INPUT"
}
output_map {
key: "OUTPUT"
value: "preprocessed"
}
},
{
model_name: "xgboost_model"
model_version: -1
input_map {
key: "INPUT"
value: "preprocessed"
}
output_map {
key: "OUTPUT"
value: "OUTPUT"
}
}
]
}
dynamic_batching {
preferred_batch_size: [ 16, 32, 64 ]
max_queue_delay_microseconds: 100
}
instance_group [
{
count: 2
kind: KIND_GPU
gpus: [ 0 ]
}
]
Triton Client
import tritonclient.grpc as grpcclient
import numpy as np
from typing import Dict, List
import time
class TritonClient:
def __init__(self, server_url: str = "localhost:8001"):
self.client = grpcclient.InferenceServerClient(url=server_url)
def predict(self, input_data: np.ndarray,
model_name: str = "fraud_detection") -> Dict:
"""Send prediction request to Triton."""
# Prepare input
inputs = [grpcclient.InferInput("INPUT", input_data.shape, "FP32")]
inputs[0].set_data_from_numpy(input_data)
# Prepare output
outputs = [grpcclient.InferRequestedOutput("OUTPUT")]
# Send request
start_time = time.time()
response = self.client.infer(
model_name=model_name,
inputs=inputs,
outputs=outputs
)
latency_ms = (time.time() - start_time) * 1000
# Get results
output = response.as_numpy("OUTPUT")
return {
'prediction': float(output[0]),
'latency_ms': latency_ms,
'model_name': model_name
}
def batch_predict(self, batch_data: np.ndarray,
model_name: str = "fraud_detection") -> List[Dict]:
"""Send batch prediction request."""
batch_size = batch_data.shape[0]
inputs = [grpcclient.InferInput("INPUT", batch_data.shape, "FP32")]
inputs[0].set_data_from_numpy(batch_data)
outputs = [grpcclient.InferRequestedOutput("OUTPUT")]
start_time = time.time()
response = self.client.infer(
model_name=model_name,
inputs=inputs,
outputs=outputs
)
total_latency_ms = (time.time() - start_time) * 1000
output = response.as_numpy("OUTPUT")
results = []
per_item_latency = total_latency_ms / batch_size
for i in range(batch_size):
results.append({
'prediction': float(output[i]),
'latency_ms': per_item_latency,
'batch_index': i
})
return results
def get_model_metadata(self, model_name: str) -> Dict:
"""Get model metadata from Triton."""
metadata = self.client.get_model_metadata(model_name=model_name)
return {
'name': metadata.name,
'version': metadata.version,
'platform': metadata.platform,
'inputs': [
{
'name': inp.name,
'datatype': inp.datatype,
'shape': inp.shape
}
for inp in metadata.inputs
],
'outputs': [
{
'name': out.name,
'datatype': out.datatype,
'shape': out.shape
}
for out in metadata.outputs
]
}
βΉοΈ
NVIDIA Triton supports dynamic batching, model ensembles, and multi-GPU inference. Use it for high-throughput serving with sub-millisecond latency requirements.
Summary
Model serving patterns depend on your requirements:
- Batch Inference: PySpark for large-scale offline processing
- Real-Time REST: FastAPI/Flask for synchronous web APIs
- Real-Time gRPC: Triton for high-throughput, low-latency serving
- A/B Testing: Statistical frameworks with traffic splitting
- Shadow Mode: Validate new models without affecting users
- Edge Deployment: TFLite/ONNX for on-device inference
Choose the pattern that matches your latency, throughput, and cost requirements.