Add nix shell config

This commit is contained in:
Yan Lin 2025-05-15 22:46:36 +02:00
parent f747f57145
commit ac07362535
5 changed files with 120 additions and 47 deletions

View file

@ -21,6 +21,21 @@
processHtmlClass: 'arithmatex'
}
};
window.addEventListener('load', function() {
document.querySelectorAll('script[type^="math/tex"]').forEach(function(script) {
const isDisplay = script.type.includes('mode=display');
const math = script.textContent;
const span = document.createElement('span');
span.className = isDisplay ? 'mathjax-block' : 'mathjax-inline';
span.innerHTML = isDisplay ? `\\[${math}\\]` : `\\(${math}\\)`;
script.parentNode.replaceChild(span, script);
});
if (typeof MathJax !== 'undefined' && MathJax.typesetPromise) {
MathJax.typesetPromise();
}
});
</script>
<style>
a {
@ -120,25 +135,26 @@
<blockquote>
<p>Van Den Oord, Aaron, and Oriol Vinyals. "Neural discrete representation learning." NeurIPS, 2017.</p>
</blockquote>
<p>Vector quantization maintains a "codebook" <span class="arithmatex">\(\boldsymbol C \in \mathbb R^{n\times d}\)</span>, which functions similarly to the index-fetching embedding layer, where <span class="arithmatex">\(n\)</span> is the total number of unique tokens, and <span class="arithmatex">\(d\)</span> is the embedding size. A given continuous vector <span class="arithmatex">\(\boldsymbol{z}\in\mathbb R^{d}\)</span> is quantized into a discrete value <span class="arithmatex">\(i\in\mathbb [0,n-1]\)</span> by finding the closest row vector in <span class="arithmatex">\(\boldsymbol C\)</span> to <span class="arithmatex">\(\boldsymbol{z}\)</span>, and that row vector <span class="arithmatex">\(\boldsymbol C_i\)</span> is fetched as the embedding for <span class="arithmatex">\(\boldsymbol{z}\)</span>. Formally:
$$
<p>Vector quantization maintains a "codebook" <script type="math/tex">\boldsymbol C \in \mathbb R^{n\times d}</script>, which functions similarly to the index-fetching embedding layer, where <script type="math/tex">n</script> is the total number of unique tokens, and <script type="math/tex">d</script> is the embedding size. A given continuous vector <script type="math/tex">\boldsymbol{z}\in\mathbb R^{d}</script> is quantized into a discrete value <script type="math/tex">i\in\mathbb [0,n-1]</script> by finding the closest row vector in <script type="math/tex">\boldsymbol C</script> to <script type="math/tex">\boldsymbol{z}</script>, and that row vector <script type="math/tex">\boldsymbol C_i</script> is fetched as the embedding for <script type="math/tex">\boldsymbol{z}</script>. Formally:
<script type="math/tex; mode=display">
i = \arg\min_j ||\boldsymbol z - \boldsymbol C_j||₂
$$
</script>
<img alt="Screen_Shot_2020-06-28_at_4.26.40_PM" src="/blog/md/multi-modal-transformer.assets/Screen_Shot_2020-06-28_at_4.26.40_PM.png" /></p>
<h2>Lookup-Free Quantization</h2>
<p>A significant limitation of vector quantization is that it requires calculating distances between the given continuous vectors and the entire codebook, which becomes computationally expensive for large-scale codebooks. This creates tension with the need for expanded codebooks to represent complex modalities such as images and videos. Research has shown that simply increasing the number of unique tokens doesn't always improve codebook performance.</p>
<blockquote>
<p>“A simple trick for training a larger codebook involves decreasing the code embedding dimension when increasing the vocabulary size.” Source: <em>Yu, Lijun, Jose Lezama, et al. “Language Model Beats Diffusion - Tokenizer Is Key to Visual Generation,” ICLR, 2024.</em></p>
</blockquote>
<p>Building on this insight, <strong>Lookup-Free Quantization</strong> (LFQ) eliminates the embedding dimension of codebooks (essentially reducing the embedding dimension to 0) and directly calculates the discrete index <span class="arithmatex">\(i\)</span> by individually quantizing each dimension of <span class="arithmatex">\(\boldsymbol z\)</span> into a binary digit. The index <span class="arithmatex">\(i\)</span> can then be computed by converting the binary representation to decimal. Formally:
$$
i=\sum_{j=1}^{d} 2^{(j-1)}\cdot 𝟙(z_j &gt; 0)
$$</p>
<p>Building on this insight, <strong>Lookup-Free Quantization</strong> (LFQ) eliminates the embedding dimension of codebooks (essentially reducing the embedding dimension to 0) and directly calculates the discrete index <script type="math/tex">i</script> by individually quantizing each dimension of <script type="math/tex">\boldsymbol z</script> into a binary digit. The index <script type="math/tex">i</script> can then be computed by converting the binary representation to decimal. Formally:
<script type="math/tex; mode=display">
i=\sum_{j=1}^{d} 2^{(j-1)}\cdot 𝟙(z_j > 0)
</script>
</p>
<blockquote>
<p>For example, given a continuous vector <span class="arithmatex">\(\boldsymbol z=\langle -0.52, 1.50, 0.53, -1.32\rangle\)</span>, we first quantize each dimension into <span class="arithmatex">\(\langle 0, 1, 1, 0\rangle\)</span>, based on the sign of each dimension. The token index of <span class="arithmatex">\(\boldsymbol z\)</span> is simply the decimal equivalent of the binary 0110, which is 6.</p>
<p>For example, given a continuous vector <script type="math/tex">\boldsymbol z=\langle -0.52, 1.50, 0.53, -1.32\rangle</script>, we first quantize each dimension into <script type="math/tex">\langle 0, 1, 1, 0\rangle</script>, based on the sign of each dimension. The token index of <script type="math/tex">\boldsymbol z</script> is simply the decimal equivalent of the binary 0110, which is 6.</p>
</blockquote>
<p>However, this approach introduces another challenge: we still need an index-fetching embedding layer to map these token indices into embedding vectors for the Transformer. This, combined with the typically large number of unique tokens when using LFQ—a 32-dimensional <span class="arithmatex">\(\boldsymbol z\)</span> will result in <span class="arithmatex">\(2^{32}=4,294,967,296\)</span> unique tokens—creates significant efficiency problems. One solution is to factorize the token space. Effectively, this means splitting the binary digits into multiple parts, embedding each part separately, and concatenating the resulting embedding vectors. For example, with a 32-dimensional <span class="arithmatex">\(\boldsymbol z\)</span>, if we quantize and embed its first and last 16 dimensions separately, we “only” need to handle <span class="arithmatex">\(2^{16}*2= 131,072\)</span> unique tokens.</p>
<p>Note that this section doesn't extensively explain how to map raw continuous features into the vector <span class="arithmatex">\(\boldsymbol{z}\)</span>, as these techniques are relatively straightforward and depend on the specific feature type—for example, fully-connected layers for numerical features, or CNN/GNN with feature flattening for structured data.</p>
<p>However, this approach introduces another challenge: we still need an index-fetching embedding layer to map these token indices into embedding vectors for the Transformer. This, combined with the typically large number of unique tokens when using LFQ—a 32-dimensional <script type="math/tex">\boldsymbol z</script> will result in <script type="math/tex">2^{32}=4,294,967,296</script> unique tokens—creates significant efficiency problems. One solution is to factorize the token space. Effectively, this means splitting the binary digits into multiple parts, embedding each part separately, and concatenating the resulting embedding vectors. For example, with a 32-dimensional <script type="math/tex">\boldsymbol z</script>, if we quantize and embed its first and last 16 dimensions separately, we “only” need to handle <script type="math/tex">2^{16}*2= 131,072</script> unique tokens.</p>
<p>Note that this section doesn't extensively explain how to map raw continuous features into the vector <script type="math/tex">\boldsymbol{z}</script>, as these techniques are relatively straightforward and depend on the specific feature type—for example, fully-connected layers for numerical features, or CNN/GNN with feature flattening for structured data.</p>
<h2>Quantization over Linear Projection</h2>
<p>You might be asking—why can't we simply use linear projections to map the raw continuous features into the embedding space? What are the benefits of quantizing continuous features into discrete tokens?</p>
<p>Although Transformers are regarded as universal sequential models, they were designed for discrete tokens in their first introduction in <em>Vaswani et al., "Attention Is All You Need"</em>. Empirically, they have optimal performance when dealing with tokens, compared to continuous features. This is supported by many research papers claiming that quantizing continuous features improves the performance of Transformers, and works demonstrating Transformers' subpar performance when applied directly to continuous features.</p>
@ -169,7 +185,7 @@ $$</p>
<h1>Output Layer</h1>
<p>For language generation, Transformers typically use classifier output layers, mapping the latent vector of each item in the output sequence back to tokens. As we've established in the "modality embedding" section, the optimal method to embed continuous features is to quantize them into discrete tokens. Correspondingly, an intuitive method to output continuous features is to map these discrete tokens back to the continuous feature space, essentially reversing the vector quantization process.</p>
<h2>Reverse Vector Quantization</h2>
<p>One approach to reverse vector quantization is readily available in VQ-VAE, since it is an auto-encoder. Given a token <span class="arithmatex">\(i\)</span>, we can look up its embedding in the codebook as <span class="arithmatex">\(\boldsymbol C_i\)</span>, then apply a decoder network to map <span class="arithmatex">\(\boldsymbol C_i\)</span> back to the continuous feature vector <span class="arithmatex">\(\boldsymbol z\)</span>. The decoder network can be pre-trained in the VQ-VAE framework—pre-train the VQ-VAE tokenizer, encoder, and decoder using auto-encoding loss functions, or end-to-end trained along with the whole Transformer. In the NLP and CV communities, the pre-training approach is more popular, since there are many large-scale pre-trained auto-encoders available.</p>
<p>One approach to reverse vector quantization is readily available in VQ-VAE, since it is an auto-encoder. Given a token <script type="math/tex">i</script>, we can look up its embedding in the codebook as <script type="math/tex">\boldsymbol C_i</script>, then apply a decoder network to map <script type="math/tex">\boldsymbol C_i</script> back to the continuous feature vector <script type="math/tex">\boldsymbol z</script>. The decoder network can be pre-trained in the VQ-VAE framework—pre-train the VQ-VAE tokenizer, encoder, and decoder using auto-encoding loss functions, or end-to-end trained along with the whole Transformer. In the NLP and CV communities, the pre-training approach is more popular, since there are many large-scale pre-trained auto-encoders available.</p>
<figure class="figure">
<img alt="image (4)" src="/blog/md/multi-modal-transformer.assets/image (4).png" / class="figure-img img-fluid rounded">
<figcaption class="figure-caption">The encoder-decoder structure of MAGVIT (<em>Yu et al., “MAGVIT”</em>), a visual VQ-VAE model. A 3D-VQ encoder quantizes a video into discrete tokens, and a 3D-VQ decoder maps them back to the pixel space.</figcaption>

View file

@ -21,6 +21,21 @@
processHtmlClass: 'arithmatex'
}
};
window.addEventListener('load', function() {
document.querySelectorAll('script[type^="math/tex"]').forEach(function(script) {
const isDisplay = script.type.includes('mode=display');
const math = script.textContent;
const span = document.createElement('span');
span.className = isDisplay ? 'mathjax-block' : 'mathjax-inline';
span.innerHTML = isDisplay ? `\\[${math}\\]` : `\\(${math}\\)`;
script.parentNode.replaceChild(span, script);
});
if (typeof MathJax !== 'undefined' && MathJax.typesetPromise) {
MathJax.typesetPromise();
}
});
</script>
<style>
a {
@ -83,13 +98,13 @@
<hr />
<h1>Background</h1>
<p>Diffusion models (DMs), or more broadly speaking, score-matching generative models, have become the de facto framework for building deep generation models. They demonstrate exceptional generation performance, especially on continuous modalities including images, videos, audios, and spatiotemporal data.</p>
<p>Most diffusion models work by coupling a forward diffusion process and a reverse denoising diffusion process. The forward diffusion process gradually adds noise to the ground truth clean data <span class="arithmatex">\(X_0\)</span>, until noisy data <span class="arithmatex">\(X_T\)</span> that follows a relatively simple distribution is reached. The reverse denoising diffusion process starts from the noisy data <span class="arithmatex">\(X_T\)</span>, and removes the noise component step-by-step until clean generated data <span class="arithmatex">\(X_0\)</span> is reached. The reverse process is essentially a Monte-Carlo process, meaning it cannot be parallelized for each generation, which can be inefficient for a process with a large number of steps.</p>
<p>Most diffusion models work by coupling a forward diffusion process and a reverse denoising diffusion process. The forward diffusion process gradually adds noise to the ground truth clean data <script type="math/tex">X_0</script>, until noisy data <script type="math/tex">X_T</script> that follows a relatively simple distribution is reached. The reverse denoising diffusion process starts from the noisy data <script type="math/tex">X_T</script>, and removes the noise component step-by-step until clean generated data <script type="math/tex">X_0</script> is reached. The reverse process is essentially a Monte-Carlo process, meaning it cannot be parallelized for each generation, which can be inefficient for a process with a large number of steps.</p>
<figure class="figure">
<img alt="image-20250503125941212" src="/blog/md/one-step-diffusion-models.assets/image-20250503125941212.png" / class="figure-img img-fluid rounded">
<figcaption class="figure-caption">The two processes in a typical diffusion model. <em>Source: Ho, Jain, and Abbeel, “Denoising Diffusion Probabilistic Models.”</em></figcaption>
</figure>
<h2>Understanding DMs</h2>
<p>There are many ways to understand how Diffusion Models (DMs) work. One of the most common and intuitive approaches is that a DM learns an ordinary differential equation (ODE) that transforms noise into data. Imagine an ODE vector field between the noise <span class="arithmatex">\(X_T\)</span> and clean data <span class="arithmatex">\(X_0\)</span>. By training on sufficiently large numbers of timesteps <span class="arithmatex">\(t\in [0,T]\)</span>, a DM is able to learn the vector (tangent) towards the cleaner data <span class="arithmatex">\(X_{t-\Delta t}\)</span>, given any specific timestep <span class="arithmatex">\(t\)</span> and the corresponding noisy data <span class="arithmatex">\(X_t\)</span>. This idea is easy to illustrate in a simplified 1-dimensional data scenario.</p>
<p>There are many ways to understand how Diffusion Models (DMs) work. One of the most common and intuitive approaches is that a DM learns an ordinary differential equation (ODE) that transforms noise into data. Imagine an ODE vector field between the noise <script type="math/tex">X_T</script> and clean data <script type="math/tex">X_0</script>. By training on sufficiently large numbers of timesteps <script type="math/tex">t\in [0,T]</script>, a DM is able to learn the vector (tangent) towards the cleaner data <script type="math/tex">X_{t-\Delta t}</script>, given any specific timestep <script type="math/tex">t</script> and the corresponding noisy data <script type="math/tex">X_t</script>. This idea is easy to illustrate in a simplified 1-dimensional data scenario.</p>
<figure class="figure">
<img alt="image-20250503132738122" src="/blog/md/one-step-diffusion-models.assets/image-20250503132738122.png" / class="figure-img img-fluid rounded">
<figcaption class="figure-caption">Illustrated ODE flow of a diffusion model on 1-dimensional data. <em>Source: Song et al., “Score-Based Generative Modeling through Stochastic Differential Equations.”</em> It should be noted that as the figure suggests, there are differences between ODEs and DMs in a narrow sense. Flow matching models, a variant of DMs, more closely resemble ODEs.</figcaption>
@ -106,7 +121,7 @@ Song et al., “Score-Based Generative Modeling through Stochastic Differential
<img alt="image-20250503135351246" src="/blog/md/one-step-diffusion-models.assets/image-20250503135351246.png" / class="figure-img img-fluid rounded">
<figcaption class="figure-caption">Images generated by conventional DMs with only a few steps of reverse process. <em>Source: Frans et al., “One Step Diffusion via Shortcut Models.”</em></figcaption>
</figure>
<p>To understand why DMs scale poorly with few reverse process steps, we can return to the ODE vector field perspective of DMs. When the target data distribution is complex, the vector field typically contains numerous intersections. When a given <span class="arithmatex">\(X_t\)</span> and <span class="arithmatex">\(t\)</span> is at these intersections, the vector points to the averaged direction of all candidates. This causes the generated data to approach the mean of the training data when only a few reverse process steps are used. Another explanation is that the learned vector field is highly curved. Using only a few reverse process steps means attempting to approximate these curves with polylines, which is inherently difficult.</p>
<p>To understand why DMs scale poorly with few reverse process steps, we can return to the ODE vector field perspective of DMs. When the target data distribution is complex, the vector field typically contains numerous intersections. When a given <script type="math/tex">X_t</script> and <script type="math/tex">t</script> is at these intersections, the vector points to the averaged direction of all candidates. This causes the generated data to approach the mean of the training data when only a few reverse process steps are used. Another explanation is that the learned vector field is highly curved. Using only a few reverse process steps means attempting to approximate these curves with polylines, which is inherently difficult.</p>
<figure class="figure">
<img alt="image-20250503141422791" src="/blog/md/one-step-diffusion-models.assets/image-20250503141422791.png" / class="figure-img img-fluid rounded">
<figcaption class="figure-caption">Illustration of the why DMs scale poorly with few reverse process steps. <em>Source: Frans et al., “One Step Diffusion via Shortcut Models.”</em></figcaption>
@ -114,18 +129,21 @@ Song et al., “Score-Based Generative Modeling through Stochastic Differential
<p>We will introduce two branches of methods that aim to scale DMs to few or even reverse process steps: <strong>distillation-based</strong>, which distillates a pre-trained DM into a one-step model; and <strong>end-to-end-based</strong>, which trains a one-step DM from scratch.</p>
<h1>Distallation</h1>
<p>Distillation-based methods are also called <strong>rectified flow</strong> methods. Their idea follows the above insight of "curved ODE vector field": if the curved vectors (flows) are hindering the scaling of reverse process steps, can we try to straighten these vectors so that they are easy to approximate with polylines or even straight lines?</p>
<p><em>Liu, Gong, and Liu, "Flow Straight and Fast"</em> implements this idea, focusing on learning an ODE that follows straight vectors as much as possible. In the context of continuous time DMs where <span class="arithmatex">\(T=1\)</span> and and <span class="arithmatex">\(t\in[0,1]\)</span>, suppose the clean data <span class="arithmatex">\(X_0\)</span> and noise <span class="arithmatex">\(X_1\)</span> each follows a data distribution, <span class="arithmatex">\(X_0\sim \pi_0\)</span> and <span class="arithmatex">\(X_1\sim \pi_1\)</span>. The "straight vectors" can be achieved by solving a nonlinear least squares optimization problem:
$$
\min_{v} \int_{0}^{1} \mathbb{E}\left[\left|\left(X_{1}-X_{0}\right)-v\left(X_{t}, t\right)\right|^{2}\right] \mathrm{d} t,
$$</p>
<div class="arithmatex">\[
<p><em>Liu, Gong, and Liu, "Flow Straight and Fast"</em> implements this idea, focusing on learning an ODE that follows straight vectors as much as possible. In the context of continuous time DMs where <script type="math/tex">T=1</script> and and <script type="math/tex">t\in[0,1]</script>, suppose the clean data <script type="math/tex">X_0</script> and noise <script type="math/tex">X_1</script> each follows a data distribution, <script type="math/tex">X_0\sim \pi_0</script> and <script type="math/tex">X_1\sim \pi_1</script>. The "straight vectors" can be achieved by solving a nonlinear least squares optimization problem:
<script type="math/tex; mode=display">
\min_{v} \int_{0}^{1} \mathbb{E}\left[\left\|\left(X_{1}-X_{0}\right)-v\left(X_{t}, t\right)\right\|^{2}\right] \mathrm{d} t,
</script>
</p>
<p>
<script type="math/tex; mode=display">
\quad X_{t}=t X_{1}+(1-t) X_{0}
\]</div>
<p>Where <span class="arithmatex">\(v\)</span> is the vector field of the ODE <span class="arithmatex">\(dZ_t = v(Z_t,t)dt\)</span>.</p>
<p>Though straightforward, when the clean data distribution <span class="arithmatex">\(\pi_0\)</span> is very complicated, the ideal result of completely straight vectors can be hard to achieve. To address this, a "reflow" procedure is introduced. This procedure iteratively trains new rectified flows using data generated by previously obtained flows:
$$
</script>
</p>
<p>Where <script type="math/tex">v</script> is the vector field of the ODE <script type="math/tex">dZ_t = v(Z_t,t)dt</script>.</p>
<p>Though straightforward, when the clean data distribution <script type="math/tex">\pi_0</script> is very complicated, the ideal result of completely straight vectors can be hard to achieve. To address this, a "reflow" procedure is introduced. This procedure iteratively trains new rectified flows using data generated by previously obtained flows:
<script type="math/tex; mode=display">
Z^{(k+1)} = RectFlow((Z_0^k, Z_1^k))
$$
</script>
This procedure produces increasingly straight flows that can be simulated with very few steps, ideally one step after several iterations.</p>
<figure class="figure">
<img alt="image-20250504142749208" src="/blog/md/one-step-diffusion-models.assets/image-20250504142749208.png" / class="figure-img img-fluid rounded">
@ -135,45 +153,52 @@ This procedure produces increasingly straight flows that can be simulated with v
<h1>End-to-end</h1>
<p>Compared to distillation-based methods, end-to-end-based methods train a one-step-capable diffusion model (DM) within a single training run. Various techniques are used to implement such methods. We will focus on two of them: <strong>consistency models</strong> and <strong>shortcut models</strong>.</p>
<h2>Consistency Models</h2>
<p>In discrete-timestep diffusion models (DMs), three components in the reverse denoising diffusion process are interchangeable through reparameterization: the noise component <span class="arithmatex">\(\epsilon_t\)</span> to remove, the less noisy previous step <span class="arithmatex">\(x_{t-1}\)</span>, and the predicted clean sample <span class="arithmatex">\(x_0\)</span>. This interchangeability is enabled by the following equation:
$$
<p>In discrete-timestep diffusion models (DMs), three components in the reverse denoising diffusion process are interchangeable through reparameterization: the noise component <script type="math/tex">\epsilon_t</script> to remove, the less noisy previous step <script type="math/tex">x_{t-1}</script>, and the predicted clean sample <script type="math/tex">x_0</script>. This interchangeability is enabled by the following equation:
<script type="math/tex; mode=display">
x_t = \sqrt{\bar{\alpha}_t} \, x_0 + \sqrt{1 - \bar{\alpha}_t} \, \epsilon_t
$$
In theory, without altering the fundamental formulation of DMs, the learnable denoiser network can be designed to predict any of these three components. Consistency models (CMs) follow this principle by training the denoiser to specifically predict the clean sample <span class="arithmatex">\(x_0\)</span>. The benefit of this approach is that CMs can naturally scale to perform the reverse process with few steps or even a single step.</p>
</script>
In theory, without altering the fundamental formulation of DMs, the learnable denoiser network can be designed to predict any of these three components. Consistency models (CMs) follow this principle by training the denoiser to specifically predict the clean sample <script type="math/tex">x_0</script>. The benefit of this approach is that CMs can naturally scale to perform the reverse process with few steps or even a single step.</p>
<figure class="figure">
<img alt="image-20250504161430743" src="/blog/md/one-step-diffusion-models.assets/image-20250504161430743.png" / class="figure-img img-fluid rounded">
<figcaption class="figure-caption">A consistency model that learns to map any point on the ODE trajectory to the clean sample. <em>Source: Song et al., “Consistency Models.”</em></figcaption>
</figure>
<p>Formally, CMs learn a function <span class="arithmatex">\(f_\theta(x_t,t)\)</span> that maps noisy data <span class="arithmatex">\(x_t\)</span> at time <span class="arithmatex">\(t\)</span> directly to the clean data <span class="arithmatex">\(x_0\)</span>, satisfying:
$$
<p>Formally, CMs learn a function <script type="math/tex">f_\theta(x_t,t)</script> that maps noisy data <script type="math/tex">x_t</script> at time <script type="math/tex">t</script> directly to the clean data <script type="math/tex">x_0</script>, satisfying:
<script type="math/tex; mode=display">
f_\theta(x_t, t) = f_\theta(x_{t'}, t') \quad \forall t, t'
$$
</script>
The model must also obey the differential consistency condition:
$$
<script type="math/tex; mode=display">
\frac{d}{dt} f_\theta(x_t, t) = 0
$$
</script>
CMs are trained by minimizing the discrepancy between outputs at adjacent times, with the loss function:
$$
<script type="math/tex; mode=display">
\mathcal{L} = \mathbb{E} \left[ d\left(f_\theta(x_t, t), f_\theta(x_{t'}, t')\right) \right]
$$
</script>
Similar to continuous-timestep DMs and discrete-timestep DMs, CMs also have continuous-time and discrete-time variants. Discrete-time CMs are easier to train, but are more sensitive to timestep scheduling and suffer from discretization errors. Continuous-time CMs, on the other hand, suffer from instability during training.</p>
<p>For a deeper discussion of the differences between the two variants of CMs, and how to stabilize continuous-time CMs, please refer to <em>Lu and Song, "Simplifying, Stabilizing and Scaling Continuous-Time Consistency Models."</em></p>
<h2>Shortcut Models</h2>
<p>Similar to distillation-based methods, the core idea of shortcut models is inspired by the "curved vector field" problem, but the shortcut models take a different approach to solve it.</p>
<p>Shortcut models are introduced in <em>Frans et al., "One Step Diffusion via Shortcut Models."</em> The paper presents the insight that conventional DMs perform badly when jumping with large step sizes stems from their lack of awareness of the step size they are set to jump forward. Since they are only trained to comply with small step sizes, they are only learning the tangents in the curved vector field, not the "correct direction" when a large step size is used.</p>
<p>Based on this insight, on top of <span class="arithmatex">\(x_t\)</span> and <span class="arithmatex">\(t\)</span>, shortcut models additionally include step size <span class="arithmatex">\(d\)</span> as part of the condition for the denoiser network. At small step sizes (<span class="arithmatex">\(d\rightarrow 0\)</span>), the model behaves like a standard flow-matching model, learning the expected tangent from noise to data. For larger step sizes, the model learns that one large step should equal two consecutive smaller steps (self-consistency), creating a binary recursive formulation. The model is trained by combining the standard flow matching loss when <span class="arithmatex">\(d=0\)</span> and the self-consistency loss when <span class="arithmatex">\(d&gt;0\)</span>:
$$
\mathcal{L} = \mathbb{E} [ \underbrace{| s_\theta(x_t, t, 0) - (x_1 - x_0)|^2}_{\text{Flow-Matching}} +
$$</p>
<div class="arithmatex">\[
<p>Based on this insight, on top of <script type="math/tex">x_t</script> and <script type="math/tex">t</script>, shortcut models additionally include step size <script type="math/tex">d</script> as part of the condition for the denoiser network. At small step sizes (<script type="math/tex">d\rightarrow 0</script>), the model behaves like a standard flow-matching model, learning the expected tangent from noise to data. For larger step sizes, the model learns that one large step should equal two consecutive smaller steps (self-consistency), creating a binary recursive formulation. The model is trained by combining the standard flow matching loss when <script type="math/tex">d=0</script> and the self-consistency loss when <script type="math/tex">d>0</script>:
<script type="math/tex; mode=display">
\mathcal{L} = \mathbb{E} [ \underbrace{\| s_\theta(x_t, t, 0) - (x_1 - x_0)\|^2}_{\text{Flow-Matching}} +
</script>
</p>
<p>
<script type="math/tex; mode=display">
\underbrace{\|s_\theta(x_t, t, 2d) - \mathbf{s}_{\text{target}}\|^2}_{\text{Self-Consistency}}],
\]</div>
<div class="arithmatex">\[
</script>
</p>
<p>
<script type="math/tex; mode=display">
\quad \mathbf{s}_{\text{target}} = s_\theta(x_t, t, d)/2 + s_\theta(x'_{t+d}, t + d, d)/2 \quad
\]</div>
<div class="arithmatex">\[
</script>
</p>
<p>
<script type="math/tex; mode=display">
\text{and} \quad x'_{t+d} = x_t + s_\theta(x_t, t, d)d
\]</div>
</script>
</p>
<figure class="figure">
<img alt="image-20250504180714955" src="/blog/md/one-step-diffusion-models.assets/image-20250504180714955.png" / class="figure-img img-fluid rounded">
<figcaption class="figure-caption">Illustration of the training process of shortcut models. <em>Source: Frans et al., “One Step Diffusion via Shortcut Models.”</em></figcaption>

View file

@ -21,6 +21,21 @@
processHtmlClass: 'arithmatex'
}
};
window.addEventListener('load', function() {
document.querySelectorAll('script[type^="math/tex"]').forEach(function(script) {
const isDisplay = script.type.includes('mode=display');
const math = script.textContent;
const span = document.createElement('span');
span.className = isDisplay ? 'mathjax-block' : 'mathjax-inline';
span.innerHTML = isDisplay ? `\\[${math}\\]` : `\\(${math}\\)`;
script.parentNode.replaceChild(span, script);
});
if (typeof MathJax !== 'undefined' && MathJax.typesetPromise) {
MathJax.typesetPromise();
}
});
</script>
<style>
a {