Spaces:
Sleeping
Sleeping
| """Orchestrator registry and management.""" | |
| import asyncio | |
| import os | |
| import re | |
| import shutil | |
| import subprocess | |
| from typing import Any | |
| from pathlib import Path | |
| from .config import OrchestratorConfig | |
| class OrchestratorRegistry: | |
| """Registry for managing available orchestrators/CLIs.""" | |
| def __init__(self): | |
| self.orchestrators: dict[str, OrchestratorConfig] = {} | |
| self._active_sessions: dict[str, asyncio.subprocess.Process] = {} | |
| def register(self, config: OrchestratorConfig) -> None: | |
| """Register an orchestrator.""" | |
| self.orchestrators[config.name] = config | |
| def unregister(self, name: str) -> None: | |
| """Unregister an orchestrator.""" | |
| self.orchestrators.pop(name, None) | |
| def get(self, name: str) -> OrchestratorConfig | None: | |
| """Get orchestrator configuration.""" | |
| return self.orchestrators.get(name) | |
| def list_enabled(self) -> list[str]: | |
| """List all enabled orchestrators.""" | |
| return [name for name, config in self.orchestrators.items() if config.enabled] | |
| def _resolve_command(cmd: list[str]) -> list[str]: | |
| """ | |
| Resolve command to full path on Windows. | |
| On Windows, asyncio.create_subprocess_exec() doesn't reliably search PATH, | |
| so we need to resolve commands to their full paths using shutil.which(). | |
| Args: | |
| cmd: Command list | |
| Returns: | |
| Resolved command (full path on Windows, original on Unix) | |
| """ | |
| if os.name != "nt" or not cmd: | |
| # On Unix systems, PATH search works fine | |
| return cmd | |
| # On Windows, resolve the executable path | |
| resolved = shutil.which(cmd[0]) | |
| if resolved: | |
| return [resolved] + cmd[1:] | |
| return cmd | |
| async def execute( | |
| self, | |
| orchestrator_name: str, | |
| task: str, | |
| timeout: int | None = None, | |
| progress_callback: Any = None, | |
| ) -> tuple[str, str, int]: | |
| """ | |
| Execute a task using specified orchestrator. | |
| Args: | |
| orchestrator_name: Name of orchestrator to use | |
| task: Task description/query | |
| timeout: Optional timeout in seconds | |
| progress_callback: Optional async callback(line: str) for stdout streaming | |
| Returns: | |
| tuple: (stdout, stderr, return_code) | |
| """ | |
| config = self.get(orchestrator_name) | |
| if not config: | |
| raise ValueError(f"Orchestrator '{orchestrator_name}' not found") | |
| if not config.enabled: | |
| raise ValueError(f"Orchestrator '{orchestrator_name}' is disabled") | |
| # Build command | |
| if isinstance(config.command, list): | |
| cmd = config.command + config.args + [task] | |
| else: | |
| cmd = [config.command] + config.args + [task] | |
| # Resolve command path on Windows | |
| resolved_cmd = self._resolve_command(cmd) | |
| # Execute with timeout | |
| timeout_seconds = timeout or config.timeout | |
| process = None | |
| # Build safe environment with allowlist approach | |
| # Only include essential environment variables | |
| allowed_env_vars = [ | |
| 'PATH', 'HOME', 'USER', 'LANG', 'LC_ALL', 'TERM', | |
| 'PYTHONPATH', 'NODE_PATH', 'OPENROUTER_API_KEY', | |
| 'ANTHROPIC_API_KEY', 'OPENAI_API_KEY', 'GOOGLE_API_KEY', | |
| 'TMPDIR', 'TEMP', 'TMP', 'USERPROFILE', 'SYSTEMROOT', | |
| ] | |
| safe_env = {} | |
| for key in allowed_env_vars: | |
| if key in os.environ: | |
| safe_env[key] = os.environ[key] | |
| # Add config-specified env vars with validation | |
| for key, value in config.env.items(): | |
| # Validate env var name (alphanumeric and underscore only) | |
| if not re.match(r'^[A-Z_][A-Z0-9_]*$', key): | |
| import logging | |
| logging.getLogger(__name__).warning( | |
| f"Skipping invalid environment variable name: {key}" | |
| ) | |
| continue | |
| safe_env[key] = value | |
| stdout_chunks = [] | |
| stderr_chunks = [] | |
| async def _read_stream(stream, is_stderr: bool): | |
| while True: | |
| line = await stream.readline() | |
| if not line: | |
| break | |
| text = line.decode("utf-8", errors="replace") | |
| if is_stderr: | |
| stderr_chunks.append(text) | |
| else: | |
| stdout_chunks.append(text) | |
| if on_output: | |
| try: | |
| await on_output(text, is_stderr) | |
| except Exception: | |
| pass # Ignore callback errors | |
| try: | |
| process = await asyncio.create_subprocess_exec( | |
| *resolved_cmd, | |
| stdout=asyncio.subprocess.PIPE, | |
| stderr=asyncio.subprocess.PIPE, | |
| env=safe_env, | |
| ) | |
| stdout_chunks = [] | |
| stderr_chunks = [] | |
| async def read_stream(stream, chunks, callback=None): | |
| while True: | |
| line = await stream.readline() | |
| if not line: | |
| break | |
| decoded_line = line.decode("utf-8", errors="replace") | |
| chunks.append(decoded_line) | |
| if callback: | |
| try: | |
| if asyncio.iscoroutinefunction(callback): | |
| await callback(decoded_line.strip()) | |
| else: | |
| callback(decoded_line.strip()) | |
| except Exception: | |
| pass # Ignore callback errors to prevent crashing execution | |
| # Create tasks for reading stdout and stderr | |
| stdout_task = asyncio.create_task( | |
| read_stream(process.stdout, stdout_chunks, progress_callback) | |
| ) | |
| stderr_task = asyncio.create_task( | |
| read_stream(process.stderr, stderr_chunks) | |
| ) | |
| # Wait for everything to finish or timeout | |
| try: | |
| # We wait for the process AND the stream readers | |
| # This ensures we don't timeout if the process is done but streams are still being read | |
| # and conversely, we DO timeout if streams are blocked even if process is done (unlikely but possible) | |
| # or if process is hanging. | |
| await asyncio.wait_for( | |
| asyncio.gather(process.wait(), stdout_task, stderr_task), | |
| timeout=timeout_seconds | |
| ) | |
| except asyncio.TimeoutError: | |
| # Timeout occurred - clean up everything | |
| if process: | |
| try: | |
| process.kill() | |
| except ProcessLookupError: | |
| pass | |
| # Cancel stream readers | |
| stdout_task.cancel() | |
| stderr_task.cancel() | |
| # Wait for cancellation to complete | |
| try: | |
| await asyncio.gather(stdout_task, stderr_task, return_exceptions=True) | |
| except Exception: | |
| pass | |
| raise TimeoutError( | |
| f"Orchestrator '{orchestrator_name}' timed out after {timeout_seconds}s" | |
| ) | |
| return ( | |
| "".join(stdout_chunks), | |
| "".join(stderr_chunks), | |
| process.returncode or 0, | |
| ) | |
| except Exception as e: | |
| if process and process.returncode is None: | |
| try: | |
| process.kill() | |
| await process.wait() | |
| except ProcessLookupError: | |
| pass | |
| if isinstance(e, (TimeoutError, RuntimeError)): | |
| raise e | |
| raise RuntimeError( | |
| f"Orchestrator '{orchestrator_name}' failed: {str(e)}" | |
| ) from e | |
| def validate_all(self) -> dict[str, bool]: | |
| """ | |
| Validate all registered orchestrators are available. | |
| Returns: | |
| dict: {orchestrator_name: is_available} | |
| """ | |
| results = {} | |
| for name, config in self.orchestrators.items(): | |
| cmd = config.command if isinstance(config.command, str) else config.command[0] | |
| try: | |
| subprocess.run( | |
| ["which", cmd] if subprocess.os.name != "nt" else ["where", cmd], | |
| capture_output=True, | |
| check=True, | |
| ) | |
| results[name] = True | |
| except subprocess.CalledProcessError: | |
| results[name] = False | |
| return results | |