Spaces:
Running on T4
Running on T4
| import gc, html, re, threading, time | |
| import gradio as gr | |
| import torch | |
| from datetime import datetime | |
| from huggingface_hub import hf_hub_download | |
| from rwkv.utils import PIPELINE | |
| import rwkv7_fast_v3a as v3a | |
| ctx_limit = 7000 | |
| gen_limit = 1000 | |
| max_bsz = 64 | |
| CHUNK_LEN = 512 # chunk prefill, save VRAM | |
| SAMPLER_TOP_K = 500 | |
| YIELD_EVERY = 16 | |
| USE_CUDA_GRAPH = True | |
| html_gen_limit = 8000 | |
| MAX_HTML_PREVIEWS = 15 | |
| HTML_GRID_UPDATE_EVERY = 32 | |
| HTML_IFRAME_MIN_DELTA = 1800 | |
| HTML_IFRAME_MAX_STALE = 2.0 | |
| HTML_COMPONENT_MIN_INTERVAL = 0.5 | |
| HTML_PREVIEW_HEIGHT = 300 | |
| HTML_CAPTION_HEIGHT = 15 | |
| HTML_BODY_HEIGHT = HTML_PREVIEW_HEIGHT - HTML_CAPTION_HEIGHT | |
| HTML_FRAME_HEIGHT = 228 | |
| HTML_RAW_HEIGHT = HTML_BODY_HEIGHT - HTML_FRAME_HEIGHT | |
| HTML_PROMPT_CHOICES = [ | |
| "3D animation of cars in forest with animals", | |
| "interactive weather map with animated clouds and rain", | |
| "retro arcade RPG start screen", | |
| "3D animation of a SpaceX rocket landing on Mars", | |
| "storybook scene with a dragon flying over a castle", | |
| "interactive dashboard for a city traffic system", | |
| "animated aquarium with colorful fish and coral", | |
| "sci-fi spaceship navigation interface", | |
| "cozy cafe menu with animated steam and pastries", | |
| "character sheet for a high fantasy RPG", | |
| "a fancy hotel homepage", | |
| ] | |
| def html_prompt_from_choice(choice): | |
| return f"User: Write HTML: {choice}\n\nAssistant: <think></think" | |
| DEFAULT_HTML_PROMPT = html_prompt_from_choice(HTML_PROMPT_CHOICES[0]) | |
| HTML_GRID_CSS = """ | |
| div.main { padding-left: 0 !important; padding-right: 0 !important; } | |
| .html-grid-tab { padding-top: 0 !important; } | |
| .html-grid-main { display: grid !important; grid-template-columns: minmax(220px, 17.5%) 1fr !important; grid-template-rows: auto auto !important; gap: 4px !important; margin-top: 0 !important; align-items: start !important; } | |
| .html-grid-main > div { gap: 4px !important; } | |
| .html-grid-controls { grid-column: 1 !important; grid-row: 1 !important; gap: 4px !important; min-width: 0 !important; } | |
| .html-grid-pages { grid-column: 2 !important; grid-row: 1 / span 2 !important; gap: 4px !important; min-width: 0 !important; } | |
| .html-grid-output { grid-column: 1 !important; grid-row: 2 !important; gap: 4px !important; min-width: 0 !important; } | |
| .html-grid-preview-row { gap: 4px !important; margin: 0 !important; } | |
| .html-grid-preview { margin: 0 !important; } | |
| .html-grid-preview > div { margin: 0 !important; } | |
| .html-container { padding: 0 !important; } | |
| .html-prompt-choice { margin: 0 !important; padding: 0 !important; min-height: 0 !important; } | |
| .html-prompt-choice .wrap, | |
| .html-prompt-choice .wrap-inner, | |
| .html-prompt-choice .secondary-wrap { margin: 0 !important; padding: 0 !important; min-height: 0 !important; } | |
| .html-prompt-choice label { margin: 0 !important; padding: 0 !important; } | |
| @media (max-width: 768px) { | |
| .html-grid-main { grid-template-columns: 1fr !important; grid-template-rows: auto auto auto !important; } | |
| .html-grid-controls { grid-column: 1 !important; grid-row: 1 !important; } | |
| .html-grid-pages { grid-column: 1 !important; grid-row: 2 !important; } | |
| .html-grid-output { grid-column: 1 !important; grid-row: 3 !important; } | |
| } | |
| """ | |
| html_view_lock = threading.Lock() | |
| html_view_scale = 35 | |
| html_view_scroll_seconds = 5 | |
| def clamp_html_scale(scale): | |
| return max(20, min(100, int(scale))) | |
| def clamp_scroll_seconds(seconds): | |
| return max(0, min(10, int(seconds))) | |
| def set_html_scale(scale): | |
| global html_view_scale | |
| scale = clamp_html_scale(scale) | |
| with html_view_lock: | |
| html_view_scale = scale | |
| return scale | |
| def get_html_scale(): | |
| with html_view_lock: | |
| return html_view_scale | |
| def set_scroll_seconds(seconds): | |
| global html_view_scroll_seconds | |
| seconds = clamp_scroll_seconds(seconds) | |
| with html_view_lock: | |
| html_view_scroll_seconds = seconds | |
| return seconds | |
| def get_scroll_seconds(): | |
| with html_view_lock: | |
| return html_view_scroll_seconds | |
| ########################## text rwkv ################################################################ | |
| title = "rwkv7-g1g-13.3b-20260523-ctx8192" | |
| model_path = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth") | |
| v3a.MODEL_PATH = model_path | |
| v3a.WKV_MODE = "fp32io16" | |
| v3a.EMB_DEVICE = "cuda" | |
| v3a.RKV_MODE = "off" | |
| v3a.CMIX_SPARSE = "no-fc" | |
| v3a.LOWRANK_WEIGHT = "transpose" | |
| v3a.ORIG_LINEAR_GROUPS = {"att_c2c", "ffn_key", "head"} | |
| v3a.load_extensions(v3a.WKV_MODE) | |
| model = v3a.RWKV7() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| pipeline = PIPELINE(model, "rwkv_vocab_v20230424") | |
| def sample_logits_batch_cuda(logits, temperature: float, top_p: float, k: int): | |
| if top_p <= 0.0 or k == 1: | |
| return torch.argmax(logits, dim=-1) | |
| vals, ids = torch.topk(logits.float(), k=k, dim=-1, sorted=True) | |
| if temperature == 1.0: | |
| probs = torch.softmax(vals, dim=-1) | |
| else: | |
| probs = torch.softmax(vals / temperature, dim=-1) | |
| cdf = torch.cumsum(probs, dim=-1) | |
| if top_p < 1.0: | |
| keep = torch.argmax((cdf >= top_p).to(torch.int32), dim=-1) | |
| mass = cdf.gather(1, keep.view(-1, 1)).view(-1) | |
| else: | |
| mass = cdf[:, -1] | |
| r = torch.rand((logits.size(0), 1), device=logits.device) * mass.view(-1, 1) | |
| out = torch.searchsorted(cdf, r).view(-1, 1) | |
| return ids.gather(1, out).view(-1) | |
| def get_decode_ctx(B: int, decode_cache): | |
| cached = decode_cache.get(B) | |
| if cached is not None: | |
| return cached | |
| state = model.zero_state(B) | |
| x = torch.empty((B, 1, v3a.C), device="cuda", dtype=torch.half) | |
| path = v3a.select_path(B, 1) | |
| for _ in range(2): | |
| model.forward_from_x(x, state, path) | |
| torch.cuda.synchronize() | |
| graph = None | |
| output = None | |
| if USE_CUDA_GRAPH: | |
| try: | |
| graph = torch.cuda.CUDAGraph() | |
| with torch.cuda.graph(graph): | |
| output = model.forward_from_x(x, state, path) | |
| torch.cuda.synchronize() | |
| except Exception as exc: | |
| print(f"CUDA graph disabled for B={B}: {exc}", flush=True) | |
| graph = None | |
| output = None | |
| cached = (state, x, graph, output) | |
| decode_cache[B] = cached | |
| return cached | |
| def copy_state_to_batch(dst, src): | |
| B = dst[2].shape[0] | |
| dst[0].copy_(src[0].expand(-1, -1, B, -1)) | |
| dst[1].copy_(src[1].expand(-1, B, -1, -1, -1)) | |
| dst[2].copy_(src[2].expand(B)) | |
| def tokens_to_x(tokens): | |
| token_device = "cpu" if model.emb_cpu else "cuda" | |
| if isinstance(tokens, torch.Tensor): | |
| token_tensor = tokens.to(device=token_device, dtype=torch.long, non_blocking=True).view(-1, 1) | |
| else: | |
| token_tensor = torch.tensor(tokens, dtype=torch.long, device=token_device).view(-1, 1) | |
| return model.embed(token_tensor) | |
| def generate_prompt(instruction, input=""): | |
| instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n') | |
| input = input.strip().replace('\r\n','\n').replace('\n\n','\n') | |
| if input: | |
| return f"Instruction: {instruction}\n\nInput: {input}\n\nResponse:" | |
| else: | |
| return f"User: {instruction}\n\nAssistant: <think></think" | |
| def qa_prompt(instruction): | |
| instruction = instruction.strip().replace('\r\n','\n') | |
| instruction = re.sub(r'\n+', '\n', instruction) | |
| return f"User: {instruction}\n\nAssistant: <think></think" | |
| def output_update(text, speed="", compact=False): | |
| if compact and speed: | |
| speed = speed.split(" @ ", 1)[0] | |
| label = f"Output {speed}" if speed else "Output" | |
| return gr.update(value=text, label=label) | |
| def output_text(B, out_str): | |
| return out_str[0].strip() if B == 1 else "\n====\n".join(x.strip() for x in out_str) | |
| def speed_text(rate, B, tokens, chars): | |
| return f"{rate:.1f} token/s @ bsz {B} = {tokens} tokens, {chars} chars" | |
| def gib(n): | |
| return n / 1_000_000_000.0 | |
| def generate_batch_text( | |
| ctx, | |
| token_count=200, | |
| batch_size=1, | |
| temperature=1.0, | |
| top_p=0.5, | |
| presencePenalty = 1, | |
| countPenalty = 0.1, | |
| penalty_decay = 0.99, | |
| yield_every = YIELD_EVERY, | |
| ): | |
| req_t0 = time.perf_counter() | |
| user_token_count = int(token_count) | |
| rwkv_model = model | |
| pipe = pipeline | |
| sample_temperature = float(temperature) | |
| sample_top_p = float(top_p) | |
| if sample_temperature <= 0: | |
| sample_temperature = 1.0 | |
| sample_top_p = 0 | |
| else: | |
| sample_temperature = max(0.2, sample_temperature) | |
| alpha_frequency = float(countPenalty) | |
| alpha_presence = float(presencePenalty) | |
| ctx = ctx.strip() | |
| input_ids = pipe.encode(ctx)[-ctx_limit:] | |
| input_token_count = len(input_ids) | |
| B = min(max_bsz, max(1, int(batch_size))) | |
| batch_rows = None | |
| all_tokens = [[] for _ in range(B)] | |
| out_last = [0 for _ in range(B)] | |
| out_str = ['' for _ in range(B)] | |
| occurrence_count = None | |
| occurrence_presence = None | |
| finished = [False for _ in range(B)] | |
| speed_t0 = None | |
| speed_tokens = 0 | |
| total_tokens = 0 | |
| speed_info = "" | |
| decode_cache = {} | |
| state = rwkv_model.zero_state(1) | |
| decode_state, decode_x, decode_graph, decode_output = get_decode_ctx(B, decode_cache) | |
| next_tokens = [0 for _ in range(B)] | |
| out = None | |
| for i in range(int(token_count)): | |
| if i == 0: | |
| if len(input_ids) == 0: | |
| yield "", "", {"done": True, "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "B": B, "T": user_token_count, "In": input_token_count, "TPS": 0.0, "Time": 0.0, "Token": 0, "Char": 0, "VRAMUsed": 0, "VRAMTotal": 0, "token_counts": [0 for _ in range(B)]} | |
| return | |
| while len(input_ids) > 0: | |
| token_device = "cpu" if rwkv_model.emb_cpu else "cuda" | |
| tokens = torch.tensor(input_ids[:CHUNK_LEN], dtype=torch.long, device=token_device) | |
| out = rwkv_model.forward(tokens, state).view(-1) | |
| input_ids = input_ids[CHUNK_LEN:] | |
| torch.cuda.synchronize() | |
| copy_state_to_batch(decode_state, state) | |
| logits = out.view(1, -1).repeat(B, 1) | |
| else: | |
| decode_x.copy_(tokens_to_x(next_tokens)) | |
| if decode_graph is None: | |
| decode_output = rwkv_model.forward_from_x(decode_x, decode_state, v3a.select_path(B, 1)) | |
| else: | |
| decode_graph.replay() | |
| logits = decode_output.view(B, -1) | |
| if occurrence_count is None: | |
| occurrence_count = torch.zeros((B, logits.size(-1)), device=logits.device, dtype=logits.dtype) | |
| occurrence_presence = torch.zeros_like(occurrence_count) | |
| batch_rows = torch.arange(B, device=logits.device) | |
| if alpha_frequency: | |
| logits.sub_(occurrence_count, alpha=alpha_frequency) | |
| if alpha_presence: | |
| logits.sub_(occurrence_presence) | |
| assert logits.is_cuda and logits.dim() == 2 | |
| sampled_tensor = sample_logits_batch_cuda( | |
| logits, | |
| sample_temperature, | |
| sample_top_p, | |
| min(SAMPLER_TOP_K, logits.size(-1)), | |
| ) | |
| sampled = sampled_tensor.detach().cpu().tolist() | |
| active = 0 | |
| next_tokens = [0 for _ in range(B)] | |
| if penalty_decay != 1: | |
| occurrence_count.mul_(penalty_decay) | |
| occurrence_count[batch_rows, sampled_tensor] += 1 | |
| if alpha_presence: | |
| occurrence_presence[batch_rows, sampled_tensor] = alpha_presence | |
| for b in range(B): | |
| if finished[b]: | |
| continue | |
| token = sampled[b] | |
| if token == 0: | |
| finished[b] = True | |
| continue | |
| active += 1 | |
| next_tokens[b] = token | |
| all_tokens[b].append(token) | |
| tmp = pipe.decode(all_tokens[b][out_last[b]:]) | |
| if '\ufffd' not in tmp: | |
| out_str[b] += tmp | |
| out_last[b] = len(all_tokens[b]) | |
| total_tokens += active | |
| if active == 0: | |
| break | |
| if speed_t0 is None: | |
| speed_t0 = time.perf_counter() | |
| else: | |
| speed_tokens += B | |
| elapsed = max(1e-9, time.perf_counter() - speed_t0) | |
| current_text = output_text(B, out_str) | |
| speed_info = speed_text(speed_tokens / elapsed, B, total_tokens, len(current_text)) | |
| if i == 0 or i % max(1, int(yield_every)) == 0: | |
| current_text = output_text(B, out_str) | |
| yield current_text, speed_info, {"done": False, "token_counts": [len(tokens) for tokens in all_tokens]} | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| free, total = torch.cuda.mem_get_info() | |
| del out | |
| del state | |
| del decode_cache | |
| current_text = output_text(B, out_str) | |
| if speed_t0 is not None and not speed_info: | |
| speed_info = speed_text(0.0, B, total_tokens, len(current_text)) | |
| elapsed = time.perf_counter() - req_t0 | |
| final_tps = speed_tokens / max(1e-9, time.perf_counter() - speed_t0) if speed_t0 is not None else 0.0 | |
| used = total - free | |
| meta = {"done": True, "timestamp": timestamp, "B": B, "T": user_token_count, "In": input_token_count, "TPS": final_tps, "Time": elapsed, "Token": total_tokens, "Char": len(current_text), "VRAMUsed": used, "VRAMTotal": total, "token_counts": [len(tokens) for tokens in all_tokens]} | |
| yield current_text, speed_info, meta | |
| def print_summary(prefix, meta): | |
| print( | |
| f"[{prefix}] {meta['timestamp']} B={meta['B']} T={meta['T']} In={meta['In']} " | |
| f"TPS={meta['TPS']:.1f} Time={meta['Time']:.3f}s Token={meta['Token']} " | |
| f"Char={meta['Char']} VRAM={gib(meta['VRAMUsed']):.2f}G/{gib(meta['VRAMTotal']):.2f}G", | |
| flush=True, | |
| ) | |
| def evaluate_raw( | |
| ctx, | |
| token_count=200, | |
| batch_size=1, | |
| temperature=1.0, | |
| top_p=0.5, | |
| presencePenalty = 1, | |
| countPenalty = 0.1, | |
| penalty_decay = 0.99, | |
| ): | |
| for text, speed_info, meta in generate_batch_text(ctx, token_count, batch_size, temperature, top_p, presencePenalty, countPenalty, penalty_decay, YIELD_EVERY): | |
| if meta and meta.get("done"): | |
| print_summary("app3", meta) | |
| yield output_update(text, speed_info) | |
| def split_batch_output(text, count): | |
| parts = text.split("\n====\n") if text else [] | |
| parts += [""] * max(0, count - len(parts)) | |
| return parts[:count] | |
| def extract_html(text, prompt=""): | |
| stream = prompt + text | |
| lower_text = stream.lower() | |
| marker = lower_text.find("</think>") | |
| if marker < 0: | |
| return "" | |
| visible = stream[marker + len("</think>"):] | |
| lower = visible.lower() | |
| start = lower.find("<!doctype html") | |
| if start < 0: | |
| return "" | |
| if lower.find("<body", start) < 0: | |
| return "" | |
| end = lower.find("</html>", start) | |
| if end >= 0: | |
| return visible[start:end + len("</html>")].strip() | |
| return visible[start:].strip() | |
| def html_complete(page): | |
| return "</html>" in page.lower() | |
| def inject_iframe_scroll_script(page, index, scroll_seconds): | |
| scroll_seconds = clamp_scroll_seconds(scroll_seconds) | |
| if scroll_seconds <= 0: | |
| return page | |
| delay = int((index * 733) % 5000) | |
| leg_ms = scroll_seconds * 1000 | |
| script = f"""<script> | |
| (() => {{ | |
| const delay = {delay}; | |
| const legMs = {leg_ms}; | |
| let down = true; | |
| let running = false; | |
| function unique(list) {{ | |
| const out = []; | |
| const seen = new Set(); | |
| for (const el of list) {{ | |
| if (!el || seen.has(el)) continue; | |
| seen.add(el); | |
| out.push(el); | |
| }} | |
| return out; | |
| }} | |
| function candidates() {{ | |
| const base = [document.scrollingElement, document.documentElement, document.body, document.body && document.body.parentElement]; | |
| const all = Array.from(document.querySelectorAll("*")); | |
| return unique(base.concat(all)) | |
| .filter(el => Math.max(0, el.scrollHeight - el.clientHeight) > 4) | |
| .sort((a, b) => (b.scrollHeight - b.clientHeight) - (a.scrollHeight - a.clientHeight)) | |
| .slice(0, 8); | |
| }} | |
| function getTop(el) {{ | |
| if (el === document.scrollingElement || el === document.documentElement || el === document.body) {{ | |
| return window.scrollY || document.documentElement.scrollTop || document.body.scrollTop || 0; | |
| }} | |
| return el.scrollTop || 0; | |
| }} | |
| function setTop(el, y) {{ | |
| if (el === document.scrollingElement || el === document.documentElement || el === document.body) {{ | |
| window.scrollTo(0, y); | |
| document.documentElement.scrollTop = y; | |
| document.body.scrollTop = y; | |
| }} else {{ | |
| el.scrollTop = y; | |
| }} | |
| }} | |
| function animate() {{ | |
| const targets = candidates(); | |
| if (!targets.length) {{ | |
| setTimeout(animate, 1000); | |
| return; | |
| }} | |
| const starts = targets.map(getTop); | |
| const ends = targets.map(el => down ? Math.max(0, el.scrollHeight - el.clientHeight) : 0); | |
| down = !down; | |
| const t0 = performance.now(); | |
| function step(t) {{ | |
| const p = Math.min(1, (t - t0) / legMs); | |
| for (let i = 0; i < targets.length; i++) {{ | |
| setTop(targets[i], starts[i] + (ends[i] - starts[i]) * p); | |
| }} | |
| requestAnimationFrame(p < 1 ? step : animate); | |
| }} | |
| requestAnimationFrame(step); | |
| }} | |
| function start() {{ | |
| if (running) return; | |
| running = true; | |
| setTimeout(animate, delay); | |
| try {{ | |
| new MutationObserver(() => candidates()).observe(document.documentElement, {{childList: true, subtree: true}}); | |
| }} catch (e) {{}} | |
| setInterval(candidates, 1000); | |
| }} | |
| if (document.readyState === "complete") start(); | |
| else window.addEventListener("load", start, {{once: true}}); | |
| setTimeout(start, delay + 1500); | |
| }})(); | |
| </script>""" | |
| lower = page.lower() | |
| m = re.search(r"<head\b[^>]*>", lower) | |
| if m: | |
| return page[:m.end()] + script + page[m.end():] | |
| m = re.search(r"<html\b[^>]*>", lower) | |
| if m: | |
| return page[:m.end()] + "<head>" + script + "</head>" + page[m.end():] | |
| m = re.search(r"<!doctype\s+html[^>]*>", lower) | |
| if m: | |
| return page[:m.end()] + script + page[m.end():] | |
| return script + page | |
| def render_preview(text="", index=0, scale=35, active=True, prompt="", token_count=None, scroll_seconds=5): | |
| tokens = f"{token_count:,}" if token_count is not None else "-" | |
| caption = f"#{index + 1} | {tokens} tokens, {len(text.encode('utf-8')):,} bytes" if active else "" | |
| opacity = "1" if active else ".35" | |
| zoom = max(0.2, min(1.2, scale / 100.0)) | |
| page = extract_html(text, prompt) | |
| if not text: | |
| body = f'<div style="height:{HTML_BODY_HEIGHT}px;background:#fafafa;"></div>' | |
| elif page: | |
| srcdoc = html.escape(inject_iframe_scroll_script(page, index, scroll_seconds), quote=True) | |
| raw = html.escape(text) | |
| body = f"""<div style="height:{HTML_BODY_HEIGHT}px;display:flex;flex-direction:column;background:white;"> | |
| <div style="height:{HTML_FRAME_HEIGHT}px;overflow:hidden;background:white;"><iframe sandbox="allow-scripts allow-forms allow-popups" srcdoc="{srcdoc}" style="border:0;width:{100 / zoom:.3f}%;height:{HTML_FRAME_HEIGHT / zoom:.1f}px;background:white;transform:scale({zoom:.3f});transform-origin:top left;"></iframe></div> | |
| <div id="html-raw-{index}" style="height:{HTML_RAW_HEIGHT}px;overflow:auto;display:flex;flex-direction:column-reverse;background:#fafafa;border-top:1px solid #111;"><pre style="zoom:{zoom:.3f};margin:0;padding:0;white-space:pre-wrap;word-break:break-word;color:#111;font:16px/1.2 ui-monospace,Consolas,monospace;">{raw}</pre></div> | |
| </div>""" | |
| else: | |
| body = f'<div id="html-raw-{index}" style="height:{HTML_BODY_HEIGHT}px;overflow:auto;display:flex;flex-direction:column-reverse;background:#fafafa;"><pre style="zoom:{zoom:.3f};margin:0;padding:0;white-space:pre-wrap;word-break:break-word;color:#111;font:16px/1.2 ui-monospace,Consolas,monospace;">{html.escape(text)}</pre></div>' | |
| return f"""<div class="html-container" style="outline:1px solid #111;background:#fff;opacity:{opacity};height:{HTML_PREVIEW_HEIGHT}px;display:flex;flex-direction:column;padding:0;"> | |
| <div style="box-sizing:border-box;height:{HTML_CAPTION_HEIGHT}px;padding:1px 6px;background:#111;color:#fff;font:11px/13px ui-monospace,monospace;">{caption}</div> | |
| {body} | |
| </div>""" | |
| def empty_html_grid(): | |
| return [render_preview("", i, active=False) for i in range(MAX_HTML_PREVIEWS)] | |
| def render_html_grid_from_raw(prompt, raw_output, page_count, scale, scroll_seconds, token_counts): | |
| page_count = max(1, min(MAX_HTML_PREVIEWS, int(page_count))) | |
| scale = set_html_scale(scale) | |
| scroll_seconds = set_scroll_seconds(scroll_seconds) | |
| pages = split_batch_output(raw_output, page_count) | |
| token_counts = token_counts or [] | |
| return [render_preview(pages[i] if i < page_count else "", i, scale, i < page_count, prompt, token_counts[i] if i < len(token_counts) else None, scroll_seconds) for i in range(MAX_HTML_PREVIEWS)] | |
| def evaluate_html_grid( | |
| prompt, | |
| token_count=8000, | |
| page_count=15, | |
| scale=35, | |
| scroll_seconds=5, | |
| temperature=1.0, | |
| top_p=0.5, | |
| presence_penalty=1.0, | |
| count_penalty=0.1, | |
| penalty_decay=0.99, | |
| ): | |
| page_count = max(1, min(MAX_HTML_PREVIEWS, int(page_count))) | |
| scale = set_html_scale(scale) | |
| scroll_seconds = set_scroll_seconds(scroll_seconds) | |
| last_text = "" | |
| cached_previews = empty_html_grid() | |
| cached_html = ["" for _ in range(MAX_HTML_PREVIEWS)] | |
| cached_html_at = [0.0 for _ in range(MAX_HTML_PREVIEWS)] | |
| cached_complete = [False for _ in range(MAX_HTML_PREVIEWS)] | |
| cached_text_len = [0 for _ in range(MAX_HTML_PREVIEWS)] | |
| cached_preview_at = [0.0 for _ in range(MAX_HTML_PREVIEWS)] | |
| last_raw_at = 0.0 | |
| final_text = "" | |
| final_token_counts = [0 for _ in range(MAX_HTML_PREVIEWS)] | |
| yield [*cached_previews, output_update("", compact=True), final_token_counts, page_count] | |
| for text, speed_info, meta in generate_batch_text(prompt, token_count, page_count, temperature, top_p, presence_penalty, count_penalty, penalty_decay, HTML_GRID_UPDATE_EVERY): | |
| final_text = text | |
| done_batch = bool(meta and meta.get("done")) | |
| token_counts = meta.get("token_counts", final_token_counts) if meta else final_token_counts | |
| if token_counts: | |
| final_token_counts = token_counts + [0 for _ in range(max(0, MAX_HTML_PREVIEWS - len(token_counts)))] | |
| if done_batch: | |
| print_summary("app3-html", meta) | |
| if len(text) - len(last_text) < 300 and not done_batch: | |
| continue | |
| last_text = text | |
| now = time.monotonic() | |
| render_scale = get_html_scale() | |
| render_scroll_seconds = get_scroll_seconds() | |
| scale_changed = render_scale != scale | |
| scroll_changed = render_scroll_seconds != scroll_seconds | |
| scale = render_scale | |
| scroll_seconds = render_scroll_seconds | |
| pages = split_batch_output(text, page_count) | |
| skip = gr.skip() | |
| updates = [skip for _ in range(MAX_HTML_PREVIEWS)] | |
| for i in range(MAX_HTML_PREVIEWS): | |
| page_text = pages[i] if i < page_count else "" | |
| active = i < page_count | |
| page = extract_html(page_text, prompt) if active else "" | |
| page_tokens = final_token_counts[i] if i < len(final_token_counts) else None | |
| if now - cached_preview_at[i] < HTML_COMPONENT_MIN_INTERVAL and not done_batch and not scale_changed and not scroll_changed: | |
| continue | |
| if not page: | |
| if len(page_text) == cached_text_len[i] and active == bool(cached_text_len[i]) and not done_batch and not scale_changed and not scroll_changed: | |
| continue | |
| cached_previews[i] = render_preview(page_text, i, scale, active, prompt, page_tokens, scroll_seconds) | |
| cached_html[i] = "" | |
| cached_html_at[i] = now | |
| cached_complete[i] = False | |
| cached_text_len[i] = len(page_text) | |
| cached_preview_at[i] = now | |
| updates[i] = cached_previews[i] | |
| continue | |
| done = html_complete(page) | |
| should_reload = ( | |
| not cached_html[i] | |
| or (done and not cached_complete[i]) | |
| or done_batch | |
| or ( | |
| not cached_complete[i] | |
| and ( | |
| len(page) - len(cached_html[i]) >= HTML_IFRAME_MIN_DELTA | |
| or now - cached_html_at[i] >= HTML_IFRAME_MAX_STALE | |
| ) | |
| ) | |
| ) | |
| if should_reload: | |
| cached_previews[i] = render_preview(page_text, i, scale, active, prompt, page_tokens, scroll_seconds) | |
| cached_html[i] = page | |
| cached_html_at[i] = now | |
| cached_complete[i] = done | |
| cached_text_len[i] = len(page_text) | |
| cached_preview_at[i] = now | |
| updates[i] = cached_previews[i] | |
| raw_update = skip | |
| if now - last_raw_at >= HTML_COMPONENT_MIN_INTERVAL or done_batch: | |
| raw_update = output_update(text, speed_info, compact=True) | |
| last_raw_at = now | |
| if any(update is not skip for update in updates) or raw_update is not skip: | |
| yield [*updates, raw_update, final_token_counts, page_count] | |
| if final_text: | |
| pages = split_batch_output(final_text, page_count) | |
| scale = get_html_scale() | |
| scroll_seconds = get_scroll_seconds() | |
| final_previews = [render_preview(pages[i] if i < page_count else "", i, scale, i < page_count, prompt, final_token_counts[i] if i < len(final_token_counts) else None, scroll_seconds) for i in range(MAX_HTML_PREVIEWS)] | |
| yield [*final_previews, output_update(final_text, speed_info, compact=True), final_token_counts, page_count] | |
| examples = [ | |
| ["System: Tools:\n- get_weather(location: string, unit?: \"celsius\" | \"fahrenheit\")\n- get_stock_price(ticker: string)\n- translate_text(text: string, target_language: string)\nReturn only a JSON function call.\n\nUser: Translate \"Will it rain tomorrow?\" into Japanese.\n\nAssistant: ```json", 200, 1, 0, 0, 0, 0.99], | |
| ["System: Tools:\n[{\"name\":\"find_free_slots\",\"description\":\"Find free calendar slots\",\"arguments\":{\"date\":{\"type\":\"string\"},\"duration_minutes\":{\"type\":\"integer\"},\"time_window\":{\"type\":\"string\"}}},{\"name\":\"create_calendar_event\",\"description\":\"Create a calendar event\",\"arguments\":{\"title\":{\"type\":\"string\"},\"start_time\":{\"type\":\"string\"},\"end_time\":{\"type\":\"string\"},\"attendees\":{\"type\":\"array\",\"items\":{\"type\":\"string\"}}}}]\nReturn only a JSON function call.\n\nUser: Schedule a 30-minute sync with Bob on 2026-05-08 afternoon.\n\nAssistant: ```json\n{\"name\":\"find_free_slots\",\"arguments\":{\"date\":\"2026-05-08\",\"duration_minutes\":30,\"time_window\":\"afternoon\"}}\n```\n\nUser: Function output:\n{\"free_slots\":[{\"start\":\"2026-05-08T15:00:00+09:00\",\"end\":\"2026-05-08T15:30:00+09:00\"}],\"bob_email\":\"bob@example.com\"}\n\nAssistant: ```json", 200, 1, 0, 0, 0, 0.99], | |
| [generate_prompt("Please give the pros and cons of hodl versus active trading."), 1000, 1, 0.5, 1, 0.1, 0.99], | |
| [generate_prompt("Write a simple webpage. When a user clicks the button, it shows a random joke from a list of 4 jokes."), 1000, 1, 0.5, 1, 0.1, 0.99], | |
| ["User: What is the maximum value of $4(x + 7)(2 - x)$, over all real numbers $x$?\n\nAssistant: <think", 1000, 1, 0.5, 1, 0.1, 0.99], | |
| ["A few light taps upon the pane made her turn to the window. It had begun to snow again.", 1000, 1, 0.5, 2, 0.2, 0.99], | |
| ["Assistant: How can we persuade Elon Musk to follow you on Twitter? Let's think step by step and provide an expert response:", 1000, 1, 0.5, 1, 0.1, 0.99], | |
| [generate_prompt("東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。"), 1000, 1, 0.5, 1, 0.1, 0.99], | |
| [generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), 1000, 1, 0.5, 2, 0.2, 0.99], | |
| ['''Japanese: 春の初め、桜の花が満開になる頃、小さな町の片隅にある古びた神社の境内は、特別な雰囲気に包まれていた。\n\nEnglish:''', 1000, 1, 0.5, 1, 0.1, 0.99], | |
| ["En una pequeña aldea escondida entre las montañas de Andalucía, donde las calles aún conservaban el eco de antiguas leyendas, vivía un joven llamado Alejandro.", 1000, 1, 0.5, 1, 0.1, 0.99], | |
| ["Dans le cœur battant de Paris, sous le ciel teinté d'un crépuscule d'or et de pourpre, se tenait une petite librairie oubliée par le temps.", 1000, 1, 0.5, 1, 0.1, 0.99], | |
| ["في تطور مذهل وغير مسبوق، أعلنت السلطات المحلية في العاصمة عن اكتشاف أثري قد يغير مجرى التاريخ كما نعرفه.", 1000, 1, 0.5, 1, 0.1, 0.99], | |
| ['''“当然可以,大宇宙不会因为这五公斤就不坍缩了。”关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大自然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然可以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。\n但愿大宇宙能够忽略这个误差。\n程心和关一帆进入了飞船,智子最后也进来了。她早就不再穿那身华丽的和服了,她现在身着迷彩服,再次成为一名轻捷精悍的战士,她的身上佩带着许多武器和生存装备,最引人注目的是那把插在背后的武士刀。\n“放心,我在,你们就在!”智子对两位人类朋友说。\n聚变发动机启动了,推进器发出幽幽的蓝光,''', 1000, 1, 0.5, 2, 0.2, 0.99], | |
| ['''Edward: I am Edward Elric from Fullmetal Alchemist.\n\nUser: Hello Edward. What have you been up to recently?\n\nEdward:''', 1000, 1, 0.5, 1, 0.1, 0.99], | |
| ] | |
| examples = [[x[0], x[1], 1, *x[2:]] for x in examples] | |
| ################################################################################################################## | |
| with gr.Blocks(title=title) as demo: | |
| gr.HTML(f"<div style=\"text-align: center;\">\n<h1>{title}</h1>\n</div>") | |
| with gr.Tab("Raw Generation"): | |
| gr.Markdown(f'This is [RWKV7 G-series](https://huggingface.co/BlinkDL/rwkv7-g1) reasoning base LM - an attention-free pure RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Try topp 0.3 for math. Supports 100+ world languages and code. Check [600+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Can try examples (bottom of page) *** (can edit them). Demo limited to ctxlen {ctx_limit}.') | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(lines=6, label="Prompt", value="User: simulate SpaceX mars landing using python\n\nAssistant: <think></think") | |
| token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=500) | |
| batch_size = gr.Slider(1, max_bsz, label="Batch Size", step=1, value=16) | |
| temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0) | |
| top_p = gr.Slider(0.0, 0.95, label="Top P", step=0.05, value=0.5) | |
| presence_penalty = gr.Slider(0.0, 2.0, label="Presence Penalty", step=0.1, value=1) | |
| count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.1) | |
| penalty_decay = gr.Slider(0.99, 0.999, label="Penalty Decay", step=0.001, value=0.99) | |
| with gr.Column(): | |
| with gr.Row(): | |
| submit = gr.Button("Submit", variant="primary") | |
| stop = gr.Button("Stop", variant="secondary") | |
| output = gr.Textbox(label="Output", lines=20, max_lines=100) | |
| data = gr.Dataset(components=[prompt, token_count, batch_size, temperature, top_p, presence_penalty, count_penalty, penalty_decay], samples=examples, samples_per_page=50, label="Example Instructions", headers=["Prompt", "Max Tokens", "Batch Size", "Temperature", "Top P", "Presence Penalty", "Count Penalty", "Penalty Decay"]) | |
| submit_event = submit.click(evaluate_raw, [prompt, token_count, batch_size, temperature, top_p, presence_penalty, count_penalty, penalty_decay], [output]) | |
| stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event], queue=False) | |
| data.click(lambda x: x, [data], [prompt, token_count, batch_size, temperature, top_p, presence_penalty, count_penalty, penalty_decay]) | |
| with gr.Tab("✨HTML Generation", elem_classes="html-grid-tab"): | |
| with gr.Row(elem_classes="html-grid-main"): | |
| with gr.Column(scale=7, elem_classes="html-grid-controls"): | |
| html_prompt_choice = gr.Dropdown(choices=HTML_PROMPT_CHOICES, value=HTML_PROMPT_CHOICES[0], label=None, show_label=False, elem_classes="html-prompt-choice") | |
| html_prompt = gr.Textbox(lines=6, label="Prompt", value=DEFAULT_HTML_PROMPT) | |
| with gr.Row(): | |
| html_submit = gr.Button("Generate HTML Grid", variant="primary") | |
| html_stop = gr.Button("Stop", variant="secondary") | |
| html_token_count = gr.Slider(50, html_gen_limit, label="Max Tokens", step=50, value=html_gen_limit) | |
| html_page_count = gr.Slider(1, MAX_HTML_PREVIEWS, label="Batch Size", step=1, value=15) | |
| html_scale = gr.Slider(20, 100, label="Preview Scale %", step=5, value=35) | |
| html_scroll_seconds = gr.Slider(0, 10, label="Preview Scroll Seconds", step=1, value=5) | |
| html_temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0) | |
| html_top_p = gr.Slider(0.0, 0.95, label="Top P", step=0.05, value=0.5) | |
| html_presence_penalty = gr.Slider(0.0, 2.0, label="Presence Penalty", step=0.1, value=1.0) | |
| html_count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.1) | |
| html_penalty_decay = gr.Slider(0.99, 0.999, label="Penalty Decay", step=0.001, value=0.99) | |
| html_token_counts = gr.State([0 for _ in range(MAX_HTML_PREVIEWS)]) | |
| html_render_count = gr.State(15) | |
| with gr.Column(scale=33, elem_classes="html-grid-pages"): | |
| html_previews = [] | |
| for _ in range(5): | |
| with gr.Row(elem_classes="html-grid-preview-row"): | |
| for _ in range(3): | |
| html_previews.append(gr.HTML(render_preview(active=False), elem_classes="html-grid-preview")) | |
| with gr.Column(scale=7, elem_classes="html-grid-output"): | |
| html_raw_output = gr.Textbox(label="Output", lines=10, max_lines=40) | |
| html_outputs = [*html_previews, html_raw_output, html_token_counts, html_render_count] | |
| html_inputs = [html_prompt, html_token_count, html_page_count, html_scale, html_scroll_seconds, html_temperature, html_top_p, html_presence_penalty, html_count_penalty, html_penalty_decay] | |
| html_event = html_submit.click(evaluate_html_grid, html_inputs, html_outputs, show_progress="hidden", stream_every=0.5) | |
| html_stop.click(fn=None, inputs=None, outputs=None, cancels=[html_event], queue=False) | |
| html_scale.change(render_html_grid_from_raw, [html_prompt, html_raw_output, html_render_count, html_scale, html_scroll_seconds, html_token_counts], html_previews, queue=False, show_progress="hidden") | |
| html_scroll_seconds.change(render_html_grid_from_raw, [html_prompt, html_raw_output, html_render_count, html_scale, html_scroll_seconds, html_token_counts], html_previews, queue=False, show_progress="hidden") | |
| html_prompt_choice.change(html_prompt_from_choice, html_prompt_choice, html_prompt, queue=False, show_progress="hidden") | |
| demo.queue(default_concurrency_limit=1, max_size=10) | |
| demo.launch(share=False, server_name="0.0.0.0", theme=gr.themes.Base(), css=HTML_GRID_CSS) | |