Skip to content

vllm.multimodal.cache

MultiModalProcessorCacheInItem module-attribute

MultiModalProcessorCacheInItem: TypeAlias = Optional[
    tuple[
        MultiModalKwargsItem,
        Sequence["ResolvedPromptUpdate"],
    ]
]

MultiModalProcessorCacheOutItem module-attribute

MultiModalProcessorCacheOutItem: TypeAlias = tuple[
    Optional[MultiModalKwargsItem],
    Sequence["ResolvedPromptUpdate"],
]

_I module-attribute

_I = TypeVar('_I', contravariant=True)

_O module-attribute

_O = TypeVar('_O', covariant=True)

_V module-attribute

_V = TypeVar('_V', bound=MultiModalCacheValue)

logger module-attribute

logger = init_logger(__name__)

BaseMultiModalCache

Bases: ABC, Generic[_I, _O]

Abstract base class to read/write multi-modal items from cache.

The idea of multi-modal caching is based on having a client and server where the client executes in the frontend process (=P0) and the server in the core process (=P1). The data flow is as follows:

              is_cached() x N    get_and_update()
P0: From API -----------------> -----------------> To P1

             get_and_update()
P1: From P0 -----------------> To model

is_cached() can be called any number of times in P0. However, get_and_update() must be called in P0 and P1 one after another so that their cache eviction order remains the same.

This ensures that the keys in P0 and P1 caches are mirrored, allowing us to determine whether a key is cached in P1 by looking up the P0 cache, without having to communicate with P1.

Source code in vllm/multimodal/cache.py
class BaseMultiModalCache(ABC, Generic[_I, _O]):
    """
    Abstract base class to read/write multi-modal items from cache.

    The idea of multi-modal caching is based on having a client and server
    where the client executes in the frontend process (=P0) and
    the server in the core process (=P1). The data flow is as follows:

    ```
                  is_cached() x N    get_and_update()
    P0: From API -----------------> -----------------> To P1

                 get_and_update()
    P1: From P0 -----------------> To model
    ```

    `is_cached()` can be called any number of times in P0. However,
    `get_and_update()` must be called in P0 and P1 one after another
    so that their cache eviction order remains the same.

    This ensures that the keys in P0 and P1 caches are mirrored,
    allowing us to determine whether a key is cached in P1 by looking
    up the P0 cache, without having to communicate with P1.
    """

    @abstractmethod
    def get_and_update_item(
        self,
        mm_item: _I,
        mm_hash: str,
    ) -> _O:
        """
        Possibly update a multi-modal item based on whether it is
        in the underlying cache.

        This update is done out-of-place and updates the cache eviction order.

        Args:
            mm_item: The multi-modal item to update.
            mm_hash: The hash of `mm_item`.

        Returns:
            The update multi-modal item.
        """
        raise NotImplementedError

    def get_and_update(
        self,
        mm_items: Sequence[_I],
        mm_hashes: list[str],
    ) -> list[_O]:
        """
        Possibly update a sequence of multi-modal items based on whether they
        are in the underlying cache.

        This update is done out-of-place and updates the cache eviction order.

        Args:
            mm_items: The multi-modal items to update.
            mm_hashes: The hash of each item in `mm_items`.

        Returns:
            A new list of updated multi-modal items.
        """
        assert len(mm_items) == len(mm_hashes)

        return [
            self.get_and_update_item(mm_item, mm_hash)
            for mm_item, mm_hash in zip(mm_items, mm_hashes)
        ]

    @abstractmethod
    def clear_cache(self) -> None:
        """Clear the underlying cache."""
        raise NotImplementedError

clear_cache abstractmethod

clear_cache() -> None

Clear the underlying cache.

Source code in vllm/multimodal/cache.py
@abstractmethod
def clear_cache(self) -> None:
    """Clear the underlying cache."""
    raise NotImplementedError

get_and_update

get_and_update(
    mm_items: Sequence[_I], mm_hashes: list[str]
) -> list[_O]

Possibly update a sequence of multi-modal items based on whether they are in the underlying cache.

This update is done out-of-place and updates the cache eviction order.

Parameters:

Name Type Description Default
mm_items Sequence[_I]

The multi-modal items to update.

required
mm_hashes list[str]

The hash of each item in mm_items.

required

Returns:

Type Description
list[_O]

A new list of updated multi-modal items.

Source code in vllm/multimodal/cache.py
def get_and_update(
    self,
    mm_items: Sequence[_I],
    mm_hashes: list[str],
) -> list[_O]:
    """
    Possibly update a sequence of multi-modal items based on whether they
    are in the underlying cache.

    This update is done out-of-place and updates the cache eviction order.

    Args:
        mm_items: The multi-modal items to update.
        mm_hashes: The hash of each item in `mm_items`.

    Returns:
        A new list of updated multi-modal items.
    """
    assert len(mm_items) == len(mm_hashes)

    return [
        self.get_and_update_item(mm_item, mm_hash)
        for mm_item, mm_hash in zip(mm_items, mm_hashes)
    ]

get_and_update_item abstractmethod

get_and_update_item(mm_item: _I, mm_hash: str) -> _O

Possibly update a multi-modal item based on whether it is in the underlying cache.

This update is done out-of-place and updates the cache eviction order.

Parameters:

Name Type Description Default
mm_item _I

The multi-modal item to update.

required
mm_hash str

The hash of mm_item.

required

Returns:

Type Description
_O

The update multi-modal item.

Source code in vllm/multimodal/cache.py
@abstractmethod
def get_and_update_item(
    self,
    mm_item: _I,
    mm_hash: str,
) -> _O:
    """
    Possibly update a multi-modal item based on whether it is
    in the underlying cache.

    This update is done out-of-place and updates the cache eviction order.

    Args:
        mm_item: The multi-modal item to update.
        mm_hash: The hash of `mm_item`.

    Returns:
        The update multi-modal item.
    """
    raise NotImplementedError

BaseMultiModalProcessorCache

Bases: BaseMultiModalCache[MultiModalProcessorCacheInItem, MultiModalProcessorCacheOutItem]

The required interface for caches on P0.

Source code in vllm/multimodal/cache.py
class BaseMultiModalProcessorCache(
        BaseMultiModalCache[MultiModalProcessorCacheInItem,
                            MultiModalProcessorCacheOutItem]):
    """The required interface for caches on P0."""

    @abstractmethod
    def is_cached_item(self, mm_hash: str) -> bool:
        """
        Check whether a multi-modal item is
        in the underlying cache.

        This **DOES NOT** update the cache eviction order.

        Args:
            mm_hash: The hash of the item to check.

        Returns:
            `True` if the item is cached, otherwise `False`.
        """
        raise NotImplementedError

    def is_cached(self, mm_hashes: list[str]) -> list[bool]:
        """
        Check whether a sequence of multi-modal items are
        in the underlying cache.

        This **DOES NOT** update the cache eviction order.

        Args:
            mm_hashes: The hash of each item to check.

        Returns:
            For each item, `True` if the item is cached, otherwise `False`.
        """
        return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]

is_cached

is_cached(mm_hashes: list[str]) -> list[bool]

Check whether a sequence of multi-modal items are in the underlying cache.

This DOES NOT update the cache eviction order.

Parameters:

Name Type Description Default
mm_hashes list[str]

The hash of each item to check.

required

Returns:

Type Description
list[bool]

For each item, True if the item is cached, otherwise False.

Source code in vllm/multimodal/cache.py
def is_cached(self, mm_hashes: list[str]) -> list[bool]:
    """
    Check whether a sequence of multi-modal items are
    in the underlying cache.

    This **DOES NOT** update the cache eviction order.

    Args:
        mm_hashes: The hash of each item to check.

    Returns:
        For each item, `True` if the item is cached, otherwise `False`.
    """
    return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]

