最简单易懂的GAN(生成对抗网络)教程:从理论到实践(附代码)

更新时间:2023-07-14 03:02:53 阅读: 评论:0

最简单易懂的GAN(⽣成对抗⽹络)教程:从理论到实践(附
代码)
之前
GAN⽹络是近两年深度学习领域的新秀,⽕的不⾏,本⽂旨在浅显理解传统GAN,分享学习⼼得。现有GAN⽹络⼤多数代码实现使⽤Python、torch等语⾔,这⾥,后⾯⽤matlab搭建⼀个简单的GAN⽹络,便于理解GAN原理。
头寸调拨GAN的⿐祖之作是2014年NIPS⼀篇⽂章:,可以细细品味。
●分享⼀个⽬前各类GAN的⼀个论⽂整理集合
●再分享⼀个⽬前各类GAN的⼀个代码整理集合
开始
我们知道GAN的思想是是⼀种⼆⼈零和博弈思想(two-player game),博弈双⽅的利益之和是⼀个常数,⽐如两个⼈掰⼿腕,假设总的空间是⼀定的,你的⼒⽓⼤⼀点,那你就得到的空间多⼀点,相应的
我的空间就少⼀点,相反我⼒⽓⼤我就得到的多⼀点,但有⼀点是确定的就是,我两的总空间是⼀定的,这就是⼆⼈博弈,但是呢总利益是⼀定的。
引申到GAN⾥⾯就是可以看成,GAN中有两个这样的博弈者,⼀个⼈名字是⽣成模型(G),另⼀个⼈名字是判别模型(D)。他们各⾃有各⾃的功能。
相同点是:
●这两个模型都可以看成是⼀个⿊匣⼦,接受输⼊然后有⼀个输出,类似⼀个函数,⼀个输⼊输出映射。
不同点是:
●⽣成模型功能:⽐作是⼀个样本⽣成器,输⼊⼀个噪声/样本,然后把它包装成⼀个逼真的样本,也就是输出。
●判别模型:⽐作⼀个⼆分类器(如同0-1分类器),来判断输⼊的样本是真是假。(就是输出值⼤于0.5还是⼩于0.5);
直接上⼀张个⼈觉得解释的好的图说明:
在之前,我们⾸先明⽩在使⽤GAN的时候的2个问题
●我们有什么?
我成长了作文
⽐如上⾯的这个图,我们有的只是真实采集⽽来的⼈脸样本数据集,仅此⽽已,⽽且很关键的⼀点是我们连⼈脸数据集的类标签都没有,也就是我们不知道那个⼈脸对应的是谁。
●我们要得到什么
⾄于要得到什么,不同的任务得到的东西不⼀样,我们只说最原始的GAN⽬的,那就是我们想通过输⼊⼀个噪声,模拟得到⼀个⼈脸图像,这个图像可以⾮常逼真以⾄于以假乱真。有关读书的作文
好了再来理解下GAN的两个模型要做什么。⾸先判别模型,就是图中右半部分的⽹络,直观来看就是⼀个简单的神经⽹络结构,输⼊就是⼀副图像,输出就是⼀个概率值,⽤于判断真假使⽤(概率值⼤于0.5那就是真,⼩于0.5那就是假),真假也不过是⼈们定义的概率⽽已。其次是⽣成模型,⽣成模型要做什么呢,同样也可以看成是⼀个神经⽹络模型,输⼊是⼀组随机数Z,输出是⼀个图像,不再是⼀个数值⽽已。从图中可以看到,会存在两个数据集,⼀个是真实数据集,这好说,另⼀个是假的数据集,那这个数据集就是有⽣成⽹络造出来的数据集。好了根据这个图我们再来理解⼀下GAN的⽬标是要⼲什么:
●判别⽹络的⽬的:就是能判别出来属于的⼀张图它是来⾃真实样本集还是假样本集。假如输⼊的是真样本,⽹络输出就接近1,输⼊的是假样本,⽹络输出接近0,那么很完美,达到了很好判别的⽬的。
●⽣成⽹络的⽬的:⽣成⽹络是造样本的,它的⽬的就是使得⾃⼰造样本的能⼒尽可能强,强到什么程度呢,你判别⽹络没法判断我是真样本还是假样本。
有了这个理解我们再来看看为什么叫做对抗⽹络了。判别⽹络说,我很强,来⼀个样本我就知道它是来⾃真样本集还是假样本集。⽣成⽹络就不服了,说我也很强,我⽣成⼀个假样本,虽然我⽣成⽹络知道是假的,但是你判别⽹络不知道呀,我包装的⾮常逼真,以⾄于判别⽹络⽆法判断真假,那么⽤
输出数值来解释就是,⽣成⽹络⽣成的假样本进去了判别⽹络以后,判别⽹络给出的结果是⼀个接近0.5的值,极限情况就是0.5,也就是说判别不出来了,这就是纳什平衡了。
由这个分析可以发现,⽣成⽹络与判别⽹络的⽬的正好是相反的,⼀个说我能判别的好,⼀个说我让你判别不好。所以叫做对抗,叫做博弈。那么最后的结果到底是谁赢呢?这就要归结到设计者,也就是我们希望谁赢了。作为设计者的我们,我们的⽬的是要得到以假乱真的样本,那么很⾃然的我们希望⽣成样本赢了,也就是希望⽣成样本很真,判别⽹络能⼒不⾜以区分真假样本位置。
再理解
知道了GAN⼤概的⽬的与设计思路,那么⼀个很⾃然的问题来了就是我们该如何⽤数学⽅法解决这么⼀个对抗问题。这就涉及到如何训练这样⼀个⽣成对抗⽹络模型了,还是先上⼀个图,⽤图来解释最直接:
需要注意的是⽣成模型与对抗模型可以说是完全独⽴的两个模型,好⽐就是完全独⽴的两个神经⽹络模型,他们之间没有什么联系。
好了那么训练这样的两个模型的⼤⽅法就是:单独交替迭代训练。
什么意思?因为是2个⽹络,不好⼀起训练,所以才去交替迭代训练,我们⼀⼀来看。
假设现在⽣成⽹络模型已经有了(当然可能不是最好的⽣成⽹络),那么给⼀堆随机数组,就会得到⼀堆假的样本集(因为不是最终的⽣成模型,那么现在⽣成⽹络可能就处于劣势,导致⽣成的样本就不咋地,可能很容易就被判别⽹络判别出来了说这货是假冒的),但是先不管这个,假设我们现在有了这样的假样本集,真样本集⼀直都有,现在我们⼈为的定义真假样本集的标签,因为我们希望真样本集的输出尽可能为1,假样本集为0,很明显这⾥我们就已经默认真样本集所有的类标签都为1,⽽假样本集的所有类标签都为0. 有⼈会说,在真样本集⾥⾯的⼈脸中,可能张三⼈脸和李四⼈脸不⼀样呀,对于这个问题我们需要理解的是,我们现在的任务是什么,我们是想分样本真假,⽽不是分真样本中那个是张三label、那个是李四label。况且我们也知道,原始真样本的label我们是不知道的。回过头来,我们现在有了真样本集以及它们的label(都是1)、假样本集以及它们的label(都是0),这样单就判别⽹络来说,此时问题就变成了⼀个再简单不过的有监督的⼆分类问题了,直接送到神经⽹络模型中训练就完事了。假设训练完了,下⾯我们来看⽣成⽹络。
对于⽣成⽹络,想想我们的⽬的,是⽣成尽可能逼真的样本。那么原始的⽣成⽹络⽣成的样本你怎么知道它真不真呢?就是送到判别⽹络中,所以在训练⽣成⽹络的时候,我们需要联合判别⽹络⼀起才能达到训练的⽬的。什么意思?就是如果我们单单只⽤⽣成⽹络,那么想想我们怎么去训练?误差来源在哪⾥?细想⼀下没有,但是如果我们把刚才的判别⽹络串接在⽣成⽹络的后⾯,这样我们就知道真假了,也就有了误差了。所以对于⽣成⽹络的训练其实是对⽣成-判别⽹络串接的训练,就像图中显⽰的那样。好了那么现在来分析⼀下样本,原始的噪声数组Z我们有,也就是⽣成了假样本我们有,此时很关键的⼀点来了,我们要把这些假样本的标签都设置为1,也就是认为这些假样本在⽣成⽹络训练的时候是真样本。那么为什么要这样呢?我们想想,是不是这样才能起到迷惑判别器的⽬的,也才能使得⽣成的假样本逐渐逼近为正样本。好了,重新顺⼀下思路,现在对于⽣成⽹络的训练,我们有了样本集(只有假样本集,没有真样本集),有了对应的label(全为1),是不是就可以训练了?有⼈会问,这样只有⼀类样本,训练啥呀?谁说⼀类样本就不能训练了?只要有误差就⾏。还有⼈说,你这样⼀训练,判别⽹络的⽹络参数不是也跟着变吗?没错,这很关键,所以在训练这个串接的⽹络的时候,⼀个很重要的操作就是不要判别⽹络的参数发⽣变化,也就是不让它参数发⽣更新,只是把误差⼀直传,传到⽣成⽹络那块后更新⽣成⽹络的参数。这样就完成了⽣成⽹络的训练了。
在完成⽣成⽹络训练好,那么我们是不是可以根据⽬前新的⽣成⽹络再对先前的那些噪声Z⽣成新的假样本了,没错,并且训练后的假样本应该是更真了才对。然后⼜有了新的真假样本集(其实是新的
假样本集),这样⼜可以重复上述过程了。我们把这个过程称作为单独交替训练。我们可以实现定义⼀个迭代次数,交替迭代到⼀定次数后停⽌即可。这个时候我们再去看⼀看噪声Z⽣成的假样本会发现,原来它已经很真了。
看完了这个过程是不是感觉GAN的设计真的很巧妙,个⼈觉得最值得称赞的地⽅可能在于这种假样本在训练过程中的真假变换,这也是博弈得以进⾏的关键之处。
进⼀步
⽂字的描述相信已经让⼤多数的⼈知道了这个过程,下⾯我们来看看原⽂中⼏个重要的数学公式描述,⾸先我们直接上原始论⽂中的⽬标公式吧:
上述这个公式说⽩了就是⼀个最⼤最⼩优化问题,其实对应的也就是上述的两个优化过程。有⼈说如果不看别的,能达看到这个公式就拍案叫绝的地步,那就是机器学习的顶级专家,哈哈,真是前路漫漫。同时也说明这个简单的公式意义重⼤。
这个公式既然是最⼤最⼩的优化,那就不是⼀步完成的,其实对⽐我们的分析过程也是这样的,这⾥现优化D,然后在取优化G,本质上是两个优化问题,把拆解就如同下⾯两个公式:
优化D:
优化G:
初中生赚钱
可以看到,优化D的时候,也就是判别⽹络,其实没有⽣成⽹络什么事,后⾯的G(z)这⾥就相当于已经得到的假样本。优化D的公式的第⼀项,使的真样本x输⼊的时候,得到的结果越⼤越好,可以理解,因为需要真样本的预测结果越接近于1越好嘛。对于假样本,需要优化是的其结果越⼩越好,也就是D(G(z))越⼩越好,因为它的标签为0。但是呢第⼀项是越⼤,第⼆项是越⼩,这不⽭盾了,所以呢把第⼆项改成1-D(G(z)),这样就是越⼤越好,两者合起来就是越⼤越好。那么同样在优化G的时候,这个时候没有真样本什么事,所以把第⼀项直接却掉了。这个时候只有假样本,但是我们说这个时候是希望假样本的标签是1的,所以是D(G(z))越⼤越好,但是呢为了统⼀成1-D(G(z))的形式,那么只能是最⼩化1-D(G(z)),本质上没有区别,只是为了形式的统⼀。之后这两个优化模型可以合并起来写,就变成了最开始的那个最⼤最⼩⽬标函数了。
所以回过头来我们来看这个最⼤最⼩⽬标函数,⾥⾯包含了判别模型的优化,包含了⽣成模型的以假
乱真的优化,完美的阐释了这样⼀个优美的理论。
再进⼀步
有⼈说GAN强⼤之处在于可以⾃动的学习原始真实样本集的数据分布,不管这个分布多么的复杂,只要训练的⾜够好就可以学出来。针对这⼀点,感觉有必要好好理解⼀下为什么别⼈会这么说。
我们知道,传统的机器学习⽅法,我们⼀般都会定义⼀个什么模型让数据去学习。⽐如说假设我们知道原始数据属于⾼斯分布呀,只是不知道⾼斯分布的参数,这个时候我们定义⾼斯分布,然后利⽤数据去学习⾼斯分布的参数得到我们最终的模型。再⽐如说我们定义⼀个分类器,⽐如SVM,然后强⾏让数据进⾏东变西变,进⾏各种⾼维映射,最后可以变成⼀个简单的分布,SVM可以很轻易的进⾏⼆分类分开,其实SVM已经放松了这种映射关系了,但是也是给了⼀个模型,这个模型就是核映射(什么径向基函数等等),说⽩了其实也好像是你事先知道让数据该怎么映射⼀样,只是核映射的参数可以学习罢了。所有的这些⽅法都在直接或者间接的告诉数据你该怎么映射⼀样,只是不同的映射⽅法能⼒不⼀样。那么我们再来看看GAN,⽣成模型最后可以通过噪声⽣成⼀个完整的真实数据(⽐如⼈脸),说明⽣成模型已经掌握了从随机噪声到⼈脸数据的分布规律了,有了这个规律,想⽣成⼈脸还不容易。然⽽这个规律我们开始知道吗?显然不知道,如果让你说从随机噪声到⼈脸应该服从什么分布,你不可能知道。这是⼀层层映射之后组合起来的⾮常复杂的分布映射规律。然⽽GAN的机制可以学习到,也就是说GAN学习到了真实样本集的数据分布。
再拿原论⽂中的⼀张图来解释:
这张图表明的是GAN的⽣成⽹络如何⼀步步从均匀分布学习到正太分布的。原始数据x服从正太分布,这个过程你也没告诉⽣成⽹络说你得⽤正太分布来学习,但是⽣成⽹络学习到了。假设你改⼀下x的分布,不管什么分布,⽣成⽹络可能也能学到。这就是GAN可以⾃动学习真实数据的分布的强⼤之处。
还有⼈说GAN强⼤之处在于可以⾃动的定义潜在损失函数。什么意思呢,这应该说的是判别⽹络可以⾃动学习到⼀个好的判别⽅法,其实就是等效的理解为可以学习到好的损失函数,来⽐较好或者不好的判别出来结果。虽然⼤的loss函数还是我们⼈为定义的,基本上对于多数GAN也都这么定义就可以了,但是判别⽹络潜在学习到的损失函数隐藏在⽹络之中,不同的问题这个函数就不⼀样,所以说可以⾃动学习这个潜在的损失函数。
开始做⼩实验
本节主要实验⼀下如何通过随机数组⽣成mnist图像。mnist⼿写体数据库应该都熟悉的。这⾥简单的使⽤matlab来实现,⽅便看到整个实现过程。这⾥⽤到了⼀个⼯具箱,关于该⼯具箱的⼀些其他使⽤。
⽹络结构很简单,就定义成下⾯这样⼦:
将上述⼯具箱添加到路径,然后运⾏下⾯代码:
clc
clc
clear
%% 构造真实训练样本 60000个样本 1*784维(28*28展开)传统的反义词
load mnist_uint8;
train_x = double(train_x(1:60000,:)) / 255;
% 真实样本认为为标签 [1 0];⽣成样本为[0 1];
train_y = double(ones(size(train_x,1),1));
% normalize
train_x = mapminmax(train_x, 0, 1);
rand('state',0)
%% 构造模拟训练样本 60000个样本 1*100维
test_x = normrnd(0,1,[60000,100]); % 0-255的整数
test_x = mapminmax(test_x, 0, 1);
test_y = double(zeros(size(test_x,1),1));
test_y_rel = double(ones(size(test_x,1),1));
%%
nn_G_t = nntup([100 784]);
画画步骤
nn_G_t.activation_function = 'sigm';
nn_G_t.output = 'sigm';
nn_D = nntup([784 100 1]);
nn_D.weightPenaltyL2 = 1e-4;  %  L2 weight decay
nn.dropoutFraction = 0.5;  %  Dropout fraction
nn.learningRate = 0.01;                %  Sigm require a lower learning rate
萧七
nn_D.activation_function = 'sigm';
nn_D.output = 'sigm';
% nn_D.weightPenaltyL2 = 1e-4;  %  L2 weight decay
nn_G = nntup([100 784 100 1]);
nn_D.weightPenaltyL2 = 1e-4;  %  L2 weight decay
nn.dropoutFraction = 0.5;  %  Dropout fraction
nn.learningRate = 0.01;                %  Sigm require a lower learning rate
nn_G.activation_function = 'sigm';
nn_G.output = 'sigm';
% nn_G.weightPenaltyL2 = 1e-4;  %  L2 weight decay
opts.numepochs =  1;        %  Number of full sweeps through data
opts.batchsize = 100;      %  Take a mean gradient step over this many samples %%
num = 1000;
tic
for each = 1:1500
%----------计算G的输出:假样本-------------------
for i = 1:length(nn_G_t.W)  %共享⽹络参数
nn_G_t.W{i} = nn_G.W{i};
end
G_output = nn_G_out(nn_G_t, test_x);
%-----------训练D------------------------------
index = randperm(60000);
train_data_D = [train_x(index(1:num),:);G_output(index(1:num),:)];
train_y_D = [train_y(index(1:num),:);test_y(index(1:num),:)];
nn_D = nntrain(nn_D, train_data_D, train_y_D, opts);%训练D
%-----------训练G-------------------------------
for i = 1:length(nn_D.W)  %共享训练的D的⽹络参数
肌酸激酶低nn_G.W{length(nn_G.W)-i+1} = nn_D.W{length(nn_D.W)-i+1};
end
%训练G:此时假样本标签为1,认为是真样本
nn_G = nntrain(nn_G, test_x(index(1:num),:), test_y_rel(index(1:num),:), opts); end
toc
for i = 1:length(nn_G_t.W)
nn_G_t.W{i} = nn_G.W{i};

本文发布于:2023-07-14 03:02:53,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/89/1080648.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:样本   模型   判别   学习   训练   分布   知道
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图