Back to Blog
6 min read

Online Inference Patterns: Low-Latency ML Predictions at Scale

Online inference requires balancing latency, throughput, and accuracy. Let’s explore patterns for serving ML predictions at scale with minimal latency.

Online Inference Architecture

┌─────────────────────────────────────────────────────────────┐
│                 Online Inference System                      │
├─────────────────────────────────────────────────────────────┤
│                                                              │
│  Request → Load Balancer → Inference Servers → Response     │
│               │                   │                          │
│               ↓                   ↓                          │
│          [Caching]         [Feature Store]                  │
│                                   │                          │
│                            [Model Cache]                     │
│                                                              │
│  Metrics → Prometheus → Grafana → Alerts                    │
│                                                              │
└─────────────────────────────────────────────────────────────┘

Pattern 1: Simple Inference Endpoint

from azure.ai.ml import MLClient
from azure.ai.ml.entities import (
    ManagedOnlineEndpoint,
    ManagedOnlineDeployment
)
from azure.identity import DefaultAzureCredential

# Create ML client
credential = DefaultAzureCredential()
ml_client = MLClient(
    credential=credential,
    subscription_id=subscription_id,
    resource_group_name=resource_group,
    workspace_name=workspace
)

# Create endpoint
endpoint = ManagedOnlineEndpoint(
    name="product-recommender",
    description="Real-time product recommendations",
    auth_mode="key"
)

ml_client.online_endpoints.begin_create_or_update(endpoint).result()

# Create deployment
deployment = ManagedOnlineDeployment(
    name="recommender-v1",
    endpoint_name="product-recommender",
    model=ml_client.models.get("recommender-model", version="1"),
    instance_type="Standard_DS3_v2",
    instance_count=3,
    request_settings={
        "request_timeout_ms": 90000,
        "max_concurrent_requests_per_instance": 100
    },
    liveness_probe={
        "initial_delay": 30,
        "period": 10
    },
    readiness_probe={
        "initial_delay": 10,
        "period": 10
    }
)

ml_client.online_deployments.begin_create_or_update(deployment).result()

Pattern 2: Multi-Model Serving

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import asyncio
from typing import Dict, Any

app = FastAPI()

class ModelRegistry:
    """Manage multiple models for inference."""

    def __init__(self):
        self.models: Dict[str, Any] = {}
        self.model_versions: Dict[str, str] = {}

    async def load_model(self, name: str, version: str, path: str):
        """Load a model into memory."""
        import mlflow

        model_uri = f"{path}/{name}/{version}"
        model = mlflow.pyfunc.load_model(model_uri)

        self.models[f"{name}:{version}"] = model
        self.model_versions[name] = version

    async def get_model(self, name: str, version: str = None):
        """Get a loaded model."""
        if version is None:
            version = self.model_versions.get(name)

        key = f"{name}:{version}"
        if key not in self.models:
            raise HTTPException(404, f"Model {key} not found")

        return self.models[key]

    async def predict(self, model_name: str, features: dict, version: str = None):
        """Run prediction with the specified model."""
        model = await self.get_model(model_name, version)

        # Convert features to model input format
        import pandas as pd
        input_df = pd.DataFrame([features])

        # Run prediction
        prediction = model.predict(input_df)

        return prediction[0] if len(prediction) == 1 else prediction.tolist()

registry = ModelRegistry()

class PredictionRequest(BaseModel):
    model_name: str
    features: dict
    version: str = None

class PredictionResponse(BaseModel):
    prediction: Any
    model_version: str
    latency_ms: float

@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    import time
    start = time.time()

    prediction = await registry.predict(
        model_name=request.model_name,
        features=request.features,
        version=request.version
    )

    latency = (time.time() - start) * 1000

    return PredictionResponse(
        prediction=prediction,
        model_version=request.version or registry.model_versions[request.model_name],
        latency_ms=latency
    )

Pattern 3: Feature Store Integration

from feast import FeatureStore
from typing import List, Dict
import numpy as np

class FeatureEnrichedInference:
    """Inference with real-time feature enrichment."""

    def __init__(self, model, feature_store_path: str):
        self.model = model
        self.fs = FeatureStore(repo_path=feature_store_path)

    async def predict(
        self,
        entity_ids: List[str],
        entity_type: str,
        feature_views: List[str],
        request_features: Dict = None
    ):
        """Get features and make predictions."""

        # Prepare entity dataframe
        import pandas as pd
        entity_df = pd.DataFrame({
            f"{entity_type}_id": entity_ids
        })

        # Get features from feature store
        feature_vector = self.fs.get_online_features(
            features=[
                f"{view}:*" for view in feature_views
            ],
            entity_rows=[{f"{entity_type}_id": eid} for eid in entity_ids]
        ).to_dict()

        # Combine with request-time features
        if request_features:
            for key, value in request_features.items():
                feature_vector[key] = [value] * len(entity_ids)

        # Convert to model input
        feature_df = pd.DataFrame(feature_vector)

        # Predict
        predictions = self.model.predict(feature_df)

        return {
            "predictions": predictions.tolist(),
            "entity_ids": entity_ids
        }

# Usage
inference_engine = FeatureEnrichedInference(
    model=loaded_model,
    feature_store_path="feature_repo/"
)

result = await inference_engine.predict(
    entity_ids=["user_123", "user_456"],
    entity_type="user",
    feature_views=["user_features", "user_activity_features"],
    request_features={"current_hour": 14, "is_weekend": False}
)

Pattern 4: Caching for Low Latency

import redis
import hashlib
import json
from typing import Optional
import asyncio

class CachedInferenceServer:
    """Inference server with prediction caching."""

    def __init__(self, model, redis_url: str, cache_ttl: int = 300):
        self.model = model
        self.redis = redis.from_url(redis_url)
        self.cache_ttl = cache_ttl
        self.cache_hits = 0
        self.cache_misses = 0

    def _cache_key(self, model_name: str, features: dict) -> str:
        """Generate cache key from features."""
        feature_str = json.dumps(features, sort_keys=True)
        hash_val = hashlib.md5(feature_str.encode()).hexdigest()
        return f"pred:{model_name}:{hash_val}"

    async def predict(
        self,
        features: dict,
        model_name: str,
        use_cache: bool = True
    ) -> dict:
        """Make prediction with optional caching."""

        cache_key = self._cache_key(model_name, features)

        # Check cache
        if use_cache:
            cached = self.redis.get(cache_key)
            if cached:
                self.cache_hits += 1
                return {
                    "prediction": json.loads(cached),
                    "cached": True,
                    "latency_ms": 0.1
                }

        self.cache_misses += 1

        # Run inference
        import time
        start = time.time()

        prediction = self.model.predict(features)

        latency = (time.time() - start) * 1000

        # Cache result
        if use_cache:
            self.redis.setex(
                cache_key,
                self.cache_ttl,
                json.dumps(prediction)
            )

        return {
            "prediction": prediction,
            "cached": False,
            "latency_ms": latency
        }

    def get_cache_stats(self) -> dict:
        """Get cache performance statistics."""
        total = self.cache_hits + self.cache_misses
        hit_rate = self.cache_hits / total if total > 0 else 0

        return {
            "cache_hits": self.cache_hits,
            "cache_misses": self.cache_misses,
            "hit_rate": hit_rate
        }

Pattern 5: Batch and Real-Time Hybrid

from typing import List
import asyncio

class HybridInferenceServer:
    """Combine batch and real-time inference."""

    def __init__(self, model, batch_size: int = 32, max_wait_ms: int = 50):
        self.model = model
        self.batch_size = batch_size
        self.max_wait_ms = max_wait_ms
        self.pending_requests = []
        self.batch_task = None

    async def predict(self, features: dict) -> dict:
        """Queue prediction request for batching."""

        # Create future for this request
        loop = asyncio.get_event_loop()
        future = loop.create_future()

        self.pending_requests.append({
            "features": features,
            "future": future
        })

        # Start batch processing if needed
        if self.batch_task is None or self.batch_task.done():
            self.batch_task = asyncio.create_task(self._process_batch())

        # Trigger immediate processing if batch is full
        if len(self.pending_requests) >= self.batch_size:
            self.batch_task.cancel()
            self.batch_task = asyncio.create_task(self._process_batch())

        # Wait for result
        result = await future
        return result

    async def _process_batch(self):
        """Process accumulated requests as a batch."""

        # Wait for more requests or timeout
        await asyncio.sleep(self.max_wait_ms / 1000)

        if not self.pending_requests:
            return

        # Get all pending requests
        batch = self.pending_requests[:]
        self.pending_requests.clear()

        # Prepare batch input
        import pandas as pd
        batch_features = pd.DataFrame([r["features"] for r in batch])

        # Run batch prediction
        predictions = self.model.predict(batch_features)

        # Distribute results
        for i, request in enumerate(batch):
            request["future"].set_result({
                "prediction": predictions[i].tolist() if hasattr(predictions[i], 'tolist') else predictions[i],
                "batch_size": len(batch)
            })

Pattern 6: Model Ensembling

import numpy as np
from typing import List, Dict

class EnsembleInferenceServer:
    """Serve ensemble of models."""

    def __init__(self, models: Dict[str, any], weights: Dict[str, float] = None):
        self.models = models
        self.weights = weights or {name: 1.0 / len(models) for name in models}

    async def predict(
        self,
        features: dict,
        ensemble_method: str = "weighted_average"
    ) -> dict:
        """Run ensemble prediction."""

        # Get predictions from all models
        predictions = {}
        for name, model in self.models.items():
            predictions[name] = model.predict([features])[0]

        # Combine predictions
        if ensemble_method == "weighted_average":
            combined = sum(
                pred * self.weights[name]
                for name, pred in predictions.items()
            )
        elif ensemble_method == "voting":
            # For classification
            votes = [np.argmax(pred) for pred in predictions.values()]
            combined = max(set(votes), key=votes.count)
        elif ensemble_method == "stacking":
            # Use meta-model
            stacked_features = np.concatenate(list(predictions.values()))
            combined = self.meta_model.predict([stacked_features])[0]

        return {
            "prediction": combined,
            "model_predictions": predictions,
            "ensemble_method": ensemble_method
        }

Online inference requires careful attention to latency, reliability, and scalability. Choose patterns based on your specific requirements and constraints.

Resources

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.