Skip to content

vllm.model_executor.layers.mamba.linear_attn

MiniMaxText01LinearAttention

Bases: Module, MambaBase

Source code in vllm/model_executor/layers/mamba/linear_attn.py
class MiniMaxText01LinearAttention(nn.Module, MambaBase):

    @property
    def mamba_type(self) -> str:
        return "linear_attention"

    def get_attn_backend(self) -> type["AttentionBackend"]:
        from vllm.v1.attention.backends.linear_attn import (
            LinearAttentionBackend)
        return LinearAttentionBackend

    def get_state_dtype(self) -> tuple[torch.dtype]:
        assert self.model_config is not None
        assert self.cache_config is not None
        return MambaStateDtypeCalculator.linear_attention_state_dtype(
            self.model_config.dtype,
            self.cache_config.mamba_cache_dtype,
        )

    def get_state_shape(self) -> tuple[tuple[int, int, int], ...]:
        return MambaStateShapeCalculator.linear_attention_state_shape(
            num_heads=self.num_heads,
            tp_size=self.tp_size,
            head_dim=self.head_dim)

    def __init__(
        self,
        hidden_size: int,
        hidden_inner_size: int,
        num_heads: int,
        head_dim: int,
        max_position: int,
        block_size: int,
        num_hidden_layer: int,
        model_config: Optional[ModelConfig] = None,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        layer_idx: int = 0,
        linear_layer_idx: int = 0,
        prefix: str = "linear_attn",
    ) -> None:
        super().__init__()

        self.layer_idx = layer_idx
        self.BLOCK = block_size
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.total_num_heads = num_heads
        self.hidden_inner_size = hidden_inner_size
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()

        assert self.total_num_heads % self.tp_size == 0
        self.tp_heads = self.total_num_heads // self.tp_size
        self.qkv_size = self.num_heads * self.head_dim
        self.tp_hidden = self.head_dim * self.tp_heads
        self.model_config = model_config
        self.cache_config = cache_config
        self.prefix = prefix

        self.qkv_proj = ColumnParallelLinear(
            hidden_size,
            self.hidden_inner_size * 3,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.output_gate = ColumnParallelLinear(
            hidden_size,
            self.hidden_inner_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.output_gate",
        )
        self.out_proj = RowParallelLinear(
            self.hidden_inner_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
        self.norm = MiniMaxText01RMSNormTP(
            self.hidden_inner_size,
            eps=1e-5,
        )

        slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
            self.num_heads)
        if num_hidden_layer <= 1:
            self.slope_rate = slope_rate * (1 + 1e-5)
        else:
            self.slope_rate = slope_rate * (1 - layer_idx /
                                            (num_hidden_layer - 1) + 1e-5)
        self.tp_slope = self.slope_rate[self.tp_rank *
                                        self.tp_heads:(self.tp_rank + 1) *
                                        self.tp_heads].contiguous()

        if envs.VLLM_USE_V1:
            compilation_config = get_current_vllm_config().compilation_config
            if prefix in compilation_config.static_forward_context:
                raise ValueError(f"Duplicate layer name: {prefix}")
            compilation_config.static_forward_context[prefix] = self

    @staticmethod
    def weight_direct_load(param: torch.Tensor,
                           loaded_weight: torch.Tensor) -> None:
        assert param.size() == loaded_weight.size()
        param.data.copy_(loaded_weight)
        return

    @staticmethod
    def _build_slope_tensor(n_attention_heads: int):

        def get_slopes(n):

            def get_slopes_power_of_2(n):
                start = 2**(-(2**-(math.log2(n) - 3)))
                ratio = start
                return [start * ratio**i for i in range(n)]

            if math.log2(n).is_integer():
                return get_slopes_power_of_2(n)
            else:
                closest_power_of_2 = 2**math.floor(math.log2(n))
                return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
                    2 * closest_power_of_2)[0::2][:n - closest_power_of_2])

        slopes = torch.tensor(get_slopes(n_attention_heads),
                              dtype=torch.float32).reshape(
                                  n_attention_heads, 1, 1)
        return slopes

    def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
                               attn_metadata):
        hidden = []
        for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
            if _prefill_idx >= len(attn_metadata.query_start_loc):
                break
            if _prefill_idx >= len(state_indices_tensor):
                break
            # prefills are packed at end of batch in V1
            offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
            _start = attn_metadata.query_start_loc[offset + _prefill_idx]
            _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
            slot_id = state_indices_tensor[offset + _prefill_idx]
            qs = q[_start:_end].transpose(0, 1).contiguous()
            ks = k[_start:_end].transpose(0, 1).contiguous()
            vs = v[_start:_end].transpose(0, 1).contiguous()
            slice_layer_cache = kv_cache[slot_id, ...]

            out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
                qs,
                ks,
                vs,
                slice_layer_cache,
                self.tp_slope,
                self.BLOCK,
                layer_idx=self.layer_idx)
            hidden.append(out_slice.contiguous())
        if attn_metadata.num_decode_tokens > 0:
            hidden_decode = self._decode_infer(q, k, v, kv_cache,
                                               state_indices_tensor,
                                               attn_metadata)
            if envs.VLLM_USE_V1:
                hidden.insert(0, hidden_decode)
            else:
                hidden.append(hidden_decode)

        if not hidden:
            return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)

        hidden = torch.concat(hidden, dim=0).contiguous()
        return hidden

    def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
                      attn_metadata):
        if not envs.VLLM_USE_V1:
            q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
            k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
            v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
            num_prefills = getattr(attn_metadata, "num_prefills", 0)
            slot_id = state_indices_tensor[num_prefills:]
        else:
            q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
            k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
            v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
            slot_id = state_indices_tensor[:attn_metadata.num_decodes]
        hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
                                              slot_id, 32)
        return hidden

    def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
                positions: torch.Tensor,
                kv_caches: MinimaxCacheParams) -> None:
        if not envs.VLLM_USE_V1:
            self._forward(hidden_states, output, positions, kv_caches)
        else:
            torch.ops.vllm.linear_attention(
                hidden_states,
                output,
                positions,
                self.prefix,
            )

    def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
                 positions: torch.Tensor,
                 kv_caches: Optional[MinimaxCacheParams]) -> None:
        forward_context = get_forward_context()
        attn_metadata: AttentionMetadata = forward_context.attn_metadata
        if envs.VLLM_USE_V1 and attn_metadata is not None:
            assert isinstance(attn_metadata, dict)
            attn_metadata = attn_metadata[self.prefix]
            assert isinstance(attn_metadata, LinearAttentionMetadata)
            num_actual_tokens = attn_metadata.num_prefill_tokens + \
                attn_metadata.num_decode_tokens
        else:
            num_actual_tokens = hidden_states.shape[0]

        qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens])
        qkv32 = qkv.to(torch.float32)
        qkvact = torch.nn.functional.silu(qkv32)
        qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
        q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
        if envs.VLLM_USE_V1:
            if attn_metadata is not None:
                kv_cache = self.kv_cache[forward_context.virtual_engine][0]
                state_indices_tensor = attn_metadata.state_indices_tensor

                num_prefills = getattr(attn_metadata, "num_prefills", 0)
                if num_prefills > 0:
                    num_decode_tokens = getattr(attn_metadata,
                                                "num_decode_tokens", 0)
                    for prefill_idx in range(num_prefills):
                        q_start = attn_metadata.query_start_loc[
                            num_decode_tokens + prefill_idx]
                        q_end = attn_metadata.query_start_loc[num_decode_tokens
                                                              + prefill_idx +
                                                              1]
                        query_len = q_end - q_start
                        context_len = attn_metadata.seq_lens[
                            num_decode_tokens + prefill_idx] - query_len
                        if context_len == 0:
                            block_to_clear = state_indices_tensor[
                                num_decode_tokens + prefill_idx]
                            kv_cache[block_to_clear, ...] = 0
        else:
            assert kv_caches is not None
            kv_cache = kv_caches.minimax_cache
            state_indices_tensor = kv_caches.state_indices_tensor

        decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
        if attn_metadata is None:
            hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
                                 device=q.device,
                                 dtype=q.dtype)
        else:
            if not decode_only:
                hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
                                                     state_indices_tensor,
                                                     attn_metadata)
            else:
                hidden = self._decode_infer(q, k, v, kv_cache,
                                            state_indices_tensor,
                                            attn_metadata)
        hidden = self.norm._forward(hidden)
        gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
        hidden = F.sigmoid(gate) * hidden
        hidden = hidden.to(hidden_states.dtype)

        output[:num_actual_tokens], _ = self.out_proj(hidden)

BLOCK instance-attribute

BLOCK = block_size

cache_config instance-attribute

cache_config = cache_config

head_dim instance-attribute

head_dim = head_dim

hidden_inner_size instance-attribute

hidden_inner_size = hidden_inner_size

hidden_size instance-attribute

hidden_size = hidden_size

layer_idx instance-attribute

layer_idx = layer_idx

mamba_type property

mamba_type: str

model_config instance-attribute

model_config = model_config

norm instance-attribute

norm = MiniMaxText01RMSNormTP(hidden_inner_size, eps=1e-05)

num_heads instance-attribute

num_heads = num_heads

out_proj instance-attribute

out_proj = RowParallelLinear(
    hidden_inner_size,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.out_proj",
)

output_gate instance-attribute

output_gate = ColumnParallelLinear(
    hidden_size,
    hidden_inner_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.output_gate",
)

prefix instance-attribute

prefix = prefix

qkv_proj instance-attribute

qkv_proj = ColumnParallelLinear(
    hidden_size,
    hidden_inner_size * 3,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

qkv_size instance-attribute

qkv_size = num_heads * head_dim

slope_rate instance-attribute

slope_rate = slope_rate * (1 + 1e-05)

total_num_heads instance-attribute

total_num_heads = num_heads

tp_heads instance-attribute

tp_heads = total_num_heads // tp_size

tp_hidden instance-attribute

tp_hidden = head_dim * tp_heads

tp_rank instance-attribute

tp_size instance-attribute

tp_slope instance-attribute

tp_slope = contiguous()

__init__

__init__(
    hidden_size: int,
    hidden_inner_size: int,
    num_heads: int,
    head_dim: int,
    max_position: int,
    block_size: int,
    num_hidden_layer: int,
    model_config: Optional[ModelConfig] = None,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    layer_idx: int = 0,
    linear_layer_idx: int = 0,
    prefix: str = "linear_attn",
) -> None
Source code in vllm/model_executor/layers/mamba/linear_attn.py
def __init__(
    self,
    hidden_size: int,
    hidden_inner_size: int,
    num_heads: int,
    head_dim: int,
    max_position: int,
    block_size: int,
    num_hidden_layer: int,
    model_config: Optional[ModelConfig] = None,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    layer_idx: int = 0,
    linear_layer_idx: int = 0,
    prefix: str = "linear_attn",
) -> None:
    super().__init__()

    self.layer_idx = layer_idx
    self.BLOCK = block_size
    self.hidden_size = hidden_size
    self.num_heads = num_heads
    self.head_dim = head_dim
    self.total_num_heads = num_heads
    self.hidden_inner_size = hidden_inner_size
    self.tp_size = get_tensor_model_parallel_world_size()
    self.tp_rank = get_tensor_model_parallel_rank()

    assert self.total_num_heads % self.tp_size == 0
    self.tp_heads = self.total_num_heads // self.tp_size
    self.qkv_size = self.num_heads * self.head_dim
    self.tp_hidden = self.head_dim * self.tp_heads
    self.model_config = model_config
    self.cache_config = cache_config
    self.prefix = prefix

    self.qkv_proj = ColumnParallelLinear(
        hidden_size,
        self.hidden_inner_size * 3,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.qkv_proj",
    )
    self.output_gate = ColumnParallelLinear(
        hidden_size,
        self.hidden_inner_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.output_gate",
    )
    self.out_proj = RowParallelLinear(
        self.hidden_inner_size,
        hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.out_proj",
    )
    self.norm = MiniMaxText01RMSNormTP(
        self.hidden_inner_size,
        eps=1e-5,
    )

    slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
        self.num_heads)
    if num_hidden_layer <= 1:
        self.slope_rate = slope_rate * (1 + 1e-5)
    else:
        self.slope_rate = slope_rate * (1 - layer_idx /
                                        (num_hidden_layer - 1) + 1e-5)
    self.tp_slope = self.slope_rate[self.tp_rank *
                                    self.tp_heads:(self.tp_rank + 1) *
                                    self.tp_heads].contiguous()

    if envs.VLLM_USE_V1:
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

_build_slope_tensor staticmethod

_build_slope_tensor(n_attention_heads: int)
Source code in vllm/model_executor/layers/mamba/linear_attn.py
@staticmethod
def _build_slope_tensor(n_attention_heads: int):

    def get_slopes(n):

        def get_slopes_power_of_2(n):
            start = 2**(-(2**-(math.log2(n) - 3)))
            ratio = start
            return [start * ratio**i for i in range(n)]

        if math.log2(n).is_integer():
            return get_slopes_power_of_2(n)
        else:
            closest_power_of_2 = 2**math.floor(math.log2(n))
            return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
                2 * closest_power_of_2)[0::2][:n - closest_power_of_2])

    slopes = torch.tensor(get_slopes(n_attention_heads),
                          dtype=torch.float32).reshape(
                              n_attention_heads, 1, 1)
    return slopes

_decode_infer

_decode_infer(
    q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
Source code in vllm/model_executor/layers/mamba/linear_attn.py
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
                  attn_metadata):
    if not envs.VLLM_USE_V1:
        q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
        k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
        v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
        num_prefills = getattr(attn_metadata, "num_prefills", 0)
        slot_id = state_indices_tensor[num_prefills:]
    else:
        q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
        k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
        v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
        slot_id = state_indices_tensor[:attn_metadata.num_decodes]
    hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
                                          slot_id, 32)
    return hidden

