philipjohnbasile commited on
Commit
eac8caf
·
verified ·
1 Parent(s): a88adac

Upload dist/install_glm_dsa_patch.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dist/install_glm_dsa_patch.py +147 -0
dist/install_glm_dsa_patch.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Put the GLM-5.2 DSA fix INTO every local mlx_lm — including LM Studio's
3
+ vendored backend — so the demolished GLM-5.2 loads WITHOUT depending on the
4
+ LM Studio app (closed source) shipping it.
5
+
6
+ The LM Studio *app* is closed, but the engine it loads MLX models with is
7
+ **open-source mlx_lm**, vendored on disk under ~/.lmstudio. Stock mlx_lm ships
8
+ only a 53-line stub for `glm_moe_dsa` (GLM-5.2's DeepSeek-Sparse-Attention MoE),
9
+ which is why GLM-5.2 fails to load ("Missing 285 parameters"). This script
10
+ overwrites that stub with our full 238-line implementation (full/shared indexer
11
+ handling), in place, with a one-time `.orig` backup.
12
+
13
+ python dist/install_glm_dsa_patch.py # patch every install found
14
+ python dist/install_glm_dsa_patch.py --dry-run # show targets, change nothing
15
+ python dist/install_glm_dsa_patch.py --revert # restore the .orig stubs
16
+
17
+ After patching, FULLY QUIT and reopen LM Studio so the backend reloads the engine.
18
+ """
19
+ import argparse
20
+ import glob
21
+ import os
22
+ import shutil
23
+ import sys
24
+
25
+ HERE = os.path.dirname(os.path.abspath(__file__))
26
+ PATCH = os.path.join(HERE, "glm_moe_dsa.py")
27
+ REL = os.path.join("mlx_lm", "models", "glm_moe_dsa.py")
28
+
29
+
30
+ def targets():
31
+ found = set()
32
+ home = os.path.expanduser("~")
33
+ # scan EVERY common location an mlx_lm install may live, so we patch them ALL on this machine
34
+ # (LM Studio, Ollama, conda/venv/virtualenv, project venvs, homebrew/usr-local/framework pythons)
35
+ for g in [
36
+ home + "/.lmstudio/**/site-packages/" + REL, # LM Studio's vendored engines
37
+ home + "/.ollama/**/site-packages/" + REL, # Ollama (if it vendors mlx_lm)
38
+ home + "/miniconda3/envs/*/lib/python*/site-packages/" + REL,
39
+ home + "/anaconda3/envs/*/lib/python*/site-packages/" + REL,
40
+ home + "/.conda/envs/*/lib/python*/site-packages/" + REL,
41
+ home + "/.virtualenvs/*/lib/python*/site-packages/" + REL,
42
+ home + "/.venv/lib/python*/site-packages/" + REL,
43
+ home + "/*/.venv/lib/python*/site-packages/" + REL, # project venvs under ~
44
+ "/opt/homebrew/lib/python*/site-packages/" + REL,
45
+ "/usr/local/lib/python*/site-packages/" + REL,
46
+ "/Library/Frameworks/Python.framework/Versions/*/lib/python*/site-packages/" + REL,
47
+ ]:
48
+ found.update(glob.glob(g, recursive=("**" in g)))
49
+ # plus the mlx_lm importable in THIS interpreter (pip/venv)
50
+ try:
51
+ import mlx_lm
52
+ found.add(os.path.join(os.path.dirname(mlx_lm.__file__), "models", "glm_moe_dsa.py"))
53
+ except Exception: # noqa: BLE001
54
+ pass
55
+ return sorted(p for p in found if os.path.exists(os.path.dirname(p)))
56
+
57
+
58
+ # The deepseek_v32.py MoE fix that unblocks gradient-based RL (GRPO) on the
59
+ # quantized MoE: the routed expert indices are non-differentiable (GatherQMM VJP),
60
+ # so they must be stop_gradient'd. Targeted, idempotent patch.
61
+ DSV32_OLD = "inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]"
62
+ DSV32_NEW = ("inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])"
63
+ " # stop_gradient: MoE top-K indices non-diff (GatherQMM VJP); needed for GRPO")
64
+
65
+
66
+ def patch_dsv32(glm_target, dry_run=False, revert=False):
67
+ """Patch deepseek_v32.py (same models/ dir as glm_moe_dsa.py). Idempotent;
68
+ backs up .orig. Returns a status string or None if the file isn't there."""
69
+ d = os.path.join(os.path.dirname(glm_target), "deepseek_v32.py")
70
+ if not os.path.exists(d):
71
+ return None
72
+ if revert:
73
+ if os.path.exists(d + ".orig"):
74
+ shutil.copy(d + ".orig", d)
75
+ return "reverted"
76
+ return None
77
+ src = open(d).read()
78
+ if "stop_gradient(mx.argpartition(-scores" in src:
79
+ return "already"
80
+ if DSV32_OLD not in src:
81
+ return "pattern-not-found"
82
+ if dry_run:
83
+ return "would-patch"
84
+ if not os.path.exists(d + ".orig"):
85
+ shutil.copy(d, d + ".orig")
86
+ open(d, "w").write(src.replace(DSV32_OLD, DSV32_NEW))
87
+ return "patched"
88
+
89
+
90
+ def main():
91
+ ap = argparse.ArgumentParser()
92
+ ap.add_argument("--dry-run", action="store_true")
93
+ ap.add_argument("--revert", action="store_true")
94
+ args = ap.parse_args()
95
+ if not os.path.exists(PATCH):
96
+ sys.exit(f" [stop] bundled fix not found: {PATCH}")
97
+ patch_src = open(PATCH).read()
98
+ tg = targets()
99
+ if not tg:
100
+ print(" no mlx_lm installs found (LM Studio not installed, no venv?).")
101
+ return
102
+ changed = 0
103
+ for t in tg:
104
+ tag = "lmstudio" if ".lmstudio" in t else "mlx_lm"
105
+ if args.revert:
106
+ if os.path.exists(t + ".orig"):
107
+ shutil.copy(t + ".orig", t)
108
+ print(f" [revert {tag}] {t}")
109
+ changed += 1
110
+ else:
111
+ print(f" [skip {tag}: no .orig] {t}")
112
+ continue
113
+ already = os.path.exists(t) and open(t).read() == patch_src
114
+ if already:
115
+ print(f" [ok {tag}: already patched] {t}")
116
+ continue
117
+ print(f" [{'would patch' if args.dry_run else 'PATCH'} {tag}] {t}")
118
+ if args.dry_run:
119
+ continue
120
+ if not os.path.exists(t + ".orig"): # back up the stub once
121
+ shutil.copy(t, t + ".orig")
122
+ shutil.copy(PATCH, t)
123
+ changed += 1
124
+ # also apply the deepseek_v32.py MoE stop_gradient fix (GRPO/RL portability)
125
+ for t in tg:
126
+ st = patch_dsv32(t, args.dry_run, args.revert)
127
+ if st in ("patched", "reverted", "would-patch"):
128
+ print(f" [dsv32 {st}] {os.path.join(os.path.dirname(t), 'deepseek_v32.py')}")
129
+ if st in ("patched", "reverted"):
130
+ changed += 1
131
+ elif st == "pattern-not-found":
132
+ print(f" [dsv32 SKIP: line not found — different mlx_lm version] "
133
+ f"{os.path.join(os.path.dirname(t), 'deepseek_v32.py')}")
134
+ elif st is None and not args.revert: # deepseek_v32.py absent → glm_moe_dsa can't import it
135
+ dsv = os.path.join(os.path.dirname(t), "deepseek_v32.py")
136
+ if not os.path.exists(dsv):
137
+ print(f" [⚠ dsv32 MISSING — the model WILL fail to load] {dsv}\n"
138
+ f" glm_moe_dsa extends deepseek_v32, but your mlx_lm predates DeepSeek-V3.2 support.\n"
139
+ f" Fix: pip install -U mlx-lm (then re-run this patch). Otherwise: ImportError at load.")
140
+ print(f"\n {'(dry-run) ' if args.dry_run else ''}{changed} file(s) "
141
+ f"{'to change' if args.dry_run else 'changed'}.")
142
+ if changed and not args.dry_run and not args.revert:
143
+ print(" -> FULLY QUIT and reopen LM Studio so it reloads the patched engine.")
144
+
145
+
146
+ if __name__ == "__main__":
147
+ main()