使用sklearn進行增量學習

2021-08-20 12:02:44 字數 1854 閱讀 6025

sklearn.*****_bayes.bernoullinb

sklearn.linear_model.perceptron

sklearn.linear_model.sgdclassifier

sklearn.linear_model.passiveaggressiveclassifier

regression 

clustering 

decomposition / feature extraction 

def iter_minibatches(data_stream, minibatch_size=1000):

'''迭代器

給定檔案流(比如乙個大檔案),每次輸出minibatch_size行,預設選擇1k行

將輸出轉化成numpy輸出,返回x, y

'''x =

y =

cur_line_num = 0

csvfile = file(data_stream, 'rb')

reader = csv.reader(csvfile)

for line in reader:

cur_line_num += 1

if cur_line_num >= minibatch_size:

x, y = np.array(x), np.array(y) # 將資料轉成numpy的array型別並返回

yield x, y

x, y = ,

cur_line_num = 0

csvfile.close()

# 生成測試檔案

minibatch_test_iterators = iter_minibatches(test_file, minibatch_size=5000)

x_test, y_test = minibatch_test_iterators.next() # 得到乙份測試檔案

from sklearn.linear_model import sgdclassifier

sgd_clf = sgdclassifier() # sgdclassifier的引數設定可以參考sklearn官網

minibatch_train_iterators = iter_minibatches(data_part_file, minibatch_size=2000)

for i, (x_train, y_train) in enumerate(minibatch_train_iterators):

# 使用 partial_fit ,並在第一次呼叫 partial_fit 的時候指定 classes

sgd_clf.partial_fit(x_train, y_train, classes=np.array([0, 1]))

print("{} time".format(i)) # 當前次數

print("{} score".format(sgd_clf.score(x_test, y_test))) # 在測試集上看效果

0 time

0.679 score

1 time

0.6954 score

2 time

0.712 score

3 time

0.7248 score

...57 time

0.745 score

58 time

0.7394 score

59 time

0.7398 score

資料只迭代一次分類器可能還沒完全收斂,可以多迭代幾次

mini-batch的量不要設定太小,太小的話,需要多迭代幾次才能收斂

**

使用sklearn進行增量學習

sklearn.bayes.bernoullinb sklearn.linear model.perceptron sklearn.linear model.sgdclassifier sklearn.linear model.passiveaggressiveclassifier regressi...

使用sklearn進行K Means文字聚類

k means演算法 中文名字叫做k 均值演算法,演算法的目的是將n個向量分別歸屬到k個中心點裡面去。演算法首先會隨機選擇k個中心向量,然後通過迭代計算以及重新選擇k個中心向量,使得n個向量各自被分配到距離最近的k中心點,並且所有向量距離各自中心點的和最小。步驟一 在輸入資料集裡面隨機選擇k個向量作...

使用sklearn進行mnist資料集分類

深度之眼 西瓜書課後 import time import matplotlib.pyplot as plt import numpy as np from sklearn.datasets import fetch openml from sklearn.linear model import l...