Table of Contents
Fetching ...

Learning to Focus: Focal Attention for Selective and Scalable Transformers

Dhananjay Ram, Wei Xia, Stefano Soatto

TL;DR

Attention in transformers can be noisy, especially with long contexts; Focal Attention sharpens the softmax via a temperature control, either fixed or learned, to improve token selection. It yields consistent gains across model sizes and data scales, achieving similar or better accuracy with substantially fewer parameters or training data. In long-context scenarios, it delivers large relative improvements (17-82%) and enhances in-context learning and retrieval-augmented generation, with simple integration via FlashAttention. The approach is practical and broadly applicable to real-world NLP tasks requiring focused reasoning over extended contexts.

Abstract

Attention is a core component of transformer architecture, whether encoder-only, decoder-only, or encoder-decoder model. However, the standard softmax attention often produces noisy probability distribution, which can impair effective feature selection at every layer of these models, particularly for long contexts. We propose Focal Attention, a simple yet effective modification that sharpens the attention distribution by controlling the softmax temperature, either as a fixed hyperparameter or as a learnable parameter during training. This sharpening enables the model to concentrate on the most relevant tokens while suppressing irrelevant ones. Empirically, Focal Attention scales more favorably than standard transformer with respect to model size, training data, and context length. Across diverse benchmarks, it achieves the same accuracy with up to 42% fewer parameters or 33% less training data. On long-context tasks, it delivers substantial relative improvements ranging from 17% to 82%, demonstrating its effectiveness in real world applications.

Learning to Focus: Focal Attention for Selective and Scalable Transformers

TL;DR

Attention in transformers can be noisy, especially with long contexts; Focal Attention sharpens the softmax via a temperature control, either fixed or learned, to improve token selection. It yields consistent gains across model sizes and data scales, achieving similar or better accuracy with substantially fewer parameters or training data. In long-context scenarios, it delivers large relative improvements (17-82%) and enhances in-context learning and retrieval-augmented generation, with simple integration via FlashAttention. The approach is practical and broadly applicable to real-world NLP tasks requiring focused reasoning over extended contexts.

Abstract

Attention is a core component of transformer architecture, whether encoder-only, decoder-only, or encoder-decoder model. However, the standard softmax attention often produces noisy probability distribution, which can impair effective feature selection at every layer of these models, particularly for long contexts. We propose Focal Attention, a simple yet effective modification that sharpens the attention distribution by controlling the softmax temperature, either as a fixed hyperparameter or as a learnable parameter during training. This sharpening enables the model to concentrate on the most relevant tokens while suppressing irrelevant ones. Empirically, Focal Attention scales more favorably than standard transformer with respect to model size, training data, and context length. Across diverse benchmarks, it achieves the same accuracy with up to 42% fewer parameters or 33% less training data. On long-context tasks, it delivers substantial relative improvements ranging from 17% to 82%, demonstrating its effectiveness in real world applications.

Paper Structure

This paper contains 32 sections, 7 equations, 9 figures, 4 tables.

Figures (9)

  • Figure 1: Comparison of attention probability distribution of a baseline transformer model and the proposed Focal Attention. We see that the Focal Attention reduces the attention noise, and redistributes the attention probability from irrelevant tokens to relevant ones.
  • Figure 2: Scaling model size from 400M to 9.5B: Focal Attention scales better than the base transformer model with model size. The performance gap increases with larger models.
  • Figure 3: Scaling total training tokens to 315B: Focal Attention scales better than the base transformer model with increasing amount of training data. The performance gain is larger with more training tokens used.
  • Figure 4: Context length expansion: Training from scratch of 3 different model sizes. Focal Attention improves performance for all model sizes, the improvement is larger for longer context length.
  • Figure 5: Long Context Capability: Focal Attention perform significantly better than transformer for 5 out of 6 task families in HELMET at different context length, each task consists of multiple datasets.
  • ...and 4 more figures