import gradio as gr import numpy as np import pandas as pd import torch import torch.nn as nn from transformers import T5EncoderModel, AutoConfig from sklearn.preprocessing import StandardScaler import pickle from huggingface_hub import hf_hub_download import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt MODEL_NAME = 'google/t5-efficient-tiny' CONTEXT_LENGTH = 60 PREDICTION_LENGTH = 14 NUM_VARIATES = 5 EMBED_DIM = 64 class RetailWorldModel(nn.Module): def __init__(self, base_model_name, context_len, pred_len, num_variates, embed_dim): super().__init__() self.config = AutoConfig.from_pretrained(base_model_name) self.encoder = T5EncoderModel.from_pretrained(base_model_name) self.context_len = context_len self.pred_len = pred_len self.num_variates = num_variates self.embed_dim = embed_dim d_model = self.config.d_model self.input_proj = nn.Linear(num_variates, d_model) self.latent_dynamics = nn.LSTM(d_model, d_model, 2, batch_first=True, dropout=0.1) self.mean_head = nn.Sequential(nn.Linear(d_model, embed_dim), nn.GELU(), nn.Linear(embed_dim, 1)) self.var_head = nn.Sequential(nn.Linear(d_model, embed_dim), nn.GELU(), nn.Linear(embed_dim, 1), nn.Softplus()) def forward(self, context): x = self.input_proj(context) enc_out = self.encoder(inputs_embeds=x, return_dict=True).last_hidden_state h0 = enc_out[:, -1:, :].transpose(0, 1).repeat(2, 1, 1) c0 = torch.zeros_like(h0) states = [] curr = enc_out[:, -1:, :] for _ in range(self.pred_len): out, (h0, c0) = self.latent_dynamics(curr, (h0, c0)) states.append(out) curr = out states = torch.cat(states, dim=1) mean = self.mean_head(states).squeeze(-1) var = self.var_head(states).squeeze(-1) return {'mean': mean, 'var': var} model = None scaler = None try: base_model = RetailWorldModel(MODEL_NAME, CONTEXT_LENGTH, PREDICTION_LENGTH, NUM_VARIATES, EMBED_DIM) ckpt_path = hf_hub_download(repo_id="superdkj/retail-world-model-v1", filename="pytorch_model.bin") scaler_path = hf_hub_download(repo_id="superdkj/retail-world-model-v1", filename="scaler.pkl") state_dict = torch.load(ckpt_path, map_location='cpu') base_model.load_state_dict(state_dict, strict=False) model = base_model model.eval() with open(scaler_path, 'rb') as f: scaler = pickle.load(f) print("Model loaded!") except Exception as e: print(f"Model not available: {e}") FAMILIES = [ 'AUTOMOTIVE','BABY CARE','BEAUTY','BEVERAGES','BOOKS','BREAD/BAKERY', 'CELEBRATION','CLEANING','DAIRY','DELI','EGGS','FROZEN FOODS', 'GROCERY I','GROCERY II','HARDWARE','HOME AND KITCHEN I','HOME AND KITCHEN II', 'HOME APPLIANCES','HOME CARE','LADIESWEAR','LAWN AND GARDEN','LINGERIE', 'LIQUOR,WINE,BEER','MAGAZINES','MEATS','PERSONAL CARE','PET SUPPLIES', 'PLAYERS AND ELECTRONICS','POULTRY','PREPARED FOODS','PRODUCE', 'SCHOOL AND OFFICE SUPPLIES','SEAFOOD' ] FAMILY_MAP = {f: i/33 for i, f in enumerate(FAMILIES)} def generate_sample_data(family_name, store_nbr): np.random.seed(42 + hash(family_name) % 1000 + int(store_nbr)) dates = pd.date_range(end=pd.Timestamp.now(), periods=CONTEXT_LENGTH, freq='D') base_sales = np.random.uniform(50, 500) trend = np.random.uniform(-1, 1) seasonality = 20 * np.sin(2 * np.pi * np.arange(CONTEXT_LENGTH) / 7) noise = np.random.normal(0, 15, CONTEXT_LENGTH) sales = np.maximum(0, base_sales + trend * np.arange(CONTEXT_LENGTH) + seasonality + noise) onprom = np.random.choice([0, 1], CONTEXT_LENGTH, p=[0.8, 0.2]) df = pd.DataFrame({'date': dates.strftime('%Y-%m-%d'), 'sales': sales.round(2), 'onpromotion': onprom}) return df def predict_sales(history_csv, family_name, store_nbr, onpromotion_flag): if history_csv is None: df = generate_sample_data(family_name, int(store_nbr)) else: try: df = pd.read_csv(history_csv.name if hasattr(history_csv, 'name') else history_csv) except Exception as e: df = generate_sample_data(family_name, int(store_nbr)) if len(df) < CONTEXT_LENGTH: df = generate_sample_data(family_name, int(store_nbr)) df = df.tail(CONTEXT_LENGTH).copy() df['date'] = pd.to_datetime(df['date']) df['day_of_week'] = df['date'].dt.dayofweek / 6.0 df['month'] = df['date'].dt.month / 12.0 family_enc = FAMILY_MAP.get(family_name, 0.5) if scaler is not None and 'sales' in df.columns: df['sales_scaled'] = scaler.transform(df['sales'].values.reshape(-1, 1)).flatten() else: mean = df['sales'].mean() std = df['sales'].std() + 1e-6 df['sales_scaled'] = (df['sales'] - mean) / std onprom_col = df['onpromotion'].values if 'onpromotion' in df.columns else [int(onpromotion_flag)] * len(df) context = np.stack([ df['sales_scaled'].values, onprom_col, df['day_of_week'].values, df['month'].values, [family_enc] * len(df) ], axis=1).astype(np.float32) if model is not None: with torch.no_grad(): ctx_tensor = torch.tensor(context).unsqueeze(0) out = model(ctx_tensor) mean_pred = out['mean'].squeeze(0).numpy() std_pred = np.sqrt(out['var'].squeeze(0).numpy()) else: last_sales = df['sales'].values[-14:] trend = np.polyfit(range(len(last_sales)), last_sales, 1)[0] if len(last_sales) > 1 else 0 mean_pred = (df['sales'].values[-1] + trend * np.arange(1, PREDICTION_LENGTH + 1) - df['sales'].mean()) / (df['sales'].std() + 1e-6) std_pred = np.ones(PREDICTION_LENGTH) * 0.5 if scaler is not None: mean_sales = scaler.inverse_transform(mean_pred.reshape(-1, 1)).flatten() std_sales = std_pred * scaler.scale_[0] else: mean_sales = mean_pred * std + mean std_sales = std_pred * std future_dates = pd.date_range(df['date'].iloc[-1] + pd.Timedelta(days=1), periods=PREDICTION_LENGTH) results = pd.DataFrame({ 'date': future_dates.strftime('%Y-%m-%d'), 'predicted_sales': np.maximum(0, mean_sales).round(2), 'lower_90': np.maximum(0, mean_sales - 1.645 * std_sales).round(2), 'upper_90': (mean_sales + 1.645 * std_sales).round(2), }) fig, ax = plt.subplots(figsize=(12, 6)) hist_dates = df['date'] ax.plot(hist_dates, df['sales'], label='Historical', color='blue', alpha=0.7, marker='o', markersize=3) ax.plot(future_dates, results['predicted_sales'], label='Predicted', color='red', linewidth=2, marker='s', markersize=4) ax.fill_between(future_dates, results['lower_90'], results['upper_90'], alpha=0.3, color='red', label='90% CI') ax.axvline(x=hist_dates.iloc[-1], color='gray', linestyle='--', alpha=0.5, label='Forecast Start') ax.set_xlabel('Date') ax.set_ylabel('Sales ($)') ax.set_title(f'Retail Sales Forecast - {family_name} (Store {store_nbr})') ax.legend() ax.grid(True, alpha=0.3) plt.tight_layout() table_md = results.to_markdown(index=False) summary = f"""## 📊 Forecast Summary - **Product Family**: {family_name} - **Store**: {store_nbr} - **Historical Period**: {len(df)} days - **Forecast Period**: {PREDICTION_LENGTH} days - **Avg Predicted Sales**: ${results['predicted_sales'].mean():.2f} - **Total Forecasted Sales**: ${results['predicted_sales'].sum():.2f} ### Predictions Table {table_md} """ return summary, fig with gr.Blocks(title="Retail World Model") as demo: gr.Markdown(""" # 🛒 Retail World Model - Sales Forecasting A transformer-based **world model** that encodes historical retail data, imagines future states via latent dynamics, and decodes probabilistic forecasts with 90% confidence intervals. """) with gr.Row(): with gr.Column(scale=1): history_csv = gr.File(label="Sales History CSV", file_types=[".csv"]) family_dropdown = gr.Dropdown(choices=FAMILIES, label="Product Family", value="BEVERAGES") store_input = gr.Number(label="Store Number", value=1, minimum=1, maximum=54, step=1) onprom_check = gr.Checkbox(label="On Promotion", value=False) predict_btn = gr.Button("🔮 Predict", variant="primary") with gr.Column(scale=2): summary_output = gr.Markdown() chart_output = gr.Plot() gr.Markdown("Leave CSV empty to use synthetic demo data. Model trained on Favorita store sales dataset (54 stores, 33 product families, ~3M records).") predict_btn.click(fn=predict_sales, inputs=[history_csv, family_dropdown, store_input, onprom_check], outputs=[summary_output, chart_output]) if __name__ == "__main__": demo.launch()