iBrokeTheCode commited on
Commit
6ab520d
Β·
1 Parent(s): 457c02a

chore: Add API service files

Browse files
.env.original ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Test variables
2
+ POSTGRES_DB=sp3
3
+ POSTGRES_USER=postgres
4
+ POSTGRES_PASSWORD=adlibitum
5
+ DATABASE_HOST=db
6
+ SECRET_KEY=S09WWWHXBAJDIUEREHCN3752346572452VGGGVWWW526194
api/.env.original ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ POSTGRES_USER=postgres
2
+ POSTGRES_PASSWORD=adlibitum
3
+ DATABASE_HOST=db
4
+ POSTGRES_DB=sp3
api/Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8.13 AS base
2
+
3
+ ENV PYTHONPATH=$PYTHONPATH:/src/
4
+
5
+ COPY ./requirements.txt /src/requirements.txt
6
+
7
+ WORKDIR /src
8
+
9
+ RUN pip install --upgrade pip && pip install -r requirements.txt
10
+
11
+ COPY ./ /src/
12
+
13
+ FROM base AS test
14
+ RUN ["python", "-m", "pytest", "-v", "/src/tests"]
15
+
16
+ FROM base AS build
17
+
18
+ CMD gunicorn -w 4 -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:5000 main:app
api/Dockerfile.populate ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8.13
2
+
3
+ # 1. Copy the requirements.txt file to the image
4
+ ENV PYTHONPATH=$PYTHONPATH:/src/
5
+ COPY ./requirements.txt /src/requirements.txt
6
+
7
+ WORKDIR /src
8
+
9
+ # 2. Install the dependencies
10
+ RUN pip install --upgrade pip && pip install -r requirements.txt
11
+
12
+ # 3. Copy the content of the current directory to the image
13
+ COPY ./ /src/
14
+
15
+ # 4. Run the populate_db.py script
16
+ CMD ["python", "populate_db.py"]
api/__init__.py ADDED
File without changes
api/app/auth/__init__.py ADDED
File without changes
api/app/auth/jwt.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta
2
+
3
+ from app.settings import SECRET_KEY
4
+ from fastapi import Depends, HTTPException, status
5
+ from fastapi.security import OAuth2PasswordBearer
6
+ from jose import JWTError, jwt
7
+
8
+ from . import schema
9
+
10
+ ALGORITHM = "HS256"
11
+ ACCESS_TOKEN_EXPIRE_MINUTES = 30
12
+
13
+
14
+ def create_access_token(data: dict) -> str:
15
+ """
16
+ Generates a JWT access token with an expiration time.
17
+
18
+ This function creates a JWT (JSON Web Token) that includes the provided
19
+ data and an expiration time. The token is signed using a secret key and
20
+ a specified algorithm.
21
+
22
+ Args:
23
+ data (dict): A dictionary containing the data to be included in the token.
24
+
25
+ Returns:
26
+ str: The encoded JWT token as a string.
27
+ """
28
+ to_encode = data.copy()
29
+ expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
30
+ to_encode.update({"exp": expire})
31
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
32
+ return encoded_jwt
33
+
34
+
35
+ def verify_token(token: str, credentials_exception):
36
+ """
37
+ Verifies the provided JWT token and extracts the user information.
38
+
39
+ This function decodes the given JWT token using the secret key and specified
40
+ algorithm. It checks for the presence of the user's email in the token's payload.
41
+ If the email is not found or the token is invalid, an exception is raised.
42
+
43
+ Args:
44
+ token (str): The JWT token to be verified.
45
+ credentials_exception: The exception to raise if the token is invalid or the email is not found.
46
+
47
+ Returns:
48
+ TokenData: An object containing the user's email extracted from the token.
49
+
50
+ Raises:
51
+ credentials_exception: If the token is invalid or does not contain an email.
52
+ """
53
+ try:
54
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
55
+ email: str = payload.get("sub")
56
+ if email is None:
57
+ raise credentials_exception
58
+ token_data = schema.TokenData(email=email)
59
+ except JWTError:
60
+ raise credentials_exception
61
+ return token_data
62
+
63
+
64
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
65
+
66
+
67
+ def get_current_user(token: str = Depends(oauth2_scheme)):
68
+ """
69
+ Retrieves the current authenticated user based on the provided JWT token.
70
+
71
+ This function extracts the JWT token from the request, verifies it, and returns
72
+ the user data associated with the token. If the token is invalid or cannot be
73
+ verified, an HTTP 401 Unauthorized exception is raised.
74
+
75
+ Args:
76
+ token (str, optional): The JWT token extracted from the request using the OAuth2
77
+ password flow. Defaults to being fetched via `Depends(oauth2_scheme)`.
78
+
79
+ Returns:
80
+ TokenData: An object containing the user's information extracted from the token.
81
+
82
+ Raises:
83
+ HTTPException: If the token is invalid or the credentials cannot be validated.
84
+ """
85
+ credentials_exception = HTTPException(
86
+ status_code=status.HTTP_401_UNAUTHORIZED,
87
+ detail="Could not validate credentials",
88
+ headers={"WWW-Authenticate": "Bearer"},
89
+ )
90
+ return verify_token(token, credentials_exception)
api/app/auth/router.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app import db
2
+ from app.user import hashing
3
+ from app.user.models import User
4
+ from fastapi import APIRouter, Depends, HTTPException, status
5
+ from fastapi.security import OAuth2PasswordRequestForm
6
+ from sqlalchemy.orm import Session
7
+
8
+ from .jwt import create_access_token
9
+
10
+ router = APIRouter(tags=["auth"])
11
+
12
+
13
+ @router.post("/login")
14
+ def login(
15
+ request: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(db.get_db)
16
+ ):
17
+ user = db.query(User).filter(User.email == request.username).first()
18
+
19
+ if not user:
20
+ raise HTTPException(
21
+ status_code=status.HTTP_404_NOT_FOUND, detail="Invalid credentials"
22
+ )
23
+ if not hashing.verify_password(request.password, user.password):
24
+ raise HTTPException(
25
+ status_code=status.HTTP_404_NOT_FOUND, detail="Incorrect password"
26
+ )
27
+
28
+ access_token = create_access_token(data={"sub": user.email})
29
+ return {"access_token": access_token, "token_type": "bearer"}
api/app/auth/schema.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class Login(BaseModel):
7
+ username: str
8
+ password: str
9
+
10
+
11
+ class Token(BaseModel):
12
+ access_token: str
13
+ token_type: str
14
+
15
+
16
+ class TokenData(BaseModel):
17
+ email: Optional[str] = None
api/app/db.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app import settings as config
2
+ from sqlalchemy import create_engine
3
+ from sqlalchemy.ext.declarative import declarative_base
4
+ from sqlalchemy.orm import sessionmaker
5
+
6
+ DATABASE_USERNAME = config.DATABASE_USERNAME
7
+ DATABASE_PASSWORD = config.DATABASE_PASSWORD
8
+ DATABASE_HOST = config.DATABASE_HOST
9
+ DATABASE_NAME = config.DATABASE_NAME
10
+
11
+ SQLALCHEMY_DATABASE_URL = f"postgresql://{DATABASE_USERNAME}:{DATABASE_PASSWORD}@{DATABASE_HOST}/{DATABASE_NAME}"
12
+
13
+ engine = create_engine(SQLALCHEMY_DATABASE_URL)
14
+
15
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
16
+
17
+ Base = declarative_base()
18
+
19
+
20
+ def get_db():
21
+ """
22
+ Provides a database session for dependency injection.
23
+
24
+ This function is used to obtain a new database session instance from the
25
+ `SessionLocal` factory. It is intended to be used with dependency injection
26
+ in FastAPI to manage database sessions.
27
+
28
+ Yields:
29
+ Session: A SQLAlchemy database session.
30
+
31
+ Notes:
32
+ The session is automatically closed after use to ensure proper resource management.
33
+ """
34
+ db = SessionLocal()
35
+ try:
36
+ yield db
37
+ finally:
38
+ db.close()
api/app/feedback/models.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.db import Base
2
+ from sqlalchemy import Column, Float, ForeignKey, Integer, String
3
+ from sqlalchemy.orm import relationship
4
+
5
+
6
+ class Feedback(Base):
7
+ __tablename__ = "feedbacks"
8
+
9
+ id = Column(Integer, primary_key=True, index=True)
10
+ score = Column(Float)
11
+ predicted_class = Column(String(50))
12
+ feedback = Column(String(255))
13
+ user_id = Column(Integer, ForeignKey("users.id"))
14
+ image_file_name = Column(String(255))
15
+ user = relationship("User", back_populates="feedbacks")
16
+
17
+ def __init__(
18
+ self, score, predicted_class, feedback, image_file_name, user, *args, **kwargs
19
+ ):
20
+ self.predicted_class = predicted_class
21
+ self.feedback = feedback
22
+ self.score = score
23
+ self.image_file_name = image_file_name
24
+ self.user = user
api/app/feedback/router.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from app import db
4
+ from app.auth.jwt import get_current_user
5
+ from app.user.schema import User
6
+ from fastapi import APIRouter, Depends, status
7
+ from sqlalchemy.orm import Session
8
+
9
+ from . import schema, services
10
+
11
+ router = APIRouter(tags=["Feedback"], prefix="/feedback")
12
+
13
+
14
+ @router.post("/", status_code=status.HTTP_201_CREATED)
15
+ async def create_feedback(
16
+ request: schema.Feedback,
17
+ database: Session = Depends(db.get_db),
18
+ current_user: User = Depends(get_current_user),
19
+ ):
20
+ return await services.new_feedback(request, current_user, database)
21
+
22
+
23
+ @router.get("/", response_model=List[schema.DisplayFeedback])
24
+ async def get_all_feedback(
25
+ database: Session = Depends(db.get_db),
26
+ current_user: User = Depends(get_current_user),
27
+ ):
28
+ return await services.all_feedback(database, current_user)
api/app/feedback/schema.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class Feedback(BaseModel):
5
+ score: float
6
+ predicted_class: str
7
+ image_file_name: str
8
+ feedback: str
9
+
10
+
11
+ class DisplayFeedback(BaseModel):
12
+ id: int
13
+ score: float
14
+ predicted_class: str
15
+ image_file_name: str
16
+ feedback: str
17
+
18
+ class Config:
19
+ orm_mode = True
api/app/feedback/services.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.auth.schema import TokenData
2
+ from app.user.models import User
3
+ from sqlalchemy.orm import Session
4
+
5
+ from . import models, schema
6
+
7
+
8
+ async def new_feedback(
9
+ request: schema.Feedback, current_user: TokenData, database: Session
10
+ ) -> models.Feedback:
11
+ """
12
+ Adds new feedback to the database associated with the current user.
13
+
14
+ This asynchronous function creates a new feedback entry in the database using
15
+ the provided feedback data and associates it with the current user. It first
16
+ retrieves the user from the database based on the email in the `current_user` object,
17
+ then creates and stores the new feedback entry.
18
+
19
+ Args:
20
+ request (schema.Feedback): An object containing the feedback details such as score,
21
+ image file name, predicted class, and feedback text.
22
+ current_user (TokenData): An object containing the email of the currently authenticated user.
23
+ database (Session): The database session used for querying and committing changes to the database.
24
+
25
+ Returns:
26
+ models.Feedback: The newly created feedback entry stored in the database.
27
+
28
+ Raises:
29
+ Exception: If there is an issue with adding or committing the feedback to the database.
30
+ """
31
+ user = database.query(User).filter(User.email == current_user.email).first()
32
+ new_feedback = models.Feedback(
33
+ score=request.score,
34
+ image_file_name=request.image_file_name,
35
+ predicted_class=request.predicted_class,
36
+ user=user,
37
+ feedback=request.feedback,
38
+ )
39
+ database.add(new_feedback)
40
+ database.commit()
41
+ database.refresh(new_feedback)
42
+ return new_feedback
43
+
44
+
45
+ async def all_feedback(database: Session, current_user: TokenData) -> models.Feedback:
46
+ """
47
+ Retrieves all feedback entries associated with the current user from the database.
48
+
49
+ This asynchronous function queries the database for all feedback entries linked to
50
+ the user identified by the `current_user` object. It returns a list of feedback entries
51
+ associated with the user's ID.
52
+
53
+ Args:
54
+ database (Session): The database session used for querying the database.
55
+ current_user (TokenData): An object containing the email of the currently authenticated user.
56
+
57
+ Returns:
58
+ list[models.Feedback]: A list of feedback entries associated with the current user.
59
+
60
+ Raises:
61
+ Exception: If there is an issue with querying the feedback entries from the database.
62
+ """
63
+ user = database.query(User).filter(User.email == current_user.email).first()
64
+ return (
65
+ database.query(models.Feedback).filter(models.Feedback.user_id == user.id).all()
66
+ )
api/app/model/router.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from app import settings as config
4
+ from app import utils
5
+ from app.auth.jwt import get_current_user
6
+ from app.model.schema import PredictResponse
7
+ from app.model.services import model_predict
8
+ from fastapi import APIRouter, Depends, HTTPException, UploadFile, status # File
9
+
10
+ router = APIRouter(tags=["Model"], prefix="/model")
11
+
12
+
13
+ @router.post("/predict")
14
+ async def predict(file: UploadFile, current_user=Depends(get_current_user)):
15
+ rpse = {"success": False, "prediction": None, "score": None}
16
+
17
+ # Check a file was sent and that file is an image
18
+ if not file or not utils.allowed_file(file.filename):
19
+ raise HTTPException(
20
+ status_code=status.HTTP_400_BAD_REQUEST,
21
+ detail="File type is not supported.",
22
+ )
23
+
24
+ # Store the image to disk, calculate hash before to avoid re-writing an image already uploaded.
25
+ new_filename = await utils.get_file_hash(file)
26
+ file_path = os.path.join(config.UPLOAD_FOLDER, new_filename)
27
+
28
+ if not os.path.exists(file_path):
29
+ with open(file_path, "wb") as out_file:
30
+ content = await file.read()
31
+ out_file.write(content)
32
+
33
+ # Reset file pointer to the beginning
34
+ await file.seek(0)
35
+
36
+ # Send the file to be processed by the model service
37
+ prediction, score = await model_predict(file_path)
38
+
39
+ # Update and return rpse dict with the corresponding values
40
+ rpse["success"] = True
41
+ rpse["prediction"] = prediction
42
+ rpse["score"] = score
43
+ rpse["image_file_name"] = new_filename
44
+
45
+ return PredictResponse(**rpse)
api/app/model/schema.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class PredictRequest(BaseModel):
5
+ file: str
6
+
7
+
8
+ class PredictResponse(BaseModel):
9
+ success: bool
10
+ prediction: str
11
+ score: float
12
+ image_file_name: str
api/app/model/services.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ from uuid import uuid4
4
+
5
+ import redis
6
+
7
+ from .. import settings
8
+
9
+ # Connect to Redis
10
+ db = redis.Redis(
11
+ host=settings.REDIS_IP, port=settings.REDIS_PORT, db=settings.REDIS_DB_ID
12
+ )
13
+
14
+
15
+ async def model_predict(image_name):
16
+ print(f"Processing image {image_name}...")
17
+ """
18
+ Receives an image name and queues the job into Redis.
19
+ Will loop until getting the answer from our ML service.
20
+
21
+ Parameters
22
+ ----------
23
+ image_name : str
24
+ Name for the image uploaded by the user.
25
+
26
+ Returns
27
+ -------
28
+ prediction, score : tuple(str, float)
29
+ Model predicted class as a string and the corresponding confidence
30
+ score as a number.
31
+ """
32
+ prediction = None
33
+ score = None
34
+
35
+ # Assign an unique ID for this job and add it to the queue.
36
+ job_id = str(uuid4())
37
+
38
+ # Create a dict with the job data we will send through Redis
39
+ job_data = {"id": job_id, "image_name": image_name}
40
+
41
+ # Send the job to the model service using Redis
42
+ db.lpush(settings.REDIS_QUEUE, json.dumps(job_data))
43
+
44
+ # Loop until we received the response from our ML model
45
+ while True:
46
+ # Attempt to get model predictions using job_id
47
+ output = db.get(job_id)
48
+
49
+ # Check if the text was correctly processed by the ML model
50
+ if output is not None:
51
+ output = json.loads(output.decode("utf-8"))
52
+ prediction = output["prediction"]
53
+ score = output["score"]
54
+
55
+ db.delete(job_id)
56
+ break
57
+
58
+ # Sleep some time waiting for model results
59
+ time.sleep(settings.API_SLEEP)
60
+
61
+ return prediction, score
api/app/settings.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import dotenv
4
+
5
+ # Load environment variables from .env file
6
+ dotenv.load_dotenv()
7
+
8
+ # Run API in Debug mode
9
+ API_DEBUG = True
10
+
11
+ # We will store images uploaded by the user on this folder
12
+ UPLOAD_FOLDER = "uploads/"
13
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
14
+
15
+ # REDIS settings
16
+
17
+ # Queue name
18
+ REDIS_QUEUE = "service_queue"
19
+ # Port
20
+ REDIS_PORT = 6379
21
+ # DB Id
22
+ REDIS_DB_ID = 0
23
+ # Host IP
24
+ REDIS_IP = os.getenv("REDIS_IP", "redis")
25
+ # Sleep parameters which manages the
26
+ # interval between requests to our redis queue
27
+ API_SLEEP = 0.05
28
+
29
+ # Database settings
30
+ DATABASE_USERNAME = os.getenv("POSTGRES_USER")
31
+ DATABASE_PASSWORD = os.getenv("POSTGRES_PASSWORD")
32
+ DATABASE_HOST = os.getenv("DATABASE_HOST")
33
+ DATABASE_NAME = os.getenv("POSTGRES_DB")
34
+ SECRET_KEY = os.getenv("SECRET_KEY", "S09WWWHXBAJDIUEREHCN3752346572452VGGGVWWW526194")
api/app/user/__init__.py ADDED
File without changes
api/app/user/hashing.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from passlib.context import CryptContext
2
+
3
+ pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
4
+
5
+
6
+ def verify_password(plain_password, hashed_password):
7
+ """
8
+ Verifies if the provided plain password matches the hashed password.
9
+
10
+ This function compares a plain password with a hashed password to check if they
11
+ match using the Argon2 hashing algorithm. It returns `True` if the passwords match,
12
+ otherwise returns `False`.
13
+
14
+ Args:
15
+ plain_password (str): The plain text password to be verified.
16
+ hashed_password (str): The hashed password to compare against.
17
+
18
+ Returns:
19
+ bool: `True` if the plain password matches the hashed password, otherwise `False`.
20
+ """
21
+ return pwd_context.verify(plain_password, hashed_password)
22
+
23
+
24
+ def get_password_hash(password):
25
+ """
26
+ Hashes the provided password using the Argon2 algorithm.
27
+
28
+ This function generates a hashed version of the given plain text password using
29
+ the Argon2 hashing algorithm. The resulting hash can be used for secure password storage.
30
+
31
+ Args:
32
+ password (str): The plain text password to be hashed.
33
+
34
+ Returns:
35
+ str: The hashed password.
36
+ """
37
+ return pwd_context.hash(password)
api/app/user/models.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.db import Base
2
+ from sqlalchemy import Column, Integer, String
3
+ from sqlalchemy.orm import relationship
4
+
5
+ from . import hashing
6
+
7
+
8
+ class User(Base):
9
+ __tablename__ = "users"
10
+
11
+ id = Column(Integer, primary_key=True, index=True)
12
+ name = Column(String(50))
13
+ email = Column(String(255), unique=True)
14
+ password = Column(String(255))
15
+ feedbacks = relationship("Feedback", back_populates="user")
16
+
17
+ def __init__(self, name, email, password, *args, **kwargs):
18
+ self.name = name
19
+ self.email = email
20
+ self.password = hashing.get_password_hash(password)
21
+
22
+ def check_password(self, password):
23
+ return hashing.verify_password(self.password, password)
api/app/user/router.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from typing import List
2
+ from app import db
3
+ from app.auth.jwt import get_current_user
4
+ from fastapi import APIRouter, Depends, HTTPException, status
5
+ from sqlalchemy.orm import Session
6
+
7
+ from . import schema, services, validator
8
+
9
+ router = APIRouter(tags=["Users"], prefix="/user")
10
+
11
+
12
+ @router.post("/", status_code=status.HTTP_201_CREATED)
13
+ async def create_user_registration(
14
+ request: schema.User, database: Session = Depends(db.get_db)
15
+ ):
16
+ # Verify the user email doesn't already exist
17
+ if await validator.verify_email_exist(email=request.email, database=database):
18
+ raise HTTPException(
19
+ status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered"
20
+ )
21
+
22
+ # If the email doesn't exist, create a new user
23
+ new_user = await services.new_user_register(request=request, database=database)
24
+
25
+ return new_user
26
+
27
+
28
+ @router.get("/")
29
+ async def get_all_users(
30
+ database: Session = Depends(db.get_db),
31
+ current_user: schema.User = Depends(get_current_user),
32
+ ):
33
+ return await services.all_users(database)
34
+
35
+
36
+ @router.get("/{id}", response_model=schema.DisplayUser)
37
+ async def get_user_by_id(
38
+ id: int,
39
+ database: Session = Depends(db.get_db),
40
+ current_user: schema.User = Depends(get_current_user),
41
+ ):
42
+ return await services.get_user_by_id(id, database)
43
+
44
+
45
+ @router.delete("/{id}", status_code=status.HTTP_204_NO_CONTENT)
46
+ async def delete_user_by_id(
47
+ id: int,
48
+ database: Session = Depends(db.get_db),
49
+ current_user: schema.User = Depends(get_current_user),
50
+ ):
51
+ return await services.delete_user_by_id(id, database)
api/app/user/schema.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, EmailStr, constr
2
+
3
+
4
+ class User(BaseModel):
5
+ name: constr(min_length=2, max_length=50)
6
+ email: EmailStr
7
+ password: str
8
+
9
+
10
+ class DisplayUser(BaseModel):
11
+ id: int
12
+ name: str
13
+ email: str
14
+
15
+ class Config:
16
+ orm_mode = True
api/app/user/services.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import HTTPException, status
2
+ from sqlalchemy.orm import Session
3
+
4
+ from . import models, schema
5
+
6
+
7
+ async def new_user_register(request: schema.User, database: Session) -> models.User:
8
+ """
9
+ Registers a new user in the database.
10
+
11
+ This asynchronous function creates a new user entry in the database using the provided
12
+ user details from the request. It adds the user to the database, commits the changes,
13
+ and returns the newly created user.
14
+
15
+ Args:
16
+ request (schema.User): An object containing user details such as name, email, and password.
17
+ database (Session): The database session used for adding and committing the user to the database.
18
+
19
+ Returns:
20
+ models.User: The newly created user entry stored in the database.
21
+ """
22
+ new_user = models.User(
23
+ name=request.name, email=request.email, password=request.password
24
+ )
25
+ database.add(new_user)
26
+ database.commit()
27
+ database.refresh(new_user)
28
+ return new_user
29
+
30
+
31
+ async def all_users(database: Session) -> models.User:
32
+ """
33
+ Retrieves all users from the database.
34
+
35
+ This asynchronous function queries the database to retrieve a list of all users.
36
+
37
+ Args:
38
+ database (Session): The database session used for querying the database.
39
+
40
+ Returns:
41
+ list[models.User]: A list of all user entries in the database.
42
+ """
43
+ return database.query(models.User).all()
44
+
45
+
46
+ async def get_user_by_id(id: int, database: Session) -> models.User:
47
+ """
48
+ Retrieves a user from the database by their ID.
49
+
50
+ This asynchronous function queries the database for a user with the specified ID.
51
+ If the user is not found, an HTTP 404 Not Found exception is raised.
52
+
53
+ Args:
54
+ id (int): The ID of the user to retrieve.
55
+ database (Session): The database session used for querying the database.
56
+
57
+ Returns:
58
+ models.User: The user entry with the specified ID.
59
+
60
+ Raises:
61
+ HTTPException: If the user with the specified ID is not found.
62
+ """
63
+ user = database.query(models.User).filter(models.User.id == id).first()
64
+ if not user:
65
+ raise HTTPException(
66
+ status_code=status.HTTP_404_NOT_FOUND,
67
+ detail=f"User with the id {id} is not available",
68
+ )
69
+ return user
70
+
71
+
72
+ async def delete_user_by_id(id: int, database: Session):
73
+ """
74
+ Deletes a user from the database by their ID.
75
+
76
+ This asynchronous function removes the user with the specified ID from the database
77
+ and commits the changes.
78
+
79
+ Args:
80
+ id (int): The ID of the user to delete.
81
+ database (Session): The database session used for querying and committing changes to the database.
82
+ """
83
+ database.query(models.User).filter(models.User.id == id).delete()
84
+ database.commit()
api/app/user/validator.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from sqlalchemy.orm import Session
4
+
5
+ from .models import User
6
+
7
+
8
+ async def verify_email_exist(email: str, database: Session) -> Optional[User]:
9
+ """
10
+ Checks if a user with the specified email exists in the database.
11
+
12
+ This asynchronous function queries the database to find a user with the given email.
13
+ It returns the user object if found, or `None` if no user with that email exists.
14
+
15
+ Args:
16
+ email (str): The email address to check for existence.
17
+ database (Session): The database session used for querying the database.
18
+
19
+ Returns:
20
+ Optional[User]: The user object if a user with the specified email exists, otherwise `None`.
21
+ """
22
+ return database.query(User).filter(User.email == email).first()
api/app/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+
4
+
5
+ def allowed_file(filename):
6
+ """
7
+ Checks if the format for the file received is acceptable. For this
8
+ particular case, we must accept only image files. This is, files with
9
+ extension ".png", ".jpg", ".jpeg" or ".gif".
10
+
11
+ Parameters
12
+ ----------
13
+ filename : str
14
+ Filename from werkzeug.datastructures.FileStorage file.
15
+
16
+ Returns
17
+ -------
18
+ bool
19
+ True if the file is an image, False otherwise.
20
+ """
21
+ # Check if the file extension of the filename received is in the set of allowed extensions (".png", ".jpg", ".jpeg", ".gif")
22
+ allowed_extensions = [".png", ".jpg", ".jpeg", ".gif"]
23
+
24
+ return (
25
+ os.path.splitext(filename)[1].lower() in allowed_extensions
26
+ ) # Alternatively use Path.suffix
27
+
28
+
29
+ async def get_file_hash(file):
30
+ """
31
+ Returns a new filename based on the file content using MD5 hashing.
32
+ It uses hashlib.md5() function from Python standard library to get
33
+ the hash.
34
+
35
+ Parameters
36
+ ----------
37
+ file : werkzeug.datastructures.FileStorage
38
+ File sent by user.
39
+
40
+ Returns
41
+ -------
42
+ str
43
+ New filename based in md5 file hash.
44
+ """
45
+ # Read file content (byte string)
46
+ content = await file.read()
47
+
48
+ # Generate MD5 hash from content
49
+ file_hash = hashlib.md5( # Apply MD5 hash
50
+ content
51
+ ).hexdigest() # hexdigest converts the result into a readable hex string
52
+
53
+ # Get the file extension
54
+ _, ext = os.path.splitext(file.filename)
55
+
56
+ # Reset file pointer to the beginning
57
+ await file.seek(0)
58
+
59
+ return f"{file_hash}{ext}"
api/compose.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ app:
3
+ build:
4
+ context: .
5
+ dockerfile: Dockerfile.populate
6
+ environment:
7
+ POSTGRES_DB: $POSTGRES_DB
8
+ POSTGRES_USER: $POSTGRES_USER
9
+ POSTGRES_PASSWORD: $POSTGRES_PASSWORD
10
+ DATABASE_HOST: $DATABASE_HOST
11
+ networks:
12
+ - shared_network
13
+
14
+ networks:
15
+ shared_network:
16
+ external: true
api/main.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.auth import router as auth_router
2
+ from app.feedback import router as feedback_router
3
+ from app.model import router as model_router
4
+ from app.user import router as user_router
5
+ from fastapi import FastAPI
6
+
7
+ app = FastAPI(title="Image Prediction API", version="0.0.1")
8
+
9
+ app.include_router(auth_router.router)
10
+ app.include_router(model_router.router)
11
+ app.include_router(user_router.router)
12
+ app.include_router(feedback_router.router)
api/populate_db.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import psycopg2
2
+ from app import settings as config
3
+ from app.db import Base
4
+ from app.user.models import User
5
+ from psycopg2.errors import DuplicateDatabase
6
+ from sqlalchemy import create_engine
7
+ from sqlalchemy.orm import sessionmaker
8
+
9
+ # Database configuration
10
+ DATABASE_USERNAME = config.DATABASE_USERNAME
11
+ DATABASE_PASSWORD = config.DATABASE_PASSWORD
12
+ DATABASE_HOST = config.DATABASE_HOST
13
+ DATABASE_NAME = config.DATABASE_NAME
14
+
15
+ # Create the initial connection URL to PostgreSQL (without specifying the database)
16
+ initial_connection_url = (
17
+ f"postgresql://{DATABASE_USERNAME}:{DATABASE_PASSWORD}@{DATABASE_HOST}/postgres"
18
+ )
19
+
20
+ print(initial_connection_url)
21
+
22
+ conn = None
23
+
24
+ # Connect to PostgreSQL to create the database if it doesn't exist
25
+ try:
26
+ conn = psycopg2.connect(initial_connection_url)
27
+ conn.autocommit = True
28
+ cursor = conn.cursor()
29
+
30
+ # Create the database
31
+ cursor.execute(f"CREATE DATABASE {DATABASE_NAME}")
32
+ print(f"Database '{DATABASE_NAME}' created successfully")
33
+
34
+ except DuplicateDatabase as e:
35
+ if "already exists" in str(e):
36
+ print(f"Database '{DATABASE_NAME}' already exists.")
37
+ else:
38
+ print(f"Error creating database: {e}")
39
+ finally:
40
+ if conn:
41
+ cursor.close()
42
+ conn.close()
43
+
44
+ # Database connection URL to the newly created database
45
+ SQLALCHEMY_DATABASE_URL = f"postgresql://{DATABASE_USERNAME}:{DATABASE_PASSWORD}@{DATABASE_HOST}/{DATABASE_NAME}"
46
+ print(SQLALCHEMY_DATABASE_URL)
47
+
48
+ # Create engine
49
+ engine = create_engine(SQLALCHEMY_DATABASE_URL)
50
+
51
+ # Drop all tables if they exist
52
+ Base.metadata.drop_all(engine)
53
+ print("Tables dropped")
54
+
55
+ # Create all tables
56
+ Base.metadata.create_all(engine)
57
+ print("Tables created")
58
+
59
+ # Populate database with a default user
60
+ print("Populating database with default user")
61
+ Session = sessionmaker(bind=engine)
62
+ session = Session()
63
+
64
+ user = User(
65
+ name="Admin User",
66
+ password="admin",
67
+ email="admin@example.com",
68
+ )
69
+
70
+ session.add(user)
71
+ session.commit()
72
+ print("Default user added")
api/requirements.txt ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gunicorn==20.1.0
2
+ redis==4.1.4
3
+ werkzeug==2.0.3
4
+ alembic==1.6.5
5
+ psycopg2-binary==2.9.1
6
+ amqp==5.1.1
7
+ anyio==3.6.2
8
+ argon2-cffi==21.3.0
9
+ argon2-cffi-bindings==21.2.0
10
+ black==22.12.0
11
+ boto3==1.21.32
12
+ botocore==1.24.46
13
+ certifi==2022.12.7
14
+ cffi==1.15.1
15
+ click==8.1.3
16
+ click-didyoumean==0.3.0
17
+ click-plugins==1.1.1
18
+ click-repl==0.2.0
19
+ cryptography==38.0.4
20
+ Deprecated==1.2.13
21
+ dnspython==2.2.1
22
+ ecdsa==0.18.0
23
+ email-validator==1.3.0
24
+ exceptiongroup==1.1.0
25
+ Faker==15.3.4
26
+ fastapi==0.88.0
27
+ h11==0.14.0
28
+ httpcore==0.16.3
29
+ httptools==0.5.0
30
+ httpx==0.23.1
31
+ idna==3.4
32
+ iniconfig==1.1.1
33
+ itsdangerous==2.1.2
34
+ Jinja2==3.1.2
35
+ jmespath==1.0.1
36
+ kombu==5.2.4
37
+ Mako==1.2.4
38
+ MarkupSafe==2.1.1
39
+ mypy-extensions==0.4.3
40
+ orjson==3.8.3
41
+ packaging==22.0
42
+ passlib==1.7.4
43
+ pathspec==0.10.3
44
+ platformdirs==2.6.0
45
+ pluggy==1.0.0
46
+ prompt-toolkit==3.0.36
47
+ pyasn1==0.4.8
48
+ pycparser==2.21
49
+ pydantic==1.10.2
50
+ pytest==7.2.0
51
+ pytest-asyncio==0.20.3
52
+ pytest-mock==3.10.0
53
+ python-dateutil==2.8.2
54
+ python-dotenv==0.21.0
55
+ python-editor==1.0.4
56
+ python-jose==3.3.0
57
+ python-multipart==0.0.5
58
+ pytz==2022.7
59
+ PyYAML==6.0
60
+ rfc3986==1.5.0
61
+ rsa==4.9
62
+ s3transfer==0.5.2
63
+ six==1.16.0
64
+ sniffio==1.3.0
65
+ SQLAlchemy==1.3.24
66
+ starlette==0.22.0
67
+ tomli==2.0.1
68
+ typing_extensions==4.4.0
69
+ ujson==5.6.0
70
+ urllib3==1.26.13
71
+ uvicorn==0.20.0
72
+ uvloop==0.17.0
73
+ vine==5.0.0
74
+ watchfiles==0.18.1
75
+ wcwidth==0.2.5
76
+ websockets==10.4
77
+ wrapt==1.14.1
api/tests/__init__.py ADDED
File without changes
api/tests/dog.jpeg ADDED

Git LFS Details

  • SHA256: 8fb0318ed02f8a28133664be7faa884475cbb961c8f1ee1ff46410e34e38b8b3
  • Pointer size: 130 Bytes
  • Size of remote file: 34.6 kB
api/tests/test_router_feedback.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest import mock
2
+
3
+ import pytest
4
+ from app import db
5
+ from app.auth.jwt import get_current_user
6
+ from app.feedback.schema import DisplayFeedback, Feedback
7
+ from app.user.schema import User
8
+ from fastapi.testclient import TestClient
9
+ from main import app
10
+ from sqlalchemy.orm import Session
11
+
12
+ client = TestClient(app)
13
+
14
+ sample_user = User(
15
+ id=1,
16
+ username="testuser",
17
+ email="testuser@example.com",
18
+ name="Test User",
19
+ password="password",
20
+ )
21
+ sample_feedback = Feedback(
22
+ feedback="Great service!",
23
+ image_file_name="testimage.jpg",
24
+ predicted_class="dog",
25
+ score=0.95,
26
+ )
27
+
28
+
29
+ @pytest.fixture
30
+ def mock_db_session():
31
+ return mock.create_autospec(Session, instance=True)
32
+
33
+
34
+ @pytest.fixture
35
+ def mock_get_current_user():
36
+ return sample_user
37
+
38
+
39
+ @mock.patch("app.feedback.router.services.new_feedback")
40
+ def test_create_feedback(mock_new_feedback, mock_db_session, mock_get_current_user):
41
+ mock_new_feedback.return_value = sample_feedback
42
+
43
+ payload = {
44
+ "feedback": "Great service!",
45
+ "image_file_name": "testimage.jpg",
46
+ "predicted_class": "dog",
47
+ "score": 0.95,
48
+ }
49
+
50
+ app.dependency_overrides[db.get_db] = lambda: mock_db_session
51
+ app.dependency_overrides[get_current_user] = lambda: mock_get_current_user
52
+
53
+ response = client.post(
54
+ "/feedback/",
55
+ json=payload,
56
+ )
57
+
58
+ assert response.status_code == 201
59
+
60
+ mock_new_feedback.assert_called_once_with(payload, sample_user, mock_db_session)
61
+
62
+
63
+ @mock.patch("app.feedback.router.services.all_feedback")
64
+ def test_get_all_feedback(mock_all_feedback, mock_db_session, mock_get_current_user):
65
+ # Setup the mock service to return a list of feedback
66
+ mock_all_feedback.return_value = [
67
+ DisplayFeedback(
68
+ id=1,
69
+ feedback="Great service!",
70
+ score=0.95,
71
+ predicted_class="dog",
72
+ image_file_name="testimage.jpg",
73
+ )
74
+ ]
75
+
76
+ app.dependency_overrides[db.get_db] = lambda: mock_db_session
77
+ app.dependency_overrides[get_current_user] = lambda: mock_get_current_user
78
+ response = client.get(
79
+ "/feedback/",
80
+ )
81
+
82
+ assert response.status_code == 200
83
+
84
+ mock_all_feedback.assert_called_once_with(mock_db_session, sample_user)
api/tests/test_router_model.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import AsyncMock, MagicMock, patch
2
+
3
+ import pytest
4
+ from app.auth.jwt import get_current_user
5
+
6
+ # from app.model.schema import PredictResponse
7
+ from fastapi import UploadFile
8
+
9
+ # from fastapi.testclient import TestClient
10
+ from httpx import AsyncClient
11
+ from main import app
12
+
13
+ # πŸ’‘ NOTE Run tests with: pytest tests/test_router_model.py -v
14
+
15
+
16
+ @pytest.mark.asyncio
17
+ async def test_predict():
18
+ mock_file = AsyncMock(spec=UploadFile)
19
+ mock_file.filename = "test_image.png"
20
+ mock_file.read = AsyncMock(return_value=b"fake-image-data")
21
+
22
+ mock_user = MagicMock()
23
+ mock_user.id = 1
24
+
25
+ mock_current_user = MagicMock()
26
+ mock_current_user.return_value = "testtoken"
27
+
28
+ app.dependency_overrides[get_current_user] = lambda: mock_current_user
29
+
30
+ with patch("app.model.router.utils.get_file_hash", return_value="fakehash123"):
31
+ with patch(
32
+ "app.model.router.model_predict", new_callable=AsyncMock
33
+ ) as mock_model_predict:
34
+ with patch("app.model.router.os.path.exists", return_value=False):
35
+ mock_model_predict.return_value = ("cat", 0.95)
36
+ with patch("builtins.open", new_callable=MagicMock):
37
+ async with AsyncClient(app=app, base_url="http://test") as ac:
38
+ response = await ac.post(
39
+ "/model/predict",
40
+ files={
41
+ "file": (
42
+ "test_image.png",
43
+ mock_file.read.return_value,
44
+ "image/png",
45
+ )
46
+ },
47
+ headers={"Authorization": "Bearer testtoken"},
48
+ )
49
+
50
+ assert response.status_code == 200
51
+
52
+ response_data = response.json()
53
+ assert response_data["success"] is True
54
+ assert response_data["prediction"] == "cat"
55
+ assert response_data["score"] == 0.95
56
+ assert response_data["image_file_name"] == "fakehash123"
57
+
58
+
59
+ @pytest.mark.asyncio
60
+ async def test_predict_fails_bad_extension():
61
+ mock_file = AsyncMock(spec=UploadFile)
62
+ mock_file.filename = "test_image.png"
63
+ mock_file.read = AsyncMock(return_value=b"fake-image-data")
64
+
65
+ mock_user = MagicMock()
66
+ mock_user.id = 1
67
+
68
+ mock_current_user = MagicMock()
69
+ mock_current_user.return_value = "testtoken"
70
+
71
+ app.dependency_overrides[get_current_user] = lambda: mock_current_user
72
+
73
+ with patch("app.model.router.utils.get_file_hash", return_value="fakehash123"):
74
+ with patch(
75
+ "app.model.router.model_predict", new_callable=AsyncMock
76
+ ) as mock_model_predict:
77
+ with patch("app.model.router.os.path.exists", return_value=False):
78
+ mock_model_predict.return_value = ("cat", 0.95)
79
+ with patch("builtins.open", new_callable=MagicMock):
80
+ async with AsyncClient(app=app, base_url="http://test") as ac:
81
+ response = await ac.post(
82
+ "/model/predict",
83
+ files={
84
+ "file": (
85
+ "test_image.pdf",
86
+ mock_file.read.return_value,
87
+ "image/png",
88
+ )
89
+ },
90
+ headers={"Authorization": "Bearer testtoken"},
91
+ )
92
+
93
+ assert response.status_code == 400
94
+ assert response.json() == {
95
+ "detail": "File type is not supported."
96
+ }
api/tests/test_router_user.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import MagicMock
2
+
3
+ import pytest
4
+ from app import db
5
+ from app.auth.jwt import create_access_token
6
+ from app.user.models import User
7
+ from app.user.schema import User as UserSchema
8
+ from httpx import AsyncClient
9
+ from main import app
10
+ from sqlalchemy.orm import Session
11
+
12
+ # πŸ’‘ NOTE Run tests with: pytest ./tests/test_router_user.py -v
13
+
14
+
15
+ @pytest.mark.asyncio
16
+ async def test_all_users():
17
+ mock_session = MagicMock(spec=Session)
18
+ mock_user = User(
19
+ name="John Doe", email="john@yahoo.com", password="123456", kwargs={"id": 1}
20
+ )
21
+
22
+ mock_session.query(User).all.return_value = [mock_user]
23
+
24
+ app.dependency_overrides[db.get_db] = lambda: mock_session
25
+
26
+ async with AsyncClient(app=app, base_url="http://test") as ac:
27
+ user_access_token = create_access_token({"sub": "john@gmail.com"})
28
+ response = await ac.get(
29
+ "/user/", headers={"Authorization": f"Bearer {user_access_token}"}
30
+ )
31
+ assert response.status_code == 200
32
+ users = response.json()
33
+ assert len(users) == 1
34
+ assert users[0]["name"] == "John Doe"
35
+
36
+
37
+ @pytest.mark.asyncio
38
+ async def test_create_user_registration_success():
39
+ mock_session = MagicMock(spec=Session)
40
+ request = UserSchema(
41
+ id=0, name="John Doe", email="john@gmail.com", password="123456"
42
+ )
43
+
44
+ mock_session.query(User).filter.return_value.first.return_value = None
45
+
46
+ app.dependency_overrides[db.get_db] = lambda: mock_session
47
+
48
+ async with AsyncClient(app=app, base_url="http://test") as ac:
49
+ response = await ac.post("/user/", json=request.dict())
50
+
51
+ assert response.status_code == 201
52
+
53
+
54
+ @pytest.mark.asyncio
55
+ async def test_create_user_registration_fails():
56
+ mock_session = MagicMock(spec=Session)
57
+ mock_user = User(id=0, name="John Doe", email="john@gmail.com", password="123456")
58
+ request = UserSchema(
59
+ id=0, name="John Doe", email="john@gmail.com", password="123456"
60
+ )
61
+
62
+ mock_session.query(User).filter.return_value.first.return_value = mock_user
63
+
64
+ app.dependency_overrides[db.get_db] = lambda: mock_session
65
+
66
+ async with AsyncClient(app=app, base_url="http://test") as ac:
67
+ response = await ac.post("/user/", json=request.dict())
68
+
69
+ assert response.status_code == 400
api/tests/test_utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ from io import BytesIO
3
+
4
+ import app.utils as utils
5
+ import pytest
6
+ from fastapi import UploadFile
7
+ from werkzeug.datastructures import FileStorage
8
+
9
+
10
+ def test_allowed_file():
11
+ # πŸ’‘ NOTE Run test with: pytest ./tests/test_utils.py::test_allowed_file -v
12
+ assert utils.allowed_file("cat.JPG")
13
+ assert utils.allowed_file("cat.jpeg")
14
+ assert utils.allowed_file("cat.JPEG")
15
+ assert utils.allowed_file("../../car.PNG")
16
+ assert utils.allowed_file("/usr/var/src/car.gif")
17
+
18
+ assert not utils.allowed_file("cat.JPGG")
19
+ assert not utils.allowed_file("invoice.pdf")
20
+ assert not utils.allowed_file("/usr/src/slides.odt")
21
+ assert not utils.allowed_file("/usr/src/api")
22
+ assert not utils.allowed_file("/usr/src/api/")
23
+ assert not utils.allowed_file("/usr/src/dog.")
24
+ assert not utils.allowed_file("/usr/src/dog./")
25
+
26
+
27
+ @pytest.mark.asyncio
28
+ async def test_get_file_hash():
29
+ # πŸ’‘ NOTE Run test with: pytest ./tests/test_utils.py::test_get_file_hash -v
30
+ filename = "tests/dog.jpeg"
31
+ md5_filename = "0a7c757a80f2c5b13fa7a2a47a683593.jpeg"
32
+ with open(filename, "rb") as fp:
33
+ file = FileStorage(fp)
34
+ file = UploadFile(file=BytesIO(file.read()), filename="dog.jpeg")
35
+
36
+ new_filename = await utils.get_file_hash(file)
37
+
38
+ assert md5_filename == new_filename
compose.yml CHANGED
@@ -1,4 +1,3 @@
1
- version: "3.2"
2
  services:
3
  api:
4
  image: flask_api
 
 
1
  services:
2
  api:
3
  image: flask_api