Score-based SDE 生成模型从入门到出师系列(一):用随机微分方程建模图像生成任务并统一分数和扩散模型

极市平台

共 92853字,需浏览 186分钟

 · 2024-04-14

↑ 点击蓝字 关注极市平台
作者丨CW不要無聊的風格
编辑丨极市平台

极市导读

 

超详细解读如何使用随机微分方程(SDE)框架来统一分数模型和扩散模型(DDPM)进行生成建模。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

一上来先吹水

久违的这篇文章终于来了!本来在去年春节写完 ScoreNetwork(分数模型) 那篇文章时,就有计划要接着解析这篇以 SDE(随机微分方程) 框架统一分数模型和扩散模型(DDPM)的佳作——Score-Based Generative Modeling through Stochastic Differential Equations,奈何当时看完 paper 后,感觉自己有许多地方并不理解,特别是其中涉及的数学推导,比如:

  • reverse-time SDE 的公式(论文公式(6)&(16))是怎么推导出来的
  • 作者提出的三种 SDE 形式所对应的 perturbation kernel 的均值与方差(论文公式(29))是怎么计算的
  • 评估似然(likelihood)的指标 bpd(bits per dimension)在计算时最后为何通常要加上 offset(通常是  log 2 ( 256 ) = 8  或  log 2 ( 128 ) = 7  ))

...

奔着互联网是个好东西的想法,CW 第一时间就在全網“搜刮”答案,结果却令我吃瘪,不仅没有让我 get 到满意的答案,甚至连讨论以上几个问题的都寥寥无几..

吾辈自强!靠别人不行,那就唯有靠自己了。于是就在概率统计和随机过程那片地里好好地玩了一段时间,甚至不小心还差点跨界到了物理学那边(那可是大佬们的地盘,不敢轻易踩线..)。

