Table of Contents
Fetching ...

MemoryFormer: Minimize Transformer Computation by Removing Fully-Connected Layers

Ning Ding, Yehui Tang, Haochen Qin, Zhenli Zhou, Chao Xu, Lin Li, Kai Han, Heng Liao, Yunhe Wang

TL;DR

MemoryFormer tackles the dominant compute bottleneck in transformers—the FC projections—by replacing them with a Memory Layer that uses in-memory hash tables and locality-sensitive hashing to approximate linear projections. The architecture preserves standard MHA while substituting the FFN with two Memory Layers that expand and then contract dimensionality, achieving an overall per-block FLOP reduction while remaining trainable end-to-end. Empirical results across multiple Pythia scales and NLP benchmarks show MemoryFormer matching or exceeding baseline accuracy with substantially lower FLOPs, and it outperforms several efficient-transformer variants that mainly reduce attention cost. This approach also provides hardware design insights, suggesting that larger memory bandwidth and cache efficiency can further enhance performance, making large-scale inference more feasible. Overall, MemoryFormer demonstrates a viable path to compute-efficient transformers by reimagining feature transformation entirely in the embedding space through learnable hashing.”

Abstract

In order to reduce the computational complexity of large language models, great efforts have been made to to improve the efficiency of transformer models such as linear attention and flash-attention. However, the model size and corresponding computational complexity are constantly scaled up in pursuit of higher performance. In this work, we present MemoryFormer, a novel transformer architecture which significantly reduces the computational complexity (FLOPs) from a new perspective. We eliminate nearly all the computations of the transformer model except for the necessary computation required by the multi-head attention operation. This is made possible by utilizing an alternative method for feature transformation to replace the linear projection of fully-connected layers. Specifically, we first construct a group of in-memory lookup tables that store a large amount of discrete vectors to replace the weight matrix used in linear projection. We then use a hash algorithm to retrieve a correlated subset of vectors dynamically based on the input embedding. The retrieved vectors combined together will form the output embedding, which provides an estimation of the result of matrix multiplication operation in a fully-connected layer. Compared to conducting matrix multiplication, retrieving data blocks from memory is a much cheaper operation which requires little computations. We train MemoryFormer from scratch and conduct extensive experiments on various benchmarks to demonstrate the effectiveness of the proposed model.

MemoryFormer: Minimize Transformer Computation by Removing Fully-Connected Layers

TL;DR

MemoryFormer tackles the dominant compute bottleneck in transformers—the FC projections—by replacing them with a Memory Layer that uses in-memory hash tables and locality-sensitive hashing to approximate linear projections. The architecture preserves standard MHA while substituting the FFN with two Memory Layers that expand and then contract dimensionality, achieving an overall per-block FLOP reduction while remaining trainable end-to-end. Empirical results across multiple Pythia scales and NLP benchmarks show MemoryFormer matching or exceeding baseline accuracy with substantially lower FLOPs, and it outperforms several efficient-transformer variants that mainly reduce attention cost. This approach also provides hardware design insights, suggesting that larger memory bandwidth and cache efficiency can further enhance performance, making large-scale inference more feasible. Overall, MemoryFormer demonstrates a viable path to compute-efficient transformers by reimagining feature transformation entirely in the embedding space through learnable hashing.”

Abstract

In order to reduce the computational complexity of large language models, great efforts have been made to to improve the efficiency of transformer models such as linear attention and flash-attention. However, the model size and corresponding computational complexity are constantly scaled up in pursuit of higher performance. In this work, we present MemoryFormer, a novel transformer architecture which significantly reduces the computational complexity (FLOPs) from a new perspective. We eliminate nearly all the computations of the transformer model except for the necessary computation required by the multi-head attention operation. This is made possible by utilizing an alternative method for feature transformation to replace the linear projection of fully-connected layers. Specifically, we first construct a group of in-memory lookup tables that store a large amount of discrete vectors to replace the weight matrix used in linear projection. We then use a hash algorithm to retrieve a correlated subset of vectors dynamically based on the input embedding. The retrieved vectors combined together will form the output embedding, which provides an estimation of the result of matrix multiplication operation in a fully-connected layer. Compared to conducting matrix multiplication, retrieving data blocks from memory is a much cheaper operation which requires little computations. We train MemoryFormer from scratch and conduct extensive experiments on various benchmarks to demonstrate the effectiveness of the proposed model.

Paper Structure

This paper contains 23 sections, 10 equations, 4 figures, 6 tables.

Figures (4)

  • Figure 1: FLOPs with different model hidden size and sequence lengths.
  • Figure 2: A demonstration with $\tau=2$ and $K=3$, where $\mathbf z_1$ is hashed to the bucket2 of $\mathbf T_1$, $\mathbf z_2$ is hashed to the bucket1 of $\mathbf T_2$, $\mathbf z_3$ is hashed to the bucket2 of $\mathbf T_3$.
  • Figure 3: Left: The schematic diagram of the Memory Layer. Right: One building block of the MemoryFormer.
  • Figure 4: The frequency at which each bucket in the hash table is retrieved.