diff --git a/claas/modal/worker.py b/claas/modal/worker.py index 1cfcb78..63bf5b4 100644 --- a/claas/modal/worker.py +++ b/claas/modal/worker.py @@ -90,7 +90,7 @@ def distill(self, request: DistillBatchRequestPayload) -> DistillResponse: Distillation response payload. """ try: - return self.trainer.distill(request) + return self.trainer.distill(request).response finally: self.trainer.offload_base_model() diff --git a/claas/training/distillation.py b/claas/training/distillation.py index b64c964..13fa63e 100644 --- a/claas/training/distillation.py +++ b/claas/training/distillation.py @@ -11,6 +11,13 @@ import torch from claas.core.types import DistillBatchRequestPayload, DistillResponse, SDPOLossInput +from claas.training.engine.local.cache import ( + DistillStepResult, + LoraAdapterConfig, + LoraCacheEntry, + cpu_optimizer_state, + gpu_optimizer_state, +) from claas.training.sdpo_loss import compute_sdpo_loss from claas.training.storage import ( cleanup_local_lora, @@ -86,6 +93,10 @@ def load_base_model(self) -> None: self.optimizer_cls = torch.optim.AdamW self.functional = torch.nn.functional + def reload_base_model(self) -> None: + """Move base model from CPU back to CUDA.""" + self.base_model.to(self.device) # type: ignore[arg-type] # functools.wraps confuses ty + def offload_base_model(self) -> None: """Move base model to CPU and release CUDA memory.""" @@ -129,6 +140,33 @@ def _load_or_create_lora(self, lora_path: str) -> "PeftModel | PeftMixedModel": ) return get_peft_model(self.base_model, lora_config) + def _load_lora_from_cache( + self, + cached: LoraCacheEntry, + ) -> "PeftModel | PeftMixedModel": + """Restore a LoRA adapter from a CPU cache entry. + + Args: + cached: CPU-resident snapshot of adapter state. + + Returns: + Trainable PEFT model with cached weights loaded. + """ + from peft import LoraConfig, get_peft_model, set_peft_model_state_dict + + cfg = cached.adapter_config + lora_config = LoraConfig( + r=cfg.r, + lora_alpha=cfg.lora_alpha, + target_modules=cfg.target_modules, + lora_dropout=cfg.lora_dropout, + bias=cfg.bias, + task_type=cfg.task_type, + ) + model = get_peft_model(self.base_model, lora_config) + set_peft_model_state_dict(model, cached.lora_state_dict) + return model + def _load_optimizer_state( self, lora_path: str, @@ -213,14 +251,58 @@ def _build_self_teacher_topk( torch.cuda.empty_cache() return top_logprobs, top_indices - def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: + def _build_cache_entry( + self, + model: "PeftModel | PeftMixedModel", + optimizer: "torch.optim.Optimizer", + ) -> LoraCacheEntry: + """Snapshot current model + optimizer state into a CPU-resident cache entry.""" + from peft import PeftModel as PeftModelCls + + peft_config = model.peft_config["default"] + adapter_config = LoraAdapterConfig( + r=peft_config.r, + lora_alpha=peft_config.lora_alpha, + target_modules=list(peft_config.target_modules), + lora_dropout=peft_config.lora_dropout, + bias=peft_config.bias, + task_type=peft_config.task_type, + ) + + # Determine state dict — use PEFT's adapter-only extraction if available + if isinstance(model, PeftModelCls): + from peft import get_peft_model_state_dict + + raw_state = get_peft_model_state_dict(model) + else: + raw_state = model.state_dict() + + lora_state = {k: v.detach().cpu().clone() for k, v in raw_state.items()} + opt_state = cpu_optimizer_state(optimizer.state_dict()) + + return LoraCacheEntry( + lora_state_dict=lora_state, + optimizer_state_dict=opt_state, + adapter_config=adapter_config, + ) + + def distill( + self, + payload: DistillBatchRequestPayload, + *, + cached: LoraCacheEntry | None = None, + ) -> DistillStepResult: """Run one SDPO distillation step. Args: payload: Distillation request payload. + cached: When provided, skip disk reads and load LoRA + optimizer + state from this CPU-resident cache entry. When ``None``, + load from disk (cold start). Returns: - Distillation response with metrics. + Result containing both the distillation response and a cache + entry for the post-step state. """ torch.cuda.empty_cache() @@ -231,10 +313,18 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: if len(payload.samples) == 0: raise ValueError("samples must contain at least one item") - lora_local_path = load_lora(payload.lora_id) + # Disk path (cold start) or cache path + lora_local_path: str | None = None + if cached is None: + lora_local_path = load_lora(payload.lora_id) + try: try: - model = self._load_or_create_lora(lora_local_path) + if cached is not None: + model = self._load_lora_from_cache(cached) + else: + assert lora_local_path is not None + model = self._load_or_create_lora(lora_local_path) model.train() model.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False}, @@ -249,7 +339,13 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: betas=(0.9, 0.999), weight_decay=0.01, ) - self._load_optimizer_state(lora_local_path, optimizer) + + if cached is not None: + optimizer.load_state_dict( + gpu_optimizer_state(cached.optimizer_state_dict, self.device) + ) + elif lora_local_path is not None: + self._load_optimizer_state(lora_local_path, optimizer) batch_loss_tensors: list[torch.Tensor] = [] batch_distill_loss: list[float] = [] @@ -362,10 +458,12 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: clip_fraction = sum(batch_clip_fraction) / len(batch_clip_fraction) grad_norm_value = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm) + cache_entry = self._build_cache_entry(model, optimizer) + del model, optimizer, batch_loss_tensors torch.cuda.empty_cache() - return DistillResponse.model_validate( + response = DistillResponse.model_validate( { "lora_id": new_lora_id, "metadata": { @@ -380,5 +478,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: }, } ) + return DistillStepResult(response=response, cache_entry=cache_entry) finally: - cleanup_local_lora(lora_local_path) + if lora_local_path is not None: + cleanup_local_lora(lora_local_path) diff --git a/claas/training/engine/local/cache.py b/claas/training/engine/local/cache.py new file mode 100644 index 0000000..309e47d --- /dev/null +++ b/claas/training/engine/local/cache.py @@ -0,0 +1,85 @@ +"""Typed cache structures and helpers for CPU-resident LoRA state between training steps.""" + +from __future__ import annotations + +import copy +from dataclasses import dataclass +from typing import cast + +import torch + +from claas.core.types import DistillResponse + + +@dataclass(frozen=True, slots=True) +class LoraAdapterConfig: + """Typed representation of LoRA adapter configuration.""" + + r: int + lora_alpha: int + target_modules: list[str] + lora_dropout: float + bias: str + task_type: str + + +@dataclass(frozen=True, slots=True) +class LoraCacheEntry: + """CPU-resident snapshot of LoRA adapter state between training steps.""" + + lora_state_dict: dict[str, torch.Tensor] + optimizer_state_dict: dict[str, object] + adapter_config: LoraAdapterConfig + + +@dataclass(frozen=True, slots=True) +class DistillStepResult: + """Result of a distillation step with both response and cache entry.""" + + response: DistillResponse + cache_entry: LoraCacheEntry + + +def cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]: + """Deep-copy optimizer state with all tensors moved to CPU.""" + result: dict[str, object] = {} + for key, value in state_dict.items(): + if key == "state": + param_states = cast("dict[int, dict[str, object]]", value) + cpu_states: dict[int, dict[str, object]] = {} + for param_id, param_state in param_states.items(): + cpu_param: dict[str, object] = {} + for k, v in param_state.items(): + if isinstance(v, torch.Tensor): + cpu_param[k] = v.detach().cpu().clone() + else: + cpu_param[k] = copy.deepcopy(v) + cpu_states[param_id] = cpu_param + result[key] = cpu_states + else: + result[key] = copy.deepcopy(value) + return result + + +def gpu_optimizer_state( + state_dict: dict[str, object], + device: torch.device, +) -> dict[str, object]: + """Deep-copy optimizer state with all tensors moved to a target device.""" + result: dict[str, object] = {} + for key, value in state_dict.items(): + if key == "state": + param_states = cast("dict[int, dict[str, object]]", value) + gpu_states: dict[int, dict[str, object]] = {} + for param_id, param_state in param_states.items(): + gpu_param: dict[str, object] = {} + for k, v in param_state.items(): + if isinstance(v, torch.Tensor): + gpu_param[k] = v.detach().to(device).clone() + else: + gpu_param[k] = copy.deepcopy(v) + gpu_states[param_id] = gpu_param + result[key] = gpu_states + else: + result[key] = copy.deepcopy(value) + return result diff --git a/claas/training/engine/local/engine.py b/claas/training/engine/local/engine.py index 0533127..4f850ed 100644 --- a/claas/training/engine/local/engine.py +++ b/claas/training/engine/local/engine.py @@ -3,7 +3,9 @@ from __future__ import annotations import asyncio +import logging import re +import threading from claas.core.config import LocalConfig from claas.core.types import ( @@ -20,6 +22,7 @@ ) from claas.training.distillation import DistillationTrainer from claas.training.engine.base import TrainingEngine +from claas.training.engine.local.cache import LoraCacheEntry from claas.training.storage import ( configure_storage_backend, create_initial_lora, @@ -31,14 +34,34 @@ resolve_lora_id, ) +logger = logging.getLogger(__name__) + class LocalTrainingEngine(TrainingEngine): """Executes training and LoRA operations on local infrastructure.""" + _trainer: DistillationTrainer + _lora_cache: dict[str, LoraCacheEntry] + _cache_lock: threading.Lock + _model_loaded: bool + def __init__(self, cfg: LocalConfig) -> None: configure_storage_backend("local_fs") self._base_model_id = cfg.base_model_id self._attn_implementation = cfg.attn_implementation + self._trainer = DistillationTrainer( + base_model_id=cfg.base_model_id, + attn_implementation=cfg.attn_implementation, + ) + self._lora_cache = {} + self._cache_lock = threading.Lock() + self._model_loaded = False + + async def _ensure_model_loaded(self) -> None: + """One-time base model load on first distill() call.""" + if not self._model_loaded: + await asyncio.to_thread(self._trainer.load_base_model) + self._model_loaded = True async def distill( self, @@ -52,15 +75,24 @@ async def distill( Returns: Distillation response. """ - trainer = DistillationTrainer( - base_model_id=self._base_model_id, - attn_implementation=self._attn_implementation, - ) - await asyncio.to_thread(trainer.load_base_model) + await self._ensure_model_loaded() + await asyncio.to_thread(self._trainer.reload_base_model) + + resolved_id = await asyncio.to_thread(resolve_lora_id, payload.lora_id) + with self._cache_lock: + cached = self._lora_cache.get(resolved_id) + try: - return await asyncio.to_thread(trainer.distill, payload) + result = await asyncio.to_thread( + self._trainer.distill, payload, cached=cached + ) finally: - await asyncio.to_thread(trainer.offload_base_model) + await asyncio.to_thread(self._trainer.offload_base_model) + + with self._cache_lock: + self._lora_cache[resolved_id] = result.cache_entry + + return result.response async def init_lora(self, request: LoraInitRequest) -> LoraInitResponse: """Initialize a LoRA adapter locally. @@ -82,7 +114,11 @@ async def init_lora(self, request: LoraInitRequest) -> LoraInitResponse: return LoraInitResponse(lora_id=lora_id) async def delete_lora(self, lora_id: str) -> LoraDeleteResponse: + resolved_id = await asyncio.to_thread(resolve_lora_id, lora_id) deleted = await asyncio.to_thread(delete_lora, lora_id) + if deleted: + with self._cache_lock: + self._lora_cache.pop(resolved_id, None) return LoraDeleteResponse(deleted=deleted) async def list_loras(self, prefix: str) -> LoraListResponse: diff --git a/tests/test_distillation_optimizer_state.py b/tests/test_distillation_optimizer_state.py index ecce841..cdeebd0 100644 --- a/tests/test_distillation_optimizer_state.py +++ b/tests/test_distillation_optimizer_state.py @@ -7,6 +7,12 @@ torch = pytest.importorskip("torch") from claas.training.distillation import DistillationTrainer # noqa: E402 +from claas.training.engine.local.cache import ( # noqa: E402 + LoraAdapterConfig, + LoraCacheEntry, + cpu_optimizer_state, + gpu_optimizer_state, +) class _SimpleLoraModel(torch.nn.Module): @@ -90,3 +96,81 @@ def test_optimizer_state_missing_gracefully_skips(trainer: DistillationTrainer, trainer._load_optimizer_state(str(tmp_path), optimizer) assert len(optimizer.state) == 0 + + +def testcpu_optimizer_state_moves_tensors_to_cpu() -> None: + """cpu_optimizer_state produces a state dict with all tensors on CPU.""" + model = _SimpleLoraModel() + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + loss = model.first.sum() + loss.backward() + optimizer.step() + + original = optimizer.state_dict() + cpu_state = cpu_optimizer_state(original) + + for param_state in cpu_state["state"].values(): + for v in param_state.values(): + if isinstance(v, torch.Tensor): + assert v.device == torch.device("cpu") + + +def test_cpugpu_optimizer_state_roundtrip() -> None: + """cpu_optimizer_state / gpu_optimizer_state round-trip preserves values.""" + model = _SimpleLoraModel() + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + loss = model.first.sum() + loss.backward() + optimizer.step() + + original = optimizer.state_dict() + cpu_state = cpu_optimizer_state(original) + roundtripped = gpu_optimizer_state(cpu_state, torch.device("cpu")) + + # Step counts match + for param_id in original["state"]: + assert roundtripped["state"][param_id]["step"] == original["state"][param_id]["step"] + + # Tensor values match + for param_id in original["state"]: + for key in ("exp_avg", "exp_avg_sq"): + orig_tensor = original["state"][param_id][key] + rt_tensor = roundtripped["state"][param_id][key] + assert torch.equal(orig_tensor, rt_tensor) + + +def testcpu_optimizer_state_does_not_mutate_original() -> None: + """cpu_optimizer_state deep-copies — mutating the copy leaves the original intact.""" + model = _SimpleLoraModel() + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + loss = model.first.sum() + loss.backward() + optimizer.step() + + original = optimizer.state_dict() + original_exp_avg = original["state"][0]["exp_avg"].clone() + + cpu_state = cpu_optimizer_state(original) + # Mutate the copy + cpu_state["state"][0]["exp_avg"].zero_() + + # Original is unchanged + assert torch.equal(original["state"][0]["exp_avg"], original_exp_avg) + + +def test_lora_cache_entry_is_frozen() -> None: + """LoraCacheEntry is immutable — attribute assignment raises.""" + entry = LoraCacheEntry( + lora_state_dict={"w": torch.zeros(2)}, + optimizer_state_dict={"state": {}, "param_groups": []}, + adapter_config=LoraAdapterConfig( + r=8, + lora_alpha=16, + target_modules=["q_proj"], + lora_dropout=0.0, + bias="none", + task_type="CAUSAL_LM", + ), + ) + with pytest.raises(AttributeError): + entry.lora_state_dict = {} # type: ignore[misc] diff --git a/tests/test_local_training_engine.py b/tests/test_local_training_engine.py index 0a2d764..8648fc5 100644 --- a/tests/test_local_training_engine.py +++ b/tests/test_local_training_engine.py @@ -13,49 +13,154 @@ DistillResponse, TrainingConfig, ) +from claas.training.engine.local.cache import ( # noqa: E402 + DistillStepResult, + LoraAdapterConfig, + LoraCacheEntry, +) from claas.training.engine.local.engine import LocalTrainingEngine # noqa: E402 +_DUMMY_CACHE_ENTRY = LoraCacheEntry( + lora_state_dict={"w": torch.zeros(2)}, + optimizer_state_dict={"state": {}, "param_groups": []}, + adapter_config=LoraAdapterConfig( + r=8, + lora_alpha=16, + target_modules=["q_proj"], + lora_dropout=0.0, + bias="none", + task_type="CAUSAL_LM", + ), +) + + +def _make_payload(lora_id: str = "user/model") -> DistillBatchRequestPayload: + return DistillBatchRequestPayload( + lora_id=lora_id, + training=TrainingConfig(), + samples=[ + DistillBatchItem( + prompt="p", + response="r", + feedback="f", + response_logprobs=[-0.1], + prompt_token_ids=[1, 2], + response_token_ids=[3], + user_prompt="p", + ) + ], + ) + class _Trainer: + """Fake trainer that records method calls.""" + def __init__(self, base_model_id: str, attn_implementation: str): self.base_model_id = base_model_id self.attn_implementation = attn_implementation + self.load_base_model_count = 0 + self.reload_count = 0 + self.offload_count = 0 + self.distill_calls: list[dict] = [] def load_base_model(self) -> None: - return None + self.load_base_model_count += 1 - def distill(self, _payload: DistillBatchRequestPayload) -> DistillResponse: - return DistillResponse(lora_id="user/model", metadata={}) + def reload_base_model(self) -> None: + self.reload_count += 1 + + def distill( + self, + _payload: DistillBatchRequestPayload, + *, + cached: LoraCacheEntry | None = None, + ) -> DistillStepResult: + self.distill_calls.append({"cached": cached}) + return DistillStepResult( + response=DistillResponse(lora_id="user/model", metadata={}), + cache_entry=_DUMMY_CACHE_ENTRY, + ) + + def offload_base_model(self) -> None: + self.offload_count += 1 + + +class _FailingOffloadTrainer(_Trainer): + """Trainer whose offload raises to test error propagation.""" def offload_base_model(self) -> None: raise OSError("cleanup failed") -def test_local_engine_distill_propagates_cleanup_error(monkeypatch): +def _build_engine(monkeypatch, trainer_cls=_Trainer): + from claas.training.engine.local import engine as local_engine + + monkeypatch.setattr(local_engine, "DistillationTrainer", trainer_cls) + monkeypatch.setattr(local_engine, "resolve_lora_id", lambda lid: lid.strip("/")) + cfg = LocalConfig(base_model_id="Qwen/Qwen3-8B", attn_implementation="sdpa") + return LocalTrainingEngine(cfg) + + +def test_trainer_created_eagerly_in_init(monkeypatch): + """Trainer is created in __init__, not lazily on first distill().""" + engine = _build_engine(monkeypatch) + assert isinstance(engine._trainer, _Trainer) + assert engine._model_loaded is False + + +def test_load_base_model_called_once(monkeypatch): + """load_base_model is called exactly once across multiple distill() calls.""" + engine = _build_engine(monkeypatch) + + asyncio.run(engine.distill(_make_payload())) + asyncio.run(engine.distill(_make_payload())) + + assert engine._trainer.load_base_model_count == 1 + + +def test_reload_called_every_distill(monkeypatch): + """reload_base_model is called on every distill() call.""" + engine = _build_engine(monkeypatch) + + asyncio.run(engine.distill(_make_payload())) + asyncio.run(engine.distill(_make_payload())) + + assert engine._trainer.reload_count == 2 + + +def test_cache_miss_then_hit(monkeypatch): + """First call has cached=None, second call uses the cached entry.""" + engine = _build_engine(monkeypatch) + + asyncio.run(engine.distill(_make_payload())) + # First call: no cache + assert engine._trainer.distill_calls[0]["cached"] is None + + asyncio.run(engine.distill(_make_payload())) + # Second call: cache hit + assert engine._trainer.distill_calls[1]["cached"] is _DUMMY_CACHE_ENTRY + + +def test_cache_evicted_on_delete(monkeypatch): + """delete_lora() evicts the cache entry for that lora_id.""" from claas.training.engine.local import engine as local_engine - monkeypatch.setenv("CLAAS_BASE_MODEL_ID", "Qwen/Qwen3-8B") - monkeypatch.setenv("CLAAS_ATTN_IMPLEMENTATION", "sdpa") monkeypatch.setattr(local_engine, "DistillationTrainer", _Trainer) + monkeypatch.setattr(local_engine, "resolve_lora_id", lambda lid: lid.strip("/")) + monkeypatch.setattr(local_engine, "delete_lora", lambda lid: True) cfg = LocalConfig(base_model_id="Qwen/Qwen3-8B", attn_implementation="sdpa") + engine = LocalTrainingEngine(cfg) + + asyncio.run(engine.distill(_make_payload())) + assert "user/model" in engine._lora_cache + + asyncio.run(engine.delete_lora("user/model")) + assert "user/model" not in engine._lora_cache + + +def test_offload_error_propagates(monkeypatch): + """Errors from offload_base_model propagate to the caller.""" + engine = _build_engine(monkeypatch, trainer_cls=_FailingOffloadTrainer) with pytest.raises(OSError, match="cleanup failed"): - asyncio.run( - LocalTrainingEngine(cfg).distill( - DistillBatchRequestPayload( - lora_id="user/model", - training=TrainingConfig(), - samples=[ - DistillBatchItem( - prompt="p", - response="r", - feedback="f", - response_logprobs=[-0.1], - prompt_token_ids=[1, 2], - response_token_ids=[3], - user_prompt="p", - ) - ], - ) - ) - ) + asyncio.run(engine.distill(_make_payload()))