Skip to content
Back to Blog
1 min read

Spark ML Patterns for Production Systems

I wrote “Spark ML Patterns for Production Systems” to share practical, production-minded guidance on this topic.

Feature Engineering Patterns

from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import *
from pyspark.sql.functions import *

class FeatureEngineeringPipeline:
    """Production-ready feature engineering."""

    def __init__(self, spark: SparkSession):
        self.spark = spark

    def create_numeric_pipeline(
        self,
        numeric_cols: list,
        strategy: str = "mean"
    ) -> list:
        """Create numeric feature processing stages."""
        stages = []

        # Impute missing values
        imputer = Imputer(
            inputCols=numeric_cols,
            outputCols=[f"{c}_imputed" for c in numeric_cols],
            strategy=strategy
        )
        stages.append(imputer)

        # Scale features
        assembler = VectorAssembler(
            inputCols=[f"{c}_imputed" for c in numeric_cols],
            outputCol="numeric_assembled"
        )
        stages.append(assembler)

        scaler = StandardScaler(
            inputCol="numeric_assembled",
            outputCol="numeric_scaled",
            withMean=True,
            withStd=True
        )
        stages.append(scaler)

        return stages

    def create_categorical_pipeline(
        self,
        categorical_cols: list,
        max_categories: int = 100
    ) -> list:
        """Create categorical feature processing stages."""
        stages = []

        for col in categorical_cols:
            # String indexing
            indexer = StringIndexer(
                inputCol=col,
                outputCol=f"{col}_indexed",
                handleInvalid="keep",
                stringOrderType="frequencyDesc"
            )
            stages.append(indexer)

            # One-hot encoding (for low cardinality)
            encoder = OneHotEncoder(
                inputCols=[f"{col}_indexed"],
                outputCols=[f"{col}_encoded"],
                dropLast=True,
                handleInvalid="keep"
            )
            stages.append(encoder)

        return stages

    def create_text_pipeline(
        self,
        text_col: str,
        output_col: str = "text_features"
    ) -> list:
        """Create text feature processing stages."""
        stages = []

        # Tokenization
        tokenizer = Tokenizer(
            inputCol=text_col,
            outputCol=f"{text_col}_tokens"
        )
        stages.append(tokenizer)

        # Remove stop words
        remover = StopWordsRemover(
            inputCol=f"{text_col}_tokens",
            outputCol=f"{text_col}_filtered"
        )
        stages.append(remover)

        # TF-IDF
        hashing_tf = HashingTF(
            inputCol=f"{text_col}_filtered",
            outputCol=f"{text_col}_tf",
            numFeatures=10000
        )
        stages.append(hashing_tf)

        idf = IDF(
            inputCol=f"{text_col}_tf",
            outputCol=output_col,
            minDocFreq=5
        )
        stages.append(idf)

        return stages

    def create_datetime_features(self, df, datetime_col: str):
        """Extract datetime features."""
        return df.select(
            "*",
            year(datetime_col).alias(f"{datetime_col}_year"),
            month(datetime_col).alias(f"{datetime_col}_month"),
            dayofmonth(datetime_col).alias(f"{datetime_col}_day"),
            dayofweek(datetime_col).alias(f"{datetime_col}_dow"),
            hour(datetime_col).alias(f"{datetime_col}_hour"),
            weekofyear(datetime_col).alias(f"{datetime_col}_week"),
            quarter(datetime_col).alias(f"{datetime_col}_quarter")
        )

Model Training Patterns

from pyspark.ml.classification import *
from pyspark.ml.regression import *
from pyspark.ml.tuning import *
from pyspark.ml.evaluation import *

class ModelTrainingPipeline:
    """Production model training patterns."""

    def __init__(self, spark: SparkSession):
        self.spark = spark

    def train_with_cross_validation(
        self,
        df,
        estimator,
        param_grid,
        evaluator,
        num_folds: int = 5,
        parallelism: int = 4
    ):
        """Train model with cross-validation."""

        cv = CrossValidator(
            estimator=estimator,
            estimatorParamMaps=param_grid,
            evaluator=evaluator,
            numFolds=num_folds,
            parallelism=parallelism,
            collectSubModels=False
        )

        cv_model = cv.fit(df)

        return {
            "best_model": cv_model.bestModel,
            "avg_metrics": cv_model.avgMetrics,
            "std_metrics": cv_model.stdMetrics if hasattr(cv_model, "stdMetrics") else None
        }

    def train_with_train_validation_split(
        self,
        df,
        estimator,
        param_grid,
        evaluator,
        train_ratio: float = 0.8
    ):
        """Train with train-validation split (faster than CV)."""

        tvs = TrainValidationSplit(
            estimator=estimator,
            estimatorParamMaps=param_grid,
            evaluator=evaluator,
            trainRatio=train_ratio,
            parallelism=4
        )

        tvs_model = tvs.fit(df)

        return {
            "best_model": tvs_model.bestModel,
            "validation_metrics": tvs_model.validationMetrics
        }

    def create_gradient_boosted_trees(
        self,
        label_col: str,
        feature_col: str = "features"
    ):
        """Create GBT classifier with param grid."""

        gbt = GBTClassifier(
            featuresCol=feature_col,
            labelCol=label_col,
            predictionCol="prediction"
        )

        param_grid = ParamGridBuilder() \
            .addGrid(gbt.maxDepth, [3, 5, 7]) \
            .addGrid(gbt.maxIter, [10, 20, 50]) \
            .addGrid(gbt.stepSize, [0.05, 0.1, 0.2]) \
            .build()

        return gbt, param_grid

    def create_random_forest(
        self,
        label_col: str,
        feature_col: str = "features"
    ):
        """Create Random Forest with param grid."""

        rf = RandomForestClassifier(
            featuresCol=feature_col,
            labelCol=label_col,
            predictionCol="prediction"
        )

        param_grid = ParamGridBuilder() \
            .addGrid(rf.numTrees, [50, 100, 200]) \
            .addGrid(rf.maxDepth, [5, 10, 15]) \
            .addGrid(rf.featureSubsetStrategy, ["sqrt", "log2", "onethird"]) \
            .build()

        return rf, param_grid

    def create_logistic_regression(
        self,
        label_col: str,
        feature_col: str = "features"
    ):
        """Create Logistic Regression with param grid."""

        lr = LogisticRegression(
            featuresCol=feature_col,
            labelCol=label_col,
            predictionCol="prediction",
            probabilityCol="probability"
        )

        param_grid = ParamGridBuilder() \
            .addGrid(lr.regParam, [0.01, 0.1, 1.0]) \
            .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0]) \
            .addGrid(lr.maxIter, [50, 100, 200]) \
            .build()

        return lr, param_grid

Model Evaluation Patterns

class ModelEvaluation:
    """Comprehensive model evaluation."""

    def __init__(self, spark: SparkSession):
        self.spark = spark

    def evaluate_binary_classification(
        self,
        predictions,
        label_col: str,
        probability_col: str = "probability"
    ) -> dict:
        """Evaluate binary classification model."""

        from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

        binary_evaluator = BinaryClassificationEvaluator(
            labelCol=label_col,
            rawPredictionCol="rawPrediction"
        )

        multi_evaluator = MulticlassClassificationEvaluator(
            labelCol=label_col,
            predictionCol="prediction"
        )

        metrics = {
            "auc_roc": binary_evaluator.evaluate(
                predictions, {binary_evaluator.metricName: "areaUnderROC"}
            ),
            "auc_pr": binary_evaluator.evaluate(
                predictions, {binary_evaluator.metricName: "areaUnderPR"}
            ),
            "accuracy": multi_evaluator.evaluate(
                predictions, {multi_evaluator.metricName: "accuracy"}
            ),
            "f1": multi_evaluator.evaluate(
                predictions, {multi_evaluator.metricName: "f1"}
            ),
            "precision": multi_evaluator.evaluate(
                predictions, {multi_evaluator.metricName: "weightedPrecision"}
            ),
            "recall": multi_evaluator.evaluate(
                predictions, {multi_evaluator.metricName: "weightedRecall"}
            )
        }

        return metrics

    def compute_confusion_matrix(
        self,
        predictions,
        label_col: str,
        prediction_col: str = "prediction"
    ):
        """Compute confusion matrix."""

        confusion = predictions.groupBy(label_col, prediction_col).count()

        # Pivot to matrix format
        matrix = confusion.groupBy(label_col).pivot(prediction_col).sum("count")

        return matrix

    def evaluate_regression(
        self,
        predictions,
        label_col: str,
        prediction_col: str = "prediction"
    ) -> dict:
        """Evaluate regression model."""

        from pyspark.ml.evaluation import RegressionEvaluator

        evaluator = RegressionEvaluator(
            labelCol=label_col,
            predictionCol=prediction_col
        )

        return {
            "rmse": evaluator.evaluate(predictions, {evaluator.metricName: "rmse"}),
            "mse": evaluator.evaluate(predictions, {evaluator.metricName: "mse"}),
            "mae": evaluator.evaluate(predictions, {evaluator.metricName: "mae"}),
            "r2": evaluator.evaluate(predictions, {evaluator.metricName: "r2"})
        }

    def compute_lift_chart(
        self,
        predictions,
        label_col: str,
        probability_col: str = "probability",
        num_buckets: int = 10
    ):
        """Compute lift chart data."""

        from pyspark.sql.functions import col, ntile
        from pyspark.ml.functions import vector_to_array

        # Extract positive class probability
        with_prob = predictions.withColumn(
            "positive_prob",
            vector_to_array(col(probability_col))[1]
        )

        # Create deciles
        with_decile = with_prob.withColumn(
            "decile",
            ntile(num_buckets).over(Window.orderBy(col("positive_prob").desc()))
        )

        # Compute lift per decile
        lift_data = with_decile.groupBy("decile").agg(
            count("*").alias("count"),
            sum(col(label_col)).alias("positives"),
            avg("positive_prob").alias("avg_score")
        ).orderBy("decile")

        # Calculate cumulative lift
        total_positives = predictions.filter(col(label_col) == 1).count()
        total_count = predictions.count()
        base_rate = total_positives / total_count

        return lift_data.withColumn(
            "lift",
            (col("positives") / col("count")) / lit(base_rate)
        )

