Back to Blog
6 min read

Databricks Feature Store: Centralized Feature Management for ML

The Databricks Feature Store provides centralized management for machine learning features, enabling feature reuse, discovery, and consistent serving across training and inference.

Why Feature Stores Matter

Without a feature store:

  • Teams duplicate feature engineering work
  • Training/serving skew causes prediction errors
  • Feature discovery is difficult
  • Point-in-time correctness is hard to achieve

The Feature Store solves these problems.

Creating Feature Tables

from databricks.feature_store import FeatureStoreClient, FeatureLookup

fs = FeatureStoreClient()

# Create customer features
customer_features_df = spark.sql("""
    SELECT
        customer_id,
        COUNT(DISTINCT order_id) as total_orders,
        SUM(order_amount) as lifetime_value,
        AVG(order_amount) as avg_order_value,
        DATEDIFF(current_date(), MAX(order_date)) as days_since_last_order,
        DATEDIFF(current_date(), MIN(order_date)) as customer_tenure_days
    FROM production.sales.orders
    GROUP BY customer_id
""")

# Create feature table
fs.create_table(
    name="production.features.customer_features",
    primary_keys=["customer_id"],
    df=customer_features_df,
    description="Customer behavior features derived from order history",
    tags={"team": "data-science", "domain": "customer"}
)

Time-Series Feature Tables

For point-in-time correctness:

# Create time-series feature table with timestamp key
product_price_features = spark.sql("""
    SELECT
        product_id,
        effective_date,
        price,
        discount_pct,
        price * (1 - discount_pct/100) as effective_price
    FROM production.catalog.price_history
""")

fs.create_table(
    name="production.features.product_pricing",
    primary_keys=["product_id"],
    timestamp_keys=["effective_date"],
    df=product_price_features,
    description="Historical product pricing for point-in-time joins"
)

Updating Features

# Compute updated features
updated_features = compute_customer_features(spark)

# Write to feature table (merge/upsert)
fs.write_table(
    name="production.features.customer_features",
    df=updated_features,
    mode="merge"  # or "overwrite"
)

# For streaming updates
streaming_features = (
    spark.readStream
    .table("production.sales.orders")
    .groupBy("customer_id")
    .agg(...)
)

fs.write_table(
    name="production.features.customer_features",
    df=streaming_features,
    mode="merge",
    trigger={"processingTime": "1 hour"}
)

Feature Discovery

# Search for features
features = fs.search_feature_tables(
    filter_string="tags.domain = 'customer'"
)

for feature in features:
    print(f"Table: {feature.name}")
    print(f"Description: {feature.description}")
    print(f"Primary keys: {feature.primary_keys}")
    print("---")

# Get feature table details
table = fs.get_table("production.features.customer_features")
print(f"Features: {[f.name for f in table.features]}")

Training with Feature Lookups

# Define training data with labels
training_labels = spark.sql("""
    SELECT
        customer_id,
        order_date as label_date,
        CASE WHEN churned_within_30_days THEN 1 ELSE 0 END as label
    FROM production.analytics.churn_labels
""")

# Define feature lookups
feature_lookups = [
    FeatureLookup(
        table_name="production.features.customer_features",
        lookup_key=["customer_id"],
        feature_names=["total_orders", "lifetime_value", "avg_order_value", "days_since_last_order"]
    ),
    FeatureLookup(
        table_name="production.features.customer_demographics",
        lookup_key=["customer_id"],
        feature_names=["age_group", "region", "account_type"]
    )
]

# Create training set with automatic feature joins
training_set = fs.create_training_set(
    df=training_labels,
    feature_lookups=feature_lookups,
    label="label",
    exclude_columns=["customer_id", "label_date"]  # Don't use as features
)

# Load as pandas DataFrame
training_df = training_set.load_df().toPandas()

# Or load as Spark DataFrame for distributed training
training_spark_df = training_set.load_df()

Point-in-Time Lookups

# For time-series features, ensure point-in-time correctness
training_labels = spark.sql("""
    SELECT
        customer_id,
        product_id,
        transaction_date,
        label
    FROM training_data
""")

feature_lookups = [
    # Customer features at time of transaction
    FeatureLookup(
        table_name="production.features.customer_features_timeseries",
        lookup_key=["customer_id"],
        timestamp_lookup_key=["transaction_date"],
        feature_names=["orders_last_30d", "spend_last_30d"]
    ),
    # Product pricing at time of transaction
    FeatureLookup(
        table_name="production.features.product_pricing",
        lookup_key=["product_id"],
        timestamp_lookup_key=["transaction_date"],
        feature_names=["price", "effective_price"]
    )
]

# Features are joined as-of the timestamp, preventing data leakage
training_set = fs.create_training_set(
    df=training_labels,
    feature_lookups=feature_lookups,
    label="label"
)

Training and Logging Models

from sklearn.ensemble import GradientBoostingClassifier
import mlflow

# Prepare data
X = training_df.drop(columns=["label"])
y = training_df["label"]

# Train model
model = GradientBoostingClassifier(n_estimators=100, max_depth=5)
model.fit(X, y)

# Log model with feature store metadata
fs.log_model(
    model=model,
    artifact_path="model",
    flavor=mlflow.sklearn,
    training_set=training_set,
    registered_model_name="churn-predictor"
)

Batch Scoring with Feature Store

# Score new customers - features are automatically looked up
customers_to_score = spark.sql("""
    SELECT customer_id FROM production.customers.active_users
""")

