"""NPC move-selection model: Gemma 3 270M + Sathvik0101/cyber-duel-tiny-users LoRA. Two execution paths: 1. **Modal GPU** (preferred): when GEMMA_SERVER is set, /pick_move is called remotely. Cold start ~1s, warm ~600-1500ms on T4. 2. **Local CPU** (fallback): when GEMMA_SERVER is unset, loads the model in-process. Used for local dev and as a safety net. Public API (kept stable for app.py): - MOVES: list of legal move names - get_model(): returns (None, None) when remote, otherwise (model, tok) - pick_counter_move(sequence) -> (move, reasoning, source) """ from __future__ import annotations import os import re import time from pathlib import Path from typing import Optional, Tuple import torch ADAPTER_DIR = Path(__file__).resolve().parent / "adapters" / "ref" # Use the un-gated unsloth mirror of the same Gemma 3 270M model. BASE_MODEL_ID = os.environ.get("GEMMA_BASE_MODEL", "unsloth/gemma-3-270m-it") GEMMA_SERVER = os.environ.get("GEMMA_SERVER", "").rstrip("/") CANONICAL_MOVES = ( "jab", "cross", "low_kick", "roundhouse", "uppercut", "parry", "backstep", "clinch", "throw", ) MOVE_TO_IDX = {m: i for i, m in enumerate(CANONICAL_MOVES)} MOVES = list(CANONICAL_MOVES) NUM_MOVES = len(MOVES) SYSTEM_PROMPT = ( "You are an expert fighting game NPC AI. " "Given the player's last 5 moves, output a single best counter-move " "from: jab, cross, low_kick, roundhouse, uppercut, parry, backstep, " "clinch, throw. Format your reply as:\n" "[one short sentence of reasoning]\n" "counter_move: " ) _model = None _tokenizer = None _model_lock_path: Optional[Path] = None _load_error: Optional[str] = None def _get_hf_token() -> Optional[str]: tok = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") if tok: return tok cache = Path.home() / ".cache" / "huggingface" / "token" if cache.exists(): return cache.read_text(encoding="utf-8").strip() or None return None def get_model(): """Lazy singleton loader. Returns (model, tokenizer) or raises.""" global _model, _tokenizer, _load_error if _model is not None and _tokenizer is not None: return _model, _tokenizer if not ADAPTER_DIR.exists(): raise FileNotFoundError( f"LoRA adapter not found at {ADAPTER_DIR}. " "Make sure adapters/ref/ is bundled with the Space." ) from huggingface_hub import snapshot_download from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel token = _get_hf_token() # Use bf16 on CPU (newer torch supports it; falls back to fp32 otherwise). bf16_ok = False try: bf16_ok = torch.cuda.is_available() or hasattr(torch, "to_bf16") except Exception: bf16_ok = False dtype = torch.bfloat16 if bf16_ok else torch.float32 print(f"[gemma_npc] Loading tokenizer + base {BASE_MODEL_ID} (dtype={dtype})...") t0 = time.perf_counter() # Pull the base model. snapshot_download lets us reuse the cached HF store # so cold-starts aren't 500MB of network on every restart. base_path = snapshot_download(repo_id=BASE_MODEL_ID, token=token) _tokenizer = AutoTokenizer.from_pretrained(base_path, token=token) base = AutoModelForCausalLM.from_pretrained( base_path, token=token, torch_dtype=dtype, ) base.eval() print(f"[gemma_npc] Loading LoRA adapter from {ADAPTER_DIR}...") _model = PeftModel.from_pretrained(base, str(ADAPTER_DIR)) _model.eval() # Warmup so the first user request doesn't pay dispatch cost. _generate("jab,cross,low_kick,roundhouse,uppercut") print(f"[gemma_npc] Model ready in {time.perf_counter() - t0:.1f}s") return _model, _tokenizer def _generate(sequence: str, max_new_tokens: int = 18) -> str: # Tiny LRU cache -- repeated move sequences are common in fighting games, # so a second request with the same sequence returns instantly. cached = _LRU_CACHE.get(sequence) if cached is not None: return cached[1] model, tokenizer = get_model() messages = [ {"role": "user", "content": _format_user_prompt(sequence)}, ] # `apply_chat_template` with `tokenize=True, return_tensors="pt"` returns # a `BatchEncoding` (subclass of tokenizers.Encoding) in transformers # 5.x, not a bare tensor -- so we ask for `return_dict=True` and index # `input_ids` / `attention_mask` ourselves. encoded = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True, ) input_ids = encoded["input_ids"] attn_mask = encoded.get("attention_mask", torch.ones_like(input_ids)) eos_id = tokenizer.convert_tokens_to_ids("") with torch.inference_mode(): out = model.generate( input_ids=input_ids, attention_mask=attn_mask, max_new_tokens=max_new_tokens, do_sample=False, temperature=1.0, pad_token_id=tokenizer.eos_token_id, eos_token_id=eos_id, ) new_tokens = out[0][input_ids.shape[-1]:] text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() return text def _format_user_prompt(sequence: str) -> str: parts = [p.strip() for p in (sequence or "").split(",") if p.strip()] parts = parts[-5:] if len(parts) > 5 else parts if not parts: parts = ["jab"] return ( f"Player's last moves: {','.join(parts)}.\n" f"Choose the best counter from: jab, cross, low_kick, roundhouse, " f"uppercut, parry, backstep, clinch, throw.\n" f"Respond with EXACTLY one line: counter_move: " ) _LRU_CACHE: "dict[str, Tuple[str, str, str]]" = {} _LRU_CACHE_MAX = 64 _MOVE_PATTERN = re.compile(r"\b(" + "|".join(re.escape(m) for m in CANONICAL_MOVES) + r")\b", re.IGNORECASE) def _parse_move(text: str) -> Tuple[str, str]: """Extract the model's chosen move and reasoning from its output. Falls back to scanning the text for any canonical move name if the `counter_move:` line is missing. """ reasoning = text.strip() chosen: Optional[str] = None m = re.search(r"counter_move\s*:\s*([A-Za-z_]+)", reasoning, re.IGNORECASE) if m: candidate = m.group(1).strip().lower() if candidate in MOVE_TO_IDX: chosen = candidate reasoning = reasoning[: m.start()].strip() or reasoning if not chosen: # Look for the first canonical move name in the output. m2 = _MOVE_PATTERN.search(reasoning) if m2: chosen = m2.group(1).lower() else: # Final fallback: deterministic. chosen = "jab" reasoning = f"could not parse model output: {reasoning[:80]}" return chosen, reasoning[:200] def pick_counter_move(sequence: str) -> Tuple[str, str, str]: """Run the model on a player-move sequence. Returns (move, reasoning, source) where `source` is "gemma_lora" on success and "fallback" if the model can't be loaded. """ cached = _LRU_CACHE.get(sequence) if cached is not None: return cached # --- Modal GPU path: fast remote inference on T4 ------------------- if GEMMA_SERVER: import httpx try: t0 = time.perf_counter() r = httpx.post( f"{GEMMA_SERVER}/pick_move", json={"sequence": sequence}, timeout=20.0, ) r.raise_for_status() data = r.json() ms = (time.perf_counter() - t0) * 1000 move = data.get("move") or "jab" if move not in MOVE_TO_IDX: move = "jab" reasoning = (data.get("reasoning") or "")[:200] src = data.get("source", "gemma_modal") result = (move, f"{reasoning} (modal {ms:.0f}ms)", src) except Exception as e: # noqa: BLE001 return "jab", f"modal failed: {type(e).__name__}: {e}", "fallback" if len(_LRU_CACHE) >= _LRU_CACHE_MAX: _LRU_CACHE.pop(next(iter(_LRU_CACHE))) _LRU_CACHE[sequence] = result return result # --- Local CPU fallback -------------------------------------------- try: model, _ = get_model() except Exception as e: # noqa: BLE001 return "jab", f"model unavailable: {type(e).__name__}: {e}", "fallback" try: t0 = time.perf_counter() text = _generate(sequence) ms = (time.perf_counter() - t0) * 1000 move, reasoning = _parse_move(text) result = (move, f"{reasoning} ({ms:.0f}ms)", "gemma_lora") if len(_LRU_CACHE) >= _LRU_CACHE_MAX: _LRU_CACHE.pop(next(iter(_LRU_CACHE))) _LRU_CACHE[sequence] = result return result except Exception as e: # noqa: BLE001 return "jab", f"inference failed: {type(e).__name__}: {e}", "fallback" def make_move_mask(distance: str) -> torch.Tensor: """Kept for backwards compatibility with the old TinyFighter API.""" mask = [1.0] * NUM_MOVES return torch.tensor(mask, dtype=torch.float32) # --------------------------------------------------------------------------- # Backwards-compat shims for the old TinyFighter state_to_features, in case # any of the Gradio panel helpers still expect them. # --------------------------------------------------------------------------- def state_to_features( last_npc_moves, last_player_moves, player_hp=100.0, npc_hp=100.0, player_stamina=100.0, npc_stamina=100.0, distance="mid", aggression=0.5, defense=0.5, parry_affinity=0.4, kick_affinity=0.3, grapple_affinity=0.3, round_num=1, history_len=5, ): """Legacy shim. Returns a 168-dim zero tensor (no real features) so old callers don't crash. The new model is move-sequence based, not feature-vector based. """ return torch.zeros(168, dtype=torch.float32) def remap_bn_state_to_ln(state_dict: dict) -> dict: """Legacy shim. The Gemma+LoRA path doesn't need this; pass-through.""" return state_dict if __name__ == "__main__": move, reasoning, source = pick_counter_move("jab,cross,low_kick,roundhouse,uppercut") print(f"move={move} source={source}") print(f"reasoning={reasoning}")