Skip to content

vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope

Ernie4_5_VLRotaryEmbedding

Bases: MRotaryEmbedding

3D rotary positional embedding. 3D is t:time h:height w:width

Source code in vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py
class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
    """3D rotary positional embedding. 3D is t:time h:height w:width"""

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        assert positions.ndim == 1 or positions.ndim == 2
        assert key is not None

        num_tokens = positions.shape[-1]
        cos_sin = self.cos_sin_cache[positions]
        cos, sin = cos_sin.chunk(2, dim=-1)
        if positions.ndim == 2:
            assert self.mrope_section

            section_h = self.mrope_section[0]  # 22
            section_w = self.mrope_section[1]  # 22
            section_t = self.mrope_section[2]  # 20
            assert section_h == section_w
            # Split according to [h w h w h w h w... t t t...]
            section_cos_t = cos[..., -section_t:]
            section_cos_h = cos[..., :section_h + section_w:2]
            section_cos_w = cos[..., 1:section_h + section_w:2]

            cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[
                1], section_cos_w[2]
            cos_hw = torch.stack([cos_h, cos_w],
                                 dim=-1).reshape(cos_h.shape[:-1] +
                                                 (cos_h.shape[-1] * 2, ))
            cos = torch.cat([cos_hw, cos_t], dim=-1)

            section_sin_t = sin[..., -section_t:]
            section_sin_h = sin[..., :section_h + section_w:2]
            section_sin_w = sin[..., 1:section_h + section_w:2]

            sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[
                1], section_sin_w[2]
            sin_hw = torch.stack([sin_h, sin_w],
                                 dim=-1).reshape(sin_h.shape[:-1] +
                                                 (sin_h.shape[-1] * 2, ))
            sin = torch.cat([sin_hw, sin_t], dim=-1)

        query_shape = query.shape
        query = query.view(num_tokens, -1, self.head_size)
        query_rot = query[..., :self.rotary_dim]
        query_pass = query[..., self.rotary_dim:]
        query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin,
                                              self.is_neox_style)
        query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

        key_shape = key.shape
        key = key.view(num_tokens, -1, self.head_size)
        key_rot = key[..., :self.rotary_dim]
        key_pass = key[..., self.rotary_dim:]
        key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin,
                                            self.is_neox_style)
        key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
        return query, key

forward

forward(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py
def forward(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    assert positions.ndim == 1 or positions.ndim == 2
    assert key is not None

    num_tokens = positions.shape[-1]
    cos_sin = self.cos_sin_cache[positions]
    cos, sin = cos_sin.chunk(2, dim=-1)
    if positions.ndim == 2:
        assert self.mrope_section

        section_h = self.mrope_section[0]  # 22
        section_w = self.mrope_section[1]  # 22
        section_t = self.mrope_section[2]  # 20
        assert section_h == section_w
        # Split according to [h w h w h w h w... t t t...]
        section_cos_t = cos[..., -section_t:]
        section_cos_h = cos[..., :section_h + section_w:2]
        section_cos_w = cos[..., 1:section_h + section_w:2]

        cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[
            1], section_cos_w[2]
        cos_hw = torch.stack([cos_h, cos_w],
                             dim=-1).reshape(cos_h.shape[:-1] +
                                             (cos_h.shape[-1] * 2, ))
        cos = torch.cat([cos_hw, cos_t], dim=-1)

        section_sin_t = sin[..., -section_t:]
        section_sin_h = sin[..., :section_h + section_w:2]
        section_sin_w = sin[..., 1:section_h + section_w:2]

        sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[
            1], section_sin_w[2]
        sin_hw = torch.stack([sin_h, sin_w],
                             dim=-1).reshape(sin_h.shape[:-1] +
                                             (sin_h.shape[-1] * 2, ))
        sin = torch.cat([sin_hw, sin_t], dim=-1)

    query_shape = query.shape
    query = query.view(num_tokens, -1, self.head_size)
    query_rot = query[..., :self.rotary_dim]
    query_pass = query[..., self.rotary_dim:]
    query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin,
                                          self.is_neox_style)
    query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

    key_shape = key.shape
    key = key.view(num_tokens, -1, self.head_size)
    key_rot = key[..., :self.rotary_dim]
    key_pass = key[..., self.rotary_dim:]
    key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin,
                                        self.is_neox_style)
    key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
    return query, key