BrianIsaac's picture
fix: add @spaces.GPU decorator for HF Spaces ZeroGPU compatibility
a6f24fe
"""Ensemble Predictor MCP Server.
This MCP server provides time series forecasting using ensemble deep learning methods:
- Chronos-Bolt Tiny (Amazon's pre-trained foundation model - 9M params)
- Tiny Time Mixers (IBM's lightweight MLP-Mixer - ~1M params)
- N-HiTS (Nixtla's hierarchical neural network)
- Ensemble combination (mean, median, weighted)
Designed for financial time series with proper preprocessing and uncertainty quantification.
"""
import logging
from typing import Dict, List, Optional, Sequence, Literal
from decimal import Decimal
import numpy as np
import pandas as pd
from fastmcp import FastMCP
from pydantic import BaseModel, Field
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
# HuggingFace Spaces GPU decorator (auto-installed on HF Spaces)
# This decorator is effect-free on non-ZeroGPU environments (CPU, local dev)
try:
import spaces
SPACES_GPU_AVAILABLE = True
except ImportError:
SPACES_GPU_AVAILABLE = False
# Create no-op decorator for local development
class _SpacesMock:
"""Mock spaces module for environments without HF Spaces."""
@staticmethod
def GPU(duration: int = 120):
"""No-op GPU decorator for non-HF environments."""
def decorator(func):
return func
return decorator
spaces = _SpacesMock()
logger = logging.getLogger(__name__)
# Initialize MCP server
mcp = FastMCP("ensemble-predictor")
# Try to import Chronos
try:
import torch
from chronos import BaseChronosPipeline
CHRONOS_AVAILABLE = True
logger.info("Chronos library available")
except ImportError:
CHRONOS_AVAILABLE = False
logger.warning("Chronos library not available")
# Try to import TTM (Tiny Time Mixers)
try:
from tsfm_public import TinyTimeMixerForPrediction
from tsfm_public.toolkit.get_model import get_model as get_ttm_model
TTM_AVAILABLE = True
logger.info("TTM library available")
except ImportError:
TTM_AVAILABLE = False
logger.warning("TTM library not available")
# Try to import N-HiTS
try:
from neuralforecast import NeuralForecast
from neuralforecast.models import NHITS
from neuralforecast.losses.pytorch import MAE
NHITS_AVAILABLE = True
logger.info("NeuralForecast library available")
except ImportError:
NHITS_AVAILABLE = False
logger.warning("NeuralForecast library not available")
# Global model cache (loaded once at startup)
_chronos_pipeline = None
_ttm_model = None
_nhits_model = None
class ForecastRequest(BaseModel):
"""Request for time series forecast."""
ticker: str
prices: List[Decimal] = Field(..., min_length=10, description="Historical prices")
dates: Optional[List[str]] = Field(default=None, description="Corresponding dates")
forecast_horizon: int = Field(default=30, ge=1, le=252, description="Number of periods to forecast")
confidence_level: float = Field(default=0.95, ge=0.5, le=0.99, description="Confidence level for intervals")
use_returns: bool = Field(default=True, description="Forecast returns instead of raw prices")
ensemble_method: Literal["mean", "median", "weighted"] = Field(default="mean", description="Ensemble combination method")
weights: Optional[Dict[str, float]] = Field(default=None, description="Model weights for weighted ensemble")
class ForecastResult(BaseModel):
"""Forecast result with predictions and uncertainty."""
ticker: str
method: str
forecast_horizon: int
# Point predictions
predictions: List[Decimal] = Field(..., description="Point forecast values")
# Uncertainty intervals
lower_bound: List[Decimal] = Field(..., description="Lower confidence bound")
upper_bound: List[Decimal] = Field(..., description="Upper confidence bound")
confidence_level: float
# Metadata
models_used: List[str] = Field(..., description="Models included in ensemble")
individual_forecasts: Optional[Dict[str, List[Decimal]]] = Field(default=None, description="Individual model predictions")
metadata: Dict[str, str] = Field(default_factory=dict)
def _preprocess_financial_series(
prices: np.ndarray,
use_returns: bool = True,
max_clip_std: float = 5.0
) -> np.ndarray:
"""Preprocess financial time series for forecasting.
Args:
prices: Raw price series
use_returns: If True, convert to log returns
max_clip_std: Maximum standard deviations for outlier clipping
Returns:
Preprocessed series
"""
if use_returns:
# Calculate log returns
returns = np.diff(np.log(prices))
# Remove infinities and NaNs
returns = np.where(np.isinf(returns), np.nan, returns)
# Clip extreme outliers
if len(returns) > 0:
mean, std = np.nanmean(returns), np.nanstd(returns)
if std > 0:
returns = np.clip(returns, mean - max_clip_std * std, mean + max_clip_std * std)
# Forward fill NaN (simple strategy)
mask = np.isnan(returns)
if np.any(mask):
returns[mask] = np.interp(
np.flatnonzero(mask),
np.flatnonzero(~mask),
returns[~mask]
)
return returns
else:
# Use log prices for better stability
return np.log(prices)
def _postprocess_predictions(
predictions: np.ndarray,
last_price: float,
use_returns: bool = True
) -> np.ndarray:
"""Convert predictions back to price space.
Args:
predictions: Predicted values (returns or log prices)
last_price: Last observed price
use_returns: Whether predictions are returns
Returns:
Predicted prices
"""
if use_returns:
# Convert returns to prices
cumulative_returns = np.cumsum(predictions)
prices = last_price * np.exp(cumulative_returns)
return prices
else:
# Convert log prices to prices
return np.exp(predictions)
def _forecast_chronos(
context: np.ndarray,
horizon: int,
model_name: str = "amazon/chronos-bolt-tiny"
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Generate forecast using Chronos model.
Args:
context: Historical context (preprocessed)
horizon: Forecast horizon
model_name: Chronos model to use
Returns:
Tuple of (median_forecast, lower_bound, upper_bound)
"""
global _chronos_pipeline
if not CHRONOS_AVAILABLE:
raise RuntimeError("Chronos library not available")
# Load model on first use (cached globally)
if _chronos_pipeline is None:
logger.info(f"Loading Chronos model: {model_name}")
try:
# Detect available device (GPU preferred, CPU fallback)
if torch.cuda.is_available():
device = "cuda"
dtype = torch.bfloat16 # Use bfloat16 for GPU (faster)
logger.info("GPU detected - using CUDA for inference")
else:
device = "cpu"
dtype = torch.float32
logger.info("No GPU detected - using CPU for inference")
_chronos_pipeline = BaseChronosPipeline.from_pretrained(
model_name,
device_map=device,
dtype=dtype,
)
logger.info(f"Chronos model loaded successfully on {device}")
except Exception as e:
logger.error(f"Failed to load Chronos model: {e}")
raise
# Convert to tensor and move to same device as model
device = next(_chronos_pipeline.model.parameters()).device
context_tensor = torch.tensor(context, dtype=torch.float32).to(device)
# Generate forecast
try:
forecast = _chronos_pipeline.predict(
context_tensor,
prediction_length=horizon,
)
# Extract quantiles (forecast shape: [1, num_samples, horizon])
# Move to CPU if on GPU before converting to numpy
forecast_np = forecast.cpu().numpy()[0] # Shape: [num_samples, horizon]
# Calculate median and confidence intervals
median = np.median(forecast_np, axis=0)
lower = np.quantile(forecast_np, 0.025, axis=0) # 95% CI
upper = np.quantile(forecast_np, 0.975, axis=0)
return median, lower, upper
except Exception as e:
logger.error(f"Chronos forecasting failed: {e}")
raise
def _forecast_ttm(
context: np.ndarray,
horizon: int,
context_length: int = 512
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Generate forecast using IBM Tiny Time Mixers.
Args:
context: Historical context (preprocessed returns)
horizon: Forecast horizon
context_length: Model context length (512, 1024, or 1536)
Returns:
Tuple of (forecast, lower_bound, upper_bound)
"""
global _ttm_model
if not TTM_AVAILABLE:
raise RuntimeError("TTM library not available")
# Load model on first use
if _ttm_model is None:
logger.info("Loading TTM model")
try:
# Use prediction length closest to requested horizon
pred_lengths = [96, 192, 336, 720]
prediction_length = min(pred_lengths, key=lambda x: abs(x - horizon))
_ttm_model = get_ttm_model(
model_path="ibm-granite/granite-timeseries-ttm-r2",
context_length=context_length,
prediction_length=prediction_length,
)
# Move to appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"
_ttm_model = _ttm_model.to(device)
_ttm_model.eval()
logger.info(f"TTM model loaded on {device}")
except Exception as e:
logger.error(f"Failed to load TTM model: {e}")
raise
try:
device = next(_ttm_model.parameters()).device
# Prepare input - TTM expects (batch, context_length, channels)
# Pad or truncate context to match model's context length
if len(context) >= context_length:
input_context = context[-context_length:]
else:
# Pad with mean if context is shorter
pad_length = context_length - len(context)
input_context = np.concatenate([
np.full(pad_length, np.mean(context)),
context
])
# Reshape to (batch, context_length, channels)
input_tensor = torch.tensor(
input_context.reshape(1, -1, 1),
dtype=torch.float32
).to(device)
# Generate forecast
with torch.no_grad():
output = _ttm_model(input_tensor)
predictions = output.prediction_outputs.cpu().numpy()[0, :horizon, 0]
# TTM doesn't provide native uncertainty, estimate from context
historical_std = np.std(context)
uncertainty = historical_std * np.sqrt(np.arange(1, horizon + 1))
lower = predictions - 1.96 * uncertainty
upper = predictions + 1.96 * uncertainty
return predictions, lower, upper
except Exception as e:
logger.error(f"TTM forecasting failed: {e}")
raise
def _forecast_nhits(
context: np.ndarray,
horizon: int,
max_steps: int = 100
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Generate forecast using N-HiTS neural network.
Args:
context: Historical context (preprocessed returns)
horizon: Forecast horizon
max_steps: Training steps (lower = faster)
Returns:
Tuple of (forecast, lower_bound, upper_bound)
"""
if not NHITS_AVAILABLE:
raise RuntimeError("NeuralForecast library not available")
try:
# Prepare data in NeuralForecast format
dates = pd.date_range(end=pd.Timestamp.now(), periods=len(context), freq='D')
df = pd.DataFrame({
'unique_id': 'series',
'ds': dates,
'y': context.astype(np.float32)
})
# Configure lightweight N-HiTS model
input_size = min(len(context) - 1, horizon * 3) # Context for model
model = NHITS(
h=horizon,
input_size=input_size,
max_steps=max_steps,
n_blocks=[1, 1, 1], # Lightweight: 3 stacks, 1 block each
mlp_units=[[256, 256], [256, 256], [256, 256]],
n_pool_kernel_size=[2, 2, 1],
n_freq_downsample=[4, 2, 1],
learning_rate=1e-3,
batch_size=32,
scaler_type='robust',
loss=MAE(),
accelerator='auto',
enable_progress_bar=False,
logger=False,
)
# Fit and predict
nf = NeuralForecast(models=[model], freq='D')
nf.fit(df=df)
forecasts = nf.predict()
# Extract predictions
predictions = forecasts['NHITS'].values
# Estimate uncertainty from historical volatility
historical_std = np.std(context)
uncertainty = historical_std * np.sqrt(np.arange(1, horizon + 1))
lower = predictions - 1.96 * uncertainty
upper = predictions + 1.96 * uncertainty
logger.info(f"N-HiTS forecast completed (steps={max_steps})")
return predictions, lower, upper
except Exception as e:
logger.error(f"N-HiTS forecasting failed: {e}")
raise
def _combine_forecasts(
forecasts: Dict[str, np.ndarray],
method: str = "mean",
weights: Optional[Dict[str, float]] = None
) -> np.ndarray:
"""Combine multiple forecasts using specified method.
Args:
forecasts: Dictionary of model forecasts
method: Combination method (mean, median, weighted)
weights: Optional weights for weighted averaging
Returns:
Combined forecast
"""
if not forecasts:
raise ValueError("No forecasts to combine")
# Stack forecasts into array
forecast_array = np.array(list(forecasts.values()))
if method == "mean":
return np.mean(forecast_array, axis=0)
elif method == "median":
return np.median(forecast_array, axis=0)
elif method == "weighted":
if weights is None:
# Equal weights
weights = {name: 1.0 / len(forecasts) for name in forecasts.keys()}
# Normalise weights
total_weight = sum(weights.get(name, 0) for name in forecasts.keys())
if total_weight == 0:
raise ValueError("Total weight is zero")
weighted_sum = np.zeros_like(forecast_array[0])
for i, name in enumerate(forecasts.keys()):
w = weights.get(name, 0) / total_weight
weighted_sum += w * forecast_array[i]
return weighted_sum
else:
raise ValueError(f"Unknown combination method: {method}")
@spaces.GPU(duration=120)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
retry=retry_if_exception_type((TimeoutError, ConnectionError, Exception)),
)
@mcp.tool()
async def forecast_ensemble(request: ForecastRequest) -> ForecastResult:
"""Generate ensemble forecast for time series.
Combines multiple forecasting models (Chronos + statistical baselines)
to produce robust predictions with uncertainty quantification.
Args:
request: Forecast request with prices and parameters
Returns:
Ensemble forecast with confidence intervals
Example:
>>> await forecast_ensemble(ForecastRequest(
... ticker="AAPL",
... prices=[150.0, 151.0, 152.5, ...],
... forecast_horizon=30
... ))
"""
logger.info(f"Generating ensemble forecast for {request.ticker} (horizon={request.forecast_horizon})")
try:
# Convert prices to numpy
prices_array = np.array([float(p) for p in request.prices])
if len(prices_array) < 10:
raise ValueError("Need at least 10 historical prices for forecasting")
# Preprocess to returns or log prices
context = _preprocess_financial_series(prices_array, use_returns=request.use_returns)
# Collect forecasts from available models
individual_forecasts = {}
individual_lower = {}
individual_upper = {}
models_used = []
# Try Chronos-Bolt Tiny if available
if CHRONOS_AVAILABLE:
try:
chronos_pred, chronos_lower, chronos_upper = _forecast_chronos(
context,
request.forecast_horizon
)
individual_forecasts["chronos"] = chronos_pred
individual_lower["chronos"] = chronos_lower
individual_upper["chronos"] = chronos_upper
models_used.append("chronos")
logger.info("Chronos-Bolt Tiny forecast completed")
except Exception as e:
logger.warning(f"Chronos forecast failed: {e}")
# Try TTM (Tiny Time Mixers) if available
if TTM_AVAILABLE:
try:
ttm_pred, ttm_lower, ttm_upper = _forecast_ttm(
context,
request.forecast_horizon
)
individual_forecasts["ttm"] = ttm_pred
individual_lower["ttm"] = ttm_lower
individual_upper["ttm"] = ttm_upper
models_used.append("ttm")
logger.info("TTM forecast completed")
except Exception as e:
logger.warning(f"TTM forecast failed: {e}")
# Try N-HiTS if available
if NHITS_AVAILABLE:
try:
nhits_pred, nhits_lower, nhits_upper = _forecast_nhits(
context,
request.forecast_horizon
)
individual_forecasts["nhits"] = nhits_pred
individual_lower["nhits"] = nhits_lower
individual_upper["nhits"] = nhits_upper
models_used.append("nhits")
logger.info("N-HiTS forecast completed")
except Exception as e:
logger.warning(f"N-HiTS forecast failed: {e}")
if not individual_forecasts:
raise RuntimeError("All forecasting models failed")
# Combine forecasts
ensemble_pred = _combine_forecasts(
individual_forecasts,
method=request.ensemble_method,
weights=request.weights
)
# Combine uncertainty (use mean of bounds for simplicity)
ensemble_lower = _combine_forecasts(
individual_lower,
method=request.ensemble_method,
weights=request.weights
)
ensemble_upper = _combine_forecasts(
individual_upper,
method=request.ensemble_method,
weights=request.weights
)
# Post-process back to price space
last_price = float(prices_array[-1])
predictions_price = _postprocess_predictions(
ensemble_pred,
last_price,
use_returns=request.use_returns
)
lower_price = _postprocess_predictions(
ensemble_lower,
last_price,
use_returns=request.use_returns
)
upper_price = _postprocess_predictions(
ensemble_upper,
last_price,
use_returns=request.use_returns
)
# Convert individual forecasts to price space for output
individual_forecasts_price = {}
for name, forecast in individual_forecasts.items():
individual_forecasts_price[name] = [
Decimal(str(p)) for p in _postprocess_predictions(
forecast, last_price, use_returns=request.use_returns
)
]
result = ForecastResult(
ticker=request.ticker,
method=f"ensemble_{request.ensemble_method}",
forecast_horizon=request.forecast_horizon,
predictions=[Decimal(str(p)) for p in predictions_price],
lower_bound=[Decimal(str(p)) for p in lower_price],
upper_bound=[Decimal(str(p)) for p in upper_price],
confidence_level=request.confidence_level,
models_used=models_used,
individual_forecasts=individual_forecasts_price,
metadata={
"num_models": str(len(models_used)),
"ensemble_method": request.ensemble_method,
"use_returns": str(request.use_returns),
"chronos_available": str(CHRONOS_AVAILABLE),
"ttm_available": str(TTM_AVAILABLE),
"nhits_available": str(NHITS_AVAILABLE),
}
)
logger.info(
f"Ensemble forecast complete for {request.ticker}: "
f"{len(models_used)} models, method={request.ensemble_method}"
)
return result
except Exception as e:
logger.error(f"Ensemble forecasting error for {request.ticker}: {e}")
raise
if __name__ == "__main__":
# Run MCP server
mcp.run()