Leveraging Low-Rank Structures for High-Dimensional Score-Based Sampling
Abstract
Diffusion models offer a powerful framework for sampling from complex probability densities by learning to reverse a noising process. A common approach involves solving for the time-reversed stochastic differential equation (SDE), which requires the score function of the evolving sample distribution. The logarithm of this distribution's density is governed by a Hamilton-Jacobi-Bellman (HJB) type partial differential equation (PDE). However, current methods for solving this PDE, such as PINNs or trajectory-based techniques, often suffer from long training times and significant sensitivity to hyperparameter tuning. In this work, we introduce a novel and efficient solver for the underlying HJB equation based on the functional tensor train (FTT) format. The FTT representation leverages latent low-rank structures to efficiently approximate high-dimensional functions, enabling both model compression and rapid computation. By integrating this efficient representation with a backward-in-time iterative scheme derived from backward stochastic differential equations (BSDEs), we develop a fast, robust and accurate sampling method. Our approach overcomes primary bottlenecks of existing techniques, enabling high-fidelity sampling from challenging target distributions with improved efficiency.