Spaces:
Running
Running
| """NPC move-selection model: Gemma 3 270M + Sathvik0101/cyber-duel-tiny-users LoRA. | |
| Two execution paths: | |
| 1. **Modal GPU** (preferred): when GEMMA_SERVER is set, /pick_move is called | |
| remotely. Cold start ~1s, warm ~600-1500ms on T4. | |
| 2. **Local CPU** (fallback): when GEMMA_SERVER is unset, loads the model | |
| in-process. Used for local dev and as a safety net. | |
| Public API (kept stable for app.py): | |
| - MOVES: list of legal move names | |
| - get_model(): returns (None, None) when remote, otherwise (model, tok) | |
| - pick_counter_move(sequence) -> (move, reasoning, source) | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import re | |
| import time | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| import torch | |
| ADAPTER_DIR = Path(__file__).resolve().parent / "adapters" / "ref" | |
| # Use the un-gated unsloth mirror of the same Gemma 3 270M model. | |
| BASE_MODEL_ID = os.environ.get("GEMMA_BASE_MODEL", "unsloth/gemma-3-270m-it") | |
| GEMMA_SERVER = os.environ.get("GEMMA_SERVER", "").rstrip("/") | |
| CANONICAL_MOVES = ( | |
| "jab", "cross", "low_kick", "roundhouse", "uppercut", | |
| "parry", "backstep", "clinch", "throw", | |
| ) | |
| MOVE_TO_IDX = {m: i for i, m in enumerate(CANONICAL_MOVES)} | |
| MOVES = list(CANONICAL_MOVES) | |
| NUM_MOVES = len(MOVES) | |
| SYSTEM_PROMPT = ( | |
| "You are an expert fighting game NPC AI. " | |
| "Given the player's last 5 moves, output a single best counter-move " | |
| "from: jab, cross, low_kick, roundhouse, uppercut, parry, backstep, " | |
| "clinch, throw. Format your reply as:\n" | |
| "[one short sentence of reasoning]\n" | |
| "counter_move: <one_move_name>" | |
| ) | |
| _model = None | |
| _tokenizer = None | |
| _model_lock_path: Optional[Path] = None | |
| _load_error: Optional[str] = None | |
| def _get_hf_token() -> Optional[str]: | |
| tok = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") | |
| if tok: | |
| return tok | |
| cache = Path.home() / ".cache" / "huggingface" / "token" | |
| if cache.exists(): | |
| return cache.read_text(encoding="utf-8").strip() or None | |
| return None | |
| def get_model(): | |
| """Lazy singleton loader. Returns (model, tokenizer) or raises.""" | |
| global _model, _tokenizer, _load_error | |
| if _model is not None and _tokenizer is not None: | |
| return _model, _tokenizer | |
| if not ADAPTER_DIR.exists(): | |
| raise FileNotFoundError( | |
| f"LoRA adapter not found at {ADAPTER_DIR}. " | |
| "Make sure adapters/ref/ is bundled with the Space." | |
| ) | |
| from huggingface_hub import snapshot_download | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| token = _get_hf_token() | |
| # Use bf16 on CPU (newer torch supports it; falls back to fp32 otherwise). | |
| bf16_ok = False | |
| try: | |
| bf16_ok = torch.cuda.is_available() or hasattr(torch, "to_bf16") | |
| except Exception: | |
| bf16_ok = False | |
| dtype = torch.bfloat16 if bf16_ok else torch.float32 | |
| print(f"[gemma_npc] Loading tokenizer + base {BASE_MODEL_ID} (dtype={dtype})...") | |
| t0 = time.perf_counter() | |
| # Pull the base model. snapshot_download lets us reuse the cached HF store | |
| # so cold-starts aren't 500MB of network on every restart. | |
| base_path = snapshot_download(repo_id=BASE_MODEL_ID, token=token) | |
| _tokenizer = AutoTokenizer.from_pretrained(base_path, token=token) | |
| base = AutoModelForCausalLM.from_pretrained( | |
| base_path, | |
| token=token, | |
| torch_dtype=dtype, | |
| ) | |
| base.eval() | |
| print(f"[gemma_npc] Loading LoRA adapter from {ADAPTER_DIR}...") | |
| _model = PeftModel.from_pretrained(base, str(ADAPTER_DIR)) | |
| _model.eval() | |
| # Warmup so the first user request doesn't pay dispatch cost. | |
| _generate("jab,cross,low_kick,roundhouse,uppercut") | |
| print(f"[gemma_npc] Model ready in {time.perf_counter() - t0:.1f}s") | |
| return _model, _tokenizer | |
| def _generate(sequence: str, max_new_tokens: int = 18) -> str: | |
| # Tiny LRU cache -- repeated move sequences are common in fighting games, | |
| # so a second request with the same sequence returns instantly. | |
| cached = _LRU_CACHE.get(sequence) | |
| if cached is not None: | |
| return cached[1] | |
| model, tokenizer = get_model() | |
| messages = [ | |
| {"role": "user", "content": _format_user_prompt(sequence)}, | |
| ] | |
| # `apply_chat_template` with `tokenize=True, return_tensors="pt"` returns | |
| # a `BatchEncoding` (subclass of tokenizers.Encoding) in transformers | |
| # 5.x, not a bare tensor -- so we ask for `return_dict=True` and index | |
| # `input_ids` / `attention_mask` ourselves. | |
| encoded = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt", | |
| return_dict=True, | |
| ) | |
| input_ids = encoded["input_ids"] | |
| attn_mask = encoded.get("attention_mask", torch.ones_like(input_ids)) | |
| eos_id = tokenizer.convert_tokens_to_ids("<end_of_turn>") | |
| with torch.inference_mode(): | |
| out = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attn_mask, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| temperature=1.0, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=eos_id, | |
| ) | |
| new_tokens = out[0][input_ids.shape[-1]:] | |
| text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| return text | |
| def _format_user_prompt(sequence: str) -> str: | |
| parts = [p.strip() for p in (sequence or "").split(",") if p.strip()] | |
| parts = parts[-5:] if len(parts) > 5 else parts | |
| if not parts: | |
| parts = ["jab"] | |
| return ( | |
| f"Player's last moves: {','.join(parts)}.\n" | |
| f"Choose the best counter from: jab, cross, low_kick, roundhouse, " | |
| f"uppercut, parry, backstep, clinch, throw.\n" | |
| f"Respond with EXACTLY one line: counter_move: <name>" | |
| ) | |
| _LRU_CACHE: "dict[str, Tuple[str, str, str]]" = {} | |
| _LRU_CACHE_MAX = 64 | |
| _MOVE_PATTERN = re.compile(r"\b(" + "|".join(re.escape(m) for m in CANONICAL_MOVES) + r")\b", re.IGNORECASE) | |
| def _parse_move(text: str) -> Tuple[str, str]: | |
| """Extract the model's chosen move and reasoning from its output. | |
| Falls back to scanning the text for any canonical move name if the | |
| `counter_move:` line is missing. | |
| """ | |
| reasoning = text.strip() | |
| chosen: Optional[str] = None | |
| m = re.search(r"counter_move\s*:\s*([A-Za-z_]+)", reasoning, re.IGNORECASE) | |
| if m: | |
| candidate = m.group(1).strip().lower() | |
| if candidate in MOVE_TO_IDX: | |
| chosen = candidate | |
| reasoning = reasoning[: m.start()].strip() or reasoning | |
| if not chosen: | |
| # Look for the first canonical move name in the output. | |
| m2 = _MOVE_PATTERN.search(reasoning) | |
| if m2: | |
| chosen = m2.group(1).lower() | |
| else: | |
| # Final fallback: deterministic. | |
| chosen = "jab" | |
| reasoning = f"could not parse model output: {reasoning[:80]}" | |
| return chosen, reasoning[:200] | |
| def pick_counter_move(sequence: str) -> Tuple[str, str, str]: | |
| """Run the model on a player-move sequence. | |
| Returns (move, reasoning, source) where `source` is "gemma_lora" on | |
| success and "fallback" if the model can't be loaded. | |
| """ | |
| cached = _LRU_CACHE.get(sequence) | |
| if cached is not None: | |
| return cached | |
| # --- Modal GPU path: fast remote inference on T4 ------------------- | |
| if GEMMA_SERVER: | |
| import httpx | |
| try: | |
| t0 = time.perf_counter() | |
| r = httpx.post( | |
| f"{GEMMA_SERVER}/pick_move", | |
| json={"sequence": sequence}, | |
| timeout=20.0, | |
| ) | |
| r.raise_for_status() | |
| data = r.json() | |
| ms = (time.perf_counter() - t0) * 1000 | |
| move = data.get("move") or "jab" | |
| if move not in MOVE_TO_IDX: | |
| move = "jab" | |
| reasoning = (data.get("reasoning") or "")[:200] | |
| src = data.get("source", "gemma_modal") | |
| result = (move, f"{reasoning} (modal {ms:.0f}ms)", src) | |
| except Exception as e: # noqa: BLE001 | |
| return "jab", f"modal failed: {type(e).__name__}: {e}", "fallback" | |
| if len(_LRU_CACHE) >= _LRU_CACHE_MAX: | |
| _LRU_CACHE.pop(next(iter(_LRU_CACHE))) | |
| _LRU_CACHE[sequence] = result | |
| return result | |
| # --- Local CPU fallback -------------------------------------------- | |
| try: | |
| model, _ = get_model() | |
| except Exception as e: # noqa: BLE001 | |
| return "jab", f"model unavailable: {type(e).__name__}: {e}", "fallback" | |
| try: | |
| t0 = time.perf_counter() | |
| text = _generate(sequence) | |
| ms = (time.perf_counter() - t0) * 1000 | |
| move, reasoning = _parse_move(text) | |
| result = (move, f"{reasoning} ({ms:.0f}ms)", "gemma_lora") | |
| if len(_LRU_CACHE) >= _LRU_CACHE_MAX: | |
| _LRU_CACHE.pop(next(iter(_LRU_CACHE))) | |
| _LRU_CACHE[sequence] = result | |
| return result | |
| except Exception as e: # noqa: BLE001 | |
| return "jab", f"inference failed: {type(e).__name__}: {e}", "fallback" | |
| def make_move_mask(distance: str) -> torch.Tensor: | |
| """Kept for backwards compatibility with the old TinyFighter API.""" | |
| mask = [1.0] * NUM_MOVES | |
| return torch.tensor(mask, dtype=torch.float32) | |
| # --------------------------------------------------------------------------- | |
| # Backwards-compat shims for the old TinyFighter state_to_features, in case | |
| # any of the Gradio panel helpers still expect them. | |
| # --------------------------------------------------------------------------- | |
| def state_to_features( | |
| last_npc_moves, | |
| last_player_moves, | |
| player_hp=100.0, | |
| npc_hp=100.0, | |
| player_stamina=100.0, | |
| npc_stamina=100.0, | |
| distance="mid", | |
| aggression=0.5, | |
| defense=0.5, | |
| parry_affinity=0.4, | |
| kick_affinity=0.3, | |
| grapple_affinity=0.3, | |
| round_num=1, | |
| history_len=5, | |
| ): | |
| """Legacy shim. Returns a 168-dim zero tensor (no real features) so old | |
| callers don't crash. The new model is move-sequence based, not | |
| feature-vector based. | |
| """ | |
| return torch.zeros(168, dtype=torch.float32) | |
| def remap_bn_state_to_ln(state_dict: dict) -> dict: | |
| """Legacy shim. The Gemma+LoRA path doesn't need this; pass-through.""" | |
| return state_dict | |
| if __name__ == "__main__": | |
| move, reasoning, source = pick_counter_move("jab,cross,low_kick,roundhouse,uppercut") | |
| print(f"move={move} source={source}") | |
| print(f"reasoning={reasoning}") | |