我们相信:世界是美好的,你是我也是。平行空间的世界里面,不同版本的生活也在继续...

以鸢尾花数据集为例,本文描述如何使用sklearn的逻辑回归模型LogisticRegression。本文仅仅是个逻辑回归模型的鸢尾花数据集使用范例,主打短平快,不对逻辑回归模型的细节进行谈论。

苏南大叔:逻辑回归模型,如何对鸢尾花数据集进行预测? - 逻辑回归预测鸢尾花数据集
逻辑回归模型,如何对鸢尾花数据集进行预测?(图2-1)

苏南大叔的“程序如此灵动”博客,记录苏南大叔的代码编程经验文章。测试环境:win10python@3.11.0sklearn@1.2.2

获得鸢尾花训练集

从前面的文章里面,苏南大叔已经对鸢尾花数据集做了无数次各个角度的数据分析。其实主要的目标就是读取csv,拿到目标数据并进行数据集切分。参考文章:

拿到鸢尾花数据集的方式很多,本文中加载的是sklearn自带的鸢尾花数据集。

获取数据集的代码如下:

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=8)

这里拿到的X/y数据类型都是ndarray

逻辑回归模型预测

逻辑回归与线性回归都是一种广义线性模型。具体的说,都是从指数分布族导出的线性模型,线性回归假设Y|X服从高斯分布,逻辑回归假设Y|X服从伯努利分布。

这里采用LogisticRegression逻辑回归模型,做鸢尾花预测。其中模型参数multi_classsolver都采用默认数值。

from sklearn.linear_model import LogisticRegression
model = LogisticRegression()                   
model.fit(X_train, y_train)
predictions = model.predict(X_test)
print(predictions)
print("逻辑回归模型准确度:", model.score(X_test, y_test))

输出:

[0 0 0 2 1 0 0 2 2 1 1 0 1 1 1 2 2 2 2 2 1 0 1 1 1 0 2 0 0 2]
逻辑回归模型: 0.9

苏南大叔:逻辑回归模型,如何对鸢尾花数据集进行预测? - 逻辑回归运算结果
逻辑回归模型,如何对鸢尾花数据集进行预测?(图2-2)

可能的报错

如果预测过程中(比如使用tensorflow的在线csv版本时),报错信息如下:

D:\Program Files\Python368\lib\site-packages\sklearn\linear_model\_logistic.py:765: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)

可以考虑更改solver参数值,从默认的lbfgs改成liblinearnewton-cg。例如:

model = LogisticRegression(solver='liblinear')
model = LogisticRegression(solver='newton-cg')
更换不同的solver是可能带来预测准确度的变化的。

更多取值可以参考:

结束语

机器学习的各种模型非常多,每个模型内部也有很多不同的预测方法,本文所使用的逻辑回归模型中,其实也有两个参数的,分别是multi_classsolver,待后续文章进行讨论。

如果本文对您有帮助,或者节约了您的时间,欢迎打赏瓶饮料,建立下友谊关系。
本博客不欢迎:各种镜像采集行为。请尊重原创文章内容,转载请保留作者链接。

 【福利】 腾讯云最新爆款活动!1核2G云服务器首年50元!

 【源码】本文代码片段及相关软件,请点此获取更多信息

 【绝密】秘籍文章入口,仅传授于有缘之人   python    sklearn