from dataclasses import dataclass import torch import torch.nn as nn from transformers import HubertConfig, HubertModel, PreTrainedModel from transformers.utils import ModelOutput from .configuration_mhubert_ipa_ctc_ft import MHuBERTIPACTCFTConfig @dataclass class MHuBERTIPACTCFTOutput(ModelOutput): logits: torch.Tensor = None hidden_states: torch.Tensor = None class MHuBERTIPACTCFTModel(PreTrainedModel): config_class = MHuBERTIPACTCFTConfig base_model_prefix = "mhubert_ipa_ctc_ft" def __init__(self, config): super().__init__(config) arch = config.architecture backbone_cfg = HubertConfig.from_dict(config.backbone_config) self.blank_id = int(arch["blank_id"]) self.backbone = HubertModel(backbone_cfg) self.proj = nn.Linear(arch["input_dim"], arch["proj_dim"]) self.lstm = nn.LSTM( arch["proj_dim"], arch["lstm_hidden"], num_layers=arch["lstm_layers"], bidirectional=arch["lstm_bidirectional"], batch_first=True, dropout=arch["dropout"] if arch["lstm_layers"] > 1 else 0.0, ) self.drop = nn.Dropout(arch["dropout"]) out_dim = arch["lstm_hidden"] * (2 if arch["lstm_bidirectional"] else 1) self.head = nn.Linear(out_dim, arch["output_dim"]) self.post_init() def forward(self, input_values, attention_mask=None, **kwargs): backbone_out = self.backbone(input_values=input_values, attention_mask=attention_mask, **kwargs) x = self.proj(backbone_out.last_hidden_state) out, _ = self.lstm(x) logits = self.head(self.drop(out)) return MHuBERTIPACTCFTOutput(logits=logits, hidden_states=backbone_out.last_hidden_state)