File size: 10,442 Bytes
068d371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
"""NPC move-selection model: Gemma 3 270M + Sathvik0101/cyber-duel-tiny-users LoRA.

Two execution paths:
  1. **Modal GPU** (preferred): when GEMMA_SERVER is set, /pick_move is called
     remotely. Cold start ~1s, warm ~600-1500ms on T4.
  2. **Local CPU** (fallback): when GEMMA_SERVER is unset, loads the model
     in-process. Used for local dev and as a safety net.

Public API (kept stable for app.py):
    - MOVES: list of legal move names
    - get_model(): returns (None, None) when remote, otherwise (model, tok)
    - pick_counter_move(sequence) -> (move, reasoning, source)
"""
from __future__ import annotations

import os
import re
import time
from pathlib import Path
from typing import Optional, Tuple

import torch

ADAPTER_DIR = Path(__file__).resolve().parent / "adapters" / "ref"
# Use the un-gated unsloth mirror of the same Gemma 3 270M model.
BASE_MODEL_ID = os.environ.get("GEMMA_BASE_MODEL", "unsloth/gemma-3-270m-it")

GEMMA_SERVER = os.environ.get("GEMMA_SERVER", "").rstrip("/")

CANONICAL_MOVES = (
    "jab", "cross", "low_kick", "roundhouse", "uppercut",
    "parry", "backstep", "clinch", "throw",
)

MOVE_TO_IDX = {m: i for i, m in enumerate(CANONICAL_MOVES)}
MOVES = list(CANONICAL_MOVES)
NUM_MOVES = len(MOVES)

SYSTEM_PROMPT = (
    "You are an expert fighting game NPC AI. "
    "Given the player's last 5 moves, output a single best counter-move "
    "from: jab, cross, low_kick, roundhouse, uppercut, parry, backstep, "
    "clinch, throw. Format your reply as:\n"
    "[one short sentence of reasoning]\n"
    "counter_move: <one_move_name>"
)

_model = None
_tokenizer = None
_model_lock_path: Optional[Path] = None
_load_error: Optional[str] = None


def _get_hf_token() -> Optional[str]:
    tok = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
    if tok:
        return tok
    cache = Path.home() / ".cache" / "huggingface" / "token"
    if cache.exists():
        return cache.read_text(encoding="utf-8").strip() or None
    return None


def get_model():
    """Lazy singleton loader. Returns (model, tokenizer) or raises."""
    global _model, _tokenizer, _load_error
    if _model is not None and _tokenizer is not None:
        return _model, _tokenizer

    if not ADAPTER_DIR.exists():
        raise FileNotFoundError(
            f"LoRA adapter not found at {ADAPTER_DIR}. "
            "Make sure adapters/ref/ is bundled with the Space."
        )

    from huggingface_hub import snapshot_download
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from peft import PeftModel

    token = _get_hf_token()
    # Use bf16 on CPU (newer torch supports it; falls back to fp32 otherwise).
    bf16_ok = False
    try:
        bf16_ok = torch.cuda.is_available() or hasattr(torch, "to_bf16")
    except Exception:
        bf16_ok = False
    dtype = torch.bfloat16 if bf16_ok else torch.float32

    print(f"[gemma_npc] Loading tokenizer + base {BASE_MODEL_ID} (dtype={dtype})...")
    t0 = time.perf_counter()

    # Pull the base model. snapshot_download lets us reuse the cached HF store
    # so cold-starts aren't 500MB of network on every restart.
    base_path = snapshot_download(repo_id=BASE_MODEL_ID, token=token)

    _tokenizer = AutoTokenizer.from_pretrained(base_path, token=token)
    base = AutoModelForCausalLM.from_pretrained(
        base_path,
        token=token,
        torch_dtype=dtype,
    )
    base.eval()

    print(f"[gemma_npc] Loading LoRA adapter from {ADAPTER_DIR}...")
    _model = PeftModel.from_pretrained(base, str(ADAPTER_DIR))
    _model.eval()

    # Warmup so the first user request doesn't pay dispatch cost.
    _generate("jab,cross,low_kick,roundhouse,uppercut")

    print(f"[gemma_npc] Model ready in {time.perf_counter() - t0:.1f}s")
    return _model, _tokenizer


def _generate(sequence: str, max_new_tokens: int = 18) -> str:
    # Tiny LRU cache -- repeated move sequences are common in fighting games,
    # so a second request with the same sequence returns instantly.
    cached = _LRU_CACHE.get(sequence)
    if cached is not None:
        return cached[1]

    model, tokenizer = get_model()
    messages = [
        {"role": "user", "content": _format_user_prompt(sequence)},
    ]
    # `apply_chat_template` with `tokenize=True, return_tensors="pt"` returns
    # a `BatchEncoding` (subclass of tokenizers.Encoding) in transformers
    # 5.x, not a bare tensor -- so we ask for `return_dict=True` and index
    # `input_ids` / `attention_mask` ourselves.
    encoded = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
        return_dict=True,
    )
    input_ids = encoded["input_ids"]
    attn_mask = encoded.get("attention_mask", torch.ones_like(input_ids))
    eos_id = tokenizer.convert_tokens_to_ids("<end_of_turn>")
    with torch.inference_mode():
        out = model.generate(
            input_ids=input_ids,
            attention_mask=attn_mask,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=1.0,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=eos_id,
        )
    new_tokens = out[0][input_ids.shape[-1]:]
    text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
    return text


def _format_user_prompt(sequence: str) -> str:
    parts = [p.strip() for p in (sequence or "").split(",") if p.strip()]
    parts = parts[-5:] if len(parts) > 5 else parts
    if not parts:
        parts = ["jab"]
    return (
        f"Player's last moves: {','.join(parts)}.\n"
        f"Choose the best counter from: jab, cross, low_kick, roundhouse, "
        f"uppercut, parry, backstep, clinch, throw.\n"
        f"Respond with EXACTLY one line: counter_move: <name>"
    )


_LRU_CACHE: "dict[str, Tuple[str, str, str]]" = {}
_LRU_CACHE_MAX = 64


_MOVE_PATTERN = re.compile(r"\b(" + "|".join(re.escape(m) for m in CANONICAL_MOVES) + r")\b", re.IGNORECASE)


def _parse_move(text: str) -> Tuple[str, str]:
    """Extract the model's chosen move and reasoning from its output.

    Falls back to scanning the text for any canonical move name if the
    `counter_move:` line is missing.
    """
    reasoning = text.strip()
    chosen: Optional[str] = None

    m = re.search(r"counter_move\s*:\s*([A-Za-z_]+)", reasoning, re.IGNORECASE)
    if m:
        candidate = m.group(1).strip().lower()
        if candidate in MOVE_TO_IDX:
            chosen = candidate
        reasoning = reasoning[: m.start()].strip() or reasoning

    if not chosen:
        # Look for the first canonical move name in the output.
        m2 = _MOVE_PATTERN.search(reasoning)
        if m2:
            chosen = m2.group(1).lower()
        else:
            # Final fallback: deterministic.
            chosen = "jab"
            reasoning = f"could not parse model output: {reasoning[:80]}"

    return chosen, reasoning[:200]


def pick_counter_move(sequence: str) -> Tuple[str, str, str]:
    """Run the model on a player-move sequence.

    Returns (move, reasoning, source) where `source` is "gemma_lora" on
    success and "fallback" if the model can't be loaded.
    """
    cached = _LRU_CACHE.get(sequence)
    if cached is not None:
        return cached

    # --- Modal GPU path: fast remote inference on T4 -------------------
    if GEMMA_SERVER:
        import httpx
        try:
            t0 = time.perf_counter()
            r = httpx.post(
                f"{GEMMA_SERVER}/pick_move",
                json={"sequence": sequence},
                timeout=20.0,
            )
            r.raise_for_status()
            data = r.json()
            ms = (time.perf_counter() - t0) * 1000
            move = data.get("move") or "jab"
            if move not in MOVE_TO_IDX:
                move = "jab"
            reasoning = (data.get("reasoning") or "")[:200]
            src = data.get("source", "gemma_modal")
            result = (move, f"{reasoning}  (modal {ms:.0f}ms)", src)
        except Exception as e:  # noqa: BLE001
            return "jab", f"modal failed: {type(e).__name__}: {e}", "fallback"
        if len(_LRU_CACHE) >= _LRU_CACHE_MAX:
            _LRU_CACHE.pop(next(iter(_LRU_CACHE)))
        _LRU_CACHE[sequence] = result
        return result

    # --- Local CPU fallback --------------------------------------------
    try:
        model, _ = get_model()
    except Exception as e:  # noqa: BLE001
        return "jab", f"model unavailable: {type(e).__name__}: {e}", "fallback"

    try:
        t0 = time.perf_counter()
        text = _generate(sequence)
        ms = (time.perf_counter() - t0) * 1000
        move, reasoning = _parse_move(text)
        result = (move, f"{reasoning}  ({ms:.0f}ms)", "gemma_lora")
        if len(_LRU_CACHE) >= _LRU_CACHE_MAX:
            _LRU_CACHE.pop(next(iter(_LRU_CACHE)))
        _LRU_CACHE[sequence] = result
        return result
    except Exception as e:  # noqa: BLE001
        return "jab", f"inference failed: {type(e).__name__}: {e}", "fallback"


def make_move_mask(distance: str) -> torch.Tensor:
    """Kept for backwards compatibility with the old TinyFighter API."""
    mask = [1.0] * NUM_MOVES
    return torch.tensor(mask, dtype=torch.float32)


# ---------------------------------------------------------------------------
# Backwards-compat shims for the old TinyFighter state_to_features, in case
# any of the Gradio panel helpers still expect them.
# ---------------------------------------------------------------------------
def state_to_features(
    last_npc_moves,
    last_player_moves,
    player_hp=100.0,
    npc_hp=100.0,
    player_stamina=100.0,
    npc_stamina=100.0,
    distance="mid",
    aggression=0.5,
    defense=0.5,
    parry_affinity=0.4,
    kick_affinity=0.3,
    grapple_affinity=0.3,
    round_num=1,
    history_len=5,
):
    """Legacy shim. Returns a 168-dim zero tensor (no real features) so old
    callers don't crash. The new model is move-sequence based, not
    feature-vector based.
    """
    return torch.zeros(168, dtype=torch.float32)


def remap_bn_state_to_ln(state_dict: dict) -> dict:
    """Legacy shim. The Gemma+LoRA path doesn't need this; pass-through."""
    return state_dict


if __name__ == "__main__":
    move, reasoning, source = pick_counter_move("jab,cross,low_kick,roundhouse,uppercut")
    print(f"move={move}  source={source}")
    print(f"reasoning={reasoning}")