Table of Contents
Fetching ...

Discriminative Class Tokens for Text-to-Image Diffusion Models

Idan Schwartz, Vésteinn Snæbjarnarson, Hila Chefer, Ryan Cotterell, Serge Belongie, Lior Wolf, Sagie Benaim

TL;DR

This paper tackles lexical ambiguity and fine-grained detail in text-to-image diffusion by learning discriminative class tokens. It introduces a token-based fine-tuning approach that optimizes only the embedding of a new class token $S_c$ using a pretrained classifier, without additional in-domain images or full-model retraining, and employs gradient skipping to reduce resources. The method yields higher classification accuracy and better FID scores than baselines, while enabling data augmentation in low-resource settings and revealing insights into training data through classifier inversion. Overall, the technique offers a fast, flexible, and privacy-conscious way to steer diffusion models toward precise class representations and detailed imagery while preserving the underlying model’s diversity and capabilities.

Abstract

Recent advances in text-to-image diffusion models have enabled the generation of diverse and high-quality images. While impressive, the images often fall short of depicting subtle details and are susceptible to errors due to ambiguity in the input text. One way of alleviating these issues is to train diffusion models on class-labeled datasets. This approach has two disadvantages: (i) supervised datasets are generally small compared to large-scale scraped text-image datasets on which text-to-image models are trained, affecting the quality and diversity of the generated images, or (ii) the input is a hard-coded label, as opposed to free-form text, limiting the control over the generated images. In this work, we propose a non-invasive fine-tuning technique that capitalizes on the expressive potential of free-form text while achieving high accuracy through discriminative signals from a pretrained classifier. This is done by iteratively modifying the embedding of an added input token of a text-to-image diffusion model, by steering generated images toward a given target class according to a classifier. Our method is fast compared to prior fine-tuning methods and does not require a collection of in-class images or retraining of a noise-tolerant classifier. We evaluate our method extensively, showing that the generated images are: (i) more accurate and of higher quality than standard diffusion models, (ii) can be used to augment training data in a low-resource setting, and (iii) reveal information about the data used to train the guiding classifier. The code is available at \url{https://github.com/idansc/discriminative_class_tokens}.

Discriminative Class Tokens for Text-to-Image Diffusion Models

TL;DR

This paper tackles lexical ambiguity and fine-grained detail in text-to-image diffusion by learning discriminative class tokens. It introduces a token-based fine-tuning approach that optimizes only the embedding of a new class token using a pretrained classifier, without additional in-domain images or full-model retraining, and employs gradient skipping to reduce resources. The method yields higher classification accuracy and better FID scores than baselines, while enabling data augmentation in low-resource settings and revealing insights into training data through classifier inversion. Overall, the technique offers a fast, flexible, and privacy-conscious way to steer diffusion models toward precise class representations and detailed imagery while preserving the underlying model’s diversity and capabilities.

Abstract

Recent advances in text-to-image diffusion models have enabled the generation of diverse and high-quality images. While impressive, the images often fall short of depicting subtle details and are susceptible to errors due to ambiguity in the input text. One way of alleviating these issues is to train diffusion models on class-labeled datasets. This approach has two disadvantages: (i) supervised datasets are generally small compared to large-scale scraped text-image datasets on which text-to-image models are trained, affecting the quality and diversity of the generated images, or (ii) the input is a hard-coded label, as opposed to free-form text, limiting the control over the generated images. In this work, we propose a non-invasive fine-tuning technique that capitalizes on the expressive potential of free-form text while achieving high accuracy through discriminative signals from a pretrained classifier. This is done by iteratively modifying the embedding of an added input token of a text-to-image diffusion model, by steering generated images toward a given target class according to a classifier. Our method is fast compared to prior fine-tuning methods and does not require a collection of in-class images or retraining of a noise-tolerant classifier. We evaluate our method extensively, showing that the generated images are: (i) more accurate and of higher quality than standard diffusion models, (ii) can be used to augment training data in a low-resource setting, and (iii) reveal information about the data used to train the guiding classifier. The code is available at \url{https://github.com/idansc/discriminative_class_tokens}.
Paper Structure (17 sections, 4 equations, 8 figures, 3 tables)

This paper contains 17 sections, 4 equations, 8 figures, 3 tables.

Figures (8)

  • Figure 1: We propose a technique that introduces a token ($S_c$) corresponding to an external classifier label class $c$. This improves text-to-image alignment when there is lexical ambiguity and enhances the depiction of intricate details.
  • Figure 2: An overview of our method for optimizing a new discriminative token representation ($v_c$) using a pre-trained classifier. For the prompt 'A photo of a $S_c$ tiger cat,' we expect the output generated with the class $c$ to be 'tiger cat'. The classifier, however, indicates that the class of the generated image is 'tiger'. We generate images iteratively and optimize the token representation using cross-entropy. Once $v_c$ has been trained, more images of the target class can be generated by including it in the context of the input text.
  • Figure 3: An illustration of the gradient skipping technique (indicated by the red line). During backpropagation, the gradient is propagated only through the final denoising step of the diffusion procedure.
  • Figure 4: Images generated based on ImageNet classes, using SD or our method. Real images are shown for comparison.
  • Figure 5: A selection of images based on iNat classes generated with Stable Diffusion (SD) and our method. A real image is shown for comparison. (a). Yellow pine chipmunk, (b). Jelly antler, (c). Salamander, (d). Pacific lions mane jelly, (e). Leaf mite, (f). Red sea urchin, (g). Seashore mallow, (h). Sheepshead minnow.
  • ...and 3 more figures