昇思MindSpore学习笔记4-03生成式--Diffusion扩散模型

摘要:

        记录昇思MindSpore AI框架使用DDPM模型给图像数据正向逐步添加噪声,反向逐步去除噪声的工作原理和实际使用方法、步骤。

一、概念

1. 扩散模型Diffusion Models

DDPM(denoising diffusion probabilistic model)

(无)条件图像/音频/视频生成领域

        Open-ai

                GLIDE

                DALL-E

        海德堡大学

                潜在扩散

        Google Brain

                图像生成

2. 扩散过程

固定(或预定义)正向扩散过程 q

        将噪声从一些简单分布转换为数据样本

        逐渐添加高斯噪声到图像中,得到纯噪声

学习反向去噪的扩散过程 p0 

        训练神经网络从纯噪声开始逐渐图像去噪,得到实际图像

3. 扩散模型实现原理

(1)正向过程

        图片上加噪声

        神经网络优化可控损失函数

真实数据分布q(x0)

        由于 x0q(x0) ,采样获得图像x0

定义向扩散过程q(xt|xt-1) 

        动态方差 0<β1<β2<...<βT<1 时间步长t

        每个时间步长t添加高斯噪声

        马尔科夫过程:

正态分布(高斯分布)定义参数

        平均值μ

        方差σ2 0

        每个时间步长t从条件高斯分布产生新的噪声图像q({\mu}_t)=\sqrt{1- {\beta}_t}{x}_{t-1}

        采样\epsilon \sim N(0,I)

        设置q(x_t)=\sqrt{1-\beta _t}x_{t-1}+\sqrt{\beta _t}\epsilon

                \beta _t每个时间步长t不恒定

                        通过动态方差

                        每个时间步长的 \beta _t是线性的、二次的、余弦的等

                        设置时间表,得到x_0,...,x_t,...x_T

                        t足够大时x_T就是纯高斯噪声

(2)反向过程

        条件概率分布 p(x_{t-1}|x_t)

        采样随机高斯噪声x_T

        逐渐去噪

        得到真实分布x_0 样本

神经网络近似学习条件概率分布 pθ(xt-1|xt)

        神经网络参数θ

高斯分布参数:

        由\mu _\theta参数化的平均值

        由\mu _\theta参数化的方差

反向过程公式p_\theta (x_{t-1}|x_t)=N(x_{t-1};\mu (x_t,t),\sum _\theta (x_t,t))

        平均值和方差取决于噪声水平t

        神经网络通过学习来找到这些均值和方差

        方差固定

        神经网络只学习条件概率分布的平均值μθ

导出目标函数来学习反向过程的平均值

qp_\theta组合为变分自动编码器(VAE)

        最小化真值数据样本x_0的似然负对数

        变分下界ELBO是每个时间步长的损失之和

                 L=L_0+L_1+...+L_T

                每项损失L_t是2个高斯分布之间的KL发散除了L_0

                相对于均值的L2-loss!

构建Diffusion正向过程的直接结果
x_0条件下任意噪声水平采样x_t

        a_t := 1-\beta _t 

        \bar{a}t:=\prod _{s=1}^{t}\textrm{a}_s ,        q(x_t|x_0)=N(x_t;\sqrt{\bar{a}_t}x_0,(1-\bar{a}_t)I)

采样高斯噪声适当缩放添加到x_0 直接获得x_t

\bar{a}_t是已知\beta _t方差计划的函数可以预先计算

训练期间随机采样t优化损失函数L的随机项L_T

优点

重新参数化平均值

神经网络学习构成损失的KL项中噪声的附加噪声

神经网络成了噪声预测器,不是均值预测器

平均值计算:\mu _\theta (x_t,t)=\frac{1}{\sqrt{a_t}}(x_t-\frac{\beta _t}{\sqrt{1-\bar{a}_t}}\epsilon _\theta (x_t,t))

目标函数Lt \left \| \epsilon -\epsilon _\theta (x_t,t) \right \|^2 =\left \| \epsilon -\epsilon _\theta (\sqrt{\bar{a}_t}x_0+\sqrt{(1-\bar{a}_t)}\epsilon ,t) \right \|^2

                        随机步长t由(ϵ∼N(0,I)) 给定

                        x_0初始图像

                        ϵ时间步长t纯噪声采样

                       \epsilon _\theta (x_t,t)神经网络

基于真实噪声和预测高斯噪声之间的简单均方误差(MSE)优化神经网络

训练算法如下:

4. Net神经网络预测噪声

