Table of Contents
Fetching ...

CLAX: Fast and Flexible Neural Click Models in JAX

Philipp Hager, Onno Zoeter, Maarten de Rijke

TL;DR

CLAX addresses the scalability gap in training classic probabilistic graphical model (PGM) click models by replacing EM with direct gradient-based optimization of the marginal log-likelihood in a JAX-based, modular framework. It demonstrates substantial speedups and scalable training on datasets with billions of sessions, while supporting embeddings, neural modules, and mixture models that can surpass traditional two-tower baselines in ranking. The work provides a comprehensive API, numerical-stability techniques, and embedding-compression strategies to enable end-to-end optimization at scale. Overall, CLAX delivers a practical, extensible tool for practitioners and researchers to develop and deploy advanced, interpretable click models in large-scale information retrieval settings.

Abstract

CLAX is a JAX-based library that implements classic click models using modern gradient-based optimization. While neural click models have emerged over the past decade, complex click models based on probabilistic graphical models (PGMs) have not systematically adopted gradient-based optimization, preventing practitioners from leveraging modern deep learning frameworks while preserving the interpretability of classic models. CLAX addresses this gap by replacing EM-based optimization with direct gradient-based optimization in a numerically stable manner. The framework's modular design enables the integration of any component, from embeddings and deep networks to custom modules, into classic click models for end-to-end optimization. We demonstrate CLAX's efficiency by running experiments on the full Baidu-ULTR dataset comprising over a billion user sessions in $\approx$ 2 hours on a single GPU, orders of magnitude faster than traditional EM approaches. CLAX implements ten classic click models, serving both industry practitioners seeking to understand user behavior and improve ranking performance at scale and researchers developing new click models. CLAX is available at: https://github.com/philipphager/clax

CLAX: Fast and Flexible Neural Click Models in JAX

TL;DR

CLAX addresses the scalability gap in training classic probabilistic graphical model (PGM) click models by replacing EM with direct gradient-based optimization of the marginal log-likelihood in a JAX-based, modular framework. It demonstrates substantial speedups and scalable training on datasets with billions of sessions, while supporting embeddings, neural modules, and mixture models that can surpass traditional two-tower baselines in ranking. The work provides a comprehensive API, numerical-stability techniques, and embedding-compression strategies to enable end-to-end optimization at scale. Overall, CLAX delivers a practical, extensible tool for practitioners and researchers to develop and deploy advanced, interpretable click models in large-scale information retrieval settings.

Abstract

CLAX is a JAX-based library that implements classic click models using modern gradient-based optimization. While neural click models have emerged over the past decade, complex click models based on probabilistic graphical models (PGMs) have not systematically adopted gradient-based optimization, preventing practitioners from leveraging modern deep learning frameworks while preserving the interpretability of classic models. CLAX addresses this gap by replacing EM-based optimization with direct gradient-based optimization in a numerically stable manner. The framework's modular design enables the integration of any component, from embeddings and deep networks to custom modules, into classic click models for end-to-end optimization. We demonstrate CLAX's efficiency by running experiments on the full Baidu-ULTR dataset comprising over a billion user sessions in 2 hours on a single GPU, orders of magnitude faster than traditional EM approaches. CLAX implements ten classic click models, serving both industry practitioners seeking to understand user behavior and improve ranking performance at scale and researchers developing new click models. CLAX is available at: https://github.com/philipphager/clax

Paper Structure

This paper contains 29 sections, 31 equations, 4 figures.

Figures (4)

  • Figure 1: CLAX matches or exceeds the click predictions of PyClick over three folds of 10M training sessions on WSCD-2012.
  • Figure 2: Kendall's $\tau$ between ranking models trained with and without embedding compression on WSCD-2012.
  • Figure 3: Embedding-based CLAX models on the Baidu-ULTR dataset (three folds of 800M / 200M / 200M sessions for training, validation, and testing) Zou2022Baidu. All models complete training under 2 hours using the hashing-trick with 10x compression.
  • Figure 4: CLAX models generalizing over BERT features on the Baidu-ULTR-UVA dataset Hager2024ULTR using a deep-cross network achieve strong ranking performance and a different model fit compared to embedding-based models.