3 min read
Spot Instances for ML: Cost-Effective Training at Scale
Spot instances offer significant cost savings for ML training. Today we explore how to effectively use spot instances while handling preemption.
Understanding Spot Instances
spot_overview = {
"pricing": "60-90% cheaper than on-demand",
"risk": "Can be preempted with 30s notice",
"best_for": [
"Training jobs with checkpointing",
"Hyperparameter tuning",
"Batch processing",
"Development and experimentation"
]
}
Azure Spot Configuration
from azure.ai.ml.entities import AmlCompute
# Spot compute cluster
spot_cluster = AmlCompute(
name="spot-gpu-training",
size="Standard_NC24ads_A100_v4",
min_instances=0,
max_instances=8,
tier="LowPriority",
idle_time_before_scale_down=120
)
ml_client.compute.begin_create_or_update(spot_cluster).result()
Checkpoint-Based Training
import os
import torch
from transformers import Trainer, TrainingArguments
class CheckpointTrainer(Trainer):
def __init__(self, *args, checkpoint_dir="./checkpoints", **kwargs):
super().__init__(*args, **kwargs)
self.checkpoint_dir = checkpoint_dir
def training_step(self, model, inputs):
loss = super().training_step(model, inputs)
# Save checkpoint periodically
if self.state.global_step % 500 == 0:
self.save_checkpoint()
return loss
def save_checkpoint(self):
checkpoint_path = os.path.join(
self.checkpoint_dir,
f"checkpoint-{self.state.global_step}"
)
self.save_model(checkpoint_path)
torch.save({
"optimizer": self.optimizer.state_dict(),
"scheduler": self.lr_scheduler.state_dict(),
"step": self.state.global_step,
"epoch": self.state.epoch
}, os.path.join(checkpoint_path, "training_state.pt"))
@classmethod
def resume_from_checkpoint(cls, checkpoint_path, *args, **kwargs):
trainer = cls(*args, **kwargs)
# Load model
trainer.model.load_state_dict(
torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
)
# Load training state
state = torch.load(os.path.join(checkpoint_path, "training_state.pt"))
trainer.optimizer.load_state_dict(state["optimizer"])
trainer.lr_scheduler.load_state_dict(state["scheduler"])
trainer.state.global_step = state["step"]
return trainer
Handling Preemption
import signal
import sys
class PreemptionHandler:
def __init__(self, trainer, checkpoint_dir):
self.trainer = trainer
self.checkpoint_dir = checkpoint_dir
signal.signal(signal.SIGTERM, self.handle_preemption)
def handle_preemption(self, signum, frame):
print("Preemption signal received. Saving checkpoint...")
self.trainer.save_checkpoint()
sys.exit(0)
# Usage
handler = PreemptionHandler(trainer, "./checkpoints")
trainer.train()
Automatic Resume Job
# Training script with auto-resume
def main():
# Check for existing checkpoint
checkpoint_dir = "./checkpoints"
latest_checkpoint = find_latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
print(f"Resuming from {latest_checkpoint}")
trainer = Trainer.resume_from_checkpoint(latest_checkpoint, ...)
else:
print("Starting fresh training")
trainer = Trainer(...)
try:
trainer.train()
except Exception as e:
print(f"Training interrupted: {e}")
trainer.save_checkpoint()
raise
def find_latest_checkpoint(checkpoint_dir):
if not os.path.exists(checkpoint_dir):
return None
checkpoints = [
d for d in os.listdir(checkpoint_dir)
if d.startswith("checkpoint-")
]
if not checkpoints:
return None
latest = max(checkpoints, key=lambda x: int(x.split("-")[1]))
return os.path.join(checkpoint_dir, latest)
Cost Analysis
def calculate_spot_savings(job_hours, vm_size):
pricing = {
"Standard_NC6s_v3": {"ondemand": 0.90, "spot": 0.27},
"Standard_NC24ads_A100_v4": {"ondemand": 3.67, "spot": 1.10},
"Standard_ND96asr_v4": {"ondemand": 27.20, "spot": 8.16}
}
if vm_size not in pricing:
return None
ondemand_cost = pricing[vm_size]["ondemand"] * job_hours
spot_cost = pricing[vm_size]["spot"] * job_hours
# Account for potential restarts (assume 20% overhead)
spot_cost_with_overhead = spot_cost * 1.2
return {
"ondemand_cost": f"${ondemand_cost:.2f}",
"spot_cost": f"${spot_cost_with_overhead:.2f}",
"savings": f"${ondemand_cost - spot_cost_with_overhead:.2f}",
"savings_percent": f"{(1 - spot_cost_with_overhead/ondemand_cost)*100:.0f}%"
}
# Example: 100 hour training job on A100
print(calculate_spot_savings(100, "Standard_NC24ads_A100_v4"))
# {'ondemand_cost': '$367.00', 'spot_cost': '$132.00', 'savings': '$235.00', 'savings_percent': '64%'}
Best Practices
spot_best_practices = {
"checkpointing": {
"frequency": "Every 10-30 minutes",
"storage": "Use Azure Blob for durability",
"cleanup": "Remove old checkpoints to save storage"
},
"job_design": {
"modular": "Break into smaller jobs",
"resumable": "Always support resume from checkpoint",
"idempotent": "Jobs should be safely restartable"
},
"monitoring": {
"alerts": "Set up preemption alerts",
"logging": "Log checkpoint saves",
"metrics": "Track training progress externally"
}
}
Tomorrow we’ll explore cost optimization strategies for AI workloads.