添加解码器
1. 添加新的解码器类¶
解码器的源代码位于 ludwig/decoders/
目录下。解码器根据其输出特征类型被分组到不同的模块中。例如,所有新的序列解码器都应该添加到 ludwig/decoders/sequence_decoders.py
文件中。
注意
一个解码器可能支持多种输出类型,如果是这样,它应该在其支持的最通用类型对应的模块中定义。如果一个解码器对输出类型是通用的,则将其添加到 ludwig/decoders/generic_decoders.py
文件中。
创建新解码器
- 定义一个新的解码器类。继承自
ludwig.decoders.base.Decoder
或其子类之一。 - 在调用
super().__init__()
后,在__init__
方法中创建所有层和状态。 - 在
def forward(self, combiner_outputs, **kwargs):
方法中实现解码器的前向传播。 - 定义一个 schema 类。
注意:Decoder
继承自 LudwigModule
,后者本身是一个 torch.nn.Module,因此开发 Torch 模块的所有常见问题都适用。
所有解码器参数都应作为关键字参数提供给构造函数,并且必须具有默认值。例如,SequenceGeneratorDecoder
解码器在其构造函数中接受以下参数列表:
from ludwig.constants import SEQUENCE, TEXT
from ludwig.decoders.base import Decoder
from ludwig.decoders.registry import register_decoder
@register_decoder("generator", [SEQUENCE, TEXT])
class SequenceGeneratorDecoder(Decoder):
def __init__(
self,
vocab_size: int,
max_sequence_length: int,
cell_type: str = "gru",
input_size: int = 256,
reduce_input: str = "sum",
num_layers: int = 1,
**kwargs,
):
super().__init__()
# Initialize any modules, layers, or variable state
2. 实现 forward
方法¶
激活值的实际计算发生在解码器的 forward
方法内部。所有解码器都应该具有以下签名:
def forward(self, combiner_outputs, **kwargs):
# perform forward pass
# combiner_hidden_output = combiner_outputs[HIDDEN]
# ...
# logits = result of decoder forward pass
return {LOGITS: logits}
输入
- combiner_outputs (Dict[str, torch.Tensor]):输入张量,它是组合器的输出,或组合器与任何依赖的输出解码器的激活值的组合。组合器输出的字典包含形状为
b x h
的张量(其中b
是批量大小,h
是嵌入大小),或形状为b x s x h
的嵌入序列(其中s
是序列长度)。
返回值
- (Dict[str, torch.Tensor]):解码器输出张量的字典。典型的解码器将返回键为
LOGITS
、PREDICTION
或两者的值(在ludwig.constants
中定义)。
3. 将新的解码器类添加到相应的解码器注册表¶
模型定义中的解码器名称与解码器类之间的映射是通过在解码器注册表中注册类来完成的。解码器注册表定义在 ludwig/decoders/registry.py
文件中。要注册您的类,请在其类定义上方的行添加 @register_decoder
装饰器,指定解码器的名称和支持的输出特征类型列表。
@register_decoder("generator", [SEQUENCE, TEXT])
class SequenceGeneratorDecoder(Decoder):
4. 定义 schema 类¶
为了确保您的自定义解码器的用户配置验证按预期工作,我们需要定义一个 schema 类来配合新定义的解码器。为此,我们使用 marshmallow_dataclass
装饰器修饰一个类定义,该类包含自定义解码器的所有输入作为属性。对于每个属性,我们使用 ludwig.schema.utils
目录中的工具函数来验证输入。最后,我们需要在自定义解码器类上引用这个 schema 类。例如:
from marshmallow_dataclass import dataclass
from ludwig.constants import SEQUENCE, TEXT
from ludwig.schema.decoders.base import BaseDecoderConfig
from ludwig.schema.decoders.utils import register_decoder_config
import ludwig.schema.utils as schema_utils
@register_decoder_config("generator", [SEQUENCE, TEXT])
@dataclass
class SequenceGeneratorDecoderConfig(BaseDecoderConfig):
type: str = schema_utils.StringOptions(options=["generator"], default="generator")
vocab_size: int = schema_utils.Integer(default=None, description="")
max_sequence_length: int = schema_utils.Integer(default=None, description="")
cell_type: str = schema_utils.String(default="gru", description="")
input_size: int = schema_utils.Integer(default=256, description="")
reduce_input: str = schema_utils.ReductionOptions(default="sum")
num_layers: int = schema_utils.Integer(default=1, description="")
最后,您应该在自定义解码器上添加对 schema 类的引用。
@staticmethod
def get_schema_cls():
return SequenceGeneratorDecoderConfig