Real-Time ML Inference: Serving Models at Scale
Build low-latency, high-throughput ML inference systems with optimized model serving, caching strategies, and scalable architecture patterns.
The Real-Time Inference Challenge
While training machine learning models can take hours or days, production inference often requires sub-second responses—sometimes sub-100ms. Real-time ML powers recommendation engines, fraud detection, dynamic pricing, and personalization systems where every millisecond of latency impacts user experience and business outcomes.
Building production inference systems that serve millions of predictions per day with consistent low latency requires careful architecture design, model optimization, and infrastructure engineering.
Latency Requirements by Use Case
Latency Tiers
Ultra-Low Latency (< 10ms)
- Ad bidding and serving
- High-frequency trading signals
- Real-time fraud detection (payment processing)
- Game AI decisions
Low Latency (10-100ms)
- E-commerce product recommendations
- Search result ranking
- Chatbot responses
- Content personalization
Interactive (100-500ms)
- Credit scoring
- Risk assessment
- Lead scoring
- Dynamic pricing
Batch-Friendly (> 500ms)
- Email campaign targeting
- Churn prediction
- Customer segmentation updates
- Long-form content generation
Model Optimization for Inference
1. Model Quantization
Reduce Model Precision
Convert 32-bit floating point to 8-bit integers:
import tensorflow as tf
# Convert TensorFlow model to TFLite with quantization
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# Post-training quantization
tflite_quantized_model = converter.convert()
# Save quantized model
with open('model_quantized.tflite', 'wb') as f:
f.write(tflite_quantized_model)
Benefits:
- 4x smaller model size
- 2-4x faster inference
- Lower memory footprint
- Minimal accuracy loss (typically < 1%)
PyTorch Quantization
import torch
# Dynamic quantization (easiest)
quantized_model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear}, # Layers to quantize
dtype=torch.qint8
)
# Static quantization (better performance, requires calibration)
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
# Calibrate on representative data
for data in calibration_data:
model(data)
torch.quantization.convert(model, inplace=True)
2. Model Pruning
Remove Unnecessary Parameters
import tensorflow_model_optimization as tfmot
# Define pruning schedule
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.0,
final_sparsity=0.5, # Remove 50% of weights
begin_step=0,
end_step=1000
)
# Apply pruning
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(
model,
pruning_schedule=pruning_schedule
)
# Train pruned model
pruned_model.compile(optimizer='adam', loss='mse')
pruned_model.fit(X_train, y_train, epochs=10)
# Strip pruning wrappers for deployment
final_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
Expected Results:
- 30-50% smaller model
- 20-40% faster inference
- Minimal accuracy degradation
3. Knowledge Distillation
Train Smaller Model from Larger “Teacher” Model
class DistillationTraining:
def __init__(self, teacher_model, student_model, temperature=3.0, alpha=0.5):
self.teacher = teacher_model
self.student = student_model
self.temperature = temperature
self.alpha = alpha # Weight between hard and soft targets
def distillation_loss(self, y_true, y_pred, teacher_pred):
"""
Combined loss: true labels + teacher predictions
"""
# Hard loss (true labels)
hard_loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
# Soft loss (teacher predictions)
soft_pred = tf.nn.softmax(y_pred / self.temperature)
soft_teacher = tf.nn.softmax(teacher_pred / self.temperature)
soft_loss = tf.keras.losses.categorical_crossentropy(soft_teacher, soft_pred)
# Combine
return self.alpha * hard_loss + (1 - self.alpha) * soft_loss
def train(self, X_train, y_train, epochs=10):
# Get teacher predictions
teacher_predictions = self.teacher.predict(X_train)
# Train student
for epoch in range(epochs):
# Training step
with tf.GradientTape() as tape:
student_predictions = self.student(X_train, training=True)
loss = self.distillation_loss(y_train, student_predictions, teacher_predictions)
gradients = tape.gradient(loss, self.student.trainable_variables)
optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))
Typical Results:
- Student model 10-100x smaller
- 10-50x faster inference
- Retains 90-98% of teacher accuracy
4. ONNX Runtime Optimization
Convert to Optimized Format
import torch
import onnx
import onnxruntime as ort
# Export PyTorch model to ONNX
dummy_input = torch.randn(1, input_size)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
# Load with ONNX Runtime (optimized inference)
session = ort.InferenceSession("model.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
# Inference
def predict(input_data):
ort_inputs = {session.get_inputs()[0].name: input_data}
ort_outputs = session.run(None, ort_inputs)
return ort_outputs[0]
Performance Gains:
- 2-5x faster than native frameworks
- Optimized graph execution
- Cross-platform compatibility
Serving Architecture Patterns
Pattern 1: REST API Serving
FastAPI Implementation
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import joblib
import numpy as np
app = FastAPI()
# Load model at startup
model = joblib.load('model.pkl')
scaler = joblib.load('scaler.pkl')
class PredictionRequest(BaseModel):
features: list[float]
class PredictionResponse(BaseModel):
prediction: float
confidence: float
latency_ms: float
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
import time
start_time = time.time()
try:
# Preprocess
features = np.array(request.features).reshape(1, -1)
features_scaled = scaler.transform(features)
# Predict
prediction = model.predict(features_scaled)[0]
confidence = model.predict_proba(features_scaled).max()
latency_ms = (time.time() - start_time) * 1000
return PredictionResponse(
prediction=float(prediction),
confidence=float(confidence),
latency_ms=latency_ms
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Health check
@app.get("/health")
async def health():
return {"status": "healthy", "model_loaded": model is not None}
Pattern 2: gRPC for Lower Latency
Protocol Buffer Definition
syntax = "proto3";
service MLInference {
rpc Predict (PredictRequest) returns (PredictResponse);
}
message PredictRequest {
repeated float features = 1;
}
message PredictResponse {
float prediction = 1;
float confidence = 2;
}
gRPC Server
import grpc
from concurrent import futures
import ml_inference_pb2
import ml_inference_pb2_grpc
class MLInferenceServicer(ml_inference_pb2_grpc.MLInferenceServicer):
def __init__(self, model):
self.model = model
def Predict(self, request, context):
features = np.array(request.features).reshape(1, -1)
prediction = self.model.predict(features)[0]
confidence = self.model.predict_proba(features).max()
return ml_inference_pb2.PredictResponse(
prediction=prediction,
confidence=confidence
)
# Start server
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
ml_inference_pb2_grpc.add_MLInferenceServicer_to_server(
MLInferenceServicer(model),
server
)
server.add_insecure_port('[::]:50051')
server.start()
Latency Comparison:
- REST API: 10-50ms
- gRPC: 2-10ms (5-10x faster)
Pattern 3: Batch Inference
Micro-Batching for Throughput
import asyncio
from collections import deque
import time
class BatchPredictor:
def __init__(self, model, batch_size=32, max_wait_ms=10):
self.model = model
self.batch_size = batch_size
self.max_wait_ms = max_wait_ms
self.queue = deque()
self.lock = asyncio.Lock()
async def predict(self, features):
"""
Add request to batch queue and wait for result
"""
# Create future for this request
future = asyncio.Future()
async with self.lock:
self.queue.append((features, future))
# Trigger batch processing if queue is full
if len(self.queue) >= self.batch_size:
await self._process_batch()
# Wait for result (with timeout)
try:
result = await asyncio.wait_for(future, timeout=self.max_wait_ms / 1000)
return result
except asyncio.TimeoutError:
# Process partial batch
await self._process_batch()
return await future
async def _process_batch(self):
"""
Process accumulated requests as batch
"""
if not self.queue:
return
# Collect batch
batch = []
futures = []
while self.queue and len(batch) < self.batch_size:
features, future = self.queue.popleft()
batch.append(features)
futures.append(future)
# Batch inference
batch_array = np.array(batch)
predictions = self.model.predict(batch_array)
# Distribute results
for future, prediction in zip(futures, predictions):
future.set_result(prediction)
# Start background batch processor
async def batch_processor(predictor):
while True:
await asyncio.sleep(predictor.max_wait_ms / 1000)
await predictor._process_batch()
Throughput Improvement:
- Single requests: 100 predictions/sec
- Batched (32): 1,000-2,000 predictions/sec (10-20x)
Pattern 4: Model Caching
Feature-Based Caching
import redis
import hashlib
import pickle
class CachedPredictor:
def __init__(self, model, redis_client, ttl_seconds=3600):
self.model = model
self.redis = redis_client
self.ttl = ttl_seconds
def predict(self, features):
# Create cache key from features
cache_key = self._make_cache_key(features)
# Check cache
cached_result = self.redis.get(cache_key)
if cached_result:
return pickle.loads(cached_result)
# Compute prediction
prediction = self.model.predict([features])[0]
# Cache result
self.redis.setex(
cache_key,
self.ttl,
pickle.dumps(prediction)
)
return prediction
def _make_cache_key(self, features):
"""Create deterministic key from features"""
feature_str = ','.join(f"{f:.4f}" for f in features) # Round to avoid float precision issues
return f"pred:{hashlib.md5(feature_str.encode()).hexdigest()}"
Cache Hit Rates:
- Recommendation systems: 30-60% (significant latency reduction)
- Fraud detection: 5-15% (limited benefit, features always changing)
Deployment Strategies
Container-Based Deployment
Dockerfile for Model Serving
FROM python:3.10-slim
# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy model and code
COPY model.pkl /app/model.pkl
COPY app.py /app/app.py
WORKDIR /app
# Expose port
EXPOSE 8000
# Run server with multiple workers
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
Kubernetes Deployment
Deployment Configuration
apiVersion: apps/v1
kind: Deployment
metadata:
name: ml-inference
spec:
replicas: 5 # Scale based on load
selector:
matchLabels:
app: ml-inference
template:
metadata:
labels:
app: ml-inference
spec:
containers:
- name: inference
image: ml-inference:latest
ports:
- containerPort: 8000
resources:
requests:
memory: "2Gi"
cpu: "1000m"
limits:
memory: "4Gi"
cpu: "2000m"
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 10
periodSeconds: 5
readinessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 5
periodSeconds: 3
---
apiVersion: v1
kind: Service
metadata:
name: ml-inference-service
spec:
selector:
app: ml-inference
ports:
- port: 80
targetPort: 8000
type: LoadBalancer
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: ml-inference-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: ml-inference
minReplicas: 3
maxReplicas: 20
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: 70
- type: Resource
resource:
name: memory
target:
type: Utilization
averageUtilization: 80
Serverless Deployment
AWS Lambda Function
import json
import boto3
import joblib
import numpy as np
# Load model once (outside handler for reuse)
s3 = boto3.client('s3')
s3.download_file('my-bucket', 'model.pkl', '/tmp/model.pkl')
model = joblib.load('/tmp/model.pkl')
def lambda_handler(event, context):
"""
Lambda function for inference
"""
try:
# Parse input
body = json.loads(event['body'])
features = np.array(body['features']).reshape(1, -1)
# Predict
prediction = model.predict(features)[0]
return {
'statusCode': 200,
'body': json.dumps({
'prediction': float(prediction)
})
}
except Exception as e:
return {
'statusCode': 500,
'body': json.dumps({'error': str(e)})
}
Cold Start Optimization:
- Use provisioned concurrency
- Keep model size < 50MB
- Use Lambda layers for dependencies
- Typical cold start: 1-3 seconds
Monitoring and Observability
Latency Tracking
from prometheus_client import Histogram, Counter
import time
# Define metrics
prediction_latency = Histogram(
'prediction_latency_seconds',
'Prediction latency in seconds',
buckets=[.001, .005, .01, .025, .05, .075, .1, .25, .5, .75, 1.0]
)
prediction_counter = Counter('predictions_total', 'Total predictions')
def predict_with_monitoring(features):
"""
Prediction with latency tracking
"""
with prediction_latency.time():
result = model.predict([features])
prediction_counter.inc()
return result
Performance Metrics Dashboard
import logging
class PerformanceMonitor:
def __init__(self):
self.latencies = []
self.throughput_window = deque(maxlen=1000)
def record_prediction(self, latency_ms):
self.latencies.append(latency_ms)
self.throughput_window.append(time.time())
# Alert if p95 latency exceeds SLA
if len(self.latencies) >= 100:
p95 = np.percentile(self.latencies[-100:], 95)
if p95 > 100: # 100ms SLA
logging.warning(f"P95 latency ({p95:.1f}ms) exceeds SLA")
def get_throughput(self):
"""Calculate requests per second"""
if len(self.throughput_window) < 2:
return 0
time_span = self.throughput_window[-1] - self.throughput_window[0]
return len(self.throughput_window) / time_span if time_span > 0 else 0
Best Practices
1. Load Testing
import asyncio
import aiohttp
import time
async def load_test(url, num_requests=1000, concurrency=50):
"""
Concurrent load test
"""
async def make_request(session):
async with session.post(url, json={'features': [1.0, 2.0, 3.0]}) as response:
return await response.json(), response.status
latencies = []
async with aiohttp.ClientSession() as session:
start_time = time.time()
# Create concurrent requests
tasks = [make_request(session) for _ in range(num_requests)]
results = await asyncio.gather(*tasks, return_exceptions=True)
duration = time.time() - start_time
# Analyze results
successful = sum(1 for r in results if not isinstance(r, Exception))
latencies = [r[1] for r in results if not isinstance(r, Exception)]
print(f"Completed {successful}/{num_requests} requests in {duration:.2f}s")
print(f"Throughput: {successful/duration:.1f} req/s")
print(f"Latency P50: {np.percentile(latencies, 50):.1f}ms")
print(f"Latency P95: {np.percentile(latencies, 95):.1f}ms")
print(f"Latency P99: {np.percentile(latencies, 99):.1f}ms")
2. Model Warmup
def warmup_model(model, num_iterations=100):
"""
Warm up model before serving (avoid cold start latency)
"""
dummy_input = np.random.rand(1, model.n_features_)
for _ in range(num_iterations):
_ = model.predict(dummy_input)
print("Model warmup complete")
3. Circuit Breaker Pattern
from datetime import datetime, timedelta
class CircuitBreaker:
def __init__(self, failure_threshold=5, timeout_seconds=60):
self.failure_threshold = failure_threshold
self.timeout_seconds = timeout_seconds
self.failures = 0
self.last_failure_time = None
self.state = 'CLOSED' # CLOSED, OPEN, HALF_OPEN
def call(self, func, *args, **kwargs):
if self.state == 'OPEN':
# Check if timeout has passed
if datetime.now() - self.last_failure_time > timedelta(seconds=self.timeout_seconds):
self.state = 'HALF_OPEN'
else:
raise Exception("Circuit breaker is OPEN")
try:
result = func(*args, **kwargs)
self._on_success()
return result
except Exception as e:
self._on_failure()
raise e
def _on_success(self):
self.failures = 0
self.state = 'CLOSED'
def _on_failure(self):
self.failures += 1
self.last_failure_time = datetime.now()
if self.failures >= self.failure_threshold:
self.state = 'OPEN'
Conclusion
Real-time ML inference requires balancing model accuracy, latency, throughput, and infrastructure costs. By optimizing models (quantization, pruning, distillation), choosing appropriate serving architectures (REST, gRPC, batch), and implementing robust monitoring, organizations can build inference systems that serve millions of predictions daily with consistent sub-100ms latency.
The key is understanding your specific latency requirements, optimizing accordingly, and continuously monitoring production performance to ensure SLAs are met.
Next Steps:
- Profile current model inference latency
- Implement model optimization (quantization as first step)
- Build containerized serving infrastructure
- Add comprehensive latency monitoring
- Load test and optimize for target throughput
Ready to Transform Your Business?
Let's discuss how our AI and technology solutions can drive revenue growth for your organization.