diff --git a/.gitignore b/.gitignore index 111a4280..58f495d4 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,4 @@ megatron_output/ ast_index_file.py test_cookbook/ /test*.py +swanlog/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 198c5575..f1979a9a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,23 +22,23 @@ repos: hooks: - id: pyupgrade args: [--py38-plus] - exclude: ^client_tools/ + exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/) - repo: https://github.com/pre-commit/pre-commit-hooks rev: v6.0.0 hooks: - id: trailing-whitespace - exclude: ^client_tools/ + exclude: ^(client_tools/|src/twinkle_client/) - id: check-yaml - exclude: ^client_tools/ + exclude: ^(client_tools/|src/twinkle_client/) - id: end-of-file-fixer - exclude: ^client_tools/ + exclude: ^(client_tools/|src/twinkle_client/) - id: requirements-txt-fixer - exclude: ^client_tools/ + exclude: ^(client_tools/|src/twinkle_client/) - id: double-quote-string-fixer - exclude: ^client_tools/ + exclude: ^(client_tools/|src/twinkle_client/) - id: check-merge-conflict - exclude: ^client_tools/ + exclude: ^(client_tools/|src/twinkle_client/) - id: mixed-line-ending args: ["--fix=lf"] - exclude: ^client_tools/ + exclude: ^(client_tools/|src/twinkle_client/) diff --git a/cookbook/client/tinker/grpo.py b/cookbook/client/tinker/grpo.py deleted file mode 100644 index db67b04a..00000000 --- a/cookbook/client/tinker/grpo.py +++ /dev/null @@ -1,278 +0,0 @@ -# Tinker-Compatible Client - GRPO (Group Relative Policy Optimization) Training Example -# -# This script demonstrates GRPO reinforcement learning training using the -# Tinker-compatible client API with save_weights_for_sampler for weight sync. -# Instead of calling sync_weights directly, it periodically saves weights and -# creates a sampling client for generation. -# -# Flow: -# 1. Prepare Countdown dataset (client-side) -# 2. Initialize Tinker-compatible training & sampling clients -# 3. Training loop: -# a. Every SYNC_INTERVAL steps: save_weights_for_sampler → sampling_client -# b. Sample completions from the sampling client -# c. Compute rewards and advantages (client-side) -# d. Train on sampled data weighted by advantages -# e. Optimizer step -# -# The server must be running first (see server.py and server_config.yaml). -# Requires both model and sampler services to be configured. - -import gc -import numpy as np -import os -from modelscope import AutoTokenizer -from tinker import types -from typing import List, Tuple - -from twinkle import get_logger -from twinkle.advantage import GRPOAdvantage -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.metric import CompletionRewardMetric -from twinkle_client import init_tinker_compat_client - -logger = get_logger() - -# ========== Configuration ========== -BASE_MODEL = 'Qwen/Qwen3-30B-A3B-Instruct-2507' -NUM_GENERATIONS = 4 -MAX_NEW_TOKENS = 1024 -LEARNING_RATE = 1e-5 -MAX_STEPS = 100 -BATCH_SIZE = 2 -TEMPERATURE = 1.0 -SYNC_INTERVAL = 2 # Save weights for sampler every N steps -LORA_RANK = 8 - - -def create_countdown_dataset(): - """Create Countdown Game dataset for GRPO training.""" - logger.info('Loading Countdown dataset...') - - dataset = Dataset(DatasetMeta('ms://zouxuhong/Countdown-Tasks-3to4', data_slice=range(500))) - dataset.set_template('Template', model_id=f'ms://{BASE_MODEL}', max_length=8192) - dataset.map('CountdownProcessor') - dataset.encode(add_generation_prompt=True) - - logger.info(f'Dataset loaded with {len(dataset)} samples') - return dataset - - -def compute_rewards(trajectories: List[dict], ) -> Tuple[List[float], List[float], List[float]]: - """Compute format and accuracy rewards for Countdown game.""" - from twinkle.reward import CountDownAccuracy, FormatReward - format_rewards = FormatReward()(trajectories, []) - accuracy_rewards = CountDownAccuracy()(trajectories, []) - total_rewards = [a + b for a, b in zip(accuracy_rewards, format_rewards)] - return total_rewards, format_rewards, accuracy_rewards - - -def main(): - logger.info('Starting GRPO training...') - - # Step 1: Prepare dataset and dataloader (client-side) - dataset = create_countdown_dataset() - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) - tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) - - logger.info('Dataset and tokenizer initialized') - - # Step 2: Initialize the Tinker-compatible client - logger.info('Connecting to Tinker server...') - service_client = init_tinker_compat_client( - base_url='http://www.modelscope.cn/twinkle', api_key=os.environ.get('MODELSCOPE_SDK_TOKEN')) - - logger.info('Creating LoRA training client...') - # Create a LoRA training client for GRPO - training_client = service_client.create_lora_training_client( - base_model=BASE_MODEL, - rank=LORA_RANK, - ) - - logger.info('Training client created successfully') - - # Step 3: Setup metrics and advantage function - advantage_fn = GRPOAdvantage() - metrics = CompletionRewardMetric() - - sampling_params = types.SamplingParams( - max_tokens=MAX_NEW_TOKENS, - temperature=TEMPERATURE, - top_p=0.95, - ) - - # The sampling client is created on-demand via save_weights_for_sampler - sampling_client = None - - step = 0 - for batch in dataloader: - if step >= MAX_STEPS: - break - - metrics.reset() - prompts = batch if isinstance(batch, list) else [batch] - - # ========== 1. Save weights for sampler (instead of sync_weights) ========== - if step % SYNC_INTERVAL == 0: - logger.info(f'Step {step}: Saving weights for sampler...') - - sampling_client = (training_client.save_weights_and_get_sampling_client(name=f'grpo-step-{step}')) - logger.info(f'Step {step}: Sampling client ready') - - if sampling_client is None: - logger.warning('No sampling client available, skipping step') - step += 1 - continue - - # ========== 2. Sample completions ========== - # Convert input features to token prompts for the sampling client - all_sequences = [] - for prompt_feature in prompts: - input_ids = prompt_feature['input_ids'] - if hasattr(input_ids, 'tolist'): - input_ids = input_ids.tolist() - prompt = types.ModelInput.from_ints(input_ids) - future = sampling_client.sample( - prompt=prompt, - sampling_params=sampling_params, - num_samples=NUM_GENERATIONS, - ) - result = future.result() - all_sequences.extend(result.sequences) - - if not all_sequences: - logger.warning(f'Step {step}: No valid samples, skipping') - step += 1 - continue - - # ========== 3. Build trajectories and collect logprobs ========== - trajectories = [] - old_logps_list = [] - completion_lengths = [] - - for seq in all_sequences: - decoded_text = tokenizer.decode(seq.tokens, skip_special_tokens=True) - trajectories.append({'messages': [{'role': 'assistant', 'content': decoded_text}]}) - old_logps_list.append([lp for lp in seq.logprobs] if seq.logprobs else []) - completion_lengths.append(len(seq.tokens)) - - # ========== 4. Compute rewards ========== - total_rewards, format_rewards, accuracy_rewards = compute_rewards(trajectories) - metrics.accumulate( - None, - None, - completion_lengths=completion_lengths, - rewards={ - 'total': total_rewards, - 'format': format_rewards, - 'accuracy': accuracy_rewards, - }) - - # ========== 5. Compute advantages ========== - advantages = advantage_fn( - total_rewards, - num_generations=NUM_GENERATIONS, - scale='group', - ).tolist() - - frac_zero_std = (1.0 if all(abs(a) < 1e-8 for a in advantages) else 0.0) - if frac_zero_std == 1.0: - logger.info(f'Step {step}: All advantages are zero, skipping training') - step += 1 - continue - - # ========== 6. Train the policies with GRPO loss ========== - # Train the policies with the Advantage-Regularized policy - # gradient (GRPO) loss function. - # - # The GRPO loss function requires: - # 1. logprobs: The log probabilities of the tokens under the current policy - # 2. advantages: The advantage values for each completion - # - # The training data is constructed with: - # - model_input: The full prompt + completion tokens - # - target_tokens: The shifted tokens for next-token prediction - # - logprobs: The log probabilities from the sampling step - # - advantages: The computed advantage values - training_data = [] - for i, seq in enumerate(all_sequences): - # Build a Datum from the completion tokens with logprobs and advantages - prompt_feature = prompts[i // NUM_GENERATIONS] - prompt_ids = prompt_feature['input_ids'] - if hasattr(prompt_ids, 'tolist'): - prompt_ids = prompt_ids.tolist() - - sampled_tokens = list(seq.tokens) - logprobs = seq.logprobs if seq.logprobs else [0.0] * len(sampled_tokens) - advantage = float(advantages[i]) - - ob_len = len(prompt_ids) - 1 - input_tokens = prompt_ids + sampled_tokens[:-1] - target_tokens = [0] * ob_len + sampled_tokens - padded_advantages = [0.0] * ob_len + [advantage] * len(sampled_tokens) - padded_logprobs = [0.0] * ob_len + logprobs - - # Verify lengths match - assert len(input_tokens) == len(target_tokens) == len(padded_logprobs) == len(padded_advantages), \ - f'Length mismatch: input={len(input_tokens)}, target={len(target_tokens)}, ' \ - f'logprobs={len(padded_logprobs)}, advantages={len(padded_advantages)}' - - datum = types.Datum( - model_input=types.ModelInput.from_ints(input_tokens), - loss_fn_inputs={ - 'target_tokens': target_tokens, - 'logprobs': types.TensorData.from_numpy(np.array(padded_logprobs, dtype=np.float32)), - 'advantages': types.TensorData.from_numpy(np.array(padded_advantages, dtype=np.float32)), - }, - ) - training_data.append(datum) - - if not training_data: - logger.info(f'Step {step}: No training data constructed, skipping') - step += 1 - continue - - # Forward-backward pass with importance_sampling (GRPO) loss - # The training data already contains logprobs and advantages for the GRPO loss - fwdbwd_future = training_client.forward_backward(training_data, 'importance_sampling') - optim_future = training_client.optim_step(types.AdamParams(learning_rate=LEARNING_RATE)) - - fwdbwd_result = fwdbwd_future.result() - optim_result = optim_future.result() - - # Compute metrics from the forward-backward result - # For importance_sampling, we get logprobs and elementwise_loss - logprobs_list = [] - elementwise_losses = [] - for output in fwdbwd_result.loss_fn_outputs: - if output.get('logprobs') is not None: - logprobs_list.append(output['logprobs'].to_numpy()) - if output.get('elementwise_loss') is not None: - elementwise_losses.append(output['elementwise_loss'].to_numpy()) - - # Compute average loss per token (weighted by advantages) - if elementwise_losses: - all_losses = np.concatenate(elementwise_losses) - avg_loss = np.mean(all_losses) if len(all_losses) > 0 else 0.0 - else: - avg_loss = 0.0 - - gc.collect() - - # ========== 7. Log ========== - log_dict = metrics.calculate() - log_dict['train/loss_per_token'] = float(avg_loss) - log_dict['train/frac_reward_zero_std'] = frac_zero_std - log_dict['train/num_training_samples'] = len(training_data) - logger.info(f'Step {step}: {log_dict}') - step += 1 - - # Save final checkpoint - save_future = training_client.save_state('grpo-countdown-final') - save_result = save_future.result() - logger.info(f'Saved final checkpoint to {save_result.path}') - - -if __name__ == '__main__': - main() diff --git a/cookbook/client/tinker/megatron/server_config.yaml b/cookbook/client/tinker/megatron/server_config.yaml index 852fd85c..69f749fd 100644 --- a/cookbook/client/tinker/megatron/server_config.yaml +++ b/cookbook/client/tinker/megatron/server_config.yaml @@ -56,6 +56,9 @@ applications: device_mesh: device_type: cuda dp_size: 4 + queue_config: + rps_limit: 20 # Max requests per second + tps_limit: 10000 # Max tokens per second deployments: - name: SamplerManagement autoscaling_config: @@ -77,7 +80,9 @@ applications: args: use_megatron: true # Use HuggingFace Transformers backend model_id: "ms://Qwen/Qwen3-30B-A3B-Instruct-2507" # ModelScope model identifier - nproc_per_node: 4 # Number of GPU processes per node + max_length: 10240 # model max length + max_loras: 5 # model max loras + nproc_per_node: 4 # Number of GPU processes per node device_group: name: model ranks: [4,5,6,7] # GPU rank indices @@ -88,11 +93,12 @@ applications: ep_size: 2 queue_config: - rps_limit: 100 # Max requests per second - tps_limit: 100000 # Max tokens per second + rps_limit: 20 # Max requests per second + tps_limit: 10000 # Max tokens per second adapter_config: - per_token_adapter_limit: 30 # Max concurrent LoRA adapters - adapter_timeout: 1800 # Seconds before idle adapter unload + per_token_adapter_limit: 3 # Max concurrent LoRA adapters + adapter_timeout: 30 # Seconds before idle adapter unload + adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours) deployments: - name: ModelManagement autoscaling_config: diff --git a/cookbook/client/tinker/megatron/server_config_7b.yaml b/cookbook/client/tinker/megatron/server_config_7b.yaml index 252e031d..cad014c9 100644 --- a/cookbook/client/tinker/megatron/server_config_7b.yaml +++ b/cookbook/client/tinker/megatron/server_config_7b.yaml @@ -50,10 +50,12 @@ applications: dp_size: 2 queue_config: rps_limit: 100 # Max requests per second - tps_limit: 100000 # Max tokens per second + tps_limit: 10000 # Max tokens per second for a single user + max_input_tokens: 10000 # Maximum input tokens per request adapter_config: - per_token_adapter_limit: 30 # Max concurrent LoRA adapters - adapter_timeout: 1800 # Seconds before idle adapter unload + adapter_timeout: 30 # Seconds before idle adapter unload + adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours) + per_token_adapter_limit: 30 deployments: - name: ModelManagement autoscaling_config: diff --git a/cookbook/client/tinker/sample.py b/cookbook/client/tinker/sample.py index 210c3819..e995123f 100644 --- a/cookbook/client/tinker/sample.py +++ b/cookbook/client/tinker/sample.py @@ -4,7 +4,6 @@ # for text generation (sampling) via the Tinker-compatible client API. # The server must be running first (see server.py and server_config.yaml). -from modelscope import AutoTokenizer from tinker import types from twinkle.data_format import Message, Trajectory @@ -14,33 +13,39 @@ # Step 1: Define the base model and connect to the server base_model = 'Qwen/Qwen3-30B-A3B-Instruct-2507' service_client = init_tinker_compat_client( - base_url='http://www.modelscope.cn/twinkle', api_key=os.environ.get('MODELSCOPE_SDK_TOKEN')) - + base_url='http://www.modelscope.cn/twinkle', + api_key=os.environ.get('MODELSCOPE_SDK_TOKEN') +) # Step 2: Create a sampling client by loading weights from a saved checkpoint. # The model_path is a twinkle:// URI pointing to a previously saved LoRA checkpoint. # The server will load the base model and apply the LoRA adapter weights. -sampling_client = service_client.create_sampling_client( - model_path='twinkle://xxx-Qwen_Qwen3-30B-A3B-Instruct-2507-xxx/weights/twinkle-lora-1', base_model=base_model) +service_client.create_sampling_client( + model_path='twinkle://xxx-Qwen_Qwen3-30B-A3B-Instruct-2507-xxx/weights/twinkle-lora-1', + base_model=base_model +) # Step 3: Load the tokenizer locally to encode the prompt and decode the results print(f'Using model {base_model}') -template = Template(model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507') -trajectory = Trajectory(messages=[ - Message(role='system', content='You are a helpful assistant'), - Message(role='user', content='你是谁?'), -]) +template = Template(model_id=f'ms://{base_model}') + +trajectory = Trajectory( + messages=[ + Message(role='system', content='You are a helpful assistant'), + Message(role='user', content='你是谁?'), + ] +) -input_features = template.batch_encode([trajectory], add_generation_prompt=True) +input_feature = template.encode(trajectory, add_generation_prompt=True) -input_ids = input_features[0]['input_ids'] +input_ids = input_feature['input_ids'].tolist() # Step 4: Prepare the prompt and sampling parameters -prompt = types.ModelInput.from_ints(list(input_ids)) +prompt = types.ModelInput.from_ints(input_ids) params = types.SamplingParams( - max_tokens=128, # Maximum number of tokens to generate - temperature=0.0, # Greedy sampling (deterministic, always pick the top token) - stop=['\n'] # Stop generation when a newline character is produced + max_tokens=128, # Maximum number of tokens to generate + temperature=0.7, + stop=['\n'] # Stop generation when a newline character is produced ) # Step 5: Send the sampling request to the server. diff --git a/cookbook/client/tinker/self_congnition.py b/cookbook/client/tinker/self_congnition.py index bf5b1b01..5a565cc5 100644 --- a/cookbook/client/tinker/self_congnition.py +++ b/cookbook/client/tinker/self_congnition.py @@ -8,15 +8,15 @@ # The server must be running first (see server.py and server_config.yaml). import numpy as np import os -from modelscope import AutoTokenizer -from tinker import types from tqdm import tqdm - +from tinker import types +from twinkle_client import init_tinker_compat_client +from twinkle.data_format import Message, Trajectory +from twinkle.template import Template from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor from twinkle.server.tinker.common import input_feature_to_datum -from twinkle_client import init_tinker_compat_client # The base model to fine-tune / evaluate base_model = 'Qwen/Qwen3-30B-A3B-Instruct-2507' @@ -82,33 +82,28 @@ def eval(): # Step 1: Load the trained LoRA checkpoint for inference # Path to a previously saved LoRA checkpoint (twinkle:// URI) - weight_path = 'twinkle://20260211_112719-Qwen_Qwen2_5-7B-Instruct-a74a4826/weights/twinkle-lora-2' + weight_path = 'twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2' # Connect to the server and create a sampling client with the trained weights service_client = init_tinker_compat_client(base_url='http://localhost:8000') sampling_client = service_client.create_sampling_client(model_path=weight_path, base_model=base_model) - # Load the tokenizer for encoding the prompt and decoding the output - tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) - # Step 2: Prepare the chat prompt # Build a multi-turn conversation to test the model's self-cognition - inputs = [{ - 'role': 'system', - 'content': 'You are a helpful assistant.' - }, { - 'role': 'user', - 'content': 'what is your name?' - }] - - # Apply the model's chat template to format the conversation - input_ids = tokenizer.apply_chat_template( - inputs, - tokenize=True, - add_generation_prompt=True # Adds the assistant prompt prefix + template = Template(model_id=f'ms://{base_model}') + + trajectory = Trajectory( + messages=[ + Message(role='system', content='You are a helpful assistant'), + Message(role='user', content='你是谁?'), + ] ) + input_feature = template.encode(trajectory, add_generation_prompt=True) + + input_ids = input_feature['input_ids'].tolist() + # Step 3: Generate responses prompt = types.ModelInput.from_ints(input_ids) @@ -126,9 +121,9 @@ def eval(): # Decode and print each response print('Responses:') for i, seq in enumerate(result.sequences): - print(f'{i}: {repr(tokenizer.decode(seq.tokens))}') + print(f'{i}: {repr(template.decode(seq.tokens))}') if __name__ == '__main__': - train() # Uncomment to run training + train() # Uncomment to run training # eval() # Run evaluation / inference diff --git a/cookbook/client/tinker/short_math_grpo.py b/cookbook/client/tinker/short_math_grpo.py index 7134039b..6ab037f3 100644 --- a/cookbook/client/tinker/short_math_grpo.py +++ b/cookbook/client/tinker/short_math_grpo.py @@ -21,19 +21,19 @@ import numpy as np import os import re -from modelscope import AutoTokenizer from tinker import types from typing import List, Tuple +from twinkle_client import init_tinker_compat_client from twinkle import get_logger from twinkle.advantage import GRPOAdvantage from twinkle.data_format import Message, Trajectory from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta -from twinkle.metric import CompletionRewardMetric from twinkle.preprocessor import Preprocessor from twinkle.reward.base import Reward -from twinkle_client import init_tinker_compat_client +from twinkle.metric import CompletionRewardMetric +from twinkle.template import Template logger = get_logger() @@ -209,9 +209,9 @@ def main(): # Step 1: Prepare dataset and dataloader (client-side) dataset = create_Math_dataset() dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) - tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) + template = Template(model_id=f'ms://{BASE_MODEL}') - logger.info('Dataset and tokenizer initialized') + logger.info('Dataset and template initialized') # Step 2: Initialize the Tinker-compatible client logger.info('Connecting to Tinker server...') @@ -291,7 +291,7 @@ def main(): completion_lengths = [] for idx, seq in enumerate(all_sequences): - decoded_text = tokenizer.decode(seq.tokens, skip_special_tokens=True) + decoded_text = template.decode(seq.tokens, skip_special_tokens=True) # Use the corresponding user data for this sequence trajectories.append({ 'messages': [ @@ -334,6 +334,10 @@ def main(): ).tolist() frac_zero_std = (1.0 if all(abs(a) < 1e-8 for a in advantages) else 0.0) + if frac_zero_std == 1.0: + logger.info(f'Step {step}: All advantages are zero, skipping training') + step += 1 + continue # ========== 6. Train the policies with GRPO loss ========== # Train the policies with the Advantage-Regularized policy diff --git a/cookbook/client/twinkle/grpo.py b/cookbook/client/twinkle/grpo.py index 9d374beb..30a33c0e 100644 --- a/cookbook/client/twinkle/grpo.py +++ b/cookbook/client/twinkle/grpo.py @@ -42,13 +42,13 @@ # ========== Configuration ========== MODEL_ID = 'ms://Qwen/Qwen2.5-3B-Instruct' -NUM_GENERATIONS = 8 +NUM_GENERATIONS = 4 MAX_NEW_TOKENS = 1024 LEARNING_RATE = 1e-5 MAX_STEPS = 10 -BATCH_SIZE = 4 +BATCH_SIZE = 2 TEMPERATURE = 1.0 -SYNC_INTERVAL = 5 # Save weights for sampler every N steps +SYNC_INTERVAL = 1 # Save weights for sampler every N steps GRADIENT_ACCUMULATION_STEPS = 4 diff --git a/cookbook/client/twinkle/transformer/server_config.yaml b/cookbook/client/twinkle/transformer/server_config.yaml index 680e6f63..93fe8592 100644 --- a/cookbook/client/twinkle/transformer/server_config.yaml +++ b/cookbook/client/twinkle/transformer/server_config.yaml @@ -49,8 +49,7 @@ applications: device_type: cuda device_mesh: # Distributed training mesh configuration device_type: cuda - mesh: [0,1] # Device indices in the mesh - mesh_dim_names: ['dp'] # Mesh dimension names: 'dp' = data parallel + dp_size: 2 # Mesh dimension names: 'dp' = data parallel deployments: - name: ModelManagement autoscaling_config: @@ -59,6 +58,10 @@ applications: target_ongoing_requests: 16 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 3. Processor Service - Handles data preprocessing on CPU # Runs tokenization, template application, and other CPU-bound tasks. @@ -84,6 +87,10 @@ applications: target_ongoing_requests: 128 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 4. Sampler Service - Handles text generation inference # Uses vLLM for efficient batched generation with optional LoRA adapters. @@ -93,7 +100,7 @@ applications: args: model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier to load sampler_type: vllm # Sampler backend (vllm or torch) - nproc_per_node: 1 # Number of GPU processes per node + nproc_per_node: 2 # Number of GPU processes per node engine_args: # vLLM engine configuration gpu_memory_utilization: 0.4 max_model_len: 1024 @@ -102,12 +109,11 @@ applications: adapter_timeout: 1800 # Seconds before idle adapter is unloaded device_group: name: sampler - ranks: [0] # GPU rank indices to use + ranks: [2] # GPU rank indices to use device_type: cuda device_mesh: device_type: cuda - mesh: [0] - mesh_dim_names: ['dp'] + dp_size: 1 deployments: - name: SamplerManagement autoscaling_config: @@ -116,3 +122,7 @@ applications: target_ongoing_requests: 16 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" diff --git a/docs/source_en/Usage Guide/Server and Client/Server.md b/docs/source_en/Usage Guide/Server and Client/Server.md index 54b2d46b..ec7b4b42 100644 --- a/docs/source_en/Usage Guide/Server and Client/Server.md +++ b/docs/source_en/Usage Guide/Server and Client/Server.md @@ -50,9 +50,78 @@ This configuration starts 3 nodes: - **Node 1** (Worker): 4 GPUs (cards 4-7) - **Node 2** (Worker): CPU-only node +#### 4. Set Environment Variables + +Before starting the Server, you need to set the following environment variables: + +```bash +export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # Specify the total number of GPUs on each physical machine +export TWINKLE_TRUST_REMOTE_CODE=0 # Whether to trust remote code (security consideration) +``` + +> **Important Note**: `DEVICE_COUNT_PER_PHYSICAL_NODE` must be set to the actual number of physical GPUs on the machine, which is crucial for correctly parsing the `ranks` configuration. + ### Node Rank in YAML Configuration -In the YAML configuration file, **each component needs to occupy a separate Node**, and the `ranks` within each Node are numbered starting from 0. +In the YAML configuration file, **each component needs to occupy a separate Node**. + +**Example configuration:** + +```yaml +applications: + # Model service occupies GPU 0-3 (physical card numbers) + - name: models-Qwen2.5-7B-Instruct + route_prefix: /models/Qwen/Qwen2.5-7B-Instruct + import_path: model + args: + nproc_per_node: 4 + device_group: + name: model + ranks: [0, 1, 2, 3] # Physical GPU card numbers + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 4 # Data parallel size + # tp_size: 1 # Tensor parallel size (optional) + # pp_size: 1 # Pipeline parallel size (optional) + # ep_size: 1 # Expert parallel size (optional) + + # Sampler service occupies GPU 4-5 (physical card numbers) + - name: sampler-Qwen2.5-7B-Instruct + route_prefix: /sampler/Qwen/Qwen2.5-7B-Instruct + import_path: sampler + args: + nproc_per_node: 2 + device_group: + name: sampler + ranks: [4, 5] # Physical GPU card numbers 4-5 + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 2 # Data parallel size + + # Processor service occupies CPU + - name: processor + route_prefix: /processors + import_path: processor + args: + ncpu_proc_per_node: 4 + device_group: + name: processor + ranks: 0 # CPU index + device_type: CPU + device_mesh: + device_type: CPU + dp_size: 4 # Data parallel size +``` +**Important notes:** +- The `ranks` configuration uses **physical GPU card numbers**, directly corresponding to the actual GPU devices on the machine +- The `device_mesh` configuration uses parameters like `dp_size`, `tp_size`, `pp_size`, `ep_size` instead of the original `mesh` and `mesh_dim_names` +- The environment variable `DEVICE_COUNT_PER_PHYSICAL_NODE` must be set to inform the system of the total number of physical GPUs on each machine +- Different components will be automatically assigned to different Nodes +- Ray will automatically schedule to the appropriate Node based on resource requirements (`num_gpus`, `num_cpus` in `ray_actor_options`) + +In the YAML configuration file, **each component needs to occupy a separate Node**. **Example configuration:** @@ -204,8 +273,9 @@ applications: device_type: cuda device_mesh: # Distributed training mesh device_type: cuda - mesh: [0, 1] # Device indices in the mesh - mesh_dim_names: ['dp'] # Mesh dimensions: dp=data parallel + dp_size: 2 # Data parallel size + # tp_size: 1 # Tensor parallel size (optional) + # pp_size: 1 # Pipeline parallel size (optional) deployments: - name: ModelManagement autoscaling_config: @@ -229,8 +299,7 @@ applications: device_type: CPU device_mesh: device_type: CPU - mesh: [0, 1] - mesh_dim_names: ['dp'] + dp_size: 2 # Data parallel size deployments: - name: ProcessorManagement autoscaling_config: @@ -260,8 +329,7 @@ The difference from the Transformers backend is only in the `use_megatron` param device_type: cuda device_mesh: device_type: cuda - mesh: [0, 1] - mesh_dim_names: ['dp'] + dp_size: 2 # Data parallel size ``` > **Note**: The Megatron backend does not need `adapter_config` (LoRA adapter management is handled internally by Megatron). @@ -314,8 +382,7 @@ applications: device_type: cuda device_mesh: device_type: cuda - mesh: [0, 1] - mesh_dim_names: ['dp'] + dp_size: 2 # Data parallel size deployments: - name: ModelManagement autoscaling_config: @@ -324,6 +391,9 @@ applications: target_ongoing_requests: 16 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # Total number of physical GPUs on each machine # 3. Sampler service (optional, for inference sampling) - name: sampler-Qwen2.5-0.5B-Instruct @@ -343,8 +413,7 @@ applications: device_type: cuda device_mesh: device_type: cuda - mesh: [0] - mesh_dim_names: ['dp'] + dp_size: 1 # Data parallel size deployments: - name: SamplerManagement autoscaling_config: @@ -354,6 +423,9 @@ applications: ray_actor_options: num_cpus: 0.1 num_gpus: 1 # Sampler needs independent GPU + runtime_env: + env_vars: + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # Total number of physical GPUs on each machine ``` ## Configuration Item Description @@ -375,11 +447,30 @@ applications: ```yaml device_group: name: model # Device group name - ranks: [0, 1] # GPU card number list + ranks: [0, 1] # Physical GPU card number list device_type: cuda # Device type: cuda / CPU device_mesh: device_type: cuda - mesh: [0, 1] # Device indices in the mesh - mesh_dim_names: ['dp'] # Dimension names, commonly used: dp (data parallel), tp (tensor parallel), pp (pipeline parallel) + dp_size: 2 # Data parallel size + # tp_size: 1 # Tensor parallel size (optional) + # pp_size: 1 # Pipeline parallel size (optional) + # ep_size: 1 # Expert parallel size (optional) +``` + +**Important configuration parameters:** + +| Parameter | Type | Description | +|------|------|------| +| `ranks` | list[int] | **Physical GPU card numbers**, directly corresponding to the actual GPU devices on the machine | +| `dp_size` | int | Data parallel size | +| `tp_size` | int (optional) | Tensor parallel size | +| `pp_size` | int (optional) | Pipeline parallel size | +| `ep_size` | int (optional) | Expert parallel size (for MoE models) | + +**Environment variables:** + +```bash +export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # Total number of GPUs on each physical machine (must be set) +export TWINKLE_TRUST_REMOTE_CODE=0 # Whether to trust remote code ``` diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" index d194159d..ab7a2436 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" @@ -50,15 +50,26 @@ ray start --address=10.28.252.9:6379 --num-gpus=0 - **Node 1**(Worker):4 个 GPU(卡 4-7) - **Node 2**(Worker):纯 CPU 节点 +#### 4. 设置环境变量 + +在启动 Server 之前,需要设置以下环境变量: + +```bash +export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # 指定每台物理机上的 GPU 总数 +export TWINKLE_TRUST_REMOTE_CODE=0 # 是否信任远程代码(安全考虑) +``` + +> **重要提示**:`DEVICE_COUNT_PER_PHYSICAL_NODE` 必须设置为机器上实际的物理 GPU 数量,这对于正确解析 `ranks` 配置至关重要。 + ### YAML 配置中的 Node Rank -在 YAML 配置文件中,**每个组件需要占用一个独立的 Node**,`ranks` 配置在各自的 Node 内都是从 0 开始编号的。 +在 YAML 配置文件中,**每个组件需要占用一个独立的 Node**。 **示例配置:** ```yaml applications: - # 模型服务占用 Node 0(Head 节点,GPU 0-3) + # 模型服务占用 GPU 0-3(物理卡号) - name: models-Qwen2.5-7B-Instruct route_prefix: /models/Qwen/Qwen2.5-7B-Instruct import_path: model @@ -66,14 +77,16 @@ applications: nproc_per_node: 4 device_group: name: model - ranks: [0, 1, 2, 3] # Node 0 内的 GPU 编号 + ranks: [0, 1, 2, 3] # 物理 GPU 卡号 device_type: cuda device_mesh: device_type: cuda - mesh: [0, 1, 2, 3] - mesh_dim_names: ['dp'] + dp_size: 4 # 数据并行大小 + # tp_size: 1 # 张量并行大小(可选) + # pp_size: 1 # 流水线并行大小(可选) + # ep_size: 1 # 专家并行大小(可选) - # Sampler 服务占用 Node 1(Worker 节点,GPU 4-7) + # Sampler 服务占用 GPU 4-5(物理卡号) - name: sampler-Qwen2.5-7B-Instruct route_prefix: /sampler/Qwen/Qwen2.5-7B-Instruct import_path: sampler @@ -81,14 +94,13 @@ applications: nproc_per_node: 2 device_group: name: sampler - ranks: [0, 1] # Node 1 内的 GPU 编号(对应物理 GPU 4-5) + ranks: [4, 5] # 物理 GPU 卡号 4-5 device_type: cuda device_mesh: device_type: cuda - mesh: [0, 1] - mesh_dim_names: ['dp'] + dp_size: 2 # 数据并行大小 - # Processor 服务占用 Node 2(CPU 节点) + # Processor 服务占用 CPU - name: processor route_prefix: /processors import_path: processor @@ -96,16 +108,16 @@ applications: ncpu_proc_per_node: 4 device_group: name: processor - ranks: 0 # Node 2 内的 CPU 编号 + ranks: 0 # CPU 编号 device_type: CPU device_mesh: device_type: CPU - mesh: [0, 1, 2, 3] - mesh_dim_names: ['dp'] + dp_size: 4 # 数据并行大小 ``` - **重要提示:** -- 每个组件的 `ranks` 配置都是相对于其所占用的 Ray Node 而言 +- `ranks` 配置使用**物理 GPU 卡号**,直接对应机器上的实际 GPU 设备 +- `device_mesh` 配置使用 `dp_size`、`tp_size`、`pp_size`、`ep_size` 等参数替代原来的 `mesh` 和 `mesh_dim_names` +- 必须设置环境变量 `DEVICE_COUNT_PER_PHYSICAL_NODE` 来告知系统每台机器的物理 GPU 总数 - 不同组件会自动分配到不同的 Node 上 - Ray 会根据资源需求(`ray_actor_options` 中的 `num_gpus`、`num_cpus`)自动调度到合适的 Node @@ -200,12 +212,13 @@ applications: nproc_per_node: 2 # 每节点 GPU 进程数 device_group: # 逻辑设备组 name: model - ranks: [0, 1] # 使用的 GPU 卡号 + ranks: [0, 1] # 物理 GPU 卡号 device_type: cuda device_mesh: # 分布式训练网格 device_type: cuda - mesh: [0, 1] # 网格中的设备索引 - mesh_dim_names: ['dp'] # 网格维度:dp=数据并行 + dp_size: 2 # 数据并行大小 + # tp_size: 1 # 张量并行大小(可选) + # pp_size: 1 # 流水线并行大小(可选) deployments: - name: ModelManagement autoscaling_config: @@ -229,8 +242,7 @@ applications: device_type: CPU device_mesh: device_type: CPU - mesh: [0, 1] - mesh_dim_names: ['dp'] + dp_size: 2 # 数据并行大小 deployments: - name: ProcessorManagement autoscaling_config: @@ -260,8 +272,7 @@ applications: device_type: cuda device_mesh: device_type: cuda - mesh: [0, 1] - mesh_dim_names: ['dp'] + dp_size: 2 # 数据并行大小 ``` > **注意**:Megatron 后端不需要 `adapter_config`(LoRA 适配器管理由 Megatron 内部处理)。 @@ -314,8 +325,7 @@ applications: device_type: cuda device_mesh: device_type: cuda - mesh: [0, 1] - mesh_dim_names: ['dp'] + dp_size: 2 # 数据并行大小 deployments: - name: ModelManagement autoscaling_config: @@ -324,6 +334,9 @@ applications: target_ongoing_requests: 16 ray_actor_options: num_cpus: 0.1 + runtime_env: + env_vars: + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 每台机器的物理 GPU 总数 # 3. Sampler 服务(可选,用于推理采样) - name: sampler-Qwen2.5-0.5B-Instruct @@ -343,8 +356,7 @@ applications: device_type: cuda device_mesh: device_type: cuda - mesh: [0] - mesh_dim_names: ['dp'] + dp_size: 1 # 数据并行大小 deployments: - name: SamplerManagement autoscaling_config: @@ -354,6 +366,9 @@ applications: ray_actor_options: num_cpus: 0.1 num_gpus: 1 # Sampler 需要独立 GPU + runtime_env: + env_vars: + DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 每台机器的物理 GPU 总数 ``` ## 配置项说明 @@ -375,11 +390,30 @@ applications: ```yaml device_group: name: model # 设备组名称 - ranks: [0, 1] # GPU 卡号列表 + ranks: [0, 1] # 物理 GPU 卡号列表 device_type: cuda # 设备类型:cuda / CPU device_mesh: device_type: cuda - mesh: [0, 1] # 网格中的设备索引 - mesh_dim_names: ['dp'] # 维度名称,常用:dp(数据并行), tp(张量并行), pp(流水线并行) + dp_size: 2 # 数据并行大小 + # tp_size: 1 # 张量并行大小(可选) + # pp_size: 1 # 流水线并行大小(可选) + # ep_size: 1 # 专家并行大小(可选) +``` + +**重要配置参数说明:** + +| 参数 | 类型 | 说明 | +|------|------|------| +| `ranks` | list[int] | **物理 GPU 卡号**,直接对应机器上的实际 GPU 设备 | +| `dp_size` | int | 数据并行大小 | +| `tp_size` | int (可选) | 张量并行大小 | +| `pp_size` | int (可选) | 流水线并行大小 | +| `ep_size` | int (可选) | 专家并行大小(用于 MoE 模型) | + +**环境变量:** + +```bash +export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # 每台物理机上的 GPU 总数(必须设置) +export TWINKLE_TRUST_REMOTE_CODE=0 # 是否信任远程代码 ``` diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 7aacc90e..d9bf4207 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -848,13 +848,13 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs): Args: name: Checkpoint name or HuggingFace Hub model id. output_dir: Parent directory that contains the checkpoint folder. - If None **and** ``resume`` is False, downloads from Hub. - resume: If True, restore optimizer, lr_scheduler and RNG state + If None **and** ``load_optimizer`` is False, downloads from Hub. + load_optimizer: If True, restore optimizer, lr_scheduler and RNG state from the mcore sub-checkpoint for training resumption. **kwargs: Additional arguments (``adapter_name``, ``no_load_optim``, ``no_load_rng``, etc.). """ - resume = kwargs.pop('resume', False) + resume = kwargs.pop('load_optimizer', False) if output_dir is None and not resume: # Load from hub token = kwargs.pop('token', None) diff --git a/src/twinkle/server/tinker/common/compat_base.py b/src/twinkle/server/tinker/common/compat_base.py index d1bdca03..1e476bbb 100644 --- a/src/twinkle/server/tinker/common/compat_base.py +++ b/src/twinkle/server/tinker/common/compat_base.py @@ -68,15 +68,43 @@ def collect_forward_backward_results(results, device_mesh: DeviceMesh): def clean_metrics(metrics: dict) -> dict: + import re + from numbers import Number + + def _to_float(v): + # python numeric / numpy scalar + if isinstance(v, (float, int, Number, np.generic, str)): + try: + return float(v) + except Exception: + return None + # 0-d torch tensor + if isinstance(v, torch.Tensor) and v.numel() == 1: + try: + return float(v.item()) + except Exception: + return None + return None + cleaned = {} for key, value in metrics.items(): + fv = _to_float(value) + if fv is not None: + cleaned[key] = fv + continue + + # handle common metric strings: "123 seconds", "1.23 iters/s" if isinstance(value, str): - import re - match = re.match(r'^([+-]?\d*\.?\d+)', value.strip()) - if match: - cleaned[key] = float(match.group(1)) - else: - cleaned[key] = value + s = value.strip() + if s: + try: + head, unit = s.split() # ignore unit/tail + cleaned[f'{key}/{unit}'] = float(head) + except Exception: + m = re.match(r'^([+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)', s) + if m: + cleaned[key] = float(m.group(1)) + return cleaned diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index 1d086ee4..2a119162 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -56,11 +56,6 @@ def build_model_app(model_id: str, Returns: Configured Ray Serve deployment bound with parameters """ - # adapter_config can be None; expanding with ** would raise TypeError and break Serve init. - # Normalize to {} so AdapterManagerMixin uses its default timeout/limits. - if adapter_config is None: - adapter_config = {} - app = FastAPI() @app.middleware('http') @@ -123,6 +118,37 @@ def __init__(self, self._init_adapter_manager(**adapter_config) self.start_adapter_countdown() + def _cleanup_adapter(self, adapter_name: str) -> None: + """Common adapter cleanup logic used by both manual unload and automatic expiration. + + This method handles: + 1. Clearing adapter state + 2. Removing adapter from model + 3. Unregistering from adapter manager + 4. Removing from server state + + Args: + adapter_name: Name of the adapter to clean up + """ + # Remove from model if it exists + if self.get_adapter_info(adapter_name): + # Clear adapter state + self.clear_adapter_state(adapter_name) + + self.model.remove_adapter(adapter_name) + # Unregister from adapter manager + self.unregister_adapter(adapter_name) + + # Remove from server state + self.state.unload_model(adapter_name) + + def _on_adapter_expired(self, adapter_name: str) -> None: + # Called from AdapterManagerMixin's countdown thread. + # Fail any pending tasks for this adapter/model. + self.fail_pending_tasks_for_model(adapter_name, reason='Adapter expired') + # Perform common cleanup (without token since it's automatic) + self._cleanup_adapter(adapter_name) + @app.post('/create_model') async def create_model(self, request: Request, body: types.CreateModelRequest) -> types.UntypedAPIFuture: """Create a new model adapter for training. @@ -146,31 +172,32 @@ async def create_model(self, request: Request, body: types.CreateModelRequest) - async def _create_adapter(): try: if body.lora_config: - # Check adapter limit before creating - allowed, reason = self.check_adapter_limit(request.state.token, True) - if not allowed: - raise RuntimeError(reason) - # TODO: support more lora config parameters, train_unembed, etc. lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear') adapter_name = self.get_adapter_name(adapter_name=model_id) - self.model.add_adapter_to_model(adapter_name=adapter_name, config_or_dir=lora_cfg) - # Register adapter with rate limiter for lifecycle tracking - self.register_adapter(adapter_name, request.state.token) + # Register adapter FIRST (limit check happens inside register_adapter) + self.register_adapter(adapter_name, request.state.token, session_id=body.session_id) + + # Create adapter AFTER successful registration + self.model.add_adapter_to_model(adapter_name=adapter_name, config_or_dir=lora_cfg) self.model.set_template('Template', adapter_name=adapter_name, model_id=self.base_model) self.model.set_processor('InputProcessor', adapter_name=adapter_name) self.model.set_optimizer('Adam', adapter_name=adapter_name) + # Fresh adapter has no accumulated gradients. + self.set_adapter_state(adapter_name, 'grad_ready', False) + training_run_manager = create_training_run_manager(request.state.token) training_run_manager.save(model_id, body) return types.CreateModelResponse(model_id=model_id) except Exception: - # If adapter creation fails, decrement the count - self.check_adapter_limit(request.state.token, False) + # Ensure we don't leave stale grad state. + adapter_name = self.get_adapter_name(adapter_name=model_id) + self._cleanup_adapter(adapter_name) logger.error(traceback.format_exc()) return types.RequestFailedResponse( @@ -179,9 +206,10 @@ async def _create_adapter(): ) return await self.schedule_task( - _create_adapter(), + _create_adapter, model_id=model_id, token=request.state.token, + task_type='create_model', ) @app.post('/get_info') @@ -230,19 +258,15 @@ async def unload_model(self, request: Request, body: types.UnloadModelRequest) - async def _do_unload(): # Only remove adapter, not the base model adapter_name = self.get_adapter_name(adapter_name=body.model_id) - if self.get_adapter_info(adapter_name): - self.model.remove_adapter(adapter_name) - # Unregister adapter from rate limiter - self.unregister_adapter(adapter_name) - # Decrement adapter count via rate limiter - self.check_adapter_limit(request.state.token, False) - self.state.unload_model(body.model_id) + # Use common cleanup logic + self._cleanup_adapter(adapter_name) return types.UnloadModelResponse(model_id=body.model_id) return await self.schedule_task( - _do_unload(), + _do_unload, model_id=body.model_id, token=request.state.token, + task_type='unload_model', ) @app.post('/forward') @@ -268,10 +292,6 @@ async def _do_forward(): self.touch_adapter(adapter_name) datum_list = body.forward_input.data - assert len(datum_list) >= self.device_mesh.data_world_size, (f'Batch size {len(datum_list)} must ' - f'be greater than data world size ' - f'{self.device_mesh.data_world_size}') - loss_fn_config = body.forward_input.loss_fn_config or {} output = self.model.forward_only(inputs=datum_list, adapter_name=adapter_name) @@ -288,13 +308,18 @@ async def _do_forward(): category=types.RequestErrorCategory.Server, ) - # Calculate input tokens for rate limiting - input_tokens = sum(len(d.model_input.to_ints()) for d in body.forward_input.data) + # Calculate input tokens and batch size for validation + datum_list = body.forward_input.data + input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list) + batch_size = len(datum_list) return await self.schedule_task( - _do_forward(), + _do_forward, model_id=body.model_id, token=request.state.token, input_tokens=input_tokens, + batch_size=batch_size, + data_world_size=self.device_mesh.data_world_size, + task_type='forward', ) @app.post('/forward_backward') @@ -324,18 +349,18 @@ async def _do_forward_backward(): self.touch_adapter(adapter_name) datum_list = body.forward_backward_input.data - assert len(datum_list) >= self.device_mesh.data_world_size, ( - f'Batch size {len(datum_list)} must be greater ' - f'than data world size {self.device_mesh.data_world_size}') - loss_fn = body.forward_backward_input.loss_fn loss_fn_config = body.forward_backward_input.loss_fn_config or {} # Unified forward_backward for both Megatron and Transformers output, loss = self.model.forward_backward( inputs=datum_list, adapter_name=adapter_name, loss_fn=loss_fn, **loss_fn_config) - output_type = ('ImportanceSamplingLossReturn' - if loss_fn == 'importance_sampling' else 'CrossEntropyLossReturn') + if loss_fn == 'importance_sampling': + output_type = 'ImportanceSamplingLossReturn' + else: + output_type = 'CrossEntropyLossReturn' + # Mark gradients as ready after a successful forward_backward. + self.set_adapter_state(adapter_name, 'grad_ready', True) return types.ForwardBackwardOutput( loss_fn_output_type=output_type, loss_fn_outputs=output, @@ -348,13 +373,18 @@ async def _do_forward_backward(): category=types.RequestErrorCategory.Server, ) - # Calculate input tokens for rate limiting - input_tokens = sum(len(d.model_input.to_ints()) for d in body.forward_backward_input.data) + # Calculate input tokens and batch size for validation + datum_list = body.forward_backward_input.data + input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list) + batch_size = len(datum_list) return await self.schedule_task( - _do_forward_backward(), + _do_forward_backward, model_id=body.model_id, token=request.state.token, input_tokens=input_tokens, + batch_size=batch_size, + data_world_size=self.device_mesh.data_world_size, + task_type='forward_backward', ) @app.post('/optim_step') @@ -376,10 +406,18 @@ async def _do_optim(): adapter_name = self.get_adapter_name(adapter_name=body.model_id) self.assert_adapter_exists(adapter_name=adapter_name) + # Disallow empty step (must have at least one forward_backward since last step) + if not self.get_adapter_state(adapter_name, 'grad_ready', False): + raise RuntimeError( + f'No accumulated gradients for adapter={adapter_name}; call forward_backward before optim_step' # noqa: E501 + ) + # Touch adapter to reset inactivity counter self.touch_adapter(adapter_name) self.model.step(adam_params=body.adam_params, adapter_name=adapter_name) + # Clear grad-ready after a successful step. + self.set_adapter_state(adapter_name, 'grad_ready', False) metrics = self.model.calculate_metric(is_training=True, adapter_name=adapter_name) return types.OptimStepResponse(metrics=metrics) except Exception: @@ -390,9 +428,10 @@ async def _do_optim(): ) return await self.schedule_task( - _do_optim(), + _do_optim, model_id=body.model_id, token=request.state.token, + task_type='optim_step', ) @app.post('/save_weights') @@ -440,9 +479,10 @@ async def _do_save(): ) return await self.schedule_task( - _do_save(), + _do_save, model_id=body.model_id, token=request.state.token, + task_type='save_weights', ) @app.post('/save_weights_for_sampler') @@ -504,9 +544,10 @@ async def _do_save_for_sampler(): ) return await self.schedule_task( - _do_save_for_sampler(), + _do_save_for_sampler, model_id=body.model_id, token=request.state.token, + task_type='save_weights_for_sampler', ) @app.post('/load_weights') @@ -545,6 +586,9 @@ async def _do_load(): load_optimizer=load_optimizer, adapter_name=adapter_name, token=token) + + # Loading a checkpoint should reset step readiness. + self.set_adapter_state(adapter_name, 'grad_ready', False) return types.LoadWeightsResponse(path=body.path, type='load_weights') except Exception: logger.error(traceback.format_exc()) @@ -554,9 +598,10 @@ async def _do_load(): ) return await self.schedule_task( - _do_load(), + _do_load, model_id=body.model_id, token=request.state.token, + task_type='load_weights', ) return ModelManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, use_megatron, diff --git a/src/twinkle/server/tinker/sampler.py b/src/twinkle/server/tinker/sampler.py index 22f69ba5..bf4108c9 100644 --- a/src/twinkle/server/tinker/sampler.py +++ b/src/twinkle/server/tinker/sampler.py @@ -9,6 +9,7 @@ 3. Multi-user inference with rate limiting 4. Flexible sampling parameters """ +import os import traceback from fastapi import FastAPI, Request from ray import serve @@ -160,6 +161,13 @@ async def _do_sample(): checkpoint_manager = create_checkpoint_manager(token) adapter_name, adapter_uri = checkpoint_manager.parse_adapter_uri(model_path) + # Validate adapter URI existence if provided + if not adapter_uri or not os.path.exists(adapter_uri): + return types.RequestFailedResponse( + error=f'Adapter URI {model_path} does not exist. Please check the model_path.', + category=types.RequestErrorCategory.User, + ) + # Convert tinker SamplingParams to twinkle SamplingParams if needed sampling_params = None if body.sampling_params: @@ -213,9 +221,10 @@ async def _do_sample(): # Calculate input tokens for rate limiting input_tokens = len(body.prompt.to_ints()) return await self.schedule_task( - _do_sample(), + _do_sample, token=request.state.token, input_tokens=input_tokens, + task_type='sample', ) return SamplerManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, sampler_type, diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py index 3044994c..2e669f56 100644 --- a/src/twinkle/server/tinker/server.py +++ b/src/twinkle/server/tinker/server.py @@ -31,6 +31,7 @@ def build_server_app(deploy_options: dict[str, Any], supported_models: list[types.SupportedModel] | None = None, + server_config: dict[str, Any] = {}, **kwargs): """Build and configure the Tinker-compatible server application. @@ -40,23 +41,12 @@ def build_server_app(deploy_options: dict[str, Any], Args: deploy_options: Ray Serve deployment configuration (num_replicas, etc.) supported_models: List of supported base models for validation + server_config: Server configuration options (per_token_adapter_limit, etc.) **kwargs: Additional keyword arguments (route_prefix, etc.) Returns: Configured Ray Serve deployment bound with options """ - # Normalize supported_models to objects; passing raw dicts can trigger internal errors - # when creating LoRA training clients via the tinker API. - if supported_models: - normalized = [] - for item in supported_models: - if isinstance(item, types.SupportedModel): - normalized.append(item) - elif isinstance(item, dict): - normalized.append(types.SupportedModel(**item)) - else: - raise TypeError(...) - supported_models = normalized app = FastAPI() @app.middleware('http') @@ -76,18 +66,22 @@ class TinkerCompatServer: - Training run and checkpoint CRUD operations """ - def __init__(self, supported_models: list[types.SupportedModel] | None = None, **kwargs) -> None: + def __init__(self, + supported_models: list[types.SupportedModel] | None = None, + server_config: dict[str, Any] = {}, + **kwargs) -> None: """Initialize the Tinker-compatible server. Args: supported_models: List of supported base models for validation **kwargs: Additional configuration (route_prefix, etc.) """ - self.state = get_server_state() + # Get per_token_adapter_limit from kwargs or use default + self.state = get_server_state(**server_config) # Disable proxy for internal requests to avoid routing through external proxies self.client = httpx.AsyncClient(timeout=None, trust_env=False) self.route_prefix = kwargs.get('route_prefix', '/api/v1') - self.supported_models = supported_models or [ + self.supported_models = self.normalize_models(supported_models) or [ types.SupportedModel(model_name='Qwen/Qwen2.5-0.5B-Instruct'), types.SupportedModel(model_name='Qwen/Qwen2.5-3B-Instruct'), types.SupportedModel(model_name='Qwen/Qwen2.5-7B-Instruct'), @@ -97,6 +91,20 @@ def __init__(self, supported_models: list[types.SupportedModel] | None = None, * # Lock for ModelScope config file operations (login writes, get_user_info reads) self._modelscope_config_lock = asyncio.Lock() + def normalize_models(self, supported_models): + # Normalize supported_models to objects; passing raw dicts can trigger internal errors + # when creating LoRA training clients via the tinker API. + if supported_models: + normalized = [] + for item in supported_models: + if isinstance(item, types.SupportedModel): + normalized.append(item) + elif isinstance(item, dict): + normalized.append(types.SupportedModel(**item)) + else: + normalized.append(types.SupportedModel(name=item)) + return normalized + def _validate_base_model(self, base_model: str) -> None: """Validate that base_model is in supported_models list. @@ -675,4 +683,5 @@ async def save_weights_for_sampler(self, request: Request, body: types.SaveWeigh base_model = self._get_base_model(body.model_id) return await self._proxy_to_model(request, 'save_weights_for_sampler', base_model) - return TinkerCompatServer.options(**deploy_options).bind(supported_models=supported_models, **kwargs) + return TinkerCompatServer.options(**deploy_options).bind( + supported_models=supported_models, server_config=server_config, **kwargs) diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py index 486ec201..1660cd10 100644 --- a/src/twinkle/server/twinkle/model.py +++ b/src/twinkle/server/twinkle/model.py @@ -187,6 +187,27 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mes self._init_adapter_manager(**adapter_config) self.start_adapter_countdown() + def _on_adapter_expired(self, adapter_name: str) -> None: + """Handle adapter expiration by removing it from the model. + + This method is called automatically by AdapterManagerMixin when + an adapter exceeds its timeout or TTL. + + Args: + adapter_name: Name of the expired adapter to remove. + """ + # Remove from model if it exists + if self.get_adapter_info(adapter_name): + # Clear adapter state + self.clear_adapter_state(adapter_name) + # Unregister from adapter manager + self.unregister_adapter(adapter_name) + + # Remove from server state + self.state.unload_model(adapter_name) + # Remove adapter from model + self.model.remove_adapter(adapter_name) + @app.post('/create') def create(self, request: Request, body: CreateRequest): return {'status': 'ok'} @@ -485,16 +506,11 @@ def add_adapter_to_model(self, request: Request, body: AddAdapterRequest): token = request.state.token training_run_manager = create_training_run_manager(token) - with self._adapter_lock: - self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) - - # Register adapter for lifecycle tracking + # Register adapter FIRST (limit check happens inside register_adapter) self.register_adapter(adapter_name, token) - # Check adapter limit (raises if exceeded) - allowed, reason = self.check_adapter_limit(token, True) - if not allowed: - raise RuntimeError(reason) + # Create adapter AFTER successful registration + self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) # Save training run metadata (similar to tinker's create_model) # Create a training run config from the adapter configuration diff --git a/src/twinkle/server/twinkle/sampler.py b/src/twinkle/server/twinkle/sampler.py index 3454b069..857c53f6 100644 --- a/src/twinkle/server/twinkle/sampler.py +++ b/src/twinkle/server/twinkle/sampler.py @@ -13,7 +13,7 @@ from fastapi import FastAPI, Request from pydantic import BaseModel, Field from ray import serve -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import twinkle from twinkle import DeviceGroup, DeviceMesh @@ -103,24 +103,6 @@ def build_sampler_app(model_id: str, **kwargs): """Build a sampler application for text generation inference. - Args: - model_id: Model identifier (e.g., "Qwen/Qwen2.5-7B-Instruct") - nproc_per_node: Number of GPU processes per node - device_group: Device group configuration dict - device_mesh: Device mesh configuration dict for parallelism - deploy_options: Ray Serve deployment options - sampler_type: Type of sampler to use ('vllm' or 'torch') - engine_args: Additional engine arguments for the sampler - adapter_config: Adapter lifecycle config (adapter_timeout, per_token_adapter_limit) - **kwargs: Additional arguments passed to the sampler - - Returns: - Ray Serve deployment bound with configuration - """ - app = FastAPI( - title='Twinkle Sampler', description='REST API for distributed text generation inference', version='1.0.0') - """Build a sampler application for text generation inference. - Args: model_id: Model identifier (e.g., "Qwen/Qwen2.5-7B-Instruct") nproc_per_node: Number of GPU processes per node @@ -200,17 +182,15 @@ def _on_adapter_expired(self, adapter_name: str, token: str) -> None: try: self.sampler.remove_adapter(adapter_name) logger.info(f'Removed expired adapter {adapter_name}') - self.check_adapter_limit(token, False) + # Adapter count is now tracked dynamically, no manual update needed except Exception as e: logger.warning(f'Failed to remove expired adapter {adapter_name}: {e}') - _adapter_config = adapter_config or {} - self._init_adapter_manager(**_adapter_config) - self.start_adapter_countdown() @staticmethod def _get_adapter_name(request: Request, adapter_name: Optional[str]) -> Optional[str]: if adapter_name is None or adapter_name == '': return None + return request.state.request_id + '-' + adapter_name @app.post('/create', response_model=CreateResponse) def create(self, request: Request) -> CreateResponse: @@ -304,13 +284,9 @@ def add_adapter_to_sampler(self, request: Request, body: AddAdapterRequest) -> A from peft import LoraConfig config = LoraConfig(**body.config) if isinstance(body.config, dict) else body.config - with self._adapter_lock: - self.sampler.add_adapter_to_sampler(full_adapter_name, config) - self.register_adapter(full_adapter_name, token) - allowed, reason = self.check_adapter_limit(token, True) - if not allowed: - raise RuntimeError(reason) + + self.sampler.add_adapter_to_sampler(full_adapter_name, config) return AddAdapterResponse(adapter_name=full_adapter_name) diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index 3244c931..04e56922 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -17,7 +17,6 @@ if TYPE_CHECKING: from twinkle.server.utils.state import ServerStateProxy - from twinkle.model import TwinkleModel from twinkle.utils.logger import get_logger @@ -31,61 +30,99 @@ class AdapterManagerMixin: that have been inactive for longer than the configured timeout period. Inheriting classes should: - 1. Have a `self.model` attribute for model operations - 2. Call _init_adapter_manager() in __init__ - 3. Optionally override _on_adapter_expired() to customize expiration handling + 1. Call _init_adapter_manager() in __init__ + 2. Override _on_adapter_expired() to customize expiration handling Attributes: _adapter_timeout: Timeout in seconds for inactive adapters. - model: Model instance for adapter operations (must be set by inheriting class). """ # Type hint for state attribute that inheriting classes must provide state: ServerStateProxy - model: TwinkleModel - def _init_adapter_manager(self, adapter_timeout: float = 1800.0, per_token_adapter_limit: int = 30) -> None: + def _init_adapter_manager( + self, + adapter_timeout: float = 1800.0, + per_token_adapter_limit: int = 30, + adapter_max_lifetime: float = 12 * 60 * 60, + ) -> None: """Initialize the adapter manager. This should be called in the __init__ of the inheriting class. Args: - adapter_timeout: Timeout in seconds for inactive adapters. - Default is 1800.0 (30 minutes). + adapter_timeout: Timeout in seconds for inactive adapters and session-based expiration. + Default is 1800.0 (30 minutes). Adapters linked to sessions will expire + when their session hasn't been touched for this duration. per_token_adapter_limit: Maximum number of adapters per user token. Default is 30. + adapter_max_lifetime: Maximum lifetime in seconds for an adapter since creation. + Default is 43200.0 (12 hours). If <= 0, lifetime enforcement is disabled. """ self._adapter_timeout = adapter_timeout self._per_token_adapter_limit = per_token_adapter_limit + self._adapter_max_lifetime = adapter_max_lifetime # Adapter lifecycle tracking - # Dict mapping adapter_name -> {'token': str, 'last_activity': float, 'created_at': float, - # 'inactivity_counter': int} + # Dict mapping adapter_name -> + # {'token': str, 'session_id': str, 'last_activity': float, 'created_at': float, 'inactivity_counter': int} self._adapter_records: dict[str, dict[str, Any]] = {} # Track adapter count per token self._adapter_counts: dict[str, int] = {} - self._adapter_lock = threading.Lock() # Countdown thread self._adapter_countdown_thread: threading.Thread | None = None self._adapter_countdown_running = False - def register_adapter(self, adapter_name: str, token: str) -> None: + def register_adapter(self, adapter_name: str, token: str, session_id: str | None = None) -> None: """Register a new adapter for lifecycle tracking. Args: adapter_name: Name of the adapter to register. token: User token that owns this adapter. + session_id: Optional session ID to associate with this adapter. + If provided, adapter will expire when the session expires. + + Raises: + RuntimeError: If adapter limit is exceeded for this token. + """ + # Check adapter limit BEFORE registering + allowed, reason = self.check_adapter_limit(token) + if not allowed: + raise RuntimeError(reason) + + current_time = time.time() + self._adapter_records[adapter_name] = { + 'token': token, + 'session_id': session_id, + 'last_activity': current_time, + 'created_at': current_time, + 'inactivity_counter': 0, + 'state': {}, + 'expiring': False, + } + logger.debug(f'[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}...' + + (f' (session: {session_id})' if session_id else '')) + + def _is_session_alive(self, session_id: str) -> bool: + """Check if a session is still alive via state proxy. + + Args: + session_id: Session ID to check + + Returns: + True if session is alive, False if expired or not found """ - with self._adapter_lock: - current_time = time.time() - self._adapter_records[adapter_name] = { - 'token': token, - 'last_activity': current_time, - 'created_at': current_time, - 'inactivity_counter': 0, - } - logger.debug(f'[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}...') + if not session_id: + return True # No session association means always alive + + # Get session last heartbeat through proxy + last_heartbeat = self.state.get_session_last_heartbeat(session_id) + if last_heartbeat is None: + return False # Session doesn't exist + + # Check if session has timed out using adapter_timeout + return (time.time() - last_heartbeat) < self._adapter_timeout def unregister_adapter(self, adapter_name: str) -> bool: """Unregister an adapter from lifecycle tracking. @@ -96,14 +133,52 @@ def unregister_adapter(self, adapter_name: str) -> bool: Returns: True if adapter was found and removed, False otherwise. """ - with self._adapter_lock: - if adapter_name in self._adapter_records: - adapter_info = self._adapter_records.pop(adapter_name) - token = adapter_info.get('token') - logger.debug(f'[AdapterManager] Unregistered adapter {adapter_name} for ' - f"token {token[:8] if token else 'unknown'}...") - return True - return False + if adapter_name in self._adapter_records: + adapter_info = self._adapter_records.pop(adapter_name) + token = adapter_info.get('token') + logger.debug( + f"[AdapterManager] Unregistered adapter {adapter_name} for token {token[:8] if token else 'unknown'}..." + ) + return True + return False + + def set_adapter_state(self, adapter_name: str, key: str, value: Any) -> None: + """Set a per-adapter state value. + + This is intentionally generic so higher-level services can store + adapter-scoped state (e.g., training readiness) without maintaining + separate side maps. + """ + info = self._adapter_records.get(adapter_name) + if info is None: + return + state = info.setdefault('state', {}) + state[key] = value + + def get_adapter_state(self, adapter_name: str, key: str, default: Any = None) -> Any: + """Get a per-adapter state value.""" + info = self._adapter_records.get(adapter_name) + if info is None: + return default + state = info.get('state') or {} + return state.get(key, default) + + def pop_adapter_state(self, adapter_name: str, key: str, default: Any = None) -> Any: + """Pop a per-adapter state value.""" + info = self._adapter_records.get(adapter_name) + if info is None: + return default + state = info.get('state') + if not isinstance(state, dict): + return default + return state.pop(key, default) + + def clear_adapter_state(self, adapter_name: str) -> None: + """Clear all per-adapter state values.""" + info = self._adapter_records.get(adapter_name) + if info is None: + return + info['state'] = {} def touch_adapter(self, adapter_name: str) -> bool: """Update adapter activity timestamp to prevent timeout. @@ -114,12 +189,14 @@ def touch_adapter(self, adapter_name: str) -> bool: Returns: True if adapter was found and touched, False otherwise. """ - with self._adapter_lock: - if adapter_name in self._adapter_records: - self._adapter_records[adapter_name]['last_activity'] = time.time() - self._adapter_records[adapter_name]['inactivity_counter'] = 0 - return True + info = self._adapter_records.get(adapter_name) + if not info: + return False + if info.get('expiring'): return False + info['last_activity'] = time.time() + info['inactivity_counter'] = 0 + return True def get_adapter_info(self, adapter_name: str) -> dict[str, Any] | None: """Get information about a registered adapter. @@ -130,42 +207,21 @@ def get_adapter_info(self, adapter_name: str) -> dict[str, Any] | None: Returns: Dict with adapter information or None if not found. """ - with self._adapter_lock: - return self._adapter_records.get(adapter_name) + return self._adapter_records.get(adapter_name) - def list_adapters(self, token: str | None = None) -> list[str]: - """List all registered adapters, optionally filtered by token. - - Args: - token: Optional user token to filter by. - - Returns: - List of adapter names. - """ - with self._adapter_lock: - if token is None: - return list(self._adapter_records.keys()) - return [name for name, info in self._adapter_records.items() if info.get('token') == token] - - def _on_adapter_expired(self, adapter_name: str, token: str) -> None: + def _on_adapter_expired(self, adapter_name: str) -> None: """Hook method called when an adapter expires. - Default implementation removes the adapter from the model and updates adapter count. - This is called from the countdown thread, so be careful with blocking operations. + This method must be overridden by inheriting classes to handle + adapter expiration logic. The base implementation raises NotImplementedError. Args: adapter_name: Name of the expired adapter. - token: User token that owns this adapter. - """ - try: - # Remove adapter from model - self.model.remove_adapter(adapter_name) - logger.info(f'[AdapterManager] Removed expired adapter {adapter_name} for token {token[:8]}...') - # Decrement adapter count - self.check_adapter_limit(token, False) - except Exception as e: - logger.warning(f'[AdapterManager] Failed to remove expired adapter {adapter_name}: {e}') + Raises: + NotImplementedError: If not overridden by inheriting class. + """ + raise NotImplementedError(f'_on_adapter_expired must be implemented by {self.__class__.__name__}') @staticmethod def get_adapter_name(adapter_name: str) -> str: @@ -182,30 +238,11 @@ def get_adapter_name(adapter_name: str) -> str: return adapter_name def assert_adapter_exists(self, adapter_name: str) -> None: - """Validate that an adapter exists. - - Args: - adapter_name: The adapter name to check - - Raises: - AssertionError: If adapter doesn't exist - """ - assert adapter_name and self.get_adapter_info(adapter_name) is not None, \ + """Validate that an adapter exists and is not expiring.""" + info = self._adapter_records.get(adapter_name) + assert adapter_name and info is not None and not info.get('expiring'), \ f'Adapter {adapter_name} not found' - def assert_adapter_valid(self, adapter_name: str | None) -> None: - """Validate that an adapter name is valid. - - Args: - adapter_name: The adapter name to validate (can be None or empty) - - Raises: - AssertionError: If adapter name is invalid - """ - assert (adapter_name is None or adapter_name == '' or - self.get_adapter_info(adapter_name) is not None), \ - f'Adapter {adapter_name} is invalid' - def _adapter_countdown_loop(self) -> None: """Background thread that monitors and handles inactive adapters. @@ -218,29 +255,74 @@ def _adapter_countdown_loop(self) -> None: while self._adapter_countdown_running: try: time.sleep(1) - - # Find and process expired adapters - expired_adapters = [] - with self._adapter_lock: - for adapter_name, info in list(self._adapter_records.items()): - # Increment inactivity counter + now = time.time() + + expired_adapters: list[tuple[str, str | None]] = [] + # Create snapshot to avoid modification during iteration + adapter_snapshot = list(self._adapter_records.items()) + for adapter_name, info in adapter_snapshot: + if info.get('expiring'): + continue + + session_id = info.get('session_id') + created_at = info.get('created_at') + + # Check TTL for both cases + exceeded_ttl = ( + self._adapter_max_lifetime and self._adapter_max_lifetime > 0 + and (now - created_at) > self._adapter_max_lifetime) + + # Different logic based on session association + if session_id: + # Has session: check session expiration and TTL + session_expired = not self._is_session_alive(session_id) + should_expire = session_expired or exceeded_ttl + logger.debug( + f'[AdapterManager] Adapter {adapter_name} session expiration check ' + f'(session_id={session_id}, session_alive={not session_expired}, should_expire={should_expire})' # noqa:E501 + ) + expiration_reasons = [] + if exceeded_ttl: + expiration_reasons.append('ttl_exceeded') + if session_expired: + expiration_reasons.append('session_expired') + else: + # No session: check inactivity timeout and TTL info['inactivity_counter'] = info.get('inactivity_counter', 0) + 1 + exceeded_inactivity = info['inactivity_counter'] > self._adapter_timeout + should_expire = exceeded_ttl or exceeded_inactivity + logger.debug( + f'[AdapterManager] Adapter {adapter_name} inactivity check ' + f'(inactivity_counter={info["inactivity_counter"]}, timeout={self._adapter_timeout}, should_expire={should_expire})' # noqa:E501 + ) + expiration_reasons = [] + if exceeded_ttl: + expiration_reasons.append('ttl_exceeded') + if exceeded_inactivity: + expiration_reasons.append('inactivity_timeout') + + if should_expire: + info['expiring'] = True + info['state'] = {} # best-effort clear + token = info.get('token') + expired_adapters.append((adapter_name, token)) - # Check if adapter has timed out - if info['inactivity_counter'] > self._adapter_timeout: - token = info.get('token') - expired_adapters.append((adapter_name, token)) - self._adapter_records.pop(adapter_name, None) - logger.debug(f'[AdapterManager] Adapter {adapter_name} timed out after ' - f"{info['inactivity_counter']}s of inactivity") - - # Call hook method outside the lock for adapter_name, token in expired_adapters: + success = False try: - self._on_adapter_expired(adapter_name, token) + self._on_adapter_expired(adapter_name) + logger.info(f'[AdapterManager] Adapter {adapter_name} expired ' + f"(reasons={','.join(expiration_reasons)}, session={session_id})") + success = True except Exception as e: - logger.warning(f'[AdapterManager] Error in _on_adapter_expired() ' - f'for {adapter_name}: {e}') + logger.warning(f'[AdapterManager] Error while expiring adapter {adapter_name}: {e}') + finally: + if success: + self._adapter_records.pop(adapter_name, None) + else: + info = self._adapter_records.get(adapter_name) + if info is not None: + info['expiring'] = False except Exception as e: logger.warning(f'[AdapterManager] Error in countdown loop: {e}') @@ -272,63 +354,24 @@ def stop_adapter_countdown(self) -> None: self._adapter_countdown_thread.join(timeout=2.0) logger.debug('[AdapterManager] Countdown thread stopped') - def get_adapter_stats(self) -> dict[str, Any]: - """Get adapter manager statistics. - - Returns: - Dict with registered adapter count and configuration. - """ - with self._adapter_lock: - return { - 'registered_adapters': len(self._adapter_records), - 'tracked_adapter_counts': len(self._adapter_counts), - 'countdown_running': self._adapter_countdown_running, - 'adapter_timeout_seconds': self._adapter_timeout, - 'per_token_adapter_limit': self._per_token_adapter_limit, - } - - def check_adapter_limit(self, token: str, add: bool) -> tuple[bool, str | None]: - """Check and update adapter count for a user token. + def check_adapter_limit(self, token: str) -> tuple[bool, str | None]: + """Check adapter count for a user token. This method enforces per-user adapter limits to prevent resource exhaustion. + Counts adapters directly from _adapter_records instead of using state storage. Args: - token: User token to check/update. - add: True to add an adapter (increment count), False to remove (decrement count). + token: User token to check. Returns: Tuple of (allowed: bool, reason: Optional[str]). If allowed is False, reason contains the explanation. """ - user_key = token + '_' + 'model_adapter' - with self._adapter_lock: - current_count = self.state.get_config(user_key) or 0 - - if add: - # Check if adding would exceed limit - if current_count >= self._per_token_adapter_limit: - return False, f'Adapter limit exceeded: {current_count}/{self._per_token_adapter_limit} adapters' - # Increment count in global state - self.state.add_config(user_key, current_count + 1) - return True, None - else: - # Decrement count in global state - if current_count > 0: - current_count -= 1 - self.state.add_config(user_key, current_count) - if current_count <= 0: - self.state.pop_config(user_key) - return True, None - - def get_adapter_count(self, token: str) -> int: - """Get current adapter count for a user token. - - Args: - token: User token to query. - - Returns: - Current number of adapters for this token. - """ - user_key = token + '_' + 'model_adapter' - with self._adapter_lock: - return self.state.get_config(user_key) or 0 + # Count adapters directly from _adapter_records + current_count = sum(1 for record in self._adapter_records.values() + if record.get('token') == token and not record.get('expiring', False)) + + # Check if current count exceeds limit + if current_count >= self._per_token_adapter_limit: + return False, f'Adapter limit exceeded: {current_count}/{self._per_token_adapter_limit} adapters' + return True, None diff --git a/src/twinkle/server/utils/state.py b/src/twinkle/server/utils/state.py index faf4e9a7..e191d80a 100644 --- a/src/twinkle/server/utils/state.py +++ b/src/twinkle/server/utils/state.py @@ -31,7 +31,8 @@ class ServerState: def __init__( self, expiration_timeout: float = 86400.0, # 24 hours in seconds - cleanup_interval: float = 3600.0) -> None: # 1 hour in seconds + cleanup_interval: float = 3600.0, + **kwargs) -> None: # 1 hour in seconds # Session tracking self.sessions: dict[str, dict[str, Any]] = {} # Model registration @@ -67,6 +68,7 @@ def create_session(self, payload: dict[str, Any]) -> str: 'user_metadata': payload.get('user_metadata') or {}, 'sdk_version': payload.get('sdk_version'), 'created_at': datetime.now().isoformat(), + 'last_heartbeat': time.time(), } return session_id @@ -85,6 +87,21 @@ def touch_session(self, session_id: str) -> bool: self.sessions[session_id]['last_heartbeat'] = time.time() return True + def get_session_last_heartbeat(self, session_id: str) -> float | None: + """ + Get the last heartbeat timestamp for a session. + + Args: + session_id: The session ID to query + + Returns: + Last heartbeat timestamp, or None if session doesn't exist + """ + session_info = self.sessions.get(session_id) + if not session_info: + return None + return session_info.get('last_heartbeat') + # ----- Model Registration ----- def register_model(self, payload: dict[str, Any], model_id: str | None = None, token: str | None = None) -> str: @@ -461,6 +478,9 @@ def create_session(self, payload: dict[str, Any]) -> str: def touch_session(self, session_id: str) -> bool: return ray.get(self._actor.touch_session.remote(session_id)) + def get_session_last_heartbeat(self, session_id: str) -> float | None: + return ray.get(self._actor.get_session_last_heartbeat.remote(session_id)) + # ----- Model Registration ----- def register_model(self, payload: dict[str, Any], model_id: str | None = None, token: str | None = None) -> str: @@ -553,7 +573,9 @@ def get_cleanup_stats(self) -> dict[str, Any]: return ray.get(self._actor.get_cleanup_stats.remote()) -def get_server_state(actor_name: str = 'twinkle_server_state', auto_start_cleanup: bool = True) -> ServerStateProxy: +def get_server_state(actor_name: str = 'twinkle_server_state', + auto_start_cleanup: bool = True, + **server_state_kwargs) -> ServerStateProxy: """ Get or create the ServerState Ray actor. @@ -563,6 +585,8 @@ def get_server_state(actor_name: str = 'twinkle_server_state', auto_start_cleanu Args: actor_name: Name for the Ray actor (default: 'twinkle_server_state') auto_start_cleanup: Whether to automatically start the cleanup task (default: True) + **server_state_kwargs: Additional keyword arguments passed to ServerState constructor + (e.g., expiration_timeout, cleanup_interval, per_token_adapter_limit) Returns: A ServerStateProxy for interacting with the actor @@ -572,7 +596,7 @@ def get_server_state(actor_name: str = 'twinkle_server_state', auto_start_cleanu except ValueError: try: _ServerState = ray.remote(ServerState) - actor = _ServerState.options(name=actor_name, lifetime='detached').remote() + actor = _ServerState.options(name=actor_name, lifetime='detached').remote(**server_state_kwargs) # Start cleanup task for newly created actor if auto_start_cleanup: try: diff --git a/src/twinkle/server/utils/task_queue.py b/src/twinkle/server/utils/task_queue.py index 6ef5fe66..4d272a07 100644 --- a/src/twinkle/server/utils/task_queue.py +++ b/src/twinkle/server/utils/task_queue.py @@ -10,11 +10,13 @@ from __future__ import annotations import asyncio +import time import traceback import uuid -from dataclasses import dataclass, field +from collections import deque +from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Deque, Dict, Optional from twinkle.utils.logger import get_logger from .rate_limiter import RateLimiter @@ -59,8 +61,7 @@ class TaskQueueConfig: enabled: Whether rate limiting is enabled. token_cleanup_multiplier: Multiplier for token cleanup threshold. token_cleanup_interval: How often to run cleanup task (seconds). - per_token_adapter_limit: Maximum number of adapters per user token. - adapter_timeout: Timeout in seconds for inactive adapters (default 30 minutes). + max_input_tokens: Maximum allowed input tokens per request (default 10000). """ rps_limit: float = 100.0 # 10 requests per second tps_limit: float = 10000.0 # 10000 input tokens per second @@ -70,6 +71,7 @@ class TaskQueueConfig: # Remove tokens after 10x window inactivity token_cleanup_multiplier: float = 10.0 token_cleanup_interval: float = 60.0 # Run cleanup every 60 seconds + max_input_tokens: int = 10000 # Maximum input tokens per request @classmethod def from_dict(cls, config_dict: dict[str, Any] | None = None) -> TaskQueueConfig: @@ -84,6 +86,7 @@ def from_dict(cls, config_dict: dict[str, Any] | None = None) -> TaskQueueConfig - enabled: whether rate limiting is enabled - token_cleanup_multiplier: multiplier for token cleanup threshold - token_cleanup_interval: cleanup task interval in seconds + - max_input_tokens: maximum input tokens per request Returns: TaskQueueConfig instance with values from dict merged with defaults. @@ -104,9 +107,23 @@ def from_dict(cls, config_dict: dict[str, Any] | None = None) -> TaskQueueConfig config.token_cleanup_multiplier = float(config_dict['token_cleanup_multiplier']) if 'token_cleanup_interval' in config_dict: config.token_cleanup_interval = float(config_dict['token_cleanup_interval']) + if 'max_input_tokens' in config_dict: + config.max_input_tokens = int(config_dict['max_input_tokens']) return config +@dataclass +class _QueuedTask: + request_id: str + coro_factory: Callable[[], Coroutine] + model_id: str | None + token: str | None + input_tokens: int + task_type: str | None + created_at: float + first_rate_limited_at: float | None = None + + class TaskQueueMixin: """Mixin providing task queue management, rate limiting, and status tracking. @@ -130,7 +147,7 @@ async def my_endpoint(self, request, body): async def _do_work(): return await some_operation() return await self.schedule_task( - _do_work(), + _do_work, model_id=body.model_id, token=request.state.token, input_tokens=len(body.tokens) @@ -147,7 +164,10 @@ def _init_task_queue(self, config: TaskQueueConfig | None = None) -> None: config: Optional TaskQueueConfig. If None, uses default config. """ self._task_queue_config = config or TaskQueueConfig() - self._task_queue: asyncio.Queue = asyncio.Queue() + # Per-key queues, but executed by a single global worker. + self._task_queues: dict[str, asyncio.Queue] = {} + self._queue_order: Deque[str] = deque() + self._new_task_event: asyncio.Event = asyncio.Event() # Initialize rate limiter for RPS/TPS control self._rate_limiter = RateLimiter( @@ -160,77 +180,258 @@ def _init_task_queue(self, config: TaskQueueConfig | None = None) -> None: # Start the rate limiter cleanup task self._rate_limiter.start_cleanup_task() + # Single worker to ensure model operations remain serial. self._worker_task: asyncio.Task | None = None self._worker_started = False self._worker_start_lock = asyncio.Lock() - async def _ensure_worker_started(self) -> None: - """Ensure the background worker is running. + # Event loop reference for thread-safe callbacks (e.g., adapter expiration thread) + self._event_loop: asyncio.AbstractEventLoop | None = None - Thread-safe: Uses asyncio.Lock to prevent race conditions when - multiple concurrent requests try to start the worker simultaneously. - """ - # Fast path: avoid lock if already started - if self._worker_started: + @staticmethod + def _queue_key( + model_id: str | None, + token: str | None, + ) -> str: + if model_id: + return f'model:{model_id}' + if token: + return f'token:{token}' + return 'default' + + async def _ensure_worker_started(self) -> None: + """Ensure the single background worker is running.""" + if self._worker_started and self._worker_task is not None and not self._worker_task.done(): return - # Slow path: acquire lock to safely check and start async with self._worker_start_lock: - # Double-check after acquiring lock (another coroutine might have started it) - if not self._worker_started: - logger.debug('[TaskQueue] Starting background worker...') - self._worker_task = asyncio.create_task(self._queue_worker()) - self._worker_started = True - logger.debug(f'[TaskQueue] Background worker started: {self._worker_task}') + if self._worker_started and self._worker_task is not None and not self._worker_task.done(): + return + self._worker_task = asyncio.create_task(self._queue_worker()) + self._worker_started = True + + def _ensure_queue_registered(self, queue_key: str) -> None: + if queue_key not in self._task_queues: + self._task_queues[queue_key] = asyncio.Queue() + if queue_key not in self._queue_order: + self._queue_order.append(queue_key) async def _queue_worker(self) -> None: - """Background worker that processes tasks from the queue serially. + """Single background worker that processes tasks serially across all queues. - This worker runs indefinitely, pulling tasks from the queue and - executing them one at a time. This ensures thread-safe execution - of model operations that cannot be parallelized. + Selection policy: round-robin across queue keys. If a task is rate-limited + at execution time, it is requeued and the worker tries other queues. """ logger.debug('[TaskQueue] Worker started') while True: try: - # Wait for a task from the queue - logger.debug(f'[TaskQueue] Waiting for task... (queue size: {self._task_queue.qsize()})') - request_id, coro, model_id = await self._task_queue.get() - - logger.debug(f'[TaskQueue] Processing task {request_id}') - try: - # Update status to RUNNING + # Wait until there is at least one queue with a task + while True: + if any(q.qsize() > 0 for q in self._task_queues.values()): + break + self._new_task_event.clear() + await self._new_task_event.wait() + + executed_any = False + # Try each queue at most once per loop for fairness + for _ in range(len(self._queue_order)): + queue_key = self._queue_order[0] + self._queue_order.rotate(-1) + + q = self._task_queues.get(queue_key) + if q is None: + continue + + try: + task: _QueuedTask = q.get_nowait() + except asyncio.QueueEmpty: + continue + + now = time.monotonic() + + # Global queue timeout + if (now - task.created_at) > self._task_queue_config.queue_timeout: + error_payload = { + 'error': f'Queue timeout exceeded: waited {now - task.created_at:.2f}s', + 'category': 'Server' + } + self.state.store_future_status( + task.request_id, + TaskStatus.FAILED.value, + task.model_id, + result=error_payload, + queue_state=QueueState.PAUSED_CAPACITY.value, + queue_state_reason=error_payload['error'], + ) + q.task_done() + continue + + # Rate limiting check has been moved to schedule_task(), so tasks here should pass rate limits + + # Execute + executed_any = True self.state.store_future_status( - request_id, TaskStatus.RUNNING.value, model_id, queue_state=QueueState.ACTIVE.value) - - # Execute the task - result = await coro - - logger.debug(f'[TaskQueue] Task {request_id} completed successfully') - # Store completed result - self.state.store_future_status(request_id, TaskStatus.COMPLETED.value, model_id, result=result) - except Exception: - # Store error result - logger.debug(f'[TaskQueue] Task {request_id} failed with error') - error_payload = {'error': traceback.format_exc(), 'category': 'Server'} - self.state.store_future_status(request_id, TaskStatus.FAILED.value, model_id, result=error_payload) - finally: - self._task_queue.task_done() + task.request_id, TaskStatus.RUNNING.value, task.model_id, queue_state=QueueState.ACTIVE.value) + + try: + coro = task.coro_factory() + result = await coro + self.state.store_future_status( + task.request_id, + TaskStatus.COMPLETED.value, + task.model_id, + result=result, + queue_state=QueueState.ACTIVE.value) + except Exception: + error_payload = {'error': traceback.format_exc(), 'category': 'Server'} + self.state.store_future_status( + task.request_id, + TaskStatus.FAILED.value, + task.model_id, + result=error_payload, + queue_state=QueueState.ACTIVE.value) + finally: + q.task_done() + + # Keep serial semantics: execute at most one runnable task per loop + break + + if not executed_any: + # All available tasks were rate-limited; avoid busy looping. + await asyncio.sleep(min(self._task_queue_config.window_seconds, 0.1)) except asyncio.CancelledError: logger.warning('[TaskQueue] Worker cancelled') break except Exception: - # Log but don't crash the worker logger.warning('Error in task queue worker') continue + async def _fail_queue_tasks_async(self, queue_key: str, reason: str) -> None: + q = self._task_queues.get(queue_key) + if q is None: + return + + drained: list[_QueuedTask] = [] + while True: + try: + drained.append(q.get_nowait()) + except asyncio.QueueEmpty: + break + + for task in drained: + error_payload = {'error': reason, 'category': 'Server'} + self.state.store_future_status( + task.request_id, + TaskStatus.FAILED.value, + task.model_id, + result=error_payload, + queue_state=QueueState.UNKNOWN.value, + queue_state_reason=reason, + ) + q.task_done() + + # Remove queue structures + self._task_queues.pop(queue_key, None) + try: + while queue_key in self._queue_order: + self._queue_order.remove(queue_key) + except ValueError: + pass + + def fail_pending_tasks_for_model(self, model_id: str, reason: str) -> None: + """Fail and drop queued tasks for a model. Safe to call from non-async threads.""" + queue_key = self._queue_key(model_id=model_id, token=None) + if self._event_loop is None: + # Best-effort: nothing we can do safely without a loop. + logger.warning(f'[TaskQueue] fail_pending_tasks_for_model called without event loop: {queue_key}') + return + + def _schedule() -> None: + asyncio.create_task(self._fail_queue_tasks_async(queue_key, reason)) + + self._event_loop.call_soon_threadsafe(_schedule) + + async def _perform_preflight_checks( + self, + request_id: str, + model_id: str | None, + token: str | None, + input_tokens: int, + batch_size: int | None = None, + data_world_size: int | None = None, + ) -> dict[str, Any] | None: + """Perform pre-flight checks including rate limiting and token validation. + + Args: + request_id: The request ID for status tracking. + model_id: Optional model_id for error reporting. + token: Optional user token for rate limiting. + input_tokens: Number of input tokens for validation. + batch_size: Optional batch size for validation. + data_world_size: Optional data world size for batch size validation. + + Returns: + None if checks pass, or error response dict if checks fail. + """ + if not token or not self._task_queue_config.enabled: + return None + + # Check max input tokens + if input_tokens > self._task_queue_config.max_input_tokens: + error_msg = f'Input tokens ({input_tokens}) exceed maximum allowed ({self._task_queue_config.max_input_tokens})' # noqa: E501 + error_payload = {'error': error_msg, 'category': 'User'} + self.state.store_future_status( + request_id, + TaskStatus.FAILED.value, + model_id, + result=error_payload, + queue_state=QueueState.UNKNOWN.value, + queue_state_reason=error_msg, + ) + return {'request_id': request_id, 'model_id': model_id} + + # Check batch size if provided + if batch_size is not None and data_world_size is not None: + if batch_size < data_world_size: + error_msg = f'Batch size {batch_size} must be greater than or equal to data world size {data_world_size}' # noqa: E501 + error_payload = {'error': error_msg, 'category': 'User'} + self.state.store_future_status( + request_id, + TaskStatus.FAILED.value, + model_id, + result=error_payload, + queue_state=QueueState.UNKNOWN.value, + queue_state_reason=error_msg, + ) + return {'request_id': request_id, 'model_id': model_id} + + # Check rate limits + allowed, reason = await self._rate_limiter.check_and_record(token, input_tokens) + if not allowed: + error_msg = f'Rate limit exceeded: {reason}' + error_payload = {'error': error_msg, 'category': 'User'} + self.state.store_future_status( + request_id, + TaskStatus.FAILED.value, + model_id, + result=error_payload, + queue_state=QueueState.PAUSED_RATE_LIMIT.value, + queue_state_reason=error_msg, + ) + return {'request_id': request_id, 'model_id': model_id} + + return None + async def schedule_task( self, - coro: Coroutine, + coro_factory: Callable[[], Coroutine], model_id: str | None = None, token: str | None = None, input_tokens: int = 0, + batch_size: int | None = None, + data_world_size: int | None = None, + task_type: str | None = None, ) -> dict[str, Any]: """Schedule an async task with rate limiting and status tracking. @@ -244,48 +445,66 @@ async def schedule_task( 3. Execute tasks serially through a queue Args: - coro: The coroutine to execute. + coro_factory: Factory that creates the coroutine to execute. The coroutine + will be created only after passing rate limiting and when it's time + to execute the queued task. model_id: Optional model_id to associate with the result. token: Optional user token for rate limiting. input_tokens: Number of input tokens for tps rate limiting. + batch_size: Optional batch size for validation. + data_world_size: Optional data world size for batch size validation. + task_type: Optional task type for logging/observability. Returns: Dict containing request_id and model_id for future retrieval. """ + # Generate request_id first so it can be included in error responses request_id = f'req_{uuid.uuid4().hex}' - logger.debug(f'[TaskQueue] Scheduling task {request_id}, rps_limit={self._task_queue_config.rps_limit}, ' - f'enabled={self._task_queue_config.enabled}') + # 1. Pre-flight checks: rate limiting, max token validation, and batch size validation + preflight_result = await self._perform_preflight_checks(request_id, model_id, token, input_tokens, batch_size, + data_world_size) + if preflight_result is not None: + return preflight_result + + if self._event_loop is None: + self._event_loop = asyncio.get_running_loop() - # 1. Register PENDING status FIRST (fixes race condition) + logger.debug( + f'[TaskQueue] Scheduling task {request_id}, rps_limit={self._task_queue_config.rps_limit}, enabled={self._task_queue_config.enabled}' # noqa: E501 + ) + + # 2. Register PENDING status FIRST self.state.store_future_status( request_id, TaskStatus.PENDING.value, model_id, queue_state=QueueState.ACTIVE.value) - # 2. Check rate limiting if enabled and token provided - if self._task_queue_config.enabled and token: - logger.debug(f'[TaskQueue] Checking rate limit for token={token[:8]}... input_tokens={input_tokens}') - allowed, reason = await self._rate_limiter.check_and_record(token, input_tokens) - if not allowed: - logger.debug(f'[TaskQueue] Rate limited: {reason}') - self.state.store_future_status( - request_id, - TaskStatus.RATE_LIMITED.value, - model_id, - reason=reason, - queue_state=QueueState.PAUSED_RATE_LIMIT.value, - queue_state_reason=reason) - return {'request_id': request_id, 'model_id': model_id} - logger.debug('[TaskQueue] Rate limit check passed') + # 3. Route to per-model/per-token queue + queue_key = self._queue_key(model_id=model_id, token=token) + self._ensure_queue_registered(queue_key) - # 3. Ensure worker is started + # 4. Ensure worker is started await self._ensure_worker_started() - # 4. Put task in queue and update status - logger.debug(f'[TaskQueue] Adding task {request_id} to queue (current size: {self._task_queue.qsize()})') - await self._task_queue.put((request_id, coro, model_id)) + # 5. Put task in queue and update status + q = self._task_queues[queue_key] + logger.debug( + f'[TaskQueue] Adding task {request_id} to queue key={queue_key} (current size: {q.qsize()}) type={task_type}' # noqa: E501 + ) + await q.put( + _QueuedTask( + request_id=request_id, + coro_factory=coro_factory, + model_id=model_id, + token=token, + input_tokens=input_tokens, + task_type=task_type, + created_at=time.monotonic(), + )) self.state.store_future_status( request_id, TaskStatus.QUEUED.value, model_id, queue_state=QueueState.ACTIVE.value) - logger.debug(f'[TaskQueue] Task {request_id} queued, new queue size: {self._task_queue.qsize()}') + logger.debug(f'[TaskQueue] Task {request_id} queued, new queue size: {q.qsize()} key={queue_key}') + + self._new_task_event.set() return {'request_id': request_id, 'model_id': model_id} @@ -296,8 +515,9 @@ def get_queue_stats(self) -> dict[str, Any]: Dict with queue size and worker status. """ return { - 'queue_size': self._task_queue.qsize(), - 'worker_running': self._worker_started and self._worker_task is not None, + 'queue_size': sum(q.qsize() for q in self._task_queues.values()), + 'queue_count': len(self._task_queues), + 'worker_running': self._worker_task is not None and not self._worker_task.done(), 'rate_limit_config': { 'rps_limit': self._task_queue_config.rps_limit, 'tps_limit': self._task_queue_config.tps_limit, @@ -341,4 +561,10 @@ async def shutdown_task_queue(self) -> None: except asyncio.CancelledError: pass + self._worker_task = None + self._worker_started = False + + self._task_queues.clear() + self._queue_order.clear() + logger.debug('[TaskQueue] Task queue shutdown complete') diff --git a/src/twinkle_client/__init__.py b/src/twinkle_client/__init__.py index e600023c..5a6928e9 100644 --- a/src/twinkle_client/__init__.py +++ b/src/twinkle_client/__init__.py @@ -21,7 +21,7 @@ def init_tinker_compat_client(base_url: str | None = None, api_key: str | None = # Apply patch to bypass tinker:// prefix validation patch_tinker() - if api_key is None: + if not api_key: api_key = get_api_key() if base_url and not base_url.startswith(('http://', 'https://')): diff --git a/src/twinkle_client/dataloader/dataloader.py b/src/twinkle_client/dataloader/dataloader.py index f43ecea9..3cd2b564 100644 --- a/src/twinkle_client/dataloader/dataloader.py +++ b/src/twinkle_client/dataloader/dataloader.py @@ -10,13 +10,11 @@ # ============================================================================ from typing import Callable, Type, Union - +from twinkle_client.http import http_post, heartbeat_manager from twinkle.dataset import Dataset from twinkle.processor import InputProcessor -from twinkle_client.http import heartbeat_manager, http_post - -class DataLoader: +class DataLoader(object): """Client wrapper for DataLoader that calls server HTTP endpoints.""" def __init__(self, dataset: Union[Dataset, Callable], **kwargs): @@ -28,11 +26,9 @@ def __init__(self, dataset: Union[Dataset, Callable], **kwargs): json_data={ 'processor_type': 'dataloader', 'class_type': 'DataLoader', - **{ - 'dataset': dataset - }, - **kwargs - }) + **{'dataset': dataset}, **kwargs + } + ) response.raise_for_status() self.processor_id = response.json()['processor_id'] heartbeat_manager.register_processor(self.processor_id) @@ -43,6 +39,7 @@ def __del__(self): except: pass + def __len__(self): response = http_post( url=f'{self.server_url}/processors/call', @@ -50,9 +47,11 @@ def __len__(self): 'processor_id': self.processor_id, 'function': '__len__', **{}, - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def set_processor(self, processor_cls: Union[Type[InputProcessor], str, InputProcessor, Callable], **kwargs): response = http_post( @@ -60,13 +59,13 @@ def set_processor(self, processor_cls: Union[Type[InputProcessor], str, InputPro json_data={ 'processor_id': self.processor_id, 'function': 'set_processor', - **{ - 'processor_cls': processor_cls - }, + **{'processor_cls': processor_cls}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __iter__(self): response = http_post( @@ -75,16 +74,19 @@ def __iter__(self): 'processor_id': self.processor_id, 'function': '__iter__', **{}, - }) + } + ) response.raise_for_status() return self - + def __next__(self): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': '__next__', - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + \ No newline at end of file diff --git a/src/twinkle_client/dataset/base.py b/src/twinkle_client/dataset/base.py index 351a5a3b..3d5b5062 100644 --- a/src/twinkle_client/dataset/base.py +++ b/src/twinkle_client/dataset/base.py @@ -10,14 +10,14 @@ # ============================================================================ from typing import Any, Callable, Dict, Type, Union - -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.preprocessor import DataFilter, Preprocessor +from twinkle_client.http import http_post, heartbeat_manager +from twinkle.dataset import Dataset +from twinkle.dataset import DatasetMeta +from twinkle.preprocessor import DataFilter +from twinkle.preprocessor import Preprocessor from twinkle.template import Template -from twinkle_client.http import heartbeat_manager, http_post - -class Dataset: +class Dataset(object): """Client wrapper for Dataset that calls server HTTP endpoints.""" def __init__(self, dataset_meta: DatasetMeta, **kwargs): @@ -29,11 +29,9 @@ def __init__(self, dataset_meta: DatasetMeta, **kwargs): json_data={ 'processor_type': 'dataset', 'class_type': 'Dataset', - **{ - 'dataset_meta': dataset_meta - }, - **kwargs - }) + **{'dataset_meta': dataset_meta}, **kwargs + } + ) response.raise_for_status() self.processor_id = response.json()['processor_id'] heartbeat_manager.register_processor(self.processor_id) @@ -44,19 +42,20 @@ def __del__(self): except: pass + def set_template(self, template_func: Union[Template, Type[Template], str], **kwargs): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': 'set_template', - **{ - 'template_func': template_func - }, + **{'template_func': template_func}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def encode(self, add_generation_prompt: bool = False, **kwargs): response = http_post( @@ -64,13 +63,13 @@ def encode(self, add_generation_prompt: bool = False, **kwargs): json_data={ 'processor_id': self.processor_id, 'function': 'encode', - **{ - 'add_generation_prompt': add_generation_prompt - }, + **{'add_generation_prompt': add_generation_prompt}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def check(self, **kwargs): response = http_post( @@ -80,49 +79,39 @@ def check(self, **kwargs): 'function': 'check', **{}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + - def map(self, - preprocess_func: Union[Preprocessor, Callable, str, Type[Preprocessor]], - dataset_meta: DatasetMeta = None, - init_args: Dict[str, Any] = None, - **kwargs): + def map(self, preprocess_func: Union[Preprocessor, Callable, str, Type[Preprocessor]], dataset_meta: DatasetMeta = None, init_args: Dict[str, Any] = None, **kwargs): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': 'map', - **{ - 'preprocess_func': preprocess_func, - 'dataset_meta': dataset_meta, - 'init_args': init_args - }, + **{'preprocess_func': preprocess_func, 'dataset_meta': dataset_meta, 'init_args': init_args}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + - def filter(self, - filter_func: Union[Callable, str, Type[DataFilter], DataFilter], - dataset_meta: DatasetMeta = None, - init_args: Dict[str, Any] = None, - **kwargs): + def filter(self, filter_func: Union[Callable, str, Type[DataFilter], DataFilter], dataset_meta: DatasetMeta = None, init_args: Dict[str, Any] = None, **kwargs): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': 'filter', - **{ - 'filter_func': filter_func, - 'dataset_meta': dataset_meta, - 'init_args': init_args - }, + **{'filter_func': filter_func, 'dataset_meta': dataset_meta, 'init_args': init_args}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def add_dataset(self, dataset_meta: DatasetMeta, **kwargs): response = http_post( @@ -130,26 +119,26 @@ def add_dataset(self, dataset_meta: DatasetMeta, **kwargs): json_data={ 'processor_id': self.processor_id, 'function': 'add_dataset', - **{ - 'dataset_meta': dataset_meta - }, + **{'dataset_meta': dataset_meta}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + - def mix_dataset(self, interleave=True): + def mix_dataset(self, interleave = True): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': 'mix_dataset', - **{ - 'interleave': interleave - }, - }) + **{'interleave': interleave}, + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __getitem__(self, idx): response = http_post( @@ -157,12 +146,12 @@ def __getitem__(self, idx): json_data={ 'processor_id': self.processor_id, 'function': '__getitem__', - **{ - 'idx': idx - }, - }) + **{'idx': idx}, + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __len__(self): response = http_post( @@ -171,6 +160,8 @@ def __len__(self): 'processor_id': self.processor_id, 'function': '__len__', **{}, - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + \ No newline at end of file diff --git a/src/twinkle_client/dataset/iterable_dataset.py b/src/twinkle_client/dataset/iterable_dataset.py index f8bd650a..347d1012 100644 --- a/src/twinkle_client/dataset/iterable_dataset.py +++ b/src/twinkle_client/dataset/iterable_dataset.py @@ -9,12 +9,11 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ +from twinkle_client.http import http_post, heartbeat_manager +from twinkle.dataset import Dataset +from twinkle.dataset import DatasetMeta from torch.utils.data import IterableDataset -from twinkle.dataset import Dataset, DatasetMeta -from twinkle_client.http import heartbeat_manager, http_post - - class IterableDataset(IterableDataset): """Client wrapper for IterableDataset that calls server HTTP endpoints.""" @@ -27,11 +26,9 @@ def __init__(self, dataset_meta: DatasetMeta, **kwargs): json_data={ 'processor_type': 'dataset', 'class_type': 'IterableDataset', - **{ - 'dataset_meta': dataset_meta - }, - **kwargs - }) + **{'dataset_meta': dataset_meta}, **kwargs + } + ) response.raise_for_status() self.processor_id = response.json()['processor_id'] heartbeat_manager.register_processor(self.processor_id) @@ -42,19 +39,20 @@ def __del__(self): except: pass + def add_dataset(self, dataset_meta: DatasetMeta, **kwargs): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': 'add_dataset', - **{ - 'dataset_meta': dataset_meta - }, + **{'dataset_meta': dataset_meta}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __len__(self): response = http_post( @@ -63,9 +61,11 @@ def __len__(self): 'processor_id': self.processor_id, 'function': '__len__', **{}, - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __getitem__(self, idx): response = http_post( @@ -73,12 +73,12 @@ def __getitem__(self, idx): json_data={ 'processor_id': self.processor_id, 'function': '__getitem__', - **{ - 'idx': idx - }, - }) + **{'idx': idx}, + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __iter__(self): response = http_post( @@ -87,16 +87,19 @@ def __iter__(self): 'processor_id': self.processor_id, 'function': '__iter__', **{}, - }) + } + ) response.raise_for_status() return self - + def __next__(self): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': '__next__', - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + \ No newline at end of file diff --git a/src/twinkle_client/dataset/iterable_packing_dataset.py b/src/twinkle_client/dataset/iterable_packing_dataset.py index 8383e55a..ce2d918d 100644 --- a/src/twinkle_client/dataset/iterable_packing_dataset.py +++ b/src/twinkle_client/dataset/iterable_packing_dataset.py @@ -9,23 +9,17 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from torch.utils.data import IterableDataset from typing import Type, Union - -from twinkle.dataset import Dataset, DatasetMeta +from twinkle_client.http import http_post, heartbeat_manager +from twinkle.dataset import Dataset +from twinkle.dataset import DatasetMeta from twinkle.template import Template -from twinkle_client.http import heartbeat_manager, http_post - +from torch.utils.data import IterableDataset class IterablePackingDataset(IterableDataset): """Client wrapper for IterablePackingDataset that calls server HTTP endpoints.""" - def __init__(self, - dataset_meta: DatasetMeta, - packing_interval: int = 128, - packing_num_proc: int = 1, - cyclic: bool = False, - **kwargs): + def __init__(self, dataset_meta: DatasetMeta, packing_interval: int = 128, packing_num_proc: int = 1, cyclic: bool = False, **kwargs): from twinkle_client.http import get_base_url self.server_url = get_base_url() @@ -34,14 +28,9 @@ def __init__(self, json_data={ 'processor_type': 'dataset', 'class_type': 'IterablePackingDataset', - **{ - 'dataset_meta': dataset_meta, - 'packing_interval': packing_interval, - 'packing_num_proc': packing_num_proc, - 'cyclic': cyclic - }, - **kwargs - }) + **{'dataset_meta': dataset_meta, 'packing_interval': packing_interval, 'packing_num_proc': packing_num_proc, 'cyclic': cyclic}, **kwargs + } + ) response.raise_for_status() self.processor_id = response.json()['processor_id'] heartbeat_manager.register_processor(self.processor_id) @@ -52,19 +41,20 @@ def __del__(self): except: pass + def set_template(self, template_cls: Union[Type[Template], str, Template], **kwargs): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': 'set_template', - **{ - 'template_cls': template_cls - }, + **{'template_cls': template_cls}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def pack_dataset(self): response = http_post( @@ -73,9 +63,11 @@ def pack_dataset(self): 'processor_id': self.processor_id, 'function': 'pack_dataset', **{}, - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __iter__(self): response = http_post( @@ -84,16 +76,19 @@ def __iter__(self): 'processor_id': self.processor_id, 'function': '__iter__', **{}, - }) + } + ) response.raise_for_status() return self - + def __next__(self): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': '__next__', - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + \ No newline at end of file diff --git a/src/twinkle_client/dataset/lazy_dataset.py b/src/twinkle_client/dataset/lazy_dataset.py index 586e6927..ce8178b1 100644 --- a/src/twinkle_client/dataset/lazy_dataset.py +++ b/src/twinkle_client/dataset/lazy_dataset.py @@ -9,11 +9,11 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from twinkle.dataset import Dataset, DatasetMeta -from twinkle_client.http import heartbeat_manager, http_post +from twinkle_client.http import http_post, heartbeat_manager +from twinkle.dataset import Dataset +from twinkle.dataset import DatasetMeta from .base import Dataset - class LazyDataset(Dataset): """Client wrapper for LazyDataset that calls server HTTP endpoints.""" @@ -26,11 +26,9 @@ def __init__(self, dataset_meta: DatasetMeta, **kwargs): json_data={ 'processor_type': 'dataset', 'class_type': 'LazyDataset', - **{ - 'dataset_meta': dataset_meta - }, - **kwargs - }) + **{'dataset_meta': dataset_meta}, **kwargs + } + ) response.raise_for_status() self.processor_id = response.json()['processor_id'] heartbeat_manager.register_processor(self.processor_id) @@ -41,6 +39,7 @@ def __del__(self): except: pass + def encode(self, **kwargs): response = http_post( url=f'{self.server_url}/processors/call', @@ -49,9 +48,11 @@ def encode(self, **kwargs): 'function': 'encode', **{}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def check(self, **kwargs): response = http_post( @@ -61,9 +62,11 @@ def check(self, **kwargs): 'function': 'check', **{}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __getitem__(self, idx): response = http_post( @@ -71,12 +74,12 @@ def __getitem__(self, idx): json_data={ 'processor_id': self.processor_id, 'function': '__getitem__', - **{ - 'idx': idx - }, - }) + **{'idx': idx}, + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + def __len__(self): response = http_post( @@ -85,6 +88,8 @@ def __len__(self): 'processor_id': self.processor_id, 'function': '__len__', **{}, - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + \ No newline at end of file diff --git a/src/twinkle_client/dataset/packing_dataset.py b/src/twinkle_client/dataset/packing_dataset.py index 8783c2ab..0d91546f 100644 --- a/src/twinkle_client/dataset/packing_dataset.py +++ b/src/twinkle_client/dataset/packing_dataset.py @@ -9,11 +9,11 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from twinkle.dataset import Dataset, DatasetMeta -from twinkle_client.http import heartbeat_manager, http_post +from twinkle_client.http import http_post, heartbeat_manager +from twinkle.dataset import Dataset +from twinkle.dataset import DatasetMeta from .base import Dataset - class PackingDataset(Dataset): """Client wrapper for PackingDataset that calls server HTTP endpoints.""" @@ -39,7 +39,7 @@ def __del__(self): except: pass - + def pack_dataset(self): response = http_post( url=f'{self.server_url}/processors/call', @@ -50,8 +50,8 @@ def pack_dataset(self): } ) response.raise_for_status() - return response.json()['result'] - + return response.json()["result"] + def __getitem__(self, index): response = http_post( @@ -63,8 +63,8 @@ def __getitem__(self, index): } ) response.raise_for_status() - return response.json()['result'] - + return response.json()["result"] + def __len__(self): response = http_post( @@ -76,4 +76,5 @@ def __len__(self): } ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + \ No newline at end of file diff --git a/src/twinkle_client/model/multi_lora_transformers.py b/src/twinkle_client/model/multi_lora_transformers.py index 35f8c849..f681c96b 100644 --- a/src/twinkle_client/model/multi_lora_transformers.py +++ b/src/twinkle_client/model/multi_lora_transformers.py @@ -8,12 +8,11 @@ # 1. Modify the source files in src/twinkle/ # 2. Run: python client_tools/client_generator.py # ============================================================================ +from typing import Any, Optional, Union, Type, Dict, Literal, List import uuid -from typing import Any, Dict, List, Literal, Optional, Type, Union - +from twinkle_client.http import http_post, heartbeat_manager from twinkle import DeviceMesh from twinkle.data_format import InputFeature, Trajectory -from twinkle_client.http import heartbeat_manager, http_post class MultiLoraTransformersModel: diff --git a/src/twinkle_client/processor/base.py b/src/twinkle_client/processor/base.py index 47e28fd2..d59572a7 100644 --- a/src/twinkle_client/processor/base.py +++ b/src/twinkle_client/processor/base.py @@ -10,20 +10,14 @@ # ============================================================================ from typing import List, Literal, Optional, Union - +from twinkle_client.http import http_post, heartbeat_manager from twinkle import DeviceMesh from twinkle.data_format import InputFeature -from twinkle_client.http import heartbeat_manager, http_post - -class InputProcessor: +class InputProcessor(object): """Client wrapper for InputProcessor that calls server HTTP endpoints.""" - def __init__(self, - device_mesh: Optional[DeviceMesh] = None, - padding_free: bool = False, - framework: Literal['transformers', 'megatron'] = 'transformers', - **kwargs): + def __init__(self, device_mesh: Optional[DeviceMesh] = None, padding_free: bool = False, framework: Literal['transformers', 'megatron'] = 'transformers', **kwargs): from twinkle_client.http import get_base_url self.server_url = get_base_url() @@ -32,13 +26,9 @@ def __init__(self, json_data={ 'processor_type': 'processor', 'class_type': 'InputProcessor', - **{ - 'device_mesh': device_mesh, - 'padding_free': padding_free, - 'framework': framework - }, - **kwargs - }) + **{'device_mesh': device_mesh, 'padding_free': padding_free, 'framework': framework}, **kwargs + } + ) response.raise_for_status() self.processor_id = response.json()['processor_id'] heartbeat_manager.register_processor(self.processor_id) @@ -49,16 +39,17 @@ def __del__(self): except: pass + def __call__(self, inputs: Union[InputFeature, List[InputFeature]], **kwargs): response = http_post( url=f'{self.server_url}/processors/call', json_data={ 'processor_id': self.processor_id, 'function': '__call__', - **{ - 'inputs': inputs - }, + **{'inputs': inputs}, **kwargs - }) + } + ) response.raise_for_status() - return response.json()['result'] + return response.json()["result"] + \ No newline at end of file diff --git a/src/twinkle_client/sampler/vllm_sampler.py b/src/twinkle_client/sampler/vllm_sampler.py index 5faed5e1..907881a4 100644 --- a/src/twinkle_client/sampler/vllm_sampler.py +++ b/src/twinkle_client/sampler/vllm_sampler.py @@ -8,12 +8,11 @@ # 1. Modify the source files in src/twinkle/ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from peft import PeftConfig -from typing import Any, Dict, List, Optional, Union - -from twinkle.data_format import InputFeature, Trajectory +from typing import Any, Optional, List, Dict, Union +from twinkle_client.http import http_post, heartbeat_manager from twinkle.sampler.base import Sampler -from twinkle_client.http import heartbeat_manager, http_post +from peft import PeftConfig +from twinkle.data_format import Trajectory, InputFeature class vLLMSampler(Sampler): @@ -32,14 +31,20 @@ def __init__(self, model_id: str, **kwargs): if '://' in model_id: model_id = model_id.split('://')[1] self.server_url = f'{self.server_url}/samplers/{model_id}' - response = http_post(url=f'{self.server_url}/create', json_data=kwargs) + response = http_post( + url=f'{self.server_url}/create', + json_data=kwargs + ) response.raise_for_status() def _send_adapter_heartbeat(self): """Internal method to send adapter heartbeat.""" if not self.adapter_name: return - response = http_post(url=f'{self.server_url}/heartbeat', json_data={'adapter_name': self.adapter_name}) + response = http_post( + url=f'{self.server_url}/heartbeat', + json_data={'adapter_name': self.adapter_name} + ) response.raise_for_status() def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs): @@ -48,16 +53,16 @@ def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs config = config.__dict__ response = http_post( url=f'{self.server_url}/add_adapter_to_sampler', - json_data={ - 'adapter_name': adapter_name, - 'config': config, - **kwargs - }) + json_data={'adapter_name': adapter_name, 'config': config, **kwargs} + ) response.raise_for_status() # Register adapter for automatic heartbeat after successful creation self.adapter_name = adapter_name - heartbeat_manager.register_adapter(self.adapter_name, self._send_adapter_heartbeat) + heartbeat_manager.register_adapter( + self.adapter_name, + self._send_adapter_heartbeat + ) return response.json() @@ -98,7 +103,10 @@ def sample( if adapter_uri is not None: json_data['adapter_uri'] = adapter_uri - response = http_post(url=f'{self.server_url}/sample', json_data=json_data) + response = http_post( + url=f'{self.server_url}/sample', + json_data=json_data + ) response.raise_for_status() return response.json() @@ -106,10 +114,7 @@ def set_template(self, template_cls: str, adapter_name: str = '', **kwargs): """Set the template for encoding trajectories.""" response = http_post( url=f'{self.server_url}/set_template', - json_data={ - 'template_cls': template_cls, - 'adapter_name': adapter_name, - **kwargs - }) + json_data={'template_cls': template_cls, 'adapter_name': adapter_name, **kwargs} + ) response.raise_for_status() return response.json()