Spaces:
Paused
Paused
| """ | |
| DigiPal - Advanced AI Monster Companion with 3D Generation | |
| Unified application with all features enabled by default | |
| """ | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from typing import Dict, Any, Optional, List | |
| from datetime import datetime | |
| import uvicorn | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| import gradio as gr | |
| import torch | |
| import spaces | |
| # Add src to path | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Environment configuration - All features enabled by default | |
| ENV_CONFIG = { | |
| "LOG_LEVEL": os.getenv("LOG_LEVEL", "INFO"), | |
| "SERVER_NAME": os.getenv("SERVER_NAME", "0.0.0.0"), | |
| "SERVER_PORT": int(os.getenv("SERVER_PORT", "7860")), | |
| "API_PORT": int(os.getenv("API_PORT", "7861")), | |
| "SHARE": os.getenv("SHARE", "false").lower() == "true", | |
| "DEBUG": os.getenv("DEBUG", "false").lower() == "true", | |
| "MAX_THREADS": int(os.getenv("MAX_THREADS", "40")), | |
| "MCP_ENDPOINT": os.getenv("MCP_ENDPOINT", ""), | |
| "MCP_API_KEY": os.getenv("MCP_API_KEY", "") | |
| } | |
| # HuggingFace Spaces detection | |
| IS_SPACES = os.getenv("SPACE_ID") is not None | |
| # API Models | |
| class CreateMonsterRequest(BaseModel): | |
| name: str | |
| personality: str | |
| class MonsterActionRequest(BaseModel): | |
| action: str | |
| params: Dict[str, Any] = {} | |
| class MonsterTalkRequest(BaseModel): | |
| message: str | |
| class Generate3DRequest(BaseModel): | |
| description: Optional[str] = None | |
| # Import core modules after environment setup | |
| try: | |
| from src.core.monster_engine_dw1 import Monster, PersonalityType | |
| from src.core.evolution_system import EvolutionSystem | |
| from src.ai.qwen_processor import QwenProcessor | |
| from src.ai.speech_engine import SpeechEngine | |
| from src.ui.state_manager import StateManager | |
| from src.deployment.zero_gpu_optimizer import get_optimal_device | |
| from src.pipelines.text_to_3d_pipeline import Text3DPipeline | |
| from src.pipelines.hunyuan3d_pipeline import Hunyuan3DClient | |
| from src.pipelines.opensource_3d_pipeline_v2 import ( | |
| OpenSourcePipelineV2, | |
| PipelineConfig, | |
| ModelProvider | |
| ) | |
| # UI imports | |
| from src.ui.gradio_interface_v2 import create_interface | |
| except ImportError as e: | |
| logger.error(f"Failed to import required modules: {e}") | |
| sys.exit(1) | |
| # Initialize FastAPI app | |
| app = FastAPI(title="DigiPal API", version="1.0.0") | |
| # Add CORS middleware for frontend communication | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, replace with specific origins | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global state management | |
| class AppState: | |
| def __init__(self): | |
| self.monsters: Dict[str, Monster] = {} | |
| self.state_manager = StateManager() | |
| self.qwen_processor = None | |
| self.speech_engine = None | |
| self.evolution_system = EvolutionSystem() | |
| self.text3d_pipeline = None | |
| self.active_connections: Dict[str, WebSocket] = {} | |
| self.initialized = False | |
| async def initialize(self): | |
| """Initialize AI components and pipelines""" | |
| if self.initialized: | |
| return | |
| logger.info("Initializing AI components...") | |
| # Initialize AI processors | |
| try: | |
| self.qwen_processor = QwenProcessor() | |
| self.speech_engine = SpeechEngine() | |
| # Initialize 3D pipeline with MCP if available | |
| if ENV_CONFIG["MCP_ENDPOINT"]: | |
| logger.info("Using MCP for 3D generation") | |
| pipeline_config = PipelineConfig( | |
| model_provider=ModelProvider.MCP, | |
| mcp_endpoint=ENV_CONFIG["MCP_ENDPOINT"], | |
| mcp_api_key=ENV_CONFIG["MCP_API_KEY"] | |
| ) | |
| else: | |
| logger.info("Using local models for 3D generation") | |
| pipeline_config = PipelineConfig( | |
| model_provider=ModelProvider.HUGGINGFACE | |
| ) | |
| self.text3d_pipeline = Text3DPipeline( | |
| pipeline_type="opensource_v2", | |
| config=pipeline_config | |
| ) | |
| self.initialized = True | |
| logger.info("All components initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize components: {e}") | |
| raise | |
| # Create global app state | |
| app_state = AppState() | |
| # WebSocket connection manager | |
| class ConnectionManager: | |
| def __init__(self): | |
| self.active_connections: Dict[str, WebSocket] = {} | |
| async def connect(self, websocket: WebSocket, monster_id: str): | |
| await websocket.accept() | |
| self.active_connections[monster_id] = websocket | |
| def disconnect(self, monster_id: str): | |
| if monster_id in self.active_connections: | |
| del self.active_connections[monster_id] | |
| async def send_update(self, monster_id: str, data: dict): | |
| if monster_id in self.active_connections: | |
| try: | |
| await self.active_connections[monster_id].send_json(data) | |
| except: | |
| self.disconnect(monster_id) | |
| manager = ConnectionManager() | |
| # API Endpoints | |
| async def startup_event(): | |
| """Initialize app state on startup""" | |
| await app_state.initialize() | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy", "initialized": app_state.initialized} | |
| async def list_monsters(): | |
| """List all available saved monsters""" | |
| try: | |
| saved_monsters = await app_state.state_manager.list_saved_monsters() | |
| return {"monsters": saved_monsters} | |
| except Exception as e: | |
| logger.error(f"Error listing monsters: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def create_monster(request: CreateMonsterRequest): | |
| """Create a new monster""" | |
| try: | |
| # Create new monster | |
| personality = PersonalityType[request.personality.upper()] | |
| monster = Monster(name=request.name, personality=personality) | |
| # Save to state | |
| app_state.monsters[monster.id] = monster | |
| # Save to database | |
| await app_state.state_manager.save_monster(monster) | |
| return { | |
| "id": monster.id, | |
| "name": monster.name, | |
| "personality": monster.personality.value, | |
| "stage": monster.stage.value, | |
| "stats": monster.get_stats() | |
| } | |
| except Exception as e: | |
| logger.error(f"Error creating monster: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_monster(monster_id: str): | |
| """Load a specific monster's full state""" | |
| try: | |
| # Check if already loaded | |
| if monster_id in app_state.monsters: | |
| monster = app_state.monsters[monster_id] | |
| else: | |
| # Load from database | |
| monster = await app_state.state_manager.load_monster_by_id(monster_id) | |
| if not monster: | |
| raise HTTPException(status_code=404, detail="Monster not found") | |
| app_state.monsters[monster_id] = monster | |
| return { | |
| "id": monster.id, | |
| "name": monster.name, | |
| "personality": monster.personality.value, | |
| "stage": monster.stage.value, | |
| "stats": monster.get_stats(), | |
| "model_url": monster.model_url, | |
| "conversation_history": monster.conversation_history[-10:] # Last 10 messages | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error loading monster: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def perform_action(monster_id: str, request: MonsterActionRequest): | |
| """Perform a care action on the monster""" | |
| try: | |
| if monster_id not in app_state.monsters: | |
| raise HTTPException(status_code=404, detail="Monster not found") | |
| monster = app_state.monsters[monster_id] | |
| result = {} | |
| # Handle different actions | |
| if request.action == "feed": | |
| food_type = request.params.get("food_type", "balanced") | |
| result = monster.feed(food_type) | |
| elif request.action == "train": | |
| training_type = request.params.get("training_type", "strength") | |
| result = monster.train(training_type) | |
| elif request.action == "play": | |
| result = monster.play() | |
| elif request.action == "clean": | |
| result = monster.clean() | |
| elif request.action == "heal": | |
| result = monster.heal() | |
| elif request.action == "discipline": | |
| result = monster.discipline() | |
| elif request.action == "rest": | |
| result = monster.rest() | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Unknown action: {request.action}") | |
| # Save state | |
| await app_state.state_manager.save_monster(monster) | |
| # Send real-time update | |
| await manager.send_update(monster_id, { | |
| "type": "stats_update", | |
| "stats": monster.get_stats(), | |
| "stage": monster.stage.value | |
| }) | |
| return { | |
| "success": True, | |
| "result": result, | |
| "stats": monster.get_stats() | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error performing action: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def talk_to_monster(monster_id: str, request: MonsterTalkRequest): | |
| """Send a text message to the monster""" | |
| try: | |
| if monster_id not in app_state.monsters: | |
| raise HTTPException(status_code=404, detail="Monster not found") | |
| monster = app_state.monsters[monster_id] | |
| # Use MCP if available, otherwise use local processor | |
| if ENV_CONFIG["MCP_ENDPOINT"] and hasattr(app_state.qwen_processor, 'use_mcp'): | |
| response = await app_state.qwen_processor.generate_response_mcp( | |
| monster, request.message | |
| ) | |
| else: | |
| response = app_state.qwen_processor.generate_response( | |
| monster, request.message | |
| ) | |
| # Update conversation history | |
| monster.conversation_history.append({ | |
| "role": "user", | |
| "content": request.message, | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| monster.conversation_history.append({ | |
| "role": "assistant", | |
| "content": response, | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| # Save state | |
| await app_state.state_manager.save_monster(monster) | |
| return { | |
| "response": response, | |
| "stats": monster.get_stats() | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error talking to monster: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_3d_model(monster_id: str, request: Generate3DRequest): | |
| """Trigger 3D model generation for the monster""" | |
| try: | |
| if monster_id not in app_state.monsters: | |
| raise HTTPException(status_code=404, detail="Monster not found") | |
| monster = app_state.monsters[monster_id] | |
| # Generate description if not provided | |
| if not request.description: | |
| description = f"A {monster.personality.value} {monster.stage.value} digital monster" | |
| else: | |
| description = request.description | |
| # Generate 3D model | |
| logger.info(f"Generating 3D model for {monster.name}: {description}") | |
| model_path = await app_state.text3d_pipeline.generate( | |
| description, | |
| output_dir=f"data/models/{monster_id}" | |
| ) | |
| # Update monster with model URL | |
| monster.model_url = f"/models/{monster_id}/{Path(model_path).name}" | |
| await app_state.state_manager.save_monster(monster) | |
| # Send update via WebSocket | |
| await manager.send_update(monster_id, { | |
| "type": "model_update", | |
| "model_url": monster.model_url | |
| }) | |
| return { | |
| "success": True, | |
| "model_url": monster.model_url | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error generating 3D model: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def websocket_endpoint(websocket: WebSocket, monster_id: str): | |
| """WebSocket endpoint for real-time updates""" | |
| await manager.connect(websocket, monster_id) | |
| try: | |
| # Send initial stats | |
| if monster_id in app_state.monsters: | |
| monster = app_state.monsters[monster_id] | |
| await websocket.send_json({ | |
| "type": "initial_state", | |
| "stats": monster.get_stats(), | |
| "stage": monster.stage.value, | |
| "model_url": monster.model_url | |
| }) | |
| # Keep connection alive and handle stat degradation | |
| while True: | |
| await asyncio.sleep(30) # Update every 30 seconds | |
| if monster_id in app_state.monsters: | |
| monster = app_state.monsters[monster_id] | |
| monster.update_time_based_stats() | |
| await websocket.send_json({ | |
| "type": "stats_update", | |
| "stats": monster.get_stats(), | |
| "stage": monster.stage.value | |
| }) | |
| except WebSocketDisconnect: | |
| manager.disconnect(monster_id) | |
| # Gradio interface for fallback/admin | |
| def create_gradio_interface(): | |
| """Create Gradio interface as admin panel""" | |
| interface = create_interface() | |
| return interface | |
| # Main entry point | |
| if __name__ == "__main__": | |
| # Create necessary directories | |
| os.makedirs("data/saves", exist_ok=True) | |
| os.makedirs("data/models", exist_ok=True) | |
| os.makedirs("data/cache", exist_ok=True) | |
| os.makedirs("logs", exist_ok=True) | |
| # Log startup info | |
| logger.info("=" * 60) | |
| logger.info("DigiPal - Advanced AI Monster Companion") | |
| logger.info("=" * 60) | |
| logger.info(f"Environment: {'HuggingFace Spaces' if IS_SPACES else 'Local'}") | |
| logger.info(f"API Port: {ENV_CONFIG['API_PORT']}") | |
| logger.info(f"Gradio Port: {ENV_CONFIG['SERVER_PORT']}") | |
| logger.info(f"MCP Enabled: {bool(ENV_CONFIG['MCP_ENDPOINT'])}") | |
| logger.info("=" * 60) | |
| # Run both FastAPI and Gradio | |
| async def run_servers(): | |
| # Start FastAPI server | |
| config = uvicorn.Config( | |
| app, | |
| host=ENV_CONFIG["SERVER_NAME"], | |
| port=ENV_CONFIG["API_PORT"], | |
| log_level=ENV_CONFIG["LOG_LEVEL"].lower() | |
| ) | |
| server = uvicorn.Server(config) | |
| # Create Gradio interface | |
| gr_interface = create_gradio_interface() | |
| # Run both servers concurrently | |
| await asyncio.gather( | |
| server.serve(), | |
| gr_interface.launch( | |
| server_name=ENV_CONFIG["SERVER_NAME"], | |
| server_port=ENV_CONFIG["SERVER_PORT"], | |
| share=ENV_CONFIG["SHARE"], | |
| max_threads=ENV_CONFIG["MAX_THREADS"], | |
| show_error=True | |
| ) | |
| ) | |
| # Run the servers | |
| asyncio.run(run_servers()) |