SvP / hybrid.py
pberck's picture
HayStack docstore and hybrid context in Pufendorf bot.
3dc3f9b
import os
import sys
from haystack.document_stores.in_memory import InMemoryDocumentStore
from datasets import load_from_disk
from haystack import Document
from haystack.components.writers import DocumentWriter
from haystack.components.embedders import SentenceTransformersDocumentEmbedder
from haystack.components.preprocessors.document_splitter import DocumentSplitter
from haystack import Pipeline
from haystack.components.retrievers.in_memory import (
InMemoryBM25Retriever,
InMemoryEmbeddingRetriever,
)
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.joiners import DocumentJoiner
# from haystack.components.rankers import TransformersSimilarityRanker
from haystack.components.rankers import SentenceTransformersSimilarityRanker
from haystack.document_stores.types import DuplicatePolicy
from haystack.components.converters import PyPDFToDocument
from haystack.components.preprocessors import DocumentCleaner
from haystack.components.builders import PromptBuilder
from pathlib import Path
from haystack.components.converters import DOCXToDocument
import re
import argparse
"""
python hybrid.py -c newstore.store │
python hybrid.py -r newstore.store -q "who is pufendorf"
"""
embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
# see https://huggingface.co/BAAI/bge-m3
reranker_model = "BAAI/bge-reranker-base"
def build_store_from_dir(dir_path: str) -> InMemoryDocumentStore:
root = Path(dir_path)
pdfs = sorted(str(p) for p in root.rglob("*.pdf"))
docxs = sorted(str(p) for p in root.rglob("*.docx"))
print(pdfs)
print(docxs)
pdf_conv = PyPDFToDocument()
docx_conv = DOCXToDocument()
docs = []
if pdfs:
out = pdf_conv.run(sources=pdfs, meta=[{"source": p} for p in pdfs])
docs.extend(out["documents"])
if docxs:
out = docx_conv.run(sources=docxs, meta=[{"source": p} for p in docxs])
docs.extend(out["documents"])
return docs
# Example usage:
# store = build_store_from_dir("/path/to/folder")
# print(len(store.filter_documents({})))
# As above, but splits the contents into sentences.
def create_index_split(docs, doc_store, split_length=5, split_overlap=1):
document_splitter = DocumentSplitter(
split_by="sentence", split_length=split_length, split_overlap=split_overlap
)
document_embedder = SentenceTransformersDocumentEmbedder(
model=embedding_model,
)
document_writer = DocumentWriter(doc_store, policy=DuplicatePolicy.SKIP)
indexing_pipeline = Pipeline()
indexing_pipeline.add_component("document_splitter", document_splitter)
indexing_pipeline.add_component("document_embedder", document_embedder)
indexing_pipeline.add_component("document_writer", document_writer)
indexing_pipeline.connect("document_splitter", "document_embedder")
indexing_pipeline.connect("document_embedder", "document_writer")
indexing_pipeline.run({"document_splitter": {"documents": docs}})
hybrid_retrieval = create_hybrid_retriever(doc_store)
return hybrid_retrieval
# Just the retriever pipeline on a document store.
# Creates an embedding and BM25 retriever on the doc_store.
def create_hybrid_retriever(doc_store):
text_embedder = SentenceTransformersTextEmbedder(
model=embedding_model,
)
embedding_retriever = InMemoryEmbeddingRetriever(doc_store)
bm25_retriever = InMemoryBM25Retriever(doc_store)
document_joiner = DocumentJoiner()
# ranker = TransformersSimilarityRanker(model=reranker_model)
# Needs haystack-ai >= 2.14
ranker = SentenceTransformersSimilarityRanker(model=reranker_model)
hybrid_retrieval = Pipeline()
hybrid_retrieval.add_component("text_embedder", text_embedder)
hybrid_retrieval.add_component("embedding_retriever", embedding_retriever)
hybrid_retrieval.add_component("bm25_retriever", bm25_retriever)
hybrid_retrieval.add_component("document_joiner", document_joiner)
hybrid_retrieval.add_component("ranker", ranker)
hybrid_retrieval.connect("text_embedder", "embedding_retriever")
hybrid_retrieval.connect("bm25_retriever", "document_joiner")
hybrid_retrieval.connect("embedding_retriever", "document_joiner")
hybrid_retrieval.connect("document_joiner", "ranker")
return hybrid_retrieval
def create_embedding_retriever(doc_store):
text_embedder = SentenceTransformersTextEmbedder(
model=embedding_model, # "BAAI/bge-small-en-v1.5" #, device=ComponentDevice.from_str("cuda:0")
)
embedding_retriever = InMemoryEmbeddingRetriever(doc_store)
ranker = SentenceTransformersSimilarityRanker(model=reranker_model)
embedding_retrieval = Pipeline()
embedding_retrieval.add_component("text_embedder", text_embedder)
embedding_retrieval.add_component("embedding_retriever", embedding_retriever)
embedding_retrieval.add_component("ranker", ranker)
embedding_retrieval.connect("text_embedder", "embedding_retriever")
embedding_retrieval.connect("embedding_retriever", "ranker")
return embedding_retrieval
def create_bm25_retriever(doc_store):
bm25_retriever = InMemoryBM25Retriever(doc_store)
document_joiner = DocumentJoiner()
ranker = SentenceTransformersSimilarityRanker(model=reranker_model)
bm25_retrieval = Pipeline()
bm25_retrieval.add_component("bm25_retriever", bm25_retriever)
bm25_retrieval.add_component("ranker", ranker)
bm25_retrieval.connect("bm25_retriever", "ranker")
return bm25_retrieval
# Run the pre-defined retrievers, returns the top_k best documents.
# We can filter the doc store if we find a name in the query.
# filters = {
# "operator": "AND",
# "conditions": [
# {"field": "meta.type", "operator": "==", "value": "article"},
# {"field": "meta.genre", "operator": "in", "value": ["economy", "politics"]},
# ],
# }
# results = DocumentStore.filter_documents(filters=filters)
def retrieve(retriever, query, top_k=8, scale=True):
result = retriever.run(
{
"text_embedder": {"text": query},
"bm25_retriever": {
"query": query,
"top_k": top_k,
"scale_score": scale,
# "filters": {"field": "meta.researcher_name",
# "operator": "==",
# "value": "P. Berck"}
},
"embedding_retriever": {"top_k": top_k, "scale_score": True},
"ranker": {"query": query, "top_k": top_k, "scale_score": True},
}
)
# print(result)
# pretty_print_results(result["ranker"])
return result["ranker"]["documents"]
def retrieve_embedded(retriever, query, top_k=8, scale=True):
result = retriever.run(
{
"text_embedder": {"text": query},
"embedding_retriever": {"top_k": top_k, "scale_score": scale},
"ranker": {"query": query, "top_k": top_k, "scale_score": scale},
}
)
return result["ranker"]["documents"]
def retrieve_bm25(retriever, query, top_k=8, scale=True):
result = retriever.run(
{
"bm25_retriever": {
"query": query,
"top_k": top_k,
"scale_score": scale,
# "filters": {"field": "meta.researcher_name",
# "operator": "==",
# "value": "P. Berck"}
},
"ranker": {"query": query, "top_k": top_k, "scale_score": True},
}
)
# print(result)
# pretty_print_results(result["ranker"])
return result["ranker"]["documents"]
def print_res(doc, width=0):
try:
txt = doc.meta["researcher_name"] + ":" + " ".join(doc.content.split())
except KeyError:
txt = " ".join(doc.content.split())
if width > 0:
txt_width = width - 8 - 3 - 1 # float and ... and LF
txt = txt[0:txt_width] + "..."
print("{:.5f}".format(doc.score), txt)
if __name__ == "__main__":
terminal_width = os.get_terminal_size().columns
parser = argparse.ArgumentParser()
parser.add_argument(
"-c", "--create_store", help="Create a new data store.", default=None
)
parser.add_argument("-d", "--dataset", help="Dataset filename.", default=None)
parser.add_argument("-r", "--read_store", help="Read a data store.", default=None)
parser.add_argument(
"-s",
"--scale",
action="store_false",
help="Do not scale retrieved scores.",
default=True,
)
parser.add_argument("--top_k", type=int, help="Retriever top_k.", default=8)
parser.add_argument("-q", "--query", help="Query DBs.", default=None)
args = parser.parse_args()
query = args.query
if args.create_store:
docs = build_store_from_dir("../Gradio/docs")
rs_doc_store = InMemoryDocumentStore()
print("Starting create_index_nosplit()")
create_index_split(docs, rs_doc_store)
rs_doc_store.save_to_disk(args.create_store)
print("Ready create_index_nosplit()")
if not args.query:
sys.exit(0)
if not args.read_store and not args.create_store:
args.read_store = "research_docs_ns.store"
elif not args.read_store and args.create_store:
args.read_store = args.create_store
print(f"Loading document store {args.read_store}...")
doc_store = InMemoryDocumentStore().load_from_disk(args.read_store)
print(f"Number of documents: {doc_store.count_documents()}.")
# Docs are already indexed/embedded in the store.
hybrid_retrieval = create_hybrid_retriever(doc_store)
documents = retrieve(hybrid_retrieval, query, top_k=args.top_k, scale=args.scale)
print("=" * 80)
print("== Hybrid")
print("=" * 80)
for doc in documents:
# print(doc.id, doc.meta["names"], ":", doc.meta["title"])
print_res(doc, terminal_width)
embedding_retrieval = create_embedding_retriever(doc_store)
documents = retrieve_embedded(
embedding_retrieval, query, top_k=args.top_k, scale=args.scale
)
print("=" * 80)
print("== Embedding")
print("=" * 80)
for doc in documents:
print_res(doc, terminal_width)
bm25_retrieval = create_bm25_retriever(doc_store)
documents = retrieve_bm25(bm25_retrieval, query, top_k=args.top_k, scale=args.scale)
print("=" * 80)
print("== bm25")
print("=" * 80)
for doc in documents:
print_res(doc, terminal_width)