KerasCV and KerasNLP: Vision and Language Power-Ups
Matthew Watson, Divyashree Shivakumar Sreepathihalli, Francois Chollet, Martin Gorner, Kiranbir Sodhia, Ramesh Sampath, Tirth Patel, Haifeng Jin, Neel Kovelamudi, Gabriel Rasskin, Samaneh Saadat, Luke Wood, Chen Qian, Jonathan Bischof, Ian Stenbit, Abheesht Sharma, Anshuman Mishra
TL;DR
KerasCV and KerasNLP address fragmentation in CV/NLP tooling by提供 multi-backend support (JAX, TensorFlow, PyTorch) and a layered API built on Keras 3 that unifies preprocessing, backbones, and task-specific models. The paper introduces Foundational Components, Pretrained Backbones, and Task Models, plus a Presets API and Kaggle Models integration to accelerate experimentation and deployment. A performance framework demonstrates speedups from backend selection and XLA compilation, while revealing current gaps (e.g., PyTorch LL M padding) and confirming scalable, cross-backend workflows. The work emphasizes open-source development with easy deployment paths (tf serving, TFLite, etc.) and aims to enable rapid, scalable CV/NLP research and production across diverse hardware and environments.
Abstract
We present the Keras domain packages KerasCV and KerasNLP, extensions of the Keras API for Computer Vision and Natural Language Processing workflows, capable of running on either JAX, TensorFlow, or PyTorch. These domain packages are designed to enable fast experimentation, with a focus on ease-of-use and performance. We adopt a modular, layered design: at the library's lowest level of abstraction, we provide building blocks for creating models and data preprocessing pipelines, and at the library's highest level of abstraction, we provide pretrained ``task" models for popular architectures such as Stable Diffusion, YOLOv8, GPT2, BERT, Mistral, CLIP, Gemma, T5, etc. Task models have built-in preprocessing, pretrained weights, and can be fine-tuned on raw inputs. To enable efficient training, we support XLA compilation for all models, and run all preprocessing via a compiled graph of TensorFlow operations using the tf.data API. The libraries are fully open-source (Apache 2.0 license) and available on GitHub.
