From d6fbd89bc1317ca11beebd819ce90081797ff2b1 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Wed, 11 Feb 2026 20:31:11 +0800 Subject: [PATCH] update load --- src/twinkle/server/tinker/common/megatron_model.py | 4 ++-- src/twinkle/server/tinker/common/transformers_model.py | 2 +- src/twinkle/server/utils/io_utils.py | 10 ++++------ src/twinkle/server/utils/validation.py | 2 +- src/twinkle_client/__init__.py | 1 + src/twinkle_client/http/http_utils.py | 1 + 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/twinkle/server/tinker/common/megatron_model.py b/src/twinkle/server/tinker/common/megatron_model.py index 1ac9f428..7a436403 100644 --- a/src/twinkle/server/tinker/common/megatron_model.py +++ b/src/twinkle/server/tinker/common/megatron_model.py @@ -175,13 +175,13 @@ def load(self, checkpoint_dir: str, **kwargs): # Create checkpoint manager with the token checkpoint_manager = create_checkpoint_manager(token) - + # Use resolve_load_path to handle path resolution resolved = checkpoint_manager.resolve_load_path(checkpoint_dir) if resolved.is_twinkle_path: # Load from twinkle checkpoint - return super().load(name=resolved.checkpoint_name, output_dir=str(resolved.checkpoint_dir), **kwargs) + return super().load(name=resolved.checkpoint_name, output_dir=resolved.checkpoint_dir, **kwargs) else: # Load from hub return super().load(name=resolved.checkpoint_name, **kwargs) diff --git a/src/twinkle/server/tinker/common/transformers_model.py b/src/twinkle/server/tinker/common/transformers_model.py index eecc2a7f..feff9036 100644 --- a/src/twinkle/server/tinker/common/transformers_model.py +++ b/src/twinkle/server/tinker/common/transformers_model.py @@ -137,7 +137,7 @@ def load(self, checkpoint_dir: str, **kwargs): if resolved.is_twinkle_path: # Load from twinkle checkpoint - return super().load(name=resolved.checkpoint_name, output_dir=str(resolved.checkpoint_dir), **kwargs) + return super().load(name=resolved.checkpoint_name, output_dir=resolved.checkpoint_dir, **kwargs) else: # Load from hub return super().load(name=resolved.checkpoint_name, **kwargs) diff --git a/src/twinkle/server/utils/io_utils.py b/src/twinkle/server/utils/io_utils.py index 8a4e9e7b..3926bd5a 100644 --- a/src/twinkle/server/utils/io_utils.py +++ b/src/twinkle/server/utils/io_utils.py @@ -912,18 +912,16 @@ def resolve_load_path(self, path: str, validate_exists: bool = True) -> Resolved f"Checkpoint not found or access denied: {path}" ) - # Get the checkpoint directory - checkpoint_dir = str(self.get_ckpt_dir(training_run_id, checkpoint_id)) + # Get the checkpoint directory parent path (no checkpoint name in the path) + checkpoint_dir = self.get_ckpt_dir(training_run_id, checkpoint_id).parent if validate_exists: - # Verify the directory exists - from pathlib import Path as PathLib - if not PathLib(checkpoint_dir).exists(): + if not checkpoint_dir.exists(): raise ValueError(f"Checkpoint directory not found: {checkpoint_dir}") return ResolvedLoadPath( checkpoint_name=checkpoint_name, - checkpoint_dir=checkpoint_dir, + checkpoint_dir=checkpoint_dir.as_posix(), is_twinkle_path=True, training_run_id=training_run_id, checkpoint_id=checkpoint_id diff --git a/src/twinkle/server/utils/validation.py b/src/twinkle/server/utils/validation.py index 1f553da5..1f63b44c 100644 --- a/src/twinkle/server/utils/validation.py +++ b/src/twinkle/server/utils/validation.py @@ -22,7 +22,7 @@ async def verify_request_token(request: Request, call_next): Returns: JSONResponse with error if validation fails, otherwise the response from call_next """ - authorization = request.headers.get("Authorization") + authorization = request.headers.get("Twinkle-Authorization") token = authorization[7:] if authorization and authorization.startswith("Bearer ") else authorization if not is_token_valid(token): return JSONResponse(status_code=403, content={"detail": "Invalid token"}) diff --git a/src/twinkle_client/__init__.py b/src/twinkle_client/__init__.py index 782956c1..1ad6812d 100644 --- a/src/twinkle_client/__init__.py +++ b/src/twinkle_client/__init__.py @@ -29,6 +29,7 @@ def init_tinker_compat_client(base_url: Optional[str] = None, api_key: Optional[ default_headers = { "X-Ray-Serve-Request-Id": get_request_id(), "Authorization": 'Bearer ' + api_key, + "Twinkle-Authorization": 'Bearer ' + api_key, # For server compatibility } | kwargs.pop("default_headers", {}) service_client = ServiceClient(base_url=base_url, api_key=api_key, default_headers=default_headers, **kwargs) diff --git a/src/twinkle_client/http/http_utils.py b/src/twinkle_client/http/http_utils.py index 1e927c2a..0743ca2c 100644 --- a/src/twinkle_client/http/http_utils.py +++ b/src/twinkle_client/http/http_utils.py @@ -17,6 +17,7 @@ def _build_headers(additional_headers: Optional[Dict[str, str]] = None) -> Dict[ headers = { "X-Ray-Serve-Request-Id": get_request_id(), "Authorization": 'Bearer ' + get_api_key(), + "Twinkle-Authorization": 'Bearer ' + get_api_key(), # For server compatibility } if additional_headers: