TensorFlow自编码器(AE)实战

更新时间:2023-06-09 10:42:36 阅读: 评论:0

TensorFlow ⾃编码器(AE )实战
第⼀部分:这⼀部分是来⾃《TensorFlow》龙良曲⽼师的书籍中的代码:import tensorflow as tf import numpy as np from PIL import Image import matplotlib .pyplot as plt from tensorflow .keras import  loss ,optimizers ,Sequential
1
2
3
4
5
6#加载Fashion MNIST 图⽚数据集(train_x ,train_y ),(test_x ,test_y )=tf .keras .datats .fashion_mnist .load_data ()
1
2#归⼀化train_x ,test_x =train_x .astype (np .float32)/255.0,test_x .astype (np .float32)/255.0
送什么礼物好
1
2EPOCHES =2batch_size =64learning_rate =0.0001#只需要通过图⽚数据即可构建数据集对象,不需要标签train_db =tf .data .Datat .from_tensor_slices (train_x )train_db =train_db .shuffle (10000)train_db =train_db .batch (batch_size )# train _db =train_db .repeat (5)#构建测试集对象test_db =tf .data .Datat .from_tensor_slices (test_x )test_db =test_db .batch (batch_size )
1
2
3
4
5
6
7
8
9
10
11
12class  AE (tf .keras .Model ):    def __init__(lf ):        super (AE , lf ).__init__()        # 创建Enconder ⽹络,实现在⾃编码器类的初始化函数中        lf .encoder =Sequential ([            tf .keras .layers .Den (256),            tf .keras .layers .Activation ('relu'),            tf .keras .layers .Den (128),            tf .keras .layers .Activation ('relu'),            tf .keras .layers .Den (20)        ])        # 创建Deconder ⽹络        lf .decoder =Sequential ([            tf .keras .layers .Den (128),            tf .keras .layers .Activation ('relu'),            tf .keras .layers .Den (256),            tf .keras .layers .Activation ('relu'),            tf .keras .layers .Den (784)        ])    def call (lf ,inputs ,training =None ):        # 前向传播函数        # 编码获得隐藏向量h ,[b ,784]->[b ,20]        out =lf .encoder (inputs )        # 解码获得重建图⽚,[b ,20]->[b ,784]        out_put =lf .decoder (out )        return  out_put
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
进口英文19
20
21
22
23
24
25
26
#创建⽹络对象model =AE ()#指定输⼊⼤⼩model .build (input_shape =(4,784))#打印⽹络信息model .summary ()
1
2
3
4
5
6
7#创建优化器,并放置学习率optimizer =optimizers .Adam (learning_rate =learning_rate )
1
2#保存图像def Save_Image (img ,filename ):    new_in = Image .new ('L', (280, 280))    index = 0    for  i in range (0, 280, 28):        for  j in range (0, 280, 28):            im = img [index ]            im = Image .fromarray (im , mode ='L')            new_in .paste (im , (i , j ))            index += 1    new_in .save (filename )
1
2
3
4
5
6
7
8林氏
9
10
复兴遂州11
12
LOSS =[]for  epoch in range (EPOCHES ):    for  step ,x in enumerate (train_db ):        x =tf .reshape (x ,[-1,784])        with tf .GradientTape () as tape :            # 前向计算获得重建的图⽚            out =model (x )            # 计算重建图⽚与输⼊之间的损失函数            loss =tf .loss .binary_crosntropy (x ,out ,from_logits =True )            loss =tf .reduce_mean (loss )        # ⾃动求导,包含两个⼦⽹络的梯度        grads =tape .gradient (loss ,model .trainable_variables )        # ⾃动更新,同时更新两个⼦⽹络        optimizer .apply_gradients (zip (grads ,model .trainable_variables ))        if  step %100==0:            LOSS .append (float (loss ))            print (epoch ,step ,float (loss ))    x =next (iter (test_db ))    logits =model (tf .reshape (x ,[-1,784]))    x_hat =tf .sigmoid (logits )    x_hat =
tf .reshape (x_hat ,[-1,28,28])    x_concat =tf .concat ([x [:50],x_hat [:50],x_hat ],axis =0)    x_concat =x_concat .numpy ()*255.0    x_concat =x_concat .astype (np .uint8)    Save_Image (x_concat ,r 'E:\python 教学\图像识别\⾃编码器\images_AE\\%d.png'%epoch )plt .figure (figsize =(6,6))x =[i for  i in range (len (LOSS ))]plt .plot (x ,LOSS ,label ='loss',linestyle ='-',color ='blue')plt .xlabel ('X')plt .ylabel ('loss')plt .legend ()plt .show ()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
speak的过去式和过去分词
31
32
33
34
35
36
实验效果:以上在写代码的时候需要注意⼀个问题:就是在Save_Image 中的图像点阵图⼤⼩的问题,Save_Image中定义的步长为28,所以最后输出的图像是10*10的图像阵列,每⼀次x=next(iter(test_db))迭代得到图像是(32,28,28),意味着每⼀次的迭代得到的图像数为32,通过后⾯的拼接:x_at([x[:50],x_hat[:50]],axis=0)之后的图像为(64,28,28),虽然这个地⽅写的是x[:50]读取前50张图像,但是实际上得到的前32张,所以10*10张图像肯定是超出了索引的范围,也就是在运⾏代码的时候会出现:索引越界的错误。
所以这个地⽅该怎么改呢:可以将Save_Image中的步长设置为56,增加步长,只不过得到的图像数少⼀点,25张。将这个x_at([x[:50],x_hat[:50]],axis=0)可以修改为这个x_at([x[:15],x_hat[:15]],axis=0)。就可以了,只要索引的范围不越界就⾏。
第⼆部分:对上⾯的代码进⾏改进
import tensorflow as tf import numpy as np from PIL import Image from tensorflow .keras import  loss ,optimizers ,Sequential
1
2
3
4#加载Fashion MNIST 图⽚数据集(train_x ,train_y ),(test_x ,test_y )=tf .keras .datats .fashion_mnist .load_data ()
1
2#归⼀化train_x ,test_x =train_x .astype (np .float32)/255.0,test_x .astype (np .float32)/255.0
1
2train_x =tf .reshape (train_x ,[-1,784])print (np .shape (train_x ))
1
2batch_size =64learning_rate =0.0001
1
2
encoder =Sequential ([    tf .keras .layers .Den (256),    tf .keras .layers .Activation ('relu'),    tf .keras .layers .Den (128),    tf .keras .layers .Activation ('relu'),    tf .keras .layers .Den (20)])decoder =Sequential ([    tf .keras .layers .Den (128),    tf .keras .layers .Activation ('relu'),    tf .keras .layers .Den (256),    tf .keras .layers .Activation ('relu'),    tf .keras .layers .Den (784)])autoencoder =Sequential ([    encoder ,    decoder ])
1
2
3
4
5
6
7
8
9
10
11
蔬菜牛肉粥
12
13
14
15
16
17
18
19
20#创建优化器,并放置学习率optimizer =optimizers .Adam (learning_rate =learning_rate )#保存图像def Save_Image (img ,filename ):    new_in =Image .new ('L',(280,280))    index =0    for  i in range (0,280,56):        for  j in range (0,280,56):            im =img [index ]            im =Image .fromarray (im ,mode ='L')            new_in .paste (im ,(i ,j ))            index +=1    new_in .save (filename )
1
2
3
4
5
6
牛奶棒7
8
9
10
11
自己的家
12
13autoencoder .compile (optimizer =optimizer ,loss =loss .binary_crosntropy )1autoencoder .fit (train_x ,train_x ,batch_size =batch_size *2,epochs =2,verbo =1)1for  epoch in range (2):    x =next (iter (test_db ))    print (np .shape (x ))    logits =model .predict (tf .reshape (x ,[-1,784]))    x_hat =tf .sigmoid (logits )    x_hat =tf .reshape (x_hat ,[-1,28,28])    x_concat =tf .concat ([x [:15],x_hat [:15]],axis =0)    print (np .shape (x_concat ))    x_concat =x_concat .numpy ()*255.0    x_concat =x_concat .astype (np .uint8)    Save_Image (x_concat ,r 'E:\python 教学\图像识别\⾃编码器\images_AE\%d.png'%epoch )1
2
3
4
5
6
7
8
9
10
11
12

本文发布于:2023-06-09 10:42:36,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/89/1030498.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:图像   得到   代码   获得   需要
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图