import contextlib import torch from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForConditionalGeneration class AbliteratedQwen3_5ForCausalLM(Qwen3_5ForConditionalGeneration): def __init__(self, config): text_cfg = getattr(config, "text_config", None) if text_cfg is not None: for key, value in text_cfg.to_dict().items(): if not hasattr(config, key): setattr(config, key, value) super().__init__(config) direction = getattr(config, "refusal_direction", None) if direction is None: raise ValueError( "Missing `refusal_direction` in config. " "This checkpoint expects an embedded refusal direction." ) direction_t = torch.tensor(direction, dtype=torch.float32) self.register_buffer("_refusal_direction", direction_t, persistent=False) self.ablation_layers = [int(x) for x in getattr(config, "ablation_layers", [])] self.ablation_strength = float(getattr(config, "ablation_strength", 1.0)) def _refresh_refusal_direction_buffer(self): direction = getattr(self.config, "refusal_direction", None) if direction is None: raise ValueError( "Missing `refusal_direction` in config. " "This checkpoint expects an embedded refusal direction." ) target_device = None existing = self._buffers.get("_refusal_direction") if existing is not None and existing.device.type != "meta": target_device = existing.device direction_t = torch.tensor(direction, dtype=torch.float32, device=target_device) needs_reset = ( existing is None or existing.shape != direction_t.shape or float(existing.detach().float().norm().item()) == 0.0 ) if needs_reset: self._buffers["_refusal_direction"] = direction_t return self._buffers["_refusal_direction"] def _decoder_layers(self): if hasattr(self.model, "language_model") and hasattr(self.model.language_model, "layers"): return self.model.language_model.layers if hasattr(self.model, "layers"): return self.model.layers raise AttributeError("Could not find decoder layers in model.") def _make_hook(self): def hook_fn(module, inputs): hidden = inputs[0] d = self._refusal_direction.to(hidden.device, hidden.dtype) d = d / (d.norm() + 1e-8) projection = torch.einsum("bsd,d->bs", hidden, d).unsqueeze(-1) * d hidden = hidden - float(self.ablation_strength) * projection return (hidden, *inputs[1:]) return hook_fn @contextlib.contextmanager def _temporary_hooks(self): layers = self._decoder_layers() layer_ids = [i for i in self.ablation_layers if 0 <= i < len(layers)] if not layer_ids or self.ablation_strength <= 0: yield return handle_list = [] try: hook_fn = self._make_hook() for layer_id in layer_ids: handle_list.append(layers[layer_id].register_forward_pre_hook(hook_fn)) yield finally: for h in handle_list: h.remove() def forward(self, *args, **kwargs): self.ablation_strength = float(getattr(self.config, "ablation_strength", self.ablation_strength)) self.ablation_layers = [int(x) for x in getattr(self.config, "ablation_layers", self.ablation_layers)] self._refresh_refusal_direction_buffer() with self._temporary_hooks(): return super().forward(*args, **kwargs)