Table of Contents
Fetching ...

Controlling Computation versus Quality for Neural Sequence Models

Ankur Bapna, Naveen Arivazhagan, Orhan Firat

TL;DR

This work tackles the inefficiency of fixed compute in neural sequence models by introducing Conditional Computation Transformer (CCT), which gates sub-networks via learned control networks and is trained to operate across multiple inference budgets. The approach combines soft-to-discrete gate training with a budgeted inference objective and applies it to Transformer-based machine translation and BERT-style representation learning. Empirical results show CCT matches or outperforms full-budget Transformers and outperforms budgeted baselines at lower budgets, while analyses reveal adaptive allocation of compute to input difficulty and decoding dynamics. The method offers a practical path to on-demand, computation-aware inference for large sequence models without sacrificing performance.

Abstract

Most neural networks utilize the same amount of compute for every example independent of the inherent complexity of the input. Further, methods that adapt the amount of computation to the example focus on finding a fixed inference-time computational graph per example, ignoring any external computational budgets or varying inference time limitations. In this work, we utilize conditional computation to make neural sequence models (Transformer) more efficient and computation-aware during inference. We first modify the Transformer architecture, making each set of operations conditionally executable depending on the output of a learned control network. We then train this model in a multi-task setting, where each task corresponds to a particular computation budget. This allows us to train a single model that can be controlled to operate on different points of the computation-quality trade-off curve, depending on the available computation budget at inference time. We evaluate our approach on two tasks: (i) WMT English-French Translation and (ii) Unsupervised representation learning (BERT). Our experiments demonstrate that the proposed Conditional Computation Transformer (CCT) is competitive with vanilla Transformers when allowed to utilize its full computational budget, while improving significantly over computationally equivalent baselines when operating on smaller computational budgets.

Controlling Computation versus Quality for Neural Sequence Models

TL;DR

This work tackles the inefficiency of fixed compute in neural sequence models by introducing Conditional Computation Transformer (CCT), which gates sub-networks via learned control networks and is trained to operate across multiple inference budgets. The approach combines soft-to-discrete gate training with a budgeted inference objective and applies it to Transformer-based machine translation and BERT-style representation learning. Empirical results show CCT matches or outperforms full-budget Transformers and outperforms budgeted baselines at lower budgets, while analyses reveal adaptive allocation of compute to input difficulty and decoding dynamics. The method offers a practical path to on-demand, computation-aware inference for large sequence models without sacrificing performance.

Abstract

Most neural networks utilize the same amount of compute for every example independent of the inherent complexity of the input. Further, methods that adapt the amount of computation to the example focus on finding a fixed inference-time computational graph per example, ignoring any external computational budgets or varying inference time limitations. In this work, we utilize conditional computation to make neural sequence models (Transformer) more efficient and computation-aware during inference. We first modify the Transformer architecture, making each set of operations conditionally executable depending on the output of a learned control network. We then train this model in a multi-task setting, where each task corresponds to a particular computation budget. This allows us to train a single model that can be controlled to operate on different points of the computation-quality trade-off curve, depending on the available computation budget at inference time. We evaluate our approach on two tasks: (i) WMT English-French Translation and (ii) Unsupervised representation learning (BERT). Our experiments demonstrate that the proposed Conditional Computation Transformer (CCT) is competitive with vanilla Transformers when allowed to utilize its full computational budget, while improving significantly over computationally equivalent baselines when operating on smaller computational budgets.

Paper Structure

This paper contains 22 sections, 15 equations, 14 figures.

Figures (14)

  • Figure 1: Our approach for adapting models for conditional computation: During training, sub-network outputs are gated by noised continuous outputs from control networks trained end-to-end with the model. During inference, sub-networks are conditionally executed depending on discrete outputs from control networks. Outputs are optionally short-circuited with residual connections.
  • Figure 2: Conditional Computation Attention Layer.
  • Figure 3: Conditional Computation Feedforward Layer.
  • Figure 4: Comparing the performance of CCT (red) at different encoder-decoder computation budgets against Transformer baselines (blue). x-axis corresponds to the average encoder-decoder per-token Flops (in millions). The transformer network size is denoted next to each corresponding data-point using the format (hidden layer size, number of layers). Note: We do not compare the computation required for embedding lookup and softmax operations.
  • Figure 5: Comparing the performance of CCT at different decoder computation budgets against Transformer baselines, when allowed to use full encoder computation. x-axis corresponds to the decoder per-token Flops (in millions). Blue dots denote the quality of individual transformer baselines. The decoder size is denoted next to each corresponding data-point using the format (hidden layer size, number of layers). Note: We do not compare the computation required for embedding lookup and softmax operations.
  • ...and 9 more figures