Spaces:
Running
Running
| import torch | |
| import huggingface_hub | |
| from transformers import AutoTokenizer, AutoModel, EsmForMaskedLM, EsmTokenizer | |
| def cache_model_weights(model_id): | |
| """ | |
| Download ESM2 model weights to cache without loading into memory. Called upon restarting of spaces to have weights ready to load once inference is called. | |
| Downloading weights without | |
| Parameters: | |
| ----------- | |
| model_id : str | |
| Model identifier (e.g., "facebook/esm2_t6_8M_UR50D") | |
| Returns: | |
| -------- | |
| str : Path to cached model directory | |
| """ | |
| cache_dir = huggingface_hub.snapshot_download(model_id) | |
| print(f"Model {model_id} cached at: {cache_dir}") | |
| return cache_dir | |
| def cache_all_models(models): | |
| """ | |
| Cache all models in the provided dictionary. | |
| Parameters: | |
| ----------- | |
| models : dict | |
| A dictionary where keys are model identifiers (e.g., "facebook/esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M"). | |
| Returns: | |
| -------- | |
| dict : A dictionary mapping model identifiers to their cache directories. | |
| """ | |
| cache_dirs = {} | |
| for model_id in models.keys(): | |
| cache_dirs[model_id] = cache_model_weights(model_id) | |
| return cache_dirs | |
| def load_model(model_id): | |
| """ | |
| Load ESM model and tokenizer using from_pretrained. Initializes from default cache directory or downloads if missing. | |
| To be used after cache_model_weights for control over when models are downloaded | |
| Parameters: | |
| ----------- | |
| model_id : str | |
| Model identifier (e.g., "facebook/esm2_t6_8M_UR50D") | |
| Returns: | |
| -------- | |
| tuple : (model, tokenized) loaded from cache | |
| """ | |
| try: | |
| print(f"Loading {model_id} from local cache...") | |
| tokenizer = EsmTokenizer.from_pretrained(model_id) | |
| model = EsmForMaskedLM.from_pretrained( | |
| model_id, | |
| output_hidden_states=True, | |
| ) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model {model_id} from cache: {e}") | |
| model = model.eval() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| print(f"{model_id} loaded on {device}") | |
| return model, tokenizer | |
| def load_all_models(models): | |
| """ | |
| Load all models in the provided dictionary. | |
| Parameters: | |
| ----------- | |
| models : dict | |
| A dictionary where keys are model identifiers (e.g., "facebook/esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M"). | |
| Returns: | |
| -------- | |
| dict : A dictionary mapping model identifiers to their loaded (model, tokenizer) tuples. | |
| """ | |
| loaded_models = {} | |
| for model_id in models.keys(): | |
| loaded_models[model_id] = load_model(model_id) | |
| return loaded_models | |
| #def cache_models(models): | |
| # """ | |
| # Download weights to ESM models in cache to be loaded later. | |
| # We do not load the models into memory at this stage to avoid using GPU memory for models that are not used in the current session. | |
| # | |
| # Parameters: | |
| # ---------- | |
| # models: dict | |
| # A dictionary where keys are model identifiers (e.g., "esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M"). | |
| # | |
| # Returns: | |
| # ------- | |
| # | |
| # """ | |
| # loaded_models = {} | |
| # for model_id, model_name in models.items(): | |
| # print(f"Loading {model_name}...") | |
| # try: | |
| # #load from local cache if avilable, upon startup of space will fail and load from HF | |
| # model, alphabet = esm.pretrained.load_model_and_alphabet_local(model_id) | |
| # except: | |
| # print(f"Loading {model_name} from HuggingFace...") | |
| # model, alphabet = esm.pretrained.load_model_and_alphabet_hub(model_id) | |
| # | |
| # model = model.eval() | |
| # device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # model = model.to(device) | |
| # loaded_models[model_id] = { | |
| # "model": model, | |
| # "alphabet": alphabet, | |
| # "batch_converter": alphabet.get_batch_converter() | |
| # } | |
| # print(f"{model_name} loaded on {device}") | |
| # | |
| #def download_models(models): | |
| # """ | |
| # Download weights to ESM models in cache to be loaded later. | |
| # We do not load the models into memory at this stage to avoid using GPU memory for models that are not used in the current session. | |
| # | |
| # Parameters: | |
| # ---------- | |
| # models: dict | |
| # A dictionary where keys are model identifiers (e.g., "esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M"). | |
| # | |
| # Returns: | |
| # ------- | |
| # | |
| # """ | |
| # loaded_models = {} | |
| # for model_id, model_name in models.items(): | |
| # print(f"Loading {model_name}...") | |
| # try: | |
| # #load from local cache if avilable, upon startup of space will fail and load from HF | |
| # model, alphabet = esm.pretrained.load_model_and_alphabet_local(model_id) | |
| # except: | |
| # print(f"Loading {model_name} from HuggingFace...") | |
| # model, alphabet = esm.pretrained.load_model_and_alphabet_hub(model_id) | |
| # | |
| # model = model.eval() | |
| # device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # model = model.to(device) | |
| # loaded_models[model_id] = { | |
| # "model": model, | |
| # "alphabet": alphabet, | |
| # "batch_converter": alphabet.get_batch_converter() | |
| # } | |
| # print(f"{model_name} loaded on {device}") |