10. Imagen

本章我们讨论谷歌发表的图像扩散模型 Imagen(Hierarchical Text-Conditional Image Generation with CLIP Latents) 图像生成模型 ,谷歌虽然发表了论文,但并没有开源它的模型,所以之前讨论的人比较少。 但是近期,StabilityAIDeepFloyd 实验室复现了这个论文中模型,并且开源了预训练模型,使我们这”穷人”也能用起来了, 所以这里也讨论交流一下这个模型,学习一下它的精髓。

ImagenStable diffusion, GLIDE 类似,它们并不是一种新的图像生成算法, 而是同样都是基于 DDPM 的应用模型。 相比于 Stable diffusionImagen 主要有以下不同:

  1. 文本编码用谷歌的 T5 预训练语言模型,直接使用文本语言模型,比 CLIP 这种残缺的文本编码器肯定是表达能力更强的。理论上,对句子文本(相比CLIP那种关键词)理解更好。

  2. 没有潜空间(Latent Space)的概念,直接扩算生成比较小( \(64 \times 64\))的图像,这个尺寸和 Stable diffusion 的潜空间尺寸是相同的,前面我们跳过过 Stable diffusion 的潜空间其实就可以看做是一张压缩图片。

  3. 通过级联两个图像放大模型解决生成高分辨率图片生成的问题。其实就是先扩算生成小尺寸图片,再放大成更大尺寸。

除了上述3点外,Imagen 论文还提出了一些其它改进点,1)对 \(\hat{x}\) 数值范围进行修正,使保持其在 \([-1,1]\) 之间,当然这一点不是这篇论文首次提出的,这篇论文强调了它的作用,并提出了一个动态修正方法( a new dynamic thresholding method) 。2)对 U-Net 网络结构也做了一定的改进和优化。 论文 [1] 中表示 Imagen 生成的图片更加写实和逼真,图 图 10.1 是一些生成样例。

../_images/imagen-examples.png

图 10.1 Imagen 生成图像示例(图片来自Imagen Home

Imagen 实际是 4 个模型串联在一起的,顺序执行,如 图 10.2 所示。

  1. 先执行文本编码器(T5预训练模型),对文本提示语进行编码,得到文本特性向量。

  2. 用一个 DDPM 扩散模型生成一个 \(64 \times 64\) 的小尺寸图像。

  3. 用 4 倍的放大模型,把 \(64 \times 64\) 的图像放大到 \(256 \times 256\),这里的放大模型也是一个 DDPM 模式, 只不过它的输入条件同时包含文本向量和小尺寸图像。

  4. 同第三步,另一个4倍放大模型。

../_images/imagen-flow.jpg

图 10.2 Imagen流程示意图(图片来自Imagen Home

在模型训练阶段,文本编码器同样是冻结的,3个扩散模型是可以分开独立训练的。 StabilityAI 公司的 DeepFloyd 实验室复现了 Imagen 模型,并且开源出来, github地址为 https://github.com/deep-floyd/IF

10.1. 代码实现解读

由于本论文不是提出了一个新的扩散算法,而是基于 DDPM 的应用模型,所以不涉及复杂公式的推导, 这里我们就用代码解读凑博客字数吧。 同样我们选择研读 Diffusers 的实现代码,这里先看一下官方给出的使用实例代码, 从中可以看出,生成一张图片分三个阶段(stage), 这对应着 图 10.2 ,接下来分别解读一下每个阶段的实现。

from diffusers import DiffusionPipeline
from diffusers.utils import pt_to_pil
import torch

# stage 1
stage_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
stage_1.enable_xformers_memory_efficient_attention()  # remove line if torch.__version__ >= 2.0.0
stage_1.enable_model_cpu_offload()

# stage 2
stage_2 = DiffusionPipeline.from_pretrained(
    "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
)
stage_2.enable_xformers_memory_efficient_attention()  # remove line if torch.__version__ >= 2.0.0
stage_2.enable_model_cpu_offload()

# stage 3
safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker}
stage_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16)
stage_3.enable_xformers_memory_efficient_attention()  # remove line if torch.__version__ >= 2.0.0
stage_3.enable_model_cpu_offload()

prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'

# text embeds
prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)

generator = torch.manual_seed(0)

# stage 1
image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt").images
pt_to_pil(image)[0].save("./if_stage_I.png")

# stage 2
image = stage_2(
    image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt"
).images
pt_to_pil(image)[0].save("./if_stage_II.png")

# stage 3
image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100).images
image[0].save("./if_stage_III.png")

10.1.1. 第一阶段

第一阶段就是一个朴素的 DDPM 模型,当然和 Stable diffusion 不一样的地方是它的文本编码器是谷歌自家的 T5 模型。 第一阶段的 Pipeline 实现代码在 diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline

先看 IFPipeline 的组件,可以看到和 Stable diffusion 相比,主要就不同就是

  • text_encoderT5EncoderModel

  • 没有 VAE 模型。

class IFPipeline(DiffusionPipeline):

    def __init__(
        self,
        tokenizer: T5Tokenizer,
        text_encoder: T5EncoderModel,
        unet: UNet2DConditionModel,
        scheduler: DDPMScheduler,
        safety_checker: Optional[IFSafetyChecker],
        feature_extractor: Optional[CLIPImageProcessor],
        watermarker: Optional[IFWatermarker],
        requires_safety_checker: bool = True,
    ):
        ...

接下来是 _call_ 方法的实现,也就是主流程

def __call__(
    self,
    prompt: Union[str, List[str]] = None,
    num_inference_steps: int = 100,
    timesteps: List[int] = None,
    guidance_scale: float = 7.0,
    negative_prompt: Optional[Union[str, List[str]]] = None,
    num_images_per_prompt: Optional[int] = 1,
    height: Optional[int] = None,
    width: Optional[int] = None,
    eta: float = 0.0,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = True,
    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
    callback_steps: int = 1,
    clean_caption: bool = True,
    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
    ):
    """
    Function invoked when calling the pipeline for generation.

    Args:
        prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
            instead.
        num_inference_steps (`int`, *optional*, defaults to 50):
            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
            expense of slower inference.
        timesteps (`List[int]`, *optional*):
            Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
            timesteps are used. Must be in descending order.
        guidance_scale (`float`, *optional*, defaults to 7.5):
            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
            `guidance_scale` is defined as `w` of equation 2. of [Imagen
            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
            usually at the expense of lower image quality.
        negative_prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts not to guide the image generation. If not defined, one has to pass
            `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
            less than `1`).
        num_images_per_prompt (`int`, *optional*, defaults to 1):
            The number of images to generate per prompt.
        height (`int`, *optional*, defaults to self.unet.config.sample_size):
            The height in pixels of the generated image.
        width (`int`, *optional*, defaults to self.unet.config.sample_size):
            The width in pixels of the generated image.
        eta (`float`, *optional*, defaults to 0.0):
            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
            [`schedulers.DDIMScheduler`], will be ignored for others.
        generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
            One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
            to make generation deterministic.
        prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
            provided, text embeddings will be generated from `prompt` input argument.
        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
            weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
            argument.
        output_type (`str`, *optional*, defaults to `"pil"`):
            The output format of the generate image. Choose between
            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
        callback (`Callable`, *optional*):
            A function that will be called every `callback_steps` steps during inference. The function will be
            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
        callback_steps (`int`, *optional*, defaults to 1):
            The frequency at which the `callback` function will be called. If not specified, the callback will be
            called at every step.
        clean_caption (`bool`, *optional*, defaults to `True`):
            Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
            be installed. If the dependencies are not installed, the embeddings will be created from the raw
            prompt.
        cross_attention_kwargs (`dict`, *optional*):
            A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
            `self.processor` in
            [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).

    Examples:

    Returns:
        [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
        [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
        returning a tuple, the first element is a list with the generated images, and the second element is a list
        of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
        or watermarked content, according to the `safety_checker`.
    """

    # 1. Check inputs. Raise error if not correct
    # 检查输入参数合法性,这里不用关注
    self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)

    # 2. Define call parameters
    # 生成图像的尺寸,默认是 64 x 64
    height = height or self.unet.config.sample_size
    width = width or self.unet.config.sample_size
    # batch_size 的判断
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    device = self._execution_device

    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.
    # 这里判断是否启动 classifier_free_guidance 特性
    # 注意,负提示词是否生效和它相关,只有启用 classifier_free_guidance 负提示词才会生效,
    # 否则负提示词不起作用。
    do_classifier_free_guidance = guidance_scale > 1.0

    # 3. Encode input prompt
    # 对输入提词语进行文本编码
    # 这里没有什么特别的知识调用文本编码器对文本提示语进行编码
    # 需要注意的是:如果 do_classifier_free_guidance==False,那么 negative_prompt_embeds=None
    prompt_embeds, negative_prompt_embeds = self.encode_prompt(
        prompt,
        do_classifier_free_guidance,
        num_images_per_prompt=num_images_per_prompt,
        device=device,
        negative_prompt=negative_prompt,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        clean_caption=clean_caption,
    )
    # 如果启用 do_classifier_free_guidance 则拼接正负提示语
    # 这里是在batch 维度进行的拼接,相当于batch 扩大了两倍,
    # 前部分是根据负提示词生成,后部分是根据正提示语生成
    # 后面会把这两部分再拆开,稍后能看到是如何处理的。
    if do_classifier_free_guidance:
        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

    # 4. Prepare timesteps
    if timesteps is not None:
        self.scheduler.set_timesteps(timesteps=timesteps, device=device)
        timesteps = self.scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

    # 5. Prepare intermediate images
    # 生成初始随机噪声数据,没啥特别的
    # shape = (batch_size=len(prompt) * num_images_per_prompt, num_channels=3, height=64, width=64)
    intermediate_images = self.prepare_intermediate_images(
        batch_size * num_images_per_prompt,
        self.unet.config.in_channels,
        height,
        width,
        prompt_embeds.dtype,
        device,
        generator,
    )

    # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

    # HACK: see comment in `enable_model_cpu_offload`
    if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
        self.text_encoder_offload_hook.offload()

    # 7. Denoising loop
    # 核心的降噪循环
    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # 如果启用 classifier_free_guidance
            # 就要把输入扩大两倍,这和前面的 prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 是对应的。
            # 这里其实有个 classifier_free_guidance 的trick
            # 原本 classifier_free_guidance 是需要两次降噪过程,一次是有提示语,一次是没有提示语
            # 这里的 前面一半,理论上对应的前面一半没有提示语的部分,但是这里实现上是负责处理负提示词。
            # 后面一半是正常的 text Prompt -> image 的过程。
            # 拼在一起进行处理是为了方便,一个批次同时处理,而不用分开两次处理。
            model_input = (
                torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images
            )
            # 这里可以对输入数据进行一定的缩放处理吗,但实际上没做任何处理
            model_input = self.scheduler.scale_model_input(model_input, t)

            # predict the noise residual
            # 调用UNET网络,进行噪声预测
            noise_pred = self.unet(
                model_input,
                t,
                encoder_hidden_states=prompt_embeds,
                cross_attention_kwargs=cross_attention_kwargs,
                return_dict=False,
            )[0]

            # perform guidance
            # 这里把拼接的两部分进行切分
            if do_classifier_free_guidance:
                # 把batch平均分开,前面的是无条件(实际上是负提示词)的部分,后面的是正常的提示词部分
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                # 每一份在channel维度又一分为二,分别是预测的噪声(期望)和方差
                noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
                noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
                # 应用classifier_free_guidance的公式
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
                # 重新在 channel维度,把预测的噪声和方差拼接回去
                noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)

            # compute the previous noisy sample x_t -> x_t-1
            # 根据预测噪声更新生成图片,即  x_t -> x_t-1
            # 论文中提到了 thresholding 修正在这里面实现的
            intermediate_images = self.scheduler.step(
                noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
            )[0]

            # call the callback, if provided
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                progress_bar.update()
                if callback is not None and i % callback_steps == 0:
                    callback(i, t, intermediate_images)

    image = intermediate_images

    if output_type == "pil":
        # 8. Post-processing
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()

        # 9. Run safety checker
        image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)

        # 10. Convert to PIL
        image = self.numpy_to_pil(image)

        # 11. Apply watermark
        if self.watermarker is not None:
            image = self.watermarker.apply_watermark(image, self.unet.config.sample_size)
    elif output_type == "pt":
        nsfw_detected = None
        watermark_detected = None

        if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
            self.unet_offload_hook.offload()
    else:
        # 8. Post-processing
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()

        # 9. Run safety checker
        image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)

    # Offload last model to CPU
    if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
        self.final_offload_hook.offload()

    if not return_dict:
        return (image, nsfw_detected, watermark_detected)

    return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)

