網格搜尋與交叉驗證

2022-08-27 00:15:18 字數 2606 閱讀 5655

一. 網格搜尋驗證

sklearn.model_selection.gridsearchcv(estimator, param_grid, scoring=none, fit_params=none, n_jobs=1, iid=true, refit=true, cv=none, verbose=0, pre_dispatch=『2*n_jobs』, error_score=』raise』, return_train_score=』warn』)

2. 常用方法和屬性

3. 使用示例(以randomforestclassifier為例, 其它的分類模型也能按這個方法調參)

1 param_test1 = 

2 gsearch1 = gridsearchcv(estimator = randomforestclassifier(min_samples_split=100,

3 min_samples_leaf=20,max_depth=8,max_features='

sqrt

' ,random_state=10),

4 param_grid = param_test1, scoring='

roc_auc

',cv=5)

5gsearch1.fit(x_train,y_train)

6print( gsearch1.best_params_, gsearch1.best_score_) #

得到最優n_estimators引數

1 param_test2 = #

, 'min_samples_split':[100,120,150,180,200,300]}

2 gsearch2 = gridsearchcv(estimator = randomforestclassifier(n_estimators=50, min_samples_split=100,

3 min_samples_leaf=20,max_features='

sqrt

' ,oob_score=true, random_state=10),

4 param_grid = param_test2, scoring='

roc_auc

',iid=false, cv=5)

5gsearch2.fit(x_train,y_train)

6print( gsearch2.best_params_, gsearch2.best_score_) #

得到最優max_depth引數

1 rf1 = randomforestclassifier(n_estimators= 50, max_depth=2, min_samples_split=100, 

2 min_samples_leaf=20,max_features='

sqrt

',oob_score=true, random_state=10)

3rf1.fit(x_train,y_train)

4print( rf1.oob_score_) #

列印袋外分數

#假設輸出結果為0.984, 預設情況為0.972

#相對於預設情況,袋外分數有提高,也就是說模型的泛化能力變好了

二. 交叉驗證

1

from sklearn.neighbors import

kneighborsclassifier

2from sklearn.model_selection import

cross_val_score

34 k_range = [1, 5, 9, 15]

5 cv_scores =

6for k in

k_range:

7 knn = kneighborsclassifier(n_neighbors=k)

8 scores = cross_val_score(knn, x_train, y_train, cv=5)

9 cv_score =np.mean(scores)

10print('

k={},驗證集上的準確率=

'.format(k, cv_score))

1112

#k=1,驗證集上的準確率=0.94713#

k=5,驗證集上的準確率=0.95514#

k=9,驗證集上的準確率=0.96415#

k=15,驗證集上的準確率=0.964

1617 best_k = k_range[np.argmax(cv_scores)] #

從交叉驗證中的最優score中取出最優引數, 代入模型重新fit,score

18 best_knn = kneighborsclassifier(n_neighbors=best_k)

19best_knn.fit(x_train, y_train)

20print('

測試集準確率:

', best_knn.score(x_test, y_test))21#

測試集準確率: 0.9736842105263158

交叉驗證與網格搜尋

交叉驗證與網格搜尋是機器學習中的兩個非常重要且基本的概念,但是這兩個概念在剛入門的時候並不是非常容易理解與掌握,自己開始學習的時候,對這兩個概念理解的並不到位,現在寫一篇關於交叉驗證與網格搜尋的文章,將這兩個基本的概念做一下梳理。網格搜尋 grid search 名字非常大氣,但是用簡答的話來說就是...

網格搜尋和交叉驗證

在介紹網格搜尋和交叉驗證以前先要介紹下什麼是機器學習的超引數。我們常說的機器學習的引數指的是和特徵相關的係數,超引數指的是對於模型的整體規劃具有重要意義的指標 例如支援向量機中的乘法因子c 用於權衡經驗風險和模型複雜度 當支援向量機核函式是為徑向基rbf核函式,對應的鐘型函式的寬度gamma就是核函...

網格搜尋調參法與交叉驗證

網格搜尋法是指定引數值的一種 窮舉搜尋方法 通過將估計函式的 引數通過 交叉驗證進行優化 來得到最優的學習演算法。將各個引數可能的取值進行 排列組合 列出所有可能的組合結果生成 網格 然後將各組合用於svm 訓練,使用 交叉驗證 對表現進行評估。在擬合函式嘗試了所有的引數組合後,返回乙個合適的分類器...