xlr8harder commited on
Commit
15d5bc6
·
verified ·
1 Parent(s): bd5fa0b

Upload Transformers safetensors conversion

Browse files
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
1
  *.bin filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  *.pt filter=lfs diff=lfs merge=lfs -text
 
 
3
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: apache-2.0
5
+ library_name: transformers
6
+ pipeline_tag: text-generation
7
+ model_name: talkie-1930-13b-base-tf
8
+ base_model:
9
+ - talkie-lm/talkie-1930-13b-base
10
+ tags:
11
+ - transformers
12
+ - safetensors
13
+ - bfloat16
14
+ - custom_code
15
+ - text-generation
16
+ - conversion
17
+ - talkie
18
+ - pre-1931
19
+ ---
20
+
21
+ # talkie-1930-13b-base-tf (BF16 Transformers + safetensors conversion)
22
+
23
+ This repository is a Transformers-compatible conversion of
24
+ [`talkie-lm/talkie-1930-13b-base`](https://huggingface.co/talkie-lm/talkie-1930-13b-base), the original Talkie base completion model.
25
+
26
+ The upstream model is a 13B vintage language model trained on 260B tokens of pre-1931 English-language text, according to the original model card.
27
+
28
+ The original base checkpoint is FP32. This repository stores a BF16 conversion of those weights and packages them for Transformers with custom `trust_remote_code` modules and BF16 sharded safetensors.
29
+
30
+ This is not an official Talkie release; refer to the upstream model card for
31
+ the author-provided provenance and usage notes.
32
+
33
+ ## Source Model
34
+
35
+ - Original model: [talkie-lm/talkie-1930-13b-base](https://huggingface.co/talkie-lm/talkie-1930-13b-base)
36
+ - Talkie report: [talkie-lm.com](https://talkie-lm.com/)
37
+ - Reference code: [github.com/talkie-lm/talkie](https://github.com/talkie-lm/talkie)
38
+
39
+ ## Conversion Details
40
+
41
+ - Weight dtype: BF16
42
+ - Weight format: sharded safetensors
43
+ - Context length: 2048 tokens
44
+ - Architecture: custom Talkie code loaded with `trust_remote_code=True`
45
+ - Tokenizer: Talkie tiktoken-compatible tokenizer exposed through `AutoTokenizer`
46
+
47
+ ## Usage
48
+
49
+ ```python
50
+ import torch
51
+ from transformers import AutoModelForCausalLM, AutoTokenizer
52
+
53
+ path = "xlr8harder/talkie-1930-13b-base-tf"
54
+ tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
55
+ model = AutoModelForCausalLM.from_pretrained(
56
+ path,
57
+ trust_remote_code=True,
58
+ dtype=torch.bfloat16,
59
+ device_map={"": "cuda"},
60
+ use_safetensors=True,
61
+ )
62
+ ```
63
+
64
+ For base completions:
65
+
66
+ ```python
67
+ inputs = tokenizer("The latest discoveries in physics suggest that", return_tensors="pt").to("cuda")
68
+ output = model.generate(**inputs, max_new_tokens=64)
69
+ print(tokenizer.decode(output[0], skip_special_tokens=True))
70
+ ```
71
+
72
+ ## vLLM
73
+
74
+ The included remote-code model implements the Transformers attention-interface
75
+ hooks expected by vLLM's Transformers modeling backend. For compatibility with
76
+ that backend, the original single-scalar `lm_head_gain` is folded into
77
+ `lm_head.weight` during conversion; the other Talkie gain parameters remain
78
+ explicit model parameters. Using vLLM's `logit_scale`-style approach was not
79
+ used because it applies scaling after the output matmul, while Talkie applies
80
+ the gain to the head weight before the matmul. In BF16 this can introduce small
81
+ rounding differences and, in smoke tests, changed one near-tied top-token
82
+ ordering.
83
+
84
+ ```bash
85
+ vllm serve xlr8harder/talkie-1930-13b-base-tf \
86
+ --task generate \
87
+ --model-impl transformers \
88
+ --trust-remote-code \
89
+ --dtype bfloat16 \
90
+ --max-model-len 2048
91
+ ```
92
+
93
+ ## Validation
94
+
95
+ The BF16 checkpoint matched a runtime BF16 cast from the original FP32 checkpoint exactly on the tested forward pass. The Transformers safetensors model was also compared against the Talkie reference architecture; the top-10 next-token ordering matched exactly, with observed max absolute logit difference `0.03125`.
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TalkieForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_talkie.TalkieConfig",
7
+ "AutoModel": "modeling_talkie.TalkieModel",
8
+ "AutoModelForCausalLM": "modeling_talkie.TalkieForCausalLM"
9
+ },
10
+ "bos_token_id": null,
11
+ "dtype": "bfloat16",
12
+ "eos_token_id": 65535,
13
+ "head_dim": 128,
14
+ "hidden_size": 5120,
15
+ "max_position_embeddings": 2048,
16
+ "model_type": "talkie",
17
+ "n_embd": 5120,
18
+ "n_head": 40,
19
+ "n_layer": 40,
20
+ "num_attention_heads": 40,
21
+ "num_hidden_layers": 40,
22
+ "pad_token_id": 65535,
23
+ "rope_base": 1000000,
24
+ "style": "base",
25
+ "tie_word_embeddings": false,
26
+ "transformers_version": "5.8.0",
27
+ "use_cache": true,
28
+ "vocab_size": 65536,
29
+ "logit_scale": 1.0
30
+ }
configuration_talkie.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class TalkieConfig(PretrainedConfig):
7
+ model_type = "talkie"
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size: int = 65536,
12
+ n_layer: int = 40,
13
+ n_head: int = 40,
14
+ n_embd: int = 5120,
15
+ head_dim: int = 128,
16
+ max_position_embeddings: int = 2048,
17
+ rope_base: int = 1_000_000,
18
+ logit_scale: float = 1.0,
19
+ use_cache: bool = True,
20
+ tie_word_embeddings: bool = False,
21
+ bos_token_id: int | None = None,
22
+ eos_token_id: int | list[int] = 65535,
23
+ pad_token_id: int | None = None,
24
+ **kwargs,
25
+ ):
26
+ super().__init__(
27
+ bos_token_id=bos_token_id,
28
+ eos_token_id=eos_token_id,
29
+ pad_token_id=pad_token_id,
30
+ tie_word_embeddings=tie_word_embeddings,
31
+ **kwargs,
32
+ )
33
+ self.vocab_size = vocab_size
34
+ self.n_layer = n_layer
35
+ self.n_head = n_head
36
+ self.n_embd = n_embd
37
+ self.head_dim = head_dim
38
+ self.max_position_embeddings = max_position_embeddings
39
+ self.rope_base = rope_base
40
+ self.logit_scale = logit_scale
41
+ self.use_cache = use_cache
42
+
43
+ # Common Transformers aliases used by generation/cache helpers.
44
+ self.hidden_size = n_embd
45
+ self.num_hidden_layers = n_layer
46
+ self.num_attention_heads = n_head
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token_id": 65535,
3
+ "pad_token_id": 65535,
4
+ "do_sample": true,
5
+ "temperature": 0.7,
6
+ "use_cache": true
7
+ }
model-00001-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7d004876dc650b041351e9c6685d1d8a35aee37c0600736e5573c5bd9cf44e6
3
+ size 4984675468
model-00002-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09c5ca9bae34ab4bb361745ebfb9116e937fd40a0de1d111b5f2d1ff40e40f07
3
+ size 4903413297
model-00003-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3cf144f6957fb50b1970fb8fe60137744805da6b1aa3247563773fc1a85bce79
3
+ size 4903413354
model-00004-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:127687e083d2b74d4ebfb83d9873e985ea001a2bf220ff60df5862ec154533b1
3
+ size 4991231319
model-00005-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:513652b18305abf2c3d5e1d36fc252bcb862048205213d0952acdcd0430a2aef
3
+ size 4991231586
model-00006-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e31d93dec37dc197396486d7c2fd23fbb9cc716f207655a7a73cf9ba73ce8bbf
3
+ size 1786514866
model.safetensors.index.json ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 26560433520
4
+ },
5
+ "weight_map": {
6
+ "blocks.0.attn.attn_key.weight": "model-00001-of-00006.safetensors",
7
+ "blocks.0.attn.attn_query.weight": "model-00001-of-00006.safetensors",
8
+ "blocks.0.attn.attn_resid.weight": "model-00001-of-00006.safetensors",
9
+ "blocks.0.attn.attn_value.weight": "model-00001-of-00006.safetensors",
10
+ "blocks.0.attn.head_gain.head_g": "model-00001-of-00006.safetensors",
11
+ "blocks.0.attn_gain.a_g": "model-00001-of-00006.safetensors",
12
+ "blocks.0.embed_skip.a_g": "model-00001-of-00006.safetensors",
13
+ "blocks.0.mlp.mlp_gate.weight": "model-00001-of-00006.safetensors",
14
+ "blocks.0.mlp.mlp_linear.weight": "model-00001-of-00006.safetensors",
15
+ "blocks.0.mlp.mlp_resid.weight": "model-00001-of-00006.safetensors",
16
+ "blocks.0.mlp_gain.a_g": "model-00001-of-00006.safetensors",
17
+ "blocks.1.attn.attn_key.weight": "model-00001-of-00006.safetensors",
18
+ "blocks.1.attn.attn_query.weight": "model-00001-of-00006.safetensors",
19
+ "blocks.1.attn.attn_resid.weight": "model-00001-of-00006.safetensors",
20
+ "blocks.1.attn.attn_value.weight": "model-00001-of-00006.safetensors",
21
+ "blocks.1.attn.head_gain.head_g": "model-00001-of-00006.safetensors",
22
+ "blocks.1.attn_gain.a_g": "model-00001-of-00006.safetensors",
23
+ "blocks.1.embed_skip.a_g": "model-00001-of-00006.safetensors",
24
+ "blocks.1.mlp.mlp_gate.weight": "model-00001-of-00006.safetensors",
25
+ "blocks.1.mlp.mlp_linear.weight": "model-00001-of-00006.safetensors",
26
+ "blocks.1.mlp.mlp_resid.weight": "model-00001-of-00006.safetensors",
27
+ "blocks.1.mlp_gain.a_g": "model-00001-of-00006.safetensors",
28
+ "blocks.10.attn.attn_key.weight": "model-00002-of-00006.safetensors",
29
+ "blocks.10.attn.attn_query.weight": "model-00002-of-00006.safetensors",
30
+ "blocks.10.attn.attn_resid.weight": "model-00002-of-00006.safetensors",
31
+ "blocks.10.attn.attn_value.weight": "model-00002-of-00006.safetensors",
32
+ "blocks.10.attn.head_gain.head_g": "model-00002-of-00006.safetensors",
33
+ "blocks.10.attn_gain.a_g": "model-00002-of-00006.safetensors",
34
+ "blocks.10.embed_skip.a_g": "model-00002-of-00006.safetensors",
35
+ "blocks.10.mlp.mlp_gate.weight": "model-00002-of-00006.safetensors",
36
+ "blocks.10.mlp.mlp_linear.weight": "model-00002-of-00006.safetensors",
37
+ "blocks.10.mlp.mlp_resid.weight": "model-00002-of-00006.safetensors",
38
+ "blocks.10.mlp_gain.a_g": "model-00002-of-00006.safetensors",
39
+ "blocks.11.attn.attn_key.weight": "model-00002-of-00006.safetensors",
40
+ "blocks.11.attn.attn_query.weight": "model-00002-of-00006.safetensors",
41
+ "blocks.11.attn.attn_resid.weight": "model-00002-of-00006.safetensors",
42
+ "blocks.11.attn.attn_value.weight": "model-00002-of-00006.safetensors",
43
+ "blocks.11.attn.head_gain.head_g": "model-00002-of-00006.safetensors",
44
+ "blocks.11.attn_gain.a_g": "model-00002-of-00006.safetensors",
45
+ "blocks.11.embed_skip.a_g": "model-00002-of-00006.safetensors",
46
+ "blocks.11.mlp.mlp_gate.weight": "model-00002-of-00006.safetensors",
47
+ "blocks.11.mlp.mlp_linear.weight": "model-00002-of-00006.safetensors",
48
+ "blocks.11.mlp.mlp_resid.weight": "model-00002-of-00006.safetensors",
49
+ "blocks.11.mlp_gain.a_g": "model-00002-of-00006.safetensors",
50
+ "blocks.12.attn.attn_key.weight": "model-00002-of-00006.safetensors",
51
+ "blocks.12.attn.attn_query.weight": "model-00002-of-00006.safetensors",
52
+ "blocks.12.attn.attn_resid.weight": "model-00002-of-00006.safetensors",
53
+ "blocks.12.attn.attn_value.weight": "model-00002-of-00006.safetensors",
54
+ "blocks.12.attn.head_gain.head_g": "model-00002-of-00006.safetensors",
55
+ "blocks.12.attn_gain.a_g": "model-00002-of-00006.safetensors",
56
+ "blocks.12.embed_skip.a_g": "model-00002-of-00006.safetensors",
57
+ "blocks.12.mlp.mlp_gate.weight": "model-00002-of-00006.safetensors",
58
+ "blocks.12.mlp.mlp_linear.weight": "model-00002-of-00006.safetensors",
59
+ "blocks.12.mlp.mlp_resid.weight": "model-00002-of-00006.safetensors",
60
+ "blocks.12.mlp_gain.a_g": "model-00002-of-00006.safetensors",
61
+ "blocks.13.attn.attn_key.weight": "model-00002-of-00006.safetensors",
62
+ "blocks.13.attn.attn_query.weight": "model-00002-of-00006.safetensors",
63
+ "blocks.13.attn.attn_resid.weight": "model-00002-of-00006.safetensors",
64
+ "blocks.13.attn.attn_value.weight": "model-00002-of-00006.safetensors",
65
+ "blocks.13.attn.head_gain.head_g": "model-00002-of-00006.safetensors",
66
+ "blocks.13.attn_gain.a_g": "model-00002-of-00006.safetensors",
67
+ "blocks.13.embed_skip.a_g": "model-00003-of-00006.safetensors",
68
+ "blocks.13.mlp.mlp_gate.weight": "model-00002-of-00006.safetensors",
69
+ "blocks.13.mlp.mlp_linear.weight": "model-00003-of-00006.safetensors",
70
+ "blocks.13.mlp.mlp_resid.weight": "model-00003-of-00006.safetensors",
71
+ "blocks.13.mlp_gain.a_g": "model-00003-of-00006.safetensors",
72
+ "blocks.14.attn.attn_key.weight": "model-00003-of-00006.safetensors",
73
+ "blocks.14.attn.attn_query.weight": "model-00003-of-00006.safetensors",
74
+ "blocks.14.attn.attn_resid.weight": "model-00003-of-00006.safetensors",
75
+ "blocks.14.attn.attn_value.weight": "model-00003-of-00006.safetensors",
76
+ "blocks.14.attn.head_gain.head_g": "model-00003-of-00006.safetensors",
77
+ "blocks.14.attn_gain.a_g": "model-00003-of-00006.safetensors",
78
+ "blocks.14.embed_skip.a_g": "model-00003-of-00006.safetensors",
79
+ "blocks.14.mlp.mlp_gate.weight": "model-00003-of-00006.safetensors",
80
+ "blocks.14.mlp.mlp_linear.weight": "model-00003-of-00006.safetensors",
81
+ "blocks.14.mlp.mlp_resid.weight": "model-00003-of-00006.safetensors",
82
+ "blocks.14.mlp_gain.a_g": "model-00003-of-00006.safetensors",
83
+ "blocks.15.attn.attn_key.weight": "model-00003-of-00006.safetensors",
84
+ "blocks.15.attn.attn_query.weight": "model-00003-of-00006.safetensors",
85
+ "blocks.15.attn.attn_resid.weight": "model-00003-of-00006.safetensors",
86
+ "blocks.15.attn.attn_value.weight": "model-00003-of-00006.safetensors",
87
+ "blocks.15.attn.head_gain.head_g": "model-00003-of-00006.safetensors",
88
+ "blocks.15.attn_gain.a_g": "model-00003-of-00006.safetensors",
89
+ "blocks.15.embed_skip.a_g": "model-00003-of-00006.safetensors",
90
+ "blocks.15.mlp.mlp_gate.weight": "model-00003-of-00006.safetensors",
91
+ "blocks.15.mlp.mlp_linear.weight": "model-00003-of-00006.safetensors",
92
+ "blocks.15.mlp.mlp_resid.weight": "model-00003-of-00006.safetensors",
93
+ "blocks.15.mlp_gain.a_g": "model-00003-of-00006.safetensors",
94
+ "blocks.16.attn.attn_key.weight": "model-00003-of-00006.safetensors",
95
+ "blocks.16.attn.attn_query.weight": "model-00003-of-00006.safetensors",
96
+ "blocks.16.attn.attn_resid.weight": "model-00003-of-00006.safetensors",
97
+ "blocks.16.attn.attn_value.weight": "model-00003-of-00006.safetensors",
98
+ "blocks.16.attn.head_gain.head_g": "model-00003-of-00006.safetensors",
99
+ "blocks.16.attn_gain.a_g": "model-00003-of-00006.safetensors",
100
+ "blocks.16.embed_skip.a_g": "model-00003-of-00006.safetensors",
101
+ "blocks.16.mlp.mlp_gate.weight": "model-00003-of-00006.safetensors",
102
+ "blocks.16.mlp.mlp_linear.weight": "model-00003-of-00006.safetensors",
103
+ "blocks.16.mlp.mlp_resid.weight": "model-00003-of-00006.safetensors",
104
+ "blocks.16.mlp_gain.a_g": "model-00003-of-00006.safetensors",
105
+ "blocks.17.attn.attn_key.weight": "model-00003-of-00006.safetensors",
106
+ "blocks.17.attn.attn_query.weight": "model-00003-of-00006.safetensors",
107
+ "blocks.17.attn.attn_resid.weight": "model-00003-of-00006.safetensors",
108
+ "blocks.17.attn.attn_value.weight": "model-00003-of-00006.safetensors",
109
+ "blocks.17.attn.head_gain.head_g": "model-00003-of-00006.safetensors",
110
+ "blocks.17.attn_gain.a_g": "model-00003-of-00006.safetensors",
111
+ "blocks.17.embed_skip.a_g": "model-00003-of-00006.safetensors",
112
+ "blocks.17.mlp.mlp_gate.weight": "model-00003-of-00006.safetensors",
113
+ "blocks.17.mlp.mlp_linear.weight": "model-00003-of-00006.safetensors",
114
+ "blocks.17.mlp.mlp_resid.weight": "model-00003-of-00006.safetensors",
115
+ "blocks.17.mlp_gain.a_g": "model-00003-of-00006.safetensors",
116
+ "blocks.18.attn.attn_key.weight": "model-00003-of-00006.safetensors",
117
+ "blocks.18.attn.attn_query.weight": "model-00003-of-00006.safetensors",
118
+ "blocks.18.attn.attn_resid.weight": "model-00003-of-00006.safetensors",
119
+ "blocks.18.attn.attn_value.weight": "model-00003-of-00006.safetensors",
120
+ "blocks.18.attn.head_gain.head_g": "model-00003-of-00006.safetensors",
121
+ "blocks.18.attn_gain.a_g": "model-00003-of-00006.safetensors",
122
+ "blocks.18.embed_skip.a_g": "model-00003-of-00006.safetensors",
123
+ "blocks.18.mlp.mlp_gate.weight": "model-00003-of-00006.safetensors",
124
+ "blocks.18.mlp.mlp_linear.weight": "model-00003-of-00006.safetensors",
125
+ "blocks.18.mlp.mlp_resid.weight": "model-00003-of-00006.safetensors",
126
+ "blocks.18.mlp_gain.a_g": "model-00003-of-00006.safetensors",
127
+ "blocks.19.attn.attn_key.weight": "model-00003-of-00006.safetensors",
128
+ "blocks.19.attn.attn_query.weight": "model-00003-of-00006.safetensors",
129
+ "blocks.19.attn.attn_resid.weight": "model-00003-of-00006.safetensors",
130
+ "blocks.19.attn.attn_value.weight": "model-00003-of-00006.safetensors",
131
+ "blocks.19.attn.head_gain.head_g": "model-00003-of-00006.safetensors",
132
+ "blocks.19.attn_gain.a_g": "model-00003-of-00006.safetensors",
133
+ "blocks.19.embed_skip.a_g": "model-00003-of-00006.safetensors",
134
+ "blocks.19.mlp.mlp_gate.weight": "model-00003-of-00006.safetensors",
135
+ "blocks.19.mlp.mlp_linear.weight": "model-00003-of-00006.safetensors",
136
+ "blocks.19.mlp.mlp_resid.weight": "model-00003-of-00006.safetensors",
137
+ "blocks.19.mlp_gain.a_g": "model-00003-of-00006.safetensors",
138
+ "blocks.2.attn.attn_key.weight": "model-00001-of-00006.safetensors",
139
+ "blocks.2.attn.attn_query.weight": "model-00001-of-00006.safetensors",
140
+ "blocks.2.attn.attn_resid.weight": "model-00001-of-00006.safetensors",
141
+ "blocks.2.attn.attn_value.weight": "model-00001-of-00006.safetensors",
142
+ "blocks.2.attn.head_gain.head_g": "model-00001-of-00006.safetensors",
143
+ "blocks.2.attn_gain.a_g": "model-00001-of-00006.safetensors",
144
+ "blocks.2.embed_skip.a_g": "model-00001-of-00006.safetensors",
145
+ "blocks.2.mlp.mlp_gate.weight": "model-00001-of-00006.safetensors",
146
+ "blocks.2.mlp.mlp_linear.weight": "model-00001-of-00006.safetensors",
147
+ "blocks.2.mlp.mlp_resid.weight": "model-00001-of-00006.safetensors",
148
+ "blocks.2.mlp_gain.a_g": "model-00001-of-00006.safetensors",
149
+ "blocks.20.attn.attn_key.weight": "model-00003-of-00006.safetensors",
150
+ "blocks.20.attn.attn_query.weight": "model-00003-of-00006.safetensors",
151
+ "blocks.20.attn.attn_resid.weight": "model-00003-of-00006.safetensors",
152
+ "blocks.20.attn.attn_value.weight": "model-00003-of-00006.safetensors",
153
+ "blocks.20.attn.head_gain.head_g": "model-00003-of-00006.safetensors",
154
+ "blocks.20.attn_gain.a_g": "model-00003-of-00006.safetensors",
155
+ "blocks.20.embed_skip.a_g": "model-00003-of-00006.safetensors",
156
+ "blocks.20.mlp.mlp_gate.weight": "model-00003-of-00006.safetensors",
157
+ "blocks.20.mlp.mlp_linear.weight": "model-00003-of-00006.safetensors",
158
+ "blocks.20.mlp.mlp_resid.weight": "model-00003-of-00006.safetensors",
159
+ "blocks.20.mlp_gain.a_g": "model-00003-of-00006.safetensors",
160
+ "blocks.21.attn.attn_key.weight": "model-00003-of-00006.safetensors",
161
+ "blocks.21.attn.attn_query.weight": "model-00003-of-00006.safetensors",
162
+ "blocks.21.attn.attn_resid.weight": "model-00003-of-00006.safetensors",
163
+ "blocks.21.attn.attn_value.weight": "model-00003-of-00006.safetensors",
164
+ "blocks.21.attn.head_gain.head_g": "model-00003-of-00006.safetensors",
165
+ "blocks.21.attn_gain.a_g": "model-00003-of-00006.safetensors",
166
+ "blocks.21.embed_skip.a_g": "model-00004-of-00006.safetensors",
167
+ "blocks.21.mlp.mlp_gate.weight": "model-00004-of-00006.safetensors",
168
+ "blocks.21.mlp.mlp_linear.weight": "model-00004-of-00006.safetensors",
169
+ "blocks.21.mlp.mlp_resid.weight": "model-00004-of-00006.safetensors",
170
+ "blocks.21.mlp_gain.a_g": "model-00004-of-00006.safetensors",
171
+ "blocks.22.attn.attn_key.weight": "model-00004-of-00006.safetensors",
172
+ "blocks.22.attn.attn_query.weight": "model-00004-of-00006.safetensors",
173
+ "blocks.22.attn.attn_resid.weight": "model-00004-of-00006.safetensors",
174
+ "blocks.22.attn.attn_value.weight": "model-00004-of-00006.safetensors",
175
+ "blocks.22.attn.head_gain.head_g": "model-00004-of-00006.safetensors",
176
+ "blocks.22.attn_gain.a_g": "model-00004-of-00006.safetensors",
177
+ "blocks.22.embed_skip.a_g": "model-00004-of-00006.safetensors",
178
+ "blocks.22.mlp.mlp_gate.weight": "model-00004-of-00006.safetensors",
179
+ "blocks.22.mlp.mlp_linear.weight": "model-00004-of-00006.safetensors",
180
+ "blocks.22.mlp.mlp_resid.weight": "model-00004-of-00006.safetensors",
181
+ "blocks.22.mlp_gain.a_g": "model-00004-of-00006.safetensors",
182
+ "blocks.23.attn.attn_key.weight": "model-00004-of-00006.safetensors",
183
+ "blocks.23.attn.attn_query.weight": "model-00004-of-00006.safetensors",
184
+ "blocks.23.attn.attn_resid.weight": "model-00004-of-00006.safetensors",
185
+ "blocks.23.attn.attn_value.weight": "model-00004-of-00006.safetensors",
186
+ "blocks.23.attn.head_gain.head_g": "model-00004-of-00006.safetensors",
187
+ "blocks.23.attn_gain.a_g": "model-00004-of-00006.safetensors",
188
+ "blocks.23.embed_skip.a_g": "model-00004-of-00006.safetensors",
189
+ "blocks.23.mlp.mlp_gate.weight": "model-00004-of-00006.safetensors",
190
+ "blocks.23.mlp.mlp_linear.weight": "model-00004-of-00006.safetensors",
191
+ "blocks.23.mlp.mlp_resid.weight": "model-00004-of-00006.safetensors",
192
+ "blocks.23.mlp_gain.a_g": "model-00004-of-00006.safetensors",
193
+ "blocks.24.attn.attn_key.weight": "model-00004-of-00006.safetensors",
194
+ "blocks.24.attn.attn_query.weight": "model-00004-of-00006.safetensors",
195
+ "blocks.24.attn.attn_resid.weight": "model-00004-of-00006.safetensors",
196
+ "blocks.24.attn.attn_value.weight": "model-00004-of-00006.safetensors",
197
+ "blocks.24.attn.head_gain.head_g": "model-00004-of-00006.safetensors",
198
+ "blocks.24.attn_gain.a_g": "model-00004-of-00006.safetensors",
199
+ "blocks.24.embed_skip.a_g": "model-00004-of-00006.safetensors",
200
+ "blocks.24.mlp.mlp_gate.weight": "model-00004-of-00006.safetensors",
201
+ "blocks.24.mlp.mlp_linear.weight": "model-00004-of-00006.safetensors",
202
+ "blocks.24.mlp.mlp_resid.weight": "model-00004-of-00006.safetensors",
203
+ "blocks.24.mlp_gain.a_g": "model-00004-of-00006.safetensors",
204
+ "blocks.25.attn.attn_key.weight": "model-00004-of-00006.safetensors",
205
+ "blocks.25.attn.attn_query.weight": "model-00004-of-00006.safetensors",
206
+ "blocks.25.attn.attn_resid.weight": "model-00004-of-00006.safetensors",
207
+ "blocks.25.attn.attn_value.weight": "model-00004-of-00006.safetensors",
208
+ "blocks.25.attn.head_gain.head_g": "model-00004-of-00006.safetensors",
209
+ "blocks.25.attn_gain.a_g": "model-00004-of-00006.safetensors",
210
+ "blocks.25.embed_skip.a_g": "model-00004-of-00006.safetensors",
211
+ "blocks.25.mlp.mlp_gate.weight": "model-00004-of-00006.safetensors",
212
+ "blocks.25.mlp.mlp_linear.weight": "model-00004-of-00006.safetensors",
213
+ "blocks.25.mlp.mlp_resid.weight": "model-00004-of-00006.safetensors",
214
+ "blocks.25.mlp_gain.a_g": "model-00004-of-00006.safetensors",
215
+ "blocks.26.attn.attn_key.weight": "model-00004-of-00006.safetensors",
216
+ "blocks.26.attn.attn_query.weight": "model-00004-of-00006.safetensors",
217
+ "blocks.26.attn.attn_resid.weight": "model-00004-of-00006.safetensors",
218
+ "blocks.26.attn.attn_value.weight": "model-00004-of-00006.safetensors",
219
+ "blocks.26.attn.head_gain.head_g": "model-00004-of-00006.safetensors",
220
+ "blocks.26.attn_gain.a_g": "model-00004-of-00006.safetensors",
221
+ "blocks.26.embed_skip.a_g": "model-00004-of-00006.safetensors",
222
+ "blocks.26.mlp.mlp_gate.weight": "model-00004-of-00006.safetensors",
223
+ "blocks.26.mlp.mlp_linear.weight": "model-00004-of-00006.safetensors",
224
+ "blocks.26.mlp.mlp_resid.weight": "model-00004-of-00006.safetensors",
225
+ "blocks.26.mlp_gain.a_g": "model-00004-of-00006.safetensors",
226
+ "blocks.27.attn.attn_key.weight": "model-00004-of-00006.safetensors",
227
+ "blocks.27.attn.attn_query.weight": "model-00004-of-00006.safetensors",
228
+ "blocks.27.attn.attn_resid.weight": "model-00004-of-00006.safetensors",
229
+ "blocks.27.attn.attn_value.weight": "model-00004-of-00006.safetensors",
230
+ "blocks.27.attn.head_gain.head_g": "model-00004-of-00006.safetensors",
231
+ "blocks.27.attn_gain.a_g": "model-00004-of-00006.safetensors",
232
+ "blocks.27.embed_skip.a_g": "model-00004-of-00006.safetensors",
233
+ "blocks.27.mlp.mlp_gate.weight": "model-00004-of-00006.safetensors",
234
+ "blocks.27.mlp.mlp_linear.weight": "model-00004-of-00006.safetensors",
235
+ "blocks.27.mlp.mlp_resid.weight": "model-00004-of-00006.safetensors",
236
+ "blocks.27.mlp_gain.a_g": "model-00004-of-00006.safetensors",
237
+ "blocks.28.attn.attn_key.weight": "model-00004-of-00006.safetensors",
238
+ "blocks.28.attn.attn_query.weight": "model-00004-of-00006.safetensors",
239
+ "blocks.28.attn.attn_resid.weight": "model-00004-of-00006.safetensors",
240
+ "blocks.28.attn.attn_value.weight": "model-00004-of-00006.safetensors",
241
+ "blocks.28.attn.head_gain.head_g": "model-00004-of-00006.safetensors",
242
+ "blocks.28.attn_gain.a_g": "model-00004-of-00006.safetensors",
243
+ "blocks.28.embed_skip.a_g": "model-00004-of-00006.safetensors",
244
+ "blocks.28.mlp.mlp_gate.weight": "model-00004-of-00006.safetensors",
245
+ "blocks.28.mlp.mlp_linear.weight": "model-00004-of-00006.safetensors",
246
+ "blocks.28.mlp.mlp_resid.weight": "model-00004-of-00006.safetensors",
247
+ "blocks.28.mlp_gain.a_g": "model-00004-of-00006.safetensors",
248
+ "blocks.29.attn.attn_key.weight": "model-00004-of-00006.safetensors",
249
+ "blocks.29.attn.attn_query.weight": "model-00004-of-00006.safetensors",
250
+ "blocks.29.attn.attn_resid.weight": "model-00005-of-00006.safetensors",
251
+ "blocks.29.attn.attn_value.weight": "model-00004-of-00006.safetensors",
252
+ "blocks.29.attn.head_gain.head_g": "model-00005-of-00006.safetensors",
253
+ "blocks.29.attn_gain.a_g": "model-00005-of-00006.safetensors",
254
+ "blocks.29.embed_skip.a_g": "model-00005-of-00006.safetensors",
255
+ "blocks.29.mlp.mlp_gate.weight": "model-00005-of-00006.safetensors",
256
+ "blocks.29.mlp.mlp_linear.weight": "model-00005-of-00006.safetensors",
257
+ "blocks.29.mlp.mlp_resid.weight": "model-00005-of-00006.safetensors",
258
+ "blocks.29.mlp_gain.a_g": "model-00005-of-00006.safetensors",
259
+ "blocks.3.attn.attn_key.weight": "model-00001-of-00006.safetensors",
260
+ "blocks.3.attn.attn_query.weight": "model-00001-of-00006.safetensors",
261
+ "blocks.3.attn.attn_resid.weight": "model-00001-of-00006.safetensors",
262
+ "blocks.3.attn.attn_value.weight": "model-00001-of-00006.safetensors",
263
+ "blocks.3.attn.head_gain.head_g": "model-00001-of-00006.safetensors",
264
+ "blocks.3.attn_gain.a_g": "model-00001-of-00006.safetensors",
265
+ "blocks.3.embed_skip.a_g": "model-00001-of-00006.safetensors",
266
+ "blocks.3.mlp.mlp_gate.weight": "model-00001-of-00006.safetensors",
267
+ "blocks.3.mlp.mlp_linear.weight": "model-00001-of-00006.safetensors",
268
+ "blocks.3.mlp.mlp_resid.weight": "model-00001-of-00006.safetensors",
269
+ "blocks.3.mlp_gain.a_g": "model-00001-of-00006.safetensors",
270
+ "blocks.30.attn.attn_key.weight": "model-00005-of-00006.safetensors",
271
+ "blocks.30.attn.attn_query.weight": "model-00005-of-00006.safetensors",
272
+ "blocks.30.attn.attn_resid.weight": "model-00005-of-00006.safetensors",
273
+ "blocks.30.attn.attn_value.weight": "model-00005-of-00006.safetensors",
274
+ "blocks.30.attn.head_gain.head_g": "model-00005-of-00006.safetensors",
275
+ "blocks.30.attn_gain.a_g": "model-00005-of-00006.safetensors",
276
+ "blocks.30.embed_skip.a_g": "model-00005-of-00006.safetensors",
277
+ "blocks.30.mlp.mlp_gate.weight": "model-00005-of-00006.safetensors",
278
+ "blocks.30.mlp.mlp_linear.weight": "model-00005-of-00006.safetensors",
279
+ "blocks.30.mlp.mlp_resid.weight": "model-00005-of-00006.safetensors",
280
+ "blocks.30.mlp_gain.a_g": "model-00005-of-00006.safetensors",
281
+ "blocks.31.attn.attn_key.weight": "model-00005-of-00006.safetensors",
282
+ "blocks.31.attn.attn_query.weight": "model-00005-of-00006.safetensors",
283
+ "blocks.31.attn.attn_resid.weight": "model-00005-of-00006.safetensors",
284
+ "blocks.31.attn.attn_value.weight": "model-00005-of-00006.safetensors",
285
+ "blocks.31.attn.head_gain.head_g": "model-00005-of-00006.safetensors",
286
+ "blocks.31.attn_gain.a_g": "model-00005-of-00006.safetensors",
287
+ "blocks.31.embed_skip.a_g": "model-00005-of-00006.safetensors",
288
+ "blocks.31.mlp.mlp_gate.weight": "model-00005-of-00006.safetensors",
289
+ "blocks.31.mlp.mlp_linear.weight": "model-00005-of-00006.safetensors",
290
+ "blocks.31.mlp.mlp_resid.weight": "model-00005-of-00006.safetensors",
291
+ "blocks.31.mlp_gain.a_g": "model-00005-of-00006.safetensors",
292
+ "blocks.32.attn.attn_key.weight": "model-00005-of-00006.safetensors",
293
+ "blocks.32.attn.attn_query.weight": "model-00005-of-00006.safetensors",
294
+ "blocks.32.attn.attn_resid.weight": "model-00005-of-00006.safetensors",
295
+ "blocks.32.attn.attn_value.weight": "model-00005-of-00006.safetensors",
296
+ "blocks.32.attn.head_gain.head_g": "model-00005-of-00006.safetensors",
297
+ "blocks.32.attn_gain.a_g": "model-00005-of-00006.safetensors",
298
+ "blocks.32.embed_skip.a_g": "model-00005-of-00006.safetensors",
299
+ "blocks.32.mlp.mlp_gate.weight": "model-00005-of-00006.safetensors",
300
+ "blocks.32.mlp.mlp_linear.weight": "model-00005-of-00006.safetensors",
301
+ "blocks.32.mlp.mlp_resid.weight": "model-00005-of-00006.safetensors",
302
+ "blocks.32.mlp_gain.a_g": "model-00005-of-00006.safetensors",
303
+ "blocks.33.attn.attn_key.weight": "model-00005-of-00006.safetensors",
304
+ "blocks.33.attn.attn_query.weight": "model-00005-of-00006.safetensors",
305
+ "blocks.33.attn.attn_resid.weight": "model-00005-of-00006.safetensors",
306
+ "blocks.33.attn.attn_value.weight": "model-00005-of-00006.safetensors",
307
+ "blocks.33.attn.head_gain.head_g": "model-00005-of-00006.safetensors",
308
+ "blocks.33.attn_gain.a_g": "model-00005-of-00006.safetensors",
309
+ "blocks.33.embed_skip.a_g": "model-00005-of-00006.safetensors",
310
+ "blocks.33.mlp.mlp_gate.weight": "model-00005-of-00006.safetensors",
311
+ "blocks.33.mlp.mlp_linear.weight": "model-00005-of-00006.safetensors",
312
+ "blocks.33.mlp.mlp_resid.weight": "model-00005-of-00006.safetensors",
313
+ "blocks.33.mlp_gain.a_g": "model-00005-of-00006.safetensors",
314
+ "blocks.34.attn.attn_key.weight": "model-00005-of-00006.safetensors",
315
+ "blocks.34.attn.attn_query.weight": "model-00005-of-00006.safetensors",
316
+ "blocks.34.attn.attn_resid.weight": "model-00005-of-00006.safetensors",
317
+ "blocks.34.attn.attn_value.weight": "model-00005-of-00006.safetensors",
318
+ "blocks.34.attn.head_gain.head_g": "model-00005-of-00006.safetensors",
319
+ "blocks.34.attn_gain.a_g": "model-00005-of-00006.safetensors",
320
+ "blocks.34.embed_skip.a_g": "model-00005-of-00006.safetensors",
321
+ "blocks.34.mlp.mlp_gate.weight": "model-00005-of-00006.safetensors",
322
+ "blocks.34.mlp.mlp_linear.weight": "model-00005-of-00006.safetensors",
323
+ "blocks.34.mlp.mlp_resid.weight": "model-00005-of-00006.safetensors",
324
+ "blocks.34.mlp_gain.a_g": "model-00005-of-00006.safetensors",
325
+ "blocks.35.attn.attn_key.weight": "model-00005-of-00006.safetensors",
326
+ "blocks.35.attn.attn_query.weight": "model-00005-of-00006.safetensors",
327
+ "blocks.35.attn.attn_resid.weight": "model-00005-of-00006.safetensors",
328
+ "blocks.35.attn.attn_value.weight": "model-00005-of-00006.safetensors",
329
+ "blocks.35.attn.head_gain.head_g": "model-00005-of-00006.safetensors",
330
+ "blocks.35.attn_gain.a_g": "model-00005-of-00006.safetensors",
331
+ "blocks.35.embed_skip.a_g": "model-00005-of-00006.safetensors",
332
+ "blocks.35.mlp.mlp_gate.weight": "model-00005-of-00006.safetensors",
333
+ "blocks.35.mlp.mlp_linear.weight": "model-00005-of-00006.safetensors",
334
+ "blocks.35.mlp.mlp_resid.weight": "model-00005-of-00006.safetensors",
335
+ "blocks.35.mlp_gain.a_g": "model-00005-of-00006.safetensors",
336
+ "blocks.36.attn.attn_key.weight": "model-00005-of-00006.safetensors",
337
+ "blocks.36.attn.attn_query.weight": "model-00005-of-00006.safetensors",
338
+ "blocks.36.attn.attn_resid.weight": "model-00005-of-00006.safetensors",
339
+ "blocks.36.attn.attn_value.weight": "model-00005-of-00006.safetensors",
340
+ "blocks.36.attn.head_gain.head_g": "model-00005-of-00006.safetensors",
341
+ "blocks.36.attn_gain.a_g": "model-00005-of-00006.safetensors",
342
+ "blocks.36.embed_skip.a_g": "model-00005-of-00006.safetensors",
343
+ "blocks.36.mlp.mlp_gate.weight": "model-00005-of-00006.safetensors",
344
+ "blocks.36.mlp.mlp_linear.weight": "model-00005-of-00006.safetensors",
345
+ "blocks.36.mlp.mlp_resid.weight": "model-00005-of-00006.safetensors",
346
+ "blocks.36.mlp_gain.a_g": "model-00005-of-00006.safetensors",
347
+ "blocks.37.attn.attn_key.weight": "model-00005-of-00006.safetensors",
348
+ "blocks.37.attn.attn_query.weight": "model-00005-of-00006.safetensors",
349
+ "blocks.37.attn.attn_resid.weight": "model-00006-of-00006.safetensors",
350
+ "blocks.37.attn.attn_value.weight": "model-00006-of-00006.safetensors",
351
+ "blocks.37.attn.head_gain.head_g": "model-00006-of-00006.safetensors",
352
+ "blocks.37.attn_gain.a_g": "model-00006-of-00006.safetensors",
353
+ "blocks.37.embed_skip.a_g": "model-00006-of-00006.safetensors",
354
+ "blocks.37.mlp.mlp_gate.weight": "model-00006-of-00006.safetensors",
355
+ "blocks.37.mlp.mlp_linear.weight": "model-00006-of-00006.safetensors",
356
+ "blocks.37.mlp.mlp_resid.weight": "model-00006-of-00006.safetensors",
357
+ "blocks.37.mlp_gain.a_g": "model-00006-of-00006.safetensors",
358
+ "blocks.38.attn.attn_key.weight": "model-00006-of-00006.safetensors",
359
+ "blocks.38.attn.attn_query.weight": "model-00006-of-00006.safetensors",
360
+ "blocks.38.attn.attn_resid.weight": "model-00006-of-00006.safetensors",
361
+ "blocks.38.attn.attn_value.weight": "model-00006-of-00006.safetensors",
362
+ "blocks.38.attn.head_gain.head_g": "model-00006-of-00006.safetensors",
363
+ "blocks.38.attn_gain.a_g": "model-00006-of-00006.safetensors",
364
+ "blocks.38.embed_skip.a_g": "model-00006-of-00006.safetensors",
365
+ "blocks.38.mlp.mlp_gate.weight": "model-00006-of-00006.safetensors",
366
+ "blocks.38.mlp.mlp_linear.weight": "model-00006-of-00006.safetensors",
367
+ "blocks.38.mlp.mlp_resid.weight": "model-00006-of-00006.safetensors",
368
+ "blocks.38.mlp_gain.a_g": "model-00006-of-00006.safetensors",
369
+ "blocks.39.attn.attn_key.weight": "model-00006-of-00006.safetensors",
370
+ "blocks.39.attn.attn_query.weight": "model-00006-of-00006.safetensors",
371
+ "blocks.39.attn.attn_resid.weight": "model-00006-of-00006.safetensors",
372
+ "blocks.39.attn.attn_value.weight": "model-00006-of-00006.safetensors",
373
+ "blocks.39.attn.head_gain.head_g": "model-00006-of-00006.safetensors",
374
+ "blocks.39.attn_gain.a_g": "model-00006-of-00006.safetensors",
375
+ "blocks.39.embed_skip.a_g": "model-00006-of-00006.safetensors",
376
+ "blocks.39.mlp.mlp_gate.weight": "model-00006-of-00006.safetensors",
377
+ "blocks.39.mlp.mlp_linear.weight": "model-00006-of-00006.safetensors",
378
+ "blocks.39.mlp.mlp_resid.weight": "model-00006-of-00006.safetensors",
379
+ "blocks.39.mlp_gain.a_g": "model-00006-of-00006.safetensors",
380
+ "blocks.4.attn.attn_key.weight": "model-00001-of-00006.safetensors",
381
+ "blocks.4.attn.attn_query.weight": "model-00001-of-00006.safetensors",
382
+ "blocks.4.attn.attn_resid.weight": "model-00001-of-00006.safetensors",
383
+ "blocks.4.attn.attn_value.weight": "model-00001-of-00006.safetensors",
384
+ "blocks.4.attn.head_gain.head_g": "model-00001-of-00006.safetensors",
385
+ "blocks.4.attn_gain.a_g": "model-00001-of-00006.safetensors",
386
+ "blocks.4.embed_skip.a_g": "model-00001-of-00006.safetensors",
387
+ "blocks.4.mlp.mlp_gate.weight": "model-00001-of-00006.safetensors",
388
+ "blocks.4.mlp.mlp_linear.weight": "model-00001-of-00006.safetensors",
389
+ "blocks.4.mlp.mlp_resid.weight": "model-00001-of-00006.safetensors",
390
+ "blocks.4.mlp_gain.a_g": "model-00001-of-00006.safetensors",
391
+ "blocks.5.attn.attn_key.weight": "model-00001-of-00006.safetensors",
392
+ "blocks.5.attn.attn_query.weight": "model-00001-of-00006.safetensors",
393
+ "blocks.5.attn.attn_resid.weight": "model-00001-of-00006.safetensors",
394
+ "blocks.5.attn.attn_value.weight": "model-00001-of-00006.safetensors",
395
+ "blocks.5.attn.head_gain.head_g": "model-00001-of-00006.safetensors",
396
+ "blocks.5.attn_gain.a_g": "model-00001-of-00006.safetensors",
397
+ "blocks.5.embed_skip.a_g": "model-00002-of-00006.safetensors",
398
+ "blocks.5.mlp.mlp_gate.weight": "model-00001-of-00006.safetensors",
399
+ "blocks.5.mlp.mlp_linear.weight": "model-00001-of-00006.safetensors",
400
+ "blocks.5.mlp.mlp_resid.weight": "model-00002-of-00006.safetensors",
401
+ "blocks.5.mlp_gain.a_g": "model-00002-of-00006.safetensors",
402
+ "blocks.6.attn.attn_key.weight": "model-00002-of-00006.safetensors",
403
+ "blocks.6.attn.attn_query.weight": "model-00002-of-00006.safetensors",
404
+ "blocks.6.attn.attn_resid.weight": "model-00002-of-00006.safetensors",
405
+ "blocks.6.attn.attn_value.weight": "model-00002-of-00006.safetensors",
406
+ "blocks.6.attn.head_gain.head_g": "model-00002-of-00006.safetensors",
407
+ "blocks.6.attn_gain.a_g": "model-00002-of-00006.safetensors",
408
+ "blocks.6.embed_skip.a_g": "model-00002-of-00006.safetensors",
409
+ "blocks.6.mlp.mlp_gate.weight": "model-00002-of-00006.safetensors",
410
+ "blocks.6.mlp.mlp_linear.weight": "model-00002-of-00006.safetensors",
411
+ "blocks.6.mlp.mlp_resid.weight": "model-00002-of-00006.safetensors",
412
+ "blocks.6.mlp_gain.a_g": "model-00002-of-00006.safetensors",
413
+ "blocks.7.attn.attn_key.weight": "model-00002-of-00006.safetensors",
414
+ "blocks.7.attn.attn_query.weight": "model-00002-of-00006.safetensors",
415
+ "blocks.7.attn.attn_resid.weight": "model-00002-of-00006.safetensors",
416
+ "blocks.7.attn.attn_value.weight": "model-00002-of-00006.safetensors",
417
+ "blocks.7.attn.head_gain.head_g": "model-00002-of-00006.safetensors",
418
+ "blocks.7.attn_gain.a_g": "model-00002-of-00006.safetensors",
419
+ "blocks.7.embed_skip.a_g": "model-00002-of-00006.safetensors",
420
+ "blocks.7.mlp.mlp_gate.weight": "model-00002-of-00006.safetensors",
421
+ "blocks.7.mlp.mlp_linear.weight": "model-00002-of-00006.safetensors",
422
+ "blocks.7.mlp.mlp_resid.weight": "model-00002-of-00006.safetensors",
423
+ "blocks.7.mlp_gain.a_g": "model-00002-of-00006.safetensors",
424
+ "blocks.8.attn.attn_key.weight": "model-00002-of-00006.safetensors",
425
+ "blocks.8.attn.attn_query.weight": "model-00002-of-00006.safetensors",
426
+ "blocks.8.attn.attn_resid.weight": "model-00002-of-00006.safetensors",
427
+ "blocks.8.attn.attn_value.weight": "model-00002-of-00006.safetensors",
428
+ "blocks.8.attn.head_gain.head_g": "model-00002-of-00006.safetensors",
429
+ "blocks.8.attn_gain.a_g": "model-00002-of-00006.safetensors",
430
+ "blocks.8.embed_skip.a_g": "model-00002-of-00006.safetensors",
431
+ "blocks.8.mlp.mlp_gate.weight": "model-00002-of-00006.safetensors",
432
+ "blocks.8.mlp.mlp_linear.weight": "model-00002-of-00006.safetensors",
433
+ "blocks.8.mlp.mlp_resid.weight": "model-00002-of-00006.safetensors",
434
+ "blocks.8.mlp_gain.a_g": "model-00002-of-00006.safetensors",
435
+ "blocks.9.attn.attn_key.weight": "model-00002-of-00006.safetensors",
436
+ "blocks.9.attn.attn_query.weight": "model-00002-of-00006.safetensors",
437
+ "blocks.9.attn.attn_resid.weight": "model-00002-of-00006.safetensors",
438
+ "blocks.9.attn.attn_value.weight": "model-00002-of-00006.safetensors",
439
+ "blocks.9.attn.head_gain.head_g": "model-00002-of-00006.safetensors",
440
+ "blocks.9.attn_gain.a_g": "model-00002-of-00006.safetensors",
441
+ "blocks.9.embed_skip.a_g": "model-00002-of-00006.safetensors",
442
+ "blocks.9.mlp.mlp_gate.weight": "model-00002-of-00006.safetensors",
443
+ "blocks.9.mlp.mlp_linear.weight": "model-00002-of-00006.safetensors",
444
+ "blocks.9.mlp.mlp_resid.weight": "model-00002-of-00006.safetensors",
445
+ "blocks.9.mlp_gain.a_g": "model-00002-of-00006.safetensors",
446
+ "embed.weight": "model-00001-of-00006.safetensors",
447
+ "lm_head.weight": "model-00001-of-00006.safetensors"
448
+ }
449
+ }
modeling_talkie.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers.cache_utils import Cache, DynamicCache
7
+ from transformers import GenerationMixin
8
+ from transformers import PreTrainedModel
9
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
10
+
11
+ try:
12
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
13
+ except ImportError: # pragma: no cover - compatibility with older Transformers.
14
+ ALL_ATTENTION_FUNCTIONS = None
15
+
16
+ from .configuration_talkie import TalkieConfig
17
+
18
+
19
+ def eager_attention_forward(
20
+ module: nn.Module,
21
+ query: torch.Tensor,
22
+ key: torch.Tensor,
23
+ value: torch.Tensor,
24
+ attention_mask: torch.Tensor | None,
25
+ dropout: float = 0.0,
26
+ scaling: float | None = None,
27
+ is_causal: bool | None = None,
28
+ **kwargs,
29
+ ) -> tuple[torch.Tensor, None]:
30
+ del kwargs
31
+ is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
32
+ output = F.scaled_dot_product_attention(
33
+ query,
34
+ key,
35
+ value,
36
+ attn_mask=attention_mask,
37
+ dropout_p=dropout,
38
+ scale=scaling,
39
+ is_causal=is_causal and attention_mask is None,
40
+ )
41
+ return output.transpose(1, 2).contiguous(), None
42
+
43
+
44
+ def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
45
+ d = x.shape[3] // 2
46
+ x1 = x[..., :d]
47
+ x2 = x[..., d:]
48
+ y1 = x1 * cos + x2 * sin
49
+ y2 = x1 * (-sin) + x2 * cos
50
+ return torch.cat([y1, y2], 3).type_as(x)
51
+
52
+
53
+ class HeadGain(nn.Module):
54
+ def __init__(self, n_head: int):
55
+ super().__init__()
56
+ self.head_g = nn.Parameter(torch.ones([n_head]))
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ return x * self.head_g.type_as(x).view(1, 1, -1, 1)
60
+
61
+
62
+ class WeightGain(nn.Module):
63
+ def __init__(self):
64
+ super().__init__()
65
+ self.w_g = nn.Parameter(torch.ones(1))
66
+
67
+ def forward(self, w: torch.Tensor) -> torch.Tensor:
68
+ return w * self.w_g.type_as(w)
69
+
70
+
71
+ class ActGain(nn.Module):
72
+ def __init__(self, init_value: float):
73
+ super().__init__()
74
+ self.a_g = nn.Parameter(torch.ones(1) * init_value)
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ return x * self.a_g.type_as(x)
78
+
79
+
80
+ class CausalSelfAttention(nn.Module):
81
+ is_causal = True
82
+
83
+ def __init__(self, config: TalkieConfig, layer_idx: int):
84
+ super().__init__()
85
+ self.config = config
86
+ self.layer_idx = layer_idx
87
+ self.n_head = config.n_head
88
+ self.head_dim = config.head_dim
89
+ n_state = config.n_embd
90
+
91
+ self.attn_query = nn.Linear(n_state, n_state, bias=False)
92
+ self.attn_key = nn.Linear(n_state, n_state, bias=False)
93
+ self.attn_value = nn.Linear(n_state, n_state, bias=False)
94
+ self.attn_resid = nn.Linear(n_state, n_state, bias=False)
95
+ self.head_gain = HeadGain(config.n_head)
96
+
97
+ def forward(
98
+ self,
99
+ x: torch.Tensor,
100
+ cos_sin: tuple[torch.Tensor, torch.Tensor],
101
+ attention_mask: torch.Tensor | None = None,
102
+ **kwargs,
103
+ ) -> torch.Tensor:
104
+ bsz, seq_len, _ = x.size()
105
+ q = self.attn_query(x).view(bsz, seq_len, self.n_head, self.head_dim)
106
+ k = self.attn_key(x).view(bsz, seq_len, self.n_head, self.head_dim)
107
+ v = self.attn_value(x).view(bsz, seq_len, self.n_head, self.head_dim)
108
+
109
+ cos, sin = cos_sin
110
+ q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
111
+ q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),))
112
+ q = self.head_gain(q)
113
+
114
+ key_states = k.transpose(1, 2)
115
+ value_states = v.transpose(1, 2)
116
+ if kwargs.get("past_key_values") is not None:
117
+ key_states, value_states = kwargs["past_key_values"].update(
118
+ key_states, value_states, self.layer_idx
119
+ )
120
+
121
+ if ALL_ATTENTION_FUNCTIONS is None:
122
+ attention_interface = eager_attention_forward
123
+ elif hasattr(ALL_ATTENTION_FUNCTIONS, "get_interface"):
124
+ attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface(
125
+ self.config._attn_implementation, eager_attention_forward
126
+ )
127
+ else: # pragma: no cover - compatibility with older Transformers.
128
+ attention_interface = ALL_ATTENTION_FUNCTIONS.get(
129
+ self.config._attn_implementation, eager_attention_forward
130
+ )
131
+ is_causal = attention_mask is None and key_states.shape[-2] == q.shape[1]
132
+ y, _ = attention_interface(
133
+ self,
134
+ q.transpose(1, 2),
135
+ key_states,
136
+ value_states,
137
+ attention_mask,
138
+ is_causal=is_causal,
139
+ **kwargs,
140
+ )
141
+ y = y.contiguous().view_as(x)
142
+ return self.attn_resid(y)
143
+
144
+
145
+ class MLP(nn.Module):
146
+ def __init__(self, config: TalkieConfig):
147
+ super().__init__()
148
+ n_state = config.n_embd
149
+ n_mlp = int(round(((8 / 3) * n_state) / 128) * 128)
150
+
151
+ self.mlp_gate = nn.Linear(n_state, n_mlp, bias=False)
152
+ self.mlp_linear = nn.Linear(n_state, n_mlp, bias=False)
153
+ self.mlp_resid = nn.Linear(n_mlp, n_state, bias=False)
154
+
155
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
156
+ x = F.silu(self.mlp_gate(x)) * self.mlp_linear(x)
157
+ return self.mlp_resid(x)
158
+
159
+
160
+ class Block(nn.Module):
161
+ def __init__(self, config: TalkieConfig, layer_idx: int):
162
+ super().__init__()
163
+ self.attn = CausalSelfAttention(config, layer_idx)
164
+ self.attn_gain = ActGain((2 * config.n_layer) ** -0.5)
165
+ self.mlp = MLP(config)
166
+ self.mlp_gain = ActGain((2 * config.n_layer) ** -0.5)
167
+ self.embed_skip = ActGain(0.0)
168
+
169
+ def forward(
170
+ self,
171
+ e_x: torch.Tensor,
172
+ x: torch.Tensor,
173
+ cos_sin: tuple[torch.Tensor, torch.Tensor],
174
+ attention_mask: torch.Tensor | None = None,
175
+ **kwargs,
176
+ ) -> torch.Tensor:
177
+ x = x + self.attn_gain(
178
+ self.attn(F.rms_norm(x, (x.shape[-1],)), cos_sin, attention_mask, **kwargs)
179
+ )
180
+ x = x + self.mlp_gain(self.mlp(F.rms_norm(x, (x.shape[-1],))))
181
+ x = x + self.embed_skip(e_x)
182
+ return x
183
+
184
+
185
+ class TalkiePreTrainedModel(PreTrainedModel):
186
+ config_class = TalkieConfig
187
+ base_model_prefix = ""
188
+ supports_gradient_checkpointing = False
189
+ _supports_sdpa = True
190
+ _supports_attention_backend = True
191
+ _no_split_modules = ["Block"]
192
+ _tied_weights_keys = None
193
+
194
+ def _init_weights(self, module: nn.Module) -> None:
195
+ return
196
+
197
+
198
+ class TalkieModel(TalkiePreTrainedModel, GenerationMixin):
199
+ def __init__(self, config: TalkieConfig):
200
+ super().__init__(config)
201
+ self.embed = nn.Embedding(config.vocab_size, config.n_embd)
202
+ self.blocks = nn.ModuleList([Block(config, i) for i in range(config.n_layer)])
203
+
204
+ cos, sin = self._precompute_rotary_embeddings(
205
+ config.max_position_embeddings, config.head_dim, config.rope_base
206
+ )
207
+ self.register_buffer("cos", cos, persistent=False)
208
+ self.register_buffer("sin", sin, persistent=False)
209
+ self._rotary_initialized = cos.device.type != "meta"
210
+ self.post_init()
211
+
212
+ def _precompute_rotary_embeddings(
213
+ self, seq_len: int, head_dim: int, base: int
214
+ ) -> tuple[torch.Tensor, torch.Tensor]:
215
+ device = self.embed.weight.device if hasattr(self, "embed") else "cpu"
216
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
217
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
218
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
219
+ freqs = torch.outer(t, inv_freq)
220
+ cos, sin = freqs.cos(), freqs.sin()
221
+ cos, sin = cos.bfloat16(), sin.bfloat16()
222
+ cos, sin = cos[None, :, None, :], sin[None, :, None, :]
223
+ return cos, sin
224
+
225
+ def _ensure_rotary_embeddings(self, seq_len: int) -> None:
226
+ device = self.embed.weight.device
227
+ needs_init = (
228
+ not self._rotary_initialized
229
+ or self.cos.device != device
230
+ or self.sin.device != device
231
+ or self.cos.shape[1] < seq_len
232
+ )
233
+ if needs_init:
234
+ max_seq_len = max(seq_len, self.config.max_position_embeddings)
235
+ cos, sin = self._precompute_rotary_embeddings(
236
+ max_seq_len, self.config.head_dim, self.config.rope_base
237
+ )
238
+ self.cos = cos.to(device=device)
239
+ self.sin = sin.to(device=device)
240
+ self._rotary_initialized = True
241
+
242
+ def get_input_embeddings(self) -> nn.Embedding:
243
+ return self.embed
244
+
245
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
246
+ self.embed = value
247
+
248
+ def _position_ids(
249
+ self,
250
+ input_ids: torch.LongTensor,
251
+ position_ids: torch.LongTensor | None = None,
252
+ cache_position: torch.LongTensor | None = None,
253
+ past_key_values: Cache | None = None,
254
+ ) -> torch.LongTensor:
255
+ batch_size, seq_len = input_ids.shape
256
+ if position_ids is not None:
257
+ if position_ids.dim() == 1:
258
+ position_ids = position_ids.unsqueeze(0)
259
+ return position_ids.to(device=input_ids.device, dtype=torch.long)
260
+ if cache_position is not None:
261
+ if cache_position.dim() == 1:
262
+ cache_position = cache_position.unsqueeze(0)
263
+ if cache_position.shape[0] == 1 and batch_size != 1:
264
+ cache_position = cache_position.expand(batch_size, -1)
265
+ return cache_position.to(device=input_ids.device, dtype=torch.long)
266
+ past_seen = past_key_values.get_seq_length() if past_key_values is not None else 0
267
+ position_ids = torch.arange(seq_len, device=input_ids.device, dtype=torch.long) + past_seen
268
+ return position_ids.unsqueeze(0)
269
+
270
+ def _attention_mask(
271
+ self,
272
+ attention_mask: torch.Tensor | None,
273
+ input_ids: torch.Tensor,
274
+ position_ids: torch.Tensor,
275
+ past_key_values: Cache | None,
276
+ dtype: torch.dtype,
277
+ ) -> torch.Tensor | None:
278
+ if attention_mask is not None and attention_mask.dim() >= 4:
279
+ return attention_mask
280
+ batch_size, query_length = input_ids.shape
281
+ past_seen = past_key_values.get_seq_length() if past_key_values is not None else 0
282
+ key_length = past_seen + query_length
283
+
284
+ if attention_mask is not None and attention_mask.dim() != 2:
285
+ return attention_mask
286
+ if attention_mask is not None:
287
+ if attention_mask.shape[-1] == query_length and past_seen:
288
+ prefix = torch.ones(
289
+ attention_mask.shape[0],
290
+ past_seen,
291
+ dtype=attention_mask.dtype,
292
+ device=attention_mask.device,
293
+ )
294
+ attention_mask = torch.cat([prefix, attention_mask], dim=-1)
295
+ key_length = attention_mask.shape[-1]
296
+ has_padding = not bool(torch.all(attention_mask == 1))
297
+ else:
298
+ has_padding = False
299
+
300
+ if attention_mask is None and past_seen == 0:
301
+ return None
302
+
303
+ key_positions = torch.arange(key_length, device=input_ids.device, dtype=torch.long)
304
+ future_mask = key_positions.view(1, 1, 1, key_length) > position_ids.view(
305
+ batch_size, 1, query_length, 1
306
+ )
307
+ if attention_mask is not None and has_padding:
308
+ padding_mask = attention_mask[:, None, None, :].to(device=input_ids.device) == 0
309
+ mask = future_mask | padding_mask
310
+ else:
311
+ mask = future_mask
312
+
313
+ if not bool(mask.any()):
314
+ return None
315
+ min_value = torch.finfo(dtype).min
316
+ causal_mask = torch.zeros(
317
+ batch_size, 1, query_length, key_length, dtype=dtype, device=input_ids.device
318
+ )
319
+ return causal_mask.masked_fill(mask, min_value)
320
+
321
+ def forward(
322
+ self,
323
+ input_ids: torch.LongTensor | None = None,
324
+ inputs_embeds: torch.FloatTensor | None = None,
325
+ attention_mask: torch.Tensor | None = None,
326
+ position_ids: torch.LongTensor | None = None,
327
+ past_key_values: Cache | None = None,
328
+ use_cache: bool | None = None,
329
+ return_dict: bool | None = None,
330
+ **kwargs,
331
+ ) -> BaseModelOutputWithPast | tuple[torch.Tensor, ...]:
332
+ cache_position = kwargs.pop("cache_position", None)
333
+ if input_ids is None and inputs_embeds is None:
334
+ raise ValueError("input_ids or inputs_embeds is required")
335
+ if input_ids is not None and inputs_embeds is not None:
336
+ raise ValueError("provide only one of input_ids or inputs_embeds")
337
+ if input_ids is None:
338
+ input_ids = torch.empty(
339
+ inputs_embeds.shape[:2],
340
+ dtype=torch.long,
341
+ device=inputs_embeds.device,
342
+ )
343
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
344
+ if use_cache and past_key_values is None:
345
+ past_key_values = DynamicCache(config=self.config)
346
+
347
+ position_ids = self._position_ids(input_ids, position_ids, cache_position, past_key_values)
348
+ needed_seq_len = int(position_ids.max().item()) + 1
349
+ self._ensure_rotary_embeddings(needed_seq_len)
350
+ if needed_seq_len > self.cos.shape[1]:
351
+ raise ValueError(
352
+ f"Sequence length {needed_seq_len} exceeds max_position_embeddings "
353
+ f"{self.cos.shape[1]}"
354
+ )
355
+
356
+ cos = self.cos[0, position_ids, :, :]
357
+ sin = self.sin[0, position_ids, :, :]
358
+ cos_sin = cos, sin
359
+ x = inputs_embeds if inputs_embeds is not None else self.embed(input_ids)
360
+ x = F.rms_norm(x, (x.shape[-1],))
361
+ attention_mask = self._attention_mask(attention_mask, input_ids, position_ids, past_key_values, x.dtype)
362
+ e_x = x
363
+ for block in self.blocks:
364
+ x = block(
365
+ e_x,
366
+ x,
367
+ cos_sin,
368
+ attention_mask=attention_mask,
369
+ past_key_values=past_key_values if use_cache else None,
370
+ **kwargs,
371
+ )
372
+ x = F.rms_norm(x, (x.shape[-1],))
373
+ past_key_values = past_key_values if use_cache else None
374
+ use_return_dict = return_dict if return_dict is not None else self.config.use_return_dict
375
+ if use_return_dict:
376
+ return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values)
377
+ output = (x,)
378
+ return output + ((past_key_values,) if past_key_values is not None else ())
379
+
380
+
381
+ class TalkieForCausalLM(TalkieModel):
382
+ _tied_weights_keys = None
383
+
384
+ def __init__(self, config: TalkieConfig):
385
+ super().__init__(config)
386
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
387
+ self.post_init()
388
+
389
+ def get_output_embeddings(self) -> nn.Linear:
390
+ return self.lm_head
391
+
392
+ def set_output_embeddings(self, value: nn.Linear) -> None:
393
+ self.lm_head = value
394
+
395
+ def forward(
396
+ self,
397
+ input_ids: torch.LongTensor | None = None,
398
+ attention_mask: torch.Tensor | None = None,
399
+ inputs_embeds: torch.FloatTensor | None = None,
400
+ labels: torch.LongTensor | None = None,
401
+ return_dict: bool | None = None,
402
+ past_key_values: Cache | None = None,
403
+ use_cache: bool | None = None,
404
+ position_ids: torch.LongTensor | None = None,
405
+ logits_to_keep: int | torch.Tensor = 0,
406
+ **kwargs,
407
+ ) -> CausalLMOutputWithPast | tuple[torch.Tensor, ...]:
408
+ if input_ids is None and inputs_embeds is None:
409
+ raise ValueError("input_ids or inputs_embeds is required")
410
+ cache_position = kwargs.pop("cache_position", None)
411
+ outputs = super().forward(
412
+ input_ids,
413
+ inputs_embeds=inputs_embeds,
414
+ attention_mask=attention_mask,
415
+ position_ids=position_ids,
416
+ past_key_values=past_key_values,
417
+ use_cache=use_cache,
418
+ cache_position=cache_position,
419
+ return_dict=True,
420
+ **kwargs,
421
+ )
422
+ hidden_states = outputs.last_hidden_state
423
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
424
+ logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
425
+ if self.config.logit_scale != 1.0:
426
+ logits = logits * self.config.logit_scale
427
+
428
+ loss = None
429
+ if labels is not None:
430
+ shift_logits = logits[..., :-1, :].contiguous()
431
+ shift_labels = labels[..., 1:].contiguous()
432
+ loss = F.cross_entropy(
433
+ shift_logits.view(-1, shift_logits.size(-1)),
434
+ shift_labels.view(-1),
435
+ ignore_index=-100,
436
+ )
437
+
438
+ use_return_dict = return_dict if return_dict is not None else self.config.use_return_dict
439
+ if use_return_dict:
440
+ return CausalLMOutputWithPast(
441
+ loss=loss,
442
+ logits=logits,
443
+ past_key_values=outputs.past_key_values,
444
+ )
445
+ output = (logits,)
446
+ if outputs.past_key_values is not None:
447
+ output += (outputs.past_key_values,)
448
+ return ((loss,) + output) if loss is not None else output
449
+
450
+
451
+ __all__ = ["TalkieConfig", "TalkieForCausalLM", "TalkieModel"]
special_tokens_map.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "eos_token": "<|endoftext|>"
3
+ }
tokenization_talkie.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import shutil
5
+ from pathlib import Path
6
+
7
+ import tiktoken
8
+ from tiktoken.load import load_tiktoken_bpe
9
+ from transformers import PreTrainedTokenizer
10
+
11
+
12
+ BASE_VOCAB_SIZE = 65536
13
+ IT_VOCAB_SIZE = BASE_VOCAB_SIZE + 4
14
+
15
+ _PAT_STR = "|".join(
16
+ [
17
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
18
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
19
+ r"""\p{N}{1,3}""",
20
+ r""" ?[^\s\p{L}\p{N}]+[\r\n/]*""",
21
+ r"""\s*[\r\n]+""",
22
+ r"""\s+(?!\S)""",
23
+ r"""\s+""",
24
+ ]
25
+ )
26
+
27
+ _BASE_SPECIAL_TOKENS = {
28
+ "<|endoftext|>": BASE_VOCAB_SIZE - 1,
29
+ }
30
+
31
+ _IT_SPECIAL_TOKENS = {
32
+ "<|endoftext|>": BASE_VOCAB_SIZE - 1,
33
+ "<|end|>": BASE_VOCAB_SIZE,
34
+ "<|user|>": BASE_VOCAB_SIZE + 1,
35
+ "<|assistant|>": BASE_VOCAB_SIZE + 2,
36
+ "<|system|>": BASE_VOCAB_SIZE + 3,
37
+ }
38
+
39
+
40
+ class TalkieTokenizer(PreTrainedTokenizer):
41
+ vocab_files_names = {"vocab_file": "vocab.txt"}
42
+ model_input_names = ["input_ids", "attention_mask"]
43
+
44
+ def __init__(
45
+ self,
46
+ vocab_file: str,
47
+ style: str = "base",
48
+ model_max_length: int = 2048,
49
+ **kwargs,
50
+ ):
51
+ self.vocab_file = str(vocab_file)
52
+ self.style = style
53
+
54
+ mergeable_ranks = load_tiktoken_bpe(self.vocab_file)
55
+ mergeable_ranks = {
56
+ key: value for key, value in mergeable_ranks.items() if value < BASE_VOCAB_SIZE - 1
57
+ }
58
+ if style == "it":
59
+ special_tokens = dict(_IT_SPECIAL_TOKENS)
60
+ vocab_size = IT_VOCAB_SIZE
61
+ name = "talkie-it"
62
+ elif style == "base":
63
+ special_tokens = dict(_BASE_SPECIAL_TOKENS)
64
+ vocab_size = BASE_VOCAB_SIZE
65
+ name = "talkie-base"
66
+ else:
67
+ raise ValueError(f"unknown Talkie tokenizer style: {style!r}")
68
+
69
+ self.encoder = tiktoken.Encoding(
70
+ name=name,
71
+ pat_str=_PAT_STR,
72
+ mergeable_ranks=mergeable_ranks,
73
+ special_tokens=special_tokens,
74
+ )
75
+ self._vocab_size = vocab_size
76
+ self._special_token_to_id = special_tokens
77
+ self._id_to_special_token = {value: key for key, value in special_tokens.items()}
78
+
79
+ if style == "it":
80
+ kwargs.setdefault("eos_token", "<|end|>")
81
+ kwargs.setdefault(
82
+ "additional_special_tokens",
83
+ ["<|endoftext|>", "<|user|>", "<|assistant|>", "<|system|>"],
84
+ )
85
+ else:
86
+ kwargs.setdefault("eos_token", "<|endoftext|>")
87
+ super().__init__(model_max_length=model_max_length, **kwargs)
88
+
89
+ @property
90
+ def vocab_size(self) -> int:
91
+ return self._vocab_size
92
+
93
+ def get_vocab(self) -> dict[str, int]:
94
+ vocab = {str(index): index for index in range(self._vocab_size)}
95
+ vocab.update(self._special_token_to_id)
96
+ vocab.update(self.get_added_vocab())
97
+ return vocab
98
+
99
+ def _tokenize(self, text: str, **kwargs) -> list[str]:
100
+ return [str(token_id) for token_id in self.encoder.encode(text, allowed_special="all")]
101
+
102
+ def _convert_token_to_id(self, token: str) -> int:
103
+ if token in self._special_token_to_id:
104
+ return self._special_token_to_id[token]
105
+ try:
106
+ token_id = int(token)
107
+ except ValueError:
108
+ return self.eos_token_id
109
+ if 0 <= token_id < self._vocab_size:
110
+ return token_id
111
+ return self.eos_token_id
112
+
113
+ def _convert_id_to_token(self, index: int) -> str:
114
+ index = int(index)
115
+ return self._id_to_special_token.get(index, str(index))
116
+
117
+ def convert_tokens_to_string(self, tokens: list[str]) -> str:
118
+ ids = [self._convert_token_to_id(token) for token in tokens]
119
+ return self.encoder.decode(ids)
120
+
121
+ def _decode(
122
+ self,
123
+ token_ids,
124
+ skip_special_tokens: bool = False,
125
+ clean_up_tokenization_spaces: bool | None = None,
126
+ **kwargs,
127
+ ) -> str:
128
+ if isinstance(token_ids, int):
129
+ token_ids = [token_ids]
130
+ ids = [int(token_id) for token_id in token_ids]
131
+ if skip_special_tokens:
132
+ specials = set(self._special_token_to_id.values())
133
+ ids = [token_id for token_id in ids if token_id not in specials]
134
+ return self.encoder.decode(ids)
135
+
136
+ def build_inputs_with_special_tokens(
137
+ self, token_ids_0: list[int], token_ids_1: list[int] | None = None
138
+ ) -> list[int]:
139
+ if token_ids_1 is None:
140
+ return list(token_ids_0)
141
+ return list(token_ids_0) + list(token_ids_1)
142
+
143
+ def get_special_tokens_mask(
144
+ self,
145
+ token_ids_0: list[int],
146
+ token_ids_1: list[int] | None = None,
147
+ already_has_special_tokens: bool = False,
148
+ ) -> list[int]:
149
+ special_ids = set(self._special_token_to_id.values())
150
+ if already_has_special_tokens:
151
+ return [1 if token_id in special_ids else 0 for token_id in token_ids_0]
152
+ token_ids = list(token_ids_0) if token_ids_1 is None else list(token_ids_0) + list(token_ids_1)
153
+ return [1 if token_id in special_ids else 0 for token_id in token_ids]
154
+
155
+ def create_token_type_ids_from_sequences(
156
+ self, token_ids_0: list[int], token_ids_1: list[int] | None = None
157
+ ) -> list[int]:
158
+ length = len(token_ids_0) if token_ids_1 is None else len(token_ids_0) + len(token_ids_1)
159
+ return [0] * length
160
+
161
+ def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None):
162
+ if not os.path.isdir(save_directory):
163
+ raise ValueError(f"Vocabulary path {save_directory!r} is not a directory")
164
+ name = "vocab.txt" if filename_prefix is None else f"{filename_prefix}-vocab.txt"
165
+ out = Path(save_directory) / name
166
+ if Path(self.vocab_file).resolve() != out.resolve():
167
+ shutil.copyfile(self.vocab_file, out)
168
+ return (str(out),)
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "TalkieTokenizer",
3
+ "auto_map": {
4
+ "AutoTokenizer": [
5
+ "tokenization_talkie.TalkieTokenizer",
6
+ null
7
+ ]
8
+ },
9
+ "model_max_length": 2048,
10
+ "style": "base",
11
+ "eos_token": "<|endoftext|>"
12
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff