卷积神经网络的前世今生

来自:CSDN 2020-08-29

来源 |《Python人工智能开发从入门到精通》

作者 |杨柳、郭坦、鲁银芝

责编 | 晋兆雨

深度学习在技术与应用上的突破引发了第三次人工智能浪潮,获得了空前成功。在前述章节的基础上,本章将主要介绍训练卷积神经网络和深度神经网络的重要方法与技巧,深度神经网络的迁移学习策略,以及如何训练深度神经网络以解决实际问题等内容。作为人工智能的核心研究内容,以卷积神经网络(ConvolutionalNeural Networks, CNNs)为代表的深度学习技术已在计算机视觉应用,如智能监控、智慧医疗及机器人自动驾驶等领域取得突破性进展,而这些应用的成功落地很大程度上依赖于视觉识别模块。

结合前文内容,本章将详细介绍如何构建并利用CNNs 这一功能强大的深度学习模型解决实际的图像识别问题。

受20世纪中期兴起的神经科学及脑科学研究的启发,通过模拟生物神经元接收和处理信息的基本特性,研究人员提出并设计了人工神经元。作为计算机科学、生物学和数学的交叉融合,卷积神经网络已经发展成为计算机视觉领域中最具影响力和有效的基础技术。

早在20 世纪60 年代,生物学家Hubel 和Wiesel 通过研究猫的视觉皮层,发现每个视觉神经元都只对一个小区域范围内的视觉图像产生响应,即感受野(Receptive Field)。初级视觉皮层中的神经元能够响应视觉环境中特定的简单特征,除此之外,Hubel 和Wiesel 通过研究发现了简单和复杂两种不同类型的细胞,其中简单细胞只在特定的空间位置对它们偏好的方向产生最强烈响应,而复杂细胞具有更大的空间不变性。

根据这些实验和分析,他们得出结论:复杂细胞通过在来自多个简单细胞(每个都有一个不同的偏好位置)的输入进行池化而实现这种不变性,这两个特性,即对特定特征的选择性和通过前馈连接增大空间不变性,构成了CNN人工视觉系统的生物及神经学基础。

发展至80年代,日本科学家Kunihiko Fukushima 通过研究并融合有关生物视觉的相关领域知识,提出了Neocognitron 神经认知机的概念,该神经认知机由S 细胞和C 细胞构成,可通过无监督的方式学习识别简单的图像。20 世纪90 年代,Yann LeCun 等人发表论文,确立了CNN 影响至今的经典网络结构,后来经过对网络结构的不断完善与改进,得到一种多层的人工神经网络,命名为LeNet-5,在手写数字识别任务上取得良好效果。和其他神经网络一样,LeNet-5能够使用反向传播算法(Back Propagation)训练。

LeNet-5 网络虽然较小,但它含有诸多神经网络学习的关键模块,具体包括卷积层、池化层及全连接层,这些基本模块构成当前深度神经网络模型的基础,下文将对LeNet-5 的结构及工作原理进行深入分析。同时,借助实例加深读者对卷积神经网络各个模块功能的理解。

卷积神经网络与LeNet-5

LeNet-5 出自Yann LeCun 教授于1998 发表的论文Gradient-Based Learning Applied to DocumentRecognition 中,LeNet-5 模型共有7 层,如图11-1 所示为LeNet-5 的基本网络架构。

LeNet-5 的基本网络架构

该模型除了输入层之外,每层都包含可训练参数,每个网络层产生多个特征图,每个特征图可通过一种卷积滤波器提取输入数据一种类型的特征。各个网络层的功能与参数情况介绍如下。

1. 输入层

首先是输入数据网络层,上例中输入图像尺寸统一归一化为32×32×1,其中1 表示输入图像为单通道的灰度图,一般不将该层作为LeNet-5 网络的基本构成,即不将输入层视为网络层次结构之一。

2. C1 层

C 取自Convolutional 的首字母,指卷积。读者可能对卷积的概念并不陌生,对数字图像做卷积运算,本质上是通过卷积核(卷积模板)在图像上滑动,将图像上的像素灰度值与对应卷积核上的数值相乘,然后将所有相乘后的值相加作为卷积核中间像素对应像素的灰度值,以此方式遍历完成对整张图像像素的卷积计算。如图11-2显示了图像卷积计算过程中一次相乘后相加的运算过程,该卷积核大小为3×3,卷积核内共有9 个数值,数值个数即为图像像素值与卷积核上数值相乘次数,运算结果-4 代替了原图像中对应位置处的值。按此方式,沿着图片以步长为1 滑动,每次滑动1个像素都进行一次相乘再相加的操作,即可得到最终的输出结果。

卷积计算过程

图像卷积计算中,卷积核的设计十分重要,一般需遵循如下基本规则。

卷积核大小一般是奇数,奇数大小的卷积核使得卷积核关于中间像素点中心对称,因此卷积核尺寸一般是3×3、5×5或7×7。卷积核有中心,相应地就有半径的概念,如7×7 大小的卷积核,其半径为3。

卷积核所有的元素之和一般应等于1,这是为了保持图像卷积计算过程中像素能量(亮度)的守恒。若滤波器矩阵所有元素之和大于1,那么滤波后的图像就会比原图像更亮;反之,若小于1,那么得到的图像将会变暗。

滤波后可能会出现负数或大于255 的数值。对这种情况,通常将它们直接截断到0~255之间即可。而对于负数,也可以取绝对值。

经卷积计算所得输出通常被称为"响应",如果是边缘检测算子,那么响应为图像边缘,能够检测到特定的图像边缘。在LeNet-5网络中得到的响应是特征图(Feature Map),计算结果为输入图像的特征表达,卷积核的参数权重可以通过优化算法在监督信息的指导下自适应地学习得到。LeNet-5 网络中C1 层输入图像尺寸为32×32×1,卷积核大小为5×5,一共包括6 种大小为5×5 的卷积核,卷积核滑动一行之后,得到的结果的边长变为32-5+1,提取的特征映射大小是28×28,即(32-5+1)=28。6种不同的卷积核,可以从不同的角度提取图像不同特性的特征。

神经元数量为28×28×6,则可训练参数为(5×5+1)×6,即每个滤波器含5×5=25个单元权值参数和1个偏置参数,一共6 个滤波器,因此总的连接数为(5×5+1)×6×28×28=122 304。针对122 304 个连接,通过权值共享策略,只需学习156 个参数。

3. S2 层

S 指的是Subsamples,该网络层完成下采样操作,得到对应的特征图。下采样的原则是在减少数据量的同时尽可能保留有用的信息。与普通插值下采样的方式不同,该层实际采用的是一种被称为池化(Pooling)的方法。具体是将一幅图像分割成若干块,每个图像块的输出是该图像块原有像素的统计结果。

图像下采样池化方法有很多,如Mean-pooling( 均值采样)、Max-pooling( 最大值采样)、Overlapping ( 重叠采样)、L2-pooling( 均方采样)、Local Contrast Normalization( 局部对比归一化)、Stochastic-pooling( 随机采样) 和Def-pooling( 形变约束采样) 等,其中最经典的是最大池化,也是

最常用的,下面简要介绍最大池化的实现原理。

为直观起见,假设有如图11-3(a)中大小为4×4 的图像,图像中每个像素点的值是上面各个格子中的数值。现在对这张4×4 大小的图像进行池化操作,池化的大小为(2,2),步长为2。采用最大池化操作,首先对图像进行分块,每个图像块大小为2×2,然后按照图11-3(b)中方式统计每个图像块的最大值,作为下采样后图像的像素值,得到图11-3(c)中结果,该过程即为最大池化。

除此之外,还有其他池化方法,如均值池化,具体是对每个块求取平均值作为下采样的新像素值。上述例子未涉及重叠采样,即每个图像块之间没有相互重叠的部分,而步长为2 时,图像分块不重叠。

最大池化操作示意图

LeNet-5 网络中的S2 层的输入是上一层的输出,共有6 个特征映射,每个特征映射的尺寸为28×28,使用2×2 大小的核进行池化操作,得到S2,即6 个14×14 大小的特征映射(28/2=14)。换言之,S2 中的池化层是对C1 中的2×2 区域内的像素求和乘以一个权值系数再加上一个偏置,然后将这个结果再做一次映射。与卷积层连接数的计算方法一样,连接数=参数个数×特征映射大小,即(2×2+1)×6×14×14=5880。

4. C3 层

C3 层同样是卷积层,输入为S2 中所有6 个或若干个特征图的组合。具体地,该层卷积核大小为5×5,一共有6种卷积核,输出特征图大小为10×10,即(14-5+1)=10。需要注意的是,C3 中每个特征图是连接到S2 中的所有6 个或若干个特征图的,即该层的特征图是上一层提取到的特征图的不同组合。如图11-4 所示,LeCun 在原论文中给出的一种组合连接方式。

LeNet-5 网络C3 层特征映射组合方式

图11-4 中共有6 行16 列,横轴代表C3 特征映射索引,纵轴代表S2 特征图索引。每列的X表示C3 中的每个特征映射与S2 中的特征图的连接情况,可以看到C3 的前6 个特征图,对应上图第1 个红框的6 列,以S2 中3 个相邻的特征图子集为输入,紧接着6 个特征图(对应上图第2 个红框的6 列)以S2 中4 个相邻特征图子集为输入。

然后,接下来的3 个特征图(对应上图第3 个红框的3 列)以S2 中不相邻的4 个特征图子集为输入,C3 中的最后一个特征图对应上图第4 个红框的1 列将S2 中所有特征图为输入。这里得到的每一个特征图为多核多通道卷积,将每一列称为一个卷积核,它由若干个卷积模板构成,因为是多个卷积核模板卷积的结果得到一个特征图,仍然认为是一个卷积核,所以每列只有一个偏置参数。之所以采取这种组合方式,LeCun 主要是基于以下两点考虑:减少参数;采用不对称的组合连接方式有利于提取多种组合特征。

5. S4 层

S4 层为下采样层,即池化层,窗口大小为2×2,包括16 个特征图,C3 层的16 个10×10 的特征图分别进行以2×2 为单位的池化得到16 个5×5 的特征图,步长为2,即本网络层的输出张量大小为5×5×16,一共有5×5×5×16=2000 个连接,连接的方式与S2 层类似。6. C5 层

C5 层是一个卷积层,输入为S4 层的全部16 个特征图,该层卷积原理与普通卷积层一致,只是因为恰巧卷积核大小与输入特征图尺寸一样,因此得到一维,即1×1(5-5+1)的输出,卷积核种类为120,得到120 维的卷积结果,每个都与上一层的16 个特征图相连,因此一共有(5×5×16+1)×120=48 120 个可训练参数。7. F6 层

F6 层是全连接层,采用全连接的方式与C5 层连接,由对C5 层的输入乘以权重加上偏置,结果通过激活Sigmoid 函数输出。F6 层有84 个节点,对应于一个7×12 的比特图,-1 表示白色,1 表示黑色,这样每个符号的比特图的黑白色对应一个编码,F6 层的训练参数/ 连接数为(120+1)×84=10 164。

8. Output 输出层

Output 输出层同样是全连接层,共有10 个节点,分别代表数字0~9,且如果节点i 的值为0,则网络识别的结果是数字i。采用的是径向基函数(RBF)的网络连接方式。假设x 是上一层的输入,y 是RBF 的输出,则RBF 输出的计算方式如式(11-1)所示。

式中wij 的值由i 的比特图编码确定,i 取值为0~9,j 取值为0~(7*12-1)。RBF 输出的值越接近于0,表示当前网络的输入越接近于i,即越接近于i 的ASCII 编码,该层一共包含84×10=840 个可学习参数。卷积神经网络在本质上是在不需要获取输入和输出之间精确的数学表达的情况下,学习从输入数据到目标输出的复杂映射,卷积神经网络的优势在于能够很好地利用图像的二维结构信息,LeNet-5 在银行支票手写体字符识别问题上得到成功应用。

