Machine Learning (8) - Logistic Regression (Multiclass Classification)

引言

Machine Learning (8) - Logistic Regression (Multiclass Classification)

上一节我们重点关注了 binary classification, 也就是预测结果非是则非,这一节我们来学习多种分类值的模型。

正文

Machine Learning (8) - Logistic Regression (Multiclass Classification)

我们通过识别手写数字的例子, 学习 Multiclass Classification。需要借助 The Digit Dataset 的数据集, 它是由 1797 个 8*8 的图像组成的。如下图所示, 每一个图像都是一个手写数字。也就是说, 我们通过这个数据集, 可以获得 1797 条数据用来训练模型。

Machine Learning (8) - Logistic Regression (Multiclass Classification)

与 binary classification 的不同之处在于, 这里获得的结果不是 yes or no 的二选一, 而是数字从 0-9 的 10 种分类.

第一步:引入数据, 并对其特性做简单的了解

%matplotlib inline
import matplotlib.pyplot as plt

// 引入 load_digits 数据集
from sklearn.datasets import load_digits
digits = load_digits()

// 查看这个数据集都有哪些属性
dir(digits)
// 输出
['DESCR', 'data', 'images', 'target', 'target_names']

digits.data
// 输出
array([[ 0.,  0.,  5., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ..., 10.,  0.,  0.],
       [ 0.,  0.,  0., ..., 16.,  9.,  0.],
       ...,
       [ 0.,  0.,  1., ...,  6.,  0.,  0.],
       [ 0.,  0.,  2., ..., 12.,  0.,  0.],
       [ 0.,  0., 10., ..., 12.,  1.,  0.]])

// 查看这个数据集的整体结构, 从输出可以看到这是一个二维数组, 一共是有 1797 条数据, 每条数据由 64 个数字组层
digits.data.shape  // 输出 (1797, 64)

// 查看二维数据的第一个, 也就是代表了第一个数字
digits.data[0]
// 输出
array([ 0.,  0.,  5., 13.,  9.,  1.,  0.,  0.,  0.,  0., 13., 15., 10.,
       15.,  5.,  0.,  0.,  3., 15.,  2.,  0., 11.,  8.,  0.,  0.,  4.,
       12.,  0.,  0.,  8.,  8.,  0.,  0.,  5.,  8.,  0.,  0.,  9.,  8.,
        0.,  0.,  4., 11.,  0.,  1., 12.,  7.,  0.,  0.,  2., 14.,  5.,
       10., 12.,  0.,  0.,  0.,  0.,  6., 13., 10.,  0.,  0.,  0.])

// 再来看下 images 属性, 输出前 4 个数字看看      
for i in range(4):
    plt.matshow(digits.images[i])

Machine Learning (8) - Logistic Regression (Multiclass Classification)

Machine Learning (8) - Logistic Regression (Multiclass Classification)

// target 属性, 也就是 data 里一堆数字到底是代表数字几
digits.target[0:4] // 输出 array([0, 1, 2, 3])
digits.target_names[0:4] // 输出 array([0, 1, 2, 3])

第二步:准备训练模型

// 按照 20% 测试数据的比例, 把数据集分为训练数据和测试数据
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size = 0.2)
len(X_train) // 输出 1437
len(X_test) // 输出 360

// 训练模型
from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
model.fit(X_train, y_train)

// 查看模型准确度
model.score(X_test, y_test) // 输出 0.9777777777777777

以上, 就是借助于 load_digits 数据集完成了对 multiclass 模型的训练, 并且得出了精确度还比较准确. 如果想更加细致地了解误差的位置, 可以通过 confusion_matrix 类实现:

y_predicted = model.predict(X_test)
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(y_test, y_predicted)
cm
// 输出
array([[36,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0, 45,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0, 36,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0, 39,  0,  0,  0,  0,  0,  0],
       [ 0,  1,  0,  0, 27,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0, 35,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0, 36,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0, 34,  0,  1],
       [ 0,  2,  0,  0,  0,  1,  0,  0, 26,  1],
       [ 0,  0,  0,  0,  0,  0,  0,  2,  0, 38]])

为了将上面的输出可视化更强, 引入 seaborn 包, 需要在终端安装 pip3 install seaborn

import seaborn as sn
plt.figure(figsize = (10, 7))
sn.heatmap(cm, annot=True)
plt.xlabel('Predicted')
plt.ylabel('Truth')

Machine Learning (8) - Logistic Regression (Multiclass Classification)

简单解释这个图表:
x 轴是预测的值, y 轴是实际的值
以左上角的 36 为例, 表示一共有 36 个数字 0 (看 y 轴), 而预测的也都是数字 0 (x 轴), 表示对 0 的预测没有误差.
再看最下面一行的 2, 表示有两个数字 9 (y 轴) 被预测成了数字 7 (x 轴), 也就是说在数字 9 的预测上有两个错误.
这个表上所有数字加起来正好是 360, 也就是测试数据的数据量.
这就非常直观地看出我们这个数据模型的准确度表现了.

讨论数量: 0
(= ̄ω ̄=)··· 暂无内容!

请勿发布不友善或者负能量的内容。与人为善,比聪明更重要!