File size: 3,989 Bytes
362a4c3
a887836
362a4c3
 
 
 
 
a887836
 
 
 
362a4c3
 
 
 
5fe520a
605b2c5
 
5fe520a
6ea1a16
 
 
 
 
dd7b175
 
362a4c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a887836
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8849d7d
8b91ed4
a887836
 
 
 
362a4c3
 
 
 
 
 
 
 
 
a887836
362a4c3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import warnings
import os
import subprocess
from pathlib import Path
import shutil
import spaces
from atomworks.io.utils.visualize import view
from lightning.fabric import seed_everything
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine


# Download model weights (skips already-downloaded models automatically)
# In total, ~6GB (3GB for RFD3, 3GB for RF3, <100MB for MPNN); may take a few minutes depending on your connection speed


cmd =  f"foundry install rfd3 ligandmpnn rf3 --checkpoint-dir /data/checkpoints"
print("Global PATH:", os.environ.get("PATH"))

result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
if result.returncode == 0:
    print("Models installed successfully.")
else:   
    print(f"Error installing models: {result.stderr}")
    print(result.stdout)
    print(result.returncode)

#download_dir = "./models/"
#if not os.path.exists(download_dir):
#    cmd  = "foundry install rfd3 ligandmpnn rf3"
#    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
#
## Run once on startup: Install models if missing
#checkpoint_dir = Path.home() / ".foundry" / "checkpoints"
#os.environ["FOUNDRY_CHECKPOINT_DIRS"] = str(checkpoint_dir)
#
#def install_models():
#    """Download rfd3, ligandmpnn, rf3 weights once."""
#    #models = ["rfd3", "ligandmpnn", "rf3"]
#    models = ["ligandmpnn"] # let's start with only ligand mpnn for testing
#    for model in models:
#        if not (checkpoint_dir / model).exists():
#            print(f"Installing {model}...")
#            subprocess.check_call(["foundry", "install", model])
#    print("All models installed.")
#
#install_models()  # Executes on app.py load

@spaces.GPU(duration=300)  
def test_rfd3():
    """Run a quick rfd3 test design (minimal monomer, 1 step)."""
    try:
        cmd = [
            "rfd3",
            "design",
            "--seed", "42",
            "contigmap.contigs=[A10]",  # Tiny 10-res monomer
            "--num_designs", "1",
            "inference.output_prefix=test_output",
            "--inference.num_diffusion_steps", "10"  # Fast test
        ]
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
        if result.returncode == 0:
            return "RFD3 test passed! Check test_output.pdb. Logs:\n" + result.stdout[-500:]
        else:
            return f"RFD3 test failed: {result.stderr}"
    except Exception as e:
        return f"Error: {str(e)}"

@spaces.GPU(duration=300)  
def test_rfd3_from_notebook():
    # Set seed for reproducibility
    seed_everything(0)

    # Configure RFD3 inference
    config = RFD3InferenceConfig(
        specification={
            'length': 40,  # Generate 80-residue proteins
        },
        diffusion_batch_size=2,  # Generate 2 structures per batch
    )

    # Initialize engine and run generation
    try: 
        model = RFD3InferenceEngine(**config)
        outputs = model.run(
            inputs=None,      # None for unconditional generation
            out_dir=None,     # None to return in memory (no file output)
            n_batches=1,      # Generate 1 batch
        )
        return_str = "RDF3 test passed! Generated structures:\n"

        for idx, data in outputs.items():
            return_str += f"Batch {idx}: {len(data)} structure(s)\n"
            for i, struct in enumerate(data):
                return_str += f"Structure {i+1}: {struct.atom_array.array_length()} Atoms\n"
                #return_str += struct.atom_array
        return return_str
    except Exception as e:
        return f"Error: {str(e)}"

    
# Gradio UI
with gr.Blocks(title="RFD3 Test") as demo:
    gr.Markdown("# RFdiffusion3 (RFD3) Model Checker")
    gr.Markdown("Models auto-downloaded on launch. Click to test.")
    
    test_btn = gr.Button("Run RFD3 Test")
    output = gr.Textbox(label="Test Result")
    
    test_btn.click(test_rfd3_from_notebook, outputs=output)

if __name__ == "__main__":
    demo.launch()