From 36a40dfa4735f72f2ffee27d8f05136c5d891612 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 25 Feb 2026 10:28:24 +0800 Subject: [PATCH 01/14] wip --- cookbook/megatron/tp_moe.sh | 2 + cookbook/megatron/tp_moe_qwen35.py | 49 ++ cookbook/rl/grpo_qwen3_5.py | 176 +++++++ src/twinkle/model/megatron/args.py | 122 +++-- src/twinkle/model/megatron/model/constant.py | 3 + .../model/megatron/model/gpt_bridge.py | 2 +- src/twinkle/model/megatron/model/gpt_model.py | 6 +- .../model/megatron/model/gpts/qwen3_next.py | 456 ++++++++++++++++++ .../model/megatron/model/mm_gpt_model.py | 3 +- .../model/megatron/model/mm_gpts/__init__.py | 2 +- .../model/megatron/model/mm_gpts/qwen3_5.py | 159 ++++++ src/twinkle/model/megatron/utils/config.py | 20 +- 12 files changed, 958 insertions(+), 42 deletions(-) create mode 100644 cookbook/megatron/tp_moe_qwen35.py create mode 100644 cookbook/rl/grpo_qwen3_5.py create mode 100644 src/twinkle/model/megatron/model/gpts/qwen3_next.py create mode 100644 src/twinkle/model/megatron/model/mm_gpts/qwen3_5.py diff --git a/cookbook/megatron/tp_moe.sh b/cookbook/megatron/tp_moe.sh index 58e58646..27132b8d 100644 --- a/cookbook/megatron/tp_moe.sh +++ b/cookbook/megatron/tp_moe.sh @@ -1 +1,3 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 tp_moe.py +CUDA_VISIBLE_DEVICES=4,5,6,7 nohup torchrun --nproc_per_node=4 /mnt/nas2/hujinghan.hjh/twinkle/cookbook/megatron/tp_moe.py > tp_moe.log 2>&1 & + diff --git a/cookbook/megatron/tp_moe_qwen35.py b/cookbook/megatron/tp_moe_qwen35.py new file mode 100644 index 00000000..8b85540b --- /dev/null +++ b/cookbook/megatron/tp_moe_qwen35.py @@ -0,0 +1,49 @@ +import os +from peft import LoraConfig +from tqdm import tqdm + +import twinkle +from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import MegatronModel +from twinkle.preprocessor import SelfCognitionProcessor + +# tp=2, pp=2, ep=2 on 4 GPUs, dp=1 +device_mesh = DeviceMesh.from_sizes(dp_size=1, tp_size=2, pp_size=2, ep_size=2) +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + +logger = get_logger() + +MODEL_ID = '/root/.cache/modelscope/hub/models/Qwen/Qwen3.5-35B-A3B' + +def train(): + dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) + dataset.set_template('Template', model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + dataset.encode() + dataloader = DataLoader(dataset=dataset, batch_size=4) + + model = MegatronModel(model_id=MODEL_ID) + lora_config = LoraConfig(r=8, lora_alpha=16, target_modules='all-linear') + model.add_adapter_to_model('default', lora_config) + model.set_optimizer(optimizer_cls='default', lr=1e-4) + model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=2, lr_decay_steps=len(dataloader)) + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info(f'Total steps: {len(dataloader)}') + + for step, batch in enumerate(dataloader): + model.forward_backward(inputs=batch) + model.clip_grad_and_step() + if step % 5 == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Step {step}/{len(dataloader)}, metric: {metric}') + if step >= 10: + break + model.save('last-checkpoint') + logger.info('Training completed.') + + +if __name__ == '__main__': + train() diff --git a/cookbook/rl/grpo_qwen3_5.py b/cookbook/rl/grpo_qwen3_5.py new file mode 100644 index 00000000..8d9eec8a --- /dev/null +++ b/cookbook/rl/grpo_qwen3_5.py @@ -0,0 +1,176 @@ +import os +from typing import List, Tuple, Dict, Any + +from peft import LoraConfig + +import twinkle +from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger +from twinkle.advantage import GRPOAdvantage +from twinkle.checkpoint_engine import CheckpointEngineManager +from twinkle.data_format import SamplingParams +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.processor import InputProcessor +from twinkle.reward import GSM8KAccuracyReward, GSM8KFormatReward +from twinkle.sampler import vLLMSampler +from twinkle.template import Template +from twinkle.metric import CompletionRewardMetric +from twinkle.preprocessor.llm import GSM8KProcessor + +logger = get_logger() + +MODEL_ID = '/root/.cache/modelscope/hub/models/Qwen/Qwen3.5-35B-A3B' +USE_MEGATRON = True + +MODEL_GPUS = 4 +SAMPLER_GPUS = 4 +NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS + +NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 4)) +MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 2048)) +LEARNING_RATE = float(os.environ.get('LR', 1e-5)) +MAX_STEPS = int(os.environ.get('MAX_STEPS', 20)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) +MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4)) +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) +ADAPTER_NAME = 'default' + + +def create_gsm8k_dataset(): + dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) + dataset.set_template('Template', model_id=MODEL_ID, max_length=1024) + dataset.map(GSM8KProcessor()) + dataset.encode(add_generation_prompt=True) + return dataset + + +def compute_rewards( + trajectories: List[Dict[str, Any]], +) -> Tuple[List[float], List[float], List[float]]: + accuracy_reward_fn = GSM8KAccuracyReward() + format_reward_fn = GSM8KFormatReward() + + accuracy_rewards = accuracy_reward_fn(trajectories) + format_rewards = format_reward_fn(trajectories) + total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)] + return total_rewards, format_rewards, accuracy_rewards + + +def main(): + device_groups = [ + DeviceGroup(name='model', ranks=list(range(4, 8)), device_type='GPU'), + DeviceGroup(name='sampler', ranks=list(range(4)), device_type='GPU'), + ] + # tp=2, pp=2, ep=2 for model group (4 GPUs) + # dp = world_size / (tp * pp) = 4 / (2 * 2) = 1 + model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=1, tp_size=2, pp_size=2, ep_size=2) + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) + + lora_config = LoraConfig(target_modules='all-linear', r=8, lora_alpha=16, lora_dropout=0.05) + + from twinkle.model.megatron import MegatronModel + model = MegatronModel(model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model', mixed_precision='bf16') + + model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1) + model.set_optimizer('default', lr=LEARNING_RATE) + model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE) + model.set_loss('GRPOLoss', epsilon=0.2) + model.set_processor(InputProcessor) + model.set_template('Template', model_id=MODEL_ID) + + sampler = vLLMSampler( + model_id=MODEL_ID, + engine_args={ + 'gpu_memory_utilization': 0.8, + 'max_model_len': 2048, + 'max_lora_rank': 8, + 'enable_lora': True, + }, + device_mesh=sampler_mesh, + remote_group='sampler', + ) + sampler.set_template(Template, model_id=MODEL_ID) + + ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler) + + GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS + dataloader = DataLoader( + dataset=create_gsm8k_dataset, + batch_size=GLOBAL_BATCH_SIZE, + min_batch_size=GLOBAL_BATCH_SIZE, + device_mesh=model_mesh, + remote_group='model', + ) + advantage_fn = GRPOAdvantage() + metrics = CompletionRewardMetric() + + sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS) + + optim_step = 0 + logger.info(get_device_placement()) + + for batch in dataloader: + if optim_step >= MAX_STEPS: + break + metrics.reset() + global_prompts = batch if isinstance(batch, list) else [batch] + ckpt_manager.sync_weights(merge_and_sync=False) + sampler.reset_prefix_cache() + sample_response = sampler.sample( + global_prompts * NUM_GENERATIONS, + sampling_params, + num_samples=1, + ) + + all_input_data: List[Dict[str, Any]] = [] + all_old_logps: List[List[float]] = [] + all_completion_lengths: List[int] = [] + + for sequence in sample_response.sequences: + all_input_data.append(sequence.new_input_feature) + all_old_logps.append(sequence.logprobs) + all_completion_lengths.append(len(sequence.tokens)) + total_rewards, format_rewards, accuracy_rewards = compute_rewards(all_input_data) + metrics.accumulate( + completion_lengths=all_completion_lengths, + rewards={ + 'total': total_rewards, + 'format': format_rewards, + 'accuracy': accuracy_rewards, + }, + ) + + advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist() + + total_completions = len(all_input_data) + for mb_start in range(0, total_completions, MINI_BATCH_SIZE): + mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions) + mb_inputs = all_input_data[mb_start:mb_end] + mb_old_logps = all_old_logps[mb_start:mb_end] + mb_advantages = advantages[mb_start:mb_end] + + model.forward_backward( + inputs=mb_inputs, + old_logps=mb_old_logps, + advantages=mb_advantages, + micro_batch_size=MICRO_BATCH_SIZE, + ) + model.clip_grad_and_step() + optim_step += 1 + + if optim_step >= MAX_STEPS: + break + log_dict = metrics.calculate() + log_dict.update(model.calculate_metric(is_training=True)) + metrics.reset() + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}') + + logger.info(f'Training completed. optim_steps={optim_step}') + model.save('grpo-qwen35-gsm8k-checkpoint') + + +if __name__ == '__main__': + main() diff --git a/src/twinkle/model/megatron/args.py b/src/twinkle/model/megatron/args.py index 858c2f0d..a0f817a3 100644 --- a/src/twinkle/model/megatron/args.py +++ b/src/twinkle/model/megatron/args.py @@ -332,7 +332,8 @@ def from_hf_config( # Detect multimodal model model_type = getattr(hf_config, 'model_type', 'qwen2') - is_multimodal = 'vl' in model_type.lower() or 'vision' in model_type.lower() or 'omni' in model_type.lower() + is_multimodal = ('vl' in model_type.lower() or 'vision' in model_type.lower() or 'omni' in model_type.lower() + or hasattr(hf_config, 'vision_config')) # Determine QKV bias if hasattr(text_config, 'attention_bias'): @@ -562,6 +563,20 @@ def _get_base_model(m): bias_activation_fusion = use_swiglu and not has_bias if 'moe_token_dispatcher_type' not in moe_kwargs: moe_kwargs['moe_token_dispatcher_type'] = 'alltoall' if self.variable_seq_lengths else 'allgather' + + # Handle use_shared_expert_gate from config + use_shared_expert_gate = mg_config_dict.get('use_shared_expert_gate', False) + + # Handle rotary_interleaved for models like Qwen3.5 with mrope + rotary_interleaved = mg_config_dict.get('rotary_interleaved', False) + partial_rotary_factor = mg_config_dict.get('partial_rotary_factor') + + # Determine position_embedding_type + position_embedding_type = mg_config_dict.get('position_embedding_type', 'rope') + apply_rope_fusion = True + if position_embedding_type != 'rope' or rotary_interleaved: + apply_rope_fusion = False + config = TransformerConfig( num_layers=num_layers, hidden_size=mg_config_dict['hidden_size'], @@ -578,61 +593,104 @@ def _get_base_model(m): params_dtype=self.params_dtype, fp16=self.params_dtype == torch.float16, bf16=self.params_dtype == torch.bfloat16, - pipeline_dtype=self.params_dtype, # Required when using pipeline parallelism + pipeline_dtype=self.params_dtype, use_cpu_initialization=self.use_cpu_initialization, add_qkv_bias=self.add_qkv_bias, variable_seq_lengths=self.variable_seq_lengths, add_bias_linear=not mg_config_dict.get('disable_bias_linear', True), gated_linear_unit=use_swiglu, - activation_func=activation_func, # SiLU for SwiGLU, GELU otherwise - bias_activation_fusion=bias_activation_fusion, # Fused SwiGLU for performance + activation_func=activation_func, + bias_activation_fusion=bias_activation_fusion, normalization='RMSNorm', layernorm_epsilon=mg_config_dict.get('norm_epsilon', 1e-6), qk_layernorm=mg_config_dict.get('qk_layernorm', False), hidden_dropout=0.0, attention_dropout=0.0, - # Performance optimizations - masked_softmax_fusion=True, # Fused attention softmax - bias_dropout_fusion=True, # Fused bias + dropout - apply_rope_fusion=True, # Fused RoPE application - attention_softmax_in_fp32=True, # Numerical stability - attention_backend=AttnBackend.flash, # FlashAttention for speed - # Activation recomputation for memory efficiency + masked_softmax_fusion=True, + bias_dropout_fusion=True, + apply_rope_fusion=apply_rope_fusion, + attention_softmax_in_fp32=True, + attention_backend=AttnBackend.flash, + rotary_interleaved=rotary_interleaved, recompute_granularity=self.recompute_granularity, recompute_modules=self.recompute_modules if self.recompute_granularity == 'selective' else None, recompute_method=recompute_method, recompute_num_layers=recompute_num_layers, - # Critical: Set finalize_model_grads_func for DP gradient synchronization - # Uses custom wrapper that handles both DDP and PEFT/LoRA models finalize_model_grads_func=finalize_model_grads_for_lora, - # MoE configuration **moe_kwargs, ) if exists('megatron_core>=0.13'): config.expert_tensor_parallel_size = self.etp_size - # Save transformer config for later use (e.g., DDP wrapping) + # Store layer_types on config for Qwen3-Next/Qwen3.5 heterogeneous layers + layer_types = mg_config_dict.get('layer_types') + if layer_types is not None: + config.layer_types = layer_types + self.layer_types = layer_types + for attr in ['linear_num_value_heads', 'linear_num_key_heads', 'linear_key_head_dim', + 'linear_value_head_dim', 'linear_conv_kernel_dim']: + val = mg_config_dict.get(attr) + if val is not None: + setattr(config, attr, val) + + # Store partial_rotary_factor on config + if partial_rotary_factor is not None: + config.partial_rotary_factor = partial_rotary_factor + + # Store args reference on config for HuggingFaceModule compatibility + config.args = self + self.config = config - # Get layer spec - enable moe_grouped_gemm for MoE models + # Get layer spec moe_grouped_gemm = num_experts > 0 - try: - layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=mg_config_dict.get('num_experts'), - moe_grouped_gemm=moe_grouped_gemm, - qk_layernorm=mg_config_dict.get('qk_layernorm', False), - ) - except (ImportError, AttributeError): - raise RuntimeError( - 'TransformerEngine is not installed or not compatible with this version of Megatron-Core.') + if layer_types is not None: + from .model.gpts.qwen3_next import get_qwen3_next_layer_spec, Qwen3NextGatedDeltaNet, Qwen3_5MoeGatedDeltaNet + hf_model_type = mg_config_dict.get('hf_model_type', '') + if hf_model_type in {'qwen3_5_moe', 'qwen3_5'}: + gated_delta_net_cls = Qwen3_5MoeGatedDeltaNet + else: + gated_delta_net_cls = Qwen3NextGatedDeltaNet + layer_spec = get_qwen3_next_layer_spec(config, self, gated_delta_net_cls) + else: + try: + layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=mg_config_dict.get('num_experts'), + moe_grouped_gemm=moe_grouped_gemm, + qk_layernorm=mg_config_dict.get('qk_layernorm', False), + ) + except (ImportError, AttributeError): + raise RuntimeError( + 'TransformerEngine is not installed or not compatible with this version of Megatron-Core.') + + # Set shared_expert_gate if needed + if use_shared_expert_gate and num_experts > 0 and moe_shared_expert_intermediate_size: + if hasattr(layer_spec, 'layer_specs'): + for ls in layer_spec.layer_specs: + if hasattr(ls.submodules.mlp.submodules, 'shared_experts'): + ls.submodules.mlp.submodules.shared_experts.params = {'gate': True} + elif hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts'): + layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} # Create model - max_seq_length = getattr(hf_config, 'max_position_embeddings', 4096) + text_config = hf_config + if hasattr(hf_config, 'text_config') and hf_config.text_config is not None: + text_config = hf_config.text_config + max_seq_length = getattr(text_config, 'max_position_embeddings', 4096) rotary_base = mg_config_dict.get('rotary_base', 10000) + rotary_percent = 1.0 + if partial_rotary_factor is not None: + rotary_percent = partial_rotary_factor extra_init_args = {} - if hasattr(hf_config, - 'rope_scaling') and hf_config.rope_scaling is not None and 'factor' in hf_config.rope_scaling: - extra_init_args = {'seq_len_interpolation_factor': hf_config.rope_scaling['factor']} + rope_scaling_dict = getattr(text_config, 'rope_scaling', None) or getattr(text_config, 'rope_parameters', None) + if rope_scaling_dict is not None and isinstance(rope_scaling_dict, dict): + if 'factor' in rope_scaling_dict: + extra_init_args['seq_len_interpolation_factor'] = rope_scaling_dict['factor'] + if 'rope_theta' in rope_scaling_dict: + rotary_base = int(rope_scaling_dict['rope_theta']) + mrope_section = mg_config_dict.get('mrope_section') + if position_embedding_type == 'mrope' and mrope_section is not None: + config.mrope_section = mrope_section vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() if vpp_size is not None and vpp_size > 1: model = [] @@ -651,7 +709,8 @@ def _get_base_model(m): post_process=mpu.is_pipeline_last_stage(**extra_kwargs), parallel_output=True, share_embeddings_and_output_weights=getattr(hf_config, 'tie_word_embeddings', False), - position_embedding_type='rope', + position_embedding_type=position_embedding_type, + rotary_percent=rotary_percent, rotary_base=rotary_base, **extra_init_args) model.append(_model) @@ -666,7 +725,8 @@ def _get_base_model(m): post_process=mpu.is_pipeline_last_stage(), parallel_output=True, share_embeddings_and_output_weights=getattr(hf_config, 'tie_word_embeddings', False), - position_embedding_type='rope', + position_embedding_type=position_embedding_type, + rotary_percent=rotary_percent, rotary_base=rotary_base, **extra_init_args, ) diff --git a/src/twinkle/model/megatron/model/constant.py b/src/twinkle/model/megatron/model/constant.py index b3ea8807..33ac637c 100644 --- a/src/twinkle/model/megatron/model/constant.py +++ b/src/twinkle/model/megatron/model/constant.py @@ -14,6 +14,8 @@ class MLLMModelType: qwen2_5_vl = 'qwen2_5_vl' qwen3_vl = 'qwen3_vl' qwen3_vl_moe = 'qwen3_vl_moe' + qwen3_5 = 'qwen3_5' + qwen3_5_moe = 'qwen3_5_moe' class ModelType(LLMModelType, MLLMModelType): @@ -29,6 +31,7 @@ class MLLMMegatronModelType: qwen2_vl = 'qwen2_vl' qwen2_5_vl = 'qwen2_5_vl' qwen3_vl = 'qwen3_vl' + qwen3_5 = 'qwen3_5' class MegatronModelType(LLMMegatronModelType, MLLMMegatronModelType): diff --git a/src/twinkle/model/megatron/model/gpt_bridge.py b/src/twinkle/model/megatron/model/gpt_bridge.py index 58e40440..37c2593d 100644 --- a/src/twinkle/model/megatron/model/gpt_bridge.py +++ b/src/twinkle/model/megatron/model/gpt_bridge.py @@ -1439,7 +1439,7 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: hf_state_dict = {} self._convert_mtp_extra(mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict) transformer_layer = None if mtp_layer is None else mtp_layer.transformer_layer - if not to_mcore and not self.args.hf_model_type.startswith('qwen3_next'): + if not to_mcore and not self.args.hf_model_type.startswith(('qwen3_next', 'qwen3_5')): self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', to_mcore) if self.args.untie_embeddings_and_output_weights: diff --git a/src/twinkle/model/megatron/model/gpt_model.py b/src/twinkle/model/megatron/model/gpt_model.py index 477ccaf5..85e3f251 100644 --- a/src/twinkle/model/megatron/model/gpt_model.py +++ b/src/twinkle/model/megatron/model/gpt_model.py @@ -211,10 +211,12 @@ def _preprocess( f'current_attention_scaling: {attention_scaling}.') packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' if self.position_embedding_type == 'mrope': + mrope_position_ids = position_ids + if mrope_position_ids.dim() == 2: + mrope_position_ids = mrope_position_ids.unsqueeze(0).expand(3, -1, -1) rotary_pos_emb = self.rotary_pos_emb( - position_ids, + mrope_position_ids, mrope_section=self.mrope_section, - packed_seq=packed_seq, ) else: rotary_pos_emb = self.rotary_pos_emb( diff --git a/src/twinkle/model/megatron/model/gpts/qwen3_next.py b/src/twinkle/model/megatron/model/gpts/qwen3_next.py new file mode 100644 index 00000000..bf980854 --- /dev/null +++ b/src/twinkle/model/megatron/model/gpts/qwen3_next.py @@ -0,0 +1,456 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +# Reference: swift/swift/megatron/model/gpts/qwen3_next.py +# Qwen3-Next / Qwen3.5 series model support for Megatron + +import megatron.core +import torch +from copy import deepcopy +from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, _get_extra_te_kwargs +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec, get_gpt_mtp_block_spec +from megatron.core.models.huggingface import HuggingFaceModule as _HuggingFaceModule +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.tensor_parallel import (gather_from_sequence_parallel_region, + reduce_scatter_to_sequence_parallel_region) +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.spec_utils import build_module +from megatron.core.transformer.transformer_block import TransformerBlockSubmodules +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import deprecate_inference_params, is_fa_min_version +from packaging import version +from typing import List, Optional, Tuple, Union + +from twinkle import get_logger +from twinkle.model.megatron.args import get_args + +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') +mcore_015 = version.parse(megatron.core.__version__) >= version.parse('0.15.0rc0') +try: + from flashattn_hopper.flash_attn_interface import _flash_attn_forward + from flashattn_hopper.flash_attn_interface import flash_attn_with_kvcache as flash_attn3_with_kvcache + HAVE_FA3 = True +except Exception: + HAVE_FA3 = False + +try: + from einops import rearrange +except ImportError: + rearrange = None + +try: + import transformer_engine # pylint: disable=unused-import + HAVE_TE = True + from megatron.core.extensions.transformer_engine import SplitAlongDim +except ImportError: + HAVE_TE = False + SplitAlongDim = None + +logger = get_logger() + + +class Qwen3NextRMSNorm(torch.nn.Module): + """ + Zero-Centered RMSNorm for Qwen3-Next/Qwen3.5. + Uses (1 + weight) scaling to match HuggingFace implementation exactly. + This eliminates the need for +1/-1 offset during weight conversion. + """ + + def __init__(self, config, hidden_size: int, eps: float = 1e-5): + super().__init__() + self.config = config + self.eps = eps + self.weight = torch.nn.Parameter(torch.zeros(hidden_size)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, hidden_states): + output = self._norm(hidden_states.float()) + output = output * (1.0 + self.weight.float()) + return output.type_as(hidden_states) + + +class Qwen3NextSelfAttention(SelfAttention): + """Full attention with output gate for Qwen3-Next/Qwen3.5 models. + + QKV projection produces [Q_heads, gate_heads, K_heads, V_heads] where + Q and gate are interleaved: Q0, gate0, Q1, gate1, ... + """ + + def __init__(self, config, submodules: SelfAttentionSubmodules, *args, **kwargs): + super(SelfAttention, self).__init__(config, submodules, *args, attention_type='self', **kwargs) + kwargs_pg = {} + if mcore_015: + kwargs_pg['tp_group'] = self.pg_collection.tp + elif mcore_013: + kwargs_pg['tp_group'] = self.model_comm_pgs.tp + self.linear_qkv = build_module( + submodules.linear_qkv, + self.config.hidden_size, + 2 * self.query_projection_size + 2 * self.kv_projection_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear or self.config.add_qkv_bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='qkv', + **kwargs_pg, + ) + + if submodules.q_layernorm is not None: + self.q_layernorm = build_module( + submodules.q_layernorm, + hidden_size=self.hidden_size_per_attention_head, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + else: + self.q_layernorm = None + + if submodules.k_layernorm is not None: + self.k_layernorm = build_module( + submodules.k_layernorm, + hidden_size=self.hidden_size_per_attention_head, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + else: + self.k_layernorm = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + rotary_pos_cos: Optional[torch.Tensor] = None, + rotary_pos_sin: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[int] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + try: + from megatron.core.utils import nvtx_range_pop, nvtx_range_push + except ImportError: + def nvtx_range_pop(*args, **kwargs): + return + def nvtx_range_push(*args, **kwargs): + return + + if hasattr(self.config, 'no_rope_freq'): + no_rope = (self.config.no_rope_freq[self.layer_number - 1] if self.config.no_rope_freq else False) + if no_rope: + rotary_pos_emb = None + + inference_context = deprecate_inference_params(inference_context, inference_params) + + if inference_context and inference_context.is_dynamic_batching(): + assert HAVE_FA3 or is_fa_min_version( + '2.7.3'), 'flash attn verion v2.7.3 and above is required for dynamic batching.' + + if self.config.flash_decode and not self.training and inference_context is not None: + rotary_pos_emb = None + else: + assert rotary_pos_cos is None and rotary_pos_sin is None + + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb, ) * 2 + + nvtx_range_push(suffix='qkv') + query, key, value, gate = self.get_query_key_value_tensors(hidden_states, key_value_states) + nvtx_range_pop(suffix='qkv') + + in_decode_mode = (inference_context is not None and inference_context.is_decode_only() and not self.training) + + nvtx_range_push(suffix='adjust_key_value') + if in_decode_mode and self.config.flash_decode: + assert self.layer_number in inference_context.key_value_memory_dict + assert inference_context.sequence_len_offset is not None + inference_key_memory, inference_value_memory = inference_context.key_value_memory_dict[self.layer_number] + output = self.flash_decode( + sequence_len_offset=sequence_len_offset, + query_layer=query, + key_layer=key, + value_layer=value, + inference_key_memory=inference_key_memory, + inference_value_memory=inference_value_memory, + rotary_cos=rotary_pos_cos, + rotary_sin=rotary_pos_sin, + rotary_interleaved=self.config.rotary_interleaved, + ) + out = output.transpose(0, 1).contiguous() + context_layer = out.view(out.size(0), out.size(1), -1) + output, bias = self.linear_proj(context_layer) + return output, bias + + if (in_decode_mode and self.config.enable_cuda_graph and inference_context.is_static_batching()): + raise ValueError('CUDA graphs must use flash decode with static batching!') + + result = self._adjust_key_value_for_inference( + inference_context, query, key, value, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset, + ) + if mcore_013: + query, key, value, rotary_pos_emb, attn_mask_type, block_table = result + else: + query, key, value, rotary_pos_emb, attn_mask_type = result + + if packed_seq_params is not None: + query = query.squeeze(1) + key = key.squeeze(1) + value = value.squeeze(1) + nvtx_range_pop(suffix='adjust_key_value') + + kwargs_cp = {} + if mcore_015: + kwargs_cp['cp_group'] = self.pg_collection.cp + elif mcore_013: + kwargs_cp['cp_group'] = self.model_comm_pgs.cp + nvtx_range_push(suffix='rotary_pos_emb') + if rotary_pos_emb is not None and not self.config.flash_decode: + q_pos_emb, k_pos_emb = rotary_pos_emb + + if packed_seq_params is not None: + cu_seqlens_q = (packed_seq_params.cu_seqlens_q_padded + if packed_seq_params.cu_seqlens_q_padded is not None + else packed_seq_params.cu_seqlens_q) + cu_seqlens_kv = (packed_seq_params.cu_seqlens_kv_padded + if packed_seq_params.cu_seqlens_kv_padded is not None + else packed_seq_params.cu_seqlens_kv) + else: + cu_seqlens_q = cu_seqlens_kv = None + + if q_pos_emb is not None: + if inference_context is None or inference_context.is_static_batching(): + query = apply_rotary_pos_emb(query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q, + **kwargs_cp) + else: + query = inference_context.apply_rotary_emb_query(query, q_pos_emb, self.config, cu_seqlens_q, + **kwargs_cp) + if k_pos_emb is not None: + key = apply_rotary_pos_emb(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv, **kwargs_cp) + nvtx_range_pop(suffix='rotary_pos_emb') + + nvtx_range_push(suffix='core_attention') + if self.checkpoint_core_attention and self.training: + core_attn_out = self._checkpointed_attention_forward( + query, key, value, attention_mask, attn_mask_type=attn_mask_type, + attention_bias=attention_bias, packed_seq_params=packed_seq_params, + ) + else: + if inference_context is None or inference_context.is_static_batching(): + core_attn_out = self.core_attention( + query, key, value, attention_mask, attn_mask_type=attn_mask_type, + attention_bias=attention_bias, packed_seq_params=packed_seq_params, + ) + else: + q, k, v = (query, key, value) + cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths() + cu_kv_lengths, kv_lengths, kv_lengths_decode_only, max_seqlen_k = (inference_context.cu_kv_lengths()) + core_attn_out = self.flash_decode_and_prefill( + q, k, v, max_seqlen_q, max_seqlen_k, cu_query_lengths, cu_kv_lengths, + kv_lengths, kv_lengths_decode_only, block_table, + ) + core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') + + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + nvtx_range_pop(suffix='core_attention') + + core_attn_out = core_attn_out * torch.sigmoid(gate.reshape_as(core_attn_out)) + nvtx_range_push(suffix='linear_proj') + output, bias = self.linear_proj(core_attn_out) + nvtx_range_pop(suffix='linear_proj') + + return output, bias + + def get_query_key_value_tensors(self, hidden_states, key_value_states=None): + mixed_qkv, _ = self.linear_qkv(hidden_states) + + new_tensor_shape = mixed_qkv.size()[:-1] + ( + self.num_query_groups_per_partition, + ((self.num_attention_heads_per_partition // self.num_query_groups_per_partition * 2 + 2) + * self.hidden_size_per_attention_head), + ) + mixed_qkv = mixed_qkv.view(*new_tensor_shape) + split_arg_list = [ + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head * 2), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head, + ] + + if SplitAlongDim is not None: + (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list) + else: + (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3) + + query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) + query, gate = query[:, :, ::2], query[:, :, 1::2] + if self.q_layernorm is not None: + query = self.q_layernorm(query) + if self.k_layernorm is not None: + key = self.k_layernorm(key) + + if self.config.test_mode: + self.run_realtime_tests() + + return query, key, value, gate + + +def _gated_delta_net_forward(self, hidden_states: torch.Tensor, **kwargs): + """Shared forward logic for all GatedDeltaNet variants.""" + args = get_args() + if args.sequence_parallel and args.tensor_model_parallel_size > 1: + hidden_states = gather_from_sequence_parallel_region(hidden_states) + seq_len = hidden_states.shape[0] + packed_seq_params = kwargs.get('packed_seq_params') + thd_format = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + if thd_format and not getattr(args, 'packing', False): + new_hidden_states = hidden_states.new_zeros( + (packed_seq_params.num_samples, packed_seq_params.max_seqlen_q.item(), hidden_states.shape[-1])) + attention_mask = hidden_states.new_zeros( + (packed_seq_params.num_samples, packed_seq_params.max_seqlen_q.item()), dtype=torch.bool) + cu_seqlens_q = packed_seq_params.cu_seqlens_q + for i in range(packed_seq_params.num_samples): + start, end = cu_seqlens_q[i], cu_seqlens_q[i + 1] + attention_mask[i, :end - start] = True + new_hidden_states[i, :end - start] = hidden_states[start:end, 0] + hidden_states = new_hidden_states + else: + hidden_states = hidden_states.transpose(0, 1) + attention_mask = kwargs.get('attention_mask') + if attention_mask is not None: + attention_mask = (~attention_mask).sum(dim=(1, 2)) > 0 + res = super(type(self), self).forward(hidden_states=hidden_states, attention_mask=attention_mask) + if thd_format and not getattr(args, 'packing', False): + res = res[attention_mask][:, None] + res = torch.concat([res, res.new_zeros(seq_len - res.shape[0], 1, res.shape[2])]) + else: + res = res.transpose(0, 1) + if args.sequence_parallel and args.tensor_model_parallel_size > 1: + res = reduce_scatter_to_sequence_parallel_region(res) / args.tensor_model_parallel_size + return res, None + + +def _gated_delta_net_init(self, hf_cls, config, submodules, layer_number, **kwargs): + """Shared __init__ logic for all GatedDeltaNet variants.""" + assert config.context_parallel_size == 1, 'Qwen3-Next/Qwen3.5 currently does not support context parallel.' + hf_cls.__init__(self, config, layer_number) + self.config = config + extra_kwargs = _get_extra_te_kwargs(config) + self.to(dtype=extra_kwargs['params_dtype'], device=extra_kwargs['device']) + + +try: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeGatedDeltaNet as _Qwen3_5MoeGatedDeltaNet +except ImportError: + _Qwen3_5MoeGatedDeltaNet = object + +try: + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextGatedDeltaNet as _Qwen3NextGatedDeltaNet +except ImportError: + _Qwen3NextGatedDeltaNet = object + + +class Qwen3NextGatedDeltaNet(_HuggingFaceModule, _Qwen3NextGatedDeltaNet): + """GatedDeltaNet for linear attention layers in Qwen3-Next models.""" + + def __init__(self, config, submodules: SelfAttentionSubmodules, layer_number: int, **kwargs): + assert _Qwen3NextGatedDeltaNet is not object, 'please update the `transformers` version.' + _gated_delta_net_init(self, _Qwen3NextGatedDeltaNet, config, submodules, layer_number, **kwargs) + + def forward(self, hidden_states: torch.Tensor, **kwargs): + return _gated_delta_net_forward(self, hidden_states, **kwargs) + + +class Qwen3_5MoeGatedDeltaNet(_HuggingFaceModule, _Qwen3_5MoeGatedDeltaNet): + """GatedDeltaNet for Qwen3.5-MoE linear attention layers.""" + + def __init__(self, config, submodules: SelfAttentionSubmodules, layer_number: int, **kwargs): + assert _Qwen3_5MoeGatedDeltaNet is not object, 'please update the `transformers` version.' + _gated_delta_net_init(self, _Qwen3_5MoeGatedDeltaNet, config, submodules, layer_number, **kwargs) + + def forward(self, hidden_states: torch.Tensor, **kwargs): + return _gated_delta_net_forward(self, hidden_states, **kwargs) + + +def get_local_layer_specs(config, layer_specs, vp_stage=None): + """Get the layer specs for layers assigned to this pipeline stage. + + Mirrors swift.megatron.utils.get_local_layer_specs for distributing + heterogeneous layer specs across pipeline stages. + """ + from megatron.core import mpu + pp_size = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + if pp_size <= 1: + return layer_specs + num_layers = len(layer_specs) + layers_per_stage = num_layers // pp_size + remainder = num_layers % pp_size + start = pp_rank * layers_per_stage + min(pp_rank, remainder) + if pp_rank < remainder: + layers_per_stage += 1 + return layer_specs[start:start + layers_per_stage] + + +def get_qwen3_next_layer_spec(config, args, gated_delta_net_cls): + """Build the heterogeneous transformer layer specs for Qwen3-Next/Qwen3.5. + + Returns a TransformerBlockSubmodules with per-layer specs matching + the model's layer_types (linear_attention / full_attention). + """ + config.hetereogenous_dist_checkpoint = True + config.hidden_act = 'silu' + config.rms_norm_eps = config.layernorm_epsilon + config.dtype = args.params_dtype + + layer_norm_impl = Qwen3NextRMSNorm + kwargs = {'use_kitchen': config.use_kitchen} if hasattr(config, 'use_kitchen') and mcore_013 else {} + moe_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=config.num_moe_experts, + moe_grouped_gemm=getattr(config, 'moe_grouped_gemm', True), + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=getattr(config, 'moe_use_legacy_grouped_gemm', False), + **kwargs, + ) + layer_specs = [] + for layer_type in config.layer_types: + layer_spec = deepcopy(moe_layer_spec) + if layer_type == 'linear_attention': + layer_spec.submodules.self_attention.module = gated_delta_net_cls + elif layer_type == 'full_attention': + layer_spec.submodules.self_attention.submodules.linear_qkv = TEColumnParallelLinear + layer_spec.submodules.self_attention.module = Qwen3NextSelfAttention + layer_spec.submodules.input_layernorm = layer_norm_impl + if hasattr(layer_spec.submodules, + 'pre_mlp_layernorm') and layer_spec.submodules.pre_mlp_layernorm is not IdentityOp: + layer_spec.submodules.pre_mlp_layernorm = layer_norm_impl + if hasattr(layer_spec.submodules.self_attention.submodules, 'q_layernorm'): + layer_spec.submodules.self_attention.submodules.q_layernorm = layer_norm_impl + if hasattr(layer_spec.submodules.self_attention.submodules, 'k_layernorm'): + layer_spec.submodules.self_attention.submodules.k_layernorm = layer_norm_impl + layer_specs.append(layer_spec) + + local_layer_specs = get_local_layer_specs(config, layer_specs) + block_spec = TransformerBlockSubmodules(layer_specs=local_layer_specs, layer_norm=layer_norm_impl) + + return block_spec + + +def get_qwen3_next_mtp_block_spec(config, transformer_layer_spec, **kwargs): + """Build MTP block spec with Qwen3NextRMSNorm.""" + mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=True, **kwargs) + for layer_spec in mtp_block_spec.layer_specs: + layer_spec.submodules.enorm = Qwen3NextRMSNorm + layer_spec.submodules.hnorm = Qwen3NextRMSNorm + layer_spec.submodules.layer_norm = Qwen3NextRMSNorm + return mtp_block_spec diff --git a/src/twinkle/model/megatron/model/mm_gpt_model.py b/src/twinkle/model/megatron/model/mm_gpt_model.py index 4e2aa4d1..83a86ef5 100644 --- a/src/twinkle/model/megatron/model/mm_gpt_model.py +++ b/src/twinkle/model/megatron/model/mm_gpt_model.py @@ -82,7 +82,8 @@ def forward(_self, input_): if reduce_scatter_embeddings: res = res.transpose(0, 1).contiguous() group_kwargs = {'group': _self.tp_group} if mcore_013 else {} - res = reduce_scatter_to_sequence_parallel_region(res, **group_kwargs) / args.tensor_model_parallel_size + tp_size = mpu.get_tensor_model_parallel_world_size() + res = reduce_scatter_to_sequence_parallel_region(res, **group_kwargs) / tp_size return res VocabParallelEmbedding.forward = forward diff --git a/src/twinkle/model/megatron/model/mm_gpts/__init__.py b/src/twinkle/model/megatron/model/mm_gpts/__init__.py index 2cee28f6..30f10d89 100644 --- a/src/twinkle/model/megatron/model/mm_gpts/__init__.py +++ b/src/twinkle/model/megatron/model/mm_gpts/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from . import qwen, qwen3_vl, utils +from . import qwen, qwen3_5, qwen3_vl, utils diff --git a/src/twinkle/model/megatron/model/mm_gpts/qwen3_5.py b/src/twinkle/model/megatron/model/mm_gpts/qwen3_5.py new file mode 100644 index 00000000..489fea86 --- /dev/null +++ b/src/twinkle/model/megatron/model/mm_gpts/qwen3_5.py @@ -0,0 +1,159 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +# Reference: swift/swift/megatron/model/mm_gpts/qwen3_5.py +# Qwen3.5 / Qwen3.5-MoE multimodal model support for Megatron + +import torch +from PIL import Image + +from twinkle.model.megatron.args import get_args +from twinkle.utils.torch_utils import to_device +from ..constant import MegatronModelType, ModelType +from ..gpt_bridge import GPTBridge, MultimodalGPTBridge +from ..register import MegatronModelMeta, register_megatron_model +from .utils import HuggingFaceModule + + +class Qwen3_5Vit(HuggingFaceModule): + """Vision module for Qwen3.5 / Qwen3.5-MoE models. + + Maps 'model.visual' from HF model to 'visual' in Megatron, + with merger as aligner. + """ + module_mapping = {'model.visual': 'visual'} + _vision_tower = ['visual'] + _aligner = ['visual.merger'] + + def __init__(self, config): + try: + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5TextModel + except ImportError: + Qwen3_5TextModel = None + try: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeTextModel + except ImportError: + Qwen3_5MoeTextModel = None + ignore_cls = [c for c in [Qwen3_5TextModel, Qwen3_5MoeTextModel] if c is not None] + super().__init__(config, ignore_cls) + + def get_inputs_embeds(self, inputs_embeds, **kwargs): + return self._get_inputs_embeds_hf(inputs_embeds, kwargs, self.visual, self.processor, self.model_config) + + def _get_inputs_embeds_hf(self, inputs_embeds, inputs, visual, processor, config): + input_ids = inputs['input_ids'] + pixel_values = inputs.get('pixel_values') + pixel_values_videos = inputs.get('pixel_values_videos') + image_grid_thw = inputs.get('image_grid_thw') + video_grid_thw = inputs.get('video_grid_thw') + dtype = visual.dtype + if pixel_values is None and pixel_values_videos is None: + images = [Image.new('RGB', (32, 32), (0, 0, 0))] + media_inputs = processor.image_processor(images=images, return_tensors='pt') + media_inputs = to_device(media_inputs, input_ids.device) + pixel_values = media_inputs['pixel_values'].type(dtype) + image_embeds = visual(pixel_values, grid_thw=media_inputs['image_grid_thw']) + if hasattr(image_embeds, 'pooler_output'): + image_embeds = image_embeds.pooler_output + inputs_embeds = inputs_embeds + image_embeds.mean().to(device=inputs_embeds.device) * 0. + else: + if pixel_values is None: + pixel_values_mixed = pixel_values_videos + grid_thw = video_grid_thw + elif pixel_values_videos is None: + pixel_values_mixed = pixel_values + grid_thw = image_grid_thw + else: + pixel_values_mixed = torch.concat([pixel_values, pixel_values_videos], dim=0) + grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0) + pixel_values_mixed = pixel_values_mixed.type(dtype) + mixed_embeds = visual(pixel_values_mixed, grid_thw=grid_thw) + if hasattr(mixed_embeds, 'pooler_output'): + mixed_embeds = mixed_embeds.pooler_output + if pixel_values is None: + image_embeds = None + video_embeds = mixed_embeds + elif pixel_values_videos is None: + image_embeds = mixed_embeds + video_embeds = None + else: + merge_length = processor.image_processor.merge_size**2 + image_tokens = (image_grid_thw.prod(dim=-1) // merge_length).sum() + image_embeds = mixed_embeds[:image_tokens] + video_embeds = mixed_embeds[image_tokens:] + + if image_embeds is not None: + image_mask = (input_ids == config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + image_mask = image_mask.to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if video_embeds is not None: + video_mask = (input_ids == config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + video_mask = video_mask.to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + return inputs_embeds + + +class Qwen3_5Bridge(MultimodalGPTBridge): + """Bridge for Qwen3.5 multimodal models. + + Uses language_model prefix for the LLM backbone since Qwen3.5 has a + multimodal architecture with model.language_model.layers structure. + + Overrides _set_layer_attn to handle the mixed linear/full attention + architecture specific to Qwen3-Next/Qwen3.5. + """ + hf_layers_prefix = 'model.language_model.layers' + hf_embed_key = 'model.language_model.embed_tokens.weight' + hf_final_layernorm_key = 'model.language_model.norm.weight' + hf_mtp_prefix = 'mtp.layers' + + def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool): + args = self.args + layer_types = getattr(args, 'layer_types', None) + if layer_types is None: + return super()._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore) + + layer_type = layer_types[layer_idx] if 0 <= layer_idx < len(layer_types) else 'full_attention' + mg_attn = None if mg_layer is None else mg_layer.self_attention + if layer_type == 'linear_attention': + hf_state_dict.update(self._set_module(mg_attn, hf_state_dict, 'linear_attn.', to_mcore)) + elif layer_type == 'full_attention': + hf_state_dict.update(self._set_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore)) + self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', to_mcore) + return hf_state_dict + + def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict): + hf_state_dict = self._remove_prefix(origin_hf_state_dict, 'mtp.') + for mg_key, key in zip(['enorm.weight', 'hnorm.weight', 'eh_proj.weight'], + ['pre_fc_norm_embedding.weight', 'pre_fc_norm_hidden.weight', 'fc.weight']): + self._set_state_dict(mtp_layer, mg_key, hf_state_dict, key, to_mcore) + self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'norm.weight', to_mcore) + if not to_mcore: + origin_hf_state_dict.update(self._add_prefix(hf_state_dict, 'mtp.')) + + +try: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeForConditionalGeneration +except ImportError: + Qwen3_5MoeForConditionalGeneration = None + +_auto_model_cls = Qwen3_5MoeForConditionalGeneration +if _auto_model_cls is None: + try: + from transformers import AutoModel + _auto_model_cls = AutoModel + except ImportError: + _auto_model_cls = None + +register_megatron_model( + MegatronModelMeta( + MegatronModelType.qwen3_5, + [ + ModelType.qwen3_5, + ModelType.qwen3_5_moe, + ], + bridge_cls=Qwen3_5Bridge, + visual_cls=Qwen3_5Vit, + auto_model_cls=_auto_model_cls, + )) diff --git a/src/twinkle/model/megatron/utils/config.py b/src/twinkle/model/megatron/utils/config.py index ef44b4b1..eca0edbb 100644 --- a/src/twinkle/model/megatron/utils/config.py +++ b/src/twinkle/model/megatron/utils/config.py @@ -37,6 +37,13 @@ 'qk_pos_emb_head_dim': ['qk_rope_head_dim'], 'moe_router_topk_scaling_factor': ['routed_scaling_factor'], 'qk_layernorm': ['use_qk_norm'], + # qwen3_next / qwen3_5 + 'linear_num_value_heads': ['linear_num_value_heads'], + 'linear_num_key_heads': ['linear_num_key_heads'], + 'linear_key_head_dim': ['linear_key_head_dim'], + 'linear_value_head_dim': ['linear_value_head_dim'], + 'linear_conv_kernel_dim': ['linear_conv_kernel_dim'], + 'full_attention_interval': ['full_attention_interval'], # other 'original_max_position_embeddings': ['original_max_position_embeddings'], 'partial_rotary_factor': ['partial_rotary_factor'], @@ -95,13 +102,14 @@ def convert_hf_config(config) -> Dict[str, Any]: interleave_moe_layer_step = res.pop('interleave_moe_layer_step', None) window_size = res.pop('window_size', None) rope_scaling = res.get('rope_scaling') or {} - if llm_model_type in {'qwen3', 'qwen3_moe', 'qwen3_next' - } or hf_model_type in {'qwen3_omni_moe', 'qwen3_omni', 'qwen3_vl', 'qwen3_vl_moe'}: + if llm_model_type in {'qwen3', 'qwen3_moe', 'qwen3_next'} or hf_model_type in { + 'qwen3_omni_moe', 'qwen3_omni', 'qwen3_vl', 'qwen3_vl_moe', 'qwen3_5', 'qwen3_5_moe' + }: res['qk_layernorm'] = True if llm_model_type in {'qwen2_moe', 'qwen3_moe', 'qwen3_next' - } or hf_model_type in {'qwen3_omni_moe', 'qwen3_vl_moe'}: + } or hf_model_type in {'qwen3_omni_moe', 'qwen3_vl_moe', 'qwen3_5_moe'}: res.pop('ffn_hidden_size', None) - if llm_model_type in {'qwen2_moe', 'qwen3_next'}: + if llm_model_type in {'qwen2_moe', 'qwen3_next'} or hf_model_type == 'qwen3_5_moe': res['use_shared_expert_gate'] = True if llm_model_type in { 'deepseek', @@ -145,8 +153,8 @@ def convert_hf_config(config) -> Dict[str, Any]: if llm_model_type == 'glm4_moe_lite': res['qk_layernorm'] = True res.pop('num_query_groups', None) - elif llm_model_type == 'qwen3_next': - full_attention_interval = res.pop('full_attention_interval') + elif llm_model_type == 'qwen3_next' or hf_model_type in {'qwen3_5', 'qwen3_5_moe'}: + full_attention_interval = res.pop('full_attention_interval', 4) num_layers = res['num_layers'] res['layer_types'] = [ 'full_attention' if (i + 1) % full_attention_interval == 0 else 'linear_attention' From d2ea7047a4ff165363a9d30f99d331a5bb17db50 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 26 Feb 2026 16:33:55 +0800 Subject: [PATCH 02/14] fix --- cookbook/megatron/tp_moe.sh | 2 - cookbook/megatron/tp_moe_qwen35.py | 8 +- src/twinkle/model/megatron/args.py | 92 +++++------- .../model/megatron/model/gpt_bridge.py | 131 ++++++++++++++---- .../model/megatron/model/gpts/qwen3_next.py | 64 ++++++--- 5 files changed, 186 insertions(+), 111 deletions(-) diff --git a/cookbook/megatron/tp_moe.sh b/cookbook/megatron/tp_moe.sh index 27132b8d..58e58646 100644 --- a/cookbook/megatron/tp_moe.sh +++ b/cookbook/megatron/tp_moe.sh @@ -1,3 +1 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 tp_moe.py -CUDA_VISIBLE_DEVICES=4,5,6,7 nohup torchrun --nproc_per_node=4 /mnt/nas2/hujinghan.hjh/twinkle/cookbook/megatron/tp_moe.py > tp_moe.log 2>&1 & - diff --git a/cookbook/megatron/tp_moe_qwen35.py b/cookbook/megatron/tp_moe_qwen35.py index 8b85540b..4e845a55 100644 --- a/cookbook/megatron/tp_moe_qwen35.py +++ b/cookbook/megatron/tp_moe_qwen35.py @@ -9,13 +9,13 @@ from twinkle.model import MegatronModel from twinkle.preprocessor import SelfCognitionProcessor -# tp=2, pp=2, ep=2 on 4 GPUs, dp=1 -device_mesh = DeviceMesh.from_sizes(dp_size=1, tp_size=2, pp_size=2, ep_size=2) +device_mesh = DeviceMesh.from_sizes(dp_size=4, tp_size=1, pp_size=1, ep_size=4) twinkle.initialize(mode='local', global_device_mesh=device_mesh) logger = get_logger() -MODEL_ID = '/root/.cache/modelscope/hub/models/Qwen/Qwen3.5-35B-A3B' +MODEL_ID = 'Qwen/Qwen3.5-35B-A3B' +MAX_STEPS = 100 def train(): dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) @@ -39,7 +39,7 @@ def train(): if step % 5 == 0: metric = model.calculate_metric(is_training=True) logger.info(f'Step {step}/{len(dataloader)}, metric: {metric}') - if step >= 10: + if step >= MAX_STEPS: break model.save('last-checkpoint') logger.info('Training completed.') diff --git a/src/twinkle/model/megatron/args.py b/src/twinkle/model/megatron/args.py index a0f817a3..d23d8020 100644 --- a/src/twinkle/model/megatron/args.py +++ b/src/twinkle/model/megatron/args.py @@ -107,6 +107,7 @@ class TwinkleMegatronArgs: num_experts: int = 0 num_experts_per_tok: int = 2 shared_expert_intermediate_size: int = 0 + moe_router_enable_expert_bias: bool = False # ========================================================================= # Training/inference settings @@ -260,6 +261,10 @@ def head_dim(self) -> int: def intermediate_size(self) -> int: return self.ffn_hidden_size + @property + def moe_shared_expert_intermediate_size(self) -> int: + return self.shared_expert_intermediate_size + @property def num_query_groups(self) -> int: """Alias for num_key_value_heads (Megatron naming).""" @@ -477,8 +482,6 @@ def finalize_model_grads_for_lora(model, *args, **kwargs): from megatron.core.distributed import DistributedDataParallel as MegatronDDP from peft import PeftModel as _PeftModel - # Check if model is DDP-wrapped (has ddp_config) - # Need to unwrap PeftModel to check the underlying model def _get_base_model(m): if isinstance(m, _PeftModel): return _get_base_model(m.base_model.model) @@ -563,20 +566,6 @@ def _get_base_model(m): bias_activation_fusion = use_swiglu and not has_bias if 'moe_token_dispatcher_type' not in moe_kwargs: moe_kwargs['moe_token_dispatcher_type'] = 'alltoall' if self.variable_seq_lengths else 'allgather' - - # Handle use_shared_expert_gate from config - use_shared_expert_gate = mg_config_dict.get('use_shared_expert_gate', False) - - # Handle rotary_interleaved for models like Qwen3.5 with mrope - rotary_interleaved = mg_config_dict.get('rotary_interleaved', False) - partial_rotary_factor = mg_config_dict.get('partial_rotary_factor') - - # Determine position_embedding_type - position_embedding_type = mg_config_dict.get('position_embedding_type', 'rope') - apply_rope_fusion = True - if position_embedding_type != 'rope' or rotary_interleaved: - apply_rope_fusion = False - config = TransformerConfig( num_layers=num_layers, hidden_size=mg_config_dict['hidden_size'], @@ -608,10 +597,10 @@ def _get_base_model(m): attention_dropout=0.0, masked_softmax_fusion=True, bias_dropout_fusion=True, - apply_rope_fusion=apply_rope_fusion, + apply_rope_fusion=True, attention_softmax_in_fp32=True, attention_backend=AttnBackend.flash, - rotary_interleaved=rotary_interleaved, + calculate_per_token_loss=True, recompute_granularity=self.recompute_granularity, recompute_modules=self.recompute_modules if self.recompute_granularity == 'selective' else None, recompute_method=recompute_method, @@ -622,32 +611,39 @@ def _get_base_model(m): if exists('megatron_core>=0.13'): config.expert_tensor_parallel_size = self.etp_size - # Store layer_types on config for Qwen3-Next/Qwen3.5 heterogeneous layers + if mg_config_dict.get('use_shared_expert_gate'): + config.moe_use_shared_expert_gate = True + if mg_config_dict.get('rotary_interleaved'): + config.rotary_interleaved = True + partial_rotary_factor = mg_config_dict.get('partial_rotary_factor') + if partial_rotary_factor is not None: + config.rotary_percent = partial_rotary_factor + config.apply_rope_fusion = False + mrope_section = mg_config_dict.get('mrope_section') + if mrope_section is not None: + config.mrope_section = mrope_section + if mg_config_dict.get('mrope_interleaved'): + config.mrope_interleaved = True + layer_types = mg_config_dict.get('layer_types') if layer_types is not None: config.layer_types = layer_types - self.layer_types = layer_types - for attr in ['linear_num_value_heads', 'linear_num_key_heads', 'linear_key_head_dim', - 'linear_value_head_dim', 'linear_conv_kernel_dim']: + for attr in ('linear_num_value_heads', 'linear_num_key_heads', 'linear_key_head_dim', + 'linear_value_head_dim', 'linear_conv_kernel_dim'): val = mg_config_dict.get(attr) if val is not None: setattr(config, attr, val) - # Store partial_rotary_factor on config - if partial_rotary_factor is not None: - config.partial_rotary_factor = partial_rotary_factor - - # Store args reference on config for HuggingFaceModule compatibility - config.args = self - self.config = config # Get layer spec moe_grouped_gemm = num_experts > 0 if layer_types is not None: - from .model.gpts.qwen3_next import get_qwen3_next_layer_spec, Qwen3NextGatedDeltaNet, Qwen3_5MoeGatedDeltaNet - hf_model_type = mg_config_dict.get('hf_model_type', '') - if hf_model_type in {'qwen3_5_moe', 'qwen3_5'}: + from .model.gpts.qwen3_next import (Qwen3_5MoeGatedDeltaNet, Qwen3NextGatedDeltaNet, + get_qwen3_next_layer_spec) + llm_model_type = mg_config_dict.get('llm_model_type', '') + hf_mt = mg_config_dict.get('hf_model_type', '') + if 'qwen3_5_moe' in (llm_model_type, hf_mt): gated_delta_net_cls = Qwen3_5MoeGatedDeltaNet else: gated_delta_net_cls = Qwen3NextGatedDeltaNet @@ -663,34 +659,14 @@ def _get_base_model(m): raise RuntimeError( 'TransformerEngine is not installed or not compatible with this version of Megatron-Core.') - # Set shared_expert_gate if needed - if use_shared_expert_gate and num_experts > 0 and moe_shared_expert_intermediate_size: - if hasattr(layer_spec, 'layer_specs'): - for ls in layer_spec.layer_specs: - if hasattr(ls.submodules.mlp.submodules, 'shared_experts'): - ls.submodules.mlp.submodules.shared_experts.params = {'gate': True} - elif hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts'): - layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} - # Create model - text_config = hf_config - if hasattr(hf_config, 'text_config') and hf_config.text_config is not None: - text_config = hf_config.text_config - max_seq_length = getattr(text_config, 'max_position_embeddings', 4096) + max_seq_length = getattr(hf_config, 'max_position_embeddings', 4096) rotary_base = mg_config_dict.get('rotary_base', 10000) - rotary_percent = 1.0 - if partial_rotary_factor is not None: - rotary_percent = partial_rotary_factor + position_embedding_type = mg_config_dict.get('position_embedding_type', 'rope') extra_init_args = {} - rope_scaling_dict = getattr(text_config, 'rope_scaling', None) or getattr(text_config, 'rope_parameters', None) - if rope_scaling_dict is not None and isinstance(rope_scaling_dict, dict): - if 'factor' in rope_scaling_dict: - extra_init_args['seq_len_interpolation_factor'] = rope_scaling_dict['factor'] - if 'rope_theta' in rope_scaling_dict: - rotary_base = int(rope_scaling_dict['rope_theta']) - mrope_section = mg_config_dict.get('mrope_section') - if position_embedding_type == 'mrope' and mrope_section is not None: - config.mrope_section = mrope_section + if hasattr(hf_config, + 'rope_scaling') and hf_config.rope_scaling is not None and 'factor' in hf_config.rope_scaling: + extra_init_args = {'seq_len_interpolation_factor': hf_config.rope_scaling['factor']} vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() if vpp_size is not None and vpp_size > 1: model = [] @@ -710,7 +686,6 @@ def _get_base_model(m): parallel_output=True, share_embeddings_and_output_weights=getattr(hf_config, 'tie_word_embeddings', False), position_embedding_type=position_embedding_type, - rotary_percent=rotary_percent, rotary_base=rotary_base, **extra_init_args) model.append(_model) @@ -726,7 +701,6 @@ def _get_base_model(m): parallel_output=True, share_embeddings_and_output_weights=getattr(hf_config, 'tie_word_embeddings', False), position_embedding_type=position_embedding_type, - rotary_percent=rotary_percent, rotary_base=rotary_base, **extra_init_args, ) diff --git a/src/twinkle/model/megatron/model/gpt_bridge.py b/src/twinkle/model/megatron/model/gpt_bridge.py index 83828bad..e8ed547e 100644 --- a/src/twinkle/model/megatron/model/gpt_bridge.py +++ b/src/twinkle/model/megatron/model/gpt_bridge.py @@ -3,6 +3,8 @@ import math import os +import re +import shutil import torch import torch.distributed as dist import torch.nn.functional as F @@ -109,6 +111,35 @@ def get_hf_mlp_prefix(self, layer_idx): def _get_hf_mlp(self, layer_idx): return getattr(self.hf_layers[layer_idx], self.get_hf_mlp_prefix(layer_idx)) + _HF_GROUPED_FALSE_TYPES = { + 'qwen2_moe', + 'qwen3_moe', + 'deepseek_v2', + 'deepseek_v3', + 'dots1', + 'ernie4_5_moe', + 'glm4_moe', + 'glm4_moe_lite', + 'glm4v_moe', + 'minimax_m2', + 'olmoe', + 'qwen3_next', + 'kimi_vl', + 'qwen3_omni_moe', + 'qwen3_5_moe', + } + + def _get_hf_grouped(self, is_mtp_layer: bool = False): + if self.args.hf_model_type in self._HF_GROUPED_FALSE_TYPES: + return False, False + return None, None + + def _get_transpose(self): + if self.args.hf_model_type in {'qwen3_vl_moe', 'gpt_oss', 'llama4'}: + return True + else: + return False + def _init_meta_hf_model(self): import copy @@ -681,6 +712,7 @@ def _set_moe_state( hf_prefix: str, layer_idx: int, to_mcore: bool, + is_mtp_layer: bool = False, ): if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) @@ -726,21 +758,33 @@ def _set_moe_state( else: mg_experts = None hf_state_dict.update( - self._set_mlp_state(mg_experts, hf_state_dict, 'experts.', layer_idx, to_mcore, ep_rank=ep_rank)) + self._set_mlp_state( + mg_experts, + hf_state_dict, + 'experts.', + layer_idx, + to_mcore, + ep_rank=ep_rank, + is_mtp_layer=is_mtp_layer)) if to_mcore: hf_state_dict = {} else: hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict - def _set_mlp_state(self, - mg_mlp, - hf_state_dict, - hf_prefix: str, - layer_idx: int, - to_mcore: bool, - ep_rank: Optional[int] = None, - hf_mlp=None): + def _set_mlp_state( + self, + mg_mlp, + hf_state_dict, + hf_prefix: str, + layer_idx: int, + to_mcore: bool, + ep_rank: Optional[int] = None, + hf_mlp=None, + is_mtp_layer: bool = False, + ): + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) if hf_mlp is None: hf_mlp = self._get_hf_mlp(layer_idx) is_expert = ep_rank is not None @@ -748,17 +792,33 @@ def _set_mlp_state(self, hf_grouped = False args = self.args if is_expert: - hf_grouped = not hasattr(hf_mlp.experts, '__len__') - hf_mlp = hf_mlp.experts if hf_grouped else hf_mlp.experts[0] + hf_mlp = hf_mlp.experts + if to_mcore: + pattern = r'\d+\.down_proj' + hf_grouped = not any(re.match(pattern, k) is not None for k in hf_state_dict.keys()) + else: + hf_grouped = not hasattr(hf_mlp, '__len__') + if hasattr(hf_mlp, '__len__'): + hf_mlp = hf_mlp[0] num_local_experts = args.num_experts // self.ep_size - # TODO: Temporary modification for transformers 5.0 compatibility with GLM4.6v, to be fixed later - is_gate_up = hasattr(hf_mlp, 'gate_up_proj') - if self.is_transformers_5 and self.args.hf_model_type in {'glm4v_moe', 'glm4_moe_lite'}: - hf_grouped = False - is_gate_up = False - if to_mcore or hf_grouped: - hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + if to_mcore: + is_gate_up = any('gate_up_proj' in k for k in hf_state_dict.keys()) else: + is_gate_up = hasattr(hf_mlp, 'gate_up_proj') + if self.is_transformers_5 and not to_mcore and is_expert: + _hf_grouped, _is_gate_up = self._get_hf_grouped(is_mtp_layer) + if _hf_grouped is not None: + hf_grouped = _hf_grouped + if _is_gate_up is not None: + is_gate_up = _is_gate_up + + need_transpose = True + if self.is_transformers_5 and hf_grouped: + need_transpose = self._get_transpose() + + if hf_grouped and not to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + elif not to_mcore: hf_state_dict = {} # linear_fc1 if to_mcore: @@ -829,11 +889,15 @@ def _set_mlp_state(self, gate_up_proj_weight = self.mxfp4_quantizer.convert(blocks, scales) else: gate_up_proj_weight = hf_state_dict['gate_up_proj'].load() - gate_up_proj_weight = gate_up_proj_weight.transpose(1, 2) + if need_transpose: + gate_up_proj_weight = gate_up_proj_weight.transpose(1, 2) + gate_up_proj_weight = gate_up_proj_weight[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] if has_scale_inv: - gate_up_scale_inv = hf_state_dict['gate_up_proj_scale_inv'].load().transpose(1, 2) + gate_up_scale_inv = hf_state_dict['gate_up_proj_scale_inv'].load() + if need_transpose: + gate_up_scale_inv = gate_up_scale_inv.transpose(1, 2) gate_up_scale_inv = gate_up_scale_inv[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] if fc1_bias is not None: @@ -989,7 +1053,8 @@ def _set_mlp_state(self, if is_gate_up: if is_expert: if hf_grouped: - gate_up_proj_weight = gate_up_proj_weight.transpose(1, 2) + if need_transpose: + gate_up_proj_weight = gate_up_proj_weight.transpose(1, 2) if 'gate_up_proj' in hf_state_dict: gate_up_proj_weight = torch.concat( [hf_state_dict['gate_up_proj'], gate_up_proj_weight], dim=0) @@ -1003,7 +1068,8 @@ def _set_mlp_state(self, del new_gate_up_proj_weight, gate_proj_weight, up_proj_weight hf_state_dict['gate_up_proj'] = gate_up_proj_weight.clone() if scale_inv is not None: - scale_inv = scale_inv.transpose(1, 2) + if need_transpose: + scale_inv = scale_inv.transpose(1, 2) if 'gate_up_proj_scale_inv' in hf_state_dict: scale_inv = torch.concat([hf_state_dict['gate_up_proj_scale_inv'], scale_inv], dim=0) @@ -1095,12 +1161,15 @@ def _set_mlp_state(self, down_proj_weight = self.mxfp4_quantizer.convert(blocks, scales) else: down_proj_weight = hf_state_dict['down_proj'].load() - down_proj_weight = down_proj_weight.transpose(1, 2) + if need_transpose: + down_proj_weight = down_proj_weight.transpose(1, 2) down_proj_weight = down_proj_weight[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts].reshape( -1, down_proj_weight.shape[-1]) if has_scale_inv: - down_scale_inv = hf_state_dict['down_proj_scale_inv'].load().transpose(1, 2) + down_scale_inv = hf_state_dict['down_proj_scale_inv'].load() + if need_transpose: + down_scale_inv = down_scale_inv.transpose(1, 2) down_scale_inv = down_scale_inv[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts].reshape(-1, down_scale_inv.shape[-1]) if fc2_bias is not None: @@ -1180,12 +1249,14 @@ def _set_mlp_state(self, del fc2_weight, fc2_bias if down_proj_weight is not None: if hf_grouped: - down_proj_weight = down_proj_weight.transpose(1, 2) + if need_transpose: + down_proj_weight = down_proj_weight.transpose(1, 2) if 'down_proj' in hf_state_dict: down_proj_weight = torch.concat([hf_state_dict['down_proj'], down_proj_weight], dim=0) hf_state_dict['down_proj'] = down_proj_weight.clone() if scale_inv is not None: - scale_inv = scale_inv.transpose(1, 2) + if need_transpose: + scale_inv = scale_inv.transpose(1, 2) if 'down_proj_scale_inv' in hf_state_dict: scale_inv = torch.concat([hf_state_dict['down_proj_scale_inv'], scale_inv], dim=0) hf_state_dict['down_proj_scale_inv'] = scale_inv.clone() @@ -1253,13 +1324,15 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo 'input_layernorm.weight', to_mcore) return hf_state_dict - def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool): + def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool, is_mtp_layer: bool = False): hf_mlp_prefix = self.get_hf_mlp_prefix(layer_idx) hf_mlp = self._get_hf_mlp(layer_idx) is_moe = self._is_moe(hf_mlp.state_dict()) mg_mlp = None if mg_layer is None else mg_layer.mlp if is_moe: - hf_state_dict.update(self._set_moe_state(mg_mlp, hf_state_dict, f'{hf_mlp_prefix}.', layer_idx, to_mcore)) + hf_state_dict.update( + self._set_moe_state( + mg_mlp, hf_state_dict, f'{hf_mlp_prefix}.', layer_idx, to_mcore, is_mtp_layer=is_mtp_layer)) self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight', to_mcore) else: @@ -1445,7 +1518,7 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, 'shared_head.head.weight', to_mcore) hf_state_dict.update(self._set_layer_attn(transformer_layer, hf_state_dict, -1, to_mcore)) - hf_state_dict.update(self._set_layer_mlp(transformer_layer, hf_state_dict, -1, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(transformer_layer, hf_state_dict, -1, to_mcore, is_mtp_layer=True)) if to_mcore: hf_state_dict = {} else: diff --git a/src/twinkle/model/megatron/model/gpts/qwen3_next.py b/src/twinkle/model/megatron/model/gpts/qwen3_next.py index bf980854..78f1f7e1 100644 --- a/src/twinkle/model/megatron/model/gpts/qwen3_next.py +++ b/src/twinkle/model/megatron/model/gpts/qwen3_next.py @@ -139,8 +139,10 @@ def forward( try: from megatron.core.utils import nvtx_range_pop, nvtx_range_push except ImportError: + def nvtx_range_pop(*args, **kwargs): return + def nvtx_range_push(*args, **kwargs): return @@ -194,7 +196,14 @@ def nvtx_range_push(*args, **kwargs): raise ValueError('CUDA graphs must use flash decode with static batching!') result = self._adjust_key_value_for_inference( - inference_context, query, key, value, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset, + inference_context, + query, + key, + value, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, ) if mcore_013: query, key, value, rotary_pos_emb, attn_mask_type, block_table = result @@ -217,19 +226,19 @@ def nvtx_range_push(*args, **kwargs): q_pos_emb, k_pos_emb = rotary_pos_emb if packed_seq_params is not None: - cu_seqlens_q = (packed_seq_params.cu_seqlens_q_padded - if packed_seq_params.cu_seqlens_q_padded is not None - else packed_seq_params.cu_seqlens_q) - cu_seqlens_kv = (packed_seq_params.cu_seqlens_kv_padded - if packed_seq_params.cu_seqlens_kv_padded is not None - else packed_seq_params.cu_seqlens_kv) + cu_seqlens_q = ( + packed_seq_params.cu_seqlens_q_padded + if packed_seq_params.cu_seqlens_q_padded is not None else packed_seq_params.cu_seqlens_q) + cu_seqlens_kv = ( + packed_seq_params.cu_seqlens_kv_padded + if packed_seq_params.cu_seqlens_kv_padded is not None else packed_seq_params.cu_seqlens_kv) else: cu_seqlens_q = cu_seqlens_kv = None if q_pos_emb is not None: if inference_context is None or inference_context.is_static_batching(): - query = apply_rotary_pos_emb(query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q, - **kwargs_cp) + query = apply_rotary_pos_emb( + query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q, **kwargs_cp) else: query = inference_context.apply_rotary_emb_query(query, q_pos_emb, self.config, cu_seqlens_q, **kwargs_cp) @@ -240,22 +249,40 @@ def nvtx_range_push(*args, **kwargs): nvtx_range_push(suffix='core_attention') if self.checkpoint_core_attention and self.training: core_attn_out = self._checkpointed_attention_forward( - query, key, value, attention_mask, attn_mask_type=attn_mask_type, - attention_bias=attention_bias, packed_seq_params=packed_seq_params, + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, ) else: if inference_context is None or inference_context.is_static_batching(): core_attn_out = self.core_attention( - query, key, value, attention_mask, attn_mask_type=attn_mask_type, - attention_bias=attention_bias, packed_seq_params=packed_seq_params, + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, ) else: q, k, v = (query, key, value) cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths() cu_kv_lengths, kv_lengths, kv_lengths_decode_only, max_seqlen_k = (inference_context.cu_kv_lengths()) core_attn_out = self.flash_decode_and_prefill( - q, k, v, max_seqlen_q, max_seqlen_k, cu_query_lengths, cu_kv_lengths, - kv_lengths, kv_lengths_decode_only, block_table, + q, + k, + v, + max_seqlen_q, + max_seqlen_k, + cu_query_lengths, + cu_kv_lengths, + kv_lengths, + kv_lengths_decode_only, + block_table, ) core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') @@ -315,8 +342,8 @@ def _gated_delta_net_forward(self, hidden_states: torch.Tensor, **kwargs): if thd_format and not getattr(args, 'packing', False): new_hidden_states = hidden_states.new_zeros( (packed_seq_params.num_samples, packed_seq_params.max_seqlen_q.item(), hidden_states.shape[-1])) - attention_mask = hidden_states.new_zeros( - (packed_seq_params.num_samples, packed_seq_params.max_seqlen_q.item()), dtype=torch.bool) + attention_mask = hidden_states.new_zeros((packed_seq_params.num_samples, packed_seq_params.max_seqlen_q.item()), + dtype=torch.bool) cu_seqlens_q = packed_seq_params.cu_seqlens_q for i in range(packed_seq_params.num_samples): start, end = cu_seqlens_q[i], cu_seqlens_q[i + 1] @@ -438,6 +465,9 @@ def get_qwen3_next_layer_spec(config, args, gated_delta_net_cls): layer_spec.submodules.self_attention.submodules.q_layernorm = layer_norm_impl if hasattr(layer_spec.submodules.self_attention.submodules, 'k_layernorm'): layer_spec.submodules.self_attention.submodules.k_layernorm = layer_norm_impl + if (getattr(config, 'moe_use_shared_expert_gate', False) and hasattr(layer_spec.submodules, 'mlp') + and hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts')): + layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} layer_specs.append(layer_spec) local_layer_specs = get_local_layer_specs(config, layer_specs) From 79302ddb966a67fb87ccbfb46438645b493ab072 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 26 Feb 2026 17:07:33 +0800 Subject: [PATCH 03/14] fix --- cookbook/rl/grpo_qwen3_5.py | 176 ------------------ src/twinkle/model/megatron/args.py | 69 +++---- src/twinkle/model/megatron/model/__init__.py | 2 +- .../model/megatron/model/gpt_bridge.py | 29 +-- .../model/megatron/model/gpts/qwen3_next.py | 22 +++ .../model/megatron/model/mm_gpts/qwen3_5.py | 7 + src/twinkle/model/megatron/model/register.py | 52 +++++- src/twinkle/processor/base.py | 2 + 8 files changed, 110 insertions(+), 249 deletions(-) delete mode 100644 cookbook/rl/grpo_qwen3_5.py diff --git a/cookbook/rl/grpo_qwen3_5.py b/cookbook/rl/grpo_qwen3_5.py deleted file mode 100644 index 8d9eec8a..00000000 --- a/cookbook/rl/grpo_qwen3_5.py +++ /dev/null @@ -1,176 +0,0 @@ -import os -from typing import List, Tuple, Dict, Any - -from peft import LoraConfig - -import twinkle -from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger -from twinkle.advantage import GRPOAdvantage -from twinkle.checkpoint_engine import CheckpointEngineManager -from twinkle.data_format import SamplingParams -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.model import TransformersModel -from twinkle.processor import InputProcessor -from twinkle.reward import GSM8KAccuracyReward, GSM8KFormatReward -from twinkle.sampler import vLLMSampler -from twinkle.template import Template -from twinkle.metric import CompletionRewardMetric -from twinkle.preprocessor.llm import GSM8KProcessor - -logger = get_logger() - -MODEL_ID = '/root/.cache/modelscope/hub/models/Qwen/Qwen3.5-35B-A3B' -USE_MEGATRON = True - -MODEL_GPUS = 4 -SAMPLER_GPUS = 4 -NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS - -NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 4)) -MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 2048)) -LEARNING_RATE = float(os.environ.get('LR', 1e-5)) -MAX_STEPS = int(os.environ.get('MAX_STEPS', 20)) -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) -MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4)) -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1)) -GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) -ADAPTER_NAME = 'default' - - -def create_gsm8k_dataset(): - dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) - dataset.set_template('Template', model_id=MODEL_ID, max_length=1024) - dataset.map(GSM8KProcessor()) - dataset.encode(add_generation_prompt=True) - return dataset - - -def compute_rewards( - trajectories: List[Dict[str, Any]], -) -> Tuple[List[float], List[float], List[float]]: - accuracy_reward_fn = GSM8KAccuracyReward() - format_reward_fn = GSM8KFormatReward() - - accuracy_rewards = accuracy_reward_fn(trajectories) - format_rewards = format_reward_fn(trajectories) - total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)] - return total_rewards, format_rewards, accuracy_rewards - - -def main(): - device_groups = [ - DeviceGroup(name='model', ranks=list(range(4, 8)), device_type='GPU'), - DeviceGroup(name='sampler', ranks=list(range(4)), device_type='GPU'), - ] - # tp=2, pp=2, ep=2 for model group (4 GPUs) - # dp = world_size / (tp * pp) = 4 / (2 * 2) = 1 - model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=1, tp_size=2, pp_size=2, ep_size=2) - sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) - twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) - - lora_config = LoraConfig(target_modules='all-linear', r=8, lora_alpha=16, lora_dropout=0.05) - - from twinkle.model.megatron import MegatronModel - model = MegatronModel(model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model', mixed_precision='bf16') - - model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1) - model.set_optimizer('default', lr=LEARNING_RATE) - model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE) - model.set_loss('GRPOLoss', epsilon=0.2) - model.set_processor(InputProcessor) - model.set_template('Template', model_id=MODEL_ID) - - sampler = vLLMSampler( - model_id=MODEL_ID, - engine_args={ - 'gpu_memory_utilization': 0.8, - 'max_model_len': 2048, - 'max_lora_rank': 8, - 'enable_lora': True, - }, - device_mesh=sampler_mesh, - remote_group='sampler', - ) - sampler.set_template(Template, model_id=MODEL_ID) - - ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler) - - GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS - dataloader = DataLoader( - dataset=create_gsm8k_dataset, - batch_size=GLOBAL_BATCH_SIZE, - min_batch_size=GLOBAL_BATCH_SIZE, - device_mesh=model_mesh, - remote_group='model', - ) - advantage_fn = GRPOAdvantage() - metrics = CompletionRewardMetric() - - sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS) - - optim_step = 0 - logger.info(get_device_placement()) - - for batch in dataloader: - if optim_step >= MAX_STEPS: - break - metrics.reset() - global_prompts = batch if isinstance(batch, list) else [batch] - ckpt_manager.sync_weights(merge_and_sync=False) - sampler.reset_prefix_cache() - sample_response = sampler.sample( - global_prompts * NUM_GENERATIONS, - sampling_params, - num_samples=1, - ) - - all_input_data: List[Dict[str, Any]] = [] - all_old_logps: List[List[float]] = [] - all_completion_lengths: List[int] = [] - - for sequence in sample_response.sequences: - all_input_data.append(sequence.new_input_feature) - all_old_logps.append(sequence.logprobs) - all_completion_lengths.append(len(sequence.tokens)) - total_rewards, format_rewards, accuracy_rewards = compute_rewards(all_input_data) - metrics.accumulate( - completion_lengths=all_completion_lengths, - rewards={ - 'total': total_rewards, - 'format': format_rewards, - 'accuracy': accuracy_rewards, - }, - ) - - advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist() - - total_completions = len(all_input_data) - for mb_start in range(0, total_completions, MINI_BATCH_SIZE): - mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions) - mb_inputs = all_input_data[mb_start:mb_end] - mb_old_logps = all_old_logps[mb_start:mb_end] - mb_advantages = advantages[mb_start:mb_end] - - model.forward_backward( - inputs=mb_inputs, - old_logps=mb_old_logps, - advantages=mb_advantages, - micro_batch_size=MICRO_BATCH_SIZE, - ) - model.clip_grad_and_step() - optim_step += 1 - - if optim_step >= MAX_STEPS: - break - log_dict = metrics.calculate() - log_dict.update(model.calculate_metric(is_training=True)) - metrics.reset() - logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}') - - logger.info(f'Training completed. optim_steps={optim_step}') - model.save('grpo-qwen35-gsm8k-checkpoint') - - -if __name__ == '__main__': - main() diff --git a/src/twinkle/model/megatron/args.py b/src/twinkle/model/megatron/args.py index d23d8020..8e5c7bfd 100644 --- a/src/twinkle/model/megatron/args.py +++ b/src/twinkle/model/megatron/args.py @@ -161,7 +161,6 @@ class TwinkleMegatronArgs: # ========================================================================= untie_embeddings_and_output_weights: bool = True max_shard_size: str = '5GB' - llm_model_type: str = 'gpt' # For transformers 5.0 compatibility use_cpu_initialization: bool = False def __post_init__(self): @@ -335,10 +334,12 @@ def from_hf_config( # Get rope_scaling rope_scaling = getattr(text_config, 'rope_scaling', None) - # Detect multimodal model model_type = getattr(hf_config, 'model_type', 'qwen2') - is_multimodal = ('vl' in model_type.lower() or 'vision' in model_type.lower() or 'omni' in model_type.lower() - or hasattr(hf_config, 'vision_config')) + + # Detect multimodal model from the registered MegatronModelMeta + from .model.register import get_megatron_model_meta + model_meta = get_megatron_model_meta(model_type) + is_multimodal = model_meta.is_multimodal if model_meta is not None else False # Determine QKV bias if hasattr(text_config, 'attention_bias'): @@ -441,7 +442,6 @@ def create_model(self, ) -> List[nn.Module]: if self._model is not None: return self._model from megatron.core import parallel_state as mpu - from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.transformer import TransformerConfig from megatron.core.transformer.enums import AttnBackend @@ -482,6 +482,8 @@ def finalize_model_grads_for_lora(model, *args, **kwargs): from megatron.core.distributed import DistributedDataParallel as MegatronDDP from peft import PeftModel as _PeftModel + # Check if model is DDP-wrapped (has ddp_config) + # Need to unwrap PeftModel to check the underlying model def _get_base_model(m): if isinstance(m, _PeftModel): return _get_base_model(m.base_model.model) @@ -582,24 +584,26 @@ def _get_base_model(m): params_dtype=self.params_dtype, fp16=self.params_dtype == torch.float16, bf16=self.params_dtype == torch.bfloat16, - pipeline_dtype=self.params_dtype, + pipeline_dtype=self.params_dtype, # Required when using pipeline parallelism use_cpu_initialization=self.use_cpu_initialization, add_qkv_bias=self.add_qkv_bias, variable_seq_lengths=self.variable_seq_lengths, add_bias_linear=not mg_config_dict.get('disable_bias_linear', True), gated_linear_unit=use_swiglu, - activation_func=activation_func, - bias_activation_fusion=bias_activation_fusion, + activation_func=activation_func, # SiLU for SwiGLU, GELU otherwise + bias_activation_fusion=bias_activation_fusion, # Fused SwiGLU for performance normalization='RMSNorm', layernorm_epsilon=mg_config_dict.get('norm_epsilon', 1e-6), qk_layernorm=mg_config_dict.get('qk_layernorm', False), hidden_dropout=0.0, attention_dropout=0.0, - masked_softmax_fusion=True, - bias_dropout_fusion=True, - apply_rope_fusion=True, - attention_softmax_in_fp32=True, - attention_backend=AttnBackend.flash, + # Performance optimizations + masked_softmax_fusion=True, # Fused attention softmax + bias_dropout_fusion=True, # Fused bias + dropout + apply_rope_fusion=True, # Fused RoPE application + attention_softmax_in_fp32=True, # Numerical stability + attention_backend=AttnBackend.flash, # FlashAttention for speed + # Activation recomputation for memory efficiency calculate_per_token_loss=True, recompute_granularity=self.recompute_granularity, recompute_modules=self.recompute_modules if self.recompute_granularity == 'selective' else None, @@ -625,39 +629,18 @@ def _get_base_model(m): if mg_config_dict.get('mrope_interleaved'): config.mrope_interleaved = True - layer_types = mg_config_dict.get('layer_types') - if layer_types is not None: - config.layer_types = layer_types - for attr in ('linear_num_value_heads', 'linear_num_key_heads', 'linear_key_head_dim', - 'linear_value_head_dim', 'linear_conv_kernel_dim'): - val = mg_config_dict.get(attr) - if val is not None: - setattr(config, attr, val) - self.config = config - # Get layer spec - moe_grouped_gemm = num_experts > 0 - if layer_types is not None: - from .model.gpts.qwen3_next import (Qwen3_5MoeGatedDeltaNet, Qwen3NextGatedDeltaNet, - get_qwen3_next_layer_spec) - llm_model_type = mg_config_dict.get('llm_model_type', '') - hf_mt = mg_config_dict.get('hf_model_type', '') - if 'qwen3_5_moe' in (llm_model_type, hf_mt): - gated_delta_net_cls = Qwen3_5MoeGatedDeltaNet - else: - gated_delta_net_cls = Qwen3NextGatedDeltaNet - layer_spec = get_qwen3_next_layer_spec(config, self, gated_delta_net_cls) + # Delegate model-specific config & layer spec construction to the loader + loader = model_meta.loader() if model_meta else None + if loader is not None: + loader.post_config(config, self, mg_config_dict) + layer_spec = loader.get_layer_spec(config, self, mg_config_dict) else: - try: - layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=mg_config_dict.get('num_experts'), - moe_grouped_gemm=moe_grouped_gemm, - qk_layernorm=mg_config_dict.get('qk_layernorm', False), - ) - except (ImportError, AttributeError): - raise RuntimeError( - 'TransformerEngine is not installed or not compatible with this version of Megatron-Core.') + from .model.register import MegatronModelLoader + default_loader = MegatronModelLoader() + default_loader.post_config(config, self, mg_config_dict) + layer_spec = default_loader.get_layer_spec(config, self, mg_config_dict) # Create model max_seq_length = getattr(hf_config, 'max_position_embeddings', 4096) diff --git a/src/twinkle/model/megatron/model/__init__.py b/src/twinkle/model/megatron/model/__init__.py index 28bae1ad..c61acef9 100644 --- a/src/twinkle/model/megatron/model/__init__.py +++ b/src/twinkle/model/megatron/model/__init__.py @@ -1,4 +1,4 @@ from . import gpts, mm_gpts from .constant import MegatronModelType from .gpt_bridge import GPTBridge -from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model +from .register import MegatronModelLoader, MegatronModelMeta, get_megatron_model_meta, register_megatron_model diff --git a/src/twinkle/model/megatron/model/gpt_bridge.py b/src/twinkle/model/megatron/model/gpt_bridge.py index e8ed547e..7f592dee 100644 --- a/src/twinkle/model/megatron/model/gpt_bridge.py +++ b/src/twinkle/model/megatron/model/gpt_bridge.py @@ -129,7 +129,7 @@ def _get_hf_mlp(self, layer_idx): 'qwen3_5_moe', } - def _get_hf_grouped(self, is_mtp_layer: bool = False): + def _get_hf_grouped(self): if self.args.hf_model_type in self._HF_GROUPED_FALSE_TYPES: return False, False return None, None @@ -712,7 +712,6 @@ def _set_moe_state( hf_prefix: str, layer_idx: int, to_mcore: bool, - is_mtp_layer: bool = False, ): if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) @@ -758,14 +757,7 @@ def _set_moe_state( else: mg_experts = None hf_state_dict.update( - self._set_mlp_state( - mg_experts, - hf_state_dict, - 'experts.', - layer_idx, - to_mcore, - ep_rank=ep_rank, - is_mtp_layer=is_mtp_layer)) + self._set_mlp_state(mg_experts, hf_state_dict, 'experts.', layer_idx, to_mcore, ep_rank=ep_rank)) if to_mcore: hf_state_dict = {} else: @@ -781,7 +773,6 @@ def _set_mlp_state( to_mcore: bool, ep_rank: Optional[int] = None, hf_mlp=None, - is_mtp_layer: bool = False, ): if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) @@ -806,7 +797,7 @@ def _set_mlp_state( else: is_gate_up = hasattr(hf_mlp, 'gate_up_proj') if self.is_transformers_5 and not to_mcore and is_expert: - _hf_grouped, _is_gate_up = self._get_hf_grouped(is_mtp_layer) + _hf_grouped, _is_gate_up = self._get_hf_grouped() if _hf_grouped is not None: hf_grouped = _hf_grouped if _is_gate_up is not None: @@ -904,7 +895,7 @@ def _set_mlp_state( gate_up_proj_bias = hf_state_dict['gate_up_proj_bias'].load() gate_up_proj_bias = gate_up_proj_bias[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] - if args.llm_model_type == 'gpt_oss': + if args.hf_model_type == 'gpt_oss': gate_proj_weight = gate_up_proj_weight[:, ::2] up_proj_weight = gate_up_proj_weight[:, 1::2] gate_proj_bias, up_proj_bias = gate_up_proj_bias[:, ::2], gate_up_proj_bias[:, 1::2] @@ -1059,7 +1050,7 @@ def _set_mlp_state( gate_up_proj_weight = torch.concat( [hf_state_dict['gate_up_proj'], gate_up_proj_weight], dim=0) is_last_ckpt = gate_up_proj_weight.shape[0] == args.num_experts - if args.llm_model_type == 'gpt_oss' and is_last_ckpt: + if args.hf_model_type == 'gpt_oss' and is_last_ckpt: gate_proj_weight, up_proj_weight = gate_up_proj_weight.chunk(2, dim=2) new_gate_up_proj_weight = torch.empty_like(gate_up_proj_weight) new_gate_up_proj_weight[..., ::2] = gate_proj_weight @@ -1079,7 +1070,7 @@ def _set_mlp_state( if 'gate_up_proj_bias' in hf_state_dict: gate_up_proj_bias = torch.concat( [hf_state_dict['gate_up_proj_bias'], gate_up_proj_bias], dim=0) - if args.llm_model_type == 'gpt_oss' and is_last_ckpt: + if args.hf_model_type == 'gpt_oss' and is_last_ckpt: gate_proj_bias, up_proj_bias = gate_up_proj_bias.chunk(2, dim=1) new_gate_up_proj_bias = torch.empty_like(gate_up_proj_bias) new_gate_up_proj_bias[:, ::2] = gate_proj_bias @@ -1324,15 +1315,13 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo 'input_layernorm.weight', to_mcore) return hf_state_dict - def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool, is_mtp_layer: bool = False): + def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool): hf_mlp_prefix = self.get_hf_mlp_prefix(layer_idx) hf_mlp = self._get_hf_mlp(layer_idx) is_moe = self._is_moe(hf_mlp.state_dict()) mg_mlp = None if mg_layer is None else mg_layer.mlp if is_moe: - hf_state_dict.update( - self._set_moe_state( - mg_mlp, hf_state_dict, f'{hf_mlp_prefix}.', layer_idx, to_mcore, is_mtp_layer=is_mtp_layer)) + hf_state_dict.update(self._set_moe_state(mg_mlp, hf_state_dict, f'{hf_mlp_prefix}.', layer_idx, to_mcore)) self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight', to_mcore) else: @@ -1518,7 +1507,7 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, 'shared_head.head.weight', to_mcore) hf_state_dict.update(self._set_layer_attn(transformer_layer, hf_state_dict, -1, to_mcore)) - hf_state_dict.update(self._set_layer_mlp(transformer_layer, hf_state_dict, -1, to_mcore, is_mtp_layer=True)) + hf_state_dict.update(self._set_layer_mlp(transformer_layer, hf_state_dict, -1, to_mcore)) if to_mcore: hf_state_dict = {} else: diff --git a/src/twinkle/model/megatron/model/gpts/qwen3_next.py b/src/twinkle/model/megatron/model/gpts/qwen3_next.py index 78f1f7e1..b589a0f2 100644 --- a/src/twinkle/model/megatron/model/gpts/qwen3_next.py +++ b/src/twinkle/model/megatron/model/gpts/qwen3_next.py @@ -24,6 +24,7 @@ from twinkle import get_logger from twinkle.model.megatron.args import get_args +from twinkle.model.megatron.model.register import MegatronModelLoader mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') mcore_015 = version.parse(megatron.core.__version__) >= version.parse('0.15.0rc0') @@ -484,3 +485,24 @@ def get_qwen3_next_mtp_block_spec(config, transformer_layer_spec, **kwargs): layer_spec.submodules.hnorm = Qwen3NextRMSNorm layer_spec.submodules.layer_norm = Qwen3NextRMSNorm return mtp_block_spec + + +class Qwen3NextLoader(MegatronModelLoader): + """Loader for Qwen3-Next models with heterogeneous linear/full attention layers.""" + gated_delta_net = Qwen3NextGatedDeltaNet + + def post_config(self, config, args, mg_config_dict): + layer_types = mg_config_dict.get('layer_types') + if layer_types is not None: + config.layer_types = layer_types + for attr in ('linear_num_value_heads', 'linear_num_key_heads', 'linear_key_head_dim', + 'linear_value_head_dim', 'linear_conv_kernel_dim'): + val = mg_config_dict.get(attr) + if val is not None: + setattr(config, attr, val) + + def get_layer_spec(self, config, args, mg_config_dict): + return get_qwen3_next_layer_spec(config, args, self.gated_delta_net) + + def get_mtp_block_spec(self, config, layer_spec, **kwargs): + return get_qwen3_next_mtp_block_spec(config, layer_spec, **kwargs) diff --git a/src/twinkle/model/megatron/model/mm_gpts/qwen3_5.py b/src/twinkle/model/megatron/model/mm_gpts/qwen3_5.py index 489fea86..f0dec64f 100644 --- a/src/twinkle/model/megatron/model/mm_gpts/qwen3_5.py +++ b/src/twinkle/model/megatron/model/mm_gpts/qwen3_5.py @@ -9,6 +9,7 @@ from twinkle.utils.torch_utils import to_device from ..constant import MegatronModelType, ModelType from ..gpt_bridge import GPTBridge, MultimodalGPTBridge +from ..gpts.qwen3_next import Qwen3_5MoeGatedDeltaNet, Qwen3NextLoader from ..register import MegatronModelMeta, register_megatron_model from .utils import HuggingFaceModule @@ -146,6 +147,11 @@ def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state except ImportError: _auto_model_cls = None + +class Qwen3_5MoeLoader(Qwen3NextLoader): + gated_delta_net = Qwen3_5MoeGatedDeltaNet + + register_megatron_model( MegatronModelMeta( MegatronModelType.qwen3_5, @@ -156,4 +162,5 @@ def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state bridge_cls=Qwen3_5Bridge, visual_cls=Qwen3_5Vit, auto_model_cls=_auto_model_cls, + loader=Qwen3_5MoeLoader, )) diff --git a/src/twinkle/model/megatron/model/register.py b/src/twinkle/model/megatron/model/register.py index f7ef917d..ab59569d 100644 --- a/src/twinkle/model/megatron/model/register.py +++ b/src/twinkle/model/megatron/model/register.py @@ -1,8 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import torch.nn as nn -from argparse import ArgumentParser from dataclasses import dataclass -from typing import Callable, List, Optional, Type +from typing import List, Optional, Type from .constant import MLLMMegatronModelType @@ -17,14 +16,9 @@ class MegatronModelMeta: is_multimodal: bool = False bridge_cls: Optional[Type] = None model_cls: Optional[Type[nn.Module]] = None - get_transformer_layer_spec: Optional[Callable] = None - model_provider: Optional[Callable[[], nn.Module]] = None visual_cls: Optional[Type[nn.Module]] = None - get_mtp_block_spec: Optional[Callable] = None - # AutoModel class for loading HF model (AutoModelForCausalLM for text, AutoModel for multimodal) auto_model_cls: Optional[Type] = None - - extra_args_provider: Optional[Callable[[ArgumentParser], ArgumentParser]] = None + loader: Optional[Type['MegatronModelLoader']] = None def __post_init__(self): if self.megatron_model_type in MLLMMegatronModelType.__dict__: @@ -39,11 +33,51 @@ def __post_init__(self): if self.auto_model_cls is None: from transformers import AutoModel, AutoModelForCausalLM self.auto_model_cls = AutoModel if self.is_multimodal else AutoModelForCausalLM + if self.loader is None: + self.loader = MegatronModelLoader + + +class MegatronModelLoader: + """Default loader that builds TransformerConfig + layer specs for a model. + + Subclass this to customize layer spec construction (e.g. heterogeneous + attention types, custom layer norms). Register the subclass via + ``MegatronModelMeta(loader=MyLoader)``. + """ + + def get_layer_spec(self, config, args, mg_config_dict): + """Build a transformer layer spec from *config* (``TransformerConfig``). + + The default implementation delegates to Megatron-Core's + ``get_gpt_layer_with_transformer_engine_spec``. + + Returns: + A ``ModuleSpec`` or ``TransformerBlockSubmodules`` instance. + """ + from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec + num_experts = mg_config_dict.get('num_experts', 0) or 0 + return get_gpt_layer_with_transformer_engine_spec( + num_experts=num_experts, + moe_grouped_gemm=num_experts > 0, + qk_layernorm=mg_config_dict.get('qk_layernorm', False), + ) + + def get_mtp_block_spec(self, config, layer_spec, **kwargs): + """Build MTP block spec. Override for custom layer norms etc.""" + from megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec + return get_gpt_mtp_block_spec(config, layer_spec, use_transformer_engine=True, **kwargs) + + def post_config(self, config, args, mg_config_dict): + """Hook called after TransformerConfig is created but before layer specs. + + Use this to set model-specific config attributes (e.g. ``layer_types``, + ``moe_use_shared_expert_gate``). + """ + pass def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: bool = False): megatron_model_type = megatron_model_meta.megatron_model_type - # diff here if not exist_ok and megatron_model_type in MEGATRON_MODEL_MAPPING: raise ValueError(f'The `{megatron_model_type}` has already been registered in the MEGATRON_MODEL_MAPPING.') MEGATRON_MODEL_MAPPING[megatron_model_type] = megatron_model_meta diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index b75603bb..cedd3dc2 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -364,6 +364,8 @@ def collate_fn(self, variable_seq_lengths=False, **kwargs) -> List[InputFeature]: if len(inputs) == 1: + if self.framework == 'megatron' and 'attention_mask' in inputs[0]: + inputs[0]['attention_mask'] = self._create_4d_attention_mask([inputs[0]['attention_mask'].squeeze(0)]) return inputs if micro_batch_size is None: # normal collate From da232f7aa42e1d76fbdd3851d294fd393502f670 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 26 Feb 2026 17:11:59 +0800 Subject: [PATCH 04/14] fix --- src/twinkle/model/megatron/args.py | 3 +++ src/twinkle/processor/base.py | 4 +--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/twinkle/model/megatron/args.py b/src/twinkle/model/megatron/args.py index 8e5c7bfd..18e0679d 100644 --- a/src/twinkle/model/megatron/args.py +++ b/src/twinkle/model/megatron/args.py @@ -609,7 +609,10 @@ def _get_base_model(m): recompute_modules=self.recompute_modules if self.recompute_granularity == 'selective' else None, recompute_method=recompute_method, recompute_num_layers=recompute_num_layers, + # Critical: Set finalize_model_grads_func for DP gradient synchronization + # Uses custom wrapper that handles both DDP and PEFT/LoRA models finalize_model_grads_func=finalize_model_grads_for_lora, + # MoE configuration **moe_kwargs, ) if exists('megatron_core>=0.13'): diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index cedd3dc2..fe9733f1 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -363,9 +363,7 @@ def collate_fn(self, micro_batch_size: Optional[int] = None, variable_seq_lengths=False, **kwargs) -> List[InputFeature]: - if len(inputs) == 1: - if self.framework == 'megatron' and 'attention_mask' in inputs[0]: - inputs[0]['attention_mask'] = self._create_4d_attention_mask([inputs[0]['attention_mask'].squeeze(0)]) + if len(inputs) == 1 and self.framework != 'megatron': return inputs if micro_batch_size is None: # normal collate From a66fb6441ac1a2433fc737f2e8843c035de19de2 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 26 Feb 2026 17:29:47 +0800 Subject: [PATCH 05/14] rename script --- cookbook/megatron/{tp_moe_qwen35.py => qwen3_5.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename cookbook/megatron/{tp_moe_qwen35.py => qwen3_5.py} (100%) diff --git a/cookbook/megatron/tp_moe_qwen35.py b/cookbook/megatron/qwen3_5.py similarity index 100% rename from cookbook/megatron/tp_moe_qwen35.py rename to cookbook/megatron/qwen3_5.py From e4d2bac3535e0343bb9e67eee9656d9b5a9dfd8b Mon Sep 17 00:00:00 2001 From: root Date: Thu, 26 Feb 2026 17:34:16 +0800 Subject: [PATCH 06/14] revert --- .../model/megatron/model/gpt_bridge.py | 50 +++++++------------ 1 file changed, 17 insertions(+), 33 deletions(-) diff --git a/src/twinkle/model/megatron/model/gpt_bridge.py b/src/twinkle/model/megatron/model/gpt_bridge.py index a7fc9d10..d04ed289 100644 --- a/src/twinkle/model/megatron/model/gpt_bridge.py +++ b/src/twinkle/model/megatron/model/gpt_bridge.py @@ -111,29 +111,6 @@ def get_hf_mlp_prefix(self, layer_idx): def _get_hf_mlp(self, layer_idx): return getattr(self.hf_layers[layer_idx], self.get_hf_mlp_prefix(layer_idx)) - _HF_GROUPED_FALSE_TYPES = { - 'qwen2_moe', - 'qwen3_moe', - 'deepseek_v2', - 'deepseek_v3', - 'dots1', - 'ernie4_5_moe', - 'glm4_moe', - 'glm4_moe_lite', - 'glm4v_moe', - 'minimax_m2', - 'olmoe', - 'qwen3_next', - 'kimi_vl', - 'qwen3_omni_moe', - 'qwen3_5_moe', - } - - def _get_hf_grouped(self): - if self.args.hf_model_type in self._HF_GROUPED_FALSE_TYPES: - return False, False - return None, None - def _get_transpose(self): if self.args.hf_model_type in {'qwen3_vl_moe', 'gpt_oss', 'llama4'}: return True @@ -177,6 +154,15 @@ def _init_meta_hf_model(self): self.processor = auto_tokenizer_cls.from_pretrained(model_dir, trust_remote_code=True) + def _get_hf_grouped(self): + if self.args.hf_model_type in { + 'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'dots1', 'ernie4_5_moe', 'glm4_moe', + 'glm4_moe_lite', 'glm4v_moe', 'minimax_m2', 'olmoe', 'qwen3_next', 'kimi_vl', 'qwen3_omni_moe', + 'qwen3_5_moe' + }: + return False, False + return None, None + def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: if mg_key is None: return @@ -764,16 +750,14 @@ def _set_moe_state( hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict - def _set_mlp_state( - self, - mg_mlp, - hf_state_dict, - hf_prefix: str, - layer_idx: int, - to_mcore: bool, - ep_rank: Optional[int] = None, - hf_mlp=None, - ): + def _set_mlp_state(self, + mg_mlp, + hf_state_dict, + hf_prefix: str, + layer_idx: int, + to_mcore: bool, + ep_rank: Optional[int] = None, + hf_mlp=None): if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) if hf_mlp is None: From d0532a8be7b682e7f573820433a32d4fd3d33a58 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 26 Feb 2026 17:35:00 +0800 Subject: [PATCH 07/14] align --- .../model/megatron/model/gpt_bridge.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/twinkle/model/megatron/model/gpt_bridge.py b/src/twinkle/model/megatron/model/gpt_bridge.py index d04ed289..ef6faed8 100644 --- a/src/twinkle/model/megatron/model/gpt_bridge.py +++ b/src/twinkle/model/megatron/model/gpt_bridge.py @@ -154,15 +154,6 @@ def _init_meta_hf_model(self): self.processor = auto_tokenizer_cls.from_pretrained(model_dir, trust_remote_code=True) - def _get_hf_grouped(self): - if self.args.hf_model_type in { - 'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'dots1', 'ernie4_5_moe', 'glm4_moe', - 'glm4_moe_lite', 'glm4v_moe', 'minimax_m2', 'olmoe', 'qwen3_next', 'kimi_vl', 'qwen3_omni_moe', - 'qwen3_5_moe' - }: - return False, False - return None, None - def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: if mg_key is None: return @@ -750,6 +741,16 @@ def _set_moe_state( hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict + + def _get_hf_grouped(self): + if self.args.hf_model_type in { + 'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'dots1', 'ernie4_5_moe', 'glm4_moe', + 'glm4_moe_lite', 'glm4v_moe', 'minimax_m2', 'olmoe', 'qwen3_next', 'kimi_vl', 'qwen3_omni_moe', + 'qwen3_5_moe' + }: + return False, False + return None, None + def _set_mlp_state(self, mg_mlp, hf_state_dict, From a2f270d7bea0b62bf6e5893f6a502b1797895f4d Mon Sep 17 00:00:00 2001 From: root Date: Thu, 26 Feb 2026 17:35:26 +0800 Subject: [PATCH 08/14] clean --- cookbook/megatron/qwen3_5.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cookbook/megatron/qwen3_5.py b/cookbook/megatron/qwen3_5.py index 4e845a55..406b2a3e 100644 --- a/cookbook/megatron/qwen3_5.py +++ b/cookbook/megatron/qwen3_5.py @@ -1,6 +1,4 @@ -import os from peft import LoraConfig -from tqdm import tqdm import twinkle from twinkle import DeviceMesh, get_device_placement, get_logger From 252b42a838d53a1ad42ed56f267b11d0a807e133 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 26 Feb 2026 17:36:02 +0800 Subject: [PATCH 09/14] lint --- src/twinkle/model/megatron/model/gpt_bridge.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/twinkle/model/megatron/model/gpt_bridge.py b/src/twinkle/model/megatron/model/gpt_bridge.py index ef6faed8..cb66b746 100644 --- a/src/twinkle/model/megatron/model/gpt_bridge.py +++ b/src/twinkle/model/megatron/model/gpt_bridge.py @@ -741,7 +741,6 @@ def _set_moe_state( hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict - def _get_hf_grouped(self): if self.args.hf_model_type in { 'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'dots1', 'ernie4_5_moe', 'glm4_moe', From 9212773d943f343eef843311d63f977c601c7346 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 27 Feb 2026 15:22:48 +0800 Subject: [PATCH 10/14] fix saving --- cookbook/megatron/qwen3_5.py | 7 ++- src/twinkle/model/megatron/args.py | 3 -- src/twinkle/model/megatron/megatron.py | 36 ++++++++++++-- .../model/megatron/model/gpt_bridge.py | 16 ++---- src/twinkle/utils/__init__.py | 2 +- src/twinkle/utils/transformers_utils.py | 49 ------------------- 6 files changed, 41 insertions(+), 72 deletions(-) diff --git a/cookbook/megatron/qwen3_5.py b/cookbook/megatron/qwen3_5.py index 406b2a3e..8807c066 100644 --- a/cookbook/megatron/qwen3_5.py +++ b/cookbook/megatron/qwen3_5.py @@ -13,7 +13,6 @@ logger = get_logger() MODEL_ID = 'Qwen/Qwen3.5-35B-A3B' -MAX_STEPS = 100 def train(): dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) @@ -37,9 +36,9 @@ def train(): if step % 5 == 0: metric = model.calculate_metric(is_training=True) logger.info(f'Step {step}/{len(dataloader)}, metric: {metric}') - if step >= MAX_STEPS: - break - model.save('last-checkpoint') + + # NOTE: you should merge lora for Qwen3.5 model when using Megatron + model.save('last-checkpoint', merge_lora=True) logger.info('Training completed.') diff --git a/src/twinkle/model/megatron/args.py b/src/twinkle/model/megatron/args.py index 68776fd9..80759d82 100644 --- a/src/twinkle/model/megatron/args.py +++ b/src/twinkle/model/megatron/args.py @@ -138,9 +138,6 @@ class TwinkleMegatronArgs: # ========================================================================= merge_lora: bool = False target_modules: List[str] = field(default_factory=list) - freeze_llm: bool = False - freeze_vit: bool = False - freeze_aligner: bool = False # ========================================================================= # FP8 quantization settings diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index aa74e72e..38b34a7e 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -819,6 +819,7 @@ def save(self, output_dir: Optional[str] = None, interval: int = 1, save_optimizer: bool = False, + merge_lora: bool = False, **kwargs): """Save model checkpoint. @@ -832,6 +833,9 @@ def save(self, interval: Save each *interval* steps. save_optimizer: If True, save optimizer + lr_scheduler + RNG state alongside the HF weights for checkpoint resumption. + merge_lora: If True, merge LoRA adapters into base weights and save + the full merged model instead of PEFT adapter format. The merge + is reversed after saving so training can continue. **kwargs: Additional arguments forwarded to the underlying save methods (e.g. ``adapter_name``). """ @@ -846,8 +850,16 @@ def save(self, output_dir = 'output' checkpoint_dir = os.path.join(output_dir, name) - # Always save HF-format weights (for inference / deployment). - self._save_hf_format(checkpoint_dir, optimizer_config.adapter_name) + is_lora = (optimizer_config.adapter_name != _default_adapter_name) + + if merge_lora and is_lora: + self._merge_lora_adapters(optimizer_config.adapter_name) + self._save_hf_format(checkpoint_dir, _default_adapter_name) + self._save_tokenizer(checkpoint_dir, adapter_name=adapter_name) + self._unmerge_lora_adapters() + else: + self._save_hf_format(checkpoint_dir, optimizer_config.adapter_name) + self._save_tokenizer(checkpoint_dir, adapter_name=adapter_name) # Optionally save mcore optimizer state (for training resumption). if save_optimizer: @@ -857,8 +869,6 @@ def save(self, **kwargs, ) - self._save_tokenizer(checkpoint_dir, adapter_name=adapter_name) - # Final synchronization to ensure all ranks complete save. if dist.is_initialized(): dist.barrier() @@ -1160,6 +1170,24 @@ def _read_iteration(tracker_path: str) -> int: iteration = iters_cuda[0].item() return iteration + def _merge_lora_adapters(self, adapter_name: str = 'default'): + """Merge LoRA adapters into base model weights.""" + from .tuners.lora import LoraParallelLinear + with torch.no_grad(): + for model in self.strategy.unwrap_model(self.model): + for module in model.modules(): + if isinstance(module, (LoraParallelLinear, LoraLinear)): + module.merge(adapter_names=[adapter_name]) + + def _unmerge_lora_adapters(self): + """Unmerge LoRA adapters to restore training state.""" + from .tuners.lora import LoraParallelLinear + with torch.no_grad(): + for model in self.strategy.unwrap_model(self.model): + for module in model.modules(): + if isinstance(module, (LoraParallelLinear, LoraLinear)): + module.unmerge() + def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=None): """Save in HuggingFace format using bridge adapter. diff --git a/src/twinkle/model/megatron/model/gpt_bridge.py b/src/twinkle/model/megatron/model/gpt_bridge.py index cb66b746..a5190ae0 100644 --- a/src/twinkle/model/megatron/model/gpt_bridge.py +++ b/src/twinkle/model/megatron/model/gpt_bridge.py @@ -20,7 +20,7 @@ from twinkle.hub import HubOperation from twinkle.model.megatron.args import get_args # Use twinkle's get_args from twinkle.utils import (MxFp4Dequantizer, SafetensorLazyLoader, StreamingSafetensorSaver, deep_getattr, get_logger, - get_modules_to_not_convert, get_multimodal_target_regex, is_last_rank, requires) + get_modules_to_not_convert, is_last_rank, requires) logger = get_logger() @@ -306,6 +306,9 @@ def _set_module(self, mg_module, hf_state_dict, hf_prefix: str, to_mcore: bool): if self._is_peft_format: if '.lora_A.' in k or '.lora_B.' in k or '.modules_to_save.' in k: k = k.replace(f'{self._adapter_name}.', '') + if '.lora_A.' in k: + module_name = k.split('.lora_A.')[0].rsplit('.', 1)[-1] + self._peft_target_modules.add(module_name) new_state_dict[k] = v else: if '.lora_A.' in k or '.lora_B.' in k or 'original_module.' in k: @@ -1568,16 +1571,7 @@ def save_weights(self, peft_config = copy(mg_models[0].peft_config[self._adapter_name]) if args.task_type == 'seq_cls': peft_config.task_type = 'SEQ_CLS' - if args.is_multimodal and 'all-linear' in args.target_modules: - peft_config.target_modules = get_multimodal_target_regex( - self.hf_model, - freeze_llm=args.freeze_llm, - freeze_vit=args.freeze_vit, - freeze_aligner=args.freeze_aligner, - include_embedding='all-embedding' in args.target_modules, - exclude_router='all-router' not in args.target_modules) - else: - peft_config.target_modules = self._peft_target_modules + peft_config.target_modules = self._peft_target_modules peft_config.modules_to_save = self._peft_modules_to_save peft_config.save_pretrained(output_dir) else: diff --git a/src/twinkle/utils/__init__.py b/src/twinkle/utils/__init__.py index edcefc34..1d7f9028 100644 --- a/src/twinkle/utils/__init__.py +++ b/src/twinkle/utils/__init__.py @@ -11,6 +11,6 @@ from .platforms import GPU, NPU, Platform, ensure_hccl_socket_env, ensure_npu_backend from .safetensors import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver from .torch_utils import pad_sequence_to_length, selective_log_softmax, stateless_init_process_group, to_device -from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert, get_multimodal_target_regex +from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert from .unsafe import check_unsafe, trust_remote_code from .utils import copy_files_by_pattern, deep_getattr diff --git a/src/twinkle/utils/transformers_utils.py b/src/twinkle/utils/transformers_utils.py index ee751c90..cecf6425 100644 --- a/src/twinkle/utils/transformers_utils.py +++ b/src/twinkle/utils/transformers_utils.py @@ -70,55 +70,6 @@ def _cond(name, module): return find_layers(model, _cond, sub_module=sub_module) - -def get_multimodal_target_regex( - model, - *, - freeze_llm: bool = False, - freeze_vit: bool = True, - freeze_aligner: bool = True, - include_embedding: bool = False, - exclude_router: bool = False, -) -> str: - import torch.nn as nn - model_arch = model.model_meta.model_arch - modules = [] - if not freeze_llm: - modules += model_arch.language_model - if not freeze_vit: - modules += model_arch.vision_tower - if not freeze_aligner: - modules += model_arch.aligner - assert len(modules) > 0, f'modules: {modules}' - - extra_layers = [] - if include_embedding: - extra_layers.append(nn.Embedding) - res = [] - for module in modules: - rejected_modules = [] - if not freeze_vit or not freeze_llm: - for aligner in model_arch.aligner: - if aligner.startswith(f'{module}.'): - rejected_modules.append(aligner) - - sub_module = deep_getattr(model, module) - if isinstance(sub_module, nn.Linear) and module.endswith('lm_head'): - target_modules = [] - else: - target_modules = find_all_linears(sub_module, model_arch, extra_layers) - if exclude_router and model.model_info.is_moe_model: - target_modules = [tm for tm in target_modules if tm not in {'gate'}] - if not target_modules: - continue - target_modules = [tm for tm in target_modules if tm] - target_pattern = rf'.*\.({"|".join(target_modules)})' if target_modules else '' - rejected_pattern = rf'(?!({"|".join(rejected_modules)}))' if rejected_modules else '' - res.append(rf'{rejected_pattern}{module}{target_pattern}') - - return rf'^({"|".join(res)})$' - - def get_modules_to_not_convert(model): if not hasattr(model, 'model_meta') or not hasattr(model, 'model_info'): return From 1f459ff17a11f15aee898558687e9e71dda7d157 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 27 Feb 2026 15:32:09 +0800 Subject: [PATCH 11/14] lint --- src/twinkle/utils/transformers_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/twinkle/utils/transformers_utils.py b/src/twinkle/utils/transformers_utils.py index cecf6425..036f7538 100644 --- a/src/twinkle/utils/transformers_utils.py +++ b/src/twinkle/utils/transformers_utils.py @@ -70,6 +70,7 @@ def _cond(name, module): return find_layers(model, _cond, sub_module=sub_module) + def get_modules_to_not_convert(model): if not hasattr(model, 'model_meta') or not hasattr(model, 'model_info'): return From 7320668a4d0ed03755e903f6ea72e4507e545f9d Mon Sep 17 00:00:00 2001 From: root Date: Fri, 27 Feb 2026 16:58:42 +0800 Subject: [PATCH 12/14] fix dense --- src/twinkle/model/megatron/model/register.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/twinkle/model/megatron/model/register.py b/src/twinkle/model/megatron/model/register.py index ab59569d..07dfd82a 100644 --- a/src/twinkle/model/megatron/model/register.py +++ b/src/twinkle/model/megatron/model/register.py @@ -55,10 +55,10 @@ def get_layer_spec(self, config, args, mg_config_dict): A ``ModuleSpec`` or ``TransformerBlockSubmodules`` instance. """ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec - num_experts = mg_config_dict.get('num_experts', 0) or 0 + num_experts = mg_config_dict.get('num_experts') or None return get_gpt_layer_with_transformer_engine_spec( num_experts=num_experts, - moe_grouped_gemm=num_experts > 0, + moe_grouped_gemm=num_experts is not None, qk_layernorm=mg_config_dict.get('qk_layernorm', False), ) From 9c43b26a2bfc4956a2302b7c462c8a2e9ab3e0ba Mon Sep 17 00:00:00 2001 From: root Date: Sat, 28 Feb 2026 17:15:29 +0800 Subject: [PATCH 13/14] fix 4d attention device --- src/twinkle/processor/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index ff0fbabf..160f873b 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -246,8 +246,9 @@ def _create_4d_attention_mask(attention_mask): import torch seq_lens = [s.shape[0] for s in attention_mask] max_len = max(seq_lens) + device = attention_mask[0].device attention_mask = torch.tril(torch.ones((len(seq_lens), max_len, max_len), - dtype=torch.bool)).view(len(seq_lens), 1, max_len, max_len) + dtype=torch.bool, device=device)).view(len(seq_lens), 1, max_len, max_len) assert attention_mask.dtype is torch.bool, f'attention_mask.dtype: {attention_mask.dtype}' for i, seq_len in enumerate(seq_lens): attention_mask[i, :, :, seq_len:] = 0 From d7a992dada696f45fd45a270c1355e7bb23edbd3 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 28 Feb 2026 17:36:58 +0800 Subject: [PATCH 14/14] lint --- src/twinkle/processor/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index 160f873b..6f10ff63 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -247,8 +247,8 @@ def _create_4d_attention_mask(attention_mask): seq_lens = [s.shape[0] for s in attention_mask] max_len = max(seq_lens) device = attention_mask[0].device - attention_mask = torch.tril(torch.ones((len(seq_lens), max_len, max_len), - dtype=torch.bool, device=device)).view(len(seq_lens), 1, max_len, max_len) + attention_mask = torch.tril(torch.ones((len(seq_lens), max_len, max_len), dtype=torch.bool, + device=device)).view(len(seq_lens), 1, max_len, max_len) assert attention_mask.dtype is torch.bool, f'attention_mask.dtype: {attention_mask.dtype}' for i, seq_len in enumerate(seq_lens): attention_mask[i, :, :, seq_len:] = 0