File size: 3,502 Bytes
94ff1b9
 
 
 
 
c1e206f
5b8fa42
aced7ae
c49e7b8
aced7ae
94ff1b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26f1fe3
94ff1b9
aced7ae
94ff1b9
 
 
 
 
 
 
 
 
 
464a533
aced7ae
 
8705e46
c49e7b8
aced7ae
 
 
 
 
4f422de
aced7ae
 
 
 
5b8fa42
 
94ff1b9
aced7ae
 
 
c49e7b8
 
 
aced7ae
c49e7b8
 
 
 
 
 
 
 
 
 
5d306a8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
import gradio as gr
from lightning.fabric import seed_everything
import time
import os
import spaces
import subprocess
import gzip
import gemmi



@spaces.GPU(duration=300)  
def test_rfd3_from_notebook():
    # Set seed for reproducibility
    seed_everything(0)

    # Configure RFD3 inference
    config = RFD3InferenceConfig(
        specification={
            'length': 40,  # Generate 80-residue proteins
        },
        diffusion_batch_size=2,  # Generate 2 structures per batch
    )

    # Initialize engine and run generation
    try: 
        model = RFD3InferenceEngine(**config)
        outputs = model.run(
            inputs=None,      # None for unconditional generation
            out_dir=None,     # None to return in memory (no file output)
            n_batches=1,      # Generate 1 batch
        )
        return_str = "RDF3 test passed! Generated structures:\n"

        for idx, data in outputs.items():
            return_str += f"Batch {idx}: {len(data)} structure(s)\n"
            for i, struct in enumerate(data):
                return_str += f"Structure {i+1}: {struct.atom_array.array_length()} Atoms\n"
                #return_str += struct.atom_array
        return return_str
    except Exception as e:
        return f"Error: {str(e)}"
    


# Initialize engine and run generation
@spaces.GPU(duration=300) 
def unconditional_generation(num_batches, num_designs_per_batch, length):

    config = RFD3InferenceConfig(
    specification={
        'length': length,
    },
    diffusion_batch_size=num_designs_per_batch,  # Generate 2 structures per batch
    )

    session_hash = gr.Request().session_hash
    time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
    directory = f"./outputs/unconditional_generation/session_{session_hash}_{time_stamp}"
    os.makedirs(directory, exist_ok=False)

    try: 
        model = RFD3InferenceEngine(**config)
        outputs = model.run(
            inputs=None,      # None for unconditional generation
            out_dir=directory,     # None to return in memory (no file output)
            n_batches=num_batches,      # Generate 1 batch
        )

        results = []
        for batch in range(num_batches):
            for design in range(num_designs_per_batch):
                file_name = os.path.join(directory, f"_{batch}_model_{design}.cif.gz")
                results.append({"batch": batch, "design": design, "file": file_name, "pdb": mcif_gz_to_pdb_string_gemmi(file_name)})
            
        print(results)
        return directory, results
    
    except Exception as e:
        raise RuntimeError(f"Error during generation: {str(e)}")
    
def collect_outputs(gen_directory, num_batches, num_designs_per_batch):
    try:
        cmd = f"ls -R {gen_directory}"
        file_list = subprocess.check_output(cmd, shell=True).decode()
        return file_list
    except Exception as e:
        return f"Error: {str(e)}"
    

def mcif_gz_to_pdb_string_gemmi(file_path: str) -> str:
    """
    Converts a .mcif.gz file to a PDB-formatted string.
    
    Args:
        file_path (str): Path to the .mcif.gz file.
    
    Returns:
        str: PDB content as string.
    
    Requires: pip install gemmi
    """
    st = gemmi.read_structure(file_path)
    st.setup_entities()  # Recommended for consistent entity handling [web:18]
    pdb_path = file_path.replace(".cif.gz", ".pdb")
    st.write_minimal_pdb(pdb_path)
    return pdb_path