File size: 11,694 Bytes
4d79fd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
(* ============================================================================
   Certified robustness of the dna-origin-classifier, machine-checked in Rocq 9.

   The published model scores a sequence as a linear function of its k-mer
   counts: score(s) = (1/T) * sum over windows w of s of u(w) + bias, where
   u(w) is the effective weight of the k-mer w (the stored weight divided by
   its feature scale) and T is the number of windows.

   Because the score is linear, a single base substitution perturbs the counts
   of at most k windows, each by an effective weight bounded by Umax in
   magnitude, so one edit moves the unnormalized margin by at most k*(2*Umax),
   and n edits by at most n*k*(2*Umax). The main theorem turns that bound into
   a certificate: if n * k * (2*Umax) < |margin(s)|, no n-substitution edit of s
   can change the sign of the call.

   This proven radius is a lower bound on the adversarial edit distance: no
   substitution count within it can flip the call. The Python `certify` exhibits
   an actual flip by greedy search, an upper bound on the same distance, so the
   two bracket it. For the published host head the effective weights are
   heavy-tailed (max |u| about 95) and the achievable margin is small (|score|
   at most about 9), so the proven radius is a few edits at most.

   Everything is proved over the rationals with no axioms and no admits.
   ========================================================================== *)

From Stdlib Require Import List Arith Lia QArith Qabs.
From Stdlib Require Import Lqa.
Import ListNotations.
Open Scope Q_scope.

Section Detector.

(* An abstract alphabet. The published model instantiates it with the four
   DNA bases; nothing below depends on the cardinality. *)
Variable Base : Type.

(* k-mer width (8 in the published model). *)
Variable k : nat.

(* Effective per-k-mer weight: the stored weight divided by its feature scale. *)
Variable u : list Base -> Q.

(* A uniform bound on the magnitude of every effective weight. The published
   host head has a concrete, finite Umax computed from its 65,536 weights. *)
Variable Umax : Q.
Hypothesis Hu : forall w, Qabs (u w) <= Umax.

(* A sequence is a total map from index to symbol; only indices below the
   length are ever inspected. *)

(* The j-th length-k window of f. *)
Definition win (f : nat -> Base) (j : nat) : list Base :=
  map (fun p => f (j + p)%nat) (seq 0 k).

(* Number of windows of a length-L sequence (for k <= L this is L - k + 1). *)
Definition nwin (L : nat) : nat := S (L - k).

(* Unnormalized score: the sum of effective weights over all windows. *)
Definition raw (L : nat) (f : nat -> Base) : Q :=
  fold_right Qplus 0 (map (fun j => u (win f j)) (seq 0 (nwin L))).

(* Substitute symbol b at position i. *)
Definition upd (f : nat -> Base) (i : nat) (b : Base) : nat -> Base :=
  fun j => if Nat.eqb j i then b else f j.

(* A window of index j "contains" position i when j <= i < j + k. *)
Definition contains (i j : nat) : bool := andb (Nat.leb j i) (Nat.ltb i (j + k)).

(* The per-edit margin bound: one substitution moves raw by at most this. *)
Definition K : Q := inject_Z (Z.of_nat k) * (Umax + Umax).

(* ---------- generic rational / list helpers ---------- *)

Lemma Qabs_minus_le : forall a b : Q, Qabs (a - b) <= Qabs a + Qabs b.
Proof.
  intros a b. unfold Qminus.
  eapply Qle_trans; [apply Qabs_triangle|].
  rewrite Qabs_opp. apply Qle_refl.
Qed.

Lemma Qabs_0_le : Qabs 0 <= 0.
Proof. rewrite Qabs_pos by apply Qle_refl. apply Qle_refl. Qed.

Lemma Qabs_sum_le : forall l : list Q,
  Qabs (fold_right Qplus 0 l) <= fold_right Qplus 0 (map Qabs l).
Proof.
  induction l as [|x l IH]; simpl.
  - apply Qle_refl.
  - eapply Qle_trans; [apply Qabs_triangle|].
    apply Qplus_le_compat; [apply Qle_refl| exact IH].
Qed.

Lemma fold_right_Qplus_le : forall (l : list nat) (a b : nat -> Q),
  (forall j, In j l -> a j <= b j) ->
  fold_right Qplus 0 (map a l) <= fold_right Qplus 0 (map b l).
Proof.
  induction l as [|x l IH]; intros a b H; simpl.
  - apply Qle_refl.
  - apply Qplus_le_compat.
    + apply H. left. reflexivity.
    + apply IH. intros j Hj. apply H. right. exact Hj.
Qed.

Lemma diff_of_folds : forall (l : list nat) (A B : nat -> Q),
  fold_right Qplus 0 (map A l) - fold_right Qplus 0 (map B l)
  == fold_right Qplus 0 (map (fun j => A j - B j) l).
Proof.
  induction l as [|x l IH]; intros A B; simpl.
  - ring.
  - rewrite <- IH. ring.
Qed.

Lemma inject_Z_Snat : forall n : nat,
  inject_Z (Z.of_nat (S n)) == inject_Z (Z.of_nat n) + 1.
Proof.
  intros n.
  assert (HZ : Z.of_nat (S n) = (Z.of_nat n + 1)%Z) by lia.
  rewrite HZ. rewrite inject_Z_plus.
  replace (inject_Z 1) with 1 by reflexivity. reflexivity.
Qed.

Lemma indicator_sum : forall (l : list nat) (P : nat -> bool) (c : Q),
  fold_right Qplus 0 (map (fun j => if P j then c else 0) l)
  == c * inject_Z (Z.of_nat (length (filter P l))).
Proof.
  induction l as [|x l IH]; intros P c; simpl.
  - ring.
  - destruct (P x) eqn:Hx; simpl.
    + rewrite IH. rewrite (inject_Z_Snat (length (filter P l))). ring.
    + rewrite IH. ring.
Qed.

(* ---------- structural facts about windows and substitution ---------- *)

(* A window that does not contain the edited position is unchanged. *)
Lemma win_eq_off : forall f i b j,
  contains i j = false -> win (upd f i b) j = win f j.
Proof.
  intros f i b j Hc. unfold win. apply map_ext_in.
  intros p Hp. apply in_seq in Hp. destruct Hp as [_ Hp]. simpl in Hp.
  unfold upd. destruct (Nat.eqb (j + p) i) eqn:E; [|reflexivity].
  apply Nat.eqb_eq in E. exfalso.
  assert (Hct : contains i j = true).
  { unfold contains. apply andb_true_intro. split.
    - apply Nat.leb_le. lia.
    - apply Nat.ltb_lt. lia. }
  rewrite Hct in Hc. discriminate.
Qed.

(* At most k windows contain a given position. *)
Lemma count_le_k : forall L i,
  (length (filter (contains i) (seq 0 (nwin L))) <= k)%nat.
Proof.
  intros L i.
  apply Nat.le_trans with (length (seq (i + 1 - k) k)).
  - apply NoDup_incl_length.
    + apply NoDup_filter. apply seq_NoDup.
    + intros x Hx. apply filter_In in Hx. destruct Hx as [_ Hc].
      unfold contains in Hc. apply andb_true_iff in Hc. destruct Hc as [Hle Hlt].
      apply Nat.leb_le in Hle. apply Nat.ltb_lt in Hlt.
      apply in_seq. split; lia.
  - rewrite length_seq. apply Nat.le_refl.
Qed.

(* ---------- one substitution moves raw by at most K ---------- *)

Lemma single_subst : forall L f i b,
  Qabs (raw L f - raw L (upd f i b)) <= K.
Proof.
  intros L f i b. unfold raw.
  rewrite (diff_of_folds (seq 0 (nwin L))
            (fun j => u (win f j)) (fun j => u (win (upd f i b) j))).
  eapply Qle_trans.
  { apply Qabs_sum_le. }
  rewrite map_map.
  eapply Qle_trans.
  { apply (fold_right_Qplus_le (seq 0 (nwin L))
            (fun j => Qabs (u (win f j) - u (win (upd f i b) j)))
            (fun j => if contains i j then Umax + Umax else 0)).
    intros j Hj. cbv beta. destruct (contains i j) eqn:Hcj.
    - eapply Qle_trans; [apply Qabs_minus_le|].
      apply Qplus_le_compat; apply Hu.
    - assert (Hw : win (upd f i b) j = win f j) by (apply win_eq_off; exact Hcj).
      rewrite Hw.
      setoid_replace (u (win f j) - u (win f j)) with 0 by ring.
      apply Qabs_0_le. }
  rewrite (indicator_sum (seq 0 (nwin L)) (contains i) (Umax + Umax)).
  unfold K.
  rewrite (Qmult_comm (Umax + Umax)).
  apply Qmult_le_compat_r.
  - rewrite <- Zle_Qle. pose proof (count_le_k L i) as Hck. lia.
  - assert (HU0 : 0 <= Umax).
    { eapply Qle_trans; [apply (Qabs_nonneg (u nil))| apply Hu]. }
    lra.
Qed.

(* ---------- n substitutions move raw by at most n * K ---------- *)

(* g is reachable from f by exactly n single-symbol substitutions. *)
Inductive reach (L : nat) (f : nat -> Base) : (nat -> Base) -> nat -> Prop :=
| reach0 : reach L f f 0
| reachS : forall g i b n, reach L f g n -> reach L f (upd g i b) (S n).

Lemma reach_bound : forall L f g n,
  reach L f g n ->
  Qabs (raw L f - raw L g) <= inject_Z (Z.of_nat n) * K.
