OpenCV + sklearnSVM 实现手写数字分割和识别

opencv,sklearnsvm · 浏览次数 : 9

小编点评

本文主要介绍了如何使用Python和OpenCV库完成手写数字的分割和识别任务,并使用支持向量机(SVM)进行分类。文章首先介绍了MNIST数据集的准备,然后详细阐述了SVM训练过程、数字分割方法以及如何展示识别结果。 1. **数据集准备**:文章首先介绍了MNIST数据集的来源和结构,以及如何使用PyTorch的 torchvision.datasets 模块读取数据集。接着,详细描述了数据预处理的过程,包括二值化、数据增强等步骤。 2. **SVM训练**:文章详细讲解了SVM的工作原理和构造函数,以及如何使用sklearn库中的SVC类进行SVM训练。讨论了SVM中的一些重要参数,如kernel、C、decision_function_shape等,并给出了相应的解释。 3. **数字分割**:文章介绍了如何使用OpenCV库进行边缘检测和轮廓提取,以分离出图像中的数字部分。通过拟合外接矩形来提取ROI,从而完成数字分割。 4. **杂项与代码**:文章还介绍了如何使用pickle模块保存和加载训练好的SVM模型,以及如何使用SummaryWriter进行模型训练过程的可视化。 总的来说,本文提供了一个完整的基于MNIST数据集的手写数字识别项目的实现过程,包括数据预处理、SVM训练、数字分割和模型保存等多个环节。

正文

这学期机器学习考核方式以大作业的形式进行考核,而且只能使用一些传统的机器学习算法。
综合再三,选择了自己比较熟悉的MNIST数据集以及OpenCV来完成手写数字的分割和识别作为大作业。

1. 数据集准备

MNIST数据集是一个手写数字的数据库,包含60000张训练图片和10000张测试图片,每张图片大小为28x28像素,每张图片都是一个
灰度图,像素取值范围在0-255之间。

这里使用pytorch的torchvision.datasets模块来读取MNIST数据集。

from torchvision import datasets
mnist_set = datasets.MNIST(root="./MNIST", train=True, download=True)

具体参数说明请自行搜索。注意若donwload=True,则torchvision会通过内置链接自动下载数据集,
但是有时会失效。因此可以自己去网络上下载并解压后排列成指定文件树,如下

MNIST
├── MNSIT
│   ├── raw
│   │   ├── t10k-images-idx3-ubyte.gz
│   │   ├── t10k-labels-idx1-ubyte.gz
│   │   ├── train-images-idx3-ubyte.gz
│   │   └── train-labels-idx1-ubyte.gz

然后使用如下语句去读取数据集

img, target = minst_set[0]

其中每个img类型为PILimage,target类型为int,代表该图片对应的数字。

但是在喂给SVM训练时需要的是[batch_size, data]大小的numpy数组,因此需要做一些预处理

   x_, y_ = list(zip(*([(np.array(img).reshape(28*28), target) for img, target in mnist_set])))

上面的语句实现了将MNIST数据集转换成numpy数组的形式,其中x_是每个成员为[1, 784]的numpy数组,y_为对应的数字所组成的列表。

2. SVM训练

支持向量机(support vector machine,SVM)是经典的机器学习算法,其通过选取两个n维支持向量(support vector)之间的n维超平面来对两类对象进行二分类。而专注于分类的SVM又称作Support Vector Classification,SVC。

求解SVM是一个很复杂的问题,但是万幸的是sklearn中有封装的很好的模块,可以很简单的直接使用

from sklearn.svm import SVC

svc = SVC(kernel='rbf', C=1)
 
svc.fit(x_, y_)

其中fit接口接受两个参数,第一个参数为训练数据[batch_size, data],第二个参数为训练标签[batch_size,1]。
SVC的构造函数如下

SVC(C=1.0, kernel='rbf', degree=3, gamma='scale', coef0=0.0, shrinking=True, probability=False, tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=-1, decision_function_shape='ovr', random_state=None)

比较重要的参数有kernel,C,decison_function_shape等。

  • kernel参数指定了核函数,常用的有linear,poly,rbf,sigmoid等。
  • C为惩罚系数,C越大,对误分类的惩罚越大,模型越保守,C越小,对误分类的惩罚越小,模型越宽松,也就是较大的C在训练集上会有更高的正确率,较小的C会容许噪声的存在,泛化能力较强。
  • decision_function_shape参数指定了决策函数的形状,ovr表示one-vs-rest,ovo表示one-vs-one,具体的意思可以网络查阅

4. 数字分割

数字分割是指将图像中的数字部分分割出来,然后一个一个喂给SVM进行分类

这里就是使用opencv对拍摄的图像进行轮廓提取后拟合外接矩形,借此来提取数字部分的ROI。

这里选择进行Canny边缘检测后去进行轮廓提取,然后拟合外接矩形,因为相较于直接二值化后去提取数字部分的ROI,
边缘检测对数字与纸张的边界更加敏感,即便在光照不均匀的情况下,也能较好的提取出数字的边缘。鲁棒性强。

5. 杂项与代码

这里还有一些杂项,比如保存模型,加载模型

使用pickle模块对训练好的模型对象进行序列化保存与加载,可以将训练好的模型保存到本地,以便后续使用。

最后贴出代码

代码
import os.path
import cv2
import numpy as np
from matplotlib import pyplot as plt
from torchvision import datasets
from torchvision import transforms
from sklearn import svm
from sklearn import preprocessing
from sklearnex import patch_sklearn
import pickle
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import learning_curve

'''
    @brief  加载MNIST数据集并转换格式成二值图
    
    @param train: 是否为训练集
    @param data_enhance: 是否进行数据增强
    
    @return 二值图集和标签集
'''
def LoadMnistDataset(train=True, data_enhance=False):
    mnist_set = datasets.MNIST(root="./MNIST", train=train, download=True)
    x_, y_ = list(zip(*([(np.array(img), target) for img, target in mnist_set])))
    sets_raw = []
    sets_r20 = []
    sets_invr20 = []
    y = []
    y_r20 = []
    y_invr20 = []
    sets = []
    matrix_r20 = cv2.getRotationMatrix2D((14, 14), 25, 1.0)
    matrix_invr20 = cv2.getRotationMatrix2D((14, 14), -25, 1.0)
    select = 0
    for idx in range(len(x_)):
        # 对图像进行二值化以及数据增强
        _, img = cv2.threshold(x_[idx], 255, 255, cv2.THRESH_OTSU)
        sets_raw.append(np.array(img.data).reshape(784))
        y.append(y_[idx])
        if data_enhance:
            if select % 2 == 0:
                img_r20 = ~cv2.warpAffine(~img, matrix_r20, (28, 28), borderValue=(255, 255, 255))
                sets_r20.append(np.array(img_r20.data).reshape(784))
                y_r20.append(y_[idx])
            else:
                img_invr20 = ~cv2.warpAffine(~img, matrix_invr20, (28, 28), borderValue=(255, 255, 255))
                sets_invr20.append(np.array(img_invr20.data).reshape(784))
                y_invr20.append(y_[idx])
            select += 1

    # 数据增强
    sets = sets_raw + sets_r20 + sets_invr20
    sets = np.array(sets)
    print(sets.shape)
    if data_enhance:
        y = y + y_r20 + y_invr20
    return sets, y

'''
    @brief  保存SVM模型
    
    @param svc_model: SVM模型 
    @param file_path: 模型保存路径,默认为./SVC
    
    @return None
'''
def SaveSvcModel(svc_model, file_path="./SVC"):
    with open(file_path, 'wb') as fs:
        pickle.dump(svc_model, fs)

'''
     @brief  加载SVM模型
     
     @param file_path: 模型保存路径,默认为./SVC
     
     @return SVM模型
'''
def LoadSvcModel(file_path="./SVC"):
    if not os.path.exists(file_path):
        assert "Model Do Not Exist"
    with open(file_path, 'rb') as fs:
        svc_model = pickle.load(fs)
    return svc_model

