import gc, html, re, threading, time
import gradio as gr
import torch
from datetime import datetime
from huggingface_hub import hf_hub_download
from rwkv.utils import PIPELINE
import rwkv7_fast_v3a as v3a
ctx_limit = 7000
gen_limit = 1000
max_bsz = 64
CHUNK_LEN = 512 # chunk prefill, save VRAM
SAMPLER_TOP_K = 500
YIELD_EVERY = 16
USE_CUDA_GRAPH = True
html_gen_limit = 8000
MAX_HTML_PREVIEWS = 15
HTML_GRID_UPDATE_EVERY = 32
HTML_IFRAME_MIN_DELTA = 1800
HTML_IFRAME_MAX_STALE = 2.0
HTML_COMPONENT_MIN_INTERVAL = 0.5
HTML_PREVIEW_HEIGHT = 300
HTML_CAPTION_HEIGHT = 15
HTML_BODY_HEIGHT = HTML_PREVIEW_HEIGHT - HTML_CAPTION_HEIGHT
HTML_FRAME_HEIGHT = 228
HTML_RAW_HEIGHT = HTML_BODY_HEIGHT - HTML_FRAME_HEIGHT
HTML_PROMPT_CHOICES = [
"3D animation of cars in forest with animals",
"interactive weather map with animated clouds and rain",
"retro arcade RPG start screen",
"3D animation of a SpaceX rocket landing on Mars",
"storybook scene with a dragon flying over a castle",
"interactive dashboard for a city traffic system",
"animated aquarium with colorful fish and coral",
"sci-fi spaceship navigation interface",
"cozy cafe menu with animated steam and pastries",
"character sheet for a high fantasy RPG",
"a fancy hotel homepage",
]
def html_prompt_from_choice(choice):
return f"User: Write HTML: {choice}\n\nAssistant: div { gap: 4px !important; }
.html-grid-controls { grid-column: 1 !important; grid-row: 1 !important; gap: 4px !important; min-width: 0 !important; }
.html-grid-pages { grid-column: 2 !important; grid-row: 1 / span 2 !important; gap: 4px !important; min-width: 0 !important; }
.html-grid-output { grid-column: 1 !important; grid-row: 2 !important; gap: 4px !important; min-width: 0 !important; }
.html-grid-preview-row { gap: 4px !important; margin: 0 !important; }
.html-grid-preview { margin: 0 !important; }
.html-grid-preview > div { margin: 0 !important; }
.html-container { padding: 0 !important; }
.html-prompt-choice { margin: 0 !important; padding: 0 !important; min-height: 0 !important; }
.html-prompt-choice .wrap,
.html-prompt-choice .wrap-inner,
.html-prompt-choice .secondary-wrap { margin: 0 !important; padding: 0 !important; min-height: 0 !important; }
.html-prompt-choice label { margin: 0 !important; padding: 0 !important; }
@media (max-width: 768px) {
.html-grid-main { grid-template-columns: 1fr !important; grid-template-rows: auto auto auto !important; }
.html-grid-controls { grid-column: 1 !important; grid-row: 1 !important; }
.html-grid-pages { grid-column: 1 !important; grid-row: 2 !important; }
.html-grid-output { grid-column: 1 !important; grid-row: 3 !important; }
}
"""
html_view_lock = threading.Lock()
html_view_scale = 35
html_view_scroll_seconds = 5
def clamp_html_scale(scale):
return max(20, min(100, int(scale)))
def clamp_scroll_seconds(seconds):
return max(0, min(10, int(seconds)))
def set_html_scale(scale):
global html_view_scale
scale = clamp_html_scale(scale)
with html_view_lock:
html_view_scale = scale
return scale
def get_html_scale():
with html_view_lock:
return html_view_scale
def set_scroll_seconds(seconds):
global html_view_scroll_seconds
seconds = clamp_scroll_seconds(seconds)
with html_view_lock:
html_view_scroll_seconds = seconds
return seconds
def get_scroll_seconds():
with html_view_lock:
return html_view_scroll_seconds
########################## text rwkv ################################################################
title = "rwkv7-g1g-13.3b-20260523-ctx8192"
model_path = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth")
v3a.MODEL_PATH = model_path
v3a.WKV_MODE = "fp32io16"
v3a.EMB_DEVICE = "cuda"
v3a.RKV_MODE = "off"
v3a.CMIX_SPARSE = "no-fc"
v3a.LOWRANK_WEIGHT = "transpose"
v3a.ORIG_LINEAR_GROUPS = {"att_c2c", "ffn_key", "head"}
v3a.load_extensions(v3a.WKV_MODE)
model = v3a.RWKV7()
gc.collect()
torch.cuda.empty_cache()
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
@torch.jit.script
def sample_logits_batch_cuda(logits, temperature: float, top_p: float, k: int):
if top_p <= 0.0 or k == 1:
return torch.argmax(logits, dim=-1)
vals, ids = torch.topk(logits.float(), k=k, dim=-1, sorted=True)
if temperature == 1.0:
probs = torch.softmax(vals, dim=-1)
else:
probs = torch.softmax(vals / temperature, dim=-1)
cdf = torch.cumsum(probs, dim=-1)
if top_p < 1.0:
keep = torch.argmax((cdf >= top_p).to(torch.int32), dim=-1)
mass = cdf.gather(1, keep.view(-1, 1)).view(-1)
else:
mass = cdf[:, -1]
r = torch.rand((logits.size(0), 1), device=logits.device) * mass.view(-1, 1)
out = torch.searchsorted(cdf, r).view(-1, 1)
return ids.gather(1, out).view(-1)
def get_decode_ctx(B: int, decode_cache):
cached = decode_cache.get(B)
if cached is not None:
return cached
state = model.zero_state(B)
x = torch.empty((B, 1, v3a.C), device="cuda", dtype=torch.half)
path = v3a.select_path(B, 1)
for _ in range(2):
model.forward_from_x(x, state, path)
torch.cuda.synchronize()
graph = None
output = None
if USE_CUDA_GRAPH:
try:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
output = model.forward_from_x(x, state, path)
torch.cuda.synchronize()
except Exception as exc:
print(f"CUDA graph disabled for B={B}: {exc}", flush=True)
graph = None
output = None
cached = (state, x, graph, output)
decode_cache[B] = cached
return cached
def copy_state_to_batch(dst, src):
B = dst[2].shape[0]
dst[0].copy_(src[0].expand(-1, -1, B, -1))
dst[1].copy_(src[1].expand(-1, B, -1, -1, -1))
dst[2].copy_(src[2].expand(B))
def tokens_to_x(tokens):
token_device = "cpu" if model.emb_cpu else "cuda"
if isinstance(tokens, torch.Tensor):
token_tensor = tokens.to(device=token_device, dtype=torch.long, non_blocking=True).view(-1, 1)
else:
token_tensor = torch.tensor(tokens, dtype=torch.long, device=token_device).view(-1, 1)
return model.embed(token_tensor)
def generate_prompt(instruction, input=""):
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
if input:
return f"Instruction: {instruction}\n\nInput: {input}\n\nResponse:"
else:
return f"User: {instruction}\n\nAssistant: 0:
token_device = "cpu" if rwkv_model.emb_cpu else "cuda"
tokens = torch.tensor(input_ids[:CHUNK_LEN], dtype=torch.long, device=token_device)
out = rwkv_model.forward(tokens, state).view(-1)
input_ids = input_ids[CHUNK_LEN:]
torch.cuda.synchronize()
copy_state_to_batch(decode_state, state)
logits = out.view(1, -1).repeat(B, 1)
else:
decode_x.copy_(tokens_to_x(next_tokens))
if decode_graph is None:
decode_output = rwkv_model.forward_from_x(decode_x, decode_state, v3a.select_path(B, 1))
else:
decode_graph.replay()
logits = decode_output.view(B, -1)
if occurrence_count is None:
occurrence_count = torch.zeros((B, logits.size(-1)), device=logits.device, dtype=logits.dtype)
occurrence_presence = torch.zeros_like(occurrence_count)
batch_rows = torch.arange(B, device=logits.device)
if alpha_frequency:
logits.sub_(occurrence_count, alpha=alpha_frequency)
if alpha_presence:
logits.sub_(occurrence_presence)
assert logits.is_cuda and logits.dim() == 2
sampled_tensor = sample_logits_batch_cuda(
logits,
sample_temperature,
sample_top_p,
min(SAMPLER_TOP_K, logits.size(-1)),
)
sampled = sampled_tensor.detach().cpu().tolist()
active = 0
next_tokens = [0 for _ in range(B)]
if penalty_decay != 1:
occurrence_count.mul_(penalty_decay)
occurrence_count[batch_rows, sampled_tensor] += 1
if alpha_presence:
occurrence_presence[batch_rows, sampled_tensor] = alpha_presence
for b in range(B):
if finished[b]:
continue
token = sampled[b]
if token == 0:
finished[b] = True
continue
active += 1
next_tokens[b] = token
all_tokens[b].append(token)
tmp = pipe.decode(all_tokens[b][out_last[b]:])
if '\ufffd' not in tmp:
out_str[b] += tmp
out_last[b] = len(all_tokens[b])
total_tokens += active
if active == 0:
break
if speed_t0 is None:
speed_t0 = time.perf_counter()
else:
speed_tokens += B
elapsed = max(1e-9, time.perf_counter() - speed_t0)
current_text = output_text(B, out_str)
speed_info = speed_text(speed_tokens / elapsed, B, total_tokens, len(current_text))
if i == 0 or i % max(1, int(yield_every)) == 0:
current_text = output_text(B, out_str)
yield current_text, speed_info, {"done": False, "token_counts": [len(tokens) for tokens in all_tokens]}
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
free, total = torch.cuda.mem_get_info()
del out
del state
del decode_cache
current_text = output_text(B, out_str)
if speed_t0 is not None and not speed_info:
speed_info = speed_text(0.0, B, total_tokens, len(current_text))
elapsed = time.perf_counter() - req_t0
final_tps = speed_tokens / max(1e-9, time.perf_counter() - speed_t0) if speed_t0 is not None else 0.0
used = total - free
meta = {"done": True, "timestamp": timestamp, "B": B, "T": user_token_count, "In": input_token_count, "TPS": final_tps, "Time": elapsed, "Token": total_tokens, "Char": len(current_text), "VRAMUsed": used, "VRAMTotal": total, "token_counts": [len(tokens) for tokens in all_tokens]}
yield current_text, speed_info, meta
def print_summary(prefix, meta):
print(
f"[{prefix}] {meta['timestamp']} B={meta['B']} T={meta['T']} In={meta['In']} "
f"TPS={meta['TPS']:.1f} Time={meta['Time']:.3f}s Token={meta['Token']} "
f"Char={meta['Char']} VRAM={gib(meta['VRAMUsed']):.2f}G/{gib(meta['VRAMTotal']):.2f}G",
flush=True,
)
def evaluate_raw(
ctx,
token_count=200,
batch_size=1,
temperature=1.0,
top_p=0.5,
presencePenalty = 1,
countPenalty = 0.1,
penalty_decay = 0.99,
):
for text, speed_info, meta in generate_batch_text(ctx, token_count, batch_size, temperature, top_p, presencePenalty, countPenalty, penalty_decay, YIELD_EVERY):
if meta and meta.get("done"):
print_summary("app3", meta)
yield output_update(text, speed_info)
def split_batch_output(text, count):
parts = text.split("\n====\n") if text else []
parts += [""] * max(0, count - len(parts))
return parts[:count]
def extract_html(text, prompt=""):
stream = prompt + text
lower_text = stream.lower()
marker = lower_text.find("")
if marker < 0:
return ""
visible = stream[marker + len(""):]
lower = visible.lower()
start = lower.find("", start)
if end >= 0:
return visible[start:end + len("