具體含義不再解釋,這是乙個我們比較常用的乙個多分類器.深度學習的一大優點就是特徵的自動構建,也正是因為該優點,使得分類器層顯得不再那麼重要,在tensorflow的官方原始碼中,softmax是很常見的乙個多分類器.其呼叫也十分的簡單.此處再此單獨拿出來介紹,是為了下一步的學習做準備.
使用方法
cross_entropy = tf.reduce_mean(
tf.nn
.softmax_cross_entropy_with_logits(labels=y_, logits=y))
用於損失函式的定義.
# 引用,官網自帶的原始碼有很多特殊之處,但是沒啥影響,自己寫的時候,完全沒必要這麼多引用
# 額外新增了控制警告訊息等級的code
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import os
os.environ['tf_cpp_min_log_level'] = '2'
flags = none
mnist = input_data.read_data_sets("/home/fonttian/data/mnist_data/", one_hot=true)
# create the model,可以看出此處的model非常簡單,就是一層y=wx+b,你也可以繼續增加層數,或者將其替代為卷積層,但是此處對於展示softmax並沒有什麼意義
x = tf.placeholder(tf.float32, [none, 784])
w = tf.variable(tf.zeros([784, 10]))
b = tf.variable(tf.zeros([10]))
y = tf.matmul(x, w) + b
# define loss and optimizer
y_ = tf.placeholder(tf.float32, [none, 10])
# 這部分**很簡單,一些細節我在之前已經介紹過了.
cross_entropy = tf.reduce_mean(
tf.nn
.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train
.gradientdescentoptimizer(0.5).minimize(cross_entropy)
sess = tf.interactivesession()
tf.global_variables_initializer().run()
# train
for _ in range(1000):
batch_xs, batch_ys = mnist.train
.next_batch(100)
sess.run(train_step, feed_dict=)
# test trained model
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict=))
關於main的部分之前已經有介紹了:
if __name__ == '__main__':parser = argparse.argumentparser()
parser.add_argument('--data_dir', type=str, default='/home/fonttian/data/mnist_data',
help='directory for storing input data')
flags, unparsed = parser.parse_known_args()
softMax交叉熵多分類引數調優之批次大小選擇
import tensorflow as tf import os import numpy as np import numpy as np os.environ tf cpp min log level 3 輸入隨機種子 myseed eval input learning rate eval ...
神經網路多分類任務的損失函式 交叉熵
神經網路解決多分類問題最常用的方法是設定n個輸出節點,其中n為類別的個數。對於每乙個樣例,神經網路可以得到的乙個n維陣列作為輸出結果。陣列中的每乙個維度 也就是每乙個輸出節點 對應乙個類別。在理想情況下,如果乙個樣本屬於類別k,那麼這個類別所對應的輸出節點的輸出值應該為1,而其他節點的輸出都為0。以...
神經網路多分類任務的損失函式 交叉熵
神經網路解決多分類問題最常用的方法是設定n個輸出節點,其中n為類別的個數。對於每乙個樣例,神經網路可以得到的乙個n維陣列作為輸出結果。陣列中的每乙個維度 也就是每乙個輸出節點 對應乙個類別。在理想情況下,如果乙個樣本屬於類別k,那麼這個類別所對應的輸出節點的輸出值應該為1,而其他節點的輸出都為0。以...