Instructions to use RWKV-Red-Team/ARWKV-7B-Preview-0.1 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use RWKV-Red-Team/ARWKV-7B-Preview-0.1 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="RWKV-Red-Team/ARWKV-7B-Preview-0.1", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("RWKV-Red-Team/ARWKV-7B-Preview-0.1", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use RWKV-Red-Team/ARWKV-7B-Preview-0.1 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "RWKV-Red-Team/ARWKV-7B-Preview-0.1" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "RWKV-Red-Team/ARWKV-7B-Preview-0.1", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/RWKV-Red-Team/ARWKV-7B-Preview-0.1
- SGLang
How to use RWKV-Red-Team/ARWKV-7B-Preview-0.1 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "RWKV-Red-Team/ARWKV-7B-Preview-0.1" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "RWKV-Red-Team/ARWKV-7B-Preview-0.1", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "RWKV-Red-Team/ARWKV-7B-Preview-0.1" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "RWKV-Red-Team/ARWKV-7B-Preview-0.1", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use RWKV-Red-Team/ARWKV-7B-Preview-0.1 with Docker Model Runner:
docker model run hf.co/RWKV-Red-Team/ARWKV-7B-Preview-0.1
Upload 3 files
Browse files- configuration_rwkv_hybrid.py +8 -6
- hybrid_cache.py +31 -110
configuration_rwkv_hybrid.py
CHANGED
|
@@ -15,9 +15,9 @@
|
|
| 15 |
# limitations under the License.
|
| 16 |
"""RwkvHybrid model configuration"""
|
| 17 |
|
| 18 |
-
from
|
| 19 |
-
from
|
| 20 |
-
from
|
| 21 |
from typing import Optional, Union, List
|
| 22 |
|
| 23 |
|
|
@@ -218,15 +218,17 @@ class RwkvHybridConfig(PretrainedConfig):
|
|
| 218 |
raise NotImplementedError(f"Unsupported wkv_version: {self.wkv_version}, \
|
| 219 |
wkv_version must be 6 or 7")
|
| 220 |
|
| 221 |
-
if wkv_layers == "full" or wkv_layers
|
| 222 |
self.wkv_layers = list(range(num_hidden_layers))
|
| 223 |
elif isinstance(wkv_layers, list):
|
| 224 |
if all(isinstance(layer, int) for layer in wkv_layers):
|
| 225 |
self.wkv_layers = wkv_layers
|
| 226 |
else:
|
| 227 |
-
raise ValueError(
|
|
|
|
| 228 |
else:
|
| 229 |
-
raise TypeError(
|
|
|
|
| 230 |
|
| 231 |
# for backward compatibility
|
| 232 |
if num_key_value_heads is None:
|
|
|
|
| 15 |
# limitations under the License.
|
| 16 |
"""RwkvHybrid model configuration"""
|
| 17 |
|
| 18 |
+
from ...configuration_utils import PretrainedConfig
|
| 19 |
+
from ...modeling_rope_utils import rope_config_validation
|
| 20 |
+
from ...utils import logging
|
| 21 |
from typing import Optional, Union, List
|
| 22 |
|
| 23 |
|
|
|
|
| 218 |
raise NotImplementedError(f"Unsupported wkv_version: {self.wkv_version}, \
|
| 219 |
wkv_version must be 6 or 7")
|
| 220 |
|
| 221 |
+
if wkv_layers == "full" or wkv_layers is None:
|
| 222 |
self.wkv_layers = list(range(num_hidden_layers))
|
| 223 |
elif isinstance(wkv_layers, list):
|
| 224 |
if all(isinstance(layer, int) for layer in wkv_layers):
|
| 225 |
self.wkv_layers = wkv_layers
|
| 226 |
else:
|
| 227 |
+
raise ValueError(
|
| 228 |
+
"All elements in wkv_layers must be integers.")
|
| 229 |
else:
|
| 230 |
+
raise TypeError(
|
| 231 |
+
"wkv_layers must be either 'full', None, or a list of integers.")
|
| 232 |
|
| 233 |
# for backward compatibility
|
| 234 |
if num_key_value_heads is None:
|
hybrid_cache.py
CHANGED
|
@@ -3,109 +3,69 @@ from typing import Any, Dict, Optional, Union
|
|
| 3 |
from transformers.cache_utils import DynamicCache
|
| 4 |
|
| 5 |
|
| 6 |
-
class
|
| 7 |
def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor):
|
| 8 |
self.shift_state = shift_state
|
| 9 |
self.wkv_state = wkv_state
|
| 10 |
|
| 11 |
|
| 12 |
-
class
|
| 13 |
def __init__(self, shift_state: torch.Tensor):
|
| 14 |
self.shift_state = shift_state
|
| 15 |
|
| 16 |
|
| 17 |
class BlockState:
|
| 18 |
-
def __init__(
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def __init__(self, shift_states, wkv_states):
|
| 26 |
-
self.wkv_states = wkv_states
|
| 27 |
-
self.shift_states = shift_states
|
| 28 |
-
|
| 29 |
-
@staticmethod
|
| 30 |
-
def create(N, B, C, H, device, dtype):
|
| 31 |
-
result = BlockStateList.empty(N, B, C, H, device, dtype)
|
| 32 |
-
result.wkv_states[:] = 0
|
| 33 |
-
result.wkv_states[:] = 0
|
| 34 |
-
result.shift_states[:] = 0
|
| 35 |
-
return result
|
| 36 |
-
|
| 37 |
-
@staticmethod
|
| 38 |
-
def empty(N, B, C, H, device, dtype):
|
| 39 |
-
wkv_states = torch.empty((N, B, H, C//H, C//H),
|
| 40 |
-
device=device,
|
| 41 |
-
dtype=torch.bfloat16)
|
| 42 |
-
shift_states = torch.empty((N, 2, B, C), device=device, dtype=dtype)
|
| 43 |
-
return BlockStateList(shift_states, wkv_states)
|
| 44 |
-
|
| 45 |
-
def __getitem__(self, layer: int):
|
| 46 |
-
return BlockState(
|
| 47 |
-
TimeMixState(self.shift_states[layer, 0], self.wkv_states[layer]),
|
| 48 |
-
ChannelMixState(self.shift_states[layer, 1]))
|
| 49 |
-
|
| 50 |
-
def __setitem__(self, layer: int, state: BlockState):
|
| 51 |
-
self.shift_states[layer, 0] = state.time_mix_state.shift_state
|
| 52 |
-
self.wkv_states[layer] = state.time_mix_state.wkv_state
|
| 53 |
-
self.shift_states[layer, 1] = state.channel_mix_state.shift_state
|
| 54 |
-
|
| 55 |
|
| 56 |
class HybridCache(DynamicCache):
|
| 57 |
def __init__(self) -> None:
|
| 58 |
super().__init__()
|
| 59 |
self.rwkv_layers = set()
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
for data in cache:
|
| 72 |
-
if not isinstance(data, torch.Tensor):
|
| 73 |
-
memories += data.time_mix_state.wkv_state.numel()
|
| 74 |
-
else:
|
| 75 |
-
memories += data.numel()
|
| 76 |
-
count_info += f", memories={memories / 1024/1024}MB, seq_length={seq_length}"
|
| 77 |
-
return count_info
|
| 78 |
-
|
| 79 |
-
def update(self,
|
| 80 |
-
key_states: Union[int, torch.Tensor],
|
| 81 |
-
value_states: Union[torch.Tensor, BlockState],
|
| 82 |
-
layer_idx: int,
|
| 83 |
-
cache_kwargs: Optional[Dict[str, Any]] = None):
|
| 84 |
-
if isinstance(key_states, int) and not isinstance(value_states, torch.Tensor):
|
| 85 |
self.rwkv_layers.add(layer_idx)
|
| 86 |
-
|
|
|
|
| 87 |
self.key_cache.append([])
|
| 88 |
self.value_cache.append([])
|
| 89 |
-
|
| 90 |
-
if len(self.key_cache[layer_idx]) == 0:
|
| 91 |
self.key_cache[layer_idx].append(key_states)
|
| 92 |
self.value_cache[layer_idx].append(value_states)
|
|
|
|
|
|
|
| 93 |
else:
|
| 94 |
-
self.key_cache[layer_idx][0] =
|
| 95 |
self.value_cache[layer_idx][0] = value_states
|
| 96 |
|
| 97 |
return key_states, value_states
|
| 98 |
|
| 99 |
return super().update(key_states, value_states, layer_idx, cache_kwargs)
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
def get_seq_length(self, layer_idx: Optional[int] = 0):
|
| 102 |
if layer_idx in self.rwkv_layers:
|
| 103 |
return self.key_cache[layer_idx][0]
|
| 104 |
return super().get_seq_length(layer_idx)
|
| 105 |
|
| 106 |
-
def get_max_length(self):
|
| 107 |
-
return super().get_max_length()
|
| 108 |
-
|
| 109 |
def reorder_cache(self, beam_idx):
|
| 110 |
return super().reorder_cache(beam_idx)
|
| 111 |
|
|
@@ -113,42 +73,3 @@ class HybridCache(DynamicCache):
|
|
| 113 |
if item in self.rwkv_layers:
|
| 114 |
return self.value_cache[item]
|
| 115 |
return super().__getitem__(item)
|
| 116 |
-
|
| 117 |
-
def offload_to_cpu(self):
|
| 118 |
-
for cache in self.value_cache:
|
| 119 |
-
for data in cache:
|
| 120 |
-
if isinstance(data, torch.Tensor):
|
| 121 |
-
data.cpu()
|
| 122 |
-
else:
|
| 123 |
-
data.time_mix_state.wkv_state.cpu()
|
| 124 |
-
data.time_mix_state.shift_state.cpu()
|
| 125 |
-
|
| 126 |
-
def offload_to_cuda(self, device: str):
|
| 127 |
-
for cache in self.value_cache:
|
| 128 |
-
for data in cache:
|
| 129 |
-
if isinstance(data, torch.Tensor):
|
| 130 |
-
data.cuda(device)
|
| 131 |
-
else:
|
| 132 |
-
data.time_mix_state.wkv_state.cuda(device)
|
| 133 |
-
data.time_mix_state.shift_state.cuda(device)
|
| 134 |
-
|
| 135 |
-
def offload_to_device(self, device_type: str, device_id: int = 0):
|
| 136 |
-
for cache in self.value_cache:
|
| 137 |
-
for data in cache:
|
| 138 |
-
if isinstance(data, torch.Tensor):
|
| 139 |
-
method = getattr(data, device_type)
|
| 140 |
-
if device_type == 'cpu':
|
| 141 |
-
method()
|
| 142 |
-
else:
|
| 143 |
-
method(device_id)
|
| 144 |
-
else:
|
| 145 |
-
wkv_state_method = getattr(
|
| 146 |
-
data.time_mix_state.wkv_state, device_type)
|
| 147 |
-
shift_state_method = getattr(
|
| 148 |
-
data.time_mix_state.shift_state, device_type)
|
| 149 |
-
if device_type == 'cpu':
|
| 150 |
-
wkv_state_method()
|
| 151 |
-
shift_state_method()
|
| 152 |
-
else:
|
| 153 |
-
wkv_state_method(device_id)
|
| 154 |
-
shift_state_method(device_id)
|
|
|
|
| 3 |
from transformers.cache_utils import DynamicCache
|
| 4 |
|
| 5 |
|
| 6 |
+
class AttnState:
|
| 7 |
def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor):
|
| 8 |
self.shift_state = shift_state
|
| 9 |
self.wkv_state = wkv_state
|
| 10 |
|
| 11 |
|
| 12 |
+
class FfnState:
|
| 13 |
def __init__(self, shift_state: torch.Tensor):
|
| 14 |
self.shift_state = shift_state
|
| 15 |
|
| 16 |
|
| 17 |
class BlockState:
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
attn_state: AttnState,
|
| 21 |
+
ffn_state: FfnState
|
| 22 |
+
):
|
| 23 |
+
self.attn_state = attn_state
|
| 24 |
+
self.ffn_state = ffn_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
class HybridCache(DynamicCache):
|
| 27 |
def __init__(self) -> None:
|
| 28 |
super().__init__()
|
| 29 |
self.rwkv_layers = set()
|
| 30 |
+
self.key_cache_nums = 0
|
| 31 |
+
self.v_first_cache = None
|
| 32 |
+
|
| 33 |
+
def update(
|
| 34 |
+
self,
|
| 35 |
+
key_states: Union[int, torch.Tensor],
|
| 36 |
+
value_states: Union[torch.Tensor, BlockState],
|
| 37 |
+
layer_idx: int,
|
| 38 |
+
cache_kwargs: Optional[Dict[str, Any]] = None
|
| 39 |
+
):
|
| 40 |
+
if isinstance(key_states, int) and isinstance(value_states, BlockState):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
self.rwkv_layers.add(layer_idx)
|
| 42 |
+
|
| 43 |
+
if layer_idx >= self.key_cache_nums:
|
| 44 |
self.key_cache.append([])
|
| 45 |
self.value_cache.append([])
|
|
|
|
|
|
|
| 46 |
self.key_cache[layer_idx].append(key_states)
|
| 47 |
self.value_cache[layer_idx].append(value_states)
|
| 48 |
+
self.key_cache_nums += 1
|
| 49 |
+
|
| 50 |
else:
|
| 51 |
+
self.key_cache[layer_idx][0] += key_states
|
| 52 |
self.value_cache[layer_idx][0] = value_states
|
| 53 |
|
| 54 |
return key_states, value_states
|
| 55 |
|
| 56 |
return super().update(key_states, value_states, layer_idx, cache_kwargs)
|
| 57 |
|
| 58 |
+
def update_v_first(self, v_first: torch.Tensor):
|
| 59 |
+
self.v_first_cache = v_first
|
| 60 |
+
|
| 61 |
+
def get_v_first(self):
|
| 62 |
+
return self.v_first_cache
|
| 63 |
+
|
| 64 |
def get_seq_length(self, layer_idx: Optional[int] = 0):
|
| 65 |
if layer_idx in self.rwkv_layers:
|
| 66 |
return self.key_cache[layer_idx][0]
|
| 67 |
return super().get_seq_length(layer_idx)
|
| 68 |
|
|
|
|
|
|
|
|
|
|
| 69 |
def reorder_cache(self, beam_idx):
|
| 70 |
return super().reorder_cache(beam_idx)
|
| 71 |
|
|
|
|
| 73 |
if item in self.rwkv_layers:
|
| 74 |
return self.value_cache[item]
|
| 75 |
return super().__getitem__(item)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|