Back to Blog
6 min read

Tracing Frameworks for AI: Distributed Tracing in LLM Applications

Distributed tracing is essential for understanding complex AI applications. Let’s explore how to implement effective tracing for LLM-powered systems.

Tracing Concepts for AI

from dataclasses import dataclass, field
from typing import Dict, Any, List, Optional
from contextlib import contextmanager
import threading
import time
import uuid

@dataclass
class TraceContext:
    """Context that flows through a distributed trace"""
    trace_id: str
    span_id: str
    parent_span_id: Optional[str] = None
    baggage: Dict[str, str] = field(default_factory=dict)

@dataclass
class AISpan:
    """Span specialized for AI operations"""
    context: TraceContext
    operation_name: str
    service_name: str
    start_time: float
    end_time: Optional[float] = None
    status: str = "UNSET"
    attributes: Dict[str, Any] = field(default_factory=dict)
    events: List[Dict] = field(default_factory=list)

    # AI-specific fields
    model: Optional[str] = None
    prompt_tokens: int = 0
    completion_tokens: int = 0
    total_tokens: int = 0
    cost_usd: float = 0.0

    def set_ai_metrics(self, model: str, prompt_tokens: int,
                      completion_tokens: int, cost_usd: float = 0.0):
        """Set AI-specific metrics"""
        self.model = model
        self.prompt_tokens = prompt_tokens
        self.completion_tokens = completion_tokens
        self.total_tokens = prompt_tokens + completion_tokens
        self.cost_usd = cost_usd
        self.attributes.update({
            "ai.model": model,
            "ai.prompt_tokens": prompt_tokens,
            "ai.completion_tokens": completion_tokens,
            "ai.total_tokens": self.total_tokens,
            "ai.cost_usd": cost_usd
        })

    def add_event(self, name: str, attributes: Dict = None):
        self.events.append({
            "timestamp": time.time(),
            "name": name,
            "attributes": attributes or {}
        })

    def end(self, status: str = "OK"):
        self.end_time = time.time()
        self.status = status

    @property
    def duration_ms(self) -> float:
        if self.end_time:
            return (self.end_time - self.start_time) * 1000
        return 0

class AITracer:
    """Tracer specialized for AI applications"""

    def __init__(self, service_name: str):
        self.service_name = service_name
        self.spans: Dict[str, List[AISpan]] = {}
        self._context_var = threading.local()

    def get_current_context(self) -> Optional[TraceContext]:
        return getattr(self._context_var, 'context', None)

    def set_current_context(self, context: TraceContext):
        self._context_var.context = context

    @contextmanager
    def start_span(self, operation_name: str, **attributes) -> AISpan:
        """Start a new span"""
        current_context = self.get_current_context()

        if current_context:
            # Child span
            trace_id = current_context.trace_id
            parent_span_id = current_context.span_id
        else:
            # Root span
            trace_id = str(uuid.uuid4())
            parent_span_id = None

        span_id = str(uuid.uuid4())

        context = TraceContext(
            trace_id=trace_id,
            span_id=span_id,
            parent_span_id=parent_span_id
        )

        span = AISpan(
            context=context,
            operation_name=operation_name,
            service_name=self.service_name,
            start_time=time.time(),
            attributes=attributes
        )

        # Store span
        if trace_id not in self.spans:
            self.spans[trace_id] = []
        self.spans[trace_id].append(span)

        # Set as current context
        old_context = self.get_current_context()
        self.set_current_context(context)

        try:
            yield span
            span.end("OK")
        except Exception as e:
            span.end("ERROR")
            span.add_event("exception", {
                "type": type(e).__name__,
                "message": str(e)
            })
            raise
        finally:
            if old_context:
                self.set_current_context(old_context)
            else:
                self._context_var.context = None

    @contextmanager
    def start_llm_span(self, operation_name: str, model: str) -> AISpan:
        """Start a span specifically for LLM calls"""
        with self.start_span(operation_name, ai_model=model) as span:
            span.add_event("llm_call_start", {"model": model})
            yield span

    def get_trace(self, trace_id: str) -> Dict:
        """Get complete trace data"""
        spans = self.spans.get(trace_id, [])

        return {
            "trace_id": trace_id,
            "spans": [
                {
                    "span_id": s.context.span_id,
                    "parent_span_id": s.context.parent_span_id,
                    "operation": s.operation_name,
                    "service": s.service_name,
                    "duration_ms": s.duration_ms,
                    "status": s.status,
                    "attributes": s.attributes,
                    "events": s.events,
                    "ai_metrics": {
                        "model": s.model,
                        "tokens": s.total_tokens,
                        "cost": s.cost_usd
                    } if s.model else None
                }
                for s in spans
            ],
            "total_duration_ms": self._calculate_trace_duration(spans),
            "total_tokens": sum(s.total_tokens for s in spans),
            "total_cost_usd": sum(s.cost_usd for s in spans)
        }

    def _calculate_trace_duration(self, spans: List[AISpan]) -> float:
        root_spans = [s for s in spans if not s.context.parent_span_id]
        if root_spans:
            return root_spans[0].duration_ms
        return 0

LLM Call Instrumentation

from openai import OpenAI
import functools

class TracedOpenAI:
    """OpenAI client with automatic tracing"""

    def __init__(self, tracer: AITracer):
        self.client = OpenAI()
        self.tracer = tracer

    def chat_completion(self, **kwargs) -> Dict:
        """Make a traced chat completion call"""
        model = kwargs.get("model", "gpt-4o")

        with self.tracer.start_llm_span("chat_completion", model) as span:
            # Add request attributes
            span.attributes["ai.messages_count"] = len(kwargs.get("messages", []))
            span.attributes["ai.max_tokens"] = kwargs.get("max_tokens")
            span.attributes["ai.temperature"] = kwargs.get("temperature")

            # Add event for prompt
            if kwargs.get("messages"):
                span.add_event("prompt", {
                    "system": self._extract_system_message(kwargs["messages"]),
                    "user_length": self._user_message_length(kwargs["messages"])
                })

            # Make the call
            response = self.client.chat.completions.create(**kwargs)

            # Record metrics
            usage = response.usage
            cost = self._calculate_cost(model, usage.prompt_tokens, usage.completion_tokens)

            span.set_ai_metrics(
                model=model,
                prompt_tokens=usage.prompt_tokens,
                completion_tokens=usage.completion_tokens,
                cost_usd=cost
            )

            # Add response event
            span.add_event("response", {
                "finish_reason": response.choices[0].finish_reason,
                "response_length": len(response.choices[0].message.content)
            })

            return response

    def _extract_system_message(self, messages: List[Dict]) -> str:
        for m in messages:
            if m.get("role") == "system":
                return m["content"][:100]  # First 100 chars
        return ""

    def _user_message_length(self, messages: List[Dict]) -> int:
        for m in reversed(messages):
            if m.get("role") == "user":
                return len(m.get("content", ""))
        return 0

    def _calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
        # Pricing per million tokens
        pricing = {
            "gpt-4o": (2.50, 10.00),
            "gpt-4o-mini": (0.15, 0.60),
            "o1-preview": (15.00, 60.00)
        }

        input_rate, output_rate = pricing.get(model, (2.50, 10.00))
        return (input_tokens * input_rate + output_tokens * output_rate) / 1_000_000

Multi-Service Tracing

class DistributedTracer:
    """Handle tracing across multiple services"""

    def __init__(self):
        self.tracers: Dict[str, AITracer] = {}

    def get_tracer(self, service_name: str) -> AITracer:
        """Get or create a tracer for a service"""
        if service_name not in self.tracers:
            self.tracers[service_name] = AITracer(service_name)
        return self.tracers[service_name]

    def inject_context(self, headers: Dict) -> Dict:
        """Inject trace context into headers for propagation"""
        context = None
        for tracer in self.tracers.values():
            context = tracer.get_current_context()
            if context:
                break

        if context:
            headers["X-Trace-ID"] = context.trace_id
            headers["X-Span-ID"] = context.span_id
            if context.baggage:
                headers["X-Baggage"] = ",".join(f"{k}={v}" for k, v in context.baggage.items())

        return headers

    def extract_context(self, headers: Dict) -> Optional[TraceContext]:
        """Extract trace context from headers"""
        trace_id = headers.get("X-Trace-ID")
        span_id = headers.get("X-Span-ID")

        if trace_id and span_id:
            baggage = {}
            if "X-Baggage" in headers:
                for item in headers["X-Baggage"].split(","):
                    k, v = item.split("=")
                    baggage[k] = v

            return TraceContext(
                trace_id=trace_id,
                span_id=span_id,
                baggage=baggage
            )

        return None

# Example: Tracing across agent and tool services
distributed_tracer = DistributedTracer()

# Agent service
agent_tracer = distributed_tracer.get_tracer("agent-service")

async def agent_process_request(request: Dict):
    with agent_tracer.start_span("process_request") as span:
        # Call tool service
        headers = distributed_tracer.inject_context({})
        result = await call_tool_service(request, headers)
        return result

# Tool service
tool_tracer = distributed_tracer.get_tracer("tool-service")

async def tool_endpoint(request: Dict, headers: Dict):
    # Extract context from parent
    parent_context = distributed_tracer.extract_context(headers)

    if parent_context:
        tool_tracer.set_current_context(parent_context)

    with tool_tracer.start_span("execute_tool") as span:
        result = await execute_tool(request)
        return result

Trace Analysis

class TraceAnalyzer:
    """Analyze traces for insights"""

    def __init__(self, tracer: AITracer):
        self.tracer = tracer

    def get_critical_path(self, trace_id: str) -> List[Dict]:
        """Get the critical path through a trace"""
        trace = self.tracer.get_trace(trace_id)
        spans = trace["spans"]

        # Build span tree
        children: Dict[str, List[Dict]] = {}
        for span in spans:
            parent = span["parent_span_id"]
            if parent not in children:
                children[parent] = []
            children[parent].append(span)

        # Find critical path (longest duration at each level)
        def find_path(parent_id: Optional[str]) -> List[Dict]:
            if parent_id not in children:
                return []

            child_spans = children[parent_id]
            if not child_spans:
                return []

            # Get span with longest duration
            longest = max(child_spans, key=lambda s: s["duration_ms"])

            return [longest] + find_path(longest["span_id"])

        root = next((s for s in spans if not s["parent_span_id"]), None)
        if root:
            return [root] + find_path(root["span_id"])

        return []

    def get_ai_summary(self, trace_id: str) -> Dict:
        """Get AI-specific summary for a trace"""
        trace = self.tracer.get_trace(trace_id)

        llm_spans = [s for s in trace["spans"] if s.get("ai_metrics")]

        return {
            "llm_calls": len(llm_spans),
            "total_tokens": trace["total_tokens"],
            "total_cost_usd": trace["total_cost_usd"],
            "models_used": list(set(s["ai_metrics"]["model"] for s in llm_spans if s["ai_metrics"]["model"])),
            "by_model": self._aggregate_by_model(llm_spans)
        }

    def _aggregate_by_model(self, spans: List[Dict]) -> Dict:
        by_model = {}
        for span in spans:
            metrics = span.get("ai_metrics", {})
            model = metrics.get("model")
            if model:
                if model not in by_model:
                    by_model[model] = {"calls": 0, "tokens": 0, "cost": 0}
                by_model[model]["calls"] += 1
                by_model[model]["tokens"] += metrics.get("tokens", 0)
                by_model[model]["cost"] += metrics.get("cost", 0)
        return by_model

Effective tracing reveals the inner workings of AI applications, helping you understand performance bottlenecks, cost drivers, and error sources across your entire system.

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.