alx-d commited on
Commit
96627ba
·
verified ·
1 Parent(s): 40767bb

Create advanced_rag.py

Browse files
Files changed (1) hide show
  1. advanced_rag.py +124 -0
advanced_rag.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
+ from typing import List
4
+
5
+ from langchain_community.llms import Replicate # importing from langchain depricated; use langchain_community for several modules here
6
+ from langchain_community.document_loaders import OnlinePDFLoader
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain_community.embeddings import CohereEmbeddings
10
+ from langchain_community.retrievers import BM25Retriever
11
+ from langchain.retrievers import EnsembleRetriever
12
+ from langchain.retrievers import ContextualCompressionRetriever
13
+ from langchain.retrievers.document_compressors import CohereRerank
14
+ from langchain.prompts import ChatPromptTemplate
15
+ from langchain.schema import StrOutputParser
16
+ from langchain_core.runnables import RunnableParallel, RunnablePassthrough
17
+
18
+
19
+ class ElevatedRagChain:
20
+ '''
21
+ Class ElevatedRagChain integrates various components from the langchain library to build
22
+ an advanced retrieval-augmented generation (RAG) system designed to process documents
23
+ by reading in, chunking, embedding, and adding their chunk embeddings to FAISS vector store
24
+ for efficient retrieval. It uses the embeddings to retrieve relevant document chunks
25
+ in response to user queries.
26
+ The chunks are retrieved using an ensemble retriever (BM25 retriever + FAISS retriver)
27
+ and passed through a Cohere reranker before being used as context
28
+ for generating answers using a Llama 2 large language model (LLM).
29
+ '''
30
+ def __init__(self) -> None:
31
+ '''
32
+ Initialize the class with predefined model, embedding function, weights, and top_k value
33
+ '''
34
+ self.llama2_70b = 'meta/llama-2-70b-chat:2d19859030ff705a87c746f7e96eea03aefb71f166725aee39692f1476566d48'
35
+ self.embed_func = CohereEmbeddings(model="embed-english-light-v3.0")
36
+ self.bm25_weight = 0.6
37
+ self.faiss_weight = 0.4
38
+ self.top_k = 5
39
+
40
+
41
+ def add_pdfs_to_vectore_store(
42
+ self,
43
+ pdf_links: List,
44
+ chunk_size: int=1500,
45
+ ) -> None:
46
+ '''
47
+ Processes PDF documents by loading, chunking, embedding, and adding them to a FAISS vector store.
48
+ Build an advanced RAG system
49
+ Args:
50
+ pdf_links (List): list of URLs pointing to the PDF documents to be processed
51
+ chunk_size (int, optional): size of text chunks to split the documents into, defaults to 1500
52
+ '''
53
+ # load pdfs
54
+ self.raw_data = [ OnlinePDFLoader(doc).load()[0] for doc in pdf_links ]
55
+
56
+ # chunk text
57
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=100)
58
+ self.split_data = self.text_splitter.split_documents(self.raw_data)
59
+
60
+ # add chunks to BM25 retriever
61
+ self.bm25_retriever = BM25Retriever.from_documents(self.split_data)
62
+ self.bm25_retriever.k = self.top_k
63
+
64
+ # embed and add chunks to vectore store
65
+ self.vector_store = FAISS.from_documents(self.split_data, self.embed_func)
66
+ self.faiss_retriever = self.vector_store.as_retriever(search_kwargs={"k": self.top_k})
67
+ print("All PDFs processed and added to vectore store.")
68
+
69
+ # build advanced RAG system
70
+ self.build_elevated_rag_system()
71
+ print("RAG system is built successfully.")
72
+
73
+
74
+ def build_elevated_rag_system(self) -> None:
75
+ '''
76
+ Build an advanced RAG system from different components:
77
+ * BM25 retriever
78
+ * FAISS vector store retriever
79
+ * Llama 2 model
80
+ '''
81
+ # combine BM25 and FAISS retrievers into an ensemble retriever
82
+ self.ensemble_retriever = EnsembleRetriever(
83
+ retrievers=[self.bm25_retriever, self.faiss_retriever],
84
+ weights=[self.bm25_weight, self.faiss_weight]
85
+ )
86
+
87
+ # use reranker to improve retrieval quality
88
+ self.reranker = CohereRerank(top_n=5)
89
+ self.rerank_retriever = ContextualCompressionRetriever( # combine ensemble retriever and reranker
90
+ base_retriever=self.ensemble_retriever,
91
+ base_compressor=self.reranker,
92
+ )
93
+
94
+ # define prompt template for the language model
95
+ RAG_PROMPT_TEMPLATE = """\
96
+ Use the following context to provide a detailed technical answer the user's question.
97
+ Do not use an introduction similar to "Based on the provided documents, ...", just answer the question.
98
+ If you don't know the answer, please respond with "I don't know".
99
+
100
+ Context:
101
+ {context}
102
+
103
+ User's question:
104
+ {question}
105
+ """
106
+ self.rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
107
+ self.str_output_parser = StrOutputParser()
108
+
109
+ # parallel execution of context retrieval and question passing
110
+ self.entry_point_and_elevated_retriever = RunnableParallel(
111
+ {
112
+ "context" : self.rerank_retriever,
113
+ "question" : RunnablePassthrough()
114
+ }
115
+ )
116
+
117
+ # initialize Llama 2 model with specific parameters
118
+ self.llm = Replicate(
119
+ model=self.llama2_70b,
120
+ model_kwargs={"temperature": 0.5,"top_p": 1, "max_new_tokens":1000}
121
+ )
122
+
123
+ # chain components to form final elevated RAG system using LangChain Expression Language (LCEL)
124
+ self.elevated_rag_chain = self.entry_point_and_elevated_retriever | self.rag_prompt | self.llm #| self.str_output_parser