File size: 4,133 Bytes
d24798b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import os
import datetime as dt
import pandas as pd
import torch
import gradio as gr
import yfinance as yf
from chronos import BaseChronosPipeline # from 'chronos-forecasting'
# ---- ์ ์ญ ์บ์: ๋ชจ๋ธ์ ํ ๋ฒ๋ง ๋ก๋ํด ์ฌ์ฌ์ฉ ----
_PIPELINE_CACHE = {}
def get_pipeline(model_id: str, device: str = "cpu"):
key = (model_id, device)
if key not in _PIPELINE_CACHE:
_PIPELINE_CACHE[key] = BaseChronosPipeline.from_pretrained(
model_id,
device_map=device, # "cpu" / "cuda" (Spaces ๊ธฐ๋ณธ์ cpu)
torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
)
return _PIPELINE_CACHE[key]
# ---- ์ฃผ๊ฐ ๋ฐ์ดํฐ ๋ก๋ฉ (yfinance) ----
def load_close_series(ticker: str, start: str, end: str, interval: str = "1d"):
# ํ๊ตญ ์ฃผ์์ ์: 005930.KS (์ผ์ฑ์ ์)
df = yf.download(ticker, start=start, end=end, interval=interval, progress=False)
if df.empty or "Close" not in df:
raise ValueError("๋ฐ์ดํฐ๊ฐ ์๊ฑฐ๋ 'Close' ์ด์ ์ฐพ์ ์ ์์ต๋๋ค. ํฐ์ปค/๋ ์ง๋ฅผ ํ์ธํ์ธ์.")
s = df["Close"].dropna().astype(float)
return s
# ---- ์์ธก ํจ์ (Gradio๊ฐ ํธ์ถ) ----
def run_forecast(ticker, start_date, end_date, horizon, model_id, device, interval):
try:
series = load_close_series(ticker, start_date, end_date, interval)
except Exception as e:
return gr.Plot.update(None), pd.DataFrame(), f"๋ฐ์ดํฐ ๋ก๋ฉ ์ค๋ฅ: {e}"
pipe = get_pipeline(model_id, device)
H = int(horizon)
# Chronos ์
๋ ฅ: 1D ํ
์ (float)
context = torch.tensor(series.values, dtype=torch.float32)
# ์ถ๋ ฅ: (num_series=1, num_quantiles=3, H)
# ๋ณดํต q=[0.1, 0.5, 0.9]
preds = pipe.predict(context=context, prediction_length=H)[0]
q10, q50, q90 = preds[0], preds[1], preds[2]
# ํ ๋ฐ์ดํฐ
df_fcst = pd.DataFrame(
{"q10": q10.numpy(), "q50": q50.numpy(), "q90": q90.numpy()},
index=pd.RangeIndex(1, H + 1, name="step"),
)
# ๊ทธ๋ํ
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 4))
plt.plot(series.index, series.values, label="history")
# ๋ฏธ๋ ๊ตฌ๊ฐ x์ถ ๋ง๋ค๊ธฐ: ์ข
๊ฐ๊ฐ ์ผ ๋จ์๋ผ 'D' ์ฃผ๊ธฐ ์ฌ์ฉ
future_index = pd.date_range(series.index[-1], periods=H + 1, freq="D")[1:]
plt.plot(future_index, q50.numpy(), label="forecast(q50)")
plt.fill_between(future_index, q10.numpy(), q90.numpy(), alpha=0.2, label="q10โq90")
plt.title(f"{ticker} forecast by Chronos-Bolt")
plt.legend()
plt.tight_layout()
note = "โป ๋ฐ๋ชจ ๋ชฉ์ ์
๋๋ค. ํฌ์ ํ๋จ์ ์ฑ
์์ ๋ณธ์ธ์๊ฒ ์์ต๋๋ค."
return fig, df_fcst, note
# ---- Gradio UI ----
with gr.Blocks(title="Chronos Stock Forecast") as demo:
gr.Markdown("# Chronos ์ฃผ๊ฐ ์์ธก ๋ฐ๋ชจ")
with gr.Row():
ticker = gr.Textbox(value="AAPL", label="ํฐ์ปค (์: AAPL, MSFT, 005930.KS)")
horizon = gr.Slider(5, 60, value=20, step=1, label="์์ธก ๊ธธ์ด H (์ผ)")
with gr.Row():
start = gr.Textbox(value=(dt.date.today()-dt.timedelta(days=365)).isoformat(), label="์์์ผ (YYYY-MM-DD)")
end = gr.Textbox(value=dt.date.today().isoformat(), label="์ข
๋ฃ์ผ (YYYY-MM-DD)")
with gr.Row():
model_id = gr.Dropdown(
choices=[
"amazon/chronos-bolt-tiny",
"amazon/chronos-bolt-mini",
"amazon/chronos-bolt-small",
"amazon/chronos-bolt-base",
],
value="amazon/chronos-bolt-small",
label="๋ชจ๋ธ"
)
device = gr.Dropdown(choices=["cpu"], value="cpu", label="๋๋ฐ์ด์ค")
interval = gr.Dropdown(choices=["1d"], value="1d", label="๊ฐ๊ฒฉ")
btn = gr.Button("์์ธก ์คํ")
plot = gr.Plot(label="History + Forecast")
table = gr.Dataframe(label="์์ธก ๊ฒฐ๊ณผ (๋ถ์์)")
note = gr.Markdown()
btn.click(
fn=run_forecast,
inputs=[ticker, start, end, horizon, model_id, device, interval],
outputs=[plot, table, note]
)
if __name__ == "__main__":
demo.launch()
|