不考虑输入层,LeNet-5 是一个7 层的网络,卷积层的参数较少,这得益于卷积层的若干重要特性,即局部连接和共享权重。现在常用的LeNet-5 结构和Yann LeCun 教授在1998 年论文中提出的结构在某些细节上存在一定的区别,如激活函数的使用,现在一般使用ReLU 作为激活函数,而输出层一般选用Softmax。CNN 能够提取原始图像的有效表征,这赋予CNN 经过较少的预处理,即可从原始像素中学习和识别视觉规律的能力。然而,由于LeNet-5 提出伊始,缺乏大规模的训练数据,计算机的计算能力也难以满足要求,CNN 的网络架构在不同文献中的描述略有差异。

不过,CNN 的基本组成单元和模块相对一致,可以像搭积木一样将不同功能的网络层组合起来,从而实现规模更大、深度更深的网络。因此,从某种意义上说,CNN 或深度学习中的网络层本质上是能够进行信息处理的积木单元。LeNet-5 对于更复杂问题的处理效果并不理想,但通过对LeNet-5 的网络结构的分析与研究,可以直观地了解卷积神经网络的构建方法,能够为分析和构建更复杂、更深层的卷积神经网络打下坚实的基础。

总结卷积神经网络的成功经验,主要在于局部连接(LocalConnection)、权值共享(Weight Sharing)和池化层(Pooling)中的降采样(Down-Sampling)。

(1)卷积层(Convolutions Layer)。卷积层由很多的卷积核(Convolutional Kernel)组成,卷积核用来计算不同的特征图,卷积层卷积神经网络的核心。在图像识别里用到的卷积是二维卷积,具体是二维滤波器滑动到二维图像上所有位置,并在每个位置上与该像素点及其领域像素点做内积。卷积操作被广泛应用于图像处理领域,不同类型的卷积核可以提取图像不同类型的特征,例如,边缘、角点等特征。在深层卷积神经网络中,通过卷积操作可以提取出图像低级简单到抽象复杂的特征,学习输入数据具有较强普适性的特征表达。除此之外,激活函数能够为CNN 卷积神经网络引入非线性,增强网络的复杂建模能力,常用的非线性激活函数有Sigmoid、Tanh 和ReLu 等,前两者常见于全连接层,后者ReLu 则多用于卷积层。

(2)池化层(Pooling Layer)。池化是非线性下采样的一种形式,主要作用是通过减少网络的参数来减小计算量,同时池化层能降低卷积层输出的特征向量,通常在卷积层的后面会加上一个池化层,通过卷积层与池化层交替使用可以获得更复杂的高层抽象特征,并且能够在一定程度上避免和缓解过拟合现象。常用的池化操作包括最大池化、平均池化等,其中最大池化是用不重叠的矩形框将输入层分成不同的区域,对于每个矩形框内的数值取最大值作为统计输出。

(3)全连接层(Full Connected Layer)。如果说卷积层、池化层和激活函数映射等操作是将原始数据映射到隐层特征空间的话,那么全连接层则起到将学到的"分布式特征表示"映射到样本标记空间的作用,将多层的特征表达拉直成一个一维的向量,实现神经网络的高层抽象推理能力,在整个卷积神经网络中起到"分类器"的作用。

(4)局部连接(Local Connection)。局部连接指的是每个神经元仅与输入神经元的一块区域相连,该局部区域也被称为感受野(Receptive Field)。局部连接的思想可追溯至生物学里面的视觉系统结构,即视觉皮层的神经元实质上是局部接收信息的。在图像卷积操作中,神经元在空间维度上是局部连接的,但在深度上是全部连接的。对于二维图像本身而言,局部像素关联较强,这种局部连接保证了学习后的过滤器能够对于局部的输入特征有最强的响应。

(5)权重共享(Weight Sharing)。实际中,图像的底层边缘特征与特征在图中的具体位置无关,即特征可能出现在图像的任意位置,权重共享正是利用这一特点,具体是指卷积核内权重参数被整张图共享,而不会因图像内位置的不同而改变,可在图像中的不同位置学习到同样的特征,权重共享可以在很大程度上减少参数数量。

LeNet-5 的TensorFlow 实现

前文介绍了LeNet-5 的基本网络结构,以及各个网络功能层的特点与作用,本节将利用TensorFlow 具体实现这一网络。首先需要说明以下几点。

(1)LeNet-5 主要采用Tanh 和Sigmoid 作为非线性激活函数,但相对这两者采用ReLu 激活函数的卷积神经网络更加有效。

(2)LeNet-5 采用平均池化作为下采样操作,但是目前最大池化操作应用更为广泛。

(3)LeNet-5 网络最后一层采用Gaussian 连接层,用于输出0~9 这10 个类别中的一类,但是目前分类器操作已经被Softmax 层取代。

第1 步:建立config.py 文件,可以将超参数设置在config.py 中,方便后期对模型进行调整。代码实现与说明如程序清单11-1 所示。

程序清单11-1 config.py 文件建立及超参数设置

1."""

2.设置模型的超参数

5.KEEP_PROB:网络随机失活的概率

6.LEARNING_RATE:学习的速率,即梯度下降的速率

7.BATCH_SIZE:一次训练所选取的样本数

8.PARAMETER_FILE:模型参数保存的路径

9.MAX_ITER:最大迭代次数

10."""

11.

12.KEEP_PROB=0.5

13.LEARNING_RATE=1e-5

14.BATCH_SIZE=50

15.PARAMETER_FILE="checkpoint/variable.ckpt"

16.MAX_ITER=50000

第2 步:构建LeNet 模型的LeNet.py 文件,建立一个名为Lenet 的类,类中实现模型的初始化与构建,代码实现与说明如程序清单11-2 所示。

程序清单11-2 构建LeNet 模型与LeNet.py 文件

1.importtensorflowastf

2.importtensorflow.contrib.slimasslim

3.importconfigascfg

5.classLenet:

6.def__init__(self):

7."""

8.初始化LeNet网络

9."""

10.#设置网络输入的图片为二维张量,数据的类型为float32,行数不固定,列固定为784

11.self.raw_input_image=tf.placeholder(tf.float32,[None,784])

12.

13.#改变网络输入张量的形状为四维,-1表示数值不固定

14.self.input_images=tf.reshape(self.raw_input_image,[-1,28,28,1])

15.

16.#设置网络输入标签为二维张量,数据类型为float,行数不固定,列固定为10

17.self.raw_input_label=tf.placeholder("float",[None,10])

18.

19.#改变标签的数据类型为int32

20.self.input_labels=tf.cast(self.raw_input_label,tf.int32)

21.

22.#设置网络的随机失活概率

23.self.dropout=cfg.KEEP_PROB

24.

25.#构建两个网络

26.#train_digits为训练网络,开启dropout

27.#pred_digits为预测网络,关闭dropout

28.withtf.variable_scope("Lenet")asscope:

29.self.train_digits=self.construct_net(True)

30.scope.reuse_variables()

31.self.pred_digits=self.construct_net(False)

32.

33.#获取网络的预测数值

34.self.prediction=tf.argmax(self.pred_digits,1)

35.

36.#获取网络的预测数值与标签的匹配程度

37.self.correct_prediction=tf.equal(tf.argmax(self.pred_digits,1),tf.argmax

(self.input_labels,1))

38.

39.#将匹配程度转换为float类型,表示为精度

40.self.train_accuracy=tf.reduce_mean(tf.cast(self.correct_prediction,"float"))

41.

42.#计算train_digits与labels之间的系数softmax交叉熵,定义为loss

43.self.loss=slim.losses.softmax_cross_entropy(self.train_digits,self.

input_labels)

44.

45.#设置学习速率

46.self.lr=cfg.LEARNING_RATE

47.self.train_op=tf.train.AdamOptimizer(self.lr).minimize(self.loss)

48.

49.

50.defconstruct_net(self,is_trained=True):

51."""

52.接收is_trained参数判断是否开启dropout

53.用slim构建LeNet模型

54.第一、三、五层为卷积层、第二、四层为池化层

55.接下来对第五层扁平化,再接入全连接

56.接着进行随机失活防止过拟合,再次接入全连接层

57.最后返回构建的网络

58."""

59.withslim.arg_scope([slim.conv2d],padding='VALID',

60.weights_initializer=tf.truncated_normal_initializer(stddev=0.01),

61.weights_regularizer=slim.l2_regularizer(0.0005)):

