Train a CNN on MNIST with TensorFlow/Keras (Step-by-Step)

Here is a bilingual (English first, then Chinese), with a single, copy-paste-ready TensorFlow CNN script for MNIST, plus concise guidance, visuals tips, and an FAQ.


Tags: Deep Learning, TensorFlow, Keras, Tutorial, MNIST
Reading time: ~8–12 minutes

Why this post

A compact, production-friendly starter for image classification using a small CNN on MNIST. You’ll get: environment setup, clean code, training/evaluation, plots, confusion matrix, and how to save/reload models.


1) Environment

python -m pip install -U pip
python -m pip install -U tensorflow matplotlib scikit-learn
# Optional (for TensorBoard):
# python -m pip install -U tensorboard
GPU note: Works fine on CPU. If you have NVIDIA GPU, install the proper driver/CUDA/cuDNN first, then pip install tensorflow. On Apple Silicon, recent tensorflow wheels also work; if not, try tensorflow-macos.

Quick check:

python - <<'PY'
import tensorflow as tf
print("TF:", tf.__version__)
print("Devices:", tf.config.list_physical_devices())
PY

2) What we’ll build

A compact CNN:

  • Conv2D(32) → Conv2D(64) → MaxPool → Dropout(0.25)
  • Flatten → Dense(128) → Dropout(0.5) → Dense(10, softmax)
  • Loss: SparseCategoricalCrossentropy; Optimizer: Adam; Metric: accuracy.

Expected test accuracy: ~99% after ~5–10 epochs.


3) Full, Ready-to-Run Script

Save as tf_mnist_cnn.py and run: python tf_mnist_cnn.py.

# tf_mnist_cnn.py
import os
import random
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

# ---------- 0) Reproducibility (optional) ----------
SEED = 42
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

# Limit GPU memory growth (optional)
for gpu in tf.config.list_physical_devices('GPU'):
    try: tf.config.experimental.set_memory_growth(gpu, True)
    except: pass

# ---------- 1) Data ----------
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.astype("float32")/255.0
x_test  = x_test.astype("float32")/255.0
x_train = np.expand_dims(x_train, -1)  # (N,28,28,1)
x_test  = np.expand_dims(x_test,  -1)

# Train/val split
from sklearn.model_selection import train_test_split
x_train, x_val, y_train, y_val = train_test_split(
    x_train, y_train, test_size=0.1, random_state=SEED, stratify=y_train
)
print("Shapes:", x_train.shape, x_val.shape, x_test.shape)

# ---------- 2) Model ----------
def build_model():
    inputs = keras.Input(shape=(28,28,1))
    x = layers.Conv2D(32, (3,3), activation="relu", padding="same")(inputs)
    x = layers.Conv2D(64, (3,3), activation="relu", padding="same")(x)
    x = layers.MaxPooling2D((2,2))(x)
    x = layers.Dropout(0.25)(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128, activation="relu")(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(10, activation="softmax")(x)
    return keras.Model(inputs, outputs, name="mnist_cnn")

model = build_model()
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"]
)

# ---------- 3) Callbacks ----------
os.makedirs("checkpoints", exist_ok=True)
ckpt_path = "checkpoints/best_mnist_cnn.keras"
callbacks = [
    keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=3, restore_best_weights=True),
    keras.callbacks.ModelCheckpoint(ckpt_path, monitor="val_accuracy", save_best_only=True),
    # keras.callbacks.TensorBoard(log_dir="logs")  # optional
]

# ---------- 4) Train ----------
history = model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    epochs=10,
    batch_size=128,
    callbacks=callbacks,
    verbose=1
)

# ---------- 5) Evaluate ----------
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
print(f"[Test] loss={test_loss:.4f}, acc={test_acc:.4f}")

# ---------- 6) Curves ----------
def plot_history(h):
    hist = h.history
    fig, ax = plt.subplots(1,2, figsize=(10,4))
    ax[0].plot(hist["loss"], label="train"); ax[0].plot(hist["val_loss"], label="val")
    ax[0].set_title("Loss"); ax[0].legend()
    ax[1].plot(hist["accuracy"], label="train"); ax[1].plot(hist["val_accuracy"], label="val")
    ax[1].set_title("Accuracy"); ax[1].legend()
    plt.tight_layout(); plt.show()

plot_history(history)

# ---------- 7) Predictions Grid ----------
def visualize_predictions(m, x, y, n=25):
    idx = np.random.choice(len(x), n, replace=False)
    imgs, labels = x[idx], y[idx]
    preds = m.predict(imgs, verbose=0).argmax(axis=1)
    rows = cols = int(np.sqrt(n))
    plt.figure(figsize=(10,10))
    for i in range(n):
        plt.subplot(rows, cols, i+1)
        plt.imshow(imgs[i].squeeze(), cmap="gray")
        title = f"T:{labels[i]} P:{preds[i]}"
        color = "green" if labels[i]==preds[i] else "red"
        plt.title(title, fontsize=10, color=color)
        plt.axis("off")
    plt.tight_layout(); plt.show()

visualize_predictions(model, x_test, y_test, n=25)

# ---------- 8) Confusion Matrix ----------
from sklearn.metrics import confusion_matrix, classification_report
y_pred = model.predict(x_test, verbose=0).argmax(axis=1)
cm = confusion_matrix(y_test, y_pred)
print("Classification Report:\n", classification_report(y_test, y_pred, digits=4))

plt.figure(figsize=(6,6))
plt.imshow(cm, cmap="Blues")
plt.title("Confusion Matrix"); plt.xlabel("Predicted"); plt.ylabel("True")
plt.colorbar()
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(j, i, str(cm[i, j]), ha="center", va="center")
plt.tight_layout(); plt.show()

# ---------- 9) Save & Reload ----------
save_dir = "saved_mnist_cnn"
model.save(save_dir)
print(f"Model saved to: {save_dir}")
reloaded = keras.models.load_model(save_dir)
loss2, acc2 = reloaded.evaluate(x_test, y_test, verbose=0)
print(f"[Reloaded] loss={loss2:.4f}, acc={acc2:.4f}")

4) Upgrade ideas

  • Data augmentation: small random translation/rotation to improve robustness.
  • LR schedule: cosine decay or ReduceLROnPlateau.
  • Deeper nets: add a third conv block or BatchNorm layers.
  • TensorBoard: monitor metrics and graph.

5) FAQ

Q: No GPU—still OK?
A: Yes. MNIST is tiny; CPU is fine (usually <5 minutes).

Q: Accuracy <99%?
A: Train 15–20 epochs, add a conv block, tune LR to 5e-4, or ensure stratify=y.

Q: Overfitting?
A: Increase Dropout, add light augmentation, or rely on EarlyStopping.

Q: Why SparseCategoricalCrossentropy?
A: Labels are integer class IDs (0–9), no need to one-hot.


6) License

Code is MIT-friendly for learning and reuse. Cite this post if helpful. 👍



用 TensorFlow/Keras 训练 MNIST 的卷积神经网络(手把手)

标签:深度学习、TensorFlow、Keras、教程、MNIST
阅读时长:约 8–12 分钟

写这篇的目的

给出一份简洁可复用的入门模板:用小型 CNN 进行 MNIST 手写数字分类。你将获得:环境准备、整洁代码、训练/评估、可视化、混淆矩阵,以及模型保存/加载。


1)环境准备

python -m pip install -U pip
python -m pip install -U tensorflow matplotlib scikit-learn
# 可选(TensorBoard):
# python -m pip install -U tensorboard
GPU 提示: CPU 就能跑。若有 NVIDIA GPU,请先装好驱动/CUDA/cuDNN,再 pip install tensorflow。Apple 芯片通常直接装 tensorflow 即可,不行就尝试 tensorflow-macos

快速检测:

python - <<'PY'
import tensorflow as tf
print("TF:", tf.__version__)
print("Devices:", tf.config.list_physical_devices())
PY

2)要实现的模型

一个小而强的 CNN:

  • Conv2D(32) → Conv2D(64) → MaxPool → Dropout(0.25)
  • Flatten → Dense(128) → Dropout(0.5) → Dense(10, softmax)
  • 损失:SparseCategoricalCrossentropy;优化器:Adam;指标:accuracy

一般 5–10 个 epoch 即可在测试集上接近 ~99%


3)完整可运行脚本

保存为 tf_mnist_cnn.py,运行:python tf_mnist_cnn.py

代码与英文版完全一致(上方代码块),不再重复贴出。你可直接复制上面的单文件脚本运行。

4)进阶改造

  • 数据增强:轻微平移/旋转增强鲁棒性。
  • 学习率调度:余弦退火或 ReduceLROnPlateau
  • 更深网络:再加一组卷积块或加入 BatchNorm。
  • TensorBoard:更优雅地观察训练曲线与计算图。

5)常见问题

问:没有 GPU 能跑吗?
答:没问题。MNIST 很小,CPU 一般几分钟内搞定。

问:准确率达不到 99%?
答:适当增加 epoch(15–20)、加一层卷积块、把学习率调到 5e-4、确认验证集分层抽样(stratify=y)。

问:过拟合怎么办?
答:调大 Dropout、加轻度数据增强、启用 EarlyStopping(文中已用)。

问:为什么用 SparseCategoricalCrossentropy
答:标签是 0–9 的整数类标,不需要手动 one-hot,更简单。


6)许可

代码可按 MIT 风格自由使用学习。若内容帮助了你,欢迎引用本帖。🙌


Need a Jupyter-Notebook version or an extended CNN (with augmentation/LR schedule/BatchNorm)?
你要是希望发布 Notebook 版本或者更高级的 CNN 配方,我可以直接给出可运行的 .ipynb(中英双语注释)。