silveroxides commited on
Commit
c81c2ee
·
verified ·
1 Parent(s): 4709f01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -4,6 +4,7 @@ import safetensors
4
  import timm
5
  from transformers import AutoProcessor
6
  import gradio as gr
 
7
  import torch
8
  import time
9
  from florence2_implementation.modeling_florence2 import Florence2ForConditionalGeneration
@@ -239,6 +240,7 @@ for idx, tag in enumerate(allowed_tags):
239
 
240
  pruner = Pruner("tags-2024-05-05.csv")
241
 
 
242
  def generate_prompt(image, expected_caption_length):
243
  global THRESHOLD, tree, tokenizer, model, tagger_model, tagger_transform
244
  tagger_input = tagger_transform(image.convert('RGBA')).unsqueeze(0)
@@ -254,7 +256,7 @@ def generate_prompt(image, expected_caption_length):
254
  task_prompt = pruner.prompt_construction_pipeline_florence2(final_tags, expected_caption_length)
255
  return task_prompt
256
 
257
-
258
  def inference_caption(image, expected_caption_length, seq_len=512,):
259
  start_time = time.time()
260
  prompt_input = generate_prompt(image, expected_caption_length)
 
4
  import timm
5
  from transformers import AutoProcessor
6
  import gradio as gr
7
+ import spaces
8
  import torch
9
  import time
10
  from florence2_implementation.modeling_florence2 import Florence2ForConditionalGeneration
 
240
 
241
  pruner = Pruner("tags-2024-05-05.csv")
242
 
243
+ @spaces.GPU
244
  def generate_prompt(image, expected_caption_length):
245
  global THRESHOLD, tree, tokenizer, model, tagger_model, tagger_transform
246
  tagger_input = tagger_transform(image.convert('RGBA')).unsqueeze(0)
 
256
  task_prompt = pruner.prompt_construction_pipeline_florence2(final_tags, expected_caption_length)
257
  return task_prompt
258
 
259
+ @spaces.GPU
260
  def inference_caption(image, expected_caption_length, seq_len=512,):
261
  start_time = time.time()
262
  prompt_input = generate_prompt(image, expected_caption_length)