import gc, html, re, threading, time import gradio as gr import torch from datetime import datetime from huggingface_hub import hf_hub_download from rwkv.utils import PIPELINE import rwkv7_fast_v3a as v3a ctx_limit = 7000 gen_limit = 1000 max_bsz = 64 CHUNK_LEN = 512 # chunk prefill, save VRAM SAMPLER_TOP_K = 500 YIELD_EVERY = 16 USE_CUDA_GRAPH = True html_gen_limit = 8000 MAX_HTML_PREVIEWS = 15 HTML_GRID_UPDATE_EVERY = 32 HTML_IFRAME_MIN_DELTA = 1800 HTML_IFRAME_MAX_STALE = 2.0 HTML_COMPONENT_MIN_INTERVAL = 0.5 HTML_PREVIEW_HEIGHT = 300 HTML_CAPTION_HEIGHT = 15 HTML_BODY_HEIGHT = HTML_PREVIEW_HEIGHT - HTML_CAPTION_HEIGHT HTML_FRAME_HEIGHT = 228 HTML_RAW_HEIGHT = HTML_BODY_HEIGHT - HTML_FRAME_HEIGHT HTML_PROMPT_CHOICES = [ "3D animation of cars in forest with animals", "interactive weather map with animated clouds and rain", "retro arcade RPG start screen", "3D animation of a SpaceX rocket landing on Mars", "storybook scene with a dragon flying over a castle", "interactive dashboard for a city traffic system", "animated aquarium with colorful fish and coral", "sci-fi spaceship navigation interface", "cozy cafe menu with animated steam and pastries", "character sheet for a high fantasy RPG", "a fancy hotel homepage", ] def html_prompt_from_choice(choice): return f"User: Write HTML: {choice}\n\nAssistant: div { gap: 4px !important; } .html-grid-controls { grid-column: 1 !important; grid-row: 1 !important; gap: 4px !important; min-width: 0 !important; } .html-grid-pages { grid-column: 2 !important; grid-row: 1 / span 2 !important; gap: 4px !important; min-width: 0 !important; } .html-grid-output { grid-column: 1 !important; grid-row: 2 !important; gap: 4px !important; min-width: 0 !important; } .html-grid-preview-row { gap: 4px !important; margin: 0 !important; } .html-grid-preview { margin: 0 !important; } .html-grid-preview > div { margin: 0 !important; } .html-container { padding: 0 !important; } .html-prompt-choice { margin: 0 !important; padding: 0 !important; min-height: 0 !important; } .html-prompt-choice .wrap, .html-prompt-choice .wrap-inner, .html-prompt-choice .secondary-wrap { margin: 0 !important; padding: 0 !important; min-height: 0 !important; } .html-prompt-choice label { margin: 0 !important; padding: 0 !important; } @media (max-width: 768px) { .html-grid-main { grid-template-columns: 1fr !important; grid-template-rows: auto auto auto !important; } .html-grid-controls { grid-column: 1 !important; grid-row: 1 !important; } .html-grid-pages { grid-column: 1 !important; grid-row: 2 !important; } .html-grid-output { grid-column: 1 !important; grid-row: 3 !important; } } """ html_view_lock = threading.Lock() html_view_scale = 35 html_view_scroll_seconds = 5 def clamp_html_scale(scale): return max(20, min(100, int(scale))) def clamp_scroll_seconds(seconds): return max(0, min(10, int(seconds))) def set_html_scale(scale): global html_view_scale scale = clamp_html_scale(scale) with html_view_lock: html_view_scale = scale return scale def get_html_scale(): with html_view_lock: return html_view_scale def set_scroll_seconds(seconds): global html_view_scroll_seconds seconds = clamp_scroll_seconds(seconds) with html_view_lock: html_view_scroll_seconds = seconds return seconds def get_scroll_seconds(): with html_view_lock: return html_view_scroll_seconds ########################## text rwkv ################################################################ title = "rwkv7-g1g-13.3b-20260523-ctx8192" model_path = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth") v3a.MODEL_PATH = model_path v3a.WKV_MODE = "fp32io16" v3a.EMB_DEVICE = "cuda" v3a.RKV_MODE = "off" v3a.CMIX_SPARSE = "no-fc" v3a.LOWRANK_WEIGHT = "transpose" v3a.ORIG_LINEAR_GROUPS = {"att_c2c", "ffn_key", "head"} v3a.load_extensions(v3a.WKV_MODE) model = v3a.RWKV7() gc.collect() torch.cuda.empty_cache() pipeline = PIPELINE(model, "rwkv_vocab_v20230424") @torch.jit.script def sample_logits_batch_cuda(logits, temperature: float, top_p: float, k: int): if top_p <= 0.0 or k == 1: return torch.argmax(logits, dim=-1) vals, ids = torch.topk(logits.float(), k=k, dim=-1, sorted=True) if temperature == 1.0: probs = torch.softmax(vals, dim=-1) else: probs = torch.softmax(vals / temperature, dim=-1) cdf = torch.cumsum(probs, dim=-1) if top_p < 1.0: keep = torch.argmax((cdf >= top_p).to(torch.int32), dim=-1) mass = cdf.gather(1, keep.view(-1, 1)).view(-1) else: mass = cdf[:, -1] r = torch.rand((logits.size(0), 1), device=logits.device) * mass.view(-1, 1) out = torch.searchsorted(cdf, r).view(-1, 1) return ids.gather(1, out).view(-1) def get_decode_ctx(B: int, decode_cache): cached = decode_cache.get(B) if cached is not None: return cached state = model.zero_state(B) x = torch.empty((B, 1, v3a.C), device="cuda", dtype=torch.half) path = v3a.select_path(B, 1) for _ in range(2): model.forward_from_x(x, state, path) torch.cuda.synchronize() graph = None output = None if USE_CUDA_GRAPH: try: graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): output = model.forward_from_x(x, state, path) torch.cuda.synchronize() except Exception as exc: print(f"CUDA graph disabled for B={B}: {exc}", flush=True) graph = None output = None cached = (state, x, graph, output) decode_cache[B] = cached return cached def copy_state_to_batch(dst, src): B = dst[2].shape[0] dst[0].copy_(src[0].expand(-1, -1, B, -1)) dst[1].copy_(src[1].expand(-1, B, -1, -1, -1)) dst[2].copy_(src[2].expand(B)) def tokens_to_x(tokens): token_device = "cpu" if model.emb_cpu else "cuda" if isinstance(tokens, torch.Tensor): token_tensor = tokens.to(device=token_device, dtype=torch.long, non_blocking=True).view(-1, 1) else: token_tensor = torch.tensor(tokens, dtype=torch.long, device=token_device).view(-1, 1) return model.embed(token_tensor) def generate_prompt(instruction, input=""): instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n') input = input.strip().replace('\r\n','\n').replace('\n\n','\n') if input: return f"Instruction: {instruction}\n\nInput: {input}\n\nResponse:" else: return f"User: {instruction}\n\nAssistant: 0: token_device = "cpu" if rwkv_model.emb_cpu else "cuda" tokens = torch.tensor(input_ids[:CHUNK_LEN], dtype=torch.long, device=token_device) out = rwkv_model.forward(tokens, state).view(-1) input_ids = input_ids[CHUNK_LEN:] torch.cuda.synchronize() copy_state_to_batch(decode_state, state) logits = out.view(1, -1).repeat(B, 1) else: decode_x.copy_(tokens_to_x(next_tokens)) if decode_graph is None: decode_output = rwkv_model.forward_from_x(decode_x, decode_state, v3a.select_path(B, 1)) else: decode_graph.replay() logits = decode_output.view(B, -1) if occurrence_count is None: occurrence_count = torch.zeros((B, logits.size(-1)), device=logits.device, dtype=logits.dtype) occurrence_presence = torch.zeros_like(occurrence_count) batch_rows = torch.arange(B, device=logits.device) if alpha_frequency: logits.sub_(occurrence_count, alpha=alpha_frequency) if alpha_presence: logits.sub_(occurrence_presence) assert logits.is_cuda and logits.dim() == 2 sampled_tensor = sample_logits_batch_cuda( logits, sample_temperature, sample_top_p, min(SAMPLER_TOP_K, logits.size(-1)), ) sampled = sampled_tensor.detach().cpu().tolist() active = 0 next_tokens = [0 for _ in range(B)] if penalty_decay != 1: occurrence_count.mul_(penalty_decay) occurrence_count[batch_rows, sampled_tensor] += 1 if alpha_presence: occurrence_presence[batch_rows, sampled_tensor] = alpha_presence for b in range(B): if finished[b]: continue token = sampled[b] if token == 0: finished[b] = True continue active += 1 next_tokens[b] = token all_tokens[b].append(token) tmp = pipe.decode(all_tokens[b][out_last[b]:]) if '\ufffd' not in tmp: out_str[b] += tmp out_last[b] = len(all_tokens[b]) total_tokens += active if active == 0: break if speed_t0 is None: speed_t0 = time.perf_counter() else: speed_tokens += B elapsed = max(1e-9, time.perf_counter() - speed_t0) current_text = output_text(B, out_str) speed_info = speed_text(speed_tokens / elapsed, B, total_tokens, len(current_text)) if i == 0 or i % max(1, int(yield_every)) == 0: current_text = output_text(B, out_str) yield current_text, speed_info, {"done": False, "token_counts": [len(tokens) for tokens in all_tokens]} timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") free, total = torch.cuda.mem_get_info() del out del state del decode_cache current_text = output_text(B, out_str) if speed_t0 is not None and not speed_info: speed_info = speed_text(0.0, B, total_tokens, len(current_text)) elapsed = time.perf_counter() - req_t0 final_tps = speed_tokens / max(1e-9, time.perf_counter() - speed_t0) if speed_t0 is not None else 0.0 used = total - free meta = {"done": True, "timestamp": timestamp, "B": B, "T": user_token_count, "In": input_token_count, "TPS": final_tps, "Time": elapsed, "Token": total_tokens, "Char": len(current_text), "VRAMUsed": used, "VRAMTotal": total, "token_counts": [len(tokens) for tokens in all_tokens]} yield current_text, speed_info, meta def print_summary(prefix, meta): print( f"[{prefix}] {meta['timestamp']} B={meta['B']} T={meta['T']} In={meta['In']} " f"TPS={meta['TPS']:.1f} Time={meta['Time']:.3f}s Token={meta['Token']} " f"Char={meta['Char']} VRAM={gib(meta['VRAMUsed']):.2f}G/{gib(meta['VRAMTotal']):.2f}G", flush=True, ) def evaluate_raw( ctx, token_count=200, batch_size=1, temperature=1.0, top_p=0.5, presencePenalty = 1, countPenalty = 0.1, penalty_decay = 0.99, ): for text, speed_info, meta in generate_batch_text(ctx, token_count, batch_size, temperature, top_p, presencePenalty, countPenalty, penalty_decay, YIELD_EVERY): if meta and meta.get("done"): print_summary("app3", meta) yield output_update(text, speed_info) def split_batch_output(text, count): parts = text.split("\n====\n") if text else [] parts += [""] * max(0, count - len(parts)) return parts[:count] def extract_html(text, prompt=""): stream = prompt + text lower_text = stream.lower() marker = lower_text.find("") if marker < 0: return "" visible = stream[marker + len(""):] lower = visible.lower() start = lower.find("", start) if end >= 0: return visible[start:end + len("")].strip() return visible[start:].strip() def html_complete(page): return "" in page.lower() def inject_iframe_scroll_script(page, index, scroll_seconds): scroll_seconds = clamp_scroll_seconds(scroll_seconds) if scroll_seconds <= 0: return page delay = int((index * 733) % 5000) leg_ms = scroll_seconds * 1000 script = f"""""" lower = page.lower() m = re.search(r"]*>", lower) if m: return page[:m.end()] + script + page[m.end():] m = re.search(r"]*>", lower) if m: return page[:m.end()] + "" + script + "" + page[m.end():] m = re.search(r"]*>", lower) if m: return page[:m.end()] + script + page[m.end():] return script + page def render_preview(text="", index=0, scale=35, active=True, prompt="", token_count=None, scroll_seconds=5): tokens = f"{token_count:,}" if token_count is not None else "-" caption = f"#{index + 1} | {tokens} tokens, {len(text.encode('utf-8')):,} bytes" if active else "" opacity = "1" if active else ".35" zoom = max(0.2, min(1.2, scale / 100.0)) page = extract_html(text, prompt) if not text: body = f'
' elif page: srcdoc = html.escape(inject_iframe_scroll_script(page, index, scroll_seconds), quote=True) raw = html.escape(text) body = f"""
{raw}
""" else: body = f'
{html.escape(text)}
' return f"""
{caption}
{body}
""" def empty_html_grid(): return [render_preview("", i, active=False) for i in range(MAX_HTML_PREVIEWS)] def render_html_grid_from_raw(prompt, raw_output, page_count, scale, scroll_seconds, token_counts): page_count = max(1, min(MAX_HTML_PREVIEWS, int(page_count))) scale = set_html_scale(scale) scroll_seconds = set_scroll_seconds(scroll_seconds) pages = split_batch_output(raw_output, page_count) token_counts = token_counts or [] return [render_preview(pages[i] if i < page_count else "", i, scale, i < page_count, prompt, token_counts[i] if i < len(token_counts) else None, scroll_seconds) for i in range(MAX_HTML_PREVIEWS)] def evaluate_html_grid( prompt, token_count=8000, page_count=15, scale=35, scroll_seconds=5, temperature=1.0, top_p=0.5, presence_penalty=1.0, count_penalty=0.1, penalty_decay=0.99, ): page_count = max(1, min(MAX_HTML_PREVIEWS, int(page_count))) scale = set_html_scale(scale) scroll_seconds = set_scroll_seconds(scroll_seconds) last_text = "" cached_previews = empty_html_grid() cached_html = ["" for _ in range(MAX_HTML_PREVIEWS)] cached_html_at = [0.0 for _ in range(MAX_HTML_PREVIEWS)] cached_complete = [False for _ in range(MAX_HTML_PREVIEWS)] cached_text_len = [0 for _ in range(MAX_HTML_PREVIEWS)] cached_preview_at = [0.0 for _ in range(MAX_HTML_PREVIEWS)] last_raw_at = 0.0 final_text = "" final_token_counts = [0 for _ in range(MAX_HTML_PREVIEWS)] yield [*cached_previews, output_update("", compact=True), final_token_counts, page_count] for text, speed_info, meta in generate_batch_text(prompt, token_count, page_count, temperature, top_p, presence_penalty, count_penalty, penalty_decay, HTML_GRID_UPDATE_EVERY): final_text = text done_batch = bool(meta and meta.get("done")) token_counts = meta.get("token_counts", final_token_counts) if meta else final_token_counts if token_counts: final_token_counts = token_counts + [0 for _ in range(max(0, MAX_HTML_PREVIEWS - len(token_counts)))] if done_batch: print_summary("app3-html", meta) if len(text) - len(last_text) < 300 and not done_batch: continue last_text = text now = time.monotonic() render_scale = get_html_scale() render_scroll_seconds = get_scroll_seconds() scale_changed = render_scale != scale scroll_changed = render_scroll_seconds != scroll_seconds scale = render_scale scroll_seconds = render_scroll_seconds pages = split_batch_output(text, page_count) skip = gr.skip() updates = [skip for _ in range(MAX_HTML_PREVIEWS)] for i in range(MAX_HTML_PREVIEWS): page_text = pages[i] if i < page_count else "" active = i < page_count page = extract_html(page_text, prompt) if active else "" page_tokens = final_token_counts[i] if i < len(final_token_counts) else None if now - cached_preview_at[i] < HTML_COMPONENT_MIN_INTERVAL and not done_batch and not scale_changed and not scroll_changed: continue if not page: if len(page_text) == cached_text_len[i] and active == bool(cached_text_len[i]) and not done_batch and not scale_changed and not scroll_changed: continue cached_previews[i] = render_preview(page_text, i, scale, active, prompt, page_tokens, scroll_seconds) cached_html[i] = "" cached_html_at[i] = now cached_complete[i] = False cached_text_len[i] = len(page_text) cached_preview_at[i] = now updates[i] = cached_previews[i] continue done = html_complete(page) should_reload = ( not cached_html[i] or (done and not cached_complete[i]) or done_batch or ( not cached_complete[i] and ( len(page) - len(cached_html[i]) >= HTML_IFRAME_MIN_DELTA or now - cached_html_at[i] >= HTML_IFRAME_MAX_STALE ) ) ) if should_reload: cached_previews[i] = render_preview(page_text, i, scale, active, prompt, page_tokens, scroll_seconds) cached_html[i] = page cached_html_at[i] = now cached_complete[i] = done cached_text_len[i] = len(page_text) cached_preview_at[i] = now updates[i] = cached_previews[i] raw_update = skip if now - last_raw_at >= HTML_COMPONENT_MIN_INTERVAL or done_batch: raw_update = output_update(text, speed_info, compact=True) last_raw_at = now if any(update is not skip for update in updates) or raw_update is not skip: yield [*updates, raw_update, final_token_counts, page_count] if final_text: pages = split_batch_output(final_text, page_count) scale = get_html_scale() scroll_seconds = get_scroll_seconds() final_previews = [render_preview(pages[i] if i < page_count else "", i, scale, i < page_count, prompt, final_token_counts[i] if i < len(final_token_counts) else None, scroll_seconds) for i in range(MAX_HTML_PREVIEWS)] yield [*final_previews, output_update(final_text, speed_info, compact=True), final_token_counts, page_count] examples = [ ["System: Tools:\n- get_weather(location: string, unit?: \"celsius\" | \"fahrenheit\")\n- get_stock_price(ticker: string)\n- translate_text(text: string, target_language: string)\nReturn only a JSON function call.\n\nUser: Translate \"Will it rain tomorrow?\" into Japanese.\n\nAssistant: ```json", 200, 1, 0, 0, 0, 0.99], ["System: Tools:\n[{\"name\":\"find_free_slots\",\"description\":\"Find free calendar slots\",\"arguments\":{\"date\":{\"type\":\"string\"},\"duration_minutes\":{\"type\":\"integer\"},\"time_window\":{\"type\":\"string\"}}},{\"name\":\"create_calendar_event\",\"description\":\"Create a calendar event\",\"arguments\":{\"title\":{\"type\":\"string\"},\"start_time\":{\"type\":\"string\"},\"end_time\":{\"type\":\"string\"},\"attendees\":{\"type\":\"array\",\"items\":{\"type\":\"string\"}}}}]\nReturn only a JSON function call.\n\nUser: Schedule a 30-minute sync with Bob on 2026-05-08 afternoon.\n\nAssistant: ```json\n{\"name\":\"find_free_slots\",\"arguments\":{\"date\":\"2026-05-08\",\"duration_minutes\":30,\"time_window\":\"afternoon\"}}\n```\n\nUser: Function output:\n{\"free_slots\":[{\"start\":\"2026-05-08T15:00:00+09:00\",\"end\":\"2026-05-08T15:30:00+09:00\"}],\"bob_email\":\"bob@example.com\"}\n\nAssistant: ```json", 200, 1, 0, 0, 0, 0.99], [generate_prompt("Please give the pros and cons of hodl versus active trading."), 1000, 1, 0.5, 1, 0.1, 0.99], [generate_prompt("Write a simple webpage. When a user clicks the button, it shows a random joke from a list of 4 jokes."), 1000, 1, 0.5, 1, 0.1, 0.99], ["User: What is the maximum value of $4(x + 7)(2 - x)$, over all real numbers $x$?\n\nAssistant: \n

{title}

\n") with gr.Tab("Raw Generation"): gr.Markdown(f'This is [RWKV7 G-series](https://huggingface.co/BlinkDL/rwkv7-g1) reasoning base LM - an attention-free pure RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Try topp 0.3 for math. Supports 100+ world languages and code. Check [600+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Can try examples (bottom of page) *** (can edit them). Demo limited to ctxlen {ctx_limit}.') with gr.Row(): with gr.Column(): prompt = gr.Textbox(lines=6, label="Prompt", value="User: simulate SpaceX mars landing using python\n\nAssistant: