| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from pathlib import Path |
| from urllib.request import urlretrieve |
|
|
| import numpy as np |
| import torch |
| from chemprop import featurizers, nn |
| from chemprop.data import BatchMolGraph |
| from chemprop.models import MPNN |
| from chemprop.nn import RegressionFFN |
| from rdkit.Chem import Mol, MolFromSmiles |
|
|
|
|
| class CheMeleonFingerprint: |
| def __init__(self, device: str | torch.device | None = None): |
| self.featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer() |
| agg = nn.MeanAggregation() |
| ckpt_dir = Path().home() / ".chemprop" |
| ckpt_dir.mkdir(exist_ok=True) |
| mp_path = ckpt_dir / "chemeleon_mp.pt" |
| if not mp_path.exists(): |
| urlretrieve( |
| r"https://zenodo.org/records/15460715/files/chemeleon_mp.pt", |
| mp_path, |
| ) |
| chemeleon_mp = torch.load(mp_path, weights_only=True) |
| mp = nn.BondMessagePassing(**chemeleon_mp["hyper_parameters"]) |
| mp.load_state_dict(chemeleon_mp["state_dict"]) |
| self.model = MPNN( |
| message_passing=mp, |
| agg=agg, |
| predictor=RegressionFFN(input_dim=mp.output_dim), |
| ) |
| self.model.eval() |
| if device is not None: |
| self.model.to(device=device) |
|
|
| def __call__(self, molecules: list[str | Mol]) -> np.ndarray: |
| bmg = BatchMolGraph( |
| [ |
| self.featurizer(MolFromSmiles(m) if isinstance(m, str) else m) |
| for m in molecules |
| ] |
| ) |
| bmg.to(device=self.model.device) |
| with torch.no_grad(): |
| return self.model.fingerprint(bmg).numpy(force=True) |
|
|
|
|
| if __name__ == "__main__": |
| chemeleon_fingerprint = CheMeleonFingerprint() |
| chemeleon_fingerprint(["C", "CC", MolFromSmiles("CCC")]) |
|
|