Table of Contents
Fetching ...

Faster Language Models with Better Multi-Token Prediction Using Tensor Decomposition

Artem Basharin, Andrei Chertkov, Ivan Oseledets

TL;DR

The paper tackles the latency of sampling in autoregressive transformers by introducing a rank-$r$ Canonical Polyadic (CP) decomposition to model the joint distribution of the next $n$ tokens. By building $n$ CP heads and mixture weights derived from the context embedding, the approach treats the joint prediction as a mixture of experts, enabling simultaneous multi-token generation with low overhead. Empirically, higher CP ranks improve joint loss and draft acceptance rates in speculative decoding, translating into faster inference for text and code generation while remaining robust across model sizes and fine-tuning regimes. The method integrates with self-speculative decoding and supports efficient training via auxiliary load balancing, making it practical for large-scale LLMs.

Abstract

We propose a new model for multi-token prediction in transformers, aiming to enhance sampling efficiency without compromising accuracy. Motivated by recent work that predicts the probabilities of subsequent tokens using multiple heads, we connect this approach to rank-$1$ canonical tensor decomposition. By generalizing it to a rank-$r$ canonical probability decomposition, we develop an improved model that predicts multiple tokens simultaneously. This model can also be interpreted as a mixture of experts, allowing us to leverage successful techniques from that domain for efficient and robust training. Importantly, the overall overhead for training and sampling remains low. Our method demonstrates significant improvements in inference speed for both text and code generation tasks, proving particularly beneficial within the self-speculative decoding paradigm. It maintains its effectiveness across various model sizes and training epochs, highlighting its robustness and scalability.

Faster Language Models with Better Multi-Token Prediction Using Tensor Decomposition

TL;DR

The paper tackles the latency of sampling in autoregressive transformers by introducing a rank- Canonical Polyadic (CP) decomposition to model the joint distribution of the next tokens. By building CP heads and mixture weights derived from the context embedding, the approach treats the joint prediction as a mixture of experts, enabling simultaneous multi-token generation with low overhead. Empirically, higher CP ranks improve joint loss and draft acceptance rates in speculative decoding, translating into faster inference for text and code generation while remaining robust across model sizes and fine-tuning regimes. The method integrates with self-speculative decoding and supports efficient training via auxiliary load balancing, making it practical for large-scale LLMs.

Abstract

We propose a new model for multi-token prediction in transformers, aiming to enhance sampling efficiency without compromising accuracy. Motivated by recent work that predicts the probabilities of subsequent tokens using multiple heads, we connect this approach to rank- canonical tensor decomposition. By generalizing it to a rank- canonical probability decomposition, we develop an improved model that predicts multiple tokens simultaneously. This model can also be interpreted as a mixture of experts, allowing us to leverage successful techniques from that domain for efficient and robust training. Importantly, the overall overhead for training and sampling remains low. Our method demonstrates significant improvements in inference speed for both text and code generation tasks, proving particularly beneficial within the self-speculative decoding paradigm. It maintains its effectiveness across various model sizes and training epochs, highlighting its robustness and scalability.

Paper Structure

This paper contains 15 sections, 15 equations, 6 figures, 4 tables.

Figures (6)

  • Figure 1: Schematic representation of the proposed model that predicts several tokens at once for a given sequence $x_{1}, x_{2}, \ldots, x_{t}$. We present the case of $n = 3$ predicted tokens $x_{t+1}, x_{t+2}, x_{t+3}$ and, accordingly, three heads which generate factor matrices $P_{\theta}^{(1)}$, $P_{\theta}^{(2)}$, and $P_{\theta}^{(3)}$ of the canonical decomposition and linear layer that generates weights $w$ are depicted.
  • Figure 2: Comparison of original (left) and reduced (right) tensor head designs. $W$ is unchanged during fine-tuning process.
  • Figure 3: Losses for the tiny transformer model with different CP-rank values trained on the TinyStories dataset.
  • Figure 4: Losses for the rank-8 tiny transformer model trained on the TinyStories dataset with different auxiliary loss penalties compared to the baseline (i.e., the rank-1 model).
  • Figure 5: Speculative decoding performance for the tiny transformer model with different CP-rank values trained on the TinyStories dataset from scratch.
  • ...and 1 more figures