import gradio as gr import torch from transformers import GPT2Tokenizer, Lfm2Config, Lfm2ForCausalLM # Load model and tokenizer model_name = "MostLime/LFM2-150M-1.5B" config = Lfm2Config.from_pretrained(model_name) tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token model = Lfm2ForCausalLM.from_pretrained(model_name, config=config) model.eval() def generate_text(prompt, max_length, temperature, top_p, rep_penalty): """Generate text with streaming, token by token""" inputs = tokenizer(prompt, return_tensors="pt", padding=True) input_ids = inputs.input_ids attention_mask = inputs.attention_mask generated_text = prompt past_key_values = None # Disable textbox during generation yield gr.update(value=generated_text, interactive=False) for _ in range(max_length): with torch.no_grad(): if past_key_values is None: outputs = model(input_ids, attention_mask=attention_mask) else: outputs = model(input_ids[:, -1:], past_key_values=past_key_values) logits = outputs.logits[:, -1, :] / temperature past_key_values = outputs.past_key_values # Apply repetition penalty if rep_penalty != 1.0: for token_id in set(input_ids[0].tolist()): logits[0, token_id] /= rep_penalty # Apply top-p sampling sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[0, 0] = False indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[0, indices_to_remove] = float('-inf') # Sample next token probs = torch.softmax(logits, dim=-1) next_token_tensor = torch.multinomial(probs, num_samples=1) # Shape [1, 1] next_token_id = next_token_tensor.item() # Extract Python int # Stop if EOS if next_token_id == tokenizer.eos_token_id: break # Append token and update attention mask input_ids = torch.cat([input_ids, next_token_tensor], dim=1) attention_mask = torch.cat([attention_mask, torch.ones((1, 1), dtype=attention_mask.dtype)], dim=1) # Decode single token (wrap in list) new_text = tokenizer.decode([next_token_id], skip_special_tokens=True) generated_text += new_text # Stream the updated text yield gr.update(value=generated_text, interactive=False) # Re-enable textbox after generation yield gr.update(value=generated_text, interactive=True) with gr.Blocks(title="LFM2 Autocompletion Playground") as demo: gr.Markdown("# LFM2 Autocompletion Playground") with gr.Row(): with gr.Column(scale=1): max_length = gr.Slider(50, 500, value=100, label="Max Tokens") temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature") top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p") rep_penalty = gr.Slider(1.0, 2.0, value=1.2, step=0.05, label="Repetition Penalty") with gr.Row(): generate_btn = gr.Button("Generate", variant="primary", scale=2) cancel_btn = gr.Button("Cancel", scale=1) with gr.Column(scale=3): text_input = gr.Textbox( label="Text Editor", lines=25, placeholder="Start typing here...", autoscroll=True ) generate_event = generate_btn.click( fn=generate_text, inputs=[text_input, max_length, temperature, top_p, rep_penalty], outputs=text_input, show_progress=True ) cancel_btn.click(fn=None, inputs=None, outputs=None, cancels=[generate_event]) demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)