is_cached_item abstractmethod

is_cached_item(mm_hash: str) -> bool

Check whether a multi-modal item is in the underlying cache.

This DOES NOT update the cache eviction order.

Parameters:

Name Type Description Default
mm_hash str

The hash of the item to check.

required

Returns:

Type Description
bool

True if the item is cached, otherwise False.

Source code in vllm/multimodal/cache.py
@abstractmethod
def is_cached_item(self, mm_hash: str) -> bool:
    """
    Check whether a multi-modal item is
    in the underlying cache.

    This **DOES NOT** update the cache eviction order.

    Args:
        mm_hash: The hash of the item to check.

    Returns:
        `True` if the item is cached, otherwise `False`.
    """
    raise NotImplementedError

BaseMultiModalReceiverCache

Bases: BaseMultiModalCache[Optional[MultiModalKwargsItem], MultiModalKwargsItem]

The required interface for caches on P1.

Source code in vllm/multimodal/cache.py
class BaseMultiModalReceiverCache(
        BaseMultiModalCache[Optional[MultiModalKwargsItem],
                            MultiModalKwargsItem]):
    """The required interface for caches on P1."""

    def get_and_update_features(
        self,
        mm_features: list["MultiModalFeatureSpec"],
    ) -> list["MultiModalFeatureSpec"]:
        """Update multimodal features with cached encoder outputs."""
        for feature in mm_features:
            feature.data = self.get_and_update_item(feature.data,
                                                    feature.identifier)
        return mm_features

get_and_update_features

get_and_update_features(
    mm_features: list[MultiModalFeatureSpec],
) -> list[MultiModalFeatureSpec]

Update multimodal features with cached encoder outputs.

Source code in vllm/multimodal/cache.py
def get_and_update_features(
    self,
    mm_features: list["MultiModalFeatureSpec"],
) -> list["MultiModalFeatureSpec"]:
    """Update multimodal features with cached encoder outputs."""
    for feature in mm_features:
        feature.data = self.get_and_update_item(feature.data,
                                                feature.identifier)
    return mm_features

MultiModalCache

Source code in vllm/multimodal/cache.py
class MultiModalCache:

    @classmethod
    def get_leaf_size(
        cls,
        leaf: object,
        *,
        debug: bool = False,
    ) -> int:
        if isinstance(leaf, MultiModalProcessorCacheItem):
            return cls.get_leaf_size(leaf.item)
        if isinstance(leaf, MultiModalProcessorCacheItemMetadata):
            return leaf.item_size

        # These are not subclasses of dict
        if isinstance(leaf, MultiModalKwargsItems):
            return cls.get_item_size(leaf.data)  # type: ignore
        if isinstance(leaf, MultiModalKwargsItem):
            return cls.get_item_size(leaf.data)  # type: ignore
        if isinstance(leaf, MultiModalKwargs):
            return cls.get_item_size(leaf.data)  # type: ignore

        if isinstance(leaf, MultiModalFieldElem):
            return cls.get_item_size(leaf.data)  # type: ignore

        # sys.getsizeof doesn't work for tensors
        if isinstance(leaf, torch.Tensor):
            return leaf.nbytes

        return sys.getsizeof(leaf)

    @classmethod
    def get_item_size(
        cls,
        value: MultiModalCacheValue,
        *,
        debug: bool = False,
    ) -> int:
        size = json_reduce_leaves(
            lambda a, b: a + b,
            json_map_leaves(lambda x: cls.get_leaf_size(x, debug=debug),
                            value),
        )

        if debug:
            leaf_count = json_count_leaves(value)
            logger.debug(
                "Calculated size of %s to be %.2f GiB (%d leaves)",
                type(value),
                size / GiB_bytes,
                leaf_count,
            )

        return size

    @classmethod
    def get_item_complexity(cls, value: MultiModalCacheValue) -> int:
        """
        Get the number of leaf elements in a multi-modal cache value.

        This provides a measure of structural complexity that can be useful
        for debugging cache performance and understanding data patterns.

        Args:
            value: The multi-modal cache value to analyze.

        Returns:
            The number of leaf elements in the nested structure.
        """
        return json_count_leaves(value)

    @classmethod
    def get_lru_cache(
        cls,
        capacity_gb: float,
        value_type: type[_V],
        *,
        debug: bool = False,
    ) -> LRUCache[str, _V]:
        return LRUCache(
            GiB_bytes * capacity_gb,
            getsizeof=lambda x: cls.get_item_size(x, debug=debug),
        )

