Spaces:
Running
Running
| from huggingface_hub import HfApi, ModelFilter | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM | |
| from transformers.tokenization_utils_base import BatchEncoding | |
| from transformers.modeling_outputs import MaskedLMOutput | |
| # Function to fetch suitable ESM models from HuggingFace Hub | |
| def get_models() -> list[None|str]: | |
| """Fetch suitable ESM models from HuggingFace Hub.""" | |
| if not any( | |
| out := [ | |
| m.modelId for m in HfApi().list_models( | |
| filter=ModelFilter( | |
| author="facebook", model_name="esm", task="fill-mask" | |
| ), | |
| sort="lastModified", | |
| direction=-1 | |
| ) | |
| ] | |
| ): | |
| raise RuntimeError("Error while retrieving models from HuggingFace Hub") | |
| return out | |
| # Class to wrap ESM models | |
| class Model: | |
| """Wrapper for ESM models.""" | |
| def __init__(self, model_name: str = ""): | |
| """Load selected model and tokenizer.""" | |
| self.model_name = model_name | |
| if model_name: | |
| self.model = AutoModelForMaskedLM.from_pretrained(model_name) | |
| self.batch_converter = AutoTokenizer.from_pretrained(model_name) | |
| self.alphabet = self.batch_converter.get_vocab() | |
| # Check if CUDA is available and if so, use it | |
| if torch.cuda.is_available(): | |
| self.model = self.model.cuda() | |
| def tokenise(self, input: str) -> BatchEncoding: | |
| """Convert input string to batch of tokens.""" | |
| return self.batch_converter(input, return_tensors="pt") | |
| def __call__(self, batch_tokens: torch.Tensor, **kwargs) -> MaskedLMOutput: | |
| """Run model on batch of tokens.""" | |
| return self.model(batch_tokens, **kwargs) | |
| def __getitem__(self, key: str) -> int: | |
| """Get token ID from character.""" | |
| return self.alphabet[key] | |
| def run_model(self, data): | |
| """Run model on data.""" | |
| def label_row(row, token_probs): | |
| """Label row with score.""" | |
| # Extract wild type, index and mutant type from the row | |
| wt, idx, mt = row[0], int(row[1:-1])-1, row[-1] | |
| # Calculate the score as the difference between the token probabilities of the mutant type and the wild type | |
| score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]] | |
| return score.item() | |
| # Tokenise the sequence data | |
| batch_tokens = self.tokenise(data.seq).input_ids | |
| # Calculate the token probabilities without updating the model parameters | |
| with torch.no_grad(): | |
| token_probs = torch.log_softmax(self(batch_tokens).logits, dim=-1) | |
| # Store the token probabilities in the data | |
| data.token_probs = token_probs.cpu().numpy() | |
| # If the scoring strategy starts with "masked-marginals" | |
| if data.scoring_strategy.startswith("masked-marginals"): | |
| all_token_probs = [] | |
| # For each token in the batch | |
| for i in range(batch_tokens.size()[1]): | |
| # If the token is in the list of residues | |
| if i in data.resi: | |
| # Clone the batch tokens and mask the current token | |
| batch_tokens_masked = batch_tokens.clone() | |
| batch_tokens_masked[0, i] = self['<mask>'] | |
| # Calculate the masked token probabilities | |
| with torch.no_grad(): | |
| masked_token_probs = torch.log_softmax( | |
| self(batch_tokens_masked).logits, dim=-1 | |
| ) | |
| else: | |
| # If the token is not in the list of residues, use the original token probabilities | |
| masked_token_probs = token_probs | |
| # Append the token probabilities to the list | |
| all_token_probs.append(masked_token_probs[:, i]) | |
| # Concatenate all token probabilities | |
| token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0) | |
| # Apply the label_row function to each row of the substitutions dataframe | |
| data.out[self.model_name] = data.sub.apply( | |
| lambda row: label_row( | |
| row['0'], | |
| token_probs, | |
| ), | |
| axis=1, | |
| ) | |