_forward

_forward(
    hidden_states: Tensor,
    output: Tensor,
    positions: Tensor,
    kv_caches: Optional[MinimaxCacheParams],
) -> None
Source code in vllm/model_executor/layers/mamba/linear_attn.py
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
             positions: torch.Tensor,
             kv_caches: Optional[MinimaxCacheParams]) -> None:
    forward_context = get_forward_context()
    attn_metadata: AttentionMetadata = forward_context.attn_metadata
    if envs.VLLM_USE_V1 and attn_metadata is not None:
        assert isinstance(attn_metadata, dict)
        attn_metadata = attn_metadata[self.prefix]
        assert isinstance(attn_metadata, LinearAttentionMetadata)
        num_actual_tokens = attn_metadata.num_prefill_tokens + \
            attn_metadata.num_decode_tokens
    else:
        num_actual_tokens = hidden_states.shape[0]

    qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens])
    qkv32 = qkv.to(torch.float32)
    qkvact = torch.nn.functional.silu(qkv32)
    qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
    q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
    if envs.VLLM_USE_V1:
        if attn_metadata is not None:
            kv_cache = self.kv_cache[forward_context.virtual_engine][0]
            state_indices_tensor = attn_metadata.state_indices_tensor

            num_prefills = getattr(attn_metadata, "num_prefills", 0)
            if num_prefills > 0:
                num_decode_tokens = getattr(attn_metadata,
                                            "num_decode_tokens", 0)
                for prefill_idx in range(num_prefills):
                    q_start = attn_metadata.query_start_loc[
                        num_decode_tokens + prefill_idx]
                    q_end = attn_metadata.query_start_loc[num_decode_tokens
                                                          + prefill_idx +
                                                          1]
                    query_len = q_end - q_start
                    context_len = attn_metadata.seq_lens[
                        num_decode_tokens + prefill_idx] - query_len
                    if context_len == 0:
                        block_to_clear = state_indices_tensor[
                            num_decode_tokens + prefill_idx]
                        kv_cache[block_to_clear, ...] = 0
    else:
        assert kv_caches is not None
        kv_cache = kv_caches.minimax_cache
        state_indices_tensor = kv_caches.state_indices_tensor

    decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
    if attn_metadata is None:
        hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
                             device=q.device,
                             dtype=q.dtype)
    else:
        if not decode_only:
            hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
                                                 state_indices_tensor,
                                                 attn_metadata)
        else:
            hidden = self._decode_infer(q, k, v, kv_cache,
                                        state_indices_tensor,
                                        attn_metadata)
    hidden = self.norm._forward(hidden)
    gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
    hidden = F.sigmoid(gate) * hidden
    hidden = hidden.to(hidden_states.dtype)

    output[:num_actual_tokens], _ = self.out_proj(hidden)