get_item_complexity classmethod

get_item_complexity(value: MultiModalCacheValue) -> int

Get the number of leaf elements in a multi-modal cache value.

This provides a measure of structural complexity that can be useful for debugging cache performance and understanding data patterns.

Parameters:

Name Type Description Default
value MultiModalCacheValue

The multi-modal cache value to analyze.

required

Returns:

Type Description
int

The number of leaf elements in the nested structure.

Source code in vllm/multimodal/cache.py
@classmethod
def get_item_complexity(cls, value: MultiModalCacheValue) -> int:
    """
    Get the number of leaf elements in a multi-modal cache value.

    This provides a measure of structural complexity that can be useful
    for debugging cache performance and understanding data patterns.

    Args:
        value: The multi-modal cache value to analyze.

    Returns:
        The number of leaf elements in the nested structure.
    """
    return json_count_leaves(value)

get_item_size classmethod

get_item_size(
    value: MultiModalCacheValue, *, debug: bool = False
) -> int
Source code in vllm/multimodal/cache.py
@classmethod
def get_item_size(
    cls,
    value: MultiModalCacheValue,
    *,
    debug: bool = False,
) -> int:
    size = json_reduce_leaves(
        lambda a, b: a + b,
        json_map_leaves(lambda x: cls.get_leaf_size(x, debug=debug),
                        value),
    )

    if debug:
        leaf_count = json_count_leaves(value)
        logger.debug(
            "Calculated size of %s to be %.2f GiB (%d leaves)",
            type(value),
            size / GiB_bytes,
            leaf_count,
        )

    return size

get_leaf_size classmethod

get_leaf_size(leaf: object, *, debug: bool = False) -> int
Source code in vllm/multimodal/cache.py
@classmethod
def get_leaf_size(
    cls,
    leaf: object,
    *,
    debug: bool = False,
) -> int:
    if isinstance(leaf, MultiModalProcessorCacheItem):
        return cls.get_leaf_size(leaf.item)
    if isinstance(leaf, MultiModalProcessorCacheItemMetadata):
        return leaf.item_size

    # These are not subclasses of dict
    if isinstance(leaf, MultiModalKwargsItems):
        return cls.get_item_size(leaf.data)  # type: ignore
    if isinstance(leaf, MultiModalKwargsItem):
        return cls.get_item_size(leaf.data)  # type: ignore
    if isinstance(leaf, MultiModalKwargs):
        return cls.get_item_size(leaf.data)  # type: ignore

    if isinstance(leaf, MultiModalFieldElem):
        return cls.get_item_size(leaf.data)  # type: ignore

    # sys.getsizeof doesn't work for tensors
    if isinstance(leaf, torch.Tensor):
        return leaf.nbytes

    return sys.getsizeof(leaf)

get_lru_cache classmethod

get_lru_cache(
    capacity_gb: float,
    value_type: type[_V],
    *,
    debug: bool = False,
) -> LRUCache[str, _V]
Source code in vllm/multimodal/cache.py
@classmethod
def get_lru_cache(
    cls,
    capacity_gb: float,
    value_type: type[_V],
    *,
    debug: bool = False,
) -> LRUCache[str, _V]:
    return LRUCache(
        GiB_bytes * capacity_gb,
        getsizeof=lambda x: cls.get_item_size(x, debug=debug),
    )

