import os import gc import subprocess subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) # Workaround for gradio_client crashes on bool schemas (e.g. additionalProperties: True/False). # Must run BEFORE `import gradio as gr` so the patched functions are used. import gradio_client.utils as _gcu _orig_get_type = _gcu.get_type def _safe_get_type(schema): if isinstance(schema, bool): return "any" return _orig_get_type(schema) _gcu.get_type = _safe_get_type _orig_json_schema = _gcu._json_schema_to_python_type def _safe_json_schema(schema, defs=None): if isinstance(schema, bool): return "any" return _orig_json_schema(schema, defs) _gcu._json_schema_to_python_type = _safe_json_schema import spaces import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig from threading import Thread, Event import time import uuid import re from diffusers import ChromaPipeline # Pre-load ONLY Chroma (not LLMs, to support custom models) print("Loading Chroma1-HD...") device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device at module level: {device}") chroma_pipe = ChromaPipeline.from_pretrained( "lodestones/Chroma1-HD", torch_dtype=torch.bfloat16 ) chroma_pipe = chroma_pipe.to(device) print("✓ Chroma1-HD ready") MODEL_CONFIGS = { "Nekochu/Luminia-13B-v3": { "system": "", "examples": [ "### Instruction:\nCreate stable diffusion metadata based on the given english description. Luminia\n\n### Input:\nfavorites and popular SFW", "### Instruction:\nProvide tips on stable diffusion to optimize low token prompts and enhance quality include prompt example." ], "supports_image_gen": True, "sd_temp": 0.3, "sd_top_p": 0.8, "branch": None # Uses main/default branch }, "Nekochu/Luminia-8B-v4-Chan": { "system": "write a response like a 4chan user", "examples": [], "supports_image_gen": False, "branch": "Llama-3-8B-4Chan_SD_QLoRa" }, "Nekochu/Luminia-8B-RP": { "system": "You are a knowledgeable and empathetic mental health professional.", "examples": ["How to cope with anxiety?"], "supports_image_gen": False, "branch": None } } DEFAULT_MODELS = list(MODEL_CONFIGS.keys()) models_cache = {} stop_event = Event() current_thread = None MAX_CACHE_SIZE = 2 DEFAULT_MODEL = DEFAULT_MODELS[0] def parse_model_id(model_id_str): """Parse model ID and optional branch (format: 'model_id:branch')""" if ':' in model_id_str: parts = model_id_str.split(':', 1) return parts[0], parts[1] if model_id_str in MODEL_CONFIGS: # Check if it's a known model with a specific branch config = MODEL_CONFIGS[model_id_str] return model_id_str, config.get('branch', None) return model_id_str, None def parse_sd_metadata(text: str): """Parse SD metadata""" metadata = { 'prompt': '', 'negative_prompt': '', 'steps': 25, 'cfg_scale': 7.0, 'seed': 42, 'width': 1024, 'height': 1024 } if not text: metadata['prompt'] = '(masterpiece, best quality), 1girl' return metadata try: if "Negative prompt:" in text: parts = text.split("Negative prompt:", 1) metadata['prompt'] = parts[0].strip().rstrip('.,;')[:500] if len(parts) > 1: neg_section = parts[1] param_match = re.search(r'(Steps:|Sampler:|CFG scale:|Seed:|Size:)', neg_section) if param_match: metadata['negative_prompt'] = neg_section[:param_match.start()].strip().rstrip('.,;')[:300] else: metadata['negative_prompt'] = neg_section.strip().rstrip('.,;')[:300] else: param_match = re.search(r'(Steps:|Sampler:|CFG scale:|Seed:|Size:)', text) if param_match: metadata['prompt'] = text[:param_match.start()].strip().rstrip('.,;')[:500] else: metadata['prompt'] = text.strip()[:500] patterns = { 'Steps': (r'Steps:\s*(\d+)', lambda x: min(int(x), 30)), 'CFG scale': (r'CFG scale:\s*([\d.]+)', float), 'Seed': (r'Seed:\s*(\d+)', lambda x: int(x) % (2**32)), 'Size': (r'Size:\s*(\d+)x(\d+)', None) } for key, (pattern, converter) in patterns.items(): match = re.search(pattern, text) if match: try: if key == 'Size': metadata['width'] = min(max(int(match.group(1)), 512), 1536) metadata['height'] = min(max(int(match.group(2)), 512), 1536) else: metadata[key.lower().replace(' ', '_')] = converter(match.group(1)) except: pass except: pass if not metadata['prompt']: metadata['prompt'] = '(masterpiece, best quality), 1girl' return metadata def clear_old_cache(): global models_cache if len(models_cache) >= MAX_CACHE_SIZE: oldest = min(models_cache.items(), key=lambda x: x[1].get('last_used', 0)) del models_cache[oldest[0]] gc.collect() torch.cuda.empty_cache() @spaces.GPU(duration=119) def generate_text_gpu(model_id_str, message, history, system, temp, top_p, top_k, max_tokens, rep_penalty): """Text generation with branch support""" global models_cache, stop_event, current_thread stop_event.clear() model_id, branch = parse_model_id(model_id_str) # Parse model ID and branch cache_key = f"{model_id}:{branch}" if branch else model_id config = MODEL_CONFIGS.get(model_id, {}) if "Luminia-13B-v3" in model_id and ("stable diffusion" in message.lower() or "metadata" in message.lower()): temp = config.get('sd_temp', 0.3) top_p = config.get('sd_top_p', 0.8) print(f"Using SD settings: temp={temp}, top_p={top_p}") if cache_key not in models_cache: clear_old_cache() try: yield history + [[message, f"📥 Loading {model_id}{f' ({branch})' if branch else ''}..."]], "Loading..." # Load with branch/revision support load_kwargs = {"trust_remote_code": True} if branch: load_kwargs["revision"] = branch print(f"Loading from branch: {branch}") tokenizer = AutoTokenizer.from_pretrained(model_id, **load_kwargs) tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True ) model_kwargs = { "quantization_config": bnb_config, "device_map": "auto", "trust_remote_code": True, "attn_implementation": "flash_attention_2" if torch.cuda.is_available() else None, "low_cpu_mem_usage": True } if branch: model_kwargs["revision"] = branch model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) models_cache[cache_key] = { "model": model, "tokenizer": tokenizer, "last_used": time.time() } except Exception as e: yield history + [[message, f"❌ Failed: {str(e)[:200]}"]], "Error" return models_cache[cache_key]['last_used'] = time.time() model = models_cache[cache_key]["model"] tokenizer = models_cache[cache_key]["tokenizer"] prompt = "" if system: prompt = f"{system}\n\n" for user_msg, assistant_msg in history: if "### Instruction:" in user_msg: prompt += f"{user_msg}\n### Response:\n{assistant_msg}\n\n" else: prompt += f"### Instruction:\n{user_msg}\n\n### Response:\n{assistant_msg}\n\n" if "### Instruction:" in message and "### Response:" not in message: prompt += f"{message}\n### Response:\n" elif "### Instruction:" not in message: prompt += f"### Instruction:\n{message}\n\n### Response:\n" else: prompt += message print(f"Prompt ending: ...{prompt[-200:]}") try: inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) input_tokens = inputs['input_ids'].shape[1] inputs = {k: v.to(model.device) for k, v in inputs.items()} except Exception as e: yield history + [[message, f"❌ Tokenization failed: {str(e)}"]], "Error" return print(f"📝 {input_tokens} tokens | Temp: {temp} | Top-p: {top_p}") streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=5) gen_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": min(max_tokens, 2048), "temperature": max(temp, 0.01), "top_p": top_p, "top_k": top_k, "repetition_penalty": rep_penalty, "do_sample": temp > 0.01, "pad_token_id": tokenizer.pad_token_id } current_thread = Thread(target=model.generate, kwargs=gen_kwargs) current_thread.start() start_time = time.time() partial = "" token_count = 0 try: for text in streamer: if stop_event.is_set(): break partial += text token_count = len(tokenizer.encode(partial, add_special_tokens=False)) elapsed = time.time() - start_time if elapsed > 0: yield history + [[message, partial]], f"⚡ {token_count} @ {token_count/elapsed:.1f} t/s" except: pass finally: if current_thread.is_alive(): stop_event.set() current_thread.join(timeout=2) final_time = time.time() - start_time yield history + [[message, partial]], f"✅ {token_count} tokens in {final_time:.1f}s" @spaces.GPU() def generate_image_gpu(text_output): """Image generation with pre-loaded Chroma""" global chroma_pipe if not text_output or text_output.isspace(): return None, "❌ No valid text", gr.update(visible=False) try: metadata = parse_sd_metadata(text_output) print(f"Generating: {metadata['width']}x{metadata['height']} | Steps: {metadata['steps']}") if torch.cuda.is_available(): chroma_pipe = chroma_pipe.to("cuda") generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(metadata['seed']) image = chroma_pipe( prompt=metadata['prompt'], negative_prompt=metadata['negative_prompt'], generator=generator, num_inference_steps=metadata['steps'], guidance_scale=metadata['cfg_scale'], width=metadata['width'], height=metadata['height'] ).images[0] status = f"✅ {metadata['width']}x{metadata['height']} | {metadata['steps']} steps | CFG: {metadata['cfg_scale']} | Seed: {metadata['seed']}" return image, status, gr.update(visible=False) except Exception as e: import traceback traceback.print_exc() return None, f"❌ Failed: {str(e)[:200]}", gr.update(visible=False) def stop_generation(): global stop_event, current_thread stop_event.set() if current_thread and current_thread.is_alive(): current_thread.join(timeout=2) return gr.update(visible=True), gr.update(visible=False) css = """ #chatbot {height: 305px;} #input-row {display: flex; gap: 4px;} #input-box {flex-grow: 1;} #button-group {display: inline-flex; flex-direction: column; gap: 2px; width: 45px;} #button-group button {width: 40px; height: 28px; padding: 2px; font-size: 14px;} #status {font-size: 11px; color: #666; margin-top: 2px;} #image-output {max-height: 400px; margin-top: 8px;} #img-loading {font-size: 11px; color: #666; margin: 4px 0;} """ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Row(): with gr.Column(scale=4): chatbot = gr.Chatbot(value=[], elem_id="chatbot", type="tuples") with gr.Row(elem_id="input-row"): msg = gr.Textbox( label="Instruction", lines=3, elem_id="input-box", value=MODEL_CONFIGS[DEFAULT_MODEL]["examples"][0] if MODEL_CONFIGS[DEFAULT_MODEL]["examples"] else "", scale=10 ) with gr.Column(elem_id="button-group", scale=1, min_width=45): submit = gr.Button("▶", variant="primary", size="sm") stop = gr.Button("⏹", variant="stop", size="sm", visible=False) undo = gr.Button("↩", size="sm") clear = gr.Button("🗑", size="sm") status = gr.Markdown("", elem_id="status") with gr.Row(): image_btn = gr.Button("🎨 Generate Image using Chroma1-HD", visible=False, variant="secondary") last_text = gr.Textbox(visible=False) img_loading = gr.Markdown("", visible=False, elem_id="img-loading") image_output = gr.Image(visible=False, elem_id="image-output") image_status = gr.Markdown("", visible=False) examples = gr.Examples( examples=[[ex] for ex in MODEL_CONFIGS[DEFAULT_MODEL]["examples"] if ex], inputs=msg, label="Examples" ) with gr.Column(scale=1): model = gr.Dropdown( DEFAULT_MODELS, value=DEFAULT_MODEL, label="Model", allow_custom_value=True, info="Custom HF ID + optional :branch" ) with gr.Accordion("Settings", open=False): system = gr.Textbox( label="System Prompt", value=MODEL_CONFIGS[DEFAULT_MODEL]["system"], lines=2 ) temp = gr.Slider(0.1, 1.0, 0.35, label="Temperature") top_p = gr.Slider(0.5, 1.0, 0.85, label="Top-p") top_k = gr.Slider(10, 100, 40, label="Top-k") rep_penalty = gr.Slider(1.0, 1.5, 1.1, label="Repetition Penalty") max_tokens = gr.Slider(256, 2048, 1024, label="Max Tokens") export_btn = gr.Button("💾 Export", size="sm") export_file = gr.File(visible=False) def update_ui_on_model_change(model_id_str): """Update all UI components when model changes""" model_id, branch = parse_model_id(model_id_str) config = MODEL_CONFIGS.get(model_id, {"system": "", "examples": [""], "supports_image_gen": False}) return ( config["system"], config["examples"][0] if config["examples"] else "", gr.update(visible=False), # image_btn "", # last_text None, # image_output (clear image) gr.update(visible=False), # image_output visibility "", # image_status text gr.update(visible=False), # image_status visibility gr.update(visible=False) # img_loading visibility ) def check_image_availability(model_id_str, history): model_id, _ = parse_model_id(model_id_str) if "Luminia-13B-v3" in model_id and history and len(history) > 0: return gr.update(visible=True), history[-1][1] return gr.update(visible=False), "" submit.click( lambda: (gr.update(visible=False), gr.update(visible=True)), None, [submit, stop] ).then( generate_text_gpu, [model, msg, chatbot, system, temp, top_p, top_k, max_tokens, rep_penalty], [chatbot, status] ).then( lambda: (gr.update(visible=True), gr.update(visible=False)), None, [submit, stop] ).then( check_image_availability, [model, chatbot], [image_btn, last_text] ) stop.click(stop_generation, None, [submit, stop]) image_btn.click( lambda: gr.update(value="🎨 Generating...", visible=True), None, img_loading ).then( generate_image_gpu, last_text, [image_output, image_status, img_loading] ).then( lambda img: (gr.update(visible=img is not None), gr.update(visible=True)), image_output, [image_output, image_status] ) model.change( update_ui_on_model_change, model, [system, msg, image_btn, last_text, image_output, image_output, image_status, image_status, img_loading] ) undo.click( lambda h: h[:-1] if h else h, chatbot, chatbot ).then( check_image_availability, [model, chatbot], [image_btn, last_text] ) clear.click( lambda: ([], "", "", None, "", gr.update(visible=False), "", gr.update(visible=False)), None, [chatbot, msg, status, image_output, image_status, image_btn, last_text, img_loading] ) def export_chat(history): if not history: return None content = "\n\n".join([f"User: {u}\n\nAssistant: {a}" for u, a in history]) path = f"chat_{uuid.uuid4().hex[:8]}.txt" with open(path, "w", encoding="utf-8") as f: f.write(content) return path export_btn.click(export_chat, chatbot, export_file).then( lambda: gr.update(visible=True), None, export_file ) demo.queue().launch(show_api=False)