NeverlandPeter commited on
Commit
9403a3d
·
1 Parent(s): f9eb4f2

batch inference

Browse files
Files changed (1) hide show
  1. app.py +86 -53
app.py CHANGED
@@ -13,6 +13,8 @@ gpu_h = nvmlDeviceGetHandleByIndex(0)
13
 
14
  ctx_limit = 7000
15
  gen_limit = 1000
 
 
16
 
17
  ########################## text rwkv ################################################################
18
 
@@ -31,18 +33,33 @@ v3a.load_extensions(v3a.WKV_MODE)
31
  model = v3a.RWKV7()
32
  pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
33
 
34
- decode_state = model.zero_state(1)
35
- decode_x = torch.empty((1, 1, v3a.C), device="cuda", dtype=torch.half)
36
- decode_path = v3a.select_path(1, 1)
37
- for _ in range(2):
38
- model.forward_from_x(decode_x, decode_state, decode_path)
39
- torch.cuda.synchronize()
40
- decode_graph = torch.cuda.CUDAGraph()
41
- with torch.cuda.graph(decode_graph):
42
- decode_output = model.forward_from_x(decode_x, decode_state, decode_path)
43
-
44
- def token_to_x(token: int):
45
- token_tensor = torch.tensor([[int(token)]], dtype=torch.long, device="cpu" if model.emb_cpu else "cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  return model.embed(token_tensor)
47
 
48
  def generate_prompt(instruction, input=""):
@@ -61,6 +78,7 @@ def qa_prompt(instruction):
61
  def evaluate(
62
  ctx,
63
  token_count=200,
 
64
  temperature=1.0,
65
  top_p=0.5,
66
  presencePenalty = 2,
@@ -73,56 +91,69 @@ def evaluate(
73
  token_ban = [], # ban the generation of some tokens
74
  token_stop = [0]) # stop generation whenever you see any token here
75
  ctx = ctx.strip()
76
- all_tokens = []
77
- out_last = 0
78
- out_str = ''
79
- occurrence = {}
 
 
80
  state = model.zero_state(1)
 
 
81
  out = None
82
  for i in range(int(token_count)):
83
-
84
  if i == 0:
85
  input_ids = pipeline.encode(ctx)[-ctx_limit:]
86
- CHUNK_LEN = 8192 # chunk prefill, save VRAM
87
  while len(input_ids) > 0:
88
  token_device = "cpu" if model.emb_cpu else "cuda"
89
  tokens = torch.tensor(input_ids[:CHUNK_LEN], dtype=torch.long, device=token_device)
90
  out = model.forward(tokens, state).view(-1)
91
  input_ids = input_ids[CHUNK_LEN:]
92
- for dst, src in zip(decode_state, state):
93
- dst.copy_(src)
94
- logits = out
95
  else:
96
- decode_x.copy_(token_to_x(token))
97
  decode_graph.replay()
98
- logits = decode_output.view(-1)
99
-
100
- for n in occurrence:
101
- logits[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
102
-
103
- token = pipeline.sample_logits(logits, temperature=args.temperature, top_p=args.top_p)
104
- if token in args.token_stop:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  break
106
- all_tokens += [token]
107
- for xxx in occurrence:
108
- occurrence[xxx] *= penalty_decay
109
-
110
- ttt = pipeline.decode([token])
111
- www = 1
112
- #if ttt in ' \t0123456789':
113
- # www = 0
114
- #elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
115
- # www = 0.5
116
- if token not in occurrence:
117
- occurrence[token] = www
118
- else:
119
- occurrence[token] += www
120
-
121
- tmp = pipeline.decode(all_tokens[out_last:])
122
- if '\ufffd' not in tmp:
123
- out_str += tmp
124
- yield out_str.strip()
125
- out_last = i + 1
126
 
127
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
128
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@@ -131,7 +162,7 @@ def evaluate(
131
  del state
132
  gc.collect()
133
  torch.cuda.empty_cache()
134
- yield out_str.strip()
135
 
136
  examples = [
137
  ["System: Tools:\n- get_weather(location: string, unit?: \"celsius\" | \"fahrenheit\")\n- get_stock_price(ticker: string)\n- translate_text(text: string, target_language: string)\nReturn only a JSON function call.\n\nUser: Translate \"Will it rain tomorrow?\" into Japanese.\n\nAssistant: ```json", 200, 1, 0, 0, 0, 0.99],
@@ -150,6 +181,7 @@ examples = [
150
  ['''“当然可以,大宇宙不会因为这五公斤就不坍缩了。”关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大自然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然可以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。\n但愿大宇宙能够忽略这个误差。\n程心和关一帆进入了飞船,智子最后也进来了。她早就不再穿那身华丽的和服了,她现在身着迷彩服,再次成为一名轻捷精悍的战士,她的身上佩带着许多武器和生存装备,最引人注目的是那把插在背后的武士刀。\n“放心,我在,你们就在!”智子对两位人类朋友说。\n聚变发动机启动了,推进器发出幽幽的蓝光,''', gen_limit, 1, 0.5, 2, 0.2, 0.99],
151
  ['''Edward: I am Edward Elric from Fullmetal Alchemist.\n\nUser: Hello Edward. What have you been up to recently?\n\nEdward:''', gen_limit, 1, 0.5, 2, 0.2, 0.99],
152
  ]
 
153
 
154
  ##################################################################################################################
155
  with gr.Blocks(title=title, theme=gr.themes.Base()) as demo:
@@ -161,6 +193,7 @@ with gr.Blocks(title=title, theme=gr.themes.Base()) as demo:
161
  with gr.Column():
162
  prompt = gr.Textbox(lines=6, label="Prompt", value="User: simulate SpaceX mars landing using python\n\nAssistant: <think></think")
163
  token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=gen_limit)
 
164
  temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
165
  top_p = gr.Slider(0.0, 0.95, label="Top P", step=0.05, value=0.5)
166
  presence_penalty = gr.Slider(0.0, 2.0, label="Presence Penalty", step=0.1, value=2)
@@ -171,10 +204,10 @@ with gr.Blocks(title=title, theme=gr.themes.Base()) as demo:
171
  submit = gr.Button("Submit", variant="primary")
172
  clear = gr.Button("Clear", variant="secondary")
173
  output = gr.Textbox(label="Output", lines=20, max_lines=100)
174
- data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty, penalty_decay], samples=examples, samples_per_page=50, label="Example Instructions", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty", "Penalty Decay"])
175
- submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty, penalty_decay], [output])
176
  clear.click(lambda: None, [], [output])
177
- data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty, penalty_decay])
178
 
179
  demo.queue(default_concurrency_limit=1, max_size=10)
180
  demo.launch(share=False, server_name="0.0.0.0")
 
13
 
14
  ctx_limit = 7000
15
  gen_limit = 1000
16
+ max_bsz = 16
17
+ CHUNK_LEN = 8192 # chunk prefill, save VRAM
18
 
19
  ########################## text rwkv ################################################################
20
 
 
33
  model = v3a.RWKV7()
34
  pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
35
 
36
+ decode_cache = {}
37
+
38
+ def get_decode_ctx(B: int):
39
+ cached = decode_cache.get(B)
40
+ if cached is not None:
41
+ return cached
42
+ state = model.zero_state(B)
43
+ x = torch.empty((B, 1, v3a.C), device="cuda", dtype=torch.half)
44
+ path = v3a.select_path(B, 1)
45
+ for _ in range(2):
46
+ model.forward_from_x(x, state, path)
47
+ torch.cuda.synchronize()
48
+ graph = torch.cuda.CUDAGraph()
49
+ with torch.cuda.graph(graph):
50
+ output = model.forward_from_x(x, state, path)
51
+ cached = (state, x, graph, output)
52
+ decode_cache[B] = cached
53
+ return cached
54
+
55
+ def copy_state_to_batch(dst, src):
56
+ B = dst[2].shape[0]
57
+ dst[0].copy_(src[0].expand(-1, -1, B, -1))
58
+ dst[1].copy_(src[1].expand(-1, B, -1, -1, -1))
59
+ dst[2].copy_(src[2].expand(B))
60
+
61
+ def tokens_to_x(tokens):
62
+ token_tensor = torch.tensor(tokens, dtype=torch.long, device="cpu" if model.emb_cpu else "cuda").view(-1, 1)
63
  return model.embed(token_tensor)
64
 
65
  def generate_prompt(instruction, input=""):
 
78
  def evaluate(
79
  ctx,
80
  token_count=200,
81
+ batch_size=1,
82
  temperature=1.0,
83
  top_p=0.5,
84
  presencePenalty = 2,
 
91
  token_ban = [], # ban the generation of some tokens
92
  token_stop = [0]) # stop generation whenever you see any token here
93
  ctx = ctx.strip()
94
+ B = max(1, int(batch_size))
95
+ all_tokens = [[] for _ in range(B)]
96
+ out_last = [0 for _ in range(B)]
97
+ out_str = ['' for _ in range(B)]
98
+ occurrence = [{} for _ in range(B)]
99
+ finished = [False for _ in range(B)]
100
  state = model.zero_state(1)
101
+ decode_state, decode_x, decode_graph, decode_output = get_decode_ctx(B)
102
+ next_tokens = [0 for _ in range(B)]
103
  out = None
104
  for i in range(int(token_count)):
105
+
106
  if i == 0:
107
  input_ids = pipeline.encode(ctx)[-ctx_limit:]
 
108
  while len(input_ids) > 0:
109
  token_device = "cpu" if model.emb_cpu else "cuda"
110
  tokens = torch.tensor(input_ids[:CHUNK_LEN], dtype=torch.long, device=token_device)
111
  out = model.forward(tokens, state).view(-1)
112
  input_ids = input_ids[CHUNK_LEN:]
113
+ copy_state_to_batch(decode_state, state)
114
+ logits = out.view(1, -1).expand(B, -1)
 
115
  else:
116
+ decode_x.copy_(tokens_to_x(next_tokens))
117
  decode_graph.replay()
118
+ logits = decode_output.view(B, -1)
119
+
120
+ active = 0
121
+ next_tokens = [0 for _ in range(B)]
122
+ for b in range(B):
123
+ if finished[b]:
124
+ continue
125
+ row = logits[b]
126
+ for n in occurrence[b]:
127
+ row[n] -= (args.alpha_presence + occurrence[b][n] * args.alpha_frequency)
128
+
129
+ token = pipeline.sample_logits(row, temperature=args.temperature, top_p=args.top_p)
130
+ if token in args.token_stop:
131
+ finished[b] = True
132
+ continue
133
+ active += 1
134
+ next_tokens[b] = token
135
+ all_tokens[b] += [token]
136
+ for xxx in occurrence[b]:
137
+ occurrence[b][xxx] *= penalty_decay
138
+
139
+ ttt = pipeline.decode([token])
140
+ www = 1
141
+ #if ttt in ' \t0123456789':
142
+ # www = 0
143
+ #elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
144
+ # www = 0.5
145
+ if token not in occurrence[b]:
146
+ occurrence[b][token] = www
147
+ else:
148
+ occurrence[b][token] += www
149
+
150
+ tmp = pipeline.decode(all_tokens[b][out_last[b]:])
151
+ if '\ufffd' not in tmp:
152
+ out_str[b] += tmp
153
+ out_last[b] = len(all_tokens[b])
154
+ if active == 0:
155
  break
156
+ yield out_str[0].strip() if B == 1 else "\n====\n".join(x.strip() for x in out_str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
159
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 
162
  del state
163
  gc.collect()
164
  torch.cuda.empty_cache()
165
+ yield out_str[0].strip() if B == 1 else "\n====\n".join(x.strip() for x in out_str)
166
 
167
  examples = [
168
  ["System: Tools:\n- get_weather(location: string, unit?: \"celsius\" | \"fahrenheit\")\n- get_stock_price(ticker: string)\n- translate_text(text: string, target_language: string)\nReturn only a JSON function call.\n\nUser: Translate \"Will it rain tomorrow?\" into Japanese.\n\nAssistant: ```json", 200, 1, 0, 0, 0, 0.99],
 
181
  ['''“当然可以,大宇宙不会因为这五公斤就不坍缩了。”关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大自然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然可以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。\n但愿大宇宙能够忽略这个误差。\n程心和关一帆进入了飞船,智子最后也进来了。她早就不再穿那身华丽的和服了,她现在身着迷彩服,再次成为一名轻捷精悍的战士,她的身上佩带着许多武器和生存装备,最引人注目的是那把插在背后的武士刀。\n“放心,我在,你们就在!”智子对两位人类朋友说。\n聚变发动机启动了,推进器发出幽幽的蓝光,''', gen_limit, 1, 0.5, 2, 0.2, 0.99],
182
  ['''Edward: I am Edward Elric from Fullmetal Alchemist.\n\nUser: Hello Edward. What have you been up to recently?\n\nEdward:''', gen_limit, 1, 0.5, 2, 0.2, 0.99],
183
  ]
184
+ examples = [[x[0], x[1], 1, *x[2:]] for x in examples]
185
 
186
  ##################################################################################################################
187
  with gr.Blocks(title=title, theme=gr.themes.Base()) as demo:
 
193
  with gr.Column():
194
  prompt = gr.Textbox(lines=6, label="Prompt", value="User: simulate SpaceX mars landing using python\n\nAssistant: <think></think")
195
  token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=gen_limit)
196
+ batch_size = gr.Slider(1, max_bsz, label="Batch Size", step=1, value=max_bsz)
197
  temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
198
  top_p = gr.Slider(0.0, 0.95, label="Top P", step=0.05, value=0.5)
199
  presence_penalty = gr.Slider(0.0, 2.0, label="Presence Penalty", step=0.1, value=2)
 
204
  submit = gr.Button("Submit", variant="primary")
205
  clear = gr.Button("Clear", variant="secondary")
206
  output = gr.Textbox(label="Output", lines=20, max_lines=100)
207
+ data = gr.Dataset(components=[prompt, token_count, batch_size, temperature, top_p, presence_penalty, count_penalty, penalty_decay], samples=examples, samples_per_page=50, label="Example Instructions", headers=["Prompt", "Max Tokens", "Batch Size", "Temperature", "Top P", "Presence Penalty", "Count Penalty", "Penalty Decay"])
208
+ submit.click(evaluate, [prompt, token_count, batch_size, temperature, top_p, presence_penalty, count_penalty, penalty_decay], [output])
209
  clear.click(lambda: None, [], [output])
210
+ data.click(lambda x: x, [data], [prompt, token_count, batch_size, temperature, top_p, presence_penalty, count_penalty, penalty_decay])
211
 
212
  demo.queue(default_concurrency_limit=1, max_size=10)
213
  demo.launch(share=False, server_name="0.0.0.0")