| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import HubertConfig, HubertModel, PreTrainedModel |
| from transformers.utils import ModelOutput |
|
|
| from .configuration_mhubert_ipa_ctc_ft import MHuBERTIPACTCFTConfig |
|
|
|
|
| @dataclass |
| class MHuBERTIPACTCFTOutput(ModelOutput): |
| logits: torch.Tensor = None |
| hidden_states: torch.Tensor = None |
|
|
|
|
| class MHuBERTIPACTCFTModel(PreTrainedModel): |
| config_class = MHuBERTIPACTCFTConfig |
| base_model_prefix = "mhubert_ipa_ctc_ft" |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| arch = config.architecture |
| backbone_cfg = HubertConfig.from_dict(config.backbone_config) |
| self.blank_id = int(arch["blank_id"]) |
| self.backbone = HubertModel(backbone_cfg) |
| self.proj = nn.Linear(arch["input_dim"], arch["proj_dim"]) |
| self.lstm = nn.LSTM( |
| arch["proj_dim"], |
| arch["lstm_hidden"], |
| num_layers=arch["lstm_layers"], |
| bidirectional=arch["lstm_bidirectional"], |
| batch_first=True, |
| dropout=arch["dropout"] if arch["lstm_layers"] > 1 else 0.0, |
| ) |
| self.drop = nn.Dropout(arch["dropout"]) |
| out_dim = arch["lstm_hidden"] * (2 if arch["lstm_bidirectional"] else 1) |
| self.head = nn.Linear(out_dim, arch["output_dim"]) |
| self.post_init() |
|
|
| def forward(self, input_values, attention_mask=None, **kwargs): |
| backbone_out = self.backbone(input_values=input_values, attention_mask=attention_mask, **kwargs) |
| x = self.proj(backbone_out.last_hidden_state) |
| out, _ = self.lstm(x) |
| logits = self.head(self.drop(out)) |
| return MHuBERTIPACTCFTOutput(logits=logits, hidden_states=backbone_out.last_hidden_state) |
|
|
|
|