bp演算法鳶尾花分類
網上很多鳶尾花例子,學習其他人後仿寫,我在執行其他人的時候會有溢位和錯誤。
下述**準確率95%
提取碼:y07d
如果有什麼不對的或者有什麼不懂[email protected]
工程位址
import math
import random
import pandas as pd
random.seed(0)
def rand(a,b):
return (b-a)*random.random()+a
def dsigmoid(x):
return x*(1-x)
def sigmoid(gamma):
if gamma < 0:
return 1 - 1/(1 + math.exp(gamma))
else:
return 1/(1 + math.exp(-gamma))
def makematrix(i,j):
m =
for i in range(i):
return m
class bp:
''' 三層bp網路'''
def __init__(self,ni,nh,no):
self.ni=ni+1
self.nh=nh+1
self.no=no
self.ai=[1.0]*self.ni
self.ah=[1.0]*self.nh
self.ao=[1.0]*self.no
self.wi=makematrix(self.ni,self.nh)
self.wo=makematrix(self.nh,self.no)
for i in range(self.ni):
for j in range(self.nh):
self.wi[i][j]=rand(-2,2)
for i in range(self.nh):
for j in range(self.no):
self.wo[i][j]=rand(-2,2)
def forward(self,input):
for i in range(self.ni-1):
self.ai=input
for i in range(self.nh):
sum=0.0
for j in range(self.ni-1):
#print(self.ai[j])
sum+=self.wi[j][i]*self.ai[j]
self.ah[i]=sigmoid(sum)
for i in range(self.no):
sum=0.0
for j in range(self.no):
sum+=self.wo[j][i]*self.ah[j]
self.ao[i]=sigmoid(sum)
return self.ao
def backward(self,tartget,lv):
rerror=0.0
outerror=[1.0]*self.no
for i in range(self.no):
error=tartget[i]-self.ao[i]
rerror=error
outerror[i]=dsigmoid(self.ao[i])*error
herror=[1.0]*self.nh
for i in range(self.nh):
error=0.0
for j in range(self.no):
error+=outerror[j]*self.wo[i][j]
herror[i]=dsigmoid(self.ah[i])*error
for i in range(self.nh):
for j in range(self.no):
change=outerror[j]*self.ah[i]
self.wo[i][j]=self.wo[i][j]+lv*change
for i in range(self.ni-1):
for j in range(self.nh):
change=herror[j]*self.ai[i]
self.wi[i][j]=self.wi[i][j]+lv*change
error = 0.0
error += 0.5 * rerror ** 2
return error
def train(self, patterns, iterations, lr=0.009):
# lr: 學習速率(learning rate)
for i in range(iterations):
error = 0.0
for p in patterns:
inputs = p[0]
targets = p[1]
#print(inputs, ",", p[1])
self.forward(inputs)
error = error + self.backward(targets, lr)
if i%100==0:
print('error: %-.9f' % error)
def test(self,patterns):
count=0
for p in patterns:
target=p[1]
#print(target)
result=self.forward(p[0])
index = result.index(max(result))
if index==0 and target==[1,0,0]:
count=count+1
elif index==1 and target==[0,1,0]:
count=count+1
elif index==2 and target==[0,0,1]:
count = count + 1
accuracy = float(count / len(patterns))
print('accuracy: %-.9f' % accuracy)
def iris():
data=
raw=pd.read_csv('iris.csv')
raw_data=raw.values
raw_feature=raw_data[:,1:6]
print(raw_feature)
for i in range(len(raw_feature)):
ele=
if raw_data[i][5]=='setosa':
elif raw_data[i][5]=='versicolor':
else:
print(ele)
random.shuffle(data)
trainning=data[0:100]
test=data[101:]
nn=bp(4,7,3)
nn.train(data,10000)
nn.test(test)
if __name__ == '__main__':
iris()1
KNN演算法進行鳶尾花分類
import numpy as np import matplotlib.pylab as pyb matplotlib inline from sklearn.neighbors import kneighborsclassifier from sklearn import datasets fr...
Python 機器學習 鳶尾花分類
python 機器學習 鳶尾花分類 匯入類庫 from pandas import read csv from pandas.plotting import scatter matrix from matplotlib import pyplot from sklearn.model selecti...
鳶尾花資料分類實戰
資料集大概是這樣子的 將資料預處理一下 def get data loc iris.csv with open loc,r as fr lines csv.reader fr data file np.array list lines data data file 1 0 1 astype floa...