Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion claas/modal/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def distill(self, request: DistillBatchRequestPayload) -> DistillResponse:
Distillation response payload.
"""
try:
return self.trainer.distill(request)
return self.trainer.distill(request).response
finally:
self.trainer.offload_base_model()

Expand Down
114 changes: 107 additions & 7 deletions claas/training/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
import torch

from claas.core.types import DistillBatchRequestPayload, DistillResponse, SDPOLossInput
from claas.training.engine.local.cache import (
DistillStepResult,
LoraAdapterConfig,
LoraCacheEntry,
cpu_optimizer_state,
gpu_optimizer_state,
)
from claas.training.sdpo_loss import compute_sdpo_loss
from claas.training.storage import (
cleanup_local_lora,
Expand Down Expand Up @@ -86,6 +93,10 @@ def load_base_model(self) -> None:
self.optimizer_cls = torch.optim.AdamW
self.functional = torch.nn.functional

def reload_base_model(self) -> None:
"""Move base model from CPU back to CUDA."""
self.base_model.to(self.device) # type: ignore[arg-type] # functools.wraps confuses ty

def offload_base_model(self) -> None:
"""Move base model to CPU and release CUDA memory."""

Expand Down Expand Up @@ -129,6 +140,33 @@ def _load_or_create_lora(self, lora_path: str) -> "PeftModel | PeftMixedModel":
)
return get_peft_model(self.base_model, lora_config)

def _load_lora_from_cache(
self,
cached: LoraCacheEntry,
) -> "PeftModel | PeftMixedModel":
"""Restore a LoRA adapter from a CPU cache entry.

Args:
cached: CPU-resident snapshot of adapter state.

Returns:
Trainable PEFT model with cached weights loaded.
"""
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict

cfg = cached.adapter_config
lora_config = LoraConfig(
r=cfg.r,
lora_alpha=cfg.lora_alpha,
target_modules=cfg.target_modules,
lora_dropout=cfg.lora_dropout,
bias=cfg.bias,
task_type=cfg.task_type,
)
model = get_peft_model(self.base_model, lora_config)
set_peft_model_state_dict(model, cached.lora_state_dict)
return model

def _load_optimizer_state(
self,
lora_path: str,
Expand Down Expand Up @@ -213,14 +251,58 @@ def _build_self_teacher_topk(
torch.cuda.empty_cache()
return top_logprobs, top_indices

def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
def _build_cache_entry(
self,
model: "PeftModel | PeftMixedModel",
optimizer: "torch.optim.Optimizer",
) -> LoraCacheEntry:
"""Snapshot current model + optimizer state into a CPU-resident cache entry."""
from peft import PeftModel as PeftModelCls

peft_config = model.peft_config["default"]
adapter_config = LoraAdapterConfig(
r=peft_config.r,
lora_alpha=peft_config.lora_alpha,
target_modules=list(peft_config.target_modules),
lora_dropout=peft_config.lora_dropout,
bias=peft_config.bias,
task_type=peft_config.task_type,
)

# Determine state dict — use PEFT's adapter-only extraction if available
if isinstance(model, PeftModelCls):
from peft import get_peft_model_state_dict

raw_state = get_peft_model_state_dict(model)
else:
raw_state = model.state_dict()

lora_state = {k: v.detach().cpu().clone() for k, v in raw_state.items()}
opt_state = cpu_optimizer_state(optimizer.state_dict())

return LoraCacheEntry(
lora_state_dict=lora_state,
optimizer_state_dict=opt_state,
adapter_config=adapter_config,
)
Comment on lines 298 to 331
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# Find LoraAdapterConfig definition
find . -name "*.py" -type f | xargs rg "class LoraAdapterConfig" -A 10

Repository: kfallah/CLaaS

Length of output: 550


🏁 Script executed:

# Also check the imports in the distillation.py file
head -50 claas/training/distillation.py

Repository: kfallah/CLaaS

Length of output: 1452


🌐 Web query:

PEFT LoraConfig task_type attribute return type TaskType enum or string

💡 Result:

In Hugging Face PEFT, LoraConfig.task_type is not limited to the TaskType enum.

  • The config base class (PeftConfigMixin) documents task_type as Union[TaskType, str] (optionally None). (huggingface.co)
  • PEFT docs/examples show passing either TaskType.SEQ_2_SEQ_LM (enum) (huggingface.co) or a string like "SEQ_2_SEQ_LM". (huggingface.co)

So, the effective “return type” you should expect when reading config.task_type is TaskType | str | None (depending on how it was created/loaded). (huggingface.co)

Citations:


🏁 Script executed:

# Check what PEFT's LoraConfig actually returns for these fields
python3 <<'EOF'
import sys
try:
    from peft import TaskType, LoraConfig
    
    # Check if TaskType is an enum
    print(f"TaskType type: {type(TaskType)}")
    print(f"TaskType.CAUSAL_LM: {TaskType.CAUSAL_LM}")
    print(f"TaskType.CAUSAL_LM type: {type(TaskType.CAUSAL_LM)}")
    print(f"TaskType.CAUSAL_LM.value: {TaskType.CAUSAL_LM.value if hasattr(TaskType.CAUSAL_LM, 'value') else 'N/A'}")
    
    # Create a LoraConfig to check what task_type looks like
    config = LoraConfig(
        r=8,
        lora_alpha=16,
        target_modules=["q_proj", "v_proj"],
        task_type=TaskType.CAUSAL_LM,
        bias="none"
    )
    
    print(f"\nconfig.task_type: {config.task_type}")
    print(f"config.task_type type: {type(config.task_type)}")
    print(f"config.bias: {config.bias}")
    print(f"config.bias type: {type(config.bias)}")
    
    # Check if bias is already a string
    print(f"\nIs task_type an enum after LoraConfig? {hasattr(config.task_type, 'value')}")
    print(f"Is bias a string? {isinstance(config.bias, str)}")
    
except ImportError:
    print("PEFT not installed, skipping verification", file=sys.stderr)
    sys.exit(1)
except Exception as e:
    print(f"Error: {e}", file=sys.stderr)
    sys.exit(1)
EOF

Repository: kfallah/CLaaS

Length of output: 98


🏁 Script executed:

# Also check the actual PEFT LoraConfig source to understand type handling
python3 <<'EOF'
from peft import LoraConfig
import inspect

# Get the source code location
print(f"LoraConfig location: {inspect.getfile(LoraConfig)}")

# Check LoraConfig's __init__ signature
sig = inspect.signature(LoraConfig.__init__)
for param_name in ['task_type', 'bias']:
    if param_name in sig.parameters:
        param = sig.parameters[param_name]
        print(f"{param_name}: annotation={param.annotation}, default={param.default}")
EOF

Repository: kfallah/CLaaS

Length of output: 173


🏁 Script executed:

# Check how peft_config is used throughout the codebase
rg "peft_config\." --type py -B 2 -A 2 | head -100

Repository: kfallah/CLaaS

Length of output: 740


🏁 Script executed:

# Search for TaskType usage in the repository
rg "TaskType" --type py -B 2 -A 2

Repository: kfallah/CLaaS

Length of output: 39


🏁 Script executed:

# Check requirements or setup files for PEFT version
find . -name "requirements*.txt" -o -name "setup.py" -o -name "pyproject.toml" | xargs cat 2>/dev/null | grep -i peft

Repository: kfallah/CLaaS

Length of output: 76


🌐 Web query:

PEFT LoraConfig bias attribute type enum string representation

💡 Result:

In Hugging Face PEFT peft.LoraConfig, the bias attribute is a string constrained to specific values (typed as a Literal in the docs):

  • bias: Literal["none", "all", "lora_only"] = "none" [1]
    • "none": don’t train any bias terms
    • "all": train all model biases
    • "lora_only": train only the biases belonging to LoRA layers [1][2]

So the “enum” is effectively these three string representations: "none", "all", "lora_only" (default "none"). [1][2]

Sources: [1] [2]


🌐 Web query:

PEFT 0.18.1 LoraConfig TaskType task_type attribute type enum

💡 Result:

In PEFT 0.18.1, LoraConfig.task_type is typed as:

  • Optional[Union[str, TaskType]] (so you can pass either a TaskType enum value or the equivalent string). (huggingface.co)

TaskType is a str-backed enum.Enum with these values:

  • TaskType.SEQ_CLS = "SEQ_CLS"
  • TaskType.SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM"
  • TaskType.CAUSAL_LM = "CAUSAL_LM"
  • TaskType.TOKEN_CLS = "TOKEN_CLS"
  • TaskType.QUESTION_ANS = "QUESTION_ANS"
  • TaskType.FEATURE_EXTRACTION = "FEATURE_EXTRACTION" (raw.githubusercontent.com)

Citations:


🏁 Script executed:

# Check how LoraAdapterConfig is actually used after creation in the codebase
rg "LoraAdapterConfig" --type py -B 3 -A 3 | head -80

Repository: kfallah/CLaaS

Length of output: 3726


🏁 Script executed:

# Look for any serialization or conversion of adapter_config
rg "adapter_config" claas/training/distillation.py -B 2 -A 2

Repository: kfallah/CLaaS

Length of output: 738


Coerce task_type enum to string for type consistency

peft_config.task_type may be a TaskType enum value (e.g. TaskType.CAUSAL_LM), but LoraAdapterConfig.task_type is annotated as str. While this works at runtime because TaskType is a string-backed enum, the type mismatch should be corrected by extracting the string value.

Note: bias is already a string per PEFT 0.18.1 and does not need coercion.

🔧 Proposed fix
         adapter_config = LoraAdapterConfig(
             r=peft_config.r,
             lora_alpha=peft_config.lora_alpha,
             target_modules=list(peft_config.target_modules),
             lora_dropout=peft_config.lora_dropout,
             bias=peft_config.bias,
-            task_type=peft_config.task_type,
+            task_type=peft_config.task_type.value if hasattr(peft_config.task_type, "value") else str(peft_config.task_type),
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/training/distillation.py` around lines 298 - 331, In
_build_cache_entry, peft_config.task_type can be a TaskType enum while
LoraAdapterConfig.task_type is annotated as str; fix by converting
peft_config.task_type to its string value when constructing adapter_config
(e.g., use str(peft_config.task_type) or peft_config.task_type.value) so
adapter_config.task_type is a plain string; update the adapter_config
construction in the _build_cache_entry function to coerce peft_config.task_type
accordingly.


def distill(
self,
payload: DistillBatchRequestPayload,
*,
cached: LoraCacheEntry | None = None,
) -> DistillStepResult:
"""Run one SDPO distillation step.

Args:
payload: Distillation request payload.
cached: When provided, skip disk reads and load LoRA + optimizer
state from this CPU-resident cache entry. When ``None``,
load from disk (cold start).

Returns:
Distillation response with metrics.
Result containing both the distillation response and a cache
entry for the post-step state.
"""

torch.cuda.empty_cache()
Expand All @@ -231,10 +313,18 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
if len(payload.samples) == 0:
raise ValueError("samples must contain at least one item")

lora_local_path = load_lora(payload.lora_id)
# Disk path (cold start) or cache path
lora_local_path: str | None = None
if cached is None:
lora_local_path = load_lora(payload.lora_id)

try:
try:
model = self._load_or_create_lora(lora_local_path)
if cached is not None:
model = self._load_lora_from_cache(cached)
else:
assert lora_local_path is not None
model = self._load_or_create_lora(lora_local_path)
model.train()
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False},
Expand All @@ -249,7 +339,13 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
betas=(0.9, 0.999),
weight_decay=0.01,
)
self._load_optimizer_state(lora_local_path, optimizer)

if cached is not None:
optimizer.load_state_dict(
gpu_optimizer_state(cached.optimizer_state_dict, self.device)
)
elif lora_local_path is not None:
self._load_optimizer_state(lora_local_path, optimizer)

batch_loss_tensors: list[torch.Tensor] = []
batch_distill_loss: list[float] = []
Expand Down Expand Up @@ -362,10 +458,12 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
clip_fraction = sum(batch_clip_fraction) / len(batch_clip_fraction)
grad_norm_value = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm)

cache_entry = self._build_cache_entry(model, optimizer)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Make cache snapshot optional for non-caching callers

This call unconditionally clones LoRA and optimizer state to CPU every distill step, but not all callers use that cache result (the Modal worker immediately returns .response and drops it). In those paths, each request pays a full extra state copy with no benefit, which can materially increase step latency and peak memory for larger adapters.

Useful? React with 👍 / 👎.


del model, optimizer, batch_loss_tensors
torch.cuda.empty_cache()

return DistillResponse.model_validate(
response = DistillResponse.model_validate(
{
"lora_id": new_lora_id,
"metadata": {
Expand All @@ -380,5 +478,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
},
}
)
return DistillStepResult(response=response, cache_entry=cache_entry)
finally:
cleanup_local_lora(lora_local_path)
if lora_local_path is not None:
cleanup_local_lora(lora_local_path)
85 changes: 85 additions & 0 deletions claas/training/engine/local/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Typed cache structures and helpers for CPU-resident LoRA state between training steps."""

