Back to Blog
6 min read

Databricks Vector Search Deep Dive: Production Patterns

Databricks Vector Search Deep Dive: Production Patterns

Taking Vector Search to production requires careful consideration of performance, reliability, and maintenance. This guide covers production-ready patterns.

Production Architecture

PRODUCTION_PATTERNS = {
    "index_design": {
        "chunking_strategy": "Optimize chunk size for your use case",
        "metadata_filtering": "Include filterable columns",
        "embedding_model": "Match model to content type"
    },
    "scaling": {
        "endpoint_sizing": "Right-size for query volume",
        "index_partitioning": "Partition large indexes",
        "caching": "Cache frequent queries"
    },
    "reliability": {
        "monitoring": "Track latency, throughput, errors",
        "fallbacks": "Handle index unavailability",
        "sync_management": "Monitor and manage syncs"
    }
}

Document Chunking Strategies

from typing import List, Dict
import tiktoken

class DocumentChunker:
    """Production-grade document chunking"""

    def __init__(
        self,
        chunk_size: int = 500,
        chunk_overlap: int = 50,
        model: str = "cl100k_base"
    ):
        self.chunk_size = chunk_size
        self.overlap = chunk_overlap
        self.tokenizer = tiktoken.get_encoding(model)

    def chunk_by_tokens(self, text: str) -> List[str]:
        """Chunk text by token count"""
        tokens = self.tokenizer.encode(text)
        chunks = []

        start = 0
        while start < len(tokens):
            end = start + self.chunk_size
            chunk_tokens = tokens[start:end]
            chunk_text = self.tokenizer.decode(chunk_tokens)
            chunks.append(chunk_text)
            start = end - self.overlap

        return chunks

    def chunk_by_paragraphs(
        self,
        text: str,
        max_tokens: int = 500
    ) -> List[str]:
        """Chunk by paragraphs, respecting max tokens"""
        paragraphs = text.split('\n\n')
        chunks = []
        current_chunk = []
        current_tokens = 0

        for para in paragraphs:
            para_tokens = len(self.tokenizer.encode(para))

            if current_tokens + para_tokens <= max_tokens:
                current_chunk.append(para)
                current_tokens += para_tokens
            else:
                if current_chunk:
                    chunks.append('\n\n'.join(current_chunk))
                current_chunk = [para]
                current_tokens = para_tokens

        if current_chunk:
            chunks.append('\n\n'.join(current_chunk))

        return chunks

    def chunk_with_metadata(
        self,
        doc_id: str,
        title: str,
        text: str,
        source: str
    ) -> List[Dict]:
        """Chunk and add metadata for each chunk"""
        chunks = self.chunk_by_paragraphs(text)

        return [
            {
                "chunk_id": f"{doc_id}_chunk_{i}",
                "doc_id": doc_id,
                "title": title,
                "content": chunk,
                "chunk_index": i,
                "total_chunks": len(chunks),
                "source": source
            }
            for i, chunk in enumerate(chunks)
        ]

# Usage with PySpark
def process_documents_for_vector_search(df):
    """Process documents for vector search indexing"""
    from pyspark.sql.functions import udf, explode
    from pyspark.sql.types import ArrayType, StructType, StructField, StringType, IntegerType

    chunker = DocumentChunker(chunk_size=500, chunk_overlap=50)

    chunk_schema = ArrayType(StructType([
        StructField("chunk_id", StringType()),
        StructField("doc_id", StringType()),
        StructField("title", StringType()),
        StructField("content", StringType()),
        StructField("chunk_index", IntegerType()),
        StructField("total_chunks", IntegerType()),
        StructField("source", StringType())
    ]))

    @udf(chunk_schema)
    def chunk_document(doc_id, title, content, source):
        return chunker.chunk_with_metadata(doc_id, title, content, source)

    return df.withColumn(
        "chunks",
        chunk_document("doc_id", "title", "content", "source")
    ).select(explode("chunks").alias("chunk")).select("chunk.*")
from databricks.vector_search.client import VectorSearchClient

class FilteredVectorSearch:
    """Vector search with metadata filtering"""

    def __init__(self, endpoint_name: str, index_name: str):
        self.vsc = VectorSearchClient()
        self.index = self.vsc.get_index(endpoint_name, index_name)

    def search_with_filters(
        self,
        query: str,
        filters: Dict = None,
        num_results: int = 10
    ) -> List[Dict]:
        """Search with metadata filters"""

        # Build filter expression
        filter_expr = self._build_filter_expression(filters) if filters else None

        results = self.index.similarity_search(
            query_text=query,
            columns=["chunk_id", "doc_id", "title", "content", "source"],
            num_results=num_results,
            filters=filter_expr
        )

        return self._parse_results(results)

    def _build_filter_expression(self, filters: Dict) -> Dict:
        """Build filter expression for Vector Search"""
        # Simple equality filters
        return filters

        # For complex filters, use the filter syntax:
        # {"column_name NOT": "value"}  # NOT equal
        # {"column_name <": value}      # Less than
        # {"column_name >": value}      # Greater than

    def _parse_results(self, results: Dict) -> List[Dict]:
        """Parse raw results into structured format"""
        parsed = []
        for row in results['result']['data_array']:
            parsed.append({
                "score": row[0],
                "chunk_id": row[1],
                "doc_id": row[2],
                "title": row[3],
                "content": row[4],
                "source": row[5]
            })
        return parsed

    def search_by_source(
        self,
        query: str,
        sources: List[str],
        num_results: int = 10
    ) -> List[Dict]:
        """Search within specific sources"""
        # Note: Vector Search supports IN filters
        all_results = []

        for source in sources:
            results = self.search_with_filters(
                query=query,
                filters={"source": source},
                num_results=num_results // len(sources)
            )
            all_results.extend(results)

        # Re-sort by score
        all_results.sort(key=lambda x: x['score'], reverse=True)
        return all_results[:num_results]

Caching and Performance

from functools import lru_cache
import hashlib
import time

class CachedVectorSearch:
    """Vector search with query caching"""

    def __init__(self, index, cache_ttl: int = 300):
        self.index = index
        self.cache = {}
        self.cache_ttl = cache_ttl

    def search(
        self,
        query: str,
        filters: Dict = None,
        num_results: int = 10,
        use_cache: bool = True
    ) -> List[Dict]:
        """Search with optional caching"""

        cache_key = self._make_cache_key(query, filters, num_results)

        # Check cache
        if use_cache and cache_key in self.cache:
            cached_result, timestamp = self.cache[cache_key]
            if time.time() - timestamp < self.cache_ttl:
                return cached_result

        # Execute search
        results = self.index.similarity_search(
            query_text=query,
            columns=["chunk_id", "title", "content"],
            num_results=num_results,
            filters=filters
        )

        parsed = self._parse_results(results)

        # Update cache
        self.cache[cache_key] = (parsed, time.time())

        # Clean old cache entries
        self._cleanup_cache()

        return parsed

    def _make_cache_key(
        self,
        query: str,
        filters: Dict,
        num_results: int
    ) -> str:
        """Create cache key"""
        key_data = f"{query}|{str(filters)}|{num_results}"
        return hashlib.md5(key_data.encode()).hexdigest()

    def _cleanup_cache(self):
        """Remove expired cache entries"""
        current_time = time.time()
        expired = [
            k for k, (_, ts) in self.cache.items()
            if current_time - ts > self.cache_ttl
        ]
        for k in expired:
            del self.cache[k]

    def _parse_results(self, results: Dict) -> List[Dict]:
        return [
            {"score": row[0], "chunk_id": row[1], "title": row[2], "content": row[3]}
            for row in results['result']['data_array']
        ]

    def invalidate_cache(self):
        """Clear all cached results"""
        self.cache.clear()

Monitoring and Alerting

from dataclasses import dataclass
from datetime import datetime
from typing import Optional
import time

@dataclass
class SearchMetrics:
    query: str
    latency_ms: float
    num_results: int
    timestamp: datetime
    success: bool
    error: Optional[str] = None

class MonitoredVectorSearch:
    """Vector search with monitoring"""

    def __init__(self, index):
        self.index = index
        self.metrics: List[SearchMetrics] = []
        self.alert_threshold_ms = 1000  # Alert if latency > 1s

    def search(self, query: str, **kwargs) -> Dict:
        """Search with monitoring"""

        start = time.time()

        try:
            results = self.index.similarity_search(
                query_text=query,
                **kwargs
            )
            latency = (time.time() - start) * 1000

            metric = SearchMetrics(
                query=query[:100],
                latency_ms=latency,
                num_results=len(results.get('result', {}).get('data_array', [])),
                timestamp=datetime.now(),
                success=True
            )

            if latency > self.alert_threshold_ms:
                self._alert_slow_query(metric)

        except Exception as e:
            latency = (time.time() - start) * 1000
            metric = SearchMetrics(
                query=query[:100],
                latency_ms=latency,
                num_results=0,
                timestamp=datetime.now(),
                success=False,
                error=str(e)
            )
            self._alert_error(metric)
            raise

        finally:
            self.metrics.append(metric)

        return results

    def get_metrics_summary(self, last_n: int = 100) -> Dict:
        """Get metrics summary"""
        recent = self.metrics[-last_n:]

        if not recent:
            return {}

        latencies = [m.latency_ms for m in recent if m.success]
        success_count = sum(1 for m in recent if m.success)

        return {
            "total_queries": len(recent),
            "success_rate": success_count / len(recent),
            "avg_latency_ms": sum(latencies) / len(latencies) if latencies else 0,
            "p95_latency_ms": sorted(latencies)[int(len(latencies) * 0.95)] if latencies else 0,
            "errors": [m.error for m in recent if not m.success]
        }

    def _alert_slow_query(self, metric: SearchMetrics):
        """Alert on slow query"""
        print(f"ALERT: Slow query ({metric.latency_ms:.0f}ms): {metric.query}")

    def _alert_error(self, metric: SearchMetrics):
        """Alert on error"""
        print(f"ERROR: Query failed: {metric.error}")

Conclusion

Production vector search requires proper chunking strategies, filtering capabilities, caching, and monitoring. Build these patterns into your application from the start for reliable semantic search at scale.

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.