duel / gemma_npc.py
sankalphs's picture
Upload gemma_npc.py with huggingface_hub
068d371 verified
Raw
History Blame Contribute Delete
10.4 kB
"""NPC move-selection model: Gemma 3 270M + Sathvik0101/cyber-duel-tiny-users LoRA.
Two execution paths:
1. **Modal GPU** (preferred): when GEMMA_SERVER is set, /pick_move is called
remotely. Cold start ~1s, warm ~600-1500ms on T4.
2. **Local CPU** (fallback): when GEMMA_SERVER is unset, loads the model
in-process. Used for local dev and as a safety net.
Public API (kept stable for app.py):
- MOVES: list of legal move names
- get_model(): returns (None, None) when remote, otherwise (model, tok)
- pick_counter_move(sequence) -> (move, reasoning, source)
"""
from __future__ import annotations
import os
import re
import time
from pathlib import Path
from typing import Optional, Tuple
import torch
ADAPTER_DIR = Path(__file__).resolve().parent / "adapters" / "ref"
# Use the un-gated unsloth mirror of the same Gemma 3 270M model.
BASE_MODEL_ID = os.environ.get("GEMMA_BASE_MODEL", "unsloth/gemma-3-270m-it")
GEMMA_SERVER = os.environ.get("GEMMA_SERVER", "").rstrip("/")
CANONICAL_MOVES = (
"jab", "cross", "low_kick", "roundhouse", "uppercut",
"parry", "backstep", "clinch", "throw",
)
MOVE_TO_IDX = {m: i for i, m in enumerate(CANONICAL_MOVES)}
MOVES = list(CANONICAL_MOVES)
NUM_MOVES = len(MOVES)
SYSTEM_PROMPT = (
"You are an expert fighting game NPC AI. "
"Given the player's last 5 moves, output a single best counter-move "
"from: jab, cross, low_kick, roundhouse, uppercut, parry, backstep, "
"clinch, throw. Format your reply as:\n"
"[one short sentence of reasoning]\n"
"counter_move: <one_move_name>"
)
_model = None
_tokenizer = None
_model_lock_path: Optional[Path] = None
_load_error: Optional[str] = None
def _get_hf_token() -> Optional[str]:
tok = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
if tok:
return tok
cache = Path.home() / ".cache" / "huggingface" / "token"
if cache.exists():
return cache.read_text(encoding="utf-8").strip() or None
return None
def get_model():
"""Lazy singleton loader. Returns (model, tokenizer) or raises."""
global _model, _tokenizer, _load_error
if _model is not None and _tokenizer is not None:
return _model, _tokenizer
if not ADAPTER_DIR.exists():
raise FileNotFoundError(
f"LoRA adapter not found at {ADAPTER_DIR}. "
"Make sure adapters/ref/ is bundled with the Space."
)
from huggingface_hub import snapshot_download
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
token = _get_hf_token()
# Use bf16 on CPU (newer torch supports it; falls back to fp32 otherwise).
bf16_ok = False
try:
bf16_ok = torch.cuda.is_available() or hasattr(torch, "to_bf16")
except Exception:
bf16_ok = False
dtype = torch.bfloat16 if bf16_ok else torch.float32
print(f"[gemma_npc] Loading tokenizer + base {BASE_MODEL_ID} (dtype={dtype})...")
t0 = time.perf_counter()
# Pull the base model. snapshot_download lets us reuse the cached HF store
# so cold-starts aren't 500MB of network on every restart.
base_path = snapshot_download(repo_id=BASE_MODEL_ID, token=token)
_tokenizer = AutoTokenizer.from_pretrained(base_path, token=token)
base = AutoModelForCausalLM.from_pretrained(
base_path,
token=token,
torch_dtype=dtype,
)
base.eval()
print(f"[gemma_npc] Loading LoRA adapter from {ADAPTER_DIR}...")
_model = PeftModel.from_pretrained(base, str(ADAPTER_DIR))
_model.eval()
# Warmup so the first user request doesn't pay dispatch cost.
_generate("jab,cross,low_kick,roundhouse,uppercut")
print(f"[gemma_npc] Model ready in {time.perf_counter() - t0:.1f}s")
return _model, _tokenizer
def _generate(sequence: str, max_new_tokens: int = 18) -> str:
# Tiny LRU cache -- repeated move sequences are common in fighting games,
# so a second request with the same sequence returns instantly.
cached = _LRU_CACHE.get(sequence)
if cached is not None:
return cached[1]
model, tokenizer = get_model()
messages = [
{"role": "user", "content": _format_user_prompt(sequence)},
]
# `apply_chat_template` with `tokenize=True, return_tensors="pt"` returns
# a `BatchEncoding` (subclass of tokenizers.Encoding) in transformers
# 5.x, not a bare tensor -- so we ask for `return_dict=True` and index
# `input_ids` / `attention_mask` ourselves.
encoded = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
)
input_ids = encoded["input_ids"]
attn_mask = encoded.get("attention_mask", torch.ones_like(input_ids))
eos_id = tokenizer.convert_tokens_to_ids("<end_of_turn>")
with torch.inference_mode():
out = model.generate(
input_ids=input_ids,
attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=1.0,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=eos_id,
)
new_tokens = out[0][input_ids.shape[-1]:]
text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
return text
def _format_user_prompt(sequence: str) -> str:
parts = [p.strip() for p in (sequence or "").split(",") if p.strip()]
parts = parts[-5:] if len(parts) > 5 else parts
if not parts:
parts = ["jab"]
return (
f"Player's last moves: {','.join(parts)}.\n"
f"Choose the best counter from: jab, cross, low_kick, roundhouse, "
f"uppercut, parry, backstep, clinch, throw.\n"
f"Respond with EXACTLY one line: counter_move: <name>"
)
_LRU_CACHE: "dict[str, Tuple[str, str, str]]" = {}
_LRU_CACHE_MAX = 64
_MOVE_PATTERN = re.compile(r"\b(" + "|".join(re.escape(m) for m in CANONICAL_MOVES) + r")\b", re.IGNORECASE)
def _parse_move(text: str) -> Tuple[str, str]:
"""Extract the model's chosen move and reasoning from its output.
Falls back to scanning the text for any canonical move name if the
`counter_move:` line is missing.
"""
reasoning = text.strip()
chosen: Optional[str] = None
m = re.search(r"counter_move\s*:\s*([A-Za-z_]+)", reasoning, re.IGNORECASE)
if m:
candidate = m.group(1).strip().lower()
if candidate in MOVE_TO_IDX:
chosen = candidate
reasoning = reasoning[: m.start()].strip() or reasoning
if not chosen:
# Look for the first canonical move name in the output.
m2 = _MOVE_PATTERN.search(reasoning)
if m2:
chosen = m2.group(1).lower()
else:
# Final fallback: deterministic.
chosen = "jab"
reasoning = f"could not parse model output: {reasoning[:80]}"
return chosen, reasoning[:200]
def pick_counter_move(sequence: str) -> Tuple[str, str, str]:
"""Run the model on a player-move sequence.
Returns (move, reasoning, source) where `source` is "gemma_lora" on
success and "fallback" if the model can't be loaded.
"""
cached = _LRU_CACHE.get(sequence)
if cached is not None:
return cached
# --- Modal GPU path: fast remote inference on T4 -------------------
if GEMMA_SERVER:
import httpx
try:
t0 = time.perf_counter()
r = httpx.post(
f"{GEMMA_SERVER}/pick_move",
json={"sequence": sequence},
timeout=20.0,
)
r.raise_for_status()
data = r.json()
ms = (time.perf_counter() - t0) * 1000
move = data.get("move") or "jab"
if move not in MOVE_TO_IDX:
move = "jab"
reasoning = (data.get("reasoning") or "")[:200]
src = data.get("source", "gemma_modal")
result = (move, f"{reasoning} (modal {ms:.0f}ms)", src)
except Exception as e: # noqa: BLE001
return "jab", f"modal failed: {type(e).__name__}: {e}", "fallback"
if len(_LRU_CACHE) >= _LRU_CACHE_MAX:
_LRU_CACHE.pop(next(iter(_LRU_CACHE)))
_LRU_CACHE[sequence] = result
return result
# --- Local CPU fallback --------------------------------------------
try:
model, _ = get_model()
except Exception as e: # noqa: BLE001
return "jab", f"model unavailable: {type(e).__name__}: {e}", "fallback"
try:
t0 = time.perf_counter()
text = _generate(sequence)
ms = (time.perf_counter() - t0) * 1000
move, reasoning = _parse_move(text)
result = (move, f"{reasoning} ({ms:.0f}ms)", "gemma_lora")
if len(_LRU_CACHE) >= _LRU_CACHE_MAX:
_LRU_CACHE.pop(next(iter(_LRU_CACHE)))
_LRU_CACHE[sequence] = result
return result
except Exception as e: # noqa: BLE001
return "jab", f"inference failed: {type(e).__name__}: {e}", "fallback"
def make_move_mask(distance: str) -> torch.Tensor:
"""Kept for backwards compatibility with the old TinyFighter API."""
mask = [1.0] * NUM_MOVES
return torch.tensor(mask, dtype=torch.float32)
# ---------------------------------------------------------------------------
# Backwards-compat shims for the old TinyFighter state_to_features, in case
# any of the Gradio panel helpers still expect them.
# ---------------------------------------------------------------------------
def state_to_features(
last_npc_moves,
last_player_moves,
player_hp=100.0,
npc_hp=100.0,
player_stamina=100.0,
npc_stamina=100.0,
distance="mid",
aggression=0.5,
defense=0.5,
parry_affinity=0.4,
kick_affinity=0.3,
grapple_affinity=0.3,
round_num=1,
history_len=5,
):
"""Legacy shim. Returns a 168-dim zero tensor (no real features) so old
callers don't crash. The new model is move-sequence based, not
feature-vector based.
"""
return torch.zeros(168, dtype=torch.float32)
def remap_bn_state_to_ln(state_dict: dict) -> dict:
"""Legacy shim. The Gemma+LoRA path doesn't need this; pass-through."""
return state_dict
if __name__ == "__main__":
move, reasoning, source = pick_counter_move("jab,cross,low_kick,roundhouse,uppercut")
print(f"move={move} source={source}")
print(f"reasoning={reasoning}")