# Load model logged with feature store
model_uri = "models:/churn-predictor/Production"

# Score batch - feature lookups happen automatically
scored_df = fs.score_batch(
    model_uri=model_uri,
    df=customers_to_score,
    result_type="float"  # Probability output
)

# Save predictions
scored_df.write.mode("overwrite").saveAsTable("analytics.predictions.churn_scores")

Real-Time Serving

# Publish features to online store for real-time serving
from databricks.feature_store import FeatureStoreClient
from databricks.feature_store.online_store_spec import AzureCosmosDBSpec

# Configure online store
online_store_spec = AzureCosmosDBSpec(
    account_uri="https://myaccount.documents.azure.com:443/",
    write_secret_prefix="feature-store/cosmosdb",
    read_secret_prefix="feature-store/cosmosdb",
    database_name="feature_store",
    container_name="customer_features"
)

# Publish to online store
fs.publish_table(
    name="production.features.customer_features",
    online_store=online_store_spec,
    mode="merge"
)

Feature Engineering Patterns

Aggregate Features

def compute_aggregate_features(spark, window_days=[7, 30, 90]):
    """Compute rolling aggregate features"""

    base_query = """
    SELECT
        customer_id,
        SUM(CASE WHEN order_date >= date_sub(current_date(), {days}) THEN 1 ELSE 0 END) as orders_last_{days}d,
        SUM(CASE WHEN order_date >= date_sub(current_date(), {days}) THEN order_amount ELSE 0 END) as spend_last_{days}d,
        AVG(CASE WHEN order_date >= date_sub(current_date(), {days}) THEN order_amount END) as avg_order_last_{days}d
    FROM production.sales.orders
    GROUP BY customer_id
    """

    dfs = []
    for days in window_days:
        df = spark.sql(base_query.format(days=days))
        dfs.append(df)

    # Join all windows
    result = dfs[0]
    for df in dfs[1:]:
        result = result.join(df, "customer_id", "outer")

    return result

Embedding Features

# Store model embeddings as features
from sentence_transformers import SentenceTransformer

def compute_text_embeddings(texts):
    model = SentenceTransformer('all-MiniLM-L6-v2')
    return model.encode(texts)

# Compute embeddings for product descriptions
@pandas_udf("array<float>")
def embed_text_udf(texts: pd.Series) -> pd.Series:
    embeddings = compute_text_embeddings(texts.tolist())
    return pd.Series([e.tolist() for e in embeddings])

product_embeddings = spark.sql("""
    SELECT product_id, description FROM production.catalog.products
""").withColumn("description_embedding", embed_text_udf(col("description")))

# Store as feature table
fs.create_table(
    name="production.features.product_embeddings",
    primary_keys=["product_id"],
    df=product_embeddings.select("product_id", "description_embedding"),
    description="Product description embeddings using sentence-transformers"
)

Feature Monitoring

def monitor_feature_drift(feature_table, reference_date, current_date):
    """Compare feature distributions between two time periods"""
    from scipy import stats

    reference_df = spark.sql(f"""
        SELECT * FROM {feature_table}
        WHERE snapshot_date = '{reference_date}'
    """).toPandas()

    current_df = spark.sql(f"""
        SELECT * FROM {feature_table}
        WHERE snapshot_date = '{current_date}'
    """).toPandas()

    drift_results = {}
    numeric_columns = reference_df.select_dtypes(include=[np.number]).columns

    for col in numeric_columns:
        # KS test for distribution shift
        ks_stat, p_value = stats.ks_2samp(reference_df[col], current_df[col])
        drift_results[col] = {
            "ks_statistic": ks_stat,
            "p_value": p_value,
            "drift_detected": p_value < 0.05
        }

    return drift_results

# Run drift detection
drift = monitor_feature_drift(
    "production.features.customer_features",
    "2022-02-01",
    "2022-03-01"
)

for feature, result in drift.items():
    if result["drift_detected"]:
        print(f"DRIFT DETECTED: {feature} (p={result['p_value']:.4f})")

Best Practices

Feature Table Design

# Good: Clear, documented feature tables
fs.create_table(
    name="production.features.customer_lifetime_metrics",
    primary_keys=["customer_id"],
    df=features_df,
    description="""
    Customer lifetime metrics computed from order history.
    Updated daily via scheduled job.
    Owner: data-science-team
    """,
    tags={
        "domain": "customer",
        "update_frequency": "daily",
        "data_quality": "production"
    }
)

# Good: Versioned feature definitions
feature_version = "v2"
fs.create_table(
    name=f"production.features.customer_features_{feature_version}",
    ...
)

Avoiding Training-Serving Skew

# Always use feature store for both training and serving
# DON'T: Manually join features for training
# DO: Use create_training_set

# DON'T: Compute features differently at serving time
# DO: Use score_batch or publish to online store

# DON'T: Include raw features in model
# DO: Log model with training_set reference
fs.log_model(
    model=model,
    training_set=training_set,  # Captures feature lineage
    ...
)

Conclusion

The Databricks Feature Store brings discipline and organization to feature engineering:

  • Centralized storage enables discovery and reuse
  • Automatic feature lookups prevent training/serving skew
  • Point-in-time joins ensure correctness
  • Integration with MLflow provides complete lineage

By investing in feature infrastructure, teams can iterate faster, avoid duplicate work, and deploy more reliable models.

Resources

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.