diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 99a7a5e15..31607875f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,9 @@ repos: hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] + exclude: ^examples/specdec_bench/specdec_bench/datasets/speed\.py$ - id: ruff-format + exclude: ^examples/specdec_bench/specdec_bench/datasets/speed\.py$ - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.17.1 diff --git a/examples/specdec_bench/README.md b/examples/specdec_bench/README.md index 770edf8d7..3040f89f9 100644 --- a/examples/specdec_bench/README.md +++ b/examples/specdec_bench/README.md @@ -41,6 +41,52 @@ python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b - ``` +### Running [SPEED-Bench](https://huggingface.co/datasets/nvidia/SPEED-Bench) on Llama 3.3 70B + Eagle 3 + +1. Install the requirements file using `pip install -r requirements.txt` + +2. Prepare the data using the provided script: + +```bash +python3 prepare_data.py --dataset speed --config all +``` + +The data will be saved to `data/` directory, each config type (qualitative, throughput_1k, ...) to each own directory. + +#### License + +GOVERNING TERMS: This dataset is governed by the NVIDIA Evaluation Dataset License Agreement. + +ADDITIONAL INFORMATION: MIT for bigcode/humanevalpack, RUCAIBox/MMATH, RUCAIBox/BAMBOO and EQ-Bench. Apache 2.0 for Writing Bench and Spec-Bench. CC BY 4.0 for FBK-MT/MCIF. MIT and Apache 2.0 for tianyang/repobench_python_v1.1, JetBrains-Research/lca-project-level-code-completion and tianyang/repobench_java_v1.1. + +NOTICE: For each dataset a user elects to use, the user is responsible for checking if the dataset license is fit for the intended purpose. The `prepare_data.py` script automatically fetches data from all the source datasets. + +Additional details are in [HuggingFace dataset repository](https://huggingface.co/datasets/nvidia/SPEED-Bench). + +#### Qualitative split + +```bash +python3 run.py --model_dir meta-llama/Llama-3.3-70B-Instruct --tokenizer meta-llama/Llama-3.3-70B-Instruct --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B --dataset speed --dataset_path data/speed/qualitative --tp_size 8 --ep_size 1 --draft_length 3 --output_length 4096 --engine TRTLLM --concurrency 32 --show_progress +``` + +#### Throughput split + +```bash +python3 run.py --model_dir meta-llama/Llama-3.3-70B-Instruct --tokenizer meta-llama/Llama-3.3-70B-Instruct --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B --dataset speed --dataset_path data/speed/throughput_1k --tp_size 8 --ep_size 1 --draft_length 3 --output_length 4096 --engine TRTLLM --concurrency 32 --show_progress +``` + +For longer context (>8192 tokens), please use the following configuration when using TRTLLM: + +```yaml +engine_args: + max_seq_len: 131072 # Model max context length (for Llama 3.3 70B) + enable_chunked_prefill: true +``` + +```bash +python3 run.py --model_dir meta-llama/Llama-3.3-70B-Instruct --tokenizer meta-llama/Llama-3.3-70B-Instruct --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B --dataset speed --dataset_path data/speed/throughput_16k --tp_size 8 --ep_size 1 --draft_length 3 --output_length 4096 --engine TRTLLM --concurrency 32 --show_progress --runtime_params runtime_args_long_context.yaml +``` + ## Notes The goal of this benchmark is to provide an easy way to configure, run, and compare speculative implementations across frameworks in an apples-to-apples method. diff --git a/examples/specdec_bench/SPECBENCH_PORTING.md b/examples/specdec_bench/SPECBENCH_PORTING.md new file mode 100644 index 000000000..bdc2be5c3 --- /dev/null +++ b/examples/specdec_bench/SPECBENCH_PORTING.md @@ -0,0 +1,329 @@ +# Porting Spec-Bench Inference Runners to specdec_bench + +This guide explains how to convert any `inference_*.py` runner from [Spec-Bench](https://github.com/hemingkx/Spec-Bench) to a model class compatible with `specdec_bench`. + +## Overview + +Spec-Bench inference runners follow a pattern where: + +1. A `*_forward()` function handles the speculative decoding logic +2. The `run_eval()` function orchestrates evaluation with tokenized inputs +3. Models are loaded in `__main__` and passed to `run_eval()` + +In contrast, `specdec_bench` uses a class-based approach where: + +1. Models inherit from the `Model` base class +2. `__init__()` handles model loading +3. `run()` is an async method that processes single requests +4. `stop()` handles cleanup + +## The specdec_bench Model Interface + +```python +class Model: + def __init__(self, model_dir, tokenizer, max_draft_length): + raise NotImplementedError + + async def run(self, prompt_ids, sampling_params, request_id, turn_id): + """ + prompt_ids: list of token IDs (not a tensor!) + Returns dict with: + - output_ids: list of list of token chunks per step [[chunk1, chunk2, ...]] + - output_logits: optional logits (usually None) + - token_times: list of timestamps per decoding step + """ + raise NotImplementedError + + def stop(self): + pass +``` + +## Step-by-Step Porting Guide + +### Step 1: Identify the Key Components in Spec-Bench + +Look at the `inference_*.py` file and identify: + +1. **The forward function** (e.g., `medusa_forward`, `ea_forward`) + - This contains the core speculative decoding loop + - Signature: `forward_func(inputs, model, tokenizer, max_new_tokens, **kwargs)` + - Returns: `(output_ids, new_token_count, num_steps, accept_length_list)` + +2. **The model class** (e.g., `MedusaModel`, `EaModel`) + - Found in `model//` directory + - Has a `from_pretrained()` class method + +3. **Required utilities** from the method's module: + - Buffer generation (e.g., `generate_medusa_buffers`) + - Initialization functions (e.g., `initialize_medusa`, `initialize_past_key_values`) + - Decoding functions (e.g., `tree_decoding`, `generate_candidates`) + - State update functions (e.g., `update_inference_inputs`) + +4. **Method-specific choices/configs** (e.g., `mc_sim_7b_63` for Medusa) + +### Step 2: Create the specdec_bench Model Class + +```python +# specdec_bench/specdec_bench/models/specbench_.py + +from .base import Model +import asyncio +import time +import torch + +# Import dependencies from Spec-Bench +try: + import sys + import os + spec_bench_path = os.path.join(os.getcwd(), "Spec-Bench") + sys.path.insert(0, spec_bench_path) + from model.. import + from model..kv_cache import initialize_past_key_values + from model..utils import ( + # Import all required utilities + ) + from model.. import +except ImportError as e: + print(f" dependencies not found: {e}") + = None + + +class SpecBenchModel(Model): + def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs): + # 1. Validate dependencies + if is None: + raise ImportError(" dependencies not found.") + + # 2. Extract configuration from kwargs + self.dtype = kwargs.get("dtype", "float16") + self.max_steps = kwargs.get("max_steps", 512) + self.temperature = sampling_kwargs.get("temperature", 0.0) + # ... other method-specific parameters + + # 3. Set up device (avoid device_map="auto" for multi-GPU issues) + self.device = torch.device(kwargs.get("device", "cuda:0")) + + # 4. Convert dtype string to torch dtype + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + torch_dtype = dtype_map.get(self.dtype, torch.float16) + + # 5. Load the model + self.model = .from_pretrained( + model_dir, + # ... other args from Spec-Bench's __main__ + torch_dtype=torch_dtype, + low_cpu_mem_usage=True, + ) + self.model = self.model.to(self.device) + + self.sampling_kwargs = sampling_kwargs +``` + +### Step 3: Port the Forward Function + +Convert the standalone `*_forward()` function to an internal method: + +```python + def _forward(self, input_ids, max_new_tokens, end_id): + """ + Port of the original *_forward function. + + Key changes from Spec-Bench: + 1. input_ids is already a tensor (converted in run()) + 2. Add timing list to track per-step timestamps + 3. Use self.device instead of model.base_model.device + 4. Return timing along with other outputs + """ + accept_length_list = [] + timing = [time.perf_counter()] # ADD: Track timing + + # === COPY THE FORWARD LOGIC FROM SPEC-BENCH === + # Replace: device=model.base_model.device + # With: device=self.device + + # Initialize buffers... + # Initialize KV cache... + # Main decoding loop... + + for idx in range(self.max_steps): + # Generate candidates... + # Tree decoding... + # Evaluate posterior... + # Update inputs... + + timing.append(time.perf_counter()) # ADD: Record time per step + + # Check for EOS + if end_id in input_ids[0, input_len:].tolist(): + break + if new_token > max_new_tokens: + break + + return input_ids, new_token, idx + 1, accept_length_list, timing # ADD timing +``` + +### Step 4: Implement the run() Method + +```python + async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): + """ + Async interface for specdec_bench. + + Args: + prompt_ids: List of input token IDs (NOT a tensor) + max_length: Maximum new tokens to generate + end_id: EOS token ID + request_id: Request identifier + turn_id: Turn identifier + + Returns: + dict with output_ids, output_logits, token_times + """ + output_dict = {} + + # Convert prompt_ids list to tensor + input_ids = torch.tensor( + [prompt_ids], dtype=torch.long, device=self.device + ) + + # Run forward pass (use asyncio.to_thread for sync code) + result = await asyncio.to_thread( + self._forward, input_ids, max_length, end_id + ) + input_ids_out, new_token, num_steps, accept_length_list, timing = result + + # Extract generated tokens (excluding prompt) + original_len = len(prompt_ids) + generated_tokens = input_ids_out[0, original_len:].tolist() + + # Remove EOS token if present + if end_id in generated_tokens: + eos_idx = generated_tokens.index(end_id) + generated_tokens = generated_tokens[:eos_idx] + + # Format output_ids as list of token chunks per step + # This matches specdec_bench's expected format + reformatted_output_ids = [[]] + start = 0 + for accept_len in accept_length_list: + if accept_len > 0 and start < len(generated_tokens): + chunk = generated_tokens[start:start + accept_len] + if chunk: + reformatted_output_ids[0].append(chunk) + start += accept_len + + # Handle remaining tokens + if start < len(generated_tokens): + reformatted_output_ids[0].append(generated_tokens[start:]) + + output_dict['output_ids'] = reformatted_output_ids + output_dict['output_logits'] = None + output_dict['token_times'] = timing + + return output_dict +``` + +### Step 5: Implement stop() for Cleanup + +```python + def stop(self): + """Clean up resources.""" + # Clear any cached states + if hasattr(self.model, "past_key_values"): + del self.model.past_key_values + del self.model.past_key_values_data + del self.model.current_length_data + + # Clear method-specific buffers + if hasattr(self.model, "_buffers"): + del self.model._buffers + + # Free GPU memory + if hasattr(self, 'model') and self.model is not None: + del self.model + torch.cuda.empty_cache() +``` + +### Step 6: Register the Model (Optional) + +Add to `specdec_bench/specdec_bench/models/__init__.py`: + +```python +from .specbench_ import SpecBenchModel +``` + +## Key Differences Summary + +| Aspect | Spec-Bench | specdec_bench | +|--------|-----------|---------------| +| Input format | `inputs.input_ids` (tensor from tokenizer) | `prompt_ids` (list of ints) | +| Output format | `(output_ids, new_token, steps, accept_lengths)` | `dict` with `output_ids`, `output_logits`, `token_times` | +| Output IDs | Full sequence tensor | List of token chunks per step | +| Timing | External (in `run_eval`) | Internal (in `run()`) | +| Device | `device_map="auto"` | Explicit single device | +| Interface | Function-based | Class-based with async `run()` | + +## Common Pitfalls + +1. **Device Mismatch**: Avoid `device_map="auto"` which spreads model across GPUs. Use explicit `.to(device)`. + +2. **Tensor vs List**: `prompt_ids` in specdec_bench is a Python list, not a tensor. Convert it in `run()`. + +3. **Output Format**: specdec_bench expects `output_ids` as `[[chunk1, chunk2, ...]]` (list of lists of lists for beam_width=1). + +4. **Timing**: Add `time.perf_counter()` calls to track per-step latency. + +5. **EOS Handling**: Strip EOS tokens from output before formatting. + +6. **Async Wrapper**: Use `asyncio.to_thread()` to wrap synchronous forward passes. + +## Example: Mapping Spec-Bench Methods + +| Spec-Bench File | Model Class | Forward Function | Key Utils | +|-----------------|-------------|------------------|-----------| +| `inference_medusa.py` | `MedusaModel` | `medusa_forward` | `generate_medusa_buffers`, `initialize_medusa` | +| `inference_eagle.py` | `EaModel` | `ea_forward` | `generate_tree_buffers`, `initialize_tree` | +| `inference_eagle2.py` | `EaModel` | `ea_forward` | Same as EAGLE | +| `inference_hydra.py` | `HydraModel` | `hydra_forward` | `generate_hydra_buffers`, `initialize_hydra` | +| `inference_lookahead.py` | `LookaheadModel` | `lookahead_forward` | Lookahead-specific utils | + +## Testing Your Port + +```python +import asyncio + +async def test(): + model = SpecBenchModel( + model_dir="/path/to/model", + max_concurrent_requests=1, + sampling_kwargs={"temperature": 0.0}, + # method-specific kwargs... + ) + + result = await model.run( + prompt_ids=[1, 2, 3, 4, 5], # Example token IDs + max_length=100, + end_id=2, # EOS token + request_id="test", + turn_id=0 + ) + + print("Output chunks:", result['output_ids']) + print("Timing:", result['token_times']) + + model.stop() + +asyncio.run(test()) +``` + +Adjust the vicuna chat template to be in the tokenizer_config to be + +Insert to tokenizer_config (for vicuna) + +```json +"chat_template": "{% set ns = namespace(system='') %}{% for m in messages %}{% if m['role'] == 'system' %}{% set ns.system = m['content'] %}{% endif %}{% endfor %}{{ ns.system | trim }}{% if ns.system | trim != '' %} {% endif %}{% for m in messages %}{% if m['role'] == 'user' %}USER: {{ m['content'] | trim }} ASSISTANT:{% elif m['role'] == 'assistant' %}{{ m['content'] | trim }}{% endif %}{% endfor %}" +``` diff --git a/examples/specdec_bench/prepare_data.py b/examples/specdec_bench/prepare_data.py new file mode 100644 index 000000000..67fe89898 --- /dev/null +++ b/examples/specdec_bench/prepare_data.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path +from typing import get_args + +from specdec_bench import datasets +from specdec_bench.datasets.speed import config_type + +datasets_available = { + "speed": datasets.SPEEDBench, +} + + +def prepare_data(args: argparse.Namespace) -> None: + """Prepare and save benchmark data to disk. + + Calls the dataset's ``prepare_data`` classmethod which downloads and + resolves all external data references, then saves the fully-resolved + result as a parquet file so that subsequent benchmark runs can load + directly from disk without re-downloading. + + Args: + args: Parsed CLI arguments containing dataset type, config, + output directory, and optional filtering parameters. + """ + configs = get_args(config_type) if args.config == "all" else [args.config] + + dataset_cls = datasets_available[args.dataset] + + for config in configs: + print(f"Preparing config '{config}' ...") + + output_path = dataset_cls.prepare_data( + output_dir=args.output_dir / args.dataset / config, + config_name=config, + ) + + print(f" -> Saved to {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Download and prepare benchmark datasets for specdec_bench.", + ) + parser.add_argument( + "--dataset", + type=str, + default="speed", + choices=list(datasets_available.keys()), + help="Dataset to prepare (default: %(default)s)", + ) + parser.add_argument( + "--config", + type=str, + default="all", + choices=[*list(get_args(config_type)), "all"], + help='SPEED-Bench configuration to prepare. Use "all" to prepare all configs. (default: %(default)s)', + ) + parser.add_argument( + "--output_dir", + type=Path, + default=Path("data/"), + help="Directory to save the prepared dataset files (default: %(default)s)", + ) + + args = parser.parse_args() + prepare_data(args) diff --git a/examples/specdec_bench/requirements.txt b/examples/specdec_bench/requirements.txt new file mode 100644 index 000000000..17bf6f8fd --- /dev/null +++ b/examples/specdec_bench/requirements.txt @@ -0,0 +1,4 @@ +datasets~=4.4.0 +rich~=14.2.0 +seaborn~=0.13.2 +tiktoken~=0.12.0 diff --git a/examples/specdec_bench/run.py b/examples/specdec_bench/run.py index dd57a5142..e2ea21758 100644 --- a/examples/specdec_bench/run.py +++ b/examples/specdec_bench/run.py @@ -25,15 +25,48 @@ postprocess_base, postprocess_gptoss, ) +from tqdm.asyncio import tqdm engines_available = { "TRTLLM": models.TRTLLMPYTModel, "VLLM": models.VLLMModel, "SGLANG": models.SGLANGModel, + "AUTO_DEPLOY": models.AutoDeployModel, + "SPECBENCH_MEDUSA": models.SpecBenchMedusaModel, } +datasets_available = { + "mtbench": datasets.MTBench, + "random": datasets.RandomToken, + "specbench": datasets.SpecBench, + "speed": datasets.SPEEDBench, +} + + +async def tqdm_gather(*fs, return_exceptions=False, **kwargs): + if not return_exceptions: + return await tqdm.gather(*fs, **kwargs) + async def wrap(f): + try: + return await f + except Exception as e: + return e -async def run_loop(runner, dataset, tokenizer, output_length, postprocess, concurrency=10): + return await tqdm.gather(*map(wrap, fs), **kwargs) + + +async def run_loop( + runner, + dataset, + tokenizer, + output_length, + postprocess, + concurrency=10, + end_id=-1, + show_progress=False, + completions=False, + chat_template_args={}, +): """ Async version of run_loop with concurrency control using a semaphore. @@ -46,7 +79,6 @@ async def run_loop(runner, dataset, tokenizer, output_length, postprocess, concu """ semaphore = asyncio.Semaphore(concurrency) max_length = output_length - end_id = tokenizer.eos_token_id async def process_single_request(request, i): """Process a single request with all its conversation turns.""" @@ -57,7 +89,12 @@ async def process_single_request(request, i): for turn_id, question in enumerate(request.turns): messages.append({"role": "user", "content": question}) - entry_encoded = encode_chat(tokenizer, messages) + entry_encoded = encode_chat( + tokenizer, + messages, + chat_template_args=chat_template_args, + completions=completions, + ) # Run the async runner.run directly output_tokens = await runner.run( @@ -70,12 +107,19 @@ async def process_single_request(request, i): return messages tasks = [process_single_request(request, i) for i, request in enumerate(dataset.data)] - text_outputs = await asyncio.gather(*tasks, return_exceptions=True) + if show_progress: + text_outputs = await tqdm_gather( + *tasks, + return_exceptions=True, + desc=f"Running requests (concurrency={concurrency})", + ) + else: + text_outputs = await asyncio.gather(*tasks, return_exceptions=True) # Check for any exceptions and handle them for i, result in enumerate(text_outputs): if isinstance(result, Exception): - print(f"Error processing request {i}: {result}") + print(f"Error processing request {i}/{dataset.data[i].question_id}: {result}") raise result runner.process_metrics_final(text_outputs) @@ -84,13 +128,22 @@ async def process_single_request(request, i): def run_simple(args): tokenizer = get_tokenizer(args.tokenizer) + chat_template_args = args.runtime_params.get("chat_template_args", {}) dataset_kwargs = args.runtime_params.get("dataset_kwargs", {}) - if args.mtbench is not None: - dataset = datasets.MTBench(args.mtbench, args.num_requests, **dataset_kwargs) + if args.num_requests is not None: + dataset_kwargs["num_samples"] = args.num_requests + if args.dataset is not None: + if args.dataset == "random": + assert args.random_isl is not None, "Random input length must be provided" + dataset = datasets.RandomToken(tokenizer, args.random_isl, **dataset_kwargs) + else: + dataset = datasets_available[args.dataset](args.dataset_path, **dataset_kwargs) + elif args.mtbench is not None: + dataset = datasets.MTBench(args.mtbench, **dataset_kwargs) elif args.random_isl is not None: - dataset = datasets.RandomToken( - tokenizer, args.random_isl, args.num_requests, **dataset_kwargs - ) + dataset = datasets.RandomToken(tokenizer, args.random_isl, **dataset_kwargs) + elif args.specbench is not None: + dataset = datasets.SpecBench(args.specbench, **dataset_kwargs) engine_args = args.runtime_params.get("engine_args", {}) sampling_kwargs = args.runtime_params.get("sampling_kwargs", {"temperature": 0}) model_class = engines_available[args.engine] @@ -111,8 +164,15 @@ def run_simple(args): metrics_list.append(metrics.AATiming(tokenizer)) if args.mtbench is not None: metrics_list.insert(0, metrics.MTBench()) + elif args.specbench is not None or args.dataset == "speed": + metrics_list.insert(0, metrics.SpecBench(requests=dataset.data)) else: metrics_list.insert(0, metrics.AcceptanceRate()) + + if args.save_dir is not None: + for metric in metrics_list: + metric.update_directory(args.save_dir) + runner = runners.SimpleRunner(model, metrics=metrics_list) if args.postprocess == "base": @@ -122,8 +182,21 @@ def run_simple(args): else: raise ValueError(f"Invalid postprocess: {args.postprocess}") + end_id = tokenizer.eos_token_id if not args.ignore_eos else -1 + asyncio.run( - run_loop(runner, dataset, tokenizer, args.output_length, postprocess, args.concurrency) + run_loop( + runner, + dataset, + tokenizer, + args.output_length, + postprocess, + args.concurrency, + end_id, + args.show_progress, + args.completions, + chat_template_args, + ) ) runner.clear_metrics() @@ -135,7 +208,18 @@ def run_simple(args): "--tokenizer", type=str, required=True, help="Path to the tokenizer directory" ) parser.add_argument( - "--mtbench", type=str, required=False, default=None, help="Path to the mtbench dataset" + "--mtbench", + type=str, + required=False, + default=None, + help="Path to the mtbench dataset", + ) + parser.add_argument( + "--specbench", + type=str, + required=False, + default=None, + help="Path to the specbench dataset", ) parser.add_argument( "--random_isl", @@ -144,7 +228,28 @@ def run_simple(args): default=None, help="How many tokens random input should be.", ) - parser.add_argument("--num_requests", type=int, required=True, help="Number of requests to run") + parser.add_argument( + "--dataset", + type=str, + required=False, + default=None, + choices=list(datasets_available.keys()), + help="Dataset to use", + ) + parser.add_argument( + "--dataset_path", + type=str, + required=False, + default=None, + help="Path to the dataset or config name for SPEEDBench", + ) + parser.add_argument( + "--num_requests", + type=int, + required=False, + default=None, + help="Number of requests to run. If not provided, all requests from the dataset will be run.", + ) parser.add_argument( "--engine", type=str, @@ -194,6 +299,13 @@ def run_simple(args): help="Maximum number of concurrent requests", ) parser.add_argument("--aa_timing", action="store_true", help="Enable AA timing metric") + parser.add_argument("--ignore_eos", action="store_true", help="Ignore EOS token") + parser.add_argument("--show_progress", action="store_true", help="Show progress bar") + parser.add_argument( + "--completions", + action="store_true", + help="Skip chat template, tokenize the message directly", + ) parser.add_argument( "--postprocess", type=str, @@ -202,7 +314,13 @@ def run_simple(args): choices=["base", "gptoss"], help="Postprocess to use", ) - + parser.add_argument( + "--save_dir", + type=str, + required=False, + default=None, + help="Directory to save the results", + ) args = parser.parse_args() if args.runtime_params is not None: @@ -210,9 +328,20 @@ def run_simple(args): args.runtime_params = yaml.safe_load(f) else: args.runtime_params = {} + if args.dataset is None: + assert ( + args.mtbench is not None or args.random_isl is not None or args.specbench is not None + ), "Either mtbench or random_isl or specbench must be provided" + else: + assert args.dataset_path is not None, "Dataset path must be provided" + if args.dataset == "specbench": + args.specbench = args.dataset_path + elif args.dataset == "mtbench": + args.mtbench = args.dataset_path - assert args.mtbench is not None or args.random_isl is not None, ( - "Either mtbench or random_isl must be provided" - ) + if args.ignore_eos: + print( + "Warning: Ignore EOS should only be used in certain cases, do no activate unless necessary" + ) run_simple(args) diff --git a/examples/specdec_bench/specdec_bench/__init__.py b/examples/specdec_bench/specdec_bench/__init__.py index 3159bfe65..47f1c65a1 100644 --- a/examples/specdec_bench/specdec_bench/__init__.py +++ b/examples/specdec_bench/specdec_bench/__init__.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,3 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + diff --git a/examples/specdec_bench/specdec_bench/datasets/__init__.py b/examples/specdec_bench/specdec_bench/datasets/__init__.py index 64449d2b5..aefc2605b 100644 --- a/examples/specdec_bench/specdec_bench/datasets/__init__.py +++ b/examples/specdec_bench/specdec_bench/datasets/__init__.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import Dataset -from .base_hf import OpenMathInstructv2, OpenOrca, UltraChat from .mtbench import MTBench from .random_token import RandomToken +from .specbench import SpecBench +from .speed import SPEEDBench + +__all__ = ["MTBench", "RandomToken", "SPEEDBench", "SpecBench"] diff --git a/examples/specdec_bench/specdec_bench/datasets/base.py b/examples/specdec_bench/specdec_bench/datasets/base.py index 587c04b07..eb72affa4 100644 --- a/examples/specdec_bench/specdec_bench/datasets/base.py +++ b/examples/specdec_bench/specdec_bench/datasets/base.py @@ -14,11 +14,14 @@ # limitations under the License. from dataclasses import dataclass, field +from pathlib import Path from typing import Any @dataclass class Request: + question_id: int | None = None + category: str | None = None system_prompt: str | None = None turns: list[str] = field(default_factory=list) mm_content: Any | None = None # TODO @@ -35,3 +38,22 @@ def __init__(self, path, **kwargs): def _preprocess(self): raise NotImplementedError + + @classmethod + def prepare_data(cls, output_dir: str | Path, **kwargs) -> Path: + """Prepare and save the dataset to the specified output directory. + + Downloads any external data, resolves all references, and persists + the fully-resolved dataset so that subsequent loads are self-contained. + + Args: + output_dir: Directory where the prepared data file will be saved. + **kwargs: Dataset-specific parameters (e.g. config_name, category). + + Returns: + Path to the saved dataset file. + + Raises: + NotImplementedError: Subclasses must override this method. + """ + raise NotImplementedError diff --git a/examples/specdec_bench/specdec_bench/datasets/base_hf.py b/examples/specdec_bench/specdec_bench/datasets/base_hf.py deleted file mode 100644 index 6c7be3d8c..000000000 --- a/examples/specdec_bench/specdec_bench/datasets/base_hf.py +++ /dev/null @@ -1,70 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -try: - from datasets import load_dataset -except ImportError: - print("datasets is not installed.") - datasets = None - - -from .base import Dataset, Request - - -class BaseHF(Dataset): - def __init__(self, num_samples=100, **kwargs): - self.data: list[Request] = [] # list of list of questions. - self.num_samples = num_samples - self._preprocess() - - def _preprocess(self): - dataset = self._load_dataset(self.num_samples) - for i, line in enumerate(dataset): - if i == self.num_samples: - break - self.data.append(self._single_line_process(line)) - - def _single_line_process(self, line): - raise NotImplementedError - - def _load_dataset(self, num_samples): - raise NotImplementedError - - -class OpenOrca(BaseHF): - def _single_line_process(self, line, **kwargs): - return Request(system_prompt=line["system_prompt"], turns=[line["question"]]) - - def _load_dataset(self, num_samples): - return load_dataset("Open-Orca/OpenOrca", split="train", streaming=True) - - -class OpenMathInstructv2(BaseHF): - def _single_line_process(self, line, **kwargs): - return Request(system_prompt=None, turns=[line["problem"]]) - - def _load_dataset(self, num_samples): - return load_dataset("nvidia/OpenMathInstruct-2", split="train_1M", streaming=True) - - -class UltraChat(BaseHF): - def _single_line_process(self, line, **kwargs): - return Request( - system_prompt=None, turns=[q for i, q in enumerate(line["data"]) if i % 2 == 0] - ) - - def _load_dataset(self, num_samples): - return load_dataset("stingning/ultrachat", split="train", streaming=True) diff --git a/examples/specdec_bench/specdec_bench/datasets/mtbench.py b/examples/specdec_bench/specdec_bench/datasets/mtbench.py index 53295bdbb..cb58dd210 100644 --- a/examples/specdec_bench/specdec_bench/datasets/mtbench.py +++ b/examples/specdec_bench/specdec_bench/datasets/mtbench.py @@ -33,11 +33,10 @@ class MTBench(Dataset): def __init__(self, path, num_samples=80, **kwargs): self.data: list[Request] = [] # list of list of questions. self.num_samples = num_samples - self.path = path - self._preprocess() + self._preprocess(path) - def _preprocess(self): - with open(self.path) as f: + def _preprocess(self, path): + with open(path) as f: for json_line in f: line = json.loads(json_line) key = "turns" if "turns" in line else "prompt" diff --git a/examples/specdec_bench/specdec_bench/datasets/random_token.py b/examples/specdec_bench/specdec_bench/datasets/random_token.py index 972a0455c..521db57ee 100644 --- a/examples/specdec_bench/specdec_bench/datasets/random_token.py +++ b/examples/specdec_bench/specdec_bench/datasets/random_token.py @@ -24,12 +24,10 @@ def __init__(self, tokenizer, input_len, num_samples=20, **kwargs): self.data: list[Request] = [] # list of list of questions. self.num_samples = num_samples self.input_len = input_len - self.tokenizer = tokenizer - self._preprocess() + self._preprocess(tokenizer) - def _preprocess(self): + def _preprocess(self, tokenizer): np.random.seed(0) - tokenizer = self.tokenizer num_prompts = self.num_samples offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) for i in range(num_prompts): diff --git a/examples/specdec_bench/specdec_bench/datasets/specbench.py b/examples/specdec_bench/specdec_bench/datasets/specbench.py new file mode 100644 index 000000000..a14d34039 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/datasets/specbench.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from .base import Dataset, Request + + +class SpecBench(Dataset): + def __init__(self, path, num_samples=480, **kwargs): + self.data: list[Request] = [] # list of list of questions. + self.num_samples = num_samples + self._preprocess(path) + + def _preprocess(self, path): + with open(path) as f: + for json_line in f: + line = json.loads(json_line) + self.data.append( + Request( + question_id=line["question_id"], + category=line["category"], + system_prompt=None, + turns=line["turns"], + ) + ) + self.data = self.data[: self.num_samples] diff --git a/examples/specdec_bench/specdec_bench/datasets/speed.py b/examples/specdec_bench/specdec_bench/datasets/speed.py new file mode 100644 index 000000000..acb082507 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/datasets/speed.py @@ -0,0 +1,802 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: disable-error-code="index" +import random +import re +from enum import Enum +from pathlib import Path +from typing import Any, Literal, get_args + +from .base import Dataset, Request + +try: + import numpy as np + import pandas as pd + import tiktoken + from datasets import concatenate_datasets, load_dataset + + not_installed = False +except ImportError: + not_installed = True + + +config_type = Literal[ + "qualitative", + "throughput_1k", + "throughput_2k", + "throughput_8k", + "throughput_16k", + "throughput_32k", +] +TURNS_PLACEHOLDER = "FULL BENCHMARK DATA SHOULD BE FETCHED FROM THE SOURCE USING SPECDEC_BENCH" + + +class BenchmarkDataset(str, Enum): + """Enum for benchmark datasets used in SPEED-Bench. + + Each enum value represents a HuggingFace dataset identifier used for + loading external benchmark datasets. + """ + + BAMBOO = "RUCAIBox/BAMBOO" + CNN_DAILYMAIL = "abisee/cnn_dailymail" + HLE = "cais/hle" + LIVECODEBENCH = "livecodebench/code_generation_lite" + CODE_CONTESTS = "deepmind/code_contests" + MTBENCH_101 = "mtbench101/mt-bench-101" + OPUS100 = "Helsinki-NLP/opus-100" + CHATRAG_BENCH = "nvidia/ChatRAG-Bench" + MMLU_PRO = "TIGER-Lab/MMLU-Pro" + ADALEVAL_STACKSELECT = "AdaLEval/stackselect" + ADALEVAL_TEXTSORT = "AdaLEval/textsort" + ROLEBENCH = "ZenMoore/RoleBench" + ROLEBENCH_ROLES = "ZenMoore/RoleBench/roles" + COSER = "Neph0s/CoSER" + + +DATASETS_AND_LOADERS_FUNCTIONS = { + BenchmarkDataset.BAMBOO.value: lambda dataset_name, config_name: load_dataset( + "json", data_files={"test": config_name}, split="test" + ), + BenchmarkDataset.CNN_DAILYMAIL.value: lambda dataset_name, config_name: load_dataset( + dataset_name, config_name, split="test" + ), + BenchmarkDataset.HLE.value: lambda dataset_name, config_name: load_dataset( + dataset_name, split="test", revision="021a3d71f516a7ac28ceb8d284969902edf1edeb" + ) + if config_name != "train_test_split" + else load_dataset( + dataset_name, split="test", revision="021a3d71f516a7ac28ceb8d284969902edf1edeb" + ).train_test_split(test_size=0.5, shuffle=True, seed=42), + BenchmarkDataset.LIVECODEBENCH.value: lambda dataset_name, config_name: load_dataset( + "json", + data_files={ + "test": [ + f"https://huggingface.co/datasets/livecodebench/code_generation_lite/resolve/0fe84c3912ea0c4d4a78037083943e8f0c4dd505/{file_name}.jsonl" + for file_name in ["test", "test2", "test3", "test4", "test5", "test6"] + ] + }, + split="test", + ), + BenchmarkDataset.CODE_CONTESTS.value: lambda dataset_name, config_name: load_dataset( + dataset_name, split="test", revision="802411c3010cb00d1b05bad57ca77365a3c699d6" + ), + BenchmarkDataset.MTBENCH_101.value: lambda dataset_name, config_name: load_dataset( + "json", data_files={"test": config_name}, split="test" + ), + BenchmarkDataset.OPUS100.value: lambda dataset_name, config_name: load_dataset( + dataset_name, + config_name, + split="test", + revision="805090dc28bf78897da9641cdf08b61287580df9", + ), + BenchmarkDataset.CHATRAG_BENCH.value: lambda dataset_name, config_names: concatenate_datasets( + [ + load_dataset( + dataset_name, + config_name, + split="test", + revision="af6c7d420ddddf21f54f8ab3394bbf462aad2577", + ) + for config_name in config_names + ] + ), + BenchmarkDataset.MMLU_PRO.value: lambda dataset_name, config_name: load_dataset( + dataset_name, split="test", revision="30527804ea8854662078e457808040d872ecdf29" + ), + BenchmarkDataset.ADALEVAL_STACKSELECT.value: lambda dataset_name, config_name: load_dataset( + "json", data_files={"test": config_name}, split="test" + ), + BenchmarkDataset.ADALEVAL_TEXTSORT.value: lambda dataset_name, config_name: load_dataset( + "json", data_files={"test": config_name}, split="test" + ), + BenchmarkDataset.ROLEBENCH.value: lambda dataset_name, config_name: pd.read_json( + config_name, lines=True + ), + BenchmarkDataset.ROLEBENCH_ROLES.value: lambda dataset_name, config_name: load_dataset( + "json", data_files={"test": config_name}, split="test" + ), + BenchmarkDataset.COSER.value: lambda dataset_name, config_name: load_dataset( + "json", + data_files={"test": config_name.replace("tree", "raw") + "/test/test_set.json"}, + split="test", + ), +} + + +class SPEEDBench(Dataset): + def __init__( + self, + config_name: config_type = "qualitative", + num_samples: int | None = None, + _prepare_mode: bool = False, + **kwargs, + ): + if not_installed: + raise ImportError( + "datasets, tiktoken, and numpy packages are required to use SPEED-Bench. Please run `pip install -r requirements.txt`" + ) + self.data: list[Request] = [] + self.num_samples = num_samples + self.external_datasets: dict[str, Any] = {} + self._config_name = config_name + self._resolved_dataset = None + self._preprocess(config_name, _prepare_mode=_prepare_mode) + + def _get_external_dataset(self, dataset_name: str, config_name: str | list[str] = "default"): + full_name = f"{dataset_name}_{config_name}" + if full_name not in self.external_datasets: + self.external_datasets[full_name] = DATASETS_AND_LOADERS_FUNCTIONS[dataset_name]( + dataset_name, config_name + ) + if config_name == "train_test_split": + self.external_datasets[full_name] = ( + self.external_datasets[full_name]["train"], + self.external_datasets[full_name]["test"], + ) + return self.external_datasets[full_name] + + @staticmethod + def _generate_stackselect_prompt( + question: str, answers: list[str], answer: str, num_tokens: int + ) -> str: + random.seed(42) + encoder = tiktoken.get_encoding("o200k_base") + prompt = """ +You are an AI assistant. Your job is to find out the most helpful answer to a given question. +Each time, you will be provided with a question and n answers to this question. +Each answer begins with an 'A' and a number(e.g. A4), which represents its designation. +You need to determine which answer is the most helpful one to the question. +The case sample is shown below and you should give me the answer in the format exactly the same as the sample. + +However, you should NOT focus on the content of sample answer. + +Sample Input (format only): + +The question is given below. +XXX(The content of question) +Possible answers are given below. +A1: +XXX(The content of answer 1) +A2: +XXX(The content of answer 2) +. +. +. +An: +XXX(The content of answer n) +Now the answers are over, please decide which answer is the most helpful one to the question. +You must give me the designation of the MOST helpful answer and the reason why you choose this answer. +For every other answer, you must give me the reason why you do not choose this answer. + +Sample Output (format only): + +Answer: The designation of the most helpful answer.(e.g. A4 means answer 4 is the most helpful answer) +Explanation: +A4: The reason why you choose this answer. +A1: The reason why you do not choose this answer. +A2: The reason why you do not choose this answer. +. +. +. +An: The reason why you do not choose this answer. +""" + prompt += "The question is given below.\n" + prompt += question + "\n\n" + prompt += "Possible answers are given below.\n" + tokens_prompt = len(encoder.encode(prompt, disallowed_special=())) + end_prompt = "Now the answers are over, please decide which answer is the most helpful one to the question. \n" + end_prompt += "You must give me the designation of the MOST helpful answer and the reason why you choose this answer.\n" + end_prompt += "For every other answer, you must give me the reason why you do not choose this answer.\n" + end_prompt_tokens = len(encoder.encode(end_prompt, disallowed_special=())) + correct_answer_i = int(answer.strip("A")) - 1 + correct_answer_tokens = len( + encoder.encode( + answer + ":\n\n" + answers[correct_answer_i] + "\n\n", + disallowed_special=(), + ) + ) + all_tokens = tokens_prompt + end_prompt_tokens + correct_answer_tokens + answers_to_add_stop = 0 + for i, answer in enumerate(answers): + if i == correct_answer_i: + continue + answer_to_add = f"A{i + 1}:\n\n{answer}\n\n" + answer_to_add_tokens = len(encoder.encode(answer_to_add, disallowed_special=())) + if all_tokens + answer_to_add_tokens > num_tokens: + break + answers_to_add_stop = i + answers_to_add = ( + answers[: answers_to_add_stop + 1] + if answers_to_add_stop >= correct_answer_i + else [answers[correct_answer_i]] + answers[: answers_to_add_stop + 1] + ) + random.shuffle(answers_to_add) + for i, answer in enumerate(answers_to_add): + prompt += f"A{i + 1}:\n\n{answer}\n\n" + prompt += end_prompt + return prompt + + @staticmethod + def _generate_textsort_prompt(prompt: str) -> str: + original_instruction = "\n You are an AI assistant. Your job is to sort multiple book sections into the correct order.\n Each time, you will be provided with 4 pieces of text.\n These texts form a continuous part of a book, but are provided in random order.\n You need to find the correct order and return the answer in a string.\n For example, if you output [4, 1, 3, 2], that means the correct order is: Part 4 -> Part 1 -> Part 3 -> Part 2.\n You will also be provided with the neighboring paragraphs before and after the 4 pieces of texts. \n\n The case sample is shown below and you should give me the answer in the format exactly the same as the sample. \n\n However, you should NOT focus on the content of sample answer. \n\n Please do NOT output any extra content. \n Sample Input (format only): \n\n Before: XXX (Text before the continuous book part)\n\n\n Part 1: XXX\n\n\n Part 2: XXX\n\n\n Part 3: XXX\n\n\n Part 4: XXX\n\n\n After: XXX (Text after the continuous book part)\n\n\n Sample Output (format only): \n\n Answer: [4, 1, 3, 2] \n\n\n\n" + + new_instruction = """ +You are an AI assistant. Your job is to sort multiple book sections into the correct order. + Each time, you will be provided with 4 pieces of text. + These texts form a continuous part of a book, but are provided in random order. + You need to find the correct order and write the all the parts in the correct order. + For example, if the correct order is: Part 4 -> Part 1 -> Part 3 -> Part 2, you need to answer with a continous text of all the parts in the correct order. + You should NOT change the text, just write it in the order it should appear. + You will also be provided with the neighboring paragraphs before and after the 4 pieces of texts. + You should NOT output the before and after paragraphs, just the text in the correct order. + + The case sample is shown below and you should give me the answer in the format exactly the same as the sample. + + However, you should NOT focus on the content of sample answer. + + Please do NOT output any extra content. + + Sample Input (format only): + + Before: BBB (Text before the continuous book part) + + + Part 1: XXX + + + Part 2: YYY + + + Part 3: ZZZ + + + Part 4: WWW + + + After: AAA (Text after the continuous book part) + + Sample Output (format only): + + Answer: + + + WWW + + XXX + + ZZZ + + YYY + """ + return prompt.replace(original_instruction, new_instruction, 1) + + @staticmethod + def _generate_writing_prompt(contents: list[str]) -> str: + content = "\n\n".join( + [ + f"START CONTENT {i + 1}\n\n{content}\n\nEND CONTENT" + for i, content in enumerate(contents) + ] + ) + prompt = f""" +I want you to act as a long dialogue completer. +Given a long dialogue(s), your objectives are: +1. Add one speaker mentioned in the past dialogue(s) at the end of the last sentence of each dialogue (between START CONTENT and END CONTENT) to complete the sentence and ensure its semantic integrity. At here, the added word must be a person's name which appears in the dialogue. +2. Continue the dialogue(s) with one or more speakers who appeared in the dialogue(s) before. Be coherent with the previous dialogue(s) and be creative in your response. +The content of the dialogue(s) is given below. + + +{content} +""" + return prompt + + @staticmethod + def _pad_or_truncate_prompt( + prompt: str, target_num_tokens: int, padding: str = "Answer now please.\n" + ) -> str: + encoder = tiktoken.get_encoding("o200k_base") + + tokens = encoder.encode(prompt, disallowed_special=()) + current_num_tokens = len(tokens) + + if current_num_tokens > target_num_tokens: + # Truncate if too long + tokens = encoder.encode(prompt, disallowed_special=()) + return encoder.decode(tokens[:target_num_tokens]) + elif current_num_tokens < target_num_tokens: + # Add padding if too short + padding_tokens = encoder.encode(padding, disallowed_special=()) + tokens_needed = target_num_tokens - current_num_tokens + # Calculate how many full padding sequences we need + num_padding_repeats = (tokens_needed + len(padding_tokens) - 1) // len(padding_tokens) + padded_prompt = prompt + (padding * num_padding_repeats) + # Truncate to exact target length + padded_tokens = encoder.encode(padded_prompt, disallowed_special=()) + return encoder.decode(padded_tokens[:target_num_tokens]) + else: + return prompt + + @staticmethod + def _generate_bamboo_prompt(external_dataset: "Dataset", num_tokens: int) -> str: + prompt = SPEEDBench._generate_writing_prompt(external_dataset["content"]) + return SPEEDBench._pad_or_truncate_prompt(prompt, num_tokens) + + @staticmethod + def _generate_chatrag_bench_prompt(external_dataset: "Dataset") -> list[Any]: + prompt = "Please give a full and complete answer for the questions. \n\nContext:\n{context}\n\nQuestion:\n{question}" + context = "\n\n".join([ctx["text"] for ctx in external_dataset["ctxs"][0]]) + questions = [ + message["content"] + for message in external_dataset["messages"][0] + if message["role"] == "user" + ] + + return [prompt.format(context=context, question=questions[0])] + questions[1:] + + @staticmethod + def _generate_coser_prompt(external_dataset: "Dataset") -> str: + rng = np.random.default_rng(seed=12347) + prompt = """You are {character} from {book_name}. +==={character}'s Profile=== +{character_profile} +===Current Scenario=== +{scenario} +===Information about the other Characters=== +{other_character_profiles_str} +===Your Inner Thoughts=== +{motivation} + +===Requirements=== +Your output should include **thought**, **speech**, and **action**. Use [your thought] +for thoughts, which others can't see, e.g. [I'm terrified, but I must appear strong.]. Use +(your action) for actions, which others can see, such as (watches silently, trying to control +her fear and anger).""" + character = rng.choice(external_dataset["major_characters"][0]) + character_profile = external_dataset["character_profiles"][0][character] + scenario = external_dataset["scenario"][0] + book_name = external_dataset["book"][0] + motivation = next( + ( + key_character["motivation"] + for key_character in external_dataset["key_characters"][0] + if key_character["name"] == character + ), + "No motivation provided", + ) + if motivation == "No motivation provided": + print("warning: no motivation provided for character", character) + other_character_profiles_str = "\n\n".join( + [ + f"{character_name}: {character_profile}" + for character_name, character_profile in external_dataset["character_profiles"][ + 0 + ].items() + if character_name != character and character_profile is not None + ] + ) + return prompt.format( + character=character, + character_profile=character_profile, + book_name=book_name, + scenario=scenario, + other_character_profiles_str=other_character_profiles_str, + motivation=motivation, + ) + + @staticmethod + def _generate_mmlu_pro_prompt(external_dataset: "Dataset", subject: str) -> list[Any]: + def get_question_and_options(question, options): + options = [(chr(ord("A") + i), a) for i, a in enumerate(options)] + options_str = "\n".join([f"({letter}) {option}" for letter, option in options]) + return f"Question: {question}\n\nOptions: {options_str}\n\n" + + prompt = 'The following are multiple choice questions (with answers) about {subject}. Think step by step and then finish your answer with "the answer is (X)" where X is the correct letter choice.\n\n' + first_question = prompt.format(subject=subject) + get_question_and_options( + external_dataset["question"][0], external_dataset["options"][0] + ) + return [first_question] + [ + get_question_and_options(question, options) + for question, options in zip( + external_dataset["question"][1:], external_dataset["options"][1:] + ) + ] + + @staticmethod + def _generate_hle_prompt( + example: dict[str, Any], + hle_train: "pd.DataFrame", + num_tokens: int, + rng: "np.random.Generator", + ) -> str: + encoder = tiktoken.get_encoding("o200k_base") + prompt = ( + "Please answer the question below.\n\nHere are some examples of question and answer pairs in the category of " + + example["category"] + + ":\n\n" + ) + prompt_tokens = encoder.encode(prompt) + example_tokens = encoder.encode(example["question"]) + current_num_tokens = len(prompt_tokens) + len(example_tokens) + hle_train_category = hle_train[hle_train["category"] == example["category"]] + + while current_num_tokens < num_tokens: + hle_train_category_sample = hle_train_category.sample(1, random_state=rng) + prompt += hle_train_category_sample["demonstration"].iloc[0] + current_num_tokens += len(hle_train_category_sample["tokens"].iloc[0]) + prompt_tokens += list(hle_train_category_sample["tokens"].iloc[0]) + + return encoder.decode( + prompt_tokens[: num_tokens - len(example_tokens) + 1] + example_tokens + ) + + @staticmethod + def _get_num_tokens_from_config(speed_config: config_type | str) -> int: + match = re.search(r"throughput_(\d+)k", speed_config) + if match: + return int(match.group(1)) * 1000 + else: + raise ValueError(f"Could not determine num_tokens from speed_config: {speed_config}") + + def _fetch_all_turns_data( + self, example: dict[str, Any], speed_config: config_type | str + ) -> dict[str, Any]: + turns = example["turns"] + if not turns[0].startswith(TURNS_PLACEHOLDER): + return example + + if BenchmarkDataset.BAMBOO.value in example["source"]: + num_tokens = self._get_num_tokens_from_config(speed_config) + src_ids = [int(match) for match in re.findall(r"_(\d+)", example["src_id"])] + external_dataset = self._get_external_dataset( + BenchmarkDataset.BAMBOO.value, config_name=example["source"] + ) + external_dataset = external_dataset.select(src_ids) + example["turns"] = [self._generate_bamboo_prompt(external_dataset, num_tokens)] + + elif BenchmarkDataset.CNN_DAILYMAIL.value in example["source"]: + external_dataset = self._get_external_dataset( + BenchmarkDataset.CNN_DAILYMAIL.value, config_name="3.0.0" + ).to_pandas() + src_id = example["src_id"] + article = external_dataset[external_dataset["id"] == src_id]["article"].iloc[0] + example["turns"] = [ + example["turns"][0].removeprefix(f"{TURNS_PLACEHOLDER}\n\n").format(article=article) + ] + + elif BenchmarkDataset.HLE.value in example["source"]: + if "qualitative" in speed_config: + external_dataset = self._get_external_dataset( + BenchmarkDataset.HLE.value, config_name="test" + ).to_pandas() + src_id = example["src_id"] + example["turns"] = [ + external_dataset[external_dataset["id"] == src_id]["question"].iloc[0] + ] + elif "throughput" in speed_config: + num_tokens = self._get_num_tokens_from_config(speed_config) + hle_train, hle_test = self._get_external_dataset( + BenchmarkDataset.HLE.value, config_name="train_test_split" + ) + hle_train = hle_train.to_pandas() + hle_train = hle_train[hle_train["image"] == ""] + hle_train["demonstration"] = hle_train.apply( + lambda e: "Question: " + + e["question"] + + "\n\nAnswer: " + + e["rationale"] + + "\n\n", + axis=1, + ) + hle_train["tokens"] = hle_train["demonstration"].apply( + lambda e: tiktoken.get_encoding("o200k_base").encode(e, disallowed_special=()) + ) + src_id = example["src_id"] + hle_test = hle_test.to_pandas() + external_dataset_example = hle_test[hle_test["id"] == src_id].iloc[0] + self.hle_rng = getattr(self, "hle_rng", np.random.default_rng(42)) + example["turns"] = [ + self._generate_hle_prompt( + external_dataset_example, hle_train, num_tokens, self.hle_rng + ) + ] + else: + raise ValueError(f"Invalid speed_config: {speed_config}") + + elif BenchmarkDataset.LIVECODEBENCH.value in example["source"]: + external_dataset = self._get_external_dataset( + BenchmarkDataset.LIVECODEBENCH.value, config_name="test" + ).to_pandas() + src_id = example["src_id"] + external_dataset_example = external_dataset[ + external_dataset["question_id"] == src_id + ].iloc[0] + example["turns"] = [ + example["turns"][0] + .removeprefix(f"{TURNS_PLACEHOLDER}\n\n") + .format( + question=external_dataset_example["question_content"], + starter_code=external_dataset_example["starter_code"], + ) + ] + + elif BenchmarkDataset.CODE_CONTESTS.value in example["source"]: + external_dataset = self._get_external_dataset( + BenchmarkDataset.CODE_CONTESTS.value, config_name="test" + ).to_pandas() + src_id = example["src_id"] + external_dataset_example = external_dataset[external_dataset["name"] == src_id].iloc[0] + example["turns"] = [ + example["turns"][0] + .removeprefix(f"{TURNS_PLACEHOLDER}\n\n") + .format(question=external_dataset_example["description"]) + ] + + elif BenchmarkDataset.MTBENCH_101.value in example["source"]: + external_dataset = self._get_external_dataset( + BenchmarkDataset.MTBENCH_101.value, config_name=example["source"] + ) + src_id = example["src_id"].rsplit("_", 1)[1] + external_dataset_example = external_dataset.select([int(src_id)]) + example["turns"] = [entry["user"] for entry in external_dataset_example["history"][0]] + + elif BenchmarkDataset.OPUS100.value in example["source"]: + _, config_name, src_id = example["src_id"].split("_") + external_dataset = self._get_external_dataset( + BenchmarkDataset.OPUS100.value, config_name=config_name + ) + external_dataset_example = external_dataset.select([int(src_id)]) + example["turns"] = [ + example["turns"][0] + .removeprefix(f"{TURNS_PLACEHOLDER}\n\n") + .format(question=external_dataset_example["translation"][0]) + ] + + elif BenchmarkDataset.CHATRAG_BENCH.value in example["source"]: + external_dataset = self._get_external_dataset( + BenchmarkDataset.CHATRAG_BENCH.value, config_name=["hybridial", "sqa"] + ) + src_id = example["src_id"].rsplit("_", 1)[1] + external_dataset_example = external_dataset.select([int(src_id)]) + example["turns"] = self._generate_chatrag_bench_prompt(external_dataset_example) + + elif BenchmarkDataset.MMLU_PRO.value in example["source"]: + external_dataset = self._get_external_dataset( + BenchmarkDataset.MMLU_PRO.value, config_name="test" + ) + src_id = int(example["src_id"].split("(")[1].split(",")[0]) + external_dataset_example = external_dataset.select( + range(src_id, src_id + len(example["turns"])) + ) + example["turns"] = self._generate_mmlu_pro_prompt( + external_dataset_example, example["sub_category"] + ) + + elif BenchmarkDataset.ADALEVAL_STACKSELECT.value in example["source"]: + num_tokens = self._get_num_tokens_from_config(speed_config) + external_dataset = self._get_external_dataset( + BenchmarkDataset.ADALEVAL_STACKSELECT.value, + config_name=example["source"], + ).to_pandas() + src_id = example["src_id"] + external_dataset_example = external_dataset[ + external_dataset["question_id"] == src_id + ].iloc[0] + example["turns"] = [ + self._pad_or_truncate_prompt( + self._generate_stackselect_prompt( + question=external_dataset_example["question"], + answers=external_dataset_example["all_answers"], + answer=external_dataset_example["answer"], + num_tokens=num_tokens, + ), + num_tokens, + ) + ] + + elif BenchmarkDataset.ADALEVAL_TEXTSORT.value in example["source"]: + num_tokens = self._get_num_tokens_from_config(speed_config) + external_dataset = self._get_external_dataset( + BenchmarkDataset.ADALEVAL_TEXTSORT.value, config_name=example["source"] + ) + src_id = example["src_id"].split("_")[1] + external_dataset_example = external_dataset.select([int(src_id)]) + example["turns"] = [ + self._pad_or_truncate_prompt( + self._generate_textsort_prompt(external_dataset_example["prompt"][0]), + num_tokens, + ) + ] + + elif BenchmarkDataset.ROLEBENCH.value in example["source"]: + config_name = example["src_id"].split("_")[1] + external_dataset = self._get_external_dataset( + BenchmarkDataset.ROLEBENCH.value, + config_name=example["source"].replace("tree", "raw") + + f"/{config_name}/role_specific/test.jsonl", + ) + roles_dataset = self._get_external_dataset( + BenchmarkDataset.ROLEBENCH_ROLES.value, + config_name="https://huggingface.co/datasets/ZenMoore/RoleBench/raw/a57ed54f9613921e4a5f1b63601a558cd5acf971/profiles-eng/desc.json", + ) + src_ids = [int(match) for match in re.findall(r"_(\d+)", example["src_id"])][ + : len(example["turns"]) + ] + external_dataset_example = external_dataset.iloc[src_ids] + role_name = external_dataset_example["role"].iloc[0] + role_description_and_catchphrases = roles_dataset[role_name][0] + example["turns"] = [ + example["turns"][0] + .removeprefix(f"{TURNS_PLACEHOLDER}\n\n") + .format( + role_name=role_name, + role_description_and_catchphrases=role_description_and_catchphrases, + ) + + "\n" + + external_dataset_example["question"].iloc[0] + ] + [ + question.removeprefix(f"{role_name}, ").removeprefix(f" {role_name},") + for question in external_dataset_example["question"].iloc[1:] + ] + + elif BenchmarkDataset.COSER.value in example["source"]: + external_dataset = self._get_external_dataset( + BenchmarkDataset.COSER.value, config_name=example["source"] + ) + src_id = example["src_id"].split("_")[1] + external_dataset_example = external_dataset.select([int(src_id)]) + example["turns"] = [self._generate_coser_prompt(external_dataset_example)] + + return example + + def _load_dataset(self, config_name_or_dataset_path: config_type | str) -> "Dataset": + """Load the raw HuggingFace dataset from a config name or local path. + + Args: + config_name_or_dataset_path: Either a SPEED-Bench config name + (e.g. ``"qualitative"``) or a path to a local parquet file / + directory. + category: If provided, filter the dataset to this category only. + + Returns: + The loaded (and optionally filtered / truncated) HuggingFace dataset. + """ + if config_name_or_dataset_path in get_args(config_type): + dataset = load_dataset("nvidia/SPEED-Bench", config_name_or_dataset_path, split="test") + else: + config_name_or_dataset_path_path = Path(config_name_or_dataset_path) + if not config_name_or_dataset_path_path.exists(): + msg = ", ".join(get_args(config_type)) + raise ValueError( + f"Dataset path {config_name_or_dataset_path_path} does not exist or not one of the supported configs {msg}" + ) + if config_name_or_dataset_path_path.is_dir(): + data_files = { + "test": [ + str(path) for path in config_name_or_dataset_path_path.rglob("*.parquet") + ] + } + else: + data_files = {"test": [str(config_name_or_dataset_path_path)]} + dataset = load_dataset("parquet", data_files=data_files, split="test") + if self.num_samples is not None: + dataset = dataset.select(range(self.num_samples)) + return dataset + + def _resolve_external_data( + self, dataset: "Dataset", speed_config: config_type | str + ) -> "Dataset": + """Resolve all external data references in the dataset. + + Applies ``_fetch_all_turns_data`` to every example so that turn + placeholders are replaced with fully-resolved prompt text. + + Args: + dataset: The HuggingFace dataset with potentially unresolved turns. + speed_config: The SPEED-Bench config name used to determine + token-length parameters for throughput configs. + + Returns: + The dataset with all turns fully resolved. + """ + return dataset.map(self._fetch_all_turns_data, fn_kwargs={"speed_config": speed_config}) + + def _preprocess( + self, + config_name_or_dataset_path: config_type | str, + *, + _prepare_mode: bool = False, + ): + dataset = self._load_dataset(config_name_or_dataset_path) + + if _prepare_mode: + # Resolve all external data references (only allowed during prepare) + dataset = self._resolve_external_data(dataset, config_name_or_dataset_path) + else: + # Validate that all turns are fully resolved (no placeholders remaining) + for example in dataset: + for turn in example["turns"]: + if turn.startswith(TURNS_PLACEHOLDER): + raise ValueError( + f"Unresolved data placeholder found in question_id={example['question_id']} " + f"(category={example['category']}). Please run " + f"`python prepare_data.py --config ` first to download " + f"and resolve all external data references." + ) + + self._resolved_dataset = dataset + self.data = [ + Request( + system_prompt=None, + turns=example["turns"], + category=example["category"], + question_id=example["question_id"], + ) + for example in dataset + ] + assert len(self.data) == len(dataset), ( # type: ignore[arg-type] + f"Number of requests {len(self.data)} does not match number of requests in the dataset {len(dataset)}" # type: ignore[arg-type] + ) + + @classmethod + def prepare_data( + cls, + output_dir: str | Path, + config_name: config_type = "qualitative", + ) -> Path: + """Download, resolve, and save the SPEED-Bench dataset as parquet. + + This is the **only** entry-point that fetches external data and + resolves turn placeholders. The resulting parquet file can then be + loaded directly by the normal ``SPEEDBench(config_name=)`` + constructor without any further network access. + + Args: + output_dir: Directory where the parquet file will be written. + config_name: SPEED-Bench configuration to prepare. + + Returns: + Path to the saved parquet file. + """ + instance = cls(config_name=config_name, _prepare_mode=True) + + # Persist to parquet + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + output_path = output_dir / "test.parquet" + instance._resolved_dataset.to_parquet(output_path) + return output_path diff --git a/examples/specdec_bench/specdec_bench/metrics/__init__.py b/examples/specdec_bench/specdec_bench/metrics/__init__.py index b61616830..1f6ac79fc 100644 --- a/examples/specdec_bench/specdec_bench/metrics/__init__.py +++ b/examples/specdec_bench/specdec_bench/metrics/__init__.py @@ -15,6 +15,8 @@ from .aa_timing import AATiming from .acceptance_rate import AcceptanceRate -from .base import Metric from .mtbench import MTBench +from .specbench import SpecBench from .timing import Timing + +__all__ = ["AATiming", "AcceptanceRate", "MTBench", "SpecBench", "Timing"] diff --git a/examples/specdec_bench/specdec_bench/metrics/aa_timing.py b/examples/specdec_bench/specdec_bench/metrics/aa_timing.py index 21af35112..cce735d5f 100644 --- a/examples/specdec_bench/specdec_bench/metrics/aa_timing.py +++ b/examples/specdec_bench/specdec_bench/metrics/aa_timing.py @@ -30,7 +30,7 @@ def __init__(self, base_tokenizer): raise ImportError( "Please install tiktoken to use the AATiming metric, or remove the metric from the run command" ) - self.enc = tiktoken.get_encoding("cl100k_base") + self.enc = tiktoken.get_encoding("o200k_base") self.base_tokenizer = base_tokenizer self.total_tokens = [] diff --git a/examples/specdec_bench/specdec_bench/metrics/specbench.py b/examples/specdec_bench/specdec_bench/metrics/specbench.py new file mode 100644 index 000000000..32ab3d1c7 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/metrics/specbench.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from collections import defaultdict +from itertools import chain +from pathlib import Path +from statistics import mean + +try: + import matplotlib.pyplot as plt + import pandas as pd + from rich.console import Console + from rich.table import Table + + not_installed = False +except ImportError: + not_installed = True + +from .acceptance_rate import AcceptanceRate + + +class SpecBench(AcceptanceRate): + def __init__(self, requests): + super().__init__() + if not_installed: + raise ImportError( + "Please install rich, matplotlib, seaborn, and pandas to use the SpecBench metric" + ) + self.requests = requests + + def process_final(self, text_outputs): + lengths = {} + self.out["Request_AR"] = {} + for request_id, request in enumerate(self.requests): + turns = self.prompt_ar[request_id].values() + assert len(turns) == len(request.turns), ( + f"Number of turns {len(turns)} does not match number of turns in request {len(request.turns)}" + ) + self.out["Request_AR"][request.question_id] = mean(list(chain(*turns))) + for turn in turns: + self._get_lengths(turn, lengths) + print(request.category, self.out["Request_AR"][request.question_id]) + per_category = defaultdict(list) + for request in self.requests: + per_category[request.category].append(self.out["Request_AR"][request.question_id]) + self.out["Category_AR"] = {} + for category_name, category_ar in per_category.items(): + if len(category_ar) > 0: + category_ar = mean(category_ar) + self.out["Category_AR"][category_name] = category_ar + average_ar = mean(self.out["Request_AR"].values()) + self.out["Average_AR"] = average_ar + self._process_lengths(lengths) + self.write() + self._format_write_output(text_outputs) + self._pretty_print_results() + self._dump_results() + self._create_visualizations(text_outputs) + + def _format_write_output(self, outputs): + with open(os.path.join(self.directory, "specbench_responses.jsonl"), "w") as outfile: + for i, messages in enumerate(outputs): + out_line = {} + out_line["question_id"] = self.requests[i].question_id + out_line["category"] = self.requests[i].category + q_turns = [c["content"] for c in messages if c["role"] == "user"] + a_turns = [c["content"] for c in messages if c["role"] == "assistant"] + out_line["turns"] = q_turns + out_line["choices"] = [{"index": 0, "turns": a_turns}] + json.dump(out_line, outfile) + outfile.write("\n") + + def _pretty_print_results(self): + # Create and display results table + console = Console() + table = Table( + title="Acceptance Rate Results", + show_header=True, + header_style="bold magenta", + ) + table.add_column("Category", style="cyan", no_wrap=True) + table.add_column("Average AR", justify="right", style="green") + + # Add category rows + for category_name, category_ar in sorted(self.out["Category_AR"].items()): + table.add_row(category_name, f"{category_ar:.4f}") + + # Add separator and summary row + table.add_section() + table.add_row("[bold]Overall Average[/bold]", f"[bold]{self.out['Average_AR']:.4f}[/bold]") + + console.print(table) + + def _dump_results(self): + with open(os.path.join(self.directory, "specbench_results.json"), "w") as outfile: + json.dump(self.out, outfile, indent=4) + + def _create_visualizations( + self, + text_outputs: list[list[dict[str, str]]], + title: str = "Speculative Decoding Acceptance Rate Analysis", + ): + """ + Create professional plots for acceptance rates. + Completely generated by Cursor. + """ + + # Set style + plt.style.use("seaborn-v0_8") + + df_clean = pd.DataFrame.from_dict( + { + "question_id": list(self.out["Request_AR"].keys()), + "acceptance_rate": list(self.out["Request_AR"].values()), + "category": [request.category for request in self.requests], + "response_length": [ + mean([len(c["content"]) for c in messages if c["role"] == "assistant"]) + for messages in text_outputs + ], + } + ) + + if len(df_clean) == 0: + print("Warning: No successful results to plot") + return + + # 1. Acceptance rate by category + fig, axes = plt.subplots(1, 3, figsize=(15, 5)) + fig.suptitle(title, fontsize=16, fontweight="bold") + + # Plot 1: Acceptance rate by category + ax1 = axes[0] + category_stats = ( + df_clean.groupby("category") + .agg({"acceptance_rate": ["mean", "std"], "question_id": "count"}) + .round(3) + ) + + categories = category_stats.index.tolist() + means = category_stats[("acceptance_rate", "mean")].values + stds = category_stats[("acceptance_rate", "std")].values + counts = category_stats[("question_id", "count")].values + + bars = ax1.bar(range(len(categories)), means, yerr=stds, capsize=5, alpha=0.8) + ax1.set_xlabel("Category") + ax1.set_ylabel("Acceptance Rate") + ax1.set_title("Acceptance Rate by Category") + ax1.set_xticks(range(len(categories))) + ax1.set_xticklabels(categories, rotation=45, ha="right") + + # Add count labels on bars + for i, (bar, count) in enumerate(zip(bars, counts)): + ax1.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 0.01, + f"n={count}", + ha="center", + va="bottom", + fontsize=8, + ) + + # Plot 2: Acceptance rate vs response length + ax2 = axes[1] + # Bin response lengths + df_clean["response_length_bin"] = pd.cut( + df_clean["response_length"], + bins=[0, 100, 300, 500, 1000, float("inf")], + labels=["0-100", "100-300", "300-500", "500-1000", "1000+"], + ) + + length_stats = ( + df_clean.groupby("response_length_bin") + .agg({"acceptance_rate": ["mean", "std"], "question_id": "count"}) + .round(3) + ) + + length_bins = length_stats.index.tolist() + length_means = length_stats[("acceptance_rate", "mean")].values + length_stds = length_stats[("acceptance_rate", "std")].values + length_counts = length_stats[("question_id", "count")].values + + bars2 = ax2.bar( + range(len(length_bins)), + length_means, + yerr=length_stds, + capsize=5, + alpha=0.8, + ) + ax2.set_xlabel("Response Length (characters)") + ax2.set_ylabel("Acceptance Rate") + ax2.set_title("Acceptance Rate by Response Length") + ax2.set_xticks(range(len(length_bins))) + ax2.set_xticklabels(length_bins) + + for i, (bar, count) in enumerate(zip(bars2, length_counts)): + ax2.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 0.01, + f"n={count}", + ha="center", + va="bottom", + fontsize=8, + ) + + # Plot 3: Distribution of acceptance rates + ax3 = axes[2] + ax3.hist(df_clean["acceptance_rate"], bins=20, alpha=0.7, edgecolor="black") + ax3.axvline( + df_clean["acceptance_rate"].mean(), + color="red", + linestyle="--", + label=f"Mean: {df_clean['acceptance_rate'].mean():.3f}", + ) + ax3.set_xlabel("Acceptance Rate") + ax3.set_ylabel("Frequency") + ax3.set_title("Distribution of Acceptance Rates") + ax3.legend() + + plt.tight_layout() + plot_path = Path(self.directory) / "acceptance_rate_analysis.png" + plt.savefig(plot_path, dpi=300, bbox_inches="tight") + plt.close() + print(f"Plots saved to {plot_path}") diff --git a/examples/specdec_bench/specdec_bench/metrics/timing.py b/examples/specdec_bench/specdec_bench/metrics/timing.py index 023aaf785..5bf33c604 100644 --- a/examples/specdec_bench/specdec_bench/metrics/timing.py +++ b/examples/specdec_bench/specdec_bench/metrics/timing.py @@ -53,6 +53,7 @@ def process_final(self, text_outputs): if tpot_time: self.out["Request Generation Step Time"] = compute_statistics(tpot_time) self.out["Request Generation Tokens Per Second"] = compute_statistics(gen_tp_time) + self.out["Number of Output Tokens"] = compute_statistics(self.total_tokens) for k, v in self.out.items(): print(k, v) self.write() diff --git a/examples/specdec_bench/specdec_bench/models/__init__.py b/examples/specdec_bench/specdec_bench/models/__init__.py index 5fa1260ab..e103a9d92 100644 --- a/examples/specdec_bench/specdec_bench/models/__init__.py +++ b/examples/specdec_bench/specdec_bench/models/__init__.py @@ -13,7 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import Model +from .auto_deploy import AutoDeployModel from .sglang import SGLANGModel +from .specbench_medusa import SpecBenchMedusaModel from .trtllm_torch_api import TRTLLMPYTModel from .vllm import VLLMModel + +__all__ = [ + "AutoDeployModel", + "SGLANGModel", + "SpecBenchMedusaModel", + "TRTLLMPYTModel", + "VLLMModel", +] diff --git a/examples/specdec_bench/specdec_bench/models/auto_deploy.py b/examples/specdec_bench/specdec_bench/models/auto_deploy.py new file mode 100644 index 000000000..bd030e783 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/models/auto_deploy.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import itertools +import time +from typing import Any + +try: + from tensorrt_llm._torch.auto_deploy.llm import LLM + from tensorrt_llm.llmapi import DraftTargetDecodingConfig + from tensorrt_llm.sampling_params import SamplingParams +except ImportError: + print("tensorrt_llm._torch.auto_deploy is not installed.") + LLM = None + +from .base import Model + + +class AutoDeployModel(Model): + def __init__(self, model_path, max_concurrent_requests, sampling_kwargs, **kwargs): + self.model = create_auto_deploy_model(model_path, max_concurrent_requests, kwargs) + self.sampling_kwargs = sampling_kwargs + + async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): + output_dict = {} + sampling_config = check_sampling_config(self.sampling_kwargs, max_length, end_id) + outputs = [] + timing = [time.perf_counter()] + beam_lens = [[] for _ in range(self.sampling_kwargs.get("beam_width", 1))] + + async for output in self.model.generate_async( + prompt_ids, + streaming=not sampling_config.use_beam_search, + sampling_params=sampling_config, + ): + for beam in output.outputs: + beam_lens[beam.index].append(len(beam.token_ids)) + outputs.append(output.outputs) + timing.append(time.perf_counter()) + + reformatted_output_ids = [[] for _ in range(self.sampling_kwargs.get("beam_width", 1))] + for beam_idx, beam_len in enumerate(beam_lens): + response = outputs[-1][beam_idx] + if beam_len[0] != 0: + reformatted_output_ids[beam_idx].append(response.token_ids[: beam_len[0]]) + for s, e in itertools.pairwise(beam_len): + reformatted_output_ids[beam_idx].append(response.token_ids[s:e]) + if len(response.token_ids) > beam_len[-1]: + reformatted_output_ids[beam_idx].append(response.token_ids[beam_len[-1] :]) + + output_dict["output_ids"] = reformatted_output_ids + output_dict["output_logits"] = None + output_dict["token_times"] = timing + return output_dict + + def stop(self): + """Stop and cleanup the model.""" + if hasattr(self, "model") and self.model is not None: + with contextlib.suppress(Exception): + del self.model + + +def create_auto_deploy_model(model_path: str, max_concurrent_requests: int, kwargs: dict[str, Any]): + world_size = kwargs.get("world_size", kwargs.get("tensor_parallel_size", 1)) + + max_seq_len = kwargs.get("max_seq_len", 8192) + + kv_cache_config = { + "enable_block_reuse": kwargs.get("prefix_cache", False), + "free_gpu_memory_fraction": kwargs.get("free_gpu_memory_fraction", 0.75), + } + + specdec = None + speculative_algorithm = kwargs.get("speculative_algorithm") + + if speculative_algorithm == "DRAFT_TARGET": + specdec = DraftTargetDecodingConfig( + max_draft_len=kwargs.get("speculative_num_steps", 3), + speculative_model_dir=kwargs.get("draft_model_dir"), + ) + elif speculative_algorithm == "NONE": + specdec = None + + max_num_tokens = kwargs.get("max_num_tokens", 8192) + + llm_kwargs = { + "model": model_path, + "world_size": world_size, + "max_batch_size": max_concurrent_requests, + "max_seq_len": max_seq_len, + "max_num_tokens": max_num_tokens, + "skip_tokenizer_init": kwargs.get("skip_tokenizer_init", True), + "kv_cache_config": kv_cache_config, + "runtime": "trtllm", + "disable_overlap_scheduler": kwargs.get("disable_overlap_scheduler", True), + "speculative_config": specdec, + } + + if kwargs.get("attn_backend"): + llm_kwargs["attn_backend"] = kwargs["attn_backend"] + + if kwargs.get("compile_backend"): + llm_kwargs["compile_backend"] = kwargs["compile_backend"] + + # Optimization mode: "graph" uses full torch.export, "transformers" is simpler + # Default to "transformers" to avoid torch.export dimension specialization issues + llm_kwargs["mode"] = kwargs.get("mode", "transformers") + + if kwargs.get("cuda_graph_batch_sizes"): + llm_kwargs["cuda_graph_batch_sizes"] = kwargs["cuda_graph_batch_sizes"] + + model = LLM(**llm_kwargs) + return model + + +def check_sampling_config(sampling_config: dict[str, Any], max_length: int, end_id: int): + return SamplingParams( + use_beam_search=sampling_config.get("beam_width", 1) > 1, + n=sampling_config.get("beam_width", 1), + top_k=sampling_config.get("top_k"), + top_p=sampling_config.get("top_p"), + seed=sampling_config.get("seed"), + temperature=sampling_config.get("temperature", 1.0), + max_tokens=max_length, + end_id=end_id, + detokenize=False, + ) diff --git a/examples/specdec_bench/specdec_bench/models/base.py b/examples/specdec_bench/specdec_bench/models/base.py index 42186fef0..ab26a4704 100644 --- a/examples/specdec_bench/specdec_bench/models/base.py +++ b/examples/specdec_bench/specdec_bench/models/base.py @@ -18,7 +18,7 @@ class Model: def __init__(self, model_dir, tokenizer, max_draft_length): raise NotImplementedError - async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): + async def run(self, prompt_ids, sampling_params, request_id, turn_id): """ prompt_ids is list of tokens output is list of list of tokens diff --git a/examples/specdec_bench/specdec_bench/models/sglang.py b/examples/specdec_bench/specdec_bench/models/sglang.py index 4840a0eda..1e8c53446 100644 --- a/examples/specdec_bench/specdec_bench/models/sglang.py +++ b/examples/specdec_bench/specdec_bench/models/sglang.py @@ -27,7 +27,12 @@ class SGLANGModel(Model): def __init__( - self, model_dir, max_concurrent_requests, sampling_kwargs, use_draft_logits=False, **kwargs + self, + model_dir, + max_concurrent_requests, + sampling_kwargs, + use_draft_logits=False, + **kwargs, ): speculative_algorithm = kwargs.get("speculative_algorithm") if speculative_algorithm == "MTP": @@ -43,35 +48,44 @@ def __init__( self.model = sgl.Engine( model_path=model_dir, skip_tokenizer_init=True, - mem_fraction_static=0.7, - disable_overlap_schedule=kwargs.get("disable_overlap_schedule", True), + trust_remote_code=True, + mem_fraction_static=0.8, + disable_overlap_schedule=kwargs.get("disable_overlap_schedule", False), tp_size=kwargs.get("tensor_parallel_size", 1), + ep_size=kwargs.get("moe_expert_parallel_size", 1), speculative_algorithm=speculative_algorithm, speculative_num_steps=kwargs.get("speculative_num_steps", 3), speculative_eagle_topk=kwargs.get("speculative_eagle_topk", 1), speculative_num_draft_tokens=kwargs.get("speculative_num_draft_tokens", 4), speculative_draft_model_path=kwargs.get("draft_model_dir"), torch_compile_max_bs=max_concurrent_requests, + max_running_requests=max_concurrent_requests, attention_backend=kwargs.get("attention_backend"), enable_torch_compile=kwargs.get("enable_torch_compile", False), cuda_graph_max_bs=max_concurrent_requests, + disable_cuda_graph=False, ) else: self.model = sgl.Engine( model_path=model_dir, skip_tokenizer_init=True, - mem_fraction_static=0.7, - disable_overlap_schedule=kwargs.get("disable_overlap_schedule", True), + trust_remote_code=True, + mem_fraction_static=0.8, + disable_overlap_schedule=kwargs.get("disable_overlap_schedule", False), tp_size=kwargs.get("tensor_parallel_size", 1), + ep_size=kwargs.get("moe_expert_parallel_size", 1), torch_compile_max_bs=max_concurrent_requests, + max_running_requests=max_concurrent_requests, attention_backend=kwargs.get("attention_backend"), enable_torch_compile=kwargs.get("enable_torch_compile", False), cuda_graph_max_bs=max_concurrent_requests, + disable_cuda_graph=False, ) self.sampling_config = sampling_kwargs async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): + """Synchronous version of run for use with asyncio.to_thread""" timing = [] output_dict = {} self.sampling_config["max_new_tokens"] = max_length @@ -79,7 +93,7 @@ async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): timing.append(time.perf_counter()) assert self.sampling_config.get("beam_width", 1) == 1 beam_lens = [[] for _ in range(self.sampling_config.get("beam_width", 1))] - outputs = [] + outputs = [None] result = await self.model.async_generate( sampling_params=self.sampling_config, input_ids=prompt_ids, stream=True ) diff --git a/examples/specdec_bench/specdec_bench/models/specbench_medusa.py b/examples/specdec_bench/specdec_bench/models/specbench_medusa.py new file mode 100644 index 000000000..8334a50b2 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/models/specbench_medusa.py @@ -0,0 +1,284 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 Heming Xia. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import time + +import torch + +from .base import Model + +# Medusa dependencies from Spec-Bench +try: + import os + import sys + + spec_bench_path = os.path.join(os.getcwd(), "Spec-Bench") + sys.path.insert(0, spec_bench_path) + from model.medusa.kv_cache import initialize_past_key_values + from model.medusa.medusa_choices import mc_sim_7b_63 + from model.medusa.medusa_model import MedusaModel + from model.medusa.utils import ( + evaluate_posterior, + generate_candidates, + generate_medusa_buffers, + initialize_medusa, + reset_medusa_mode, + tree_decoding, + update_inference_inputs, + ) +except ImportError as e: + print(f"Medusa dependencies not found: {e}") + MedusaModel = None + + +class SpecBenchMedusaModel(Model): + def __init__( + self, + model_dir, + max_concurrent_requests, + sampling_kwargs, + use_draft_logits=False, + **kwargs, + ): + if MedusaModel is None: + raise ImportError( + "Medusa dependencies not found. Please ensure Spec-Bench is available." + ) + assert max_concurrent_requests == 1, "Only support batch size 1 for now!" + self.medusa_num_heads = kwargs.get("medusa_num_heads", 4) + self.draft_model_path = kwargs.get("draft_model_dir") + self.dtype = kwargs.get("dtype", "float16") + self.max_steps = kwargs.get("max_steps", 512) + + # Medusa decoding parameters + self.temperature = sampling_kwargs.get("temperature", 0.0) + self.posterior_threshold = kwargs.get("posterior_threshold", 0.09) + self.posterior_alpha = kwargs.get("posterior_alpha", 0.3) + self.medusa_choices = kwargs.get("medusa_choices", mc_sim_7b_63) + + # Convert dtype string to torch dtype + dtype_map = { + "float32": torch.float32, + "float64": torch.float64, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + torch_dtype = dtype_map.get(self.dtype, torch.float16) + + # Load the Medusa model + # Use single GPU to avoid device mismatch issues with device_map="auto" + self.device = torch.device(kwargs.get("device", "cuda:0")) + self.model = MedusaModel.from_pretrained( + self.draft_model_path, + model_dir, + medusa_num_heads=self.medusa_num_heads, + torch_dtype=torch_dtype, + low_cpu_mem_usage=True, + ) + self.model = self.model.to(self.device) + + self.sampling_kwargs = sampling_kwargs + + def _medusa_forward(self, input_ids, max_new_tokens, end_id): + """ + Run Medusa speculative decoding forward pass. + + Returns: + tuple: (output_ids, new_token_count, num_steps, accept_length_list, timing) + """ + # Avoid modifying the input_ids in-place + accept_length_list = [] + input_ids = input_ids.clone() + timing = [time.perf_counter()] + + # Cache medusa buffers (the fixed patterns for tree attention) + if ( + hasattr(self.model, "medusa_choices") + and self.model.medusa_choices == self.medusa_choices + ): + medusa_buffers = self.model.medusa_buffers + else: + medusa_buffers = generate_medusa_buffers(self.medusa_choices, device=self.device) + self.model.medusa_buffers = medusa_buffers + self.model.medusa_choices = self.medusa_choices + + # Initialize the past key and value states + if hasattr(self.model, "past_key_values"): + past_key_values = self.model.past_key_values + past_key_values_data = self.model.past_key_values_data + current_length_data = self.model.current_length_data + current_length_data.zero_() + else: + ( + past_key_values, + past_key_values_data, + current_length_data, + ) = initialize_past_key_values(self.model.base_model) + self.model.past_key_values = past_key_values + self.model.past_key_values_data = past_key_values_data + self.model.current_length_data = current_length_data + + input_len = input_ids.shape[1] + cur_length = input_len + reset_medusa_mode(self.model) + medusa_logits, logits = initialize_medusa( + input_ids, self.model, medusa_buffers["medusa_attn_mask"], past_key_values + ) + new_token = 0 + + for idx in range(self.max_steps): + candidates, tree_candidates = generate_candidates( + medusa_logits, + logits, + medusa_buffers["tree_indices"], + medusa_buffers["retrieve_indices"], + ) + medusa_logits, logits, outputs = tree_decoding( + self.model, + tree_candidates, + past_key_values, + medusa_buffers["medusa_position_ids"], + input_ids, + medusa_buffers["retrieve_indices"], + ) + best_candidate, accept_length = evaluate_posterior( + logits, + candidates, + self.temperature, + self.posterior_threshold, + self.posterior_alpha, + ) + input_ids, logits, medusa_logits, new_token = update_inference_inputs( + input_ids, + candidates, + best_candidate, + accept_length, + medusa_buffers["retrieve_indices"], + outputs, + logits, + medusa_logits, + new_token, + past_key_values_data, + current_length_data, + ) + accept_length_tree = input_ids.shape[1] - cur_length + cur_length = accept_length_tree + cur_length + accept_length_list.append(accept_length_tree) + timing.append(time.perf_counter()) + + if end_id in input_ids[0, input_len:].tolist(): + break + if new_token > max_new_tokens: + break + + return input_ids, new_token, idx + 1, accept_length_list, timing + + async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): + """ + Run inference on the given prompt. + + Args: + prompt_ids: List of input token IDs + max_length: Maximum number of new tokens to generate + end_id: End of sequence token ID + request_id: Request identifier + turn_id: Turn identifier + + Returns: + dict with output_ids, output_logits, and token_times + """ + output_dict = {} + + # Convert prompt_ids to tensor + input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=self.device) + + # Run medusa forward pass (synchronously, but wrapped for async interface) + ( + input_ids_out, + new_token, + num_steps, + accept_length_list, + timing, + ) = await asyncio.to_thread(self._medusa_forward, input_ids, max_length, end_id) + + # Extract generated tokens (excluding the prompt) + original_len = len(prompt_ids) + generated_tokens = input_ids_out[0, original_len:].tolist() + + # Remove EOS token from output if present + if end_id in generated_tokens: + eos_idx = generated_tokens.index(end_id) + generated_tokens = generated_tokens[:eos_idx] + # Also adjust accept_length_list and timing + # Count how many tokens we're removing + tokens_to_remove = len(input_ids_out[0, original_len:].tolist()) - len(generated_tokens) + if tokens_to_remove > 0 and len(accept_length_list) > 0: + # Adjust the last accept length + accept_length_list[-1] = max(0, accept_length_list[-1] - tokens_to_remove) + if accept_length_list[-1] == 0: + accept_length_list.pop() + if len(timing) > 1: + timing.pop() + + # Format output_ids as list of list of tokens per step (for beam_width=1) + reformatted_output_ids = [[]] + start = 0 + for accept_len in accept_length_list: + if accept_len > 0: + reformatted_output_ids[0].append(generated_tokens[start : start + accept_len]) + start += accept_len + + # Handle any remaining tokens + if start < len(generated_tokens): + reformatted_output_ids[0].append(generated_tokens[start:]) + + output_dict["output_ids"] = reformatted_output_ids + output_dict["output_logits"] = None + output_dict["token_times"] = timing + + return output_dict + + def stop(self): + """Cleanup resources.""" + # Clear cached KV states to free memory + if hasattr(self.model, "past_key_values"): + del self.model.past_key_values + del self.model.past_key_values_data + del self.model.current_length_data + + # Clear medusa buffers + if hasattr(self.model, "medusa_buffers"): + del self.model.medusa_buffers + + # Move model to CPU or delete to free GPU memory + if hasattr(self, "model") and self.model is not None: + del self.model + torch.cuda.empty_cache() diff --git a/examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py b/examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py index 11ceeb207..25a2aed63 100644 --- a/examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py +++ b/examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py @@ -38,7 +38,12 @@ class TRTLLMPYTModel(Model): def __init__( - self, model_path, max_concurrent_requests, sampling_kwargs, use_draft_logits=False, **kwargs + self, + model_path, + max_concurrent_requests, + sampling_kwargs, + use_draft_logits=False, + **kwargs, ): self.model = create_executor(model_path, max_concurrent_requests, kwargs) self.sampling_kwargs = sampling_kwargs @@ -80,16 +85,23 @@ def create_executor(model_path: str, max_concurrent_requests, kwargs): max_draft_len=kwargs.get("speculative_num_steps", 3), speculative_model_dir=kwargs.get("draft_model_dir", None), ) - disable_overlap_schedule = True elif kwargs.get("speculative_algorithm", None) == "EAGLE3": + extra_params = {} + if "allow_advanced_sampling" in EagleDecodingConfig.model_fields: + extra_params["allow_advanced_sampling"] = kwargs.get("allow_advanced_sampling", False) + elif "allow_advanced_sampling" in kwargs: + print( + f"WARNING: allow_advanced_sampling unsupported in tensorrt_llm version: {trtllm.__version__}" + ) specdec = EagleDecodingConfig( max_draft_len=kwargs.get("speculative_num_steps", 3), speculative_model_dir=kwargs.get("draft_model_dir", None), eagle3_one_model=kwargs.get("use_one_model", True), eagle3_layers_to_capture=kwargs.get("eagle3_layers_to_capture", None), + num_eagle_layers=kwargs.get("num_eagle_layers", 1), + **extra_params, ) - disable_overlap_schedule = not kwargs.get("use_one_model", True) elif kwargs.get("speculative_algorithm", None) == "MTP": specdec = MTPDecodingConfig( @@ -127,13 +139,15 @@ def create_executor(model_path: str, max_concurrent_requests, kwargs): moe_expert_parallel_size=kwargs.get("moe_expert_parallel_size", 2), disable_overlap_scheduler=disable_overlap_schedule, cuda_graph_config=cuda_graph_config, - enable_chunked_prefill=kwargs.get("enable_chunked_prefill", False), + enable_chunked_prefill=kwargs.get("enable_chunked_prefill", True), kv_cache_config=kv_cache_config, speculative_config=specdec, enable_attention_dp=kwargs.get("enable_attention_dp", False), max_batch_size=max_concurrent_requests, moe_config=MoeConfig(backend=kwargs.get("moe_backend", "TRTLLM")), sampler_type="TorchSampler", + max_seq_len=kwargs.get("max_seq_len", None), + max_num_tokens=kwargs.get("max_num_tokens", 8192), ) return model diff --git a/examples/specdec_bench/specdec_bench/models/vllm.py b/examples/specdec_bench/specdec_bench/models/vllm.py index deb79ed89..81ad3cd24 100644 --- a/examples/specdec_bench/specdec_bench/models/vllm.py +++ b/examples/specdec_bench/specdec_bench/models/vllm.py @@ -51,10 +51,13 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs } elif kwargs.get("speculative_algorithm") == "DRAFT_TARGET": specdec = { - "method": "draft_target", + "method": "draft_model", "model": kwargs.get("draft_model_dir"), "num_speculative_tokens": kwargs.get("speculative_num_steps", 3), } + if kwargs.get("parallel_draft_block_sizes") is not None: + specdec["disable_padded_drafter_batch"] = True + specdec["parallel_draft_block_sizes"] = kwargs.get("parallel_draft_block_sizes") elif kwargs.get("speculative_algorithm") == "MTP": specdec = { "method": "mtp", @@ -62,6 +65,11 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs } elif kwargs.get("speculative_algorithm") == "NONE": specdec = None + + if specdec is None: + num_speculative_tokens = 1 + else: + num_speculative_tokens = specdec.get("num_speculative_tokens", 3) engine_args = AsyncEngineArgs( model=model_dir, trust_remote_code=True, @@ -69,8 +77,10 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs enable_expert_parallel=kwargs.get("moe_expert_parallel_size", 1) > 1, enable_prefix_caching=kwargs.get("prefix_cache", False), speculative_config=specdec, - max_num_seqs=max_concurrent_requests, + max_num_seqs=max_concurrent_requests * num_speculative_tokens, skip_tokenizer_init=False, + async_scheduling=kwargs.get("async_scheduling", True), + enforce_eager=False, ) self.model = AsyncLLM.from_engine_args(engine_args) self.sampling_kwargs = sampling_kwargs @@ -88,6 +98,8 @@ async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): output_dict = {} self.sampling_config.max_tokens = max_length self.sampling_config.stop_token_ids = [end_id] + if end_id == -1: + self.sampling_config.ignore_eos = True outputs, timing, full_tokens = await self.generate(prompt_ids, request_id, turn_id) diff --git a/examples/specdec_bench/specdec_bench/runners/__init__.py b/examples/specdec_bench/specdec_bench/runners/__init__.py index 61a85c769..17832bb99 100644 --- a/examples/specdec_bench/specdec_bench/runners/__init__.py +++ b/examples/specdec_bench/specdec_bench/runners/__init__.py @@ -13,5 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import BaseRunner from .simple import SimpleRunner + +__all__ = ["SimpleRunner"] diff --git a/examples/specdec_bench/specdec_bench/runners/base.py b/examples/specdec_bench/specdec_bench/runners/base.py index c481a0fd0..ee9062e39 100644 --- a/examples/specdec_bench/specdec_bench/runners/base.py +++ b/examples/specdec_bench/specdec_bench/runners/base.py @@ -21,7 +21,7 @@ def __init__(self, model, metrics): self.metrics = metrics self.prompt_ar = [] - async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): + async def run(self, prompt_ids, max_length, end_id, sampling_kwargs): raise NotImplementedError() def process_metrics_final(self, text_outputs): diff --git a/examples/specdec_bench/specdec_bench/utils.py b/examples/specdec_bench/specdec_bench/utils.py index d605f0b4b..e0e8b285a 100644 --- a/examples/specdec_bench/specdec_bench/utils.py +++ b/examples/specdec_bench/specdec_bench/utils.py @@ -19,12 +19,16 @@ def get_tokenizer(path): - return AutoTokenizer.from_pretrained(path) + return AutoTokenizer.from_pretrained(path, trust_remote_code=True) -def encode_chat(tokenizer, messages): +def encode_chat(tokenizer, messages, chat_template_args={}, completions=False): + if completions: + return tokenizer.encode(messages[-1]["content"], add_special_tokens=False) return tokenizer.encode( - tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True), + tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, **chat_template_args + ), add_special_tokens=False, ) @@ -46,4 +50,11 @@ def postprocess_base(text): def postprocess_gptoss(text): - return text.split("<|channel|>final<|message|>")[-1] + final_message = text.split("<|channel|>final<|message|>")[-1] + if "<|end|>" in final_message: + final_message = final_message.split("<|end|>")[0] + if "<|return|>" in final_message: + final_message = final_message.split("<|return|>")[0] + if "<|channel|>" in final_message: + final_message = final_message.split("<|channel|>")[0] + return final_message