Merge branch 'main' into 1059-migrating-database-to-new-model

This commit is contained in:
Alex
2024-09-09 23:55:25 +01:00
64 changed files with 3517 additions and 4971 deletions

View File

@@ -1,7 +1,7 @@
import asyncio
import os
import sys
from flask import Blueprint, request, Response
from flask import Blueprint, request, Response, current_app
import json
import datetime
import logging
@@ -126,7 +126,11 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm)
"role": "assistant",
"content": "Summarise following conversation in no more than 3 "
"words, respond ONLY with the summary, use the same "
"language as the system \n\nUser: " + question + "\n\n" + "AI: " + response,
"language as the system \n\nUser: "
+ question
+ "\n\n"
+ "AI: "
+ response,
},
{
"role": "user",
@@ -166,7 +170,10 @@ def get_prompt(prompt_id):
return prompt
def complete_stream(question, retriever, conversation_id, user_api_key):
def complete_stream(
question, retriever, conversation_id, user_api_key, isNoneDoc=False
):
try:
response_full = ""
source_log_docs = []
@@ -179,9 +186,17 @@ def complete_stream(question, retriever, conversation_id, user_api_key):
elif "source" in line:
source_log_docs.append(line["source"])
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key)
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
if user_api_key is None:
conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm)
conversation_id = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
yield f"data: {data}\n\n"
@@ -205,7 +220,6 @@ def complete_stream(question, retriever, conversation_id, user_api_key):
def stream():
try:
data = request.get_json()
# get parameter from url question
question = data["question"]
if "history" not in data:
history = []
@@ -252,10 +266,9 @@ def stream():
source = {}
user_api_key = None
""" if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
retriever_name = "classic"
else:
retriever_name = source["active_docs"] """
current_app.logger.info(f"/stream - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})}
)
prompt = get_prompt(prompt_id)
@@ -277,20 +290,23 @@ def stream():
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
isNoneDoc=data.get("isNoneDoc"),
),
mimetype="text/event-stream",
)
except ValueError as err:
except ValueError:
message = "Malformed request body"
print("\033[91merr", str(err), file=sys.stderr)
print("\033[91merr", str(message), file=sys.stderr)
return Response(
error_stream_generate(message),
status=400,
mimetype="text/event-stream",
)
except Exception as e:
print("\033[91merr", str(e), file=sys.stderr)
current_app.logger.error(f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()}
)
message = e.args[0]
status_code = 400
# # Custom exceptions with two arguments, index 1 as status code
@@ -357,6 +373,10 @@ def api_answer():
prompt = get_prompt(prompt_id)
current_app.logger.info(f"/api/answer - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})}
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
@@ -376,7 +396,13 @@ def api_answer():
elif "answer" in line:
response_full += line["answer"]
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key)
if data.get("isNoneDoc"):
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
)
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = str(
@@ -385,16 +411,15 @@ def api_answer():
return result
except Exception as e:
# print whole traceback
traceback.print_exc()
print(str(e))
current_app.logger.error(f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()}
)
return bad_request(500, str(e))
@answer.route("/api/search", methods=["POST"])
def api_search():
data = request.get_json()
# get parameter from url question
question = data["question"]
if "chunks" in data:
chunks = int(data["chunks"])
@@ -420,6 +445,10 @@ def api_search():
token_limit = data["token_limit"]
else:
token_limit = settings.DEFAULT_MAX_HISTORY
current_app.logger.info(f"/api/answer - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})}
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
@@ -433,4 +462,9 @@ def api_search():
user_api_key=user_api_key,
)
docs = retriever.search()
if data.get("isNoneDoc"):
for doc in docs:
doc["source"] = "None"
return docs

View File

@@ -6,12 +6,14 @@ from application.core.settings import settings
from application.api.user.routes import user
from application.api.answer.routes import answer
from application.api.internal.routes import internal
from application.core.logging_config import setup_logging
if platform.system() == "Windows":
import pathlib
pathlib.PosixPath = pathlib.WindowsPath
dotenv.load_dotenv()
setup_logging()
app = Flask(__name__)
app.register_blueprint(user)

View File

@@ -1,9 +1,15 @@
from celery import Celery
from application.core.settings import settings
from celery.signals import setup_logging
def make_celery(app_name=__name__):
celery = Celery(app_name, broker=settings.CELERY_BROKER_URL, backend=settings.CELERY_RESULT_BACKEND)
celery.conf.update(settings)
return celery
@setup_logging.connect
def config_loggers(*args, **kwargs):
from application.core.logging_config import setup_logging
setup_logging()
celery = make_celery()

View File

@@ -0,0 +1,22 @@
from logging.config import dictConfig
def setup_logging():
dictConfig({
'version': 1,
'formatters': {
'default': {
'format': '[%(asctime)s] %(levelname)s in %(module)s: %(message)s',
}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
"formatter": "default",
}
},
'root': {
'level': 'INFO',
'handlers': ['console'],
},
})

View File

@@ -18,7 +18,7 @@ class Settings(BaseSettings):
DEFAULT_MAX_HISTORY: int = 150
MODEL_TOKEN_LIMITS: dict = {"gpt-3.5-turbo": 4096, "claude-2": 1e5}
UPLOAD_FOLDER: str = "inputs"
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant"
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus"
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
API_URL: str = "http://localhost:7091" # backend url for celery worker
@@ -29,6 +29,7 @@ class Settings(BaseSettings):
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings
OPENAI_BASE_URL: Optional[str] = None # openai base url for open ai compatable models
# elasticsearch
ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch
@@ -61,6 +62,11 @@ class Settings(BaseSettings):
QDRANT_PATH: Optional[str] = None
QDRANT_DISTANCE_FUNC: str = "Cosine"
# Milvus vectorstore config
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
MILVUS_TOKEN: Optional[str] = ""
BRAVE_SEARCH_API_KEY: Optional[str] = None
FLASK_DEBUG_MODE: bool = False

View File

@@ -2,25 +2,23 @@ from application.llm.base import BaseLLM
from application.core.settings import settings
class OpenAILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
global openai
from openai import OpenAI
super().__init__(*args, **kwargs)
self.client = OpenAI(
api_key=api_key,
)
if settings.OPENAI_BASE_URL:
self.client = OpenAI(
api_key=api_key,
base_url=settings.OPENAI_BASE_URL
)
else:
self.client = OpenAI(api_key=api_key)
self.api_key = api_key
self.user_api_key = user_api_key
def _get_openai(self):
# Import openai when needed
import openai
return openai
def _raw_gen(
self,
baseself,
@@ -29,7 +27,7 @@ class OpenAILLM(BaseLLM):
stream=False,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
):
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
@@ -44,7 +42,7 @@ class OpenAILLM(BaseLLM):
stream=True,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
):
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
@@ -73,8 +71,3 @@ class AzureOpenAILLM(OpenAILLM):
api_base=settings.OPENAI_API_BASE,
deployment_name=settings.AZURE_DEPLOYMENT_NAME,
)
def _get_openai(self):
openai = super()._get_openai()
return openai

View File

