diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 7eed8434..893edd3a 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -74,7 +74,7 @@ def run_async_chain(chain, question, chat_history): def get_data_from_api_key(api_key): data = api_key_collection.find_one({"key": api_key}) - + # # Raise custom exception if the API key is not found if data is None: raise Exception("Invalid API Key, please generate new key", 401) @@ -129,10 +129,10 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm) "content": "Summarise following conversation in no more than 3 " "words, respond ONLY with the summary, use the same " "language as the system \n\nUser: " - +question - +"\n\n" - +"AI: " - +response, + + question + + "\n\n" + + "AI: " + + response, }, { "role": "user", @@ -172,7 +172,9 @@ def get_prompt(prompt_id): return prompt -def complete_stream(question, retriever, conversation_id, user_api_key): +def complete_stream( + question, retriever, conversation_id, user_api_key, isNoneDoc=False +): try: response_full = "" @@ -186,126 +188,136 @@ def complete_stream(question, retriever, conversation_id, user_api_key): elif "source" in line: source_log_docs.append(line["source"]) + if isNoneDoc: + for doc in source_log_docs: + doc["source"] = "None" + llm = LLMCreator.create_llm( settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key - ) - if(user_api_key is None): + ) + if user_api_key is None: conversation_id = save_conversation( conversation_id, question, response_full, source_log_docs, llm ) # send data.type = "end" to indicate that the stream has ended as json data = json.dumps({"type": "id", "id": str(conversation_id)}) yield f"data: {data}\n\n" - + data = json.dumps({"type": "end"}) yield f"data: {data}\n\n" except Exception as e: print("\033[91merr", str(e), file=sys.stderr) - data = json.dumps({"type": "error","error":"Please try again later. We apologize for any inconvenience.", - "error_exception": str(e)}) + data = json.dumps( + { + "type": "error", + "error": "Please try again later. We apologize for any inconvenience.", + "error_exception": str(e), + } + ) yield f"data: {data}\n\n" - return + return + @answer.route("/stream", methods=["POST"]) def stream(): - try: - data = request.get_json() - # get parameter from url question - question = data["question"] - if "history" not in data: - history = [] - else: - history = data["history"] - history = json.loads(history) - if "conversation_id" not in data: - conversation_id = None - else: - conversation_id = data["conversation_id"] - if "prompt_id" in data: - prompt_id = data["prompt_id"] - else: - prompt_id = "default" - if "selectedDocs" in data and data["selectedDocs"] is None: - chunks = 0 - elif "chunks" in data: - chunks = int(data["chunks"]) - else: - chunks = 2 - if "token_limit" in data: - token_limit = data["token_limit"] - else: - token_limit = settings.DEFAULT_MAX_HISTORY + try: + data = request.get_json() + question = data["question"] + if "history" not in data: + history = [] + else: + history = data["history"] + history = json.loads(history) + if "conversation_id" not in data: + conversation_id = None + else: + conversation_id = data["conversation_id"] + if "prompt_id" in data: + prompt_id = data["prompt_id"] + else: + prompt_id = "default" + if "selectedDocs" in data and data["selectedDocs"] is None: + chunks = 0 + elif "chunks" in data: + chunks = int(data["chunks"]) + else: + chunks = 2 + if "token_limit" in data: + token_limit = data["token_limit"] + else: + token_limit = settings.DEFAULT_MAX_HISTORY - # check if active_docs or api_key is set + # check if active_docs or api_key is set - if "api_key" in data: - data_key = get_data_from_api_key(data["api_key"]) - chunks = int(data_key["chunks"]) - prompt_id = data_key["prompt_id"] - source = {"active_docs": data_key["source"]} - user_api_key = data["api_key"] - elif "active_docs" in data: - source = {"active_docs": data["active_docs"]} - user_api_key = None - else: - source = {} - user_api_key = None + if "api_key" in data: + data_key = get_data_from_api_key(data["api_key"]) + chunks = int(data_key["chunks"]) + prompt_id = data_key["prompt_id"] + source = {"active_docs": data_key["source"]} + user_api_key = data["api_key"] + elif "active_docs" in data: + source = {"active_docs": data["active_docs"]} + user_api_key = None + else: + source = {} + user_api_key = None - if ( - source["active_docs"].split("/")[0] == "default" - or source["active_docs"].split("/")[0] == "local" - ): - retriever_name = "classic" - else: - retriever_name = source["active_docs"] + if source["active_docs"].split("/")[0] in ["default", "local"]: + retriever_name = "classic" + else: + retriever_name = source["active_docs"] - prompt = get_prompt(prompt_id) + prompt = get_prompt(prompt_id) - retriever = RetrieverCreator.create_retriever( - retriever_name, - question=question, - source=source, - chat_history=history, - prompt=prompt, - chunks=chunks, - token_limit=token_limit, - gpt_model=gpt_model, - user_api_key=user_api_key, - ) - - return Response( - complete_stream( + retriever = RetrieverCreator.create_retriever( + retriever_name, question=question, - retriever=retriever, - conversation_id=conversation_id, + source=source, + chat_history=history, + prompt=prompt, + chunks=chunks, + token_limit=token_limit, + gpt_model=gpt_model, user_api_key=user_api_key, - ), - mimetype="text/event-stream", - ) - - except ValueError: - message = "Malformed request body" - print("\033[91merr", str(message), file=sys.stderr) - return Response( - error_stream_generate(message), - status=400, - mimetype="text/event-stream", - ) - except Exception as e: + ) + + return Response( + complete_stream( + question=question, + retriever=retriever, + conversation_id=conversation_id, + user_api_key=user_api_key, + isNoneDoc=data.get("isNoneDoc"), + ), + mimetype="text/event-stream", + ) + + except ValueError: + message = "Malformed request body" + print("\033[91merr", str(message), file=sys.stderr) + return Response( + error_stream_generate(message), + status=400, + mimetype="text/event-stream", + ) + except Exception as e: print("\033[91merr", str(e), file=sys.stderr) message = e.args[0] status_code = 400 # # Custom exceptions with two arguments, index 1 as status code - if(len(e.args) >= 2): + if len(e.args) >= 2: status_code = e.args[1] return Response( - error_stream_generate(message), - status=status_code, - mimetype="text/event-stream", - ) + error_stream_generate(message), + status=status_code, + mimetype="text/event-stream", + ) + + def error_stream_generate(err_response): - data = json.dumps({"type": "error", "error":err_response}) - yield f"data: {data}\n\n" + data = json.dumps({"type": "error", "error": err_response}) + yield f"data: {data}\n\n" + @answer.route("/api/answer", methods=["POST"]) def api_answer(): @@ -346,10 +358,7 @@ def api_answer(): source = data user_api_key = None - if ( - source["active_docs"].split("/")[0] == "default" - or source["active_docs"].split("/")[0] == "local" - ): + if source["active_docs"].split("/")[0] in ["default", "local"]: retriever_name = "classic" else: retriever_name = source["active_docs"] @@ -375,6 +384,10 @@ def api_answer(): elif "answer" in line: response_full += line["answer"] + if data.get("isNoneDoc"): + for doc in source_log_docs: + doc["source"] = "None" + llm = LLMCreator.create_llm( settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key ) @@ -395,7 +408,6 @@ def api_answer(): @answer.route("/api/search", methods=["POST"]) def api_search(): data = request.get_json() - # get parameter from url question question = data["question"] if "chunks" in data: chunks = int(data["chunks"]) @@ -413,10 +425,7 @@ def api_search(): source = {} user_api_key = None - if ( - source["active_docs"].split("/")[0] == "default" - or source["active_docs"].split("/")[0] == "local" - ): + if source["active_docs"].split("/")[0] in ["default", "local"]: retriever_name = "classic" else: retriever_name = source["active_docs"] @@ -437,4 +446,9 @@ def api_search(): user_api_key=user_api_key, ) docs = retriever.search() + + if data.get("isNoneDoc"): + for doc in docs: + doc["source"] = "None" + return docs diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index eb83c5f4..4fc0f172 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -1,5 +1,6 @@ import { forwardRef, useState } from 'react'; import ReactMarkdown from 'react-markdown'; +import { useSelector } from 'react-redux'; import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'; import { vscDarkPlus } from 'react-syntax-highlighter/dist/cjs/styles/prism'; import remarkGfm from 'remark-gfm'; @@ -14,6 +15,7 @@ import Sources from '../assets/sources.svg'; import Avatar from '../components/Avatar'; import CopyButton from '../components/CopyButton'; import Sidebar from '../components/Sidebar'; +import { selectChunks } from '../preferences/preferenceSlice'; import classes from './ConversationBubble.module.css'; import { FEEDBACK, MESSAGE_TYPE } from './conversationModels'; @@ -34,6 +36,7 @@ const ConversationBubble = forwardRef< { message, type, className, feedback, handleFeedback, sources, retryBtn }, ref, ) { + const chunks = useSelector(selectChunks); const [isLikeHovered, setIsLikeHovered] = useState(false); const [isDislikeHovered, setIsDislikeHovered] = useState(false); const [isLikeClicked, setIsLikeClicked] = useState(false); @@ -59,12 +62,17 @@ const ConversationBubble = forwardRef< ref={ref} className={`flex flex-wrap self-start ${className} group flex-col dark:text-bright-gray`} > - {DisableSourceFE || type === 'ERROR' ? null : !sources || - sources.length === 0 ? ( + {DisableSourceFE || + type === 'ERROR' || + chunks === '0' || + sources?.length === 0 || + sources?.some( + (source) => source.source === 'None', + ) ? null : !sources ? (
Sources
+Sources
response.json())
.then((data) => {
diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts
index 75c457a9..23962a28 100644
--- a/frontend/src/conversation/conversationSlice.ts
+++ b/frontend/src/conversation/conversationSlice.ts
@@ -62,7 +62,7 @@ export const fetchAnswer = createAsyncThunk