""" data_engine.py — DuckDB lifecycle, schema introspection, and safe query execution. Handles: - Creating per-request in-memory DuckDB connections (thread-safe) - Seeding schema + data (from seed.sql or programmatically) - Schema introspection for prompt context - extract_sql(): JSON envelope → ```sql``` block → raw fallback - validate_sql(): forbidden-token check + schema-aware column validation via EXPLAIN - execute_safe(): extraction, validation, timeout, subquery wrapping, execution """ from __future__ import annotations import json import re import time import threading import duckdb from pathlib import Path # ── Forbidden SQL terms (case-insensitive) ───────────────────────────── FORBIDDEN_TOKENS = [ "drop", "delete", "insert", "update", "alter", "truncate", "create", "attach", "detach", "pragma", ] # ── Execution limits ─────────────────────────────────────────────────── MAX_RESULT_ROWS = 1000 QUERY_TIMEOUT_SEC = 10 DATA_DIR = Path(__file__).parent / "data" SEED_SQL_PATH = DATA_DIR / "seed.sql" # Parquet files (pre-generated once via data/export_parquet.py). # On HF Spaces, place them in the persistent /data/ directory. _PARQUET_DIRS = [ Path("/data"), # HF Space persistent storage DATA_DIR, # local dev (data/) ] _ENROLLMENT_PQ = "enrollment.parquet" _ATTENDANCE_PQ = "attendance.parquet" # ── Connection factory ───────────────────────────────────────────────── def get_connection(read_only: bool = False) -> duckdb.DuckDBPyConnection: """ Return a fresh in-memory DuckDB connection with safety defaults. Each request gets its own connection for thread safety. """ conn = duckdb.connect(database=":memory:") conn.execute("SET enable_progress_bar = false;") conn.execute(f"SET max_memory = '256MB';") conn.execute(f"SET threads = 2;") return conn # ── Database seeding ─────────────────────────────────────────────────── # All 5 tables (expanded schema for Day 1+) _PARQUET_TABLES = ["enrollment", "attendance", "students", "discipline", "grades"] def seed_database( conn: duckdb.DuckDBPyConnection, seed_sql_path: Path | None = None, ) -> None: """ Create tables and load seed data. Tries, in order: 1. Parquet files (fastest — pre-generated, ~260 KB total) → /data/*.parquet (HF Space persistent storage) → data/*.parquet (local dev) 2. data/seed.sql (custom overrides) 3. Python generator (slow fallback, ~20s) """ # ── 1. Parquet files ───────────────────────────────────────── for base in _PARQUET_DIRS: if all((base / f"{t}.parquet").exists() for t in _PARQUET_TABLES): for table in _PARQUET_TABLES: pq_path = base / f"{table}.parquet" conn.execute( f"CREATE TABLE {table} AS " f"SELECT * FROM read_parquet('{pq_path}')" ) return # ── 2. seed.sql ────────────────────────────────────────────── if seed_sql_path is None: seed_sql_path = SEED_SQL_PATH if seed_sql_path.exists(): with open(seed_sql_path) as f: sql = f.read() for statement in sql.split(";"): statement = statement.strip() if statement and not statement.startswith("--"): conn.execute(statement) return # ── 3. Python generator (slow) ─────────────────────────────── from data.generate_seed import generate_seed_data generate_seed_data(conn) # ── Schema introspection ─────────────────────────────────────────────── def get_schema_info(conn: duckdb.DuckDBPyConnection) -> dict[str, list[tuple[str, str, str]]]: """ Introspect the database schema for prompt context. Returns: dict: table_name -> [(column_name, type, "")] The description field is empty — we rely on the prompt's table docs. """ tables = conn.execute( "SELECT table_name FROM information_schema.tables " "WHERE table_schema = 'main' ORDER BY table_name" ).fetchall() schema = {} for (table_name,) in tables: cols = conn.execute( f"SELECT column_name, data_type FROM information_schema.columns " f"WHERE table_name = '{table_name}' ORDER BY ordinal_position" ).fetchall() schema[table_name] = [(name, dtype, "") for name, dtype in cols] return schema # ── JSON envelope parsing ────────────────────────────────────────────── def _try_parse_json_envelope(text: str) -> str | None: """ Try to parse the LLM output as a JSON envelope like: {"sql": "SELECT ...", "explanation": "..."} Returns the SQL string if found, or None. """ # Try to find a JSON object anywhere in the text json_match = re.search(r'\{[^{}]*"sql"\s*:\s*"[^"]+"[^{}]*\}', text, re.DOTALL) if not json_match: return None try: obj = json.loads(json_match.group(0)) if isinstance(obj, dict) and "sql" in obj: return obj["sql"] except (json.JSONDecodeError, KeyError): pass return None # ── SQL extraction ───────────────────────────────────────────────────── def extract_sql(raw_llm_output: str) -> str: """ Extract SQL from LLM output. Tries, in order: 1. JSON envelope: {"sql": "...", "explanation": "..."} 2. ```sql ... ``` markdown block 3. Generic ``` ... ``` code block 4. Raw text fallback Always strips trailing semicolons (they break subquery wrapping). """ # 1. Try JSON envelope first json_sql = _try_parse_json_envelope(raw_llm_output) if json_sql: return json_sql.strip().rstrip(";") # 2. Try ```sql ... ``` sql_match = re.search(r"```sql\s*\n?(.*?)```", raw_llm_output, re.DOTALL | re.IGNORECASE) if sql_match: return sql_match.group(1).strip().rstrip(";") # 3. Try generic ``` ... ``` code_match = re.search(r"```\s*\n?(.*?)```", raw_llm_output, re.DOTALL) if code_match: return code_match.group(1).strip().rstrip(";") # 4. Fallback: return raw text, stripped return raw_llm_output.strip().rstrip(";") # ── SQL validation ───────────────────────────────────────────────────── def validate_sql(sql: str, conn: duckdb.DuckDBPyConnection | None = None) -> None: """ Validate that the SQL is safe and refers to real columns. Layer 1 — static checks (always run): - Not empty - Contains SELECT - No forbidden tokens (DROP, DELETE, INSERT, etc.) Layer 2 — schema-aware validation (if conn provided): - Runs EXPLAIN against the actual schema to catch missing columns, unknown tables, and syntax errors before execution. Raises ValueError with a user-facing message on any failure. """ if not sql: raise ValueError("Empty SQL query — nothing to execute.") # Check forbidden tokens FIRST (before SELECT check — DROP/INSERT # statements don't contain SELECT but are more dangerous) sql_lower = sql.lower() for token in FORBIDDEN_TOKENS: if re.search(rf"\b{token}\b", sql_lower): raise ValueError( f"Forbidden operation detected: '{token}'. Only SELECT queries are allowed." ) sql_upper = sql.upper() if "SELECT" not in sql_upper: raise ValueError("Only SELECT queries are allowed. No SELECT found.") # Schema-aware validation via DuckDB EXPLAIN if conn is not None: try: conn.execute(f"EXPLAIN {sql}") except duckdb.Error as e: # Surface the DuckDB error (e.g., "column 'foo' does not exist") msg = str(e).strip() # Clean up common DuckDB error prefixes for user-friendliness for prefix in ["Parser Error: ", "Catalog Error: ", "Binder Error: "]: if msg.startswith(prefix): msg = msg[len(prefix):] raise ValueError(f"SQL validation failed: {msg}") from e # ── Timeout helper ───────────────────────────────────────────────────── class QueryTimeoutError(TimeoutError): """Raised when a query exceeds the time budget.""" pass def _execute_with_timeout( conn: duckdb.DuckDBPyConnection, sql: str, timeout_sec: int, ): """ Execute SQL with a Python-level timeout via conn.interrupt(). DuckDB doesn't have a built-in SET query_timeout, so we use a watchdog thread that calls conn.interrupt() after the deadline. """ result = {"df": None, "error": None} done = threading.Event() def run(): try: result["df"] = conn.execute(sql).fetchdf() except Exception as e: result["error"] = e finally: done.set() thread = threading.Thread(target=run, daemon=True) thread.start() if not done.wait(timeout=timeout_sec): # Timed out — interrupt the DuckDB connection conn.interrupt() thread.join(timeout=2) raise QueryTimeoutError(f"Query timed out after {timeout_sec}s.") if result["error"]: raise result["error"] return result["df"] # ── Safe SQL execution ───────────────────────────────────────────────── def execute_safe( conn: duckdb.DuckDBPyConnection, raw_llm_output: str, timeout_sec: int = QUERY_TIMEOUT_SEC, ) -> tuple[str, "DataFrame"]: """ Extract, validate, and execute LLM-generated SQL. Pipeline: 1. extract_sql() — parse JSON / ```sql``` / raw 2. validate_sql() — static checks + schema-aware EXPLAIN 3. Wrap in SELECT * FROM () AS _safe LIMIT {MAX_RESULT_ROWS} 4. Execute directly (DuckDB in-memory is fast, no timeout needed) 5. Return (cleaned_sql, dataframe) Returns: (cleaned_sql, duckdb.DataFrame) Raises: ValueError: if SQL is invalid or references unknown columns/tables. duckdb.Error: on database-level failures. """ sql = extract_sql(raw_llm_output) validate_sql(sql, conn=conn) # Safety wrap: SELECT * FROM () LIMIT MAX_RESULT_ROWS safe_sql = f"SELECT * FROM (\n{sql}\n) AS _safe LIMIT {MAX_RESULT_ROWS}" df = conn.execute(safe_sql).fetchdf() return sql, df # ── Full pipeline (for use in app.py) ────────────────────────────────── def create_session() -> duckdb.DuckDBPyConnection: """Create a seeded DuckDB connection ready for queries.""" conn = get_connection() seed_database(conn) return conn