Back to Blog
6 min read

Serverless Model Serving: Pay-Per-Request ML Inference

Serverless Model Serving: Pay-Per-Request ML Inference

Serverless model serving eliminates infrastructure management while providing cost-effective, scalable ML inference. This guide covers implementing serverless patterns in Databricks.

Serverless Benefits

SERVERLESS_ADVANTAGES = {
    "cost_efficiency": {
        "description": "Pay only for actual inference requests",
        "best_for": ["Variable workloads", "Dev/test environments", "Batch-heavy patterns"]
    },
    "scalability": {
        "description": "Automatic scaling from zero to high throughput",
        "best_for": ["Unpredictable traffic", "Burst handling", "Global deployments"]
    },
    "simplicity": {
        "description": "No infrastructure management",
        "best_for": ["Small teams", "Rapid prototyping", "Focus on ML not ops"]
    }
}

Creating Serverless Endpoints

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import (
    EndpointCoreConfigInput,
    ServedEntityInput
)

client = WorkspaceClient()

def create_serverless_endpoint(
    name: str,
    model_name: str,
    model_version: str,
    min_scale: int = 0,
    max_scale: int = 10
):
    """Create a serverless model endpoint"""

    endpoint = client.serving_endpoints.create_and_wait(
        name=name,
        config=EndpointCoreConfigInput(
            served_entities=[
                ServedEntityInput(
                    entity_name=model_name,
                    entity_version=model_version,
                    # Serverless configuration
                    workload_size="Small",
                    scale_to_zero_enabled=(min_scale == 0),
                    # Custom scaling bounds
                    min_instances=min_scale,
                    max_instances=max_scale
                )
            ]
        )
    )

    print(f"Endpoint created: {name}")
    print(f"State: {endpoint.state.ready}")

    return endpoint

# Create endpoint that scales to zero
endpoint = create_serverless_endpoint(
    name="recommendation-engine",
    model_name="main.ml_models.product_recommender",
    model_version="1",
    min_scale=0,
    max_scale=5
)

Cold Start Optimization

import time
import requests
from typing import Optional

class ServerlessEndpointClient:
    """Client optimized for serverless endpoints"""

    def __init__(
        self,
        workspace_url: str,
        token: str,
        warm_up_enabled: bool = True
    ):
        self.workspace_url = workspace_url
        self.token = token
        self.warm_up_enabled = warm_up_enabled
        self.headers = {
            "Authorization": f"Bearer {token}",
            "Content-Type": "application/json"
        }

    def predict(
        self,
        endpoint_name: str,
        inputs: list,
        timeout: int = 60
    ) -> dict:
        """Make prediction with cold start handling"""

        url = f"{self.workspace_url}/serving-endpoints/{endpoint_name}/invocations"

        # Retry logic for cold starts
        max_retries = 3
        retry_delay = 5

        for attempt in range(max_retries):
            try:
                response = requests.post(
                    url,
                    headers=self.headers,
                    json={"inputs": inputs},
                    timeout=timeout
                )

                if response.status_code == 200:
                    return response.json()
                elif response.status_code == 503:
                    # Endpoint scaling up
                    print(f"Endpoint scaling up, retry in {retry_delay}s...")
                    time.sleep(retry_delay)
                else:
                    response.raise_for_status()

            except requests.exceptions.Timeout:
                if attempt < max_retries - 1:
                    print(f"Request timed out, retry {attempt + 1}/{max_retries}")
                    time.sleep(retry_delay)
                else:
                    raise

        raise Exception("Max retries exceeded")

    def warm_up(self, endpoint_name: str, sample_input: dict):
        """Send warm-up request to prevent cold start"""

        if not self.warm_up_enabled:
            return

        try:
            self.predict(endpoint_name, [sample_input], timeout=120)
            print(f"Endpoint {endpoint_name} warmed up")
        except Exception as e:
            print(f"Warm-up failed (non-critical): {e}")

    def batch_predict(
        self,
        endpoint_name: str,
        inputs: list,
        batch_size: int = 100
    ) -> list:
        """Batch predictions with optimal batching"""

        results = []

        for i in range(0, len(inputs), batch_size):
            batch = inputs[i:i + batch_size]
            result = self.predict(endpoint_name, batch)
            results.extend(result.get("predictions", []))

        return results

# Usage
client = ServerlessEndpointClient(
    workspace_url="https://adb-xxx.azuredatabricks.net",
    token="your-token"
)

# Warm up endpoint (e.g., before expected traffic)
client.warm_up(
    "recommendation-engine",
    sample_input={"user_id": "test", "product_id": "test"}
)

# Make predictions
predictions = client.predict(
    "recommendation-engine",
    inputs=[{"user_id": "user123", "product_id": "prod456"}]
)

Cost Optimization Strategies

class CostOptimizedServing:
    """Cost optimization patterns for serverless serving"""

    def __init__(self, workspace_client):
        self.client = workspace_client

    def get_usage_report(self, endpoint_name: str, days: int = 30) -> dict:
        """Generate usage and cost report"""

        # Query usage metrics
        usage = spark.sql(f"""
            SELECT
                DATE(timestamp) as date,
                COUNT(*) as request_count,
                SUM(compute_duration_ms) / 1000 / 60 as compute_minutes,
                AVG(input_tokens + output_tokens) as avg_tokens
            FROM main.ml_monitoring.{endpoint_name.replace('-', '_')}_usage
            WHERE timestamp >= current_date() - INTERVAL {days} DAYS
            GROUP BY DATE(timestamp)
            ORDER BY date
        """).collect()

        # Calculate costs (simplified)
        total_compute_minutes = sum(row["compute_minutes"] for row in usage)
        total_requests = sum(row["request_count"] for row in usage)

        # Cost estimates (example rates)
        cost_per_minute = 0.01
        estimated_cost = total_compute_minutes * cost_per_minute

        return {
            "total_requests": total_requests,
            "total_compute_minutes": total_compute_minutes,
            "estimated_cost": estimated_cost,
            "cost_per_request": estimated_cost / total_requests if total_requests > 0 else 0,
            "daily_breakdown": usage
        }

    def recommend_configuration(self, usage_report: dict) -> dict:
        """Recommend endpoint configuration based on usage"""

        avg_daily_requests = usage_report["total_requests"] / 30

        if avg_daily_requests < 100:
            recommendation = {
                "config": "serverless_scale_to_zero",
                "workload_size": "Small",
                "min_instances": 0,
                "reason": "Low traffic - maximize cost savings with scale to zero"
            }
        elif avg_daily_requests < 10000:
            recommendation = {
                "config": "serverless_warm",
                "workload_size": "Small",
                "min_instances": 1,
                "reason": "Moderate traffic - keep one instance warm to avoid cold starts"
            }
        else:
            recommendation = {
                "config": "provisioned",
                "workload_size": "Medium",
                "min_instances": 2,
                "reason": "High traffic - consider provisioned throughput for cost predictability"
            }

        return recommendation

    def implement_caching(self, endpoint_name: str, cache_ttl: int = 300):
        """Implement response caching to reduce inference calls"""

        # Cache configuration
        return {
            "strategy": "lru_cache",
            "max_size": 10000,
            "ttl_seconds": cache_ttl,
            "key_function": "hash(input_features)"
        }

# Usage
optimizer = CostOptimizedServing(client)

# Get usage report
report = optimizer.get_usage_report("recommendation-engine", days=30)
print(f"Total cost: ${report['estimated_cost']:.2f}")
print(f"Cost per request: ${report['cost_per_request']:.4f}")

# Get recommendations
recommendation = optimizer.recommend_configuration(report)
print(f"Recommended config: {recommendation['config']}")
print(f"Reason: {recommendation['reason']}")

Hybrid Serving Patterns

class HybridServingRouter:
    """Route between serverless and provisioned endpoints"""

    def __init__(self, serverless_endpoint: str, provisioned_endpoint: str):
        self.serverless = serverless_endpoint
        self.provisioned = provisioned_endpoint
        self.request_count = 0
        self.threshold = 100  # Requests per minute threshold

    def route(self, inputs: list, latency_requirement: str = "normal") -> str:
        """Determine which endpoint to use"""

        if latency_requirement == "low":
            # Use provisioned for low-latency requirements
            return self.provisioned
        elif self.request_count > self.threshold:
            # High traffic - use provisioned for cost efficiency
            return self.provisioned
        else:
            # Normal traffic - use serverless
            return self.serverless

    def predict(
        self,
        inputs: list,
        latency_requirement: str = "normal"
    ) -> dict:
        """Make prediction with automatic routing"""

        endpoint = self.route(inputs, latency_requirement)

        # Make prediction
        result = self._call_endpoint(endpoint, inputs)

        # Track request for routing decisions
        self.request_count += 1

        return {
            "predictions": result,
            "endpoint_used": endpoint
        }

    def _call_endpoint(self, endpoint: str, inputs: list) -> dict:
        # Implementation
        pass

# Usage
router = HybridServingRouter(
    serverless_endpoint="model-serverless",
    provisioned_endpoint="model-provisioned"
)

# Normal request - may use serverless
result = router.predict(inputs, latency_requirement="normal")

# Low-latency request - uses provisioned
result = router.predict(inputs, latency_requirement="low")

Monitoring Serverless Endpoints

class ServerlessMonitor:
    """Monitor serverless endpoint health and costs"""

    def __init__(self, endpoint_name: str):
        self.endpoint_name = endpoint_name

    def get_scaling_metrics(self, hours: int = 24) -> dict:
        """Get scaling behavior metrics"""

        metrics = spark.sql(f"""
            SELECT
                DATE_TRUNC('hour', timestamp) as hour,
                COUNT(*) as requests,
                MAX(active_instances) as max_instances,
                MIN(active_instances) as min_instances,
                SUM(CASE WHEN cold_start THEN 1 ELSE 0 END) as cold_starts
            FROM main.ml_monitoring.{self.endpoint_name.replace('-', '_')}_scaling
            WHERE timestamp >= current_timestamp() - INTERVAL {hours} HOURS
            GROUP BY DATE_TRUNC('hour', timestamp)
            ORDER BY hour
        """)

        return metrics.toPandas().to_dict('records')

    def get_cold_start_analysis(self) -> dict:
        """Analyze cold start patterns"""

        analysis = spark.sql(f"""
            SELECT
                HOUR(timestamp) as hour_of_day,
                COUNT(*) as total_requests,
                SUM(CASE WHEN cold_start THEN 1 ELSE 0 END) as cold_starts,
                AVG(CASE WHEN cold_start THEN latency_ms ELSE NULL END) as avg_cold_start_latency,
                AVG(CASE WHEN NOT cold_start THEN latency_ms ELSE NULL END) as avg_warm_latency
            FROM main.ml_monitoring.{self.endpoint_name.replace('-', '_')}_requests
            WHERE timestamp >= current_date() - INTERVAL 7 DAYS
            GROUP BY HOUR(timestamp)
            ORDER BY hour_of_day
        """).collect()

        return {
            "by_hour": analysis,
            "recommendation": self._get_warm_up_recommendation(analysis)
        }

    def _get_warm_up_recommendation(self, analysis) -> str:
        """Recommend warm-up schedule based on cold start patterns"""

        high_cold_start_hours = [
            row["hour_of_day"] for row in analysis
            if row["cold_starts"] / row["total_requests"] > 0.1
        ]

        if high_cold_start_hours:
            return f"Consider warm-up requests at hours: {high_cold_start_hours}"
        return "Cold starts are minimal, no warm-up needed"

# Usage
monitor = ServerlessMonitor("recommendation-engine")
cold_start_analysis = monitor.get_cold_start_analysis()
print(cold_start_analysis["recommendation"])

Conclusion

Serverless model serving provides cost-effective, scalable ML inference. Handle cold starts with warm-up strategies, optimize costs with caching and intelligent routing, and monitor scaling behavior for continuous improvement.

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.