Refactor conversationSlice.ts and conversationApi.ts

This commit is contained in:
Alex
2024-03-28 13:43:10 +00:00
parent e146922367
commit 97fabf51b8
4 changed files with 78 additions and 48 deletions

View File

@@ -26,6 +26,7 @@ db = mongo["docsgpt"]
conversations_collection = db["conversations"]
vectors_collection = db["vectors"]
prompts_collection = db["prompts"]
api_key_collection = db["api_keys"]
answer = Blueprint('answer', __name__)
if settings.LLM_NAME == "gpt4":
@@ -74,6 +75,12 @@ def run_async_chain(chain, question, chat_history):
result["answer"] = answer
return result
def get_data_from_api_key(api_key):
data = api_key_collection.find_one({"key": api_key})
if data is None:
return bad_request(401, "Invalid API key")
return data
def get_vectorstore(data):
if "active_docs" in data:
@@ -95,8 +102,8 @@ def is_azure_configured():
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
def complete_stream(question, docsearch, chat_history, api_key, prompt_id, conversation_id):
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key)
def complete_stream(question, docsearch, chat_history, prompt_id, conversation_id):
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
if prompt_id == 'default':
prompt = chat_combine_template
@@ -182,10 +189,15 @@ def stream():
data = request.get_json()
# get parameter from url question
question = data["question"]
history = data["history"]
# history to json object from string
history = json.loads(history)
conversation_id = data["conversation_id"]
if "history" not in data:
history = []
else:
history = data["history"]
history = json.loads(history)
if "conversation_id" not in data:
conversation_id = None
else:
conversation_id = data["conversation_id"]
if 'prompt_id' in data:
prompt_id = data["prompt_id"]
else:
@@ -193,23 +205,18 @@ def stream():
# check if active_docs is set
if not api_key_set:
api_key = data["api_key"]
else:
api_key = settings.API_KEY
if not embeddings_key_set:
embeddings_key = data["embeddings_key"]
else:
embeddings_key = settings.EMBEDDINGS_KEY
if "active_docs" in data:
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
vectorstore = get_vectorstore({"active_docs": data_key["source"]})
elif "active_docs" in data:
vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
else:
vectorstore = ""
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key)
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
return Response(
complete_stream(question, docsearch,
chat_history=history, api_key=api_key,
chat_history=history,
prompt_id=prompt_id,
conversation_id=conversation_id), mimetype="text/event-stream"
)
@@ -219,20 +226,15 @@ def stream():
def api_answer():
data = request.get_json()
question = data["question"]
history = data["history"]
if "history" not in data:
history = []
else:
history = data["history"]
if "conversation_id" not in data:
conversation_id = None
else:
conversation_id = data["conversation_id"]
print("-" * 5)
if not api_key_set:
api_key = data["api_key"]
else:
api_key = settings.API_KEY
if not embeddings_key_set:
embeddings_key = data["embeddings_key"]
else:
embeddings_key = settings.EMBEDDINGS_KEY
if 'prompt_id' in data:
prompt_id = data["prompt_id"]
else:
@@ -250,13 +252,17 @@ def api_answer():
# use try and except to check for exception
try:
# check if the vectorstore is set
vectorstore = get_vectorstore(data)
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
vectorstore = get_vectorstore({"active_docs": data_key["source"]})
else:
vectorstore = get_vectorstore(data)
# loading the index and the store and the prompt template
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key)
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key)
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
@@ -348,18 +354,14 @@ def api_search():
# get parameter from url question
question = data["question"]
if not embeddings_key_set:
if "embeddings_key" in data:
embeddings_key = data["embeddings_key"]
else:
embeddings_key = settings.EMBEDDINGS_KEY
else:
embeddings_key = settings.EMBEDDINGS_KEY
if "active_docs" in data:
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
vectorstore = data_key["source"]
elif "active_docs" in data:
vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
else:
vectorstore = ""
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key)
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
docs = docsearch.search(question, k=2)

View File

@@ -1,4 +1,5 @@
import os
import uuid
from flask import Blueprint, request, jsonify
import requests
from pymongo import MongoClient
@@ -16,6 +17,7 @@ conversations_collection = db["conversations"]
vectors_collection = db["vectors"]
prompts_collection = db["prompts"]
feedback_collection = db["feedback"]
api_key_collection = db["api_keys"]
user = Blueprint('user', __name__)
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -343,5 +345,41 @@ def update_prompt_name():
@user.route("/api/get_api_keys", methods=["GET"])
def get_api_keys():
user = "local"
keys = api_key_collection.find({"user": user})
list_keys = []
for key in keys:
list_keys.append({"id": str(key["_id"]), "name": key["name"], "key": key["key"][:4] + "..." + key["key"][-4:], "source": key["source"]})
return jsonify(list_keys)
@user.route("/api/create_api_key", methods=["POST"])
def create_api_key():
data = request.get_json()
name = data["name"]
source = data["source"]
key = str(uuid.uuid4())
user = "local"
resp = api_key_collection.insert_one(
{
"name": name,
"key": key,
"source": source,
"user": user,
}
)
new_id = str(resp.inserted_id)
return {"id": new_id, "key": key}
@user.route("/api/delete_api_key", methods=["POST"])
def delete_api_key():
data = request.get_json()
id = data["id"]
api_key_collection.delete_one(
{
"_id": ObjectId(id),
}
)
return {"status": "ok"}