@CustomOp.register("sparse_attn_indexer")
class SparseAttnIndexer(CustomOp):
"""Sparse Attention Indexer Custom Op Layer. This layer is extracted as a
separate custom op since it involves heavy custom kernels like `mqa_logits`,
`paged_mqa_logits` and `top_k_per_row`, etc. Those kernels maybe requires
specific memory layout or implementation for different hardware backends to
achieve optimal performance.
For now, the default native path will use CUDA backend path. Other platform
may requires add the corresponding Custom Op name `sparse_attn_indexer` to
`custom_ops` in `CompilationConfig` to enable the platform specific path.
"""
def __init__(
self,
k_cache,
quant_block_size: int,
scale_fmt: str,
topk_tokens: int,
head_dim: int,
max_model_len: int,
max_total_seq_len: int,
topk_indices_buffer: torch.Tensor,
):
super().__init__()
self.k_cache = k_cache
self.quant_block_size = quant_block_size
self.scale_fmt = scale_fmt
self.topk_tokens = topk_tokens
self.head_dim = head_dim
self.max_model_len = max_model_len
self.max_total_seq_len = max_total_seq_len
self.topk_indices_buffer = topk_indices_buffer
if current_platform.is_cuda() and not has_deep_gemm():
raise RuntimeError(
"Sparse Attention Indexer CUDA op requires DeepGEMM to be installed."
)
def forward_native(
self,
hidden_states: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
):
if current_platform.is_cuda():
return self.forward_cuda(hidden_states, q_fp8, k, weights)
elif current_platform.is_rocm():
return self.forward_hip(hidden_states, q_fp8, k, weights)
else:
raise NotImplementedError(
"SparseAttnIndexer native forward is only implemented for "
"CUDA and ROCm platform."
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
):
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q_fp8,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
def forward_hip(
self,
hidden_states: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
):
if rocm_aiter_ops.is_enabled():
return torch.ops.vllm.rocm_aiter_sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q_fp8,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
else:
raise RuntimeError(
"Sparse attention indexer ROCm custom op requires ROCm "
"Aiter ops to be enabled."
)