Skip to content

vllm.attention.backends.abstract

T module-attribute

T = TypeVar('T', bound=AttentionMetadata)

AttentionBackend

Bases: ABC

Abstract class for attention backends.

Source code in vllm/attention/backends/abstract.py
class AttentionBackend(ABC):
    """Abstract class for attention backends."""

    # For some attention backends, we allocate an output tensor before
    # calling the custom op. When piecewise cudagraph is enabled, this
    # makes sure the output tensor is allocated inside the cudagraph.
    accept_output_buffer: bool = False
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
    supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]

    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        return [MultipleOf(1)]

    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_impl_cls() -> type["AttentionImpl"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_builder_cls():  # -> Type["AttentionMetadataBuilder"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        raise NotImplementedError

    @staticmethod
    def get_kv_cache_stride_order(
        include_num_layers_dimension: bool = False,
    ) -> tuple[int, ...]:
        """
        Get the physical (memory layout) ordering of the kv cache dimensions.
        e.g. if the KV cache shape is
        [2, num_blocks, block_size, num_heads, head_size],
        and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical
        ordering of dimensions is
        [num_blocks, num_heads, 2, block_size, head_size].

        If this function is unimplemented / raises NotImplementedError,
        the physical layout of the KV cache will match the logical shape.

        Args:
            include_num_layers_dimension: if True, includes an additional
                num_layers dimension, which is assumed to be prepended
                to the logical KV cache shape.
                With the above example, a return value (2, 4, 0, 1, 3, 5)
                corresponds to
                [num_blocks, num_heads, num_layers, 2, block_size, head_size].

                If an additional dimension is NOT included in the returned
                tuple, the physical layout will not include a layers dimension.

        Returns:
            A tuple of ints which is a permutation of range(len(shape)).
        """
        raise NotImplementedError

    @classmethod
    def full_cls_name(cls) -> tuple[str, str]:
        return (cls.__module__, cls.__qualname__)

    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return []

    @classmethod
    def supports_head_size(cls, head_size: int) -> bool:
        supported_head_sizes = cls.get_supported_head_sizes()
        return (not supported_head_sizes) or head_size in supported_head_sizes

    @classmethod
    def supports_dtype(cls, dtype: torch.dtype) -> bool:
        return dtype in cls.supported_dtypes

    @classmethod
    def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
        if kv_cache_dtype is None:
            return True
        return (not cls.supported_kv_cache_dtypes) or (
            kv_cache_dtype in cls.supported_kv_cache_dtypes
        )

    @classmethod
    def supports_block_size(cls, block_size: int | None) -> bool:
        from vllm.config.cache import BlockSize

        if block_size is None:
            return True

        valid_sizes = get_args(BlockSize)
        if block_size not in valid_sizes:
            return False

        supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
        if not supported_kernel_block_sizes:
            return True

        for supported_size in supported_kernel_block_sizes:
            if isinstance(supported_size, MultipleOf):
                supported_size = supported_size.base
            # With hybrid_blocks feature, the framework-level block size
            # only needs to be a multiple of the kernel's requirement,
            # even if the kernel requires a fixed block_size.
            if block_size % supported_size == 0:
                return True
        return False

    @classmethod
    def is_mla(cls) -> bool:
        return False

    @classmethod
    def supports_sink(cls) -> bool:
        return False

    @classmethod
    def is_sparse(cls) -> bool:
        return False

    @classmethod
    def supports_attn_type(cls, attn_type: str) -> bool:
        """Check if backend supports a given attention type.

        By default, only supports decoder attention.
        Backends should override this to support other attention types.
        """
        return attn_type == AttentionType.DECODER

    @classmethod
    def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
        return True

    @classmethod
    def supports_combination(
        cls,
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: "CacheDType | None",
        block_size: int | None,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
        device_capability: "DeviceCapability",
    ) -> str | None:
        return None

    @classmethod
    def validate_configuration(
        cls,
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: "CacheDType | None",
        block_size: int | None,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
        device_capability: "DeviceCapability",
        attn_type: str,
    ) -> list[str]:
        invalid_reasons = []
        if not cls.supports_head_size(head_size):
            invalid_reasons.append("head_size not supported")
        if not cls.supports_dtype(dtype):
            invalid_reasons.append("dtype not supported")
        if not cls.supports_kv_cache_dtype(kv_cache_dtype):
            invalid_reasons.append("kv_cache_dtype not supported")
        if not cls.supports_block_size(block_size):
            invalid_reasons.append("block_size not supported")
        if use_mla != cls.is_mla():
            if use_mla:
                invalid_reasons.append("MLA not supported")
            else:
                invalid_reasons.append("non-MLA not supported")
        if has_sink and not cls.supports_sink():
            invalid_reasons.append("sink setting not supported")
        if use_sparse != cls.is_sparse():
            if use_sparse:
                invalid_reasons.append("sparse not supported")
            else:
                invalid_reasons.append("non-sparse not supported")
        if not cls.supports_compute_capability(device_capability):
            invalid_reasons.append("compute capability not supported")
        if not cls.supports_attn_type(attn_type):
            invalid_reasons.append(f"attention type {attn_type} not supported")
        combination_reason = cls.supports_combination(
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla,
            has_sink,
            use_sparse,
            device_capability,
        )
        if combination_reason is not None:
            invalid_reasons.append(combination_reason)
        return invalid_reasons

    @classmethod
    def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
        return None

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = False

supported_dtypes class-attribute

supported_dtypes: list[dtype] = [float16, bfloat16]

supported_kv_cache_dtypes class-attribute

supported_kv_cache_dtypes: list[CacheDType] = ['auto']

full_cls_name classmethod

full_cls_name() -> tuple[str, str]
Source code in vllm/attention/backends/abstract.py
@classmethod
def full_cls_name(cls) -> tuple[str, str]:
    return (cls.__module__, cls.__qualname__)

get_builder_cls abstractmethod staticmethod

get_builder_cls()
Source code in vllm/attention/backends/abstract.py
@staticmethod
@abstractmethod
def get_builder_cls():  # -> Type["AttentionMetadataBuilder"]:
    raise NotImplementedError

get_impl_cls abstractmethod staticmethod

get_impl_cls() -> type[AttentionImpl]
Source code in vllm/attention/backends/abstract.py
@staticmethod
@abstractmethod
def get_impl_cls() -> type["AttentionImpl"]:
    raise NotImplementedError

get_kv_cache_shape abstractmethod staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]
Source code in vllm/attention/backends/abstract.py
@staticmethod
@abstractmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    raise NotImplementedError

get_kv_cache_stride_order staticmethod

get_kv_cache_stride_order(
    include_num_layers_dimension: bool = False,
) -> tuple[int, ...]

Get the physical (memory layout) ordering of the kv cache dimensions. e.g. if the KV cache shape is [2, num_blocks, block_size, num_heads, head_size], and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical ordering of dimensions is [num_blocks, num_heads, 2, block_size, head_size].

If this function is unimplemented / raises NotImplementedError, the physical layout of the KV cache will match the logical shape.

Parameters:

Name Type Description Default
include_num_layers_dimension bool

if True, includes an additional num_layers dimension, which is assumed to be prepended to the logical KV cache shape. With the above example, a return value (2, 4, 0, 1, 3, 5) corresponds to [num_blocks, num_heads, num_layers, 2, block_size, head_size].

If an additional dimension is NOT included in the returned tuple, the physical layout will not include a layers dimension.

False

Returns:

Type Description
tuple[int, ...]

A tuple of ints which is a permutation of range(len(shape)).

Source code in vllm/attention/backends/abstract.py
@staticmethod
def get_kv_cache_stride_order(
    include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
    """
    Get the physical (memory layout) ordering of the kv cache dimensions.
    e.g. if the KV cache shape is
    [2, num_blocks, block_size, num_heads, head_size],
    and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical
    ordering of dimensions is
    [num_blocks, num_heads, 2, block_size, head_size].

    If this function is unimplemented / raises NotImplementedError,
    the physical layout of the KV cache will match the logical shape.

    Args:
        include_num_layers_dimension: if True, includes an additional
            num_layers dimension, which is assumed to be prepended
            to the logical KV cache shape.
            With the above example, a return value (2, 4, 0, 1, 3, 5)
            corresponds to
            [num_blocks, num_heads, num_layers, 2, block_size, head_size].

            If an additional dimension is NOT included in the returned
            tuple, the physical layout will not include a layers dimension.

    Returns:
        A tuple of ints which is a permutation of range(len(shape)).
    """
    raise NotImplementedError

get_name abstractmethod staticmethod

get_name() -> str
Source code in vllm/attention/backends/abstract.py
@staticmethod
@abstractmethod
def get_name() -> str:
    raise NotImplementedError

get_required_kv_cache_layout classmethod

get_required_kv_cache_layout() -> KVCacheLayoutType | None
Source code in vllm/attention/backends/abstract.py
@classmethod
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
    return None

get_supported_head_sizes classmethod

get_supported_head_sizes() -> list[int]
Source code in vllm/attention/backends/abstract.py
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
    return []

get_supported_kernel_block_sizes staticmethod

get_supported_kernel_block_sizes() -> list[
    int | MultipleOf
]
Source code in vllm/attention/backends/abstract.py
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
    return [MultipleOf(1)]

is_mla classmethod

is_mla() -> bool
Source code in vllm/attention/backends/abstract.py
@classmethod
def is_mla(cls) -> bool:
    return False

is_sparse classmethod

is_sparse() -> bool
Source code in vllm/attention/backends/abstract.py
@classmethod
def is_sparse(cls) -> bool:
    return False

supports_attn_type classmethod

supports_attn_type(attn_type: str) -> bool

Check if backend supports a given attention type.

By default, only supports decoder attention. Backends should override this to support other attention types.

Source code in vllm/attention/backends/abstract.py
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
    """Check if backend supports a given attention type.

    By default, only supports decoder attention.
    Backends should override this to support other attention types.
    """
    return attn_type == AttentionType.DECODER

supports_block_size classmethod

supports_block_size(block_size: int | None) -> bool
Source code in vllm/attention/backends/abstract.py
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
    from vllm.config.cache import BlockSize

    if block_size is None:
        return True

    valid_sizes = get_args(BlockSize)
    if block_size not in valid_sizes:
        return False

    supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
    if not supported_kernel_block_sizes:
        return True

    for supported_size in supported_kernel_block_sizes:
        if isinstance(supported_size, MultipleOf):
            supported_size = supported_size.base
        # With hybrid_blocks feature, the framework-level block size
        # only needs to be a multiple of the kernel's requirement,
        # even if the kernel requires a fixed block_size.
        if block_size % supported_size == 0:
            return True
    return False

supports_combination classmethod

supports_combination(
    head_size: int,
    dtype: dtype,
    kv_cache_dtype: CacheDType | None,
    block_size: int | None,
    use_mla: bool,
    has_sink: bool,
    use_sparse: bool,
    device_capability: DeviceCapability,
) -> str | None
Source code in vllm/attention/backends/abstract.py
@classmethod
def supports_combination(
    cls,
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: "CacheDType | None",
    block_size: int | None,
    use_mla: bool,
    has_sink: bool,
    use_sparse: bool,
    device_capability: "DeviceCapability",
) -> str | None:
    return None

supports_compute_capability classmethod

supports_compute_capability(
    capability: DeviceCapability,
) -> bool
Source code in vllm/attention/backends/abstract.py
@classmethod
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
    return True

supports_dtype classmethod

supports_dtype(dtype: dtype) -> bool
Source code in vllm/attention/backends/abstract.py
@classmethod
def supports_dtype(cls, dtype: torch.dtype) -> bool:
    return dtype in cls.supported_dtypes

supports_head_size classmethod

supports_head_size(head_size: int) -> bool
Source code in vllm/attention/backends/abstract.py
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
    supported_head_sizes = cls.get_supported_head_sizes()
    return (not supported_head_sizes) or head_size in supported_head_sizes

supports_kv_cache_dtype classmethod

supports_kv_cache_dtype(
    kv_cache_dtype: CacheDType | None,
) -> bool
Source code in vllm/attention/backends/abstract.py
@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
    if kv_cache_dtype is None:
        return True
    return (not cls.supported_kv_cache_dtypes) or (
        kv_cache_dtype in cls.supported_kv_cache_dtypes
    )

supports_sink classmethod

supports_sink() -> bool
Source code in vllm/attention/backends/abstract.py
@classmethod
def supports_sink(cls) -> bool:
    return False

validate_configuration classmethod

validate_configuration(
    head_size: int,
    dtype: dtype,
    kv_cache_dtype: CacheDType | None,
    block_size: int | None,
    use_mla: bool,
    has_sink: bool,
    use_sparse: bool,
    device_capability: DeviceCapability,
    attn_type: str,
) -> list[str]
Source code in vllm/attention/backends/abstract.py
@classmethod
def validate_configuration(
    cls,
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: "CacheDType | None",
    block_size: int | None,
    use_mla: bool,
    has_sink: bool,
    use_sparse: bool,
    device_capability: "DeviceCapability",
    attn_type: str,
) -> list[str]:
    invalid_reasons = []
    if not cls.supports_head_size(head_size):
        invalid_reasons.append("head_size not supported")
    if not cls.supports_dtype(dtype):
        invalid_reasons.append("dtype not supported")
    if not cls.supports_kv_cache_dtype(kv_cache_dtype):
        invalid_reasons.append("kv_cache_dtype not supported")
    if not cls.supports_block_size(block_size):
        invalid_reasons.append("block_size not supported")
    if use_mla != cls.is_mla():
        if use_mla:
            invalid_reasons.append("MLA not supported")
        else:
            invalid_reasons.append("non-MLA not supported")
    if has_sink and not cls.supports_sink():
        invalid_reasons.append("sink setting not supported")
    if use_sparse != cls.is_sparse():
        if use_sparse:
            invalid_reasons.append("sparse not supported")
        else:
            invalid_reasons.append("non-sparse not supported")
    if not cls.supports_compute_capability(device_capability):
        invalid_reasons.append("compute capability not supported")
    if not cls.supports_attn_type(attn_type):
        invalid_reasons.append(f"attention type {attn_type} not supported")
    combination_reason = cls.supports_combination(
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_mla,
        has_sink,
        use_sparse,
        device_capability,
    )
    if combination_reason is not None:
        invalid_reasons.append(combination_reason)
    return invalid_reasons

AttentionImpl

Bases: ABC, Generic[T]

Source code in vllm/attention/backends/abstract.py
class AttentionImpl(ABC, Generic[T]):
    # Whether the attention impl can return the softmax lse for decode.
    # Some features like decode context parallelism require the softmax lse.
    can_return_lse_for_decode: bool = False

    # some attention backends might not always want to return lse
    # even if they can return lse (for efficiency reasons)
    need_to_return_lse_for_decode: bool = False

    # Whether this attention implementation supports pre-quantized query input.
    # When True, the attention layer will quantize queries before passing them
    # to this backend, allowing torch.compile to fuse the quantization with
    # previous operations. This is typically supported when using FP8 KV cache
    # with compatible attention kernels (e.g., TRT-LLM).
    # Subclasses should set this in __init__.
    # TODO add support to more backends:
    # https://github.com/vllm-project/vllm/issues/25584
    supports_quant_query_input: bool = False

    dcp_world_size: int
    dcp_rank: int

    pcp_world_size: int
    pcp_rank: int

    total_cp_world_size: int
    total_cp_rank: int

    def __new__(cls, *args, **kwargs):
        # use __new__ so that all subclasses will call this
        self = super().__new__(cls)
        try:
            from vllm.distributed.parallel_state import get_dcp_group

            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
        try:
            from vllm.distributed.parallel_state import get_pcp_group

            self.pcp_world_size = get_pcp_group().world_size
            self.pcp_rank = get_pcp_group().rank_in_group
        except AssertionError:
            self.pcp_world_size = 1
            self.pcp_rank = 0
        self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
        self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank

        self.need_to_return_lse_for_decode = (
            self.dcp_world_size > 1 and self.can_return_lse_for_decode
        )
        return self

    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        sliding_window: int | None = None,
        kv_cache_dtype: str = "auto",
        logits_soft_cap: float | None = None,
        attn_type: str = AttentionType.DECODER,
        kv_sharing_target_layer_name: str | None = None,
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
        layer: AttentionLayer,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: T,
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
    ) -> torch.Tensor:
        raise NotImplementedError

    def fused_output_quant_supported(self, quant_key: "QuantKey"):
        """
        Does this attention implementation support fused output quantization.
        This is used by the AttnFusionPass to only fuse output quantization
        onto implementations that support it.

        :param quant_key: QuantKey object that describes the quantization op
        :return: is fusion supported for this type of quantization
        """
        return False

    def process_weights_after_loading(self, act_dtype: torch.dtype):
        pass

can_return_lse_for_decode class-attribute instance-attribute

can_return_lse_for_decode: bool = False

dcp_rank instance-attribute

dcp_rank: int

dcp_world_size instance-attribute

dcp_world_size: int

need_to_return_lse_for_decode class-attribute instance-attribute

need_to_return_lse_for_decode: bool = False

pcp_rank instance-attribute

pcp_rank: int

pcp_world_size instance-attribute

pcp_world_size: int

supports_quant_query_input class-attribute instance-attribute

supports_quant_query_input: bool = False

total_cp_rank instance-attribute

total_cp_rank: int

total_cp_world_size instance-attribute

total_cp_world_size: int

__init__ abstractmethod

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int | None = None,
    alibi_slopes: list[float] | None = None,
    sliding_window: int | None = None,
    kv_cache_dtype: str = "auto",
    logits_soft_cap: float | None = None,
    attn_type: str = DECODER,
    kv_sharing_target_layer_name: str | None = None,
) -> None
Source code in vllm/attention/backends/abstract.py
@abstractmethod
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int | None = None,
    alibi_slopes: list[float] | None = None,
    sliding_window: int | None = None,
    kv_cache_dtype: str = "auto",
    logits_soft_cap: float | None = None,
    attn_type: str = AttentionType.DECODER,
    kv_sharing_target_layer_name: str | None = None,
) -> None:
    raise NotImplementedError

__new__

__new__(*args, **kwargs)
Source code in vllm/attention/backends/abstract.py
def __new__(cls, *args, **kwargs):
    # use __new__ so that all subclasses will call this
    self = super().__new__(cls)
    try:
        from vllm.distributed.parallel_state import get_dcp_group

        self.dcp_world_size = get_dcp_group().world_size
        self.dcp_rank = get_dcp_group().rank_in_group
    except AssertionError:
        # DCP might not be initialized in testing
        self.dcp_world_size = 1
        self.dcp_rank = 0
    try:
        from vllm.distributed.parallel_state import get_pcp_group

        self.pcp_world_size = get_pcp_group().world_size
        self.pcp_rank = get_pcp_group().rank_in_group
    except AssertionError:
        self.pcp_world_size = 1
        self.pcp_rank = 0
    self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
    self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank

    self.need_to_return_lse_for_decode = (
        self.dcp_world_size > 1 and self.can_return_lse_for_decode
    )
    return self

forward abstractmethod

forward(
    layer: AttentionLayer,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: T,
    output: Tensor | None = None,
    output_scale: Tensor | None = None,
    output_block_scale: Tensor | None = None,
) -> Tensor
Source code in vllm/attention/backends/abstract.py
@abstractmethod
def forward(
    self,
    layer: AttentionLayer,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: T,
    output: torch.Tensor | None = None,
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
    raise NotImplementedError

fused_output_quant_supported

fused_output_quant_supported(quant_key: QuantKey)

Does this attention implementation support fused output quantization. This is used by the AttnFusionPass to only fuse output quantization onto implementations that support it.

:param quant_key: QuantKey object that describes the quantization op :return: is fusion supported for this type of quantization

Source code in vllm/attention/backends/abstract.py
def fused_output_quant_supported(self, quant_key: "QuantKey"):
    """
    Does this attention implementation support fused output quantization.
    This is used by the AttnFusionPass to only fuse output quantization
    onto implementations that support it.

    :param quant_key: QuantKey object that describes the quantization op
    :return: is fusion supported for this type of quantization
    """
    return False

process_weights_after_loading

process_weights_after_loading(act_dtype: dtype)
Source code in vllm/attention/backends/abstract.py
def process_weights_after_loading(self, act_dtype: torch.dtype):
    pass

AttentionLayer

Bases: Protocol

Source code in vllm/attention/backends/abstract.py
class AttentionLayer(Protocol):
    _q_scale: torch.Tensor
    _k_scale: torch.Tensor
    _v_scale: torch.Tensor
    _q_scale_float: float
    _k_scale_float: float
    _v_scale_float: float
    _prob_scale: torch.Tensor

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor: ...

_k_scale instance-attribute

_k_scale: Tensor

_k_scale_float instance-attribute

_k_scale_float: float

_prob_scale instance-attribute

_prob_scale: Tensor

_q_scale instance-attribute

_q_scale: Tensor

_q_scale_float instance-attribute

_q_scale_float: float

_v_scale instance-attribute

_v_scale: Tensor

_v_scale_float instance-attribute

_v_scale_float: float

forward

forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: AttentionMetadata,
) -> Tensor
Source code in vllm/attention/backends/abstract.py
def forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: AttentionMetadata,
) -> torch.Tensor: ...

