import os import datetime as dt import pandas as pd import torch import gradio as gr import yfinance as yf from chronos import BaseChronosPipeline # from 'chronos-forecasting' # ---- 전역 캐시: 모델을 한 번만 로드해 재사용 ---- _PIPELINE_CACHE = {} def get_pipeline(model_id: str, device: str = "cpu"): key = (model_id, device) if key not in _PIPELINE_CACHE: _PIPELINE_CACHE[key] = BaseChronosPipeline.from_pretrained( model_id, device_map=device, # "cpu" / "cuda" (Spaces 기본은 cpu) torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16, ) return _PIPELINE_CACHE[key] # ---- 주가 데이터 로딩 (yfinance) ---- def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"): # 한국 주식은 예: 005930.KS (삼성전자) df = yf.download(ticker, start=start, end=end, interval=interval, progress=False) if df.empty or "Close" not in df: raise ValueError("데이터가 없거나 'Close' 열을 찾을 수 없습니다. 티커/날짜를 확인하세요.") s = df["Close"].dropna().astype(float) return s # ---- 예측 함수 (Gradio가 호출) ---- def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interval): try: series = load_close_series(ticker, start_date, end_date, interval) except Exception as e: return gr.Plot.update(None), pd.DataFrame(), f"데이터 로딩 오류: {e}" pipe = get_pipeline(model_id, device) H = int(horizon) # Chronos 입력: 1D 텐서 (float) context = torch.tensor(series.values, dtype=torch.float32) # 출력: (num_series=1, num_quantiles=3, H) # 보통 q=[0.1, 0.5, 0.9] preds = pipe.predict(context=context, prediction_length=H)[0] q10, q50, q90 = preds[0], preds[1], preds[2] # 표 데이터 df_fcst = pd.DataFrame( {"q10": q10.numpy(), "q50": q50.numpy(), "q90": q90.numpy()}, index=pd.RangeIndex(1, H + 1, name="step"), ) # 그래프 import matplotlib.pyplot as plt fig = plt.figure(figsize=(10, 4)) plt.plot(series.index, series.values, label="history") # 미래 구간 x축 만들기: 종가가 일 단위라 'D' 주기 사용 future_index = pd.date_range(series.index[-1], periods=H + 1, freq="D")[1:] plt.plot(future_index, q50.numpy(), label="forecast(q50)") plt.fill_between(future_index, q10.numpy(), q90.numpy(), alpha=0.2, label="q10–q90") plt.title(f"{ticker} forecast by Chronos-Bolt") plt.legend() plt.tight_layout() note = "※ 데모 목적입니다. 투자 판단의 책임은 본인에게 있습니다." return fig, df_fcst, note # ---- Gradio UI ---- with gr.Blocks(title="Chronos Stock Forecast") as demo: gr.Markdown("# Chronos 주가 예측 데모") with gr.Row(): ticker = gr.Textbox(value="AAPL", label="티커 (예: AAPL, MSFT, 005930.KS)") horizon = gr.Slider(5, 60, value=20, step=1, label="예측 길이 H (일)") with gr.Row(): start = gr.Textbox(value=(dt.date.today()-dt.timedelta(days=365)).isoformat(), label="시작일 (YYYY-MM-DD)") end = gr.Textbox(value=dt.date.today().isoformat(), label="종료일 (YYYY-MM-DD)") with gr.Row(): model_id = gr.Dropdown( choices=[ "amazon/chronos-bolt-tiny", "amazon/chronos-bolt-mini", "amazon/chronos-bolt-small", "amazon/chronos-bolt-base", ], value="amazon/chronos-bolt-small", label="모델" ) device = gr.Dropdown(choices=["cpu"], value="cpu", label="디바이스") interval = gr.Dropdown(choices=["1d"], value="1d", label="간격") btn = gr.Button("예측 실행") plot = gr.Plot(label="History + Forecast") table = gr.Dataframe(label="예측 결과 (분위수)") note = gr.Markdown() btn.click( fn=run_forecast, inputs=[ticker, start, end, horizon, model_id, device, interval], outputs=[plot, table, note] ) if __name__ == "__main__": demo.launch()