From a9f6a06446be28408de07146cc27c0774448ac44 Mon Sep 17 00:00:00 2001 From: akashmangoai Date: Fri, 11 Oct 2024 22:58:18 +0530 Subject: [PATCH] lazy import & fixed other issue --- application/vectorstore/lancedb.py | 54 ++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/application/vectorstore/lancedb.py b/application/vectorstore/lancedb.py index caec57e9..25d62318 100644 --- a/application/vectorstore/lancedb.py +++ b/application/vectorstore/lancedb.py @@ -1,6 +1,5 @@ from typing import List, Optional -import pyarrow as pa -import lancedb +import importlib from application.vectorstore.base import BaseVectorStore from application.core.settings import settings @@ -8,21 +7,37 @@ class LanceDBVectorStore(BaseVectorStore): """Class for LanceDB Vector Store integration.""" def __init__(self, path: str = settings.LANCEDB_PATH, - table_name: str = settings.LANCEDB_TABLE_NAME, + table_name_prefix: str = settings.LANCEDB_TABLE_NAME, + source_id: str = None, embeddings_key: str = "embeddings"): """Initialize the LanceDB vector store.""" super().__init__() self.path = path - self.table_name = table_name + self.table_name = f"{table_name_prefix}_{source_id}" if source_id else table_name_prefix self.embeddings_key = embeddings_key - self._lance_db = None # Updated to snake_case + self._lance_db = None self.docsearch = None + self._pa = None # PyArrow (pa) will be lazy loaded + + @property + def pa(self): + """Lazy load pyarrow module.""" + if self._pa is None: + self._pa = importlib.import_module("pyarrow") + return self._pa + + @property + def lancedb(self): + """Lazy load lancedb module.""" + if not hasattr(self, "_lancedb_module"): + self._lancedb_module = importlib.import_module("lancedb") + return self._lancedb_module @property def lance_db(self): """Lazy load the LanceDB connection.""" if self._lance_db is None: - self._lance_db = lancedb.connect(self.path) + self._lance_db = self.lancedb.connect(self.path) return self._lance_db @property @@ -39,21 +54,23 @@ class LanceDBVectorStore(BaseVectorStore): """Ensure the table exists before performing operations.""" if self.table is None: embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, self.embeddings_key) - schema = pa.schema([ - pa.field("vector", pa.list_(pa.float32(), list_size=embeddings.dimension)), - pa.field("text", pa.string()), - pa.field("metadata", pa.struct([ - pa.field("key", pa.string()), - pa.field("value", pa.string()) + schema = self.pa.schema([ + self.pa.field("vector", self.pa.list_(self.pa.float32(), list_size=embeddings.dimension)), + self.pa.field("text", self.pa.string()), + self.pa.field("metadata", self.pa.struct([ + self.pa.field("key", self.pa.string()), + self.pa.field("value", self.pa.string()) ])) ]) self.docsearch = self.lance_db.create_table(self.table_name, schema=schema) - def add_texts(self, texts: List[str], metadatas: Optional[List[dict]] = None): + def add_texts(self, texts: List[str], metadatas: Optional[List[dict]] = None, source_id: str = None): """Add texts with metadata and their embeddings to the LanceDB table.""" embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, self.embeddings_key).embed_documents(texts) vectors = [] for embedding, text, metadata in zip(embeddings, texts, metadatas or [{}] * len(texts)): + if source_id: + metadata["source_id"] = source_id metadata_struct = [{"key": k, "value": str(v)} for k, v in metadata.items()] vectors.append({ "vector": embedding, @@ -89,5 +106,14 @@ class LanceDBVectorStore(BaseVectorStore): def filter_documents(self, filter_condition: dict) -> List[dict]: """Filter documents based on certain conditions.""" self.ensure_table_exists() - filtered_data = self.docsearch.filter(filter_condition).to_list() + + # Ensure source_id exists in the filter condition + if 'source_id' not in filter_condition: + raise ValueError("filter_condition must contain 'source_id'") + + source_id = filter_condition["source_id"] + + # Use LanceDB's native filtering if supported, otherwise filter manually + filtered_data = self.docsearch.filter(lambda x: x.metadata and x.metadata.get("source_id") == source_id).to_list() + return filtered_data \ No newline at end of file