From fd2b6c111c40a336d29a0ec7a89439777a7e1a98 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 31 Mar 2025 17:02:36 +0530 Subject: [PATCH] feat: enhance ClassicAgent and ReActAgent with tool preparation steps --- application/agents/classic_agent.py | 4 +++- application/agents/react_agent.py | 11 +++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py index 3328f7f8..ce01e2e9 100644 --- a/application/agents/classic_agent.py +++ b/application/agents/classic_agent.py @@ -11,11 +11,12 @@ class ClassicAgent(BaseAgent): self, query: str, retriever: BaseRetriever, log_context: LogContext ) -> Generator[Dict, None, None]: retrieved_data = self._retriever_search(retriever, query, log_context) - messages = self._build_messages(self.prompt, query, retrieved_data) tools_dict = self._get_user_tools(self.user) self._prepare_tools(tools_dict) + messages = self._build_messages(self.prompt, query, retrieved_data) + resp = self._llm_gen(messages, log_context) if isinstance(resp, str): @@ -46,5 +47,6 @@ class ClassicAgent(BaseAgent): for line in completion: if isinstance(line, str): yield {"answer": line} + yield {"sources": retrieved_data} yield {"tool_calls": self.tool_calls.copy()} diff --git a/application/agents/react_agent.py b/application/agents/react_agent.py index f4fee0e7..572a4e51 100644 --- a/application/agents/react_agent.py +++ b/application/agents/react_agent.py @@ -19,6 +19,9 @@ class ReActAgent(BaseAgent): ) -> Generator[Dict, None, None]: retrieved_data = self._retriever_search(retriever, query, log_context) + tools_dict = self._get_user_tools(self.user) + self._prepare_tools(tools_dict) + docs_together = "\n".join([doc["text"] for doc in retrieved_data]) plan = self._create_plan(query, docs_together, log_context) for line in plan: @@ -29,9 +32,6 @@ class ReActAgent(BaseAgent): prompt = self.prompt + f"\nFollow this plan: {self.plan}" messages = self._build_messages(prompt, query, retrieved_data) - tools_dict = self._get_user_tools(self.user) - self._prepare_tools(tools_dict) - resp = self._llm_gen(messages, log_context) if isinstance(resp, str): @@ -85,7 +85,10 @@ class ReActAgent(BaseAgent): plan_prompt = plan_prompt.replace("{summaries}", summaries) messages = [{"role": "user", "content": plan_prompt}] - plan = self.llm.gen_stream(model=self.gpt_model, messages=messages) + print(self.tools) + plan = self.llm.gen_stream( + model=self.gpt_model, messages=messages, tools=self.tools + ) if log_context: data = build_stack_data(self.llm) log_context.stacks.append({"component": "planning_llm", "data": data})