Table of Contents
Fetching ...

SOCKET: SOft Collison Kernel EsTimator for Sparse Attention

Sahil Joshi, Agniva Chowdhury, Wyatt Bellinger, Amar Kanakamedala, Ekam Singh, Hoang Anh Duy Le, Aditya Desai, Anshumali Shrivastava

TL;DR

SOCKET tackles the heavy cost of dense attention in long-context inference by introducing a soft Locality-Sensitive Hashing (LSH)–based scoring kernel that treats collisions as probabilistic, similarity-aware signals rather than binary matches. By transforming LSH into a ranker, SOCKET enables stable top-$k$ token selection without data-dependent training, delivering principled angular-attention–style behavior with theoretical guarantees and practical speedups. The approach is backed by a sampling-based estimator and a tight end-to-end error bound that decomposes sources of approximation error into sampling variance, finite-table effects, and a bias term controlled by the temperature $\tau$. Empirically, SOCKET shows strong accuracy across LongBench and RULER benchmarks on multiple models and achieves up to $1.5\times$ decode-throughput improvements over FlashAttention, with robust, data-agnostic performance that reduces reliance on retraining or calibration. The work provides a concrete, open-source pathway for efficient long-context inference using soft-collision kernels.

Abstract

Exploiting sparsity during long-context inference is central to scaling large language models, as attention dominates the cost of autoregressive decoding. Sparse attention reduces this cost by restricting computation to a subset of tokens, but its effectiveness depends critically on efficient scoring and selection of relevant tokens at inference time. We revisit Locality-Sensitive Hashing (LSH) as a sparsification primitive and introduce SOCKET, a SOft Collision Kernel EsTimator that replaces hard bucket matches with probabilistic, similarity-aware aggregation. Our key insight is that hard LSH produces discrete collision signals and is therefore poorly suited for ranking. In contrast, soft LSH aggregates graded collision evidence across hash tables, preserving the stability of relative ordering among the true top-$k$ tokens. This transformation elevates LSH from a candidate-generation heuristic to a principled and mathematically grounded scoring kernel for sparse attention. Leveraging this property, SOCKET enables efficient token selection without ad-hoc voting mechanism, and matches or surpasses established sparse attention baselines across multiple long-context benchmarks using diverse set of models. With a custom CUDA kernel for scoring keys and a Flash Decode Triton backend for sparse attention, SOCKET achieves up to 1.5$\times$ higher throughput than FlashAttention, making it an effective tool for long-context inference. Code is open-sourced at https://github.com/amarka8/SOCKET.

SOCKET: SOft Collison Kernel EsTimator for Sparse Attention

TL;DR

SOCKET tackles the heavy cost of dense attention in long-context inference by introducing a soft Locality-Sensitive Hashing (LSH)–based scoring kernel that treats collisions as probabilistic, similarity-aware signals rather than binary matches. By transforming LSH into a ranker, SOCKET enables stable top- token selection without data-dependent training, delivering principled angular-attention–style behavior with theoretical guarantees and practical speedups. The approach is backed by a sampling-based estimator and a tight end-to-end error bound that decomposes sources of approximation error into sampling variance, finite-table effects, and a bias term controlled by the temperature . Empirically, SOCKET shows strong accuracy across LongBench and RULER benchmarks on multiple models and achieves up to decode-throughput improvements over FlashAttention, with robust, data-agnostic performance that reduces reliance on retraining or calibration. The work provides a concrete, open-source pathway for efficient long-context inference using soft-collision kernels.

Abstract

Exploiting sparsity during long-context inference is central to scaling large language models, as attention dominates the cost of autoregressive decoding. Sparse attention reduces this cost by restricting computation to a subset of tokens, but its effectiveness depends critically on efficient scoring and selection of relevant tokens at inference time. We revisit Locality-Sensitive Hashing (LSH) as a sparsification primitive and introduce SOCKET, a SOft Collision Kernel EsTimator that replaces hard bucket matches with probabilistic, similarity-aware aggregation. Our key insight is that hard LSH produces discrete collision signals and is therefore poorly suited for ranking. In contrast, soft LSH aggregates graded collision evidence across hash tables, preserving the stability of relative ordering among the true top- tokens. This transformation elevates LSH from a candidate-generation heuristic to a principled and mathematically grounded scoring kernel for sparse attention. Leveraging this property, SOCKET enables efficient token selection without ad-hoc voting mechanism, and matches or surpasses established sparse attention baselines across multiple long-context benchmarks using diverse set of models. With a custom CUDA kernel for scoring keys and a Flash Decode Triton backend for sparse attention, SOCKET achieves up to 1.5 higher throughput than FlashAttention, making it an effective tool for long-context inference. Code is open-sourced at https://github.com/amarka8/SOCKET.
Paper Structure (30 sections, 8 theorems, 91 equations, 4 figures, 6 tables, 4 algorithms)

This paper contains 30 sections, 8 theorems, 91 equations, 4 figures, 6 tables, 4 algorithms.

Key Result

Theorem 3

Let $\mathbf{q}\in\mathbb{R}^d$ be a fixed query, with keys $\mathbf{k}_1,\dots ,\mathbf{k}_N\in\mathbb{R}^d$ and values $\mathbf{v}_1,\dots,\mathbf{v}_N\in\mathbb{R}^d$. Under Assumptions as1 and as2, and for parameters $L$, $M$, and $\tau$ with $L\ge\frac{2B^2\log(8/\delta)}{Z_{\tau,\min}^2}$, the with probability at least $1-\delta$, where $\mathbf{y}^*(\mathbf{q})$ denotes the target (angular)

Figures (4)

  • Figure 1: Performance across sparsity levels on RULER-32K (Llama-3.1-8B-Instruct). Mem denotes additional memory (bits/token) beyond the KV cache. Spr denotes sparsity.
  • Figure 2: Ranking quality comparison between SOCKET and traditional LSH as a function of top-$k$. Keys are randomly generated using a standard Gaussian distribution, and ground-truth relevance is defined by dot-product similarity between a query and a key. (a)--(b) measure overlap with the ground-truth top-$k$ set, while (c) additionally evaluates agreement with the ground-truth ranking. (d) shows ground-truth scores for the top-$k$ keys selected by SOCKET and hard LSH, with the dashed line indicating the $k$-th cutoff. These metrics are well defined in Appendix \ref{['sec:metric_def']}.
  • Figure 3: (a) GPU index construction time comparison between SOCKET and PQCache. (b–c) Decode-only throughput versus context length for SOCKET (33× sparsity) and FlashAttention, evaluated using GPT-FAST on Llama-2-7b-hf with a single layer and batch size of 1.
  • Figure 4: (a) compares the index construction time of SOCKET against a CPU-based PQCache baseline. (b–c) show the throughput speedup of SOCKET over FlashAttention across different hardware platforms.

Theorems & Definitions (10)

  • Theorem 3
  • Remark 4: Error bound without sampling
  • Lemma 5
  • Lemma 6
  • Remark 7: Implication for ranking and top-$k$ selection
  • Lemma 8: Concentration of $\widetilde{Z}$ and $\widetilde{\mathbf{n}}(\mathbf{q})$
  • Lemma 9: Final bound for $\mathbf{y}_{\tau,L}(\mathbf{q})$
  • Lemma 10
  • Lemma 11
  • Theorem 12: Final end-to-end bound via union bound