Upload folder using huggingface_hub
Browse files- README.md +178 -3
- checkpoints/phase2/exp_factor_balanced/best_model.pt +3 -0
- configs/features_config.json +108 -0
- feature_calculator.py +273 -0
- processed_data/stage3/norm_params.json +146 -0
- requirements.txt +4 -0
- test_wearable_service.py +182 -0
- wearable_anomaly_detector.py +428 -0
README.md
CHANGED
|
@@ -1,3 +1,178 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Wearable_TimeSeries_Health_Monitor
|
| 2 |
+
|
| 3 |
+
面向可穿戴设备的多用户健康监控方案:一份模型、一个配置,就能为不同用户构建个性化异常检测。模型基于 **Phased LSTM + Temporal Fusion Transformer (TFT)**,并整合自适应基线、因子特征以及单位秒级的数据滑窗能力,适合当作 HuggingFace 模型或企业内部服务快速接入。
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 🌟 模型应用亮点
|
| 8 |
+
|
| 9 |
+
| 能力 | 说明 |
|
| 10 |
+
| --- | --- |
|
| 11 |
+
| **即插即用** | 内置 `WearableAnomalyDetector` 封装,加载模型即可预测,一次初始化后可持续监控多个用户 |
|
| 12 |
+
| **配置驱动特征** | `configs/features_config.json` 描述所有特征、缺省值、类别映射,新增/删减血氧、呼吸率等只需改配置 |
|
| 13 |
+
| **多用户实时服务** | `FeatureCalculator` + 轻量级 `data_storage` 缓存,实现用户历史管理、基线演化、批量推理 |
|
| 14 |
+
| **多场景 Demo** | `test_wearable_service.py` 内置 3 个真实“客户”案例:完整传感器、缺少字段、匿名设备,即使没有原始数据也能立即体验 |
|
| 15 |
+
| **自适应基线支持** | 可扩展 `UserDataManager` 将个人/分组基线接入推理流程,持续改善个体敏感度 |
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
## 📊 核心指标(短期窗口)
|
| 20 |
+
|
| 21 |
+
- **F1**: 0.2819
|
| 22 |
+
- **Precision**: 0.1769
|
| 23 |
+
- **Recall**: 0.6941
|
| 24 |
+
- **最佳阈值**: 0.53
|
| 25 |
+
- **窗口定义**: 12 条 5 分钟数据(1小时时间窗,预测未来 0.5 小时)
|
| 26 |
+
|
| 27 |
+
> 模型偏向召回,适合“异常先提醒、人机协同复核”的场景。可通过阈值/采样策略调节精度与召回。
|
| 28 |
+
|
| 29 |
+
---
|
| 30 |
+
|
| 31 |
+
## 🚀 快速体验
|
| 32 |
+
|
| 33 |
+
### 1. 克隆或下载模型仓库
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
git clone https://huggingface.co/oscarzhang/Wearable_TimeSeries_Health_Monitor
|
| 37 |
+
cd Wearable_TimeSeries_Health_Monitor
|
| 38 |
+
pip install -r requirements.txt
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### 2. 运行内置 Demo(无需额外数据)
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
# 默认跑 ab60 案例
|
| 45 |
+
python test_wearable_service.py
|
| 46 |
+
|
| 47 |
+
# 批量跑全部预置客户
|
| 48 |
+
python test_wearable_service.py --case all
|
| 49 |
+
|
| 50 |
+
# 想从原始 stage1 CSV 抽样测试
|
| 51 |
+
python test_wearable_service.py --from-raw
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
`test_wearable_service.py` 将自动:
|
| 55 |
+
- 加载 `WearableAnomalyDetector`
|
| 56 |
+
- 读取配置驱动特征
|
| 57 |
+
- 构建窗口并执行预测
|
| 58 |
+
- 输出每位“客户”的异常分数、阈值、预测详情
|
| 59 |
+
|
| 60 |
+
### 3. 在业务代码中调用
|
| 61 |
+
|
| 62 |
+
```python
|
| 63 |
+
from wearable_anomaly_detector import WearableAnomalyDetector
|
| 64 |
+
|
| 65 |
+
detector = WearableAnomalyDetector(
|
| 66 |
+
model_dir="checkpoints/phase2/exp_factor_balanced",
|
| 67 |
+
threshold=0.53,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
result = detector.predict(data_points, return_score=True, return_details=True)
|
| 71 |
+
print(result)
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
> `data_points` 为 12 条最新的 5 分钟记录;若缺静态特征/设备信息,系统会自动从配置/缓存补齐。
|
| 75 |
+
|
| 76 |
+
---
|
| 77 |
+
|
| 78 |
+
## 🔧 输入与输出
|
| 79 |
+
|
| 80 |
+
### 输入(单个数据点)
|
| 81 |
+
|
| 82 |
+
```python
|
| 83 |
+
{
|
| 84 |
+
"timestamp": "2024-01-01T08:00:00",
|
| 85 |
+
"deviceId": "ab60", # 可选,缺失时会自动创建匿名 ID
|
| 86 |
+
"features": {
|
| 87 |
+
"hr": 72.0,
|
| 88 |
+
"hrv_rmssd": 30.0,
|
| 89 |
+
"time_period_primary": "morning",
|
| 90 |
+
"data_quality": "high",
|
| 91 |
+
...
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
- 每个窗口需 12 条数据(默认 1 小时)
|
| 97 |
+
- 特征是否必填由 `configs/features_config.json` 控制
|
| 98 |
+
- 缺失值会自动回落到 default 或 category_mapping 定义值
|
| 99 |
+
|
| 100 |
+
### 输出
|
| 101 |
+
|
| 102 |
+
```python
|
| 103 |
+
{
|
| 104 |
+
"is_anomaly": True,
|
| 105 |
+
"anomaly_score": 0.5760,
|
| 106 |
+
"threshold": 0.5300,
|
| 107 |
+
"details": {
|
| 108 |
+
"window_size": 12,
|
| 109 |
+
"model_output": 0.5760,
|
| 110 |
+
"prediction_confidence": 0.0460
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
---
|
| 116 |
+
|
| 117 |
+
## 🧱 模型架构与训练
|
| 118 |
+
|
| 119 |
+
- **模型骨干**:Phased LSTM 处理不等间隔序列 + Temporal Fusion Transformer 聚合时间上下文
|
| 120 |
+
- **异常检测头**:增强注意力、多层 MLP、可选对比学习/类型辅助头
|
| 121 |
+
- **特征体系**:
|
| 122 |
+
- 生理:HR、HRV(RMSSD/SDNN/PNN50…)
|
| 123 |
+
- 活动:步数、距离、能量消耗、加速度、陀螺仪
|
| 124 |
+
- 环境:光线、昼夜标签、数据质量
|
| 125 |
+
- 基线:自适应基线均值/标准差 + 偏差特征
|
| 126 |
+
- **标签来源**:问卷高置信度标签 + 自适应基线低置信度标签
|
| 127 |
+
- **训练流程**:Stage1/2/3 数据加工 ➜ Phase1 自监督预训练 ➜ Phase2 监督微调 ➜ 阈值/案例校正
|
| 128 |
+
|
| 129 |
+
---
|
| 130 |
+
|
| 131 |
+
## 📦 仓库结构(部分)
|
| 132 |
+
|
| 133 |
+
```
|
| 134 |
+
├─ configs/
|
| 135 |
+
│ └─ features_config.json # 特征定义 & 归一化策略
|
| 136 |
+
├─ wearable_anomaly_detector.py # 核心封装:加载、预测、批处理
|
| 137 |
+
├─ feature_calculator.py # 配置驱动的特征构建 + 用户历史缓存
|
| 138 |
+
├─ test_wearable_service.py # HuggingFace Demo脚本(内含预置案例)
|
| 139 |
+
└─ checkpoints/phase2/... # 模型权重 & summary
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
---
|
| 143 |
+
|
| 144 |
+
## 📚 数据来源与许可证
|
| 145 |
+
|
| 146 |
+
- 训练数据基于 **“A continuous real-world dataset comprising wearable-based heart rate variability alongside sleep diaries”**(Baigutanova *et al.*, Scientific Data, 2025)以及其 Figshare 数据集 [doi:10.1038/s41597-025-05801-3](https://www.nature.com/articles/s41597-025-05801-3) / [dataset link](https://springernature.figshare.com/articles/dataset/In-situ_wearable-based_dataset_of_continuous_heart_rate_variability_monitoring_accompanied_by_sleep_diaries/28509740)。
|
| 147 |
+
- 该数据集以 **Creative Commons Attribution 4.0 (CC BY 4.0)** 许可发布,可自由使用、修改、分发,但必须保留署名并附上许可证链接。
|
| 148 |
+
- 本仓库沿用 CC BY 4.0 对原始数据的要求;若你在此基础上再加工或发布,请继续保留上述署名与许可证说明。
|
| 149 |
+
- 代码/模型可根据需要使用 MIT/Apache 等许可证,但凡涉及数据的部分,仍需遵循 CC BY 4.0。
|
| 150 |
+
|
| 151 |
+
---
|
| 152 |
+
|
| 153 |
+
## 🤝 贡献与扩展
|
| 154 |
+
|
| 155 |
+
欢迎:
|
| 156 |
+
1. 新增特征或数据源 ⇒ 更新 `features_config.json` + 提交 PR
|
| 157 |
+
2. 接入新的用户数据管理/基线策略 ⇒ 扩展 `FeatureCalculator` 或贡献 `UserDataManager`
|
| 158 |
+
3. 反馈案例或真实部署经验 ⇒ 提 Issue 或 Discussion
|
| 159 |
+
|
| 160 |
+
---
|
| 161 |
+
|
| 162 |
+
## 📄 许可证
|
| 163 |
+
|
| 164 |
+
待定(可根据项目需要替换)。
|
| 165 |
+
|
| 166 |
+
---
|
| 167 |
+
|
| 168 |
+
## 🔖 引用
|
| 169 |
+
|
| 170 |
+
```bibtex
|
| 171 |
+
@software{Wearable_TimeSeries_Health_Monitor,
|
| 172 |
+
title = {Wearable\_TimeSeries\_Health\_Monitor},
|
| 173 |
+
author = {oscarzhang},
|
| 174 |
+
year = {2025},
|
| 175 |
+
url = {https://huggingface.co/oscarzhang/Wearable_TimeSeries_Health_Monitor}
|
| 176 |
+
}
|
| 177 |
+
```
|
| 178 |
+
|
checkpoints/phase2/exp_factor_balanced/best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4f2f056ea3cec48902ffda2399e905189dce62826034470bb6514f8739eba9ff
|
| 3 |
+
size 27270610
|
configs/features_config.json
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"version": "1.0",
|
| 4 |
+
"description": "Wearable anomaly detection feature configuration"
|
| 5 |
+
},
|
| 6 |
+
"time_series": [
|
| 7 |
+
{"name": "hr", "enabled": true, "default": 70.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 8 |
+
{"name": "hr_resting", "enabled": true, "default": 65.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 9 |
+
{"name": "hrv_rmssd", "enabled": true, "default": 30.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 10 |
+
{"name": "hrv_sdnn", "enabled": true, "default": 40.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 11 |
+
{"name": "hrv_pnn50", "enabled": true, "default": 15.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 12 |
+
{"name": "sdnn", "enabled": true, "default": 35.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 13 |
+
{"name": "sdsd", "enabled": true, "default": 25.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 14 |
+
{"name": "rmssd", "enabled": true, "default": 30.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 15 |
+
{"name": "pnn20", "enabled": true, "default": 25.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 16 |
+
{"name": "pnn50", "enabled": true, "default": 12.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 17 |
+
{"name": "ibi", "enabled": true, "default": 0.86, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 18 |
+
{"name": "lf/hf", "enabled": true, "default": 1.8, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 19 |
+
{"name": "steps", "enabled": true, "default": 20.0, "normalization": {"type": "minmax", "min": 0.0, "max": 500.0}},
|
| 20 |
+
{"name": "distance", "enabled": true, "default": 10.0, "normalization": {"type": "minmax", "min": 0.0, "max": 2000.0}},
|
| 21 |
+
{"name": "calories", "enabled": true, "default": 1.5, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 22 |
+
{"name": "acc_x_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 23 |
+
{"name": "acc_y_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 24 |
+
{"name": "acc_z_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 25 |
+
{"name": "grv_x_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 26 |
+
{"name": "grv_y_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 27 |
+
{"name": "grv_z_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 28 |
+
{"name": "grv_w_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 29 |
+
{"name": "gyr_x_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 30 |
+
{"name": "gyr_y_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 31 |
+
{"name": "gyr_z_avg", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 32 |
+
{"name": "light_avg", "enabled": true, "default": 100.0, "normalization": {"type": "minmax", "min": 0.0, "max": 1000.0}},
|
| 33 |
+
{
|
| 34 |
+
"name": "time_period_primary",
|
| 35 |
+
"enabled": true,
|
| 36 |
+
"default": 2.0,
|
| 37 |
+
"normalization": {"type": "none"},
|
| 38 |
+
"category_mapping": {
|
| 39 |
+
"night": 0,
|
| 40 |
+
"morning": 1,
|
| 41 |
+
"day": 2,
|
| 42 |
+
"evening": 3,
|
| 43 |
+
"unknown": 4
|
| 44 |
+
}
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"name": "time_period_secondary",
|
| 48 |
+
"enabled": true,
|
| 49 |
+
"default": 7.0,
|
| 50 |
+
"normalization": {"type": "none"},
|
| 51 |
+
"category_mapping": {
|
| 52 |
+
"commute_morning": 0,
|
| 53 |
+
"breakfast": 1,
|
| 54 |
+
"work_morning": 2,
|
| 55 |
+
"lunch": 3,
|
| 56 |
+
"work_afternoon": 4,
|
| 57 |
+
"commute_evening": 5,
|
| 58 |
+
"dinner": 6,
|
| 59 |
+
"rest_evening": 7,
|
| 60 |
+
"rest_night": 8,
|
| 61 |
+
"exercise": 9,
|
| 62 |
+
"unknown": 10
|
| 63 |
+
}
|
| 64 |
+
},
|
| 65 |
+
{"name": "is_weekend", "enabled": true, "default": 0.0, "normalization": {"type": "none"}},
|
| 66 |
+
{
|
| 67 |
+
"name": "data_quality",
|
| 68 |
+
"enabled": true,
|
| 69 |
+
"default": 0.9,
|
| 70 |
+
"normalization": {"type": "minmax", "min": 0.0, "max": 1.0},
|
| 71 |
+
"category_mapping": {
|
| 72 |
+
"low": 0.3,
|
| 73 |
+
"medium": 0.6,
|
| 74 |
+
"high": 1.0
|
| 75 |
+
}
|
| 76 |
+
},
|
| 77 |
+
{"name": "missingness_score", "enabled": true, "default": 0.0, "normalization": {"type": "minmax", "min": 0.0, "max": 1.0}},
|
| 78 |
+
{"name": "baseline_hrv_mean", "enabled": true, "default": 30.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 79 |
+
{"name": "baseline_hrv_std", "enabled": true, "default": 5.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 80 |
+
{"name": "hrv_deviation_abs", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 81 |
+
{"name": "hrv_deviation_pct", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}},
|
| 82 |
+
{"name": "hrv_z_score", "enabled": true, "default": 0.0, "normalization": {"type": "zscore", "use_norm_params": true}}
|
| 83 |
+
],
|
| 84 |
+
"static": [
|
| 85 |
+
{"name": "age_group", "enabled": true, "default": -1},
|
| 86 |
+
{"name": "age_normalized", "enabled": true, "default": 0.5},
|
| 87 |
+
{"name": "sex", "enabled": true, "default": 0.5},
|
| 88 |
+
{"name": "marriage", "enabled": true, "default": -1},
|
| 89 |
+
{"name": "exercise", "enabled": true, "default": -1},
|
| 90 |
+
{"name": "coffee", "enabled": true, "default": -1},
|
| 91 |
+
{"name": "smoking", "enabled": true, "default": -1},
|
| 92 |
+
{"name": "drinking", "enabled": true, "default": -1},
|
| 93 |
+
{"name": "MEQ", "enabled": true, "default": 0.0},
|
| 94 |
+
{"name": "baseline_commute_morning_mean", "enabled": true, "default": 30.0},
|
| 95 |
+
{"name": "baseline_commute_morning_std", "enabled": true, "default": 5.0}
|
| 96 |
+
],
|
| 97 |
+
"factor_features": {
|
| 98 |
+
"enabled": true,
|
| 99 |
+
"factor_names": ["physio", "activity", "context"],
|
| 100 |
+
"factor_dim": 4
|
| 101 |
+
},
|
| 102 |
+
"known_future": [
|
| 103 |
+
{"name": "hour_of_day", "enabled": true},
|
| 104 |
+
{"name": "day_of_week", "enabled": true},
|
| 105 |
+
{"name": "is_weekend", "enabled": true}
|
| 106 |
+
]
|
| 107 |
+
}
|
| 108 |
+
|
feature_calculator.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Dict, List, Optional, Any
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FeatureCalculator:
|
| 11 |
+
"""
|
| 12 |
+
统一从配置文件加载特征定义,构建推理/训练需要的窗口结构
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
config_path: Optional[Path] = None,
|
| 18 |
+
norm_params_path: Optional[Path] = None,
|
| 19 |
+
static_features_path: Optional[Path] = None,
|
| 20 |
+
storage_dir: Optional[Path] = None,
|
| 21 |
+
):
|
| 22 |
+
base_dir = Path(__file__).parent
|
| 23 |
+
self.config_path = Path(config_path or base_dir / "configs" / "features_config.json")
|
| 24 |
+
self.norm_params_path = Path(norm_params_path or base_dir / "processed_data" / "stage3" / "norm_params.json")
|
| 25 |
+
self.static_features_path = Path(static_features_path or base_dir / "processed_data" / "stage2" / "static_features.csv")
|
| 26 |
+
self.storage_dir = Path(storage_dir or base_dir / "data_storage")
|
| 27 |
+
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
| 28 |
+
|
| 29 |
+
self.features_config = self._load_json(self.config_path)
|
| 30 |
+
self.norm_params = self._load_json(self.norm_params_path) if self.norm_params_path.exists() else {}
|
| 31 |
+
self.static_features_dict = self._load_static_features(self.static_features_path)
|
| 32 |
+
|
| 33 |
+
self.time_series_features = [f for f in self.features_config.get("time_series", []) if f.get("enabled", True)]
|
| 34 |
+
self.static_feature_defs = [f for f in self.features_config.get("static", []) if f.get("enabled", True)]
|
| 35 |
+
self.known_future_defs = [f for f in self.features_config.get("known_future", []) if f.get("enabled", True)]
|
| 36 |
+
factor_cfg = self.features_config.get("factor_features", {})
|
| 37 |
+
self.factor_enabled = factor_cfg.get("enabled", False)
|
| 38 |
+
self.factor_names = factor_cfg.get("factor_names", [])
|
| 39 |
+
self.factor_dim = factor_cfg.get("factor_dim", 0)
|
| 40 |
+
|
| 41 |
+
# 简单的内存级历史缓存,便于后续扩展个性化特征
|
| 42 |
+
self.user_histories: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def _load_json(path: Path) -> Dict:
|
| 46 |
+
if not path.exists():
|
| 47 |
+
return {}
|
| 48 |
+
with open(path, "r") as f:
|
| 49 |
+
return json.load(f)
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def _load_static_features(static_file: Path) -> Dict[str, Dict]:
|
| 53 |
+
if not static_file.exists():
|
| 54 |
+
return {}
|
| 55 |
+
df = pd.read_csv(static_file)
|
| 56 |
+
static_dict = {}
|
| 57 |
+
for _, row in df.iterrows():
|
| 58 |
+
device_id = str(row.get("deviceId"))
|
| 59 |
+
if device_id:
|
| 60 |
+
static_dict[device_id] = {
|
| 61 |
+
col: row[col]
|
| 62 |
+
for col in df.columns
|
| 63 |
+
if col != "deviceId"
|
| 64 |
+
}
|
| 65 |
+
return static_dict
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def _to_serializable(value):
|
| 69 |
+
import numpy as np
|
| 70 |
+
from datetime import datetime
|
| 71 |
+
if isinstance(value, (np.integer, )):
|
| 72 |
+
return int(value)
|
| 73 |
+
if isinstance(value, (np.floating, )):
|
| 74 |
+
return float(value)
|
| 75 |
+
if isinstance(value, (pd.Timestamp, datetime)):
|
| 76 |
+
return value.isoformat()
|
| 77 |
+
if isinstance(value, (np.ndarray, )):
|
| 78 |
+
return value.tolist()
|
| 79 |
+
raise TypeError(f"Object of type {type(value)} is not JSON serializable")
|
| 80 |
+
|
| 81 |
+
def register_data_points(self, user_id: str, data_points: List[Dict]):
|
| 82 |
+
"""
|
| 83 |
+
轻量缓存用户数据,并写入 data_storage/users/{user_id}.jsonl
|
| 84 |
+
"""
|
| 85 |
+
if not user_id:
|
| 86 |
+
return
|
| 87 |
+
user_dir = self.storage_dir / "users"
|
| 88 |
+
user_dir.mkdir(exist_ok=True, parents=True)
|
| 89 |
+
history_file = user_dir / f"{user_id}.jsonl"
|
| 90 |
+
|
| 91 |
+
with history_file.open("a", encoding="utf-8") as f:
|
| 92 |
+
for point in data_points:
|
| 93 |
+
serializable = dict(point)
|
| 94 |
+
ts = serializable.get('timestamp')
|
| 95 |
+
if isinstance(ts, (pd.Timestamp, )):
|
| 96 |
+
serializable['timestamp'] = ts.isoformat()
|
| 97 |
+
elif hasattr(ts, "isoformat"):
|
| 98 |
+
serializable['timestamp'] = ts.isoformat()
|
| 99 |
+
f.write(json.dumps(serializable, ensure_ascii=False, default=self._to_serializable) + "\n")
|
| 100 |
+
|
| 101 |
+
self.user_histories[user_id].extend(data_points)
|
| 102 |
+
# 只保留最近 5,000 条在内存,避免占用
|
| 103 |
+
if len(self.user_histories[user_id]) > 5000:
|
| 104 |
+
self.user_histories[user_id] = self.user_histories[user_id][-5000:]
|
| 105 |
+
|
| 106 |
+
def normalize_series(self, values: List[float], feature_name: str, cfg: Dict) -> List[float]:
|
| 107 |
+
arr = np.array(values, dtype=np.float32)
|
| 108 |
+
norm_cfg = cfg.get("normalization", {"type": "none"})
|
| 109 |
+
norm_type = norm_cfg.get("type", "none")
|
| 110 |
+
|
| 111 |
+
if norm_type == "zscore":
|
| 112 |
+
mean, std = self._get_norm_stats(feature_name, norm_cfg)
|
| 113 |
+
if std == 0:
|
| 114 |
+
std = 1.0
|
| 115 |
+
arr = (arr - mean) / std
|
| 116 |
+
elif norm_type == "minmax":
|
| 117 |
+
min_v = norm_cfg.get("min", 0.0)
|
| 118 |
+
max_v = norm_cfg.get("max", 1.0)
|
| 119 |
+
scale = max(max_v - min_v, 1e-6)
|
| 120 |
+
arr = (arr - min_v) / scale
|
| 121 |
+
else:
|
| 122 |
+
# none
|
| 123 |
+
pass
|
| 124 |
+
|
| 125 |
+
arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
|
| 126 |
+
return arr.tolist()
|
| 127 |
+
|
| 128 |
+
@staticmethod
|
| 129 |
+
def _coerce_value(value, feat_cfg):
|
| 130 |
+
default = feat_cfg.get("default", 0.0)
|
| 131 |
+
if value is None or pd.isna(value):
|
| 132 |
+
return default
|
| 133 |
+
category_mapping = feat_cfg.get("category_mapping")
|
| 134 |
+
if isinstance(value, str):
|
| 135 |
+
if category_mapping:
|
| 136 |
+
return category_mapping.get(value, default)
|
| 137 |
+
try:
|
| 138 |
+
return float(value)
|
| 139 |
+
except ValueError:
|
| 140 |
+
return default
|
| 141 |
+
try:
|
| 142 |
+
return float(value)
|
| 143 |
+
except (TypeError, ValueError):
|
| 144 |
+
return default
|
| 145 |
+
|
| 146 |
+
def _get_norm_stats(self, feature_name: str, norm_cfg: Dict) -> (float, float):
|
| 147 |
+
if norm_cfg.get("use_norm_params") and feature_name in self.norm_params:
|
| 148 |
+
stats = self.norm_params[feature_name]
|
| 149 |
+
return stats.get("mean", 0.0), stats.get("std", 1.0)
|
| 150 |
+
return norm_cfg.get("mean", 0.0), norm_cfg.get("std", 1.0)
|
| 151 |
+
|
| 152 |
+
def build_window(self, data_points: List[Dict], user_id: Optional[str] = None) -> Dict:
|
| 153 |
+
if len(data_points) < 12:
|
| 154 |
+
raise ValueError("数据点不足,需要至少12个点构建短期窗口")
|
| 155 |
+
|
| 156 |
+
if user_id:
|
| 157 |
+
self.register_data_points(user_id, data_points)
|
| 158 |
+
|
| 159 |
+
timestamps = []
|
| 160 |
+
input_features = {feat["name"]: [] for feat in self.time_series_features}
|
| 161 |
+
|
| 162 |
+
for point in data_points:
|
| 163 |
+
ts = point.get("timestamp")
|
| 164 |
+
if isinstance(ts, str):
|
| 165 |
+
ts = pd.to_datetime(ts)
|
| 166 |
+
timestamps.append(ts)
|
| 167 |
+
|
| 168 |
+
feature_payload = point.get("features", {})
|
| 169 |
+
for feat_cfg in self.time_series_features:
|
| 170 |
+
name = feat_cfg["name"]
|
| 171 |
+
value = feature_payload.get(name)
|
| 172 |
+
value = self._coerce_value(value, feat_cfg)
|
| 173 |
+
input_features[name].append(value)
|
| 174 |
+
|
| 175 |
+
# delta_t
|
| 176 |
+
delta_t = [0.0]
|
| 177 |
+
for i in range(1, len(timestamps)):
|
| 178 |
+
diff = (timestamps[i] - timestamps[i - 1]).total_seconds()
|
| 179 |
+
delta_t.append(float(diff))
|
| 180 |
+
|
| 181 |
+
# 归一化
|
| 182 |
+
normalized_features = {}
|
| 183 |
+
for feat_cfg in self.time_series_features:
|
| 184 |
+
name = feat_cfg["name"]
|
| 185 |
+
normalized_features[name] = self.normalize_series(input_features[name], name, feat_cfg)
|
| 186 |
+
|
| 187 |
+
static_features = self._build_static_features(data_points[0], user_id)
|
| 188 |
+
factor_features = self._build_factor_features(normalized_features)
|
| 189 |
+
known_future = self._build_known_future(timestamps[-6:] if len(timestamps) >= 6 else timestamps)
|
| 190 |
+
|
| 191 |
+
return {
|
| 192 |
+
"input_timestamp": timestamps[:12],
|
| 193 |
+
"input_delta_t": delta_t[:12],
|
| 194 |
+
"input_features": normalized_features,
|
| 195 |
+
"target_timestamp": timestamps[12:] if len(timestamps) > 12 else [],
|
| 196 |
+
"target_delta_t": delta_t[12:] if len(delta_t) > 12 else [],
|
| 197 |
+
"static_features": static_features,
|
| 198 |
+
"known_future_features": known_future,
|
| 199 |
+
"factor_features": factor_features,
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
def _build_static_features(self, first_point: Dict, user_id: Optional[str]) -> Dict:
|
| 203 |
+
static_payload = dict(first_point.get("static_features", {}))
|
| 204 |
+
device_id = first_point.get("deviceId") or user_id
|
| 205 |
+
|
| 206 |
+
if device_id and str(device_id) in self.static_features_dict:
|
| 207 |
+
for key, value in self.static_features_dict[str(device_id)].items():
|
| 208 |
+
static_payload.setdefault(key, value)
|
| 209 |
+
|
| 210 |
+
result = {}
|
| 211 |
+
for feat_cfg in self.static_feature_defs:
|
| 212 |
+
name = feat_cfg["name"]
|
| 213 |
+
result[name] = static_payload.get(name, feat_cfg.get("default", 0.0))
|
| 214 |
+
return result
|
| 215 |
+
|
| 216 |
+
def _build_factor_features(self, normalized_features: Dict[str, List[float]]) -> Optional[Dict[str, List[float]]]:
|
| 217 |
+
if not self.factor_enabled or not self.factor_names:
|
| 218 |
+
return None
|
| 219 |
+
|
| 220 |
+
factor_vectors = {}
|
| 221 |
+
for factor_name in self.factor_names:
|
| 222 |
+
# 目前采用简单均值/最大值/最小值/最后值,方便后续替换
|
| 223 |
+
merged = []
|
| 224 |
+
for feat_name, values in normalized_features.items():
|
| 225 |
+
if factor_name == "physio" and feat_name.startswith("hrv"):
|
| 226 |
+
merged.extend(values)
|
| 227 |
+
elif factor_name == "activity" and feat_name in {"steps", "distance", "calories"}:
|
| 228 |
+
merged.extend(values)
|
| 229 |
+
elif factor_name == "context" and feat_name in {"time_period_primary", "time_period_secondary", "is_weekend"}:
|
| 230 |
+
merged.extend(values)
|
| 231 |
+
|
| 232 |
+
if not merged:
|
| 233 |
+
factor_vectors[factor_name] = [0.0] * self.factor_dim
|
| 234 |
+
else:
|
| 235 |
+
arr = np.array(merged, dtype=np.float32)
|
| 236 |
+
stats = [
|
| 237 |
+
float(arr.mean()),
|
| 238 |
+
float(arr.std()),
|
| 239 |
+
float(arr.max()),
|
| 240 |
+
float(arr.min())
|
| 241 |
+
]
|
| 242 |
+
factor_vectors[factor_name] = stats[: self.factor_dim] if len(stats) >= self.factor_dim else stats + [0.0] * (self.factor_dim - len(stats))
|
| 243 |
+
return factor_vectors
|
| 244 |
+
|
| 245 |
+
def _build_known_future(self, timestamps: List[pd.Timestamp]) -> Dict[str, List[float]]:
|
| 246 |
+
hours, days, weekends = [], [], []
|
| 247 |
+
for ts in timestamps:
|
| 248 |
+
if pd.isna(ts):
|
| 249 |
+
hours.append(12.0)
|
| 250 |
+
days.append(3.0)
|
| 251 |
+
weekends.append(0.0)
|
| 252 |
+
else:
|
| 253 |
+
hours.append(float(ts.hour))
|
| 254 |
+
days.append(float(ts.weekday()))
|
| 255 |
+
weekends.append(float(1 if ts.weekday() >= 5 else 0))
|
| 256 |
+
|
| 257 |
+
result = {}
|
| 258 |
+
for cfg in self.known_future_defs:
|
| 259 |
+
name = cfg["name"]
|
| 260 |
+
if name == "hour_of_day":
|
| 261 |
+
result[name] = hours
|
| 262 |
+
elif name == "day_of_week":
|
| 263 |
+
result[name] = days
|
| 264 |
+
elif name == "is_weekend":
|
| 265 |
+
result[name] = weekends
|
| 266 |
+
return result
|
| 267 |
+
|
| 268 |
+
def get_enabled_feature_names(self) -> List[str]:
|
| 269 |
+
return [feat["name"] for feat in self.time_series_features]
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
__all__ = ["FeatureCalculator"]
|
| 273 |
+
|
processed_data/stage3/norm_params.json
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"hr_mean": {
|
| 3 |
+
"mean": 79.88385009765625,
|
| 4 |
+
"std": 15.546831130981445,
|
| 5 |
+
"min": 33.0,
|
| 6 |
+
"max": 200.2244873046875
|
| 7 |
+
},
|
| 8 |
+
"hr_std": {
|
| 9 |
+
"mean": 12.757049560546875,
|
| 10 |
+
"std": 3.9224278926849365,
|
| 11 |
+
"min": 0.0,
|
| 12 |
+
"max": 32.2431755065918
|
| 13 |
+
},
|
| 14 |
+
"hr_median": {
|
| 15 |
+
"mean": 76.4555892944336,
|
| 16 |
+
"std": 6.908801555633545,
|
| 17 |
+
"min": 48.0,
|
| 18 |
+
"max": 104.0
|
| 19 |
+
},
|
| 20 |
+
"hr_resting": {
|
| 21 |
+
"mean": 65.74867248535156,
|
| 22 |
+
"std": 7.843548774719238,
|
| 23 |
+
"min": 44.12284469604492,
|
| 24 |
+
"max": 86.0
|
| 25 |
+
},
|
| 26 |
+
"hr_nrem": {
|
| 27 |
+
"mean": 61.779720306396484,
|
| 28 |
+
"std": 11.666051864624023,
|
| 29 |
+
"min": 0.0,
|
| 30 |
+
"max": 92.5469970703125
|
| 31 |
+
},
|
| 32 |
+
"hrv_rmssd": {
|
| 33 |
+
"mean": 83.4627685546875,
|
| 34 |
+
"std": 62.30027389526367,
|
| 35 |
+
"min": 0.0,
|
| 36 |
+
"max": 855.8391723632812
|
| 37 |
+
},
|
| 38 |
+
"hrv_sdnn": {
|
| 39 |
+
"mean": 100.59049987792969,
|
| 40 |
+
"std": 43.545467376708984,
|
| 41 |
+
"min": 0.0,
|
| 42 |
+
"max": 393.35162353515625
|
| 43 |
+
},
|
| 44 |
+
"steps": {
|
| 45 |
+
"mean": 342.7657470703125,
|
| 46 |
+
"std": 823.3682861328125,
|
| 47 |
+
"min": 0.0,
|
| 48 |
+
"max": 27004.0
|
| 49 |
+
},
|
| 50 |
+
"distance": {
|
| 51 |
+
"mean": 225.4749755859375,
|
| 52 |
+
"std": 504.8075866699219,
|
| 53 |
+
"min": 0.0,
|
| 54 |
+
"max": 10460.2998046875
|
| 55 |
+
},
|
| 56 |
+
"calories": {
|
| 57 |
+
"mean": 104.05133819580078,
|
| 58 |
+
"std": 211.85128784179688,
|
| 59 |
+
"min": 0.0,
|
| 60 |
+
"max": 2962.070068359375
|
| 61 |
+
},
|
| 62 |
+
"sleep_duration_total": {
|
| 63 |
+
"mean": 418.6901550292969,
|
| 64 |
+
"std": 142.2774200439453,
|
| 65 |
+
"min": 0.0,
|
| 66 |
+
"max": 1110.0
|
| 67 |
+
},
|
| 68 |
+
"sleep_efficiency": {
|
| 69 |
+
"mean": 93.89789581298828,
|
| 70 |
+
"std": 7.327056884765625,
|
| 71 |
+
"min": 34.0,
|
| 72 |
+
"max": 100.0
|
| 73 |
+
},
|
| 74 |
+
"sleep_deep_ratio": {
|
| 75 |
+
"mean": 1.00419020652771,
|
| 76 |
+
"std": 0.3390481770038605,
|
| 77 |
+
"min": 0.0,
|
| 78 |
+
"max": 4.310344696044922
|
| 79 |
+
},
|
| 80 |
+
"sleep_rem_ratio": {
|
| 81 |
+
"mean": 1.00448739528656,
|
| 82 |
+
"std": 0.35869544744491577,
|
| 83 |
+
"min": 0.0,
|
| 84 |
+
"max": 3.9259259700775146
|
| 85 |
+
},
|
| 86 |
+
"sleep_light_ratio": {
|
| 87 |
+
"mean": 0.9923003315925598,
|
| 88 |
+
"std": 0.23265497386455536,
|
| 89 |
+
"min": 0.0,
|
| 90 |
+
"max": 3.034313678741455
|
| 91 |
+
},
|
| 92 |
+
"spo2": {
|
| 93 |
+
"mean": 95.9047622680664,
|
| 94 |
+
"std": 1.04403817653656,
|
| 95 |
+
"min": 92.4000015258789,
|
| 96 |
+
"max": 100.0
|
| 97 |
+
},
|
| 98 |
+
"stress_score": {
|
| 99 |
+
"mean": 65.94886779785156,
|
| 100 |
+
"std": 28.051528930664062,
|
| 101 |
+
"min": 0.0,
|
| 102 |
+
"max": 93.0
|
| 103 |
+
},
|
| 104 |
+
"ALERT": {
|
| 105 |
+
"mean": 0.07375683635473251,
|
| 106 |
+
"std": 0.2613747715950012,
|
| 107 |
+
"min": 0.0,
|
| 108 |
+
"max": 1.0
|
| 109 |
+
},
|
| 110 |
+
"HAPPY": {
|
| 111 |
+
"mean": 0.1726546734571457,
|
| 112 |
+
"std": 0.37794846296310425,
|
| 113 |
+
"min": 0.0,
|
| 114 |
+
"max": 1.0
|
| 115 |
+
},
|
| 116 |
+
"NEUTRAL": {
|
| 117 |
+
"mean": 0.1967589408159256,
|
| 118 |
+
"std": 0.3975485563278198,
|
| 119 |
+
"min": 0.0,
|
| 120 |
+
"max": 1.0
|
| 121 |
+
},
|
| 122 |
+
"RESTED/RELAXED": {
|
| 123 |
+
"mean": 0.23211927711963654,
|
| 124 |
+
"std": 0.42218467593193054,
|
| 125 |
+
"min": 0.0,
|
| 126 |
+
"max": 1.0
|
| 127 |
+
},
|
| 128 |
+
"SAD": {
|
| 129 |
+
"mean": 0.018068943172693253,
|
| 130 |
+
"std": 0.13320080935955048,
|
| 131 |
+
"min": 0.0,
|
| 132 |
+
"max": 1.0
|
| 133 |
+
},
|
| 134 |
+
"TENSE/ANXIOUS": {
|
| 135 |
+
"mean": 0.10590820014476776,
|
| 136 |
+
"std": 0.3077200949192047,
|
| 137 |
+
"min": 0.0,
|
| 138 |
+
"max": 1.0
|
| 139 |
+
},
|
| 140 |
+
"TIRED": {
|
| 141 |
+
"mean": 0.20073312520980835,
|
| 142 |
+
"std": 0.4005488157272339,
|
| 143 |
+
"min": 0.0,
|
| 144 |
+
"max": 1.0
|
| 145 |
+
}
|
| 146 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.1.0
|
| 2 |
+
numpy>=1.24
|
| 3 |
+
pandas>=2.0
|
| 4 |
+
huggingface_hub>=0.23
|
test_wearable_service.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import List, Dict, Tuple
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
from wearable_anomaly_detector import load_detector
|
| 8 |
+
from feature_calculator import FeatureCalculator
|
| 9 |
+
|
| 10 |
+
# 预置案例(来自wearable原始数据)
|
| 11 |
+
PREDEFINED_CASES: Dict[str, Dict] = {
|
| 12 |
+
"ab60_morning_rest": {
|
| 13 |
+
"description": "用户ab60,清晨休息到早餐前的连续12个窗口,用于快速验证服务是否可输出结果",
|
| 14 |
+
"user_id": "ab60",
|
| 15 |
+
"data_points": [
|
| 16 |
+
{"timestamp": "2021-03-04T04:45:20.170000", "deviceId": "ab60", "features": {"hr": 91.65860215053765, "hr_resting": 87.84302108870469, "hrv_rmssd": 73.33511196423747, "hrv_sdnn": 72.35486488414405, "hrv_pnn50": 0.3422818791946309, "sdnn": 72.35486488414405, "sdsd": 55.28945952794972, "rmssd": 73.33511196423747, "pnn20": 0.6677852348993288, "pnn50": 0.3422818791946309, "ibi": 671.2685790942928, "lf/hf": 0.6578861348372742, "acc_x_avg": 4.9294712037533515, "acc_y_avg": -3.1057153652814957, "acc_z_avg": 5.8750820100536005, "grv_x_avg": -0.5497846977211797, "grv_y_avg": 0.0042184631367292, "grv_z_avg": 0.1525969041554961, "grv_w_avg": 0.1771413800268097, "gyr_x_avg": -0.9240750428954416, "gyr_y_avg": 1.1994772238605889, "gyr_z_avg": 0.3142024215817695, "light_avg": 440.066889632107, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.2886646919083789}},
|
| 17 |
+
{"timestamp": "2021-03-04T05:15:20.497000", "deviceId": "ab60", "features": {"hr": 84.15604988203573, "hr_resting": 87.84302108870469, "hrv_rmssd": 73.48193641157741, "hrv_sdnn": 74.51390542231427, "hrv_pnn50": 0.4143302180685358, "sdnn": 74.51390542231427, "sdsd": 51.36438677767364, "rmssd": 73.48193641157741, "pnn20": 0.7538940809968847, "pnn50": 0.4143302180685358, "ibi": 729.7248964415015, "lf/hf": 1.1884595045473076, "acc_x_avg": 5.264426130609508, "acc_y_avg": -2.4949751527126582, "acc_z_avg": 5.638523847957136, "grv_x_avg": -0.2725604313462818, "grv_y_avg": -0.0417301761553922, "grv_z_avg": 0.0857810462156731, "grv_w_avg": 0.2177219383791028, "gyr_x_avg": -0.6168720663094444, "gyr_y_avg": 1.3517548573342275, "gyr_z_avg": 0.3159611446751512, "light_avg": 121.7391304347826, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1872242091224755}},
|
| 18 |
+
{"timestamp": "2021-03-04T05:35:21.317000", "deviceId": "ab60", "features": {"hr": 82.22642857142857, "hr_resting": 87.84302108870469, "hrv_rmssd": 73.0761458919254, "hrv_sdnn": 81.71063624330793, "hrv_pnn50": 0.45703125, "sdnn": 81.71063624330793, "sdsd": 47.19876056820636, "rmssd": 73.0761458919254, "pnn20": 0.7578125, "pnn50": 0.45703125, "ibi": 799.9492801995793, "lf/hf": 1.7751635086235489, "acc_x_avg": 4.37106831212324, "acc_y_avg": -0.3996270033489605, "acc_z_avg": 6.891127246483591, "grv_x_avg": -0.6645361294433281, "grv_y_avg": -0.3851286666666666, "grv_z_avg": 0.0854148665325285, "grv_w_avg": 0.189048981220657, "gyr_x_avg": -0.1036394095174265, "gyr_y_avg": 1.4141689068364616, "gyr_z_avg": 0.4626943733243945, "light_avg": 603.2266666666667, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.335979916085374}},
|
| 19 |
+
{"timestamp": "2021-03-04T05:55:21.428000", "deviceId": "ab60", "features": {"hr": 86.76466621712744, "hr_resting": 87.84302108870469, "hrv_rmssd": 74.59311564233765, "hrv_sdnn": 87.31034842159234, "hrv_pnn50": 0.391304347826087, "sdnn": 87.31034842159234, "sdsd": 52.14310859847846, "rmssd": 74.59311564233765, "pnn20": 0.7418478260869565, "pnn50": 0.391304347826087, "ibi": 709.0165592504526, "lf/hf": 2.4824085488532703, "acc_x_avg": 6.8974065639651725, "acc_y_avg": -0.6532005197588735, "acc_z_avg": 4.533895625586072, "grv_x_avg": -0.606806606831882, "grv_y_avg": -0.1301535706630945, "grv_z_avg": 0.0961370341594106, "grv_w_avg": 0.4334408365706628, "gyr_x_avg": -0.9999732337575332, "gyr_y_avg": 1.1268921553918274, "gyr_z_avg": 0.0987407849966511, "light_avg": 428.3612040133779, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1217623103705546}},
|
| 20 |
+
{"timestamp": "2021-03-04T06:00:21.434000", "deviceId": "ab60", "features": {"hr": 86.11454484380249, "hr_resting": 87.84302108870469, "hrv_rmssd": 80.21713778627182, "hrv_sdnn": 72.4260069463789, "hrv_pnn50": 0.4117647058823529, "sdnn": 72.4260069463789, "sdsd": 57.42532753452254, "rmssd": 80.21713778627182, "pnn20": 0.7536764705882353, "pnn50": 0.4117647058823529, "ibi": 711.989004966488, "lf/hf": 0.5778707099781514, "acc_x_avg": 2.6462390616208977, "acc_y_avg": -2.0968226182183503, "acc_z_avg": 7.492496608841257, "grv_x_avg": -0.3523352029470862, "grv_y_avg": -0.3171756162089749, "grv_z_avg": 0.0852522665773609, "grv_w_avg": 0.0408305612860013, "gyr_x_avg": -0.7581300395442356, "gyr_y_avg": 1.7399128639410202, "gyr_z_avg": 0.5810656782841818, "light_avg": 719.866220735786, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.3055760776711148}},
|
| 21 |
+
{"timestamp": "2021-03-04T06:05:21.530000", "deviceId": "ab60", "features": {"hr": 85.55876427132304, "hr_resting": 87.84302108870469, "hrv_rmssd": 61.68777738949315, "hrv_sdnn": 61.78286083667294, "hrv_pnn50": 0.3786127167630058, "sdnn": 61.78286083667294, "sdsd": 40.15746948668274, "rmssd": 61.68777738949315, "pnn20": 0.7283236994219653, "pnn50": 0.3786127167630058, "ibi": 715.325568241176, "lf/hf": 0.9103296457711838, "acc_x_avg": 3.0594024239785633, "acc_y_avg": -2.580148740120562, "acc_z_avg": 7.2776362170127245, "grv_x_avg": -0.2312993507712944, "grv_y_avg": -0.1165045117370892, "grv_z_avg": 0.1689689483568074, "grv_w_avg": 0.1020506572769953, "gyr_x_avg": -0.2617024135388738, "gyr_y_avg": 1.1428016065683635, "gyr_z_avg": 0.0963672908847178, "light_avg": 665.6321070234113, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1561355447930486}},
|
| 22 |
+
{"timestamp": "2021-03-04T06:15:21.623000", "deviceId": "ab60", "features": {"hr": 82.62592343854936, "hr_resting": 87.84302108870469, "hrv_rmssd": 81.00579057638534, "hrv_sdnn": 73.40584581276266, "hrv_pnn50": 0.4357366771159874, "sdnn": 73.40584581276266, "sdsd": 57.10852463732151, "rmssd": 81.00579057638534, "pnn20": 0.7366771159874608, "pnn50": 0.4357366771159874, "ibi": 746.8545044048512, "lf/hf": 2.2582676863270708, "acc_x_avg": 0.1661593281982583, "acc_y_avg": -1.6491836677829892, "acc_z_avg": 9.736414508372398, "grv_x_avg": -0.6639661513730741, "grv_y_avg": 0.2112439484259877, "grv_z_avg": 0.008919145344943, "grv_w_avg": 0.0059900468854654, "gyr_x_avg": -0.3158673777628935, "gyr_y_avg": 0.9963161399866036, "gyr_z_avg": 0.4727930301406551, "light_avg": 802.180602006689, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.174593188653174}},
|
| 23 |
+
{"timestamp": "2021-03-04T06:40:21.894000", "deviceId": "ab60", "features": {"hr": 83.86055107526882, "hr_resting": 87.84302108870469, "hrv_rmssd": 79.079102105631, "hrv_sdnn": 71.271849209199, "hrv_pnn50": 0.4221556886227545, "sdnn": 71.271849209199, "sdsd": 54.69084849089656, "rmssd": 79.079102105631, "pnn20": 0.7724550898203593, "pnn50": 0.4221556886227545, "ibi": 732.4375213133641, "lf/hf": 0.6677155405784594, "acc_x_avg": 1.7091933630274596, "acc_y_avg": -2.3761690361687893, "acc_z_avg": 8.209631811788354, "grv_x_avg": -0.3922270288010715, "grv_y_avg": 0.2412905398526454, "grv_z_avg": -0.0233138225050234, "grv_w_avg": 0.1066616838580042, "gyr_x_avg": 0.3474681640991295, "gyr_y_avg": 1.1165304762223718, "gyr_z_avg": 0.186041525117213, "light_avg": 729.1103678929766, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.162896032760479}},
|
| 24 |
+
{"timestamp": "2021-03-04T06:45:21.969000", "deviceId": "ab60", "features": {"hr": 85.06473364801079, "hr_resting": 87.84302108870469, "hrv_rmssd": 73.21565395106201, "hrv_sdnn": 69.24046192941114, "hrv_pnn50": 0.4137931034482758, "sdnn": 69.24046192941114, "sdsd": 49.87687480659143, "rmssd": 73.21565395106201, "pnn20": 0.768025078369906, "pnn50": 0.4137931034482758, "ibi": 719.8668191606536, "lf/hf": 1.6620124103207512, "acc_x_avg": 2.2839716548257365, "acc_y_avg": -2.790153060991958, "acc_z_avg": 8.004115201072393, "grv_x_avg": -0.4554428230563004, "grv_y_avg": -0.1927247533512062, "grv_z_avg": 0.0215668719839142, "grv_w_avg": 0.1034467037533509, "gyr_x_avg": -0.0257573813672921, "gyr_y_avg": 1.0048659463806948, "gyr_z_avg": 0.3241018806970504, "light_avg": 755.3177257525084, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1959064930123424}},
|
| 25 |
+
{"timestamp": "2021-03-04T06:50:21.981000", "deviceId": "ab60", "features": {"hr": 84.85642062689585, "hr_resting": 87.84302108870469, "hrv_rmssd": 68.0148861140481, "hrv_sdnn": 76.76458770787994, "hrv_pnn50": 0.3815384615384615, "sdnn": 76.76458770787994, "sdsd": 46.69614203120664, "rmssd": 68.0148861140481, "pnn20": 0.6707692307692308, "pnn50": 0.3815384615384615, "ibi": 729.6592731642268, "lf/hf": 0.9127720565801908, "acc_x_avg": 1.71269043335566, "acc_y_avg": -1.0698834099129273, "acc_z_avg": 9.200909472873422, "grv_x_avg": -0.375676233087742, "grv_y_avg": 0.1822926999330206, "grv_z_avg": 0.0157954038847957, "grv_w_avg": 0.0460910823844608, "gyr_x_avg": -0.1220897675820484, "gyr_y_avg": 0.6891694554588073, "gyr_z_avg": 0.4564768801071655, "light_avg": 847.0335570469799, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1845048437257962}},
|
| 26 |
+
{"timestamp": "2021-03-04T06:55:21.985000", "deviceId": "ab60", "features": {"hr": 80.75881760161236, "hr_resting": 87.84302108870469, "hrv_rmssd": 66.04701786885973, "hrv_sdnn": 76.17603294638347, "hrv_pnn50": 0.3684210526315789, "sdnn": 76.17603294638347, "sdsd": 45.0541425499126, "rmssd": 66.04701786885973, "pnn20": 0.7151702786377709, "pnn50": 0.3684210526315789, "ibi": 757.0256501417455, "lf/hf": 1.269681575994282, "acc_x_avg": 3.739257345612857, "acc_y_avg": -2.4820750622906904, "acc_z_avg": 7.388303849966525, "grv_x_avg": -0.5472615157401197, "grv_y_avg": 0.3134806992632281, "grv_z_avg": 0.0378557434695244, "grv_w_avg": 0.2093843141326185, "gyr_x_avg": -0.5253990328638486, "gyr_y_avg": 1.6931924748490956, "gyr_z_avg": 0.665633814889338, "light_avg": 703.4496644295302, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1579866815850661}},
|
| 27 |
+
{"timestamp": "2021-03-04T07:00:22.020000", "deviceId": "ab60", "features": {"hr": 83.36467427803895, "hr_resting": 87.84302108870469, "hrv_rmssd": 85.01833953484277, "hrv_sdnn": 76.7068088047925, "hrv_pnn50": 0.4434782608695652, "sdnn": 76.7068088047925, "sdsd": 62.38011995356233, "rmssd": 85.01833953484277, "pnn20": 0.7536231884057971, "pnn50": 0.4434782608695652, "ibi": 724.6919175240898, "lf/hf": 1.3219549787885163, "acc_x_avg": 0.4397476164658638, "acc_y_avg": -1.638112914323962, "acc_z_avg": 9.6665646124498, "grv_x_avg": -0.3218441425702814, "grv_y_avg": 0.1350831372155288, "grv_z_avg": 0.0047377771084337, "grv_w_avg": 0.0104064123159303, "gyr_x_avg": -0.2149196720214184, "gyr_y_avg": 0.9163253052208852, "gyr_z_avg": 0.5661378821954464, "light_avg": 861.3110367892976, "time_period_primary": "morning", "time_period_secondary": "breakfast", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1291275275920405}}
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
"nd56_low_activity_sparse": {
|
| 31 |
+
"description": "nd56:多数传感器字段缺失,只提供心率/时间段等基础信息,验证模型默认填充能力",
|
| 32 |
+
"user_id": "nd56",
|
| 33 |
+
"data_points": [
|
| 34 |
+
{"timestamp": "2021-03-04T03:40:55.745000", "deviceId": "nd56", "features": {"hr": 73.3681592039801, "steps": None, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium"}},
|
| 35 |
+
{"timestamp": "2021-03-04T03:55:55.963000", "deviceId": "nd56", "features": {"hr": 70.58150365934797, "steps": None, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium"}},
|
| 36 |
+
{"timestamp": "2021-03-04T04:30:56.879000", "deviceId": "nd56", "features": {"hr": 79.00866089273818, "steps": None, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium"}},
|
| 37 |
+
{"timestamp": "2021-03-04T05:11:22.339000", "deviceId": "nd56", "features": {"hr": 76.88147410358566, "steps": None, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium"}},
|
| 38 |
+
{"timestamp": "2021-03-04T06:21:22.941000", "deviceId": "nd56", "features": {"hr": 79.52276503821868, "steps": None, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium"}},
|
| 39 |
+
{"timestamp": "2021-03-04T06:36:22.983000", "deviceId": "nd56", "features": {"hr": 75.333, "steps": None, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium"}},
|
| 40 |
+
{"timestamp": "2021-03-04T06:41:23.077000", "deviceId": "nd56", "features": {"hr": 75.7566401062417, "steps": None, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium"}},
|
| 41 |
+
{"timestamp": "2021-03-04T06:56:23.275000", "deviceId": "nd56", "features": {"hr": 74.49435590969456, "steps": None, "time_period_primary": "morning", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium"}},
|
| 42 |
+
{"timestamp": "2021-03-04T07:26:23.418000", "deviceId": "nd56", "features": {"hr": 73.90371845949535, "steps": None, "time_period_primary": "morning", "time_period_secondary": "breakfast", "is_weekend": 0, "data_quality": "medium"}},
|
| 43 |
+
{"timestamp": "2021-03-04T07:36:23.463000", "deviceId": "nd56", "features": {"hr": 72.7252911813644, "steps": None, "time_period_primary": "morning", "time_period_secondary": "breakfast", "is_weekend": 0, "data_quality": "medium"}},
|
| 44 |
+
{"timestamp": "2021-03-04T07:41:23.508000", "deviceId": "nd56", "features": {"hr": 70.14043824701196, "steps": None, "time_period_primary": "morning", "time_period_secondary": "breakfast", "is_weekend": 0, "data_quality": "medium"}},
|
| 45 |
+
{"timestamp": "2021-03-04T07:46:23.547000", "deviceId": "nd56", "features": {"hr": 71.33565737051792, "steps": None, "time_period_primary": "morning", "time_period_secondary": "breakfast", "is_weekend": 0, "data_quality": "medium"}}
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
"anon_commuter_minimal": {
|
| 49 |
+
"description": "匿名通勤用户:仅提供心率/HRV/时间段等极简指标,deviceId缺失,��证服务可批量处理不同客户",
|
| 50 |
+
"user_id": "anon_commuter",
|
| 51 |
+
"data_points": [
|
| 52 |
+
{"timestamp": "2021-03-09T00:52:04.300000", "deviceId": None, "features": {"hr": 88.77359119706568, "hrv_rmssd": 68.78080551188198, "hrv_sdnn": 68.55954732717733, "steps": 0.0, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.2340064304816851}},
|
| 53 |
+
{"timestamp": "2021-03-09T00:57:04.328000", "deviceId": None, "features": {"hr": 80.68233333333333, "hrv_rmssd": 59.2934623277283, "hrv_sdnn": 73.79953723523498, "steps": 0.0, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1442517250480751}},
|
| 54 |
+
{"timestamp": "2021-03-09T01:02:04.404000", "deviceId": None, "features": {"hr": 75.80266666666667, "hrv_rmssd": 45.96602248249421, "hrv_sdnn": 86.08949225862196, "steps": None, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.0475269119819883}},
|
| 55 |
+
{"timestamp": "2021-03-09T01:07:04.432000", "deviceId": None, "features": {"hr": 74.08533333333334, "hrv_rmssd": 35.124819470890934, "hrv_sdnn": 76.28600536828999, "steps": None, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.0362464905334389}},
|
| 56 |
+
{"timestamp": "2021-03-09T01:12:04.441000", "deviceId": None, "features": {"hr": 76.9121411276375, "hrv_rmssd": 72.02894057601205, "hrv_sdnn": 87.11453810833142, "steps": 0.0, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1366772654292948}},
|
| 57 |
+
{"timestamp": "2021-03-09T01:17:04.485000", "deviceId": None, "features": {"hr": 68.27024325224924, "hrv_rmssd": 36.20840500569033, "hrv_sdnn": 61.29313423988413, "steps": None, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": -0.0048301680504103}},
|
| 58 |
+
{"timestamp": "2021-03-09T01:22:04.496000", "deviceId": None, "features": {"hr": 71.42038640906063, "hrv_rmssd": 42.73973712175516, "hrv_sdnn": 100.9222381995624, "steps": 0.0, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.0282886513311317}},
|
| 59 |
+
{"timestamp": "2021-03-09T01:27:04.517000", "deviceId": None, "features": {"hr": 78.02364302364302, "hrv_rmssd": 68.8236455702779, "hrv_sdnn": 91.82364754816273, "steps": 0.0, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.133595953991592}},
|
| 60 |
+
{"timestamp": "2021-03-09T01:32:04.589000", "deviceId": None, "features": {"hr": 74.2271818787475, "hrv_rmssd": 47.71422420363224, "hrv_sdnn": 80.68140559056071, "steps": 0.0, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.0569492438181573}},
|
| 61 |
+
{"timestamp": "2021-03-09T01:37:04.619000", "deviceId": None, "features": {"hr": 66.9780146568954, "hrv_rmssd": 35.74818419244297, "hrv_sdnn": 50.71532677029013, "steps": None, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.0026578073089701}},
|
| 62 |
+
{"timestamp": "2021-03-09T01:47:04.721000", "deviceId": None, "features": {"hr": 74.31312458361093, "hrv_rmssd": 57.44451616196892, "hrv_sdnn": 98.6011860346101, "steps": 0.0, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.1253227425948505}},
|
| 63 |
+
{"timestamp": "2021-03-09T01:52:04.735000", "deviceId": None, "features": {"hr": 71.42538307794804, "hrv_rmssd": 50.01054170733226, "hrv_sdnn": 73.22227779949023, "steps": 0.0, "time_period_primary": "night", "time_period_secondary": "rest_night", "is_weekend": 0, "data_quality": "medium", "missingness_score": 0.0255565038546025}}
|
| 64 |
+
]
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_raw_data_window(
|
| 70 |
+
stage1_file: Path,
|
| 71 |
+
feature_names: List[str],
|
| 72 |
+
window_size: int = 12,
|
| 73 |
+
) -> Tuple[List[Dict], str]:
|
| 74 |
+
stage1_path = Path(stage1_file)
|
| 75 |
+
if not stage1_path.exists():
|
| 76 |
+
raise FileNotFoundError(f"找不到原始数据文件: {stage1_file}")
|
| 77 |
+
|
| 78 |
+
base_cols = pd.read_csv(stage1_path, nrows=0).columns.tolist()
|
| 79 |
+
usecols = ['deviceId', 'ts_start'] + [feat for feat in feature_names if feat in base_cols]
|
| 80 |
+
usecols = list(dict.fromkeys(usecols))
|
| 81 |
+
|
| 82 |
+
buffers: Dict[str, List[Dict]] = {}
|
| 83 |
+
reader = pd.read_csv(
|
| 84 |
+
stage1_path,
|
| 85 |
+
usecols=usecols,
|
| 86 |
+
parse_dates=['ts_start'],
|
| 87 |
+
chunksize=10000,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
for chunk in reader:
|
| 91 |
+
chunk = chunk.sort_values(['deviceId', 'ts_start'])
|
| 92 |
+
for device_id, group in chunk.groupby('deviceId'):
|
| 93 |
+
records = group.to_dict('records')
|
| 94 |
+
if device_id not in buffers:
|
| 95 |
+
buffers[device_id] = []
|
| 96 |
+
buffers[device_id].extend(records)
|
| 97 |
+
buffers[device_id] = sorted(buffers[device_id], key=lambda r: r['ts_start'])
|
| 98 |
+
|
| 99 |
+
if len(buffers[device_id]) >= window_size:
|
| 100 |
+
segment = buffers[device_id][:window_size]
|
| 101 |
+
data_points: List[Dict] = []
|
| 102 |
+
for row in segment:
|
| 103 |
+
feature_payload = {}
|
| 104 |
+
for feat in feature_names:
|
| 105 |
+
if feat in row and pd.notna(row[feat]):
|
| 106 |
+
feature_payload[feat] = row[feat]
|
| 107 |
+
data_points.append({
|
| 108 |
+
'timestamp': row['ts_start'].to_pydatetime(),
|
| 109 |
+
'deviceId': str(device_id),
|
| 110 |
+
'features': feature_payload,
|
| 111 |
+
})
|
| 112 |
+
return data_points, str(device_id)
|
| 113 |
+
|
| 114 |
+
if len(buffers[device_id]) > window_size * 4:
|
| 115 |
+
buffers[device_id] = buffers[device_id][-window_size*2:]
|
| 116 |
+
|
| 117 |
+
raise ValueError("没有找到满足窗口长度的用户数据(请检查原始数据是否存在足够连续的记录)")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def load_predefined_case(case_name: str) -> Tuple[List[Dict], str, str]:
|
| 121 |
+
if case_name not in PREDEFINED_CASES:
|
| 122 |
+
raise ValueError(f"未找到预置案例: {case_name},可选: {list(PREDEFINED_CASES.keys())}")
|
| 123 |
+
case = PREDEFINED_CASES[case_name]
|
| 124 |
+
data_points = []
|
| 125 |
+
for point in case["data_points"]:
|
| 126 |
+
converted = dict(point)
|
| 127 |
+
converted["timestamp"] = pd.to_datetime(converted["timestamp"])
|
| 128 |
+
data_points.append(converted)
|
| 129 |
+
return data_points, case["user_id"], case["description"]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def main():
|
| 133 |
+
parser = argparse.ArgumentParser(description="使用原始 wearables 数据测试新的推理服务")
|
| 134 |
+
parser.add_argument("--model-dir", type=str, default="checkpoints/phase2/exp_factor_balanced")
|
| 135 |
+
parser.add_argument("--stage1-file", type=str, default="processed_data/stage1/wearable_processed.csv")
|
| 136 |
+
parser.add_argument("--window-size", type=int, default=12)
|
| 137 |
+
parser.add_argument("--case", type=str, nargs="+", default=["ab60_morning_rest"], help="使用预置案例名(可多选,all=全部)")
|
| 138 |
+
parser.add_argument("--from-raw", action="store_true", help="从stage1原始文件抽样,而不是预置案例")
|
| 139 |
+
args = parser.parse_args()
|
| 140 |
+
|
| 141 |
+
base_dir = Path(__file__).parent
|
| 142 |
+
stage1_file = base_dir / args.stage1_file
|
| 143 |
+
|
| 144 |
+
feature_calculator = FeatureCalculator()
|
| 145 |
+
feature_names = feature_calculator.get_enabled_feature_names()
|
| 146 |
+
|
| 147 |
+
print("\n🚀 加载异常检测器...")
|
| 148 |
+
detector = load_detector(base_dir / args.model_dir)
|
| 149 |
+
|
| 150 |
+
def run_prediction(label: str, data_points: List[Dict], user_id_hint: str):
|
| 151 |
+
print(f"\n🧪 执行预测: {label}")
|
| 152 |
+
print(f" - 用户: {user_id_hint}")
|
| 153 |
+
print(f" - 窗口长度: {len(data_points)} (每个点5分钟)")
|
| 154 |
+
result = detector.predict(data_points, return_score=True, return_details=True)
|
| 155 |
+
print(" ▸ 是否异常:", "是" if result['is_anomaly'] else "否")
|
| 156 |
+
print(f" ▸ 异常分数: {result.get('anomaly_score', 0.0):.4f} (阈值 {result.get('threshold', detector.threshold):.4f})")
|
| 157 |
+
if result.get('details'):
|
| 158 |
+
print(" ▸ 详情:", result['details'])
|
| 159 |
+
|
| 160 |
+
if args.from_raw:
|
| 161 |
+
print("📥 正在从原始数据抽样窗口...")
|
| 162 |
+
if not stage1_file.exists():
|
| 163 |
+
raise FileNotFoundError(f"找不到原始数据文件: {stage1_file}")
|
| 164 |
+
data_points, device_id = load_raw_data_window(stage1_file, feature_names, window_size=args.window_size)
|
| 165 |
+
run_prediction("raw_sample", data_points, device_id)
|
| 166 |
+
else:
|
| 167 |
+
case_names = args.case
|
| 168 |
+
if "all" in case_names:
|
| 169 |
+
case_names = list(PREDEFINED_CASES.keys())
|
| 170 |
+
for case_name in case_names:
|
| 171 |
+
if case_name not in PREDEFINED_CASES:
|
| 172 |
+
print(f"⚠️ 跳过未知案例: {case_name}")
|
| 173 |
+
continue
|
| 174 |
+
data_points, user_id, desc = load_predefined_case(case_name)
|
| 175 |
+
print(f"\n🧾 使用预置案例: {case_name}")
|
| 176 |
+
print(f" - 描述: {desc}")
|
| 177 |
+
run_prediction(case_name, data_points, user_id)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
if __name__ == "__main__":
|
| 181 |
+
main()
|
| 182 |
+
|
wearable_anomaly_detector.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Wearable健康异常检测模型 - 标准化封装
|
| 3 |
+
提供简单的API接口,用于实时异常检测
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import json
|
| 9 |
+
import pickle
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, List, Optional, Union
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
import pandas as pd
|
| 14 |
+
|
| 15 |
+
# 添加项目根目录到路径
|
| 16 |
+
import sys
|
| 17 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 18 |
+
|
| 19 |
+
from models.phased_lstm_tft import PhasedLSTM_TFT, PhasedLSTM_TFT_WithEnhancedAnomalyDetection
|
| 20 |
+
from feature_calculator import FeatureCalculator
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class WearableAnomalyDetector:
|
| 24 |
+
"""
|
| 25 |
+
Wearable健康异常检测器
|
| 26 |
+
|
| 27 |
+
使用示例:
|
| 28 |
+
detector = WearableAnomalyDetector(model_dir="checkpoints/phase2/exp_factor_balanced")
|
| 29 |
+
result = detector.predict(data_points)
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
model_dir: Union[str, Path],
|
| 35 |
+
device: Optional[str] = None,
|
| 36 |
+
threshold: Optional[float] = None
|
| 37 |
+
):
|
| 38 |
+
"""
|
| 39 |
+
初始化异常检测器
|
| 40 |
+
|
| 41 |
+
参数:
|
| 42 |
+
model_dir: 模型目录路径(包含best_model.pt和配置文件)
|
| 43 |
+
device: 设备('cuda'或'cpu'),如果为None则自动选择
|
| 44 |
+
threshold: 异常阈值,如果为None则从配置中读取
|
| 45 |
+
"""
|
| 46 |
+
self.model_dir = Path(model_dir)
|
| 47 |
+
self.device = torch.device(device or ('cuda' if torch.cuda.is_available() else 'cpu'))
|
| 48 |
+
|
| 49 |
+
# 加载配置
|
| 50 |
+
self.config = self._load_config()
|
| 51 |
+
|
| 52 |
+
# 确定阈值
|
| 53 |
+
if threshold is not None:
|
| 54 |
+
self.threshold = float(threshold)
|
| 55 |
+
else:
|
| 56 |
+
config_threshold = self.config.get('threshold')
|
| 57 |
+
if config_threshold is not None:
|
| 58 |
+
self.threshold = float(config_threshold)
|
| 59 |
+
else:
|
| 60 |
+
self.threshold = 0.53 # 默认阈值
|
| 61 |
+
print(f" ⚠️ 未找到阈值配置,使用默认值: {self.threshold:.4f}")
|
| 62 |
+
|
| 63 |
+
# 加载模型
|
| 64 |
+
self.model = self._load_model()
|
| 65 |
+
self.model.eval()
|
| 66 |
+
|
| 67 |
+
# 加载归一化参数(维持向后兼容)
|
| 68 |
+
self.norm_params = self._load_norm_params()
|
| 69 |
+
|
| 70 |
+
# 配置驱动特征计算
|
| 71 |
+
self.feature_calculator = FeatureCalculator(
|
| 72 |
+
config_path=self.config.get('feature_config_path'),
|
| 73 |
+
norm_params_path=Path(__file__).parent / 'processed_data' / 'stage3' / 'norm_params.json',
|
| 74 |
+
static_features_path=Path(__file__).parent / 'processed_data' / 'stage2' / 'static_features.csv',
|
| 75 |
+
storage_dir=Path(self.config.get('storage_dir', Path(__file__).parent / 'data_storage'))
|
| 76 |
+
)
|
| 77 |
+
self.features = self.feature_calculator.get_enabled_feature_names()
|
| 78 |
+
self.static_feature_names = [cfg["name"] for cfg in self.feature_calculator.static_feature_defs]
|
| 79 |
+
self.known_future_dim = max(len(self.feature_calculator.known_future_defs), 1)
|
| 80 |
+
self.factor_metadata = {
|
| 81 |
+
'enabled': self.feature_calculator.factor_enabled,
|
| 82 |
+
'factor_names': self.feature_calculator.factor_names,
|
| 83 |
+
'factor_dim': self.feature_calculator.factor_dim
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
print(f"✅ 模型加载成功")
|
| 87 |
+
print(f" - 设备: {self.device}")
|
| 88 |
+
print(f" - 阈值: {self.threshold:.4f}")
|
| 89 |
+
print(f" - 特征数: {len(self.features)}")
|
| 90 |
+
|
| 91 |
+
def _load_config(self) -> Dict:
|
| 92 |
+
"""加载模型配置"""
|
| 93 |
+
config_file = self.model_dir / 'config.json'
|
| 94 |
+
if config_file.exists():
|
| 95 |
+
with open(config_file, 'r') as f:
|
| 96 |
+
config = json.load(f)
|
| 97 |
+
return config
|
| 98 |
+
|
| 99 |
+
# 尝试从summary.json读取
|
| 100 |
+
summary_file = self.model_dir / 'summary.json'
|
| 101 |
+
if summary_file.exists():
|
| 102 |
+
with open(summary_file, 'r') as f:
|
| 103 |
+
summary = json.load(f)
|
| 104 |
+
config = {
|
| 105 |
+
'threshold': summary.get('best_threshold'),
|
| 106 |
+
'features': [], # 需要从其他地方获取
|
| 107 |
+
}
|
| 108 |
+
return config
|
| 109 |
+
|
| 110 |
+
# 如果都没有,返回空配置(使用默认值)
|
| 111 |
+
print(f" ⚠️ 未找到配置文件,使用默认配置")
|
| 112 |
+
return {}
|
| 113 |
+
|
| 114 |
+
def _load_model(self):
|
| 115 |
+
"""加载模型"""
|
| 116 |
+
# 加载Phase1模型
|
| 117 |
+
phase1_model_path = self.model_dir.parent.parent / 'phase1' / 'best_model.pt'
|
| 118 |
+
if not phase1_model_path.exists():
|
| 119 |
+
raise FileNotFoundError(f"Phase1模型不存在: {phase1_model_path}")
|
| 120 |
+
|
| 121 |
+
checkpoint_phase1 = torch.load(phase1_model_path, map_location=self.device, weights_only=False)
|
| 122 |
+
phase1_config = checkpoint_phase1['config']
|
| 123 |
+
|
| 124 |
+
base_model = PhasedLSTM_TFT(phase1_config)
|
| 125 |
+
base_model.load_state_dict(checkpoint_phase1['model_state_dict'])
|
| 126 |
+
base_model = base_model.to(self.device)
|
| 127 |
+
|
| 128 |
+
# 加载factor_config
|
| 129 |
+
factor_config = self._load_factor_config()
|
| 130 |
+
|
| 131 |
+
# 创建Phase2模型
|
| 132 |
+
model = PhasedLSTM_TFT_WithEnhancedAnomalyDetection(
|
| 133 |
+
base_model,
|
| 134 |
+
num_anomaly_types=4,
|
| 135 |
+
use_enhanced_head=True,
|
| 136 |
+
use_multi_source_heads=False,
|
| 137 |
+
use_domain_adversarial=False,
|
| 138 |
+
factor_config=factor_config
|
| 139 |
+
)
|
| 140 |
+
model = model.to(self.device)
|
| 141 |
+
|
| 142 |
+
# 加载Phase2权重
|
| 143 |
+
phase2_model_path = self.model_dir / 'best_model.pt'
|
| 144 |
+
if not phase2_model_path.exists():
|
| 145 |
+
raise FileNotFoundError(f"Phase2模型不存在: {phase2_model_path}")
|
| 146 |
+
|
| 147 |
+
checkpoint_phase2 = torch.load(phase2_model_path, map_location=self.device, weights_only=False)
|
| 148 |
+
model.load_state_dict(checkpoint_phase2['model_state_dict'])
|
| 149 |
+
|
| 150 |
+
return model
|
| 151 |
+
|
| 152 |
+
def _load_factor_config(self) -> Optional[Dict]:
|
| 153 |
+
"""加载因子特征配置"""
|
| 154 |
+
# 方法1: 从config.json读取(如果已加载)
|
| 155 |
+
if hasattr(self, 'factor_metadata') and self.factor_metadata:
|
| 156 |
+
if self.factor_metadata.get('enabled'):
|
| 157 |
+
return {
|
| 158 |
+
'num_factors': len(self.factor_metadata.get('factor_names', [])),
|
| 159 |
+
'factor_dim': self.factor_metadata.get('factor_dim', 0),
|
| 160 |
+
'factor_names': self.factor_metadata.get('factor_names', []),
|
| 161 |
+
'min_weight': 0.2,
|
| 162 |
+
'dropout': 0.1,
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
# 方法2: 从窗口信息文件读取
|
| 166 |
+
window_info_file = Path(__file__).parent / 'processed_data' / 'stage3' / 'window_info_multi_scale.json'
|
| 167 |
+
if window_info_file.exists():
|
| 168 |
+
with open(window_info_file, 'r') as f:
|
| 169 |
+
window_info = json.load(f)
|
| 170 |
+
factor_metadata = window_info.get('factor_features', {})
|
| 171 |
+
if factor_metadata and factor_metadata.get('enabled'):
|
| 172 |
+
return {
|
| 173 |
+
'num_factors': len(factor_metadata.get('factor_names', [])),
|
| 174 |
+
'factor_dim': factor_metadata.get('factor_dim', 0),
|
| 175 |
+
'factor_names': factor_metadata.get('factor_names', []),
|
| 176 |
+
'min_weight': 0.2,
|
| 177 |
+
'dropout': 0.1,
|
| 178 |
+
}
|
| 179 |
+
return None
|
| 180 |
+
|
| 181 |
+
def _load_norm_params(self) -> Optional[Dict]:
|
| 182 |
+
"""加载归一化参数"""
|
| 183 |
+
norm_file = Path(__file__).parent / 'processed_data' / 'stage3' / 'norm_params.json'
|
| 184 |
+
if norm_file.exists():
|
| 185 |
+
with open(norm_file, 'r') as f:
|
| 186 |
+
return json.load(f)
|
| 187 |
+
return None
|
| 188 |
+
|
| 189 |
+
def predict(
|
| 190 |
+
self,
|
| 191 |
+
data_points: List[Dict],
|
| 192 |
+
return_score: bool = True,
|
| 193 |
+
return_details: bool = False
|
| 194 |
+
) -> Dict:
|
| 195 |
+
"""
|
| 196 |
+
预测异常
|
| 197 |
+
|
| 198 |
+
参数:
|
| 199 |
+
data_points: 数据点列表,每个数据点是一个字典,包含:
|
| 200 |
+
- timestamp: 时间戳(datetime或字符串)
|
| 201 |
+
- features: 特征字典,包含所有需要的特征值
|
| 202 |
+
- static_features: 静态特征字典(可选)
|
| 203 |
+
return_score: 是否返回异常分数
|
| 204 |
+
return_details: 是否返回详细信息
|
| 205 |
+
|
| 206 |
+
返回:
|
| 207 |
+
{
|
| 208 |
+
'is_anomaly': bool, # 是否异常
|
| 209 |
+
'anomaly_score': float, # 异常分数(0-1)
|
| 210 |
+
'threshold': float, # 使用的阈值
|
| 211 |
+
'details': dict (可选) # 详细信息
|
| 212 |
+
}
|
| 213 |
+
"""
|
| 214 |
+
user_id = data_points[0].get('deviceId') or data_points[0].get('user_id')
|
| 215 |
+
window = self.feature_calculator.build_window(data_points, user_id=user_id)
|
| 216 |
+
|
| 217 |
+
# 转换为模型输入格式
|
| 218 |
+
model_input = self._prepare_model_input(window)
|
| 219 |
+
|
| 220 |
+
# 模型预测
|
| 221 |
+
with torch.no_grad():
|
| 222 |
+
# 模型forward方法接受位置参数,需要按顺序传递
|
| 223 |
+
outputs = self.model(
|
| 224 |
+
model_input['x'],
|
| 225 |
+
model_input['delta_t'],
|
| 226 |
+
model_input['static_features'],
|
| 227 |
+
model_input['known_future_features'],
|
| 228 |
+
mask=model_input.get('mask'),
|
| 229 |
+
return_contrastive_features=model_input.get('return_contrastive_features', False),
|
| 230 |
+
source=None,
|
| 231 |
+
return_domain_features=False,
|
| 232 |
+
factor_features=model_input.get('factor_features')
|
| 233 |
+
)
|
| 234 |
+
anomaly_score = outputs['anomaly_score'].cpu().item()
|
| 235 |
+
|
| 236 |
+
# 判断是否异常
|
| 237 |
+
is_anomaly = anomaly_score >= self.threshold
|
| 238 |
+
|
| 239 |
+
result = {
|
| 240 |
+
'is_anomaly': bool(is_anomaly),
|
| 241 |
+
'threshold': float(self.threshold),
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
if return_score:
|
| 245 |
+
result['anomaly_score'] = float(anomaly_score)
|
| 246 |
+
|
| 247 |
+
if return_details:
|
| 248 |
+
result['details'] = {
|
| 249 |
+
'window_size': len(data_points),
|
| 250 |
+
'model_output': float(anomaly_score),
|
| 251 |
+
'prediction_confidence': abs(anomaly_score - self.threshold),
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
return result
|
| 255 |
+
|
| 256 |
+
def _prepare_model_input(self, window: Dict) -> Dict:
|
| 257 |
+
"""准备模型输入"""
|
| 258 |
+
input_features_list = []
|
| 259 |
+
for feat in self.features:
|
| 260 |
+
values = window['input_features'].get(feat, [0.0] * 12)
|
| 261 |
+
input_features_list.append(values)
|
| 262 |
+
|
| 263 |
+
# 转换为tensor
|
| 264 |
+
input_features = torch.tensor(
|
| 265 |
+
np.stack(input_features_list, axis=1),
|
| 266 |
+
dtype=torch.float32
|
| 267 |
+
).unsqueeze(0).to(self.device) # [1, 12, num_features]
|
| 268 |
+
|
| 269 |
+
delta_t = torch.tensor(
|
| 270 |
+
window['input_delta_t'],
|
| 271 |
+
dtype=torch.float32
|
| 272 |
+
).unsqueeze(-1).unsqueeze(0).to(self.device) # [1, 12, 1]
|
| 273 |
+
|
| 274 |
+
# 静态特征
|
| 275 |
+
static_feature_values = []
|
| 276 |
+
static_keys = self.static_feature_names or sorted(window['static_features'].keys())
|
| 277 |
+
for key in static_keys:
|
| 278 |
+
value = window['static_features'].get(key, 0.0)
|
| 279 |
+
static_feature_values.append(float(value))
|
| 280 |
+
|
| 281 |
+
if len(static_feature_values) == 0:
|
| 282 |
+
static_feature_values = [0.0]
|
| 283 |
+
|
| 284 |
+
static_features = torch.tensor(
|
| 285 |
+
static_feature_values,
|
| 286 |
+
dtype=torch.float32
|
| 287 |
+
).unsqueeze(0).to(self.device) # [1, num_static]
|
| 288 |
+
|
| 289 |
+
# 已知未来特征
|
| 290 |
+
pred_len = len(window.get('target_timestamp', []))
|
| 291 |
+
if pred_len == 0:
|
| 292 |
+
pred_len = 6 # 默认预测长度
|
| 293 |
+
|
| 294 |
+
known_future = torch.zeros(1, pred_len, self.known_future_dim, dtype=torch.float32).to(self.device)
|
| 295 |
+
if 'known_future_features' in window:
|
| 296 |
+
kf = window['known_future_features']
|
| 297 |
+
for idx, cfg in enumerate(self.feature_calculator.known_future_defs):
|
| 298 |
+
name = cfg['name']
|
| 299 |
+
if name in kf:
|
| 300 |
+
series = kf[name][:pred_len]
|
| 301 |
+
if name == 'hour_of_day':
|
| 302 |
+
values = torch.tensor([float(h) / 23.0 for h in series], dtype=torch.float32)
|
| 303 |
+
elif name == 'day_of_week':
|
| 304 |
+
values = torch.tensor([float(d) / 6.0 for d in series], dtype=torch.float32)
|
| 305 |
+
else:
|
| 306 |
+
values = torch.tensor([float(v) for v in series], dtype=torch.float32)
|
| 307 |
+
known_future[0, :len(series), idx] = values
|
| 308 |
+
|
| 309 |
+
# 输入mask(假设所有数据都有效)
|
| 310 |
+
input_mask = torch.ones(1, 12, len(self.features), dtype=torch.float32).to(self.device)
|
| 311 |
+
|
| 312 |
+
# 因子特征
|
| 313 |
+
factor_features = None
|
| 314 |
+
if window.get('factor_features'):
|
| 315 |
+
factor_names = self.factor_metadata.get('factor_names', [])
|
| 316 |
+
factor_dim = self.factor_metadata.get('factor_dim', 4)
|
| 317 |
+
factor_vectors = []
|
| 318 |
+
for name in factor_names:
|
| 319 |
+
vec = window['factor_features'].get(name, [0.0] * factor_dim)
|
| 320 |
+
factor_vectors.append(vec[:factor_dim])
|
| 321 |
+
if factor_vectors:
|
| 322 |
+
factor_features = torch.tensor(
|
| 323 |
+
factor_vectors,
|
| 324 |
+
dtype=torch.float32
|
| 325 |
+
).unsqueeze(0).to(self.device) # [1, num_factors, factor_dim]
|
| 326 |
+
|
| 327 |
+
return {
|
| 328 |
+
'x': input_features,
|
| 329 |
+
'delta_t': delta_t,
|
| 330 |
+
'static_features': static_features,
|
| 331 |
+
'known_future_features': known_future,
|
| 332 |
+
'mask': input_mask,
|
| 333 |
+
'factor_features': factor_features,
|
| 334 |
+
'return_contrastive_features': False,
|
| 335 |
+
'source': None,
|
| 336 |
+
'return_domain_features': False,
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
def batch_predict(
|
| 340 |
+
self,
|
| 341 |
+
windows: List[List[Dict]],
|
| 342 |
+
return_scores: bool = True
|
| 343 |
+
) -> List[Dict]:
|
| 344 |
+
"""
|
| 345 |
+
批量预测
|
| 346 |
+
|
| 347 |
+
参数:
|
| 348 |
+
windows: 窗口列表,每个窗口是一个数据点列表
|
| 349 |
+
return_scores: 是否返回异常分数
|
| 350 |
+
|
| 351 |
+
返回:
|
| 352 |
+
预测结果列表
|
| 353 |
+
"""
|
| 354 |
+
results = []
|
| 355 |
+
for window_data in windows:
|
| 356 |
+
result = self.predict(window_data, return_score=return_scores)
|
| 357 |
+
results.append(result)
|
| 358 |
+
return results
|
| 359 |
+
|
| 360 |
+
def update_threshold(self, threshold: float):
|
| 361 |
+
"""更新异常阈值"""
|
| 362 |
+
self.threshold = threshold
|
| 363 |
+
print(f"✅ 阈值已更新为: {threshold:.4f}")
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def load_detector(model_dir: Union[str, Path], **kwargs) -> WearableAnomalyDetector:
|
| 367 |
+
"""
|
| 368 |
+
便捷函数:加载异常检测器
|
| 369 |
+
|
| 370 |
+
参数:
|
| 371 |
+
model_dir: 模型目录路径
|
| 372 |
+
**kwargs: 其他参数(device, threshold等)
|
| 373 |
+
|
| 374 |
+
返回:
|
| 375 |
+
WearableAnomalyDetector实例
|
| 376 |
+
"""
|
| 377 |
+
return WearableAnomalyDetector(model_dir, **kwargs)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
if __name__ == '__main__':
|
| 381 |
+
# 使用示例
|
| 382 |
+
print("=" * 80)
|
| 383 |
+
print("Wearable健康异常检测器 - 使用示例")
|
| 384 |
+
print("=" * 80)
|
| 385 |
+
|
| 386 |
+
# 加载模型
|
| 387 |
+
model_dir = Path(__file__).parent / 'checkpoints' / 'phase2' / 'exp_factor_balanced'
|
| 388 |
+
detector = load_detector(model_dir)
|
| 389 |
+
|
| 390 |
+
# 模拟数据点(实际使用时应该从实时数据流获取)
|
| 391 |
+
print("\n模拟数据点...")
|
| 392 |
+
data_points = []
|
| 393 |
+
base_time = datetime.now()
|
| 394 |
+
|
| 395 |
+
# 使用一个真实的deviceId(如果静态特征表存在)
|
| 396 |
+
# 或者提供一个完整的静态特征示例
|
| 397 |
+
example_device_id = None
|
| 398 |
+
static_dict = detector.feature_calculator.static_features_dict
|
| 399 |
+
if static_dict:
|
| 400 |
+
example_device_id = list(static_dict.keys())[0]
|
| 401 |
+
print(f" 使用示例用户ID: {example_device_id}")
|
| 402 |
+
|
| 403 |
+
for i in range(12):
|
| 404 |
+
data_point = {
|
| 405 |
+
'timestamp': base_time.replace(minute=i*5),
|
| 406 |
+
'deviceId': example_device_id, # 提供deviceId以便加载完整静态特征
|
| 407 |
+
'features': {
|
| 408 |
+
'hr': 70.0 + np.random.randn() * 5,
|
| 409 |
+
'hrv_rmssd': 30.0 + np.random.randn() * 3,
|
| 410 |
+
# ... 其他特征(简化示例,实际需要所有36个特征)
|
| 411 |
+
},
|
| 412 |
+
'static_features': {
|
| 413 |
+
# 可以只提供部分特征,系统会自动从静态特征表补充
|
| 414 |
+
# 或者不提供,完全从静态特征表加载
|
| 415 |
+
}
|
| 416 |
+
}
|
| 417 |
+
data_points.append(data_point)
|
| 418 |
+
|
| 419 |
+
# 预测
|
| 420 |
+
result = detector.predict(data_points, return_score=True, return_details=True)
|
| 421 |
+
|
| 422 |
+
print(f"\n预测结果:")
|
| 423 |
+
print(f" - 是否异常: {result['is_anomaly']}")
|
| 424 |
+
print(f" - 异常分数: {result['anomaly_score']:.4f}")
|
| 425 |
+
print(f" - 阈值: {result['threshold']:.4f}")
|
| 426 |
+
if 'details' in result:
|
| 427 |
+
print(f" - 详细信息: {result['details']}")
|
| 428 |
+
|