Enoch Jason J
Finalize Hub-based deployment strategy
1ab6f41
raw
history blame
4.91 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import re
import os
# --- Import Libraries ---
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# --- Model Paths ---
GENDER_MODEL_PATH = "google/gemma-3-270m-qat-q4_0-unquantized"
BASE_MODEL_PATH = "unsloth/gemma-2b-it"
# This correctly points to your model on the Hugging Face Hub.
LORA_ADAPTER_PATH = "enoch10jason/gemma-grammar-lora"
# --- Global variables for models ---
grammar_model = None
grammar_tokenizer = None
gender_model = None
gender_tokenizer = None
device = "cpu"
print("--- Starting Model Loading ---")
try:
# Models are loaded from the pre-downloaded cache in the image.
# No token is needed at runtime because the files are already cached.
print(f"Loading gender model from cache: {GENDER_MODEL_PATH}")
gender_tokenizer = AutoTokenizer.from_pretrained(GENDER_MODEL_PATH)
gender_model = AutoModelForCausalLM.from_pretrained(GENDER_MODEL_PATH).to(device)
print("βœ… Gender verifier model loaded successfully!")
print(f"Loading base model for grammar correction from cache: {BASE_MODEL_PATH}")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_PATH,
dtype=torch.float32,
).to(device)
grammar_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
print(f"Applying LoRA adapter from cache: {LORA_ADAPTER_PATH}")
grammar_model = PeftModel.from_pretrained(base_model, LORA_ADAPTER_PATH).to(device)
print("βœ… Grammar correction model loaded successfully!")
if grammar_tokenizer.pad_token is None:
grammar_tokenizer.pad_token = grammar_tokenizer.eos_token
if gender_tokenizer.pad_token is None:
gender_tokenizer.pad_token = gender_tokenizer.eos_token
except Exception as e:
print(f"❌ Critical error during model loading: {e}")
grammar_model = None
gender_model = None
print("--- Model Loading Complete ---")
# --- FastAPI Application Setup ---
app = FastAPI(title="Text Correction API")
class CorrectionRequest(BaseModel):
text: str
class CorrectionResponse(BaseModel):
original_text: str
corrected_text: str
# --- Helper Functions ---
def clean_grammar_response(text: str) -> str:
if "Response:" in text:
parts = text.split("Response:")
if len(parts) > 1: return parts[1].strip()
return text.strip()
def clean_gender_response(text: str) -> str:
if "Response:" in text:
parts = text.split("Response:")
if len(parts) > 1: text = parts[1].strip()
text = re.sub(r'^(Corrected sentence:|Correct:|Prompt:)\s*', '', text, flags=re.IGNORECASE)
return text.strip().strip('"')
def correct_gender_rules(text: str) -> str:
corrections = {
r'\bher wife\b': 'her husband', r'\bhis husband\b': 'his wife',
r'\bhe is a girl\b': 'he is a boy', r'\bshe is a boy\b': 'she is a girl'
}
for pattern, replacement in corrections.items():
text = re.sub(pattern, replacement, text, flags=re.IGNORECASE)
return text
# --- API Endpoints ---
@app.post("/correct_grammar", response_model=CorrectionResponse)
async def handle_grammar_correction(request: CorrectionRequest):
if not grammar_model or not grammar_tokenizer:
raise HTTPException(status_code=503, detail="Grammar model is not available.")
prompt_text = request.text
input_text = f"Prompt: {prompt_text}\nResponse:"
inputs = grammar_tokenizer(input_text, return_tensors="pt").to(device)
output_ids = grammar_model.generate(**inputs, max_new_tokens=64, do_sample=False)
output_text = grammar_tokenizer.decode(output_ids[0], skip_special_tokens=True)
corrected = clean_grammar_response(output_text)
return CorrectionResponse(original_text=prompt_text, corrected_text=corrected)
@app.post("/correct_gender", response_model=CorrectionResponse)
async def handle_gender_correction(request: CorrectionRequest):
if not gender_model or not gender_tokenizer:
raise HTTPException(status_code=503, detail="Gender model is not available.")
prompt_text = request.text
input_text = f"Prompt: Please rewrite the sentence with correct grammar and gender. Output ONLY the corrected sentence:\n{prompt_text}\nResponse:"
inputs = gender_tokenizer(input_text, return_tensors="pt").to(device)
output_ids = gender_model.generate(
**inputs, max_new_tokens=64, temperature=0.0,
do_sample=False, eos_token_id=gender_tokenizer.eos_token_id
)
output_text = gender_tokenizer.decode(output_ids[0], skip_special_tokens=True)
cleaned_from_model = clean_gender_response(output_text)
final_correction = correct_gender_rules(cleaned_from_model)
return CorrectionResponse(original_text=prompt_text, corrected_text=final_correction)
@app.get("/")
def read_root():
return {"status": "Text Correction API is running."}