from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch import re import os # --- Import Libraries --- from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel # --- Model Paths --- GENDER_MODEL_PATH = "google/gemma-3-270m-qat-q4_0-unquantized" BASE_MODEL_PATH = "unsloth/gemma-2b-it" # This correctly points to your model on the Hugging Face Hub. LORA_ADAPTER_PATH = "enoch10jason/gemma-grammar-lora" # --- Global variables for models --- grammar_model = None grammar_tokenizer = None gender_model = None gender_tokenizer = None device = "cpu" print("--- Starting Model Loading ---") try: # Models are loaded from the pre-downloaded cache in the image. # No token is needed at runtime because the files are already cached. print(f"Loading gender model from cache: {GENDER_MODEL_PATH}") gender_tokenizer = AutoTokenizer.from_pretrained(GENDER_MODEL_PATH) gender_model = AutoModelForCausalLM.from_pretrained(GENDER_MODEL_PATH).to(device) print("✅ Gender verifier model loaded successfully!") print(f"Loading base model for grammar correction from cache: {BASE_MODEL_PATH}") base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_PATH, dtype=torch.float32, ).to(device) grammar_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH) print(f"Applying LoRA adapter from cache: {LORA_ADAPTER_PATH}") grammar_model = PeftModel.from_pretrained(base_model, LORA_ADAPTER_PATH).to(device) print("✅ Grammar correction model loaded successfully!") if grammar_tokenizer.pad_token is None: grammar_tokenizer.pad_token = grammar_tokenizer.eos_token if gender_tokenizer.pad_token is None: gender_tokenizer.pad_token = gender_tokenizer.eos_token except Exception as e: print(f"❌ Critical error during model loading: {e}") grammar_model = None gender_model = None print("--- Model Loading Complete ---") # --- FastAPI Application Setup --- app = FastAPI(title="Text Correction API") class CorrectionRequest(BaseModel): text: str class CorrectionResponse(BaseModel): original_text: str corrected_text: str # --- Helper Functions --- def clean_grammar_response(text: str) -> str: if "Response:" in text: parts = text.split("Response:") if len(parts) > 1: return parts[1].strip() return text.strip() def clean_gender_response(text: str) -> str: if "Response:" in text: parts = text.split("Response:") if len(parts) > 1: text = parts[1].strip() text = re.sub(r'^(Corrected sentence:|Correct:|Prompt:)\s*', '', text, flags=re.IGNORECASE) return text.strip().strip('"') def correct_gender_rules(text: str) -> str: corrections = { r'\bher wife\b': 'her husband', r'\bhis husband\b': 'his wife', r'\bhe is a girl\b': 'he is a boy', r'\bshe is a boy\b': 'she is a girl' } for pattern, replacement in corrections.items(): text = re.sub(pattern, replacement, text, flags=re.IGNORECASE) return text # --- API Endpoints --- @app.post("/correct_grammar", response_model=CorrectionResponse) async def handle_grammar_correction(request: CorrectionRequest): if not grammar_model or not grammar_tokenizer: raise HTTPException(status_code=503, detail="Grammar model is not available.") prompt_text = request.text input_text = f"Prompt: {prompt_text}\nResponse:" inputs = grammar_tokenizer(input_text, return_tensors="pt").to(device) output_ids = grammar_model.generate(**inputs, max_new_tokens=64, do_sample=False) output_text = grammar_tokenizer.decode(output_ids[0], skip_special_tokens=True) corrected = clean_grammar_response(output_text) return CorrectionResponse(original_text=prompt_text, corrected_text=corrected) @app.post("/correct_gender", response_model=CorrectionResponse) async def handle_gender_correction(request: CorrectionRequest): if not gender_model or not gender_tokenizer: raise HTTPException(status_code=503, detail="Gender model is not available.") prompt_text = request.text input_text = f"Prompt: Please rewrite the sentence with correct grammar and gender. Output ONLY the corrected sentence:\n{prompt_text}\nResponse:" inputs = gender_tokenizer(input_text, return_tensors="pt").to(device) output_ids = gender_model.generate( **inputs, max_new_tokens=64, temperature=0.0, do_sample=False, eos_token_id=gender_tokenizer.eos_token_id ) output_text = gender_tokenizer.decode(output_ids[0], skip_special_tokens=True) cleaned_from_model = clean_gender_response(output_text) final_correction = correct_gender_rules(cleaned_from_model) return CorrectionResponse(original_text=prompt_text, corrected_text=final_correction) @app.get("/") def read_root(): return {"status": "Text Correction API is running."}