-
Notifications
You must be signed in to change notification settings - Fork 1
Persistent trainer + CPU cache for local engine #32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,13 @@ | |
| import torch | ||
|
|
||
| from claas.core.types import DistillBatchRequestPayload, DistillResponse, SDPOLossInput | ||
| from claas.training.engine.local.cache import ( | ||
| DistillStepResult, | ||
| LoraAdapterConfig, | ||
| LoraCacheEntry, | ||
| cpu_optimizer_state, | ||
| gpu_optimizer_state, | ||
| ) | ||
| from claas.training.sdpo_loss import compute_sdpo_loss | ||
| from claas.training.storage import ( | ||
| cleanup_local_lora, | ||
|
|
@@ -86,6 +93,10 @@ def load_base_model(self) -> None: | |
| self.optimizer_cls = torch.optim.AdamW | ||
| self.functional = torch.nn.functional | ||
|
|
||
| def reload_base_model(self) -> None: | ||
| """Move base model from CPU back to CUDA.""" | ||
| self.base_model.to(self.device) # type: ignore[arg-type] # functools.wraps confuses ty | ||
|
|
||
| def offload_base_model(self) -> None: | ||
| """Move base model to CPU and release CUDA memory.""" | ||
|
|
||
|
|
@@ -129,6 +140,33 @@ def _load_or_create_lora(self, lora_path: str) -> "PeftModel | PeftMixedModel": | |
| ) | ||
| return get_peft_model(self.base_model, lora_config) | ||
|
|
||
| def _load_lora_from_cache( | ||
| self, | ||
| cached: LoraCacheEntry, | ||
| ) -> "PeftModel | PeftMixedModel": | ||
| """Restore a LoRA adapter from a CPU cache entry. | ||
|
|
||
| Args: | ||
| cached: CPU-resident snapshot of adapter state. | ||
|
|
||
| Returns: | ||
| Trainable PEFT model with cached weights loaded. | ||
| """ | ||
| from peft import LoraConfig, get_peft_model, set_peft_model_state_dict | ||
|
|
||
| cfg = cached.adapter_config | ||
| lora_config = LoraConfig( | ||
| r=cfg.r, | ||
| lora_alpha=cfg.lora_alpha, | ||
| target_modules=cfg.target_modules, | ||
| lora_dropout=cfg.lora_dropout, | ||
| bias=cfg.bias, | ||
| task_type=cfg.task_type, | ||
| ) | ||
| model = get_peft_model(self.base_model, lora_config) | ||
| set_peft_model_state_dict(model, cached.lora_state_dict) | ||
| return model | ||
|
|
||
| def _load_optimizer_state( | ||
| self, | ||
| lora_path: str, | ||
|
|
@@ -213,14 +251,58 @@ def _build_self_teacher_topk( | |
| torch.cuda.empty_cache() | ||
| return top_logprobs, top_indices | ||
|
|
||
| def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: | ||
| def _build_cache_entry( | ||
| self, | ||
| model: "PeftModel | PeftMixedModel", | ||
| optimizer: "torch.optim.Optimizer", | ||
| ) -> LoraCacheEntry: | ||
| """Snapshot current model + optimizer state into a CPU-resident cache entry.""" | ||
| from peft import PeftModel as PeftModelCls | ||
|
|
||
| peft_config = model.peft_config["default"] | ||
| adapter_config = LoraAdapterConfig( | ||
| r=peft_config.r, | ||
| lora_alpha=peft_config.lora_alpha, | ||
| target_modules=list(peft_config.target_modules), | ||
| lora_dropout=peft_config.lora_dropout, | ||
| bias=peft_config.bias, | ||
| task_type=peft_config.task_type, | ||
| ) | ||
|
|
||
| # Determine state dict — use PEFT's adapter-only extraction if available | ||
| if isinstance(model, PeftModelCls): | ||
| from peft import get_peft_model_state_dict | ||
|
|
||
| raw_state = get_peft_model_state_dict(model) | ||
| else: | ||
| raw_state = model.state_dict() | ||
|
|
||
| lora_state = {k: v.detach().cpu().clone() for k, v in raw_state.items()} | ||
| opt_state = cpu_optimizer_state(optimizer.state_dict()) | ||
|
|
||
| return LoraCacheEntry( | ||
| lora_state_dict=lora_state, | ||
| optimizer_state_dict=opt_state, | ||
| adapter_config=adapter_config, | ||
| ) | ||
|
|
||
| def distill( | ||
| self, | ||
| payload: DistillBatchRequestPayload, | ||
| *, | ||
| cached: LoraCacheEntry | None = None, | ||
| ) -> DistillStepResult: | ||
| """Run one SDPO distillation step. | ||
|
|
||
| Args: | ||
| payload: Distillation request payload. | ||
| cached: When provided, skip disk reads and load LoRA + optimizer | ||
| state from this CPU-resident cache entry. When ``None``, | ||
| load from disk (cold start). | ||
|
|
||
| Returns: | ||
| Distillation response with metrics. | ||
| Result containing both the distillation response and a cache | ||
| entry for the post-step state. | ||
| """ | ||
|
|
||
| torch.cuda.empty_cache() | ||
|
|
@@ -231,10 +313,18 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: | |
| if len(payload.samples) == 0: | ||
| raise ValueError("samples must contain at least one item") | ||
|
|
||
| lora_local_path = load_lora(payload.lora_id) | ||
| # Disk path (cold start) or cache path | ||
| lora_local_path: str | None = None | ||
| if cached is None: | ||
| lora_local_path = load_lora(payload.lora_id) | ||
|
|
||
| try: | ||
| try: | ||
| model = self._load_or_create_lora(lora_local_path) | ||
| if cached is not None: | ||
| model = self._load_lora_from_cache(cached) | ||
| else: | ||
| assert lora_local_path is not None | ||
| model = self._load_or_create_lora(lora_local_path) | ||
| model.train() | ||
| model.gradient_checkpointing_enable( | ||
| gradient_checkpointing_kwargs={"use_reentrant": False}, | ||
|
|
@@ -249,7 +339,13 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: | |
| betas=(0.9, 0.999), | ||
| weight_decay=0.01, | ||
| ) | ||
| self._load_optimizer_state(lora_local_path, optimizer) | ||
|
|
||
| if cached is not None: | ||
| optimizer.load_state_dict( | ||
| gpu_optimizer_state(cached.optimizer_state_dict, self.device) | ||
| ) | ||
| elif lora_local_path is not None: | ||
| self._load_optimizer_state(lora_local_path, optimizer) | ||
|
|
||
| batch_loss_tensors: list[torch.Tensor] = [] | ||
| batch_distill_loss: list[float] = [] | ||
|
|
@@ -362,10 +458,12 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: | |
| clip_fraction = sum(batch_clip_fraction) / len(batch_clip_fraction) | ||
| grad_norm_value = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm) | ||
|
|
||
| cache_entry = self._build_cache_entry(model, optimizer) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This call unconditionally clones LoRA and optimizer state to CPU every distill step, but not all callers use that cache result (the Modal worker immediately returns Useful? React with 👍 / 👎. |
||
|
|
||
| del model, optimizer, batch_loss_tensors | ||
| torch.cuda.empty_cache() | ||
|
|
||
| return DistillResponse.model_validate( | ||
| response = DistillResponse.model_validate( | ||
| { | ||
| "lora_id": new_lora_id, | ||
| "metadata": { | ||
|
|
@@ -380,5 +478,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: | |
| }, | ||
| } | ||
| ) | ||
| return DistillStepResult(response=response, cache_entry=cache_entry) | ||
| finally: | ||
| cleanup_local_lora(lora_local_path) | ||
| if lora_local_path is not None: | ||
| cleanup_local_lora(lora_local_path) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| """Typed cache structures and helpers for CPU-resident LoRA state between training steps.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import copy | ||
| from dataclasses import dataclass | ||
| from typing import cast | ||
|
|
||
| import torch | ||
|
|
||
| from claas.core.types import DistillResponse | ||
|
|
||
|
|
||
| @dataclass(frozen=True, slots=True) | ||
| class LoraAdapterConfig: | ||
| """Typed representation of LoRA adapter configuration.""" | ||
|
|
||
| r: int | ||
| lora_alpha: int | ||
| target_modules: list[str] | ||
| lora_dropout: float | ||
| bias: str | ||
| task_type: str | ||
|
|
||
|
|
||
| @dataclass(frozen=True, slots=True) | ||
| class LoraCacheEntry: | ||
| """CPU-resident snapshot of LoRA adapter state between training steps.""" | ||
|
|
||
| lora_state_dict: dict[str, torch.Tensor] | ||
| optimizer_state_dict: dict[str, object] | ||
| adapter_config: LoraAdapterConfig | ||
|
|
||
|
|
||
| @dataclass(frozen=True, slots=True) | ||
| class DistillStepResult: | ||
| """Result of a distillation step with both response and cache entry.""" | ||
|
|
||
| response: DistillResponse | ||
| cache_entry: LoraCacheEntry | ||
|
|
||
|
|
||
| def cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]: | ||
| """Deep-copy optimizer state with all tensors moved to CPU.""" | ||
| result: dict[str, object] = {} | ||
| for key, value in state_dict.items(): | ||
| if key == "state": | ||
| param_states = cast("dict[int, dict[str, object]]", value) | ||
| cpu_states: dict[int, dict[str, object]] = {} | ||
| for param_id, param_state in param_states.items(): | ||
| cpu_param: dict[str, object] = {} | ||
| for k, v in param_state.items(): | ||
| if isinstance(v, torch.Tensor): | ||
| cpu_param[k] = v.detach().cpu().clone() | ||
| else: | ||
| cpu_param[k] = copy.deepcopy(v) | ||
| cpu_states[param_id] = cpu_param | ||
| result[key] = cpu_states | ||
| else: | ||
| result[key] = copy.deepcopy(value) | ||
| return result | ||
|
|
||
|
|
||
| def gpu_optimizer_state( | ||
| state_dict: dict[str, object], | ||
| device: torch.device, | ||
| ) -> dict[str, object]: | ||
| """Deep-copy optimizer state with all tensors moved to a target device.""" | ||
| result: dict[str, object] = {} | ||
| for key, value in state_dict.items(): | ||
| if key == "state": | ||
| param_states = cast("dict[int, dict[str, object]]", value) | ||
| gpu_states: dict[int, dict[str, object]] = {} | ||
| for param_id, param_state in param_states.items(): | ||
| gpu_param: dict[str, object] = {} | ||
| for k, v in param_state.items(): | ||
| if isinstance(v, torch.Tensor): | ||
| gpu_param[k] = v.detach().to(device).clone() | ||
| else: | ||
| gpu_param[k] = copy.deepcopy(v) | ||
| gpu_states[param_id] = gpu_param | ||
| result[key] = gpu_states | ||
| else: | ||
| result[key] = copy.deepcopy(value) | ||
| return result |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,7 +3,9 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import re | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import threading | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from claas.core.config import LocalConfig | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from claas.core.types import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -20,6 +22,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from claas.training.distillation import DistillationTrainer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from claas.training.engine.base import TrainingEngine | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from claas.training.engine.local.cache import LoraCacheEntry | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from claas.training.storage import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| configure_storage_backend, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| create_initial_lora, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -31,14 +34,34 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| resolve_lora_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger = logging.getLogger(__name__) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class LocalTrainingEngine(TrainingEngine): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Executes training and LoRA operations on local infrastructure.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _trainer: DistillationTrainer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _lora_cache: dict[str, LoraCacheEntry] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _cache_lock: threading.Lock | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _model_loaded: bool | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, cfg: LocalConfig) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| configure_storage_backend("local_fs") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._base_model_id = cfg.base_model_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._attn_implementation = cfg.attn_implementation | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._trainer = DistillationTrainer( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| base_model_id=cfg.base_model_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attn_implementation=cfg.attn_implementation, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._lora_cache = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._cache_lock = threading.Lock() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._model_loaded = False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def _ensure_model_loaded(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """One-time base model load on first distill() call.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not self._model_loaded: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| await asyncio.to_thread(self._trainer.load_base_model) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._model_loaded = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+60
to
+64
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Between the An 🔧 Proposed fixAdd to self._load_lock = asyncio.Lock()Then: async def _ensure_model_loaded(self) -> None:
"""One-time base model load on first distill() call."""
- if not self._model_loaded:
- await asyncio.to_thread(self._trainer.load_base_model)
- self._model_loaded = True
+ async with self._load_lock:
+ if not self._model_loaded:
+ await asyncio.to_thread(self._trainer.load_base_model)
+ self._model_loaded = True🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def distill( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -52,15 +75,24 @@ async def distill( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Distillation response. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| trainer = DistillationTrainer( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| base_model_id=self._base_model_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attn_implementation=self._attn_implementation, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| await asyncio.to_thread(trainer.load_base_model) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| await self._ensure_model_loaded() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| await asyncio.to_thread(self._trainer.reload_base_model) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| resolved_id = await asyncio.to_thread(resolve_lora_id, payload.lora_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with self._cache_lock: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cached = self._lora_cache.get(resolved_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return await asyncio.to_thread(trainer.distill, payload) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result = await asyncio.to_thread( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._trainer.distill, payload, cached=cached | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| finally: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| await asyncio.to_thread(trainer.offload_base_model) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| await asyncio.to_thread(self._trainer.offload_base_model) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with self._cache_lock: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._lora_cache[resolved_id] = result.cache_entry | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When Useful? React with 👍 / 👎. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return result.response | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+81
to
+95
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Find save_in_place usage in the codebase
rg -n 'save_in_place' --type py -C 2Repository: kfallah/CLaaS Length of output: 1283 🏁 Script executed: #!/bin/bash
# Read the distillation.py file around the save_in_place logic to see what happens in both branches
sed -n '485,510p' claas/training/distillation.pyRepository: kfallah/CLaaS Length of output: 1235 🏁 Script executed: #!/bin/bash
# Find the definitions of save_lora and save_lora_inplace functions
rg -n 'def save_lora' --type py -A 15Repository: kfallah/CLaaS Length of output: 2048 🏁 Script executed: #!/bin/bash
# Read the full save_lora function implementation
sed -n '292,345p' claas/training/storage.pyRepository: kfallah/CLaaS Length of output: 1645 🏁 Script executed: #!/bin/bash
# Read the save_lora_inplace function implementation
sed -n '346,390p' claas/training/storage.pyRepository: kfallah/CLaaS Length of output: 1718 🏁 Script executed: #!/bin/bash
# Read the rest of save_lora_inplace function to see the return statement
sed -n '390,400p' claas/training/storage.pyRepository: kfallah/CLaaS Length of output: 308 🏁 Script executed: #!/bin/bash
# Find where DistillResponse is created and how new_lora_id is used
sed -n '508,525p' claas/training/distillation.pyRepository: kfallah/CLaaS Length of output: 815 🏁 Script executed: #!/bin/bash
# Check if save_in_place=False is actually used anywhere (excluding tests)
rg -n 'save_in_place\s*=\s*False' --type pyRepository: kfallah/CLaaS Length of output: 39 🏁 Script executed: #!/bin/bash
# Check if there are any other callers of distill() that might use save_in_place=False
rg -n '\.distill\(' --type py -B 5 -A 2Repository: kfallah/CLaaS Length of output: 8135 Key cache by output lora_id instead of input to handle non-in-place saves correctly.
The current cache key (line 93) uses the input
While Cache key fix with self._cache_lock:
- self._lora_cache[resolved_id] = result.cache_entry
+ new_resolved = await asyncio.to_thread(resolve_lora_id, result.response.lora_id)
+ self._lora_cache[new_resolved] = result.cache_entry📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def init_lora(self, request: LoraInitRequest) -> LoraInitResponse: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Initialize a LoRA adapter locally. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -82,7 +114,11 @@ async def init_lora(self, request: LoraInitRequest) -> LoraInitResponse: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return LoraInitResponse(lora_id=lora_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def delete_lora(self, lora_id: str) -> LoraDeleteResponse: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| resolved_id = await asyncio.to_thread(resolve_lora_id, lora_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| deleted = await asyncio.to_thread(delete_lora, lora_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if deleted: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with self._cache_lock: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._lora_cache.pop(resolved_id, None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return LoraDeleteResponse(deleted=deleted) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def list_loras(self, prefix: str) -> LoraListResponse: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: kfallah/CLaaS
Length of output: 550
🏁 Script executed:
# Also check the imports in the distillation.py file head -50 claas/training/distillation.pyRepository: kfallah/CLaaS
Length of output: 1452
🌐 Web query:
PEFT LoraConfig task_type attribute return type TaskType enum or string💡 Result:
In Hugging Face PEFT,
LoraConfig.task_typeis not limited to theTaskTypeenum.PeftConfigMixin) documentstask_typeasUnion[TaskType, str](optionallyNone). (huggingface.co)TaskType.SEQ_2_SEQ_LM(enum) (huggingface.co) or a string like"SEQ_2_SEQ_LM". (huggingface.co)So, the effective “return type” you should expect when reading
config.task_typeisTaskType | str | None(depending on how it was created/loaded). (huggingface.co)Citations:
🏁 Script executed:
Repository: kfallah/CLaaS
Length of output: 98
🏁 Script executed:
Repository: kfallah/CLaaS
Length of output: 173
🏁 Script executed:
Repository: kfallah/CLaaS
Length of output: 740
🏁 Script executed:
Repository: kfallah/CLaaS
Length of output: 39
🏁 Script executed:
Repository: kfallah/CLaaS
Length of output: 76
🌐 Web query:
PEFT LoraConfig bias attribute type enum string representation💡 Result:
In Hugging Face PEFT
peft.LoraConfig, thebiasattribute is a string constrained to specific values (typed as aLiteralin the docs):bias: Literal["none", "all", "lora_only"] = "none"[1]"none": don’t train any bias terms"all": train all model biases"lora_only": train only the biases belonging to LoRA layers [1][2]So the “enum” is effectively these three string representations:
"none","all","lora_only"(default"none"). [1][2]Sources: [1] [2]
🌐 Web query:
PEFT 0.18.1 LoraConfig TaskType task_type attribute type enum💡 Result:
In PEFT 0.18.1,
LoraConfig.task_typeis typed as:Optional[Union[str, TaskType]](so you can pass either aTaskTypeenum value or the equivalent string). (huggingface.co)TaskTypeis astr-backedenum.Enumwith these values:TaskType.SEQ_CLS = "SEQ_CLS"TaskType.SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM"TaskType.CAUSAL_LM = "CAUSAL_LM"TaskType.TOKEN_CLS = "TOKEN_CLS"TaskType.QUESTION_ANS = "QUESTION_ANS"TaskType.FEATURE_EXTRACTION = "FEATURE_EXTRACTION"(raw.githubusercontent.com)Citations:
🏁 Script executed:
Repository: kfallah/CLaaS
Length of output: 3726
🏁 Script executed:
Repository: kfallah/CLaaS
Length of output: 738
Coerce
task_typeenum to string for type consistencypeft_config.task_typemay be aTaskTypeenum value (e.g.TaskType.CAUSAL_LM), butLoraAdapterConfig.task_typeis annotated asstr. While this works at runtime becauseTaskTypeis a string-backed enum, the type mismatch should be corrected by extracting the string value.Note:
biasis already a string per PEFT 0.18.1 and does not need coercion.🔧 Proposed fix
adapter_config = LoraAdapterConfig( r=peft_config.r, lora_alpha=peft_config.lora_alpha, target_modules=list(peft_config.target_modules), lora_dropout=peft_config.lora_dropout, bias=peft_config.bias, - task_type=peft_config.task_type, + task_type=peft_config.task_type.value if hasattr(peft_config.task_type, "value") else str(peft_config.task_type), )🤖 Prompt for AI Agents