Table of Contents
Fetching ...

SAMSA: Efficient Transformer for Many Data Modalities

Minh Lenhat, Viet Anh Nguyen, Khoa Nguyen, Duong Duc Hieu, Dao Huu Hung, Truong Son Hy

TL;DR

SAMSA introduces a modality-agnostic, context-aware self-attention mechanism with differentiable sampling without replacement to achieve linear-time inference. By learning token importance and applying mixture-of-attention-experts across heads, SAMSA expands receptive fields while maintaining expressivity. Across sequences, graphs, and point clouds, it attains competitive or state-of-the-art results with faster inference than full self-attention and non-modality-specific inductive biases. The approach balances speed and accuracy, though it notes optimization challenges for hard sampling and suggests directions for improved gradient estimators.

Abstract

The versatility of self-attention mechanism earned transformers great success in almost all data modalities, with limitations on the quadratic complexity and difficulty of training. Efficient transformers, on the other hand, often rely on clever data-modality-dependent construction to get over the quadratic complexity of transformers. This greatly hinders their applications on different data modalities, which is one of the pillars of contemporary foundational modeling. In this paper, we lay the groundwork for efficient foundational modeling by proposing SAMSA - SAMpling-Self-Attention, a context-aware linear complexity self-attention mechanism that works well on multiple data modalities. Our mechanism is based on a differentiable sampling without replacement method we discovered. This enables the self-attention module to attend to the most important token set, where the importance is defined by data. Moreover, as differentiability is not needed in inference, the sparse formulation of our method costs little time overhead, further lowering computational costs. In short, SAMSA achieved competitive or even SOTA results on many benchmarks, while being faster in inference, compared to other very specialized models. Against full self-attention, real inference time significantly decreases while performance ranges from negligible degradation to outperformance. We release our source code in the repository: https://github.com/HySonLab/SAMSA

SAMSA: Efficient Transformer for Many Data Modalities

TL;DR

SAMSA introduces a modality-agnostic, context-aware self-attention mechanism with differentiable sampling without replacement to achieve linear-time inference. By learning token importance and applying mixture-of-attention-experts across heads, SAMSA expands receptive fields while maintaining expressivity. Across sequences, graphs, and point clouds, it attains competitive or state-of-the-art results with faster inference than full self-attention and non-modality-specific inductive biases. The approach balances speed and accuracy, though it notes optimization challenges for hard sampling and suggests directions for improved gradient estimators.

Abstract

The versatility of self-attention mechanism earned transformers great success in almost all data modalities, with limitations on the quadratic complexity and difficulty of training. Efficient transformers, on the other hand, often rely on clever data-modality-dependent construction to get over the quadratic complexity of transformers. This greatly hinders their applications on different data modalities, which is one of the pillars of contemporary foundational modeling. In this paper, we lay the groundwork for efficient foundational modeling by proposing SAMSA - SAMpling-Self-Attention, a context-aware linear complexity self-attention mechanism that works well on multiple data modalities. Our mechanism is based on a differentiable sampling without replacement method we discovered. This enables the self-attention module to attend to the most important token set, where the importance is defined by data. Moreover, as differentiability is not needed in inference, the sparse formulation of our method costs little time overhead, further lowering computational costs. In short, SAMSA achieved competitive or even SOTA results on many benchmarks, while being faster in inference, compared to other very specialized models. Against full self-attention, real inference time significantly decreases while performance ranges from negligible degradation to outperformance. We release our source code in the repository: https://github.com/HySonLab/SAMSA
Paper Structure (47 sections, 18 equations, 4 figures, 8 tables, 1 algorithm)

This paper contains 47 sections, 18 equations, 4 figures, 8 tables, 1 algorithm.

Figures (4)

  • Figure 1: Overview of our proposed model sampling-self-attention module SAMSA. The key-value vectors are selected via top-k importance score computed using tokens' latent. The Gumbel-Sigmoid reparameterization provides gradients to guide the optimization process towards most important key-value pairs of vectors (left). The sampled key-value vectors are then fed into Flash Attention to attend to the query vectors (right).
  • Figure 2: Learning Curves of SAMSA models in Sequence Tasks
  • Figure 3: Learning Curves of SAMSA models in Graph Tasks
  • Figure 4: Learning Curves of SAMSA models in Point Cloud Tasks