跳到内容

添加编码器

1. 添加一个新的编码器类

编码器的源代码位于 ludwig/encoders/ 下。编码器根据其输入特征类型分组到不同的模块中。例如,所有新的序列编码器都应该添加到 ludwig/encoders/sequence_encoders.py 中。

注意

一个编码器可能支持多种类型,如果是这样,它应该在其最通用的支持类型所对应的模块中定义。如果一个编码器对于输入类型是通用的,则将其添加到 ludwig/encoders/generic_encoders.py 中。

要创建一个新的编码器:

  1. 定义一个新的编码器类。继承自 ludwig.encoders.base.Encoder 或其子类之一。
  2. 在调用 super().__init__() 后,在 __init__ 方法中创建所有层和状态。
  3. def forward(self, inputs, mask=None): 中实现编码器的前向传播。
  4. 定义 @property input_shape@property output_shape
  5. 定义一个 schema 类。

注意:Encoder 继承自 LudwigModule,它本身就是一个 torch.nn.Module,因此开发 Torch 模块的所有常见注意事项都适用。

所有编码器参数都应作为关键字参数提供给构造函数,并且必须具有默认值。例如,StackedRNN 编码器在其构造函数中接受以下参数列表:

from ludwig.constants import AUDIO, SEQUENCE, TEXT, TIMESERIES
from ludwig.encoders.base import Encoder
from ludwig.encoders.registry import register_encoder

@register_encoder("rnn", [AUDIO, SEQUENCE, TEXT, TIMESERIES])
class StackedRNN(Encoder):
    def __init__(
        self,
        should_embed=True,
        vocab=None,
        representation="dense",
        embedding_size=256,
        embeddings_trainable=True,
        pretrained_embeddings=None,
        embeddings_on_cpu=False,
        num_layers=1,
        max_sequence_length=None,
        state_size=256,
        cell_type="rnn",
        bidirectional=False,
        activation="tanh",
        recurrent_activation="sigmoid",
        unit_forget_bias=True,
        recurrent_initializer="orthogonal",
        dropout=0.0,
        recurrent_dropout=0.0,
        fc_layers=None,
        num_fc_layers=0,
        output_size=256,
        use_bias=True,
        weights_initializer="xavier_uniform",
        bias_initializer="zeros",
        norm=None,
        norm_params=None,
        fc_activation="relu",
        fc_dropout=0,
        reduce_output="last",
        **kwargs,
    ):
    super().__init__()
    # Initialize any modules, layers, or variable state

2. 实现 forwardinput_shapeoutput_shape

实际的激活计算发生在编码器的 forward 方法内部。所有编码器都应该具有以下签名:

    def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None):
        # perform forward pass
        # ...
        # output_tensor = result of forward pass
        return {"encoder_output": output_tensor}

输入

  • inputs (torch.Tensor):输入张量。
  • mask (torch.Tensor, 默认值: None):二进制张量,指示输入中哪些值应该被掩盖。注意:mask 不是必需的,并且在大多数编码器类型中未实现。

返回

  • (dict):一个字典,包含键 encoder_output,其值是编码器输出张量。{"encoder_output": output_tensor}

input_shapeoutput_shape 属性必须返回编码器预期输入和输出的完整指定形状,不包含批次维度。

    @property
    def input_shape(self) -> torch.Size:
        return torch.Size([self.max_sequence_length])

    @property
    def output_shape(self) -> torch.Size:
        return self.recurrent_stack.output_shape

3. 将新的编码器类添加到编码器注册表中

模型定义中的编码器名称与编码器类之间的映射是通过在编码器注册表中注册类来完成的。编码器注册表定义在 ludwig/encoders/registry.py 中。要注册您的类,在其类定义上方添加 @register_encoder 装饰器,指定编码器的名称和支持的输入特征类型列表。

@register_encoder("rnn", [AUDIO, SEQUENCE, TEXT, TIMESERIES])
class StackedRNN(Encoder):

4. 定义一个 schema 类

为了确保您自定义的编码器功能按预期进行用户配置验证,我们需要为新定义的编码器定义一个 schema 类。为此,我们在一个类定义上使用 marshmallow_dataclass 装饰器,该类定义包含您自定义编码器的所有输入作为属性。对于每个属性,我们使用 ludwig.schema.utils 目录中的实用函数来验证该输入。最后,我们需要在自定义编码器类上添加对此 schema 类的引用。例如:

from marshmallow_dataclass import dataclass

from ludwig.constants import SEQUENCE, TEXT
from ludwig.schema.encoders.base import BaseEncoderConfig
from ludwig.schema.encoders.utils import register_encoder_config
import ludwig.schema.utils as schema_utils

@register_encoder_config("stacked_rnn", [SEQUENCE, TEXT])
@dataclass
class StackedRNNConfig(BaseEncoderConfig):
        type: str = schema_utils.StringOptions(options=["stacked_rnn"], default="stacked_rnn")
        should_embed: bool = schema_utils.Boolean(default=True, description="")
        vocab: list = schema_utils.List(list_type=str, default=None, description="")
        representation: str = schema_utils.StringOptions(options=["sparse", "dense"], default="dense", description="")
        embedding_size: int = schema_utils.Integer(default=256, description="")
        embeddings_trainable: bool = schema_utils.Boolean(default=True, description="")
        pretrained_embeddings: str = schema_utils.String(default=None, description="")
        embeddings_on_cpu: bool = schema_utils.Boolean(default=False, description="")
        num_layers: int = schema_utils.Integer(default=1, description="")
        max_sequence_length: int = schema_utils.Integer(default=None, description="")
        state_size: int = schema_utils.Integer(default=256, description="")
        cell_type: str = schema_utils.StringOptions(
            options=["rnn", "lstm", "lstm_block", "ln", "lstm_cudnn", "gru", "gru_block", "gru_cudnn"], 
            default="rnn", description=""
        )
        bidirectional: bool = schema_utils.Boolean(default=False, description="")
        activation: str = schema_utils.ActivationOptions(default="tanh", description="")
        recurrent_activation: str = schema_utils.activations(default="sigmoid", description="")
        unit_forget_bias: bool = schema_utils.Boolean(default=True, description="")
        recurrent_initializer: str = schema_utils.InitializerOptions(default="orthogonal", description="")
        dropout: float = schema_utils.FloatRange(default=0.0, min=0, max=1, description="")
        recurrent_dropout: float = schema_utils.FloatRange(default=0.0, min=0, max=1, description="")
        fc_layers: list = schema_utils.DictList(default=None, description="")
        num_fc_layers: int = schema_utils.NonNegativeInteger(default=0, description="")
        output_size: int = schema_utils.Integer(default=256, description="")
        use_bias: bool = schema_utils.Boolean(default=True, description="")
        weights_initialize: str = schema_utils.InitializerOptions(default="xavier_uniform", description="")
        bias_initializer: str = schema_utils.InitializerOptions(default="zeros", description="")
        norm: str = schema_utils.StringOptions(options=["batch", "layer"], default=None, description="")
        norm_params: dict = schema_utils.Dict(default=None, description="")
        fc_activation: str = schema_utils.ActivationOptions(default="relu", description="")
        fc_dropout: float = schema_utils.FloatRange(default=0.0, min=0, max=1, description="")
        reduce_output: str = schema_utils.ReductionOptions(default="last", description="")

最后,您应该在自定义编码器上添加对此 schema 类的引用

    @staticmethod
    def get_schema_cls():
        return StackedRNNConfig