我们相信:世界是美好的,你是我也是。平行空间的世界里面,不同版本的生活也在继续...

苏南大叔再次重新审视在基于sklearn的机器学习中,最常用的数据集切割函数train_test_split()。同以前的文章一样,本文还是基于对鸢尾花数据集的再一次加工处理和审视。那么,在本文中,会有什么新的观点呢?

苏南大叔:以鸢尾花数据集为例,重新审视数据切割函数train_test_split - sklearn-train_test_split
以鸢尾花数据集为例,重新审视数据切割函数train_test_split(图1-1)

大家好,这里是苏南大叔的“程序如此灵动”博客,记录苏南大叔和计算机代码的故事。本文将对train_test_split()函数的用法,进行简要的分析。本文测试环境:win10python@3.12.0scikit-learn@1.2.2

前文回顾

本文的新观点,是基于下面的两篇文章的。

函数原型

函数定义在<python>\Lib\site-packages\sklearn\model_selection\_split.py文件中,函数定义如下:

def train_test_split(
    *arrays,
    test_size=None,
    train_size=None,
    random_state=None,
    shuffle=True,
    stratify=None,
):

*arrays

本文主要聚焦于第一个参数*arrays
在平常的代码应用中,它相当于X,y,被切割为X_train,X_test,y_train,y_test。但是,并不是说,可以被切割的就仅仅是Xy。这里是可以无限输入参数的,任何数量的数据,都可以被顺序切割为_train_test两部分。

train_test_split()其它的参数,请参考:

新的待切割数据

对于鸢尾花数据集里面的数字标签0,1,2,变成了中文标签山鸢尾,变色鸢尾,维吉尼亚鸢尾

from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target

import numpy as np
y_cn = np.array(["山鸢尾","变色鸢尾","维吉尼亚鸢尾"])[y]

下面切割传统的Xy,以及新的y_cn数据。

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test, y_cn_train, y_cn_test = \
       train_test_split(X, y, y_cn, test_size=0.2, random_state=8)

print(X_train[0:3])
print(y_train[0:3])
print(y_cn_train[0:3])

输出:

[[5.1 3.8 1.9 0.4]
 [4.8 3.  1.4 0.1]
 [4.6 3.6 1.  0.2]]
[0 0 0]
['山鸢尾' '山鸢尾' '山鸢尾']

可能存在的限制

对于被分割的训练集数据X,yy_cn,要求它们具有相同的数据条数。否则,可能会碰到下面的类似错误提示信息:

ValueError: Found input variables with inconsistent numbers of samples: [50, 150, 150]

其实,仅仅是对X做了数据截取处理。

# ...
X = X[0:50]
# ...
X_train, X_test, y_train, y_test, y_cn_train, y_cn_test = \
       train_test_split(X, y, y_cn, test_size=0.2, random_state=8)

这样的数据输出结果,就是报错。

结语

更多基于python的经验文章,请点击苏南大叔的系列文章:

如果本文对您有帮助,或者节约了您的时间,欢迎打赏瓶饮料,建立下友谊关系。
本博客不欢迎:各种镜像采集行为。请尊重原创文章内容,转载请保留作者链接。

 【福利】 腾讯云最新爆款活动!1核2G云服务器首年50元!

 【源码】本文代码片段及相关软件,请点此获取更多信息

 【绝密】秘籍文章入口,仅传授于有缘之人   python    sklearn