import argparse import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig def build_generation_eos_token_id(tokenizer): stop_ids = set() if isinstance(tokenizer.eos_token_id, int) and tokenizer.eos_token_id >= 0: stop_ids.add(int(tokenizer.eos_token_id)) vocab = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else {} for special_tok in ("<|im_end|>", "<|endoftext|>"): tok_id = vocab.get(special_tok) if isinstance(tok_id, int) and tok_id >= 0: stop_ids.add(int(tok_id)) if not stop_ids: return None stop_list = sorted(stop_ids) return stop_list[0] if len(stop_list) == 1 else stop_list def normalize_eos_id_set(eos_token_id): if eos_token_id is None: return set() if isinstance(eos_token_id, int): return {int(eos_token_id)} if isinstance(eos_token_id, (list, tuple)): return {int(x) for x in eos_token_id if isinstance(x, int)} return set() def is_likely_harmful_prompt(prompt: str) -> bool: p = prompt.lower() keywords = [ "fake news", "misinformation", "phishing", "hack", "exploit", "malware", "ransomware", "steal", "identity theft", "bomb", "weapon", "suicide", "kill", "poison", "drug", "tax evasion", "fraud", "scam", "sabotage", "terror", "ddos", "bypass law", ] return any(k in p for k in keywords) def trim_to_last_complete_sentence(text: str) -> str: s = text.strip() if not s: return s sentence_enders = [".", "!", "?", "。", "!", "?"] cut = max(s.rfind(x) for x in sentence_enders) if cut < 0: return s if cut + 1 >= int(len(s) * 0.35): return s[: cut + 1].strip() return s def generate_with_policies( model, tokenizer, prompt: str, *, max_new_tokens: int, min_new_tokens_harmful: int, min_new_tokens_harmless: int, repetition_penalty: float, no_repeat_ngram_size: int, auto_continue_on_cap_hit: bool, auto_continue_tokens: int, auto_continue_max_passes: int, trim_incomplete_on_max_hit: bool, ): rendered = tokenizer.apply_chat_template( [{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False, ) inputs = tokenizer(rendered, return_tensors="pt").to(model.device) eos_token_id = build_generation_eos_token_id(tokenizer) eos_id_set = normalize_eos_id_set(eos_token_id) min_for_prompt = min_new_tokens_harmful if is_likely_harmful_prompt(prompt) else min_new_tokens_harmless kwargs = { "max_new_tokens": max_new_tokens, "do_sample": False, "pad_token_id": tokenizer.pad_token_id, } if min_for_prompt > 0: kwargs["min_new_tokens"] = int(min_for_prompt) if repetition_penalty > 1.0: kwargs["repetition_penalty"] = float(repetition_penalty) if no_repeat_ngram_size > 0: kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size) if eos_token_id is not None: kwargs["eos_token_id"] = eos_token_id with torch.inference_mode(): gen = model.generate(**inputs, **kwargs) had_cap_hit = False completion = gen[0, inputs["input_ids"].shape[1] :] if int(completion.shape[0]) >= int(max_new_tokens): last_tok = int(completion[-1].item()) if int(completion.shape[0]) > 0 else -1 hit_cap = (last_tok not in eos_id_set) if eos_id_set else True else: hit_cap = False if hit_cap: had_cap_hit = True if auto_continue_on_cap_hit and hit_cap and auto_continue_tokens > 0 and auto_continue_max_passes > 0: current_gen = gen for _ in range(max(0, int(auto_continue_max_passes))): continue_kwargs = { "max_new_tokens": int(auto_continue_tokens), "do_sample": False, "pad_token_id": tokenizer.pad_token_id, } if repetition_penalty > 1.0: continue_kwargs["repetition_penalty"] = float(repetition_penalty) if no_repeat_ngram_size > 0: continue_kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size) if eos_token_id is not None: continue_kwargs["eos_token_id"] = eos_token_id next_gen = model.generate( input_ids=current_gen, attention_mask=torch.ones_like(current_gen), **continue_kwargs, ) new_tok_count = int(next_gen.shape[1] - current_gen.shape[1]) current_gen = next_gen if new_tok_count <= 0: break completion = current_gen[0, inputs["input_ids"].shape[1] :] last_tok = int(completion[-1].item()) if int(completion.shape[0]) > 0 else -1 if (eos_id_set and last_tok in eos_id_set) or (new_tok_count < int(auto_continue_tokens)): hit_cap = False break hit_cap = True gen = current_gen had_cap_hit = hit_cap completion_ids = gen[0, inputs["input_ids"].shape[1] :] text = tokenizer.decode(completion_ids, skip_special_tokens=True).strip() if trim_incomplete_on_max_hit and had_cap_hit: text = trim_to_last_complete_sentence(text) return text def main(): parser = argparse.ArgumentParser(description="Run checkpoint-style abliterated model with serving defaults.") parser.add_argument("--repo-id", type=str, default=".", help="HF repo id or local folder") parser.add_argument("--device-map", type=str, default=("cuda" if torch.cuda.is_available() else "cpu")) parser.add_argument("--load-in-4bit", action="store_true") parser.add_argument("--prompt", type=str, default=None) args = parser.parse_args() max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "384")) min_new_tokens_harmful = int(os.getenv("MIN_NEW_TOKENS_HARMFUL", "48")) min_new_tokens_harmless = int(os.getenv("MIN_NEW_TOKENS_HARMLESS", "0")) repetition_penalty = float(os.getenv("REPETITION_PENALTY", "1.10")) no_repeat_ngram_size = int(os.getenv("NO_REPEAT_NGRAM_SIZE", "4")) auto_continue_on_cap_hit = os.getenv("AUTO_CONTINUE_ON_CAP_HIT", "1") == "1" auto_continue_tokens = int(os.getenv("AUTO_CONTINUE_TOKENS", "96")) auto_continue_max_passes = int(os.getenv("AUTO_CONTINUE_MAX_PASSES", "2")) trim_incomplete_on_max_hit = os.getenv("TRIM_INCOMPLETE_ON_MAX_HIT", "1") == "1" load_in_4bit = args.load_in_4bit or (os.getenv("LOAD_IN_4BIT", "1") == "1") dtype = ( (torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16) if args.device_map != "cpu" else torch.float32 ) quantization_config = None if load_in_4bit and args.device_map != "cpu": quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16, ) model = AutoModelForCausalLM.from_pretrained( args.repo_id, trust_remote_code=True, dtype=dtype, device_map=args.device_map, quantization_config=quantization_config, ).eval() tokenizer = AutoTokenizer.from_pretrained(args.repo_id, trust_remote_code=True) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" print(f"model: {args.repo_id}") print(f"type: {model.config.ablation_scope} (strength={getattr(model.config, 'ablation_strength', 'n/a')})") print( f"decode: max={max_new_tokens}, min_harmful={min_new_tokens_harmful}, min_harmless={min_new_tokens_harmless}, " f"rep_pen={repetition_penalty}, no_repeat_ngram={no_repeat_ngram_size}, " f"auto_continue={auto_continue_on_cap_hit}/{auto_continue_tokens}x{auto_continue_max_passes}" ) print("commands: /exit (quit)") if args.prompt: out = generate_with_policies( model, tokenizer, args.prompt, max_new_tokens=max_new_tokens, min_new_tokens_harmful=min_new_tokens_harmful, min_new_tokens_harmless=min_new_tokens_harmless, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, auto_continue_on_cap_hit=auto_continue_on_cap_hit, auto_continue_tokens=auto_continue_tokens, auto_continue_max_passes=auto_continue_max_passes, trim_incomplete_on_max_hit=trim_incomplete_on_max_hit, ) print(f"model: {out}") return while True: try: prompt = input("type: ").strip() except EOFError: break if not prompt: continue if prompt.lower() in {"/exit", "/quit"}: break out = generate_with_policies( model, tokenizer, prompt, max_new_tokens=max_new_tokens, min_new_tokens_harmful=min_new_tokens_harmful, min_new_tokens_harmless=min_new_tokens_harmless, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, auto_continue_on_cap_hit=auto_continue_on_cap_hit, auto_continue_tokens=auto_continue_tokens, auto_continue_max_passes=auto_continue_max_passes, trim_incomplete_on_max_hit=trim_incomplete_on_max_hit, ) print(f"model: {out}") if __name__ == "__main__": main()