|
|
from difflib import SequenceMatcher |
|
|
from typing import Dict, List |
|
|
from fastapi import FastAPI |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
import logging |
|
|
import json |
|
|
import re |
|
|
from pathlib import Path |
|
|
from scripts.main import main |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI(title="RAG API", version="1.0.0") |
|
|
|
|
|
import os |
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
@app.post("/get_response") |
|
|
async def get_response(query: str, top_k: int = 20, top_n: int = 10): |
|
|
response, session_id = main(query, top_k=top_k, top_n=top_n) |
|
|
return { |
|
|
"response": response, |
|
|
"session_id": session_id |
|
|
} |
|
|
|
|
|
@app.get("/api/session/{session_id}") |
|
|
async def get_session_data(session_id: str): |
|
|
import json |
|
|
import os |
|
|
|
|
|
import os |
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
session_path = os.path.join(current_dir, "scripts", "sessions", f"{session_id}.json") |
|
|
if not os.path.exists(session_path): |
|
|
return {"error": "Session not found"}, 404 |
|
|
|
|
|
try: |
|
|
with open(session_path, "r", encoding="utf-8") as f: |
|
|
session_data = json.load(f) |
|
|
return session_data |
|
|
except Exception as e: |
|
|
logger.error(f"Error reading session {session_id}: {e}") |
|
|
return {"error": "Error reading session data"}, 500 |
|
|
|
|
|
@app.get("/api/document/{doc_id}") |
|
|
async def get_document_content(doc_id: str): |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
doc_path = "converted/11. QĐ về Học phí final (25-10-2021).md" |
|
|
|
|
|
if not os.path.exists(doc_path): |
|
|
return {"error": "Document not found"}, 404 |
|
|
|
|
|
try: |
|
|
with open(doc_path, "r", encoding="utf-8") as f: |
|
|
content = f.read() |
|
|
return {"content": content, "doc_id": doc_id} |
|
|
except Exception as e: |
|
|
logger.error(f"Error reading document {doc_id}: {e}") |
|
|
return {"error": "Error reading document"}, 500 |
|
|
|
|
|
|
|
|
def find_all_positions(content: str, sub: str) -> list[int]: |
|
|
"""Trả về list tất cả vị trí start của sub trong content""" |
|
|
positions = [] |
|
|
start = 0 |
|
|
while True: |
|
|
pos = content.find(sub, start) |
|
|
if pos == -1: |
|
|
break |
|
|
positions.append(pos) |
|
|
start = pos + 1 |
|
|
return positions |
|
|
|
|
|
def find_best_match(text_to_find, markdown_content, threshold=0.8, best_matches=[]): |
|
|
"""Tìm đoạn text tương tự nhất trong markdown content""" |
|
|
best_match = None |
|
|
best_ratio = 0 |
|
|
|
|
|
|
|
|
start_markers = text_to_find[:50] |
|
|
if start_markers[0] == "|": |
|
|
start_markers = start_markers[:20] |
|
|
if start_markers in markdown_content: |
|
|
start_pos_list = find_all_positions(markdown_content, start_markers) |
|
|
candidates = [markdown_content[start_pos:start_pos + len(text_to_find)] for start_pos in start_pos_list] |
|
|
ratios = [SequenceMatcher(None, text_to_find, candidate).ratio() for candidate in candidates] |
|
|
best_ratio = max(ratios) |
|
|
best_match = candidates[ratios.index(best_ratio)] |
|
|
|
|
|
if best_matches: |
|
|
for prev_best_match, _ in best_matches: |
|
|
|
|
|
|
|
|
for overlap_size in range(50, 201): |
|
|
if len(prev_best_match) >= overlap_size and len(text_to_find) >= overlap_size: |
|
|
prev_end = prev_best_match[-overlap_size:] |
|
|
curr_start = text_to_find[:overlap_size] |
|
|
if prev_end == curr_start: |
|
|
print(f"Found overlap of {overlap_size} characters, slicing {overlap_size} chars from current chunk") |
|
|
best_match = best_match[overlap_size:] |
|
|
text_to_find = text_to_find[overlap_size:] |
|
|
break |
|
|
|
|
|
|
|
|
return best_match, text_to_find |
|
|
|
|
|
@app.get("/api/highlighted-document/{doc_id}") |
|
|
async def get_highlighted_document(doc_id: str, session_id: str): |
|
|
import os |
|
|
import json |
|
|
import re |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
import os |
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
session_path = os.path.join(current_dir, "scripts", "sessions", f"{session_id}.json") |
|
|
if not os.path.exists(session_path): |
|
|
return {"error": "Session not found"}, 404 |
|
|
|
|
|
try: |
|
|
with open(session_path, "r", encoding="utf-8") as f: |
|
|
session_data = json.load(f) |
|
|
|
|
|
|
|
|
texts = [item["text"] for item in session_data if item["doc_id"] == doc_id] |
|
|
|
|
|
if not texts: |
|
|
return {"error": "No texts found for this document"}, 404 |
|
|
|
|
|
|
|
|
highlighted_content, highlighting_stats = await highlight_doc_with_chunks_new(doc_id, texts) |
|
|
|
|
|
return { |
|
|
"content": highlighted_content, |
|
|
"doc_id": doc_id, |
|
|
"highlighted_count": highlighting_stats["highlighted_count"], |
|
|
"total_texts": highlighting_stats["total_texts"], |
|
|
"success_rate": highlighting_stats["success_rate"] |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing highlighted document {doc_id}: {e}") |
|
|
return {"error": "Error processing document"}, 500 |
|
|
|
|
|
def extract_sequence_from_id(chunk_id: str) -> int: |
|
|
"""Trích xuất sequence number từ chunk ID""" |
|
|
|
|
|
match = re.search(r'::C(\d+)$', chunk_id) |
|
|
if match: |
|
|
return int(match.group(1)) |
|
|
return 0 |
|
|
|
|
|
def load_document_chunks(doc_id: str) -> list: |
|
|
"""Load tất cả chunks của một document và sắp xếp theo thứ tự""" |
|
|
import os |
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
chunks_path = Path(current_dir) / "chunks" |
|
|
manifest_path = chunks_path / "chunks_manifest.json" |
|
|
|
|
|
if not manifest_path.exists(): |
|
|
return [] |
|
|
|
|
|
with open(manifest_path, "r", encoding="utf-8") as f: |
|
|
manifest = json.load(f) |
|
|
|
|
|
|
|
|
doc_chunks = [] |
|
|
for chunk_info in manifest["chunks"]: |
|
|
if chunk_info["id"].startswith(doc_id): |
|
|
chunk_file_path = chunk_info["path"] |
|
|
if os.path.exists(chunk_file_path): |
|
|
with open(chunk_file_path, "r", encoding="utf-8") as f: |
|
|
chunk_data = json.load(f) |
|
|
doc_chunks.append(chunk_data) |
|
|
|
|
|
|
|
|
doc_chunks.sort(key=lambda x: extract_sequence_from_id(x["id"])) |
|
|
return doc_chunks |
|
|
|
|
|
def reconstruct_document(chunks: list) -> str: |
|
|
"""Tái tạo lại document từ các chunks""" |
|
|
if not chunks: |
|
|
return "" |
|
|
|
|
|
document_parts = [] |
|
|
current_path = [] |
|
|
|
|
|
for chunk in chunks: |
|
|
content_type = chunk.get("content_type", "text") |
|
|
chunk_text = chunk.get("chunk_text", "") |
|
|
path = chunk.get("path", []) |
|
|
|
|
|
|
|
|
if path != current_path: |
|
|
|
|
|
for i, path_item in enumerate(path): |
|
|
if i >= len(current_path) or path_item != current_path[i]: |
|
|
|
|
|
if path_item and path_item not in ["ROOT", "TABLE"]: |
|
|
|
|
|
level = i + 1 |
|
|
header_marker = "#" * min(level, 6) |
|
|
document_parts.append(f"\n{header_marker} {path_item}\n") |
|
|
break |
|
|
current_path = path |
|
|
|
|
|
if content_type == "table": |
|
|
|
|
|
document_parts.append(f"\n{chunk_text}\n") |
|
|
else: |
|
|
|
|
|
if chunk_text.strip(): |
|
|
document_parts.append(chunk_text) |
|
|
|
|
|
return "\n".join(document_parts) |
|
|
|
|
|
def find_text_positions_in_reconstructed_doc(text_to_find: str, reconstructed_doc: str) -> list: |
|
|
"""Tìm tất cả vị trí của text trong document đã tái tạo""" |
|
|
positions = [] |
|
|
start = 0 |
|
|
|
|
|
while True: |
|
|
pos = reconstructed_doc.find(text_to_find, start) |
|
|
if pos == -1: |
|
|
break |
|
|
positions.append((pos, pos + len(text_to_find))) |
|
|
start = pos + 1 |
|
|
|
|
|
return positions |
|
|
|
|
|
def highlight_text_in_reconstructed_doc(texts_to_highlight: list, reconstructed_doc: str, chunks: list = None) -> str: |
|
|
"""Highlight text trong document đã tái tạo""" |
|
|
if not texts_to_highlight: |
|
|
return reconstructed_doc |
|
|
|
|
|
|
|
|
highlighted_doc = reconstructed_doc |
|
|
|
|
|
|
|
|
sorted_texts = sorted(texts_to_highlight, key=len, reverse=True) |
|
|
|
|
|
for i, text in enumerate(sorted_texts): |
|
|
if not text.strip(): |
|
|
continue |
|
|
|
|
|
|
|
|
positions = find_text_positions_in_reconstructed_doc(text, highlighted_doc) |
|
|
|
|
|
|
|
|
if not positions and chunks: |
|
|
for chunk in chunks: |
|
|
chunk_embedding = chunk.get('chunk_for_embedding', '') |
|
|
if text in chunk_embedding: |
|
|
|
|
|
highlighted_doc += f"\n\n{text}" |
|
|
positions = [(len(highlighted_doc) - len(text), len(highlighted_doc))] |
|
|
break |
|
|
|
|
|
|
|
|
for start, end in reversed(positions): |
|
|
highlighted_text = f'<span class="highlighted-text" data-index="{i}">{text}</span>' |
|
|
highlighted_doc = highlighted_doc[:start] + highlighted_text + highlighted_doc[end:] |
|
|
|
|
|
return highlighted_doc |
|
|
|
|
|
async def highlight_doc_with_chunks_new(doc_id: str, texts: list) -> tuple: |
|
|
"""Highlight document sử dụng chunks thay vì file markdown gốc""" |
|
|
|
|
|
chunks = load_document_chunks(doc_id) |
|
|
|
|
|
if not chunks: |
|
|
return f"⚠️ Không tìm thấy chunks cho document {doc_id}", { |
|
|
"highlighted_count": 0, |
|
|
"total_texts": len(texts), |
|
|
"success_rate": 0.0 |
|
|
} |
|
|
|
|
|
|
|
|
reconstructed_doc = reconstruct_document(chunks) |
|
|
|
|
|
if not reconstructed_doc.strip(): |
|
|
return f"⚠️ Document {doc_id} không có nội dung", { |
|
|
"highlighted_count": 0, |
|
|
"total_texts": len(texts), |
|
|
"success_rate": 0.0 |
|
|
} |
|
|
|
|
|
|
|
|
highlighted_doc = highlight_text_in_reconstructed_doc(texts, reconstructed_doc, chunks) |
|
|
|
|
|
|
|
|
highlighted_count = 0 |
|
|
for text in texts: |
|
|
if text.strip() and text in reconstructed_doc: |
|
|
highlighted_count += 1 |
|
|
|
|
|
total = len([t for t in texts if t.strip()]) |
|
|
success_rate = (highlighted_count / total * 100) if total > 0 else 0.0 |
|
|
|
|
|
return highlighted_doc, { |
|
|
"highlighted_count": highlighted_count, |
|
|
"total_texts": total, |
|
|
"success_rate": success_rate |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") |