Food Desert commited on
Commit
dbdfd45
·
1 Parent(s): 797364d

Fix HF startup: add torch deps and lazy-load T5 imports

Browse files
psq_rag/llm/rewrite_local_t5.py CHANGED
@@ -5,9 +5,6 @@ import threading
5
  from pathlib import Path
6
  from typing import Optional
7
 
8
- import torch
9
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
10
-
11
  _LOCK = threading.Lock()
12
  _MODEL = None
13
  _TOKENIZER = None
@@ -32,6 +29,10 @@ def _load_model(model_dir: Path):
32
  if not model_dir.is_dir():
33
  raise FileNotFoundError(f"T5 rewrite model directory not found: {model_dir}")
34
 
 
 
 
 
35
  tokenizer = AutoTokenizer.from_pretrained(str(model_dir), local_files_only=True, use_fast=False)
36
  model = AutoModelForSeq2SeqLM.from_pretrained(str(model_dir), local_files_only=True)
37
  model.eval()
@@ -55,6 +56,8 @@ def local_t5_rewrite_prompt(
55
  return ""
56
 
57
  try:
 
 
58
  resolved_dir = _resolve_model_dir(model_dir).resolve()
59
  model, tokenizer = _load_model(resolved_dir)
60
  task_prefix = os.environ.get("PSQ_T5_REWRITE_TASK_PREFIX", "caption_to_tags: ").strip()
 
5
  from pathlib import Path
6
  from typing import Optional
7
 
 
 
 
8
  _LOCK = threading.Lock()
9
  _MODEL = None
10
  _TOKENIZER = None
 
29
  if not model_dir.is_dir():
30
  raise FileNotFoundError(f"T5 rewrite model directory not found: {model_dir}")
31
 
32
+ # Lazy import so missing deps never crash app import/startup.
33
+ import torch
34
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
35
+
36
  tokenizer = AutoTokenizer.from_pretrained(str(model_dir), local_files_only=True, use_fast=False)
37
  model = AutoModelForSeq2SeqLM.from_pretrained(str(model_dir), local_files_only=True)
38
  model.eval()
 
56
  return ""
57
 
58
  try:
59
+ import torch
60
+
61
  resolved_dir = _resolve_model_dir(model_dir).resolve()
62
  model, tokenizer = _load_model(resolved_dir)
63
  task_prefix = os.environ.get("PSQ_T5_REWRITE_TASK_PREFIX", "caption_to_tags: ").strip()
requirements.txt CHANGED
@@ -12,3 +12,6 @@ huggingface_hub<1.0
12
  rapidfuzz>=3.0
13
  langchain-core>=0.3,<0.4
14
  langchain-openai>=0.2,<0.4
 
 
 
 
12
  rapidfuzz>=3.0
13
  langchain-core>=0.3,<0.4
14
  langchain-openai>=0.2,<0.4
15
+ torch
16
+ transformers>=4.46,<5
17
+ sentencepiece