maximilian-schall-ppx commited on
Commit
33bcb86
·
1 Parent(s): dfb20d4

Fix transformers 5.x compatibility (create_causal_mask kwargs)

Browse files
Files changed (1) hide show
  1. modeling.py +21 -9
modeling.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Callable, Literal
2
  import numpy as np
3
  import torch
@@ -11,6 +12,14 @@ from .configuration import PPLXQwen3Config
11
  from transformers import AutoTokenizer
12
  from .st_quantize import FlexibleQuantizer
13
 
 
 
 
 
 
 
 
 
14
 
15
  # From modeling_t5gemma.py
16
  def bidirectional_mask_function(attention_mask: torch.Tensor | None) -> Callable:
@@ -57,16 +66,19 @@ class PPLXQwen3Model(Qwen3Model):
57
  inputs_embeds = self.embed_tokens(input_ids)
58
  input_ids = None
59
 
60
- attention_mask = {
61
- "full_attention": create_causal_mask(
62
- config=self.config,
63
- inputs_embeds=inputs_embeds,
64
- attention_mask=attention_mask,
65
- past_key_values=None,
66
- position_ids=position_ids,
67
- or_mask_function=bidirectional_mask_function(attention_mask),
68
- )
69
  }
 
 
 
 
 
70
 
71
  outputs = super().forward(
72
  input_ids=input_ids,
 
1
+ import inspect
2
  from typing import Callable, Literal
3
  import numpy as np
4
  import torch
 
12
  from transformers import AutoTokenizer
13
  from .st_quantize import FlexibleQuantizer
14
 
15
+ # The transformers `create_causal_mask` signature has shifted over releases
16
+ # (the embeds kwarg was renamed `input_embeds` -> `inputs_embeds`, and
17
+ # `cache_position` was eventually dropped). Probe the actual signature at import
18
+ # time so this works on any installed release, including dev/main builds.
19
+ _CCM_PARAMS = inspect.signature(create_causal_mask).parameters
20
+ _CCM_EMBEDS_KEY = "inputs_embeds" if "inputs_embeds" in _CCM_PARAMS else "input_embeds"
21
+ _CCM_ACCEPTS_CACHE_POSITION = "cache_position" in _CCM_PARAMS
22
+
23
 
24
  # From modeling_t5gemma.py
25
  def bidirectional_mask_function(attention_mask: torch.Tensor | None) -> Callable:
 
66
  inputs_embeds = self.embed_tokens(input_ids)
67
  input_ids = None
68
 
69
+ mask_kwargs = {
70
+ "config": self.config,
71
+ _CCM_EMBEDS_KEY: inputs_embeds,
72
+ "attention_mask": attention_mask,
73
+ "past_key_values": None,
74
+ "position_ids": position_ids,
75
+ "or_mask_function": bidirectional_mask_function(attention_mask),
 
 
76
  }
77
+ if _CCM_ACCEPTS_CACHE_POSITION:
78
+ mask_kwargs["cache_position"] = torch.arange(
79
+ inputs_embeds.shape[1], device=inputs_embeds.device, dtype=torch.long
80
+ )
81
+ attention_mask = {"full_attention": create_causal_mask(**mask_kwargs)}
82
 
83
  outputs = super().forward(
84
  input_ids=input_ids,