import torch import huggingface_hub from transformers import AutoTokenizer, AutoModel, EsmForMaskedLM, EsmTokenizer def cache_model_weights(model_id): """ Download ESM2 model weights to cache without loading into memory. Called upon restarting of spaces to have weights ready to load once inference is called. Downloading weights without Parameters: ----------- model_id : str Model identifier (e.g., "facebook/esm2_t6_8M_UR50D") Returns: -------- str : Path to cached model directory """ cache_dir = huggingface_hub.snapshot_download(model_id) print(f"Model {model_id} cached at: {cache_dir}") return cache_dir def cache_all_models(models): """ Cache all models in the provided dictionary. Parameters: ----------- models : dict A dictionary where keys are model identifiers (e.g., "facebook/esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M"). Returns: -------- dict : A dictionary mapping model identifiers to their cache directories. """ cache_dirs = {} for model_id in models.keys(): cache_dirs[model_id] = cache_model_weights(model_id) return cache_dirs def load_model(model_id): """ Load ESM model and tokenizer using from_pretrained. Initializes from default cache directory or downloads if missing. To be used after cache_model_weights for control over when models are downloaded Parameters: ----------- model_id : str Model identifier (e.g., "facebook/esm2_t6_8M_UR50D") Returns: -------- tuple : (model, tokenized) loaded from cache """ try: print(f"Loading {model_id} from local cache...") tokenizer = EsmTokenizer.from_pretrained(model_id) model = EsmForMaskedLM.from_pretrained( model_id, output_hidden_states=True, ) except Exception as e: raise RuntimeError(f"Failed to load model {model_id} from cache: {e}") model = model.eval() device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) print(f"{model_id} loaded on {device}") return model, tokenizer def load_all_models(models): """ Load all models in the provided dictionary. Parameters: ----------- models : dict A dictionary where keys are model identifiers (e.g., "facebook/esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M"). Returns: -------- dict : A dictionary mapping model identifiers to their loaded (model, tokenizer) tuples. """ loaded_models = {} for model_id in models.keys(): loaded_models[model_id] = load_model(model_id) return loaded_models #def cache_models(models): # """ # Download weights to ESM models in cache to be loaded later. # We do not load the models into memory at this stage to avoid using GPU memory for models that are not used in the current session. # # Parameters: # ---------- # models: dict # A dictionary where keys are model identifiers (e.g., "esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M"). # # Returns: # ------- # # """ # loaded_models = {} # for model_id, model_name in models.items(): # print(f"Loading {model_name}...") # try: # #load from local cache if avilable, upon startup of space will fail and load from HF # model, alphabet = esm.pretrained.load_model_and_alphabet_local(model_id) # except: # print(f"Loading {model_name} from HuggingFace...") # model, alphabet = esm.pretrained.load_model_and_alphabet_hub(model_id) # # model = model.eval() # device = "cuda" if torch.cuda.is_available() else "cpu" # model = model.to(device) # loaded_models[model_id] = { # "model": model, # "alphabet": alphabet, # "batch_converter": alphabet.get_batch_converter() # } # print(f"{model_name} loaded on {device}") # #def download_models(models): # """ # Download weights to ESM models in cache to be loaded later. # We do not load the models into memory at this stage to avoid using GPU memory for models that are not used in the current session. # # Parameters: # ---------- # models: dict # A dictionary where keys are model identifiers (e.g., "esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M"). # # Returns: # ------- # # """ # loaded_models = {} # for model_id, model_name in models.items(): # print(f"Loading {model_name}...") # try: # #load from local cache if avilable, upon startup of space will fail and load from HF # model, alphabet = esm.pretrained.load_model_and_alphabet_local(model_id) # except: # print(f"Loading {model_name} from HuggingFace...") # model, alphabet = esm.pretrained.load_model_and_alphabet_hub(model_id) # # model = model.eval() # device = "cuda" if torch.cuda.is_available() else "cpu" # model = model.to(device) # loaded_models[model_id] = { # "model": model, # "alphabet": alphabet, # "batch_converter": alphabet.get_batch_converter() # } # print(f"{model_name} loaded on {device}")