如圖,
大訓練集
分塊,使用不同的分塊方法分成n對小訓練集
和驗證集
。
使用小訓練集
進行訓練,使用驗證集
進行驗證,得到準確率,求n個驗證集
上的平均正確率
;
使用平均正確率
最高的超引數
,對整個大訓練集
進行訓練,訓練出引數。
在訓練集
上訓練。
十折交叉驗證
諸如你有多個可調節的超引數,那麼選擇超引數的方法通常是網格搜尋,即固定乙個參、變化其他參,像網格一樣去搜尋。
"""任務:鳶尾花識別
"""import
pandas as pd
from sklearn.model_selection import
train_test_split, gridsearchcv
from sklearn.neighbors import
kneighborsclassifier
from sklearn.linear_model import
logisticregression
from sklearn.svm import
svcdata_file = '
./data_ai/iris.csv
'species_label_dict =
#使用的特徵列
feat_cols = ['
sepallengthcm
', '
sepalwidthcm
', '
petallengthcm
', '
petalwidthcm']
defmain():
"""主函式
"""#
讀取資料集
iris_data = pd.read_csv(data_file, index_col='id'
) iris_data[
'label
'] = iris_data['
species
'].map(species_label_dict)
#獲取資料集特徵
x =iris_data[feat_cols].values
#獲取資料標籤
y = iris_data['
label
'].values
#劃分資料集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=1/3, random_state=10)
model_dict =
),'logistic regression':
(logisticregression(),
),'svm':
(svc(),
)}
#名稱+元組
for model_name, (model, model_params) in
model_dict.items():
#訓練模型
clf = gridsearchcv(estimator=model, param_grid=model_params, cv=5) #
模型、引數、折數
clf.fit(x_train, y_train) #
訓練 best_model = clf.best_estimator_ #
最佳模型的物件#驗證
acc =best_model.score(x_test, y_test)
print('
{}模型的**準確率:%
'.format(model_name, acc * 100))
print('
{}模型的最優引數:{}
'.format(model_name, clf.best_params_)) #
最好的模型名稱和引數
if__name__ == '
__main__':
main()
執行結果:
knn模型的**準確率:96.00%
knn模型的最優引數:
logistic regression模型的**準確率:96.00%
logistic regression模型的最優引數:
svm模型的**準確率:98.00%
svm模型的最優引數:
練習:使用交叉驗證對水果分類模型進行調參
可能的**
import執行結果:pandas as pd
from sklearn.model_selection import
gridsearchcv, train_test_split
from sklearn.neighbors import
kneighborsclassifier
from sklearn.linear_model import
logisticregression
from sklearn.svm import
svc#
讀取資料
data = pd.read_csv('
./data_ai/fruit_data.csv')
#資料處理
fruit_dict =
data[
'label
'] = data['
fruit_name
'].map(fruit_dict)
feat_cols = ['
mass
','width
','height
','color_score']
#資料提取
x =data[feat_cols].values
y = data['
label
'].values
x_train, x_test, y_train, y_test = train_test_split(x,y,test_size=1/5, random_state= 3)
model_dict = ),
'logestic regression
': (logisticregression(), ),
'svm
': (svc(), )
}for model_name, (model, model_para) in
model_dict.items():
#訓練clf = gridsearchcv(estimator=model, param_grid=model_para, cv=5) #
模型、引數、折數
clf.fit(x_train,y_train)
best_model =clf.best_estimator_
#驗證acc =best_model.score(x_test, y_test)
print(f'
中選擇為引數的**準確率最好,準確率可達%
')
knn中選擇為引數的**準確率最好,準確率可達66.66666666666666%
logestic regression中選擇為引數的**準確率最好,準確率可達91.66666666666666%
svm中選擇為引數的**準確率最好,準確率可達50.0%
利用KNN對鳶尾花資料進行分類
knn k nearest neighbor 工作原理 存在乙個樣本資料集合,也稱為訓練樣本集,並且樣本集中每個資料都存在標籤,即我們知道樣本集中每一資料與所屬分類對應的關係。輸入沒有標籤的資料後,將新資料中的每個特徵與樣本集中資料對應的特徵進行比較,提取出樣本集中特徵最相似資料 最近鄰 的分類標籤...
交叉驗證與網格搜尋(以KNN分類鳶尾花為例)
總結 import pandas as pd pd.set option display.max rows 6 1獲取資料 from sklearn.datasets import load iris iris load iris 新建乙個dataframe,把iris中data方進來,並且data...
使用KNN對鳶尾花資料集進行分類處理
author tao contact 1281538933 qq.com file knn.py time 2020 12 21 software vscode from sklearn.datasets import load iris 匯入資料集iris import matplotlib.py...