Spaces:
Paused
Paused
| import os | |
| from pathlib import Path | |
| from typing import Any | |
| import pandas as pd | |
| from fastembed import SparseTextEmbedding, SparseEmbedding | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import hf_hub_download | |
| from qdrant_client import QdrantClient | |
| from qdrant_client import models as qmodels | |
| VLLM_DTYPE = os.getenv("VLLM_DTYPE") | |
| DATA_PATH = Path(os.getenv("DATA_PATH")) | |
| DB_PATH = DATA_PATH / "db" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| RECREATE_DB = bool(os.getenv("RECREATE_DB", "False").lower() == "true") | |
| DATA_REPO = os.getenv("DATA_REPO") | |
| DATA_FILENAME = os.getenv("DATA_FILENAME") | |
| client = QdrantClient(path=str(DB_PATH)) | |
| collection_name = "knowledge_cards" | |
| dense_model_dims = 1024 | |
| dense_batch_size = 128 | |
| sparse_batch_size = 256 | |
| dense_encoder = SentenceTransformer( | |
| model_name_or_path="mixedbread-ai/mxbai-embed-large-v1", | |
| device="cuda", | |
| model_kwargs={"torch_dtype": VLLM_DTYPE}, | |
| ) | |
| sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True) | |
| # Utils | |
| def convert_serialized_sparse_embeddings(sparse_dict: dict[str, float]): | |
| """Convert all dictionary keys to strings for PyArrow compatibility.""" | |
| return SparseEmbedding.from_dict({int(k): v for k, v in sparse_dict.items()}) | |
| def ingest_data(chunks: list[dict[str, Any]]): | |
| if client.collection_exists(collection_name) and RECREATE_DB: | |
| print("Recreating collection.", flush=True) | |
| client.delete_collection(collection_name) | |
| elif client.collection_exists(collection_name): | |
| print("Collection already exists, skipping ingestion.", flush=True) | |
| return | |
| print("Ingesting knowledge cards...", flush=True) | |
| client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config={ | |
| "dense": qmodels.VectorParams( | |
| size=dense_model_dims, | |
| distance=qmodels.Distance.COSINE, | |
| ) | |
| }, | |
| sparse_vectors_config={ | |
| "sparse": qmodels.SparseVectorParams(modifier=qmodels.Modifier.IDF) | |
| }, | |
| ) | |
| # Generate embeddings | |
| chunk_texts = [chunk["text"] for chunk in chunks] | |
| dense_vectors = list( | |
| dense_encoder.encode( | |
| chunk_texts, | |
| batch_size=dense_batch_size, | |
| normalize_embeddings=True, | |
| ) | |
| ) | |
| sparse_vectors = list( | |
| sparse_encoder.embed(chunk_texts, batch_size=sparse_batch_size) | |
| ) | |
| # Upload to db | |
| client.upload_points( | |
| collection_name=collection_name, | |
| points=[ | |
| qmodels.PointStruct( | |
| id=idx, | |
| payload=chunk, | |
| vector={"dense": dense, "sparse": sparse.as_object()}, | |
| ) | |
| for idx, (chunk, dense, sparse) in enumerate( | |
| zip(chunks, dense_vectors, sparse_vectors) | |
| ) | |
| ], | |
| ) | |
| def ingest(): | |
| downloaded_path = hf_hub_download( | |
| repo_id=DATA_REPO, filename=DATA_FILENAME, token=HF_TOKEN, repo_type="dataset" | |
| ) | |
| print(f"Downloaded knowledge card dataset; path = {downloaded_path}", flush=True) | |
| chunk_df = pd.read_parquet(downloaded_path) | |
| chunks = chunk_df.to_dict(orient="records") | |
| ingest_data(chunks=chunks) | |
| print("Ingestion is finished.", flush=True) | |
| if __name__ == "__main__": | |
| ingest() | |