-
Notifications
You must be signed in to change notification settings - Fork 29
refactor fp8 linear #210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
refactor fp8 linear #210
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
| from typing import Tuple | ||
|
|
||
| from diffsynth_engine.utils.platform import DTYPE_FP8, FP8_MAX | ||
|
|
||
|
|
||
| def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| x_max = x.abs().float().amax(dim=-1, keepdim=True).clamp(min=1e-4) | ||
| scale = x_max / FP8_MAX | ||
| x_scaled = x / scale | ||
| return x_scaled, scale | ||
|
|
||
|
|
||
| def fp8_linear( | ||
| input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None, scaling: bool = True | ||
| ) -> torch.Tensor: | ||
| device = input.device | ||
| origin_dtype = input.dtype | ||
| origin_shape = input.shape | ||
| input = input.reshape(-1, origin_shape[-1]) | ||
| out_features, _ = weight.shape | ||
|
|
||
| if scaling: | ||
| input, scale_a = per_token_cast_to_fp8(input) | ||
| scale_b = torch.ones((out_features, 1), device=device) | ||
| else: | ||
| scale_a = torch.tensor(1.0, device=device) | ||
| scale_b = torch.tensor(1.0, device=device) | ||
| input = input.to(DTYPE_FP8) | ||
| weight = weight.to(DTYPE_FP8) | ||
|
|
||
| result = torch._scaled_mm( | ||
| input, | ||
| weight.T, | ||
| scale_a=scale_a, | ||
| scale_b=scale_b.T, | ||
| bias=bias, | ||
| out_dtype=origin_dtype, | ||
| ) | ||
| new_shape = origin_shape[:-1] + result.shape[-1:] | ||
| result = result.reshape(new_shape) | ||
| return result | ||
|
|
||
|
|
||
| class FP8Linear(nn.Linear): | ||
| def __init__( | ||
| self, | ||
| in_features: int, | ||
| out_features: int, | ||
| bias: bool = True, | ||
| device=None, | ||
| dtype=None, | ||
| scaling: bool = True, | ||
| ): | ||
| super().__init__(in_features, out_features, bias, device, dtype) | ||
| self.weight.data = self.weight.data.to(DTYPE_FP8) | ||
| self.scaling = scaling | ||
|
|
||
| def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
| return fp8_linear(input, self.weight, self.bias, self.scaling) |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -7,8 +7,8 @@ | |||||
| from diffsynth_engine.models.basic import attention as attention_ops | ||||||
| from diffsynth_engine.models.basic.timestep import TimestepEmbeddings | ||||||
| from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, GELU, RMSNorm | ||||||
| from diffsynth_engine.models.basic.lora import LoRAFP8Linear | ||||||
| from diffsynth_engine.utils.gguf import gguf_inference | ||||||
| from diffsynth_engine.utils.fp8_linear import fp8_inference | ||||||
| from diffsynth_engine.utils.parallel import ( | ||||||
| cfg_parallel, | ||||||
| cfg_parallel_unshard, | ||||||
|
|
@@ -441,10 +441,8 @@ def forward( | |||||
| attn_kwargs: Optional[Dict[str, Any]] = None, | ||||||
| ): | ||||||
| h, w = image.shape[-2:] | ||||||
| fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False) | ||||||
| use_cfg = image.shape[0] > 1 | ||||||
| with ( | ||||||
| fp8_inference(fp8_linear_enabled), | ||||||
| gguf_inference(), | ||||||
| cfg_parallel( | ||||||
| ( | ||||||
|
|
@@ -540,3 +538,9 @@ def compile_repeated_blocks(self, *args, **kwargs): | |||||
|
|
||||||
| def get_fsdp_module_cls(self): | ||||||
| return {QwenImageTransformerBlock} | ||||||
|
|
||||||
| def enable_fp8_linear(self): | ||||||
| target_names = ["transformer_blocks"] | ||||||
| for name, module in self.named_modules(): | ||||||
| if any([t in name for t in target_names]) and isinstance(module, nn.Linear): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The condition
Suggested change
|
||||||
| self.set_submodule(name, LoRAFP8Linear.from_linear(module)) | ||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |||||
| from diffsynth_engine.models.basic.attention import attention | ||||||
| from diffsynth_engine.models.basic import attention as attention_ops | ||||||
| from diffsynth_engine.models.basic.transformer_helper import RMSNorm | ||||||
| from diffsynth_engine.models.basic.lora import LoRAFP8Linear | ||||||
| from diffsynth_engine.utils.constants import ( | ||||||
| WAN2_1_DIT_T2V_1_3B_CONFIG_FILE, | ||||||
| WAN2_1_DIT_I2V_14B_CONFIG_FILE, | ||||||
|
|
@@ -20,7 +21,6 @@ | |||||
| WAN_DIT_KEYMAP_FILE, | ||||||
| ) | ||||||
| from diffsynth_engine.utils.gguf import gguf_inference | ||||||
| from diffsynth_engine.utils.fp8_linear import fp8_inference | ||||||
| from diffsynth_engine.utils.parallel import ( | ||||||
| cfg_parallel, | ||||||
| cfg_parallel_unshard, | ||||||
|
|
@@ -386,10 +386,8 @@ def forward( | |||||
| y: Optional[torch.Tensor] = None, # vae_encoder(img) | ||||||
| attn_kwargs: Optional[Dict[str, Any]] = None, | ||||||
| ): | ||||||
| fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False) | ||||||
| use_cfg = x.shape[0] > 1 | ||||||
| with ( | ||||||
| fp8_inference(fp8_linear_enabled), | ||||||
| gguf_inference(), | ||||||
| cfg_parallel((x, context, timestep, clip_feature, y), use_cfg=use_cfg), | ||||||
| ): | ||||||
|
|
@@ -541,3 +539,9 @@ def compile_repeated_blocks(self, *args, **kwargs): | |||||
|
|
||||||
| def get_fsdp_module_cls(self): | ||||||
| return {DiTBlock} | ||||||
|
|
||||||
| def enable_fp8_linear(self): | ||||||
| target_names = ["blocks"] | ||||||
| for name, module in self.named_modules(): | ||||||
| if any([t in name for t in target_names]) and isinstance(module, nn.Linear): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The condition
Suggested change
|
||||||
| self.set_submodule(name, LoRAFP8Linear.from_linear(module)) | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,7 +24,6 @@ | |
| from diffsynth_engine.tokenizers import WanT5Tokenizer | ||
| from diffsynth_engine.utils.constants import WAN_TOKENIZER_CONF_PATH | ||
| from diffsynth_engine.utils.download import fetch_model | ||
| from diffsynth_engine.utils.fp8_linear import enable_fp8_linear | ||
| from diffsynth_engine.utils.image import resize_and_center_crop | ||
| from diffsynth_engine.utils.video import read_n_frames | ||
| from diffsynth_engine.utils.parallel import ParallelWrapper | ||
|
|
@@ -666,7 +665,7 @@ def _from_state_dict( | |
| use_vsa=(config.dit_attn_impl.value == "vsa"), | ||
| ) | ||
| if config.use_fp8_linear: | ||
| enable_fp8_linear(dit) | ||
| dit.enable_fp8_linear() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This call to To fix this, you should implement the You can add a method similar to the one in from diffsynth_engine.models.basic.lora import LoRAFP8Linear
...
class WanS2VDiT(WanDiT):
...
def enable_fp8_linear(self):
target_names = ["blocks"]
for name, module in self.named_modules():
if any(name.startswith(f"{t}.") for t in target_names) and isinstance(module, nn.Linear):
self.set_submodule(name, LoRAFP8Linear.from_linear(module))Note that you'll also need to import |
||
|
|
||
| pipe = cls( | ||
| config=config, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition
any([t in name for t in target_names])to identify target modules is a bit loose and could lead to incorrect module replacement if a module name incidentally contains one of the target names (e.g., a new module namedarchived_blocks). Usingstartswithwould be more robust and ensure that only modules within the specified parent modules are matched.