import gc, os, re import gradio as gr import torch from datetime import datetime from huggingface_hub import hf_hub_download from pynvml import * from rwkv.utils import PIPELINE, PIPELINE_ARGS import rwkv7_fast_v3a as v3a nvmlInit() # gpu_h = nvmlDeviceGetHandleByIndex(0) ctx_limit = 7000 gen_limit = 1000 ########################## text rwkv ################################################################ title = "rwkv7-g1f-2.9b-20260420-ctx8192" model_path = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth") # model_path = "/dev/shm/rwkv7-g1f-7.2b-20260414-ctx8192.pth" v3a.MODEL_PATH = model_path v3a.WKV_MODE = "fp32io16" v3a.EMB_DEVICE = "cpu" v3a.RKV_MODE = "off" v3a.CMIX_SPARSE = "no-fc" v3a.LOWRANK_WEIGHT = "transpose" v3a.ORIG_LINEAR_GROUPS = {"att_c2c", "ffn_key", "head"} v3a.load_extensions(v3a.WKV_MODE) model = v3a.RWKV7() pipeline = PIPELINE(model, "rwkv_vocab_v20230424") decode_state = model.zero_state(1) decode_x = torch.empty((1, 1, v3a.C), device="cuda", dtype=torch.half) decode_path = v3a.select_path(1, 1) for _ in range(2): model.forward_from_x(decode_x, decode_state, decode_path) torch.cuda.synchronize() decode_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(decode_graph): decode_output = model.forward_from_x(decode_x, decode_state, decode_path) def token_to_x(token: int): token_tensor = torch.tensor([[int(token)]], dtype=torch.long, device="cpu" if model.emb_cpu else "cuda") return model.embed(token_tensor) def generate_prompt(instruction, input=""): instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n') input = input.strip().replace('\r\n','\n').replace('\n\n','\n') if input: return f"Instruction: {instruction}\n\nInput: {input}\n\nResponse:" else: return f"User: {instruction}\n\nAssistant: 0: token_device = "cpu" if model.emb_cpu else "cuda" tokens = torch.tensor(input_ids[:CHUNK_LEN], dtype=torch.long, device=token_device) out = model.forward(tokens, state).view(-1) input_ids = input_ids[CHUNK_LEN:] for dst, src in zip(decode_state, state): dst.copy_(src) logits = out else: decode_x.copy_(token_to_x(token)) decode_graph.replay() logits = decode_output.view(-1) for n in occurrence: logits[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) token = pipeline.sample_logits(logits, temperature=args.temperature, top_p=args.top_p) if token in args.token_stop: break all_tokens += [token] for xxx in occurrence: occurrence[xxx] *= penalty_decay ttt = pipeline.decode([token]) www = 1 #if ttt in ' \t0123456789': # www = 0 #elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】': # www = 0.5 if token not in occurrence: occurrence[token] = www else: occurrence[token] += www tmp = pipeline.decode(all_tokens[out_last:]) if '\ufffd' not in tmp: out_str += tmp yield out_str.strip() out_last = i + 1 gpu_info = nvmlDeviceGetMemoryInfo(gpu_h) timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}') del out del state gc.collect() torch.cuda.empty_cache() yield out_str.strip() examples = [ ["System: Tools:\n- get_weather(location: string, unit?: \"celsius\" | \"fahrenheit\")\n- get_stock_price(ticker: string)\n- translate_text(text: string, target_language: string)\nReturn only a JSON function call.\n\nUser: Translate \"Will it rain tomorrow?\" into Japanese.\n\nAssistant: ```json", 200, 1, 0, 0, 0, 0.99], ["System: Tools:\n[{\"name\":\"find_free_slots\",\"description\":\"Find free calendar slots\",\"arguments\":{\"date\":{\"type\":\"string\"},\"duration_minutes\":{\"type\":\"integer\"},\"time_window\":{\"type\":\"string\"}}},{\"name\":\"create_calendar_event\",\"description\":\"Create a calendar event\",\"arguments\":{\"title\":{\"type\":\"string\"},\"start_time\":{\"type\":\"string\"},\"end_time\":{\"type\":\"string\"},\"attendees\":{\"type\":\"array\",\"items\":{\"type\":\"string\"}}}}]\nReturn only a JSON function call.\n\nUser: Schedule a 30-minute sync with Bob on 2026-05-08 afternoon.\n\nAssistant: ```json\n{\"name\":\"find_free_slots\",\"arguments\":{\"date\":\"2026-05-08\",\"duration_minutes\":30,\"time_window\":\"afternoon\"}}\n```\n\nUser: Function output:\n{\"free_slots\":[{\"start\":\"2026-05-08T15:00:00+09:00\",\"end\":\"2026-05-08T15:30:00+09:00\"}],\"bob_email\":\"bob@example.com\"}\n\nAssistant: ```json", 200, 1, 0, 0, 0, 0.99], [generate_prompt("Please give the pros and cons of hodl versus active trading."), gen_limit, 1, 0.5, 2, 0.2, 0.99], [generate_prompt("Write a simple webpage. When a user clicks the button, it shows a random joke from a list of 4 jokes."), gen_limit, 1, 0.5, 2, 0.2, 0.99], ["User: What is the maximum value of $4(x + 7)(2 - x)$, over all real numbers $x$?\n\nAssistant: \n

{title}

\n") with gr.Tab("=== Base Model (Raw Generation) ==="): gr.Markdown(f'This is [RWKV7 G-series](https://huggingface.co/BlinkDL/rwkv7-g1) reasoning base LM - an attention-free pure RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Try topp 0.3 for math. Supports 100+ world languages and code. Check [600+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Can try examples (bottom of page) *** (can edit them). Demo limited to ctxlen {ctx_limit}.') with gr.Row(): with gr.Column(): prompt = gr.Textbox(lines=6, label="Prompt", value="User: simulate SpaceX mars landing using python\n\nAssistant: