Table of Contents
Fetching ...

Don't Pay Attention, PLANT It: Pretraining Attention via Learning-to-Rank

Debjyoti Saha Roy, Byron C. Wallace, Javed A. Aslam

TL;DR

PLANT tackles XMTC’s long-tail challenge by pretraining per-label attention with a Learning-to-Rank objective guided by Mutual Information Gain, then leveraging this seed in full end-to-end training. The two-stage approach is architecture-agnostic and integrates with diverse LLM backbones, yielding substantial improvements on rare labels and improved sample efficiency. Ablation studies quantify the central role of Stage 1 initialization, MIG signals, and the ranking loss, while Stage 2 refinements provide complementary gains. Practically, PLANT enables smaller models to outperform larger baselines in many settings, making attention initialization a transferable, low-cost knob for XMTC beyond ICD coding to legal and web-scale categorization.

Abstract

State-of-the-art Extreme Multi-Label Text Classification models rely on multi-label attention to focus on key tokens in input text, but learning good attention weights is challenging. We introduce PLANT - Pretrained and Leveraged Attention - a plug-and-play strategy for initializing attention. PLANT works by planting label-specific attention using a pretrained Learning-to-Rank model guided by mutual information gain. This architecture-agnostic approach integrates seamlessly with large language model backbones such as Mistral-7B, LLaMA3-8B, DeepSeek-V3, and Phi-3. PLANT outperforms state-of-the-art methods across tasks including ICD coding, legal topic classification, and content recommendation. Gains are especially pronounced in few-shot settings, with substantial improvements on rare labels. Ablation studies confirm that attention initialization is a key driver of these gains. For code and trained models, see https://github.com/debjyotiSRoy/xcube/tree/plant

Don't Pay Attention, PLANT It: Pretraining Attention via Learning-to-Rank

TL;DR

PLANT tackles XMTC’s long-tail challenge by pretraining per-label attention with a Learning-to-Rank objective guided by Mutual Information Gain, then leveraging this seed in full end-to-end training. The two-stage approach is architecture-agnostic and integrates with diverse LLM backbones, yielding substantial improvements on rare labels and improved sample efficiency. Ablation studies quantify the central role of Stage 1 initialization, MIG signals, and the ranking loss, while Stage 2 refinements provide complementary gains. Practically, PLANT enables smaller models to outperform larger baselines in many settings, making attention initialization a transferable, low-cost knob for XMTC beyond ICD coding to legal and web-scale categorization.

Abstract

State-of-the-art Extreme Multi-Label Text Classification models rely on multi-label attention to focus on key tokens in input text, but learning good attention weights is challenging. We introduce PLANT - Pretrained and Leveraged Attention - a plug-and-play strategy for initializing attention. PLANT works by planting label-specific attention using a pretrained Learning-to-Rank model guided by mutual information gain. This architecture-agnostic approach integrates seamlessly with large language model backbones such as Mistral-7B, LLaMA3-8B, DeepSeek-V3, and Phi-3. PLANT outperforms state-of-the-art methods across tasks including ICD coding, legal topic classification, and content recommendation. Gains are especially pronounced in few-shot settings, with substantial improvements on rare labels. Ablation studies confirm that attention initialization is a key driver of these gains. For code and trained models, see https://github.com/debjyotiSRoy/xcube/tree/plant

Paper Structure

This paper contains 35 sections, 11 equations, 8 figures, 18 tables.

Figures (8)

  • Figure 1: PLANT Attention. On the left, the $\mathsf{MultiHead\!-\!Attention}$ module vaswani2017attention, parameterized by $\boldsymbol{W}_{\mathsf{attn}}$, takes as input queries $\mathbf{Q}=\mathbf{E}$ (label embeddings), keys $\mathbf{K}=\mathbf{H}_i$, and values $\mathbf{V}=\mathbf{H}_i$, and produces $\mathbf{S}\in\mathbb{R}^{|\mathcal{L}|\times n}$, representing the token-level attention distribution for each label. The orange box highlights the set of top-$k$ tokens per label, $\mathcal{T}_l$, selected via Mutual Information Gain $r_{lj}$ between labels and tokens. Within this set, two tokens $j$ (red) and $h$ (blue) are compared, with $j$ being more relevant than $h$. The $\mathsf{MultiHead\!-\!Attention}$ module is trained to maximize the probability of correctly ranking tokens $j$ and $h$ ($P(j \succ h)$), while penalizing incorrect rankings in proportion to their impact on the $\mathsf{nDCG@k}$ metric if $j$ and $h$ were swapped ($|\Delta \mathsf{nDCG}@k|_{jh}$). Finally, the summation box aggregates over all token pairs in $\mathcal{T}_l$, yielding the PLANT objective---(nDCG term$\times$probability term)---that is optimized to initialize $\boldsymbol{W}_{\mathsf{attn}}$.
  • Figure 2: (Left) Rare codes have near-zero macro-F1. (Right) Macro-F1 distribution on MIMIC-III-few for rare codes across Co-Relationluo2024corelation (mean=0.054), Mistral-7B (0.309), and Mistral-7B+PLANT (0.663). Mistral-7B+PLANT yields far more rare codes with higher F1. See Section \ref{['subsection:fig2vsfig3']} (RQ4).
  • Figure 3: (Left) Random initialization yields diffuse, inconsistent patterns for rare codes (broad orange peak near 0.75), whereas PLANT restores consistency (sharp orange peak at 0.985). (Right) Rare-F1 when training only on common labels ($>1\%$ frequency). PLANT retains strong zero-shot performance (7.3--8.1%); random attention initialization collapses (0.5--1.1%). See Section \ref{['subsection:freq_binned_transfer']} (RQ6).
  • Figure 4: (Left) On MIMIC-IV-full with LLaMA3-8B, better Stage 1 $\mathsf{nDCG}@k$ (attn. init. quality) leads to higher Stage 2 (downstream) macro-F1. The same trend for rare-F1 is shown in App. \ref{['sec:appendix_results_detailed']}, Fig. \ref{['fig:stage1_quality']}. Extended results (MIMIC-III-full+Mistral-7B, MIMIC-IV-full+LLaMA3-8B; macro-F1, $\mathsf{P@15}$) appear in App. \ref{['sec:appendix_results_detailed']}, Fig. \ref{['fig:effect_plant_initialization']}. (Right)PLANT consistently boosts Mistral-7B on MIMIC-IV-full across training sizes: solid lines (Mistral-7B+PLANT) beat dashed baselines on $\mathsf{P@5}$/$\mathsf{P@15}$, with largest gains in low-data regimes. Paired MIMIC-III-full+MIMIC-IV-full results are in App. \ref{['sec:appendix_results_detailed']}, Fig. \ref{['fig:plant_vary_trn_size']}.
  • Figure 5: PLANT's Stage 1 attention initialization critical for downstream performance. Insets show performance degradation when Stage 1 is absent (weights $\mathbf{W}_{\mathsf{attn}}$ initialized randomly). The left panel is shown in the main paper as Figure \ref{['fig:effect_plant_initialization_plus_plant_vary_trn_size']} (left).
  • ...and 3 more figures