[sklearn]决策树

学习sklearn中决策树的使用

原理

决策树是一种用于分类和回归的非参数监督学习方法。目标是通过学习从数据特征推断出的简单决策规则,创建一个预测目标变量值(或类别)的模型

4种不同类别的决策树算法:ID3、C4.5、C5.0和CART,其中sklearn实现了CART(Classification and Regression Trees)

Iris分类

使用决策树对iris数据分类。实现流程如下:

  1. 加载数据
  2. 创建决策树
  3. 预测结果
  4. 可视化决策树
  5. 计算最优

所需导入文件及函数如下

1
2
3
4
5
6
import numpy as np
import pandas as pd
import graphviz
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_graphviz

加载数据

1
2
3
4
5
6
7
8
9
10
11
12
13
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris


def load_data():
data = load_iris()

df = pd.DataFrame(data.data, columns=data.feature_names)
df['target'] = data.target

X_train, X_test, Y_train, Y_test = train_test_split(df[data.feature_names], df['target'], random_state=0)
return X_train, X_test, Y_train, Y_test

创建决策树

1
2
3
4
5
def create_classifier(X_train, Y_train, max_depth=3):
clf = DecisionTreeClassifier(max_depth=max_depth, random_state=0)
clf.fit(X_train, Y_train)

return clf

预测结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def predict(clf, X_test, Y_test):
# 预测单个数据
print(clf.predict(X_test.iloc[0].values.reshape(1, -1)))
# 计算分类概率
print(clf.predict_proba(X_test.iloc[0].values.reshape(1, -1)))

# 预测多个数据
print(clf.predict(X_test[0:10]))

# 计算分类准确度
score = clf.score(X_test, Y_test)
print(score)


if __name__ == '__main__':
X_train, X_test, Y_train, Y_test = load_data()
clf = create_classifier(X_train, Y_train)
predict(clf, X_test, Y_test)
############################## 输出
[2]
[[0. 0. 1.]]
[2 1 0 2 0 2 0 1 1 1]
0.9736842105263158

可视化决策树

1
2
3
4
5
6
7
8
9
10
11
def create_graph(clf):
dot_data = export_graphviz(clf, out_file=None)
graph = graphviz.Source(dot_data)
graph.render("iris")


if __name__ == '__main__':
X_train, X_test, Y_train, Y_test = load_data()
clf = create_classifier(X_train, Y_train)
# predict(clf, X_test, Y_test)
create_graph(clf)

生成两个文件:irisiris.pdf。其中iris文件以文本的形式显示决策树的树形结构,iris.pdf文件以图形的方式显示决策树的树形结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
digraph Tree {
node [shape=box] ;
0 [label="X[3] <= 0.8\ngini = 0.665\nsamples = 112\nvalue = [37, 34, 41]"] ;
1 [label="gini = 0.0\nsamples = 37\nvalue = [37, 0, 0]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="X[2] <= 4.95\ngini = 0.496\nsamples = 75\nvalue = [0, 34, 41]"] ;
0 -> 2 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
3 [label="X[3] <= 1.65\ngini = 0.153\nsamples = 36\nvalue = [0, 33, 3]"] ;
2 -> 3 ;
4 [label="gini = 0.0\nsamples = 32\nvalue = [0, 32, 0]"] ;
3 -> 4 ;
5 [label="gini = 0.375\nsamples = 4\nvalue = [0, 1, 3]"] ;
3 -> 5 ;
6 [label="X[2] <= 5.05\ngini = 0.05\nsamples = 39\nvalue = [0, 1, 38]"] ;
2 -> 6 ;
7 [label="gini = 0.375\nsamples = 4\nvalue = [0, 1, 3]"] ;
6 -> 7 ;
8 [label="gini = 0.0\nsamples = 35\nvalue = [0, 0, 35]"] ;
6 -> 8 ;
}

计算最好的树深度

上面默认使用了3层决策树,可以通过遍历的方式判断最好的决策树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def get_best_depth(X_train, X_test, Y_train, Y_test, max_depth=6):
# List of values to try for max_depth:
max_depth_range = list(range(1, max_depth))
# List to store the average RMSE for each value of max_depth:
accuracy = []
for depth in max_depth_range:
clf = DecisionTreeClassifier(max_depth=depth, random_state=0)
clf.fit(X_train, Y_train)
score = clf.score(X_test, Y_test)
accuracy.append(score)
print(accuracy)


if __name__ == '__main__':
X_train, X_test, Y_train, Y_test = load_data()
get_best_depth(X_train, X_test, Y_train, Y_test)
############################################# 输出
[0.5789473684210527, 0.8947368421052632, 0.9736842105263158, 0.9736842105263158, 0.9736842105263158]

从结果可知,设置max_depth=3即可达到最好的实现效果

数据预处理

  1. 决策树不需要向神经网络一样进行数据标准化操作,因为它的实现是独立于每个属性的(就是单独对每个属性的数据计算giti或者entropy
  2. 决策树即可以处理数值型数据也可以处理类别型数据,不过sklearn实现中仅能处理数值型数据(同时不支持缺失值)

Requires little data preparation. Other techniques often require data normalisation, dummy variables need to be created and blank values to be removed. Note however that this module does not support missing values. scikit-learn uses an optimised version of the CART algorithm; however, scikit-learn implementation does not support categorical variables for now.

自定义数据

参考load_data,准备3个属性

  1. data.data:数据数组(numpy.ndarray),大小为(150, 4)(共150条数据,每条数据包含4个属性)
  2. data.feature_names:属性列表。对于iris['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
  3. data.traget:类别数组(numpy.ndarray),大小为(150,)。每个值表示一条数据对应的类别

回归

决策树同时支持分类和回归操作。。。

相关阅读