1 min read
Cost Controls for AI Applications: Budgeting and Optimization
I wrote “Cost Controls for AI Applications: Budgeting and Optimization” to share practical, production-minded guidance on this topic.
Cost Tracking Foundation
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from datetime import datetime, timedelta
from decimal import Decimal
import threading
@dataclass
class ModelPricing:
"""Pricing for a specific model"""
model: str
input_per_million: Decimal
output_per_million: Decimal
cached_input_per_million: Optional[Decimal] = None
# Current pricing (as of September 2024)
PRICING = {
"gpt-4o": ModelPricing(
model="gpt-4o",
input_per_million=Decimal("2.50"),
output_per_million=Decimal("10.00")
),
"gpt-4o-mini": ModelPricing(
model="gpt-4o-mini",
input_per_million=Decimal("0.15"),
output_per_million=Decimal("0.60")
),
"o1-preview": ModelPricing(
model="o1-preview",
input_per_million=Decimal("15.00"),
output_per_million=Decimal("60.00")
),
"o1-mini": ModelPricing(
model="o1-mini",
input_per_million=Decimal("3.00"),
output_per_million=Decimal("12.00")
),
"text-embedding-3-small": ModelPricing(
model="text-embedding-3-small",
input_per_million=Decimal("0.02"),
output_per_million=Decimal("0")
)
}
@dataclass
class UsageRecord:
"""Record of API usage"""
timestamp: datetime
model: str
input_tokens: int
output_tokens: int
cost: Decimal
user_id: Optional[str] = None
request_id: Optional[str] = None
class CostTracker:
"""Track AI API costs"""
def __init__(self):
self.records: List[UsageRecord] = []
self._lock = threading.Lock()
def calculate_cost(self, model: str, input_tokens: int,
output_tokens: int) -> Decimal:
"""Calculate cost for a request"""
if model not in PRICING:
# Default to gpt-4o pricing if unknown
pricing = PRICING["gpt-4o"]
else:
pricing = PRICING[model]
input_cost = (Decimal(input_tokens) / 1_000_000) * pricing.input_per_million
output_cost = (Decimal(output_tokens) / 1_000_000) * pricing.output_per_million
return input_cost + output_cost
def record(self, model: str, input_tokens: int, output_tokens: int,
user_id: str = None, request_id: str = None):
"""Record usage"""
cost = self.calculate_cost(model, input_tokens, output_tokens)
record = UsageRecord(
timestamp=datetime.now(),
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost=cost,
user_id=user_id,
request_id=request_id
)
with self._lock:
self.records.append(record)
return cost
def get_total_cost(self, period: timedelta = None) -> Decimal:
"""Get total cost for a period"""
with self._lock:
if period:
cutoff = datetime.now() - period
records = [r for r in self.records if r.timestamp > cutoff]
else:
records = self.records
return sum(r.cost for r in records)
def get_cost_by_model(self, period: timedelta = None) -> Dict[str, Decimal]:
"""Get cost breakdown by model"""
with self._lock:
if period:
cutoff = datetime.now() - period
records = [r for r in self.records if r.timestamp > cutoff]
else:
records = self.records
costs = {}
for record in records:
if record.model not in costs:
costs[record.model] = Decimal("0")
costs[record.model] += record.cost
return costs
def get_cost_by_user(self, period: timedelta = None) -> Dict[str, Decimal]:
"""Get cost breakdown by user"""
with self._lock:
if period:
cutoff = datetime.now() - period
records = [r for r in self.records if r.timestamp > cutoff]
else:
records = self.records
costs = {}
for record in records:
user = record.user_id or "anonymous"
if user not in costs:
costs[user] = Decimal("0")
costs[user] += record.cost
return costs
Budget Enforcement
from decimal import Decimal
from enum import Enum
class BudgetAction(Enum):
ALLOW = "allow"
WARN = "warn"
LIMIT = "limit"
BLOCK = "block"
@dataclass
class Budget:
"""Budget configuration"""
name: str
amount: Decimal
period: timedelta
warning_threshold: float = 0.75
limit_threshold: float = 0.9
class BudgetEnforcer:
"""Enforce budget limits"""
def __init__(self, cost_tracker: CostTracker):
self.cost_tracker = cost_tracker
self.budgets: Dict[str, Budget] = {}
def set_budget(self, entity_id: str, budget: Budget):
"""Set budget for an entity"""
self.budgets[entity_id] = budget
def check_budget(self, entity_id: str, estimated_cost: Decimal) -> tuple[BudgetAction, str]:
"""Check if request is within budget"""
if entity_id not in self.budgets:
return BudgetAction.ALLOW, "No budget set"
budget = self.budgets[entity_id]
current_spend = self.cost_tracker.get_total_cost(budget.period)
projected_spend = current_spend + estimated_cost
utilization = float(projected_spend / budget.amount)
if utilization >= 1.0:
return BudgetAction.BLOCK, f"Budget exhausted: ${current_spend:.2f}/${budget.amount:.2f}"
if utilization >= budget.limit_threshold:
return BudgetAction.LIMIT, f"Budget nearly exhausted: {utilization:.1%} used"
if utilization >= budget.warning_threshold:
return BudgetAction.WARN, f"Budget warning: {utilization:.1%} used"
return BudgetAction.ALLOW, f"Budget available: {utilization:.1%} used"
def get_budget_status(self, entity_id: str) -> dict:
"""Get budget status for an entity"""
if entity_id not in self.budgets:
return {"error": "No budget set"}
budget = self.budgets[entity_id]
current_spend = self.cost_tracker.get_total_cost(budget.period)
return {
"budget": float(budget.amount),
"spent": float(current_spend),
"remaining": float(budget.amount - current_spend),
"utilization": float(current_spend / budget.amount),
"period_days": budget.period.days
}
# Usage
cost_tracker = CostTracker()
budget_enforcer = BudgetEnforcer(cost_tracker)
# Set daily budget of $100
budget_enforcer.set_budget("org_123", Budget(
name="daily_budget",
amount=Decimal("100.00"),
period=timedelta(days=1)
))
# Before each request
def call_with_budget(org_id: str, prompt: str, model: str = "gpt-4o") -> str:
# Estimate cost
estimated_tokens = len(prompt) // 4 + 1000
estimated_cost = cost_tracker.calculate_cost(
model,
input_tokens=estimated_tokens,
output_tokens=1000
)
# Check budget
action, message = budget_enforcer.check_budget(org_id, estimated_cost)
if action == BudgetAction.BLOCK:
raise BudgetExceededError(message)
if action == BudgetAction.LIMIT:
# Use cheaper model
model = "gpt-4o-mini"
# Make request
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}]
)
# Record actual cost
usage = response.usage
cost_tracker.record(
model=model,
input_tokens=usage.prompt_tokens,
output_tokens=usage.completion_tokens,
user_id=org_id
)
return response.choices[0].message.content
class BudgetExceededError(Exception):
pass
Cost Optimization Strategies
class CostOptimizer:
"""Optimize AI costs through various strategies"""
def __init__(self):
self.cache = {}
self.cost_tracker = CostTracker()
def optimize_request(self, prompt: str, required_quality: str = "standard") -> dict:
"""Optimize a request for cost"""
strategies = []
# Strategy 1: Check cache
cache_key = self._hash_prompt(prompt)
if cache_key in self.cache:
return {
"source": "cache",
"cost": Decimal("0"),
"response": self.cache[cache_key]
}
strategies.append("cache_miss")
# Strategy 2: Model selection based on quality
if required_quality == "high":
model = "gpt-4o"
elif required_quality == "standard":
model = "gpt-4o-mini" # 16x cheaper
else:
model = "gpt-4o-mini"
strategies.append(f"model:{model}")
# Strategy 3: Prompt optimization
optimized_prompt = self._optimize_prompt(prompt)
strategies.append(f"prompt_reduced:{len(prompt)-len(optimized_prompt)}_chars")
# Strategy 4: Output length control
max_tokens = self._estimate_required_tokens(prompt)
strategies.append(f"max_tokens:{max_tokens}")
return {
"model": model,
"prompt": optimized_prompt,
"max_tokens": max_tokens,
"strategies_applied": strategies,
"estimated_cost": self.cost_tracker.calculate_cost(
model, len(optimized_prompt) // 4, max_tokens
)
}
def _hash_prompt(self, prompt: str) -> str:
import hashlib
return hashlib.sha256(prompt.encode()).hexdigest()
def _optimize_prompt(self, prompt: str) -> str:
"""Remove unnecessary whitespace and redundancy"""
# Remove extra whitespace
optimized = " ".join(prompt.split())
return optimized
def _estimate_required_tokens(self, prompt: str) -> int:
"""Estimate required output tokens"""
# Heuristic: shorter prompts usually need shorter responses
prompt_tokens = len(prompt) // 4
if prompt_tokens < 100:
return 500
elif prompt_tokens < 500:
return 1000
else:
return 2000
class ModelRouter:
"""Route requests to optimal models based on cost/quality tradeoffs"""
def __init__(self):
self.complexity_classifier = self._build_classifier()
def route(self, prompt: str, budget_remaining: Decimal) -> str:
"""Choose optimal model for request"""
# Estimate complexity
complexity = self._estimate_complexity(prompt)
# Route based on complexity and budget
if budget_remaining < Decimal("0.01"):
return "gpt-4o-mini" # Cheapest option
if complexity == "high":
if budget_remaining > Decimal("0.10"):
return "gpt-4o"
return "gpt-4o-mini"
if complexity == "reasoning":
if budget_remaining > Decimal("1.00"):
return "o1-mini"
return "gpt-4o"
return "gpt-4o-mini"
def _estimate_complexity(self, prompt: str) -> str:
"""Estimate task complexity"""
prompt_lower = prompt.lower()
# Keywords indicating reasoning tasks
reasoning_keywords = ["prove", "derive", "analyze", "explain why", "step by step"]
if any(kw in prompt_lower for kw in reasoning_keywords):
return "reasoning"
# Keywords indicating complex tasks
complex_keywords = ["compare", "contrast", "evaluate", "synthesize"]
if any(kw in prompt_lower for kw in complex_keywords):
return "high"
return "standard"
Cost Reporting
class CostReporter:
"""Generate cost reports"""
def __init__(self, cost_tracker: CostTracker):
self.cost_tracker = cost_tracker
def daily_report(self) -> dict:
"""Generate daily cost report"""
period = timedelta(days=1)
return {
"period": "daily",
"total_cost": float(self.cost_tracker.get_total_cost(period)),
"by_model": {
k: float(v)
for k, v in self.cost_tracker.get_cost_by_model(period).items()
},
"by_user": {
k: float(v)
for k, v in self.cost_tracker.get_cost_by_user(period).items()
},
"generated_at": datetime.now().isoformat()
}
def monthly_summary(self) -> dict:
"""Generate monthly summary"""
period = timedelta(days=30)
daily_costs = []
# Calculate daily costs for trend
for i in range(30):
day_start = datetime.now() - timedelta(days=i+1)
day_end = datetime.now() - timedelta(days=i)
# Filter records for this day
day_cost = sum(
r.cost for r in self.cost_tracker.records
if day_start <= r.timestamp < day_end
)
daily_costs.append(float(day_cost))
return {
"period": "monthly",
"total_cost": float(self.cost_tracker.get_total_cost(period)),
"daily_average": sum(daily_costs) / len(daily_costs) if daily_costs else 0,
"daily_trend": daily_costs,
"by_model": {
k: float(v)
for k, v in self.cost_tracker.get_cost_by_model(period).items()
},
"top_users": dict(sorted(
self.cost_tracker.get_cost_by_user(period).items(),
key=lambda x: x[1],
reverse=True
)[:10]),
"recommendations": self._generate_recommendations()
}
def _generate_recommendations(self) -> List[str]:
"""Generate cost optimization recommendations"""
recommendations = []
by_model = self.cost_tracker.get_cost_by_model(timedelta(days=7))
# Check if expensive models are overused
if "gpt-4o" in by_model and by_model["gpt-4o"] > Decimal("100"):
recommendations.append(
"Consider using gpt-4o-mini for simpler tasks - potential 90% savings"
)
if "o1-preview" in by_model and by_model["o1-preview"] > Decimal("50"):
recommendations.append(
"Review o1-preview usage - ensure it's used only for complex reasoning tasks"
)
return recommendations
Cost control in AI applications requires a multi-layered approach: track everything, set budgets, optimize requests, and review regularly. The strategies here help you build cost-effective AI applications without sacrificing quality.\n\n## Takeaways\n\nAdd a concise, personal takeaway and recommended next steps here.\n