Gil Stetler commited on
Commit
8a29bed
Β·
1 Parent(s): 67dad62

add % mse and rmse

Browse files
Files changed (1) hide show
  1. app.py +33 -17
app.py CHANGED
@@ -18,7 +18,7 @@ dtype = torch.bfloat16 if device == "cuda" else torch.float32
18
  # Load once at startup (HF Spaces cache between runs)
19
  pipe = ChronosPipeline.from_pretrained(
20
  MODEL_ID,
21
- device_map="auto",
22
  torch_dtype=dtype,
23
  )
24
 
@@ -33,9 +33,7 @@ def run_forecast_and_evaluate():
33
  if n <= PREDICTION_LENGTH + 5:
34
  raise gr.Error("Time series too short for a holdout evaluation.")
35
 
36
- # 2) Train/forecast split:
37
- # Use all but the last PREDICTION_LENGTH points as context (train),
38
- # and compare forecast to the real last PREDICTION_LENGTH points (test).
39
  y_train = y[: n - PREDICTION_LENGTH]
40
  y_test = y[n - PREDICTION_LENGTH :]
41
 
@@ -44,21 +42,29 @@ def run_forecast_and_evaluate():
44
  samples = fcst[0].cpu().numpy() # (S, H)
45
 
46
  # 3) Summaries & metrics
47
- low, median, high = np.quantile(samples, [0.1, 0.5, 0.9], axis=0)
48
 
49
- # "mean standard error" is ambiguous; commonly MSE + RMSE are reported:
50
- mse = float(np.mean((median - y_test) ** 2))
51
  rmse = float(np.sqrt(mse))
52
 
53
- # 4) Plot: full history + forecast horizon vs ground truth
 
 
 
 
 
 
 
 
54
  fig = plt.figure(figsize=(9, 4))
55
  x_hist = np.arange(len(y_train))
56
  x_fcst = np.arange(len(y_train), len(y_train) + PREDICTION_LENGTH)
57
 
58
  plt.plot(x_hist, y_train, label="history")
59
  plt.plot(x_fcst, y_test, label="actual (holdout)")
60
- plt.plot(x_fcst, median, linestyle="--", label="forecast (median)")
61
- plt.fill_between(x_fcst, low, high, alpha=0.3, label="80% interval")
62
  plt.title("Chronos-T5-Large β€’ Holdout Evaluation")
63
  plt.xlabel("time")
64
  plt.ylabel("#Passengers")
@@ -69,22 +75,32 @@ def run_forecast_and_evaluate():
69
  out_json = {
70
  "prediction_length": int(PREDICTION_LENGTH),
71
  "num_samples": int(NUM_SAMPLES),
72
- "metrics": {"MSE": mse, "RMSE": rmse},
73
- "median": median.tolist(),
74
- "p10": low.tolist(),
75
- "p90": high.tolist(),
 
 
 
 
 
 
 
76
  "actual": y_test.tolist(),
77
  }
78
 
79
- # Metrics text to display prominently
80
- metrics_md = f"**MSE:** {mse:.3f}  **RMSE:** {rmse:.3f}"
 
 
 
81
  return fig, out_json, metrics_md
82
 
83
  with gr.Blocks(title="Chronos-T5-Large β€’ Holdout Demo") as demo:
84
  gr.Markdown(
85
  "## Chronos-T5-Large (zero-shot forecasting) β€” Holdout Evaluation\n"
86
  "Click **Run** to forecast the last 12 months from AirPassengers and compare to the true values.\n"
87
- "Computation runs on this Space's server hardware."
88
  )
89
  run_btn = gr.Button("Run", variant="primary")
90
  plot = gr.Plot(label="Forecast vs Actual (holdout)")
 
18
  # Load once at startup (HF Spaces cache between runs)
19
  pipe = ChronosPipeline.from_pretrained(
20
  MODEL_ID,
21
+ device_map="auto", # uses GPU if available
22
  torch_dtype=dtype,
23
  )
24
 
 
33
  if n <= PREDICTION_LENGTH + 5:
34
  raise gr.Error("Time series too short for a holdout evaluation.")
35
 
36
+ # 2) Holdout split: forecast the last 12 points
 
 
37
  y_train = y[: n - PREDICTION_LENGTH]
38
  y_test = y[n - PREDICTION_LENGTH :]
39
 
 
42
  samples = fcst[0].cpu().numpy() # (S, H)
43
 
44
  # 3) Summaries & metrics
45
+ p10, p50, p90 = np.quantile(samples, [0.1, 0.5, 0.9], axis=0)
46
 
47
+ # Point forecast = median
48
+ mse = float(np.mean((p50 - y_test) ** 2))
49
  rmse = float(np.sqrt(mse))
50
 
51
+ # Percent versions (relative to the mean of true holdout)
52
+ mean_y = float(np.mean(y_test))
53
+ rmse_pct = float(100.0 * rmse / mean_y) # RMSE as % of mean
54
+ mse_pct = float(100.0 * mse / (mean_y ** 2)) # MSE as % of mean^2
55
+
56
+ # (Optional) MAPE if you ever want it:
57
+ # mape_pct = float(100.0 * np.mean(np.abs((p50 - y_test) / y_test)))
58
+
59
+ # 4) Plot: history + forecast horizon vs ground truth
60
  fig = plt.figure(figsize=(9, 4))
61
  x_hist = np.arange(len(y_train))
62
  x_fcst = np.arange(len(y_train), len(y_train) + PREDICTION_LENGTH)
63
 
64
  plt.plot(x_hist, y_train, label="history")
65
  plt.plot(x_fcst, y_test, label="actual (holdout)")
66
+ plt.plot(x_fcst, p50, linestyle="--", label="forecast (median)")
67
+ plt.fill_between(x_fcst, p10, p90, alpha=0.3, label="80% interval")
68
  plt.title("Chronos-T5-Large β€’ Holdout Evaluation")
69
  plt.xlabel("time")
70
  plt.ylabel("#Passengers")
 
75
  out_json = {
76
  "prediction_length": int(PREDICTION_LENGTH),
77
  "num_samples": int(NUM_SAMPLES),
78
+ "metrics": {
79
+ "MSE": mse,
80
+ "RMSE": rmse,
81
+ "RMSE_%_of_mean": rmse_pct,
82
+ "MSE_%_of_mean^2": mse_pct,
83
+ # "MAPE_%": mape_pct, # uncomment if you add MAPE
84
+ "mean_of_truth": mean_y,
85
+ },
86
+ "median": p50.tolist(),
87
+ "p10": p10.tolist(),
88
+ "p90": p90.tolist(),
89
  "actual": y_test.tolist(),
90
  }
91
 
92
+ metrics_md = (
93
+ f"**MSE:** {mse:.3f}  **RMSE:** {rmse:.3f}  "
94
+ f"**RMSE% of mean:** {rmse_pct:.2f}%  "
95
+ f"**MSE% of meanΒ²:** {mse_pct:.3f}%"
96
+ )
97
  return fig, out_json, metrics_md
98
 
99
  with gr.Blocks(title="Chronos-T5-Large β€’ Holdout Demo") as demo:
100
  gr.Markdown(
101
  "## Chronos-T5-Large (zero-shot forecasting) β€” Holdout Evaluation\n"
102
  "Click **Run** to forecast the last 12 months from AirPassengers and compare to the true values.\n"
103
+ "Shows MSE, RMSE, and RMSE% / MSE% relative to the mean of the 12 true values."
104
  )
105
  run_btn = gr.Button("Run", variant="primary")
106
  plot = gr.Plot(label="Forecast vs Actual (holdout)")