MultiModalProcessorCacheItem

The data to store inside MultiModalProcessorOnlyCache.

Parameters:

Name Type Description Default
item MultiModalKwargsItem

The processed tensor data corresponding to a multi-modal item.

required
prompt_updates Sequence[ResolvedPromptUpdate]

The prompt updates corresponding to item.

required
Source code in vllm/multimodal/cache.py
class MultiModalProcessorCacheItem:
    """
    The data to store inside `MultiModalProcessorOnlyCache`.

    Args:
        item: The processed tensor data corresponding to a multi-modal item.
        prompt_updates: The prompt updates corresponding to `item`.
    """

    def __init__(
        self,
        item: MultiModalKwargsItem,
        prompt_updates: Sequence["ResolvedPromptUpdate"],
    ) -> None:
        super().__init__()

        self.item = item
        self.prompt_updates = prompt_updates

item instance-attribute

item = item

prompt_updates instance-attribute

prompt_updates = prompt_updates

__init__

__init__(
    item: MultiModalKwargsItem,
    prompt_updates: Sequence[ResolvedPromptUpdate],
) -> None
Source code in vllm/multimodal/cache.py
def __init__(
    self,
    item: MultiModalKwargsItem,
    prompt_updates: Sequence["ResolvedPromptUpdate"],
) -> None:
    super().__init__()

    self.item = item
    self.prompt_updates = prompt_updates

MultiModalProcessorCacheItemMetadata

The metadata to store inside MultiModalProcessorSenderCache.

Parameters:

Name Type Description Default
item MultiModalKwargsItem

The processed tensor data corresponding to a multi-modal item. Since P1 already stores the tensor data, we only store its size metadata in P0 to reduce memory usage. The size metadata is still needed to keep the same cache eviction policy as P0.

required
prompt_updates Sequence[ResolvedPromptUpdate]

The prompt updates corresponding to item. This needs to stay on P0 because for some models, they are dependent on the processed tensor data (cached on P1).

required
Source code in vllm/multimodal/cache.py
class MultiModalProcessorCacheItemMetadata:
    """
    The metadata to store inside `MultiModalProcessorSenderCache`.

    Args:
        item: The processed tensor data corresponding to a multi-modal item.
            Since P1 already stores the tensor data, we only store its size
            metadata in P0 to reduce memory usage. The size metadata is still
            needed to keep the same cache eviction policy as P0.
        prompt_updates: The prompt updates corresponding to `item`.
            This needs to stay on P0 because for some models, they are
            dependent on the processed tensor data (cached on P1).
    """

    def __init__(
        self,
        item: MultiModalKwargsItem,
        prompt_updates: Sequence["ResolvedPromptUpdate"],
    ) -> None:
        super().__init__()

        self.item_size = MultiModalCache.get_item_size(item)
        self.prompt_updates = prompt_updates

item_size instance-attribute

item_size = get_item_size(item)

prompt_updates instance-attribute

prompt_updates = prompt_updates

__init__

__init__(
    item: MultiModalKwargsItem,
    prompt_updates: Sequence[ResolvedPromptUpdate],
) -> None
Source code in vllm/multimodal/cache.py
def __init__(
    self,
    item: MultiModalKwargsItem,
    prompt_updates: Sequence["ResolvedPromptUpdate"],
) -> None:
    super().__init__()

    self.item_size = MultiModalCache.get_item_size(item)
    self.prompt_updates = prompt_updates

MultiModalProcessorOnlyCache

Bases: BaseMultiModalProcessorCache

The cache which is used on P0 when IPC caching is disabled.

How to update each item:

  • If the item is in the cache, replace the input with the cached item.
  • If the item is not in the cache, store that item (which includes tensor data and metadata) into the cache, and return the input.
Source code in vllm/multimodal/cache.py
class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
    """
    The cache which is used on P0 when IPC caching is disabled.

    How to update each item:

    - If the item is in the cache, replace the input with the cached item.
    - If the item is not in the cache, store that item (which includes
      tensor data and metadata) into the cache, and return the input.
    """

    def __init__(self, model_config: "ModelConfig") -> None:
        super().__init__()

        mm_config = model_config.get_multimodal_config()

        self._cache = MultiModalCache.get_lru_cache(
            mm_config.mm_processor_cache_gb,
            MultiModalProcessorCacheItem,
        )

    @override
    def is_cached_item(self, mm_hash: str) -> bool:
        return mm_hash in self._cache

    @override
    def get_and_update_item(
        self,
        mm_item: MultiModalProcessorCacheInItem,
        mm_hash: str,
    ) -> MultiModalProcessorCacheOutItem:
        if (cached_item := self._cache.get(mm_hash)) is not None:
            return cached_item.item, cached_item.prompt_updates

        assert mm_item is not None, f"Expected a cached item for {mm_hash=}"

        self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item)

        return mm_item

    @override
    def clear_cache(self) -> None:
        self._cache.clear()

_cache instance-attribute

_cache = get_lru_cache(
    mm_processor_cache_gb, MultiModalProcessorCacheItem
)

__init__

__init__(model_config: ModelConfig) -> None
Source code in vllm/multimodal/cache.py
def __init__(self, model_config: "ModelConfig") -> None:
    super().__init__()

    mm_config = model_config.get_multimodal_config()

    self._cache = MultiModalCache.get_lru_cache(
        mm_config.mm_processor_cache_gb,
        MultiModalProcessorCacheItem,
    )

clear_cache

clear_cache() -> None
Source code in vllm/multimodal/cache.py
@override
def clear_cache(self) -> None:
    self._cache.clear()

get_and_update_item

get_and_update_item(
    mm_item: MultiModalProcessorCacheInItem, mm_hash: str
) -> MultiModalProcessorCacheOutItem
Source code in vllm/multimodal/cache.py
@override
def get_and_update_item(
    self,
    mm_item: MultiModalProcessorCacheInItem,
    mm_hash: str,
) -> MultiModalProcessorCacheOutItem:
    if (cached_item := self._cache.get(mm_hash)) is not None:
        return cached_item.item, cached_item.prompt_updates

    assert mm_item is not None, f"Expected a cached item for {mm_hash=}"

    self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item)

    return mm_item

is_cached_item

is_cached_item(mm_hash: str) -> bool
Source code in vllm/multimodal/cache.py
@override
def is_cached_item(self, mm_hash: str) -> bool:
    return mm_hash in self._cache

MultiModalProcessorSenderCache

Bases: BaseMultiModalProcessorCache

The cache which is used on P0 when IPC caching is enabled.

How to update each item:

  • If the item is already in the cache, clear the input to avoid unnecessary IPC.

  • If the item is not in the cache, store the metadata of that item so that the eviction policy remains the same as the cache on P1, and return the input. By only storing the metadata, we avoid keeping the data itself in memory inside P0.

Source code in vllm/multimodal/cache.py
class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
    """
    The cache which is used on P0 when IPC caching is enabled.

    How to update each item:

    - If the item is already in the cache, clear the input to avoid
      unnecessary IPC.

    - If the item is not in the cache, store the metadata of that item so
      that the eviction policy remains the same as the cache on P1,
      and return the input.
      By only storing the metadata, we avoid keeping the data itself in
      memory inside P0.
    """

    def __init__(self, model_config: "ModelConfig") -> None:
        super().__init__()

        mm_config = model_config.get_multimodal_config()

        self._cache = MultiModalCache.get_lru_cache(
            mm_config.mm_processor_cache_gb,
            MultiModalProcessorCacheItemMetadata,
        )

    @override
    def is_cached_item(self, mm_hash: str) -> bool:
        return mm_hash in self._cache

    @override
    def get_and_update_item(
        self,
        mm_item: MultiModalProcessorCacheInItem,
        mm_hash: str,
    ) -> MultiModalProcessorCacheOutItem:
        if (cached_item := self._cache.get(mm_hash)) is not None:
            return None, cached_item.prompt_updates

        assert mm_item is not None, f"Expected a cached item for {mm_hash=}"

        self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item)

        return mm_item

    @override
    def clear_cache(self) -> None:
        self._cache.clear()

