sklearn数据集,月亮数据集如何分类?如何画月亮散点图?
发布于 作者:苏南大叔 来源:程序如此灵动~
目光再次回到sklearn的数据集上,除了常见的鸢尾花数据集,也有一些其它数据集。本文讲述其中的月亮数据集,故名意思,就是对应的数据点可以组成两个弯月亮形状。当然,这个弯月形状,也可以说是半环形状。

苏南大叔的“程序如此灵动”技术博客,记录苏南大叔的代码经验总结。本文测试环境:win10,python@3.11.0,numpy@1.24.2。
加载月亮数据集
函数原型:
xy,label = sklearn.datasets.make_moons(n_samples=100, shuffle=True, noise=None, random_state=None)参数:
n_samples:设置样本数量。noise:设置噪声。噪点越小,半环的环越窄。random_state:设置随机种子参数,相同的参数会随机到相同的值。实际上写啥都可以。
返回值:
xy返回的是坐标值。label是上面数据点的标签,分0和1。
散点图
可以直接画散点图:
import sklearn.datasets
import matplotlib.pyplot as plt
xy, label = sklearn.datasets.make_moons(n_samples=300, noise=0.02, random_state=99)
plt.scatter(xy[:, 0], xy[:, 1], c=label)
plt.show()参考文章:


如果噪点值稍大一些,这个月亮形状半环型就看不出来了。

散点图二
因为make_moons()月亮数据集实际上返回的是两个弯月形状。那么,对应的点本来就是有两类的。在make_moons()的返回Y值中有所体现。所以,根据Y的0/1来进行区分两个月亮。所以,下面这种写法,适合对“区分”有着很强烈需求的情况。
import sklearn.datasets
import matplotlib.pyplot as plt
import numpy as np
xy, label = sklearn.datasets.make_moons(n_samples=300, noise=0.02, random_state=99)
moon1 = np.squeeze(np.argwhere(label == 0)) # 第1组数据索引
moon2 = np.squeeze(np.argwhere(label == 1)) # 第2组数据索引
# print(xy[:10]) # 坐标 [[-0.46569255 0.81721295] [ 0.14305122 -0.26400186]]
# print(label[:10]) # 分类 [0 1 1 0 1 0 1 0 0 1]
# print(np.argwhere(label==0)[:10]) # [[ 0][ 3][ 5][ 7][ 8][12][13][15][18][20]]
# print(np.argwhere(label==1)[:10])
plt.scatter(xy[moon1, 0], xy[moon1, 1], s=50, c="b", marker="+", label="moon1") # s表示线条粗细
plt.scatter(xy[moon2, 0], xy[moon2, 1], s=10, c="r", marker="o", label="moon2")
plt.legend()
plt.show()
这里用到了np.argwhere()获得坐标,还使用了np.squezz()来移除单维度元素。参考文章:
结语
这个月亮数据集,应该没有CSV文件,只是个函数来生成对应规律的数据点,还是比较简单的。更多sklearn的相关文章,请参考苏南大叔的链接: