alx-d commited on
Commit
ec42367
·
verified ·
1 Parent(s): 96627ba

Update advanced_rag.py

Browse files
Files changed (1) hide show
  1. advanced_rag.py +76 -54
advanced_rag.py CHANGED
@@ -2,7 +2,11 @@ import os
2
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
  from typing import List
4
 
5
- from langchain_community.llms import Replicate # importing from langchain depricated; use langchain_community for several modules here
 
 
 
 
6
  from langchain_community.document_loaders import OnlinePDFLoader
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain_community.vectorstores import FAISS
@@ -15,110 +19,128 @@ from langchain.prompts import ChatPromptTemplate
15
  from langchain.schema import StrOutputParser
16
  from langchain_core.runnables import RunnableParallel, RunnablePassthrough
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  class ElevatedRagChain:
20
  '''
21
- Class ElevatedRagChain integrates various components from the langchain library to build
22
- an advanced retrieval-augmented generation (RAG) system designed to process documents
23
- by reading in, chunking, embedding, and adding their chunk embeddings to FAISS vector store
24
- for efficient retrieval. It uses the embeddings to retrieve relevant document chunks
25
- in response to user queries.
26
- The chunks are retrieved using an ensemble retriever (BM25 retriever + FAISS retriver)
27
- and passed through a Cohere reranker before being used as context
28
- for generating answers using a Llama 2 large language model (LLM).
29
  '''
30
  def __init__(self) -> None:
31
  '''
32
- Initialize the class with predefined model, embedding function, weights, and top_k value
33
  '''
34
- self.llama2_70b = 'meta/llama-2-70b-chat:2d19859030ff705a87c746f7e96eea03aefb71f166725aee39692f1476566d48'
35
  self.embed_func = CohereEmbeddings(model="embed-english-light-v3.0")
36
  self.bm25_weight = 0.6
37
  self.faiss_weight = 0.4
38
  self.top_k = 5
39
 
40
-
41
  def add_pdfs_to_vectore_store(
42
  self,
43
  pdf_links: List,
44
- chunk_size: int=1500,
45
- ) -> None:
46
  '''
47
  Processes PDF documents by loading, chunking, embedding, and adding them to a FAISS vector store.
48
- Build an advanced RAG system
49
  Args:
50
- pdf_links (List): list of URLs pointing to the PDF documents to be processed
51
- chunk_size (int, optional): size of text chunks to split the documents into, defaults to 1500
52
  '''
53
- # load pdfs
54
- self.raw_data = [ OnlinePDFLoader(doc).load()[0] for doc in pdf_links ]
55
 
56
- # chunk text
57
  self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=100)
58
  self.split_data = self.text_splitter.split_documents(self.raw_data)
59
 
60
- # add chunks to BM25 retriever
61
- self.bm25_retriever = BM25Retriever.from_documents(self.split_data)
62
  self.bm25_retriever.k = self.top_k
63
 
64
- # embed and add chunks to vectore store
65
- self.vector_store = FAISS.from_documents(self.split_data, self.embed_func)
66
- self.faiss_retriever = self.vector_store.as_retriever(search_kwargs={"k": self.top_k})
67
- print("All PDFs processed and added to vectore store.")
68
 
69
- # build advanced RAG system
70
  self.build_elevated_rag_system()
71
  print("RAG system is built successfully.")
72
 
73
-
74
  def build_elevated_rag_system(self) -> None:
75
  '''
76
- Build an advanced RAG system from different components:
77
- * BM25 retriever
78
- * FAISS vector store retriever
79
- * Llama 2 model
80
  '''
81
- # combine BM25 and FAISS retrievers into an ensemble retriever
82
  self.ensemble_retriever = EnsembleRetriever(
83
  retrievers=[self.bm25_retriever, self.faiss_retriever],
84
  weights=[self.bm25_weight, self.faiss_weight]
85
  )
86
 
87
- # use reranker to improve retrieval quality
88
  self.reranker = CohereRerank(top_n=5)
89
- self.rerank_retriever = ContextualCompressionRetriever( # combine ensemble retriever and reranker
90
  base_retriever=self.ensemble_retriever,
91
  base_compressor=self.reranker,
92
  )
93
 
94
- # define prompt template for the language model
95
  RAG_PROMPT_TEMPLATE = """\
96
- Use the following context to provide a detailed technical answer the user's question.
97
- Do not use an introduction similar to "Based on the provided documents, ...", just answer the question.
98
- If you don't know the answer, please respond with "I don't know".
99
 
100
- Context:
101
- {context}
102
 
103
- User's question:
104
- {question}
105
- """
106
  self.rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
107
  self.str_output_parser = StrOutputParser()
108
 
109
- # parallel execution of context retrieval and question passing
110
  self.entry_point_and_elevated_retriever = RunnableParallel(
111
  {
112
- "context" : self.rerank_retriever,
113
- "question" : RunnablePassthrough()
114
  }
115
  )
116
 
117
- # initialize Llama 2 model with specific parameters
118
- self.llm = Replicate(
119
- model=self.llama2_70b,
120
- model_kwargs={"temperature": 0.5,"top_p": 1, "max_new_tokens":1000}
121
- )
122
 
123
- # chain components to form final elevated RAG system using LangChain Expression Language (LCEL)
124
- self.elevated_rag_chain = self.entry_point_and_elevated_retriever | self.rag_prompt | self.llm #| self.str_output_parser
 
 
2
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
  from typing import List
4
 
5
+ # Imports for our DeepSeek model pipeline
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
+ from langchain.llms import HuggingFacePipeline
8
+
9
+ # Other LangChain and community imports
10
  from langchain_community.document_loaders import OnlinePDFLoader
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
12
  from langchain_community.vectorstores import FAISS
 
19
  from langchain.schema import StrOutputParser
20
  from langchain_core.runnables import RunnableParallel, RunnablePassthrough
21
 
22
+ def create_deepseek_pipeline() -> HuggingFacePipeline:
23
+ """
24
+ Create a HuggingFace pipeline using the DeepSeek-R1 model and wrap it as a LangChain LLM.
25
+ """
26
+ # Load the DeepSeek model and tokenizer
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ "deepseek-ai/DeepSeek-R1",
29
+ trust_remote_code=True
30
+ )
31
+ tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1")
32
+
33
+ # Create a text-generation pipeline.
34
+ # You can adjust parameters like max_length, temperature, and top_p as needed.
35
+ pipe = pipeline(
36
+ "text-generation",
37
+ model=model,
38
+ tokenizer=tokenizer,
39
+ trust_remote_code=True,
40
+ max_length=2048,
41
+ do_sample=True,
42
+ temperature=0.5,
43
+ top_p=1
44
+ )
45
+
46
+ # Wrap the pipeline with HuggingFacePipeline for LangChain compatibility
47
+ return HuggingFacePipeline(pipeline=pipe)
48
 
49
  class ElevatedRagChain:
50
  '''
51
+ ElevatedRagChain integrates various components from LangChain to build an advanced
52
+ retrieval-augmented generation (RAG) system. It processes PDF documents by loading,
53
+ chunking, embedding, and adding their embeddings to a FAISS vector store for efficient
54
+ retrieval. It then uses an ensemble retriever (BM25 + FAISS) with a reranker and uses a
55
+ DeepSeek model (via a Transformers pipeline) for generating detailed technical answers.
 
 
 
56
  '''
57
  def __init__(self) -> None:
58
  '''
59
+ Initialize the class with predefined embedding function, weights, and top_k value.
60
  '''
 
61
  self.embed_func = CohereEmbeddings(model="embed-english-light-v3.0")
62
  self.bm25_weight = 0.6
63
  self.faiss_weight = 0.4
64
  self.top_k = 5
65
 
 
66
  def add_pdfs_to_vectore_store(
67
  self,
68
  pdf_links: List,
69
+ chunk_size: int = 1500,
70
+ ) -> None:
71
  '''
72
  Processes PDF documents by loading, chunking, embedding, and adding them to a FAISS vector store.
73
+
74
  Args:
75
+ pdf_links (List): list of URLs pointing to the PDF documents to be processed.
76
+ chunk_size (int, optional): size of text chunks to split the documents into, defaults to 1500.
77
  '''
78
+ # Load PDFs
79
+ self.raw_data = [OnlinePDFLoader(doc).load()[0] for doc in pdf_links]
80
 
81
+ # Chunk text
82
  self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=100)
83
  self.split_data = self.text_splitter.split_documents(self.raw_data)
84
 
85
+ # Create BM25 retriever from the split documents
86
+ self.bm25_retriever = BM25Retriever.from_documents(self.split_data)
87
  self.bm25_retriever.k = self.top_k
88
 
89
+ # Embed and add chunks to FAISS vector store
90
+ self.vector_store = FAISS.from_documents(self.split_data, self.embed_func)
91
+ self.faiss_retriever = self.vector_store.as_retriever(search_kwargs={"k": self.top_k})
92
+ print("All PDFs processed and added to vector store.")
93
 
94
+ # Build the advanced RAG system
95
  self.build_elevated_rag_system()
96
  print("RAG system is built successfully.")
97
 
 
98
  def build_elevated_rag_system(self) -> None:
99
  '''
100
+ Build an advanced RAG system by combining:
101
+ - BM25 retriever
102
+ - FAISS vector store retriever
103
+ - A DeepSeek model (via a HuggingFace pipeline)
104
  '''
105
+ # Combine BM25 and FAISS retrievers into an ensemble retriever
106
  self.ensemble_retriever = EnsembleRetriever(
107
  retrievers=[self.bm25_retriever, self.faiss_retriever],
108
  weights=[self.bm25_weight, self.faiss_weight]
109
  )
110
 
111
+ # Use a Cohere reranker to improve retrieval quality
112
  self.reranker = CohereRerank(top_n=5)
113
+ self.rerank_retriever = ContextualCompressionRetriever(
114
  base_retriever=self.ensemble_retriever,
115
  base_compressor=self.reranker,
116
  )
117
 
118
+ # Define the prompt template for the language model
119
  RAG_PROMPT_TEMPLATE = """\
120
+ Use the following context to provide a detailed technical answer to the user's question.
121
+ Do not include an introduction like "Based on the provided documents, ...". Just answer the question.
122
+ If you don't know the answer, please respond with "I don't know".
123
 
124
+ Context:
125
+ {context}
126
 
127
+ User's question:
128
+ {question}
129
+ """
130
  self.rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
131
  self.str_output_parser = StrOutputParser()
132
 
133
+ # Prepare parallel execution of context retrieval and question processing
134
  self.entry_point_and_elevated_retriever = RunnableParallel(
135
  {
136
+ "context": self.rerank_retriever,
137
+ "question": RunnablePassthrough()
138
  }
139
  )
140
 
141
+ # Initialize the DeepSeek model using a HuggingFace pipeline as our LLM
142
+ self.llm = create_deepseek_pipeline()
 
 
 
143
 
144
+ # Chain the components to form the final elevated RAG system.
145
+ # Note: Depending on your needs, you may wish to add self.str_output_parser at the end.
146
+ self.elevated_rag_chain = self.entry_point_and_elevated_retriever | self.rag_prompt | self.llm