Back to Blog
10 min read

Recursive Retrieval for Complex RAG Queries

Introduction

Recursive retrieval handles complex queries that require multiple retrieval steps or following references between documents. This technique enables RAG systems to answer questions that span multiple documents or require reasoning chains.

Recursive Retrieval Architecture

from dataclasses import dataclass, field
from typing import List, Dict, Optional, Set
from enum import Enum
import uuid

class RetrievalAction(Enum):
    RETRIEVE = "retrieve"
    FOLLOW_REFERENCE = "follow_reference"
    AGGREGATE = "aggregate"
    STOP = "stop"

@dataclass
class RetrievalStep:
    step_id: str
    action: RetrievalAction
    query: str
    results: List[Dict] = field(default_factory=list)
    references_found: List[str] = field(default_factory=list)
    reasoning: str = ""

@dataclass
class RecursiveResult:
    final_context: str
    retrieval_path: List[RetrievalStep]
    total_documents: int
    depth_reached: int

class RecursiveRetriever:
    """Retriever that recursively follows references and queries"""

    def __init__(
        self,
        base_retriever,
        llm_client=None,
        max_depth: int = 3,
        max_documents: int = 10
    ):
        self.retriever = base_retriever
        self.llm = llm_client
        self.max_depth = max_depth
        self.max_documents = max_documents

    def retrieve(self, query: str) -> RecursiveResult:
        """Perform recursive retrieval"""
        retrieval_path = []
        collected_docs = []
        seen_doc_ids: Set[str] = set()
        current_queries = [query]
        depth = 0

        while current_queries and depth < self.max_depth:
            depth += 1
            next_queries = []

            for q in current_queries:
                step = self._execute_retrieval_step(
                    q, depth, seen_doc_ids
                )
                retrieval_path.append(step)

                # Collect new documents
                for result in step.results:
                    doc_id = result.get("id", str(uuid.uuid4()))
                    if doc_id not in seen_doc_ids:
                        seen_doc_ids.add(doc_id)
                        collected_docs.append(result)

                        if len(collected_docs) >= self.max_documents:
                            break

                # Collect references for next iteration
                next_queries.extend(step.references_found)

                if len(collected_docs) >= self.max_documents:
                    break

            current_queries = next_queries[:5]  # Limit queries per depth

        # Build final context
        final_context = self._build_context(collected_docs)

        return RecursiveResult(
            final_context=final_context,
            retrieval_path=retrieval_path,
            total_documents=len(collected_docs),
            depth_reached=depth
        )

    def _execute_retrieval_step(
        self,
        query: str,
        depth: int,
        seen_ids: Set[str]
    ) -> RetrievalStep:
        """Execute a single retrieval step"""
        step_id = f"step_{depth}_{uuid.uuid4().hex[:6]}"

        # Retrieve documents
        results = self.retriever.retrieve(query, top_k=5)

        # Filter already seen
        new_results = [
            r for r in results
            if r.get("id", str(uuid.uuid4())) not in seen_ids
        ]

        # Extract references
        references = self._extract_references(new_results)

        return RetrievalStep(
            step_id=step_id,
            action=RetrievalAction.RETRIEVE,
            query=query,
            results=new_results,
            references_found=references,
            reasoning=f"Retrieved {len(new_results)} new documents, found {len(references)} references"
        )

    def _extract_references(self, results: List[Dict]) -> List[str]:
        """Extract references from retrieved documents"""
        references = []

        for result in results:
            content = result.get("content", "")

            # Look for explicit references
            import re
            ref_patterns = [
                r"see also[:\s]+([^.]+)",
                r"refer(?:s|ring)? to[:\s]+([^.]+)",
                r"as described in[:\s]+([^.]+)",
                r"for more (?:information|details)[,\s]+(?:see|read)[:\s]+([^.]+)"
            ]

            for pattern in ref_patterns:
                matches = re.findall(pattern, content, re.IGNORECASE)
                references.extend(matches)

            # Look for document links
            link_pattern = r'\[([^\]]+)\]\([^)]+\)'
            links = re.findall(link_pattern, content)
            references.extend(links)

        # Clean and deduplicate
        references = list(set(
            ref.strip() for ref in references
            if len(ref.strip()) > 3
        ))

        return references[:5]  # Limit references

    def _build_context(self, documents: List[Dict]) -> str:
        """Build context from collected documents"""
        context_parts = []

        for i, doc in enumerate(documents):
            content = doc.get("content", "")
            source = doc.get("source", f"Document {i+1}")
            context_parts.append(f"[{source}]\n{content}")

        return "\n\n---\n\n".join(context_parts)

Query Decomposition

class QueryDecomposer:
    """Decompose complex queries into sub-queries"""

    def __init__(self, llm_client=None):
        self.llm = llm_client

    def decompose(self, query: str) -> List[Dict]:
        """Decompose query into sub-queries"""
        if not self.llm:
            return [{"query": query, "type": "original", "dependency": None}]

        decomposition_prompt = f"""Analyze this question and break it down into simpler sub-questions if needed.
If the question is already simple, return it as is.

Question: {query}

For each sub-question, indicate:
1. The sub-question itself
2. Whether it depends on the answer to a previous sub-question

Format your response as:
SUB1: [question]
DEPENDS: [none or SUB number]

Sub-questions:"""

        response = self.llm.generate(decomposition_prompt)

        # Parse response
        sub_queries = self._parse_decomposition(response, query)

        return sub_queries

    def _parse_decomposition(
        self,
        response: str,
        original_query: str
    ) -> List[Dict]:
        """Parse decomposition response"""
        import re
        sub_queries = []

        lines = response.strip().split('\n')
        current_sub = None
        current_depends = None

        for line in lines:
            if line.startswith('SUB'):
                if current_sub:
                    sub_queries.append({
                        "query": current_sub,
                        "type": "decomposed",
                        "dependency": current_depends
                    })

                match = re.match(r'SUB\d+:\s*(.+)', line)
                if match:
                    current_sub = match.group(1)
                current_depends = None

            elif line.startswith('DEPENDS:'):
                depends = line.replace('DEPENDS:', '').strip().lower()
                if depends != 'none':
                    current_depends = depends

        if current_sub:
            sub_queries.append({
                "query": current_sub,
                "type": "decomposed",
                "dependency": current_depends
            })

        if not sub_queries:
            return [{"query": original_query, "type": "original", "dependency": None}]

        return sub_queries

class RecursiveQueryRetriever:
    """Retriever with query decomposition and recursive retrieval"""

    def __init__(
        self,
        base_retriever,
        llm_client=None,
        max_depth: int = 3
    ):
        self.retriever = base_retriever
        self.llm = llm_client
        self.decomposer = QueryDecomposer(llm_client)
        self.recursive = RecursiveRetriever(base_retriever, llm_client, max_depth)

    def retrieve(self, query: str) -> Dict:
        """Retrieve with decomposition"""
        # Decompose query
        sub_queries = self.decomposer.decompose(query)

        # Process each sub-query
        all_results = []
        sub_answers = {}

        for i, sub_q in enumerate(sub_queries):
            # Substitute dependencies
            actual_query = self._substitute_dependencies(
                sub_q["query"],
                sub_q["dependency"],
                sub_answers
            )

            # Recursive retrieval
            result = self.recursive.retrieve(actual_query)
            all_results.append({
                "sub_query": sub_q,
                "actual_query": actual_query,
                "result": result
            })

            # Generate intermediate answer if LLM available
            if self.llm and result.final_context:
                answer = self._generate_intermediate_answer(
                    actual_query,
                    result.final_context
                )
                sub_answers[f"sub{i+1}"] = answer

        # Combine all contexts
        combined_context = self._combine_contexts(all_results)

        return {
            "combined_context": combined_context,
            "sub_query_results": all_results,
            "decomposition": sub_queries,
            "total_retrieval_steps": sum(
                len(r["result"].retrieval_path) for r in all_results
            )
        }

    def _substitute_dependencies(
        self,
        query: str,
        dependency: Optional[str],
        answers: Dict[str, str]
    ) -> str:
        """Substitute dependency answers into query"""
        if not dependency or dependency not in answers:
            return query

        # Simple substitution
        return f"{query} (Context: {answers[dependency]})"

    def _generate_intermediate_answer(
        self,
        query: str,
        context: str
    ) -> str:
        """Generate answer for intermediate query"""
        prompt = f"""Based on the context, briefly answer the question.

Context: {context[:2000]}

Question: {query}

Brief answer:"""

        return self.llm.generate(prompt)

    def _combine_contexts(self, results: List[Dict]) -> str:
        """Combine contexts from all sub-queries"""
        parts = []

        for i, r in enumerate(results):
            query = r["sub_query"]["query"]
            context = r["result"].final_context

            parts.append(f"[Sub-query {i+1}: {query}]\n{context}")

        return "\n\n===\n\n".join(parts)

Multi-Hop Retrieval

