Spaces:
Build error
Build error
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import torch | |
| import re | |
| import os | |
| # --- Import Libraries --- | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| # --- Model Paths --- | |
| GENDER_MODEL_PATH = "google/gemma-3-270m-qat-q4_0-unquantized" | |
| BASE_MODEL_PATH = "unsloth/gemma-2b-it" | |
| # This correctly points to your model on the Hugging Face Hub. | |
| LORA_ADAPTER_PATH = "enoch10jason/gemma-grammar-lora" | |
| # --- Global variables for models --- | |
| grammar_model = None | |
| grammar_tokenizer = None | |
| gender_model = None | |
| gender_tokenizer = None | |
| device = "cpu" | |
| print("--- Starting Model Loading ---") | |
| try: | |
| # Models are loaded from the pre-downloaded cache in the image. | |
| # No token is needed at runtime because the files are already cached. | |
| print(f"Loading gender model from cache: {GENDER_MODEL_PATH}") | |
| gender_tokenizer = AutoTokenizer.from_pretrained(GENDER_MODEL_PATH) | |
| gender_model = AutoModelForCausalLM.from_pretrained(GENDER_MODEL_PATH).to(device) | |
| print("β Gender verifier model loaded successfully!") | |
| print(f"Loading base model for grammar correction from cache: {BASE_MODEL_PATH}") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL_PATH, | |
| dtype=torch.float32, | |
| ).to(device) | |
| grammar_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH) | |
| print(f"Applying LoRA adapter from cache: {LORA_ADAPTER_PATH}") | |
| grammar_model = PeftModel.from_pretrained(base_model, LORA_ADAPTER_PATH).to(device) | |
| print("β Grammar correction model loaded successfully!") | |
| if grammar_tokenizer.pad_token is None: | |
| grammar_tokenizer.pad_token = grammar_tokenizer.eos_token | |
| if gender_tokenizer.pad_token is None: | |
| gender_tokenizer.pad_token = gender_tokenizer.eos_token | |
| except Exception as e: | |
| print(f"β Critical error during model loading: {e}") | |
| grammar_model = None | |
| gender_model = None | |
| print("--- Model Loading Complete ---") | |
| # --- FastAPI Application Setup --- | |
| app = FastAPI(title="Text Correction API") | |
| class CorrectionRequest(BaseModel): | |
| text: str | |
| class CorrectionResponse(BaseModel): | |
| original_text: str | |
| corrected_text: str | |
| # --- Helper Functions --- | |
| def clean_grammar_response(text: str) -> str: | |
| if "Response:" in text: | |
| parts = text.split("Response:") | |
| if len(parts) > 1: return parts[1].strip() | |
| return text.strip() | |
| def clean_gender_response(text: str) -> str: | |
| if "Response:" in text: | |
| parts = text.split("Response:") | |
| if len(parts) > 1: text = parts[1].strip() | |
| text = re.sub(r'^(Corrected sentence:|Correct:|Prompt:)\s*', '', text, flags=re.IGNORECASE) | |
| return text.strip().strip('"') | |
| def correct_gender_rules(text: str) -> str: | |
| corrections = { | |
| r'\bher wife\b': 'her husband', r'\bhis husband\b': 'his wife', | |
| r'\bhe is a girl\b': 'he is a boy', r'\bshe is a boy\b': 'she is a girl' | |
| } | |
| for pattern, replacement in corrections.items(): | |
| text = re.sub(pattern, replacement, text, flags=re.IGNORECASE) | |
| return text | |
| # --- API Endpoints --- | |
| async def handle_grammar_correction(request: CorrectionRequest): | |
| if not grammar_model or not grammar_tokenizer: | |
| raise HTTPException(status_code=503, detail="Grammar model is not available.") | |
| prompt_text = request.text | |
| input_text = f"Prompt: {prompt_text}\nResponse:" | |
| inputs = grammar_tokenizer(input_text, return_tensors="pt").to(device) | |
| output_ids = grammar_model.generate(**inputs, max_new_tokens=64, do_sample=False) | |
| output_text = grammar_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| corrected = clean_grammar_response(output_text) | |
| return CorrectionResponse(original_text=prompt_text, corrected_text=corrected) | |
| async def handle_gender_correction(request: CorrectionRequest): | |
| if not gender_model or not gender_tokenizer: | |
| raise HTTPException(status_code=503, detail="Gender model is not available.") | |
| prompt_text = request.text | |
| input_text = f"Prompt: Please rewrite the sentence with correct grammar and gender. Output ONLY the corrected sentence:\n{prompt_text}\nResponse:" | |
| inputs = gender_tokenizer(input_text, return_tensors="pt").to(device) | |
| output_ids = gender_model.generate( | |
| **inputs, max_new_tokens=64, temperature=0.0, | |
| do_sample=False, eos_token_id=gender_tokenizer.eos_token_id | |
| ) | |
| output_text = gender_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| cleaned_from_model = clean_gender_response(output_text) | |
| final_correction = correct_gender_rules(cleaned_from_model) | |
| return CorrectionResponse(original_text=prompt_text, corrected_text=final_correction) | |
| def read_root(): | |
| return {"status": "Text Correction API is running."} | |