bert / 模型匯入方式.txt
AlanRex's picture
Upload 模型匯入方式.txt
d715e4f
是一種「黑箱」方法,您這邊的程式碼只需要提供原始數據,而不用管模型內部用了哪些特徵,所有複雜的處理都由另一端完成。這是一個很好的軟體設計概念,可以讓您的程式碼更乾淨、更具彈性。
要實現這個目標,最好的做法是將所有與模型相關的程式碼(載入模型、選擇特徵、正規化、預測)都封裝在一個獨立的 Python 檔案中。我會提供兩個檔案的程式碼範例:
model_predictor.py:這個檔案由您的組員或您來維護,它包含了所有模型預測的邏輯。
HUGING_FACE_V1.0.py:您的主程式,它只需要簡單地匯入 model_predictor.py 中的函式並使用即可。
第一步:建立 model_predictor.py
這個檔案假設模型是您的組員訓練的,並且他們知道模型需要哪些特徵。您可以請他們將這個檔案的內容替換成他們實際的模型和資料處理邏輯。
這個範例假設模型需要 Close 和 Volume 這兩個特徵,並需要 60 天的歷史數據。
請將以下程式碼儲存為一個新檔案,命名為 model_predictor.py:
Python
import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
# 這個函式將所有模型相關的邏輯封裝起來
def predict_with_lstm_model(df_historical):
"""
使用預訓練的 LSTM 模型對歷史數據進行預測。
參數:
df_historical (pd.DataFrame): 包含歷史股價數據的 DataFrame。
回傳:
tuple: (預測值, 狀態訊息)
"""
# === 這段程式碼由您的組員提供和維護,您不需要修改 ===
# 這裡定義模型所需的特徵。請您的組員根據實際訓練模型時使用的特徵來修改這裡。
features_to_use = ['Close', 'Volume']
sequence_length = 60 # 模型需要60天的歷史數據
try:
# 1. 確保數據包含所有所需特徵
if not all(feature in df_historical.columns for feature in features_to_use):
missing_features = [f for f in features_to_use if f not in df_historical.columns]
return None, f"原始數據中缺少模型所需特徵: {', '.join(missing_features)}。"
# 2. 準備輸入數據
if len(df_historical) < sequence_length:
return None, "數據長度不足以進行預測。"
df_predict_input = df_historical[features_to_use].tail(sequence_length)
data_to_predict = df_predict_input.values
# 3. 正規化數據(假設訓練時使用 MinMaxScaler)
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(data_to_predict)
# 4. 重塑為模型期望的形狀:[樣本數, 序列長度, 特徵數]
X_test = np.reshape(scaled_data, (1, scaled_data.shape[0], scaled_data.shape[1]))
# 5. 載入模型
model = tf.keras.models.load_model('stock_lstm_model_v2.keras')
# 6. 進行預測
predicted_scaled = model.predict(X_test, verbose=0)
# 7. 反正規化,將預測結果轉換回原始價格範圍
dummy_array = np.zeros(shape=(1, len(features_to_use)))
dummy_array[0, 0] = predicted_scaled[0, 0]
prediction = scaler.inverse_transform(dummy_array)[0, 0]
return prediction, f"台指期模型預測下一個交易日收盤價為:{prediction:.2f}點。"
except Exception as e:
print(f"載入或預測模型時發生錯誤: {e}")
return None, "模型載入或預測失敗,請檢查模型檔案或輸入資料。"
第二步:修改您的 HUGING_FACE_V1.0.py
現在,您的主程式將變得非常簡潔。您只需要在開頭匯入 predict_with_lstm_model 函式,並在適當的回調函式中呼叫它即可。
請用以下程式碼替換您的 HUGING_FACE_V1.0.py 內容:
Python
# 系統套件
import os
from datetime import datetime, timedelta
# 數據處理
import pandas as pd
import numpy as np
import yfinance as yf
# Dash & Plotly
from dash import Dash, dcc, html, callback, Input, Output
import dash
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# === 匯入您組員提供的模型預測函式 ===
from model_predictor import predict_with_lstm_model
# 台股代號對應表 (移除台指期,因為它現在是獨立區塊)
TAIWAN_STOCKS = {
'台積電': '2330.TW',
'聯發科': '2454.TW',
'鴻海': '2317.TW',
'台塑': '1301.TW',
'中華電': '2412.TW',
'富邦金': '2881.TW',
'國泰金': '2882.TW',
'台達電': '2308.TW',
'統一': '1216.TW',
'日月光': '2311.TW',
'長榮': '2306.TW',
'慧洋-KY': '2637.TW',
'上銀': '2049.TW',
'台泥': '1101.TW',
'譜瑞-KY': '4966.TW',
'貿聯-KY': '3665.TW'
}
# 產業分類
INDUSTRY_MAPPING = {
'2330.TW': '半導體',
'2454.TW': '半導體',
'2317.TW': '電子',
'1301.TW': '塑化',
'2412.TW': '通訊',
'2881.TW': '金融',
'2882.TW': '金融',
'2308.TW': '電子',
'1216.TW': '食品',
'2311.TW': '半導體',
'2306.TW': '航運',
'2637.TW': '航運',
'2049.TW': '機械',
'1101.TW': '水泥',
'4966.TW': '半導體',
'3665.TW': '電子'
}
# 輔助函式: 獲取股價數據
def get_stock_data(symbol, start, end):
try:
df = yf.download(symbol, start=start, end=end)
return df
except Exception as e:
print(f"下載數據時發生錯誤: {e}")
return pd.DataFrame()
# 輔助函式: 計算技術指標
def calculate_technical_indicators(df):
df['MA5'] = df['Close'].rolling(window=5).mean()
df['MA20'] = df['Close'].rolling(window=20).mean()
delta = df['Close'].diff()
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
rs = gain / loss
df['RSI'] = 100 - (100 / (1 + rs))
ema12 = df['Close'].ewm(span=12, adjust=False).mean()
ema26 = df['Close'].ewm(span=26, adjust=False).mean()
df['MACD'] = ema12 - ema26
df['Signal'] = df['MACD'].ewm(span=9, adjust=False).mean()
df['MACD_Hist'] = df['MACD'] - df['Signal']
df['Upper_BB'] = df['MA20'] + 2 * df['Close'].rolling(window=20).std()
df['Lower_BB'] = df['MA20'] - 2 * df['Close'].rolling(window=20).std()
df['RSV'] = ((df['Close'] - df['Low'].rolling(window=9).min()) / (df['High'].rolling(window=9).max() - df['Low'].rolling(window=9).min())) * 100
df['K'] = df['RSV'].ewm(alpha=1/3, adjust=False).mean()
df['D'] = df['K'].ewm(alpha=1/3, adjust=False).mean()
df['Williams_%R'] = ((df['High'].rolling(window=14).max() - df['Close']) / (df['High'].rolling(window=14).max() - df['Low'].rolling(window=14).min())) * -100
return df
# 應用程式啟動
app = dash.Dash(__name__)
# 應用程式佈局
app.layout = html.Div(style={'font-family': 'Arial, sans-serif', 'background-color': '#f0f2f5', 'padding': '20px'}, children=[
html.H1("台股股價與金融數據分析儀表板", style={'text-align': 'center', 'color': '#333'}),
html.Div([
html.Div([
html.H3("股票選擇與時間範圍", style={'color': '#444'}),
html.Div([
html.Label('選擇股票:', style={'font-weight': 'bold', 'margin-right': '10px'}),
dcc.Dropdown(
id='stock-dropdown',
options=[{'label': name, 'value': symbol} for name, symbol in TAIWAN_STOCKS.items()],
value='2330.TW',
style={'width': '80%'}
)
], style={'display': 'flex', 'align-items': 'center', 'margin-bottom': '10px'}),
html.Div([
html.Label('開始日期:', style={'font-weight': 'bold', 'margin-right': '10px'}),
dcc.DatePickerSingle(
id='start-date-picker',
initial_visible_month=datetime.now() - timedelta(days=365),
date=datetime.now() - timedelta(days=365)
)
], style={'display': 'flex', 'align-items': 'center', 'margin-bottom': '10px'}),
html.Div([
html.Label('結束日期:', style={'font-weight': 'bold', 'margin-right': '10px'}),
dcc.DatePickerSingle(
id='end-date-picker',
initial_visible_month=datetime.now(),
date=datetime.now()
)
], style={'display': 'flex', 'align-items': 'center'})
], className='card', style={'flex': '1'}),
html.Div([
html.H3("AI深度學習預測 (台指期)", style={'color': '#444'}),
html.Div(id='lstm-prediction-text', style={'font-size': '20px', 'font-weight': 'bold', 'margin-top': '15px'})
], className='card', style={'flex': '1', 'margin-left': '20px'}),
], style={'display': 'flex', 'justify-content': 'space-between', 'margin-bottom': '20px'}),
html.Div([
html.H3("技術分析", style={'color': '#444'}),
dcc.Graph(id='candlestick-chart', style={'height': '600px'}),
dcc.Graph(id='sub-chart-1', style={'height': '300px'}),
dcc.Graph(id='sub-chart-2', style={'height': '300px'}),
dcc.Graph(id='sub-chart-3', style={'height': '300px'}),
], className='card', style={'margin-bottom': '20px'}),
html.Div([
html.Div([
html.H3("產業分析", style={'color': '#444'}),
html.P(id='industry-text', style={'font-size': '16px'}),
html.Div(id='industry-gauge', style={'margin-top': '20px'})
], className='card', style={'flex': '1'}),
html.Div([
html.H3("新聞摘要", style={'color': '#444'}),
html.Div(id='news-section', style={'font-size': '16px', 'max-height': '300px', 'overflow-y': 'auto'})
], className='card', style={'flex': '1', 'margin-left': '20px'})
], style={'display': 'flex', 'justify-content': 'space-between'}),
])
# 回調函式: 更新所有圖表和資訊
@app.callback(
Output('candlestick-chart', 'figure'),
Output('sub-chart-1', 'figure'),
Output('sub-chart-2', 'figure'),
Output('sub-chart-3', 'figure'),
Output('industry-text', 'children'),
Output('industry-gauge', 'children'),
Output('news-section', 'children'),
Output('lstm-prediction-text', 'children'),
Input('stock-dropdown', 'value'),
Input('start-date-picker', 'date'),
Input('end-date-picker', 'date')
)
def update_stock_info(selected_stock, start_date, end_date):
start_date_obj = datetime.strptime(start_date, '%Y-%m-%d')
end_date_obj = datetime.strptime(end_date, '%Y-%m-%d')
# 獲取主要股票數據
df = get_stock_data(selected_stock, start_date_obj, end_date_obj)
if df.empty:
return (go.Figure(), go.Figure(), go.Figure(), go.Figure(),
"無法獲取數據,請檢查股票代號或時間範圍。",
html.Div(), "無法獲取數據,請檢查網路或API。",
"無法進行預測,因為缺乏歷史數據。")
df = calculate_technical_indicators(df)
# 創建主圖 (K線圖)
candlestick_fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.1,
row_heights=[0.7, 0.3])
candlestick_fig.add_trace(go.Candlestick(x=df.index, open=df['Open'], high=df['High'],
low=df['Low'], close=df['Close'], name='K線圖'),
row=1, col=1)
candlestick_fig.add_trace(go.Bar(x=df.index, y=df['Volume'], name='成交量', marker_color='rgba(158,202,225,0.8)'),
row=2, col=1)
candlestick_fig.update_layout(title='股價K線圖與成交量', yaxis_title='價格', xaxis_rangeslider_visible=False, height=600)
# 創建子圖 1 (MACD)
macd_fig = go.Figure()
macd_fig.add_trace(go.Bar(x=df.index, y=df['MACD_Hist'], name='MACD柱狀圖', marker_color=np.where(df['MACD_Hist'] > 0, '#4CAF50', '#FF5733')))
macd_fig.add_trace(go.Scatter(x=df.index, y=df['MACD'], mode='lines', name='MACD線', line=dict(color='#337AB7')))
macd_fig.add_trace(go.Scatter(x=df.index, y=df['Signal'], mode='lines', name='信號線', line=dict(color='#FFC300')))
macd_fig.update_layout(title='MACD指標', xaxis_rangeslider_visible=False)
# 創建子圖 2 (RSI 與 KD)
rsi_kd_fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.1)
rsi_kd_fig.add_trace(go.Scatter(x=df.index, y=df['RSI'], mode='lines', name='RSI', line=dict(color='#8A2BE2')), row=1, col=1)
rsi_kd_fig.add_trace(go.Scatter(x=df.index, y=df['K'], mode='lines', name='K值', line=dict(color='orange')), row=2, col=1)
rsi_kd_fig.add_trace(go.Scatter(x=df.index, y=df['D'], mode='lines', name='D值', line=dict(color='blue')), row=2, col=1)
rsi_kd_fig.update_layout(title='RSI與KD指標', xaxis_rangeslider_visible=False)
# 創建子圖 3 (布林通道)
bb_fig = go.Figure()
bb_fig.add_trace(go.Candlestick(x=df.index, open=df['Open'], high=df['High'], low=df['Low'], close=df['Close'], name='K線圖'))
bb_fig.add_trace(go.Scatter(x=df.index, y=df['Upper_BB'], name='布林上軌', line=dict(color='red', width=1, dash='dash')))
bb_fig.add_trace(go.Scatter(x=df.index, y=df['Lower_BB'], name='布林下軌', line=dict(color='green', width=1, dash='dash')))
bb_fig.add_trace(go.Scatter(x=df.index, y=df['MA20'], name='中軌', line=dict(color='blue', width=1, dash='solid')))
bb_fig.update_layout(title='布林通道', xaxis_rangeslider_visible=False)
# 產業分析與新聞摘要
industry = INDUSTRY_MAPPING.get(selected_stock, '未知')
industry_text = f"此為{selected_stock} ({list(TAIWAN_STOCKS.keys())[list(TAIWAN_STOCKS.values()).index(selected_stock)]}),隸屬於{industry}產業。"
gauge_value = (df['RSI'].iloc[-1]) if not df['RSI'].isnull().all() else 50
gauge_fig = go.Figure(go.Indicator(
mode="gauge+number",
value=gauge_value,
title={'text': "相對強弱指標 (RSI)"},
gauge={
'axis': {'range': [None, 100]},
'steps': [
{'range': [0, 30], 'color': "lightcoral"},
{'range': [30, 70], 'color': "lightgray"},
{'range': [70, 100], 'color': "lightgreen"}
],
'threshold': {
'line': {'color': "red", 'width': 4},
'thickness': 0.75,
'value': 90
}
}
))
gauge_fig.update_layout(height=200, margin=dict(l=20, r=20, t=40, b=20))
gauge_html = dcc.Graph(figure=gauge_fig)
# 模擬新聞摘要
stock_name = [name for name, symbol in TAIWAN_STOCKS.items() if symbol == selected_stock][0]
news_items = [
f"📈 {stock_name}獲外資調升目標價,看好後續發展前景",
f"💼 法人預期{stock_name}下季營收將較上季成長5-10%",
f"🌐 國際市場波動對{stock_name}影響有限,基本面穩健",
f"⚡ 產業景氣回溫,{stock_name}受惠程度值得關注",
f"📊 技術面顯示{stock_name}突破關鍵壓力,短線偏多"
]
news_content = html.Div([html.P(news) for news in news_items])
# 處理台指期預測
# 這裡我們只負責獲取數據,然後將其傳入預測函式
df_futures = get_stock_data('@TX.F.TW', start_date_obj - timedelta(days=90), end_date_obj)
prediction_value, prediction_text = predict_with_lstm_model(df_futures)
return (candlestick_fig, macd_fig, rsi_kd_fig, bb_fig,
industry_text, gauge_html, news_content,
f"{prediction_text}")
if __name__ == '__main__':
app.run_server(debug=True)
這樣一來,您就可以完全將注意力放在應用程式的介面和功能上,而不用擔心模型內部如何運作。當您的組員更新模型時,他們只需要修改 model_predictor.py 檔案,您的主程式則不需要做任何變動,這大大簡化了後續的維護工作。