mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
Merge pull request #822 from arc53/feat/mongodb
Mongodb integration as vectorstore
This commit is contained in:
8
application/vectorstore/document_class.py
Normal file
8
application/vectorstore/document_class.py
Normal file
@@ -0,0 +1,8 @@
|
||||
class Document(str):
|
||||
"""Class for storing a piece of text and associated metadata."""
|
||||
|
||||
def __new__(cls, page_content: str, metadata: dict):
|
||||
instance = super().__new__(cls, page_content)
|
||||
instance.page_content = page_content
|
||||
instance.metadata = metadata
|
||||
return instance
|
||||
@@ -1,16 +1,8 @@
|
||||
from application.vectorstore.base import BaseVectorStore
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.document_class import Document
|
||||
import elasticsearch
|
||||
|
||||
class Document(str):
|
||||
"""Class for storing a piece of text and associated metadata."""
|
||||
|
||||
def __new__(cls, page_content: str, metadata: dict):
|
||||
instance = super().__new__(cls, page_content)
|
||||
instance.page_content = page_content
|
||||
instance.metadata = metadata
|
||||
return instance
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
126
application/vectorstore/mongodb.py
Normal file
126
application/vectorstore/mongodb.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from application.vectorstore.base import BaseVectorStore
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.document_class import Document
|
||||
|
||||
class MongoDBVectorStore(BaseVectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
path: str = "",
|
||||
embeddings_key: str = "embeddings",
|
||||
collection: str = "documents",
|
||||
index_name: str = "vector_search_index",
|
||||
text_key: str = "text",
|
||||
embedding_key: str = "embedding",
|
||||
database: str = "docsgpt",
|
||||
):
|
||||
self._index_name = index_name
|
||||
self._text_key = text_key
|
||||
self._embedding_key = embedding_key
|
||||
self._embeddings_key = embeddings_key
|
||||
self._mongo_uri = settings.MONGO_URI
|
||||
self._path = path.replace("application/indexes/", "").rstrip("/")
|
||||
self._embedding = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
|
||||
|
||||
try:
|
||||
import pymongo
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import pymongo python package. "
|
||||
"Please install it with `pip install pymongo`."
|
||||
)
|
||||
|
||||
self._client = pymongo.MongoClient(self._mongo_uri)
|
||||
self._database = self._client[database]
|
||||
self._collection = self._database[collection]
|
||||
|
||||
|
||||
def search(self, question, k=2, *args, **kwargs):
|
||||
query_vector = self._embedding.embed_query(question)
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$vectorSearch": {
|
||||
"queryVector": query_vector,
|
||||
"path": self._embedding_key,
|
||||
"limit": k,
|
||||
"numCandidates": k * 10,
|
||||
"index": self._index_name,
|
||||
"filter": {
|
||||
"store": {"$eq": self._path}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
cursor = self._collection.aggregate(pipeline)
|
||||
|
||||
results = []
|
||||
for doc in cursor:
|
||||
text = doc[self._text_key]
|
||||
doc.pop("_id")
|
||||
doc.pop(self._text_key)
|
||||
doc.pop(self._embedding_key)
|
||||
metadata = doc
|
||||
results.append(Document(text, metadata))
|
||||
return results
|
||||
|
||||
def _insert_texts(self, texts, metadatas):
|
||||
if not texts:
|
||||
return []
|
||||
embeddings = self._embedding.embed_documents(texts)
|
||||
to_insert = [
|
||||
{self._text_key: t, self._embedding_key: embedding, **m}
|
||||
for t, m, embedding in zip(texts, metadatas, embeddings)
|
||||
]
|
||||
# insert the documents in MongoDB Atlas
|
||||
insert_result = self._collection.insert_many(to_insert)
|
||||
return insert_result.inserted_ids
|
||||
|
||||
def add_texts(self,
|
||||
texts,
|
||||
metadatas = None,
|
||||
ids = None,
|
||||
refresh_indices = True,
|
||||
create_index_if_not_exists = True,
|
||||
bulk_kwargs = None,
|
||||
**kwargs,):
|
||||
|
||||
|
||||
#dims = self._embedding.client[1].word_embedding_dimension
|
||||
# # check if index exists
|
||||
# if create_index_if_not_exists:
|
||||
# # check if index exists
|
||||
# info = self._collection.index_information()
|
||||
# if self._index_name not in info:
|
||||
# index_mongo = {
|
||||
# "fields": [{
|
||||
# "type": "vector",
|
||||
# "path": self._embedding_key,
|
||||
# "numDimensions": dims,
|
||||
# "similarity": "cosine",
|
||||
# },
|
||||
# {
|
||||
# "type": "filter",
|
||||
# "path": "store"
|
||||
# }]
|
||||
# }
|
||||
# self._collection.create_index(self._index_name, index_mongo)
|
||||
|
||||
batch_size = 100
|
||||
_metadatas = metadatas or ({} for _ in texts)
|
||||
texts_batch = []
|
||||
metadatas_batch = []
|
||||
result_ids = []
|
||||
for i, (text, metadata) in enumerate(zip(texts, _metadatas)):
|
||||
texts_batch.append(text)
|
||||
metadatas_batch.append(metadata)
|
||||
if (i + 1) % batch_size == 0:
|
||||
result_ids.extend(self._insert_texts(texts_batch, metadatas_batch))
|
||||
texts_batch = []
|
||||
metadatas_batch = []
|
||||
if texts_batch:
|
||||
result_ids.extend(self._insert_texts(texts_batch, metadatas_batch))
|
||||
return result_ids
|
||||
|
||||
def delete_index(self, *args, **kwargs):
|
||||
self._collection.delete_many({"store": self._path})
|
||||
@@ -1,11 +1,13 @@
|
||||
from application.vectorstore.faiss import FaissStore
|
||||
from application.vectorstore.elasticsearch import ElasticsearchStore
|
||||
from application.vectorstore.mongodb import MongoDBVectorStore
|
||||
|
||||
|
||||
class VectorCreator:
|
||||
vectorstores = {
|
||||
'faiss': FaissStore,
|
||||
'elasticsearch':ElasticsearchStore
|
||||
'elasticsearch':ElasticsearchStore,
|
||||
'mongodb': MongoDBVectorStore,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from application.llm.openai import OpenAILLM
|
||||
|
||||
class TestOpenAILLM(unittest.TestCase):
|
||||
@@ -10,23 +9,3 @@ class TestOpenAILLM(unittest.TestCase):
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.llm.api_key, self.api_key)
|
||||
|
||||
@patch('application.llm.openai.openai.ChatCompletion.create')
|
||||
def test_gen(self, mock_create):
|
||||
model = "test_model"
|
||||
engine = "test_engine"
|
||||
messages = ["test_message"]
|
||||
response = {"choices": [{"message": {"content": "test_response"}}]}
|
||||
mock_create.return_value = response
|
||||
result = self.llm.gen(model, engine, messages)
|
||||
self.assertEqual(result, "test_response")
|
||||
|
||||
@patch('application.llm.openai.openai.ChatCompletion.create')
|
||||
def test_gen_stream(self, mock_create):
|
||||
model = "test_model"
|
||||
engine = "test_engine"
|
||||
messages = ["test_message"]
|
||||
response = [{"choices": [{"delta": {"content": "test_response"}}]}]
|
||||
mock_create.return_value = response
|
||||
result = list(self.llm.gen_stream(model, engine, messages))
|
||||
self.assertEqual(result, ["test_response"])
|
||||
|
||||
Reference in New Issue
Block a user