在这闭关修炼期间,我接连发了解析 Neural ODE(https://zhuanlan.zhihu.com/p/644097554) 和 Bayesian Flow Networks(https://zhuanlan.zhihu.com/p/659111986) 的文章,以定期检验自己的三脚猫功夫处于第几层。后面,我回过头来看 SDE 这篇 paper,感觉一切都明朗了,之前脑海中的迷雾已烟消云散,感觉甚是良好,怡然自得~

小时候老师就教导 CW 要与大家团结友好互助,我深刻地不敢忘记,并且努力贯彻着,如今我就把自己的理解与江湖上各位朋友分享,在自己揭开迷雾后也希望能为大家扫除障碍尽一份力,对于论文中涉及但却没有讲明白的数学推导,我一定会在文中摊开来解析(杜绝默认大家都懂然而自己都不懂这种现象)。但如果是已有其它出处给出非常详细的过程,那么我就会指明出处(比如贴链接)而不在文中赘述。同时,也欢迎大家一起热烈交流。当然,要友好哦(但可以不那么正经)~

来点正经的——文章内容安排

由于这篇 paper 值得仔细推敲的细节众多,因此 CW 计划分三篇进行解析。本文作为首篇,会把大方面给覆盖掉,主要将目光聚焦于 SDE 对图像生成任务的建模方式 以及 分数模型和扩散模型是如何被统一到该框架下的

为了贯彻不无聊的风格,三篇文章都会在吹水(理论解析)的同时会结合代码实现,每讲完一部分就立马把对应的代码给撸了,以达到 from scratch & step by step 的效果。

至于另外两篇文章的主题,现在先暂时埋个坑(哦,两篇文章是两个坑才对),但是该有的会有,绝不拉下!

这篇论文讲了什么

在 CW 细嚼慢咽并反复斟酌后,认为全篇可归纳为以下四方面:

i. 以 SDE 视角来看待图像生成问题

相比于以往生成模型的做法,这篇 paper 的方法比较“数学&物理 style”—— 使用 SDE(随机微分方程) 来对图像生成任务进行建模。概括起来,这种玩法就是利用 SDE 将复杂的图像分布“平滑地过渡”到简单的先验分布(比如将一张图映射为标准高斯噪声);同时利用 SDE 的逆向形式("reverse-time SDE")将先验分布映射回原来的图像分布

SDE 建模示意图

这两个互逆的 SDE 过程实际在做的事情就是加噪与去噪,这点与 SMLD(Score Matching with Langevin Dynamics, 其实这里指的是 NCSN) 和 DDPM 的套路一样,但通过 SDE 这种视角来建模,在凸显高级感(笑)的同时,还能顺带将 score-based models 和 diffusion models 变为“门徒”,也就是将这两种生成模型的建模方式纳入到 SDE 这个统一的框架下,可谓是非常秀了~

特别地,作者还分别为 SMLD 和 DDPM 所对应的 SDE 形式起了名字,前者名曰 _"Variance Exploding(VE) SDE";_后者名为 _"Variance Preserving(VP) SDE"_。另外,在 VP SDE 的基础上,作者还额外提出了一种叫 "sub-VP SDE" 的形式,它通常能够取得更好的采样质量和更高的似然(likelihood),可以认为其是 VP SDE 的改进版。

ii. 提出一种含校正机制的新采样方法

由于模型是跑在计算机上,因此在采样生成样本时就会用数值方法去解逆向 SDE,这种做法本质上是以离散近似连续,势必会引入误差。幸运的是,求解逆向 SDE 的过程依赖于 score,这就使得我们可以在使用逆向 SDE 采样的基础上加入 score-based MCMC 这类采样方法(因为这类方法也依赖于 score),比如 SMLD 中用到的 Langevin dynamics。

Langevin dynamics

所谓人多力量大嘛,于是拍脑袋就会觉得综合两种采样方法会比仅使用一种来得强。同时又因为 SDE 在该论文中是一种主角的存在,所以在采样过程的每个时间步(time step)会以逆向 SDE 的求解结果作为主要的预测(采样)结果,然后再基于这个结果使用 score-based MCMC 作进一步校正。这么搞法,逆向 SDE 就相当于一种 "predictor" 角色;而 score-based MCMC 则承担起 "corrector" 的责任,是不是非常形象~

综上,作者就将这种采样方式命名为 _"Predictor-Corrector(PC) sampling"_。

iii. 使用 ODE 加速采样 & 计算似然

作者还惊喜地发现,用于建模的 SDE 还有等价的 ODE(常微分方程) 形式。直观来看,ODE 是一种特殊形式的 SDE,它相当于把 SDE 中的随机项给去掉了(可以当作是其随机项的系数为0),从而成为一种“确定性的轨迹”。对比起逆向 SDE,如果在采样时使用这个等价的 ODE 形式,通常会有更高的采样效率, 因为目前市面上已有许多成熟的黑盒 ODE 求解器,它们可是经过业界大佬们优化的,计算效率当然杠杠滴!

另外,使用这个 ODE,还**能够对模型生成的样本计算似然(likelihood)**,这就非常香了,因为似然虽然作为模型生成效果的评估指标之一,但并非每种生成模型都有能力去计算它(说的就是你——GAN)。可以简单认为,似然越高,模型越靠谱——生成的图像越符合真实图像的分布。

iv. 广泛且通用的条件生成能力

在 SDE 的建模方式下,还能够很方便地拓展至各种条件生成的场景。在大多数时候,模型在无条件生成的情况下完成训练后,就能够直接在条件生成的场景下进行推理(采样生成),无需重新训练,这种适配能力包括图像修复(image impainting)和图像上色(colorization)等。

图像上色

或许..可以拿来祛痘?

以 SDE 的视角进行生成式建模

接下来一起深入了解下在 SDE 的视角下,模型的训练与推理(采样)流程到底是如何运作的,具体包括模型的输出如何适配到 SDE(是个方程) 中、模型的学习目标是什么、损失函数如何设计以及与 SDE 相关的数学推导。

Itô SDE:扩散(加噪)过程的抽象表示

以往像 SMLD 和 DDPM 这些老流派在训练时加噪的方式是在每个时间步对图像施加高斯噪声:

以物理学的角度来看,以上过程本质上是在离散时间 内的扩散过程,由一开始“稳定的状态”不断叠加噪声(随机性)直至完全混乱(纯噪声)。

加噪示意图

而这篇论文的做法就是将这个扩散过程拓展至连续时间 ,从而就可以用伊藤过程(Itô SDE) 来建模:

也就是说, 顺着以上公式玩下去, 由时间 , 就能将一幅图像变成纯噪声图。在这个过程中图像 将成为时间变量的函数: , 其中初值 是原始图像, 代表原始图像分布; 终值 是高斯噪声, 代表标准高斯分布。

基于连续时间的随机过程进行加噪

在 (iii) 式所示的伊藤 SDE 中 被称为 的 “漂移系数(drift coefficient)", 则被称为 的“扩散系数(diffusion coefficient)"。需要注意的是, 作者为了简化这个建模过程, 就将扩散系数设置为标量(scalar)函数, 但实际上它可以是 这样的矩阵值函数, 这是更 general 的形式

最后,剩下的那项 代表_标准维纳过程(布朗运动)_,这种运动过程的变化服从正态分布,方差随时间间隔的长度线性增长。简单粗暴来说,就是以下这个意思:

布朗运动示意图

另外,这种运动过程在任意不同时间区间上变化的概率分布相互独立。比如 就是相互独立的。

(iii)式是伊藤 SDE 的通用表达形式,设置不同的漂移系数和扩散系数将导致不同具体形式的 SDE,SMLD 和 DDPM 就恰好是两种不同 SDE 的数值离散形式(莫急,后文会详谈)。

Reverse-Time SDE: 采样(去噪)过程的抽象表示

在扩散过程中时间变量是 0→T0 \rightarrow T0 \rightarrow T 正向递增的,对应建模的是网络的训练过程,这个过程将图像变成纯噪声。那么,推理(采样)过程就应该是它的逆过程,从而将纯噪声“变回”像原图那样有意义的图像**(并非变回和训练时一模一样的原图,否则就不是生成模型而是具有纯压缩功能的 AE(http://dl.acm.org/doi/10.5555/2987189.2987190) 了)。**

去噪(生成)示意图

直觉告诉我们(直觉演得好!),采样过程应该也会有一个对应的 SDE 形式。年轻人敢想就是好事!这不,古老的这篇 paper(http://www.sciencedirect.com/science/article/pii/0304414982900515) 指出,形如 (iii)式的扩散过程所对应的逆过程也是一个伊藤过程,并且由以下逆向时间(reverse-time) SDE 所表示:

这表明, 顺着以上公式不断迭代计算, 就能够由纯噪声生成有意义的图像。请注意, 这个逆向 的时间“流动方向”是 , 因此这里 是绝对值无穷小的负值。 其中 的概率密度 (说明概率密度会随时间变化), 从而 就是 score 了。 则对应与 (iii) 式的一模一样, 而 也同样是标准维纳过程, 只不过其运动过程中的时间值是 逆向“流动”的, 所以在 头上加了一条横杠, 以和 (iii) 的 作区分。

CW 在以往的文章(https://zhuanlan.zhihu.com/p/597490389)中解析过,score 与 噪声的方向相反(两者都是向量),且两者成倍数关系,因此 reverse-time SDE 就是一个动态的去噪过程。

基于逆向连续时间的随机过程进行去噪(生成)

有朋友或许要着急了——为何会有 (iv) 式这样的 reverse-time SDE,怎么证明的?乖,先忍忍,后文会谈的~

Score: 网络学习的目标

到目前为止,你貌似还不清楚网络的训练目标(loss 函数)是什么,毕竟盯着与训练过程所对应的 (iii)式只感觉一头雾水,对吧(别虚,CW 也是)..

当你在一条路子上走不通时, 不妨慢下来, 扭头看看旁边的道路一 (iv) 式 ( (iii) 旁边自然就是 (iv) 嘛 ), 容易看出, 它依赖于 时刻下的 score 一 , 也就是没有 score 的话 (iv) 式就玩不下去了。

而 score 从哪里来? 从已知条件中没法来, 因为 这个边缘分布, 它需要“考虑周全”原始图像 的各种情况才能计算出来, 更何况在采样时我们是不知道 的(提前知道了还采样个 哦!),因为它是 reverse-time SDE 想要努力达成的 KPI,也是采样的“终点”。

在深度学习这种大力出奇迹的时代, 凡是不知道的事情就让神经网络来估(瞎)计(猜)好了。自然地, 在这种玩法下, 训练目标就是 score 了。让模型学会估计 score 后, 就可以根据 (iv) 式, 结合模型的估计值 使用下式来进行完成采样生成:

问题来了! 如何让模型靠谱地估计(而非瞎猜) score 呢? 本质上这个问题就是要设计合适的 loss 函数让模型学会预测 score。可是刚才又说到边缘分布 是 intractable 的, 那么就会造成真实的 score 一 无法计算, 学习目标都不明确, 你让模型学个毛线.. KAO! 这样一来这个问题岂不是无解了?

别方!解法就藏在:

图像生成别只知道扩散模型(Diffusion Models),还有基于梯度去噪的分数模型:NCSN(Noise Conditional Score Networks)(https://zhuanlan.zhihu.com/p/597490389)

在以上那篇文章中,CW 介绍了几种可以在不计(知)算(道)真实 score 的情况下又可以让模型学会估计 score 的方法,其中最“平民化”的方法当属 sliced score matching 和 denoising score matching 了,后者也正是 NCSN(现在经常被叫作 SMLD) 所用的招数。

SMLD 的多尺度 denoise (采样)过程

假如我们使用 denoising score matching,那么 loss 函数就是:

注意, (v) 式用的是条件分布 , 在 SMLD/DDPM 的场景下, 它正是扩散过程中为加噪而显式构造的分布, 比如在 NCSN 中, 它就是 , 概率密度函数是有解析形式的, 于是这个“条件 score" - 就很容易计算了。

是权重参数, 根据 NCSN 的结论, 它通常设置为 ; 至于时间变量 , 则从连续的均匀分布 中采样。

另外, 当漂移系数 和扩散系数 不是 affine function 时(NCSN 和 DDPM 所对应的 SDE 形式中, 漂移系数和扩散系数都恰好是 affine function), 条件分布 可能会没有解析形式, 这时候就可以转为采用 sliced score matching 从而避免计算条件 score, 此时 loss 函数对应为:

其中 是服从均值为 0 、协方差矩阵为单位阵 的某种简单分布(比如标准正态分布)的随机向量。借助诸如 torch.autograd 这类自动微分框架, 这部分也是很容易进行计算的。

有关以上两种 loss 形式的详细推导及计算方法均在 CW 前面贴的那篇文章中有详细解析,没看过的是不是忍不住想要看看了(doge)~

更 General 的 SDE 形式

方才 CW 提到, (iii) 式是伊藤过程的简化形式, 而在更 general 的形式下, 扩散系数是与 相关 (而非仅仅与时间 相关)的矩阵值函数一 , 从而 (iii) 式就对应变为:

相应地,其逆向 SDE(即 (iv) 式)则变为:

其中 nabla 算子(https://www.zhihu.com/question/58797500) 作用于矩阵 的每个 维向量, 即每个向量的分量对相应维度的 的分量求偏导后再求和。将 记为矩阵 ,那么:

其中 , 于是最终 是一个 维向量。

Reverse-Time SDE 是怎么来的

对于求知欲旺盛的朋友们来说,到目前为止还不知道 reverse-time SDE 是怎么来的肯定非常纳闷(甚至恼火)。是的,CW 当时在看这篇 paper 的时候也是同样的心情,翻遍了全网也没有看到哪位好心人帮我破解这个问题(提出 reverse-time SDE 的那篇 paper 太过正经,实在看不下去..)。无奈之下,只能自力更生(揾食艰难呀~),经历一番头脑风暴后,好不容易算是捣鼓出了一个勉强令自己信服的答案,具体都记录在 CW 对下面这个问题的回答之中了:

https://www.zhihu.com/question/629085800/answer/3421808487

推导的基本思路是利用Fokker–Planck–Kolmogorov(FPK) 方程,根据概率密度对时间的导数与漂移和扩散系数之间的关系凑成等价形式的 SDE。

Coding Time

兄弟(或许还有靓女)们,前面 CW 对 SDE 的建模方式进行了一番介绍,现在是时候贯彻不无聊的风格——写代码了!我们先就前面提及的主要内容进行代码实现,在后文吹水时若发现还有相关的重要内容,再补充对应的实现,一步步来,以实现迭代更新(coder 都是这样玩的嘛~)。

    import abc
import torch

from typing import Callable, Union, Tuple

from torch import Tensor
from torch.nn import Module

class SDE(abc.ABC):
    def __init__(self):
        super().__init__()

    @property
    @abc.abstractmethod
    def T(self) -> int:
        """ 正向 SDE 的终止时刻, 整个过程时间的流动方向是 0 -> T """
        pass

    @abc.abstractmethod
    def sde(self, x: Tensor, t: Tensor) -> Tuple[Tensor, Tensor]:
        """ 计算漂移和扩散系数: f, g """
        pass
    
    @abc.abstractmethod
    def p_0t(self, x: Tensor, t: Tensor):
        """ 计算决定条件分布 p(x(t) | x(0)) 的参数,这里计划会返回均值和标准差 """
        pass
    
    def prior_sampling(self, shape) -> Tensor:
        """ 从先验分布 p_T(x) 中采样(作为采样起点),先验通常为标准高斯分布 """
        return torch.randn(*shape)
    
    def reverse(self, score_fn: Union[Module, Callable]):
        """ 构建逆向 SDE, 返回代表 reverse-time SDE 的对象 """

        T = self.T
        # 用于计算正向 SDE 的漂移和扩散系数的函数
        fw_sde = self.sde

        class RSDE(self.__class__):
            @property
            def T(self) -> int:
                return T
            
            def sde(self, x: Tensor, t: Tensor) -> Tuple[Tensor, Tensor]:
                # 正向 SDE 的漂移和扩散系数
                f, g = fw_sde(x, t)
                score = score_fn(x, t)

                # 根据正向 SDE 的漂移和扩散系数计算逆向 SDE 的漂移系数
                f = f - g[:, None, None, None] ** 2 * score

                return f, g
        
        return RSDE() 

正、逆两个方向的 SDE 实现如上,此外还额外封装了两个方法:p_0t()prior_sampling(), 前者在训练过程中为数据加噪提供了便利(在下面一段代码中就可以看到);而后者的返回结果则可以作为采样起点。

接着把 loss 函数也搞定,我们就选用 denoising score matching 来训练吧~

主要过程是先采样时间变量和高斯噪声,然后根据条件分布 对数据进行加噪,接着模型根据含噪声的数据和当前时间估计 score,最后根据公式 (v)计算 loss。

    def sde_loss_fn(sde: SDE, score_fn: Union[Module, Callable], data: Tensor, eps: float = 1e-5) -> Tensor:
    """ loss 函数, 其中时间变量是连续数值而非离散的时间步 """

    bs = data.size(0)

    T = sde.T
    # 时间变量从连续的均匀分布中采样
    # 这里做了特殊处理,使得最小值为 eps 而非 0,
    # 有助于稳定训练效果
    t = torch.rand(bs, device=data.device) * (T - eps) + eps

    # 从标准高斯分布中采样噪声
    noise = torch.randn_like(data)
    mean, std = sde.p_0t(data, t)
    # 生成加噪后的数据, 其服从均值为 mean, 标准差为 std 的高斯分布
    perturbed_data = mean + std[:, None, None, None] * noise

    # 模型根据含噪声的数据及当前时间估计出对应的 score
    score = score_fn(perturbed_data, t)
    # loss 函数化简后的形式, 计算出 loss 后独立在每个样本的所有维度求平均
    loss = ((score * std[:, None, None, None] + noise) ** 2).reshape(bs, -1).mean(dim=1)

    # 最后返回所有样本的 loss 均值
    return loss.mean() 

注意, 倒数第 2 行是公式 的化简结果, 权重参数 设置为条件分布 的方差, 也就是代码中的 std ** 2 。由于加噪使用的是标准高斯噪声, 因此加噪后数据也服从高斯分布, 并且均值和标准差就是以上代码中的 mean, std, 于是代入概率密度公式进行计算, 得到条件 score 即 就恰好是负的噪声除以标准差: -noise / std , 最后将权重参数作用上去就能得到代码中的化简结果了。

Tips: 对于这方面仍不太理解的朋友们,可以参考下前面贴出来的 CW 解析 NCSN 那篇文章。

同时还需要注意的是,时间变量t的最小值并非是0,而是取一个极小值 eps,具体原因待后文揭晓。至于 score_fn,它可以是一个函数,也可以是一个类,只要其中封装了模型的前向过程(foward)能够输出对 score 的预测值即可。

由离散 Markov Chain 至连续 SDE

Markov Chain -> SDE

前面 CW 说过,SMLD 和 DDPM 的加噪方式(离散的马尔科夫链)其实是两种不同 SDE 的数值离散形式。当然,这不是瞎说的,论文里有推导过程。但是,其中的推导比较简略,特别是将 SMLD 和 DDPM 推广至 SDE 形式时所对应的 perturbation kernel( 即 ) 的推导,会让你感觉有种看了等于没看的感觉..

基于这种尴尬的局面,CW 的 motivation 油然而生——将这部分的推导过程扒得明明白白,与大家坦诚相待。

SMLD -> VE SDE

SMLD 在每个时间步加噪的样子如 (ii) 式, 总共使用 步, 这是离散的形式; 如果将其拓展到连续的形式, 并且能表示无限个细粒度状态(而非直接从 跨度1个整数那么大), 那么就有 , 原来离散的时间步变成连续的浮点值: , 从而 。于是整个马尔科夫链 将成为连续的随机过程 , 同时噪声标准差 变成时间的函数 。同理, 随机噪声 也对应变为

现在, 让我们将原来状态之间的跨度从 改为从 , 其中 , 刚好等于 , 于是 perturbation kernel 即 (ii) 式就变为:

,上式就转变成:

最后一步是基于布朗运动的性质,因为 , 以上得到的结果就是 VE 。进一步, 根据 (vi) 式, 令 , 就能得到:

这就是 perturbation kernel 所服从的分布。由于 SMLD 中使用的 是递增的几何序列(等比数列), 因此容易发现, 随着 , perturbation kernel 的方差会愈发变大, 以至于 “爆炸" (explode), 这就是为何 SMLD 所对应的 SDE 被称为 "VE(Variance Exploding) SDE" 的原因。

其实, 这么推导 perturbation kernel 是不严谨的, 更给力的应该是根据这篇 paper 的公式 (5.50) & (5.51) 来推导 (这里为了方便理解, 就不再根据这两个公式来推导了, 有兴趣的玩家们可以尝试一下, 不难, 没有接下来 VP SDE 和 sub-VP SDE 的 perturbation kernel 的推导难)。这两个公式讲的是, 对于形如 (iii) 式的伊藤 SDE, 记 的期望和协方差矩阵分别为 , 那么这两者随时间的变化则是:

在接下来推导剩余两个 SDE 的 perturbation kernel 所服从的分布时,CW 会拥抱这两个公式的大腿。至于这两个式子的证明过程,对应的论文里有,而且写得也算明白,这里就不再赘述了。

DDPM -> VP SDE

DDPM 在每个时间步加噪的离散马尔科夫链如 (i) 式, 同理, 我们要将其拓展至连续时间的情况: ; 接着, 作者比较调皮地玩了个 trick, 额外定义了一个辅助集 ,使得 ; 然后, 同样让 对应起来, 并且 (注意, 这里的 已经不是原来的了, 而是 在连续时间上的表示); 最后, 如前面的情况, , 于是 (i) 式就变为:

将等式右边第一项的 看作函数 。当 时, 就有 , 于是就可以对 处进行泰勒展开。我们要求低一些, 将其展开到一阶即可, 于是:

别忘了回来——将 代入,当 时,就能继续在前面的式子上得到:

同样, 最后一步 的出现也是源于布朗运动的性质, 这便得到了 VP SDE, 其漂移系数 , 扩散系数

接下来推导 perturbation kernel: 所服从的分布。首先明确的是,这个分布是个高斯分布,因为在离散时间步的情况下,DDPM 加噪时的这个分布就是高斯分布,现在拓展到连续时间上面,可以认为是无穷多个细粒度的高斯分布的叠加(类似于混合高斯分布),其结果依然会是个高斯分布,而现在的关键是要确定这个分布的均值和方差。

那么如何确定呢?前面 CW 说到要利用公式 (5.50) & (5.51),所以就得解那两个微分方程,我们不妨先来关注方差,但在此之前,需要先引入一个很给力的结论。

由于 , 因此乘上常数 后其结果依然是 0 , 即:

从而:

将 VP SDE 的漂移系数和扩散系数代入上式后一并再代入到式 (5.51),就会得到:

以上就是 perturbation kernel 方差随时间变化的微分方程,使用分离变量法可解:

其中 均表示常数(后文均以此约定俗成)。注意第 2 步为了方便表示为 时刻的解, 于是将积分微元换为 。进一步将 的初值代入, 可得到常数 , 从而解得:

容易看出, 由于 , 因此 会介于 之间, 这就代表它是有界的,会一直“维持” (preserve) 在一定范围内 (而不会像 VE SDE 的方差那样发散), 这就是为何这个 SDE 被称为 "V(Variance) (Preserving) SDE" 的原因。特别是, 若 的话, 那么 就会恒等于 , 这就更能体现出 "preserve" 的意思了。

接下来推导 perturbation kernel 的均值,根据式 (5.50),有:

同样利用分离变量法,并结合初值 ,解得:

综合以上,若规定 ,那么 perturbation kernel 所服从的分布就是

sub-VP SDE

Inspired by VP SDE,作者进一步提出了 sub-VP SDE:

这个形式也不知道是怎么造出来的,它除了扩散系数与 VP SDE 不同之外,其余部分与 VP SDE 长得一模一样(好一个 inspired by..)。sub-VP SDE 为何长这样我们不管,我们要做的正经事是根据式 (5.50)&(5.51) 来推导出 perturbation kernel 的均值与方差。

由于 sub-VP SDE 与他哥(VP SDE)的漂移系数长得一样,而式 (5.50) 又仅与漂移系数相关,因此 perturbation kernel 的均值也自然会与他哥的一样,即:(白嫖成功,YEAH!!)

成功拿下一个目标之后,我们紧接着来推导方差 。根据式 (5.51) 以及前面所使用的 ** 式,可得:

Damn! 这下可不能再用分离变量法了, 咋办.. 冷静下来, 一不小心你就发现这其实是一阶非齐次线性 ODE 的形式: , 只不过在我们这里 。而一阶非齐次线性 ODE 是有通解形式的, 所以我们不妨先来研究下通解, 然后再转过头来将我们这里的具体形式代进去得到目标解。

先将一阶非齐次线性 ODE 变形为 , 同时在等式右边最后一项中, 将 表示为 的函数。这么一来, 就实现了“分家”一等式左边全是 “家族”、而右边则全是 “家族"。哈! 不用我说你应该也想到了, 这时候又可以使用前面的老套路—分离变量法了:

将上面得到的 y 代入原 ODE:

在顺利得到这个通解形式后,根据前面的对应关系将我们这里实际的形式代入,得到:

进一步将初值 代入, 从得到常数 。现在这种程度还不行,必须得把等式右边第一项那一大串给干掉,不然着实看着难受.. 先来看看其中的积分项能不能简化一下:

咦, 这样看着感觉还挺有规律, 都是 及其积分(你看自然底数的幂和整体的一个大积分)。抖一机灵, 来个换元法试试一令 , 则 , 从而上面的积分项变为:

干脆将上面 表达式中的 都表达为 , 于是整个表达式就变为:

还原为以 来表示, 最终得到 为:

与 paper 中公式(28)完全一致。若初值 , 则上式其实是一个完全平方: 。综上, sub-VP SDE 的 perturbation kernel 为:

恭喜你们!本文最艰苦的阶段已顺利渡过~

三个 SDE 及其 perturbation kernel 的具体表达式

享受完数学游戏后,是时候开启代码游戏了。不过,在正式写代码实现之前,还有部分预备工作需要搞一下~

前面几节 CW 和大家一起推导了三种 SDE 所对应的 perturbation kernel,总结起来就是论文中公式(29)的结果:

perturbation kernels

但是, 其中加噪所使用的 noise scale 到底是什么呢? 想必各位帅哥靓女也想到了——直接沿(白)用(嫖) SMLD 和 DDPM 的方案就好, 只不过将它们原先是离散的 换上另外一套衣服变为现在连续的 :

sub-VP SDE 呢?当然是跟着他哥的节奏——使用同样的 啦!

将以上两个式子代入三个 SDE 的表达式以及 刚才推导出的三个 perturbation kernel(如上面的 (29) 式) 中,就会得到论文公式 (30)~(34) 的结果,即:

  • VE SDE
VE SDE
perturbation kernel of VE SDE
  • VP SDE
VP SDE
perturbation kernel of VP SDE
  • sub-VP SDE
perturbation kernel of sub-VP SDE

细心的你应该注意到了,VE SDE 的最小时间值是大于0的,而其它两位则可以为0,这是为啥捏?

原因是 对应为原始图像分布, 也就是还没有进行加噪, 于是这个值为 0 ; 但由于采用了 SMLD 的加噪方案 (见前面 的式子), 根据该方案的设置, 在 时, ( 代表 SMLD 的 noise scales 中最小的噪声强度), 因此就会导致 不是连续可微的, 进而导致 VE SDE 在 时没法玩(undefined)。基于这个现象, 在实操写代码时, 就需要将时间的最小值“截胡"至某个趋近于 0 的极小值 , 作者选取了

另外还有个坑! 那就是作者在实验中发现, 不论是训练还是采样, VP SDE 在 时均会出现数值不稳定的现象, 于是也学 VE SDE 那样(老子一把梭 ! ) 设置一个正极小值并将其作为时间变量可取的最小值。在实现时, 作者选取了

至于你问 sub-VP SDE ? 还用说吗! ? 他哥(VP SDE)都那样子了, 他肯定也得随着改变呀! 也就是同样使用

另外,作者还提到,使用较小的 进行训练通常最终能导致模型获得较高的似然(likelihood);然而在采样时却要酌情选取合适的 (也就是不是一昧地设小)才可以获得较好的 IS(Inception Score) 和 FID(Fréchet Inception Distance)。深度学习,玄之又玄~

动手深度学习

OK,预备工作做足了,现在可以真正开撸了(兴奋不)!

  • VE SDE

先来干掉 VE SDE!既然是 SDE,它就得继承上一章所实现的 SDE class。另外,由于 VE SDE 是受 SMLD 启发而来的,因此这里还额外实现了该 SDE 的数值离散形式,即 SMLD 加噪(扩散)过程的马尔科夫链。

另外,需要注意下其先验分布并非是标准正态分布,而是以 为方差的 0 均值正态分布,因此需要重载(overwrite)父类所对应的方法prior_sampling()

    import numpy as np

class VESDE(SDE):
    def __init__(self, sigma_min: float = 0.01,sigma_max: float = 50., N: int = 1000):
        super().__init__()

        # 最小的 noise scale
        self.sigma_min = sigma_min
        # 最大的 noise scale
        self.sigma_max = sigma_max

        # NCSN 的 N 个 noise scales
        self.N = N
        # 幂形成等差数列, 则最终结果就是等比数列
        self.discrete_sigmas = torch.exp(
            torch.linspace(np.log(sigma_min), np.log(sigma_max), N)
        )

    @property
    def T(self) -> int:
        # 在由离散的 SMLD 拓展至连续的 VE SDE 后,
        # 时间的 t 的取值范围为 [0,1]
        return 1
    
    def sde(self, x: Tensor, t: Tensor) -> Tuple[Tensor, Tensor]:
        """ 计算 VE SDE 的漂移和扩散系数, 对应论文公式(30) """

        sigma_t = self.sigma_min * (self.sigma_max / self.sigma_min) ** t

        f = torch.zeros_like(x)
        g = sigma_t * torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)), device=x.device).sqrt()

        return f, g

    def p_0t(self, x_0, t) -> Tuple[Tensor, Union[float, Tensor]]:
        """ 计算 VE SDE 的 perturbation kernel 均值和标准差, 参照论文公式(31) """

        return x_0, self.sigma_min * (self.sigma_max / self.sigma_min) ** t
    
    def prior_samping(self, shape) -> Tensor:
        """ 先验分布是 $\mathcal N(0, sigma_max^2 I)$ """

        return torch.randn(*shape) * self.sigma_max
    
    def discretize(self, x: Tensor, t: Tensor):
        """ VE SDE 的数值离散形式, 即 SMLD 加噪的马尔科夫链, 对应论文公式(8) 
 相当于 sde() 方法的离散版本 "
""

        # 将当前连续的时间变量转换为离散的时间步
        timestep_i = (t / self.T * (self.N - 1)).long()
        sigma_i = self.discrete_sigmas.to(x.device)[timestep_i]
        # $\sigma_{i-1}$
        adj_sigma = torch.where(
            timestep_i == 0, 
            torch.zeros_like(sigma_i), 
            self.discrete_sigmas.to(sigma_i.device)[timestep_i - 1]
        )

        # 因为将 SMLD 的马尔科夫链看作是 VE SDE 的数值离散化过程,
        # 所以这里依照伊藤 SDE 的惯例返回漂移和扩散系数
        f = torch.zeros_like(x)
        g = (sigma_i ** 2 - adj_sigma ** 2).sqrt()

        return f, g
  • VP SDE

接下来轮到 VP SDE,所要实现的内容与前面的 VE SDE 类似,只不过这里不需要重新实现 prior_sampling() 方法,因为其先验就是标准高斯分布,与父类 SDE的一致。

    class VPSDE(SDE):
    def __init__(self, beta_min: float = 0.1, beta_max: float = 20., N: int = 1000):
        super().__init__()

        # $\bar{\beta}$ VP SDE 的 \beta(0) 和 \beta(1)
        self.beta_0 = beta_min
        self.beta_1 = beta_max

        # DDPM 的 N 个 noise scales ${\beta_i}$
        # 与 VE SDE 的 \bar{\beta} 的关系是:
        # $N \beta = \bar{\beta}$
        self.N = N
        self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N)

        # DDPM 加噪过程中使用的 \alpha_i
        self.alphas = 1 - self.discrete_betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alphas_cumprod = self.alphas_cumprod.sqrt()
        self.sqrt_1m_alphas_cumprod = (1. - self.alphas_cumprod).sqrt()

    @property
    def T(self) -> int:
        return 1
    
    def sde(self, x: Tensor, t: Tensor) -> Tuple[Tensor, Tensor]:
        """ 计算 VP SDE 的漂移和扩散系数, 对应论文公式(32) """

        beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)

        f = -0.5 * beta_t[:, None, None, None] * x
        g = beta_t.sqrt()

        return f, g
    
    def p_0t(self, x_0: Tensor, t: Tensor) -> Tuple[Tensor, Union[float, Tensor]]:
        """ 计算 VP SDE 的 perturbation kernel 的均值和标准差, 参考论文公式(33) """

        exponential = -0.25 * t ** 2 * \
            (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
        
        mean = torch.exp(exponential[:, None, None, None]) * x_0
        std = (1. - torch.exp(2. * exponential)).sqrt()

        return mean, std
    
    def discretize(self, x: Tensor, t: Tensor) -> Tuple[Tensor, Tensor]:
        """ VP SDE 的离散过程, 即 DDPM 加噪的马尔科夫链过程, 对应论文公式(10) """

        timestep_i = (t / self.T * (self.N - 1)).long()

        sqrt_beta = self.discrete_betas.to(x.device)[timestep_i].sqrt()
        sqrt_alpha = self.alphas.to(x.device)[timestep_i].sqrt()

        # 因为将 DDPM 的马尔科夫链看作是 VP SDE 的数值离散化过程,
        # 所以这里依照伊藤 SDE 的惯例返回漂移和扩散系数
        f = (sqrt_alpha - 1.)[:, None, None, None] * x
        g = sqrt_beta

        return f, g
  • sub-VP SDE

VP SDE 之后,就轮到他弟 sub-VP SDE 了。不愧是弟弟,所需实现的部分更少了——不需要实现 discretize() 方法,因为它并非像 VE, VP SDE 一样由离散的马尔科夫链启发而来,并且作者设计这个 SDE 更多是想将其用于 continuous 训练(这算是 CW 个人的读后感~)。

    class subVPSDE(SDE):
    def __init__(self, beta_min: float = 0.1, beta_max: float = 20, N: int = 1000):
        super().__init__()

        self.beta_0 = beta_min
        self.beta_1 = beta_max

        self.N = N
    
    @property
    def T(self):
        return 1
    
    def sde(self, x: Tensor, t: Tensor) -> Tuple[Tensor, Tensor]:
        """ 计算 sub-VP SDE 的漂移和扩散系数, 对应论文公式(12) """

        beta_t = self.beta_0 * t * (self.beta_1 - self.beta_0)
        f = -0.5 * beta_t[:, None, None, None] * x

        # 将 $\beta(t) = \beta_0 + t(\beta_1 - \beta_0)$ 代入论文公式(12)计算出自然底数的幂
        coeff = 1. - torch.exp(
            -2 * self.beta_0 * t
            -(self.beta_1 - self.beta_0) * t ** 2
        )
        g = (beta_t * coeff).sqrt()

        return f, g
    
    def p_0t(self, x_0: Tensor, t: Tensor) -> Tuple[Tensor, Tensor]:
        """ 计算 sub-VP SDE 的 perturbation kernel 的均值和标准差, 参考论文公式(34) """

        exponential = -0.25 * t ** 2 * \
            (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
        
        mean = torch.exp(exponential[:, None, None, None]) * x_0
        std = 1. - torch.exp(2. * exponential)

        return mean, std

这里需要特别小心扩散系数的计算(也就是 方法中变量 的计算), 需要自行将 代入论文中的公式(12)去计算积分的结果, 然后再写入对应代码中(对应以上 coeff 变量)。

炼丹:Play with the SDE Score Model

通过前面的内容我们已经把积木搭好了,现在就一起来“炼一炼”这种以 SDE 进行建模的 score model 叭~ 至于数据集,就选最可爱(新手友好)的 MNIST 好了,玩起来比较快。

首先把 train loop 给秒掉,毕竟都是那些套路:

  • 设置超参
  • 构建 dataset 和 data loader
  • 设置 optimizer
  • 不断从 data loader 取出 batched data 喂给模型
  • 将模型输出送给 loss 函数计算 loss 并回传梯度
  • 用 optimizer 更新模型参数
  • 打印一些日志以便观测训练情况是否健康
    from tqdm import tqdm

from torch.optim import Adam
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torchvision.transforms import transforms

device = "cuda"

bs = 32  
lr = 1e-4  
n_epochs = 50 

dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(
    dataset, batch_size=bs, shuffle=True,
    num_workers=2, pin_memory=True
)

def train(sde: SDE, model: Module, loss_fn: Callable):
    optimizer = Adam(model.parameters(), lr=lr)
    optimizer.zero_grad()

    tqdm_epochs = tqdm(range(1, n_epochs + 1))
    for epoch in tqdm_epochs:
        total_loss = 0.
        num_samples = 0

        for x, _ in data_loader:
            x = x.to(device)
            loss = loss_fn(sde, model, x)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            num_samples += x.size(0)
            total_loss += loss.item() * x.size(0)

        avg_loss = total_loss / num_samples
        tqdm_epochs.set_description(
            f"Epoch:[{epoch}]/[{n_epochs}]; Avg Loss: {avg_loss:5f}; "
            f"Num Samples: {num_samples}"
        )
    
    torch.save(model.state_dict(), f"{sde.__class__.__name__}-{n_epochs}_epochs.pth")

OK,万事俱备,只欠模型!现在就只剩能够预测 score 的模型还未实现了,咱就 old school 一点——采用 U-Net 架构,别搞什么 transformer 了(毕竟它已经足够火);对于时间变量,我们采用傅里叶特征编码的方式将其编码为 embeddings。

    import torch.nn as nn
import torch.nn.functional as F

class GaussianFourierProjection(nn.Module):
    """ Gaussian random features for encoding time steps. """

    def __init__(self, embed_dim, scale=30.):
        super().__init__()
        # Randomly sample weights during initialization. These weights are fixed
        # during optimization and are not trainable.
        # \omega \sim \mathcal N(0, s^2 I), s = 30.
        self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
  
    def forward(self, x):
        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

class Dense(nn.Module):
    """ A fully connected layer that reshapes outputs to feature maps. """

    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.dense = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        return self.dense(x)[..., None, None]

class ScoreModel(nn.Module):
    """ A time-dependent score-based model built upon U-Net architecture. """

    def __init__(self, p_0t: Callable, channels=[32, 64, 128, 256], embed_dim=256):
        """Initialize a time-dependent score-based network.

 Args:
 p_0t: A function that takes time t and gives the standard
 deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
 channels: The number of channels for feature maps of each resolution.
 embed_dim: The dimensionality of Gaussian random feature embeddings.
 "
""

        super().__init__()

        # Gaussian random feature embedding layer for time
        self.embed = nn.Sequential(
            GaussianFourierProjection(embed_dim),
            nn.Linear(embed_dim, embed_dim)
        )

        # Encoding layers where the resolution decreases
        self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
        self.dense1 = Dense(embed_dim, channels[0])
        self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])

        self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
        self.dense2 = Dense(embed_dim, channels[1])
        self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])

        self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
        self.dense3 = Dense(embed_dim, channels[2])
        self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])

        self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
        self.dense4 = Dense(embed_dim, channels[3])
        self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])

        # Decoding layers where the resolution increases
        self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
        self.dense5 = Dense(embed_dim, channels[2])
        self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])

        self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)
        self.dense6 = Dense(embed_dim, channels[1])
        self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])

        self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)
        self.dense7 = Dense(embed_dim, channels[0])
        self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])

        self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)

        # The swish activation function
        self.act = lambda x: x * torch.sigmoid(x)

        self.p_0t = p_0t

    def forward(self, x, t):
        # Obtain the Gaussian random feature embedding for t
        embed = self.act(self.embed(t))

        # Encoding path
        h1 = self.conv1(x)
        ## Incorporate information from t
        h1 += self.dense1(embed)
        ## Group normalization
        h1 = self.gnorm1(h1)
        h1 = self.act(h1)

        h2 = self.conv2(h1)
        h2 += self.dense2(embed)
        h2 = self.gnorm2(h2)
        h2 = self.act(h2)

        h3 = self.conv3(h2)
        h3 += self.dense3(embed)
        h3 = self.gnorm3(h3)
        h3 = self.act(h3)

        h4 = self.conv4(h3)
        h4 += self.dense4(embed)
        h4 = self.gnorm4(h4)
        h4 = self.act(h4)

        # Decoding path with kip connection from the encoding path
        h = self.tconv4(h4)
        h += self.dense5(embed)
        h = self.tgnorm4(h)
        h = self.act(h)

        h = self.tconv3(torch.cat([h, h3], dim=1))
        h += self.dense6(embed)
        h = self.tgnorm3(h)
        h = self.act(h)

        h = self.tconv2(torch.cat([h, h2], dim=1))
        h += self.dense7(embed)
        h = self.tgnorm2(h)
        h = self.act(h)

        h = self.tconv1(torch.cat([h, h1], dim=1))

        # 由于真实 score 的尺度(2-norm)在 1/std 水平,
        # 因此这里用 1/std 来 rescale 模型输出, 就会鼓励模型的输出具有单位尺度
        std = self.p_0t(x, t)[1]
        h = h / std[:, None, None, None]

        return h

