Skip to content

Persistent trainer + CPU cache for local engine#32

Open
kfallah wants to merge 2 commits intomainfrom
persistent-trainer-cpu-cache
Open

Persistent trainer + CPU cache for local engine#32
kfallah wants to merge 2 commits intomainfrom
persistent-trainer-cpu-cache

Conversation

@kfallah
Copy link
Owner

@kfallah kfallah commented Feb 23, 2026

Summary

  • Persistent trainer: LocalTrainingEngine now creates the DistillationTrainer once in __init__ and reuses it across distill() calls, eliminating redundant base model loads after the first call
  • CPU LoRA cache: After each training step, LoRA adapter weights and optimizer state are snapshotted to CPU memory (LoraCacheEntry). Subsequent calls for the same lora_id skip all disk I/O (load_lora / load_optimizer_state)
  • Typed cache structures: New claas/training/cache.py with frozen @dataclass types (LoraCacheEntry, LoraAdapterConfig, DistillStepResult) — no loose dicts, no optional types where invariants can be enforced
  • GPU memory guarantee preserved: The existing offload_base_model() + del model, optimizer + cuda.empty_cache() pattern is unchanged. Cache holds CPU-only tensors by construction

Test plan

  • tests/test_distillation_optimizer_state.py_cpu_optimizer_state / _gpu_optimizer_state round-trip, deep-copy isolation, LoraCacheEntry immutability
  • tests/test_local_training_engine.py — eager trainer creation, one-time load_base_model, per-call reload_base_model, cache miss→hit, cache eviction on delete, offload error propagation
  • Full test suite passes: uv run pytest tests/ -v -m "not integration" (103 passed, 25 skipped for torch)
  • Lint clean: uv run ruff check claas/ tests/
  • Type check: uv run ty check (only pre-existing unresolved-import errors for torch/peft/transformers)

🤖 Generated with Claude Code

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced in-memory caching mechanism for LoRA adapter state and optimizer configuration between training steps
    • Added optimizer state serialization and device placement utilities for improved GPU/CPU memory management
  • Refactor

    • Enhanced training engine to support cache-driven training workflows, reducing redundant model loading and state reconstruction

Keep the DistillationTrainer and base model across distill() calls instead
of recreating them each time. Cache LoRA adapter weights and optimizer state
on CPU between steps so the second call for a given lora_id skips all disk
I/O. GPU memory is still fully released after each step via the existing
offload_base_model() pattern.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link

coderabbitai bot commented Feb 23, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

The PR introduces CPU-resident caching for LoRA adapter states between distillation training steps. New immutable data structures capture LoRA configuration, optimizer states, and distillation results. The distillation trainer and local engine are updated to optionally load from cache with thread-safe management and proper device placement for optimizer state tensors.

Changes

Cohort / File(s) Summary
Cache Infrastructure
claas/training/cache.py
Added three new frozen dataclasses: LoraAdapterConfig (LoRA hyperparameters), LoraCacheEntry (LoRA and optimizer state snapshot), and DistillStepResult (response with cache entry) to support immutable, serializable caching structures.
Distillation Trainer
claas/training/distillation.py
Enhanced distill method to accept optional cached parameter; added optimizer state utilities (_cpu_optimizer_state, _gpu_optimizer_state); introduced LoRA loading from cache (_load_lora_from_cache) and cache entry building (_build_cache_entry); changed return type to DistillStepResult; added public reload_base_model method.
Local Training Engine
claas/training/engine/local/engine.py
Implemented thread-safe LoRA caching with _lora_cache and _cache_lock; updated distill flow to reload base model, resolve LoRA ID, fetch/update cache under lock, and return response; modified delete_lora to evict cache entries; added one-time model loading via _ensure_model_loaded.
Worker Integration
claas/modal/worker.py
Modified distill return value to extract response attribute from DistillationResponse, now returning inner payload instead of wrapper.
Test Coverage
tests/test_distillation_optimizer_state.py, tests/test_local_training_engine.py
Added tests for optimizer state serialization/deserialization, cache immutability, and caching behavior including cache hit/miss, eviction via delete_lora, and offload error propagation.

