DALL-E:Zero-Shot Text-to-Image Generation

分别使用 CLIP、VAE 提取文本编码、图像编码,然后使用 transformer 使用自注意力学习。推理时,使用 CLIP 提取文本编码,然后输入
transformert 提取图片编码,最后使用 dVAE decoder 生成图片

  1. 文本 ->CLIP-> 文本编码;
  2. 图像 ->dVAE encoder-> 图像编码;
  3. 文本编码 +(图像编码)-> 自回归 -> 新图像编码 ->dVAE decoder-> 条件图片,文生图

什么是 DALL-E ?

  • 分别使用 CLIP、VAE 提取文本编码、图像编码,然后使用 transformer 使用自注意力学习。推理时,使用 CLIP 提取文本编码,然后输入 transformert 提取图片编码,最后使用 dVAE decoder 生成图片

DALL-E 的训练过程?

  • 结构:DALL-E 是一个两阶段的模型:它的第一个阶段是离散变分自编码器(Discrete Variance Auto-Encoder,dVAE),用于生成图像的 token,它的第二个阶段是混合了图像和文本特征的,以 Transformer 为基础的生成模型;最后使用 CLIP 筛选图片
  • 训练步骤:DALL-E 涉及 3 个网络,dVAE、transformer、CLIP,每个网络单独训练

DALL-E 的 dVAE 的作用?

  • 由于图片的分辨率很大,如果把单个 pixel 当成一个 token 处理,会导致计算量过于庞大,于是 DALL・E 引入了一个 dVAE 模型来降低图片的分辨率
  • dVAE 把每张 256x256 的 RGB 图片压缩成 32x32 的图片 token,每个位置有 8192 种可能的取值 (也就是说 dVAE 的 encoder 输出是维度为 32x32x8192 的 logits,然后通过 logits 索引 codebook 的特征进行组合,codebook 的 embedding 是可学习的)
  • 和 VQVAE 方法相似,dVAE 的 encoder 是将图像的 patch 映射到 8192 的词表中,由于不可导的问题,此时不能采用重参数技巧,DALL・E 使用 Gumbel Softmax trick 来解决这个问题

DALL-E 的 Transformer 的作用?

  • DALL・E 中的 Transformer 结构由 64 层 attention 层组成,每层的注意力头数为 62,每个注意力头的维度为 64,因此,每个 token 的向量表示维度为 3968。如图所示,attention 层使用了行注意力 mask、列注意力 mask 和卷积注意力 mask 三种稀疏注意力

DALL-E 的 CLIP 的作用?

  • 通过输入不同的首个图像的 token 可生成很多各种类型的图片(设置 max=512),需要根据 CLIP 来对得到的图文对进行重排

DALL-E 的 BPE Encoder?

  • 用 BPE Encoder 对文本进行编码,得到最多 256 个文本 token,token 数不满 256 的话 padding 到 256,然后将 256 个文本 token 与 1024 个图像 token 进行拼接,得到长度为 1280 的数据
  • BPE 是一种减少词表的算法

DALL-E 的推理?

  • 首先将输入文本编码成特征向量,然将特征向量送入到自回归的 Transformer 中生成图像的 token,再后将图像的 token 送入到 dVAE 的解码器中得到生成图像,最后通过 CLIP 对生成样本进行评估,得到最终的生成结果
  • 1)生成图像的模块只有 dVAE 的解码器,其要求输入是图像的 token
  • 2)文本输入,其目地是得到图像 token,做法是文本输入得到文本 token,通过 Transformer 生成图像的 token。先 2 后 1 即可根据文本生成图像
  • 3)通过 Transformer 生成图像的 token 有多个,所以经过 dVAE 编码器后生成多张候选图片,将文本 + 候选图片输入 CLIP,计算出与文本越接近的候选图片,即为最终生成图片

DALL-E 使用 Gumbel-SoftMax 解决离散变量求导问题?

  • dVAE 计算离散的 one-hot 编码时,需要使用到 argmax,但是这个 argmax 是不可导的,因此无法用来更新模型。DALL-E 解决这个问题的策略是引入了 Gumbel-Softmax
  • Gumbel-Softmax 在 DALL-E 中可以理解为通过向 softmax 中引入超参数 τ\tau 来使 argmax 可导。超参数 τ\tau 在深度学习中有一个专业术语叫做温度(Temperature),它可以通过调整 softmax 曲线的平滑程度来实现不同的功能。加入超参 τ\tau 的 softmax 可以表示为

στ(pj)=exp(pj/τ)i=1Npi/τ\sigma_\tau(p_j)=\frac{\exp(p_j/\tau)}{\sum_{i=1}^Np_i/\tau}

  • 可以对比蒸馏学习中,通过在 softmax 输出后引入温度 τ\tau,减少不同 way 的差异,使得 teacher 和 student 的有效信息得以传递

DALL-E 中使用 logit-Laplace 分布的原理?

  • 在构建生成图像时,图像的像素是有值域范围的,而 VAE 中通过拉普拉斯分布或者高斯分布得到的值域是整个实数集,这就造成了模型目标和实际生成内容的不匹配问题
  • 为了解决这个问题,DALL-E 提出了拉普拉斯分布的变体:log - 拉普拉斯分布。它的核心思想是将 sigmoid 作用到拉普拉斯分布的随机变量上,从而得到一个值域是 (0,1) 的随机变

f(xμ,b)=12bx(1x)exp(logit(x)μb)f(x\mid\mu,b)=\frac{1}{2bx(1-x)}\exp\biggl(-\frac{|\operatorname{logit}(x)-\mu|}{b}\biggr)

DALL-E 训练、推理理解?

  • Drawing-20231112170733.excalidraw
  • 训练过程:通过已经训练好的 dVAE、BPE 分布提取图像、文本的 token 表示,通过 transformer 学习 N+M 个 token 表示,然后使用原始输入监督 transformer 学习,这一过程为了训练 transformer
  • 推理过程:将文本编码后的 token 输入 BPE,输出 N 个 token T1,T2...TnT_1,T_2...T_n,搭配图片 Token(n=0,1,2…)输入 transformer,输出 M+n token 预测,取均值得到 1 个 token 预测,对应图片第一个 token 预测 I1I_1,再次输入 T1,T2...TnT_1,T_2...T_nI1I_1,预测图片的下一个 token,当执行 N 次时,输出 N 个图片预测 token,然后使用 dVAE 的解码器还原为图片,最后使用 CLIP 比较文本和图像的相似程度

参考:

  1. DALL・E— 从文本到图像,超现实主义的图像生成器 - 知乎
  2. DALL・E 原理通俗理解 - 知乎
  3. OpenAI 的 DALL-E 模型原理 - 知乎
  4. GitHub - lucidrains/DALLE-pytorch at 58c1e1a4fef10725a79bd45cdb5581c03e3e59e7