Generative AI 新世界 | 扩散模型原理的代码实践之模型训练篇

PyTorch
云计算
生成式人工智能
0
0
<!--StartFragment--> 在上一期的[“扩散模型工作原理”代码实践系列的第一篇中](http://mp.weixin.qq.com/s?\\__biz=Mzg5Mzg1NDc2NQ==\\&mid=2247489218\\&idx=1\\&sn=c2a8452ae1b4b4667bcf1230fd5726bd\\&chksm=c0292380f75eaa96b8737a801b9c54ef41264e4b0ef69b659721f55f43ff2e3ce5c664bad14a\\&scene=21#wechat_redirect?trk=cndc-detail),我们通过两段不同代码块的实现,来对比了两种扩散模型的采样方法,让大家通过代码亲身体验了每次迭代后,需要添加额外采样噪声来保证输入噪声符合正态分布的重要性。本期文章,我们将继续用代码一起来实践扩散模型的训练过程。 我将会用四集的篇幅,逐行代码来构建扩散模型(Diffusion Model)。这四部分分别是: 1. 噪声采样(Sampling) 2. 训练扩散模型(Training) 3. 添加上下文(Embedding & Adding Context) 4. 噪声快速采样(Fast Sampling) 这四部分的完整代码可参考: * https\://github.com/hanyun2019/difussion-model-code-implementation?trk=cndc-detail 本文是第二部分:训练扩散模型(Training)。 ### **1 扩散模型的神经网络:U-Net** 本节我们将介绍扩散模型的神经网络架构,以及如何将其他信息整合到神经网络中。 我们**用于扩散模型的神经网络架构是 U-Net**。U-Net 将原始图像作为输入,其输出的尺寸会与原始图像大小相同,我们因此借助这一特点来预测噪声。 U-Net 论文发布于 2015 年,如下图所示: ![640.png](https://dev-media.amazoncloud.cn/5790238e18b542028ca88671d4dc90a4_640.png "640.png") ![640 (21).png](https://dev-media.amazoncloud.cn/9a977b8f7f7d460bb2a75a1af2029b02_640%20%2821%29.png "640 (21).png") Source: https\://arxiv.org/pdf/1505.04597.pdf?trk=cndc-detail 它最开始的用途是图像分割:例如,把照片或图像,在画面上分成行人、汽车、交通灯等,因此 U-Net 在汽车自动驾驶等研究领域被广泛使用。 U-Net 首先把输入信息做 Embedding,用卷积层向下采样(down-sampling),把信息压缩到更少量的空间。然后,它使用相同数量的采样块在输出上做向上采样(up-sampling),来预测图像上的噪声,而预测的噪声维度与原始输入图像相同,例如如下图所示,输入和输出的维度都是 16x16x3。  ![640 (1).png](https://dev-media.amazoncloud.cn/530797b28e4b4084abd61265844d894e_640%20%281%29.png "640 (1).png") Source:How Diffusion Models Work,by DeepLearning.AI U-Net 的另一个优点是可以接受额外的信息。那么我们需要哪些额外信息呢? 有两类重要的额外信息: 1. **时间嵌入(Time Embedding)** 2. **上下文嵌入(Context Embedding)** 如下图所示: ![640 (2).png](https://dev-media.amazoncloud.cn/243287de1dd7474990a8717912c28eb7_640%20%282%29.png "640 (2).png") Source: How Diffusion Models Work, by DeepLearning.AI 第一个额外信息是**时间嵌入**。它可以告诉模型时间步长(time step),并由此得知此时需要的噪声水平。对于这个时间嵌入,需要将其嵌入到一个向量中,然后将其添加到上采样块中。 另一个额外信息是**上下文嵌入**。它的主要作用是控制模型生成的内容。例如,通过一段文字描述,我们希望模型生成一些精灵图像,或者某种颜色头发的精灵图像等。 以下是关于时间嵌入和上下文嵌入的一段伪代码: ![640 (3).png](https://dev-media.amazoncloud.cn/23f2050b8dc54cba9444f105497c3fac_640%20%283%29.png "640 (3).png") 变量 cemb1 定义了上下文嵌入,而变量 temb1 定义了时间嵌入。 ### **2 扩散模型的训练策略** 训练神经网络(Neural Network)的目标是让网络预测噪声,真正的任务是让它学习图像上的噪声分布(也包括需要学习什么是游戏角色图片的特征)。 **训练策略**是:**从训练数据中去一张游戏角色图片,然后添加随机噪声,然后让神经网络预测这个噪声。之后,将预测的噪声与实际的噪声进行比较,计算损失函数。通过反向传播算法不断迭代,让神经网络学会更好的预测噪声**。 ![640 (4).png](https://dev-media.amazoncloud.cn/00a7dcba84ab45cdb774df6de9c122f1_640%20%284%29.png "640 (4).png")  Source: https\://learn.deeplearning.ai/diffusion-models/lesson/5/training\?trk=cndc-detail 那么如何确定这里的噪声是什么? 可以通过时间和采样,给它不同的噪声级别。但在实际的训练过程中,我们不希望神经网络一直观察同一个游戏角色图片,因为如果在一个周期内观察到不同的游戏角色图片,神经网络会更稳定,更均匀。所以,我们实际上是随机采样一个可能的时间步长,然后获取相应的噪声级别,添加到图像中,再让神经网络做预测。 之后,选择下一张游戏角色图片,执行同样的过程。这样就得到了一个稳定的训练过程。 例如,假设我们有一张关于巫师帽精灵的真实图像,如下所示:  ![640 (5).png](https://dev-media.amazoncloud.cn/0497203d233241a88f72ebd6e9d63a96_640%20%285%29.png "640 (5).png") Source: https\://learn.deeplearning.ai/diffusion-models/lesson/5/training?trk=cndc-detail 而下图从左到右分别是: 1. 加了噪声的输入图像 2.  经过第一轮迭代后,减去模型预测噪声后的图像 3. 经过第 31 轮迭代后,减去模型预测噪声后的图像  ![640 (6).png](https://dev-media.amazoncloud.cn/d553970aca2f43fb96d3ffbe446fced0_640%20%286%29.png "640 (6).png") Source: https\://learn.deeplearning.ai/diffusion-models/lesson/5/training\?trk=cndc-detail 当经历第一轮迭代后,神经网络还没有真正太了解精灵是什么,预测的噪声并不能完全改变输入的样子;因此当它减去模型预测噪声后,图像其实和最初的输入图像也差不多。 但当经历第 31 轮迭代后,神经网络已经更好地了解了这个精灵的样子。然后它预测噪点,然后从这个输入中减去噪点,产生看起来确实像这个巫师帽精灵的东西。 上面是针对一个样本的情况。 下图显示了多个不同的样本时的场景:多个不同的精灵、跨越多个迭代周期、以及不同迭代周期时图像的样子。正如你在第一个迭代周期时所看到的那样,它离精灵还有很长的路要走,但是当你进入第 32 迭代周期时,它已经看起来很像电子游戏中的角色形象啦! ![640 (7).png](https://dev-media.amazoncloud.cn/313388e22f4248cfa7a299af286cc7b0_640%20%287%29.png "640 (7).png") Source: https\://learn.deeplearning.ai/diffusion-models/lesson/5/training\?trk=cndc-detail ### **3 模型训练的代码实践** #### **3.1 创建 Amazon SageMaker Notebook 实例** 篇幅所限,本文不再赘述创建 [Amazon SageMaker](https://aws.amazon.com/cn/sagemaker/?trk=cndc-detail) Notebook 实例的过程。如需了解可参考以下文档: * https\://docs.aws.amazon.com/zh_cn/sagemaker/latest/dg/gs-setup-working-env.html?trk=cndc-detail #### **3.2 环境设置参考** 本篇里的代码会涉及模型训练,在没有 GPU 的实例上运行这段代码,需要几个小时才能完成。因此,本文的示例代码建议运行在具有 GPU 的实例上。我选择了 ml.p3.8xlarge 实例,该实例配置了 4 颗 V100 的 GPU,总 GPU 内存为64GB。关于 p3 系列的 GPU 配置情况,可参考以下亚马逊云科技的在线文档。 ![640 (8).png](https://dev-media.amazoncloud.cn/435367a8e6a64f05b12a7a9a8d9dc607_640%20%288%29.png "640 (8).png") Source: https\://aws.amazon.com/sagemaker/pricing/?trk=cndc-detail 示例代码的 notebook 在 [Amazon SageMaker](https://aws.amazon.com/cn/sagemaker/?trk=cndc-detail) Notebook 测试通过,内核为 conda_pytorch_p310,实例为一台 ml.p3.8xlarge 实例,如下图所示:  ![640 (9).png](https://dev-media.amazoncloud.cn/fe3cfd999a3648c3b7f893270cc07f75_640%20%289%29.png "640 (9).png") 我比较习惯使用“Open JupyterLab”,点击“Open JupyterLab”进入 SageMaker Notebook:  ![640 (10).png](https://dev-media.amazoncloud.cn/fb1c63e879464b73bcfc8b68cdd48609_640%20%2810%29.png "640 (10).png") 点击“Terminal”图标打开一个终端,输入 nvidia-smi 命令查看实例上的 GPU 情况:  ![640 (11).png](https://dev-media.amazoncloud.cn/a5cf167643fc4dd0aa0a259dcf7634c5_640%20%2811%29.png "640 (11).png") 如上图所见,配置在 ml.p3.8xlarge 实例上的四张 Tesla V100 GPU 卡的情况。GPU 卡已经齐备,我们可以进入训练代码了。 本实验的完整示例代码可参考: * https\://github.com/hanyun2019/difussion-model-code-implementation/blob/dm-project-haowen-mac/L2\_Training.ipynb?trk=cndc-detail #### **3.3 导入所需的库文件** 现在我们进入通过代码解读扩散模型的部分。 首先,我们需要导入 PyTorch 和一些 PyTorch 相关的实用库,以及导入帮助我们设计神经网络的一些辅助函数(helper functions)。 ``` from typing import Dict, Tuple from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision import models, transforms from torchvision.utils import save_image, make_grid import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation, PillowWriter import numpy as np from IPython.display import HTML from diffusion_utilities import * ``` #### **3.4 定义神经网络架构和参数** ##### **3.4.1 定义 U-Net 网络架构** 首先,我们先定义一个 U-Net 网络架构。 ``` class ContextUnet(nn.Module): def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28):  # cfeat - context features super(ContextUnet, self).__init__() # number of input channels, number of intermediate feature maps and number of classes self.in_channels = in_channels self.n_feat = n_feat self.n_cfeat = n_cfeat self.h = height  #assume h == w. must be divisible by 4, so 28,24,20,16... # Initialize the initial convolutional layer self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True) # Initialize the down-sampling path of the U-Net with two levels self.down1 = UnetDown(n_feat, n_feat)        # down1 #[10, 256, 8, 8] self.down2 = UnetDown(n_feat, 2 * n_feat)    # down2 #[10, 256, 4, 4] # original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU()) self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU()) # Embed the timestep and context labels with a one-layer fully connected neural network self.timeembed1 = EmbedFC(1, 2*n_feat) self.timeembed2 = EmbedFC(1, 1*n_feat) self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat) self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat) # Initialize the up-sampling path of the U-Net with three levels self.up0 = nn.Sequential( nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4), # up-sample nn.GroupNorm(8, 2 * n_feat), # normalize                         nn.ReLU(), ) self.up1 = UnetUp(4 * n_feat, n_feat) self.up2 = UnetUp(2 * n_feat, n_feat) # Initialize the final convolutional layers to map to the same number of channels as the input image self.out = nn.Sequential( nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), # reduce number of feature maps   #in_channels, out_channels, kernel_size, stride=1, padding=0 nn.GroupNorm(8, n_feat), # normalize nn.ReLU(), nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input ) def forward(self, x, t, c=None): """ x : (batch, n_feat, h, w) : input image t : (batch, n_cfeat)      : time step c : (batch, n_classes)    : context label """ # x is the input image, c is the context label, t is the timestep, context_mask says which samples to block the context on # pass the input image through the initial convolutional layer x = self.init_conv(x) # pass the result through the down-sampling path down1 = self.down1(x)       #[10, 256, 8, 8] down2 = self.down2(down1)   #[10, 256, 4, 4] # convert the feature maps to a vector and apply an activation hiddenvec = self.to_vec(down2) # mask out context if context_mask == 1 if c is None: c = torch.zeros(x.shape[0], self.n_cfeat).to(x) # embed context and timestep cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1)     # (batch, 2*n_feat, 1,1) temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1) cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1) temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1) #print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}") up1 = self.up0(hiddenvec) up2 = self.up1(cemb1*up1 + temb1, down2)  # add and multiply embeddings up3 = self.up2(cemb2*up2 + temb2, down1) out = self.out(torch.cat((up3, x), 1)) return out ``` ##### **3.4.2 定义模型训练的超参数** 接下来,我们将设置模型训练需要的一些超参数,例如图像的长度和高度 height 值设置为 16,则表示 16 乘 16 的正方形图像等,如下所示。 ``` # hyperparameters # diffusion hyperparameters timesteps = 500 beta1 = 1e-4 beta2 = 0.02 # network hyperparameters device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu')) n_feat = 64 # 64 hidden dimension feature n_cfeat = 5 # context vector is of size 5 height = 16 # 16x16 image save_dir = './weights/' # training hyperparameters batch_size = 100 n_epoch = 32 lrate=1e-3 ``` ##### **3.4.3 定义添加噪声的策略(noise schedule)** 定义 DDPM 论文中定义的添加噪声策略(noise schedule): ``` # construct DDPM noise schedule b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1 a_t = 1 - b_t ab_t = torch.cumsum(a_t.log(), dim=0).exp() ab_t[0] = 1 ``` 其中 beta1 和 beta2 是 DDPM 算法的超参数。 ##### **3.4.4 实例化模型** 接下来实例化模型: ``` # construct model nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device) ``` #### **3.5 启动模型训练** 加载数据集和构建优化器: ``` # load dataset and construct optimizer dataset = CustomDataset("./sprites_1788_16x16.npy", "./sprite_labels_nc_1788_16x16.npy", transform, null_context=False) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1) optim = torch.optim.Adam(nn_model.parameters(), lr=lrate) ``` 将图像扰动到指定的噪点水平: ``` # helper function: perturbs an image to a specified noise level def perturb_input(x, t, noise): return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]) * noise ``` 开始模型训练。我们在本文中先使用无上下文策略的代码,在下一集的文章中我们还将对比有上下文策略的代码做对比。无上下文策略的代码如下所示: ``` # training without context code # set into train mode nn_model.train() for ep in range(n_epoch): print(f'epoch {ep}') # linearly decay learning rate optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch) pbar = tqdm(dataloader, mininterval=2 ) for x, _ in pbar:   # x: images optim.zero_grad() x = x.to(device) # perturb data noise = torch.randn_like(x) t = torch.randint(1, timesteps + 1, (x.shape[0],)).to(device) x_pert = perturb_input(x, t, noise) # use network to recover noise pred_noise = nn_model(x_pert, t / timesteps) # loss is mean squared error between the predicted and true noise loss = F.mse_loss(pred_noise, noise) loss.backward() optim.step() # save model periodically if ep%4==0 or ep == int(n_epoch-1): if not os.path.exists(save_dir): os.mkdir(save_dir) torch.save(nn_model.state_dict(), save_dir + f"model_{ep}.pth") print('saved model at ' + save_dir + f"model_{ep}.pth") ``` 你可以在 weights 这个生成的目录下面,看到一些正不断生成的 xxx.pth 文件如下图所示。如果对照模型训练过程的日志输出,你会发现这正对应着每 5 次迭代(epoch)后,会把模型保存到某个 xxx.pth 文件的操作:例如:saved model at ./weights/model\_8.pth 有好奇心的同学可能会问:**为什么要每隔 5 个训练回合,保存这些中间训练过程的模型输出呢?**(如下图所示)。这其实是个非常好非常重要的问题。我们在此先卖个关子,**谜底会在后面内容中揭开!** ![640 (12).png](https://dev-media.amazoncloud.cn/a4ad7c04098f42fe8a92487f55256c0a_640%20%2812%29.png "640 (12).png") #### **3.6 采样** 模型训练结束后,进入采样环节。 ##### **3.6.1 定义 DDPM 采样算法** ``` # helper function; removes the predicted noise (but adds some noise back in to avoid collapse) def denoise_add_noise(x, t, pred_noise, z=None): if z is None: z = torch.randn_like(x) noise = b_t.sqrt()[t] * z mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt() return mean + noise # sample using standard algorithm @torch.no_grad() def sample_ddpm(n_sample, save_rate=20): # x_T ~ N(0, 1), sample initial noise samples = torch.randn(n_sample, 3, height, height).to(device)   # array to keep track of generated steps for plotting intermediate = [] for i in range(timesteps, 0, -1): print(f'sampling timestep {i:3d}', end='\\r') # reshape time tensor t = torch.tensor([i / timesteps])[:, None, None, None].to(device) # sample some random noise to inject back in. For i = 1, don't add back in noise z = torch.randn_like(samples) if i > 1 else 0 eps = nn_model(samples, t)    # predict noise e_(x_t,t) samples = denoise_add_noise(samples, i, eps, z) if i % save_rate ==0 or i==timesteps or i<8: intermediate.append(samples.detach().cpu().numpy()) intermediate = np.stack(intermediate) return samples, intermediate ``` ##### **3.6.2. 可视化第 1 个训练回合(Epoch 0)的模型输出** 首先加载之前保存的中间过程的模型输出结果: ``` # load in model weights and set to eval mode nn_model.load_state_dict(torch.load(f"{save_dir}/model_0.pth", map_location=device)) nn_model.eval() print("Loaded in Model") ``` 然后把这些模型输出结果可视化,一帧一帧地播放出来: ``` # visualize samples plt.clf() samples, intermediate_ddpm = sample_ddpm(32) animation_ddpm = plot_sample(intermediate_ddpm,32,4,save_dir, "ani_run", None, save=False) HTML(animation_ddpm.to_jshtml()) ``` 以下展示第 1 个训练回合第一帧 (Epoch 0, frame 0):  ![640 (13).png](https://dev-media.amazoncloud.cn/bd0eacc228fd488b96f95477e339026d_640%20%2813%29.png "640 (13).png") Source: Screenshot of using [Amazon SageMaker](https://aws.amazon.com/cn/sagemaker/?trk=cndc-detail) Notebook,by Haowen 以下展示第 1 个训练回合最后一帧 (Epoch 0, frame 31): ![640 (14).png](https://dev-media.amazoncloud.cn/815dba75577e42198f97e7922e7578bd_640%20%2814%29.png "640 (14).png") Source: Screenshot of using [Amazon SageMaker](https://aws.amazon.com/cn/sagemaker/?trk=cndc-detail) Notebook ,by Haowen 聪明的读者们,我想你们可能已经知道之前我们保存中间训练回合的模型输出结构的意义了:**是为了从这里开始,做对比展示**。 ##### **3.6.3 可视化第 5 个训练回合(Epoch 4)的模型输出** ``` # load in model weights and set to eval mode nn_model.load_state_dict(torch.load(f"{save_dir}/model_4.pth", map_location=device)) nn_model.eval() print("Loaded in Model") # visualize samples plt.clf() samples, intermediate_ddpm = sample_ddpm(32) animation_ddpm = plot_sample(intermediate_ddpm,32,4,save_dir, "ani_run", None, save=False) HTML(animation_ddpm.to_jshtml()) ``` 以下展示第 5 个训练回合第一帧 (Epoch 4, frame 0): ![640 (15).png](https://dev-media.amazoncloud.cn/5455957329a7420aa612b88009a741d1_640%20%2815%29.png "640 (15).png") Source: Screenshot of using [Amazon SageMaker](https://aws.amazon.com/cn/sagemaker/?trk=cndc-detail) Notebook  by Haowen 以下展示第 5 个训练回合最后一帧 (Epoch 5, frame 31): ![640 (16).png](https://dev-media.amazoncloud.cn/61cb177deb3c43c6b00846ef89d07c2b_640%20%2816%29.png "640 (16).png") Source: Screenshot of using [Amazon SageMaker](https://aws.amazon.com/cn/sagemaker/?trk=cndc-detail) Notebook,by Haowen 和第一个训练回合相比,经历了 5 个训练回合后,模型输出的图像,开始有了一点轮廓了。 ##### **3.6.4 可视化第 25 个训练回合(Epoch 24)的模型输出** ``` # load in model weights and set to eval mode nn_model.load_state_dict(torch.load(f"{save_dir}/model_24.pth", map_location=device)) nn_model.eval() print("Loaded in Model") # visualize samples plt.clf() samples, intermediate_ddpm = sample_ddpm(32) animation_ddpm = plot_sample(intermediate_ddpm,32,4,save_dir, "ani_run", None, save=False) HTML(animation_ddpm.to_jshtml()) ``` 以下展示第 25 个训练回合第一帧 (Epoch 16, frame 0): ![640 (17).png](https://dev-media.amazoncloud.cn/7e7214f75536492d82c11dd044afe288_640%20%2817%29.png "640 (17).png") Source: Screenshot of using [Amazon SageMaker](https://aws.amazon.com/cn/sagemaker/?trk=cndc-detail) Notebook ,by Haowen 以下展示第 25 个训练回合最后一帧 (Epoch 16, frame 31): ![640 (18).png](https://dev-media.amazoncloud.cn/7f9bf63daa5144e398d8869eb3b0f0a7_640%20%2818%29.png "640 (18).png") Source: Screenshot of using [Amazon SageMaker](https://aws.amazon.com/cn/sagemaker/?trk=cndc-detail) Notebook ,by Haowen 和第 5 个训练回合相比,经历了 25 个训练回合后,模型输出的图像,轮廓形象已经基本逼真呈现出来了,我们期待最后几个训练回合对图像质量的进一步改善。 ##### **3.6.5. 可视化第 32 个训练回合(Epoch 31)的模型输出** ``` # load in model weights and set to eval mode nn_model.load_state_dict(torch.load(f"{save_dir}/model_31.pth", map_location=device)) nn_model.eval() print("Loaded in Model") # visualize samples plt.clf() samples, intermediate_ddpm = sample_ddpm(32) animation_ddpm = plot_sample(intermediate_ddpm,32,4,save_dir, "ani_run", None, save=False) HTML(animation_ddpm.to_jshtml()) ``` 以下展示第 32 个训练回合第一帧 (Epoch 32, frame 0): ![640 (19).png](https://dev-media.amazoncloud.cn/5ee92343db1547218aa0e425ef485bb8_640%20%2819%29.png "640 (19).png") Source: Screenshot of using [Amazon SageMaker](https://aws.amazon.com/cn/sagemaker/?trk=cndc-detail) Notebook ,by Haowen 以下展示第 32 个训练回合最后一帧 (Epoch 31, frame 31): ![640 (20).png](https://dev-media.amazoncloud.cn/24f4b42c700444899e67abd69ee80adf_640%20%2820%29.png "640 (20).png") 和第 25 个训练回合相比,经历了 32 个训练回合后,模型输出的图像…… 说实话,我觉得差不多。可能是肉眼已经分辨不出,或者在第 25 个训练回合的图像已经很不错了。 ### **总结** 作为“扩散模型工作原理”代码实践系列的第二篇,本文用代码实践了扩散模型的训练过程,希望帮助读者们揭开扩散模型训练的神秘面纱。 下一篇我们将探讨在模型训练过程中,通过添加上下文(Embedding & Adding Context)的策略来优化模型训练过程,敬请期待。 ### **参考资料** 1\. DeepLearning.AI short course “How Diffusion Models Work” * https\://www\.deeplearning.ai/short-courses/how-diffusion-models-work/?trk=cndc-detail 2.Sprites by ElvGames, FrootsnVeggies and kyrise\ * FrootsnVeggies:https\://zrghr.itch.io/froots-and-veggies-culinary-pixels?trk=cndc-detail * kyrise:https\://kyrise.itch.io/?trk=cndc-detail 3.Code reference, This code is modified from  * https\://github.com/cloneofsimo/minDiffusion\?trk=cndc-detail 4\. DDPM & DDIM papers Diffusion model is based onDenoising Diffusion Probabilistic Models and Denoising Diffusion Implicit Models * Denoising Diffusion Probabilistic Models :https\://arxiv.org/abs/2006.11239?trk=cndc-detail * Denoising Diffusion Implicit Models:https\://arxiv.org/abs/2010.02502?trk=cndc-detail 5. 上采样和下采样 * https\://www\.jianshu.com/p/fd9e2166cfcc?trk=cndc-detail 6. 理解扩散模型(DDPM) * https\://www\.zhihu.com/question/545764550/answer/2670611518?trk=cndc-detail ![640.gif](https://dev-media.amazoncloud.cn/80e76329e39f413887e523beecce6570_640.gif "640.gif") <!--EndFragment-->
目录
亚马逊云科技解决方案 基于行业客户应用场景及技术领域的解决方案
联系亚马逊云科技专家
亚马逊云科技解决方案
基于行业客户应用场景及技术领域的解决方案
联系专家
0
目录
关闭