Conversation
Keep the DistillationTrainer and base model across distill() calls instead of recreating them each time. Cache LoRA adapter weights and optimizer state on CPU between steps so the second call for a given lora_id skips all disk I/O. GPU memory is still fully released after each step via the existing offload_base_model() pattern. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThe PR introduces CPU-resident caching for LoRA adapter states between distillation training steps. New immutable data structures capture LoRA configuration, optimizer states, and distillation results. The distillation trainer and local engine are updated to optionally load from cache with thread-safe management and proper device placement for optimizer state tensors. Changes
Sequence DiagramsequenceDiagram
participant Client as Client Request
participant Engine as LocalTrainingEngine
participant Cache as LoRA Cache
participant Trainer as DistillationTrainer
participant Model as LoRA Model
participant Optimizer as Optimizer State
Client->>Engine: distill(lora_id, request)
Engine->>Engine: _ensure_model_loaded()
Engine->>Model: reload_base_model()
Engine->>Cache: fetch cached entry for lora_id
alt Cache Hit
Cache-->>Engine: LoraCacheEntry
Engine->>Trainer: distill(payload, cached=entry)
Trainer->>Model: _load_lora_from_cache(entry)
Trainer->>Optimizer: _gpu_optimizer_state(cached.optimizer_state_dict)
else Cache Miss
Engine->>Trainer: distill(payload, cached=None)
Trainer->>Model: load LoRA from disk
Trainer->>Optimizer: initialize optimizer
end
Trainer->>Trainer: train()
Trainer->>Model: get trained model state
Trainer->>Optimizer: get optimizer state
Trainer->>Trainer: _build_cache_entry()
Trainer-->>Engine: DistillStepResult(response, cache_entry)
Engine->>Cache: update cache[lora_id] = cache_entry
Engine->>Model: offload_base_model()
Engine-->>Client: result.response
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
claas/training/cache.py
Outdated
| @@ -0,0 +1,38 @@ | |||
| """Typed cache structures for CPU-resident LoRA state between training steps.""" | |||
There was a problem hiding this comment.
should this live in claas/training/engine/local/cache.py?
There was a problem hiding this comment.
Moved to claas/training/engine/local/cache.py along with the optimizer state helpers in 0747c76. These types and helpers are only used by the local engine's CPU caching path, so they belong in its namespace.
Comment generated by Claude Code
claas/training/distillation.py
Outdated
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]: |
There was a problem hiding this comment.
Should this live in the local engine imp?
There was a problem hiding this comment.
Moved to claas/training/engine/local/cache.py in 0747c76. distillation.py imports from the new location — the functions are only called when the local engine passes cached=.
Comment generated by Claude Code
claas/training/distillation.py
Outdated
| return result | ||
|
|
||
|
|
||
| def _gpu_optimizer_state( |
There was a problem hiding this comment.
Should this live in the local engine imp?
There was a problem hiding this comment.
Same move as above — now in claas/training/engine/local/cache.py in 0747c76.
Comment generated by Claude Code
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e6b04caaad
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| await asyncio.to_thread(self._trainer.offload_base_model) | ||
|
|
||
| with self._cache_lock: | ||
| self._lora_cache[resolved_id] = result.cache_entry |
There was a problem hiding this comment.
Cache post-step state under the returned LoRA id
When save_in_place=False (the default on DistillBatchRequestPayload), the trainer saves to a new versioned lora_id, but this code stores the cache entry under the original resolved_id from the request. If a caller submits another distill request for that same original ID, the engine will now reuse the cached newer weights/optimizer state instead of reloading the requested adapter from disk, silently changing training lineage and producing incorrect branch behavior.
Useful? React with 👍 / 👎.
| clip_fraction = sum(batch_clip_fraction) / len(batch_clip_fraction) | ||
| grad_norm_value = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm) | ||
|
|
||
| cache_entry = self._build_cache_entry(model, optimizer) |
There was a problem hiding this comment.
Make cache snapshot optional for non-caching callers
This call unconditionally clones LoRA and optimizer state to CPU every distill step, but not all callers use that cache result (the Modal worker immediately returns .response and drops it). In those paths, each request pays a full extra state copy with no benefit, which can materially increase step latency and peak memory for larger adapters.
Useful? React with 👍 / 👎.
Relocate LoraCacheEntry, LoraAdapterConfig, DistillStepResult, and the cpu/gpu_optimizer_state helpers from the shared training module into claas/training/engine/local/cache.py since they are only used by the local engine's CPU caching path. The Modal worker never uses caching. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (3)
claas/training/distillation.py (2)
140-142:reload_base_modelassumesself.base_modelandself.deviceare set.If called before
load_base_model(), this will raiseAttributeError. The engine's_ensure_model_loadedgate makes this safe today, but a defensive check or docstring noting the precondition would help.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/training/distillation.py` around lines 140 - 142, The reload_base_model method assumes self.base_model and self.device exist; add a defensive precondition check at the start of reload_base_model that verifies hasattr(self, "base_model") and hasattr(self, "device") (or self.base_model is not None / self.device is not None) and either return early (no-op) or raise a clear RuntimeError indicating load_base_model must be called first; reference the existing _ensure_model_loaded guard in the comment or docstring to make the precondition explicit and update the docstring of reload_base_model to state that load_base_model must be called prior.
43-85:_cpu_optimizer_stateand_gpu_optimizer_stateare nearly identical — consider a shared helper.The two functions differ only in the tensor placement expression (
v.detach().cpu().clone()vsv.detach().to(device).clone()). A single_remap_optimizer_state(state_dict, device)would eliminate the duplication.♻️ Proposed consolidation
-def _cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]: - """Deep-copy optimizer state with all tensors moved to CPU.""" - result: dict[str, object] = {} - for key, value in state_dict.items(): - if key == "state": - param_states = cast("dict[int, dict[str, object]]", value) - cpu_states: dict[int, dict[str, object]] = {} - for param_id, param_state in param_states.items(): - cpu_param: dict[str, object] = {} - for k, v in param_state.items(): - if isinstance(v, torch.Tensor): - cpu_param[k] = v.detach().cpu().clone() - else: - cpu_param[k] = copy.deepcopy(v) - cpu_states[param_id] = cpu_param - result[key] = cpu_states - else: - result[key] = copy.deepcopy(value) - return result - - -def _gpu_optimizer_state( - state_dict: dict[str, object], - device: torch.device, -) -> dict[str, object]: - """Deep-copy optimizer state with all tensors moved to a target device.""" - result: dict[str, object] = {} - for key, value in state_dict.items(): - if key == "state": - param_states = cast("dict[int, dict[str, object]]", value) - gpu_states: dict[int, dict[str, object]] = {} - for param_id, param_state in param_states.items(): - gpu_param: dict[str, object] = {} - for k, v in param_state.items(): - if isinstance(v, torch.Tensor): - gpu_param[k] = v.detach().to(device).clone() - else: - gpu_param[k] = copy.deepcopy(v) - gpu_states[param_id] = gpu_param - result[key] = gpu_states - else: - result[key] = copy.deepcopy(value) - return result +def _remap_optimizer_state( + state_dict: dict[str, object], + device: torch.device, +) -> dict[str, object]: + """Deep-copy optimizer state with all tensors moved to *device*.""" + result: dict[str, object] = {} + for key, value in state_dict.items(): + if key == "state": + param_states = cast("dict[int, dict[str, object]]", value) + new_states: dict[int, dict[str, object]] = {} + for param_id, param_state in param_states.items(): + new_param: dict[str, object] = {} + for k, v in param_state.items(): + if isinstance(v, torch.Tensor): + new_param[k] = v.detach().to(device).clone() + else: + new_param[k] = copy.deepcopy(v) + new_states[param_id] = new_param + result[key] = new_states + else: + result[key] = copy.deepcopy(value) + return result + + +def _cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]: + """Deep-copy optimizer state with all tensors moved to CPU.""" + return _remap_optimizer_state(state_dict, torch.device("cpu")) + + +def _gpu_optimizer_state( + state_dict: dict[str, object], + device: torch.device, +) -> dict[str, object]: + """Deep-copy optimizer state with all tensors moved to a target device.""" + return _remap_optimizer_state(state_dict, device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/training/distillation.py` around lines 43 - 85, Consolidate duplicated logic in _cpu_optimizer_state and _gpu_optimizer_state by extracting a single helper _remap_optimizer_state(state_dict: dict[str, object], device: torch.device) -> dict[str, object] that walks the state_dict exactly as today, copies non-tensor values with copy.deepcopy, and for torch.Tensor values uses v.detach().to(device).clone(); then implement _cpu_optimizer_state as a thin wrapper that calls _remap_optimizer_state(state_dict, torch.device("cpu")) and _gpu_optimizer_state as a thin wrapper that forwards the provided device to _remap_optimizer_state, keeping the existing type casts ("dict[int, dict[str, object]]") and return shape unchanged.claas/training/cache.py (1)
12-31: Frozen dataclasses with mutable containers provide only shallow immutability.
frozen=Trueprevents attribute reassignment but callers can still mutate the innerlistanddictobjects in place (e.g.,entry.lora_state_dict["new_key"] = ...). This is fine as-is since_build_cache_entrycreates defensive copies, but worth noting for future maintainers.If you want deeper guarantees later, consider
tuple[str, ...]fortarget_modulesandtypes.MappingProxyTypewrappers for the dicts.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/training/cache.py` around lines 12 - 31, The dataclasses LoraAdapterConfig and LoraCacheEntry are declared frozen but contain mutable containers (list and dict) allowing in-place mutation; update the types and construction to provide stronger immutability by changing LoraAdapterConfig.target_modules from list[str] to tuple[str, ...] and make LoraCacheEntry.lora_state_dict and optimizer_state_dict immutable views (e.g., wrap with types.MappingProxyType) when building entries; also update the builder function (_build_cache_entry or wherever entries are created) to convert incoming lists to tuples and wrap dicts in MappingProxyType so the frozen dataclass truly prevents mutations of internal containers.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@claas/training/distillation.py`:
- Around line 298-331: In _build_cache_entry, peft_config.task_type can be a
TaskType enum while LoraAdapterConfig.task_type is annotated as str; fix by
converting peft_config.task_type to its string value when constructing
adapter_config (e.g., use str(peft_config.task_type) or
peft_config.task_type.value) so adapter_config.task_type is a plain string;
update the adapter_config construction in the _build_cache_entry function to
coerce peft_config.task_type accordingly.
In `@claas/training/engine/local/engine.py`:
- Around line 60-64: There's a race in _ensure_model_loaded: multiple concurrent
distill() callers can pass the "if not self._model_loaded" check and each call
self._trainer.load_base_model via asyncio.to_thread, causing redundant loads;
fix by creating an asyncio.Lock in __init__ (e.g. self._load_lock =
asyncio.Lock()) and wrap the check/load/flag assignment inside "async with
self._load_lock" in _ensure_model_loaded so only one caller runs
asyncio.to_thread(self._trainer.load_base_model) and sets self._model_loaded =
True while holding the lock.
- Around line 81-95: The cache is being stored under the input-derived
resolved_id (from resolve_lora_id(payload.lora_id)) but distill() may return a
different output id (result.response.lora_id) when save_in_place=False; fix by
using the output id as the cache key: after awaiting
asyncio.to_thread(self._trainer.distill, ...), read final_id =
result.response.lora_id (or from result.cache_entry if more appropriate) and
then under self._cache_lock set self._lora_cache[final_id] = result.cache_entry
(optionally remove the old resolved_id if final_id != resolved_id to avoid stale
entries); keep the initial cache read using resolved_id unchanged but always
write using final_id.
---
Nitpick comments:
In `@claas/training/cache.py`:
- Around line 12-31: The dataclasses LoraAdapterConfig and LoraCacheEntry are
declared frozen but contain mutable containers (list and dict) allowing in-place
mutation; update the types and construction to provide stronger immutability by
changing LoraAdapterConfig.target_modules from list[str] to tuple[str, ...] and
make LoraCacheEntry.lora_state_dict and optimizer_state_dict immutable views
(e.g., wrap with types.MappingProxyType) when building entries; also update the
builder function (_build_cache_entry or wherever entries are created) to convert
incoming lists to tuples and wrap dicts in MappingProxyType so the frozen
dataclass truly prevents mutations of internal containers.
In `@claas/training/distillation.py`:
- Around line 140-142: The reload_base_model method assumes self.base_model and
self.device exist; add a defensive precondition check at the start of
reload_base_model that verifies hasattr(self, "base_model") and hasattr(self,
"device") (or self.base_model is not None / self.device is not None) and either
return early (no-op) or raise a clear RuntimeError indicating load_base_model
must be called first; reference the existing _ensure_model_loaded guard in the
comment or docstring to make the precondition explicit and update the docstring
of reload_base_model to state that load_base_model must be called prior.
- Around line 43-85: Consolidate duplicated logic in _cpu_optimizer_state and
_gpu_optimizer_state by extracting a single helper
_remap_optimizer_state(state_dict: dict[str, object], device: torch.device) ->
dict[str, object] that walks the state_dict exactly as today, copies non-tensor
values with copy.deepcopy, and for torch.Tensor values uses
v.detach().to(device).clone(); then implement _cpu_optimizer_state as a thin
wrapper that calls _remap_optimizer_state(state_dict, torch.device("cpu")) and
_gpu_optimizer_state as a thin wrapper that forwards the provided device to
_remap_optimizer_state, keeping the existing type casts ("dict[int, dict[str,
object]]") and return shape unchanged.
ℹ️ Review info
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to data retention organization setting
📒 Files selected for processing (6)
claas/modal/worker.pyclaas/training/cache.pyclaas/training/distillation.pyclaas/training/engine/local/engine.pytests/test_distillation_optimizer_state.pytests/test_local_training_engine.py
| def _build_cache_entry( | ||
| self, | ||
| model: "PeftModel | PeftMixedModel", | ||
| optimizer: "torch.optim.Optimizer", | ||
| ) -> LoraCacheEntry: | ||
| """Snapshot current model + optimizer state into a CPU-resident cache entry.""" | ||
| from peft import PeftModel as PeftModelCls | ||
|
|
||
| peft_config = model.peft_config["default"] | ||
| adapter_config = LoraAdapterConfig( | ||
| r=peft_config.r, | ||
| lora_alpha=peft_config.lora_alpha, | ||
| target_modules=list(peft_config.target_modules), | ||
| lora_dropout=peft_config.lora_dropout, | ||
| bias=peft_config.bias, | ||
| task_type=peft_config.task_type, | ||
| ) | ||
|
|
||
| # Determine state dict — use PEFT's adapter-only extraction if available | ||
| if isinstance(model, PeftModelCls): | ||
| from peft import get_peft_model_state_dict | ||
|
|
||
| raw_state = get_peft_model_state_dict(model) | ||
| else: | ||
| raw_state = model.state_dict() | ||
|
|
||
| lora_state = {k: v.detach().cpu().clone() for k, v in raw_state.items()} | ||
| opt_state = _cpu_optimizer_state(optimizer.state_dict()) | ||
|
|
||
| return LoraCacheEntry( | ||
| lora_state_dict=lora_state, | ||
| optimizer_state_dict=opt_state, | ||
| adapter_config=adapter_config, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Find LoraAdapterConfig definition
find . -name "*.py" -type f | xargs rg "class LoraAdapterConfig" -A 10Repository: kfallah/CLaaS
Length of output: 550
🏁 Script executed:
# Also check the imports in the distillation.py file
head -50 claas/training/distillation.pyRepository: kfallah/CLaaS
Length of output: 1452
🌐 Web query:
PEFT LoraConfig task_type attribute return type TaskType enum or string
💡 Result:
In Hugging Face PEFT, LoraConfig.task_type is not limited to the TaskType enum.
- The config base class (
PeftConfigMixin) documentstask_typeasUnion[TaskType, str](optionallyNone). (huggingface.co) - PEFT docs/examples show passing either
TaskType.SEQ_2_SEQ_LM(enum) (huggingface.co) or a string like"SEQ_2_SEQ_LM". (huggingface.co)
So, the effective “return type” you should expect when reading config.task_type is TaskType | str | None (depending on how it was created/loaded). (huggingface.co)
Citations:
- 1: https://huggingface.co/docs/peft/main/en/package_reference/config?utm_source=openai
- 2: https://huggingface.co/docs/peft/v0.16.0/quicktour?utm_source=openai
- 3: https://huggingface.co/docs/peft/main/package_reference/lora?utm_source=openai
- 4: https://huggingface.co/docs/peft/main/en/package_reference/config?utm_source=openai
🏁 Script executed:
# Check what PEFT's LoraConfig actually returns for these fields
python3 <<'EOF'
import sys
try:
from peft import TaskType, LoraConfig
# Check if TaskType is an enum
print(f"TaskType type: {type(TaskType)}")
print(f"TaskType.CAUSAL_LM: {TaskType.CAUSAL_LM}")
print(f"TaskType.CAUSAL_LM type: {type(TaskType.CAUSAL_LM)}")
print(f"TaskType.CAUSAL_LM.value: {TaskType.CAUSAL_LM.value if hasattr(TaskType.CAUSAL_LM, 'value') else 'N/A'}")
# Create a LoraConfig to check what task_type looks like
config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
task_type=TaskType.CAUSAL_LM,
bias="none"
)
print(f"\nconfig.task_type: {config.task_type}")
print(f"config.task_type type: {type(config.task_type)}")
print(f"config.bias: {config.bias}")
print(f"config.bias type: {type(config.bias)}")
# Check if bias is already a string
print(f"\nIs task_type an enum after LoraConfig? {hasattr(config.task_type, 'value')}")
print(f"Is bias a string? {isinstance(config.bias, str)}")
except ImportError:
print("PEFT not installed, skipping verification", file=sys.stderr)
sys.exit(1)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
EOFRepository: kfallah/CLaaS
Length of output: 98
🏁 Script executed:
# Also check the actual PEFT LoraConfig source to understand type handling
python3 <<'EOF'
from peft import LoraConfig
import inspect
# Get the source code location
print(f"LoraConfig location: {inspect.getfile(LoraConfig)}")
# Check LoraConfig's __init__ signature
sig = inspect.signature(LoraConfig.__init__)
for param_name in ['task_type', 'bias']:
if param_name in sig.parameters:
param = sig.parameters[param_name]
print(f"{param_name}: annotation={param.annotation}, default={param.default}")
EOFRepository: kfallah/CLaaS
Length of output: 173
🏁 Script executed:
# Check how peft_config is used throughout the codebase
rg "peft_config\." --type py -B 2 -A 2 | head -100Repository: kfallah/CLaaS
Length of output: 740
🏁 Script executed:
# Search for TaskType usage in the repository
rg "TaskType" --type py -B 2 -A 2Repository: kfallah/CLaaS
Length of output: 39
🏁 Script executed:
# Check requirements or setup files for PEFT version
find . -name "requirements*.txt" -o -name "setup.py" -o -name "pyproject.toml" | xargs cat 2>/dev/null | grep -i peftRepository: kfallah/CLaaS
Length of output: 76
🌐 Web query:
PEFT LoraConfig bias attribute type enum string representation
💡 Result:
In Hugging Face PEFT peft.LoraConfig, the bias attribute is a string constrained to specific values (typed as a Literal in the docs):
bias: Literal["none", "all", "lora_only"] = "none"[1]"none": don’t train any bias terms"all": train all model biases"lora_only": train only the biases belonging to LoRA layers [1][2]
So the “enum” is effectively these three string representations: "none", "all", "lora_only" (default "none"). [1][2]
Sources: [1] [2]
🌐 Web query:
PEFT 0.18.1 LoraConfig TaskType task_type attribute type enum
💡 Result:
In PEFT 0.18.1, LoraConfig.task_type is typed as:
Optional[Union[str, TaskType]](so you can pass either aTaskTypeenum value or the equivalent string). (huggingface.co)
TaskType is a str-backed enum.Enum with these values:
TaskType.SEQ_CLS = "SEQ_CLS"TaskType.SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM"TaskType.CAUSAL_LM = "CAUSAL_LM"TaskType.TOKEN_CLS = "TOKEN_CLS"TaskType.QUESTION_ANS = "QUESTION_ANS"TaskType.FEATURE_EXTRACTION = "FEATURE_EXTRACTION"(raw.githubusercontent.com)
Citations:
- 1: https://huggingface.co/docs/peft/main/package_reference/lora?utm_source=openai
- 2: https://raw.githubusercontent.com/huggingface/peft/v0.18.1/src/peft/utils/peft_types.py
🏁 Script executed:
# Check how LoraAdapterConfig is actually used after creation in the codebase
rg "LoraAdapterConfig" --type py -B 3 -A 3 | head -80Repository: kfallah/CLaaS
Length of output: 3726
🏁 Script executed:
# Look for any serialization or conversion of adapter_config
rg "adapter_config" claas/training/distillation.py -B 2 -A 2Repository: kfallah/CLaaS
Length of output: 738
Coerce task_type enum to string for type consistency
peft_config.task_type may be a TaskType enum value (e.g. TaskType.CAUSAL_LM), but LoraAdapterConfig.task_type is annotated as str. While this works at runtime because TaskType is a string-backed enum, the type mismatch should be corrected by extracting the string value.
Note: bias is already a string per PEFT 0.18.1 and does not need coercion.
🔧 Proposed fix
adapter_config = LoraAdapterConfig(
r=peft_config.r,
lora_alpha=peft_config.lora_alpha,
target_modules=list(peft_config.target_modules),
lora_dropout=peft_config.lora_dropout,
bias=peft_config.bias,
- task_type=peft_config.task_type,
+ task_type=peft_config.task_type.value if hasattr(peft_config.task_type, "value") else str(peft_config.task_type),
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@claas/training/distillation.py` around lines 298 - 331, In
_build_cache_entry, peft_config.task_type can be a TaskType enum while
LoraAdapterConfig.task_type is annotated as str; fix by converting
peft_config.task_type to its string value when constructing adapter_config
(e.g., use str(peft_config.task_type) or peft_config.task_type.value) so
adapter_config.task_type is a plain string; update the adapter_config
construction in the _build_cache_entry function to coerce peft_config.task_type
accordingly.
| async def _ensure_model_loaded(self) -> None: | ||
| """One-time base model load on first distill() call.""" | ||
| if not self._model_loaded: | ||
| await asyncio.to_thread(self._trainer.load_base_model) | ||
| self._model_loaded = True |
There was a problem hiding this comment.
_ensure_model_loaded has a race window under concurrent async callers.
Between the if not self._model_loaded check and self._model_loaded = True assignment, there's an await yield point. A second concurrent distill() call can observe _model_loaded is False and also trigger load_base_model. The result is a harmless but wasteful double-load at startup.
An asyncio.Lock would close the gap:
🔧 Proposed fix
Add to __init__:
self._load_lock = asyncio.Lock()Then:
async def _ensure_model_loaded(self) -> None:
"""One-time base model load on first distill() call."""
- if not self._model_loaded:
- await asyncio.to_thread(self._trainer.load_base_model)
- self._model_loaded = True
+ async with self._load_lock:
+ if not self._model_loaded:
+ await asyncio.to_thread(self._trainer.load_base_model)
+ self._model_loaded = True🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@claas/training/engine/local/engine.py` around lines 60 - 64, There's a race
in _ensure_model_loaded: multiple concurrent distill() callers can pass the "if
not self._model_loaded" check and each call self._trainer.load_base_model via
asyncio.to_thread, causing redundant loads; fix by creating an asyncio.Lock in
__init__ (e.g. self._load_lock = asyncio.Lock()) and wrap the check/load/flag
assignment inside "async with self._load_lock" in _ensure_model_loaded so only
one caller runs asyncio.to_thread(self._trainer.load_base_model) and sets
self._model_loaded = True while holding the lock.
| resolved_id = await asyncio.to_thread(resolve_lora_id, payload.lora_id) | ||
| with self._cache_lock: | ||
| cached = self._lora_cache.get(resolved_id) | ||
|
|
||
| try: | ||
| return await asyncio.to_thread(trainer.distill, payload) | ||
| result = await asyncio.to_thread( | ||
| self._trainer.distill, payload, cached=cached | ||
| ) | ||
| finally: | ||
| await asyncio.to_thread(trainer.offload_base_model) | ||
| await asyncio.to_thread(self._trainer.offload_base_model) | ||
|
|
||
| with self._cache_lock: | ||
| self._lora_cache[resolved_id] = result.cache_entry | ||
|
|
||
| return result.response |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Find save_in_place usage in the codebase
rg -n 'save_in_place' --type py -C 2Repository: kfallah/CLaaS
Length of output: 1283
🏁 Script executed:
#!/bin/bash
# Read the distillation.py file around the save_in_place logic to see what happens in both branches
sed -n '485,510p' claas/training/distillation.pyRepository: kfallah/CLaaS
Length of output: 1235
🏁 Script executed:
#!/bin/bash
# Find the definitions of save_lora and save_lora_inplace functions
rg -n 'def save_lora' --type py -A 15Repository: kfallah/CLaaS
Length of output: 2048
🏁 Script executed:
#!/bin/bash
# Read the full save_lora function implementation
sed -n '292,345p' claas/training/storage.pyRepository: kfallah/CLaaS
Length of output: 1645
🏁 Script executed:
#!/bin/bash
# Read the save_lora_inplace function implementation
sed -n '346,390p' claas/training/storage.pyRepository: kfallah/CLaaS
Length of output: 1718
🏁 Script executed:
#!/bin/bash
# Read the rest of save_lora_inplace function to see the return statement
sed -n '390,400p' claas/training/storage.pyRepository: kfallah/CLaaS
Length of output: 308
🏁 Script executed:
#!/bin/bash
# Find where DistillResponse is created and how new_lora_id is used
sed -n '508,525p' claas/training/distillation.pyRepository: kfallah/CLaaS
Length of output: 815
🏁 Script executed:
#!/bin/bash
# Check if save_in_place=False is actually used anywhere (excluding tests)
rg -n 'save_in_place\s*=\s*False' --type pyRepository: kfallah/CLaaS
Length of output: 39
🏁 Script executed:
#!/bin/bash
# Check if there are any other callers of distill() that might use save_in_place=False
rg -n '\.distill\(' --type py -B 5 -A 2Repository: kfallah/CLaaS
Length of output: 8135
Key cache by output lora_id instead of input to handle non-in-place saves correctly.
resolved_id (line 81) is derived from payload.lora_id. The distill() method returns result.response.lora_id which may differ from the input when save_in_place=False:
save_in_place=True:save_lora_inplace()returns the same resolved IDsave_in_place=False:save_lora()auto-generates a timestamp suffix and returns a different ID (e.g.,base-20250214-120530)
The current cache key (line 93) uses the input resolved_id, causing mismatches when non-in-place saves occur:
- Request with the new lora_id → cache miss
- Request reusing the old lora_id → stale cache hit
While save_in_place=False is currently unused (API always sets save_in_place=True), the code explicitly supports this mode. Apply the suggested fix to ensure correctness:
Cache key fix
with self._cache_lock:
- self._lora_cache[resolved_id] = result.cache_entry
+ new_resolved = await asyncio.to_thread(resolve_lora_id, result.response.lora_id)
+ self._lora_cache[new_resolved] = result.cache_entry📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| resolved_id = await asyncio.to_thread(resolve_lora_id, payload.lora_id) | |
| with self._cache_lock: | |
| cached = self._lora_cache.get(resolved_id) | |
| try: | |
| return await asyncio.to_thread(trainer.distill, payload) | |
| result = await asyncio.to_thread( | |
| self._trainer.distill, payload, cached=cached | |
| ) | |
| finally: | |
| await asyncio.to_thread(trainer.offload_base_model) | |
| await asyncio.to_thread(self._trainer.offload_base_model) | |
| with self._cache_lock: | |
| self._lora_cache[resolved_id] = result.cache_entry | |
| return result.response | |
| resolved_id = await asyncio.to_thread(resolve_lora_id, payload.lora_id) | |
| with self._cache_lock: | |
| cached = self._lora_cache.get(resolved_id) | |
| try: | |
| result = await asyncio.to_thread( | |
| self._trainer.distill, payload, cached=cached | |
| ) | |
| finally: | |
| await asyncio.to_thread(self._trainer.offload_base_model) | |
| with self._cache_lock: | |
| new_resolved = await asyncio.to_thread(resolve_lora_id, result.response.lora_id) | |
| self._lora_cache[new_resolved] = result.cache_entry | |
| return result.response |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@claas/training/engine/local/engine.py` around lines 81 - 95, The cache is
being stored under the input-derived resolved_id (from
resolve_lora_id(payload.lora_id)) but distill() may return a different output id
(result.response.lora_id) when save_in_place=False; fix by using the output id
as the cache key: after awaiting asyncio.to_thread(self._trainer.distill, ...),
read final_id = result.response.lora_id (or from result.cache_entry if more
appropriate) and then under self._cache_lock set self._lora_cache[final_id] =
result.cache_entry (optionally remove the old resolved_id if final_id !=
resolved_id to avoid stale entries); keep the initial cache read using
resolved_id unchanged but always write using final_id.
Summary
LocalTrainingEnginenow creates theDistillationTraineronce in__init__and reuses it acrossdistill()calls, eliminating redundant base model loads after the first callLoraCacheEntry). Subsequent calls for the samelora_idskip all disk I/O (load_lora/load_optimizer_state)claas/training/cache.pywith frozen@dataclasstypes (LoraCacheEntry,LoraAdapterConfig,DistillStepResult) — no loose dicts, no optional types where invariants can be enforcedoffload_base_model()+del model, optimizer+cuda.empty_cache()pattern is unchanged. Cache holds CPU-only tensors by constructionTest plan
tests/test_distillation_optimizer_state.py—_cpu_optimizer_state/_gpu_optimizer_stateround-trip, deep-copy isolation,LoraCacheEntryimmutabilitytests/test_local_training_engine.py— eager trainer creation, one-timeload_base_model, per-callreload_base_model, cache miss→hit, cache eviction on delete, offload error propagationuv run pytest tests/ -v -m "not integration"(103 passed, 25 skipped for torch)uv run ruff check claas/ tests/uv run ty check(only pre-existing unresolved-import errors for torch/peft/transformers)🤖 Generated with Claude Code
Summary by CodeRabbit
Release Notes
New Features
Refactor