Pipeline Persistence

class PipelinePersistence:
    """Save and load ML pipelines."""

    def __init__(self, spark: SparkSession, base_path: str):
        self.spark = spark
        self.base_path = base_path

    def save_pipeline(
        self,
        pipeline_model,
        name: str,
        version: str,
        metadata: dict = None
    ):
        """Save pipeline with versioning."""
        import json
        from datetime import datetime

        path = f"{self.base_path}/{name}/v{version}"

        # Save model
        pipeline_model.write().overwrite().save(f"{path}/model")

        # Save metadata
        meta = {
            "name": name,
            "version": version,
            "saved_at": datetime.utcnow().isoformat(),
            "spark_version": self.spark.version,
            **(metadata or {})
        }

        self.spark.sparkContext.parallelize([json.dumps(meta)]).saveAsTextFile(
            f"{path}/metadata"
        )

        print(f"Pipeline saved to {path}")

    def load_pipeline(self, name: str, version: str = "latest"):
        """Load pipeline by name and version."""
        from pyspark.ml import PipelineModel

        if version == "latest":
            version = self._get_latest_version(name)

        path = f"{self.base_path}/{name}/v{version}"
        return PipelineModel.load(f"{path}/model")

    def _get_latest_version(self, name: str) -> str:
        """Get latest version of a pipeline."""
        import os

        path = f"{self.base_path}/{name}"
        versions = [d.name.replace("v", "") for d in os.listdir(path) if d.startswith("v")]
        return max(versions, key=lambda x: [int(p) for p in x.split(".")])

    def register_model(
        self,
        pipeline_model,
        name: str,
        metrics: dict,
        stage: str = "staging"
    ):
        """Register model for promotion workflow."""
        import mlflow
        from mlflow.spark import log_model

        with mlflow.start_run():
            # Log metrics
            for key, value in metrics.items():
                mlflow.log_metric(key, value)

            # Log model
            log_model(
                pipeline_model,
                artifact_path="model",
                registered_model_name=name
            )

            # Set stage
            client = mlflow.tracking.MlflowClient()
            latest_version = client.get_latest_versions(name)[0].version
            client.transition_model_version_stage(
                name=name,
                version=latest_version,
                stage=stage
            )

Batch Inference Pattern

class BatchInference:
    """Production batch inference patterns."""

    def __init__(self, spark: SparkSession):
        self.spark = spark

    def run_batch_inference(
        self,
        model,
        input_table: str,
        output_table: str,
        feature_pipeline = None,
        partition_cols: list = None
    ):
        """Run batch inference on a table."""

        # Read input
        df = self.spark.table(input_table)

        # Apply feature pipeline if provided
        if feature_pipeline:
            df = feature_pipeline.transform(df)

        # Run inference
        predictions = model.transform(df)

        # Select relevant columns
        output_cols = [c for c in predictions.columns
                      if c not in ["features", "rawPrediction", "probability"]]
        predictions = predictions.select(output_cols)

        # Write output
        if partition_cols:
            predictions.write \
                .mode("overwrite") \
                .partitionBy(partition_cols) \
                .saveAsTable(output_table)
        else:
            predictions.write \
                .mode("overwrite") \
                .saveAsTable(output_table)

        return predictions.count()

    def incremental_inference(
        self,
        model,
        source_table: str,
        target_table: str,
        watermark_col: str,
        last_watermark
    ):
        """Run inference on new data only."""

        # Read only new data
        df = self.spark.table(source_table) \
            .filter(col(watermark_col) > last_watermark)

        if df.count() == 0:
            print("No new data to process")
            return 0

        # Run inference
        predictions = model.transform(df)

        # Append to target
        predictions.write \
            .mode("append") \
            .saveAsTable(target_table)

        return predictions.count()

# Usage
inference = BatchInference(spark)

# Load model
model = PipelineModel.load("/models/customer_churn/v1.0/model")

# Run batch inference
count = inference.run_batch_inference(
    model,
    input_table="silver.customer_features",
    output_table="gold.churn_predictions",
    partition_cols=["prediction_date"]
)
print(f"Processed {count} records")

Spark ML provides battle-tested patterns for production machine learning. From feature engineering to model persistence, these patterns ensure reliable ML systems at scale.\n\n## Takeaways\n\nAdd a concise, personal takeaway and recommended next steps here.\n

Michael John Pena

Michael John Pena

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.