跳到内容

添加一个指标

指标用于报告模型在训练和评估期间的性能,也可作为 超参数优化 的优化目标。

具体来说,指标是模块,用于计算模型在每个批次输出的函数,并聚合所有批次的结果。一个常见的指标例子是 LossMetric,它计算平均批次损失。指标在 ludwig/modules/metric_modules.py 中定义。Ludwig 的指标旨在与 torchmetrics 保持一致,并符合 torchmetrics.Metric 的接口。

注意

从头开始实现新指标之前,请查看 torchmetrics 文档,看看所需功能是否已存在。Torchmetrics 通常可以非常简单地添加到 Ludwig 中,例如查看 ludwig/modules/metric_modules.py 中的 RMSEMetric

1. 添加一个新的指标类

对于大多数用例,指标应该在批次之间进行平均,为此 Ludwig 提供了一个 MeanMetric 类,它维护其值的运行平均。以下示例将假设需要平均,并继承自 MeanMetric。如果您需要不同的聚合行为,请将 MeanMetric 替换为 LudwigMetric 并根据需要累积指标值。

我们将使用 TokenAccuracyMetric 作为示例,它将序列的每个 token 视为独立的预测,并计算序列上的平均准确率。

首先,在 ludwig/modules/metric_modules.py 中声明新的指标类

class TokenAccuracyMetric(MeanMetric):

2. 实现所需方法

get_current_value

如果使用 MeanMetric,请在 get_current_value 中根据一批特征输出和目标值计算指标的值。

def get_current_value(
        self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    # Compute metric over a batch of predictions (preds) and truth values (target).
    # Aggregate metric over batch.
    return metric_value

输入

  • preds (torch.Tensor): 一批来自输出特征的输出,根据 get_inputs 的返回值,可以是预测值、概率或 logits。
  • target (torch.Tensor): 与指标的输出特征相对应的数据集列的真实标签批次。

返回值

  • (torch.Tensor): 计算出的指标,在大多数情况下这是一个标量值。

update 和 reset

如果不使用 MeanMetric,则实现 updatereset 方法,而不是 get_current_value

def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
    # Compute metric over a batch of predictions (preds) and truth values (target).
    # Accumulate metric values or aggregate statistics.

输入

  • preds (torch.Tensor): 一批来自输出特征的输出,根据 get_inputs 的返回值,可以是预测值、概率或 logits。
  • target (torch.Tensor): 与指标的输出特征相对应的数据集列的真实标签批次。
def reset(self) -> None:
    # Reset accumulated values.

注意

MeanMetric 的 update 方法只是将指标计算委托给 get_current_value

def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
    self.avg.update(self.get_current_value(preds, target))

get_objective

get_objective 的返回值告诉 Ludwig 在超参数优化中应该最小化还是最大化此指标。

@classmethod
def get_objective(cls):
    return MAXIMIZE

返回值

  • (str): 如何优化此指标,可以是 MINIMIZE 或 MAXIMIZE 之一。

get_inputs

确定哪个特征输出被传递到此指标的 updateget_current_value 方法中。有效的返回值包括:

  • PREDICTIONS: 输出特征的预测值。
  • PROBABILITIES: 概率向量。
  • LOGITS: 特征解码器最后一层(应用任何 sigmoid 或 softmax 函数之前)的输出向量。
@classmethod
def get_inputs(cls):
    return PREDICTIONS

返回值

  • (str): 此指标值来源的输出,可以是 PREDICTIONS, PROBABILITIES, 或 LOGITS 之一。

3. 将新的指标类添加到注册表

config 中的指标名称与指标类之间的映射通过在指标注册表中注册类来完成。指标注册表在 ludwig/modules/metric_registry.py 中定义。要注册您的类,在其类定义上方的行添加 @register_metric 装饰器,指定指标的名称和支持的输出特征类型列表

@register_metric(TOKEN_ACCURACY, [SEQUENCE, TEXT])
class TokenAccuracyMetric(MeanMetric):