import gc, os, re import gradio as gr import torch from datetime import datetime from huggingface_hub import hf_hub_download from pynvml import * from rwkv.utils import PIPELINE, PIPELINE_ARGS import rwkv7_fast_v3a as v3a nvmlInit() gpu_h = nvmlDeviceGetHandleByIndex(0) ctx_limit = 7000 gen_limit = 1000 max_bsz = 16 CHUNK_LEN = 8192 # chunk prefill, save VRAM ########################## text rwkv ################################################################ title = "rwkv7-g1f-2.9b-20260420-ctx8192" model_path = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth") # model_path = "/dev/shm/rwkv7-g1f-7.2b-20260414-ctx8192.pth" v3a.MODEL_PATH = model_path v3a.WKV_MODE = "fp32io16" v3a.EMB_DEVICE = "cpu" v3a.RKV_MODE = "off" v3a.CMIX_SPARSE = "no-fc" v3a.LOWRANK_WEIGHT = "transpose" v3a.ORIG_LINEAR_GROUPS = {"att_c2c", "ffn_key", "head"} v3a.load_extensions(v3a.WKV_MODE) model = v3a.RWKV7() pipeline = PIPELINE(model, "rwkv_vocab_v20230424") decode_cache = {} def get_decode_ctx(B: int): cached = decode_cache.get(B) if cached is not None: return cached state = model.zero_state(B) x = torch.empty((B, 1, v3a.C), device="cuda", dtype=torch.half) path = v3a.select_path(B, 1) for _ in range(2): model.forward_from_x(x, state, path) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): output = model.forward_from_x(x, state, path) cached = (state, x, graph, output) decode_cache[B] = cached return cached def copy_state_to_batch(dst, src): B = dst[2].shape[0] dst[0].copy_(src[0].expand(-1, -1, B, -1)) dst[1].copy_(src[1].expand(-1, B, -1, -1, -1)) dst[2].copy_(src[2].expand(B)) def tokens_to_x(tokens): token_tensor = torch.tensor(tokens, dtype=torch.long, device="cpu" if model.emb_cpu else "cuda").view(-1, 1) return model.embed(token_tensor) def generate_prompt(instruction, input=""): instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n') input = input.strip().replace('\r\n','\n').replace('\n\n','\n') if input: return f"Instruction: {instruction}\n\nInput: {input}\n\nResponse:" else: return f"User: {instruction}\n\nAssistant: 0: token_device = "cpu" if model.emb_cpu else "cuda" tokens = torch.tensor(input_ids[:CHUNK_LEN], dtype=torch.long, device=token_device) out = model.forward(tokens, state).view(-1) input_ids = input_ids[CHUNK_LEN:] copy_state_to_batch(decode_state, state) logits = out.view(1, -1).expand(B, -1) else: decode_x.copy_(tokens_to_x(next_tokens)) decode_graph.replay() logits = decode_output.view(B, -1) active = 0 next_tokens = [0 for _ in range(B)] for b in range(B): if finished[b]: continue row = logits[b] for n in occurrence[b]: row[n] -= (args.alpha_presence + occurrence[b][n] * args.alpha_frequency) token = pipeline.sample_logits(row, temperature=args.temperature, top_p=args.top_p) if token in args.token_stop: finished[b] = True continue active += 1 next_tokens[b] = token all_tokens[b] += [token] for xxx in occurrence[b]: occurrence[b][xxx] *= penalty_decay ttt = pipeline.decode([token]) www = 1 #if ttt in ' \t0123456789': # www = 0 #elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】': # www = 0.5 if token not in occurrence[b]: occurrence[b][token] = www else: occurrence[b][token] += www tmp = pipeline.decode(all_tokens[b][out_last[b]:]) if '\ufffd' not in tmp: out_str[b] += tmp out_last[b] = len(all_tokens[b]) if active == 0: break yield out_str[0].strip() if B == 1 else "\n====\n".join(x.strip() for x in out_str) gpu_info = nvmlDeviceGetMemoryInfo(gpu_h) timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}') del out del state gc.collect() torch.cuda.empty_cache() yield out_str[0].strip() if B == 1 else "\n====\n".join(x.strip() for x in out_str) examples = [ ["System: Tools:\n- get_weather(location: string, unit?: \"celsius\" | \"fahrenheit\")\n- get_stock_price(ticker: string)\n- translate_text(text: string, target_language: string)\nReturn only a JSON function call.\n\nUser: Translate \"Will it rain tomorrow?\" into Japanese.\n\nAssistant: ```json", 200, 1, 0, 0, 0, 0.99], ["System: Tools:\n[{\"name\":\"find_free_slots\",\"description\":\"Find free calendar slots\",\"arguments\":{\"date\":{\"type\":\"string\"},\"duration_minutes\":{\"type\":\"integer\"},\"time_window\":{\"type\":\"string\"}}},{\"name\":\"create_calendar_event\",\"description\":\"Create a calendar event\",\"arguments\":{\"title\":{\"type\":\"string\"},\"start_time\":{\"type\":\"string\"},\"end_time\":{\"type\":\"string\"},\"attendees\":{\"type\":\"array\",\"items\":{\"type\":\"string\"}}}}]\nReturn only a JSON function call.\n\nUser: Schedule a 30-minute sync with Bob on 2026-05-08 afternoon.\n\nAssistant: ```json\n{\"name\":\"find_free_slots\",\"arguments\":{\"date\":\"2026-05-08\",\"duration_minutes\":30,\"time_window\":\"afternoon\"}}\n```\n\nUser: Function output:\n{\"free_slots\":[{\"start\":\"2026-05-08T15:00:00+09:00\",\"end\":\"2026-05-08T15:30:00+09:00\"}],\"bob_email\":\"bob@example.com\"}\n\nAssistant: ```json", 200, 1, 0, 0, 0, 0.99], [generate_prompt("Please give the pros and cons of hodl versus active trading."), gen_limit, 1, 0.5, 2, 0.2, 0.99], [generate_prompt("Write a simple webpage. When a user clicks the button, it shows a random joke from a list of 4 jokes."), gen_limit, 1, 0.5, 2, 0.2, 0.99], ["User: What is the maximum value of $4(x + 7)(2 - x)$, over all real numbers $x$?\n\nAssistant: \n

{title}

\n") with gr.Tab("=== Base Model (Raw Generation) ==="): gr.Markdown(f'This is [RWKV7 G-series](https://huggingface.co/BlinkDL/rwkv7-g1) reasoning base LM - an attention-free pure RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Try topp 0.3 for math. Supports 100+ world languages and code. Check [600+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Can try examples (bottom of page) *** (can edit them). Demo limited to ctxlen {ctx_limit}.') with gr.Row(): with gr.Column(): prompt = gr.Textbox(lines=6, label="Prompt", value="User: simulate SpaceX mars landing using python\n\nAssistant: