Nekochu commited on
Commit
8ecadee
·
1 Parent(s): d572d1b

Drop flash-attn install (was broken stub); attn_implementation=sdpa; GPU duration 119->60 (free tier max)

Browse files
Files changed (1) hide show
  1. app.py +2 -4
app.py CHANGED
@@ -1,7 +1,5 @@
1
  import os
2
  import gc
3
- import subprocess
4
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
5
 
6
  # Workaround for gradio_client crashes on bool schemas (e.g. additionalProperties: True/False).
7
  # Must run BEFORE `import gradio as gr` so the patched functions are used.
@@ -156,7 +154,7 @@ def clear_old_cache():
156
  gc.collect()
157
  torch.cuda.empty_cache()
158
 
159
- @spaces.GPU(duration=119)
160
  def generate_text_gpu(model_id_str, message, history, system, temp, top_p, top_k, max_tokens, rep_penalty):
161
  """Text generation with branch support"""
162
  global models_cache, stop_event, current_thread
@@ -196,7 +194,7 @@ def generate_text_gpu(model_id_str, message, history, system, temp, top_p, top_k
196
  "quantization_config": bnb_config,
197
  "device_map": "auto",
198
  "trust_remote_code": True,
199
- "attn_implementation": "flash_attention_2" if torch.cuda.is_available() else None,
200
  "low_cpu_mem_usage": True
201
  }
202
  if branch:
 
1
  import os
2
  import gc
 
 
3
 
4
  # Workaround for gradio_client crashes on bool schemas (e.g. additionalProperties: True/False).
5
  # Must run BEFORE `import gradio as gr` so the patched functions are used.
 
154
  gc.collect()
155
  torch.cuda.empty_cache()
156
 
157
+ @spaces.GPU(duration=60)
158
  def generate_text_gpu(model_id_str, message, history, system, temp, top_p, top_k, max_tokens, rep_penalty):
159
  """Text generation with branch support"""
160
  global models_cache, stop_event, current_thread
 
194
  "quantization_config": bnb_config,
195
  "device_map": "auto",
196
  "trust_remote_code": True,
197
+ "attn_implementation": "sdpa" if torch.cuda.is_available() else None,
198
  "low_cpu_mem_usage": True
199
  }
200
  if branch: