mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-10 08:10:46 +00:00
Merge branch 'main' into 1059-migrating-database-to-new-model
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from flask import Blueprint, request, Response
|
||||
from flask import Blueprint, request, Response, current_app
|
||||
import json
|
||||
import datetime
|
||||
import logging
|
||||
@@ -126,7 +126,11 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm)
|
||||
"role": "assistant",
|
||||
"content": "Summarise following conversation in no more than 3 "
|
||||
"words, respond ONLY with the summary, use the same "
|
||||
"language as the system \n\nUser: " + question + "\n\n" + "AI: " + response,
|
||||
"language as the system \n\nUser: "
|
||||
+ question
|
||||
+ "\n\n"
|
||||
+ "AI: "
|
||||
+ response,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
@@ -166,7 +170,10 @@ def get_prompt(prompt_id):
|
||||
return prompt
|
||||
|
||||
|
||||
def complete_stream(question, retriever, conversation_id, user_api_key):
|
||||
def complete_stream(
|
||||
question, retriever, conversation_id, user_api_key, isNoneDoc=False
|
||||
):
|
||||
|
||||
try:
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
@@ -179,9 +186,17 @@ def complete_stream(question, retriever, conversation_id, user_api_key):
|
||||
elif "source" in line:
|
||||
source_log_docs.append(line["source"])
|
||||
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key)
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
||||
)
|
||||
if user_api_key is None:
|
||||
conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm)
|
||||
conversation_id = save_conversation(
|
||||
conversation_id, question, response_full, source_log_docs, llm
|
||||
)
|
||||
# send data.type = "end" to indicate that the stream has ended as json
|
||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||
yield f"data: {data}\n\n"
|
||||
@@ -205,7 +220,6 @@ def complete_stream(question, retriever, conversation_id, user_api_key):
|
||||
def stream():
|
||||
try:
|
||||
data = request.get_json()
|
||||
# get parameter from url question
|
||||
question = data["question"]
|
||||
if "history" not in data:
|
||||
history = []
|
||||
@@ -252,10 +266,9 @@ def stream():
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
""" if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
|
||||
retriever_name = "classic"
|
||||
else:
|
||||
retriever_name = source["active_docs"] """
|
||||
current_app.logger.info(f"/stream - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})}
|
||||
)
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
@@ -277,20 +290,23 @@ def stream():
|
||||
retriever=retriever,
|
||||
conversation_id=conversation_id,
|
||||
user_api_key=user_api_key,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
except ValueError as err:
|
||||
except ValueError:
|
||||
message = "Malformed request body"
|
||||
print("\033[91merr", str(err), file=sys.stderr)
|
||||
print("\033[91merr", str(message), file=sys.stderr)
|
||||
return Response(
|
||||
error_stream_generate(message),
|
||||
status=400,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
except Exception as e:
|
||||
print("\033[91merr", str(e), file=sys.stderr)
|
||||
current_app.logger.error(f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()}
|
||||
)
|
||||
message = e.args[0]
|
||||
status_code = 400
|
||||
# # Custom exceptions with two arguments, index 1 as status code
|
||||
@@ -357,6 +373,10 @@ def api_answer():
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
current_app.logger.info(f"/api/answer - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})}
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
@@ -376,7 +396,13 @@ def api_answer():
|
||||
elif "answer" in line:
|
||||
response_full += line["answer"]
|
||||
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key)
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
||||
)
|
||||
|
||||
result = {"answer": response_full, "sources": source_log_docs}
|
||||
result["conversation_id"] = str(
|
||||
@@ -385,16 +411,15 @@ def api_answer():
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
# print whole traceback
|
||||
traceback.print_exc()
|
||||
print(str(e))
|
||||
current_app.logger.error(f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()}
|
||||
)
|
||||
return bad_request(500, str(e))
|
||||
|
||||
|
||||
@answer.route("/api/search", methods=["POST"])
|
||||
def api_search():
|
||||
data = request.get_json()
|
||||
# get parameter from url question
|
||||
question = data["question"]
|
||||
if "chunks" in data:
|
||||
chunks = int(data["chunks"])
|
||||
@@ -420,6 +445,10 @@ def api_search():
|
||||
token_limit = data["token_limit"]
|
||||
else:
|
||||
token_limit = settings.DEFAULT_MAX_HISTORY
|
||||
|
||||
current_app.logger.info(f"/api/answer - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})}
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
@@ -433,4 +462,9 @@ def api_search():
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
docs = retriever.search()
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
return docs
|
||||
|
||||
@@ -6,12 +6,14 @@ from application.core.settings import settings
|
||||
from application.api.user.routes import user
|
||||
from application.api.answer.routes import answer
|
||||
from application.api.internal.routes import internal
|
||||
from application.core.logging_config import setup_logging
|
||||
|
||||
if platform.system() == "Windows":
|
||||
import pathlib
|
||||
pathlib.PosixPath = pathlib.WindowsPath
|
||||
|
||||
dotenv.load_dotenv()
|
||||
setup_logging()
|
||||
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(user)
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
from celery import Celery
|
||||
from application.core.settings import settings
|
||||
from celery.signals import setup_logging
|
||||
|
||||
def make_celery(app_name=__name__):
|
||||
celery = Celery(app_name, broker=settings.CELERY_BROKER_URL, backend=settings.CELERY_RESULT_BACKEND)
|
||||
celery.conf.update(settings)
|
||||
return celery
|
||||
|
||||
@setup_logging.connect
|
||||
def config_loggers(*args, **kwargs):
|
||||
from application.core.logging_config import setup_logging
|
||||
setup_logging()
|
||||
|
||||
celery = make_celery()
|
||||
|
||||
22
application/core/logging_config.py
Normal file
22
application/core/logging_config.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from logging.config import dictConfig
|
||||
|
||||
def setup_logging():
|
||||
dictConfig({
|
||||
'version': 1,
|
||||
'formatters': {
|
||||
'default': {
|
||||
'format': '[%(asctime)s] %(levelname)s in %(module)s: %(message)s',
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": "ext://sys.stdout",
|
||||
"formatter": "default",
|
||||
}
|
||||
},
|
||||
'root': {
|
||||
'level': 'INFO',
|
||||
'handlers': ['console'],
|
||||
},
|
||||
})
|
||||
@@ -18,7 +18,7 @@ class Settings(BaseSettings):
|
||||
DEFAULT_MAX_HISTORY: int = 150
|
||||
MODEL_TOKEN_LIMITS: dict = {"gpt-3.5-turbo": 4096, "claude-2": 1e5}
|
||||
UPLOAD_FOLDER: str = "inputs"
|
||||
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant"
|
||||
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus"
|
||||
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
|
||||
|
||||
API_URL: str = "http://localhost:7091" # backend url for celery worker
|
||||
@@ -29,6 +29,7 @@ class Settings(BaseSettings):
|
||||
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
|
||||
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings
|
||||
OPENAI_BASE_URL: Optional[str] = None # openai base url for open ai compatable models
|
||||
|
||||
# elasticsearch
|
||||
ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch
|
||||
@@ -61,6 +62,11 @@ class Settings(BaseSettings):
|
||||
QDRANT_PATH: Optional[str] = None
|
||||
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
||||
|
||||
# Milvus vectorstore config
|
||||
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
|
||||
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
|
||||
MILVUS_TOKEN: Optional[str] = ""
|
||||
|
||||
BRAVE_SEARCH_API_KEY: Optional[str] = None
|
||||
|
||||
FLASK_DEBUG_MODE: bool = False
|
||||
|
||||
@@ -2,25 +2,23 @@ from application.llm.base import BaseLLM
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
global openai
|
||||
from openai import OpenAI
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client = OpenAI(
|
||||
api_key=api_key,
|
||||
)
|
||||
if settings.OPENAI_BASE_URL:
|
||||
self.client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=settings.OPENAI_BASE_URL
|
||||
)
|
||||
else:
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _get_openai(self):
|
||||
# Import openai when needed
|
||||
import openai
|
||||
|
||||
return openai
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
baseself,
|
||||
@@ -29,7 +27,7 @@ class OpenAILLM(BaseLLM):
|
||||
stream=False,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
**kwargs
|
||||
):
|
||||
):
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
@@ -44,7 +42,7 @@ class OpenAILLM(BaseLLM):
|
||||
stream=True,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
**kwargs
|
||||
):
|
||||
):
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
@@ -73,8 +71,3 @@ class AzureOpenAILLM(OpenAILLM):
|
||||
api_base=settings.OPENAI_API_BASE,
|
||||
deployment_name=settings.AZURE_DEPLOYMENT_NAME,
|
||||
)
|
||||
|
||||
def _get_openai(self):
|
||||
openai = super()._get_openai()
|
||||
|
||||
return openai
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
Contains parser for html files.
|
||||
|
||||
"""
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
@@ -18,66 +17,8 @@ class HTMLParser(BaseParser):
|
||||
return {}
|
||||
|
||||
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, list[str]]:
|
||||
"""Parse file.
|
||||
from langchain_community.document_loaders import BSHTMLLoader
|
||||
|
||||
Returns:
|
||||
Union[str, List[str]]: a string or a List of strings.
|
||||
"""
|
||||
try:
|
||||
from unstructured.partition.html import partition_html
|
||||
from unstructured.staging.base import convert_to_isd
|
||||
from unstructured.cleaners.core import clean
|
||||
except ImportError:
|
||||
raise ValueError("unstructured package is required to parse HTML files.")
|
||||
|
||||
# Using the unstructured library to convert the html to isd format
|
||||
# isd sample : isd = [
|
||||
# {"text": "My Title", "type": "Title"},
|
||||
# {"text": "My Narrative", "type": "NarrativeText"}
|
||||
# ]
|
||||
with open(file, "r", encoding="utf-8") as fp:
|
||||
elements = partition_html(file=fp)
|
||||
isd = convert_to_isd(elements)
|
||||
|
||||
# Removing non ascii charactwers from isd_el['text']
|
||||
for isd_el in isd:
|
||||
isd_el['text'] = isd_el['text'].encode("ascii", "ignore").decode()
|
||||
|
||||
# Removing all the \n characters from isd_el['text'] using regex and replace with single space
|
||||
# Removing all the extra spaces from isd_el['text'] using regex and replace with single space
|
||||
for isd_el in isd:
|
||||
isd_el['text'] = re.sub(r'\n', ' ', isd_el['text'], flags=re.MULTILINE | re.DOTALL)
|
||||
isd_el['text'] = re.sub(r"\s{2,}", " ", isd_el['text'], flags=re.MULTILINE | re.DOTALL)
|
||||
|
||||
# more cleaning: extra_whitespaces, dashes, bullets, trailing_punctuation
|
||||
for isd_el in isd:
|
||||
clean(isd_el['text'], extra_whitespace=True, dashes=True, bullets=True, trailing_punctuation=True)
|
||||
|
||||
# Creating a list of all the indexes of isd_el['type'] = 'Title'
|
||||
title_indexes = [i for i, isd_el in enumerate(isd) if isd_el['type'] == 'Title']
|
||||
|
||||
# Creating 'Chunks' - List of lists of strings
|
||||
# each list starting with isd_el['type'] = 'Title' and all the data till the next 'Title'
|
||||
# Each Chunk can be thought of as an individual set of data, which can be sent to the model
|
||||
# Where Each Title is grouped together with the data under it
|
||||
|
||||
Chunks = [[]]
|
||||
final_chunks = list(list())
|
||||
|
||||
for i, isd_el in enumerate(isd):
|
||||
if i in title_indexes:
|
||||
Chunks.append([])
|
||||
Chunks[-1].append(isd_el['text'])
|
||||
|
||||
# Removing all the chunks with sum of length of all the strings in the chunk < 25
|
||||
# TODO: This value can be an user defined variable
|
||||
for chunk in Chunks:
|
||||
# sum of length of all the strings in the chunk
|
||||
sum = 0
|
||||
sum += len(str(chunk))
|
||||
if sum < 25:
|
||||
Chunks.remove(chunk)
|
||||
else:
|
||||
# appending all the approved chunks to final_chunks as a single string
|
||||
final_chunks.append(" ".join([str(item) for item in chunk]))
|
||||
return final_chunks
|
||||
loader = BSHTMLLoader(file)
|
||||
data = loader.load()
|
||||
return data
|
||||
|
||||
@@ -5,7 +5,7 @@ from application.parser.remote.base import BaseRemote
|
||||
|
||||
class CrawlerLoader(BaseRemote):
|
||||
def __init__(self, limit=10):
|
||||
from langchain.document_loaders import WebBaseLoader
|
||||
from langchain_community.document_loaders import WebBaseLoader
|
||||
self.loader = WebBaseLoader # Initialize the document loader
|
||||
self.limit = limit # Set the limit for the number of pages to scrape
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from application.parser.remote.base import BaseRemote
|
||||
|
||||
class SitemapLoader(BaseRemote):
|
||||
def __init__(self, limit=20):
|
||||
from langchain.document_loaders import WebBaseLoader
|
||||
from langchain_community.document_loaders import WebBaseLoader
|
||||
self.loader = WebBaseLoader
|
||||
self.limit = limit # Adding limit to control the number of URLs to process
|
||||
|
||||
|
||||
@@ -1,34 +1,36 @@
|
||||
anthropic==0.12.0
|
||||
boto3==1.34.6
|
||||
anthropic==0.34.0
|
||||
boto3==1.34.153
|
||||
beautifulsoup4==4.12.3
|
||||
celery==5.3.6
|
||||
dataclasses_json==0.6.3
|
||||
dataclasses_json==0.6.7
|
||||
docx2txt==0.8
|
||||
duckduckgo-search==5.3.0
|
||||
duckduckgo-search==6.2.6
|
||||
EbookLib==0.18
|
||||
elasticsearch==8.12.0
|
||||
elasticsearch==8.14.0
|
||||
escodegen==1.0.11
|
||||
esprima==4.0.1
|
||||
faiss-cpu==1.7.4
|
||||
Flask==3.0.1
|
||||
gunicorn==22.0.0
|
||||
faiss-cpu==1.8.0.post1
|
||||
gunicorn==23.0.0
|
||||
html2text==2020.1.16
|
||||
javalang==0.13.0
|
||||
langchain==0.1.4
|
||||
langchain-openai==0.0.5
|
||||
langchain==0.2.16
|
||||
langchain-community==0.2.16
|
||||
langchain-core==0.2.38
|
||||
langchain-openai==0.1.23
|
||||
openapi3_parser==1.1.16
|
||||
pandas==2.2.0
|
||||
pydantic_settings==2.1.0
|
||||
pymongo==4.6.3
|
||||
pandas==2.2.2
|
||||
pydantic_settings==2.4.0
|
||||
pymongo==4.8.0
|
||||
PyPDF2==3.0.1
|
||||
python-dotenv==1.0.1
|
||||
qdrant-client==1.9.0
|
||||
qdrant-client==1.11.0
|
||||
redis==5.0.1
|
||||
Requests==2.32.0
|
||||
retry==0.9.2
|
||||
sentence-transformers
|
||||
tiktoken
|
||||
tiktoken==0.7.0
|
||||
torch
|
||||
tqdm==4.66.3
|
||||
transformers==4.36.2
|
||||
unstructured==0.12.2
|
||||
transformers==4.44.0
|
||||
Werkzeug==3.0.3
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.utils import count_tokens
|
||||
from application.utils import num_tokens_from_string
|
||||
from langchain_community.tools import BraveSearch
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ class BraveRetSearch(BaseRetriever):
|
||||
self.chat_history.reverse()
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(
|
||||
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
|
||||
i["response"]
|
||||
)
|
||||
if tokens_current_history + tokens_batch < self.token_limit:
|
||||
|
||||
@@ -3,7 +3,7 @@ from application.core.settings import settings
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
from application.utils import count_tokens
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
|
||||
class ClassicRAG(BaseRetriever):
|
||||
@@ -82,7 +82,7 @@ class ClassicRAG(BaseRetriever):
|
||||
self.chat_history.reverse()
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(
|
||||
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
|
||||
i["response"]
|
||||
)
|
||||
if tokens_current_history + tokens_batch < self.token_limit:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.utils import count_tokens
|
||||
from application.utils import num_tokens_from_string
|
||||
from langchain_community.tools import DuckDuckGoSearchResults
|
||||
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
|
||||
|
||||
@@ -95,7 +95,7 @@ class DuckDuckSearch(BaseRetriever):
|
||||
self.chat_history.reverse()
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(
|
||||
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
|
||||
i["response"]
|
||||
)
|
||||
if tokens_current_history + tokens_batch < self.token_limit:
|
||||
|
||||
@@ -5,15 +5,16 @@ from application.retriever.brave_search import BraveRetSearch
|
||||
|
||||
|
||||
class RetrieverCreator:
|
||||
retievers = {
|
||||
retrievers = {
|
||||
'classic': ClassicRAG,
|
||||
'duckduck_search': DuckDuckSearch,
|
||||
'brave_search': BraveRetSearch
|
||||
'brave_search': BraveRetSearch,
|
||||
'default': ClassicRAG
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_retriever(cls, type, *args, **kwargs):
|
||||
retiever_class = cls.retievers.get(type.lower())
|
||||
retiever_class = cls.retrievers.get(type.lower())
|
||||
if not retiever_class:
|
||||
raise ValueError(f"No retievers class found for type {type}")
|
||||
return retiever_class(*args, **kwargs)
|
||||
@@ -2,7 +2,7 @@ import sys
|
||||
from pymongo import MongoClient
|
||||
from datetime import datetime
|
||||
from application.core.settings import settings
|
||||
from application.utils import count_tokens
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
mongo = MongoClient(settings.MONGO_URI)
|
||||
db = mongo["docsgpt"]
|
||||
@@ -24,9 +24,9 @@ def update_token_usage(user_api_key, token_usage):
|
||||
def gen_token_usage(func):
|
||||
def wrapper(self, model, messages, stream, **kwargs):
|
||||
for message in messages:
|
||||
self.token_usage["prompt_tokens"] += count_tokens(message["content"])
|
||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
|
||||
result = func(self, model, messages, stream, **kwargs)
|
||||
self.token_usage["generated_tokens"] += count_tokens(result)
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
|
||||
update_token_usage(self.user_api_key, self.token_usage)
|
||||
return result
|
||||
|
||||
@@ -36,14 +36,14 @@ def gen_token_usage(func):
|
||||
def stream_token_usage(func):
|
||||
def wrapper(self, model, messages, stream, **kwargs):
|
||||
for message in messages:
|
||||
self.token_usage["prompt_tokens"] += count_tokens(message["content"])
|
||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
|
||||
batch = []
|
||||
result = func(self, model, messages, stream, **kwargs)
|
||||
for r in result:
|
||||
batch.append(r)
|
||||
yield r
|
||||
for line in batch:
|
||||
self.token_usage["generated_tokens"] += count_tokens(line)
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_string(line)
|
||||
update_token_usage(self.user_api_key, self.token_usage)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -1,6 +1,22 @@
|
||||
from transformers import GPT2TokenizerFast
|
||||
import tiktoken
|
||||
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
|
||||
tokenizer.model_max_length = 100000
|
||||
def count_tokens(string):
|
||||
return len(tokenizer(string)['input_ids'])
|
||||
_encoding = None
|
||||
|
||||
def get_encoding():
|
||||
global _encoding
|
||||
if _encoding is None:
|
||||
_encoding = tiktoken.get_encoding("cl100k_base")
|
||||
return _encoding
|
||||
|
||||
def num_tokens_from_string(string: str) -> int:
|
||||
encoding = get_encoding()
|
||||
num_tokens = len(encoding.encode(string))
|
||||
return num_tokens
|
||||
|
||||
def count_tokens_docs(docs):
|
||||
docs_content = ""
|
||||
for doc in docs:
|
||||
docs_content += doc.page_content
|
||||
|
||||
tokens = num_tokens_from_string(docs_content)
|
||||
return tokens
|
||||
@@ -1,13 +1,30 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
from langchain_community.embeddings import (
|
||||
HuggingFaceEmbeddings,
|
||||
CohereEmbeddings,
|
||||
HuggingFaceInstructEmbeddings,
|
||||
)
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from application.core.settings import settings
|
||||
|
||||
class EmbeddingsWrapper:
|
||||
def __init__(self, model_name, *args, **kwargs):
|
||||
self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs)
|
||||
self.dimension = self.model.get_sentence_embedding_dimension()
|
||||
|
||||
def embed_query(self, query: str):
|
||||
return self.model.encode(query).tolist()
|
||||
|
||||
def embed_documents(self, documents: list):
|
||||
return self.model.encode(documents).tolist()
|
||||
|
||||
def __call__(self, text):
|
||||
if isinstance(text, str):
|
||||
return self.embed_query(text)
|
||||
elif isinstance(text, list):
|
||||
return self.embed_documents(text)
|
||||
else:
|
||||
raise ValueError("Input must be a string or a list of strings")
|
||||
|
||||
|
||||
|
||||
class EmbeddingsSingleton:
|
||||
_instances = {}
|
||||
|
||||
@@ -23,16 +40,15 @@ class EmbeddingsSingleton:
|
||||
def _create_instance(embeddings_name, *args, **kwargs):
|
||||
embeddings_factory = {
|
||||
"openai_text-embedding-ada-002": OpenAIEmbeddings,
|
||||
"huggingface_sentence-transformers/all-mpnet-base-v2": HuggingFaceEmbeddings,
|
||||
"huggingface_sentence-transformers-all-mpnet-base-v2": HuggingFaceEmbeddings,
|
||||
"huggingface_hkunlp/instructor-large": HuggingFaceInstructEmbeddings,
|
||||
"cohere_medium": CohereEmbeddings
|
||||
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
|
||||
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
|
||||
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"),
|
||||
}
|
||||
|
||||
if embeddings_name not in embeddings_factory:
|
||||
raise ValueError(f"Invalid embeddings_name: {embeddings_name}")
|
||||
|
||||
return embeddings_factory[embeddings_name](*args, **kwargs)
|
||||
if embeddings_name in embeddings_factory:
|
||||
return embeddings_factory[embeddings_name](*args, **kwargs)
|
||||
else:
|
||||
return EmbeddingsWrapper(embeddings_name, *args, **kwargs)
|
||||
|
||||
class BaseVectorStore(ABC):
|
||||
def __init__(self):
|
||||
@@ -58,22 +74,14 @@ class BaseVectorStore(ABC):
|
||||
embeddings_name,
|
||||
openai_api_key=embeddings_key
|
||||
)
|
||||
elif embeddings_name == "cohere_medium":
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
embeddings_name,
|
||||
cohere_api_key=embeddings_key
|
||||
)
|
||||
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
||||
if os.path.exists("./model/all-mpnet-base-v2"):
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
embeddings_name,
|
||||
model_name="./model/all-mpnet-base-v2",
|
||||
model_kwargs={"device": "cpu"}
|
||||
embeddings_name="./model/all-mpnet-base-v2",
|
||||
)
|
||||
else:
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
embeddings_name,
|
||||
model_kwargs={"device": "cpu"}
|
||||
)
|
||||
else:
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
|
||||
|
||||
@@ -24,7 +24,8 @@ class FaissStore(BaseVectorStore):
|
||||
)
|
||||
else:
|
||||
self.docsearch = FAISS.load_local(
|
||||
self.path, embeddings
|
||||
self.path, embeddings,
|
||||
allow_dangerous_deserialization=True
|
||||
)
|
||||
self.assert_embedding_dimensions(embeddings)
|
||||
|
||||
@@ -47,10 +48,10 @@ class FaissStore(BaseVectorStore):
|
||||
"""
|
||||
if settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
||||
try:
|
||||
word_embedding_dimension = embeddings.client[1].word_embedding_dimension
|
||||
word_embedding_dimension = embeddings.dimension
|
||||
except AttributeError as e:
|
||||
raise AttributeError("word_embedding_dimension not found in embeddings.client[1]") from e
|
||||
raise AttributeError("'dimension' attribute not found in embeddings instance. Make sure the embeddings object is properly initialized.") from e
|
||||
docsearch_index_dimension = self.docsearch.index.d
|
||||
if word_embedding_dimension != docsearch_index_dimension:
|
||||
raise ValueError(f"word_embedding_dimension ({word_embedding_dimension}) " +
|
||||
f"!= docsearch_index_word_embedding_dimension ({docsearch_index_dimension})")
|
||||
raise ValueError(f"Embedding dimension mismatch: embeddings.dimension ({word_embedding_dimension}) " +
|
||||
f"!= docsearch index dimension ({docsearch_index_dimension})")
|
||||
37
application/vectorstore/milvus.py
Normal file
37
application/vectorstore/milvus.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.base import BaseVectorStore
|
||||
|
||||
|
||||
class MilvusStore(BaseVectorStore):
|
||||
def __init__(self, path: str = "", embeddings_key: str = "embeddings"):
|
||||
super().__init__()
|
||||
from langchain_milvus import Milvus
|
||||
|
||||
connection_args = {
|
||||
"uri": settings.MILVUS_URI,
|
||||
"token": settings.MILVUS_TOKEN,
|
||||
}
|
||||
self._docsearch = Milvus(
|
||||
embedding_function=self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key),
|
||||
collection_name=settings.MILVUS_COLLECTION_NAME,
|
||||
connection_args=connection_args,
|
||||
)
|
||||
self._path = path
|
||||
|
||||
def search(self, question, k=2, *args, **kwargs):
|
||||
return self._docsearch.similarity_search(query=question, k=k, filter={"path": self._path} *args, **kwargs)
|
||||
|
||||
def add_texts(self, texts: List[str], metadatas: Optional[List[dict]], *args, **kwargs):
|
||||
ids = [str(uuid4()) for _ in range(len(texts))]
|
||||
|
||||
return self._docsearch.add_texts(texts=texts, metadatas=metadatas, ids=ids, *args, **kwargs)
|
||||
|
||||
def save_local(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def delete_index(self, *args, **kwargs):
|
||||
pass
|
||||
@@ -1,5 +1,6 @@
|
||||
from application.vectorstore.faiss import FaissStore
|
||||
from application.vectorstore.elasticsearch import ElasticsearchStore
|
||||
from application.vectorstore.milvus import MilvusStore
|
||||
from application.vectorstore.mongodb import MongoDBVectorStore
|
||||
from application.vectorstore.qdrant import QdrantStore
|
||||
|
||||
@@ -10,6 +11,7 @@ class VectorCreator:
|
||||
"elasticsearch": ElasticsearchStore,
|
||||
"mongodb": MongoDBVectorStore,
|
||||
"qdrant": QdrantStore,
|
||||
"milvus": MilvusStore,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -2,8 +2,8 @@ import os
|
||||
import shutil
|
||||
import string
|
||||
import zipfile
|
||||
import tiktoken
|
||||
from urllib.parse import urljoin
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from bson.objectid import ObjectId
|
||||
@@ -14,6 +14,8 @@ from application.parser.remote.remote_creator import RemoteCreator
|
||||
from application.parser.open_ai_func import call_openai_api
|
||||
from application.parser.schema.base import Document
|
||||
from application.parser.token_func import group_split
|
||||
from application.utils import count_tokens_docs
|
||||
|
||||
|
||||
|
||||
# Define a function to extract metadata from a given filename.
|
||||
@@ -40,7 +42,7 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
|
||||
max_depth (int): Maximum allowed depth of recursion to prevent infinite loops.
|
||||
"""
|
||||
if current_depth > max_depth:
|
||||
print(f"Reached maximum recursion depth of {max_depth}")
|
||||
logging.warning(f"Reached maximum recursion depth of {max_depth}")
|
||||
return
|
||||
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
@@ -88,14 +90,13 @@ def ingest_worker(self, directory, formats, name_job, filename, user, retriever=
|
||||
max_tokens = 1250
|
||||
recursion_depth = 2
|
||||
full_path = os.path.join(directory, user, name_job)
|
||||
import sys
|
||||
|
||||
print(full_path, file=sys.stderr)
|
||||
logging.info(f"Ingest file: {full_path}", extra={"user": user, "job": name_job})
|
||||
# check if API_URL env variable is set
|
||||
file_data = {"name": name_job, "file": filename, "user": user}
|
||||
response = requests.get(urljoin(settings.API_URL, "/api/download"), params=file_data)
|
||||
# check if file is in the response
|
||||
print(response, file=sys.stderr)
|
||||
response = requests.get(
|
||||
urljoin(settings.API_URL, "/api/download"), params=file_data
|
||||
)
|
||||
file = response.content
|
||||
|
||||
if not os.path.exists(full_path):
|
||||
@@ -134,7 +135,7 @@ def ingest_worker(self, directory, formats, name_job, filename, user, retriever=
|
||||
|
||||
if sample:
|
||||
for i in range(min(5, len(raw_docs))):
|
||||
print(raw_docs[i].text)
|
||||
logging.info(f"Sample document {i}: {raw_docs[i]}")
|
||||
|
||||
# get files from outputs/inputs/index.faiss and outputs/inputs/index.pkl
|
||||
# and send them to the server (provide user and name in form)
|
||||
@@ -170,6 +171,7 @@ def remote_worker(self, source_data, name_job, user, loader, directory="temp", r
|
||||
if not os.path.exists(full_path):
|
||||
os.makedirs(full_path)
|
||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||
logging.info(f"Remote job: {full_path}", extra={"user": user, "job": name_job, source_data: source_data})
|
||||
|
||||
remote_loader = RemoteCreator.create_loader(loader)
|
||||
raw_docs = remote_loader.load_data(source_data)
|
||||
@@ -202,23 +204,3 @@ def remote_worker(self, source_data, name_job, user, loader, directory="temp", r
|
||||
shutil.rmtree(full_path)
|
||||
|
||||
return {"urls": source_data, "name_job": name_job, "user": user, "limited": False}
|
||||
|
||||
|
||||
def count_tokens_docs(docs):
|
||||
# Here we convert the docs list to a string and calculate the number of tokens the string represents.
|
||||
# docs_content = (" ".join(docs))
|
||||
docs_content = ""
|
||||
for doc in docs:
|
||||
docs_content += doc.page_content
|
||||
|
||||
tokens, total_price = num_tokens_from_string(string=docs_content, encoding_name="cl100k_base")
|
||||
# Here we print the number of tokens and the approx user cost with some visually appealing formatting.
|
||||
return tokens
|
||||
|
||||
|
||||
def num_tokens_from_string(string: str, encoding_name: str) -> int:
|
||||
# Function to convert string to tokens and estimate user cost.
|
||||
encoding = tiktoken.get_encoding(encoding_name)
|
||||
num_tokens = len(encoding.encode(string))
|
||||
total_price = (num_tokens / 1000) * 0.0004
|
||||
return num_tokens, total_price
|
||||
|
||||
Reference in New Issue
Block a user