Table of Contents
Fetching ...

Selection, Reflection and Self-Refinement: Revisit Reasoning Tasks via a Causal Lens

Yunlong Deng, Boyang Sun, Yan Li, Lingjing Kong, Zeyu Tang, Kun Zhang, Guangyi Chen

TL;DR

This work reframes reasoning as a causal selection problem where latent rules $\mathbf{z}$ constrain observed inputs $\mathbf{x}$ to produce outputs $\mathbf{y}$ under constraints $S(\mathbf{z})=1$, formalized through $p(\mathbf{x}, \mathbf{y})$ and related expressions. It introduces SR$^2$, a flat recurrent Transformer framework with three components: Reflective Representation Learning, Dependency Self-Refinement, and Periodic Alignment, to iteratively refine latent representations and enforce dense dependencies. Empirically, SR$^2$ achieves state-of-the-art or near state-of-the-art results on Sudoku-Extreme and Maze-Hard with substantially fewer parameters (e.g., ~3.4M vs ~27.3M) and demonstrates through ablations the necessity and balance of its modules and iterative design. The approach offers a principled alternative to scaling by exploiting structured latent reasoning and fixed-point Refinement, with potential implications for robust, human-like reasoning in AI systems.

Abstract

Due to their inherent complexity, reasoning tasks have long been regarded as rigorous benchmarks for assessing the capabilities of machine learning models, especially large language models (LLMs). Although humans can solve these tasks with ease, existing models, even after extensive pre-training and post-training at scale, still fail to perform reasoning reliably. In this paper, we revisit reasoning tasks from a causal perspective, seeking to understand their behavior in latent space and to offer insights for addressing their challenges. Specifically, we cast reasoning tasks as a selection mechanism, in which high-level logical concepts function as selection operators on the given observations, such as, identifying the correct answer in a math problem or filling the appropriate entry in Sudoku. We emphasize two key properties of this formulation that shed light on the difficulty of reasoning tasks. First, the latent space exceeds the observation space in complexity, even when the correct answer is fully determined by the observed input. Second, the latent variables, corresponding to logical thought, are densely structured and exhibit strong dependencies. Building on this formulation, we introduce a framework, called SR$^2$, that incorporates the estimated latent variables as feedback into the selection mechanism, thereby facilitating the learning of dense dependencies among latent representations. The framework consists of three key modules: reflective representation learning, dependency self-refinement, and periodic intermediate alignment. Experimentally, we show that our approach yields significant gains in reasoning accuracy, for example, attaining over 10$\%$ improvement in performance with 8$\times$ fewer parameters on the Sudoku and Maze tasks over the recent advances.

Selection, Reflection and Self-Refinement: Revisit Reasoning Tasks via a Causal Lens

TL;DR

This work reframes reasoning as a causal selection problem where latent rules constrain observed inputs to produce outputs under constraints , formalized through and related expressions. It introduces SR, a flat recurrent Transformer framework with three components: Reflective Representation Learning, Dependency Self-Refinement, and Periodic Alignment, to iteratively refine latent representations and enforce dense dependencies. Empirically, SR achieves state-of-the-art or near state-of-the-art results on Sudoku-Extreme and Maze-Hard with substantially fewer parameters (e.g., ~3.4M vs ~27.3M) and demonstrates through ablations the necessity and balance of its modules and iterative design. The approach offers a principled alternative to scaling by exploiting structured latent reasoning and fixed-point Refinement, with potential implications for robust, human-like reasoning in AI systems.

Abstract

Due to their inherent complexity, reasoning tasks have long been regarded as rigorous benchmarks for assessing the capabilities of machine learning models, especially large language models (LLMs). Although humans can solve these tasks with ease, existing models, even after extensive pre-training and post-training at scale, still fail to perform reasoning reliably. In this paper, we revisit reasoning tasks from a causal perspective, seeking to understand their behavior in latent space and to offer insights for addressing their challenges. Specifically, we cast reasoning tasks as a selection mechanism, in which high-level logical concepts function as selection operators on the given observations, such as, identifying the correct answer in a math problem or filling the appropriate entry in Sudoku. We emphasize two key properties of this formulation that shed light on the difficulty of reasoning tasks. First, the latent space exceeds the observation space in complexity, even when the correct answer is fully determined by the observed input. Second, the latent variables, corresponding to logical thought, are densely structured and exhibit strong dependencies. Building on this formulation, we introduce a framework, called SR, that incorporates the estimated latent variables as feedback into the selection mechanism, thereby facilitating the learning of dense dependencies among latent representations. The framework consists of three key modules: reflective representation learning, dependency self-refinement, and periodic intermediate alignment. Experimentally, we show that our approach yields significant gains in reasoning accuracy, for example, attaining over 10 improvement in performance with 8 fewer parameters on the Sudoku and Maze tasks over the recent advances.

Paper Structure

This paper contains 35 sections, 24 equations, 6 figures, 2 tables.

Figures (6)

  • Figure 1: Illustration of reasoning tasks and the selection mechanism, using Sudoku as an example. (a) A sample $9\times9$ Sudoku puzzle with a subset of given clues; the goal is to fill the remaining cells so that each row, column, and $3\times3$ subgrid contains the digits $1$–$9$ exactly once. (b) A single unfilled cell $Y_{ij}$ with its row (purple), column (blue), and $3\times3$ block (orange) highlighted; the digits within these groups impose constraints that determine the admissible values for $Y_{ij}$. (c) Selection mechanism: $Y$ is valid iff the validity criteria are satisfied $S^{(i)}_{\text{Row}}=S^{(j)}_{\text{Col}}=S^{(b)}_{\text{Block}}=1$.
  • Figure 2: Causal graph of the selection mechanism where (${\mathbf{x}}$,${\mathbf{y}}$) are selected by constrains $S({\mathbf{z}})$.
  • Figure 3: Overall framework of the SR$^2$ method. The framework consists of three main modules: Reflective Representation Learning (in blue), Dependency Self-Refinement (in orange), and Periodic Alignment (in green). $f$ denotes the weight-shared atomic block that updates the latent state, and $g$ projects the latent space to the final answers. In the representation learning stage, $f$ recurrently updates the latent state with the observation as injection, ${\mathbf{z}}^{(t+1)}=f({\mathbf{z}}^{(t)},{\mathbf{x}})$, for $M$ steps to obtain a refined initialization. Next, in the self-refinement stage, the model drops the observation signal, ${\mathbf{z}}^{(t+1)}=f({\mathbf{z}}^{(t)},\mathbf{0})$, and updates for a long $M\times (N\!-\!1)$ steps to resolve dense dependencies and approach a fixed point. Throughout training, the supervision is conducted periodically (e.g., every $M$ steps) to stabilize long recurrences and mitigate gradient vanishing. When adding supervision, the gradients from future states are blocked (as the red arrows).
  • Figure 4: Choosing $N$ and $M$ under a fixed compute budget.
  • Figure 5: Accuracy scaling with the hyperparameter $N$ and $M$: the left panel varies $N$ at $M{=}16$, and the right varies $M$ at $N{=}16$.
  • ...and 1 more figures