机器学习,详解 predict_proba() 置信度 推导预测结果的过程
发布于 作者:苏南大叔 来源:程序如此灵动~
继续关于机器学习的各种模型预测的话题,预测数据使用的方法是model.predict(X_test),本文要写的方法是model.predict_proba(X_test)。两者看起来很相似,对不?两者有什么联系呢?这就是本文主要要探讨的问题。

苏南大叔的“程序如此灵动”博客,记录苏南大叔和计算机代码的故事。本文将对model.predict_proba(X_test)置信度结果,进行简要的分析。本文测试环境:win10,python@3.12.0,scikit-learn@1.2.2。
前文回顾
本文的行文分析基于“梯度提升模型”的鸢尾花数据集,其预测过程可以参考:
.decision_function(X_test)表述的是模型预测数据的置信度,.predict_proba(X_test)表述的也是置信度。两者在具体的数据展示形式上,是有所不同的。参考文章:
.predict()以及.decision_function()
依然先观测一下已有的结论代码:
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=8)
model = GradientBoostingClassifier()
model.fit(X_train, y_train)
pred = model.predict(X_test)
print("预测值:", pred) # [0 0 0 2 1 0 0 2 2 1 1 0 1 1 1 2 2 2 2 1 1 0 1 1 1 0 2 0 0 2]
func = model.decision_function(X_test)
print("置信度:", func) # [[ 6.2513459 -7.98794458 -8.02471493]...].predict_proba()置信度
proba = model.predict_proba(X_test)
print("置信度2:", proba) # [[9.99998714e-01 6.54567078e-07 6.30935551e-07]...]predict_proba()返回的是一个n 行 k 列的数组,第i行第j列上的数值是模型预测:第i个预测样本为某个标签的概率,并且每一行的概率和为"1"。
科学计数法
上面.predict_proba()返回数据比较奇怪,带着e-0x字样,实际上是一种科学计数法的写法。根据百度百科的描述:
在科学计数法中,为了使公式简便,可以用带“E”的格式表示。当用该格式表示时,E前面的数字和“E+”后面要精确到十分位,(位数不够末尾补0)。例如:
- 2乘10的 7次方,正常写法为:2x10^7,简写为“2E+07”的形式。
- 2乘10的-7次方,正常写法为:2x10^(-7),简写为“2E-07”的形式。
苏南大叔是这么理解的:
E+0x就是个非常大的数,用于形容天体之间的距离。x越大,真实的数值越大。E-0x就是个非常小的数,用于形容电子质子之类的质量。x越大,真实的数值越小。
所以,观察一下.predict_proba()的返回结果,下面的数据“[3.99998714e-01 9.54567078e-07 6.30935551e-07]”,最大的数据实际是:3.99998714e-01。(因为e-01中的01最小)。
从.predict_proba()到.pred()
proba = model.predict_proba(X_test)
# print("置信度2:", proba) # [[9.99998714e-01 6.54567078e-07 6.30935551e-07]...]
import numpy as np
_pred = np.argmax(proba, axis=1)
print("推断值2:",_pred)
print('逻辑推理是否准确:{}'.format(np.all(_pred == pred)))输出值:
预测值: [0 0 0 2 1 0 0 2 2 1 1 0 1 1 1 2 2 2 2 1 1 0 1 1 1 0 2 0 0 2]
推断值2: [0 0 0 2 1 0 0 2 2 1 1 0 1 1 1 2 2 2 2 1 1 0 1 1 1 0 2 0 0 2]
逻辑推理是否准确:True这里的推断过程,则非常简单,一个np.argmax(proba, axis=1)就建立了从predict_proba()到pred的联系。在几个分类层面上,哪个置信度高,哪个分类就是最终结果。
结语
更多苏南大叔的sklearn相关经验文章,请点击苏南大叔的经验链接: