Spaces:
Running on Zero
Running on Zero
File size: 4,057 Bytes
362a4c3 a887836 362a4c3 a887836 f19aa8b 94ff1b9 743042b a887836 362a4c3 f19aa8b 5fe520a 362a4c3 27305e9 362a4c3 27305e9 362a4c3 a887836 362a4c3 27305e9 aced7ae 27305e9 5d306a8 aced7ae 65e7711 5d306a8 65e7711 5d306a8 65e7711 5d306a8 743042b 65e7711 5d306a8 65e7711 5d306a8 65e7711 5d306a8 65e7711 5d306a8 743042b 5d306a8 743042b 5d306a8 743042b 65e7711 743042b 27305e9 362a4c3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | import gradio as gr
import warnings
import os
import subprocess
from pathlib import Path
import shutil
import spaces
from atomworks.io.utils.visualize import view
from lightning.fabric import seed_everything
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
from utils import download_weights
from utils.pipelines import test_rfd3_from_notebook, unconditional_generation
#from gradio_molecule3d import Molecule3D
download_weights()
# Gradio UI
with gr.Blocks(title="RFD3 Test") as demo:
gr.Markdown("# RFdiffusion3 (RFD3) for Backbone generation")
gr.Markdown("Models auto-downloaded on launch. Click to test.")
test_btn = gr.Button("Run RFD3 Test")
output = gr.Textbox(label="Test Result")
test_btn.click(test_rfd3_from_notebook, outputs=output)
gr.Markdown("Unconditional generation of backbones")
with gr.Row():
num_designs_per_batch = gr.Number(
value=2,
label="Number of Designs per Batch",
precision=0,
minimum=1,
maximum=8
)
num_batches = gr.Number(
value=5,
label="Number of Batches",
precision=0,
minimum=1,
maximum=10
)
length = gr.Number(
value=40,
label="Length of Protein (number of residues)",
precision=0,
minimum=10,
maximum=200
)
gen_directory = gr.State(None)
gen_results = gr.State(None)
gen_btn = gr.Button("Run Unconditional Generation")
# New visualize section
with gr.Row():
batch_dropdown = gr.Dropdown(
choices=[],
label="Select Batch",
visible=True
)
design_dropdown = gr.Dropdown(
choices=[],
label="Select Design",
visible=True
)
visualize_btn = gr.Button("Visualize", visible=True)
display_state = gr.Textbox(label="Selected Batch and Design", visible=True)
display_state.value = "Please Select a Batch and Design number to show sequence"
#viewer = Molecule3D(visible=True)
def update_batch_choices(result):
if result is None:
return gr.Dropdown(choices = [])
batches = sorted(list({d["batch"] for d in result}))
return gr.update(choices=batches, visible=True)
gen_btn.click(unconditional_generation, inputs=[num_batches, num_designs_per_batch, length], outputs=[gen_directory, gen_results]).then(
update_batch_choices,
inputs=gen_results,
outputs=batch_dropdown)
def update_designs(batch, result):
if batch is None:
return gr.update(choices=[])
designs = sorted(list({d["design"] for d in result if d["batch"] == batch}))
return gr.update(choices=designs)
batch_dropdown.change(update_designs, inputs=[batch_dropdown, gen_results], outputs=[design_dropdown])
design_dropdown.change()
def visualize_selection(batch, design, result):
if batch is None or design is None:
return gr.update()
pdb_path= next(d["pdb"] for d in result if d["batch"] == int(batch) and d["design"] == int(design))
with open(pdb_path, 'r') as f:
pdb_str = f.read()
return gr.update(value=f"Selected Batch: {batch}, Design: {design}, saved at {pdb_str}:\n {pdb_str}", visible=True)
visualize_btn.click(visualize_selection, inputs=[batch_dropdown, design_dropdown, gen_results], outputs=display_state)
#def load_viewer(batch, design, result):
# if batch is None or design is None:
# return gr.update()
# pdb_data = next(d["pdb"] for d in result if d["batch"] == int(batch) and d["design"] == int(design))
# return gr.update(value=pdb_data, visible=True, reps=[{"style": "cartoon"}]) # Customize style
#
#visualize_btn.click(load_viewer, inputs=[batch_dropdown, design_dropdown, gen_results], outputs=viewer)
if __name__ == "__main__":
demo.launch()
|