Steered Generation via Gradient Descent on Sparse Features
Sumanta Bhattacharyya, Pedram Rooshenas
TL;DR
This work addresses controllable text generation by learning sparse, interpretable latent representations of query-attention features in LLMs. By training sparse autoencoders on query-head activations and embedding these in a prototype-based latent space, the method steers output toward desired cognitive styles via gradient-based updates in latent space, without altering model weights. The approach emphasizes mid-layer representations, sparsity regularization, and a gradient ascent mechanism to move toward style prototypes, demonstrating improved alignment to Bloom's taxonomy levels in educational feedback tasks and showing generalization across models like Mistral. Key contributions include a novel SAE-based steering framework, a synthetic Bloom-style educational dataset, and extensive analyses of layer choice, sparsity, and dimensionality, with implications for interpretable, on-demand style control in LLMs. Limitations such as polysemantic activations are discussed, and future work points to co-training SAEs with LLMs and broader model evaluations.
Abstract
Large language models (LLMs) encode a diverse range of linguistic features within their latent representations, which can be harnessed to steer their output toward specific target characteristics. In this paper, we modify the internal structure of LLMs by training sparse autoencoders to learn a sparse representation of the query embedding, allowing precise control over the model's attention distribution. We demonstrate that manipulating this sparse representation effectively transforms the output toward different stylistic and cognitive targets. Specifically, in an educational setting, we show that the cognitive complexity of LLM-generated feedback can be systematically adjusted by modifying the encoded query representation at a specific layer. To achieve this, we guide the learned sparse embedding toward the representation of samples from the desired cognitive complexity level, using gradient-based optimization in the latent space.
