首页 > 作文

MindSpore实现生成对抗网络

更新时间:2023-04-07 15:17:21 阅读: 评论:0

MindSpore实现生成对抗网络-GAN (2)

啦啦啦啦,好久没写文章了。今天来兑现诺言,CGAN的MindSpore实现它来了~~~
在前面两篇博文中,我介绍了MindSpore实现历届党代会召开时间GAN的方法,并使用DCGAN实现了手写数字的生成:

MindSpore实现生成网络(1)MindSpore实现生成网络(2)

但是它生成的内容是随机。自然地,就有人开始想:能不能让GAN生成我们想要的内容呢?于是乎,就有了CGAN。

简单的CGAN说明

CGAN的核心在于将属性y作为输入,融入到判别器和生成器中。如下图所示(图源于网络):

从这个图中可以看到,判别器的输入除了样本x,还多了属性标签y(在手写数字生成中,y可以是数字标签的onehot编码)。这样一来,判别器和生成器的学习目标就变成了条件y下的条件概率分布。在判别器中,无论输入的是真样本还是假样本,都需要加上条件y。
还有一点需要注意的是,输入的标签y不但要在输入时和z、x融合,在判别器和生成器的每一层特征里都要和特征融合。否则可能“学不好标签y”。

接下来,就是令人激动的代码展示环节

代码实现

首先,导入所需要的包

import osimport numpy as npfrom numpy.core.fromnumeric import sizeimport matplotlib.pyplot as pltfrom mindspore import nnimport mindspore.datat as dsimport mindspore.datat.transforms.c_transforms as CTimport mindspore.ops.operations as Pimport mindspore.ops.functional as Fimport mindspore.ops.composite as Cfrom mindspore.train.datat_helper import DatatHelper, connect_network_with_datatfrom mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,                                       _get_parallel_mode)from mindspore.context import ParallelModefrom mindspore.nn.wrap.grad_reducer import DistributedGradReducerfrom mindspore import Tensorfrom mindspore.common import dtype as mstypefrom mindspore import contextfrom cells import SigmoidCrossEntropyWithLogits##cells是前两篇博文定义的一个py文件

然后,定义判别器和生成器。与之前介绍的相比,主要的不同是输入和construct函数。
生成器的输入除了隐码z,还有标签y。生成器同样。正如前面所说,每一层的输入都要和标签融合一下,所以使用了P.concat(1)对特征和y进行连接。代码中的+10是因为我所使用的标签是onehot编码,是一个10维的向量。也可以直接使用原先的编码。

class Discriminator(nn.Cell):    def __init__(lf, input_dims, auto_prefix=True):        super().__init__(auto_prefix=auto_prefix)        lf.fc1 = nn.Den(input_dims + 10, 256)        lf.fc2 = nn.Den(256 + 10, 128)        lf.fc3 = nn.Den(128 + 10, 1)        lf.lrelu = nn.LeakyReLU()        lf.concat = P.Concat(1)    def construct(lf, x, label):        x = lf.concat((x, label))        x = lf.fc1(x)        x = lf.lrelu(x)        x = lf.concat((x, label))        x = lf.fc2(x)        x = lf.lrelu(x)        x = lf.concat((x, label))        x = lf.fc3(x)        return xclass Generator(nn.Cell):    def __init__(lf, input_dims, output_dim, auto_prefix=True):        super().__init__(auto_prefix=auto_prefix)        lf.fc1 = nn.Den(input_dims + 10, 128)        lf.fc2 = nn.Den(128 + 10, 256)        lf.fc3 = nn.Den(256 + 10, output_dim)        lf.relu = nn.ReLU()        lf.tanh = nn.Tanh()        lf.concat = P.Concat(1)    def construct(lf, x, label):        x = lf.concat((x, label))        x = lf.fc1(x)        x = lf.relu(x)        x = lf.concat((x, label))        x = lf.fc2(x)        x = lf.relu(x)        x = lf.concat((x, label))        x = lf.fc3(x)        x = lf.tanh(x)        return x

既然判别器和生成器的做了修改,那么DisWithLossCell和GenWithLossCell也要修改,同理,TrainOneStepCell也是。主要的改变就是输入加入label,相应的地方做出修改就可,没有什么特别要说的。

class DisWithLossCell(nn.Cell):def __init__(lf, netG, netD, loss_fn, auto_prefix=True):super(DisWithLossCell, lf).__init__(auto_prefix=auto_prefix)lf.netG = netGlf.netD = netDlf.loss_fn = loss_fndef construct(lf, real_data, latent_code, label):real_out = lf.netD(real_data, label)real_loss = lf.loss_fn(real_out, F.ones_like(real_out))fake_data = lf.netG(latent_code, label)fake_out = lf.netD(fake_data, label)fake_loss = lf.loss_fn(fake_out, F.zeros_like(fake_out))loss_D = real_loss + fake_lossreturn loss_Dclass GenWithLossCell(nn.Cell):def __init__(lf, netG, netD, loss_fn, auto_prefix=True):super(GenWithLossCell, lf).__init__(auto_prefix=auto_prefix)lf.netG = netGlf.netD = netDlf.loss_fn = loss_fndef construct(lf, latent_code, label):fake_data = lf.netG(latent_code, label)fake_out = lf.netD(fake_data, label)loss_G = lf.loss_fn(fake_out, F.ones_like(fake_out))return loss_Gclass TrainOneStepCell(nn.Cell):def __init__(lf,netG,netD,optimizerG: nn.Optimizer,optimizerD: nn.Optimizer,ns=1.0,auto_prefix=True):super(TrainOneStepCell, lf).__init__(auto_prefix=auto_prefix)lf.netG = netGlf.netG.t_grad()lf.netG.add_flags(defer_inline=True)lf.netD = netDlf.netD.t_grad()lf.netD.add_flags(defer_inline=True)lf.weights_G = optimizerG.parameterslf.optimizerG = optimizerGlf.weights_D = optimizerD.parameterslf.optimizerD = optimizerDlf.grad = C.GradOperation(get_by_list=True, ns_param=True)lf.ns = nslf.reducer_flag = Fallf.grad_reducer_G = F.identity借款起诉状lf.grad_reducer_D = F.identitylf.parallel_mode = _get_parallel_mode()if lf.parallel_mode in (Paral如何使用投影仪lelMode.DATA_PARALLEL,ParallelMode.HYBRID_PARALLEL):lf.reducer_flag = Trueif lf.reducer_flag:mean = _get_gradients_mean()degree = _get_device_num()l广西民族师范学院f.grad_reducer_G = DistributedGradReducer(lf.weights_G, mean, degree)lf.grad_reducer_D = DistributedGradReducer(lf.weights_D, mean, degree)def trainD(lf, real_data, latent_code, label, loss, loss_net, grad,optimizer, weights, grad_reducer):ns = P.Fill()(P.DType()(loss), P.Shape()(loss), lf.ns)grads = grad(loss_net, weights)(real_data, latent_code, label, ns)grads = grad_reducer(grads)return F.depend(loss, optimizer(grads))def trainG(lf, latent_code, label, loss, loss_net, grad, optimizer,weights, grad_reducer):ns = P.Fill()(P.DType()(loss), P.Shape()(loss), lf.ns)grads = grad(loss_net, weights)(latent_code, label, ns)grads = grad_reducer(grads)return F.depend(loss, optimizer(grads))def construct(lf, real_data, latent_code, label):loss_D = lf.netD(real_data, latent_code, label)loss_G = lf.netG(latent_code, label)d_out = lf.trainD(real_data, latent_code, label, loss_D, lf.netD,lf.grad, lf.optimizerD, lf.weights_D,lf.grad_reducer_D)g_out = lf.trainG(latent_code, label, loss_G, lf.netG, lf.grad,lf.optimizerG, lf.weights_G,lf.grad_r麦当劳广告语educer_G)return d_out, g_out

接下来就是训练的部分了,和之前的几乎一样。每训练一个epoch,我都进行一次测试,生成4列0到9。从结果可以看到,通过控制输入的标签y,可以得到想要的数字。

def create_datat(data_path,flatten_size,batch_size,repeat_size=1,num_parallel_workers=1):mnist_ds = ds.MnistDatat(data_path)type_cast_op = CT.TypeCast(mstype.float32)onehot_op = CT.OneHot(num_class=10)mnist_ds = mnist_ds.map(input_columns="label",operations=onehot_op,num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(input_columns="label",operations=type_cast_op,num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(input_columns="image",operations=lambda x:((x - 127.5) / 127.5).astype("float32"),num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(input_columns="image",operations=lambda x: (x.reshape((flatten_size, ))),num_parallel_workers=num_parallel_workers)buffer_size = 10000mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)mnist_ds = mnist_ds.repeat(repeat_size)return mnist_dsdef one_hot(num_class=10, arr=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]):return np.eye(num_class)[arr]context.t_context(mode=context.GRAPH_MODE, device_target="GPU")batch_size = 128input_dim = 100epochs = 100lr = 0.001ds = create_datat(os.path.join("./data/MNIST_Data", "train"),flatten_size=28 * 28,batch_size=batch_size,num_parallel_workers=2)netG = Generator(input_dim, 28 * 28)netD = Discriminator(28 * 28)loss = SigmoidCrossEntropyWithLogits()netG_with_loss = GenWithLossCell(netG, netD, loss)netD_with_loss = DisWithLossCell(netG, netD, loss)optimizerG = nn.Adam(netG.trainable_params(), lr)optimizerD = nn.Adam(netD.trainable_params(), lr)net_train = TrainOneStepCell(netG_with_loss, netD_with_loss, optimizerG,optimizerD)datat_helper = DatatHelper(ds, epoch_num=epochs, datat_sink_mode=True)net_train = connect_network_with_datat(net_train, datat_helper)netG.t_train()netD.t_train()for epoch in range(epochs):step = 1for data in datat_helper:imgs = data[0]label = data[1]latent_code = Tensor(np.random.normal(size=(batch_size, input_dim)),dtype=mstype.float32)dout, gout = net_train(imgs, latent_code, label)if step % 100 == 0:print("epoch {} step {}, d_loss is {:.4f}, g_loss is {:.4f}".format(epoch, step, dout.asnumpy(), gout.asnumpy()))step += 1for digit in range(10):for i in range(4):latent_code = Tensor(np.random.normal(size=(1, input_dim)),dtype=mstype.float32)label = Tensor(one_hot(arr=[digit]), dtype=mstype.float32)gen_imgs = netG(latent_code, label).asnumpy()gen_imgs = gen_imgs.reshape((28, 28))plt.subplot(10, 4, digit * 4 + i + 1)plt.imshow(gen_imgs * 127.5 + 127.5, cmap="gray")plt.axis("off")plt.savefig("./images/{}.jpg".format(epoch))

乍一看,效果还行。如果将基本的网络结构换成DCGAN,效果将会更好,有兴趣的可以去试一下。
关于Mindspore实现GAN的教程到这里就暂时结束了(只是暂时,考虑后续继续实现别的GAN)。接下来有可能会继续实现一些有趣的项目,比如pix2pix,DQN玩flappy bird,神经风格迁移等,慢慢地丰富MindSpore的生态吧。
最近事情有点多,不知道什么时候才能开始。原计划这三个GAN是一周完成的,结果嘛。。。不谈了。。。
三个教程的代码可以从这儿得到:mindpore实现gan。
最后求个star~~~谢谢。

本文地址:https://blog.csdn.net/jijianzhang2204/article/details/109897030

本文发布于:2023-04-07 15:17:19,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/zuowen/8cf22a533950c65d313a66232541f9aa.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

本文word下载地址:MindSpore实现生成对抗网络.doc

本文 PDF 下载地址:MindSpore实现生成对抗网络.pdf

标签:生成器   标签   的是   网络
相关文章
留言与评论(共有 0 条评论)
   
验证码:
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图