本文實例為大家分享了Tensorflow訓練MNIST手寫數字識別模型的具體代碼,供大家參考,具體內容如下

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
INPUT_NODE = 784 # 輸入層節點=圖片像素=28x28=784
OUTPUT_NODE = 10 # 輸出層節點數=圖片類別數目
LAYER1_NODE = 500 # 隱藏層節點數,只有一個隱藏層
BATCH_SIZE = 100 # 一個訓練包中的數據個數,數字越小
# 越接近隨機梯度下降,越大越接近梯度下降
LEARNING_RATE_BASE = 0.8 # 基礎學習率
LEARNING_RATE_DECAY = 0.99 # 學習率衰減率
REGULARIZATION_RATE = 0.0001 # 正則化項系數
TRAINING_STEPS = 30000 # 訓練輪數
MOVING_AVG_DECAY = 0.99 # 滑動平均衰減率
# 定義一個輔助函數,給定神經網絡的輸入和所有參數,計算神經網絡的前向傳播結果
def inference(input_tensor, avg_class, weights1, biases1,
weights2, biases2):
# 當沒有提供滑動平均類時,直接使用參數當前取值
if avg_class == None:
# 計算隱藏層前向傳播結果
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
# 計算輸出層前向傳播結果
return tf.matmul(layer1, weights2) + biases2
else:
# 首先計算變量的滑動平均值,然后計算前向傳播結果
layer1 = tf.nn.relu(
tf.matmul(input_tensor, avg_class.average(weights1)) +
avg_class.average(biases1))
return tf.matmul(
layer1, avg_class.average(weights2)) + avg_class.average(biases2)
# 訓練模型的過程
def train(mnist):
x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
# 生成隱藏層參數
weights1 = tf.Variable(
tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))
biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))
# 生成輸出層參數
weights2 = tf.Variable(
tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))
biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))
# 計算前向傳播結果,不使用參數滑動平均值 avg_class=None
y = inference(x, None, weights1, biases1, weights2, biases2)
# 定義訓練輪數變量,指定為不可訓練
global_step = tf.Variable(0, trainable=False)
# 給定滑動平均衰減率和訓練輪數的變量,初始化滑動平均類
variable_avgs = tf.train.ExponentialMovingAverage(
MOVING_AVG_DECAY, global_step)
# 在所有代表神經網絡參數的可訓練變量上使用滑動平均
variables_avgs_op = variable_avgs.apply(tf.trainable_variables())
# 計算使用滑動平均值后的前向傳播結果
avg_y = inference(x, variable_avgs, weights1, biases1, weights2, biases2)
# 計算交叉熵作為損失函數
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=y, labels=tf.argmax(y_, 1))
cross_entropy_mean = tf.reduce_mean(cross_entropy)
# 計算L2正則化損失函數
regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
regularization = regularizer(weights1) + regularizer(weights2)
loss = cross_entropy_mean + regularization
# 設置指數衰減的學習率
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step, # 當前迭代輪數
mnist.train.num_examples / BATCH_SIZE, # 過完所有訓練數據的迭代次數
LEARNING_RATE_DECAY)
# 優化損失函數
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(
loss, global_step=global_step)
# 反向傳播同時更新神經網絡參數及其滑動平均值
with tf.control_dependencies([train_step, variables_avgs_op]):
train_op = tf.no_op(name='train')
# 檢驗使用了滑動平均模型的神經網絡前向傳播結果是否正確
correct_prediction = tf.equal(tf.argmax(avg_y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 初始化會話并開始訓練
with tf.Session() as sess:
tf.global_variables_initializer().run()
# 準備驗證數據,用于判斷停止條件和訓練效果
validate_feed = {x: mnist.validation.images,
y_: mnist.validation.labels}
# 準備測試數據,用于模型優劣的最后評價標準
test_feed = {x: mnist.test.images, y_: mnist.test.labels}
# 迭代訓練神經網絡
for i in range(TRAINING_STEPS):
if i%1000 == 0:
validate_acc = sess.run(accuracy, feed_dict=validate_feed)
print("After %d training step(s), validation accuracy using average "
"model is %g " % (i, validate_acc))
xs, ys = mnist.train.next_batch(BATCH_SIZE)
sess.run(train_op, feed_dict={x: xs, y_: ys})
# 訓練結束后在測試集上檢測模型的最終正確率
test_acc = sess.run(accuracy, feed_dict=test_feed)
print("After %d training steps, test accuracy using average model "
"is %g " % (TRAINING_STEPS, test_acc))
# 主程序入口
def main(argv=None):
mnist = input_data.read_data_sets("/tmp/data", one_hot=True)
train(mnist)
# Tensorflow主程序入口
if __name__ == '__main__':
tf.app.run()另外有需要云服務器可以了解下創新互聯scvps.cn,海內外云服務器15元起步,三天無理由+7*72小時售后在線,公司持有idc許可證,提供“云服務器、裸金屬服務器、高防服務器、香港服務器、美國服務器、虛擬主機、免備案服務器”等云主機租用服務以及企業上云的綜合解決方案,具有“安全穩定、簡單易用、服務可用性高、性價比高”等特點與優勢,專為企業上云打造定制,能夠滿足用戶豐富、多元化的應用場景需求。
分享文章:Tensorflow訓練MNIST手寫數字識別模型-創新互聯
網站網址:http://www.yijiale78.com/article32/ddhcsc.html
成都網站建設公司_創新互聯,為您提供響應式網站、網站設計公司、網頁設計公司、App設計、微信公眾號、靜態網站
聲明:本網站發布的內容(圖片、視頻和文字)以用戶投稿、用戶轉載內容為主,如果涉及侵權請盡快告知,我們將會在第一時間刪除。文章觀點不代表本網站立場,如需處理請聯系客服。電話:028-86922220;郵箱:631063699@qq.com。內容未經允許不得轉載,或轉載時需注明來源: 創新互聯