从零开始玩转TensorFlow:小明的机器学习故事 5

news/2025/2/25 9:40:44

图像识别的挑战

1 故事引入:小明的“图像识别”大赛

小明从学校里听说了一个有趣的比赛:“美食图像识别”。参赛者需要训练计算机,看一张食物照片(例如披萨、苹果、汉堡等),就能猜出这是什么食物。听起来非常酷,但是怎么让计算机“看懂”图片呢?

小明想:
“传统的神经网络(全连接网络)之前用来做数字分类效果还行,但这些美食图像不但颜色复杂,而且分辨率更高,直接用全连接网络可能会效果一般。听说有个东西叫卷积神经网络(CNN),特别适合图像识别。我就来好好研究一下吧!”


2 为什么CNN更擅长看图?

2.1 普通神经网络:所有像素一锅端

在最简单的全连接网络里,每个神经元都要处理图像里所有像素的信息。想象下:

  • 你有一大张图,每个位置都有信息。
  • “侦探们”都拥挤在一起,每个人想“盯”整个图像。
  • 数据一多就会非常混乱,大家根本不好分工,效率也低。
2.2 卷积神经网络:给侦探们分配“放大镜”

CNN 里有一层又一层的小“滤镜”(卷积核),它们就像给侦探们每人发了一个“放大镜”,让他们专注查看图像上的某一块区域,从而提取“边缘”“角”“颜色块”等重要特征。

  • 卷积层(Convolution Layer):这个层就负责把一张图分成小块挨个扫描,发现局部特征。
  • 池化层(Pooling Layer):把提取到的“局部发现”做简化,让整体数据量更小;同时,对位置的小幅变化也更有耐心。
  • 全连接层(Dense Layer):将各个层提取到的特征进行综合决策,最终输出“这是苹果呢?还是披萨?还是汉堡?”。

正因为这样分工协作,CNN 在图像识别任务上往往能大显身手。


3 小明的热身实验:CIFAR-10

比赛数据通常比较大,小明想先试试 CNN 的“套路”。他找来 CIFAR-10 这个小数据集做热身。CIFAR-10 共有 10 个类别(如飞机、汽车、鸟、猫、船等),每张图都是彩色的,但只有 32×32 像素,大小和图像内容都比较简单,正好适合入门。

3.1 准备数据与环境
import tensorflow as tf
from tensorflow import keras

# 下载并加载CIFAR-10数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

# 缩放像素值到[0,1]区间
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

num_classes = 10  # 一共10个类别
  1. x_train, y_train:训练数据,包含图像及其对应标签。
  2. x_test, y_test:测试数据,用来在训练结束后考核模型。
  3. 归一化:像素值从 0~255 变成 0~1,加快模型收敛。

3.2 CNN 模型结构:小明的“放大镜团队”
model = keras.Sequential([
    # 第一组:卷积 + 池化
    keras.layers.Conv2D(32, (3,3), activation='relu', padding='same',
                        input_shape=(32, 32, 3)),
    keras.layers.MaxPooling2D((2,2)),

    # 第二组:卷积 + 池化
    keras.layers.Conv2D(64, (3,3), activation='relu', padding='same'),
    keras.layers.MaxPooling2D((2,2)),

    # 将特征图展开,再接全连接层
    keras.layers.Flatten(),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(num_classes, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)
model.summary()
  1. Conv2D(32, 3×3):32 个滤镜,每个滤镜关注 3×3 区域;padding='same' 表示卷积后图大小不变。
  2. MaxPooling2D(2×2):对特征图做 2×2 的“精华提取”。
  3. Flatten:把卷积/池化后得到的多维特征图拉平成一维数组。
  4. Dense(128):中间的全连接层,用于更深层次的特征整合。
  5. Dense(num_classes):输出 10 个类别(CIFAR-10 就是 0~9 这 10 类)。

3.3 模型训练:让侦探团队学会识图
history = model.fit(
    x_train, y_train,
    epochs=10,
    batch_size=64,
    validation_split=0.2
)
  • 训练过程:喂给网络很多训练图片,网络会先猜测是哪一类,然后根据“猜的结果”和“真实标签”之间的差距来更新参数。
  • epochs:训练轮数;batch_size:一次要处理多少张图片。
  • validation_split=0.2:从训练集中划出 20% 的数据做验证,便于实时观察模型的泛化能力。

3.4 训练成果可视化:准确率与损失曲线

在实际操作中,我们通常还会绘制训练和验证的准确率(accuracy)和损失值(loss),看看模型是否正在稳步提高。

  • 如果验证准确率突然下降,说明可能出现 过拟合
  • 如果训练准确率和验证准确率都一起上升,恭喜你,模型健康成长。

以下是一段常见的可视化示例代码,演示如何使用 matplotlib 绘制训练过程中 准确率(accuracy)损失值(loss) 的变化曲线。

import matplotlib.pyplot as plt

# 从 history 对象中获取训练过程中的准确率和损失值
acc = history.history['accuracy']              # 训练准确率
val_acc = history.history['val_accuracy']      # 验证准确率
loss = history.history['loss']                 # 训练损失
val_loss = history.history['val_loss']         # 验证损失

epochs_range = range(len(acc))  # 横坐标:训练的轮数

plt.figure(figsize=(12, 5))

# 1. 绘制准确率曲线
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.title('Accuracy Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')

# 2. 绘制损失值曲线
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(loc='upper right')

plt.show()

代码解释

  1. 获取准确率和损失值

    • history.history['accuracy']:训练集上的准确率随 epoch 的变化。
    • history.history['val_accuracy']:验证集上的准确率随 epoch 的变化。
    • history.history['loss']history.history['val_loss']:分别是训练集和验证集的损失值。
  2. 绘制图像

    • plt.subplot(1, 2, 1)plt.subplot(1, 2, 2) 用于将图像一分为二,分别在左、右两边绘制准确率和损失值曲线。
    • plt.plot(epochs_range, ...):将不同 epoch 的值连成线,观察趋势。
    • plt.title()plt.xlabel()plt.ylabel():添加标题和坐标轴标签,便于阅读。
    • plt.legend():显示图例,区分训练曲线和验证曲线。

运行该段代码后,你会看到两个并排的折线图:左边是 准确率 随训练轮数的变化,右边是 损失值 随训练轮数的变化。通过它们,你可以直观判断模型是否持续学习收敛,或是否出现 过拟合(若验证准确率下降或验证损失上升,而训练集表现持续改善,往往说明过拟合)。

示例输出:
在这里插入图片描述

4 最终的“毕业考”:测试集 & 图片可视化

小明已经把 CNN 训练好了,现在是让模型“毕业考”的时候,也就是在测试集 (x_test, y_test) 上检验它的表现。

test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print("Test accuracy:", test_acc)

示例输出:

Test accuracy: 0.7005000114440918

成功得到一个满意的准确率后,小明心中一阵窃喜:
“看来这‘放大镜团队’果然名不虚传!”

4.1 随机挑几张测试图,看看模型预测得如何
import random
import numpy as np
import matplotlib.pyplot as plt

indices = random.sample(range(len(x_test)), 5)
for i in indices:
    img = x_test[i]
    label = y_test[i][0]
    pred = model.predict(img[np.newaxis, ...])
    pred_label = np.argmax(pred, axis=1)[0]

    plt.imshow(img)
    plt.title(f"Real: {label}, Pred: {pred_label}")
    plt.show()
  • random.sample(...):随机挑几个测试样本进行展示。
  • plt.imshow(img):显示这张 32×32 像素的图像。
  • Real: {label}:数据集中给出的真实标签。
  • Pred: {pred_label}:模型预测的标签。

当图像分辨率低时,我们人眼看起来可能会觉得比较模糊。但只要模型学到了“像素间的相关性”,就能“看”得出哪张更像猫、哪张更像船。

示例输出:
在这里插入图片描述
在这里插入图片描述


5 解读示例:当模型识别到一艘船

下图是一个示例输出(当你运行上面代码时,可能得到不同的随机图)。假设标题写着:

Real: 8, Pred: 8

并且画面看上去有一点类似“海面上白色物体”的模糊画面。

  • CIFAR-10 标签 8 通常代表“船”(ship)
  • 虽然图像分辨率低,但海洋深蓝和船体白色的对比能给网络提供重要线索。
  • 模型预测也给出了“8”,说明它成功认定这是“船”。
  • 如果你在比赛里,这是一个 “预测正确” 的案例。

如果实际图像是“狗”,然而模型给了“猫”之类的预测,就可以通过观察图像特征、数据增强或调整网络结构来做进一步优化。这样的人机对照可以帮你快速找到模型的盲点。


6 故事总结:小明的收获

  1. 卷积神经网络为什么行?
    • 通过“局部扫描 + 池化精简”,CNN 能更好地从图像中提取关键特征,且参数量更少,不容易过拟合。
  2. CIFAR-10先热身
    • 小明在 CIFAR-10 上先试手,发现 CNN 能得到不错的效果,足以说明原理可行。
  3. 看结果很重要
    • 训练曲线可以帮助判断模型是否过拟合或仍在上升空间。
    • 最终在测试集上打印几张预测结果,能让人更直观地看到“模型脑子里想的”和“真实情况”一致性如何。
  4. 阶段性成就感
    • 完整走完这套流程,小明对 CNN 有了更深理解,也为他在 “美食图像识别” 大赛上取得好成绩打下坚实基础。

最终寄语

通过这一章节,我们以 小明的故事 为主线,先介绍为什么卷积神经网络(CNN)更适合图像,再到如何用一个小练习数据集(CIFAR-10)搭建模型、训练、测试,并可视化预测结果。

  • 逻辑链路:问题(美食识别)→ 为什么CNN → 如何CNN → 小实验→ 看结果→ 结果分析
  • 核心方法:理解卷积、池化、全连接的作用,知道如何用代码实现并调试。
  • 图像可视化:从中可以直观看到模型预测是否正确,尤其是像CIFAR-10这种低分辨率图像,对人眼来说模糊,但模型却能学到对应特征。如果出现“预测错”,就能帮助我们快速找到改进思路。

现在,小明和你都算踏进了 “计算机视觉” 这个广阔的领域,下一步可以在更大、更真实的数据上继续折腾!别忘了多实践、多观察,最后说不定你也会在比赛中取得令人惊喜的成绩。加油!


http://www.niftyadmin.cn/n/5865336.html

相关文章

flutter Column嵌套ListView高度自适应问题

1.限制最大高度500,当布局高度小于500时高度自适应包裹 //当布局外不需要包裹Container时,使用ConstrainedBox(constraints: BoxConstraints(maxHeight: 500,minHeight: 0),child: Column()) _body(){return Container(constraints: BoxConstraints(max…

vue3学习3-route

创建路由器: 应用路由器: 路由展示区RouterView 和 路由跳转RouterLink: 路由组件(在路由配置文件中配置的)一般放到pages/views文件夹下 路由组件切换的时候执行的是 挂载/卸载操作 onMounted / onUnmouted 路由器两…

3dtiles平移旋转工具制作

3dtiles平移旋转缩放原理及可视化工具实现 背景 平时工作中,通过cesium平台来搭建一个演示场景是很常见的事情。一般来说,演示场景不需要多完善的功能,但是需要一批三维模型搭建,如厂房、电力设备、园区等。在实际搭建过程中&…

一文讲解Redis为什么读写性能高以及I/O复用相关知识点

Redis为什么读写性能高呢? Redis 的速度⾮常快,单机的 Redis 就可以⽀撑每秒十几万的并发,性能是 MySQL 的⼏⼗倍。原因主要有⼏点: ①、基于内存的数据存储,Redis 将数据存储在内存当中,使得数据的读写操…

协方差(Covariance)与得分函数:从Fisher信息矩阵看统计关联

协方差与得分函数:从Fisher信息矩阵看统计关联 协方差(Covariance)是统计学中一个基础但强大的概念,它描述了两个随机变量之间的关系。在Fisher信息矩阵中,协方差以一种特别的形式出现:得分函数的协方差。…

vue js-web-screen-shot浏览器截取其他非全屏窗口界面

网页截屏 js-web-screen-shot 截取其他窗口 显示不全问题 npm 安装 js-web-screen-shot npm install js-web-screen-shot --savejs-web-screen-shot默认截屏是从左下角开始的,修改成左上角开始,然后编辑cropBoxInfo参数宽高进行截取,目前截…

计算机网络与通讯知识总结

计算机网络与通讯知识总结 基础知识总结 1)FTP:文件传输 SSH:远程登录 HTTP:网址访问 2)‌交换机 定义‌:一种基于MAC地址实现局域网(LAN)内数据高速转发的网络设备,可为接入设备提供独享通信通道‌。 -‌ 核心功能‌: 1.数据链路层(OSI第二层)工作,通过MAC地址…

Spring高级篇-Spring IOC容器 Aware 接口

一、概述 在Spring框架中,IOC(Inversion of Control)容器负责管理应用程序中的对象(即Bean)的生命周期和依赖关系。Spring提供了一系列的Aware接口,允许Bean在初始化时获取Spring容器中的某些资源或信息。…