Commit
Β·
6ab520d
1
Parent(s):
457c02a
chore: Add API service files
Browse files- .env.original +6 -0
- api/.env.original +4 -0
- api/Dockerfile +18 -0
- api/Dockerfile.populate +16 -0
- api/__init__.py +0 -0
- api/app/auth/__init__.py +0 -0
- api/app/auth/jwt.py +90 -0
- api/app/auth/router.py +29 -0
- api/app/auth/schema.py +17 -0
- api/app/db.py +38 -0
- api/app/feedback/models.py +24 -0
- api/app/feedback/router.py +28 -0
- api/app/feedback/schema.py +19 -0
- api/app/feedback/services.py +66 -0
- api/app/model/router.py +45 -0
- api/app/model/schema.py +12 -0
- api/app/model/services.py +61 -0
- api/app/settings.py +34 -0
- api/app/user/__init__.py +0 -0
- api/app/user/hashing.py +37 -0
- api/app/user/models.py +23 -0
- api/app/user/router.py +51 -0
- api/app/user/schema.py +16 -0
- api/app/user/services.py +84 -0
- api/app/user/validator.py +22 -0
- api/app/utils.py +59 -0
- api/compose.yml +16 -0
- api/main.py +12 -0
- api/populate_db.py +72 -0
- api/requirements.txt +77 -0
- api/tests/__init__.py +0 -0
- api/tests/dog.jpeg +3 -0
- api/tests/test_router_feedback.py +84 -0
- api/tests/test_router_model.py +96 -0
- api/tests/test_router_user.py +69 -0
- api/tests/test_utils.py +38 -0
- compose.yml +0 -1
.env.original
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Test variables
|
| 2 |
+
POSTGRES_DB=sp3
|
| 3 |
+
POSTGRES_USER=postgres
|
| 4 |
+
POSTGRES_PASSWORD=adlibitum
|
| 5 |
+
DATABASE_HOST=db
|
| 6 |
+
SECRET_KEY=S09WWWHXBAJDIUEREHCN3752346572452VGGGVWWW526194
|
api/.env.original
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
POSTGRES_USER=postgres
|
| 2 |
+
POSTGRES_PASSWORD=adlibitum
|
| 3 |
+
DATABASE_HOST=db
|
| 4 |
+
POSTGRES_DB=sp3
|
api/Dockerfile
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.8.13 AS base
|
| 2 |
+
|
| 3 |
+
ENV PYTHONPATH=$PYTHONPATH:/src/
|
| 4 |
+
|
| 5 |
+
COPY ./requirements.txt /src/requirements.txt
|
| 6 |
+
|
| 7 |
+
WORKDIR /src
|
| 8 |
+
|
| 9 |
+
RUN pip install --upgrade pip && pip install -r requirements.txt
|
| 10 |
+
|
| 11 |
+
COPY ./ /src/
|
| 12 |
+
|
| 13 |
+
FROM base AS test
|
| 14 |
+
RUN ["python", "-m", "pytest", "-v", "/src/tests"]
|
| 15 |
+
|
| 16 |
+
FROM base AS build
|
| 17 |
+
|
| 18 |
+
CMD gunicorn -w 4 -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:5000 main:app
|
api/Dockerfile.populate
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.8.13
|
| 2 |
+
|
| 3 |
+
# 1. Copy the requirements.txt file to the image
|
| 4 |
+
ENV PYTHONPATH=$PYTHONPATH:/src/
|
| 5 |
+
COPY ./requirements.txt /src/requirements.txt
|
| 6 |
+
|
| 7 |
+
WORKDIR /src
|
| 8 |
+
|
| 9 |
+
# 2. Install the dependencies
|
| 10 |
+
RUN pip install --upgrade pip && pip install -r requirements.txt
|
| 11 |
+
|
| 12 |
+
# 3. Copy the content of the current directory to the image
|
| 13 |
+
COPY ./ /src/
|
| 14 |
+
|
| 15 |
+
# 4. Run the populate_db.py script
|
| 16 |
+
CMD ["python", "populate_db.py"]
|
api/__init__.py
ADDED
|
File without changes
|
api/app/auth/__init__.py
ADDED
|
File without changes
|
api/app/auth/jwt.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime, timedelta
|
| 2 |
+
|
| 3 |
+
from app.settings import SECRET_KEY
|
| 4 |
+
from fastapi import Depends, HTTPException, status
|
| 5 |
+
from fastapi.security import OAuth2PasswordBearer
|
| 6 |
+
from jose import JWTError, jwt
|
| 7 |
+
|
| 8 |
+
from . import schema
|
| 9 |
+
|
| 10 |
+
ALGORITHM = "HS256"
|
| 11 |
+
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def create_access_token(data: dict) -> str:
|
| 15 |
+
"""
|
| 16 |
+
Generates a JWT access token with an expiration time.
|
| 17 |
+
|
| 18 |
+
This function creates a JWT (JSON Web Token) that includes the provided
|
| 19 |
+
data and an expiration time. The token is signed using a secret key and
|
| 20 |
+
a specified algorithm.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
data (dict): A dictionary containing the data to be included in the token.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
str: The encoded JWT token as a string.
|
| 27 |
+
"""
|
| 28 |
+
to_encode = data.copy()
|
| 29 |
+
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
| 30 |
+
to_encode.update({"exp": expire})
|
| 31 |
+
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
| 32 |
+
return encoded_jwt
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def verify_token(token: str, credentials_exception):
|
| 36 |
+
"""
|
| 37 |
+
Verifies the provided JWT token and extracts the user information.
|
| 38 |
+
|
| 39 |
+
This function decodes the given JWT token using the secret key and specified
|
| 40 |
+
algorithm. It checks for the presence of the user's email in the token's payload.
|
| 41 |
+
If the email is not found or the token is invalid, an exception is raised.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
token (str): The JWT token to be verified.
|
| 45 |
+
credentials_exception: The exception to raise if the token is invalid or the email is not found.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
TokenData: An object containing the user's email extracted from the token.
|
| 49 |
+
|
| 50 |
+
Raises:
|
| 51 |
+
credentials_exception: If the token is invalid or does not contain an email.
|
| 52 |
+
"""
|
| 53 |
+
try:
|
| 54 |
+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
| 55 |
+
email: str = payload.get("sub")
|
| 56 |
+
if email is None:
|
| 57 |
+
raise credentials_exception
|
| 58 |
+
token_data = schema.TokenData(email=email)
|
| 59 |
+
except JWTError:
|
| 60 |
+
raise credentials_exception
|
| 61 |
+
return token_data
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_current_user(token: str = Depends(oauth2_scheme)):
|
| 68 |
+
"""
|
| 69 |
+
Retrieves the current authenticated user based on the provided JWT token.
|
| 70 |
+
|
| 71 |
+
This function extracts the JWT token from the request, verifies it, and returns
|
| 72 |
+
the user data associated with the token. If the token is invalid or cannot be
|
| 73 |
+
verified, an HTTP 401 Unauthorized exception is raised.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
token (str, optional): The JWT token extracted from the request using the OAuth2
|
| 77 |
+
password flow. Defaults to being fetched via `Depends(oauth2_scheme)`.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
TokenData: An object containing the user's information extracted from the token.
|
| 81 |
+
|
| 82 |
+
Raises:
|
| 83 |
+
HTTPException: If the token is invalid or the credentials cannot be validated.
|
| 84 |
+
"""
|
| 85 |
+
credentials_exception = HTTPException(
|
| 86 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 87 |
+
detail="Could not validate credentials",
|
| 88 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 89 |
+
)
|
| 90 |
+
return verify_token(token, credentials_exception)
|
api/app/auth/router.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from app import db
|
| 2 |
+
from app.user import hashing
|
| 3 |
+
from app.user.models import User
|
| 4 |
+
from fastapi import APIRouter, Depends, HTTPException, status
|
| 5 |
+
from fastapi.security import OAuth2PasswordRequestForm
|
| 6 |
+
from sqlalchemy.orm import Session
|
| 7 |
+
|
| 8 |
+
from .jwt import create_access_token
|
| 9 |
+
|
| 10 |
+
router = APIRouter(tags=["auth"])
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@router.post("/login")
|
| 14 |
+
def login(
|
| 15 |
+
request: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(db.get_db)
|
| 16 |
+
):
|
| 17 |
+
user = db.query(User).filter(User.email == request.username).first()
|
| 18 |
+
|
| 19 |
+
if not user:
|
| 20 |
+
raise HTTPException(
|
| 21 |
+
status_code=status.HTTP_404_NOT_FOUND, detail="Invalid credentials"
|
| 22 |
+
)
|
| 23 |
+
if not hashing.verify_password(request.password, user.password):
|
| 24 |
+
raise HTTPException(
|
| 25 |
+
status_code=status.HTTP_404_NOT_FOUND, detail="Incorrect password"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
access_token = create_access_token(data={"sub": user.email})
|
| 29 |
+
return {"access_token": access_token, "token_type": "bearer"}
|
api/app/auth/schema.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Login(BaseModel):
|
| 7 |
+
username: str
|
| 8 |
+
password: str
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Token(BaseModel):
|
| 12 |
+
access_token: str
|
| 13 |
+
token_type: str
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TokenData(BaseModel):
|
| 17 |
+
email: Optional[str] = None
|
api/app/db.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from app import settings as config
|
| 2 |
+
from sqlalchemy import create_engine
|
| 3 |
+
from sqlalchemy.ext.declarative import declarative_base
|
| 4 |
+
from sqlalchemy.orm import sessionmaker
|
| 5 |
+
|
| 6 |
+
DATABASE_USERNAME = config.DATABASE_USERNAME
|
| 7 |
+
DATABASE_PASSWORD = config.DATABASE_PASSWORD
|
| 8 |
+
DATABASE_HOST = config.DATABASE_HOST
|
| 9 |
+
DATABASE_NAME = config.DATABASE_NAME
|
| 10 |
+
|
| 11 |
+
SQLALCHEMY_DATABASE_URL = f"postgresql://{DATABASE_USERNAME}:{DATABASE_PASSWORD}@{DATABASE_HOST}/{DATABASE_NAME}"
|
| 12 |
+
|
| 13 |
+
engine = create_engine(SQLALCHEMY_DATABASE_URL)
|
| 14 |
+
|
| 15 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
| 16 |
+
|
| 17 |
+
Base = declarative_base()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_db():
|
| 21 |
+
"""
|
| 22 |
+
Provides a database session for dependency injection.
|
| 23 |
+
|
| 24 |
+
This function is used to obtain a new database session instance from the
|
| 25 |
+
`SessionLocal` factory. It is intended to be used with dependency injection
|
| 26 |
+
in FastAPI to manage database sessions.
|
| 27 |
+
|
| 28 |
+
Yields:
|
| 29 |
+
Session: A SQLAlchemy database session.
|
| 30 |
+
|
| 31 |
+
Notes:
|
| 32 |
+
The session is automatically closed after use to ensure proper resource management.
|
| 33 |
+
"""
|
| 34 |
+
db = SessionLocal()
|
| 35 |
+
try:
|
| 36 |
+
yield db
|
| 37 |
+
finally:
|
| 38 |
+
db.close()
|
api/app/feedback/models.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from app.db import Base
|
| 2 |
+
from sqlalchemy import Column, Float, ForeignKey, Integer, String
|
| 3 |
+
from sqlalchemy.orm import relationship
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Feedback(Base):
|
| 7 |
+
__tablename__ = "feedbacks"
|
| 8 |
+
|
| 9 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 10 |
+
score = Column(Float)
|
| 11 |
+
predicted_class = Column(String(50))
|
| 12 |
+
feedback = Column(String(255))
|
| 13 |
+
user_id = Column(Integer, ForeignKey("users.id"))
|
| 14 |
+
image_file_name = Column(String(255))
|
| 15 |
+
user = relationship("User", back_populates="feedbacks")
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self, score, predicted_class, feedback, image_file_name, user, *args, **kwargs
|
| 19 |
+
):
|
| 20 |
+
self.predicted_class = predicted_class
|
| 21 |
+
self.feedback = feedback
|
| 22 |
+
self.score = score
|
| 23 |
+
self.image_file_name = image_file_name
|
| 24 |
+
self.user = user
|
api/app/feedback/router.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from app import db
|
| 4 |
+
from app.auth.jwt import get_current_user
|
| 5 |
+
from app.user.schema import User
|
| 6 |
+
from fastapi import APIRouter, Depends, status
|
| 7 |
+
from sqlalchemy.orm import Session
|
| 8 |
+
|
| 9 |
+
from . import schema, services
|
| 10 |
+
|
| 11 |
+
router = APIRouter(tags=["Feedback"], prefix="/feedback")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@router.post("/", status_code=status.HTTP_201_CREATED)
|
| 15 |
+
async def create_feedback(
|
| 16 |
+
request: schema.Feedback,
|
| 17 |
+
database: Session = Depends(db.get_db),
|
| 18 |
+
current_user: User = Depends(get_current_user),
|
| 19 |
+
):
|
| 20 |
+
return await services.new_feedback(request, current_user, database)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@router.get("/", response_model=List[schema.DisplayFeedback])
|
| 24 |
+
async def get_all_feedback(
|
| 25 |
+
database: Session = Depends(db.get_db),
|
| 26 |
+
current_user: User = Depends(get_current_user),
|
| 27 |
+
):
|
| 28 |
+
return await services.all_feedback(database, current_user)
|
api/app/feedback/schema.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Feedback(BaseModel):
|
| 5 |
+
score: float
|
| 6 |
+
predicted_class: str
|
| 7 |
+
image_file_name: str
|
| 8 |
+
feedback: str
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DisplayFeedback(BaseModel):
|
| 12 |
+
id: int
|
| 13 |
+
score: float
|
| 14 |
+
predicted_class: str
|
| 15 |
+
image_file_name: str
|
| 16 |
+
feedback: str
|
| 17 |
+
|
| 18 |
+
class Config:
|
| 19 |
+
orm_mode = True
|
api/app/feedback/services.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from app.auth.schema import TokenData
|
| 2 |
+
from app.user.models import User
|
| 3 |
+
from sqlalchemy.orm import Session
|
| 4 |
+
|
| 5 |
+
from . import models, schema
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
async def new_feedback(
|
| 9 |
+
request: schema.Feedback, current_user: TokenData, database: Session
|
| 10 |
+
) -> models.Feedback:
|
| 11 |
+
"""
|
| 12 |
+
Adds new feedback to the database associated with the current user.
|
| 13 |
+
|
| 14 |
+
This asynchronous function creates a new feedback entry in the database using
|
| 15 |
+
the provided feedback data and associates it with the current user. It first
|
| 16 |
+
retrieves the user from the database based on the email in the `current_user` object,
|
| 17 |
+
then creates and stores the new feedback entry.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
request (schema.Feedback): An object containing the feedback details such as score,
|
| 21 |
+
image file name, predicted class, and feedback text.
|
| 22 |
+
current_user (TokenData): An object containing the email of the currently authenticated user.
|
| 23 |
+
database (Session): The database session used for querying and committing changes to the database.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
models.Feedback: The newly created feedback entry stored in the database.
|
| 27 |
+
|
| 28 |
+
Raises:
|
| 29 |
+
Exception: If there is an issue with adding or committing the feedback to the database.
|
| 30 |
+
"""
|
| 31 |
+
user = database.query(User).filter(User.email == current_user.email).first()
|
| 32 |
+
new_feedback = models.Feedback(
|
| 33 |
+
score=request.score,
|
| 34 |
+
image_file_name=request.image_file_name,
|
| 35 |
+
predicted_class=request.predicted_class,
|
| 36 |
+
user=user,
|
| 37 |
+
feedback=request.feedback,
|
| 38 |
+
)
|
| 39 |
+
database.add(new_feedback)
|
| 40 |
+
database.commit()
|
| 41 |
+
database.refresh(new_feedback)
|
| 42 |
+
return new_feedback
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
async def all_feedback(database: Session, current_user: TokenData) -> models.Feedback:
|
| 46 |
+
"""
|
| 47 |
+
Retrieves all feedback entries associated with the current user from the database.
|
| 48 |
+
|
| 49 |
+
This asynchronous function queries the database for all feedback entries linked to
|
| 50 |
+
the user identified by the `current_user` object. It returns a list of feedback entries
|
| 51 |
+
associated with the user's ID.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
database (Session): The database session used for querying the database.
|
| 55 |
+
current_user (TokenData): An object containing the email of the currently authenticated user.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
list[models.Feedback]: A list of feedback entries associated with the current user.
|
| 59 |
+
|
| 60 |
+
Raises:
|
| 61 |
+
Exception: If there is an issue with querying the feedback entries from the database.
|
| 62 |
+
"""
|
| 63 |
+
user = database.query(User).filter(User.email == current_user.email).first()
|
| 64 |
+
return (
|
| 65 |
+
database.query(models.Feedback).filter(models.Feedback.user_id == user.id).all()
|
| 66 |
+
)
|
api/app/model/router.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from app import settings as config
|
| 4 |
+
from app import utils
|
| 5 |
+
from app.auth.jwt import get_current_user
|
| 6 |
+
from app.model.schema import PredictResponse
|
| 7 |
+
from app.model.services import model_predict
|
| 8 |
+
from fastapi import APIRouter, Depends, HTTPException, UploadFile, status # File
|
| 9 |
+
|
| 10 |
+
router = APIRouter(tags=["Model"], prefix="/model")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@router.post("/predict")
|
| 14 |
+
async def predict(file: UploadFile, current_user=Depends(get_current_user)):
|
| 15 |
+
rpse = {"success": False, "prediction": None, "score": None}
|
| 16 |
+
|
| 17 |
+
# Check a file was sent and that file is an image
|
| 18 |
+
if not file or not utils.allowed_file(file.filename):
|
| 19 |
+
raise HTTPException(
|
| 20 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 21 |
+
detail="File type is not supported.",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Store the image to disk, calculate hash before to avoid re-writing an image already uploaded.
|
| 25 |
+
new_filename = await utils.get_file_hash(file)
|
| 26 |
+
file_path = os.path.join(config.UPLOAD_FOLDER, new_filename)
|
| 27 |
+
|
| 28 |
+
if not os.path.exists(file_path):
|
| 29 |
+
with open(file_path, "wb") as out_file:
|
| 30 |
+
content = await file.read()
|
| 31 |
+
out_file.write(content)
|
| 32 |
+
|
| 33 |
+
# Reset file pointer to the beginning
|
| 34 |
+
await file.seek(0)
|
| 35 |
+
|
| 36 |
+
# Send the file to be processed by the model service
|
| 37 |
+
prediction, score = await model_predict(file_path)
|
| 38 |
+
|
| 39 |
+
# Update and return rpse dict with the corresponding values
|
| 40 |
+
rpse["success"] = True
|
| 41 |
+
rpse["prediction"] = prediction
|
| 42 |
+
rpse["score"] = score
|
| 43 |
+
rpse["image_file_name"] = new_filename
|
| 44 |
+
|
| 45 |
+
return PredictResponse(**rpse)
|
api/app/model/schema.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class PredictRequest(BaseModel):
|
| 5 |
+
file: str
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class PredictResponse(BaseModel):
|
| 9 |
+
success: bool
|
| 10 |
+
prediction: str
|
| 11 |
+
score: float
|
| 12 |
+
image_file_name: str
|
api/app/model/services.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import time
|
| 3 |
+
from uuid import uuid4
|
| 4 |
+
|
| 5 |
+
import redis
|
| 6 |
+
|
| 7 |
+
from .. import settings
|
| 8 |
+
|
| 9 |
+
# Connect to Redis
|
| 10 |
+
db = redis.Redis(
|
| 11 |
+
host=settings.REDIS_IP, port=settings.REDIS_PORT, db=settings.REDIS_DB_ID
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
async def model_predict(image_name):
|
| 16 |
+
print(f"Processing image {image_name}...")
|
| 17 |
+
"""
|
| 18 |
+
Receives an image name and queues the job into Redis.
|
| 19 |
+
Will loop until getting the answer from our ML service.
|
| 20 |
+
|
| 21 |
+
Parameters
|
| 22 |
+
----------
|
| 23 |
+
image_name : str
|
| 24 |
+
Name for the image uploaded by the user.
|
| 25 |
+
|
| 26 |
+
Returns
|
| 27 |
+
-------
|
| 28 |
+
prediction, score : tuple(str, float)
|
| 29 |
+
Model predicted class as a string and the corresponding confidence
|
| 30 |
+
score as a number.
|
| 31 |
+
"""
|
| 32 |
+
prediction = None
|
| 33 |
+
score = None
|
| 34 |
+
|
| 35 |
+
# Assign an unique ID for this job and add it to the queue.
|
| 36 |
+
job_id = str(uuid4())
|
| 37 |
+
|
| 38 |
+
# Create a dict with the job data we will send through Redis
|
| 39 |
+
job_data = {"id": job_id, "image_name": image_name}
|
| 40 |
+
|
| 41 |
+
# Send the job to the model service using Redis
|
| 42 |
+
db.lpush(settings.REDIS_QUEUE, json.dumps(job_data))
|
| 43 |
+
|
| 44 |
+
# Loop until we received the response from our ML model
|
| 45 |
+
while True:
|
| 46 |
+
# Attempt to get model predictions using job_id
|
| 47 |
+
output = db.get(job_id)
|
| 48 |
+
|
| 49 |
+
# Check if the text was correctly processed by the ML model
|
| 50 |
+
if output is not None:
|
| 51 |
+
output = json.loads(output.decode("utf-8"))
|
| 52 |
+
prediction = output["prediction"]
|
| 53 |
+
score = output["score"]
|
| 54 |
+
|
| 55 |
+
db.delete(job_id)
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
# Sleep some time waiting for model results
|
| 59 |
+
time.sleep(settings.API_SLEEP)
|
| 60 |
+
|
| 61 |
+
return prediction, score
|
api/app/settings.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import dotenv
|
| 4 |
+
|
| 5 |
+
# Load environment variables from .env file
|
| 6 |
+
dotenv.load_dotenv()
|
| 7 |
+
|
| 8 |
+
# Run API in Debug mode
|
| 9 |
+
API_DEBUG = True
|
| 10 |
+
|
| 11 |
+
# We will store images uploaded by the user on this folder
|
| 12 |
+
UPLOAD_FOLDER = "uploads/"
|
| 13 |
+
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
| 14 |
+
|
| 15 |
+
# REDIS settings
|
| 16 |
+
|
| 17 |
+
# Queue name
|
| 18 |
+
REDIS_QUEUE = "service_queue"
|
| 19 |
+
# Port
|
| 20 |
+
REDIS_PORT = 6379
|
| 21 |
+
# DB Id
|
| 22 |
+
REDIS_DB_ID = 0
|
| 23 |
+
# Host IP
|
| 24 |
+
REDIS_IP = os.getenv("REDIS_IP", "redis")
|
| 25 |
+
# Sleep parameters which manages the
|
| 26 |
+
# interval between requests to our redis queue
|
| 27 |
+
API_SLEEP = 0.05
|
| 28 |
+
|
| 29 |
+
# Database settings
|
| 30 |
+
DATABASE_USERNAME = os.getenv("POSTGRES_USER")
|
| 31 |
+
DATABASE_PASSWORD = os.getenv("POSTGRES_PASSWORD")
|
| 32 |
+
DATABASE_HOST = os.getenv("DATABASE_HOST")
|
| 33 |
+
DATABASE_NAME = os.getenv("POSTGRES_DB")
|
| 34 |
+
SECRET_KEY = os.getenv("SECRET_KEY", "S09WWWHXBAJDIUEREHCN3752346572452VGGGVWWW526194")
|
api/app/user/__init__.py
ADDED
|
File without changes
|
api/app/user/hashing.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from passlib.context import CryptContext
|
| 2 |
+
|
| 3 |
+
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def verify_password(plain_password, hashed_password):
|
| 7 |
+
"""
|
| 8 |
+
Verifies if the provided plain password matches the hashed password.
|
| 9 |
+
|
| 10 |
+
This function compares a plain password with a hashed password to check if they
|
| 11 |
+
match using the Argon2 hashing algorithm. It returns `True` if the passwords match,
|
| 12 |
+
otherwise returns `False`.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
plain_password (str): The plain text password to be verified.
|
| 16 |
+
hashed_password (str): The hashed password to compare against.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
bool: `True` if the plain password matches the hashed password, otherwise `False`.
|
| 20 |
+
"""
|
| 21 |
+
return pwd_context.verify(plain_password, hashed_password)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_password_hash(password):
|
| 25 |
+
"""
|
| 26 |
+
Hashes the provided password using the Argon2 algorithm.
|
| 27 |
+
|
| 28 |
+
This function generates a hashed version of the given plain text password using
|
| 29 |
+
the Argon2 hashing algorithm. The resulting hash can be used for secure password storage.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
password (str): The plain text password to be hashed.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
str: The hashed password.
|
| 36 |
+
"""
|
| 37 |
+
return pwd_context.hash(password)
|
api/app/user/models.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from app.db import Base
|
| 2 |
+
from sqlalchemy import Column, Integer, String
|
| 3 |
+
from sqlalchemy.orm import relationship
|
| 4 |
+
|
| 5 |
+
from . import hashing
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class User(Base):
|
| 9 |
+
__tablename__ = "users"
|
| 10 |
+
|
| 11 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 12 |
+
name = Column(String(50))
|
| 13 |
+
email = Column(String(255), unique=True)
|
| 14 |
+
password = Column(String(255))
|
| 15 |
+
feedbacks = relationship("Feedback", back_populates="user")
|
| 16 |
+
|
| 17 |
+
def __init__(self, name, email, password, *args, **kwargs):
|
| 18 |
+
self.name = name
|
| 19 |
+
self.email = email
|
| 20 |
+
self.password = hashing.get_password_hash(password)
|
| 21 |
+
|
| 22 |
+
def check_password(self, password):
|
| 23 |
+
return hashing.verify_password(self.password, password)
|
api/app/user/router.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from typing import List
|
| 2 |
+
from app import db
|
| 3 |
+
from app.auth.jwt import get_current_user
|
| 4 |
+
from fastapi import APIRouter, Depends, HTTPException, status
|
| 5 |
+
from sqlalchemy.orm import Session
|
| 6 |
+
|
| 7 |
+
from . import schema, services, validator
|
| 8 |
+
|
| 9 |
+
router = APIRouter(tags=["Users"], prefix="/user")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@router.post("/", status_code=status.HTTP_201_CREATED)
|
| 13 |
+
async def create_user_registration(
|
| 14 |
+
request: schema.User, database: Session = Depends(db.get_db)
|
| 15 |
+
):
|
| 16 |
+
# Verify the user email doesn't already exist
|
| 17 |
+
if await validator.verify_email_exist(email=request.email, database=database):
|
| 18 |
+
raise HTTPException(
|
| 19 |
+
status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# If the email doesn't exist, create a new user
|
| 23 |
+
new_user = await services.new_user_register(request=request, database=database)
|
| 24 |
+
|
| 25 |
+
return new_user
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@router.get("/")
|
| 29 |
+
async def get_all_users(
|
| 30 |
+
database: Session = Depends(db.get_db),
|
| 31 |
+
current_user: schema.User = Depends(get_current_user),
|
| 32 |
+
):
|
| 33 |
+
return await services.all_users(database)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@router.get("/{id}", response_model=schema.DisplayUser)
|
| 37 |
+
async def get_user_by_id(
|
| 38 |
+
id: int,
|
| 39 |
+
database: Session = Depends(db.get_db),
|
| 40 |
+
current_user: schema.User = Depends(get_current_user),
|
| 41 |
+
):
|
| 42 |
+
return await services.get_user_by_id(id, database)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@router.delete("/{id}", status_code=status.HTTP_204_NO_CONTENT)
|
| 46 |
+
async def delete_user_by_id(
|
| 47 |
+
id: int,
|
| 48 |
+
database: Session = Depends(db.get_db),
|
| 49 |
+
current_user: schema.User = Depends(get_current_user),
|
| 50 |
+
):
|
| 51 |
+
return await services.delete_user_by_id(id, database)
|
api/app/user/schema.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, EmailStr, constr
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class User(BaseModel):
|
| 5 |
+
name: constr(min_length=2, max_length=50)
|
| 6 |
+
email: EmailStr
|
| 7 |
+
password: str
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DisplayUser(BaseModel):
|
| 11 |
+
id: int
|
| 12 |
+
name: str
|
| 13 |
+
email: str
|
| 14 |
+
|
| 15 |
+
class Config:
|
| 16 |
+
orm_mode = True
|
api/app/user/services.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import HTTPException, status
|
| 2 |
+
from sqlalchemy.orm import Session
|
| 3 |
+
|
| 4 |
+
from . import models, schema
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
async def new_user_register(request: schema.User, database: Session) -> models.User:
|
| 8 |
+
"""
|
| 9 |
+
Registers a new user in the database.
|
| 10 |
+
|
| 11 |
+
This asynchronous function creates a new user entry in the database using the provided
|
| 12 |
+
user details from the request. It adds the user to the database, commits the changes,
|
| 13 |
+
and returns the newly created user.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
request (schema.User): An object containing user details such as name, email, and password.
|
| 17 |
+
database (Session): The database session used for adding and committing the user to the database.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
models.User: The newly created user entry stored in the database.
|
| 21 |
+
"""
|
| 22 |
+
new_user = models.User(
|
| 23 |
+
name=request.name, email=request.email, password=request.password
|
| 24 |
+
)
|
| 25 |
+
database.add(new_user)
|
| 26 |
+
database.commit()
|
| 27 |
+
database.refresh(new_user)
|
| 28 |
+
return new_user
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
async def all_users(database: Session) -> models.User:
|
| 32 |
+
"""
|
| 33 |
+
Retrieves all users from the database.
|
| 34 |
+
|
| 35 |
+
This asynchronous function queries the database to retrieve a list of all users.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
database (Session): The database session used for querying the database.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
list[models.User]: A list of all user entries in the database.
|
| 42 |
+
"""
|
| 43 |
+
return database.query(models.User).all()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
async def get_user_by_id(id: int, database: Session) -> models.User:
|
| 47 |
+
"""
|
| 48 |
+
Retrieves a user from the database by their ID.
|
| 49 |
+
|
| 50 |
+
This asynchronous function queries the database for a user with the specified ID.
|
| 51 |
+
If the user is not found, an HTTP 404 Not Found exception is raised.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
id (int): The ID of the user to retrieve.
|
| 55 |
+
database (Session): The database session used for querying the database.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
models.User: The user entry with the specified ID.
|
| 59 |
+
|
| 60 |
+
Raises:
|
| 61 |
+
HTTPException: If the user with the specified ID is not found.
|
| 62 |
+
"""
|
| 63 |
+
user = database.query(models.User).filter(models.User.id == id).first()
|
| 64 |
+
if not user:
|
| 65 |
+
raise HTTPException(
|
| 66 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 67 |
+
detail=f"User with the id {id} is not available",
|
| 68 |
+
)
|
| 69 |
+
return user
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
async def delete_user_by_id(id: int, database: Session):
|
| 73 |
+
"""
|
| 74 |
+
Deletes a user from the database by their ID.
|
| 75 |
+
|
| 76 |
+
This asynchronous function removes the user with the specified ID from the database
|
| 77 |
+
and commits the changes.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
id (int): The ID of the user to delete.
|
| 81 |
+
database (Session): The database session used for querying and committing changes to the database.
|
| 82 |
+
"""
|
| 83 |
+
database.query(models.User).filter(models.User.id == id).delete()
|
| 84 |
+
database.commit()
|
api/app/user/validator.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from sqlalchemy.orm import Session
|
| 4 |
+
|
| 5 |
+
from .models import User
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
async def verify_email_exist(email: str, database: Session) -> Optional[User]:
|
| 9 |
+
"""
|
| 10 |
+
Checks if a user with the specified email exists in the database.
|
| 11 |
+
|
| 12 |
+
This asynchronous function queries the database to find a user with the given email.
|
| 13 |
+
It returns the user object if found, or `None` if no user with that email exists.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
email (str): The email address to check for existence.
|
| 17 |
+
database (Session): The database session used for querying the database.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Optional[User]: The user object if a user with the specified email exists, otherwise `None`.
|
| 21 |
+
"""
|
| 22 |
+
return database.query(User).filter(User.email == email).first()
|
api/app/utils.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def allowed_file(filename):
|
| 6 |
+
"""
|
| 7 |
+
Checks if the format for the file received is acceptable. For this
|
| 8 |
+
particular case, we must accept only image files. This is, files with
|
| 9 |
+
extension ".png", ".jpg", ".jpeg" or ".gif".
|
| 10 |
+
|
| 11 |
+
Parameters
|
| 12 |
+
----------
|
| 13 |
+
filename : str
|
| 14 |
+
Filename from werkzeug.datastructures.FileStorage file.
|
| 15 |
+
|
| 16 |
+
Returns
|
| 17 |
+
-------
|
| 18 |
+
bool
|
| 19 |
+
True if the file is an image, False otherwise.
|
| 20 |
+
"""
|
| 21 |
+
# Check if the file extension of the filename received is in the set of allowed extensions (".png", ".jpg", ".jpeg", ".gif")
|
| 22 |
+
allowed_extensions = [".png", ".jpg", ".jpeg", ".gif"]
|
| 23 |
+
|
| 24 |
+
return (
|
| 25 |
+
os.path.splitext(filename)[1].lower() in allowed_extensions
|
| 26 |
+
) # Alternatively use Path.suffix
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
async def get_file_hash(file):
|
| 30 |
+
"""
|
| 31 |
+
Returns a new filename based on the file content using MD5 hashing.
|
| 32 |
+
It uses hashlib.md5() function from Python standard library to get
|
| 33 |
+
the hash.
|
| 34 |
+
|
| 35 |
+
Parameters
|
| 36 |
+
----------
|
| 37 |
+
file : werkzeug.datastructures.FileStorage
|
| 38 |
+
File sent by user.
|
| 39 |
+
|
| 40 |
+
Returns
|
| 41 |
+
-------
|
| 42 |
+
str
|
| 43 |
+
New filename based in md5 file hash.
|
| 44 |
+
"""
|
| 45 |
+
# Read file content (byte string)
|
| 46 |
+
content = await file.read()
|
| 47 |
+
|
| 48 |
+
# Generate MD5 hash from content
|
| 49 |
+
file_hash = hashlib.md5( # Apply MD5 hash
|
| 50 |
+
content
|
| 51 |
+
).hexdigest() # hexdigest converts the result into a readable hex string
|
| 52 |
+
|
| 53 |
+
# Get the file extension
|
| 54 |
+
_, ext = os.path.splitext(file.filename)
|
| 55 |
+
|
| 56 |
+
# Reset file pointer to the beginning
|
| 57 |
+
await file.seek(0)
|
| 58 |
+
|
| 59 |
+
return f"{file_hash}{ext}"
|
api/compose.yml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
app:
|
| 3 |
+
build:
|
| 4 |
+
context: .
|
| 5 |
+
dockerfile: Dockerfile.populate
|
| 6 |
+
environment:
|
| 7 |
+
POSTGRES_DB: $POSTGRES_DB
|
| 8 |
+
POSTGRES_USER: $POSTGRES_USER
|
| 9 |
+
POSTGRES_PASSWORD: $POSTGRES_PASSWORD
|
| 10 |
+
DATABASE_HOST: $DATABASE_HOST
|
| 11 |
+
networks:
|
| 12 |
+
- shared_network
|
| 13 |
+
|
| 14 |
+
networks:
|
| 15 |
+
shared_network:
|
| 16 |
+
external: true
|
api/main.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from app.auth import router as auth_router
|
| 2 |
+
from app.feedback import router as feedback_router
|
| 3 |
+
from app.model import router as model_router
|
| 4 |
+
from app.user import router as user_router
|
| 5 |
+
from fastapi import FastAPI
|
| 6 |
+
|
| 7 |
+
app = FastAPI(title="Image Prediction API", version="0.0.1")
|
| 8 |
+
|
| 9 |
+
app.include_router(auth_router.router)
|
| 10 |
+
app.include_router(model_router.router)
|
| 11 |
+
app.include_router(user_router.router)
|
| 12 |
+
app.include_router(feedback_router.router)
|
api/populate_db.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import psycopg2
|
| 2 |
+
from app import settings as config
|
| 3 |
+
from app.db import Base
|
| 4 |
+
from app.user.models import User
|
| 5 |
+
from psycopg2.errors import DuplicateDatabase
|
| 6 |
+
from sqlalchemy import create_engine
|
| 7 |
+
from sqlalchemy.orm import sessionmaker
|
| 8 |
+
|
| 9 |
+
# Database configuration
|
| 10 |
+
DATABASE_USERNAME = config.DATABASE_USERNAME
|
| 11 |
+
DATABASE_PASSWORD = config.DATABASE_PASSWORD
|
| 12 |
+
DATABASE_HOST = config.DATABASE_HOST
|
| 13 |
+
DATABASE_NAME = config.DATABASE_NAME
|
| 14 |
+
|
| 15 |
+
# Create the initial connection URL to PostgreSQL (without specifying the database)
|
| 16 |
+
initial_connection_url = (
|
| 17 |
+
f"postgresql://{DATABASE_USERNAME}:{DATABASE_PASSWORD}@{DATABASE_HOST}/postgres"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
print(initial_connection_url)
|
| 21 |
+
|
| 22 |
+
conn = None
|
| 23 |
+
|
| 24 |
+
# Connect to PostgreSQL to create the database if it doesn't exist
|
| 25 |
+
try:
|
| 26 |
+
conn = psycopg2.connect(initial_connection_url)
|
| 27 |
+
conn.autocommit = True
|
| 28 |
+
cursor = conn.cursor()
|
| 29 |
+
|
| 30 |
+
# Create the database
|
| 31 |
+
cursor.execute(f"CREATE DATABASE {DATABASE_NAME}")
|
| 32 |
+
print(f"Database '{DATABASE_NAME}' created successfully")
|
| 33 |
+
|
| 34 |
+
except DuplicateDatabase as e:
|
| 35 |
+
if "already exists" in str(e):
|
| 36 |
+
print(f"Database '{DATABASE_NAME}' already exists.")
|
| 37 |
+
else:
|
| 38 |
+
print(f"Error creating database: {e}")
|
| 39 |
+
finally:
|
| 40 |
+
if conn:
|
| 41 |
+
cursor.close()
|
| 42 |
+
conn.close()
|
| 43 |
+
|
| 44 |
+
# Database connection URL to the newly created database
|
| 45 |
+
SQLALCHEMY_DATABASE_URL = f"postgresql://{DATABASE_USERNAME}:{DATABASE_PASSWORD}@{DATABASE_HOST}/{DATABASE_NAME}"
|
| 46 |
+
print(SQLALCHEMY_DATABASE_URL)
|
| 47 |
+
|
| 48 |
+
# Create engine
|
| 49 |
+
engine = create_engine(SQLALCHEMY_DATABASE_URL)
|
| 50 |
+
|
| 51 |
+
# Drop all tables if they exist
|
| 52 |
+
Base.metadata.drop_all(engine)
|
| 53 |
+
print("Tables dropped")
|
| 54 |
+
|
| 55 |
+
# Create all tables
|
| 56 |
+
Base.metadata.create_all(engine)
|
| 57 |
+
print("Tables created")
|
| 58 |
+
|
| 59 |
+
# Populate database with a default user
|
| 60 |
+
print("Populating database with default user")
|
| 61 |
+
Session = sessionmaker(bind=engine)
|
| 62 |
+
session = Session()
|
| 63 |
+
|
| 64 |
+
user = User(
|
| 65 |
+
name="Admin User",
|
| 66 |
+
password="admin",
|
| 67 |
+
email="admin@example.com",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
session.add(user)
|
| 71 |
+
session.commit()
|
| 72 |
+
print("Default user added")
|
api/requirements.txt
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gunicorn==20.1.0
|
| 2 |
+
redis==4.1.4
|
| 3 |
+
werkzeug==2.0.3
|
| 4 |
+
alembic==1.6.5
|
| 5 |
+
psycopg2-binary==2.9.1
|
| 6 |
+
amqp==5.1.1
|
| 7 |
+
anyio==3.6.2
|
| 8 |
+
argon2-cffi==21.3.0
|
| 9 |
+
argon2-cffi-bindings==21.2.0
|
| 10 |
+
black==22.12.0
|
| 11 |
+
boto3==1.21.32
|
| 12 |
+
botocore==1.24.46
|
| 13 |
+
certifi==2022.12.7
|
| 14 |
+
cffi==1.15.1
|
| 15 |
+
click==8.1.3
|
| 16 |
+
click-didyoumean==0.3.0
|
| 17 |
+
click-plugins==1.1.1
|
| 18 |
+
click-repl==0.2.0
|
| 19 |
+
cryptography==38.0.4
|
| 20 |
+
Deprecated==1.2.13
|
| 21 |
+
dnspython==2.2.1
|
| 22 |
+
ecdsa==0.18.0
|
| 23 |
+
email-validator==1.3.0
|
| 24 |
+
exceptiongroup==1.1.0
|
| 25 |
+
Faker==15.3.4
|
| 26 |
+
fastapi==0.88.0
|
| 27 |
+
h11==0.14.0
|
| 28 |
+
httpcore==0.16.3
|
| 29 |
+
httptools==0.5.0
|
| 30 |
+
httpx==0.23.1
|
| 31 |
+
idna==3.4
|
| 32 |
+
iniconfig==1.1.1
|
| 33 |
+
itsdangerous==2.1.2
|
| 34 |
+
Jinja2==3.1.2
|
| 35 |
+
jmespath==1.0.1
|
| 36 |
+
kombu==5.2.4
|
| 37 |
+
Mako==1.2.4
|
| 38 |
+
MarkupSafe==2.1.1
|
| 39 |
+
mypy-extensions==0.4.3
|
| 40 |
+
orjson==3.8.3
|
| 41 |
+
packaging==22.0
|
| 42 |
+
passlib==1.7.4
|
| 43 |
+
pathspec==0.10.3
|
| 44 |
+
platformdirs==2.6.0
|
| 45 |
+
pluggy==1.0.0
|
| 46 |
+
prompt-toolkit==3.0.36
|
| 47 |
+
pyasn1==0.4.8
|
| 48 |
+
pycparser==2.21
|
| 49 |
+
pydantic==1.10.2
|
| 50 |
+
pytest==7.2.0
|
| 51 |
+
pytest-asyncio==0.20.3
|
| 52 |
+
pytest-mock==3.10.0
|
| 53 |
+
python-dateutil==2.8.2
|
| 54 |
+
python-dotenv==0.21.0
|
| 55 |
+
python-editor==1.0.4
|
| 56 |
+
python-jose==3.3.0
|
| 57 |
+
python-multipart==0.0.5
|
| 58 |
+
pytz==2022.7
|
| 59 |
+
PyYAML==6.0
|
| 60 |
+
rfc3986==1.5.0
|
| 61 |
+
rsa==4.9
|
| 62 |
+
s3transfer==0.5.2
|
| 63 |
+
six==1.16.0
|
| 64 |
+
sniffio==1.3.0
|
| 65 |
+
SQLAlchemy==1.3.24
|
| 66 |
+
starlette==0.22.0
|
| 67 |
+
tomli==2.0.1
|
| 68 |
+
typing_extensions==4.4.0
|
| 69 |
+
ujson==5.6.0
|
| 70 |
+
urllib3==1.26.13
|
| 71 |
+
uvicorn==0.20.0
|
| 72 |
+
uvloop==0.17.0
|
| 73 |
+
vine==5.0.0
|
| 74 |
+
watchfiles==0.18.1
|
| 75 |
+
wcwidth==0.2.5
|
| 76 |
+
websockets==10.4
|
| 77 |
+
wrapt==1.14.1
|
api/tests/__init__.py
ADDED
|
File without changes
|
api/tests/dog.jpeg
ADDED
|
Git LFS Details
|
api/tests/test_router_feedback.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from unittest import mock
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from app import db
|
| 5 |
+
from app.auth.jwt import get_current_user
|
| 6 |
+
from app.feedback.schema import DisplayFeedback, Feedback
|
| 7 |
+
from app.user.schema import User
|
| 8 |
+
from fastapi.testclient import TestClient
|
| 9 |
+
from main import app
|
| 10 |
+
from sqlalchemy.orm import Session
|
| 11 |
+
|
| 12 |
+
client = TestClient(app)
|
| 13 |
+
|
| 14 |
+
sample_user = User(
|
| 15 |
+
id=1,
|
| 16 |
+
username="testuser",
|
| 17 |
+
email="testuser@example.com",
|
| 18 |
+
name="Test User",
|
| 19 |
+
password="password",
|
| 20 |
+
)
|
| 21 |
+
sample_feedback = Feedback(
|
| 22 |
+
feedback="Great service!",
|
| 23 |
+
image_file_name="testimage.jpg",
|
| 24 |
+
predicted_class="dog",
|
| 25 |
+
score=0.95,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@pytest.fixture
|
| 30 |
+
def mock_db_session():
|
| 31 |
+
return mock.create_autospec(Session, instance=True)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@pytest.fixture
|
| 35 |
+
def mock_get_current_user():
|
| 36 |
+
return sample_user
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@mock.patch("app.feedback.router.services.new_feedback")
|
| 40 |
+
def test_create_feedback(mock_new_feedback, mock_db_session, mock_get_current_user):
|
| 41 |
+
mock_new_feedback.return_value = sample_feedback
|
| 42 |
+
|
| 43 |
+
payload = {
|
| 44 |
+
"feedback": "Great service!",
|
| 45 |
+
"image_file_name": "testimage.jpg",
|
| 46 |
+
"predicted_class": "dog",
|
| 47 |
+
"score": 0.95,
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
app.dependency_overrides[db.get_db] = lambda: mock_db_session
|
| 51 |
+
app.dependency_overrides[get_current_user] = lambda: mock_get_current_user
|
| 52 |
+
|
| 53 |
+
response = client.post(
|
| 54 |
+
"/feedback/",
|
| 55 |
+
json=payload,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
assert response.status_code == 201
|
| 59 |
+
|
| 60 |
+
mock_new_feedback.assert_called_once_with(payload, sample_user, mock_db_session)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@mock.patch("app.feedback.router.services.all_feedback")
|
| 64 |
+
def test_get_all_feedback(mock_all_feedback, mock_db_session, mock_get_current_user):
|
| 65 |
+
# Setup the mock service to return a list of feedback
|
| 66 |
+
mock_all_feedback.return_value = [
|
| 67 |
+
DisplayFeedback(
|
| 68 |
+
id=1,
|
| 69 |
+
feedback="Great service!",
|
| 70 |
+
score=0.95,
|
| 71 |
+
predicted_class="dog",
|
| 72 |
+
image_file_name="testimage.jpg",
|
| 73 |
+
)
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
app.dependency_overrides[db.get_db] = lambda: mock_db_session
|
| 77 |
+
app.dependency_overrides[get_current_user] = lambda: mock_get_current_user
|
| 78 |
+
response = client.get(
|
| 79 |
+
"/feedback/",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
assert response.status_code == 200
|
| 83 |
+
|
| 84 |
+
mock_all_feedback.assert_called_once_with(mock_db_session, sample_user)
|
api/tests/test_router_model.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from app.auth.jwt import get_current_user
|
| 5 |
+
|
| 6 |
+
# from app.model.schema import PredictResponse
|
| 7 |
+
from fastapi import UploadFile
|
| 8 |
+
|
| 9 |
+
# from fastapi.testclient import TestClient
|
| 10 |
+
from httpx import AsyncClient
|
| 11 |
+
from main import app
|
| 12 |
+
|
| 13 |
+
# π‘ NOTE Run tests with: pytest tests/test_router_model.py -v
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@pytest.mark.asyncio
|
| 17 |
+
async def test_predict():
|
| 18 |
+
mock_file = AsyncMock(spec=UploadFile)
|
| 19 |
+
mock_file.filename = "test_image.png"
|
| 20 |
+
mock_file.read = AsyncMock(return_value=b"fake-image-data")
|
| 21 |
+
|
| 22 |
+
mock_user = MagicMock()
|
| 23 |
+
mock_user.id = 1
|
| 24 |
+
|
| 25 |
+
mock_current_user = MagicMock()
|
| 26 |
+
mock_current_user.return_value = "testtoken"
|
| 27 |
+
|
| 28 |
+
app.dependency_overrides[get_current_user] = lambda: mock_current_user
|
| 29 |
+
|
| 30 |
+
with patch("app.model.router.utils.get_file_hash", return_value="fakehash123"):
|
| 31 |
+
with patch(
|
| 32 |
+
"app.model.router.model_predict", new_callable=AsyncMock
|
| 33 |
+
) as mock_model_predict:
|
| 34 |
+
with patch("app.model.router.os.path.exists", return_value=False):
|
| 35 |
+
mock_model_predict.return_value = ("cat", 0.95)
|
| 36 |
+
with patch("builtins.open", new_callable=MagicMock):
|
| 37 |
+
async with AsyncClient(app=app, base_url="http://test") as ac:
|
| 38 |
+
response = await ac.post(
|
| 39 |
+
"/model/predict",
|
| 40 |
+
files={
|
| 41 |
+
"file": (
|
| 42 |
+
"test_image.png",
|
| 43 |
+
mock_file.read.return_value,
|
| 44 |
+
"image/png",
|
| 45 |
+
)
|
| 46 |
+
},
|
| 47 |
+
headers={"Authorization": "Bearer testtoken"},
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
assert response.status_code == 200
|
| 51 |
+
|
| 52 |
+
response_data = response.json()
|
| 53 |
+
assert response_data["success"] is True
|
| 54 |
+
assert response_data["prediction"] == "cat"
|
| 55 |
+
assert response_data["score"] == 0.95
|
| 56 |
+
assert response_data["image_file_name"] == "fakehash123"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@pytest.mark.asyncio
|
| 60 |
+
async def test_predict_fails_bad_extension():
|
| 61 |
+
mock_file = AsyncMock(spec=UploadFile)
|
| 62 |
+
mock_file.filename = "test_image.png"
|
| 63 |
+
mock_file.read = AsyncMock(return_value=b"fake-image-data")
|
| 64 |
+
|
| 65 |
+
mock_user = MagicMock()
|
| 66 |
+
mock_user.id = 1
|
| 67 |
+
|
| 68 |
+
mock_current_user = MagicMock()
|
| 69 |
+
mock_current_user.return_value = "testtoken"
|
| 70 |
+
|
| 71 |
+
app.dependency_overrides[get_current_user] = lambda: mock_current_user
|
| 72 |
+
|
| 73 |
+
with patch("app.model.router.utils.get_file_hash", return_value="fakehash123"):
|
| 74 |
+
with patch(
|
| 75 |
+
"app.model.router.model_predict", new_callable=AsyncMock
|
| 76 |
+
) as mock_model_predict:
|
| 77 |
+
with patch("app.model.router.os.path.exists", return_value=False):
|
| 78 |
+
mock_model_predict.return_value = ("cat", 0.95)
|
| 79 |
+
with patch("builtins.open", new_callable=MagicMock):
|
| 80 |
+
async with AsyncClient(app=app, base_url="http://test") as ac:
|
| 81 |
+
response = await ac.post(
|
| 82 |
+
"/model/predict",
|
| 83 |
+
files={
|
| 84 |
+
"file": (
|
| 85 |
+
"test_image.pdf",
|
| 86 |
+
mock_file.read.return_value,
|
| 87 |
+
"image/png",
|
| 88 |
+
)
|
| 89 |
+
},
|
| 90 |
+
headers={"Authorization": "Bearer testtoken"},
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
assert response.status_code == 400
|
| 94 |
+
assert response.json() == {
|
| 95 |
+
"detail": "File type is not supported."
|
| 96 |
+
}
|
api/tests/test_router_user.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from unittest.mock import MagicMock
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from app import db
|
| 5 |
+
from app.auth.jwt import create_access_token
|
| 6 |
+
from app.user.models import User
|
| 7 |
+
from app.user.schema import User as UserSchema
|
| 8 |
+
from httpx import AsyncClient
|
| 9 |
+
from main import app
|
| 10 |
+
from sqlalchemy.orm import Session
|
| 11 |
+
|
| 12 |
+
# π‘ NOTE Run tests with: pytest ./tests/test_router_user.py -v
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@pytest.mark.asyncio
|
| 16 |
+
async def test_all_users():
|
| 17 |
+
mock_session = MagicMock(spec=Session)
|
| 18 |
+
mock_user = User(
|
| 19 |
+
name="John Doe", email="john@yahoo.com", password="123456", kwargs={"id": 1}
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
mock_session.query(User).all.return_value = [mock_user]
|
| 23 |
+
|
| 24 |
+
app.dependency_overrides[db.get_db] = lambda: mock_session
|
| 25 |
+
|
| 26 |
+
async with AsyncClient(app=app, base_url="http://test") as ac:
|
| 27 |
+
user_access_token = create_access_token({"sub": "john@gmail.com"})
|
| 28 |
+
response = await ac.get(
|
| 29 |
+
"/user/", headers={"Authorization": f"Bearer {user_access_token}"}
|
| 30 |
+
)
|
| 31 |
+
assert response.status_code == 200
|
| 32 |
+
users = response.json()
|
| 33 |
+
assert len(users) == 1
|
| 34 |
+
assert users[0]["name"] == "John Doe"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@pytest.mark.asyncio
|
| 38 |
+
async def test_create_user_registration_success():
|
| 39 |
+
mock_session = MagicMock(spec=Session)
|
| 40 |
+
request = UserSchema(
|
| 41 |
+
id=0, name="John Doe", email="john@gmail.com", password="123456"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
mock_session.query(User).filter.return_value.first.return_value = None
|
| 45 |
+
|
| 46 |
+
app.dependency_overrides[db.get_db] = lambda: mock_session
|
| 47 |
+
|
| 48 |
+
async with AsyncClient(app=app, base_url="http://test") as ac:
|
| 49 |
+
response = await ac.post("/user/", json=request.dict())
|
| 50 |
+
|
| 51 |
+
assert response.status_code == 201
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@pytest.mark.asyncio
|
| 55 |
+
async def test_create_user_registration_fails():
|
| 56 |
+
mock_session = MagicMock(spec=Session)
|
| 57 |
+
mock_user = User(id=0, name="John Doe", email="john@gmail.com", password="123456")
|
| 58 |
+
request = UserSchema(
|
| 59 |
+
id=0, name="John Doe", email="john@gmail.com", password="123456"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
mock_session.query(User).filter.return_value.first.return_value = mock_user
|
| 63 |
+
|
| 64 |
+
app.dependency_overrides[db.get_db] = lambda: mock_session
|
| 65 |
+
|
| 66 |
+
async with AsyncClient(app=app, base_url="http://test") as ac:
|
| 67 |
+
response = await ac.post("/user/", json=request.dict())
|
| 68 |
+
|
| 69 |
+
assert response.status_code == 400
|
api/tests/test_utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import os
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
|
| 4 |
+
import app.utils as utils
|
| 5 |
+
import pytest
|
| 6 |
+
from fastapi import UploadFile
|
| 7 |
+
from werkzeug.datastructures import FileStorage
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_allowed_file():
|
| 11 |
+
# π‘ NOTE Run test with: pytest ./tests/test_utils.py::test_allowed_file -v
|
| 12 |
+
assert utils.allowed_file("cat.JPG")
|
| 13 |
+
assert utils.allowed_file("cat.jpeg")
|
| 14 |
+
assert utils.allowed_file("cat.JPEG")
|
| 15 |
+
assert utils.allowed_file("../../car.PNG")
|
| 16 |
+
assert utils.allowed_file("/usr/var/src/car.gif")
|
| 17 |
+
|
| 18 |
+
assert not utils.allowed_file("cat.JPGG")
|
| 19 |
+
assert not utils.allowed_file("invoice.pdf")
|
| 20 |
+
assert not utils.allowed_file("/usr/src/slides.odt")
|
| 21 |
+
assert not utils.allowed_file("/usr/src/api")
|
| 22 |
+
assert not utils.allowed_file("/usr/src/api/")
|
| 23 |
+
assert not utils.allowed_file("/usr/src/dog.")
|
| 24 |
+
assert not utils.allowed_file("/usr/src/dog./")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@pytest.mark.asyncio
|
| 28 |
+
async def test_get_file_hash():
|
| 29 |
+
# π‘ NOTE Run test with: pytest ./tests/test_utils.py::test_get_file_hash -v
|
| 30 |
+
filename = "tests/dog.jpeg"
|
| 31 |
+
md5_filename = "0a7c757a80f2c5b13fa7a2a47a683593.jpeg"
|
| 32 |
+
with open(filename, "rb") as fp:
|
| 33 |
+
file = FileStorage(fp)
|
| 34 |
+
file = UploadFile(file=BytesIO(file.read()), filename="dog.jpeg")
|
| 35 |
+
|
| 36 |
+
new_filename = await utils.get_file_hash(file)
|
| 37 |
+
|
| 38 |
+
assert md5_filename == new_filename
|
compose.yml
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
version: "3.2"
|
| 2 |
services:
|
| 3 |
api:
|
| 4 |
image: flask_api
|
|
|
|
|
|
|
| 1 |
services:
|
| 2 |
api:
|
| 3 |
image: flask_api
|