oscarzhang commited on
Commit
fc66e35
·
verified ·
1 Parent(s): 4460a10

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,3 +1,178 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Wearable_TimeSeries_Health_Monitor
2
+
3
+ 面向可穿戴设备的多用户健康监控方案:一份模型、一个配置,就能为不同用户构建个性化异常检测。模型基于 **Phased LSTM + Temporal Fusion Transformer (TFT)**,并整合自适应基线、因子特征以及单位秒级的数据滑窗能力,适合当作 HuggingFace 模型或企业内部服务快速接入。
4
+
5
+ ---
6
+
7
+ ## 🌟 模型应用亮点
8
+
9
+ | 能力 | 说明 |
10
+ | --- | --- |
11
+ | **即插即用** | 内置 `WearableAnomalyDetector` 封装,加载模型即可预测,一次初始化后可持续监控多个用户 |
12
+ | **配置驱动特征** | `configs/features_config.json` 描述所有特征、缺省值、类别映射,新增/删减血氧、呼吸率等只需改配置 |
13
+ | **多用户实时服务** | `FeatureCalculator` + 轻量级 `data_storage` 缓存,实现用户历史管理、基线演化、批量推理 |
14
+ | **多场景 Demo** | `test_wearable_service.py` 内置 3 个真实“客户”案例:完整传感器、缺少字段、匿名设备,即使没有原始数据也能立即体验 |
15
+ | **自适应基线支持** | 可扩展 `UserDataManager` 将个人/分组基线接入推理流程,持续改善个体敏感度 |
16
+
17
+ ---
18
+
19
+ ## 📊 核心指标(短期窗口)
20
+
21
+ - **F1**: 0.2819
22
+ - **Precision**: 0.1769
23
+ - **Recall**: 0.6941
24
+ - **最佳阈值**: 0.53
25
+ - **窗口定义**: 12 条 5 分钟数据(1小时时间窗,预测未来 0.5 小时)
26
+
27
+ > 模型偏向召回,适合“异常先提醒、人机协同复核”的场景。可通过阈值/采样策略调节精度与召回。
28
+
29
+ ---
30
+
31
+ ## 🚀 快速体验
32
+
33
+ ### 1. 克隆或下载模型仓库
34
+
35
+ ```bash
36
+ git clone https://huggingface.co/oscarzhang/Wearable_TimeSeries_Health_Monitor
37
+ cd Wearable_TimeSeries_Health_Monitor
38
+ pip install -r requirements.txt
39
+ ```
40
+
41
+ ### 2. 运行内置 Demo(无需额外数据)
42
+
43
+ ```bash
44
+ # 默认跑 ab60 案例
45
+ python test_wearable_service.py
46
+
47
+ # 批量跑全部预置客户
48
+ python test_wearable_service.py --case all
49
+
50
+ # 想从原始 stage1 CSV 抽样测试
51
+ python test_wearable_service.py --from-raw
52
+ ```
53
+
54
+ `test_wearable_service.py` 将自动:
55
+ - 加载 `WearableAnomalyDetector`
56
+ - 读取配置驱动特征
57
+ - 构建窗口并执行预测
58
+ - 输出每位“客户”的异常分数、阈值、预测详情
59
+
60
+ ### 3. 在业务代码中调用
61
+
62
+ ```python
63
+ from wearable_anomaly_detector import WearableAnomalyDetector
64
+
65
+ detector = WearableAnomalyDetector(
66
+ model_dir="checkpoints/phase2/exp_factor_balanced",
67
+ threshold=0.53,
68
+ )
69
+
70
+ result = detector.predict(data_points, return_score=True, return_details=True)
71
+ print(result)
72
+ ```
73
+
74
+ > `data_points` 为 12 条最新的 5 分钟记录;若缺静态特征/设备信息,系统会自动从配置/缓存补齐。
75
+
76
+ ---
77
+
78
+ ## 🔧 输入与输出
79
+
80
+ ### 输入(单个数据点)
81
+
82
+ ```python
83
+ {
84
+ "timestamp": "2024-01-01T08:00:00",
85
+ "deviceId": "ab60", # 可选,缺失时会自动创建匿名 ID
86
+ "features": {
87
+ "hr": 72.0,
88
+ "hrv_rmssd": 30.0,
89
+ "time_period_primary": "morning",
90
+ "data_quality": "high",
91
+ ...
92
+ }
93
+ }
94
+ ```
95
+
96
+ - 每个窗口需 12 条数据(默认 1 小时)
97
+ - 特征是否必填由 `configs/features_config.json` 控制
98
+ - 缺失值会自动回落到 default 或 category_mapping 定义值
99
+
100
+ ### 输出
101
+
102
+ ```python
103
+ {
104
+ "is_anomaly": True,
105
+ "anomaly_score": 0.5760,
106
+ "threshold": 0.5300,
107
+ "details": {
108
+ "window_size": 12,
109
+ "model_output": 0.5760,
110
+ "prediction_confidence": 0.0460
111
+ }
112
+ }
113
+ ```
114
+
115
+ ---
116
+
117
+ ## 🧱 模型架构与训练
118
+
119
+ - **模型骨干**:Phased LSTM 处理不等间隔序列 + Temporal Fusion Transformer 聚合时间上下文
120
+ - **异常检测头**:增强注意力、多层 MLP、可选对比学习/类型辅助头
121
+ - **特征体系**:
122
+ - 生理:HR、HRV(RMSSD/SDNN/PNN50…)
123
+ - 活动:步数、距离、能量消耗、加速度、陀螺仪
124
+ - 环境:光线、昼夜标签、数据质量
125
+ - 基线:自适应基线均值/标准差 + 偏差特征
126
+ - **标签来源**:问卷高置信度标签 + 自适应基线低置信度标签
127
+ - **训练流程**:Stage1/2/3 数据加工 ➜ Phase1 自监督预训练 ➜ Phase2 监督微调 ➜ 阈值/案例校正
128
+
129
+ ---
130
+
131
+ ## 📦 仓库结构(部分)
132
+
133
+ ```
134
+ ├─ configs/
135
+ │ └─ features_config.json # 特征定义 & 归一化策略
136
+ ├─ wearable_anomaly_detector.py # 核心封装:加载、预测、批处理
137
+ ├─ feature_calculator.py # 配置驱动的特征构建 + 用户历史缓存
138
+ ├─ test_wearable_service.py # HuggingFace Demo脚本(内含预置案例)
139
+ └─ checkpoints/phase2/... # 模型权重 & summary
140
+ ```
141
+
142
+ ---
143
+
144
+ ## 📚 数据来源与许可证
145
+
146
+ - 训练数据基于 **“A continuous real-world dataset comprising wearable-based heart rate variability alongside sleep diaries”**(Baigutanova *et al.*, Scientific Data, 2025)以及其 Figshare 数据集 [doi:10.1038/s41597-025-05801-3](https://www.nature.com/articles/s41597-025-05801-3) / [dataset link](https://springernature.figshare.com/articles/dataset/In-situ_wearable-based_dataset_of_continuous_heart_rate_variability_monitoring_accompanied_by_sleep_diaries/28509740)。
147
+ - 该数据集以 **Creative Commons Attribution 4.0 (CC BY 4.0)** 许可发布,可自由使用、修改、分发,但必须保留署名并附上许可证链接。
148
+ - 本仓库沿用 CC BY 4.0 对原始数据的要求;若你在此基础上再加工或发布,请继续保留上述署名与许可证说明。
149
+ - 代码/模型可根据需要使用 MIT/Apache 等许可证,但凡涉及数据的部分,仍需遵循 CC BY 4.0。
150
+
151
+ ---
152
+
153
+ ## 🤝 贡献与扩展
154
+
155
+ 欢迎:
156
+ 1. 新增特征或数据源 ⇒ 更新 `features_config.json` + 提交 PR
157
+ 2. 接入新的用户数据管理/基线策略 ⇒ 扩展 `FeatureCalculator` 或贡献 `UserDataManager`
158
+ 3. 反馈案例或真实部署经验 ⇒ 提 Issue 或 Discussion
159
+
160
+ ---
161
+
162
+ ## 📄 许可证
163
+
164
+ 待定(可根据项目需要替换)。
165
+
166
+ ---
167
+
168
+ ## 🔖 引用
169
+
170
+ ```bibtex
171
+ @software{Wearable_TimeSeries_Health_Monitor,
172
+ title = {Wearable\_TimeSeries\_Health\_Monitor},
173
+ author = {oscarzhang},
174
+ year = {2025},
175
+ url = {https://huggingface.co/oscarzhang/Wearable_TimeSeries_Health_Monitor}
176
+ }
177
+ ```
178
+
checkpoints/phase2/exp_factor_balanced/best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f2f056ea3cec48902ffda2399e905189dce62826034470bb6514f8739eba9ff
3
+ size 27270610
configs/features_config.json ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "version": "1.0",
4
+ "description": "Wearable anomaly detection feature configuration"
5
+ },
6
+ "time_series": [
7
+ {"name": "hr", "enabled": true, "default": 70.0, "normalization": {"type": "zscore", "use_norm_params": true}},
8
+ {"name": "hr_resting", "enabled": true, "default": 65.0, "normalization": {"type": "zscore", "use_norm_params": true}},
9
+ {"name": "hrv_rmssd", "enabled": true, "default": 30.0, "normalization": {"type": "zscore", "use_norm_params": true}},
10
+ {"name": "hrv_sdnn", "enabled": true, "default": 40.0, "normalization": {"type": "zscore", "use_norm_params": true}},
11
+ {"name": "hrv_pnn50", "enabled": true, "default": 15.0, "normalization": {"type": "zscore", "use_norm_params": true}},
12
+ {"name": "sdnn", "enabled": true, "default": 35.0, "normalization": {"type": "zscore", "use_norm_params": true}},
13
+ {"name": "sdsd", "enabled": true, "default": 25.0, "normalization": {"type": "zscore", "use_norm_params": true}},
14
+ {"name": "rmssd", "enabled": true, "default": 30.0, "normalization": {"type": "zscore", "use_norm_params": true}},
15
+ {"name": "pnn20", "enabled": true, "default": 25.0, "normalization": {"type": "zscore", "use_norm_params": true}},
16
+ {"name": "pnn50", "enabled": true, "default": 12.0, "normalization": {"type": "zscore", "use_norm_params": true}},
17
+ {"name": "ibi", "enabled": true, "default": 0.86, "normalization": {"type": "zscore", "use_norm_params": true}},
18
+ {"name": "lf/hf", "enabled": true, "default": 1.8, "normalization": {"type": "zscore", "use_norm_params": true}},
19
+ {"name": "steps", "enabled": true, "default": 20.0, "normalization": {"type": "minmax", "min": 0.0, "max": 500.0}},
20
+ {"name": "distance", "enabled": true, "default": 10.0, "normalization": {"type": "minmax", "min": 0.0, "max": 2000.0}},
21
+ {"name": "calories", "enabled": true, "default": 1.5, "normalization": {"type": "zscore", "use_norm_params": true}},
22
+ {"name": "acc_x_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
23
+ {"name": "acc_y_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
24
+ {"name": "acc_z_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
25
+ {"name": "grv_x_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
26
+ {"name": "grv_y_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
27
+ {"name": "grv_z_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
28
+ {"name": "grv_w_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
29
+ {"name": "gyr_x_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
30
+ {"name": "gyr_y_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
31
+ {"name": "gyr_z_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
32
+ {"name": "light_avg", "enabled": true, "default": 100.0, "normalization": {"type": "minmax", "min": 0.0, "max": 1000.0}},
33
+ {
34
+ "name": "time_period_primary",
35
+ "enabled": true,
36
+ "default": 2.0,
37
+ "normalization": {"type": "none"},
38
+ "category_mapping": {
39
+ "night": 0,
40
+ "morning": 1,
41
+ "day": 2,
42
+ "evening": 3,
43
+ "unknown": 4
44
+ }
45
+ },
46
+ {
47
+ "name": "time_period_secondary",
48
+ "enabled": true,
49
+ "default": 7.0,
50
+ "normalization": {"type": "none"},
51
+ "category_mapping": {
52
+ "commute_morning": 0,
53
+ "breakfast": 1,
54
+ "work_morning": 2,
55
+ "lunch": 3,
56
+ "work_afternoon": 4,
57
+ "commute_evening": 5,
58
+ "dinner": 6,
59
+ "rest_evening": 7,
60
+ "rest_night": 8,
61
+ "exercise": 9,
62
+ "unknown": 10
63
+ }
64
+ },
65
+ {"name": "is_weekend", "enabled": true, "default": 0.0, "normalization": {"type": "none"}},
66
+ {
67
+ "name": "data_quality",
68
+ "enabled": true,
69
+ "default": 0.9,
70
+ "normalization": {"type": "minmax", "min": 0.0, "max": 1.0},
71
+ "category_mapping": {
72
+ "low": 0.3,
73
+ "medium": 0.6,
74
+ "high": 1.0
75
+ }
76
+ },
77
+ {"name": "missingness_score", "enabled": true, "default": 0.0, "normalization": {"type": "minmax", "min": 0.0, "max": 1.0}},
78
+ {"name": "baseline_hrv_mean", "enabled": true, "default": 30.0, "normalization": {"type": "zscore", "use_norm_params": true}},
79
+ {"name": "baseline_hrv_std", "enabled": true, "default": 5.0, "normalization": {"type": "zscore", "use_norm_params": true}},
80
+ {"name": "hrv_deviation_abs", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
81
+ {"name": "hrv_deviation_pct", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
82
+ {"name": "hrv_z_score", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}}
83
+ ],
84
+ "static": [
85
+ {"name": "age_group", "enabled": true, "default": -1},
86
+ {"name": "age_normalized", "enabled": true, "default": 0.5},
87
+ {"name": "sex", "enabled": true, "default": 0.5},
88
+ {"name": "marriage", "enabled": true, "default": -1},
89
+ {"name": "exercise", "enabled": true, "default": -1},
90
+ {"name": "coffee", "enabled": true, "default": -1},
91
+ {"name": "smoking", "enabled": true, "default": -1},
92
+ {"name": "drinking", "enabled": true, "default": -1},
93
+ {"name": "MEQ", "enabled": true, "default": 0.0},
94
+ {"name": "baseline_commute_morning_mean", "enabled": true, "default": 30.0},
95
+ {"name": "baseline_commute_morning_std", "enabled": true, "default": 5.0}
96
+ ],
97
+ "factor_features": {
98
+ "enabled": true,
99
+ "factor_names": ["physio", "activity", "context"],
100
+ "factor_dim": 4
101
+ },
102
+ "known_future": [
103
+ {"name": "hour_of_day", "enabled": true},
104
+ {"name": "day_of_week", "enabled": true},
105
+ {"name": "is_weekend", "enabled": true}
106
+ ]
107
+ }
108
+
feature_calculator.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Dict, List, Optional, Any
4
+ from collections import defaultdict
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+
10
+ class FeatureCalculator:
11
+ """
12
+ 统一从配置文件加载特征定义,构建推理/训练需要的窗口结构
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ config_path: Optional[Path] = None,
18
+ norm_params_path: Optional[Path] = None,
19
+ static_features_path: Optional[Path] = None,
20
+ storage_dir: Optional[Path] = None,
21
+ ):
22
+ base_dir = Path(__file__).parent
23
+ self.config_path = Path(config_path or base_dir / "configs" / "features_config.json")
24
+ self.norm_params_path = Path(norm_params_path or base_dir / "processed_data" / "stage3" / "norm_params.json")
25
+ self.static_features_path = Path(static_features_path or base_dir / "processed_data" / "stage2" / "static_features.csv")
26
+ self.storage_dir = Path(storage_dir or base_dir / "data_storage")
27
+ self.storage_dir.mkdir(parents=True, exist_ok=True)
28
+
29
+ self.features_config = self._load_json(self.config_path)
30
+ self.norm_params = self._load_json(self.norm_params_path) if self.norm_params_path.exists() else {}
31
+ self.static_features_dict = self._load_static_features(self.static_features_path)
32
+
33
+ self.time_series_features = [f for f in self.features_config.get("time_series", []) if f.get("enabled", True)]
34
+ self.static_feature_defs = [f for f in self.features_config.get("static", []) if f.get("enabled", True)]
35
+ self.known_future_defs = [f for f in self.features_config.get("known_future", []) if f.get("enabled", True)]
36
+ factor_cfg = self.features_config.get("factor_features", {})
37
+ self.factor_enabled = factor_cfg.get("enabled", False)
38
+ self.factor_names = factor_cfg.get("factor_names", [])
39
+ self.factor_dim = factor_cfg.get("factor_dim", 0)
40
+
41
+ # 简单的内存级历史缓存,便于后续扩展个性化特征
42
+ self.user_histories: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
43
+
44
+ @staticmethod
45
+ def _load_json(path: Path) -> Dict:
46
+ if not path.exists():
47
+ return {}
48
+ with open(path, "r") as f:
49
+ return json.load(f)
50
+
51
+ @staticmethod
52
+ def _load_static_features(static_file: Path) -> Dict[str, Dict]:
53
+ if not static_file.exists():
54
+ return {}
55
+ df = pd.read_csv(static_file)
56
+ static_dict = {}
57
+ for _, row in df.iterrows():
58
+ device_id = str(row.get("deviceId"))
59
+ if device_id:
60
+ static_dict[device_id] = {
61
+ col: row[col]
62
+ for col in df.columns
63
+ if col != "deviceId"
64
+ }
65
+ return static_dict
66
+
67
+ @staticmethod
68
+ def _to_serializable(value):
69
+ import numpy as np
70
+ from datetime import datetime
71
+ if isinstance(value, (np.integer, )):
72
+ return int(value)
73
+ if isinstance(value, (np.floating, )):
74
+ return float(value)
75
+ if isinstance(value, (pd.Timestamp, datetime)):
76
+ return value.isoformat()
77
+ if isinstance(value, (np.ndarray, )):
78
+ return value.tolist()
79
+ raise TypeError(f"Object of type {type(value)} is not JSON serializable")
80
+
81
+ def register_data_points(self, user_id: str, data_points: List[Dict]):
82
+ """
83
+ 轻量缓存用户数据,并写入 data_storage/users/{user_id}.jsonl
84
+ """
85
+ if not user_id:
86
+ return
87
+ user_dir = self.storage_dir / "users"
88
+ user_dir.mkdir(exist_ok=True, parents=True)
89
+ history_file = user_dir / f"{user_id}.jsonl"
90
+
91
+ with history_file.open("a", encoding="utf-8") as f:
92
+ for point in data_points:
93
+ serializable = dict(point)
94
+ ts = serializable.get('timestamp')
95
+ if isinstance(ts, (pd.Timestamp, )):
96
+ serializable['timestamp'] = ts.isoformat()
97
+ elif hasattr(ts, "isoformat"):
98
+ serializable['timestamp'] = ts.isoformat()
99
+ f.write(json.dumps(serializable, ensure_ascii=False, default=self._to_serializable) + "\n")
100
+
101
+ self.user_histories[user_id].extend(data_points)
102
+ # 只保留最近 5,000 条在内存,避免占用
103
+ if len(self.user_histories[user_id]) > 5000:
104
+ self.user_histories[user_id] = self.user_histories[user_id][-5000:]
105
+
106
+ def normalize_series(self, values: List[float], feature_name: str, cfg: Dict) -> List[float]:
107
+ arr = np.array(values, dtype=np.float32)
108
+ norm_cfg = cfg.get("normalization", {"type": "none"})
109
+ norm_type = norm_cfg.get("type", "none")
110
+
111
+ if norm_type == "zscore":
112
+ mean, std = self._get_norm_stats(feature_name, norm_cfg)
113
+ if std == 0:
114
+ std = 1.0
115
+ arr = (arr - mean) / std
116
+ elif norm_type == "minmax":
117
+ min_v = norm_cfg.get("min", 0.0)
118
+ max_v = norm_cfg.get("max", 1.0)
119
+ scale = max(max_v - min_v, 1e-6)
120
+ arr = (arr - min_v) / scale
121
+ else:
122
+ # none
123
+ pass
124
+
125
+ arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
126
+ return arr.tolist()
127
+
128
+ @staticmethod
129
+ def _coerce_value(value, feat_cfg):
130
+ default = feat_cfg.get("default", 0.0)
131
+ if value is None or pd.isna(value):
132
+ return default
133
+ category_mapping = feat_cfg.get("category_mapping")
134
+ if isinstance(value, str):
135
+ if category_mapping:
136
+ return category_mapping.get(value, default)
137
+ try:
138
+ return float(value)
139
+ except ValueError:
140
+ return default
141
+ try:
142
+ return float(value)
143
+ except (TypeError, ValueError):
144
+ return default
145
+
146
+ def _get_norm_stats(self, feature_name: str, norm_cfg: Dict) -> (float, float):
147
+ if norm_cfg.get("use_norm_params") and feature_name in self.norm_params:
148
+ stats = self.norm_params[feature_name]
149
+ return stats.get("mean", 0.0), stats.get("std", 1.0)
150
+ return norm_cfg.get("mean", 0.0), norm_cfg.get("std", 1.0)
151
+
152
+ def build_window(self, data_points: List[Dict], user_id: Optional[str] = None) -> Dict:
153
+ if len(data_points) < 12:
154
+ raise ValueError("数据点不足,需要至少12个点构建短期窗口")
155
+
156
+ if user_id:
157
+ self.register_data_points(user_id, data_points)
158
+
159
+ timestamps = []
160
+ input_features = {feat["name"]: [] for feat in self.time_series_features}
161
+
162
+ for point in data_points:
163
+ ts = point.get("timestamp")
164
+ if isinstance(ts, str):
165
+ ts = pd.to_datetime(ts)
166
+ timestamps.append(ts)
167
+
168
+ feature_payload = point.get("features", {})
169
+ for feat_cfg in self.time_series_features:
170
+ name = feat_cfg["name"]
171
+ value = feature_payload.get(name)
172
+ value = self._coerce_value(value, feat_cfg)
173
+ input_features[name].append(value)
174
+
175
+ # delta_t
176
+ delta_t = [0.0]
177
+ for i in range(1, len(timestamps)):
178
+ diff = (timestamps[i] - timestamps[i - 1]).total_seconds()
179
+ delta_t.append(float(diff))
180
+
181
+ # 归一化
182
+ normalized_features = {}
183
+ for feat_cfg in self.time_series_features:
184
+ name = feat_cfg["name"]
185
+ normalized_features[name] = self.normalize_series(input_features[name], name, feat_cfg)
186
+
187
+ static_features = self._build_static_features(data_points[0], user_id)
188
+ factor_features = self._build_factor_features(normalized_features)
189
+ known_future = self._build_known_future(timestamps[-6:] if len(timestamps) >= 6 else timestamps)
190
+
191
+ return {
192
+ "input_timestamp": timestamps[:12],
193
+ "input_delta_t": delta_t[:12],
194
+ "input_features": normalized_features,
195
+ "target_timestamp": timestamps[12:] if len(timestamps) > 12 else [],
196
+ "target_delta_t": delta_t[12:] if len(delta_t) > 12 else [],
197
+ "static_features": static_features,
198
+ "known_future_features": known_future,
199
+ "factor_features": factor_features,
200
+ }
201
+
202
+ def _build_static_features(self, first_point: Dict, user_id: Optional[str]) -> Dict:
203
+ static_payload = dict(first_point.get("static_features", {}))
204
+ device_id = first_point.get("deviceId") or user_id
205
+
206
+ if device_id and str(device_id) in self.static_features_dict:
207
+ for key, value in self.static_features_dict[str(device_id)].items():
208
+ static_payload.setdefault(key, value)
209
+
210
+ result = {}
211
+ for feat_cfg in self.static_feature_defs:
212
+ name = feat_cfg["name"]
213
+ result[name] = static_payload.get(name, feat_cfg.get("default", 0.0))
214
+ return result
215
+
216
+ def _build_factor_features(self, normalized_features: Dict[str, List[float]]) -> Optional[Dict[str, List[float]]]:
217
+ if not self.factor_enabled or not self.factor_names:
218
+ return None
219
+
220
+ factor_vectors = {}
221
+ for factor_name in self.factor_names:
222
+ # 目前采用简单均值/最大值/最小值/最后值,方便后续替换
223
+ merged = []
224
+ for feat_name, values in normalized_features.items():
225
+ if factor_name == "physio" and feat_name.startswith("hrv"):
226
+ merged.extend(values)
227
+ elif factor_name == "activity" and feat_name in {"steps", "distance", "calories"}:
228
+ merged.extend(values)
229
+ elif factor_name == "context" and feat_name in {"time_period_primary", "time_period_secondary", "is_weekend"}:
230
+ merged.extend(values)
231
+
232
+ if not merged:
233
+ factor_vectors[factor_name] = [0.0] * self.factor_dim
234
+ else:
235
+ arr = np.array(merged, dtype=np.float32)
236
+ stats = [
237
+ float(arr.mean()),
238
+ float(arr.std()),
239
+ float(arr.max()),
240
+ float(arr.min())
241
+ ]
242
+ factor_vectors[factor_name] = stats[: self.factor_dim] if len(stats) >= self.factor_dim else stats + [0.0] * (self.factor_dim - len(stats))
243
+ return factor_vectors
244
+
245
+ def _build_known_future(self, timestamps: List[pd.Timestamp]) -> Dict[str, List[float]]:
246
+ hours, days, weekends = [], [], []
247
+ for ts in timestamps:
248
+ if pd.isna(ts):
249
+ hours.append(12.0)
250
+ days.append(3.0)
251
+ weekends.append(0.0)
252
+ else:
253
+ hours.append(float(ts.hour))
254
+ days.append(float(ts.weekday()))
255
+ weekends.append(float(1 if ts.weekday() >= 5 else 0))
256
+
257
+ result = {}
258
+ for cfg in self.known_future_defs:
259
+ name = cfg["name"]
260
+ if name == "hour_of_day":
261
+ result[name] = hours
262
+ elif name == "day_of_week":
263
+ result[name] = days
264
+ elif name == "is_weekend":
265
+ result[name] = weekends
266
+ return result
267
+
268
+ def get_enabled_feature_names(self) -> List[str]:
269
+ return [feat["name"] for feat in self.time_series_features]
270
+
271
+
272
+ __all__ = ["FeatureCalculator"]
273
+
processed_data/stage3/norm_params.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "hr_mean": {
3
+ "mean": 79.88385009765625,
4
+ "std": 15.546831130981445,
5
+ "min": 33.0,
6
+ "max": 200.2244873046875
7
+ },
8
+ "hr_std": {
9
+ "mean": 12.757049560546875,
10
+ "std": 3.9224278926849365,
11
+ "min": 0.0,
12
+ "max": 32.2431755065918
13
+ },
14
+ "hr_median": {
15
+ "mean": 76.4555892944336,
16
+ "std": 6.908801555633545,
17
+ "min": 48.0,
18
+ "max": 104.0
19
+ },
20
+ "hr_resting": {
21
+ "mean": 65.74867248535156,
22
+ "std": 7.843548774719238,
23
+ "min": 44.12284469604492,
24
+ "max": 86.0
25
+ },
26
+ "hr_nrem": {
27
+ "mean": 61.779720306396484,
28
+ "std": 11.666051864624023,
29
+ "min": 0.0,
30
+ "max": 92.5469970703125
31
+ },
32
+ "hrv_rmssd": {
33
+ "mean": 83.4627685546875,
34
+ "std": 62.30027389526367,
35
+ "min": 0.0,
36
+ "max": 855.8391723632812
37
+ },
38
+ "hrv_sdnn": {
39
+ "mean": 100.59049987792969,
40
+ "std": 43.545467376708984,
41
+ "min": 0.0,
42
+ "max": 393.35162353515625
43
+ },
44
+ "steps": {
45
+ "mean": 342.7657470703125,
46
+ "std": 823.3682861328125,
47
+ "min": 0.0,
48
+ "max": 27004.0
49
+ },
50
+ "distance": {
51
+ "mean": 225.4749755859375,
52
+ "std": 504.8075866699219,
53
+ "min": 0.0,
54
+ "max": 10460.2998046875
55
+ },
56
+ "calories": {
57
+ "mean": 104.05133819580078,
58
+ "std": 211.85128784179688,
59
+ "min": 0.0,
60
+ "max": 2962.070068359375
61
+ },
62
+ "sleep_duration_total": {
63
+ "mean": 418.6901550292969,
64
+ "std": 142.2774200439453,
65
+ "min": 0.0,
66
+ "max": 1110.0
67
+ },
68
+ "sleep_efficiency": {
69
+ "mean": 93.89789581298828,
70
+ "std": 7.327056884765625,
71
+ "min": 34.0,
72
+ "max": 100.0
73
+ },
74
+ "sleep_deep_ratio": {
75
+ "mean": 1.00419020652771,
76
+ "std": 0.3390481770038605,
77
+ "min": 0.0,
78
+ "max": 4.310344696044922
79
+ },
80
+ "sleep_rem_ratio": {
81
+ "mean": 1.00448739528656,
82
+ "std": 0.35869544744491577,
83
+ "min": 0.0,
84
+ "max": 3.9259259700775146
85
+ },
86
+ "sleep_light_ratio": {
87
+ "mean": 0.9923003315925598,
88
+ "std": 0.23265497386455536,
89
+ "min": 0.0,
90
+ "max": 3.034313678741455
91
+ },
92
+ "spo2": {
93
+ "mean": 95.9047622680664,
94
+ "std": 1.04403817653656,
95
+ "min": 92.4000015258789,
96
+ "max": 100.0
97
+ },
98
+ "stress_score": {
99
+ "mean": 65.94886779785156,
100
+ "std": 28.051528930664062,
101
+ "min": 0.0,
102
+ "max": 93.0
103
+ },
104
+ "ALERT": {
105
+ "mean": 0.07375683635473251,
106
+ "std": 0.2613747715950012,
107
+ "min": 0.0,
108
+ "max": 1.0
109
+ },
110
+ "HAPPY": {
111
+ "mean": 0.1726546734571457,
112
+ "std": 0.37794846296310425,
113
+ "min": 0.0,
114
+ "max": 1.0
115
+ },
116
+ "NEUTRAL": {
117
+ "mean": 0.1967589408159256,
118
+ "std": 0.3975485563278198,
119
+ "min": 0.0,
120
+ "max": 1.0
121
+ },
122
+ "RESTED/RELAXED": {
123
+ "mean": 0.23211927711963654,
124
+ "std": 0.42218467593193054,
125
+ "min": 0.0,
126
+ "max": 1.0
127
+ },
128
+ "SAD": {
129
+ "mean": 0.018068943172693253,
130
+ "std": 0.13320080935955048,
131
+ "min": 0.0,
132
+ "max": 1.0
133
+ },
134
+ "TENSE/ANXIOUS": {
135
+ "mean": 0.10590820014476776,
136
+ "std": 0.3077200949192047,
137
+ "min": 0.0,
138
+ "max": 1.0
139
+ },
140
+ "TIRED": {
141
+ "mean": 0.20073312520980835,
142
+ "std": 0.4005488157272339,
143
+ "min": 0.0,
144
+ "max": 1.0
145
+ }
146
+ }
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.1.0
2
+ numpy>=1.24
3
+ pandas>=2.0
4
+ huggingface_hub>=0.23
test_wearable_service.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ from typing import List, Dict, Tuple
4
+
5
+ import pandas as pd
6
+
7
+ from wearable_anomaly_detector import load_detector
8
+ from feature_calculator import FeatureCalculator
9
+
10
+ # 预置案例(来自wearable原始数据)
11
+ PREDEFINED_CASES: Dict[str, Dict] = {
12
+ "ab60_morning_rest": {
13
+ "description": "用户ab60,清晨休息到早餐前的连续12个窗口,用于快速验证服务是否可输出结果",
14
+ "user_id": "ab60",
15
+ "data_points": [
16
+ {"timestamp": "2021-03-04T04:45:20.170000", "deviceId": "ab60", "features": {"hr": 91.65860215053765, "hr_resting": 87.84302108870469, "hrv_rmssd": 73.33511196423747, "hrv_sdnn": 72.35486488414405, "hrv_pnn50": 0.3422818791946309, "sdnn": 72.35486488414405, "sdsd": 55.28945952794972, "rmssd": 73.33511196423747, "pnn20": 0.6677852348993288, "pnn50": 0.3422818791946309, "ibi": 671.2685790942928, "lf/hf": 0.6578861348372742, "acc_x_avg": 4.9294712037533515, "acc_y_avg": -3.1057153652814957, "acc_z_avg": 5.8750820100536005, "grv_x_avg": -0.5497846977211797, "grv_y_avg": 0.0042184631367292, "grv_z_avg": 0.1525969041554961, "grv_w_avg": 0.1771413800268097, "gyr_x_avg": -0.9240750428954416, "gyr_y_avg": 1.1994772238605889, "gyr_z_avg": 0.3142024215817695, "light_avg": 440.066889632107, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.2886646919083789}},
17
+ {"timestamp": "2021-03-04T05:15:20.497000", "deviceId": "ab60", "features": {"hr": 84.15604988203573, "hr_resting": 87.84302108870469, "hrv_rmssd": 73.48193641157741, "hrv_sdnn": 74.51390542231427, "hrv_pnn50": 0.4143302180685358, "sdnn": 74.51390542231427, "sdsd": 51.36438677767364, "rmssd": 73.48193641157741, "pnn20": 0.7538940809968847, "pnn50": 0.4143302180685358, "ibi": 729.7248964415015, "lf/hf": 1.1884595045473076, "acc_x_avg": 5.264426130609508, "acc_y_avg": -2.4949751527126582, "acc_z_avg": 5.638523847957136, "grv_x_avg": -0.2725604313462818, "grv_y_avg": -0.0417301761553922, "grv_z_avg": 0.0857810462156731, "grv_w_avg": 0.2177219383791028, "gyr_x_avg": -0.6168720663094444, "gyr_y_avg": 1.3517548573342275, "gyr_z_avg": 0.3159611446751512, "light_avg": 121.7391304347826, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1872242091224755}},
18
+ {"timestamp": "2021-03-04T05:35:21.317000", "deviceId": "ab60", "features": {"hr": 82.22642857142857, "hr_resting": 87.84302108870469, "hrv_rmssd": 73.0761458919254, "hrv_sdnn": 81.71063624330793, "hrv_pnn50": 0.45703125, "sdnn": 81.71063624330793, "sdsd": 47.19876056820636, "rmssd": 73.0761458919254, "pnn20": 0.7578125, "pnn50": 0.45703125, "ibi": 799.9492801995793, "lf/hf": 1.7751635086235489, "acc_x_avg": 4.37106831212324, "acc_y_avg": -0.3996270033489605, "acc_z_avg": 6.891127246483591, "grv_x_avg": -0.6645361294433281, "grv_y_avg": -0.3851286666666666, "grv_z_avg": 0.0854148665325285, "grv_w_avg": 0.189048981220657, "gyr_x_avg": -0.1036394095174265, "gyr_y_avg": 1.4141689068364616, "gyr_z_avg": 0.4626943733243945, "light_avg": 603.2266666666667, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.335979916085374}},
19
+ {"timestamp": "2021-03-04T05:55:21.428000", "deviceId": "ab60", "features": {"hr": 86.76466621712744, "hr_resting": 87.84302108870469, "hrv_rmssd": 74.59311564233765, "hrv_sdnn": 87.31034842159234, "hrv_pnn50": 0.391304347826087, "sdnn": 87.31034842159234, "sdsd": 52.14310859847846, "rmssd": 74.59311564233765, "pnn20": 0.7418478260869565, "pnn50": 0.391304347826087, "ibi": 709.0165592504526, "lf/hf": 2.4824085488532703, "acc_x_avg": 6.8974065639651725, "acc_y_avg": -0.6532005197588735, "acc_z_avg": 4.533895625586072, "grv_x_avg": -0.606806606831882, "grv_y_avg": -0.1301535706630945, "grv_z_avg": 0.0961370341594106, "grv_w_avg": 0.4334408365706628, "gyr_x_avg": -0.9999732337575332, "gyr_y_avg": 1.1268921553918274, "gyr_z_avg": 0.0987407849966511, "light_avg": 428.3612040133779, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1217623103705546}},
20
+ {"timestamp": "2021-03-04T06:00:21.434000", "deviceId": "ab60", "features": {"hr": 86.11454484380249, "hr_resting": 87.84302108870469, "hrv_rmssd": 80.21713778627182, "hrv_sdnn": 72.4260069463789, "hrv_pnn50": 0.4117647058823529, "sdnn": 72.4260069463789, "sdsd": 57.42532753452254, "rmssd": 80.21713778627182, "pnn20": 0.7536764705882353, "pnn50": 0.4117647058823529, "ibi": 711.989004966488, "lf/hf": 0.5778707099781514, "acc_x_avg": 2.6462390616208977, "acc_y_avg": -2.0968226182183503, "acc_z_avg": 7.492496608841257, "grv_x_avg": -0.3523352029470862, "grv_y_avg": -0.3171756162089749, "grv_z_avg": 0.0852522665773609, "grv_w_avg": 0.0408305612860013, "gyr_x_avg": -0.7581300395442356, "gyr_y_avg": 1.7399128639410202, "gyr_z_avg": 0.5810656782841818, "light_avg": 719.866220735786, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.3055760776711148}},
21
+ {"timestamp": "2021-03-04T06:05:21.530000", "deviceId": "ab60", "features": {"hr": 85.55876427132304, "hr_resting": 87.84302108870469, "hrv_rmssd": 61.68777738949315, "hrv_sdnn": 61.78286083667294, "hrv_pnn50": 0.3786127167630058, "sdnn": 61.78286083667294, "sdsd": 40.15746948668274, "rmssd": 61.68777738949315, "pnn20": 0.7283236994219653, "pnn50": 0.3786127167630058, "ibi": 715.325568241176, "lf/hf": 0.9103296457711838, "acc_x_avg": 3.0594024239785633, "acc_y_avg": -2.580148740120562, "acc_z_avg": 7.2776362170127245, "grv_x_avg": -0.2312993507712944, "grv_y_avg": -0.1165045117370892, "grv_z_avg": 0.1689689483568074, "grv_w_avg": 0.1020506572769953, "gyr_x_avg": -0.2617024135388738, "gyr_y_avg": 1.1428016065683635, "gyr_z_avg": 0.0963672908847178, "light_avg": 665.6321070234113, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1561355447930486}},
22
+ {"timestamp": "2021-03-04T06:15:21.623000", "deviceId": "ab60", "features": {"hr": 82.62592343854936, "hr_resting": 87.84302108870469, "hrv_rmssd": 81.00579057638534, "hrv_sdnn": 73.40584581276266, "hrv_pnn50": 0.4357366771159874, "sdnn": 73.40584581276266, "sdsd": 57.10852463732151, "rmssd": 81.00579057638534, "pnn20": 0.7366771159874608, "pnn50": 0.4357366771159874, "ibi": 746.8545044048512, "lf/hf": 2.2582676863270708, "acc_x_avg": 0.1661593281982583, "acc_y_avg": -1.6491836677829892, "acc_z_avg": 9.736414508372398, "grv_x_avg": -0.6639661513730741, "grv_y_avg": 0.2112439484259877, "grv_z_avg": 0.008919145344943, "grv_w_avg": 0.0059900468854654, "gyr_x_avg": -0.3158673777628935, "gyr_y_avg": 0.9963161399866036, "gyr_z_avg": 0.4727930301406551, "light_avg": 802.180602006689, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.174593188653174}},
23
+ {"timestamp": "2021-03-04T06:40:21.894000", "deviceId": "ab60", "features": {"hr": 83.86055107526882, "hr_resting": 87.84302108870469, "hrv_rmssd": 79.079102105631, "hrv_sdnn": 71.271849209199, "hrv_pnn50": 0.4221556886227545, "sdnn": 71.271849209199, "sdsd": 54.69084849089656, "rmssd": 79.079102105631, "pnn20": 0.7724550898203593, "pnn50": 0.4221556886227545, "ibi": 732.4375213133641, "lf/hf": 0.6677155405784594, "acc_x_avg": 1.7091933630274596, "acc_y_avg": -2.3761690361687893, "acc_z_avg": 8.209631811788354, "grv_x_avg": -0.3922270288010715, "grv_y_avg": 0.2412905398526454, "grv_z_avg": -0.0233138225050234, "grv_w_avg": 0.1066616838580042, "gyr_x_avg": 0.3474681640991295, "gyr_y_avg": 1.1165304762223718, "gyr_z_avg": 0.186041525117213, "light_avg": 729.1103678929766, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.162896032760479}},
24
+ {"timestamp": "2021-03-04T06:45:21.969000", "deviceId": "ab60", "features": {"hr": 85.06473364801079, "hr_resting": 87.84302108870469, "hrv_rmssd": 73.21565395106201, "hrv_sdnn": 69.24046192941114, "hrv_pnn50": 0.4137931034482758, "sdnn": 69.24046192941114, "sdsd": 49.87687480659143, "rmssd": 73.21565395106201, "pnn20": 0.768025078369906, "pnn50": 0.4137931034482758, "ibi": 719.8668191606536, "lf/hf": 1.6620124103207512, "acc_x_avg": 2.2839716548257365, "acc_y_avg": -2.790153060991958, "acc_z_avg": 8.004115201072393, "grv_x_avg": -0.4554428230563004, "grv_y_avg": -0.1927247533512062, "grv_z_avg": 0.0215668719839142, "grv_w_avg": 0.1034467037533509, "gyr_x_avg": -0.0257573813672921, "gyr_y_avg": 1.0048659463806948, "gyr_z_avg": 0.3241018806970504, "light_avg": 755.3177257525084, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1959064930123424}},
25
+ {"timestamp": "2021-03-04T06:50:21.981000", "deviceId": "ab60", "features": {"hr": 84.85642062689585, "hr_resting": 87.84302108870469, "hrv_rmssd": 68.0148861140481, "hrv_sdnn": 76.76458770787994, "hrv_pnn50": 0.3815384615384615, "sdnn": 76.76458770787994, "sdsd": 46.69614203120664, "rmssd": 68.0148861140481, "pnn20": 0.6707692307692308, "pnn50": 0.3815384615384615, "ibi": 729.6592731642268, "lf/hf": 0.9127720565801908, "acc_x_avg": 1.71269043335566, "acc_y_avg": -1.0698834099129273, "acc_z_avg": 9.200909472873422, "grv_x_avg": -0.375676233087742, "grv_y_avg": 0.1822926999330206, "grv_z_avg": 0.0157954038847957, "grv_w_avg": 0.0460910823844608, "gyr_x_avg": -0.1220897675820484, "gyr_y_avg": 0.6891694554588073, "gyr_z_avg": 0.4564768801071655, "light_avg": 847.0335570469799, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1845048437257962}},
26
+ {"timestamp": "2021-03-04T06:55:21.985000", "deviceId": "ab60", "features": {"hr": 80.75881760161236, "hr_resting": 87.84302108870469, "hrv_rmssd": 66.04701786885973, "hrv_sdnn": 76.17603294638347, "hrv_pnn50": 0.3684210526315789, "sdnn": 76.17603294638347, "sdsd": 45.0541425499126, "rmssd": 66.04701786885973, "pnn20": 0.7151702786377709, "pnn50": 0.3684210526315789, "ibi": 757.0256501417455, "lf/hf": 1.269681575994282, "acc_x_avg": 3.739257345612857, "acc_y_avg": -2.4820750622906904, "acc_z_avg": 7.388303849966525, "grv_x_avg": -0.5472615157401197, "grv_y_avg": 0.3134806992632281, "grv_z_avg": 0.0378557434695244, "grv_w_avg": 0.2093843141326185, "gyr_x_avg": -0.5253990328638486, "gyr_y_avg": 1.6931924748490956, "gyr_z_avg": 0.665633814889338, "light_avg": 703.4496644295302, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1579866815850661}},
27
+ {"timestamp": "2021-03-04T07:00:22.020000", "deviceId": "ab60", "features": {"hr": 83.36467427803895, "hr_resting": 87.84302108870469, "hrv_rmssd": 85.01833953484277, "hrv_sdnn": 76.7068088047925, "hrv_pnn50": 0.4434782608695652, "sdnn": 76.7068088047925, "sdsd": 62.38011995356233, "rmssd": 85.01833953484277, "pnn20": 0.7536231884057971, "pnn50": 0.4434782608695652, "ibi": 724.6919175240898, "lf/hf": 1.3219549787885163, "acc_x_avg": 0.4397476164658638, "acc_y_avg": -1.638112914323962, "acc_z_avg": 9.6665646124498, "grv_x_avg": -0.3218441425702814, "grv_y_avg": 0.1350831372155288, "grv_z_avg": 0.0047377771084337, "grv_w_avg": 0.0104064123159303, "gyr_x_avg": -0.2149196720214184, "gyr_y_avg": 0.9163253052208852, "gyr_z_avg": 0.5661378821954464, "light_avg": 861.3110367892976, "time_period_primary": "morning", "time_period_secondary": "breakfast", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1291275275920405}}
28
+ ]
29
+ },
30
+ "nd56_low_activity_sparse": {
31
+ "description": "nd56:多数传感器字段缺失,只提供心率/时间段等基础信息,验证模型默认填充能力",
32
+ "user_id": "nd56",
33
+ "data_points": [
34
+ {"timestamp": "2021-03-04T03:40:55.745000", "deviceId": "nd56", "features": {"hr": 73.3681592039801, "steps": None, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium"}},
35
+ {"timestamp": "2021-03-04T03:55:55.963000", "deviceId": "nd56", "features": {"hr": 70.58150365934797, "steps": None, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium"}},
36
+ {"timestamp": "2021-03-04T04:30:56.879000", "deviceId": "nd56", "features": {"hr": 79.00866089273818, "steps": None, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium"}},
37
+ {"timestamp": "2021-03-04T05:11:22.339000", "deviceId": "nd56", "features": {"hr": 76.88147410358566, "steps": None, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium"}},
38
+ {"timestamp": "2021-03-04T06:21:22.941000", "deviceId": "nd56", "features": {"hr": 79.52276503821868, "steps": None, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium"}},
39
+ {"timestamp": "2021-03-04T06:36:22.983000", "deviceId": "nd56", "features": {"hr": 75.333, "steps": None, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium"}},
40
+ {"timestamp": "2021-03-04T06:41:23.077000", "deviceId": "nd56", "features": {"hr": 75.7566401062417, "steps": None, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium"}},
41
+ {"timestamp": "2021-03-04T06:56:23.275000", "deviceId": "nd56", "features": {"hr": 74.49435590969456, "steps": None, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium"}},
42
+ {"timestamp": "2021-03-04T07:26:23.418000", "deviceId": "nd56", "features": {"hr": 73.90371845949535, "steps": None, "time_period_primary": "morning", "time_period_secondary": "breakfast", "is_weekend": 0, "data_quality": "medium"}},
43
+ {"timestamp": "2021-03-04T07:36:23.463000", "deviceId": "nd56", "features": {"hr": 72.7252911813644, "steps": None, "time_period_primary": "morning", "time_period_secondary": "breakfast", "is_weekend": 0, "data_quality": "medium"}},
44
+ {"timestamp": "2021-03-04T07:41:23.508000", "deviceId": "nd56", "features": {"hr": 70.14043824701196, "steps": None, "time_period_primary": "morning", "time_period_secondary": "breakfast", "is_weekend": 0, "data_quality": "medium"}},
45
+ {"timestamp": "2021-03-04T07:46:23.547000", "deviceId": "nd56", "features": {"hr": 71.33565737051792, "steps": None, "time_period_primary": "morning", "time_period_secondary": "breakfast", "is_weekend": 0, "data_quality": "medium"}}
46
+ ]
47
+ },
48
+ "anon_commuter_minimal": {
49
+ "description": "匿名通勤用户:仅提供心率/HRV/时间段等极简指标,deviceId缺失,��证服务可批量处理不同客户",
50
+ "user_id": "anon_commuter",
51
+ "data_points": [
52
+ {"timestamp": "2021-03-09T00:52:04.300000", "deviceId": None, "features": {"hr": 88.77359119706568, "hrv_rmssd": 68.78080551188198, "hrv_sdnn": 68.55954732717733, "steps": 0.0, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.2340064304816851}},
53
+ {"timestamp": "2021-03-09T00:57:04.328000", "deviceId": None, "features": {"hr": 80.68233333333333, "hrv_rmssd": 59.2934623277283, "hrv_sdnn": 73.79953723523498, "steps": 0.0, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1442517250480751}},
54
+ {"timestamp": "2021-03-09T01:02:04.404000", "deviceId": None, "features": {"hr": 75.80266666666667, "hrv_rmssd": 45.96602248249421, "hrv_sdnn": 86.08949225862196, "steps": None, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.0475269119819883}},
55
+ {"timestamp": "2021-03-09T01:07:04.432000", "deviceId": None, "features": {"hr": 74.08533333333334, "hrv_rmssd": 35.124819470890934, "hrv_sdnn": 76.28600536828999, "steps": None, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.0362464905334389}},
56
+ {"timestamp": "2021-03-09T01:12:04.441000", "deviceId": None, "features": {"hr": 76.9121411276375, "hrv_rmssd": 72.02894057601205, "hrv_sdnn": 87.11453810833142, "steps": 0.0, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1366772654292948}},
57
+ {"timestamp": "2021-03-09T01:17:04.485000", "deviceId": None, "features": {"hr": 68.27024325224924, "hrv_rmssd": 36.20840500569033, "hrv_sdnn": 61.29313423988413, "steps": None, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": -0.0048301680504103}},
58
+ {"timestamp": "2021-03-09T01:22:04.496000", "deviceId": None, "features": {"hr": 71.42038640906063, "hrv_rmssd": 42.73973712175516, "hrv_sdnn": 100.9222381995624, "steps": 0.0, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.0282886513311317}},
59
+ {"timestamp": "2021-03-09T01:27:04.517000", "deviceId": None, "features": {"hr": 78.02364302364302, "hrv_rmssd": 68.8236455702779, "hrv_sdnn": 91.82364754816273, "steps": 0.0, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.133595953991592}},
60
+ {"timestamp": "2021-03-09T01:32:04.589000", "deviceId": None, "features": {"hr": 74.2271818787475, "hrv_rmssd": 47.71422420363224, "hrv_sdnn": 80.68140559056071, "steps": 0.0, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.0569492438181573}},
61
+ {"timestamp": "2021-03-09T01:37:04.619000", "deviceId": None, "features": {"hr": 66.9780146568954, "hrv_rmssd": 35.74818419244297, "hrv_sdnn": 50.71532677029013, "steps": None, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.0026578073089701}},
62
+ {"timestamp": "2021-03-09T01:47:04.721000", "deviceId": None, "features": {"hr": 74.31312458361093, "hrv_rmssd": 57.44451616196892, "hrv_sdnn": 98.6011860346101, "steps": 0.0, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1253227425948505}},
63
+ {"timestamp": "2021-03-09T01:52:04.735000", "deviceId": None, "features": {"hr": 71.42538307794804, "hrv_rmssd": 50.01054170733226, "hrv_sdnn": 73.22227779949023, "steps": 0.0, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.0255565038546025}}
64
+ ]
65
+ }
66
+ }
67
+
68
+
69
+ def load_raw_data_window(
70
+ stage1_file: Path,
71
+ feature_names: List[str],
72
+ window_size: int = 12,
73
+ ) -> Tuple[List[Dict], str]:
74
+ stage1_path = Path(stage1_file)
75
+ if not stage1_path.exists():
76
+ raise FileNotFoundError(f"找不到原始数据文件: {stage1_file}")
77
+
78
+ base_cols = pd.read_csv(stage1_path, nrows=0).columns.tolist()
79
+ usecols = ['deviceId', 'ts_start'] + [feat for feat in feature_names if feat in base_cols]
80
+ usecols = list(dict.fromkeys(usecols))
81
+
82
+ buffers: Dict[str, List[Dict]] = {}
83
+ reader = pd.read_csv(
84
+ stage1_path,
85
+ usecols=usecols,
86
+ parse_dates=['ts_start'],
87
+ chunksize=10000,
88
+ )
89
+
90
+ for chunk in reader:
91
+ chunk = chunk.sort_values(['deviceId', 'ts_start'])
92
+ for device_id, group in chunk.groupby('deviceId'):
93
+ records = group.to_dict('records')
94
+ if device_id not in buffers:
95
+ buffers[device_id] = []
96
+ buffers[device_id].extend(records)
97
+ buffers[device_id] = sorted(buffers[device_id], key=lambda r: r['ts_start'])
98
+
99
+ if len(buffers[device_id]) >= window_size:
100
+ segment = buffers[device_id][:window_size]
101
+ data_points: List[Dict] = []
102
+ for row in segment:
103
+ feature_payload = {}
104
+ for feat in feature_names:
105
+ if feat in row and pd.notna(row[feat]):
106
+ feature_payload[feat] = row[feat]
107
+ data_points.append({
108
+ 'timestamp': row['ts_start'].to_pydatetime(),
109
+ 'deviceId': str(device_id),
110
+ 'features': feature_payload,
111
+ })
112
+ return data_points, str(device_id)
113
+
114
+ if len(buffers[device_id]) > window_size * 4:
115
+ buffers[device_id] = buffers[device_id][-window_size*2:]
116
+
117
+ raise ValueError("没有找到满足窗口长度的用户数据(请检查原始数据是否存在足够连续的记录)")
118
+
119
+
120
+ def load_predefined_case(case_name: str) -> Tuple[List[Dict], str, str]:
121
+ if case_name not in PREDEFINED_CASES:
122
+ raise ValueError(f"未找到预置案例: {case_name},可选: {list(PREDEFINED_CASES.keys())}")
123
+ case = PREDEFINED_CASES[case_name]
124
+ data_points = []
125
+ for point in case["data_points"]:
126
+ converted = dict(point)
127
+ converted["timestamp"] = pd.to_datetime(converted["timestamp"])
128
+ data_points.append(converted)
129
+ return data_points, case["user_id"], case["description"]
130
+
131
+
132
+ def main():
133
+ parser = argparse.ArgumentParser(description="使用原始 wearables 数据测试新的推理服务")
134
+ parser.add_argument("--model-dir", type=str, default="checkpoints/phase2/exp_factor_balanced")
135
+ parser.add_argument("--stage1-file", type=str, default="processed_data/stage1/wearable_processed.csv")
136
+ parser.add_argument("--window-size", type=int, default=12)
137
+ parser.add_argument("--case", type=str, nargs="+", default=["ab60_morning_rest"], help="使用预置案例名(可多选,all=全部)")
138
+ parser.add_argument("--from-raw", action="store_true", help="从stage1原始文件抽样,而不是预置案例")
139
+ args = parser.parse_args()
140
+
141
+ base_dir = Path(__file__).parent
142
+ stage1_file = base_dir / args.stage1_file
143
+
144
+ feature_calculator = FeatureCalculator()
145
+ feature_names = feature_calculator.get_enabled_feature_names()
146
+
147
+ print("\n🚀 加载异常检测器...")
148
+ detector = load_detector(base_dir / args.model_dir)
149
+
150
+ def run_prediction(label: str, data_points: List[Dict], user_id_hint: str):
151
+ print(f"\n🧪 执行预测: {label}")
152
+ print(f" - 用户: {user_id_hint}")
153
+ print(f" - 窗口长度: {len(data_points)} (每个点5分钟)")
154
+ result = detector.predict(data_points, return_score=True, return_details=True)
155
+ print(" ▸ 是否异常:", "是" if result['is_anomaly'] else "否")
156
+ print(f" ▸ 异常分数: {result.get('anomaly_score', 0.0):.4f} (阈值 {result.get('threshold', detector.threshold):.4f})")
157
+ if result.get('details'):
158
+ print(" ▸ 详情:", result['details'])
159
+
160
+ if args.from_raw:
161
+ print("📥 正在从原始数据抽样窗口...")
162
+ if not stage1_file.exists():
163
+ raise FileNotFoundError(f"找不到原始数据文件: {stage1_file}")
164
+ data_points, device_id = load_raw_data_window(stage1_file, feature_names, window_size=args.window_size)
165
+ run_prediction("raw_sample", data_points, device_id)
166
+ else:
167
+ case_names = args.case
168
+ if "all" in case_names:
169
+ case_names = list(PREDEFINED_CASES.keys())
170
+ for case_name in case_names:
171
+ if case_name not in PREDEFINED_CASES:
172
+ print(f"⚠️ 跳过未知案例: {case_name}")
173
+ continue
174
+ data_points, user_id, desc = load_predefined_case(case_name)
175
+ print(f"\n🧾 使用预置案例: {case_name}")
176
+ print(f" - 描述: {desc}")
177
+ run_prediction(case_name, data_points, user_id)
178
+
179
+
180
+ if __name__ == "__main__":
181
+ main()
182
+
wearable_anomaly_detector.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wearable健康异常检测模型 - 标准化封装
3
+ 提供简单的API接口,用于实时异常检测
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ import json
9
+ import pickle
10
+ from pathlib import Path
11
+ from typing import Dict, List, Optional, Union
12
+ from datetime import datetime
13
+ import pandas as pd
14
+
15
+ # 添加项目根目录到路径
16
+ import sys
17
+ sys.path.insert(0, str(Path(__file__).parent.parent))
18
+
19
+ from models.phased_lstm_tft import PhasedLSTM_TFT, PhasedLSTM_TFT_WithEnhancedAnomalyDetection
20
+ from feature_calculator import FeatureCalculator
21
+
22
+
23
+ class WearableAnomalyDetector:
24
+ """
25
+ Wearable健康异常检测器
26
+
27
+ 使用示例:
28
+ detector = WearableAnomalyDetector(model_dir="checkpoints/phase2/exp_factor_balanced")
29
+ result = detector.predict(data_points)
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ model_dir: Union[str, Path],
35
+ device: Optional[str] = None,
36
+ threshold: Optional[float] = None
37
+ ):
38
+ """
39
+ 初始化异常检测器
40
+
41
+ 参数:
42
+ model_dir: 模型目录路径(包含best_model.pt和配置文件)
43
+ device: 设备('cuda'或'cpu'),如果为None则自动选择
44
+ threshold: 异常阈值,如果为None则从配置中读取
45
+ """
46
+ self.model_dir = Path(model_dir)
47
+ self.device = torch.device(device or ('cuda' if torch.cuda.is_available() else 'cpu'))
48
+
49
+ # 加载配置
50
+ self.config = self._load_config()
51
+
52
+ # 确定阈值
53
+ if threshold is not None:
54
+ self.threshold = float(threshold)
55
+ else:
56
+ config_threshold = self.config.get('threshold')
57
+ if config_threshold is not None:
58
+ self.threshold = float(config_threshold)
59
+ else:
60
+ self.threshold = 0.53 # 默认阈值
61
+ print(f" ⚠️ 未找到阈值配置,使用默认值: {self.threshold:.4f}")
62
+
63
+ # 加载模型
64
+ self.model = self._load_model()
65
+ self.model.eval()
66
+
67
+ # 加载归一化参数(维持向后兼容)
68
+ self.norm_params = self._load_norm_params()
69
+
70
+ # 配置驱动特征计算
71
+ self.feature_calculator = FeatureCalculator(
72
+ config_path=self.config.get('feature_config_path'),
73
+ norm_params_path=Path(__file__).parent / 'processed_data' / 'stage3' / 'norm_params.json',
74
+ static_features_path=Path(__file__).parent / 'processed_data' / 'stage2' / 'static_features.csv',
75
+ storage_dir=Path(self.config.get('storage_dir', Path(__file__).parent / 'data_storage'))
76
+ )
77
+ self.features = self.feature_calculator.get_enabled_feature_names()
78
+ self.static_feature_names = [cfg["name"] for cfg in self.feature_calculator.static_feature_defs]
79
+ self.known_future_dim = max(len(self.feature_calculator.known_future_defs), 1)
80
+ self.factor_metadata = {
81
+ 'enabled': self.feature_calculator.factor_enabled,
82
+ 'factor_names': self.feature_calculator.factor_names,
83
+ 'factor_dim': self.feature_calculator.factor_dim
84
+ }
85
+
86
+ print(f"✅ 模型加载成功")
87
+ print(f" - 设备: {self.device}")
88
+ print(f" - 阈值: {self.threshold:.4f}")
89
+ print(f" - 特征数: {len(self.features)}")
90
+
91
+ def _load_config(self) -> Dict:
92
+ """加载模型配置"""
93
+ config_file = self.model_dir / 'config.json'
94
+ if config_file.exists():
95
+ with open(config_file, 'r') as f:
96
+ config = json.load(f)
97
+ return config
98
+
99
+ # 尝试从summary.json读取
100
+ summary_file = self.model_dir / 'summary.json'
101
+ if summary_file.exists():
102
+ with open(summary_file, 'r') as f:
103
+ summary = json.load(f)
104
+ config = {
105
+ 'threshold': summary.get('best_threshold'),
106
+ 'features': [], # 需要从其他地方获取
107
+ }
108
+ return config
109
+
110
+ # 如果都没有,返回空配置(使用默认值)
111
+ print(f" ⚠️ 未找到配置文件,使用默认配置")
112
+ return {}
113
+
114
+ def _load_model(self):
115
+ """加载模型"""
116
+ # 加载Phase1模型
117
+ phase1_model_path = self.model_dir.parent.parent / 'phase1' / 'best_model.pt'
118
+ if not phase1_model_path.exists():
119
+ raise FileNotFoundError(f"Phase1模型不存在: {phase1_model_path}")
120
+
121
+ checkpoint_phase1 = torch.load(phase1_model_path, map_location=self.device, weights_only=False)
122
+ phase1_config = checkpoint_phase1['config']
123
+
124
+ base_model = PhasedLSTM_TFT(phase1_config)
125
+ base_model.load_state_dict(checkpoint_phase1['model_state_dict'])
126
+ base_model = base_model.to(self.device)
127
+
128
+ # 加载factor_config
129
+ factor_config = self._load_factor_config()
130
+
131
+ # 创建Phase2模型
132
+ model = PhasedLSTM_TFT_WithEnhancedAnomalyDetection(
133
+ base_model,
134
+ num_anomaly_types=4,
135
+ use_enhanced_head=True,
136
+ use_multi_source_heads=False,
137
+ use_domain_adversarial=False,
138
+ factor_config=factor_config
139
+ )
140
+ model = model.to(self.device)
141
+
142
+ # 加载Phase2权重
143
+ phase2_model_path = self.model_dir / 'best_model.pt'
144
+ if not phase2_model_path.exists():
145
+ raise FileNotFoundError(f"Phase2模型不存在: {phase2_model_path}")
146
+
147
+ checkpoint_phase2 = torch.load(phase2_model_path, map_location=self.device, weights_only=False)
148
+ model.load_state_dict(checkpoint_phase2['model_state_dict'])
149
+
150
+ return model
151
+
152
+ def _load_factor_config(self) -> Optional[Dict]:
153
+ """加载因子特征配置"""
154
+ # 方法1: 从config.json读取(如果已加载)
155
+ if hasattr(self, 'factor_metadata') and self.factor_metadata:
156
+ if self.factor_metadata.get('enabled'):
157
+ return {
158
+ 'num_factors': len(self.factor_metadata.get('factor_names', [])),
159
+ 'factor_dim': self.factor_metadata.get('factor_dim', 0),
160
+ 'factor_names': self.factor_metadata.get('factor_names', []),
161
+ 'min_weight': 0.2,
162
+ 'dropout': 0.1,
163
+ }
164
+
165
+ # 方法2: 从窗口信息文件读取
166
+ window_info_file = Path(__file__).parent / 'processed_data' / 'stage3' / 'window_info_multi_scale.json'
167
+ if window_info_file.exists():
168
+ with open(window_info_file, 'r') as f:
169
+ window_info = json.load(f)
170
+ factor_metadata = window_info.get('factor_features', {})
171
+ if factor_metadata and factor_metadata.get('enabled'):
172
+ return {
173
+ 'num_factors': len(factor_metadata.get('factor_names', [])),
174
+ 'factor_dim': factor_metadata.get('factor_dim', 0),
175
+ 'factor_names': factor_metadata.get('factor_names', []),
176
+ 'min_weight': 0.2,
177
+ 'dropout': 0.1,
178
+ }
179
+ return None
180
+
181
+ def _load_norm_params(self) -> Optional[Dict]:
182
+ """加载归一化参数"""
183
+ norm_file = Path(__file__).parent / 'processed_data' / 'stage3' / 'norm_params.json'
184
+ if norm_file.exists():
185
+ with open(norm_file, 'r') as f:
186
+ return json.load(f)
187
+ return None
188
+
189
+ def predict(
190
+ self,
191
+ data_points: List[Dict],
192
+ return_score: bool = True,
193
+ return_details: bool = False
194
+ ) -> Dict:
195
+ """
196
+ 预测异常
197
+
198
+ 参数:
199
+ data_points: 数据点列表,每个数据点是一个字典,包含:
200
+ - timestamp: 时间戳(datetime或字符串)
201
+ - features: 特征字典,包含所有需要的特征值
202
+ - static_features: 静态特征字典(可选)
203
+ return_score: 是否返回异常分数
204
+ return_details: 是否返回详细信息
205
+
206
+ 返回:
207
+ {
208
+ 'is_anomaly': bool, # 是否异常
209
+ 'anomaly_score': float, # 异常分数(0-1)
210
+ 'threshold': float, # 使用的阈值
211
+ 'details': dict (可选) # 详细信息
212
+ }
213
+ """
214
+ user_id = data_points[0].get('deviceId') or data_points[0].get('user_id')
215
+ window = self.feature_calculator.build_window(data_points, user_id=user_id)
216
+
217
+ # 转换为模型输入格式
218
+ model_input = self._prepare_model_input(window)
219
+
220
+ # 模型预测
221
+ with torch.no_grad():
222
+ # 模型forward方法接受位置参数,需要按顺序传递
223
+ outputs = self.model(
224
+ model_input['x'],
225
+ model_input['delta_t'],
226
+ model_input['static_features'],
227
+ model_input['known_future_features'],
228
+ mask=model_input.get('mask'),
229
+ return_contrastive_features=model_input.get('return_contrastive_features', False),
230
+ source=None,
231
+ return_domain_features=False,
232
+ factor_features=model_input.get('factor_features')
233
+ )
234
+ anomaly_score = outputs['anomaly_score'].cpu().item()
235
+
236
+ # 判断是否异常
237
+ is_anomaly = anomaly_score >= self.threshold
238
+
239
+ result = {
240
+ 'is_anomaly': bool(is_anomaly),
241
+ 'threshold': float(self.threshold),
242
+ }
243
+
244
+ if return_score:
245
+ result['anomaly_score'] = float(anomaly_score)
246
+
247
+ if return_details:
248
+ result['details'] = {
249
+ 'window_size': len(data_points),
250
+ 'model_output': float(anomaly_score),
251
+ 'prediction_confidence': abs(anomaly_score - self.threshold),
252
+ }
253
+
254
+ return result
255
+
256
+ def _prepare_model_input(self, window: Dict) -> Dict:
257
+ """准备模型输入"""
258
+ input_features_list = []
259
+ for feat in self.features:
260
+ values = window['input_features'].get(feat, [0.0] * 12)
261
+ input_features_list.append(values)
262
+
263
+ # 转换为tensor
264
+ input_features = torch.tensor(
265
+ np.stack(input_features_list, axis=1),
266
+ dtype=torch.float32
267
+ ).unsqueeze(0).to(self.device) # [1, 12, num_features]
268
+
269
+ delta_t = torch.tensor(
270
+ window['input_delta_t'],
271
+ dtype=torch.float32
272
+ ).unsqueeze(-1).unsqueeze(0).to(self.device) # [1, 12, 1]
273
+
274
+ # 静态特征
275
+ static_feature_values = []
276
+ static_keys = self.static_feature_names or sorted(window['static_features'].keys())
277
+ for key in static_keys:
278
+ value = window['static_features'].get(key, 0.0)
279
+ static_feature_values.append(float(value))
280
+
281
+ if len(static_feature_values) == 0:
282
+ static_feature_values = [0.0]
283
+
284
+ static_features = torch.tensor(
285
+ static_feature_values,
286
+ dtype=torch.float32
287
+ ).unsqueeze(0).to(self.device) # [1, num_static]
288
+
289
+ # 已知未来特征
290
+ pred_len = len(window.get('target_timestamp', []))
291
+ if pred_len == 0:
292
+ pred_len = 6 # 默认预测长度
293
+
294
+ known_future = torch.zeros(1, pred_len, self.known_future_dim, dtype=torch.float32).to(self.device)
295
+ if 'known_future_features' in window:
296
+ kf = window['known_future_features']
297
+ for idx, cfg in enumerate(self.feature_calculator.known_future_defs):
298
+ name = cfg['name']
299
+ if name in kf:
300
+ series = kf[name][:pred_len]
301
+ if name == 'hour_of_day':
302
+ values = torch.tensor([float(h) / 23.0 for h in series], dtype=torch.float32)
303
+ elif name == 'day_of_week':
304
+ values = torch.tensor([float(d) / 6.0 for d in series], dtype=torch.float32)
305
+ else:
306
+ values = torch.tensor([float(v) for v in series], dtype=torch.float32)
307
+ known_future[0, :len(series), idx] = values
308
+
309
+ # 输入mask(假设所有数据都有效)
310
+ input_mask = torch.ones(1, 12, len(self.features), dtype=torch.float32).to(self.device)
311
+
312
+ # 因子特征
313
+ factor_features = None
314
+ if window.get('factor_features'):
315
+ factor_names = self.factor_metadata.get('factor_names', [])
316
+ factor_dim = self.factor_metadata.get('factor_dim', 4)
317
+ factor_vectors = []
318
+ for name in factor_names:
319
+ vec = window['factor_features'].get(name, [0.0] * factor_dim)
320
+ factor_vectors.append(vec[:factor_dim])
321
+ if factor_vectors:
322
+ factor_features = torch.tensor(
323
+ factor_vectors,
324
+ dtype=torch.float32
325
+ ).unsqueeze(0).to(self.device) # [1, num_factors, factor_dim]
326
+
327
+ return {
328
+ 'x': input_features,
329
+ 'delta_t': delta_t,
330
+ 'static_features': static_features,
331
+ 'known_future_features': known_future,
332
+ 'mask': input_mask,
333
+ 'factor_features': factor_features,
334
+ 'return_contrastive_features': False,
335
+ 'source': None,
336
+ 'return_domain_features': False,
337
+ }
338
+
339
+ def batch_predict(
340
+ self,
341
+ windows: List[List[Dict]],
342
+ return_scores: bool = True
343
+ ) -> List[Dict]:
344
+ """
345
+ 批量预测
346
+
347
+ 参数:
348
+ windows: 窗口列表,每个窗口是一个数据点列表
349
+ return_scores: 是否返回异常分数
350
+
351
+ 返回:
352
+ 预测结果列表
353
+ """
354
+ results = []
355
+ for window_data in windows:
356
+ result = self.predict(window_data, return_score=return_scores)
357
+ results.append(result)
358
+ return results
359
+
360
+ def update_threshold(self, threshold: float):
361
+ """更新异常阈值"""
362
+ self.threshold = threshold
363
+ print(f"✅ 阈值已更新为: {threshold:.4f}")
364
+
365
+
366
+ def load_detector(model_dir: Union[str, Path], **kwargs) -> WearableAnomalyDetector:
367
+ """
368
+ 便捷函数:加载异常检测器
369
+
370
+ 参数:
371
+ model_dir: 模型目录路径
372
+ **kwargs: 其他参数(device, threshold等)
373
+
374
+ 返回:
375
+ WearableAnomalyDetector实例
376
+ """
377
+ return WearableAnomalyDetector(model_dir, **kwargs)
378
+
379
+
380
+ if __name__ == '__main__':
381
+ # 使用示例
382
+ print("=" * 80)
383
+ print("Wearable健康异常检测器 - 使用示例")
384
+ print("=" * 80)
385
+
386
+ # 加载模型
387
+ model_dir = Path(__file__).parent / 'checkpoints' / 'phase2' / 'exp_factor_balanced'
388
+ detector = load_detector(model_dir)
389
+
390
+ # 模拟数据点(实际使用时应该从实时数据流获取)
391
+ print("\n模拟数据点...")
392
+ data_points = []
393
+ base_time = datetime.now()
394
+
395
+ # 使用一个真实的deviceId(如果静态特征表存在)
396
+ # 或者提供一个完整的静态特征示例
397
+ example_device_id = None
398
+ static_dict = detector.feature_calculator.static_features_dict
399
+ if static_dict:
400
+ example_device_id = list(static_dict.keys())[0]
401
+ print(f" 使用示例用户ID: {example_device_id}")
402
+
403
+ for i in range(12):
404
+ data_point = {
405
+ 'timestamp': base_time.replace(minute=i*5),
406
+ 'deviceId': example_device_id, # 提供deviceId以便加载完整静态特征
407
+ 'features': {
408
+ 'hr': 70.0 + np.random.randn() * 5,
409
+ 'hrv_rmssd': 30.0 + np.random.randn() * 3,
410
+ # ... 其他特征(简化示例,实际需要所有36个特征)
411
+ },
412
+ 'static_features': {
413
+ # 可以只提供部分特征,系统会自动从静态特征表补充
414
+ # 或者不提供,完全从静态特征表加载
415
+ }
416
+ }
417
+ data_points.append(data_point)
418
+
419
+ # 预测
420
+ result = detector.predict(data_points, return_score=True, return_details=True)
421
+
422
+ print(f"\n预测结果:")
423
+ print(f" - 是否异常: {result['is_anomaly']}")
424
+ print(f" - 异常分数: {result['anomaly_score']:.4f}")
425
+ print(f" - 阈值: {result['threshold']:.4f}")
426
+ if 'details' in result:
427
+ print(f" - 详细信息: {result['details']}")
428
+