用keras创建拟合网络解决回归问题Regression

更新时间:2023-06-09 11:14:39 阅读: 评论:0

⽤keras创建拟合⽹络解决回归问题Regression
实现了正弦曲线的拟合,即regression问题。
创建的模型单输⼊单输出,两个隐层分别为100、50个神经元。
在keras的官⽅⽂档中,给的例⼦多是关于分类的。因此在测试regression时,遇到了⼀些问题。总结来说,应注意以下⼏个⽅⾯:1)训练数据需是矩阵型,这⾥的输⼊和输出是1000*1,即1000个样本;每个样本得到⼀个输出;
注意:训练数据的⽣成⾮常关键,⾸先需要检查输⼊数据和输出数据的维度匹配;
3)输出层的激活函数选择很重要,该拟合的输出有正负值,因此选择tanh⽐较合适;
4)regression问题中,训练函数compile中的误差函数通常选择mean_squared_error。
5)值得注意的是,在训练时,可以将测试数据的输⼊和输出绘制出来,这样可以帮助调试参数。
6)keras中实现回归问题,返回的准确率为0。
[python]
1. # -*- coding: utf-8 -*-
2. """
3. Created on Mon May 16 13:34:30 2016
掌上药通4. @author: Michelle
5. """
6. dels import Sequential
7. from import Den, Activation
8. from keras.optimizers import SGD
9. from keras.layers.advanced_activations import LeakyReLU
如何踢足球10. from sklearn import preprocessing
11. from keras.utils.visualize_plots import figures
12. import matplotlib.pyplot as plt
13. import numpy as np
14.
15. #part1: train data
16. #generate 100 numbers from -2pi to 2pi
17. x_train = np.linspace(-2*np.pi, 2*np.pi, 1000)  #array: [1000,]
18. x_train = np.array(x_train).reshape((len(x_train), 1)) #reshape to matrix with [100,1]
19. n=0.1*np.random.rand(len(x_train),1) #generate a matrix with size [len(x),1], value in (0,1),array: [1000,1]
20. y_train=np.sin(x_train)+n
21.
22. #训练数据集:零均值单位⽅差
月亮的英文怎么读23. x_train = preprocessing.scale(x_train)
24. scaler = preprocessing.StandardScaler().fit(x_train)
25. y_train = ansform(y_train)
26.
27. #part2: test data
28. x_test = np.linspace(-5,5,2000)
29. x_test = np.array(x_test).reshape((len(x_test), 1))
30. y_test=np.sin(x_test)
31.
32. #零均值单位⽅差
33. x_test = ansform(x_test)
自己的己组词34. #y_test = ansform(y_test)
35. ##plot testing data
36. #fig, ax = plt.subplots()
37. #ax.plot(x_test, y_test,'g')
38.
39. #prediction data
40. x_prd = np.linspace(-3,3,101)
牛肉汤底
41. x_prd = np.array(x_prd).reshape((len(x_prd), 1))
42. x_prd = ansform(x_prd)
43. y_prd=np.sin(x_prd)
44. #plot testing data
45. fig, ax = plt.subplots()
46. ax.plot(x_prd, y_prd,'r')
47.
48. #part3: create models, with 1hidden layers
49. model = Sequential()
50. model.add(Den(100, init='uniform', input_dim=1))
51. #model.add(Activation(LeakyReLU(alpha=0.01)))
52. model.add(Activation('relu'))
53.
54. model.add(Den(50))
55. #model.add(Activation(LeakyReLU(alpha=0.1)))
碘量瓶图片
56. model.add(Activation('relu'))
57.
58. model.add(Den(1))
59. #model.add(Activation(LeakyReLU(alpha=0.01)))
60. model.add(Activation('tanh'))
61.
虚线是预测值,红⾊是输⼊值;
绘制误差值随着迭代次数的曲线函数是Visualize_plots.py,
1)将其放在C:\Anaconda2\Lib\site-packages\keras\utils下⾯。
2)在使⽤时,需要添加这句话:from keras.utils.visualize_plots import figures,然后在程序中直接调⽤函数figures(hist)。垓函数的实现代码为:
[python]
1. # -*- coding: utf-8 -*-
2. """
3. Created on Sat May 21 22:26:24 2016
4.
5. @author: Shemmy
6. """
7.
8. def figures(history,figure_name="plots"):
9. """ method to visualize accuracies and loss vs epoch for training as well as testind data\n
10.        Argumets: history    = an instance returned by model.fit method\n
11.                  figure_name = a string reprenting file name to plots. By default it is t to "plots" \n
12.        Usage: hist = model.fit(X,y)\n              figures(hist) """
13. from keras.callbacks import History
佐卡伊珠宝14. if isinstance(history,History):
15. import matplotlib.pyplot as plt
16.        hist    = history.history
17.        epoch    = history.epoch
18.        acc      = hist['acc']
19.        loss    = hist['loss']
20.        val_loss = hist['val_loss']
21.        val_acc  = hist['val_acc']
22.        plt.figure(1)
23.
24.        plt.subplot(221)
25.        plt.plot(epoch,acc)
26.        plt.title("Training accuracy vs Epoch")
27.        plt.xlabel("Epoch")
28.        plt.ylabel("Accuracy")
29.
30.        plt.subplot(222)
31.        plt.plot(epoch,loss)
类乌齐寺
32.        plt.title("Training loss vs Epoch")
33.        plt.xlabel("Epoch")
34.        plt.ylabel("Loss")
35.
36.        plt.subplot(223)
37.        plt.plot(epoch,val_acc)
38.        plt.title("Validation Acc vs Epoch")
39.        plt.xlabel("Epoch")
40.        plt.ylabel("Validation Accuracy")
41.
42.        plt.subplot(224)
43.        plt.plot(epoch,val_loss)
44.        plt.title("Validation loss vs Epoch")
45.        plt.xlabel("Epoch")
46.        plt.ylabel("Validation Loss")
47.        plt.tight_layout()
48.        plt.savefig(figure_name)
49. el:
50. print"Input Argument is not an instance of class History"

本文发布于:2023-06-09 11:14:39,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/82/910731.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:输出   数据   函数   问题   训练   实现   拟合
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图