Spaces:
Running on Zero
Running on Zero
File size: 3,502 Bytes
94ff1b9 c1e206f 5b8fa42 aced7ae c49e7b8 aced7ae 94ff1b9 26f1fe3 94ff1b9 aced7ae 94ff1b9 464a533 aced7ae 8705e46 c49e7b8 aced7ae 4f422de aced7ae 5b8fa42 94ff1b9 aced7ae c49e7b8 aced7ae c49e7b8 5d306a8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
import gradio as gr
from lightning.fabric import seed_everything
import time
import os
import spaces
import subprocess
import gzip
import gemmi
@spaces.GPU(duration=300)
def test_rfd3_from_notebook():
# Set seed for reproducibility
seed_everything(0)
# Configure RFD3 inference
config = RFD3InferenceConfig(
specification={
'length': 40, # Generate 80-residue proteins
},
diffusion_batch_size=2, # Generate 2 structures per batch
)
# Initialize engine and run generation
try:
model = RFD3InferenceEngine(**config)
outputs = model.run(
inputs=None, # None for unconditional generation
out_dir=None, # None to return in memory (no file output)
n_batches=1, # Generate 1 batch
)
return_str = "RDF3 test passed! Generated structures:\n"
for idx, data in outputs.items():
return_str += f"Batch {idx}: {len(data)} structure(s)\n"
for i, struct in enumerate(data):
return_str += f"Structure {i+1}: {struct.atom_array.array_length()} Atoms\n"
#return_str += struct.atom_array
return return_str
except Exception as e:
return f"Error: {str(e)}"
# Initialize engine and run generation
@spaces.GPU(duration=300)
def unconditional_generation(num_batches, num_designs_per_batch, length):
config = RFD3InferenceConfig(
specification={
'length': length,
},
diffusion_batch_size=num_designs_per_batch, # Generate 2 structures per batch
)
session_hash = gr.Request().session_hash
time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
directory = f"./outputs/unconditional_generation/session_{session_hash}_{time_stamp}"
os.makedirs(directory, exist_ok=False)
try:
model = RFD3InferenceEngine(**config)
outputs = model.run(
inputs=None, # None for unconditional generation
out_dir=directory, # None to return in memory (no file output)
n_batches=num_batches, # Generate 1 batch
)
results = []
for batch in range(num_batches):
for design in range(num_designs_per_batch):
file_name = os.path.join(directory, f"_{batch}_model_{design}.cif.gz")
results.append({"batch": batch, "design": design, "file": file_name, "pdb": mcif_gz_to_pdb_string_gemmi(file_name)})
print(results)
return directory, results
except Exception as e:
raise RuntimeError(f"Error during generation: {str(e)}")
def collect_outputs(gen_directory, num_batches, num_designs_per_batch):
try:
cmd = f"ls -R {gen_directory}"
file_list = subprocess.check_output(cmd, shell=True).decode()
return file_list
except Exception as e:
return f"Error: {str(e)}"
def mcif_gz_to_pdb_string_gemmi(file_path: str) -> str:
"""
Converts a .mcif.gz file to a PDB-formatted string.
Args:
file_path (str): Path to the .mcif.gz file.
Returns:
str: PDB content as string.
Requires: pip install gemmi
"""
st = gemmi.read_structure(file_path)
st.setup_entities() # Recommended for consistent entity handling [web:18]
pdb_path = file_path.replace(".cif.gz", ".pdb")
st.write_minimal_pdb(pdb_path)
return pdb_path
|