Table of Contents
Fetching ...

Diffusion Transformer Captures Spatial-Temporal Dependencies: A Theory for Gaussian Process Data

Hengyu Fu, Zehao Dou, Jiawei Guo, Mengdi Wang, Minshuo Chen

TL;DR

This paper develops a theoretical framework for diffusion transformers that capture spatial-temporal dependencies in sequential data modeled as Gaussian processes with covariance $m{Gamma} mskip-3muigotimesm{Sigma}$. It introduces score-function approximation by unrolling gradient descent in a transformer and proves a Transformer-based approximation theorem, along with a sample-complexity bound showing how decay of temporal correlations improves learning efficiency. The results are-supported by numerical experiments on GP data and semi-synthetic video data, showing that attention layers learn and reveal the underlying temporal kernel and spatial covariance, respectively. The work provides a principled bridge between diffusion models and transformers for sequential data and suggests practical guidelines for leveraging correlation decay to improve learning efficiency in long sequences.

Abstract

Diffusion Transformer, the backbone of Sora for video generation, successfully scales the capacity of diffusion models, pioneering new avenues for high-fidelity sequential data generation. Unlike static data such as images, sequential data consists of consecutive data frames indexed by time, exhibiting rich spatial and temporal dependencies. These dependencies represent the underlying dynamic model and are critical to validate the generated data. In this paper, we make the first theoretical step towards bridging diffusion transformers for capturing spatial-temporal dependencies. Specifically, we establish score approximation and distribution estimation guarantees of diffusion transformers for learning Gaussian process data with covariance functions of various decay patterns. We highlight how the spatial-temporal dependencies are captured and affect learning efficiency. Our study proposes a novel transformer approximation theory, where the transformer acts to unroll an algorithm. We support our theoretical results by numerical experiments, providing strong evidence that spatial-temporal dependencies are captured within attention layers, aligning with our approximation theory.

Diffusion Transformer Captures Spatial-Temporal Dependencies: A Theory for Gaussian Process Data

TL;DR

This paper develops a theoretical framework for diffusion transformers that capture spatial-temporal dependencies in sequential data modeled as Gaussian processes with covariance . It introduces score-function approximation by unrolling gradient descent in a transformer and proves a Transformer-based approximation theorem, along with a sample-complexity bound showing how decay of temporal correlations improves learning efficiency. The results are-supported by numerical experiments on GP data and semi-synthetic video data, showing that attention layers learn and reveal the underlying temporal kernel and spatial covariance, respectively. The work provides a principled bridge between diffusion models and transformers for sequential data and suggests practical guidelines for leveraging correlation decay to improve learning efficiency in long sequences.

Abstract

Diffusion Transformer, the backbone of Sora for video generation, successfully scales the capacity of diffusion models, pioneering new avenues for high-fidelity sequential data generation. Unlike static data such as images, sequential data consists of consecutive data frames indexed by time, exhibiting rich spatial and temporal dependencies. These dependencies represent the underlying dynamic model and are critical to validate the generated data. In this paper, we make the first theoretical step towards bridging diffusion transformers for capturing spatial-temporal dependencies. Specifically, we establish score approximation and distribution estimation guarantees of diffusion transformers for learning Gaussian process data with covariance functions of various decay patterns. We highlight how the spatial-temporal dependencies are captured and affect learning efficiency. Our study proposes a novel transformer approximation theory, where the transformer acts to unroll an algorithm. We support our theoretical results by numerical experiments, providing strong evidence that spatial-temporal dependencies are captured within attention layers, aligning with our approximation theory.
Paper Structure (80 sections, 34 theorems, 216 equations, 13 figures, 4 tables)

This paper contains 80 sections, 34 theorems, 216 equations, 13 figures, 4 tables.

Key Result

Lemma 1

For an arbitrarily fixed $t \in [0, T]$ and $\mathbf{v}_t$, given an error tolerance $\epsilon > 0$ and any integer $J < N$, if $\bar{\bm{\Gamma}}$ with $\bar{\bm{\Gamma}}_{ij} = \bm{\Gamma}_{ij} \mathds{1}\{|i - j| < J\}$ is positive semidefinite, then running gradient descent in equ::GD with a sui where $\kappa_t=\kappa{\left( \alpha_t^2 (\bar{\bm{\Gamma}} \otimes \bm{\Sigma}) + \sigma_t^2 \math

Figures (13)

  • Figure 1: Diffusion transformer learns spatial-temporal dependencies. The diffusion transformer is trained with data sampled from a stationary Gaussian process consisting of $128$ time steps. At each time step, the data dimension is $8$. We obtain 1000 generated samples at each time step. The left large heat map demonstrates the estimated temporal correlation (see Appendix \ref{['appendix:exp_detail']} for the estimation method) in the process between different time steps, which aligns well with the ground truth on the right. The smaller heat maps are the estimated covariance matrix of data at a single time step, which demonstrate the spatial dependencies in data. They also align well with the ground truth.
  • Figure 2: Transformer architecture. Here $f_{\rm in}$ is a linear layer to lift input patch to $\mathbb{R}^D$, which appends the input raw data with time index embedding and other useful information. After passing through $L$ transformer blocks, $f_{\rm out}$ projects each patch into the data original dimension $\mathbb{R}^d$ and clip the output range by $R$. We allow the output range to be diffusion timestep $t$ dependent (denoted as $R_t$).
  • Figure 3: Construction of score function approximation using a transformer. By rewriting the score function as the optimizer of a quadratic objective function, we use gradient descent algorithm to approximate the optimizer. We allow correlation truncation to manipulate the maximum length of temporal dependencies to model in Lemma \ref{['lemma:score_as_gd']}. Then we construct a transformer architecture to unroll the gradient descent algorithm for score approximation in Theorem \ref{['thm:approx']}. Each gradient descent iteration is realized by a multiplication module followed by two transformer blocks. In the first transformer block, its attention layer calculates correlation $\bm{\Gamma}_{ij}$ utilizing time embedding. The second transformer block calculates the linear offset $-\eta_t \sigma_t^2 \mathbf{s} - \eta_t (\mathbf{v}_t - \alpha_t \bm{\mu})$ in \ref{['equ::GD']}.
  • Figure 4: In panel (a), we observe that the relative error decreases as the sample size increases. Meanwhile, larger $\nu$ or smaller $\ell$ leads to better performance, supporting our generalization bound in Theorem \ref{['thm:generalization']}. In panel (b), we split $\mathbf{Q}$ to $\mathbf{Q}=[\mathbf{Q}_x,\mathbf{Q}_e]$ where $\mathbf{Q}_x\in \mathbb{R}^{16\times32}$ and $\mathbf{Q}_e\in \mathbb{R}^{16\times32}$ corresponds to time embedding $\mathbf{e}_{i}$. We also split the key matrix $\mathbf{K}$ as $[\mathbf{K}_x,\mathbf{K}_e]$. The sub-block $\mathbf{Q}_{\mathbf{e}}^\top \mathbf{K}_{\mathbf{e}}$ has dominant weights compared to other sub-blocks.
  • Figure 5: We demonstrate score matrices in different attention layers and at different backward denoising steps. The learned temporal dependencies gain more and more clarity as the denoising in the backward process proceeds. Meanwhile, we observe that the temporal dependencies are well captured starting from the 3rd layer.
  • ...and 8 more figures

Theorems & Definitions (35)

  • Lemma 1: Gradient Descent Iterate Approximates the Score Function
  • Corollary 1: Correlation Truncation with Decay
  • Theorem 1: Score Approximation by Transformers
  • Theorem 2: Sample Complexity of Diffusion Transformer
  • Lemma 2: Theorem 3.12 in bubeck2015convex
  • Lemma 3
  • Lemma 4
  • Remark 1
  • Proposition 1
  • Lemma 5
  • ...and 25 more