signsur4739379373 commited on
Commit
428cf46
·
verified ·
1 Parent(s): e3a6c9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +440 -306
app.py CHANGED
@@ -1,340 +1,474 @@
1
  import gradio as gr
2
- import os
3
- import subprocess
4
- import shutil
5
- import json
6
- import time
7
- from pathlib import Path
8
  import torch
9
  import spaces
10
- from diffusers import DiffusionPipeline
11
-
12
- # ==========================================
13
- # 1. SETUP & GLOBAL VARS
14
- # ==========================================
15
-
16
- DATASET_DIR = Path("./datasets")
17
- OUTPUT_DIR = Path("./output")
18
- DATASET_DIR.mkdir(exist_ok=True)
19
- OUTPUT_DIR.mkdir(exist_ok=True)
20
-
21
- # global tracking for loras
22
- # key: friendly name, value: path
23
- AVAILABLE_LORAS = {}
24
-
25
- print("loading z-image-turbo pipeline...")
26
- pipe = DiffusionPipeline.from_pretrained(
27
- "Tongyi-MAI/Z-Image-Turbo",
28
- torch_dtype=torch.bfloat16,
29
- low_cpu_mem_usage=False,
30
- )
31
- pipe.to("cuda")
32
- print("pipeline loaded!")
33
 
34
- # ==========================================
35
- # 2. TRAINING LOGIC
36
- # ==========================================
 
 
 
37
 
38
- def check_gpu():
39
- if torch.cuda.is_available():
40
- return f"✅ gpu available: {torch.cuda.get_device_name(0)}"
41
- return "⚠️ no gpu detected"
42
 
43
- def upload_and_prepare_dataset(files, dataset_name, trigger_word):
44
- if not files:
45
- return "❌ upload images first", None, ""
46
-
47
- if not dataset_name:
48
- dataset_name = f"dataset_{int(time.time())}"
49
-
50
- dataset_path = DATASET_DIR / dataset_name
51
- dataset_path.mkdir(exist_ok=True, parents=True)
52
-
53
- image_count = 0
54
- for file in files:
55
- if file.name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp', '.bmp')):
56
- filename = Path(file.name).name
57
- dest = dataset_path / filename
58
- shutil.copy(file.name, dest)
59
-
60
- caption_file = dest.with_suffix('.txt')
61
- caption_text = trigger_word if trigger_word else "a photo"
62
- with open(caption_file, 'w') as f:
63
- f.write(caption_text)
64
-
65
- image_count += 1
66
-
67
- if image_count == 0:
68
- return "❌ no valid images found", None, ""
69
-
70
- return f"✅ ready: {image_count} images in {dataset_name}", str(dataset_path), dataset_name
71
-
72
- # request 10 mins gpu for training
73
- @spaces.GPU(duration=200)
74
- def train_lora(
75
- dataset_path,
76
- project_name,
77
- trigger_word,
78
- steps,
79
- learning_rate,
80
- lora_rank,
81
- resolution,
82
- progress=gr.Progress()
83
- ):
84
- if not dataset_path:
85
- return "❌ no dataset", None
86
 
87
- if not project_name:
88
- project_name = f"lora_{int(time.time())}"
89
-
90
- output_path = OUTPUT_DIR / project_name
91
- output_path.mkdir(exist_ok=True, parents=True)
92
-
93
- # config generation
94
- config = {
95
- "job": "extension",
96
- "config": {
97
- "name": project_name,
98
- "process": [{
99
- "type": "sd_trainer",
100
- "training_folder": str(output_path),
101
- "device": "cuda:0",
102
- "trigger_word": trigger_word or "",
103
- "network": {
104
- "type": "lora",
105
- "linear": int(lora_rank),
106
- "linear_alpha": int(lora_rank),
107
- },
108
- "save": {
109
- "dtype": "float16",
110
- "save_every": int(steps), # save only at end to save space
111
- "max_step_saves_to_keep": 1,
112
- },
113
- "datasets": [{
114
- "folder_path": dataset_path,
115
- "caption_ext": "txt",
116
- "caption_dropout_rate": 0.05,
117
- "resolution": [int(resolution), int(resolution)],
118
- }],
119
- "train": {
120
- "batch_size": 1,
121
- "steps": int(steps),
122
- "gradient_accumulation_steps": 1,
123
- "train_unet": True,
124
- "train_text_encoder": False,
125
- "gradient_checkpointing": True,
126
- "noise_scheduler": "flowmatch",
127
- "optimizer": "adamw8bit",
128
- "lr": float(learning_rate),
129
- "ema_config": {"use_ema": True, "ema_decay": 0.99},
130
- "dtype": "bf16",
131
- },
132
- "model": {
133
- "name_or_path": "Tongyi-MAI/Z-Image-Base",
134
- "is_v_pred": False,
135
- "quantize": True,
136
- },
137
- }]
138
- }
139
- }
140
-
141
- config_path = output_path / "config.json"
142
- with open(config_path, 'w') as f:
143
- json.dump(config, f, indent=2)
144
-
145
- # install ai-toolkit
146
- progress(0.1, desc="setting up environment...")
147
- if not Path("./ai-toolkit").exists():
148
- try:
149
- subprocess.run(["git", "clone", "https://github.com/ostris/ai-toolkit.git"], check=True)
150
- subprocess.run(["pip", "install", "-q", "-r", "ai-toolkit/requirements.txt"], check=True)
151
- except Exception as e:
152
- return f" setup failed: {e}", None
153
-
154
- progress(0.2, desc="training (this takes time)...")
155
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  try:
157
- # run training script
158
- # explicitly passing environment to ensure cuda visibility in subprocess
159
- env = os.environ.copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- proc = subprocess.run(
162
- ["python", "ai-toolkit/run.py", str(config_path)],
163
- capture_output=True,
164
- text=True,
165
- env=env,
166
- timeout=3500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  )
168
 
169
- if proc.returncode != 0:
170
- return f"❌ training crashed:\n{proc.stderr}", None
171
-
172
- # find result
173
- lora_files = list(output_path.glob("*.safetensors"))
174
- if lora_files:
175
- lora_file = lora_files[-1]
176
- AVAILABLE_LORAS[project_name] = str(lora_file)
177
-
178
- # update the dropdown choices dynamically
179
- choices = [("None", None)] + [(k, v) for k, v in AVAILABLE_LORAS.items()]
 
 
 
180
 
181
- return f"✅ trained: {project_name}", str(lora_file)
 
182
 
183
- return "⚠️ finished but no safetensors found", None
184
-
185
  except Exception as e:
186
- return f" fatal error: {e}", None
187
-
188
- # ==========================================
189
- # 3. INFERENCE LOGIC
190
- # ==========================================
191
-
192
- @spaces.GPU
193
- def generate_image(
194
- prompt,
195
- height,
196
- width,
197
- steps,
198
- seed,
199
- randomize_seed,
200
- lora_path,
201
- lora_scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  ):
203
- # handle lora loading/unloading
204
- pipe.unload_lora_weights() # clean slate
205
-
206
- if lora_path and os.path.exists(lora_path):
207
- print(f"loading lora: {lora_path}")
208
- try:
209
- pipe.load_lora_weights(lora_path)
210
- # manual scaling not always supported directly without fuse,
211
- # but usually applied by default.
212
- # for simplicitly we just load it.
213
- except Exception as e:
214
- print(f"lora load failed: {e}")
 
 
 
215
 
 
 
 
 
 
 
 
216
  if randomize_seed:
217
- seed = torch.randint(0, 2**32 - 1, (1,)).item()
 
 
 
218
 
219
- generator = torch.Generator("cuda").manual_seed(int(seed))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
 
 
221
  image = pipe(
 
222
  prompt=prompt,
223
- height=int(height),
224
- width=int(width),
225
- num_inference_steps=int(steps),
226
- guidance_scale=0.0,
227
  generator=generator,
228
- ).images[0]
229
-
230
- return image, seed
231
 
232
- def update_lora_list():
233
- """helper to refresh dropdown"""
234
- choices = [("None", None)] + [(k, v) for k, v in AVAILABLE_LORAS.items()]
235
- return gr.Dropdown(choices=choices)
236
 
237
- # ==========================================
238
- # 4. UI CONSTRUCTION
239
- # ==========================================
240
 
241
- custom_theme = gr.themes.Soft(primary_hue="yellow", secondary_hue="slate")
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
- with gr.Blocks(theme=custom_theme, title="Z-Image ZeroGPU Trainer") as demo:
244
-
245
- gr.Markdown("# ⚡ Z-Image-Turbo: Train & Test")
246
-
247
- with gr.Tabs():
248
-
249
- # TAB 1: INFERENCE
250
- with gr.Tab("🎨 Generate"):
251
- with gr.Row():
252
- with gr.Column():
253
- prompt_input = gr.Textbox(label="Prompt", lines=3)
254
-
255
- with gr.Row():
256
- lora_selector = gr.Dropdown(
257
- label="Select LoRA",
258
- choices=[("None", None)],
259
- value=None,
260
- interactive=True
261
- )
262
- refresh_btn = gr.Button("🔄", size="sm", scale=0)
263
-
264
- with gr.Accordion("Settings", open=False):
265
- h_slider = gr.Slider(512, 2048, 1024, step=64, label="Height")
266
- w_slider = gr.Slider(512, 2048, 1024, step=64, label="Width")
267
- steps_slider = gr.Slider(1, 50, 9, step=1, label="Steps")
268
- seed_num = gr.Number(42, label="Seed")
269
- rand_seed = gr.Checkbox(True, label="Randomize Seed")
270
-
271
- gen_btn = gr.Button("Generate", variant="primary")
272
-
273
- with gr.Column():
274
- out_img = gr.Image(label="Result")
275
- out_seed = gr.Number(label="Seed Used")
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
- # TAB 2: TRAINING
278
- with gr.Tab("🏋️ Train LoRA"):
279
- gr.Markdown("⚠️ **Note:** Requires paid GPU space for long timeouts.")
280
-
281
  with gr.Row():
282
- with gr.Column():
283
- train_files = gr.Files(label="Images", file_types=["image"])
284
- train_name = gr.Textbox(label="Project Name", value="my_lora")
285
- train_trigger = gr.Textbox(label="Trigger Word", value="ohwx")
286
-
287
- # hidden state for dataset path
288
- dataset_path_state = gr.State()
289
-
290
- upload_btn = gr.Button("1. Process Dataset")
291
- upload_status = gr.Textbox(label="Dataset Status")
292
-
293
- gr.Markdown("---")
294
-
295
- train_steps = gr.Slider(100, 2000, 500, step=100, label="Steps")
296
- train_lr = gr.Slider(1e-5, 1e-3, 1e-4, step=1e-5, label="Learning Rate")
297
- train_rank = gr.Slider(4, 128, 16, step=4, label="Rank")
298
-
299
- start_train_btn = gr.Button("2. Start Training", variant="stop")
300
 
301
- with gr.Column():
302
- train_log = gr.Textbox(label="Training Log", lines=10)
303
- lora_file_download = gr.File(label="Download LoRA")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
- # WIRING
306
-
307
- # Refresh LoRA list
308
- refresh_btn.click(update_lora_list, outputs=lora_selector)
309
-
310
- # Upload
311
- upload_btn.click(
312
- upload_and_prepare_dataset,
313
- [train_files, train_name, train_trigger],
314
- [upload_status, dataset_path_state, train_name]
315
- )
316
-
317
- # Train
318
- def on_train_complete(status, file_path):
319
- # Update available loras list immediately after training
320
- new_choices = [("None", None)] + [(k, v) for k, v in AVAILABLE_LORAS.items()]
321
- return status, file_path, gr.Dropdown(choices=new_choices)
322
-
323
- start_train_btn.click(
324
- train_lora,
325
- [dataset_path_state, train_name, train_trigger, train_steps, train_lr, train_rank, h_slider], # reusing h_slider for res
326
- [train_log, lora_file_download]
327
- ).then(
328
- update_lora_list,
329
- outputs=[lora_selector]
330
  )