class MultiHopRetriever:
    """Retriever for multi-hop reasoning questions"""

    def __init__(
        self,
        base_retriever,
        llm_client,
        max_hops: int = 3
    ):
        self.retriever = base_retriever
        self.llm = llm_client
        self.max_hops = max_hops

    def retrieve(self, query: str) -> Dict:
        """Multi-hop retrieval with reasoning"""
        hops = []
        current_context = ""
        current_query = query

        for hop in range(self.max_hops):
            # Retrieve for current query
            results = self.retriever.retrieve(current_query, top_k=3)

            # Build hop context
            hop_context = self._build_hop_context(results)

            # Analyze if more hops needed
            analysis = self._analyze_hop(
                query,
                current_query,
                hop_context,
                current_context,
                hop
            )

            hops.append({
                "hop_number": hop + 1,
                "query": current_query,
                "results": results,
                "context": hop_context,
                "analysis": analysis
            })

            # Update cumulative context
            current_context += f"\n\n[Hop {hop + 1}]\n{hop_context}"

            # Check if we should continue
            if analysis["should_stop"]:
                break

            # Generate next query
            current_query = analysis.get("next_query", current_query)

        return {
            "final_context": current_context,
            "hops": hops,
            "total_hops": len(hops),
            "original_query": query
        }

    def _build_hop_context(self, results: List[Dict]) -> str:
        """Build context from hop results"""
        parts = []
        for r in results:
            content = r.get("content", "")
            parts.append(content)
        return "\n\n".join(parts)

    def _analyze_hop(
        self,
        original_query: str,
        current_query: str,
        hop_context: str,
        cumulative_context: str,
        hop_number: int
    ) -> Dict:
        """Analyze hop results and determine next action"""
        if not self.llm:
            return {"should_stop": True, "reason": "No LLM for analysis"}

        analysis_prompt = f"""Analyze the retrieval results for a multi-hop question.

Original Question: {original_query}
Current Query: {current_query}
Hop Number: {hop_number + 1}

New Information Retrieved:
{hop_context[:1500]}

Previous Context:
{cumulative_context[:1000]}

Determine:
1. Can the original question be answered with current information?
2. If not, what additional information is needed?
3. What should be the next search query?

Respond in format:
ANSWERABLE: [yes/no]
MISSING: [what information is missing, or "none"]
NEXT_QUERY: [next search query, or "none"]"""

        response = self.llm.generate(analysis_prompt)

        # Parse response
        answerable = "yes" in response.lower().split("ANSWERABLE:")[-1].split("\n")[0]
        should_stop = answerable or hop_number >= self.max_hops - 1

        next_query = None
        if "NEXT_QUERY:" in response:
            next_part = response.split("NEXT_QUERY:")[-1].strip()
            if next_part and "none" not in next_part.lower():
                next_query = next_part.split("\n")[0].strip()

        return {
            "should_stop": should_stop,
            "answerable": answerable,
            "next_query": next_query,
            "raw_analysis": response
        }

Self-Querying Recursive Retrieval

class SelfQueryingRetriever:
    """Retriever that generates its own follow-up queries"""

    def __init__(
        self,
        base_retriever,
        llm_client,
        max_iterations: int = 3
    ):
        self.retriever = base_retriever
        self.llm = llm_client
        self.max_iterations = max_iterations

    def retrieve(self, query: str) -> Dict:
        """Self-querying retrieval loop"""
        iterations = []
        all_contexts = []
        current_query = query

        for i in range(self.max_iterations):
            # Retrieve
            results = self.retriever.retrieve(current_query, top_k=3)
            context = self._results_to_context(results)
            all_contexts.append(context)

            # Evaluate completeness
            evaluation = self._evaluate_completeness(
                query,
                all_contexts
            )

            iterations.append({
                "iteration": i + 1,
                "query": current_query,
                "results_count": len(results),
                "evaluation": evaluation
            })

            if evaluation["is_complete"]:
                break

            # Generate follow-up query
            follow_up = self._generate_follow_up(
                query,
                all_contexts,
                evaluation["gaps"]
            )

            if not follow_up:
                break

            current_query = follow_up

        return {
            "final_context": "\n\n---\n\n".join(all_contexts),
            "iterations": iterations,
            "total_iterations": len(iterations),
            "original_query": query
        }

    def _results_to_context(self, results: List[Dict]) -> str:
        """Convert results to context string"""
        return "\n\n".join(r.get("content", "") for r in results)

    def _evaluate_completeness(
        self,
        original_query: str,
        contexts: List[str]
    ) -> Dict:
        """Evaluate if we have enough information"""
        combined = "\n".join(contexts)

        prompt = f"""Evaluate if the retrieved information is sufficient to answer the question.

Question: {original_query}

Retrieved Information:
{combined[:3000]}

Analyze:
1. Is the information sufficient to fully answer the question?
2. What specific information is missing (if any)?

SUFFICIENT: [yes/no]
MISSING: [list specific gaps, or "none"]"""

        response = self.llm.generate(prompt)

        is_complete = "yes" in response.lower().split("SUFFICIENT:")[-1].split("\n")[0]

        gaps = []
        if "MISSING:" in response:
            missing_part = response.split("MISSING:")[-1].strip()
            if "none" not in missing_part.lower():
                gaps = [g.strip() for g in missing_part.split(",")]

        return {
            "is_complete": is_complete,
            "gaps": gaps
        }

    def _generate_follow_up(
        self,
        original_query: str,
        contexts: List[str],
        gaps: List[str]
    ) -> Optional[str]:
        """Generate follow-up query to fill gaps"""
        if not gaps:
            return None

        prompt = f"""Generate a search query to find the missing information.

Original Question: {original_query}

Information Gaps:
{chr(10).join(f"- {g}" for g in gaps)}

Generate a single, focused search query to address the most important gap:"""

        response = self.llm.generate(prompt)

        return response.strip() if response.strip() else None

Complete Recursive RAG

class RecursiveRAG:
    """RAG system with recursive retrieval capabilities"""

    def __init__(self, base_retriever, generator):
        self.retriever = base_retriever
        self.generator = generator
        self.recursive = RecursiveRetriever(base_retriever, generator)
        self.multi_hop = MultiHopRetriever(base_retriever, generator)
        self.self_query = SelfQueryingRetriever(base_retriever, generator)

    def query(
        self,
        question: str,
        mode: str = "auto"
    ) -> Dict:
        """Query with recursive retrieval"""
        # Determine complexity
        if mode == "auto":
            mode = self._determine_mode(question)

        # Execute appropriate retrieval
        if mode == "simple":
            results = self.retriever.retrieve(question, top_k=5)
            context = "\n\n".join(r.get("content", "") for r in results)
            retrieval_info = {"mode": "simple", "steps": 1}

        elif mode == "recursive":
            result = self.recursive.retrieve(question)
            context = result.final_context
            retrieval_info = {
                "mode": "recursive",
                "steps": len(result.retrieval_path),
                "depth": result.depth_reached
            }

        elif mode == "multi_hop":
            result = self.multi_hop.retrieve(question)
            context = result["final_context"]
            retrieval_info = {
                "mode": "multi_hop",
                "hops": result["total_hops"]
            }

        else:  # self_query
            result = self.self_query.retrieve(question)
            context = result["final_context"]
            retrieval_info = {
                "mode": "self_query",
                "iterations": result["total_iterations"]
            }

        # Generate answer
        prompt = f"""Answer the question based on the provided context.

Context:
{context[:4000]}

Question: {question}

Provide a comprehensive answer:"""

        answer = self.generator.generate(prompt)

        return {
            "answer": answer,
            "retrieval_info": retrieval_info,
            "context_length": len(context)
        }

    def _determine_mode(self, question: str) -> str:
        """Determine best retrieval mode for question"""
        question_lower = question.lower()

        # Check for multi-hop indicators
        multi_hop_indicators = [
            "relationship between",
            "how does X affect Y",
            "compare",
            "what led to"
        ]
        if any(ind in question_lower for ind in multi_hop_indicators):
            return "multi_hop"

        # Check for complex query indicators
        if len(question.split()) > 20:
            return "recursive"

        if "?" in question and question.count("?") > 1:
            return "self_query"

        return "simple"

# Usage
class MockRetriever:
    def retrieve(self, query, top_k=5):
        return [{"content": f"Result for: {query}", "id": "1"}]

class MockGenerator:
    def generate(self, prompt):
        return "Generated answer."

rag = RecursiveRAG(MockRetriever(), MockGenerator())

result = rag.query("What is the relationship between X and Y?")
print(f"Mode: {result['retrieval_info']['mode']}")
print(f"Answer: {result['answer']}")

Conclusion

Recursive retrieval enables RAG systems to handle complex queries requiring multiple steps, reference following, or multi-hop reasoning. By decomposing queries, following references, and iteratively retrieving until sufficient information is gathered, recursive retrieval significantly extends the capabilities of RAG systems for answering complex, real-world questions.

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.