文本分类
本示例展示了如何使用 Ludwig 构建文本分类器。
这些交互式笔记本遵循本示例的步骤
我们将使用 AG 新闻主题分类数据集,这是一个常用的文本分类基准数据集。该数据集是完整的 AG 新闻数据集的子集,通过选择原始语料库中最大的四个类别构建而成。每个类别包含 30,000 个训练样本和 1,900 个测试样本。训练样本总数为 120,000 个,测试样本总数为 7,600 个。原始分割不包含验证集,因此我们将每个训练类别的前 5% 标记为验证集。
该数据集包含四列
列 | 描述 |
---|---|
class_index | 一个从 1 到 4 的整数,分别代表:“world”、“sports”、“business”、“sci_tech” |
class | 一个字符串,是“world”、“sports”、“business”、“sci_tech”之一 |
title | 新闻文章标题 |
描述 | 新闻文章描述 |
Ludwig 还提供了其他几个可用于文本分类的基准数据集,包括
下载数据集¶
下载数据集并将其写入当前目录下的 agnews.csv
文件。
ludwig datasets download agnews
将 AG 新闻数据集下载到 pandas DataFrame 中。
from ludwig.datasets import agnews
# Loads the dataset as a pandas.DataFrame
train_df, test_df, _ = agnews.load()
该数据集包含上述四列以及一个额外的 split
列,其取值为 0:训练集,1:测试集,2:验证集。
样本 (为节省空间省略了描述文本)
class_index,title,description,split,class
3,Carlyle Looks Toward Commercial Aerospace (Reuters),...,0,business
3,Oil and Economy Cloud Stocks' Outlook (Reuters),...,0,business
3,Iraq Halts Oil Exports from Main Southern Pipeline (Reuters),...,0,business
训练¶
定义 Ludwig 配置¶
Ludwig 配置文件声明了机器学习任务。它告诉 Ludwig 要预测什么,使用哪些列作为输入,并可选择指定模型类型和超参数。
在这里,为简单起见,我们将尝试从 title 预测 class。
使用 config.yaml
文件
input_features:
-
name: title
type: text
encoder:
type: parallel_cnn
output_features:
-
name: class
type: category
trainer:
epochs: 3
使用 Python 字典定义的配置
config = {
"input_features": [
{
"name": "title", # The name of the input column
"type": "text", # Data type of the input column
"encoder": {
"type": "parallel_cnn"
} # The model architecture we should use for encoding this column
}
],
"output_features": [
{
"name": "class",
"type": "category",
}
],
"trainer": {
"epochs": 3, # We'll train for three epochs. Training longer might give
# better performance.
}
}
创建和训练模型¶
ludwig train --dataset agnews.csv -c config.yaml
# Constructs Ludwig model from config dictionary
model = LudwigModel(config, logging_level=logging.INFO)
# Trains the model. This cell might take a few minutes.
train_stats, preprocessed_data, output_directory = model.train(dataset=train_df)
评估¶
生成测试集的预测结果和性能统计数据。
ludwig evaluate \
--model_path results/experiment_run/model \
--dataset agnews.csv \
--split test \
--output_directory test_results
# Generates predictions and performance statistics for the test set.
test_stats, predictions, output_directory = model.evaluate(
test_df,
collect_predictions=True,
collect_overall_stats=True
)
可视化指标¶
可视化混淆矩阵,提供分类器在每个类别上的性能概览。
ludwig visualize \
--visualization confusion_matrix \
--ground_truth_metadata results/experiment_run/model/training_set_metadata.json \
--test_statistics test_results/test_statistics.json \
--output_directory visualizations \
--file_format png
from ludwig.visualize import confusion_matrix
confusion_matrix(
[test_stats],
model.training_set_metadata,
'class',
top_n_classes=[5],
model_names=[''],
normalize=True,
)
混淆矩阵 | 类别熵 |
---|---|
![]() |
![]() |
可视化学习曲线,显示训练期间性能指标随时间的变化。
ludwig visualize \
--visualization learning_curves \
--ground_truth_metadata results/experiment_run/model/training_set_metadata.json \
--training_statistics results/experiment_run/training_statistics.json \
--file_format png \
--output_directory visualizations
# Visualizes learning curves, which show how performance metrics changed over
# time during training.
from ludwig.visualize import learning_curves
learning_curves(train_stats, output_feature_name='class')
损失 | 指标 |
---|---|
![]() |
![]() |
![]() |
![]() |
在新数据上进行预测¶
最后,我们将展示如何为新数据生成预测结果。
以下是一些最近的新闻标题。请随意编辑或添加您自己的字符串到 text_to_predict,看看新训练的模型如何对其进行分类。
使用 text_to_predict.csv
文件
title
Google may spur cloud cybersecurity M&A with $5.4B Mandiant buy
Europe struggles to meet mounting needs of Ukraine's fleeing millions
How the pandemic housing market spurred buyer's remorse across America
ludwig predict \
--model_path results/experiment_run/model \
--dataset text_to_predict.csv \
--output_directory predictions
text_to_predict = pd.DataFrame({
"title": [
"Google may spur cloud cybersecurity M&A with $5.4B Mandiant buy",
"Europe struggles to meet mounting needs of Ukraine's fleeing millions",
"How the pandemic housing market spurred buyer's remorse across America",
]
})
predictions, output_directory = model.predict(text_to_predict)
此命令会将预测结果写入 output_directory
。预测输出以多种格式写入,包括 csv 和 parquet。例如,predictions/predictions.parquet
包含每个样本的预测类别以及每个类别的伪概率。
class_predictions | class_probabilities | class_probability | class_probabilities_<UNK> | class_probabilities_sci_tech | class_probabilities_sports | class_probabilities_world | class_probabilities_business |
---|---|---|---|---|---|---|---|
sci_tech | [1.9864278277825775e-10, ... | 0.954650 | 1.986428e-10 | 0.954650 | 0.000033 | 0.002563 | 0.042754 |
world | [8.458710176739714e-09, ... | 0.995293 | 8.458710e-09 | 0.002305 | 0.000379 | 0.995293 | 0.002022 |
business | [3.710099008458201e-06, ... | 0.490741 | 3.710099e-06 | 0.447916 | 0.000815 | 0.060523 | 0.490741 |