Back to Blog
6 min read

Databricks Model Serving: Real-Time ML Inference at Scale

Databricks Model Serving provides serverless, real-time inference for machine learning models. It automatically handles scaling, monitoring, and integrates seamlessly with the Feature Store.

Model Serving Overview

Key features:

  • Serverless: No infrastructure management
  • Auto-scaling: Handles variable traffic automatically
  • Low latency: Sub-second response times
  • Feature Store integration: Automatic feature lookups
  • A/B testing: Built-in traffic splitting

Enabling Model Serving

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import EndpointCoreConfigInput, ServedModelInput

w = WorkspaceClient()

# Create a model serving endpoint
endpoint_config = EndpointCoreConfigInput(
    name="churn-prediction-endpoint",
    served_models=[
        ServedModelInput(
            model_name="churn-predictor",
            model_version="3",
            workload_size="Small",
            scale_to_zero_enabled=True
        )
    ]
)

w.serving_endpoints.create(
    name="churn-prediction-endpoint",
    config=endpoint_config
)

REST API Configuration

# Via REST API
import requests
import json

workspace_url = "https://adb-xxx.azuredatabricks.net"
token = os.environ["DATABRICKS_TOKEN"]

headers = {
    "Authorization": f"Bearer {token}",
    "Content-Type": "application/json"
}

# Create endpoint
endpoint_config = {
    "name": "recommendation-endpoint",
    "config": {
        "served_models": [
            {
                "model_name": "product-recommender",
                "model_version": "1",
                "workload_size": "Small",
                "scale_to_zero_enabled": True
            }
        ],
        "traffic_config": {
            "routes": [
                {
                    "served_model_name": "product-recommender-1",
                    "traffic_percentage": 100
                }
            ]
        }
    }
}

response = requests.post(
    f"{workspace_url}/api/2.0/serving-endpoints",
    headers=headers,
    json=endpoint_config
)

print(response.json())

Querying the Endpoint

# Python client
import requests

def predict(endpoint_name, instances):
    url = f"{workspace_url}/serving-endpoints/{endpoint_name}/invocations"

    payload = {
        "instances": instances
    }

    response = requests.post(
        url,
        headers=headers,
        json=payload
    )

    return response.json()

# Example: Predict churn for customers
predictions = predict(
    "churn-prediction-endpoint",
    [
        {"customer_id": "C001", "total_orders": 5, "lifetime_value": 500.0},
        {"customer_id": "C002", "total_orders": 20, "lifetime_value": 2500.0}
    ]
)

print(predictions)
# {"predictions": [0.75, 0.12]}

Calling from Different Languages

// JavaScript/Node.js
const axios = require('axios');

async function predict(endpointName, instances) {
    const response = await axios.post(
        `${workspaceUrl}/serving-endpoints/${endpointName}/invocations`,
        { instances },
        {
            headers: {
                'Authorization': `Bearer ${token}`,
                'Content-Type': 'application/json'
            }
        }
    );
    return response.data;
}
// C#/.NET
using System.Net.Http;
using System.Text.Json;

public async Task<PredictionResponse> PredictAsync(string endpointName, object[] instances)
{
    var client = new HttpClient();
    client.DefaultRequestHeaders.Add("Authorization", $"Bearer {token}");

    var payload = new { instances };
    var content = new StringContent(
        JsonSerializer.Serialize(payload),
        Encoding.UTF8,
        "application/json"
    );

    var response = await client.PostAsync(
        $"{workspaceUrl}/serving-endpoints/{endpointName}/invocations",
        content
    );

    var responseBody = await response.Content.ReadAsStringAsync();
    return JsonSerializer.Deserialize<PredictionResponse>(responseBody);
}

Feature Store Integration

Serve models that automatically look up features:

# Log model with feature store
from databricks.feature_store import FeatureStoreClient

fs = FeatureStoreClient()

# When logging the model, include the training set
fs.log_model(
    model=trained_model,
    artifact_path="model",
    flavor=mlflow.sklearn,
    training_set=training_set,  # Contains feature lookup definitions
    registered_model_name="churn-predictor-with-features"
)

# Create endpoint - it will automatically look up features
endpoint_config = {
    "name": "churn-with-features",
    "config": {
        "served_models": [{
            "model_name": "churn-predictor-with-features",
            "model_version": "1",
            "workload_size": "Small"
        }]
    }
}

# When calling the endpoint, only provide lookup keys
predictions = predict(
    "churn-with-features",
    [{"customer_id": "C001"}, {"customer_id": "C002"}]
)
# Features are automatically fetched from the Feature Store

A/B Testing

Deploy multiple model versions:

# Update endpoint with A/B test configuration
ab_test_config = {
    "served_models": [
        {
            "name": "model-a",
            "model_name": "churn-predictor",
            "model_version": "2",
            "workload_size": "Small"
        },
        {
            "name": "model-b",
            "model_name": "churn-predictor",
            "model_version": "3",
            "workload_size": "Small"
        }
    ],
    "traffic_config": {
        "routes": [
            {"served_model_name": "model-a", "traffic_percentage": 80},
            {"served_model_name": "model-b", "traffic_percentage": 20}
        ]
    }
}

# Update endpoint
response = requests.put(
    f"{workspace_url}/api/2.0/serving-endpoints/churn-prediction-endpoint/config",
    headers=headers,
    json=ab_test_config
)

# Monitor which model served each request
def predict_with_metadata(endpoint_name, instances):
    response = requests.post(
        f"{workspace_url}/serving-endpoints/{endpoint_name}/invocations",
        headers={**headers, "X-Databricks-Return-Metadata": "true"},
        json={"instances": instances}
    )
    return response.json()

# Response includes which model version served the request

Canary Deployments

Gradually roll out new models:

class CanaryDeployment:
    def __init__(self, endpoint_name, current_model, new_model):
        self.endpoint_name = endpoint_name
        self.current_model = current_model
        self.new_model = new_model
        self.traffic_percentages = [5, 10, 25, 50, 100]
        self.current_stage = 0

    def advance_stage(self):
        if self.current_stage >= len(self.traffic_percentages):
            print("Deployment complete")
            return

        new_traffic = self.traffic_percentages[self.current_stage]
        old_traffic = 100 - new_traffic

        config = {
            "traffic_config": {
                "routes": [
                    {"served_model_name": self.current_model, "traffic_percentage": old_traffic},
                    {"served_model_name": self.new_model, "traffic_percentage": new_traffic}
                ]
            }
        }

        response = requests.put(
            f"{workspace_url}/api/2.0/serving-endpoints/{self.endpoint_name}/config",
            headers=headers,
            json=config
        )

        self.current_stage += 1
        print(f"Advanced to {new_traffic}% traffic on new model")

    def rollback(self):
        config = {
            "traffic_config": {
                "routes": [
                    {"served_model_name": self.current_model, "traffic_percentage": 100},
                    {"served_model_name": self.new_model, "traffic_percentage": 0}
                ]
            }
        }

        requests.put(
            f"{workspace_url}/api/2.0/serving-endpoints/{self.endpoint_name}/config",
            headers=headers,
            json=config
        )
        print("Rolled back to current model")

Monitoring and Observability

# Get endpoint metrics
def get_endpoint_metrics(endpoint_name, hours=24):
    end_time = datetime.utcnow()
    start_time = end_time - timedelta(hours=hours)

    response = requests.get(
        f"{workspace_url}/api/2.0/serving-endpoints/{endpoint_name}/metrics",
        headers=headers,
        params={
            "start_time": start_time.isoformat(),
            "end_time": end_time.isoformat()
        }
    )

    return response.json()

# Parse metrics
metrics = get_endpoint_metrics("churn-prediction-endpoint")
print(f"Total requests: {metrics['request_count']}")
print(f"Average latency: {metrics['avg_latency_ms']}ms")
print(f"Error rate: {metrics['error_rate']}")

# Set up alerts via Azure Monitor
# Model serving metrics are automatically exported

Custom Metrics Logging

# Log custom metrics from within your model
import mlflow

