|
|
import os |
|
|
import sys |
|
|
from haystack.document_stores.in_memory import InMemoryDocumentStore |
|
|
from datasets import load_from_disk |
|
|
from haystack import Document |
|
|
from haystack.components.writers import DocumentWriter |
|
|
from haystack.components.embedders import SentenceTransformersDocumentEmbedder |
|
|
from haystack.components.preprocessors.document_splitter import DocumentSplitter |
|
|
from haystack import Pipeline |
|
|
from haystack.components.retrievers.in_memory import ( |
|
|
InMemoryBM25Retriever, |
|
|
InMemoryEmbeddingRetriever, |
|
|
) |
|
|
from haystack.components.embedders import SentenceTransformersTextEmbedder |
|
|
from haystack.components.joiners import DocumentJoiner |
|
|
|
|
|
|
|
|
from haystack.components.rankers import SentenceTransformersSimilarityRanker |
|
|
from haystack.document_stores.types import DuplicatePolicy |
|
|
from haystack.components.converters import PyPDFToDocument |
|
|
from haystack.components.preprocessors import DocumentCleaner |
|
|
from haystack.components.builders import PromptBuilder |
|
|
from pathlib import Path |
|
|
from haystack.components.converters import DOCXToDocument |
|
|
import re |
|
|
import argparse |
|
|
|
|
|
|
|
|
""" |
|
|
python hybrid.py -c newstore.store │ |
|
|
python hybrid.py -r newstore.store -q "who is pufendorf" |
|
|
""" |
|
|
|
|
|
embedding_model = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
|
|
|
|
|
|
reranker_model = "BAAI/bge-reranker-base" |
|
|
|
|
|
|
|
|
def build_store_from_dir(dir_path: str) -> InMemoryDocumentStore: |
|
|
root = Path(dir_path) |
|
|
pdfs = sorted(str(p) for p in root.rglob("*.pdf")) |
|
|
docxs = sorted(str(p) for p in root.rglob("*.docx")) |
|
|
|
|
|
print(pdfs) |
|
|
print(docxs) |
|
|
|
|
|
pdf_conv = PyPDFToDocument() |
|
|
docx_conv = DOCXToDocument() |
|
|
|
|
|
docs = [] |
|
|
if pdfs: |
|
|
out = pdf_conv.run(sources=pdfs, meta=[{"source": p} for p in pdfs]) |
|
|
docs.extend(out["documents"]) |
|
|
if docxs: |
|
|
out = docx_conv.run(sources=docxs, meta=[{"source": p} for p in docxs]) |
|
|
docs.extend(out["documents"]) |
|
|
|
|
|
return docs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_index_split(docs, doc_store, split_length=5, split_overlap=1): |
|
|
document_splitter = DocumentSplitter( |
|
|
split_by="sentence", split_length=split_length, split_overlap=split_overlap |
|
|
) |
|
|
document_embedder = SentenceTransformersDocumentEmbedder( |
|
|
model=embedding_model, |
|
|
) |
|
|
document_writer = DocumentWriter(doc_store, policy=DuplicatePolicy.SKIP) |
|
|
|
|
|
indexing_pipeline = Pipeline() |
|
|
indexing_pipeline.add_component("document_splitter", document_splitter) |
|
|
indexing_pipeline.add_component("document_embedder", document_embedder) |
|
|
indexing_pipeline.add_component("document_writer", document_writer) |
|
|
|
|
|
indexing_pipeline.connect("document_splitter", "document_embedder") |
|
|
indexing_pipeline.connect("document_embedder", "document_writer") |
|
|
|
|
|
indexing_pipeline.run({"document_splitter": {"documents": docs}}) |
|
|
|
|
|
hybrid_retrieval = create_hybrid_retriever(doc_store) |
|
|
return hybrid_retrieval |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_hybrid_retriever(doc_store): |
|
|
text_embedder = SentenceTransformersTextEmbedder( |
|
|
model=embedding_model, |
|
|
) |
|
|
embedding_retriever = InMemoryEmbeddingRetriever(doc_store) |
|
|
bm25_retriever = InMemoryBM25Retriever(doc_store) |
|
|
|
|
|
document_joiner = DocumentJoiner() |
|
|
|
|
|
|
|
|
ranker = SentenceTransformersSimilarityRanker(model=reranker_model) |
|
|
|
|
|
hybrid_retrieval = Pipeline() |
|
|
hybrid_retrieval.add_component("text_embedder", text_embedder) |
|
|
hybrid_retrieval.add_component("embedding_retriever", embedding_retriever) |
|
|
hybrid_retrieval.add_component("bm25_retriever", bm25_retriever) |
|
|
hybrid_retrieval.add_component("document_joiner", document_joiner) |
|
|
hybrid_retrieval.add_component("ranker", ranker) |
|
|
|
|
|
hybrid_retrieval.connect("text_embedder", "embedding_retriever") |
|
|
hybrid_retrieval.connect("bm25_retriever", "document_joiner") |
|
|
hybrid_retrieval.connect("embedding_retriever", "document_joiner") |
|
|
hybrid_retrieval.connect("document_joiner", "ranker") |
|
|
|
|
|
return hybrid_retrieval |
|
|
|
|
|
|
|
|
def create_embedding_retriever(doc_store): |
|
|
text_embedder = SentenceTransformersTextEmbedder( |
|
|
model=embedding_model, |
|
|
) |
|
|
embedding_retriever = InMemoryEmbeddingRetriever(doc_store) |
|
|
|
|
|
ranker = SentenceTransformersSimilarityRanker(model=reranker_model) |
|
|
|
|
|
embedding_retrieval = Pipeline() |
|
|
embedding_retrieval.add_component("text_embedder", text_embedder) |
|
|
embedding_retrieval.add_component("embedding_retriever", embedding_retriever) |
|
|
embedding_retrieval.add_component("ranker", ranker) |
|
|
|
|
|
embedding_retrieval.connect("text_embedder", "embedding_retriever") |
|
|
embedding_retrieval.connect("embedding_retriever", "ranker") |
|
|
|
|
|
return embedding_retrieval |
|
|
|
|
|
|
|
|
def create_bm25_retriever(doc_store): |
|
|
bm25_retriever = InMemoryBM25Retriever(doc_store) |
|
|
|
|
|
document_joiner = DocumentJoiner() |
|
|
ranker = SentenceTransformersSimilarityRanker(model=reranker_model) |
|
|
|
|
|
bm25_retrieval = Pipeline() |
|
|
bm25_retrieval.add_component("bm25_retriever", bm25_retriever) |
|
|
bm25_retrieval.add_component("ranker", ranker) |
|
|
bm25_retrieval.connect("bm25_retriever", "ranker") |
|
|
|
|
|
return bm25_retrieval |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve(retriever, query, top_k=8, scale=True): |
|
|
result = retriever.run( |
|
|
{ |
|
|
"text_embedder": {"text": query}, |
|
|
"bm25_retriever": { |
|
|
"query": query, |
|
|
"top_k": top_k, |
|
|
"scale_score": scale, |
|
|
|
|
|
|
|
|
|
|
|
}, |
|
|
"embedding_retriever": {"top_k": top_k, "scale_score": True}, |
|
|
"ranker": {"query": query, "top_k": top_k, "scale_score": True}, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
return result["ranker"]["documents"] |
|
|
|
|
|
|
|
|
def retrieve_embedded(retriever, query, top_k=8, scale=True): |
|
|
result = retriever.run( |
|
|
{ |
|
|
"text_embedder": {"text": query}, |
|
|
"embedding_retriever": {"top_k": top_k, "scale_score": scale}, |
|
|
"ranker": {"query": query, "top_k": top_k, "scale_score": scale}, |
|
|
} |
|
|
) |
|
|
return result["ranker"]["documents"] |
|
|
|
|
|
|
|
|
def retrieve_bm25(retriever, query, top_k=8, scale=True): |
|
|
result = retriever.run( |
|
|
{ |
|
|
"bm25_retriever": { |
|
|
"query": query, |
|
|
"top_k": top_k, |
|
|
"scale_score": scale, |
|
|
|
|
|
|
|
|
|
|
|
}, |
|
|
"ranker": {"query": query, "top_k": top_k, "scale_score": True}, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
return result["ranker"]["documents"] |
|
|
|
|
|
|
|
|
def print_res(doc, width=0): |
|
|
try: |
|
|
txt = doc.meta["researcher_name"] + ":" + " ".join(doc.content.split()) |
|
|
except KeyError: |
|
|
txt = " ".join(doc.content.split()) |
|
|
if width > 0: |
|
|
txt_width = width - 8 - 3 - 1 |
|
|
txt = txt[0:txt_width] + "..." |
|
|
print("{:.5f}".format(doc.score), txt) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
terminal_width = os.get_terminal_size().columns |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"-c", "--create_store", help="Create a new data store.", default=None |
|
|
) |
|
|
parser.add_argument("-d", "--dataset", help="Dataset filename.", default=None) |
|
|
parser.add_argument("-r", "--read_store", help="Read a data store.", default=None) |
|
|
parser.add_argument( |
|
|
"-s", |
|
|
"--scale", |
|
|
action="store_false", |
|
|
help="Do not scale retrieved scores.", |
|
|
default=True, |
|
|
) |
|
|
parser.add_argument("--top_k", type=int, help="Retriever top_k.", default=8) |
|
|
parser.add_argument("-q", "--query", help="Query DBs.", default=None) |
|
|
args = parser.parse_args() |
|
|
query = args.query |
|
|
|
|
|
if args.create_store: |
|
|
docs = build_store_from_dir("../Gradio/docs") |
|
|
rs_doc_store = InMemoryDocumentStore() |
|
|
print("Starting create_index_nosplit()") |
|
|
create_index_split(docs, rs_doc_store) |
|
|
rs_doc_store.save_to_disk(args.create_store) |
|
|
print("Ready create_index_nosplit()") |
|
|
|
|
|
if not args.query: |
|
|
sys.exit(0) |
|
|
|
|
|
if not args.read_store and not args.create_store: |
|
|
args.read_store = "research_docs_ns.store" |
|
|
elif not args.read_store and args.create_store: |
|
|
args.read_store = args.create_store |
|
|
print(f"Loading document store {args.read_store}...") |
|
|
doc_store = InMemoryDocumentStore().load_from_disk(args.read_store) |
|
|
print(f"Number of documents: {doc_store.count_documents()}.") |
|
|
|
|
|
|
|
|
hybrid_retrieval = create_hybrid_retriever(doc_store) |
|
|
|
|
|
documents = retrieve(hybrid_retrieval, query, top_k=args.top_k, scale=args.scale) |
|
|
print("=" * 80) |
|
|
print("== Hybrid") |
|
|
print("=" * 80) |
|
|
for doc in documents: |
|
|
|
|
|
print_res(doc, terminal_width) |
|
|
|
|
|
embedding_retrieval = create_embedding_retriever(doc_store) |
|
|
documents = retrieve_embedded( |
|
|
embedding_retrieval, query, top_k=args.top_k, scale=args.scale |
|
|
) |
|
|
print("=" * 80) |
|
|
print("== Embedding") |
|
|
print("=" * 80) |
|
|
for doc in documents: |
|
|
print_res(doc, terminal_width) |
|
|
|
|
|
bm25_retrieval = create_bm25_retriever(doc_store) |
|
|
documents = retrieve_bm25(bm25_retrieval, query, top_k=args.top_k, scale=args.scale) |
|
|
print("=" * 80) |
|
|
print("== bm25") |
|
|
print("=" * 80) |
|
|
for doc in documents: |
|
|
print_res(doc, terminal_width) |
|
|
|