Don't Pay Attention, PLANT It: Pretraining Attention via Learning-to-Rank
Debjyoti Saha Roy, Byron C. Wallace, Javed A. Aslam
TL;DR
PLANT tackles XMTC’s long-tail challenge by pretraining per-label attention with a Learning-to-Rank objective guided by Mutual Information Gain, then leveraging this seed in full end-to-end training. The two-stage approach is architecture-agnostic and integrates with diverse LLM backbones, yielding substantial improvements on rare labels and improved sample efficiency. Ablation studies quantify the central role of Stage 1 initialization, MIG signals, and the ranking loss, while Stage 2 refinements provide complementary gains. Practically, PLANT enables smaller models to outperform larger baselines in many settings, making attention initialization a transferable, low-cost knob for XMTC beyond ICD coding to legal and web-scale categorization.
Abstract
State-of-the-art Extreme Multi-Label Text Classification models rely on multi-label attention to focus on key tokens in input text, but learning good attention weights is challenging. We introduce PLANT - Pretrained and Leveraged Attention - a plug-and-play strategy for initializing attention. PLANT works by planting label-specific attention using a pretrained Learning-to-Rank model guided by mutual information gain. This architecture-agnostic approach integrates seamlessly with large language model backbones such as Mistral-7B, LLaMA3-8B, DeepSeek-V3, and Phi-3. PLANT outperforms state-of-the-art methods across tasks including ICD coding, legal topic classification, and content recommendation. Gains are especially pronounced in few-shot settings, with substantial improvements on rare labels. Ablation studies confirm that attention initialization is a key driver of these gains. For code and trained models, see https://github.com/debjyotiSRoy/xcube/tree/plant
