详解从decision_function()置信度推导预测结果的过程
发布于 作者:苏南大叔 来源:程序如此灵动~在机器学习中,模型置信度model.decision_function()
的结果,是可以推导出预测结果的,两者是存在着联系的。文章是基于"梯度上升模型"的鸢尾花数据集的,但是关于如何从置信度推算出大家所需要的预测结果的过程,应该是放到各种模型里面都是适用的过程。欢迎在不同的模型下,验证本文的推理过程。
苏南大叔的“程序如此灵动”博客,记录苏南大叔和计算机代码的故事。本文将对model.decision_function()
置信度结果,进行简要的分析。本文测试环境:win10
,python@3.12.0
,scikit-learn@1.2.2
。
模型预测结果
理论上来说,苏南大叔以前写过的任何机器学习模型都是可以的。现在使用最近的一篇文章“梯度上升模型”进行推理。代码如下:
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=8)
model = GradientBoostingClassifier()
model.fit(X_train, y_train)
pred = model.predict(X_test)
print("预测值:", pred) # [0 0 0 2 1 0 0 2 2 1 1 0 1 1 1 2 2 2 2 1 1 0 1 1 1 0 2 0 0 2]
print("实际值:", y_test) # [0 0 0 2 1 0 0 2 2 1 1 0 2 1 2 2 2 2 1 2 1 0 1 1 1 0 2 0 0 2]
print("准确度:", model.score(X_test, y_test)) # 准确度: 0.8666666666666667 (26/30)
预测结果是pred
,和真实值进行对比,可以知道:有四个值预测错了,所以准确度是26/30=0.8666666666666667
。
参考文章:
模型置信度
置信度是通过model.decision_function(X_test)
来获得的。
print(model.decision_function(X_test))
输出结果是:
[[ 6.2513459 -7.98794458 -8.02471493]
...
[-8.00690033 6.05561728 -8.02581679]
[ 6.2513459 -7.98794458 -8.02471493]]
从置信度到预测结果
那么,从置信度是如何推断出预测值的呢?这就是本文的主要讨论内容。下面的推导,是苏南大叔根据资料自己总结的,不一定准确,欢迎大家进行验证。
import numpy as np
_pred = np.where(model.classes_[(model.decision_function(X_test)>0).astype(int)]>=1)[1]
print("推断值:",_pred) # 对比 pred 结果
输出值:
预测值: [0 0 0 2 1 0 0 2 2 1 1 0 1 1 1 2 2 2 2 1 1 0 1 1 1 0 2 0 0 2]
推断值: [0 0 0 2 1 0 0 2 2 1 1 0 1 1 1 2 2 2 2 1 1 0 1 1 1 0 2 0 0 2]
详解推断过程
下面做置信度到预测结果的分解动作:
func = model.decision_function(X_test) # 每个预测值的置信度
print("置信度:",func,type(func))
print("置信度到预测结果的计算过程:")
_func = func > 0 # 置信度转化为布尔
print(_func)
_func = (_func).astype(int) # 布尔值True/False转化为1或0
print(_func)
_func = model.classes_[_func] # 0和1作为classes的索引
print(model.classes_)
print(_func) # 决策边界置信度结果
import numpy as np
sn = np.where(_func >= 1) # 找出所有>=1的元素
print(sn)
_pred = sn[1]
print(_pred)
print('逻辑推理是否准确:{}'.format(np.all(_pred == pred)))
输出:
置信度: [[ 6.2513459 -7.98794458 -8.02471493]
...
[ 6.2513459 -7.98794458 -8.02471493]
[-8.00722706 -7.98930769 6.16392638]]
<class 'numpy.ndarray'>
置信度到预测结果的计算过程:
[[ True False False]
...
[ True False False]
[False False True]]
[[1 0 0]
...
[1 0 0]
[0 0 1]]
[0 1 2]
[[1 0 0]
...
[1 0 0]
[0 0 1]]
(array([ 0, ... 28, 29], dtype=int64), array([0, ... 0, 2], dtype=int64))
[0 0 0 2 1 0 0 2 2 1 1 0 1 1 1 2 2 2 2 1 1 0 1 1 1 0 2 0 0 2]
逻辑推理是否准确:True
.where()
筛选数据的参考文章:
置信度结论
置信度的结果是个ndarray
类型,其维度是[预测数据条数,分类数量]。也可以这么说:每一条被预测的数据,对应每个分类都有个预测的结果,比如:
第一条被预测的数据[5. 3.6 1.4 0.2]
,经过模型预测后,其置信度是[ 6.2513459 -7.98794458 -8.02471493]
。那么,分别对应:属于分类0("山鸢尾")的概率是6.2513459
,属于分类1("变色鸢尾")的概率是-7.98794458
,属于分类2("维吉尼亚鸢尾")的概率是-8.02471493
。
值得注意的是:对应着三个分类的置信度数据,只有一个是正数(>0
),推断过程中,这个数据被置为True
,后通过astype(int)
转化为1
,最终被np.where(>=1)
分化出来。
最后通过ndarray
是否相等的判断np.all(_pred == pred)
,验证推断过程。
结语
更多sklearn
的相关经验文章内容,请参考苏南大叔的文章:
本博客不欢迎:各种镜像采集行为。请尊重原创文章内容,转载请保留作者链接。