seslami-pplx commited on
Commit
43188fa
·
verified ·
1 Parent(s): 4a5d176

Upload SLERP-merged checkpoint (alpha=0.5) from two adversarial-FT runs at step-1500

Browse files
added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
config.json ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "PPLXQwen3ContextualModel"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration.PPLXQwen3Config",
9
+ "AutoModel": "modeling.PPLXQwen3ContextualModel"
10
+ },
11
+ "bos_token_id": 151643,
12
+ "dtype": "float32",
13
+ "eos_token_id": 151643,
14
+ "head_dim": 128,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 2560,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 9728,
19
+ "layer_types": [
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention",
44
+ "full_attention",
45
+ "full_attention",
46
+ "full_attention",
47
+ "full_attention",
48
+ "full_attention",
49
+ "full_attention",
50
+ "full_attention",
51
+ "full_attention",
52
+ "full_attention",
53
+ "full_attention",
54
+ "full_attention",
55
+ "full_attention"
56
+ ],
57
+ "max_position_embeddings": 32768,
58
+ "max_window_layers": 36,
59
+ "model_type": "bidirectional_pplx_qwen3",
60
+ "num_attention_heads": 32,
61
+ "num_hidden_layers": 36,
62
+ "num_key_value_heads": 8,
63
+ "rms_norm_eps": 1e-06,
64
+ "rope_parameters": {
65
+ "rope_theta": 1000000,
66
+ "rope_type": "default"
67
+ },
68
+ "rope_theta": 1000000,
69
+ "sliding_window": null,
70
+ "tie_word_embeddings": true,
71
+ "transformers_version": "5.0.0.dev0",
72
+ "use_bidirectional_attention": true,
73
+ "use_cache": false,
74
+ "use_sliding_window": false,
75
+ "vocab_size": 151936
76
+ }
configuration.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
2
+
3
+
4
+ class PPLXQwen3Config(Qwen3Config):
5
+ model_type = "bidirectional_pplx_qwen3"
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcb5ec8de5d8ca71dbea35cd7d055942537d70ac817d9346f7005a31e2082fec
3
+ size 16089915848
modeling.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Literal
2
+ import numpy as np
3
+ import torch
4
+ from transformers import Qwen3Model
5
+ from transformers.cache_utils import Cache
6
+ from transformers.masking_utils import create_causal_mask
7
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
8
+ from transformers.processing_utils import Unpack
9
+ from transformers.utils import TransformersKwargs
10
+ from .configuration import PPLXQwen3Config
11
+ from transformers import AutoTokenizer
12
+ from .st_quantize import FlexibleQuantizer
13
+
14
+
15
+ # From modeling_t5gemma.py
16
+ def bidirectional_mask_function(attention_mask: torch.Tensor | None) -> Callable:
17
+ """
18
+ This creates bidirectional attention mask.
19
+ """
20
+
21
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
22
+ if attention_mask is None:
23
+ return torch.ones((), dtype=torch.bool)
24
+ return attention_mask[batch_idx, kv_idx].to(torch.bool)
25
+
26
+ return inner_mask
27
+
28
+
29
+ class PPLXQwen3Model(Qwen3Model):
30
+ _supports_flash_attn = True
31
+ _supports_sdpa = True
32
+
33
+ config_class = PPLXQwen3Config
34
+
35
+ def __init__(self, config):
36
+ super().__init__(config)
37
+ self.post_init()
38
+
39
+ def post_init(self):
40
+ super().post_init()
41
+ # Override to set all layers to non-causal attention. This'll work with attn_implementation="flash_attention_2" or "sdpa"
42
+ for layer in self.layers:
43
+ layer.self_attn.is_causal = False
44
+
45
+ def forward(
46
+ self,
47
+ input_ids: torch.LongTensor | None = None,
48
+ attention_mask: torch.Tensor | None = None,
49
+ position_ids: torch.LongTensor | None = None,
50
+ past_key_values: Cache | None = None,
51
+ inputs_embeds: torch.FloatTensor | None = None,
52
+ use_cache: bool | None = None,
53
+ cache_position: torch.LongTensor | None = None,
54
+ **kwargs: Unpack[TransformersKwargs],
55
+ ) -> BaseModelOutputWithPooling:
56
+ if inputs_embeds is None:
57
+ inputs_embeds = self.embed_tokens(input_ids)
58
+ input_ids = None
59
+
60
+ # We construct a dummy tensor imitating initial positions
61
+ dummy_cache_position = torch.arange(
62
+ inputs_embeds.shape[1], device=inputs_embeds.device, dtype=torch.long
63
+ )
64
+ attention_mask = {
65
+ "full_attention": create_causal_mask(
66
+ config=self.config,
67
+ input_embeds=inputs_embeds,
68
+ attention_mask=attention_mask,
69
+ cache_position=dummy_cache_position,
70
+ past_key_values=None,
71
+ position_ids=position_ids,
72
+ or_mask_function=bidirectional_mask_function(attention_mask),
73
+ )
74
+ }
75
+
76
+ outputs = super().forward(
77
+ input_ids=input_ids,
78
+ attention_mask=attention_mask,
79
+ position_ids=position_ids,
80
+ past_key_values=past_key_values,
81
+ inputs_embeds=inputs_embeds,
82
+ use_cache=use_cache,
83
+ cache_position=cache_position,
84
+ **kwargs,
85
+ )
86
+ return outputs
87
+
88
+
89
+ class PPLXQwen3ContextualModel(PPLXQwen3Model):
90
+ """
91
+ Qwen3 model with contextual encoding support for late chunking.
92
+
93
+ This model extends PPLXQwen3Model with an encode() method that supports both
94
+ standard encoding (list[str]) and contextual encoding (list[list[str]]) with late chunking.
95
+
96
+ IMPORTANT: This model MUST be loaded with trust_remote_code=True:
97
+
98
+ from transformers import AutoModel
99
+
100
+ model = AutoModel.from_pretrained(
101
+ "path/to/model",
102
+ trust_remote_code=True # REQUIRED!
103
+ )
104
+
105
+ embeddings = model.encode([["chunk1", "chunk2"]])
106
+
107
+ Loading without trust_remote_code=True will fail to load this custom model class.
108
+ """
109
+
110
+ config_class = PPLXQwen3Config
111
+
112
+ def __init__(self, config):
113
+ super().__init__(config)
114
+
115
+ if not isinstance(config, PPLXQwen3Config):
116
+ raise TypeError(
117
+ f"PPLXQwen3ContextualModel requires PPLXQwen3Config, got {type(config).__name__}. "
118
+ f"Did you forget to load with trust_remote_code=True?"
119
+ )
120
+
121
+ self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
122
+ self._flexible_quantizer = FlexibleQuantizer()
123
+
124
+ @staticmethod
125
+ def mean_pooling(
126
+ token_embeddings: torch.Tensor, attention_mask: torch.Tensor
127
+ ) -> torch.Tensor:
128
+ """Apply mean pooling to token embeddings."""
129
+ input_mask_expanded = (
130
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
131
+ )
132
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
133
+ input_mask_expanded.sum(1), min=1e-9
134
+ )
135
+
136
+ @torch.inference_mode()
137
+ def encode(
138
+ self,
139
+ documents: list[list[str]],
140
+ batch_size: int = 32,
141
+ show_progress_bar: bool = False,
142
+ device: str | torch.device | None = None,
143
+ normalize_embeddings: bool = False,
144
+ convert_to_numpy: bool = True,
145
+ quantization: Literal["int8", "binary", "ubinary"] = "int8",
146
+ ) -> list[np.ndarray] | list[torch.Tensor]:
147
+ """
148
+ Encode documents with late chunking (contextual embeddings).
149
+
150
+ This model is designed specifically for contextual encoding and always expects
151
+ documents as nested lists where each document is a list of text chunks.
152
+
153
+ The encoding process:
154
+ 1. Concatenate chunks with separator tokens
155
+ 2. Run forward pass to get token embeddings
156
+ 3. Extract and pool individual chunk embeddings (late chunking)
157
+ 4. Apply quantization (Int8 or binary, always enabled)
158
+ 5. Normalize embeddings if requested (applied after quantization)
159
+ 6. Convert to numpy or return as tensors
160
+
161
+ Args:
162
+ documents: List of documents, where each document is a list of text chunks.
163
+ Example: [["chunk1", "chunk2"], ["chunk1", "chunk2", "chunk3"]]
164
+ batch_size: Batch size for encoding
165
+ show_progress_bar: Show progress bar during encoding
166
+ device: Device to use for computation (defaults to model's device)
167
+ normalize_embeddings: Normalize embeddings to unit length (applied after quantization)
168
+ convert_to_numpy: If True, returns list[np.ndarray], otherwise list[torch.Tensor]
169
+ quantization: Quantization type to apply. Options:
170
+ - "int8": Int8 tanh quantization (default)
171
+ - "binary": Binary tanh quantization (-1.0 or 1.0)
172
+ - "ubinary": Unsigned packed binary (uint8, 8x compression)
173
+
174
+ Returns:
175
+ List of numpy arrays or tensors (preserves document structure).
176
+ Each element has shape (n_chunks, hidden_dim) or (n_chunks, hidden_dim // 8) for ubinary.
177
+ Example: embeddings[0].shape = (2, 1024), embeddings[1].shape = (3, 1024)
178
+ Output type depends on quantization method:
179
+ - "int8": int8 dtype, values in range [-128, 127], shape (..., hidden_dim)
180
+ - "binary": float32 dtype, values -1.0 or 1.0, shape (..., hidden_dim)
181
+ - "ubinary": uint8 dtype, packed bits (8x smaller), shape (..., hidden_dim // 8)
182
+ """
183
+
184
+ if not isinstance(documents, list) or not all(
185
+ isinstance(doc, list) for doc in documents
186
+ ):
187
+ raise TypeError(
188
+ "Input 'documents' must be a list of lists of strings for contextual encoding."
189
+ )
190
+
191
+ if quantization not in ["int8", "binary", "ubinary"]:
192
+ raise ValueError(
193
+ f"Unsupported quantization type: '{quantization}'. "
194
+ f"Supported types are: 'int8', 'binary', 'ubinary'. "
195
+ f"Got: {type(quantization).__name__} = '{quantization}'"
196
+ )
197
+
198
+ if normalize_embeddings and quantization == "ubinary":
199
+ raise ValueError(
200
+ "normalize_embeddings=True is incompatible with quantization='ubinary'. "
201
+ "Packed binary embeddings (uint8) cannot be normalized because each byte "
202
+ "represents 8 packed bits, not a single dimension. "
203
+ "Either set normalize_embeddings=False or use 'binary' quantization instead."
204
+ )
205
+
206
+ self.eval()
207
+
208
+ if device is None:
209
+ device = next(self.parameters()).device
210
+
211
+ all_embeddings = []
212
+
213
+ range_iter = range(0, len(documents), batch_size)
214
+ if show_progress_bar:
215
+ try:
216
+ from tqdm import tqdm
217
+
218
+ range_iter = tqdm(range_iter, desc="Encoding documents")
219
+ except ImportError:
220
+ pass
221
+
222
+ for i in range_iter:
223
+ batch_docs = documents[i : i + batch_size]
224
+
225
+ doc_strings = [
226
+ self.tokenizer.sep_token.join(chunks) for chunks in batch_docs
227
+ ]
228
+
229
+ inputs = self.tokenizer(
230
+ doc_strings,
231
+ padding=True,
232
+ truncation=True,
233
+ return_tensors="pt",
234
+ )
235
+ inputs = {k: v.to(device) for k, v in inputs.items()}
236
+
237
+ outputs = self.forward(**inputs)
238
+ token_embeddings = outputs.last_hidden_state
239
+
240
+ batch_chunk_embeddings = self._extract_chunks_from_concatenated(
241
+ input_ids=inputs["input_ids"],
242
+ token_embeddings=token_embeddings,
243
+ attention_mask=inputs["attention_mask"],
244
+ )
245
+
246
+ batch_chunk_embeddings = [
247
+ torch.stack([chunk for chunk in doc_chunks], dim=0)
248
+ for doc_chunks in batch_chunk_embeddings
249
+ ]
250
+
251
+ batch_chunk_embeddings = [
252
+ self._flexible_quantizer(
253
+ {"sentence_embedding": emb}, quantization=quantization
254
+ )["sentence_embedding"]
255
+ for emb in batch_chunk_embeddings
256
+ ]
257
+
258
+ if normalize_embeddings:
259
+ batch_chunk_embeddings = [
260
+ torch.nn.functional.normalize(emb, p=2, dim=-1)
261
+ for emb in batch_chunk_embeddings
262
+ ]
263
+
264
+ batch_chunk_embeddings = [emb.cpu() for emb in batch_chunk_embeddings]
265
+
266
+ all_embeddings.extend(batch_chunk_embeddings)
267
+
268
+ if convert_to_numpy:
269
+ all_embeddings = [emb.numpy() for emb in all_embeddings]
270
+
271
+ return all_embeddings
272
+
273
+ def _extract_chunks_from_concatenated(
274
+ self,
275
+ input_ids: torch.Tensor,
276
+ token_embeddings: torch.Tensor,
277
+ attention_mask: torch.Tensor,
278
+ ) -> list[list[torch.Tensor]]:
279
+ """
280
+ Extract individual chunk embeddings from concatenated sequence using late chunking.
281
+
282
+ This method splits concatenated sequences like "[chunk1][SEP][chunk2][SEP]..."
283
+ back into individual chunk embeddings by finding SEP token positions.
284
+
285
+ Args:
286
+ input_ids: Token IDs (batch_size, seq_len)
287
+ token_embeddings: Token embeddings (batch_size, seq_len, hidden_dim)
288
+ attention_mask: Attention mask (batch_size, seq_len)
289
+
290
+ Returns:
291
+ list[list[torch.Tensor]]: List of documents, each containing list of chunk embeddings
292
+
293
+ Note:
294
+ The sep_token_id is retrieved from self.tokenizer.sep_token_id.
295
+ Common values: Qwen2=151643, BERT=102, varies by tokenizer.
296
+ """
297
+ sep_token_id = self.tokenizer.sep_token_id
298
+ batch_size = input_ids.shape[0]
299
+
300
+ all_doc_chunks = []
301
+
302
+ for batch_idx in range(batch_size):
303
+ # non-pad sep tokens
304
+ valid_positions = attention_mask[batch_idx].bool()
305
+ sep_positions = (
306
+ (input_ids[batch_idx] == sep_token_id) & valid_positions
307
+ ).nonzero(as_tuple=True)[0]
308
+
309
+ chunk_embeddings = []
310
+ start_pos = 0
311
+
312
+ for sep_pos in sep_positions:
313
+ chunk_tokens = token_embeddings[batch_idx, start_pos:sep_pos]
314
+ chunk_mask = attention_mask[batch_idx, start_pos:sep_pos]
315
+
316
+ chunk_emb = self.mean_pooling(
317
+ chunk_tokens.unsqueeze(0), chunk_mask.unsqueeze(0)
318
+ ).squeeze(0)
319
+
320
+ chunk_embeddings.append(chunk_emb)
321
+
322
+ start_pos = sep_pos + 1
323
+
324
+ # Handle the last chunk (after the last SEP token)
325
+ last_valid_pos = attention_mask[batch_idx].sum().item()
326
+
327
+ chunk_tokens = token_embeddings[batch_idx, start_pos:last_valid_pos]
328
+ chunk_mask = attention_mask[batch_idx, start_pos:last_valid_pos]
329
+
330
+ if chunk_mask.sum() > 0:
331
+ chunk_emb = self.mean_pooling(
332
+ chunk_tokens.unsqueeze(0), chunk_mask.unsqueeze(0)
333
+ ).squeeze(0)
334
+ else:
335
+ # Empty chunk - create zero embedding
336
+ chunk_emb = torch.zeros(
337
+ token_embeddings.shape[-1],
338
+ device=token_embeddings.device,
339
+ dtype=token_embeddings.dtype,
340
+ )
341
+
342
+ chunk_embeddings.append(chunk_emb)
343
+
344
+ all_doc_chunks.append(chunk_embeddings)
345
+
346
+ return all_doc_chunks
special_tokens_map.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "mask_token": {
25
+ "content": "â½Ĺ",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ },
31
+ "pad_token": {
32
+ "content": "<|endoftext|>",
33
+ "lstrip": false,
34
+ "normalized": false,
35
+ "rstrip": false,
36
+ "single_word": false
37
+ },
38
+ "sep_token": {
39
+ "content": "<|endoftext|>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false
44
+ }
45
+ }
st_quantize.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import Literal
4
+ from sentence_transformers.models import Module
5
+
6
+
7
+ class Quantizer(torch.nn.Module):
8
+ def __init__(self, hard: bool = True):
9
+ """
10
+ Args:
11
+ hard: Whether to use hard or soft quantization. Defaults to True.
12
+ """
13
+ super().__init__()
14
+ self._hard = hard
15
+
16
+ def _hard_quantize(self, x, *args, **kwargs) -> torch.Tensor:
17
+ raise NotImplementedError
18
+
19
+ def _soft_quantize(self, x, *args, **kwargs) -> torch.Tensor:
20
+ raise NotImplementedError
21
+
22
+ def forward(self, x, *args, **kwargs) -> torch.Tensor:
23
+ soft = self._soft_quantize(x, *args, **kwargs)
24
+
25
+ if not self._hard:
26
+ result = soft
27
+ else:
28
+ result = (
29
+ self._hard_quantize(x, *args, **kwargs).detach() + soft - soft.detach()
30
+ )
31
+
32
+ return result
33
+
34
+
35
+ class Int8TanhQuantizer(Quantizer):
36
+ def __init__(
37
+ self,
38
+ hard: bool = True,
39
+ ):
40
+ super().__init__(hard=hard)
41
+ self.qmin = -128
42
+ self.qmax = 127
43
+
44
+ def _soft_quantize(self, x, *args, **kwargs):
45
+ return torch.tanh(x)
46
+
47
+ def _hard_quantize(self, x, *args, **kwargs):
48
+ soft = self._soft_quantize(x)
49
+ int_x = torch.round(soft * self.qmax)
50
+ int_x = torch.clamp(int_x, self.qmin, self.qmax)
51
+ return int_x
52
+
53
+
54
+ class BinaryTanhQuantizer(Quantizer):
55
+ def __init__(
56
+ self,
57
+ hard: bool = True,
58
+ scale: float = 1.0,
59
+ ):
60
+ super().__init__(hard)
61
+ self._scale = scale
62
+
63
+ def _soft_quantize(self, x, *args, **kwargs):
64
+ return torch.tanh(self._scale * x)
65
+
66
+ def _hard_quantize(self, x, *args, **kwargs):
67
+ return torch.where(x >= 0, 1.0, -1.0)
68
+
69
+
70
+ class PackedBinaryQuantizer:
71
+ """
72
+ Packs binary embeddings into uint8 format for efficient storage.
73
+
74
+ This quantizer applies a binary threshold (x >= 0) and packs 8 consecutive
75
+ bits into a single uint8 byte using numpy.packbits. This reduces memory
76
+ usage by 8x compared to float32 and by 4x compared to int8.
77
+
78
+ IMPORTANT: This is an inference-only quantizer - it is not differentiable
79
+ and should only be used for encoding/inference, not during training.
80
+
81
+ Args:
82
+ x: Input tensor of any float dtype, shape (..., embedding_dim)
83
+
84
+ Returns:
85
+ Packed binary tensor of dtype uint8, shape (..., embedding_dim // 8)
86
+
87
+ Example:
88
+ >>> quantizer = PackedBinaryQuantizer()
89
+ >>> embeddings = torch.randn(2, 1024) # float32
90
+ >>> packed = quantizer(embeddings) # uint8, shape (2, 128)
91
+ """
92
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
93
+ bits = np.where(x.cpu().numpy() >= 0, True, False)
94
+ packed = np.packbits(bits, axis=-1)
95
+ return torch.from_numpy(packed).to(x.device)
96
+
97
+
98
+ class FlexibleQuantizer(Module):
99
+ def __init__(self):
100
+ super().__init__()
101
+ self._int8_quantizer = Int8TanhQuantizer()
102
+ self._binary_quantizer = BinaryTanhQuantizer()
103
+ self._packed_binary_quantizer = PackedBinaryQuantizer()
104
+
105
+ def forward(
106
+ self,
107
+ features: dict[str, torch.Tensor],
108
+ quantization: Literal["int8", "binary", "ubinary"] = "int8",
109
+ **kwargs,
110
+ ) -> dict[str, torch.Tensor]:
111
+ if quantization == "int8":
112
+ features["sentence_embedding"] = self._int8_quantizer(
113
+ features["sentence_embedding"]
114
+ )
115
+ elif quantization == "binary":
116
+ features["sentence_embedding"] = self._binary_quantizer(
117
+ features["sentence_embedding"]
118
+ )
119
+ elif quantization == "ubinary":
120
+ features["sentence_embedding"] = self._packed_binary_quantizer(
121
+ features["sentence_embedding"]
122
+ )
123
+ else:
124
+ raise ValueError(
125
+ f"Invalid quantization type: {quantization}. Must be 'binary', 'ubinary', or 'int8'."
126
+ )
127
+ return features
128
+
129
+ @classmethod
130
+ def load(
131
+ cls,
132
+ model_name_or_path: str,
133
+ subfolder: str = "",
134
+ token: bool | str | None = None,
135
+ cache_folder: str | None = None,
136
+ revision: str | None = None,
137
+ local_files_only: bool = False,
138
+ **kwargs,
139
+ ):
140
+ return cls()
141
+
142
+ def save(self, output_path: str, *args, **kwargs) -> None:
143
+ return
tokenizer_config.json ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151642": {
6
+ "content": "â½Ĺ",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151643": {
14
+ "content": "<|endoftext|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151644": {
22
+ "content": "<|im_start|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151645": {
30
+ "content": "<|im_end|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151646": {
38
+ "content": "<|object_ref_start|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151647": {
46
+ "content": "<|object_ref_end|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151648": {
54
+ "content": "<|box_start|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151649": {
62
+ "content": "<|box_end|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151650": {
70
+ "content": "<|quad_start|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151651": {
78
+ "content": "<|quad_end|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151652": {
86
+ "content": "<|vision_start|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151653": {
94
+ "content": "<|vision_end|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151654": {
102
+ "content": "<|vision_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151655": {
110
+ "content": "<|image_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151656": {
118
+ "content": "<|video_pad|>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": true
124
+ },
125
+ "151657": {
126
+ "content": "<tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151658": {
134
+ "content": "</tool_call>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151659": {
142
+ "content": "<|fim_prefix|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151660": {
150
+ "content": "<|fim_middle|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151661": {
158
+ "content": "<|fim_suffix|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151662": {
166
+ "content": "<|fim_pad|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151663": {
174
+ "content": "<|repo_name|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151664": {
182
+ "content": "<|file_sep|>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151665": {
190
+ "content": "<tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151666": {
198
+ "content": "</tool_response>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151667": {
206
+ "content": "<think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ },
213
+ "151668": {
214
+ "content": "</think>",
215
+ "lstrip": false,
216
+ "normalized": false,
217
+ "rstrip": false,
218
+ "single_word": false,
219
+ "special": false
220
+ }
221
+ },
222
+ "additional_special_tokens": [
223
+ "<|im_start|>",
224
+ "<|im_end|>",
225
+ "<|object_ref_start|>",
226
+ "<|object_ref_end|>",
227
+ "<|box_start|>",
228
+ "<|box_end|>",
229
+ "<|quad_start|>",
230
+ "<|quad_end|>",
231
+ "<|vision_start|>",
232
+ "<|vision_end|>",
233
+ "<|vision_pad|>",
234
+ "<|image_pad|>",
235
+ "<|video_pad|>"
236
+ ],
237
+ "bos_token": null,
238
+ "clean_up_tokenization_spaces": false,
239
+ "eos_token": "<|endoftext|>",
240
+ "errors": "replace",
241
+ "extra_special_tokens": {},
242
+ "mask_token": "â½Ĺ",
243
+ "model_max_length": 131072,
244
+ "pad_token": "<|endoftext|>",
245
+ "sep_token": "<|endoftext|>",
246
+ "split_special_tokens": false,
247
+ "tokenizer_class": "Qwen2Tokenizer",
248
+ "unk_token": null
249
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff