Hello folks,
Diffusion language models outperform AR on synthetic reasoning tasks, but why?
It turns out the answer can be traced down to a surprising mechanism: diffusion models naturally maintain "latent tokens" -- joint predictions over positions they won't immediately decode -- that enable planning and lookahead. In particular, latent tokens control a smooth tradeoff between inference speed and quality, and that this mechanism yields large gains in AR models on the same reasoning tasks where they've traditionally struggled.
This Monday, Andre He (LTI @ CMU) will present his paper Reasoning with Latent Tokens in Diffusion Language Models.
Title: Reasoning with Latent Tokens in Diffusion Language Models
Meeting Link: click here
Time: Mar 2 (Monday) 1pm ET / 10am PT / 7pm CET / 11:30pm IST
Paper: https://arxiv.org/abs/2602.03769
Prior knowledge:
Fundamentals of discrete diffusion (video by Sasha Rush)
Abstract: Discrete diffusion models have recently become competitive with autoregressive models for language modeling, even outperforming them on reasoning tasks requiring planning and global coherence, but they require more computation at inference time. We trace this trade-off to a key mechanism: diffusion models are trained to jointly predict a distribution over all unknown tokens, including those that will not actually be decoded in the current step. Ablating this joint prediction yields faster inference but degrades performance, revealing that accurate prediction at the decoded position relies on joint reasoning about the distribution of undecoded tokens. We interpret these as latent tokens and introduce a method for modulating their number, demonstrating empirically that this enables a smooth tradeoff between inference speed and sample quality. Furthermore, we demonstrate that latent tokens can be introduced into autoregressive models through an auxiliary multi-token prediction objective, yielding substantial improvements on the same reasoning tasks where they have traditionally struggled. Our results suggest that latent tokens, while arising naturally in diffusion, represent a general mechanism for improving performance on tasks requiring global coherence or lookahead.
Yours truly,
Subham, Justin, Zhihan
Website, Twitter, Discord, YouTube