from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.auth.jwt import get_current_user from fastapi import UploadFile from httpx import AsyncClient from main import app # 💡 NOTE Run tests with: pytest tests/test_router_model.py -v @pytest.mark.asyncio async def test_predict(): mock_file = AsyncMock(spec=UploadFile) mock_file.filename = "test_image.png" mock_file.read = AsyncMock(return_value=b"fake-image-data") mock_user = MagicMock() mock_user.id = 1 mock_current_user = MagicMock() mock_current_user.return_value = "testtoken" app.dependency_overrides[get_current_user] = lambda: mock_current_user with patch("app.model.router.utils.get_file_hash", return_value="fakehash123"): with patch( "app.model.router.model_predict", new_callable=AsyncMock ) as mock_model_predict: with patch("app.model.router.os.path.exists", return_value=False): mock_model_predict.return_value = ("cat", 0.95) with patch("builtins.open", new_callable=MagicMock): async with AsyncClient(app=app, base_url="http://test") as ac: response = await ac.post( "/model/predict", files={ "file": ( "test_image.png", mock_file.read.return_value, "image/png", ) }, headers={"Authorization": "Bearer testtoken"}, ) assert response.status_code == 200 response_data = response.json() assert response_data["success"] is True assert response_data["prediction"] == "cat" assert response_data["score"] == 0.95 assert response_data["image_file_name"] == "fakehash123" @pytest.mark.asyncio async def test_predict_fails_bad_extension(): mock_file = AsyncMock(spec=UploadFile) mock_file.filename = "test_image.png" mock_file.read = AsyncMock(return_value=b"fake-image-data") mock_user = MagicMock() mock_user.id = 1 mock_current_user = MagicMock() mock_current_user.return_value = "testtoken" app.dependency_overrides[get_current_user] = lambda: mock_current_user with patch("app.model.router.utils.get_file_hash", return_value="fakehash123"): with patch( "app.model.router.model_predict", new_callable=AsyncMock ) as mock_model_predict: with patch("app.model.router.os.path.exists", return_value=False): mock_model_predict.return_value = ("cat", 0.95) with patch("builtins.open", new_callable=MagicMock): async with AsyncClient(app=app, base_url="http://test") as ac: response = await ac.post( "/model/predict", files={ "file": ( "test_image.pdf", mock_file.read.return_value, "image/png", ) }, headers={"Authorization": "Bearer testtoken"}, ) assert response.status_code == 400 assert response.json() == { "detail": "File type is not supported." }