Machine Learning(13)- Random Forest

引言

Machine Learning(13)- Random Forest

Random Forest Algorithm 是另一个非常流行的 Machine Learning 技术,主要应用于 Regression 和 Classication.

Random Forest 的名字其实是来自于上一节学习的 Decision Tree,它是基于一个数据集, 根据不同的规则划分, 一级一级创建成一颗树的形式.
这里要学的 Random Forest 的内在实现就是把一个数据集拆分成 n 个 tree, n 是可以调节的参数, 用以创建更准确的模型。

Machine Learning(13)- Random Forest

正文

引入数据集

下面以手写数字的分类为例,来学习 Random Forest

from sklearn.datasets import load_digits
digits = load_digits()

拆分测试数据和训练数据

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size = 0.2)

用 RandomForest 训练模型

// 引入 RandomForestClassifier
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(n_estimators=40)
model.fit(X_train, y_train)

// 输出
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=40, n_jobs=None,
            oob_score=False, random_state=None, verbose=0,
            warm_start=False)   

查看模型准确度

可以通过微调参数寻找更好的准确度, 参数 n_estimators 就是代表分多少棵树来创建模型。

model.score(X_test, y_test) // 0.9805555555555555
讨论数量: 0
(= ̄ω ̄=)··· 暂无内容!

请勿发布不友善或者负能量的内容。与人为善,比聪明更重要!