feat Milvus integration

This commit is contained in:
Alex
2024-09-05 23:41:51 +01:00
parent 6c583eedb9
commit d232229abf
9 changed files with 56 additions and 56 deletions

View File

@@ -8,13 +8,15 @@ EbookLib==0.18
elasticsearch==8.12.0
escodegen==1.0.11
esprima==4.0.1
faiss-cpu==1.7.4
faiss-cpu==1.8.0.post1
Flask==3.0.1
gunicorn==22.0.0
html2text==2020.1.16
javalang==0.13.0
langchain==0.1.4
langchain-openai==0.0.5
langchain==0.2.16
langchain-community==0.2.16
langchain-core==0.2.38
langchain-openai==0.1.23
openapi3_parser==1.1.16
pandas==2.2.0
pydantic_settings==2.1.0
@@ -26,7 +28,7 @@ redis==5.0.1
Requests==2.32.0
retry==0.9.2
sentence-transformers
tiktoken
tiktoken==0.7.0
torch
tqdm==4.66.3
transformers==4.36.2

View File

@@ -1,13 +1,30 @@
from abc import ABC, abstractmethod
import os
from langchain_community.embeddings import (
HuggingFaceEmbeddings,
CohereEmbeddings,
HuggingFaceInstructEmbeddings,
)
from sentence_transformers import SentenceTransformer
from langchain_openai import OpenAIEmbeddings
from application.core.settings import settings
class EmbeddingsWrapper:
def __init__(self, model_name, *args, **kwargs):
self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs)
self.dimension = self.model.get_sentence_embedding_dimension()
def embed_query(self, query: str):
return self.model.encode(query).tolist()
def embed_documents(self, documents: list):
return self.model.encode(documents).tolist()
def __call__(self, text):
if isinstance(text, str):
return self.embed_query(text)
elif isinstance(text, list):
return self.embed_documents(text)
else:
raise ValueError("Input must be a string or a list of strings")
class EmbeddingsSingleton:
_instances = {}
@@ -23,16 +40,15 @@ class EmbeddingsSingleton:
def _create_instance(embeddings_name, *args, **kwargs):
embeddings_factory = {
"openai_text-embedding-ada-002": OpenAIEmbeddings,
"huggingface_sentence-transformers/all-mpnet-base-v2": HuggingFaceEmbeddings,
"huggingface_sentence-transformers-all-mpnet-base-v2": HuggingFaceEmbeddings,
"huggingface_hkunlp/instructor-large": HuggingFaceInstructEmbeddings,
"cohere_medium": CohereEmbeddings
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"),
}
if embeddings_name not in embeddings_factory:
raise ValueError(f"Invalid embeddings_name: {embeddings_name}")
return embeddings_factory[embeddings_name](*args, **kwargs)
if embeddings_name in embeddings_factory:
return embeddings_factory[embeddings_name](*args, **kwargs)
else:
return EmbeddingsWrapper(embeddings_name, *args, **kwargs)
class BaseVectorStore(ABC):
def __init__(self):
@@ -58,22 +74,14 @@ class BaseVectorStore(ABC):
embeddings_name,
openai_api_key=embeddings_key
)
elif embeddings_name == "cohere_medium":
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
cohere_api_key=embeddings_key
)
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
if os.path.exists("./model/all-mpnet-base-v2"):
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
model_name="./model/all-mpnet-base-v2",
model_kwargs={"device": "cpu"}
embeddings_name="./model/all-mpnet-base-v2",
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
model_kwargs={"device": "cpu"}
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)

View File

@@ -14,7 +14,8 @@ class FaissStore(BaseVectorStore):
)
else:
self.docsearch = FAISS.load_local(
self.path, embeddings
self.path, embeddings,
allow_dangerous_deserialization=True
)
self.assert_embedding_dimensions(embeddings)
@@ -37,10 +38,10 @@ class FaissStore(BaseVectorStore):
"""
if settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2":
try:
word_embedding_dimension = embeddings.client[1].word_embedding_dimension
word_embedding_dimension = embeddings.dimension
except AttributeError as e:
raise AttributeError("word_embedding_dimension not found in embeddings.client[1]") from e
raise AttributeError("'dimension' attribute not found in embeddings instance. Make sure the embeddings object is properly initialized.") from e
docsearch_index_dimension = self.docsearch.index.d
if word_embedding_dimension != docsearch_index_dimension:
raise ValueError(f"word_embedding_dimension ({word_embedding_dimension}) " +
f"!= docsearch_index_word_embedding_dimension ({docsearch_index_dimension})")
raise ValueError(f"Embedding dimension mismatch: embeddings.dimension ({word_embedding_dimension}) " +
f"!= docsearch index dimension ({docsearch_index_dimension})")

View File

@@ -1,5 +1,6 @@
from typing import List, Optional
from langchain_community.vectorstores.milvus import Milvus
from uuid import uuid4
from application.core.settings import settings
from application.vectorstore.base import BaseVectorStore
@@ -8,28 +9,26 @@ from application.vectorstore.base import BaseVectorStore
class MilvusStore(BaseVectorStore):
def __init__(self, path: str = "", embeddings_key: str = "embeddings"):
super().__init__()
if path:
connection_args ={
"uri": path,
"tpken": settings.MILVUS_TOKEN,
}
else:
connection_args = {
"uri": settings.MILVUS_URL,
'token': settings.MILVUS_TOKEN,
}
from langchain_milvus import Milvus
connection_args = {
"uri": settings.MILVUS_URI,
"token": settings.MILVUS_TOKEN,
}
self._docsearch = Milvus(
embedding_function=self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key),
collection_name=settings.COLLECTION_NAME,
collection_name=settings.MILVUS_COLLECTION_NAME,
connection_args=connection_args,
drop_old=True,
)
self._path = path
def search(self, question, k=2, *args, **kwargs):
return self._docsearch.similarity_search(query=question, k=k, *args, **kwargs)
return self._docsearch.similarity_search(query=question, k=k, filter={"path": self._path} *args, **kwargs)
def add_texts(self, texts: List[str], metadatas: Optional[List[dict]], *args, **kwargs):
return self._docsearch.add_texts(texts=texts, metadatas=metadatas, *args, **kwargs)
ids = [str(uuid4()) for _ in range(len(texts))]
return self._docsearch.add_texts(texts=texts, metadatas=metadatas, ids=ids, *args, **kwargs)
def save_local(self, *args, **kwargs):
pass