Back to Blog
2 min read

Model Versioning: Managing AI Model Lifecycle

Effective model versioning ensures reproducibility and enables safe deployments. Here’s how to implement it.

Model Versioning System

from azure.ai.ml import MLClient
from azure.ai.ml.entities import Model
from dataclasses import dataclass
from typing import Dict, List, Optional
import mlflow

@dataclass
class ModelVersion:
    name: str
    version: str
    path: str
    metrics: Dict
    parameters: Dict
    tags: Dict

class ModelVersionManager:
    def __init__(self, ml_client: MLClient):
        self.ml_client = ml_client
        mlflow.set_tracking_uri(ml_client.tracking_uri)

    def register_model(self, model: ModelVersion) -> str:
        """Register a new model version."""
        with mlflow.start_run() as run:
            # Log parameters
            mlflow.log_params(model.parameters)

            # Log metrics
            mlflow.log_metrics(model.metrics)

            # Log model
            mlflow.log_artifact(model.path)

            # Register in Azure ML
            registered = self.ml_client.models.create_or_update(
                Model(
                    name=model.name,
                    path=model.path,
                    tags=model.tags,
                    properties={
                        "metrics": model.metrics,
                        "mlflow_run_id": run.info.run_id
                    }
                )
            )

            return registered.version

    def get_model(self, name: str, version: str = "latest") -> ModelVersion:
        """Get model by name and version."""
        if version == "latest":
            model = self.ml_client.models.get(name, label="latest")
        else:
            model = self.ml_client.models.get(name, version=version)

        return ModelVersion(
            name=model.name,
            version=model.version,
            path=model.path,
            metrics=model.properties.get("metrics", {}),
            parameters=model.tags,
            tags=model.tags
        )

    def compare_versions(self, name: str, v1: str, v2: str) -> Dict:
        """Compare two model versions."""
        model1 = self.get_model(name, v1)
        model2 = self.get_model(name, v2)

        return {
            "version1": v1,
            "version2": v2,
            "metric_comparison": {
                metric: {
                    "v1": model1.metrics.get(metric),
                    "v2": model2.metrics.get(metric),
                    "diff": model2.metrics.get(metric, 0) - model1.metrics.get(metric, 0)
                }
                for metric in set(model1.metrics.keys()) | set(model2.metrics.keys())
            },
            "parameter_changes": self.diff_params(model1.parameters, model2.parameters)
        }

    def promote_model(self, name: str, version: str, stage: str):
        """Promote model to a stage (staging, production)."""
        model = self.ml_client.models.get(name, version=version)

        # Add stage tag
        model.tags["stage"] = stage
        model.tags["promoted_at"] = datetime.now().isoformat()

        self.ml_client.models.create_or_update(model)

        # Update stage label
        self.ml_client.models.archive(name, label=stage)
        self.ml_client.models.create_or_update(
            Model(name=name, version=version, label=stage)
        )

    def rollback(self, name: str, stage: str) -> str:
        """Rollback to previous version in stage."""
        history = self.get_stage_history(name, stage)

        if len(history) < 2:
            raise ValueError("No previous version to rollback to")

        previous_version = history[-2]
        self.promote_model(name, previous_version, stage)

        return previous_version

Robust model versioning enables confident deployments and quick rollbacks.

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.