Instructions to use kuleshov-group/e2d2-wmt with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use kuleshov-group/e2d2-wmt with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="kuleshov-group/e2d2-wmt", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("kuleshov-group/e2d2-wmt", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from functools import partial | |
| from typing import Any, Dict, Literal, Optional, Tuple, Union | |
| import torch | |
| from tqdm.auto import tqdm | |
| from transformers import ( | |
| GenerationConfig, | |
| LogitsProcessorList, | |
| PreTrainedTokenizer, | |
| StoppingCriteriaList, | |
| ) | |
| from transformers.cache_utils import Cache, DynamicCache | |
| try: | |
| from torch.nn.attention.flex_attention import ( | |
| BlockMask, | |
| and_masks, | |
| create_block_mask, | |
| ) | |
| except ImportError: | |
| BlockMask, and_masks, create_block_mask = None, None, None | |
| from .denoiser_base import ( | |
| Denoiser, | |
| DenoiserConfig, | |
| DenoiserInput, | |
| LossAndNllOutput, | |
| ) | |
| def create_attn_mask(attn_mask): | |
| # noinspection PyUnusedLocal | |
| def padding(b, h, q_idx, kv_idx): | |
| return attn_mask[b, q_idx] & attn_mask[b, kv_idx] | |
| return padding | |
| class DiffusionGenerationConfig(GenerationConfig): | |
| def __init__( | |
| self, | |
| num_steps: int = 1000, | |
| min_t: float = 1e-5, | |
| block_size: Optional[int] = None, | |
| first_hitting: bool = False, | |
| sampling_strategy: Literal["posterior", "predict_then_noise"] = "posterior", | |
| confidence_based_noising: bool = False, | |
| confidence_margin_based_noising: bool = False, | |
| confidence_threshold: float = 1e6, | |
| use_model_output_cache: bool = True, | |
| align_inputs_to_blocks: bool = True, | |
| **kwargs, | |
| ): | |
| """Generation config with additional parameters relevant for diffusion model | |
| sampling. | |
| Args: | |
| num_steps (int): Number of diffusion / iterative refinement steps. | |
| Defaults to 1000. | |
| min_t (float): Minimum time to use. | |
| Diffusion models use t=1 for noise and t=0 for signal. | |
| Setting t=0 exactly can lead to certain numerical instabilities. | |
| Defaults to 1e-5. | |
| block_size (int): Block size to use for semi-autoregressive decoding. | |
| Defaults to None (in which case block_size is set to max_new_tokens). | |
| first_hitting (bool): Whether to use first hitting sampler. | |
| When set to true, rather than following the diffusion time and sampling | |
| from posterior, which can result in no tokens changing between steps, | |
| e.g., for masked diffusion, we explicitly determine the next time step | |
| at which a token will be decoded / generated. | |
| Note: this will negate the `num_steps` parameter, as we will decode one | |
| token at a time, hence, when True, num_steps = seq_length | |
| (or block_size, for semi-autoregressive). | |
| See https://arxiv.org/abs/2409.02908 for details. | |
| Defaults to False. | |
| sampling_strategy (str): Method for transitioning between latents. | |
| Options: | |
| - "posterior" - Compute and sample from the posterior | |
| q(x_s | x_t, x_theta). | |
| - "predict_then_noise" - Sample from the denoising model x_theta, | |
| then add back noise to produce x_s. | |
| Only implemented for absorbing diffusion. | |
| Defaults to "posterior". | |
| confidence_based_noising (bool): When using the "predict_then_noise" | |
| strategy, whether to add noise to random positions or to those that have | |
| the lowest probability under x_theta. | |
| Cannot be used in conjunction with confidence_margin_based_noising. | |
| Defaults to False. | |
| confidence_margin_based_noising (bool): When using the "predict_then_noise" | |
| strategy, whether to add noise to random positions or to those that have | |
| the lowest probability margins under x_theta, where margin is defined as | |
| the absolute difference between the top two probabilities at a given | |
| position. | |
| See https://arxiv.org/abs/2502.06768 for details. | |
| Cannot be used in conjunction with confidence_based_noising. | |
| Defaults to False. | |
| confidence_threshold (float): Confidence threshold to use for sampling. | |
| Any tokens that exceed threshold are decoded. | |
| See https://arxiv.org/abs/2505.22618 for details. | |
| Defaults to 1e6. | |
| use_model_output_cache (bool): Whether to re-use model's output, if sequence | |
| is unchanged, because if xt == xs, we can simply re-use the denoising | |
| model's outputs and save a function evaluation. | |
| Relevant if model.backbone is not time/noise-conditioned. | |
| Defaults to True. | |
| align_inputs_to_blocks (bool): Whether to align input tokens to block size, | |
| e.g., for an input of length C and block size S, context will be C // S, | |
| and generation will begin with a block whose first C % S tokens come | |
| from the input. | |
| kwargs: Keyword arguments passed to `GenerationConfig`. | |
| """ | |
| super().__init__(**kwargs) | |
| self.num_steps = num_steps | |
| self.min_t = min_t | |
| # TODO: assumes we are setting max_new_tokens, which may not be the case! | |
| self.block_size = block_size if block_size is not None else self.max_new_tokens | |
| self.first_hitting = first_hitting | |
| if self.first_hitting: | |
| # TODO: log.warn that this is being overridden | |
| self.num_steps = min(num_steps, self.block_size) | |
| self.sampling_strategy = sampling_strategy | |
| assert not confidence_based_noising or not confidence_margin_based_noising, ( | |
| "Cannot use both `confidence_based_noising` and" | |
| " `confidence_margin_based_noising`." | |
| ) | |
| self.confidence_based_noising = confidence_based_noising | |
| self.confidence_margin_based_noising = confidence_margin_based_noising | |
| self.confidence_threshold = confidence_threshold | |
| self.use_model_output_cache = use_model_output_cache | |
| self.align_inputs_to_blocks = align_inputs_to_blocks | |
| class D3PMConfig(DenoiserConfig): | |
| """Configuration class for D3PM models.""" | |
| model_type = "d3pm" | |
| auto_map = { | |
| "AutoConfig": "diffusion.D3PMConfig", | |
| "AutoModel": "diffusion.D3PM", | |
| "AutoModelForMaskedLM": "diffusion.D3PM", | |
| } | |
| def __init__( | |
| self, | |
| keep_clean_bos: Optional[bool] = None, # Whether to enforce un-noised BOS token | |
| T: int = 1000, | |
| diffusion_type: Literal["absorbing", "uniform"] = "absorbing", | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.keep_clean_bos = keep_clean_bos | |
| self.diffusion_type = diffusion_type | |
| self.T = T | |
| class D3PM(Denoiser): | |
| """Denoiser class for D3PM models. | |
| This class implements the Denoiser interface for D3PM models. | |
| """ | |
| config_class = D3PMConfig | |
| def __init__(self, config: D3PMConfig, **kwargs): | |
| super().__init__(config, **kwargs) | |
| self.T = config.T | |
| self.diffusion_type = config.diffusion_type | |
| self._create_static_mask() | |
| def _create_static_mask(self) -> None: | |
| static_mask = torch.ones( | |
| self.config.length, self.config.length, dtype=torch.bool | |
| ) | |
| self.register_buffer( | |
| "static_attention_mask", | |
| static_mask, | |
| ) | |
| self.skip_params_for_push.append("static_attention_mask") | |
| def _sample_q_xt( | |
| self, | |
| x0: torch.LongTensor, | |
| alpha_t: torch.FloatTensor, | |
| context_mask: torch.FloatTensor, | |
| ) -> torch.LongTensor: | |
| """Sample from the pre-defined forward / noising process. | |
| Parameters: | |
| x0 (Tensor): Signal / data sample; | |
| can potentially include context tokens. | |
| alpha_t (Tensor): Amount of signal to retain. | |
| context_mask (Tensor): Indicator of context tokens (to remain | |
| unchanged). | |
| """ | |
| move_indices = torch.rand(*x0.shape, device=x0.device) < (1.0 - alpha_t) | |
| if self.diffusion_type == "absorbing": | |
| xt = torch.where( | |
| (move_indices * (1 - context_mask)).bool(), self.mask_token_id, x0 | |
| ) | |
| if self.config.keep_clean_bos: | |
| xt[..., 0] = x0[..., 0] | |
| return xt # type: ignore | |
| if self.diffusion_type == "uniform": | |
| xt = torch.randint(0, self.vocab_size, x0.shape, device=x0.device) | |
| xt = torch.where(context_mask.bool(), x0, xt) | |
| if self.config.keep_clean_bos: | |
| xt[..., 0] = x0[..., 0] | |
| return xt # type: ignore | |
| raise NotImplementedError( | |
| f"Diffusion type '{self.diffusion_type}' not implemented." | |
| ) | |
| def _prepare_inputs( | |
| self, | |
| input_ids: torch.LongTensor, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| context_mask: Optional[torch.FloatTensor] = None, | |
| t: Optional[torch.FloatTensor] = None, | |
| past_key_values: Optional[Cache] = None, | |
| ): | |
| # Prepare inputs for D3PM model | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids) | |
| if context_mask is None: | |
| context_mask = torch.zeros_like(attention_mask) | |
| if torch.is_floating_point(attention_mask): | |
| attention_mask = attention_mask.to(torch.int) | |
| context_mask = context_mask.to(torch.int) | |
| if t is None: | |
| t = torch.rand(input_ids.shape[0], device=input_ids.device) | |
| alpha_t, alpha_t_prime = self.noise_schedule(t) | |
| while alpha_t.ndim < 2: | |
| alpha_t = alpha_t[..., None] | |
| alpha_t_prime = alpha_t_prime[..., None] | |
| xt = self._sample_q_xt( | |
| x0=input_ids, | |
| alpha_t=alpha_t, | |
| context_mask=context_mask, | |
| ) | |
| if ( | |
| context_mask is not None | |
| and context_mask.sum() == 0 | |
| and (attention_mask == 1).all() | |
| ): | |
| processed_attention_mask = None | |
| else: | |
| processed_attention_mask = ( | |
| self.static_attention_mask[None, ...] | |
| & attention_mask[:, None, :] | |
| & attention_mask[..., None] | |
| )[:, None, ...] # Make attention mask 4D | |
| processed_attention_mask = self._preprocess_attention_mask( | |
| processed_attention_mask, dtype=torch.float | |
| ) | |
| if self.training and self.config.train_on_context: | |
| tokens_mask = attention_mask | |
| else: | |
| tokens_mask = attention_mask * (1 - context_mask) | |
| return DenoiserInput( | |
| xt=xt, | |
| x0=input_ids, | |
| attention_mask=processed_attention_mask, | |
| context_mask=context_mask, | |
| tokens_mask=tokens_mask, | |
| t=t, | |
| alpha_t=alpha_t, | |
| alpha_t_prime=alpha_t_prime, | |
| ) | |
| def _prepare_inputs_inference( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| context: Optional[torch.LongTensor] = None, | |
| context_mask: Optional[torch.FloatTensor] = None, | |
| cache: Optional[Dict[str, Any]] = None, | |
| **backbone_kwargs: Any, | |
| ) -> Tuple[DenoiserInput, Dict[str, Any]]: | |
| assert input_ids is not None or context is not None, ( | |
| "Must provide either input_ids or context." | |
| ) | |
| cache = cache if cache is not None else {} | |
| past_key_values = cache.pop("past_key_values", DynamicCache()) | |
| if context is not None: | |
| if input_ids is not None: | |
| if context_mask is None: | |
| context_mask = torch.cat( | |
| [torch.ones_like(context), torch.zeros_like(input_ids)], dim=-1 | |
| ) | |
| input_ids = torch.cat([context, input_ids], dim=-1) | |
| else: | |
| input_ids = context | |
| context_mask = torch.ones_like(input_ids) | |
| if attention_mask is None: | |
| cache_length = self._get_past_key_values_seq_length(past_key_values) | |
| full_seq_length = cache_length + input_ids.shape[-1] | |
| attention_mask = torch.ones( | |
| (input_ids.shape[0], 1, input_ids.shape[1], full_seq_length), | |
| device=input_ids.device, | |
| ) # Make attention mask 4D | |
| attention_mask = self._preprocess_attention_mask( | |
| attention_mask, dtype=torch.float | |
| ) | |
| return DenoiserInput( | |
| xt=input_ids, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| context_mask=context_mask, | |
| backbone_kwargs=backbone_kwargs | {"use_cache": False}, | |
| ), cache | |
| def _forward( | |
| self, | |
| backbone_output: torch.FloatTensor, | |
| denoiser_inputs: DenoiserInput, | |
| **kwargs, | |
| ) -> torch.FloatTensor: | |
| return torch.log_softmax(backbone_output, dim=-1) # type: ignore | |
| def _compute_loss( | |
| self, | |
| model_output: torch.FloatTensor, | |
| denoiser_inputs: DenoiserInput, | |
| **kwargs: Any, | |
| ) -> LossAndNllOutput: | |
| raise NotImplementedError | |
| def _sample_prior(self, device, batch_size, length): | |
| """Samples from prior / limiting distribution.""" | |
| if self.diffusion_type == "absorbing": | |
| return self.mask_token_id * torch.ones( | |
| (batch_size, length), dtype=torch.int64, device=device | |
| ) | |
| if self.diffusion_type == "uniform": | |
| return torch.randint( | |
| 0, | |
| self.vocab_size, | |
| (batch_size, length), | |
| device=device, | |
| dtype=torch.int64, | |
| ) | |
| raise NotImplementedError( | |
| f"Diffusion type '{self.diffusion_type}' not implemented." | |
| ) | |
| def _compute_posterior( | |
| self, | |
| x: Union[torch.FloatTensor, torch.LongTensor], | |
| xt: torch.LongTensor, | |
| alpha_t: torch.FloatTensor, | |
| alpha_s: torch.FloatTensor, | |
| ) -> torch.FloatTensor: | |
| """Computes posterior / approximate posterior q(x_s | x_t, x), | |
| where x represents clean sequence (as one-hots) or the output of the | |
| denoising model. | |
| Args: | |
| x (Tensor): True (one-hot) / predicted clean signal (B, L, V). | |
| xt (Tensor): Noised signal at time t (B, L). | |
| alpha_t (Tensor): Noise schedule parameter at time t (B, 1, 1). | |
| alpha_s (Tensor): Noise schedule parameter at time s (B, 1, 1). | |
| """ | |
| if self.diffusion_type == "absorbing": | |
| q_xs = x * (alpha_s - alpha_t) | |
| q_xs[..., self.mask_token_id] = 1 - alpha_s[..., 0] | |
| q_xs /= 1 - alpha_t | |
| return q_xs # type: ignore | |
| alpha_ts = alpha_t / alpha_s | |
| d_alpha = alpha_s - alpha_t | |
| xt_one_hot = torch.nn.functional.one_hot(x, self.vocab_size) | |
| limiting_distribution = torch.ones_like(xt_one_hot) / self.vocab_size | |
| if self.diffusion_type == "uniform": | |
| return ( | |
| alpha_t * self.vocab_size * x * xt_one_hot | |
| + (alpha_ts - alpha_t) * xt_one_hot | |
| + d_alpha * x | |
| + (1 - alpha_ts) * (1 - alpha_s) * limiting_distribution | |
| ) / ( | |
| alpha_t * self.vocab_size * torch.gather(x, -1, xt[..., None]) | |
| + (1 - alpha_t) | |
| ) | |
| raise NotImplementedError( | |
| f"Diffusion type {self.diffusion_type} not implemented." | |
| ) | |
| def _sample_generation_timesteps( | |
| generation_config: DiffusionGenerationConfig, | |
| max_length: Optional[int] = None, | |
| device: Optional[str] = None, | |
| ) -> torch.FloatTensor: | |
| """Sample timesteps for diffusion generation process.""" | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if max_length is None: | |
| max_length = generation_config.max_new_tokens | |
| if ( | |
| generation_config.first_hitting | |
| # TODO: first-hitting does not work with posterior | |
| and generation_config.sampling_strategy == "posterior" | |
| ): | |
| timesteps = torch.FloatTensor([1.0]) | |
| for i in range(max_length, 0, -1): | |
| u = torch.rand(1) | |
| next_t = timesteps[-1] * u ** (1 / i) | |
| timesteps = torch.cat((timesteps, next_t), dim=0) | |
| return timesteps[1:].to(device) # type: ignore | |
| return torch.linspace( # type: ignore | |
| 1.0, | |
| generation_config.min_t, | |
| generation_config.num_steps + 1, | |
| device=device, | |
| )[:-1] | |
| def _generate_unconditional( | |
| self, | |
| generation_config: DiffusionGenerationConfig, | |
| alpha_t: torch.FloatTensor, | |
| alpha_s: torch.FloatTensor, | |
| denoiser_inputs: Optional[DenoiserInput] = None, | |
| model_output_cache: Optional[Dict[str, torch.FloatTensor]] = None, | |
| cache: Optional[Dict[str, Any]] = None, | |
| running_generation: Optional[torch.LongTensor] = None, | |
| logits_processor: Optional[LogitsProcessorList] = None, | |
| **kwargs: Any, | |
| ) -> Tuple[torch.LongTensor, Dict[str, torch.FloatTensor], Dict[str, Any]]: | |
| cache = cache if cache is not None else {} | |
| if model_output_cache is None: # execute function evaluation | |
| backbone_output = self._backbone_forward( | |
| denoiser_inputs, | |
| fix_cache_length=True, # Do not let kv cache grow on each forward call | |
| **cache, | |
| **kwargs, | |
| ) | |
| backbone_output = {k: v for k, v in backbone_output.items()} | |
| logits = backbone_output.pop("logits") | |
| cache = cache | backbone_output | |
| log_x_theta = self._forward(logits, denoiser_inputs, **kwargs) | |
| if logits_processor is not None: | |
| for token_idx in range(log_x_theta.shape[1]): | |
| # TODO: Looping over token positions like this does not allow for | |
| # some processors, e.g. length penalty which could be applied all | |
| # at once to the entire block, to be applied in parallel. | |
| log_x_theta[:, token_idx] = logits_processor( | |
| input_ids=running_generation, | |
| scores=log_x_theta[:, token_idx], # type: ignore | |
| ) | |
| log_x_theta = torch.log_softmax(log_x_theta, dim=-1) # re-normalize | |
| x_theta = log_x_theta.exp() | |
| else: | |
| x_theta = model_output_cache["x_theta"] | |
| model_output_cache = {"x_theta": x_theta} | |
| prob_check_denom = denoiser_inputs.xt.numel() | |
| if generation_config.sampling_strategy == "posterior": | |
| q_xs = self._compute_posterior( | |
| x_theta, denoiser_inputs.xt, alpha_t, alpha_s | |
| ) | |
| assert abs((q_xs.sum() / prob_check_denom).item() - 1.0) < 1e-6, ( | |
| "Posterior probabilities not summing to 1." | |
| ) | |
| assert q_xs.isnan().sum().item() == 0, "NaN found in the posterior." | |
| xs = self._sample_categorical(q_xs, generation_config.do_sample) | |
| output = torch.where( | |
| (denoiser_inputs.xt != self.mask_token_id).bool(), # type: ignore | |
| denoiser_inputs.xt, | |
| xs, | |
| ) | |
| elif generation_config.sampling_strategy == "predict_and_noise": | |
| assert self.config.diffusion_type == "absorbing", ( | |
| "predict_and_noise decoding strategy only supports absorbing diffusion." | |
| ) | |
| # assert ( | |
| # abs((x_theta.sum() / prob_check_denom).item() - 1.0) < 1e-6 | |
| # ), "Denoising output probabilities not summing to 1." | |
| # assert x_theta.isnan().sum().item() == 0, ( | |
| # "NaN found in the denoising output." | |
| # ) | |
| # Predict | |
| xs = self._sample_categorical(x_theta, generation_config.do_sample) | |
| xs_probs = x_theta.gather(-1, xs[..., None]).squeeze(dim=-1) | |
| output = xs.clone() | |
| # Noise | |
| num_noise_indices = torch.minimum( | |
| ((1 - alpha_s) * generation_config.block_size).to(torch.int), | |
| (denoiser_inputs.xt == self.mask_token_id).sum() - 1, # type: ignore | |
| ) | |
| if generation_config.confidence_based_noising: | |
| conf = x_theta.gather(-1, xs[..., None]).squeeze(-1) | |
| conf = torch.where( # already decoded tokens have 'inf' confidence | |
| (denoiser_inputs.xt == self.mask_token_id).bool(), # type: ignore | |
| conf, | |
| torch.inf, | |
| ) | |
| noise_indices = conf.argsort(dim=-1)[..., :num_noise_indices] | |
| elif generation_config.confidence_margin_based_noising: | |
| top2 = torch.topk(x_theta, k=2, dim=-1).values # shape: (B, L, 2) | |
| conf = (top2[..., 0] - top2[..., 1]).abs() | |
| conf = torch.where( # already decoded tokens have 'inf' confidence | |
| (denoiser_inputs.xt == self.mask_token_id).bool(), # type: ignore | |
| conf, | |
| torch.inf, | |
| ) | |
| noise_indices = conf.argsort(dim=-1)[..., :num_noise_indices] | |
| else: | |
| # TODO: implement random noise indices selection | |
| raise NotImplementedError | |
| output[..., noise_indices] = self.mask_token_id | |
| output = torch.where( | |
| xs_probs >= generation_config.confidence_threshold, xs, output | |
| ) | |
| else: | |
| raise NotImplementedError( | |
| f"Sampling strategy {generation_config.sampling_strategy} not" | |
| " implemented." | |
| ) | |
| return output, model_output_cache, cache # type: ignore | |
| def generate( | |
| self, | |
| inputs: Optional[torch.LongTensor] = None, | |
| generation_config: Optional[DiffusionGenerationConfig] = None, | |
| logits_processor: Optional[LogitsProcessorList] = None, | |
| stopping_criteria: Optional[StoppingCriteriaList] = None, | |
| max_length: Optional[int] = None, | |
| max_new_tokens: Optional[int] = None, | |
| batch_size: Optional[int] = None, | |
| device: Optional[str] = None, | |
| tokenizer: Optional[PreTrainedTokenizer] = None, | |
| disable_pbar: bool = False, | |
| **kwargs: Any, | |
| ) -> torch.LongTensor: | |
| # Setup sampling variables | |
| if generation_config is None: | |
| assert getattr(self, "generation_config", None) is not None, ( | |
| "Generation config must be provided if not present in the model." | |
| ) | |
| generation_config = self.generation_config | |
| if inputs is None: | |
| inputs = torch.ones((batch_size, 1), device=device) * self.bos_token_id | |
| if max_length is None: | |
| if hasattr(generation_config, "max_length"): | |
| max_length = generation_config.max_length | |
| else: | |
| max_length = self.max_length | |
| if max_new_tokens is None: | |
| if hasattr(generation_config, "max_new_tokens"): | |
| max_new_tokens = generation_config.max_new_tokens | |
| else: | |
| max_new_tokens = max_length - inputs.shape[-1] | |
| batch_size = batch_size if batch_size is not None else inputs.shape[0] | |
| assert batch_size == 1, "Batched sampling not supported yet" | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| block_size = generation_config.block_size | |
| max_blocks = max_new_tokens // block_size | |
| # Sample max generation length tensor from prior | |
| accumulated_samples = self._sample_prior( | |
| device=device, | |
| batch_size=batch_size, | |
| length=max_blocks * block_size, | |
| ) | |
| accumulated_samples = torch.cat([inputs, accumulated_samples], dim=-1) | |
| if generation_config.use_cache and inputs.numel() > 0: | |
| cache = self.update_cache( | |
| inputs=inputs[:, : block_size * (inputs.shape[-1] // block_size)] | |
| if generation_config.align_inputs_to_blocks | |
| else inputs, | |
| cache={}, | |
| ) | |
| else: | |
| cache = None | |
| if generation_config.align_inputs_to_blocks: | |
| inputs_offset = ( | |
| block_size * (inputs.shape[-1] // block_size) | |
| if inputs.numel() > 0 | |
| else 0 | |
| ) | |
| else: | |
| inputs_offset = inputs.shape[-1] if inputs.numel() > 0 else 0 | |
| total_NFEs = 0 | |
| timesteps = self._sample_generation_timesteps( # Re-use in every block | |
| generation_config, max_length=block_size, device=device | |
| ) | |
| dt = (1 - generation_config.min_t) / len(timesteps) | |
| block_pbar = tqdm( | |
| range(max_blocks), | |
| desc="Blocks", | |
| leave=True, | |
| disable=disable_pbar, | |
| ) | |
| for block_id in block_pbar: | |
| block_NFEs = 0 | |
| xt = accumulated_samples[ | |
| :, | |
| inputs_offset + (block_id * block_size) : inputs_offset | |
| + ((block_id + 1) * block_size), | |
| ] | |
| if self.mask_token_id not in xt: | |
| continue | |
| step_pbar = tqdm( | |
| timesteps, | |
| desc="T", | |
| total=timesteps.shape[0], | |
| leave=False, | |
| disable=disable_pbar, | |
| ) | |
| model_output_cache = None | |
| context = ( | |
| accumulated_samples[:, : (block_id * block_size) + inputs_offset] | |
| if not generation_config.use_cache | |
| else None | |
| ) | |
| # Used for logit processing | |
| running_generation = accumulated_samples[ | |
| :, | |
| inputs_offset : inputs_offset + (block_id * block_size), | |
| ] | |
| for t in step_pbar: | |
| if model_output_cache is None: | |
| block_NFEs += 1 | |
| total_NFEs += 1 | |
| # t is 0-dim tensor, reshape to (1, 1, 1) for broadcasting | |
| alpha_t, _ = self.noise_schedule(t) | |
| alpha_s, _ = self.noise_schedule(t - dt) | |
| alpha_t = alpha_t[None, None, None] | |
| alpha_s = alpha_s[None, None, None] | |
| denoiser_inputs, cache = self._prepare_inputs_inference( | |
| input_ids=xt, | |
| context=context, | |
| cache=cache if generation_config.use_cache else None, | |
| ) | |
| xs, model_output_cache, cache = self._generate_unconditional( | |
| generation_config=generation_config, | |
| alpha_t=alpha_t, | |
| alpha_s=alpha_s, | |
| denoiser_inputs=denoiser_inputs, | |
| model_output_cache=model_output_cache, | |
| cache=cache, | |
| running_generation=running_generation, # type: ignore | |
| logits_processor=logits_processor, | |
| tokenizer=tokenizer, | |
| **kwargs, | |
| ) | |
| block_pbar.set_postfix( | |
| NFEs=total_NFEs, | |
| block_NFEs=block_NFEs, | |
| ) | |
| if ( | |
| not torch.allclose(xs, denoiser_inputs.xt) | |
| or not generation_config.use_model_output_cache | |
| ): | |
| model_output_cache = None | |
| if not generation_config.use_cache: | |
| xt[..., -block_size:] = xs[..., -block_size:] | |
| else: | |
| xt = xs | |
| if ( | |
| xt == self.mask_token_id | |
| ).sum().item() == 0 and self.config.diffusion_type == "absorbing": | |
| break | |
| accumulated_samples[ | |
| :, | |
| inputs_offset + (block_id * block_size) : inputs_offset | |
| + ((block_id + 1) * block_size), | |
| ] = xt | |
| if tokenizer is not None: # Useful for debugging | |
| print(tokenizer.batch_decode(accumulated_samples)) | |
| if stopping_criteria is not None: | |
| is_done = stopping_criteria( | |
| input_ids=accumulated_samples[ # type: ignore | |
| :, | |
| inputs_offset : inputs_offset + ((block_id + 1) * block_size), | |
| ], | |
| scores=None, # type: ignore | |
| ) | |
| if torch.any(is_done): | |
| accumulated_samples = accumulated_samples[ | |
| :, | |
| : inputs_offset + ((block_id + 1) * block_size), | |
| ] | |
| break | |
| if generation_config.use_cache: | |
| cache = self.update_cache( | |
| inputs=xt, | |
| cache=cache, | |
| ) | |
| return accumulated_samples # type: ignore | |
| class MDLMConfig(D3PMConfig): | |
| """Configuration class for MDLM models.""" | |
| model_type = "mdlm" | |
| auto_map = { | |
| "AutoConfig": "diffusion.MDLMConfig", | |
| "AutoModel": "diffusion.MDLM", | |
| "AutoModelForMaskedLM": "diffusion.MDLM", | |
| } | |
| class MDLM(D3PM): | |
| """Denoiser class for MDLM models.""" | |
| config_class = MDLMConfig | |
| def __init__(self, config: MDLMConfig, **kwargs): | |
| super().__init__(config, **kwargs) | |
| self.neg_infinity = -1e12 | |
| def _forward( | |
| self, | |
| backbone_output: torch.FloatTensor, | |
| denoiser_inputs: DenoiserInput, | |
| **kwargs, | |
| ) -> torch.FloatTensor: | |
| # Zero-mask probability | |
| backbone_output[..., self.mask_token_id] = self.neg_infinity | |
| log_probs = backbone_output - torch.logsumexp( | |
| backbone_output, dim=-1, keepdim=True | |
| ) | |
| # Copy-over unmasked: For the log_probs of the unmasked tokens, set all values | |
| # to -infinity except for the indices corresponding to | |
| # the unmasked tokens. | |
| xt = denoiser_inputs.xt | |
| unmasked_indices = xt != self.mask_token_id | |
| log_probs[unmasked_indices] = self.neg_infinity | |
| log_probs[unmasked_indices, xt[unmasked_indices]] = 0 | |
| return log_probs # type: ignore | |
| def _compute_loss( | |
| self, | |
| model_output: torch.FloatTensor, | |
| denoiser_inputs: DenoiserInput, | |
| **kwargs: Any, | |
| ) -> LossAndNllOutput: | |
| log_p_theta = torch.gather( | |
| input=model_output, dim=-1, index=denoiser_inputs.x0[:, :, None] | |
| ).squeeze(-1) | |
| nlls = ( | |
| log_p_theta | |
| * denoiser_inputs.alpha_t_prime | |
| / (1 - denoiser_inputs.alpha_t) | |
| * denoiser_inputs.tokens_mask | |
| ) | |
| if self.training: | |
| batch_nll = -(log_p_theta * denoiser_inputs.tokens_mask).sum(dim=-1) | |
| else: | |
| batch_nll = nlls.sum(dim=-1) | |
| count = denoiser_inputs.tokens_mask.sum(dim=-1) | |
| token_nll = (batch_nll / count).mean() | |
| return LossAndNllOutput( | |
| loss=token_nll, # type: ignore | |
| nlls=nlls, | |
| other_loss_terms={ | |
| "masked_tokens": (denoiser_inputs.xt == self.mask_token_id).int() | |
| }, | |
| ) | |
| class BD3LMConfig(MDLMConfig): | |
| """Configuration class for BD3LM models.""" | |
| model_type = "bd3lm" | |
| auto_map = { | |
| "AutoConfig": "diffusion.BD3LMConfig", | |
| "AutoModel": "diffusion.BD3LM", | |
| "AutoModelForMaskedLM": "diffusion.BD3LM", | |
| } | |
| def __init__( | |
| self, | |
| block_size: Optional[int] = None, | |
| eval_block_size: Optional[int] = None, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.block_size = block_size | |
| self.eval_block_size = ( | |
| eval_block_size if eval_block_size is not None else block_size | |
| ) | |
| class BD3LM(MDLM): | |
| """Denoiser class for BD3LM models.""" | |
| config_class = BD3LMConfig | |
| def __init__(self, config: BD3LMConfig, **kwargs): | |
| super().__init__(config, **kwargs) | |
| # noinspection PyUnusedLocal | |
| def _block_mask( | |
| b, | |
| h, | |
| q_idx, | |
| kv_idx, | |
| block_size: Optional[int] = None, | |
| seq_length: Optional[int] = None, | |
| ) -> torch.Tensor: | |
| del b, h | |
| # Indicate whether token belongs to xt or x0: | |
| xt_flag_q = (q_idx >= seq_length).bool() | |
| xt_flag_kv = (kv_idx >= seq_length).bool() | |
| # Compute block indices | |
| block_q = torch.where( | |
| xt_flag_q, (q_idx - seq_length) // block_size, q_idx // block_size | |
| ) | |
| block_kv = torch.where( | |
| xt_flag_kv, (kv_idx - seq_length) // block_size, kv_idx // block_size | |
| ) | |
| # **1. Offset Block-Causal Mask (M_OBC) ** | |
| offset_block_causal = (block_q > block_kv) & ~xt_flag_kv & xt_flag_q | |
| # **2. Block Diagonal Mask (M_BD) ** | |
| block_diagonal = (block_q == block_kv) & (xt_flag_q == xt_flag_kv) | |
| # **3. Block-Causal Mask (M_BC) ** | |
| block_causal = (block_q >= block_kv) & ~xt_flag_kv & ~xt_flag_q | |
| # **3. Combine Masks ** | |
| return block_diagonal | offset_block_causal | block_causal | |
| def _create_static_mask(self) -> None: | |
| if self.config.attn_backend == "sdpa": | |
| static_mask = self._block_mask( | |
| b=None, | |
| h=None, | |
| q_idx=torch.arange(self.config.length * 2)[:, None], | |
| kv_idx=torch.arange(self.config.length * 2)[None, :], | |
| block_size=self.config.block_size | |
| if self.training | |
| else self.config.eval_block_size, | |
| seq_length=self.config.length, | |
| ) | |
| self.register_buffer( | |
| "static_attention_mask", | |
| static_mask, | |
| ) | |
| self.skip_params_for_push.append("static_attention_mask") | |
| elif self.config.attn_backend == "flex_attention": | |
| mask = partial( | |
| self._block_mask, | |
| block_size=self.config.block_size | |
| if self.training | |
| else self.config.eval_block_size, | |
| seq_length=self.config.length, | |
| ) | |
| self.static_attention_mask = create_block_mask( | |
| mask, | |
| B=None, | |
| H=None, | |
| Q_LEN=self.config.length * 2, | |
| KV_LEN=self.config.length * 2, | |
| ) | |
| def _ensure_no_unmasked_blocks( | |
| self, | |
| input_ids: torch.LongTensor, | |
| xt: torch.LongTensor, | |
| context_mask: Optional[torch.FloatTensor] = None, | |
| ) -> torch.Tensor: | |
| n_blocks = xt.shape[1] // self.config.block_size | |
| # If context overlaps w/block, ignore it | |
| blocks_without_masks = ((xt == self.mask_token_id) + context_mask).reshape( | |
| -1, n_blocks, self.config.block_size | |
| ).sum(dim=-1) == 0 | |
| if blocks_without_masks.sum() > 0: | |
| num_remasks_per_block = torch.randint( | |
| 0, | |
| self.config.block_size, | |
| blocks_without_masks.shape, | |
| device=xt.device, | |
| ) | |
| rand = torch.rand(xt.shape[0], xt.shape[1], device=xt.device) | |
| perm_indices = torch.argsort( | |
| rand.view(xt.shape[0], n_blocks, self.config.block_size), | |
| stable=True, | |
| dim=-1, | |
| ) | |
| remask_indices = perm_indices <= num_remasks_per_block[..., None] | |
| xt = torch.where( | |
| remask_indices.view(xt.shape[0], xt.shape[1]) | |
| * blocks_without_masks.repeat_interleave(self.config.block_size, dim=1), | |
| self.mask_token_id, | |
| xt, | |
| ) | |
| if self.config.keep_clean_bos: | |
| xt[..., 0] = input_ids[..., 0] | |
| return xt | |
| def _prepare_inputs( | |
| self, | |
| input_ids: torch.LongTensor, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| context_mask: Optional[torch.FloatTensor] = None, | |
| t: Optional[torch.FloatTensor] = None, | |
| past_key_values: Optional[Cache] = None, | |
| ): | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids) | |
| if context_mask is None: | |
| context_mask = torch.zeros_like(attention_mask) | |
| if torch.is_floating_point(attention_mask): | |
| attention_mask = attention_mask.to(torch.int) | |
| context_mask = context_mask.to(torch.int) | |
| if t is None: | |
| t = torch.rand( | |
| input_ids.shape[0], | |
| input_ids.shape[1] // self.config.block_size | |
| if self.training | |
| else self.config.eval_block_size, | |
| device=input_ids.device, | |
| ).repeat_interleave( | |
| self.config.block_size | |
| if self.training | |
| else self.config.eval_block_size, | |
| dim=-1, | |
| ) | |
| alpha_t, alpha_t_prime = self.noise_schedule(t) | |
| while alpha_t.ndim < 2: | |
| alpha_t = alpha_t[..., None] | |
| alpha_t_prime = alpha_t_prime[..., None] | |
| xt = self._sample_q_xt(x0=input_ids, alpha_t=alpha_t, context_mask=context_mask) | |
| # Ensure each block has at least 1 masked token | |
| if self.training: | |
| xt = self._ensure_no_unmasked_blocks( | |
| input_ids, | |
| xt, | |
| context_mask, | |
| ) | |
| if self.config.attn_backend == "sdpa": | |
| decoder_attention_mask = ( | |
| self.static_attention_mask[None, ...] | |
| & attention_mask.repeat(1, 2)[:, None, :] | |
| & attention_mask.repeat(1, 2)[..., None] | |
| )[:, None, ...] # Make attention mask 4D | |
| decoder_attention_mask = self._preprocess_attention_mask( | |
| decoder_attention_mask, dtype=torch.float | |
| ) | |
| elif self.config.attn_backend == "flex_attention": | |
| if context_mask.any(): | |
| raise NotImplementedError( | |
| "flex_attention with context_mask not implemented yet." | |
| ) | |
| elif attention_mask is not None and (attention_mask != 1).any(): | |
| padding_mask = create_attn_mask( | |
| attention_mask.bool().repeat(2, 2).bool() | |
| ) | |
| dec_masks = [ | |
| partial( | |
| self._block_mask, | |
| block_size=self.config.block_size | |
| if self.training | |
| else self.config.eval_block_size, | |
| seq_length=self.config.length, | |
| ), | |
| padding_mask, | |
| ] | |
| decoder_attention_mask = create_block_mask( | |
| and_masks(*dec_masks), | |
| B=input_ids.shape[0], | |
| H=None, | |
| Q_LEN=input_ids.shape[1] * 2, | |
| KV_LEN=input_ids.shape[1] * 2, | |
| ) | |
| else: | |
| decoder_attention_mask = self.static_attention_mask | |
| else: | |
| raise ValueError("Unknown backbone backend") | |
| backbone_input_ids = torch.cat((input_ids, xt), dim=-1) | |
| position_ids = ( | |
| torch.arange(input_ids.shape[1]).repeat(2).to(input_ids.device)[None, :] | |
| ) | |
| if self.training and self.config.train_on_context: | |
| tokens_mask = attention_mask | |
| else: | |
| tokens_mask = attention_mask * (1 - context_mask) | |
| return DenoiserInput( | |
| xt=backbone_input_ids, # type: ignore | |
| x0=input_ids, | |
| attention_mask=decoder_attention_mask, # type: ignore | |
| tokens_mask=tokens_mask, | |
| t=t, | |
| alpha_t=alpha_t, | |
| alpha_t_prime=alpha_t_prime, | |
| backbone_kwargs={ | |
| "cache_position": position_ids[0], | |
| "position_ids": position_ids, | |
| }, | |
| ) | |
| def _prepare_inputs_inference( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| context: Optional[torch.LongTensor] = None, | |
| context_mask: Optional[torch.FloatTensor] = None, | |
| cache: Optional[Dict[str, Any]] = None, | |
| return_updated_cache: bool = False, | |
| **backbone_kwargs: Dict[str, Any], | |
| ) -> Tuple[DenoiserInput, Union[Dict[str, Any], None]]: | |
| device = input_ids.device if input_ids is not None else context.device | |
| assert input_ids is not None or context is not None, ( | |
| "Must provide either input_ids or context." | |
| ) | |
| cache = cache if cache is not None else {} | |
| past_key_values = cache.pop("past_key_values", DynamicCache()) | |
| if context is not None: | |
| if input_ids is not None: | |
| input_ids = torch.cat([context, input_ids], dim=-1) | |
| else: | |
| input_ids = context | |
| cache_length = self._get_past_key_values_seq_length(past_key_values) | |
| full_seq_length = cache_length + input_ids.shape[-1] | |
| decoder_attention_mask = self.static_attention_mask[ | |
| None, | |
| None, | |
| cache_length:full_seq_length, | |
| :full_seq_length, | |
| ] # Make attention mask 4D | |
| decoder_attention_mask = self._preprocess_attention_mask( | |
| decoder_attention_mask, dtype=torch.float | |
| ) | |
| position_ids = torch.arange(cache_length, full_seq_length).to(device)[None, :] | |
| return DenoiserInput( | |
| xt=input_ids, | |
| attention_mask=decoder_attention_mask, | |
| context_mask=context_mask, | |
| past_key_values=past_key_values, | |
| backbone_kwargs={ | |
| "position_ids": position_ids, | |
| } | |
| | backbone_kwargs, | |
| ), cache | |
| def _compute_loss( | |
| self, | |
| model_output: torch.FloatTensor, | |
| denoiser_inputs: DenoiserInput, | |
| **kwargs: Any, | |
| ) -> LossAndNllOutput: | |
| input_length = denoiser_inputs.xt.shape[1] // 2 | |
| model_output = model_output[:, input_length:, ...] | |
| return super()._compute_loss( | |
| model_output=model_output, # type: ignore | |
| denoiser_inputs=denoiser_inputs, | |
| **kwargs, | |
| ) | |
| class E2D2Config(BD3LMConfig): | |
| """Configuration class for E2D2 models.""" | |
| model_type = "e2d2" | |
| auto_map = { | |
| "AutoConfig": "diffusion.E2D2Config", | |
| "AutoModel": "diffusion.E2D2", | |
| "AutoModelForMaskedLM": "diffusion.E2D2", | |
| } | |
| def __init__( | |
| self, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| class E2D2(BD3LM): | |
| """Denoiser class for E2D2 models.""" | |
| config_class = E2D2Config | |
| def __init__(self, config: E2D2Config, **kwargs): | |
| super().__init__(config, **kwargs) | |
| # noinspection PyUnusedLocal | |
| def _encoder_block_mask( | |
| b, | |
| h, | |
| q_idx, | |
| kv_idx, | |
| block_size: Optional[int] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| q_idx (Tensor): Query indices. | |
| kv_idx (Tensor): Key indices | |
| b (Optional: int): batch size | |
| h (Optional: int): number of heads | |
| block_size (Optional: int): Defines the block structure. | |
| Returns: | |
| Encoder block-causal attention mask. | |
| """ | |
| # Compute block indices | |
| block_q = q_idx // block_size | |
| block_kv = kv_idx // block_size | |
| # ** Block-Causal Mask ** | |
| return block_q >= block_kv | |
| # noinspection PyUnusedLocal | |
| def _decoder_block_mask( | |
| b, | |
| h, | |
| q_idx, | |
| kv_idx, | |
| block_size: Optional[int] = None, | |
| seq_length: Optional[int] = None, | |
| ) -> torch.Tensor: | |
| # Indicate whether token belongs to xt or x0: | |
| xt_flag_kv = (kv_idx >= seq_length).bool() | |
| # Compute block indices | |
| block_q = q_idx // block_size | |
| block_kv = torch.where( | |
| xt_flag_kv, (kv_idx - seq_length) // block_size, kv_idx // block_size | |
| ) | |
| # **1. Offset Block-Causal Mask (M_OBC) ** | |
| offset_block_causal = (block_q > block_kv) & ~xt_flag_kv | |
| # **2. Block Diagonal Mask (M_BD) ** | |
| block_diagonal = (block_q == block_kv) & xt_flag_kv | |
| # **3. Combine Masks ** | |
| return block_diagonal | offset_block_causal | |
| def _create_static_mask(self) -> None: | |
| if self.config.attn_backend == "flex_attention": | |
| enc_mask = partial( | |
| self._encoder_block_mask, | |
| block_size=self.config.block_size | |
| if self.training | |
| else self.config.eval_block_size, | |
| ) | |
| encoder_attention_mask = create_block_mask( | |
| enc_mask, | |
| B=None, | |
| H=None, | |
| Q_LEN=self.config.length, | |
| KV_LEN=self.config.length, | |
| ) | |
| dec_mask = partial( | |
| self._decoder_block_mask, | |
| block_size=self.config.block_size | |
| if self.training | |
| else self.config.eval_block_size, | |
| seq_length=self.config.length, | |
| ) | |
| decoder_attention_mask = create_block_mask( | |
| dec_mask, | |
| B=None, | |
| H=None, | |
| Q_LEN=self.config.length, | |
| KV_LEN=self.config.length * 2, | |
| ) | |
| self.encoder_static_attention_mask = encoder_attention_mask | |
| self.static_attention_mask = decoder_attention_mask | |
| else: | |
| encoder_static_mask = self._encoder_block_mask( | |
| b=None, # type: ignore | |
| h=None, # type: ignore | |
| q_idx=torch.arange(self.config.length)[:, None], | |
| kv_idx=torch.arange(self.config.length)[None, :], | |
| block_size=self.config.block_size | |
| if self.training | |
| else self.config.eval_block_size, | |
| ) | |
| decoder_static_mask = self._decoder_block_mask( | |
| b=None, | |
| h=None, | |
| q_idx=torch.arange(self.config.length)[:, None], | |
| kv_idx=torch.arange(self.config.length * 2)[None, :], | |
| block_size=self.config.block_size | |
| if self.training | |
| else self.config.eval_block_size, | |
| seq_length=self.config.length, | |
| ) | |
| self.register_buffer( | |
| "encoder_static_attention_mask", | |
| encoder_static_mask, | |
| ) | |
| self.register_buffer( | |
| "static_attention_mask", | |
| decoder_static_mask, | |
| ) | |
| self.skip_params_for_push.append("encoder_static_attention_mask") | |
| self.skip_params_for_push.append("static_attention_mask") | |
| def _prepare_inputs( | |
| self, | |
| input_ids: torch.LongTensor, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| context_mask: Optional[torch.FloatTensor] = None, | |
| t: Optional[torch.FloatTensor] = None, | |
| past_key_values: Optional[Cache] = None, | |
| ): | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids) | |
| if context_mask is None: | |
| context_mask = torch.zeros_like(attention_mask) | |
| if torch.is_floating_point(attention_mask): | |
| attention_mask = attention_mask.to(torch.int) | |
| context_mask = context_mask.to(torch.int) | |
| if t is None: | |
| t = torch.rand( | |
| input_ids.shape[0], | |
| input_ids.shape[1] // self.config.block_size | |
| if self.training | |
| else self.config.eval_block_size, | |
| device=input_ids.device, | |
| ).repeat_interleave( | |
| self.config.block_size | |
| if self.training | |
| else self.config.eval_block_size, | |
| dim=-1, | |
| ) | |
| alpha_t, alpha_t_prime = self.noise_schedule(t) | |
| while alpha_t.ndim < 2: | |
| alpha_t = alpha_t[..., None] | |
| alpha_t_prime = alpha_t_prime[..., None] | |
| xt = self._sample_q_xt(x0=input_ids, alpha_t=alpha_t, context_mask=context_mask) | |
| # Ensure each block has at least 1 masked token | |
| if self.training: | |
| xt = self._ensure_no_unmasked_blocks( | |
| input_ids, | |
| xt, | |
| context_mask, | |
| ) | |
| if self.config.attn_backend == "sdpa": | |
| decoder_attention_mask = ( | |
| self.static_attention_mask[None, ...] | |
| & attention_mask.repeat(1, 2)[:, None, :] | |
| & attention_mask[..., None] | |
| )[:, None, ...] # Make attention mask 4D | |
| encoder_attention_mask = ( | |
| ( | |
| self.encoder_static_attention_mask[None, ...] | |
| | context_mask[:, None, :] | |
| ) | |
| & attention_mask[:, None, :] | |
| & attention_mask[..., None] | |
| )[:, None, ...] # Make attention mask 4D | |
| encoder_attention_mask = self._preprocess_attention_mask( | |
| encoder_attention_mask, dtype=torch.float | |
| ) | |
| decoder_attention_mask = self._preprocess_attention_mask( | |
| decoder_attention_mask, dtype=torch.float | |
| ) | |
| elif self.config.attn_backend == "flex_attention": | |
| # TODO enable bidirectional attention on context for seq2seq tasks | |
| if context_mask.any(): | |
| raise NotImplementedError( | |
| "flex_attention with context_mask not implemented yet." | |
| ) | |
| elif attention_mask is not None and (attention_mask != 1).any(): | |
| padding_mask = create_attn_mask(attention_mask.bool()) | |
| dec_padding_mask = create_attn_mask(attention_mask.repeat(1, 2).bool()) | |
| enc_masks = [ | |
| partial( | |
| self._encoder_block_mask, | |
| block_size=self.config.block_size | |
| if self.training | |
| else self.config.eval_block_size, | |
| ), | |
| padding_mask, | |
| ] | |
| encoder_attention_mask = create_block_mask( | |
| and_masks(*enc_masks), | |
| B=input_ids.shape[0], | |
| H=None, | |
| Q_LEN=input_ids.shape[1], | |
| KV_LEN=input_ids.shape[1], | |
| ) | |
| dec_masks = [ | |
| partial( | |
| self._decoder_block_mask, | |
| block_size=self.config.block_size | |
| if self.training | |
| else self.config.eval_block_size, | |
| seq_length=input_ids.shape[1], | |
| ), | |
| dec_padding_mask, | |
| ] | |
| decoder_attention_mask = create_block_mask( | |
| and_masks(*dec_masks), | |
| B=input_ids.shape[0], | |
| H=None, | |
| Q_LEN=input_ids.shape[1], | |
| KV_LEN=input_ids.shape[1] * 2, | |
| ) | |
| else: | |
| encoder_attention_mask = self.encoder_static_attention_mask | |
| decoder_attention_mask = self.static_attention_mask | |
| else: | |
| raise ValueError("Unknown backbone backend") | |
| position_ids = torch.arange(input_ids.shape[1]).to(input_ids.device)[None, :] | |
| if self.training and self.config.train_on_context: | |
| tokens_mask = attention_mask | |
| else: | |
| tokens_mask = attention_mask * (1 - context_mask) | |
| return DenoiserInput( | |
| xt=xt, | |
| x0=input_ids, | |
| attention_mask=decoder_attention_mask, | |
| tokens_mask=tokens_mask, | |
| t=t, | |
| alpha_t=alpha_t, | |
| alpha_t_prime=alpha_t_prime, | |
| backbone_kwargs={ | |
| "encoder_input_ids": input_ids, | |
| "encoder_attention_mask": encoder_attention_mask, | |
| "encoder_position_ids": position_ids, | |
| "encoder_cache_position": position_ids[0], | |
| }, | |
| ) | |
| def _prepare_inputs_inference( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| context: Optional[torch.LongTensor] = None, | |
| context_mask: Optional[torch.FloatTensor] = None, | |
| cache: Optional[Dict[str, Any]] = None, | |
| return_updated_cache: bool = False, | |
| **backbone_kwargs: Dict[str, Any], | |
| ) -> Tuple[DenoiserInput, Union[Dict[str, Any], None]]: | |
| device = input_ids.device if input_ids is not None else context.device | |
| batch_size = input_ids.shape[0] if input_ids is not None else context.shape[0] | |
| assert input_ids is not None or context is not None, ( | |
| "Must provide either input_ids or context." | |
| ) | |
| if return_updated_cache: # Indicates this is a cache update step | |
| context = input_ids | |
| input_ids = None | |
| position_ids, encoder_position_ids = None, None | |
| if cache is not None: | |
| past_key_values = cache.pop("past_key_values", DynamicCache()) | |
| encoder_past_key_values = cache.pop( | |
| "encoder_past_key_values", DynamicCache() | |
| ) | |
| encoder_last_hidden_state = cache.pop("encoder_last_hidden_state", None) | |
| if input_ids is not None: # Skip enc: nothing new to cache | |
| cache_length = self._get_past_key_values_seq_length(past_key_values) | |
| if encoder_last_hidden_state is not None: | |
| full_seq_length = ( | |
| cache_length | |
| + encoder_last_hidden_state.shape[1] # type: ignore | |
| + input_ids.shape[-1] | |
| ) | |
| else: | |
| full_seq_length = cache_length + input_ids.shape[-1] | |
| encoder_attention_mask = None | |
| position_ids = torch.arange( | |
| cache_length, full_seq_length, device=device | |
| )[None, :] | |
| else: # Caching new tokens in the enc | |
| encoder_cache_length = self._get_past_key_values_seq_length( | |
| encoder_past_key_values | |
| if len(encoder_past_key_values) > 0 | |
| else past_key_values | |
| ) | |
| encoder_full_seq_length = encoder_cache_length + context.shape[-1] | |
| encoder_attention_mask = torch.ones( | |
| ( | |
| 1, | |
| 1, | |
| encoder_full_seq_length - encoder_cache_length, | |
| encoder_full_seq_length, | |
| ), | |
| device=context.device, | |
| ) | |
| encoder_position_ids = torch.arange( | |
| encoder_cache_length, encoder_full_seq_length | |
| ).to(device)[None, :] | |
| encoder_attention_mask = self._preprocess_attention_mask( | |
| encoder_attention_mask, dtype=torch.float | |
| ) | |
| full_seq_length = -1 # Not used | |
| else: # Not using kv-cache | |
| past_key_values = None | |
| encoder_past_key_values, encoder_last_hidden_state = None, None | |
| if context is not None: | |
| context_len = context.shape[1] | |
| encoder_attention_mask = torch.ones( | |
| (1, 1, context_len, context_len), device=context.device | |
| ) | |
| encoder_attention_mask = self._preprocess_attention_mask( | |
| encoder_attention_mask, dtype=torch.float | |
| ) | |
| encoder_position_ids = torch.arange(context_len).to(device)[None, :] | |
| else: | |
| context_len = 0 | |
| encoder_attention_mask = None | |
| if input_ids is not None: | |
| full_seq_length = context_len + input_ids.shape[1] | |
| else: | |
| full_seq_length = context_len | |
| position_ids = torch.arange(context_len, full_seq_length).to(device)[ | |
| None, : | |
| ] | |
| if input_ids is not None: | |
| decoder_attention_mask = torch.ones( | |
| (batch_size, 1, input_ids.shape[1], full_seq_length), | |
| device=device, | |
| ) # Make attention mask 4D | |
| decoder_attention_mask = self._preprocess_attention_mask( | |
| decoder_attention_mask, dtype=torch.float | |
| ) | |
| else: | |
| decoder_attention_mask = None | |
| return DenoiserInput( | |
| xt=input_ids, | |
| attention_mask=decoder_attention_mask, | |
| context_mask=context_mask, | |
| past_key_values=past_key_values, | |
| backbone_kwargs={ | |
| "position_ids": position_ids, | |
| "encoder_input_ids": context, | |
| "encoder_position_ids": encoder_position_ids, | |
| "encoder_attention_mask": encoder_attention_mask, | |
| "encoder_past_key_values": encoder_past_key_values, | |
| "encoder_last_hidden_state": encoder_last_hidden_state, | |
| } | |
| | backbone_kwargs, | |
| ), cache # TODO: potentially returning cache None, violates return type | |
| def _compute_loss( | |
| self, | |
| model_output: torch.FloatTensor, | |
| denoiser_inputs: DenoiserInput, | |
| **kwargs: Any, | |
| ) -> LossAndNllOutput: | |
| # Use MDLM `_compute_loss`, since BD3LM method splits model_output | |
| return super(BD3LM, self)._compute_loss( | |
| model_output=model_output, | |
| denoiser_inputs=denoiser_inputs, | |
| **kwargs, | |
| ) | |