mnist資料集是機器學習領域中非常經典的乙個資料集,由60000個訓練樣本和10000個測試樣本組成,每個樣本都是一張28 * 28畫素的灰度手寫數字。
import time
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import logisticregression
from sklearn import datasets
from sklearn.preprocessing import standardscaler
from sklearn.datasets import fetch_openml #從openml.org**匯入資料
from sklearn.utils import check_random_state
t0=time.time()
train_samples=5000
#載入資料集
x, y = fetch_openml('mnist_784', version=1, return_x_y=true)
print(x.shape,y.shape)
#視覺化樣本,圖形化顯示前6個資料
permutation = random_state.permutation(x.shape[0]) #隨機排序序列
x = x[permutation]
y = y[permutation]
x = x.reshape((x.shape[0], -1))
x_train,x_test,y_train,y_test=train_test_split(x,y,train_size=train_samples,test_size=10000)
#資料進行標準化處理
scaler=standardscaler()
x_train=scaler.fit_transform(x_train)
x_test=scaler.transform(x_test)
#訓練模型
clf=logisticregression(c=50. / train_samples, penalty='l1', solver='saga', tol=0.1)
clf.fit(x_train,y_train)
score=clf.score(x_test,y_test)
sparsity = np.mean(clf.coef_ == 0) * 100 #稀疏性
print("sparsity with l1 penalty: %.2f%%" % sparsity)
print("test score with l1 penalty: %.4f" % score)
#畫圖,為什麼要給係數畫圖呢?
coef = clf.coef_.copy()
print(coef.shape)
plt.figure(figsize=(10, 5))
scale = np.abs(coef).max()
for i in range(10):
l1_plot = plt.subplot(2, 5, i + 1)
l1_plot.imshow(coef[i].reshape(28, 28), interpolation='nearest',
cmap=plt.cm.rdbu, vmin=-scale, vmax=scale)
l1_plot.set_xticks(())
l1_plot.set_yticks(())
l1_plot.set_xlabel('class %i' % i)
plt.suptitle('classification vector for...')
run_time = time.time() - t0
print('example run in %.3f s' % run_time)
plt.show()
執行結果
問題:為什麼要給係數畫圖呢?
使用sklearn進行mnist資料集分類
深度之眼 西瓜書課後 import time import matplotlib.pyplot as plt import numpy as np from sklearn.datasets import fetch openml from sklearn.linear model import l...
用mnist簡單訓練個邏輯回歸(單層神經網路)
先說個小事情 這個功能挺方便。import requests data path path data path data path mnist path.mkdir parents true,exist ok true url filename mnist.pkl.gz if not path fi...
mnist資料集進行自編碼
自動編碼的核心就是各種全連線的組合,它是一種無監督的形式,因為他的標籤是自己。import torch import torch.nn as nn from torch.autograd import variable import torch.utils.data as data import t...