Train a CNN on MNIST with TensorFlow/Keras (Step-by-Step)
Here is a bilingual (English first, then Chinese), with a single, copy-paste-ready TensorFlow CNN script for MNIST, plus concise guidance, visuals tips, and an FAQ.
Tags: Deep Learning, TensorFlow, Keras, Tutorial, MNIST
Reading time: ~8–12 minutes
Why this post
A compact, production-friendly starter for image classification using a small CNN on MNIST. You’ll get: environment setup, clean code, training/evaluation, plots, confusion matrix, and how to save/reload models.
1) Environment
python -m pip install -U pip
python -m pip install -U tensorflow matplotlib scikit-learn
# Optional (for TensorBoard):
# python -m pip install -U tensorboard
GPU note: Works fine on CPU. If you have NVIDIA GPU, install the proper driver/CUDA/cuDNN first, thenpip install tensorflow. On Apple Silicon, recenttensorflowwheels also work; if not, trytensorflow-macos.
Quick check:
python - <<'PY'
import tensorflow as tf
print("TF:", tf.__version__)
print("Devices:", tf.config.list_physical_devices())
PY
2) What we’ll build
A compact CNN:
Conv2D(32) → Conv2D(64) → MaxPool → Dropout(0.25)Flatten → Dense(128) → Dropout(0.5) → Dense(10, softmax)- Loss:
SparseCategoricalCrossentropy; Optimizer:Adam; Metric:accuracy.
Expected test accuracy: ~99% after ~5–10 epochs.
3) Full, Ready-to-Run Script
Save as tf_mnist_cnn.py and run: python tf_mnist_cnn.py.
# tf_mnist_cnn.py
import os
import random
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
# ---------- 0) Reproducibility (optional) ----------
SEED = 42
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)
# Limit GPU memory growth (optional)
for gpu in tf.config.list_physical_devices('GPU'):
try: tf.config.experimental.set_memory_growth(gpu, True)
except: pass
# ---------- 1) Data ----------
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.astype("float32")/255.0
x_test = x_test.astype("float32")/255.0
x_train = np.expand_dims(x_train, -1) # (N,28,28,1)
x_test = np.expand_dims(x_test, -1)
# Train/val split
from sklearn.model_selection import train_test_split
x_train, x_val, y_train, y_val = train_test_split(
x_train, y_train, test_size=0.1, random_state=SEED, stratify=y_train
)
print("Shapes:", x_train.shape, x_val.shape, x_test.shape)
# ---------- 2) Model ----------
def build_model():
inputs = keras.Input(shape=(28,28,1))
x = layers.Conv2D(32, (3,3), activation="relu", padding="same")(inputs)
x = layers.Conv2D(64, (3,3), activation="relu", padding="same")(x)
x = layers.MaxPooling2D((2,2))(x)
x = layers.Dropout(0.25)(x)
x = layers.Flatten()(x)
x = layers.Dense(128, activation="relu")(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(10, activation="softmax")(x)
return keras.Model(inputs, outputs, name="mnist_cnn")
model = build_model()
model.summary()
model.compile(
optimizer=keras.optimizers.Adam(1e-3),
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=["accuracy"]
)
# ---------- 3) Callbacks ----------
os.makedirs("checkpoints", exist_ok=True)
ckpt_path = "checkpoints/best_mnist_cnn.keras"
callbacks = [
keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=3, restore_best_weights=True),
keras.callbacks.ModelCheckpoint(ckpt_path, monitor="val_accuracy", save_best_only=True),
# keras.callbacks.TensorBoard(log_dir="logs") # optional
]
# ---------- 4) Train ----------
history = model.fit(
x_train, y_train,
validation_data=(x_val, y_val),
epochs=10,
batch_size=128,
callbacks=callbacks,
verbose=1
)
# ---------- 5) Evaluate ----------
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
print(f"[Test] loss={test_loss:.4f}, acc={test_acc:.4f}")
# ---------- 6) Curves ----------
def plot_history(h):
hist = h.history
fig, ax = plt.subplots(1,2, figsize=(10,4))
ax[0].plot(hist["loss"], label="train"); ax[0].plot(hist["val_loss"], label="val")
ax[0].set_title("Loss"); ax[0].legend()
ax[1].plot(hist["accuracy"], label="train"); ax[1].plot(hist["val_accuracy"], label="val")
ax[1].set_title("Accuracy"); ax[1].legend()
plt.tight_layout(); plt.show()
plot_history(history)
# ---------- 7) Predictions Grid ----------
def visualize_predictions(m, x, y, n=25):
idx = np.random.choice(len(x), n, replace=False)
imgs, labels = x[idx], y[idx]
preds = m.predict(imgs, verbose=0).argmax(axis=1)
rows = cols = int(np.sqrt(n))
plt.figure(figsize=(10,10))
for i in range(n):
plt.subplot(rows, cols, i+1)
plt.imshow(imgs[i].squeeze(), cmap="gray")
title = f"T:{labels[i]} P:{preds[i]}"
color = "green" if labels[i]==preds[i] else "red"
plt.title(title, fontsize=10, color=color)
plt.axis("off")
plt.tight_layout(); plt.show()
visualize_predictions(model, x_test, y_test, n=25)
# ---------- 8) Confusion Matrix ----------
from sklearn.metrics import confusion_matrix, classification_report
y_pred = model.predict(x_test, verbose=0).argmax(axis=1)
cm = confusion_matrix(y_test, y_pred)
print("Classification Report:\n", classification_report(y_test, y_pred, digits=4))
plt.figure(figsize=(6,6))
plt.imshow(cm, cmap="Blues")
plt.title("Confusion Matrix"); plt.xlabel("Predicted"); plt.ylabel("True")
plt.colorbar()
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
plt.text(j, i, str(cm[i, j]), ha="center", va="center")
plt.tight_layout(); plt.show()
# ---------- 9) Save & Reload ----------
save_dir = "saved_mnist_cnn"
model.save(save_dir)
print(f"Model saved to: {save_dir}")
reloaded = keras.models.load_model(save_dir)
loss2, acc2 = reloaded.evaluate(x_test, y_test, verbose=0)
print(f"[Reloaded] loss={loss2:.4f}, acc={acc2:.4f}")
4) Upgrade ideas
- Data augmentation: small random translation/rotation to improve robustness.
- LR schedule: cosine decay or
ReduceLROnPlateau. - Deeper nets: add a third conv block or BatchNorm layers.
- TensorBoard: monitor metrics and graph.
5) FAQ
Q: No GPU—still OK?
A: Yes. MNIST is tiny; CPU is fine (usually <5 minutes).
Q: Accuracy <99%?
A: Train 15–20 epochs, add a conv block, tune LR to 5e-4, or ensure stratify=y.
Q: Overfitting?
A: Increase Dropout, add light augmentation, or rely on EarlyStopping.
Q: Why SparseCategoricalCrossentropy?
A: Labels are integer class IDs (0–9), no need to one-hot.
6) License
Code is MIT-friendly for learning and reuse. Cite this post if helpful. 👍
用 TensorFlow/Keras 训练 MNIST 的卷积神经网络(手把手)
标签:深度学习、TensorFlow、Keras、教程、MNIST
阅读时长:约 8–12 分钟
写这篇的目的
给出一份简洁可复用的入门模板:用小型 CNN 进行 MNIST 手写数字分类。你将获得:环境准备、整洁代码、训练/评估、可视化、混淆矩阵,以及模型保存/加载。
1)环境准备
python -m pip install -U pip
python -m pip install -U tensorflow matplotlib scikit-learn
# 可选(TensorBoard):
# python -m pip install -U tensorboard
GPU 提示: CPU 就能跑。若有 NVIDIA GPU,请先装好驱动/CUDA/cuDNN,再pip install tensorflow。Apple 芯片通常直接装tensorflow即可,不行就尝试tensorflow-macos。
快速检测:
python - <<'PY'
import tensorflow as tf
print("TF:", tf.__version__)
print("Devices:", tf.config.list_physical_devices())
PY
2)要实现的模型
一个小而强的 CNN:
Conv2D(32) → Conv2D(64) → MaxPool → Dropout(0.25)Flatten → Dense(128) → Dropout(0.5) → Dense(10, softmax)- 损失:
SparseCategoricalCrossentropy;优化器:Adam;指标:accuracy。
一般 5–10 个 epoch 即可在测试集上接近 ~99%。
3)完整可运行脚本
保存为 tf_mnist_cnn.py,运行:python tf_mnist_cnn.py。
代码与英文版完全一致(上方代码块),不再重复贴出。你可直接复制上面的单文件脚本运行。
4)进阶改造
- 数据增强:轻微平移/旋转增强鲁棒性。
- 学习率调度:余弦退火或
ReduceLROnPlateau。 - 更深网络:再加一组卷积块或加入 BatchNorm。
- TensorBoard:更优雅地观察训练曲线与计算图。
5)常见问题
问:没有 GPU 能跑吗?
答:没问题。MNIST 很小,CPU 一般几分钟内搞定。
问:准确率达不到 99%?
答:适当增加 epoch(15–20)、加一层卷积块、把学习率调到 5e-4、确认验证集分层抽样(stratify=y)。
问:过拟合怎么办?
答:调大 Dropout、加轻度数据增强、启用 EarlyStopping(文中已用)。
问:为什么用 SparseCategoricalCrossentropy?
答:标签是 0–9 的整数类标,不需要手动 one-hot,更简单。
6)许可
代码可按 MIT 风格自由使用学习。若内容帮助了你,欢迎引用本帖。🙌
Need a Jupyter-Notebook version or an extended CNN (with augmentation/LR schedule/BatchNorm)?
你要是希望发布 Notebook 版本或者更高级的 CNN 配方,我可以直接给出可运行的 .ipynb(中英双语注释)。