| from typing import List |
|
|
| import numpy as np |
| import pandas as pd |
| import plotly.graph_objects as go |
|
|
|
|
| def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure: |
| """ |
| Plot the training and test datasets using Plotly. |
| |
| Args: |
| df1 (pd.DataFrame): Train dataset |
| df2 (pd.DataFrame): Test dataset |
| |
| Returns: |
| None |
| """ |
|
|
| |
| fig = go.Figure() |
|
|
| |
| fig.add_trace( |
| go.Scatter( |
| x=df1.index, |
| y=df1.iloc[:, 0], |
| mode="lines", |
| name="Training Data", |
| line=dict(color="steelblue"), |
| marker=dict(color="steelblue"), |
| ) |
| ) |
|
|
| |
| fig.add_trace( |
| go.Scatter( |
| x=df2.index, |
| y=df2.iloc[:, 0], |
| mode="lines", |
| name="Test Data", |
| line=dict(color="gold"), |
| marker=dict(color="gold"), |
| ) |
| ) |
|
|
| |
| fig.update_layout( |
| title="Univariate Time Series", |
| xaxis=dict(title="Date"), |
| yaxis=dict(title="Value"), |
| showlegend=True, |
| template="plotly_white", |
| ) |
| return fig |
|
|
|
|
| def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]): |
| """ |
| Plot the true values and forecasts using Plotly. |
| |
| Args: |
| df (pd.DataFrame): DataFrame with the true values. Assumed to have an index and columns. |
| forecasts (List[pd.DataFrame]): List of DataFrames containing the forecasts. |
| |
| Returns: |
| go.Figure: Plotly figure object. |
| """ |
|
|
| |
| fig = go.Figure() |
|
|
| |
| fig.add_trace( |
| go.Scatter( |
| x=pd.to_datetime(df.index), |
| y=df.iloc[:, 0], |
| mode="lines", |
| name="True values", |
| line=dict(color="black"), |
| ) |
| ) |
|
|
| |
| colors = ["green", "blue", "purple"] |
| for i, forecast in enumerate(forecasts): |
| color = colors[i % len(colors)] |
| for sample in forecast.samples: |
| fig.add_trace( |
| go.Scatter( |
| x=forecast.index.to_timestamp(), |
| y=sample, |
| mode="lines", |
| opacity=0.15, |
| name=f"Forecast {i + 1}", |
| showlegend=False, |
| hoverinfo="none", |
| line=dict(color=color), |
| ) |
| ) |
| |
| mean_forecast = np.mean(forecast.samples, axis=0) |
| fig.add_trace( |
| go.Scatter( |
| x=forecast.index.to_timestamp(), |
| y=mean_forecast, |
| mode="lines", |
| name="Mean Forecast", |
| line=dict(color="red", dash="dash"), |
| legendgroup="mean forecast", |
| showlegend=i == 0, |
| ) |
| ) |
|
|
| |
| fig.update_layout( |
| title=f"{df.columns[0]} Forecast", |
| yaxis=dict(title=df.columns[0]), |
| showlegend=True, |
| legend=dict(x=0, y=1), |
| hovermode="x", |
| ) |
|
|
| |
| return fig |
|
|