AI_Doctors / backend /rag_engine.py
NandanData's picture
Upload 853 files
6331e04 verified
raw
history blame
1.58 kB
import os,glob,chromadb
from sentence_transformers import SentenceTransformer,models
from langchain_text_splitters import RecursiveCharacterTextSplitter
from utils.constants import CHROMA_DIR,DOCS_DIR,COLLECTION,EMB_MODEL_NAME
def get_embedder():
w=models.Transformer(EMB_MODEL_NAME);p=models.Pooling(w.get_word_embedding_dimension())
return SentenceTransformer(modules=[w,p])
def get_chroma():
c=chromadb.PersistentClient(path=CHROMA_DIR)
return c,c.get_or_create_collection(COLLECTION,metadata={"hnsw:space":"cosine"})
def embed(m,txts):return m.encode(txts,convert_to_numpy=True).tolist()
def seed_index(col,m,folder):
sp=RecursiveCharacterTextSplitter(chunk_size=1000,chunk_overlap=150)
paths=glob.glob(folder+'/*.txt')
ids,docs,meta=[],[],[]
for p in paths:
t=os.path.basename(p).replace('.txt','')
with open(p) as f:tx=f.read()
for i,ch in enumerate(sp.split_text(tx)):
ids.append(f"{t}-{i}");docs.append(ch);meta.append({"title":t,"source":p})
em=embed(m,docs)
try:col.add(ids=ids,documents=docs,metadatas=meta,embeddings=em)
except:col.delete(ids=ids);col.add(ids=ids,documents=docs,metadatas=meta,embeddings=em)
return len(docs)
def retrieve(col,m,q,k):
em=embed(m,[q])[0]
r=col.query(query_embeddings=[em],n_results=k,include=["documents","metadatas"])
out=[]
if r.get("ids"):
for i in range(len(r["ids"][0])):
out.append({"text":r["documents"][0][i],"title":r["metadatas"][0][i]["title"],"source":r["metadatas"][0][i]["source"]})
return out