WGAN的实现代码(pytorch版)

更新时间:2023-07-05 00:25:19 阅读: 评论:0

WGAN的实现代码(pytorch版)
WGAN的实现⽅法
⼀、GAN存在的问题
1、原始的GAN训练困难。需要很⼩⼼地平衡⽣成器和判别器的训练程度,如果判别器过强,会导致⽣成器梯度消失严重,难以进化,进⽽⼤⼤增加训练所需时间。
2、⽣成器和判别器的loss⽆法指⽰进程,也就是说,我们⽆法通过⽣成器与判别器的loss来判断我们⽣成的图像是否到达了我们所满意的情况。只能通过显⽰训练图像⾃⾏感受训练程度。
3、⽣成样本缺乏多样性。容易产⽣模型崩坏,即⽣成的图像中有着⼤量的重复图像。
⼆、WGAN的优点所在
1、彻底解决GAN训练不稳定的问题,不再需要⼩⼼平衡⽣成器和判别器的训练程度。
2、基本解决了collap mode的问题,确保了⽣成样本的多样性 。
3、训练过程中终于有⼀个像交叉熵、准确率这样的数值来指⽰训练的进程,这个数值越⼩代表GAN训
练得越好,代表⽣成器产⽣的图像质量越⾼。
4、以上⼀切好处不需要精⼼设计的⽹络架构,最简单的多层全连接⽹络就可以做到。(DCGAN依靠的是对判别器和⽣成器的架构进⾏实验枚举,最终找到⼀组⽐较好的⽹络架构设置)
三、改进流程
1、判别器最后⼀层去掉sigmoid。sigmoid函数容易出现梯度消失的情况。
2、⽣成器和判别器的loss不取log
3、每次更新判别器的参数之后把它们的绝对值截断到不超过⼀个固定常数c
4、不要⽤基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也⾏
算法流程如下
四、代码实现(pytorch)
代码添加了⼀些便于理解的注释
import argpar
tigerwoodsimport os
import numpy as np
anchoring
import math
import sysbaerman
ansforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datats
from torch.autograd import Variable
as nn
functional as F
import torch
os.makedirs("images", exist_ok=True)#⽅法是递归⽬录创建功能。如果exists_ok为Fal(默认值),则如果⽬标⽬录已存在,则引发OSError错误,True则不会
parr = argpar.ArgumentParr()
parr.add_argument("--n_epochs",type=int, default=200,help="number of epochs of training")
parr.add_argument("--batch_size",type=int, default=64,help="size of the batches")
parr.add_argument("--lr",type=float, default=0.00005,help="learning rate")
parr.add_argument("--n_cpu",type=int, default=8,help="number of cpu threads to u during batch generation")
parr.add_argument("--latent_dim",type=int, default=100,help="dimensionality of the latent space")
parr.add_argument("--img_size",type=int, default=28,help="size of each image dimension")
parr.add_argument("--channels",type=int, default=1,help="number of image channels")
parr.add_argument("--n_critic",type=int, default=5,help="number of training steps for discriminator per iter")
parr.add_argument("--clip_value",type=float, default=0.01,help="lower and upper clip value for dis
c. weights")
parr.add_argument("--sample_interval",type=int, default=400,help="interval betwen image samples")
opt = parr.par_args()
print(opt)
img_shape =(opt.channels, opt.img_size, opt.img_size)
cuda =True if torch.cuda.is_available()el Fal
class Generator(nn.Module):
def__init__(lf):
你令爱了不起
super(Generator, lf).__init__()
def block(in_feat, out_feat, normalize=True):
layers =[nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat,0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
*block(opt.latent_dim,128, normalize=Fal),
*block(128,256),
济南
*block(256,512),
杰克逊的歌
*block(512,1024),
nn.Linear(1024,int(np.prod(img_shape))),
nn.Tanh()
)
def forward(lf, z):
img = lf.model(z)
img = img.view(img.shape[0],*img_shape)
return img
class Discriminator(nn.Module):
def__init__(lf):
super(Discriminator, lf).__init__()
nn.Linear(int(np.prod(img_shape)),512),#np.prod(img_shape)连乘操作,长*宽*深度
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512,256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256,1),#改进1、判别器最后⼀层去掉sigmoid。sigmoid函数容易出现梯度消失的情况。
)
def forward(lf, img):
img_flat = img.view(img.shape[0],-1)
validity = lf.model(img_flat)
return validity
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
if cuda:
generator.cuda()
discriminator.cuda()
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datats.MNIST(
"../../data/mnist",configuration
train=True,
download=True,
transform=transforms.Compo([transforms.ToTensor(), transforms.Normalize([0.5],[0.5])]),
),
batch_size=opt.batch_size,
shuffle=True,
)
# Optimizers
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=opt.lr)#改进4、不要⽤基于动量的优化算法(包括momentum和Adam),推荐RMSProp,S GD也⾏
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=opt.lr)
Tensor = torch.cuda.FloatTensor if cuda el torch.FloatTensor
# ----------
#  Training
# ----------
batches_done =0
north是什么意思for epoch in range(opt.n_epochs):
for i,(imgs, _)in enumerate(dataloader):
# Configure input
real_imgs = pe(Tensor))
# ---------------------
#  Train Discriminator
# ---------------------
_grad()
# Sample noi as generator input
z = Variable(Tensor(al(0,1,(imgs.shape[0], opt.latent_dim))))
# Generate a batch of images
fake_imgs = generator(z).detach()#detach 的意思是,这个数据和⽣成它的计算图“脱钩”了,即梯度传到它那个地⽅就停了,不再继续往前传播#要少计算⼀次 generator 的所有参数的梯度,同时,也不必刻意保存⼀次计算图,占⽤不必要的内存。
# Adversarial loss
loss_D =-an(discriminator(real_imgs))+ an(discriminator(fake_imgs))#改进2、⽣成器和判别器的loss不取log
loss_D.backward()
optimizer_D.step()#只更新discriminator的参数
# Clip weights of discriminator
for p in discriminator.parameters():
p.data.clamp_(-opt.clip_value, opt.clip_value)
# Train the generator every n_critic iterations
if i % opt.n_critic ==0:
# -----------------
#  Train Generator
# -----------------
_grad()
# Generate a batch of images
gen_imgs = generator(z)
# Adversarial loss
loss_G =-an(discriminator(gen_imgs))
loss_G.backward()
optimizer_G.step()#只更新 generator 的参数
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
%(epoch, opt.n_epochs, batches_done %len(dataloader),len(dataloader), loss_D.item(), loss_G.item()) )
if batches_done % opt.sample_interval ==0:
save_image(gen_imgs.data[:25],"images/%d.png"% batches_done, nrow=5, normalize=True)
batches_done +=1
注意此处
fake_imgs = generator(z).detach()
要少计算⼀次 generator 的所有参数的梯度,同时,也不必刻意保存⼀次计算图,占⽤不必要的内存。discriminator ⽐ generator 简单才使⽤此⽅案(通常也是如此)
isa

本文发布于:2023-07-05 00:25:19,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/90/167291.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:训练   判别   成器   梯度
相关文章
留言与评论(共有 0 条评论)
   
验证码:
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图