Table of Contents
Fetching ...

Data-Driven Loss Functions for Inference-Time Optimization in Text-to-Image

Sapir Esther Yiflach, Yuval Atzmon, Gal Chechik

TL;DR

This work addresses the challenge of accurately enforcing spatial relations in text-to-image diffusion. It introduces Learn-to-Steer, a data-driven framework that learns a relation classifier from cross-attention maps and uses it as a learned loss during test-time optimization, avoiding model fine-tuning. A key contribution is the dual-inversion augmentation that mitigates relation leakage, enabling reliable learning of spatial patterns directly from internal representations. The approach yields substantial gains in spatial accuracy across multiple diffusion models and supports multiple relations, out-of-distribution generalization, and zero-shot personalization, all while preserving the base model's broader capabilities. This data-driven, inference-time steering paradigm offers a scalable path to reliable compositional scene generation without retraining the model.

Abstract

Text-to-image diffusion models can generate stunning visuals, yet they often fail at tasks children find trivial--like placing a dog to the right of a teddy bear rather than to the left. When combinations get more unusual--a giraffe above an airplane--these failures become even more pronounced. Existing methods attempt to fix these spatial reasoning failures through model fine-tuning or test-time optimization with handcrafted losses that are suboptimal. Rather than imposing our assumptions about spatial encoding, we propose learning these objectives directly from the model's internal representations. We introduce Learn-to-Steer, a novel framework that learns data-driven objectives for test-time optimization rather than handcrafting them. Our key insight is to train a lightweight classifier that decodes spatial relationships from the diffusion model's cross-attention maps, then deploy this classifier as a learned loss function during inference. Training such classifiers poses a surprising challenge: they can take shortcuts by detecting linguistic traces in the cross-attention maps, rather than learning true spatial patterns. We solve this by augmenting our training data with samples generated using prompts with incorrect relation words, which encourages the classifier to avoid linguistic shortcuts and learn spatial patterns from the attention maps. Our method dramatically improves spatial accuracy: from 20% to 61% on FLUX.1-dev and from 7% to 54% on SD2.1 across standard benchmarks. It also generalizes to multiple relations with significantly improved accuracy.

Data-Driven Loss Functions for Inference-Time Optimization in Text-to-Image

TL;DR

This work addresses the challenge of accurately enforcing spatial relations in text-to-image diffusion. It introduces Learn-to-Steer, a data-driven framework that learns a relation classifier from cross-attention maps and uses it as a learned loss during test-time optimization, avoiding model fine-tuning. A key contribution is the dual-inversion augmentation that mitigates relation leakage, enabling reliable learning of spatial patterns directly from internal representations. The approach yields substantial gains in spatial accuracy across multiple diffusion models and supports multiple relations, out-of-distribution generalization, and zero-shot personalization, all while preserving the base model's broader capabilities. This data-driven, inference-time steering paradigm offers a scalable path to reliable compositional scene generation without retraining the model.

Abstract

Text-to-image diffusion models can generate stunning visuals, yet they often fail at tasks children find trivial--like placing a dog to the right of a teddy bear rather than to the left. When combinations get more unusual--a giraffe above an airplane--these failures become even more pronounced. Existing methods attempt to fix these spatial reasoning failures through model fine-tuning or test-time optimization with handcrafted losses that are suboptimal. Rather than imposing our assumptions about spatial encoding, we propose learning these objectives directly from the model's internal representations. We introduce Learn-to-Steer, a novel framework that learns data-driven objectives for test-time optimization rather than handcrafting them. Our key insight is to train a lightweight classifier that decodes spatial relationships from the diffusion model's cross-attention maps, then deploy this classifier as a learned loss function during inference. Training such classifiers poses a surprising challenge: they can take shortcuts by detecting linguistic traces in the cross-attention maps, rather than learning true spatial patterns. We solve this by augmenting our training data with samples generated using prompts with incorrect relation words, which encourages the classifier to avoid linguistic shortcuts and learn spatial patterns from the attention maps. Our method dramatically improves spatial accuracy: from 20% to 61% on FLUX.1-dev and from 7% to 54% on SD2.1 across standard benchmarks. It also generalizes to multiple relations with significantly improved accuracy.

Paper Structure

This paper contains 31 sections, 24 figures, 10 tables.

Figures (24)

  • Figure 1: Learn-to-Steer learns how spatial relationships are encoded in attention maps to guide generation. a) Correctly renders all four orientations (above/below/left/right) b, c) Handles complex scenes with multiple spatial relationships of up to five objects and three relations. Prompts are for illustration purpose, see \ref{['sec:multi']} for actual prompt structure.
  • Figure 2: Classifier training pipeline. Given a spatially-aligned image, we augment the data with both correct and incorrect relation prompts to prevent relation leakage. During denoising, we extract relevant attention maps to train our classifier.
  • Figure 3: Test-time optimization pipeline. During inference, we extract the relevant cross-attention maps when denoising $z_t$ and evaluate their relationship using our trained relation classifier. We then update the latent noise with backpropagation.
  • Figure 4: Comparison using FLUX.1-dev (left) and SD 2.1 (right) as base models with prompts from the GenEval ghosh2023geneval benchmark. For each prompt, the same seed is used for all methods.
  • Figure 5: Evolution of cross-attention maps across denoising steps. Maps shown are after the steering step (See Fig. 3). At early steps ($t=T$), maps are diffuse and overlapping. During denoising, the corgi localizes on the left and the teapot on the right, converging to the intended spatial arrangement by $t=0$.
  • ...and 19 more figures