from __future__ import annotations

import copy
from dataclasses import dataclass
from typing import cast

import torch

from claas.core.types import DistillResponse


@dataclass(frozen=True, slots=True)
class LoraAdapterConfig:
"""Typed representation of LoRA adapter configuration."""

r: int
lora_alpha: int
target_modules: list[str]
lora_dropout: float
bias: str
task_type: str


@dataclass(frozen=True, slots=True)
class LoraCacheEntry:
"""CPU-resident snapshot of LoRA adapter state between training steps."""

lora_state_dict: dict[str, torch.Tensor]
optimizer_state_dict: dict[str, object]
adapter_config: LoraAdapterConfig


@dataclass(frozen=True, slots=True)
class DistillStepResult:
"""Result of a distillation step with both response and cache entry."""

response: DistillResponse
cache_entry: LoraCacheEntry


def cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]:
"""Deep-copy optimizer state with all tensors moved to CPU."""
result: dict[str, object] = {}
for key, value in state_dict.items():
if key == "state":
param_states = cast("dict[int, dict[str, object]]", value)
cpu_states: dict[int, dict[str, object]] = {}
for param_id, param_state in param_states.items():
cpu_param: dict[str, object] = {}
for k, v in param_state.items():
if isinstance(v, torch.Tensor):
cpu_param[k] = v.detach().cpu().clone()
else:
cpu_param[k] = copy.deepcopy(v)
cpu_states[param_id] = cpu_param
result[key] = cpu_states
else:
result[key] = copy.deepcopy(value)
return result


def gpu_optimizer_state(
state_dict: dict[str, object],
device: torch.device,
) -> dict[str, object]:
"""Deep-copy optimizer state with all tensors moved to a target device."""
result: dict[str, object] = {}
for key, value in state_dict.items():
if key == "state":
param_states = cast("dict[int, dict[str, object]]", value)
gpu_states: dict[int, dict[str, object]] = {}
for param_id, param_state in param_states.items():
gpu_param: dict[str, object] = {}
for k, v in param_state.items():
if isinstance(v, torch.Tensor):
gpu_param[k] = v.detach().to(device).clone()
else:
gpu_param[k] = copy.deepcopy(v)
gpu_states[param_id] = gpu_param
result[key] = gpu_states
else:
result[key] = copy.deepcopy(value)
return result
50 changes: 43 additions & 7 deletions claas/training/engine/local/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from __future__ import annotations

import asyncio
import logging
import re
import threading

from claas.core.config import LocalConfig
from claas.core.types import (
Expand All @@ -20,6 +22,7 @@
)
from claas.training.distillation import DistillationTrainer
from claas.training.engine.base import TrainingEngine
from claas.training.engine.local.cache import LoraCacheEntry
from claas.training.storage import (
configure_storage_backend,
create_initial_lora,
Expand All @@ -31,14 +34,34 @@
resolve_lora_id,
)

logger = logging.getLogger(__name__)


class LocalTrainingEngine(TrainingEngine):
"""Executes training and LoRA operations on local infrastructure."""

_trainer: DistillationTrainer
_lora_cache: dict[str, LoraCacheEntry]
_cache_lock: threading.Lock
_model_loaded: bool

def __init__(self, cfg: LocalConfig) -> None:
configure_storage_backend("local_fs")
self._base_model_id = cfg.base_model_id
self._attn_implementation = cfg.attn_implementation
self._trainer = DistillationTrainer(
base_model_id=cfg.base_model_id,
attn_implementation=cfg.attn_implementation,
)
self._lora_cache = {}
self._cache_lock = threading.Lock()
self._model_loaded = False

async def _ensure_model_loaded(self) -> None:
"""One-time base model load on first distill() call."""
if not self._model_loaded:
await asyncio.to_thread(self._trainer.load_base_model)
self._model_loaded = True
Comment on lines +60 to +64
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

_ensure_model_loaded has a race window under concurrent async callers.

Between the if not self._model_loaded check and self._model_loaded = True assignment, there's an await yield point. A second concurrent distill() call can observe _model_loaded is False and also trigger load_base_model. The result is a harmless but wasteful double-load at startup.

An asyncio.Lock would close the gap:

🔧 Proposed fix

Add to __init__:

self._load_lock = asyncio.Lock()

Then:

     async def _ensure_model_loaded(self) -> None:
         """One-time base model load on first distill() call."""
-        if not self._model_loaded:
-            await asyncio.to_thread(self._trainer.load_base_model)
-            self._model_loaded = True
+        async with self._load_lock:
+            if not self._model_loaded:
+                await asyncio.to_thread(self._trainer.load_base_model)
+                self._model_loaded = True
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/training/engine/local/engine.py` around lines 60 - 64, There's a race
in _ensure_model_loaded: multiple concurrent distill() callers can pass the "if
not self._model_loaded" check and each call self._trainer.load_base_model via
asyncio.to_thread, causing redundant loads; fix by creating an asyncio.Lock in
__init__ (e.g. self._load_lock = asyncio.Lock()) and wrap the check/load/flag
assignment inside "async with self._load_lock" in _ensure_model_loaded so only
one caller runs asyncio.to_thread(self._trainer.load_base_model) and sets
self._model_loaded = True while holding the lock.


async def distill(
self,
Expand All @@ -52,15 +75,24 @@ async def distill(
Returns:
Distillation response.
"""
trainer = DistillationTrainer(
base_model_id=self._base_model_id,
attn_implementation=self._attn_implementation,
)
await asyncio.to_thread(trainer.load_base_model)
await self._ensure_model_loaded()
await asyncio.to_thread(self._trainer.reload_base_model)

resolved_id = await asyncio.to_thread(resolve_lora_id, payload.lora_id)
with self._cache_lock:
cached = self._lora_cache.get(resolved_id)

try:
return await asyncio.to_thread(trainer.distill, payload)
result = await asyncio.to_thread(
self._trainer.distill, payload, cached=cached
)
finally:
await asyncio.to_thread(trainer.offload_base_model)
await asyncio.to_thread(self._trainer.offload_base_model)

with self._cache_lock:
self._lora_cache[resolved_id] = result.cache_entry

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Cache post-step state under the returned LoRA id

When save_in_place=False (the default on DistillBatchRequestPayload), the trainer saves to a new versioned lora_id, but this code stores the cache entry under the original resolved_id from the request. If a caller submits another distill request for that same original ID, the engine will now reuse the cached newer weights/optimizer state instead of reloading the requested adapter from disk, silently changing training lineage and producing incorrect branch behavior.

Useful? React with 👍 / 👎.


return result.response
Comment on lines +81 to +95
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find save_in_place usage in the codebase
rg -n 'save_in_place' --type py -C 2

Repository: kfallah/CLaaS

Length of output: 1283


🏁 Script executed:

#!/bin/bash
# Read the distillation.py file around the save_in_place logic to see what happens in both branches
sed -n '485,510p' claas/training/distillation.py

Repository: kfallah/CLaaS

Length of output: 1235


🏁 Script executed:

#!/bin/bash
# Find the definitions of save_lora and save_lora_inplace functions
rg -n 'def save_lora' --type py -A 15

Repository: kfallah/CLaaS

Length of output: 2048


🏁 Script executed:

#!/bin/bash
# Read the full save_lora function implementation
sed -n '292,345p' claas/training/storage.py

Repository: kfallah/CLaaS

Length of output: 1645


🏁 Script executed:

#!/bin/bash
# Read the save_lora_inplace function implementation
sed -n '346,390p' claas/training/storage.py

Repository: kfallah/CLaaS

Length of output: 1718


🏁 Script executed:

#!/bin/bash
# Read the rest of save_lora_inplace function to see the return statement
sed -n '390,400p' claas/training/storage.py

Repository: kfallah/CLaaS

Length of output: 308


🏁 Script executed:

#!/bin/bash
# Find where DistillResponse is created and how new_lora_id is used
sed -n '508,525p' claas/training/distillation.py

Repository: kfallah/CLaaS

Length of output: 815


🏁 Script executed:

#!/bin/bash
# Check if save_in_place=False is actually used anywhere (excluding tests)
rg -n 'save_in_place\s*=\s*False' --type py

Repository: kfallah/CLaaS

Length of output: 39


🏁 Script executed:

#!/bin/bash
# Check if there are any other callers of distill() that might use save_in_place=False
rg -n '\.distill\(' --type py -B 5 -A 2

Repository: kfallah/CLaaS

Length of output: 8135


Key cache by output lora_id instead of input to handle non-in-place saves correctly.

resolved_id (line 81) is derived from payload.lora_id. The distill() method returns result.response.lora_id which may differ from the input when save_in_place=False:

  • save_in_place=True: save_lora_inplace() returns the same resolved ID
  • save_in_place=False: save_lora() auto-generates a timestamp suffix and returns a different ID (e.g., base-20250214-120530)

The current cache key (line 93) uses the input resolved_id, causing mismatches when non-in-place saves occur:

  • Request with the new lora_id → cache miss
  • Request reusing the old lora_id → stale cache hit

While save_in_place=False is currently unused (API always sets save_in_place=True), the code explicitly supports this mode. Apply the suggested fix to ensure correctness:

Cache key fix
 with self._cache_lock:
-    self._lora_cache[resolved_id] = result.cache_entry
+    new_resolved = await asyncio.to_thread(resolve_lora_id, result.response.lora_id)
+    self._lora_cache[new_resolved] = result.cache_entry
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
resolved_id = await asyncio.to_thread(resolve_lora_id, payload.lora_id)
with self._cache_lock:
cached = self._lora_cache.get(resolved_id)
try:
return await asyncio.to_thread(trainer.distill, payload)
result = await asyncio.to_thread(
self._trainer.distill, payload, cached=cached
)
finally:
await asyncio.to_thread(trainer.offload_base_model)
await asyncio.to_thread(self._trainer.offload_base_model)
with self._cache_lock:
self._lora_cache[resolved_id] = result.cache_entry
return result.response
resolved_id = await asyncio.to_thread(resolve_lora_id, payload.lora_id)
with self._cache_lock:
cached = self._lora_cache.get(resolved_id)
try:
result = await asyncio.to_thread(
self._trainer.distill, payload, cached=cached
)
finally:
await asyncio.to_thread(self._trainer.offload_base_model)
with self._cache_lock:
new_resolved = await asyncio.to_thread(resolve_lora_id, result.response.lora_id)
self._lora_cache[new_resolved] = result.cache_entry
return result.response
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/training/engine/local/engine.py` around lines 81 - 95, The cache is
being stored under the input-derived resolved_id (from
resolve_lora_id(payload.lora_id)) but distill() may return a different output id
(result.response.lora_id) when save_in_place=False; fix by using the output id
as the cache key: after awaiting asyncio.to_thread(self._trainer.distill, ...),
read final_id = result.response.lora_id (or from result.cache_entry if more
appropriate) and then under self._cache_lock set self._lora_cache[final_id] =
result.cache_entry (optionally remove the old resolved_id if final_id !=
resolved_id to avoid stale entries); keep the initial cache read using
resolved_id unchanged but always write using final_id.


async def init_lora(self, request: LoraInitRequest) -> LoraInitResponse:
"""Initialize a LoRA adapter locally.
Expand All @@ -82,7 +114,11 @@ async def init_lora(self, request: LoraInitRequest) -> LoraInitResponse:
return LoraInitResponse(lora_id=lora_id)

async def delete_lora(self, lora_id: str) -> LoraDeleteResponse:
resolved_id = await asyncio.to_thread(resolve_lora_id, lora_id)
deleted = await asyncio.to_thread(delete_lora, lora_id)
if deleted:
with self._cache_lock:
self._lora_cache.pop(resolved_id, None)
return LoraDeleteResponse(deleted=deleted)

async def list_loras(self, prefix: str) -> LoraListResponse:
Expand Down
Loading
Loading