@@ -3,7 +3,6 @@
Contains parser for html files.
"""
import re
from pathlib import Path
from typing import Dict, Union
@@ -18,66 +17,8 @@ class HTMLParser(BaseParser):
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, list[str]]:
"""Parse file.
from langchain_community.document_loaders import BSHTMLLoader
Returns:
Union[str, List[str]]: a string or a List of strings.
"""
try:
from unstructured.partition.html import partition_html
from unstructured.staging.base import convert_to_isd
from unstructured.cleaners.core import clean
except ImportError:
raise ValueError("unstructured package is required to parse HTML files.")
# Using the unstructured library to convert the html to isd format
# isd sample : isd = [
# {"text": "My Title", "type": "Title"},
# {"text": "My Narrative", "type": "NarrativeText"}
# ]
with open(file, "r", encoding="utf-8") as fp:
elements = partition_html(file=fp)
isd = convert_to_isd(elements)
# Removing non ascii charactwers from isd_el['text']
for isd_el in isd:
isd_el['text'] = isd_el['text'].encode("ascii", "ignore").decode()
# Removing all the \n characters from isd_el['text'] using regex and replace with single space
# Removing all the extra spaces from isd_el['text'] using regex and replace with single space
for isd_el in isd:
isd_el['text'] = re.sub(r'\n', ' ', isd_el['text'], flags=re.MULTILINE | re.DOTALL)
isd_el['text'] = re.sub(r"\s{2,}", " ", isd_el['text'], flags=re.MULTILINE | re.DOTALL)
# more cleaning: extra_whitespaces, dashes, bullets, trailing_punctuation
for isd_el in isd:
clean(isd_el['text'], extra_whitespace=True, dashes=True, bullets=True, trailing_punctuation=True)
# Creating a list of all the indexes of isd_el['type'] = 'Title'
title_indexes = [i for i, isd_el in enumerate(isd) if isd_el['type'] == 'Title']
# Creating 'Chunks' - List of lists of strings
# each list starting with isd_el['type'] = 'Title' and all the data till the next 'Title'
# Each Chunk can be thought of as an individual set of data, which can be sent to the model
# Where Each Title is grouped together with the data under it
Chunks = [[]]
final_chunks = list(list())
for i, isd_el in enumerate(isd):
if i in title_indexes:
Chunks.append([])
Chunks[-1].append(isd_el['text'])
# Removing all the chunks with sum of length of all the strings in the chunk < 25
# TODO: This value can be an user defined variable
for chunk in Chunks:
# sum of length of all the strings in the chunk
sum = 0
sum += len(str(chunk))
if sum < 25:
Chunks.remove(chunk)
else:
# appending all the approved chunks to final_chunks as a single string
final_chunks.append(" ".join([str(item) for item in chunk]))
return final_chunks
loader = BSHTMLLoader(file)
data = loader.load()
return data

View File

@@ -5,7 +5,7 @@ from application.parser.remote.base import BaseRemote
class CrawlerLoader(BaseRemote):
def __init__(self, limit=10):
from langchain.document_loaders import WebBaseLoader
from langchain_community.document_loaders import WebBaseLoader
self.loader = WebBaseLoader # Initialize the document loader
self.limit = limit # Set the limit for the number of pages to scrape

View File

@@ -5,7 +5,7 @@ from application.parser.remote.base import BaseRemote
class SitemapLoader(BaseRemote):
def __init__(self, limit=20):
from langchain.document_loaders import WebBaseLoader
from langchain_community.document_loaders import WebBaseLoader
self.loader = WebBaseLoader
self.limit = limit # Adding limit to control the number of URLs to process

View File

@@ -1,34 +1,36 @@
anthropic==0.12.0
boto3==1.34.6
anthropic==0.34.0
boto3==1.34.153
beautifulsoup4==4.12.3
celery==5.3.6
dataclasses_json==0.6.3
dataclasses_json==0.6.7
docx2txt==0.8
duckduckgo-search==5.3.0
duckduckgo-search==6.2.6
EbookLib==0.18
elasticsearch==8.12.0
elasticsearch==8.14.0
escodegen==1.0.11
esprima==4.0.1
faiss-cpu==1.7.4
Flask==3.0.1
gunicorn==22.0.0
faiss-cpu==1.8.0.post1
gunicorn==23.0.0
html2text==2020.1.16
javalang==0.13.0
langchain==0.1.4
langchain-openai==0.0.5
langchain==0.2.16
langchain-community==0.2.16
langchain-core==0.2.38
langchain-openai==0.1.23
openapi3_parser==1.1.16
pandas==2.2.0
pydantic_settings==2.1.0
pymongo==4.6.3
pandas==2.2.2
pydantic_settings==2.4.0
pymongo==4.8.0
PyPDF2==3.0.1
python-dotenv==1.0.1
qdrant-client==1.9.0
qdrant-client==1.11.0
redis==5.0.1
Requests==2.32.0
retry==0.9.2
sentence-transformers
tiktoken
tiktoken==0.7.0
torch
tqdm==4.66.3
transformers==4.36.2
unstructured==0.12.2
transformers==4.44.0
Werkzeug==3.0.3

View File

@@ -2,7 +2,7 @@ import json
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import count_tokens
from application.utils import num_tokens_from_string
from langchain_community.tools import BraveSearch
@@ -78,7 +78,7 @@ class BraveRetSearch(BaseRetriever):
self.chat_history.reverse()
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
if tokens_current_history + tokens_batch < self.token_limit:

View File

@@ -3,7 +3,7 @@ from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
from application.llm.llm_creator import LLMCreator
from application.utils import count_tokens
from application.utils import num_tokens_from_string
class ClassicRAG(BaseRetriever):
@@ -82,7 +82,7 @@ class ClassicRAG(BaseRetriever):
self.chat_history.reverse()
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
if tokens_current_history + tokens_batch < self.token_limit:

View File

@@ -1,7 +1,7 @@
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import count_tokens
from application.utils import num_tokens_from_string
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
@@ -95,7 +95,7 @@ class DuckDuckSearch(BaseRetriever):
self.chat_history.reverse()
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
if tokens_current_history + tokens_batch < self.token_limit:

View File

@@ -5,15 +5,16 @@ from application.retriever.brave_search import BraveRetSearch
class RetrieverCreator:
retievers = {
retrievers = {
'classic': ClassicRAG,
'duckduck_search': DuckDuckSearch,
'brave_search': BraveRetSearch
'brave_search': BraveRetSearch,
'default': ClassicRAG
}
@classmethod
def create_retriever(cls, type, *args, **kwargs):
retiever_class = cls.retievers.get(type.lower())
retiever_class = cls.retrievers.get(type.lower())
if not retiever_class:
raise ValueError(f"No retievers class found for type {type}")
return retiever_class(*args, **kwargs)

View File

@@ -2,7 +2,7 @@ import sys
from pymongo import MongoClient
from datetime import datetime
from application.core.settings import settings
from application.utils import count_tokens
from application.utils import num_tokens_from_string
mongo = MongoClient(settings.MONGO_URI)
db = mongo["docsgpt"]
@@ -24,9 +24,9 @@ def update_token_usage(user_api_key, token_usage):
def gen_token_usage(func):
def wrapper(self, model, messages, stream, **kwargs):
for message in messages:
self.token_usage["prompt_tokens"] += count_tokens(message["content"])
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
result = func(self, model, messages, stream, **kwargs)
self.token_usage["generated_tokens"] += count_tokens(result)
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
update_token_usage(self.user_api_key, self.token_usage)
return result
@@ -36,14 +36,14 @@ def gen_token_usage(func):
def stream_token_usage(func):
def wrapper(self, model, messages, stream, **kwargs):
for message in messages:
self.token_usage["prompt_tokens"] += count_tokens(message["content"])
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
batch = []
result = func(self, model, messages, stream, **kwargs)
for r in result:
batch.append(r)
yield r
for line in batch:
self.token_usage["generated_tokens"] += count_tokens(line)
self.token_usage["generated_tokens"] += num_tokens_from_string(line)
update_token_usage(self.user_api_key, self.token_usage)
return wrapper

View File

@@ -1,6 +1,22 @@
from transformers import GPT2TokenizerFast
import tiktoken
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
tokenizer.model_max_length = 100000
def count_tokens(string):
return len(tokenizer(string)['input_ids'])
_encoding = None
def get_encoding():
global _encoding
if _encoding is None:
_encoding = tiktoken.get_encoding("cl100k_base")
return _encoding
def num_tokens_from_string(string: str) -> int:
encoding = get_encoding()
num_tokens = len(encoding.encode(string))
return num_tokens
def count_tokens_docs(docs):
docs_content = ""
for doc in docs:
docs_content += doc.page_content
tokens = num_tokens_from_string(docs_content)
return tokens

View File

@@ -1,13 +1,30 @@
from abc import ABC, abstractmethod
import os
from langchain_community.embeddings import (
HuggingFaceEmbeddings,
CohereEmbeddings,
HuggingFaceInstructEmbeddings,
)
from sentence_transformers import SentenceTransformer
from langchain_openai import OpenAIEmbeddings
from application.core.settings import settings
class EmbeddingsWrapper:
def __init__(self, model_name, *args, **kwargs):
self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs)
self.dimension = self.model.get_sentence_embedding_dimension()
def embed_query(self, query: str):
return self.model.encode(query).tolist()
def embed_documents(self, documents: list):
return self.model.encode(documents).tolist()
def __call__(self, text):
if isinstance(text, str):
return self.embed_query(text)
elif isinstance(text, list):
return self.embed_documents(text)
else:
raise ValueError("Input must be a string or a list of strings")
class EmbeddingsSingleton:
_instances = {}
@@ -23,16 +40,15 @@ class EmbeddingsSingleton:
def _create_instance(embeddings_name, *args, **kwargs):
embeddings_factory = {
"openai_text-embedding-ada-002": OpenAIEmbeddings,
"huggingface_sentence-transformers/all-mpnet-base-v2": HuggingFaceEmbeddings,
"huggingface_sentence-transformers-all-mpnet-base-v2": HuggingFaceEmbeddings,
"huggingface_hkunlp/instructor-large": HuggingFaceInstructEmbeddings,
"cohere_medium": CohereEmbeddings
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"),
}
if embeddings_name not in embeddings_factory:
raise ValueError(f"Invalid embeddings_name: {embeddings_name}")
return embeddings_factory[embeddings_name](*args, **kwargs)
if embeddings_name in embeddings_factory:
return embeddings_factory[embeddings_name](*args, **kwargs)
else:
return EmbeddingsWrapper(embeddings_name, *args, **kwargs)
class BaseVectorStore(ABC):
def __init__(self):
@@ -58,22 +74,14 @@ class BaseVectorStore(ABC):
embeddings_name,
openai_api_key=embeddings_key
)
elif embeddings_name == "cohere_medium":
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
cohere_api_key=embeddings_key
)
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
if os.path.exists("./model/all-mpnet-base-v2"):
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
model_name="./model/all-mpnet-base-v2",
model_kwargs={"device": "cpu"}
embeddings_name="./model/all-mpnet-base-v2",
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
model_kwargs={"device": "cpu"}
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)

View File

@@ -24,7 +24,8 @@ class FaissStore(BaseVectorStore):
)
else:
self.docsearch = FAISS.load_local(
self.path, embeddings
self.path, embeddings,
allow_dangerous_deserialization=True
)
self.assert_embedding_dimensions(embeddings)
@@ -47,10 +48,10 @@ class FaissStore(BaseVectorStore):
"""
if settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2":
try:
word_embedding_dimension = embeddings.client[1].word_embedding_dimension
word_embedding_dimension = embeddings.dimension
except AttributeError as e:
raise AttributeError("word_embedding_dimension not found in embeddings.client[1]") from e
raise AttributeError("'dimension' attribute not found in embeddings instance. Make sure the embeddings object is properly initialized.") from e
docsearch_index_dimension = self.docsearch.index.d
if word_embedding_dimension != docsearch_index_dimension:
raise ValueError(f"word_embedding_dimension ({word_embedding_dimension}) " +
f"!= docsearch_index_word_embedding_dimension ({docsearch_index_dimension})")
raise ValueError(f"Embedding dimension mismatch: embeddings.dimension ({word_embedding_dimension}) " +
f"!= docsearch index dimension ({docsearch_index_dimension})")

