import torch import gradio as gr import spaces from transformers import AutoTokenizer, AutoModelForCausalLM from diffusers import BlockRefinementScheduler, LLaDA2Pipeline MODEL_ID = "inclusionAI/LLaDA2.1-mini" REVISION = "refs/pr/4" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, revision=REVISION) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=torch.bfloat16, revision=REVISION, attn_implementation="sdpa", ) model.eval() scheduler = BlockRefinementScheduler() pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer) pipe.to("cuda") MASK_TOKEN = "[MASK]" MASK_ID = tokenizer.mask_token_id def extract_text(content): """Extract plain text from Gradio message content (may be str or list of dicts).""" if isinstance(content, str): return content if isinstance(content, list): return "".join( item.get("text", "") if isinstance(item, dict) else str(item) for item in content ) return str(content) def parse_constraints(constraints_text): """Parse constraints in format 'position:word, position:word, ...' Returns dict mapping gen-relative token positions to token IDs.""" constraints = {} if not constraints_text: return constraints for part in constraints_text.split(","): if ":" not in part: continue pos_str, word = part.split(":", 1) try: pos = int(pos_str.strip()) word = word.strip() if word and pos >= 0: token_ids = tokenizer.encode(" " + word, add_special_tokens=False) for i, tid in enumerate(token_ids): constraints[pos + i] = tid except ValueError: continue return constraints def confidence_label(prob): """Map a probability to a confidence category.""" if prob < 0.3: return "low" elif prob < 0.7: return "mid" return "high" def build_vis_state(x, prompt_length, gen_length, confidences=None): """Build (highlighted_text_state, plain_text) from the current token sequence. confidences: dict mapping gen-relative position -> probability (0-1) """ highlighted = [] tokens = [] for i in range(gen_length): pos = prompt_length + i if pos >= x.shape[1] or x[0, pos].item() == MASK_ID: highlighted.append((MASK_TOKEN, None)) tokens.append(MASK_TOKEN) else: token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True) token = token or " " label = confidence_label(confidences[i]) if confidences and i in confidences else None highlighted.append((token, label)) tokens.append(token) return highlighted, "".join(tokens) @spaces.GPU @torch.no_grad() def generate_streaming(messages, gen_length, num_inference_steps, temperature, block_length, threshold, editing_threshold, max_post_steps, constraints=None): """Generator that yields (highlighted, plain, response_text_or_None) at each step.""" if constraints is None: constraints = {} encoded = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True, ) prompt_ids = encoded["input_ids"].to("cuda") prompt_length = prompt_ids.shape[1] mask_token_id = MASK_ID eos_token_id = tokenizer.eos_token_id num_inference_steps = min(num_inference_steps, gen_length) pipe.scheduler.set_timesteps(num_inference_steps, device="cuda") num_blocks = (prompt_length + gen_length + block_length - 1) // block_length total_length = num_blocks * block_length attn_mask = torch.ones((1, total_length), device="cuda", dtype=torch.long) position_ids = torch.arange(total_length, device="cuda", dtype=torch.long).unsqueeze(0) x = torch.full((1, total_length), mask_token_id, device="cuda", dtype=torch.long) x[:, :prompt_length] = prompt_ids # Apply constraints to initial sequence for gen_pos, tid in constraints.items(): abs_pos = prompt_length + gen_pos if abs_pos < total_length: x[0, abs_pos] = tid prefill_blocks = prompt_length // block_length finished = torch.zeros((1,), device="cuda", dtype=torch.bool) editing_enabled = editing_threshold is not None and editing_threshold > 0.0 # Track per-position confidence (gen-relative pos -> probability) token_confidences = {} # Initial masked state highlighted, plain = build_vis_state(x, prompt_length, gen_length, token_confidences) yield highlighted, plain, None # Block-wise refinement for num_block in range(prefill_blocks, num_blocks): current_window_end = (num_block + 1) * block_length block_x = x[:, :current_window_end] block_attn_mask = attn_mask[:, :current_window_end] block_position_ids = position_ids[:, :current_window_end] block_start_pos = num_block * block_length prompt_mask_in_block = torch.zeros(block_length, device="cuda", dtype=torch.bool) if block_start_pos < prompt_length: prompt_end_in_block = min(prompt_length - block_start_pos, block_length) prompt_mask_in_block[:prompt_end_in_block] = True post_steps = 0 step_idx = 0 should_continue = True while should_continue: block_tokens = block_x[:, -block_length:] masks_remaining = (block_tokens == mask_token_id).any() if not masks_remaining: post_steps += 1 logits = pipe.model(block_x, attention_mask=block_attn_mask, position_ids=block_position_ids).logits block_logits = logits[:, -block_length:, :] scheduler_output = pipe.scheduler.step( model_output=block_logits, timestep=step_idx, sample=block_tokens, mask_token_id=mask_token_id, temperature=temperature, threshold=threshold, editing_threshold=editing_threshold, minimal_topk=1, prompt_mask=prompt_mask_in_block, return_dict=True, ) transfer_index = scheduler_output.transfer_index editing_transfer_index = scheduler_output.editing_transfer_index final_transfer = transfer_index | editing_transfer_index if final_transfer.any(): block_x[:, -block_length:] = scheduler_output.prev_sample # Record confidence for newly committed tokens sampled_probs = scheduler_output.sampled_probs[0].cpu() for j in range(block_length): if final_transfer[0, j]: gen_pos = block_start_pos + j - prompt_length if 0 <= gen_pos < gen_length: token_confidences[gen_pos] = float(sampled_probs[j]) # Re-apply constraints after each step for gen_pos, tid in constraints.items(): abs_pos = prompt_length + gen_pos if block_start_pos <= abs_pos < block_start_pos + block_length: block_x[0, abs_pos] = tid token_confidences[gen_pos] = 1.0 if eos_token_id is not None: finished = pipe.scheduler.check_eos_finished( cur_x=block_x, sampled_tokens=scheduler_output.sampled_tokens, final_transfer=final_transfer, finished=finished, eos_token_id=eos_token_id, mask_token_id=mask_token_id, prompt_length=prompt_length, ) # Yield current state for real-time visualization highlighted, plain = build_vis_state(x, prompt_length, gen_length, token_confidences) yield highlighted, plain, None if masks_remaining: step_idx += 1 should_continue = pipe.scheduler.check_block_should_continue( step_idx=step_idx, masks_remaining=masks_remaining, editing_enabled=editing_enabled, editing_transfer_index=editing_transfer_index, post_steps=post_steps, max_post_steps=max_post_steps, finished=finished, ) x[:, :current_window_end] = block_x if finished.all(): break # Decode final response generated = x[:, prompt_length:prompt_length + gen_length] if eos_token_id is not None: eos_positions = (generated[0] == eos_token_id).nonzero(as_tuple=True)[0] if len(eos_positions) > 0: generated = generated[:, :int(eos_positions[0].item()) + 1] response_text = tokenizer.decode(generated[0], skip_special_tokens=True) highlighted, plain = build_vis_state(x, prompt_length, gen_length, token_confidences) yield highlighted, plain, response_text css = """ .category-legend{display:none} button{height: 60px} .legend{margin-bottom: 5px} .legend-item{height: 25px} """ def create_chatbot_demo(): with gr.Blocks(css=css) as demo: gr.Markdown("# LLaDA2.1 - Large Language Diffusion Model Demo") gr.Markdown( "[model LLaDA 2.1-mini](https://huggingface.co/inclusionAI/LLaDA2.1-mini), " "[project page](https://github.com/inclusionAI/LLaDA2)" ) with gr.Row(): with gr.Column(scale=3): chatbot_ui = gr.Chatbot(label="Conversation", height=500) with gr.Group(): with gr.Row(): user_input = gr.Textbox( label="Your Message", placeholder="Type your message here...", show_label=False, ) send_btn = gr.Button("Send") constraints_input = gr.Textbox( label="Word Constraints", info="Place specific words at specific positions: 'position:word' format. " "Example: '0:Once, 5:upon, 10:time'", placeholder="0:Once, 5:upon, 10:time", value="", ) with gr.Column(scale=2): output_vis = gr.HighlightedText( label="Confidence", combine_adjacent=False, show_legend=True, color_map={ "low": "#FF6666", "mid": "#FFAA33", "high": "#66CC66", }, ) with gr.Accordion("Generation Settings", open=False): with gr.Row(): gen_length = gr.Slider( minimum=16, maximum=512, value=128, step=8, label="Generation Length", ) steps = gr.Slider( minimum=8, maximum=64, value=32, step=4, label="Denoising Steps", ) with gr.Row(): temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Temperature", ) block_length = gr.Slider( minimum=8, maximum=128, value=32, step=8, label="Block Length", ) with gr.Row(): threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.7, step=0.05, label="Confidence Threshold", ) editing_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Editing Threshold", ) with gr.Row(): max_post_steps = gr.Slider( minimum=0, maximum=32, value=16, step=1, label="Max Post-Editing Steps", ) clear_btn = gr.Button("Clear Conversation") def user_message_submitted(message, history): if not message.strip(): return history, "", [] history = history + [{"role": "user", "content": message}] return history, "", [] def bot_response( history, gen_length, steps, temperature, block_length, threshold, editing_threshold, max_post_steps, constraints_text, ): if not history: yield history, [] return try: messages = [ {"role": msg["role"], "content": extract_text(msg["content"])} for msg in history if msg["role"] in ("user", "assistant") and msg.get("content") ] constraints = parse_constraints(constraints_text) history = history + [{"role": "assistant", "content": ""}] response_text = None for vis_state, plain, text in generate_streaming( messages, gen_length, steps, temperature, block_length, threshold, editing_threshold, max_post_steps, constraints, ): if text is not None: response_text = text history[-1]["content"] = plain yield history, vis_state if response_text: history[-1]["content"] = response_text yield history, vis_state except Exception as e: error_msg = f"Error: {str(e)}" print(error_msg) yield history, [(error_msg, "red")] def clear_conversation(): return [], "", [] clear_btn.click( fn=clear_conversation, inputs=[], outputs=[chatbot_ui, user_input, output_vis], ) msg_submit = user_input.submit( fn=user_message_submitted, inputs=[user_input, chatbot_ui], outputs=[chatbot_ui, user_input, output_vis], ) send_click = send_btn.click( fn=user_message_submitted, inputs=[user_input, chatbot_ui], outputs=[chatbot_ui, user_input, output_vis], ) bot_inputs = [ chatbot_ui, gen_length, steps, temperature, block_length, threshold, editing_threshold, max_post_steps, constraints_input, ] bot_outputs = [chatbot_ui, output_vis] msg_submit.then(fn=bot_response, inputs=bot_inputs, outputs=bot_outputs) send_click.then(fn=bot_response, inputs=bot_inputs, outputs=bot_outputs) return demo if __name__ == "__main__": demo = create_chatbot_demo() demo.queue().launch(share=True)