File size: 3,311 Bytes
1886358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from torchvision import models, transforms
import torch.nn.functional as F
import librosa, soundfile as sf, tempfile
import numpy as np
import matplotlib.pyplot as plt
import librosa.display
from PIL import Image
import io
from feature_extract import AudioFeatureExtractor
import requests, os

# === CONFIG ===
MODEL_REPO = "Chula-PD/voice-mobilenet-pd"
MODEL_FILE = "MobileNet_Model.pth"
MODEL_URL = f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILE}"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === FastAPI Init ===
app = FastAPI(title="CheckPD Voice API", version="1.0")

# Allow CORS (for React frontend)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # ปรับให้เฉพาะ domain ได้ภายหลัง
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# === Load Model ===
def load_model():
    if not os.path.exists(MODEL_FILE):
        print("Downloading model weights from Hugging Face...")
        weights_bytes = requests.get(MODEL_URL)
        with open(MODEL_FILE, "wb") as f:
            f.write(weights_bytes.content)
    model = models.mobilenet_v3_small(weights=None)
    in_features = model.classifier[-1].in_features
    model.classifier[-1] = torch.nn.Linear(in_features, 2)
    model.load_state_dict(torch.load(MODEL_FILE, map_location=device))
    model.eval()
    return model

model = load_model()
classes = ["HC", "PD"]

# === Image Transform ===
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]
    ),
])

@app.get("/")
def home():
    return {"message": "CheckPD Voice API is running."}

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    try:
        # Load and preprocess audio
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            tmp.write(await file.read())
            tmp.flush()
            wav_path = tmp.name

        extractor = AudioFeatureExtractor(wav_path, sr=16000)
        mel_db = extractor.get_melspectrogram()

        # Convert mel to image
        fig, ax = plt.subplots(figsize=(6, 3))
        librosa.display.specshow(mel_db, sr=16000, hop_length=51, cmap="viridis")
        plt.axis("off")
        buf = io.BytesIO()
        plt.savefig(buf, format='png', bbox_inches="tight", pad_inches=0)
        plt.close()
        buf.seek(0)
        image = Image.open(buf).convert("RGB")

        # Predict
        input_tensor = transform(image).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = model(input_tensor)
            probs = F.softmax(outputs, dim=1)
            pred_idx = torch.argmax(probs, dim=1).item()
            confidence = probs[0][pred_idx].item()

        return {
            "label": classes[pred_idx],
            "confidence": round(confidence, 4)
        }

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))