Table of Contents
Fetching ...

Tree Cross Attention

Leo Feng, Frederick Tung, Hossein Hajimirsadeghi, Yoshua Bengio, Mohamed Osama Ahmed

TL;DR

Tree Cross Attention (TCA) is proposed - a module based on Cross Attention that only retrieves information from a logarithmic $\mathcal{O}(\log(N))$ number of tokens for performing inference.

Abstract

Cross Attention is a popular method for retrieving information from a set of context tokens for making predictions. At inference time, for each prediction, Cross Attention scans the full set of $\mathcal{O}(N)$ tokens. In practice, however, often only a small subset of tokens are required for good performance. Methods such as Perceiver IO are cheap at inference as they distill the information to a smaller-sized set of latent tokens $L < N$ on which cross attention is then applied, resulting in only $\mathcal{O}(L)$ complexity. However, in practice, as the number of input tokens and the amount of information to distill increases, the number of latent tokens needed also increases significantly. In this work, we propose Tree Cross Attention (TCA) - a module based on Cross Attention that only retrieves information from a logarithmic $\mathcal{O}(\log(N))$ number of tokens for performing inference. TCA organizes the data in a tree structure and performs a tree search at inference time to retrieve the relevant tokens for prediction. Leveraging TCA, we introduce ReTreever, a flexible architecture for token-efficient inference. We show empirically that Tree Cross Attention (TCA) performs comparable to Cross Attention across various classification and uncertainty regression tasks while being significantly more token-efficient. Furthermore, we compare ReTreever against Perceiver IO, showing significant gains while using the same number of tokens for inference.

Tree Cross Attention

TL;DR

Tree Cross Attention (TCA) is proposed - a module based on Cross Attention that only retrieves information from a logarithmic number of tokens for performing inference.

Abstract

Cross Attention is a popular method for retrieving information from a set of context tokens for making predictions. At inference time, for each prediction, Cross Attention scans the full set of tokens. In practice, however, often only a small subset of tokens are required for good performance. Methods such as Perceiver IO are cheap at inference as they distill the information to a smaller-sized set of latent tokens on which cross attention is then applied, resulting in only complexity. However, in practice, as the number of input tokens and the amount of information to distill increases, the number of latent tokens needed also increases significantly. In this work, we propose Tree Cross Attention (TCA) - a module based on Cross Attention that only retrieves information from a logarithmic number of tokens for performing inference. TCA organizes the data in a tree structure and performs a tree search at inference time to retrieve the relevant tokens for prediction. Leveraging TCA, we introduce ReTreever, a flexible architecture for token-efficient inference. We show empirically that Tree Cross Attention (TCA) performs comparable to Cross Attention across various classification and uncertainty regression tasks while being significantly more token-efficient. Furthermore, we compare ReTreever against Perceiver IO, showing significant gains while using the same number of tokens for inference.
Paper Structure (39 sections, 6 equations, 14 figures, 16 tables, 1 algorithm)

This paper contains 39 sections, 6 equations, 14 figures, 16 tables, 1 algorithm.

Figures (14)

  • Figure 1: Architecture Diagram of ReTreever. Input Array comprises a set of $N$ context tokens which are fed through an encoder to compute a set of context encodings. Query Array denotes a batch of $M$ query feature vectors. Tree Cross Attention organizes the encodings and constructs a tree $\mathcal{T}$. At inference time, given a query feature vector, a logarithmic-sized subset of nodes (encodings) is retrieved from the tree $\mathcal{T}$. The query feature vector retrieves information from the subset of encodings via Cross Attention and makes a prediction.
  • Figure 2: Diagram of the aggregation procedure performed during the Tree Construction phase. The aggregation procedure is performed bottom-up beginning from the parents of the leaves and ending at the root of the tree. The complexity of this procedure is $\mathcal{O}(N)$ but this only needs to be performed once for a set of context tokens. Compared to the cost of performing multiple predictions, the one-time cost of the aggregation process is minor.
  • Figure 3: Example result of a Retrieval phase. The policy creates a path from the tree’s root to its leaves, selecting a subset of nodes: the terminal leaves in the path and the highest-level unexplored ancestors of the other leaves. The green arrows represent the path (actions) chosen by the policy $\pi$. The red arrows represent the actions rejected by the policy. The green nodes denote the subset of nodes selected, i.e., $\mathbb{S} = \{h_2, h_3, h_9, h_{10}\}$. The grey nodes denote nodes that were explored at some point but not selected. The red nodes denote the nodes that were not explored or selected.
  • Figure 4: Analyses Plots. (left) Memory usage plot comparing the rate in which memory usage grows at inference time relative to the number of tokens. (middle) Training curve with varying weights ($\lambda_{CA}$) for the Cross Attention loss term. (right) Training curve with varying weights ($\lambda_{RL}$) for the RL retrieval loss term.
  • Figure 5: Architecture Diagram of Perceiver IO. The model is composed of an iterative attention encoder ($\mathbb{R}^{N \times D} \rightarrow \mathbb{R}^{L \times D}$) and a Cross Attention module ($\mathbb{R}^{M \times D} \times \mathbb{R}^{L \times D} \rightarrow \mathbb{R}^{M \times D}$).
  • ...and 9 more figures