RFdiffusion3 / utils /pipelines.py
gabboud's picture
fix path
8705e46
Raw
History Blame
3.45 kB
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
import gradio as gr
from lightning.fabric import seed_everything
import time
import os
import spaces
import subprocess
from Bio.PDB import MMCIFParser, PDBIO
import gzip
@spaces.GPU(duration=300)
def test_rfd3_from_notebook():
# Set seed for reproducibility
seed_everything(0)
# Configure RFD3 inference
config = RFD3InferenceConfig(
specification={
'length': 40, # Generate 80-residue proteins
},
diffusion_batch_size=2, # Generate 2 structures per batch
)
# Initialize engine and run generation
try:
model = RFD3InferenceEngine(**config)
outputs = model.run(
inputs=None, # None for unconditional generation
out_dir=None, # None to return in memory (no file output)
n_batches=1, # Generate 1 batch
)
return_str = "RDF3 test passed! Generated structures:\n"
for idx, data in outputs.items():
return_str += f"Batch {idx}: {len(data)} structure(s)\n"
for i, struct in enumerate(data):
return_str += f"Structure {i+1}: {struct.atom_array.array_length()} Atoms\n"
#return_str += struct.atom_array
return return_str
except Exception as e:
return f"Error: {str(e)}"
# Initialize engine and run generation
@spaces.GPU(duration=300)
def unconditional_generation(num_batches, num_designs_per_batch, length):
config = RFD3InferenceConfig(
specification={
'length': length,
},
diffusion_batch_size=num_designs_per_batch, # Generate 2 structures per batch
)
session_hash = gr.Request().session_hash
time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
directory = f"./outputs/unconditional_generation/session_{session_hash}_{time_stamp}"
os.makedirs(directory, exist_ok=False)
try:
model = RFD3InferenceEngine(**config)
outputs = model.run(
inputs=None, # None for unconditional generation
out_dir=directory, # None to return in memory (no file output)
n_batches=num_batches, # Generate 1 batch
)
results = []
for batch in range(num_batches):
for design in range(num_designs_per_batch):
file_name = os.path.join(directory, f"_{batch}_model_{design}.cif.gz")
results.append({"batch": batch, "design": design, "file": file_name, "pdb": cif_gz_to_pdb(file_name)})
print(results)
return directory, results
except Exception as e:
raise RuntimeError(f"Error during generation: {str(e)}")
def collect_outputs(gen_directory, num_batches, num_designs_per_batch):
try:
cmd = f"ls -R {gen_directory}"
file_list = subprocess.check_output(cmd, shell=True).decode()
return file_list
except Exception as e:
return f"Error: {str(e)}"
def cif_gz_to_pdb(cif_gz_path):
"""Convert .cif.gz to PDB string for viewer."""
# Decompress & parse
parser = MMCIFParser(QUIET=True)
with gzip.open(cif_gz_path, 'rt') as f:
struct = parser.get_structure('model', f)
# Write to string
io = PDBIO()
io.set_structure(struct)
pdb_lines = []
class StringIO:
def write(self, s): pdb_lines.append(s)
io.save(StringIO())
return ''.join(pdb_lines)