Text Generation
Transformers
Safetensors
English
nemotron-nas
nvidia
llama3.3
conversational
custom_code
Instructions to use nvidia/Llama-3_3-Nemotron-Super-49B-GenRM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nvidia/Llama-3_3-Nemotron-Super-49B-GenRM with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="nvidia/Llama-3_3-Nemotron-Super-49B-GenRM", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("nvidia/Llama-3_3-Nemotron-Super-49B-GenRM", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use nvidia/Llama-3_3-Nemotron-Super-49B-GenRM with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "nvidia/Llama-3_3-Nemotron-Super-49B-GenRM" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/Llama-3_3-Nemotron-Super-49B-GenRM", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/nvidia/Llama-3_3-Nemotron-Super-49B-GenRM
- SGLang
How to use nvidia/Llama-3_3-Nemotron-Super-49B-GenRM with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "nvidia/Llama-3_3-Nemotron-Super-49B-GenRM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/Llama-3_3-Nemotron-Super-49B-GenRM", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "nvidia/Llama-3_3-Nemotron-Super-49B-GenRM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/Llama-3_3-Nemotron-Super-49B-GenRM", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use nvidia/Llama-3_3-Nemotron-Super-49B-GenRM with Docker Model Runner:
docker model run hf.co/nvidia/Llama-3_3-Nemotron-Super-49B-GenRM
| import copy | |
| import importlib.metadata | |
| import json | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import torch | |
| from packaging import version | |
| from transformers.configuration_utils import PretrainedConfig | |
| from transformers.utils import is_torchdynamo_compiling, logging | |
| logger = logging.get_logger(__name__) | |
| class Cache(torch.nn.Module): | |
| """ | |
| Base, abstract class for all caches. The actual data structure is specific to each subclass. | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| def update( | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| layer_idx: int, | |
| cache_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | |
| Parameters: | |
| key_states (`torch.Tensor`): | |
| The new key states to cache. | |
| value_states (`torch.Tensor`): | |
| The new value states to cache. | |
| layer_idx (`int`): | |
| The index of the layer to cache the states for. | |
| cache_kwargs (`Dict[str, Any]`, `optional`): | |
| Additional arguments for the cache subclass. These are specific to each subclass and allow new types of | |
| cache to be created. | |
| Return: | |
| A tuple containing the updated key and value states. | |
| """ | |
| raise NotImplementedError("Make sure to implement `update` in a subclass.") | |
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | |
| """Returns the sequence length of the cached states. A layer index can be optionally passed.""" | |
| # TODO: deprecate this function in favor of `cache_position` | |
| raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") | |
| def get_max_length(self) -> Optional[int]: | |
| """Returns the maximum sequence length of the cached states, if there is any.""" | |
| raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") | |
| def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: | |
| """Given the sequence length of the new inputs, returns the usable length of the cache.""" | |
| # Cache without size limit -> all cache is usable | |
| # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache | |
| # length, we will need to evict part of the cache (and thus not all cache is usable) | |
| max_length = self.get_max_length() | |
| previous_seq_length = self.get_seq_length(layer_idx) | |
| if max_length is not None and previous_seq_length + new_seq_length > max_length: | |
| return max_length - new_seq_length | |
| return previous_seq_length | |
| def reorder_cache(self, beam_idx: torch.LongTensor): | |
| """Reorders the cache for beam search, given the selected beam indices.""" | |
| for layer_idx in range(len(self.key_cache)): | |
| device = self.key_cache[layer_idx].device | |
| self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) | |
| device = self.value_cache[layer_idx].device | |
| self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) | |
| def seen_tokens(self): | |
| logger.warning_once( | |
| "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " | |
| "model input instead." | |
| ) | |
| if hasattr(self, "_seen_tokens"): | |
| return self._seen_tokens | |
| else: | |
| return None | |
| class CacheConfig: | |
| """ | |
| Base class for cache configs | |
| """ | |
| cache_implementation: None | |
| def from_dict(cls, config_dict, **kwargs): | |
| """ | |
| Constructs a CacheConfig instance from a dictionary of parameters. | |
| Args: | |
| config_dict (Dict[str, Any]): Dictionary containing configuration parameters. | |
| **kwargs: Additional keyword arguments to override dictionary values. | |
| Returns: | |
| CacheConfig: Instance of CacheConfig constructed from the dictionary. | |
| """ | |
| config = cls(**config_dict) | |
| to_remove = [] | |
| for key, value in kwargs.items(): | |
| if hasattr(config, key): | |
| setattr(config, key, value) | |
| to_remove.append(key) | |
| for key in to_remove: | |
| kwargs.pop(key, None) | |
| return config | |
| # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file | |
| def to_json_file(self, json_file_path: Union[str, os.PathLike]): | |
| """ | |
| Save this instance to a JSON file. | |
| Args: | |
| json_file_path (`str` or `os.PathLike`): | |
| Path to the JSON file in which this configuration instance's parameters will be saved. | |
| use_diff (`bool`, *optional*, defaults to `True`): | |
| If set to `True`, only the difference between the config instance and the default | |
| `QuantizationConfig()` is serialized to JSON file. | |
| """ | |
| with open(json_file_path, "w", encoding="utf-8") as writer: | |
| config_dict = self.to_dict() | |
| json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" | |
| writer.write(json_string) | |
| # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict | |
| def to_dict(self) -> Dict[str, Any]: | |
| """ | |
| Serializes this instance to a Python dictionary. Returns: | |
| `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. | |
| """ | |
| return copy.deepcopy(self.__dict__) | |
| # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ | |
| def __iter__(self): | |
| """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" | |
| for attr, value in copy.deepcopy(self.__dict__).items(): | |
| yield attr, value | |
| # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ | |
| def __repr__(self): | |
| return f"{self.__class__.__name__} {self.to_json_string()}" | |
| def to_json_string(self): | |
| """ | |
| Serializes this instance to a JSON formatted string. | |
| Returns: | |
| str: JSON formatted string representing the configuration instance. | |
| """ | |
| return json.dumps(self.__dict__, indent=2) + "\n" | |
| # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update | |
| def update(self, **kwargs): | |
| """ | |
| Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, | |
| returning all the unused kwargs. | |
| Args: | |
| kwargs (`Dict[str, Any]`): | |
| Dictionary of attributes to tentatively update this class. | |
| Returns: | |
| `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. | |
| """ | |
| to_remove = [] | |
| for key, value in kwargs.items(): | |
| if hasattr(self, key): | |
| setattr(self, key, value) | |
| to_remove.append(key) | |
| # Remove all the attributes that were updated, without modifying the input dict | |
| unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} | |
| return unused_kwargs | |
| class DynamicCache(Cache): | |
| """ | |
| A cache that grows dynamically as more tokens are generated. This is the default for generative models. | |
| It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is | |
| `[batch_size, num_heads, seq_len, head_dim]`. | |
| Example: | |
| ```python | |
| >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache | |
| >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") | |
| >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") | |
| >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") | |
| >>> # Prepare a cache class and pass it to model's forward | |
| >>> past_key_values = DynamicCache() | |
| >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | |
| >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation | |
| ``` | |
| """ | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.key_cache: List[torch.Tensor] = [] | |
| self.value_cache: List[torch.Tensor] = [] | |
| self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen | |
| def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: | |
| """ | |
| Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the | |
| sequence length. | |
| """ | |
| if layer_idx < len(self): | |
| return (self.key_cache[layer_idx], self.value_cache[layer_idx]) | |
| else: | |
| raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") | |
| def __iter__(self): | |
| """ | |
| Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over | |
| keys and values | |
| """ | |
| for layer_idx in range(len(self)): | |
| yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) | |
| def __len__(self): | |
| """ | |
| Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds | |
| to the number of layers in the model. | |
| """ | |
| return len(self.key_cache) | |
| def update( | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| layer_idx: int, | |
| cache_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | |
| Parameters: | |
| key_states (`torch.Tensor`): | |
| The new key states to cache. | |
| value_states (`torch.Tensor`): | |
| The new value states to cache. | |
| layer_idx (`int`): | |
| The index of the layer to cache the states for. | |
| cache_kwargs (`Dict[str, Any]`, `optional`): | |
| Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. | |
| Return: | |
| A tuple containing the updated key and value states. | |
| """ | |
| # Update the number of seen tokens | |
| if layer_idx == 0: | |
| self._seen_tokens += key_states.shape[-2] | |
| # Update the cache | |
| if len(self.key_cache) <= layer_idx: | |
| self.key_cache.append(key_states) | |
| self.value_cache.append(value_states) | |
| else: | |
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) | |
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) | |
| return self.key_cache[layer_idx], self.value_cache[layer_idx] | |
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | |
| """Returns the sequence length of the cached states. A layer index can be optionally passed.""" | |
| # TODO: deprecate this function in favor of `cache_position` | |
| if len(self.key_cache) <= layer_idx: | |
| return 0 | |
| return self.key_cache[layer_idx].shape[-2] | |
| def get_max_length(self) -> Optional[int]: | |
| """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" | |
| return None | |
| def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: | |
| """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for | |
| backward compatibility.""" | |
| legacy_cache = () | |
| for layer_idx in range(len(self)): | |
| legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) | |
| return legacy_cache | |
| def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": | |
| """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for | |
| backward compatibility.""" | |
| cache = cls() | |
| if past_key_values is not None: | |
| for layer_idx in range(len(past_key_values)): | |
| key_states, value_states = past_key_values[layer_idx] | |
| cache.update(key_states, value_states, layer_idx) | |
| return cache | |
| def crop(self, max_length: int): | |
| """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be | |
| negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" | |
| # In case it is negative | |
| if max_length < 0: | |
| max_length = self.get_seq_length() - abs(max_length) | |
| if self.get_seq_length() <= max_length: | |
| return | |
| self._seen_tokens = max_length | |
| for idx in range(len(self.key_cache)): | |
| self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] | |
| self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] | |
| def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: | |
| """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by | |
| `_split_model_inputs()` in `generation.utils`""" | |
| out = [] | |
| for i in range(0, full_batch_size, split_size): | |
| current_split = DynamicCache() | |
| current_split._seen_tokens = self._seen_tokens | |
| current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] | |
| current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] | |
| out.append(current_split) | |
| return out | |
| def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache": | |
| """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in | |
| `generation.utils`""" | |
| cache = cls() | |
| for idx in range(len(splits[0])): | |
| layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) | |
| layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) | |
| cache.update(layer_keys, layer_values, idx) | |
| return cache | |
| def batch_repeat_interleave(self, repeats: int): | |
| """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" | |
| for layer_idx in range(len(self)): | |
| self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) | |
| self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) | |
| def batch_select_indices(self, indices: torch.Tensor): | |
| """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" | |
| for layer_idx in range(len(self)): | |
| self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] | |
| self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] | |
| class OffloadedCache(DynamicCache): | |
| """ | |
| A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory. | |
| Useful for generating from models with very long context. | |
| In addition to the default CUDA stream, where all forward() computations happen, | |
| this class uses another stream, the prefetch stream, which it creates itself. | |
| Since scheduling of operations on separate streams happens independently, this class uses | |
| the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. | |
| The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to | |
| ensure the eviction is scheduled after all computations on that cache are finished. | |
| """ | |
| def __init__(self) -> None: | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("OffloadedCache can only be used with a GPU") | |
| super().__init__() | |
| self.original_device = [] | |
| self.prefetch_stream = torch.cuda.Stream() | |
| self.beam_idx = None # used to delay beam search operations | |
| def prefetch_layer(self, layer_idx: int): | |
| "Starts prefetching the next layer cache" | |
| if layer_idx < len(self): | |
| with torch.cuda.stream(self.prefetch_stream): | |
| # Prefetch next layer tensors to GPU | |
| device = self.original_device[layer_idx] | |
| self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) | |
| self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True) | |
| def evict_previous_layer(self, layer_idx: int): | |
| "Moves the previous layer cache to the CPU" | |
| if len(self) > 2: | |
| # We do it on the default stream so it occurs after all earlier computations on these tensors are done | |
| prev_layer_idx = (layer_idx - 1) % len(self) | |
| self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True) | |
| self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True) | |
| def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: | |
| "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." | |
| if layer_idx < len(self): | |
| # Evict the previous layer if necessary | |
| torch.cuda.current_stream().synchronize() | |
| self.evict_previous_layer(layer_idx) | |
| # Load current layer cache to its original device if not already there | |
| original_device = self.original_device[layer_idx] | |
| self.prefetch_stream.synchronize() | |
| key_tensor = self.key_cache[layer_idx] | |
| value_tensor = self.value_cache[layer_idx] | |
| # Now deal with beam search ops which were delayed | |
| if self.beam_idx is not None: | |
| self.beam_idx = self.beam_idx.to(original_device) | |
| key_tensor = key_tensor.index_select(0, self.beam_idx) | |
| value_tensor = value_tensor.index_select(0, self.beam_idx) | |
| # Prefetch the next layer | |
| self.prefetch_layer((layer_idx + 1) % len(self)) | |
| return (key_tensor, value_tensor) | |
| else: | |
| raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") | |
| def reorder_cache(self, beam_idx: torch.LongTensor): | |
| """Saves the beam indices and reorders the cache when the tensor is back to its device.""" | |
| # We delay this operation until the tensors are back to their original | |
| # device because performing torch.index_select on the CPU is very slow | |
| del self.beam_idx | |
| self.beam_idx = beam_idx.clone() | |
| def update( | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| layer_idx: int, | |
| cache_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | |
| Parameters: | |
| key_states (`torch.Tensor`): | |
| The new key states to cache. | |
| value_states (`torch.Tensor`): | |
| The new value states to cache. | |
| layer_idx (`int`): | |
| The index of the layer to cache the states for. | |
| cache_kwargs (`Dict[str, Any]`, `optional`): | |
| Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`. | |
| Return: | |
| A tuple containing the updated key and value states. | |
| """ | |
| # Update the number of seen tokens | |
| if layer_idx == 0: | |
| self._seen_tokens += key_states.shape[-2] | |
| # Update the cache | |
| if len(self.key_cache) <= layer_idx: | |
| self.key_cache.append(key_states) | |
| self.value_cache.append(value_states) | |
| self.original_device.append(key_states.device) | |
| self.evict_previous_layer(layer_idx) | |
| else: | |
| key_tensor, value_tensor = self[layer_idx] | |
| self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2) | |
| self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2) | |
| return self.key_cache[layer_idx], self.value_cache[layer_idx] | |
| # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError | |
| # if a method is not supposed to be supported in a subclass we should set it to None | |
| from_legacy_cache = None | |
| to_legacy_cache = None | |
| class SinkCache(Cache): | |
| """ | |
| A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to | |
| generate beyond the length of its context window, without losing fluency in the conversation. As it discards past | |
| tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. | |
| It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is | |
| `[batch_size, num_heads, seq_len, head_dim]`. | |
| Parameters: | |
| window_length (`int`): | |
| The length of the context window. | |
| num_sink_tokens (`int`): | |
| The number of sink tokens. See the original paper for more information. | |
| Example: | |
| ```python | |
| >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache | |
| >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") | |
| >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") | |
| >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") | |
| >>> # Prepare a cache class and pass it to model's forward | |
| >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4) | |
| >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | |
| >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation | |
| ``` | |
| """ | |
| def __init__(self, window_length: int, num_sink_tokens: int) -> None: | |
| super().__init__() | |
| self.key_cache: List[torch.Tensor] = [] | |
| self.value_cache: List[torch.Tensor] = [] | |
| self.window_length = window_length | |
| self.num_sink_tokens = num_sink_tokens | |
| self.cos_sin_rerotation_cache = {} | |
| self._cos_cache = None | |
| self._sin_cache = None | |
| self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen | |
| def _rotate_half(x): | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def _apply_key_rotary_pos_emb( | |
| self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor | |
| ) -> torch.Tensor: | |
| rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) | |
| return rotated_key_states | |
| def _get_rerotation_cos_sin( | |
| self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if key_states.shape[-2] not in self.cos_sin_rerotation_cache: | |
| # Upcast to float32 temporarily for better accuracy | |
| cos = cos.to(torch.float32) | |
| sin = sin.to(torch.float32) | |
| # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence | |
| original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] | |
| shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] | |
| original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] | |
| shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] | |
| rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin | |
| rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin | |
| self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( | |
| rerotation_cos.to(key_states.dtype).unsqueeze(0), | |
| rerotation_sin.to(key_states.dtype).unsqueeze(0), | |
| ) | |
| return self.cos_sin_rerotation_cache[key_states.shape[-2]] | |
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | |
| """Returns the sequence length of the cached states. A layer index can be optionally passed.""" | |
| # TODO: deprecate this function in favor of `cache_position` | |
| # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length | |
| if len(self.key_cache) <= layer_idx: | |
| return 0 | |
| return self.key_cache[layer_idx].shape[-2] | |
| def get_max_length(self) -> Optional[int]: | |
| """Returns the maximum sequence length of the cached states.""" | |
| return self.window_length | |
| def update( | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| layer_idx: int, | |
| cache_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | |
| Parameters: | |
| key_states (`torch.Tensor`): | |
| The new key states to cache. | |
| value_states (`torch.Tensor`): | |
| The new value states to cache. | |
| layer_idx (`int`): | |
| The index of the layer to cache the states for. | |
| cache_kwargs (`Dict[str, Any]`, `optional`): | |
| Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, | |
| `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the | |
| rotation as the tokens are shifted. | |
| Return: | |
| A tuple containing the updated key and value states. | |
| """ | |
| # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models | |
| # with partially rotated position embeddings, like Phi or Persimmon. | |
| sin = cache_kwargs.get("sin") | |
| cos = cache_kwargs.get("cos") | |
| partial_rotation_size = cache_kwargs.get("partial_rotation_size") | |
| using_rope = cos is not None and sin is not None | |
| # Update the number of seen tokens | |
| if layer_idx == 0: | |
| self._seen_tokens += key_states.shape[-2] | |
| # Update the sin/cos cache, which holds sin/cos values for all possible positions | |
| if using_rope and layer_idx == 0: | |
| # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove | |
| # after all RoPE models have a llama-like cache utilization. | |
| if cos.dim() == 2: | |
| self._cos_cache = cos | |
| self._sin_cache = sin | |
| else: | |
| if self._cos_cache is None: | |
| self._cos_cache = cos[0, ...] | |
| self._sin_cache = sin[0, ...] | |
| elif self._cos_cache.shape[0] < self.window_length: | |
| self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) | |
| self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) | |
| # [bsz, num_heads, seq_len, head_dim] | |
| if len(self.key_cache) <= layer_idx: | |
| # Empty cache | |
| self.key_cache.append(key_states) | |
| self.value_cache.append(value_states) | |
| elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: | |
| # Growing cache | |
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) | |
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) | |
| else: | |
| # Shifting cache | |
| keys_to_keep = self.key_cache[layer_idx][ | |
| :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : | |
| ] | |
| # On RoPE models, we need to recompute the Key rotation as the tokens are shifted | |
| if using_rope: | |
| rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( | |
| key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length] | |
| ) | |
| if partial_rotation_size is not None: | |
| keys_to_keep, keys_pass = ( | |
| keys_to_keep[..., :partial_rotation_size], | |
| keys_to_keep[..., partial_rotation_size:], | |
| ) | |
| keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) | |
| if partial_rotation_size is not None: | |
| keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) | |
| # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens | |
| sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] | |
| self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) | |
| sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] | |
| values_to_keep = self.value_cache[layer_idx][ | |
| :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : | |
| ] | |
| self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2) | |
| return self.key_cache[layer_idx], self.value_cache[layer_idx] | |
| class StaticCache(Cache): | |
| """ | |
| Static Cache class to be used with `torch.compile(model)` and `torch.export()`. | |
| Parameters: | |
| config (`PretrainedConfig`): | |
| The configuration file defining the shape-related attributes required to initialize the static cache. | |
| max_batch_size (`int`): | |
| The maximum batch size with which the model will be used. | |
| max_cache_len (`int`): | |
| The maximum sequence length with which the model will be used. | |
| device (`torch.device`): | |
| The device on which the cache should be initialized. Should be the same as the layer. | |
| dtype (*optional*, defaults to `torch.float32`): | |
| The default `dtype` to use when initializing the layer. | |
| Example: | |
| ```python | |
| >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache | |
| >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") | |
| >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") | |
| >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") | |
| >>> # Prepare a cache class and pass it to model's forward | |
| >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate | |
| >>> max_generated_length = inputs.input_ids.shape[1] + 10 | |
| >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) | |
| >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | |
| >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation | |
| ``` | |
| """ | |
| def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: | |
| super().__init__() | |
| self.max_batch_size = max_batch_size | |
| self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len | |
| # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads | |
| self.head_dim = ( | |
| config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads | |
| ) | |
| self.dtype = dtype if dtype is not None else torch.float32 | |
| self.num_key_value_heads = ( | |
| config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads | |
| ) | |
| self.key_cache: List[torch.Tensor] = [] | |
| self.value_cache: List[torch.Tensor] = [] | |
| # Note: There will be significant perf decrease if switching to use 5D tensors instead. | |
| cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) | |
| for idx in range(config.num_hidden_layers): | |
| new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) | |
| new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) | |
| # Notes: | |
| # 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph | |
| # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case | |
| # it is not needed anyway) | |
| # 2. `torch.export()` requires mutations to be registered as buffers. | |
| if not is_torchdynamo_compiling(): | |
| self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) | |
| self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) | |
| new_layer_key_cache = getattr(self, f"key_cache_{idx}") | |
| new_layer_value_cache = getattr(self, f"value_cache_{idx}") | |
| torch._dynamo.mark_static_address(new_layer_key_cache) | |
| torch._dynamo.mark_static_address(new_layer_value_cache) | |
| self.key_cache.append(new_layer_key_cache) | |
| self.value_cache.append(new_layer_value_cache) | |
| self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen | |
| def update( | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| layer_idx: int, | |
| cache_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | |
| It is VERY important to index using a tensor, otherwise you introduce a copy to the device. | |
| Parameters: | |
| key_states (`torch.Tensor`): | |
| The new key states to cache. | |
| value_states (`torch.Tensor`): | |
| The new value states to cache. | |
| layer_idx (`int`): | |
| The index of the layer to cache the states for. | |
| cache_kwargs (`Dict[str, Any]`, `optional`): | |
| Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input | |
| to know how where to write in the cache. | |
| Return: | |
| A tuple containing the updated key and value states. | |
| """ | |
| # Update the number of seen tokens | |
| if layer_idx == 0: | |
| self._seen_tokens += key_states.shape[-2] | |
| cache_position = cache_kwargs.get("cache_position") | |
| self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) | |
| self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) | |
| k_out = self.key_cache[layer_idx] | |
| v_out = self.value_cache[layer_idx] | |
| if cache_position is None: | |
| k_out.copy_(key_states) | |
| v_out.copy_(value_states) | |
| else: | |
| # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to | |
| # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place | |
| # operation, that avoids copies and uses less memory. | |
| try: | |
| k_out.index_copy_(2, cache_position, key_states) | |
| v_out.index_copy_(2, cache_position, value_states) | |
| except NotImplementedError: | |
| # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. | |
| k_out[:, :, cache_position] = key_states | |
| v_out[:, :, cache_position] = value_states | |
| return k_out, v_out | |
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | |
| """Returns the sequence length of the cached states that were seen by the model.""" | |
| # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's | |
| # limit the check to the first batch member and head dimension. | |
| # TODO: deprecate this function in favor of `cache_position` | |
| # return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() | |
| return self._seen_tokens | |
| def get_max_length(self) -> Optional[int]: | |
| """Returns the maximum sequence length of the cached states.""" | |
| return self.max_cache_len | |
| def reset(self): | |
| self._seen_tokens = 0 | |
| """Resets the cache values while preserving the objects""" | |
| for layer_idx in range(len(self.key_cache)): | |
| # In-place ops prevent breaking the static address | |
| self.key_cache[layer_idx].zero_() | |
| self.value_cache[layer_idx].zero_() | |
| class SlidingWindowCache(StaticCache): | |
| """ | |
| Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. | |
| Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`, | |
| if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), | |
| we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. | |
| The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: | |
| indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window | |
| tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, | |
| 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, | |
| 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, | |
| 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) | |
| We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) | |
| Parameters: | |
| config (`PretrainedConfig`): | |
| The configuration file defining the shape-related attributes required to initialize the static cache. | |
| max_batch_size (`int`): | |
| The maximum batch size with which the model will be used. | |
| max_cache_len (`int`): | |
| The maximum sequence length with which the model will be used. | |
| device (`torch.device`): | |
| The device on which the cache should be initialized. Should be the same as the layer. | |
| dtype (*optional*, defaults to `torch.float32`): | |
| The default `dtype` to use when initializing the layer. | |
| Example: | |
| ```python | |
| >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache | |
| >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") | |
| >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") | |
| >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") | |
| >>> # Prepare a cache class and pass it to model's forward | |
| >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate | |
| >>> max_generated_length = inputs.input_ids.shape[1] + 10 | |
| >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) | |
| >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | |
| >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation | |
| ``` | |
| """ | |
| def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: | |
| super().__init__(config, max_batch_size, max_cache_len, device, dtype) | |
| if not hasattr(config, "sliding_window") or config.sliding_window is None: | |
| raise ValueError( | |
| "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " | |
| "sliding window attention, please check if there is a `sliding_window` field in the model " | |
| "config and it's not set to None." | |
| ) | |
| max_cache_len = min(config.sliding_window, max_cache_len) | |
| super().__init__( | |
| config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype | |
| ) | |
| def update( | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| layer_idx: int, | |
| cache_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> Tuple[torch.Tensor]: | |
| cache_position = cache_kwargs.get("cache_position") | |
| k_out = self.key_cache[layer_idx] | |
| v_out = self.value_cache[layer_idx] | |
| # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len) | |
| if cache_position.shape[0] > self.max_cache_len: | |
| k_out = key_states[:, :, -self.max_cache_len :, :] | |
| v_out = value_states[:, :, -self.max_cache_len :, :] | |
| # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly | |
| self.key_cache[layer_idx] += k_out | |
| self.value_cache[layer_idx] += v_out | |
| # we should return the whole states instead of k_out, v_out to take the whole prompt | |
| # into consideration when building kv cache instead of just throwing away tokens outside of the window | |
| return key_states, value_states | |
| slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) | |
| cache_position = cache_position.clamp(0, self.max_cache_len - 1) | |
| to_shift = cache_position >= self.max_cache_len - 1 | |
| indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len | |
| k_out = k_out[:, :, indices] | |
| v_out = v_out[:, :, indices] | |
| try: | |
| cache_position.to(device=k_out.device) | |
| k_out.index_copy_(2, cache_position, key_states) | |
| v_out.index_copy_(2, cache_position, value_states) | |
| except NotImplementedError: | |
| # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. | |
| k_out[:, :, cache_position] = key_states | |
| v_out[:, :, cache_position] = value_states | |
| # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) | |
| self.key_cache[layer_idx].zero_() | |
| self.value_cache[layer_idx].zero_() | |
| self.key_cache[layer_idx] += k_out | |
| self.value_cache[layer_idx] += v_out | |
| return k_out, v_out | |
| def get_max_length(self) -> Optional[int]: | |
| # in theory there is no limit because the sliding window size is fixed no matter how long the sentence is | |
| return None | |
| def reset(self): | |
| for layer_idx in range(len(self.key_cache)): | |
| # In-place ops prevent breaking the static address | |
| self.key_cache[layer_idx].zero_() | |
| self.value_cache[layer_idx].zero_() | |
| class EncoderDecoderCache(Cache): | |
| """ | |
| Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and | |
| cross-attention caches. | |
| Example: | |
| ```python | |
| >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache | |
| >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small") | |
| >>> processor = AutoProcessor.from_pretrained("openai/whisper-small") | |
| >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt") | |
| >>> # Prepare cache classes for encoder and decoder and pass it to model's forward | |
| >>> self_attention_cache = DynamicCache() | |
| >>> cross_attention_cache = DynamicCache() | |
| >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) | |
| >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | |
| >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation | |
| ``` | |
| """ | |
| def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): | |
| super().__init__() | |
| self.self_attention_cache = self_attention_cache | |
| self.cross_attention_cache = cross_attention_cache | |
| self.is_updated = {} | |
| for layer_idx in range(len(cross_attention_cache.key_cache)): | |
| self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) | |
| def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: | |
| """ | |
| Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the | |
| sequence length. | |
| """ | |
| if layer_idx < len(self): | |
| return ( | |
| self.self_attention_cache.key_cache[layer_idx], | |
| self.self_attention_cache.value_cache[layer_idx], | |
| self.cross_attention_cache.key_cache[layer_idx], | |
| self.cross_attention_cache.value_cache[layer_idx], | |
| ) | |
| else: | |
| raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") | |
| def __len__(self): | |
| """ | |
| Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds | |
| to the number of layers in the model. | |
| """ | |
| return len(self.self_attention_cache) | |
| def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: | |
| """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" | |
| legacy_cache = () | |
| if len(self.cross_attention_cache) > 0: | |
| for self_attn, cross_attn in zip( | |
| self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache() | |
| ): | |
| legacy_cache += (self_attn + cross_attn,) | |
| else: | |
| legacy_cache = self.self_attention_cache.to_legacy_cache() | |
| return legacy_cache | |
| def from_legacy_cache( | |
| cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
| ) -> "EncoderDecoderCache": | |
| """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" | |
| cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache()) | |
| if past_key_values is not None: | |
| for layer_idx in range(len(past_key_values)): | |
| key_states, value_states = past_key_values[layer_idx][:2] | |
| cache.self_attention_cache.update(key_states, value_states, layer_idx) | |
| if len(past_key_values[layer_idx]) > 2: | |
| key_states, value_states = past_key_values[layer_idx][2:] | |
| cache.cross_attention_cache.update(key_states, value_states, layer_idx) | |
| cache.is_updated[layer_idx] = True | |
| return cache | |
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | |
| """Returns the sequence length of the cached states. A layer index can be optionally passed.""" | |
| if len(self.self_attention_cache.key_cache) <= layer_idx: | |
| return 0 | |
| return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum() | |
| def reset(self): | |
| if hasattr(self.self_attention_cache, "reset"): | |
| self.self_attention_cache.reset() | |
| if hasattr(self.cross_attention_cache, "reset"): | |
| self.cross_attention_cache.reset() | |
| elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"): | |
| raise ValueError( | |
| "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should " | |
| "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. " | |
| f"Got {self.self_attention_cache.__str__()} for the self attention cache and " | |
| f"{self.cross_attention_cache.__str__()} for the cross attention cache." | |
| ) | |
| for layer_idx in self.is_updated: | |
| self.is_updated[layer_idx] = False | |
| def reorder_cache(self, beam_idx: torch.LongTensor): | |
| """Reorders the cache for beam search, given the selected beam indices.""" | |
| self.self_attention_cache.reorder_cache(beam_idx) | |
| self.cross_attention_cache.reorder_cache(beam_idx) | |
| def check_dynamic_cache(self, method: str): | |
| if not ( | |
| isinstance(self.self_attention_cache, DynamicCache) | |
| and isinstance(self.cross_attention_cache, DynamicCache) | |
| ): | |
| raise ValueError( | |
| f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " | |
| f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache." | |
| ) | |
| # TODO(gante, sanchit-gandhi): move following functionality into `.generate` | |
| def crop(self, maximum_length: int): | |
| """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be | |
| negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" | |
| self.check_dynamic_cache(self.crop.__name__) | |
| self.self_attention_cache.crop(maximum_length) | |
| def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]": | |
| """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by | |
| `_split_model_inputs()` in `generation.utils`""" | |
| self.check_dynamic_cache(self.batch_split.__name__) | |
| self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) | |
| cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) | |
| out = [] | |
| for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): | |
| out.append(EncoderDecoderCache(self_attn, cross_attn)) | |
| return out | |
| def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache": | |
| """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in | |
| `generation.utils`""" | |
| self_attention_cache = DynamicCache() | |
| cross_attention_cache = DynamicCache() | |
| for idx in range(len(splits[0])): | |
| layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0) | |
| layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0) | |
| self_attention_cache.update(layer_keys, layer_values, idx) | |
| layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0) | |
| layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0) | |
| cross_attention_cache.update(layer_keys, layer_values, idx) | |
| return cls(self_attention_cache, cross_attention_cache) | |
| def batch_repeat_interleave(self, repeats: int): | |
| """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" | |
| self.check_dynamic_cache(self.batch_repeat_interleave.__name__) | |
| self.self_attention_cache.batch_repeat_interleave(repeats) | |
| self.cross_attention_cache.batch_repeat_interleave(repeats) | |
| def batch_select_indices(self, indices: torch.Tensor): | |
| """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" | |
| self.check_dynamic_cache(self.batch_select_indices.__name__) | |
| self.self_attention_cache.batch_select_indices(indices) | |
| self.cross_attention_cache.batch_select_indices(indices) | |
| class HybridCache(Cache): | |
| """ | |
| Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention | |
| and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention | |
| and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. | |
| Parameters: | |
| config (`PretrainedConfig): | |
| The configuration file defining the shape-related attributes required to initialize the static cache. | |
| max_batch_size (`int`): | |
| The maximum batch size with which the model will be used. | |
| max_cache_len (`int`): | |
| The maximum sequence length with which the model will be used. | |
| device (`torch.device`, *optional*, defaults to `"cpu"`): | |
| The device on which the cache should be initialized. Should be the same as the layer. | |
| dtype (*optional*, defaults to `torch.float32`): | |
| The default `dtype` to use when initializing the layer. | |
| Example: | |
| ```python | |
| >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache | |
| >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b") | |
| >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") | |
| >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") | |
| >>> # Prepare a cache class and pass it to model's forward | |
| >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate | |
| >>> max_generated_length = inputs.input_ids.shape[1] + 10 | |
| >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) | |
| >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | |
| >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation | |
| ``` | |
| """ | |
| def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None: | |
| super().__init__() | |
| if not hasattr(config, "sliding_window") or config.sliding_window is None: | |
| raise ValueError( | |
| "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " | |
| "sliding window attention, please check if there is a `sliding_window` field in the model " | |
| "config and it's not set to None." | |
| ) | |
| self.max_cache_len = max_cache_len | |
| self.max_batch_size = max_batch_size | |
| # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads | |
| self.head_dim = ( | |
| config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads | |
| ) | |
| self.dtype = dtype if dtype is not None else torch.float32 | |
| self.num_key_value_heads = ( | |
| config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads | |
| ) | |
| self.is_sliding = torch.tensor( | |
| [not bool(i % 2) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device | |
| ) | |
| self.key_cache: List[torch.Tensor] = [] | |
| self.value_cache: List[torch.Tensor] = [] | |
| global_cache_shape = (max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) | |
| sliding_cache_shape = ( | |
| max_batch_size, | |
| self.num_key_value_heads, | |
| min(config.sliding_window, max_cache_len), | |
| self.head_dim, | |
| ) | |
| for i in range(config.num_hidden_layers): | |
| # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph | |
| # breaks when updating the cache. | |
| cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape | |
| new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) | |
| new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) | |
| torch._dynamo.mark_static_address(new_layer_key_cache) | |
| torch._dynamo.mark_static_address(new_layer_value_cache) | |
| self.key_cache.append(new_layer_key_cache) | |
| self.value_cache.append(new_layer_value_cache) | |
| def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): | |
| if cache_position.shape[0] > max_cache_len: | |
| k_out = key_states[:, :, -max_cache_len:, :] | |
| v_out = value_states[:, :, -max_cache_len:, :] | |
| # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly | |
| self.key_cache[layer_idx] += k_out | |
| self.value_cache[layer_idx] += v_out | |
| # we should return the whole states instead of k_out, v_out to take the whole prompt | |
| # into consideration when building kv cache instead of just throwing away tokens outside of the window | |
| return key_states, value_states | |
| slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) | |
| cache_position = cache_position.clamp(0, max_cache_len - 1) | |
| to_shift = cache_position >= max_cache_len - 1 | |
| indices = (slicing + to_shift[-1].int() - 1) % max_cache_len | |
| k_out = k_out[:, :, indices] | |
| v_out = v_out[:, :, indices] | |
| k_out[:, :, cache_position] = key_states | |
| v_out[:, :, cache_position] = value_states | |
| # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) | |
| self.key_cache[layer_idx].zero_() | |
| self.value_cache[layer_idx].zero_() | |
| self.key_cache[layer_idx] += k_out | |
| self.value_cache[layer_idx] += v_out | |
| return k_out, v_out | |
| def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): | |
| k_out[:, :, cache_position] = key_states | |
| v_out[:, :, cache_position] = value_states | |
| self.key_cache[layer_idx] = k_out | |
| self.value_cache[layer_idx] = v_out | |
| return k_out, v_out | |
| def update( | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| layer_idx: int, | |
| cache_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> Tuple[torch.Tensor]: | |
| cache_position = cache_kwargs.get("cache_position") | |
| sliding_window = cache_kwargs.get("sliding_window") | |
| self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) | |
| self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) | |
| k_out = self.key_cache[layer_idx] | |
| v_out = self.value_cache[layer_idx] | |
| if sliding_window: | |
| update_fn = self._sliding_update | |
| else: | |
| update_fn = self._static_update | |
| return update_fn( | |
| cache_position, | |
| layer_idx, | |
| key_states, | |
| value_states, | |
| k_out, | |
| v_out, | |
| k_out.shape[2], | |
| ) | |
| def get_max_length(self) -> Optional[int]: | |
| # in theory there is no limit because the sliding window size is fixed | |
| # no matter how long the sentence is | |
| return self.max_cache_len | |
| def get_seq_length(self, layer_idx: Optional[int] = 0): | |
| return None | |
| def reset(self): | |
| """Resets the cache values while preserving the objects""" | |
| for layer_idx in range(len(self.key_cache)): | |
| # In-place ops prevent breaking the static address | |
| self.key_cache[layer_idx].zero_() | |
| self.value_cache[layer_idx].zero_() | |
| class MambaCache: | |
| """ | |
| Cache for mamba model which does not have attention mechanism and key value states. | |
| Arguments: | |
| config (`PretrainedConfig): | |
| The configuration file defining the shape-related attributes required to initialize the static cache. | |
| max_batch_size (`int`): | |
| The maximum batch size with which the model will be used. | |
| dtype (*optional*, defaults to `torch.float16`): | |
| The default `dtype` to use when initializing the layer. | |
| device (`torch.device`, *optional*): | |
| The device on which the cache should be initialized. Should be the same as the layer. | |
| Attributes: | |
| dtype: (`torch.dtype`): | |
| The default `dtype` used to initializing the cache. | |
| intermediate_size: (`int`): | |
| Model's intermediate_size taken from config. | |
| ssm_state_size: (`int`): | |
| Model's state_size taken from config. | |
| conv_kernel_size: (`int`): | |
| Model's convolution kernel size taken from config | |
| conv_states: (`torch.Tensor`): | |
| A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states. | |
| ssm_states: (`torch.Tensor`): | |
| A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states | |
| Example: | |
| ```python | |
| >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache | |
| >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") | |
| >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") | |
| >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") | |
| >>> # Prepare a cache class and pass it to model's forward | |
| >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) | |
| >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | |
| >>> past_kv = outputs.past_key_values | |
| ``` | |
| """ | |
| def __init__( | |
| self, | |
| config: PretrainedConfig, | |
| max_batch_size: int, | |
| dtype: torch.dtype = torch.float16, | |
| device: Optional[str] = None, | |
| **kwargs, | |
| ): | |
| self.dtype = dtype | |
| self.max_batch_size = max_batch_size | |
| self.intermediate_size = config.intermediate_size | |
| self.ssm_state_size = config.state_size | |
| self.conv_kernel_size = config.conv_kernel | |
| self.conv_states: torch.Tensor = torch.zeros( | |
| config.num_hidden_layers, | |
| self.max_batch_size, | |
| self.intermediate_size, | |
| self.conv_kernel_size, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| self.ssm_states: torch.Tensor = torch.zeros( | |
| config.num_hidden_layers, | |
| self.max_batch_size, | |
| self.intermediate_size, | |
| self.ssm_state_size, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| torch._dynamo.mark_static_address(self.conv_states) | |
| torch._dynamo.mark_static_address(self.ssm_states) | |
| def update_conv_state( | |
| self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor | |
| ) -> torch.Tensor: | |
| conv_state = self.conv_states[layer_idx] | |
| cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) | |
| conv_state = conv_state.roll(shifts=-1, dims=-1) | |
| conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) | |
| self.conv_states[layer_idx].zero_() | |
| self.conv_states[layer_idx] += conv_state | |
| return self.conv_states[layer_idx] | |
| def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): | |
| self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) | |
| return self.ssm_states[layer_idx] | |
| def reset(self): | |
| self.conv_states.zero_() | |
| self.ssm_states.zero_() | |