'''
     @brief  训练SVM模型
     
     @param c: SVM参数C
     @param enhance: 是否进行数据增强
     
     @return acc: 在测试集上的准确率
             svc_model: SVM模型
'''
def TrainSvc(c, enhance):
    # 读取数据集,训练集及测试集
    images_train, targets_train = LoadMnistDataset(train=True, data_enhance=enhance)
    images_test, targets_test = LoadMnistDataset(train=False, data_enhance=enhance)

    # 训练
    svc_model = svm.SVC(C=c,kernel='rbf', decision_function_shape='ovr')
    svc_model.fit(images_train, targets_train)

    # 在测试集上测试准确度

    res = svc_model.predict(images_test)
    correct = (res == targets_test).sum()
    accuracy = correct / len(images_test)
    print(f"测试集上的准确率为{accuracy * 100}%")
    return svc_model

'''
     @brief  预处理比较粗的字体
     
     @param image: 输入图像
     @:param show: 是否显示预处理后的图像
     @:param thresh: 二值化阈值
     
     @return 预处理后的图像数据
'''
def PreProcessFatFont(image, show=False):
    # 白底黑字转黑底白字
    pre_ = ~image

    # 转单通道灰度
    pre_ = cv2.cvtColor(pre_, cv2.COLOR_BGR2GRAY)
    # 二值化
    _, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)

    # resize后添加黑色边框,亲测可提高识别率
    pre_ = cv2.resize(pre_, (112, 112))
    _, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)

    back = np.zeros((300, 300), np.uint8)
    back[29:141, 29:141] = pre_
    pre_ = back

    if show:
        cv2.imshow("show", pre_)
        cv2.waitKey(0)

    # 做一次开运算(腐蚀 + 膨胀)
    kernel = np.ones((2, 2), np.uint8)
    pre_ = cv2.erode(pre_, kernel, iterations=1)
    kernel = np.ones((3, 3), np.uint8)
    pre_ = cv2.dilate(pre_, kernel, iterations=1)

    # 第二次resize
    pre_ = cv2.resize(pre_, (56, 56))
    _, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)

    # 做一次开运算(腐蚀 + 膨胀)
    kernel = np.ones((2, 2), np.uint8)
    pre_ = cv2.erode(pre_, kernel, iterations=1)
    kernel = np.ones((3, 3), np.uint8)
    pre_ = cv2.dilate(pre_, kernel, iterations=1)

    # resize成输入规格
    pre_ = cv2.resize(pre_, (28, 28))
    _, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)

    # 转换为SVM的输入格式
    pre_ = np.array(pre_).flatten().reshape(1, -1)
    return pre_

'''
     @brief  预处理细的字体
     
     @param image: 输入图像
     @param show: 是否显示预处理后的图像
     @param thresh: 二值化阈值
     
     
     @return 预处理后的图像数据
'''
def PreProcessThinFont(image, show=False):
    # 白底黑字转黑底白字
    pre_ = ~image

    # 转灰度图
    pre_ = cv2.cvtColor(pre_, cv2.COLOR_BGR2GRAY)

    # 增加黑色边框
    pre_ = cv2.resize(pre_, (112, 112))
    _, pre_ = cv2.threshold(pre_,thresh=0, maxval=255, type=cv2.THRESH_OTSU)
    back = np.zeros((170, 170), dtype=np.uint8) # 这里不指明类型会导致后续矩阵强转为float64,无法使用大津法阈值
    back[29:141, 29:141] = pre_
    pre_ = back

    if show:
        cv2.imshow("show", pre_)
        cv2.waitKey(0)

    # 对细字体先膨胀一下
    kernel = np.ones((3, 3), np.uint8)
    pre_ = cv2.dilate(pre_, kernel, iterations=2)



    # 第二次resize
    pre_ = cv2.resize(pre_, (56, 56))

    _, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)

    # 做一次开运算(腐蚀 + 膨胀)
    kernel = np.ones((2, 2), np.uint8)
    pre_ = cv2.erode(pre_, kernel, iterations=1)
    kernel = np.ones((3, 3), np.uint8)
    pre_ = cv2.dilate(pre_, kernel, iterations=1)

    # resize成输入规格
    pre_ = cv2.resize(pre_, (28, 28))
    _, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)

    # 转换为SVM输入格式
    pre_ = np.array(pre_).flatten().reshape(1, -1)

    return pre_

'''
     @brief  在空白背景上显示提取出的roi
     
     @param res_list: roi列表
     
     @return None
'''
def ShowRoi(res_list):
    back = 255 * np.ones((1000, 1500, 3), dtype=np.uint8)
    # 图片x轴偏移量
    tlx = 0

    for roi in res_list:
        if tlx + roi.shape[1] > back.shape[1]:
            break
        # 每次在原图上加上一个roi
        back[0:roi.shape[0], tlx:tlx + roi.shape[1], :] = roi
        tlx += roi.shape[1]

    cv2.imshow("show", back)
    cv2.waitKey(0)

'''
     @brief  寻找数字轮廓并提取roi
     
     @param src: 输入图像
     @param thin: 是否为细字体
     @param thresh: 二值化阈值
     
     @return roi列表
'''
def FindNumbers(src, thin=True):
    # 拷贝
    dst = src.copy()
    paint = src.copy()
    roi = src.copy()
    dst = ~dst

    # 预处理
    paint = cv2.resize(paint, (448, 448))
    dst = cv2.resize(dst, (448, 448))

    # 记录缩放比例,后来看这一步好像没啥意义
    fx = src.shape[1] / 448
    fy = src.shape[0] / 448

    # 转单通道
    dst = cv2.cvtColor(dst, cv2.COLOR_BGR2GRAY)

    # 边缘检测后二值化,直接二值化的话由于采光不同的原因灰度直方图峰与峰之间可能会差距过大,导致二值图的分割不准确
    # 而边缘检测对像素突变更加敏感,因此采用Canny边缘检测后二值化
    cv2.Canny(dst, 200, 200, dst)

    # 对于平常笔写的字太细,膨胀一下
    if thin:
        kernel = np.ones((5, 5), np.uint8)
        dst = cv2.dilate(dst, kernel, iterations=1)

    # 寻找外围轮廓
    contours, _ = cv2.findContours(dst, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # 提取roi
    roi_list = []
    rect_list = []
    for contour in contours:
        rect = cv2.boundingRect(contour)
        if not ((rect[2] * rect[3] < 400 or rect[2] * rect[3] > 448 * 448 / 2.5) or (rect[3] < rect[2])):
            cv2.rectangle(paint, rect, (255, 0, 0), 1)
            x_min = rect[0] * fx
            x_max = (rect[0] + rect[2]) * fx
            y_min = rect[1] * fy
            y_max = (rect[1] + rect[3]) * fy
            roi_list.append(roi[int(y_min):int(y_max), int(x_min):int(x_max)].copy())
            rect_list.append(rect)
    return paint, roi_list, rect_list

'''
     @brief  以txt形式显示数据
     
     @param data: 数据集
     
     @return None   
'''
def ShowDataTxt(data):
    print("----------------------------------------------------------")
    for i in range(28):
        for j in range(28):
            print(0 if data[0][i * 28 + j] == 255 else 1, end='')
        print('\n')
    print("----------------------------------------------------------")



if __name__ == "__main__":
    # 加载
    patch_sklearn()
    model_path = "./SVC_C1_enhance.pkl"

    if os.path.exists(model_path):
        print("Model Exist, Load Form Serialization")
        model = LoadSvcModel(model_path)
    else:
        print("Model Do Not Exist, Train")

        # 训练
        model = TrainSvc(1, False)


        # 保存
        SaveSvcModel(model, model_path)

    # 测试
    paint, nums, rects = FindNumbers(cv2.imread("test_final.jpg"))
    predict_nums = []
    for img in nums:
        data = PreProcessThinFont(img, show=False)
       # ShowDataTxt(data)
        predict_nums.append(model.predict(data)[0])
    for i in range(len(predict_nums)):
        cv2.putText(paint,str(predict_nums[i]), (rects[i][0], rects[i][1]), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
    cv2.imshow("show", paint)
    cv2.waitKey(0)

给出几个识别后的效果:
image

与OpenCV + sklearnSVM 实现手写数字分割和识别相似的内容:

OpenCV + sklearnSVM 实现手写数字分割和识别

这学期机器学习考核方式以大作业的形式进行考核,而且只能使用一些传统的机器学习算法。 综合再三,选择了自己比较熟悉的MNIST数据集以及OpenCV来完成手写数字的分割和识别作为大作业。 1. 数据集准备 MNIST数据集是一个手写数字的数据库,包含60000张训练图片和10000张测试图片,每张图片

增补博客 第七篇 python 比较不同Python图形处理库或图像处理库的异同点

OpenCV、Pillow 和 scikit image OpenCV(OpenCV 是一个强大的计算机视觉库,它提供了各种图像处理和计算机视觉算法的实现,可以处理各种图像和视频数据。 异同点 跨平台性: OpenCV 支持多种操作系统,包括 Windows、Linux 和 macOS。 功能丰富:

OpenCV实战:从图像处理到深度学习的全面指南

> 本文深入浅出地探讨了OpenCV库在图像处理和深度学习中的应用。从基本概念和操作,到复杂的图像变换和深度学习模型的使用,文章以详尽的代码和解释,带领大家步入OpenCV的实战世界。 # 1. OpenCV简介 ## 什么是OpenCV? ![file](https://img2023.cnblo

OpenCV计算机视觉学习(14)——浅谈常见图像后缀(png, jpg, bmp)的区别(opencv读取语义分割mask的坑)

如果需要处理的原图及代码,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/ComputerVisionPractice 本来不想碎碎念,但是我已经在图像后缀上栽倒两次了。而且因为无意犯错,根本找不到问题。不论是在深度学习的语

Android无障碍自动化结合opencv实现支付宝能量自动收集

Android无障碍服务可以操作元素,手势模拟,实现基本的控制。opencv可以进行图像识别。两者结合在一起即可实现支付宝能量自动收集。opencv用于识别能量,无障碍服务用于模拟手势,即点击能量。 当然这两者结合不单单只能实现这些,还能做很多自动化的程序,如芭芭农场自动施肥、蚂蚁庄园等等的自动化,

Python OpenCV #1 - OpenCV介绍

一、OpenCV介绍 1.1 OpenCV-Python教程简介 OpenCV由 Gary Bradsky 于1999年在英特尔创立,第一个版本于2000年发布。 Vadim Pisarevsky 加入了Gary Bradsky,管理英特尔的俄罗斯软件OpenCV团队。2005年,OpenCV被用于

QtCreator 跨平台开发添加动态库教程(以OpenCV库举例)- Windows篇

Qt具有跨平台的特性,即Qt数据结构与算法库本身跨平台和编译脚本(.pro)跨平台。在同时具有Windows下和Linux开发的需求时,最好的建议是使用QtCreator来开发,虽然也可以使用其他的IDE配合CMake等方式,但使用QtCreator更加方便,并且操作环境完全一致。QtCreator

Building wheel for opencv-python (pyproject.toml) ,安装命令增加 --verbose 参数

Mac 安装 paddlehub 出现 Building wheels for collected packages: opencv-python, ffmpy, jieba, seqeval, future Building wheel for opencv-python (pyproject.t

Python从零到壹丨带你了解图像直方图理论知识和绘制实现

摘要:本文将从OpenCV和Matplotlib两个方面介绍如何绘制直方图,这将为图像处理像素对比提供有效支撑。 本文分享自华为云社区《[Python从零到壹] 五十.图像增强及运算篇之图像直方图理论知识和绘制实现》,作者:eastmount。 一.图像直方图理论知识 灰度直方图是灰度级的函数,描述

python进阶:带你学习实时目标跟踪

摘要:本程序主要实现了python的opencv人工智能视觉模块的目标跟踪功能。 本文分享自华为云社区《python进阶——人工智能实时目标跟踪,这一篇就够用了!》,作者:lqj_本人 。 前言 本程序主要实现了python的opencv人工智能视觉模块的目标跟踪功能。 项目介绍 区域性锁定目标实时