multimodalart HF Staff commited on
Commit
4bec2d3
·
verified ·
1 Parent(s): a65d8ba

Update gradio_app_streaming.py

Browse files
Files changed (1) hide show
  1. gradio_app_streaming.py +43 -14
gradio_app_streaming.py CHANGED
@@ -84,21 +84,11 @@ def _save_chunk_audio_to_wav(audio_array, wav_path, sample_rate=16000):
84
  wav_file.writeframes(samples.tobytes())
85
  return wav_path
86
 
87
- ckpt_dir = "models/SoulX-FlashHead-1_3B"
88
- wav2vec_dir = "models/wav2vec2-base-960h"
89
- model_type = "lite"
90
- pipeline = get_pipeline(
91
- world_size=1,
92
- ckpt_dir=ckpt_dir,
93
- model_type=model_type,
94
- wav2vec_dir=wav2vec_dir,
95
- )
96
- loaded_ckpt_dir = ckpt_dir
97
- loaded_wav2vec_dir = wav2vec_dir
98
- loaded_model_type = model_type
99
-
100
  @spaces.GPU
101
  def run_inference_streaming(
 
 
 
102
  cond_image,
103
  audio_path,
104
  seed,
@@ -109,7 +99,30 @@ def run_inference_streaming(
109
  流式推理:主程序监控 res_queue,有 frames 就保存并 yield;
110
  推理在独立线程中执行,按 chunk 顺序 infer,结果放入 res_queue。
111
  """
112
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  progress(0.5, desc="Preparing Data...")
114
  base_seed = int(seed) if seed >= 0 else 9999
115
  try:
@@ -285,6 +298,19 @@ with gr.Blocks(title="SoulX-FlashHead 流式视频生成", theme=gr.themes.Soft(
285
  )
286
  generate_btn = gr.Button("🚀 流式生成视频", variant="primary", size="lg")
287
  with gr.Accordion("⚙️ 高级设置", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  use_face_crop_input = gr.Checkbox(label="Use Face Crop", value=False)
289
  seed_input = gr.Number(label="Random Seed", value=9999, precision=0)
290
  with gr.Column(scale=1):
@@ -300,6 +326,9 @@ with gr.Blocks(title="SoulX-FlashHead 流式视频生成", theme=gr.themes.Soft(
300
  generate_btn.click(
301
  fn=run_inference_streaming,
302
  inputs=[
 
 
 
303
  cond_image_input,
304
  audio_path_input,
305
  seed_input,
 
84
  wav_file.writeframes(samples.tobytes())
85
  return wav_path
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  @spaces.GPU
88
  def run_inference_streaming(
89
+ ckpt_dir,
90
+ wav2vec_dir,
91
+ model_type,
92
  cond_image,
93
  audio_path,
94
  seed,
 
99
  流式推理:主程序监控 res_queue,有 frames 就保存并 yield;
100
  推理在独立线程中执行,按 chunk 顺序 infer,结果放入 res_queue。
101
  """
102
+ global pipeline, loaded_ckpt_dir, loaded_wav2vec_dir, loaded_model_type
103
+
104
+ if (
105
+ pipeline is None
106
+ or loaded_ckpt_dir != ckpt_dir
107
+ or loaded_wav2vec_dir != wav2vec_dir
108
+ or loaded_model_type != model_type
109
+ ):
110
+ progress(0.2, desc="Loading Model...")
111
+ logger.info(f"Loading pipeline with ckpt_dir={ckpt_dir}, wav2vec_dir={wav2vec_dir}")
112
+ try:
113
+ pipeline = get_pipeline(
114
+ world_size=1,
115
+ ckpt_dir=ckpt_dir,
116
+ model_type=model_type,
117
+ wav2vec_dir=wav2vec_dir,
118
+ )
119
+ loaded_ckpt_dir = ckpt_dir
120
+ loaded_wav2vec_dir = wav2vec_dir
121
+ loaded_model_type = model_type
122
+ except Exception as e:
123
+ logger.error(f"Failed to load model: {e}")
124
+ raise gr.Error(f"Failed to load model: {e}")
125
+
126
  progress(0.5, desc="Preparing Data...")
127
  base_seed = int(seed) if seed >= 0 else 9999
128
  try:
 
298
  )
299
  generate_btn = gr.Button("🚀 流式生成视频", variant="primary", size="lg")
300
  with gr.Accordion("⚙️ 高级设置", open=False):
301
+ ckpt_dir_input = gr.Textbox(
302
+ label="FlashHead Checkpoint Directory",
303
+ value="models/SoulX-FlashHead-1_3B",
304
+ )
305
+ wav2vec_dir_input = gr.Textbox(
306
+ label="Wav2Vec Directory",
307
+ value="models/wav2vec2-base-960h",
308
+ )
309
+ model_type_input = gr.Dropdown(
310
+ label="Model Type",
311
+ choices=["pro", "lite"],
312
+ value="lite",
313
+ )
314
  use_face_crop_input = gr.Checkbox(label="Use Face Crop", value=False)
315
  seed_input = gr.Number(label="Random Seed", value=9999, precision=0)
316
  with gr.Column(scale=1):
 
326
  generate_btn.click(
327
  fn=run_inference_streaming,
328
  inputs=[
329
+ ckpt_dir_input,
330
+ wav2vec_dir_input,
331
+ model_type_input,
332
  cond_image_input,
333
  audio_path_input,
334
  seed_input,