mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
Merge branch 'main' into feat/analytics-and-logs
This commit is contained in:
@@ -9,6 +9,7 @@ import traceback
|
||||
|
||||
from pymongo import MongoClient
|
||||
from bson.objectid import ObjectId
|
||||
from bson.dbref import DBRef
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
@@ -20,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
mongo = MongoClient(settings.MONGO_URI)
|
||||
db = mongo["docsgpt"]
|
||||
conversations_collection = db["conversations"]
|
||||
vectors_collection = db["vectors"]
|
||||
sources_collection = db["sources"]
|
||||
prompts_collection = db["prompts"]
|
||||
api_key_collection = db["api_keys"]
|
||||
user_logs_collection = db["user_logs"]
|
||||
@@ -37,9 +38,7 @@ if settings.MODEL_NAME: # in case there is particular model name configured
|
||||
gpt_model = settings.MODEL_NAME
|
||||
|
||||
# load the prompts
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
|
||||
chat_combine_template = f.read()
|
||||
|
||||
@@ -75,35 +74,34 @@ def run_async_chain(chain, question, chat_history):
|
||||
|
||||
def get_data_from_api_key(api_key):
|
||||
data = api_key_collection.find_one({"key": api_key})
|
||||
|
||||
# # Raise custom exception if the API key is not found
|
||||
if data is None:
|
||||
raise Exception("Invalid API Key, please generate new key", 401)
|
||||
|
||||
if "retriever" not in data:
|
||||
data["retriever"] = None
|
||||
|
||||
if "source" in data and isinstance(data["source"], DBRef):
|
||||
source_doc = db.dereference(data["source"])
|
||||
data["source"] = str(source_doc["_id"])
|
||||
if "retriever" in source_doc:
|
||||
data["retriever"] = source_doc["retriever"]
|
||||
else:
|
||||
data["source"] = {}
|
||||
return data
|
||||
|
||||
|
||||
def get_vectorstore(data):
|
||||
if "active_docs" in data:
|
||||
if data["active_docs"].split("/")[0] == "default":
|
||||
vectorstore = ""
|
||||
elif data["active_docs"].split("/")[0] == "local":
|
||||
vectorstore = "indexes/" + data["active_docs"]
|
||||
else:
|
||||
vectorstore = "vectors/" + data["active_docs"]
|
||||
if data["active_docs"] == "default":
|
||||
vectorstore = ""
|
||||
else:
|
||||
vectorstore = ""
|
||||
vectorstore = os.path.join("application", vectorstore)
|
||||
return vectorstore
|
||||
def get_retriever(source_id: str):
|
||||
doc = sources_collection.find_one({"_id": ObjectId(source_id)})
|
||||
if doc is None:
|
||||
raise Exception("Source document does not exist", 404)
|
||||
retriever_name = None if "retriever" not in doc else doc["retriever"]
|
||||
return retriever_name
|
||||
|
||||
|
||||
|
||||
def is_azure_configured():
|
||||
return (
|
||||
settings.OPENAI_API_BASE
|
||||
and settings.OPENAI_API_VERSION
|
||||
and settings.AZURE_DEPLOYMENT_NAME
|
||||
)
|
||||
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
|
||||
|
||||
|
||||
def save_conversation(conversation_id, question, response, source_log_docs, llm):
|
||||
@@ -263,33 +261,33 @@ def stream():
|
||||
else:
|
||||
token_limit = settings.DEFAULT_MAX_HISTORY
|
||||
|
||||
# check if active_docs or api_key is set
|
||||
## retriever can be "brave_search, duckduck_search or classic"
|
||||
retriever_name = data["retriever"] if "retriever" in data else "classic"
|
||||
|
||||
# check if active_docs or api_key is set
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key["chunks"])
|
||||
prompt_id = data_key["prompt_id"]
|
||||
source = {"active_docs": data_key["source"]}
|
||||
retriever_name = data_key["retriever"] or retriever_name
|
||||
user_api_key = data["api_key"]
|
||||
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
source = {"active_docs" : data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
if source["active_docs"].split("/")[0] in ["default", "local"]:
|
||||
retriever_name = "classic"
|
||||
else:
|
||||
retriever_name = source["active_docs"]
|
||||
|
||||
current_app.logger.info(
|
||||
f"/stream - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
current_app.logger.info(f"/stream - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})}
|
||||
)
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
question=question,
|
||||
@@ -369,6 +367,10 @@ def api_answer():
|
||||
else:
|
||||
token_limit = settings.DEFAULT_MAX_HISTORY
|
||||
|
||||
## retriever can be brave_search, duckduck_search or classic
|
||||
retriever_name = data["retriever"] if "retriever" in data else "classic"
|
||||
|
||||
# use try and except to check for exception
|
||||
try:
|
||||
# check if the vectorstore is set
|
||||
if "api_key" in data:
|
||||
@@ -376,15 +378,15 @@ def api_answer():
|
||||
chunks = int(data_key["chunks"])
|
||||
prompt_id = data_key["prompt_id"]
|
||||
source = {"active_docs": data_key["source"]}
|
||||
retriever_name = data_key["retriever"] or retriever_name
|
||||
user_api_key = data["api_key"]
|
||||
else:
|
||||
source = data
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs":data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
|
||||
if source["active_docs"].split("/")[0] in ["default", "local"]:
|
||||
retriever_name = "classic"
|
||||
else:
|
||||
retriever_name = source["active_docs"]
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
@@ -421,8 +423,8 @@ def api_answer():
|
||||
)
|
||||
|
||||
result = {"answer": response_full, "sources": source_log_docs}
|
||||
result["conversation_id"] = save_conversation(
|
||||
conversation_id, question, response_full, source_log_docs, llm
|
||||
result["conversation_id"] = str(
|
||||
save_conversation(conversation_id, question, response_full, source_log_docs, llm)
|
||||
)
|
||||
retriever_params = retriever.get_params()
|
||||
user_logs_collection.insert_one(
|
||||
@@ -459,19 +461,19 @@ def api_search():
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key["chunks"])
|
||||
source = {"active_docs": data_key["source"]}
|
||||
user_api_key = data["api_key"]
|
||||
source = {"active_docs":data_key["source"]}
|
||||
user_api_key = data_key["api_key"]
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
source = {"active_docs":data["active_docs"]}
|
||||
user_api_key = None
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
|
||||
if source["active_docs"].split("/")[0] in ["default", "local"]:
|
||||
retriever_name = "classic"
|
||||
if "retriever" in data:
|
||||
retriever_name = data["retriever"]
|
||||
else:
|
||||
retriever_name = source["active_docs"]
|
||||
retriever_name = "classic"
|
||||
if "token_limit" in data:
|
||||
token_limit = data["token_limit"]
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user