Table of Contents
Fetching ...

Rethinking Weight Tying: Pseudo-Inverse Tying for Stable LM Training and Updates

Jian Gu, Aldeida Aleti, Chunyang Chen, Hongyu Zhang

TL;DR

This work tackles instability from token-interface drift in weight- tied language models by introducing Pseudo-Inverse Tying (PIT). PIT constructs a shared token memory Z and a learnable SPD transform T to realize E = Z T^{-1} and W_out = T Z^T, guaranteeing a pseudo-inverse-consistent interface with W_out E = I_d whenever Z^T Z = I_d. The approach uses a thin polar decomposition to obtain an orthonormal Z and parameterizes T via a Cholesky factor T = L L^T, enabling efficient, stable forward passes without explicit pseudo-inverse recomputation. Across on-device to edge-server scales (256M–1.3B) and both teacher-mode and scratch-mode setups, PIT yields stronger layerwise semantic consistency, improved training stability, and reduced side effects during lightweight post-training updates, demonstrating a practical inductive bias for robust LM training and deployment. This framework also strengthens mechanistic interpretability by providing a stable, comparable vocabulary geometry across layers and checkpoints, facilitating safer model editing and adaptation in resource-constrained settings.

Abstract

Weight tying is widely used in compact language models to reduce parameters by sharing the token table between the input embedding and the output projection. However, weight sharing does not guarantee a stable token interface: during training, the correspondence between encoding tokens into hidden states and decoding hidden states into logits can drift, worsening optimization sensitivity and making post-training interventions such as editing, patching, and lightweight adaptation less predictable. We propose Pseudo-Inverse Tying (PIT), which synchronizes embedding and unembedding as coupled projections of a shared latent token memory, guaranteeing a pseudo-inverse-consistent interface throughout training. PIT maintains an orthonormal shared memory, obtained by thin polar decomposition for teacher initialization or random orthonormal initialization from scratch, and introduces a fully learned symmetric positive definite hidden-space transform parameterized via a Cholesky factor. The output head applies this transform to hidden states before the vocabulary projection, while the embedding applies the inverse transform to token vectors using stable triangular solves, avoiding explicit pseudo-inverse recomputation and any vocabulary-sized auxiliary parameters. We evaluate PIT on on-device models spanning 256M-1.3B parameters across pretraining and adaptation, and consistently observe improved training stability, stronger layerwise semantic consistency, and substantially reduced side effects.

Rethinking Weight Tying: Pseudo-Inverse Tying for Stable LM Training and Updates

TL;DR

This work tackles instability from token-interface drift in weight- tied language models by introducing Pseudo-Inverse Tying (PIT). PIT constructs a shared token memory Z and a learnable SPD transform T to realize E = Z T^{-1} and W_out = T Z^T, guaranteeing a pseudo-inverse-consistent interface with W_out E = I_d whenever Z^T Z = I_d. The approach uses a thin polar decomposition to obtain an orthonormal Z and parameterizes T via a Cholesky factor T = L L^T, enabling efficient, stable forward passes without explicit pseudo-inverse recomputation. Across on-device to edge-server scales (256M–1.3B) and both teacher-mode and scratch-mode setups, PIT yields stronger layerwise semantic consistency, improved training stability, and reduced side effects during lightweight post-training updates, demonstrating a practical inductive bias for robust LM training and deployment. This framework also strengthens mechanistic interpretability by providing a stable, comparable vocabulary geometry across layers and checkpoints, facilitating safer model editing and adaptation in resource-constrained settings.

Abstract

Weight tying is widely used in compact language models to reduce parameters by sharing the token table between the input embedding and the output projection. However, weight sharing does not guarantee a stable token interface: during training, the correspondence between encoding tokens into hidden states and decoding hidden states into logits can drift, worsening optimization sensitivity and making post-training interventions such as editing, patching, and lightweight adaptation less predictable. We propose Pseudo-Inverse Tying (PIT), which synchronizes embedding and unembedding as coupled projections of a shared latent token memory, guaranteeing a pseudo-inverse-consistent interface throughout training. PIT maintains an orthonormal shared memory, obtained by thin polar decomposition for teacher initialization or random orthonormal initialization from scratch, and introduces a fully learned symmetric positive definite hidden-space transform parameterized via a Cholesky factor. The output head applies this transform to hidden states before the vocabulary projection, while the embedding applies the inverse transform to token vectors using stable triangular solves, avoiding explicit pseudo-inverse recomputation and any vocabulary-sized auxiliary parameters. We evaluate PIT on on-device models spanning 256M-1.3B parameters across pretraining and adaptation, and consistently observe improved training stability, stronger layerwise semantic consistency, and substantially reduced side effects.
Paper Structure (48 sections, 20 equations, 4 figures, 3 tables, 1 algorithm)

This paper contains 48 sections, 20 equations, 4 figures, 3 tables, 1 algorithm.

Figures (4)

  • Figure 1: Two perspectives of token semantics in the residual stream: input-side ($B_{\mathrm{in}}$) vs. output-side ($B_{\mathrm{out}}$, Logit Lens).
  • Figure 2: Layerwise semantic transition (transition trace) of the medium token during next-token prediction, under a stable vs. drifting token interface.
  • Figure 3: In Scratch-Mode, both PIT and TT show stable loss curve and TT keep a lower level of loss in the whole process.
  • Figure 4: In Teacher-Mode, PIT performs more stable than TT, and obtained lower loss. TT shows an weird loss increase at the beginning phase, which may be caused by the mis-aligned token interface.