Skip to content
WONGCW 網誌
  • 首頁
  • 論壇
  • 微博
  • 壁紙下載
  • 免費圖床
  • 視頻下載
  • 聊天室
  • SEO工具
  • 支援中心
  • 表格製作
  • More
    • 在線名片
    • 網頁搜索
    • 天氣預報
    • 二維碼生成器
    • WordPress 插件及主題下載
  • Search Icon

WONGCW 網誌

記錄生活經驗與點滴

TensorFlow – 基于 CNN 数字识别

TensorFlow – 基于 CNN 数字识别

2018-09-24 Comments 0 Comment

TensorFlow 实现基于 CNN 数字识别的代码

任务时间:时间未知

前期准备

TensorFlow 相关 API 可以到在实验 TensorFlow – 相关 API 中学习。

训练数据下载:

wget https://devlab-1251520893.cos.ap-guangzhou.myqcloud.com/t10k-images-idx3-ubyte.gz
wget https://devlab-1251520893.cos.ap-guangzhou.myqcloud.com/t10k-labels-idx1-ubyte.gz
wget https://devlab-1251520893.cos.ap-guangzhou.myqcloud.com/train-images-idx3-ubyte.gz
wget https://devlab-1251520893.cos.ap-guangzhou.myqcloud.com/train-labels-idx1-ubyte.gz

CNN 模型构建

示例代码:

现在您可以在 /home/ubuntu 目录下创建源文件 mnist_model.py,内容可参考:

示例代码:/home/ubuntu/mnist_model.py
#!/usr/bin/python
# -*- coding: utf-8 -*

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import tempfile

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

FLAGS = None


def deepnn(x):

  with tf.name_scope('reshape'):
    x_image = tf.reshape(x, [-1, 28, 28, 1])

  #第一层卷积层,卷积核为5*5,生成32个feature maps.
  with tf.name_scope('conv1'):
    W_conv1 = weight_variable([5, 5, 1, 32])
    b_conv1 = bias_variable([32])
    h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) #激活函数采用relu

  # 第一层池化层,下采样2.
  with tf.name_scope('pool1'):
    h_pool1 = max_pool_2x2(h_conv1)

  # 第二层卷积层,卷积核为5*5,生成64个feature maps
  with tf.name_scope('conv2'):
    W_conv2 = weight_variable([5, 5, 32, 64])
    b_conv2 = bias_variable([64])
    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)#激活函数采用relu

  # 第二层池化层,下采样2.
  with tf.name_scope('pool2'):
    h_pool2 = max_pool_2x2(h_conv2)

  #第一层全连接层,将7x7x64个feature maps与1024个features全连接
  with tf.name_scope('fc1'):
    W_fc1 = weight_variable([7 * 7 * 64, 1024])
    b_fc1 = bias_variable([1024])

    h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

  #dropout层,训练时候随机让某些隐含层节点权重不工作
  with tf.name_scope('dropout'):
    keep_prob = tf.placeholder(tf.float32)
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

  # 第二层全连接层,1024个features和10个features全连接
  with tf.name_scope('fc2'):
    W_fc2 = weight_variable([1024, 10])
    b_fc2 = bias_variable([10])

    y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
  return y_conv, keep_prob

#卷积
def conv2d(x, W):
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

#池化
def max_pool_2x2(x):
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                        strides=[1, 2, 2, 1], padding='SAME')
#权重
def weight_variable(shape):
  initial = tf.truncated_normal(shape, stddev=0.1)
  return tf.Variable(initial)

#偏置
def bias_variable(shape):
  initial = tf.constant(0.1, shape=shape)
  return tf.Variable(initial)

训练 CNN 模型

示例代码:

现在您可以在 /home/ubuntu 目录下创建源文件 train_mnist_model.py,内容可参考:

示例代码:/home/ubuntu/train_mnist_model.py
#!/usr/bin/python
# -*- coding: utf-8 -*

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import tempfile

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

import mnist_model

FLAGS = None


def main(_):
  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

  #输入变量,mnist图片大小为28*28
  x = tf.placeholder(tf.float32, [None, 784])

  #输出变量,数字是1-10
  y_ = tf.placeholder(tf.float32, [None, 10])

  # 构建网络,输入—>第一层卷积—>第一层池化—>第二层卷积—>第二层池化—>第一层全连接—>第二层全连接
  y_conv, keep_prob = mnist_model.deepnn(x)

  #第一步对网络最后一层的输出做一个softmax,第二步将softmax输出和实际样本做一个交叉熵
  #cross_entropy返回的是向量
  with tf.name_scope('loss'):
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_,
                                                            logits=y_conv)

  #求cross_entropy向量的平均值得到交叉熵
  cross_entropy = tf.reduce_mean(cross_entropy)

  #AdamOptimizer是Adam优化算法:一个寻找全局最优点的优化算法,引入二次方梯度校验
  with tf.name_scope('adam_optimizer'):
    train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

  #在测试集上的精确度
  with tf.name_scope('accuracy'):
    correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
    correct_prediction = tf.cast(correct_prediction, tf.float32)
  accuracy = tf.reduce_mean(correct_prediction)

  #将神经网络图模型保存本地,可以通过浏览器查看可视化网络结构
  graph_location = tempfile.mkdtemp()
  print('Saving graph to: %s' % graph_location)
  train_writer = tf.summary.FileWriter(graph_location)
  train_writer.add_graph(tf.get_default_graph())

  #将训练的网络保存下来
  saver = tf.train.Saver()
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(5000):
      batch = mnist.train.next_batch(50)
      if i % 100 == 0:
        train_accuracy = accuracy.eval(feed_dict={
            x: batch[0], y_: batch[1], keep_prob: 1.0})#输入是字典,表示tensorflow被feed的值
        print('step %d, training accuracy %g' % (i, train_accuracy))
      train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

    test_accuracy = 0
    for i in range(200):
      batch = mnist.test.next_batch(50)
      test_accuracy += accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0}) / 200;

    print('test accuracy %g' % test_accuracy)

    save_path = saver.save(sess,"mnist_cnn_model.ckpt")

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--data_dir', type=str,
                      default='./',
                      help='Directory for storing input data')
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

然后执行:

cd /home/ubuntu;
python train_mnist_model.py

训练的时间会较长,可以喝杯茶耐心等待。

执行结果:

step 3600, training accuracy 0.98
step 3700, training accuracy 0.98
step 3800, training accuracy 0.96
step 3900, training accuracy 1
step 4000, training accuracy 0.98
step 4100, training accuracy 0.96
step 4200, training accuracy 1
step 4300, training accuracy 1
step 4400, training accuracy 0.98
step 4500, training accuracy 0.98
step 4600, training accuracy 0.98
step 4700, training accuracy 1
step 4800, training accuracy 0.98
step 4900, training accuracy 1
test accuracy 0.9862

测试 CNN 模型

下载测试图片

下载 test_num.zip

cd /home/ubuntu
wget https://devlab-1251520893.cos.ap-guangzhou.myqcloud.com/test_num.zip

解压测试图片包

解压 test_num.zip,其中 1-9.png 为 1-9 数字图片。

unzip test_num.zip

实现 predict 代码

现在您可以在 /home/ubuntu 目录下创建源文件 predict_mnist_model.py,内容可参考:

示例代码:/home/ubuntu/predict_mnist_model.py
#!/usr/bin/python
# -*- coding: utf-8 -*

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import tempfile

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

import mnist_model
from PIL import Image, ImageFilter

def load_data(argv):

    grayimage = Image.open(argv).convert('L')
    width = float(grayimage.size[0])
    height = float(grayimage.size[1])
    newImage = Image.new('L', (28, 28), (255))

    if width > height:
        nheight = int(round((20.0/width*height),0))
        if (nheigth == 0):
            nheigth = 1
        img = grayimage.resize((20,nheight), Image.ANTIALIAS).filter(ImageFilter.SHARPEN)
        wtop = int(round(((28 - nheight)/2),0))
        newImage.paste(img, (4, wtop))
    else:
        nwidth = int(round((20.0/height*width),0))
        if (nwidth == 0):
            nwidth = 1
        img = grayimage.resize((nwidth,20), Image.ANTIALIAS).filter(ImageFilter.SHARPEN)
        wleft = int(round(((28 - nwidth)/2),0))
        newImage.paste(img, (wleft, 4))

    tv = list(newImage.getdata())
    tva = [ (255-x)*1.0/255.0 for x in tv]
    return tva

def main(argv):

    imvalue = load_data(argv)

    x = tf.placeholder(tf.float32, [None, 784])
    y_ = tf.placeholder(tf.float32, [None, 10])
    y_conv, keep_prob = mnist_model.deepnn(x)

    y_predict = tf.nn.softmax(y_conv)
    init_op = tf.global_variables_initializer()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(init_op)
        saver.restore(sess, "mnist_cnn_model.ckpt")
        prediction=tf.argmax(y_predict,1)
        predint = prediction.eval(feed_dict={x: [imvalue],keep_prob: 1.0}, session=sess)
        print (predint[0])

if __name__ == "__main__":
    main(sys.argv[1])

然后执行:

cd /home/ubuntu;
python predict_mnist_model.py 1.png

执行结果:

1

你可以修改 1.png 为 1-9.png 中任意一个

完成

任务时间:时间未知

恭喜,您已完成本实验内容

您可进行更多 TensorFlow 的系列教程:

  • TensorFlow – 相关 API
  • TensorFlow – 线性回归

关于 TensorFlow 的更多资料可参考 TensorFlow 官网 。

分享此文:

  • 分享到 Twitter(在新視窗中開啟)
  • 按一下以分享至 Facebook(在新視窗中開啟)
  • 分享到 WhatsApp(在新視窗中開啟)
  • 按一下以分享到 Telegram(在新視窗中開啟)
  • 分享到 Pinterest(在新視窗中開啟)
  • 分享到 Reddit(在新視窗中開啟)
  • 按一下即可分享至 Skype(在新視窗中開啟)
  • 按一下即可以電子郵件傳送連結給朋友(在新視窗中開啟)
  • 點這裡列印(在新視窗中開啟)

相關


TensorFlow

Post navigation

PREVIOUS
TensorFlow – 深层神经网络 (deep neural networks)
NEXT
TensorFlow – 基于 CNN 破解验证码

發表迴響 取消回覆

這個網站採用 Akismet 服務減少垃圾留言。進一步了解 Akismet 如何處理網站訪客的留言資料。

More results...

Generic filters
Exact matches only
Search in title
Search in content
Search in excerpt
Filter by 分類
網站公告
Featured
限時免費
ESET NOD32
WINDOWS 10 &11 INSIDER PREVIEW
Windows 軟件下載
系統軟件
辦公軟件
圖像處理
影音媒體
網絡軟件
應用軟件
Mac 軟件下載
安卓軟件下載
網絡資訊
Mac資訊
Linux資訊
VPS資訊
NASA資訊
金融資訊
WhatsApp Stickers教學
WordPress資訊
WeChat資訊
PHP資訊
Plesk資訊
TensorFlow
教學資源
開源程序
網頁工具
SEO工具
醫療健康
旅遊及消閒
其他資訊
Content from
Content to
2018 年 9 月
一 二 三 四 五 六 日
 12
3456789
10111213141516
17181920212223
24252627282930
« 8 月   10 月 »

分類

  • 網站公告
  • 限時免費
  • ESET NOD32
  • WINDOWS 10 &11 INSIDER PREVIEW
  • Windows 軟件下載
  • 系統軟件
  • 辦公軟件
  • 圖像處理
  • 影音媒體
  • 網絡軟件
  • 應用軟件
  • Mac 軟件下載
  • 安卓軟件下載
  • 網絡資訊
  • Mac資訊
  • Linux資訊
  • VPS資訊
  • NASA資訊
  • WhatsApp Stickers教學
  • WordPress資訊
  • WeChat資訊
  • PHP資訊
  • Plesk資訊
  • TensorFlow
  • 教學資源
  • 開源程序
  • 網頁工具
  • SEO工具
  • 醫療健康
  • 旅遊及消閒
  • 其他資訊

彙整

近期文章

  • 微軟預計英國監管機構CMA將反對其收購動視暴雪將盡力贏得支持 2023-02-07
  • 地質學家突破性地發現地球構造板塊下隱藏的熔岩層 2023-02-07
  • 蘋果Mac mini還不如一台顯示器貴?我們跟它的產品營銷團隊聊了聊 2023-02-07
  • 一天內兩次遭遇7.8級大地震土耳其災區網絡流量暴跌94% 幾乎中斷 2023-02-07
  • NASA“踩下油門” 開始同時建造三枚巨大的月球火箭 2023-02-07
  • 蘋果高管討論M2芯片、Mac上的遊戲、來自英特爾的競爭等問題 2023-02-07
  • 研究發現氣候危機還極大地影響了孟加拉國重災區的心理健康 2023-02-07
  • 壞消息:海洋變暖使海洋寄生蟲數量減少 2023-02-07
  • Google宣布ChatGPT的競爭產品Bard 未來幾週將更廣泛上線 2023-02-07
  • AI輔助的地外生命搜索系統檢測到8個”非常可疑的”無線電信號 2023-02-07

熱門文章與頁面︰

  • DP vs HDMI 誰才是遊戲玩家最佳選擇?
  • ESET NOD32 LICENSE KEY (UPDATED 2023-01-17)
  • 打車叫到特斯拉不會開門很尷尬?官方介紹開關門方法
  • Explorer Patcher:讓Windows 11恢復Windows 10的行為特徵
  • 貝貝BiliBili – 在線一鍵下載B站視頻,可解析播放!
  • Koodo Reader – 免費開源的Epub 閱讀器,支持桌面版和網頁版
  • ESET NOD32 LICENSE KEY (UPDATED 2022-01-01)
  • AirTag機場實測:各項表現優異只有1個場景下無法生效
  • 下載WINDOWS10 登入畫面的圖片 (WINDOWS焦點)
  • YouTube By Click 2.2.86(含:註冊機序列號)

投遞稿件

歡迎各界人士投遞稿件到admin@wongcw.com

請提供以下資料:

1.你的名字

2.你的電郵

3.分類目錄

4.文章標題

5.文章摘要

6.文章內容

7.文章來源

 

聯繫我們

查詢,投稿,商務合作:
​admin@wongcw.com
​技術支援:
​support@wongcw.com
​客户服務:
​cs@wongcw.com

QQ群:833641851

快帆

MALUS

極度掃描

DMCA.com Protection Status

WONGCW 網誌

  • 免責聲明
  • 捐助我們
  • ThemeNcode PDF Viewer
  • ThemeNcode PDF Viewer SC
  • Events

服務器提供

本站使用之服務器由ikoula提供。

聯繫我們

查詢,投稿,商務合作:
​admin@wongcw.com
​技術支援:
​support@wongcw.com
​客户服務:
​cs@wongcw.com

QQ群:833641851

© 2023   All Rights Reserved.
X