下载mnist库
- 用官方链接下载四个文件,分别是训练的图片、标签和识别的图片、标签,下载完后放在程序目录下的MNIST_data文件夹下
- 如果不下载通过导包的方式,也会自动下载好
1
from tensorflow.examples.tutorials.mnist import input_data
导入数据
TF提供了方便的封装,可以直接加载MNIST数据为我们期望的格式
1 | mnist = input_data.read_data_sets('MNIST_data', one_hot = True) |
建立一层神经网络,由于经常需要添加神经层,干脆封成一个函数
多分类任务,通常使用softmax regression模型。工作原理就是讲某一类的特征相加,然后把这些特征转换为判定是这一类的概率。
1 | import tensorflow as tf |
Train方法(优化算法)
这里采用梯度下降法
1 | train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) |
训练
1 | # 开始训练之前,首先要构建图,InteractiveSession可以被注册为默认的session,这样之后运算会方便 |
验证
1 | correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(y_, 1)) |