Proof.
  intros L f g n H. induction H.
  - setoid_replace (raw L f - raw L f) with 0 by ring.
    setoid_replace (inject_Z (Z.of_nat 0) * K) with 0
      by (replace (inject_Z (Z.of_nat 0)) with 0 by reflexivity; ring).
    apply Qabs_0_le.
  - eapply Qle_trans.
    + setoid_replace (raw L f - raw L (upd g i b))
        with ((raw L f - raw L g) + (raw L g - raw L (upd g i b))) by ring.
      apply Qabs_triangle.
    + eapply Qle_trans.
      { apply Qplus_le_compat; [exact IHreach| apply single_subst]. }
      setoid_replace (inject_Z (Z.of_nat (S n)) * K)
        with (inject_Z (Z.of_nat n) * K + K) by (rewrite (inject_Z_Snat n); ring).
      apply Qle_refl.
Qed.

(* ---------- the certificate ---------- *)

(* The unnormalized margin: positive favors the class, negative opposes it.
   It shares the sign of the normalized score because the window count is
   positive, and a same-length edit leaves the window count fixed. *)
Definition margin (L : nat) (b0 : Q) (f : nat -> Base) : Q :=
  raw L f + b0 * inject_Z (Z.of_nat (nwin L)).

Lemma margin_diff : forall L b0 f g,
  margin L b0 f - margin L b0 g == raw L f - raw L g.
Proof. intros. unfold margin. ring. Qed.

Lemma sign_preserved : forall x d : Q,
  Qabs d < Qabs x ->
  (0 < x -> 0 < x - d) /\ (x < 0 -> x - d < 0).
Proof.
  intros x d Hlt. split; intro Hx.
  - assert (Hax : Qabs x == x) by (apply Qabs_pos; lra).
    assert (Hd : d <= Qabs d) by apply Qle_Qabs.
    lra.
  - assert (Hax : Qabs x == - x) by (apply Qabs_neg; lra).
    assert (Hcc := proj1 (Qabs_Qle_condition d (Qabs d)) (Qle_refl (Qabs d))).
    destruct Hcc as [Hc _].
    lra.
Qed.

(* Main theorem. If g is reachable from f by n substitutions and the
   certificate n * K < |margin(f)| holds, the call's sign cannot change.
   K = k * (2 * Umax), so the certified radius is the largest n with
   n * k * 2 * Umax < |margin|, which is what `certify` returns. *)
Theorem certified_robust : forall L b0 f g n,
  reach L f g n ->
  inject_Z (Z.of_nat n) * K < Qabs (margin L b0 f) ->
  (0 < margin L b0 f -> 0 < margin L b0 g) /\
  (margin L b0 f < 0 -> margin L b0 g < 0).
Proof.
  intros L b0 f g n Hreach Hcert.
  assert (Hb : Qabs (raw L f - raw L g) <= inject_Z (Z.of_nat n) * K)
    by (apply reach_bound; exact Hreach).
  set (x := margin L b0 f). set (y := margin L b0 g).
  assert (Hd : Qabs (x - y) < Qabs x).
  { setoid_replace (x - y) with (raw L f - raw L g)
      by (unfold x, y; apply margin_diff).
    eapply Qle_lt_trans; [exact Hb| exact Hcert]. }
  destruct (sign_preserved x (x - y) Hd) as [Hpos Hneg].
  split; intro Hsgn.
  - assert (Hy : 0 < x - (x - y)) by (apply Hpos; exact Hsgn).
    setoid_replace y with (x - (x - y)) by ring. exact Hy.
  - assert (Hy : x - (x - y) < 0) by (apply Hneg; exact Hsgn).
    setoid_replace y with (x - (x - y)) by ring. exact Hy.
Qed.

End Detector.

(* Instantiation to the published model's setting: the four DNA bases and k = 8.
   The host head's effective weights have a concrete finite bound Umax_host, so
   the guarantee holds with per-substitution margin constant K = 8 * (2 * Umax_host).
   The radius `certify` reports is the largest n with n * K < |margin|. *)
Inductive DNA := dA | dC | dG | dT.

Corollary certified_robust_dna :
  forall (u : list DNA -> Q) (Umax : Q),
    (forall w, Qabs (u w) <= Umax) ->
    forall (L : nat) (b0 : Q) (f g : nat -> DNA) (n : nat),
      reach DNA L f g n ->
      inject_Z (Z.of_nat n) * K 8 Umax < Qabs (margin DNA 8 u L b0 f) ->
      (0 < margin DNA 8 u L b0 f -> 0 < margin DNA 8 u L b0 g) /\
      (margin DNA 8 u L b0 f < 0 -> margin DNA 8 u L b0 g < 0).
Proof.
  intros u Umax Hu. exact (certified_robust DNA 8 u Umax Hu).
Qed.