跳到内容

添加损失函数

概括来说,损失函数评估模型预测数据集的效果。损失函数应始终输出一个标量。损失越低意味着拟合越好,因此训练的目标是最小化损失。

Ludwig 的损失函数符合 torch.nn.Module 接口,并在 ludwig/modules/loss_modules.py 中声明。在从头实现新的损失函数之前,请查阅 torch.nn 损失函数 的文档,看看是否存在所需的损失函数。将 PyTorch 损失函数添加到 Ludwig 比从头实现损失函数更简单。

将 PyTorch 损失函数添加到 Ludwig

签名接受模型输出和目标(即 loss(model(input), target))的 PyTorch 损失函数,可以通过在 ludwig/modules/loss_modules.py 中声明一个简单的子类并为一种或多种输出特征类型注册该损失函数来轻松添加到 Ludwig 中。此示例将 MAELoss(平均绝对误差损失)添加到 Ludwig

@register_loss("mean_absolute_error", [NUMBER, TIMESERIES, VECTOR])
class MAELoss(torch.nn.L1Loss, LogitsInputsMixin):
    def __init__(self, **kwargs):
        super().__init__()

@register_loss 装饰器将损失函数注册在名称 mean_absolute_error 下,并指示它支持 NUMBERTIMESERIESVECTOR 输出特征。

从头实现损失函数

实现损失函数

要实现新的损失函数,我们建议首先将其实现为一个以 logits 和 labels 为参数,以及其他配置参数的函数。例如,假设我们实现了来自 "Robust Bi-Tempered Logistic Loss Based on Bregman Divergences" 的 tempered softmax。这个损失函数接受两个常数参数 t1t2,我们希望允许用户在配置中指定它们。

假设我们有以下函数

def tempered_softmax_cross_entropy_loss(
        logits: torch.Tensor,
        labels: torch.Tensor,
        t1: float, t2: float) -> torch.Tensor:
    # Computes the loss, returns the result as a torch.Tensor.

定义并注册模块

接下来,我们将定义一个计算损失函数的模块类,并使用 @register_loss 将其添加到用于 CATEGORY 输出特征的损失注册表中。LogitsInputsMixin 告诉 Ludwig 这个损失函数应该使用输出特征 logits 来调用,logits 是特征解码器在归一化为概率分布之前的输出。

@register_loss("tempered_softmax_cross_entropy", [CATEGORY])
class TemperedSoftmaxCrossEntropy(torch.nn.Module, LogitsInputsMixin):

注意

可以在 logits 之外的其他输出上定义损失函数,但这目前在 Ludwig 中尚未使用。例如,可以在 probabilities 上计算损失,但这通常从 logits 计算更具数值稳定性(而不是通过 softmax 函数反向传播损失)。

构造函数

损失函数的构造函数将接收配置中指定的任何参数作为 kwargs。它必须为所有参数提供合理的默认值。

def __init__(self, t1: float = 1.0, t2: float = 1.0, **kwargs):
    super().__init__()
    self.t1 = t1
    self.t2 = t2

forward

forward 方法负责计算损失函数。在这里,我们将在确保其输入类型正确后调用 tempered_softmax_cross_entropy_loss,并返回按批量平均后的输出。

def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    labels = target.long()
    loss = tempered_softmax_cross_entropy_loss(logits, labels, self.t1, self.t2)
    return torch.mean(loss)

定义损失 Schema 类

为了根据您定义的新损失函数的预期输入和输入类型验证用户输入,我们需要创建一个 Schema 类,该类将自动生成验证所需的 json schema。这个类应该在 ludiwg.schema.features.loss.loss.py 中定义。此示例为上面定义的 MAELoss 类添加一个 Schema 类

@dataclass
class MAELossConfig(BaseLossConfig):

    type: str = schema_utils.StringOptions(
        options=[MEAN_ABSOLUTE_ERROR],
        description="Type of loss.",
    )

    weight: float = schema_utils.NonNegativeFloat(
        default=1.0,
        description="Weight of the loss.",
    )

最后,我们需要在损失类上添加对这个 Schema 类的引用。例如,在上面定义的 MAELoss 类上,我们会添加

    @staticmethod
    def get_schema_cls():
        return MAELossConfig