331
-
332
- # Generate
333
- gen_btn.click(
334
- generate_image,
335
- [prompt_input, h_slider, w_slider, steps_slider, seed_num, rand_seed, lora_selector, train_lr], # train_lr dummy
336
- [out_img, out_seed]
337
  )
338
 
339
  if __name__ == "__main__":
340
- demo.launch()
 
1
  import gradio as gr
2
+ import numpy as np
3
+ import random
 
 
 
 
4
  import torch
5
  import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ from PIL import Image
8
+ from diffusers import FlowMatchEulerDiscreteScheduler, QwenImageEditPlusPipeline
9
+ # from optimization import optimize_pipeline_
10
+ # from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
11
+ # from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
12
+ # from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
13
 
14
+ from huggingface_hub import InferenceClient
15
+ import math
 
 
16
 
17
+ import os
18
+ import base64
19
+ from io import BytesIO
20
+ import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ SYSTEM_PROMPT = '''
23
+ # Edit Instruction Rewriter
24
+ You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable professional-level edit instruction based on the user-provided instruction and the image to be edited.
25
+
26
+ Please strictly follow the rewriting rules below:
27
+
28
+ ## 1. General Principles
29
+ - Keep the rewritten prompt **concise and comprehensive**. Avoid overly long sentences and unnecessary descriptive language.
30
+ - If the instruction is contradictory, vague, or unachievable, prioritize reasonable inference and correction, and supplement details when necessary.
31
+ - Keep the main part of the original instruction unchanged, only enhancing its clarity, rationality, and visual feasibility.
32
+ - All added objects or modifications must align with the logic and style of the scene in the input images.
33
+ - If multiple sub-images are to be generated, describe the content of each sub-image individually.
34
+
35
+ ## 2. Task-Type Handling Rules
36
+
37
+ ### 1. Add, Delete, Replace Tasks
38
+ - If the instruction is clear (already includes task type, target entity, position, quantity, attributes), preserve the original intent and only refine the grammar.
39
+ - If the description is vague, supplement with minimal but sufficient details (category, color, size, orientation, position, etc.). For example:
40
+ > Original: "Add an animal"
41
+ > Rewritten: "Add a light-gray cat in the bottom-right corner, sitting and facing the camera"
42
+ - Remove meaningless instructions: e.g., "Add 0 objects" should be ignored or flagged as invalid.
43
+ - For replacement tasks, specify "Replace Y with X" and briefly describe the key visual features of X.
44
+
45
+ ### 2. Text Editing Tasks
46
+ - All text content must be enclosed in English double quotes `" "`. Keep the original language of the text, and keep the capitalization.
47
+ - Both adding new text and replacing existing text are text replacement tasks, For example:
48
+ - Replace "xx" to "yy"
49
+ - Replace the mask / bounding box to "yy"
50
+ - Replace the visual object to "yy"
51
+ - Specify text position, color, and layout only if user has required.
52
+ - If font is specified, keep the original language of the font.
53
+
54
+ ### 3. Human Editing Tasks
55
+ - Make the smallest changes to the given user's prompt.
56
+ - If changes to background, action, expression, camera shot, or ambient lighting are required, please list each modification individually.
57
+ - **Edits to makeup or facial features / expression must be subtle, not exaggerated, and must preserve the subject's identity consistency.**
58
+ > Original: "Add eyebrows to the face"
59
+ > Rewritten: "Slightly thicken the person's eyebrows with little change, look natural."
60
+
61
+ ### 4. Style Conversion or Enhancement Tasks
62
+ - If a style is specified, describe it concisely using key visual features. For example:
63
+ > Original: "Disco style"
64
+ > Rewritten: "1970s disco style: flashing lights, disco ball, mirrored walls, vibrant colors"
65
+ - For style reference, analyze the original image and extract key characteristics (color, composition, texture, lighting, artistic style, etc.), integrating them into the instruction.
66
+ - **Colorization tasks (including old photo restoration) must use the fixed template:**
67
+ "Restore and colorize the old photo."
68
+ - Clearly specify the object to be modified. For example:
69
+ > Original: Modify the subject in Picture 1 to match the style of Picture 2.
70
+ > Rewritten: Change the girl in Picture 1 to the ink-wash style of Picture 2 — rendered in black-and-white watercolor with soft color transitions.
71
+
72
+ ### 5. Material Replacement
73
+ - Clearly specify the object and the material. For example: "Change the material of the apple to papercut style."
74
+ - For text material replacement, use the fixed template:
75
+ "Change the material of text "xxxx" to laser style"
76
+
77
+ ### 6. Logo/Pattern Editing
78
+ - Material replacement should preserve the original shape and structure as much as possible. For example:
79
+ > Original: "Convert to sapphire material"
80
+ > Rewritten: "Convert the main subject in the image to sapphire material, preserving similar shape and structure"
81
+ - When migrating logos/patterns to new scenes, ensure shape and structure consistency. For example:
82
+ > Original: "Migrate the logo in the image to a new scene"
83
+ > Rewritten: "Migrate the logo in the image to a new scene, preserving similar shape and structure"
84
+
85
+ ### 7. Multi-Image Tasks
86
+ - Rewritten prompts must clearly point out which image's element is being modified. For example:
87
+ > Original: "Replace the subject of picture 1 with the subject of picture 2"
88
+ > Rewritten: "Replace the girl of picture 1 with the boy of picture 2, keeping picture 2's background unchanged"
89
+ - For stylization tasks, describe the reference image's style in the rewritten prompt, while preserving the visual content of the source image.
90
+
91
+ ## 3. Rationale and Logic Check
92
+ - Resolve contradictory instructions: e.g., "Remove all trees but keep all trees" requires logical correction.
93
+ - Supplement missing critical information: e.g., if position is unspecified, choose a reasonable area based on composition (near subject, blank space, center/edge, etc.).
94
+
95
+ # Output Format Example
96
+ ```json
97
+ {
98
+ "Rewritten": "..."
99
+ }
100
+ '''
101
+
102
+ def polish_prompt_hf(original_prompt, img_list):
103
+ """
104
+ Rewrites the prompt using a Hugging Face InferenceClient.
105
+ Supports multiple images via img_list.
106
+ """
107
+ # Ensure HF_TOKEN is set
108
+ api_key = os.environ.get("inference_providers")
109
+ if not api_key:
110
+ print("Warning: HF_TOKEN not set. Falling back to original prompt.")
111
+ return original_prompt
112
+ prompt = f"{SYSTEM_PROMPT}\n\nUser Input: {original_prompt}\n\nRewritten Prompt:"
113
+ system_prompt = "you are a helpful assistant, you should provide useful answers to users."
114
  try:
115
+ # Initialize the client
116
+ client = InferenceClient(
117
+ provider="nebius",
118
+ api_key=api_key,
119
+ )
120
+
121
+ # Convert list of images to base64 data URLs
122
+ image_urls = []
123
+ if img_list is not None:
124
+ # Ensure img_list is actually a list
125
+ if not isinstance(img_list, list):
126
+ img_list = [img_list]
127
+
128
+ for img in img_list:
129
+ image_url = None
130
+ # If img is a PIL Image
131
+ if hasattr(img, 'save'): # Check if it's a PIL Image
132
+ buffered = BytesIO()
133
+ img.save(buffered, format="PNG")
134
+ img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
135
+ image_url = f"data:image/png;base64,{img_base64}"
136
+ # If img is already a file path (string)
137
+ elif isinstance(img, str):
138
+ with open(img, "rb") as image_file:
139
+ img_base64 = base64.b64encode(image_file.read()).decode('utf-8')
140
+ image_url = f"data:image/png;base64,{img_base64}"
141
+ else:
142
+ print(f"Warning: Unexpected image type: {type(img)}, skipping...")
143
+ continue
144
+
145
+ if image_url:
146
+ image_urls.append(image_url)
147
+
148
+ # Build the content array with text first, then all images
149
+ content = [
150
+ {
151
+ "type": "text",
152
+ "text": prompt
153
+ }
154
+ ]
155
 
156
+ # Add all images to the content
157
+ for image_url in image_urls:
158
+ content.append({
159
+ "type": "image_url",
160
+ "image_url": {
161
+ "url": image_url
162
+ }
163
+ })
164
+
165
+ # Format the messages for the chat completions API
166
+ messages = [
167
+ {"role": "system", "content": system_prompt},
168
+ {
169
+ "role": "user",
170
+ "content": content
171
+ }
172
+ ]
173
+
174
+ # Call the API
175
+ completion = client.chat.completions.create(
176
+ model="Qwen/Qwen2.5-VL-72B-Instruct",
177
+ messages=messages,
178
  )
179
 
180
+ # Parse the response
181
+ result = completion.choices[0].message.content
182
+
183
+ # Try to extract JSON if present
184
+ if '"Rewritten"' in result:
185
+ try:
186
+ # Clean up the response
187
+ result = result.replace('```json', '').replace('```', '')
188
+ result_json = json.loads(result)
189
+ polished_prompt = result_json.get('Rewritten', result)
190
+ except:
191
+ polished_prompt = result
192
+ else:
193
+ polished_prompt = result
194
 
195
+ polished_prompt = polished_prompt.strip().replace("\n", " ")
196
+ return polished_prompt
197
 
 
 
198
  except Exception as e:
199
+ print(f"Error during API call to Hugging Face: {e}")
200
+ # Fallback to original prompt if enhancement fails
201
+ return original_prompt
202
+
203
+
204
+
205
+ def encode_image(pil_image):
206
+ import io
207
+ buffered = io.BytesIO()
208
+ pil_image.save(buffered, format="PNG")
209
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
210
+
211
+ # --- Model Loading ---
212
+ dtype = torch.bfloat16
213
+ device = "cuda" if torch.cuda.is_available() else "cpu"
214
+
215
+ # Scheduler configuration for Lightning
216
+ scheduler_config = {
217
+ "base_image_seq_len": 256,
218
+ "base_shift": math.log(3),
219
+ "invert_sigmas": False,
220
+ "max_image_seq_len": 8192,
221
+ "max_shift": math.log(3),
222
+ "num_train_timesteps": 1000,
223
+ "shift": 1.0,
224
+ "shift_terminal": None,
225
+ "stochastic_sampling": False,
226
+ "time_shift_type": "exponential",
227
+ "use_beta_sigmas": False,
228
+ "use_dynamic_shifting": True,
229
+ "use_exponential_sigmas": False,
230
+ "use_karras_sigmas": False,
231
+ }
232
+
233
+ # Initialize scheduler with Lightning config
234
+ scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
235
+
236
+ # Load the model pipeline
237
+ pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2511",
238
+ scheduler=scheduler,
239
+ torch_dtype=dtype).to(device)
240
+ pipe.load_lora_weights(
241
+ "lightx2v/Qwen-Image-Edit-2511-Lightning",
242
+ weight_name="Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors"
243
+ )
244
+ pipe.fuse_lora()
245
+
246
+ # # Apply the same optimizations from the first version
247
+ # pipe.transformer.__class__ = QwenImageTransformer2DModel
248
+ # pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
249
+
250
+ # # --- Ahead-of-time compilation ---
251
+ # optimize_pipeline_(pipe, image=[Image.new("RGB", (1024, 1024)), Image.new("RGB", (1024, 1024))], prompt="prompt")
252
+
253
+ # --- UI Constants and Helpers ---
254
+ MAX_SEED = np.iinfo(np.int32).max
255
+
256
+ def use_output_as_input(output_images):
257
+ """Convert output images to input format for the gallery"""
258
+ if output_images is None or len(output_images) == 0:
259
+ return []
260
+ return output_images
261
+
262
+ # --- Main Inference Function (with hardcoded negative prompt) ---
263
+ @spaces.GPU()
264
+ def infer(
265
+ images,
266
+ prompt,
267
+ seed=42,
268
+ randomize_seed=False,
269
+ true_guidance_scale=1.0,
270
+ num_inference_steps=4,
271
+ height=None,
272
+ width=None,
273
+ rewrite_prompt=True,
274
+ num_images_per_prompt=1,
275
+ progress=gr.Progress(track_tqdm=True),
276
  ):
277
+ """
278
+ Run image-editing inference using the Qwen-Image-Edit pipeline.
279
+
280
+ Parameters:
281
+ images (list): Input images from the Gradio gallery (PIL or path-based).
282
+ prompt (str): Editing instruction (may be rewritten by LLM if enabled).
283
+ seed (int): Random seed for reproducibility.
284
+ randomize_seed (bool): If True, overrides seed with a random value.
285
+ true_guidance_scale (float): CFG scale used by Qwen-Image.
286
+ num_inference_steps (int): Number of diffusion steps.
287
+ height (int | None): Optional output height override.
288
+ width (int | None): Optional output width override.
289
+ rewrite_prompt (bool): Whether to rewrite the prompt using Qwen-2.5-VL.
290
+ num_images_per_prompt (int): Number of images to generate.
291
+ progress: Gradio progress callback.
292
 
293
+ Returns:
294
+ tuple: (generated_images, seed_used, UI_visibility_update)
295
+ """
296
+
297
+ # Hardcode the negative prompt as requested
298
+ negative_prompt = " "
299
+
300
  if randomize_seed:
301
+ seed = random.randint(0, MAX_SEED)
302
+
303
+ # Set up the generator for reproducibility
304
+ generator = torch.Generator(device=device).manual_seed(seed)
305
 
306
+ # Load input images into PIL Images
307
+ pil_images = []
308
+ if images is not None:
309
+ for item in images:
310
+ try:
311
+ if isinstance(item[0], Image.Image):
312
+ pil_images.append(item[0].convert("RGB"))
313
+ elif isinstance(item[0], str):
314
+ pil_images.append(Image.open(item[0]).convert("RGB"))
315
+ elif hasattr(item, "name"):
316
+ pil_images.append(Image.open(item.name).convert("RGB"))
317
+ except Exception:
318
+ continue
319
+
320
+ if height==256 and width==256:
321
+ height, width = None, None
322
+ print(f"Calling pipeline with prompt: '{prompt}'")
323
+ print(f"Negative Prompt: '{negative_prompt}'")
324
+ print(f"Seed: {seed}, Steps: {num_inference_steps}, Guidance: {true_guidance_scale}, Size: {width}x{height}")
325
+ if rewrite_prompt and len(pil_images) > 0:
326
+ prompt = polish_prompt_hf(prompt, pil_images)
327
+ print(f"Rewritten Prompt: {prompt}")
328
 
329
+
330
+ # Generate the image
331
  image = pipe(
332
+ image=pil_images if len(pil_images) > 0 else None,
333
  prompt=prompt,
334
+ height=height,
335
+ width=width,
336
+ negative_prompt=negative_prompt,
337
+ num_inference_steps=num_inference_steps,
338
  generator=generator,
339
+ true_cfg_scale=true_guidance_scale,
340
+ num_images_per_prompt=num_images_per_prompt,
341
+ ).images
342
 
343
+ # Return images, seed, and make button visible
344
+ return image, seed, gr.update(visible=True)
 
 
345
 
346
+ # --- Examples and UI Layout ---
347
+ examples = []
 
348
 
349
+ css = """
350
+ #col-container {
351
+ margin: 0 auto;
352
+ max-width: 1024px;
353
+ }
354
+ #logo-title {
355
+ text-align: center;
356
+ }
357
+ #logo-title img {
358
+ width: 400px;
359
+ }
360
+ #edit_text{margin-top: -62px !important}
361
+ """
362
 
363
+ with gr.Blocks(css=css) as demo:
364
+ with gr.Column(elem_id="col-container"):
365
+ gr.HTML("""
366
+ <div id="logo-title">
367
+ <img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/qwen_image_edit_logo.png" alt="Qwen-Image Edit Logo" width="400" style="display: block; margin: 0 auto;">
368
+ <h2 style="font-style: italic;color: #5b47d1;margin-top: -27px !important;margin-left: 96px">[Plus] Fast, 4-steps with LightX2V LoRA</h2>
369
+ </div>
370
+ """)
371
+ gr.Markdown("""
372
+ [Learn more](https://github.com/QwenLM/Qwen-Image) about the Qwen-Image series.
373
+ This demo uses the new [Qwen-Image-Edit-2511](https://huggingface.co/Qwen/Qwen-Image-Edit-2511) with the [Qwen-Image-Lightning-2511](https://huggingface.co/lightx2v/Qwen-Image-Edit-2511-Lightning) LoRA for accelerated inference.
374
+ Try on [Qwen Chat](https://chat.qwen.ai/), or [download model](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) to run locally with ComfyUI or diffusers.
375
+ """)
376
+ with gr.Row():
377
+ with gr.Column():
378
+ input_images = gr.Gallery(label="Input Images",
379
+ show_label=False,
380
+ type="pil",
381
+ interactive=True)
382
+
383
+ with gr.Column():
384
+ result = gr.Gallery(label="Result", show_label=False, type="pil", interactive=False)
385
+ # Add this button right after the result gallery - initially hidden
386
+ use_output_btn = gr.Button("↗️ Use as input", variant="secondary", size="sm", visible=False)
387
+
388
+ with gr.Row():
389
+ prompt = gr.Text(
390
+ label="Prompt",
391
+ show_label=False,
392
+ placeholder="describe the edit instruction",
393
+ container=False,
394
+ )
395
+ run_button = gr.Button("Edit!", variant="primary")
396
+
397
+ with gr.Accordion("Advanced Settings", open=False):
398
+ # Negative prompt UI element is removed here
399
+
400
+ seed = gr.Slider(
401
+ label="Seed",
402
+ minimum=0,
403
+ maximum=MAX_SEED,
404
+ step=1,
405
+ value=0,
406
+ )
407
+
408
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
409
 
 
 
 
 
410
  with gr.Row():
411
+
412
+ true_guidance_scale = gr.Slider(
413
+ label="True guidance scale",
414
+ minimum=1.0,
415
+ maximum=10.0,
416
+ step=0.1,
417
+ value=1.0
418
+ )
419
+
420
+ num_inference_steps = gr.Slider(
421
+ label="Number of inference steps",
422
+ minimum=1,
423
+ maximum=40,
424
+ step=1,
425
+ value=4,
426
+ )
 
 
427
 
428
+ height = gr.Slider(
429
+ label="Height",
430
+ minimum=256,
431
+ maximum=2048,
432
+ step=8,
433
+ value=None,
434
+ )
435
+
436
+ width = gr.Slider(
437
+ label="Width",
438
+ minimum=256,
439
+ maximum=2048,
440
+ step=8,
441
+ value=None,
442
+ )
443
+
444
+
445
+ rewrite_prompt = gr.Checkbox(label="Rewrite prompt", value=True)
446
 
447
+ # gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=False)
448
+
449
+ gr.on(
450
+ triggers=[run_button.click, prompt.submit],
451
+ fn=infer,
452
+ inputs=[
453
+ input_images,
454
+ prompt,
455
+ seed,
456
+ randomize_seed,
457
+ true_guidance_scale,
458
+ num_inference_steps,
459
+ height,
460
+ width,
461
+ rewrite_prompt,
462
+ ],
463
+ outputs=[result, seed, use_output_btn], # Added use_output_btn to outputs
 
 
 
 
 
 
 
 
464
  )
465
+
466
+ # Add the new event handler for the "Use Output as Input" button
467
+ use_output_btn.click(
468
+ fn=use_output_as_input,
469
+ inputs=[result],
470
+ outputs=[input_images]
471
  )
472
 
473
  if __name__ == "__main__":
474
+ demo.launch(mcp_server=True)