Sequence Diagram

sequenceDiagram
    participant Client as Client Request
    participant Engine as LocalTrainingEngine
    participant Cache as LoRA Cache
    participant Trainer as DistillationTrainer
    participant Model as LoRA Model
    participant Optimizer as Optimizer State

    Client->>Engine: distill(lora_id, request)
    Engine->>Engine: _ensure_model_loaded()
    Engine->>Model: reload_base_model()
    Engine->>Cache: fetch cached entry for lora_id
    
    alt Cache Hit
        Cache-->>Engine: LoraCacheEntry
        Engine->>Trainer: distill(payload, cached=entry)
        Trainer->>Model: _load_lora_from_cache(entry)
        Trainer->>Optimizer: _gpu_optimizer_state(cached.optimizer_state_dict)
    else Cache Miss
        Engine->>Trainer: distill(payload, cached=None)
        Trainer->>Model: load LoRA from disk
        Trainer->>Optimizer: initialize optimizer
    end
    
    Trainer->>Trainer: train()
    Trainer->>Model: get trained model state
    Trainer->>Optimizer: get optimizer state
    Trainer->>Trainer: _build_cache_entry()
    Trainer-->>Engine: DistillStepResult(response, cache_entry)
    Engine->>Cache: update cache[lora_id] = cache_entry
    Engine->>Model: offload_base_model()
    Engine-->>Client: result.response
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

🐰 Hop, hop! Cache entries now stay,
LoRA states snapped away,
No disk reads slow the train,
Optimizer states ordained,
Thread-safe caches light the way! 🌟

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 72.97% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main architectural changes: adding a persistent trainer instance and CPU-based LoRA caching to the local engine, which are the core objectives of this PR.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch persistent-trainer-cpu-cache

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@@ -0,0 +1,38 @@
"""Typed cache structures for CPU-resident LoRA state between training steps."""
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this live in claas/training/engine/local/cache.py?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to claas/training/engine/local/cache.py along with the optimizer state helpers in 0747c76. These types and helpers are only used by the local engine's CPU caching path, so they belong in its namespace.

Comment generated by Claude Code

logger = logging.getLogger(__name__)


def _cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this live in the local engine imp?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to claas/training/engine/local/cache.py in 0747c76. distillation.py imports from the new location — the functions are only called when the local engine passes cached=.

Comment generated by Claude Code

return result


def _gpu_optimizer_state(
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this live in the local engine imp?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same move as above — now in claas/training/engine/local/cache.py in 0747c76.

Comment generated by Claude Code

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e6b04caaad

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

await asyncio.to_thread(self._trainer.offload_base_model)

with self._cache_lock:
self._lora_cache[resolved_id] = result.cache_entry

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Cache post-step state under the returned LoRA id

When save_in_place=False (the default on DistillBatchRequestPayload), the trainer saves to a new versioned lora_id, but this code stores the cache entry under the original resolved_id from the request. If a caller submits another distill request for that same original ID, the engine will now reuse the cached newer weights/optimizer state instead of reloading the requested adapter from disk, silently changing training lineage and producing incorrect branch behavior.

Useful? React with 👍 / 👎.

clip_fraction = sum(batch_clip_fraction) / len(batch_clip_fraction)
grad_norm_value = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm)

cache_entry = self._build_cache_entry(model, optimizer)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Make cache snapshot optional for non-caching callers

This call unconditionally clones LoRA and optimizer state to CPU every distill step, but not all callers use that cache result (the Modal worker immediately returns .response and drops it). In those paths, each request pays a full extra state copy with no benefit, which can materially increase step latency and peak memory for larger adapters.

Useful? React with 👍 / 👎.

Relocate LoraCacheEntry, LoraAdapterConfig, DistillStepResult, and the
cpu/gpu_optimizer_state helpers from the shared training module into
claas/training/engine/local/cache.py since they are only used by the
local engine's CPU caching path. The Modal worker never uses caching.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
claas/training/distillation.py (2)

140-142: reload_base_model assumes self.base_model and self.device are set.

If called before load_base_model(), this will raise AttributeError. The engine's _ensure_model_loaded gate makes this safe today, but a defensive check or docstring noting the precondition would help.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/training/distillation.py` around lines 140 - 142, The reload_base_model
method assumes self.base_model and self.device exist; add a defensive
precondition check at the start of reload_base_model that verifies hasattr(self,
"base_model") and hasattr(self, "device") (or self.base_model is not None /
self.device is not None) and either return early (no-op) or raise a clear
RuntimeError indicating load_base_model must be called first; reference the
existing _ensure_model_loaded guard in the comment or docstring to make the
precondition explicit and update the docstring of reload_base_model to state
that load_base_model must be called prior.

43-85: _cpu_optimizer_state and _gpu_optimizer_state are nearly identical — consider a shared helper.

The two functions differ only in the tensor placement expression (v.detach().cpu().clone() vs v.detach().to(device).clone()). A single _remap_optimizer_state(state_dict, device) would eliminate the duplication.

♻️ Proposed consolidation
-def _cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]:
-    """Deep-copy optimizer state with all tensors moved to CPU."""
-    result: dict[str, object] = {}
-    for key, value in state_dict.items():
-        if key == "state":
-            param_states = cast("dict[int, dict[str, object]]", value)
-            cpu_states: dict[int, dict[str, object]] = {}
-            for param_id, param_state in param_states.items():
-                cpu_param: dict[str, object] = {}
-                for k, v in param_state.items():
-                    if isinstance(v, torch.Tensor):
-                        cpu_param[k] = v.detach().cpu().clone()
-                    else:
-                        cpu_param[k] = copy.deepcopy(v)
-                cpu_states[param_id] = cpu_param
-            result[key] = cpu_states
-        else:
-            result[key] = copy.deepcopy(value)
-    return result
-
-
-def _gpu_optimizer_state(
-    state_dict: dict[str, object],
-    device: torch.device,
-) -> dict[str, object]:
-    """Deep-copy optimizer state with all tensors moved to a target device."""
-    result: dict[str, object] = {}
-    for key, value in state_dict.items():
-        if key == "state":
-            param_states = cast("dict[int, dict[str, object]]", value)
-            gpu_states: dict[int, dict[str, object]] = {}
-            for param_id, param_state in param_states.items():
-                gpu_param: dict[str, object] = {}
-                for k, v in param_state.items():
-                    if isinstance(v, torch.Tensor):
-                        gpu_param[k] = v.detach().to(device).clone()
-                    else:
-                        gpu_param[k] = copy.deepcopy(v)
-                gpu_states[param_id] = gpu_param
-            result[key] = gpu_states
-        else:
-            result[key] = copy.deepcopy(value)
-    return result
+def _remap_optimizer_state(
+    state_dict: dict[str, object],
+    device: torch.device,
+) -> dict[str, object]:
+    """Deep-copy optimizer state with all tensors moved to *device*."""
+    result: dict[str, object] = {}
+    for key, value in state_dict.items():
+        if key == "state":
+            param_states = cast("dict[int, dict[str, object]]", value)
+            new_states: dict[int, dict[str, object]] = {}
+            for param_id, param_state in param_states.items():
+                new_param: dict[str, object] = {}
+                for k, v in param_state.items():
+                    if isinstance(v, torch.Tensor):
+                        new_param[k] = v.detach().to(device).clone()
+                    else:
+                        new_param[k] = copy.deepcopy(v)
+                new_states[param_id] = new_param
+            result[key] = new_states
+        else:
+            result[key] = copy.deepcopy(value)
+    return result
+
+
+def _cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]:
+    """Deep-copy optimizer state with all tensors moved to CPU."""
+    return _remap_optimizer_state(state_dict, torch.device("cpu"))
+
+
+def _gpu_optimizer_state(
+    state_dict: dict[str, object],
+    device: torch.device,
+) -> dict[str, object]:
+    """Deep-copy optimizer state with all tensors moved to a target device."""
+    return _remap_optimizer_state(state_dict, device)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/training/distillation.py` around lines 43 - 85, Consolidate duplicated
logic in _cpu_optimizer_state and _gpu_optimizer_state by extracting a single
helper _remap_optimizer_state(state_dict: dict[str, object], device:
torch.device) -> dict[str, object] that walks the state_dict exactly as today,
copies non-tensor values with copy.deepcopy, and for torch.Tensor values uses
v.detach().to(device).clone(); then implement _cpu_optimizer_state as a thin
wrapper that calls _remap_optimizer_state(state_dict, torch.device("cpu")) and
_gpu_optimizer_state as a thin wrapper that forwards the provided device to
_remap_optimizer_state, keeping the existing type casts ("dict[int, dict[str,
object]]") and return shape unchanged.
claas/training/cache.py (1)

12-31: Frozen dataclasses with mutable containers provide only shallow immutability.

frozen=True prevents attribute reassignment but callers can still mutate the inner list and dict objects in place (e.g., entry.lora_state_dict["new_key"] = ...). This is fine as-is since _build_cache_entry creates defensive copies, but worth noting for future maintainers.

If you want deeper guarantees later, consider tuple[str, ...] for target_modules and types.MappingProxyType wrappers for the dicts.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/training/cache.py` around lines 12 - 31, The dataclasses
LoraAdapterConfig and LoraCacheEntry are declared frozen but contain mutable
containers (list and dict) allowing in-place mutation; update the types and
construction to provide stronger immutability by changing
LoraAdapterConfig.target_modules from list[str] to tuple[str, ...] and make
LoraCacheEntry.lora_state_dict and optimizer_state_dict immutable views (e.g.,
wrap with types.MappingProxyType) when building entries; also update the builder
function (_build_cache_entry or wherever entries are created) to convert
incoming lists to tuples and wrap dicts in MappingProxyType so the frozen
dataclass truly prevents mutations of internal containers.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@claas/training/distillation.py`:
- Around line 298-331: In _build_cache_entry, peft_config.task_type can be a
TaskType enum while LoraAdapterConfig.task_type is annotated as str; fix by
converting peft_config.task_type to its string value when constructing
adapter_config (e.g., use str(peft_config.task_type) or
peft_config.task_type.value) so adapter_config.task_type is a plain string;
update the adapter_config construction in the _build_cache_entry function to
coerce peft_config.task_type accordingly.

In `@claas/training/engine/local/engine.py`:
- Around line 60-64: There's a race in _ensure_model_loaded: multiple concurrent
distill() callers can pass the "if not self._model_loaded" check and each call
self._trainer.load_base_model via asyncio.to_thread, causing redundant loads;
fix by creating an asyncio.Lock in __init__ (e.g. self._load_lock =
asyncio.Lock()) and wrap the check/load/flag assignment inside "async with
self._load_lock" in _ensure_model_loaded so only one caller runs
asyncio.to_thread(self._trainer.load_base_model) and sets self._model_loaded =
True while holding the lock.
- Around line 81-95: The cache is being stored under the input-derived
resolved_id (from resolve_lora_id(payload.lora_id)) but distill() may return a
different output id (result.response.lora_id) when save_in_place=False; fix by
using the output id as the cache key: after awaiting
asyncio.to_thread(self._trainer.distill, ...), read final_id =
result.response.lora_id (or from result.cache_entry if more appropriate) and
then under self._cache_lock set self._lora_cache[final_id] = result.cache_entry
(optionally remove the old resolved_id if final_id != resolved_id to avoid stale
entries); keep the initial cache read using resolved_id unchanged but always
write using final_id.

---

Nitpick comments:
In `@claas/training/cache.py`:
- Around line 12-31: The dataclasses LoraAdapterConfig and LoraCacheEntry are
declared frozen but contain mutable containers (list and dict) allowing in-place
mutation; update the types and construction to provide stronger immutability by
changing LoraAdapterConfig.target_modules from list[str] to tuple[str, ...] and
make LoraCacheEntry.lora_state_dict and optimizer_state_dict immutable views
(e.g., wrap with types.MappingProxyType) when building entries; also update the
builder function (_build_cache_entry or wherever entries are created) to convert
incoming lists to tuples and wrap dicts in MappingProxyType so the frozen
dataclass truly prevents mutations of internal containers.

In `@claas/training/distillation.py`:
- Around line 140-142: The reload_base_model method assumes self.base_model and
self.device exist; add a defensive precondition check at the start of
reload_base_model that verifies hasattr(self, "base_model") and hasattr(self,
"device") (or self.base_model is not None / self.device is not None) and either
return early (no-op) or raise a clear RuntimeError indicating load_base_model
must be called first; reference the existing _ensure_model_loaded guard in the
comment or docstring to make the precondition explicit and update the docstring
of reload_base_model to state that load_base_model must be called prior.
- Around line 43-85: Consolidate duplicated logic in _cpu_optimizer_state and
_gpu_optimizer_state by extracting a single helper
_remap_optimizer_state(state_dict: dict[str, object], device: torch.device) ->
dict[str, object] that walks the state_dict exactly as today, copies non-tensor
values with copy.deepcopy, and for torch.Tensor values uses
v.detach().to(device).clone(); then implement _cpu_optimizer_state as a thin
wrapper that calls _remap_optimizer_state(state_dict, torch.device("cpu")) and
_gpu_optimizer_state as a thin wrapper that forwards the provided device to
_remap_optimizer_state, keeping the existing type casts ("dict[int, dict[str,
object]]") and return shape unchanged.

ℹ️ Review info

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between dedadf6 and e6b04ca.

📒 Files selected for processing (6)
  • claas/modal/worker.py
  • claas/training/cache.py
  • claas/training/distillation.py
  • claas/training/engine/local/engine.py
  • tests/test_distillation_optimizer_state.py
  • tests/test_local_training_engine.py

Comment on lines 298 to 331
def _build_cache_entry(
self,
model: "PeftModel | PeftMixedModel",
optimizer: "torch.optim.Optimizer",
) -> LoraCacheEntry:
"""Snapshot current model + optimizer state into a CPU-resident cache entry."""
from peft import PeftModel as PeftModelCls

peft_config = model.peft_config["default"]
adapter_config = LoraAdapterConfig(
r=peft_config.r,
lora_alpha=peft_config.lora_alpha,
target_modules=list(peft_config.target_modules),
lora_dropout=peft_config.lora_dropout,
bias=peft_config.bias,
task_type=peft_config.task_type,
)

# Determine state dict — use PEFT's adapter-only extraction if available
if isinstance(model, PeftModelCls):
from peft import get_peft_model_state_dict

raw_state = get_peft_model_state_dict(model)
else:
raw_state = model.state_dict()

lora_state = {k: v.detach().cpu().clone() for k, v in raw_state.items()}
opt_state = _cpu_optimizer_state(optimizer.state_dict())

return LoraCacheEntry(
lora_state_dict=lora_state,
optimizer_state_dict=opt_state,
adapter_config=adapter_config,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# Find LoraAdapterConfig definition
find . -name "*.py" -type f | xargs rg "class LoraAdapterConfig" -A 10

Repository: kfallah/CLaaS

Length of output: 550


🏁 Script executed:

# Also check the imports in the distillation.py file
head -50 claas/training/distillation.py

Repository: kfallah/CLaaS

Length of output: 1452


🌐 Web query:

PEFT LoraConfig task_type attribute return type TaskType enum or string

💡 Result:

In Hugging Face PEFT, LoraConfig.task_type is not limited to the TaskType enum.

  • The config base class (PeftConfigMixin) documents task_type as Union[TaskType, str] (optionally None). (huggingface.co)
  • PEFT docs/examples show passing either TaskType.SEQ_2_SEQ_LM (enum) (huggingface.co) or a string like "SEQ_2_SEQ_LM". (huggingface.co)

So, the effective “return type” you should expect when reading config.task_type is TaskType | str | None (depending on how it was created/loaded). (huggingface.co)

Citations:


🏁 Script executed:

# Check what PEFT's LoraConfig actually returns for these fields
python3 <<'EOF'
import sys
try:
    from peft import TaskType, LoraConfig
    
    # Check if TaskType is an enum
    print(f"TaskType type: {type(TaskType)}")
    print(f"TaskType.CAUSAL_LM: {TaskType.CAUSAL_LM}")
    print(f"TaskType.CAUSAL_LM type: {type(TaskType.CAUSAL_LM)}")
    print(f"TaskType.CAUSAL_LM.value: {TaskType.CAUSAL_LM.value if hasattr(TaskType.CAUSAL_LM, 'value') else 'N/A'}")
    
    # Create a LoraConfig to check what task_type looks like
    config = LoraConfig(
        r=8,
        lora_alpha=16,
        target_modules=["q_proj", "v_proj"],
        task_type=TaskType.CAUSAL_LM,
        bias="none"
    )
    
    print(f"\nconfig.task_type: {config.task_type}")
    print(f"config.task_type type: {type(config.task_type)}")
    print(f"config.bias: {config.bias}")
    print(f"config.bias type: {type(config.bias)}")
    
    # Check if bias is already a string
    print(f"\nIs task_type an enum after LoraConfig? {hasattr(config.task_type, 'value')}")
    print(f"Is bias a string? {isinstance(config.bias, str)}")
    
except ImportError:
    print("PEFT not installed, skipping verification", file=sys.stderr)
    sys.exit(1)
except Exception as e:
    print(f"Error: {e}", file=sys.stderr)
    sys.exit(1)
EOF

Repository: kfallah/CLaaS

Length of output: 98


🏁 Script executed:

# Also check the actual PEFT LoraConfig source to understand type handling
python3 <<'EOF'
from peft import LoraConfig
import inspect

# Get the source code location
print(f"LoraConfig location: {inspect.getfile(LoraConfig)}")

# Check LoraConfig's __init__ signature
sig = inspect.signature(LoraConfig.__init__)
for param_name in ['task_type', 'bias']:
    if param_name in sig.parameters:
        param = sig.parameters[param_name]
        print(f"{param_name}: annotation={param.annotation}, default={param.default}")
EOF

Repository: kfallah/CLaaS

Length of output: 173


🏁 Script executed:

# Check how peft_config is used throughout the codebase
rg "peft_config\." --type py -B 2 -A 2 | head -100

Repository: kfallah/CLaaS

Length of output: 740


🏁 Script executed:

# Search for TaskType usage in the repository
rg "TaskType" --type py -B 2 -A 2

Repository: kfallah/CLaaS

Length of output: 39


🏁 Script executed:

# Check requirements or setup files for PEFT version
find . -name "requirements*.txt" -o -name "setup.py" -o -name "pyproject.toml" | xargs cat 2>/dev/null | grep -i peft

Repository: kfallah/CLaaS

Length of output: 76


🌐 Web query:

PEFT LoraConfig bias attribute type enum string representation

💡 Result:

In Hugging Face PEFT peft.LoraConfig, the bias attribute is a string constrained to specific values (typed as a Literal in the docs):

  • bias: Literal["none", "all", "lora_only"] = "none" [1]
    • "none": don’t train any bias terms
    • "all": train all model biases
    • "lora_only": train only the biases belonging to LoRA layers [1][2]

So the “enum” is effectively these three string representations: "none", "all", "lora_only" (default "none"). [1][2]

Sources: [1] [2]


🌐 Web query:

PEFT 0.18.1 LoraConfig TaskType task_type attribute type enum

💡 Result:

In PEFT 0.18.1, LoraConfig.task_type is typed as:

  • Optional[Union[str, TaskType]] (so you can pass either a TaskType enum value or the equivalent string). (huggingface.co)

TaskType is a str-backed enum.Enum with these values:

  • TaskType.SEQ_CLS = "SEQ_CLS"
  • TaskType.SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM"
  • TaskType.CAUSAL_LM = "CAUSAL_LM"
  • TaskType.TOKEN_CLS = "TOKEN_CLS"
  • TaskType.QUESTION_ANS = "QUESTION_ANS"
  • TaskType.FEATURE_EXTRACTION = "FEATURE_EXTRACTION" (raw.githubusercontent.com)

Citations:


🏁 Script executed:

# Check how LoraAdapterConfig is actually used after creation in the codebase
rg "LoraAdapterConfig" --type py -B 3 -A 3 | head -80

Repository: kfallah/CLaaS

Length of output: 3726


🏁 Script executed:

# Look for any serialization or conversion of adapter_config
rg "adapter_config" claas/training/distillation.py -B 2 -A 2

Repository: kfallah/CLaaS

Length of output: 738


Coerce task_type enum to string for type consistency

peft_config.task_type may be a TaskType enum value (e.g. TaskType.CAUSAL_LM), but LoraAdapterConfig.task_type is annotated as str. While this works at runtime because TaskType is a string-backed enum, the type mismatch should be corrected by extracting the string value.

Note: bias is already a string per PEFT 0.18.1 and does not need coercion.

🔧 Proposed fix
         adapter_config = LoraAdapterConfig(
             r=peft_config.r,
             lora_alpha=peft_config.lora_alpha,
             target_modules=list(peft_config.target_modules),
             lora_dropout=peft_config.lora_dropout,
             bias=peft_config.bias,
-            task_type=peft_config.task_type,
+            task_type=peft_config.task_type.value if hasattr(peft_config.task_type, "value") else str(peft_config.task_type),
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/training/distillation.py` around lines 298 - 331, In
_build_cache_entry, peft_config.task_type can be a TaskType enum while
LoraAdapterConfig.task_type is annotated as str; fix by converting
peft_config.task_type to its string value when constructing adapter_config
(e.g., use str(peft_config.task_type) or peft_config.task_type.value) so
adapter_config.task_type is a plain string; update the adapter_config
construction in the _build_cache_entry function to coerce peft_config.task_type
accordingly.

Comment on lines +60 to +64
async def _ensure_model_loaded(self) -> None:
"""One-time base model load on first distill() call."""
if not self._model_loaded:
await asyncio.to_thread(self._trainer.load_base_model)
self._model_loaded = True
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

_ensure_model_loaded has a race window under concurrent async callers.

Between the if not self._model_loaded check and self._model_loaded = True assignment, there's an await yield point. A second concurrent distill() call can observe _model_loaded is False and also trigger load_base_model. The result is a harmless but wasteful double-load at startup.

An asyncio.Lock would close the gap:

🔧 Proposed fix

Add to __init__:

self._load_lock = asyncio.Lock()

Then:

     async def _ensure_model_loaded(self) -> None:
         """One-time base model load on first distill() call."""
-        if not self._model_loaded:
-            await asyncio.to_thread(self._trainer.load_base_model)
-            self._model_loaded = True
+        async with self._load_lock:
+            if not self._model_loaded:
+                await asyncio.to_thread(self._trainer.load_base_model)
+                self._model_loaded = True
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/training/engine/local/engine.py` around lines 60 - 64, There's a race
in _ensure_model_loaded: multiple concurrent distill() callers can pass the "if
not self._model_loaded" check and each call self._trainer.load_base_model via
asyncio.to_thread, causing redundant loads; fix by creating an asyncio.Lock in
__init__ (e.g. self._load_lock = asyncio.Lock()) and wrap the check/load/flag
assignment inside "async with self._load_lock" in _ensure_model_loaded so only
one caller runs asyncio.to_thread(self._trainer.load_base_model) and sets
self._model_loaded = True while holding the lock.

Comment on lines +81 to +95
resolved_id = await asyncio.to_thread(resolve_lora_id, payload.lora_id)
with self._cache_lock:
cached = self._lora_cache.get(resolved_id)

try:
return await asyncio.to_thread(trainer.distill, payload)
result = await asyncio.to_thread(
self._trainer.distill, payload, cached=cached
)
finally:
await asyncio.to_thread(trainer.offload_base_model)
await asyncio.to_thread(self._trainer.offload_base_model)

with self._cache_lock:
self._lora_cache[resolved_id] = result.cache_entry

return result.response
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find save_in_place usage in the codebase
rg -n 'save_in_place' --type py -C 2

Repository: kfallah/CLaaS

Length of output: 1283


🏁 Script executed:

#!/bin/bash
# Read the distillation.py file around the save_in_place logic to see what happens in both branches
sed -n '485,510p' claas/training/distillation.py

Repository: kfallah/CLaaS

Length of output: 1235


🏁 Script executed:

#!/bin/bash
# Find the definitions of save_lora and save_lora_inplace functions
rg -n 'def save_lora' --type py -A 15

Repository: kfallah/CLaaS

Length of output: 2048


🏁 Script executed:

#!/bin/bash
# Read the full save_lora function implementation
sed -n '292,345p' claas/training/storage.py

Repository: kfallah/CLaaS

Length of output: 1645


🏁 Script executed:

#!/bin/bash
# Read the save_lora_inplace function implementation
sed -n '346,390p' claas/training/storage.py

Repository: kfallah/CLaaS

Length of output: 1718


🏁 Script executed:

#!/bin/bash
# Read the rest of save_lora_inplace function to see the return statement
sed -n '390,400p' claas/training/storage.py

Repository: kfallah/CLaaS

Length of output: 308


🏁 Script executed:

#!/bin/bash
# Find where DistillResponse is created and how new_lora_id is used
sed -n '508,525p' claas/training/distillation.py

Repository: kfallah/CLaaS

Length of output: 815


🏁 Script executed:

#!/bin/bash
# Check if save_in_place=False is actually used anywhere (excluding tests)
rg -n 'save_in_place\s*=\s*False' --type py

Repository: kfallah/CLaaS

Length of output: 39


🏁 Script executed:

#!/bin/bash
# Check if there are any other callers of distill() that might use save_in_place=False
rg -n '\.distill\(' --type py -B 5 -A 2

Repository: kfallah/CLaaS

Length of output: 8135


Key cache by output lora_id instead of input to handle non-in-place saves correctly.

resolved_id (line 81) is derived from payload.lora_id. The distill() method returns result.response.lora_id which may differ from the input when save_in_place=False:

  • save_in_place=True: save_lora_inplace() returns the same resolved ID
  • save_in_place=False: save_lora() auto-generates a timestamp suffix and returns a different ID (e.g., base-20250214-120530)

The current cache key (line 93) uses the input resolved_id, causing mismatches when non-in-place saves occur:

  • Request with the new lora_id → cache miss
  • Request reusing the old lora_id → stale cache hit

While save_in_place=False is currently unused (API always sets save_in_place=True), the code explicitly supports this mode. Apply the suggested fix to ensure correctness:

Cache key fix
 with self._cache_lock:
-    self._lora_cache[resolved_id] = result.cache_entry
+    new_resolved = await asyncio.to_thread(resolve_lora_id, result.response.lora_id)
+    self._lora_cache[new_resolved] = result.cache_entry
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
resolved_id = await asyncio.to_thread(resolve_lora_id, payload.lora_id)
with self._cache_lock:
cached = self._lora_cache.get(resolved_id)
try:
return await asyncio.to_thread(trainer.distill, payload)
result = await asyncio.to_thread(
self._trainer.distill, payload, cached=cached
)
finally:
await asyncio.to_thread(trainer.offload_base_model)
await asyncio.to_thread(self._trainer.offload_base_model)
with self._cache_lock:
self._lora_cache[resolved_id] = result.cache_entry
return result.response
resolved_id = await asyncio.to_thread(resolve_lora_id, payload.lora_id)
with self._cache_lock:
cached = self._lora_cache.get(resolved_id)
try:
result = await asyncio.to_thread(
self._trainer.distill, payload, cached=cached
)
finally:
await asyncio.to_thread(self._trainer.offload_base_model)
with self._cache_lock:
new_resolved = await asyncio.to_thread(resolve_lora_id, result.response.lora_id)
self._lora_cache[new_resolved] = result.cache_entry
return result.response
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/training/engine/local/engine.py` around lines 81 - 95, The cache is
being stored under the input-derived resolved_id (from
resolve_lora_id(payload.lora_id)) but distill() may return a different output id
(result.response.lora_id) when save_in_place=False; fix by using the output id
as the cache key: after awaiting asyncio.to_thread(self._trainer.distill, ...),
read final_id = result.response.lora_id (or from result.cache_entry if more
appropriate) and then under self._cache_lock set self._lora_cache[final_id] =
result.cache_entry (optionally remove the old resolved_id if final_id !=
resolved_id to avoid stale entries); keep the initial cache read using
resolved_id unchanged but always write using final_id.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant