Instructions to use RWKV-Red-Team/ARWKV-7B-Preview-0.1 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use RWKV-Red-Team/ARWKV-7B-Preview-0.1 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="RWKV-Red-Team/ARWKV-7B-Preview-0.1", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("RWKV-Red-Team/ARWKV-7B-Preview-0.1", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use RWKV-Red-Team/ARWKV-7B-Preview-0.1 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "RWKV-Red-Team/ARWKV-7B-Preview-0.1" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "RWKV-Red-Team/ARWKV-7B-Preview-0.1", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/RWKV-Red-Team/ARWKV-7B-Preview-0.1
- SGLang
How to use RWKV-Red-Team/ARWKV-7B-Preview-0.1 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "RWKV-Red-Team/ARWKV-7B-Preview-0.1" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "RWKV-Red-Team/ARWKV-7B-Preview-0.1", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "RWKV-Red-Team/ARWKV-7B-Preview-0.1" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "RWKV-Red-Team/ARWKV-7B-Preview-0.1", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use RWKV-Red-Team/ARWKV-7B-Preview-0.1 with Docker Model Runner:
docker model run hf.co/RWKV-Red-Team/ARWKV-7B-Preview-0.1
| import os | |
| import torch | |
| from einops import rearrange | |
| import math | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from .configuration_rwkv_hybrid import RwkvHybridConfig | |
| from typing import Optional | |
| from .hybrid_cache import HybridCache, AttnState, BlockState | |
| try: | |
| import triton # pylint: disable=F401 | |
| from rwkvfla.ops.rwkv7 import ( | |
| fused_recurrent_rwkv7, | |
| chunk_rwkv7, | |
| native_recurrent_rwkv7, | |
| fused_addcmul_rwkv7, | |
| ) # pylint: disable=C0411 | |
| from rwkvfla.ops.rwkv6 import ( | |
| fused_recurrent_rwkv6, | |
| chunk_rwkv6, | |
| native_recurrent_rwkv6, | |
| ) | |
| except ImportError: | |
| from rwkvfla.ops.rwkv7 import native_recurrent_rwkv7 # pylint: disable=C0411 | |
| from rwkvfla.ops.rwkv6 import native_recurrent_rwkv6 | |
| from rwkvfla.ops.rwkv7 import torch_addcmul_rwkv7 | |
| fused_recurrent_rwkv7 = native_recurrent_rwkv7 | |
| chunk_rwkv7 = native_recurrent_rwkv7 | |
| chunk_rwkv6 = native_recurrent_rwkv6 | |
| fused_recurrent_rwkv6 = native_recurrent_rwkv6 | |
| fused_addcmul_rwkv7 = torch_addcmul_rwkv7 | |
| from rwkvfla.utils import check_pytorch_version | |
| if check_pytorch_version("2.6"): | |
| compile_decorator = torch.compile | |
| torch._dynamo.config.cache_size_limit = 512 | |
| else: | |
| def compile_decorator(func): | |
| return func | |
| wkv_mode = os.environ.get("WKV_MODE", "fused") | |
| wkv_mode = wkv_mode.lower() | |
| assert wkv_mode in ['fused', 'chunk', 'pytorch'] | |
| class Rwkv_Tmix_x070(nn.Module): | |
| def __init__(self, args: RwkvHybridConfig, layer_id, **kwargs): | |
| super().__init__() | |
| self.args = args | |
| self.layer_id = layer_id | |
| self.hidden_size = args.hidden_size | |
| self.head_size = args.head_size | |
| self.n_head = args.num_wkv_heads | |
| assert args.hidden_size % self.n_head == 0 | |
| H = self.n_head | |
| N = self.head_size | |
| self.x_r = nn.Parameter(torch.Tensor(1, 1, args.hidden_size)) | |
| self.x_w = nn.Parameter(torch.Tensor(1, 1, args.hidden_size)) | |
| self.x_k = nn.Parameter(torch.Tensor(1, 1, args.hidden_size)) | |
| self.x_v = nn.Parameter(torch.Tensor(1, 1, args.hidden_size)) | |
| self.x_a = nn.Parameter(torch.Tensor(1, 1, args.hidden_size)) | |
| D_DECAY_LORA = 64 | |
| D_AAA_LORA = 64 | |
| D_MV_LORA = 32 | |
| D_GATE_LORA = 128 | |
| self.w1 = nn.Parameter(torch.Tensor(args.hidden_size, D_DECAY_LORA)) | |
| self.w2 = nn.Parameter(torch.Tensor(D_DECAY_LORA, args.hidden_size)) | |
| self.w0 = nn.Parameter(torch.Tensor(1, 1, args.hidden_size)) | |
| self.a1 = nn.Parameter(torch.Tensor(args.hidden_size, D_AAA_LORA)) | |
| self.a2 = nn.Parameter(torch.Tensor(D_AAA_LORA, args.hidden_size)) | |
| self.a0 = nn.Parameter(torch.Tensor(1, 1, args.hidden_size)) | |
| self.v1 = nn.Parameter(torch.Tensor(args.hidden_size, D_MV_LORA)) | |
| self.v2 = nn.Parameter(torch.Tensor(D_MV_LORA, args.hidden_size)) | |
| self.v0 = nn.Parameter(torch.Tensor(1, 1, args.hidden_size)) | |
| if self.args.wkv_has_gate: | |
| self.x_g = nn.Parameter(torch.Tensor(1, 1, args.hidden_size)) | |
| self.g1 = nn.Parameter(torch.Tensor(args.hidden_size, D_GATE_LORA)) | |
| self.g2 = nn.Parameter(torch.Tensor(D_GATE_LORA, args.hidden_size)) | |
| self.k_k = nn.Parameter(torch.Tensor(1, 1, args.hidden_size)) | |
| self.k_a = nn.Parameter(torch.Tensor(1, 1, args.hidden_size)) | |
| self.r_k = nn.Parameter(torch.Tensor(H, N)) | |
| self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) | |
| self.receptance = nn.Linear( | |
| args.hidden_size, args.hidden_size, bias=False) | |
| self.key = nn.Linear(args.hidden_size, args.hidden_size, bias=False) | |
| self.value = nn.Linear(args.hidden_size, args.hidden_size, bias=False) | |
| self.output = nn.Linear(args.hidden_size, args.hidden_size, bias=False) | |
| if self.args.wkv_has_group_norm: | |
| self.ln_x = nn.GroupNorm( | |
| H, args.hidden_size, eps=(1e-5) * (args.head_size_divisor**2) | |
| ) | |
| def post_init(self): | |
| with torch.no_grad(): | |
| ratio_0_to_1 = self.layer_id / \ | |
| (self.args.num_hidden_layers - 1) # 0 to 1 | |
| ratio_1_to_almost0 = 1.0 - ( | |
| self.layer_id / self.args.num_hidden_layers | |
| ) # 1 to ~0 | |
| ddd = torch.ones(1, 1, self.args.hidden_size) | |
| for i in range(self.args.hidden_size): | |
| ddd[0, 0, i] = i / self.args.hidden_size | |
| nn.init.constant_( | |
| self.x_r, 1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0)) | |
| nn.init.constant_( | |
| self.x_w, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0)) | |
| nn.init.constant_( | |
| self.x_k, | |
| 1.0 - (torch.pow(ddd, 0.9 * ratio_1_to_almost0) + | |
| 0.4 * ratio_0_to_1), | |
| ) | |
| nn.init.constant_( | |
| self.x_v, | |
| 1.0 - (torch.pow(ddd, 0.4 * ratio_1_to_almost0) + | |
| 0.6 * ratio_0_to_1), | |
| ) | |
| nn.init.constant_( | |
| self.x_a, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0)) | |
| def ortho_init(x, scale): | |
| shape = x.shape | |
| original_dtype = x.dtype | |
| x_fp32 = x.float() | |
| if len(shape) == 2: | |
| gain = math.sqrt(shape[0] / shape[1] | |
| ) if shape[0] > shape[1] else 1 | |
| nn.init.orthogonal_(x_fp32, gain=gain * scale) | |
| elif len(shape) == 3: | |
| gain = math.sqrt(shape[1] / shape[2] | |
| ) if shape[1] > shape[2] else 1 | |
| for i in range(shape[0]): | |
| nn.init.orthogonal_(x_fp32[i], gain=gain * scale) | |
| else: | |
| raise ValueError( | |
| "ortho_init only supports 2D or 3D tensors") | |
| x.data.copy_(x_fp32.to(original_dtype)) | |
| return x | |
| D_DECAY_LORA = 64 | |
| nn.init.zeros_(self.w1) | |
| self.w2 = nn.Parameter( | |
| ortho_init(torch.zeros( | |
| D_DECAY_LORA, self.args.hidden_size), 0.1) | |
| ) | |
| decay_speed = torch.ones(self.args.hidden_size) | |
| for n in range(self.args.hidden_size): | |
| decay_speed[n] = -7 + 5 * (n / (self.args.hidden_size - 1)) ** ( | |
| 0.85 + 1.0 * ratio_0_to_1**0.5 | |
| ) | |
| nn.init.constant_( | |
| self.w0, decay_speed.reshape(1, 1, self.args.hidden_size) + 0.5 | |
| ) | |
| D_AAA_LORA = 64 | |
| nn.init.zeros_(self.a1) | |
| self.a2 = nn.Parameter( | |
| ortho_init(torch.zeros(D_AAA_LORA, self.args.hidden_size), 0.1) | |
| ) | |
| nn.init.zeros_(self.a0) | |
| D_MV_LORA = 32 | |
| nn.init.zeros_(self.v1) | |
| self.v2 = nn.Parameter( | |
| ortho_init(torch.zeros(D_MV_LORA, self.args.hidden_size), 0.1) | |
| ) | |
| nn.init.constant_(self.v0, 1.0) | |
| D_GATE_LORA = 128 | |
| if self.args.wkv_has_gate: | |
| nn.init.zeros_(self.g1) | |
| self.g2 = nn.Parameter( | |
| ortho_init(torch.zeros( | |
| D_GATE_LORA, self.args.hidden_size), 0.1) | |
| ) | |
| nn.init.constant_( | |
| self.x_g, 1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0)) | |
| nn.init.constant_(self.k_k, 0.85) | |
| nn.init.constant_(self.k_a, 1.0) | |
| nn.init.zeros_(self.r_k) | |
| nn.init.zeros_(self.receptance.weight) | |
| nn.init.zeros_(self.key.weight) | |
| nn.init.zeros_(self.value.weight) | |
| nn.init.zeros_(self.output.weight) | |
| if self.args.wkv_has_group_norm: | |
| nn.init.ones_(self.ln_x.weight) | |
| nn.init.zeros_(self.ln_x.bias) | |
| def apply_wkv7_state( | |
| self, r, k, v, w, a, b, s, | |
| output_final_state, | |
| cu_seqlens | |
| ): | |
| if wkv_mode == 'pytorch': | |
| r, w, k, v, a, b = map(lambda x: rearrange( | |
| x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v, a, b)) | |
| o, state = native_recurrent_rwkv7( | |
| r=r, k=k, v=v, w=w, | |
| a=a, b=b, | |
| scale=1.0, | |
| initial_state=s, | |
| output_final_state=True, | |
| head_first=True, | |
| ) | |
| x = rearrange(o, "b h l d -> b l (h d)") | |
| else: | |
| r, w, k, v, a, b = map(lambda x: rearrange( | |
| x, 'b l (h d) -> b l h d', h=self.n_head), (r, w, k, v, a, b)) | |
| wkv7_func = chunk_rwkv7 if wkv_mode == 'chunk' else fused_recurrent_rwkv7 | |
| o, state = wkv7_func( | |
| r=r, k=k, v=v, w=w, | |
| a=a, b=b, | |
| scale=1.0, | |
| initial_state=s, | |
| output_final_state=output_final_state, | |
| cu_seqlens=cu_seqlens, | |
| head_first=False, | |
| ) | |
| x = rearrange(o, "b l h d -> b l (h d)") | |
| return x, state | |
| def forward( | |
| self, | |
| hidden_states, | |
| last_state: AttnState, | |
| use_cache: Optional[bool] = False, | |
| cu_seqlens: Optional[torch.Tensor] = None, | |
| v_first: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| **kwargs | |
| ): | |
| shift_state = last_state.shift_state | |
| B, T, C = hidden_states.size() | |
| xx = torch.concat((shift_state.unsqueeze( | |
| 1), hidden_states[:, :-1]), dim=1) - hidden_states | |
| lx = hidden_states[:, -1] | |
| if self.args.wkv_has_gate: | |
| xr, xw, xk, xv, xa, xg = fused_addcmul_rwkv7( | |
| hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a, self.x_g) | |
| else: | |
| xr, xw, xk, xv, xa, _ = fused_addcmul_rwkv7( | |
| hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a) | |
| r = self.receptance(xr) | |
| w = ( | |
| -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 | |
| ) # soft-clamp to (-inf, -0.5) | |
| k = self.key(xk) | |
| v = self.value(xv) | |
| if self.layer_id == 0: | |
| v_first = v | |
| else: | |
| v = torch.lerp(v, v_first, torch.sigmoid( | |
| self.v0 + (xv @ self.v1) @ self.v2 | |
| )) # add value residual | |
| if attention_mask is not None: | |
| v = v.mul(attention_mask[:, -v.shape[-2]:, None]) | |
| a = torch.sigmoid( | |
| self.a0 + (xa @ self.a1) @ self.a2 | |
| ) # a is "in-context learning rate" | |
| if self.args.wkv_has_gate: | |
| g = torch.sigmoid(xg @ self.g1) @ self.g2 #+ 1.0 | |
| kk = k * self.k_k | |
| kk = F.normalize(kk.view(B, T, self.n_head, -1), | |
| p=2.0, dim=-1, eps=1e-4 if kk.dtype == torch.float16 else 1e-12).view(B, T, C) | |
| k = torch.lerp(k, k * a, self.k_a) | |
| wkv_state = last_state.wkv_state | |
| hidden_states, wkv_state = self.apply_wkv7_state( | |
| r, | |
| k, | |
| v, | |
| w, | |
| -kk, | |
| (kk * a), | |
| s=wkv_state, | |
| output_final_state=use_cache, | |
| cu_seqlens=cu_seqlens | |
| ) | |
| if self.args.wkv_has_group_norm: | |
| hidden_states = self.ln_x( | |
| hidden_states.view(B * T, C)).view(B, T, C) | |
| # original code: | |
| # weighted_sum_rk = (r.view(B, T, self.n_head, -1) * k.view(B, T, self.n_head, -1) * self.r_k).sum( | |
| # dim=-1, keepdim=True | |
| # ) | |
| weighted_sum_rk = torch.einsum('btij,btij,ij->btij', r.view(B, T, self.n_head, -1), | |
| k.view(B, T, self.n_head, -1), self.r_k).sum(dim=-1, keepdim=True) | |
| hidden_states = hidden_states + \ | |
| (weighted_sum_rk * v.view(B, T, self.n_head, -1)).view(B, T, C) | |
| hidden_states = self.output( | |
| hidden_states * g) if self.args.wkv_has_gate else self.output(hidden_states) | |
| return hidden_states, AttnState(lx, wkv_state), v_first | |
| class Rwkv7Attention(nn.Module): | |
| def __init__(self, args: RwkvHybridConfig, layer_id): | |
| super().__init__() | |
| self.args = args | |
| self.layer_idx = layer_id | |
| self.time_mixer = Rwkv_Tmix_x070(args, layer_id) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| past_key_value: Optional[HybridCache] = None, | |
| output_attentions: Optional[bool] = False, | |
| use_cache: Optional[bool] = False, | |
| cache_position: Optional[torch.Tensor] = None, | |
| position_embeddings: Optional[torch.Tensor] = None, | |
| cu_seqlens: Optional[torch.Tensor] = None, | |
| v_first: Optional[torch.Tensor] = None, | |
| **kwargs | |
| ): | |
| batch_size, token_length, _ = hidden_states.shape | |
| if use_cache and len(past_key_value) > self.layer_idx: | |
| last_state = past_key_value[self.layer_idx][0] | |
| else: | |
| last_state = self.init_state( | |
| batch_size, hidden_states.device, hidden_states.dtype | |
| ) | |
| attn_output, states, v_first = self.time_mixer(hidden_states=hidden_states, | |
| last_state=last_state.attn_state, | |
| use_cache=use_cache, | |
| cu_seqlens=cu_seqlens, | |
| v_first=v_first, | |
| **kwargs) | |
| if use_cache: | |
| last_state.attn_state = states | |
| past_key_value.update(token_length, last_state, self.layer_idx) | |
| return attn_output, None, v_first | |
| def init_state(self, batch_size, device, dtype) -> BlockState: | |
| wkv_states = torch.zeros( | |
| ( | |
| batch_size, | |
| self.args.num_wkv_heads, | |
| self.args.head_size, | |
| self.args.head_size, | |
| ), | |
| device=device, | |
| dtype=torch.float32, | |
| ) | |
| shift_states = torch.zeros( | |
| (batch_size, self.args.hidden_size), device=device, dtype=dtype | |
| ) | |
| return BlockState(AttnState(shift_states, wkv_states), None) | |
| class Rwkv_Tmix_x060(nn.Module): | |
| def __init__(self, args: RwkvHybridConfig, layer_id, **kwargs): | |
| super().__init__() | |
| self.args = args | |
| self.layer_id = layer_id | |
| self.hidden_size = args.hidden_size | |
| self.head_size = args.head_size | |
| self.n_head = args.num_wkv_heads | |
| assert args.hidden_size % self.n_head == 0 | |
| with torch.no_grad(): | |
| ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 | |
| ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 | |
| ddd = torch.ones(1, 1, args.hidden_size) | |
| for i in range(args.hidden_size): | |
| ddd[0, 0, i] = i / args.hidden_size | |
| # fancy time_mix | |
| self.time_maa_x = nn.Parameter( | |
| 1.0 - torch.pow(ddd, ratio_1_to_almost0)) | |
| self.time_maa_w = nn.Parameter( | |
| 1.0 - torch.pow(ddd, ratio_1_to_almost0)) | |
| self.time_maa_k = nn.Parameter( | |
| 1.0 - torch.pow(ddd, ratio_1_to_almost0)) | |
| self.time_maa_v = nn.Parameter( | |
| 1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) | |
| ) | |
| self.time_maa_r = nn.Parameter( | |
| 1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0) | |
| ) | |
| self.time_maa_g = nn.Parameter( | |
| 1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0) | |
| ) | |
| D_MIX_LORA = 32 # generate TIME_MIX for w,k,v,r,g | |
| if args.hidden_size == 4096: | |
| D_MIX_LORA = D_MIX_LORA * 2 | |
| self.time_maa_w1 = nn.Parameter( | |
| torch.zeros(args.hidden_size, D_MIX_LORA * 5) | |
| ) | |
| self.time_maa_w2 = nn.Parameter( | |
| torch.zeros(5, D_MIX_LORA, | |
| args.hidden_size).uniform_(-0.01, 0.01) | |
| ) | |
| # fancy time_decay | |
| decay_speed = torch.ones(args.head_size) | |
| for n in range(args.head_size): | |
| decay_speed[n] = -6 + 5 * (n / (args.head_size - 1)) ** ( | |
| 0.7 + 1.3 * ratio_0_to_1 | |
| ) | |
| self.time_decay = nn.Parameter( | |
| decay_speed.reshape(1, 1, args.head_size)) | |
| D_DECAY_LORA = 64 | |
| if args.hidden_size == 4096: | |
| D_DECAY_LORA = D_DECAY_LORA * 2 | |
| self.time_decay_w1 = nn.Parameter( | |
| torch.zeros(args.hidden_size, D_DECAY_LORA) | |
| ) | |
| self.time_decay_w2 = nn.Parameter( | |
| torch.zeros(D_DECAY_LORA, args.head_size).uniform_(-0.01, 0.01) | |
| ) | |
| tmp = torch.zeros(args.head_size) | |
| for n in range(args.head_size): | |
| zigzag = ((n + 1) % 3 - 1) * 0.1 | |
| tmp[n] = ratio_0_to_1 * \ | |
| (1 - (n / (args.head_size - 1))) + zigzag | |
| self.time_faaaa = nn.Parameter( | |
| tmp.reshape(self.n_head, self.head_size)) | |
| self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) | |
| self.receptance = nn.Linear( | |
| args.hidden_size, args.head_size, bias=False) | |
| self.key = nn.Linear(args.hidden_size, args.head_size, bias=False) | |
| self.value = nn.Linear(args.hidden_size, args.head_size, bias=False) | |
| self.output = nn.Linear(args.head_size, args.hidden_size, bias=False) | |
| self.gate = nn.Linear(args.hidden_size, args.head_size, bias=False) | |
| if self.args.wkv_has_group_norm: | |
| self.ln_x = nn.GroupNorm( | |
| self.n_head, args.head_size, eps=( | |
| 1e-5) * (args.head_size_divisor**2) | |
| ) | |
| def post_init(self): | |
| pass | |
| def forward( | |
| self, | |
| hidden_states, | |
| last_state: AttnState, | |
| use_cache: Optional[bool] = False, | |
| cu_seqlens: Optional[torch.Tensor] = None, | |
| v_first: Optional[torch.Tensor] = None, | |
| **kwargs | |
| ): | |
| shift_state = last_state.shift_state | |
| B, T, C = hidden_states.size() | |
| H = self.n_head | |
| xx = torch.concat((shift_state.unsqueeze( | |
| 1), hidden_states[:, :-1]), dim=1) - hidden_states | |
| lx = hidden_states[:, -1] | |
| xxx = hidden_states + xx * self.time_maa_x | |
| xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * | |
| T, 5, -1).transpose(0, 1) | |
| xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1) | |
| mw, mk, mv, mr, mg = xxx.unbind(dim=0) | |
| xw = hidden_states + xx * (self.time_maa_w + mw) | |
| xk = hidden_states + xx * (self.time_maa_k + mk) | |
| xv = hidden_states + xx * (self.time_maa_v + mv) | |
| xr = hidden_states + xx * (self.time_maa_r + mr) | |
| xg = hidden_states + xx * (self.time_maa_g + mg) | |
| r = self.receptance(xr) | |
| k = self.key(xk) | |
| v = self.value(xv) | |
| g = F.silu(self.gate(xg)) | |
| ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2 | |
| w = self.time_decay + ww | |
| wkv_state = last_state.wkv_state | |
| hidden_states, wkv_state = self.apply_wkv6_state( | |
| B, T, C, H, r, k, v, w, u=self.time_faaaa, s=wkv_state | |
| ) | |
| if self.args.wkv_has_group_norm: | |
| hidden_states = self.ln_x( | |
| hidden_states.view(B * T, C)).view(B, T, C) | |
| hidden_states = self.output(hidden_states * g) | |
| return hidden_states, AttnState(lx, wkv_state), None | |
| def apply_wkv6_state(self, B, T, C, H, r, k, v, w, u, s): | |
| r, w, k, v = map(lambda x: rearrange( | |
| x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v)) | |
| if r.device.type == "cpu": | |
| wkv6_func = native_recurrent_rwkv6 | |
| elif self.training: | |
| wkv6_func = chunk_rwkv6 | |
| else: | |
| wkv6_func = fused_recurrent_rwkv6 | |
| o, state = wkv6_func( | |
| r, | |
| k, | |
| v, | |
| -torch.exp(w), | |
| u=u, | |
| scale=1.0, | |
| initial_state=s, | |
| output_final_state=True, | |
| ) | |
| x = rearrange(o, "b h l d -> b l (h d)") | |
| return x, state | |
| class Rwkv6Attention(nn.Module): | |
| def __init__(self, args: RwkvHybridConfig, layer_id, **kwargs): | |
| super().__init__() | |
| self.args = args | |
| self.layer_idx = layer_id | |
| self.time_mixer = Rwkv_Tmix_x060(args, layer_id, **kwargs) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| past_key_value: Optional[HybridCache] = None, | |
| output_attentions: Optional[bool] = False, | |
| use_cache: Optional[bool] = False, | |
| cache_position: Optional[torch.Tensor] = None, | |
| position_embeddings: Optional[torch.Tensor] = None, | |
| cu_seqlens: Optional[torch.Tensor] = None, | |
| v_first: Optional[torch.Tensor] = None, | |
| **kwargs | |
| ): | |
| attn_output = hidden_states | |
| batch_size, token_length, _ = hidden_states.shape | |
| if use_cache and len(past_key_value) > self.layer_idx: | |
| last_state = past_key_value[self.layer_idx][0] | |
| else: | |
| last_state = self.init_state( | |
| batch_size, hidden_states.device, hidden_states.dtype | |
| ) | |
| attn_output, states, v_first = self.time_mixer(hidden_states=hidden_states, | |
| last_state=last_state.attn_state, | |
| use_cache=use_cache, | |
| cu_seqlens=cu_seqlens, | |
| v_first=v_first, | |
| **kwargs) | |
| if use_cache: | |
| last_state.attn_state = states | |
| past_key_value.update(token_length, last_state, self.layer_idx) | |
| return attn_output, None, v_first | |
| def init_state(self, batch_size, device, dtype) -> BlockState: | |
| wkv_states = torch.zeros( | |
| ( | |
| batch_size, | |
| self.args.num_wkv_heads, | |
| self.args.head_size, | |
| self.args.head_size, | |
| ), | |
| device=device, | |
| dtype=torch.float32, | |
| ) | |
| shift_states = torch.zeros( | |
| (batch_size, self.args.hidden_size), device=device, dtype=dtype | |
| ) | |
| return BlockState(AttnState(shift_states, wkv_states), None) | |