Back to Blog
7 min read

Cost Controls for AI Applications: Budgeting and Optimization

AI API costs can escalate quickly without proper controls. Let’s explore strategies for budgeting, monitoring, and optimizing AI application costs.

Cost Tracking Foundation

from dataclasses import dataclass, field
from typing import Dict, List, Optional
from datetime import datetime, timedelta
from decimal import Decimal
import threading

@dataclass
class ModelPricing:
    """Pricing for a specific model"""
    model: str
    input_per_million: Decimal
    output_per_million: Decimal
    cached_input_per_million: Optional[Decimal] = None

# Current pricing (as of September 2024)
PRICING = {
    "gpt-4o": ModelPricing(
        model="gpt-4o",
        input_per_million=Decimal("2.50"),
        output_per_million=Decimal("10.00")
    ),
    "gpt-4o-mini": ModelPricing(
        model="gpt-4o-mini",
        input_per_million=Decimal("0.15"),
        output_per_million=Decimal("0.60")
    ),
    "o1-preview": ModelPricing(
        model="o1-preview",
        input_per_million=Decimal("15.00"),
        output_per_million=Decimal("60.00")
    ),
    "o1-mini": ModelPricing(
        model="o1-mini",
        input_per_million=Decimal("3.00"),
        output_per_million=Decimal("12.00")
    ),
    "text-embedding-3-small": ModelPricing(
        model="text-embedding-3-small",
        input_per_million=Decimal("0.02"),
        output_per_million=Decimal("0")
    )
}

@dataclass
class UsageRecord:
    """Record of API usage"""
    timestamp: datetime
    model: str
    input_tokens: int
    output_tokens: int
    cost: Decimal
    user_id: Optional[str] = None
    request_id: Optional[str] = None

class CostTracker:
    """Track AI API costs"""

    def __init__(self):
        self.records: List[UsageRecord] = []
        self._lock = threading.Lock()

    def calculate_cost(self, model: str, input_tokens: int,
                      output_tokens: int) -> Decimal:
        """Calculate cost for a request"""
        if model not in PRICING:
            # Default to gpt-4o pricing if unknown
            pricing = PRICING["gpt-4o"]
        else:
            pricing = PRICING[model]

        input_cost = (Decimal(input_tokens) / 1_000_000) * pricing.input_per_million
        output_cost = (Decimal(output_tokens) / 1_000_000) * pricing.output_per_million

        return input_cost + output_cost

    def record(self, model: str, input_tokens: int, output_tokens: int,
              user_id: str = None, request_id: str = None):
        """Record usage"""
        cost = self.calculate_cost(model, input_tokens, output_tokens)

        record = UsageRecord(
            timestamp=datetime.now(),
            model=model,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            cost=cost,
            user_id=user_id,
            request_id=request_id
        )

        with self._lock:
            self.records.append(record)

        return cost

    def get_total_cost(self, period: timedelta = None) -> Decimal:
        """Get total cost for a period"""
        with self._lock:
            if period:
                cutoff = datetime.now() - period
                records = [r for r in self.records if r.timestamp > cutoff]
            else:
                records = self.records

            return sum(r.cost for r in records)

    def get_cost_by_model(self, period: timedelta = None) -> Dict[str, Decimal]:
        """Get cost breakdown by model"""
        with self._lock:
            if period:
                cutoff = datetime.now() - period
                records = [r for r in self.records if r.timestamp > cutoff]
            else:
                records = self.records

            costs = {}
            for record in records:
                if record.model not in costs:
                    costs[record.model] = Decimal("0")
                costs[record.model] += record.cost

            return costs

    def get_cost_by_user(self, period: timedelta = None) -> Dict[str, Decimal]:
        """Get cost breakdown by user"""
        with self._lock:
            if period:
                cutoff = datetime.now() - period
                records = [r for r in self.records if r.timestamp > cutoff]
            else:
                records = self.records

            costs = {}
            for record in records:
                user = record.user_id or "anonymous"
                if user not in costs:
                    costs[user] = Decimal("0")
                costs[user] += record.cost

            return costs

Budget Enforcement

from decimal import Decimal
from enum import Enum

class BudgetAction(Enum):
    ALLOW = "allow"
    WARN = "warn"
    LIMIT = "limit"
    BLOCK = "block"

@dataclass
class Budget:
    """Budget configuration"""
    name: str
    amount: Decimal
    period: timedelta
    warning_threshold: float = 0.75
    limit_threshold: float = 0.9

class BudgetEnforcer:
    """Enforce budget limits"""

    def __init__(self, cost_tracker: CostTracker):
        self.cost_tracker = cost_tracker
        self.budgets: Dict[str, Budget] = {}

    def set_budget(self, entity_id: str, budget: Budget):
        """Set budget for an entity"""
        self.budgets[entity_id] = budget

    def check_budget(self, entity_id: str, estimated_cost: Decimal) -> tuple[BudgetAction, str]:
        """Check if request is within budget"""

        if entity_id not in self.budgets:
            return BudgetAction.ALLOW, "No budget set"

        budget = self.budgets[entity_id]
        current_spend = self.cost_tracker.get_total_cost(budget.period)
        projected_spend = current_spend + estimated_cost

        utilization = float(projected_spend / budget.amount)

        if utilization >= 1.0:
            return BudgetAction.BLOCK, f"Budget exhausted: ${current_spend:.2f}/${budget.amount:.2f}"

        if utilization >= budget.limit_threshold:
            return BudgetAction.LIMIT, f"Budget nearly exhausted: {utilization:.1%} used"

        if utilization >= budget.warning_threshold:
            return BudgetAction.WARN, f"Budget warning: {utilization:.1%} used"

        return BudgetAction.ALLOW, f"Budget available: {utilization:.1%} used"

    def get_budget_status(self, entity_id: str) -> dict:
        """Get budget status for an entity"""

        if entity_id not in self.budgets:
            return {"error": "No budget set"}

        budget = self.budgets[entity_id]
        current_spend = self.cost_tracker.get_total_cost(budget.period)

        return {
            "budget": float(budget.amount),
            "spent": float(current_spend),
            "remaining": float(budget.amount - current_spend),
            "utilization": float(current_spend / budget.amount),
            "period_days": budget.period.days
        }

# Usage
cost_tracker = CostTracker()
budget_enforcer = BudgetEnforcer(cost_tracker)

# Set daily budget of $100
budget_enforcer.set_budget("org_123", Budget(
    name="daily_budget",
    amount=Decimal("100.00"),
    period=timedelta(days=1)
))

# Before each request
def call_with_budget(org_id: str, prompt: str, model: str = "gpt-4o") -> str:
    # Estimate cost
    estimated_tokens = len(prompt) // 4 + 1000
    estimated_cost = cost_tracker.calculate_cost(
        model,
        input_tokens=estimated_tokens,
        output_tokens=1000
    )

    # Check budget
    action, message = budget_enforcer.check_budget(org_id, estimated_cost)

    if action == BudgetAction.BLOCK:
        raise BudgetExceededError(message)

    if action == BudgetAction.LIMIT:
        # Use cheaper model
        model = "gpt-4o-mini"

    # Make request
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}]
    )

    # Record actual cost
    usage = response.usage
    cost_tracker.record(
        model=model,
        input_tokens=usage.prompt_tokens,
        output_tokens=usage.completion_tokens,
        user_id=org_id
    )

    return response.choices[0].message.content

class BudgetExceededError(Exception):
    pass

Cost Optimization Strategies

class CostOptimizer:
    """Optimize AI costs through various strategies"""

    def __init__(self):
        self.cache = {}
        self.cost_tracker = CostTracker()

    def optimize_request(self, prompt: str, required_quality: str = "standard") -> dict:
        """Optimize a request for cost"""

        strategies = []

        # Strategy 1: Check cache
        cache_key = self._hash_prompt(prompt)
        if cache_key in self.cache:
            return {
                "source": "cache",
                "cost": Decimal("0"),
                "response": self.cache[cache_key]
            }
        strategies.append("cache_miss")

        # Strategy 2: Model selection based on quality
        if required_quality == "high":
            model = "gpt-4o"
        elif required_quality == "standard":
            model = "gpt-4o-mini"  # 16x cheaper
        else:
            model = "gpt-4o-mini"

        strategies.append(f"model:{model}")

        # Strategy 3: Prompt optimization
        optimized_prompt = self._optimize_prompt(prompt)
        strategies.append(f"prompt_reduced:{len(prompt)-len(optimized_prompt)}_chars")

        # Strategy 4: Output length control
        max_tokens = self._estimate_required_tokens(prompt)
        strategies.append(f"max_tokens:{max_tokens}")

        return {
            "model": model,
            "prompt": optimized_prompt,
            "max_tokens": max_tokens,
            "strategies_applied": strategies,
            "estimated_cost": self.cost_tracker.calculate_cost(
                model, len(optimized_prompt) // 4, max_tokens
            )
        }

    def _hash_prompt(self, prompt: str) -> str:
        import hashlib
        return hashlib.sha256(prompt.encode()).hexdigest()

    def _optimize_prompt(self, prompt: str) -> str:
        """Remove unnecessary whitespace and redundancy"""
        # Remove extra whitespace
        optimized = " ".join(prompt.split())
        return optimized

    def _estimate_required_tokens(self, prompt: str) -> int:
        """Estimate required output tokens"""
        # Heuristic: shorter prompts usually need shorter responses
        prompt_tokens = len(prompt) // 4

        if prompt_tokens < 100:
            return 500
        elif prompt_tokens < 500:
            return 1000
        else:
            return 2000

