mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
feat: enhance ClassicAgent and ReActAgent with tool preparation steps
This commit is contained in:
@@ -11,11 +11,12 @@ class ClassicAgent(BaseAgent):
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
retrieved_data = self._retriever_search(retriever, query, log_context)
|
||||
messages = self._build_messages(self.prompt, query, retrieved_data)
|
||||
|
||||
tools_dict = self._get_user_tools(self.user)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
messages = self._build_messages(self.prompt, query, retrieved_data)
|
||||
|
||||
resp = self._llm_gen(messages, log_context)
|
||||
|
||||
if isinstance(resp, str):
|
||||
@@ -46,5 +47,6 @@ class ClassicAgent(BaseAgent):
|
||||
for line in completion:
|
||||
if isinstance(line, str):
|
||||
yield {"answer": line}
|
||||
|
||||
yield {"sources": retrieved_data}
|
||||
yield {"tool_calls": self.tool_calls.copy()}
|
||||
|
||||
@@ -19,6 +19,9 @@ class ReActAgent(BaseAgent):
|
||||
) -> Generator[Dict, None, None]:
|
||||
retrieved_data = self._retriever_search(retriever, query, log_context)
|
||||
|
||||
tools_dict = self._get_user_tools(self.user)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
|
||||
plan = self._create_plan(query, docs_together, log_context)
|
||||
for line in plan:
|
||||
@@ -29,9 +32,6 @@ class ReActAgent(BaseAgent):
|
||||
prompt = self.prompt + f"\nFollow this plan: {self.plan}"
|
||||
messages = self._build_messages(prompt, query, retrieved_data)
|
||||
|
||||
tools_dict = self._get_user_tools(self.user)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
resp = self._llm_gen(messages, log_context)
|
||||
|
||||
if isinstance(resp, str):
|
||||
@@ -85,7 +85,10 @@ class ReActAgent(BaseAgent):
|
||||
plan_prompt = plan_prompt.replace("{summaries}", summaries)
|
||||
|
||||
messages = [{"role": "user", "content": plan_prompt}]
|
||||
plan = self.llm.gen_stream(model=self.gpt_model, messages=messages)
|
||||
print(self.tools)
|
||||
plan = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=self.tools
|
||||
)
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm)
|
||||
log_context.stacks.append({"component": "planning_llm", "data": data})
|
||||
|
||||
Reference in New Issue
Block a user