神经网络需要在特定时间步长接收带噪声的图像,并返回预测的噪声。

预测噪声是与输入图像具有相同大小/分辨率的张量。

网络接受并输出相同形状的张量。

自动编码器

        编码器编码图像为"bottleneck"--较小的隐藏表示

        解码器解码"bottleneck"回实际图像

残差连接改善梯度流

正向和反向过程在有限时间步长T(T=1000)

t=0开始,在数据分布中采样真实图像x_0

使用ImageNet猫图像添加噪声

正向过程

        每个时间步长t都采样一些高斯分布噪声

        添加到上一个次图像中

        足够大的T + 较好地添加噪声过程

        t = T时得到各向同性高斯分布

二、环境准备

安装并导入所需的库MindSpore、download、dataset、matplotlib以及tqdm

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y 
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore

输出:

import math
from functools import partial
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import numpy as np
from multiprocessing import cpu_count
from download import download

import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, Parameter
from mindspore import dtype as mstype
from mindspore.dataset.vision import Resize, Inter, CenterCrop, ToTensor, RandomHorizontalFlip, ToPIL
from mindspore.common.initializer import initializer
from mindspore.amp import DynamicLossScaler

ms.set_seed(0)

三、构建Diffusion模型

1.定义帮助函数和类

def rearrange(head, inputs):
    b, hc, x, y = inputs.shape
    c = hc // head
    return inputs.reshape((b, head, c, x * y))

def rsqrt(x):
    res = ops.sqrt(x)
    return ops.inv(res)

def randn_like(x, dtype=None):
    if dtype is None:
        dtype = x.dtype
    res = ops.standard_normal(x.shape).astype(dtype)
    return res

def randn(shape, dtype=None):
    if dtype is None:
        dtype = ms.float32
    res = ops.standard_normal(shape).astype(dtype)
    return res

def randint(low, high, size, dtype=ms.int32):
    res = ops.uniform(size, Tensor(low, dtype), Tensor(high, dtype), dtype=dtype)
    return res

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def _check_dtype(d1, d2):
    if ms.float32 in (d1, d2):
        return ms.float32
    if d1 == d2:
        return d1
    raise ValueError('dtype is not supported.')

class Residual(nn.Cell):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def construct(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

2.定义上采样和下采样操作的别名

def Upsample(dim):
    return nn.Conv2dTranspose(dim, dim, 4, 2, pad_mode="pad", padding=1)

def Downsample(dim):
    return nn.Conv2d(dim, dim, 4, 2, pad_mode="pad", padding=1)

3.位置向量

神经网络时间参数使用正弦位置嵌入来编码特定时间步长t 

SinusoidalPositionEmbeddings模块

输入采用(batch_size, 1)形状的张量

        批处理噪声图像、噪声水平

转换为(batch_size, dim)形状的张量

        dim是位置嵌入尺寸

添加到每个剩余块中

class SinusoidalPositionEmbeddings(nn.Cell):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = np.exp(np.arange(half_dim) * - emb)
        self.emb = Tensor(emb, ms.float32)

    def construct(self, x):
        emb = x[:, None] * self.emb[None, :]
        emb = ops.concat((ops.sin(emb), ops.cos(emb)), axis=-1)
        return emb

4.ResNet/ConvNeXT块

选择ConvNeXT块构建U-Net模型

class Block(nn.Cell):
    def __init__(self, dim, dim_out, groups=1):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, pad_mode="pad", padding=1)
        self.proj = c(dim, dim_out, 3, padding=1, pad_mode='pad')
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()
​
    def construct(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)
​
        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift
​
        x = self.act(x)
        return x
​
class ConvNextBlock(nn.Cell):
    def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
        super().__init__()
        self.mlp = (
            nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim))
            if exists(time_emb_dim)
            else None
        )
​
        self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad")
        self.net = nn.SequentialCell(
            nn.GroupNorm(1, dim) if norm else nn.Identity(),
            nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"),
            nn.GELU(),
            nn.GroupNorm(1, dim_out * mult),
            nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"),
        )
​
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
​
    def construct(self, x, time_emb=None):
        h = self.ds_conv(x)
        if exists(self.mlp) and exists(time_emb):
            assert exists(time_emb), "time embedding must be passed in"
            condition = self.mlp(time_emb)
            condition = condition.expand_dims(-1).expand_dims(-1)
            h = h + condition
​
        h = self.net(h)
        return h + self.res_conv(x)

5.Attention模块

multi-head self-attention

        常规注意力中缩放

LinearAttention

        时间和内存要求在序列长度上线性缩放

class Attention(nn.Cell):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
​
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True)
        self.map = ops.Map()
        self.partial = ops.Partial()
​
    def construct(self, x):
        b, _, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, 1)
        q, k, v = self.map(self.partial(rearrange, self.heads), qkv)
​
        q = q * self.scale
​
        # 'b h d i, b h d j -> b h i j'
        sim = ops.bmm(q.swapaxes(2, 3), k)
        attn = ops.softmax(sim, axis=-1)
        # 'b h i j, b h d j -> b h i d'
        out = ops.bmm(attn, v.swapaxes(2, 3))
        out = out.swapaxes(-1, -2).reshape((b, -1, h, w))
​
        return self.to_out(out)
​
​
class LayerNorm(nn.Cell):
    def __init__(self, dim):
        super().__init__()
        self.g = Parameter(initializer('ones', (1, dim, 1, 1)), name='g')
​
    def construct(self, x):
        eps = 1e-5
        var = x.var(1, keepdims=True)
        mean = x.mean(1, keep_dims=True)
        return (x - mean) * rsqrt((var + eps)) * self.g
​
​
class LinearAttention(nn.Cell):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)
​
        self.to_out = nn.SequentialCell(
            nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True),
            LayerNorm(dim)
        )
​
        self.map = ops.Map()
        self.partial = ops.Partial()
​
    def construct(self, x):
        b, _, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, 1)
        q, k, v = self.map(self.partial(rearrange, self.heads), qkv)
​
        q = ops.softmax(q, -2)
        k = ops.softmax(k, -1)
​
        q = q * self.scale
        v = v / (h * w)
​
        # 'b h d n, b h e n -> b h d e'
        context = ops.bmm(k, v.swapaxes(2, 3))
        # 'b h d e, b h d n -> b h e n'
        out = ops.bmm(context.swapaxes(2, 3), q)
​
        out = out.reshape((b, -1, h, w))
        return self.to_out(out)

6.组归一化

U-Net卷积/注意层与群归一化

定义PreNorm类

        在注意层之前应用groupnorm

class PreNorm(nn.Cell):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)
​
    def construct(self, x):
        x = self.norm(x)
        return self.fn(x)

7.条件U-Net

网络\epsilon _\theta (x_t,t)

        输入

                噪声图像,(batch_size, num_channels, height, width)形状

                噪音水平,(batch_size, 1)形状

        输出

                噪声,(batch_size, num_channels, height, width)形状的张量

8.网络构建过程

噪声图像批上应用卷积层

计算噪声水平位置

应用一系列下采样级

        每个下采样阶段

                2个ResNet/ConvNeXT块

                Groupnorm

                Attention

                残差连接

                一个下采样操作

应用ResNet或ConvNeXT块

交织attention

应用一系列上采样级

        每个上采样级

                2个ResNet/ConvNeXT块

                Groupnorm

                Attention

                残差连接

                一个上采样操作

应用ResNet/ConvNeXT块

应用卷积层

class Unet(nn.Cell):
    def __init__(
            self,
            dim,
            init_dim=None,
            out_dim=None,
            dim_mults=(1, 2, 4, 8),
            channels=3,
            with_time_emb=True,
            convnext_mult=2,
    ):
        super().__init__()
​
        self.channels = channels
​
        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True)
​
        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
​
        block_klass = partial(ConvNextBlock, mult=convnext_mult)
​
        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.SequentialCell(
                SinusoidalPositionEmbeddings(dim),
                nn.Dense(dim, time_dim),
                nn.GELU(),
                nn.Dense(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None
​
        self.downs = nn.CellList([])
        self.ups = nn.CellList([])
        num_resolutions = len(in_out)
​
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)
​
            self.downs.append(
                nn.CellList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Downsample(dim_out) if not is_last else nn.Identity(),
                    ]
                )
            )
​
        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
​
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)
​
            self.ups.append(
                nn.CellList(
                    [
                        block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Upsample(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )
​
        out_dim = default(out_dim, channels)
        self.final_conv = nn.SequentialCell(
            block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
        )
​
    def construct(self, x, time):
        x = self.init_conv(x)
​
        t = self.time_mlp(time) if exists(self.time_mlp) else None
​
        h = []
​
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)
​
            x = downsample(x)
​
        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)
​
        len_h = len(h) - 1
        for block1, block2, attn, upsample in self.ups:
            x = ops.concat((x, h[len_h]), 1)
            len_h -= 1
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
​
            x = upsample(x)
        return self.final_conv(x)

四、正向扩散

1.定义T时间步的时间表

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return np.linspace(beta_start, beta_end, timesteps).astype(np.float32)

首先使用T = 200时间步长的线性计划

定义\beta _t的各种变量

        方差 \bar{a}_t的累积乘积

        每个变量都是一维张量,存储tT的值

        extract函数,批提取t索引

# 扩散200步
timesteps = 200
​
# 定义 beta schedule
betas = linear_beta_schedule(timesteps=timesteps)
​
# 定义 alphas
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.pad(alphas_cumprod[:-1], (1, 0), constant_values=1)
​
sqrt_recip_alphas = Tensor(np.sqrt(1. / alphas))
sqrt_alphas_cumprod = Tensor(np.sqrt(alphas_cumprod))
sqrt_one_minus_alphas_cumprod = Tensor(np.sqrt(1. - alphas_cumprod))
​
# 计算 q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
​
p2_loss_weight = (1 + alphas_cumprod / (1 - alphas_cumprod)) ** -0.
p2_loss_weight = Tensor(p2_loss_weight)
​
def extract(a, t, x_shape):
    b = t.shape[0]
    out = Tensor(a).gather(t, -1)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

2.扩散过程的每个时间步猫图像添加噪音

# 下载猫猫图像
url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/image_cat.zip'
path = download(url, './', kind="zip", replace=True)

输出:

Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/image_cat.zip (170 kB)

file_sizes: 100%|████████████████████████████| 174k/174k [00:00<00:00, 1.45MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./

from PIL import Image
​
image = Image.open('./image_cat/jpg/000000039769.jpg')
base_width = 160
image = image.resize((base_width, int(float(image.size[1]) * float(base_width / float(image.size[0])))))
image.show()

输出:

添加噪声到mindspore张量

定义图像转换

        从PIL图像转换到mindspore张量

        除以255标准化图像,确保在[-1,1]范围内(假设图像数据由{0,1,...,255}中的整数组成)

from mindspore.dataset import ImageFolderDataset
​
image_size = 128
transforms = [
    Resize(image_size, Inter.BILINEAR),
    CenterCrop(image_size),
    ToTensor(),
    lambda t: (t * 2) - 1
]
​
​
path = './image_cat'
dataset = ImageFolderDataset(dataset_dir=path, num_parallel_workers=cpu_count(),
                             extensions=['.jpg', '.jpeg', '.png', '.tiff'],
                             num_shards=1, shard_id=0, shuffle=False, decode=True)
dataset = dataset.project('image')
transforms.insert(1, RandomHorizontalFlip())
dataset_1 = dataset.map(transforms, 'image')
dataset_2 = dataset_1.batch(1, drop_remainder=True)
x_start = next(dataset_2.create_tuple_iterator())[0]
print(x_start.shape)

输出:

(1, 3, 128, 128)

3.定义反向变换

输入一个包[−1,1]的张量

输出PIL图像

import numpy as np
​
reverse_transform = [
    lambda t: (t + 1) / 2,
    lambda t: ops.permute(t, (1, 2, 0)), # CHW to HWC
    lambda t: t * 255.,
    lambda t: t.asnumpy().astype(np.uint8),
    ToPIL()
]
​
def compose(transform, x):
    for d in transform:
        x = d(x)
    return x

验证:

reverse_image = compose(reverse_transform, x_start[0])
reverse_image.show()

输出:

4.定义向扩散过程

def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = randn_like(x_start)
    return (extract(sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)

测试:

def get_noisy_image(x_start, t):
    # 添加噪音
    x_noisy = q_sample(x_start, t=t)
​
    # 转换为 PIL 图像
    noisy_image = compose(reverse_transform, x_noisy[0])
​
    return noisy_image
[18]:

# 设置 time step
t = Tensor([40])
noisy_image = get_noisy_image(x_start, t)
print(noisy_image)
noisy_image.show()

输出:

<PIL.Image.Image image mode=RGB size=128x128 at 0x7F54569F3950>

显示不同的时间步骤:

import matplotlib.pyplot as plt
​
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        imgs = [imgs]
​
    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    _, axs = plt.subplots(figsize=(200, 200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [image] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
​
    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])
​
    plt.tight_layout()
[20]:

plot([get_noisy_image(x_start, Tensor([t])) for t in [0, 50, 100, 150, 199]])

定义损失函数:

def p_losses(unet_model, x_start, t, noise=None):
    if noise is None:
        noise = randn_like(x_start)
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = unet_model(x_noisy, t)
​
    loss = nn.SmoothL1Loss()(noise, predicted_noise)# todo
    loss = loss.reshape(loss.shape[0], -1)
    loss = loss * extract(p2_loss_weight, t, loss.shape)
    return loss.mean()

五、数据准备与处理

1.下载数据集

Fashion-MNIST图像

        线性缩放为 [−1,1]

        相同图像大小28x28

        随机水平翻转

使用download下载

解压到指定路径./

# 下载MNIST数据集
url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset.zip'
path = download(url, './', kind="zip", replace=True)

输出:

Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset.zip (29.4 MB)

file_sizes: 100%|██████████████████████████| 30.9M/30.9M [00:00<00:00, 43.4MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./
from mindspore.dataset import FashionMnistDataset
​
image_size = 28
channels = 1
batch_size = 16
​
fashion_mnist_dataset_dir = "./dataset"
dataset = FashionMnistDataset(dataset_dir=fashion_mnist_dataset_dir, usage="train", num_parallel_workers=cpu_count(), shuffle=True, num_shards=1, shard_id=0)

2.定义transform操作

图像预处理

        随机水平翻转

        重新调整

        值在 [−1,1]范围内

transforms = [
    RandomHorizontalFlip(),
    ToTensor(),
    lambda t: (t * 2) - 1
]
dataset = dataset.project('image')
dataset = dataset.shuffle(64)
dataset = dataset.map(transforms, 'image')
dataset = dataset.batch(16, drop_remainder=True)

x = next(dataset.create_dict_iterator())
print(x.keys())

输出:

dict_keys(['image'])

3.采样

在训练期间从模型中采样。

采样算法2:

反转扩散过程

        从T开始,采样高斯分布纯噪声

        神经网络使用条件概率逐渐去噪,时间步t=0结束

        重新参数化

                噪声预测器插入平均值

        导出降噪程度较低的图像xt-1

        得到一个近似真实数据分布的图像

def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)
​
    if t_index == 0:
        return model_mean
    posterior_variance_t = extract(posterior_variance, t, x.shape)
    noise = randn_like(x)
    return model_mean + ops.sqrt(posterior_variance_t) * noise
​
def p_sample_loop(model, shape):
    b = shape[0]
    # 从纯噪声开始
    img = randn(shape, dtype=None)
    imgs = []
​
    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, ms.numpy.full((b,), i, dtype=mstype.int32), i)
        imgs.append(img.asnumpy())
    return imgs
​
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

六、训练过程

# 定义动态学习率
lr = nn.cosine_decay_lr(min_lr=1e-7, max_lr=1e-4, total_step=10*3750, step_per_epoch=3750, decay_epoch=10)
​
# 定义 Unet模型
unet_model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,)
)
​
name_list = []
for (name, par) in list(unet_model.parameters_and_names()):
    name_list.append(name)
