Table of Contents
Fetching ...

Align-to-Distill: Trainable Attention Alignment for Knowledge Distillation in Neural Machine Translation

Heegon Jin, Seonil Son, Jemin Park, Youngseok Kim, Hyungjong Noh, Yeonsoo Lee

TL;DR

The “Align-to-Distill” (A2D) strategy, designed to address the feature mapping problem by adaptively aligning student attention heads with their teacher counterparts during training, is introduced.

Abstract

The advent of scalable deep models and large datasets has improved the performance of Neural Machine Translation. Knowledge Distillation (KD) enhances efficiency by transferring knowledge from a teacher model to a more compact student model. However, KD approaches to Transformer architecture often rely on heuristics, particularly when deciding which teacher layers to distill from. In this paper, we introduce the 'Align-to-Distill' (A2D) strategy, designed to address the feature mapping problem by adaptively aligning student attention heads with their teacher counterparts during training. The Attention Alignment Module in A2D performs a dense head-by-head comparison between student and teacher attention heads across layers, turning the combinatorial mapping heuristics into a learning problem. Our experiments show the efficacy of A2D, demonstrating gains of up to +3.61 and +0.63 BLEU points for WMT-2022 De->Dsb and WMT-2014 En->De, respectively, compared to Transformer baselines.

Align-to-Distill: Trainable Attention Alignment for Knowledge Distillation in Neural Machine Translation

TL;DR

The “Align-to-Distill” (A2D) strategy, designed to address the feature mapping problem by adaptively aligning student attention heads with their teacher counterparts during training, is introduced.

Abstract

The advent of scalable deep models and large datasets has improved the performance of Neural Machine Translation. Knowledge Distillation (KD) enhances efficiency by transferring knowledge from a teacher model to a more compact student model. However, KD approaches to Transformer architecture often rely on heuristics, particularly when deciding which teacher layers to distill from. In this paper, we introduce the 'Align-to-Distill' (A2D) strategy, designed to address the feature mapping problem by adaptively aligning student attention heads with their teacher counterparts during training. The Attention Alignment Module in A2D performs a dense head-by-head comparison between student and teacher attention heads across layers, turning the combinatorial mapping heuristics into a learning problem. Our experiments show the efficacy of A2D, demonstrating gains of up to +3.61 and +0.63 BLEU points for WMT-2022 De->Dsb and WMT-2014 En->De, respectively, compared to Transformer baselines.
Paper Structure (23 sections, 10 equations, 2 figures, 6 tables)

This paper contains 23 sections, 10 equations, 2 figures, 6 tables.

Figures (2)

  • Figure 1: Attention Transfer with A2D. The Attention Alignment Module (AAM), implemented as a pointwise convolution layer, produces intermediate attention maps from a collection of student attention maps. The number of intermediate maps matches the total attention maps of the teacher model, encompassing all layers and heads. These intermediate maps are then directly compared to the teacher's attention maps using KL-Divergence, without any form of reduction.
  • Figure 2: Attention head connection weights in the trained AAM. Axes indicate attention head numbers in the student (3 layers of 8 heads) and teacher (6 layers of 4 heads) models. The dashed grid shows layer boundaries. Darker colors signify stronger connections. Best viewed in color.