RWKV-Gradio-1 / app.py
BlinkDL's picture
Update app.py
9ddec9f verified
import gc, html, re, threading, time
import gradio as gr
import torch
from datetime import datetime
from huggingface_hub import hf_hub_download
from rwkv.utils import PIPELINE
import rwkv7_fast_v3a as v3a
ctx_limit = 7000
gen_limit = 1000
max_bsz = 64
CHUNK_LEN = 512 # chunk prefill, save VRAM
SAMPLER_TOP_K = 500
YIELD_EVERY = 16
USE_CUDA_GRAPH = True
html_gen_limit = 8000
MAX_HTML_PREVIEWS = 15
HTML_GRID_UPDATE_EVERY = 32
HTML_IFRAME_MIN_DELTA = 1800
HTML_IFRAME_MAX_STALE = 2.0
HTML_COMPONENT_MIN_INTERVAL = 0.5
HTML_PREVIEW_HEIGHT = 300
HTML_CAPTION_HEIGHT = 15
HTML_BODY_HEIGHT = HTML_PREVIEW_HEIGHT - HTML_CAPTION_HEIGHT
HTML_FRAME_HEIGHT = 228
HTML_RAW_HEIGHT = HTML_BODY_HEIGHT - HTML_FRAME_HEIGHT
HTML_PROMPT_CHOICES = [
"3D animation of cars in forest with animals",
"interactive weather map with animated clouds and rain",
"retro arcade RPG start screen",
"3D animation of a SpaceX rocket landing on Mars",
"storybook scene with a dragon flying over a castle",
"interactive dashboard for a city traffic system",
"animated aquarium with colorful fish and coral",
"sci-fi spaceship navigation interface",
"cozy cafe menu with animated steam and pastries",
"character sheet for a high fantasy RPG",
"a fancy hotel homepage",
]
def html_prompt_from_choice(choice):
return f"User: Write HTML: {choice}\n\nAssistant: <think></think"
DEFAULT_HTML_PROMPT = html_prompt_from_choice(HTML_PROMPT_CHOICES[0])
HTML_GRID_CSS = """
div.main { padding-left: 0 !important; padding-right: 0 !important; }
.html-grid-tab { padding-top: 0 !important; }
.html-grid-main { display: grid !important; grid-template-columns: minmax(220px, 17.5%) 1fr !important; grid-template-rows: auto auto !important; gap: 4px !important; margin-top: 0 !important; align-items: start !important; }
.html-grid-main > div { gap: 4px !important; }
.html-grid-controls { grid-column: 1 !important; grid-row: 1 !important; gap: 4px !important; min-width: 0 !important; }
.html-grid-pages { grid-column: 2 !important; grid-row: 1 / span 2 !important; gap: 4px !important; min-width: 0 !important; }
.html-grid-output { grid-column: 1 !important; grid-row: 2 !important; gap: 4px !important; min-width: 0 !important; }
.html-grid-preview-row { gap: 4px !important; margin: 0 !important; }
.html-grid-preview { margin: 0 !important; }
.html-grid-preview > div { margin: 0 !important; }
.html-container { padding: 0 !important; }
.html-prompt-choice { margin: 0 !important; padding: 0 !important; min-height: 0 !important; }
.html-prompt-choice .wrap,
.html-prompt-choice .wrap-inner,
.html-prompt-choice .secondary-wrap { margin: 0 !important; padding: 0 !important; min-height: 0 !important; }
.html-prompt-choice label { margin: 0 !important; padding: 0 !important; }
@media (max-width: 768px) {
.html-grid-main { grid-template-columns: 1fr !important; grid-template-rows: auto auto auto !important; }
.html-grid-controls { grid-column: 1 !important; grid-row: 1 !important; }
.html-grid-pages { grid-column: 1 !important; grid-row: 2 !important; }
.html-grid-output { grid-column: 1 !important; grid-row: 3 !important; }
}
"""
html_view_lock = threading.Lock()
html_view_scale = 35
html_view_scroll_seconds = 5
def clamp_html_scale(scale):
return max(20, min(100, int(scale)))
def clamp_scroll_seconds(seconds):
return max(0, min(10, int(seconds)))
def set_html_scale(scale):
global html_view_scale
scale = clamp_html_scale(scale)
with html_view_lock:
html_view_scale = scale
return scale
def get_html_scale():
with html_view_lock:
return html_view_scale
def set_scroll_seconds(seconds):
global html_view_scroll_seconds
seconds = clamp_scroll_seconds(seconds)
with html_view_lock:
html_view_scroll_seconds = seconds
return seconds
def get_scroll_seconds():
with html_view_lock:
return html_view_scroll_seconds
########################## text rwkv ################################################################
title = "rwkv7-g1g-13.3b-20260523-ctx8192"
model_path = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth")
v3a.MODEL_PATH = model_path
v3a.WKV_MODE = "fp32io16"
v3a.EMB_DEVICE = "cuda"
v3a.RKV_MODE = "off"
v3a.CMIX_SPARSE = "no-fc"
v3a.LOWRANK_WEIGHT = "transpose"
v3a.ORIG_LINEAR_GROUPS = {"att_c2c", "ffn_key", "head"}
v3a.load_extensions(v3a.WKV_MODE)
model = v3a.RWKV7()
gc.collect()
torch.cuda.empty_cache()
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
@torch.jit.script
def sample_logits_batch_cuda(logits, temperature: float, top_p: float, k: int):
if top_p <= 0.0 or k == 1:
return torch.argmax(logits, dim=-1)
vals, ids = torch.topk(logits.float(), k=k, dim=-1, sorted=True)
if temperature == 1.0:
probs = torch.softmax(vals, dim=-1)
else:
probs = torch.softmax(vals / temperature, dim=-1)
cdf = torch.cumsum(probs, dim=-1)
if top_p < 1.0:
keep = torch.argmax((cdf >= top_p).to(torch.int32), dim=-1)
mass = cdf.gather(1, keep.view(-1, 1)).view(-1)
else:
mass = cdf[:, -1]
r = torch.rand((logits.size(0), 1), device=logits.device) * mass.view(-1, 1)
out = torch.searchsorted(cdf, r).view(-1, 1)
return ids.gather(1, out).view(-1)
def get_decode_ctx(B: int, decode_cache):
cached = decode_cache.get(B)
if cached is not None:
return cached
state = model.zero_state(B)
x = torch.empty((B, 1, v3a.C), device="cuda", dtype=torch.half)
path = v3a.select_path(B, 1)
for _ in range(2):
model.forward_from_x(x, state, path)
torch.cuda.synchronize()
graph = None
output = None
if USE_CUDA_GRAPH:
try:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
output = model.forward_from_x(x, state, path)
torch.cuda.synchronize()
except Exception as exc:
print(f"CUDA graph disabled for B={B}: {exc}", flush=True)
graph = None
output = None
cached = (state, x, graph, output)
decode_cache[B] = cached
return cached
def copy_state_to_batch(dst, src):
B = dst[2].shape[0]
dst[0].copy_(src[0].expand(-1, -1, B, -1))
dst[1].copy_(src[1].expand(-1, B, -1, -1, -1))
dst[2].copy_(src[2].expand(B))
def tokens_to_x(tokens):
token_device = "cpu" if model.emb_cpu else "cuda"
if isinstance(tokens, torch.Tensor):
token_tensor = tokens.to(device=token_device, dtype=torch.long, non_blocking=True).view(-1, 1)
else:
token_tensor = torch.tensor(tokens, dtype=torch.long, device=token_device).view(-1, 1)
return model.embed(token_tensor)
def generate_prompt(instruction, input=""):
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
if input:
return f"Instruction: {instruction}\n\nInput: {input}\n\nResponse:"
else:
return f"User: {instruction}\n\nAssistant: <think></think"
def qa_prompt(instruction):
instruction = instruction.strip().replace('\r\n','\n')
instruction = re.sub(r'\n+', '\n', instruction)
return f"User: {instruction}\n\nAssistant: <think></think"
def output_update(text, speed="", compact=False):
if compact and speed:
speed = speed.split(" @ ", 1)[0]
label = f"Output {speed}" if speed else "Output"
return gr.update(value=text, label=label)
def output_text(B, out_str):
return out_str[0].strip() if B == 1 else "\n====\n".join(x.strip() for x in out_str)
def speed_text(rate, B, tokens, chars):
return f"{rate:.1f} token/s @ bsz {B} = {tokens} tokens, {chars} chars"
def gib(n):
return n / 1_000_000_000.0
def generate_batch_text(
ctx,
token_count=200,
batch_size=1,
temperature=1.0,
top_p=0.5,
presencePenalty = 1,
countPenalty = 0.1,
penalty_decay = 0.99,
yield_every = YIELD_EVERY,
):
req_t0 = time.perf_counter()
user_token_count = int(token_count)
rwkv_model = model
pipe = pipeline
sample_temperature = float(temperature)
sample_top_p = float(top_p)
if sample_temperature <= 0:
sample_temperature = 1.0
sample_top_p = 0
else:
sample_temperature = max(0.2, sample_temperature)
alpha_frequency = float(countPenalty)
alpha_presence = float(presencePenalty)
ctx = ctx.strip()
input_ids = pipe.encode(ctx)[-ctx_limit:]
input_token_count = len(input_ids)
B = min(max_bsz, max(1, int(batch_size)))
batch_rows = None
all_tokens = [[] for _ in range(B)]
out_last = [0 for _ in range(B)]
out_str = ['' for _ in range(B)]
occurrence_count = None
occurrence_presence = None
finished = [False for _ in range(B)]
speed_t0 = None
speed_tokens = 0
total_tokens = 0
speed_info = ""
decode_cache = {}
state = rwkv_model.zero_state(1)
decode_state, decode_x, decode_graph, decode_output = get_decode_ctx(B, decode_cache)
next_tokens = [0 for _ in range(B)]
out = None
for i in range(int(token_count)):
if i == 0:
if len(input_ids) == 0:
yield "", "", {"done": True, "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "B": B, "T": user_token_count, "In": input_token_count, "TPS": 0.0, "Time": 0.0, "Token": 0, "Char": 0, "VRAMUsed": 0, "VRAMTotal": 0, "token_counts": [0 for _ in range(B)]}
return
while len(input_ids) > 0:
token_device = "cpu" if rwkv_model.emb_cpu else "cuda"
tokens = torch.tensor(input_ids[:CHUNK_LEN], dtype=torch.long, device=token_device)
out = rwkv_model.forward(tokens, state).view(-1)
input_ids = input_ids[CHUNK_LEN:]
torch.cuda.synchronize()
copy_state_to_batch(decode_state, state)
logits = out.view(1, -1).repeat(B, 1)
else:
decode_x.copy_(tokens_to_x(next_tokens))
if decode_graph is None:
decode_output = rwkv_model.forward_from_x(decode_x, decode_state, v3a.select_path(B, 1))
else:
decode_graph.replay()
logits = decode_output.view(B, -1)
if occurrence_count is None:
occurrence_count = torch.zeros((B, logits.size(-1)), device=logits.device, dtype=logits.dtype)
occurrence_presence = torch.zeros_like(occurrence_count)
batch_rows = torch.arange(B, device=logits.device)
if alpha_frequency:
logits.sub_(occurrence_count, alpha=alpha_frequency)
if alpha_presence:
logits.sub_(occurrence_presence)
assert logits.is_cuda and logits.dim() == 2
sampled_tensor = sample_logits_batch_cuda(
logits,
sample_temperature,
sample_top_p,
min(SAMPLER_TOP_K, logits.size(-1)),
)
sampled = sampled_tensor.detach().cpu().tolist()
active = 0
next_tokens = [0 for _ in range(B)]
if penalty_decay != 1:
occurrence_count.mul_(penalty_decay)
occurrence_count[batch_rows, sampled_tensor] += 1
if alpha_presence:
occurrence_presence[batch_rows, sampled_tensor] = alpha_presence
for b in range(B):
if finished[b]:
continue
token = sampled[b]
if token == 0:
finished[b] = True
continue
active += 1
next_tokens[b] = token
all_tokens[b].append(token)
tmp = pipe.decode(all_tokens[b][out_last[b]:])
if '\ufffd' not in tmp:
out_str[b] += tmp
out_last[b] = len(all_tokens[b])
total_tokens += active
if active == 0:
break
if speed_t0 is None:
speed_t0 = time.perf_counter()
else:
speed_tokens += B
elapsed = max(1e-9, time.perf_counter() - speed_t0)
current_text = output_text(B, out_str)
speed_info = speed_text(speed_tokens / elapsed, B, total_tokens, len(current_text))
if i == 0 or i % max(1, int(yield_every)) == 0:
current_text = output_text(B, out_str)
yield current_text, speed_info, {"done": False, "token_counts": [len(tokens) for tokens in all_tokens]}
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
free, total = torch.cuda.mem_get_info()
del out
del state
del decode_cache
current_text = output_text(B, out_str)
if speed_t0 is not None and not speed_info:
speed_info = speed_text(0.0, B, total_tokens, len(current_text))
elapsed = time.perf_counter() - req_t0
final_tps = speed_tokens / max(1e-9, time.perf_counter() - speed_t0) if speed_t0 is not None else 0.0
used = total - free
meta = {"done": True, "timestamp": timestamp, "B": B, "T": user_token_count, "In": input_token_count, "TPS": final_tps, "Time": elapsed, "Token": total_tokens, "Char": len(current_text), "VRAMUsed": used, "VRAMTotal": total, "token_counts": [len(tokens) for tokens in all_tokens]}
yield current_text, speed_info, meta
def print_summary(prefix, meta):
print(
f"[{prefix}] {meta['timestamp']} B={meta['B']} T={meta['T']} In={meta['In']} "
f"TPS={meta['TPS']:.1f} Time={meta['Time']:.3f}s Token={meta['Token']} "
f"Char={meta['Char']} VRAM={gib(meta['VRAMUsed']):.2f}G/{gib(meta['VRAMTotal']):.2f}G",
flush=True,
)
def evaluate_raw(
ctx,
token_count=200,
batch_size=1,
temperature=1.0,
top_p=0.5,
presencePenalty = 1,
countPenalty = 0.1,
penalty_decay = 0.99,
):
for text, speed_info, meta in generate_batch_text(ctx, token_count, batch_size, temperature, top_p, presencePenalty, countPenalty, penalty_decay, YIELD_EVERY):
if meta and meta.get("done"):
print_summary("app3", meta)
yield output_update(text, speed_info)
def split_batch_output(text, count):
parts = text.split("\n====\n") if text else []
parts += [""] * max(0, count - len(parts))
return parts[:count]
def extract_html(text, prompt=""):
stream = prompt + text
lower_text = stream.lower()
marker = lower_text.find("</think>")
if marker < 0:
return ""
visible = stream[marker + len("</think>"):]
lower = visible.lower()
start = lower.find("<!doctype html")
if start < 0:
return ""
if lower.find("<body", start) < 0:
return ""
end = lower.find("</html>", start)
if end >= 0:
return visible[start:end + len("</html>")].strip()
return visible[start:].strip()
def html_complete(page):
return "</html>" in page.lower()
def inject_iframe_scroll_script(page, index, scroll_seconds):
scroll_seconds = clamp_scroll_seconds(scroll_seconds)
if scroll_seconds <= 0:
return page
delay = int((index * 733) % 5000)
leg_ms = scroll_seconds * 1000
script = f"""<script>
(() => {{
const delay = {delay};
const legMs = {leg_ms};
let down = true;
let running = false;
function unique(list) {{
const out = [];
const seen = new Set();
for (const el of list) {{
if (!el || seen.has(el)) continue;
seen.add(el);
out.push(el);
}}
return out;
}}
function candidates() {{
const base = [document.scrollingElement, document.documentElement, document.body, document.body && document.body.parentElement];
const all = Array.from(document.querySelectorAll("*"));
return unique(base.concat(all))
.filter(el => Math.max(0, el.scrollHeight - el.clientHeight) > 4)
.sort((a, b) => (b.scrollHeight - b.clientHeight) - (a.scrollHeight - a.clientHeight))
.slice(0, 8);
}}
function getTop(el) {{
if (el === document.scrollingElement || el === document.documentElement || el === document.body) {{
return window.scrollY || document.documentElement.scrollTop || document.body.scrollTop || 0;
}}
return el.scrollTop || 0;
}}
function setTop(el, y) {{
if (el === document.scrollingElement || el === document.documentElement || el === document.body) {{
window.scrollTo(0, y);
document.documentElement.scrollTop = y;
document.body.scrollTop = y;
}} else {{
el.scrollTop = y;
}}
}}
function animate() {{
const targets = candidates();
if (!targets.length) {{
setTimeout(animate, 1000);
return;
}}
const starts = targets.map(getTop);
const ends = targets.map(el => down ? Math.max(0, el.scrollHeight - el.clientHeight) : 0);
down = !down;
const t0 = performance.now();
function step(t) {{
const p = Math.min(1, (t - t0) / legMs);
for (let i = 0; i < targets.length; i++) {{
setTop(targets[i], starts[i] + (ends[i] - starts[i]) * p);
}}
requestAnimationFrame(p < 1 ? step : animate);
}}
requestAnimationFrame(step);
}}
function start() {{
if (running) return;
running = true;
setTimeout(animate, delay);
try {{
new MutationObserver(() => candidates()).observe(document.documentElement, {{childList: true, subtree: true}});
}} catch (e) {{}}
setInterval(candidates, 1000);
}}
if (document.readyState === "complete") start();
else window.addEventListener("load", start, {{once: true}});
setTimeout(start, delay + 1500);
}})();
</script>"""
lower = page.lower()
m = re.search(r"<head\b[^>]*>", lower)
if m:
return page[:m.end()] + script + page[m.end():]
m = re.search(r"<html\b[^>]*>", lower)
if m:
return page[:m.end()] + "<head>" + script + "</head>" + page[m.end():]
m = re.search(r"<!doctype\s+html[^>]*>", lower)
if m:
return page[:m.end()] + script + page[m.end():]
return script + page
def render_preview(text="", index=0, scale=35, active=True, prompt="", token_count=None, scroll_seconds=5):
tokens = f"{token_count:,}" if token_count is not None else "-"
caption = f"#{index + 1} | {tokens} tokens, {len(text.encode('utf-8')):,} bytes" if active else ""
opacity = "1" if active else ".35"
zoom = max(0.2, min(1.2, scale / 100.0))
page = extract_html(text, prompt)
if not text:
body = f'<div style="height:{HTML_BODY_HEIGHT}px;background:#fafafa;"></div>'
elif page:
srcdoc = html.escape(inject_iframe_scroll_script(page, index, scroll_seconds), quote=True)
raw = html.escape(text)
body = f"""<div style="height:{HTML_BODY_HEIGHT}px;display:flex;flex-direction:column;background:white;">
<div style="height:{HTML_FRAME_HEIGHT}px;overflow:hidden;background:white;"><iframe sandbox="allow-scripts allow-forms allow-popups" srcdoc="{srcdoc}" style="border:0;width:{100 / zoom:.3f}%;height:{HTML_FRAME_HEIGHT / zoom:.1f}px;background:white;transform:scale({zoom:.3f});transform-origin:top left;"></iframe></div>
<div id="html-raw-{index}" style="height:{HTML_RAW_HEIGHT}px;overflow:auto;display:flex;flex-direction:column-reverse;background:#fafafa;border-top:1px solid #111;"><pre style="zoom:{zoom:.3f};margin:0;padding:0;white-space:pre-wrap;word-break:break-word;color:#111;font:16px/1.2 ui-monospace,Consolas,monospace;">{raw}</pre></div>
</div>"""
else:
body = f'<div id="html-raw-{index}" style="height:{HTML_BODY_HEIGHT}px;overflow:auto;display:flex;flex-direction:column-reverse;background:#fafafa;"><pre style="zoom:{zoom:.3f};margin:0;padding:0;white-space:pre-wrap;word-break:break-word;color:#111;font:16px/1.2 ui-monospace,Consolas,monospace;">{html.escape(text)}</pre></div>'
return f"""<div class="html-container" style="outline:1px solid #111;background:#fff;opacity:{opacity};height:{HTML_PREVIEW_HEIGHT}px;display:flex;flex-direction:column;padding:0;">
<div style="box-sizing:border-box;height:{HTML_CAPTION_HEIGHT}px;padding:1px 6px;background:#111;color:#fff;font:11px/13px ui-monospace,monospace;">{caption}</div>
{body}
</div>"""
def empty_html_grid():
return [render_preview("", i, active=False) for i in range(MAX_HTML_PREVIEWS)]
def render_html_grid_from_raw(prompt, raw_output, page_count, scale, scroll_seconds, token_counts):
page_count = max(1, min(MAX_HTML_PREVIEWS, int(page_count)))
scale = set_html_scale(scale)
scroll_seconds = set_scroll_seconds(scroll_seconds)
pages = split_batch_output(raw_output, page_count)
token_counts = token_counts or []
return [render_preview(pages[i] if i < page_count else "", i, scale, i < page_count, prompt, token_counts[i] if i < len(token_counts) else None, scroll_seconds) for i in range(MAX_HTML_PREVIEWS)]
def evaluate_html_grid(
prompt,
token_count=8000,
page_count=15,
scale=35,
scroll_seconds=5,
temperature=1.0,
top_p=0.5,
presence_penalty=1.0,
count_penalty=0.1,
penalty_decay=0.99,
):
page_count = max(1, min(MAX_HTML_PREVIEWS, int(page_count)))
scale = set_html_scale(scale)
scroll_seconds = set_scroll_seconds(scroll_seconds)
last_text = ""
cached_previews = empty_html_grid()
cached_html = ["" for _ in range(MAX_HTML_PREVIEWS)]
cached_html_at = [0.0 for _ in range(MAX_HTML_PREVIEWS)]
cached_complete = [False for _ in range(MAX_HTML_PREVIEWS)]
cached_text_len = [0 for _ in range(MAX_HTML_PREVIEWS)]
cached_preview_at = [0.0 for _ in range(MAX_HTML_PREVIEWS)]
last_raw_at = 0.0
final_text = ""
final_token_counts = [0 for _ in range(MAX_HTML_PREVIEWS)]
yield [*cached_previews, output_update("", compact=True), final_token_counts, page_count]
for text, speed_info, meta in generate_batch_text(prompt, token_count, page_count, temperature, top_p, presence_penalty, count_penalty, penalty_decay, HTML_GRID_UPDATE_EVERY):
final_text = text
done_batch = bool(meta and meta.get("done"))
token_counts = meta.get("token_counts", final_token_counts) if meta else final_token_counts
if token_counts:
final_token_counts = token_counts + [0 for _ in range(max(0, MAX_HTML_PREVIEWS - len(token_counts)))]
if done_batch:
print_summary("app3-html", meta)
if len(text) - len(last_text) < 300 and not done_batch:
continue
last_text = text
now = time.monotonic()
render_scale = get_html_scale()
render_scroll_seconds = get_scroll_seconds()
scale_changed = render_scale != scale
scroll_changed = render_scroll_seconds != scroll_seconds
scale = render_scale
scroll_seconds = render_scroll_seconds
pages = split_batch_output(text, page_count)
skip = gr.skip()
updates = [skip for _ in range(MAX_HTML_PREVIEWS)]
for i in range(MAX_HTML_PREVIEWS):
page_text = pages[i] if i < page_count else ""
active = i < page_count
page = extract_html(page_text, prompt) if active else ""
page_tokens = final_token_counts[i] if i < len(final_token_counts) else None
if now - cached_preview_at[i] < HTML_COMPONENT_MIN_INTERVAL and not done_batch and not scale_changed and not scroll_changed:
continue
if not page:
if len(page_text) == cached_text_len[i] and active == bool(cached_text_len[i]) and not done_batch and not scale_changed and not scroll_changed:
continue
cached_previews[i] = render_preview(page_text, i, scale, active, prompt, page_tokens, scroll_seconds)
cached_html[i] = ""
cached_html_at[i] = now
cached_complete[i] = False
cached_text_len[i] = len(page_text)
cached_preview_at[i] = now
updates[i] = cached_previews[i]
continue
done = html_complete(page)
should_reload = (
not cached_html[i]
or (done and not cached_complete[i])
or done_batch
or (
not cached_complete[i]
and (
len(page) - len(cached_html[i]) >= HTML_IFRAME_MIN_DELTA
or now - cached_html_at[i] >= HTML_IFRAME_MAX_STALE
)
)
)
if should_reload:
cached_previews[i] = render_preview(page_text, i, scale, active, prompt, page_tokens, scroll_seconds)
cached_html[i] = page
cached_html_at[i] = now
cached_complete[i] = done
cached_text_len[i] = len(page_text)
cached_preview_at[i] = now
updates[i] = cached_previews[i]
raw_update = skip
if now - last_raw_at >= HTML_COMPONENT_MIN_INTERVAL or done_batch:
raw_update = output_update(text, speed_info, compact=True)
last_raw_at = now
if any(update is not skip for update in updates) or raw_update is not skip:
yield [*updates, raw_update, final_token_counts, page_count]
if final_text:
pages = split_batch_output(final_text, page_count)
scale = get_html_scale()
scroll_seconds = get_scroll_seconds()
final_previews = [render_preview(pages[i] if i < page_count else "", i, scale, i < page_count, prompt, final_token_counts[i] if i < len(final_token_counts) else None, scroll_seconds) for i in range(MAX_HTML_PREVIEWS)]
yield [*final_previews, output_update(final_text, speed_info, compact=True), final_token_counts, page_count]
examples = [
["System: Tools:\n- get_weather(location: string, unit?: \"celsius\" | \"fahrenheit\")\n- get_stock_price(ticker: string)\n- translate_text(text: string, target_language: string)\nReturn only a JSON function call.\n\nUser: Translate \"Will it rain tomorrow?\" into Japanese.\n\nAssistant: ```json", 200, 1, 0, 0, 0, 0.99],
["System: Tools:\n[{\"name\":\"find_free_slots\",\"description\":\"Find free calendar slots\",\"arguments\":{\"date\":{\"type\":\"string\"},\"duration_minutes\":{\"type\":\"integer\"},\"time_window\":{\"type\":\"string\"}}},{\"name\":\"create_calendar_event\",\"description\":\"Create a calendar event\",\"arguments\":{\"title\":{\"type\":\"string\"},\"start_time\":{\"type\":\"string\"},\"end_time\":{\"type\":\"string\"},\"attendees\":{\"type\":\"array\",\"items\":{\"type\":\"string\"}}}}]\nReturn only a JSON function call.\n\nUser: Schedule a 30-minute sync with Bob on 2026-05-08 afternoon.\n\nAssistant: ```json\n{\"name\":\"find_free_slots\",\"arguments\":{\"date\":\"2026-05-08\",\"duration_minutes\":30,\"time_window\":\"afternoon\"}}\n```\n\nUser: Function output:\n{\"free_slots\":[{\"start\":\"2026-05-08T15:00:00+09:00\",\"end\":\"2026-05-08T15:30:00+09:00\"}],\"bob_email\":\"bob@example.com\"}\n\nAssistant: ```json", 200, 1, 0, 0, 0, 0.99],
[generate_prompt("Please give the pros and cons of hodl versus active trading."), 1000, 1, 0.5, 1, 0.1, 0.99],
[generate_prompt("Write a simple webpage. When a user clicks the button, it shows a random joke from a list of 4 jokes."), 1000, 1, 0.5, 1, 0.1, 0.99],
["User: What is the maximum value of $4(x + 7)(2 - x)$, over all real numbers $x$?\n\nAssistant: <think", 1000, 1, 0.5, 1, 0.1, 0.99],
["A few light taps upon the pane made her turn to the window. It had begun to snow again.", 1000, 1, 0.5, 2, 0.2, 0.99],
["Assistant: How can we persuade Elon Musk to follow you on Twitter? Let's think step by step and provide an expert response:", 1000, 1, 0.5, 1, 0.1, 0.99],
[generate_prompt("東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。"), 1000, 1, 0.5, 1, 0.1, 0.99],
[generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), 1000, 1, 0.5, 2, 0.2, 0.99],
['''Japanese: 春の初め、桜の花が満開になる頃、小さな町の片隅にある古びた神社の境内は、特別な雰囲気に包まれていた。\n\nEnglish:''', 1000, 1, 0.5, 1, 0.1, 0.99],
["En una pequeña aldea escondida entre las montañas de Andalucía, donde las calles aún conservaban el eco de antiguas leyendas, vivía un joven llamado Alejandro.", 1000, 1, 0.5, 1, 0.1, 0.99],
["Dans le cœur battant de Paris, sous le ciel teinté d'un crépuscule d'or et de pourpre, se tenait une petite librairie oubliée par le temps.", 1000, 1, 0.5, 1, 0.1, 0.99],
["في تطور مذهل وغير مسبوق، أعلنت السلطات المحلية في العاصمة عن اكتشاف أثري قد يغير مجرى التاريخ كما نعرفه.", 1000, 1, 0.5, 1, 0.1, 0.99],
['''“当然可以,大宇宙不会因为这五公斤就不坍缩了。”关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大自然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然可以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。\n但愿大宇宙能够忽略这个误差。\n程心和关一帆进入了飞船,智子最后也进来了。她早就不再穿那身华丽的和服了,她现在身着迷彩服,再次成为一名轻捷精悍的战士,她的身上佩带着许多武器和生存装备,最引人注目的是那把插在背后的武士刀。\n“放心,我在,你们就在!”智子对两位人类朋友说。\n聚变发动机启动了,推进器发出幽幽的蓝光,''', 1000, 1, 0.5, 2, 0.2, 0.99],
['''Edward: I am Edward Elric from Fullmetal Alchemist.\n\nUser: Hello Edward. What have you been up to recently?\n\nEdward:''', 1000, 1, 0.5, 1, 0.1, 0.99],
]
examples = [[x[0], x[1], 1, *x[2:]] for x in examples]
##################################################################################################################
with gr.Blocks(title=title) as demo:
gr.HTML(f"<div style=\"text-align: center;\">\n<h1>{title}</h1>\n</div>")
with gr.Tab("Raw Generation"):
gr.Markdown(f'This is [RWKV7 G-series](https://huggingface.co/BlinkDL/rwkv7-g1) reasoning base LM - an attention-free pure RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Try topp 0.3 for math. Supports 100+ world languages and code. Check [600+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Can try examples (bottom of page) *** (can edit them). Demo limited to ctxlen {ctx_limit}.')
with gr.Row():
with gr.Column():
prompt = gr.Textbox(lines=6, label="Prompt", value="User: simulate SpaceX mars landing using python\n\nAssistant: <think></think")
token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=500)
batch_size = gr.Slider(1, max_bsz, label="Batch Size", step=1, value=16)
temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
top_p = gr.Slider(0.0, 0.95, label="Top P", step=0.05, value=0.5)
presence_penalty = gr.Slider(0.0, 2.0, label="Presence Penalty", step=0.1, value=1)
count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.1)
penalty_decay = gr.Slider(0.99, 0.999, label="Penalty Decay", step=0.001, value=0.99)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", variant="primary")
stop = gr.Button("Stop", variant="secondary")
output = gr.Textbox(label="Output", lines=20, max_lines=100)
data = gr.Dataset(components=[prompt, token_count, batch_size, temperature, top_p, presence_penalty, count_penalty, penalty_decay], samples=examples, samples_per_page=50, label="Example Instructions", headers=["Prompt", "Max Tokens", "Batch Size", "Temperature", "Top P", "Presence Penalty", "Count Penalty", "Penalty Decay"])
submit_event = submit.click(evaluate_raw, [prompt, token_count, batch_size, temperature, top_p, presence_penalty, count_penalty, penalty_decay], [output])
stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event], queue=False)
data.click(lambda x: x, [data], [prompt, token_count, batch_size, temperature, top_p, presence_penalty, count_penalty, penalty_decay])
with gr.Tab("✨HTML Generation", elem_classes="html-grid-tab"):
with gr.Row(elem_classes="html-grid-main"):
with gr.Column(scale=7, elem_classes="html-grid-controls"):
html_prompt_choice = gr.Dropdown(choices=HTML_PROMPT_CHOICES, value=HTML_PROMPT_CHOICES[0], label=None, show_label=False, elem_classes="html-prompt-choice")
html_prompt = gr.Textbox(lines=6, label="Prompt", value=DEFAULT_HTML_PROMPT)
with gr.Row():
html_submit = gr.Button("Generate HTML Grid", variant="primary")
html_stop = gr.Button("Stop", variant="secondary")
html_token_count = gr.Slider(50, html_gen_limit, label="Max Tokens", step=50, value=html_gen_limit)
html_page_count = gr.Slider(1, MAX_HTML_PREVIEWS, label="Batch Size", step=1, value=15)
html_scale = gr.Slider(20, 100, label="Preview Scale %", step=5, value=35)
html_scroll_seconds = gr.Slider(0, 10, label="Preview Scroll Seconds", step=1, value=5)
html_temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
html_top_p = gr.Slider(0.0, 0.95, label="Top P", step=0.05, value=0.5)
html_presence_penalty = gr.Slider(0.0, 2.0, label="Presence Penalty", step=0.1, value=1.0)
html_count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.1)
html_penalty_decay = gr.Slider(0.99, 0.999, label="Penalty Decay", step=0.001, value=0.99)
html_token_counts = gr.State([0 for _ in range(MAX_HTML_PREVIEWS)])
html_render_count = gr.State(15)
with gr.Column(scale=33, elem_classes="html-grid-pages"):
html_previews = []
for _ in range(5):
with gr.Row(elem_classes="html-grid-preview-row"):
for _ in range(3):
html_previews.append(gr.HTML(render_preview(active=False), elem_classes="html-grid-preview"))
with gr.Column(scale=7, elem_classes="html-grid-output"):
html_raw_output = gr.Textbox(label="Output", lines=10, max_lines=40)
html_outputs = [*html_previews, html_raw_output, html_token_counts, html_render_count]
html_inputs = [html_prompt, html_token_count, html_page_count, html_scale, html_scroll_seconds, html_temperature, html_top_p, html_presence_penalty, html_count_penalty, html_penalty_decay]
html_event = html_submit.click(evaluate_html_grid, html_inputs, html_outputs, show_progress="hidden", stream_every=0.5)
html_stop.click(fn=None, inputs=None, outputs=None, cancels=[html_event], queue=False)
html_scale.change(render_html_grid_from_raw, [html_prompt, html_raw_output, html_render_count, html_scale, html_scroll_seconds, html_token_counts], html_previews, queue=False, show_progress="hidden")
html_scroll_seconds.change(render_html_grid_from_raw, [html_prompt, html_raw_output, html_render_count, html_scale, html_scroll_seconds, html_token_counts], html_previews, queue=False, show_progress="hidden")
html_prompt_choice.change(html_prompt_from_choice, html_prompt_choice, html_prompt, queue=False, show_progress="hidden")
demo.queue(default_concurrency_limit=1, max_size=10)
demo.launch(share=False, server_name="0.0.0.0", theme=gr.themes.Base(), css=HTML_GRID_CSS)