如何使用lightgbm的plot_importance函数?特征重要性排序
发布于 作者:苏南大叔 来源:程序如此灵动~ 我们相信:世界是美好的,你是我也是。平行空间的世界里面,不同版本的生活也在继续...
微软出品的lightgbm
模型,也带着plot_importance()
函数。那么,在使用lightgbm
模型,如何对特征的重要性进行排序呢?本文的龙套配角还是鸢尾花数据集。lightgbm
本身也自带importance_type
设置,虽然对本文的鸢尾花数据集没有特别的区别。但是,在处理其它复杂数据集的时候,还是有作用的。
苏南大叔的“程序如此灵动”博客,记录苏南大叔的代码感想感悟。本文测试环境:win10
,python@3.12.0
,pandas@2.1.3
,matplotlib@3.8.2
,LightGBM@4.1.0
。
前文回顾
本文使用LightGBM
来处理鸢尾花数据集,所以主要的内容,可以参考下面的文章:
本文的主体内容逻辑,都是按照xgboost
的plot_importance()
函数写的,参考文章:
本文的结果可合理推断:在选择任意两个特征画散点图的时候,可以有依据了。参考文章:
plot
画图显示中文的事情,可以参考:
加载鸢尾花数据集
from sklearn.model_selection import train_test_split
import pandas as pd
data_url = "http://download.tensorflow.org/data/iris_training.csv"
column_names = ["萼长", "萼宽", "瓣长", "瓣宽", "种类"]
data = pd.read_csv(data_url, header=0, names=column_names)
X = data.iloc[:, :-1] # dataframe 带column信息
y = data.iloc[:, -1:].values.flatten() # ndarray
X_train, X_true, y_train, y_true = train_test_split(X, y, test_size=0.2, random_state=8)
这里其实有个细节,需要特别说明一下:
拿到的鸢尾花数据X
,从原来的ndarray
类型改成了dataframe
类型。dataframe
类型变量里面,携带了列名column
的信息。而对于y
如果使用dataframe
类型的话,会受到column_or_1d()
函数的影响。参考文章:
lightgbm的importance_type
本文是以讨论特征的重要性为目的的,而lightgbm.LGBMClassifier()
的importance_type
有着不同的取值。理论上是可以影响预测结果的。【当然,对于本文的鸢尾花数据集来说,似乎没有影响】
weight
默认值,表示一个特征在所有树中被使用的次数。这个参数反映了该特征的重要性,因为如果一个特征被用于更多的树中,那么它对最终预测结果的贡献就更大。gain
,表示一个特征在所有树中对预测结果的平均增益。这个参数反映了该特征在每个节点上的分裂能力,因为如果一个特征在每个节点上的分裂能力越强,它对最终预测结果的贡献就越大。cover
,表示一个特征在所有树中对样本的平均覆盖度。这个参数反映了该特征对模型的覆盖能力,因为如果一个特征对更多的样本有影响,它对最终预测结果的贡献就更大。
import lightgbm as lgb
# model = lgb.LGBMClassifier(verbose=-1, num_threads=2, importance_type="weight")
# model = lgb.LGBMClassifier(verbose=-1, num_threads=2, importance_type="gain")
# model = lgb.LGBMClassifier(verbose=-1, num_threads=2, importance_type="cover")
model = lgb.LGBMClassifier(verbose=-1, num_threads=2)
model.fit(X_train, y_train)
y_pred = model.predict(X_true)
画图排序重要性【本文重点】
import matplotlib.pyplot as plt
plt.rcParams["font.sans-serif"] = ["SimHei"] # 显示中文
from lightgbm import plot_importance
fig,ax = plt.subplots(figsize=(6,3.5)) # 纸张大小
plot_importance(model,ax=ax) # 必须写ax,否则画图奇怪输出
# plot_importance(model,ax=ax,max_num_features=2) # max_num_features控制输出的特征柱状图的数量
plt.show()
这里需要说明的是:
lightgbm
模型进行数据填充的时候,是带着column
信息的。所以,这里的图会显示对应的标签名。ax
是把plot_importance()
和默认的plt
画图对象进行关联的关键点。- 如果特征点很多的话,
max_num_features
就可以控制显示的特征点的数量。
重大结论
lightgbm
和xgboost
虽然都有这个plot_importance()
函数。但是,两者的结论却是不同的!
- 从
lightgbm
的结果上来看,最重要的两个特征是:"瓣长"和"瓣宽"。 - 然而,从
xgboost
的结果上看,最重要的两个特征是:"萼长"和"瓣长"。
结语
不想写结语,不知道写啥。
如果本文对您有帮助,或者节约了您的时间,欢迎打赏瓶饮料,建立下友谊关系。
本博客不欢迎:各种镜像采集行为。请尊重原创文章内容,转载请保留作者链接。
本博客不欢迎:各种镜像采集行为。请尊重原创文章内容,转载请保留作者链接。