多模态分类
本示例展示了如何使用 Ludwig 构建一个多模态分类器。
如果您想在 Colab 中交互式运行此示例,请打开以下任一笔记本并尝试:
注意:您需要您的 Kaggle API 令牌
我们将使用由 David Martín Gutiérrez 最初上传到 Kaggle 的 twitter human-bots 数据集。该数据集包含 37438 行,每一行对应一个 Twitter 用户账户。每行包含通过 Twitter API 收集的 20 个特征列。这些特征包含多种数据模态,包括账户描述和个人资料图片。
目标列 account_type 有两个唯一值:bot 或 human。其中 25013 个用户账户被标注为人类账户,其余 12425 个是机器人账户。
该数据集包含 20 列,但我们只使用这 16 列(15 个输入 + 1 个目标)
列 | 类型 | 描述 |
---|---|---|
default_profile | 二元 | 账户是否有默认资料 |
default_profile_image | 二元 | 账户是否有默认资料图片 |
描述 | 文本 | 用户账户描述 |
favorites_count | 数值 | 点赞推文总数 |
followers_count | 数值 | 粉丝总数 |
friends_count | 数值 | 朋友总数 |
geo_enabled | 二元 | 账户是否启用了地理位置 |
lang | 类别 | 账户语言 |
location | 类别 | 账户位置 |
profile_background_image_path | 图像 | 资料背景图片路径 |
profile_image_path | 图像 | 资料图片路径 |
statuses_count | 数值 | 推文总数 |
verified | 二元 | 账户是否已验证 |
average_tweets_per_day | 数值 | 平均每日推文数 |
account_age_days | 数值 | 账户年龄(天) |
account_type | 二元 | "human" 或 "bot",如果账户是机器人则为 true |
Kaggle API 令牌 (kaggle.json)¶
要使用 Kaggle CLI 下载数据集,您需要一个 Kaggle API 令牌。
如果您已经有一个令牌,它应该安装在 ~/.kaggle/kaggle.json
。在 shell 中运行此命令,并复制输出
cat ~/.kaggle/kaggle.json
如果您没有 kaggle.json
文件
- 登录 Kaggle。如果您还没有账户,请创建一个。
- 前往“账户”,然后点击“创建新的 API 令牌”按钮。这将开始下载。
- 按照 Kaggle 的说明,将您的
kaggle.json
从下载位置复制到您主目录中名为.kaggle
的目录。 - 如果您想在示例 Colab 笔记本中运行此示例,请打开 kaggle.json 并将其内容复制到剪贴板。kaggle.json 文件应类似于
{"username":"your_user_name","key":"_______________________________"}
下载数据集¶
下载数据集并在当前目录中创建 twitter_human_bots_dataset.csv
文件。
# Downloads the dataset to the current working directory
kaggle datasets download danieltreiman/twitter-human-bots-dataset
# Unzips the downloaded dataset, creates twitter_human_bots_dataset.csv
unzip -q -o twitter-human-bots-dataset.zip
训练¶
定义 ludwig 配置¶
Ludwig 配置声明了机器学习任务:要使用哪些列,它们的数据类型,以及要预测哪些列。
注意
只有 20 张独特的背景图片,因此我们将 profile_background_image_path
声明为类别类型而不是图像类型。图像编码器需要大量独特的图像才能表现良好,并且在如此小的样本量下会很快过拟合。
使用 config.yaml
input_features:
- name: default_profile
type: binary
- name: default_profile_image
type: binary
- name: description
type: text
- name: favourites_count
type: number
- name: followers_count
type: number
- name: friends_count
type: number
- name: geo_enabled
type: binary
- name: lang
type: category
- name: location
type: category
- name: profile_background_image_path
type: category
- name: profile_image_path
type: image
- name: statuses_count
type: number
- name: verified
type: binary
- name: average_tweets_per_day
type: number
- name: account_age_days
type: number
output_features:
- name: account_type
type: binary
使用 Python 字典中定义的配置
config = {
"input_features": [
{
"name": "default_profile",
"type": "binary",
},
{
"name": "default_profile_image",
"type": "binary",
},
{
"name": "description",
"type": "text",
},
{
"name": "favourites_count",
"type": "number",
},
{
"name": "followers_count",
"type": "number",
},
{
"name": "friends_count",
"type": "number",
},
{
"name": "geo_enabled",
"type": "binary",
},
{
"name": "lang",
"type": "category",
},
{
"name": "location",
"type": "category",
},
{
"name": "profile_background_image_path",
"type": "category",
},
{
"name": "profile_image_path",
"type": "image",
},
{
"name": "statuses_count",
"type": "number",
},
{
"name": "verified",
"type": "binary",
},
{
"name": "average_tweets_per_day",
"type": "number",
},
{
"name": "account_age_days",
"type": "number",
},
],
"output_features": [
{
"name": "account_type",
"type": "binary",
}
]
}
创建并训练模型¶
ludwig train --dataset twitter_human_bots_dataset.csv -c config.yaml
import pandas as pd
# Reads the dataset from CSV file.
dataset_df = pd.read_csv("twitter_human_bots_dataset.csv")
# Constructs Ludwig model from config dictionary
model = LudwigModel(config, logging_level=logging.INFO)
# Trains the model. This cell might take a few minutes.
train_stats, preprocessed_data, output_directory = model.train(dataset=dataset_df)
评估¶
为测试集生成预测和性能统计信息。
ludwig evaluate \
--model_path results/experiment_run/model \
--dataset twitter_human_bots_dataset.csv \
--split test \
--output_directory test_results
# Generates predictions and performance statistics for the test set.
test_stats, predictions, output_directory = model.evaluate(
dataset_df[dataset_df.split == 1],
collect_predictions=True,
collect_overall_stats=True
)
可视化指标¶
可视化混淆矩阵,它提供了分类器在每个类别上的性能概览。
ludwig visualize \
--visualization confusion_matrix \
--ground_truth_metadata results/experiment_run/model/training_set_metadata.json \
--test_statistics test_results/test_statistics.json \
--output_directory visualizations \
--file_format png
from ludwig.visualize import confusion_matrix
confusion_matrix(
[test_stats],
model.training_set_metadata,
'account_type',
top_n_classes=[2],
model_names=[''],
normalize=True,
)
混淆矩阵 | 类别熵 |
---|---|
![]() |
![]() |
可视化学习曲线,展示了训练期间性能指标如何随时间变化。
ludwig visualize \
--visualization learning_curves \
--ground_truth_metadata results/experiment_run/model/training_set_metadata.json \
--training_statistics results/experiment_run/training_statistics.json \
--file_format png \
--output_directory visualizations
# Visualizes learning curves, which show how performance metrics changed over time during training.
from ludwig.visualize import learning_curves
learning_curves(train_stats, output_feature_name='account_type')
损失 | 指标 |
---|---|
![]() |
![]() |