Back to Blog
10 min read

Auto-Merge Retrieval for RAG Systems

Introduction

Auto-merge retrieval automatically combines retrieved chunks when they belong to the same parent section or when adjacent chunks are both relevant. This technique ensures comprehensive context without redundancy, improving generation quality.

Auto-Merge Architecture

from dataclasses import dataclass, field
from typing import List, Dict, Optional, Set
import uuid

@dataclass
class HierarchicalNode:
    id: str
    content: str
    level: int  # 0 = document, 1 = section, 2 = subsection, 3 = leaf
    parent_id: Optional[str] = None
    children_ids: List[str] = field(default_factory=list)
    embedding: Optional[List[float]] = None
    metadata: Dict = field(default_factory=dict)

class AutoMergeIndexer:
    """Index documents with hierarchical structure for auto-merging"""

    def __init__(
        self,
        leaf_chunk_size: int = 256,
        merge_threshold: float = 0.6
    ):
        self.leaf_chunk_size = leaf_chunk_size
        self.merge_threshold = merge_threshold
        self.nodes: Dict[str, HierarchicalNode] = {}
        self.parent_to_children: Dict[str, List[str]] = {}

    def index_document(
        self,
        document: str,
        metadata: Dict = None
    ) -> Dict:
        """Index document with hierarchy"""
        # Level 0: Document
        doc_id = f"doc_{uuid.uuid4().hex[:8]}"
        doc_node = HierarchicalNode(
            id=doc_id,
            content=document,
            level=0,
            metadata=metadata or {}
        )
        self.nodes[doc_id] = doc_node

        # Level 1: Sections (split by double newlines or headers)
        sections = self._split_into_sections(document)

        for i, section in enumerate(sections):
            section_id = f"sec_{uuid.uuid4().hex[:8]}"
            section_node = HierarchicalNode(
                id=section_id,
                content=section,
                level=1,
                parent_id=doc_id,
                metadata={"section_index": i}
            )
            self.nodes[section_id] = section_node
            doc_node.children_ids.append(section_id)

            # Level 2: Leaf chunks
            leaves = self._split_into_leaves(section)

            for j, leaf in enumerate(leaves):
                leaf_id = f"leaf_{uuid.uuid4().hex[:8]}"
                leaf_node = HierarchicalNode(
                    id=leaf_id,
                    content=leaf,
                    level=2,
                    parent_id=section_id,
                    metadata={"leaf_index": j}
                )
                self.nodes[leaf_id] = leaf_node
                section_node.children_ids.append(leaf_id)

        # Build parent-to-children mapping
        for node in self.nodes.values():
            if node.parent_id:
                if node.parent_id not in self.parent_to_children:
                    self.parent_to_children[node.parent_id] = []
                self.parent_to_children[node.parent_id].append(node.id)

        return {
            "document_id": doc_id,
            "sections": len(doc_node.children_ids),
            "total_leaves": sum(
                len(self.nodes[sid].children_ids)
                for sid in doc_node.children_ids
            )
        }

    def _split_into_sections(self, text: str) -> List[str]:
        """Split document into sections"""
        import re

        # Try splitting by headers
        if re.search(r'^#{1,3}\s', text, re.MULTILINE):
            sections = re.split(r'(?=^#{1,3}\s)', text, flags=re.MULTILINE)
        else:
            # Split by double newlines
            sections = text.split('\n\n\n')

        # Clean and filter empty sections
        sections = [s.strip() for s in sections if s.strip()]

        # Merge very small sections
        merged = []
        current = ""
        for section in sections:
            if len(current) + len(section) < 500:
                current += "\n\n" + section
            else:
                if current:
                    merged.append(current.strip())
                current = section

        if current:
            merged.append(current.strip())

        return merged if merged else [text]

    def _split_into_leaves(self, text: str) -> List[str]:
        """Split section into leaf chunks"""
        import re

        # Split by sentences first
        sentences = re.split(r'(?<=[.!?])\s+', text)

        # Group sentences into chunks
        chunks = []
        current = ""

        for sentence in sentences:
            if len(current) + len(sentence) <= self.leaf_chunk_size:
                current += " " + sentence if current else sentence
            else:
                if current:
                    chunks.append(current.strip())
                current = sentence

        if current:
            chunks.append(current.strip())

        return chunks

    def get_leaves(self) -> List[HierarchicalNode]:
        """Get all leaf nodes"""
        return [n for n in self.nodes.values() if n.level == 2]

    def get_parent(self, node_id: str) -> Optional[HierarchicalNode]:
        """Get parent node"""
        node = self.nodes.get(node_id)
        if node and node.parent_id:
            return self.nodes.get(node.parent_id)
        return None

    def get_siblings(self, node_id: str) -> List[HierarchicalNode]:
        """Get sibling nodes"""
        node = self.nodes.get(node_id)
        if not node or not node.parent_id:
            return []

        sibling_ids = self.parent_to_children.get(node.parent_id, [])
        return [self.nodes[sid] for sid in sibling_ids if sid in self.nodes]

Auto-Merge Retriever

class AutoMergeRetriever:
    """Retriever with automatic chunk merging"""

    def __init__(
        self,
        indexer: AutoMergeIndexer,
        embedding_model=None,
        merge_threshold: float = 0.5
    ):
        self.indexer = indexer
        self.embedding_model = embedding_model
        self.merge_threshold = merge_threshold
        self.embeddings: Dict[str, List[float]] = {}

    def build_index(self):
        """Build embeddings for leaf nodes"""
        leaves = self.indexer.get_leaves()

        for leaf in leaves:
            if self.embedding_model:
                embedding = self.embedding_model.encode(leaf.content)
            else:
                embedding = self._simple_embedding(leaf.content)

            self.embeddings[leaf.id] = embedding
            leaf.embedding = embedding

    def _simple_embedding(self, text: str) -> List[float]:
        """Simple embedding for demonstration"""
        words = text.lower().split()
        embedding = [0.0] * 128

        for word in words:
            idx = hash(word) % 128
            embedding[idx] += 1

        norm = sum(x*x for x in embedding) ** 0.5
        if norm > 0:
            embedding = [x/norm for x in embedding]

        return embedding

    def retrieve(self, query: str, top_k: int = 3) -> List[Dict]:
        """Retrieve with auto-merging"""
        # Get query embedding
        if self.embedding_model:
            query_emb = self.embedding_model.encode(query)
        else:
            query_emb = self._simple_embedding(query)

        # Score all leaves
        scored = []
        for leaf_id, embedding in self.embeddings.items():
            score = self._cosine_similarity(query_emb, embedding)
            scored.append((leaf_id, score))

        scored.sort(key=lambda x: x[1], reverse=True)

        # Get top candidates (more than needed for merging)
        candidates = scored[:top_k * 3]

        # Auto-merge relevant chunks
        merged_results = self._auto_merge(candidates, top_k)

        return merged_results

    def _cosine_similarity(self, a: List[float], b: List[float]) -> float:
        """Calculate cosine similarity"""
        dot = sum(x*y for x, y in zip(a, b))
        norm_a = sum(x*x for x in a) ** 0.5
        norm_b = sum(x*x for x in b) ** 0.5
        return dot / (norm_a * norm_b) if norm_a * norm_b > 0 else 0

    def _auto_merge(
        self,
        candidates: List[tuple],
        top_k: int
    ) -> List[Dict]:
        """Auto-merge related chunks"""
        # Group candidates by parent
        by_parent: Dict[str, List[tuple]] = {}

        for leaf_id, score in candidates:
            leaf = self.indexer.nodes[leaf_id]
            parent_id = leaf.parent_id

            if parent_id not in by_parent:
                by_parent[parent_id] = []
            by_parent[parent_id].append((leaf_id, score))

        results = []

        for parent_id, parent_candidates in by_parent.items():
            if len(parent_candidates) == 0:
                continue

            # Check if we should merge to parent
            should_merge = self._should_merge_to_parent(parent_candidates)

            if should_merge:
                # Return parent content
                parent = self.indexer.nodes[parent_id]
                avg_score = sum(s for _, s in parent_candidates) / len(parent_candidates)

                results.append({
                    "type": "merged",
                    "content": parent.content,
                    "score": avg_score,
                    "merged_count": len(parent_candidates),
                    "node_id": parent_id,
                    "level": "section"
                })
            else:
                # Return individual leaves or merge adjacent
                merged_leaves = self._merge_adjacent_leaves(parent_candidates)

                for merged in merged_leaves:
                    results.append(merged)

        # Sort by score and take top_k
        results.sort(key=lambda x: x["score"], reverse=True)
        return results[:top_k]

    def _should_merge_to_parent(
        self,
        candidates: List[tuple]
    ) -> bool:
        """Determine if candidates should be merged to parent"""
        if len(candidates) < 2:
            return False

        # Get parent's children count
        first_leaf = self.indexer.nodes[candidates[0][0]]
        parent = self.indexer.get_parent(first_leaf.id)

        if not parent:
            return False

        children_count = len(parent.children_ids)

        # Merge if majority of children are relevant
        coverage = len(candidates) / children_count

        # Also check average score
        avg_score = sum(s for _, s in candidates) / len(candidates)

        return coverage >= self.merge_threshold and avg_score > 0.3

    def _merge_adjacent_leaves(
        self,
        candidates: List[tuple]
    ) -> List[Dict]:
        """Merge adjacent leaf chunks"""
        if not candidates:
            return []

        # Sort by position
        sorted_candidates = []
        for leaf_id, score in candidates:
            leaf = self.indexer.nodes[leaf_id]
            position = leaf.metadata.get("leaf_index", 0)
            sorted_candidates.append((leaf_id, score, position))

        sorted_candidates.sort(key=lambda x: x[2])

        # Find adjacent groups
        groups = []
        current_group = [sorted_candidates[0]]

        for i in range(1, len(sorted_candidates)):
            prev_pos = sorted_candidates[i-1][2]
            curr_pos = sorted_candidates[i][2]

            if curr_pos - prev_pos <= 1:  # Adjacent
                current_group.append(sorted_candidates[i])
            else:
                groups.append(current_group)
                current_group = [sorted_candidates[i]]

        groups.append(current_group)

        # Convert groups to results
        results = []
        for group in groups:
            if len(group) == 1:
                leaf_id, score, _ = group[0]
                leaf = self.indexer.nodes[leaf_id]
                results.append({
                    "type": "leaf",
                    "content": leaf.content,
                    "score": score,
                    "merged_count": 1,
                    "node_id": leaf_id,
                    "level": "leaf"
                })
            else:
                # Merge group
                contents = []
                scores = []
                for leaf_id, score, _ in group:
                    leaf = self.indexer.nodes[leaf_id]
                    contents.append(leaf.content)
                    scores.append(score)

                results.append({
                    "type": "merged_adjacent",
                    "content": " ".join(contents),
                    "score": sum(scores) / len(scores),
                    "merged_count": len(group),
                    "node_ids": [g[0] for g in group],
                    "level": "merged_leaves"
                })

        return results

Intelligent Merge Strategies

class IntelligentMergeRetriever:
    """Advanced merging with multiple strategies"""

    def __init__(
        self,
        indexer: AutoMergeIndexer,
        embedding_model=None
    ):
        self.indexer = indexer
        self.embedding_model = embedding_model
        self.base_retriever = AutoMergeRetriever(
            indexer, embedding_model, merge_threshold=0.5
        )
        self.base_retriever.build_index()

    def retrieve_with_strategy(
        self,
        query: str,
        strategy: str = "auto",
        top_k: int = 3
    ) -> List[Dict]:
        """Retrieve with specified merge strategy"""
        if strategy == "always_merge":
            return self._always_merge_strategy(query, top_k)
        elif strategy == "never_merge":
            return self._never_merge_strategy(query, top_k)
        elif strategy == "semantic_merge":
            return self._semantic_merge_strategy(query, top_k)
        else:  # auto
            return self.base_retriever.retrieve(query, top_k)

    def _always_merge_strategy(
        self,
        query: str,
        top_k: int
    ) -> List[Dict]:
        """Always return parent sections"""
        # Get initial results
        base_results = self.base_retriever.retrieve(query, top_k * 2)

        # Group by parent and return parents
        seen_parents = set()
        results = []

        for result in base_results:
            if result["type"] == "merged":
                if result["node_id"] not in seen_parents:
                    results.append(result)
                    seen_parents.add(result["node_id"])
            else:
                node_id = result.get("node_id") or result.get("node_ids", [None])[0]
                if node_id:
                    leaf = self.indexer.nodes.get(node_id)
                    if leaf and leaf.parent_id not in seen_parents:
                        parent = self.indexer.nodes.get(leaf.parent_id)
                        if parent:
                            results.append({
                                "type": "merged",
                                "content": parent.content,
                                "score": result["score"],
                                "merged_count": len(parent.children_ids),
                                "node_id": parent.id,
                                "level": "section"
                            })
                            seen_parents.add(parent.id)

        return results[:top_k]

    def _never_merge_strategy(
        self,
        query: str,
        top_k: int
    ) -> List[Dict]:
        """Return individual leaf chunks only"""
        query_emb = self.base_retriever._simple_embedding(query)

        scored = []
        for leaf_id, embedding in self.base_retriever.embeddings.items():
            score = self.base_retriever._cosine_similarity(query_emb, embedding)
            scored.append((leaf_id, score))

        scored.sort(key=lambda x: x[1], reverse=True)

        results = []
        for leaf_id, score in scored[:top_k]:
            leaf = self.indexer.nodes[leaf_id]
            results.append({
                "type": "leaf",
                "content": leaf.content,
                "score": score,
                "merged_count": 1,
                "node_id": leaf_id,
                "level": "leaf"
            })

        return results

    def _semantic_merge_strategy(
        self,
        query: str,
        top_k: int
    ) -> List[Dict]:
        """Merge based on semantic coherence"""
        # Get candidates
        query_emb = self.base_retriever._simple_embedding(query)

        scored = []
        for leaf_id, embedding in self.base_retriever.embeddings.items():
            score = self.base_retriever._cosine_similarity(query_emb, embedding)
            scored.append((leaf_id, score))

        scored.sort(key=lambda x: x[1], reverse=True)
        candidates = scored[:top_k * 3]

        # Semantic clustering
        clusters = self._semantic_cluster(candidates)

        results = []
        for cluster in clusters[:top_k]:
            if len(cluster) == 1:
                leaf_id, score = cluster[0]
                leaf = self.indexer.nodes[leaf_id]
                results.append({
                    "type": "leaf",
                    "content": leaf.content,
                    "score": score,
                    "merged_count": 1,
                    "node_id": leaf_id,
                    "level": "leaf"
                })
            else:
                contents = []
                scores = []
                for leaf_id, score in cluster:
                    leaf = self.indexer.nodes[leaf_id]
                    contents.append(leaf.content)
                    scores.append(score)

                results.append({
                    "type": "semantic_merged",
                    "content": " ".join(contents),
                    "score": sum(scores) / len(scores),
                    "merged_count": len(cluster),
                    "level": "semantic_cluster"
                })

        return results

    def _semantic_cluster(
        self,
        candidates: List[tuple]
    ) -> List[List[tuple]]:
        """Cluster candidates by semantic similarity"""
        if len(candidates) <= 1:
            return [[c] for c in candidates]

        # Simple clustering by embedding similarity
        clusters = []
        used = set()

        for leaf_id, score in candidates:
            if leaf_id in used:
                continue

            cluster = [(leaf_id, score)]
            used.add(leaf_id)

            emb = self.base_retriever.embeddings[leaf_id]

            # Find similar candidates
            for other_id, other_score in candidates:
                if other_id in used:
                    continue

                other_emb = self.base_retriever.embeddings[other_id]
                sim = self.base_retriever._cosine_similarity(emb, other_emb)

                if sim > 0.7:  # High similarity threshold
                    cluster.append((other_id, other_score))
                    used.add(other_id)

            clusters.append(cluster)

        # Sort clusters by max score
        clusters.sort(key=lambda c: max(s for _, s in c), reverse=True)

        return clusters

