diff --git a/README.md b/README.md index 39cb0832..34c915e0 100644 --- a/README.md +++ b/README.md @@ -203,7 +203,7 @@ if __name__ == '__main__': import os from tqdm import tqdm from tinker import types -from twinkle_client import init_tinker_client +from twinkle import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor diff --git a/README_ZH.md b/README_ZH.md index 65edf58d..8462f148 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -186,7 +186,7 @@ if __name__ == '__main__': import os from tqdm import tqdm from tinker import types -from twinkle_client import init_tinker_client +from twinkle import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor diff --git a/cookbook/client/tinker/lora.py b/cookbook/client/tinker/custom_service/lora.py similarity index 88% rename from cookbook/client/tinker/lora.py rename to cookbook/client/tinker/custom_service/lora.py index e94719bc..e1eab4d3 100644 --- a/cookbook/client/tinker/lora.py +++ b/cookbook/client/tinker/custom_service/lora.py @@ -8,13 +8,10 @@ # Step 1: Load environment variables from a .env file (e.g., API tokens) import dotenv - dotenv.load_dotenv('.env') -import os - # Step 2: Initialize Tinker client before importing ServiceClient -from twinkle_client import init_tinker_client +from twinkle import init_tinker_client init_tinker_client() @@ -22,8 +19,12 @@ from tinker import ServiceClient service_client = ServiceClient( - base_url='http://www.modelscope.cn/twinkle', - api_key=os.environ.get('MODELSCOPE_TOKEN') + # BASE_URL can be a local server endpoint such as http://localhost:8000, or + # points to a previously deployed remote server, or + # modelscope server such as 'http://www.modelscope.cn/twinkle' + base_url='http://localhost:8000', + # API_KEY can be empty or a meaninful one according to sever configuration + api_key='EMPTY-TOKEN' ) # Step 4: List models available on the server to verify the connection @@ -40,10 +41,12 @@ # You can resume from either: # 1. A twinkle path: "twinkle://...//weights/" -# 2. A model id on hub: "/" +# 2. A model id on ModelScope hub: "ms:///" +# 3. A local path to a checkpoint directory # Example: # resume_path = "twinkle://20260131_170251-Qwen_Qwen2_5-0_5B-Instruct-7275126c/weights/pig-latin-lora-epoch-1" -# resume_path = "AlexEz/20260205_163645-Qwen_Qwen2_5-7B-Instruct-385d5c17_pig-latin-lora-epoch-1" +# resume_path = "ms://AlexEz/20260205_163645-Qwen_Qwen2_5-7B-Instruct-385d5c17_pig-latin-lora-epoch-1" +# resume_path = "/path/to/local/checkpoint/directory" resume_path = '' print(f'Found {len(response.training_runs)} training runs') @@ -58,7 +61,7 @@ # Step 6: Create or resume a training client. # If resume_path is set, it restores both model weights and optimizer state. -base_model = 'Qwen/Qwen2.5-7B-Instruct' +base_model = 'Qwen/Qwen3-4B' if not resume_path: training_client = service_client.create_lora_training_client(base_model=base_model) else: @@ -85,19 +88,7 @@ { 'input': 'pickle jar', 'output': 'ickle-pay ar-jay' - }, - { - 'input': 'space exploration', - 'output': 'ace-spay exploration-way' - }, - { - 'input': 'rubber duck', - 'output': 'ubber-ray uck-day' - }, - { - 'input': 'coding wizard', - 'output': 'oding-cay izard-way' - }, + } ] from modelscope import AutoTokenizer @@ -181,6 +172,7 @@ def process_example(example: dict, tokenizer) -> types.Datum: # Step 9: Publish the final checkpoint to ModelScope Hub. # NOTE: Requires a valid ModelScope token set as api_key when initializing the client. -# The published model name will be: {run_id}_{checkpoint_name} +# The model will be published under the owner of the supplied ModelScope token, +# with model name formatted as: {run_id}_{checkpoint_name} rest_client.publish_checkpoint_from_tinker_path(save_result.path).result() print('Published checkpoint') diff --git a/cookbook/client/tinker/megatron/server.py b/cookbook/client/tinker/custom_service/megatron/server.py similarity index 100% rename from cookbook/client/tinker/megatron/server.py rename to cookbook/client/tinker/custom_service/megatron/server.py diff --git a/cookbook/client/tinker/megatron/server_config_7b.yaml b/cookbook/client/tinker/custom_service/megatron/server_config.yaml similarity index 58% rename from cookbook/client/tinker/megatron/server_config_7b.yaml rename to cookbook/client/tinker/custom_service/megatron/server_config.yaml index dc47d796..04f8c12c 100644 --- a/cookbook/client/tinker/megatron/server_config_7b.yaml +++ b/cookbook/client/tinker/custom_service/megatron/server_config.yaml @@ -22,9 +22,9 @@ applications: import_path: server # Python module to import args: server_config: - per_token_model_limit: 1 # Maximum number of models (adapters) per token (server-globally enforced) + per_token_model_limit: 3 # Maximum number of models (adapters) per token (server-globally enforced) supported_models: - - Qwen/Qwen2.5-7B-Instruct + - Qwen/Qwen3-4B deployments: - name: TinkerCompatServer autoscaling_config: @@ -36,17 +36,17 @@ applications: # 2. Model Service (commented out) - Would host the base model for training. # Uncomment and configure if you need a training model worker. - - name: models-Qwen2.5-7B-Instruct - route_prefix: /api/v1/model/Qwen/Qwen2.5-7B-Instruct + - name: models-Qwen3-4B + route_prefix: /api/v1/model/Qwen/Qwen3-4B import_path: model args: use_megatron: true - model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier + model_id: "ms://Qwen/Qwen3-4B" # ModelScope model identifier max_length: 10240 nproc_per_node: 2 # Number of GPU processes per node device_group: name: model - ranks: [0,1] # GPU rank indices + ranks: 2 # GPU rank indices device_type: cuda device_mesh: device_type: cuda @@ -58,11 +58,12 @@ applications: adapter_config: adapter_timeout: 30 # Seconds before idle adapter unload adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours) + max_loras: 1 # Maximum number of LoRA adapters per model deployments: - name: ModelManagement autoscaling_config: - min_replicas: 1 - max_replicas: 1 + min_replicas: 2 + max_replicas: 2 target_ongoing_requests: 16 ray_actor_options: num_cpus: 0.1 @@ -72,36 +73,36 @@ applications: # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - # - name: sampler-Qwen2.5-7B-Instruct - # route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct - # import_path: sampler - # args: - # model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier - # nproc_per_node: 2 # Number of GPU processes per node - # sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) - # engine_args: # vLLM engine-specific settings - # max_model_len: 4096 # Maximum sequence length the engine supports - # gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) - # enable_lora: true # Allow loading LoRA adapters during inference - # logprobs_mode: processed_logprobs # Logprobs mode for sampling results - # device_group: # Logical device group for the sampler - # name: sampler - # ranks: [2] # GPU rank indices to use - # device_type: cuda - # device_mesh: - # device_type: cuda - # dp_size: 1 - # queue_config: - # rps_limit: 100 # Max requests per second - # tps_limit: 100000 # Max tokens per second - # deployments: - # - name: SamplerManagement - # autoscaling_config: - # min_replicas: 1 - # max_replicas: 1 - # target_ongoing_requests: 16 - # ray_actor_options: - # num_cpus: 0.1 - # runtime_env: - # env_vars: - # TWINKLE_TRUST_REMOTE_CODE: "0" + - name: sampler-Qwen3-4B + route_prefix: /api/v1/sampler/Qwen/Qwen3-4B + import_path: sampler + args: + model_id: "ms://Qwen/Qwen3-4B" # ModelScope model identifier + nproc_per_node: 2 # Number of GPU processes per node + sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) + engine_args: # vLLM engine-specific settings + max_model_len: 4096 # Maximum sequence length the engine supports + gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) + enable_lora: true # Allow loading LoRA adapters during inference + logprobs_mode: processed_logprobs # Logprobs mode for sampling results + device_group: # Logical device group for the sampler + name: sampler + ranks: 1 # Number of GPUs to use + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 1 + queue_config: + rps_limit: 100 # Max requests per second + tps_limit: 100000 # Max tokens per second + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" diff --git a/cookbook/client/tinker/sample.py b/cookbook/client/tinker/custom_service/sample.py similarity index 92% rename from cookbook/client/tinker/sample.py rename to cookbook/client/tinker/custom_service/sample.py index 84931a59..dc48833c 100644 --- a/cookbook/client/tinker/sample.py +++ b/cookbook/client/tinker/custom_service/sample.py @@ -9,7 +9,7 @@ from twinkle.data_format import Message, Trajectory from twinkle.template import Template -from twinkle_client import init_tinker_client +from twinkle import init_tinker_client # Step 1: Initialize Tinker client init_tinker_client() @@ -17,10 +17,10 @@ from tinker import ServiceClient # Step 2: Define the base model and connect to the server -base_model = 'Qwen/Qwen3-30B-A3B-Instruct-2507' +base_model = 'Qwen/Qwen3-4B' service_client = ServiceClient( - base_url='http://www.modelscope.cn/twinkle', - api_key=os.environ.get('MODELSCOPE_TOKEN') + base_url='http://localhost:8000', + api_key='EMPTY-TOKEN' ) # Step 3: Create a sampling client by loading weights from a saved checkpoint. diff --git a/cookbook/client/tinker/custom_service/self_cognition.py b/cookbook/client/tinker/custom_service/self_cognition.py new file mode 100644 index 00000000..9b78a14f --- /dev/null +++ b/cookbook/client/tinker/custom_service/self_cognition.py @@ -0,0 +1,137 @@ +# Tinker-Compatible Client - Self-Cognition Training & Evaluation Example +# +# This script demonstrates two workflows using the Tinker-compatible client: +# 1. train(): Fine-tune a model on a self-cognition dataset so it learns +# a custom identity (name, author). +# 2. eval(): Load a trained checkpoint and sample from it to verify +# that the model has learned the custom identity. +# The server must be running first (see server.py and server_config.yaml). +import os +from tqdm import tqdm +from tinker import types +from twinkle import init_tinker_client +from twinkle.data_format import Message, Trajectory +from twinkle.template import Template +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.preprocessor import SelfCognitionProcessor +from twinkle.server.tinker.common import input_feature_to_datum + +# Initialize the Tinker client before importing ServiceClient +init_tinker_client() + +from tinker import ServiceClient + +# The base model to fine-tune / evaluate +base_model = 'Qwen/Qwen3-4B' +base_url = 'http://localhost:8000' +api_key = 'EMPTY_API_KEY' + + +def train(): + # Step 1: Prepare the dataset + + # Load the self-cognition dataset from ModelScope (first 500 examples) + dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(500))) + + # Apply the chat template matching the base model (max 256 tokens per sample) + dataset.set_template('Template', model_id=f'ms://{base_model}', max_length=256) + + # Replace placeholder names with custom model/author identity + dataset.map(SelfCognitionProcessor('twinkle模型', 'twinkle团队'), load_from_cache_file=False) + + # Tokenize and encode the dataset into model-ready input features + dataset.encode(batched=True, load_from_cache_file=False) + + # Wrap the dataset into a DataLoader that yields batches of size 8 + dataloader = DataLoader(dataset=dataset, batch_size=8) + + # Step 2: Initialize the training client + + + service_client = ServiceClient( + base_url=base_url, + api_key=api_key + ) + + # Create a LoRA training client for the base model (rank=16 for the LoRA adapter) + training_client = service_client.create_lora_training_client(base_model=base_model, rank=16) + + # Step 3: Run the training loop + + for epoch in range(3): + print(f'Epoch {epoch}') + for step, batch in tqdm(enumerate(dataloader)): + # Convert each InputFeature into a Datum for the Tinker API + input_datum = [input_feature_to_datum(input_feature) for input_feature in batch] + + # Send data to server: forward + backward pass (computes gradients) + fwdbwd_future = training_client.forward_backward(input_datum, 'cross_entropy') + + # Optimizer step: update model weights with Adam + optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4)) + + # Wait for both operations to complete + fwdbwd_result = fwdbwd_future.result() + optim_result = optim_future.result() + + # Compute weighted average log-loss per token for monitoring + # logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs]) + # weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in input_datum]) + # print(f'Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}') + print(f'Training Metrics: {optim_result}') + + # Save a checkpoint after each epoch + save_future = training_client.save_state(f'twinkle-lora-{epoch}') + save_result = save_future.result() + print(f'Saved checkpoint to {save_result.path}') + + +def eval(): + # Step 1: Load the trained LoRA checkpoint for inference + + # Path to a previously saved LoRA checkpoint (twinkle:// URI) + weight_path = 'twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2' + + service_client = ServiceClient(base_url=base_url, api_key=os.environ.get('MODELSCOPE_TOKEN')) + sampling_client = service_client.create_sampling_client(model_path=weight_path, base_model=base_model) + + # Step 2: Prepare the chat prompt + + # Build a multi-turn conversation to test the model's self-cognition + template = Template(model_id=f'ms://{base_model}') + + trajectory = Trajectory( + messages=[ + Message(role='system', content='You are a helpful assistant'), + Message(role='user', content='你是谁?'), + ] + ) + + input_feature = template.encode(trajectory, add_generation_prompt=True) + + input_ids = input_feature['input_ids'].tolist() + + # Step 3: Generate responses + + prompt = types.ModelInput.from_ints(input_ids) + params = types.SamplingParams( + max_tokens=50, # Maximum tokens to generate + temperature=0.2, # Low temperature for more focused responses + stop=['\n'] # Stop at newline + ) + + # Sample 8 independent completions + print('Sampling...') + future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8) + result = future.result() + + # Decode and print each response + print('Responses:') + for i, seq in enumerate(result.sequences): + print(f'{i}: {repr(template.decode(seq.tokens))}') + + +if __name__ == '__main__': + train() # Uncomment to run training + # eval() # Run evaluation / inference diff --git a/cookbook/client/tinker/custom_service/short_math_grpo.py b/cookbook/client/tinker/custom_service/short_math_grpo.py new file mode 100644 index 00000000..d35102b7 --- /dev/null +++ b/cookbook/client/tinker/custom_service/short_math_grpo.py @@ -0,0 +1,410 @@ +# Tinker-Compatible Client - Math GRPO Training Example +# +# This script demonstrates Math problem training using the +# Tinker-compatible client API with save_weights_for_sampler for weight sync. +# Instead of calling sync_weights directly, it periodically saves weights and +# creates a sampling client for generation. +# +# Flow: +# 1. Prepare Math dataset (client-side) +# 2. Initialize Tinker-compatible training & sampling clients +# 3. Training loop: +# a. Every SYNC_INTERVAL steps: save_weights_for_sampler → sampling_client +# b. Sample completions from the sampling client +# c. Compute rewards and advantages (client-side) +# d. Train on sampled data weighted by advantages +# e. Optimizer step +# +# The server must be running first (see server.py and server_config.yaml). +# Requires both model and sampler services to be configured. +import gc +import numpy as np +import os +import re +from tinker import types +from typing import List, Tuple + +from twinkle import init_tinker_client +from twinkle import get_logger +from twinkle.advantage import GRPOAdvantage +from twinkle.data_format import Message, Trajectory +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.preprocessor import Preprocessor +from twinkle.reward.base import Reward +from twinkle.metric import CompletionRewardMetric +from twinkle.template import Template + +logger = get_logger() + +# ========== Configuration ========== +BASE_MODEL = 'Qwen/Qwen3-4B' +NUM_GENERATIONS = 8 +MAX_NEW_TOKENS = 4096 +LEARNING_RATE = 1e-4 +MAX_STEPS = 1000 +BATCH_SIZE = 2 +TEMPERATURE = 1.0 +SYNC_INTERVAL = 1 # Save weights for sampler every N steps +LORA_RANK = 8 +DATA_NUM = 2000 # Number of Math samples to use + +SYSTEM_PROMPT = ('You are a math assistant that values brevity. ' + 'Solve problems with minimal but correct reasoning.\n\n' + 'Rules:\n' + '1. Use tags for reasoning\n' + '2. Final answer after ####\n\n' + 'Example:\nKey step1 -> Ket step 2 -> conclusion\n#### 42') + + + +class MathPreprocessor(Preprocessor): + + def __call__(self, sample): + if sample['level'] not in ('Level 4', 'Level 5'): + return Trajectory(messages=[], user_data=[]) + + def get_boxed_answer(text): + match = re.search(r'\\boxed{([^}]*)}', text) + return match.group(1) if match else None + + ground_truth = get_boxed_answer(sample['solution']) + if ground_truth is None: + return Trajectory(messages=[], user_data=[]) + problem = sample['problem'] + return Trajectory( + messages=[ + Message(role='system', content=SYSTEM_PROMPT), + Message(role='user', content=problem), + ], + user_data=[('ground_truth', ground_truth)], + ) + + +# ========== Math Reward Functions ========== +class MathAccuracyReward(Reward): + """Accuracy reward for Math: checks if the model's answer matches ground truth. + + Extracts the last '#### ' from model output and compares with ground truth. + Returns 1.0 for correct, 0.0 for incorrect. + """ + + @staticmethod + def extract_answer(completion: str) -> str: + """Extract the last #### answer from model completion.""" + # Only check last 500 chars for efficiency + text = completion[-500:] if len(completion) > 500 else completion + matches = re.findall(r'####\s*([\-\d,\.\s]+)', text) + if matches: + return matches[-1].replace(',', '').replace(' ', '').strip() + return '' + + def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]) -> List[float]: + rewards = [] + for trajectory in trajectories: + messages = trajectory.get('messages', []) + # Get model completion (last assistant message) + completion = '' + for msg in reversed(messages): + if msg.get('role') == 'assistant': + completion = msg.get('content', '') + break + + # Get ground truth from user_data + gt = '' + user_data = trajectory.get('user_data', []) + if isinstance(user_data, list): + for item in user_data: + if isinstance(item, (list, tuple)) and len(item) == 2: + if item[0] == 'ground_truth': + gt = str(item[1]) + break + + predicted = self.extract_answer(completion) + + # Numeric comparison + correct = False + if predicted and gt: + try: + correct = abs(float(predicted) - float(gt)) < 1e-5 + except (ValueError, OverflowError): + correct = predicted == gt + + rewards.append(1.0 if correct else 0.0) + return rewards + + +class MathFormatReward(Reward): + """Format reward: checks format and rewards shorter completions. + + Returns higher score for shorter completions (1.0 at length 100 or less). + Returns 0.0 if format is incorrect. + """ + + def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]) -> List[float]: + rewards = [] + for trajectory in trajectories: + messages = trajectory.get('messages', []) + completion = '' + for msg in reversed(messages): + if msg.get('role') == 'assistant': + completion = msg.get('content', '') + break + + has_think = bool(re.search(r'.*?', completion, re.DOTALL)) + has_answer = bool(re.search(r'####\s*[\-\d,\.]+', completion)) + + if not (has_think and has_answer): + rewards.append(0.0) + else: + length = len(completion) + if length <= 100: + rewards.append(1.0) + else: + reward = max(0.0, 1.0 - (length - 100) / 2000) + rewards.append(reward) + + return rewards + + +def create_math_dataset(): + """Create Math dataset.""" + meta = DatasetMeta( + 'ms://modelscope/competition_math', + subset_name='default', + split='train', + data_slice=range(DATA_NUM), + ) + dataset = Dataset(meta) + dataset.set_template('Template', model_id=BASE_MODEL, max_length=4096, truncation_strategy='delete') + dataset.map(MathPreprocessor()) + dataset.filter(lambda row: bool(row['messages'])) + dataset.encode(add_generation_prompt=True) + return dataset + + +def compute_rewards(trajectories: List[Trajectory], ) -> Tuple[List[float], List[float], List[float]]: + """Compute accuracy and format rewards for Math.""" + accuracy_reward_fn = MathAccuracyReward() + format_reward_fn = MathFormatReward() + + accuracy_rewards = accuracy_reward_fn(trajectories, []) + format_rewards = format_reward_fn(trajectories, []) + total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)] + return total_rewards, format_rewards, accuracy_rewards + + +def main(): + logger.info('Starting Math GRPO training...') + + # Step 1: Prepare dataset and dataloader (client-side) + dataset = create_math_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + template = Template(model_id=f'ms://{BASE_MODEL}') + + logger.info('Dataset and template initialized') + + # Step 2: Initialize the Tinker-compatible client + logger.info('Connecting to Tinker server...') + init_tinker_client() + + from tinker import ServiceClient + service_client = ServiceClient( + base_url='http://localhost:8000', + api_key=os.environ.get('MODELSCOPE_TOKEN') + ) + + logger.info('Creating LoRA training client...') + # Create a LoRA training client for GRPO + training_client = service_client.create_lora_training_client( + base_model=BASE_MODEL, + rank=LORA_RANK, + ) + + logger.info('Training client created successfully') + + # Step 3: Setup metrics and advantage function + advantage_fn = GRPOAdvantage() + metrics = CompletionRewardMetric() + + sampling_params = types.SamplingParams( + max_tokens=MAX_NEW_TOKENS, + temperature=TEMPERATURE, + top_p=0.95, + ) + + # The sampling client is created on-demand via save_weights_for_sampler + sampling_client = None + + step = 0 + for batch in dataloader: + if step >= MAX_STEPS: + break + + metrics.reset() + prompts = batch if isinstance(batch, list) else [batch] + + # ========== 1. Save weights for sampler (instead of sync_weights) ========== + if step % SYNC_INTERVAL == 0: + logger.info(f'Step {step}: Saving weights for sampler...') + + sampling_client = (training_client.save_weights_and_get_sampling_client(name=f'Math-step-{step}')) + logger.info(f'Step {step}: Sampling client ready') + + if sampling_client is None: + logger.warning('No sampling client available, skipping step') + step += 1 + continue + + # ========== 2. Sample completions ========== + # Convert input features to token prompts for the sampling client + all_sequences = [] + all_user_data = [] + for prompt_feature in prompts: + input_ids = prompt_feature['input_ids'] + if hasattr(input_ids, 'tolist'): + input_ids = input_ids.tolist() + prompt = types.ModelInput.from_ints(input_ids) + future = sampling_client.sample( + prompt=prompt, + sampling_params=sampling_params, + num_samples=NUM_GENERATIONS, + ) + result = future.result() + # Store both sequences and user data + for _ in range(NUM_GENERATIONS): + all_user_data.append(prompt_feature.get('user_data', [])) + all_sequences.extend(result.sequences) + + if not all_sequences: + logger.warning(f'Step {step}: No valid samples, skipping') + step += 1 + continue + + # ========== 3. Build trajectories and collect logprobs ========== + trajectories = [] + old_logps_list = [] + completion_lengths = [] + + for idx, seq in enumerate(all_sequences): + decoded_text = template.decode(seq.tokens, skip_special_tokens=True) + # Use the corresponding user data for this sequence + trajectories.append({ + 'messages': [ + { + 'role': 'system', + 'content': SYSTEM_PROMPT + }, + { + 'role': 'user', + 'content': 'Math problem' + }, # Placeholder + { + 'role': 'assistant', + 'content': decoded_text + } + ], + 'user_data': + all_user_data[idx] + }) + old_logps_list.append([lp for lp in seq.logprobs] if seq.logprobs else []) + completion_lengths.append(len(seq.tokens)) + + # ========== 4. Compute rewards ========== + total_rewards, format_rewards, accuracy_rewards = compute_rewards(trajectories) + metrics.accumulate( + None, + None, + completion_lengths=completion_lengths, + rewards={ + 'total': total_rewards, + 'format': format_rewards, + 'accuracy': accuracy_rewards, + }) + + # ========== 5. Compute advantages ========== + advantages = advantage_fn( + total_rewards, + num_generations=NUM_GENERATIONS, + scale='group', + ).tolist() + + frac_zero_std = (1.0 if all(abs(a) < 1e-8 for a in advantages) else 0.0) + if frac_zero_std == 1.0: + logger.info(f'Step {step}: All advantages are zero, skipping training') + step += 1 + continue + + # ========== 6. Train the policies with GRPO loss ========== + # Train the policies with the Advantage-Regularized policy + # gradient (GRPO) loss function. + # + # The GRPO loss function requires: + # 1. logprobs: The log probabilities of the tokens under the current policy + # 2. advantages: The advantage values for each completion + # + # The training data is constructed with: + # - model_input: The full prompt + completion tokens + # - target_tokens: The shifted tokens for next-token prediction + # - logprobs: The log probabilities from the sampling step + # - advantages: The computed advantage values + training_data = [] + for i, seq in enumerate(all_sequences): + # Build a Datum from the completion tokens with logprobs and advantages + prompt_feature = prompts[i // NUM_GENERATIONS] + prompt_ids = prompt_feature['input_ids'] + if hasattr(prompt_ids, 'tolist'): + prompt_ids = prompt_ids.tolist() + + sampled_tokens = list(seq.tokens) + logprobs = seq.logprobs if seq.logprobs else [0.0] * len(sampled_tokens) + advantage = float(advantages[i]) + + ob_len = len(prompt_ids) - 1 + input_tokens = prompt_ids + sampled_tokens[:-1] + target_tokens = [0] * ob_len + sampled_tokens + weights = [0] * ob_len + [1] * len(sampled_tokens) + padded_advantages = [0.0] * ob_len + [advantage] * len(sampled_tokens) + padded_logprobs = [0.0] * ob_len + logprobs + + datum = types.Datum( + model_input=types.ModelInput.from_ints(input_tokens), + loss_fn_inputs={ + 'target_tokens': target_tokens, + 'weights': weights, + 'logprobs': types.TensorData.from_numpy(np.array(padded_logprobs, dtype=np.float32)), + 'advantages': types.TensorData.from_numpy(np.array(padded_advantages, dtype=np.float32)), + }, + ) + training_data.append(datum) + + if not training_data: + logger.info(f'Step {step}: No training data constructed, skipping') + step += 1 + continue + + # Forward-backward pass with importance_sampling (GRPO) loss + # The training data already contains logprobs and advantages for the GRPO loss + fwdbwd_result = training_client.forward_backward(training_data, 'importance_sampling').result() + + optim_result = training_client.optim_step(types.AdamParams(learning_rate=LEARNING_RATE)).result() + + gc.collect() + + # ========== 7. Log ========== + log_dict = metrics.calculate() + if optim_result.metrics: + log_dict.update(optim_result.metrics) + log_dict['train/frac_reward_zero_std'] = frac_zero_std + log_dict['train/num_training_samples'] = len(training_data) + logger.info(f'Step {step}: {log_dict}') + step += 1 + + # Save final checkpoint + save_future = training_client.save_state('Math-grpo-final') + save_result = save_future.result() + logger.info(f'Saved final checkpoint to {save_result.path}') + + +if __name__ == '__main__': + main() diff --git a/cookbook/client/tinker/transformer/server.py b/cookbook/client/tinker/custom_service/transformer/server.py similarity index 100% rename from cookbook/client/tinker/transformer/server.py rename to cookbook/client/tinker/custom_service/transformer/server.py diff --git a/cookbook/client/tinker/transformer/server_config.yaml b/cookbook/client/tinker/custom_service/transformer/server_config.yaml similarity index 85% rename from cookbook/client/tinker/transformer/server_config.yaml rename to cookbook/client/tinker/custom_service/transformer/server_config.yaml index f9c7a690..5009ce08 100644 --- a/cookbook/client/tinker/transformer/server_config.yaml +++ b/cookbook/client/tinker/custom_service/transformer/server_config.yaml @@ -23,6 +23,8 @@ applications: args: server_config: per_token_model_limit: 3 # Maximum number of models (adapters) per token (server-globally enforced) + supported_models: + - Qwen/Qwen3-4B deployments: - name: TinkerCompatServer autoscaling_config: @@ -34,26 +36,26 @@ applications: # 2. Model Service (commented out) - Would host the base model for training. # Uncomment and configure if you need a training model worker. - - name: models-Qwen2.5-7B-Instruct - route_prefix: /api/v1/model/Qwen/Qwen2.5-7B-Instruct + - name: models-Qwen3-4B + route_prefix: /api/v1/model/Qwen/Qwen3-4B import_path: model args: use_megatron: false # Use HuggingFace Transformers backend - model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier + model_id: "ms://Qwen/Qwen3-4B" # ModelScope model identifier max_length: 10240 nproc_per_node: 2 # Number of GPU processes per node device_group: name: model - ranks: [0,1] # GPU rank indices + ranks: 2 device_type: cuda device_mesh: device_type: cuda dp_size: 2 queue_config: - rps_limit: 100 # Max requests per second + rps_limit: 100 # Max requests per second tps_limit: 100000 # Max tokens per second adapter_config: - adapter_timeout: 1800 # Seconds before idle adapter unload + adapter_timeout: 30 # Seconds before idle adapter unload deployments: - name: ModelManagement autoscaling_config: @@ -68,11 +70,11 @@ applications: # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - - name: sampler-Qwen2.5-7B-Instruct - route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct + - name: sampler-Qwen3-4B + route_prefix: /api/v1/sampler/Qwen/Qwen3-4B import_path: sampler args: - model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier + model_id: "ms://Qwen/Qwen3-4B" # ModelScope model identifier nproc_per_node: 2 # Number of GPU processes per node sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) engine_args: # vLLM engine-specific settings @@ -82,7 +84,7 @@ applications: logprobs_mode: processed_logprobs # Logprobs mode for sampling results device_group: # Logical device group for the sampler name: sampler - ranks: [2] # GPU rank indices to use + ranks: 1 # Number of GPUs to use device_type: cuda device_mesh: device_type: cuda diff --git a/cookbook/client/tinker/self_congnition.py b/cookbook/client/tinker/modelscope_service/self_cognition.py similarity index 96% rename from cookbook/client/tinker/self_congnition.py rename to cookbook/client/tinker/modelscope_service/self_cognition.py index 326a6f78..f8d2a607 100644 --- a/cookbook/client/tinker/self_congnition.py +++ b/cookbook/client/tinker/modelscope_service/self_cognition.py @@ -9,7 +9,7 @@ import os from tqdm import tqdm from tinker import types -from twinkle_client import init_tinker_client +from twinkle import init_tinker_client from twinkle.data_format import Message, Trajectory from twinkle.template import Template from twinkle.dataloader import DataLoader @@ -23,9 +23,8 @@ from tinker import ServiceClient # The base model to fine-tune / evaluate -# base_model = 'Qwen/Qwen3-30B-A3B-Instruct-2507' -base_model = 'Qwen/Qwen2.5-7B-Instruct' -base_url = 'http://localhost:8000' +base_model = 'Qwen/Qwen3-30B-A3B-Instruct-2507' +base_url = 'http://www.modelscope.cn/twinkle' def train(): diff --git a/cookbook/client/tinker/modelscope_service/server.py b/cookbook/client/tinker/modelscope_service/server.py new file mode 100644 index 00000000..e38f43a4 --- /dev/null +++ b/cookbook/client/tinker/modelscope_service/server.py @@ -0,0 +1,21 @@ +# Twinkle Server Launcher - Tinker-Compatible Megatron Backend +# +# This script starts the Twinkle server with Tinker-compatible API support +# using the Megatron model backend. +# It reads the server_config.yaml in the same directory for all +# configuration (model, deployment settings, etc.). +# Run this script BEFORE running the client training script (lora.py). + +import os + +# Enable Ray debug mode for verbose logging during development +os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '1' + +from twinkle.server import launch_server + +# Resolve the path to server_config.yaml relative to this script's location +file_dir = os.path.abspath(os.path.dirname(__file__)) +config_path = os.path.join(file_dir, 'server_config.yaml') + +# Launch the Twinkle server — this call blocks until the server is shut down +launch_server(config_path=config_path) diff --git a/cookbook/client/tinker/megatron/server_config.yaml b/cookbook/client/tinker/modelscope_service/server_config.yaml similarity index 100% rename from cookbook/client/tinker/megatron/server_config.yaml rename to cookbook/client/tinker/modelscope_service/server_config.yaml diff --git a/cookbook/client/tinker/short_math_grpo.py b/cookbook/client/tinker/modelscope_service/short_math_grpo.py similarity index 99% rename from cookbook/client/tinker/short_math_grpo.py rename to cookbook/client/tinker/modelscope_service/short_math_grpo.py index 43647ab3..3d3ad2c2 100644 --- a/cookbook/client/tinker/short_math_grpo.py +++ b/cookbook/client/tinker/modelscope_service/short_math_grpo.py @@ -24,7 +24,7 @@ from tinker import types from typing import List, Tuple -from twinkle_client import init_tinker_client +from twinkle import init_tinker_client from twinkle import get_logger from twinkle.advantage import GRPOAdvantage from twinkle.data_format import Message, Trajectory diff --git a/cookbook/client/twinkle/grpo.py b/cookbook/client/twinkle/grpo.py index ee874db6..6db8cee2 100644 --- a/cookbook/client/twinkle/grpo.py +++ b/cookbook/client/twinkle/grpo.py @@ -44,7 +44,7 @@ logger = get_logger() # ========== Configuration ========== -MODEL_ID = 'ms://Qwen/Qwen2.5-3B-Instruct' +MODEL_ID = 'ms://Qwen/Qwen3-4B' NUM_GENERATIONS = 4 MAX_NEW_TOKENS = 1024 LEARNING_RATE = 1e-5 diff --git a/cookbook/client/twinkle/megatron/server_config.yaml b/cookbook/client/twinkle/megatron/server_config.yaml index f431bb21..91c300f9 100644 --- a/cookbook/client/twinkle/megatron/server_config.yaml +++ b/cookbook/client/twinkle/megatron/server_config.yaml @@ -34,22 +34,21 @@ applications: # 2. Model Service - Hosts the base model for training (Megatron backend) # This is the actual model worker that performs forward/backward passes. - - name: models-Qwen2.5-3B-Instruct - route_prefix: /models/Qwen/Qwen2.5-3B-Instruct # REST path for this model + - name: models-Qwen3-4B + route_prefix: /models/Qwen/Qwen3-4B # REST path for this model import_path: model args: use_megatron: true # Use Megatron-LM backend (not HuggingFace) mixed_precision: bf16 - model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier to load + model_id: "ms://Qwen/Qwen3-4B" # ModelScope model identifier to load nproc_per_node: 2 # Number of GPU processes per node device_group: # Logical device group for this model name: model - ranks: [0,1] # GPU rank indices to use + ranks: 2 # Number of GPUs to use device_type: cuda device_mesh: # Distributed training mesh configuration device_type: cuda - mesh: [0,1] # Device indices in the mesh - mesh_dim_names: ['dp'] # Mesh dimension names: 'dp' = data parallel + dp_size: 2 # Data parallel size adapter_config: adapter_timeout: 1800 # Seconds before idle adapter unload deployments: @@ -71,12 +70,11 @@ applications: ncpu_proc_per_node: 2 # Number of CPU processes per node device_group: name: model - ranks: 2 # CPU rank index + ranks: 2 # Number of CPU workers to use device_type: CPU device_mesh: device_type: CPU - mesh: [0,1] - mesh_dim_names: ['dp'] + dp_size: 2 # Data parallel size deployments: - name: ProcessorManagement autoscaling_config: diff --git a/cookbook/client/twinkle/sample.py b/cookbook/client/twinkle/sample.py index 27f22fba..75aeb5b1 100644 --- a/cookbook/client/twinkle/sample.py +++ b/cookbook/client/twinkle/sample.py @@ -22,7 +22,7 @@ logger = get_logger() -MODEL_ID = 'Qwen/Qwen2.5-3B-Instruct' +MODEL_ID = 'Qwen/Qwen3-4B' # Optional: adapter URI for LoRA inference # This can be a twinkle:// path from a training run checkpoint diff --git a/cookbook/client/twinkle/self_congnition.py b/cookbook/client/twinkle/self_congnition.py index fd23726f..781b809f 100644 --- a/cookbook/client/twinkle/self_congnition.py +++ b/cookbook/client/twinkle/self_congnition.py @@ -26,7 +26,7 @@ # Step 2: Initialize the Twinkle client to communicate with the remote server. # - base_url: the address of the running Twinkle server # - api_key: authentication token (loaded from environment variable) -client = init_twinkle_client(base_url='http://127.0.0.1:8000', api_key=os.environ.get('MODELSCOPE_TOKEN')) +client = init_twinkle_client(base_url='http://127.0.0.1:8000', api_key='EMPTY_TOKEN') # Step 3: Query the server for existing training runs and their checkpoints. # This is useful for resuming a previous training session. @@ -51,7 +51,7 @@ def train(): dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(500))) # Apply a chat template so the data matches the model's expected input format - dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-3B-Instruct', max_length=512) + dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B', max_length=512) # Replace placeholder names in the dataset with custom model/author names dataset.map('SelfCognitionProcessor', init_args={'model_name': 'twinkle模型', 'model_author': 'ModelScope社区'}) @@ -65,7 +65,7 @@ def train(): # Step 5: Configure the model # Create a multi-LoRA Transformers model pointing to the base model on ModelScope - model = MultiLoraTransformersModel(model_id='ms://Qwen/Qwen2.5-3B-Instruct') + model = MultiLoraTransformersModel(model_id='ms://Qwen/Qwen3-4B') # Define LoRA configuration: apply low-rank adapters to all linear layers lora_config = LoraConfig(target_modules='all-linear') diff --git a/cookbook/client/twinkle/transformer/server_config.yaml b/cookbook/client/twinkle/transformer/server_config.yaml index 3e9e1472..f10b5b3e 100644 --- a/cookbook/client/twinkle/transformer/server_config.yaml +++ b/cookbook/client/twinkle/transformer/server_config.yaml @@ -34,18 +34,18 @@ applications: # 2. Model Service - Hosts the base model for training # This is the actual model worker that performs forward/backward passes. - - name: models-Qwen2.5-3B-Instruct - route_prefix: /models/Qwen/Qwen2.5-3B-Instruct # REST path for this model + - name: models-Qwen3-4B + route_prefix: /models/Qwen/Qwen3-4B # REST path for this model import_path: model args: use_megatron: false # Use HuggingFace Transformers (not Megatron) - model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier to load + model_id: "ms://Qwen/Qwen3-4B" # ModelScope model identifier to load adapter_config: adapter_timeout: 1800 # Seconds before an idle adapter is unloaded nproc_per_node: 2 # Number of GPU processes per node device_group: # Logical device group for this model name: model - ranks: [0,1] # GPU rank indices to use + ranks: 2 # Number of GPUs to use device_type: cuda device_mesh: # Distributed training mesh configuration device_type: cuda @@ -72,12 +72,11 @@ applications: ncpu_proc_per_node: 2 # Number of CPU processes per node device_group: name: model - ranks: 2 # CPU rank index + ranks: 2 # Number of CPU workers to use device_type: CPU device_mesh: device_type: CPU - mesh: [0,1] - mesh_dim_names: ['dp'] + dp_size: 2 # Data parallel size deployments: - name: ProcessorManagement autoscaling_config: @@ -92,11 +91,11 @@ applications: # 4. Sampler Service - Handles text generation inference # Uses vLLM for efficient batched generation with optional LoRA adapters. - - name: sampler-Qwen2.5-3B-Instruct - route_prefix: /samplers/Qwen/Qwen2.5-3B-Instruct # REST path for this sampler + - name: sampler-Qwen3-4B + route_prefix: /samplers/Qwen/Qwen3-4B # REST path for this sampler import_path: sampler args: - model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier to load + model_id: "ms://Qwen/Qwen3-4B" # ModelScope model identifier to load sampler_type: vllm # Sampler backend (vllm or torch) nproc_per_node: 2 # Number of GPU processes per node engine_args: # vLLM engine configuration @@ -106,7 +105,7 @@ applications: adapter_timeout: 1800 # Seconds before idle adapter is unloaded device_group: name: sampler - ranks: [2] # GPU rank indices to use + ranks: 1 # Number of GPUs to use device_type: cuda device_mesh: device_type: cuda diff --git a/docs/source_en/Usage Guide/Server and Client/Server.md b/docs/source_en/Usage Guide/Server and Client/Server.md index e5b80023..39f1d029 100644 --- a/docs/source_en/Usage Guide/Server and Client/Server.md +++ b/docs/source_en/Usage Guide/Server and Client/Server.md @@ -74,7 +74,7 @@ applications: nproc_per_node: 4 device_group: name: model - ranks: [0, 1, 2, 3] # Physical GPU card numbers + ranks: 4 # Number of GPUs to use device_type: cuda device_mesh: device_type: cuda @@ -91,7 +91,7 @@ applications: nproc_per_node: 2 device_group: name: sampler - ranks: [4, 5] # Physical GPU card numbers 4-5 + ranks: 2 # Number of GPUs to use device_type: cuda device_mesh: device_type: cuda @@ -112,68 +112,11 @@ applications: dp_size: 4 # Data parallel size ``` **Important notes:** -- The `ranks` configuration uses **physical GPU card numbers**, directly corresponding to the actual GPU devices on the machine -- The `device_mesh` configuration uses parameters like `dp_size`, `tp_size`, `pp_size`, `ep_size` instead of the original `mesh` and `mesh_dim_names` +- The `ranks` configuration specifies the **number of GPUs** to allocate for the component +- The `device_mesh` configuration uses parameters like `dp_size`, `tp_size`, `pp_size`, `ep_size` to define the parallelization strategy - Different components will be automatically assigned to different Nodes - Ray will automatically schedule to the appropriate Node based on resource requirements (`num_gpus`, `num_cpus` in `ray_actor_options`) -In the YAML configuration file, **each component needs to occupy a separate Node**. - -**Example configuration:** - -```yaml -applications: - # Model service occupies Node 0 (Head node, GPU 0-3) - - name: models-Qwen2.5-7B-Instruct - route_prefix: /models/Qwen/Qwen2.5-7B-Instruct - import_path: model - args: - nproc_per_node: 4 - device_group: - name: model - ranks: [0, 1, 2, 3] # GPU indices within Node 0 - device_type: cuda - device_mesh: - device_type: cuda - mesh: [0, 1, 2, 3] - mesh_dim_names: ['dp'] - - # Sampler service occupies Node 1 (Worker node, GPU 4-7) - - name: sampler-Qwen2.5-7B-Instruct - route_prefix: /sampler/Qwen/Qwen2.5-7B-Instruct - import_path: sampler - args: - nproc_per_node: 2 - device_group: - name: sampler - ranks: [0, 1] # GPU indices within Node 1 (corresponding to physical GPU 4-5) - device_type: cuda - device_mesh: - device_type: cuda - mesh: [0, 1] - mesh_dim_names: ['dp'] - - # Processor service occupies Node 2 (CPU node) - - name: processor - route_prefix: /processors - import_path: processor - args: - ncpu_proc_per_node: 4 - device_group: - name: processor - ranks: 0 # CPU index within Node 2 - device_type: CPU - device_mesh: - device_type: CPU - mesh: [0, 1, 2, 3] - mesh_dim_names: ['dp'] -``` - -**Important notes:** -- The `ranks` configuration for each component is relative to the Ray Node it occupies -- Different components are automatically assigned to different Nodes -- Ray automatically schedules components to the appropriate Node based on resource requirements (`num_gpus`, `num_cpus` in `ray_actor_options`) - ## Startup Methods The Server is uniformly launched through the `launch_server` function or CLI command, with YAML configuration files. @@ -263,7 +206,7 @@ applications: nproc_per_node: 2 # Number of GPU processes per node device_group: # Logical device group name: model - ranks: [0, 1] # GPU card numbers to use + ranks: 2 # Number of GPUs to use device_type: cuda device_mesh: # Distributed training mesh device_type: cuda @@ -319,7 +262,7 @@ The difference from the Transformers backend is only in the `use_megatron` param nproc_per_node: 2 device_group: name: model - ranks: [0, 1] + ranks: 2 device_type: cuda device_mesh: device_type: cuda @@ -374,7 +317,7 @@ applications: nproc_per_node: 2 device_group: name: model - ranks: [0, 1] + ranks: 2 device_type: cuda device_mesh: device_type: cuda @@ -404,7 +347,7 @@ applications: enable_lora: true # Support loading LoRA during inference device_group: name: sampler - ranks: [0] + ranks: 1 device_type: cuda device_mesh: device_type: cuda @@ -435,13 +378,13 @@ applications: ### device_group and device_mesh -- **device_group**: Defines logical device groups, specifying which GPU cards to use +- **device_group**: Defines logical device groups, specifying how many GPUs to use - **device_mesh**: Defines distributed training mesh, controls parallelization strategy ```yaml device_group: name: model # Device group name - ranks: [0, 1] # Physical GPU card number list + ranks: 2 # Number of GPUs to use device_type: cuda # Device type: cuda / CPU device_mesh: @@ -456,7 +399,7 @@ device_mesh: | Parameter | Type | Description | |------|------|------| -| `ranks` | list[int] | **Physical GPU card numbers**, directly corresponding to the actual GPU devices on the machine | +| `ranks` | int | **Number of GPUs to use** for this component | | `dp_size` | int | Data parallel size | | `tp_size` | int (optional) | Tensor parallel size | | `pp_size` | int (optional) | Pipeline parallel size | diff --git a/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md b/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md index 8978e2a3..2e781ad9 100644 --- a/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md +++ b/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md @@ -6,7 +6,7 @@ The Tinker Client is suitable for scenarios with existing Tinker training code. ```python # Initialize Tinker client before importing ServiceClient -from twinkle_client import init_tinker_client +from twinkle import init_tinker_client init_tinker_client() @@ -41,7 +41,7 @@ import dotenv dotenv.load_dotenv('.env') # Step 1: Initialize Tinker client before importing ServiceClient -from twinkle_client import init_tinker_client +from twinkle import init_tinker_client init_tinker_client() from tinker import types, ServiceClient @@ -139,7 +139,7 @@ Tinker compatible mode can also leverage Twinkle's dataset components to simplif ```python from tqdm import tqdm from tinker import types -from twinkle_client import init_tinker_client +from twinkle import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor @@ -216,7 +216,7 @@ You can also load saved checkpoints for inference: import os from tinker import types from modelscope import AutoTokenizer -from twinkle_client import init_tinker_client +from twinkle import init_tinker_client # Initialize Tinker client before importing ServiceClient init_tinker_client() diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" index 7c204f2c..e11ded44 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" @@ -6,7 +6,7 @@ Tinker Client 适用于已有 Tinker 训练代码的场景。通过 `init_tinker ```python # 在导入 ServiceClient 之前,先初始化 Tinker 客户端 -from twinkle_client import init_tinker_client +from twinkle import init_tinker_client init_tinker_client() @@ -41,7 +41,7 @@ import dotenv dotenv.load_dotenv('.env') # Step 1: 在导入 ServiceClient 之前,先初始化 Tinker 客户端 -from twinkle_client import init_tinker_client +from twinkle import init_tinker_client init_tinker_client() from tinker import types, ServiceClient @@ -139,7 +139,7 @@ Tinker 兼容模式也可以利用 Twinkle 的数据集组件来简化数据准 ```python from tqdm import tqdm from tinker import types -from twinkle_client import init_tinker_client +from twinkle import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor @@ -216,7 +216,7 @@ for i, seq in enumerate(result.sequences): import os from tinker import types from modelscope import AutoTokenizer -from twinkle_client import init_tinker_client +from twinkle import init_tinker_client # 在导入 ServiceClient 之前,先初始化 Tinker 客户端 init_tinker_client() diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" index 73915e8c..7fa547e2 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" @@ -74,7 +74,7 @@ applications: nproc_per_node: 4 device_group: name: model - ranks: [0, 1, 2, 3] # 物理 GPU 卡号 + ranks: 4 # 使用的 GPU 数量 device_type: cuda device_mesh: device_type: cuda @@ -91,7 +91,7 @@ applications: nproc_per_node: 2 device_group: name: sampler - ranks: [4, 5] # 物理 GPU 卡号 4-5 + ranks: 2 # 使用的 GPU 数量 device_type: cuda device_mesh: device_type: cuda @@ -112,8 +112,8 @@ applications: dp_size: 4 # 数据并行大小 ``` **重要提示:** -- `ranks` 配置使用**物理 GPU 卡号**,直接对应机器上的实际 GPU 设备 -- `device_mesh` 配置使用 `dp_size`、`tp_size`、`pp_size`、`ep_size` 等参数替代原来的 `mesh` 和 `mesh_dim_names` +- `ranks` 配置指定为该组件分配的 **GPU 数量** +- `device_mesh` 配置使用 `dp_size`、`tp_size`、`pp_size`、`ep_size` 等参数定义并行策略 - 不同组件会自动分配到不同的 Node 上 - Ray 会根据资源需求(`ray_actor_options` 中的 `num_gpus`、`num_cpus`)自动调度到合适的 Node @@ -206,7 +206,7 @@ applications: nproc_per_node: 2 # 每节点 GPU 进程数 device_group: # 逻辑设备组 name: model - ranks: [0, 1] # 物理 GPU 卡号 + ranks: 2 # 使用的 GPU 数量 device_type: cuda device_mesh: # 分布式训练网格 device_type: cuda @@ -262,7 +262,7 @@ applications: nproc_per_node: 2 device_group: name: model - ranks: [0, 1] + ranks: 2 device_type: cuda device_mesh: device_type: cuda @@ -317,7 +317,7 @@ applications: nproc_per_node: 2 device_group: name: model - ranks: [0, 1] + ranks: 2 device_type: cuda device_mesh: device_type: cuda @@ -347,7 +347,7 @@ applications: enable_lora: true # 支持推理时加载 LoRA device_group: name: sampler - ranks: [0] + ranks: 1 device_type: cuda device_mesh: device_type: cuda @@ -378,13 +378,13 @@ applications: ### device_group 与 device_mesh -- **device_group**:定义逻辑设备组,指定使用哪些 GPU 卡 +- **device_group**:定义逻辑设备组,指定使用多少 GPU - **device_mesh**:定义分布式训练网格,控制并行策略 ```yaml device_group: name: model # 设备组名称 - ranks: [0, 1] # 物理 GPU 卡号列表 + ranks: 2 # 使用的 GPU 数量 device_type: cuda # 设备类型:cuda / CPU device_mesh: @@ -399,7 +399,7 @@ device_mesh: | 参数 | 类型 | 说明 | |------|------|------| -| `ranks` | list[int] | **物理 GPU 卡号**,直接对应机器上的实际 GPU 设备 | +| `ranks` | int | **使用的 GPU 数量** | | `dp_size` | int | 数据并行大小 | | `tp_size` | int (可选) | 张量并行大小 | | `pp_size` | int (可选) | 流水线并行大小 | diff --git a/pyproject.toml b/pyproject.toml index 76ca660d..53f29f4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ kernels = ["kernels"] megatron = ["megatron-core>=0.12.0", "transformer-engine[pytorch]"] vllm = ["vllm>=0.11"] ray = ["ray[serve]"] +tinker = ["tinker==0.14.0"] docs = [ "sphinx>=5.3.0,<6.0.0", "docutils>=0.16.0,<0.17.0", diff --git a/src/twinkle/__init__.py b/src/twinkle/__init__.py index 63ffb66a..f64917a5 100644 --- a/src/twinkle/__init__.py +++ b/src/twinkle/__init__.py @@ -4,11 +4,11 @@ from .utils.import_utils import _LazyModule # noqa if TYPE_CHECKING: + from twinkle_client import init_tinker_client, init_twinkle_client from .infra import get_device_placement, initialize, is_master, remote_class, remote_function from .utils import (GPU, NPU, DeviceGroup, DeviceMesh, Platform, Plugin, check_unsafe, exists, find_free_port, find_node_ip, framework_util, get_logger, requires, torch_util, trust_remote_code) from .version import __release_datetime__, __version__ - else: _import_structure = { 'version': ['__release_datetime__', '__version__'], @@ -21,10 +21,15 @@ import sys + from twinkle_client import init_tinker_client, init_twinkle_client + sys.modules[__name__] = _LazyModule( __name__, globals()['__file__'], _import_structure, module_spec=__spec__, # noqa - extra_objects={}, + extra_objects={ + 'init_tinker_client': init_tinker_client, + 'init_twinkle_client': init_twinkle_client + }, ) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index aa74e72e..71d76056 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -880,9 +880,12 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs): """ resume = kwargs.pop('load_optimizer', False) if output_dir is None and not resume: - # Load from hub - token = kwargs.pop('token', None) - checkpoint_dir = HubOperation.download_model(name, token=token) + if os.path.exists(name): + checkpoint_dir = name + else: + # load from hub + token = kwargs.pop('token', None) + checkpoint_dir = HubOperation.download_model(name, token=token) else: if output_dir is None: output_dir = 'output' diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index bc74786d..dc16f624 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -115,7 +115,7 @@ def acquire_lora(self, tenant_adapter_name: str, config: LoraConfig) -> str: raise ValueError(f'Lora {tenant_adapter_name} already exists') _available_lora = self._get_available_lora() if _available_lora is None: - raise RuntimeError(f'No lora available for tenant {tenant_adapter_name}') + raise RuntimeError(f'No lora available for tenant {tenant_adapter_name}. Max loras: {self.max_loras}') if config.r > self.max_r: raise RuntimeError(f'Too big rank for lora: {config.r}') _available_lora.tenant_config = config diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 45df3082..1e8c1700 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -842,9 +842,12 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs): adapter_name = kwargs.pop('adapter_name', self._get_default_group()) if output_dir is None: - # load from hub - token = kwargs.pop('token', None) - checkpoint_dir = HubOperation.download_model(name, token=token) + if os.path.exists(name): + checkpoint_dir = name + else: + # load from hub + token = kwargs.pop('token', None) + checkpoint_dir = HubOperation.download_model(name, token=token) else: checkpoint_dir = os.path.join(output_dir, name) model = self.strategy.unwrap_model(self.model) diff --git a/src/twinkle/server/tinker/common/router.py b/src/twinkle/server/tinker/common/router.py new file mode 100644 index 00000000..19ec8650 --- /dev/null +++ b/src/twinkle/server/tinker/common/router.py @@ -0,0 +1,76 @@ +from ray.serve.request_router import (FIFOMixin, MultiplexMixin, PendingRequest, ReplicaID, ReplicaResult, + RequestRouter, RunningReplica) +from typing import Dict, List, Optional + +from twinkle.server.utils.state import ServerStateProxy, get_server_state +from twinkle.utils.logger import get_logger + +logger = get_logger() + + +class StickyLoraRequestRouter(FIFOMixin, MultiplexMixin, RequestRouter): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.state: ServerStateProxy = get_server_state() + + async def choose_replicas( + self, + candidate_replicas: List[RunningReplica], + pending_request: Optional[PendingRequest] = None, + ) -> List[List[RunningReplica]]: + """ + This method chooses the best replica for the request based on + multiplexed and avaliable lora count. The algorithm + works as follows: + + 1. Populate top_ranked_replicas based on available replicas based on + multiplex_id (only one replica is chosen) + 2. Populate and override top_ranked_replicas info based on avalible lora + slots of the replica. + """ + + # Take the best set of replicas for the multiplexed model + if (pending_request is not None and pending_request.metadata.multiplexed_model_id): + ranked_replicas_multiplex: List[RunningReplica] = (self.rank_replicas_via_multiplex( + replicas=candidate_replicas, + multiplexed_model_id=pending_request.metadata.multiplexed_model_id, + ))[0] + + # If found any replica, return it + if ranked_replicas_multiplex: + logger.debug('[Router] Found replica for multiplexed model !!!') + return [ranked_replicas_multiplex] + + # Dictionary to hold the top-ranked replicas + top_ranked_replicas: Dict[ReplicaID, RunningReplica] = {} + + # Filter out replicas that are not available (queue length exceed max ongoing request) + ranked_replicas_locality = self.select_available_replicas(candidates=candidate_replicas) + + for replica in ranked_replicas_locality: + top_ranked_replicas[replica.replica_id] = replica + + # Filter out replicas that exceed max lora count (query from server state) + candidate_ids = [r.replica_id.unique_id for r in top_ranked_replicas.values()] + available_ids = set(self.state.get_available_replica_ids(candidate_ids)) + if available_ids: + top_ranked_replicas = { + rid: r + for rid, r in top_ranked_replicas.items() if r.replica_id.unique_id in available_ids + } + + if not top_ranked_replicas: + # No replica has remaining LoRA capacity – fall back to all candidates + logger.debug('[Router] No replica has remaining LoRA capacity') + return [candidate_replicas] + + logger.debug('[Router] StickyLoraRequestRouter choosing replica for request') + + # Take the replica with minimum throughput. + min_throughput_replicas = min( + [replica for replica in top_ranked_replicas.values()], + key=lambda r: r.routing_stats.get('throughput', 0), + ) + return [[min_throughput_replicas]] diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index 55d7e3bd..30ced15e 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -13,6 +13,7 @@ from fastapi import FastAPI, Request from peft import LoraConfig from ray import serve +from ray.serve.config import RequestRouterConfig from tinker import types from typing import Any, Dict, Optional @@ -21,9 +22,10 @@ from twinkle.server.utils.adapter_manager import AdapterManagerMixin from twinkle.server.utils.state import ServerStateProxy, get_server_state from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin -from twinkle.server.utils.validation import verify_request_token +from twinkle.server.utils.validation import get_token_from_request, verify_request_token from twinkle.utils.logger import get_logger from .common.io_utils import create_checkpoint_manager, create_training_run_manager +from .common.router import StickyLoraRequestRouter logger = get_logger() @@ -62,7 +64,10 @@ async def verify_token(request: Request, call_next): """Middleware to verify authentication token for all requests.""" return await verify_request_token(request=request, call_next=call_next) - @serve.deployment(name='ModelManagement') + @serve.deployment( + name='ModelManagement', + request_router_config=RequestRouterConfig(request_router_class=StickyLoraRequestRouter, ), + ) @serve.ingress(app) class ModelManagement(TaskQueueMixin, AdapterManagerMixin): """Model management service handling training operations. @@ -99,8 +104,8 @@ def __init__(self, else: self.device_mesh = DeviceMesh.from_sizes(**device_mesh) self.use_megatron = use_megatron - replica_context = serve.get_replica_context() - replica_id = replica_context.replica_id.unique_id + self.replica_id = serve.get_replica_context().replica_id.unique_id + self.max_loras = kwargs.get('max_loras', 5) # Initialize model immediately - choose backend based on use_megatron if use_megatron: from .common.megatron_model import TwinkleCompatMegatronModel @@ -108,7 +113,7 @@ def __init__(self, model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, - instance_id=replica_id, + instance_id=self.replica_id, **kwargs) else: from .common.transformers_model import TwinkleCompatTransformersModel @@ -116,11 +121,14 @@ def __init__(self, model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, - instance_id=replica_id, + instance_id=self.replica_id, **kwargs) self.base_model = model_id self.state: ServerStateProxy = get_server_state() + # Register this replica so the router can track capacity + self.state.register_replica(self.replica_id, self.max_loras) + # Initialize task queue self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) @@ -128,7 +136,7 @@ def __init__(self, self.start_adapter_countdown() """ - TODO This is a cache system, we must change to sticky routing + This is a cache system, we must change to sticky routing Reference docs: 1. [Now]https://docs.ray.io/en/latest/serve/model-multiplexing.html 2. https://docs.ray.io/en/latest/serve/llm/architecture/routing-policies.html @@ -136,9 +144,21 @@ def __init__(self, 4. Direct call actor instead of http or handler in server.py """ - # @serve.multiplexed(max_num_models_per_replica=kwargs.get('max_loras', 5)) - # async def get_multiplexed_adapter(self, request_id: str): - # return request_id + @serve.multiplexed(max_num_models_per_replica=kwargs.get('max_loras', 5)) + async def _sticky_entry(self, sticky_key: str): + return sticky_key + + async def _ensure_sticky(self): + sticky_key = serve.get_multiplexed_model_id() + await self._sticky_entry(sticky_key) + + async def _on_request_start(self, request: Request) -> str: + await self._ensure_sticky() + token = get_token_from_request(request) + return token + + def __del__(self): + self.state.unregister_replica(self.replica_id) def _cleanup_adapter(self, adapter_name: str) -> None: """Common adapter cleanup logic used by both manual unload and automatic expiration. @@ -188,12 +208,13 @@ async def create_model(self, request: Request, body: types.CreateModelRequest) - Returns: UntypedAPIFuture wrapping CreateModelResponse with model_id """ + token = await self._on_request_start(request) async def _create_adapter(): model_id = None try: # Register a new model_id for each create_model call - model_id = self.state.register_model(body.model_dump(), token=request.state.token) + model_id = self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id) # Create a new LoRA adapter for the model if body.lora_config: @@ -203,7 +224,7 @@ async def _create_adapter(): adapter_name = self.get_adapter_name(adapter_name=model_id) # Register adapter FIRST - self.register_adapter(adapter_name, request.state.token, session_id=body.session_id) + self.register_adapter(adapter_name, token, session_id=body.session_id) # Create adapter AFTER successful registration self.model.add_adapter_to_model(adapter_name=adapter_name, config_or_dir=lora_cfg) @@ -215,7 +236,7 @@ async def _create_adapter(): # Fresh adapter has no accumulated gradients. self.set_adapter_state(adapter_name, 'grad_ready', False) - training_run_manager = create_training_run_manager(request.state.token) + training_run_manager = create_training_run_manager(token) training_run_manager.save(model_id, body) return types.CreateModelResponse(model_id=model_id) @@ -233,7 +254,7 @@ async def _create_adapter(): return await self.schedule_task( _create_adapter, - token=request.state.token, + token=token, task_type='create_model', ) @@ -248,9 +269,10 @@ async def get_info(self, request: Request, body: types.GetInfoRequest) -> types. Returns: GetInfoResponse with model metadata (name, lora_rank, etc.) """ + token = await self._on_request_start(request) # Note: get_info doesn't require token for reading metadata in tinker # Using a default token or None since this is read-only - training_run_manager = create_training_run_manager(request.state.token) + training_run_manager = create_training_run_manager(token) metadata = training_run_manager.get(str(body.model_id)) model_name = metadata.base_model if metadata else model_id lora_rank = None @@ -279,6 +301,7 @@ async def unload_model(self, request: Request, body: types.UnloadModelRequest) - Returns: UntypedAPIFuture wrapping UnloadModelResponse """ + token = await self._on_request_start(request) async def _do_unload(): # Only remove adapter, not the base model @@ -290,7 +313,7 @@ async def _do_unload(): return await self.schedule_task( _do_unload, model_id=body.model_id, - token=request.state.token, + token=token, task_type='unload_model', ) @@ -307,6 +330,7 @@ async def forward(self, request: Request, body: types.ForwardRequest) -> types.U Returns: UntypedAPIFuture wrapping ForwardBackwardOutput with loss """ + token = await self._on_request_start(request) async def _do_forward(): try: @@ -340,7 +364,7 @@ async def _do_forward(): return await self.schedule_task( _do_forward, model_id=body.model_id, - token=request.state.token, + token=token, input_tokens=input_tokens, batch_size=batch_size, data_world_size=self.device_mesh.data_world_size, @@ -364,6 +388,7 @@ async def forward_backward(self, request: Request, Returns: UntypedAPIFuture wrapping ForwardBackwardOutput with loss and metrics """ + token = await self._on_request_start(request) async def _do_forward_backward(): try: @@ -405,7 +430,7 @@ async def _do_forward_backward(): return await self.schedule_task( _do_forward_backward, model_id=body.model_id, - token=request.state.token, + token=token, input_tokens=input_tokens, batch_size=batch_size, data_world_size=self.device_mesh.data_world_size, @@ -425,6 +450,7 @@ async def optim_step(self, request: Request, body: types.OptimStepRequest) -> ty Returns: UntypedAPIFuture wrapping OptimStepResponse """ + token = await self._on_request_start(request) async def _do_optim(): try: @@ -455,7 +481,7 @@ async def _do_optim(): return await self.schedule_task( _do_optim, model_id=body.model_id, - token=request.state.token, + token=token, task_type='optim_step', ) @@ -473,6 +499,7 @@ async def save_weights(self, request: Request, body: types.SaveWeightsRequest) - Returns: UntypedAPIFuture wrapping SaveWeightsResponse with saved path """ + token = await self._on_request_start(request) async def _do_save(): try: @@ -482,8 +509,6 @@ async def _do_save(): # Touch adapter to reset inactivity counter self.touch_adapter(adapter_name) - # Extract token from request for user isolation - token = request.state.token checkpoint_manager = create_checkpoint_manager(token) # get save dir with token-based isolation @@ -506,7 +531,7 @@ async def _do_save(): return await self.schedule_task( _do_save, model_id=body.model_id, - token=request.state.token, + token=token, task_type='save_weights', ) @@ -525,6 +550,7 @@ async def save_weights_for_sampler(self, request: Request, Returns: UntypedAPIFuture wrapping SaveWeightsForSamplerResponseInternal """ + token = await self._on_request_start(request) async def _do_save_for_sampler(): try: @@ -535,8 +561,6 @@ async def _do_save_for_sampler(): # Touch adapter to reset inactivity counter self.touch_adapter(adapter_name) - # Extract token from request for user isolation - token = request.state.token checkpoint_manager = create_checkpoint_manager(token) # get save dir with token-based isolation @@ -571,7 +595,7 @@ async def _do_save_for_sampler(): return await self.schedule_task( _do_save_for_sampler, model_id=body.model_id, - token=request.state.token, + token=token, task_type='save_weights_for_sampler', ) @@ -589,6 +613,7 @@ async def load_weights(self, request: Request, body: types.LoadWeightsRequest) - Returns: UntypedAPIFuture wrapping LoadWeightsResponse """ + token = await self._on_request_start(request) async def _do_load(): try: @@ -600,9 +625,6 @@ async def _do_load(): # Touch adapter to reset inactivity counter self.touch_adapter(adapter_name) - # Extract token from request for user isolation - token = request.state.token - weight_path = body.path load_optimizer = body.optimizer @@ -625,7 +647,7 @@ async def _do_load(): return await self.schedule_task( _do_load, model_id=body.model_id, - token=request.state.token, + token=token, task_type='load_weights', ) diff --git a/src/twinkle/server/tinker/sampler.py b/src/twinkle/server/tinker/sampler.py index 8ab6fd91..20b0a5a1 100644 --- a/src/twinkle/server/tinker/sampler.py +++ b/src/twinkle/server/tinker/sampler.py @@ -21,7 +21,7 @@ from twinkle.data_format import SamplingParams from twinkle.server.utils.state import ServerStateProxy, get_server_state from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin -from twinkle.server.utils.validation import verify_request_token +from twinkle.server.utils.validation import get_token_from_request, verify_request_token from twinkle.utils.logger import get_logger from .common.io_utils import create_checkpoint_manager @@ -126,6 +126,19 @@ def __init__(self, self.state: ServerStateProxy = get_server_state() self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) + @serve.multiplexed(max_num_models_per_replica=5) + async def _sticky_entry(self, sticky_key: str): + return sticky_key + + async def _ensure_sticky(self): + sticky_key = serve.get_multiplexed_model_id() + await self._sticky_entry(sticky_key) + + async def _on_request_start(self, request: Request) -> str: + await self._ensure_sticky() + token = get_token_from_request(request) + return token + @app.post('/asample') async def asample(self, request: Request, body: types.SampleRequest) -> types.UntypedAPIFuture: """Execute text generation (inference). @@ -144,6 +157,7 @@ async def asample(self, request: Request, body: types.SampleRequest) -> types.Un Returns: UntypedAPIFuture wrapping SampleResponse with generated sequences """ + token = await self._on_request_start(request) async def _do_sample(): try: @@ -160,7 +174,6 @@ async def _do_sample(): # Parse and resolve adapter URI from model_path adapter_uri = None if model_path: - token = request.state.token checkpoint_manager = create_checkpoint_manager(token) adapter_name, adapter_uri = checkpoint_manager.parse_adapter_uri(model_path) @@ -225,7 +238,7 @@ async def _do_sample(): input_tokens = len(body.prompt.to_ints()) return await self.schedule_task( _do_sample, - token=request.state.token, + token=token, input_tokens=input_tokens, task_type='sample', ) diff --git a/src/twinkle/server/utils/io_utils.py b/src/twinkle/server/utils/io_utils.py index 203540b9..1a95b6c2 100644 --- a/src/twinkle/server/utils/io_utils.py +++ b/src/twinkle/server/utils/io_utils.py @@ -6,6 +6,8 @@ file-based storage of training run metadata and checkpoint information. Both tinker and twinkle servers inherit from these classes. """ +import hashlib +import hmac import json import os import re @@ -25,6 +27,20 @@ CHECKPOINT_INFO_FILENAME = 'checkpoint_metadata.json' TRAIN_RUN_INFO_FILENAME = 'twinkle_metadata.json' +# Salt used when hashing tokens for directory isolation. +# Override via env var TWINKLE_TOKEN_SALT to customise per-deployment. +_TOKEN_SALT = os.environ.get('TWINKLE_TOKEN_SALT', 'twinkle-path-salt-v1').encode('utf-8') + + +def _hash_token(token: str) -> str: + """Return a salted HMAC-SHA256 hex digest of *token*. + + The digest is used as the per-user base directory name so that the raw + token value is never written to the filesystem. + """ + return hmac.new(_TOKEN_SALT, token.encode('utf-8'), hashlib.sha256).hexdigest()[:16] + + # ----- Common Pydantic Models ----- @@ -275,13 +291,15 @@ def get_base_dir(self) -> Path: """ Get base directory with token-based isolation. + The token is never written to disk in plaintext; instead a salted + HMAC-SHA256 digest is used as the directory name so that the real + token cannot be recovered by inspecting the filesystem. + Returns: Path to token-specific base directory """ base_path = Path(TWINKLE_DEFAULT_SAVE_DIR).absolute() - # Sanitize token to avoid filesystem issues - sanitized_token = re.sub(r'[^\w\-]', '_', self.token) - return base_path / sanitized_token + return base_path / _hash_token(self.token) def get_model_dir(self, model_id: str) -> Path: """ diff --git a/src/twinkle/server/utils/state/model_manager.py b/src/twinkle/server/utils/state/model_manager.py index 9e0d02b8..586e4868 100644 --- a/src/twinkle/server/utils/state/model_manager.py +++ b/src/twinkle/server/utils/state/model_manager.py @@ -13,6 +13,9 @@ class ModelManager(BaseManager[ModelRecord]): its owning session has already been removed (cascade expiry). Enforces a per-token model limit across all model instances (server-global). + + Also tracks replica registrations so the router can query which replicas + still have capacity (i.e. their loaded-model count < max_loras). """ def __init__(self, expiration_timeout: float, per_token_model_limit: int = 30) -> None: @@ -20,6 +23,58 @@ def __init__(self, expiration_timeout: float, per_token_model_limit: int = 30) - self._per_token_model_limit = per_token_model_limit # token -> set of model_ids owned by that token self._token_models: dict[str, set[str]] = {} + # replica_id -> set of model_ids currently loaded on that replica + self._replica_models: dict[str, set[str]] = {} + # replica_id -> max_loras limit declared at registration time + self._replica_max_loras: dict[str, int] = {} + + # ----- Replica Registration ----- + + def register_replica(self, replica_id: str, max_loras: int) -> None: + """Register a replica and its LoRA capacity. + + Args: + replica_id: Unique identifier for the replica. + max_loras: Maximum number of LoRA adapters the replica can hold. + """ + self._replica_max_loras[replica_id] = max_loras + self._replica_models.setdefault(replica_id, set()) + + def unregister_replica(self, replica_id: str) -> None: + """Remove a replica from the registry. + + Any model associations for this replica are also cleared. + + Args: + replica_id: Unique identifier for the replica to remove. + """ + self._replica_max_loras.pop(replica_id, None) + self._replica_models.pop(replica_id, None) + + def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: + """Return the subset of candidate replica IDs that still have capacity. + + A replica has capacity when its current loaded-model count is strictly + less than its declared ``max_loras``. Replicas that are not registered + (unknown to this manager) are included as-is (conservative fallback). + + Args: + candidate_ids: Replica IDs to evaluate. + + Returns: + Filtered list preserving the original order. + """ + available = [] + for rid in candidate_ids: + max_loras = self._replica_max_loras.get(rid) + if max_loras is None: + # Unknown replica – include conservatively + available.append(rid) + continue + current = len(self._replica_models.get(rid, set())) + if current < max_loras: + available.append(rid) + return available # ----- CRUD ----- @@ -39,10 +94,12 @@ def add(self, model_id: str, record: ModelRecord) -> None: raise RuntimeError(f'Model limit exceeded: ' f'{len(current_ids)}/{self._per_token_model_limit} models') self._token_models.setdefault(token, set()).add(model_id) + if record.replica_id is not None: + self._replica_models.setdefault(record.replica_id, set()).add(model_id) self._store[model_id] = record def remove(self, model_id: str) -> bool: - """Remove a record by ID and clean up token ownership. + """Remove a record by ID and clean up token and replica ownership. Returns: True if the record existed and was removed, False otherwise. @@ -50,11 +107,7 @@ def remove(self, model_id: str) -> bool: record = self._store.pop(model_id, None) if record is None: return False - token = record.token - if token and token in self._token_models: - self._token_models[token].discard(model_id) - if not self._token_models[token]: - del self._token_models[token] + self._cleanup_ownership(model_id, record) return True # ----- Cleanup ----- @@ -87,10 +140,23 @@ def cleanup_expired(self, cutoff_time: float, expired_session_ids: list[str] | N for model_id in expired_ids: record = self._store.pop(model_id) - token = record.token - if token and token in self._token_models: - self._token_models[token].discard(model_id) - if not self._token_models[token]: - del self._token_models[token] + self._cleanup_ownership(model_id, record) return len(expired_ids) + + # ----- Internal helpers ----- + + def _cleanup_ownership(self, model_id: str, record: ModelRecord) -> None: + """Remove token and replica ownership entries for a model record. + + Args: + model_id: The model ID being removed. + record: The associated ModelRecord. + """ + token = record.token + if token and token in self._token_models: + self._token_models[token].discard(model_id) + if not self._token_models[token]: + del self._token_models[token] + if record.replica_id and record.replica_id in self._replica_models: + self._replica_models[record.replica_id].discard(model_id) diff --git a/src/twinkle/server/utils/state/models.py b/src/twinkle/server/utils/state/models.py index d8499ff8..71279b89 100644 --- a/src/twinkle/server/utils/state/models.py +++ b/src/twinkle/server/utils/state/models.py @@ -30,6 +30,7 @@ class ModelRecord(BaseModel): base_model: str | None = None user_metadata: dict[str, Any] = Field(default_factory=dict) lora_config: Any = None + replica_id: str | None = None created_at: str = Field(default_factory=_now_iso) diff --git a/src/twinkle/server/utils/state/server_state.py b/src/twinkle/server/utils/state/server_state.py index 82605410..7588c65d 100644 --- a/src/twinkle/server/utils/state/server_state.py +++ b/src/twinkle/server/utils/state/server_state.py @@ -89,13 +89,18 @@ def get_session_last_heartbeat(self, session_id: str) -> float | None: # ----- Model Registration ----- - def register_model(self, payload: dict[str, Any], token: str, model_id: str | None = None) -> str: + def register_model(self, + payload: dict[str, Any], + token: str, + model_id: str | None = None, + replica_id: str | None = None) -> str: """Register a new model with the server state. Args: payload: Model configuration containing base_model, lora_config, etc. token: User token that owns this model. Required. model_id: Optional explicit model_id; otherwise auto-generated. + replica_id: Optional replica that is hosting this model. Returns: The model_id for the registered model. @@ -112,6 +117,7 @@ def register_model(self, payload: dict[str, Any], token: str, model_id: str | No user_metadata=payload.get('user_metadata') or {}, lora_config=payload.get('lora_config'), token=token, + replica_id=replica_id, ) self._model_mgr.add(_model_id, record) return _model_id @@ -129,6 +135,36 @@ def get_model_metadata(self, model_id: str) -> dict[str, Any] | None: record = self._model_mgr.get(model_id) return record.model_dump() if record is not None else None + # ----- Replica Management ----- + + def register_replica(self, replica_id: str, max_loras: int) -> None: + """Register a replica and its LoRA capacity. + + Args: + replica_id: Unique identifier for the replica. + max_loras: Maximum number of LoRA adapters the replica can hold. + """ + self._model_mgr.register_replica(replica_id, max_loras) + + def unregister_replica(self, replica_id: str) -> None: + """Remove a replica from the registry. + + Args: + replica_id: Unique identifier for the replica to remove. + """ + self._model_mgr.unregister_replica(replica_id) + + def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: + """Return candidate replica IDs that have not reached their max_loras limit. + + Args: + candidate_ids: Replica IDs to evaluate. + + Returns: + Filtered list of replica IDs with remaining capacity. + """ + return self._model_mgr.get_available_replica_ids(candidate_ids) + # ----- Sampling Session Management ----- def create_sampling_session(self, payload: dict[str, Any], sampling_session_id: str | None = None) -> str: @@ -344,8 +380,12 @@ def get_session_last_heartbeat(self, session_id: str) -> float | None: # ----- Model Registration ----- - def register_model(self, payload: dict[str, Any], token: str, model_id: str | None = None) -> str: - return ray.get(self._actor.register_model.remote(payload, token, model_id)) + def register_model(self, + payload: dict[str, Any], + token: str, + model_id: str | None = None, + replica_id: str | None = None) -> str: + return ray.get(self._actor.register_model.remote(payload, token, model_id, replica_id)) def unload_model(self, model_id: str) -> bool: return ray.get(self._actor.unload_model.remote(model_id)) @@ -353,6 +393,17 @@ def unload_model(self, model_id: str) -> bool: def get_model_metadata(self, model_id: str) -> dict[str, Any] | None: return ray.get(self._actor.get_model_metadata.remote(model_id)) + # ----- Replica Management ----- + + def register_replica(self, replica_id: str, max_loras: int) -> None: + ray.get(self._actor.register_replica.remote(replica_id, max_loras)) + + def unregister_replica(self, replica_id: str) -> None: + ray.get(self._actor.unregister_replica.remote(replica_id)) + + def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: + return ray.get(self._actor.get_available_replica_ids.remote(candidate_ids)) + # ----- Sampling Session Management ----- def create_sampling_session(self, payload: dict[str, Any], sampling_session_id: str | None = None) -> str: diff --git a/src/twinkle_client/__init__.py b/src/twinkle_client/__init__.py index 25564306..58c43a37 100644 --- a/src/twinkle_client/__init__.py +++ b/src/twinkle_client/__init__.py @@ -1,12 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations -from twinkle.utils import requires -from .http.utils import get_api_key, get_base_url, set_api_key, set_base_url -from .manager import TwinkleClient, TwinkleClientError - - - def init_tinker_client(**kwargs) -> None: """Initialize Tinker client with Twinkle-specific headers. @@ -20,11 +14,13 @@ def init_tinker_client(**kwargs) -> None: **kwargs: Additional keyword arguments (currently unused, reserved for future) Example: - >>> from twinkle_client import init_tinker_client + >>> from twinkle import init_tinker_client >>> init_tinker_client() >>> from tinker import ServiceClient >>> client = ServiceClient(base_url='http://localhost:8000', api_key='your_token') """ + from twinkle.utils import requires + requires('tinker') from twinkle_client.utils.patch_tinker import patch_tinker @@ -36,6 +32,9 @@ def init_twinkle_client(base_url: str | None = None, api_key: str | None = None, """ Initialize a Twinkle client and setup context variables. """ + from .http.utils import get_api_key, get_base_url, set_api_key, set_base_url + from .manager import TwinkleClient, TwinkleClientError + if base_url is not None: set_base_url(base_url) else: @@ -49,4 +48,4 @@ def init_twinkle_client(base_url: str | None = None, api_key: str | None = None, return TwinkleClient(base_url=base_url, api_key=api_key, **kwargs) -__all__ = ['TwinkleClient', 'TwinkleClientError', 'init_tinker_client', 'init_twinkle_client'] +__all__ = ['init_tinker_client', 'init_twinkle_client'] diff --git a/src/twinkle_client/utils/patch_tinker.py b/src/twinkle_client/utils/patch_tinker.py index 73363472..d0245c20 100644 --- a/src/twinkle_client/utils/patch_tinker.py +++ b/src/twinkle_client/utils/patch_tinker.py @@ -12,6 +12,8 @@ import os from typing import TYPE_CHECKING, Any, Mapping, Union +from twinkle_client.http.utils import get_api_key, get_request_id + _patched = False @@ -115,6 +117,29 @@ def _patched_from_tinker_path(cls, tinker_path: str) -> Any: ) +def _make_patched_service_client_init(original): + def _patched_service_client_init(self, user_metadata=None, **kwargs): + """Patched version of ServiceClient.__init__ that injects Twinkle-specific headers.""" + # Resolve api_key with the same priority order used by AsyncTinker: + # 1. explicit kwarg 2. TINKER_API_KEY env var 3. TWINKLE_SERVER_TOKEN env var + api_key = kwargs.get('api_key') + if api_key is None: + api_key = get_api_key() + + twinkle_headers = { + 'serve_multiplexed_model_id': get_request_id(), + 'Authorization': 'Bearer ' + api_key, + 'Twinkle-Authorization': 'Bearer ' + api_key, + } + # Merge: caller-supplied default_headers take precedence over twinkle_headers + user_default_headers = kwargs.pop('default_headers', {}) + kwargs['default_headers'] = twinkle_headers | user_default_headers + + original(self, user_metadata=user_metadata, **kwargs) + + return _patched_service_client_init + + def patch_tinker(): """ Apply patches to tinker library. @@ -146,29 +171,7 @@ def patch_tinker(): # Patch 4: inject Twinkle-specific headers by patching ServiceClient.__init__. from tinker.lib.public_interfaces.service_client import ServiceClient - from twinkle_client.http.utils import get_request_id, get_api_key - - _original_service_client_init = ServiceClient.__init__ - - def _patched_service_client_init(self, user_metadata=None, **kwargs): - # Resolve api_key with the same priority order used by AsyncTinker: - # 1. explicit kwarg 2. TINKER_API_KEY env var 3. TWINKLE_SERVER_TOKEN env var - api_key = kwargs.get('api_key') - if api_key is None: - api_key = get_api_key() - - twinkle_headers = { - 'serve_multiplexed_model_id': get_request_id(), - 'Authorization': 'Bearer ' + api_key, - 'Twinkle-Authorization': 'Bearer ' + api_key, - } - # Merge: caller-supplied default_headers take precedence over twinkle_headers - user_default_headers = kwargs.pop('default_headers', {}) - kwargs['default_headers'] = twinkle_headers | user_default_headers - - _original_service_client_init(self, user_metadata=user_metadata, **kwargs) - - ServiceClient.__init__ = _patched_service_client_init + ServiceClient.__init__ = _make_patched_service_client_init(ServiceClient.__init__) _patched = True except ImportError: