以鸢尾花数据集为例,重新审视数据切割函数train_test_split
发布于 作者:苏南大叔 来源:程序如此灵动~ 我们相信:世界是美好的,你是我也是。平行空间的世界里面,不同版本的生活也在继续...
苏南大叔再次重新审视在基于sklearn
的机器学习中,最常用的数据集切割函数train_test_split()
。同以前的文章一样,本文还是基于对鸢尾花数据集的再一次加工处理和审视。那么,在本文中,会有什么新的观点呢?
大家好,这里是苏南大叔的“程序如此灵动”博客,记录苏南大叔和计算机代码的故事。本文将对train_test_split()
函数的用法,进行简要的分析。本文测试环境:win10
,python@3.12.0
,scikit-learn@1.2.2
。
前文回顾
本文的新观点,是基于下面的两篇文章的。
函数原型
函数定义在<python>\Lib\site-packages\sklearn\model_selection\_split.py
文件中,函数定义如下:
def train_test_split(
*arrays,
test_size=None,
train_size=None,
random_state=None,
shuffle=True,
stratify=None,
):
*arrays
本文主要聚焦于第一个参数*arrays
,
在平常的代码应用中,它相当于X,y
,被切割为X_train,X_test,y_train,y_test
。但是,并不是说,可以被切割的就仅仅是X
和y
。这里是可以无限输入参数的,任何数量的数据,都可以被顺序切割为_train
和_test
两部分。
train_test_split()
其它的参数,请参考:
新的待切割数据
对于鸢尾花数据集里面的数字标签0,1,2
,变成了中文标签山鸢尾,变色鸢尾,维吉尼亚鸢尾
。
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target
import numpy as np
y_cn = np.array(["山鸢尾","变色鸢尾","维吉尼亚鸢尾"])[y]
下面切割传统的X
和y
,以及新的y_cn
数据。
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test, y_cn_train, y_cn_test = \
train_test_split(X, y, y_cn, test_size=0.2, random_state=8)
print(X_train[0:3])
print(y_train[0:3])
print(y_cn_train[0:3])
输出:
[[5.1 3.8 1.9 0.4]
[4.8 3. 1.4 0.1]
[4.6 3.6 1. 0.2]]
[0 0 0]
['山鸢尾' '山鸢尾' '山鸢尾']
可能存在的限制
对于被分割的训练集数据X
,y
和y_cn
,要求它们具有相同的数据条数。否则,可能会碰到下面的类似错误提示信息:
ValueError: Found input variables with inconsistent numbers of samples: [50, 150, 150]
其实,仅仅是对X
做了数据截取处理。
# ...
X = X[0:50]
# ...
X_train, X_test, y_train, y_test, y_cn_train, y_cn_test = \
train_test_split(X, y, y_cn, test_size=0.2, random_state=8)
这样的数据输出结果,就是报错。
结语
更多基于python
的经验文章,请点击苏南大叔的系列文章:
如果本文对您有帮助,或者节约了您的时间,欢迎打赏瓶饮料,建立下友谊关系。
本博客不欢迎:各种镜像采集行为。请尊重原创文章内容,转载请保留作者链接。
本博客不欢迎:各种镜像采集行为。请尊重原创文章内容,转载请保留作者链接。