Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| import warnings | |
| import os | |
| import subprocess | |
| from pathlib import Path | |
| import shutil | |
| import spaces | |
| from atomworks.io.utils.visualize import view | |
| from lightning.fabric import seed_everything | |
| from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine | |
| from utils import download_weights | |
| from utils.pipelines import test_rfd3_from_notebook, unconditional_generation | |
| #from gradio_molecule3d import Molecule3D | |
| download_weights() | |
| # Gradio UI | |
| with gr.Blocks(title="RFD3 Test") as demo: | |
| gr.Markdown("# RFdiffusion3 (RFD3) for Backbone generation") | |
| gr.Markdown("Models auto-downloaded on launch. Click to test.") | |
| test_btn = gr.Button("Run RFD3 Test") | |
| output = gr.Textbox(label="Test Result") | |
| test_btn.click(test_rfd3_from_notebook, outputs=output) | |
| gr.Markdown("Unconditional generation of backbones") | |
| with gr.Row(): | |
| num_designs_per_batch = gr.Number( | |
| value=2, | |
| label="Number of Designs per Batch", | |
| precision=0, | |
| minimum=1, | |
| maximum=8 | |
| ) | |
| num_batches = gr.Number( | |
| value=5, | |
| label="Number of Batches", | |
| precision=0, | |
| minimum=1, | |
| maximum=10 | |
| ) | |
| length = gr.Number( | |
| value=40, | |
| label="Length of Protein (number of residues)", | |
| precision=0, | |
| minimum=10, | |
| maximum=200 | |
| ) | |
| gen_directory = gr.State(None) | |
| gen_results = gr.State(None) | |
| gen_btn = gr.Button("Run Unconditional Generation") | |
| # New visualize section | |
| with gr.Row(): | |
| batch_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="Select Batch", | |
| visible=True | |
| ) | |
| design_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="Select Design", | |
| visible=True | |
| ) | |
| visualize_btn = gr.Button("Visualize", visible=True) | |
| display_state = gr.Textbox(label="Selected Batch and Design", visible=True) | |
| display_state.value = "Please Select a Batch and Design number to show sequence" | |
| #viewer = Molecule3D(visible=True) | |
| def update_batch_choices(result): | |
| if result is None: | |
| return gr.Dropdown(choices = []) | |
| batches = sorted(list({d["batch"] for d in result})) | |
| return gr.update(choices=batches, visible=True) | |
| gen_btn.click(unconditional_generation, inputs=[num_batches, num_designs_per_batch, length], outputs=[gen_directory, gen_results]).then( | |
| update_batch_choices, | |
| inputs=gen_results, | |
| outputs=batch_dropdown) | |
| def update_designs(batch, result): | |
| if batch is None: | |
| return gr.update(choices=[]) | |
| designs = sorted(list({d["design"] for d in result if d["batch"] == batch})) | |
| return gr.update(choices=designs) | |
| batch_dropdown.change(update_designs, inputs=[batch_dropdown, gen_results], outputs=[design_dropdown]) | |
| design_dropdown.change() | |
| def visualize_selection(batch, design, result): | |
| if batch is None or design is None: | |
| return gr.update() | |
| pdb_path= next(d["pdb"] for d in result if d["batch"] == int(batch) and d["design"] == int(design)) | |
| with open(pdb_path, 'r') as f: | |
| pdb_str = f.read() | |
| return gr.update(value=f"Selected Batch: {batch}, Design: {design}, saved at {pdb_str}:\n {pdb_str}", visible=True) | |
| visualize_btn.click(visualize_selection, inputs=[batch_dropdown, design_dropdown, gen_results], outputs=display_state) | |
| #def load_viewer(batch, design, result): | |
| # if batch is None or design is None: | |
| # return gr.update() | |
| # pdb_data = next(d["pdb"] for d in result if d["batch"] == int(batch) and d["design"] == int(design)) | |
| # return gr.update(value=pdb_data, visible=True, reps=[{"style": "cartoon"}]) # Customize style | |
| # | |
| #visualize_btn.click(load_viewer, inputs=[batch_dropdown, design_dropdown, gen_results], outputs=viewer) | |
| if __name__ == "__main__": | |
| demo.launch() | |