i = 0
for item in list(unet_model.trainable_params()):
    item.name = name_list[i]
    i += 1
​
# 定义优化器
optimizer = nn.Adam(unet_model.trainable_params(), learning_rate=lr)
loss_scaler = DynamicLossScaler(65536, 2, 1000)
​
# 定义正向过程
def forward_fn(data, t, noise=None):
    loss = p_losses(unet_model, data, t, noise)
    return loss
​
# 计算梯度
grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)
​
# 梯度更新
def train_step(data, t, noise):
    loss, grads = grad_fn(data, t, noise)
    optimizer(grads)
    return loss
import time
​
# 由于时间原因,epochs设置为1,可根据需求进行调整
epochs = 1
​
for epoch in range(epochs):
    begin_time = time.time()
    for step, batch in enumerate(dataset.create_tuple_iterator()):
        unet_model.set_train()
        batch_size = batch[0].shape[0]
        t = randint(0, timesteps, (batch_size,), dtype=ms.int32)
        noise = randn_like(batch[0])
        loss = train_step(batch[0], t, noise)
​
        if step % 500 == 0:
            print(" epoch: ", epoch, " step: ", step, " Loss: ", loss)
    end_time = time.time()
    times = end_time - begin_time
    print("training time:", times, "s")
    # 展示随机采样效果
    unet_model.set_train(False)
    samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)
    plt.imshow(samples[-1][5].reshape(image_size, image_size, channels), cmap="gray")
print("Training Success!")

输出:

 epoch:  0  step:  0  Loss:  0.43375123
 epoch:  0  step:  500  Loss:  0.113769315
 epoch:  0  step:  1000  Loss:  0.08649178
 epoch:  0  step:  1500  Loss:  0.067664884
 epoch:  0  step:  2000  Loss:  0.07234038
 epoch:  0  step:  2500  Loss:  0.043936778
 epoch:  0  step:  3000  Loss:  0.058127824
 epoch:  0  step:  3500  Loss:  0.049789283
