Spaces:
Running on T4
Running on T4
Commit ·
9403a3d
1
Parent(s): f9eb4f2
batch inference
Browse files
app.py
CHANGED
|
@@ -13,6 +13,8 @@ gpu_h = nvmlDeviceGetHandleByIndex(0)
|
|
| 13 |
|
| 14 |
ctx_limit = 7000
|
| 15 |
gen_limit = 1000
|
|
|
|
|
|
|
| 16 |
|
| 17 |
########################## text rwkv ################################################################
|
| 18 |
|
|
@@ -31,18 +33,33 @@ v3a.load_extensions(v3a.WKV_MODE)
|
|
| 31 |
model = v3a.RWKV7()
|
| 32 |
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
return model.embed(token_tensor)
|
| 47 |
|
| 48 |
def generate_prompt(instruction, input=""):
|
|
@@ -61,6 +78,7 @@ def qa_prompt(instruction):
|
|
| 61 |
def evaluate(
|
| 62 |
ctx,
|
| 63 |
token_count=200,
|
|
|
|
| 64 |
temperature=1.0,
|
| 65 |
top_p=0.5,
|
| 66 |
presencePenalty = 2,
|
|
@@ -73,56 +91,69 @@ def evaluate(
|
|
| 73 |
token_ban = [], # ban the generation of some tokens
|
| 74 |
token_stop = [0]) # stop generation whenever you see any token here
|
| 75 |
ctx = ctx.strip()
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
| 80 |
state = model.zero_state(1)
|
|
|
|
|
|
|
| 81 |
out = None
|
| 82 |
for i in range(int(token_count)):
|
| 83 |
-
|
| 84 |
if i == 0:
|
| 85 |
input_ids = pipeline.encode(ctx)[-ctx_limit:]
|
| 86 |
-
CHUNK_LEN = 8192 # chunk prefill, save VRAM
|
| 87 |
while len(input_ids) > 0:
|
| 88 |
token_device = "cpu" if model.emb_cpu else "cuda"
|
| 89 |
tokens = torch.tensor(input_ids[:CHUNK_LEN], dtype=torch.long, device=token_device)
|
| 90 |
out = model.forward(tokens, state).view(-1)
|
| 91 |
input_ids = input_ids[CHUNK_LEN:]
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
logits = out
|
| 95 |
else:
|
| 96 |
-
decode_x.copy_(
|
| 97 |
decode_graph.replay()
|
| 98 |
-
logits = decode_output.view(-1)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
break
|
| 106 |
-
|
| 107 |
-
for xxx in occurrence:
|
| 108 |
-
occurrence[xxx] *= penalty_decay
|
| 109 |
-
|
| 110 |
-
ttt = pipeline.decode([token])
|
| 111 |
-
www = 1
|
| 112 |
-
#if ttt in ' \t0123456789':
|
| 113 |
-
# www = 0
|
| 114 |
-
#elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
|
| 115 |
-
# www = 0.5
|
| 116 |
-
if token not in occurrence:
|
| 117 |
-
occurrence[token] = www
|
| 118 |
-
else:
|
| 119 |
-
occurrence[token] += www
|
| 120 |
-
|
| 121 |
-
tmp = pipeline.decode(all_tokens[out_last:])
|
| 122 |
-
if '\ufffd' not in tmp:
|
| 123 |
-
out_str += tmp
|
| 124 |
-
yield out_str.strip()
|
| 125 |
-
out_last = i + 1
|
| 126 |
|
| 127 |
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
|
| 128 |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
@@ -131,7 +162,7 @@ def evaluate(
|
|
| 131 |
del state
|
| 132 |
gc.collect()
|
| 133 |
torch.cuda.empty_cache()
|
| 134 |
-
yield out_str.strip()
|
| 135 |
|
| 136 |
examples = [
|
| 137 |
["System: Tools:\n- get_weather(location: string, unit?: \"celsius\" | \"fahrenheit\")\n- get_stock_price(ticker: string)\n- translate_text(text: string, target_language: string)\nReturn only a JSON function call.\n\nUser: Translate \"Will it rain tomorrow?\" into Japanese.\n\nAssistant: ```json", 200, 1, 0, 0, 0, 0.99],
|
|
@@ -150,6 +181,7 @@ examples = [
|
|
| 150 |
['''“当然可以,大宇宙不会因为这五公斤就不坍缩了。”关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大自然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然可以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。\n但愿大宇宙能够忽略这个误差。\n程心和关一帆进入了飞船,智子最后也进来了。她早就不再穿那身华丽的和服了,她现在身着迷彩服,再次成为一名轻捷精悍的战士,她的身上佩带着许多武器和生存装备,最引人注目的是那把插在背后的武士刀。\n“放心,我在,你们就在!”智子对两位人类朋友说。\n聚变发动机启动了,推进器发出幽幽的蓝光,''', gen_limit, 1, 0.5, 2, 0.2, 0.99],
|
| 151 |
['''Edward: I am Edward Elric from Fullmetal Alchemist.\n\nUser: Hello Edward. What have you been up to recently?\n\nEdward:''', gen_limit, 1, 0.5, 2, 0.2, 0.99],
|
| 152 |
]
|
|
|
|
| 153 |
|
| 154 |
##################################################################################################################
|
| 155 |
with gr.Blocks(title=title, theme=gr.themes.Base()) as demo:
|
|
@@ -161,6 +193,7 @@ with gr.Blocks(title=title, theme=gr.themes.Base()) as demo:
|
|
| 161 |
with gr.Column():
|
| 162 |
prompt = gr.Textbox(lines=6, label="Prompt", value="User: simulate SpaceX mars landing using python\n\nAssistant: <think></think")
|
| 163 |
token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=gen_limit)
|
|
|
|
| 164 |
temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
|
| 165 |
top_p = gr.Slider(0.0, 0.95, label="Top P", step=0.05, value=0.5)
|
| 166 |
presence_penalty = gr.Slider(0.0, 2.0, label="Presence Penalty", step=0.1, value=2)
|
|
@@ -171,10 +204,10 @@ with gr.Blocks(title=title, theme=gr.themes.Base()) as demo:
|
|
| 171 |
submit = gr.Button("Submit", variant="primary")
|
| 172 |
clear = gr.Button("Clear", variant="secondary")
|
| 173 |
output = gr.Textbox(label="Output", lines=20, max_lines=100)
|
| 174 |
-
data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty, penalty_decay], samples=examples, samples_per_page=50, label="Example Instructions", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty", "Penalty Decay"])
|
| 175 |
-
submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty, penalty_decay], [output])
|
| 176 |
clear.click(lambda: None, [], [output])
|
| 177 |
-
data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty, penalty_decay])
|
| 178 |
|
| 179 |
demo.queue(default_concurrency_limit=1, max_size=10)
|
| 180 |
demo.launch(share=False, server_name="0.0.0.0")
|
|
|
|
| 13 |
|
| 14 |
ctx_limit = 7000
|
| 15 |
gen_limit = 1000
|
| 16 |
+
max_bsz = 16
|
| 17 |
+
CHUNK_LEN = 8192 # chunk prefill, save VRAM
|
| 18 |
|
| 19 |
########################## text rwkv ################################################################
|
| 20 |
|
|
|
|
| 33 |
model = v3a.RWKV7()
|
| 34 |
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
|
| 35 |
|
| 36 |
+
decode_cache = {}
|
| 37 |
+
|
| 38 |
+
def get_decode_ctx(B: int):
|
| 39 |
+
cached = decode_cache.get(B)
|
| 40 |
+
if cached is not None:
|
| 41 |
+
return cached
|
| 42 |
+
state = model.zero_state(B)
|
| 43 |
+
x = torch.empty((B, 1, v3a.C), device="cuda", dtype=torch.half)
|
| 44 |
+
path = v3a.select_path(B, 1)
|
| 45 |
+
for _ in range(2):
|
| 46 |
+
model.forward_from_x(x, state, path)
|
| 47 |
+
torch.cuda.synchronize()
|
| 48 |
+
graph = torch.cuda.CUDAGraph()
|
| 49 |
+
with torch.cuda.graph(graph):
|
| 50 |
+
output = model.forward_from_x(x, state, path)
|
| 51 |
+
cached = (state, x, graph, output)
|
| 52 |
+
decode_cache[B] = cached
|
| 53 |
+
return cached
|
| 54 |
+
|
| 55 |
+
def copy_state_to_batch(dst, src):
|
| 56 |
+
B = dst[2].shape[0]
|
| 57 |
+
dst[0].copy_(src[0].expand(-1, -1, B, -1))
|
| 58 |
+
dst[1].copy_(src[1].expand(-1, B, -1, -1, -1))
|
| 59 |
+
dst[2].copy_(src[2].expand(B))
|
| 60 |
+
|
| 61 |
+
def tokens_to_x(tokens):
|
| 62 |
+
token_tensor = torch.tensor(tokens, dtype=torch.long, device="cpu" if model.emb_cpu else "cuda").view(-1, 1)
|
| 63 |
return model.embed(token_tensor)
|
| 64 |
|
| 65 |
def generate_prompt(instruction, input=""):
|
|
|
|
| 78 |
def evaluate(
|
| 79 |
ctx,
|
| 80 |
token_count=200,
|
| 81 |
+
batch_size=1,
|
| 82 |
temperature=1.0,
|
| 83 |
top_p=0.5,
|
| 84 |
presencePenalty = 2,
|
|
|
|
| 91 |
token_ban = [], # ban the generation of some tokens
|
| 92 |
token_stop = [0]) # stop generation whenever you see any token here
|
| 93 |
ctx = ctx.strip()
|
| 94 |
+
B = max(1, int(batch_size))
|
| 95 |
+
all_tokens = [[] for _ in range(B)]
|
| 96 |
+
out_last = [0 for _ in range(B)]
|
| 97 |
+
out_str = ['' for _ in range(B)]
|
| 98 |
+
occurrence = [{} for _ in range(B)]
|
| 99 |
+
finished = [False for _ in range(B)]
|
| 100 |
state = model.zero_state(1)
|
| 101 |
+
decode_state, decode_x, decode_graph, decode_output = get_decode_ctx(B)
|
| 102 |
+
next_tokens = [0 for _ in range(B)]
|
| 103 |
out = None
|
| 104 |
for i in range(int(token_count)):
|
| 105 |
+
|
| 106 |
if i == 0:
|
| 107 |
input_ids = pipeline.encode(ctx)[-ctx_limit:]
|
|
|
|
| 108 |
while len(input_ids) > 0:
|
| 109 |
token_device = "cpu" if model.emb_cpu else "cuda"
|
| 110 |
tokens = torch.tensor(input_ids[:CHUNK_LEN], dtype=torch.long, device=token_device)
|
| 111 |
out = model.forward(tokens, state).view(-1)
|
| 112 |
input_ids = input_ids[CHUNK_LEN:]
|
| 113 |
+
copy_state_to_batch(decode_state, state)
|
| 114 |
+
logits = out.view(1, -1).expand(B, -1)
|
|
|
|
| 115 |
else:
|
| 116 |
+
decode_x.copy_(tokens_to_x(next_tokens))
|
| 117 |
decode_graph.replay()
|
| 118 |
+
logits = decode_output.view(B, -1)
|
| 119 |
+
|
| 120 |
+
active = 0
|
| 121 |
+
next_tokens = [0 for _ in range(B)]
|
| 122 |
+
for b in range(B):
|
| 123 |
+
if finished[b]:
|
| 124 |
+
continue
|
| 125 |
+
row = logits[b]
|
| 126 |
+
for n in occurrence[b]:
|
| 127 |
+
row[n] -= (args.alpha_presence + occurrence[b][n] * args.alpha_frequency)
|
| 128 |
+
|
| 129 |
+
token = pipeline.sample_logits(row, temperature=args.temperature, top_p=args.top_p)
|
| 130 |
+
if token in args.token_stop:
|
| 131 |
+
finished[b] = True
|
| 132 |
+
continue
|
| 133 |
+
active += 1
|
| 134 |
+
next_tokens[b] = token
|
| 135 |
+
all_tokens[b] += [token]
|
| 136 |
+
for xxx in occurrence[b]:
|
| 137 |
+
occurrence[b][xxx] *= penalty_decay
|
| 138 |
+
|
| 139 |
+
ttt = pipeline.decode([token])
|
| 140 |
+
www = 1
|
| 141 |
+
#if ttt in ' \t0123456789':
|
| 142 |
+
# www = 0
|
| 143 |
+
#elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
|
| 144 |
+
# www = 0.5
|
| 145 |
+
if token not in occurrence[b]:
|
| 146 |
+
occurrence[b][token] = www
|
| 147 |
+
else:
|
| 148 |
+
occurrence[b][token] += www
|
| 149 |
+
|
| 150 |
+
tmp = pipeline.decode(all_tokens[b][out_last[b]:])
|
| 151 |
+
if '\ufffd' not in tmp:
|
| 152 |
+
out_str[b] += tmp
|
| 153 |
+
out_last[b] = len(all_tokens[b])
|
| 154 |
+
if active == 0:
|
| 155 |
break
|
| 156 |
+
yield out_str[0].strip() if B == 1 else "\n====\n".join(x.strip() for x in out_str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
|
| 159 |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
| 162 |
del state
|
| 163 |
gc.collect()
|
| 164 |
torch.cuda.empty_cache()
|
| 165 |
+
yield out_str[0].strip() if B == 1 else "\n====\n".join(x.strip() for x in out_str)
|
| 166 |
|
| 167 |
examples = [
|
| 168 |
["System: Tools:\n- get_weather(location: string, unit?: \"celsius\" | \"fahrenheit\")\n- get_stock_price(ticker: string)\n- translate_text(text: string, target_language: string)\nReturn only a JSON function call.\n\nUser: Translate \"Will it rain tomorrow?\" into Japanese.\n\nAssistant: ```json", 200, 1, 0, 0, 0, 0.99],
|
|
|
|
| 181 |
['''“当然可以,大宇宙不会因为这五公斤就不坍缩了。”关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大自然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然可以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。\n但愿大宇宙能够忽略这个误差。\n程心和关一帆进入了飞船,智子最后也进来了。她早就不再穿那身华丽的和服了,她现在身着迷彩服,再次成为一名轻捷精悍的战士,她的身上佩带着许多武器和生存装备,最引人注目的是那把插在背后的武士刀。\n“放心,我在,你们就在!”智子对两位人类朋友说。\n聚变发动机启动了,推进器发出幽幽的蓝光,''', gen_limit, 1, 0.5, 2, 0.2, 0.99],
|
| 182 |
['''Edward: I am Edward Elric from Fullmetal Alchemist.\n\nUser: Hello Edward. What have you been up to recently?\n\nEdward:''', gen_limit, 1, 0.5, 2, 0.2, 0.99],
|
| 183 |
]
|
| 184 |
+
examples = [[x[0], x[1], 1, *x[2:]] for x in examples]
|
| 185 |
|
| 186 |
##################################################################################################################
|
| 187 |
with gr.Blocks(title=title, theme=gr.themes.Base()) as demo:
|
|
|
|
| 193 |
with gr.Column():
|
| 194 |
prompt = gr.Textbox(lines=6, label="Prompt", value="User: simulate SpaceX mars landing using python\n\nAssistant: <think></think")
|
| 195 |
token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=gen_limit)
|
| 196 |
+
batch_size = gr.Slider(1, max_bsz, label="Batch Size", step=1, value=max_bsz)
|
| 197 |
temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
|
| 198 |
top_p = gr.Slider(0.0, 0.95, label="Top P", step=0.05, value=0.5)
|
| 199 |
presence_penalty = gr.Slider(0.0, 2.0, label="Presence Penalty", step=0.1, value=2)
|
|
|
|
| 204 |
submit = gr.Button("Submit", variant="primary")
|
| 205 |
clear = gr.Button("Clear", variant="secondary")
|
| 206 |
output = gr.Textbox(label="Output", lines=20, max_lines=100)
|
| 207 |
+
data = gr.Dataset(components=[prompt, token_count, batch_size, temperature, top_p, presence_penalty, count_penalty, penalty_decay], samples=examples, samples_per_page=50, label="Example Instructions", headers=["Prompt", "Max Tokens", "Batch Size", "Temperature", "Top P", "Presence Penalty", "Count Penalty", "Penalty Decay"])
|
| 208 |
+
submit.click(evaluate, [prompt, token_count, batch_size, temperature, top_p, presence_penalty, count_penalty, penalty_decay], [output])
|
| 209 |
clear.click(lambda: None, [], [output])
|
| 210 |
+
data.click(lambda x: x, [data], [prompt, token_count, batch_size, temperature, top_p, presence_penalty, count_penalty, penalty_decay])
|
| 211 |
|
| 212 |
demo.queue(default_concurrency_limit=1, max_size=10)
|
| 213 |
demo.launch(share=False, server_name="0.0.0.0")
|