RFdiffusion3 / app.py
gabboud's picture
revert to non-persistence of weights, describe persistence issue in issues_doc
c47162e
Raw
History Blame
3.99 kB
import gradio as gr
import warnings
import os
import subprocess
from pathlib import Path
import shutil
import spaces
from atomworks.io.utils.visualize import view
from lightning.fabric import seed_everything
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
# Download model weights (skips already-downloaded models automatically)
# In total, ~6GB (3GB for RFD3, 3GB for RF3, <100MB for MPNN); may take a few minutes depending on your connection speed
cmd = f"foundry install rfd3 ligandmpnn rf3 --checkpoint-dir /data/checkpoints"
print("Global PATH:", os.environ.get("PATH"))
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
if result.returncode == 0:
print("Models installed successfully.")
else:
print(f"Error installing models: {result.stderr}")
print(result.stdout)
print(result.returncode)
#download_dir = "./models/"
#if not os.path.exists(download_dir):
# cmd = "foundry install rfd3 ligandmpnn rf3"
# result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
#
## Run once on startup: Install models if missing
#checkpoint_dir = Path.home() / ".foundry" / "checkpoints"
#os.environ["FOUNDRY_CHECKPOINT_DIRS"] = str(checkpoint_dir)
#
#def install_models():
# """Download rfd3, ligandmpnn, rf3 weights once."""
# #models = ["rfd3", "ligandmpnn", "rf3"]
# models = ["ligandmpnn"] # let's start with only ligand mpnn for testing
# for model in models:
# if not (checkpoint_dir / model).exists():
# print(f"Installing {model}...")
# subprocess.check_call(["foundry", "install", model])
# print("All models installed.")
#
#install_models() # Executes on app.py load
@spaces.GPU(duration=300)
def test_rfd3():
"""Run a quick rfd3 test design (minimal monomer, 1 step)."""
try:
cmd = [
"rfd3",
"design",
"--seed", "42",
"contigmap.contigs=[A10]", # Tiny 10-res monomer
"--num_designs", "1",
"inference.output_prefix=test_output",
"--inference.num_diffusion_steps", "10" # Fast test
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
if result.returncode == 0:
return "RFD3 test passed! Check test_output.pdb. Logs:\n" + result.stdout[-500:]
else:
return f"RFD3 test failed: {result.stderr}"
except Exception as e:
return f"Error: {str(e)}"
@spaces.GPU(duration=300)
def test_rfd3_from_notebook():
# Set seed for reproducibility
seed_everything(0)
# Configure RFD3 inference
config = RFD3InferenceConfig(
specification={
'length': 40, # Generate 80-residue proteins
},
diffusion_batch_size=2, # Generate 2 structures per batch
)
# Initialize engine and run generation
try:
model = RFD3InferenceEngine(**config)
outputs = model.run(
inputs=None, # None for unconditional generation
out_dir=None, # None to return in memory (no file output)
n_batches=1, # Generate 1 batch
)
return_str = "RDF3 test passed! Generated structures:\n"
for idx, data in outputs.items():
return_str += f"Batch {idx}: {len(data)} structure(s)\n"
for i, struct in enumerate(data):
return_str += f"Structure {i+1}: {struct.atom_array.array_length()} Atoms\n"
#return_str += struct.atom_array
return return_str
except Exception as e:
return f"Error: {str(e)}"
# Gradio UI
with gr.Blocks(title="RFD3 Test") as demo:
gr.Markdown("# RFdiffusion3 (RFD3) Model Checker")
gr.Markdown("Models auto-downloaded on launch. Click to test.")
test_btn = gr.Button("Run RFD3 Test")
output = gr.Textbox(label="Test Result")
test_btn.click(test_rfd3_from_notebook, outputs=output)
if __name__ == "__main__":
demo.launch()