Spaces:
Running
Running
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| from torchvision import models, transforms | |
| import torch.nn.functional as F | |
| import librosa, soundfile as sf, tempfile | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import librosa.display | |
| from PIL import Image | |
| import io | |
| from feature_extract import AudioFeatureExtractor | |
| import requests, os | |
| # === CONFIG === | |
| MODEL_REPO = "Chula-PD/voice-mobilenet-pd" | |
| MODEL_FILE = "MobileNet_Model.pth" | |
| MODEL_URL = f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILE}" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # === FastAPI Init === | |
| app = FastAPI(title="CheckPD Voice API", version="1.0") | |
| # Allow CORS (for React frontend) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # ปรับให้เฉพาะ domain ได้ภายหลัง | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # === Load Model === | |
| def load_model(): | |
| if not os.path.exists(MODEL_FILE): | |
| print("Downloading model weights from Hugging Face...") | |
| weights_bytes = requests.get(MODEL_URL) | |
| with open(MODEL_FILE, "wb") as f: | |
| f.write(weights_bytes.content) | |
| model = models.mobilenet_v3_small(weights=None) | |
| in_features = model.classifier[-1].in_features | |
| model.classifier[-1] = torch.nn.Linear(in_features, 2) | |
| model.load_state_dict(torch.load(MODEL_FILE, map_location=device)) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| classes = ["HC", "PD"] | |
| # === Image Transform === | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| [0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225] | |
| ), | |
| ]) | |
| def home(): | |
| return {"message": "CheckPD Voice API is running."} | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| # Load and preprocess audio | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| tmp.write(await file.read()) | |
| tmp.flush() | |
| wav_path = tmp.name | |
| extractor = AudioFeatureExtractor(wav_path, sr=16000) | |
| mel_db = extractor.get_melspectrogram() | |
| # Convert mel to image | |
| fig, ax = plt.subplots(figsize=(6, 3)) | |
| librosa.display.specshow(mel_db, sr=16000, hop_length=51, cmap="viridis") | |
| plt.axis("off") | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches="tight", pad_inches=0) | |
| plt.close() | |
| buf.seek(0) | |
| image = Image.open(buf).convert("RGB") | |
| # Predict | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| probs = F.softmax(outputs, dim=1) | |
| pred_idx = torch.argmax(probs, dim=1).item() | |
| confidence = probs[0][pred_idx].item() | |
| return { | |
| "label": classes[pred_idx], | |
| "confidence": round(confidence, 4) | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |