|
|
""" |
|
|
Question Answering System trained on SQuAD 2.0 |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
current_dir = Path(__file__).parent |
|
|
sys.path.insert(0, str(current_dir)) |
|
|
|
|
|
from src.models.bert_based_model import BertBasedQAModel |
|
|
from src.config.model_configs import OriginalBertQAConfig |
|
|
from src.etl.types import QAExample |
|
|
|
|
|
model = BertBasedQAModel.load_from_experiment( |
|
|
experiment_dir=Path("checkpoint"), config_class=OriginalBertQAConfig, device="cpu" |
|
|
) |
|
|
|
|
|
|
|
|
def answer_question(context: str, question: str) -> str: |
|
|
"""Process QA request and return answer.""" |
|
|
if not context.strip(): |
|
|
return "Please provide context text." |
|
|
if not question.strip(): |
|
|
return "Please provide a question." |
|
|
|
|
|
try: |
|
|
example = QAExample( |
|
|
question_id="demo", |
|
|
title="Demo", |
|
|
question=question.strip(), |
|
|
context=context.strip(), |
|
|
answer_texts=[], |
|
|
answer_starts=[], |
|
|
|
|
|
|
|
|
is_impossible=True, |
|
|
) |
|
|
|
|
|
predictions = model.predict({"demo": example}) |
|
|
answer = predictions["demo"].predicted_answer |
|
|
|
|
|
return answer if answer else "No answer found." |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=answer_question, |
|
|
inputs=[ |
|
|
gr.Textbox(lines=8, placeholder="Enter context paragraph...", label="Context"), |
|
|
gr.Textbox(placeholder="Enter your question...", label="Question"), |
|
|
], |
|
|
outputs=gr.Textbox(label="Answer", show_copy_button=True, lines=4), |
|
|
title="SQuAD 2.0 Question Answering", |
|
|
description="BERT-base model fine-tuned on SQuAD 2.0 dataset", |
|
|
allow_flagging="never", |
|
|
deep_link=False, |
|
|
theme="earneleh/paris", |
|
|
|
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|