Table of Contents
Fetching ...

Learning to Disentangle Latent Reasoning Rules with Language VAEs: A Systematic Study

Yingji Zhang, Marco Valentino, Danilo S. Carvalho, André Freitas

TL;DR

This work tackles the memorisation-versus-rule inference gap in natural language inference by explicitly encoding reasoning rules in a language VAE's latent space. It presents a NTK-inspired framework that treats rules as distinct latent subspaces, demonstrating that explicit supervision yields disentangled rule representations and rule-specific clustering. The authors implement an end-to-end Transformer-based VAE with three latent-injection strategies, finding that injecting latent information into the Query yields the best rule separation, and that FFN components better preserve rule separation than attention. The approach improves interpretability and controllability of latent reasoning, with practical implications for safer and more auditable NLI systems, and points to diffusion-based extensions for future decoding control.

Abstract

Incorporating explicit reasoning rules within the latent space of language models (LMs) offers a promising pathway to enhance generalisation, interpretability, and controllability. While current Transformer-based language models have shown strong performance on Natural Language Inference (NLI) tasks, they often rely on memorisation rather than rule-based inference. This work investigates how reasoning rules can be explicitly embedded and memorised within the LMs through Language Variational Autoencoders (VAEs). We propose a complete pipeline for learning reasoning rules within Transformer-based language VAEs. This pipeline encompasses three rule-based reasoning tasks, a supporting theoretical framework, and a practical end-to-end architecture. The experiment illustrates the following findings: Disentangled reasoning: Under explicit signal supervision, reasoning rules - viewed as functional mappings - can be disentangled within the encoder's parametric space. This separation results in distinct clustering of rules in the output feature space. Prior knowledge injection: injecting reasoning information into the Query enables the model to more effectively retrieve the stored value Value from memory based on Key. This approach offers a simple method for integrating prior knowledge into decoder-only language models. Performance bottleneck: In mathematical reasoning tasks using Qwen2.5(0.5B), increasing sample count doesn't improve performance beyond a point. Moreover, ffn layers are better than attention layers at preserving the separation of reasoning rules in the model's parameters.

Learning to Disentangle Latent Reasoning Rules with Language VAEs: A Systematic Study

TL;DR

This work tackles the memorisation-versus-rule inference gap in natural language inference by explicitly encoding reasoning rules in a language VAE's latent space. It presents a NTK-inspired framework that treats rules as distinct latent subspaces, demonstrating that explicit supervision yields disentangled rule representations and rule-specific clustering. The authors implement an end-to-end Transformer-based VAE with three latent-injection strategies, finding that injecting latent information into the Query yields the best rule separation, and that FFN components better preserve rule separation than attention. The approach improves interpretability and controllability of latent reasoning, with practical implications for safer and more auditable NLI systems, and points to diffusion-based extensions for future decoding control.

Abstract

Incorporating explicit reasoning rules within the latent space of language models (LMs) offers a promising pathway to enhance generalisation, interpretability, and controllability. While current Transformer-based language models have shown strong performance on Natural Language Inference (NLI) tasks, they often rely on memorisation rather than rule-based inference. This work investigates how reasoning rules can be explicitly embedded and memorised within the LMs through Language Variational Autoencoders (VAEs). We propose a complete pipeline for learning reasoning rules within Transformer-based language VAEs. This pipeline encompasses three rule-based reasoning tasks, a supporting theoretical framework, and a practical end-to-end architecture. The experiment illustrates the following findings: Disentangled reasoning: Under explicit signal supervision, reasoning rules - viewed as functional mappings - can be disentangled within the encoder's parametric space. This separation results in distinct clustering of rules in the output feature space. Prior knowledge injection: injecting reasoning information into the Query enables the model to more effectively retrieve the stored value Value from memory based on Key. This approach offers a simple method for integrating prior knowledge into decoder-only language models. Performance bottleneck: In mathematical reasoning tasks using Qwen2.5(0.5B), increasing sample count doesn't improve performance beyond a point. Moreover, ffn layers are better than attention layers at preserving the separation of reasoning rules in the model's parameters.

Paper Structure

This paper contains 34 sections, 10 equations, 13 figures, 4 tables.

Figures (13)

  • Figure 1: Overview, where $(\pi, x, c)$ represents the (rule, input premise(s), conclusion). To systematically evaluate rule-based learning within a VAE framework, first, we examine three rule-based NLI tasks. Second, we formalise the hypothesis that reasoning rules can be functionally and separately encoded within the encoder’s parametric space, enabling rule learning in the latent space, grounded in the theoretical framework of neural tangent kernels. Third, we introduce an end-to-end VAE architecture, with three different latent injection setups, designed to capture coarse-grained reasoning patterns in its latent space while remaining sensitive to the lexical semantics of the input.
  • Figure 2: Gradient heatmap for the last posterior encoder layer (query_add setup), where the left: math derivation, middle: explanatory reasoning, right: syllogistic reasoning, Top: cls_weight is 1.0, bottom: cls_weight is 0.1. We can observe that the non-diagonal values are notably close to 0 when providing higher cls_weight (the red colour elements are less scattered), suggesting that incorporating rule information during training enhances the separation of rule subspaces in the encoder's parameter space. We provide the heatmaps of all layers in the supplementary material.
  • Figure 3: PCA visualisation for query_add injection setup, where left three: cls_weight is 1.0, right three: cls_weight is 0.1. We can observe that the model struggle to learn the rules when the weight is close to zero, indicating the neural network try to deliver reason behaviour via memorisation, rather than rule-based learning. For other injection setups, their visualisations are provided in Figure \ref{['fig:math_pca']}, \ref{['fig:explanation_pca']}, and \ref{['fig:syllogistic_pca']} in the supplementary material.
  • Figure 4: Case study for Math Reasoning task, where left: analysing how varying the number of training samples for each operation affects the reasoning capabilities. Right: comparing the parametric rule separation between attn and ffn at the last layer in Qwen2.5-0.5B, a pretrained checkpoint with a training sample size of 4k.
  • Figure 5: Math Reasoning: gradient heatmap, where cls_weight is 0.1.
  • ...and 8 more figures