mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
elastic2
This commit is contained in:
@@ -1,22 +1,38 @@
|
||||
from application.vectorstore.base import BaseVectorStore
|
||||
from application.core.settings import settings
|
||||
import elasticsearch
|
||||
#from langchain.vectorstores.elasticsearch import ElasticsearchStore
|
||||
|
||||
class Document(str):
|
||||
"""Class for storing a piece of text and associated metadata."""
|
||||
|
||||
page_content: str
|
||||
"""String text."""
|
||||
metadata: dict
|
||||
"""Arbitrary metadata"""
|
||||
|
||||
|
||||
class ElasticsearchStore(BaseVectorStore):
|
||||
_es_connection = None # Class attribute to hold the Elasticsearch connection
|
||||
|
||||
def __init__(self, path, embeddings_key, index_name="docsgpt"):
|
||||
def __init__(self, path, embeddings_key, index_name=settings.ELASTIC_INDEX):
|
||||
super().__init__()
|
||||
self.path = path.replace("/app/application/indexes/", "")
|
||||
self.path = path.replace("application/indexes/", "")
|
||||
self.embeddings_key = embeddings_key
|
||||
self.index_name = index_name
|
||||
|
||||
if ElasticsearchStore._es_connection is None:
|
||||
connection_params = {}
|
||||
connection_params["cloud_id"] = settings.ELASTIC_CLOUD_ID
|
||||
connection_params["basic_auth"] = (settings.ELASTIC_USERNAME, settings.ELASTIC_PASSWORD)
|
||||
if settings.ELASTIC_URL:
|
||||
connection_params["hosts"] = [settings.ELASTIC_URL]
|
||||
connection_params["http_auth"] = (settings.ELASTIC_USERNAME, settings.ELASTIC_PASSWORD)
|
||||
elif settings.ELASTIC_CLOUD_ID:
|
||||
connection_params["cloud_id"] = settings.ELASTIC_CLOUD_ID
|
||||
connection_params["basic_auth"] = (settings.ELASTIC_USERNAME, settings.ELASTIC_PASSWORD)
|
||||
else:
|
||||
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
|
||||
|
||||
|
||||
|
||||
ElasticsearchStore._es_connection = elasticsearch.Elasticsearch(**connection_params)
|
||||
|
||||
self.docsearch = ElasticsearchStore._es_connection
|
||||
@@ -94,106 +110,112 @@ class ElasticsearchStore(BaseVectorStore):
|
||||
},
|
||||
"rank": {"rrf": {}},
|
||||
}
|
||||
resp = self.docsearch.search(index=index_name, query=full_query['query'], size=k, knn=full_query['knn'])
|
||||
return resp
|
||||
resp = self.docsearch.search(index=self.index_name, query=full_query['query'], size=k, knn=full_query['knn'])
|
||||
# create Documnets objects from the results page_content ['_source']['text'], metadata ['_source']['metadata']
|
||||
import sys
|
||||
print(self.path, file=sys.stderr)
|
||||
print(resp, file=sys.stderr)
|
||||
doc_list = []
|
||||
for hit in resp['hits']['hits']:
|
||||
|
||||
doc_list.append(Document(page_content = hit['_source']['text'], metadata = hit['_source']['metadata']))
|
||||
return doc_list
|
||||
|
||||
def _create_index_if_not_exists(
|
||||
self, index_name, dims_length
|
||||
):
|
||||
def _create_index_if_not_exists(
|
||||
self, index_name, dims_length
|
||||
):
|
||||
|
||||
if self.client.indices.exists(index=index_name):
|
||||
print(f"Index {index_name} already exists.")
|
||||
if self._es_connection.indices.exists(index=index_name):
|
||||
print(f"Index {index_name} already exists.")
|
||||
|
||||
else:
|
||||
self.strategy.before_index_setup(
|
||||
client=self.client,
|
||||
text_field=self.query_field,
|
||||
vector_query_field=self.vector_query_field,
|
||||
)
|
||||
else:
|
||||
|
||||
indexSettings = self.index(
|
||||
dims_length=dims_length,
|
||||
)
|
||||
self.client.indices.create(index=index_name, **indexSettings)
|
||||
def index(
|
||||
self,
|
||||
dims_length,
|
||||
):
|
||||
indexSettings = self.index(
|
||||
dims_length=dims_length,
|
||||
)
|
||||
self._es_connection.indices.create(index=index_name, **indexSettings)
|
||||
|
||||
|
||||
return {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": dims_length,
|
||||
"index": True,
|
||||
"similarity": "cosine",
|
||||
},
|
||||
}
|
||||
def index(
|
||||
self,
|
||||
dims_length,
|
||||
):
|
||||
return {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": dims_length,
|
||||
"index": True,
|
||||
"similarity": "cosine",
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts,
|
||||
metadatas = None,
|
||||
ids = None,
|
||||
refresh_indices = True,
|
||||
create_index_if_not_exists = True,
|
||||
bulk_kwargs = None,
|
||||
**kwargs,
|
||||
def add_texts(
|
||||
self,
|
||||
texts,
|
||||
metadatas = None,
|
||||
ids = None,
|
||||
refresh_indices = True,
|
||||
create_index_if_not_exists = True,
|
||||
bulk_kwargs = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
from elasticsearch.helpers import BulkIndexError, bulk
|
||||
|
||||
from elasticsearch.helpers import BulkIndexError, bulk
|
||||
|
||||
bulk_kwargs = bulk_kwargs or {}
|
||||
import uuid
|
||||
embeddings = []
|
||||
ids = ids or [str(uuid.uuid4()) for _ in texts]
|
||||
requests = []
|
||||
embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, self.embeddings_key)
|
||||
bulk_kwargs = bulk_kwargs or {}
|
||||
import uuid
|
||||
embeddings = []
|
||||
ids = ids or [str(uuid.uuid4()) for _ in texts]
|
||||
requests = []
|
||||
embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, self.embeddings_key)
|
||||
|
||||
vectors = embeddings.embed_documents(list(texts))
|
||||
vectors = embeddings.embed_documents(list(texts))
|
||||
|
||||
dims_length = len(vectors[0])
|
||||
dims_length = len(vectors[0])
|
||||
|
||||
if create_index_if_not_exists:
|
||||
self._create_index_if_not_exists(
|
||||
index_name=self.index_name, dims_length=dims_length
|
||||
if create_index_if_not_exists:
|
||||
self._create_index_if_not_exists(
|
||||
index_name=self.index_name, dims_length=dims_length
|
||||
)
|
||||
|
||||
for i, (text, vector) in enumerate(zip(texts, vectors)):
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
|
||||
requests.append(
|
||||
{
|
||||
"_op_type": "index",
|
||||
"_index": self.index_name,
|
||||
"text": text,
|
||||
"vector": vector,
|
||||
"metadata": metadata,
|
||||
"_id": ids[i],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if len(requests) > 0:
|
||||
try:
|
||||
success, failed = bulk(
|
||||
self._es_connection,
|
||||
requests,
|
||||
stats_only=True,
|
||||
refresh=refresh_indices,
|
||||
**bulk_kwargs,
|
||||
)
|
||||
return ids
|
||||
except BulkIndexError as e:
|
||||
print(f"Error adding texts: {e}")
|
||||
firstError = e.errors[0].get("index", {}).get("error", {})
|
||||
print(f"First error reason: {firstError.get('reason')}")
|
||||
raise e
|
||||
|
||||
for i, (text, vector) in enumerate(zip(texts, vectors)):
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
else:
|
||||
return []
|
||||
|
||||
requests.append(
|
||||
{
|
||||
"_op_type": "index",
|
||||
"_index": self.index_name,
|
||||
"text": text,
|
||||
"vector": vector,
|
||||
"metadata": metadata,
|
||||
"_id": ids[i],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if len(requests) > 0:
|
||||
try:
|
||||
success, failed = bulk(
|
||||
self.client,
|
||||
requests,
|
||||
stats_only=True,
|
||||
refresh=refresh_indices,
|
||||
**bulk_kwargs,
|
||||
)
|
||||
return ids
|
||||
except BulkIndexError as e:
|
||||
print(f"Error adding texts: {e}")
|
||||
firstError = e.errors[0].get("index", {}).get("error", {})
|
||||
print(f"First error reason: {firstError.get('reason')}")
|
||||
raise e
|
||||
|
||||
else:
|
||||
return []
|
||||
def delete_index(self):
|
||||
self._es_connection.delete_by_query(index=self.index_name, query={"match": {
|
||||
"metadata.filename.keyword": self.path}},)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user