Skip to content

vllm.compilation.activation_quant_fusion

FP4_DTYPE module-attribute

FP4_DTYPE = uint8

FP8_DTYPE module-attribute

FP8_DTYPE = fp8_dtype()

FUSED_OPS module-attribute

FUSED_OPS: dict[QuantKey, OpOverload] = {
    kFp8StaticTensorSym: default
}

SILU_MUL_OP module-attribute

SILU_MUL_OP = default

logger module-attribute

logger = init_logger(__name__)

silu_and_mul_nvfp4_quant_supported module-attribute

silu_and_mul_nvfp4_quant_supported = is_cuda() and hasattr(
    _C, "silu_and_mul_nvfp4_quant"
)

ActivationQuantFusionPass

Bases: VllmInductorPass

This pass fuses a pre-defined set of custom ops into fused ops. It uses the torch pattern matcher to find the patterns and replace them.

Because patterns can only be registered once, the pass is a singleton. This will be addressed in a future version of PyTorch: https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980

Source code in vllm/compilation/activation_quant_fusion.py
class ActivationQuantFusionPass(VllmInductorPass):
    """
    This pass fuses a pre-defined set of custom ops into fused ops.
    It uses the torch pattern matcher to find the patterns and replace them.

    Because patterns can only be registered once, the pass is a singleton.
    This will be addressed in a future version of PyTorch:
    https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
    """

    @enable_fake_mode
    def __init__(self, config: VllmConfig):
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
            pass_name="activation_quant_fusion_pass")

        pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
        pattern_silu_mul_fp8.register(self.patterns)

        if silu_and_mul_nvfp4_quant_supported:
            pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
            pattern_silu_mul_nvfp4.register(self.patterns)

    def __call__(self, graph: torch.fx.Graph):
        self.begin()
        self.dump_graph(graph, "before_act_quant_fusion")

        count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns in ActivationQuantFusionPass",
                     count)

        self.dump_graph(graph, "after_act_quant_fusion")
        self.end_and_log()

    def uuid(self):
        return VllmInductorPass.hash_source(self, ActivationQuantPattern,
                                            SiluMulFp8StaticQuantPattern,
                                            SiluMulNvfp4QuantPattern)

patterns instance-attribute

patterns: PatternMatcherPass = PatternMatcherPass(
    pass_name="activation_quant_fusion_pass"
)

__call__

__call__(graph: Graph)
Source code in vllm/compilation/activation_quant_fusion.py
def __call__(self, graph: torch.fx.Graph):
    self.begin()
    self.dump_graph(graph, "before_act_quant_fusion")

    count = self.patterns.apply(graph)
    logger.debug("Replaced %s patterns in ActivationQuantFusionPass",
                 count)

    self.dump_graph(graph, "after_act_quant_fusion")
    self.end_and_log()

__init__

__init__(config: VllmConfig)
Source code in vllm/compilation/activation_quant_fusion.py
@enable_fake_mode
def __init__(self, config: VllmConfig):
    super().__init__(config)

    self.patterns: PatternMatcherPass = PatternMatcherPass(
        pass_name="activation_quant_fusion_pass")

    pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
    pattern_silu_mul_fp8.register(self.patterns)

    if silu_and_mul_nvfp4_quant_supported:
        pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
        pattern_silu_mul_nvfp4.register(self.patterns)

uuid

uuid()
Source code in vllm/compilation/activation_quant_fusion.py
def uuid(self):
    return VllmInductorPass.hash_source(self, ActivationQuantPattern,
                                        SiluMulFp8StaticQuantPattern,
                                        SiluMulNvfp4QuantPattern)

ActivationQuantPattern

Bases: ABC

The base class for Activation+Quant fusions. Should not be used directly.

Source code in vllm/compilation/activation_quant_fusion.py
class ActivationQuantPattern(ABC):
    """
    The base class for Activation+Quant fusions.
    Should not be used directly.
    """

    def __init__(
        self,
        quant_key: QuantKey,
    ):
        self.quant_key = quant_key
        self.quant_dtype = quant_key.dtype

        assert self.quant_key in QUANT_OPS, \
            f"unsupported quantization scheme {self.quant_key}"
        self.QUANT_OP = QUANT_OPS[self.quant_key]

        assert self.quant_key in FUSED_OPS, \
            f"unsupported fusion scheme {self.quant_key}"
        self.FUSED_OP = FUSED_OPS[self.quant_key]

    def empty_quant(self, *args, **kwargs):
        kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
        return torch.empty(*args, **kwargs)

    @abstractmethod
    def register(self, pm_pass: PatternMatcherPass):
        raise NotImplementedError

FUSED_OP instance-attribute

FUSED_OP = FUSED_OPS[quant_key]

QUANT_OP instance-attribute

QUANT_OP = QUANT_OPS[quant_key]

quant_dtype instance-attribute

quant_dtype = dtype

quant_key instance-attribute

quant_key = quant_key

__init__

__init__(quant_key: QuantKey)
Source code in vllm/compilation/activation_quant_fusion.py
def __init__(
    self,
    quant_key: QuantKey,
):
    self.quant_key = quant_key
    self.quant_dtype = quant_key.dtype

    assert self.quant_key in QUANT_OPS, \
        f"unsupported quantization scheme {self.quant_key}"
    self.QUANT_OP = QUANT_OPS[self.quant_key]

    assert self.quant_key in FUSED_OPS, \
        f"unsupported fusion scheme {self.quant_key}"
    self.FUSED_OP = FUSED_OPS[self.quant_key]

empty_quant

empty_quant(*args, **kwargs)
Source code in vllm/compilation/activation_quant_fusion.py
def empty_quant(self, *args, **kwargs):
    kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
    return torch.empty(*args, **kwargs)

register abstractmethod

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/activation_quant_fusion.py
@abstractmethod
def register(self, pm_pass: PatternMatcherPass):
    raise NotImplementedError

SiluMulFp8StaticQuantPattern

Bases: ActivationQuantPattern

Fusion for SiluMul+Fp8StaticQuant Pattern

Source code in vllm/compilation/activation_quant_fusion.py
class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
    """
    Fusion for SiluMul+Fp8StaticQuant Pattern
    """

    def __init__(self, symmetric: bool = True):
        quant_key = QuantKey(dtype=FP8_DTYPE,
                             scale=kStaticTensorScale,
                             symmetric=symmetric)
        super().__init__(quant_key)

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(result: torch.Tensor, result_silu_mul: torch.Tensor,
                    input: torch.Tensor, scale: torch.Tensor):
            at1 = auto_functionalized(SILU_MUL_OP,
                                      result=result_silu_mul,
                                      input=input)
            at2 = auto_functionalized(self.QUANT_OP,
                                      result=result,
                                      input=at1[1],
                                      scale=scale)
            return at2[1]

        def replacement(result: torch.Tensor, result_silu_mul: torch.Tensor,
                        input: torch.Tensor, scale: torch.Tensor):
            at = auto_functionalized(self.FUSED_OP,
                                     result=result,
                                     input=input,
                                     scale=scale)
            return at[1]

        inputs = [
            self.empty_quant(5, 4),  # result
            empty_bf16(5, 4),  # result_silu_mul
            empty_bf16(5, 4),  # input
            empty_fp32(1, 1)  # scale
        ]

        register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)

__init__