class ModelRouter:
    """Route requests to optimal models based on cost/quality tradeoffs"""

    def __init__(self):
        self.complexity_classifier = self._build_classifier()

    def route(self, prompt: str, budget_remaining: Decimal) -> str:
        """Choose optimal model for request"""

        # Estimate complexity
        complexity = self._estimate_complexity(prompt)

        # Route based on complexity and budget
        if budget_remaining < Decimal("0.01"):
            return "gpt-4o-mini"  # Cheapest option

        if complexity == "high":
            if budget_remaining > Decimal("0.10"):
                return "gpt-4o"
            return "gpt-4o-mini"

        if complexity == "reasoning":
            if budget_remaining > Decimal("1.00"):
                return "o1-mini"
            return "gpt-4o"

        return "gpt-4o-mini"

    def _estimate_complexity(self, prompt: str) -> str:
        """Estimate task complexity"""
        prompt_lower = prompt.lower()

        # Keywords indicating reasoning tasks
        reasoning_keywords = ["prove", "derive", "analyze", "explain why", "step by step"]
        if any(kw in prompt_lower for kw in reasoning_keywords):
            return "reasoning"

        # Keywords indicating complex tasks
        complex_keywords = ["compare", "contrast", "evaluate", "synthesize"]
        if any(kw in prompt_lower for kw in complex_keywords):
            return "high"

        return "standard"

Cost Reporting

class CostReporter:
    """Generate cost reports"""

    def __init__(self, cost_tracker: CostTracker):
        self.cost_tracker = cost_tracker

    def daily_report(self) -> dict:
        """Generate daily cost report"""
        period = timedelta(days=1)

        return {
            "period": "daily",
            "total_cost": float(self.cost_tracker.get_total_cost(period)),
            "by_model": {
                k: float(v)
                for k, v in self.cost_tracker.get_cost_by_model(period).items()
            },
            "by_user": {
                k: float(v)
                for k, v in self.cost_tracker.get_cost_by_user(period).items()
            },
            "generated_at": datetime.now().isoformat()
        }

    def monthly_summary(self) -> dict:
        """Generate monthly summary"""
        period = timedelta(days=30)
        daily_costs = []

        # Calculate daily costs for trend
        for i in range(30):
            day_start = datetime.now() - timedelta(days=i+1)
            day_end = datetime.now() - timedelta(days=i)
            # Filter records for this day
            day_cost = sum(
                r.cost for r in self.cost_tracker.records
                if day_start <= r.timestamp < day_end
            )
            daily_costs.append(float(day_cost))

        return {
            "period": "monthly",
            "total_cost": float(self.cost_tracker.get_total_cost(period)),
            "daily_average": sum(daily_costs) / len(daily_costs) if daily_costs else 0,
            "daily_trend": daily_costs,
            "by_model": {
                k: float(v)
                for k, v in self.cost_tracker.get_cost_by_model(period).items()
            },
            "top_users": dict(sorted(
                self.cost_tracker.get_cost_by_user(period).items(),
                key=lambda x: x[1],
                reverse=True
            )[:10]),
            "recommendations": self._generate_recommendations()
        }

    def _generate_recommendations(self) -> List[str]:
        """Generate cost optimization recommendations"""
        recommendations = []

        by_model = self.cost_tracker.get_cost_by_model(timedelta(days=7))

        # Check if expensive models are overused
        if "gpt-4o" in by_model and by_model["gpt-4o"] > Decimal("100"):
            recommendations.append(
                "Consider using gpt-4o-mini for simpler tasks - potential 90% savings"
            )

        if "o1-preview" in by_model and by_model["o1-preview"] > Decimal("50"):
            recommendations.append(
                "Review o1-preview usage - ensure it's used only for complex reasoning tasks"
            )

        return recommendations

Cost control in AI applications requires a multi-layered approach: track everything, set budgets, optimize requests, and review regularly. The strategies here help you build cost-effective AI applications without sacrificing quality.

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.