_prefill_and_mix_infer

_prefill_and_mix_infer(
    q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
Source code in vllm/model_executor/layers/mamba/linear_attn.py
def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
                           attn_metadata):
    hidden = []
    for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
        if _prefill_idx >= len(attn_metadata.query_start_loc):
            break
        if _prefill_idx >= len(state_indices_tensor):
            break
        # prefills are packed at end of batch in V1
        offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
        _start = attn_metadata.query_start_loc[offset + _prefill_idx]
        _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
        slot_id = state_indices_tensor[offset + _prefill_idx]
        qs = q[_start:_end].transpose(0, 1).contiguous()
        ks = k[_start:_end].transpose(0, 1).contiguous()
        vs = v[_start:_end].transpose(0, 1).contiguous()
        slice_layer_cache = kv_cache[slot_id, ...]

        out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
            qs,
            ks,
            vs,
            slice_layer_cache,
            self.tp_slope,
            self.BLOCK,
            layer_idx=self.layer_idx)
        hidden.append(out_slice.contiguous())
    if attn_metadata.num_decode_tokens > 0:
        hidden_decode = self._decode_infer(q, k, v, kv_cache,
                                           state_indices_tensor,
                                           attn_metadata)
        if envs.VLLM_USE_V1:
            hidden.insert(0, hidden_decode)
        else:
            hidden.append(hidden_decode)

    if not hidden:
        return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)

    hidden = torch.concat(hidden, dim=0).contiguous()
    return hidden

forward

forward(
    hidden_states: Tensor,
    output: Tensor,
    positions: Tensor,
    kv_caches: MinimaxCacheParams,
) -> None
Source code in vllm/model_executor/layers/mamba/linear_attn.py
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
            positions: torch.Tensor,
            kv_caches: MinimaxCacheParams) -> None:
    if not envs.VLLM_USE_V1:
        self._forward(hidden_states, output, positions, kv_caches)
    else:
        torch.ops.vllm.linear_attention(
            hidden_states,
            output,
            positions,
            self.prefix,
        )

get_attn_backend

get_attn_backend() -> type[AttentionBackend]
Source code in vllm/model_executor/layers/mamba/linear_attn.py
def get_attn_backend(self) -> type["AttentionBackend"]:
    from vllm.v1.attention.backends.linear_attn import (
        LinearAttentionBackend)
    return LinearAttentionBackend

get_state_dtype

get_state_dtype() -> tuple[dtype]
Source code in vllm/model_executor/layers/mamba/linear_attn.py
def get_state_dtype(self) -> tuple[torch.dtype]:
    assert self.model_config is not None
    assert self.cache_config is not None
    return MambaStateDtypeCalculator.linear_attention_state_dtype(
        self.model_config.dtype,
        self.cache_config.mamba_cache_dtype,
    )

get_state_shape

