chua commited on
Commit
7e53c9a
·
0 Parent(s):

Upload GPA v1.5 model package

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .DS_Store
README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
added_tokens.json ADDED
The diff for this file is too large to render. See raw diff
 
chat_template.jinja ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% for message in messages %}
2
+ {% if message.role == "system" %}
3
+ <|system|>
4
+ {{ message.content }}
5
+ {% elif message.role == "user" %}
6
+ <|user|>
7
+ {{ message.content }}
8
+ {% elif message.role == "assistant" %}
9
+ <|assistant|>
10
+ {{ message.content }}
11
+ {% endif %}
12
+ {% endfor %}
13
+ {% if add_generation_prompt %}
14
+ <|assistant|>
15
+ {% endif %}
config.json ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "adapter_type": "mlp",
3
+ "architectures": [
4
+ "ArkasrForConditionalGeneration"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "audio_token_id": 151663,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_arkasr.ArkasrConfig",
10
+ "AutoModelForCausalLM": "modeling_arkasr.ArkasrForConditionalGeneration"
11
+ },
12
+ "dtype": "float32",
13
+ "eos_token_id": 151665,
14
+ "hidden_act": "silu",
15
+ "hidden_size": 896,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 4864,
18
+ "layer_types": [
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention"
43
+ ],
44
+ "max_position_embeddings": 32768,
45
+ "max_whisper_length": 1500,
46
+ "max_window_layers": 24,
47
+ "merge_factor": 4,
48
+ "mlp_adapter_act": "gelu",
49
+ "model_type": "arkasr",
50
+ "num_attention_heads": 14,
51
+ "num_hidden_layers": 24,
52
+ "num_key_value_heads": 2,
53
+ "pad_token_id": 151643,
54
+ "rms_norm_eps": 1e-06,
55
+ "rope_scaling": null,
56
+ "rope_theta": 1000000.0,
57
+ "sliding_window": null,
58
+ "spec_aug": false,
59
+ "tie_word_embeddings": true,
60
+ "transformers_version": "4.57.3",
61
+ "use_cache": false,
62
+ "use_mrope": false,
63
+ "use_rope": true,
64
+ "use_sliding_window": false,
65
+ "vocab_size": 163958,
66
+ "whisper_config": {
67
+ "activation_dropout": 0.0,
68
+ "activation_function": "gelu",
69
+ "apply_spec_augment": false,
70
+ "architectures": [
71
+ "WhisperForConditionalGeneration"
72
+ ],
73
+ "attention_dropout": 0.0,
74
+ "begin_suppress_tokens": [
75
+ 220,
76
+ 50257
77
+ ],
78
+ "bos_token_id": 50257,
79
+ "classifier_proj_size": 256,
80
+ "d_model": 1280,
81
+ "decoder_attention_heads": 20,
82
+ "decoder_ffn_dim": 5120,
83
+ "decoder_layerdrop": 0.0,
84
+ "decoder_layers": 32,
85
+ "decoder_start_token_id": 50258,
86
+ "dropout": 0.0,
87
+ "dtype": "bfloat16",
88
+ "encoder_attention_heads": 20,
89
+ "encoder_ffn_dim": 5120,
90
+ "encoder_layerdrop": 0.0,
91
+ "encoder_layers": 32,
92
+ "eos_token_id": 50257,
93
+ "init_std": 0.02,
94
+ "mask_feature_length": 10,
95
+ "mask_feature_min_masks": 0,
96
+ "mask_feature_prob": 0.0,
97
+ "mask_time_length": 10,
98
+ "mask_time_min_masks": 2,
99
+ "mask_time_prob": 0.05,
100
+ "max_length": 448,
101
+ "max_source_positions": 1500,
102
+ "max_target_positions": 448,
103
+ "median_filter_width": 7,
104
+ "model_type": "whisper",
105
+ "num_hidden_layers": 32,
106
+ "num_mel_bins": 128,
107
+ "scale_embedding": false,
108
+ "use_cache": true,
109
+ "use_weighted_layer_sum": false,
110
+ "vocab_size": 51866
111
+ }
112
+ }
configuration_arkasr.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Union
2
+ from transformers import Qwen2Config, WhisperConfig
3
+
4
+
5
+ class ArkasrConfig(Qwen2Config):
6
+ model_type = "arkasr"
7
+ is_composition = True
8
+
9
+ def __init__(
10
+ self,
11
+ whisper_config: Optional[Union[Dict[str, Any], WhisperConfig]] = None,
12
+ adapter_type: str = "mlp",
13
+ merge_factor: int = 4,
14
+ spec_aug: bool = False,
15
+ use_rope: bool = True,
16
+ max_whisper_length: int = 1500,
17
+ mlp_adapter_act: str = "gelu",
18
+ **kwargs, # All Qwen2Config parameters are forwarded from here.
19
+ ):
20
+ # === 1️⃣ Key point: initialize Qwen2Config (LM section) ===
21
+ # This consumes fields such as:
22
+ # vocab_size / hidden_size / num_hidden_layers / rope_scaling / ...
23
+ super().__init__(**kwargs)
24
+
25
+ # === 2️⃣ Whisper sub-config ===
26
+ if isinstance(whisper_config, dict):
27
+ self.whisper_config = WhisperConfig(**whisper_config)
28
+ elif isinstance(whisper_config, WhisperConfig):
29
+ self.whisper_config = whisper_config
30
+ else:
31
+ self.whisper_config = WhisperConfig()
32
+
33
+ # === 3️⃣ ArkASR-specific parameters ===
34
+ self.adapter_type = adapter_type
35
+ self.merge_factor = int(merge_factor)
36
+ self.spec_aug = bool(spec_aug)
37
+ self.use_rope = bool(use_rope)
38
+ self.max_whisper_length = int(max_whisper_length)
39
+ self.mlp_adapter_act = mlp_adapter_act
40
+
41
+ def to_dict(self):
42
+ output = super().to_dict()
43
+ output["whisper_config"] = self.whisper_config.to_dict()
44
+ return output
45
+
46
+
47
+ __all__ = ["ArkasrConfig"]
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": 151665,
4
+ "pad_token_id": 151665,
5
+ "transformers_version": "4.57.3"
6
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a7d69140f7b75d815e2136a83fc307dc353f6d5ecd2134a1187d6948ae6aa37
3
+ size 2305207856
modeling_arkasr.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, List, Tuple, Union, Dict
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+ from transformers import Qwen2ForCausalLM
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+
10
+ from .configuration_arkasr import ArkasrConfig
11
+ from .modeling_audio import WhisperSpecialEncoder
12
+
13
+
14
+ class AudioMLPAdapter(nn.Module):
15
+ def __init__(self, config: ArkasrConfig):
16
+ super().__init__()
17
+ whisper_config = config.whisper_config
18
+ self.merge_factor = int(config.merge_factor)
19
+
20
+ # Audio encoder
21
+ self.whisper = WhisperSpecialEncoder(
22
+ whisper_config,
23
+ use_rope=getattr(config, "use_rope", False),
24
+ )
25
+ # Disable Whisper's built-in LayerNorm.
26
+ self.whisper.layer_norm = nn.Identity()
27
+ self.layer_norm = nn.LayerNorm(whisper_config.hidden_size)
28
+
29
+ act_fn_map = {
30
+ "gelu": nn.GELU(),
31
+ "relu": nn.ReLU(),
32
+ "selu": nn.SELU(),
33
+ }
34
+ act = act_fn_map.get(getattr(config, "mlp_adapter_act", "gelu"), nn.GELU())
35
+
36
+ input_dim = whisper_config.hidden_size * self.merge_factor
37
+ output_dim = config.hidden_size
38
+
39
+ self.adapting = nn.Sequential(
40
+ nn.Linear(input_dim, output_dim * 2),
41
+ act,
42
+ nn.Linear(output_dim * 2, output_dim),
43
+ )
44
+
45
+ def forward(self, audios: Tensor) -> Tensor:
46
+ """
47
+ Args:
48
+ audios: (B, mel, T) or (B, raw_len), depending on WhisperSpecialEncoder.
49
+ Returns:
50
+ adapted_features: (B, Seq_Audio, LLM_Hidden_Dim)
51
+ """
52
+ bsz = audios.size(0)
53
+
54
+ encoded = self.whisper(audios)[0] # (B, T, D)
55
+ encoded = self.layer_norm(encoded)
56
+
57
+ seq_len = encoded.size(1)
58
+ if seq_len % self.merge_factor != 0:
59
+ target_len = (seq_len // self.merge_factor) * self.merge_factor
60
+ if target_len <= 0:
61
+ # Guard for extremely short audio: pad to merge_factor.
62
+ target_len = self.merge_factor
63
+ if seq_len < target_len:
64
+ pad_len = target_len - seq_len
65
+ pad = encoded.new_zeros((bsz, pad_len, encoded.size(-1)))
66
+ encoded = torch.cat([encoded, pad], dim=1)
67
+ else:
68
+ encoded = encoded[:, :target_len, :]
69
+
70
+ encoded = encoded.reshape(bsz, -1, encoded.size(-1) * self.merge_factor)
71
+ adapted = self.adapting(encoded) # (B, T/k, hidden)
72
+ return adapted
73
+
74
+
75
+ class ArkasrForConditionalGeneration(Qwen2ForCausalLM):
76
+ config_class = ArkasrConfig
77
+ _no_split_modules = ["WhisperSpecialEncoder"]
78
+
79
+ def __init__(self, config: ArkasrConfig):
80
+ super().__init__(config)
81
+ self.audio_encoder = AudioMLPAdapter(config)
82
+
83
+ self.audio_token_id = getattr(config, "audio_token_id", None)
84
+ if self.audio_token_id is None:
85
+ raise ValueError("`audio_token_id` must be defined in config.")
86
+
87
+ @staticmethod
88
+ def _cache_seq_len(past_key_values) -> int:
89
+ if past_key_values is None:
90
+ return 0
91
+ if hasattr(past_key_values, "get_seq_length"):
92
+ try:
93
+ return int(past_key_values.get_seq_length())
94
+ except Exception:
95
+ return 0
96
+ try:
97
+ return int(past_key_values[0][0].shape[-2])
98
+ except Exception:
99
+ return 0
100
+
101
+ def _inject_audio_embeddings_batch_encode_then_loop_scatter(
102
+ self,
103
+ input_ids: torch.LongTensor, # (B, S)
104
+ inputs_embeds: torch.FloatTensor, # (B, S, H)
105
+ audios: Tensor, # (B, ...)
106
+ ) -> torch.FloatTensor:
107
+ """
108
+ First run one batched audio encoding pass for samples that contain audio tokens,
109
+ then scatter each sample's audio features back into inputs_embeds at the
110
+ corresponding audio_token positions.
111
+
112
+ Benefits:
113
+ - The encoder runs only once.
114
+ - Scatter is performed per sample, so features cannot drift across samples.
115
+ - Rows without audio_token are skipped directly, which keeps TTS-only rows unaffected.
116
+
117
+ Constraint:
118
+ - The number of audio tokens n_i in each sample should align with Sa from the
119
+ audio encoder output. If they do not align, this path truncates or zero-pads
120
+ to n_i instead of raising an error.
121
+ """
122
+ B, S = input_ids.shape
123
+ H = inputs_embeds.size(-1)
124
+ device = inputs_embeds.device
125
+ dtype = inputs_embeds.dtype
126
+
127
+ # Find the samples that require audio injection.
128
+ mask = (input_ids == self.audio_token_id) # (B, S)
129
+ per_counts = mask.sum(dim=1) # (B,)
130
+ need_idx = (per_counts > 0).nonzero(as_tuple=False).squeeze(1) # (K,)
131
+
132
+ if need_idx.numel() == 0:
133
+ return inputs_embeds
134
+
135
+ # Encode only the subset of audio that needs injection. (K, ...)
136
+ audios_sub = audios.index_select(0, need_idx)
137
+ feats_sub = self.audio_encoder(audios_sub) # (K, Sa, H)
138
+
139
+ # Scatter back per sample; the write-back itself is negligible.
140
+ feats_sub = feats_sub.to(device=device, dtype=dtype)
141
+ Sa = feats_sub.size(1)
142
+
143
+ # Inject one sample at a time.
144
+ for k in range(need_idx.numel()):
145
+ i = int(need_idx[k].item())
146
+ n_i = int(per_counts[i].item())
147
+ if n_i <= 0:
148
+ continue
149
+
150
+ feat_i = feats_sub[k] # (Sa, H)
151
+
152
+ # Align to the number of audio tokens n_i for this sample.
153
+ if Sa < n_i:
154
+ pad = feat_i.new_zeros((n_i - Sa, H))
155
+ feat_i_use = torch.cat([feat_i, pad], dim=0)
156
+ elif Sa > n_i:
157
+ feat_i_use = feat_i[:n_i]
158
+ else:
159
+ feat_i_use = feat_i
160
+
161
+ pos_i = mask[i].nonzero(as_tuple=False).squeeze(1) # (n_i,)
162
+ # Write features back into embeddings.
163
+ inputs_embeds[i, pos_i, :] = feat_i_use
164
+
165
+ return inputs_embeds
166
+
167
+ def forward(
168
+ self,
169
+ input_ids: Optional[torch.LongTensor] = None,
170
+ audios: Optional[Tensor] = None,
171
+ attention_mask: Optional[Tensor] = None,
172
+ position_ids: Optional[Tensor] = None,
173
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
174
+ inputs_embeds: Optional[torch.FloatTensor] = None,
175
+ use_cache: Optional[bool] = None,
176
+ labels: Optional[torch.LongTensor] = None,
177
+ output_attentions: Optional[bool] = None,
178
+ output_hidden_states: Optional[bool] = None,
179
+ logits_to_keep: int | torch.Tensor = 0,
180
+ **kwargs,
181
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
182
+
183
+ if inputs_embeds is None:
184
+ if input_ids is None:
185
+ raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
186
+ inputs_embeds = self.model.embed_tokens(input_ids)
187
+
188
+ # Inject only on the first step (past_len == 0) to avoid re-encoding during generation.
189
+ past_len = self._cache_seq_len(past_key_values)
190
+ if audios is not None and input_ids is not None and past_len == 0:
191
+ inputs_embeds = self._inject_audio_embeddings_batch_encode_then_loop_scatter(
192
+ input_ids=input_ids,
193
+ inputs_embeds=inputs_embeds,
194
+ audios=audios,
195
+ )
196
+
197
+ outputs = self.model(
198
+ input_ids=None,
199
+ attention_mask=attention_mask,
200
+ position_ids=position_ids,
201
+ past_key_values=past_key_values,
202
+ inputs_embeds=inputs_embeds,
203
+ use_cache=use_cache,
204
+ output_attentions=output_attentions,
205
+ output_hidden_states=output_hidden_states,
206
+ )
207
+
208
+ hidden_states = outputs[0]
209
+
210
+ # Restrict logits computation when possible to avoid redundant lm_head work.
211
+ if isinstance(logits_to_keep, int) and logits_to_keep > 0:
212
+ hidden_for_logits = hidden_states[:, -logits_to_keep:, :]
213
+ elif isinstance(logits_to_keep, torch.Tensor):
214
+ hidden_for_logits = hidden_states[:, logits_to_keep, :]
215
+ else:
216
+ hidden_for_logits = hidden_states
217
+
218
+ logits = self.lm_head(hidden_for_logits)
219
+
220
+ loss = None
221
+ if labels is not None:
222
+ loss = self.loss_function(
223
+ logits=logits,
224
+ labels=labels,
225
+ vocab_size=self.config.vocab_size,
226
+ **kwargs,
227
+ )
228
+
229
+ return CausalLMOutputWithPast(
230
+ loss=loss,
231
+ logits=logits,
232
+ past_key_values=outputs.past_key_values,
233
+ hidden_states=outputs.hidden_states,
234
+ attentions=outputs.attentions,
235
+ )
236
+
237
+ def prepare_inputs_for_generation(
238
+ self,
239
+ input_ids,
240
+ past_key_values=None,
241
+ attention_mask=None,
242
+ inputs_embeds=None,
243
+ **kwargs,
244
+ ):
245
+ past_len = self._cache_seq_len(past_key_values)
246
+ if past_len > 0:
247
+ input_ids = input_ids[:, -1:]
248
+
249
+ model_inputs = {
250
+ "input_ids": input_ids,
251
+ "past_key_values": past_key_values,
252
+ "use_cache": kwargs.get("use_cache"),
253
+ "attention_mask": attention_mask,
254
+ # Pass audios through. Injection happens only when past_len == 0 in forward,
255
+ # so later generation steps do not re-encode.
256
+ "audios": kwargs.get("audios", None),
257
+ }
258
+
259
+ if inputs_embeds is not None and past_key_values is None:
260
+ model_inputs["inputs_embeds"] = inputs_embeds
261
+ del model_inputs["input_ids"]
262
+
263
+ return model_inputs
264
+
265
+
266
+ __all__ = ["ArkasrForConditionalGeneration", "AudioMLPAdapter"]
modeling_audio.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+ from torch.nn.functional import scaled_dot_product_attention
6
+ from transformers import WhisperConfig
7
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
8
+ from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperEncoderLayer
9
+ from transformers.utils import logging
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+ # ==========================================
14
+ # 1. Core Rotary Embedding components
15
+ # ==========================================
16
+
17
+ class RotaryEmbedding(nn.Module):
18
+ def __init__(self, dim, rope_ratio=1):
19
+ super().__init__()
20
+ self.dim = dim
21
+ self.rope_ratio = rope_ratio
22
+
23
+ @torch.no_grad()
24
+ def get_emb(self, seq_len: int, dtype: torch.dtype, device: torch.device, base: int = 10000):
25
+ """Generate the cached RoPE table."""
26
+ base = base * self.rope_ratio
27
+ # Compute the theta frequencies.
28
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim))
29
+
30
+ # Build the position indices.
31
+ t = torch.arange(seq_len, device=device, dtype=torch.float)
32
+ freqs = torch.outer(t, inv_freq) # [seq_len, dim/2]
33
+
34
+ # Construct the cos/sin cache.
35
+ # Shape: [seq_len, dim/2, 2]
36
+ emb = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
37
+
38
+ if dtype in (torch.float16, torch.bfloat16):
39
+ emb = emb.to(dtype)
40
+ return emb
41
+
42
+ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
43
+ """
44
+ x: [batch, num_heads, seq_len, head_dim]
45
+ rope_cache: [1, seq_len, dim/2, 2]
46
+ """
47
+ b, nh, sq, hd = x.shape
48
+ rot_dim = rope_cache.shape[-2] * 2
49
+
50
+ # Split x into the rotated and pass-through portions.
51
+ x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:]
52
+
53
+ # Reshape x_rot to match rope_cache: [b, nh, sq, rot_dim/2, 2]
54
+ x_shaped = x_rot.reshape(b, nh, sq, rot_dim // 2, 2)
55
+
56
+ # Apply the complex rotation: (a+bi)(c+di) = (ac-bd) + (ad+bc)i
57
+ cos = rope_cache[..., 0] # [1, sq, rot_dim/2]
58
+ sin = rope_cache[..., 1] # [1, sq, rot_dim/2]
59
+
60
+ # Add the head dimension.
61
+ cos = cos.unsqueeze(1) # [1, 1, sq, rot_dim/2]
62
+ sin = sin.unsqueeze(1) # [1, 1, sq, rot_dim/2]
63
+
64
+ x_out = torch.stack([
65
+ x_shaped[..., 0] * cos - x_shaped[..., 1] * sin,
66
+ x_shaped[..., 1] * cos + x_shaped[..., 0] * sin
67
+ ], dim=-1)
68
+
69
+ x_out = x_out.flatten(3) # Merge the final two dimensions into rot_dim.
70
+ return torch.cat([x_out, x_pass], dim=-1)
71
+
72
+ # ==========================================
73
+ # 2. RoPE attention built on SDPA
74
+ # ==========================================
75
+
76
+ class WhisperRoPESdpaAttention(nn.Module):
77
+ """
78
+ Replace WhisperFlashAttention2 with PyTorch's native scaled_dot_product_attention.
79
+ """
80
+ def __init__(self, config: WhisperConfig, embed_dim: int, num_heads: int, dropout: float = 0.0):
81
+ super().__init__()
82
+ self.config = config
83
+ self.embed_dim = embed_dim
84
+ self.num_heads = num_heads
85
+ self.dropout = dropout
86
+ self.head_dim = embed_dim // num_heads
87
+
88
+ # Standard Whisper projection layers.
89
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
90
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
91
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
92
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
93
+
94
+ self.is_causal = False
95
+
96
+ def forward(
97
+ self,
98
+ hidden_states: torch.Tensor,
99
+ attention_mask: Optional[torch.Tensor] = None,
100
+ layer_head_mask: Optional[torch.Tensor] = None,
101
+ output_attentions: bool = False,
102
+ rotary_pos_emb: Optional[torch.Tensor] = None,
103
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], None]:
104
+
105
+ bsz, q_len, _ = hidden_states.size()
106
+
107
+ # 1. Project to queries, keys, and values.
108
+ query_states = self.q_proj(hidden_states)
109
+ key_states = self.k_proj(hidden_states)
110
+ value_states = self.v_proj(hidden_states)
111
+
112
+ # 2. Reshape to [batch, heads, seq, dim] and keep memory contiguous.
113
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
114
+ key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
115
+ value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
116
+
117
+ # 3. Apply RoPE.
118
+ if rotary_pos_emb is not None:
119
+ query_states = apply_rotary_pos_emb(query_states, rotary_pos_emb)
120
+ key_states = apply_rotary_pos_emb(key_states, rotary_pos_emb)
121
+
122
+ # 4. Align dtypes to avoid mismatches introduced by fp32 LayerNorm.
123
+ target_dtype = self.q_proj.weight.dtype
124
+ query_states = query_states.to(target_dtype)
125
+ key_states = key_states.to(target_dtype)
126
+ value_states = value_states.to(target_dtype)
127
+
128
+ # 5. Run SDPA. Do not apply manual scaling; SDPA handles it internally.
129
+ # If a 4D attention_mask is provided, SDPA applies it correctly.
130
+ attn_output = scaled_dot_product_attention(
131
+ query_states,
132
+ key_states,
133
+ value_states,
134
+ attn_mask=attention_mask,
135
+ dropout_p=self.dropout if self.training else 0.0,
136
+ is_causal=self.is_causal,
137
+ )
138
+
139
+ # 6. Restore shape and apply the output projection.
140
+ attn_output = attn_output.transpose(1, 2).contiguous()
141
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
142
+ attn_output = self.out_proj(attn_output)
143
+
144
+ return attn_output, None, None
145
+
146
+ # ==========================================
147
+ # 3. Wrapped encoder layer and encoder
148
+ # ==========================================
149
+
150
+ class WhisperSpecialEncoderLayer(WhisperEncoderLayer):
151
+ def __init__(self, config: WhisperConfig):
152
+ super().__init__(config)
153
+ # Replace self-attention with the RoPE + SDPA implementation.
154
+ self.self_attn = WhisperRoPESdpaAttention(
155
+ config=config,
156
+ embed_dim=self.embed_dim,
157
+ num_heads=config.encoder_attention_heads,
158
+ dropout=config.attention_dropout,
159
+ )
160
+
161
+ def forward(
162
+ self,
163
+ hidden_states: torch.Tensor,
164
+ attention_mask: Optional[torch.Tensor] = None,
165
+ layer_head_mask: Optional[torch.Tensor] = None,
166
+ output_attentions: bool = False,
167
+ rotary_pos_emb: Optional[torch.Tensor] = None,
168
+ position_ids: Optional[torch.Tensor] = None,
169
+ ) -> Tuple[torch.Tensor, Any]:
170
+
171
+ residual = hidden_states
172
+ hidden_states = self.self_attn_layer_norm(hidden_states)
173
+
174
+ hidden_states, attn_weights, _ = self.self_attn(
175
+ hidden_states=hidden_states,
176
+ attention_mask=attention_mask,
177
+ layer_head_mask=layer_head_mask,
178
+ output_attentions=output_attentions,
179
+ rotary_pos_emb=rotary_pos_emb,
180
+ )
181
+
182
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
183
+ hidden_states = residual + hidden_states
184
+
185
+ residual = hidden_states
186
+ hidden_states = self.final_layer_norm(hidden_states)
187
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
188
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
189
+ hidden_states = self.fc2(hidden_states)
190
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
191
+ hidden_states = residual + hidden_states
192
+
193
+ if hidden_states.dtype == torch.float16:
194
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
195
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
196
+
197
+ return (hidden_states, None) # Keep the tuple length aligned with the Whisper interface.
198
+
199
+ class WhisperSpecialEncoder(WhisperEncoder):
200
+ def __init__(self, config: WhisperConfig, use_rope=True, rope_ratio=1):
201
+ super().__init__(config)
202
+ self.use_rope = use_rope
203
+ # Override the parent layer stack.
204
+ self.layers = nn.ModuleList(
205
+ [WhisperSpecialEncoderLayer(config) for _ in range(config.encoder_layers)]
206
+ )
207
+
208
+ if use_rope:
209
+ # Compute the RoPE dimension, typically a subset of head_dim.
210
+ head_dim = config.d_model // config.encoder_attention_heads
211
+ self.rotary_embedding = RotaryEmbedding(head_dim // 2, rope_ratio)
212
+
213
+ def forward(
214
+ self,
215
+ input_features,
216
+ attention_mask=None,
217
+ head_mask=None,
218
+ output_attentions=None,
219
+ output_hidden_states=None,
220
+ return_dict=None,
221
+ position_ids=None,
222
+ ):
223
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
224
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
225
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
226
+
227
+ # Whisper convolutional feature extraction.
228
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
229
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
230
+ inputs_embeds = inputs_embeds.permute(0, 2, 1) # [B, T_down, D]
231
+
232
+ if self.use_rope:
233
+ # Build the rotary embedding cache.
234
+ rotary_embs = self.rotary_embedding.get_emb(
235
+ seq_len=inputs_embeds.shape[1],
236
+ dtype=inputs_embeds.dtype,
237
+ device=inputs_embeds.device
238
+ )
239
+ # Reshape to [1, seq_len, dim/2, 2] for broadcasting.
240
+ rotary_embs = rotary_embs.unsqueeze(0)
241
+ hidden_states = inputs_embeds
242
+ else:
243
+ rotary_embs = None
244
+ # Fall back to absolute positional embeddings.
245
+ embed_pos = self.embed_positions.weight[:inputs_embeds.shape[1]]
246
+ hidden_states = inputs_embeds + embed_pos
247
+
248
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
249
+
250
+ encoder_states = () if output_hidden_states else None
251
+ all_attentions = () if output_attentions else None
252
+
253
+ for idx, encoder_layer in enumerate(self.layers):
254
+ if output_hidden_states:
255
+ encoder_states = encoder_states + (hidden_states,)
256
+
257
+ if self.gradient_checkpointing and self.training:
258
+ layer_outputs = self._gradient_checkpointing_func(
259
+ encoder_layer.__call__,
260
+ hidden_states,
261
+ None, # attention_mask
262
+ (head_mask[idx] if head_mask is not None else None),
263
+ output_attentions,
264
+ rotary_embs,
265
+ position_ids,
266
+ )
267
+ else:
268
+ layer_outputs = encoder_layer(
269
+ hidden_states,
270
+ attention_mask=None,
271
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
272
+ output_attentions=output_attentions,
273
+ rotary_pos_emb=rotary_embs,
274
+ position_ids=position_ids,
275
+ )
276
+
277
+ hidden_states = layer_outputs[0]
278
+
279
+ if output_attentions:
280
+ all_attentions = all_attentions + (layer_outputs[2],)
281
+
282
+ hidden_states = self.layer_norm(hidden_states)
283
+ if output_hidden_states:
284
+ encoder_states = encoder_states + (hidden_states,)
285
+
286
+ if not return_dict:
287
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
288
+
289
+ return BaseModelOutputWithPastAndCrossAttentions(
290
+ last_hidden_state=hidden_states,
291
+ hidden_states=encoder_states,
292
+ attentions=all_attentions,
293
+ )
preprocessor_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "feature_extractor_type": "WhisperFeatureExtractor",
4
+ "feature_size": 128,
5
+ "hop_length": 160,
6
+ "n_fft": 400,
7
+ "n_samples": 480000,
8
+ "nb_max_frames": 3000,
9
+ "padding_side": "right",
10
+ "padding_value": 0.0,
11
+ "return_attention_mask": false,
12
+ "sampling_rate": 16000
13
+ }
processing_arkasr.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from __future__ import annotations
3
+
4
+ import base64
5
+ import io
6
+ import json
7
+ import os
8
+ from typing import Any, Dict, List, Optional, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import librosa
13
+ import soundfile as sf # Explicitly import soundfile to handle BytesIO.
14
+
15
+ from transformers import AutoTokenizer, WhisperFeatureExtractor
16
+ from transformers.feature_extraction_utils import BatchFeature
17
+ from transformers.processing_utils import ProcessorMixin
18
+ from transformers.utils import logging
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+ _AUDIO_MARKER = "<<AUDIO_TOKENS>>"
23
+
24
+ def _normalize_dtype_name(name: str) -> str:
25
+ name = name.strip().lower()
26
+ alias = {
27
+ "fp16": "float16",
28
+ "float16": "float16",
29
+ "half": "float16",
30
+ "bf16": "bfloat16",
31
+ "bfloat16": "bfloat16",
32
+ "fp32": "float32",
33
+ "float32": "float32",
34
+ "float": "float32",
35
+ }
36
+ return alias.get(name, name)
37
+
38
+
39
+ def _resolve_torch_dtype(x: Any, default: str = "float32") -> torch.dtype:
40
+ if isinstance(x, torch.dtype):
41
+ return x
42
+ if x is None:
43
+ x = default
44
+ if isinstance(x, str):
45
+ name = _normalize_dtype_name(x)
46
+ if not hasattr(torch, name):
47
+ raise ValueError(f"Unknown torch dtype string: {x} (normalized: {name})")
48
+ return getattr(torch, name)
49
+ raise TypeError(f"audio_dtype/audio_torch_dtype must be str or torch.dtype or None, got {type(x)}")
50
+
51
+
52
+ class ArkasrProcessor(ProcessorMixin):
53
+ attributes = ["feature_extractor", "tokenizer"]
54
+ valid_kwargs = ["merge_factor", "audio_token", "audio_dtype"]
55
+ feature_extractor_class = ("WhisperFeatureExtractor", "SequenceFeatureExtractor")
56
+ tokenizer_class = ("PreTrainedTokenizerFast", "PreTrainedTokenizer")
57
+
58
+ def __init__(
59
+ self,
60
+ feature_extractor,
61
+ tokenizer,
62
+ merge_factor: int = 4,
63
+ audio_token: str = "<|audio|>",
64
+ audio_dtype: str = "float32",
65
+ **kwargs,
66
+ ):
67
+ super().__init__(feature_extractor, tokenizer)
68
+ self.merge_factor = int(merge_factor)
69
+ self.audio_token = str(audio_token)
70
+ self.audio_dtype = str(audio_dtype)
71
+
72
+ self.bos_audio_token = "<|begin_of_audio|>"
73
+ self.eos_audio_token = "<|end_of_audio|>"
74
+ self.user_token = "<|user|>"
75
+ self.assistant_token = "<|assistant|>"
76
+
77
+ @classmethod
78
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "ArkasrProcessor":
79
+ trust_remote_code = bool(kwargs.pop("trust_remote_code", False))
80
+ passthrough_keys = {"cache_dir", "force_download", "local_files_only", "token", "revision", "subfolder"}
81
+ shared_kwargs = {k: kwargs[k] for k in list(kwargs.keys()) if k in passthrough_keys}
82
+
83
+ merge_factor = 4
84
+ audio_token = "<|audio|>"
85
+ audio_dtype = "float32"
86
+ tokenizer_cfg: Dict[str, Any] = {}
87
+ feat_cfg: Dict[str, Any] = {}
88
+
89
+ proc_cfg_path = os.path.join(pretrained_model_name_or_path, "processor_config.json")
90
+ if os.path.isfile(proc_cfg_path):
91
+ with open(proc_cfg_path, "r", encoding="utf-8") as f:
92
+ proc_cfg = json.load(f)
93
+ merge_factor = int(proc_cfg.get("merge_factor", merge_factor))
94
+ audio_token = str(proc_cfg.get("audio_token", audio_token))
95
+ audio_dtype = str(proc_cfg.get("audio_dtype", audio_dtype))
96
+ tokenizer_cfg = proc_cfg.get("tokenizer_config", {}) or {}
97
+ feat_cfg = proc_cfg.get("feature_extractor_config", {}) or {}
98
+
99
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **shared_kwargs)
100
+ for k, v in feat_cfg.items():
101
+ if hasattr(feature_extractor, k):
102
+ try: setattr(feature_extractor, k, v)
103
+ except Exception: pass
104
+
105
+ tokenizer = AutoTokenizer.from_pretrained(
106
+ pretrained_model_name_or_path, use_fast=True, trust_remote_code=trust_remote_code, **shared_kwargs
107
+ )
108
+ for k, v in tokenizer_cfg.items():
109
+ if hasattr(tokenizer, k):
110
+ try: setattr(tokenizer, k, v)
111
+ except Exception: pass
112
+
113
+ return cls(
114
+ feature_extractor=feature_extractor,
115
+ tokenizer=tokenizer,
116
+ merge_factor=merge_factor,
117
+ audio_token=audio_token,
118
+ audio_dtype=audio_dtype,
119
+ )
120
+
121
+ # =========================
122
+ # audio helpers (Modified)
123
+ # =========================
124
+ def _load_audio_file(self, path: str, sampling_rate: int = 16000, offset: float = 0.0, duration: Optional[float] = None) -> np.ndarray:
125
+ # librosa.load supports offset and duration.
126
+ # offset: start reading after this time (in seconds)
127
+ # duration: only load up to this much audio (in seconds)
128
+ audio_array, _ = librosa.load(path, sr=int(sampling_rate), mono=True, offset=offset, duration=duration)
129
+ return np.asarray(audio_array, dtype=np.float32)
130
+
131
+ def _strip_data_url_prefix(self, b64: str) -> str:
132
+ if "," in b64 and b64[:30].lower().startswith("data:"):
133
+ return b64.split(",", 1)[1]
134
+ return b64
135
+
136
+ def _load_audio_base64(self, b64: str, sampling_rate: int = 16000, offset: float = 0.0, duration: Optional[float] = None) -> np.ndarray:
137
+ b64 = self._strip_data_url_prefix(b64)
138
+ raw = base64.b64decode(b64)
139
+ bio = io.BytesIO(raw)
140
+
141
+ # librosa also supports offset and duration when loading from BytesIO.
142
+ try:
143
+ wav, _sr = librosa.load(bio, sr=int(sampling_rate), mono=True, offset=offset, duration=duration)
144
+ return np.asarray(wav, dtype=np.float32)
145
+ except Exception as e:
146
+ # Fallback path: manual slicing, which is slower.
147
+ try:
148
+ bio.seek(0)
149
+ data, sr = sf.read(bio, dtype="float32", always_2d=True)
150
+ wav = data.mean(axis=1)
151
+ if int(sr) != int(sampling_rate):
152
+ wav = librosa.resample(wav, orig_sr=int(sr), target_sr=int(sampling_rate))
153
+
154
+ start_sample = int(offset * sampling_rate)
155
+ end_sample = None
156
+ if duration is not None:
157
+ end_sample = start_sample + int(duration * sampling_rate)
158
+
159
+ return np.asarray(wav[start_sample:end_sample], dtype=np.float32)
160
+ except Exception as e2:
161
+ raise ValueError("Failed to decode base64 audio.") from e2
162
+
163
+ def calculate_audio_token_count(self, mel_frames: int) -> int:
164
+ downsampled = (int(mel_frames) + 1) // 2
165
+ merged = downsampled // max(self.merge_factor, 1)
166
+ return max(int(merged), 1)
167
+
168
+ def _build_templates_and_audios(
169
+ self,
170
+ conversations: List[List[dict]],
171
+ sampling_rate: int,
172
+ add_generation_prompt: bool,
173
+ ) -> tuple[List[str], List[np.ndarray], List[int]]:
174
+ prompts_template: List[str] = []
175
+ audios_raw: List[np.ndarray] = []
176
+ prompt_audio_counts: List[int] = []
177
+
178
+ for conv in conversations:
179
+ conv_str = ""
180
+ last_role = None
181
+ audio_count_this_conv = 0
182
+
183
+ for msg in conv:
184
+ role = msg["role"]
185
+ last_role = role
186
+ content = msg["content"]
187
+
188
+ if role == "user": conv_str += f"{self.user_token}"
189
+ elif role == "assistant": conv_str += f"{self.assistant_token}"
190
+ else: conv_str += f"<|{role}|>"
191
+
192
+ if isinstance(content, str):
193
+ conv_str += f"{content}"
194
+ elif isinstance(content, list):
195
+ for part in content:
196
+ ptype = part.get("type")
197
+ if ptype == "audio":
198
+ # ------------------------------------------------------------
199
+ # Parse begin_time and end_time when present.
200
+ # ------------------------------------------------------------
201
+ begin_time = part.get("begin_time", -1)
202
+ end_time = part.get("end_time", -1)
203
+
204
+ offset = 0.0
205
+ duration = None
206
+
207
+ # Apply slicing only when begin_time is valid and non-negative.
208
+ if begin_time is not None and begin_time >= 0:
209
+ offset = float(begin_time)
210
+ if end_time is not None and end_time > begin_time:
211
+ duration = float(end_time) - float(begin_time)
212
+
213
+ audio_raw_this = None
214
+ if "array" in part:
215
+ arr = part["array"]
216
+ if isinstance(arr, torch.Tensor):
217
+ arr = arr.detach().cpu().numpy()
218
+ full_arr = np.asarray(arr, dtype=np.float32).reshape(-1)
219
+
220
+ # Slice the in-memory audio array.
221
+ start_idx = int(offset * sampling_rate)
222
+ end_idx = None
223
+ if duration is not None:
224
+ end_idx = start_idx + int(duration * sampling_rate)
225
+ audio_raw_this = full_arr[start_idx:end_idx]
226
+
227
+ elif "path" in part:
228
+ audio_raw_this = self._load_audio_file(
229
+ part["path"],
230
+ sampling_rate=sampling_rate,
231
+ offset=offset,
232
+ duration=duration
233
+ )
234
+ elif "base64" in part:
235
+ audio_raw_this = self._load_audio_base64(
236
+ part["base64"],
237
+ sampling_rate=sampling_rate,
238
+ offset=offset,
239
+ duration=duration
240
+ )
241
+ else:
242
+ raise ValueError("Audio part must contain 'path' or 'array' or 'base64'.")
243
+
244
+ audios_raw.append(audio_raw_this)
245
+ audio_count_this_conv += 1
246
+ conv_str += f"{self.bos_audio_token}{_AUDIO_MARKER}{self.eos_audio_token}"
247
+
248
+ elif ptype == "text":
249
+ conv_str += f"{part.get('text', '')}"
250
+ else:
251
+ raise ValueError(f"Unknown content part type: {ptype}")
252
+ else:
253
+ raise ValueError(f"Unsupported message content type: {type(content)}")
254
+
255
+ if add_generation_prompt:
256
+ if last_role == "user": conv_str += f"{self.assistant_token}"
257
+ elif last_role == "assistant": conv_str += f"{self.user_token}"
258
+ else: conv_str += f"{self.assistant_token}"
259
+
260
+ prompts_template.append(conv_str)
261
+ prompt_audio_counts.append(audio_count_this_conv)
262
+
263
+ return prompts_template, audios_raw, prompt_audio_counts
264
+
265
+ def _calculate_audio_token_counts_per_sample(
266
+ self,
267
+ audios_raw: List[np.ndarray],
268
+ sampling_rate: int,
269
+ audio_max_length: Optional[int],
270
+ audio_pad_to_multiple_of: Optional[int],
271
+ ) -> List[int]:
272
+ del sampling_rate, audio_pad_to_multiple_of
273
+
274
+ hop_length = int(getattr(self.feature_extractor, "hop_length", 160))
275
+ max_audio_samples = int(audio_max_length) if audio_max_length is not None else None
276
+ token_counts: List[int] = []
277
+
278
+ for audio_raw in audios_raw:
279
+ audio_np = np.asarray(audio_raw, dtype=np.float32).reshape(-1)
280
+ effective_len = int(audio_np.shape[0])
281
+ if max_audio_samples is not None:
282
+ effective_len = min(effective_len, max_audio_samples)
283
+
284
+ mel_frames = effective_len // max(hop_length, 1)
285
+ token_counts.append(self.calculate_audio_token_count(int(mel_frames)))
286
+
287
+ return token_counts
288
+
289
+ # =========================
290
+ # apply_chat_template
291
+ # =========================
292
+ def apply_chat_template(
293
+ self,
294
+ conversation: Union[List[dict], List[List[dict]]],
295
+ chat_template: Optional[str] = None,
296
+ add_generation_prompt: bool = True,
297
+ **kwargs,
298
+ ) -> Union[BatchFeature, str, List[str]]:
299
+ if chat_template is not None:
300
+ logger.warning("chat_template argument is ignored.")
301
+
302
+ tokenize = kwargs.pop("tokenize", True)
303
+ return_tensors = kwargs.pop("return_tensors", "pt")
304
+ kwargs.pop("return_dict", None)
305
+
306
+ audio_torch_dtype = kwargs.pop("audio_torch_dtype", None)
307
+ audio_dtype_override = kwargs.pop("audio_dtype", None)
308
+ dtype_source = audio_torch_dtype if audio_torch_dtype is not None else audio_dtype_override
309
+ target_dtype = _resolve_torch_dtype(dtype_source, default=getattr(self, "audio_dtype", "float32"))
310
+
311
+ text_kwargs = dict(kwargs.pop("text_kwargs", {}) or {})
312
+ for k in ("padding", "truncation", "max_length", "add_special_tokens"):
313
+ if k in kwargs and k not in text_kwargs:
314
+ text_kwargs[k] = kwargs.pop(k)
315
+
316
+ sampling_rate = int(kwargs.pop("sampling_rate", 16000))
317
+ audio_padding = kwargs.pop("audio_padding", "longest")
318
+ audio_max_length = kwargs.pop("audio_max_length", None)
319
+ audio_pad_to_multiple_of = kwargs.pop("audio_pad_to_multiple_of", None)
320
+
321
+ if kwargs:
322
+ logger.warning(f"Ignored unused kwargs: {list(kwargs.keys())}")
323
+
324
+ if isinstance(conversation, list) and conversation and isinstance(conversation[0], dict):
325
+ conversations = [conversation]
326
+ is_single = True
327
+ else:
328
+ conversations = conversation
329
+ is_single = False
330
+
331
+ prompt_templates, audios_raw, prompt_audio_counts = self._build_templates_and_audios(
332
+ conversations=conversations,
333
+ sampling_rate=sampling_rate,
334
+ add_generation_prompt=add_generation_prompt,
335
+ )
336
+
337
+ input_features = None
338
+ audio_token_counts: List[int] = []
339
+
340
+ if len(audios_raw) > 0:
341
+ feat = self.feature_extractor(
342
+ audios_raw,
343
+ sampling_rate=sampling_rate,
344
+ return_tensors="np",
345
+ return_attention_mask=False,
346
+ padding=audio_padding,
347
+ max_length=audio_max_length,
348
+ pad_to_multiple_of=audio_pad_to_multiple_of,
349
+ )
350
+ input_features = feat["input_features"]
351
+ if not isinstance(input_features, np.ndarray):
352
+ input_features = np.asarray(input_features)
353
+
354
+ audio_token_counts = self._calculate_audio_token_counts_per_sample(
355
+ audios_raw=audios_raw,
356
+ sampling_rate=sampling_rate,
357
+ audio_max_length=audio_max_length,
358
+ audio_pad_to_multiple_of=audio_pad_to_multiple_of,
359
+ )
360
+
361
+ prompts: List[str] = []
362
+ audio_idx = 0
363
+ for prompt_template, audio_count in zip(prompt_templates, prompt_audio_counts):
364
+ prompt = prompt_template
365
+ for _ in range(audio_count):
366
+ if audio_idx >= len(audio_token_counts):
367
+ raise ValueError("Audio token count mismatch while building prompts.")
368
+ audio_tokens_str = "".join([self.audio_token] * audio_token_counts[audio_idx])
369
+ prompt = prompt.replace(_AUDIO_MARKER, audio_tokens_str, 1)
370
+ audio_idx += 1
371
+ if _AUDIO_MARKER in prompt:
372
+ raise ValueError("Unresolved audio marker remained in prompt.")
373
+ prompts.append(prompt)
374
+
375
+ if audio_idx != len(audio_token_counts):
376
+ raise ValueError("Unused audio token counts remained after prompt construction.")
377
+
378
+ if not tokenize:
379
+ return prompts[0] if is_single else prompts
380
+
381
+ text_kwargs.setdefault("padding", "longest")
382
+ text_kwargs.setdefault("add_special_tokens", False)
383
+ text_kwargs["return_tensors"] = return_tensors
384
+
385
+ enc = self.tokenizer(prompts, **text_kwargs)
386
+ data: Dict[str, Any] = dict(enc)
387
+
388
+ if input_features is not None:
389
+ data["audios"] = torch.tensor(input_features, dtype=target_dtype)
390
+
391
+ return BatchFeature(data=data, tensor_type=return_tensors)
392
+
393
+ # ... (The remaining batch_decode, decode, __call__, and model_input_names stay unchanged.) ...
394
+ def batch_decode(self, *args, **kwargs):
395
+ return self.tokenizer.batch_decode(*args, **kwargs)
396
+
397
+ def decode(self, *args, **kwargs):
398
+ return self.tokenizer.decode(*args, **kwargs)
399
+
400
+ def __call__(
401
+ self,
402
+ text: Union[str, List[str]],
403
+ audios: Union[np.ndarray, torch.Tensor, List[Union[np.ndarray, torch.Tensor]]],
404
+ sampling_rate: int = 16000,
405
+ return_tensors: str = "pt",
406
+ **tokenizer_kwargs,
407
+ ) -> BatchFeature:
408
+ # Simplified implementation that skips time slicing because the caller passes raw audio arrays directly.
409
+ audios_list = []
410
+ def flatten_audios(obj):
411
+ if isinstance(obj, (list, tuple)):
412
+ if len(obj) > 0 and isinstance(obj[0], (float, int)):
413
+ audios_list.append(obj)
414
+ else:
415
+ for item in obj: flatten_audios(item)
416
+ elif isinstance(obj, (np.ndarray, torch.Tensor)):
417
+ audios_list.append(obj)
418
+ flatten_audios(audios)
419
+
420
+ audios_np: List[np.ndarray] = []
421
+ for a in audios_list:
422
+ if isinstance(a, torch.Tensor): a = a.detach().cpu().numpy()
423
+ a = np.asarray(a, dtype=np.float32).reshape(-1)
424
+ audios_np.append(a)
425
+
426
+ input_features = None
427
+ if audios_np:
428
+ feat = self.feature_extractor(audios_np, sampling_rate=int(sampling_rate), return_tensors="np", return_attention_mask=False, padding="longest")
429
+ input_features = feat["input_features"]
430
+ if not isinstance(input_features, np.ndarray): input_features = np.asarray(input_features)
431
+
432
+ tokenizer_kwargs = dict(tokenizer_kwargs or {})
433
+ tokenizer_kwargs.setdefault("padding", "longest")
434
+ tokenizer_kwargs.setdefault("add_special_tokens", False)
435
+ tokenizer_kwargs["return_tensors"] = return_tensors
436
+
437
+ enc = self.tokenizer(text, **tokenizer_kwargs)
438
+ data: Dict[str, Any] = dict(enc)
439
+ if input_features is not None:
440
+ data["audios"] = torch.tensor(input_features, dtype=_resolve_torch_dtype(getattr(self, "audio_dtype", "float32")))
441
+ return BatchFeature(data=data, tensor_type=return_tensors)
442
+
443
+ @property
444
+ def model_input_names(self):
445
+ return ["input_ids", "attention_mask", "audios"]
processor_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "ArkasrProcessor",
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_arkasr.ArkasrProcessor"
5
+ },
6
+
7
+ "feature_extractor_type": "WhisperFeatureExtractor",
8
+ "tokenizer_class": "Qwen2Tokenizer",
9
+
10
+ "merge_factor": 4,
11
+ "audio_token": "<|audio|>",
12
+
13
+ "audio_dtype": "bfloat16",
14
+
15
+ "tokenizer_config": {
16
+ "padding_side": "left",
17
+ "model_max_length": 8192
18
+ }
19
+ }
20
+
spark_tokenizer_model/config.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "apply_spec_augment": true,
4
+ "architectures": [
5
+ "Wav2Vec2ForPreTraining"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "bos_token_id": 1,
9
+ "codevector_dim": 768,
10
+ "contrastive_logits_temperature": 0.1,
11
+ "conv_bias": true,
12
+ "conv_dim": [
13
+ 512,
14
+ 512,
15
+ 512,
16
+ 512,
17
+ 512,
18
+ 512,
19
+ 512
20
+ ],
21
+ "conv_kernel": [
22
+ 10,
23
+ 3,
24
+ 3,
25
+ 3,
26
+ 3,
27
+ 2,
28
+ 2
29
+ ],
30
+ "conv_stride": [
31
+ 5,
32
+ 2,
33
+ 2,
34
+ 2,
35
+ 2,
36
+ 2,
37
+ 2
38
+ ],
39
+ "ctc_loss_reduction": "sum",
40
+ "ctc_zero_infinity": false,
41
+ "diversity_loss_weight": 0.1,
42
+ "do_stable_layer_norm": true,
43
+ "eos_token_id": 2,
44
+ "feat_extract_activation": "gelu",
45
+ "feat_extract_dropout": 0.0,
46
+ "feat_extract_norm": "layer",
47
+ "feat_proj_dropout": 0.1,
48
+ "feat_quantizer_dropout": 0.0,
49
+ "final_dropout": 0.0,
50
+ "gradient_checkpointing": false,
51
+ "hidden_act": "gelu",
52
+ "hidden_dropout": 0.1,
53
+ "hidden_size": 1024,
54
+ "initializer_range": 0.02,
55
+ "intermediate_size": 4096,
56
+ "layer_norm_eps": 1e-05,
57
+ "layerdrop": 0.1,
58
+ "mask_channel_length": 10,
59
+ "mask_channel_min_space": 1,
60
+ "mask_channel_other": 0.0,
61
+ "mask_channel_prob": 0.0,
62
+ "mask_channel_selection": "static",
63
+ "mask_feature_length": 10,
64
+ "mask_feature_prob": 0.0,
65
+ "mask_time_length": 10,
66
+ "mask_time_min_space": 1,
67
+ "mask_time_other": 0.0,
68
+ "mask_time_prob": 0.075,
69
+ "mask_time_selection": "static",
70
+ "model_type": "wav2vec2",
71
+ "num_attention_heads": 16,
72
+ "num_codevector_groups": 2,
73
+ "num_codevectors_per_group": 320,
74
+ "num_conv_pos_embedding_groups": 16,
75
+ "num_conv_pos_embeddings": 128,
76
+ "num_feat_extract_layers": 7,
77
+ "num_hidden_layers": 24,
78
+ "num_negatives": 100,
79
+ "pad_token_id": 0,
80
+ "proj_codevector_dim": 768,
81
+ "transformers_version": "4.7.0.dev0",
82
+ "vocab_size": 32
83
+ }
spark_tokenizer_model/config.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ audio_tokenizer:
2
+ mel_params:
3
+ sample_rate: 16000
4
+ n_fft: 1024
5
+ win_length: 640
6
+ hop_length: 320
7
+ mel_fmin: 10
8
+ mel_fmax: null
9
+ num_mels: 128
10
+
11
+ encoder:
12
+ input_channels: 1024
13
+ vocos_dim: 384
14
+ vocos_intermediate_dim: 2048
15
+ vocos_num_layers: 12
16
+ out_channels: 1024
17
+ sample_ratios: [1,1]
18
+
19
+ decoder:
20
+ input_channel: 1024
21
+ channels: 1536
22
+ rates: [8, 5, 4, 2]
23
+ kernel_sizes: [16,11,8,4]
24
+
25
+ quantizer:
26
+ input_dim: 1024
27
+ codebook_size: 8192
28
+ codebook_dim: 8
29
+ commitment: 0.25
30
+ codebook_loss_weight: 2.0
31
+ use_l2_normlize: True
32
+ threshold_ema_dead_code: 0.2
33
+
34
+ speaker_encoder:
35
+ input_dim: 128
36
+ out_dim: 1024
37
+ latent_dim: 128
38
+ token_num: 32
39
+ fsq_levels: [4, 4, 4, 4, 4, 4]
40
+ fsq_num_quantizers: 1
41
+
42
+ prenet:
43
+ input_channels: 1024
44
+ vocos_dim: 384
45
+ vocos_intermediate_dim: 2048
46
+ vocos_num_layers: 12
47
+ out_channels: 1024
48
+ condition_dim: 1024
49
+ sample_ratios: [1,1]
50
+ use_tanh_at_final: False
51
+
52
+ postnet:
53
+ input_channels: 1024
54
+ vocos_dim: 384
55
+ vocos_intermediate_dim: 2048
56
+ vocos_num_layers: 6
57
+ out_channels: 1024
58
+ use_tanh_at_final: False
59
+ highpass_cutoff_freq: 40
60
+ sample_rate: 16000
61
+ segment_duration: 2.4 # (s)
62
+ max_val_duration: 12 # (s)
63
+ latent_hop_length: 320
64
+ ref_segment_duration: 6
65
+ volume_normalize: true
66
+
spark_tokenizer_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9940cd48d4446e4340ced82d234bf5618350dd9f5db900ebe47a4fdb03867ec
3
+ size 625518756
spark_tokenizer_model/wav2vec2-large-xlsr-53/config.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "apply_spec_augment": true,
4
+ "architectures": [
5
+ "Wav2Vec2ForPreTraining"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "bos_token_id": 1,
9
+ "codevector_dim": 768,
10
+ "contrastive_logits_temperature": 0.1,
11
+ "conv_bias": true,
12
+ "conv_dim": [
13
+ 512,
14
+ 512,
15
+ 512,
16
+ 512,
17
+ 512,
18
+ 512,
19
+ 512
20
+ ],
21
+ "conv_kernel": [
22
+ 10,
23
+ 3,
24
+ 3,
25
+ 3,
26
+ 3,
27
+ 2,
28
+ 2
29
+ ],
30
+ "conv_stride": [
31
+ 5,
32
+ 2,
33
+ 2,
34
+ 2,
35
+ 2,
36
+ 2,
37
+ 2
38
+ ],
39
+ "ctc_loss_reduction": "sum",
40
+ "ctc_zero_infinity": false,
41
+ "diversity_loss_weight": 0.1,
42
+ "do_stable_layer_norm": true,
43
+ "eos_token_id": 2,
44
+ "feat_extract_activation": "gelu",
45
+ "feat_extract_dropout": 0.0,
46
+ "feat_extract_norm": "layer",
47
+ "feat_proj_dropout": 0.1,
48
+ "feat_quantizer_dropout": 0.0,
49
+ "final_dropout": 0.0,
50
+ "gradient_checkpointing": false,
51
+ "hidden_act": "gelu",
52
+ "hidden_dropout": 0.1,
53
+ "hidden_size": 1024,
54
+ "initializer_range": 0.02,
55
+ "intermediate_size": 4096,
56
+ "layer_norm_eps": 1e-05,
57
+ "layerdrop": 0.1,
58
+ "mask_channel_length": 10,
59
+ "mask_channel_min_space": 1,
60
+ "mask_channel_other": 0.0,
61
+ "mask_channel_prob": 0.0,
62
+ "mask_channel_selection": "static",
63
+ "mask_feature_length": 10,
64
+ "mask_feature_prob": 0.0,
65
+ "mask_time_length": 10,
66
+ "mask_time_min_space": 1,
67
+ "mask_time_other": 0.0,
68
+ "mask_time_prob": 0.075,
69
+ "mask_time_selection": "static",
70
+ "model_type": "wav2vec2",
71
+ "num_attention_heads": 16,
72
+ "num_codevector_groups": 2,
73
+ "num_codevectors_per_group": 320,
74
+ "num_conv_pos_embedding_groups": 16,
75
+ "num_conv_pos_embeddings": 128,
76
+ "num_feat_extract_layers": 7,
77
+ "num_hidden_layers": 24,
78
+ "num_negatives": 100,
79
+ "pad_token_id": 0,
80
+ "proj_codevector_dim": 768,
81
+ "transformers_version": "4.7.0.dev0",
82
+ "vocab_size": 32
83
+ }
spark_tokenizer_model/wav2vec2-large-xlsr-53/preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0,
7
+ "return_attention_mask": true,
8
+ "sampling_rate": 16000
9
+ }
spark_tokenizer_model/wav2vec2-large-xlsr-53/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:314340227371a608f71adcd5f0de5933824fe77e55822aa4b24dba9c1c364dcb
3
+ size 1269737156
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|user|>",
4
+ "<|begin_of_audio|>",
5
+ "<|end_of_audio|>",
6
+ "<|assistant|>",
7
+ "<|system|>"
8
+ ],
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff07bfc6cf4ed2365e9ae107e5118c89170363624f2036c85a27904d368efd87
3
+ size 13894630
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:347b151ccd4d11234353c5ba9417ccbc5ac6e3003a9a3e1481f87080d978d782
3
+ size 7313
vocab.json ADDED
The diff for this file is too large to render. See raw diff