-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathindexcache.patch
More file actions
212 lines (203 loc) · 9.49 KB
/
indexcache.patch
File metadata and controls
212 lines (203 loc) · 9.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
diff --git a/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py b/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py
index 9eca43ce..a4d4ca4e 100644
--- a/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py
+++ b/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py
@@ -90,6 +90,7 @@ class DeepseekMLAForwardMixin:
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
llama_4_scaling: Optional[torch.Tensor] = None,
+ prev_topk_indices: Optional[torch.Tensor] = None,
):
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
@@ -181,18 +182,7 @@ class DeepseekMLAForwardMixin:
q = self.q_b_proj(q)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
- topk_indices = self.indexer(
- x=hidden_states,
- q_lora=q_lora,
- positions=positions,
- forward_batch=forward_batch,
- layer_id=self.layer_id,
- )
- current_stream.wait_stream(self.alt_stream)
- else:
- k_nope = k_nope.unsqueeze(1)
- q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
- if q_lora is not None:
+ if not self.skip_topk:
topk_indices = self.indexer(
x=hidden_states,
q_lora=q_lora,
@@ -200,6 +190,23 @@ class DeepseekMLAForwardMixin:
forward_batch=forward_batch,
layer_id=self.layer_id,
)
+ else:
+ topk_indices = prev_topk_indices
+ current_stream.wait_stream(self.alt_stream)
+ else:
+ k_nope = k_nope.unsqueeze(1)
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
+ if q_lora is not None:
+ if not self.skip_topk:
+ topk_indices = self.indexer(
+ x=hidden_states,
+ q_lora=q_lora,
+ positions=positions,
+ forward_batch=forward_batch,
+ layer_id=self.layer_id,
+ )
+ else:
+ topk_indices = prev_topk_indices
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
@@ -494,7 +501,10 @@ class DeepseekMLAForwardMixin:
)
output, _ = self.o_proj(attn_bmm_output)
- return output
+ if not self.next_skip_topk:
+ return output, None
+ else:
+ return output, topk_indices
def _fuse_rope_for_trtllm_mla(
self: DeepseekV2AttentionMLA, forward_batch: ForwardBatch
diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py
index d57eb882..0327434f 100644
--- a/python/sglang/srt/models/deepseek_nextn.py
+++ b/python/sglang/srt/models/deepseek_nextn.py
@@ -166,7 +166,7 @@ class DeepseekModelNextN(nn.Module):
positions = cp_split_and_rebuild_position(forward_batch, positions)
residual = None
with get_global_expert_distribution_recorder().disable_this_region():
- hidden_states, residual = self.decoder(
+ hidden_states, residual, topk_indices = self.decoder(
positions,
hidden_states,
forward_batch,
diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py
index 8f061714..4541527f 100644
--- a/python/sglang/srt/models/deepseek_v2.py
+++ b/python/sglang/srt/models/deepseek_v2.py
@@ -1074,6 +1074,7 @@ class DeepseekV2AttentionMLA(
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
skip_rope: bool = False,
+ is_nextn: bool = False,
) -> None:
super().__init__()
self.layer_id = layer_id
@@ -1163,6 +1164,22 @@ class DeepseekV2AttentionMLA(
layer_id=layer_id,
alt_stream=alt_stream,
)
+ if is_nextn:
+ self.skip_topk = False
+ self.next_skip_topk = False
+ else:
+ self.index_topk_freq = getattr(config, "index_topk_freq", 1)
+ self.index_topk_pattern = getattr(config, "index_topk_pattern", None)
+ if self.index_topk_pattern is None:
+ self.skip_topk = (max(layer_id-1, 0) % self.index_topk_freq != 0)
+ self.next_skip_topk = (layer_id % self.index_topk_freq != 0)
+ else:
+ self.skip_topk = self.index_topk_pattern[layer_id] == 'S'
+ if layer_id < len(self.index_topk_pattern) - 1:
+ self.next_skip_topk = self.index_topk_pattern[layer_id+1] == 'S'
+ else:
+ self.next_skip_topk = False
+ print('layer_id {} DSA skip_topk {} next_skip_topk {} is_nextn {}'.format(layer_id, self.skip_topk, self.next_skip_topk, is_nextn))
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
@@ -1309,6 +1326,7 @@ class DeepseekV2AttentionMLA(
zero_allocator: BumpAllocator,
layer_scatter_modes: LayerScatterModes = None,
llama_4_scaling: Optional[torch.Tensor] = None,
+ prev_topk_indices: Optional[torch.Tensor] = None,
):
s = self.forward_prepare(
positions=positions,
@@ -1317,6 +1335,7 @@ class DeepseekV2AttentionMLA(
zero_allocator=zero_allocator,
layer_scatter_modes=layer_scatter_modes,
llama_4_scaling=llama_4_scaling,
+ prev_topk_indices=prev_topk_indices,
)
return self.forward_core(s)
@@ -1328,6 +1347,7 @@ class DeepseekV2AttentionMLA(
zero_allocator: BumpAllocator,
layer_scatter_modes: LayerScatterModes = None,
llama_4_scaling: Optional[torch.Tensor] = None,
+ prev_topk_indices: Optional[torch.Tensor] = None,
):
if self.attn_mha.kv_b_proj is None:
self.attn_mha.kv_b_proj = self.kv_b_proj
@@ -1367,7 +1387,7 @@ class DeepseekV2AttentionMLA(
)
elif attn_forward_method == AttnForwardMethod.MLA:
inner_state = self.forward_absorb_prepare(
- positions, hidden_states, forward_batch, zero_allocator, llama_4_scaling
+ positions, hidden_states, forward_batch, zero_allocator, llama_4_scaling, prev_topk_indices
)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_ROCM:
inner_state = self.forward_absorb_fused_mla_rope_prepare(
@@ -1528,6 +1548,7 @@ class DeepseekV2DecoderLayer(nn.Module):
reduce_results=False,
prefix=add_prefix("self_attn", prefix),
alt_stream=alt_stream,
+ is_nextn=is_nextn,
)
if not hasattr(config, "q_lora_rank") and envs.SGLANG_USE_AG_AFTER_QLORA.get():
raise ValueError(
@@ -1614,6 +1635,7 @@ class DeepseekV2DecoderLayer(nn.Module):
zero_allocator: BumpAllocator,
gemm_output_zero_allocator: BumpAllocator = None,
llama_4_scaling: Optional[torch.Tensor] = None,
+ prev_topk_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
quant_format = (
"mxfp4"
@@ -1656,7 +1678,12 @@ class DeepseekV2DecoderLayer(nn.Module):
zero_allocator=zero_allocator,
llama_4_scaling=llama_4_scaling,
layer_scatter_modes=self.layer_scatter_modes,
+ prev_topk_indices=prev_topk_indices,
)
+ if isinstance(hidden_states, tuple):
+ hidden_states, topk_indices = hidden_states
+ else:
+ topk_indices = None
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
@@ -1692,7 +1719,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual, forward_batch
)
- return hidden_states, residual
+ return hidden_states, residual, topk_indices
def op_comm_prepare_attn(
self,
@@ -1969,6 +1996,7 @@ class DeepseekV2Model(nn.Module):
elif self.first_k_dense_replace < normal_start_layer:
normal_end_layer = normal_start_layer = 0
aux_hidden_states = []
+ topk_indices = None
for i in range(normal_start_layer, normal_end_layer):
# NOTE: torch dynamo does not support graph break in context manager
ctx = (
@@ -1986,7 +2014,7 @@ class DeepseekV2Model(nn.Module):
else:
aux_hidden_states.append(hidden_states + residual)
layer = self.layers[i]
- hidden_states, residual = layer(
+ hidden_states, residual, topk_indices = layer(
positions,
hidden_states,
forward_batch,
@@ -1994,6 +2022,7 @@ class DeepseekV2Model(nn.Module):
zero_allocator,
gemm_output_zero_allocator,
llama_4_scaling,
+ prev_topk_indices=topk_indices,
)
if normal_end_layer != self.end_layer: