phanerozoic commited on
Commit
7f7d423
·
verified ·
1 Parent(s): 125c662

Rich README + certify/attribute methods, design.py, bundled 8-mer atlas

Browse files
Files changed (1) hide show
  1. model.py +76 -2
model.py CHANGED
@@ -19,12 +19,17 @@ def _index(kmer):
19
 
20
 
21
  class DnaOriginClassifier:
22
- """Discriminative 8-mer classifier of DNA origin.
 
23
 
24
  A fixed featurizer counts all 65,536 8-mers and normalizes to within-sequence
25
  frequency; three discriminatively trained linear heads read it: a 5-class origin
26
  head and two binary detectors (host vs non-host, engineered vs natural). No
27
  alignment, no database. Requires only numpy and safetensors.
 
 
 
 
28
  """
29
 
30
  def __init__(self, path="model.safetensors"):
@@ -34,6 +39,7 @@ class DnaOriginClassifier:
34
  self.HW, self.Hb = t["host.weight"], t["host.bias"]
35
  self.EW, self.Eb = t["engineered.weight"], t["engineered.bias"]
36
 
 
37
  def features(self, seq):
38
  seq = "".join(c for c in seq.upper() if c in _B)
39
  v = np.zeros(VOCAB, dtype=np.float32)
@@ -61,8 +67,76 @@ class DnaOriginClassifier:
61
  """Higher means more likely engineered/synthetic (engineered vs natural head)."""
62
  return float(self.EW @ self.features(seq) + self.Eb[0])
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  if __name__ == "__main__":
66
  clf = DnaOriginClassifier()
67
  seq = "ATGGCTAGCAAAGGAGAAGAACTTTTCACTGGAGTTGTCCCAATTCTTGTTGAATTAGATGGTGATGTT" * 5
68
- print("origin:", clf.classify(seq), "host_score:", round(clf.host_score(seq), 3))
 
 
19
 
20
 
21
  class DnaOriginClassifier:
22
+ """Discriminative 8-mer classifier of DNA origin, with exact closed-form
23
+ interpretability and robustness because the model is linear in 8-mer counts.
24
 
25
  A fixed featurizer counts all 65,536 8-mers and normalizes to within-sequence
26
  frequency; three discriminatively trained linear heads read it: a 5-class origin
27
  head and two binary detectors (host vs non-host, engineered vs natural). No
28
  alignment, no database. Requires only numpy and safetensors.
29
+
30
+ Beyond classify/host_score/engineered_score, the linear form gives:
31
+ - attribute(seq): exact per-base contribution to a head (sums to the score)
32
+ - certify(seq): minimum base substitutions to flip a call (greedy, exact deltas)
33
  """
34
 
35
  def __init__(self, path="model.safetensors"):
 
39
  self.HW, self.Hb = t["host.weight"], t["host.bias"]
40
  self.EW, self.Eb = t["engineered.weight"], t["engineered.bias"]
41
 
42
+ # ---- core ----
43
  def features(self, seq):
44
  seq = "".join(c for c in seq.upper() if c in _B)
45
  v = np.zeros(VOCAB, dtype=np.float32)
 
67
  """Higher means more likely engineered/synthetic (engineered vs natural head)."""
68
  return float(self.EW @ self.features(seq) + self.Eb[0])
69
 
70
+ # ---- closed-form interpretability and robustness ----
71
+ def _eff(self, head):
72
+ w = {"host": self.HW, "engineered": self.EW}[head]
73
+ return w / self.scale
74
+
75
+ def _bias(self, head):
76
+ return float({"host": self.Hb, "engineered": self.Eb}[head][0])
77
+
78
+ def attribute(self, seq, head="host"):
79
+ """Exact per-base contribution of each position to the head score.
80
+
81
+ The score is a sum over 8-mer windows; this distributes each window's weight
82
+ across its 8 bases, so the contributions sum to (score - bias) with no
83
+ approximation. Returns an array of length len(seq).
84
+ """
85
+ seq = "".join(c for c in seq.upper() if c in _B)
86
+ w = self._eff(head)
87
+ n = max(1, len(seq) - K + 1)
88
+ contrib = np.zeros(len(seq))
89
+ for i in range(len(seq) - K + 1):
90
+ j = _index(seq[i:i + K])
91
+ if j is None:
92
+ continue
93
+ per = w[j] / n / K
94
+ contrib[i:i + K] += per
95
+ return contrib
96
+
97
+ def certify(self, seq, head="host", max_edits=80):
98
+ """Minimum base substitutions (greedy, with exact per-edit deltas) to flip the
99
+ head's sign. Returns the edit count, or None if not flipped within max_edits.
100
+ A near-tight upper bound on the true minimum adversarial radius.
101
+ """
102
+ seq = [c for c in seq.upper() if c in _B]
103
+ w = self._eff(head)
104
+ b = self._bias(head)
105
+ n = max(1, len(seq) - K + 1)
106
+
107
+ def score(s):
108
+ tot = 0.0
109
+ for i in range(len(s) - K + 1):
110
+ j = _index(s[i:i + K])
111
+ if j is not None:
112
+ tot += w[j]
113
+ return tot / n + b
114
+
115
+ sign = 1 if score("".join(seq)) > 0 else -1
116
+ edits = 0
117
+ while sign * score("".join(seq)) > 0 and edits < max_edits:
118
+ s = "".join(seq)
119
+ best_d, best = 0.0, None
120
+ for p in range(len(seq)):
121
+ wins = range(max(0, p - K + 1), min(p, n - 1) + 1)
122
+ old = sum(w[_index(s[a:a + K])] for a in wins if _index(s[a:a + K]) is not None)
123
+ for nb in BASES:
124
+ if nb == seq[p]:
125
+ continue
126
+ s2 = s[:p] + nb + s[p + 1:]
127
+ new = sum(w[_index(s2[a:a + K])] for a in wins if _index(s2[a:a + K]) is not None)
128
+ d = (new - old) / n
129
+ if sign * d < best_d:
130
+ best_d, best = sign * d, (p, nb)
131
+ if best is None:
132
+ break
133
+ seq[best[0]] = best[1]
134
+ edits += 1
135
+ return edits if sign * score("".join(seq)) <= 0 else None
136
+
137
 
138
  if __name__ == "__main__":
139
  clf = DnaOriginClassifier()
140
  seq = "ATGGCTAGCAAAGGAGAAGAACTTTTCACTGGAGTTGTCCCAATTCTTGTTGAATTAGATGGTGATGTT" * 5
141
+ print("origin:", clf.classify(seq), "host_score:", round(clf.host_score(seq), 3),
142
+ "edits_to_flip:", clf.certify(seq), "top_base_contrib:", round(float(clf.attribute(seq).max()), 4))