整个实现是非常清晰简洁的,很容易能看懂,这里稍微复杂的地方就是 classifier_free_guidance 的处理, 先回顾一下 classifier_free_guidance 核心的公式

(10.1.32)\[\bar{\epsilon}_{\theta}(z_t, c) = w \epsilon_{\theta}(z_t, c) + (1 − w) \epsilon_{\theta}(z_t)\]

前面讨论过 classifier free guidance 技术,其实实现起来很简单,就是逆过程过程中,用同一个 UNET 网络分别进行有条件和无条件两个噪声预测, 然后两者加权求和作为最终的预测噪声。 这里实现的时候有两个小 trick:

  1. 没有分别调用 UNET 两次,而是把输入 batch 扩大两倍,前面部分作为无条件,后面部分作为有条件,反正都是同一个 UNET 网络,这样做效率更高。

  2. 无条件部分,并不是真的没有任何条件,而是把负提示词作为条件。

Stable Diffusion 的实现也是这样做的,不稀奇。

关于 thresholding 修正的功能,因为这个功能其实一定程度上是个通用功能,所以在它是放到了 Scheduler 中实现,

def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
    """
    "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
    prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
    s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
    pixels from saturation at each step. We find that dynamic thresholding results in significantly better
    photorealism as well as better image-text alignment, especially when using very large guidance weights."

    https://arxiv.org/abs/2205.11487
    """
    dtype = sample.dtype
    batch_size, channels, height, width = sample.shape

    if dtype not in (torch.float32, torch.float64):
        sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half

    # Flatten sample for doing quantile calculation along each image
    sample = sample.reshape(batch_size, channels * height * width)

    abs_sample = sample.abs()  # "a certain percentile absolute pixel value"

    s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
    s = torch.clamp(
        s, min=1, max=self.config.sample_max_value
    )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]

    s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
    sample = torch.clamp(sample, -s, s) / s  # "we threshold xt0 to the range [-s, s] and then divide by s"

    sample = sample.reshape(batch_size, channels, height, width)
    sample = sample.to(dtype)

    return sample

10.1.2. 第二阶段

第二阶段的 Pipeline 实现代码在 diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline

def __call__(
    self,
    prompt: Union[str, List[str]] = None,
    image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
    num_inference_steps: int = 50,
    timesteps: List[int] = None,
    guidance_scale: float = 4.0,
    negative_prompt: Optional[Union[str, List[str]]] = None,
    num_images_per_prompt: Optional[int] = 1,
    eta: float = 0.0,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = True,
    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
    callback_steps: int = 1,
    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
    noise_level: int = 250,
    clean_caption: bool = True,
):
    """
    Function invoked when calling the pipeline for generation.

    Args:
        prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
            instead.
        image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`):
            The image to be upscaled.
        num_inference_steps (`int`, *optional*, defaults to 50):
            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
            expense of slower inference.
        timesteps (`List[int]`, *optional*):
            Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
            timesteps are used. Must be in descending order.
        guidance_scale (`float`, *optional*, defaults to 7.5):
            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
            `guidance_scale` is defined as `w` of equation 2. of [Imagen
            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
            usually at the expense of lower image quality.
        negative_prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts not to guide the image generation. If not defined, one has to pass
            `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
            less than `1`).
        num_images_per_prompt (`int`, *optional*, defaults to 1):
            The number of images to generate per prompt.
        eta (`float`, *optional*, defaults to 0.0):
            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
            [`schedulers.DDIMScheduler`], will be ignored for others.
        generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
            One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
            to make generation deterministic.
        prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
            provided, text embeddings will be generated from `prompt` input argument.
        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
            weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
            argument.
        output_type (`str`, *optional*, defaults to `"pil"`):
            The output format of the generate image. Choose between
            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
        callback (`Callable`, *optional*):
            A function that will be called every `callback_steps` steps during inference. The function will be
            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
        callback_steps (`int`, *optional*, defaults to 1):
            The frequency at which the `callback` function will be called. If not specified, the callback will be
            called at every step.
        cross_attention_kwargs (`dict`, *optional*):
            A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
            `self.processor` in
            [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
        noise_level (`int`, *optional*, defaults to 250):
            The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)`
        clean_caption (`bool`, *optional*, defaults to `True`):
            Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
            be installed. If the dependencies are not installed, the embeddings will be created from the raw
            prompt.

    Examples:

    Returns:
        [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
        [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
        returning a tuple, the first element is a list with the generated images, and the second element is a list
        of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
        or watermarked content, according to the `safety_checker`.
    """
    # 1. Check inputs. Raise error if not correct

    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    self.check_inputs(
        prompt,
        image,
        batch_size,
        noise_level,
        callback_steps,
        negative_prompt,
        prompt_embeds,
        negative_prompt_embeds,
    )

    # 2. Define call parameters
    # 图像尺寸,默认 256 x 256
    height = self.unet.config.sample_size
    width = self.unet.config.sample_size

    device = self._execution_device

    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.
    do_classifier_free_guidance = guidance_scale > 1.0

    # 3. Encode input prompt
    # 正负提示词的处理,这里和第一阶段没什么不同
    prompt_embeds, negative_prompt_embeds = self.encode_prompt(
        prompt,
        do_classifier_free_guidance,
        num_images_per_prompt=num_images_per_prompt,
        device=device,
        negative_prompt=negative_prompt,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        clean_caption=clean_caption,
    )
    # 和第一阶段同样的逻辑,只有 classifier_free_guidance 生效,负提示词才有效
    if do_classifier_free_guidance:
        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

    # 4. Prepare timesteps
    if timesteps is not None:
        self.scheduler.set_timesteps(timesteps=timesteps, device=device)
        timesteps = self.scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

    # 5. Prepare intermediate images
    # 这里注意, self.unet.config.in_channels ==6 ,默认是6
    # 图片的 channel 不是3么?这里为什么是6?
    # 这模型是用来放大图片的,也就是输入输入小图片,放大到大尺寸,
    # 这里其实把输入图片和初始噪声在通道维度拼在一起了,3个通道给小图片,3个通道给初始噪声
    num_channels = self.unet.config.in_channels // 2
    # 这里是生成3个通道的初始噪声,尺寸是 256 x 256
    intermediate_images = self.prepare_intermediate_images(
        batch_size * num_images_per_prompt,
        num_channels,
        height,
        width,
        prompt_embeds.dtype,
        device,
        generator,
    )

    # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

    # 7. Prepare upscaled image and noise level
    image = self.preprocess_image(image, num_images_per_prompt, device)
    # 使用内插法对输入的小图进行放到,放大到 256 x 256
    upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True)
    # 对放大后的图片,添加一定程度的随机噪声
    noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
    noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
    upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)

    if do_classifier_free_guidance:
        noise_level = torch.cat([noise_level] * 2)

    # HACK: see comment in `enable_model_cpu_offload`
    if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
        self.text_encoder_offload_hook.offload()

    # 8. Denoising loop
    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # 把初始噪声图片和输入图片拼接在一起。注意,这里是在 channel 维度进行的拼接。
            # 拼接后 channel = 6
            model_input = torch.cat([intermediate_images, upscaled], dim=1)

            model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
            model_input = self.scheduler.scale_model_input(model_input, t)

            # predict the noise residual
            noise_pred = self.unet(
                model_input,
                t,
                encoder_hidden_states=prompt_embeds,
                class_labels=noise_level,
                cross_attention_kwargs=cross_attention_kwargs,
                return_dict=False,
            )[0]

            # perform guidance
            # 这里逻辑和阶段1 一样,
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
                noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
                noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)

            # compute the previous noisy sample x_t -> x_t-1
            intermediate_images = self.scheduler.step(
                noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
            )[0]

            # call the callback, if provided
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                progress_bar.update()
                if callback is not None and i % callback_steps == 0:
                    callback(i, t, intermediate_images)

    image = intermediate_images

    if output_type == "pil":
        # 9. Post-processing
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()

        # 10. Run safety checker
        image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)

        # 11. Convert to PIL
        image = self.numpy_to_pil(image)

        # 12. Apply watermark
        if self.watermarker is not None:
            self.watermarker.apply_watermark(image, self.unet.config.sample_size)
    elif output_type == "pt":
        nsfw_detected = None
        watermark_detected = None

        if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
            self.unet_offload_hook.offload()
    else:
        # 9. Post-processing
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()

        # 10. Run safety checker
        image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)

    # Offload last model to CPU
    if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
        self.final_offload_hook.offload()

    if not return_dict:
        return (image, nsfw_detected, watermark_detected)

    return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)

