diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index a2e32eba..34e6abca 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -88,9 +88,6 @@ def get_data_from_api_key(api_key): if data is None: raise Exception("Invalid API Key, please generate new key", 401) - if "retriever" not in data: - data["retriever"] = None - if "source" in data and isinstance(data["source"], DBRef): source_doc = db.dereference(data["source"]) data["source"] = str(source_doc["_id"]) @@ -117,7 +114,9 @@ def is_azure_configured(): ) -def save_conversation(conversation_id, question, response, source_log_docs, llm,index=None): +def save_conversation( + conversation_id, question, response, source_log_docs, tool_calls, llm, index=None +): if conversation_id is not None and index is not None: conversations_collection.update_one( {"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}}, @@ -126,20 +125,14 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm, f"queries.{index}.prompt": question, f"queries.{index}.response": response, f"queries.{index}.sources": source_log_docs, + f"queries.{index}.tool_calls": tool_calls, } - } + }, ) ##remove following queries from the array conversations_collection.update_one( {"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}}, - { - "$push":{ - "queries":{ - "$each":[], - "$slice":index+1 - } - } - } + {"$push": {"queries": {"$each": [], "$slice": index + 1}}}, ) elif conversation_id is not None and conversation_id != "None": conversations_collection.update_one( @@ -150,6 +143,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm, "prompt": question, "response": response, "sources": source_log_docs, + "tool_calls": tool_calls, } } }, @@ -169,11 +163,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm, "role": "user", "content": "Summarise following conversation in no more than 3 words, " "respond ONLY with the summary, use the same language as the " - "system \n\nUser: " - + question - + "\n\n" - + "AI: " - + response, + "system \n\nUser: " + question + "\n\n" + "AI: " + response, }, ] @@ -188,6 +178,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm, "prompt": question, "response": response, "sources": source_log_docs, + "tool_calls": tool_calls, } ], } @@ -208,12 +199,13 @@ def get_prompt(prompt_id): def complete_stream( - question, retriever, conversation_id, user_api_key, isNoneDoc=False,index=None + question, retriever, conversation_id, user_api_key, isNoneDoc=False, index=None ): - try: + try: response_full = "" source_log_docs = [] + tool_calls = [] answer = retriever.gen() sources = retriever.search() for source in sources: @@ -222,6 +214,7 @@ def complete_stream( if len(sources) > 0: data = json.dumps({"type": "source", "source": sources}) yield f"data: {data}\n\n" + for line in answer: if "answer" in line: response_full += str(line["answer"]) @@ -229,6 +222,10 @@ def complete_stream( yield f"data: {data}\n\n" elif "source" in line: source_log_docs.append(line["source"]) + elif "tool_calls" in line: + tool_calls = line["tool_calls"] + data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls}) + yield f"data: {data}\n\n" if isNoneDoc: for doc in source_log_docs: @@ -239,7 +236,13 @@ def complete_stream( ) if user_api_key is None: conversation_id = save_conversation( - conversation_id, question, response_full, source_log_docs, llm,index + conversation_id, + question, + response_full, + source_log_docs, + tool_calls, + llm, + index, ) # send data.type = "end" to indicate that the stream has ended as json data = json.dumps({"type": "id", "id": str(conversation_id)}) @@ -303,7 +306,7 @@ class Stream(Resource): "isNoneDoc": fields.Boolean( required=False, description="Flag indicating if no document is used" ), - "index":fields.Integer( + "index": fields.Integer( required=False, description="The position where query is to be updated" ), }, @@ -315,22 +318,24 @@ class Stream(Resource): data = request.get_json() required_fields = ["question"] if "index" in data: - required_fields = ["question","conversation_id"] + required_fields = ["question", "conversation_id"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields try: question = data["question"] - history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model) + history = limit_chat_history( + json.loads(data.get("history", [])), gpt_model=gpt_model + ) conversation_id = data.get("conversation_id") prompt_id = data.get("prompt_id", "default") - - index=data.get("index",None) + + index = data.get("index", None) chunks = int(data.get("chunks", 2)) token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY) retriever_name = data.get("retriever", "classic") - + if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) chunks = int(data_key.get("chunks", 2)) @@ -367,7 +372,7 @@ class Stream(Resource): gpt_model=gpt_model, user_api_key=user_api_key, ) - + return Response( complete_stream( question=question, @@ -395,7 +400,7 @@ class Stream(Resource): ) status_code = 400 return Response( - error_stream_generate('Unknown error occurred'), + error_stream_generate("Unknown error occurred"), status=status_code, mimetype="text/event-stream", ) @@ -442,14 +447,16 @@ class Answer(Resource): @api.doc(description="Provide an answer based on the question and retriever") def post(self): data = request.get_json() - required_fields = ["question"] + required_fields = ["question"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields try: question = data["question"] - history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model) + history = limit_chat_history( + json.loads(data.get("history", [])), gpt_model=gpt_model + ) conversation_id = data.get("conversation_id") prompt_id = data.get("prompt_id", "default") chunks = int(data.get("chunks", 2)) @@ -490,13 +497,16 @@ class Answer(Resource): user_api_key=user_api_key, ) - source_log_docs = [] response_full = "" + source_log_docs = [] + tool_calls = [] for line in retriever.gen(): if "source" in line: source_log_docs.append(line["source"]) elif "answer" in line: response_full += line["answer"] + elif "tool_calls" in line: + tool_calls.append(line["tool_calls"]) if data.get("isNoneDoc"): for doc in source_log_docs: @@ -509,7 +519,12 @@ class Answer(Resource): result = {"answer": response_full, "sources": source_log_docs} result["conversation_id"] = str( save_conversation( - conversation_id, question, response_full, source_log_docs, llm + conversation_id, + question, + response_full, + source_log_docs, + tool_calls, + llm, ) ) retriever_params = retriever.get_params() diff --git a/application/core/settings.py b/application/core/settings.py index 0bace432..5842da33 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -1,26 +1,37 @@ +import os from pathlib import Path from typing import Optional -import os from pydantic_settings import BaseSettings -current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +current_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) class Settings(BaseSettings): LLM_NAME: str = "docsgpt" - MODEL_NAME: Optional[str] = None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo + MODEL_NAME: Optional[str] = ( + None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo + ) EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2" CELERY_BROKER_URL: str = "redis://localhost:6379/0" CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" MONGO_URI: str = "mongodb://localhost:27017/docsgpt" MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf") DEFAULT_MAX_HISTORY: int = 150 - MODEL_TOKEN_LIMITS: dict = {"gpt-4o-mini": 128000, "gpt-3.5-turbo": 4096, "claude-2": 1e5} + MODEL_TOKEN_LIMITS: dict = { + "gpt-4o-mini": 128000, + "gpt-3.5-turbo": 4096, + "claude-2": 1e5, + "gemini-2.0-flash-exp": 1e6, + } UPLOAD_FOLDER: str = "inputs" PARSE_PDF_AS_IMAGE: bool = False - VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" - RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search + VECTOR_STORE: str = ( + "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" + ) + RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search # LLM Cache CACHE_REDIS_URL: str = "redis://localhost:6379/2" @@ -28,12 +39,18 @@ class Settings(BaseSettings): API_URL: str = "http://localhost:7091" # backend url for celery worker API_KEY: Optional[str] = None # LLM api key - EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY) + EMBEDDINGS_KEY: Optional[str] = ( + None # api key for embeddings (if using openai, just copy API_KEY) + ) OPENAI_API_BASE: Optional[str] = None # azure openai api base url OPENAI_API_VERSION: Optional[str] = None # azure openai api version AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering - AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings - OPENAI_BASE_URL: Optional[str] = None # openai base url for open ai compatable models + AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = ( + None # azure deployment name for embeddings + ) + OPENAI_BASE_URL: Optional[str] = ( + None # openai base url for open ai compatable models + ) # elasticsearch ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch @@ -68,12 +85,14 @@ class Settings(BaseSettings): # Milvus vectorstore config MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt" - MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default + MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default MILVUS_TOKEN: Optional[str] = "" # LanceDB vectorstore config LANCEDB_PATH: str = "/tmp/lancedb" # Path where LanceDB stores its local data - LANCEDB_TABLE_NAME: Optional[str] = "docsgpts" # Name of the table to use for storing vectors + LANCEDB_TABLE_NAME: Optional[str] = ( + "docsgpts" # Name of the table to use for storing vectors + ) BRAVE_SEARCH_API_KEY: Optional[str] = None FLASK_DEBUG_MODE: bool = False diff --git a/application/llm/openai.py b/application/llm/openai.py index b507a1da..36861584 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -1,3 +1,5 @@ +import json + from application.core.settings import settings from application.llm.base import BaseLLM @@ -15,6 +17,63 @@ class OpenAILLM(BaseLLM): self.api_key = api_key self.user_api_key = user_api_key + def _clean_messages_openai(self, messages): + cleaned_messages = [] + for message in messages: + role = message.get("role") + content = message.get("content") + + if role == "model": + role = "assistant" + + if role and content is not None: + if isinstance(content, str): + cleaned_messages.append({"role": role, "content": content}) + elif isinstance(content, list): + for item in content: + if "text" in item: + cleaned_messages.append( + {"role": role, "content": item["text"]} + ) + elif "function_call" in item: + tool_call = { + "id": item["function_call"]["call_id"], + "type": "function", + "function": { + "name": item["function_call"]["name"], + "arguments": json.dumps( + item["function_call"]["args"] + ), + }, + } + cleaned_messages.append( + { + "role": "assistant", + "content": None, + "tool_calls": [tool_call], + } + ) + elif "function_response" in item: + cleaned_messages.append( + { + "role": "tool", + "tool_call_id": item["function_response"][ + "call_id" + ], + "content": json.dumps( + item["function_response"]["response"]["result"] + ), + } + ) + else: + raise ValueError( + f"Unexpected content dictionary format: {item}" + ) + else: + raise ValueError(f"Unexpected content type: {type(content)}") + + return cleaned_messages + def _raw_gen( self, baseself, @@ -25,9 +84,15 @@ class OpenAILLM(BaseLLM): engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs, ): + messages = self._clean_messages_openai(messages) + print(messages) if tools: response = self.client.chat.completions.create( - model=model, messages=messages, stream=stream, tools=tools, **kwargs + model=model, + messages=messages, + stream=stream, + tools=tools, + **kwargs, ) return response.choices[0] else: @@ -46,6 +111,7 @@ class OpenAILLM(BaseLLM): engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs, ): + messages = self._clean_messages_openai(messages) response = self.client.chat.completions.create( model=model, messages=messages, stream=stream, **kwargs ) diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index d4c7d755..ca40f966 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,3 +1,5 @@ +import uuid + from application.core.settings import settings from application.retriever.base import BaseRetriever from application.tools.agent import Agent @@ -35,6 +37,12 @@ class ClassicRAG(BaseRetriever): ) ) self.user_api_key = user_api_key + self.agent = Agent( + llm_name=settings.LLM_NAME, + gpt_model=self.gpt_model, + api_key=settings.API_KEY, + user_api_key=self.user_api_key, + ) def _get_data(self): if self.chunks == 0: @@ -78,21 +86,42 @@ class ClassicRAG(BaseRetriever): messages_combine.append( {"role": "assistant", "content": i["response"]} ) + if "tool_calls" in i: + for tool_call in i["tool_calls"]: + call_id = tool_call.get("call_id") + if call_id is None or call_id == "None": + call_id = str(uuid.uuid4()) + + function_call_dict = { + "function_call": { + "name": tool_call.get("action_name"), + "args": tool_call.get("arguments"), + "call_id": call_id, + } + } + function_response_dict = { + "function_response": { + "name": tool_call.get("action_name"), + "response": {"result": tool_call.get("result")}, + "call_id": call_id, + } + } + + messages_combine.append( + {"role": "assistant", "content": [function_call_dict]} + ) + messages_combine.append( + {"role": "tool", "content": [function_response_dict]} + ) + messages_combine.append({"role": "user", "content": self.question}) - # llm = LLMCreator.create_llm( - # settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key - # ) - # completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) - agent = Agent( - llm_name=settings.LLM_NAME, - gpt_model=self.gpt_model, - api_key=settings.API_KEY, - user_api_key=self.user_api_key, - ) - completion = agent.gen(messages_combine) + completion = self.agent.gen(messages_combine) + for line in completion: yield {"answer": str(line)} + yield {"tool_calls": self.agent.tool_calls.copy()} + def search(self): return self._get_data() diff --git a/application/tools/agent.py b/application/tools/agent.py index d0743cd9..10798862 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -16,6 +16,7 @@ class Agent: # Static tool configuration (to be replaced later) self.tools = [] self.tool_config = {} + self.tool_calls = [] def _get_user_tools(self, user="local"): mongo = MongoDB.get_client() @@ -123,6 +124,16 @@ class Agent: print(f"Executing tool: {action_name} with args: {call_args}") result = tool.execute_action(action_name, **parameters) call_id = getattr(call, "id", None) + + tool_call_data = { + "tool_name": tool_data["name"], + "call_id": call_id if call_id is not None else "None", + "action_name": f"{action_name}_{tool_id}", + "arguments": call_args, + "result": result, + } + self.tool_calls.append(tool_call_data) + return result, call_id def _simple_tool_agent(self, messages): @@ -134,7 +145,11 @@ class Agent: if isinstance(resp, str): yield resp return - if hasattr(resp, "message") and hasattr(resp.message, "content"): + if ( + hasattr(resp, "message") + and hasattr(resp.message, "content") + and resp.message.content is not None + ): yield resp.message.content return @@ -142,7 +157,11 @@ class Agent: if isinstance(resp, str): yield resp - elif hasattr(resp, "message") and hasattr(resp.message, "content"): + elif ( + hasattr(resp, "message") + and hasattr(resp.message, "content") + and resp.message.content is not None + ): yield resp.message.content else: completion = self.llm.gen_stream( @@ -154,6 +173,7 @@ class Agent: return def gen(self, messages): + self.tool_calls = [] if self.llm.supports_tools(): resp = self._simple_tool_agent(messages) for line in resp: diff --git a/application/tools/llm_handler.py b/application/tools/llm_handler.py index cc7494c0..334d2c4c 100644 --- a/application/tools/llm_handler.py +++ b/application/tools/llm_handler.py @@ -24,13 +24,28 @@ class OpenAILLMHandler(LLMHandler): tool_response, call_id = agent._execute_tool_action( tools_dict, call ) - messages.append( - { - "role": "tool", - "content": str(tool_response), - "tool_call_id": call_id, + function_call_dict = { + "function_call": { + "name": call.function.name, + "args": call.function.arguments, + "call_id": call_id, } + } + function_response_dict = { + "function_response": { + "name": call.function.name, + "response": {"result": tool_response}, + "call_id": call_id, + } + } + + messages.append( + {"role": "assistant", "content": [function_call_dict]} ) + messages.append( + {"role": "tool", "content": [function_response_dict]} + ) + except Exception as e: messages.append( { diff --git a/application/utils.py b/application/utils.py index 38bbc622..6d47d31a 100644 --- a/application/utils.py +++ b/application/utils.py @@ -1,8 +1,9 @@ -import tiktoken import hashlib -from flask import jsonify, make_response import re +import tiktoken +from flask import jsonify, make_response + _encoding = None @@ -22,6 +23,7 @@ def num_tokens_from_string(string: str) -> int: else: return 0 + def num_tokens_from_object_or_list(thing): if isinstance(thing, list): return sum([num_tokens_from_object_or_list(x) for x in thing]) @@ -32,6 +34,7 @@ def num_tokens_from_object_or_list(thing): else: return 0 + def count_tokens_docs(docs): docs_content = "" for doc in docs: @@ -59,6 +62,7 @@ def check_required_fields(data, required_fields): def get_hash(data): return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest() + def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"): """ Limits chat history based on token count. @@ -67,38 +71,41 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"): from application.core.settings import settings max_token_limit = ( - max_token_limit - if max_token_limit and - max_token_limit < settings.MODEL_TOKEN_LIMITS.get( - gpt_model, settings.DEFAULT_MAX_HISTORY - ) - else settings.MODEL_TOKEN_LIMITS.get( - gpt_model, settings.DEFAULT_MAX_HISTORY - ) - ) - + max_token_limit + if max_token_limit + and max_token_limit + < settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) + else settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) + ) if not history: return [] - - tokens_current_history = 0 + trimmed_history = [] - + tokens_current_history = 0 + for message in reversed(history): + tokens_batch = 0 if "prompt" in message and "response" in message: - tokens_batch = num_tokens_from_string(message["prompt"]) + num_tokens_from_string( - message["response"] - ) - if tokens_current_history + tokens_batch < max_token_limit: - tokens_current_history += tokens_batch - trimmed_history.insert(0, message) - else: - break + tokens_batch += num_tokens_from_string(message["prompt"]) + tokens_batch += num_tokens_from_string(message["response"]) + + if "tool_calls" in message: + for tool_call in message["tool_calls"]: + tool_call_string = f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}" + tokens_batch += num_tokens_from_string(tool_call_string) + + if tokens_current_history + tokens_batch < max_token_limit: + tokens_current_history += tokens_batch + trimmed_history.insert(0, message) + else: + break return trimmed_history + def validate_function_name(function_name): """Validates if a function name matches the allowed pattern.""" if not re.match(r"^[a-zA-Z0-9_-]+$", function_name): return False - return True \ No newline at end of file + return True diff --git a/frontend/src/assets/chevron-down.svg b/frontend/src/assets/chevron-down.svg new file mode 100644 index 00000000..b2605251 --- /dev/null +++ b/frontend/src/assets/chevron-down.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/src/components/Accordion.tsx b/frontend/src/components/Accordion.tsx new file mode 100644 index 00000000..ec0a81d7 --- /dev/null +++ b/frontend/src/components/Accordion.tsx @@ -0,0 +1,59 @@ +import React, { useRef, useState } from 'react'; + +import ChevronDown from '../assets/chevron-down.svg'; + +type AccordionProps = { + title: string; + children: React.ReactNode; + className?: string; + titleClassName?: string; + contentClassName?: string; + open?: boolean; +}; + +export default function Accordion({ + title, + children, + className = '', + titleClassName = '', + contentClassName = '', + open: initialOpen = false, +}: AccordionProps) { + const contentRef = useRef(null); + const [isOpen, setIsOpen] = useState(initialOpen); + + const accordionContentStyle = { + height: isOpen ? 'auto' : '0px', + transition: 'height 0.3s ease-in-out, opacity 0.3s ease-in-out', + overflow: 'hidden', + } as React.CSSProperties; + + const toggleAccordion = () => { + setIsOpen(!isOpen); + }; + return ( +
+ + +
+ {children} +
+
+ ); +} diff --git a/frontend/src/components/CopyButton.tsx b/frontend/src/components/CopyButton.tsx index e13f9133..f0559f52 100644 --- a/frontend/src/components/CopyButton.tsx +++ b/frontend/src/components/CopyButton.tsx @@ -40,7 +40,7 @@ export default function CoppyButton({ /> ) : ( { handleCopyClick(text); }} diff --git a/frontend/src/conversation/Conversation.tsx b/frontend/src/conversation/Conversation.tsx index f4511fc3..29438f93 100644 --- a/frontend/src/conversation/Conversation.tsx +++ b/frontend/src/conversation/Conversation.tsx @@ -225,6 +225,7 @@ export default function Conversation() { message={query.response} type={'ANSWER'} sources={query.sources} + toolCalls={query.tool_calls} feedback={query.feedback} handleFeedback={(feedback: FEEDBACK) => handleFeedback(query, feedback, index) diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index 668b0935..ab70584f 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -1,6 +1,7 @@ import 'katex/dist/katex.min.css'; import { forwardRef, useRef, useState } from 'react'; +import { useTranslation } from 'react-i18next'; import ReactMarkdown from 'react-markdown'; import { useSelector } from 'react-redux'; import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'; @@ -8,27 +9,29 @@ import { vscDarkPlus } from 'react-syntax-highlighter/dist/cjs/styles/prism'; import rehypeKatex from 'rehype-katex'; import remarkGfm from 'remark-gfm'; import remarkMath from 'remark-math'; -import { useTranslation } from 'react-i18next'; import DocsGPT3 from '../assets/cute_docsgpt3.svg'; +import ChevronDown from '../assets/chevron-down.svg'; import Dislike from '../assets/dislike.svg?react'; import Document from '../assets/document.svg'; +import Edit from '../assets/edit.svg'; import Like from '../assets/like.svg?react'; import Link from '../assets/link.svg'; import Sources from '../assets/sources.svg'; -import Edit from '../assets/edit.svg'; import UserIcon from '../assets/user.png'; +import Accordion from '../components/Accordion'; import Avatar from '../components/Avatar'; import CopyButton from '../components/CopyButton'; import Sidebar from '../components/Sidebar'; import SpeakButton from '../components/TextToSpeechButton'; +import { useOutsideAlerter } from '../hooks'; import { selectChunks, selectSelectedDocs, } from '../preferences/preferenceSlice'; import classes from './ConversationBubble.module.css'; import { FEEDBACK, MESSAGE_TYPE } from './conversationModels'; -import { useOutsideAlerter } from '../hooks'; +import { ToolCallsType } from './types'; const DisableSourceFE = import.meta.env.VITE_DISABLE_SOURCE_FE || false; @@ -41,6 +44,7 @@ const ConversationBubble = forwardRef< feedback?: FEEDBACK; handleFeedback?: (feedback: FEEDBACK) => void; sources?: { title: string; text: string; source: string }[]; + toolCalls?: ToolCallsType[]; retryBtn?: React.ReactElement; questionNumber?: number; handleUpdatedQuestionSubmission?: ( @@ -57,6 +61,7 @@ const ConversationBubble = forwardRef< feedback, handleFeedback, sources, + toolCalls, retryBtn, questionNumber, handleUpdatedQuestionSubmission, @@ -307,6 +312,9 @@ const ConversationBubble = forwardRef< ) )} + {toolCalls && toolCalls.length > 0 && ( + + )}
+
+ + } + /> + +
+ {isToolCallsOpen && ( +
+
+ {toolCalls.map((toolCall, index) => ( + +
+

+ + Arguments + {' '} + +

+

+ + {JSON.stringify(toolCall.arguments, null, 2)} + +

+
+
+

+ + Response + {' '} + +

+

+ + {JSON.stringify(toolCall.result, null, 2)} + +

+
+
+ } + /> + ))} +
+
+ )} +
+ ); +} diff --git a/frontend/src/conversation/SharedConversation.tsx b/frontend/src/conversation/SharedConversation.tsx index c6f73351..b231ed1e 100644 --- a/frontend/src/conversation/SharedConversation.tsx +++ b/frontend/src/conversation/SharedConversation.tsx @@ -121,6 +121,7 @@ export const SharedConversation = () => { message={query.response} type={'ANSWER'} sources={query.sources ?? []} + toolCalls={query.tool_calls} > ); } else if (query.error) { diff --git a/frontend/src/conversation/conversationHandlers.ts b/frontend/src/conversation/conversationHandlers.ts index 1abdb74d..ddd84bd3 100644 --- a/frontend/src/conversation/conversationHandlers.ts +++ b/frontend/src/conversation/conversationHandlers.ts @@ -1,6 +1,7 @@ import conversationService from '../api/services/conversationService'; import { Doc } from '../models/misc'; import { Answer, FEEDBACK, RetrievalPayload } from './conversationModels'; +import { ToolCallsType } from './types'; export function handleFetchAnswer( question: string, @@ -16,6 +17,7 @@ export function handleFetchAnswer( result: any; answer: any; sources: any; + toolCalls: ToolCallsType[]; conversationId: any; query: string; } @@ -23,13 +25,18 @@ export function handleFetchAnswer( result: any; answer: any; sources: any; + toolCalls: ToolCallsType[]; query: string; conversationId: any; title: any; } > { history = history.map((item) => { - return { prompt: item.prompt, response: item.response }; + return { + prompt: item.prompt, + response: item.response, + tool_calls: item.tool_calls, + }; }); const payload: RetrievalPayload = { question: question, @@ -60,6 +67,7 @@ export function handleFetchAnswer( query: question, result, sources: data.sources, + toolCalls: data.tool_calls, conversationId: data.conversation_id, }; }); @@ -78,7 +86,11 @@ export function handleFetchAnswerSteaming( indx?: number, ): Promise { history = history.map((item) => { - return { prompt: item.prompt, response: item.response }; + return { + prompt: item.prompt, + response: item.response, + tool_calls: item.tool_calls, + }; }); const payload: RetrievalPayload = { question: question, @@ -155,7 +167,11 @@ export function handleSearch( token_limit: number, ) { history = history.map((item) => { - return { prompt: item.prompt, response: item.response }; + return { + prompt: item.prompt, + response: item.response, + tool_calls: item.tool_calls, + }; }); const payload: RetrievalPayload = { question: question, @@ -183,7 +199,11 @@ export function handleSearchViaApiKey( history: Array = [], ) { history = history.map((item) => { - return { prompt: item.prompt, response: item.response }; + return { + prompt: item.prompt, + response: item.response, + tool_calls: item.tool_calls, + }; }); return conversationService .search({ @@ -230,7 +250,11 @@ export function handleFetchSharedAnswerStreaming( //for shared conversations onEvent: (event: MessageEvent) => void, ): Promise { history = history.map((item) => { - return { prompt: item.prompt, response: item.response }; + return { + prompt: item.prompt, + response: item.response, + tool_calls: item.tool_calls, + }; }); return new Promise((resolve, reject) => { @@ -330,6 +354,7 @@ export function handleFetchSharedAnswer( query: question, result, sources: data.sources, + toolCalls: data.tool_calls, }; }); } diff --git a/frontend/src/conversation/conversationModels.ts b/frontend/src/conversation/conversationModels.ts index a499deec..c747c55c 100644 --- a/frontend/src/conversation/conversationModels.ts +++ b/frontend/src/conversation/conversationModels.ts @@ -1,3 +1,5 @@ +import { ToolCallsType } from './types'; + export type MESSAGE_TYPE = 'QUESTION' | 'ANSWER' | 'ERROR'; export type Status = 'idle' | 'loading' | 'failed'; export type FEEDBACK = 'LIKE' | 'DISLIKE' | null; @@ -17,9 +19,10 @@ export interface Answer { answer: string; query: string; result: string; - sources: { title: string; text: string; source: string }[]; conversationId: string | null; title: string | null; + sources: { title: string; text: string; source: string }[]; + tool_calls: ToolCallsType[]; } export interface Query { @@ -27,10 +30,12 @@ export interface Query { response?: string; feedback?: FEEDBACK; error?: string; - sources?: { title: string; text: string; source: string }[]; conversationId?: string | null; title?: string | null; + sources?: { title: string; text: string; source: string }[]; + tool_calls?: ToolCallsType[]; } + export interface RetrievalPayload { question: string; active_docs?: string; diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index 69f81e21..4473d3a9 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -82,6 +82,13 @@ export const fetchAnswer = createAsyncThunk< query: { sources: data.source ?? [] }, }), ); + } else if (data.type === 'tool_calls') { + dispatch( + updateToolCalls({ + index: indx ?? state.conversation.queries.length - 1, + query: { tool_calls: data.tool_calls }, + }), + ); } else if (data.type === 'error') { // set status to 'failed' dispatch(conversationSlice.actions.setStatus('failed')); @@ -130,7 +137,11 @@ export const fetchAnswer = createAsyncThunk< dispatch( updateQuery({ index: indx ?? state.conversation.queries.length - 1, - query: { response: answer.answer, sources: sourcesPrepped }, + query: { + response: answer.answer, + sources: sourcesPrepped, + tool_calls: answer.toolCalls, + }, }), ); dispatch( @@ -156,6 +167,7 @@ export const fetchAnswer = createAsyncThunk< query: question, result: '', sources: [], + tool_calls: [], }; }); @@ -212,6 +224,15 @@ export const conversationSlice = createSlice({ state.queries[index].sources!.push(query.sources![0]); } }, + updateToolCalls( + state, + action: PayloadAction<{ index: number; query: Partial }>, + ) { + const { index, query } = action.payload; + if (!state.queries[index].tool_calls) { + state.queries[index].tool_calls = query?.tool_calls; + } + }, updateQuery( state, action: PayloadAction<{ index: number; query: Partial }>, @@ -263,6 +284,7 @@ export const { updateStreamingQuery, updateConversationId, updateStreamingSource, + updateToolCalls, setConversation, } = conversationSlice.actions; export default conversationSlice.reducer; diff --git a/frontend/src/conversation/sharedConversationSlice.ts b/frontend/src/conversation/sharedConversationSlice.ts index c95f44e8..3140b418 100644 --- a/frontend/src/conversation/sharedConversationSlice.ts +++ b/frontend/src/conversation/sharedConversationSlice.ts @@ -51,6 +51,13 @@ export const fetchSharedAnswer = createAsyncThunk( query: { sources: data.source ?? [] }, }), ); + } else if (data.type === 'tool_calls') { + dispatch( + updateToolCalls({ + index: state.sharedConversation.queries.length - 1, + query: { tool_calls: data.tool_calls }, + }), + ); } else if (data.type === 'error') { // set status to 'failed' dispatch(sharedConversationSlice.actions.setStatus('failed')); @@ -107,6 +114,7 @@ export const fetchSharedAnswer = createAsyncThunk( query: question, result: '', sources: [], + tool_calls: [], }; }, ); @@ -161,6 +169,15 @@ export const sharedConversationSlice = createSlice({ }; } }, + updateToolCalls( + state, + action: PayloadAction<{ index: number; query: Partial }>, + ) { + const { index, query } = action.payload; + if (!state.queries[index].tool_calls) { + state.queries[index].tool_calls = query?.tool_calls; + } + }, updateQuery( state, action: PayloadAction<{ index: number; query: Partial }>, @@ -232,6 +249,7 @@ export const { setClientApiKey, updateQuery, updateStreamingQuery, + updateToolCalls, addQuery, saveToLocalStorage, updateStreamingSource, diff --git a/frontend/src/conversation/types/index.ts b/frontend/src/conversation/types/index.ts new file mode 100644 index 00000000..9b5f2365 --- /dev/null +++ b/frontend/src/conversation/types/index.ts @@ -0,0 +1,7 @@ +export type ToolCallsType = { + tool_name: string; + action_name: string; + call_id: string; + arguments: Record; + result: Record; +};