AttentionMetadata

Source code in vllm/attention/backends/abstract.py
class AttentionMetadata:
    pass

AttentionType

Attention type. Use string to be compatible with torch.compile.

Source code in vllm/attention/backends/abstract.py
class AttentionType:
    """
    Attention type.
    Use string to be compatible with `torch.compile`.
    """

    DECODER = "decoder"
    """Decoder attention between previous layer Q/K/V."""
    ENCODER = "encoder"
    """Encoder attention between previous layer Q/K/V for encoder-decoder."""
    ENCODER_ONLY = "encoder_only"
    """Encoder attention between previous layer Q/K/V."""
    ENCODER_DECODER = "encoder_decoder"
    """Attention between dec. Q and enc. K/V for encoder-decoder."""

DECODER class-attribute instance-attribute

DECODER = 'decoder'

Decoder attention between previous layer Q/K/V.

ENCODER class-attribute instance-attribute

ENCODER = 'encoder'

Encoder attention between previous layer Q/K/V for encoder-decoder.

ENCODER_DECODER class-attribute instance-attribute

ENCODER_DECODER = 'encoder_decoder'

Attention between dec. Q and enc. K/V for encoder-decoder.

ENCODER_ONLY class-attribute instance-attribute

ENCODER_ONLY = 'encoder_only'

