diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index c495809bb..ccecd9cb0 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -4,7 +4,7 @@ Speculative decoding accelerates auto-regressive generation in large language models (LLMs) by leveraging a lightweight draft model to predict the next γ tokens. The main LLM then verifies these candidate tokens in a single forward pass. If the draft model correctly predicts α tokens, the LLM can accept and generate α+1 tokens per verification step, significantly improving generation speed. -This folder contains an end-to-end runnable speculative decoding fine‑tuning pipeline in which Llama‑3.2‑1B (Hugging Face) is trained on the Daring‑Anteater dataset. +This folder contains an end-to-end runnable speculative decoding fine‑tuning pipeline in which Llama‑3.2‑1B (Hugging Face) is trained on the [UltraChat-200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) dataset. This example focuses on training with Hugging Face. To train with Megatron‑LM, see the [Megatron‑LM example](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt). @@ -45,14 +45,16 @@ pip install -r requirements.txt ### Data Preparation -We use [Daring-Anteater](https://huggingface.co/datasets/nvidia/Daring-Anteater) dataset in this example. Prepare data by: +We support a range of input datasets. In this example, we will use the [UltraChat-200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) dataset. ```bash -python prepare_input_conversations/add_daring_anteater.py +python prepare_input_conversations/make_dataset.py -f prepare_input_conversations/example_data_config.yaml --full-conversations ``` See [other-datasets](#other-datasets) section for other dataset options and instruction for user-provided data. +Omit `--full-conversations` if you plan to run synthetic data generation (see [data-synthesis](#data-synthesis)). + ## Getting Started: Simplified Workflow ```bash @@ -62,7 +64,7 @@ bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct -- This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in Modelopt. Specifically, it - Initializes the draft model with [default settings](https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/speculative/eagle/default_config.py#L18) -- Fine-tunes the model on the [Daring-Anteater](https://huggingface.co/datasets/nvidia/Daring-Anteater) dataset +- Fine-tunes the model on the dataset - Evaluates the acceptance rate on [MT-Bench](https://huggingface.co/datasets/HuggingFaceH4/mt_bench_prompts) - Exports a checkpoint ready for deployment @@ -73,7 +75,7 @@ For small base models that fit in GPU memory, we can collocate them with draft m ```bash ./launch_train.sh --model $BASE_MODEL \ --output_dir $OUTPUT_DIR \ - --data input_conversations/daring-anteater.jsonl \ + --data input_conversations/train.jsonl \ --num_gpu $NUM_GPU \ --num_epochs $NUM_EPOCH \ --eagle_config eagle_config.json @@ -93,7 +95,7 @@ We support two backends for generating base model hidden states. For better effc ```bash python collect_hidden_states/compute_hidden_states_trtllm.py \ --model $BASE_MODEL \ - --input-file input_conversations/daring-anteater.jsonl \ + --input-file input_conversations/train.jsonl \ --output-dir $HIDDEN_STATES_DIR ``` @@ -104,7 +106,7 @@ Alternatively, you can generate the same hidden states with HF: ```bash python collect_hidden_states/compute_hidden_states_hf.py \ --model $BASE_MODEL \ - --input-file input_conversations/daring-anteater.jsonl \ + --input-file input_conversations/train.jsonl \ --output-dir $HIDDEN_STATES_DIR ``` @@ -201,16 +203,14 @@ See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/RE ### Other Datasets -In addition to `daring-anteater`, we provide scripts for adding several other commonly used datasets in `prepare_input_conversations`: +In addition to the default dataset, we support adding several other commonly used datasets in `prepare_input_conversations/make_dataset.py`: -```text -prepare_input_conversations/ - ├── add_daring_anteater.py - ├── add_mtbench.py - ├── add_sharegpt.py - ├── add_ultrachat.py - └── example_make_prompt_dataset.sh -``` +- MTBench (for debugging) +- ShareGPT +- UltraChat +- Daring-Anteater +- Magpie (Full 1M, and 500k and 300k filtered) +- Nemotron Post-Training Dataset V2 To use your own datasets, please preprocess your data into a `.jsonl` file with each line in the format: @@ -234,10 +234,10 @@ vllm serve meta-llama/Llama-3.2-1B-Instruct --api-key token-abc123 --port 8000 Note: Add `--quantization=modelopt` flag for quantized models. -Then, we generate conversations with the base model using prompts from Daring-Anteater: +Then, we generate conversations with the base model using the prepared prompts: ```bash -python scripts/server_generate.py --data_path input_conversations/daring-anteater.jsonl --output_path synthetic/train.jsonl +python scripts/server_generate.py --data_path input_conversations/train.jsonl --output_path synthetic/train.jsonl ``` To add a system prompt, use the `--system_prompt ` argument. @@ -249,7 +249,7 @@ For large scale data generation, please see [SLURM prepare data](SLURM_prepare_d We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set: ```bash -python scripts/calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data input_conversations/daring-anteater.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache +python scripts/calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data input_conversations/train.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache ``` This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`. diff --git a/examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py b/examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py deleted file mode 100644 index c78739edb..000000000 --- a/examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py +++ /dev/null @@ -1,102 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Add Daring-Anteater conversations to a conversation dataset.""" - -import argparse -from pathlib import Path - -from datasets import load_dataset -from tqdm import tqdm -from utils import ( - dataset_splits_explanation, - id_for_conversation, - update_dataset_file_with_conversations, -) - - -def parse_args() -> argparse.Namespace: - """Parse command-line arguments.""" - parser = argparse.ArgumentParser(description="Load Daring-Anteater conversations.") - - parser.add_argument( - "--output-split-name", - type=str, - default="daring-anteater", - help=dataset_splits_explanation("daring-anteater"), - ) - - parser.add_argument( - "--output-dir", - type=Path, - default=Path("input_conversations/"), - help="Path to save the conversations file(s) into. Default is 'input_conversations/'.", - ) - - return parser.parse_args() - - -async def main(args: argparse.Namespace) -> None: - ds = load_dataset("nvidia/Daring-Anteater", split="train", streaming=False) - input_conversations = [] - for i in tqdm( - range(len(ds)), - desc="Loading Daring-Anteater dataset", - total=len(ds), - ): - conversations = ds[i]["conversations"] - if conversations and isinstance(conversations, list): - prompt_id = f"daring-anteater-{i:05}_" + id_for_conversation(conversations) - processed_conversations = [] - for msg in conversations: - if "from" in msg: - role = msg["from"].lower() - elif "role" in msg: - role = msg["role"].lower() - else: - continue - if role == "human": - role = "user" - elif role == "gpt": - role = "assistant" - - if "value" in msg: - content = msg["value"] - elif "text" in msg: - content = msg["text"] - elif "content" in msg: - content = msg["content"] - else: - continue - content = content.strip() - if content: - processed_conversations.append({"role": role, "content": content}) - - input_conversations.append( - {"conversation_id": prompt_id, "conversations": processed_conversations} - ) - - print(f"Loaded {len(input_conversations)} prompts from Daring-Anteater.") - - update_dataset_file_with_conversations( - input_conversations, args.output_dir, args.output_split_name - ) - - -if __name__ == "__main__": - import asyncio - - args = parse_args() - asyncio.run(main(args)) diff --git a/examples/speculative_decoding/prepare_input_conversations/add_mtbench.py b/examples/speculative_decoding/prepare_input_conversations/add_mtbench.py deleted file mode 100644 index 76f090cd0..000000000 --- a/examples/speculative_decoding/prepare_input_conversations/add_mtbench.py +++ /dev/null @@ -1,105 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Add MTBench conversations to a conversation dataset.""" - -import argparse -import json -from pathlib import Path - -from tqdm import tqdm -from utils import ( - dataset_splits_explanation, - download_file, - id_for_conversation, - update_dataset_file_with_conversations, -) - -MTBENCH_QUESTIONS_URL = "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl" - - -def parse_args() -> argparse.Namespace: - """Parse command-line arguments.""" - parser = argparse.ArgumentParser(description="Load MTBench conversations.") - - parser.add_argument( - "--mtbench-questions-file", - type=Path, - required=False, - help="""Path to the MTBench questions.jsonl file. - If not provided, it will be downloaded and saved to ~/.cache/""", - ) - - parser.add_argument( - "--output-split-name", - type=str, - default="mtbench", - help=dataset_splits_explanation("mtbench"), - ) - - parser.add_argument( - "--output-dir", - type=Path, - default=Path("input_conversations/"), - help="Path to save the conversations file(s) into. Default is 'input_conversations/'.", - ) - - return parser.parse_args() - - -async def main(args: argparse.Namespace) -> None: - # Download the MTBench questions file if not provided - if not args.mtbench_questions_file: - args.mtbench_questions_file = ( - Path("~/.cache/mtbench_questions.jsonl").expanduser().resolve() - ) - if not args.mtbench_questions_file.exists(): - print("Downloading MTBench questions dataset...") - await download_file(MTBENCH_QUESTIONS_URL, args.mtbench_questions_file) - else: - print(f"Using existing MTBench questions file {args.mtbench_questions_file}") - - # Error if we failed to download the file or if it was provided but does not exist - if not args.mtbench_questions_file.exists(): - err_msg = f"MTBench questions file {args.mtbench_questions_file} does not exist." - raise FileNotFoundError(err_msg) - - with args.mtbench_questions_file.open("r", encoding="utf-8") as f: - mtbench_raw = [json.loads(line) for line in f] - - input_conversations: list[dict] = [] - for entry in tqdm(mtbench_raw, desc="Loading MTBench", total=len(mtbench_raw)): - if not entry: - continue - prompt = entry.get("turns", [""])[0] - if not prompt: - continue - prompt_id = f"mtbench-{entry['question_id']:03}_" + id_for_conversation(prompt) - input_conversations.append( - {"conversation_id": prompt_id, "conversations": [{"role": "user", "content": prompt}]} - ) - - print(f"Loaded {len(input_conversations)} filtered conversations from MTBench.") - - update_dataset_file_with_conversations( - input_conversations, args.output_dir, args.output_split_name - ) - - -if __name__ == "__main__": - import asyncio - - args = parse_args() - asyncio.run(main(args)) diff --git a/examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py b/examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py deleted file mode 100644 index 5ea90cfe6..000000000 --- a/examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py +++ /dev/null @@ -1,131 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Add ShareGPT conversations to a conversation dataset.""" - -import argparse -import json -from pathlib import Path - -from tqdm import tqdm -from utils import ( - dataset_splits_explanation, - download_file, - id_for_conversation, - update_dataset_file_with_conversations, -) - -SHAREGPT_DATASET_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" - - -def parse_args() -> argparse.Namespace: - """Parse command-line arguments.""" - parser = argparse.ArgumentParser(description="Load ShareGPT conversations.") - - parser.add_argument( - "--sharegpt-file", - type=Path, - required=False, - help="""Path to the ShareGPT JSON file containing conversations. - If not provided, it will be downloaded and saved to ~/.cache/""", - ) - - parser.add_argument( - "--output-split-name", - type=str, - default="sharegpt", - help=dataset_splits_explanation("sharegpt"), - ) - - parser.add_argument( - "--output-dir", - type=Path, - default=Path("input_conversations/"), - help="Path to save the conversations file(s) into. Default is 'input_conversations/'.", - ) - - return parser.parse_args() - - -def parse_sharegpt_conversation(sharegpt_conv: dict) -> list[dict] | None: - """Parse a ShareGPT conversation into a list of messages.""" - msgs = [] - for turn in sharegpt_conv.get("conversations", []): - if turn.get("from") in ["human", "user"]: - role = "user" - elif turn.get("from") in ["gpt", "chatgpt", "bard"]: - role = "assistant" - elif turn.get("from") == "system": - # ShareGPT system messages are metadata, skip them - continue - elif turn.get("from") == "bing": - # Bing conversations are skipped for training, omit it - return None - else: - err_msg = f"Unknown role in conversation: {turn.get('from')}" - raise ValueError(err_msg) - - value = turn.get("value", "").strip() - if value: - msgs.append({"role": role, "content": value}) - - return msgs - - -async def main(args: argparse.Namespace) -> None: - # Download the ShareGPT dataset if not provided - if not args.sharegpt_file: - args.sharegpt_file = Path("~/.cache/sharegpt.json").expanduser().resolve() - if not args.sharegpt_file.exists(): - print("Downloading ShareGPT dataset...") - await download_file(SHAREGPT_DATASET_URL, args.sharegpt_file) - else: - print(f"Using existing ShareGPT file at {args.sharegpt_file}") - - # Error if we failed to download the file or if it was provided but does not exist - if not args.sharegpt_file.exists(): - err_msg = f"ShareGPT file {args.sharegpt_file} does not exist." - raise FileNotFoundError(err_msg) - - with args.sharegpt_file.open("r", encoding="utf-8") as f: - sharegpt_raw = json.load(f) - - input_conversations: list[dict] = [] - for source_conv in tqdm(sharegpt_raw, desc="Loading ShareGPT", total=len(sharegpt_raw)): - msgs = parse_sharegpt_conversation(source_conv) - if not msgs: - continue - cid = source_conv.get("id") - conv_id = id_for_conversation(msgs) - if cid: - cid = f"{cid}_{conv_id}" - else: - cid = conv_id - cid = f"sharegpt-{cid}" - - input_conversations.append({"conversation_id": cid, "conversations": msgs}) - - print(f"Loaded {len(input_conversations)} filtered conversations from ShareGPT.") - - update_dataset_file_with_conversations( - input_conversations, args.output_dir, args.output_split_name - ) - - -if __name__ == "__main__": - import asyncio - - args = parse_args() - asyncio.run(main(args)) diff --git a/examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py b/examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py deleted file mode 100644 index 2c5f5c748..000000000 --- a/examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py +++ /dev/null @@ -1,86 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Add UltraChat conversations to a conversation dataset.""" - -import argparse -from pathlib import Path - -from datasets import load_dataset -from tqdm import tqdm -from utils import ( - dataset_splits_explanation, - id_for_conversation, - update_dataset_file_with_conversations, -) - - -def parse_args() -> argparse.Namespace: - """Parse command-line arguments.""" - parser = argparse.ArgumentParser(description="Load UltraChat conversations.") - - parser.add_argument( - "--ultrachat-split", - type=str, - default="train_sft", - help="Split of the HuggingFace UltraChat dataset to load. Default is 'train_sft'.", - ) - - parser.add_argument( - "--output-split-name", - type=str, - default="ultrachat", - help=dataset_splits_explanation("ultrachat"), - ) - - parser.add_argument( - "--output-dir", - type=Path, - default=Path("input_conversations/"), - help="Path to save the conversations file(s) into. Default is 'input_conversations/'.", - ) - - return parser.parse_args() - - -async def main(args: argparse.Namespace) -> None: - ds = load_dataset("HuggingFaceH4/ultrachat_200k", split=args.ultrachat_split, streaming=False) - input_conversations = [] - for i in tqdm( - range(len(ds)), - desc=f"Loading UltraChat split {args.ultrachat_split}", - total=len(ds), - ): - prompt = ds[i]["prompt"].strip() - prompt_id = ds[i]["prompt_id"].strip() - if prompt and prompt_id: - msgs = [{"role": "user", "content": prompt}] - prompt_id = ( - f"ultrachat-{args.ultrachat_split}_{i:06}-{prompt_id}_" + id_for_conversation(msgs) - ) - input_conversations.append({"conversation_id": prompt_id, "conversations": msgs}) - - print(f"Loaded {len(input_conversations)} prompts from UltraChat.") - - update_dataset_file_with_conversations( - input_conversations, args.output_dir, args.output_split_name - ) - - -if __name__ == "__main__": - import asyncio - - args = parse_args() - asyncio.run(main(args)) diff --git a/examples/speculative_decoding/prepare_input_conversations/example_data_config.yaml b/examples/speculative_decoding/prepare_input_conversations/example_data_config.yaml new file mode 100644 index 000000000..a77c91f73 --- /dev/null +++ b/examples/speculative_decoding/prepare_input_conversations/example_data_config.yaml @@ -0,0 +1,33 @@ +outputs: + - filename: "input_conversations/train.jsonl" + global_limit: 50000 + sources: + - name: "sharegpt" + splits: + all: 0 + - name: "ultrachat" + splits: + train_gen: 25000 + train_sft: 25000 + - name: "mtbench" + splits: + all: 0 + - name: "daring-anteater" + splits: + all: 0 + - name: "magpie" + splits: + 300k: 0 + 500k: 0 + 1M: 0 + - name: "nemotron-post-training-v2" + splits: + chat: 0 + stem: 0 + math: 0 + code: 0 + multilingual_ja: 0 + multilingual_it: 0 + multilingual_de: 0 + multilingual_es: 0 + multilingual_fr: 0 diff --git a/examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh b/examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh deleted file mode 100644 index fa1319b8e..000000000 --- a/examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Example script to prepare a dataset of prompts for generation -# Lines in this script can be uncommented to include specific datasets/splits in the prompt dataset. - -python3 prepare_input_conversations/add_daring_anteater.py --output-split-name train -# python3 prepare_input_conversations/add_sharegpt.py --output-split-name train -# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_sft --output-split-name train -# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_gen --output-split-name train -# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_sft --output-split-name mix_test -# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_gen --output-split-name mix_test -python3 prepare_input_conversations/add_mtbench.py --output-split-name mix_test diff --git a/examples/speculative_decoding/prepare_input_conversations/make_dataset.py b/examples/speculative_decoding/prepare_input_conversations/make_dataset.py new file mode 100644 index 000000000..716634902 --- /dev/null +++ b/examples/speculative_decoding/prepare_input_conversations/make_dataset.py @@ -0,0 +1,532 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Using a YAML file as an outline, initialize one or more conversation dataset files, +each as a JSONL file containing a list of conversations sampled from multiple source datasets. + +Each source dataset is specified in the YAML file with its name, and which splits +from that dataset to include in the output conversation dataset files, as well as +bounds on the number of conversations to include from each split. + +A global limit can also be placed on the total number of conversations in each output dataset file. + +The dataset choices available are: +- "mtbench" +- "sharegpt" +- "ultrachat" +- "daring-anteater" +- "magpie" +- "nemotron-post-training-v2" + +Here is an example YAML file: + +``` +outputs: + - filename: "mixed_conversation.jsonl" + global_limit: 5000 # downsample to 5000 total samples + sources: + - name: "mtbench" + splits: ["all"] + - name: "ultrachat" + splits: + train_gen: 0.5 # 50% of examples from train_gen split + train_sft: 100 # 100 examples from train_sft split + test_gen: "all" # all examples from test_gen split +``` +""" + +import argparse +import asyncio +import json +import logging +import random +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import yaml +from datasets import load_dataset +from utils import download_file, id_for_conversation + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%H:%M:%S" +) +logger = logging.getLogger(__name__) + + +@dataclass +class SourceDatasetSpec: + """Defines which dataset to read and how to sample from its splits.""" + + name: str + # Accepts list (implied "all") or dict (specific counts/percentages) + splits: list[str] | dict[str, int | float | str] + + def __post_init__(self): + # Normalize list format ["train"] -> {"train": "all"} + if isinstance(self.splits, list): + self.splits = dict.fromkeys(self.splits, "all") + + +@dataclass +class OutputConfig: + """Defines a single target JSONL file to generate.""" + + filename: str + sources: list[SourceDatasetSpec] + global_limit: int | None = None + + def __post_init__(self): + # Convert dictionary dictionaries into strongly typed objects + self.sources = [SourceDatasetSpec(**s) if isinstance(s, dict) else s for s in self.sources] + + +@dataclass +class DataMixingConfig: + """The top-level configuration containing all output jobs.""" + + outputs: list[OutputConfig] + + @classmethod + def load(cls, path: str): + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("r", encoding="utf-8") as f: + data = yaml.safe_load(f) + # Instantiate objects from the parsed YAML dict + return cls(outputs=[OutputConfig(**o) for o in data.get("outputs", [])]) + + +def check_row_constraint(constraint) -> int | float | None: + if constraint == "all": + return None + if constraint < 0: + raise ValueError("Number of samples to use for a split cannot be negative.") + if isinstance(constraint, (float, int)): + return constraint + return 0 + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Prepare conversation datasets based on a YAML configuration." + ) + + parser.add_argument( + "--config-file", + "-c", + "-f", + type=Path, + required=True, + help="Path to the YAML configuration file specifying dataset construction.", + ) + parser.add_argument( + "--full-conversations", + "--full", + action="store_true", + help="If set, include full conversations including assistant completions. " + "By default, the last assistant completion is stripped to use the conversation as a prompt.", + ) + + return parser.parse_args() + + +def max_samples_for_constraint(total_size: int, row_constraint: int | float | None) -> int: + """Get the maximum number of samples to draw from a dataset split based on the constraint.""" + if row_constraint is None: + # "all" + return total_size + elif isinstance(row_constraint, float): + # Percentage + return int(total_size * row_constraint) + else: + # Absolute number + return min(row_constraint, total_size) + + +MTBENCH_QUESTIONS_URL = "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl" + + +async def _load_mtbench_conversations( + split_name: str, +) -> AsyncGenerator[int | dict[str, Any], None]: + if split_name != "all": + logger.warning("MTBench dataset has no splits; you should provide it as 'all'. Skipping.") + yield 0 + return + + # Download the MTBench questions file if not provided + mtbench_questions_file = ( + Path("~/.cache/modelopt/mtbench_questions.jsonl").expanduser().resolve() + ) + if not mtbench_questions_file.exists(): + logger.info("Downloading MTBench questions dataset...") + await download_file(MTBENCH_QUESTIONS_URL, mtbench_questions_file) + + # Error if we failed to download the file + if not mtbench_questions_file.exists(): + err_msg = f"MTBench questions file {mtbench_questions_file} does not exist." + raise FileNotFoundError(err_msg) + + with mtbench_questions_file.open("r", encoding="utf-8") as f: + mtbench_raw = [json.loads(line) for line in f] + + random.shuffle(mtbench_raw) + yield len(mtbench_raw) + + for entry in mtbench_raw: + if not entry: + continue + prompt = entry.get("turns", [""])[0] + if not prompt: + continue + prompt_id = f"mtbench-{entry['question_id']:03}-" + id_for_conversation(prompt) + yield {"conversation_id": prompt_id, "conversations": [{"role": "user", "content": prompt}]} + logger.info("Finished loading MTBench conversations.") + + +SHAREGPT_DATASET_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" + + +def _parse_sharegpt_conversation(sharegpt_conv: dict) -> list[dict] | None: + """Parse a ShareGPT conversation into a list of messages.""" + msgs = [] + for turn in sharegpt_conv.get("conversations", []): + if turn.get("from") in ["human", "user"]: + role = "user" + elif turn.get("from") in ["gpt", "chatgpt", "bard"]: + role = "assistant" + elif turn.get("from") == "system": + # ShareGPT system messages are metadata, skip them + continue + elif turn.get("from") == "bing": + # Bing conversations are skipped for training, omit it + return None + else: + err_msg = f"Unknown role in conversation: {turn.get('from')}" + raise ValueError(err_msg) + + value = turn.get("value", "").strip() + if value: + msgs.append({"role": role, "content": value}) + + return msgs + + +async def _load_sharegpt_conversations( + split_name: str, +) -> AsyncGenerator[int | dict[str, Any], None]: + if split_name != "all": + logger.warning("ShareGPT dataset has no splits; you should provide it as 'all'. Skipping.") + yield 0 + return + + # Download the ShareGPT dataset if not provided + sharegpt_file = Path("~/.cache/modelopt/sharegpt.json").expanduser().resolve() + if not sharegpt_file.exists(): + logger.info("Downloading ShareGPT dataset...") + await download_file(SHAREGPT_DATASET_URL, sharegpt_file) + + # Error if we failed to download the file + if not sharegpt_file.exists(): + err_msg = f"ShareGPT file {sharegpt_file} does not exist." + raise FileNotFoundError(err_msg) + + with sharegpt_file.open("r", encoding="utf-8") as f: + sharegpt_raw = json.load(f) + + random.shuffle(sharegpt_raw) + yield len(sharegpt_raw) + + for source_conv in sharegpt_raw: + msgs = _parse_sharegpt_conversation(source_conv) + if not msgs: + continue + cid = source_conv.get("id") + conv_id = id_for_conversation(msgs) + if cid: + cid = f"{cid}-{conv_id}" + else: + cid = conv_id + cid = f"sharegpt-{cid}" + + yield {"conversation_id": cid, "conversations": msgs} + logger.info("Finished loading ShareGPT conversations.") + + +async def _load_ultrachat_conversations( + split_name: str, +) -> AsyncGenerator[int | dict[str, Any], None]: + ds = load_dataset("HuggingFaceH4/ultrachat_200k", split=split_name) + ds = ds.shuffle(seed=42) + yield len(ds) + for i in range(len(ds)): + prompt = ds[i]["prompt"].strip() + prompt_id = ds[i]["prompt_id"].strip() + if prompt: + msgs = [{"role": "user", "content": prompt}] + if not prompt_id: + prompt_id = id_for_conversation(msgs) + prompt_id = f"ultrachat-{split_name}-{prompt_id}" + yield {"conversation_id": prompt_id, "conversations": msgs} + logger.info(f"Finished loading UltraChat {split_name} conversations.") + + +def _parse_daring_anteater_conversation(daring_anteater_conv: list) -> list[dict] | None: + """Parse a DaringAnteater conversation into a list of messages.""" + msgs = [] + for turn in daring_anteater_conv: + if "from" in turn: + role = turn["from"].lower() + elif "role" in turn: + role = turn["role"].lower() + else: + continue + if role == "human": + role = "user" + elif role == "gpt": + role = "assistant" + + if "value" in turn: + content = turn["value"] + elif "text" in turn: + content = turn["text"] + elif "content" in turn: + content = turn["content"] + else: + continue + content = content.strip() + if content: + msgs.append({"role": role, "content": content}) + + return msgs + + +async def _load_daring_anteater_conversations( + split_name: str, +) -> AsyncGenerator[int | dict[str, Any], None]: + ds = load_dataset("nvidia/Daring-Anteater", split=split_name) + ds = ds.shuffle(seed=42) + yield len(ds) + for i in range(len(ds)): + conversations = ds[i]["conversations"] + if conversations and isinstance(conversations, list): + prompt_id = f"daring-anteater-{split_name}-" + id_for_conversation(conversations) + processed_conversations = _parse_daring_anteater_conversation(conversations) + if processed_conversations: + yield {"conversation_id": prompt_id, "conversations": processed_conversations} + logger.info(f"Finished loading Daring-Anteater {split_name} conversations.") + + +async def _load_magpie_conversations( + split_name: str, +) -> AsyncGenerator[int | dict[str, Any], None]: + if split_name not in ("300k", "500k", "1M"): + logger.warning("Only Magpie splits '300k', '500k' and '1M' are available. Skipping.") + yield 0 + return + if split_name == "500k": + ds = load_dataset("Magpie-Align/Magpie-Llama-3.3-Pro-500K-Filtered", split="train") + elif split_name == "300k": + ds = load_dataset("Magpie-Align/Magpie-Llama-3.1-Pro-300K-Filtered", split="train") + else: + assert split_name == "1M" + ds = load_dataset("Magpie-Align/Magpie-Llama-3.3-Pro-1M-v0.1", split="train") + ds = ds.shuffle(seed=42) + yield len(ds) + for i in range(len(ds)): + prompt = ds[i]["instruction"].strip() + if prompt: + conversations = [{"role": "user", "content": prompt}] + prompt_id = f"magpie-{split_name}-" + id_for_conversation(conversations) + yield {"conversation_id": prompt_id, "conversations": conversations} + logger.info(f"Finished loading Magpie {split_name} conversations.") + + +async def load_nemotron_post_training_v2_conversations( + split_name: str, +) -> AsyncGenerator[int | dict[str, Any], None]: + nemotron_splits = [ + "math", + "code", + "chat", + "stem", + "multilingual_ja", + "multilingual_it", + "multilingual_de", + "multilingual_es", + "multilingual_fr", + ] + if split_name not in nemotron_splits: + logger.warning( + f"Nemotron Post-Training V2 splits are: {', '.join(nemotron_splits)}. Skipping." + ) + yield 0 + return + + ds = load_dataset("nvidia/Nemotron-Post-Training-Dataset-v2", split=split_name) + ds = ds.shuffle(seed=42) + yield len(ds) + + for i in range(len(ds)): + conversations = ds[i]["messages"] + if conversations and isinstance(conversations, list): + # Strip leading empty system messages + while ( + conversations + and conversations[0]["role"] == "system" + and not conversations[0]["content"].strip() + ): + conversations.pop(0) + prompt_id = f"nemotron-post-training-v2-{split_name}-" + id_for_conversation( + conversations + ) + yield {"conversation_id": prompt_id, "conversations": conversations} + logger.info(f"Finished loading Nemotron Post-Training V2 {split_name} conversations.") + + +async def load_conversations_for_split( + dataset_name: str, + split_name: str, + row_constraint: int | float | None, + strip_last_completion: bool = True, +) -> list[dict]: + if dataset_name == "mtbench": + samples_it = _load_mtbench_conversations(split_name) + elif dataset_name == "sharegpt": + samples_it = _load_sharegpt_conversations(split_name) + elif dataset_name == "ultrachat": + samples_it = _load_ultrachat_conversations(split_name) + elif dataset_name == "daring-anteater": + samples_it = _load_daring_anteater_conversations(split_name) + elif dataset_name == "magpie": + samples_it = _load_magpie_conversations(split_name) + elif dataset_name == "nemotron-post-training-v2": + samples_it = load_nemotron_post_training_v2_conversations(split_name) + else: + logger.warning(f"Dataset {dataset_name} is not yet implemented. Ignoring.") + return [] + + num_samples = await samples_it.__anext__() + assert isinstance(num_samples, int), "First yielded value must be the total number of samples." + deduplication_ids = set() + unique_samples = [] + max_num_samples = max_samples_for_constraint(num_samples, row_constraint) + async for sample in samples_it: + assert isinstance(sample, dict) and "conversations" in sample, ( + "Each conversation sample must be a dict with a 'conversations' field." + ) + + # Strip the last turn of the conversation as long as it is an assistant completion, + # since we want to use these conversations as prompts only. + if strip_last_completion: + while sample["conversations"] and sample["conversations"][-1]["role"] != "user": + sample["conversations"].pop() + + if not sample["conversations"]: + continue + + sample["source_dataset"] = dataset_name + sample["source_split"] = split_name + + # Deduplicate based on the first 512 characters from each turn. + # To avoid too many similar conversations with minor differences. + truncated_conversations = [ + {"role": msg["role"], "content": msg["content"][0:512]} + for msg in sample["conversations"] + ] + dedup_id = id_for_conversation(truncated_conversations) + if dedup_id not in deduplication_ids: + deduplication_ids.add(dedup_id) + unique_samples.append(sample) + if len(unique_samples) >= max_num_samples: + break + return unique_samples + + +async def main(args: argparse.Namespace) -> None: + config = DataMixingConfig.load(args.config_file) + + for output in config.outputs: + all_conversations_promises = [] + for source in output.sources: + for split_name, constraint in source.splits.items(): + row_constraint = check_row_constraint(constraint) + if row_constraint == 0: + continue # Skip this split, no samples requested + all_conversations_promises.append( + load_conversations_for_split( + source.name, + split_name, + row_constraint, + strip_last_completion=not args.full_conversations, + ) + ) + + all_conversations_results = await asyncio.gather(*all_conversations_promises) + all_conversations = [] + num_conversations_per_split = {} + for conversations in all_conversations_results: + all_conversations.extend(conversations) + + total_num_conversations = len(all_conversations) + if output.global_limit is not None and total_num_conversations > output.global_limit: + random_indices = random.sample(range(total_num_conversations), output.global_limit) + all_conversations = [all_conversations[i] for i in random_indices] + logger.info( + "Subsampling uniformly from %d to global limit of %d conversations.", + total_num_conversations, + output.global_limit, + ) + else: + random.shuffle(all_conversations) + + for conversation in all_conversations: + key = (conversation["source_dataset"], conversation["source_split"]) + num_conversations_per_split[key] = num_conversations_per_split.get(key, 0) + 1 + + # Metadata for pretty-printing + max_ds_len = max((len(ds) for ds, _ in num_conversations_per_split), default=0) + max_split_len = max((len(sp) for _, sp in num_conversations_per_split), default=0) + logger.info("Dataset splits used for output '%s':", output.filename) + num_conversations_per_split = dict( + sorted(num_conversations_per_split.items(), key=lambda item: (item[0][0], item[0][1])) + ) + for (dataset_name, split_name), num_convs in num_conversations_per_split.items(): + logger.info( + f" - {dataset_name:<{max_ds_len}} / {split_name:<{max_split_len}} : {num_convs:<8} conversations" + ) + + logger.info(f"Writing {len(all_conversations)} conversations to {output.filename}") + output_path = Path(output.filename) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", encoding="utf-8") as f: + for entry in all_conversations: + assert "conversations" in entry, ( + "Each conversation entry must have a 'conversations' field." + ) + if "conversation_id" not in entry: + entry["conversation_id"] = id_for_conversation(entry["conversations"]) + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + + +if __name__ == "__main__": + import asyncio + + args = parse_args() + asyncio.run(main(args)) diff --git a/examples/speculative_decoding/prepare_input_conversations/utils.py b/examples/speculative_decoding/prepare_input_conversations/utils.py index 6a3698f36..ca14bd6a9 100644 --- a/examples/speculative_decoding/prepare_input_conversations/utils.py +++ b/examples/speculative_decoding/prepare_input_conversations/utils.py @@ -17,7 +17,6 @@ import hashlib import json -import random from pathlib import Path import aiohttp @@ -40,133 +39,3 @@ def id_for_conversation(conversation: list) -> str: json_str = json.dumps(conversation, sort_keys=True, separators=(",", ":"), ensure_ascii=False) json_bytes = json_str.encode("utf-8") return hashlib.sha256(json_bytes).hexdigest() - - -def add_conversations_to_split(conversations: list, dataset_dir: Path, split: str) -> None: - """Add conversations to a specific split in the dataset.""" - if len(conversations) == 0: - return - - # Open the dataset file for the specified split, or create it if it doesn't exist - dataset_file = dataset_dir / f"{split}.jsonl" - all_conversations = [] - if dataset_file.exists(): - # load the existing conversations - with dataset_file.open("r", encoding="utf-8") as f: - all_conversations.extend([json.loads(line) for line in f if line.strip()]) - - if any(not entry.get("conversation_id") for entry in all_conversations): - msg = "All existing conversations must have a 'conversation_id' field." - raise ValueError(msg) - - existing_ids = {entry["conversation_id"] for entry in all_conversations} - num_new_entries = 0 - num_duplicates = 0 - for entry in conversations: - if entry.get("conversation_id") is None: - raise ValueError("Each conversation must have a 'conversation_id' field.") - if entry["conversation_id"] not in existing_ids: - all_conversations.append( - { - "conversation_id": entry["conversation_id"], - "conversations": entry["conversations"], - } - ) - num_new_entries += 1 - else: - num_duplicates += 1 - - if num_duplicates > 0: - print( - f"Added {num_new_entries} new conversations to {dataset_file}, " - f"skipped {num_duplicates} existing entries." - ) - else: - print(f"Added {num_new_entries} new conversations to {dataset_file}.") - - dataset_dir.mkdir(parents=True, exist_ok=True) - with dataset_file.open("w", encoding="utf-8") as f: - for entry in all_conversations: - f.write(json.dumps(entry, ensure_ascii=False) + "\n") - - -def mix_conversations_and_add_to_splits( - conversations: list, - dataset_dir: Path, - train_ratio: float, - val_ratio: float, - test_ratio: float, - *, - shuffle: bool = True, - seed: int = 42, -) -> None: - """Mix the conversations and add to the dataset's train, val, and test splits.""" - if train_ratio + val_ratio + test_ratio != 1.0: - msg = "Ratios must sum to 1.0" - raise ValueError(msg) - if any(ratio < 0 for ratio in [train_ratio, val_ratio, test_ratio]): - msg = "Ratios must be non-negative" - raise ValueError(msg) - - total_conversations = len(conversations) - train_count = int(total_conversations * train_ratio) - val_count = int(total_conversations * val_ratio) - - if shuffle: - random.seed(seed) - random.shuffle(conversations) - - train_conversations = conversations[:train_count] - val_conversations = conversations[train_count : train_count + val_count] - test_conversations = conversations[train_count + val_count :] - add_conversations_to_split(train_conversations, dataset_dir, "train") - add_conversations_to_split(val_conversations, dataset_dir, "val") - add_conversations_to_split(test_conversations, dataset_dir, "test") - - -def update_dataset_file_with_conversations( - conversations: list, dataset_dir: Path, dataset_split: str -) -> None: - """ - Update a dataset file with new conversations. The conversations are added to the specified - split in the dataset. If the split is 'mix' or 'mix_test', the conversations are mixed and - distributed into train, val, and test splits according to predefined ratios. - """ - if dataset_split == "mix": - print("Mixing conversations and adding to train, val, and test splits.") - mix_conversations_and_add_to_splits( - conversations, - dataset_dir, - train_ratio=0.8, - val_ratio=0.1, - test_ratio=0.1, - ) - elif dataset_split == "mix_test": - print("Mixing conversations and adding to val and test splits.") - mix_conversations_and_add_to_splits( - conversations, - dataset_dir, - train_ratio=0.0, - val_ratio=0.5, - test_ratio=0.5, - ) - else: - add_conversations_to_split(conversations, dataset_dir, dataset_split) - - -def dataset_splits_explanation(default_split: str) -> str: - """Return an explanation string for the dataset split argument.""" - return f"""Split to assign the processed conversations to. - Can be any name, or one of ['mix', 'mix_test']. - Default is '{default_split}'. - - If the provided split name matches an existing file in the dataset directory, - the new conversations will be added to that file, - avoiding duplicates based on conversation IDs. - - Special split names: - - 'mix': Conversations will be randomly mixed and distributed into - 'train' (80%%), 'val' (10%%), and 'test' (10%%) splits. - - 'mix_test': Conversations will be randomly mixed and distributed into - 'val' (50%%) and 'test' (50%%) splits. - """ diff --git a/examples/speculative_decoding/train_eagle3_and_export.sh b/examples/speculative_decoding/train_eagle3_and_export.sh index 60c57d1c3..7c801345f 100755 --- a/examples/speculative_decoding/train_eagle3_and_export.sh +++ b/examples/speculative_decoding/train_eagle3_and_export.sh @@ -20,7 +20,7 @@ set -eo pipefail # Set default values for BASE_MODEL, NUM_GPU, and DATA BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct NUM_GPU=1 -DATA=input_conversations/daring-anteater.jsonl +DATA=input_conversations/train.jsonl # Parse input arguments --base_model, --num_gpu, and --data while [[ $# -gt 0 ]]; do