Table of Contents
Fetching ...

You Do Not Fully Utilize Transformer's Representation Capacity

Gleb Gerasimov, Yaroslav Aksenov, Nikita Balagansky, Viacheslav Sinii, Daniil Gavrilov

TL;DR

The paper tackles representation collapse in decoder-only Transformers by showing that relying solely on the previous layer's hidden state limits long-range and multi-step reasoning. It proposes Layer-Integrated Memory (LIMe), a lightweight extension that routes and blends representations from all earlier layers using a trainable per-head, per-layer router over pre-allocated Key–Value buffers, incurring minimal overhead. Empirically, LIMe delivers faster convergence and lower perplexity per FLOP, improves synthetic reasoning benchmarks, and enables very deep architectures to scale more effectively, while preserving higher value-vector entropy and better token separability. Analyses of learned routings reveal systematic reuse of local and long-distance features, demonstrating LIMe's capacity to mitigate collapse without increasing hidden-state size and suggesting new directions for latent-space reasoning in deep transformers.

Abstract

In contrast to RNNs, which compress their history into a single hidden state, Transformers can attend to all past tokens directly. However, standard Transformers rely solely on the hidden state from the previous layer to represent the entire context. We show that this design choice induces representation collapse and degrades performance. To address this issue, we introduce Layer-Integrated Memory (LIMe), a lightweight extension that leverages existing key-value buffers and learns per-head, per-layer routing weights to integrate representations from all previous layers with negligible overhead. Through extensive experiments-including language modeling, synthetic reasoning benchmarks, and very deep architectures-LIMe consistently achieves faster convergence, lower perplexity per FLOP, and substantial accuracy improvements on synthetic tasks while preserving higher value-vector entropy and improved token separability. Finally, our analysis of the learned routing weights reveals systematic reuse of both local and long-distance features, demonstrating how LIMe mitigates collapse, unlocks richer representations without increasing hidden-state size, and points to promising directions for future research.

You Do Not Fully Utilize Transformer's Representation Capacity

TL;DR

The paper tackles representation collapse in decoder-only Transformers by showing that relying solely on the previous layer's hidden state limits long-range and multi-step reasoning. It proposes Layer-Integrated Memory (LIMe), a lightweight extension that routes and blends representations from all earlier layers using a trainable per-head, per-layer router over pre-allocated Key–Value buffers, incurring minimal overhead. Empirically, LIMe delivers faster convergence and lower perplexity per FLOP, improves synthetic reasoning benchmarks, and enables very deep architectures to scale more effectively, while preserving higher value-vector entropy and better token separability. Analyses of learned routings reveal systematic reuse of local and long-distance features, demonstrating LIMe's capacity to mitigate collapse without increasing hidden-state size and suggesting new directions for latent-space reasoning in deep transformers.

Abstract

In contrast to RNNs, which compress their history into a single hidden state, Transformers can attend to all past tokens directly. However, standard Transformers rely solely on the hidden state from the previous layer to represent the entire context. We show that this design choice induces representation collapse and degrades performance. To address this issue, we introduce Layer-Integrated Memory (LIMe), a lightweight extension that leverages existing key-value buffers and learns per-head, per-layer routing weights to integrate representations from all previous layers with negligible overhead. Through extensive experiments-including language modeling, synthetic reasoning benchmarks, and very deep architectures-LIMe consistently achieves faster convergence, lower perplexity per FLOP, and substantial accuracy improvements on synthetic tasks while preserving higher value-vector entropy and improved token separability. Finally, our analysis of the learned routing weights reveals systematic reuse of both local and long-distance features, demonstrating how LIMe mitigates collapse, unlocks richer representations without increasing hidden-state size, and points to promising directions for future research.

Paper Structure

This paper contains 30 sections, 10 equations, 12 figures, 9 tables.

Figures (12)

  • Figure 1: Training loss per FLOPs for LLaMa and LIMe. LIMe has a substantially lower loss with a similar amount of FLOPs. See Section \ref{['sec:lm']} for more details.
  • Figure 2: (a) Matrix entropy of values on the FineWeb Edu subset by layer. LIMe has more diverse values than LLaMa, which indicates that more information is stored in its hidden states. (b) Values' classification accuracy, with standard deviation over five cross-validation folds. Values in later layers obtained from LIMe can be linearly separated with nearly 1.0 accuracy, whereas the accuracy for values from LLaMa is much lower. See Section \ref{['subsection:representation']} for more details.
  • Figure 3: (a) t-SNE of similar tokens' hidden states among layers. Although hidden states are not separable in later layers for both models, unlike LLaMA, LIMe can make updates attending to the previous representations, which leads to high values' separability. (b) t-SNE of similar tokens' values among layers shows higher separability for LIMe's representations. See Section \ref{['subsection:representation']} for more details.
  • Figure 4: (a) LIMe exhibits consistently higher entropy of value vectors across layers, particularly in the final layer, indicating reduced representation collapse compared to LLaMa. (b) On the Arithmetic Expressions task, LIMe significantly outperforms the LLaMa baseline, maintaining high accuracy even as the number of operands increases, while LLaMa's performance deteriorates. For details, see Section \ref{['sec:arithmetic']}.
  • Figure 5: Mean retrieval weight for each buffered representation across subsequent layers. Larger diagonal values confirm reliance on the current residual stream, while the pronounced off-diagonal weights for the earliest buffers and the repeated reuse of intermediate ones show that the model systematically retrieves earlier features, providing auxiliary memory and helping to mitigate representation collapse. See Section \ref{['subsec:analysing_routings']} for more details.
  • ...and 7 more figures