training time: 922.3438229560852 s
 epoch:  1  step:  0  Loss:  0.05088563
 epoch:  1  step:  500  Loss:  0.051174678
 epoch:  1  step:  1000  Loss:  0.04455947
 epoch:  1  step:  1500  Loss:  0.055165425
 epoch:  1  step:  2000  Loss:  0.043942295
 epoch:  1  step:  2500  Loss:  0.03274461
 epoch:  1  step:  3000  Loss:  0.048117325
 epoch:  1  step:  3500  Loss:  0.063063145
training time: 937.5596783161163 s
 epoch:  2  step:  0  Loss:  0.052893892
 epoch:  2  step:  500  Loss:  0.05721748
 epoch:  2  step:  1000  Loss:  0.057248186
 epoch:  2  step:  1500  Loss:  0.048806388
 epoch:  2  step:  2000  Loss:  0.05007638
 epoch:  2  step:  2500  Loss:  0.04337231
 epoch:  2  step:  3000  Loss:  0.043207955
 epoch:  2  step:  3500  Loss:  0.034530163
training time: 947.6374666690826 s
 epoch:  3  step:  0  Loss:  0.04867614
 epoch:  3  step:  500  Loss:  0.051636297
 epoch:  3  step:  1000  Loss:  0.03338969
 epoch:  3  step:  1500  Loss:  0.0420174
 epoch:  3  step:  2000  Loss:  0.052145053
 epoch:  3  step:  2500  Loss:  0.03905913
 epoch:  3  step:  3000  Loss:  0.07621498
 epoch:  3  step:  3500  Loss:  0.06484105
training time: 957.7780408859253 s
 epoch:  4  step:  0  Loss:  0.046281893
 epoch:  4  step:  500  Loss:  0.03783619
 epoch:  4  step:  1000  Loss:  0.0587488
 epoch:  4  step:  1500  Loss:  0.06974746
 epoch:  4  step:  2000  Loss:  0.04299112
 epoch:  4  step:  2500  Loss:  0.027945498
 epoch:  4  step:  3000  Loss:  0.045338146
 epoch:  4  step:  3500  Loss:  0.06362417
training time: 955.6116819381714 s
 epoch:  5  step:  0  Loss:  0.04781142
 epoch:  5  step:  500  Loss:  0.032488734
 epoch:  5  step:  1000  Loss:  0.061507083
 epoch:  5  step:  1500  Loss:  0.039130375
 epoch:  5  step:  2000  Loss:  0.034972396
 epoch:  5  step:  2500  Loss:  0.039485026
 epoch:  5  step:  3000  Loss:  0.06690869
 epoch:  5  step:  3500  Loss:  0.05355365
training time: 951.7758958339691 s
 epoch:  6  step:  0  Loss:  0.04807706
 epoch:  6  step:  500  Loss:  0.021469856
 epoch:  6  step:  1000  Loss:  0.035354104
 epoch:  6  step:  1500  Loss:  0.044303045
 epoch:  6  step:  2000  Loss:  0.040063944
 epoch:  6  step:  2500  Loss:  0.02970439
 epoch:  6  step:  3000  Loss:  0.041152682
 epoch:  6  step:  3500  Loss:  0.02062454
training time: 955.2220208644867 s
 epoch:  7  step:  0  Loss:  0.029668871
 epoch:  7  step:  500  Loss:  0.028485576
 epoch:  7  step:  1000  Loss:  0.029675964
 epoch:  7  step:  1500  Loss:  0.052743085
 epoch:  7  step:  2000  Loss:  0.03664278
 epoch:  7  step:  2500  Loss:  0.04454907
 epoch:  7  step:  3000  Loss:  0.043067697
 epoch:  7  step:  3500  Loss:  0.0619511
training time: 952.6654670238495 s
 epoch:  8  step:  0  Loss:  0.055328347
 epoch:  8  step:  500  Loss:  0.035807922
 epoch:  8  step:  1000  Loss:  0.026412832
 epoch:  8  step:  1500  Loss:  0.051044375
 epoch:  8  step:  2000  Loss:  0.05474911
 epoch:  8  step:  2500  Loss:  0.044595096
 epoch:  8  step:  3000  Loss:  0.034082986
 epoch:  8  step:  3500  Loss:  0.02653109
training time: 961.9374921321869 s
 epoch:  9  step:  0  Loss:  0.039675284
 epoch:  9  step:  500  Loss:  0.046295933
 epoch:  9  step:  1000  Loss:  0.031403508
 epoch:  9  step:  1500  Loss:  0.028816734
 epoch:  9  step:  2000  Loss:  0.06530296
 epoch:  9  step:  2500  Loss:  0.051451046
 epoch:  9  step:  3000  Loss:  0.037913296
 epoch:  9  step:  3500  Loss:  0.030541396
training time: 974.643147945404 s
Training Success!

七、推理过程(从模型中采样)

从模型中采样,只使用上面定义的采样函数:

# 采样64个图片
unet_model.set_train(False)
samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)

输出:

sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]

# 展示一个随机效果
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")

cmap="gray")

输出:

<matplotlib.image.AxesImage at 0x7f5175ea1690>

这个模型产生一件衣服!

创建去噪过程的gif:

import matplotlib.animation as animation
​
random_index = 53
​
fig = plt.figure()
ims = []
for i in range(timesteps):
    im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])
​
animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=100)
animate.save('diffusion.gif')
plt.show()

输出:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/781386.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数据库系统原理练习 | 作业1-第1章绪论(附答案)

整理自博主本科《数据库系统原理》专业课完成的课后作业&#xff0c;以便各位学习数据库系统概论的小伙伴们参考、学习。 *文中若存在书写不合理的地方&#xff0c;欢迎各位斧正。 专业课本&#xff1a; 目录 一、选择题 二&#xff1a;简答题 三&#xff1a;综合题 一、选择…

【数据库】MySQL基本操作语句

目录 一、SQL语句 1.1 SQL分类 1.2 SQL语言规范 1.3 数据库对象与命名 1.3.1 数据库的组件(对象)&#xff1a; 1.3.2 命名规则&#xff1a; 1.4 SQL语句分类 二、基本命令 2.1 查看帮助信息 2.2 查看支持的字符集 2.3 查看默认使用的字符集 2.4 修改默认字符集 2.5…

Camera Raw:编辑 - 校准

Camera Raw “编辑”模块中的校准 Calibration面板设计初衷是校准相机所采集的 R、G、B 色彩信息&#xff0c;使相机的 RGB 色域范围尽可能与标准 RGB 色域范围重合。不过&#xff0c;现在多用于创意调色。通过调整红、绿、蓝三个原色的色相和饱和度&#xff0c;以及阴影的色调…

HTTP长连接

长连接优点 HTTP为什么要开启长连接呢? 主要是为了节省建立的时间,请求可以复用同一条TCP链路,不用重复进行三握+四挥 如果没有长连接,每次请求都做三握+四挥 如果有长链接,在一个 TCP 连接中可以持续发送多份数据而不会断开连接,即请求可以复用TCP链路 长连接缺点 …

数字信号处理及MATLAB仿真(3)——量化的其他概念

上回书说到AD转换的两个步骤——量化与采样两个步骤。现在更加深入的去了解以下对应的概念。学无止境&#xff0c;要不断地努力才有好的收获。万丈高楼平地起&#xff0c;唯有打好基础&#xff0c;才能踏实前行。 不说了&#xff0c;今天咱们继续说说这两个步骤&#xff0c;首先…

归并排序的实现(递归与非递归)

概念 基本思想&#xff1a;归并排序&#xff08;MERGE-SORT&#xff09;是建立在归并操作上的一种有效的排序算法,该算法是采用分治法&#xff08;Divide andConquer&#xff09;的一个非常典型的应用。将已有序的子序列合并&#xff0c;得到完全有序的序列&#xff1b;即先使…

一篇文章搞懂qt图形视图框架setRect和setPos函数的意义

setRect()函数 三个坐标系我就不多说了&#xff0c;view原点默认在左上角&#xff0c;scene和item的原点默认都在中心位置。 注意&#xff1a;此函数并不能设置一个item的位置&#xff0c;我的理解是当一个item调用该函数时&#xff0c;会构建一个一个item的坐标系&#xff0c…

秋招提前批面试经验分享(下)

⭐️感谢点开文章&#x1f44b;&#xff0c;欢迎来到我的微信公众号&#xff01;我是恒心&#x1f60a; 一位热爱技术分享的博主。如果觉得本文能帮到您&#xff0c;劳烦点个赞、在看支持一下哈&#x1f44d;&#xff01; ⭐️我叫恒心&#xff0c;一名喜欢书写博客的研究生在读…

vue3+antd 实现点击按钮弹出对话框

格式1&#xff1a;确认对话框 按钮&#xff1a; 点击按钮之后&#xff1a; 完整代码&#xff1a; <template><div><a-button click"showConfirm">Confirm</a-button></div> </template> <script setup> import {Mod…

关于Web开发的详细介绍

目录 一、什么是Web&#xff1f; 二、Web网站的工作流程和开发模式 &#xff08;1&#xff09;简单介绍 &#xff08;2&#xff09;工作流程 1、第一步 2、第二步 &#xff08;3&#xff09;Web网站的开发模式 1、前后端分离开发模式 ​编辑2、混合开发模式 三、开发W…

DPDK源码分析之(1)libmbuf模块

DPDK源码分析之(1)libmbuf模块 Author&#xff1a;OnceDay Date&#xff1a;2024年7月2日 漫漫长路&#xff0c;有人对你笑过嘛… 全系列文档可参考专栏&#xff1a;源码分析_Once-Day的博客-CSDN博客 参考文档&#xff1a; DPDK downloadGetting Started Guide for Linux…

业界数据架构的演变

目录 一、概述 二、业务处理-单体架构 三、业务处理-微服务架构 四、数据分析-大数据Lambda架构 五、数据分析-Kappa架构 六、数据分析-LambdaKappa混合架构 七、湖仓一体架构 一、概述 近年来随着越来越多的大数据技术被开源&#xff0c;例如&#xff1a;HDFS、Spark等…

最小表示法

#define _CRT_SECURE_NO_WARNINGS #include<bits/stdc.h> using namespace std;const int N (int)3e5 5; int n; int a[N * 2];int main() {cin >> n;for (int i 0; i < n; i) {cin >> a[i];a[i n] a[i]; // 构造成链}int l 0, r 1; // 一开始 r …

进入防火墙Web管理页面(eNSP USG6000V)和管理员模块

1、进入防火墙Web管理页面 USG系列是华为提供的一款高端防火墙产品&#xff0c;其特点在于提供强大的安全防护能力和灵活的扩展性。 以eNSP中的USG6000为例&#xff1a; MGMT口&#xff08;web管理口&#xff09;&#xff1a;对应设备上的G0/0/0口&#xff0c;上面初始配有一…

算法-常见数据结构设计

文章目录 1. 带有setAll功能的哈希表2. LRU缓存结构3. O(1)时间插入删除随机(去重)4. O(1)时间插入删除随机(不去重)5. 快速获取数据流中的中位数6. 最大频率栈7. 全O(1)结构8. LFU缓存结构 本节的内容比较难, 大多是leetcodeHard难度级别的题目 1. 带有setAll功能的哈希表 哈希…

QCustomPlot+ vs2022+ qt

零、printSupport 步骤一&#xff1a;下载QCustomPlot 访问QCustomPlot的官网 QCustomPlot 下载最新版本的源代码。 步骤二&#xff1a;配置项目 创建新的Qt项目&#xff1a; 打开VS2022&#xff0c;创建一个新的Qt Widgets Application项目。 将QCustomPlot源代码添加到项目…

集合复习(java)

文章目录 Collection 接口Collection结构图Collection接口中的方法Iterator 与 Iterable 接口Collection集合遍历方式迭代器遍历增强 for 遍历 List&#xff08;线性表&#xff09;List特有方法ArrayList&#xff08;可变数组&#xff09;ArrayList 底层原理ArrayList 底层原理…

【UML用户指南】-30-对体系结构建模-模式和框架

目录 1、机制 2、框架 3、常用建模技术 3.1、对设计模式建模 3.2、对体系结构模式建模 用模式来详述形成系统体系结构的机制和框架。通过清晰地标识模式的槽、标签、按钮和刻度盘 在UML中&#xff0c; 对设计模式&#xff08;也叫做机制&#xff09;建模&#xff0c;将它…

前端技术(三)—— javasctipt 介绍:jQuery方法和点击事件介绍(补充)

6. 常用方法 ● addClass() 为jQuery对象添加一个或多个class <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0">&…

Efficient Contrastive Learning for Fast and Accurate Inference on Graphs

发表于:ICML24 推荐指数: #paper/⭐⭐⭐ 创新点一颗星,证明三颗星(证明的不错,值得借鉴,但是思路只能说还行吧) 如图, 本文采取的创新点就是MLP用原始节点,GCN用邻居节点的对比学习.这样,可以加快运算速度 L E C L − 1 ∣ V ∣ ∑ v ∈ V 1 ∣ N ( v ) ∣ ∑ u ∈ N ( v )…