_cache instance-attribute

_cache = get_lru_cache(
    mm_processor_cache_gb,
    MultiModalProcessorCacheItemMetadata,
)

__init__

__init__(model_config: ModelConfig) -> None
Source code in vllm/multimodal/cache.py
def __init__(self, model_config: "ModelConfig") -> None:
    super().__init__()

    mm_config = model_config.get_multimodal_config()

    self._cache = MultiModalCache.get_lru_cache(
        mm_config.mm_processor_cache_gb,
        MultiModalProcessorCacheItemMetadata,
    )

clear_cache

clear_cache() -> None
Source code in vllm/multimodal/cache.py
@override
def clear_cache(self) -> None:
    self._cache.clear()

get_and_update_item

get_and_update_item(
    mm_item: MultiModalProcessorCacheInItem, mm_hash: str
) -> MultiModalProcessorCacheOutItem
Source code in vllm/multimodal/cache.py
@override
def get_and_update_item(
    self,
    mm_item: MultiModalProcessorCacheInItem,
    mm_hash: str,
) -> MultiModalProcessorCacheOutItem:
    if (cached_item := self._cache.get(mm_hash)) is not None:
        return None, cached_item.prompt_updates

    assert mm_item is not None, f"Expected a cached item for {mm_hash=}"

    self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item)

    return mm_item

is_cached_item

is_cached_item(mm_hash: str) -> bool
Source code in vllm/multimodal/cache.py
@override
def is_cached_item(self, mm_hash: str) -> bool:
    return mm_hash in self._cache

MultiModalReceiverCache

Bases: BaseMultiModalReceiverCache

The cache which is used on P1 when IPC caching is enabled.

How to update each item:

  • If the item is in the cache, replace the input with the cached item.
  • If the item is not in the cache, store that item (which includes tensor data) into the cache, and return the input.
Source code in vllm/multimodal/cache.py
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
    """
    The cache which is used on P1 when IPC caching is enabled.

    How to update each item:

    - If the item is in the cache, replace the input with the cached item.
    - If the item is not in the cache, store that item (which includes tensor
      data) into the cache, and return the input.
    """

    def __init__(self, model_config: "ModelConfig") -> None:
        super().__init__()

        mm_config = model_config.get_multimodal_config()

        self._cache = MultiModalCache.get_lru_cache(
            mm_config.mm_processor_cache_gb,
            MultiModalKwargsItem,
        )

    @override
    def get_and_update_item(
        self,
        mm_item: Optional[MultiModalKwargsItem],
        mm_hash: str,
    ) -> MultiModalKwargsItem:
        if (cached_item := self._cache.get(mm_hash)) is not None:
            return cached_item

        assert mm_item is not None, f"Expected a cached item for {mm_hash=}"

        self._cache[mm_hash] = mm_item
        return mm_item

    @override
    def clear_cache(self) -> None:
        self._cache.clear()

_cache instance-attribute

_cache = get_lru_cache(
    mm_processor_cache_gb, MultiModalKwargsItem
)

__init__

__init__(model_config: ModelConfig) -> None
Source code in vllm/multimodal/cache.py
def __init__(self, model_config: "ModelConfig") -> None:
    super().__init__()

    mm_config = model_config.get_multimodal_config()

    self._cache = MultiModalCache.get_lru_cache(
        mm_config.mm_processor_cache_gb,
        MultiModalKwargsItem,
    )