62.net=slim.conv2d(self.input_images,6,[5,5],1,padding='SAME',scope='conv1')

63.net=slim.max_pool2d(net,[2,2],scope='pool2')

64.net=slim.conv2d(net,16,[5,5],1,scope='conv3')

65.net=slim.max_pool2d(net,[2,2],scope='pool4')

66.net=slim.conv2d(net,120,[5,5],1,scope='conv5')

67.net=slim.flatten(net,scope='flat6')

68.net=slim.fully_connected(net,84,scope='fc7')

69.net=slim.dropout(net,self.dropout,is_training=is_trained,scope=

'dropout8')

70.digits=slim.fully_connected(net,10,scope='fc9')

71.returndigits

本模型是对著名的手写字体MNIST数据集进行训练,可以在网站//yann.lecun.com/exdb/mnist/ 上很方便地直接下载数据,得到如图11-5 的MNIST 数据集。

MNIST 数据集列表

第3 步:建立模型训练文件Train.py,主要实现数据读取、模型训练等功能,代码实现与说明如程序清单11-3 所示。

程序清单11-3 模型训练文件Train.py 的建立

1.importtensorflow.examples.tutorials.mnist.input_dataasinput_data

2.importtensorflowastf

3.importconfigascfg

4.importos

5.importlenet

6.fromlenetimportLenet

9.defmain():

10.#从指定路径加载训练数据

11.mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)

12.

13.#开启TensorFlow会话

14.sess=tf.Session()

15.

16.#设置超参数

17.batch_size=cfg.BATCH_SIZE

18.parameter_path=cfg.PARAMETER_FILE

19.lenet=Lenet()

20.max_iter=cfg.MAX_ITER

21.

22.#加载已保存的模型参数文件,如果不存在则调用初始化函数生成初始网络

23.saver=tf.train.Saver()

24.ifos.path.exists(parameter_path):

25.saver.restore(parameter_path)

26.else:

27.sess.run(tf.initialize_all_variables())

28.

29.#迭代训练max_iter次,每次抽取50个样本进行训练

30.#每100次打印出当前数据的精度

31.#训练完成后保存模型参数

32.foriinrange(max_iter):

33.batch=mnist.train.next_batch(50)

34.ifi%100==0:

35.train_accuracy=sess.run(lenet.train_accuracy,feed_dict={

36.lenet.raw_input_image:batch[0],lenet.raw_input_label:batch[1]

37.})

38.print("step%d,trainingaccuracy%g"%(i,train_accuracy))

39.sess.run(lenet.train_op,feed_dict={lenet.raw_input_image:batch[0],lenet.

raw_input_label:batch[1]})

40.save_path=saver.save(sess,parameter_path)

41.

42.if__name__=='__main__':

43.main(

第4 步:在上述完成步骤的基础上,运行Train.py。如图11-6 所示,可以看到随着不断迭代优化,模型精度在逐步提高。

模型迭代优化过程

第5 步:建立测试文件Inference.py。具体地,建立一个Inference 类完成对图片的识别,成员函数predict 接收图片作为参数,返回预测值,代码实现与说明如程序清单11-4 所示。

程序清单11-4 测试文件Inference.py 的建立

1.importtensorflowastf

2.fromPILimportImage,ImageOps

3.importnumpyasnp

4.fromlenetimportLenet

5.importconfigascfg

7.classinference:

8.def__init__(self):

9."""

10.构建Lenet网络,设置模型参数文件路径,开启TensorFlow会话

11."""

12.self.lenet=Lenet()

13.self.sess=tf.Session()

14.self.parameter_path=cfg.PARAMETER_FILE

15.self.saver=tf.train.Saver()

16.

17.defpredict(self,image):

18."""

19.接收要测试的图片作为参数,返回预测值

20."""

21.#将图片转换为合适的大小进行输入

22.img=image.convert('L')

23.img=img.resize([28,28],Image.ANTIALIAS)

24.image_input=np.array(img,dtype="float32")/255

25.image_input=np.reshape(image_input,[-1,784])

26.

27.#读取模型参数并对图片进行预测,返回预测值

28.self.saver.restore(self.sess,self.parameter_path)

29.predition=self.sess.run(self.lenet.prediction,feed_dict={self.lenet.raw_

input_image:image_input})

30.returnpredition

第6 步:为了方便地实现对手写数字的识别,可以使用python 的tkinter方便地绘制一个UI 进行识别,具体实现代码如程序清单11-5 所示。

程序清单11-5 利用tkinter 建立UI

1.importtkinter

2.fromPILimportImage,ImageDraw

3.fromInferenceimportinference

5.classMyCanvas:

6."""

7.设置一个256*256大小的容器进行手写界面的绘制

8.背景色设置为黑色,绘制轨迹设置为白色

9."""

10.def__init__(self,root):

11.self.root=root

12.self.canvas=tkinter.Canvas(root,width=256,height=256,bg='black')

13.self.canvas.pack()

14.self.image1=Image.new("RGB",(256,256),"black")

15.self.draw=ImageDraw.Draw(self.image1)

16.self.canvas.bind('<B1-Motion>',self.Draw)

17.

18.#绘制轨迹

19.defDraw(self,event):

20.self.canvas.create_oval(event.x,event.y,event.x,event.y,outline="white",

width=20)

21.self.draw.ellipse((event.x-10,event.y-10,event.x+10,event.y+10),fill=(255,

255,255))

22.

23.

24.defmain():

25.#建立一个tkinter对象,设置大小为380*300

26.root=tkinter.Tk()

27.root.geometry('380x300')

28.#创建一个256*256的框架容纳手写的容器,位于tkinter对象的左边,填充y方向

29.frame=tkinter.Frame(root,width=256,height=256)

30.frame.pack_propagate(0)

31.frame.pack(side="left",fill='y')

32.#将frame导入canvas容器

33.canvas1=MyCanvas(frame)

34.#创建一个图像识别的实例

35.infer=inference()

36.

37.#定义识别按钮触发函数

38.#按下的时候将cavas导出为图片,放入infer中进行图像识别,并将结果显示在label2中

39.definference_click():

40.img=canvas1.image1

41.result=infer.predict(img)

42.result=int(result)

43.label2["text"]=str(result)

44.

45.#定义清除按钮的触发函数

46.#按下的时候将canvas情况并重新绘制背景,并将label设置为空

47.defclear_click():

48.canvas1.canvas.delete("all")

49.canvas1.image1=Image.new("RGB",(256,256),"black")

50.canvas1.draw=ImageDraw.Draw(canvas1.image1)

51.label2["text"]=""

52.

53.#定义识别按钮的样式

54.botton_Inference=tkinter.Button(root,

55.text="检测",

56.width=14,

57.height=2,

58.command=inference_click

59.)

60.#定义清除按钮的样式

61.botton_Clear=tkinter.Button(root,

62.text="清屏",

63.width=14,

64.height=2,

65.command=clear_click

66.)

67.#绑定识别按钮到tkinter中,设置位置为顶层

68.botton_Inference.pack(side="top")

69.

70.#绑定清除按钮到tkinter中

71.botton_Clear.pack(side="top")

72.

73.#定义label1

74.label1=tkinter.Label(root,justify="center",text="检测结果为:")

75.label1.pack(side="top")

76.

77.#定义label2

78.label2=tkinter.Label(root,justify="center")

79.

80.#设置字体样式与大小

81.label2["font"]=("Arial,48")

82.label2.pack(side="top")

83.root.mainloop()

84.

85.if__name__=='__main__':

86.main()

第7 步:运行代码并进行如下的几组测试,测试结果如图11-7 所示。

手写数字测试示例

更多精彩推荐

"Talk is cheap, show me the code"你一行代码有多少漏洞?

融资 2000 万美元后,他竟将核心代码全开源,这……能行吗?

打破定制化语音技术落地怪圈?从讲一口标准英音的语音助手说起

赠书 | 人工智能识万物:卷积神经网络的前世今生

MySQL 索引分析除了 EXPLAIN 还有什么方法?

医疗数字化:区块链或成最强辅助

点分享

点点赞

点在看