Spaces:
Running
Running
feat: add baseline scores JSON, inference script, and update Dockerfile for improved project structure
6d74982 | // main.go — GridMind-RL HTTP server (OpenEnv-compliant) | |
| // Exposes: POST /step, POST /reset, GET /state, GET /health, GET /replay, GET /grade, GET /metrics | |
| // Port: 7860 (Hugging Face Spaces compatible) | |
| package main | |
| import ( | |
| "encoding/json" | |
| "fmt" | |
| "log" | |
| "math" | |
| "net/http" | |
| "os" | |
| "strconv" | |
| "sync" | |
| "sync/atomic" | |
| "time" | |
| "gridmind-rl/env" | |
| ) | |
| // ────────────────────────────────────────────── | |
| // Prometheus-style metrics (OpenTelemetry) | |
| // ────────────────────────────────────────────── | |
| type Metrics struct { | |
| mu sync.Mutex | |
| stepCount int64 | |
| stepLatencySum float64 | |
| stepLatencyCount int64 | |
| rewardSum float64 | |
| rewardCount int64 | |
| rewardMin float64 | |
| rewardMax float64 | |
| // Histograms | |
| actionBuckets map[string]int64 // hvac bucket counts | |
| errorCount int64 | |
| } | |
| var metrics = &Metrics{ | |
| rewardMin: math.MaxFloat64, | |
| rewardMax: -math.MaxFloat64, | |
| actionBuckets: map[string]int64{"low": 0, "mid": 0, "high": 0}, | |
| } | |
| func (m *Metrics) recordStep(latencyMs float64, reward float64) { | |
| atomic.AddInt64(&m.stepCount, 1) | |
| m.mu.Lock() | |
| defer m.mu.Unlock() | |
| m.stepLatencySum += latencyMs | |
| m.stepLatencyCount++ | |
| m.rewardSum += reward | |
| m.rewardCount++ | |
| if reward < m.rewardMin { | |
| m.rewardMin = reward | |
| } | |
| if reward > m.rewardMax { | |
| m.rewardMax = reward | |
| } | |
| } | |
| func (m *Metrics) recordAction(hvac float64) { | |
| m.mu.Lock() | |
| defer m.mu.Unlock() | |
| switch { | |
| case hvac < 0.33: | |
| m.actionBuckets["low"]++ | |
| case hvac < 0.66: | |
| m.actionBuckets["mid"]++ | |
| default: | |
| m.actionBuckets["high"]++ | |
| } | |
| } | |
| func (m *Metrics) prometheus() string { | |
| m.mu.Lock() | |
| defer m.mu.Unlock() | |
| avgLatency := 0.0 | |
| if m.stepLatencyCount > 0 { | |
| avgLatency = m.stepLatencySum / float64(m.stepLatencyCount) | |
| } | |
| avgReward := 0.0 | |
| if m.rewardCount > 0 { | |
| avgReward = m.rewardSum / float64(m.rewardCount) | |
| } | |
| return fmt.Sprintf(`# HELP gridmind_steps_total Total environment steps taken | |
| # TYPE gridmind_steps_total counter | |
| gridmind_steps_total %d | |
| # HELP gridmind_step_latency_ms_avg Average step latency (ms) | |
| # TYPE gridmind_step_latency_ms_avg gauge | |
| gridmind_step_latency_ms_avg %.4f | |
| # HELP gridmind_reward_avg Average reward per step | |
| # TYPE gridmind_reward_avg gauge | |
| gridmind_reward_avg %.4f | |
| # HELP gridmind_reward_min Minimum reward seen | |
| # TYPE gridmind_reward_min gauge | |
| gridmind_reward_min %.4f | |
| # HELP gridmind_reward_max Maximum reward seen | |
| # TYPE gridmind_reward_max gauge | |
| gridmind_reward_max %.4f | |
| # HELP gridmind_action_hvac_bucket HVAC power level distribution | |
| # TYPE gridmind_action_hvac_bucket counter | |
| gridmind_action_hvac_bucket{bin="low"} %d | |
| gridmind_action_hvac_bucket{bin="mid"} %d | |
| gridmind_action_hvac_bucket{bin="high"} %d | |
| # HELP gridmind_errors_total Total request errors | |
| # TYPE gridmind_errors_total counter | |
| gridmind_errors_total %d | |
| `, | |
| atomic.LoadInt64(&m.stepCount), | |
| avgLatency, avgReward, | |
| m.rewardMin, m.rewardMax, | |
| m.actionBuckets["low"], m.actionBuckets["mid"], m.actionBuckets["high"], | |
| atomic.LoadInt64(&m.errorCount), | |
| ) | |
| } | |
| // ────────────────────────────────────────────── | |
| // Server | |
| // ────────────────────────────────────────────── | |
| type Server struct { | |
| envMgr *env.Environment | |
| } | |
| func newServer() *Server { | |
| return &Server{envMgr: env.NewEnvironment()} | |
| } | |
| func (s *Server) routes() *http.ServeMux { | |
| mux := http.NewServeMux() | |
| mux.HandleFunc("/health", s.handleHealth) | |
| mux.HandleFunc("/ping", s.handlePing) | |
| mux.HandleFunc("/reset", s.handleReset) | |
| mux.HandleFunc("/step", s.handleStep) | |
| mux.HandleFunc("/state", s.handleState) | |
| mux.HandleFunc("/replay", s.handleReplay) | |
| mux.HandleFunc("/grade", s.handleGrade) | |
| mux.HandleFunc("/tasks", s.handleTasks) | |
| mux.HandleFunc("/metrics", s.handleMetrics) | |
| return mux | |
| } | |
| // ── /health ────────────────────────────────────────────────────────────────── | |
| func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { | |
| w.Header().Set("Content-Type", "application/json") | |
| w.WriteHeader(http.StatusOK) | |
| json.NewEncoder(w).Encode(map[string]string{"status": "ok", "version": "1.0.0"}) | |
| } | |
| func (s *Server) handlePing(w http.ResponseWriter, r *http.Request) { | |
| w.Header().Set("Content-Type", "application/json") | |
| w.WriteHeader(http.StatusOK) | |
| json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) | |
| } | |
| // ── /reset ─────────────────────────────────────────────────────────────────── | |
| func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) { | |
| if r.Method != http.MethodPost { | |
| http.Error(w, "method not allowed", http.StatusMethodNotAllowed) | |
| return | |
| } | |
| var req env.ResetRequest | |
| if err := json.NewDecoder(r.Body).Decode(&req); err != nil { | |
| // Allow empty body → defaults | |
| req = env.ResetRequest{TaskID: 1} | |
| } | |
| if req.TaskID == 0 { | |
| req.TaskID = 1 | |
| } | |
| resp := s.envMgr.Reset(req) | |
| w.Header().Set("Content-Type", "application/json") | |
| json.NewEncoder(w).Encode(resp) | |
| } | |
| // ── /step ──────────────────────────────────────────────────────────────────── | |
| func (s *Server) handleStep(w http.ResponseWriter, r *http.Request) { | |
| if r.Method != http.MethodPost { | |
| http.Error(w, "method not allowed", http.StatusMethodNotAllowed) | |
| return | |
| } | |
| start := time.Now() | |
| // Accept both single action and array of actions | |
| var actions []env.ActionModel | |
| var single env.ActionModel | |
| body := make([]byte, 0, 512) | |
| buf := make([]byte, 512) | |
| for { | |
| n, err := r.Body.Read(buf) | |
| body = append(body, buf[:n]...) | |
| if err != nil { | |
| break | |
| } | |
| } | |
| if len(body) > 0 && body[0] == '[' { | |
| if err := json.Unmarshal(body, &actions); err != nil { | |
| atomic.AddInt64(&metrics.errorCount, 1) | |
| http.Error(w, "invalid action array: "+err.Error(), http.StatusBadRequest) | |
| return | |
| } | |
| } else { | |
| if err := json.Unmarshal(body, &single); err != nil { | |
| atomic.AddInt64(&metrics.errorCount, 1) | |
| http.Error(w, "invalid action: "+err.Error(), http.StatusBadRequest) | |
| return | |
| } | |
| actions = []env.ActionModel{single} | |
| } | |
| responses, done := s.envMgr.Step(actions) | |
| latency := float64(time.Since(start).Microseconds()) / 1000.0 | |
| for _, resp := range responses { | |
| metrics.recordStep(latency, resp.Reward) | |
| } | |
| if len(actions) > 0 { | |
| metrics.recordAction(actions[0].HVACPowerLevel) | |
| } | |
| w.Header().Set("Content-Type", "application/json") | |
| if done && len(responses) == 1 { | |
| responses[0].Done = true | |
| } | |
| // Return single response if single building, array otherwise | |
| if len(responses) == 1 { | |
| json.NewEncoder(w).Encode(responses[0]) | |
| } else { | |
| json.NewEncoder(w).Encode(responses) | |
| } | |
| } | |
| // ── /state ─────────────────────────────────────────────────────────────────── | |
| func (s *Server) handleState(w http.ResponseWriter, r *http.Request) { | |
| if r.Method != http.MethodGet { | |
| http.Error(w, "method not allowed", http.StatusMethodNotAllowed) | |
| return | |
| } | |
| state := s.envMgr.GetState() | |
| // Add CORS for dashboard | |
| w.Header().Set("Access-Control-Allow-Origin", "*") | |
| w.Header().Set("Content-Type", "application/json") | |
| json.NewEncoder(w).Encode(state) | |
| } | |
| // ── /replay ────────────────────────────────────────────────────────────────── | |
| func (s *Server) handleReplay(w http.ResponseWriter, r *http.Request) { | |
| if r.Method != http.MethodGet { | |
| http.Error(w, "method not allowed", http.StatusMethodNotAllowed) | |
| return | |
| } | |
| replay := s.envMgr.GetReplay() | |
| w.Header().Set("Content-Type", "application/json") | |
| w.Header().Set("Access-Control-Allow-Origin", "*") | |
| json.NewEncoder(w).Encode(map[string]interface{}{ | |
| "replay": replay, | |
| "steps": len(replay), | |
| }) | |
| } | |
| // ── /grade ─────────────────────────────────────────────────────────────────── | |
| func (s *Server) handleGrade(w http.ResponseWriter, r *http.Request) { | |
| state := s.envMgr.GetState() | |
| replay := s.envMgr.GetReplay() | |
| // Collect per-building exploit penalties | |
| penalties := make([]float64, len(state.Buildings)) | |
| for i := range state.Buildings { | |
| _, pen := s.envMgr.ExploitDetected(i) | |
| penalties[i] = pen | |
| } | |
| // Build building states from public state | |
| buildings := make([]*env.BuildingState, len(state.Buildings)) | |
| for i, pub := range state.Buildings { | |
| jobsCopy := make([]env.BatchJob, len(pub.Jobs)) | |
| copy(jobsCopy, pub.Jobs) | |
| buildings[i] = &env.BuildingState{ | |
| CumulativeCost: pub.CumulativeCost, | |
| BaselineCost: pub.BaselineCost, | |
| CumulativeCarbon: pub.CumulativeCarbon, | |
| BaselineCarbon: pub.BaselineCarbon, | |
| Jobs: jobsCopy, | |
| } | |
| } | |
| // Reconstruct temp history from public state | |
| tempHistory := make([][]float64, len(state.Buildings)) | |
| for i, pub := range state.Buildings { | |
| tempHistory[i] = pub.TempHistory | |
| } | |
| grade := env.GradeEpisode(env.GradeEpisodeInput{ | |
| TaskID: state.TaskID, | |
| Buildings: buildings, | |
| Replay: replay, | |
| TempHistory: tempHistory, | |
| TMin: env.TMinDefault, | |
| TMax: env.TMaxDefault, | |
| ExploitPenalties: penalties, | |
| }) | |
| w.Header().Set("Content-Type", "application/json") | |
| w.Header().Set("Access-Control-Allow-Origin", "*") | |
| json.NewEncoder(w).Encode(grade) | |
| } | |
| // ── /tasks ─────────────────────────────────────────────────────────────────── | |
| func (s *Server) handleTasks(w http.ResponseWriter, r *http.Request) { | |
| w.Header().Set("Content-Type", "application/json") | |
| json.NewEncoder(w).Encode(env.AllTasks()) | |
| } | |
| // ── /metrics ───────────────────────────────────────────────────────────────── | |
| func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request) { | |
| w.Header().Set("Content-Type", "text/plain; version=0.0.4") | |
| fmt.Fprint(w, metrics.prometheus()) | |
| } | |
| // ────────────────────────────────────────────── | |
| // Entry point | |
| // ────────────────────────────────────────────── | |
| func main() { | |
| port := os.Getenv("PORT") | |
| if port == "" { | |
| port = "7860" | |
| } | |
| // Validate port | |
| if _, err := strconv.Atoi(port); err != nil { | |
| log.Fatalf("invalid PORT: %s", port) | |
| } | |
| srv := newServer() | |
| // Perform initial reset so /state is always valid | |
| var seed int64 = 42 | |
| srv.envMgr.Reset(env.ResetRequest{Seed: &seed, TaskID: 1, NumBuildings: 1}) | |
| log.Printf("GridMind-RL environment server starting on :%s", port) | |
| log.Printf("Endpoints: GET /health /ping /state /replay /grade /tasks /metrics | POST /reset /step") | |
| mux := withCORS(withLogging(srv.routes())) | |
| if err := http.ListenAndServe(":"+port, mux); err != nil { | |
| log.Fatalf("server error: %v", err) | |
| } | |
| } | |
| // ────────────────────────────────────────────── | |
| // Middleware | |
| // ────────────────────────────────────────────── | |
| func withLogging(next http.Handler) http.Handler { | |
| return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
| start := time.Now() | |
| next.ServeHTTP(w, r) | |
| log.Printf("%s %s %s", r.Method, r.URL.Path, time.Since(start)) | |
| }) | |
| } | |
| func withCORS(next http.Handler) http.Handler { | |
| return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
| w.Header().Set("Access-Control-Allow-Origin", "*") | |
| w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") | |
| w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") | |
| if r.Method == http.MethodOptions { | |
| w.WriteHeader(http.StatusNoContent) | |
| return | |
| } | |
| next.ServeHTTP(w, r) | |
| }) | |
| } | |