clear_cache

clear_cache() -> None
Source code in vllm/multimodal/cache.py
@override
def clear_cache(self) -> None:
    self._cache.clear()

get_and_update_item

get_and_update_item(
    mm_item: Optional[MultiModalKwargsItem], mm_hash: str
) -> MultiModalKwargsItem
Source code in vllm/multimodal/cache.py
@override
def get_and_update_item(
    self,
    mm_item: Optional[MultiModalKwargsItem],
    mm_hash: str,
) -> MultiModalKwargsItem:
    if (cached_item := self._cache.get(mm_hash)) is not None:
        return cached_item

    assert mm_item is not None, f"Expected a cached item for {mm_hash=}"

    self._cache[mm_hash] = mm_item
    return mm_item

_enable_ipc_cache

_enable_ipc_cache(vllm_config: VllmConfig) -> bool
Source code in vllm/multimodal/cache.py
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
    parallel_config = vllm_config.parallel_config
    supports_ipc_cache = (parallel_config.data_parallel_size == 1
                          or parallel_config.data_parallel_external_lb)

    return supports_ipc_cache

_enable_processor_cache

_enable_processor_cache(
    model_config: ModelConfig,
    mm_registry: MultiModalRegistry,
) -> bool
Source code in vllm/multimodal/cache.py
def _enable_processor_cache(
    model_config: "ModelConfig",
    mm_registry: "MultiModalRegistry",
) -> bool:
    if not mm_registry.supports_multimodal_inputs(model_config):
        return False

    mm_config = model_config.get_multimodal_config()
    return mm_config.mm_processor_cache_gb > 0

processor_cache_from_config

processor_cache_from_config(
    vllm_config: VllmConfig, mm_registry: MultiModalRegistry
) -> Optional[BaseMultiModalProcessorCache]

Return a BaseMultiModalProcessorCache, if enabled.

Source code in vllm/multimodal/cache.py
def processor_cache_from_config(
    vllm_config: "VllmConfig",
    mm_registry: "MultiModalRegistry",
) -> Optional[BaseMultiModalProcessorCache]:
    """Return a `BaseMultiModalProcessorCache`, if enabled."""
    model_config = vllm_config.model_config

    if not _enable_processor_cache(model_config, mm_registry):
        return None

    if not _enable_ipc_cache(vllm_config):
        return MultiModalProcessorOnlyCache(model_config)

    return MultiModalProcessorSenderCache(model_config)

processor_only_cache_from_config

processor_only_cache_from_config(
    model_config: ModelConfig,
    mm_registry: MultiModalRegistry,
)

Return a MultiModalProcessorOnlyCache, if enabled.

Source code in vllm/multimodal/cache.py
def processor_only_cache_from_config(
    model_config: "ModelConfig",
    mm_registry: "MultiModalRegistry",
):
    """Return a `MultiModalProcessorOnlyCache`, if enabled."""
    if not _enable_processor_cache(model_config, mm_registry):
        return None

    return MultiModalProcessorOnlyCache(model_config)

receiver_cache_from_config

receiver_cache_from_config(
    vllm_config: VllmConfig, mm_registry: MultiModalRegistry
) -> Optional[BaseMultiModalReceiverCache]

Return a BaseMultiModalReceiverCache, if enabled.

Source code in vllm/multimodal/cache.py
def receiver_cache_from_config(
    vllm_config: "VllmConfig",
    mm_registry: "MultiModalRegistry",
) -> Optional[BaseMultiModalReceiverCache]:
    """Return a `BaseMultiModalReceiverCache`, if enabled."""
    model_config = vllm_config.model_config

    if not _enable_processor_cache(model_config, mm_registry):
        return None

    if not _enable_ipc_cache(vllm_config):
        return None

    return MultiModalReceiverCache(model_config)