[多分类]PR曲线

计算多分类任务的PR曲线

解析

计算多分类任务的PR曲线面积AP,通常使用微平均方式(micro),即累加所有的标签的混淆矩阵,统一计算精确度和召回率

python

sklearn库提供了函数sklearn.metrics.average_precision_score计算AP

1
def average_precision_score(y_true, y_score, average="macro", pos_label=1, sample_weight=None):

利用函数sklearn.metrics.precision_recall_curve计算微平均方式的精确度和召回率

1
precision["micro"], recall["micro"], _ = precision_recall_curve(y_test.ravel(), y_score.ravel())

示例

使用iris数据集和单层神经网络进行PR曲线绘制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# -*- coding: utf-8 -*-

"""
@author: zj
@file: multi-pr-nn.py
@time: 2020-01-11
"""

import numpy as np
import matplotlib.pyplot as plt
from nn_classifier import NN
from sklearn.preprocessing import label_binarize
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score


def load_data():
# Import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target

# shuffle and split training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5)

return X_train, X_test, y_train, y_test


if __name__ == '__main__':
X_train, X_test, y_train, y_test = load_data()
n_classes = 3

# 数据标准化
x_train = X_train.astype(np.float64)
x_test = X_test.astype(np.float64)
mu = np.mean(x_train, axis=0)
var = np.var(x_train, axis=0)
eps = 1e-8
x_train = (x_train - mu) / np.sqrt(np.maximum(var, eps))
x_test = (x_test - mu) / np.sqrt(np.maximum(var, eps))

# 定义分类器,训练和预测
classifier = NN(None, input_dim=4, num_classes=3)
classifier.train(x_train, y_train, num_iters=100, batch_size=8, verbose=True)
res_labels, y_score = classifier.predict(x_test)

# For each class
precision = dict()
recall = dict()
average_precision = dict()

# Binarize the output 将类别标签二值化
y_test = label_binarize(y_test, classes=[0, 1, 2])
for i in range(n_classes):
precision[i], recall[i], _ = precision_recall_curve(y_test[:, i], y_score[:, i])
average_precision[i] = average_precision_score(y_test[:, i], y_score[:, i])

# 计算微定义AP
# A "micro-average": quantifying score on all classes jointly
precision["micro"], recall["micro"], _ = precision_recall_curve(y_test.ravel(), y_score.ravel())
average_precision["micro"] = average_precision_score(y_test, y_score, average="micro")
print('Average precision score, micro-averaged over all classes: {0:0.2f}'.format(average_precision["micro"]))

# plt.figure()
# plt.step(recall['micro'], precision['micro'], where='post')
# plt.xlabel('Recall')
# plt.ylabel('Precision')
# plt.ylim([0.0, 1.05])
# plt.xlim([0.0, 1.0])
# plt.title('Average precision score, micro-averaged over all classes: AP={0:0.2f}'.format(average_precision["micro"]))
# plt.show()

# setup plot details
colors = ['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal']
plt.figure(figsize=(7, 8))
f_scores = np.linspace(0.2, 0.8, num=4)
lines = []
labels = []

l, = plt.plot(recall["micro"], precision["micro"], color='gold', lw=2)
lines.append(l)
labels.append('micro-average Precision-recall (area = {0:0.2f})'.format(average_precision["micro"]))

for i, color in zip(range(n_classes), colors):
l, = plt.plot(recall[i], precision[i], color=color, lw=2)
lines.append(l)
labels.append('Precision-recall for class {0} (area = {1:0.2f})'
''.format(i, average_precision[i]))

fig = plt.gcf()
fig.subplots_adjust(bottom=0.15)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Extension of Precision-Recall curve to multi-class')
plt.legend(lines, labels, loc=(0, -.38), prop=dict(size=14))
plt.show()

实现结果如下:

1
2
3
4
5
6
7
8
9
10
11
iteration 0 / 100: loss 1.080253
iteration 10 / 100: loss 0.848652
iteration 20 / 100: loss 0.703864
iteration 30 / 100: loss 0.611669
iteration 40 / 100: loss 0.549452
iteration 50 / 100: loss 0.504679
iteration 60 / 100: loss 0.470638
iteration 70 / 100: loss 0.443626
iteration 80 / 100: loss 0.421484
iteration 90 / 100: loss 0.402878
Average precision score, micro-averaged over all classes: 0.92

相关阅读