Python代码实现简单的MNIST⼿写数字识别(适合初学者看)初学机器学习,第⼀步是做⼀个简单的⼿写数字识别,我选⽤的是MNIST数据集(⽤其他数据集也可以,原理都差不多),算法是
KNN(下载库直接调⽤函数,算法的具体实现没有过多关⼼)。在⽹上也看到过MNIST数据集的Python代码,但是感觉有些复杂,作为初学者见到那么多代码就头⼤……这⾥分享⼀下我的代码,虽然并不完善,但是可以为其他初学者提供⼀点简单的思路吧。
⾸先明确⼀下我的思路:解析图⽚和标签——处理图⽚和标签——加载KNN分类器训练——读⼊处理后的测试图⽚和标签——得出正确率。
我写了两个程序,第⼀个⽤来解析并保存图⽚,第⼆个对图⽚进⾏处理、解析标签、训练、预测、得出结果。
准备⼯作:
0.安装Python:最好默认安装位置,安装的时候勾上,下⼀步还有个为所有Urs安装,也勾上。win7以上如果默认安装到C盘某个⽬录下,需要更改⼀下⽂件夹的权限,在python⽂件夹上点击右键>>属性>>安全>>编辑,把Urs和ALL APPLICATION PACKAGES的权限的“完全控制”都打上勾,确定。
1.Python做数字⼿写识别需要⽤到的库:numpy,scipy,scikit_learn,也可以再加上openCV,因为我把图⽚都解析出来保存下来了。点击可以寻找并下载这些库(.whl⽂件),注意要对应⾃⼰的python版本。下载后放在python安装⽬录的Scripts下,按着shift点⿏标右键,点击”在此处打开命令窗⼝”,输⼊pip install ****** ,这⾥******代表要安装的⽂件名,注意不要更改那些⽂件名,直接把⽂件名带着后缀.whl复制粘贴在命令⾏⾥就⾏,粘贴的时候不可以使⽤ctrl+v,直接⿏标右键粘贴就⾏。要先安装numpy和scipy,再安装scikit_learn 和openCV。安装成功会有提⽰的,失败的话……当然也会有提⽰,如果有错误就百度⼀下。如果你安装的是python3.5或者3.6的话,联⽹的情况下不⽤下载.whl⽂件直接输⼊pip install ******应该程序就会⾃⼰在⽹上下载安装,特别⽅便。但是下载的numpy只有numpy,⽽在我提供的那个⽹址上下载的都是numpy+mkl,后续安装⼀些库需要⽤到mkl,所以建议numpy还是⾃⼰下载。
2.下载MNIST数据集,⽹上很好搜到(很久前下载的了,这⾥就不贴链接了)。下载后有四个⽂件,两个.idx3-ubyte⽂件,分别是训练⽤和测试⽤的图像⽂件,还有两个.idx1-ubyte⽂件,分别是训练和测试⽤的标签⽂件。这些⽂件都是⼆进制⽂件,没办法直接打开,需要在Python ⾥写程序解析。
正式开始:
为了验证我的解析结果是否正确,我把解析出的图⽚进⾏了保存(这⾥⽤到了openCV),然后处理的是保存后的图⽚,其实不保存就可以,解析出来直接⽤,反⽽会节省很多步骤。
解析图⽚:解析图⽚和标签的原理在这⾥我就不多说了,⽹上可以搜到,介绍的很详细。直接上代码(如果你不⽤openCV的话,就要修改⼀下代码了,可以不转换成图⽚)。
import numpy as np
import struct
import cv2
def readfile():#读取源图⽚⽂件
with open('E:\\t10k-images.idx3-ubyte','rb') as f1:
buf1 = f1.read()
return buf1
def get_image(buf1):#解析并保存图⽚
image_index = 0
image_index += struct.calcsize('>IIII')
magic,numImages,imgRows,imgCols=struct.unpack_from(">IIII",buf1,0)
im = []
for i in range(numImages):
temp = struct.unpack_from('>784B', buf1, image_index)
im=np.array(temp)
shape(28,28)
cv2.imwrite("E:\\testImages\\testIM"+str(i)+".jpg",im2)#保存路径⾃⼰设置
image_index += struct.calcsize('>784B') # 28*28=784(B)
if i%20==0:#知道图⽚保存的进度
孟浩然的简介资料print i
el:
print i,
if __name__ == "__main__":
image_data = readfile()
get_image(image_data)
cv2.waitKey(0)
cv2.destroyAllWindows()
然后是主要程序了(我的程序有点啰嗦,见谅)。
import cv2
import os
import numpy as np
from sklearn import neighbors
import struct
print "Now start,"#程序运⾏时间灰常漫长……
#当时脑抽⽤了四个⾃定义函数,其实⽤两个就够了,玩家可以⾃定义
def getImages():#处理训练图⽚
莲子芯
s([60000,784],int)#建⽴⼀个60000*784的0矩阵
for i in range(60000):
img1=cv2.imread("E:\\Images\\IM"+str(i)+".jpg",0)#读取每⼀张图⽚(路径⾃定义)
for rows in range(28):
for cols in range(28):#访问每张图⽚的每个像素,这种⽅法简单易懂但是效率⽐较低
if img1[rows,cols]>=127:#⼆值化处理,把⼀整张图⽚的像素处理成只有0和1
img1[rows,cols]=1
el:
img1[rows,cols]=0#这⾥选择的临界点是127,正好是0-255的中间值
imgs[i,rows*28+cols]=img1[rows,cols]#把每张图⽚(28*28)展开成⼀⾏(1*784), #然后把每张图⽚的像素逐⾏放到(60000*784)的⼤矩阵中
return imgs#返回所有图⽚的像素重构的矩阵
def getLabels():#解析训练标签(解析出来的标签和图⽚顺序是⼀⼀对应的)
f1=open("E:\\train-labels.idx1-ubyte",'rb')
ad()
f1.clo()
index=0
magic,num=struct.unpack_from(">II",buf1,0)
index+=struct.calcsize('>II')
labs=[]
labs=struct.unpack_from('>'+str(num)+'B',buf1,index)
return labs#返回训练标签。之前没有单独解析出来保存在⽂本⽂件中,因为解析标签⽐较简单。
def getTestImages():#处理测试图⽚,和处理训练图⽚是⼀样的
s([10000,784],int)#
for i in range(10000):#
img1=cv2.imread("E:\\testImages\\testIM"+str(i)+".jpg",0)
for rows in range(28):
for cols in range(28):
if img1[rows,cols]>=127:
img1[rows,cols]=1
el:
img1[rows,cols]=0
imgs[i,rows*28+cols]=img1[rows,cols]
return imgs
def getTestLabels():#处理测试标签,和处理训练标签是⼀样的面包简笔画
f1=open("E:\\t10k-labels.idx1-ubyte",'rb')
ad()
f1.clo()
index=0
magic,num=struct.unpack_from(">II",buf1,0)
index+=struct.calcsize('>II')
弱小的近义词labs=[]
labs=struct.unpack_from('>'+str(num)+'B',buf1,index)
return labs
if __name__=="__main__":#主函数
print "Getting "#print的⽬的就是知道进度
会挽雕弓如满月
train_imgs=getImages()#train_imgs保存60000*784的⼤矩阵
print "Getting "
train_labels=getLabels()#train_labels保存60000个训练标签
print "Creating "
knn=neighbors.KNeighborsClassifier(algorithm='kd_tree',n_neighbors=3)#重点来了,这⾥就是加载KNN分类器,具体的⽤法可以上⽹搜索 print ""
knn.fit(train_imgs,train_labels)#读⼊训练图⽚和标签进⾏训练
print "Getting "
test_imgs=getTestImages()#test_imgs保存10000*784的⼤矩阵
print "Getting "
蝙蝠图片test_labels=getTestLabels()#test_labels保存10000个训练标签
print ""
result=knn.predict(test_imgs)#对测试图⽚进⾏预测
溺水事例wrongNum=np.sum(result!=test_labels)#得出错误个数
快速性num=len(test_imgs)#训练图⽚的总数
print "Total number:",num
print "Wrong number:",wrongNum
print "RightRate:",1-wrongNum/float(num)#得出正确率
#英语部分可能有些表述错误,能理解就好~_~
我第⼀次运⾏⽤了⼤约半个⼩时的时间,这个程序使⽤的内存升到了⼏乎700M,运⾏完⼀次我再也不想运⾏了,⼼疼我电脑……在下初学者,不⾜之处甚多,恳请批评指正。