打起精神来!以上最后对模型的输出使用 1/std 进行 rescale,这会鼓励模型的输出具有单位尺度,因为根据 NCSN 那篇文章的理论,真实 score 的尺度(l2-norm)在 1/std 水平,所以提前将模型输出除以 1/std 再拿这个整体去优化,模型输出就会维持在单位尺度(因为这个整体会被要求在 1/std 尺度)。

最后,我们就选择 VE SDE 来开启训练吧~!

ve_sde = VESDE()
model = ScoreModel(ve_sde.p_0t).to(device)

train(ve_sde, model, sde_loss_fn)
第 21 个 epoch
第 26 个 epoch
第 34 个 epoch
第 44 个 epoch
最后 1 个 epoch

也不知道这种玩法好使不,不如采样看看什么效果:

在 MNIST 上训练 50 个 epochs 后的采样结果

Em.. 比鬼画符强一丢丢,作为喝手冲(咖啡)时的娱乐游戏,差不多是这么个意思就行,在内卷时代稍微佛系一些才能凸显个性。哦?你想知道采样过程的代码是怎么写的?冇问题,它就长以下这样:

    from torchvision.utils import make_grid

%matplotlib inline
import matplotlib.pyplot as plt

ckpt = f"{ve_sde.__class__.__name__}-{n_epochs}_epochs.pth"
state_dict = torch.load(ckpt, map_location=device)
model.load_state_dict(state_dict)

num_samples = 64  #@param {'type': 'integer'}
sample_shape = (num_samples,) + (1, 28, 28)

signal_to_noise_ratio = 0.16  #@param {'type':'number'}
n_corrector_steps = 1  #@param {'tpye': 'integer'}

predictor = ReverseDiffusionPredictor(ve_sde, model)
corrector = LangevinDynamicsCorrector(
    ve_sde, model, 
    snr=signal_to_noise_ratio, n_steps=n_corrector_steps
)

samples = pc_sampling(
    ve_sde, sample_shape, 
    predictor.update_fn, corrector.update_fn,
    eps=1e-5, device=device
)
samples = samples.clamp(0., 1.)
grids = make_grid(samples, nrow=int(np.sqrt(num_samples)))

plt.figure(figsize=(6,6))
plt.axis('off')
plt.imshow(grids.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
plt.show()

至于以上代码中 predictorcorrectorpc_sampling() 等这些是什么,这里就暂且搁(放心,不鸽)一下,待下一篇文章再深入聊聊。

What's Next

相必聪明的你也知道下一篇文章的主题了,没错!就是在 SDE 建模方式下的各种采样方法,文章内容的安排也会像本篇一样理论推导和代码实现相结合,至于文风嘛~ 那当然是不正经的,毕竟 CW 不要无聊的风格~



公众号后台回复“数据集”获取100+深度学习各方向资源整理

极市干货

技术专栏:多模态大模型超详细解读专栏搞懂Tranformer系列ICCV2023论文解读极市直播
极视角动态欢迎高校师生申报极视角2023年教育部产学合作协同育人项目新视野+智慧脑,「无人机+AI」成为道路智能巡检好帮手!
技术综述:四万字详解Neural ODE:用神经网络去刻画非离散的状态变化transformer的细节到底是怎么样的?Transformer 连环18问!

点击阅读原文进入CV社区

收获更多技术干货


浏览 122
10点赞
评论
收藏
分享

手机扫一扫分享

举报
评论
图片
表情
推荐
10点赞
评论
收藏
分享

手机扫一扫分享

举报