ESM2 / utils /download_models.py
gabboud's picture
fix locally and implement PPL
ae38197
Raw
History Blame Contribute Delete
5.43 kB
import torch
import huggingface_hub
from transformers import AutoTokenizer, AutoModel, EsmForMaskedLM, EsmTokenizer
def cache_model_weights(model_id):
"""
Download ESM2 model weights to cache without loading into memory. Called upon restarting of spaces to have weights ready to load once inference is called.
Downloading weights without
Parameters:
-----------
model_id : str
Model identifier (e.g., "facebook/esm2_t6_8M_UR50D")
Returns:
--------
str : Path to cached model directory
"""
cache_dir = huggingface_hub.snapshot_download(model_id)
print(f"Model {model_id} cached at: {cache_dir}")
return cache_dir
def cache_all_models(models):
"""
Cache all models in the provided dictionary.
Parameters:
-----------
models : dict
A dictionary where keys are model identifiers (e.g., "facebook/esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M").
Returns:
--------
dict : A dictionary mapping model identifiers to their cache directories.
"""
cache_dirs = {}
for model_id in models.keys():
cache_dirs[model_id] = cache_model_weights(model_id)
return cache_dirs
def load_model(model_id):
"""
Load ESM model and tokenizer using from_pretrained. Initializes from default cache directory or downloads if missing.
To be used after cache_model_weights for control over when models are downloaded
Parameters:
-----------
model_id : str
Model identifier (e.g., "facebook/esm2_t6_8M_UR50D")
Returns:
--------
tuple : (model, tokenized) loaded from cache
"""
try:
print(f"Loading {model_id} from local cache...")
tokenizer = EsmTokenizer.from_pretrained(model_id)
model = EsmForMaskedLM.from_pretrained(
model_id,
output_hidden_states=True,
)
except Exception as e:
raise RuntimeError(f"Failed to load model {model_id} from cache: {e}")
model = model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print(f"{model_id} loaded on {device}")
return model, tokenizer
def load_all_models(models):
"""
Load all models in the provided dictionary.
Parameters:
-----------
models : dict
A dictionary where keys are model identifiers (e.g., "facebook/esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M").
Returns:
--------
dict : A dictionary mapping model identifiers to their loaded (model, tokenizer) tuples.
"""
loaded_models = {}
for model_id in models.keys():
loaded_models[model_id] = load_model(model_id)
return loaded_models
#def cache_models(models):
# """
# Download weights to ESM models in cache to be loaded later.
# We do not load the models into memory at this stage to avoid using GPU memory for models that are not used in the current session.
#
# Parameters:
# ----------
# models: dict
# A dictionary where keys are model identifiers (e.g., "esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M").
#
# Returns:
# -------
#
# """
# loaded_models = {}
# for model_id, model_name in models.items():
# print(f"Loading {model_name}...")
# try:
# #load from local cache if avilable, upon startup of space will fail and load from HF
# model, alphabet = esm.pretrained.load_model_and_alphabet_local(model_id)
# except:
# print(f"Loading {model_name} from HuggingFace...")
# model, alphabet = esm.pretrained.load_model_and_alphabet_hub(model_id)
#
# model = model.eval()
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model = model.to(device)
# loaded_models[model_id] = {
# "model": model,
# "alphabet": alphabet,
# "batch_converter": alphabet.get_batch_converter()
# }
# print(f"{model_name} loaded on {device}")
#
#def download_models(models):
# """
# Download weights to ESM models in cache to be loaded later.
# We do not load the models into memory at this stage to avoid using GPU memory for models that are not used in the current session.
#
# Parameters:
# ----------
# models: dict
# A dictionary where keys are model identifiers (e.g., "esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M").
#
# Returns:
# -------
#
# """
# loaded_models = {}
# for model_id, model_name in models.items():
# print(f"Loading {model_name}...")
# try:
# #load from local cache if avilable, upon startup of space will fail and load from HF
# model, alphabet = esm.pretrained.load_model_and_alphabet_local(model_id)
# except:
# print(f"Loading {model_name} from HuggingFace...")
# model, alphabet = esm.pretrained.load_model_and_alphabet_hub(model_id)
#
# model = model.eval()
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model = model.to(device)
# loaded_models[model_id] = {
# "model": model,
# "alphabet": alphabet,
# "batch_converter": alphabet.get_batch_converter()
# }
# print(f"{model_name} loaded on {device}")