Spaces:
Running on Zero
Running on Zero
| from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine | |
| import gradio as gr | |
| from lightning.fabric import seed_everything | |
| import time | |
| import os | |
| import spaces | |
| import subprocess | |
| from Bio.PDB import MMCIFParser, PDBIO | |
| import gzip | |
| def test_rfd3_from_notebook(): | |
| # Set seed for reproducibility | |
| seed_everything(0) | |
| # Configure RFD3 inference | |
| config = RFD3InferenceConfig( | |
| specification={ | |
| 'length': 40, # Generate 80-residue proteins | |
| }, | |
| diffusion_batch_size=2, # Generate 2 structures per batch | |
| ) | |
| # Initialize engine and run generation | |
| try: | |
| model = RFD3InferenceEngine(**config) | |
| outputs = model.run( | |
| inputs=None, # None for unconditional generation | |
| out_dir=None, # None to return in memory (no file output) | |
| n_batches=1, # Generate 1 batch | |
| ) | |
| return_str = "RDF3 test passed! Generated structures:\n" | |
| for idx, data in outputs.items(): | |
| return_str += f"Batch {idx}: {len(data)} structure(s)\n" | |
| for i, struct in enumerate(data): | |
| return_str += f"Structure {i+1}: {struct.atom_array.array_length()} Atoms\n" | |
| #return_str += struct.atom_array | |
| return return_str | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Initialize engine and run generation | |
| def unconditional_generation(num_batches, num_designs_per_batch, length): | |
| config = RFD3InferenceConfig( | |
| specification={ | |
| 'length': length, | |
| }, | |
| diffusion_batch_size=num_designs_per_batch, # Generate 2 structures per batch | |
| ) | |
| session_hash = gr.Request().session_hash | |
| time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S") | |
| directory = f"./outputs/unconditional_generation/session_{session_hash}_{time_stamp}" | |
| os.makedirs(directory, exist_ok=False) | |
| try: | |
| model = RFD3InferenceEngine(**config) | |
| outputs = model.run( | |
| inputs=None, # None for unconditional generation | |
| out_dir=directory, # None to return in memory (no file output) | |
| n_batches=num_batches, # Generate 1 batch | |
| ) | |
| results = [] | |
| for batch in range(num_batches): | |
| for design in range(num_designs_per_batch): | |
| file_name = os.path.join(directory, f"_{batch}_model_{design}.cif.gz") | |
| results.append({"batch": batch, "design": design, "file": file_name, "pdb": cif_gz_to_pdb(file_name)}) | |
| print(results) | |
| return directory, results | |
| except Exception as e: | |
| raise RuntimeError(f"Error during generation: {str(e)}") | |
| def collect_outputs(gen_directory, num_batches, num_designs_per_batch): | |
| try: | |
| cmd = f"ls -R {gen_directory}" | |
| file_list = subprocess.check_output(cmd, shell=True).decode() | |
| return file_list | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def cif_gz_to_pdb(cif_gz_path): | |
| """Convert .cif.gz to PDB string for viewer.""" | |
| # Decompress & parse | |
| parser = MMCIFParser(QUIET=True) | |
| with gzip.open(cif_gz_path, 'rt') as f: | |
| struct = parser.get_structure('model', f) | |
| # Write to string | |
| io = PDBIO() | |
| io.set_structure(struct) | |
| pdb_lines = [] | |
| class StringIO: | |
| def write(self, s): pdb_lines.append(s) | |
| io.save(StringIO()) | |
| return ''.join(pdb_lines) |