diff --git a/application/core/settings.py b/application/core/settings.py index 7346da08..d4b02481 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -18,7 +18,7 @@ class Settings(BaseSettings): DEFAULT_MAX_HISTORY: int = 150 MODEL_TOKEN_LIMITS: dict = {"gpt-3.5-turbo": 4096, "claude-2": 1e5} UPLOAD_FOLDER: str = "inputs" - VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" + VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search # LLM Cache @@ -70,6 +70,9 @@ class Settings(BaseSettings): MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default MILVUS_TOKEN: Optional[str] = "" + # LanceDB vectorstore config + LANCEDB_PATH: str = "/tmp/lancedb" # Path where LanceDB stores its local data + LANCEDB_TABLE_NAME: Optional[str] = "docsgpts" # Name of the table to use for storing vectors BRAVE_SEARCH_API_KEY: Optional[str] = None FLASK_DEBUG_MODE: bool = False diff --git a/application/vectorstore/lancedb.py b/application/vectorstore/lancedb.py new file mode 100644 index 00000000..25d62318 --- /dev/null +++ b/application/vectorstore/lancedb.py @@ -0,0 +1,119 @@ +from typing import List, Optional +import importlib +from application.vectorstore.base import BaseVectorStore +from application.core.settings import settings + +class LanceDBVectorStore(BaseVectorStore): + """Class for LanceDB Vector Store integration.""" + + def __init__(self, path: str = settings.LANCEDB_PATH, + table_name_prefix: str = settings.LANCEDB_TABLE_NAME, + source_id: str = None, + embeddings_key: str = "embeddings"): + """Initialize the LanceDB vector store.""" + super().__init__() + self.path = path + self.table_name = f"{table_name_prefix}_{source_id}" if source_id else table_name_prefix + self.embeddings_key = embeddings_key + self._lance_db = None + self.docsearch = None + self._pa = None # PyArrow (pa) will be lazy loaded + + @property + def pa(self): + """Lazy load pyarrow module.""" + if self._pa is None: + self._pa = importlib.import_module("pyarrow") + return self._pa + + @property + def lancedb(self): + """Lazy load lancedb module.""" + if not hasattr(self, "_lancedb_module"): + self._lancedb_module = importlib.import_module("lancedb") + return self._lancedb_module + + @property + def lance_db(self): + """Lazy load the LanceDB connection.""" + if self._lance_db is None: + self._lance_db = self.lancedb.connect(self.path) + return self._lance_db + + @property + def table(self): + """Lazy load the LanceDB table.""" + if self.docsearch is None: + if self.table_name in self.lance_db.table_names(): + self.docsearch = self.lance_db.open_table(self.table_name) + else: + self.docsearch = None + return self.docsearch + + def ensure_table_exists(self): + """Ensure the table exists before performing operations.""" + if self.table is None: + embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, self.embeddings_key) + schema = self.pa.schema([ + self.pa.field("vector", self.pa.list_(self.pa.float32(), list_size=embeddings.dimension)), + self.pa.field("text", self.pa.string()), + self.pa.field("metadata", self.pa.struct([ + self.pa.field("key", self.pa.string()), + self.pa.field("value", self.pa.string()) + ])) + ]) + self.docsearch = self.lance_db.create_table(self.table_name, schema=schema) + + def add_texts(self, texts: List[str], metadatas: Optional[List[dict]] = None, source_id: str = None): + """Add texts with metadata and their embeddings to the LanceDB table.""" + embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, self.embeddings_key).embed_documents(texts) + vectors = [] + for embedding, text, metadata in zip(embeddings, texts, metadatas or [{}] * len(texts)): + if source_id: + metadata["source_id"] = source_id + metadata_struct = [{"key": k, "value": str(v)} for k, v in metadata.items()] + vectors.append({ + "vector": embedding, + "text": text, + "metadata": metadata_struct + }) + self.ensure_table_exists() + self.docsearch.add(vectors) + + def search(self, query: str, k: int = 2, *args, **kwargs): + """Search LanceDB for the top k most similar vectors.""" + self.ensure_table_exists() + query_embedding = self._get_embeddings(settings.EMBEDDINGS_NAME, self.embeddings_key).embed_query(query) + results = self.docsearch.search(query_embedding).limit(k).to_list() + return [(result["_distance"], result["text"], result["metadata"]) for result in results] + + def delete_index(self): + """Delete the entire LanceDB index (table).""" + if self.table: + self.lance_db.drop_table(self.table_name) + + def assert_embedding_dimensions(self, embeddings): + """Ensure that embedding dimensions match the table index dimensions.""" + word_embedding_dimension = embeddings.dimension + if self.table: + table_index_dimension = len(self.docsearch.schema["vector"].type.value_type) + if word_embedding_dimension != table_index_dimension: + raise ValueError( + f"Embedding dimension mismatch: embeddings.dimension ({word_embedding_dimension}) " + f"!= table index dimension ({table_index_dimension})" + ) + + def filter_documents(self, filter_condition: dict) -> List[dict]: + """Filter documents based on certain conditions.""" + self.ensure_table_exists() + + # Ensure source_id exists in the filter condition + if 'source_id' not in filter_condition: + raise ValueError("filter_condition must contain 'source_id'") + + source_id = filter_condition["source_id"] + + # Use LanceDB's native filtering if supported, otherwise filter manually + filtered_data = self.docsearch.filter(lambda x: x.metadata and x.metadata.get("source_id") == source_id).to_list() + + return filtered_data \ No newline at end of file