View File

@@ -0,0 +1,37 @@
from typing import List, Optional
from uuid import uuid4
from application.core.settings import settings
from application.vectorstore.base import BaseVectorStore
class MilvusStore(BaseVectorStore):
def __init__(self, path: str = "", embeddings_key: str = "embeddings"):
super().__init__()
from langchain_milvus import Milvus
connection_args = {
"uri": settings.MILVUS_URI,
"token": settings.MILVUS_TOKEN,
}
self._docsearch = Milvus(
embedding_function=self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key),
collection_name=settings.MILVUS_COLLECTION_NAME,
connection_args=connection_args,
)
self._path = path
def search(self, question, k=2, *args, **kwargs):
return self._docsearch.similarity_search(query=question, k=k, filter={"path": self._path} *args, **kwargs)
def add_texts(self, texts: List[str], metadatas: Optional[List[dict]], *args, **kwargs):
ids = [str(uuid4()) for _ in range(len(texts))]
return self._docsearch.add_texts(texts=texts, metadatas=metadatas, ids=ids, *args, **kwargs)
def save_local(self, *args, **kwargs):
pass
def delete_index(self, *args, **kwargs):
pass

View File

@@ -1,5 +1,6 @@
from application.vectorstore.faiss import FaissStore
from application.vectorstore.elasticsearch import ElasticsearchStore
from application.vectorstore.milvus import MilvusStore
from application.vectorstore.mongodb import MongoDBVectorStore
from application.vectorstore.qdrant import QdrantStore
@@ -10,6 +11,7 @@ class VectorCreator:
"elasticsearch": ElasticsearchStore,
"mongodb": MongoDBVectorStore,
"qdrant": QdrantStore,
"milvus": MilvusStore,
}
@classmethod

View File

@@ -2,8 +2,8 @@ import os
import shutil
import string
import zipfile
import tiktoken
from urllib.parse import urljoin
import logging
import requests
from bson.objectid import ObjectId
@@ -14,6 +14,8 @@ from application.parser.remote.remote_creator import RemoteCreator
from application.parser.open_ai_func import call_openai_api
from application.parser.schema.base import Document
from application.parser.token_func import group_split
from application.utils import count_tokens_docs
# Define a function to extract metadata from a given filename.
@@ -40,7 +42,7 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
max_depth (int): Maximum allowed depth of recursion to prevent infinite loops.
"""
if current_depth > max_depth:
print(f"Reached maximum recursion depth of {max_depth}")
logging.warning(f"Reached maximum recursion depth of {max_depth}")
return
with zipfile.ZipFile(zip_path, "r") as zip_ref:
@@ -88,14 +90,13 @@ def ingest_worker(self, directory, formats, name_job, filename, user, retriever=
max_tokens = 1250
recursion_depth = 2
full_path = os.path.join(directory, user, name_job)
import sys
print(full_path, file=sys.stderr)
logging.info(f"Ingest file: {full_path}", extra={"user": user, "job": name_job})
# check if API_URL env variable is set
file_data = {"name": name_job, "file": filename, "user": user}
response = requests.get(urljoin(settings.API_URL, "/api/download"), params=file_data)
# check if file is in the response
print(response, file=sys.stderr)
response = requests.get(
urljoin(settings.API_URL, "/api/download"), params=file_data
)
file = response.content
if not os.path.exists(full_path):
@@ -134,7 +135,7 @@ def ingest_worker(self, directory, formats, name_job, filename, user, retriever=
if sample:
for i in range(min(5, len(raw_docs))):
print(raw_docs[i].text)
logging.info(f"Sample document {i}: {raw_docs[i]}")
# get files from outputs/inputs/index.faiss and outputs/inputs/index.pkl
# and send them to the server (provide user and name in form)
@@ -170,6 +171,7 @@ def remote_worker(self, source_data, name_job, user, loader, directory="temp", r
if not os.path.exists(full_path):
os.makedirs(full_path)
self.update_state(state="PROGRESS", meta={"current": 1})
logging.info(f"Remote job: {full_path}", extra={"user": user, "job": name_job, source_data: source_data})
remote_loader = RemoteCreator.create_loader(loader)
raw_docs = remote_loader.load_data(source_data)
@@ -202,23 +204,3 @@ def remote_worker(self, source_data, name_job, user, loader, directory="temp", r
shutil.rmtree(full_path)
return {"urls": source_data, "name_job": name_job, "user": user, "limited": False}
def count_tokens_docs(docs):
# Here we convert the docs list to a string and calculate the number of tokens the string represents.
# docs_content = (" ".join(docs))
docs_content = ""
for doc in docs:
docs_content += doc.page_content
tokens, total_price = num_tokens_from_string(string=docs_content, encoding_name="cl100k_base")
# Here we print the number of tokens and the approx user cost with some visually appealing formatting.
return tokens
def num_tokens_from_string(string: str, encoding_name: str) -> int:
# Function to convert string to tokens and estimate user cost.
encoding = tiktoken.get_encoding(encoding_name)
num_tokens = len(encoding.encode(string))
total_price = (num_tokens / 1000) * 0.0004
return num_tokens, total_price