我们相信:世界是美好的,你是我也是。平行空间的世界里面,不同版本的生活也在继续...

人工智能有很多种算法模型,苏南大叔已经写过很多相关算法文章了。本文要写的是一个catboost算法,来对鸢尾花数据进行处理。看看catboost模型,有什么特殊的地方没有?

苏南大叔:CatBoost模型,如何对鸢尾花数据集进行预测? - catboost预测
CatBoost模型,如何对鸢尾花数据集进行预测?(图3-1)

苏南大叔的“程序如此灵动”博客,记录苏南大叔的代码故事。本文测试环境:win10python@3.12.0pandas@2.1.3catboost@1.2.5

CatBoost简介

CatBoost是一种梯度提升决策树的集成模型。梯度提升是一种通过训练一系列弱学习器来构建一个强大的集成模型的技术。CatBoost是从XGBoostLightGBM中发展而来的一种优化版本。

相关文章:

CatBoost最独特的特点是它对类别型特征的处理方式。传统的梯度提升决策树通常需要将类别型特征转换为数值型特征,例如使用独热编码或者标签编码。而CatBoost采用了一种特殊的有序目录分类(Ordered boosting)算法,可以直接处理类别型特征,无需进行转换。

相关文章:

CatBoost依然是没有包含在sklearn里面的,所以需要单独安装。

pip install catboost

截至到发稿,其最新版本为1.2.5

鸢尾花数据集

鸢尾花数据集已经反复说了无数次了,加载方法也说了无数次了。这里直接放代码:

from sklearn.model_selection import train_test_split
import pandas as pd
data_url = "http://download.tensorflow.org/data/iris_training.csv"
column_names = ["萼长", "萼宽", "瓣长", "瓣宽", "种类"]
data = pd.read_csv(data_url, header=0, names=column_names)
X = data.iloc[:, :-1].values
y = data.iloc[:, -1:].values.flatten()
X_train, X_true, y_train, y_true = train_test_split(X, y, test_size=0.2, random_state=8)

不明白的读者,可以参考下面的文章:

catboost预测

这个catboost的使用方式,除了model的初始化,其它的.fit().predict()方法,都和其它的模型一样。但是fit()函数的参数和以往不一样,是个pool()

from catboost import CatBoostClassifier, Pool
model = CatBoostClassifier(iterations=100, depth=3, learning_rate=0.1, classes_count=3, loss_function='MultiClass')
data_pool = Pool(X_train, y_train)
model.fit(data_pool)
y_pred = model.predict(X_true)
print(y_pred)
print("CATBOOST算法预测准确率:", model.score(X_true, y_true))

输出:

0:      learn: 0.9794429        total: 142ms    remaining: 14.1s
1:      learn: 0.8742464        total: 144ms    remaining: 7.04s
2:      learn: 0.7822673        total: 145ms    remaining: 4.7s
3:      learn: 0.7118345        total: 147ms    remaining: 3.53s
4:      learn: 0.6483435        total: 149ms    remaining: 2.83s
5:      learn: 0.5950788        total: 151ms    remaining: 2.37s
6:      learn: 0.5635209        total: 153ms    remaining: 2.03s
7:      learn: 0.5277414        total: 155ms    remaining: 1.78s
8:      learn: 0.4853018        total: 156ms    remaining: 1.58s
9:      learn: 0.4562100        total: 158ms    remaining: 1.42s
10:     learn: 0.4224053        total: 160ms    remaining: 1.29s
11:     learn: 0.3976666        total: 163ms    remaining: 1.2s
12:     learn: 0.3708765        total: 165ms    remaining: 1.1s
13:     learn: 0.3510984        total: 167ms    remaining: 1.02s
14:     learn: 0.3325293        total: 168ms    remaining: 954ms
15:     learn: 0.3248074        total: 170ms    remaining: 891ms
16:     learn: 0.3079852        total: 171ms    remaining: 833ms
17:     learn: 0.2934140        total: 171ms    remaining: 781ms
18:     learn: 0.2786572        total: 173ms    remaining: 736ms
19:     learn: 0.2705533        total: 173ms    remaining: 694ms
20:     learn: 0.2584368        total: 174ms    remaining: 656ms
21:     learn: 0.2477025        total: 175ms    remaining: 621ms
22:     learn: 0.2358616        total: 176ms    remaining: 589ms
23:     learn: 0.2251890        total: 177ms    remaining: 560ms
24:     learn: 0.2147225        total: 178ms    remaining: 533ms
25:     learn: 0.2047610        total: 178ms    remaining: 508ms
26:     learn: 0.1994996        total: 179ms    remaining: 484ms
27:     learn: 0.1911167        total: 180ms    remaining: 464ms
28:     learn: 0.1827716        total: 181ms    remaining: 444ms
29:     learn: 0.1756403        total: 182ms    remaining: 425ms
30:     learn: 0.1686401        total: 183ms    remaining: 408ms
31:     learn: 0.1630934        total: 184ms    remaining: 391ms
32:     learn: 0.1579780        total: 185ms    remaining: 375ms
33:     learn: 0.1533013        total: 186ms    remaining: 361ms
34:     learn: 0.1486099        total: 187ms    remaining: 347ms
35:     learn: 0.1432425        total: 187ms    remaining: 333ms
36:     learn: 0.1382741        total: 188ms    remaining: 321ms
37:     learn: 0.1346168        total: 189ms    remaining: 308ms
38:     learn: 0.1299909        total: 190ms    remaining: 297ms
39:     learn: 0.1270192        total: 191ms    remaining: 286ms
40:     learn: 0.1233524        total: 191ms    remaining: 275ms
41:     learn: 0.1191749        total: 192ms    remaining: 265ms
42:     learn: 0.1176893        total: 192ms    remaining: 255ms
43:     learn: 0.1148606        total: 193ms    remaining: 245ms
44:     learn: 0.1118652        total: 193ms    remaining: 236ms
45:     learn: 0.1095958        total: 194ms    remaining: 227ms
46:     learn: 0.1069031        total: 195ms    remaining: 220ms
47:     learn: 0.1045801        total: 195ms    remaining: 211ms
48:     learn: 0.1017750        total: 196ms    remaining: 204ms
49:     learn: 0.1001975        total: 196ms    remaining: 196ms
50:     learn: 0.0984423        total: 197ms    remaining: 189ms
51:     learn: 0.0955003        total: 197ms    remaining: 182ms
52:     learn: 0.0944250        total: 198ms    remaining: 175ms
53:     learn: 0.0919019        total: 198ms    remaining: 169ms
54:     learn: 0.0894862        total: 198ms    remaining: 162ms
55:     learn: 0.0873643        total: 199ms    remaining: 156ms
56:     learn: 0.0853279        total: 200ms    remaining: 151ms
57:     learn: 0.0837705        total: 200ms    remaining: 145ms
58:     learn: 0.0816314        total: 201ms    remaining: 139ms
59:     learn: 0.0791867        total: 201ms    remaining: 134ms
60:     learn: 0.0778391        total: 202ms    remaining: 129ms
61:     learn: 0.0754883        total: 202ms    remaining: 124ms
62:     learn: 0.0743293        total: 203ms    remaining: 119ms
63:     learn: 0.0729998        total: 203ms    remaining: 114ms
64:     learn: 0.0720266        total: 204ms    remaining: 110ms
65:     learn: 0.0709492        total: 204ms    remaining: 105ms
66:     learn: 0.0699846        total: 205ms    remaining: 101ms
67:     learn: 0.0690457        total: 205ms    remaining: 96.5ms
68:     learn: 0.0677621        total: 206ms    remaining: 92.3ms
69:     learn: 0.0669052        total: 206ms    remaining: 88.3ms
70:     learn: 0.0663736        total: 206ms    remaining: 84.3ms
71:     learn: 0.0654374        total: 207ms    remaining: 80.5ms
72:     learn: 0.0646364        total: 207ms    remaining: 76.7ms
73:     learn: 0.0636969        total: 208ms    remaining: 73ms
74:     learn: 0.0623601        total: 208ms    remaining: 69.5ms
75:     learn: 0.0608246        total: 209ms    remaining: 66ms
76:     learn: 0.0598671        total: 209ms    remaining: 62.6ms
77:     learn: 0.0585025        total: 210ms    remaining: 59.2ms
78:     learn: 0.0571492        total: 210ms    remaining: 56ms
79:     learn: 0.0563438        total: 211ms    remaining: 52.7ms
80:     learn: 0.0554539        total: 211ms    remaining: 49.6ms
81:     learn: 0.0548807        total: 212ms    remaining: 46.5ms
82:     learn: 0.0537179        total: 212ms    remaining: 43.5ms
83:     learn: 0.0531897        total: 213ms    remaining: 40.5ms
84:     learn: 0.0526296        total: 213ms    remaining: 37.6ms
85:     learn: 0.0520202        total: 214ms    remaining: 34.8ms
86:     learn: 0.0515390        total: 214ms    remaining: 32.1ms
87:     learn: 0.0510733        total: 215ms    remaining: 29.4ms
88:     learn: 0.0500396        total: 216ms    remaining: 26.7ms
89:     learn: 0.0497724        total: 216ms    remaining: 24ms
90:     learn: 0.0486054        total: 217ms    remaining: 21.4ms
91:     learn: 0.0477191        total: 217ms    remaining: 18.9ms
92:     learn: 0.0469996        total: 218ms    remaining: 16.4ms
93:     learn: 0.0464555        total: 218ms    remaining: 13.9ms
94:     learn: 0.0458702        total: 219ms    remaining: 11.5ms
95:     learn: 0.0450066        total: 219ms    remaining: 9.14ms
96:     learn: 0.0446415        total: 220ms    remaining: 6.8ms
97:     learn: 0.0441281        total: 220ms    remaining: 4.49ms
98:     learn: 0.0434056        total: 221ms    remaining: 2.23ms
99:     learn: 0.0430703        total: 221ms    remaining: 0us
[[1]
 [2]
 [2]
 [3]
 [1]
 [4]
 [0]
 [0]
 [1]
 [5]
 [0]
 [6]
 [2]
 [0]
 [0]
 [7]
 [2]
 [8]
 [2]
 [0]
 [0]
 [9]
 [0]
 [10]]
CATBOOST算法预测准确率: 0.9166666666666666

从实际运行的效果上来看,这个模型的运行时间要比其它模型要时间长一些。

苏南大叔:CatBoost模型,如何对鸢尾花数据集进行预测? - catboost鸢尾花预测结果
CatBoost模型,如何对鸢尾花数据集进行预测?(图3-2)

模型评估

这里引入的依然是经典的f1/recall/precision,代码如下:

from sklearn.metrics import classification_report
report = classification_report(y_true, y_pred)
print(report)

输出:

              precision    recall  f1-score   support

           0       1.00      1.00      1.00         8
           1       1.00      0.75      0.86         8
           2       0.80      1.00      0.89         8

    accuracy                           0.92        24
   macro avg       0.93      0.92      0.92        24
weighted avg       0.93      0.92      0.92        24

参考文章:

可能的区别

对比目前已知的一系列模型预测的方法,这个catboost的最大区别是多了个Pool()。参考代码:

data_pool = Pool(X_train, y_train)
model.fit(data_pool)

不使用Pool()改成和其它算法一样的传递方式的话,代码也是可以顺利运行的。例如:

model.fit(X_train, y_train)

苏南大叔:CatBoost模型,如何对鸢尾花数据集进行预测? - 机器学习算法的写法
CatBoost模型,如何对鸢尾花数据集进行预测?(图3-3)

结语

从上述代码中可以看出,catboost的使用方式,和以往的算法略有区别。更多机器学习的算法文章,可以参考:

如果本文对您有帮助,或者节约了您的时间,欢迎打赏瓶饮料,建立下友谊关系。
本博客不欢迎:各种镜像采集行为。请尊重原创文章内容,转载请保留作者链接。

 【福利】 腾讯云最新爆款活动!1核2G云服务器首年50元!

 【源码】本文代码片段及相关软件,请点此获取更多信息

 【绝密】秘籍文章入口,仅传授于有缘之人   python    机器学习