MK0727 commited on
Commit
134df9b
·
verified ·
1 Parent(s): 660391c

Upload lambda-160m pretrained model

Browse files
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MyLLMForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_myllm.MyLLMConfig",
7
+ "AutoModelForCausalLM": "modeling_myllm.MyLLMForCausalLM",
8
+ "AutoModel": "modeling_myllm.MyLLMForCausalLM"
9
+ },
10
+ "bos_token_id": 2,
11
+ "d_ff": 3072,
12
+ "d_model": 768,
13
+ "dtype": "float32",
14
+ "eos_token_id": 3,
15
+ "hidden_size": 768,
16
+ "intermediate_size": 3072,
17
+ "learning_rate": 0.0002,
18
+ "max_len": 1024,
19
+ "max_position_embeddings": 1024,
20
+ "model_type": "myllm",
21
+ "num_attention_heads": 12,
22
+ "num_heads": 12,
23
+ "num_hidden_layers": 16,
24
+ "num_layers": 16,
25
+ "pad_token_id": 0,
26
+ "tie_word_embeddings": true,
27
+ "transformers_version": "5.8.0",
28
+ "vocab_size": 65536
29
+ }
configuration_myllm.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedConfig
2
+
3
+
4
+ class MyLLMConfig(PreTrainedConfig):
5
+ model_type = "myllm"
6
+
7
+ def __init__(
8
+ self,
9
+ vocab_size: int = 4,
10
+ max_len: int = 6,
11
+ d_model: int = 2,
12
+ num_layers: int = 2,
13
+ num_heads: int = 1,
14
+ d_ff: int = 8,
15
+ learning_rate: float = 0.1,
16
+ pad_token_id: int = 0,
17
+ bos_token_id: int = 2,
18
+ eos_token_id: int = 3,
19
+ tie_word_embeddings: bool = True,
20
+ **kwargs: object,
21
+ ) -> None:
22
+ # ---------------------------------------------------------
23
+ # Store the architecture values needed to rebuild the
24
+ # PyTorch decoder-only Transformer during AutoModel loading.
25
+ # ---------------------------------------------------------
26
+ self.vocab_size = vocab_size
27
+ self.max_len = max_len
28
+ self.d_model = d_model
29
+ self.num_layers = num_layers
30
+ self.num_heads = num_heads
31
+ self.d_ff = d_ff
32
+ self.learning_rate = learning_rate
33
+ self.tie_word_embeddings = tie_word_embeddings
34
+ self.hidden_size = d_model
35
+ self.num_hidden_layers = num_layers
36
+ self.num_attention_heads = num_heads
37
+ self.intermediate_size = d_ff
38
+ self.max_position_embeddings = max_len
39
+
40
+ # ---------------------------------------------------------
41
+ # Pass standard token ids to the Transformers base config so
42
+ # generation utilities can resolve special tokens normally.
43
+ # ---------------------------------------------------------
44
+ super().__init__(
45
+ pad_token_id=pad_token_id,
46
+ bos_token_id=bos_token_id,
47
+ eos_token_id=eos_token_id,
48
+ **kwargs,
49
+ )
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 2,
4
+ "eos_token_id": 3,
5
+ "output_attentions": false,
6
+ "output_hidden_states": false,
7
+ "pad_token_id": 0,
8
+ "transformers_version": "5.8.0"
9
+ }
kv_cache.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import torch
2
+
3
+ LayerKeyValueCache = tuple[torch.Tensor, torch.Tensor]
4
+ KeyValueCache = list[LayerKeyValueCache]
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccf44ef8dd3ef60402ff149195e31321f994ced33109e58b09e5e596196b4e05
3
+ size 658115811
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11f43b55acbb2069c12d4b2bfe9fb3d4ee523ebba0b1bded93f460dea404a4d7
3
+ size 658085248
model_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"max_len": 1024, "d_model": 768, "num_layers": 16, "num_heads": 12, "d_ff": 3072, "learning_rate": 0.0002, "lr_schedule": "warmup_cosine", "lr_warmup_steps": 2000, "min_learning_rate": 2e-05, "min_learning_rate_ratio": 0.1, "loss_chunk_size": 32, "pad_token_id": 0, "bos_token_id": 2, "eos_token_id": 3, "corpus_signature": "551ac72eceb57f5f", "dataset_cases": [{"name": "fineweb2-edu-ja", "genre": "web", "language": "ja", "dataset_path": "hotchpotch/fineweb-2-edu-japanese", "config_name": "default", "split": "train", "text_column": "text", "token_percentage": 30.0, "is_ramped": false, "repeat_on_end": true, "excluded_url_domains": ["wikipedia.org"]}, {"name": "cleanedwiki-jp", "genre": "wiki", "language": "ja", "dataset_path": "MK0727/CleanedWiki-jp", "config_name": "all", "split": "train", "text_column": "text", "token_percentage": 70.0, "is_ramped": true, "repeat_on_end": true, "excluded_url_domains": []}], "mix_cycle_tokens": 100000, "ramp_start_progress": 0.5, "val_split_modulo": 100, "val_split_index": 0, "validation_cache_path": "models/lambda-160m/validation-cache-551ac72eceb57f5f-bos-eos-text-hash-len1024-samples6144-split100-0.pt", "validation_sample_count": 6144, "trained_steps": 40960}
modeling_myllm.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ from transformers.generation import GenerationMixin
5
+ from transformers.modeling_outputs import CausalLMOutputWithPast
6
+
7
+ from .configuration_myllm import MyLLMConfig
8
+ from .kv_cache import KeyValueCache
9
+ from .position_encoding import PositionEncoding
10
+ from .self_attention import Attention
11
+ from .transformer import DecoderOnlyTransformer
12
+
13
+ # ---------------------------------------------------------
14
+ # Reference nested remote-code dependencies directly so local
15
+ # AutoModel loading copies every file needed by relative imports.
16
+ # ---------------------------------------------------------
17
+ REMOTE_CODE_DEPENDENCIES = (Attention, PositionEncoding)
18
+
19
+
20
+ class MyLLMForCausalLM(PreTrainedModel, GenerationMixin):
21
+ config_class = MyLLMConfig
22
+ main_input_name = "input_ids"
23
+ _tied_weights_keys = {"transformer.fc_layer.weight": "transformer.we.weight"}
24
+
25
+ def __init__(self, config: MyLLMConfig) -> None:
26
+ super().__init__(config)
27
+
28
+ # ---------------------------------------------------------
29
+ # Reuse the existing PyTorch Transformer implementation and
30
+ # keep the HF wrapper responsible only for AutoModel APIs.
31
+ # ---------------------------------------------------------
32
+ self.transformer = DecoderOnlyTransformer(
33
+ num_tokens=config.vocab_size,
34
+ d_model=config.d_model,
35
+ max_len=config.max_len,
36
+ num_layers=config.num_layers,
37
+ num_heads=config.num_heads,
38
+ d_ff=config.d_ff,
39
+ learning_rate=config.learning_rate,
40
+ pad_token_id=config.pad_token_id,
41
+ )
42
+ self.post_init()
43
+
44
+ def get_input_embeddings(self) -> nn.Embedding:
45
+ # ---------------------------------------------------------
46
+ # Expose input embeddings through the standard Transformers
47
+ # interface used by resizing and generation helpers.
48
+ # ---------------------------------------------------------
49
+ return self.transformer.we
50
+
51
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
52
+ # ---------------------------------------------------------
53
+ # Keep tied output weights aligned when callers replace the
54
+ # token embedding module through the Transformers interface.
55
+ # ---------------------------------------------------------
56
+ self.transformer.we = value
57
+ self.transformer.fc_layer.weight = value.weight
58
+
59
+ def get_output_embeddings(self) -> nn.Linear:
60
+ # ---------------------------------------------------------
61
+ # Expose the tied LM head through the standard Transformers
62
+ # interface used by causal language model utilities.
63
+ # ---------------------------------------------------------
64
+ return self.transformer.fc_layer
65
+
66
+ def set_output_embeddings(self, value: nn.Linear) -> None:
67
+ # ---------------------------------------------------------
68
+ # Allow Transformers utilities to replace the LM head while
69
+ # preserving the module expected by the existing model.
70
+ # ---------------------------------------------------------
71
+ self.transformer.fc_layer = value
72
+
73
+ def _supports_default_dynamic_cache(self) -> bool:
74
+ # ---------------------------------------------------------
75
+ # Use the existing list-based KV cache instead of letting
76
+ # Transformers allocate its DynamicCache implementation.
77
+ # ---------------------------------------------------------
78
+ return False
79
+
80
+ def prepare_inputs_for_generation(
81
+ self,
82
+ input_ids: torch.Tensor,
83
+ past_key_values: KeyValueCache | None = None,
84
+ **kwargs: object,
85
+ ) -> dict[str, torch.Tensor | KeyValueCache | bool | None]:
86
+ # ---------------------------------------------------------
87
+ # Feed only the newest token after the cache is populated so
88
+ # generate can reuse the existing incremental forward path.
89
+ # ---------------------------------------------------------
90
+ del kwargs
91
+ model_input_ids = input_ids[:, -1:] if past_key_values is not None else input_ids
92
+ return {
93
+ "input_ids": model_input_ids,
94
+ "past_key_values": past_key_values,
95
+ "use_cache": True,
96
+ }
97
+
98
+ def forward(
99
+ self,
100
+ input_ids: torch.Tensor | None = None,
101
+ labels: torch.Tensor | None = None,
102
+ past_key_values: KeyValueCache | None = None,
103
+ use_cache: bool | None = None,
104
+ return_dict: bool | None = None,
105
+ **kwargs: object,
106
+ ) -> CausalLMOutputWithPast | tuple[torch.Tensor, ...]:
107
+ # ---------------------------------------------------------
108
+ # Accept the standard AutoModelForCausalLM argument names and
109
+ # delegate the actual tensor computation to the PyTorch model.
110
+ # ---------------------------------------------------------
111
+ del kwargs
112
+
113
+ if input_ids is None:
114
+ raise ValueError("input_ids is required")
115
+
116
+ should_use_cache = bool(use_cache)
117
+
118
+ if past_key_values is not None or should_use_cache:
119
+ logits, next_key_values = self.transformer.forward_with_cache(
120
+ token_ids=input_ids,
121
+ past_key_values=past_key_values,
122
+ )
123
+ else:
124
+ logits = self.transformer(token_ids=input_ids)
125
+ next_key_values = None
126
+
127
+ # ---------------------------------------------------------
128
+ # Follow causal LM convention for labels supplied by HF
129
+ # Trainer and examples: predict token n+1 from position n.
130
+ # ---------------------------------------------------------
131
+ loss = None
132
+
133
+ if labels is not None:
134
+ shift_logits = logits[:, :-1, :].contiguous()
135
+ shift_labels = labels[:, 1:].contiguous()
136
+ loss = nn.functional.cross_entropy(
137
+ shift_logits.view(-1, self.config.vocab_size),
138
+ shift_labels.view(-1),
139
+ ignore_index=self.config.pad_token_id,
140
+ )
141
+
142
+ # ---------------------------------------------------------
143
+ # Return either the standard modeling output or a tuple for
144
+ # callers that explicitly disable dictionary-style outputs.
145
+ # ---------------------------------------------------------
146
+ if return_dict is False:
147
+ output = (logits,)
148
+ return (loss, *output) if loss is not None else output
149
+
150
+ return CausalLMOutputWithPast(
151
+ loss=loss,
152
+ logits=logits,
153
+ past_key_values=next_key_values,
154
+ )
position_encoding.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class PositionEncoding(nn.Module):
6
+ def __init__(self, d_model: int = 2, max_len: int = 6) -> None:
7
+ super().__init__()
8
+
9
+ # ---------------------------------------------------------
10
+ # Precompute sinusoidal positions once so token embeddings
11
+ # can be shifted cheaply during training and inference.
12
+ # ---------------------------------------------------------
13
+ pe = torch.zeros(max_len, d_model)
14
+ position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)
15
+ embedding_index = torch.arange(start=0, end=d_model, step=2).float()
16
+ div_term = 1 / torch.tensor(10000.0) ** (embedding_index / d_model)
17
+
18
+ pe[:, 0::2] = torch.sin(position * div_term)
19
+ pe[:, 1::2] = torch.cos(position * div_term)
20
+
21
+ self.register_buffer("pe", pe)
22
+
23
+ def forward(self, word_embeddings: torch.Tensor, position_offset: int = 0) -> torch.Tensor:
24
+ # ---------------------------------------------------------
25
+ # Add positions for the visible slice, starting at the cache
26
+ # length when incremental inference supplies an offset.
27
+ # ---------------------------------------------------------
28
+ seq_len = word_embeddings.size(1)
29
+ position_end = position_offset + seq_len
30
+ return word_embeddings + self.pe[position_offset:position_end, :].unsqueeze(0)
31
+
32
+
33
+ if __name__ == "__main__":
34
+ n = PositionEncoding()
self_attention.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from src.pretraining.kv_cache import LayerKeyValueCache
6
+
7
+
8
+ class Attention(nn.Module):
9
+ def __init__(self, d_model: int = 2, num_heads: int = 1) -> None:
10
+ super().__init__()
11
+
12
+ # ---------------------------------------------------------
13
+ # Split the model dimension into multiple heads so the same
14
+ # attention module can be reused in a more general structure.
15
+ # ---------------------------------------------------------
16
+ if d_model % num_heads != 0:
17
+ raise ValueError("d_model must be divisible by num_heads")
18
+
19
+ self.d_model = d_model
20
+ self.num_heads = num_heads
21
+ self.head_dim = d_model // num_heads
22
+
23
+ # ---------------------------------------------------------
24
+ # Project inputs into query, key, and value spaces and merge
25
+ # the heads back into the model dimension after attention.
26
+ # ---------------------------------------------------------
27
+ self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
28
+ self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
29
+ self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
30
+ self.W_o = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
31
+
32
+ def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
33
+ # ---------------------------------------------------------
34
+ # Rearrange the last dimension into head count and head size
35
+ # so attention can be computed independently per head.
36
+ # ---------------------------------------------------------
37
+ batch_size, seq_len, _ = x.size()
38
+ reshaped = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
39
+ return reshaped.transpose(1, 2)
40
+
41
+ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
42
+ # ---------------------------------------------------------
43
+ # Restore the tensor to the original model dimension after
44
+ # per-head attention has been combined.
45
+ # ---------------------------------------------------------
46
+ batch_size, _, seq_len, _ = x.size()
47
+ transposed = x.transpose(1, 2).contiguous()
48
+ return transposed.view(batch_size, seq_len, self.d_model)
49
+
50
+ def forward(
51
+ self,
52
+ encoding_for_q: torch.Tensor,
53
+ encoding_for_k: torch.Tensor,
54
+ encoding_for_v: torch.Tensor,
55
+ is_causal: bool = False,
56
+ ) -> torch.Tensor:
57
+ # ---------------------------------------------------------
58
+ # Create the projected queries, keys, and values for each
59
+ # attention head from the incoming hidden states.
60
+ # ---------------------------------------------------------
61
+ q = self._split_heads(self.W_q(encoding_for_q))
62
+ k = self._split_heads(self.W_k(encoding_for_k))
63
+ v = self._split_heads(self.W_v(encoding_for_v))
64
+
65
+ # ---------------------------------------------------------
66
+ # Use PyTorch's fused scaled dot-product attention so large
67
+ # score and softmax tensors do not need to be materialized.
68
+ # ---------------------------------------------------------
69
+ attention_scores = F.scaled_dot_product_attention(
70
+ q,
71
+ k,
72
+ v,
73
+ is_causal=is_causal,
74
+ )
75
+
76
+ # ---------------------------------------------------------
77
+ # Merge the attended heads and project the result back into
78
+ # the model dimension for the next layer.
79
+ # ---------------------------------------------------------
80
+ merged_scores = self._merge_heads(attention_scores)
81
+ return self.W_o(merged_scores)
82
+
83
+ def forward_with_cache(
84
+ self,
85
+ encoding_for_q: torch.Tensor,
86
+ encoding_for_k: torch.Tensor,
87
+ encoding_for_v: torch.Tensor,
88
+ past_key_value: LayerKeyValueCache | None,
89
+ is_causal: bool = False,
90
+ ) -> tuple[torch.Tensor, LayerKeyValueCache]:
91
+ # ---------------------------------------------------------
92
+ # Project the current tokens and append previous keys and
93
+ # values so generation can avoid recomputing old states.
94
+ # ---------------------------------------------------------
95
+ q = self._split_heads(self.W_q(encoding_for_q))
96
+ current_k = self._split_heads(self.W_k(encoding_for_k))
97
+ current_v = self._split_heads(self.W_v(encoding_for_v))
98
+
99
+ k = current_k
100
+ v = current_v
101
+
102
+ if past_key_value is not None:
103
+ past_k, past_v = past_key_value
104
+ k = torch.cat((past_k, current_k), dim=2)
105
+ v = torch.cat((past_v, current_v), dim=2)
106
+
107
+ # ---------------------------------------------------------
108
+ # Attend the current query positions over cached and current
109
+ # keys with the fused scaled dot-product implementation.
110
+ # ---------------------------------------------------------
111
+ attention_scores = F.scaled_dot_product_attention(
112
+ q,
113
+ k,
114
+ v,
115
+ is_causal=is_causal,
116
+ )
117
+
118
+ # ---------------------------------------------------------
119
+ # Return both the attention result and the updated cache for
120
+ # this layer so the caller can feed the next token directly.
121
+ # ---------------------------------------------------------
122
+ merged_scores = self._merge_heads(attention_scores)
123
+ return self.W_o(merged_scores), (k, v)
special_tokens_map.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pad_token": "|<pad>|",
3
+ "unk_token": "|<unknown>|",
4
+ "bos_token": "|<bos>|",
5
+ "eos_token": "|<eos>|",
6
+ "sep_token": "|<sep>|",
7
+ "cls_token": "|<cls>|",
8
+ "mask_token": "|<mask>|",
9
+ "extra_special_tokens": [
10
+ "|<system>|",
11
+ "|<user>|",
12
+ "|<assistant>|",
13
+ "|<thinking>|",
14
+ "|<end_of_thinking>|",
15
+ "|<end_of_turn>|"
16
+ ]
17
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": "|<bos>|",
4
+ "cls_token": "|<cls>|",
5
+ "eos_token": "|<eos>|",
6
+ "extra_special_tokens": [
7
+ "|<system>|",
8
+ "|<user>|",
9
+ "|<assistant>|",
10
+ "|<thinking>|",
11
+ "|<end_of_thinking>|",
12
+ "|<end_of_turn>|"
13
+ ],
14
+ "mask_token": "|<mask>|",
15
+ "model_max_length": 1000000000000000019884624838656,
16
+ "pad_token": "|<pad>|",
17
+ "sep_token": "|<sep>|",
18
+ "tokenizer_class": "TokenizersBackend",
19
+ "unk_token": "|<unknown>|"
20
+ }
transformer.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.optim import AdamW
6
+ from torch.optim.lr_scheduler import LambdaLR
7
+ import lightning as L
8
+
9
+ from .kv_cache import KeyValueCache, LayerKeyValueCache
10
+ from .position_encoding import PositionEncoding
11
+ from .self_attention import Attention
12
+
13
+
14
+ class FeedForward(nn.Module):
15
+ def __init__(self, d_model: int, d_ff: int) -> None:
16
+ super().__init__()
17
+
18
+ # ---------------------------------------------------------
19
+ # Use the standard Transformer feed-forward sublayer so each
20
+ # token can be transformed independently after attention.
21
+ # ---------------------------------------------------------
22
+ self.linear_1 = nn.Linear(in_features=d_model, out_features=d_ff)
23
+ self.activation = nn.GELU()
24
+ self.linear_2 = nn.Linear(in_features=d_ff, out_features=d_model)
25
+
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ # ---------------------------------------------------------
28
+ # Expand the channel dimension, apply a non-linearity, and
29
+ # project back to the model dimension.
30
+ # ---------------------------------------------------------
31
+ hidden = self.linear_1(x)
32
+ activated = self.activation(hidden)
33
+ return self.linear_2(activated)
34
+
35
+
36
+ class DecoderBlock(nn.Module):
37
+ def __init__(self, d_model: int, num_heads: int, d_ff: int) -> None:
38
+ super().__init__()
39
+
40
+ # ---------------------------------------------------------
41
+ # Compose one decoder block from attention, feed-forward, and
42
+ # RMS normalization layers with residual connections.
43
+ # ---------------------------------------------------------
44
+ self.norm_1 = nn.RMSNorm(normalized_shape=d_model)
45
+ self.attention = Attention(d_model=d_model, num_heads=num_heads)
46
+ self.norm_2 = nn.RMSNorm(normalized_shape=d_model)
47
+ self.feed_forward = FeedForward(d_model=d_model, d_ff=d_ff)
48
+
49
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
50
+ # ---------------------------------------------------------
51
+ # Apply pre-norm self-attention so multiple decoder blocks can
52
+ # be stacked without changing the external interface.
53
+ # ---------------------------------------------------------
54
+ attention_input = self.norm_1(x)
55
+ attention_output = self.attention(
56
+ attention_input,
57
+ attention_input,
58
+ attention_input,
59
+ is_causal=True,
60
+ )
61
+ attention_residual = x + attention_output
62
+
63
+ # ---------------------------------------------------------
64
+ # Apply the position-wise feed-forward network as the second
65
+ # sublayer inside the decoder block.
66
+ # ---------------------------------------------------------
67
+ feed_forward_input = self.norm_2(attention_residual)
68
+ feed_forward_output = self.feed_forward(feed_forward_input)
69
+ return attention_residual + feed_forward_output
70
+
71
+ def forward_with_cache(
72
+ self,
73
+ x: torch.Tensor,
74
+ past_key_value: LayerKeyValueCache | None,
75
+ ) -> tuple[torch.Tensor, LayerKeyValueCache]:
76
+ # ---------------------------------------------------------
77
+ # Apply self-attention with a layer-local cache, then keep the
78
+ # feed-forward path identical to the full sequence forward.
79
+ # ---------------------------------------------------------
80
+ attention_input = self.norm_1(x)
81
+ attention_output, key_value_cache = self.attention.forward_with_cache(
82
+ attention_input,
83
+ attention_input,
84
+ attention_input,
85
+ past_key_value,
86
+ is_causal=past_key_value is None,
87
+ )
88
+ attention_residual = x + attention_output
89
+
90
+ # ---------------------------------------------------------
91
+ # Transform only the visible token states because old states
92
+ # have already been folded into the cached keys and values.
93
+ # ---------------------------------------------------------
94
+ feed_forward_input = self.norm_2(attention_residual)
95
+ feed_forward_output = self.feed_forward(feed_forward_input)
96
+ return attention_residual + feed_forward_output, key_value_cache
97
+
98
+
99
+ class DecoderOnlyTransformer(L.LightningModule):
100
+ def __init__(
101
+ self,
102
+ num_tokens: int = 4,
103
+ d_model: int = 2,
104
+ max_len: int = 6,
105
+ num_layers: int = 2,
106
+ num_heads: int = 1,
107
+ d_ff: int = 8,
108
+ learning_rate: float = 0.1,
109
+ pad_token_id: int = 0,
110
+ use_fused_optimizer: bool = False,
111
+ loss_chunk_size: int = 32,
112
+ lr_warmup_steps: int | None = None,
113
+ lr_total_steps: int | None = None,
114
+ min_learning_rate: float | None = None,
115
+ ) -> None:
116
+ super().__init__()
117
+
118
+ # ---------------------------------------------------------
119
+ # Embed tokens and positions before passing them through a
120
+ # stack of decoder blocks.
121
+ # ---------------------------------------------------------
122
+ self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)
123
+ self.pe = PositionEncoding(d_model=d_model, max_len=max_len)
124
+ self.blocks = nn.ModuleList(
125
+ [DecoderBlock(d_model=d_model, num_heads=num_heads, d_ff=d_ff) for _ in range(num_layers)]
126
+ )
127
+ self.final_norm = nn.RMSNorm(normalized_shape=d_model)
128
+ self.fc_layer = nn.Linear(in_features=d_model, out_features=num_tokens)
129
+
130
+ # ---------------------------------------------------------
131
+ # Share token embedding weights with the output projection
132
+ # so small models spend more parameters inside the blocks.
133
+ # ---------------------------------------------------------
134
+ self.fc_layer.weight = self.we.weight
135
+ self.learning_rate = learning_rate
136
+ self.pad_token_id = pad_token_id
137
+ self.use_fused_optimizer = use_fused_optimizer
138
+ self.loss_chunk_size = loss_chunk_size
139
+ self.lr_warmup_steps = lr_warmup_steps
140
+ self.lr_total_steps = lr_total_steps
141
+ self.min_learning_rate = min_learning_rate
142
+
143
+ # ---------------------------------------------------------
144
+ # Reject partially configured schedules so posttraining can
145
+ # keep fixed LR while pretraining opts into full scheduling.
146
+ # ---------------------------------------------------------
147
+ lr_schedule_values = [lr_warmup_steps, lr_total_steps, min_learning_rate]
148
+
149
+ if any(value is None for value in lr_schedule_values) and any(
150
+ value is not None for value in lr_schedule_values
151
+ ):
152
+ raise ValueError("LR schedule requires warmup steps, total steps, and minimum learning rate")
153
+
154
+ # ---------------------------------------------------------
155
+ # Keep summed token loss local so large vocabulary logits
156
+ # can be reduced chunk by chunk during training.
157
+ # ---------------------------------------------------------
158
+ self.loss = nn.CrossEntropyLoss(ignore_index=pad_token_id, reduction="sum")
159
+
160
+ def forward_hidden(self, token_ids: torch.Tensor) -> torch.Tensor:
161
+ # ---------------------------------------------------------
162
+ # Convert token ids into hidden states and apply positional
163
+ # information before the decoder stack.
164
+ # ---------------------------------------------------------
165
+ word_embeddings = self.we(token_ids)
166
+ hidden_states = self.pe(word_embeddings)
167
+
168
+ # ---------------------------------------------------------
169
+ # Reuse the same decoder block interface for every layer to
170
+ # make the model depth configurable.
171
+ # ---------------------------------------------------------
172
+ for block in self.blocks:
173
+ hidden_states = block(hidden_states)
174
+
175
+ # ---------------------------------------------------------
176
+ # Normalize the final hidden states and map them into token
177
+ # logits for next-token prediction.
178
+ # ---------------------------------------------------------
179
+ return self.final_norm(hidden_states)
180
+
181
+ def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
182
+ # ---------------------------------------------------------
183
+ # Keep the public forward path returning full vocabulary
184
+ # logits for inference and compatibility with callers.
185
+ # ---------------------------------------------------------
186
+ hidden_states = self.forward_hidden(token_ids)
187
+ return self.fc_layer(hidden_states)
188
+
189
+ def forward_with_cache(
190
+ self,
191
+ token_ids: torch.Tensor,
192
+ past_key_values: KeyValueCache | None,
193
+ ) -> tuple[torch.Tensor, KeyValueCache]:
194
+ # ---------------------------------------------------------
195
+ # Offset positions by the cached sequence length so one-token
196
+ # inference matches full-sequence absolute positions.
197
+ # ---------------------------------------------------------
198
+ position_offset = 0
199
+
200
+ if past_key_values is not None:
201
+ position_offset = past_key_values[0][0].size(dim=2)
202
+
203
+ word_embeddings = self.we(token_ids)
204
+ hidden_states = self.pe(word_embeddings, position_offset=position_offset)
205
+ next_key_values: KeyValueCache = []
206
+
207
+ # ---------------------------------------------------------
208
+ # Pass each layer its own cache entry and collect the updated
209
+ # entries in the same order for the next generation step.
210
+ # ---------------------------------------------------------
211
+ for layer_index, block in enumerate(self.blocks):
212
+ past_key_value = None if past_key_values is None else past_key_values[layer_index]
213
+ hidden_states, key_value_cache = block.forward_with_cache(
214
+ hidden_states,
215
+ past_key_value,
216
+ )
217
+ next_key_values.append(key_value_cache)
218
+
219
+ # ---------------------------------------------------------
220
+ # Produce logits only for the currently supplied token slice
221
+ # while returning cache tensors that include all past tokens.
222
+ # ---------------------------------------------------------
223
+ normalized_hidden_states = self.final_norm(hidden_states)
224
+ return self.fc_layer(normalized_hidden_states), next_key_values
225
+
226
+ def configure_optimizers(self) -> AdamW | dict[str, object]:
227
+ # ---------------------------------------------------------
228
+ # Use AdamW for decoupled weight decay and enable the fused
229
+ # CUDA implementation only when the training script requests it.
230
+ # ---------------------------------------------------------
231
+ optimizer = AdamW(
232
+ self.parameters(),
233
+ lr=self.learning_rate,
234
+ fused=self.use_fused_optimizer,
235
+ )
236
+
237
+ # ---------------------------------------------------------
238
+ # Keep callers without scheduler settings on fixed learning
239
+ # rate while pretraining uses step-wise warmup and cosine decay.
240
+ # ---------------------------------------------------------
241
+ if self.lr_warmup_steps is None or self.lr_total_steps is None or self.min_learning_rate is None:
242
+ return optimizer
243
+
244
+ scheduler = LambdaLR(
245
+ optimizer=optimizer,
246
+ lr_lambda=lambda step: resolve_warmup_cosine_learning_rate(
247
+ step=step,
248
+ max_learning_rate=self.learning_rate,
249
+ min_learning_rate=self.min_learning_rate,
250
+ warmup_steps=self.lr_warmup_steps,
251
+ total_steps=self.lr_total_steps,
252
+ )
253
+ / self.learning_rate,
254
+ )
255
+ return {
256
+ "optimizer": optimizer,
257
+ "lr_scheduler": {
258
+ "scheduler": scheduler,
259
+ "interval": "step",
260
+ "frequency": 1,
261
+ },
262
+ }
263
+
264
+ def compute_chunked_loss(self, input_tokens: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
265
+ # ---------------------------------------------------------
266
+ # Run the Transformer stack once, then split only the large
267
+ # vocabulary projection and cross-entropy over token positions.
268
+ # ---------------------------------------------------------
269
+ hidden_states = self.forward_hidden(input_tokens)
270
+ seq_len = hidden_states.size(dim=1)
271
+ chunk_starts = range(0, seq_len, self.loss_chunk_size)
272
+
273
+ # ---------------------------------------------------------
274
+ # Accumulate summed token losses so padding can be ignored
275
+ # with the same weighting as a single full cross-entropy call.
276
+ # ---------------------------------------------------------
277
+ loss_chunks = [
278
+ self.loss(
279
+ self.fc_layer(
280
+ hidden_states[:, chunk_start : chunk_start + self.loss_chunk_size, :]
281
+ ).transpose(1, 2),
282
+ labels[:, chunk_start : chunk_start + self.loss_chunk_size],
283
+ )
284
+ for chunk_start in chunk_starts
285
+ ]
286
+ total_loss = torch.stack(loss_chunks).sum()
287
+ valid_token_count = labels.ne(self.pad_token_id).sum()
288
+ return total_loss / valid_token_count
289
+
290
+ def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
291
+ # ---------------------------------------------------------
292
+ # Run the forward pass and compute token-level cross-entropy
293
+ # against the shifted labels.
294
+ # ---------------------------------------------------------
295
+ del batch_idx
296
+ input_tokens, labels = batch
297
+ loss = self.compute_chunked_loss(input_tokens=input_tokens, labels=labels)
298
+ self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False)
299
+ return loss
300
+
301
+ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
302
+ # ---------------------------------------------------------
303
+ # Reuse the same autoregressive loss during validation so
304
+ # checkpoints can monitor held-out next-token accuracy.
305
+ # ---------------------------------------------------------
306
+ del batch_idx
307
+ input_tokens, labels = batch
308
+ loss = self.compute_chunked_loss(input_tokens=input_tokens, labels=labels)
309
+ self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
310
+ return loss
311
+
312
+
313
+ def resolve_warmup_cosine_learning_rate(
314
+ step: int,
315
+ max_learning_rate: float,
316
+ min_learning_rate: float,
317
+ warmup_steps: int,
318
+ total_steps: int,
319
+ ) -> float:
320
+ # ---------------------------------------------------------
321
+ # Raise the learning rate linearly at the start, then decay it
322
+ # smoothly to the configured minimum by the final training step.
323
+ # ---------------------------------------------------------
324
+ if step < warmup_steps:
325
+ return max_learning_rate * step / warmup_steps
326
+
327
+ decay_progress = min(1.0, (step - warmup_steps) / (total_steps - warmup_steps))
328
+ cosine_scale = 0.5 * (1.0 + math.cos(math.pi * decay_progress))
329
+ return min_learning_rate + (max_learning_rate - min_learning_rate) * cosine_scale