| """ |
| Direct observable regression: f_θ(R, t) → {⟨Γ_μ⟩(R, t)}. |
| |
| Bypasses the shadow/Q-conditioning pipeline entirely. Predicts signal |
| matrix entries directly using learnable Fourier features for time encoding. |
| |
| Two variants: |
| - Shared frequencies (v2/v3): ω_k are global learnable parameters |
| - Geometry-conditioned (v4+): ω_k(R) = ω_k^{(0)} + g_φ(R)_k where g_φ |
| is a small MLP. Energy gaps depend on R, so the optimal Fourier basis |
| should too. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import asdict, dataclass |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass |
| class ObservableRegressorConfig: |
| n_observables: int = 120 |
| d_hidden: int = 256 |
| n_layers: int = 3 |
| n_fourier: int = 64 |
| fourier_scale: float = 10.0 |
| conditioned_frequencies: bool = False |
| freq_net_hidden: int = 64 |
| freq_net_layers: int = 2 |
| n_orb_features: int = 0 |
| adaptive_bandwidth: bool = False |
| omega_op_floor: float = 0.0 |
| soft_omega_floor: bool = False |
| soft_omega_beta: float = 10.0 |
| standardize_orb_energies: bool = False |
| explicit_amplitude: bool = False |
| amp_rank: int = 16 |
| with_residual: bool = False |
|
|
| def to_dict(self): |
| return asdict(self) |
|
|
|
|
| class ObservableRegressor(nn.Module): |
| """Direct regression: (R, t) → K observable expectations. |
| |
| Architecture: |
| 1. Fourier features for t: sin(ω_k(R) * t), cos(ω_k(R) * t) |
| - Shared mode: ω_k are global learnable parameters |
| - Conditioned mode: ω_k(R) = ω_base_k + freq_net(R)_k |
| 2. Input = [R, Fourier features] ∈ R^{1 + 2*n_fourier} |
| 3. MLP with GELU activations → K outputs |
| """ |
|
|
| def __init__(self, config: ObservableRegressorConfig): |
| super().__init__() |
| self.config = config |
|
|
| |
| |
| log_omega = torch.linspace( |
| np.log(0.05), np.log(config.fourier_scale), config.n_fourier |
| ) |
| self.omega_base = nn.Parameter(log_omega.exp()) |
|
|
| |
| |
| if config.adaptive_bandwidth and not config.conditioned_frequencies: |
| raise ValueError("adaptive_bandwidth requires conditioned_frequencies=True") |
| if config.conditioned_frequencies: |
| freq_in = config.n_orb_features if config.n_orb_features > 0 else 1 |
| fn_layers = [nn.Linear(freq_in, config.freq_net_hidden), nn.GELU()] |
| for _ in range(config.freq_net_layers - 2): |
| fn_layers.extend([ |
| nn.Linear(config.freq_net_hidden, config.freq_net_hidden), |
| nn.GELU(), |
| ]) |
| fn_layers.append(nn.Linear(config.freq_net_hidden, config.n_fourier)) |
| self.freq_net = nn.Sequential(*fn_layers) |
| else: |
| self.freq_net = None |
|
|
| |
| |
| |
| n_orb = max(config.n_orb_features, 1) |
| self.register_buffer("orb_mean", torch.zeros(n_orb)) |
| self.register_buffer("orb_std", torch.ones(n_orb)) |
|
|
| if config.explicit_amplitude: |
| |
| |
| amp_in = config.n_orb_features if config.n_orb_features > 0 else 1 |
| r = config.amp_rank |
| if r > 0 and r < min(config.n_fourier, config.n_observables): |
| |
| |
| amp_out = 2 * (config.n_fourier * r + r * config.n_observables) |
| else: |
| |
| amp_out = 2 * config.n_fourier * config.n_observables |
| amp_layers = [nn.Linear(amp_in, config.d_hidden), nn.GELU()] |
| for _ in range(config.n_layers - 1): |
| amp_layers.extend([nn.Linear(config.d_hidden, config.d_hidden), nn.GELU()]) |
| amp_layers.append(nn.Linear(config.d_hidden, amp_out)) |
| self.amp_net = nn.Sequential(*amp_layers) |
| self.dc_net = nn.Sequential( |
| nn.Linear(amp_in, config.freq_net_hidden), nn.GELU(), |
| nn.Linear(config.freq_net_hidden, config.n_observables), |
| ) |
| |
| |
| |
| |
| |
| self.net = None |
| else: |
| input_dim = 1 + 2 * config.n_fourier |
| layers = [nn.Linear(input_dim, config.d_hidden), nn.GELU()] |
| for _ in range(config.n_layers - 1): |
| layers.extend([nn.Linear(config.d_hidden, config.d_hidden), nn.GELU()]) |
| layers.append(nn.Linear(config.d_hidden, config.n_observables)) |
| self.net = nn.Sequential(*layers) |
| self.amp_net = None |
| self.dc_net = None |
|
|
| for p in self.parameters(): |
| if p.dim() > 1: |
| nn.init.xavier_uniform_(p) |
|
|
| if config.explicit_amplitude and config.with_residual: |
| |
| |
| |
| |
| |
| |
| input_dim = 1 + 2 * config.n_fourier |
| layers = [nn.Linear(input_dim, config.d_hidden), nn.GELU()] |
| for _ in range(config.n_layers - 1): |
| layers.extend([nn.Linear(config.d_hidden, config.d_hidden), nn.GELU()]) |
| layers.append(nn.Linear(config.d_hidden, config.n_observables)) |
| self.net = nn.Sequential(*layers) |
| for p in self.net.parameters(): |
| if p.dim() > 1: |
| nn.init.xavier_uniform_(p) |
| nn.init.zeros_(self.net[-1].weight) |
| nn.init.zeros_(self.net[-1].bias) |
|
|
| if self.freq_net is not None: |
| nn.init.zeros_(self.freq_net[-1].weight) |
| if config.adaptive_bandwidth: |
| |
| |
| init_sigma = torch.linspace(0.05, 0.95, config.n_fourier) |
| init_logits = torch.log(init_sigma / (1.0 - init_sigma)) |
| self.freq_net[-1].bias.data.copy_(init_logits) |
| else: |
| |
| nn.init.zeros_(self.freq_net[-1].bias) |
|
|
| def set_orb_normalization(self, mean: torch.Tensor, std: torch.Tensor): |
| """Assign standardization stats; called by trainer after computing |
| per-feature mean/std over the training R-set.""" |
| self.orb_mean.copy_(mean.to(self.orb_mean.device, self.orb_mean.dtype)) |
| self.orb_std.copy_(std.clamp_min(1e-8).to(self.orb_std.device, self.orb_std.dtype)) |
|
|
| def forward( |
| self, |
| rt: torch.Tensor, |
| orb_energies: torch.Tensor = None, |
| omega_op: torch.Tensor = None, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| rt: (B, 2) tensor with [R, t] per sample |
| orb_energies: (B, n_orb) tensor of HF orbital energies, or None |
| omega_op: (B,) tensor of operational frequency ceilings ω_op(R), |
| required when config.adaptive_bandwidth is True. |
| Returns: |
| (B, K) predicted observable expectations |
| """ |
| R = rt[:, 0:1] |
| t = rt[:, 1:2] |
|
|
| if ( |
| self.config.standardize_orb_energies |
| and self.config.n_orb_features > 0 |
| and orb_energies is not None |
| ): |
| orb_energies = (orb_energies - self.orb_mean) / self.orb_std |
|
|
| if self.config.adaptive_bandwidth: |
| if omega_op is None: |
| raise ValueError("adaptive_bandwidth requires omega_op input") |
| if self.config.omega_op_floor > 0.0: |
| if self.config.soft_omega_floor: |
| |
| |
| floor = self.config.omega_op_floor |
| beta = self.config.soft_omega_beta |
| omega_op_eff = floor + F.softplus(omega_op - floor, beta=beta) |
| else: |
| omega_op_eff = torch.clamp(omega_op, min=self.config.omega_op_floor) |
| else: |
| omega_op_eff = omega_op |
| x = orb_energies if (self.config.n_orb_features > 0 and orb_energies is not None) else R |
| sigma = torch.sigmoid(self.freq_net(x)) |
| omega = omega_op_eff[:, None] * sigma |
| elif self.freq_net is not None: |
| if self.config.n_orb_features > 0 and orb_energies is not None: |
| omega = self.omega_base[None, :] + self.freq_net(orb_energies) |
| else: |
| omega = self.omega_base[None, :] + self.freq_net(R) |
| else: |
| omega = self.omega_base[None, :] |
|
|
| if self.config.explicit_amplitude: |
| |
| |
| |
| x_in = orb_energies if (self.config.n_orb_features > 0 and orb_energies is not None) else R |
| wt = omega * t |
| if wt.shape[0] == 1: |
| wt = wt.expand(t.shape[0], -1) |
| cos_t = torch.cos(wt) |
| sin_t = torch.sin(wt) |
| amp = self.amp_net(x_in) |
| B = R.shape[0] |
| K = self.config.n_fourier |
| n_obs = self.config.n_observables |
| r = self.config.amp_rank |
| if r > 0 and r < min(K, n_obs): |
| |
| splits = [K * r, r * n_obs, K * r, r * n_obs] |
| a_U, a_V, b_U, b_V = torch.split(amp, splits, dim=-1) |
| a_U = a_U.view(B, K, r); a_V = a_V.view(B, r, n_obs) |
| b_U = b_U.view(B, K, r); b_V = b_V.view(B, r, n_obs) |
| |
| tmp_a = torch.einsum("bk,bkr->br", cos_t, a_U) |
| tmp_b = torch.einsum("bk,bkr->br", sin_t, b_U) |
| y = torch.einsum("br,brn->bn", tmp_a, a_V) \ |
| + torch.einsum("br,brn->bn", tmp_b, b_V) |
| else: |
| a, b = amp.view(B, 2, K, n_obs).unbind(dim=1) |
| y = torch.einsum("bk,bkn->bn", cos_t, a) \ |
| + torch.einsum("bk,bkn->bn", sin_t, b) |
| y = y + self.dc_net(x_in) |
| if self.config.with_residual: |
| |
| |
| fourier_res = torch.cat([sin_t, cos_t], dim=-1) |
| if R.shape[0] == 1 and fourier_res.shape[0] != 1: |
| R_in = R.expand(fourier_res.shape[0], -1) |
| else: |
| R_in = R |
| trunk_in = torch.cat([R_in, fourier_res], dim=-1) |
| y = y + self.net(trunk_in) |
| return y |
|
|
| fourier = torch.cat( |
| [torch.sin(omega * t), torch.cos(omega * t)], |
| dim=-1, |
| ) |
| x = torch.cat([R, fourier], dim=-1) |
| return self.net(x) |
|
|
|
|
| def init_observable_regressor( |
| n_observables: int = 120, |
| d_hidden: int = 256, |
| n_layers: int = 3, |
| n_fourier: int = 64, |
| fourier_scale: float = 10.0, |
| conditioned_frequencies: bool = False, |
| freq_net_hidden: int = 64, |
| freq_net_layers: int = 2, |
| n_orb_features: int = 0, |
| adaptive_bandwidth: bool = False, |
| omega_op_floor: float = 0.0, |
| soft_omega_floor: bool = False, |
| soft_omega_beta: float = 10.0, |
| standardize_orb_energies: bool = False, |
| explicit_amplitude: bool = False, |
| amp_rank: int = 16, |
| with_residual: bool = False, |
| ): |
| return ObservableRegressor( |
| ObservableRegressorConfig( |
| n_observables=n_observables, |
| d_hidden=d_hidden, |
| n_layers=n_layers, |
| n_fourier=n_fourier, |
| fourier_scale=fourier_scale, |
| conditioned_frequencies=conditioned_frequencies, |
| freq_net_hidden=freq_net_hidden, |
| freq_net_layers=freq_net_layers, |
| n_orb_features=n_orb_features, |
| adaptive_bandwidth=adaptive_bandwidth, |
| omega_op_floor=omega_op_floor, |
| soft_omega_floor=soft_omega_floor, |
| soft_omega_beta=soft_omega_beta, |
| standardize_orb_energies=standardize_orb_energies, |
| explicit_amplitude=explicit_amplitude, |
| amp_rank=amp_rank, |
| with_residual=with_residual, |
| ) |
| ) |
|
|