get_state_shape() -> tuple[tuple[int, int, int], ...]
Source code in vllm/model_executor/layers/mamba/linear_attn.py
def get_state_shape(self) -> tuple[tuple[int, int, int], ...]:
    return MambaStateShapeCalculator.linear_attention_state_shape(
        num_heads=self.num_heads,
        tp_size=self.tp_size,
        head_dim=self.head_dim)

weight_direct_load staticmethod

weight_direct_load(
    param: Tensor, loaded_weight: Tensor
) -> None
Source code in vllm/model_executor/layers/mamba/linear_attn.py
@staticmethod
def weight_direct_load(param: torch.Tensor,
                       loaded_weight: torch.Tensor) -> None:
    assert param.size() == loaded_weight.size()
    param.data.copy_(loaded_weight)
    return

MiniMaxText01LinearKernel

Source code in vllm/model_executor/layers/mamba/linear_attn.py
class MiniMaxText01LinearKernel:

    @staticmethod
    def jit_linear_forward_prefix(q: torch.Tensor,
                                  k: torch.Tensor,
                                  v: torch.Tensor,
                                  kv_caches: torch.Tensor,
                                  slope_rate: torch.Tensor,
                                  block_size: int,
                                  layer_idx: Optional[int] = None,
                                  **kwargs) -> torch.Tensor:

        slope_rate = slope_rate.to(torch.float32)
        should_pad_dim = q.dim() == 3
        if should_pad_dim:
            q = q.unsqueeze(0)
            k = k.unsqueeze(0)
            v = v.unsqueeze(0)
        b, h, n, d = q.shape
        e = d
        kv_history = kv_caches.reshape(1, h, d, e).contiguous()
        output, kv_history = lightning_attention(q,
                                                 k,
                                                 v,
                                                 slope_rate,
                                                 block_size=block_size,
                                                 kv_history=kv_history)
        kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
        assert output.shape[0] == 1, "batch size must be 1"
        return rearrange(output.squeeze(0), "h n d -> n (h d)")

jit_linear_forward_prefix staticmethod

jit_linear_forward_prefix(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    kv_caches: Tensor,
    slope_rate: Tensor,
    block_size: int,
    layer_idx: Optional[int] = None,
    **kwargs,
) -> Tensor
Source code in vllm/model_executor/layers/mamba/linear_attn.py
@staticmethod
def jit_linear_forward_prefix(q: torch.Tensor,
                              k: torch.Tensor,
                              v: torch.Tensor,
                              kv_caches: torch.Tensor,
                              slope_rate: torch.Tensor,
                              block_size: int,
                              layer_idx: Optional[int] = None,
                              **kwargs) -> torch.Tensor:

    slope_rate = slope_rate.to(torch.float32)
    should_pad_dim = q.dim() == 3
    if should_pad_dim:
        q = q.unsqueeze(0)
        k = k.unsqueeze(0)
        v = v.unsqueeze(0)
    b, h, n, d = q.shape
    e = d
    kv_history = kv_caches.reshape(1, h, d, e).contiguous()
    output, kv_history = lightning_attention(q,
                                             k,
                                             v,
                                             slope_rate,
                                             block_size=block_size,
                                             kv_history=kv_history)
    kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
    assert output.shape[0] == 1, "batch size must be 1"
    return rearrange(output.squeeze(0), "h n d -> n (h d)")

MiniMaxText01RMSNormTP

Bases: CustomOp

Source code in vllm/model_executor/layers/mamba/linear_attn.py
class MiniMaxText01RMSNormTP(CustomOp):
    name = "MiniMaxText01RMSNormTP"

    def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.tp_world = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.weight = nn.Parameter(torch.ones(int(hidden_size /
                                                  self.tp_world)))

        self.weight.weight_loader = self.weight_loader
        self.variance_epsilon = eps
        return

    @staticmethod
    def weight_loader(
        param: nn.Parameter,
        loaded_weight: torch.Tensor,
    ) -> None:
        tp_world = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()

        shard_size = loaded_weight.shape[0] // tp_world
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        param.data.copy_(loaded_weight[shard])
        return

    def _forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
        if self.tp_world > 1:
            variance = tensor_model_parallel_all_reduce(
                variance) / self.tp_world
        x = x * torch.rsqrt(variance + self.variance_epsilon)

        weight = self.weight
        if x.size(-1) != self.weight.size(0):
            if self.weight.size(0) < x.size(-1):
                repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
                full_weight = self.weight.repeat(repeat_count)
                weight = full_weight[:x.size(-1)]
            else:
                weight = self.weight[:x.size(-1)]

        x = x.to(orig_dtype) * weight
        return x

    def forward(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        assert residual is None, "RMSNorm does not support residual connection."
        return self._forward(x)

name class-attribute instance-attribute

name = 'MiniMaxText01RMSNormTP'

tp_rank instance-attribute

tp_world instance-attribute

variance_epsilon instance-attribute

variance_epsilon = eps

weight instance-attribute

weight = Parameter(ones(int(hidden_size / tp_world)))

__init__

__init__(hidden_size: int, eps: float = 1e-06) -> None
Source code in vllm/model_executor/layers/mamba/linear_attn.py
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
    super().__init__()
    self.tp_world = get_tensor_model_parallel_world_size()
    self.tp_rank = get_tensor_model_parallel_rank()
    self.weight = nn.Parameter(torch.ones(int(hidden_size /
                                              self.tp_world)))

    self.weight.weight_loader = self.weight_loader
    self.variance_epsilon = eps
    return

_forward

_forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/layers/mamba/linear_attn.py
def _forward(
    self,
    x: torch.Tensor,
) -> torch.Tensor:
    orig_dtype = x.dtype
    x = x.to(torch.float32)
    variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
    if self.tp_world > 1:
        variance = tensor_model_parallel_all_reduce(
            variance) / self.tp_world
    x = x * torch.rsqrt(variance + self.variance_epsilon)

    weight = self.weight
    if x.size(-1) != self.weight.size(0):
        if self.weight.size(0) < x.size(-1):
            repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
            full_weight = self.weight.repeat(repeat_count)
            weight = full_weight[:x.size(-1)]
        else:
            weight = self.weight[:x.size(-1)]

    x = x.to(orig_dtype) * weight
    return x

forward

forward(
    x: Tensor, residual: Optional[Tensor] = None
) -> Union[Tensor, tuple[Tensor, Tensor]]
Source code in vllm/model_executor/layers/mamba/linear_attn.py
def forward(
    self,
    x: torch.Tensor,
    residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
    assert residual is None, "RMSNorm does not support residual connection."
    return self._forward(x)

weight_loader staticmethod

weight_loader(
    param: Parameter, loaded_weight: Tensor
) -> None
Source code in vllm/model_executor/layers/mamba/linear_attn.py
@staticmethod
def weight_loader(
    param: nn.Parameter,
    loaded_weight: torch.Tensor,
) -> None:
    tp_world = get_tensor_model_parallel_world_size()
    tp_rank = get_tensor_model_parallel_rank()

    shard_size = loaded_weight.shape[0] // tp_world
    shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
    param.data.copy_(loaded_weight[shard])
    return

linear_attention

linear_attention(
    hidden_states: Tensor,
    output: Tensor,
    positions: Tensor,
    layer_name: str,
) -> None
Source code in vllm/model_executor/layers/mamba/linear_attn.py
def linear_attention(
    hidden_states: torch.Tensor,
    output: torch.Tensor,
    positions: torch.Tensor,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
    self = forward_context.no_compile_layers[layer_name]
    self._forward(hidden_states=hidden_states,
                  output=output,
                  positions=positions,
                  kv_caches=None)

linear_attention_fake

linear_attention_fake(
    hidden_states: Tensor,
    output: Tensor,
    positions: Tensor,
    layer_name: str,
) -> None
Source code in vllm/model_executor/layers/mamba/linear_attn.py
def linear_attention_fake(
    hidden_states: torch.Tensor,
    output: torch.Tensor,
    positions: torch.Tensor,
    layer_name: str,
) -> None:
    return