phanerozoic commited on
Commit
6f4134a
·
verified ·
1 Parent(s): 7f7d423

Rich README + certify/attribute methods, design.py, bundled 8-mer atlas

Browse files
Files changed (1) hide show
  1. design.py +81 -0
design.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Closed-form inverse design against the classifier's linear heads.
2
+
3
+ Because each head is linear in 8-mer counts, sequences that maximize or minimize a
4
+ head score can be found by coordinate ascent. Two modes:
5
+ - free: design a sequence from scratch toward maximum or minimum score
6
+ - synonymous: re-choose codons of a given coding sequence to push the score while
7
+ preserving the encoded protein
8
+
9
+ Usage:
10
+ from design import free_design, synonymous_design
11
+ from model import DnaOriginClassifier
12
+ clf = DnaOriginClassifier("model.safetensors")
13
+ seq = free_design(clf, length=300, direction="max") # maximally host-like
14
+ recoded = synonymous_design(clf, cds, direction="min") # least host-like, same protein
15
+ """
16
+ import random
17
+ import itertools
18
+
19
+ BASES = "ACGT"
20
+ _CODON = {}
21
+ _aa = "KNKNTTTTRSRSIIMIQHQHPPPPRRRRLLLLEDEDAAAAGGGGVVVV*Y*YSSSSLFLF*CWCLFLF" # standard code, see below
22
+ # build the standard genetic code explicitly to avoid ordering ambiguity
23
+ _BASES4 = "TCAG"
24
+ _AAS = ("FFLLSSSSYY**CC*WLLLLPPPPHHQQRRRRIIIMTTTTNNKKSSRRVVVVAAAADDEEGGGG")
25
+ _i = 0
26
+ for a in _BASES4:
27
+ for b in _BASES4:
28
+ for c in _BASES4:
29
+ _CODON[a + b + c] = _AAS[_i]; _i += 1
30
+ _AA2COD = {}
31
+ for cod, aa in _CODON.items():
32
+ _AA2COD.setdefault(aa, []).append(cod)
33
+
34
+
35
+ def _protein(seq):
36
+ return "".join(_CODON.get(seq[i:i + 3], "X") for i in range(0, len(seq) - 2, 3))
37
+
38
+
39
+ def free_design(clf, length=300, direction="max", passes=6, seed=0):
40
+ rng = random.Random(seed)
41
+ seq = [rng.choice(BASES) for _ in range(length)]
42
+ d = 1 if direction == "max" else -1
43
+ for _ in range(passes):
44
+ for p in range(length):
45
+ best, bs = seq[p], clf.host_score("".join(seq))
46
+ for nb in BASES:
47
+ seq[p] = nb
48
+ sc = clf.host_score("".join(seq))
49
+ if d * sc > d * bs:
50
+ bs, best = sc, nb
51
+ seq[p] = best
52
+ return "".join(seq)
53
+
54
+
55
+ def synonymous_design(clf, cds, direction="max", passes=4):
56
+ cod = [cds[i:i + 3] for i in range(0, len(cds) - 2, 3)]
57
+ d = 1 if direction == "max" else -1
58
+ for _ in range(passes):
59
+ for ci in range(len(cod)):
60
+ aa = _CODON.get(cod[ci])
61
+ if aa not in _AA2COD:
62
+ continue
63
+ best, bs = cod[ci], clf.host_score("".join(cod))
64
+ for cand in _AA2COD[aa]:
65
+ cod[ci] = cand
66
+ sc = clf.host_score("".join(cod))
67
+ if d * sc > d * bs:
68
+ bs, best = sc, cand
69
+ cod[ci] = best
70
+ out = "".join(cod)
71
+ assert _protein(out) == _protein(cds), "protein not preserved"
72
+ return out
73
+
74
+
75
+ if __name__ == "__main__":
76
+ from model import DnaOriginClassifier
77
+ clf = DnaOriginClassifier()
78
+ mx = free_design(clf, 300, "max")
79
+ mn = free_design(clf, 300, "min")
80
+ print("max-host design score:", round(clf.host_score(mx), 2))
81
+ print("min-host design score:", round(clf.host_score(mn), 2))