mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
133 lines
4.8 KiB
Python
133 lines
4.8 KiB
Python
import os
|
|
from typing import Dict, Generator, List
|
|
|
|
from application.agents.base import BaseAgent
|
|
from application.logging import build_stack_data, LogContext
|
|
from application.retriever.base import BaseRetriever
|
|
|
|
current_dir = os.path.dirname(
|
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
)
|
|
with open(
|
|
os.path.join(current_dir, "application/prompts", "react_planning_prompt.txt"), "r"
|
|
) as f:
|
|
planning_prompt = f.read()
|
|
with open(
|
|
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"),
|
|
"r",
|
|
) as f:
|
|
final_prompt = f.read()
|
|
|
|
|
|
class ReActAgent(BaseAgent):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.plan = ""
|
|
self.observations: List[str] = []
|
|
|
|
def _gen_inner(
|
|
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
|
) -> Generator[Dict, None, None]:
|
|
retrieved_data = self._retriever_search(retriever, query, log_context)
|
|
|
|
if self.user_api_key:
|
|
tools_dict = self._get_tools(self.user_api_key)
|
|
else:
|
|
tools_dict = self._get_user_tools(self.user)
|
|
self._prepare_tools(tools_dict)
|
|
|
|
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
|
|
plan = self._create_plan(query, docs_together, log_context)
|
|
for line in plan:
|
|
if isinstance(line, str):
|
|
self.plan += line
|
|
yield {"thought": line}
|
|
|
|
prompt = self.prompt + f"\nFollow this plan: {self.plan}"
|
|
messages = self._build_messages(prompt, query, retrieved_data)
|
|
|
|
resp = self._llm_gen(messages, log_context)
|
|
|
|
if isinstance(resp, str):
|
|
self.observations.append(resp)
|
|
if (
|
|
hasattr(resp, "message")
|
|
and hasattr(resp.message, "content")
|
|
and resp.message.content is not None
|
|
):
|
|
self.observations.append(resp.message.content)
|
|
|
|
resp = self._llm_handler(resp, tools_dict, messages, log_context)
|
|
|
|
for tool_call in self.tool_calls:
|
|
observation = (
|
|
f"Action '{tool_call['action_name']}' of tool '{tool_call['tool_name']}' "
|
|
f"with arguments '{tool_call['arguments']}' returned: '{tool_call['result']}'"
|
|
)
|
|
self.observations.append(observation)
|
|
|
|
if isinstance(resp, str):
|
|
self.observations.append(resp)
|
|
elif (
|
|
hasattr(resp, "message")
|
|
and hasattr(resp.message, "content")
|
|
and resp.message.content is not None
|
|
):
|
|
self.observations.append(resp.message.content)
|
|
else:
|
|
completion = self.llm.gen_stream(
|
|
model=self.gpt_model, messages=messages, tools=self.tools
|
|
)
|
|
for line in completion:
|
|
if isinstance(line, str):
|
|
self.observations.append(line)
|
|
|
|
log_context.stacks.append(
|
|
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
|
)
|
|
|
|
yield {"sources": retrieved_data}
|
|
# clean tool_call_data only send first 50 characters of tool_call['result']
|
|
for tool_call in self.tool_calls:
|
|
if len(str(tool_call["result"])) > 50:
|
|
tool_call["result"] = str(tool_call["result"])[:50] + "..."
|
|
yield {"tool_calls": self.tool_calls.copy()}
|
|
|
|
final_answer = self._create_final_answer(query, self.observations, log_context)
|
|
for line in final_answer:
|
|
if isinstance(line, str):
|
|
yield {"answer": line}
|
|
|
|
def _create_plan(
|
|
self, query: str, docs_data: str, log_context: LogContext = None
|
|
) -> Generator[str, None, None]:
|
|
plan_prompt = planning_prompt.replace("{query}", query)
|
|
if "{summaries}" in planning_prompt:
|
|
summaries = docs_data
|
|
plan_prompt = plan_prompt.replace("{summaries}", summaries)
|
|
|
|
messages = [{"role": "user", "content": plan_prompt}]
|
|
print(self.tools)
|
|
plan = self.llm.gen_stream(
|
|
model=self.gpt_model, messages=messages, tools=self.tools
|
|
)
|
|
if log_context:
|
|
data = build_stack_data(self.llm)
|
|
log_context.stacks.append({"component": "planning_llm", "data": data})
|
|
return plan
|
|
|
|
def _create_final_answer(
|
|
self, query: str, observations: List[str], log_context: LogContext = None
|
|
) -> str:
|
|
observation_string = "\n".join(observations)
|
|
final_answer_prompt = final_prompt.format(
|
|
query=query, observations=observation_string
|
|
)
|
|
|
|
messages = [{"role": "user", "content": final_answer_prompt}]
|
|
final_answer = self.llm.gen_stream(model=self.gpt_model, messages=messages)
|
|
if log_context:
|
|
data = build_stack_data(self.llm)
|
|
log_context.stacks.append({"component": "final_answer_llm", "data": data})
|
|
return final_answer
|