如何使用 Gradio 快速创建机器学习模型的预测界面?
发布于 作者:苏南大叔 来源:程序如此灵动~

在这篇文章中,苏南大叔将演示如何使用Gradio
为机器学习模型创建一个预测界面。Gradio
是一个Python
库,它允许你快速创建可定制的UI
组件,并与他人分享你的机器学习模型。
苏南大叔的“程序如此灵动”博客,记录苏南大叔的代码编程经验总结。测试环境:win10
,python@3.13.2
,sklearn@1.6.1
,gradio@5.17.1
。
前文回顾
本文讲述的重点是给机器学习的模型,做个预测界面。所以,首先,需要有个机器学习的模型,这里引用前一篇文章里面线性回归模型。参考文章:
另外,把训练好的模型,使用joblib
保存起来。使用的时候,再加载就好。如果想用新的数据集训练模型,直接把joblib
生成的模型删除即可。参考文章:
pip install joblib
基础框架代码如下:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
import pandas as pd
import joblib
import os
// 总之,这里得到一个可用的模型。
// 如果模型已经存在,我们将加载它而不是重新训练。
model_path = "linear_regression_model.joblib"
if os.path.exists(model_path):
model = joblib.load(model_path)
else:
// 这里是模型训练的逻辑,可更改,具体参考相关文章
ad_data = pd.read_csv("marketing.csv", encoding="gbk")
ad_data.drop("期数", axis=1, inplace=True)
X = ad_data.drop("净利润", axis=1) # 特征值
y = ad_data["净利润"] # 目标值
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
model = LinearRegression()
model.fit(X_train, y_train)
joblib.dump(model, model_path)
//...
创建一个函数,该函数接受输入特征并返回预测的净利润。写成函数的原因是因为要适配Gradio
。
def predict_profit(douyin, kuaishou, baidu):
new_df = pd.DataFrame({"抖音": [douyin], "快手": [kuaishou], "百度": [baidu]})
y_pred = model.predict(new_df)
return y_pred[0].round(2)
安装 gradio
Gradio
的官方网站是:
Gradio
是一个用于快速构建和部署机器学习模型的框架,主要用于以下用途:
- 快速开发界面:允许用户无需代码即可创建交互式AI应用。
- 模型部署:将预训练模型集成到Web界面中,供用户访问和使用。
- 与后端集成:与Flask、FastAPI等框架无缝连接,实现API集成。
- 教育与演示:用于演示AI模型的工作原理,帮助教育者和开发者理解AI。
- 扩展功能:提供工具支持自定义布局、用户交互和数据分析。
通过Gradio
,开发者可以将复杂的AI
模型转化为易于理解的Web
界面,适用于各种场景,包括数据分析、图像识别和自然语言处理等。
使用pip
安装Gradio
和其他所需的库:
pip install gradio
构建 Gradio 界面
使用Gradio
为预测函数创建一个网页界面。该界面将有三个输入字段,用于输入在不同平台上的广告支出,以及一个输出字段,用于显示预测的净利润。
import gradio as gr
demo = gr.Interface(
fn = predict_profit,
inputs = [
gr.Number(label="抖音"),
gr.Number(label="快手"),
gr.Number(label="百度")
],
outputs = gr.Number(label="净利润预测"),
title = "广告支出预测净利润",
description = "输入广告支出,预测净利润",
flagging_mode="never",
)
demo.launch(inbrowser=True)
这个函数其实比较好理解,对标的模型预测函数predict_profit()
。inputs
对应的是这个函数的输入,outputs
对应的预测函数的输出。flagging_mode
这个设置为never
,右侧就会少一个按钮。这个按钮的作用是保存创建新的数据集的。暂留后续讨论。
完整代码
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
import joblib
import os
import gradio as gr
model_path = "linear_regression_model.joblib"
if os.path.exists(model_path):
model = joblib.load(model_path)
else:
ad_data = pd.read_csv("marketing.csv", encoding="gbk")
ad_data.drop("期数", axis=1, inplace=True)
X = ad_data.drop("净利润", axis=1) # 特征值
y = ad_data["净利润"] # 目标值
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
# Train the model and save it
model = LinearRegression()
model.fit(X_train, y_train)
joblib.dump(model, model_path)
def predict_profit(douyin, kuaishou, baidu):
new_df = pd.DataFrame({"抖音": [douyin], "快手": [kuaishou], "百度": [baidu]})
y_pred = model.predict(new_df)
return y_pred[0].round(2)
demo = gr.Interface(
fn=predict_profit,
inputs=[
gr.Number(label="抖音"),
gr.Number(label="快手"),
gr.Number(label="百度")
],
outputs=gr.Number(label="净利润预测"),
title="广告支出预测净利润",
description="输入广告支出,预测净利润",
flagging_mode="never",
)
demo.launch(inbrowser=True)
结论
在这篇博客中,苏南大叔演示了如何使用Gradio
为机器学习模型创建一个预测界面。这个界面允许用户输入广告支出,并实时获取预测的净利润。Gradio
使得与他人分享你的机器学习模型和创建交互式演示变得非常容易。
更多苏南大叔的精彩文章,请点击:
https://newsn.net/tag/ai/


