import gc, os, re
import gradio as gr
import torch
from datetime import datetime
from huggingface_hub import hf_hub_download
from pynvml import *
from rwkv.utils import PIPELINE, PIPELINE_ARGS
import rwkv7_fast_v3a as v3a
nvmlInit()
gpu_h = nvmlDeviceGetHandleByIndex(0)
ctx_limit = 7000
gen_limit = 1000
max_bsz = 16
CHUNK_LEN = 8192 # chunk prefill, save VRAM
########################## text rwkv ################################################################
title = "rwkv7-g1f-2.9b-20260420-ctx8192"
model_path = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth")
# model_path = "/dev/shm/rwkv7-g1f-7.2b-20260414-ctx8192.pth"
v3a.MODEL_PATH = model_path
v3a.WKV_MODE = "fp32io16"
v3a.EMB_DEVICE = "cpu"
v3a.RKV_MODE = "off"
v3a.CMIX_SPARSE = "no-fc"
v3a.LOWRANK_WEIGHT = "transpose"
v3a.ORIG_LINEAR_GROUPS = {"att_c2c", "ffn_key", "head"}
v3a.load_extensions(v3a.WKV_MODE)
model = v3a.RWKV7()
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
decode_cache = {}
def get_decode_ctx(B: int):
cached = decode_cache.get(B)
if cached is not None:
return cached
state = model.zero_state(B)
x = torch.empty((B, 1, v3a.C), device="cuda", dtype=torch.half)
path = v3a.select_path(B, 1)
for _ in range(2):
model.forward_from_x(x, state, path)
torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
output = model.forward_from_x(x, state, path)
cached = (state, x, graph, output)
decode_cache[B] = cached
return cached
def copy_state_to_batch(dst, src):
B = dst[2].shape[0]
dst[0].copy_(src[0].expand(-1, -1, B, -1))
dst[1].copy_(src[1].expand(-1, B, -1, -1, -1))
dst[2].copy_(src[2].expand(B))
def tokens_to_x(tokens):
token_tensor = torch.tensor(tokens, dtype=torch.long, device="cpu" if model.emb_cpu else "cuda").view(-1, 1)
return model.embed(token_tensor)
def generate_prompt(instruction, input=""):
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
if input:
return f"Instruction: {instruction}\n\nInput: {input}\n\nResponse:"
else:
return f"User: {instruction}\n\nAssistant: 0:
token_device = "cpu" if model.emb_cpu else "cuda"
tokens = torch.tensor(input_ids[:CHUNK_LEN], dtype=torch.long, device=token_device)
out = model.forward(tokens, state).view(-1)
input_ids = input_ids[CHUNK_LEN:]
copy_state_to_batch(decode_state, state)
logits = out.view(1, -1).expand(B, -1)
else:
decode_x.copy_(tokens_to_x(next_tokens))
decode_graph.replay()
logits = decode_output.view(B, -1)
active = 0
next_tokens = [0 for _ in range(B)]
for b in range(B):
if finished[b]:
continue
row = logits[b]
for n in occurrence[b]:
row[n] -= (args.alpha_presence + occurrence[b][n] * args.alpha_frequency)
token = pipeline.sample_logits(row, temperature=args.temperature, top_p=args.top_p)
if token in args.token_stop:
finished[b] = True
continue
active += 1
next_tokens[b] = token
all_tokens[b] += [token]
for xxx in occurrence[b]:
occurrence[b][xxx] *= penalty_decay
ttt = pipeline.decode([token])
www = 1
#if ttt in ' \t0123456789':
# www = 0
#elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
# www = 0.5
if token not in occurrence[b]:
occurrence[b][token] = www
else:
occurrence[b][token] += www
tmp = pipeline.decode(all_tokens[b][out_last[b]:])
if '\ufffd' not in tmp:
out_str[b] += tmp
out_last[b] = len(all_tokens[b])
if active == 0:
break
yield out_str[0].strip() if B == 1 else "\n====\n".join(x.strip() for x in out_str)
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
del out
del state
gc.collect()
torch.cuda.empty_cache()
yield out_str[0].strip() if B == 1 else "\n====\n".join(x.strip() for x in out_str)
examples = [
["System: Tools:\n- get_weather(location: string, unit?: \"celsius\" | \"fahrenheit\")\n- get_stock_price(ticker: string)\n- translate_text(text: string, target_language: string)\nReturn only a JSON function call.\n\nUser: Translate \"Will it rain tomorrow?\" into Japanese.\n\nAssistant: ```json", 200, 1, 0, 0, 0, 0.99],
["System: Tools:\n[{\"name\":\"find_free_slots\",\"description\":\"Find free calendar slots\",\"arguments\":{\"date\":{\"type\":\"string\"},\"duration_minutes\":{\"type\":\"integer\"},\"time_window\":{\"type\":\"string\"}}},{\"name\":\"create_calendar_event\",\"description\":\"Create a calendar event\",\"arguments\":{\"title\":{\"type\":\"string\"},\"start_time\":{\"type\":\"string\"},\"end_time\":{\"type\":\"string\"},\"attendees\":{\"type\":\"array\",\"items\":{\"type\":\"string\"}}}}]\nReturn only a JSON function call.\n\nUser: Schedule a 30-minute sync with Bob on 2026-05-08 afternoon.\n\nAssistant: ```json\n{\"name\":\"find_free_slots\",\"arguments\":{\"date\":\"2026-05-08\",\"duration_minutes\":30,\"time_window\":\"afternoon\"}}\n```\n\nUser: Function output:\n{\"free_slots\":[{\"start\":\"2026-05-08T15:00:00+09:00\",\"end\":\"2026-05-08T15:30:00+09:00\"}],\"bob_email\":\"bob@example.com\"}\n\nAssistant: ```json", 200, 1, 0, 0, 0, 0.99],
[generate_prompt("Please give the pros and cons of hodl versus active trading."), gen_limit, 1, 0.5, 2, 0.2, 0.99],
[generate_prompt("Write a simple webpage. When a user clicks the button, it shows a random joke from a list of 4 jokes."), gen_limit, 1, 0.5, 2, 0.2, 0.99],
["User: What is the maximum value of $4(x + 7)(2 - x)$, over all real numbers $x$?\n\nAssistant: \n{title}
\n")
with gr.Tab("=== Base Model (Raw Generation) ==="):
gr.Markdown(f'This is [RWKV7 G-series](https://huggingface.co/BlinkDL/rwkv7-g1) reasoning base LM - an attention-free pure RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Try topp 0.3 for math. Supports 100+ world languages and code. Check [600+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Can try examples (bottom of page) *** (can edit them). Demo limited to ctxlen {ctx_limit}.')
with gr.Row():
with gr.Column():
prompt = gr.Textbox(lines=6, label="Prompt", value="User: simulate SpaceX mars landing using python\n\nAssistant: