|
|
from typing import Any, Dict, List |
|
|
import gradio as gr |
|
|
import json |
|
|
import re |
|
|
from difflib import SequenceMatcher |
|
|
from scripts.main import main, feedback |
|
|
import os |
|
|
from utils import count_tokens |
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
output_path = os.path.join(current_dir, "output.json") |
|
|
converted_path = os.path.join(current_dir, "converted") |
|
|
|
|
|
def extract_sequence_from_id(chunk_id: str) -> int: |
|
|
"""Trích xuất sequence number từ chunk ID""" |
|
|
|
|
|
match = re.search(r'::C(\d+)$', chunk_id) |
|
|
if match: |
|
|
return int(match.group(1)) |
|
|
return 0 |
|
|
|
|
|
def load_document_chunks(doc_id: str) -> List[Dict]: |
|
|
"""Load tất cả chunks của một document và sắp xếp theo thứ tự""" |
|
|
import os |
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
chunks_path = os.path.join(current_dir, "chunks") |
|
|
manifest_path = os.path.join(chunks_path, "chunks_manifest.json") |
|
|
|
|
|
if not os.path.exists(manifest_path): |
|
|
return [] |
|
|
|
|
|
with open(manifest_path, "r", encoding="utf-8") as f: |
|
|
manifest = json.load(f) |
|
|
|
|
|
|
|
|
doc_chunks = [] |
|
|
for chunk_info in manifest["chunks"]: |
|
|
if chunk_info["id"].startswith(doc_id): |
|
|
chunk_file_path = chunk_info["path"] |
|
|
if os.path.exists(chunk_file_path): |
|
|
with open(chunk_file_path, "r", encoding="utf-8") as f: |
|
|
chunk_data = json.load(f) |
|
|
doc_chunks.append(chunk_data) |
|
|
|
|
|
|
|
|
doc_chunks.sort(key=lambda x: extract_sequence_from_id(x["id"])) |
|
|
return doc_chunks |
|
|
|
|
|
def reconstruct_document(chunks: List[Dict]) -> str: |
|
|
"""Tái tạo lại document từ các chunks""" |
|
|
if not chunks: |
|
|
return "" |
|
|
|
|
|
document_parts = [] |
|
|
current_path = [] |
|
|
|
|
|
for chunk in chunks: |
|
|
content_type = chunk.get("content_type", "text") |
|
|
chunk_text = chunk.get("chunk_text", "") |
|
|
path = chunk.get("path", []) |
|
|
|
|
|
|
|
|
if path != current_path: |
|
|
|
|
|
for i, path_item in enumerate(path): |
|
|
if i >= len(current_path) or path_item != current_path[i]: |
|
|
|
|
|
if path_item and path_item not in ["ROOT", "TABLE"]: |
|
|
|
|
|
level = i + 1 |
|
|
|
|
|
if "Điểm" in path_item: |
|
|
header_marker = "####" |
|
|
else: |
|
|
header_marker = "#" * min(level, 6) |
|
|
document_parts.append(f"\n{header_marker} {path_item}\n") |
|
|
break |
|
|
current_path = path |
|
|
|
|
|
if content_type == "table": |
|
|
|
|
|
document_parts.append(f"\n{chunk_text}\n") |
|
|
else: |
|
|
|
|
|
if chunk_text.strip(): |
|
|
document_parts.append(chunk_text) |
|
|
|
|
|
return "\n".join(document_parts) |
|
|
|
|
|
def find_text_positions_in_reconstructed_doc(text_to_find: str, reconstructed_doc: str) -> List[tuple]: |
|
|
"""Tìm tất cả vị trí của text trong document đã tái tạo""" |
|
|
positions = [] |
|
|
start = 0 |
|
|
|
|
|
while True: |
|
|
pos = reconstructed_doc.find(text_to_find, start) |
|
|
if pos == -1: |
|
|
break |
|
|
positions.append((pos, pos + len(text_to_find))) |
|
|
start = pos + 1 |
|
|
|
|
|
return positions |
|
|
|
|
|
def highlight_text_in_reconstructed_doc(texts_to_highlight: List[str], reconstructed_doc: str, chunks: List[Dict] = None) -> str: |
|
|
"""Highlight text trong document đã tái tạo""" |
|
|
if not texts_to_highlight: |
|
|
return reconstructed_doc |
|
|
|
|
|
|
|
|
highlighted_doc = reconstructed_doc |
|
|
|
|
|
|
|
|
sorted_texts = sorted(texts_to_highlight, key=len, reverse=True) |
|
|
|
|
|
for i, text in enumerate(sorted_texts): |
|
|
if not text.strip(): |
|
|
continue |
|
|
|
|
|
|
|
|
positions = find_text_positions_in_reconstructed_doc(text, highlighted_doc) |
|
|
|
|
|
|
|
|
if not positions and chunks: |
|
|
for chunk in chunks: |
|
|
chunk_embedding = chunk.get('chunk_for_embedding', '') |
|
|
if text in chunk_embedding: |
|
|
|
|
|
highlighted_doc += f"\n\n{text}" |
|
|
positions = [(len(highlighted_doc) - len(text), len(highlighted_doc))] |
|
|
break |
|
|
|
|
|
|
|
|
for start, end in reversed(positions): |
|
|
token_count = count_tokens(text) |
|
|
highlighted_text = f'<span style="color:green; font-weight:bold; background-color:yellow;">{text}</span> ({token_count} tokens)' |
|
|
highlighted_doc = highlighted_doc[:start] + highlighted_text + highlighted_doc[end:] |
|
|
|
|
|
return highlighted_doc |
|
|
|
|
|
def format_highlighted_doc(highlighted_doc: str) -> str: |
|
|
"""Format highlighted doc to be more readable""" |
|
|
|
|
|
import re |
|
|
|
|
|
|
|
|
|
|
|
formatted_doc = re.sub(r'^\s*# (.+)$', r'<h1>\1</h1>', highlighted_doc, flags=re.MULTILINE) |
|
|
formatted_doc = re.sub(r'^\s*## (.+)$', r'<h2>\1</h2>', formatted_doc, flags=re.MULTILINE) |
|
|
formatted_doc = re.sub(r'^\s*### (.+)$', r'<h3>\1</h3>', formatted_doc, flags=re.MULTILINE) |
|
|
formatted_doc = re.sub(r'^\s*#### (.+)$', r'<h4>\1</h4>', formatted_doc, flags=re.MULTILINE) |
|
|
|
|
|
|
|
|
formatted_doc = formatted_doc.replace("\n", "<br>") |
|
|
return formatted_doc |
|
|
|
|
|
def highlight_doc_with_chunks(doc_id: str, texts: List[str]) -> str: |
|
|
"""Highlight document sử dụng chunks thay vì file markdown gốc""" |
|
|
|
|
|
chunks = load_document_chunks(doc_id) |
|
|
|
|
|
if not chunks: |
|
|
return f"⚠️ Không tìm thấy chunks cho document {doc_id}" |
|
|
|
|
|
|
|
|
reconstructed_doc = reconstruct_document(chunks) |
|
|
|
|
|
if not reconstructed_doc.strip(): |
|
|
return f"⚠️ Document {doc_id} không có nội dung" |
|
|
|
|
|
|
|
|
highlighted_doc = highlight_text_in_reconstructed_doc(texts, reconstructed_doc, chunks) |
|
|
|
|
|
|
|
|
highlighted_count = 0 |
|
|
for text in texts: |
|
|
if text.strip() and text in reconstructed_doc: |
|
|
highlighted_count += 1 |
|
|
total = len([t for t in texts if t.strip()]) |
|
|
success_rate = (highlighted_count / total * 100) if total > 0 else 0.0 |
|
|
|
|
|
summary = f""" |
|
|
<div style='background-color: #f0f0f0; padding: 10px; margin: 10px 0; border-radius: 5px;'> |
|
|
<h3>Highlight Summary:</h3> |
|
|
<p><strong>Document ID:</strong> {doc_id}</p> |
|
|
<p><strong>Actually highlighted:</strong> {highlighted_count}</p> |
|
|
<p><strong>Success rate:</strong> {success_rate:.1f}%</p> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
return summary + f"<pre style='white-space: pre-wrap;'>{format_highlighted_doc(highlighted_doc)}</pre>" |
|
|
|
|
|
def format_user_prompt(user_prompt: str) -> str: |
|
|
"""Format user prompt to be more readable""" |
|
|
|
|
|
token_count = count_tokens(user_prompt) |
|
|
formatted_prompt = user_prompt.replace("Chunk", "<strong>Chunk</strong>") |
|
|
|
|
|
formatted_prompt = formatted_prompt.replace("\n", "<br>") |
|
|
|
|
|
formatted_prompt = f"<p><strong>Total tokens:</strong> {token_count}</p><br>" + formatted_prompt |
|
|
return formatted_prompt |
|
|
|
|
|
|
|
|
current_session_id = None |
|
|
|
|
|
def get_feedback(is_like: bool, session_id: str): |
|
|
return feedback(is_like, session_id) |
|
|
|
|
|
def response_generator(query: str, top_k: int = 20, top_n: int = 10): |
|
|
global current_session_id |
|
|
response, session_id = main(query, top_k=top_k, top_n=top_n) |
|
|
current_session_id = session_id |
|
|
|
|
|
session_path = f"sessions/{session_id}.json" |
|
|
with open(session_path, "r", encoding="utf-8") as f: |
|
|
session_output = json.load(f) |
|
|
|
|
|
rag_results = session_output[0]["rag_results"] |
|
|
user_prompt = session_output[0]["user_prompt"] |
|
|
doc_ids_set = set([item["doc_id"] for item in rag_results]) |
|
|
chunks_retrieved = [{ |
|
|
"doc_id": doc_id, |
|
|
"texts": [item["text"] for item in rag_results if item["doc_id"] == doc_id] |
|
|
} for doc_id in doc_ids_set] |
|
|
|
|
|
highlighted_texts = [highlight_doc_with_chunks(chunk["doc_id"], chunk["texts"]) for chunk in chunks_retrieved] |
|
|
user_prompt = format_user_prompt(user_prompt) |
|
|
|
|
|
while len(highlighted_texts) < 15: |
|
|
highlighted_texts.append("") |
|
|
|
|
|
|
|
|
return response, current_session_id, *highlighted_texts, user_prompt |
|
|
|
|
|
def get_like_feedback(): |
|
|
global current_session_id |
|
|
if current_session_id: |
|
|
result = get_feedback(True, current_session_id) |
|
|
print(f"Like feedback: {result}") |
|
|
return f"✅ {result}" |
|
|
return "❌ No active session" |
|
|
|
|
|
def get_dislike_feedback(): |
|
|
global current_session_id |
|
|
if current_session_id: |
|
|
result = get_feedback(False, current_session_id) |
|
|
print(f"Dislike feedback: {result}") |
|
|
return f"👎 {result}" |
|
|
return "❌ No active session" |
|
|
|
|
|
def clear_feedback(): |
|
|
return "" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="RAG") as demo: |
|
|
gr.Markdown("# RAG System") |
|
|
gr.Markdown("Query the document and see highlighted results (Link Google Drive: https://drive.google.com/drive/folders/1gQ-KCaTHIoYWxds_UnrDrGu4sE1yU8PJ?usp=sharing)") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
query_input = gr.Textbox(lines=5, label="Query", placeholder="Enter your question here...") |
|
|
with gr.Row(): |
|
|
gr.HTML("") |
|
|
submit_btn = gr.Button("Submit", variant="primary", size="sm") |
|
|
with gr.Column(scale=1): |
|
|
response_output = gr.Textbox(lines=8, label="Response", interactive=True) |
|
|
session_id_output = gr.Textbox(lines=1, label="Current Session ID", interactive=False) |
|
|
feedback_output = gr.Textbox(lines=2, label="Feedback Status", interactive=False) |
|
|
with gr.Row(): |
|
|
gr.HTML("") |
|
|
like_btn = gr.Button("Like", variant="primary", size="sm") |
|
|
dislike_btn = gr.Button("Dislike", variant="primary", size="sm") |
|
|
|
|
|
|
|
|
with gr.Tabs(): |
|
|
html_outputs = [] |
|
|
for i in range(15): |
|
|
with gr.TabItem(f"Document Chunk {i+1}"): |
|
|
html_outputs.append(gr.HTML()) |
|
|
|
|
|
|
|
|
with gr.TabItem("User Prompt"): |
|
|
user_prompt_output = gr.HTML(label="User Prompt") |
|
|
|
|
|
submit_btn.click( |
|
|
fn=response_generator, |
|
|
inputs=[query_input], |
|
|
outputs=[response_output, session_id_output] + html_outputs + [user_prompt_output] |
|
|
).then( |
|
|
fn=clear_feedback, |
|
|
inputs=[], |
|
|
outputs=[feedback_output] |
|
|
) |
|
|
like_btn.click( |
|
|
fn=get_like_feedback, |
|
|
inputs=[], |
|
|
outputs=[feedback_output] |
|
|
) |
|
|
dislike_btn.click( |
|
|
fn=get_dislike_feedback, |
|
|
inputs=[], |
|
|
outputs=[feedback_output] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |
|
|
|