Complete Auto-Merge RAG

class AutoMergeRAG:
    """RAG system with auto-merge retrieval"""

    def __init__(self, generator):
        self.indexer = AutoMergeIndexer(leaf_chunk_size=200)
        self.retriever = None
        self.intelligent_retriever = None
        self.generator = generator

    def add_document(self, document: str, metadata: Dict = None):
        """Add and index document"""
        result = self.indexer.index_document(document, metadata)

        # Build retrievers
        self.retriever = AutoMergeRetriever(self.indexer)
        self.retriever.build_index()

        self.intelligent_retriever = IntelligentMergeRetriever(
            self.indexer
        )

        return result

    def query(
        self,
        question: str,
        strategy: str = "auto",
        top_k: int = 3
    ) -> Dict:
        """Query with auto-merge"""
        if not self.retriever:
            return {"error": "No documents indexed"}

        # Retrieve with strategy
        if strategy == "auto":
            results = self.retriever.retrieve(question, top_k)
        else:
            results = self.intelligent_retriever.retrieve_with_strategy(
                question, strategy, top_k
            )

        # Build context
        context_parts = []
        for i, r in enumerate(results):
            merge_info = f"[{r['type']}, {r['merged_count']} chunk(s)]"
            context_parts.append(
                f"Source {i+1} {merge_info}:\n{r['content']}"
            )

        context = "\n\n---\n\n".join(context_parts)

        # Generate
        prompt = f"""Answer based on the following context.
The context may include merged sections for comprehensive coverage.

{context}

Question: {question}

Answer:"""

        response = self.generator.generate(prompt)

        return {
            "answer": response,
            "strategy_used": strategy,
            "sources": [
                {
                    "type": r["type"],
                    "merged_count": r["merged_count"],
                    "score": r["score"],
                    "preview": r["content"][:100] + "..."
                }
                for r in results
            ]
        }

# Usage
class MockGenerator:
    def generate(self, prompt):
        return "Generated response."

rag = AutoMergeRAG(MockGenerator())

document = """
# Introduction to Machine Learning

Machine learning is a branch of artificial intelligence.
It allows systems to learn from data.
The field has grown significantly in recent years.

# Types of Machine Learning

Supervised learning uses labeled data.
Unsupervised learning finds patterns without labels.
Reinforcement learning learns through interaction.

# Deep Learning

Deep learning uses neural networks.
These networks have multiple layers.
They excel at complex pattern recognition.
"""

rag.add_document(document)

result = rag.query("What is machine learning?", strategy="auto")
print(f"Answer: {result['answer']}")
print(f"Sources: {len(result['sources'])}")
for source in result['sources']:
    print(f"  - Type: {source['type']}, Merged: {source['merged_count']}")

Conclusion

Auto-merge retrieval intelligently combines related chunks to provide comprehensive context while avoiding redundancy. By understanding document hierarchy and chunk relationships, auto-merge ensures the generator receives coherent, complete information. Different merge strategies allow optimization for various use cases, from precision-focused leaf retrieval to comprehensive section-level context.

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.