Rich README + certify/attribute methods, design.py, bundled 8-mer atlas
Browse files
model.py
CHANGED
|
@@ -19,12 +19,17 @@ def _index(kmer):
|
|
| 19 |
|
| 20 |
|
| 21 |
class DnaOriginClassifier:
|
| 22 |
-
"""Discriminative 8-mer classifier of DNA origin
|
|
|
|
| 23 |
|
| 24 |
A fixed featurizer counts all 65,536 8-mers and normalizes to within-sequence
|
| 25 |
frequency; three discriminatively trained linear heads read it: a 5-class origin
|
| 26 |
head and two binary detectors (host vs non-host, engineered vs natural). No
|
| 27 |
alignment, no database. Requires only numpy and safetensors.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
"""
|
| 29 |
|
| 30 |
def __init__(self, path="model.safetensors"):
|
|
@@ -34,6 +39,7 @@ class DnaOriginClassifier:
|
|
| 34 |
self.HW, self.Hb = t["host.weight"], t["host.bias"]
|
| 35 |
self.EW, self.Eb = t["engineered.weight"], t["engineered.bias"]
|
| 36 |
|
|
|
|
| 37 |
def features(self, seq):
|
| 38 |
seq = "".join(c for c in seq.upper() if c in _B)
|
| 39 |
v = np.zeros(VOCAB, dtype=np.float32)
|
|
@@ -61,8 +67,76 @@ class DnaOriginClassifier:
|
|
| 61 |
"""Higher means more likely engineered/synthetic (engineered vs natural head)."""
|
| 62 |
return float(self.EW @ self.features(seq) + self.Eb[0])
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
if __name__ == "__main__":
|
| 66 |
clf = DnaOriginClassifier()
|
| 67 |
seq = "ATGGCTAGCAAAGGAGAAGAACTTTTCACTGGAGTTGTCCCAATTCTTGTTGAATTAGATGGTGATGTT" * 5
|
| 68 |
-
print("origin:", clf.classify(seq), "host_score:", round(clf.host_score(seq), 3)
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
class DnaOriginClassifier:
|
| 22 |
+
"""Discriminative 8-mer classifier of DNA origin, with exact closed-form
|
| 23 |
+
interpretability and robustness because the model is linear in 8-mer counts.
|
| 24 |
|
| 25 |
A fixed featurizer counts all 65,536 8-mers and normalizes to within-sequence
|
| 26 |
frequency; three discriminatively trained linear heads read it: a 5-class origin
|
| 27 |
head and two binary detectors (host vs non-host, engineered vs natural). No
|
| 28 |
alignment, no database. Requires only numpy and safetensors.
|
| 29 |
+
|
| 30 |
+
Beyond classify/host_score/engineered_score, the linear form gives:
|
| 31 |
+
- attribute(seq): exact per-base contribution to a head (sums to the score)
|
| 32 |
+
- certify(seq): minimum base substitutions to flip a call (greedy, exact deltas)
|
| 33 |
"""
|
| 34 |
|
| 35 |
def __init__(self, path="model.safetensors"):
|
|
|
|
| 39 |
self.HW, self.Hb = t["host.weight"], t["host.bias"]
|
| 40 |
self.EW, self.Eb = t["engineered.weight"], t["engineered.bias"]
|
| 41 |
|
| 42 |
+
# ---- core ----
|
| 43 |
def features(self, seq):
|
| 44 |
seq = "".join(c for c in seq.upper() if c in _B)
|
| 45 |
v = np.zeros(VOCAB, dtype=np.float32)
|
|
|
|
| 67 |
"""Higher means more likely engineered/synthetic (engineered vs natural head)."""
|
| 68 |
return float(self.EW @ self.features(seq) + self.Eb[0])
|
| 69 |
|
| 70 |
+
# ---- closed-form interpretability and robustness ----
|
| 71 |
+
def _eff(self, head):
|
| 72 |
+
w = {"host": self.HW, "engineered": self.EW}[head]
|
| 73 |
+
return w / self.scale
|
| 74 |
+
|
| 75 |
+
def _bias(self, head):
|
| 76 |
+
return float({"host": self.Hb, "engineered": self.Eb}[head][0])
|
| 77 |
+
|
| 78 |
+
def attribute(self, seq, head="host"):
|
| 79 |
+
"""Exact per-base contribution of each position to the head score.
|
| 80 |
+
|
| 81 |
+
The score is a sum over 8-mer windows; this distributes each window's weight
|
| 82 |
+
across its 8 bases, so the contributions sum to (score - bias) with no
|
| 83 |
+
approximation. Returns an array of length len(seq).
|
| 84 |
+
"""
|
| 85 |
+
seq = "".join(c for c in seq.upper() if c in _B)
|
| 86 |
+
w = self._eff(head)
|
| 87 |
+
n = max(1, len(seq) - K + 1)
|
| 88 |
+
contrib = np.zeros(len(seq))
|
| 89 |
+
for i in range(len(seq) - K + 1):
|
| 90 |
+
j = _index(seq[i:i + K])
|
| 91 |
+
if j is None:
|
| 92 |
+
continue
|
| 93 |
+
per = w[j] / n / K
|
| 94 |
+
contrib[i:i + K] += per
|
| 95 |
+
return contrib
|
| 96 |
+
|
| 97 |
+
def certify(self, seq, head="host", max_edits=80):
|
| 98 |
+
"""Minimum base substitutions (greedy, with exact per-edit deltas) to flip the
|
| 99 |
+
head's sign. Returns the edit count, or None if not flipped within max_edits.
|
| 100 |
+
A near-tight upper bound on the true minimum adversarial radius.
|
| 101 |
+
"""
|
| 102 |
+
seq = [c for c in seq.upper() if c in _B]
|
| 103 |
+
w = self._eff(head)
|
| 104 |
+
b = self._bias(head)
|
| 105 |
+
n = max(1, len(seq) - K + 1)
|
| 106 |
+
|
| 107 |
+
def score(s):
|
| 108 |
+
tot = 0.0
|
| 109 |
+
for i in range(len(s) - K + 1):
|
| 110 |
+
j = _index(s[i:i + K])
|
| 111 |
+
if j is not None:
|
| 112 |
+
tot += w[j]
|
| 113 |
+
return tot / n + b
|
| 114 |
+
|
| 115 |
+
sign = 1 if score("".join(seq)) > 0 else -1
|
| 116 |
+
edits = 0
|
| 117 |
+
while sign * score("".join(seq)) > 0 and edits < max_edits:
|
| 118 |
+
s = "".join(seq)
|
| 119 |
+
best_d, best = 0.0, None
|
| 120 |
+
for p in range(len(seq)):
|
| 121 |
+
wins = range(max(0, p - K + 1), min(p, n - 1) + 1)
|
| 122 |
+
old = sum(w[_index(s[a:a + K])] for a in wins if _index(s[a:a + K]) is not None)
|
| 123 |
+
for nb in BASES:
|
| 124 |
+
if nb == seq[p]:
|
| 125 |
+
continue
|
| 126 |
+
s2 = s[:p] + nb + s[p + 1:]
|
| 127 |
+
new = sum(w[_index(s2[a:a + K])] for a in wins if _index(s2[a:a + K]) is not None)
|
| 128 |
+
d = (new - old) / n
|
| 129 |
+
if sign * d < best_d:
|
| 130 |
+
best_d, best = sign * d, (p, nb)
|
| 131 |
+
if best is None:
|
| 132 |
+
break
|
| 133 |
+
seq[best[0]] = best[1]
|
| 134 |
+
edits += 1
|
| 135 |
+
return edits if sign * score("".join(seq)) <= 0 else None
|
| 136 |
+
|
| 137 |
|
| 138 |
if __name__ == "__main__":
|
| 139 |
clf = DnaOriginClassifier()
|
| 140 |
seq = "ATGGCTAGCAAAGGAGAAGAACTTTTCACTGGAGTTGTCCCAATTCTTGTTGAATTAGATGGTGATGTT" * 5
|
| 141 |
+
print("origin:", clf.classify(seq), "host_score:", round(clf.host_score(seq), 3),
|
| 142 |
+
"edits_to_flip:", clf.certify(seq), "top_base_contrib:", round(float(clf.attribute(seq).max()), 4))
|