File size: 4,057 Bytes
362a4c3
a887836
362a4c3
 
 
 
 
a887836
 
 
f19aa8b
94ff1b9
743042b
a887836
362a4c3
f19aa8b
5fe520a
362a4c3
 
 
27305e9
362a4c3
27305e9
362a4c3
 
 
 
a887836
362a4c3
27305e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aced7ae
 
27305e9
5d306a8
 
aced7ae
65e7711
 
 
 
5d306a8
65e7711
 
 
 
5d306a8
65e7711
5d306a8
 
 
743042b
 
65e7711
5d306a8
65e7711
5d306a8
65e7711
5d306a8
 
 
 
 
 
 
 
 
65e7711
 
 
 
 
5d306a8
 
 
 
 
 
743042b
5d306a8
743042b
 
 
 
 
5d306a8
743042b
 
65e7711
 
743042b
 
 
 
 
 
 
27305e9
362a4c3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import warnings
import os
import subprocess
from pathlib import Path
import shutil
import spaces
from atomworks.io.utils.visualize import view
from lightning.fabric import seed_everything
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
from utils import download_weights
from utils.pipelines import test_rfd3_from_notebook, unconditional_generation
#from gradio_molecule3d import Molecule3D


download_weights()

    
# Gradio UI
with gr.Blocks(title="RFD3 Test") as demo:
    gr.Markdown("# RFdiffusion3 (RFD3) for Backbone generation")
    gr.Markdown("Models auto-downloaded on launch. Click to test.")

    
    test_btn = gr.Button("Run RFD3 Test")
    output = gr.Textbox(label="Test Result")
    
    test_btn.click(test_rfd3_from_notebook, outputs=output)

    gr.Markdown("Unconditional generation of backbones")
    with gr.Row():
        num_designs_per_batch = gr.Number(
            value=2,
            label="Number of Designs per Batch",
            precision=0,
            minimum=1,
            maximum=8
        )
        num_batches = gr.Number(
            value=5,
            label="Number of Batches",
            precision=0,
            minimum=1,
            maximum=10
        )
        length = gr.Number(
            value=40,
            label="Length of Protein (number of residues)",
            precision=0,
            minimum=10,
            maximum=200
        )

    gen_directory = gr.State(None)
    gen_results = gr.State(None)
    gen_btn = gr.Button("Run Unconditional Generation")


    # New visualize section
    with gr.Row():
        batch_dropdown = gr.Dropdown(
            choices=[], 
            label="Select Batch",
            visible=True
        )
        design_dropdown = gr.Dropdown(
            choices=[], 
            label="Select Design",
            visible=True
        )
        visualize_btn = gr.Button("Visualize", visible=True)
        
    display_state = gr.Textbox(label="Selected Batch and Design", visible=True)
    display_state.value = "Please Select a Batch and Design number to show sequence"
    #viewer = Molecule3D(visible=True)

    def update_batch_choices(result):
        if result is None:
            return gr.Dropdown(choices = [])
        batches = sorted(list({d["batch"] for d in result}))
        return gr.update(choices=batches, visible=True)
        
    gen_btn.click(unconditional_generation, inputs=[num_batches, num_designs_per_batch, length], outputs=[gen_directory, gen_results]).then(
        update_batch_choices,
        inputs=gen_results,
        outputs=batch_dropdown)



    def update_designs(batch, result):
        if batch is None:
            return gr.update(choices=[])
        designs = sorted(list({d["design"] for d in result if d["batch"] == batch}))
        return gr.update(choices=designs)
    

    
    batch_dropdown.change(update_designs, inputs=[batch_dropdown, gen_results], outputs=[design_dropdown])
    design_dropdown.change()

    def visualize_selection(batch, design, result):
        if batch is None or design is None:
            return gr.update()
        pdb_path= next(d["pdb"] for d in result if d["batch"] == int(batch) and d["design"] == int(design))
        with open(pdb_path, 'r') as f:
            pdb_str = f.read()
        return gr.update(value=f"Selected Batch: {batch}, Design: {design}, saved at {pdb_str}:\n {pdb_str}", visible=True)

    
    visualize_btn.click(visualize_selection, inputs=[batch_dropdown, design_dropdown, gen_results], outputs=display_state)


    #def load_viewer(batch, design, result):
    #    if batch is None or design is None:
    #        return gr.update()
    #    pdb_data = next(d["pdb"] for d in result if d["batch"] == int(batch) and d["design"] == int(design))
    #    return gr.update(value=pdb_data, visible=True, reps=[{"style": "cartoon"}])  # Customize style
#
    #visualize_btn.click(load_viewer, inputs=[batch_dropdown, design_dropdown, gen_results], outputs=viewer)

if __name__ == "__main__":
    demo.launch()