# -*- coding: utf-8 -*- """ Created on Fri Apr 3 21:26:48 2026 @author: taten """ # -*- coding: utf-8 -*- """ Created on Fri Apr 3 13:46:18 2026 @author: taten """ #!/usr/bin/env python3 """ # -*- coding: utf-8 -*- """ """ ╔══════════════════════════════════════════════════════════════════════════════╗ ║ ║ ║ ██╗ ██╗ ██╗██████╗ ██╗ ██████╗ ██╗ ██╗ █████╗ ███╗ ██╗████████╗║ ║ ██╗ ██╗ ██╗██████╗ ██╗ ██████╗ ██╗ ██╗ █████╗ ███╗ ██╗████████╗║ ║ ██║ ██╔╝███║██╔══██╗██║ ██╔═══██╗██║ ██║██╔══██╗████╗ ██║╚══██╔══╝║ ║ █████╔╝ ╚██║██████╔╝██║ ██║ ██║██║ ██║███████║██╔██╗ ██║ ██║ ║ ║ ██╔═██╗ ██║██╔══██╗██║ ██║▄▄ ██║██║ ██║██╔══██║██║╚██╗██║ ██║ ║ ║ ██║ ██╗ ██║██║ ██║███████╗ ╚██████╔╝╚██████╔╝██║ ██║██║ ╚████║ ██║ ║ ║ ╚═╝ ╚═╝ ╚═╝╚═╝ ╚═╝╚══════╝ ╚══▀▀═╝ ╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═══╝ ╚═╝ ║ ║ ║ ║ QUASAR AXRVI SIGNIFICANCE RANKER v6 ║ ║ ║ ║ Shreve Continuous-Time • Conditional Expectation • Jump-Diffusion ║ ║ ║ ║ ─────────────────────────────────────────────────────────────────────────── ║ ║ ║ ║ STATUS: ● ACTIVE LATENCY: 0.34ms MARTINGALE GATE: ε ≤ 0.01 ║ ║ MODELS: ENSEMBLE (18) SIGNALS: 7/min AUM: ***** ║ ║ ║ ║ "Where Quadratic Variation Meets Optimal Stopping Theory" ║ ║ ║ ║ [ RANKER ONLINE ] v6.0-shreve | 2026-03-28 ║ ║ ║ ╚══════════════════════════════════════════════════════════════════════════════╝ ╔══════════════════════════════════════════════════════════════════════════════════════╗ ║ ║ ║ QUASAR AXRVI SIGNIFICANCE RANKER v6 — SHREVE CONTINUOUS-TIME FRAMEWORK ║ ║ ───────────────────────────────────────────────────────────────────────────────── ║ ║ Quantum Cross-asset Similarity and Attention Mechanism ║ ║ QCSAM ║ ║ Built on v5 (integrated v3+v4). All nine Shreve upgrades applied below. ║ ║ ║ ║ Shreve Upgrades (Stochastic Calculus for Finance II, Shreve 2004) ║ ║ ───────────────────────────────────────────────────────────────────────────────── ║ ║ ║ ║ [S1] Conditional Expectation as Central Object (§1, tower property) ║ ║ Model output interpreted as V_t = E[R_{t+τ} | F_t]. ║ ║ HybridTrainer gains value-consistency loss L_CE = (V̂_t - R_{t+τ})². ║ ║ RankingEngine priority = discounted conditional-expectation estimate. ║ ║ ║ ║ [S2] Strict Non-Anticipation / Adapted Replay (§4, Itô integral adaptedness) ║ ║ Replay transitions are true (s_t, a_t, r_t, s_{t+1}) tuples with ║ ║ s_t ∈ F_t, s_{t+1} ∈ F_{t+1} captured AFTER the next market update. ║ ║ Pending-episode dict keyed by trade_id eliminates the next≡current bug. ║ ║ ║ ║ [S3] Quadratic Variation Volatility (§3, Brownian QV theorem) ║ ║ Feature [6]: σ̂_t = sqrt(QV_t / Δt), QV_t = Σ (Δ log S_j)² ║ ║ Feature [20]: ΔQV_t = QV_short − QV_long (vol acceleration in QV units) ║ ║ ║ ║ [S4] Directional Log-Return Reward (§1, GBM / Itô log-return identity) ║ ║ R = sgn(a_t) · log(S_{t+τ}/S_t) − fees − slippage_bps/10000 ║ ║ All reward & comparison logic is in log-return units (GBM consistent). ║ ║ ║ ║ [S5] Itô-Doeblin Curvature in Decision Layer (§4, Itô–Doeblin formula) ║ ║ Ranking priority includes a convexity correction: ║ ║ Π_t = D(t,τ)·Ê[R_{t+τ}|F_t] + ½·σ²_t·Δt·convexity_feature ║ ║ Jump-prone / high-convexity assets are ranked with explicit curvature. ║ ║ ║ ║ [S6] Discounted Risk-Neutral Priority (§5, Girsanov / risk-neutral measure) ║ ║ Π_t = Ẽ[D(t,t+τ)·R_{t+τ} | F_t], D(t,τ) = exp(−r·τ) ║ ║ Replaces confidence × significance alone. ║ ║ ║ ║ [S7] Martingale Null-Hypothesis Gate (§4, martingale definition) ║ ║ Gate E (hard filter): if DevMart(t) ≤ ε, market is consistent with ║ ║ E[ΔlogS_{t+1}|F_t] = 0 → no trade. Feature [23] becomes a veto. ║ ║ ║ ║ [S8] Optimal-Stopping Exit Rule (§8, American option / OST) ║ ║ τ* = arg sup_τ E[R_τ | F_t]. ║ ║ Fixed expiry replaced by: continue if C_t > G_t, stop if G_t ≥ C_t. ║ ║ G_t = log(current_price/entry_price) · sgn(direction) − fees ║ ║ C_t = model value estimate of next-tick return (adapted from log_var head) ║ ║ ║ ║ [S9] Jump-Diffusion Intensity Model (§11, compound Poisson) ║ ║ dS = μS dt + σS dW + S_{t−} dJ_t ║ ║ Feature [24]: Poisson intensity λ̂ estimated from empirical jump frequency. ║ ║ CRASH/SPIKE assets tagged in ShreveConfig.jump_diffusion_assets. ║ ║ ║ ║ Preserved from v5 ║ ║ ───────────────────────────────────────────────────────────────────────────────── ║ ║ • AXRVINet LSTM+attention architecture (26-dim input, unchanged) ║ ║ • HubSubscriber, DerivWebSocketClient, PriceStreamer ║ ║ • BanditSelector (UCB/Thompson/Greedy) ║ ║ • GirsanovReplayBuffer, MC dropout, ConservativeRanker ║ ║ • AssetStateBuffer, AdaptiveNormalizer ║ ║ • PositionManager, ranker_logging bridge ║ ║ • All CLI args and sync/async run modes ║ ║ ║ ║ Data Flow (v6) ║ ║ ───────────────────────────────────────────────────────────────────────────────── ║ ║ Hub WS → HubSubscriber → AssetSnapshot → AssetStateBuffer ║ ║ Deriv WS → DerivWebSocketClient → PriceStreamer → AssetStateBuffer ║ ║ AssetStateBuffer → UnifiedFeatureEngine (26-dim, QV vol + jump-diffusion) ║ ║ → AXRVINet + MC Dropout → significance_weight + V_t (conditional expectation) ║ ║ → ShreveRankingEngine: Π_t = D(t,τ)·Ê[R|F_t] + ½σ²Δt·convexity ║ ║ → ConservativeRanker → BanditSelector ║ ║ → Execution Gates A–E (A:signal, B:confidence, C:significance, ║ ║ D:vol/uncertainty/jump, E:martingale null-hypothesis) ║ ║ → Trade execution → PendingEpisodeStore (s_t captured at open) ║ ║ → On close: s_{t+1} captured → GirsanovReplayBuffer → HybridTrainer ║ ║ → OptimalStoppingMonitor: continue if C_t > G_t, else exit ║ ║ ║ ║ Version: v6.0-shreve | 2026-03-28 ║ ╚══════════════════════════════════════════════════════════════════════════════════════╝ """ import argparse import asyncio import json import logging import math import os import queue import shutil import sys import threading import time import traceback from collections import Counter, defaultdict, deque from dataclasses import dataclass, field, asdict from datetime import datetime from enum import Enum from pathlib import Path from threading import Lock from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim try: import websockets except ImportError: websockets = None import warnings # ─── Optional Qiskit / Aer (quantum circuit reference hardware path) ─────────────── try: from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister from qiskit_aer import AerSimulator from qiskit.quantum_info import Statevector QISKIT_AVAILABLE = True except ImportError: try: from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister from qiskit.providers.aer import AerSimulator from qiskit.quantum_info import Statevector QISKIT_AVAILABLE = True except ImportError: QISKIT_AVAILABLE = False # Stub classes so class bodies parse without import errors class QuantumCircuit: # type: ignore def __init__(self, *a, **kw): pass class QuantumRegister: # type: ignore def __init__(self, *a, **kw): pass class ClassicalRegister: # type: ignore def __init__(self, *a, **kw): pass class AerSimulator: # type: ignore def __init__(self, *a, **kw): pass class Statevector: # type: ignore def __init__(self, *a, **kw): pass @property def data(self): return np.zeros(2, dtype=complex) # ─── Optional sklearn (QCSAM data preprocessing) ────────────────────────────────── try: from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler SKLEARN_AVAILABLE = True except ImportError: SKLEARN_AVAILABLE = False class PCA: # type: ignore def __init__(self, *a, **kw): pass def fit_transform(self, x): return x def transform(self, x): return x class StandardScaler: # type: ignore def __init__(self, *a, **kw): pass def fit_transform(self, x): return x def transform(self, x): return x # ─── Quantum numeric constants ───────────────────────────────────────────────────── _EPS: float = 1e-9 # denominator guard for statevector normalisation _NORM_TOL: float = 1e-4 # tolerance for unit-norm assertions # ─── Optional Structured Logging Integration ─────────────────────────────────────── try: from ranker_logging import ( RankerLogger, RankerLogBridge, EventCategory, LogLevel, ) LOGGING_AVAILABLE = True except ImportError: LOGGING_AVAILABLE = False RankerLogger = None RankerLogBridge = None # ─── System Logging (single basicConfig call) ────────────────────────────────────── logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", stream=sys.stdout, force=True, ) logger = logging.getLogger("QuasarAXRVI") # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 0 — CONFIGURATION DATACLASSES # ══════════════════════════════════════════════════════════════════════════════════════ @dataclass class StochasticCalculusConfig: """Stochastic calculus grounding parameters (from v4).""" use_log_returns: bool = True # GBM-consistent log returns use_quadratic_variation: bool = True # QV-based volatility estimation use_ito_convexity: bool = True # Ito convexity proxy feature use_vol_regime: bool = True # Stochastic volatility regime detection use_leverage_effect: bool = True # Return-vol correlation feature use_martingale_deviation: bool = True # Variance-ratio martingale test use_jump_risk: bool = True # Jump-diffusion crash tension # Annualisation: 252 days × 24 hours × 12 five-minute bars vol_annualization: float = 252 * 24 * 12 qv_window: int = 20 vol_regime_short: int = 5 vol_regime_long: int = 20 martingale_test_window: int = 20 martingale_test_threshold: float = 1.96 # 95 % confidence jump_threshold_std: float = 3.0 # σ-multiples for jump detection @dataclass class PortfolioRiskConfig: """ Institutional position-sizing and portfolio-level risk parameters. total_capital — account equity in base currency (used for sizing) kelly_fraction — fractional Kelly multiplier; 0.5 = half-Kelly max_portfolio_risk — max fraction of capital deployed across all open trades drawdown_halt_pct — drawdown that triggers a temporary trading halt halt_duration_secs — how long the halt lasts before auto-resuming (default 3 min) drawdown_reduce_pct — drawdown that halves all new position sizes cvar_floor — CVaR@5% below this vetoes the trade (no capital deployed) min_notional — floor on trade size in base currency max_notional — ceiling on trade size in base currency """ total_capital: float = 6.0 # $6 micro account kelly_fraction: float = 0.0 # disabled — flat $1 stake always max_portfolio_risk: float = 0.50 # max 50% at risk ($3 = 3×$1 trades) drawdown_halt_pct: float = 0.33 # halt if $2 lost (~33% of $6) halt_duration_secs: float = 60.0 # 1-minute cooldown after halt drawdown_reduce_pct: float = 0.17 # reduce after ~$1 lost cvar_floor: float = -0.02 min_notional: float = 1.0 # Deriv multiplier floor = $1 max_notional: float = 1.0 # flat $1 stake — never size up on $6 @dataclass class UncertaintyConfig: """MC Dropout + conservative-bound parameters (from v4).""" use_mc_dropout: bool = True mc_samples: int = 10 confidence_level: float = 0.95 uncertainty_veto_threshold: float = 0.3 uncertainty_penalty_weight: float = 0.5 @dataclass class GirsanovReplayConfig: """Prioritised replay parameters based on Girsanov measure change (from v4).""" use_prioritized: bool = True priority_alpha: float = 0.6 # degree of prioritisation priority_beta: float = 0.4 # importance-sampling correction min_priority: float = 0.01 max_priority: float = 10.0 use_vol_weighting: bool = True # weight by 1/σ (Girsanov) use_td_error_weighting: bool = True # weight by TD error @dataclass class TradeConfig: amount: float = 1.0 # $1 flat stake — minimum for multipliers expiry_time: int = 60 # 1-minute max duration (optimal-stopping exits earlier) commission_rate: float = 0.001 slippage_bps: float = 2.0 # basis points, applied as log-return deduction [S4] @dataclass class ShreveConfig: """ Parameters governing the Shreve continuous-time upgrades (v6). [S1] value_consistency_loss_weight — λ_CE for L_CE = (V̂_t − R_{t+τ})² [S2] (no extra params — proper (s_t, s_{t+1}) is structural) [S3] qv_dt_seconds — Δt denominator for σ̂² = QV_t/Δt (bar period in seconds) [S4] (no extra params — slippage_bps lives in TradeConfig) [S5] ito_curvature_weight — weight of ½σ²Δt·convexity in priority score [S6] risk_free_rate — annualised r for discount factor D(t,τ) = e^{−rτ} horizon_seconds — τ (trade horizon) used in discounting [S7] martingale_gate_epsilon — hard floor on DevMart(t) for trade to pass [S8] min_holding_ticks — minimum ticks before optimal-stopping is evaluated stopping_value_buffer — G_t must exceed C_t by this margin to trigger [S9] jump_diffusion_assets — set of asset IDs modelled with dJ component jump_intensity_window — rolling window for Poisson intensity estimation λ̂ """ # [S1] value_consistency_loss_weight: float = 0.3 # [S3] qv_dt_seconds: float = 300.0 # 5-minute bars → Δt = 300 s # [S5] ito_curvature_weight: float = 0.05 # [S6] risk_free_rate: float = 0.05 # annualised (e.g. 5 %) horizon_seconds: float = 60.0 # τ = 1 minute — matches 1-min trade window # [S7] martingale_gate_epsilon: float = 0.05 # DevMart(t) must exceed this to trade # [S8] min_holding_ticks: int = 2 # legacy — kept for checkpoint compat, not used min_holding_seconds: float = 20.0 # wall-clock minimum before any exit is evaluated stopping_value_buffer: float = 0.0005 # tight buffer for 1-min window # [S9] jump_diffusion_assets: Tuple[str, ...] = ("CRASH500", "CRASH1000", "STEP200") jump_intensity_window: int = 200 # bars for Poisson intensity λ̂ # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 1 — SHARED CONSTANTS # ══════════════════════════════════════════════════════════════════════════════════════ DERIV_API_KEY = os.environ.get("DERIV_API_KEY", "mXZY9NxhIqIJyrM") DERIV_WS_URL = "wss://ws.binaryws.com/websockets/v3?app_id=1089" # Deriv API symbol → AXRVI internal symbol SYMBOL_MAP = { "R_25": "V25", "1HZ50V": "V50_1s", "R_75": "V75", "1HZ75V": "V75_1s", "JD100": "JD100", "R_100": "V100", "1HZ100V": "V100_1s", "CRASH500": "CRASH500", "CRASH1000": "CRASH1000", "stpRNG2": "STEP200", # CONFIRMED: live Deriv WS API returns "stpRNG2" } SYMBOL_MAP_REVERSE = {v: k for k, v in SYMBOL_MAP.items()} # FIX 4: guard against duplicate values in SYMBOL_MAP (two Deriv symbols # mapping to the same AXRVI id silently overwrites the reverse entry). _reverse_check: dict = {} for _deriv_sym, _axrvi_id in SYMBOL_MAP.items(): if _axrvi_id in _reverse_check: raise RuntimeError( f"[SYMBOL_MAP] Duplicate AXRVI id '{_axrvi_id}' mapped from both " f"'{_reverse_check[_axrvi_id]}' and '{_deriv_sym}'. " f"Fix SYMBOL_MAP so every Deriv symbol maps to a unique AXRVI id." ) _reverse_check[_axrvi_id] = _deriv_sym del _reverse_check, _deriv_sym, _axrvi_id # Per-asset metadata: base volatility and max position fraction ASSET_REGISTRY: Dict[str, dict] = { "V25": {"symbol": "R_25", "base_vol": 25.0, "max_pos": 0.006}, "V50_1s": {"symbol": "1HZ50V", "base_vol": 50.0, "max_pos": 0.004}, "V75": {"symbol": "R_75", "base_vol": 75.0, "max_pos": 0.005}, "V75_1s": {"symbol": "1HZ75V", "base_vol": 75.0, "max_pos": 0.003}, "JD100": {"symbol": "JD100", "base_vol": 100.0, "max_pos": 0.003}, "V100": {"symbol": "R_100", "base_vol": 100.0, "max_pos": 0.004}, "V100_1s": {"symbol": "1HZ100V", "base_vol": 100.0, "max_pos": 0.002}, "CRASH500": {"symbol": "CRASH500", "base_vol": 50.0, "max_pos": 0.003}, "CRASH1000":{"symbol": "CRASH1000", "base_vol": 100.0, "max_pos": 0.002}, "STEP200": {"symbol": "stpRNG2", "base_vol": 200.0, "max_pos": 0.002}, # CONFIRMED live symbol } # ── Per-asset MULTUP/MULTDOWN multipliers (BROKER-VALIDATED ACCEPTABLE RANGES) ── # CRITICAL BUG FIX: Previous hardcoded values were REJECTED by Deriv broker. # Each asset now uses the SMALLEST ACCEPTABLE multiplier from the broker's approved list. # This ensures all trades execute without "Multiplier is not in acceptable range" errors. # # Rationale for smallest multiplier: # - Minimizes required capital # - Reduces max loss per trade # - Ensures execution (any rejected multiplier blocks all trades) # # Accepted ranges per Deriv broker: # V25 — [160, 400, 800, 1200, 1600] → use 160 # V50_1s — [80, 200, 400, 600, 800] → use 80 # V75 — [50, 100, 200, 300, 500] → use 50 # V75_1s — [50, 100, 200, 300, 500] → use 50 # JD100 — [50, 100, 200, 300, 500] → use 50 # V100 — [40, 100, 200, 300, 400] → use 40 # V100_1s — [40, 100, 200, 300, 400] → use 40 # CRASH500 — [100, 150, 200, 300, 400] → use 100 # CRASH1000 — [100, 200, 300, 400, 500] → use 100 # STEP200 — [400, 1000, 2000, 3000, 4000] → use 400 ASSET_MULTIPLIER: Dict[str, int] = { "V25": 160, # FIXED: was 50 → rejected, now 160 ✓ "V50_1s": 80, # FIXED: was 30 → rejected, now 80 ✓ "V75": 50, # FIXED: was 30 → rejected, now 50 ✓ "V75_1s": 50, # FIXED: was 20 → rejected, now 50 ✓ "JD100": 50, # Jump Diffusion 100 — smallest accepted ✓ "V100": 40, # Volatility 100 Index — smallest accepted ✓ "V100_1s": 40, # FIXED: was 15 → rejected, now 40 ✓ "CRASH500": 100, # FIXED: was 10 → rejected, now 100 ✓ "CRASH1000":100, # FIXED: was 10 → rejected, now 100 ✓ "STEP200": 400, # FIXED: was 20 → rejected, now 400 ✓ } # ── Broker's acceptable multiplier ranges (for validation & future fallback) ── ASSET_ACCEPTABLE_MULTIPLIERS: Dict[str, List[int]] = { "V25": [160, 400, 800, 1200, 1600], "V50_1s": [80, 200, 400, 600, 800], "V75": [50, 100, 200, 300, 500], "V75_1s": [50, 100, 200, 300, 500], "JD100": [50, 100, 200, 300, 500], "V100": [40, 100, 200, 300, 400], "V100_1s": [40, 100, 200, 300, 400], "CRASH500": [100, 150, 200, 300, 400], "CRASH1000":[100, 200, 300, 400, 500], "STEP200": [400, 1000, 2000, 3000, 4000], } # Stop-loss as fraction of stake per asset (capped to protect $6 account) # e.g. 0.50 = close when $0.50 of the $1 stake is lost ASSET_STOP_LOSS_FRAC: Dict[str, float] = { "V25": 0.60, "V50_1s": 0.55, "V75": 0.50, "V75_1s": 0.45, "JD100": 0.50, "V100": 0.50, "V100_1s": 0.40, "CRASH500": 0.50, "CRASH1000":0.50, "STEP200": 0.55, } # Take-profit as fraction of stake (exit early when profit target hit) ASSET_TAKE_PROFIT_FRAC: Dict[str, float] = { "V25": 1.00, "V50_1s": 0.90, "V75": 0.80, "V75_1s": 0.75, "JD100": 0.80, "V100": 0.80, "V100_1s": 0.70, "CRASH500": 0.80, "CRASH1000":0.80, "STEP200": 0.90, } # Neural network hyper-parameters SEQ_LEN = 20 FEATURE_DIM = 26 # 19 base (v3) + 7 stochastic extras (v4) D_MODEL = 64 NUM_HEADS = 4 NUM_ENCODER_LAYERS = 2 # Transformer depth (was LSTM num_layers) NUM_QUANTILES = 9 # Distributional head quantile levels NUM_REGIMES = 4 # RegimeRouter: trending / mean-rev / high-vol / crash LATENT_DIM = 32 DROPOUT = 0.1 # Bandit / decision UCB_C = 1.4 THOMPSON_STD = 0.05 SCORE_THRESHOLD = 0.0 # Dynamic: starts permissive, adapts upward via DynamicExecutionGate MAX_CONCURRENT = 4 # always run across top-4 ranked assets # Training LEARNING_RATE = 3e-4 GAMMA = 0.99 LAMBDA_RANK = 0.4 LAMBDA_RISK = 0.3 REPLAY_CAPACITY = 10_000 TRAIN_BATCH = 2 # FIX: Lowered to 2 — trains after 2 closed trades TRAIN_EVERY_N = 2 # FIX: Check every 2 rank cycles # Connection WS_RECONNECT_DELAY = 5 WS_MAX_RETRIES = 10 PRICE_UPDATE_TIMEOUT = 30 @dataclass class AssetRankerConfig: """Top-level configuration for the ranker bridge (v3 extended).""" asset_symbols: List[str] = field( default_factory=lambda: list(ASSET_REGISTRY.keys()) ) asset_registry: Dict[str, dict] = field(default_factory=lambda: ASSET_REGISTRY) feature_window: int = SEQ_LEN feature_dim: int = FEATURE_DIM d_model: int = D_MODEL n_heads: int = NUM_HEADS n_encoder_layers: int = NUM_ENCODER_LAYERS n_cross_asset_layers: int = 2 n_quantiles: int = NUM_QUANTILES n_regimes: int = NUM_REGIMES latent_dim: int = LATENT_DIM dropout: float = DROPOUT bandit_strategy: str = "ucb" ucb_c: float = UCB_C thompson_std: float = THOMPSON_STD score_threshold: float = SCORE_THRESHOLD max_concurrent: int = MAX_CONCURRENT learning_rate: float = LEARNING_RATE gamma: float = GAMMA lambda_rank: float = LAMBDA_RANK lambda_risk: float = LAMBDA_RISK batch_size: int = TRAIN_BATCH train_every_n: int = TRAIN_EVERY_N buffer_size: int = REPLAY_CAPACITY update_frequency_seconds: float = 5.0 model_path: str = "quasar_axrvi_v6.pt" device: str = "cuda" if torch.cuda.is_available() else "cpu" default_price: float = 1500.0 # Stochastic calculus / uncertainty configs (v4 additions) stochastic_config: StochasticCalculusConfig = field( default_factory=StochasticCalculusConfig ) uncertainty_config: UncertaintyConfig = field( default_factory=UncertaintyConfig ) replay_config: GirsanovReplayConfig = field( default_factory=GirsanovReplayConfig ) # Shreve continuous-time framework config (v6) shreve_config: ShreveConfig = field( default_factory=ShreveConfig ) # Institutional position-sizing and portfolio-level risk config portfolio_risk_config: PortfolioRiskConfig = field( default_factory=PortfolioRiskConfig ) # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 2 — ENUMS & DATA STRUCTURES # ══════════════════════════════════════════════════════════════════════════════════════ class TradeDirection(Enum): LONG = "long" SHORT = "short" class PositionState(Enum): PENDING = "pending" # buy sent to broker; awaiting buy confirmation OPEN = "open" # broker confirmed; contract live CLOSING = "closing" # sell sent to broker; awaiting terminal event CLOSED = "closed" # broker reported terminal state (won/lost/sold/expired) @dataclass class PriceTick: symbol: str bid: float ask: float mid: float timestamp: float bid_volume: Optional[int] = None ask_volume: Optional[int] = None @dataclass class Trade: """ Broker-backed trade record. All authoritative state comes from Deriv events. Lifecycle: PENDING — buy sent; no contract yet OPEN — broker confirmed; contract_id bound CLOSING — sell/early-exit sent; awaiting terminal event CLOSED — broker reported final outcome (won / lost / sold / expired) Fields marked [BROKER] must NOT be set locally; they are written only from incoming Deriv WebSocket messages. """ # ── Identity ───────────────────────────────────────────────────────────── trade_id: str asset: str # internal AXRVI symbol (e.g. "V75") direction: TradeDirection quantity: float entry_time: float # wall-clock time of buy send (for monitoring only) # ── Broker primary key [BROKER] ─────────────────────────────────────────── contract_id: Optional[str] = None # Deriv contract_id from buy response transaction_id: Optional[str] = None # Deriv transaction_id # ── Broker contract details [BROKER] ───────────────────────────────────── shortcode: Optional[str] = None # Deriv shortcode broker_symbol: Optional[str] = None # Deriv API symbol (e.g. "R_75") status: Optional[str] = None # open | won | lost | sold | expired # ── Price fields [BROKER] ───────────────────────────────────────────────── buy_price: Optional[float] = None # price paid for the contract sell_price: Optional[float] = None # price received on close/sell entry_tick: Optional[float] = None # spot price at contract open tick current_spot: Optional[float] = None # latest spot (updated from poc stream) profit: Optional[float] = None # broker-confirmed net P&L # ── Internal tracking (non-authoritative) ──────────────────────────────── # entry_price is kept for optional-stopping log-return calc until broker # provides the authoritative entry_tick. Updated from entry_tick on confirm. entry_price: float = 0.0 exit_price: Optional[float] = None exit_time: Optional[float] = None state: PositionState = PositionState.PENDING # These fields are retained as read-only mirrors of broker data and must # NEVER be used as authoritative execution truth. unrealized_pnl: float = 0.0 # live spot estimate only — NOT authoritative realized_pnl: float = 0.0 # set from broker profit field on close fees: float = 0.0 # ── Broker state update (called from _on_deriv_message on buy confirm) ─── def confirm_open( self, contract_id: str, buy_price: float, entry_tick: float, transaction_id: Optional[str] = None, shortcode: Optional[str] = None, broker_symbol: Optional[str] = None, ) -> None: """Transition PENDING → OPEN from broker buy confirmation.""" self.contract_id = contract_id self.buy_price = buy_price self.entry_price = entry_tick # use broker tick as authoritative entry self.entry_tick = entry_tick self.current_spot = entry_tick self.transaction_id = transaction_id self.shortcode = shortcode self.broker_symbol = broker_symbol self.status = "open" self.state = PositionState.OPEN def update_from_poc(self, poc: dict) -> None: """ Apply a proposal_open_contract (poc) stream update. Handles both live tick updates and terminal states. """ self.current_spot = float(poc.get("current_spot", self.current_spot or 0.0)) self.status = poc.get("status", self.status) # Terminal event: populate close fields from broker data if poc.get("is_expired") or poc.get("is_sold") or self.status in ("won", "lost", "sold", "expired"): sell_price = poc.get("sell_price") or poc.get("bid_price") if sell_price is not None: self.sell_price = float(sell_price) # NOTE: Do NOT assign exit_price = sell_price here. # sell_price is the Deriv contract's dollar value (e.g. 0.98), # NOT the underlying asset market price (e.g. 34074.02). # exit_price is set correctly in close_trade_from_broker() # using exit_tick / current_spot (the actual market price). self.exit_time = time.time() raw_profit = poc.get("profit") if raw_profit is None: raw_profit = poc.get("bid_price", 0.0) self.profit = float(raw_profit) if raw_profit is not None else 0.0 self.realized_pnl = self.profit # mirror broker profit self.state = PositionState.CLOSED def close(self, exit_price: float, exit_time: float, fees: float) -> None: """ Legacy compatibility close — only called when a broker terminal event has been received and the broker profit is authoritative. Uses broker sell_price if already set (preferred), otherwise falls back to the passed exit_price for reward calc continuity. """ if self.sell_price is not None: self.exit_price = self.sell_price else: self.exit_price = exit_price self.exit_time = exit_time self.fees = fees self.state = PositionState.CLOSED # Do NOT recompute realized_pnl from local arithmetic; use broker profit # if already available. If not (edge case), use local arithmetic as # last-resort fallback so callers never get None. if self.profit is not None: self.realized_pnl = self.profit else: if self.direction == TradeDirection.LONG: self.realized_pnl = (self.exit_price - self.entry_price) * self.quantity - fees else: self.realized_pnl = (self.entry_price - self.exit_price) * self.quantity - fees def compute_unrealized_pnl(self, current_price: float) -> float: """ NON-AUTHORITATIVE spot estimate for display/logging only. NOTE: For MULTUP/MULTDOWN contracts the true P&L is stake × multiplier × % move — this estimate omits the multiplier and is intentionally approximate. The [S8] G_t stopping rule uses log(price/entry) directly, not this function. NEVER used as execution truth. """ ref = self.entry_tick if self.entry_tick else self.entry_price if ref <= 0: return 0.0 if self.direction == TradeDirection.LONG: return (current_price - ref) * self.quantity return (ref - current_price) * self.quantity def compute_unrealized_return(self, current_price: float) -> float: """Non-authoritative spot return — for G_t [S8] only.""" ref = self.entry_tick if self.entry_tick else self.entry_price if ref <= 0: return 0.0 if self.direction == TradeDirection.LONG: return (current_price - ref) / ref return (ref - current_price) / ref @property def holding_duration(self) -> float: end = self.exit_time or time.time() return end - self.entry_time @dataclass class RankedAsset: """ Merged from v3 (RankedAsset) + v4 (RankedAsset). Primary field name follows v3 convention; v4 aliases noted in comments. """ space_name: str signal_confidence: float # = hub_confidence (v4 alias) significance_weight: float final_priority: float dominant_signal: str avn_accuracy: float training_steps: int = 0 score: float = 0.0 # backward-compat alias for final_priority (v3) # v4 uncertainty fields epistemic_std: float = 0.0 aleatoric_std: float = 0.0 rank: int = 0 @property def hub_confidence(self) -> float: """v4 alias for signal_confidence.""" return self.signal_confidence def to_dict(self) -> dict: return asdict(self) # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 3 — ASSET SNAPSHOT (hub read-only data) # ══════════════════════════════════════════════════════════════════════════════════════ @dataclass class AssetSnapshot: """Latest training + voting metrics for one asset space, received from hub.""" space_name: str actor_loss: float = 0.0 critic_loss: float = 0.0 avn_loss: float = 0.0 avn_accuracy: float = 0.0 training_steps: int = 0 # `dominant_signal` is the public action field that all downstream code # (Gate A, _ensure_minimum_trades, ranking export, monitoring) reads from. # In v2.3+ it is populated SOLELY by realtime per-tick signals delivered # via /ws/signals → HubSubscriber.inject_signal(). The hub-snapshot path # (apply_update) NO LONGER writes to it — voting aggregates on /ws/subscribe # are intentionally ignored as a source of truth for the action, because # they carry the cumulative dominant of an EMA-style aggregation rather # than the raw per-tick AVN inference. dominant_signal: str = "NEUTRAL" buy_count: int = 0 sell_count: int = 0 # ── Realtime signal bookkeeping (populated by inject_signal) ────────── # latest_action mirrors the per-tick action that arrived on /ws/signals. # Domain: BUY | SELL | HOLD. dominant_signal is derived from it # (HOLD → NEUTRAL) so that the rest of the codebase keeps reading the # same field name with the same {BUY, SELL, NEUTRAL} domain it always had. latest_action: str = "HOLD" latest_action_price: float = 0.0 latest_action_ts: float = 0.0 latest_action_seq: int = 0 last_updated: float = 0.0 def apply_update(self, snapshot: dict) -> None: """ Merge a hub snapshot (from /ws/subscribe) into this AssetSnapshot. v2.3+ scope: • Training metrics (actor_loss, critic_loss, avn_loss/accuracy, steps) • Voting counters (buy_count, sell_count) — kept ONLY for the legacy vote-ratio confidence fallback. NOT used to set dominant_signal. The per-tick action is delivered out-of-band on /ws/signals and lands via inject_signal() / apply_signal(). Hub-snapshot voting aggregates are deliberately ignored here. """ training = snapshot.get("training", {}) voting = snapshot.get("voting", {}) if training: self.actor_loss = float(training.get("actor_loss", self.actor_loss)) self.critic_loss = float(training.get("critic_loss", self.critic_loss)) self.avn_loss = float(training.get("avn_loss", self.avn_loss)) self.avn_accuracy = max(0.0, min(1.0, float( training.get("avn_accuracy", self.avn_accuracy) ))) self.training_steps = int(training.get("training_steps", self.training_steps)) if voting: # Counts are still useful for the vote-ratio fallback in # signal_confidence when no realtime signal has arrived yet. self.buy_count = int(voting.get("buy_count", self.buy_count)) self.sell_count = int(voting.get("sell_count", self.sell_count)) self.last_updated = snapshot.get("last_updated", time.time()) def apply_signal( self, action: str, price: float = 0.0, ts: Optional[float] = None, seq: int = 0, ) -> None: """ Apply a realtime per-tick signal from /ws/signals. Sets latest_action and mirrors it into dominant_signal so all existing gate / ranking / execution code (which reads dominant_signal) sees the per-tick action with no further changes. BUY → dominant_signal = BUY, latest_action_confidence = 1.0 SELL → dominant_signal = SELL, latest_action_confidence = 1.0 HOLD → dominant_signal = NEUTRAL, confidence falls back to vote ratio """ if not isinstance(action, str): return action = action.upper() if action not in {"BUY", "SELL", "HOLD"}: action = "HOLD" self.latest_action = action self.latest_action_price = float(price or 0.0) self.latest_action_ts = float(ts) if ts is not None else time.time() self.latest_action_seq = int(seq or 0) if action in {"BUY", "SELL"}: self.dominant_signal = action self._latest_signal_confidence = 1.0 # direct AVN inference else: # HOLD self.dominant_signal = "NEUTRAL" self._latest_signal_confidence = None # fall back to vote ratio self.last_updated = self.latest_action_ts @property def total_votes(self) -> int: return self.buy_count + self.sell_count @property def signal_confidence(self) -> float: # ── CORE FIX 6b: use 1.0 when signal came from latest_signal ── if hasattr(self, "_latest_signal_confidence") and self._latest_signal_confidence is not None: return self._latest_signal_confidence if self.total_votes == 0: return 0.0 return max(self.buy_count, self.sell_count) / self.total_votes def to_dict(self) -> dict: return { "space_name": self.space_name, "actor_loss": self.actor_loss, "critic_loss": self.critic_loss, "avn_loss": self.avn_loss, "avn_accuracy": self.avn_accuracy, "training_steps": self.training_steps, "dominant_signal": self.dominant_signal, "buy_count": self.buy_count, "sell_count": self.sell_count, "signal_confidence": self.signal_confidence, # Realtime per-tick signal bookkeeping (v2.3+) "latest_action": self.latest_action, "latest_action_price": self.latest_action_price, "latest_action_ts": self.latest_action_ts, "latest_action_seq": self.latest_action_seq, "last_updated": self.last_updated, } # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 4 — HUB SUBSCRIBER (read-only WebSocket client) # ══════════════════════════════════════════════════════════════════════════════════════ class HubSubscriber: """ Subscribes to the Central WebSocket Hub and maintains per-asset AssetSnapshots. Constraint: NEVER writes back to the hub or any asset space. Merged from v3 (with logger support) + v4 (simpler, no logger). """ _MAX_BACKOFF = 30 def __init__( self, hub_url: str, on_update: Optional[Callable[[str, AssetSnapshot], None]] = None, ranker_logger: Optional[object] = None, ): self.hub_url = hub_url self.on_update = on_update self.ranker_logger = ranker_logger self._snapshots: Dict[str, AssetSnapshot] = {} self._lock = threading.Lock() self._running = False self._thread: Optional[threading.Thread] = None self._reconnect_count = 0 self.stats = { "messages_received": 0, "reconnect_count": 0, "last_message_time": 0.0, } def start(self) -> None: if self._running: return self._running = True self._thread = threading.Thread( target=self._run_loop, daemon=True, name="HubSubscriber" ) self._thread.start() logger.info(f"[HubSubscriber] Started → {self.hub_url}") if self.ranker_logger: self.ranker_logger.connection_event( "Hub WebSocket", "connected", "Subscriber started" ) def stop(self) -> None: self._running = False logger.info("[HubSubscriber] Stopping…") def get_snapshot(self, space_name: str) -> Optional[AssetSnapshot]: with self._lock: return self._snapshots.get(space_name) def get_all_snapshots(self) -> Dict[str, AssetSnapshot]: with self._lock: return dict(self._snapshots) def _run_loop(self) -> None: while self._running: try: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(self._ws_session()) loop.close() # A clean session exit (e.g. hub closed connection gracefully) # is not a failure — reset the backoff counter so the next # reconnect starts from a short delay, not an exponential one. self._reconnect_count = 0 except Exception as e: logger.error(f"[HubSubscriber] Session error: {e}") self._reconnect_count += 1 self.stats["reconnect_count"] = self._reconnect_count if not self._running: break backoff = min(self._MAX_BACKOFF, 2 ** min(self._reconnect_count, 4)) logger.info(f"[HubSubscriber] Reconnecting in {backoff}s…") time.sleep(backoff) async def _ws_session(self) -> None: if websockets is None: logger.error("[HubSubscriber] websockets library not installed") await asyncio.sleep(5) return from websockets.exceptions import ConnectionClosed async with websockets.connect(self.hub_url) as ws: self._reconnect_count = 0 logger.info("[HubSubscriber] ✅ Connected to hub") while self._running: try: raw = await ws.recv() self._handle_message(raw) except ConnectionClosed: logger.info("[HubSubscriber] Connection closed by hub") break except Exception as e: logger.error(f"[HubSubscriber] Message error: {e}") def _handle_message(self, raw: str) -> None: try: data = json.loads(raw) except json.JSONDecodeError: logger.warning("[HubSubscriber] Malformed JSON") return msg_type = data.get("type", "") self.stats["messages_received"] += 1 self.stats["last_message_time"] = time.time() if msg_type == "initial_state": # New format: {"type": "initial_state", "assets": {"space_name": {"metadata": {...}, "snapshot": {...}}}} assets = data.get("assets", {}) if assets: for space_name, asset_data in assets.items(): if isinstance(asset_data, dict): snap_dict = asset_data.get("snapshot", {}) else: snap_dict = asset_data self._apply_snapshot(space_name, snap_dict) # Fallback: old format {"type": "initial_state", "snapshots": {"space_name": {...}}} else: for space_name, snap_dict in data.get("snapshots", {}).items(): self._apply_snapshot(space_name, snap_dict) elif msg_type == "metrics_update": # New format: {"type": "metrics_update", "asset": {"space_name": ..., "metadata": {...}, "snapshot": {...}}} asset = data.get("asset", {}) if asset: space_name = asset.get("space_name") snap_dict = asset.get("snapshot", {}) if space_name: self._apply_snapshot(space_name, snap_dict) # Fallback: old format {"type": "metrics_update", "space_name": ..., "snapshot": {...}} else: space_name = data.get("space_name") snap_dict = data.get("snapshot", {}) if space_name: self._apply_snapshot(space_name, snap_dict) def _apply_snapshot(self, space_name: str, snap_dict: dict) -> None: with self._lock: if space_name not in self._snapshots: self._snapshots[space_name] = AssetSnapshot(space_name=space_name) snap = self._snapshots[space_name] snap.apply_update(snap_dict) if self.on_update: try: self.on_update(space_name, snap) except Exception as e: logger.error(f"[HubSubscriber] on_update callback error: {e}") # ────────────────────────────────────────────────────────────────────────── # Public fast-path injector — called by SignalSubscriber when a per-tick # realtime signal arrives on the /ws/signals channel. Updates the SAME # AssetSnapshot instance the ranker reads from, so rank_and_gate sees the # latest action with ~30 ms latency instead of waiting for the next # metrics_update on /ws/subscribe (which carries cumulative aggregates, # not the per-tick action we actually want). # ────────────────────────────────────────────────────────────────────────── def inject_signal(self, space_name: str, signal: dict) -> None: """ Apply a realtime per-tick signal directly to the AssetSnapshot under the subscriber's lock. Expects `signal` shaped as: {"action": "BUY|SELL|HOLD", "price": float, "seq": int, "ts": float, "source": str, ...} Handling: • BUY / SELL → snap.dominant_signal is set accordingly (confidence=1.0) • HOLD → snap.dominant_signal reset to NEUTRAL so Gate A stops firing on stale direction """ if not isinstance(signal, dict): return action = signal.get("action", "HOLD") if not isinstance(action, str): action = "HOLD" with self._lock: if space_name not in self._snapshots: self._snapshots[space_name] = AssetSnapshot(space_name=space_name) snap = self._snapshots[space_name] snap.apply_signal( action = action, price = signal.get("price", 0.0), ts = signal.get("ts"), seq = signal.get("seq", 0), ) if self.on_update: try: self.on_update(space_name, snap) except Exception as e: logger.error(f"[HubSubscriber] on_update (signal-path) callback error: {e}") # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 4b — SIGNAL SUBSCRIBER (high-priority, low-latency side channel) # ══════════════════════════════════════════════════════════════════════════════════════ # # The SignalSubscriber is a dedicated WebSocket client that consumes only the # realtime per-tick AVN signal stream from the hub's /ws/signals endpoint. It # runs on its own background thread (matching the original sketch: # "DEDICATED THREAD, poll every 30ms") and bypasses the slower /ws/subscribe # snapshot path for per-tick action delivery. # # Why a separate subscriber: # • /ws/subscribe carries the full training-metrics firehose (losses, steps, # avn_accuracy) — large payloads at irregular cadence. Per-tick signals # would queue behind them. # • /ws/subscribe also only carries CUMULATIVE aggregates (the EMA-style # dominant_signal). The per-tick AVN action — what we actually want to # drive Gate A and the asset-buffer features off of — is not in those # payloads, only on /ws/signals. # • /ws/signals carries only {asset, action, price, source, seq, ts} — # tiny payloads, coalesced to 30 ms broadcasts. # • Running on its own thread means a slow rank_and_gate cycle in the main # loop cannot stall realtime-signal ingestion. # # Per-asset signal consistency: # • The hub assigns a monotonic `seq` per asset. The subscriber tracks # `last_seq` and drops replays / out-of-order deliveries. # • Updates land in the SAME AssetSnapshot dict the ranker reads from, via # HubSubscriber.inject_signal(), so there is exactly one source of truth # per asset at any point in time. # class SignalSubscriber: """ High-priority WS client for /ws/signals. Feeds per-tick realtime signals directly into the shared HubSubscriber snapshot store via inject_signal(). """ _MAX_BACKOFF = 30 def __init__( self, signal_url: str, hub_subscriber: "HubSubscriber", # snapshots are written through here ranker_logger: Optional[object] = None, ): self.signal_url = signal_url self.hub_subscriber = hub_subscriber self.ranker_logger = ranker_logger self._last_seq: Dict[str, int] = {} self._lock = threading.Lock() self._running = False self._thread: Optional[threading.Thread] = None self._reconnect_count = 0 self.stats = { "signals_received": 0, "signals_applied": 0, "signals_out_of_order": 0, "reconnect_count": 0, "last_signal_time": 0.0, } def start(self) -> None: if self._running: return self._running = True self._thread = threading.Thread( target=self._run_loop, daemon=True, name="SignalSubscriber" ) self._thread.start() logger.info(f"[SignalSubscriber] 📡 Started → {self.signal_url}") if self.ranker_logger: try: self.ranker_logger.connection_event( "Signal WebSocket", "connected", "SignalSubscriber started" ) except Exception: pass def stop(self) -> None: self._running = False logger.info("[SignalSubscriber] Stopping…") def _run_loop(self) -> None: while self._running: try: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(self._ws_session()) loop.close() self._reconnect_count = 0 except Exception as e: logger.error(f"[SignalSubscriber] Session error: {e}") self._reconnect_count += 1 self.stats["reconnect_count"] = self._reconnect_count if not self._running: break backoff = min(self._MAX_BACKOFF, 2 ** min(self._reconnect_count, 4)) logger.info(f"[SignalSubscriber] Reconnecting in {backoff}s…") time.sleep(backoff) async def _ws_session(self) -> None: if websockets is None: logger.error("[SignalSubscriber] websockets library not installed") await asyncio.sleep(5) return from websockets.exceptions import ConnectionClosed async with websockets.connect(self.signal_url) as ws: self._reconnect_count = 0 logger.info("[SignalSubscriber] ✅ Connected to signal channel") while self._running: try: raw = await ws.recv() self._handle_message(raw) except ConnectionClosed: logger.info("[SignalSubscriber] Connection closed by hub") break except Exception as e: logger.error(f"[SignalSubscriber] Message error: {e}") def _handle_message(self, raw: str) -> None: try: data = json.loads(raw) except json.JSONDecodeError: logger.warning("[SignalSubscriber] Malformed JSON") return msg_type = data.get("type", "") if msg_type not in ("signal_snapshot", "signal_delta"): return signals = data.get("signals", []) if not isinstance(signals, list): return for sig in signals: self._apply_signal(sig) def _apply_signal(self, signal: dict) -> None: asset = signal.get("asset") seq = signal.get("seq", 0) if not asset: return self.stats["signals_received"] += 1 # Drop replays / out-of-order (per-asset signal consistency) with self._lock: last = self._last_seq.get(asset, 0) if seq <= last: self.stats["signals_out_of_order"] += 1 return self._last_seq[asset] = seq # Push directly into the shared AssetSnapshot store self.hub_subscriber.inject_signal(asset, signal) self.stats["signals_applied"] += 1 self.stats["last_signal_time"] = time.time() action = signal.get("action", "?") if action in ("BUY", "SELL"): logger.info( f"[SignalSubscriber] ⚡ {asset} signal#{seq} → {action} " f"@ {signal.get('price', 0):.5f}" ) def _derive_signal_url(hub_ws_url: str) -> str: """Derive the /ws/signals URL from whatever hub URL the ranker was given. Accepts ws://host:port/ws/subscribe → ws://host:port/ws/signals.""" url = hub_ws_url # Replace the path if present, otherwise append for known in ("/ws/subscribe", "/ws/metrics", "/subscribe", "/ws/flips"): if url.endswith(known): return url[: -len(known)] + "/ws/signals" # Fallback: strip any trailing path and append /ws/signals if "://" in url: scheme, rest = url.split("://", 1) host_and_path = rest.split("/", 1)[0] return f"{scheme}://{host_and_path}/ws/signals" return url.rstrip("/") + "/ws/signals" # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 5 — ADAPTIVE NORMALIZER # ══════════════════════════════════════════════════════════════════════════════════════ class AdaptiveNormalizer: """Online z-score normaliser with exponential moving statistics (from v3).""" def __init__(self, dim: int, momentum: float = 0.01, eps: float = 1e-6): self.dim = dim self.momentum = momentum self.eps = eps self.mean = np.zeros(dim, dtype=np.float32) self.var = np.ones(dim, dtype=np.float32) self.n = 0 def update_and_normalize(self, x: np.ndarray) -> np.ndarray: x = np.asarray(x, dtype=np.float32) if self.n == 0: self.mean = x.copy() self.var = np.ones_like(x) self.n = 1 return np.zeros_like(x) old_mean = self.mean self.mean = (1 - self.momentum) * self.mean + self.momentum * x # Use old_mean so the variance delta is computed before the mean shifts self.var = (1 - self.momentum) * self.var + self.momentum * (x - old_mean) ** 2 self.n += 1 return np.clip((x - self.mean) / (np.sqrt(self.var) + self.eps), -5.0, 5.0) def state_dict(self) -> dict: """Return serialisable state for checkpointing.""" return { "dim": self.dim, "momentum": self.momentum, "eps": self.eps, "mean": self.mean.tolist(), "var": self.var.tolist(), "n": self.n, } def load_state_dict(self, state: dict) -> None: """Restore state from a dict produced by state_dict().""" self.dim = state["dim"] self.momentum = state["momentum"] self.eps = state["eps"] self.mean = np.array(state["mean"], dtype=np.float32) self.var = np.array(state["var"], dtype=np.float32) self.n = state["n"] # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 6 — UNIFIED FEATURE ENGINE (26-dim) # ══════════════════════════════════════════════════════════════════════════════════════ class UnifiedFeatureEngine: """ 26-dimensional context vector for the significance model. Base features (v3 FeatureEngine, indices 0–18) ─────────────────────────────────────────────── Signal features (hub input) [0] AVN loss delta — is training loss improving? [1] AVN accuracy (reliability) — how reliable is this model's history? [2] Actor loss delta — policy gradient convergence [3] Critic loss delta — value estimator convergence [4] Signal flip rate — is the hub signal stable or noisy? Market context (Deriv price stream) [5] Latest log-return — last tick direction (GBM-consistent) [6] Realised volatility — annualised vol proxy (QV-based) [7] Vol-to-baseline ratio — is vol elevated vs asset norm? Performance context (trade outcomes) [8] PnL trend slope — are we making or losing money lately? [9] Rolling Sharpe ratio — risk-adjusted return quality [10] Drawdown from peak — current depth of drawdown [11] Rolling hit rate — fraction of winning trades Hub signal encoding [12] Match efficacy — signal_confidence proxy [13] Signal confidence (raw) — hub vote majority fraction [14] Direction × confidence — signed hub conviction Price context (added in v3) [15] Normalised price (z-score) — is price unusually high/low? [16] 5-period momentum — sum of last 5 log-returns [17] Spread-to-price ratio — cost-of-trade / liquidity proxy [18] Signal × price context — hub direction × price anomaly Stochastic calculus extras (v4 StochasticFeatureEngine, indices 19–25) ───────────────────────────────────────────────────────────────────── [19] Ito convexity proxy — ½f_xx σ² discrete second derivative [20] Volatility acceleration — d(QV)/dt: is vol speeding up? [21] Volatility regime flag — 1 if short-vol > 1.2 × long-vol [22] Leverage effect — lagged return-vol correlation [23] Martingale deviation — variance-ratio test deviation [24] Jump risk / crash tension — Poisson-weighted jump proximity [25] OU residual — mean-reversion deviation """ FEATURE_DIM = 26 def __init__( self, asset_id: str, base_vol: float = 75.0, history_len: int = 100, stoch_config: Optional[StochasticCalculusConfig] = None, shreve_config: Optional["ShreveConfig"] = None, ): self.asset_id = asset_id self.base_vol = base_vol self.stoch_config = stoch_config or StochasticCalculusConfig() self.shreve_config = shreve_config or ShreveConfig() self.normalizer = AdaptiveNormalizer(self.FEATURE_DIM) # Track whether this asset is modelled under jump-diffusion [S9] self._is_jump_diffusion: bool = ( asset_id in self.shreve_config.jump_diffusion_assets ) # Price-level online normaliser self._price_norm_mean: float = 0.0 self._price_norm_var: float = 1.0 self._price_norm_n: int = 0 # History buffers (shared across feature groups) self._prices: deque = deque(maxlen=history_len) self._returns: deque = deque(maxlen=history_len) self._signals: deque = deque(maxlen=50) self._pnl_history: deque = deque(maxlen=100) self._avn_loss_h: deque = deque(maxlen=20) self._actor_h: deque = deque(maxlen=20) self._critic_h: deque = deque(maxlen=20) self._win_hist: deque = deque(maxlen=100) self._vol_history: deque = deque(maxlen=100) # for feature [20] # Scalar state self._peak_pnl: float = 0.0 self._total_pnl: float = 0.0 self.latest_spread: float = 0.03 # Jump detection state (v4) self._ticks_since_last_jump: int = 0 self._last_jump_magnitude: float = 0.0 self._jump_history: deque = deque(maxlen=50) # Last raw (pre-normalisation) feature vector — needed for gate queries self._last_raw: np.ndarray = np.zeros(self.FEATURE_DIM, dtype=np.float32) # ── Data update methods ─────────────────────────────────────────────────────── def on_price(self, price: float, spread: float, pnl: float) -> None: """Ingest a new price tick — updates price & return buffers.""" self._prices.append(price) self.latest_spread = spread self._total_pnl = pnl self._pnl_history.append(pnl) self._peak_pnl = max(self._peak_pnl, pnl) if len(self._prices) >= 2: if self.stoch_config.use_log_returns: ret = math.log(self._prices[-1] / max(self._prices[-2], 1e-12)) else: ret = (self._prices[-1] - self._prices[-2]) / max(self._prices[-2], 1e-8) self._returns.append(ret) def on_signal( self, action: str, confidence: float, avn_loss: float, avn_accuracy: float, match_efficacy: float, actor_loss: float = 0.0, critic_loss: float = 0.0, win: Optional[bool] = None, ) -> None: """Ingest a new hub signal — updates signal-quality buffers.""" self._signals.append(action) self._avn_loss_h.append(avn_loss) self._actor_h.append(actor_loss) self._critic_h.append(critic_loss) if win is not None: self._win_hist.append(1.0 if win else 0.0) # ── Feature extraction ──────────────────────────────────────────────────────── def extract(self, latest_signal: dict) -> np.ndarray: """ Produce the normalised 26-dim feature vector for the current tick. Called once per hub-snapshot cycle from AssetStateBuffer. """ raw = np.zeros(self.FEATURE_DIM, dtype=np.float32) action = latest_signal.get("action", "HOLD") confidence = float(latest_signal.get("confidence", 0.0)) avn_acc = float(latest_signal.get("avn_accuracy", 0.5)) match_eff = float(latest_signal.get("match_efficacy", 0.5)) dir_map = {"BUY": 1.0, "SELL": -1.0, "HOLD": 0.0, "NEUTRAL": 0.0} # ── BASE 19 FEATURES (v3 FeatureEngine logic) ────────────────────────── # [0] AVN loss delta if len(self._avn_loss_h) >= 2: raw[0] = self._avn_loss_h[-1] - self._avn_loss_h[-2] # [1] AVN accuracy (training reliability) raw[1] = avn_acc # [2] Actor loss delta if len(self._actor_h) >= 2: raw[2] = self._actor_h[-1] - self._actor_h[-2] # [3] Critic loss delta if len(self._critic_h) >= 2: raw[3] = self._critic_h[-1] - self._critic_h[-2] # [4] Signal flip rate (last 10 signals) if len(self._signals) >= 2: recent = list(self._signals)[-10:] flips = sum(1 for i in range(1, len(recent)) if recent[i] != recent[i - 1]) raw[4] = flips / max(len(recent) - 1, 1) # [5] Latest log-return (GBM-consistent) raw[5] = float(self._returns[-1]) if self._returns else 0.0 # [6] Realised volatility — Shreve QV form [S3] # σ̂_t = sqrt( QV_t / Δt ), QV_t = Σ (Δ log S_j)² # Divide by the window length first to get the per-bar QV rate, then # multiply by vol_annualization so σ̂ is in consistent annualised units # regardless of how many bars are in the window. if len(self._returns) >= 5: arr = np.array(list(self._returns)[-self.stoch_config.qv_window:], dtype=np.float64) n_bars = max(len(arr), 1) qv_t = float(np.sum(arr ** 2)) / n_bars # per-bar QV rate raw[6] = float(np.sqrt(qv_t * self.stoch_config.vol_annualization)) # [FIX-6a] Clamp QV-vol: jump-diffusion assets can produce σ̂ > 500% # annualised; cap at 10.0 (1000% ann.) to prevent high-variance features # from blowing up the normalizer and causing downstream gradient explosion. raw[6] = float(np.clip(raw[6], 0.0, 10.0)) else: raw[6] = self.base_vol # [7] Vol-to-baseline ratio raw[7] = raw[6] / max(self.base_vol, 1.0) # [8] PnL trend slope (drift estimation) if len(self._pnl_history) >= 5: pnl_arr = np.array(list(self._pnl_history)[-10:], dtype=np.float64) t = np.arange(len(pnl_arr)) raw[8] = float(np.polyfit(t, pnl_arr, 1)[0]) # [9] Rolling Sharpe ratio (market price of risk Θ) if len(self._returns) >= 10: ret_arr = np.array(list(self._returns)[-20:]) std_r = ret_arr.std() raw[9] = float( ret_arr.mean() / (std_r + 1e-8) * np.sqrt(self.stoch_config.vol_annualization) ) # [10] Drawdown from peak if self._peak_pnl > 0: raw[10] = float(np.clip( (self._peak_pnl - self._total_pnl) / (self._peak_pnl + 1e-8), 0.0, 1.0 )) # [11] Rolling hit rate raw[11] = float(np.mean(list(self._win_hist))) if self._win_hist else 0.5 # [12] Match efficacy raw[12] = match_eff # [13] Signal confidence (raw hub voting fraction) raw[13] = confidence # [14] Direction × confidence (signed hub conviction) raw[14] = dir_map.get(action, 0.0) * confidence # [15] Normalised price (online z-score — asset-specific scale) current_price = float(self._prices[-1]) if self._prices else 0.0 mom = 0.01 if self._price_norm_n == 0: self._price_norm_mean = current_price self._price_norm_var = max(current_price * 0.01, 1.0) ** 2 else: self._price_norm_mean = (1 - mom) * self._price_norm_mean + mom * current_price self._price_norm_var = ( (1 - mom) * self._price_norm_var + mom * (current_price - self._price_norm_mean) ** 2 ) self._price_norm_n += 1 price_zscore = float(np.clip( (current_price - self._price_norm_mean) / (np.sqrt(self._price_norm_var) + 1e-8), -5.0, 5.0 )) raw[15] = price_zscore # [16] 5-period log-return momentum if len(self._returns) >= 5: raw[16] = float(np.sum(list(self._returns)[-5:])) # [17] Spread-to-price ratio (liquidity / cost-of-trade proxy) if current_price > 0: raw[17] = float(np.clip(self.latest_spread / current_price, 0.0, 0.1)) # [18] Signal-weighted price context raw[18] = dir_map.get(action, 0.0) * confidence * price_zscore # ── ADVANCED 7 STOCHASTIC CALCULUS FEATURES (v4) ────────────────────── # [19] Ito convexity proxy (½f_xx σ²) — discrete second derivative if self.stoch_config.use_ito_convexity and len(self._prices) >= 3: p = list(self._prices)[-3:] raw[19] = float(np.clip( (p[2] - 2 * p[1] + p[0]) / (p[1] + 1e-8), -0.5, 0.5 )) # [20] Volatility acceleration — ΔQV_t = QV_short − QV_long [S3] # d/dt [log S, log S]_t in discretised form. if self.stoch_config.use_vol_regime and len(self._returns) >= self.stoch_config.vol_regime_long: arr_short = np.array(list(self._returns)[-self.stoch_config.vol_regime_short:], dtype=np.float64) arr_long = np.array(list(self._returns)[-self.stoch_config.vol_regime_long:], dtype=np.float64) qv_short = float(np.sum(arr_short ** 2)) qv_long = float(np.sum(arr_long ** 2)) # Normalise by window length so the units are per-bar QV qv_short_rate = qv_short / self.stoch_config.vol_regime_short qv_long_rate = qv_long / self.stoch_config.vol_regime_long raw[20] = float(np.clip((qv_short_rate - qv_long_rate) / (qv_long_rate + 1e-12), -1.0, 1.0)) # [21] Volatility regime flag (short-vol > 1.2 × long-vol ?) if (self.stoch_config.use_vol_regime and len(self._returns) >= self.stoch_config.vol_regime_long): short_vol = ( np.std(list(self._returns)[-self.stoch_config.vol_regime_short:]) * np.sqrt(self.stoch_config.vol_annualization) ) long_vol = ( np.std(list(self._returns)[-self.stoch_config.vol_regime_long:]) * np.sqrt(self.stoch_config.vol_annualization) ) raw[21] = 1.0 if short_vol / (long_vol + 1e-8) > 1.2 else 0.0 # [22] Leverage effect (lagged return-vol correlation) if self.stoch_config.use_leverage_effect and len(self._returns) >= 20: rets = np.array(list(self._returns)[-20:]) vols = np.abs(rets) if len(rets) > 1: corr = np.corrcoef(rets[:-1], vols[1:])[0, 1] raw[22] = float(np.clip(corr, -1.0, 1.0)) # [23] Martingale deviation (variance-ratio test) if (self.stoch_config.use_martingale_deviation and len(self._returns) >= self.stoch_config.martingale_test_window): raw[23] = self._compute_martingale_deviation() # [24] Jump risk / crash tension (Poisson-weighted proximity) if self.stoch_config.use_jump_risk and len(self._returns) >= 10: raw[24] = self._compute_jump_risk() # [25] Ornstein-Uhlenbeck residual (mean-reversion deviation) if len(self._prices) >= 10: prices_np = np.array(list(self._prices)[-10:]) raw[25] = float(np.clip(self._compute_ou_residual(prices_np), -1.0, 1.0)) # Save current vol for the vol-acceleration feature on next tick self._vol_history.append(raw[6]) # Cache raw vector for external gate queries (e.g. jump_risk, vol_ratio) self._last_raw = raw.copy() return self.normalizer.update_and_normalize(raw) # ── Internal stochastic feature helpers ────────────────────────────────────── def _compute_martingale_deviation(self) -> float: """Variance-ratio test for deviation from martingale property.""" rets = list(self._returns)[-self.stoch_config.martingale_test_window:] if len(rets) < 2: return 0.0 k = min(5, len(rets) // 2) aggregated = [sum(rets[i:i + k]) for i in range(len(rets) - k)] var_single = np.var(rets) var_agg = np.var(aggregated) if aggregated else 0.0 if var_single > 0: vr = var_agg / (k * var_single) return min(1.0, abs(vr - 1.0) * 2) return 0.0 def _compute_jump_risk(self) -> float: """ [S9] Jump-diffusion feature under compound Poisson model. dS_t = μ S_t dt + σ S_t dW_t + S_{t−} dJ_t For plain diffusion assets : use inter-arrival Poisson proximity (as before). For jump-diffusion assets : estimate Poisson intensity λ̂ from empirical jump frequency over the intensity window, then return a risk metric that is proportional to both λ̂ and expected jump magnitude E[|Y|]. Feature value = min(1, λ̂_t · E[|Y|_t] · (ticks_since_jump + 1)^{-1}) """ self._ticks_since_last_jump += 1 recent_returns = list(self._returns) n_recent = min(20, len(recent_returns)) if n_recent < 5: return 0.0 recent_vol = float(np.sqrt(np.sum(np.array(recent_returns[-n_recent:]) ** 2))) current_ret = abs(recent_returns[-1]) threshold = self.stoch_config.jump_threshold_std * (recent_vol / math.sqrt(n_recent) + 1e-9) if current_ret > threshold: self._ticks_since_last_jump = 0 self._last_jump_magnitude = current_ret self._jump_history.append(current_ret) if self._is_jump_diffusion and len(self._jump_history) >= 2: # [S9] Compound-Poisson: λ̂ = #jumps_in_window / window_bars # Use the windowed count, not the total history length, so λ̂ # reflects only the recent jump_intensity_window bars. window_len = min(len(self._jump_history), self.shreve_config.jump_intensity_window) n_jumps = window_len # jumps within the window lambda_hat = n_jumps / max(window_len, 1) # jumps per bar (in window) exp_jump = float(np.mean(self._jump_history)) # E[|Y|] proximity = 1.0 / (self._ticks_since_last_jump + 1) return min(1.0, lambda_hat * exp_jump * proximity * 10.0) else: # Plain diffusion: inverse inter-arrival proximity (original heuristic) risk = 1.0 / (self._ticks_since_last_jump + 1) if self._jump_history: avg_jump = float(np.mean(self._jump_history)) risk *= min(2.0, avg_jump / (recent_vol / math.sqrt(n_recent) + 1e-9)) return min(1.0, risk) def _compute_ou_residual(self, prices: np.ndarray) -> float: """Ornstein-Uhlenbeck mean-reversion residual.""" rolling_mean = np.mean(prices) return float((prices[-1] - rolling_mean) / (rolling_mean + 1e-8)) # ── Convenience ──────────────────────────────────────────────────────────── def get_raw_feature(self, idx: int) -> float: """Return the most recent pre-normalisation value for feature index `idx`.""" if 0 <= idx < len(self._last_raw): return float(self._last_raw[idx]) return 0.0 def has_data(self, min_samples: int = 5) -> bool: return len(self._returns) >= min_samples def state_dict(self) -> dict: """Return serialisable rolling-history state for checkpointing.""" return { "asset_id": self.asset_id, "prices": list(self._prices), "returns": list(self._returns), "signals": list(self._signals), "pnl_history": list(self._pnl_history), "avn_loss_h": list(self._avn_loss_h), "actor_h": list(self._actor_h), "critic_h": list(self._critic_h), "win_hist": list(self._win_hist), "vol_history": list(self._vol_history), "jump_history": list(self._jump_history), "peak_pnl": self._peak_pnl, "total_pnl": self._total_pnl, "latest_spread": self.latest_spread, "ticks_since_last_jump": self._ticks_since_last_jump, "last_jump_magnitude": self._last_jump_magnitude, "price_norm_mean": self._price_norm_mean, "price_norm_var": self._price_norm_var, "price_norm_n": self._price_norm_n, "normalizer": self.normalizer.state_dict(), } def load_state_dict(self, state: dict) -> None: """Restore rolling-history state from a dict produced by state_dict().""" self._prices.extend(state.get("prices", [])) self._returns.extend(state.get("returns", [])) self._signals.extend(state.get("signals", [])) self._pnl_history.extend(state.get("pnl_history", [])) self._avn_loss_h.extend(state.get("avn_loss_h", [])) self._actor_h.extend(state.get("actor_h", [])) self._critic_h.extend(state.get("critic_h", [])) self._win_hist.extend(state.get("win_hist", [])) self._vol_history.extend(state.get("vol_history", [])) self._jump_history.extend(state.get("jump_history", [])) self._peak_pnl = state.get("peak_pnl", self._peak_pnl) self._total_pnl = state.get("total_pnl", self._total_pnl) self.latest_spread = state.get("latest_spread", self.latest_spread) self._ticks_since_last_jump = state.get("ticks_since_last_jump", self._ticks_since_last_jump) self._last_jump_magnitude = state.get("last_jump_magnitude", self._last_jump_magnitude) self._price_norm_mean = state.get("price_norm_mean", self._price_norm_mean) self._price_norm_var = state.get("price_norm_var", self._price_norm_var) self._price_norm_n = state.get("price_norm_n", self._price_norm_n) if "normalizer" in state: self.normalizer.load_state_dict(state["normalizer"]) # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 7 — ASSET STATE BUFFER # ══════════════════════════════════════════════════════════════════════════════════════ class AssetStateBuffer: """ Per-asset rolling state buffer — bridges raw streams to the neural net. Responsibilities ──────────────── • Receives hub snapshots (signal + training metrics) via on_hub_snapshot() • Receives live price ticks via on_price() • Passes data to UnifiedFeatureEngine (26-dim context vectors) • Maintains a rolling sequence buffer of length SEQ_LEN for LSTM input • Reports has_data() = True once enough history exists to feed the model This class does NOT interpret or generate signals. """ STALE_TIMEOUT = 60.0 def __init__( self, asset_id: str, cfg: dict, seq_len: int = SEQ_LEN, stoch_config: Optional[StochasticCalculusConfig] = None, shreve_config: Optional[ShreveConfig] = None, ): self.asset_id = asset_id self.seq_len = seq_len self.feature_eng = UnifiedFeatureEngine( asset_id = asset_id, base_vol = cfg.get("base_vol", 75.0), stoch_config = stoch_config, shreve_config = shreve_config, ) self._seq_buffer = deque(maxlen=seq_len) self.latest_signal: dict = { "action": "HOLD", "confidence": 0.0, "avn_loss": 0.0, "avn_accuracy": 0.5, "match_efficacy": 0.5, "actor_loss": 0.0, "critic_loss": 0.0, } self.latest_price = 1500.0 self.latest_spread = 0.03 self.current_pnl = 0.0 self.last_signal_ts: Optional[float] = None self.last_price_ts: Optional[float] = None self.axrvi_score = 0.0 self.is_enabled = False @property def is_stale(self) -> bool: if self.last_signal_ts is None: return True return (time.time() - self.last_signal_ts) > self.STALE_TIMEOUT def on_hub_snapshot(self, snap: AssetSnapshot) -> None: """ Update from a hub AssetSnapshot — triggers feature extraction + buffer append. v2.3+ action sourcing: The per-tick action is taken from snap.latest_action (populated by SignalSubscriber → HubSubscriber.inject_signal from /ws/signals). snap.dominant_signal is now a mirror of latest_action (BUY/SELL pass-through, HOLD → NEUTRAL) so it stays usable as a fallback if a snapshot arrives via the cumulative path before any realtime signal has landed. """ # Prefer the realtime per-tick action from /ws/signals. action = snap.latest_action if snap.latest_action in {"BUY", "SELL", "HOLD"} else "HOLD" # Backward-compat fallback: if no realtime signal has landed yet but # dominant_signal is set (e.g. legacy code path or snapshot replay), # still use it. Domain mapping: NEUTRAL → HOLD. if action == "HOLD" and snap.dominant_signal in {"BUY", "SELL"}: action = snap.dominant_signal self.latest_signal = { "action": action, "confidence": snap.signal_confidence, "avn_loss": snap.avn_loss, "avn_accuracy": snap.avn_accuracy, "match_efficacy": snap.signal_confidence, "actor_loss": snap.actor_loss, "critic_loss": snap.critic_loss, } self.last_signal_ts = time.time() self.feature_eng.on_signal( action = self.latest_signal["action"], confidence = self.latest_signal["confidence"], avn_loss = snap.avn_loss, avn_accuracy = snap.avn_accuracy, match_efficacy = snap.signal_confidence, actor_loss = snap.actor_loss, critic_loss = snap.critic_loss, ) feat = self.feature_eng.extract(self.latest_signal) self._seq_buffer.append(feat) def on_price(self, price: float, spread: float, pnl: float) -> None: """Ingest a price tick — updates internal feature engine state.""" self.latest_price = price self.latest_spread = spread self.current_pnl = pnl self.last_price_ts = time.time() self.feature_eng.on_price(price=price, spread=spread, pnl=pnl) def get_sequence(self) -> np.ndarray: """Return (seq_len, feature_dim) float32 array, left-padded with zeros.""" buf = list(self._seq_buffer) pad = self.seq_len - len(buf) if pad > 0: buf = [np.zeros(self.feature_eng.FEATURE_DIM, dtype=np.float32)] * pad + buf return np.stack(buf, axis=0).astype(np.float32) def has_data(self) -> bool: return len(self._seq_buffer) >= max(self.seq_len // 4, 3) # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 8 — NEURAL SIGNIFICANCE MODEL (AXRVINet v7) # ══════════════════════════════════════════════════════════════════════════════════════ @dataclass class AXRVIConfig: """Central configuration dataclass for all hyperparameters.""" # Core dimensions — ALIGNED WITH SYSTEM CONSTANTS (Section 1) # D_MODEL=64, NUM_HEADS=4, SEQ_LEN=20, NUM_ENCODER_LAYERS=2 feature_dim: int = 26 d_model: int = 64 # ✅ FIX: Changed from 128 → D_MODEL constant num_heads: int = 4 # ✅ FIX: Changed from 8 → NUM_HEADS constant seq_len: int = 20 # ✅ FIX: Changed from 50 → SEQ_LEN constant num_encoder_layers: int = 2 # ✅ FIX: Changed from 3 → NUM_ENCODER_LAYERS constant num_cross_layers: int = 2 latent_dim: int = 32 # Also reduced to match LATENT_DIM constant implied usage num_quantiles: int = 9 num_regimes: int = 4 # Dropout & regularization dropout: float = 0.1 attention_dropout: float = 0.1 ff_dropout: float = 0.1 # MoE parameters n_moe_experts: int = 4 moe_top_k: int = 2 moe_balance_coeff: float = 0.01 # ODE parameters ode_steps: int = 6 hyperbolic_curvature: float = 1.0 # STDP parameters stdp_lr: float = 0.005 stdp_decay: float = 0.995 stdp_hebbian_decay: float = 0.01 # Loss weights lambda_crps: float = 0.10 lambda_moe: float = 1.00 lambda_entropy: float = 0.01 lambda_gate_reg: float = 0.005 # Training stability gradient_clip_norm: float = 1.0 use_amp: bool = False # Automatic Mixed Precision (disabled — not yet wired in HybridTrainer) # QCSAM/FABLE integration parameters (live cross-asset quantum core) # n_qubits: number of qubits per patch for QuantumFeatureMap. # Hilbert dim = 2**n_qubits. Must satisfy 2**n_qubits >= d_model when # used as a readout projection source; we set n_qubits=4 → hilbert_dim=16 # and project back up to d_model with a learned linear layer. qcsam_n_qubits: int = 4 # 2^4 = 16 Hilbert dimensions per patch qcsam_n_layers: int = 2 # variational depth for QuantumFeatureMap qcsam_n_heads: int = 2 # heads for QuantumMultiHeadAttention qcsam_qffn_layers: int = 2 # QFFN depth qcsam_align_loss_weight: float = 0.05 # weight for L_align in total loss # Deployment torchscript_compatible: bool = False checkpoint_grads: bool = False # Gradient checkpointing for memory # Default config instance DEFAULT_CONFIG = AXRVIConfig() # ====================================================================================== # I. QUANTUM-INSPIRED MODULES (Refactored) # ====================================================================================== # ============================================================================= # SECTION 1 — GATE MATRIX HELPERS (PyTorch, autograd-safe) # ============================================================================= # These functions implement single-qubit rotation matrices and statevector # application in pure PyTorch so that gradients flow through angle parameters. # # NOTE: These were ABSENT from the patched file — all forward passes would # have raised NameError at runtime. They are added here as the foundational # primitive layer upon which QuantumFeatureMap and QFFN are built. # ============================================================================= def _rx_matrix(angle: torch.Tensor) -> torch.Tensor: """ Rx(θ) = [[cos(θ/2), -i·sin(θ/2)], [-i·sin(θ/2), cos(θ/2) ]] Returns: (2, 2) cdouble tensor, differentiable w.r.t. angle. """ a = angle.double() / 2 c = torch.cos(a) s = torch.sin(a) z = torch.zeros_like(c) c_c = torch.complex(c, z) # cos(θ/2) + 0i mis = torch.complex(z, -s) # 0 − i·sin(θ/2) return torch.stack([torch.stack([c_c, mis]), torch.stack([mis, c_c])]) def _ry_matrix(angle: torch.Tensor) -> torch.Tensor: """ Ry(θ) = [[cos(θ/2), -sin(θ/2)], [sin(θ/2), cos(θ/2)]] Returns: (2, 2) cdouble tensor, differentiable w.r.t. angle. """ a = angle.double() / 2 c = torch.cos(a) s = torch.sin(a) z = torch.zeros_like(c) c_c = torch.complex(c, z) s_c = torch.complex(s, z) ms_c = torch.complex(-s, z) return torch.stack([torch.stack([c_c, ms_c]), torch.stack([s_c, c_c])]) def _rz_matrix(angle: torch.Tensor) -> torch.Tensor: """ Rz(θ) = [[exp(-iθ/2), 0 ], [0, exp(+iθ/2)]] Returns: (2, 2) cdouble tensor, differentiable w.r.t. angle. """ a = angle.double() / 2 z = torch.zeros_like(a) e_neg = torch.complex(torch.cos(-a), torch.sin(-a)) # exp(-ia) e_pos = torch.complex(torch.cos( a), torch.sin( a)) # exp(+ia) zero = torch.complex(z, z) return torch.stack([torch.stack([e_neg, zero]), torch.stack([zero, e_pos])]) def _apply_single_qubit( state: torch.Tensor, n_qubits: int, qubit: int, gate: torch.Tensor, ) -> torch.Tensor: """ Apply a 2×2 single-qubit gate to one wire of a statevector. Args: state : (2**n_qubits,) cdouble — no batch dimension n_qubits : total number of qubits qubit : target wire index (0 = most significant) gate : (2, 2) cdouble gate matrix Returns: (2**n_qubits,) cdouble — updated statevector """ dim = 2 ** n_qubits # Reshape to tensor-product form: axes (q0, q1, ..., q_{n-1}) s = state.view(*([2] * n_qubits)) # Bring target qubit to axis 0 s = s.movedim(qubit, 0) # (2, 2, ..., 2) shape_rest = s.shape[1:] s = s.reshape(2, -1) # (2, dim/2) # Matrix multiply: gate (2,2) @ s (2, dim/2) → (2, dim/2) s = torch.matmul(gate.to(s.dtype), s) # Restore axes s = s.reshape(2, *shape_rest) s = s.movedim(0, qubit) return s.reshape(dim) def _apply_cnot( state: torch.Tensor, n_qubits: int, ctrl: int, tgt: int, ) -> torch.Tensor: """ Apply CNOT gate (ctrl → tgt) to a statevector. Args: state : (2**n_qubits,) cdouble n_qubits : total qubits ctrl : control qubit index tgt : target qubit index (ctrl ≠ tgt enforced by caller) Returns: (2**n_qubits,) cdouble """ dim = 2 ** n_qubits s = state.view(*([2] * n_qubits)) result = s.clone() # Slice selecting ctrl = 1 subspace idx = [slice(None)] * n_qubits idx[ctrl] = 1 idx_t = tuple(idx) # After indexing away the ctrl axis, the target axis shifts down by 1 # if tgt > ctrl (because ctrl axis is removed below tgt). tgt_dim = tgt if tgt < ctrl else tgt - 1 # Flip the target qubit within the ctrl=1 subspace result[idx_t] = torch.flip(s[idx_t], dims=[tgt_dim]) return result.reshape(dim) # ============================================================================= # SECTION 2 — UTILITIES # ============================================================================= def safe_normalise(v: torch.Tensor, eps: float = _EPS) -> torch.Tensor: """L2-normalise along the last dimension (batch-safe).""" norm = torch.norm(v, dim=-1, keepdim=True) return v / (norm + eps) def assert_normalised(v: torch.Tensor, name: str = "state", tol: float = _NORM_TOL): """Emit a warning if any vector in v deviates from unit norm.""" norms = torch.norm(v, dim=-1) deviation = torch.abs(norms - 1.0).max().item() if deviation > tol: warnings.warn( f"[QCSAM] {name} norm deviation = {deviation:.6f} > tol={tol}. " "State may not be physical." ) def polar_to_complex(r: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: """α = r · exp(i·θ) — convert (magnitude, phase) → cdouble.""" return torch.polar(r, theta) # ============================================================================= # SECTION 3 — QISKIT CIRCUIT HELPERS # ============================================================================= def build_feature_map_circuit( n_qubits: int, x: np.ndarray, theta: np.ndarray, n_layers: int, ) -> QuantumCircuit: """ Variational feature-map circuit (Eq. 7). Wire semantics: qubit[0..n_qubits-1] : data register Structure: Layer 0 : Rx(x_i) for each qubit ← first data encoding Layer l : Ry(θ[l,i,0]), Rz(θ[l,i,1]) ← variational CNOT(i, i+1) with extra Rz phase ← ZZ entangler Final : Rx(x_i) re-encoding pass Args: n_qubits : number of data qubits x : classical data vector (n_qubits,), values in [0, π] theta : variational params (n_layers, n_qubits, 2) n_layers : number of variational layers Returns: QuantumCircuit (no measurement — statevector output) """ qr = QuantumRegister(n_qubits, name="data") qc = QuantumCircuit(qr) for i in range(n_qubits): qc.rx(float(x[i]), qr[i]) for l in range(n_layers): for i in range(n_qubits): qc.ry(float(theta[l, i, 0]), qr[i]) qc.rz(float(theta[l, i, 1]), qr[i]) for i in range(n_qubits - 1): qc.cx(qr[i], qr[i + 1]) qc.rz(float(theta[l, i, 1]), qr[i + 1]) # extra ZZ phase qc.cx(qr[i], qr[i + 1]) for i in range(n_qubits): qc.rx(float(x[i]), qr[i]) return qc def statevector_from_circuit(qc: QuantumCircuit) -> np.ndarray: """Simulate a circuit on Aer and return the statevector (complex128).""" return Statevector(qc).data def build_hadamard_test_circuit( n_qubits: int, state_q: np.ndarray, state_k: np.ndarray, imaginary: bool = False, ) -> QuantumCircuit: """ Improved Hadamard-test circuit for Re() or Im(). Wire semantics: ancilla[0] : Hadamard ancilla — controls SWAP q_reg[0..n-1] : register holding |Q⟩ k_reg[0..n-1] : register holding |K⟩ Math (Eq. 10-11): P0_real = (1 + Re()) / 2 → Re = 2·P0 − 1 P0_imag = (1 − Im()) / 2 → Im = 1 − 2·P0 If imaginary=True, the ancilla is rotated by S† before H (selects Im part). NOTE: In simulation is computed directly in PyTorch. This circuit is provided as the hardware-faithful diagnostic. """ anc = QuantumRegister(1, name="ancilla") reg_q = QuantumRegister(n_qubits, name="q_reg") reg_k = QuantumRegister(n_qubits, name="k_reg") cr = ClassicalRegister(1, name="anc_meas") qc = QuantumCircuit(anc, reg_q, reg_k, cr) qc.initialize(state_q.tolist(), reg_q) qc.initialize(state_k.tolist(), reg_k) if imaginary: qc.sdg(anc[0]) # S†: shifts to imaginary channel qc.h(anc[0]) for i in range(n_qubits): # controlled-SWAP (Fredkin) qc.cswap(anc[0], reg_q[i], reg_k[i]) qc.h(anc[0]) qc.measure(anc[0], cr[0]) return qc def hadamard_test_overlap( n_qubits: int, state_q: np.ndarray, state_k: np.ndarray, n_shots: int = 4096, ) -> complex: """ Run real + imaginary Hadamard-test circuits; return complex . Re() = 2·P(anc=0)_real − 1 Im() = 1 − 2·P(anc=0)_imag """ sim = AerSimulator(method="statevector") def _run(imag: bool) -> float: qc = build_hadamard_test_circuit(n_qubits, state_q, state_k, imaginary=imag) job = sim.run(qc, shots=n_shots) counts = job.result().get_counts() return counts.get("0", 0) / n_shots p0_re = _run(imag=False) p0_im = _run(imag=True) return complex(2.0 * p0_re - 1.0, 1.0 - 2.0 * p0_im) # ============================================================================= # SECTION 4 — FABLE HELPERS # ============================================================================= def gray_code_sequence(num_bits: int) -> List[int]: """Gray code sequence for 0..2^num_bits − 1.""" return [i ^ (i >> 1) for i in range(1 << num_bits)] def gray_permutation_indices(num_bits: int) -> np.ndarray: """ Return permutation P_G that maps binary-order indices to Gray-code order. argsort of the Gray sequence gives the inverse-Gray permutation: P_G[k] = position at which Gray(k) falls in the sorted Gray sequence. This is the correct reordering needed by the FABLE angle synthesis (place WHT-transformed angle k at the Gray-code-ordered position). """ gray = gray_code_sequence(num_bits) return np.argsort(gray) def fast_walsh_hadamard_transform(x: np.ndarray) -> np.ndarray: """ Fast Walsh–Hadamard transform (normalised, operates on a copy). Returns H · x / √N where H is the N×N Hadamard matrix. """ x = x.copy().astype(float) n = len(x) h = 1 while h < n: for i in range(0, n, h * 2): for j in range(i, i + h): x[j], x[j + h] = x[j] + x[j + h], x[j] - x[j + h] h <<= 1 return x / np.sqrt(n) def fable_angles_real(A: np.ndarray) -> np.ndarray: """ Compute θ̂ for the uniformly-controlled Ry rotations implementing O_A. For real A with |a_ij| ≤ 1: 1. θ_ij = arccos(A_ij) (rotation angle per entry) 2. θ̂ = P_G · WHT^{-1} · θ (Gray-permuted WHT) The WHT diagonalises the uniformly-controlled structure; Gray ordering maximises CNOT cancellations (FABLE §3). Returns: θ̂ of length N² (N = A.shape[0], must be power of 2). """ N = A.shape[0] n = int(np.log2(N)) theta = np.arccos(np.clip(A, -1.0, 1.0)).flatten() theta_hat = fast_walsh_hadamard_transform(theta) perm = gray_permutation_indices(2 * n) return theta_hat[perm] def fable_angles_complex(A: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """ For complex A (|a_ij| ≤ 1): magnitude → θ̂_mag via arccos + WHT + Gray (Ry rotations) phase → θ̂_ph via negated arg + WHT + Gray (Rz rotations) Returns: (θ̂_mag, θ̂_ph), each of length N². """ N = A.shape[0] n = int(np.log2(N)) perm = gray_permutation_indices(2 * n) theta_mag = np.arccos(np.clip(np.abs(A), 0.0, 1.0)).flatten() theta_ph = (-np.angle(A)).flatten() theta_hat_mag = fast_walsh_hadamard_transform(theta_mag)[perm] theta_hat_ph = fast_walsh_hadamard_transform(theta_ph )[perm] return theta_hat_mag, theta_hat_ph def compress_uniformly_controlled_rotations( theta_hat: np.ndarray, delta_c: float ) -> Tuple[np.ndarray, List[int]]: """ Zero out angles with |θ̂_i| ≤ delta_c (FABLE compression, Theorem 2). Returns: (compressed_theta_hat, list_of_kept_indices). The kept_indices list gives the angles that survive compression; its length is an upper bound on the number of Ry/Rz gates required. """ kept_mask = np.abs(theta_hat) > delta_c theta_comp = theta_hat.copy() theta_comp[~kept_mask] = 0.0 return theta_comp, np.where(kept_mask)[0].tolist() def simplify_cnot_parity(control_sequence: List[int]) -> List[int]: """ Keep only control qubits that appear an odd number of times in a sequence of same-target CNOTs (pairs cancel by CNOT² = I). """ cnt = Counter(control_sequence) return [q for q, c in cnt.items() if c % 2 == 1] def estimate_fable_gate_count( N: int, compressed: bool, kept_angles: Optional[int] = None, ) -> Dict[str, int]: """ Estimate CNOT and rotation counts for a FABLE block-encoding circuit. Full circuit (FABLE §3): CNOT ≤ 2N²−2, Ry = N², Rz = N² (complex). Compressed: counts scale with |kept_angles|. """ max_cnot = 2 * N * N - 2 max_ry = N * N if compressed and kept_angles is not None: cnot_est = max(0, 2 * kept_angles - 2) if kept_angles > 1 else 0 return { "cnot": cnot_est, "ry": kept_angles, "rz": 0, "max_cnot": max_cnot, "max_ry": max_ry, } return { "cnot": max_cnot, "ry": max_ry, "rz": max_ry, "max_cnot": max_cnot, "max_ry": max_ry, } # ============================================================================= # SECTION 5 — FABLE MATRIX QUERY ORACLE (hardware-reference circuit) # ============================================================================= class FABLEMatrixQueryOracle: """ Qiskit circuit for the FABLE matrix query oracle O_A (Definition 2). Supports real and complex matrices with optional compression. This circuit is the HARDWARE DEPLOYMENT REFERENCE — in simulation the weighted sum is computed directly (see FABLEBlockEncoder / FABLECLCU) without executing this circuit. Wire semantics: anc[0] : ancilla (block-encoding qubit) row[0..n-1] : row index register (n = log2(N)) col[0..n-1] : column index register """ def __init__( self, matrix: np.ndarray, delta_c: float = 0.0, is_complex: bool = False, ): self.matrix = matrix self.N = matrix.shape[0] self.n = int(np.log2(self.N)) self.delta_c = delta_c self.is_complex = is_complex self.compressed = delta_c > 0 self.kept_indices: Optional[List[int]] = None if not is_complex: self.theta_hat = fable_angles_real(matrix) if self.compressed: self.theta_hat, self.kept_indices = \ compress_uniformly_controlled_rotations(self.theta_hat, delta_c) else: self.theta_hat_mag, self.theta_hat_ph = fable_angles_complex(matrix) if self.compressed: self.theta_hat_mag, km = compress_uniformly_controlled_rotations( self.theta_hat_mag, delta_c) self.theta_hat_ph, kp = compress_uniformly_controlled_rotations( self.theta_hat_ph, delta_c) self.kept_indices = list(set(km) | set(kp)) def build_circuit(self) -> QuantumCircuit: """Return Qiskit circuit O_A on (anc, row_reg, col_reg).""" qr_anc = QuantumRegister(1, "anc") qr_row = QuantumRegister(self.n, "row") qr_col = QuantumRegister(self.n, "col") qc = QuantumCircuit(qr_anc, qr_row, qr_col) gray = gray_code_sequence(2 * self.n) for idx in range(len(gray) - 1): diff = gray[idx] ^ gray[idx + 1] ctrl_bit = diff.bit_length() - 1 # 0 .. 2n-1 # Map ctrl_bit to a physical qubit if ctrl_bit < self.n: ctrl_q = qr_row[ctrl_bit] else: ctrl_q = qr_col[ctrl_bit - self.n] qc.cx(ctrl_q, qr_anc[0]) if not self.is_complex: angle = self.theta_hat[idx] if idx < len(self.theta_hat) else 0.0 if abs(angle) > 1e-12: qc.ry(2.0 * angle, qr_anc[0]) else: mag_a = (self.theta_hat_mag[idx] if idx < len(self.theta_hat_mag) else 0.0) ph_a = (self.theta_hat_ph[idx] if idx < len(self.theta_hat_ph) else 0.0) if abs(mag_a) > 1e-12: qc.ry(2.0 * mag_a, qr_anc[0]) if abs(ph_a) > 1e-12: qc.rz(2.0 * ph_a, qr_anc[0]) return qc def compression_stats(self) -> Dict[str, Any]: """Return compression ratio and error bound.""" n_total = self.N ** 2 n_kept = len(self.kept_indices) if self.kept_indices is not None else n_total eps_bound = (n_total ** 1.5) * self.delta_c # Theorem 2 bound return { "n_total_angles": n_total, "n_kept_angles": n_kept, "compression_ratio": 1.0 - n_kept / n_total, "error_bound": eps_bound, "delta_c": self.delta_c, } # ============================================================================= # SECTION 6 — FABLE BLOCK ENCODER (simulation shortcut + hardware reference) # ============================================================================= class FABLEBlockEncoder(nn.Module): """ Block-encoding of matrix A via the FABLE oracle. SIMULATION SHORTCUT (default, use_simulation_shortcut=True): forward(x) = A @ x — exact, avoids circuit overhead. This is equivalent to the block-encoding action on the subspace where the ancilla is measured as |0⟩. HARDWARE PATH (use_simulation_shortcut=False): Raises NotImplementedError — placeholder for future QPU deployment. The oracle circuit from FABLEMatrixQueryOracle should be compiled and executed, with post-selection on ancilla = |0⟩. Args: A : (N, N) complex or real torch.Tensor delta_c : compression threshold (0 = no compression) use_simulation_shortcut : True → A@x | False → NotImplemented """ def __init__( self, A: torch.Tensor, delta_c: float = 0.0, use_simulation_shortcut: bool = True, ): super().__init__() self.register_buffer("A", A) self.N = A.shape[0] self.delta_c = delta_c self.use_sim = use_simulation_shortcut # Build oracle (for hardware reference / compression diagnostics) A_np = A.detach().cpu().numpy() self.oracle = FABLEMatrixQueryOracle( A_np, delta_c=delta_c, is_complex=torch.is_complex(A) ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Apply block-encoded A to state x. Args: x : (batch, N) cdouble statevector Returns: (batch, N) cdouble — A @ x (un-normalised; caller normalises) """ if self.use_sim: # Simulation shortcut: exact linear algebra # On hardware: execute oracle circuit + post-select on ancilla=0 A = self.A.to(dtype=x.dtype, device=x.device) return torch.einsum("ij,bj->bi", A, x) else: raise NotImplementedError( "Hardware FABLE execution is not implemented in this simulation. " "Use FABLEMatrixQueryOracle.build_circuit() to obtain the " "deployable Qiskit circuit." ) # ============================================================================= # SECTION 7 — FABLECLCU: FABLE-backed Complex Linear Combination of Unitaries # ============================================================================= class FABLECLCU(nn.Module): """ Complex Linear Combination of Unitaries using FABLE block-encoding. This replaces the legacy CLCU class with FABLE-informed aggregation. The FABLE upgrade adds: (1) FABLE angle synthesis for hardware-circuit reference/diagnostics. (2) Compression: when use_compression=True and delta_c > 0, coefficients with |α_j · a_{k,j}| ≤ delta_c are zeroed — directly simulating the sparse gate pattern of the compressed FABLE circuit. This DOES affect the computed output. (3) Error-bound reporting (Theorem 2) when compression is active. In simulation the aggregation is a weighted sum (mathematically equivalent to the FABLE block-encoding shortcut); the oracle circuit is the hardware-deployment reference produced by FABLEMatrixQueryOracle. Parameters: n_unitaries : number of states / attention heads (= n_patches) delta_c : compression threshold for coefficient zeroing use_compression : if False, delta_c is ignored """ def __init__( self, n_unitaries: int, delta_c: float = 0.0, use_compression: bool = False, ): super().__init__() self.n_unitaries = n_unitaries self.delta_c = delta_c if use_compression else 0.0 self.use_compression = use_compression # Trainable complex coefficients α_j = r_j · exp(i·θ_j) self._log_r = nn.Parameter(torch.zeros(n_unitaries, dtype=torch.double)) self.phase = nn.Parameter(torch.zeros(n_unitaries, dtype=torch.double)) @property def alpha(self) -> torch.Tensor: """α_j = softplus(log_r_j) · exp(i·phase_j) — (n_unitaries,) cdouble.""" r = F.softplus(self._log_r) return polar_to_complex(r, self.phase) def forward( self, states: List[torch.Tensor], attention_weights: torch.Tensor, ) -> torch.Tensor: """ FABLE-backed aggregation. Stage 1 — Coefficient construction: coeffs_j = α_j · a_{k,j} (batch, n_unitaries) cdouble Stage 2 — Compression (if enabled): Zero out coeffs_j where |coeffs_j| ≤ delta_c. This simulates the sparse gate count of the compressed FABLE circuit and DOES affect the numerical output. Stage 3 — Weighted superposition (simulation shortcut): raw = Σ_j coeffs_j · |V_j⟩ Stage 4 — Post-selection normalisation: |S⟩ = raw / ‖raw‖ Args: states : list of n_unitaries (batch, dim) cdouble tensors attention_weights: (batch, n_unitaries) cdouble Returns: |S⟩ : (batch, dim) cdouble — normalised aggregated state """ alpha = self.alpha.to(device=attention_weights.device) # (n_unitaries,) batch = states[0].shape[0] dim = states[0].shape[-1] device = states[0].device # Stage 1: α_j · a_{k,j} coeffs = alpha.unsqueeze(0) * attention_weights # (batch, n_unitaries) # Stage 2: compression — zero small coefficients if self.use_compression and self.delta_c > 0.0: mask = coeffs.abs() > self.delta_c # (batch, n_unitaries) bool coeffs = coeffs * mask.to(dtype=coeffs.dtype) n_kept = mask.float().sum(dim=-1).mean().item() eps_bound = (self.n_unitaries ** 1.5) * self.delta_c warnings.warn( f"[FABLECLCU] Compression δ_c={self.delta_c:.3e}: " f"kept {n_kept:.1f}/{self.n_unitaries} coeffs on avg; " f"FABLE Theorem-2 error bound ≈ {eps_bound:.3e}" ) # Stage 3: weighted sum — simulation shortcut # Equivalent to the FABLE block-encoding action on the |0⟩-ancilla subspace. raw = torch.zeros(batch, dim, dtype=torch.cdouble, device=device) for j, v_j in enumerate(states): raw = raw + coeffs[:, j].unsqueeze(-1) * v_j # Stage 4: post-selection normalisation norm = torch.norm(raw, dim=-1, keepdim=True) S = raw / (norm + _EPS) assert_normalised(S, "FABLECLCU output |S⟩") return S def fable_diagnostics( self, coeff_matrix: np.ndarray, ) -> Dict[str, Any]: """ Run FABLE angle synthesis + compression on a given coefficient matrix. Returns diagnostics without running a circuit. Args: coeff_matrix : (n_unitaries, n_unitaries) real or complex numpy array Returns: dict with keys: theta_hat, kept_indices, compression_ratio, error_bound, gate_count_estimate """ N = coeff_matrix.shape[0] is_cplx = np.iscomplexobj(coeff_matrix) oracle = FABLEMatrixQueryOracle( coeff_matrix, delta_c=self.delta_c, is_complex=is_cplx ) stats = oracle.compression_stats() gc = estimate_fable_gate_count( N, compressed=(self.delta_c > 0), kept_angles=stats["n_kept_angles"], ) stats["gate_count_estimate"] = gc return stats # ============================================================================= # SECTION 8 — QUANTUM FEATURE MAP # ============================================================================= class QuantumFeatureMap(nn.Module): """ Trainable quantum feature map: classical data → n-qubit statevector. Gradient computation: The statevector is computed analytically in PyTorch (equivalent to backprop through Aer) so autograd works. The Qiskit circuit is stored as a hardware-deployment reference. Wire semantics: qubit[0..n_qubits-1] : data register (one wire per PCA feature) """ def __init__(self, n_qubits: int, n_layers: int = 2): super().__init__() self.n_qubits = n_qubits self.n_layers = n_layers # (n_layers, n_qubits, 2) — [Ry angle, Rz angle] per qubit per layer self.theta = nn.Parameter( torch.randn(n_layers, n_qubits, 2, dtype=torch.double) * np.pi ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x : (batch, n_qubits) float/double, values in [0, π] Returns: states : (batch, 2**n_qubits) cdouble — each row is a unit vector """ x = x.double() return torch.stack([self._compute_statevector(x[b]) for b in range(x.shape[0])], dim=0) def _compute_statevector(self, x_single: torch.Tensor) -> torch.Tensor: """ Differentiable statevector |ψ(x, θ)⟩ (mirrors the Qiskit circuit). """ dim = 2 ** self.n_qubits state = torch.zeros(dim, dtype=torch.cdouble, device=x_single.device) state[0] = 1.0 + 0j # |0…0⟩ # First data encoding for i in range(self.n_qubits): state = _apply_single_qubit(state, self.n_qubits, i, _rx_matrix(x_single[i])) # Variational layers for l in range(self.n_layers): for i in range(self.n_qubits): state = _apply_single_qubit(state, self.n_qubits, i, _ry_matrix(self.theta[l, i, 0])) state = _apply_single_qubit(state, self.n_qubits, i, _rz_matrix(self.theta[l, i, 1])) for i in range(self.n_qubits - 1): state = _apply_cnot(state, self.n_qubits, i, i + 1) state = _apply_single_qubit(state, self.n_qubits, i + 1, _rz_matrix(self.theta[l, i, 1])) state = _apply_cnot(state, self.n_qubits, i, i + 1) # Second data encoding for i in range(self.n_qubits): state = _apply_single_qubit(state, self.n_qubits, i, _rx_matrix(x_single[i])) return state def get_qiskit_circuit(self, x: np.ndarray) -> QuantumCircuit: """Return the equivalent Qiskit circuit for hardware deployment.""" return build_feature_map_circuit( self.n_qubits, x, self.theta.detach().cpu().numpy(), self.n_layers, ) # ============================================================================= # SECTION 9 — SHARED Q/K FEATURE MAP # ============================================================================= class SharedQKFeatureMap(nn.Module): """ Shared-backbone feature map for Q and K (Hilbert-space alignment). Architecture: backbone : QuantumFeatureMap(n_qubits, n_layers) — shared parameters q_head : QuantumFeatureMap(n_qubits, n_head_layers) — Q-specific k_head : QuantumFeatureMap(n_qubits, n_head_layers) — K-specific v_map : QuantumFeatureMap(n_qubits, n_layers) — independent Alignment strategy: • tie_heads=True : q_head = k_head = backbone → Q ≡ K (hardest alignment) • tie_heads=False : q_head and k_head are separate BUT the alignment regulariser L_align = mean‖Q − K‖² is the binding force. The shared backbone object exists as a parameter pool for the "same basis origin" intention; in practice tie_heads=False relies entirely on the alignment loss. NOTE: The backbone statevector is NOT used as input to the heads — each head circuit independently encodes x from |0⟩. Both heads therefore start from identical classical input; their parameter similarity (tie_heads) or the alignment loss (tie_heads=False) keeps their outputs close. """ def __init__( self, n_qubits: int, n_layers: int = 2, n_head_layers: int = 1, tie_heads: bool = False, ): super().__init__() self.n_qubits = n_qubits self.tie_heads = tie_heads self.backbone = QuantumFeatureMap(n_qubits, n_layers) self.v_map = QuantumFeatureMap(n_qubits, n_layers) if tie_heads: # Strongest alignment: Q and K share all parameters self.q_head = self.backbone self.k_head = self.backbone else: self.q_head = QuantumFeatureMap(n_qubits, n_head_layers) self.k_head = QuantumFeatureMap(n_qubits, n_head_layers) def forward( self, x: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x : (batch, n_qubits) float or double Returns: Q, K, V : each (batch, 2**n_qubits) cdouble, unit-norm """ Q = safe_normalise(self.q_head(x)) K = safe_normalise(self.k_head(x)) V = safe_normalise(self.v_map(x)) assert_normalised(Q, "Q") assert_normalised(K, "K") assert_normalised(V, "V") return Q, K, V def alignment_loss( self, Q: torch.Tensor, K: torch.Tensor ) -> torch.Tensor: """ L_align = mean ‖Q − K‖² (real, positive, differentiable). Penalises Hilbert-space drift between Q and K to keep physically interpretable. """ diff = Q - K return torch.mean(torch.sum(diff.abs() ** 2, dim=-1)) # ============================================================================= # SECTION 10 — COMPLEX SIMILARITY # ============================================================================= class ComplexSimilarity(nn.Module): """ Compute = Re() + i·Im() (Eq. 9). In simulation: Σ_i conj(K_i) · Q_i (exact inner product). On hardware : use build_hadamard_test_circuit() per pair. Modes: use_real_only=False (default) : full complex attention use_real_only=True : magnitude-only ablation """ def __init__(self, use_real_only: bool = False, log_stats: bool = False): super().__init__() self.use_real_only = use_real_only self.log_stats = log_stats self.last_stats: Dict[str, float] = {} def forward(self, Q: torch.Tensor, K: torch.Tensor) -> torch.Tensor: """ Args: Q, K : (batch, dim) cdouble — must be unit-norm Returns: attn : (batch,) cdouble (or double if use_real_only) """ assert_normalised(Q, "Q (similarity input)") assert_normalised(K, "K (similarity input)") inner = torch.sum(torch.conj(K) * Q, dim=-1) # (batch,) cdouble if self.log_stats: self.last_stats = { "mean_||": inner.abs().mean().item(), "mean_Re()": inner.real.mean().item(), "mean_Im()": inner.imag.mean().item(), } return inner.abs() if self.use_real_only else inner # ============================================================================= # SECTION 11 — QUANTUM MULTI-HEAD ATTENTION (uses FABLECLCU) # ============================================================================= class QuantumMultiHeadAttention(nn.Module): """ Multi-head quantum attention with FABLE-backed aggregation. Per head: 1. Each patch i → Q_i, K_i, V_i via SharedQKFeatureMap 2. Attention matrix: A[k, j] = (full complex) 3. Per-query aggregation: |S_k⟩ = FABLECLCU(V_list, A[k, :]) 4. Head output = mean_k |S_k⟩ (renormalised) Final output = Σ_h α_h · head_h output (renormalised) Alignment loss = mean over heads and patches of ‖Q_i − K_i‖². This is computed here and returned to QCSAModel for inclusion in the total training loss. CLCU modules are initialised lazily on the first forward call (n_patches is data-dependent). Subsequent calls reuse cached modules unless n_patches changes (rare in practice). Args: clcu_compression_delta : δ_c forwarded to every FABLECLCU instance """ def __init__( self, n_qubits: int, n_heads: int, n_layers: int = 2, use_complex_attention: bool = True, clcu_compression_delta: float = 0.0, n_patches: int = -1, # NEW: if > 0, eagerly initialise clcu_heads immediately ): super().__init__() self.n_qubits = n_qubits self.n_heads = n_heads self.use_complex_attention = use_complex_attention self.clcu_compression_delta = clcu_compression_delta # ── Persistent buffer so _clcu_n_patches survives state_dict save/load ─ # After load_state_dict, this buffer will be restored to the pre-save # value (N = num_assets), so _ensure_clcu will correctly detect that # clcu_heads is already initialised and skip re-random-init. self.register_buffer( '_clcu_n_patches_buf', torch.tensor(-1, dtype=torch.long) ) self.qkv_maps = nn.ModuleList([ SharedQKFeatureMap(n_qubits, n_layers) for _ in range(n_heads) ]) self.similarity = ComplexSimilarity( use_real_only=not use_complex_attention, log_stats=True, ) # Head-combination weights: α_h = r_h · exp(i·φ_h) self._head_log_r = nn.Parameter(torch.zeros(n_heads, dtype=torch.double)) self._head_phase = nn.Parameter(torch.zeros(n_heads, dtype=torch.double)) # ── Eager CLCU init (n_patches known at construction time) ─────────── # When num_assets is passed in from AXRVINet, we initialise clcu_heads # immediately so they are part of state_dict from the very first save. # This eliminates the lazy-init checkpoint race entirely. if n_patches > 0: self._ensure_clcu(n_patches) @property def head_weights(self) -> torch.Tensor: r = F.softplus(self._head_log_r) return polar_to_complex(r, self._head_phase) # (n_heads,) cdouble def _ensure_clcu(self, n_patches: int) -> None: """ Initialise FABLECLCU modules for the given n_patches count. CHECKPOINT-SAFE: uses ``_clcu_n_patches_buf`` (a persistent buffer that survives state_dict save/load) as the source of truth. After ``load_state_dict`` restores both ``_clcu_n_patches_buf = n_patches`` and the actual ``clcu_heads`` weights, this guard detects the match and returns immediately — preventing the silent random-reinit bug that previously discarded all loaded FABLECLCU weights. """ current = int(self._clcu_n_patches_buf.item()) if current == n_patches: # Buffer and n_patches agree → clcu_heads already correctly set up. # This branch is taken after load_state_dict restores both the buffer # and the weights, so we must NOT reinitialise here. return if current != -1: # n_patches changed (e.g. batch-size difference at inference) — # log a warning so it is visible but do not treat this as a bug. logger.warning( f"[QCSAM] _ensure_clcu: n_patches changed {current} → {n_patches}. " "FABLECLCU re-initialised with fresh weights. " "This is expected only if the number of tracked assets changed." ) self.clcu_heads = nn.ModuleList([ FABLECLCU( n_patches, delta_c=self.clcu_compression_delta, use_compression=(self.clcu_compression_delta > 0), ) for _ in range(self.n_heads) ]) self._clcu_n_patches_buf.fill_(n_patches) def forward( self, x: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, Dict]: """ Args: x : (batch, n_patches, n_qubits) double Returns: final_state : (batch, 2**n_qubits) cdouble — normalised align_loss : scalar double — mean alignment loss across heads/patches diagnostics : dict with per-head stats """ x = x.double() batch_size, n_patches, _ = x.shape self._ensure_clcu(n_patches) head_outputs = [] total_align = torch.tensor(0.0, dtype=torch.double, device=x.device) diagnostics = {"heads": []} for h in range(self.n_heads): Q_list, K_list, V_list = [], [], [] align_h = torch.tensor(0.0, dtype=torch.double, device=x.device) # ── per-patch Q, K, V ────────────────────────────────────────── for i in range(n_patches): q_i, k_i, v_i = self.qkv_maps[h](x[:, i, :]) Q_list.append(q_i) K_list.append(k_i) V_list.append(v_i) align_h = align_h + self.qkv_maps[h].alignment_loss(q_i, k_i) align_h = align_h / n_patches total_align = total_align + align_h # ── attention matrix A[k, j] = ──────────────────── attn_matrix = torch.zeros( batch_size, n_patches, n_patches, dtype=torch.cdouble, device=x.device, ) for k in range(n_patches): for j in range(n_patches): attn_matrix[:, k, j] = self.similarity(Q_list[k], K_list[j]) attn_norm = attn_matrix.abs().norm().item() if attn_norm < 1e-6 or attn_norm > 1e4: warnings.warn( f"[QCSAM] Head {h}: attention matrix norm = {attn_norm:.4f}; " "possible attention collapse." ) head_diag = { "attn_matrix_norm": attn_norm, "align_loss_h": align_h.item(), **self.similarity.last_stats, } diagnostics["heads"].append(head_diag) # ── FABLE-backed aggregation per query ───────────────────────── S_patches = [ self.clcu_heads[h](V_list, attn_matrix[:, k, :]) for k in range(n_patches) ] head_state = torch.stack(S_patches, dim=1).mean(dim=1) head_state = safe_normalise(head_state) head_outputs.append(head_state) # ── combine heads ────────────────────────────────────────────────── hw = self.head_weights.to(x.device) final_state = sum(hw[h] * head_outputs[h] for h in range(self.n_heads)) final_state = safe_normalise(final_state) assert_normalised(final_state, "MHA output") mean_align = total_align / self.n_heads return final_state, mean_align, diagnostics # ============================================================================= # SECTION 12 — QUANTUM FEEDFORWARD NETWORK (QFFN) [single definition] # ============================================================================= class QFFN(nn.Module): """ Quantum Feedforward Network (Eq. 31). Wire semantics: qubit[0..n_qubits-1] : data register (same labelling as feature map) Layer structure per layer l: Rz(θ[l,i,0]) → Ry(θ[l,i,1]) → Rz(θ[l,i,2]) on each qubit i CNOT entanglers: even pairs, then odd pairs (brickwork) Parameters: (n_layers, n_qubits, 3) """ def __init__(self, n_qubits: int, n_layers: int = 2): super().__init__() self.n_qubits = n_qubits self.n_layers = n_layers self.params = nn.Parameter( torch.randn(n_layers, n_qubits, 3, dtype=torch.double) * np.pi ) def forward(self, state: torch.Tensor) -> torch.Tensor: """ Args: state : (batch, 2**n_qubits) cdouble Returns: state : (batch, 2**n_qubits) cdouble — renormalised """ batch = state.shape[0] out = [] for b in range(batch): s = state[b] for l in range(self.n_layers): for i in range(self.n_qubits): s = _apply_single_qubit(s, self.n_qubits, i, _rz_matrix(self.params[l, i, 0])) s = _apply_single_qubit(s, self.n_qubits, i, _ry_matrix(self.params[l, i, 1])) s = _apply_single_qubit(s, self.n_qubits, i, _rz_matrix(self.params[l, i, 2])) for i in range(0, self.n_qubits - 1, 2): # even pairs s = _apply_cnot(s, self.n_qubits, i, i + 1) for i in range(1, self.n_qubits - 1, 2): # odd pairs s = _apply_cnot(s, self.n_qubits, i, i + 1) out.append(s) return safe_normalise(torch.stack(out, dim=0)) def get_qiskit_circuit(self, init_state: np.ndarray) -> QuantumCircuit: """Hardware-reference Qiskit circuit for this QFFN instance.""" theta_np = self.params.detach().cpu().numpy() qr = QuantumRegister(self.n_qubits, name="data") qc = QuantumCircuit(qr) qc.initialize(init_state.tolist(), qr) for l in range(self.n_layers): for i in range(self.n_qubits): qc.rz(float(theta_np[l, i, 0]), qr[i]) qc.ry(float(theta_np[l, i, 1]), qr[i]) qc.rz(float(theta_np[l, i, 2]), qr[i]) for i in range(0, self.n_qubits - 1, 2): qc.cx(qr[i], qr[i + 1]) for i in range(1, self.n_qubits - 1, 2): qc.cx(qr[i], qr[i + 1]) return qc # ============================================================================= # SECTION 13 — QUANTUM MEASUREMENT (Eq. 33) # ============================================================================= class QuantumMeasurement(nn.Module): """ Projective measurement via Pauli observables (Eq. 33). ŷ_k = (1 + <ψ|M_k|ψ>) / Σ_j (1 + <ψ|M_j|ψ>) For n_classes = 2 : M = {Z₀, −Z₀} (numerators sum to 2) For n_classes = 3 : M = {X₀, Y₀, Z₀} For n_classes > 3 : tensor-products of Paulis on first two qubits Probabilities are guaranteed ≥ 0 (since <ψ|M|ψ> ∈ [−1, 1]) and normalised to 1 by the denominator. """ def __init__(self, n_qubits: int, n_classes: int): super().__init__() self.n_qubits = n_qubits self.n_classes = n_classes obs_mats = self._build_observable_matrices() for k, mat in enumerate(obs_mats): self.register_buffer(f"M_{k}", torch.tensor(mat, dtype=torch.cdouble)) self.n_observables = len(obs_mats) def _build_observable_matrices(self) -> List[np.ndarray]: I = np.eye(2, dtype=complex) X = np.array([[0, 1], [1, 0]], dtype=complex) Y = np.array([[0, -1j], [1j, 0]], dtype=complex) Z = np.array([[1, 0], [0, -1]], dtype=complex) def embed(op, qubit): full = np.eye(1, dtype=complex) for q in range(self.n_qubits): full = np.kron(full, op if q == qubit else I) return full if self.n_classes == 2: M0 = embed(Z, 0) return [M0, -M0] elif self.n_classes == 3: return [embed(P, 0) for P in (X, Y, Z)] else: paulis = [X, Y, Z] obs = [] for Pi in paulis: for Pj in paulis: if len(obs) >= self.n_classes: break M = np.kron(Pi, Pj) for _ in range(self.n_qubits - 2): M = np.kron(M, I) obs.append(M) return obs def forward(self, state: torch.Tensor) -> torch.Tensor: """ Args: state : (batch, 2**n_qubits) cdouble — unit-norm Returns: y_hat : (batch, n_classes) double — valid probability distribution """ m_vals = [] for k in range(self.n_observables): M_k = getattr(self, f"M_{k}") Mpsi = torch.einsum("ij,bj->bi", M_k, state) # <ψ|M_k|ψ> is real-valued for Hermitian M_k and unit-norm ψ m_k = torch.einsum("bi,bi->b", torch.conj(state), Mpsi).real m_vals.append(m_k) m = torch.stack(m_vals, dim=-1) # (batch, n_classes) double numerators = 1.0 + m # ∈ [0, 2] per entry y_hat = numerators / (numerators.sum(dim=-1, keepdim=True) + _EPS) return y_hat # ============================================================================= # SECTION 14 — QCSA MODEL # ============================================================================= # ============================================================================= # SECTION 15 — DATA PREPROCESSING # ============================================================================= def preprocess_data( images: np.ndarray, labels: np.ndarray, n_qubits: int, n_patches: int, train: bool = True, pca=None, scaler=None, ): """ Preprocess image patches for QCSAM input. Pipeline: 1. Extract √n_patches × √n_patches non-overlapping patches 2. PCA to n_qubits components 3. StandardScaler normalisation 4. Min-max rescale to [0, π] (suitable for Rx data encoding) Returns: (X_tensor, y_tensor, pca, scaler) X_tensor : (N, n_patches, n_qubits) float32 y_tensor : (N,) int64 """ N, H, W = images.shape sq = int(np.sqrt(n_patches)) patch_h = H // sq patch_w = W // sq patches = [] for i in range(0, H, patch_h): for j in range(0, W, patch_w): patch = images[:, i:i + patch_h, j:j + patch_w].reshape(N, -1) patches.append(patch) patches = np.stack(patches, axis=1) # (N, n_patches, patch_h*patch_w) flat = patches.reshape(N * n_patches, -1) if train: pca = PCA(n_components=n_qubits) reduced = pca.fit_transform(flat) scaler = StandardScaler() reduced = scaler.fit_transform(reduced) else: reduced = pca.transform(flat) reduced = scaler.transform(reduced) reduced = reduced.reshape(N, n_patches, n_qubits) mn, mx = reduced.min(), reduced.max() reduced = (reduced - mn) / (mx - mn + _EPS) * np.pi return ( torch.tensor(reduced, dtype=torch.float32), torch.tensor(labels, dtype=torch.long), pca, scaler, ) # ====================================================================================== # II. BIOPHYSICAL MODULES (Refactored for Stability) # ====================================================================================== class HodgkinHuxleyGate(nn.Module): """ Hodgkin-Huxley activation gate with anti-saturation and conductance clamping. Replaces GELU throughout the network. """ def __init__(self, dim: int, max_conductance: float = 1.2, min_conductance: float = 0.01): super().__init__() self.norm_m = nn.LayerNorm(dim) self.norm_h = nn.LayerNorm(dim) self.norm_n = nn.LayerNorm(dim) self.fc_m = nn.Linear(dim, dim) self.fc_h = nn.Linear(dim, dim) self.fc_n = nn.Linear(dim, dim) # Log conductances (ensures positivity via exp) self.log_g_Na = nn.Parameter(torch.full((dim,), math.log(0.8))) self.log_g_K = nn.Parameter(torch.full((dim,), math.log(0.2))) self.log_g_L = nn.Parameter(torch.full((dim,), math.log(0.05))) self.max_g = max_conductance self.min_g = min_conductance # Initialize biases for resting state nn.init.constant_(self.fc_h.bias, 1.5) nn.init.xavier_uniform_(self.fc_m.weight, gain=0.5) nn.init.xavier_uniform_(self.fc_h.weight, gain=0.5) nn.init.xavier_uniform_(self.fc_n.weight, gain=0.5) def forward(self, x: torch.Tensor) -> torch.Tensor: # Gate pre-activations with layer norm for stability m_pre = self.fc_m(self.norm_m(x)) h_pre = self.fc_h(self.norm_h(x)) n_pre = self.fc_n(self.norm_n(x)) m = torch.sigmoid(m_pre) h = torch.sigmoid(-h_pre) n = torch.sigmoid(n_pre) # Clamped conductances g_Na = torch.clamp(torch.exp(self.log_g_Na), self.min_g, self.max_g) g_K = torch.clamp(torch.exp(self.log_g_K), self.min_g, self.max_g) g_L = torch.clamp(torch.exp(self.log_g_L), self.min_g, self.max_g) # HH current: (excitatory - inhibitory - leak) * x hh_factor = (g_Na * (m ** 3) * h) - (g_K * (n ** 4)) - g_L return hh_factor * x * 0.5 # Scaling to match return magnitudes class DendriticFFN(nn.Module): """ Multi-compartment dendritic processing with NMDA coincidence detection. Includes branch regularization and entropy monitoring. """ def __init__(self, d_model: int, ffn_mult: int = 4, n_branches: int = 4, dropout: float = DEFAULT_CONFIG.dropout, gate_reg_weight: float = DEFAULT_CONFIG.lambda_gate_reg): super().__init__() hidden_dim = d_model * ffn_mult branch_dim = hidden_dim // n_branches self.n_branches = n_branches self.gate_reg_weight = gate_reg_weight self.register_buffer('_gate_entropy', torch.tensor(0.0)) # Projections per branch self.basal_proj = nn.ModuleList([nn.Linear(d_model, branch_dim) for _ in range(n_branches)]) self.apical_proj = nn.ModuleList([nn.Linear(d_model, branch_dim) for _ in range(n_branches)]) self.nmda_gate = nn.ModuleList([nn.Linear(d_model, branch_dim) for _ in range(n_branches)]) # Branch dropout (2D: batch, branches) self.branch_dropout = nn.Dropout2d(dropout) self.soma_hh = HodgkinHuxleyGate(hidden_dim) self.soma_norm = nn.LayerNorm(hidden_dim) self.out_proj = nn.Linear(hidden_dim, d_model) self.dropout = nn.Dropout(dropout) # Initialize NMDA gate biases for gate in self.nmda_gate: nn.init.constant_(gate.bias, 0.0) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: B, N, _ = x.shape basal_list = [] apical_list = [] gate_activations = [] for basal, apical, gate in zip(self.basal_proj, self.apical_proj, self.nmda_gate): basal_out = basal(x) gate_val = torch.sigmoid(gate(x)) apical_out = apical(x) * gate_val basal_list.append(basal_out) apical_list.append(apical_out) gate_activations.append(gate_val) # Stack and apply branch dropout basal_stack = torch.stack(basal_list, dim=2) # (B, N, branches, branch_dim) apical_stack = torch.stack(apical_list, dim=2) gate_stack = torch.stack(gate_activations, dim=2) if self.training: # Dropout2d zeroes entire channel slices (dim=1). We want to drop entire # *branches*, so permute branches into dim=1 before applying, then restore. # (B, N, branches, branch_dim) → (B, branches, N, branch_dim) → drop → back basal_stack = self.branch_dropout( basal_stack.permute(0, 2, 1, 3) ).permute(0, 2, 1, 3) apical_stack = self.branch_dropout( apical_stack.permute(0, 2, 1, 3) ).permute(0, 2, 1, 3) # Flatten branches basal_flat = basal_stack.view(B, N, -1) apical_flat = apical_stack.view(B, N, -1) # Gate regularization entropy (maximize diversity) if self.training and self.gate_reg_weight > 0: mean_gate = gate_stack.mean(dim=(0, 1, 3)) # (branches,) entropy = -(mean_gate * torch.log(mean_gate + 1e-8)).sum() self._gate_entropy = -entropy * self.gate_reg_weight else: self._gate_entropy = torch.tensor(0.0, device=x.device) soma_input = basal_flat + apical_flat soma_act = self.soma_hh(soma_input) soma_act = self.soma_norm(soma_act) out = self.out_proj(self.dropout(soma_act)) return out, self._gate_entropy class STDPAdjacencyLayer(nn.Module): """ Spike-Timing Dependent Plasticity with Hebbian decay and regime-aware reset. Enables online adaptation of GNN edge weights. """ def __init__(self, n_heads: int, max_assets: int = 16, lr_stdp: float = DEFAULT_CONFIG.stdp_lr, decay: float = DEFAULT_CONFIG.stdp_decay, hebbian_decay: float = DEFAULT_CONFIG.stdp_hebbian_decay, spike_ema: float = 0.9, spike_threshold: float = 0.5): super().__init__() self.n_heads = n_heads self.max_assets = max_assets self.lr_stdp = lr_stdp self.decay = decay self.hebbian_decay = hebbian_decay self.spike_ema = spike_ema self.spike_threshold = spike_threshold # STDP delta accumulator (non-trainable) self.register_buffer('stdp_delta', torch.zeros(n_heads, max_assets, max_assets)) self.register_buffer('running_spike', torch.zeros(max_assets)) self.register_buffer('step_count', torch.tensor(0)) @torch.no_grad() def stdp_update(self, asset_embeddings: torch.Tensor, regime_weights: Optional[torch.Tensor] = None) -> None: """ Update STDP weights based on asset activity patterns. Args: asset_embeddings: (B, N, d_model) regime_weights: (B, N) optional multiplier (e.g., crash probability) """ B, N, _ = asset_embeddings.shape # Clamp N to max_assets — extra assets are simply not tracked N = min(N, self.max_assets) # Use only the first N (clamped) assets asset_emb_n = asset_embeddings[:, :N, :] # Activity proxy: L2 norm normalized across batch activity = asset_emb_n.norm(dim=-1) mu = activity.mean(dim=0, keepdim=True) std = activity.std(dim=0, keepdim=True) + 1e-6 z_activity = (activity - mu) / std # Soft spike detection (smooth binary) spike = torch.sigmoid((z_activity - self.spike_threshold) * 5.0) # Update running spike EMA self.running_spike[:N] = (self.spike_ema * self.running_spike[:N] + (1 - self.spike_ema) * spike.mean(0)) # Pairwise lead-lag statistics ltp = torch.einsum('bi,bj->bij', spike, 1 - spike).mean(0) # i leads j ltd = torch.einsum('bi,bj->bij', 1 - spike, spike).mean(0) # j leads i # STDP delta (asymmetric) delta = self.lr_stdp * (ltp - 0.5 * ltd) # Hebbian decay (prevents runaway potentiation) hebb_term = self.hebbian_decay * torch.outer(self.running_spike[:N], self.running_spike[:N]) delta = delta - hebb_term # Regime modulation if regime_weights is not None: regime_mult = regime_weights.mean(dim=0) delta = delta * regime_mult.unsqueeze(0).unsqueeze(0) # Update with decay self.stdp_delta[:, :N, :N].mul_(self.decay).add_(delta.unsqueeze(0)) self.stdp_delta.clamp_(-2.0, 2.0) self.step_count += 1 def get_adapted_bias(self, base_adj_bias: torch.Tensor, N: int) -> torch.Tensor: """Return base bias + accumulated STDP delta, clamped to registered max_assets.""" N_clamped = min(N, self.max_assets) if N > self.max_assets: # Pad with zeros for the extra assets beyond max_assets pad = N - self.max_assets padded_delta = F.pad(self.stdp_delta[:, :N_clamped, :N_clamped], (0, pad, 0, pad), value=0.0) padded_base = base_adj_bias[:, :N, :N] return padded_base + padded_delta return base_adj_bias[:, :N_clamped, :N_clamped] + self.stdp_delta[:, :N_clamped, :N_clamped] def reset_plasticity(self) -> None: """Reset STDP state (call on regime shift detection).""" self.stdp_delta.zero_() self.running_spike.zero_() self.step_count.zero_() class CircadianPositionalEncoding(nn.Module): """ Multi-frequency oscillator encoding with orthogonal frequency initialization. Provides structured inductive bias for intraday patterns. """ def __init__(self, d_model: int = DEFAULT_CONFIG.d_model, seq_len: int = DEFAULT_CONFIG.seq_len, n_oscillators: int = 8, dropout: float = DEFAULT_CONFIG.dropout, min_period: float = 2.0, max_period: float = 288.0): super().__init__() self.d_model = d_model self.n_osc = n_oscillators # Log-spaced periods (geometric progression) log_min = math.log(min_period) log_max = math.log(max_period) init_periods = torch.exp(torch.linspace(log_min, log_max, n_oscillators)) self.log_T = nn.Parameter(torch.log(init_periods)) # Amplitudes and phases self.amp_sin = nn.Parameter(torch.ones(n_oscillators)) self.amp_cos = nn.Parameter(torch.ones(n_oscillators)) self.phases = nn.Parameter(torch.linspace(0, 2 * math.pi, n_oscillators)) # Projection with orthogonal initialization self.proj = nn.Linear(n_oscillators * 2, d_model) nn.init.orthogonal_(self.proj.weight) nn.init.zeros_(self.proj.bias) self.dropout = nn.Dropout(dropout) # Precompute time indices t = torch.arange(seq_len, dtype=torch.float32) self.register_buffer('t_template', t) def forward(self, seq_len: int, device: torch.device) -> torch.Tensor: t = self.t_template[:seq_len].to(device) periods = torch.exp(self.log_T) angle = 2 * math.pi * t[:, None] / periods[None, :] + self.phases sin_feat = self.amp_sin * torch.sin(angle) cos_feat = self.amp_cos * torch.cos(angle) pe = self.proj(torch.cat([sin_feat, cos_feat], dim=-1)) return self.dropout(pe.unsqueeze(0)) # ====================================================================================== # III. ADVANCED NEURAL ARCHITECTURE MODULES (Refactored) # ====================================================================================== class SelectiveSSMLayer(nn.Module): """ Selective State Space Model (Mamba architecture) with O(T) complexity. Refactored: Added numerical stability, gradient checkpointing, and mixed precision hooks. """ def __init__(self, d_model: int = DEFAULT_CONFIG.d_model, d_state: int = 16, d_conv: int = 4, expand: int = 2, dropout: float = DEFAULT_CONFIG.dropout): super().__init__() self.d_model = d_model self.d_state = d_state self.d_inner = d_model * expand self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, d_conv, padding=d_conv - 1, groups=self.d_inner) self.x_proj = nn.Linear(self.d_inner, d_state * 2 + self.d_inner, bias=False) self.dt_proj = nn.Linear(self.d_inner, self.d_inner) # Stable A matrix (negative diagonal) A = torch.arange(1, d_state + 1).float().unsqueeze(0).repeat(self.d_inner, 1) self.A_log = nn.Parameter(torch.log(A)) self.D = nn.Parameter(torch.ones(self.d_inner)) self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) # Cache for inference mode self._recurrent_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None def forward(self, x: torch.Tensor, recurrent: bool = False) -> torch.Tensor: """ x: (B, T, d_model) recurrent: If True, use O(1) per-step inference mode """ if recurrent and self._recurrent_state is not None: return self._recurrent_step(x) B, T, _ = x.shape residual = x xz = self.in_proj(x) x_gate, z = xz.chunk(2, dim=-1) # Causal depthwise conv x_conv = self.conv1d(x_gate.transpose(1, 2))[..., :T].transpose(1, 2) x_conv = F.silu(x_conv) # Input-dependent SSM parameters ssm_in = self.x_proj(x_conv) B_ssm, C_ssm, dt_raw = ssm_in.split([self.d_state, self.d_state, self.d_inner], dim=-1) dt = F.softplus(self.dt_proj(dt_raw)) A = -torch.exp(self.A_log) # Selective scan y = self._selective_scan(x_conv, dt, A, B_ssm, C_ssm) y = y * F.silu(z) out = self.out_proj(y) return self.norm(residual + self.dropout(out)) def _selective_scan(self, u: torch.Tensor, dt: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor) -> torch.Tensor: """ Sequential selective scan. In production, replace with CUDA kernel: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn """ B_batch, T, d = u.shape d_state = A.shape[-1] # Zero-order hold discretization dA = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)) dB = dt.unsqueeze(-1) * B.unsqueeze(2) h = torch.zeros(B_batch, d, d_state, device=u.device) ys = [] for t in range(T): h = dA[:, t] * h + dB[:, t] * u[:, t].unsqueeze(-1) y_t = (h * C[:, t].unsqueeze(1)).sum(-1) ys.append(y_t) y = torch.stack(ys, dim=1) return y + self.D * u def _recurrent_step(self, x: torch.Tensor) -> torch.Tensor: """Single-step recurrent inference (O(1) per step).""" # Simplified for live trading; full implementation requires state caching return self.forward(x, recurrent=False) def reset_recurrent_state(self): """Reset recurrent state for new sequence.""" self._recurrent_state = None class PoincareBall: """ Poincaré ball model for hyperbolic geometry. All operations are exact and differentiable. """ @staticmethod def project(x: torch.Tensor, c: float = 1.0, eps: float = 1e-5) -> torch.Tensor: max_norm = (1.0 / math.sqrt(c)) - eps norm = x.norm(dim=-1, keepdim=True).clamp(min=1e-8) return torch.where(norm > max_norm, x / norm * max_norm, x) @staticmethod def mobius_add(x: torch.Tensor, y: torch.Tensor, c: float = 1.0) -> torch.Tensor: x2 = x.pow(2).sum(-1, keepdim=True) y2 = y.pow(2).sum(-1, keepdim=True) xy = (x * y).sum(-1, keepdim=True) num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y denom = (1 + 2 * c * xy + c ** 2 * x2 * y2).clamp(min=1e-8) return num / denom @staticmethod def exp_map_at_origin(v: torch.Tensor, c: float = 1.0) -> torch.Tensor: sqrt_c = math.sqrt(c) v_norm = v.norm(dim=-1, keepdim=True).clamp(min=1e-8) return torch.tanh(sqrt_c * v_norm) / (sqrt_c * v_norm) * v @staticmethod def log_map_at_origin(x: torch.Tensor, c: float = 1.0) -> torch.Tensor: sqrt_c = math.sqrt(c) x_norm = x.norm(dim=-1, keepdim=True).clamp(min=1e-8) return (1.0 / sqrt_c) * torch.atanh((sqrt_c * x_norm).clamp(-1 + 1e-6, 1 - 1e-6)) / x_norm * x class KANLayer(nn.Module): """ Kolmogorov-Arnold Network layer with Gaussian RBF basis. Interpretable 1D splines per input-output connection. """ def __init__(self, in_features: int, out_features: int, n_basis: int = 8): super().__init__() self.in_features = in_features self.out_features = out_features self.n_basis = n_basis # RBF centers and log-widths self.centers = nn.Parameter( torch.linspace(-2, 2, n_basis).unsqueeze(0).expand(in_features, -1).clone() ) self.log_widths = nn.Parameter(torch.zeros(in_features, n_basis)) # Spline coefficients self.coefficients = nn.Parameter(torch.randn(out_features, in_features, n_basis) * 0.05) # Residual linear path self.linear = nn.Linear(in_features, out_features) nn.init.xavier_uniform_(self.linear.weight, gain=0.1) def forward(self, x: torch.Tensor) -> torch.Tensor: *batch, n = x.shape xf = x.reshape(-1, n) M = xf.shape[0] # Gaussian RBF basis diff = xf.unsqueeze(-1) - self.centers.unsqueeze(0) widths = torch.exp(self.log_widths).unsqueeze(0) basis = torch.exp(-0.5 * (diff / widths.clamp(min=1e-4)) ** 2) # Spline output out_spline = torch.einsum('oid,mid->mo', self.coefficients, basis) # Residual linear out_linear = self.linear(xf) return (out_spline + out_linear).reshape(*batch, self.out_features) class KANScoringHead(nn.Module): """ Two-layer KAN for significance scoring. Interpretable: plot 1D spline to audit monotonicity. """ def __init__(self, n_basis: int = 8): super().__init__() self.kan1 = KANLayer(2, 16, n_basis) self.kan2 = KANLayer(16, 1, n_basis) def forward(self, value: torch.Tensor, log_var: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: combined = torch.cat([value, log_var], dim=-1) h = self.kan1(combined) significance_logits = self.kan2(h).squeeze(-1) significance_weight = torch.sigmoid(significance_logits) return significance_logits, significance_weight class NeuralODERegimeRouter(nn.Module): """ Continuous-time regime dynamics via Neural ODE (RK4 integration). Smooth regime transitions vs discrete switching. """ def __init__(self, d_model: int = DEFAULT_CONFIG.d_model, num_regimes: int = DEFAULT_CONFIG.num_regimes, n_steps: int = DEFAULT_CONFIG.ode_steps, dropout: float = DEFAULT_CONFIG.dropout): super().__init__() self.d_model = d_model self.num_regimes = num_regimes self.n_steps = n_steps # ODE dynamics: dz/dt = f([z, t]) self.ode_func = nn.Sequential( nn.Linear(d_model + 1, d_model * 2), nn.Tanh(), nn.Linear(d_model * 2, d_model), nn.Tanh() ) # Regime classifier on evolved state self.regime_classifier = nn.Sequential( nn.Linear(d_model, d_model // 2), HodgkinHuxleyGate(d_model // 2), nn.Dropout(dropout), nn.Linear(d_model // 2, num_regimes) ) # Regime-specific projections self.regime_projections = nn.ModuleList([ nn.Sequential( nn.Linear(d_model, d_model), HodgkinHuxleyGate(d_model), nn.LayerNorm(d_model) ) for _ in range(num_regimes) ]) def _rk4_step(self, z: torch.Tensor, t: float, dt: float) -> torch.Tensor: def f(z_, t_): t_vec = torch.full((*z_.shape[:-1], 1), t_, device=z_.device, dtype=z_.dtype) return self.ode_func(torch.cat([z_, t_vec], dim=-1)) k1 = f(z, t) k2 = f(z + dt / 2 * k1, t + dt / 2) k3 = f(z + dt / 2 * k2, t + dt / 2) k4 = f(z + dt * k3, t + dt) return z + (dt / 6) * (k1 + 2 * k2 + 2 * k3 + k4) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: dt = 1.0 / self.n_steps z = x for step in range(self.n_steps): z = self._rk4_step(z, step * dt, dt) regime_logits = self.regime_classifier(z) regime_probs = F.softmax(regime_logits, dim=-1) out = torch.zeros_like(z) for i, proj in enumerate(self.regime_projections): out = out + regime_probs[..., i:i+1] * proj(z) return out, regime_logits, regime_probs class MoETemporalEncoder(nn.Module): """ Mixture-of-Experts Temporal Encoder with sparse top-k routing. Combines Transformer and Mamba experts for efficient temporal processing. """ def __init__(self, input_dim: int = DEFAULT_CONFIG.feature_dim, d_model: int = DEFAULT_CONFIG.d_model, n_experts: int = DEFAULT_CONFIG.n_moe_experts, top_k: int = DEFAULT_CONFIG.moe_top_k, seq_len: int = DEFAULT_CONFIG.seq_len, dropout: float = DEFAULT_CONFIG.dropout, num_heads: int = DEFAULT_CONFIG.num_heads, num_encoder_layers: int = DEFAULT_CONFIG.num_encoder_layers): super().__init__() self.n_experts = n_experts self.top_k = top_k self.d_model = d_model self.balance_coeff = DEFAULT_CONFIG.moe_balance_coeff self.last_load_balance_loss = torch.tensor(0.0) # Expert pool — built dynamically so len(experts) always == n_experts. # Alternate Transformer / Mamba stubs so the pool is architecturally diverse. self.experts = nn.ModuleList([ TransformerExpertStub(input_dim, d_model, num_heads, num_encoder_layers, seq_len, dropout) if i % 2 == 0 else MambaExpertStub(input_dim, d_model, seq_len, dropout) for i in range(n_experts) ]) # Router uses first bar features self.router = nn.Linear(input_dim, n_experts) def forward(self, x: torch.Tensor) -> torch.Tensor: BN, T, feat_dim = x.shape # Routing from first bar router_logits = self.router(x[:, 0, :]) router_probs = F.softmax(router_logits, dim=-1) # Top-k sparse selection topk_probs, topk_idx = router_probs.topk(self.top_k, dim=-1) topk_probs = topk_probs / topk_probs.sum(-1, keepdim=True) # Load balancing loss expert_load = router_probs.mean(0) uniform = torch.ones_like(expert_load) / self.n_experts self.last_load_balance_loss = self.balance_coeff * ((expert_load - uniform).pow(2).sum()) # Accumulate expert outputs out = torch.zeros(BN, self.d_model, device=x.device, dtype=x.dtype) for k in range(self.top_k): for exp_id in range(self.n_experts): sel = topk_idx[:, k] == exp_id if sel.sum() == 0: continue w = topk_probs[sel, k] exp_out = self.experts[exp_id](x[sel]) out[sel] = out[sel] + w.unsqueeze(-1) * exp_out return out class TransformerExpertStub(nn.Module): """Transformer temporal encoder with circadian positional encoding.""" def __init__(self, input_dim: int, d_model: int, num_heads: int, num_layers: int, seq_len: int, dropout: float): super().__init__() self.d_model = d_model self.input_proj = nn.Sequential( nn.Linear(input_dim, d_model), nn.LayerNorm(d_model), HodgkinHuxleyGate(d_model) ) self.circadian_pe = CircadianPositionalEncoding(d_model, seq_len, dropout=dropout) encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=num_heads, dim_feedforward=d_model * 4, dropout=dropout, batch_first=True, norm_first=True ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.out_norm = nn.LayerNorm(d_model) self.register_buffer('_causal_mask', None, persistent=False) def _get_causal_mask(self, T: int, device: torch.device) -> torch.Tensor: if self._causal_mask is None or self._causal_mask.shape[0] != T or self._causal_mask.device != device: self._causal_mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool() return self._causal_mask def forward(self, x: torch.Tensor) -> torch.Tensor: T = x.shape[1] h = self.input_proj(x) h = h + self.circadian_pe(T, x.device) mask = self._get_causal_mask(T, x.device) h = self.transformer(h, mask=mask) return self.out_norm(h[:, -1, :]) class MambaExpertStub(nn.Module): """Mamba SSM temporal encoder wrapper.""" def __init__(self, input_dim: int, d_model: int, seq_len: int, dropout: float): super().__init__() self.proj = nn.Sequential(nn.Linear(input_dim, d_model), nn.LayerNorm(d_model)) self.ssm = SelectiveSSMLayer(d_model, dropout=dropout) self.d_model = d_model def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.proj(x) h = self.ssm(h) return h[:, -1, :] class DistributionalHead(nn.Module): """ Return distribution head with quantile regression. Uses CRPS loss for proper scoring. """ def __init__(self, d_model: int = DEFAULT_CONFIG.d_model, latent_dim: int = DEFAULT_CONFIG.latent_dim, n_quantiles: int = DEFAULT_CONFIG.num_quantiles, dropout: float = DEFAULT_CONFIG.dropout): super().__init__() self.n_quantiles = n_quantiles self.register_buffer("quantile_levels", torch.linspace(0.1, 0.9, n_quantiles)) self.encoder = nn.Sequential( nn.Linear(d_model, latent_dim * 2), HodgkinHuxleyGate(latent_dim * 2), nn.Dropout(dropout), nn.Linear(latent_dim * 2, latent_dim) ) self.quantile_head = nn.Linear(latent_dim, n_quantiles) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: z = self.encoder(x) quantiles = self.quantile_head(z) mid = self.n_quantiles // 2 # median index # With linspace(0.1, 0.9, n_quantiles), integer floor-division yields: # n_quantiles=9: q_lo=2 → τ=0.30, q_hi=6 → τ=0.70 (not true Q25/Q75). # Renamed to q_lo / q_hi to avoid the misleading q25/q75 labels. q_lo = self.n_quantiles // 4 # lower spread quantile q_hi = 3 * self.n_quantiles // 4 # upper spread quantile value = quantiles[..., mid].unsqueeze(-1) iqr = (quantiles[..., q_hi] - quantiles[..., q_lo]).clamp(min=1e-6) log_var = (2.0 * torch.log(iqr)).unsqueeze(-1) # CVaR@5% approximation: with n_quantiles=9 the lowest available τ is 0.10, # so n_tail=1 gives Q10 — the closest obtainable proxy to CVaR@5%. n_tail = max(1, self.n_quantiles // 9) # ≈ τ=0.10 tail cvar_05 = quantiles[..., :n_tail].mean(dim=-1) # Q10 approximation of CVaR@5% return z, value, log_var, quantiles, cvar_05 # ====================================================================================== # IV. LOSS FUNCTIONS # ====================================================================================== def crps_loss(quantiles: torch.Tensor, quantile_levels: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Continuous Ranked Probability Score (proper scoring rule for distributions). Supports two calling conventions: • 2-D quantiles (B, Q) — e.g. selected-asset slice, y_true shape (B,) • 3-D quantiles (B, N, Q) — e.g. full asset roster, y_true shape (B,) The previous implementation used `y_true.unsqueeze(-1)` → (B, 1) for both cases. When quantiles is 3-D, PyTorch left-pads (B, 1) to (1, B, 1) before broadcasting against (B, N, Q), producing a dim-1 mismatch of B vs N (e.g. "size 4 must match 9 at non-singleton dimension 1"). Fixed below by branching on ndim. """ if quantiles.dim() == 2: # (B, Q) — selected-asset slice or standalone call y_exp = y_true.unsqueeze(-1) # (B, 1) tau = quantile_levels.view(1, -1) # (1, Q) else: # (B, N, Q) — full asset roster y_exp = y_true.view(-1, 1, 1) # (B, 1, 1) tau = quantile_levels.view(1, 1, -1) # (1, 1, Q) errors = y_exp - quantiles pinball = torch.where(errors >= 0, tau * errors, (tau - 1) * errors) return pinball.mean() def v8_total_loss(out: Dict[str, torch.Tensor], rank_targets: torch.Tensor, return_true: torch.Tensor, quantile_levels: torch.Tensor, config: AXRVIConfig = DEFAULT_CONFIG) -> torch.Tensor: """ Combined auxiliary loss for AXRVINet v8 architecture components. NOTE: This function computes the architecture-level regularisation terms (CRPS distribution calibration, MoE load-balance, dendritic branch diversity, and regime entropy). It is NOT the primary training loss used by HybridTrainer.train_on_batch(), which applies its own 7-component objective including RL (TD error), value-consistency [S1], ranking margin, uncertainty penalty, and pinball/quantile loss [v7]. Components: - CRPS (distributional calibration of the quantile head) - MoE (expert load balance — pulled from MoETemporalEncoder) - GateReg (dendritic branch diversity — pulled from HyperbolicCrossAssetLayer) - RegEnt (regime-router entropy maximisation) The returned scalar can be added to HybridTrainer's total loss or used in a separate architecture fine-tuning pass. """ crps = crps_loss(out["quantiles"], quantile_levels, return_true) moe = out.get("moe_balance_loss", torch.tensor(0.0)) gate_reg = out.get("gate_entropy_loss", torch.tensor(0.0)) # Regime entropy regularisation (maximise diversity across the 4 regimes). # entropy is already positive; we want to MAXIMISE it, so we add -entropy # to the total loss (minimising the loss maximises entropy). probs = out["regime_probs"] + 1e-8 entropy = -(probs * probs.log()).sum(-1).mean() # positive scalar regime_ent = -entropy # negate so that minimising loss → maximising entropy return (config.lambda_crps * crps + config.lambda_moe * moe + config.lambda_entropy * regime_ent + config.lambda_gate_reg * gate_reg) # ====================================================================================== # V-pre. MISSING CROSS-ASSET LAYERS — required by AXRVINet # ====================================================================================== class HyperbolicCrossAssetLayer(nn.Module): """ Cross-asset interaction layer operating in Poincaré-ball hyperbolic space. Architecture per layer: 1. Project each asset embedding to Poincaré ball via exp_map 2. Pairwise Möbius addition for cross-asset context mixing 3. Multi-head attention (standard, real-valued) in tangent space 4. DendriticFFN for non-linear integration 5. Project back to Euclidean space via log_map 6. STDP adjacency update (online plasticity) Inputs : h (B, N, d_model) float32 Outputs: h' (B, N, d_model) float32 + gate_entropy scalar """ def __init__( self, d_model: int, num_heads: int, curvature: float = 1.0, dropout: float = 0.1, ): super().__init__() self.d_model = d_model self.c = curvature # Standard multi-head attention in tangent space self.attn = nn.MultiheadAttention( embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True, ) self.attn_norm = nn.LayerNorm(d_model) self.attn_drop = nn.Dropout(dropout) # DendriticFFN for branch-wise processing self.ffn = DendriticFFN(d_model=d_model, dropout=dropout) self.ffn_norm = nn.LayerNorm(d_model) # STDP adjacency self.stdp = STDPAdjacencyLayer(n_heads=num_heads) # Learnable adjacency bias (n_heads, max_assets, max_assets) self.register_buffer( "_adj_bias", torch.zeros(num_heads, 16, 16), ) # State tracking self.last_attn_weights: Optional[torch.Tensor] = None self._last_gate_entropy: Optional[torch.Tensor] = None self._stdp_enabled: bool = True def forward( self, h: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: h : (B, N, d_model) float32 mask : (B, N) bool — True = stale / ignore Returns: h' : (B, N, d_model) float32 """ B, N, D = h.shape # ── 1. Map to Poincaré ball tangent space ───────────────────────── h_hyp = PoincareBall.exp_map_at_origin(h, c=self.c) # still float32 # ── 2. Multi-head self-attention (tangent space) ────────────────── key_pad_mask: Optional[torch.Tensor] = None if mask is not None: # mask: (B, N) bool; MultiheadAttention expects (B, N) float key_padding_mask # True entries are positions to ignore key_pad_mask = mask.squeeze(0) if mask.dim() == 3 else mask # (B, N) attn_out, attn_weights = self.attn( h_hyp, h_hyp, h_hyp, key_padding_mask=key_pad_mask, need_weights=True, average_attn_weights=True, ) self.last_attn_weights = attn_weights.detach() # (B, N, N) for diagnostics h_hyp = self.attn_norm(h_hyp + self.attn_drop(attn_out)) # ── 3. DendriticFFN ────────────────────────────────────────────── ffn_out, gate_entropy = self.ffn(h_hyp) self._last_gate_entropy = gate_entropy h_hyp = self.ffn_norm(h_hyp + ffn_out) # ── 4. Map back to Euclidean space ─────────────────────────────── h_out = PoincareBall.log_map_at_origin(h_hyp, c=self.c) # ── 5. STDP update (online — only during real training forward) ── if self._stdp_enabled and self.training: self.stdp.stdp_update(h_out.detach()) return h_out class QuantumAmplitudeAttention(nn.Module): """ Quantum Amplitude Attention — lightweight surrogate for the full QMHA. Instead of running the full per-batch statevector loop (which is O(B·N²·2^n) and too slow for real-time inference), this layer: 1. Projects (B, N, d_model) → (B, N, n_qubits) — classical compression 2. Computes Hilbert-space attention weights via learned complex phases 3. Aggregates with phase-weighted softmax across assets 4. Projects back to (B, N, d_model) The complex phase matrix encodes lead-lag structure between assets. last_phase is exposed for V8GlassBoxDiagnostics. This sits in AXRVINet as a PARALLEL branch to HyperbolicCrossAssetLayer; both outputs are blended: h = h_hyp + 0.1 * h_qaa The full QMHA (with statevectors) is active in QCSAMCrossAssetLayer. """ def __init__( self, d_model: int, num_heads: int, dropout: float = 0.1, n_qubits: int = 4, ): super().__init__() self.d_model = d_model self.n_qubits = n_qubits hilbert_dim = 2 ** n_qubits # Classical → quantum compression self.compress = nn.Linear(d_model, n_qubits) # Learnable complex phase matrix (n_qubits, n_qubits) self.phase_re = nn.Parameter(torch.randn(n_qubits, n_qubits) * 0.1) self.phase_im = nn.Parameter(torch.randn(n_qubits, n_qubits) * 0.1) # Re-expansion self.expand = nn.Linear(n_qubits, d_model) self.norm = nn.LayerNorm(d_model) self.drop = nn.Dropout(dropout) # Exposed for diagnostics self.last_phase: Optional[torch.Tensor] = None def forward( self, h: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: h : (B, N, d_model) float32 mask: (B, N) bool — stale assets Returns: h' : (B, N, d_model) float32 """ B, N, D = h.shape # Compress to n_qubits "amplitude" space: (B, N, n_qubits) x = self.compress(h) # (B, N, n_qubits) # Complex phase attention: A_ij = # Approximate as real dot product with phase-rotated keys phase = torch.complex(self.phase_re, self.phase_im) # (n_qubits, n_qubits) self.last_phase = phase.detach() # Apply phase rotation: x_rot = x @ (Re(phase) + Im(phase)) x_rot = x @ (self.phase_re + self.phase_im) # (B, N, n_qubits) # Attention weights: (B, N, N) attn = torch.bmm(x, x_rot.transpose(1, 2)) / (self.n_qubits ** 0.5) # Mask stale assets if mask is not None: m = mask if mask.dim() == 2 else mask.squeeze(0) # (B, N) attn = attn.masked_fill(m.unsqueeze(1), float("-inf")) attn = torch.softmax(attn, dim=-1) # Aggregate and re-expand x_agg = torch.bmm(attn, x) # (B, N, n_qubits) out = self.expand(x_agg) # (B, N, d_model) return self.norm(h + self.drop(out)) # ====================================================================================== # V-pre-b. QCSAMCrossAssetLayer — live QCSAM/FABLE cross-asset engine # ====================================================================================== class QCSAMCrossAssetLayer(nn.Module): """ Live QCSAM/FABLE cross-asset core — integrates the full quantum pipeline into the AXRVINet forward pass. This is the component that turns dead QCSAM/FABLE code into the active cross-asset interaction engine of the ranker. Architecture: Input : (B, N, d_model) float32 — temporal encoder output Step 1 : Adapter — project d_model → n_qubits → (B, N, n_qubits) double Step 2 : QuantumMultiHeadAttention(x=(B, N, n_qubits)) → final_state (B, hilbert_dim) cdouble + align_loss scalar Step 3 : QFFN(final_state) → evolved_state (B, hilbert_dim) cdouble Step 4 : Readout — |evolved_state|.real → (B, hilbert_dim) → broadcast to (B, N, hilbert_dim) → linear (hilbert_dim → d_model) Output : (B, N, d_model) float32 + align_loss scalar Gradient flow: Gradients flow through the adapter (linear), through the classical parts of QMHA (alpha parameters, phase parameters, FABLECLCU coefficients) and through the real part of the QFFN output back to the adapter. The statevector ops in QuantumFeatureMap are differentiable (pure PyTorch). Shape/dtype contract: adapter_input : (B, N, d_model) float32 adapter_output : (B, N, n_qubits) double qmha_input : (B, N, n_qubits) double [n_patches = N] qmha_output : (B, hilbert_dim=2^n_qubits) cdouble qffn_output : (B, hilbert_dim) cdouble readout_input : (B, N, hilbert_dim) float32 (real part broadcast) layer_output : (B, N, d_model) float32 Checkpoint notes: All submodules (adapter_proj, qmha, qffn, readout, residual_gate) are registered nn.Module / nn.Parameter attributes and are therefore fully captured by AXRVINet.state_dict(). FABLECLCU inside qmha is eagerly initialised via num_assets so its weights are in state_dict from the very first save. _clcu_n_patches_buf (persistent buffer) guards against accidental re-initialisation on load_state_dict. """ # Registry flag — set True after first successful forward. # Checked by run_qcsam_integration_contract_test(). _first_forward_complete: bool = False def __init__( self, d_model: int, n_qubits: int = 4, n_heads: int = 2, n_layers: int = 2, qffn_layers: int = 2, num_assets: int = -1, # ← FIXED: now accepts num_assets ): super().__init__() self.d_model = d_model self.n_qubits = n_qubits self.hilbert_dim = 2 ** n_qubits self.n_heads = n_heads # ── Adapter: float32 d_model → double n_qubits ─────────────────── self.adapter_proj = nn.Linear(d_model, n_qubits) nn.init.xavier_uniform_(self.adapter_proj.weight, gain=0.5) nn.init.zeros_(self.adapter_proj.bias) # ── QCSAM core ──────────────────────────────────────────────────── # Pass num_assets so FABLECLCU is eagerly initialised self.qmha = QuantumMultiHeadAttention( n_qubits = n_qubits, n_heads = n_heads, n_layers = n_layers, n_patches = num_assets, # ← FIXED: eagerly init CLCU ) # ── QFFN ───────────────────────────────────────────────────────── self.qffn = QFFN(n_qubits=n_qubits, n_layers=qffn_layers) # ── Readout: hilbert_dim → d_model ─────────────────────────────── self.readout = nn.Sequential( nn.Linear(self.hilbert_dim, d_model), nn.LayerNorm(d_model), ) self.residual_gate = nn.Parameter(torch.full((1,), 0.1)) self.last_align_loss: float = 0.0 self.last_qmha_diagnostics: dict = {} self._fwd_count: int = 0 def forward( self, h: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: h : (B, N, d_model) float32 Returns: h_out : (B, N, d_model) float32 — updated asset embeddings align_loss : scalar double — Hilbert alignment regulariser """ B, N, D = h.shape device = h.device # ── Shape / dtype assertions ────────────────────────────────────── assert h.dtype == torch.float32, \ f"[QCSAMCrossAssetLayer] Expected float32 input, got {h.dtype}" assert h.dim() == 3, \ f"[QCSAMCrossAssetLayer] Expected (B,N,D), got {h.shape}" # ── Step 1: Adapter ─────────────────────────────────────────────── # Project each asset's d_model embedding to n_qubits features, # rescale to [0, π] for data-encoding, convert to double. x_proj = self.adapter_proj(h) # (B, N, n_qubits) float32 # Normalise to [0, π] so QuantumFeatureMap Rx gates operate in range x_min = x_proj.min(dim=-1, keepdim=True).values x_max = x_proj.max(dim=-1, keepdim=True).values x_norm = (x_proj - x_min) / (x_max - x_min + 1e-8) * math.pi x_q = x_norm.double() # (B, N, n_qubits) double # Assert adapter output shape before QMHA assert x_q.shape == (B, N, self.n_qubits), \ f"[QCSAMCrossAssetLayer] Adapter output shape mismatch: {x_q.shape}" assert x_q.dtype == torch.double, \ f"[QCSAMCrossAssetLayer] Adapter output must be double, got {x_q.dtype}" # ── Step 2: QuantumMultiHeadAttention ───────────────────────────── # QMHA treats the N asset embeddings as n_patches=N 'patches'. # Each patch has n_qubits features. # Input: (B, N, n_qubits) double # Output: (B, 2**n_qubits) cdouble + align_loss scalar qmha_state, align_loss, qmha_diag = self.qmha(x_q) self.last_align_loss = align_loss.item() self.last_qmha_diagnostics = qmha_diag # Assert QMHA output assert qmha_state.shape == (B, self.hilbert_dim), \ f"[QCSAMCrossAssetLayer] QMHA output shape: {qmha_state.shape}, expected ({B},{self.hilbert_dim})" assert qmha_state.dtype == torch.cdouble, \ f"[QCSAMCrossAssetLayer] QMHA output must be cdouble, got {qmha_state.dtype}" # ── Step 3: QFFN ───────────────────────────────────────────────── # Apply quantum feedforward network to the aggregated statevector. # Input: (B, hilbert_dim) cdouble # Output: (B, hilbert_dim) cdouble evolved = self.qffn(qmha_state) # Assert QFFN output assert evolved.shape == (B, self.hilbert_dim), \ f"[QCSAMCrossAssetLayer] QFFN output shape: {evolved.shape}, expected ({B},{self.hilbert_dim})" # ── Step 4: Readout ─────────────────────────────────────────────── # Take the real part of the evolved statevector. # Broadcast the single (B, hilbert_dim) context vector over all N assets, # then project back to (B, N, d_model). real_part = evolved.real.float() # (B, hilbert_dim) float32 # Broadcast over N: (B, 1, hilbert_dim) → (B, N, hilbert_dim) context = real_part.unsqueeze(1).expand(-1, N, -1) # (B, N, hilbert_dim) delta = self.readout(context) # (B, N, d_model) # Residual gated addition gate = torch.sigmoid(self.residual_gate) # scalar ∈ (0, 1) h_out = h + gate * delta # (B, N, d_model) float32 # ── Mark first forward complete ─────────────────────────────────── self._fwd_count += 1 QCSAMCrossAssetLayer._first_forward_complete = True # ── QCSAM diagnostic logging ────────────────────────────────────── # Log on the very first forward pass, then every 100 calls. # (10 MC-samples × ~5s rank cycle → every ~50s in steady state.) if self._fwd_count == 1 or self._fwd_count % 100 == 0: gate_val = torch.sigmoid(self.residual_gate).item() heads_info = " | ".join( f"h{i}: attn_norm={hd.get('attn_matrix_norm', 0.0):.4f} " f"align_h={hd.get('align_loss_h', 0.0):.6f}" for i, hd in enumerate(self.last_qmha_diagnostics.get("heads", [])) ) logger.info( f"🔬 [QCSAM] fwd#{self._fwd_count} | " f"B={B} N={N} hilbert_dim={self.hilbert_dim} | " f"gate={gate_val:.4f} | " f"align_loss={self.last_align_loss:.6f} | " f"{heads_info}" ) return h_out, align_loss # ====================================================================================== # V. AXRVINet v8 — COMPLETE MODEL # ====================================================================================== class AXRVINet(nn.Module): def __init__(self, num_assets: int = 5, config: AXRVIConfig = DEFAULT_CONFIG, feature_dim: int = None, seq_len: int = None): super().__init__() """ AXRVI Significance Scoring Network — v8 Institutional Refactor + QCSAM Integration. Architecture: Input: (B, N, T, F) [1] MoETemporalEncoder → (B, N, d_model) [2] HyperbolicCrossAssetLayer × L → (B, N, d_model) [classical GNN, STDP] + QuantumAmplitudeAttention → blended in (parallel, lightweight) [3] QCSAMCrossAssetLayer → (B, N, d_model) [LIVE QCSAM/FABLE core] ├─ Adapter proj → (B, N, n_qubits) double ├─ QuantumMultiHeadAttention (FABLECLCU inside) ├─ QFFN └─ Readout linear → (B, N, d_model) float32 [4] NeuralODERegimeRouter → (B, N, d_model) + regime_probs [5] DistributionalHead → quantiles, value, log_var [6] KANScoringHead → significance scores QCSAM/FABLE integration: QCSAMCrossAssetLayer is the live cross-asset engine. QuantumMultiHeadAttention, FABLECLCU, QuantumFeatureMap, QFFN are all called inside it on every forward pass. align_loss from QMHA is returned in the output dict and consumed by HybridTrainer as L_align. """ """ Args: num_assets : number of parallel asset streams (N dimension) config : AXRVIConfig — all hyperparameters feature_dim : override config.feature_dim (backward compat) seq_len : override config.seq_len (backward compat) """ # Apply any direct overrides so callers that pass feature_dim/seq_len # directly (e.g. test_components) still get the right architecture. if feature_dim is not None or seq_len is not None: import copy config = copy.copy(config) # FIX 8: was cssopy.copy (NameError typo) if feature_dim is not None: config.feature_dim = feature_dim if seq_len is not None: config.seq_len = seq_len self.num_assets = num_assets self.config = config self.d_model = config.d_model # Temporal encoding (MoE with Mamba + Transformer) self.temporal = MoETemporalEncoder( input_dim=config.feature_dim, d_model=config.d_model, n_experts=config.n_moe_experts, top_k=config.moe_top_k, seq_len=config.seq_len, dropout=config.dropout, num_heads=config.num_heads, num_encoder_layers=config.num_encoder_layers ) # Cross-asset hyperbolic GNN layers (classical, with STDP) self.cross_asset_layers = nn.ModuleList([ HyperbolicCrossAssetLayer( d_model=config.d_model, num_heads=config.num_heads, curvature=config.hyperbolic_curvature, dropout=config.dropout ) for _ in range(config.num_cross_layers) ]) # ── LIVE QCSAM/FABLE cross-asset core ──────────────────────────── self.qcsam_layer = QCSAMCrossAssetLayer( d_model = config.d_model, n_qubits = config.qcsam_n_qubits, n_heads = config.qcsam_n_heads, n_layers = config.qcsam_n_layers, qffn_layers = config.qcsam_qffn_layers, num_assets = num_assets, # ← ADD THIS LINE ) # Quantum amplitude attention layers (parallel lightweight branch) self.quantum_attn = nn.ModuleList([ QuantumAmplitudeAttention( d_model=config.d_model, num_heads=config.num_heads, dropout=config.dropout, n_qubits=config.qcsam_n_qubits, ) for _ in range(config.num_cross_layers) ]) # Continuous-time regime router self.regime_router = NeuralODERegimeRouter( d_model=config.d_model, num_regimes=config.num_regimes, n_steps=config.ode_steps, dropout=config.dropout ) # Distributional head self.distributional = DistributionalHead( d_model=config.d_model, latent_dim=config.latent_dim, n_quantiles=config.num_quantiles, dropout=config.dropout ) # KAN scoring head self.scoring = KANScoringHead(n_basis=8) # Gradient checkpointing flag self._checkpoint_grads = config.checkpoint_grads def forward(self, sequences: torch.Tensor, stale_mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: """ Forward pass through the entire network. Args: sequences : (B, N, T, F) float32 input features stale_mask: (B, N) bool — True = stale asset (pad with zeros in attention) Returns: Dictionary with all outputs including QCSAM align_loss. Dataflow: sequences (B,N,T,F) → MoETemporalEncoder → h (B,N,d_model) → HyperbolicCrossAsset × L (h updated in-place with residuals) + QuantumAmplitudeAttn (parallel blend: h += 0.1 * h_qaa) → QCSAMCrossAssetLayer → h (B,N,d_model) [LIVE QMHA+FABLE+QFFN] → NeuralODERegimeRouter → h + regime_probs → DistributionalHead → value, log_var, quantiles, cvar_05 → KANScoringHead → significance_logits, significance_weight """ B, N, T, feat_dim = sequences.shape # ── [1] Temporal encoding ───────────────────────────────────────── h = self.temporal(sequences.view(B * N, T, feat_dim)).view(B, N, self.d_model) # Assert temporal output shape assert h.shape == (B, N, self.d_model), \ f"[AXRVINet] temporal output shape mismatch: {h.shape}" assert h.dtype == torch.float32, \ f"[AXRVINet] temporal output must be float32, got {h.dtype}" # ── [2] Cross-asset processing (Hyperbolic GNN + QAA blend) ────── total_gate_entropy = torch.tensor(0.0, device=h.device) for hyp_layer, qaa_layer in zip(self.cross_asset_layers, self.quantum_attn): if self._checkpoint_grads and self.training: hyp_layer._stdp_enabled = False h_hyp = torch.utils.checkpoint.checkpoint( hyp_layer, h, stale_mask, use_reentrant=False ) hyp_layer._stdp_enabled = True h_qaa = torch.utils.checkpoint.checkpoint( qaa_layer, h, stale_mask, use_reentrant=False ) else: h_hyp = hyp_layer(h, mask=stale_mask) h_qaa = qaa_layer(h, mask=stale_mask) h = h_hyp + 0.1 * h_qaa if hyp_layer._last_gate_entropy is not None: total_gate_entropy = total_gate_entropy + hyp_layer._last_gate_entropy # ── [3] QCSAMCrossAssetLayer — LIVE QMHA + FABLE + QFFN ────────── # This is where the quantum stack becomes active on every forward pass. # Shape in: (B, N, d_model) float32 # Shape out: (B, N, d_model) float32 + align_loss scalar assert h.shape == (B, N, self.d_model), \ f"[AXRVINet] pre-QCSAM shape mismatch: {h.shape}" h, qcsam_align_loss = self.qcsam_layer(h) assert h.shape == (B, N, self.d_model), \ f"[AXRVINet] post-QCSAM shape mismatch: {h.shape}" assert h.dtype == torch.float32, \ f"[AXRVINet] post-QCSAM must be float32, got {h.dtype}" # ── [4] Regime routing ──────────────────────────────────────────── h, regime_logits, regime_probs = self.regime_router(h) # ── [5] Distributional head ─────────────────────────────────────── z, value, log_var, quantiles, cvar_05 = self.distributional(h) # Assert distributional head output shapes assert value.shape == (B, N, 1), f"value shape: {value.shape}" assert log_var.shape == (B, N, 1), f"log_var shape: {log_var.shape}" assert quantiles.shape[0] == B and quantiles.shape[1] == N, \ f"quantiles shape: {quantiles.shape}" # ── [6] Scoring ─────────────────────────────────────────────────── significance_logits, significance_weight = self.scoring(value, log_var) # Collect diagnostics last_phase = self.quantum_attn[-1].last_phase if self.quantum_attn else None return { # Core outputs "significance_logits": significance_logits, "significance_weight": significance_weight, "raw_scores": significance_logits, "probs": significance_weight, # Distributional "z": z, "value": value, "log_var": log_var, "quantiles": quantiles, "cvar_05": cvar_05, # Regime "regime_probs": regime_probs, "regime_logits": regime_logits, # QCSAM/FABLE alignment loss — consumed by HybridTrainer as L_align "qcsam_align_loss": qcsam_align_loss, # Attention diagnostics "attn_weights": { f"ca_layer_{i}": layer.last_attn_weights for i, layer in enumerate(self.cross_asset_layers) if layer.last_attn_weights is not None }, # v8 specific "lead_lag_phase": last_phase, "moe_balance_loss": self.temporal.last_load_balance_loss, "gate_entropy_loss": total_gate_entropy, # QCSAM diagnostics passthrough "qcsam_diagnostics": self.qcsam_layer.last_qmha_diagnostics, } @torch.no_grad() def forward_with_epistemic_uncertainty(self, sequences: torch.Tensor, mc_samples: int = 10) -> Dict[str, torch.Tensor]: """ Monte Carlo Dropout for epistemic uncertainty estimation. Collects all per-sample outputs and returns their means, including the distributional value head (median quantile) and CVaR@5% so that rank_and_gate() can use them on the MC code-path identically to the deterministic path. [S1 / v8] """ self.train() # enable dropout for MC sampling # Bug 8 fix: suppress STDP in all cross-asset layers during MC sampling. # self.train() sets training=True which would trigger stdp_update() on every # forward pass, corrupting adjacency weights with inference-time activations. for layer in self.cross_asset_layers: layer._stdp_enabled = False mc_weights = [] mc_log_vars = [] mc_quantiles = [] mc_regime = [] mc_values = [] # value-head (median quantile) per sample [S1] mc_cvar = [] # CVaR@5% per sample mc_align = [] # QCSAM alignment loss per sample try: for _ in range(mc_samples): out = self.forward(sequences) mc_weights.append(out["significance_weight"]) mc_log_vars.append(out["log_var"].squeeze(-1)) mc_quantiles.append(out["quantiles"]) mc_regime.append(out["regime_probs"]) mc_values.append(out["value"].squeeze(-1)) # (B, N) mc_cvar.append(out["cvar_05"]) # (B, N) mc_align.append(out["qcsam_align_loss"]) finally: self.eval() # always restore eval mode, even if a sample raises for layer in self.cross_asset_layers: layer._stdp_enabled = True # re-enable for subsequent train_on_batch() mc_t = torch.stack(mc_weights, dim=0) # (S, B, N) mean_sig = mc_t.mean(0) epistemic_var = mc_t.var(0) mean_value = torch.stack(mc_values, dim=0).mean(0) # (B, N) mean_cvar = torch.stack(mc_cvar, dim=0).mean(0) # (B, N) # Average align loss over MC samples mean_align = sum(mc_align) / len(mc_align) if mc_align else torch.tensor(0.0) # ── QCSAM inference summary (every 20 mc_forward calls) ─────────── if not hasattr(self, "_mc_fwd_count"): self._mc_fwd_count = 0 self._mc_fwd_count += 1 if self._mc_fwd_count % 20 == 1: align_val = mean_align.item() if hasattr(mean_align, "item") else float(mean_align) gate_val = torch.sigmoid(self.qcsam_layer.residual_gate).item() logger.info( f"🔬 [QCSAM/Inference] mc_fwd#{self._mc_fwd_count} | " f"samples={mc_samples} | " f"mean_align_loss={align_val:.6f} | " f"residual_gate={gate_val:.4f}" ) return { "significance_weight": mean_sig, "epistemic_variance": epistemic_var, "epistemic_std": torch.sqrt(epistemic_var + 1e-8), "log_var": torch.stack(mc_log_vars, dim=0).mean(0).unsqueeze(-1), "quantiles": torch.stack(mc_quantiles, dim=0).mean(0), "regime_probs": torch.stack(mc_regime, dim=0).mean(0), # [S1] Conditional expectation proxy — needed by ShreveRankingEngine "value": mean_value.unsqueeze(-1), # (B, N, 1) # CVaR@5% — needed by PortfolioRiskManager sizing layer "cvar_05": mean_cvar, # (B, N) # QCSAM alignment loss — averaged over MC samples "qcsam_align_loss": mean_align, } def reset_plasticity(self) -> None: """Reset STDP and other online learning states.""" for layer in self.cross_asset_layers: layer.stdp.reset_plasticity() def to_torchscript(self, filepath: Optional[str] = None) -> torch.jit.ScriptModule: """ Export to TorchScript for production deployment. """ self.eval() self.config.torchscript_compatible = True # Create a wrapper with fixed sequence length class Wrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, sequences: torch.Tensor) -> Dict[str, torch.Tensor]: return self.model(sequences, stale_mask=None) scripted = torch.jit.script(Wrapper(self)) if filepath: scripted.save(filepath) return scripted # ====================================================================================== # VI. GLASS BOX DIAGNOSTICS # ====================================================================================== class V8GlassBoxDiagnostics: """ Interpretability tools for AXRVINet v8. Provides dashboard-ready outputs for live monitoring. """ def __init__(self, model: AXRVINet): self.model = model self.config = model.config @torch.no_grad() def lead_lag_matrix(self, out: Dict[str, torch.Tensor], asset_names: List[str]) -> Dict: """Extract lead-lag phase matrix from QAA.""" phase = out.get("lead_lag_phase") if phase is None: return {} mean_phase = phase.mean(dim=(0, 1)).cpu().numpy() return { "matrix": mean_phase.tolist(), "assets": asset_names, "units": "radians (positive = row leads column)" } @torch.no_grad() def regime_dashboard(self, out: Dict[str, torch.Tensor], asset_names: List[str]) -> Dict: """Regime probabilities per asset.""" probs = out["regime_probs"].mean(0).cpu() regime_labels = ["Trending", "Mean-Rev", "High-Vol", "Crash/Spike"] return { asset: dict(zip(regime_labels, probs[i].tolist())) for i, asset in enumerate(asset_names[:len(probs)]) } @torch.no_grad() def kan_spline_data(self, input_range: Tuple[float, float] = (-3, 3), n_pts: int = 100) -> Dict: """Export KAN spline curves for monotonicity audit.""" xs = torch.linspace(*input_range, n_pts) dummy = torch.zeros(n_pts, 2) dummy[:, 0] = xs logits, _ = self.model.scoring(dummy[:, :1].unsqueeze(0), dummy[:, 1:].unsqueeze(0)) return { "x_value": xs.tolist(), "y_score": logits.squeeze(0).cpu().tolist(), "note": "Monotonic increasing = well-behaved scorer" } def full_report(self, out: Dict[str, torch.Tensor], asset_names: List[str]) -> Dict: """Generate complete diagnostic report.""" return { "lead_lag": self.lead_lag_matrix(out, asset_names), "regimes": self.regime_dashboard(out, asset_names), "kan_curves": self.kan_spline_data(), "cvar_05": out["cvar_05"].mean(0).cpu().tolist(), "quantiles": out["quantiles"].mean(0).cpu().tolist(), } # ====================================================================================== # VII. FACTORY FUNCTION FOR EASY INITIALIZATION # ====================================================================================== def create_axrvi_v8(num_assets: int = 5, config: Optional[AXRVIConfig] = None, device: Union[str, torch.device] = "cpu") -> AXRVINet: """ Factory function to create and initialise AXRVINet v8. The ``device`` default has been changed from "cuda" to "cpu" so the factory is safe to call in CPU-only environments (Spaces, laptops). Callers that know CUDA is available should pass device="cuda" explicitly. Weight initialisation follows Xavier-uniform (gain=0.5) for Linear layers and constant-one/zero for LayerNorm, which prevents activation explosion at the start of training through the deep hyperbolic + ODE stack. Example:: model = create_axrvi_v8(num_assets=10, config=cfg, device="cpu") out = model(sequences) # (1, 10, 20, 26) loss = v8_total_loss(out, rank_t, returns, model.distributional.quantile_levels) """ if config is None: config = DEFAULT_CONFIG model = AXRVINet(num_assets=num_assets, config=config) model = model.to(device) # Xavier-uniform weight initialisation — prevents gradient explosion through # the deep HH → hyperbolic → ODE → KAN stack on first forward pass. def init_weights(m: nn.Module) -> None: if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight, gain=0.5) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) model.apply(init_weights) return model def _axrvi_config_from_ranker_config(rc: "AssetRankerConfig") -> AXRVIConfig: """ Derive an ``AXRVIConfig`` from an ``AssetRankerConfig`` so all dimension constants live in one place. Called by ``QuasarAXRVIBridge.initialize()`` to construct the neural-net config directly from the top-level ranker config, avoiding duplicated constants. Any stand-alone caller that has an ``AssetRankerConfig`` may use this function as well. The mapping is: rc.feature_dim → AXRVIConfig.feature_dim (26) rc.d_model → AXRVIConfig.d_model (64) rc.n_heads → AXRVIConfig.num_heads (4) rc.feature_window → AXRVIConfig.seq_len (20) rc.n_encoder_layers → AXRVIConfig.num_encoder_layers (2) rc.n_cross_asset_layers → AXRVIConfig.num_cross_layers (2) rc.latent_dim → AXRVIConfig.latent_dim (32) rc.n_quantiles → AXRVIConfig.num_quantiles (9) rc.n_regimes → AXRVIConfig.num_regimes (4) rc.dropout → AXRVIConfig.dropout (0.1) """ return AXRVIConfig( feature_dim = rc.feature_dim, d_model = rc.d_model, num_heads = rc.n_heads, seq_len = rc.feature_window, num_encoder_layers = rc.n_encoder_layers, num_cross_layers = rc.n_cross_asset_layers, latent_dim = rc.latent_dim, num_quantiles = rc.n_quantiles, num_regimes = rc.n_regimes, dropout = rc.dropout, ) # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 9 — BANDIT SELECTOR # ══════════════════════════════════════════════════════════════════════════════════════ class BanditSelector: """ Multi-armed bandit asset selector: UCB / Thompson Sampling / Greedy (v3). Learns which assets historically yield good rewards when their significance is high. """ class Strategy(Enum): UCB = "ucb" THOMPSON = "thompson" GREEDY = "greedy" def __init__( self, asset_ids: List[str], strategy: "BanditSelector.Strategy" = None, ucb_c: float = UCB_C, thompson_std: float = THOMPSON_STD, ): self.asset_ids = asset_ids self.strategy = strategy or self.Strategy.UCB self.ucb_c = ucb_c self.thompson_std = thompson_std self.t = 0 self.counts = {a: 1 for a in asset_ids} self.total_reward = {a: 0.0 for a in asset_ids} def select( self, neural_scores: np.ndarray, threshold: float = SCORE_THRESHOLD, max_select: int = MAX_CONCURRENT, ) -> Tuple[List[str], np.ndarray]: self.t += 1 if self.strategy == self.Strategy.UCB: log_t = math.log(max(self.t, 1)) exploration = np.array([ self.ucb_c * math.sqrt(log_t / max(self.counts[a], 1)) for a in self.asset_ids ]) final_scores = neural_scores + exploration elif self.strategy == self.Strategy.THOMPSON: noise = np.random.normal(0, self.thompson_std, len(self.asset_ids)) final_scores = neural_scores + noise else: # GREEDY final_scores = neural_scores.copy() ranked_idx = np.argsort(final_scores)[::-1] # Linear O(N) selection: iterate ranked candidates in order, # accept those above threshold until max_select is reached. selected: List[str] = [] for i in ranked_idx: if len(selected) >= max_select: break if final_scores[i] >= threshold: selected.append(self.asset_ids[i]) # ── Warmup fallback: always guarantee at least 1 trade for training data ── # If no asset cleared the threshold, force-select the top scorer so the # replay buffer can start filling even while the model is cold-starting. if not selected and len(ranked_idx) > 0: selected = [self.asset_ids[ranked_idx[0]]] logger.debug( f"[BanditSelector] Threshold={threshold:.3f} cleared by 0 assets — " f"forcing top scorer '{selected[0]}' (score={final_scores[ranked_idx[0]]:.3f}) " f"for training-data warmup." ) return selected, final_scores def update_reward(self, asset_id: str, reward: float) -> None: if asset_id in self.counts: self.counts[asset_id] += 1 self.total_reward[asset_id] += reward def get_avg_rewards(self) -> Dict[str, float]: return { a: self.total_reward[a] / max(self.counts[a], 1) for a in self.asset_ids } def state_dict(self) -> dict: """Return serialisable bandit state for checkpointing.""" return { "asset_ids": self.asset_ids, "strategy": self.strategy.value, "ucb_c": self.ucb_c, "thompson_std": self.thompson_std, "t": self.t, "counts": dict(self.counts), "total_reward": dict(self.total_reward), } def load_state_dict(self, state: dict) -> None: """Restore bandit state from a dict produced by state_dict().""" self.t = state.get("t", self.t) self.counts = state.get("counts", self.counts) self.total_reward = state.get("total_reward", self.total_reward) # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 10 — RANKING ENGINE (v3) + CONSERVATIVE RANKER (v4) # ══════════════════════════════════════════════════════════════════════════════════════ class ShreveRankingEngine: """ [S1, S5, S6] Shreve-consistent ranking engine (v6). Priority formula: Π_t = D(t,τ) · Ê[R_{t+τ} | F_t] + ½ · σ²_t · Δt · κ_t · w_curv where: D(t,τ) = exp(−r · τ) discount factor [S6] Ê[R_{t+τ} | F_t] = hub_confidence · neural_significance (best available proxy for conditional expectation) [S1] ½ σ²_t Δt κ_t = Itô–Doeblin curvature correction [S5] κ_t = convexity feature [19], σ²_t from feature [6] w_curv = ShreveConfig.ito_curvature_weight Fallback: if significance_weights not provided, uses avn_accuracy. This replaces v5 RankingEngine (Final_Priority = confidence × significance). The legacy alias `rank()` still works and is mapped to `rank_risk_neutral()`. """ def __init__( self, shreve_config: Optional[ShreveConfig] = None, asset_buffers: Optional[Dict[str, "AssetStateBuffer"]] = None, ): self.sc = shreve_config or ShreveConfig() self.asset_buffers = asset_buffers or {} # Annualisation for τ conversion: seconds → fraction of year self._sec_per_year = 365.25 * 24 * 3600.0 # ── Main ranking method ─────────────────────────────────────────────────── def rank_risk_neutral( self, snapshots: Dict[str, "AssetSnapshot"], significance_weights: Optional[Dict[str, float]] = None, value_estimates: Optional[Dict[str, float]] = None, ) -> List["RankedAsset"]: """ Rank assets by Π_t = D(t,τ)·Ê[R_{t+τ}|F_t] + Itô curvature [S1,S5,S6]. Parameters ---------- snapshots : hub AssetSnapshots (confidence, signal) significance_weights : AXRVINet significance_weight per asset value_estimates : AXRVINet value head output per asset (V̂_t) [S1] """ # [S6] D(t,τ)=exp(-r·τ) ≈ 0.9999999 at τ=60s — numerically 1.0, removed from # priority formula. Logged only so Shreve grounding is visible in output. discount_log = math.exp( -self.sc.risk_free_rate * self.sc.horizon_seconds / self._sec_per_year ) results: List[RankedAsset] = [] # FIX: iterate over ALL configured asset symbols, not just hub snapshots. # hub_snapshots only contains assets whose publisher has connected — if any # space has not yet sent a metrics_update the snapshot is missing and that # asset is silently dropped from the ranking, causing the executor to see # only the one asset whose space happens to be the active publisher. _all_asset_ids = ( list(self.asset_buffers.keys()) if self.asset_buffers else list(snapshots.keys()) ) for space_name in _all_asset_ids: snap = snapshots.get(space_name) if snap is None: # Asset has no hub snapshot yet — create a neutral default so it # participates in ranking with zero significance weight. snap = AssetSnapshot(space_name=space_name) acc = snap.avn_accuracy # Raw neural significance — no hub_confidence multiplier, no Gaussian # lower-bound. Cross-asset competition is handled by low-τ softmax # in rank_and_gate() (τ=0.05) which converts these logits into # sharp allocation probabilities. sig_w = ( float(significance_weights.get(space_name, acc)) if significance_weights else acc ) # [S1] Blend value head when available: 70% value + 30% significance if value_estimates and space_name in value_estimates: v_hat = float(value_estimates[space_name]) priority = 0.7 * v_hat + 0.3 * sig_w else: priority = sig_w # unconstrained neural output, no compression # [S5] Itô–Doeblin curvature correction (retained — adds real signal) buf = self.asset_buffers.get(space_name) if buf is not None: sigma_ann = buf.feature_eng.get_raw_feature(6) kappa_t = buf.feature_eng.get_raw_feature(19) dt_years = self.sc.qv_dt_seconds / self._sec_per_year ito_curv = 0.5 * (sigma_ann ** 2) * dt_years * kappa_t priority += self.sc.ito_curvature_weight * ito_curv results.append(RankedAsset( space_name = space_name, signal_confidence = round(snap.signal_confidence, 6), # stored for logging only; not used in scoring significance_weight = round(sig_w, 6), final_priority = round(priority, 6), score = round(priority, 6), dominant_signal = snap.dominant_signal, avn_accuracy = round(acc, 6), training_steps = snap.training_steps, )) results.sort(key=lambda r: r.final_priority, reverse=True) for i, r in enumerate(results): r.rank = i + 1 return results # ── Legacy alias (v5 API compatibility) ────────────────────────────────── def rank( self, snapshots: Dict[str, "AssetSnapshot"], significance_weights: Optional[Dict[str, float]] = None, ) -> List["RankedAsset"]: """Backward-compatible alias → rank_risk_neutral() without value estimates.""" return self.rank_risk_neutral(snapshots, significance_weights) def top_asset(self, snapshots: Dict[str, "AssetSnapshot"]) -> Optional["RankedAsset"]: ranked = self.rank(snapshots) return ranked[0] if ranked else None # ── Keep legacy RankingEngine as thin wrapper for any external consumers ────── class RankingEngine(ShreveRankingEngine): """ Legacy alias for ShreveRankingEngine. Final_Priority now uses [S1,S5,S6] formula rather than plain conf×sig. """ pass class ConservativeRanker: """ AXRVI v7 — pass-through ranker. The old implementation collapsed all scores to ~0.30 by applying: hub_confidence × max(0, sig_weight − 1.96 × σ_total) Root cause (see why_scores_collapse_to_0_30 analysis): • hub_confidence is ≥ 0.99 for all assets → no cross-asset differentiation • sig_weight ≈ 0.50 for untrained net → same for everyone • Gaussian penalty 1.96 × 0.10 = 0.196 → floors everyone at 0.304 • Result: 4th-decimal hub_confidence edge permanently locks rank order FIX: return the raw significance logit z_i unchanged. Differentiation and cross-asset competition are handled upstream by the low-temperature softmax in rank_and_gate() (τ = 0.05). No hub_confidence gate, no Gaussian clip, no lower-bound clipping. """ def __init__( self, confidence_level: float = 0.95, # retained for API compat uncertainty_veto_threshold: float = 0.3, # retained for API compat ): self.confidence_level = confidence_level self.uncertainty_veto_threshold = uncertainty_veto_threshold def compute_conservative_priority( self, hub_confidence: float, # IGNORED — was the score-compression lever significance_weight: float, # raw logit or sigmoid weight from AXRVINet aleatoric_std: float = 0.1, # IGNORED — Gaussian penalty removed epistemic_std: float = 0.0, # IGNORED — Gaussian penalty removed ) -> float: # Return significance_weight unchanged. # The old formula hub_confidence × max(0, sig − 1.96σ) is gone. # Ranking competition happens via Softmax(logits / τ=0.05) downstream. return significance_weight def should_veto(self, epistemic_std: float, aleatoric_std: float) -> bool: # Veto logic retained for DynamicExecutionGate compatibility; never # called to suppress ranking scores any more. return math.sqrt(epistemic_std ** 2 + aleatoric_std ** 2) \ > self.uncertainty_veto_threshold # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 11 — DYNAMIC EXECUTION GATE (v4) # ══════════════════════════════════════════════════════════════════════════════════════ class DynamicExecutionGate: """ Seven-gate volatility/uncertainty/jump-adjusted execution filter (v4). Applied as Gate D, after the triple hub/significance/threshold gates (v3). Threshold adapts to recent performance history. """ def __init__( self, base_threshold: float = 0.0, # Dynamic: permissive start, adapts upward volatility_risk_aversion: float = 1.0, max_epistemic_uncertainty: float = 0.5, max_jump_risk: float = 0.7, high_vol_threshold: float = 2.0, martingale_epsilon: float = 0.05, # [S7] Gate E threshold ): self.base_threshold = base_threshold self.volatility_risk_aversion = volatility_risk_aversion self.max_epistemic_uncertainty = max_epistemic_uncertainty self.max_jump_risk = max_jump_risk self.high_vol_threshold = high_vol_threshold self._martingale_epsilon = martingale_epsilon # [S7] self.current_threshold = base_threshold self._performance_history: deque = deque(maxlen=100) def should_execute( self, hub_confidence: float, significance: float, volatility_ratio: float, jump_risk: float, epistemic_std: float, aleatoric_std: float, dominant_signal: str, martingale_deviation: float = 0.0, # [S7] default 0.0: callers that omit this are conservatively blocked by Gate E ) -> Tuple[bool, str]: if dominant_signal not in ("BUY", "SELL"): return False, f"Signal is {dominant_signal} (not directional)" if hub_confidence <= 0.0: return False, f"Hub confidence {hub_confidence:.3f} is zero" if significance < self.current_threshold: return False, ( f"Significance {significance:.3f} < threshold {self.current_threshold:.3f}" ) total_std = math.sqrt(epistemic_std ** 2 + aleatoric_std ** 2) if total_std > self.max_epistemic_uncertainty: return False, ( f"Total uncertainty {total_std:.3f} > {self.max_epistemic_uncertainty}" ) if jump_risk > self.max_jump_risk: return False, f"Jump risk {jump_risk:.3f} > {self.max_jump_risk}" if volatility_ratio > self.high_vol_threshold: return False, f"Vol ratio {volatility_ratio:.3f} > {self.high_vol_threshold}" risk_adj_sig = significance / (1 + self.volatility_risk_aversion * volatility_ratio) if risk_adj_sig < self.base_threshold * 0.5: return False, f"Risk-adjusted sig {risk_adj_sig:.3f} too low" # [S7] Gate E — Martingale null-hypothesis filter # H₀: E[Δ log S_{t+1} | F_t] = 0 (price process is a martingale) # Trade only if DevMart(t) > ε, i.e. the null is statistically rejected. if martingale_deviation <= self._martingale_epsilon: return False, ( f"Gate E: DevMart={martingale_deviation:.4f} ≤ ε={self._martingale_epsilon:.4f} " f"(returns consistent with martingale H₀ — no tradeable edge)" ) return True, "All gates passed (A–E)" def record_performance(self, sharpe: float) -> None: self._performance_history.append(sharpe) if len(self._performance_history) >= 10: recent_sharpe = float(np.mean(self._performance_history)) adjustment = 0.1 * (-recent_sharpe) new_thresh = float(np.clip(self.base_threshold + adjustment, 0.0, 0.8)) self.current_threshold = 0.9 * self.current_threshold + 0.1 * new_thresh def state_dict(self) -> dict: """Return serialisable gate state for checkpointing.""" return { "current_threshold": self.current_threshold, "performance_history": list(self._performance_history), } def load_state_dict(self, state: dict) -> None: """Restore gate state from a dict produced by state_dict().""" self.current_threshold = state.get("current_threshold", self.current_threshold) hist = state.get("performance_history", []) if hist: self._performance_history.extend(hist) # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 12 — GIRSANOV PRIORITISED REPLAY BUFFER (v4) # ══════════════════════════════════════════════════════════════════════════════════════ class GirsanovReplayBuffer: """ Prioritised experience replay buffer with Girsanov-motivated weighting (v4). Priority ∝ TD error + 1/volatility + |reward| (1/σ weighting is motivated by Girsanov's theorem: changes of measure in stochastic calculus weight by the inverse diffusion coefficient.) Extended vs v4: is_ready() added for HybridTrainer compatibility. Episode dict extended to include both trainer keys (selected_idx, pnl_per_asset) and Girsanov keys (volatility, td_error). """ def __init__(self, capacity: int, config: Optional[GirsanovReplayConfig] = None): self.capacity = capacity self.config = config or GirsanovReplayConfig() self.buffer: List[dict] = [] self.priorities: List[float] = [] self.position = 0 def push(self, episode: dict) -> None: # [FIX-4a] Cap reward magnitude before storing so extreme jump events # don't dominate the priority distribution indefinitely. # tanh-squashed rewards are already in (−1,+1); raw rewards from # external callers might not be — hard-clip as a safety net. if "reward" in episode: episode = episode.copy() episode["reward"] = float(np.clip(episode["reward"], -1.0, 1.0)) priority = self._compute_priority(episode) priority = max(self.config.min_priority, min(self.config.max_priority, priority)) if len(self.buffer) < self.capacity: self.buffer.append(episode) self.priorities.append(priority) else: self.buffer[self.position] = episode self.priorities[self.position] = priority self.position = (self.position + 1) % self.capacity def _compute_priority(self, episode: dict) -> float: priority = 0.0 if self.config.use_td_error_weighting: # [FIX-4b] Clamp TD-error contribution so a single extreme transition # cannot monopolise sampling. Raw TD errors from an unstable early # training phase can be O(1000); cap at 5.0 to limit priority skew. td_contrib = min(abs(episode.get("td_error", 0.0)), 5.0) priority += td_contrib if self.config.use_vol_weighting: vol = episode.get("volatility", 1.0) priority += 1.0 / (vol + 1e-8) # [FIX-4c] Reward already capped in push(); still clip here defensively. priority += min(abs(episode.get("reward", 0.0)), 1.0) return max(self.config.min_priority, priority) def sample(self, batch_size: int) -> List[dict]: if not self.buffer: return [] # ── PnL-variance filter ─────────────────────────────────────────────── # Batches where pnl_per_asset ≈ constant across all assets yield # z_best ≈ z_worst → zero gradient for L_rank → significance_weight # stays uniform at 0.5 for all assets (the original collapse bug). # # Filter: keep only episodes where Var(pnl_per_asset) ≥ MIN_PNL_VAR. # The threshold is intentionally small (1e-6) so only truly flat PnL # vectors (e.g. all zeros on cold-start) are excluded. MIN_PNL_VAR = 1e-6 informative = [ ep for ep in self.buffer if ( "pnl_per_asset" in ep and float(np.var(ep["pnl_per_asset"])) >= MIN_PNL_VAR ) ] n_filtered = len(self.buffer) - len(informative) if n_filtered > 0: logger.debug( f"[GirsanovReplayBuffer] PnL-variance filter: " f"{n_filtered}/{len(self.buffer)} episodes excluded " f"(uniform pnl_per_asset — no gradient signal for L_rank)" ) # Fall back to full buffer if not enough informative episodes yet pool = informative if len(informative) >= batch_size else self.buffer if self.config.use_prioritized: # Build priority weights from the filtered pool pool_indices = [self.buffer.index(ep) for ep in pool if ep in self.buffer] all_prios = np.array(self.priorities) pool_prios = all_prios[pool_indices] if pool_indices else all_prios probs = pool_prios ** self.config.priority_alpha probs /= probs.sum() chosen = np.random.choice( len(pool), min(batch_size, len(pool)), p=probs, replace=False, ) weights = (1.0 / (len(pool) * probs[chosen])) ** self.config.priority_beta weights /= weights.max() samples = [] for idx, w in zip(chosen, weights): ep = pool[idx].copy() ep["importance_weight"] = float(w) samples.append(ep) return samples indices = np.random.choice( len(pool), min(batch_size, len(pool)), replace=False ) return [pool[i].copy() for i in indices] def update_priorities(self, indices: List[int], td_errors: List[float]) -> None: for idx, td_err in zip(indices, td_errors): if idx < len(self.priorities): new_p = self.priorities[idx] + abs(td_err) self.priorities[idx] = max( self.config.min_priority, min(self.config.max_priority, new_p) ) def is_ready(self, min_size: int) -> bool: """True once the buffer has at least `min_size` episodes.""" return len(self.buffer) >= min_size def __len__(self) -> int: return len(self.buffer) def save_state(self) -> dict: """Return serialisable buffer state for checkpointing.""" return { "capacity": self.capacity, "buffer": self.buffer, "priorities": self.priorities, "position": self.position, } def load_state(self, state: dict) -> None: """Restore buffer state from a dict produced by save_state().""" self.capacity = state.get("capacity", self.capacity) self.buffer = state.get("buffer", []) self.priorities = state.get("priorities", []) self.position = state.get("position", 0) # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 13 — HYBRID TRAINER (v3) # ══════════════════════════════════════════════════════════════════════════════════════ class HybridTrainer: """ [S1] Trains AXRVINet with four loss components (v6): L_rl — RL (TD error): pushes value estimates toward discounted returns L_ce — Value-consistency (Shreve CE): (V̂_t − R_{t+τ})² [S1] Interprets model output as E[R_{t+τ} | F_t] and enforces that. L_rank — Ranking margin: high-PnL assets must have higher significance_logit L_risk — Uncertainty penalty: minimise aleatoric variance on selected assets total_loss = L_rl + λ_ce·L_ce + λ_rank·L_rank + λ_risk·L_risk Note: next_sequences in the episode dict must be a genuinely distinct future state (s_{t+1} ∈ F_{t+1}) captured after the trade closed. This is enforced by QuasarAXRVIBridge._close_position via the pending-episode store [S2]. """ def __init__( self, model: AXRVINet, lr: float = LEARNING_RATE, gamma: float = GAMMA, lambda_rank: float = LAMBDA_RANK, lambda_risk: float = LAMBDA_RISK, lambda_ce: float = 0.3, # [S1] value-consistency weight rank_margin: float = 0.1, ranker_logger: Optional[object] = None, ): self.model = model self.gamma = gamma self.lambda_rank = lambda_rank self.lambda_risk = lambda_risk self.lambda_ce = lambda_ce # [S1] self.lambda_ql = 0.2 # [v7] quantile distribution calibration self.rank_margin = rank_margin self.ranker_logger = ranker_logger # [v8] MoE load-balance and dendritic gate-entropy regularisation weights. # Pull from the model's AXRVIConfig so they stay in one place. _cfg = getattr(model, "config", None) self.lambda_moe = float(getattr(_cfg, "moe_balance_coeff", 0.01)) self.lambda_gate = float(getattr(_cfg, "lambda_gate_reg", 0.005)) self.lambda_crps = float(getattr(_cfg, "lambda_crps", 0.10)) # [Bug 4 fix] self.lambda_rent = float(getattr(_cfg, "lambda_entropy", 0.01)) # [Bug 4 fix] # Dispersion regulariser: penalises uniform logits (−Var(z)). self.lambda_div = float(getattr(_cfg, "lambda_div", 0.05)) # QCSAM/FABLE Hilbert alignment loss weight [live integration] self.lambda_align = float(getattr(_cfg, "qcsam_align_loss_weight", 0.05)) self.optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) self.scheduler = optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=1000, eta_min=lr / 10 ) # Track which lazy-init parameter sets have already been added to the # optimizer so _register_lazy_params() is idempotent. self._lazy_params_registered: bool = False self.train_step = 0 self.loss_history = deque(maxlen=200) # ── AMP (Automatic Mixed Precision) ────────────────────────────────── # Reads use_amp from the model's AXRVIConfig so the flag is the single # source of truth. Falls back to False on CPU (AMP is CUDA-only). _cfg = getattr(model, "config", None) _use_amp = bool(getattr(_cfg, "use_amp", False)) _device = next(model.parameters()).device # AMP is only meaningful on CUDA; silently disable on CPU self.use_amp = _use_amp and _device.type == "cuda" self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp) def _register_lazy_params(self) -> None: """Add lazily-created FABLECLCU parameters to the optimizer. FABLECLCU._log_r / .phase are created inside QuantumMultiHeadAttention._ensure_clcu() on the *first* forward pass, which happens *after* AdamW was constructed in __init__. Those params therefore never appear in optimizer.param_groups and would otherwise be frozen at their zero-initialisation values throughout training. This method is idempotent — once called it sets _lazy_params_registered to True and becomes a no-op, so it can safely be called at the top of every train_on_batch() call. BUG FIX (v6.1): On checkpoint resume, axrvi_net.load_state_dict() is called before the optimizer is first used, which causes the FABLECLCU parameters to already be yielded by model.parameters() and therefore already present in the base AdamW param group. The old code then called add_param_group() with the same tensors → PyTorch raised: "some parameters appear in more than one parameter group" crashing train_on_batch() and halting all training indefinitely. Fix: collect the set of parameter ids already registered in any existing param group, and only add the genuinely new (unseen) params. If nothing is new, just mark as registered and return. """ if self._lazy_params_registered: return qmha = getattr( getattr(getattr(self.model, "qcsam_layer", None), "qmha", None), "clcu_heads", None, ) if qmha is None: return # first forward pass has not happened yet — retry next batch clcu_params = list(qmha.parameters()) if not clcu_params: return # ── Guard: filter out params already in any optimizer group ─────────── # This prevents the "parameters appear in more than one parameter group" # error that fires on checkpoint resume when model weights (including # lazily-created FABLECLCU params) are loaded before add_param_group(). existing_ids = { id(p) for grp in self.optimizer.param_groups for p in grp["params"] } new_params = [p for p in clcu_params if id(p) not in existing_ids] if not new_params: # All FABLECLCU params already covered by the base group (post-resume). self._lazy_params_registered = True logger.info( "[HybridTrainer] FABLECLCU params already in optimizer base group " "(checkpoint resume path) — skipping add_param_group." ) return self.optimizer.add_param_group({ "params": new_params, "lr": self.optimizer.defaults["lr"], "weight_decay": self.optimizer.defaults["weight_decay"], }) self._lazy_params_registered = True logger.info( f"[HybridTrainer] Registered {len(new_params)} FABLECLCU parameter(s) " "with AdamW (lazy-init fix)." ) def train_on_batch(self, episodes: List[dict]) -> dict: """ v7 training — 5-loss objective: L_rl — TD error (unchanged) L_ce — Value-consistency [S1] (unchanged) L_rank — Ranking margin (unchanged) L_risk — Uncertainty penalty (unchanged) L_ql — Quantile / pinball loss [NEW]: calibrates full return distribution total = L_rl + λ_ce·L_ce + λ_rank·L_rank + λ_risk·L_risk + λ_ql·L_ql """ if not episodes: return {} valid = [ ep for ep in episodes if "sequences" in ep and "next_sequences" in ep and "selected_idx" in ep and "pnl_per_asset" in ep ] if not valid: logger.warning( f"[HybridTrainer] No valid episodes in batch of {len(episodes)} — " f"skipping training step. Episodes may be missing required keys: " f"'sequences', 'next_sequences', 'selected_idx', 'pnl_per_asset'." ) return {} # ── [FIX] Drop episodes whose N dimension doesn't match the current model ── # This happens when new assets are added mid-run: old replay-buffer episodes # have shape (N_old, T, F) while the model now expects (N_new, T, F). # torch.stack requires identical shapes, so mismatched episodes crash training. expected_n = self.model.num_assets size_before = len(valid) valid = [ ep for ep in valid if ( len(ep["sequences"]) == expected_n and len(ep["next_sequences"]) == expected_n and len(ep["pnl_per_asset"]) == expected_n ) ] dropped = size_before - len(valid) if dropped: logger.warning( f"[HybridTrainer] Dropped {dropped}/{size_before} episodes with " f"stale asset-count (expected N={expected_n}). " f"These are left-over from before new assets were added and will " f"expire naturally as the replay buffer fills with fresh episodes." ) if not valid: logger.warning( f"[HybridTrainer] All {size_before} episodes have wrong N " f"(expected {expected_n}) — skipping this training step." ) return {} self.model.train() device = next(self.model.parameters()).device def _build(key: str) -> torch.Tensor: """Stack per-episode arrays → (B, N, T, F) or (B, N) depending on key.""" return torch.stack( [torch.tensor(ep[key], dtype=torch.float32) for ep in valid] ).to(device) seq_t = _build("sequences") # (B, N, T, F) next_seq_t = _build("next_sequences") # (B, N, T, F) selected = torch.tensor([ep["selected_idx"] for ep in valid], dtype=torch.long, device=device) rewards = torch.tensor([ep["reward"] for ep in valid], dtype=torch.float32, device=device) pnl_arr = torch.tensor( np.stack([ep["pnl_per_asset"] for ep in valid]), dtype=torch.float32, device=device) imp_w = torch.tensor([ep.get("importance_weight", 1.0) for ep in valid], dtype=torch.float32, device=device) # ── Forward pass (AMP-aware) ────────────────────────────────────────── with torch.cuda.amp.autocast(enabled=self.use_amp): out = self.model(seq_t) # Fix 1 — register FABLECLCU params that were lazily created during # this first forward pass but were absent when AdamW was constructed. self._register_lazy_params() scores = out["significance_logits"] value = out["value"].squeeze(-1) # (B, N) — median Q̂_{0.5} log_var = out["log_var"].squeeze(-1) # (B, N) quantiles = out["quantiles"] # (B, N, n_quantiles) with torch.no_grad(): next_out = self.model(next_seq_t) best_next_v = next_out["value"].squeeze(-1).max(dim=1).values # ── [FIX-7a] Reward distribution diagnostics ───────────────────── with torch.no_grad(): r_mean = rewards.mean().item() r_std = rewards.std().item() r_max = rewards.abs().max().item() if r_max > 0.95: # alert if rewards are near saturation logger.warning( f"[TrainDiag] Reward saturation: " f"mean={r_mean:.4f} std={r_std:.4f} absmax={r_max:.4f} " f"— check tanh squashing scale." ) else: logger.debug( f"[TrainDiag] Reward stats: " f"mean={r_mean:.4f} std={r_std:.4f} absmax={r_max:.4f}" ) # L_rl — TD error [FIX-2a: Huber replaces MSE to prevent quadratic explosion] selected_v = value.gather(1, selected.unsqueeze(1)).squeeze(1) td_target = rewards + self.gamma * best_next_v # Huber(δ=1): linear for |err|>1, quadratic for |err|≤1 — bounded gradient l_rl_raw = (imp_w * F.huber_loss(selected_v, td_target.detach(), reduction="none", delta=1.0)).mean() # L_ce — [S1] Value-consistency: selected_v ≈ E[R_{t+τ} | F_t] # [FIX-2b: Huber replaces MSE] same rationale as L_rl l_ce_raw = (imp_w * F.huber_loss(selected_v, rewards.detach(), reduction="none", delta=1.0)).mean() # L_rank — Pairwise ranking loss (Burges et al., 2005 / RankNet) best_idx = pnl_arr.argmax(dim=1) # (B,) worst_idx = pnl_arr.argmin(dim=1) # (B,) z_best = scores.gather(1, best_idx.unsqueeze(1)).squeeze(1) # (B,) z_worst = scores.gather(1, worst_idx.unsqueeze(1)).squeeze(1) # (B,) l_rank_raw = -F.logsigmoid(z_best - z_worst).mean() # L_div — Dispersion regulariser l_div = -scores.var(dim=-1).mean() # L_risk — Uncertainty penalty l_risk_raw = torch.exp(log_var.gather(1, selected.unsqueeze(1)).squeeze(1)).mean() # L_ql — Quantile / pinball loss [v7] sel_q = quantiles.gather( 1, selected.unsqueeze(1).unsqueeze(2).expand(-1, 1, quantiles.shape[-1]) ).squeeze(1) # (B, n_quantiles) tau = self.model.distributional.quantile_levels # (n_quantiles,) u = rewards.unsqueeze(1) - sel_q # (B, n_quantiles) l_ql_raw = torch.max(tau * u, (tau - 1.0) * u).mean() # L_moe — MoE load-balance regularisation [v8] l_moe = out.get("moe_balance_loss", torch.tensor(0.0, device=seq_t.device)) # L_gate — DendriticFFN gate-entropy regularisation [v8] l_gate = out.get("gate_entropy_loss", torch.tensor(0.0, device=seq_t.device)) # L_crps — distributional calibration (CRPS proper scoring rule) [v8] ql_levels = self.model.distributional.quantile_levels l_crps = crps_loss(sel_q, ql_levels, rewards) # L_rent — regime-router entropy regularisation [v8] r_probs = out["regime_probs"] + 1e-8 # (B, N, n_regimes) regime_ent = -(r_probs * r_probs.log()).sum(-1).mean() l_rent = -regime_ent # L_align — QCSAM/FABLE Hilbert-space alignment [live QMHA integration] l_align = out.get("qcsam_align_loss", torch.tensor(0.0, device=seq_t.device)) # ── [FIX-5] Per-component loss normalization ────────────────────── # Normalise each stochastic-calculus-critical component to unit # scale before weighting. This prevents a single component from # overwhelming the gradient when its natural scale drifts (e.g. # l_risk can be >>1 when log_var is large; l_rank is in [0, ln2]). _eps_norm = 1e-6 def _safe_norm(t: torch.Tensor) -> torch.Tensor: """Divide by running L1 norm to keep component at O(1).""" mag = t.detach().abs().clamp(min=_eps_norm) return t / mag l_rl = _safe_norm(l_rl_raw) l_ce = _safe_norm(l_ce_raw) l_rank = _safe_norm(l_rank_raw) l_risk = _safe_norm(l_risk_raw) l_ql = _safe_norm(l_ql_raw) # Total loss — 11-component objective (v8 + QCSAM alignment) loss = (l_rl + self.lambda_ce * l_ce + self.lambda_rank * l_rank + self.lambda_risk * l_risk + self.lambda_ql * l_ql + self.lambda_moe * l_moe + self.lambda_gate * l_gate + self.lambda_crps * l_crps + self.lambda_rent * l_rent + self.lambda_div * l_div + self.lambda_align * l_align) # ── Backward pass (AMP-aware) ───────────────────────────────────────── self.optimizer.zero_grad() self.scaler.scale(loss).backward() # Unscale before clipping so grad norms are in the original fp32 scale self.scaler.unscale_(self.optimizer) # [FIX-3] Gradient clipping — enforced every step; max_norm=1.0 # Unscaled norms are logged for diagnostics [FIX-7b]. grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) grad_norm_val = float(grad_norm) if grad_norm_val > 0.9: # approaching or hitting clip boundary logger.warning( f"[TrainDiag] Grad norm {grad_norm_val:.4f} " f"{'CLIPPED' if grad_norm_val >= 1.0 else 'near clip'} " f"at step {self.train_step + 1}." ) self.scaler.step(self.optimizer) self.scaler.update() # scheduler.step() is called externally (once per training epoch / rank # cycle) via step_scheduler(), NOT here per batch. Calling it here # compressed the entire T_max=1000 cosine schedule into ~14 hours of # 5-second rank cycles rather than the intended ~1000 training epochs. self.train_step += 1 loss_dict = { "total": loss.item(), # [FIX-7c] Raw (pre-norm) component values for diagnostics "rl": l_rl_raw.item(), "ce": l_ce_raw.item(), "rank": l_rank_raw.item(), "risk": l_risk_raw.item(), "ql": l_ql_raw.item(), # Normalised versions (what actually enters the gradient) "rl_norm": l_rl.item(), "ce_norm": l_ce.item(), "rank_norm": l_rank.item(), "risk_norm": l_risk.item(), "ql_norm": l_ql.item(), "div": l_div.item(), "moe": l_moe.item() if hasattr(l_moe, "item") else float(l_moe), "gate": l_gate.item() if hasattr(l_gate, "item") else float(l_gate), "crps": l_crps.item() if hasattr(l_crps, "item") else float(l_crps), "rent": l_rent.item() if hasattr(l_rent, "item") else float(l_rent), "align": l_align.item() if hasattr(l_align, "item") else float(l_align), # [FIX-7b] Gradient norm diagnostic "grad_norm": grad_norm_val, # [FIX-7a] Reward distribution summary "reward_mean": r_mean, "reward_std": r_std, "reward_absmax": r_max, "step": self.train_step, } self.loss_history.append(loss_dict) if self.ranker_logger: # [LABEL FIX] Previously passed len(valid) here — that's the # BATCH SIZE (number of episodes in this training step), not # the asset count. With TRAIN_BATCH=2 the field always showed # "asset_count=2" which looked like "only 2 of 10 assets are # training", but in fact every episode carries the full # (N=10, T, F) tensor and all 10 assets are trained per step. # # We now pass the TRUE asset count (model.num_assets) so the # dashboard/log reflects reality, and also log batch_size in # the human-readable line below so batch health stays visible. self.ranker_logger.training_update( step=self.train_step, loss=loss.item(), lr=self.optimizer.param_groups[0]["lr"], asset_count=self.model.num_assets, ) logger.info( f"🧠 [TrainingStep {self.train_step:>6d}] " f"total={loss.item():.4f} " f"assets={self.model.num_assets} batch={len(valid)}/{len(episodes)} " f"rl={l_rl_raw.item():.4f}(n={l_rl.item():.4f}) " f"ce={l_ce_raw.item():.4f}(n={l_ce.item():.4f}) " f"rank={l_rank_raw.item():.4f} " f"risk={l_risk_raw.item():.4f} " f"ql={l_ql_raw.item():.4f} " f"div={loss_dict['div']:.4f} moe={loss_dict['moe']:.4f} " f"gate={loss_dict['gate']:.4f} crps={loss_dict['crps']:.4f} " f"rent={loss_dict['rent']:.4f} align={loss_dict['align']:.4f} " f"grad={grad_norm_val:.4f} " f"r̄={r_mean:.4f} r_σ={r_std:.4f} r_max={r_max:.4f} " f"lr={self.optimizer.param_groups[0]['lr']:.2e}" ) return loss_dict def save(self, path: str) -> None: """ Persist model + optimizer + all loss-weight scalars to *path*. Called by step_scheduler() after every LR step, and by Bridge.run() on clean shutdown for the legacy single-file backward-compat checkpoint. """ torch.save({ "model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "scheduler": self.scheduler.state_dict(), "step": self.train_step, "lambda_ce": self.lambda_ce, "lambda_ql": self.lambda_ql, "lambda_rank": self.lambda_rank, "lambda_risk": self.lambda_risk, "lambda_moe": self.lambda_moe, "lambda_gate": self.lambda_gate, "lambda_crps": self.lambda_crps, # v8 / Bug 4 "lambda_rent": self.lambda_rent, # v8 / Bug 4 "lambda_align": self.lambda_align, # QCSAM integration "rank_margin": self.rank_margin, "loss_history": list(self.loss_history), }, path) logger.info(f"✅ Model saved → {path}") def step_scheduler(self, path: Optional[str] = None) -> None: """ Advance the cosine LR schedule by one step then save a checkpoint. Call once per training epoch / rank-cycle, NOT once per batch. Args: path: Destination for the checkpoint file. Defaults to the model_path stored on the config (passed in by the Bridge). Providing an explicit path keeps this method usable standalone. """ self.scheduler.step() if path: self.save(path) def load(self, path: str) -> None: try: ckpt = torch.load(path, map_location="cpu") # ── Print .pt file contents summary ────────────────────────────── file_size_mb = os.path.getsize(path) / 1_048_576 logger.info(f"") logger.info(f"┌─── PT FILE CONTENTS: {path} ({file_size_mb:.2f} MB) ───") logger.info(f"│ train_step : {ckpt.get('step', 0)}") logger.info(f"│ loss_history : {len(ckpt.get('loss_history', []))} entries") if ckpt.get("loss_history"): last = ckpt["loss_history"][-1] logger.info(f"│ last loss : total={last.get('total', 0):.4f} " f"rl={last.get('rl', 0):.4f} ce={last.get('ce', 0):.4f} " f"rank={last.get('rank', 0):.4f} risk={last.get('risk', 0):.4f}") model_keys = list(ckpt.get("model", {}).keys()) logger.info(f"│ model keys : {len(model_keys)} tensors") logger.info(f"│ optimizer keys : {'present' if 'optimizer' in ckpt else 'missing'}") lambdas = {k: round(ckpt[k], 6) for k in ("lambda_ce","lambda_ql","lambda_rank","lambda_risk", "lambda_moe","lambda_gate","lambda_crps","lambda_rent","lambda_align") if k in ckpt} if lambdas: logger.info(f"│ loss weights : {lambdas}") logger.info(f"│ rank_margin : {ckpt.get('rank_margin', '—')}") logger.info(f"└{'─' * 60}") # Fix 2 — use strict=False so that a checkpoint saved after a # forward pass (which contains lazily-created clcu_heads.* keys) # can be loaded into a freshly constructed model where those # submodules don't yet exist. Log any mismatches explicitly so # problems are visible instead of being silently swallowed. # # FIX 2b — asset-count mismatch guard (e.g. 9-asset checkpoint # loaded into a 12-asset model). The FABLECLCU cross-asset # attention weights are parameterised by N, so a size mismatch # will crash rank_and_gate() with: # "size of tensor a (9) must match tensor b (12)" # Detect it early, log clearly, and skip the incompatible keys so # the model keeps its fresh random init for those layers instead. ckpt_model = ckpt.get("model", {}) current_n = self.model.num_assets # Sniff N from the first key that embeds asset count (e.g. a # weight of shape (N, …) in the temporal encoder input projection). ckpt_n = None for key, tensor in ckpt_model.items(): if hasattr(tensor, "shape") and len(tensor.shape) >= 1: # The significance_head bias has shape (N,) — reliable marker. if "significance_head" in key and tensor.shape[0] not in (1,): ckpt_n = int(tensor.shape[0]) break if ckpt_n is not None and ckpt_n != current_n: logger.warning( f"[HybridTrainer.load] ⚠️ ASSET COUNT MISMATCH: " f"checkpoint has N={ckpt_n} assets, current model has N={current_n}. " f"Stale replay episodes (shape N={ckpt_n}) will be dropped by " f"train_on_batch(). Model weights are loaded with strict=False — " f"incompatible tensors keep their random initialisation. " f"Run with --no-resume to start fully fresh and clear this warning." ) # Strip keys whose first weight dimension matches the old N so they # don't partially overwrite the correctly-sized fresh weights. filtered_ckpt = {} skipped = [] for key, tensor in ckpt_model.items(): if (hasattr(tensor, "shape") and len(tensor.shape) >= 1 and tensor.shape[0] == ckpt_n and ckpt_n != current_n): skipped.append(key) else: filtered_ckpt[key] = tensor if skipped: logger.warning( f"[HybridTrainer.load] Skipped {len(skipped)} mismatched " f"weight tensor(s) (N={ckpt_n}→{current_n}): {skipped[:8]}" + (" …" if len(skipped) > 8 else "") ) ckpt_model = filtered_ckpt incompatible = self.model.load_state_dict(ckpt_model, strict=False) if incompatible.missing_keys: logger.warning( f"[HybridTrainer.load] Missing keys in checkpoint " f"(will keep random init): {incompatible.missing_keys}" ) if incompatible.unexpected_keys: logger.warning( f"[HybridTrainer.load] Unexpected keys in checkpoint " f"(ignored): {incompatible.unexpected_keys}" ) self.optimizer.load_state_dict(ckpt["optimizer"]) # Restore scheduler position so cosine LR continues from saved step if "scheduler" in ckpt: self.scheduler.load_state_dict(ckpt["scheduler"]) self.train_step = ckpt.get("step", 0) self.lambda_ce = ckpt.get("lambda_ce", self.lambda_ce) self.lambda_ql = ckpt.get("lambda_ql", self.lambda_ql) self.lambda_rank = ckpt.get("lambda_rank", self.lambda_rank) self.lambda_risk = ckpt.get("lambda_risk", self.lambda_risk) self.lambda_moe = ckpt.get("lambda_moe", self.lambda_moe) self.lambda_gate = ckpt.get("lambda_gate", self.lambda_gate) self.lambda_crps = ckpt.get("lambda_crps", self.lambda_crps) # [Bug 4] self.lambda_rent = ckpt.get("lambda_rent", self.lambda_rent) # [Bug 4] self.lambda_align= ckpt.get("lambda_align", self.lambda_align) # QCSAM self.rank_margin = ckpt.get("rank_margin", self.rank_margin) if "loss_history" in ckpt: # FIX 4c: replace (not extend) to avoid doubling on resume _maxlen = self.loss_history.maxlen self.loss_history = deque(ckpt["loss_history"], maxlen=_maxlen) logger.info(f"✅ Model loaded ← {path} | resumed from train_step={self.train_step}") except FileNotFoundError: logger.info(f"ℹ️ No checkpoint at {path} — starting fresh") except Exception as e: logger.warning(f"⚠️ Could not load model: {e} — starting fresh") # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 14 — UNIFIED REWARD CALCULATOR (v3 strategies + v4 Ito correction) # ══════════════════════════════════════════════════════════════════════════════════════ class UnifiedRewardCalculator: """ [S4] Reward in Shreve GBM / Itô log-return units (v6). R = sgn(a_t) · log(S_{t+τ} / S_t) − commission − slippage_bps/10_000 This is the natural reward under dS = μS dt + σS dW (Shreve §1): log(S_{t+τ}/S_t) = (μ − ½σ²)τ + σ(W_{t+τ} − W_t) Sharpe / Sortino variants divide by realised vol of log-returns over the trade window, preserving the log-return base in both cases. The previous ito_simple_reward arithmetic-return logic is replaced entirely. ito_correction parameter retained for API compatibility but is now a no-op (the formula is already Itô-consistent by construction). """ ANNUALIZATION = 252 * 24 * 12 # bars per year def __init__( self, strategy: str = "simple", use_ito_correction: bool = True, # kept for API compat; no-op in v6 fee_per_trade: float = 0.001, slippage_bps: float = 2.0, ): self.strategy = strategy self.fee_per_trade = fee_per_trade self.slippage_bps = slippage_bps # ── Public API ──────────────────────────────────────────────────────────── def compute_reward( self, trade: "Trade", price_history: Optional[List[float]] = None, ) -> float: if trade.state != PositionState.CLOSED: raise ValueError("Trade must be closed before computing reward") if self.strategy == "sharpe": return self._sharpe_reward(trade, price_history) elif self.strategy == "sortino": return self._sortino_reward(trade, price_history) return self._log_return_reward(trade) # ── Core log-return reward [S4] ─────────────────────────────────────────── # Rolling reward history for volatility normalisation (class-level so all # instances share empirical scale; protected by GIL for CPython threads). _reward_history: deque = deque(maxlen=500) # ── [FIX-1a] Reward squashing ───────────────────────────────────────────── @staticmethod def _squash_reward(r: float, scale: float = 0.10) -> float: """ Soft-clip via tanh to keep rewards in (−1, +1). scale controls the sensitivity: • scale=0.10 → a raw log-return of ±0.10 (±10%) maps to ≈ ±0.76 • scale=0.05 → tighter; use for very-low-vol assets The tanh transform preserves ordinality (monotone) and sign, so log-return semantics [S4] are fully intact. It is differentiable everywhere, making it safer than hard clipping. """ return float(math.tanh(r / scale)) # ── [FIX-1b] Rolling-volatility normaliser ──────────────────────────────── @classmethod def _vol_normalise(cls, r: float) -> float: """ Optionally normalise by empirical rolling σ of recent rewards. Activated only when ≥20 samples are available; otherwise returns r. This is the discrete analogue of dividing by the diffusion coefficient σ_t. """ if len(cls._reward_history) < 20: return r sigma = float(np.std(cls._reward_history)) if sigma < 1e-8: return r # Divide then re-squash so we never blow up on sigma→0 edge cases return float(math.tanh(r / (3.0 * sigma))) # 3σ → tanh ≈ 1 def _log_return_reward(self, trade: "Trade") -> float: """ [S4] R = tanh( sgn(a_t) · log(S_{t+τ}/S_t) − fees − slip / scale ) Squashing via tanh keeps rewards bounded in (−1,+1) while preserving log-return semantics and the sign of the trade direction. """ if trade.exit_price is None or trade.entry_price <= 0: return 0.0 log_ret = math.log(trade.exit_price / trade.entry_price) if trade.direction == TradeDirection.SHORT: log_ret = -log_ret # sgn(a_t) slippage = self.slippage_bps / 10_000.0 raw = float(log_ret - self.fee_per_trade - slippage) squashed = self._squash_reward(raw) # [FIX-1a] self._reward_history.append(squashed) return squashed # ── Risk-adjusted variants (preserve log-return base) ──────────────────── def _sharpe_reward( self, trade: "Trade", price_history: Optional[List[float]], ) -> float: base = self._log_return_reward(trade) # already squashed if not price_history or len(price_history) < 2: return base log_rets = np.diff(np.log(np.maximum(price_history, 1e-12))) vol = float(np.std(log_rets)) # Normalise raw base (un-squash → normalise → re-squash) raw_unsquashed = math.atanh(max(-0.9999, min(0.9999, base))) * 0.10 normed = raw_unsquashed / vol if vol > 1e-8 else raw_unsquashed squashed = self._squash_reward(normed) self._reward_history.append(squashed) return squashed def _sortino_reward( self, trade: "Trade", price_history: Optional[List[float]], ) -> float: base = self._log_return_reward(trade) # already squashed if not price_history or len(price_history) < 2: return base log_rets = np.diff(np.log(np.maximum(price_history, 1e-12))) downside = log_rets[log_rets < 0] dvol = float(np.std(downside)) if len(downside) > 0 else 1e-8 raw_unsquashed = math.atanh(max(-0.9999, min(0.9999, base))) * 0.10 normed = raw_unsquashed / dvol if dvol > 1e-8 else raw_unsquashed squashed = self._squash_reward(normed) self._reward_history.append(squashed) return squashed # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 15 — PRICE STREAMER (v3 structure + v4 price_history list) # ══════════════════════════════════════════════════════════════════════════════════════ class PriceStreamer: """ Per-asset live price stream manager. Merged: v3 (PriceTick deque, get_ticks()) + v4 (price_history list, get_latest_price()). """ def __init__(self, symbol: str, history_len: int = 100): self.symbol = symbol self._ticks: deque = deque(maxlen=history_len) self._lock = Lock() self.latest_bid = 0.0 self.latest_ask = 0.0 self.latest_mid = 0.0 self.last_update_time = time.time() # avoids false stale on startup self.tick_count = 0 self.price_history: deque = deque(maxlen=history_len) # v4: for reward calculator; bounded to prevent memory growth def on_tick(self, bid: float, ask: float, timestamp: float) -> None: with self._lock: mid = (bid + ask) / 2.0 self.latest_bid = bid self.latest_ask = ask self.latest_mid = mid self.last_update_time = timestamp self.tick_count += 1 self.price_history.append(mid) self._ticks.append(PriceTick( symbol=self.symbol, bid=bid, ask=ask, mid=mid, timestamp=timestamp )) def get_latest_price(self) -> float: with self._lock: return self.latest_mid def get_latest_tick(self) -> Optional[PriceTick]: with self._lock: return self._ticks[-1] if self._ticks else None def get_ticks(self, n: int = 10) -> List[PriceTick]: with self._lock: return list(self._ticks)[-n:] def get_price_history(self, n: int = 20) -> List[float]: with self._lock: return self.price_history[-n:] if self.price_history else [] def get_average_price(self, n: int = 10) -> float: with self._lock: ticks = list(self._ticks)[-n:] return float(np.mean([t.mid for t in ticks])) if ticks else self.latest_mid def get_spread(self) -> float: with self._lock: return self.latest_ask - self.latest_bid def is_stale(self, timeout_sec: float = PRICE_UPDATE_TIMEOUT) -> bool: return (time.time() - self.last_update_time) > timeout_sec # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 16 — POSITION MANAGER (v3 with logger hooks) # ══════════════════════════════════════════════════════════════════════════════════════ class PositionManager: """ Broker-backed trade registry. Lifecycle contract: register_pending_buy() — buy sent to broker; trade enters PENDING confirm_buy() — broker buy confirmation received; trade enters OPEN mark_closing() — sell sent to broker; trade enters CLOSING close_trade_from_broker() — broker terminal event; trade enters CLOSED NO local PnL computation is authoritative. NO state transitions happen without a broker event. The legacy close_trade() method is retained as a thin wrapper around close_trade_from_broker() for backward compatibility with callers that pass an exit_price (e.g. checkpoint restore). It does NOT simulate fills. """ def __init__(self, ranker_logger: Optional[object] = None): self._open_trades: Dict[str, Trade] = {} # contract_id NOT yet known (PENDING) self._closed_trades: List[Trade] = [] self._lock = Lock() self.total_realized_pnl: float = 0.0 self.total_fees: float = 0.0 self.trades_opened: int = 0 self.trades_closed: int = 0 self.ranker_logger = ranker_logger # ── Phase 1: send buy ───────────────────────────────────────────────────── def register_pending_buy( self, trade_id: str, asset: str, direction: TradeDirection, quantity: float, broker_symbol: Optional[str] = None, ) -> Trade: """ Create a PENDING trade immediately after the buy message is sent. No price, no PnL, no fill — contract_id is unknown until broker confirms. """ trade = Trade( trade_id = trade_id, asset = asset, direction = direction, quantity = quantity, entry_time = time.time(), broker_symbol = broker_symbol, state = PositionState.PENDING, ) with self._lock: self._open_trades[trade_id] = trade self.trades_opened += 1 logger.info( f"⏳ [{asset}] BUY SENT | trade_id={trade_id} | " f"dir={direction.value.upper()} | awaiting broker confirmation" ) return trade # ── Phase 2: broker buy confirmation ────────────────────────────────────── def confirm_buy( self, trade_id: str, contract_id: str, buy_price: float, entry_tick: float, transaction_id: Optional[str] = None, shortcode: Optional[str] = None, broker_symbol: Optional[str] = None, ) -> Optional[Trade]: """ Called from _on_deriv_message when a 'buy' response arrives. Transitions trade PENDING → OPEN and binds all broker contract details. """ with self._lock: trade = self._open_trades.get(trade_id) if trade is None: logger.warning( f"[PositionManager.confirm_buy] trade_id={trade_id} not found" ) return None trade.confirm_open( contract_id = contract_id, buy_price = buy_price, entry_tick = entry_tick, transaction_id = transaction_id, shortcode = shortcode, broker_symbol = broker_symbol or trade.broker_symbol, ) if self.ranker_logger: self.ranker_logger.trade_open( trade_id = trade_id, asset = trade.asset, direction = trade.direction.value, price = entry_tick, qty = trade.quantity, ) logger.info( f"✅ [{trade.asset}] TRADE OPENED | trade_id={trade_id} | " f"contract_id={contract_id} | entry_tick={entry_tick:.4f} | " f"buy_price={buy_price:.4f}" ) return trade # ── Phase 3: send sell ──────────────────────────────────────────────────── def mark_closing(self, trade_id: str) -> None: """ Mark a trade as CLOSING after a sell request is sent. Actual close happens in close_trade_from_broker() when broker confirms. FIX: exit_time is stamped here so the CLOSING-timeout handler in monitor_positions() can correctly compute closing_duration. Without this stamp, closing_duration is always 0 and the 10-second timeout never fires, leaving trades stuck in CLOSING forever. """ with self._lock: trade = self._open_trades.get(trade_id) if trade: trade.state = PositionState.CLOSING trade.exit_time = time.time() # ← FIX: stamp so timeout works logger.info(f"[{trade_id}] ⏳ SELL SENT — awaiting broker terminal event") # ── Phase 4: broker terminal event ──────────────────────────────────────── def close_trade_from_broker( self, trade_id: str, status: str, profit: float, sell_price: Optional[float] = None, exit_tick: Optional[float] = None, ) -> Optional[Trade]: """ Called from _on_deriv_message when proposal_open_contract reports a terminal state. Authoritative close - profit comes directly from broker. """ with self._lock: if trade_id not in self._open_trades: return None trade = self._open_trades.pop(trade_id) # exit_tick / current_spot = actual underlying asset market price (e.g. 34074.02) # sell_price = Deriv contract dollar value (e.g. 0.98) # Prioritise market price; fall back to contract value only if nothing else exists. exit_price = exit_tick or (trade.current_spot or 0.0) or sell_price or 0.0 # ✅ Store the REAL profit from broker trade.profit = profit trade.realized_pnl = profit # authoritative broker P&L trade.sell_price = sell_price trade.exit_price = exit_price trade.exit_time = time.time() trade.status = status trade.state = PositionState.CLOSED trade.fees = 0.0 with self._lock: self._closed_trades.append(trade) self.trades_closed += 1 self.total_realized_pnl += trade.realized_pnl if self.ranker_logger: # ✅ Use REAL profit for return calculation, not price difference return_pct = (profit / trade.quantity) / trade.entry_price if trade.entry_price > 0 else 0.0 self.ranker_logger.trade_close( trade_id=trade_id, asset=trade.asset, pnl=profit, # ← REAL profit! return_pct=return_pct, exit_price=exit_price, # ✅ FIX: Include exit_price for dashboard! ) # ✅ Log the REAL profit, not the contract value. # FIX 2: Include exit_price so the dashboard log parser can populate # the EXIT column. Format mirrors the TRADE CLOSED | ID= pattern that # TRADE_CLOSE_RE_STRICT already matches. logger.info( f"INFO | TRADE | {trade.asset} | " f"TRADE CLOSED | ID={trade_id} | " f"pnl={profit:+.4f} | " f"exit_price={exit_price:.5f} | " f"status={status} | contract_id={trade.contract_id}" ) return trade # ── Legacy compatibility wrapper ─────────────────────────────────────────── def open_trade( self, trade_id: str, asset: str, direction: TradeDirection, entry_price: float, quantity: float, ) -> Trade: """ Backward-compatibility shim — delegates to register_pending_buy(). entry_price is stored as a non-authoritative hint; it will be overwritten by the broker's entry_tick on confirm_buy(). """ trade = self.register_pending_buy(trade_id, asset, direction, quantity) trade.entry_price = entry_price # non-authoritative hint return trade def close_trade( self, trade_id: str, exit_price: float, fees: float = 0.0, ) -> Optional[Trade]: """ Backward-compatibility shim — accepts an exit_price for callers that previously supplied it. Delegates to close_trade_from_broker() with profit derived from broker data if available, otherwise uses the passed exit_price only for the reward calculator (log-return path). Does NOT simulate fills or invent authoritative P&L. """ with self._lock: trade = self._open_trades.get(trade_id) if trade is None: logger.warning(f"⚠️ Trade {trade_id} not found") return None # Use broker profit if already set (authoritative); otherwise estimate. # For MULTUP/MULTDOWN: P&L = stake × multiplier × % move (not price diff × qty). # We approximate with the multiplier from ASSET_MULTIPLIER so the replay # reward is in the right ballpark before the broker terminal event arrives. if trade.profit is not None: profit = trade.profit else: ref = trade.entry_tick or trade.entry_price mult = ASSET_MULTIPLIER.get(trade.asset, 20) # ✅ FIX 3: Use actual stake paid to broker, not quantity*ref. # quantity is near-zero for Kelly-sized trades → quantity*ref ≈ 0, # collapsing the profit term and leaving only fees as the PnL. stake = trade.buy_price if (trade.buy_price and trade.buy_price > 0) else 1.0 if ref > 0 and exit_price > 0: pct_move = (exit_price - ref) / ref sign = 1.0 if trade.direction == TradeDirection.LONG else -1.0 profit = sign * pct_move * stake * mult - fees else: profit = -fees # can't compute — assume no P&L beyond fees return self.close_trade_from_broker( trade_id = trade_id, status = trade.status or "sold", profit = profit, sell_price = exit_price, ) # ── Queries ──────────────────────────────────────────────────────────────── def get_open_trades(self) -> List[Trade]: """Return all trades that are PENDING, OPEN, or CLOSING (i.e. not CLOSED).""" with self._lock: return list(self._open_trades.values()) def get_confirmed_open_trades(self) -> List[Trade]: """Return only broker-confirmed OPEN or CLOSING trades (not PENDING).""" with self._lock: return [ t for t in self._open_trades.values() if t.state in (PositionState.OPEN, PositionState.CLOSING) ] def get_open_trade_by_asset(self, asset: str) -> Optional[Trade]: """Return the active trade for an asset, or None. Excludes CLOSING trades — sell already sent, awaiting broker confirmation.""" with self._lock: for t in self._open_trades.values(): if t.asset == asset and t.state != PositionState.CLOSING: return t return None def get_open_trade_by_contract(self, contract_id: str) -> Optional[Trade]: with self._lock: for t in self._open_trades.values(): if t.contract_id == str(contract_id): return t return None def get_unrealized_pnl(self, price_map: Dict[str, float]) -> float: """Non-authoritative spot estimate only.""" with self._lock: return sum( t.compute_unrealized_pnl(price_map.get(t.asset, t.entry_price)) for t in self._open_trades.values() ) def get_stats(self) -> dict: with self._lock: return { "open_trades": len(self._open_trades), "closed_trades": len(self._closed_trades), "trades_opened": self.trades_opened, "trades_closed": self.trades_closed, "total_realized_pnl": self.total_realized_pnl, "total_fees": self.total_fees, } # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 16b — PORTFOLIO RISK MANAGER # ══════════════════════════════════════════════════════════════════════════════════════ class PortfolioRiskManager: """ Institutional position-sizing with four layers and one unbreakable hard cap. Final size formula: notional = Kelly(μ,σ²) × significance × CVaR_adj × DD_adj notional = hard_min( notional, max_pos × total_capital ) ← NEVER BREACHED Layer 1 — Kelly / Conviction ───────────────────────────── f_kelly = kelly_fraction × (μ / σ²) μ = value_estimate (model's E[R|F_t], log-return units) σ² = realised_var (QV-based vol from feature [6], squared) Scaled by significance_weight ∈ [0,1] as a conviction multiplier. Layer 2 — CVaR adjustment ────────────────────────── CVaR@5% from the distributional head. If cvar_05 < cvar_floor → trade vetoed outright (quantity = 0). Otherwise: cvar_adj = clip(1 + cvar_05 / |cvar_floor|, 0, 1). Layer 3 — Drawdown circuit breaker (TEMPORARY HALT — 3 min default) ────────────────────────────────────────────────────────────────────── Tracks peak equity. On drawdown: ≥ drawdown_reduce_pct → all new sizes halved ≥ drawdown_halt_pct → all new trades blocked for halt_duration_secs then AUTOMATICALLY RESUMES — no manual reset needed. Layer 4 — Hard max_pos cap (UNBREAKABLE) ────────────────────────────────────────── Applied unconditionally as the very last step, in both fraction-space and notional-space (belt + braces). Cannot be overridden by any upstream layer. """ def __init__( self, asset_registry: Dict[str, dict], config: Optional[PortfolioRiskConfig] = None, ): self.asset_registry = asset_registry self.cfg = config or PortfolioRiskConfig() self._peak_equity: float = self.cfg.total_capital self._current_equity: float = self.cfg.total_capital self._equity_history: deque = deque(maxlen=500) self._equity_history.append(self._current_equity) # Committed capital per open trade: trade_id → notional self._committed: Dict[str, float] = {} # Circuit-breaker state — stores the timestamp when the halt was triggered. # None means not currently halted. Once time.time() > _halt_until, trading # automatically resumes with no manual intervention required. self._halt_until: Optional[float] = None logger.info( f"[PortfolioRiskManager] Initialized | " f"capital={self.cfg.total_capital:.2f} | " f"half-Kelly={self.cfg.kelly_fraction} | " f"halt_dd={self.cfg.drawdown_halt_pct:.0%} " f"(auto-resume after {self.cfg.halt_duration_secs:.0f}s) | " f"reduce_dd={self.cfg.drawdown_reduce_pct:.0%}" ) # ── Public API ───────────────────────────────────────────────────────────────────── def compute_position_size( self, asset_id: str, current_price: float, value_estimate: float, realized_vol: float, significance: float, cvar_05: float, fallback_notional: Optional[float] = None, ) -> Tuple[float, str]: """ Returns (quantity_in_asset_units, reason_string). PERFORMANCE RANKER REFACTOR (v7): • NEVER returns quantity=0 — the system MUST trade to collect performance data. • CVaR veto removed: demoted to LOG ONLY (trades must run to generate ranking data). • Drawdown circuit breaker removed as a BLOCKER: demoted to LOG ONLY. • Drawdown size-reduction removed: trade at full Kelly to maximise data fidelity. • If Kelly formula produces qty <= 0, falls back to fallback_notional / price (or TradeConfig.amount equivalent) so the minimum-trade guarantee is never broken by sizing arithmetic. """ # ── Hard guard: cannot size without a valid price ───────────────────── if current_price <= 0: # Even here we return the fallback rather than 0, because the caller # (process_axrvi_signal) will itself guard against price == 0. fb = (fallback_notional or self.cfg.min_notional) / max(current_price, 1.0) return fb, "Invalid price — using absolute fallback" dd = self._current_drawdown() now = time.time() # ── Layer 3: circuit breaker — LOG ONLY, never blocks ───────────────── # CHANGE v7: was return 0.0 → now logs a warning and continues. # Ranking requires live trades; a halt produces zero data. if dd >= self.cfg.drawdown_halt_pct: if self._halt_until is None: self._halt_until = now + self.cfg.halt_duration_secs logger.warning( f"[PortfolioRiskManager] ⚠️ Circuit breaker threshold HIT " f"(dd={dd:.1%} ≥ halt={self.cfg.drawdown_halt_pct:.1%}) — " f"LOGGING ONLY (not halting — ranker needs live trades for performance data)" ) elif now >= self._halt_until: self._halt_until = None logger.info("[PortfolioRiskManager] ✅ Circuit breaker cooldown expired") else: # Drawdown recovered below threshold — clear any stale halt timestamp if self._halt_until is not None and now >= self._halt_until: self._halt_until = None # ── Drawdown reduce — LOG ONLY, never halves sizes ──────────────────── # CHANGE v7: was dd_adj = 0.5 → removed entirely. # Full-size trades produce cleaner log-return signals for ranking. if dd >= self.cfg.drawdown_reduce_pct: logger.info( f"[PortfolioRiskManager] ℹ️ Drawdown reduce threshold reached " f"(dd={dd:.1%}) — size reduction SKIPPED (ranker mode: full size for data)" ) # ── Layer 2: CVaR check — LOG ONLY, never vetoes ────────────────────── # CHANGE v7: was return 0.0 → now logs and applies a mild cvar_adj only. # The ranker needs the trade to execute regardless of CVaR; blocking it # produces a gap in the performance record for this asset. if cvar_05 < self.cfg.cvar_floor: logger.info( f"[PortfolioRiskManager] ℹ️ CVaR@5%={cvar_05:.4f} below floor=" f"{self.cfg.cvar_floor:.4f} — CVaR VETO BYPASSED (ranker mode)" ) cvar_adj = max(0.0, 1.0 + cvar_05 / abs(self.cfg.cvar_floor)) if cvar_05 < 0 else 1.0 # ── Layer 1: Kelly / conviction ──────────────────────────────────────── mu = max(value_estimate, 0.0) realized_var = max(realized_vol ** 2, 1e-8) max_pos = self._get_max_pos(asset_id) f_kelly = 0.0 kelly_ok = False if mu > 0: f_kelly = self.cfg.kelly_fraction * (mu / realized_var) f_conviction = min(f_kelly * significance, max_pos) f_adjusted = f_conviction * cvar_adj # dd_adj removed (v7) # ── Layer 4 (portfolio-level): remaining risk budget ─────────────── committed_frac = self._total_committed_fraction() remaining = max(0.0, self.cfg.max_portfolio_risk - committed_frac) f_final = min(f_adjusted, remaining) # ── Layer 4 (asset-level): HARD max_pos cap — UNBREAKABLE ────────── f_final = min(f_final, max_pos) notional = f_final * self.cfg.total_capital notional = max(self.cfg.min_notional, min(self.cfg.max_notional, notional)) notional = min(notional, max_pos * self.cfg.total_capital) quantity = notional / current_price kelly_ok = quantity > 0 else: quantity = 0.0 notional = 0.0 f_final = 0.0 f_conviction = 0.0 # ── FALLBACK GUARANTEE: quantity must NEVER be 0 ───────────────────── # CHANGE v7: if Kelly produced 0 (no edge, risk-budget exhausted, etc.) # fall back to fallback_notional / price so the system always trades. # This is the minimum-data guarantee: every asset gets observed. if quantity <= 0: fb_notional = fallback_notional or self.cfg.min_notional quantity = fb_notional / current_price reason = ( f"Kelly qty=0 (μ={value_estimate:.5f}, kelly_ok={kelly_ok}) — " f"FALLBACK qty={quantity:.6f} from notional={fb_notional:.2f}" ) logger.info(f"[PortfolioRiskManager] [{asset_id}] FALLBACK | {reason}") return quantity, reason reason = ( f"f_kelly={f_kelly:.5f} sig={significance:.3f} " f"cvar_adj={cvar_adj:.3f} dd_adj=1.0(removed) " f"f_final={f_final:.5f} notional={notional:.2f} " f"max_pos={max_pos:.4f}[HARD CAP]" ) logger.info(f"[PortfolioRiskManager] [{asset_id}] qty={quantity:.6f} | {reason}") return quantity, reason def register_open(self, trade_id: str, notional: float) -> None: """Record committed capital when a trade opens.""" self._committed[trade_id] = notional def register_close(self, trade_id: str, pnl: float) -> None: """Release committed capital and update equity when a trade closes.""" self._committed.pop(trade_id, None) self._current_equity += pnl self._peak_equity = max(self._peak_equity, self._current_equity) self._equity_history.append(self._current_equity) def update_equity(self, new_equity: float) -> None: """Sync equity from an external account balance feed if available.""" self._current_equity = new_equity self._peak_equity = max(self._peak_equity, new_equity) self._equity_history.append(new_equity) def get_stats(self) -> dict: dd = self._current_drawdown() now = time.time() return { "current_equity": self._current_equity, "peak_equity": self._peak_equity, "current_drawdown": dd, "committed_notional": sum(self._committed.values()), "open_trades": len(self._committed), "circuit_breaker": self._halt_until is not None and now < self._halt_until, "resumes_in_secs": max(0.0, (self._halt_until or 0) - now), "sizing_reduced": dd >= self.cfg.drawdown_reduce_pct, } # ── Internal helpers ─────────────────────────────────────────────────────────────── def _current_drawdown(self) -> float: if self._peak_equity <= 0: return 0.0 raw_dd = (self._peak_equity - self._current_equity) / self._peak_equity return min(1.0, max(0.0, raw_dd)) # FIX: clamp [0,1] def _get_max_pos(self, asset_id: str) -> float: return self.asset_registry.get(asset_id, {}).get("max_pos", 0.002) def _total_committed_fraction(self) -> float: return sum(self._committed.values()) / max(self.cfg.total_capital, 1.0) def state_dict(self) -> dict: """Return serialisable portfolio state for checkpointing.""" return { "peak_equity": self._peak_equity, "current_equity": self._current_equity, "committed": dict(self._committed), "halt_until": self._halt_until, "equity_history": list(self._equity_history), } def load_state_dict(self, state: dict) -> None: """Restore portfolio state — preserves drawdown circuit-breaker status.""" self._peak_equity = state.get("peak_equity", self._peak_equity) self._current_equity = state.get("current_equity", self._current_equity) self._committed = state.get("committed", {}) self._halt_until = state.get("halt_until", None) hist = state.get("equity_history", []) if hist: self._equity_history.extend(hist) # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 17 — DERIV WEBSOCKET CLIENT (v3 with logger hooks, sync on_message) # ══════════════════════════════════════════════════════════════════════════════════════ class DerivWebSocketClient: """ Manages the Deriv WebSocket connection. on_message must be a SYNCHRONOUS callable — called from the async listen loop via direct invocation (no await). This fixes v4's bug where on_message was async. """ def __init__( self, api_key: str, ws_url: str, on_message: Callable[[dict], None], # must be sync ranker_logger: Optional[object] = None, ): self.api_key = api_key self.ws_url = ws_url self.on_message = on_message self.ranker_logger = ranker_logger self.ws = None self.running = False self.connected = False self._lock = Lock() self._message_id = 0 # Track subscribed symbols so we can re-subscribe after reconnect self._subscribed_symbols: Set[str] = set() def _next_msg_id(self) -> int: self._message_id += 1 return self._message_id async def connect(self) -> bool: if websockets is None: logger.error("❌ websockets library not installed") return False for attempt in range(1, WS_MAX_RETRIES + 1): try: logger.info(f"📡 Connecting to Deriv (attempt {attempt}/{WS_MAX_RETRIES})…") self.ws = await asyncio.wait_for( websockets.connect( self.ws_url, ping_interval=20, # send a WS ping every 20s to keep alive ping_timeout=10, # treat as dead if pong not received in 10s close_timeout=10, ), timeout=15.0 ) self.connected = True self.running = True if self.ranker_logger: self.ranker_logger.connection_event("Deriv WebSocket", "connected") logger.info("✅ Connected to Deriv WebSocket") return True except Exception as e: logger.warning(f"⏱️ Connect error: {e}") await asyncio.sleep(WS_RECONNECT_DELAY) if self.ranker_logger: self.ranker_logger.connection_event( "Deriv WebSocket", "error", f"Failed after {WS_MAX_RETRIES} attempts" ) return False async def authenticate(self) -> bool: try: await self.ws.send( json.dumps({"authorize": self.api_key, "req_id": self._next_msg_id()}) ) data = json.loads(await asyncio.wait_for(self.ws.recv(), timeout=10.0)) if data.get("authorize"): logger.info("✅ Authenticated with Deriv") return True logger.error("❌ Authentication failed") return False except Exception as e: logger.error(f"❌ Auth error: {e}") return False async def subscribe_to_ticks(self, symbol: str) -> bool: try: await self.ws.send(json.dumps({ "ticks": symbol, "subscribe": 1, "req_id": self._next_msg_id() })) self._subscribed_symbols.add(symbol) # remember for post-reconnect replay logger.info(f"📊 Subscribed to {symbol}") return True except Exception as e: logger.error(f"❌ Subscription error for {symbol}: {e}") return False async def subscribe_to_poc(self, contract_id: str) -> bool: """ Subscribe to proposal_open_contract stream for a live contract. Delivers real-time status updates and terminal events (won/lost/sold/expired). Must be called after buy confirmation so we receive all lifecycle events. """ try: await self.ws.send(json.dumps({ "proposal_open_contract": 1, "contract_id": int(contract_id), "subscribe": 1, "req_id": self._next_msg_id(), })) logger.info(f"🔔 Subscribed to poc stream | contract_id={contract_id}") return True except Exception as e: logger.error(f"❌ poc subscription error | contract_id={contract_id}: {e}") return False async def forget_contract(self, subscription_id: str) -> None: """Unsubscribe from a poc stream after contract closes to avoid leaking subscriptions.""" try: await self.ws.send(json.dumps({ "forget": subscription_id, "req_id": self._next_msg_id(), })) except Exception: pass # best-effort cleanup async def send_message(self, msg: dict) -> bool: """ Send a message to Deriv with a hard 10s timeout. [HANG FIX — Layer 1] The previous implementation awaited self.ws.send() with no timeout. On a half-open TCP connection (silent proxy drop, NAT table flush, HF Spaces idle reap) the kernel send buffer fills and this await never returns, freezing every caller — including rank_and_gate() via _handle_rank_rotation → _close_position, and via _ensure_minimum_trades → process_axrvi_signal. No more rankings POST to the hub, its 60s TTL expires, dashboard shows 0.0000. On timeout we mark the connection dead and schedule reconnect() asynchronously so we don't block the caller. Callers see False and can react (same as any other send failure). """ try: msg["req_id"] = self._next_msg_id() await asyncio.wait_for( self.ws.send(json.dumps(msg)), timeout=10.0, ) return True except asyncio.TimeoutError: logger.error( "❌ Deriv ws.send() stalled >10s — connection is half-open. " "Scheduling reconnect." ) self.connected = False # Fire-and-forget reconnect so we don't block the hung caller. try: asyncio.get_running_loop().create_task(self.reconnect()) except RuntimeError: # No running loop (shouldn't happen here, but be safe). pass return False except Exception as e: logger.error(f"❌ Send error: {e}") return False async def listen(self) -> None: while self.running and self.connected: try: message = await asyncio.wait_for( self.ws.recv(), timeout=PRICE_UPDATE_TIMEOUT ) self.on_message(json.loads(message)) # sync callback except asyncio.TimeoutError: logger.warning("⏱️ No messages from Deriv (timeout) — pinging to verify connection") # Actively check if the connection is still alive before continuing try: pong_waiter = await self.ws.ping() await asyncio.wait_for(pong_waiter, timeout=5.0) logger.debug("✅ Deriv ping/pong OK — connection alive") except Exception as ping_err: logger.warning(f"⚠️ Deriv ping failed ({ping_err}) — triggering reconnect") self.connected = False await self.reconnect() except Exception as e: logger.error(f"❌ Listen error: {e}") self.connected = False await self.reconnect() async def reconnect(self) -> None: self.connected = False if self.ws: try: await self.ws.close() except Exception: pass for attempt in range(1, WS_MAX_RETRIES + 1): await asyncio.sleep(WS_RECONNECT_DELAY) if not self.running: logger.info("[Deriv] Reconnect aborted — bridge stopped") return logger.info(f"[Deriv] Reconnect attempt {attempt}/{WS_MAX_RETRIES}…") if await self.connect() and await self.authenticate(): # ── Critical: re-subscribe to all previously subscribed symbols ── # Without this, ticks stop arriving after every reconnect even # though the WebSocket connection itself is alive. symbols_to_restore = list(self._subscribed_symbols) for symbol in symbols_to_restore: try: # [HANG FIX — Layer 1] Same 10s cap as send_message so a # stalled re-subscribe can't hang the reconnect task. await asyncio.wait_for( self.ws.send(json.dumps({ "ticks": symbol, "subscribe": 1, "req_id": self._next_msg_id() })), timeout=10.0, ) logger.info(f"🔄 Re-subscribed to {symbol} after reconnect") except asyncio.TimeoutError: logger.warning( f"⚠️ Re-subscription to {symbol} timed out — will retry on next reconnect" ) except Exception as re_err: logger.warning(f"⚠️ Re-subscription failed for {symbol}: {re_err}") logger.info( f"✅ Reconnected to Deriv | restored {len(symbols_to_restore)} subscriptions" ) return logger.error("❌ Failed to reconnect to Deriv after max retries") async def close(self) -> None: self.running = self.connected = False if self.ws: try: await self.ws.close() except Exception: pass # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 18 — QUASAR AXRVI BRIDGE v5 (main orchestrator) # ══════════════════════════════════════════════════════════════════════════════════════ class QuasarAXRVIBridge: """ Context-Aware Signal Relevance Engine — fully integrated v5 orchestrator. Philosophy ────────── This system does NOT predict market direction. It answers one question per cycle: "Which asset's hub signal is most RELEVANT and TRUSTWORTHY right now?" Pipeline (executed every config.update_frequency_seconds) ────────────────────────────────────────────────────────── Hub WebSocket ──► AssetSnapshot (dominant_signal, signal_confidence, avn_accuracy, buy/sell counts) │ ▼ UnifiedFeatureEngine (26-dim: 19 base + 7 stochastic) │ ▼ AXRVINet + MC Dropout ──► significance_weight + epistemic_std │ ▼ RankingEngine (Final_Priority = confidence × significance) ConservativeRanker (lower-bound under uncertainty) │ ▼ BanditSelector (UCB/Thompson/Greedy) → top N candidates │ ▼ Execution Gates: Gate A: dominant_signal ∉ {NEUTRAL, HOLD} Gate B: hub signal_confidence > 0 Gate C: significance_weight ≥ SCORE_THRESHOLD Gate D: DynamicExecutionGate (vol + uncertainty + jump) │ ▼ Deriv API ──► BUY / SELL (hub's signal verbatim) │ ▼ GirsanovReplayBuffer → HybridTrainer (online learning) Hard constraints ──────────────── ✗ Never generates its own buy/sell signals ✗ Never modifies hub signals ✗ Never writes back to the hub or any asset space ✓ Subscriber-only WebSocket mode ✓ All trade execution uses hub's dominant_signal verbatim """ def __init__( self, config: Optional[AssetRankerConfig] = None, trade_config: Optional[TradeConfig] = None, reward_strategy: str = "simple", hub_ws_url: str = os.environ.get("QUASAR_HUB_URL", "ws://localhost:7860/ws/subscribe"), enable_logging: bool = True, checkpoint_dir: str = "./Ranker10", # new folder for 10-asset build resume: bool = True, # [RESUME FIX] default ON — see __init__ hf_repo_id: Optional[str] = "KarlQuant/quasar-axrvi-v10", # new HF repo (10 assets) ): self.config = config or AssetRankerConfig() self.trade_config = trade_config or TradeConfig() self.reward_strategy = reward_strategy self.enable_logging = enable_logging and LOGGING_AVAILABLE # ── [RESUME FIX] Environment variable override ──────────────────────── # HF Spaces entrypoints usually can't pass CLI flags — they just run # `python Quasar_axrvi_ranker.py`. To control resume behaviour there, # set the QUASAR_RESUME environment variable in the Space's secrets: # QUASAR_RESUME=0 / false / no → start fresh (overrides constructor) # QUASAR_RESUME=1 / true / yes → resume from latest checkpoint # (unset) → use the constructor argument _env_resume = os.environ.get("QUASAR_RESUME", "").strip().lower() if _env_resume in ("0", "false", "no", "off"): logger.warning( "[RESUME] QUASAR_RESUME env var forces fresh start " "(overriding resume=True constructor argument)" ) resume = False elif _env_resume in ("1", "true", "yes", "on"): resume = True # ── Checkpoint manager (local + optional HF sync) ───────────────────── self.checkpoint_mgr = RankerCheckpointManager( checkpoint_dir=checkpoint_dir, hf_repo_id=hf_repo_id, ) self.resume = resume # ── [RESUME FIX] Startup banner ─────────────────────────────────────── # Prints resume state + HF sync status in a single eyeballable block so # you can tell at a glance whether checkpoints will actually be used # and mirrored. The most common failure mode on Spaces is a missing # HF_TOKEN secret — that goes silent without this banner. _hf_enabled = self.checkpoint_mgr._hf.enabled _hf_token = "✅ set" if os.environ.get("HF_TOKEN") else "❌ missing" _hf_repo = os.environ.get("HF_REPO_ID") or hf_repo_id or "—" logger.info( "\n" + "═" * 66 + "\n" f" 📦 CHECKPOINT CONFIG\n" f" resume : {self.resume} " f"({'will attempt to restore on start' if self.resume else 'FRESH START — no restore'})\n" f" checkpoint_dir: {checkpoint_dir}\n" f" hf_repo : {_hf_repo}\n" f" hf_token : {_hf_token}\n" f" hf_sync : {'✅ enabled' if _hf_enabled else '❌ disabled (set HF_TOKEN + HF_REPO_ID)'}\n" + "═" * 66 ) # ── Structured logger (optional) ────────────────────────────────────── self.ranker_logger: Optional[object] = None self.log_bridge: Optional[object] = None if self.enable_logging: try: # FIX: use RANKER_LOG_DIR env var so the dashboard service can # always locate the log files via the same variable. _rl_log_dir = os.environ.get("RANKER_LOG_DIR", "./ranker_logs") self.ranker_logger = RankerLogger(log_dir=_rl_log_dir) self.log_bridge = RankerLogBridge(self.ranker_logger) except Exception as e: logger.warning(f"Structured logging unavailable: {e} — continuing without") self.enable_logging = False # ── Deriv connectivity ──────────────────────────────────────────────── self.ws_client: Optional[DerivWebSocketClient] = None self.price_streamers: Dict[str, PriceStreamer] = {} # ── AXRVI components ────────────────────────────────────────────────── self.axrvi_net: Optional[AXRVINet] = None self.bandit: Optional[BanditSelector] = None self.trainer: Optional[HybridTrainer] = None # ── Ranking & gating ────────────────────────────────────────────────── # ShreveRankingEngine wired with shreve_config; asset_buffers injected # after initialize() populates them via the shared dict reference — # because Python dicts are passed by reference, any keys added to # self.asset_buffers later are automatically visible to ranking_engine. self.asset_buffers: Dict[str, AssetStateBuffer] = {} self.ranking_engine = ShreveRankingEngine( shreve_config = self.config.shreve_config, asset_buffers = self.asset_buffers, # shared reference — auto-updated ) self.conservative_ranker = ConservativeRanker( confidence_level = self.config.uncertainty_config.confidence_level, uncertainty_veto_threshold = self.config.uncertainty_config.uncertainty_veto_threshold, ) self.execution_gate = DynamicExecutionGate( martingale_epsilon = self.config.shreve_config.martingale_gate_epsilon, # [S7] ) # ── Replay & reward ─────────────────────────────────────────────────── self.replay = GirsanovReplayBuffer( capacity = self.config.buffer_size, config = self.config.replay_config, ) self.reward_calc = UnifiedRewardCalculator( strategy = reward_strategy, use_ito_correction = True, ) # ── Trade management ────────────────────────────────────────────────── self.position_mgr = PositionManager(ranker_logger=self.ranker_logger) # ── Portfolio risk manager (sizing + circuit breaker) ───────────────── self.portfolio_risk_mgr = PortfolioRiskManager( asset_registry = self.config.asset_registry, config = self.config.portfolio_risk_config, ) # ── Hub subscriber ──────────────────────────────────────────────────── self.hub_subscriber = HubSubscriber( hub_url = hub_ws_url, on_update = self._on_hub_update, ranker_logger = self.ranker_logger, ) # ── Signal subscriber (high-priority side channel, v2.3+) ───────────── # Runs on its own thread, consumes /ws/signals at ~30 ms cadence, and # writes per-tick realtime signals directly into the shared # AssetSnapshot store via hub_subscriber.inject_signal(). Gives # rank_and_gate sub-second visibility into BUY/SELL actions instead of # waiting for the next metrics_update broadcast on /ws/subscribe # (which only carries cumulative aggregates, not per-tick data). self.signal_subscriber = SignalSubscriber( signal_url = _derive_signal_url(hub_ws_url), hub_subscriber = self.hub_subscriber, ranker_logger = self.ranker_logger, ) # ── Hub HTTP URL — for pushing AXRVI rankings to /api/flip/rankings ─ # Derived from the WS URL: ws://host:port/... → http://host:port self._hub_http_url: str = ( hub_ws_url .replace("wss://", "https://") .replace("ws://", "http://") .split("/ws/")[0] # strip the /ws/subscribe path ) # ── Internal state ──────────────────────────────────────────────────── self.running = False self._lock = Lock() self._price_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=200)) self.rank_count = 0 self._last_sequences: Optional[np.ndarray] = None self._last_selected_idx: Optional[int] = None self._selected_assets: List[str] = [] self._last_final_scores: Optional[np.ndarray] = None self._last_value_estimates: Dict[str, float] = {} # ── [HANG FIX — Layer 2] Re-entrancy guard for rank_and_gate ────────── # Both _rank_loop and monitor_positions' refill trigger call # rank_and_gate(). Without this lock they can enter concurrently and # corrupt shared state (rank_count, _last_value_estimates, the ranked # list, and the position manager). The lock is created lazily in # rank_and_gate() itself because asyncio.Lock() must be created inside # a running event loop on Python <3.10. self._rank_lock: Optional[asyncio.Lock] = None # ── [HANG FIX — Layer 3] Watchdog timestamp ─────────────────────────── # Updated at the END of every successful rank_and_gate() cycle. # _rank_watchdog() checks this and force-closes the ws (triggering # reconnect) if no cycle has completed within RANK_STALL_THRESHOLD_S. # This is the safety net that recovers from ANY cause of stall — # not just the ws.send() one we know about. self._last_rank_complete_ts: float = time.time() # [S2] Pending-episode store: keyed by trade_id. # s_t is captured at trade-open time; s_{t+1} is captured at close. # This gives a proper (s_t, a_t, r_t, s_{t+1}) tuple with s_t ∈ F_t, # s_{t+1} ∈ F_{t+1} as required by Shreve adaptedness (§4). self._pending_episodes: Dict[str, dict] = {} # [S8] Per-trade tick counters for minimum holding check (optimal stopping) self._trade_tick_counts: Dict[str, int] = {} # ── Deriv req_id → trade_id mapping (for contract_id binding) ──────── # When we send a buy, we record req_id→trade_id here so that when the # buy confirmation arrives in _on_deriv_message we can bind the # contract_id back to the correct Trade object. self._pending_req_to_trade: Dict[int, str] = {} # ── Deriv contract_id → trade_id mapping (for poc routing) ──────────── # When a buy confirmation arrives, we map contract_id→trade_id so that # subsequent proposal_open_contract updates can be routed to the correct # Trade object for live updates and terminal state handling. self._contract_to_trade: Dict[str, str] = {} # ── Last-known hub snapshots cache ──────────────────────────────────── # Used by _ensure_minimum_trades() to force-fill below the 3-trade floor # when rank_and_gate() has not yet produced a fresh ranked list. self._last_hub_snapshots: Dict[str, AssetSnapshot] = {} # ── Top-3 rotation tracker ──────────────────────────────────────────── # Tracks the set of assets currently in the top-3 ranked positions. # On every rank cycle we compare the new top-3 against the previous one. # If a 4th-ranked asset displaces the current 3rd (i.e. any asset leaves # the top-3), its open position is closed immediately so the new top-3 # asset can take the slot in the next execution pass. self._prev_top3: Set[str] = set() self._sync_thread: Optional[threading.Thread] = None self._sync_loop: Optional[asyncio.AbstractEventLoop] = None self.stats = { "connections_made": 0, "price_ticks_received": 0, "rank_cycles": 0, "trades_executed": 0, "trades_closed": 0, "total_pnl": 0.0, "selected_assets": Counter(), } logger.info("✅ QuasarAXRVIBridge v5 initialized") if self.ranker_logger: self.ranker_logger.connection_event("Ranker", "initialized", "v5.0-integrated") def _on_hub_update(self, space_name: str, snap: AssetSnapshot) -> None: buf = self.asset_buffers.get(space_name) if buf is not None: buf.on_hub_snapshot(snap) logger.debug( f"[{space_name}] snapshot | " f"confidence={snap.signal_confidence:.3f} | " f"signal={snap.dominant_signal}" ) else: # FIX 3: was a silent drop — now logged so mismatched space names # are visible instead of causing mysteriously missing signals. logger.warning( f"[HubSubscriber] ⚠️ No asset buffer for space_name='{space_name}' " f"— hub update dropped. " f"Known buffers: {sorted(self.asset_buffers.keys())}" ) # ── Initialization ──────────────────────────────────────────────────────────────── async def initialize(self) -> bool: try: # Deriv connection with infinite back-off retry self.ws_client = DerivWebSocketClient( api_key = DERIV_API_KEY, ws_url = DERIV_WS_URL, on_message = self._on_deriv_message, ranker_logger = self.ranker_logger, ) _backoff = 5 while True: if await self.ws_client.connect() and await self.ws_client.authenticate(): break logger.warning(f"⚠️ Deriv connect/auth failed — retrying in {_backoff}s") if self.ranker_logger: self.ranker_logger.connection_event("Deriv", "warn", "Retrying…") await asyncio.sleep(_backoff) _backoff = min(_backoff * 2, 60) # Asset infrastructure for asset_id in self.config.asset_symbols: cfg = self.config.asset_registry.get(asset_id, {}) self.price_streamers[asset_id] = PriceStreamer(symbol=asset_id) self.asset_buffers[asset_id] = AssetStateBuffer( asset_id = asset_id, cfg = cfg, stoch_config = self.config.stochastic_config, shreve_config = self.config.shreve_config, ) deriv_symbol = SYMBOL_MAP_REVERSE.get(asset_id) if deriv_symbol: await self.ws_client.subscribe_to_ticks(deriv_symbol) # Neural network — v8 architecture # Use the helper that maps AssetRankerConfig → AXRVIConfig so the # dimension constants stay in a single place (Bug: dead-code fix). _axrvi_cfg = _axrvi_config_from_ranker_config(self.config) # Override the v8-specific hyperparams that the helper leaves at defaults _axrvi_cfg.n_moe_experts = 4 _axrvi_cfg.moe_top_k = 2 _axrvi_cfg.moe_balance_coeff = 0.01 _axrvi_cfg.ode_steps = 6 _axrvi_cfg.hyperbolic_curvature= 1.0 _axrvi_cfg.stdp_lr = 0.005 _axrvi_cfg.stdp_decay = 0.995 _axrvi_cfg.stdp_hebbian_decay = 0.01 _axrvi_cfg.lambda_crps = 0.10 _axrvi_cfg.lambda_moe = 1.00 _axrvi_cfg.lambda_entropy = 0.01 _axrvi_cfg.lambda_gate_reg = 0.005 _axrvi_cfg.gradient_clip_norm = 1.0 _axrvi_cfg.use_amp = False _axrvi_cfg.checkpoint_grads = False _axrvi_cfg.torchscript_compatible = False self.axrvi_net = create_axrvi_v8( num_assets = len(self.config.asset_symbols), config = _axrvi_cfg, device = self.config.device, ) total_params = sum(p.numel() for p in self.axrvi_net.parameters()) trainable_params = sum(p.numel() for p in self.axrvi_net.parameters() if p.requires_grad) logger.info( f"🧠 AXRVINet v8 initialized | assets={len(self.config.asset_symbols)} | " f"feature_dim={_axrvi_cfg.feature_dim} | seq_len={_axrvi_cfg.seq_len} | " f"d_model={_axrvi_cfg.d_model} | heads={_axrvi_cfg.num_heads} | " f"enc_layers={_axrvi_cfg.num_encoder_layers} | cross_layers={_axrvi_cfg.num_cross_layers} | " f"quantiles={_axrvi_cfg.num_quantiles} | regimes={_axrvi_cfg.num_regimes} | " f"params={total_params:,} ({trainable_params:,} trainable) | device={self.config.device}" ) self.bandit = BanditSelector( asset_ids = self.config.asset_symbols, strategy = BanditSelector.Strategy(self.config.bandit_strategy), ucb_c = self.config.ucb_c, ) self.trainer = HybridTrainer( model = self.axrvi_net, lr = self.config.learning_rate, gamma = self.config.gamma, lambda_rank = self.config.lambda_rank, lambda_risk = self.config.lambda_risk, lambda_ce = self.config.shreve_config.value_consistency_loss_weight, ranker_logger = self.ranker_logger, ) self.trainer.load(self.config.model_path) # Wire asset_buffers into ranking engine after they are created [S5/S6] self.ranking_engine.asset_buffers = self.asset_buffers self.hub_subscriber.start() self.signal_subscriber.start() # ── Full-state checkpoint resume ────────────────────────────────── # Must run AFTER all components are initialised so _restore_checkpoint # can write directly into live objects. if self.resume: loaded = self.checkpoint_mgr.load(self) if not loaded: logger.info( "[RankerCheckpoint] No checkpoint found — " "starting fresh (this is normal on first run)" ) self.checkpoint_mgr.print_summary() self.stats["connections_made"] += 1 logger.info("✅ QuasarAXRVIBridge initialized successfully") if self.ranker_logger: self.ranker_logger.connection_event("Ranker", "ready", "All systems up") return True except Exception as e: logger.error(f"❌ Initialization error: {e}") if self.ranker_logger: self.ranker_logger.connection_event("Ranker", "error", str(e)) traceback.print_exc() return False # ── Deriv message handling ───────────────────────────────────────────────────────── def _on_deriv_message(self, msg: dict) -> None: """ Synchronous handler for Deriv WebSocket messages. Handles: tick — price update buy — contract open confirmation (PENDING → OPEN) proposal_open_contract — live contract update; terminal event → CLOSED balance — account balance sync to portfolio risk manager error — broker rejection log """ try: if "tick" in msg: self._on_price_tick(msg["tick"]) elif "buy" in msg: self._on_buy_confirmation(msg) elif "proposal_open_contract" in msg: self._on_poc_update(msg["proposal_open_contract"], msg) elif "sell" in msg: # FIX 1: Handle sell ACK from broker. # Previously this message was silently dropped, leaving trades # stuck in CLOSING with no observability. The authoritative # close still comes via the proposal_open_contract terminal # event — this block is ACK-logging only. sell_data = msg["sell"] contract_id = str(sell_data.get("contract_id", "")) sold_for = float(sell_data.get("sold_for", 0.0)) logger.info( f"[Deriv] ✅ SELL ACK | contract_id={contract_id} | " f"sold_for={sold_for:.4f} | waiting for POC terminal event" ) elif "balance" in msg: bal = msg["balance"] new_equity = float(bal.get("balance", 0.0)) self.portfolio_risk_mgr.update_equity(new_equity) logger.debug( f"[Deriv] Balance update | equity={new_equity:.2f}" ) elif "error" in msg: err = msg["error"] code = err.get("code", "UNKNOWN") message = err.get("message", str(err)) req_id = msg.get("req_id") # Clean up any pending trade associated with this req_id trade_id = self._pending_req_to_trade.pop(req_id, None) if trade_id: # Remove the pending trade stub — broker rejected the buy with self.position_mgr._lock: rejected = self.position_mgr._open_trades.pop(trade_id, None) if rejected: # ── Clean up all trade-level state so nothing leaks ── # _pending_episodes: episode was opened by _open_pending_episode # but there will never be a close event — discard it. self._pending_episodes.pop(trade_id, None) # _trade_tick_counts: monitor_positions tracks this; clear it. self._trade_tick_counts.pop(trade_id, None) # portfolio_risk_mgr: committed capital was registered in # process_axrvi_signal; release it with 0 PnL. self.portfolio_risk_mgr.register_close(trade_id, 0.0) logger.error( f"❌ [{rejected.asset}] BROKER REJECTED buy | " f"trade_id={trade_id} | code={code} | {message}" ) # Legacy-compatible execution failure log logger.error( f"[{rejected.asset}] EXECUTION FAILED | " f"reason=broker_rejected | code={code} | {message}" ) else: logger.error( f"❌ Deriv error (req_id={req_id}): [{code}] {message}" ) else: logger.error( f"❌ Deriv error (req_id={req_id}): [{code}] {message}" ) if self.ranker_logger: self.ranker_logger.connection_event("Deriv", "error", str(msg["error"])) except Exception as e: logger.error(f"❌ Deriv message handler error: {e}") traceback.print_exc() def _on_buy_confirmation(self, msg: dict) -> None: """ Handle broker buy confirmation: PENDING → OPEN. Bind contract_id, entry_tick, buy_price to the Trade object. Then subscribe to the proposal_open_contract stream for live updates. """ buy_data = msg["buy"] req_id = msg.get("req_id") trade_id = self._pending_req_to_trade.pop(req_id, None) contract_id = str(buy_data.get("contract_id", "")) if not contract_id: logger.error( f"❌ Buy confirmation missing contract_id | req_id={req_id}" ) return # Map contract_id → trade_id for poc routing if trade_id: self._contract_to_trade[contract_id] = trade_id buy_price = float(buy_data.get("buy_price", 0.0)) start_tick = float(buy_data.get("start_time", time.time())) # Deriv may provide longcode/shortcode/transaction_id directly in buy shortcode = buy_data.get("shortcode") tx_id = str(buy_data.get("transaction_id", "")) # Use spot price at confirmation as entry_tick (broker-authoritative) entry_tick = float( buy_data.get("spot", 0.0) or buy_data.get("entry_spot", 0.0) or 0.0 ) # If Deriv didn't include spot in buy, use current streamer price if entry_tick == 0.0 and trade_id: pending = self.position_mgr._open_trades.get(trade_id) if pending: s = self.price_streamers.get(pending.asset) entry_tick = s.latest_mid if s else 0.0 broker_symbol = buy_data.get("symbol", "") if trade_id: trade = self.position_mgr.confirm_buy( trade_id = trade_id, contract_id = contract_id, buy_price = buy_price, entry_tick = entry_tick, transaction_id = tx_id, shortcode = shortcode, broker_symbol = broker_symbol, ) if trade: # Legacy-compatible log: "TRADE OPENED" logger.info( f"[{trade.asset}] TRADE OPENED | trade_id={trade_id} | " f"contract_id={contract_id} | buy_price={buy_price:.4f}" ) # Subscribe to poc stream for live status + terminal event if self.ws_client and self.ws_client.connected: asyncio.get_running_loop().create_task( self.ws_client.subscribe_to_poc(contract_id) ) else: logger.warning( f"[Deriv] Buy confirmation — no trade_id for req_id={req_id} | " f"contract_id={contract_id} (late or orphaned confirmation) — " f"sending immediate SELL to avoid dangling open contract on broker side" ) # The contract is live on Deriv but we have no internal tracking for # it. Immediately sell at market so we don't accumulate un-tracked # open contracts that drain the account without appearing in our books. if self.ws_client and self.ws_client.connected: asyncio.get_running_loop().create_task( self.ws_client.send_message({"sell": contract_id, "price": 0}) ) def _on_poc_update(self, poc: dict, raw_msg: dict) -> None: """ Handle a proposal_open_contract update. IMPORTANT: For multiplier contracts: - 'profit' field = actual P&L in account currency (e.g., 0.03 = $0.03 profit) - 'bid_price' = current contract value (stake × multiplier × % move) - 'current_spot' = underlying asset price """ contract_id = str(poc.get("contract_id", "")) trade_id = self._contract_to_trade.get(contract_id) if not trade_id: return with self.position_mgr._lock: trade = self.position_mgr._open_trades.get(trade_id) if trade is None: return # ── Update current_spot on live ticks ──────────────────────────────── current_spot = poc.get("current_spot") if current_spot: trade.current_spot = float(current_spot) streamer = self.price_streamers.get(trade.asset) if streamer: trade.unrealized_pnl = trade.compute_unrealized_pnl(float(current_spot)) # ── Extract profit from broker (ONLY from 'profit' field) ──────────── # CRITICAL: Do NOT use bid_price as profit fallback! raw_profit = poc.get("profit") # Log live profit for monitoring (this is the REAL P&L) if raw_profit is not None: logger.debug(f"[{trade.asset}] Live profit: {float(raw_profit):+.4f}") # ✅ FIX 1: Cache live broker profit so rotation closes use the real P&L. # trade.profit stays None only if Deriv never sent a profit field at all. trade.profit = float(raw_profit) # ✅ FIX-MAXLOSS: dedicated field for OptStop max-loss check so the # monitor can read broker-authoritative live P&L without confusion # with the terminal-event authoritative profit set on close. trade._last_poc_profit = float(raw_profit) # ── Terminal state check ────────────────────────────────────────────── is_terminal = ( poc.get("is_expired", False) or poc.get("is_sold", False) or poc.get("status") in ("won", "lost", "sold", "expired") ) if not is_terminal: return # ── Extract broker-authoritative close data ─────────────────────────── status = poc.get("status", "expired") # ✅ ONLY use 'profit' field for P&L - NEVER fall back to bid_price if raw_profit is not None: profit = float(raw_profit) else: # If profit field is missing (shouldn't happen on terminal events), # log error and use 0 as fallback - but don't use bid_price! logger.error(f"[{trade.asset}] Terminal event missing profit field! contract_id={contract_id}") profit = 0.0 # bid_price is the contract value, NOT profit - store separately if needed bid_price = float(poc.get("bid_price", 0.0)) sell_price = float(poc.get("sell_price") or 0.0) exit_tick = float(poc.get("exit_tick") or poc.get("current_spot") or 0.0) logger.info( f"[{trade.asset}] CONTRACT TERMINAL | contract_id={contract_id} | " f"status={status} | profit={profit:+.4f} | bid_price={bid_price:.4f} | " f"sell_price={sell_price:.4f}" ) closed_trade = self.position_mgr.close_trade_from_broker( trade_id=trade_id, status=status, profit=profit, # ✅ Now using actual profit, not bid_price sell_price=sell_price if sell_price > 0 else None, exit_tick=exit_tick if exit_tick > 0 else None, ) if closed_trade is None: # Trade was already closed by the CLOSING-timeout handler in # monitor_positions() — nothing left to do. return # FIX 5: Remove the contract mapping BEFORE the REFILL trigger fires. # This ensures that if the CLOSING-timeout handler in monitor_positions() # races with this handler, it will find no contract_id and skip the # duplicate close_trade_from_broker() call. self._contract_to_trade.pop(contract_id, None) # Reward is now correctly based on actual profit reward = self._reward_from_broker(closed_trade) self.portfolio_risk_mgr.register_close(trade_id, closed_trade.realized_pnl) self._close_pending_episode(trade_id, reward) self._trade_tick_counts.pop(trade_id, None) self.stats["trades_closed"] += 1 self.stats["total_pnl"] += closed_trade.realized_pnl logger.info( f"💰 [{closed_trade.asset}] TRADE CLOSED | " f"exit_price={sell_price:.5f} | " f"reward={reward:+.6f} | profit={profit:+.4f} | " f"portfolio_dd={self.portfolio_risk_mgr._current_drawdown():.2%}" ) # Subscription cleanup sub_id = poc.get("id") or poc.get("subscription", {}).get("id") if sub_id and self.ws_client and self.ws_client.connected: asyncio.get_running_loop().create_task( self.ws_client.forget_contract(sub_id) ) # REFILL TRIGGER open_count = len(self.position_mgr.get_open_trades()) if open_count < 4 and self.running: logger.warning( f"[poc_terminal] ⚠️ REFILL TRIGGER — " f"open_count={open_count} < 4 after contract terminal. " f"Scheduling immediate rank_and_gate() to restore top-4." ) asyncio.get_running_loop().create_task(self._safe_rank_and_gate()) def _reward_from_broker(self, trade: "Trade") -> float: """ Compute replay reward from broker-authoritative profit. For multiplier contracts, trade.profit is the actual P&L in USD. With $1 stake, profit of 0.03 = 3% return = reward of 0.03. """ stake = self.trade_config.amount # should be 1.0 if trade.profit is not None: # Profit is already absolute P&L in account currency # Normalize to stake for reward consistency reward = trade.profit / stake if stake > 0 else 0.0 # Sanity check: realistic rewards for a $1 stake multiplier contract if reward < -1.0 or reward > 10.0: logger.warning( f"[_reward_from_broker] Unusual reward {reward:.4f} for {trade.asset} | " f"profit={trade.profit:.4f} stake={stake}" ) return reward # Fallback (should never happen with correct poc handling) logger.warning(f"[_reward_from_broker] No profit for {trade.asset}, returning 0") return 0.0 async def _safe_rank_and_gate(self) -> None: """Wrapper that silently catches errors from refill-triggered rank_and_gate().""" try: await self.rank_and_gate() except Exception as e: logger.error(f"[refill] rank_and_gate error: {e}") def _on_price_tick(self, tick_data: dict) -> None: try: symbol = tick_data.get("symbol") bid = float(tick_data.get("bid", 0)) ask = float(tick_data.get("ask", 0)) axrvi_id = SYMBOL_MAP.get(symbol) if not axrvi_id or axrvi_id not in self.price_streamers: return ts = time.time() self.price_streamers[axrvi_id].on_tick(bid, ask, ts) mid = (bid + ask) / 2.0 self._price_history[axrvi_id].append(mid) trade = self.position_mgr.get_open_trade_by_asset(axrvi_id) pnl = trade.compute_unrealized_pnl(mid) if trade else 0.0 self.asset_buffers[axrvi_id].on_price(mid, ask - bid, pnl) self.stats["price_ticks_received"] += 1 except Exception as e: logger.error(f"❌ Price tick error: {e}") # ── Signal processing (open trade) ──────────────────────────────────────────────── async def process_axrvi_signal( self, asset: str, action: str, value_estimate: float = 0.001, realized_vol: float = 0.01, significance: float = 0.5, cvar_05: float = 0.0, ) -> None: """ PERFORMANCE RANKER REFACTOR (v7): • Sizing veto removed: if compute_position_size() returns qty <= 0 we use a fixed minimum fallback (trade_config.amount / price) so the system ALWAYS opens a trade and produces a closed-episode reward signal. • Deriv API buy call is ALWAYS sent after the internal position is opened. The call uses SYMBOL_MAP_REVERSE to convert the internal asset id to the correct Deriv API symbol string. """ # Gate A still applies — only BUY/SELL are tradeable directions if action in ("HOLD", "NEUTRAL"): return streamer = self.price_streamers.get(asset) if not streamer or streamer.is_stale(): logger.warning(f"⚠️ [{asset}] No recent price — skipping signal") return # One open trade per asset — this guard is preserved (v7 constraint) if self.position_mgr.get_open_trade_by_asset(asset): logger.debug(f"[{asset}] Already has an open trade — skipping duplicate") return price = streamer.latest_mid trade_id = f"{asset}_{int(time.time() * 1000)}" # ── Institutional sizing (Kelly × conviction — vetoes removed in v7) ── quantity, sizing_reason = self.portfolio_risk_mgr.compute_position_size( asset_id = asset, current_price = price, value_estimate = value_estimate, realized_vol = max(realized_vol, 1e-4), significance = significance, cvar_05 = cvar_05, fallback_notional = self.trade_config.amount, # minimum data-collection stake ) # ── Fallback quantity guarantee (v7) ────────────────────────────────── # compute_position_size() now never returns 0, but be defensive: if for # any reason quantity is still 0, use the flat minimum so ranking data # is never lost. This block should never fire in normal operation. if quantity <= 0: fallback_qty = self.trade_config.amount / max(price, 1.0) logger.warning( f"[{asset}] Sizing returned qty=0 despite v7 refactor — " f"applying hard fallback qty={fallback_qty:.6f} | {sizing_reason}" ) quantity = fallback_qty # ── Open internal position (replay buffer + trainer integration) ────── direction = TradeDirection.LONG if action == "BUY" else TradeDirection.SHORT self.position_mgr.open_trade(trade_id, asset, direction, price, quantity) # ── Send actual buy order to Deriv API ──────────────────────────────── # CHANGE v7: this block is now UNCONDITIONAL — every internal open is # paired with a real Deriv contract. We use SYMBOL_MAP_REVERSE to # translate the internal asset id (e.g. "V75") to the Deriv symbol # string (e.g. "R_75") that the API expects. deriv_symbol = SYMBOL_MAP_REVERSE.get(asset) if deriv_symbol and self.ws_client and self.ws_client.connected: # MULTUP/MULTDOWN — the only contract type that works for all assets # (CRASH500/CRASH1000 reject CALL/PUT; MULTUP/MULTDOWN works everywhere) contract_type = "MULTUP" if action == "BUY" else "MULTDOWN" multiplier = ASSET_MULTIPLIER.get(asset, 50) # ── VALIDATION: Ensure multiplier is in broker's acceptable range ── # BUG FIX: Validates against ASSET_ACCEPTABLE_MULTIPLIERS to prevent # "Multiplier is not in acceptable range" rejections acceptable = ASSET_ACCEPTABLE_MULTIPLIERS.get(asset, [50]) if multiplier not in acceptable: # Fallback: pick smallest acceptable multiplier multiplier = min(acceptable) if acceptable else 50 logger.warning( f"[{asset}] ⚠️ Configured multiplier not in broker range. " f"Falling back to smallest acceptable: {multiplier}x" ) stake = self.trade_config.amount # always $1 stop_loss_amt = round(stake * ASSET_STOP_LOSS_FRAC.get(asset, 0.50), 2) take_profit_amt= round(stake * ASSET_TAKE_PROFIT_FRAC.get(asset, 0.80), 2) buy_msg = { "buy": "1", "price": stake, "parameters": { "amount": stake, "basis": "stake", "contract_type": contract_type, "currency": "USD", # NOTE: No duration/duration_unit — multipliers are open-ended # contracts with no fixed expiry. They close via the sell API # or when SL/TP is hit. Adding duration causes a broker error. "multiplier": multiplier, "symbol": deriv_symbol, "limit_order": { "stop_loss": stop_loss_amt, "take_profit": take_profit_amt, }, }, } sent = await self.ws_client.send_message(buy_msg) if sent: # send_message stamps req_id on buy_msg in-place — record the # mapping so _on_deriv_message can bind contract_id to this trade self._pending_req_to_trade[buy_msg["req_id"]] = trade_id logger.info( f"[{asset}] 📤 Deriv BUY sent | " f"contract={contract_type} | multiplier={multiplier}x | " f"stake=${stake} | SL=${stop_loss_amt} | TP=${take_profit_amt} | " f"deriv_symbol={deriv_symbol} | req_id={buy_msg['req_id']}" ) else: logger.error( f"[{asset}] ❌ Deriv BUY send FAILED — ws_client.send_message returned False" ) else: logger.error( f"[{asset}] ❌ Cannot send Deriv BUY — " f"deriv_symbol={deriv_symbol} (SYMBOL_MAP_REVERSE lookup) | " f"ws_connected={self.ws_client.connected if self.ws_client else 'no client'}" ) # Register committed notional with risk manager self.portfolio_risk_mgr.register_open(trade_id, quantity * price) # [S2] Capture s_t ∈ F_t at the moment of trade open # FIX 3: Guard against None — do NOT silently default to index 0, # which would attribute this episode to the wrong asset entirely. if self._last_selected_idx is None: logger.warning( f"[_open_pending_episode] [{asset}] No selected_idx available — " f"skipping episode capture for trade_id={trade_id}. " f"Trade will still execute but will NOT contribute to training." ) else: selected_idx = self._last_selected_idx self._open_pending_episode(trade_id, asset, selected_idx) # [S8] Initialise tick counter for optimal-stopping monitor self._trade_tick_counts[trade_id] = 0 buf = self.asset_buffers.get(asset) if buf and self.log_bridge: self.log_bridge.capture_signal(asset, buf, buf.axrvi_score) self.stats["trades_executed"] += 1 # ── Top-3 rotation — close displaced assets ─────────────────────────────────────── async def _handle_rank_rotation(self, current_top3: List[str]) -> None: """ Called after every rank cycle with the current top-3 asset IDs. Logic: • Compute which assets have LEFT the top-3 since the last cycle. • For each departed asset that still has an open position → close it now. • Update _prev_top3 to the current top-3 so next cycle can diff correctly. This implements the user rule: "Close when a 4th asset goes 3rd" i.e. whenever the ranked top-3 set changes, immediately exit the position on any asset that is no longer in the top-3. """ current_set = set(current_top3) departed = self._prev_top3 - current_set # assets that left top-3 if departed: logger.info( f"[Rotation] 🔄 Top-3 changed | " f"prev={sorted(self._prev_top3)} → now={sorted(current_set)} | " f"departing={sorted(departed)}" ) for asset in departed: trade = self.position_mgr.get_open_trade_by_asset(asset) if trade: # Skip if already closing — SELL was already sent to broker if trade.state == PositionState.CLOSING: logger.debug( f"[Rotation] ⏩ {asset} already CLOSING — skipping duplicate SELL" ) continue streamer = self.price_streamers.get(asset) price = streamer.latest_mid if streamer else trade.entry_price logger.info( f"[Rotation] 📤 Closing {asset} — no longer in top-3 | " f"price={price:.4f} | trade_id={trade.trade_id}" ) await self._close_position(trade.trade_id, price) self._prev_top3 = current_set # ── Position monitoring ──────────────────────────────────────────────────────────── async def _close_position(self, trade_id: str, exit_price: float) -> None: # ── Prevent re-sending SELL for a contract already in CLOSING state ─ with self.position_mgr._lock: open_trade = self.position_mgr._open_trades.get(trade_id) if open_trade is None: return # already closed if open_trade.state == PositionState.CLOSING: # Sell already sent; don't spam the broker — just wait for POC logger.debug( f"[{trade_id}] ⏳ Already CLOSING — skipping duplicate SELL | " f"contract_id={open_trade.contract_id}" ) return cid = open_trade.contract_id if open_trade else None if cid and self.ws_client and self.ws_client.connected: sell_sent = await self.ws_client.send_message({ "sell": cid, "price": 0, # 0 = market price / best available }) if sell_sent: logger.info(f"[{trade_id}] 📤 Deriv SELL sent | contract_id={cid}") else: logger.error(f"[{trade_id}] ❌ Deriv SELL send FAILED | contract_id={cid}") elif not cid: # contract_id was never bound — buy confirmation hasn't arrived yet # or trade was opened before this fix was deployed. Log and continue # so internal bookkeeping still closes cleanly. logger.warning( f"[{trade_id}] ⚠️ No contract_id bound — cannot send Deriv SELL. " f"Contract may expire naturally on Deriv side." ) # When the trade has a contract_id the Deriv broker will emit a # proposal_open_contract terminal event which calls # _handle_poc_terminal_event. That handler is the single authority # for register_close / _close_pending_episode / stats updates, so we # must NOT duplicate those calls here — that was the double-execution # bug. We mark the trade as CLOSING and let the poc handler finish it. if cid: self.position_mgr.mark_closing(trade_id) return # ── No contract_id path ──────────────────────────────────────────────── # Trade was never confirmed by the broker (e.g. buy response not yet # received, or pre-fix checkpoint). Do the full local close here # because no poc event will arrive. # # Fees: must be on stake, NOT on spot price. # Was: exit_price * commission_rate → e.g. 316738 * 0.001 = $316.74 (wrong!) # Now: stake * commission_rate → e.g. 1.0 * 0.001 = $0.001 (correct) fees = self.trade_config.amount * self.trade_config.commission_rate trade = self.position_mgr.close_trade(trade_id, exit_price, fees) if not trade: return price_hist = list(self._price_history[trade.asset]) reward = self.reward_calc.compute_reward(trade, price_hist or None) # Release committed capital and update portfolio equity self.portfolio_risk_mgr.register_close(trade_id, trade.realized_pnl) # [S2] Push proper (s_t, a_t, r_t, s_{t+1}) — s_{t+1} captured here self._close_pending_episode(trade_id, reward) self._trade_tick_counts.pop(trade_id, None) self.stats["trades_closed"] += 1 self.stats["total_pnl"] += trade.realized_pnl logger.info( f"💰 Closed {trade.asset} (no-cid fallback) | reward={reward:+.6f} | " f"pnl={trade.realized_pnl:+.4f} | " f"exit_price={exit_price:.5f} | " f"portfolio_dd={self.portfolio_risk_mgr._current_drawdown():.2%}" ) async def monitor_positions(self) -> None: """ [S8] Optimal-stopping exit rule: τ* = arg sup_{τ ∈ T} E[R_τ | F_t] At each tick, compare: G_t = sgn(direction) · log(S_t / S_entry) − fees (immediate liquidation value) C_t = model value estimate of future return (continuation value) Exit priority (evaluated in order, first match wins): 1. OptStop — G_t ≥ C_t + stopping_value_buffer (profit target) 2. MaxLoss — broker live P&L ≤ −(sl_frac × stake) (stop-loss) 3. HARD WALL — holding_duration > 120 s (absolute failsafe) 4. SoftExpiry — holding_duration > expiry_time (default 60 s) Minimum holding guard: do not evaluate stopping until _trade_tick_counts[trade_id] >= shreve_config.min_holding_ticks. REFILL TRIGGER (v7): After any position is closed, if open_trade_count drops below 3, immediately call rank_and_gate() to refill. CLOSING TIMEOUT FIX (v6.1): If a trade remains in CLOSING state for > 10 seconds without a terminal event from the broker, force-close it locally to prevent stuck trades. NOTE: requires mark_closing() to stamp exit_time (fixed in this version). """ sc = self.config.shreve_config # FIX 6: Extended from 10 → 30 seconds. # Deriv multiplier contracts can take 15-20 s to emit the terminal # proposal_open_contract event after the sell is acknowledged. # 10 s was too aggressive and caused false force-closes with # estimated (non-broker-authoritative) P&L. CLOSING_TIMEOUT_SECONDS = 30.0 # Maximum time to wait for broker terminal event while self.running: try: closed_any = False # track whether we closed a trade this tick for trade in self.position_mgr.get_open_trades(): # ── CLOSING TIMEOUT HANDLER ────────────────────────────────── # If trade has been CLOSING for > CLOSING_TIMEOUT_SECONDS, # force-close it locally (broker never responded) if trade.state == PositionState.CLOSING: closing_duration = time.time() - trade.exit_time if trade.exit_time else 0 if closing_duration > CLOSING_TIMEOUT_SECONDS: logger.warning( f"[{trade.asset}] ⚠️ CLOSING TIMEOUT | " f"trade_id={trade.trade_id} | " f"contract_id={trade.contract_id} | " f"stuck in CLOSING for {closing_duration:.1f}s — " f"forcing local close" ) # Force close locally (broker never responded) # Use current price as exit price streamer = self.price_streamers.get(trade.asset) price = streamer.latest_mid if streamer else trade.entry_price # Estimate profit from broker data if available if trade.profit is not None: profit = trade.profit else: # Fallback: estimate from price movement if price > 0 and trade.entry_price > 0: pct_move = (price - trade.entry_price) / trade.entry_price sign = 1.0 if trade.direction == TradeDirection.LONG else -1.0 mult = ASSET_MULTIPLIER.get(trade.asset, 50) stake = trade.buy_price if (trade.buy_price and trade.buy_price > 0) else 1.0 profit = sign * pct_move * stake * mult else: profit = 0.0 closed_trade = self.position_mgr.close_trade_from_broker( trade_id=trade.trade_id, status="timeout", profit=profit, sell_price=price, exit_tick=price, ) if closed_trade: reward = self._reward_from_broker(closed_trade) self.portfolio_risk_mgr.register_close(trade.trade_id, closed_trade.realized_pnl) self._close_pending_episode(trade.trade_id, reward) self._trade_tick_counts.pop(trade.trade_id, None) self.stats["trades_closed"] += 1 self.stats["total_pnl"] += closed_trade.realized_pnl closed_any = True logger.info( f"💰 [{closed_trade.asset}] TRADE FORCE-CLOSED (timeout) | " f"reward={reward:+.6f} | profit={profit:+.4f}" ) continue # Skip other checks for CLOSING trades streamer = self.price_streamers.get(trade.asset) if not streamer: continue price = streamer.latest_mid trade.unrealized_pnl = trade.compute_unrealized_pnl(price) # Tick counter for minimum holding period tid = trade.trade_id self._trade_tick_counts[tid] = self._trade_tick_counts.get(tid, 0) + 1 # ── Optimal stopping evaluation [S8] ────────────────────── should_stop = False stop_reason = "" # Wall-clock minimum holding guard — replaces unreliable tick counter. # Ticks fire at ~2 s intervals but can bunch or stall; 20 s is absolute. holding_secs = time.time() - trade.entry_time if holding_secs >= sc.min_holding_seconds: # G_t — immediate liquidation log-return if price > 0 and trade.entry_price > 0: raw_log_ret = math.log(price / trade.entry_price) sign = 1.0 if trade.direction == TradeDirection.LONG else -1.0 # Fees for G_t must be on stake, NOT on spot price. fees = self.trade_config.amount * self.trade_config.commission_rate slippage = self.trade_config.slippage_bps / 10_000.0 g_t = sign * raw_log_ret - fees / price - slippage # C_t — continuation value from model's value head buf = self.asset_buffers.get(trade.asset) c_t = self._last_value_estimates.get(trade.asset, 0.0) if buf is not None and c_t == 0.0: # Fallback: use directional significance as proxy c_t = buf.axrvi_score * 0.01 if g_t >= c_t + sc.stopping_value_buffer: should_stop = True stop_reason = ( f"OptStop: G_t={g_t:+.5f} ≥ C_t={c_t:+.5f} " f"+ buffer={sc.stopping_value_buffer:.5f}" ) # ── [2] Max-loss stop: broker live P&L hit stop-loss ───────── # Fires when price moved against us — G_t is negative so # OptStop never triggers, but we still need to cut losses. # Uses _last_poc_profit: broker-authoritative live P&L cached # from the poc stream on every tick (set in _on_poc_update). if not should_stop: sl_frac = ASSET_STOP_LOSS_FRAC.get(trade.asset, 0.50) stake = (trade.buy_price if (trade.buy_price and trade.buy_price > 0) else self.trade_config.amount) sl_threshold = -(sl_frac * stake) live_loss = getattr(trade, "_last_poc_profit", None) if live_loss is None: live_loss = trade.unrealized_pnl # approx fallback if live_loss is not None and live_loss <= sl_threshold: should_stop = True stop_reason = ( f"MaxLoss: live_pnl={live_loss:+.4f} ≤ " f"sl_threshold={sl_threshold:+.4f} " f"(SL={sl_frac:.0%} × stake={stake:.2f})" ) # ── [3] Hard absolute wall: 120 s no matter what ────────────── # Fires when OptStop and MaxLoss both failed (e.g. G_t oscillating # near zero, poc profit stream delayed). 120 s = 2× the soft expiry. MAX_HOLDING_SECONDS = 120 if not should_stop and holding_secs >= sc.min_holding_seconds and holding_secs > MAX_HOLDING_SECONDS: should_stop = True stop_reason = ( f"HARD WALL: held {holding_secs:.0f}s " f"> {MAX_HOLDING_SECONDS}s absolute limit" ) # ── [4] Soft expiry: configured expiry_time (default 60 s) ──── elif not should_stop and holding_secs >= sc.min_holding_seconds and holding_secs > self.trade_config.expiry_time: should_stop = True stop_reason = ( f"SoftExpiry: held {holding_secs:.0f}s " f"> expiry_time={self.trade_config.expiry_time}s" ) if should_stop: logger.info(f"[{trade.asset}] EXIT | {stop_reason}") await self._close_position(tid, price) closed_any = True # ── REFILL TRIGGER (v7) ─────────────────────────────────────── # After closing any trade this cycle, check whether open_count # has dropped below the 2-trade floor. If so, trigger a fresh # rank_and_gate() immediately — don't wait for the next scheduled # _rank_loop tick — so the minimum is restored as fast as possible. # # [HANG FIX — Layer 2] Only trigger a refill if rank_and_gate # is NOT currently running. Re-entering rank_and_gate while the # scheduled _rank_loop call is still inside it corrupts shared # state and can deadlock on the Deriv WS. If it's running, the # next scheduled tick will pick up the refill need within # update_frequency_seconds — acceptable latency vs a deadlock. if closed_any: open_count = len(self.position_mgr.get_open_trades()) if open_count < 4: if self._rank_lock is not None and self._rank_lock.locked(): logger.info( f"[monitor_positions] ⏩ REFILL SKIPPED — " f"rank_and_gate already running; next scheduled " f"tick will restore top-4 (open_count={open_count})" ) else: logger.warning( f"[monitor_positions] ⚠️ REFILL TRIGGER — " f"open_count={open_count} < 4 after close. " f"Calling rank_and_gate() immediately to restore top-4." ) try: await self.rank_and_gate() except Exception as refill_err: logger.error( f"[monitor_positions] ❌ Refill rank_and_gate error: {refill_err}", exc_info=True, ) await asyncio.sleep(2) except Exception as e: logger.error(f"❌ Position monitor error: {e}") await asyncio.sleep(2) async def _ensure_minimum_trades( self, hub_snapshots: Dict[str, "AssetSnapshot"], significance_map: Dict[str, float], value_map: Dict[str, float], cvar_map: Dict[str, float], aleatoric_map: Dict[str, float], epistemic_map: Dict[str, float], ) -> None: """ SAFETY NET — called at the start of every rank_and_gate() cycle. Guarantees that open_trade_count >= 4 before the per-candidate gate loop runs. If open_count < 4, this method forces immediate execution on the top-ranked available assets from the most recent hub_snapshots, bypassing all gate outcomes. The only hard filter respected here is Gate A: the asset must have a BUY or SELL dominant_signal — a NEUTRAL/HOLD signal cannot be acted on because we would not know which direction to trade. No more than (4 - open_count) trades are opened here; the remaining slots are filled by the normal enforcer loop in rank_and_gate(). """ open_count = len(self.position_mgr.get_open_trades()) if open_count >= 4: return # already at or above the floor — nothing to do needed = 4 - open_count logger.warning( f"[_ensure_minimum_trades] ⚠️ open_count={open_count} < 4 — " f"forcing execution on top {needed} ranked asset(s) to restore top-4" ) # Build a priority-sorted candidate list. # Prefer assets that are in the known top-4 (_prev_top3) — they are the # correct targets for this system. Fall back to the full hub_snapshots # pool only if the top-4 set is not yet populated (first run). # Within each tier, sort by value_map score descending. top3_pool = self._prev_top3 if self._prev_top3 else set(hub_snapshots.keys()) candidates: List[Tuple[float, str]] = [] for asset_id, snap in hub_snapshots.items(): if snap.dominant_signal in ("NEUTRAL", "HOLD"): continue # Gate A: cannot trade without a direction tier = 0 if asset_id in top3_pool else 1 # top-4 first priority = value_map.get(asset_id, significance_map.get(asset_id, 0.0)) candidates.append((tier, -priority, asset_id)) # sort: tier asc, priority desc candidates.sort() # (tier, -priority, asset_id) → top-4 highest-value first filled = 0 for _, _, asset_id in candidates: if filled >= needed: break # One trade per asset (constraint preserved) if self.position_mgr.get_open_trade_by_asset(asset_id): continue snap = hub_snapshots.get(asset_id) if snap is None: continue buf = self.asset_buffers.get(asset_id) realized_vol = buf.feature_eng.get_raw_feature(6) if buf else 0.01 # FIX 5: Guard stale price BEFORE force-executing so the warning # appears here (in _ensure_minimum_trades) rather than inside # process_axrvi_signal where it silently eats the slot. streamer = self.price_streamers.get(asset_id) if not streamer or streamer.is_stale(): logger.warning( f"[_ensure_minimum_trades] ⚠️ [{asset_id}] Price cache stale " f"(last update {time.time() - streamer.last_update_time:.1f}s ago) " f"— skipping force-execute; will retry next rank cycle." if streamer else f"[_ensure_minimum_trades] ⚠️ [{asset_id}] No price streamer — skipping." ) continue logger.warning( f"[_ensure_minimum_trades] FORCE EXECUTE {snap.dominant_signal} " f"on {asset_id} (top-4 floor enforcement, all gates bypassed)" ) await self.process_axrvi_signal( asset = asset_id, action = snap.dominant_signal, value_estimate = max(value_map.get(asset_id, 0.001), 1e-4), realized_vol = realized_vol, significance = significance_map.get(asset_id, 0.5), cvar_05 = cvar_map.get(asset_id, 0.0), ) filled += 1 if filled < needed: logger.warning( f"[_ensure_minimum_trades] Could only fill {filled}/{needed} slots " f"— not enough top-4 assets with BUY/SELL signals at this tick" ) def _push_rankings_to_hub(self, ranked: list) -> None: """ Fire-and-forget HTTP POST of the current AXRVI‑scored ranking list to the hub's /api/flip/rankings endpoint, including the real flip_direction. [HANG FIX — Layers 4 & 5] • Previously the timeout was 0.5s and failures logged at DEBUG, so transient hub slowness silently dropped rankings with no visibility. • Now: timeout=3.0s, failures logged at WARNING. • The HTTP POST runs in a thread-pool executor so hub latency never blocks the async rank loop — even if the hub is completely dead, the ranker keeps producing rankings. """ if not ranked: return # Capture the payload here (in the async caller's context) so the # snapshot read is consistent, then hand the blocking HTTP POST to # the executor. try: rankings_payload = [] for r in ranked: snap = self.hub_subscriber.get_snapshot(r.space_name) snap_signal = snap.dominant_signal if snap else "NEUTRAL" flip_dir = ( snap_signal if snap_signal not in ("NEUTRAL", "NONE", None, "") else r.dominant_signal ) rankings_payload.append({ "space_name": r.space_name, "score": r.score, "final_priority": r.final_priority, "rank": r.rank, "dominant_signal": flip_dir, "flip_direction": flip_dir, "avn_accuracy": r.avn_accuracy, "signal_confidence": r.signal_confidence, "epistemic_std": r.epistemic_std, "training_steps": r.training_steps, }) payload = json.dumps({"rankings": rankings_payload}).encode() url = f"{self._hub_http_url}/api/flip/rankings" except Exception as build_err: logger.warning(f"[Rankings] Payload build failed: {build_err}") return def _do_post() -> None: """Blocking HTTP POST — runs on a worker thread, not the event loop.""" try: import urllib.request as _urlreq req = _urlreq.Request( url, data=payload, headers={"Content-Type": "application/json"}, method="POST", ) _urlreq.urlopen(req, timeout=3.0) except Exception as post_err: # Warn-level so we see repeated failures in the log and can # diagnose whether the hub is the problem next time. logger.warning(f"[Rankings] Hub push failed: {post_err}") try: loop = asyncio.get_running_loop() loop.run_in_executor(None, _do_post) except RuntimeError: # No running loop (e.g. called from sync context) — do it inline. _do_post() async def rank_and_gate(self) -> None: """ [HANG FIX — Layers 2 & 3] Thin wrapper around _rank_and_gate_impl that holds self._rank_lock for the duration of the cycle (so _rank_loop and the monitor_positions refill trigger cannot re-enter concurrently) and stamps self._last_rank_complete_ts so the watchdog can detect stalls. Also logs cycle timing at WARNING level if a cycle exceeds 5s — slow cycles are the early symptom of a pending hang. """ # asyncio.Lock must be created inside a running loop on Python <3.10 if self._rank_lock is None: self._rank_lock = asyncio.Lock() t0 = time.time() async with self._rank_lock: try: await self._rank_and_gate_impl() finally: # Stamp even on failure: a fast failure is healthier than a # stalled coroutine. The watchdog should only fire when the # rank loop is actually stuck, not when it's crashing loudly. self._last_rank_complete_ts = time.time() elapsed = time.time() - t0 if elapsed > 5.0: logger.warning(f"[RankCycle] completed slowly in {elapsed:.2f}s") async def _rank_and_gate_impl(self) -> None: """ v6/v7 Shreve Ranking Cycle: 1. Data readiness check 2. Build 26-dim feature tensors (UnifiedFeatureEngine, QV vol [S3]) 3. AXRVINet + MC Dropout → significance_weight + V̂_t (value head) [S1] 4a. ShreveRankingEngine: Π_t = D(t,τ)·Ê[R|F_t] + ½σ²Δt·κ_t [S1,S5,S6] 4b. ConservativeRanker: lower-confidence-bound adjustment 4c. TOP-4 ROTATION: close any asset that left the top-4 since last cycle 5. BanditSelector: top 3 candidates 6. _ensure_minimum_trades() — safety net: force 3 open trades before gates run 7. TOP-3 ENFORCER: Hard filter: Gate A only (signal ∈ {BUY, SELL}) open_count < 3 → execute immediately, no veto (top-3 must be filled) open_count >= 3 → do not open new trades (hard ceiling = 3) Gates B, C, D, E → LOG ONLY (never block) 8. Train on replay buffer (with L_CE [S1]) """ self.rank_count += 1 # ── Step 1 ───────────────────────────────────────────────────────────── ready = [a for a in self.config.asset_symbols if self.asset_buffers[a].has_data()] if len(ready) < 1: logger.info(f"⏳ Waiting for data ({len(ready)}/{len(self.config.asset_symbols)} ready)") return # ── Step 2 ───────────────────────────────────────────────────────────── # _build_input_tensors returns (1,N,T,F), stale mask (1,N), and the # full asset_symbols list as active_ids (N is always the full roster). seq_t, stale_t, active_ids = self._build_input_tensors() # ── Step 3: neural inference ─────────────────────────────────────────── self.axrvi_net.eval() with torch.no_grad(): if self.config.uncertainty_config.use_mc_dropout: out = self.axrvi_net.forward_with_epistemic_uncertainty( seq_t, mc_samples=self.config.uncertainty_config.mc_samples, ) epistemic_std_arr = out["epistemic_std"].squeeze(0).numpy() else: out = self.axrvi_net(seq_t, stale_mask=stale_t) epistemic_std_arr = np.zeros(len(active_ids)) significance_arr = out["significance_weight"].squeeze(0).numpy() # Aleatoric uncertainty from log_var log_var_t = out.get("log_var", torch.zeros(1, len(active_ids), 1)) aleatoric_std_arr = torch.exp(0.5 * log_var_t).squeeze(0).squeeze(-1).detach().numpy() if aleatoric_std_arr.shape == (): aleatoric_std_arr = np.full(len(active_ids), 0.1) # [S1] Extract value-head estimates (V̂_t = Ê[R_{t+τ} | F_t] proxy) # "value" is now always present — MC path computes it from mean_quantiles. if "value" in out: v_arr = out["value"].squeeze(0).squeeze(-1).detach().numpy() else: v_arr = significance_arr # last-resort fallback significance_map: Dict[str, float] = {} epistemic_map: Dict[str, float] = {} aleatoric_map: Dict[str, float] = {} value_map: Dict[str, float] = {} # [S1] cvar_map: Dict[str, float] = {} # CVaR@5% for sizing layer for i, aid in enumerate(active_ids): significance_map[aid] = float(significance_arr[i]) epistemic_map[aid] = float(epistemic_std_arr[i]) if i < len(epistemic_std_arr) else 0.0 aleatoric_map[aid] = float(aleatoric_std_arr[i]) if i < len(aleatoric_std_arr) else 0.1 value_map[aid] = float(v_arr[i]) if i < len(v_arr) else 0.0 # Extract CVaR@5% per asset — DistributionalHead returns (B, N), MC path same if "cvar_05" in out: _cvar_raw = out["cvar_05"].detach() if _cvar_raw.dim() == 2: cvar_arr = _cvar_raw.squeeze(0).numpy() # (N,) else: cvar_arr = _cvar_raw.numpy() # already (N,) on det path for i, aid in enumerate(active_ids): cvar_map[aid] = float(cvar_arr[i]) if i < len(cvar_arr) else 0.0 else: cvar_map = {aid: 0.0 for aid in active_ids} # Cache value estimates for optimal-stopping monitor [S8] self._last_value_estimates = value_map # ── Step 4a: Shreve risk-neutral ranking [S1, S5, S6] ───────────────── hub_snapshots = self.hub_subscriber.get_all_snapshots() ranked = self.ranking_engine.rank_risk_neutral( snapshots = hub_snapshots, significance_weights = significance_map, value_estimates = value_map, ) # ── Step 4b: low-temperature softmax allocation (replaces ConservativeRanker) ── # The old ConservativeRanker applied hub_confidence × max(0, sig − 1.96σ) # which collapsed all scores to ~0.304 regardless of neural output. # # Replacement: π_i = Softmax(z_i / τ), τ = 0.05 # At τ=0.05 a 0.1 logit difference → ~7× probability ratio, giving the # network genuine ability to express conviction through score magnitude. # Rankings change every cycle as logits diverge during training. _SOFTMAX_TEMPERATURE = 0.05 if ranked: # Extract raw priorities (z_i) from ShreveRankingEngine output raw_logits = np.array([r.final_priority for r in ranked], dtype=np.float64) # Numerically stable softmax with temperature shifted = raw_logits - raw_logits.max() exp_v = np.exp(shifted / _SOFTMAX_TEMPERATURE) alloc_probs = exp_v / (exp_v.sum() + 1e-12) # π_i ∈ Δⁿ for r, prob in zip(ranked, alloc_probs): r.epistemic_std = epistemic_map.get(r.space_name, 0.0) r.aleatoric_std = aleatoric_map.get(r.space_name, 0.1) r.final_priority = float(prob) # allocation probability, not compressed score r.score = float(prob) # Logit spread diagnostic — should grow from ~0.002 → >0.1 as training progresses logit_spread = float(raw_logits.max() - raw_logits.min()) logger.debug( f"[SoftmaxRanker] τ={_SOFTMAX_TEMPERATURE} | " f"logit_spread={logit_spread:.4f} | " f"alloc_probs={[f'{p:.4f}' for p in alloc_probs[:5]]}" ) ranked.sort(key=lambda x: x.final_priority, reverse=True) for i, r in enumerate(ranked): r.rank = i + 1 # ── Top-3 rotation: close any asset that left the top-3 ─────────────── current_top3 = [r.space_name for r in ranked[:3]] await self._handle_rank_rotation(current_top3) if self.log_bridge: self.log_bridge.capture_ranking(ranked, hub_snapshots) # ── Push live AXRVI rankings to hub /api/flip/rankings ─────────────── # The hub's _compute_rankings() will serve these from /api/state so # the dashboard Score column reflects real AXRVINet priorities, not # the stale hub-snapshot vote-ratio fallback. self._push_rankings_to_hub(ranked) logger.info( "📊 Rankings: " + " | ".join( f"{r.rank}.{r.space_name} " f"priority={r.final_priority:+.4f} " f"(V̂={value_map.get(r.space_name, 0.0):+.3f}, " f"ε_std={r.epistemic_std:.3f}, signal={r.dominant_signal})" for r in ranked[:5] ) ) # ── Step 5: bandit selection ─────────────────────────────────────────── # Use softmax allocation probabilities (not raw significance_map) so the # bandit sees the same competitive scores as the ranking display. # Build a lookup from space_name → alloc_prob for fast O(1) access. alloc_prob_map: Dict[str, float] = ( {r.space_name: r.final_priority for r in ranked} if ranked else {} ) neural_scores = np.array([ alloc_prob_map.get(a, significance_map.get(a, 0.0)) for a in self.config.asset_symbols ]) selected_ids, final_scores = self.bandit.select( neural_scores = neural_scores, threshold = self.config.score_threshold, max_select = self.config.max_concurrent, ) self._selected_assets = selected_ids self._last_final_scores = final_scores # Store as (N, T, F) ndarray — used by _open_pending_episode [S2] self._last_sequences = seq_t.squeeze(0).numpy() # (N, T, F) self._last_selected_idx = ( active_ids.index(selected_ids[0]) if selected_ids else None ) self.stats["rank_cycles"] += 1 if selected_ids: self.stats["selected_assets"][selected_ids[0]] += 1 for i, aid in enumerate(active_ids): self.asset_buffers[aid].axrvi_score = float(final_scores[i]) self.asset_buffers[aid].is_enabled = aid in selected_ids # Cache latest hub snapshots for _ensure_minimum_trades() refill logic self._last_hub_snapshots = hub_snapshots # ── Steps 6–7: MINIMUM TRADES ENFORCER (v7) ─────────────────────────── # # PHILOSOPHY (v7): This system is a PERFORMANCE RANKER. Trades are the # measurement instrument. Zero active trades = zero ranking data. # # RULE: # open_count < 4 → execute immediately on top-ranked assets regardless # of gate outcomes — NO VETO ALLOWED # open_count == 4 → hard ceiling — do not open new trades # open_count >= 4 → do not open new trades (hard ceiling) # # Gate A (dominant_signal ∈ {BUY, SELL}) remains the ONLY hard filter. # Gates B, C, D, E are demoted to LOG ONLY — they never block a trade. # # _ensure_minimum_trades() is called first as a safety net before the # per-candidate loop runs. await self._ensure_minimum_trades(hub_snapshots, significance_map, value_map, cvar_map, aleatoric_map, epistemic_map) open_count = len(self.position_mgr.get_open_trades()) for asset_id in selected_ids: # Re-check live count each iteration so we never exceed the ceiling open_count = len(self.position_mgr.get_open_trades()) # Hard ceiling: 4 concurrent trades maximum if open_count >= 4: logger.debug( f"[{asset_id}] SKIP — open_count={open_count} ≥ 4 (hard ceiling)" ) break snap = hub_snapshots.get(asset_id) sig_w = significance_map.get(asset_id, 0.0) if snap is None: continue # ── Gate A — hard filter: directional signal required ────────────── # This is the ONLY gate that can block a trade in the enforcer model. if snap.dominant_signal in ("NEUTRAL", "HOLD"): logger.debug(f"[{asset_id}] SKIP Gate A — signal={snap.dominant_signal}") continue # ── Gate B — LOG ONLY (was: hub_confidence ≤ 0 → skip) ─────────── if snap.signal_confidence <= 0.0: logger.info( f"[{asset_id}] Gate B LOG: hub_confidence=0.00 " f"(open_count={open_count} — executing anyway if below floor)" ) # Below the 4-trade floor, execute regardless of confidence if open_count >= 4: continue # ── Gate C — LOG ONLY (was: significance < threshold → skip) ────── if sig_w < self.config.score_threshold: logger.info( f"[{asset_id}] Gate C LOG: significance={sig_w:.3f} " f"< threshold={self.config.score_threshold:.3f} " f"(open_count={open_count} — executing anyway if below floor)" ) if open_count >= 4: continue # ── Gates D + E — LOG ONLY (was: DynamicExecutionGate → block) ──── buf = self.asset_buffers.get(asset_id) vol_ratio = buf.feature_eng.get_raw_feature(7) if buf else 1.0 jump_risk = buf.feature_eng.get_raw_feature(24) if buf else 0.0 mart_dev_raw = buf.feature_eng.get_raw_feature(23) if buf else 1.0 realized_vol = buf.feature_eng.get_raw_feature(6) if buf else 0.01 execute, reason = self.execution_gate.should_execute( hub_confidence = snap.signal_confidence, significance = sig_w, volatility_ratio = vol_ratio, jump_risk = jump_risk, epistemic_std = epistemic_map.get(asset_id, 0.0), aleatoric_std = aleatoric_map.get(asset_id, 0.1), dominant_signal = snap.dominant_signal, martingale_deviation = mart_dev_raw, ) if not execute: logger.info( f"[{asset_id}] Gate D/E LOG (not blocking): {reason} " f"open_count={open_count}" ) # Below the 4-trade floor, enforce the minimum and execute anyway if open_count >= 4: continue # ── Determine whether to execute based on open_count ────────────── # open_count < 4 → execute unconditionally (top-4 enforcer) # open_count >= 4 → already caught by the ceiling check at loop top _top_prob = ranked[0].final_priority if ranked else 0.0 _raw_sig = significance_map.get(asset_id, 0.0) logger.info( f"[{asset_id}] EXECUTE {snap.dominant_signal} | " f"open_count={open_count} | " f"alloc_prob={_top_prob:.4f} | " f"sig_logit={_raw_sig:+.4f} | " f"V̂={value_map.get(asset_id, 0.0):+.3f}" ) await self.process_axrvi_signal( asset = asset_id, action = snap.dominant_signal, value_estimate = max(value_map.get(asset_id, 0.001), 1e-4), realized_vol = realized_vol, significance = sig_w, cvar_05 = cvar_map.get(asset_id, 0.0), ) # ── Step 8: train on replay buffer ──────────────────────────────────── if self.rank_count % self.config.train_every_n == 0: buf_ready = self.replay.is_ready(self.config.batch_size) buf_size = len(self.replay.buffer) if hasattr(self.replay, "buffer") else "?" if buf_ready: batch = self.replay.sample(self.config.batch_size) train_result = self.trainer.train_on_batch(batch) # Advance LR schedule once per training pass (not per batch) self.trainer.step_scheduler(path=self.config.model_path) if train_result: logger.info( f"✅ [Training] step={self.trainer.train_step} | " f"loss={train_result.get('total', 0):.4f} | " f"rl={train_result.get('rl', 0):.4f} | " f"ce={train_result.get('ce', 0):.4f} | " f"batch_size={len(batch)}" ) # ── HF SYNC: mirror trainer .pt to HF periodically ──────────── # FIX 6: Upload every 10 training steps (not every step) to avoid # flooding HF with commits. The full-state autosave (maybe_save) # handles persistence every 5 min regardless. if self.trainer.train_step % 10 == 0: self.checkpoint_mgr._hf.queue_upload( local_path=self.config.model_path, step=self.trainer.train_step, metadata={ "reason": "train_step", "loss": train_result.get("total", 0.0), "rl": train_result.get("rl", 0.0), "ce": train_result.get("ce", 0.0), "train_step": self.trainer.train_step, }, ) else: logger.warning( "⚠️ [Training] train_on_batch returned empty result — " "possible invalid episodes in batch (check sequences vs next_sequences diff)" ) else: logger.info( f"⏳ [Training] Replay buffer not ready yet | " f"size={buf_size}/{self.config.batch_size} | " f"step={self.trainer.train_step if self.trainer else 0} | " f"(waiting for {self.config.batch_size} closed-trade episodes)" ) # ── Time-based checkpoint: every 30 min regardless of training ──────── self.checkpoint_mgr.maybe_save(self) # ── Input tensor construction ────────────────────────────────────────────────────── def _build_input_tensors( self, ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: """ Build the (1, N, T, F) input tensor for AXRVINet from live asset buffers. Returns ------- seq_t : FloatTensor of shape (1, N, T, F) N = len(config.asset_symbols) — always the full roster. Assets that have no data yet are represented as all-zeros, and their stale flag is True so the model can mask them. stale_t : BoolTensor of shape (1, N) — True means the asset is stale and should be masked out in cross-asset attention layers. active_ids : list[str] — same order as the N dimension; always equals config.asset_symbols so callers can zip with model outputs. """ asset_ids = self.config.asset_symbols # canonical N-dim ordering T = self.config.feature_window # SEQ_LEN F = self.config.feature_dim # FEATURE_DIM seqs = [] stale_flags = [] for aid in asset_ids: buf = self.asset_buffers.get(aid) if buf is not None: seqs.append(buf.get_sequence()) # (T, F) float32 stale_flags.append(buf.is_stale) else: # Asset registered but buffer not yet created — pad with zeros seqs.append(np.zeros((T, F), dtype=np.float32)) stale_flags.append(True) # seq_np: (N, T, F) → add batch dim → (1, N, T, F) seq_np = np.stack(seqs, axis=0)[np.newaxis].astype(np.float32) seq_t = torch.from_numpy(seq_np) # (1, N, T, F) stale_t = torch.tensor([stale_flags], dtype=torch.bool) # (1, N) return seq_t, stale_t, asset_ids # ── Episode recording — Shreve non-anticipating transitions [S2] ───────────────── def _open_pending_episode(self, trade_id: str, asset_id: str, selected_idx: int) -> None: """ [S2] Capture s_t ∈ F_t at trade-OPEN time. ``_last_sequences`` is set in rank_and_gate() as ``seq_t.squeeze(0).numpy()`` → shape (N, T, F), where N = len(asset_symbols), T = feature_window, F = feature_dim. If no rank cycle has completed yet (cold-start edge-case), we fall back to building the tensor directly from the current asset buffers so the replay entry is never skipped. """ n_assets = len(self.config.asset_symbols) # --- s_t snapshot --- if self._last_sequences is not None: sequences = self._last_sequences # already (N, T, F) ndarray else: # cold-start fallback: build from live buffers right now seq_t, _, _ = self._build_input_tensors() sequences = seq_t.squeeze(0).numpy() # (N, T, F) # Quadratic-variation volatility [S3] — used for Girsanov priority buf = self.asset_buffers.get(asset_id) vol = 1.0 if buf and len(buf.feature_eng._returns) >= 5: arr = np.array(list(buf.feature_eng._returns)[-20:], dtype=np.float64) vol = float(np.sqrt(np.sum(arr ** 2))) + 1e-8 self._pending_episodes[trade_id] = { "sequences": sequences.tolist(), # s_t ∈ F_t — shape (N, T, F) "selected_idx": selected_idx, "asset_id": asset_id, "n_assets": n_assets, "volatility": vol, } def _close_pending_episode(self, trade_id: str, reward: float) -> None: """ [S2] Complete the (s_t, a_t, r_t, s_{t+1}) transition at trade-CLOSE. s_t — shape (N, T, F), captured at trade OPEN (F_t-measurable). s_{t+1} — shape (N, T, F), built from CURRENT buffers (F_{t+1}-measurable). Both tensors are stored as nested Python lists so they survive json-like serialisation inside the replay buffer dict and can be re-hydrated by HybridTrainer as torch.FloatTensor(ep["sequences"]) → (N, T, F). """ ep = self._pending_episodes.pop(trade_id, None) if ep is None: return asset_id = ep["asset_id"] n_assets = ep["n_assets"] selected_idx = ep["selected_idx"] # ✅ s_t — F_t-measurable state captured at trade OPEN sequences = np.array(ep["sequences"], dtype=np.float32) # (N, T, F) # ✅ s_{t+1} — F_{t+1}-measurable state captured at trade CLOSE # _build_input_tensors reflects all buffer updates that happened # during the trade lifetime — guaranteed temporal separation [S2]. next_seq_t, _, _ = self._build_input_tensors() next_sequences = next_seq_t.squeeze(0).numpy() # (N, T, F) # ── Episode validity check: s_t and s_{t+1} must differ ──────────── # If they are identical the TD error collapses to zero and training # produces no gradient. This happens when a trade opens and closes # within the same rank cycle before any price ticks arrive. # In that case we still push the episode but tag it so HybridTrainer # can apply a lower importance weight. # # [FIX] Guard against N-mismatch: if assets were added between trade # open (s_t) and close (s_{t+1}), the shapes differ and subtraction # would crash. In that case we skip the state_diff check and mark # the episode with state_diff=-1 so the trainer's N-filter drops it. if sequences.shape != next_sequences.shape: logger.warning( f"[_close_pending_episode] [{asset_id}] s_t shape {sequences.shape} " f"≠ s_{{t+1}} shape {next_sequences.shape} — asset list changed " f"during this trade. Episode will be discarded by trainer N-filter." ) state_diff = -1.0 # sentinel: trainer N-filter will drop this episode else: state_diff = float(np.mean(np.abs(next_sequences - sequences))) if state_diff < 1e-6: logger.warning( f"[_close_pending_episode] [{asset_id}] s_t ≈ s_{{t+1}} " f"(diff={state_diff:.2e}) — episode may not produce useful gradient. " f"Trade may have closed before a rank cycle completed." ) # Still push — the buffer needs data; the trainer will use low importance_weight pnl_proxy = np.zeros(n_assets, dtype=np.float32) if asset_id in self.config.asset_symbols: pnl_proxy[self.config.asset_symbols.index(asset_id)] = float(reward) self.replay.push({ "sequences": sequences.tolist(), # s_t ∈ F_t (N, T, F) "next_sequences": next_sequences.tolist(), # s_{t+1} ∈ F_{t+1} "selected_idx": selected_idx, "reward": float(reward), # log-return units [S4] "pnl_per_asset": pnl_proxy.tolist(), # Girsanov priority metadata [S2/GirsanovReplayBuffer] "volatility": ep["volatility"], "td_error": abs(reward), # State diversity marker — used for importance weighting "state_diff": state_diff, }) if asset_id in self.config.asset_symbols: self.bandit.update_reward(asset_id, reward) # ── Run loops ────────────────────────────────────────────────────────────────────── async def run(self) -> None: if not await self.initialize(): raise RuntimeError("Initialization failed — see logs above") self.running = True logger.info("🚀 QuasarAXRVIBridge v5 running") try: await asyncio.gather( self.ws_client.listen(), self._rank_loop(), self.monitor_positions(), # [HANG FIX — Layer 3] Watchdog task: auto-recover if the rank # loop stops completing cycles (ws.send() hang, deadlock, etc.) self._rank_watchdog(), ) except asyncio.CancelledError: pass except Exception as e: logger.error(f"❌ Run error: {e}") traceback.print_exc() finally: self.running = False await self.ws_client.close() self.hub_subscriber.stop() self.signal_subscriber.stop() # ── Full-state shutdown checkpoint ──────────────────────────────── # Saves everything: model, replay buffer, bandit, normalizers, etc. # This is the "progress matters" save — nothing is lost on restart. if self.trainer: logger.info("[RankerCheckpoint] 🛑 Shutdown detected — saving full state…") self.checkpoint_mgr.save(self, reason="shutdown") self.checkpoint_mgr.print_summary() # Also write legacy single-file model for backward compat self.trainer.save(self.config.model_path) if self.ranker_logger: self.ranker_logger.connection_event("Ranker", "stopped", "System shutdown") logger.info("✅ QuasarAXRVIBridge v5 stopped") async def _rank_loop(self) -> None: while self.running: try: await self.rank_and_gate() except Exception as e: # [HANG FIX — Layer 4] exc_info=True so we get a full traceback # in the ranker log the next time something breaks silently. # Previously errors were logged as a one-liner with no stack # frame, making root-cause diagnosis impossible. logger.error(f"❌ Rank loop error: {e}", exc_info=True) await asyncio.sleep(self.config.update_frequency_seconds) # ── [HANG FIX — Layer 3] Rank-loop watchdog ─────────────────────────────── # Periodically checks self._last_rank_complete_ts. If no rank cycle has # completed within RANK_STALL_THRESHOLD_S, assumes the rank loop is hung # (almost always on a ws.send() to a half-open Deriv connection) and # force-closes the websocket. That raises ConnectionClosed inside any # pending send/recv, unblocking the await and triggering reconnect(). # # This is the safety net: even if a NEW bug introduces a different hang, # the system auto-recovers within RANK_STALL_THRESHOLD_S instead of # sitting dead until a human restarts the Space. RANK_STALL_THRESHOLD_S: float = 120.0 # 4× the expected worst-case cycle RANK_WATCHDOG_POLL_S: float = 30.0 async def _rank_watchdog(self) -> None: """Force-close the Deriv ws if the rank loop stops completing cycles.""" logger.info( f"🐕 [RankWatchdog] started | stall_threshold={self.RANK_STALL_THRESHOLD_S}s " f"| poll_interval={self.RANK_WATCHDOG_POLL_S}s" ) while self.running: await asyncio.sleep(self.RANK_WATCHDOG_POLL_S) if not self.running: break since_last = time.time() - self._last_rank_complete_ts if since_last > self.RANK_STALL_THRESHOLD_S: logger.critical( f"🚨 [RankWatchdog] Rank loop has not completed a cycle for " f"{since_last:.0f}s (threshold={self.RANK_STALL_THRESHOLD_S}s). " f"Forcing Deriv ws close to unblock any pending send()." ) # Force-close the ws. This raises ConnectionClosed inside # whatever coroutine is awaiting ws.send() or ws.recv(), # unblocking the rank loop. listen()'s except branch will # then drive reconnect(). try: if self.ws_client and self.ws_client.ws: await self.ws_client.ws.close() except Exception as close_err: logger.warning( f"[RankWatchdog] ws.close() raised (expected on dead socket): {close_err}" ) # Reset the stamp so we don't spam CRITICAL every poll interval # while reconnect is in progress. self._last_rank_complete_ts = time.time() else: logger.debug( f"[RankWatchdog] healthy | last_cycle={since_last:.1f}s ago" ) def start_sync(self) -> None: """Start in synchronous (threading) mode.""" def _run_loop(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) self._sync_loop = loop try: loop.run_until_complete(self.run()) finally: loop.close() self._sync_thread = threading.Thread( target=_run_loop, daemon=True, name="QuasarAXRVISync" ) self._sync_thread.start() logger.info("✅ QuasarAXRVIBridge v5 started (sync mode)") def stop_sync(self) -> None: self.running = False if self._sync_thread: self._sync_thread.join(timeout=10) logger.info("✅ QuasarAXRVIBridge v5 stopped") # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 18a — HUGGING FACE DATASET CHECKPOINT SYNC LAYER # ══════════════════════════════════════════════════════════════════════════════════════ # # HFDatasetCheckpointManager — mirrors every successful local .pt checkpoint # (written by QuasarCheckpointManager) to a Hugging Face Dataset repository. # # Environment variables (set these to enable HF sync): # HF_TOKEN — Hugging Face API token with write access # HF_REPO_ID — Dataset repo, e.g. "YourUsername/quasar-checkpoints" # # The HF layout exactly matches the local layout: # step_XXXXXXX.pt — full-state checkpoint (identical bytes to local) # ranker_index.json — master index compatible with QuasarCheckpointManager # latest — sentinel file holding the latest step number # # No existing save/load path is broken if HF is unavailable; all HF operations # are wrapped in try/except and log warnings rather than raising. class HFDatasetCheckpointManager: """ Hugging Face Dataset mirror for QUASAR AXRVI Ranker checkpoints. Mirrors local step_XXXXXXX.pt files (produced by QuasarCheckpointManager) to a HF Dataset repo in a background thread so the live trading loop is never blocked. The same naming convention is used so the index file stays compatible with local reload. """ INDEX_FILENAME = "ranker_index.json" LATEST_SENTINEL = "latest" def __init__( self, repo_id: str, token: Optional[str] = None, verbose: bool = True, ): self.repo_id = repo_id self.token = token or os.environ.get("HF_TOKEN") self.verbose = verbose self._hfapi = None # lazy-loaded HfApi instance self._hf_hub_dl = None # lazy-loaded hf_hub_download self._upload_count = 0 self._download_count = 0 self._last_upload_step: Optional[int] = None if self.verbose: logger.info( f"☁️ HFDatasetCheckpointManager | repo={self.repo_id} | " f"token={'✅' if self.token else '❌ missing — HF sync disabled'}" ) # ── Lazy HF import ───────────────────────────────────────────────────────── def _ensure_hf(self) -> bool: """Return True if huggingface_hub is importable and token is set.""" if not self.token: return False if self._hfapi is not None: return True try: from huggingface_hub import HfApi, hf_hub_download self._hfapi = HfApi(token=self.token) self._hf_hub_dl = hf_hub_download return True except ImportError: logger.warning( "☁️ HF sync unavailable — install huggingface_hub: pip install huggingface_hub" ) return False # ── HF path helpers ──────────────────────────────────────────────────────── def _hf_ckpt_path(self, step: int) -> str: return f"step_{step:07d}.pt" def _ensure_repo_exists(self) -> None: """Create the HF Dataset repo if it does not already exist. No-op if it exists.""" if not self._ensure_hf(): return try: self._hfapi.create_repo( repo_id = self.repo_id, repo_type = "dataset", exist_ok = True, # safe to call even if repo already exists private = True, ) if self.verbose: logger.info(f"☁️ HF repo ready (created or already exists): {self.repo_id}") except Exception as exc: logger.warning(f"☁️ Could not ensure HF repo exists (non-fatal): {exc}") # ── Upload ───────────────────────────────────────────────────────────────── def upload( self, local_path: Union[str, Path], step: int, metadata: Optional[Dict] = None, ) -> bool: """ Upload *local_path* to HF Dataset as step_XXXXXXX.pt. Updates ranker_index.json and latest sentinel on HF. FIX 3: All three files (.pt, index, sentinel) are batched into a SINGLE commit via create_commit + CommitOperationAdd, reducing HF commit usage from 3 → 1 per save. This prevents the 256/hr rate limit from being blown by the 11-space fleet. Returns True on success, False on any error (non-raising). """ if not self._ensure_hf(): return False # Auto-create the repo if this is a new HF folder (no-op if it already exists) self._ensure_repo_exists() local_path = Path(local_path) if not local_path.exists(): logger.warning(f"☁️ HF upload skipped — file not found: {local_path}") return False import tempfile import json as _json hf_path = self._hf_ckpt_path(step) try: if self.verbose: logger.info(f"☁️ Uploading {hf_path} → {self.repo_id} (batched commit) …") from huggingface_hub import CommitOperationAdd # ── Build updated index payload ──────────────────────────────────── existing: Dict[str, Any] = {} try: with tempfile.TemporaryDirectory() as tmpdir: idx_file = self._hf_hub_dl( repo_id=self.repo_id, filename=self.INDEX_FILENAME, repo_type="dataset", token=self.token, local_dir=tmpdir, local_dir_use_symlinks=False, ) with open(idx_file) as fh: existing = _json.load(fh) except Exception: pass # no existing index — start fresh entry: Dict[str, Any] = { "step": step, "filename": hf_path, "size_mb": round(local_path.stat().st_size / 1_048_576, 2), "timestamp": datetime.now().isoformat(), } if metadata: entry.update(metadata) if "checkpoints" not in existing: existing["checkpoints"] = [] existing["checkpoints"] = [ cp for cp in existing["checkpoints"] if cp.get("step") != step ] existing["checkpoints"].append(entry) existing["checkpoints"].sort(key=lambda x: x.get("step", 0)) existing["latest_step"] = step existing["last_updated"] = datetime.now().isoformat() existing["total_checkpoints"] = len(existing["checkpoints"]) index_bytes = _json.dumps(existing, indent=2).encode() sentinel_bytes = str(step).encode() # ── Single batched commit: .pt + index + sentinel ────────────────── with open(local_path, "rb") as pt_fh: pt_bytes = pt_fh.read() self._hfapi.create_commit( repo_id=self.repo_id, repo_type="dataset", operations=[ CommitOperationAdd( path_in_repo=hf_path, path_or_fileobj=pt_bytes, ), CommitOperationAdd( path_in_repo=self.INDEX_FILENAME, path_or_fileobj=index_bytes, ), CommitOperationAdd( path_in_repo=self.LATEST_SENTINEL, path_or_fileobj=sentinel_bytes, ), ], commit_message=f"checkpoint step={step:07d}", ) self._upload_count += 1 self._last_upload_step = step logger.info( f"☁️ ✅ HF upload complete (1 commit) | step={step} | repo={self.repo_id}" ) return True except Exception as exc: logger.warning(f"☁️ ⚠️ HF upload failed (non-fatal): {exc}") return False # _update_hf_index and _update_hf_latest removed in FIX 3. # All three files (.pt, index, sentinel) are now batched into a single # create_commit call inside upload() to avoid the 3x commit-per-save # rate-limit explosion. # ── Download ─────────────────────────────────────────────────────────────── def download( self, step: int, local_dir: Union[str, Path] = "./hf_restored", ) -> Optional[Path]: """Download step_XXXXXXX.pt from HF Dataset. Returns local Path or None.""" if not self._ensure_hf(): return None local_dir = Path(local_dir) local_dir.mkdir(parents=True, exist_ok=True) hf_path = self._hf_ckpt_path(step) try: logger.info(f"☁️ Downloading {hf_path} ← {self.repo_id} …") dl = self._hf_hub_dl( repo_id=self.repo_id, filename=hf_path, repo_type="dataset", token=self.token, local_dir=str(local_dir), local_dir_use_symlinks=False, ) self._download_count += 1 logger.info(f"☁️ ✅ HF download complete → {dl}") return Path(dl) except Exception as exc: logger.warning(f"☁️ ⚠️ HF download failed (non-fatal): {exc}") return None def download_latest( self, local_dir: Union[str, Path] = "./hf_restored", ) -> Optional[Path]: """Download the latest checkpoint from HF. Returns local Path or None.""" if not self._ensure_hf(): return None import tempfile, json as _json # Try sentinel first (fast path) step = None try: with tempfile.TemporaryDirectory() as tmpdir: sf = self._hf_hub_dl( repo_id=self.repo_id, filename=self.LATEST_SENTINEL, repo_type="dataset", token=self.token, local_dir=tmpdir, local_dir_use_symlinks=False, ) with open(sf) as fh: step = int(fh.read().strip()) except Exception: pass # Fallback: scan index if step is None: try: with tempfile.TemporaryDirectory() as tmpdir: idx_file = self._hf_hub_dl( repo_id=self.repo_id, filename=self.INDEX_FILENAME, repo_type="dataset", token=self.token, local_dir=tmpdir, local_dir_use_symlinks=False, ) with open(idx_file) as fh: idx = _json.load(fh) cps = idx.get("checkpoints", []) if cps: step = max(c["step"] for c in cps) except Exception: pass if step is None: logger.info("☁️ No HF checkpoints found.") return None return self.download(step, local_dir) # ── Stats ────────────────────────────────────────────────────────────────── @property def stats(self) -> Dict[str, Any]: return { "uploads": self._upload_count, "downloads": self._download_count, "last_upload_step": self._last_upload_step, } # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 18a.2 — HF SYNC LAYER (background thread, non-blocking) # ══════════════════════════════════════════════════════════════════════════════════════ class HFSyncLayer: """ Thin coordination wrapper used by RankerCheckpointManager. * After every successful local save, queues an async HF upload so the trading loop is not blocked by network I/O. * On load/resume, tries HF download if no local checkpoint exists. * All operations are best-effort; failures are logged, never raised. Configuration comes entirely from environment variables: HF_TOKEN — HF write token (required to enable sync) HF_REPO_ID — Dataset repo ID (required to enable sync) You can also pass hf_repo_id explicitly to RankerCheckpointManager. """ def __init__(self, hf_repo_id: Optional[str] = None): repo = hf_repo_id or os.environ.get("HF_REPO_ID", "") token = os.environ.get("HF_TOKEN", "") self.enabled = bool(repo and token) self._hf: Optional[HFDatasetCheckpointManager] = None if self.enabled: self._hf = HFDatasetCheckpointManager( repo_id=repo, token=token, verbose=True ) # Background upload queue self._queue: queue.Queue = queue.Queue() self._worker = threading.Thread( target=self._upload_worker, daemon=True, name="HFSyncWorker" ) self._worker.start() logger.info( f"☁️ HFSyncLayer enabled | repo={repo} | background worker started" ) else: if repo or token: # Partial config — warn but don't crash logger.warning( "☁️ HFSyncLayer: HF_TOKEN and HF_REPO_ID must both be set to " "enable HF sync. Running local-only." ) else: logger.info( "☁️ HFSyncLayer: HF_REPO_ID/HF_TOKEN not set — " "local-only checkpoints (set env vars to enable HF sync)" ) # ── Upload path ─────────────────────────────────────────────────────────── def queue_upload( self, local_path: Union[str, Path], step: int, metadata: Optional[Dict] = None, ) -> None: """ Enqueue a background HF upload. Returns immediately; never blocks. Called by RankerCheckpointManager.save() after every successful local save. """ if not self.enabled: return self._queue.put((local_path, step, metadata)) def _upload_worker(self) -> None: """Background thread: drains the upload queue.""" while True: try: local_path, step, metadata = self._queue.get(timeout=5) if self._hf is not None: self._hf.upload(local_path, step, metadata) self._queue.task_done() except queue.Empty: continue except Exception as exc: logger.warning(f"☁️ HF upload worker error (non-fatal): {exc}") # ── Restore path ────────────────────────────────────────────────────────── def restore_if_needed( self, local_checkpoint_found: bool, local_restore_dir: Union[str, Path] = "./hf_restored", ) -> Optional[Path]: """ If *local_checkpoint_found* is False, try to pull the latest checkpoint from HF and return the local path to the downloaded file. Returns None if HF is disabled, unavailable, or has no checkpoints. Called by RankerCheckpointManager.load() when local state is absent. """ if not self.enabled or local_checkpoint_found: return None if self._hf is None: return None logger.info("☁️ No local checkpoint — attempting HF restore …") return self._hf.download_latest(local_dir=local_restore_dir) # ── Stats ───────────────────────────────────────────────────────────────── @property def upload_count(self) -> int: return self._hf.stats["uploads"] if (self.enabled and self._hf) else 0 @property def download_count(self) -> int: return self._hf.stats["downloads"] if (self.enabled and self._hf) else 0 # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 18b — RANKER CHECKPOINT MANAGER # ══════════════════════════════════════════════════════════════════════════════════════ class RankerCheckpointManager: """ Bridge-facing checkpoint controller for QuasarAXRVIBridge. This is the authoritative checkpoint authority for the live trading system. It wraps QuasarCheckpointManager and exposes the three methods the bridge calls: load(bridge), save(bridge, reason=""), maybe_save(bridge), and print_summary(). All state is routed through correct attribute names on QuasarAXRVIBridge. """ # How often maybe_save() will actually persist (seconds) AUTOSAVE_INTERVAL: int = 300 # 5 minutes def __init__( self, checkpoint_dir: str = "./Ranker10", hf_repo_id: Optional[str] = None, ): self.checkpoint_dir = checkpoint_dir self._engine = QuasarCheckpointManager( checkpoint_dir = checkpoint_dir, max_checkpoints = 10, autosave_interval_minutes = 0, # we handle timing ourselves enable_async_save = False, enable_compression = True, verbose = True, ) # ── HF Sync Layer (non-blocking secondary persistence) ───────────── # Reads HF_TOKEN / HF_REPO_ID from environment; hf_repo_id overrides. self._hf = HFSyncLayer(hf_repo_id=hf_repo_id) self._last_save_time: float = 0.0 self._save_count: int = 0 # ── Public API ──────────────────────────────────────────────────────────── def save(self, bridge, reason: str = "") -> bool: """ Persist the complete live runtime state of *bridge*. Returns True on success. """ try: step = bridge.trainer.train_step if bridge.trainer else 0 ckpt = self._build_checkpoint(bridge, step, reason) path = self._engine.checkpoint_dir / self._engine.get_checkpoint_filename(step) torch.save(ckpt, path, _use_new_zipfile_serialization=True) self._engine._update_index(step, 0, os.path.getsize(path) / 1_048_576, self._engine.get_checkpoint_filename(step)) self._engine._cleanup_old_checkpoints() self._last_save_time = time.time() self._save_count += 1 logger.info(f"✅ [RankerCheckpointManager] Saved step={step} reason='{reason}' → {path}") # Validate the saved file is readable _ok = torch.load(path, map_location="cpu", weights_only=False) assert "step" in _ok, "Checkpoint validation failed: 'step' key missing" # ── HF SYNC: mirror to Hugging Face (non-blocking background upload) ── # Triggered on every successful local save so HF always tracks latest. self._hf.queue_upload( local_path=path, step=step, metadata={ "reason": reason, "rank_count": getattr(bridge, "rank_count", 0), "total_pnl": bridge.stats.get("total_pnl", 0.0) if hasattr(bridge, "stats") else 0.0, "train_step": step, }, ) return True except Exception as e: logger.error(f"❌ [RankerCheckpointManager] Save failed: {e}") traceback.print_exc() return False def load(self, bridge) -> bool: """ Load the latest checkpoint and restore all state into *bridge*. Primary path: local QuasarCheckpointManager index. Fallback path: HF Dataset download (when local has no checkpoints). Returns True if a checkpoint was found and loaded successfully. """ latest = self._engine.get_latest_checkpoint_info() if latest is None: # ── HF RESTORE FALLBACK: no local checkpoint → try HF ──────────── hf_path = self._hf.restore_if_needed( local_checkpoint_found=False, local_restore_dir=str(self._engine.checkpoint_dir / "hf_restored"), ) if hf_path is None: logger.info("[RankerCheckpointManager] No checkpoint found locally or on HF — starting fresh.") return False # HF download succeeded — load from the downloaded file logger.info(f"☁️ [RankerCheckpointManager] Restoring from HF checkpoint: {hf_path}") try: ckpt = torch.load(hf_path, map_location="cpu", weights_only=False) except Exception as e: logger.error(f"❌ [RankerCheckpointManager] Cannot read HF checkpoint {hf_path}: {e}") return False try: self._restore_checkpoint(bridge, ckpt) # Mirror into local index so future local lookups work step = ckpt.get("step", 0) size_mb = hf_path.stat().st_size / 1_048_576 local_dest = self._engine.checkpoint_dir / self._engine.get_checkpoint_filename(step) if not local_dest.exists(): import shutil as _shutil _shutil.copy2(hf_path, local_dest) self._engine._update_index(step, 0, size_mb, self._engine.get_checkpoint_filename(step)) logger.info( f"✅ [RankerCheckpointManager] ☁️ Restored from HF step={ckpt.get('step', 0)}" ) return True except Exception as e: logger.error(f"❌ [RankerCheckpointManager] HF restore failed: {e}") traceback.print_exc() return False # ── Normal local load ───────────────────────────────────────────────── path = self._engine.get_checkpoint_path(latest["step"]) if not path.exists(): logger.warning(f"[RankerCheckpointManager] Checkpoint file missing: {path}") return False try: ckpt = torch.load(path, map_location="cpu", weights_only=False) except Exception as e: logger.error(f"❌ [RankerCheckpointManager] Cannot read checkpoint {path}: {e}") return False try: self._restore_checkpoint(bridge, ckpt) logger.info( f"✅ [RankerCheckpointManager] Restored step={ckpt.get('step', 0)} " f"from {path}" ) return True except Exception as e: logger.error(f"❌ [RankerCheckpointManager] Restore failed: {e}") traceback.print_exc() return False def maybe_save(self, bridge) -> bool: """Save if AUTOSAVE_INTERVAL seconds have elapsed since last save.""" if time.time() - self._last_save_time >= self.AUTOSAVE_INTERVAL: return self.save(bridge, reason="autosave") return False def print_summary(self) -> None: """Print a summary of checkpoint history including HF sync stats.""" index = self._engine.index cps = index.get("checkpoints", []) print(f"\n{'='*60}") print(f"📁 RankerCheckpointManager | dir={self.checkpoint_dir}") print(f" Total saves this session : {self._save_count}") print(f" Checkpoints on disk : {len(cps)}") for cp in cps[-5:]: print(f" ├─ step={cp.get('step',0):>8d} {cp.get('size_mb',0):.2f} MB {cp.get('timestamp','')[:19]}") if self._hf.enabled: hf_obj = self._hf._hf # HFDatasetCheckpointManager repo_id = hf_obj.repo_id if hf_obj else "—" uploads = self._hf.upload_count downloads = self._hf.download_count queue_depth = self._hf._queue.qsize() if hasattr(self._hf, "_queue") else 0 last_step = (hf_obj.stats.get("last_upload_step", "—") if hf_obj else "—") print(f" ☁️ HF repo : {repo_id}") print(f" ☁️ HF uploads (session) : {uploads}") print(f" ☁️ HF downloads (session): {downloads}") print(f" ☁️ HF queue depth : {queue_depth} pending") print(f" ☁️ HF last upload step : {last_step}") else: print(f" ☁️ HF sync: disabled (set HF_TOKEN + HF_REPO_ID, or pass --hf-repo)") print(f"{'='*60}\n") # ── Internal helpers ────────────────────────────────────────────────────── def _build_checkpoint(self, bridge, step: int, reason: str) -> dict: """Collect all live bridge state into a flat checkpoint dict.""" ckpt: dict = { "version": "2.1", "step": step, "reason": reason, "timestamp": datetime.now().isoformat(), "num_assets": bridge.axrvi_net.num_assets if bridge.axrvi_net is not None else 0, "asset_symbols": list(bridge.config.asset_symbols), } # ── Model ───────────────────────────────────────────────────────────── if bridge.axrvi_net is not None: ckpt["axrvi_net"] = bridge.axrvi_net.state_dict() # ── Trainer ─────────────────────────────────────────────────────────── if bridge.trainer is not None: tr = bridge.trainer ckpt["optimizer"] = tr.optimizer.state_dict() ckpt["scheduler"] = tr.scheduler.state_dict() ckpt["train_step"] = tr.train_step ckpt["lambda_ce"] = tr.lambda_ce ckpt["lambda_ql"] = tr.lambda_ql ckpt["lambda_rank"] = tr.lambda_rank ckpt["lambda_risk"] = tr.lambda_risk ckpt["lambda_moe"] = tr.lambda_moe ckpt["lambda_gate"] = tr.lambda_gate ckpt["lambda_crps"] = tr.lambda_crps ckpt["lambda_rent"] = tr.lambda_rent ckpt["lambda_align"] = tr.lambda_align ckpt["rank_margin"] = tr.rank_margin ckpt["loss_history"] = list(tr.loss_history) # ── Replay buffer ───────────────────────────────────────────────────── ckpt["replay"] = bridge.replay.save_state() # ── Bandit ─────────────────────────────────────────────────────────── if bridge.bandit is not None: ckpt["bandit"] = bridge.bandit.state_dict() # ── Portfolio risk manager ───────────────────────────────────────────── ckpt["portfolio_risk_mgr"] = bridge.portfolio_risk_mgr.state_dict() # ── Execution gate ──────────────────────────────────────────────────── ckpt["execution_gate"] = bridge.execution_gate.state_dict() # ── Per-asset feature engine rolling history ────────────────────────── fe_states: dict = {} for asset_id, buf in bridge.asset_buffers.items(): if hasattr(buf, "feature_eng") and buf.feature_eng is not None: fe_states[asset_id] = buf.feature_eng.state_dict() ckpt["feature_engines"] = fe_states # ── Runtime counters & episode tracking ─────────────────────────────── ckpt["rank_count"] = bridge.rank_count ckpt["pending_episodes"] = dict(bridge._pending_episodes) ckpt["trade_tick_counts"] = dict(bridge._trade_tick_counts) return ckpt def _restore_checkpoint(self, bridge, ckpt: dict) -> None: """Restore all state from *ckpt* into the active bridge objects.""" # ── Validate critical keys ───────────────────────────────────────────── assert "step" in ckpt, "Checkpoint missing 'step' key" # ── num_assets compatibility guard ──────────────────────────────────── # If the checkpoint was saved with a different asset count (e.g. 12 before # V50/V30_1s were removed), the QCSAM/FABLE weight tensors will have the # wrong shape. We detect this early and skip model + optimizer weights so # the rest of the state (replay, bandit, counters) can still be restored. ckpt_num_assets = ckpt.get("num_assets", -1) current_num_assets = bridge.axrvi_net.num_assets if bridge.axrvi_net is not None else -1 _model_compatible = True if ckpt_num_assets != -1 and ckpt_num_assets != current_num_assets: ckpt_assets = ckpt.get("asset_symbols", "unknown") logger.warning( f"⚠️ [Restore] Asset count mismatch: checkpoint has {ckpt_num_assets} assets " f"({ckpt_assets}), but current model has {current_num_assets} assets " f"({list(bridge.config.asset_symbols)}). " f"Skipping axrvi_net + optimizer weights — model starts fresh. " f"All other state (replay, bandit, counters) will be restored." ) _model_compatible = False # ── Model ───────────────────────────────────────────────────────────── if "axrvi_net" in ckpt and bridge.axrvi_net is not None: if _model_compatible: incompatible = bridge.axrvi_net.load_state_dict(ckpt["axrvi_net"], strict=False) if incompatible.missing_keys: logger.warning(f"[Restore] axrvi_net missing keys: {incompatible.missing_keys}") if incompatible.unexpected_keys: logger.warning(f"[Restore] axrvi_net unexpected keys: {incompatible.unexpected_keys}") logger.info(" ✅ axrvi_net restored") else: logger.info(" ⏭️ axrvi_net skipped (asset count mismatch — fresh weights kept)") # ── Trainer ─────────────────────────────────────────────────────────── if bridge.trainer is not None: tr = bridge.trainer if _model_compatible: if "optimizer" in ckpt: tr.optimizer.load_state_dict(ckpt["optimizer"]) if "scheduler" in ckpt: tr.scheduler.load_state_dict(ckpt["scheduler"]) else: logger.info(" ⏭️ optimizer/scheduler skipped (asset count mismatch — fresh state kept)") tr.train_step = ckpt.get("train_step", ckpt.get("step", tr.train_step)) tr.lambda_ce = ckpt.get("lambda_ce", tr.lambda_ce) tr.lambda_ql = ckpt.get("lambda_ql", tr.lambda_ql) tr.lambda_rank = ckpt.get("lambda_rank", tr.lambda_rank) tr.lambda_risk = ckpt.get("lambda_risk", tr.lambda_risk) tr.lambda_moe = ckpt.get("lambda_moe", tr.lambda_moe) tr.lambda_gate = ckpt.get("lambda_gate", tr.lambda_gate) tr.lambda_crps = ckpt.get("lambda_crps", tr.lambda_crps) tr.lambda_rent = ckpt.get("lambda_rent", tr.lambda_rent) tr.lambda_align = ckpt.get("lambda_align", tr.lambda_align) tr.rank_margin = ckpt.get("rank_margin", tr.rank_margin) if "loss_history" in ckpt: # FIX 4: replace (not extend) so history isn't doubled on restore _maxlen = tr.loss_history.maxlen tr.loss_history = deque(ckpt["loss_history"], maxlen=_maxlen) logger.info(f" ✅ trainer restored | train_step={tr.train_step}") # ── Replay buffer ───────────────────────────────────────────────────── if "replay" in ckpt: bridge.replay.load_state(ckpt["replay"]) logger.info(f" ✅ replay restored | size={len(bridge.replay)}") # ── Bandit ─────────────────────────────────────────────────────────── if "bandit" in ckpt and bridge.bandit is not None: bridge.bandit.load_state_dict(ckpt["bandit"]) logger.info(" ✅ bandit restored") # ── Portfolio risk manager ───────────────────────────────────────────── if "portfolio_risk_mgr" in ckpt: bridge.portfolio_risk_mgr.load_state_dict(ckpt["portfolio_risk_mgr"]) logger.info(" ✅ portfolio_risk_mgr restored") # ── Execution gate ──────────────────────────────────────────────────── if "execution_gate" in ckpt: bridge.execution_gate.load_state_dict(ckpt["execution_gate"]) logger.info(" ✅ execution_gate restored") # ── Per-asset feature engine rolling history ────────────────────────── fe_states = ckpt.get("feature_engines", {}) for asset_id, state in fe_states.items(): buf = bridge.asset_buffers.get(asset_id) if buf is not None and hasattr(buf, "feature_eng") and buf.feature_eng is not None: buf.feature_eng.load_state_dict(state) logger.info(f" ✅ feature_engine[{asset_id}] restored") else: logger.warning(f" ⚠️ feature_engine[{asset_id}] not in bridge.asset_buffers — skipped") # ── Runtime counters (safe to restore) ─────────────────────────────── bridge.rank_count = ckpt.get("rank_count", bridge.rank_count) # FIX 5: Do NOT restore pending_episodes or trade_tick_counts. # These reference Deriv multiplier contracts that are gone after a restart. # Restoring them causes phantom open_count > 0, which blocks the top-4 # floor enforcer from opening any new trades and corrupts RL episode pairing. # rank_count is the only runtime counter safe to restore across restarts. stale_ep = ckpt.get("pending_episodes", {}) stale_tc = ckpt.get("trade_tick_counts", {}) if stale_ep: logger.warning( f" ⚠️ Discarded {len(stale_ep)} stale pending_episodes from checkpoint " "(Deriv contracts expired across restart — resetting to avoid phantom positions)" ) if stale_tc: logger.info( f" ℹ️ Discarded {len(stale_tc)} stale trade_tick_counts (reset on restart)" ) # bridge._pending_episodes and bridge._trade_tick_counts stay at their # freshly-initialised empty-dict state — no update needed. # Post-load validation assert bridge.rank_count >= 0, "rank_count must be non-negative after restore" # ============================================================================ # LOAD STATUS ENUMS # ============================================================================ class LoadStatus: SUCCESS = "success" FALLBACK = "fallback" REINITIALIZED = "reinitialized" SKIPPED = "skipped" NOT_FOUND = "not_found" FAILED = "failed" @dataclass class ComponentLoadResult: """Result of loading a single component.""" name: str status: str = LoadStatus.FAILED success: bool = False source_checkpoint: str = "" source_step: int = 0 keys_loaded: int = 0 load_time_ms: float = 0 error: str = "" fallback_attempts: int = 0 @dataclass class CheckpointLoadResult: """Complete checkpoint load result.""" success: bool = False training_steps: int = 0 episode_number: int = 0 timestamp: str = "" components: Dict[str, ComponentLoadResult] = field(default_factory=dict) errors: List[str] = field(default_factory=list) total_load_time: float = 0 def get_summary(self) -> Dict[str, Any]: succeeded = sum(1 for c in self.components.values() if c.success) fallback = sum(1 for c in self.components.values() if c.status == LoadStatus.FALLBACK) return { "total_components": len(self.components), "succeeded": succeeded, "fallback": fallback, "failed": len(self.components) - succeeded } @dataclass class CheckpointSaveResult: """Checkpoint save result.""" success: bool = False training_steps: int = 0 episode_number: int = 0 local_path: str = "" file_size_mb: float = 0 total_save_time: float = 0 components_saved: List[str] = field(default_factory=list) errors: List[str] = field(default_factory=list) # ============================================================================ # COMPONENT REINITIALIZER # ============================================================================ class ComponentReinitializer: """Reinitialize components when checkpoint loading fails.""" @staticmethod def reinit_module(module: nn.Module, name: str) -> bool: """Reinitialize module weights.""" try: def reset_weights(m): if hasattr(m, 'reset_parameters'): m.reset_parameters() elif hasattr(m, 'initialize'): m.initialize() module.apply(reset_weights) print(f" 🔄 {name}: Reinitialized fresh") return True except Exception as e: print(f" ❌ {name}: Reinit failed - {e}") return False # ============================================================================ # QUASAR AXRVI RANKER CHECKPOINT MANAGER # ============================================================================ class QuasarCheckpointManager: """ Native checkpoint manager for QUASAR AXRVI Ranker. Simplified structure (no instruments, no voyages): ./checkpoints/ ├── ranker_index.json # Master index of all checkpoints ├── latest # Symlink/sentinel to latest checkpoint └── step_0000012345.pt # Individual checkpoint files Features: - Automatic latest checkpoint detection on restart - Full state saving (weights, optimizer, replay buffer, normalization) - Async saving support - Automatic cleanup of old checkpoints - Graceful fallback on corrupted checkpoints """ INDEX_FILENAME = "ranker_index.json" # Unique name to avoid conflicts LATEST_SENTINEL = "latest" # File pointing to latest checkpoint CHECKPOINT_PREFIX = "step_" CHECKPOINT_SUFFIX = ".pt" def __init__( self, checkpoint_dir: str = "./checkpoints", max_checkpoints: int = 10, # Keep last N checkpoints autosave_interval_minutes: int = 5, enable_async_save: bool = True, enable_compression: bool = True, verbose: bool = True ): """ Initialize QUASAR checkpoint manager. Args: checkpoint_dir: Directory to store checkpoints max_checkpoints: Maximum number of checkpoints to keep autosave_interval_minutes: Auto-save interval (0 to disable) enable_async_save: Save in background thread enable_compression: Use PyTorch's compressed serialization verbose: Print detailed logs """ self.checkpoint_dir = Path(checkpoint_dir).resolve() self.checkpoint_dir.mkdir(parents=True, exist_ok=True) self.max_checkpoints = max_checkpoints self.enable_async_save = enable_async_save self.enable_compression = enable_compression self.verbose = verbose # Async save queue self.save_queue = queue.Queue() if enable_async_save else None self.save_worker_thread = None if enable_async_save: self._start_save_worker() # Autosave self.autosave_interval = autosave_interval_minutes * 60 self.last_save_time = time.time() self.autosave_enabled = False self.autosave_thread = None # Statistics self.save_count = 0 self.load_count = 0 # Load or create index self.index = self._load_or_create_index() # Reinitializer for fallback self.reinitializer = ComponentReinitializer() if self.verbose: self._print_init_banner() def _print_init_banner(self): """Print initialization banner.""" print(f"\n{'='*80}") print(f"⚡ QUASAR AXRVI RANKER CHECKPOINT MANAGER V1.0") print(f"{'='*80}") print(f" Checkpoint Directory: {self.checkpoint_dir}") print(f" Max Checkpoints: {self.max_checkpoints}") print(f" Async Save: {self.enable_async_save}") print(f" Compression: {self.enable_compression}") print(f" Known Checkpoints: {len(self.index.get('checkpoints', []))}") latest = self.get_latest_checkpoint_info() if latest: print(f" Latest Checkpoint: step_{latest['step']:07d} ({latest.get('timestamp', 'unknown')})") else: print(f" Latest Checkpoint: None (fresh start)") print(f"{'='*80}\n") # ======================================================================== # INDEX MANAGEMENT # ======================================================================== def _get_index_path(self) -> Path: """Get path to ranker index file.""" return self.checkpoint_dir / self.INDEX_FILENAME def _get_latest_sentinel_path(self) -> Path: """Get path to latest sentinel file.""" return self.checkpoint_dir / self.LATEST_SENTINEL def _load_or_create_index(self) -> Dict[str, Any]: """Load existing index or create new one.""" index_path = self._get_index_path() if index_path.exists(): try: with open(index_path, 'r') as f: index = json.load(f) # Validate index structure if 'checkpoints' not in index: index['checkpoints'] = [] if 'version' not in index: index['version'] = '1.0' # Clean up missing checkpoints valid_checkpoints = [] for cp in index.get('checkpoints', []): cp_path = self.checkpoint_dir / cp.get('filename', '') if cp_path.exists(): valid_checkpoints.append(cp) else: print(f" ⚠️ Removing missing checkpoint: {cp.get('filename')}") index['checkpoints'] = valid_checkpoints return index except Exception as e: print(f" ⚠️ Failed to load index: {e}, creating new") # Create new index return { "version": "1.0", "created": datetime.now().isoformat(), "last_updated": datetime.now().isoformat(), "checkpoints": [] } def _save_index(self): """Save ranker index to disk.""" self.index["last_updated"] = datetime.now().isoformat() with open(self._get_index_path(), 'w') as f: json.dump(self.index, f, indent=2) def _update_index(self, training_steps: int, episode_number: int, file_size_mb: float, filename: str): """Update index with new checkpoint.""" # Add new checkpoint checkpoint_entry = { "step": training_steps, "episode": episode_number, "filename": filename, "size_mb": round(file_size_mb, 2), "timestamp": datetime.now().isoformat() } # Remove existing entry with same step if present self.index["checkpoints"] = [ cp for cp in self.index.get("checkpoints", []) if cp.get("step") != training_steps ] # Add new entry and sort self.index.setdefault("checkpoints", []).append(checkpoint_entry) self.index["checkpoints"].sort(key=lambda x: x.get("step", 0)) # Update latest self.index["latest_step"] = training_steps self.index["latest_episode"] = episode_number self.index["last_updated"] = datetime.now().isoformat() # Save index self._save_index() # Update latest sentinel self._update_latest_sentinel(training_steps, filename) # Cleanup old checkpoints self._cleanup_old_checkpoints() def _update_latest_sentinel(self, training_steps: int, filename: str): """Update latest sentinel file.""" sentinel_path = self._get_latest_sentinel_path() try: with open(sentinel_path, 'w') as f: f.write(f"{training_steps}\n{filename}") except Exception as e: if self.verbose: print(f" ⚠️ Failed to update sentinel: {e}") def _cleanup_old_checkpoints(self): """Remove old checkpoints beyond max_checkpoints.""" checkpoints = self.index.get("checkpoints", []) if len(checkpoints) <= self.max_checkpoints: return # Keep only the most recent N checkpoints to_remove = checkpoints[:-self.max_checkpoints] for cp in to_remove: filename = cp.get("filename") if filename: cp_path = self.checkpoint_dir / filename if cp_path.exists(): cp_path.unlink() if self.verbose: print(f" 🗑️ Removed old checkpoint: {filename}") # Update index self.index["checkpoints"] = checkpoints[-self.max_checkpoints:] self._save_index() # ======================================================================== # CHECKPOINT INFO # ======================================================================== def get_checkpoint_filename(self, training_steps: int) -> str: """Generate checkpoint filename.""" return f"{self.CHECKPOINT_PREFIX}{training_steps:07d}{self.CHECKPOINT_SUFFIX}" def get_checkpoint_path(self, training_steps: int) -> Path: """Get full checkpoint path.""" return self.checkpoint_dir / self.get_checkpoint_filename(training_steps) def list_checkpoints(self) -> List[Dict[str, Any]]: """List all available checkpoints.""" return self.index.get("checkpoints", []).copy() def get_latest_checkpoint_info(self) -> Optional[Dict[str, Any]]: """Get information about the latest checkpoint.""" checkpoints = self.list_checkpoints() return checkpoints[-1] if checkpoints else None def get_latest_step(self) -> int: """Get the latest training step from index.""" latest = self.get_latest_checkpoint_info() return latest.get("step", 0) if latest else 0 def checkpoint_exists(self, training_steps: int) -> bool: """Check if a checkpoint exists.""" cp_path = self.get_checkpoint_path(training_steps) return cp_path.exists() # ======================================================================== # SAVE CHECKPOINT # ======================================================================== def save_checkpoint( self, ranker, # QUASAR AXRVI Ranker instance training_steps: int, episode_number: int = 0, force_sync: bool = False, **extra_data ) -> CheckpointSaveResult: """ Save complete ranker state. Args: ranker: QUASAR AXRVI Ranker instance with all components training_steps: Current training step episode_number: Current episode number force_sync: Force synchronous save (ignore async setting) **extra_data: Additional data to include in checkpoint Returns: CheckpointSaveResult with save details """ start_time = time.time() result = CheckpointSaveResult( success=False, training_steps=training_steps, episode_number=episode_number ) if self.verbose: print(f"\n{'='*80}") print(f"💾 SAVING CHECKPOINT") print(f" Step: {training_steps}") print(f" Episode: {episode_number}") print(f"{'='*80}") # Build checkpoint dictionary checkpoint = { 'version': '1.0', 'training_steps': training_steps, 'episode_number': episode_number, 'timestamp': datetime.now().isoformat(), 'pytorch_version': torch.__version__, } # ==================================================================== # SAVE MODEL COMPONENTS # ==================================================================== # Primary model: bridge.axrvi_net (NOT ranker.quantum_system) if hasattr(ranker, 'axrvi_net') and ranker.axrvi_net is not None: checkpoint['axrvi_net'] = ranker.axrvi_net.state_dict() result.components_saved.append('axrvi_net') # ==================================================================== # SAVE OPTIMIZER STATES (bridge.trainer.optimizer / .scheduler) # ==================================================================== if hasattr(ranker, 'trainer') and ranker.trainer is not None: tr = ranker.trainer if hasattr(tr, 'optimizer') and tr.optimizer: checkpoint['optimizer'] = tr.optimizer.state_dict() result.components_saved.append('optimizer') if hasattr(tr, 'scheduler') and tr.scheduler: checkpoint['scheduler'] = tr.scheduler.state_dict() result.components_saved.append('scheduler') # ==================================================================== # SAVE TRAINING STATE # ==================================================================== # Replay buffer: bridge.replay (NOT ranker.replay_buffer) if hasattr(ranker, 'replay') and ranker.replay is not None: if hasattr(ranker.replay, 'save_state'): checkpoint['replay'] = ranker.replay.save_state() result.components_saved.append('replay') # Bandit: bridge.bandit (NOT ranker.bandit_preferences) if hasattr(ranker, 'bandit') and ranker.bandit is not None: if hasattr(ranker.bandit, 'state_dict'): checkpoint['bandit'] = ranker.bandit.state_dict() result.components_saved.append('bandit') # Portfolio: bridge.portfolio_risk_mgr (NOT ranker.portfolio_state) if hasattr(ranker, 'portfolio_risk_mgr') and ranker.portfolio_risk_mgr is not None: if hasattr(ranker.portfolio_risk_mgr, 'state_dict'): checkpoint['portfolio_risk_mgr'] = ranker.portfolio_risk_mgr.state_dict() result.components_saved.append('portfolio_risk_mgr') # Execution gate: bridge.execution_gate if hasattr(ranker, 'execution_gate') and ranker.execution_gate is not None: if hasattr(ranker.execution_gate, 'state_dict'): checkpoint['execution_gate'] = ranker.execution_gate.state_dict() result.components_saved.append('execution_gate') # Per-asset feature engine rolling history fe_states: dict = {} if hasattr(ranker, 'asset_buffers'): for asset_id, buf in ranker.asset_buffers.items(): if hasattr(buf, 'feature_eng') and buf.feature_eng is not None: fe_states[asset_id] = buf.feature_eng.state_dict() if fe_states: checkpoint['feature_engines'] = fe_states result.components_saved.append('feature_engines') # Pending episodes: bridge._pending_episodes (NOT ranker.pending_episodes) if hasattr(ranker, '_pending_episodes'): checkpoint['pending_episodes'] = dict(ranker._pending_episodes) result.components_saved.append('pending_episodes') # Tick counts: bridge._trade_tick_counts (NOT ranker.tick_counts) if hasattr(ranker, '_trade_tick_counts'): checkpoint['trade_tick_counts'] = dict(ranker._trade_tick_counts) result.components_saved.append('trade_tick_counts') # Runtime counters if hasattr(ranker, 'rank_count'): checkpoint['rank_count'] = ranker.rank_count result.components_saved.append('rank_count') # Trainer scalar state if hasattr(ranker, 'trainer') and ranker.trainer is not None: tr = ranker.trainer checkpoint['train_step'] = tr.train_step checkpoint['lambda_ce'] = tr.lambda_ce checkpoint['lambda_ql'] = tr.lambda_ql checkpoint['lambda_rank'] = tr.lambda_rank checkpoint['lambda_risk'] = tr.lambda_risk checkpoint['lambda_moe'] = tr.lambda_moe checkpoint['lambda_gate'] = tr.lambda_gate checkpoint['lambda_crps'] = tr.lambda_crps checkpoint['lambda_rent'] = tr.lambda_rent checkpoint['lambda_align'] = tr.lambda_align checkpoint['rank_margin'] = tr.rank_margin checkpoint['loss_history'] = list(tr.loss_history) result.components_saved.append('trainer_scalars') # Extra data checkpoint['extra'] = extra_data # ==================================================================== # WRITE TO DISK # ==================================================================== filename = self.get_checkpoint_filename(training_steps) checkpoint_path = self.checkpoint_dir / filename try: # Save checkpoint if self.enable_compression: torch.save(checkpoint, checkpoint_path, _use_new_zipfile_serialization=True) else: torch.save(checkpoint, checkpoint_path) # Get file size result.file_size_mb = os.path.getsize(checkpoint_path) / (1024 * 1024) result.local_path = str(checkpoint_path) # Update index self._update_index(training_steps, episode_number, result.file_size_mb, filename) result.success = True result.total_save_time = time.time() - start_time self.save_count += 1 self.last_save_time = time.time() if self.verbose: print(f"\n ✅ SAVED: {filename} ({result.file_size_mb:.2f} MB)") print(f" Components: {len(result.components_saved)}") print(f" Time: {result.total_save_time:.2f}s") print(f"{'='*80}\n") except Exception as e: result.success = False result.errors.append(str(e)) if self.verbose: print(f"\n ❌ SAVE FAILED: {e}") print(f"{'='*80}\n") return result # ======================================================================== # LOAD CHECKPOINT # ======================================================================== def load_checkpoint( self, ranker, # QUASAR AXRVI Ranker instance training_steps: Optional[int] = None, device: str = 'cpu', # FIX 7: was 'cuda' — HF Spaces are CPU-only, 'cuda' crashes strict: bool = True, load_optimizer: bool = True, load_replay_buffer: bool = True, load_normalizer: bool = True ) -> CheckpointLoadResult: """ Load ranker state from checkpoint. Args: ranker: QUASAR AXRVI Ranker instance to load into training_steps: Specific step to load (None = latest) device: Device to load tensors to strict: Strict state dict loading load_optimizer: Load optimizer state load_replay_buffer: Load replay buffer load_normalizer: Load normalization state Returns: CheckpointLoadResult with load details """ start_time = time.time() result = CheckpointLoadResult(success=False) # Determine which checkpoint to load if training_steps is None: latest = self.get_latest_checkpoint_info() if latest is None: result.errors.append("No checkpoint found") if self.verbose: print(f"\n ℹ️ No checkpoint found - starting fresh") return result training_steps = latest.get("step") # Check if checkpoint exists checkpoint_path = self.get_checkpoint_path(training_steps) if not checkpoint_path.exists(): result.errors.append(f"Checkpoint not found: step_{training_steps}") if self.verbose: print(f"\n ❌ Checkpoint not found: step_{training_steps}") return result if self.verbose: print(f"\n{'='*80}") print(f"📥 LOADING CHECKPOINT") print(f" Step: {training_steps}") print(f" Path: {checkpoint_path}") print(f"{'='*80}") # Load checkpoint file try: checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) except Exception as e: result.errors.append(f"Failed to load checkpoint: {e}") if self.verbose: print(f"\n ❌ Failed to load: {e}") return result # Extract metadata result.training_steps = checkpoint.get('training_steps', training_steps) result.episode_number = checkpoint.get('episode_number', 0) result.timestamp = checkpoint.get('timestamp', '') # ==================================================================== # LOAD MODEL COMPONENTS # ==================================================================== print(f"\n 📦 LOADING COMPONENTS:") print(f" {'-'*70}") # ── Model: bridge.axrvi_net ──────────────────────────────────────────── if hasattr(ranker, 'axrvi_net') and ranker.axrvi_net is not None: if 'axrvi_net' in checkpoint: try: incompatible = ranker.axrvi_net.load_state_dict( checkpoint['axrvi_net'], strict=strict) result.components['axrvi_net'] = ComponentLoadResult( name='axrvi_net', status=LoadStatus.SUCCESS, success=True, source_checkpoint=f"step_{training_steps}" ) print(f" ✅ axrvi_net") except Exception as e: result.components['axrvi_net'] = ComponentLoadResult( name='axrvi_net', status=LoadStatus.FAILED, success=False, error=str(e) ) print(f" ❌ axrvi_net: {str(e)[:60]}") # ==================================================================== # LOAD OPTIMIZER / TRAINER STATE # ==================================================================== if load_optimizer and hasattr(ranker, 'trainer') and ranker.trainer is not None: tr = ranker.trainer # optimizer if hasattr(tr, 'optimizer') and tr.optimizer and 'optimizer' in checkpoint: try: tr.optimizer.load_state_dict(checkpoint['optimizer']) result.components['optimizer'] = ComponentLoadResult( name='optimizer', status=LoadStatus.SUCCESS, success=True, source_checkpoint=f"step_{training_steps}" ) print(f" ✅ optimizer") except Exception as e: result.components['optimizer'] = ComponentLoadResult( name='optimizer', status=LoadStatus.FAILED, success=False, error=str(e) ) print(f" ⚠️ optimizer: {str(e)[:50]}") # scheduler if hasattr(tr, 'scheduler') and tr.scheduler and 'scheduler' in checkpoint: try: tr.scheduler.load_state_dict(checkpoint['scheduler']) result.components['scheduler'] = ComponentLoadResult( name='scheduler', status=LoadStatus.SUCCESS, success=True, source_checkpoint=f"step_{training_steps}" ) print(f" ✅ scheduler") except Exception as e: result.components['scheduler'] = ComponentLoadResult( name='scheduler', status=LoadStatus.FAILED, success=False, error=str(e) ) print(f" ⚠️ scheduler: {str(e)[:50]}") # scalar trainer state tr.train_step = checkpoint.get("train_step", checkpoint.get("step", tr.train_step)) tr.lambda_ce = checkpoint.get("lambda_ce", tr.lambda_ce) tr.lambda_ql = checkpoint.get("lambda_ql", tr.lambda_ql) tr.lambda_rank = checkpoint.get("lambda_rank", tr.lambda_rank) tr.lambda_risk = checkpoint.get("lambda_risk", tr.lambda_risk) tr.lambda_moe = checkpoint.get("lambda_moe", tr.lambda_moe) tr.lambda_gate = checkpoint.get("lambda_gate", tr.lambda_gate) tr.lambda_crps = checkpoint.get("lambda_crps", tr.lambda_crps) tr.lambda_rent = checkpoint.get("lambda_rent", tr.lambda_rent) tr.lambda_align = checkpoint.get("lambda_align", tr.lambda_align) tr.rank_margin = checkpoint.get("rank_margin", tr.rank_margin) if "loss_history" in checkpoint: # FIX 4b: replace (not extend) to avoid doubling history on restore _maxlen = tr.loss_history.maxlen tr.loss_history = deque(checkpoint["loss_history"], maxlen=_maxlen) print(f" ✅ trainer scalars (train_step={tr.train_step})") # ==================================================================== # LOAD TRAINING STATE # ==================================================================== # Replay buffer: bridge.replay if load_replay_buffer and hasattr(ranker, 'replay') and ranker.replay is not None: if 'replay' in checkpoint: try: ranker.replay.load_state(checkpoint['replay']) result.components['replay'] = ComponentLoadResult( name='replay', status=LoadStatus.SUCCESS, success=True, source_checkpoint=f"step_{training_steps}" ) print(f" ✅ replay ({len(ranker.replay)} episodes)") except Exception as e: result.components['replay'] = ComponentLoadResult( name='replay', status=LoadStatus.FAILED, success=False, error=str(e) ) print(f" ⚠️ replay: {str(e)[:50]}") # Bandit: bridge.bandit if hasattr(ranker, 'bandit') and ranker.bandit is not None: if 'bandit' in checkpoint: try: ranker.bandit.load_state_dict(checkpoint['bandit']) result.components['bandit'] = ComponentLoadResult( name='bandit', status=LoadStatus.SUCCESS, success=True, source_checkpoint=f"step_{training_steps}" ) print(f" ✅ bandit") except Exception as e: result.components['bandit'] = ComponentLoadResult( name='bandit', status=LoadStatus.FAILED, success=False, error=str(e) ) print(f" ⚠️ bandit: {str(e)[:50]}") # Portfolio: bridge.portfolio_risk_mgr if hasattr(ranker, 'portfolio_risk_mgr') and ranker.portfolio_risk_mgr is not None: if 'portfolio_risk_mgr' in checkpoint: try: ranker.portfolio_risk_mgr.load_state_dict(checkpoint['portfolio_risk_mgr']) result.components['portfolio_risk_mgr'] = ComponentLoadResult( name='portfolio_risk_mgr', status=LoadStatus.SUCCESS, success=True, source_checkpoint=f"step_{training_steps}" ) print(f" ✅ portfolio_risk_mgr") except Exception as e: result.components['portfolio_risk_mgr'] = ComponentLoadResult( name='portfolio_risk_mgr', status=LoadStatus.FAILED, success=False, error=str(e) ) print(f" ⚠️ portfolio_risk_mgr: {str(e)[:50]}") # Execution gate: bridge.execution_gate if hasattr(ranker, 'execution_gate') and ranker.execution_gate is not None: if 'execution_gate' in checkpoint: try: ranker.execution_gate.load_state_dict(checkpoint['execution_gate']) result.components['execution_gate'] = ComponentLoadResult( name='execution_gate', status=LoadStatus.SUCCESS, success=True, source_checkpoint=f"step_{training_steps}" ) print(f" ✅ execution_gate") except Exception as e: result.components['execution_gate'] = ComponentLoadResult( name='execution_gate', status=LoadStatus.FAILED, success=False, error=str(e) ) print(f" ⚠️ execution_gate: {str(e)[:50]}") # Per-asset feature engines fe_states = checkpoint.get("feature_engines", {}) if fe_states and hasattr(ranker, 'asset_buffers'): for asset_id, state in fe_states.items(): buf = ranker.asset_buffers.get(asset_id) if buf is not None and hasattr(buf, 'feature_eng') and buf.feature_eng is not None: try: buf.feature_eng.load_state_dict(state) print(f" ✅ feature_engine[{asset_id}]") except Exception as e: print(f" ⚠️ feature_engine[{asset_id}]: {str(e)[:50]}") # FIX 5b: Do NOT restore pending_episodes or trade_tick_counts. # These reference Deriv multiplier contracts that are gone after a restart. # Restoring them causes phantom open_count > 0, blocking the top-4 floor # enforcer and corrupting RL episode pairing with dead contract IDs. stale_ep = checkpoint.get('pending_episodes', {}) stale_tc = checkpoint.get('trade_tick_counts', {}) if stale_ep: print( f" ⚠️ Discarded {len(stale_ep)} stale pending_episodes " "(Deriv contracts expired across restart — phantom positions prevented)" ) if stale_tc: print( f" ℹ️ Discarded {len(stale_tc)} stale trade_tick_counts (reset on restart)" ) # _pending_episodes and _trade_tick_counts stay at fresh empty-dict state. # Runtime counters (safe to restore) if 'rank_count' in checkpoint and hasattr(ranker, 'rank_count'): ranker.rank_count = checkpoint['rank_count'] print(f" ✅ rank_count={ranker.rank_count}") result.success = True result.total_load_time = time.time() - start_time self.load_count += 1 if self.verbose: summary = result.get_summary() print(f"\n {'='*70}") print(f" 📊 LOAD SUMMARY: step_{result.training_steps}") print(f" {'='*70}") print(f" Components: {summary['succeeded']}/{summary['total_components']} success") print(f" Episode: {result.episode_number}") print(f" Time: {result.total_load_time:.2f}s") print(f"{'='*80}\n") return result def load_latest_checkpoint(self, ranker, **kwargs) -> CheckpointLoadResult: """Load the latest available checkpoint.""" return self.load_checkpoint(ranker, training_steps=None, **kwargs) # ======================================================================== # AUTO-SAVE FUNCTIONALITY # ======================================================================== def start_autosave(self, ranker, get_training_steps_fn, get_episode_fn=None): """ Start automatic periodic saving. Args: ranker: Ranker instance to save get_training_steps_fn: Function that returns current training step get_episode_fn: Function that returns current episode (optional) """ if self.autosave_interval <= 0: return self.autosave_enabled = True def autosave_worker(): import time as _time while self.autosave_enabled: _time.sleep(self.autosave_interval) if not self.autosave_enabled: break try: steps = get_training_steps_fn() episode = get_episode_fn() if get_episode_fn else 0 if steps > self.get_latest_step(): self.save_checkpoint(ranker, steps, episode) except Exception as e: if self.verbose: print(f" ⚠️ Autosave failed: {e}") self.autosave_thread = threading.Thread(target=autosave_worker, daemon=True) self.autosave_thread.start() if self.verbose: print(f" ⏰ Autosave started (every {self.autosave_interval//60} minutes)") def stop_autosave(self): """Stop automatic saving.""" self.autosave_enabled = False if self.autosave_thread: self.autosave_thread.join(timeout=2) # ======================================================================== # ASYNC SAVE WORKER # ======================================================================== def _start_save_worker(self): """Start async save worker thread.""" def worker(): while True: try: task = self.save_queue.get(timeout=1) if task is None: break # Process save task (placeholder for future async save) self.save_queue.task_done() except queue.Empty: continue except Exception: pass self.save_worker_thread = threading.Thread(target=worker, daemon=True) self.save_worker_thread.start() def shutdown(self): """Shutdown checkpoint manager.""" self.stop_autosave() if self.save_queue: self.save_queue.put(None) if self.save_worker_thread: self.save_worker_thread.join(timeout=2) # ======================================================================== # UTILITY METHODS # ======================================================================== def print_structure(self): """Print checkpoint directory structure.""" print(f"\n{'='*80}") print(f"📁 QUASAR CHECKPOINT STRUCTURE") print(f"{'='*80}") print(f" Directory: {self.checkpoint_dir}") print(f" Total Checkpoints: {len(self.index.get('checkpoints', []))}") for cp in self.index.get('checkpoints', [])[-5:]: # Show last 5 print(f" ├─ step_{cp['step']:07d}.pt ({cp.get('size_mb', 0):.2f} MB)") print(f"{'='*80}\n") def clear_all_checkpoints(self, confirm: bool = False): """Delete all checkpoints.""" if not confirm: print(" ⚠️ Use confirm=True to delete all checkpoints") return for cp in self.index.get('checkpoints', []): cp_path = self.checkpoint_dir / cp.get('filename', '') if cp_path.exists(): cp_path.unlink() self.index['checkpoints'] = [] self._save_index() # Remove sentinel sentinel = self._get_latest_sentinel_path() if sentinel.exists(): sentinel.unlink() if self.verbose: print(f" 🗑️ Cleared all checkpoints") # ============================================================================ # INTEGRATION HELPER FOR QUASAR AXRVI RANKER # ============================================================================ def integrate_checkpoint_manager(ranker_class): """ Decorator/helper to integrate checkpoint manager into ranker class. Usage: @integrate_checkpoint_manager class QuasarAXRVIRanker: ... Or manually: ranker.checkpoint_manager = QuasarCheckpointManager(...) ranker.load_state = lambda: ranker.checkpoint_manager.load_latest_checkpoint(ranker) ranker.save_state = lambda step, ep: ranker.checkpoint_manager.save_checkpoint(ranker, step, ep) """ def enhanced_init(self, *args, checkpoint_dir="./checkpoints", **kwargs): # Store original init args self._checkpoint_dir = checkpoint_dir # Call original init original_init(self, *args, **kwargs) # Initialize checkpoint manager self.checkpoint_manager = QuasarCheckpointManager( checkpoint_dir=checkpoint_dir, verbose=getattr(self, 'verbose', True) ) # Auto-load latest checkpoint if available result = self.checkpoint_manager.load_latest_checkpoint(self) if result.success: print(f" ✅ Resumed from step {result.training_steps}, episode {result.episode_number}") # Return loaded state for ranker to use self._loaded_step = result.training_steps self._loaded_episode = result.episode_number else: self._loaded_step = 0 self._loaded_episode = 0 print(f" 🆕 Starting fresh training") # Store original init original_init = ranker_class.__init__ ranker_class.__init__ = enhanced_init # Add convenience methods def save_state(self, training_steps, episode_number=0): return self.checkpoint_manager.save_checkpoint(self, training_steps, episode_number) def load_state(self, training_steps=None): return self.checkpoint_manager.load_checkpoint(self, training_steps) def get_latest_step(self): return self.checkpoint_manager.get_latest_step() ranker_class.save_state = save_state ranker_class.load_state = load_state ranker_class.get_latest_step = get_latest_step return ranker_class # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 19 — ENTRY POINT HELPER # ══════════════════════════════════════════════════════════════════════════════════════ async def run_live_trading_system( asset_symbols: Optional[List[str]] = None, bandit_strategy: str = "ucb", reward_strategy: str = "simple", model_path: str = "quasar_axrvi_v6.pt", hub_ws_url: str = "ws://localhost:7860/ws/subscribe", enable_logging: bool = True, shreve_config: Optional[ShreveConfig] = None, checkpoint_dir: str = "./Ranker10", resume: bool = True, # [RESUME FIX] default ON — env var QUASAR_RESUME overrides hf_repo_id: Optional[str] = "KarlQuant/quasar-axrvi-v10", # new HF repo (10 assets) ) -> None: config = AssetRankerConfig( asset_symbols = asset_symbols or list(ASSET_REGISTRY.keys()), bandit_strategy = bandit_strategy, model_path = model_path, max_concurrent = 3, shreve_config = shreve_config or ShreveConfig(), portfolio_risk_config = PortfolioRiskConfig( total_capital = 6.0, kelly_fraction = 0.0, min_notional = 1.0, max_notional = 1.0, max_portfolio_risk = 0.50, drawdown_halt_pct = 0.33, halt_duration_secs = 60.0, drawdown_reduce_pct = 0.17, ), ) bridge = QuasarAXRVIBridge( config = config, trade_config = TradeConfig( amount = 1.0, expiry_time = 60, ), reward_strategy = reward_strategy, hub_ws_url = hub_ws_url, enable_logging = enable_logging, checkpoint_dir = checkpoint_dir, resume = resume, hf_repo_id = hf_repo_id, ) await bridge.run() # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 20 — COMPONENT TESTS (merged + corrected from v3 + v4) # ══════════════════════════════════════════════════════════════════════════════════════ def test_components() -> None: print("\n" + "=" * 72) print("QUASAR AXRVI v6 — SHREVE FRAMEWORK COMPONENT TESTS") print("=" * 72) # ── 1. UnifiedFeatureEngine (QV vol + jump-diffusion) ─────────────────── print("\n✓ Testing UnifiedFeatureEngine (26-dim, QV vol [S3], jump-diffusion [S9])…") sc = ShreveConfig(jump_diffusion_assets=("CRASH1000",)) fe_d = UnifiedFeatureEngine("V75", base_vol=75.0, shreve_config=sc) # diffusion fe_j = UnifiedFeatureEngine("CRASH1000", base_vol=100.0, shreve_config=sc) # jump-diffusion sig_kw = {"action": "BUY", "confidence": 0.8, "avn_accuracy": 0.9, "match_efficacy": 0.85} for price in [1500.0 + i * 2.0 for i in range(30)]: fe_d.on_price(price, 0.03, 0.0) fe_j.on_price(price, 0.05, 0.0) fe_d.on_signal("BUY", 0.8, 0.1, 0.9, 0.85) fe_j.on_signal("BUY", 0.8, 0.1, 0.9, 0.85) feats_d = fe_d.extract(sig_kw) feats_j = fe_j.extract(sig_kw) assert feats_d.shape == (26,), f"Expected (26,), got {feats_d.shape}" assert feats_j.shape == (26,) # QV-based vol [S3]: feature[6] raw must be > 0 assert fe_d.get_raw_feature(6) > 0, "QV vol should be positive" # jump-diffusion flag assert fe_j._is_jump_diffusion, "CRASH1000 should be a jump-diffusion asset" assert not fe_d._is_jump_diffusion, "V75 should not be a jump-diffusion asset" print(f" ✅ UnifiedFeatureEngine OK | shape={feats_d.shape} | qv_vol={fe_d.get_raw_feature(6):.4f}") # ── 2. AssetStateBuffer ────────────────────────────────────────────────── print("\n✓ Testing AssetStateBuffer…") buf = AssetStateBuffer("V75", {"base_vol": 75.0}, shreve_config=sc) snap = AssetSnapshot("V75") snap.apply_update({ "training": {"avn_accuracy": 0.8, "actor_loss": 0.1, "critic_loss": 0.2, "avn_loss": 0.05}, "voting": {"buy_count": 7, "sell_count": 3, "dominant_signal": "BUY"}, }) # v2.3+: per-tick action arrives via /ws/signals → apply_signal, not via # the snapshot's voting dict. Simulate that here. snap.apply_signal(action="BUY", price=1500.0) buf.on_price(1500.0, 0.03, 0.0) buf.on_hub_snapshot(snap) seq = buf.get_sequence() assert seq.shape == (SEQ_LEN, FEATURE_DIM) print(f" ✅ AssetStateBuffer OK | seq shape={seq.shape}") # ── 3. AXRVINet (forward + MC Dropout) ────────────────────────────────── print("\n✓ Testing AXRVINet…") net = AXRVINet(num_assets=3, feature_dim=26, seq_len=SEQ_LEN) test_input = torch.randn(1, 3, SEQ_LEN, 26) with torch.no_grad(): out = net(test_input) assert out["significance_weight"].shape == (1, 3) assert "value" in out and "log_var" in out mc_out = net.forward_with_epistemic_uncertainty(test_input, mc_samples=5) assert mc_out["epistemic_std"].shape == (1, 3) print(f" ✅ AXRVINet OK | sig={out['significance_weight'][0].tolist()}") # ── 4. ShreveRankingEngine [S1, S5, S6] ───────────────────────────────── print("\n✓ Testing ShreveRankingEngine (discounted risk-neutral priority [S6])…") sc_rank = ShreveConfig(risk_free_rate=0.05, horizon_seconds=60.0, ito_curvature_weight=0.0) re = ShreveRankingEngine(shreve_config=sc_rank) snapshots: Dict[str, AssetSnapshot] = {} for name, buy, sell, acc in [ ("V75", 7, 3, 0.60), ("V100_1s", 5, 5, 0.70), ("CRASH500", 9, 1, 0.40), ]: s = AssetSnapshot(space_name=name) s.apply_update({ "training": {"avn_accuracy": acc}, "voting": {"buy_count": buy, "sell_count": sell, "dominant_signal": "BUY" if buy > sell else "NEUTRAL"}, }) # v2.3+: realtime per-tick action lives in latest_action / dominant_signal, # populated by apply_signal (which mirrors the /ws/signals path). s.apply_signal(action="BUY" if buy > sell else "HOLD") snapshots[name] = s sig_map = {"V75": 0.60, "V100_1s": 0.70, "CRASH500": 0.40} ranked = re.rank_risk_neutral(snapshots, significance_weights=sig_map) assert ranked[0].space_name == "V75", f"Top should be V75, got {ranked[0].space_name}" # Verify discount is applied: priority < conf × sig for r in ranked: raw = r.signal_confidence * sig_map.get(r.space_name, r.avn_accuracy) discount = math.exp(-0.05 * 60.0 / (365.25 * 24 * 3600)) expected = discount * raw assert abs(r.final_priority - expected) < 1e-4, ( f"Priority mismatch for {r.space_name}: got {r.final_priority:.6f}, " f"expected ~{expected:.6f}" ) print(" ✅ ShreveRankingEngine OK | " + " > ".join(f"{r.space_name}({r.final_priority:.4f})" for r in ranked)) # ── 5. ConservativeRanker ──────────────────────────────────────────────── print("\n✓ Testing ConservativeRanker…") cr = ConservativeRanker(confidence_level=0.95) priority = cr.compute_conservative_priority(0.8, 0.7, aleatoric_std=0.1, epistemic_std=0.05) assert 0.0 <= priority <= 0.8 print(f" ✅ ConservativeRanker OK | priority={priority:.4f}") # ── 6. DynamicExecutionGate — Gates A–E (incl. martingale [S7]) ───────── print("\n✓ Testing DynamicExecutionGate (Gates A–E incl. martingale null-hypothesis [S7])…") gate = DynamicExecutionGate(martingale_epsilon=0.05) # Normal pass ok, r_ok = gate.should_execute(0.8, 0.6, 1.2, 0.1, 0.05, 0.05, "BUY", martingale_deviation=0.3) # Jump veto (Gate D) bad, r_bad = gate.should_execute(0.8, 0.6, 1.2, 0.9, 0.05, 0.05, "BUY", martingale_deviation=0.3) # Martingale veto (Gate E): DevMart ≤ ε mart, r_mart = gate.should_execute(0.8, 0.6, 1.2, 0.1, 0.05, 0.05, "BUY", martingale_deviation=0.01) assert ok, f"Should pass all gates: {r_ok}" assert not bad, f"Should veto on jump risk: {r_bad}" assert not mart, f"Should veto on martingale null: {r_mart}" assert "Gate E" in r_mart, f"Veto reason should mention Gate E: {r_mart}" print(f" ✅ DynamicExecutionGate OK | pass='{r_ok}' | jump='{r_bad}' | mart='{r_mart}'") # ── 7. GirsanovReplayBuffer ────────────────────────────────────────────── print("\n✓ Testing GirsanovReplayBuffer…") replay = GirsanovReplayBuffer(capacity=100) for r, v, td in [(0.5, 0.1, 0.2), (-0.3, 0.2, 0.5), (0.1, 0.05, 0.1)]: replay.push({ "reward": r, "volatility": v, "td_error": td, "sequences": [[[0.0] * 26] * SEQ_LEN], "next_sequences": [[[0.0] * 26] * SEQ_LEN], # distinct from sequences [S2] "selected_idx": 0, "pnl_per_asset": [r, 0.0, 0.0], }) assert replay.is_ready(3) samples = replay.sample(2) assert len(samples) == 2 and "importance_weight" in samples[0] print(f" ✅ GirsanovReplayBuffer OK | size={len(replay)}") # ── 8. HybridTrainer — 4-loss incl. L_CE [S1] ─────────────────────────── print("\n✓ Testing HybridTrainer (4-loss: rl + ce + rank + risk) [S1]…") model = AXRVINet(num_assets=3, config=AXRVIConfig(feature_dim=26, seq_len=SEQ_LEN)) trainer = HybridTrainer(model=model, lr=1e-3, lambda_ce=0.3) dummy_s = np.zeros((1, 3, SEQ_LEN, 26), dtype=np.float32) dummy_ns = np.ones( (1, 3, SEQ_LEN, 26), dtype=np.float32) * 0.01 # distinct [S2] episodes = [ { "sequences": dummy_s.tolist(), "next_sequences": dummy_ns.tolist(), # genuinely distinct s_{t+1} [S2] "selected_idx": 0, "reward": 0.5 if i % 2 == 0 else -0.1, "pnl_per_asset": [0.5, 0.0, 0.0] if i % 2 == 0 else [-0.1, 0.0, 0.0], } for i in range(8) ] loss_dict = trainer.train_on_batch(episodes) assert "total" in loss_dict and loss_dict["total"] >= 0.0 assert "ce" in loss_dict, "Value-consistency loss 'ce' missing from loss dict [S1]" assert "rl" in loss_dict assert "rank" in loss_dict assert "risk" in loss_dict print(f" ✅ HybridTrainer OK | total={loss_dict['total']:.4f} | " f"rl={loss_dict['rl']:.4f} | ce={loss_dict['ce']:.4f} | " f"rank={loss_dict['rank']:.4f} | risk={loss_dict['risk']:.4f}") # ── 9. UnifiedRewardCalculator — directional log-return [S4] ──────────── print("\n✓ Testing UnifiedRewardCalculator (directional log-return [S4])…") trade_l = Trade( trade_id="T001", asset="V75", direction=TradeDirection.LONG, entry_price=1500.0, entry_time=time.time(), quantity=0.1, ) trade_l.close(1515.0, time.time(), 0.0) # close with 0 fees (added by calc) trade_s = Trade( trade_id="T002", asset="V75", direction=TradeDirection.SHORT, entry_price=1500.0, entry_time=time.time(), quantity=0.1, ) trade_s.close(1485.0, time.time(), 0.0) calc = UnifiedRewardCalculator(strategy="simple", fee_per_trade=0.0, slippage_bps=0.0) r_long = calc.compute_reward(trade_l) r_short = calc.compute_reward(trade_s) # Both should yield positive log-returns: log(1515/1500) and log(1500/1485) expected_long = math.log(1515.0 / 1500.0) expected_short = math.log(1500.0 / 1485.0) assert abs(r_long - expected_long) < 1e-6, f"Long reward mismatch: {r_long:.6f} vs {expected_long:.6f}" assert abs(r_short - expected_short) < 1e-6, f"Short reward mismatch: {r_short:.6f} vs {expected_short:.6f}" # Test fee deduction calc_fees = UnifiedRewardCalculator(strategy="simple", fee_per_trade=0.001, slippage_bps=2.0) r_fees = calc_fees.compute_reward(trade_l) assert r_fees < r_long, "Reward with fees should be less than without" for strat in ("simple", "sharpe", "sortino"): rc = UnifiedRewardCalculator(strategy=strat) rv = rc.compute_reward(trade_l, [1500.0 + i for i in range(20)]) assert isinstance(rv, float) print(f" ✅ UnifiedRewardCalculator OK | long={r_long:+.6f} | short={r_short:+.6f}") # ── 10. Trade + PositionManager ───────────────────────────────────────── print("\n✓ Testing Trade + PositionManager…") mgr = PositionManager() mgr.open_trade("T003", "V75", TradeDirection.SHORT, 1500.0, 0.1) assert len(mgr.get_open_trades()) == 1 mgr.close_trade("T003", 1495.0, 0.15) assert mgr.trades_closed == 1 print(f" ✅ PositionManager OK | realized_pnl={mgr.total_realized_pnl:+.4f}") # ── 11. BanditSelector ────────────────────────────────────────────────── print("\n✓ Testing BanditSelector…") bandit = BanditSelector(["V75", "V100_1s", "CRASH1000"]) selected, _ = bandit.select(np.array([0.7, 0.5, 0.3]), threshold=0.0, max_select=3) assert len(selected) <= 3 print(f" ✅ BanditSelector OK | selected={selected}") # ── 12. PriceStreamer ──────────────────────────────────────────────────── print("\n✓ Testing PriceStreamer…") ps = PriceStreamer("V75") ps.on_tick(1499.5, 1500.5, time.time()) assert ps.get_latest_price() == 1500.0 print(f" ✅ PriceStreamer OK | mid={ps.get_latest_price():.2f}") # ── 13. Optimal-stopping logic [S8] sanity check ──────────────────────── print("\n✓ Testing optimal-stopping G_t vs C_t logic [S8]…") # Simulate: entry 1500, current 1515 LONG → G_t > 0 entry_p = 1500.0 current_p = 1515.0 g_t = math.log(current_p / entry_p) # log-return (no fee for simplicity) c_t = 0.002 # low continuation value assert g_t > c_t + 0.001, "Should trigger stop: G_t > C_t + buffer" # Reverse: G_t < C_t → continue current_p2 = 1501.0 g_t2 = math.log(current_p2 / entry_p) c_t2 = 0.02 # high continuation value assert g_t2 < c_t2, "Should continue: G_t < C_t" print(f" ✅ Optimal stopping OK | G_t={g_t:+.5f} > C_t={c_t:.5f} → stop | " f"G_t={g_t2:+.5f} < C_t={c_t2:.5f} → continue") # ── Optional: RankerLogger ────────────────────────────────────────────── if LOGGING_AVAILABLE: print("\n✓ Testing RankerLogger…") rl = RankerLogger(buffer_size=100, log_dir="./test_logs") rl.signal("V75", "BUY", 0.5, 0.8) entries = rl.get_recent(n=10) assert len(entries) >= 1 print(f" ✅ RankerLogger OK | {len(entries)} log entries") print("\n" + "=" * 72) print("✅ ALL COMPONENT TESTS PASSED (v6.0-shreve)") print("=" * 72) # ══════════════════════════════════════════════════════════════════════════════════════ # SECTION 21 — MAIN # ══════════════════════════════════════════════════════════════════════════════════════ def _parse_args(): """Parse CLI args, stripping any Jupyter kernel args.""" filtered = [a for a in sys.argv[1:] if not a.startswith("-f")] parser = argparse.ArgumentParser(description="QUASAR AXRVI Ranker v7 — Shreve Framework") parser.add_argument("--test", action="store_true", help="Run component tests and exit") parser.add_argument("--assets", nargs="+", default=list(ASSET_REGISTRY.keys()), help="Asset symbols to track and trade") parser.add_argument("--bandit", choices=["ucb", "thompson", "greedy"], default="ucb", help="Bandit selection strategy") parser.add_argument("--reward", choices=["simple", "sharpe", "sortino"], default="simple", help="Reward calculation strategy") parser.add_argument("--model", default="quasar_axrvi_v6.pt", help="Path to save/load model checkpoint") parser.add_argument("--hub", default=os.environ.get( "QUASAR_HUB_URL", "ws://localhost:7860/ws/subscribe" ), help="Central hub WebSocket URL (env: QUASAR_HUB_URL)") parser.add_argument("--sync", action="store_true", help="Run in synchronous (threading) mode instead of asyncio") parser.add_argument("--no-logs", action="store_true", help="Disable structured logging") parser.add_argument("--mc-samples", type=int, default=10, help="MC Dropout samples for epistemic uncertainty") parser.add_argument("--risk-free-rate", type=float, default=0.05, help="[S6] Annualised risk-free rate for discounting (default 0.05)") parser.add_argument("--horizon-seconds", type=float, default=60.0, help="[S6/S8] Trade horizon τ in seconds (default 60)") parser.add_argument("--martingale-epsilon", type=float, default=0.05, help="[S7] Gate E martingale deviation threshold (default 0.05)") parser.add_argument("--checkpoint-dir", default="./Ranker10", help="Directory for full-state checkpoints (default ./Ranker10)") # [RESUME FIX] Default is now --resume (load latest checkpoint). # Pass --no-resume to deliberately start fresh. # Env var QUASAR_RESUME=0|false further overrides this in the bridge. parser.add_argument("--no-resume", dest="no_resume", action="store_true", default=False, help="Start fresh, ignoring any existing checkpoint (default: resume).") parser.add_argument("--resume", dest="no_resume", action="store_false", help="Restore from the latest Ranker10 checkpoint (default).") parser.add_argument("--hf-repo", default=None, metavar="OWNER/REPO", help="Hugging Face Dataset repo for checkpoint sync " "(e.g. 'YourName/quasar-ckpts'). " "Overrides the HF_REPO_ID env var. " "Requires HF_TOKEN env var to be set.") return parser.parse_args(filtered) if __name__ == "__main__": args = _parse_args() if args.test: test_components() sys.exit(0) stoch_cfg = StochasticCalculusConfig() uncertainty_cfg = UncertaintyConfig(mc_samples=args.mc_samples) replay_cfg = GirsanovReplayConfig() shreve_cfg = ShreveConfig( risk_free_rate = args.risk_free_rate, horizon_seconds = args.horizon_seconds, martingale_gate_epsilon = args.martingale_epsilon, ) config = AssetRankerConfig( asset_symbols = args.assets, bandit_strategy = args.bandit, model_path = args.model, stochastic_config = stoch_cfg, uncertainty_config = uncertainty_cfg, replay_config = replay_cfg, shreve_config = shreve_cfg, max_concurrent = 3, # always maintain top-3 portfolio_risk_config = PortfolioRiskConfig( total_capital = 6.0, kelly_fraction = 0.0, min_notional = 1.0, max_notional = 1.0, max_portfolio_risk = 0.50, drawdown_halt_pct = 0.33, halt_duration_secs = 60.0, drawdown_reduce_pct= 0.17, ), ) bridge = QuasarAXRVIBridge( config = config, trade_config = TradeConfig( amount = 1.0, # $1 flat stake expiry_time = 60, # 1-minute max; optimal-stopping exits earlier ), reward_strategy = args.reward, hub_ws_url = args.hub, enable_logging = not args.no_logs, checkpoint_dir = args.checkpoint_dir, resume = not args.no_resume, # [RESUME FIX] default True — env var QUASAR_RESUME overrides hf_repo_id = args.hf_repo or "KarlQuant/quasar-axrvi-v10", ) try: if args.sync: bridge.start_sync() while True: time.sleep(1) else: try: import nest_asyncio nest_asyncio.apply() except ImportError: pass asyncio.run(run_live_trading_system( asset_symbols = args.assets, bandit_strategy = args.bandit, reward_strategy = args.reward, model_path = args.model, hub_ws_url = args.hub, enable_logging = not args.no_logs, checkpoint_dir = args.checkpoint_dir, # FIX 1: was silently ignored resume = not args.no_resume, # FIX 2: default False (always fresh start) hf_repo_id = args.hf_repo or "KarlQuant/quasar-axrvi-v10", )) except KeyboardInterrupt: print("\n👋 Shutting down…") if args.sync: bridge.stop_sync() except Exception as e: print(f"\n❌ Fatal error: {e}") traceback.print_exc() sys.exit(1)