Paddle手写数字识别

安装PaddlePaddle

环境准备

  • windows 64
  • acconda3 2022.05
1.1创建虚拟环境
  • 安装环境

    首先根据具体的Python版本创建Anaconda虚拟环境,PaddlePaddle的Anaconda安装支持以下四种python安装环境。

    python版本为3.6

1
conda create -n paddle_env python=3.6

​ python版本为3.7

1
conda create -n paddle_env python=3.6

​ python版本为3.8

1
conda create -n paddle_env python=3.6

​ python版本为3.9

1
conda create -n paddle_env python=3.6
进入创建的虚拟的环境
1
activate paddle_env

开始安装

  • cpu 版

    1
    conda install paddlepaddle==2.3.2 --channel https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/

​ 后面的网址是指定安装源,用清华源速度快一点

  • gpu 版

    1
    conda install paddlepaddle-gpu==2.3.2 cudatoolkit=11.6 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/ -c conda-forge

    关于cuda版本的问题不用担心,将cuda版本更新到最新即可,他是向下兼容的,具体可以看看这篇文章,查看cuda版本可以看看这篇文章

    安装过程中可能耗时比较久,中间可能会卡顿一下,按下回车即可

    安装完成

验证安装

输入python进入python解释器,输入import paddle,再输入paddle.utils.run_check(),如果出现PaddlePaddle is installed successfully!,说明已成功安装。

参考资料:Paddle 飞桨官网文档

实践:手写数字识别任务

准备

安装 Python 的 matplotlib 库和 numpy 库,matplotlib 库用于可视化图片,numpy 库用于处理数据。

1
2
# 使用 pip 工具安装 matplotlib 和 numpy
! python3 -m pip install matplotlib numpy -i https://mirror.baidu.com/pypi/simple

这里可以也可以用conda 命令的,这里图省事,就直接粘贴过来了

实践

  • 数据集定义与加载

    1
    2
    3
    4
    5
    6
    7
    import paddle
    from paddle.vision.transforms import Normalize

    transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
    # 下载数据集并初始化 DataSet
    train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
    test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)

    这里可能会遇到下载不了的情况,建议先考察自己的电脑网络问题,再考虑其他,笔者这里耽误了很多时间,就是因为网络问题。

    在执行代码时会遇到DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses import imp可以不用管他,这里是一个警告,是指imp已经被废弃,建议更换imporlib,但是我更换之后会出现一个包无法引入的情况,网上暂时没有搜索到有用的信息,所以暂时搁置了,

  • 模型组网

    1
    2
    3
    4
    5
    # 模型组网并初始化网络
    lenet = paddle.vision.models.LeNet(num_classes=10)

    # 可视化模型组网结构和参数
    paddle.summary(lenet,(1, 1, 28, 28))
  • 模型训练与评估

    • 模型训练

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      # 封装模型,便于进行后续的训练、评估和推理
      model = paddle.Model(lenet)

      # 模型训练的配置准备,准备损失函数,优化器和评价指标
      model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),
      paddle.nn.CrossEntropyLoss(),
      paddle.metric.Accuracy())

      # 开始训练
      model.fit(train_dataset, epochs=5, batch_size=64, verbose=1)

    • 模型评估

      1
      2
      # 进行模型评估
      model.evaluate(test_dataset, batch_size=64, verbose=1)

  • 模型推理

    • 模型保存

      1
      2
      # 保存模型,文件夹会自动创建
      model.save('./output/mnist')

      output
      ├── mnist.pdopt # 优化器的参数
      └── mnist.pdparams # 模型的参数

      以上代码执行后会在output目录下保存两个文件,mnist.pdopt为优化器的参数,mnist.pdparams为模型的参数。

    • 模型加载推理

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      # 加载模型
      model.load('output/mnist')

      # 从测试集中取出一张图片
      img, label = test_dataset[0]
      # 将图片shape从1*28*28变为1*1*28*28,增加一个batch维度,以匹配模型输入格式要求
      img_batch = np.expand_dims(img.astype('float32'), axis=0)

      # 执行推理并打印结果,此处predict_batch返回的是一个list,取出其中数据获得预测结果
      out = model.predict_batch(img_batch)[0]
      pred_label = out.argmax()
      print('true label: {}, pred label: {}'.format(label[0], pred_label))
      # 可视化图片
      from matplotlib import pyplot as plt
      plt.imshow(img[0])

      这里可以看到机器识别出了这个数字为2

​ 可能最后的图片的无法弹出来,可以参考一下这篇文章

​ 其中一个解决方法只有pycharm专业版才有的。


Paddle手写数字识别
https://lijusting.top/posts/dab415f3/
作者
lijusting,
发布于
2022年10月17日
许可协议