xlr8harder commited on
Commit
294862a
·
verified ·
1 Parent(s): 15d5bc6

Fix vLLM CUDA graph capture in forward path

Browse files
Files changed (2) hide show
  1. configuration_talkie.py +54 -0
  2. modeling_talkie.py +221 -42
configuration_talkie.py CHANGED
@@ -1,5 +1,7 @@
1
  from __future__ import annotations
2
 
 
 
3
  from transformers import PretrainedConfig
4
 
5
 
@@ -15,6 +17,8 @@ class TalkieConfig(PretrainedConfig):
15
  head_dim: int = 128,
16
  max_position_embeddings: int = 2048,
17
  rope_base: int = 1_000_000,
 
 
18
  logit_scale: float = 1.0,
19
  use_cache: bool = True,
20
  tie_word_embeddings: bool = False,
@@ -23,6 +27,11 @@ class TalkieConfig(PretrainedConfig):
23
  pad_token_id: int | None = None,
24
  **kwargs,
25
  ):
 
 
 
 
 
26
  super().__init__(
27
  bos_token_id=bos_token_id,
28
  eos_token_id=eos_token_id,
@@ -37,6 +46,8 @@ class TalkieConfig(PretrainedConfig):
37
  self.head_dim = head_dim
38
  self.max_position_embeddings = max_position_embeddings
39
  self.rope_base = rope_base
 
 
40
  self.logit_scale = logit_scale
41
  self.use_cache = use_cache
42
 
@@ -44,3 +55,46 @@ class TalkieConfig(PretrainedConfig):
44
  self.hidden_size = n_embd
45
  self.num_hidden_layers = n_layer
46
  self.num_attention_heads = n_head
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ from collections.abc import Mapping
4
+
5
  from transformers import PretrainedConfig
6
 
7
 
 
17
  head_dim: int = 128,
18
  max_position_embeddings: int = 2048,
19
  rope_base: int = 1_000_000,
20
+ rope_scaling: dict | None = None,
21
+ rope_parameters: dict | None = None,
22
  logit_scale: float = 1.0,
23
  use_cache: bool = True,
24
  tie_word_embeddings: bool = False,
 
27
  pad_token_id: int | None = None,
28
  **kwargs,
29
  ):
30
+ if rope_scaling is None:
31
+ rope_scaling = rope_parameters
32
+ self.max_position_embeddings = max_position_embeddings
33
+ self.rope_scaling = self._normalize_rope_scaling(rope_scaling)
34
+ self.rope_parameters = self.rope_scaling
35
  super().__init__(
36
  bos_token_id=bos_token_id,
37
  eos_token_id=eos_token_id,
 
46
  self.head_dim = head_dim
47
  self.max_position_embeddings = max_position_embeddings
48
  self.rope_base = rope_base
49
+ self.rope_scaling = self._normalize_rope_scaling(rope_scaling)
50
+ self.rope_parameters = self.rope_scaling
51
  self.logit_scale = logit_scale
52
  self.use_cache = use_cache
53
 
 
55
  self.hidden_size = n_embd
56
  self.num_hidden_layers = n_layer
57
  self.num_attention_heads = n_head
58
+
59
+ @staticmethod
60
+ def _normalize_rope_scaling(rope_scaling: dict | None) -> dict | None:
61
+ if rope_scaling is None:
62
+ return None
63
+ if not isinstance(rope_scaling, Mapping):
64
+ raise TypeError("rope_scaling must be a dictionary")
65
+
66
+ scaling = dict(rope_scaling)
67
+ rope_type = scaling.get("rope_type", scaling.get("type"))
68
+ if rope_type is None:
69
+ raise ValueError("rope_scaling must include 'rope_type' or 'type'")
70
+
71
+ rope_type = str(rope_type).lower()
72
+ if rope_type == "ntk":
73
+ rope_type = "dynamic"
74
+ supported = {"default", "linear", "dynamic", "yarn"}
75
+ if rope_type not in supported:
76
+ raise ValueError(
77
+ f"unsupported rope_scaling type {rope_type!r}; expected one of {sorted(supported)}"
78
+ )
79
+
80
+ if rope_type == "default":
81
+ return None
82
+
83
+ factor = float(scaling.get("factor", 1.0))
84
+ if factor < 1.0:
85
+ raise ValueError("rope_scaling factor must be >= 1.0")
86
+
87
+ scaling["rope_type"] = rope_type
88
+ scaling.pop("type", None)
89
+ scaling["factor"] = factor
90
+ if "original_max_position_embeddings" in scaling:
91
+ scaling["original_max_position_embeddings"] = int(
92
+ scaling["original_max_position_embeddings"]
93
+ )
94
+ if "beta_fast" in scaling:
95
+ scaling["beta_fast"] = float(scaling["beta_fast"])
96
+ if "beta_slow" in scaling:
97
+ scaling["beta_slow"] = float(scaling["beta_slow"])
98
+ if "attention_factor" in scaling and scaling["attention_factor"] is not None:
99
+ scaling["attention_factor"] = float(scaling["attention_factor"])
100
+ return scaling
modeling_talkie.py CHANGED
@@ -1,5 +1,7 @@
1
  from __future__ import annotations
2
 
 
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
@@ -185,7 +187,7 @@ class Block(nn.Module):
185
  class TalkiePreTrainedModel(PreTrainedModel):
186
  config_class = TalkieConfig
187
  base_model_prefix = ""
188
- supports_gradient_checkpointing = False
189
  _supports_sdpa = True
190
  _supports_attention_backend = True
191
  _no_split_modules = ["Block"]
@@ -200,28 +202,153 @@ class TalkieModel(TalkiePreTrainedModel, GenerationMixin):
200
  super().__init__(config)
201
  self.embed = nn.Embedding(config.vocab_size, config.n_embd)
202
  self.blocks = nn.ModuleList([Block(config, i) for i in range(config.n_layer)])
 
203
 
204
- cos, sin = self._precompute_rotary_embeddings(
205
- config.max_position_embeddings, config.head_dim, config.rope_base
206
- )
207
  self.register_buffer("cos", cos, persistent=False)
208
  self.register_buffer("sin", sin, persistent=False)
209
  self._rotary_initialized = cos.device.type != "meta"
210
  self.post_init()
211
 
212
  def _precompute_rotary_embeddings(
213
- self, seq_len: int, head_dim: int, base: int
 
 
 
214
  ) -> tuple[torch.Tensor, torch.Tensor]:
215
  device = self.embed.weight.device if hasattr(self, "embed") else "cpu"
216
- channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
217
- inv_freq = 1.0 / (base ** (channel_range / head_dim))
 
218
  t = torch.arange(seq_len, dtype=torch.float32, device=device)
219
  freqs = torch.outer(t, inv_freq)
220
  cos, sin = freqs.cos(), freqs.sin()
 
 
 
221
  cos, sin = cos.bfloat16(), sin.bfloat16()
222
  cos, sin = cos[None, :, None, :], sin[None, :, None, :]
223
  return cos, sin
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  def _ensure_rotary_embeddings(self, seq_len: int) -> None:
226
  device = self.embed.weight.device
227
  needs_init = (
@@ -232,13 +359,14 @@ class TalkieModel(TalkiePreTrainedModel, GenerationMixin):
232
  )
233
  if needs_init:
234
  max_seq_len = max(seq_len, self.config.max_position_embeddings)
235
- cos, sin = self._precompute_rotary_embeddings(
236
- max_seq_len, self.config.head_dim, self.config.rope_base
237
- )
238
  self.cos = cos.to(device=device)
239
  self.sin = sin.to(device=device)
240
  self._rotary_initialized = True
241
 
 
 
 
242
  def get_input_embeddings(self) -> nn.Embedding:
243
  return self.embed
244
 
@@ -265,7 +393,7 @@ class TalkieModel(TalkiePreTrainedModel, GenerationMixin):
265
  return cache_position.to(device=input_ids.device, dtype=torch.long)
266
  past_seen = past_key_values.get_seq_length() if past_key_values is not None else 0
267
  position_ids = torch.arange(seq_len, device=input_ids.device, dtype=torch.long) + past_seen
268
- return position_ids.unsqueeze(0)
269
 
270
  def _attention_mask(
271
  self,
@@ -279,10 +407,13 @@ class TalkieModel(TalkiePreTrainedModel, GenerationMixin):
279
  return attention_mask
280
  batch_size, query_length = input_ids.shape
281
  past_seen = past_key_values.get_seq_length() if past_key_values is not None else 0
282
- key_length = past_seen + query_length
283
 
284
  if attention_mask is not None and attention_mask.dim() != 2:
285
  return attention_mask
 
 
 
 
286
  if attention_mask is not None:
287
  if attention_mask.shape[-1] == query_length and past_seen:
288
  prefix = torch.ones(
@@ -293,25 +424,17 @@ class TalkieModel(TalkiePreTrainedModel, GenerationMixin):
293
  )
294
  attention_mask = torch.cat([prefix, attention_mask], dim=-1)
295
  key_length = attention_mask.shape[-1]
296
- has_padding = not bool(torch.all(attention_mask == 1))
297
- else:
298
- has_padding = False
299
-
300
- if attention_mask is None and past_seen == 0:
301
- return None
302
 
303
  key_positions = torch.arange(key_length, device=input_ids.device, dtype=torch.long)
304
  future_mask = key_positions.view(1, 1, 1, key_length) > position_ids.view(
305
  batch_size, 1, query_length, 1
306
  )
307
- if attention_mask is not None and has_padding:
308
  padding_mask = attention_mask[:, None, None, :].to(device=input_ids.device) == 0
309
  mask = future_mask | padding_mask
310
  else:
311
  mask = future_mask
312
 
313
- if not bool(mask.any()):
314
- return None
315
  min_value = torch.finfo(dtype).min
316
  causal_mask = torch.zeros(
317
  batch_size, 1, query_length, key_length, dtype=dtype, device=input_ids.device
@@ -341,17 +464,15 @@ class TalkieModel(TalkiePreTrainedModel, GenerationMixin):
341
  device=inputs_embeds.device,
342
  )
343
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
344
  if use_cache and past_key_values is None:
345
  past_key_values = DynamicCache(config=self.config)
346
 
347
  position_ids = self._position_ids(input_ids, position_ids, cache_position, past_key_values)
348
- needed_seq_len = int(position_ids.max().item()) + 1
349
- self._ensure_rotary_embeddings(needed_seq_len)
350
- if needed_seq_len > self.cos.shape[1]:
351
- raise ValueError(
352
- f"Sequence length {needed_seq_len} exceeds max_position_embeddings "
353
- f"{self.cos.shape[1]}"
354
- )
355
 
356
  cos = self.cos[0, position_ids, :, :]
357
  sin = self.sin[0, position_ids, :, :]
@@ -361,14 +482,34 @@ class TalkieModel(TalkiePreTrainedModel, GenerationMixin):
361
  attention_mask = self._attention_mask(attention_mask, input_ids, position_ids, past_key_values, x.dtype)
362
  e_x = x
363
  for block in self.blocks:
364
- x = block(
365
- e_x,
366
- x,
367
- cos_sin,
368
- attention_mask=attention_mask,
369
- past_key_values=past_key_values if use_cache else None,
370
- **kwargs,
371
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  x = F.rms_norm(x, (x.shape[-1],))
373
  past_key_values = past_key_values if use_cache else None
374
  use_return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -392,6 +533,32 @@ class TalkieForCausalLM(TalkieModel):
392
  def set_output_embeddings(self, value: nn.Linear) -> None:
393
  self.lm_head = value
394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  def forward(
396
  self,
397
  input_ids: torch.LongTensor | None = None,
@@ -403,6 +570,8 @@ class TalkieForCausalLM(TalkieModel):
403
  use_cache: bool | None = None,
404
  position_ids: torch.LongTensor | None = None,
405
  logits_to_keep: int | torch.Tensor = 0,
 
 
406
  **kwargs,
407
  ) -> CausalLMOutputWithPast | tuple[torch.Tensor, ...]:
408
  if input_ids is None and inputs_embeds is None:
@@ -420,13 +589,23 @@ class TalkieForCausalLM(TalkieModel):
420
  **kwargs,
421
  )
422
  hidden_states = outputs.last_hidden_state
423
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
424
- logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
425
- if self.config.logit_scale != 1.0:
426
- logits = logits * self.config.logit_scale
427
-
428
  loss = None
429
- if labels is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  shift_logits = logits[..., :-1, :].contiguous()
431
  shift_labels = labels[..., 1:].contiguous()
432
  loss = F.cross_entropy(
 
1
  from __future__ import annotations
2
 
3
+ import math
4
+
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
 
187
  class TalkiePreTrainedModel(PreTrainedModel):
188
  config_class = TalkieConfig
189
  base_model_prefix = ""
190
+ supports_gradient_checkpointing = True
191
  _supports_sdpa = True
192
  _supports_attention_backend = True
193
  _no_split_modules = ["Block"]
 
202
  super().__init__(config)
203
  self.embed = nn.Embedding(config.vocab_size, config.n_embd)
204
  self.blocks = nn.ModuleList([Block(config, i) for i in range(config.n_layer)])
205
+ self.gradient_checkpointing = False
206
 
207
+ cos, sin = self._precompute_rotary_embeddings(config.max_position_embeddings)
 
 
208
  self.register_buffer("cos", cos, persistent=False)
209
  self.register_buffer("sin", sin, persistent=False)
210
  self._rotary_initialized = cos.device.type != "meta"
211
  self.post_init()
212
 
213
  def _precompute_rotary_embeddings(
214
+ self,
215
+ seq_len: int,
216
+ head_dim: int | None = None,
217
+ base: int | float | None = None,
218
  ) -> tuple[torch.Tensor, torch.Tensor]:
219
  device = self.embed.weight.device if hasattr(self, "embed") else "cpu"
220
+ head_dim = head_dim if head_dim is not None else self.config.head_dim
221
+ base = base if base is not None else self.config.rope_base
222
+ inv_freq, attention_factor = self._rotary_inv_freq(seq_len, head_dim, float(base), device)
223
  t = torch.arange(seq_len, dtype=torch.float32, device=device)
224
  freqs = torch.outer(t, inv_freq)
225
  cos, sin = freqs.cos(), freqs.sin()
226
+ if attention_factor != 1.0:
227
+ cos = cos * attention_factor
228
+ sin = sin * attention_factor
229
  cos, sin = cos.bfloat16(), sin.bfloat16()
230
  cos, sin = cos[None, :, None, :], sin[None, :, None, :]
231
  return cos, sin
232
 
233
+ def _rotary_inv_freq(
234
+ self,
235
+ seq_len: int,
236
+ head_dim: int,
237
+ base: float,
238
+ device: torch.device | str,
239
+ ) -> tuple[torch.Tensor, float]:
240
+ scaling = self.config.rope_scaling
241
+ rope_type = scaling.get("rope_type") if scaling else None
242
+ if rope_type in (None, "default"):
243
+ return self._default_rotary_inv_freq(head_dim, base, device), 1.0
244
+ if rope_type == "linear":
245
+ inv_freq = self._default_rotary_inv_freq(head_dim, base, device)
246
+ return inv_freq / float(scaling["factor"]), 1.0
247
+ if rope_type == "dynamic":
248
+ return self._dynamic_rotary_inv_freq(seq_len, head_dim, base, device, scaling), 1.0
249
+ if rope_type == "yarn":
250
+ return self._yarn_rotary_inv_freq(head_dim, base, device, scaling)
251
+ raise ValueError(f"unsupported rope_scaling type {rope_type!r}")
252
+
253
+ @staticmethod
254
+ def _default_rotary_inv_freq(
255
+ head_dim: int, base: float, device: torch.device | str
256
+ ) -> torch.Tensor:
257
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
258
+ return 1.0 / (base ** (channel_range / head_dim))
259
+
260
+ def _original_max_position_embeddings(self, scaling: dict | None) -> int:
261
+ if scaling and "original_max_position_embeddings" in scaling:
262
+ return int(scaling["original_max_position_embeddings"])
263
+ return int(self.config.max_position_embeddings)
264
+
265
+ def _dynamic_rotary_inv_freq(
266
+ self,
267
+ seq_len: int,
268
+ head_dim: int,
269
+ base: float,
270
+ device: torch.device | str,
271
+ scaling: dict,
272
+ ) -> torch.Tensor:
273
+ original_max_position_embeddings = self._original_max_position_embeddings(scaling)
274
+ scaled_seq_len = max(seq_len, original_max_position_embeddings)
275
+ factor = float(scaling["factor"])
276
+ base = base * (
277
+ (factor * scaled_seq_len / original_max_position_embeddings) - (factor - 1.0)
278
+ ) ** (head_dim / (head_dim - 2.0))
279
+ return self._default_rotary_inv_freq(head_dim, base, device)
280
+
281
+ def _yarn_rotary_inv_freq(
282
+ self,
283
+ head_dim: int,
284
+ base: float,
285
+ device: torch.device | str,
286
+ scaling: dict,
287
+ ) -> tuple[torch.Tensor, float]:
288
+ factor = float(scaling["factor"])
289
+ original_max_position_embeddings = self._original_max_position_embeddings(scaling)
290
+ beta_fast = float(scaling.get("beta_fast", 32.0))
291
+ beta_slow = float(scaling.get("beta_slow", 1.0))
292
+ attention_factor = scaling.get("attention_factor")
293
+ if attention_factor is None:
294
+ attention_factor = 1.0 if factor <= 1.0 else 0.1 * math.log(factor) + 1.0
295
+
296
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
297
+ pos_freqs = base ** (channel_range / head_dim)
298
+ inv_freq_extrapolation = 1.0 / pos_freqs
299
+ inv_freq_interpolation = 1.0 / (factor * pos_freqs)
300
+
301
+ low, high = self._yarn_correction_range(
302
+ beta_fast,
303
+ beta_slow,
304
+ head_dim,
305
+ base,
306
+ original_max_position_embeddings,
307
+ truncate=bool(scaling.get("truncate", True)),
308
+ )
309
+ ramp = self._yarn_linear_ramp(low, high, head_dim // 2, device)
310
+ extrapolation_factor = 1.0 - ramp
311
+ inv_freq = (
312
+ inv_freq_interpolation * (1.0 - extrapolation_factor)
313
+ + inv_freq_extrapolation * extrapolation_factor
314
+ )
315
+ return inv_freq, float(attention_factor)
316
+
317
+ @staticmethod
318
+ def _yarn_correction_range(
319
+ low_rot: float,
320
+ high_rot: float,
321
+ head_dim: int,
322
+ base: float,
323
+ original_max_position_embeddings: int,
324
+ truncate: bool,
325
+ ) -> tuple[float, float]:
326
+ def correction_dim(num_rotations: float) -> float:
327
+ return (
328
+ head_dim
329
+ * math.log(original_max_position_embeddings / (num_rotations * 2.0 * math.pi))
330
+ / (2.0 * math.log(base))
331
+ )
332
+
333
+ low = correction_dim(low_rot)
334
+ high = correction_dim(high_rot)
335
+ if truncate:
336
+ low = math.floor(low)
337
+ high = math.ceil(high)
338
+ return max(low, 0.0), min(high, float(head_dim - 1))
339
+
340
+ @staticmethod
341
+ def _yarn_linear_ramp(
342
+ low: float,
343
+ high: float,
344
+ dim: int,
345
+ device: torch.device | str,
346
+ ) -> torch.Tensor:
347
+ if low == high:
348
+ high += 0.001
349
+ ramp = (torch.arange(dim, dtype=torch.float32, device=device) - low) / (high - low)
350
+ return torch.clamp(ramp, 0.0, 1.0)
351
+
352
  def _ensure_rotary_embeddings(self, seq_len: int) -> None:
353
  device = self.embed.weight.device
354
  needs_init = (
 
359
  )
360
  if needs_init:
361
  max_seq_len = max(seq_len, self.config.max_position_embeddings)
362
+ cos, sin = self._precompute_rotary_embeddings(max_seq_len)
 
 
363
  self.cos = cos.to(device=device)
364
  self.sin = sin.to(device=device)
365
  self._rotary_initialized = True
366
 
367
+ def reset_rotary_embeddings(self) -> None:
368
+ self._rotary_initialized = False
369
+
370
  def get_input_embeddings(self) -> nn.Embedding:
371
  return self.embed
372
 
 
393
  return cache_position.to(device=input_ids.device, dtype=torch.long)
394
  past_seen = past_key_values.get_seq_length() if past_key_values is not None else 0
395
  position_ids = torch.arange(seq_len, device=input_ids.device, dtype=torch.long) + past_seen
396
+ return position_ids.unsqueeze(0).expand(batch_size, -1)
397
 
398
  def _attention_mask(
399
  self,
 
407
  return attention_mask
408
  batch_size, query_length = input_ids.shape
409
  past_seen = past_key_values.get_seq_length() if past_key_values is not None else 0
 
410
 
411
  if attention_mask is not None and attention_mask.dim() != 2:
412
  return attention_mask
413
+ if attention_mask is None and past_seen == 0:
414
+ return None
415
+
416
+ key_length = past_seen + query_length
417
  if attention_mask is not None:
418
  if attention_mask.shape[-1] == query_length and past_seen:
419
  prefix = torch.ones(
 
424
  )
425
  attention_mask = torch.cat([prefix, attention_mask], dim=-1)
426
  key_length = attention_mask.shape[-1]
 
 
 
 
 
 
427
 
428
  key_positions = torch.arange(key_length, device=input_ids.device, dtype=torch.long)
429
  future_mask = key_positions.view(1, 1, 1, key_length) > position_ids.view(
430
  batch_size, 1, query_length, 1
431
  )
432
+ if attention_mask is not None:
433
  padding_mask = attention_mask[:, None, None, :].to(device=input_ids.device) == 0
434
  mask = future_mask | padding_mask
435
  else:
436
  mask = future_mask
437
 
 
 
438
  min_value = torch.finfo(dtype).min
439
  causal_mask = torch.zeros(
440
  batch_size, 1, query_length, key_length, dtype=dtype, device=input_ids.device
 
464
  device=inputs_embeds.device,
465
  )
466
  use_cache = use_cache if use_cache is not None else self.config.use_cache
467
+ if self.gradient_checkpointing and self.training:
468
+ use_cache = False
469
  if use_cache and past_key_values is None:
470
  past_key_values = DynamicCache(config=self.config)
471
 
472
  position_ids = self._position_ids(input_ids, position_ids, cache_position, past_key_values)
473
+ # Keep graph capture free of CUDA tensor -> Python scalar syncs. The
474
+ # configured context length is the static serving/training contract.
475
+ self._ensure_rotary_embeddings(int(self.config.max_position_embeddings))
 
 
 
 
476
 
477
  cos = self.cos[0, position_ids, :, :]
478
  sin = self.sin[0, position_ids, :, :]
 
482
  attention_mask = self._attention_mask(attention_mask, input_ids, position_ids, past_key_values, x.dtype)
483
  e_x = x
484
  for block in self.blocks:
485
+ if self.gradient_checkpointing and self.training:
486
+ def custom_forward(
487
+ e_x: torch.Tensor,
488
+ x: torch.Tensor,
489
+ cos: torch.Tensor,
490
+ sin: torch.Tensor,
491
+ attention_mask: torch.Tensor | None,
492
+ block: Block = block,
493
+ ) -> torch.Tensor:
494
+ return block(e_x, x, (cos, sin), attention_mask=attention_mask)
495
+
496
+ x = self._gradient_checkpointing_func(
497
+ custom_forward,
498
+ e_x,
499
+ x,
500
+ cos,
501
+ sin,
502
+ attention_mask,
503
+ )
504
+ else:
505
+ x = block(
506
+ e_x,
507
+ x,
508
+ cos_sin,
509
+ attention_mask=attention_mask,
510
+ past_key_values=past_key_values if use_cache else None,
511
+ **kwargs,
512
+ )
513
  x = F.rms_norm(x, (x.shape[-1],))
514
  past_key_values = past_key_values if use_cache else None
515
  use_return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
533
  def set_output_embeddings(self, value: nn.Linear) -> None:
534
  self.lm_head = value
535
 
536
+ def _chunked_lm_loss(
537
+ self,
538
+ hidden_states: torch.Tensor,
539
+ labels: torch.Tensor,
540
+ chunk_size: int,
541
+ ) -> torch.Tensor:
542
+ if chunk_size <= 0:
543
+ raise ValueError("chunk_size must be positive")
544
+
545
+ total_loss = hidden_states.new_zeros((), dtype=torch.float32)
546
+ total_tokens = hidden_states.new_zeros((), dtype=torch.float32)
547
+ for start in range(0, hidden_states.shape[1], chunk_size):
548
+ end = min(start + chunk_size, hidden_states.shape[1])
549
+ logits = self.lm_head(hidden_states[:, start:end, :]).float()
550
+ if self.config.logit_scale != 1.0:
551
+ logits = logits * self.config.logit_scale
552
+ chunk_labels = labels[:, start:end].contiguous()
553
+ total_loss = total_loss + F.cross_entropy(
554
+ logits.reshape(-1, logits.size(-1)),
555
+ chunk_labels.reshape(-1),
556
+ ignore_index=-100,
557
+ reduction="sum",
558
+ )
559
+ total_tokens = total_tokens + (chunk_labels != -100).sum(dtype=torch.float32)
560
+ return total_loss / total_tokens.clamp_min(1.0)
561
+
562
  def forward(
563
  self,
564
  input_ids: torch.LongTensor | None = None,
 
570
  use_cache: bool | None = None,
571
  position_ids: torch.LongTensor | None = None,
572
  logits_to_keep: int | torch.Tensor = 0,
573
+ loss_chunk_size: int = 0,
574
+ return_logits: bool = True,
575
  **kwargs,
576
  ) -> CausalLMOutputWithPast | tuple[torch.Tensor, ...]:
577
  if input_ids is None and inputs_embeds is None:
 
589
  **kwargs,
590
  )
591
  hidden_states = outputs.last_hidden_state
 
 
 
 
 
592
  loss = None
593
+ logits = None
594
+ if labels is not None and loss_chunk_size > 0:
595
+ loss = self._chunked_lm_loss(
596
+ hidden_states[:, :-1, :],
597
+ labels[:, 1:],
598
+ loss_chunk_size,
599
+ )
600
+ if return_logits:
601
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
602
+ logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
603
+ if self.config.logit_scale != 1.0:
604
+ logits = logits * self.config.logit_scale
605
+
606
+ if labels is not None and loss is None:
607
+ if logits is None:
608
+ raise ValueError("return_logits must be true when loss_chunk_size is not used")
609
  shift_logits = logits[..., :-1, :].contiguous()
610
  shift_labels = labels[..., 1:].contiguous()
611
  loss = F.cross_entropy(