Spaces:
Build error
Build error
Enoch Jason J
commited on
Commit
Β·
9fca407
0
Parent(s):
Add application file
Browse files- .gitignore +67 -0
- Dockerfile +35 -0
- document_pipeline.py +133 -0
- download_models.py +39 -0
- main.py +126 -0
- requirements.txt +8 -0
- requirements_local.txt +8 -0
.gitignore
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyd
|
| 5 |
+
*.pyo
|
| 6 |
+
*.dll
|
| 7 |
+
|
| 8 |
+
# C extensions
|
| 9 |
+
*.so
|
| 10 |
+
|
| 11 |
+
# Distribution / packaging
|
| 12 |
+
.Python
|
| 13 |
+
build/
|
| 14 |
+
develop-eggs/
|
| 15 |
+
dist/
|
| 16 |
+
eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
|
| 28 |
+
# Installer logs
|
| 29 |
+
pip-log.txt
|
| 30 |
+
pip-delete-this-directory.txt
|
| 31 |
+
|
| 32 |
+
# Unit test / coverage reports
|
| 33 |
+
.coverage
|
| 34 |
+
.tox/
|
| 35 |
+
htmlcov/
|
| 36 |
+
.pytest_cache/
|
| 37 |
+
|
| 38 |
+
# Editors
|
| 39 |
+
.vscode/
|
| 40 |
+
.idea/
|
| 41 |
+
|
| 42 |
+
# OS
|
| 43 |
+
.DS_Store
|
| 44 |
+
.Trashes
|
| 45 |
+
Thumbs.db
|
| 46 |
+
|
| 47 |
+
# Virtual environment
|
| 48 |
+
.venv/
|
| 49 |
+
venv/
|
| 50 |
+
env/
|
| 51 |
+
|
| 52 |
+
# Jupyter Notebook
|
| 53 |
+
.ipynb_checkpoints
|
| 54 |
+
|
| 55 |
+
# MyPy
|
| 56 |
+
.mypy_cache/
|
| 57 |
+
|
| 58 |
+
# PyInstaller
|
| 59 |
+
*.spec
|
| 60 |
+
build/
|
| 61 |
+
dist/
|
| 62 |
+
|
| 63 |
+
# Temporary files
|
| 64 |
+
*.tmp
|
| 65 |
+
*.bak
|
| 66 |
+
*.swp
|
| 67 |
+
*~
|
Dockerfile
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use a standard Python slim image for a lightweight CPU environment
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
# Set the working directory
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# --- Installation ---
|
| 8 |
+
# Copy the local requirements file and the model downloader script
|
| 9 |
+
COPY requirements_local.txt .
|
| 10 |
+
COPY download_models.py .
|
| 11 |
+
|
| 12 |
+
# Install Python dependencies
|
| 13 |
+
RUN pip install --no-cache-dir -r requirements_local.txt
|
| 14 |
+
|
| 15 |
+
# --- Pre-download and Cache Models during the Build Process ---
|
| 16 |
+
# This makes the container startup fast and reliable.
|
| 17 |
+
# The token is passed securely as a build argument and is not saved in the final image.
|
| 18 |
+
ARG HUGGING_FACE_HUB_TOKEN
|
| 19 |
+
RUN --mount=type=cache,target=/root/.cache/huggingface \
|
| 20 |
+
HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN} python download_models.py
|
| 21 |
+
|
| 22 |
+
# Copy the main application code
|
| 23 |
+
COPY main.py .
|
| 24 |
+
|
| 25 |
+
# IMPORTANT: If your LoRA adapter is a local folder, you need to copy it in.
|
| 26 |
+
# For example:
|
| 27 |
+
# COPY ./my_local_lora_adapter /app/my_local_lora_adapter
|
| 28 |
+
# Then, in main.py, set LORA_ADAPTER_PATH = "/app/my_local_lora_adapter"
|
| 29 |
+
|
| 30 |
+
# Expose the port the app runs on
|
| 31 |
+
EXPOSE 8000
|
| 32 |
+
|
| 33 |
+
# Command to run the application
|
| 34 |
+
CMD ["uvicorn", "main:app", "--host", "0._0.0.0", "--port", "8000"]
|
| 35 |
+
|
document_pipeline.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import re
|
| 3 |
+
from fpdf import FPDF
|
| 4 |
+
import os
|
| 5 |
+
import textract
|
| 6 |
+
|
| 7 |
+
# --- Configuration ---
|
| 8 |
+
AI_SERVICE_URL = "http://localhost:8000"
|
| 9 |
+
INPUT_DOC_PATH = "Doreen.doc"
|
| 10 |
+
OUTPUT_PDF_PATH = "Doreen DeFio_Dr. Daniel Rich_Report_Generated.pdf"
|
| 11 |
+
|
| 12 |
+
def correct_text_via_api(endpoint: str, text: str) -> str:
|
| 13 |
+
try:
|
| 14 |
+
response = requests.post(f"{AI_SERVICE_URL}/{endpoint}", json={"text": text})
|
| 15 |
+
response.raise_for_status()
|
| 16 |
+
return response.json()["corrected_text"]
|
| 17 |
+
except requests.exceptions.RequestException as e:
|
| 18 |
+
print(f"Error calling AI service at endpoint '{endpoint}': {e}")
|
| 19 |
+
return text
|
| 20 |
+
|
| 21 |
+
def extract_text_from_doc(filepath):
|
| 22 |
+
if not os.path.exists(filepath):
|
| 23 |
+
raise FileNotFoundError(f"Input file not found at: {filepath}")
|
| 24 |
+
try:
|
| 25 |
+
text_bytes = textract.process(filepath)
|
| 26 |
+
return text_bytes.decode('utf-8')
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"Error reading document with textract: {e}")
|
| 29 |
+
return None
|
| 30 |
+
|
| 31 |
+
def parse_and_correct_text(raw_text):
|
| 32 |
+
structured_data = {}
|
| 33 |
+
current_section = None
|
| 34 |
+
buffer = []
|
| 35 |
+
key_value_pattern = re.compile(
|
| 36 |
+
r'^\s*(Client Name|Date of Exam|Date of Accident|Examinee|Observed By|Performed By|Specialty|Facility|Facility Description|Appointment Scheduled For|Arrived at Office|Admitted to Exam Room|Intake Start|Exam Start|Exam End|Length of Exam|Total Length of Visit|Others Present|Description of IME physician|Layout of Exam Room|Did IME Physician Have Examinees Medical Records)\s*:\s*(.*)',
|
| 37 |
+
re.IGNORECASE | re.DOTALL
|
| 38 |
+
)
|
| 39 |
+
section_headers = ["Intake:", "Exam:"]
|
| 40 |
+
lines = [line.strip() for line in raw_text.split('\n') if line.strip()]
|
| 41 |
+
|
| 42 |
+
i = 0
|
| 43 |
+
while i < len(lines):
|
| 44 |
+
line = lines[i]
|
| 45 |
+
if line in section_headers:
|
| 46 |
+
if current_section and buffer:
|
| 47 |
+
full_paragraph = " ".join(buffer)
|
| 48 |
+
grammar_corrected = correct_text_via_api("correct_grammar", full_paragraph)
|
| 49 |
+
final_corrected = correct_text_via_api("correct_gender", grammar_corrected)
|
| 50 |
+
structured_data[current_section] = final_corrected
|
| 51 |
+
current_section = line.replace(":", "").strip()
|
| 52 |
+
buffer = []
|
| 53 |
+
i += 1
|
| 54 |
+
continue
|
| 55 |
+
match = key_value_pattern.match(line)
|
| 56 |
+
if match:
|
| 57 |
+
key, value = map(str.strip, match.groups())
|
| 58 |
+
if not value and (i + 1) < len(lines) and not key_value_pattern.match(lines[i+1]) and lines[i+1] not in section_headers:
|
| 59 |
+
value = lines[i+1]
|
| 60 |
+
i += 1
|
| 61 |
+
structured_data[key] = correct_text_via_api("correct_grammar", value)
|
| 62 |
+
elif current_section:
|
| 63 |
+
buffer.append(line)
|
| 64 |
+
i += 1
|
| 65 |
+
if current_section and buffer:
|
| 66 |
+
full_paragraph = " ".join(buffer)
|
| 67 |
+
grammar_corrected = correct_text_via_api("correct_grammar", full_paragraph)
|
| 68 |
+
final_corrected = correct_text_via_api("correct_gender", grammar_corrected)
|
| 69 |
+
structured_data[current_section] = final_corrected
|
| 70 |
+
return structured_data
|
| 71 |
+
|
| 72 |
+
class PDF(FPDF):
|
| 73 |
+
def header(self):
|
| 74 |
+
self.set_font("DejaVu", "B", 15)
|
| 75 |
+
self.cell(0, 10, 'IME WatchDog Report', 0, 1, 'C')
|
| 76 |
+
self.ln(10)
|
| 77 |
+
|
| 78 |
+
def footer(self):
|
| 79 |
+
self.set_y(-15)
|
| 80 |
+
self.set_font("DejaVu", "I", 8)
|
| 81 |
+
self.cell(0, 10, f'Page {self.page_no()}', 0, 0, 'C')
|
| 82 |
+
|
| 83 |
+
def generate_pdf(data, output_path):
|
| 84 |
+
pdf = PDF()
|
| 85 |
+
# --- FIX: Add a Unicode font that supports characters like β ---
|
| 86 |
+
# You may need to provide the path to the .ttf font file if not in a standard location.
|
| 87 |
+
# This example assumes it can be found.
|
| 88 |
+
try:
|
| 89 |
+
pdf.add_font("DejaVu", "", "DejaVuSans.ttf", uni=True)
|
| 90 |
+
pdf.add_font("DejaVu", "B", "DejaVuSans-Bold.ttf", uni=True)
|
| 91 |
+
pdf.add_font("DejaVu", "I", "DejaVuSans-Oblique.ttf", uni=True)
|
| 92 |
+
except RuntimeError:
|
| 93 |
+
print("---")
|
| 94 |
+
print("β οΈ FONT WARNING: DejaVuSans.ttf not found.")
|
| 95 |
+
print("The PDF will be generated, but may have character issues.")
|
| 96 |
+
print("Please download the DejaVu font family and place the .ttf files in this directory.")
|
| 97 |
+
print("---")
|
| 98 |
+
|
| 99 |
+
pdf.add_page()
|
| 100 |
+
pdf.set_font("DejaVu", "", 12)
|
| 101 |
+
key_order = [
|
| 102 |
+
"Client Name", "Date of Exam", "Date of Accident", "Examinee", "Observed By",
|
| 103 |
+
"Performed By", "Specialty", "Facility", "Facility Description",
|
| 104 |
+
"Appointment Scheduled For", "Arrived at Office", "Admitted to Exam Room",
|
| 105 |
+
"Intake Start", "Exam Start", "Exam End", "Length of Exam", "Total Length of Visit",
|
| 106 |
+
"Others Present", "Description of IME physician", "Layout of Exam Room",
|
| 107 |
+
"Did IME Physician Have Examinees Medical Records", "Intake", "Exam"
|
| 108 |
+
]
|
| 109 |
+
for key in key_order:
|
| 110 |
+
if key in data:
|
| 111 |
+
value = data[key]
|
| 112 |
+
pdf.set_font("DejaVu", "B", 12)
|
| 113 |
+
pdf.cell(0, 10, f"{key}:", ln=True)
|
| 114 |
+
pdf.set_font("DejaVu", "", 12)
|
| 115 |
+
pdf.multi_cell(0, 8, str(value))
|
| 116 |
+
pdf.ln(4)
|
| 117 |
+
pdf.output(output_path)
|
| 118 |
+
print(f"β
Successfully generated PDF report at: {output_path}")
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
print("--- Starting Document Transformation Pipeline ---")
|
| 122 |
+
if os.path.exists(INPUT_DOC_PATH):
|
| 123 |
+
print(f"1. Extracting text from '{INPUT_DOC_PATH}' using textract...")
|
| 124 |
+
raw_document_text = extract_text_from_doc(INPUT_DOC_PATH)
|
| 125 |
+
if raw_document_text:
|
| 126 |
+
print("2. Parsing and correcting text via AI microservice...")
|
| 127 |
+
corrected_data = parse_and_correct_text(raw_document_text)
|
| 128 |
+
print(f"3. Generating PDF report '{OUTPUT_PDF_PATH}'...")
|
| 129 |
+
generate_pdf(corrected_data, OUTPUT_PDF_PATH)
|
| 130 |
+
print("--- Pipeline Finished ---")
|
| 131 |
+
else:
|
| 132 |
+
print(f"β ERROR: Input file not found: '{INPUT_DOC_PATH}'")
|
| 133 |
+
|
download_models.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 4 |
+
from peft import PeftModel
|
| 5 |
+
|
| 6 |
+
# This script is run during the Docker build process to pre-download models.
|
| 7 |
+
|
| 8 |
+
GENDER_MODEL_PATH = "google/gemma-3-270m-qat-q4_0-unquantized"
|
| 9 |
+
BASE_MODEL_PATH = "unsloth/gemma-2b-it"
|
| 10 |
+
LORA_ADAPTER_PATH = "unsloth/gemma-2b-it-lora-test"
|
| 11 |
+
|
| 12 |
+
hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
|
| 13 |
+
if not hf_token:
|
| 14 |
+
raise ValueError("HUGGING_FACE_HUB_TOKEN environment variable is required to download models.")
|
| 15 |
+
|
| 16 |
+
print("--- Starting Model Pre-downloading ---")
|
| 17 |
+
|
| 18 |
+
# 1. Download Gender Model
|
| 19 |
+
print(f"Downloading: {GENDER_MODEL_PATH}")
|
| 20 |
+
AutoTokenizer.from_pretrained(GENDER_MODEL_PATH, token=hf_token)
|
| 21 |
+
AutoModelForCausalLM.from_pretrained(GENDER_MODEL_PATH, token=hf_token)
|
| 22 |
+
print("β
Gender model downloaded.")
|
| 23 |
+
|
| 24 |
+
# 2. Download Grammar Model (Base + Adapter)
|
| 25 |
+
print(f"Downloading: {BASE_MODEL_PATH}")
|
| 26 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 27 |
+
BASE_MODEL_PATH,
|
| 28 |
+
token=hf_token,
|
| 29 |
+
dtype=torch.float32,
|
| 30 |
+
)
|
| 31 |
+
AutoTokenizer.from_pretrained(BASE_MODEL_PATH, token=hf_token)
|
| 32 |
+
print("β
Base model downloaded.")
|
| 33 |
+
|
| 34 |
+
print(f"Downloading: {LORA_ADAPTER_PATH}")
|
| 35 |
+
PeftModel.from_pretrained(base_model, LORA_ADAPTER_PATH, token=hf_token)
|
| 36 |
+
print("β
LoRA adapter downloaded.")
|
| 37 |
+
|
| 38 |
+
print("--- Model Pre-downloading Complete ---")
|
| 39 |
+
|
main.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
import torch
|
| 4 |
+
import re
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
# --- Import Libraries ---
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 9 |
+
from peft import PeftModel
|
| 10 |
+
|
| 11 |
+
# --- Model Paths (will be loaded from local cache) ---
|
| 12 |
+
GENDER_MODEL_PATH = "google/gemma-3-270m-qat-q4_0-unquantized"
|
| 13 |
+
BASE_MODEL_PATH = "unsloth/gemma-2b-it"
|
| 14 |
+
LORA_ADAPTER_PATH = "unsloth/gemma-2b-it-lora-test"
|
| 15 |
+
|
| 16 |
+
# --- Global variables for models ---
|
| 17 |
+
grammar_model = None
|
| 18 |
+
grammar_tokenizer = None
|
| 19 |
+
gender_model = None
|
| 20 |
+
gender_tokenizer = None
|
| 21 |
+
device = "cpu"
|
| 22 |
+
|
| 23 |
+
print("--- Starting Model Loading ---")
|
| 24 |
+
|
| 25 |
+
# The token is only used during the build, not at runtime.
|
| 26 |
+
hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
# Load models from the local cache inside the container. Startup is now fast.
|
| 30 |
+
print(f"Loading gender model from cache: {GENDER_MODEL_PATH}")
|
| 31 |
+
gender_tokenizer = AutoTokenizer.from_pretrained(GENDER_MODEL_PATH, token=hf_token)
|
| 32 |
+
gender_model = AutoModelForCausalLM.from_pretrained(GENDER_MODEL_PATH, token=hf_token).to(device)
|
| 33 |
+
print("β
Gender verifier model loaded successfully!")
|
| 34 |
+
|
| 35 |
+
print(f"Loading base model for grammar correction from cache: {BASE_MODEL_PATH}")
|
| 36 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 37 |
+
BASE_MODEL_PATH,
|
| 38 |
+
token=hf_token,
|
| 39 |
+
dtype=torch.float32,
|
| 40 |
+
).to(device)
|
| 41 |
+
grammar_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH, token=hf_token)
|
| 42 |
+
|
| 43 |
+
print(f"Applying LoRA adapter from cache: {LORA_ADAPTER_PATH}")
|
| 44 |
+
grammar_model = PeftModel.from_pretrained(base_model, LORA_ADAPTER_PATH, token=hf_token).to(device)
|
| 45 |
+
print("β
Grammar correction model loaded successfully!")
|
| 46 |
+
|
| 47 |
+
if grammar_tokenizer.pad_token is None:
|
| 48 |
+
grammar_tokenizer.pad_token = grammar_tokenizer.eos_token
|
| 49 |
+
if gender_tokenizer.pad_token is None:
|
| 50 |
+
gender_tokenizer.pad_token = gender_tokenizer.eos_token
|
| 51 |
+
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f"β Critical error during model loading: {e}")
|
| 54 |
+
grammar_model = None
|
| 55 |
+
gender_model = None
|
| 56 |
+
|
| 57 |
+
print("--- Model Loading Complete ---")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# --- FastAPI Application Setup ---
|
| 61 |
+
app = FastAPI(title="Text Correction API")
|
| 62 |
+
|
| 63 |
+
class CorrectionRequest(BaseModel):
|
| 64 |
+
text: str
|
| 65 |
+
|
| 66 |
+
class CorrectionResponse(BaseModel):
|
| 67 |
+
original_text: str
|
| 68 |
+
corrected_text: str
|
| 69 |
+
|
| 70 |
+
# --- Helper Functions for Text Cleaning ---
|
| 71 |
+
def clean_grammar_response(text: str) -> str:
|
| 72 |
+
if "Response:" in text:
|
| 73 |
+
parts = text.split("Response:")
|
| 74 |
+
if len(parts) > 1: return parts[1].strip()
|
| 75 |
+
return text.strip()
|
| 76 |
+
|
| 77 |
+
def clean_gender_response(text: str) -> str:
|
| 78 |
+
if "Response:" in text:
|
| 79 |
+
parts = text.split("Response:")
|
| 80 |
+
if len(parts) > 1: text = parts[1].strip()
|
| 81 |
+
text = re.sub(r'^(Corrected sentence:|Correct:|Prompt:)\s*', '', text, flags=re.IGNORECASE)
|
| 82 |
+
return text.strip().strip('"')
|
| 83 |
+
|
| 84 |
+
def correct_gender_rules(text: str) -> str:
|
| 85 |
+
corrections = {
|
| 86 |
+
r'\bher wife\b': 'her husband', r'\bhis husband\b': 'his wife',
|
| 87 |
+
r'\bhe is a girl\b': 'he is a boy', r'\bshe is a boy\b': 'she is a girl'
|
| 88 |
+
}
|
| 89 |
+
for pattern, replacement in corrections.items():
|
| 90 |
+
text = re.sub(pattern, replacement, text, flags=re.IGNORECASE)
|
| 91 |
+
return text
|
| 92 |
+
|
| 93 |
+
# --- API Endpoints ---
|
| 94 |
+
|
| 95 |
+
@app.post("/correct_grammar", response_model=CorrectionResponse)
|
| 96 |
+
async def handle_grammar_correction(request: CorrectionRequest):
|
| 97 |
+
if not grammar_model or not grammar_tokenizer:
|
| 98 |
+
raise HTTPException(status_code=503, detail="Grammar model is not available.")
|
| 99 |
+
prompt_text = request.text
|
| 100 |
+
input_text = f"Prompt: {prompt_text}\nResponse:"
|
| 101 |
+
inputs = grammar_tokenizer(input_text, return_tensors="pt").to(device)
|
| 102 |
+
output_ids = grammar_model.generate(**inputs, max_new_tokens=64, do_sample=False)
|
| 103 |
+
output_text = grammar_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 104 |
+
corrected = clean_grammar_response(output_text)
|
| 105 |
+
return CorrectionResponse(original_text=prompt_text, corrected_text=corrected)
|
| 106 |
+
|
| 107 |
+
@app.post("/correct_gender", response_model=CorrectionResponse)
|
| 108 |
+
async def handle_gender_correction(request: CorrectionRequest):
|
| 109 |
+
if not gender_model or not gender_tokenizer:
|
| 110 |
+
raise HTTPException(status_code=503, detail="Gender model is not available.")
|
| 111 |
+
prompt_text = request.text
|
| 112 |
+
input_text = f"Prompt: Please rewrite the sentence with correct grammar and gender. Output ONLY the corrected sentence:\n{prompt_text}\nResponse:"
|
| 113 |
+
inputs = gender_tokenizer(input_text, return_tensors="pt").to(device)
|
| 114 |
+
output_ids = gender_model.generate(
|
| 115 |
+
**inputs, max_new_tokens=64, temperature=0.0,
|
| 116 |
+
do_sample=False, eos_token_id=gender_tokenizer.eos_token_id
|
| 117 |
+
)
|
| 118 |
+
output_text = gender_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 119 |
+
cleaned_from_model = clean_gender_response(output_text)
|
| 120 |
+
final_correction = correct_gender_rules(cleaned_from_model)
|
| 121 |
+
return CorrectionResponse(original_text=prompt_text, corrected_text=final_correction)
|
| 122 |
+
|
| 123 |
+
@app.get("/")
|
| 124 |
+
def read_root():
|
| 125 |
+
return {"status": "Text Correction API is running."}
|
| 126 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn[standard]
|
| 3 |
+
torch
|
| 4 |
+
transformers
|
| 5 |
+
peft
|
| 6 |
+
accelerate
|
| 7 |
+
pydantic
|
| 8 |
+
sentencepiece
|
requirements_local.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn[standard]
|
| 3 |
+
torch
|
| 4 |
+
transformers
|
| 5 |
+
peft
|
| 6 |
+
python-docx
|
| 7 |
+
fpdf2
|
| 8 |
+
textract
|