The Strong Lottery Ticket Hypothesis for Multi-Head Attention Mechanisms
Hikari Otsuka, Daiki Chijiwa, Yasuyuki Okoshi, Daichi Fujiki, Susumu Takeuchi, Masato Motomura
TL;DR
The paper extends the strong lottery ticket hypothesis (SLTH) to transformers by proving the existence of strong lottery tickets within multi-head attention (MHA) mechanisms. It shows that a randomly initialized MHA with $H$ heads and input dimension $d$ contains SLTs that can approximate any target MHA when the key/value hidden dimensions satisfy $n_K,n_V = O\big(d\log(Hd^{3/2}/\epsilon)\big)$, leveraging a weight-merge reinterpretation and a two-layers-for-one approximation variant. This result is extended to transformers without normalization layers, with a block-wise error that decays exponentially with hidden dimensions and remains independent of sequence length. Empirically, the authors validate the theory on synthetic angular-velocity tasks and language-modeling settings, revealing practical weight-initialization scales that improve SLTs and demonstrating exponential error decay as hidden dimensions grow. The findings offer a theoretical foundation for pruning-based, training-free subnetworks in modern transformer architectures and suggest new directions for understanding overparameterization in sequence models.
Abstract
The strong lottery ticket hypothesis (SLTH) conjectures that high-performing subnetworks, called strong lottery tickets (SLTs), are hidden in randomly initialized neural networks. Although recent theoretical studies have established the SLTH across various neural architectures, the SLTH for transformer architectures still lacks theoretical understanding. In particular, the current theory of the SLTH does not yet account for the multi-head attention (MHA) mechanism, a core component of transformers. To address this gap, we introduce a theoretical analysis of the existence of SLTs within MHAs. We prove that, if a randomly initialized MHA of $H$ heads and input dimension $d$ has the hidden dimension $O(d\log(Hd^{3/2}))$ for the key and value, it contains an SLT that approximates an arbitrary MHA with the same input dimension with high probability. Furthermore, by leveraging this theory for MHAs, we extend the SLTH to transformers without normalization layers. We empirically validate our theoretical findings, demonstrating that the approximation error between the SLT within a source model (MHA and transformer) and an approximate target counterpart decreases exponentially by increasing the hidden dimension of the source model.
