File size: 11,694 Bytes
4d79fd9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 | (* ============================================================================
Certified robustness of the dna-origin-classifier, machine-checked in Rocq 9.
The published model scores a sequence as a linear function of its k-mer
counts: score(s) = (1/T) * sum over windows w of s of u(w) + bias, where
u(w) is the effective weight of the k-mer w (the stored weight divided by
its feature scale) and T is the number of windows.
Because the score is linear, a single base substitution perturbs the counts
of at most k windows, each by an effective weight bounded by Umax in
magnitude, so one edit moves the unnormalized margin by at most k*(2*Umax),
and n edits by at most n*k*(2*Umax). The main theorem turns that bound into
a certificate: if n * k * (2*Umax) < |margin(s)|, no n-substitution edit of s
can change the sign of the call.
This proven radius is a lower bound on the adversarial edit distance: no
substitution count within it can flip the call. The Python `certify` exhibits
an actual flip by greedy search, an upper bound on the same distance, so the
two bracket it. For the published host head the effective weights are
heavy-tailed (max |u| about 95) and the achievable margin is small (|score|
at most about 9), so the proven radius is a few edits at most.
Everything is proved over the rationals with no axioms and no admits.
========================================================================== *)
From Stdlib Require Import List Arith Lia QArith Qabs.
From Stdlib Require Import Lqa.
Import ListNotations.
Open Scope Q_scope.
Section Detector.
(* An abstract alphabet. The published model instantiates it with the four
DNA bases; nothing below depends on the cardinality. *)
Variable Base : Type.
(* k-mer width (8 in the published model). *)
Variable k : nat.
(* Effective per-k-mer weight: the stored weight divided by its feature scale. *)
Variable u : list Base -> Q.
(* A uniform bound on the magnitude of every effective weight. The published
host head has a concrete, finite Umax computed from its 65,536 weights. *)
Variable Umax : Q.
Hypothesis Hu : forall w, Qabs (u w) <= Umax.
(* A sequence is a total map from index to symbol; only indices below the
length are ever inspected. *)
(* The j-th length-k window of f. *)
Definition win (f : nat -> Base) (j : nat) : list Base :=
map (fun p => f (j + p)%nat) (seq 0 k).
(* Number of windows of a length-L sequence (for k <= L this is L - k + 1). *)
Definition nwin (L : nat) : nat := S (L - k).
(* Unnormalized score: the sum of effective weights over all windows. *)
Definition raw (L : nat) (f : nat -> Base) : Q :=
fold_right Qplus 0 (map (fun j => u (win f j)) (seq 0 (nwin L))).
(* Substitute symbol b at position i. *)
Definition upd (f : nat -> Base) (i : nat) (b : Base) : nat -> Base :=
fun j => if Nat.eqb j i then b else f j.
(* A window of index j "contains" position i when j <= i < j + k. *)
Definition contains (i j : nat) : bool := andb (Nat.leb j i) (Nat.ltb i (j + k)).
(* The per-edit margin bound: one substitution moves raw by at most this. *)
Definition K : Q := inject_Z (Z.of_nat k) * (Umax + Umax).
(* ---------- generic rational / list helpers ---------- *)
Lemma Qabs_minus_le : forall a b : Q, Qabs (a - b) <= Qabs a + Qabs b.
Proof.
intros a b. unfold Qminus.
eapply Qle_trans; [apply Qabs_triangle|].
rewrite Qabs_opp. apply Qle_refl.
Qed.
Lemma Qabs_0_le : Qabs 0 <= 0.
Proof. rewrite Qabs_pos by apply Qle_refl. apply Qle_refl. Qed.
Lemma Qabs_sum_le : forall l : list Q,
Qabs (fold_right Qplus 0 l) <= fold_right Qplus 0 (map Qabs l).
Proof.
induction l as [|x l IH]; simpl.
- apply Qle_refl.
- eapply Qle_trans; [apply Qabs_triangle|].
apply Qplus_le_compat; [apply Qle_refl| exact IH].
Qed.
Lemma fold_right_Qplus_le : forall (l : list nat) (a b : nat -> Q),
(forall j, In j l -> a j <= b j) ->
fold_right Qplus 0 (map a l) <= fold_right Qplus 0 (map b l).
Proof.
induction l as [|x l IH]; intros a b H; simpl.
- apply Qle_refl.
- apply Qplus_le_compat.
+ apply H. left. reflexivity.
+ apply IH. intros j Hj. apply H. right. exact Hj.
Qed.
Lemma diff_of_folds : forall (l : list nat) (A B : nat -> Q),
fold_right Qplus 0 (map A l) - fold_right Qplus 0 (map B l)
== fold_right Qplus 0 (map (fun j => A j - B j) l).
Proof.
induction l as [|x l IH]; intros A B; simpl.
- ring.
- rewrite <- IH. ring.
Qed.
Lemma inject_Z_Snat : forall n : nat,
inject_Z (Z.of_nat (S n)) == inject_Z (Z.of_nat n) + 1.
Proof.
intros n.
assert (HZ : Z.of_nat (S n) = (Z.of_nat n + 1)%Z) by lia.
rewrite HZ. rewrite inject_Z_plus.
replace (inject_Z 1) with 1 by reflexivity. reflexivity.
Qed.
Lemma indicator_sum : forall (l : list nat) (P : nat -> bool) (c : Q),
fold_right Qplus 0 (map (fun j => if P j then c else 0) l)
== c * inject_Z (Z.of_nat (length (filter P l))).
Proof.
induction l as [|x l IH]; intros P c; simpl.
- ring.
- destruct (P x) eqn:Hx; simpl.
+ rewrite IH. rewrite (inject_Z_Snat (length (filter P l))). ring.
+ rewrite IH. ring.
Qed.
(* ---------- structural facts about windows and substitution ---------- *)
(* A window that does not contain the edited position is unchanged. *)
Lemma win_eq_off : forall f i b j,
contains i j = false -> win (upd f i b) j = win f j.
Proof.
intros f i b j Hc. unfold win. apply map_ext_in.
intros p Hp. apply in_seq in Hp. destruct Hp as [_ Hp]. simpl in Hp.
unfold upd. destruct (Nat.eqb (j + p) i) eqn:E; [|reflexivity].
apply Nat.eqb_eq in E. exfalso.
assert (Hct : contains i j = true).
{ unfold contains. apply andb_true_intro. split.
- apply Nat.leb_le. lia.
- apply Nat.ltb_lt. lia. }
rewrite Hct in Hc. discriminate.
Qed.
(* At most k windows contain a given position. *)
Lemma count_le_k : forall L i,
(length (filter (contains i) (seq 0 (nwin L))) <= k)%nat.
Proof.
intros L i.
apply Nat.le_trans with (length (seq (i + 1 - k) k)).
- apply NoDup_incl_length.
+ apply NoDup_filter. apply seq_NoDup.
+ intros x Hx. apply filter_In in Hx. destruct Hx as [_ Hc].
unfold contains in Hc. apply andb_true_iff in Hc. destruct Hc as [Hle Hlt].
apply Nat.leb_le in Hle. apply Nat.ltb_lt in Hlt.
apply in_seq. split; lia.
- rewrite length_seq. apply Nat.le_refl.
Qed.
(* ---------- one substitution moves raw by at most K ---------- *)
Lemma single_subst : forall L f i b,
Qabs (raw L f - raw L (upd f i b)) <= K.
Proof.
intros L f i b. unfold raw.
rewrite (diff_of_folds (seq 0 (nwin L))
(fun j => u (win f j)) (fun j => u (win (upd f i b) j))).
eapply Qle_trans.
{ apply Qabs_sum_le. }
rewrite map_map.
eapply Qle_trans.
{ apply (fold_right_Qplus_le (seq 0 (nwin L))
(fun j => Qabs (u (win f j) - u (win (upd f i b) j)))
(fun j => if contains i j then Umax + Umax else 0)).
intros j Hj. cbv beta. destruct (contains i j) eqn:Hcj.
- eapply Qle_trans; [apply Qabs_minus_le|].
apply Qplus_le_compat; apply Hu.
- assert (Hw : win (upd f i b) j = win f j) by (apply win_eq_off; exact Hcj).
rewrite Hw.
setoid_replace (u (win f j) - u (win f j)) with 0 by ring.
apply Qabs_0_le. }
rewrite (indicator_sum (seq 0 (nwin L)) (contains i) (Umax + Umax)).
unfold K.
rewrite (Qmult_comm (Umax + Umax)).
apply Qmult_le_compat_r.
- rewrite <- Zle_Qle. pose proof (count_le_k L i) as Hck. lia.
- assert (HU0 : 0 <= Umax).
{ eapply Qle_trans; [apply (Qabs_nonneg (u nil))| apply Hu]. }
lra.
Qed.
(* ---------- n substitutions move raw by at most n * K ---------- *)
(* g is reachable from f by exactly n single-symbol substitutions. *)
Inductive reach (L : nat) (f : nat -> Base) : (nat -> Base) -> nat -> Prop :=
| reach0 : reach L f f 0
| reachS : forall g i b n, reach L f g n -> reach L f (upd g i b) (S n).
Lemma reach_bound : forall L f g n,
reach L f g n ->
Qabs (raw L f - raw L g) <= inject_Z (Z.of_nat n) * K.
Proof.
intros L f g n H. induction H.
- setoid_replace (raw L f - raw L f) with 0 by ring.
setoid_replace (inject_Z (Z.of_nat 0) * K) with 0
by (replace (inject_Z (Z.of_nat 0)) with 0 by reflexivity; ring).
apply Qabs_0_le.
- eapply Qle_trans.
+ setoid_replace (raw L f - raw L (upd g i b))
with ((raw L f - raw L g) + (raw L g - raw L (upd g i b))) by ring.
apply Qabs_triangle.
+ eapply Qle_trans.
{ apply Qplus_le_compat; [exact IHreach| apply single_subst]. }
setoid_replace (inject_Z (Z.of_nat (S n)) * K)
with (inject_Z (Z.of_nat n) * K + K) by (rewrite (inject_Z_Snat n); ring).
apply Qle_refl.
Qed.
(* ---------- the certificate ---------- *)
(* The unnormalized margin: positive favors the class, negative opposes it.
It shares the sign of the normalized score because the window count is
positive, and a same-length edit leaves the window count fixed. *)
Definition margin (L : nat) (b0 : Q) (f : nat -> Base) : Q :=
raw L f + b0 * inject_Z (Z.of_nat (nwin L)).
Lemma margin_diff : forall L b0 f g,
margin L b0 f - margin L b0 g == raw L f - raw L g.
Proof. intros. unfold margin. ring. Qed.
Lemma sign_preserved : forall x d : Q,
Qabs d < Qabs x ->
(0 < x -> 0 < x - d) /\ (x < 0 -> x - d < 0).
Proof.
intros x d Hlt. split; intro Hx.
- assert (Hax : Qabs x == x) by (apply Qabs_pos; lra).
assert (Hd : d <= Qabs d) by apply Qle_Qabs.
lra.
- assert (Hax : Qabs x == - x) by (apply Qabs_neg; lra).
assert (Hcc := proj1 (Qabs_Qle_condition d (Qabs d)) (Qle_refl (Qabs d))).
destruct Hcc as [Hc _].
lra.
Qed.
(* Main theorem. If g is reachable from f by n substitutions and the
certificate n * K < |margin(f)| holds, the call's sign cannot change.
K = k * (2 * Umax), so the certified radius is the largest n with
n * k * 2 * Umax < |margin|, which is what `certify` returns. *)
Theorem certified_robust : forall L b0 f g n,
reach L f g n ->
inject_Z (Z.of_nat n) * K < Qabs (margin L b0 f) ->
(0 < margin L b0 f -> 0 < margin L b0 g) /\
(margin L b0 f < 0 -> margin L b0 g < 0).
Proof.
intros L b0 f g n Hreach Hcert.
assert (Hb : Qabs (raw L f - raw L g) <= inject_Z (Z.of_nat n) * K)
by (apply reach_bound; exact Hreach).
set (x := margin L b0 f). set (y := margin L b0 g).
assert (Hd : Qabs (x - y) < Qabs x).
{ setoid_replace (x - y) with (raw L f - raw L g)
by (unfold x, y; apply margin_diff).
eapply Qle_lt_trans; [exact Hb| exact Hcert]. }
destruct (sign_preserved x (x - y) Hd) as [Hpos Hneg].
split; intro Hsgn.
- assert (Hy : 0 < x - (x - y)) by (apply Hpos; exact Hsgn).
setoid_replace y with (x - (x - y)) by ring. exact Hy.
- assert (Hy : x - (x - y) < 0) by (apply Hneg; exact Hsgn).
setoid_replace y with (x - (x - y)) by ring. exact Hy.
Qed.
End Detector.
(* Instantiation to the published model's setting: the four DNA bases and k = 8.
The host head's effective weights have a concrete finite bound Umax_host, so
the guarantee holds with per-substitution margin constant K = 8 * (2 * Umax_host).
The radius `certify` reports is the largest n with n * K < |margin|. *)
Inductive DNA := dA | dC | dG | dT.
Corollary certified_robust_dna :
forall (u : list DNA -> Q) (Umax : Q),
(forall w, Qabs (u w) <= Umax) ->
forall (L : nat) (b0 : Q) (f g : nat -> DNA) (n : nat),
reach DNA L f g n ->
inject_Z (Z.of_nat n) * K 8 Umax < Qabs (margin DNA 8 u L b0 f) ->
(0 < margin DNA 8 u L b0 f -> 0 < margin DNA 8 u L b0 g) /\
(margin DNA 8 u L b0 f < 0 -> margin DNA 8 u L b0 g < 0).
Proof.
intros u Umax Hu. exact (certified_robust DNA 8 u Umax Hu).
Qed.
|