__init__(symmetric: bool = True)
Source code in vllm/compilation/activation_quant_fusion.py
def __init__(self, symmetric: bool = True):
    quant_key = QuantKey(dtype=FP8_DTYPE,
                         scale=kStaticTensorScale,
                         symmetric=symmetric)
    super().__init__(quant_key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/activation_quant_fusion.py
def register(self, pm_pass: PatternMatcherPass):

    def pattern(result: torch.Tensor, result_silu_mul: torch.Tensor,
                input: torch.Tensor, scale: torch.Tensor):
        at1 = auto_functionalized(SILU_MUL_OP,
                                  result=result_silu_mul,
                                  input=input)
        at2 = auto_functionalized(self.QUANT_OP,
                                  result=result,
                                  input=at1[1],
                                  scale=scale)
        return at2[1]

    def replacement(result: torch.Tensor, result_silu_mul: torch.Tensor,
                    input: torch.Tensor, scale: torch.Tensor):
        at = auto_functionalized(self.FUSED_OP,
                                 result=result,
                                 input=input,
                                 scale=scale)
        return at[1]

    inputs = [
        self.empty_quant(5, 4),  # result
        empty_bf16(5, 4),  # result_silu_mul
        empty_bf16(5, 4),  # input
        empty_fp32(1, 1)  # scale
    ]

    register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)

SiluMulNvfp4QuantPattern

Bases: ActivationQuantPattern

Fusion for SiluMul+Nvfp4Quant Pattern

Source code in vllm/compilation/activation_quant_fusion.py
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
    """
    Fusion for SiluMul+Nvfp4Quant Pattern
    """

    def __init__(self):
        super().__init__(kNvfp4Quant)

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(result: torch.Tensor, output_scale: torch.Tensor,
                    result_silu_mul: torch.Tensor, input: torch.Tensor,
                    scale: torch.Tensor):
            at1 = auto_functionalized(SILU_MUL_OP,
                                      result=result_silu_mul,
                                      input=input)
            at2 = auto_functionalized(self.QUANT_OP,
                                      output=result,
                                      input=at1[1],
                                      output_scale=output_scale,
                                      input_scale=scale)
            return at2[1], at2[2]

        def replacement(result: torch.Tensor, output_scale: torch.Tensor,
                        result_silu_mul: torch.Tensor, input: torch.Tensor,
                        scale: torch.Tensor):
            at = auto_functionalized(self.FUSED_OP,
                                     result=result,
                                     result_block_scale=output_scale,
                                     input=input,
                                     input_global_scale=scale)
            return at[1], at[2]

        inputs = [
            self.empty_quant(5, 32),  # result
            empty_i32(128, 4),  # output_scale
            empty_bf16(5, 64),  # result_silu_mul
            empty_bf16(5, 64),  # input
            empty_fp32(1, 1)  # scale
        ]

        register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)

__init__

__init__()
Source code in vllm/compilation/activation_quant_fusion.py
def __init__(self):
    super().__init__(kNvfp4Quant)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/activation_quant_fusion.py
def register(self, pm_pass: PatternMatcherPass):

    def pattern(result: torch.Tensor, output_scale: torch.Tensor,
                result_silu_mul: torch.Tensor, input: torch.Tensor,
                scale: torch.Tensor):
        at1 = auto_functionalized(SILU_MUL_OP,
                                  result=result_silu_mul,
                                  input=input)
        at2 = auto_functionalized(self.QUANT_OP,
                                  output=result,
                                  input=at1[1],
                                  output_scale=output_scale,
                                  input_scale=scale)
        return at2[1], at2[2]

    def replacement(result: torch.Tensor, output_scale: torch.Tensor,
                    result_silu_mul: torch.Tensor, input: torch.Tensor,
                    scale: torch.Tensor):
        at = auto_functionalized(self.FUSED_OP,
                                 result=result,
                                 result_block_scale=output_scale,
                                 input=input,
                                 input_global_scale=scale)
        return at[1], at[2]

    inputs = [
        self.empty_quant(5, 32),  # result
        empty_i32(128, 4),  # output_scale
        empty_bf16(5, 64),  # result_silu_mul
        empty_bf16(5, 64),  # input
        empty_fp32(1, 1)  # scale
    ]

    register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)