diff --git a/README.md b/README.md index 4c0953a3..03393629 100644 --- a/README.md +++ b/README.md @@ -48,10 +48,13 @@ - [x] Add tools (Jan 2025) - [x] Manually updating chunks in the app UI (Feb 2025) - [x] Devcontainer for easy development (Feb 2025) +- [x] ReACT agent (March 2025) - [ ] Anthropic Tool compatibility +- [ ] New input box in the conversation menu - [ ] Add triggerable actions / tools (webhook) - [ ] Add OAuth 2.0 authentication for tools and sources -- [ ] Chatbots menu re-design to handle tools, scheduling, and more +- [ ] Chatbots menu re-design to handle tools, agent types, and more +- [ ] Agent scheduling You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT! diff --git a/application/agents/__init__.py b/application/agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/application/agents/agent_creator.py b/application/agents/agent_creator.py index a76d9faf..bf37d4ec 100644 --- a/application/agents/agent_creator.py +++ b/application/agents/agent_creator.py @@ -1,9 +1,11 @@ from application.agents.classic_agent import ClassicAgent +from application.agents.react_agent import ReActAgent class AgentCreator: agents = { "classic": ClassicAgent, + "react": ReActAgent, } @classmethod diff --git a/application/agents/base.py b/application/agents/base.py index d0f972a9..2361ef9a 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -1,4 +1,6 @@ -from typing import Dict, Generator +import uuid +from abc import ABC, abstractmethod +from typing import Dict, Generator, List, Optional from application.agents.llm_handler import get_llm_handler from application.agents.tools.tool_action_parser import ToolActionParser @@ -6,19 +8,34 @@ from application.agents.tools.tool_manager import ToolManager from application.core.mongo_db import MongoDB from application.llm.llm_creator import LLMCreator +from application.logging import build_stack_data, log_activity, LogContext +from application.retriever.base import BaseRetriever -class BaseAgent: +class BaseAgent(ABC): def __init__( self, - endpoint, - llm_name, - gpt_model, - api_key, - user_api_key=None, - decoded_token=None, + endpoint: str, + llm_name: str, + gpt_model: str, + api_key: str, + user_api_key: Optional[str] = None, + prompt: str = "", + chat_history: Optional[List[Dict]] = None, + decoded_token: Optional[Dict] = None, ): self.endpoint = endpoint + self.llm_name = llm_name + self.gpt_model = gpt_model + self.api_key = api_key + self.user_api_key = user_api_key + self.prompt = prompt + self.decoded_token = decoded_token or {} + self.user: str = decoded_token.get("sub") + self.tool_config: Dict = {} + self.tools: List[Dict] = [] + self.tool_calls: List[Dict] = [] + self.chat_history: List[Dict] = chat_history if chat_history is not None else [] self.llm = LLMCreator.create_llm( llm_name, api_key=api_key, @@ -26,13 +43,18 @@ class BaseAgent: decoded_token=decoded_token, ) self.llm_handler = get_llm_handler(llm_name) - self.gpt_model = gpt_model - self.tools = [] - self.tool_config = {} - self.tool_calls = [] - def gen(self, *args, **kwargs) -> Generator[Dict, None, None]: - raise NotImplementedError('Method "gen" must be implemented in the child class') + @log_activity() + def gen( + self, query: str, retriever: BaseRetriever, log_context: LogContext = None + ) -> Generator[Dict, None, None]: + yield from self._gen_inner(query, retriever, log_context) + + @abstractmethod + def _gen_inner( + self, query: str, retriever: BaseRetriever, log_context: LogContext + ) -> Generator[Dict, None, None]: + pass def _get_user_tools(self, user="local"): mongo = MongoDB.get_client() @@ -109,14 +131,12 @@ class BaseAgent: for param, details in action_data[param_type]["properties"].items(): if param not in call_args and "value" in details: target_dict[param] = details["value"] - for param, value in call_args.items(): for param_type, target_dict in param_types.items(): if param_type in action_data and param in action_data[param_type].get( "properties", {} ): target_dict[param] = value - tm = ToolManager(config={}) tool = tm.load_tool( tool_data["name"], @@ -151,3 +171,79 @@ class BaseAgent: self.tool_calls.append(tool_call_data) return result, call_id + + def _build_messages( + self, + system_prompt: str, + query: str, + retrieved_data: List[Dict], + ) -> List[Dict]: + docs_together = "\n".join([doc["text"] for doc in retrieved_data]) + p_chat_combine = system_prompt.replace("{summaries}", docs_together) + messages_combine = [{"role": "system", "content": p_chat_combine}] + + for i in self.chat_history: + if "prompt" in i and "response" in i: + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append({"role": "assistant", "content": i["response"]}) + if "tool_calls" in i: + for tool_call in i["tool_calls"]: + call_id = tool_call.get("call_id") or str(uuid.uuid4()) + + function_call_dict = { + "function_call": { + "name": tool_call.get("action_name"), + "args": tool_call.get("arguments"), + "call_id": call_id, + } + } + function_response_dict = { + "function_response": { + "name": tool_call.get("action_name"), + "response": {"result": tool_call.get("result")}, + "call_id": call_id, + } + } + + messages_combine.append( + {"role": "assistant", "content": [function_call_dict]} + ) + messages_combine.append( + {"role": "tool", "content": [function_response_dict]} + ) + messages_combine.append({"role": "user", "content": query}) + return messages_combine + + def _retriever_search( + self, + retriever: BaseRetriever, + query: str, + log_context: Optional[LogContext] = None, + ) -> List[Dict]: + retrieved_data = retriever.search(query) + if log_context: + data = build_stack_data(retriever, exclude_attributes=["llm"]) + log_context.stacks.append({"component": "retriever", "data": data}) + return retrieved_data + + def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None): + resp = self.llm.gen_stream( + model=self.gpt_model, messages=messages, tools=self.tools + ) + if log_context: + data = build_stack_data(self.llm) + log_context.stacks.append({"component": "llm", "data": data}) + return resp + + def _llm_handler( + self, + resp, + tools_dict: Dict, + messages: List[Dict], + log_context: Optional[LogContext] = None, + ): + resp = self.llm_handler.handle_response(self, resp, tools_dict, messages) + if log_context: + data = build_stack_data(self.llm_handler) + log_context.stacks.append({"component": "llm_handler", "data": data}) + return resp diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py index 2752c833..ce01e2e9 100644 --- a/application/agents/classic_agent.py +++ b/application/agents/classic_agent.py @@ -1,86 +1,23 @@ -import uuid from typing import Dict, Generator from application.agents.base import BaseAgent -from application.logging import build_stack_data, log_activity, LogContext +from application.logging import LogContext from application.retriever.base import BaseRetriever class ClassicAgent(BaseAgent): - def __init__( - self, - endpoint, - llm_name, - gpt_model, - api_key, - user_api_key=None, - prompt="", - chat_history=None, - decoded_token=None, - ): - super().__init__( - endpoint, llm_name, gpt_model, api_key, user_api_key, decoded_token - ) - self.user = decoded_token.get("sub") - self.prompt = prompt - self.chat_history = chat_history if chat_history is not None else [] - - @log_activity() - def gen( - self, query: str, retriever: BaseRetriever, log_context: LogContext = None - ) -> Generator[Dict, None, None]: - yield from self._gen_inner(query, retriever, log_context) - def _gen_inner( self, query: str, retriever: BaseRetriever, log_context: LogContext ) -> Generator[Dict, None, None]: retrieved_data = self._retriever_search(retriever, query, log_context) - docs_together = "\n".join([doc["text"] for doc in retrieved_data]) - p_chat_combine = self.prompt.replace("{summaries}", docs_together) - messages_combine = [{"role": "system", "content": p_chat_combine}] - - if len(self.chat_history) > 0: - for i in self.chat_history: - if "prompt" in i and "response" in i: - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append( - {"role": "assistant", "content": i["response"]} - ) - if "tool_calls" in i: - for tool_call in i["tool_calls"]: - call_id = tool_call.get("call_id") - if call_id is None or call_id == "None": - call_id = str(uuid.uuid4()) - - function_call_dict = { - "function_call": { - "name": tool_call.get("action_name"), - "args": tool_call.get("arguments"), - "call_id": call_id, - } - } - function_response_dict = { - "function_response": { - "name": tool_call.get("action_name"), - "response": {"result": tool_call.get("result")}, - "call_id": call_id, - } - } - - messages_combine.append( - {"role": "assistant", "content": [function_call_dict]} - ) - messages_combine.append( - {"role": "tool", "content": [function_response_dict]} - ) - messages_combine.append({"role": "user", "content": query}) - tools_dict = self._get_user_tools(self.user) self._prepare_tools(tools_dict) - resp = self._llm_gen(messages_combine, log_context) + messages = self._build_messages(self.prompt, query, retrieved_data) + + resp = self._llm_gen(messages, log_context) if isinstance(resp, str): yield {"answer": resp} @@ -93,7 +30,7 @@ class ClassicAgent(BaseAgent): yield {"answer": resp.message.content} return - resp = self._llm_handler(resp, tools_dict, messages_combine, log_context) + resp = self._llm_handler(resp, tools_dict, messages, log_context) if isinstance(resp, str): yield {"answer": resp} @@ -105,7 +42,7 @@ class ClassicAgent(BaseAgent): yield {"answer": resp.message.content} else: completion = self.llm.gen_stream( - model=self.gpt_model, messages=messages_combine, tools=self.tools + model=self.gpt_model, messages=messages, tools=self.tools ) for line in completion: if isinstance(line, str): @@ -113,28 +50,3 @@ class ClassicAgent(BaseAgent): yield {"sources": retrieved_data} yield {"tool_calls": self.tool_calls.copy()} - - def _retriever_search(self, retriever, query, log_context): - retrieved_data = retriever.search(query) - if log_context: - data = build_stack_data(retriever, exclude_attributes=["llm"]) - log_context.stacks.append({"component": "retriever", "data": data}) - return retrieved_data - - def _llm_gen(self, messages_combine, log_context): - resp = self.llm.gen_stream( - model=self.gpt_model, messages=messages_combine, tools=self.tools - ) - if log_context: - data = build_stack_data(self.llm) - log_context.stacks.append({"component": "llm", "data": data}) - return resp - - def _llm_handler(self, resp, tools_dict, messages_combine, log_context): - resp = self.llm_handler.handle_response( - self, resp, tools_dict, messages_combine - ) - if log_context: - data = build_stack_data(self.llm_handler) - log_context.stacks.append({"component": "llm_handler", "data": data}) - return resp diff --git a/application/agents/react_agent.py b/application/agents/react_agent.py new file mode 100644 index 00000000..f721b487 --- /dev/null +++ b/application/agents/react_agent.py @@ -0,0 +1,121 @@ +import os +from typing import Dict, Generator, List + +from application.agents.base import BaseAgent +from application.logging import build_stack_data, LogContext +from application.retriever.base import BaseRetriever + +current_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +with open( + os.path.join(current_dir, "application/prompts", "react_planning_prompt.txt"), "r" +) as f: + planning_prompt = f.read() +with open( + os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"), + "r", +) as f: + final_prompt = f.read() + + +class ReActAgent(BaseAgent): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.plan = "" + self.observations: List[str] = [] + + def _gen_inner( + self, query: str, retriever: BaseRetriever, log_context: LogContext + ) -> Generator[Dict, None, None]: + retrieved_data = self._retriever_search(retriever, query, log_context) + + tools_dict = self._get_user_tools(self.user) + self._prepare_tools(tools_dict) + + docs_together = "\n".join([doc["text"] for doc in retrieved_data]) + plan = self._create_plan(query, docs_together, log_context) + for line in plan: + if isinstance(line, str): + self.plan += line + yield {"thought": line} + + prompt = self.prompt + f"\nFollow this plan: {self.plan}" + messages = self._build_messages(prompt, query, retrieved_data) + + resp = self._llm_gen(messages, log_context) + + if isinstance(resp, str): + self.observations.append(resp) + if ( + hasattr(resp, "message") + and hasattr(resp.message, "content") + and resp.message.content is not None + ): + self.observations.append(resp.message.content) + + resp = self._llm_handler(resp, tools_dict, messages, log_context) + + for tool_call in self.tool_calls: + observation = ( + f"Action '{tool_call['action_name']}' of tool '{tool_call['tool_name']}' " + f"with arguments '{tool_call['arguments']}' returned: '{tool_call['result']}'" + ) + self.observations.append(observation) + + if isinstance(resp, str): + self.observations.append(resp) + elif ( + hasattr(resp, "message") + and hasattr(resp.message, "content") + and resp.message.content is not None + ): + self.observations.append(resp.message.content) + else: + completion = self.llm.gen_stream( + model=self.gpt_model, messages=messages, tools=self.tools + ) + for line in completion: + if isinstance(line, str): + self.observations.append(line) + + yield {"sources": retrieved_data} + yield {"tool_calls": self.tool_calls.copy()} + + final_answer = self._create_final_answer(query, self.observations, log_context) + for line in final_answer: + if isinstance(line, str): + yield {"answer": line} + + def _create_plan( + self, query: str, docs_data: str, log_context: LogContext = None + ) -> Generator[str, None, None]: + plan_prompt = planning_prompt.replace("{query}", query) + if "{summaries}" in planning_prompt: + summaries = docs_data + plan_prompt = plan_prompt.replace("{summaries}", summaries) + + messages = [{"role": "user", "content": plan_prompt}] + print(self.tools) + plan = self.llm.gen_stream( + model=self.gpt_model, messages=messages, tools=self.tools + ) + if log_context: + data = build_stack_data(self.llm) + log_context.stacks.append({"component": "planning_llm", "data": data}) + return plan + + def _create_final_answer( + self, query: str, observations: List[str], log_context: LogContext = None + ) -> str: + observation_string = "\n".join(observations) + final_answer_prompt = final_prompt.format( + query=query, observations=observation_string + ) + + messages = [{"role": "user", "content": final_answer_prompt}] + final_answer = self.llm.gen_stream(model=self.gpt_model, messages=messages) + if log_context: + data = build_stack_data(self.llm) + log_context.stacks.append({"component": "final_answer_llm", "data": data}) + return final_answer diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 34081784..8ecd218f 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -121,6 +121,7 @@ def save_conversation( conversation_id, question, response, + thought, source_log_docs, tool_calls, llm, @@ -136,6 +137,7 @@ def save_conversation( "$set": { f"queries.{index}.prompt": question, f"queries.{index}.response": response, + f"queries.{index}.thought": thought, f"queries.{index}.sources": source_log_docs, f"queries.{index}.tool_calls": tool_calls, f"queries.{index}.timestamp": current_time, @@ -155,6 +157,7 @@ def save_conversation( "queries": { "prompt": question, "response": response, + "thought": thought, "sources": source_log_docs, "tool_calls": tool_calls, "timestamp": current_time, @@ -190,6 +193,7 @@ def save_conversation( { "prompt": question, "response": response, + "thought": thought, "sources": source_log_docs, "tool_calls": tool_calls, "timestamp": current_time, @@ -230,9 +234,7 @@ def complete_stream( should_save_conversation=True, ): try: - response_full = "" - source_log_docs = [] - tool_calls = [] + response_full, thought, source_log_docs, tool_calls = "", "", [], [] answer = agent.gen(query=question, retriever=retriever) @@ -258,6 +260,10 @@ def complete_stream( tool_calls = line["tool_calls"] data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls}) yield f"data: {data}\n\n" + elif "thought" in line: + thought += line["thought"] + data = json.dumps({"type": "thought", "thought": line["thought"]}) + yield f"data: {data}\n\n" if isNoneDoc: for doc in source_log_docs: @@ -275,6 +281,7 @@ def complete_stream( conversation_id, question, response_full, + thought, source_log_docs, tool_calls, llm, diff --git a/application/prompts/react_final_prompt.txt b/application/prompts/react_final_prompt.txt new file mode 100644 index 00000000..50916127 --- /dev/null +++ b/application/prompts/react_final_prompt.txt @@ -0,0 +1,3 @@ +Query: {query} +Observations: {observations} +Now, using the insights from the observations, formulate a well-structured and precise final answer. \ No newline at end of file diff --git a/application/prompts/react_planning_prompt.txt b/application/prompts/react_planning_prompt.txt new file mode 100644 index 00000000..3fd17116 --- /dev/null +++ b/application/prompts/react_planning_prompt.txt @@ -0,0 +1,10 @@ +You are an AI assistant and talk like you're thinking out loud. Given the following query, outline a concise thought process that includes key steps and considerations necessary for effective analysis and response. Avoid pointwise formatting. The goal is to break down the query into manageable components without excessive detail, focusing on clarity and logical progression. + +Include the following elements in your thought process: +1. Identify the main objective of the query. +2. Determine any relevant context or background information needed to understand the query. +3. List potential approaches or methods to address the query. +4. Highlight any critical factors or constraints that may influence the outcome. + +Query: {query} +Summaries: {summaries} \ No newline at end of file diff --git a/frontend/src/assets/cloud.svg b/frontend/src/assets/cloud.svg new file mode 100644 index 00000000..9a8727db --- /dev/null +++ b/frontend/src/assets/cloud.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index 7bc042cf..ec244087 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -6,16 +6,16 @@ import ReactMarkdown from 'react-markdown'; import { useSelector } from 'react-redux'; import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'; import { - vscDarkPlus, oneLight, + vscDarkPlus, } from 'react-syntax-highlighter/dist/cjs/styles/prism'; import rehypeKatex from 'rehype-katex'; import remarkGfm from 'remark-gfm'; import remarkMath from 'remark-math'; -import { useDarkTheme } from '../hooks'; -import DocsGPT3 from '../assets/cute_docsgpt3.svg'; import ChevronDown from '../assets/chevron-down.svg'; +import Cloud from '../assets/cloud.svg'; +import DocsGPT3 from '../assets/cute_docsgpt3.svg'; import Dislike from '../assets/dislike.svg?react'; import Document from '../assets/document.svg'; import Edit from '../assets/edit.svg'; @@ -28,7 +28,7 @@ import Avatar from '../components/Avatar'; import CopyButton from '../components/CopyButton'; import Sidebar from '../components/Sidebar'; import SpeakButton from '../components/TextToSpeechButton'; -import { useOutsideAlerter } from '../hooks'; +import { useDarkTheme, useOutsideAlerter } from '../hooks'; import { selectChunks, selectSelectedDocs, @@ -42,11 +42,12 @@ const DisableSourceFE = import.meta.env.VITE_DISABLE_SOURCE_FE || false; const ConversationBubble = forwardRef< HTMLDivElement, { - message: string; + message?: string; type: MESSAGE_TYPE; className?: string; feedback?: FEEDBACK; handleFeedback?: (feedback: FEEDBACK) => void; + thought?: string; sources?: { title: string; text: string; source: string }[]; toolCalls?: ToolCallsType[]; retryBtn?: React.ReactElement; @@ -65,6 +66,7 @@ const ConversationBubble = forwardRef< className, feedback, handleFeedback, + thought, sources, toolCalls, retryBtn, @@ -160,7 +162,7 @@ const ConversationBubble = forwardRef< + + {isThoughtOpen && ( +
+
+ +
+ + {language} + + +
+ + {String(children).replace(/\n$/, '')} + +
+ ) : ( + + {children} + + ); + }, + ul({ children }) { + return ( + + ); + }, + ol({ children }) { + return ( +
    + {children} +
+ ); + }, + table({ children }) { + return ( +
+ + {children} +
+
+ ); + }, + thead({ children }) { + return ( + + {children} + + ); + }, + tr({ children }) { + return ( + + {children} + + ); + }, + th({ children }) { + return {children}; + }, + td({ children }) { + return {children}; + }, + }} + > + {preprocessLaTeX(thought ?? '')} + +
+ + )} + + ); +} diff --git a/frontend/src/conversation/ConversationMessages.tsx b/frontend/src/conversation/ConversationMessages.tsx index 601b9cf1..a578538f 100644 --- a/frontend/src/conversation/ConversationMessages.tsx +++ b/frontend/src/conversation/ConversationMessages.tsx @@ -1,11 +1,12 @@ import { Fragment, useEffect, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; -import ConversationBubble from './ConversationBubble'; -import Hero from '../Hero'; -import { FEEDBACK, Query, Status } from './conversationModels'; + import ArrowDown from '../assets/arrow-down.svg'; import RetryIcon from '../components/RetryIcon'; +import Hero from '../Hero'; import { useDarkTheme } from '../hooks'; +import ConversationBubble from './ConversationBubble'; +import { FEEDBACK, Query, Status } from './conversationModels'; interface ConversationMessagesProps { handleQuestion: (params: { @@ -83,13 +84,14 @@ export default function ConversationMessages({ const prepResponseView = (query: Query, index: number) => { let responseView; - if (query.response) { + if (query.thought || query.response) { responseView = ( }>, + ) { + const { index, query } = action.payload; + if (query.thought != undefined) { + state.queries[index].thought = + (state.queries[index].thought || '') + query.thought; + } else { + state.queries[index] = { + ...state.queries[index], + ...query, + }; + } + }, updateStreamingSource( state, action: PayloadAction<{ index: number; query: Partial }>, @@ -286,6 +312,7 @@ export const { resendQuery, updateStreamingQuery, updateConversationId, + updateThought, updateStreamingSource, updateToolCalls, setConversation, diff --git a/frontend/src/conversation/sharedConversationSlice.ts b/frontend/src/conversation/sharedConversationSlice.ts index 7ae4862b..4f784ad2 100644 --- a/frontend/src/conversation/sharedConversationSlice.ts +++ b/frontend/src/conversation/sharedConversationSlice.ts @@ -44,6 +44,15 @@ export const fetchSharedAnswer = createAsyncThunk( // set status to 'idle' dispatch(sharedConversationSlice.actions.setStatus('idle')); dispatch(saveToLocalStorage()); + } else if (data.type === 'thought') { + const result = data.thought; + console.log('thought', result); + dispatch( + updateThought({ + index: state.sharedConversation.queries.length - 1, + query: { thought: result }, + }), + ); } else if (data.type === 'source') { dispatch( updateStreamingSource({ @@ -113,6 +122,7 @@ export const fetchSharedAnswer = createAsyncThunk( answer: '', query: question, result: '', + thought: '', sources: [], tool_calls: [], }; @@ -183,6 +193,21 @@ export const sharedConversationSlice = createSlice({ ...query, }; }, + updateThought( + state, + action: PayloadAction<{ index: number; query: Partial }>, + ) { + const { index, query } = action.payload; + if (query.thought != undefined) { + state.queries[index].thought = + (state.queries[index].thought || '') + query.thought; + } else { + state.queries[index] = { + ...state.queries[index], + ...query, + }; + } + }, updateStreamingSource( state, action: PayloadAction<{ index: number; query: Partial }>, @@ -243,6 +268,7 @@ export const { setClientApiKey, updateQuery, updateStreamingQuery, + updateThought, updateToolCalls, addQuery, saveToLocalStorage,