File size: 6,184 Bytes
1b5c7e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/env python3
"""Put the GLM-5.2 DSA fix INTO every local mlx_lm — including LM Studio's
vendored backend — so the demolished GLM-5.2 loads WITHOUT depending on the
LM Studio app (closed source) shipping it.

The LM Studio *app* is closed, but the engine it loads MLX models with is
**open-source mlx_lm**, vendored on disk under ~/.lmstudio. Stock mlx_lm ships
only a 53-line stub for `glm_moe_dsa` (GLM-5.2's DeepSeek-Sparse-Attention MoE),
which is why GLM-5.2 fails to load ("Missing 285 parameters"). This script
overwrites that stub with our full 238-line implementation (full/shared indexer
handling), in place, with a one-time `.orig` backup.

  python dist/install_glm_dsa_patch.py            # patch every install found
  python dist/install_glm_dsa_patch.py --dry-run  # show targets, change nothing
  python dist/install_glm_dsa_patch.py --revert    # restore the .orig stubs

After patching, FULLY QUIT and reopen LM Studio so the backend reloads the engine.
"""
import argparse
import glob
import os
import shutil
import sys

HERE = os.path.dirname(os.path.abspath(__file__))
PATCH = os.path.join(HERE, "glm_moe_dsa.py")
REL = os.path.join("mlx_lm", "models", "glm_moe_dsa.py")


def targets():
    found = set()
    home = os.path.expanduser("~")
    # scan EVERY common location an mlx_lm install may live, so we patch them ALL on this machine
    # (LM Studio, Ollama, conda/venv/virtualenv, project venvs, homebrew/usr-local/framework pythons)
    for g in [
        home + "/.lmstudio/**/site-packages/" + REL,          # LM Studio's vendored engines
        home + "/.ollama/**/site-packages/" + REL,            # Ollama (if it vendors mlx_lm)
        home + "/miniconda3/envs/*/lib/python*/site-packages/" + REL,
        home + "/anaconda3/envs/*/lib/python*/site-packages/" + REL,
        home + "/.conda/envs/*/lib/python*/site-packages/" + REL,
        home + "/.virtualenvs/*/lib/python*/site-packages/" + REL,
        home + "/.venv/lib/python*/site-packages/" + REL,
        home + "/*/.venv/lib/python*/site-packages/" + REL,   # project venvs under ~
        "/opt/homebrew/lib/python*/site-packages/" + REL,
        "/usr/local/lib/python*/site-packages/" + REL,
        "/Library/Frameworks/Python.framework/Versions/*/lib/python*/site-packages/" + REL,
    ]:
        found.update(glob.glob(g, recursive=("**" in g)))
    # plus the mlx_lm importable in THIS interpreter (pip/venv)
    try:
        import mlx_lm
        found.add(os.path.join(os.path.dirname(mlx_lm.__file__), "models", "glm_moe_dsa.py"))
    except Exception:  # noqa: BLE001
        pass
    return sorted(p for p in found if os.path.exists(os.path.dirname(p)))


# The deepseek_v32.py MoE fix that unblocks gradient-based RL (GRPO) on the
# quantized MoE: the routed expert indices are non-differentiable (GatherQMM VJP),
# so they must be stop_gradient'd. Targeted, idempotent patch.
DSV32_OLD = "inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]"
DSV32_NEW = ("inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])"
             "  # stop_gradient: MoE top-K indices non-diff (GatherQMM VJP); needed for GRPO")


def patch_dsv32(glm_target, dry_run=False, revert=False):
    """Patch deepseek_v32.py (same models/ dir as glm_moe_dsa.py). Idempotent;
    backs up .orig. Returns a status string or None if the file isn't there."""
    d = os.path.join(os.path.dirname(glm_target), "deepseek_v32.py")
    if not os.path.exists(d):
        return None
    if revert:
        if os.path.exists(d + ".orig"):
            shutil.copy(d + ".orig", d)
            return "reverted"
        return None
    src = open(d).read()
    if "stop_gradient(mx.argpartition(-scores" in src:
        return "already"
    if DSV32_OLD not in src:
        return "pattern-not-found"
    if dry_run:
        return "would-patch"
    if not os.path.exists(d + ".orig"):
        shutil.copy(d, d + ".orig")
    open(d, "w").write(src.replace(DSV32_OLD, DSV32_NEW))
    return "patched"


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dry-run", action="store_true")
    ap.add_argument("--revert", action="store_true")
    args = ap.parse_args()
    if not os.path.exists(PATCH):
        sys.exit(f"  [stop] bundled fix not found: {PATCH}")
    patch_src = open(PATCH).read()
    tg = targets()
    if not tg:
        print("  no mlx_lm installs found (LM Studio not installed, no venv?).")
        return
    changed = 0
    for t in tg:
        tag = "lmstudio" if ".lmstudio" in t else "mlx_lm"
        if args.revert:
            if os.path.exists(t + ".orig"):
                shutil.copy(t + ".orig", t)
                print(f"  [revert {tag}] {t}")
                changed += 1
            else:
                print(f"  [skip {tag}: no .orig] {t}")
            continue
        already = os.path.exists(t) and open(t).read() == patch_src
        if already:
            print(f"  [ok {tag}: already patched] {t}")
            continue
        print(f"  [{'would patch' if args.dry_run else 'PATCH'} {tag}] {t}")
        if args.dry_run:
            continue
        if not os.path.exists(t + ".orig"):           # back up the stub once
            shutil.copy(t, t + ".orig")
        shutil.copy(PATCH, t)
        changed += 1
    # also apply the deepseek_v32.py MoE stop_gradient fix (GRPO/RL portability)
    for t in tg:
        st = patch_dsv32(t, args.dry_run, args.revert)
        if st in ("patched", "reverted", "would-patch"):
            print(f"  [dsv32 {st}] {os.path.join(os.path.dirname(t), 'deepseek_v32.py')}")
            if st in ("patched", "reverted"):
                changed += 1
        elif st == "pattern-not-found":
            print(f"  [dsv32 SKIP: line not found — different mlx_lm version] "
                  f"{os.path.join(os.path.dirname(t), 'deepseek_v32.py')}")
    print(f"\n  {'(dry-run) ' if args.dry_run else ''}{changed} file(s) "
          f"{'to change' if args.dry_run else 'changed'}.")
    if changed and not args.dry_run and not args.revert:
        print("  -> FULLY QUIT and reopen LM Studio so it reloads the patched engine.")


if __name__ == "__main__":
    main()