Skip to content
Back to Blog
1 min read

Entity Extraction at Scale with LLMs

I wrote “Entity Extraction at Scale with LLMs” to share practical, production-minded guidance on this topic.

Scalable Entity Extraction Pipeline

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
import json
from dataclasses import dataclass
from typing import List, Optional

@dataclass
class ExtractedEntity:
    text: str
    entity_type: str
    confidence: float
    start_pos: int
    end_pos: int
    metadata: Optional[dict] = None

class ScalableEntityExtractor:
    """Distributed entity extraction with LLMs."""

    def __init__(self, spark: SparkSession, llm_client):
        self.spark = spark
        self.client = llm_client

    async def extract_entities_batch(
        self,
        texts: List[str],
        entity_types: List[str],
        batch_size: int = 10
    ) -> List[List[ExtractedEntity]]:
        """Extract entities from multiple texts."""

        results = []

        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            batch_results = await self._process_batch(batch, entity_types)
            results.extend(batch_results)

        return results

    async def _process_batch(
        self,
        texts: List[str],
        entity_types: List[str]
    ) -> List[List[ExtractedEntity]]:
        """Process a batch of texts."""

        texts_formatted = "\n---\n".join([
            f"[TEXT {i}]: {text}"
            for i, text in enumerate(texts)
        ])

        prompt = f"""Extract named entities from these texts.

Entity Types to Extract: {', '.join(entity_types)}

Texts:
{texts_formatted}

For each text, extract entities with:
- text: the entity text
- entity_type: type from the list above
- confidence: 0.0-1.0
- start_pos: character position
- end_pos: character position

Return as JSON: {{"text_0": [...], "text_1": [...], ...}}"""

        response = await self.client.chat_completion(
            model="gpt-4",
            messages=[{"role": "user", "content": prompt}],
            temperature=0
        )

        try:
            parsed = json.loads(response.content)
            results = []
            for i in range(len(texts)):
                entities = parsed.get(f"text_{i}", [])
                results.append([
                    ExtractedEntity(**e) for e in entities
                ])
            return results
        except:
            return [[] for _ in texts]

    def extract_with_spark(
        self,
        df,
        text_col: str,
        entity_types: List[str],
        output_col: str = "entities"
    ):
        """Extract entities using Spark UDF."""

        # Define schema for entities
        entity_schema = ArrayType(StructType([
            StructField("text", StringType()),
            StructField("entity_type", StringType()),
            StructField("confidence", DoubleType()),
            StructField("start_pos", IntegerType()),
            StructField("end_pos", IntegerType())
        ]))

        # Create extraction UDF (simplified - actual impl needs async handling)
        @udf(entity_schema)
        def extract_entities(text):
            # In practice, this would call the LLM service
            # Using placeholder for demonstration
            return []

        return df.withColumn(output_col, extract_entities(col(text_col)))

    def extract_with_cognitive_services(
        self,
        df,
        text_col: str,
        output_col: str = "entities"
    ):
        """Extract entities using Azure Cognitive Services."""
        from synapse.ml.cognitive import NER

        ner = (NER()
            .setTextCol(text_col)
            .setOutputCol(output_col)
            .setErrorCol("ner_error"))

        return ner.transform(df)

Domain-Specific Entity Extraction

class DomainEntityExtractor:
    """Extract domain-specific entities."""

    def __init__(self, llm_client, domain: str):
        self.client = llm_client
        self.domain = domain
        self.entity_definitions = self._load_domain_definitions()

    def _load_domain_definitions(self) -> dict:
        """Load entity definitions for domain."""
        domains = {
            "healthcare": {
                "entities": [
                    "MEDICATION", "DOSAGE", "CONDITION", "PROCEDURE",
                    "ANATOMY", "SYMPTOM", "PROVIDER", "FACILITY"
                ],
                "examples": {
                    "MEDICATION": ["aspirin", "metformin", "lisinopril"],
                    "CONDITION": ["diabetes", "hypertension", "asthma"]
                }
            },
            "finance": {
                "entities": [
                    "COMPANY", "TICKER", "AMOUNT", "CURRENCY",
                    "DATE", "PERCENTAGE", "FINANCIAL_METRIC", "TRANSACTION"
                ],
                "examples": {
                    "TICKER": ["AAPL", "GOOGL", "MSFT"],
                    "FINANCIAL_METRIC": ["revenue", "EBITDA", "P/E ratio"]
                }
            },
            "legal": {
                "entities": [
                    "PERSON", "ORGANIZATION", "CASE_NUMBER", "COURT",
                    "DATE", "STATUTE", "JURISDICTION", "LEGAL_TERM"
                ],
                "examples": {
                    "COURT": ["Supreme Court", "District Court"],
                    "LEGAL_TERM": ["plaintiff", "defendant", "jurisdiction"]
                }
            }
        }
        return domains.get(self.domain, {"entities": [], "examples": {}})

    async def extract_domain_entities(
        self,
        text: str,
        include_relationships: bool = False
    ) -> dict:
        """Extract domain-specific entities."""

        entities = self.entity_definitions["entities"]
        examples = json.dumps(self.entity_definitions["examples"], indent=2)

        prompt = f"""Extract {self.domain} domain entities from this text.

Text:
{text}

Entity Types:
{', '.join(entities)}

Examples:
{examples}

Extract all entities with:
- text: exact text
- type: entity type
- normalized: normalized form (if applicable)
- confidence: 0.0-1.0
- context: surrounding context

{'Also extract relationships between entities.' if include_relationships else ''}

Return as JSON with 'entities' array{' and "relationships" array' if include_relationships else ''}."""

        response = await self.client.chat_completion(
            model="gpt-4",
            messages=[{"role": "user", "content": prompt}],
            temperature=0
        )

        return json.loads(response.content)

    async def extract_medical_entities(self, clinical_note: str) -> dict:
        """Extract medical entities from clinical notes."""

        prompt = f"""Extract medical entities from this clinical note.

Clinical Note:
{clinical_note}

Extract:
1. MEDICATIONS: drug names, dosages, frequencies, routes
2. CONDITIONS: diagnoses, symptoms, medical history
3. PROCEDURES: surgeries, tests, treatments
4. ANATOMY: body parts, organs mentioned
5. MEASUREMENTS: vital signs, lab values
6. PROVIDERS: doctors, nurses, specialists
7. TEMPORAL: dates, durations, frequencies

Also identify:
- Negations (e.g., "no fever" means fever is negated)
- Uncertainties (e.g., "possible pneumonia")
- Historical vs current

Return as structured JSON."""

        response = await self.client.chat_completion(
            model="gpt-4",
            messages=[{"role": "user", "content": prompt}],
            temperature=0
        )

        return json.loads(response.content)

    async def extract_financial_entities(self, document: str) -> dict:
        """Extract financial entities from documents."""

        prompt = f"""Extract financial entities from this document.

Document:
{document}

Extract:
1. COMPANIES: company names with tickers if mentioned
2. AMOUNTS: monetary values with currencies
3. DATES: relevant dates and time periods
4. METRICS: financial metrics (revenue, profit, etc.)
5. EVENTS: earnings, acquisitions, IPOs, etc.
6. PERCENTAGES: growth rates, changes, ratios
7. ANALYSTS: names of analysts or firms

Also extract:
- Sentiment towards companies (positive/negative/neutral)
- Forward-looking statements
- Risk factors mentioned

Return as structured JSON."""

        response = await self.client.chat_completion(
            model="gpt-4",
            messages=[{"role": "user", "content": prompt}],
            temperature=0
        )

        return json.loads(response.content)

Entity Linking and Resolution

class EntityLinker:
    """Link extracted entities to knowledge bases."""

    def __init__(self, llm_client, spark: SparkSession):
        self.client = llm_client
        self.spark = spark
        self.entity_cache = {}

    async def link_entities(
        self,
        entities: List[ExtractedEntity],
        knowledge_base: str = "wikidata"
    ) -> List[dict]:
        """Link entities to knowledge base entries."""

        linked_entities = []

        for entity in entities:
            # Check cache first
            cache_key = f"{entity.text}:{entity.entity_type}"
            if cache_key in self.entity_cache:
                linked_entities.append(self.entity_cache[cache_key])
                continue

            # Query for linking
            prompt = f"""Link this entity to {knowledge_base}.

Entity: {entity.text}
Type: {entity.entity_type}

Provide:
- canonical_name: standardized name
- kb_id: {knowledge_base} identifier (if known)
- aliases: other known names
- description: brief description
- confidence: 0.0-1.0

Return as JSON."""

            response = await self.client.chat_completion(
                model="gpt-4",
                messages=[{"role": "user", "content": prompt}],
                temperature=0
            )

            linked = json.loads(response.content)
            linked["original"] = entity.__dict__
            linked_entities.append(linked)

            # Cache result
            self.entity_cache[cache_key] = linked

        return linked_entities

    async def resolve_coreferences(
        self,
        text: str,
        entities: List[dict]
    ) -> dict:
        """Resolve coreferences in text."""

        entities_str = json.dumps(entities, indent=2)

        prompt = f"""Resolve coreferences in this text.

Text:
{text}

Extracted Entities:
{entities_str}

Identify:
1. Pronouns referring to entities (he, she, it, they)
2. Definite descriptions referring to entities (the company, the CEO)
3. Abbreviated references

Return as JSON with:
- coreference_chains: groups of mentions referring to same entity
- resolved_text: text with coreferences replaced by entity names"""

        response = await self.client.chat_completion(
            model="gpt-4",
            messages=[{"role": "user", "content": prompt}]
        )

        return json.loads(response.content)

    async def deduplicate_entities(
        self,
        entities: List[dict]
    ) -> List[dict]:
        """Deduplicate and merge similar entities."""

        prompt = f"""Deduplicate these entities.

Entities:
{json.dumps(entities, indent=2)}

Group entities that refer to the same real-world entity.
For each group, create a merged entity with:
- canonical_name: best name to use
- all_mentions: list of all text mentions
- merged_metadata: combined metadata
- mention_count: number of mentions

Return as JSON array of merged entities."""

        response = await self.client.chat_completion(
            model="gpt-4",
            messages=[{"role": "user", "content": prompt}]
        )

        return json.loads(response.content)

Relationship Extraction

class RelationshipExtractor:
    """Extract relationships between entities."""

    def __init__(self, llm_client):
        self.client = llm_client

    async def extract_relationships(
        self,
        text: str,
        entities: List[dict],
        relationship_types: List[str] = None
    ) -> List[dict]:
        """Extract relationships between entities."""

        default_types = [
            "WORKS_FOR", "LOCATED_IN", "OWNS", "PRODUCES",
            "PART_OF", "FOUNDED_BY", "ACQUIRED", "COMPETES_WITH"
        ]
        rel_types = relationship_types or default_types

        prompt = f"""Extract relationships between entities in this text.

Text:
{text}

Entities:
{json.dumps(entities, indent=2)}

Relationship Types: {', '.join(rel_types)}

For each relationship provide:
- subject: source entity
- predicate: relationship type
- object: target entity
- confidence: 0.0-1.0
- evidence: text supporting the relationship

Return as JSON array of relationships."""

        response = await self.client.chat_completion(
            model="gpt-4",
            messages=[{"role": "user", "content": prompt}],
            temperature=0
        )

        return json.loads(response.content)

    async def build_knowledge_graph(
        self,
        documents: List[str]
    ) -> dict:
        """Build knowledge graph from documents."""

        all_entities = []
        all_relationships = []

        for doc in documents:
            # Extract entities
            entities = await self._extract_entities(doc)
            all_entities.extend(entities)

            # Extract relationships
            relationships = await self.extract_relationships(doc, entities)
            all_relationships.extend(relationships)

        # Deduplicate and merge
        linker = EntityLinker(self.client, None)
        merged_entities = await linker.deduplicate_entities(all_entities)

        return {
            "nodes": merged_entities,
            "edges": all_relationships,
            "document_count": len(documents)
        }

    async def extract_events(
        self,
        text: str
    ) -> List[dict]:
        """Extract events and their participants."""

        prompt = f"""Extract events from this text.

Text:
{text}

For each event identify:
- event_type: what kind of event
- trigger: word/phrase indicating the event
- participants: entities involved with their roles
- time: when it happened (if mentioned)
- location: where it happened (if mentioned)
- outcome: result of the event (if mentioned)

Return as JSON array of events."""

        response = await self.client.chat_completion(
            model="gpt-4",
            messages=[{"role": "user", "content": prompt}]
        )

        return json.loads(response.content)

Production Pipeline

class ProductionEntityPipeline:
    """Production-ready entity extraction pipeline."""

    def __init__(self, spark: SparkSession, llm_client, config: dict):
        self.spark = spark
        self.client = llm_client
        self.config = config

        self.extractor = ScalableEntityExtractor(spark, llm_client)
        self.domain_extractor = DomainEntityExtractor(
            llm_client, config.get("domain", "general")
        )
        self.linker = EntityLinker(llm_client, spark)

    def process_documents(
        self,
        input_table: str,
        output_table: str,
        text_column: str
    ):
        """Process documents end-to-end."""

        # Read input
        df = self.spark.table(input_table)

        # Extract entities using Cognitive Services for scale
        df_with_entities = self.extractor.extract_with_cognitive_services(
            df, text_column
        )

        # Flatten entities
        df_flattened = df_with_entities.select(
            "*",
            explode(col("entities")).alias("entity")
        ).select(
            "*",
            col("entity.text").alias("entity_text"),
            col("entity.category").alias("entity_type"),
            col("entity.confidenceScore").alias("confidence")
        )

        # Write to Delta table
        df_flattened.write \
            .format("delta") \
            .mode("overwrite") \
            .saveAsTable(output_table)

        return df_flattened.count()

    async def enrich_entities(
        self,
        entity_table: str,
        enriched_table: str
    ):
        """Enrich extracted entities with linking."""

        # Read entities
        df = self.spark.table(entity_table)
        entities = df.select("entity_text", "entity_type").distinct().collect()

        # Link entities
        entity_list = [
            ExtractedEntity(
                text=e.entity_text,
                entity_type=e.entity_type,
                confidence=1.0,
                start_pos=0,
                end_pos=len(e.entity_text)
            )
            for e in entities
        ]

        linked = await self.linker.link_entities(entity_list)

        # Convert to DataFrame and save
        linked_df = self.spark.createDataFrame(linked)
        linked_df.write \
            .format("delta") \
            .mode("overwrite") \
            .saveAsTable(enriched_table)

# Usage
pipeline = ProductionEntityPipeline(
    spark, llm_client,
    {"domain": "finance"}
)

# Process documents
count = pipeline.process_documents(
    input_table="bronze.financial_reports",
    output_table="silver.report_entities",
    text_column="report_text"
)
print(f"Extracted entities from {count} documents")

Entity extraction at scale transforms unstructured text into structured knowledge. Combining LLM intelligence with distributed processing enables insights from document collections of any size.\n\n## Takeaways\n\nAdd a concise, personal takeaway and recommended next steps here.\n

Michael John Pena

Michael John Pena

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.