Inference-Time Rethinking with Latent Thought Vectors for Math Reasoning
Deqian Kong, Minglu Zhao, Aoyang Qin, Bo Pang, Chenxin Tao, David Hartmann, Edouardo Honig, Dehong Xu, Amit Kumar, Matt Sarte, Chuan Li, Jianwen Xie, Ying Nian Wu
TL;DR
This work tackles brittleness in standard chain-of-thought reasoning by decoupling declarative content from procedural generation through latent thought vectors and a Gibbs-style inference-time rethinking loop. It presents a two-part Transformer-based model with a learnable latent space that guides generation and enables per-instance latent optimization during inference, avoiding amortized inference. Across GSM8K, SVAMP, and MultiArith, a 0.2B model with 30 rethinking iterations achieves state-of-the-art results, outperforming substantially larger baselines and demonstrating robustness to distribution shifts. The findings highlight inference-time computation on a structured latent manifold as a powerful scaling axis complementary to parameter count, with implications for more reliable mathematical reasoning in AI systems.
Abstract
Standard chain-of-thought reasoning generates a solution in a single forward pass, committing irrevocably to each token and lacking a mechanism to recover from early errors. We introduce Inference-Time Rethinking, a generative framework that enables iterative self-correction by decoupling declarative latent thought vectors from procedural generation. We factorize reasoning into a continuous latent thought vector (what to reason about) and a decoder that verbalizes the trace conditioned on this vector (how to reason). Beyond serving as a declarative buffer, latent thought vectors compress the reasoning structure into a continuous representation that abstracts away surface-level token variability, making gradient-based optimization over reasoning strategies well-posed. Our prior model maps unstructured noise to a learned manifold of valid reasoning patterns, and at test time we employ a Gibbs-style procedure that alternates between generating a candidate trace and optimizing the latent vector to better explain that trace, effectively navigating the latent manifold to refine the reasoning strategy. Training a 0.2B-parameter model from scratch on GSM8K, our method with 30 rethinking iterations surpasses baselines with 10 to 15 times more parameters, including a 3B counterpart. This result demonstrates that effective mathematical reasoning can emerge from sophisticated inference-time computation rather than solely from massive parameter counts.