class MonitoredModel(mlflow.pyfunc.PythonModel):
    def __init__(self, model):
        self.model = model

    def predict(self, context, model_input):
        import time

        start = time.time()
        predictions = self.model.predict(model_input)
        latency = time.time() - start

        # Log custom metrics
        mlflow.log_metric("prediction_latency", latency)
        mlflow.log_metric("batch_size", len(model_input))

        # Log prediction distribution
        mlflow.log_metric("mean_prediction", predictions.mean())

        return predictions

Cost Optimization

# Configure auto-scaling for cost efficiency
scaling_config = {
    "served_models": [{
        "model_name": "churn-predictor",
        "model_version": "3",
        "workload_size": "Small",
        "scale_to_zero_enabled": True,  # Save costs during idle periods
        "min_provisioned_throughput": 0,
        "max_provisioned_throughput": 10000
    }]
}

# Workload sizes and their characteristics:
# - Small: Up to 4 concurrent requests
# - Medium: Up to 16 concurrent requests
# - Large: Up to 64 concurrent requests

# For predictable traffic, provision capacity
high_traffic_config = {
    "served_models": [{
        "model_name": "churn-predictor",
        "model_version": "3",
        "workload_size": "Medium",
        "scale_to_zero_enabled": False,  # Keep warm for low latency
        "min_provisioned_throughput": 1000
    }]
}

Custom Python Environments

# Define custom environment for complex dependencies
import mlflow

# Create conda environment
conda_env = {
    "channels": ["defaults", "conda-forge"],
    "dependencies": [
        "python=3.9",
        "pip",
        {
            "pip": [
                "scikit-learn==1.0.2",
                "xgboost==1.5.0",
                "shap==0.40.0",
                "custom-package==1.0.0"
            ]
        }
    ]
}

# Log model with custom environment
mlflow.sklearn.log_model(
    model,
    "model",
    conda_env=conda_env,
    registered_model_name="model-with-custom-env"
)

Real-Time Feature Engineering

# Combine pre-computed features with real-time features
class RealTimeFeatureModel(mlflow.pyfunc.PythonModel):
    def __init__(self, model, feature_store_client):
        self.model = model
        self.fs = feature_store_client

    def predict(self, context, model_input):
        # model_input contains raw request data
        customer_ids = model_input["customer_id"].tolist()

        # Fetch pre-computed features from Feature Store
        stored_features = self.fs.read_table(
            "production.features.customer_features"
        ).filter(col("customer_id").isin(customer_ids))

        # Compute real-time features
        real_time_features = self.compute_real_time_features(model_input)

        # Combine features
        combined = stored_features.join(
            real_time_features,
            "customer_id"
        )

        # Predict
        return self.model.predict(combined)

    def compute_real_time_features(self, input_data):
        # Compute features from request data
        # e.g., time since last session, current cart value
        return input_data

Best Practices

High Availability

# Deploy to multiple endpoints for redundancy
endpoints = ["prod-east", "prod-west"]

# Use a load balancer to distribute traffic
# Configure health checks
health_check_config = {
    "health_check_enabled": True,
    "health_check_interval_seconds": 30
}

Request Validation

# Validate inputs before prediction
class ValidatedModel(mlflow.pyfunc.PythonModel):
    def __init__(self, model, schema):
        self.model = model
        self.schema = schema

    def predict(self, context, model_input):
        # Validate input schema
        for col, dtype in self.schema.items():
            if col not in model_input.columns:
                raise ValueError(f"Missing required column: {col}")
            if model_input[col].dtype != dtype:
                raise ValueError(f"Invalid type for {col}")

        # Validate value ranges
        if (model_input["age"] < 0).any():
            raise ValueError("Age cannot be negative")

        return self.model.predict(model_input)

Conclusion

Databricks Model Serving simplifies the deployment of ML models to production:

  • Serverless infrastructure eliminates operational overhead
  • Feature Store integration ensures consistency
  • A/B testing and canary deployments enable safe releases
  • Auto-scaling handles variable traffic efficiently

The integration with the broader Databricks ecosystem - MLflow, Feature Store, and Unity Catalog - provides a complete MLOps solution.

Resources

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.