Table of Contents
Fetching ...

Latent Flow Transformer

Yen-Chen Wu, Feng-Ting Liao, Meng-Hsi Chen, Pei-Chen Ho, Farhang Nabiei, Da-shan Shiu

TL;DR

The paper tackles the inefficiency of deep transformer stacks by introducing the Latent Flow Transformer (LFT), which replaces a block of layers with a learned latent transport operator trained via Flow Matching. It further tackles the issue of velocity-path ambiguities with Flow Walking, enabling more reliable latent transport and stronger compression. Empirical results on Pythia-410M show substantial layer-reduction (e.g., 6 of 24 or 12 of 24 replaced) with favorable KL divergence and the potential to recover performance through FW, bridging autoregressive and flow-based generation paradigms. The work advances parameter-efficient language modeling by leveraging continuous-time flow concepts, providing a pathway toward leaner LLMs with preserved functional fidelity and actionable metrics like the Recoupling Ratio for layer selection.

Abstract

Transformers, the standard implementation for large language models (LLMs), typically consist of tens to hundreds of discrete layers. While more layers can lead to better performance, this approach has been challenged as far from efficient, especially given the superiority of continuous layers demonstrated by diffusion and flow-based models for image generation. We propose the Latent Flow Transformer (LFT), which replaces a block of layers with a single learned transport operator trained via flow matching, offering significant compression while maintaining compatibility with the original architecture. Additionally, we address the limitations of existing flow-based methods in \textit{preserving coupling} by introducing the Flow Walking (FW) algorithm. On the Pythia-410M model, LFT trained with flow matching compresses 6 of 24 layers and outperforms directly skipping 2 layers (KL Divergence of LM logits at 0.407 vs. 0.529), demonstrating the feasibility of this design. When trained with FW, LFT further distills 12 layers into one while reducing the KL to 0.736 surpassing that from skipping 3 layers (0.932), significantly narrowing the gap between autoregressive and flow-based generation paradigms.

Latent Flow Transformer

TL;DR

The paper tackles the inefficiency of deep transformer stacks by introducing the Latent Flow Transformer (LFT), which replaces a block of layers with a learned latent transport operator trained via Flow Matching. It further tackles the issue of velocity-path ambiguities with Flow Walking, enabling more reliable latent transport and stronger compression. Empirical results on Pythia-410M show substantial layer-reduction (e.g., 6 of 24 or 12 of 24 replaced) with favorable KL divergence and the potential to recover performance through FW, bridging autoregressive and flow-based generation paradigms. The work advances parameter-efficient language modeling by leveraging continuous-time flow concepts, providing a pathway toward leaner LLMs with preserved functional fidelity and actionable metrics like the Recoupling Ratio for layer selection.

Abstract

Transformers, the standard implementation for large language models (LLMs), typically consist of tens to hundreds of discrete layers. While more layers can lead to better performance, this approach has been challenged as far from efficient, especially given the superiority of continuous layers demonstrated by diffusion and flow-based models for image generation. We propose the Latent Flow Transformer (LFT), which replaces a block of layers with a single learned transport operator trained via flow matching, offering significant compression while maintaining compatibility with the original architecture. Additionally, we address the limitations of existing flow-based methods in \textit{preserving coupling} by introducing the Flow Walking (FW) algorithm. On the Pythia-410M model, LFT trained with flow matching compresses 6 of 24 layers and outperforms directly skipping 2 layers (KL Divergence of LM logits at 0.407 vs. 0.529), demonstrating the feasibility of this design. When trained with FW, LFT further distills 12 layers into one while reducing the KL to 0.736 surpassing that from skipping 3 layers (0.932), significantly narrowing the gap between autoregressive and flow-based generation paradigms.

Paper Structure

This paper contains 19 sections, 9 equations, 6 figures, 1 table, 3 algorithms.

Figures (6)

  • Figure 1: Velocity field estimator. (a) The DiT block, mapping an input hidden state to an output hidden state conditioned on time $t$. (b) Velocity field estimator derived from the DiT block, mapping an input state $x_t$ and time $t$ to a velocity $v_t$. (c) Velocity field estimator for Pythia.
  • Figure 2: Static structure of an unrolled LFT. We highlight only the latent flow layer. The simple reconstruction rule of \ref{['eq:take_one_step']} is assumed. (a) LFT based on Pythia with single step reconstruction. (b) LFT based on Pythia with two step reconstruction.
  • Figure 3: Toy trajectories for paired data. Faded lines show the ground‐truth trajectories of paired points, while solid curves depict predictions from various flow‐matching algorithms. (a) Standard flow matching fails to maintain pairwise correspondence. (b) FW results for $k=1$ (top) and $k=2$ (bottom): using fewer integration steps leads to poor trajectory estimates. (c) FW with $k=3$: projected view (top) and full trajectories (bottom) demonstrate how the learned velocity field generates smooth, curved paths that avoid intersections. (d) A hybrid of FW and standard flow matching with $\alpha=0.001$ preserves the constant velocity of paired data while allowing curvature at intersections to prevent flow crossings.
  • Figure 4: Latent distillation with flow matching vs layer-skipping baselines. Validation NMSE between predicted and original hidden states vs. training tokens. The dotted lines show Pythia layer-skipping baselines. Flow-matching layers approximate original representations with lower reconstruction error than layer-skipping baselines across both setting.
  • Figure 5: Inference performance ($\textbf{KL}_{x||\hat{x}}$ and NMSE) vs number of discrete points $k$. We report results of LFT-SFM and LFT-FW with latent flow layer replacing layer 6-18 of Pythia-410m.
  • ...and 1 more figures