對於tensorflow中的優化器(optimizer),目前已有的有以下:
不同的優化器有各自的特點,不能說誰好誰壞,有的收斂速度慢,有的收斂速度快。
此處以mnist資料集識別分類為例進行不同優化器的測試
1、梯度下降法:tf.train.gradientdescentoptimizer()
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#載入資料集
mnist = input_data.read_data_sets("mnist_data",one_hot=true)
#每個批次的大小
batch_size = 100
#計算一共需要多少個批次
n_batch = mnist.train.num_examples // batch_size
#定義兩個佔位符(placeholder)
x = tf.placeholder(tf.float32,[none,784])
y = tf.placeholder(tf.float32,[none,10])
#建立乙個簡單的神經網路
w = tf.variable(tf.zeros([784,10]))
b = tf.variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,w)+b)
#方法一:二次代價函式
#loss = tf.reduce_mean(tf.square(y - prediction))
#方法二:交叉熵代價函式(cross-entropy)的使用,加快收斂速度,迭代較少次數就能達到滿意的效果
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
#優化器的使用
#方法一:使用梯度下降法
train_step = tf.train.gradientdescentoptimizer(0.2).minimize(loss)
#方法二:使用adam方法
#train_step = tf.train.adamoptimizer(1e-2).minimize(loss)
#初始化變數
init = tf.global_variables_initializer()
#結果存放在乙個布林型列表中
#equal中的兩個值,若是一樣,則返回true,否則返回false。argmax函式:返回最大值所在的索引值,即位置
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#求準確率
#將上一步的布林型別轉化為32位浮點型,即true轉換為1.0,false轉換為0.0,然後計算這些值的平均值作為準確率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
#定義會話
with tf.session() as sess:
#初始化變數
sess.run(init)
#迭代21個週期
for epoch in range(21):
#n_batch:之前定義的批次
for batch in range(n_batch):
#獲得100張,的資料儲存在batch_xs中,的標籤儲存在batch_ys中
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
#使用feed操作,此步執行訓練操作的op,將資料餵給他
sess.run(train_step,feed_dict = )
#訓練乙個週期後就可以看下準確率,使用feed方法,此步執行計算準確度的op操作,將其對應的引數餵給它
acc = sess.run(accuracy,feed_dict = )
print("iter " + str(epoch) + ",testing accuracy " + str(acc))
訓練測試結果:
extracting mnist_data/train-images-idx3-ubyte.gz2、tf.train.adamoptimizer() 優化器的使用extracting mnist_data/train-labels-idx1-ubyte.gz
extracting mnist_data/t10k-images-idx3-ubyte.gz
extracting mnist_data/t10k-labels-idx1-ubyte.gz
iter 0,testing accuracy 0.8237
iter 1,testing accuracy 0.8902
iter 2,testing accuracy 0.8997
iter 3,testing accuracy 0.9048
iter 4,testing accuracy 0.9089
iter 5,testing accuracy 0.9107
iter 6,testing accuracy 0.9119
iter 7,testing accuracy 0.914
iter 8,testing accuracy 0.9148
iter 9,testing accuracy 0.9161
iter 10,testing accuracy 0.9173
iter 11,testing accuracy 0.9182
iter 12,testing accuracy 0.9203
iter 13,testing accuracy 0.9199
iter 14,testing accuracy 0.9192
iter 15,testing accuracy 0.9196
iter 16,testing accuracy 0.9196
iter 17,testing accuracy 0.9207
iter 18,testing accuracy 0.9214
iter 19,testing accuracy 0.9209
iter 20,testing accuracy 0.9217
即將
#方法一:使用梯度下降法
train_step = tf.train.gradientdescentoptimizer(0.2).minimize(loss)
換成
#方法二:使用adam方法
train_step = tf.train.adamoptimizer(1e-2).minimize(loss)
訓練測試結果:
extracting mnist_data/train-images-idx3-ubyte.gz可見,在mnist資料下,使用設計的神經網路,此種優化器效能較好。extracting mnist_data/train-labels-idx1-ubyte.gz
extracting mnist_data/t10k-images-idx3-ubyte.gz
extracting mnist_data/t10k-labels-idx1-ubyte.gz
iter 0,testing accuracy 0.917
iter 1,testing accuracy 0.9242
iter 2,testing accuracy 0.9287
iter 3,testing accuracy 0.9287
iter 4,testing accuracy 0.9285
iter 5,testing accuracy 0.9317
iter 6,testing accuracy 0.931
iter 7,testing accuracy 0.9308
iter 8,testing accuracy 0.9325
iter 9,testing accuracy 0.929
iter 10,testing accuracy 0.9293
iter 11,testing accuracy 0.933
iter 12,testing accuracy 0.9282
iter 13,testing accuracy 0.9297
iter 14,testing accuracy 0.9323
iter 15,testing accuracy 0.9309
iter 16,testing accuracy 0.9296
iter 17,testing accuracy 0.9317
iter 18,testing accuracy 0.9304
iter 19,testing accuracy 0.9281
iter 20,testing accuracy 0.9314
tensorflow的優化器比較
標準梯度下降法 彙總所有樣本的總誤差,然後根據總誤差更新權值 隨機梯度下降 隨機抽取乙個樣本誤差,然後更新權值 每個樣本都更新一次權值,可能造成的誤差比較大 批量梯度下降法 相當於前兩種的折中方案,抽取乙個批次的樣本計算總誤差,比如總樣本有10000個,可以抽取1000個作為乙個批次,然後根據該批次...
tensorflow中的優化器
1.tf.train.gradientdescentoptimizer 標準梯度下降優化器 標準梯度下降先計算所有樣本彙總誤差,然後根據總誤差來更新權值 2.tf.train.adadeltaoptimizer adadelta優化器,在sgd的基礎上 3.tf.train.adagradoptim...
tensorflow常用的優化器
tf.train.momentumoptimizer learning rate,momentum,use locking false name momentum minimize loss learning rate 學習率,資料型別為tensor或float。momentum 動量引數,mome...