Encoder attention between previous layer Q/K/V.

MLAAttentionImpl

Bases: AttentionImpl[T], Generic[T]

Source code in vllm/attention/backends/abstract.py
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
        kv_cache_dtype: str,
        logits_soft_cap: float | None,
        attn_type: str,
        kv_sharing_target_layer_name: str | None,
        # MLA Specific Arguments
        q_lora_rank: int | None,
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        qk_head_dim: int,
        v_head_dim: int,
        kv_b_proj: "ColumnParallelLinear",
        indexer: object | None = None,
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
        layer: AttentionLayer,
        hidden_states_or_cq: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: T,
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
    ) -> torch.Tensor:
        raise NotImplementedError

__init__ abstractmethod

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: list[float] | None,
    sliding_window: int | None,
    kv_cache_dtype: str,
    logits_soft_cap: float | None,
    attn_type: str,
    kv_sharing_target_layer_name: str | None,
    q_lora_rank: int | None,
    kv_lora_rank: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    qk_head_dim: int,
    v_head_dim: int,
    kv_b_proj: ColumnParallelLinear,
    indexer: object | None = None,
) -> None
Source code in vllm/attention/backends/abstract.py
@abstractmethod
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: list[float] | None,
    sliding_window: int | None,
    kv_cache_dtype: str,
    logits_soft_cap: float | None,
    attn_type: str,
    kv_sharing_target_layer_name: str | None,
    # MLA Specific Arguments
    q_lora_rank: int | None,
    kv_lora_rank: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    qk_head_dim: int,
    v_head_dim: int,
    kv_b_proj: "ColumnParallelLinear",
    indexer: object | None = None,
) -> None:
    raise NotImplementedError

forward abstractmethod

forward(
    layer: AttentionLayer,
    hidden_states_or_cq: Tensor,
    kv_c_normed: Tensor,
    k_pe: Tensor,
    kv_cache: Tensor,
    attn_metadata: T,
    output: Tensor | None = None,
    output_scale: Tensor | None = None,
    output_block_scale: Tensor | None = None,
) -> Tensor
Source code in vllm/attention/backends/abstract.py
@abstractmethod
def forward(
    self,
    layer: AttentionLayer,
    hidden_states_or_cq: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: T,
    output: torch.Tensor | None = None,
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
    raise NotImplementedError

MultipleOf

Source code in vllm/attention/backends/abstract.py
class MultipleOf:
    base: int

    def __init__(self, base: int):
        self.base = base

base instance-attribute

base: int = base

__init__

__init__(base: int)
Source code in vllm/attention/backends/abstract.py
def __init__(self, base: int):
    self.base = base

is_quantized_kv_cache

is_quantized_kv_cache(kv_cache_dtype: str) -> bool
Source code in vllm/attention/backends/abstract.py
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
    return kv_cache_dtype != "auto"