"""Make a local copy of syvai/cohere-transcribe-diarize loadable by vLLM 0.19.0. Three fixes: 1. tokenizer_config.json: drop legacy `extra_special_tokens` list (transformers 4.57+ expects a dict; the actual tokens are already in tokenizer.json). 2. config.json: set head.num_classes + transf_decoder.config_dict.vocab_size to 16684 to match the resized embedding/LM head. 3. model.safetensors: strip the `model.` prefix from every key and drop the `*.num_batches_tracked` tensors that vLLM's CohereAsr model doesn't register. Idempotent — re-running is safe. Run once after downloading the model: python fix_for_vllm.py /path/to/syvai-cohere-transcribe-diarize """ import json, os, sys import safetensors.torch as st if len(sys.argv) != 2: sys.exit("usage: python fix_for_vllm.py ") MODEL_DIR = sys.argv[1].rstrip("/") assert os.path.isdir(MODEL_DIR), f"not a directory: {MODEL_DIR}" # 1. tokenizer_config.json tcfg_path = f"{MODEL_DIR}/tokenizer_config.json" tcfg = json.load(open(tcfg_path)) if "extra_special_tokens" in tcfg: tcfg.pop("extra_special_tokens") json.dump(tcfg, open(tcfg_path, "w"), indent=2, ensure_ascii=False) print(" ✓ tokenizer_config.json: removed extra_special_tokens") else: print(" · tokenizer_config.json: already clean") # 2. config.json cfg_path = f"{MODEL_DIR}/config.json" cfg = json.load(open(cfg_path)) changed = False if cfg.get("head", {}).get("num_classes") != 16684: cfg["head"]["num_classes"] = 16684 changed = True if cfg.get("transf_decoder", {}).get("config_dict", {}).get("vocab_size") != 16684: cfg["transf_decoder"]["config_dict"]["vocab_size"] = 16684 changed = True if changed: json.dump(cfg, open(cfg_path, "w"), indent=2, ensure_ascii=False) print(" ✓ config.json: head.num_classes + decoder vocab_size = 16684") else: print(" · config.json: already correct") # 3. model.safetensors — strip `model.` prefix + drop num_batches_tracked sf_path = f"{MODEL_DIR}/model.safetensors" sd = st.load_file(sf_path) n_before = len(sd) needs_rewrite = any(k.startswith("model.") for k in sd) or any( k.endswith("num_batches_tracked") for k in sd ) if needs_rewrite: sd = {(k[6:] if k.startswith("model.") else k): v for k, v in sd.items() if not k.endswith("num_batches_tracked")} st.save_file(sd, sf_path) print(f" ✓ model.safetensors: rewrote ({n_before} → {len(sd)} tensors)") else: print(" · model.safetensors: already clean") print("\nDone. Now apply the vLLM-side patches with `python vllm_diarized_patch.py`.")