CatBoost模型,如何对鸢尾花数据集进行预测?
发布于 作者:苏南大叔 来源:程序如此灵动~人工智能有很多种算法模型,苏南大叔已经写过很多相关算法文章了。本文要写的是一个catboost
算法,来对鸢尾花数据进行处理。看看catboost
模型,有什么特殊的地方没有?
苏南大叔的“程序如此灵动”博客,记录苏南大叔的代码故事。本文测试环境:win10
,python@3.12.0
,pandas@2.1.3
,catboost@1.2.5
。
CatBoost简介
CatBoost
是一种梯度提升决策树的集成模型。梯度提升是一种通过训练一系列弱学习器来构建一个强大的集成模型的技术。CatBoost
是从XGBoost
和LightGBM
中发展而来的一种优化版本。
相关文章:
CatBoost
最独特的特点是它对类别型特征的处理方式。传统的梯度提升决策树通常需要将类别型特征转换为数值型特征,例如使用独热编码或者标签编码。而CatBoost
采用了一种特殊的有序目录分类(Ordered boosting)算法,可以直接处理类别型特征,无需进行转换。
相关文章:
CatBoost
依然是没有包含在sklearn
里面的,所以需要单独安装。
pip install catboost
截至到发稿,其最新版本为1.2.5
。
鸢尾花数据集
鸢尾花数据集已经反复说了无数次了,加载方法也说了无数次了。这里直接放代码:
from sklearn.model_selection import train_test_split
import pandas as pd
data_url = "http://download.tensorflow.org/data/iris_training.csv"
column_names = ["萼长", "萼宽", "瓣长", "瓣宽", "种类"]
data = pd.read_csv(data_url, header=0, names=column_names)
X = data.iloc[:, :-1].values
y = data.iloc[:, -1:].values.flatten()
X_train, X_true, y_train, y_true = train_test_split(X, y, test_size=0.2, random_state=8)
不明白的读者,可以参考下面的文章:
- https://newsn.net/say/sklearn-csv.html
- https://newsn.net/say/sklearn-load_iris.html
- https://newsn.net/say/sklearn-train_test_split.html
catboost预测
这个catboost
的使用方式,除了model
的初始化,其它的.fit()
和.predict()
方法,都和其它的模型一样。但是fit()
函数的参数和以往不一样,是个pool()
。
from catboost import CatBoostClassifier, Pool
model = CatBoostClassifier(iterations=100, depth=3, learning_rate=0.1, classes_count=3, loss_function='MultiClass')
data_pool = Pool(X_train, y_train)
model.fit(data_pool)
y_pred = model.predict(X_true)
print(y_pred)
print("CATBOOST算法预测准确率:", model.score(X_true, y_true))
输出:
0: learn: 0.9794429 total: 142ms remaining: 14.1s
1: learn: 0.8742464 total: 144ms remaining: 7.04s
2: learn: 0.7822673 total: 145ms remaining: 4.7s
3: learn: 0.7118345 total: 147ms remaining: 3.53s
4: learn: 0.6483435 total: 149ms remaining: 2.83s
5: learn: 0.5950788 total: 151ms remaining: 2.37s
6: learn: 0.5635209 total: 153ms remaining: 2.03s
7: learn: 0.5277414 total: 155ms remaining: 1.78s
8: learn: 0.4853018 total: 156ms remaining: 1.58s
9: learn: 0.4562100 total: 158ms remaining: 1.42s
10: learn: 0.4224053 total: 160ms remaining: 1.29s
11: learn: 0.3976666 total: 163ms remaining: 1.2s
12: learn: 0.3708765 total: 165ms remaining: 1.1s
13: learn: 0.3510984 total: 167ms remaining: 1.02s
14: learn: 0.3325293 total: 168ms remaining: 954ms
15: learn: 0.3248074 total: 170ms remaining: 891ms
16: learn: 0.3079852 total: 171ms remaining: 833ms
17: learn: 0.2934140 total: 171ms remaining: 781ms
18: learn: 0.2786572 total: 173ms remaining: 736ms
19: learn: 0.2705533 total: 173ms remaining: 694ms
20: learn: 0.2584368 total: 174ms remaining: 656ms
21: learn: 0.2477025 total: 175ms remaining: 621ms
22: learn: 0.2358616 total: 176ms remaining: 589ms
23: learn: 0.2251890 total: 177ms remaining: 560ms
24: learn: 0.2147225 total: 178ms remaining: 533ms
25: learn: 0.2047610 total: 178ms remaining: 508ms
26: learn: 0.1994996 total: 179ms remaining: 484ms
27: learn: 0.1911167 total: 180ms remaining: 464ms
28: learn: 0.1827716 total: 181ms remaining: 444ms
29: learn: 0.1756403 total: 182ms remaining: 425ms
30: learn: 0.1686401 total: 183ms remaining: 408ms
31: learn: 0.1630934 total: 184ms remaining: 391ms
32: learn: 0.1579780 total: 185ms remaining: 375ms
33: learn: 0.1533013 total: 186ms remaining: 361ms
34: learn: 0.1486099 total: 187ms remaining: 347ms
35: learn: 0.1432425 total: 187ms remaining: 333ms
36: learn: 0.1382741 total: 188ms remaining: 321ms
37: learn: 0.1346168 total: 189ms remaining: 308ms
38: learn: 0.1299909 total: 190ms remaining: 297ms
39: learn: 0.1270192 total: 191ms remaining: 286ms
40: learn: 0.1233524 total: 191ms remaining: 275ms
41: learn: 0.1191749 total: 192ms remaining: 265ms
42: learn: 0.1176893 total: 192ms remaining: 255ms
43: learn: 0.1148606 total: 193ms remaining: 245ms
44: learn: 0.1118652 total: 193ms remaining: 236ms
45: learn: 0.1095958 total: 194ms remaining: 227ms
46: learn: 0.1069031 total: 195ms remaining: 220ms
47: learn: 0.1045801 total: 195ms remaining: 211ms
48: learn: 0.1017750 total: 196ms remaining: 204ms
49: learn: 0.1001975 total: 196ms remaining: 196ms
50: learn: 0.0984423 total: 197ms remaining: 189ms
51: learn: 0.0955003 total: 197ms remaining: 182ms
52: learn: 0.0944250 total: 198ms remaining: 175ms
53: learn: 0.0919019 total: 198ms remaining: 169ms
54: learn: 0.0894862 total: 198ms remaining: 162ms
55: learn: 0.0873643 total: 199ms remaining: 156ms
56: learn: 0.0853279 total: 200ms remaining: 151ms
57: learn: 0.0837705 total: 200ms remaining: 145ms
58: learn: 0.0816314 total: 201ms remaining: 139ms
59: learn: 0.0791867 total: 201ms remaining: 134ms
60: learn: 0.0778391 total: 202ms remaining: 129ms
61: learn: 0.0754883 total: 202ms remaining: 124ms
62: learn: 0.0743293 total: 203ms remaining: 119ms
63: learn: 0.0729998 total: 203ms remaining: 114ms
64: learn: 0.0720266 total: 204ms remaining: 110ms
65: learn: 0.0709492 total: 204ms remaining: 105ms
66: learn: 0.0699846 total: 205ms remaining: 101ms
67: learn: 0.0690457 total: 205ms remaining: 96.5ms
68: learn: 0.0677621 total: 206ms remaining: 92.3ms
69: learn: 0.0669052 total: 206ms remaining: 88.3ms
70: learn: 0.0663736 total: 206ms remaining: 84.3ms
71: learn: 0.0654374 total: 207ms remaining: 80.5ms
72: learn: 0.0646364 total: 207ms remaining: 76.7ms
73: learn: 0.0636969 total: 208ms remaining: 73ms
74: learn: 0.0623601 total: 208ms remaining: 69.5ms
75: learn: 0.0608246 total: 209ms remaining: 66ms
76: learn: 0.0598671 total: 209ms remaining: 62.6ms
77: learn: 0.0585025 total: 210ms remaining: 59.2ms
78: learn: 0.0571492 total: 210ms remaining: 56ms
79: learn: 0.0563438 total: 211ms remaining: 52.7ms
80: learn: 0.0554539 total: 211ms remaining: 49.6ms
81: learn: 0.0548807 total: 212ms remaining: 46.5ms
82: learn: 0.0537179 total: 212ms remaining: 43.5ms
83: learn: 0.0531897 total: 213ms remaining: 40.5ms
84: learn: 0.0526296 total: 213ms remaining: 37.6ms
85: learn: 0.0520202 total: 214ms remaining: 34.8ms
86: learn: 0.0515390 total: 214ms remaining: 32.1ms
87: learn: 0.0510733 total: 215ms remaining: 29.4ms
88: learn: 0.0500396 total: 216ms remaining: 26.7ms
89: learn: 0.0497724 total: 216ms remaining: 24ms
90: learn: 0.0486054 total: 217ms remaining: 21.4ms
91: learn: 0.0477191 total: 217ms remaining: 18.9ms
92: learn: 0.0469996 total: 218ms remaining: 16.4ms
93: learn: 0.0464555 total: 218ms remaining: 13.9ms
94: learn: 0.0458702 total: 219ms remaining: 11.5ms
95: learn: 0.0450066 total: 219ms remaining: 9.14ms
96: learn: 0.0446415 total: 220ms remaining: 6.8ms
97: learn: 0.0441281 total: 220ms remaining: 4.49ms
98: learn: 0.0434056 total: 221ms remaining: 2.23ms
99: learn: 0.0430703 total: 221ms remaining: 0us
[[1]
[2]
[2]
[3]
[1]
[4]
[0]
[0]
[1]
[5]
[0]
[6]
[2]
[0]
[0]
[7]
[2]
[8]
[2]
[0]
[0]
[9]
[0]
[10]]
CATBOOST算法预测准确率: 0.9166666666666666
从实际运行的效果上来看,这个模型的运行时间要比其它模型要时间长一些。
模型评估
这里引入的依然是经典的f1
/recall
/precision
,代码如下:
from sklearn.metrics import classification_report
report = classification_report(y_true, y_pred)
print(report)
输出:
precision recall f1-score support
0 1.00 1.00 1.00 8
1 1.00 0.75 0.86 8
2 0.80 1.00 0.89 8
accuracy 0.92 24
macro avg 0.93 0.92 0.92 24
weighted avg 0.93 0.92 0.92 24
参考文章:
- https://newsn.net/say/sklearn-classification_report.html
- https://newsn.net/say/sklearn-classification_report-2.html
- https://newsn.net/say/sklearn-score.html
- https://newsn.net/say/sklearn-score-2.html
可能的区别
对比目前已知的一系列模型预测的方法,这个catboost
的最大区别是多了个Pool()
。参考代码:
data_pool = Pool(X_train, y_train)
model.fit(data_pool)
不使用Pool()
改成和其它算法一样的传递方式的话,代码也是可以顺利运行的。例如:
model.fit(X_train, y_train)
结语
从上述代码中可以看出,catboost
的使用方式,和以往的算法略有区别。更多机器学习的算法文章,可以参考:
本博客不欢迎:各种镜像采集行为。请尊重原创文章内容,转载请保留作者链接。