Skip to content

vllm.model_executor.layers.fused_moe.trtllm_moe

TrtLlmGenExperts

Bases: FusedMoEPermuteExpertsUnpermute

Source code in vllm/model_executor/layers/fused_moe/trtllm_moe.py
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):

    def __init__(
        self,
        moe: FusedMoEConfig,
        gemm1_alpha,
        gemm1_beta,
        gemm1_clamp_limit,
        w13_bias,
        w2_bias,
        max_capture_size,
    ):
        super().__init__(moe.quant_config)
        self.moe = moe
        self.gemm1_alpha = gemm1_alpha
        self.gemm1_beta = gemm1_beta
        self.gemm1_clamp_limit = gemm1_clamp_limit
        self.w13_bias = w13_bias
        self.w2_bias = w2_bias
        self.max_capture_size = max_capture_size

    @property
    def activation_formats(
        self
    ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
        return (mk.FusedMoEActivationFormat.Standard,
                mk.FusedMoEActivationFormat.Standard)

    def supports_chunking(self) -> bool:
        return True

    def supports_expert_map(self) -> bool:
        return True

    def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
        return TopKWeightAndReduceNoOP()

    def workspace_shapes(
        self,
        a: torch.Tensor,
        aq: torch.Tensor,
        M: int,
        N: int,
        K: int,
        topk: int,
        global_num_experts: int,
        local_num_experts: int,
        expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
        # The workspaces for this implementation are managed by flashinfer.
        # TODO(varun) : workspace1 is could be used as the output tensor. This
        # is error-prone. Allow the `workspace_shapes` to return None workspaces
        workspace1 = (M, K)
        workspace2 = (0, 0)
        output = (M, K)
        return (workspace1, workspace2, output, a.dtype)

    def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int,
                             local_num_experts: int):
        # Number of tokens in the input tensor.
        num_tokens = x.shape[0]
        # Factor to account for the imbalance of the experts.
        # factor equals to the
        # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
        # 1.0 means perfect expert distribution.
        # > 1.0 means some experts have more tokens than the perfect
        # distribution.
        # < 1.0 does not make sense.
        imbalance_factor = 1.3
        # Calculate the number of tokens per expert assuming perfect
        # distribution.
        num_tokens_per_expert = (num_tokens * top_k) // local_num_experts
        # Apply the imbalance factor.
        num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
        # And pad the number to the next power of 2.
        tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
        # Cap to 8-64 tokens per CTA tile as it's the range supported by the
        #  kernel.
        tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)

        return tile_tokens_dim

    def apply(
        self,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        activation: str,
        global_num_experts: int,
        expert_map: Optional[torch.Tensor],
        w1_scale: Optional[torch.Tensor],
        w2_scale: Optional[torch.Tensor],
        w1_zp: Optional[torch.Tensor],
        w2_zp: Optional[torch.Tensor],
        a1q_scale: Optional[torch.Tensor],
        a2_scale: Optional[torch.Tensor],
        workspace13: torch.Tensor,
        workspace2: torch.Tensor,
        expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
        apply_router_weight_on_input: bool,
    ):
        topk = topk_ids.size(-1)
        local_num_experts = w1.size(0)
        intermediate_size = w2.size(1)
        local_expert_offset = self.moe.ep_rank * local_num_experts

        x_quant = hidden_states
        x_scale = a1q_scale
        if x_scale is not None:
            x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
                *x_quant.shape[:-1], -1)

        packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
            torch.bfloat16).view(torch.int16)

        assert w1_scale is not None
        assert w2_scale is not None
        kwargs = {
            "topk_ids":
            packed_tensor,
            "routing_bias":
            None,
            "hidden_states":
            x_quant,
            "hidden_states_scale":
            x_scale,
            "gemm1_weights":
            w1,
            "gemm1_weights_scale":
            w1_scale,
            "gemm1_bias":
            self.w13_bias,
            "gemm1_alpha":
            self.gemm1_alpha,
            "gemm1_beta":
            self.gemm1_beta,
            "gemm1_clamp_limit":
            self.gemm1_clamp_limit,
            "gemm2_weights":
            w2,
            "gemm2_weights_scale":
            w2_scale,
            "gemm2_bias":
            self.w2_bias,
            "output1_scale_scalar":
            None,
            "output1_scale_gate_scalar":
            None,
            "output2_scale_scalar":
            None,
            "num_experts":
            global_num_experts,
            "top_k":
            topk,
            "n_group":
            None,
            "topk_group":
            None,
            "intermediate_size":
            intermediate_size,
            "local_expert_offset":
            local_expert_offset,
            "local_num_experts":
            local_num_experts,
            "routed_scaling_factor":
            None,
            "tile_tokens_dim":
            self._get_tile_tokens_dim(x_quant, topk, local_num_experts),
            "routing_method_type":
            1,
            "do_finalize":
            True,
            "output":
            output,
            "tune_max_num_tokens":
            self.max_capture_size,
        }

        from flashinfer import trtllm_fp4_block_scale_routed_moe
        trtllm_fp4_block_scale_routed_moe(**kwargs)
        return output

activation_formats property

gemm1_alpha instance-attribute

gemm1_alpha = gemm1_alpha

gemm1_beta instance-attribute

gemm1_beta = gemm1_beta

gemm1_clamp_limit instance-attribute

gemm1_clamp_limit = gemm1_clamp_limit

max_capture_size instance-attribute

max_capture_size = max_capture_size

moe instance-attribute

moe = moe

w13_bias instance-attribute

w13_bias = w13_bias

w2_bias instance-attribute

w2_bias = w2_bias

__init__

__init__(
    moe: FusedMoEConfig,
    gemm1_alpha,
    gemm1_beta,
    gemm1_clamp_limit,
    w13_bias,
    w2_bias,
    max_capture_size,
)
Source code in vllm/model_executor/layers/fused_moe/trtllm_moe.py
def __init__(
    self,
    moe: FusedMoEConfig,
    gemm1_alpha,
    gemm1_beta,
    gemm1_clamp_limit,
    w13_bias,
    w2_bias,
    max_capture_size,
):
    super().__init__(moe.quant_config)
    self.moe = moe
    self.gemm1_alpha = gemm1_alpha
    self.gemm1_beta = gemm1_beta
    self.gemm1_clamp_limit = gemm1_clamp_limit
    self.w13_bias = w13_bias
    self.w2_bias = w2_bias
    self.max_capture_size = max_capture_size

_get_tile_tokens_dim

_get_tile_tokens_dim(
    x: Tensor, top_k: int, local_num_experts: int
)
Source code in vllm/model_executor/layers/fused_moe/trtllm_moe.py
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int,
                         local_num_experts: int):
    # Number of tokens in the input tensor.
    num_tokens = x.shape[0]
    # Factor to account for the imbalance of the experts.
    # factor equals to the
    # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
    # 1.0 means perfect expert distribution.
    # > 1.0 means some experts have more tokens than the perfect
    # distribution.
    # < 1.0 does not make sense.
    imbalance_factor = 1.3
    # Calculate the number of tokens per expert assuming perfect
    # distribution.
    num_tokens_per_expert = (num_tokens * top_k) // local_num_experts
    # Apply the imbalance factor.
    num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
    # And pad the number to the next power of 2.
    tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
    # Cap to 8-64 tokens per CTA tile as it's the range supported by the
    #  kernel.
    tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)

    return tile_tokens_dim

apply

apply(
    output: Tensor,
    hidden_states: Tensor,
    w1: Tensor,
    w2: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    activation: str,
    global_num_experts: int,
    expert_map: Optional[Tensor],
    w1_scale: Optional[Tensor],
    w2_scale: Optional[Tensor],
    w1_zp: Optional[Tensor],
    w2_zp: Optional[Tensor],
    a1q_scale: Optional[Tensor],
    a2_scale: Optional[Tensor],
    workspace13: Tensor,
    workspace2: Tensor,
    expert_tokens_meta: Optional[ExpertTokensMetadata],
    apply_router_weight_on_input: bool,
)
Source code in vllm/model_executor/layers/fused_moe/trtllm_moe.py
def apply(
    self,
    output: torch.Tensor,
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    activation: str,
    global_num_experts: int,
    expert_map: Optional[torch.Tensor],
    w1_scale: Optional[torch.Tensor],
    w2_scale: Optional[torch.Tensor],
    w1_zp: Optional[torch.Tensor],
    w2_zp: Optional[torch.Tensor],
    a1q_scale: Optional[torch.Tensor],
    a2_scale: Optional[torch.Tensor],
    workspace13: torch.Tensor,
    workspace2: torch.Tensor,
    expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
    apply_router_weight_on_input: bool,
):
    topk = topk_ids.size(-1)
    local_num_experts = w1.size(0)
    intermediate_size = w2.size(1)
    local_expert_offset = self.moe.ep_rank * local_num_experts

    x_quant = hidden_states
    x_scale = a1q_scale
    if x_scale is not None:
        x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
            *x_quant.shape[:-1], -1)

    packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
        torch.bfloat16).view(torch.int16)

    assert w1_scale is not None
    assert w2_scale is not None
    kwargs = {
        "topk_ids":
        packed_tensor,
        "routing_bias":
        None,
        "hidden_states":
        x_quant,
        "hidden_states_scale":
        x_scale,
        "gemm1_weights":
        w1,
        "gemm1_weights_scale":
        w1_scale,
        "gemm1_bias":
        self.w13_bias,
        "gemm1_alpha":
        self.gemm1_alpha,
        "gemm1_beta":
        self.gemm1_beta,
        "gemm1_clamp_limit":
        self.gemm1_clamp_limit,
        "gemm2_weights":
        w2,
        "gemm2_weights_scale":
        w2_scale,
        "gemm2_bias":
        self.w2_bias,
        "output1_scale_scalar":
        None,
        "output1_scale_gate_scalar":
        None,
        "output2_scale_scalar":
        None,
        "num_experts":
        global_num_experts,
        "top_k":
        topk,
        "n_group":
        None,
        "topk_group":
        None,
        "intermediate_size":
        intermediate_size,
        "local_expert_offset":
        local_expert_offset,
        "local_num_experts":
        local_num_experts,
        "routed_scaling_factor":
        None,
        "tile_tokens_dim":
        self._get_tile_tokens_dim(x_quant, topk, local_num_experts),
        "routing_method_type":
        1,
        "do_finalize":
        True,
        "output":
        output,
        "tune_max_num_tokens":
        self.max_capture_size,
    }

    from flashinfer import trtllm_fp4_block_scale_routed_moe
    trtllm_fp4_block_scale_routed_moe(**kwargs)
    return output

finalize_weight_and_reduce_impl

finalize_weight_and_reduce_impl() -> TopKWeightAndReduce
Source code in vllm/model_executor/layers/fused_moe/trtllm_moe.py
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
    return TopKWeightAndReduceNoOP()

supports_chunking

supports_chunking() -> bool
Source code in vllm/model_executor/layers/fused_moe/trtllm_moe.py
def supports_chunking(self) -> bool:
    return True

supports_expert_map

supports_expert_map() -> bool
Source code in vllm/model_executor/layers/fused_moe/trtllm_moe.py
def supports_expert_map(self) -> bool:
    return True

workspace_shapes

workspace_shapes(
    a: Tensor,
    aq: Tensor,
    M: int,
    N: int,
    K: int,
    topk: int,
    global_num_experts: int,
    local_num_experts: int,
    expert_tokens_meta: Optional[ExpertTokensMetadata],
) -> tuple[
    tuple[int, ...], tuple[int, ...], tuple[int, ...], dtype
]
Source code in vllm/model_executor/layers/fused_moe/trtllm_moe.py
def workspace_shapes(
    self,
    a: torch.Tensor,
    aq: torch.Tensor,
    M: int,
    N: int,
    K: int,
    topk: int,
    global_num_experts: int,
    local_num_experts: int,
    expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
    # The workspaces for this implementation are managed by flashinfer.
    # TODO(varun) : workspace1 is could be used as the output tensor. This
    # is error-prone. Allow the `workspace_shapes` to return None workspaces
    workspace1 = (M, K)
    workspace2 = (0, 0)
    output = (M, K)
    return (workspace1, workspace2, output, a.dtype)