这一阶段是一个图片放大模型(也称为上采样),逻辑上也可以看做是一个 image+text -> image 的过程, 小尺寸的 image 也起到条件的作用,只是它没有像 ControlNet 那样引入一个独立的模块处理作为条件的图片, 而是直接拼接在的噪声图片上,在通道维度上做的拼接。

10.1.3. 第三阶段

第三阶段的 Pipeline 实现代码在 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline ,这一阶段没有独立训练一个模型,而是直接使用的 StableDiffusion 的放大模型。 果然如我当初所想,StableDiffusion 内含一个先缩小再放大的过程,完全可以作为一个图片放大器使用。

def __call__(
    self,
    prompt: Union[str, List[str]] = None,
    image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
    num_inference_steps: int = 75,
    guidance_scale: float = 9.0,
    noise_level: int = 20,
    negative_prompt: Optional[Union[str, List[str]]] = None,
    num_images_per_prompt: Optional[int] = 1,
    eta: float = 0.0,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    latents: Optional[torch.FloatTensor] = None,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = True,
    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
    callback_steps: int = 1,
):
    r"""
    Function invoked when calling the pipeline for generation.

    Args:
        prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
            instead.
        image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
            `Image`, or tensor representing an image batch which will be upscaled. *
        num_inference_steps (`int`, *optional*, defaults to 50):
            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
            expense of slower inference.
        guidance_scale (`float`, *optional*, defaults to 7.5):
            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
            `guidance_scale` is defined as `w` of equation 2. of [Imagen
            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
            usually at the expense of lower image quality.
        negative_prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts not to guide the image generation. If not defined, one has to pass
            `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
            is less than `1`).
        num_images_per_prompt (`int`, *optional*, defaults to 1):
            The number of images to generate per prompt.
        eta (`float`, *optional*, defaults to 0.0):
            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
            [`schedulers.DDIMScheduler`], will be ignored for others.
        generator (`torch.Generator`, *optional*):
            One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
            to make generation deterministic.
        latents (`torch.FloatTensor`, *optional*):
            Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
            tensor will ge generated by sampling using the supplied random `generator`.
        prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
            provided, text embeddings will be generated from `prompt` input argument.
        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
            weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
            argument.
        output_type (`str`, *optional*, defaults to `"pil"`):
            The output format of the generate image. Choose between
            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
            plain tuple.
        callback (`Callable`, *optional*):
            A function that will be called every `callback_steps` steps during inference. The function will be
            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
        callback_steps (`int`, *optional*, defaults to 1):
            The frequency at which the `callback` function will be called. If not specified, the callback will be
            called at every step.

    Returns:
        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
        When returning a tuple, the first element is a list with the generated images, and the second element is a
        list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
        (nsfw) content, according to the `safety_checker`.
    """

    # 1. Check inputs
    self.check_inputs(
        prompt,
        image,
        noise_level,
        callback_steps,
        negative_prompt,
        prompt_embeds,
        negative_prompt_embeds,
    )

    if image is None:
        raise ValueError("`image` input cannot be undefined.")

    # 2. Define call parameters
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    device = self._execution_device
    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.
    do_classifier_free_guidance = guidance_scale > 1.0

    # 3. Encode input prompt
    # 这里处理文本编码,和前两个阶段是一样的过程,
    # 只是这里把正负提示词拼接的过程放到方法 _encode_prompt 内部了。
    prompt_embeds = self._encode_prompt(
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
    )

    # 4. Preprocess image
    image = preprocess(image)
    image = image.to(dtype=prompt_embeds.dtype, device=device)

    # 5. set timesteps
    self.scheduler.set_timesteps(num_inference_steps, device=device)
    timesteps = self.scheduler.timesteps

    # 5. Add noise to image
    # 这里和第二阶段一样,对待放大的图片加噪处理
    noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
    noise = randn_tensor(image.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
    image = self.low_res_scheduler.add_noise(image, noise, noise_level)
    # batch 放大逻辑,和前面一样
    batch_multiplier = 2 if do_classifier_free_guidance else 1
    image = torch.cat([image] * batch_multiplier * num_images_per_prompt)
    noise_level = torch.cat([noise_level] * image.shape[0])

    # 6. Prepare latent variables
    # 这里就是生成初始随机噪声数据
    # 只不过在 Stable Diffusion 的世界里,初始噪声在潜空间(latent space),所以看上去是生成 latents
    # 当然尺寸要和放大前的图片一致
    height, width = image.shape[2:]
    num_channels_latents = self.vae.config.latent_channels  # latent_channels = 4
    latents = self.prepare_latents(
        batch_size * num_images_per_prompt,
        num_channels_latents,
        height,
        width,
        prompt_embeds.dtype,
        device,
        generator,
        latents,
    )

    # 7. Check that sizes of image and latents match
    num_channels_image = image.shape[1]
    if num_channels_latents + num_channels_image != self.unet.config.in_channels:
        raise ValueError(
            f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
            f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
            f" `num_channels_image`: {num_channels_image} "
            f" = {num_channels_latents+num_channels_image}. Please verify the config of"
            " `pipeline.unet` or your `image` input."
        )

    # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

    # 9. Denoising loop
    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

            # concat latents, mask, masked_image_latents in the channel dimension
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
            # 在 channel 维度拼接,这和前面的模型是一样的逻辑
            latent_model_input = torch.cat([latent_model_input, image], dim=1)

            # predict the noise residual
            #
            noise_pred = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=prompt_embeds,
                class_labels=noise_level,
                return_dict=False,
            )[0]

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

            # call the callback, if provided
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                progress_bar.update()
                if callback is not None and i % callback_steps == 0:
                    callback(i, t, latents)

    # 10. Post-processing
    # make sure the VAE is in float32 mode, as it overflows in float16
    self.vae.to(dtype=torch.float32)

    # TODO(Patrick, William) - clean up when attention is refactored
    use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention")
    use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers
    # if xformers or torch_2_0 is used attention block does not need
    # to be in float32 which can save lots of memory
    if not use_torch_2_0_attn and not use_xformers:
        self.vae.post_quant_conv.to(latents.dtype)
        self.vae.decoder.conv_in.to(latents.dtype)
        self.vae.decoder.mid_block.to(latents.dtype)
    else:
        latents = latents.float()

    # 11. Convert to PIL
    # 这里,最后有个利用 VAE做图片放大的过程
    if output_type == "pil":
        #  这里把 latent 解码,其实就是图像放大的过程
        image = self.decode_latents(latents)

        image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)

        image = self.numpy_to_pil(image)

        # 11. Apply watermark
        if self.watermarker is not None:
            image = self.watermarker.apply_watermark(image)
    elif output_type == "pt":
        latents = 1 / self.vae.config.scaling_factor * latents
        image = self.vae.decode(latents).sample
        has_nsfw_concept = None
    else:
        image = self.decode_latents(latents)
        has_nsfw_concept = None

    # Offload last model to CPU
    if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
        self.final_offload_hook.offload()

    if not return_dict:
        return (image, has_nsfw_concept)

    return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

StableDiffusion 中表面上看 VAE 是一个编解码的过程,其实也可以看做是一个图像缩小放大的过程, 所以可以利用 VAE 的解码过程实现图像放大。

10.2. Imagen 总结

其实刚看到 Stable Diffusion 时就在想,为什么采用 CLIP 的文本编码器,而不是更成熟的大语言模型, 理论上,CLIP 的文本编码器就是很弱的,训练语料有限,而是都是短语和关键词,对于复杂的长句子的理解肯定不足的。 后来谷歌发表了 Imagen,果然对文本编码器做了调整, 个人理解 Imagen 的主要价值是验证了完整语言模型作为文本编码器能带来多大的效果提升, 此外,Imagen 再次强调了 classifier_free_guidance 的重要性,并引入了动态阈值,这允许使用比以前工作中看到的高得多的引导权重。

当然 Imagen 的缺点也很明显,级联三个模型,这生成效率差的可怜。 截止到目前,对于超高清图片的生成貌似还没有特别高效的方法,当然相信未来一定会解决。

这里提出一个小疑问,放大模型中是把小尺寸图片(待放大图片)加噪后和随机高斯噪声拼接成了一个6通道的图片,然后进行降噪生成新图片, 这里为什么不能直接在小尺寸图片的基础上直接降噪呢?暂时没有足够的资源和精力去做实验验证,留待有缘人吧。

10.3. 参考文献