diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 4c393714..abb2f67c 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -26,6 +26,7 @@ db = mongo["docsgpt"] conversations_collection = db["conversations"] vectors_collection = db["vectors"] prompts_collection = db["prompts"] +api_key_collection = db["api_keys"] answer = Blueprint('answer', __name__) if settings.LLM_NAME == "gpt4": @@ -74,6 +75,12 @@ def run_async_chain(chain, question, chat_history): result["answer"] = answer return result +def get_data_from_api_key(api_key): + data = api_key_collection.find_one({"key": api_key}) + if data is None: + return bad_request(401, "Invalid API key") + return data + def get_vectorstore(data): if "active_docs" in data: @@ -95,8 +102,8 @@ def is_azure_configured(): return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME -def complete_stream(question, docsearch, chat_history, api_key, prompt_id, conversation_id): - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key) +def complete_stream(question, docsearch, chat_history, prompt_id, conversation_id): + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) if prompt_id == 'default': prompt = chat_combine_template @@ -182,10 +189,15 @@ def stream(): data = request.get_json() # get parameter from url question question = data["question"] - history = data["history"] - # history to json object from string - history = json.loads(history) - conversation_id = data["conversation_id"] + if "history" not in data: + history = [] + else: + history = data["history"] + history = json.loads(history) + if "conversation_id" not in data: + conversation_id = None + else: + conversation_id = data["conversation_id"] if 'prompt_id' in data: prompt_id = data["prompt_id"] else: @@ -193,23 +205,18 @@ def stream(): # check if active_docs is set - if not api_key_set: - api_key = data["api_key"] - else: - api_key = settings.API_KEY - if not embeddings_key_set: - embeddings_key = data["embeddings_key"] - else: - embeddings_key = settings.EMBEDDINGS_KEY - if "active_docs" in data: + if "api_key" in data: + data_key = get_data_from_api_key(data["api_key"]) + vectorstore = get_vectorstore({"active_docs": data_key["source"]}) + elif "active_docs" in data: vectorstore = get_vectorstore({"active_docs": data["active_docs"]}) else: vectorstore = "" - docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key) + docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY) return Response( complete_stream(question, docsearch, - chat_history=history, api_key=api_key, + chat_history=history, prompt_id=prompt_id, conversation_id=conversation_id), mimetype="text/event-stream" ) @@ -219,20 +226,15 @@ def stream(): def api_answer(): data = request.get_json() question = data["question"] - history = data["history"] + if "history" not in data: + history = [] + else: + history = data["history"] if "conversation_id" not in data: conversation_id = None else: conversation_id = data["conversation_id"] print("-" * 5) - if not api_key_set: - api_key = data["api_key"] - else: - api_key = settings.API_KEY - if not embeddings_key_set: - embeddings_key = data["embeddings_key"] - else: - embeddings_key = settings.EMBEDDINGS_KEY if 'prompt_id' in data: prompt_id = data["prompt_id"] else: @@ -250,13 +252,17 @@ def api_answer(): # use try and except to check for exception try: # check if the vectorstore is set - vectorstore = get_vectorstore(data) + if "api_key" in data: + data_key = get_data_from_api_key(data["api_key"]) + vectorstore = get_vectorstore({"active_docs": data_key["source"]}) + else: + vectorstore = get_vectorstore(data) # loading the index and the store and the prompt template # Note if you have used other embeddings than OpenAI, you need to change the embeddings - docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key) + docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY) - llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=api_key) + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) @@ -348,18 +354,14 @@ def api_search(): # get parameter from url question question = data["question"] - if not embeddings_key_set: - if "embeddings_key" in data: - embeddings_key = data["embeddings_key"] - else: - embeddings_key = settings.EMBEDDINGS_KEY - else: - embeddings_key = settings.EMBEDDINGS_KEY - if "active_docs" in data: + if "api_key" in data: + data_key = get_data_from_api_key(data["api_key"]) + vectorstore = data_key["source"] + elif "active_docs" in data: vectorstore = get_vectorstore({"active_docs": data["active_docs"]}) else: vectorstore = "" - docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, embeddings_key) + docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY) docs = docsearch.search(question, k=2) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 1779472b..239278b9 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -1,4 +1,5 @@ import os +import uuid from flask import Blueprint, request, jsonify import requests from pymongo import MongoClient @@ -16,6 +17,7 @@ conversations_collection = db["conversations"] vectors_collection = db["vectors"] prompts_collection = db["prompts"] feedback_collection = db["feedback"] +api_key_collection = db["api_keys"] user = Blueprint('user', __name__) current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -343,5 +345,41 @@ def update_prompt_name(): +@user.route("/api/get_api_keys", methods=["GET"]) +def get_api_keys(): + user = "local" + keys = api_key_collection.find({"user": user}) + list_keys = [] + for key in keys: + list_keys.append({"id": str(key["_id"]), "name": key["name"], "key": key["key"][:4] + "..." + key["key"][-4:], "source": key["source"]}) + return jsonify(list_keys) +@user.route("/api/create_api_key", methods=["POST"]) +def create_api_key(): + data = request.get_json() + name = data["name"] + source = data["source"] + key = str(uuid.uuid4()) + user = "local" + resp = api_key_collection.insert_one( + { + "name": name, + "key": key, + "source": source, + "user": user, + } + ) + new_id = str(resp.inserted_id) + return {"id": new_id, "key": key} + +@user.route("/api/delete_api_key", methods=["POST"]) +def delete_api_key(): + data = request.get_json() + id = data["id"] + api_key_collection.delete_one( + { + "_id": ObjectId(id), + } + ) + return {"status": "ok"} diff --git a/frontend/src/conversation/conversationApi.ts b/frontend/src/conversation/conversationApi.ts index 8293df1b..4f789f15 100644 --- a/frontend/src/conversation/conversationApi.ts +++ b/frontend/src/conversation/conversationApi.ts @@ -6,7 +6,6 @@ const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com'; export function fetchAnswerApi( question: string, signal: AbortSignal, - apiKey: string, selectedDocs: Doc, history: Array = [], conversationId: string | null, @@ -59,8 +58,6 @@ export function fetchAnswerApi( }, body: JSON.stringify({ question: question, - api_key: apiKey, - embeddings_key: apiKey, history: history, active_docs: docPath, conversation_id: conversationId, @@ -90,7 +87,6 @@ export function fetchAnswerApi( export function fetchAnswerSteaming( question: string, signal: AbortSignal, - apiKey: string, selectedDocs: Doc, history: Array = [], conversationId: string | null, @@ -124,8 +120,6 @@ export function fetchAnswerSteaming( return new Promise((resolve, reject) => { const body = { question: question, - api_key: apiKey, - embeddings_key: apiKey, active_docs: docPath, history: JSON.stringify(history), conversation_id: conversationId, @@ -188,7 +182,6 @@ export function fetchAnswerSteaming( } export function searchEndpoint( question: string, - apiKey: string, selectedDocs: Doc, conversation_id: string | null, history: Array = [], diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index 35aadd9a..ed4e41e4 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -23,7 +23,6 @@ export const fetchAnswer = createAsyncThunk( await fetchAnswerSteaming( question, signal, - state.preference.apiKey, state.preference.selectedDocs!, state.conversation.queries, state.conversation.conversationId, @@ -47,7 +46,6 @@ export const fetchAnswer = createAsyncThunk( searchEndpoint( //search for sources post streaming question, - state.preference.apiKey, state.preference.selectedDocs!, state.conversation.conversationId, state.conversation.queries, @@ -81,7 +79,6 @@ export const fetchAnswer = createAsyncThunk( const answer = await fetchAnswerApi( question, signal, - state.preference.apiKey, state.preference.selectedDocs!, state.conversation.queries, state.conversation.conversationId,