mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-05 07:57:50 +00:00
feat: add tool calls tracking and show in frontend
This commit is contained in:
@@ -88,9 +88,6 @@ def get_data_from_api_key(api_key):
|
||||
if data is None:
|
||||
raise Exception("Invalid API Key, please generate new key", 401)
|
||||
|
||||
if "retriever" not in data:
|
||||
data["retriever"] = None
|
||||
|
||||
if "source" in data and isinstance(data["source"], DBRef):
|
||||
source_doc = db.dereference(data["source"])
|
||||
data["source"] = str(source_doc["_id"])
|
||||
@@ -117,7 +114,9 @@ def is_azure_configured():
|
||||
)
|
||||
|
||||
|
||||
def save_conversation(conversation_id, question, response, source_log_docs, llm,index=None):
|
||||
def save_conversation(
|
||||
conversation_id, question, response, source_log_docs, tool_calls, llm, index=None
|
||||
):
|
||||
if conversation_id is not None and index is not None:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
||||
@@ -126,20 +125,14 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
|
||||
f"queries.{index}.prompt": question,
|
||||
f"queries.{index}.response": response,
|
||||
f"queries.{index}.sources": source_log_docs,
|
||||
f"queries.{index}.tool_calls": tool_calls,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
##remove following queries from the array
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
||||
{
|
||||
"$push":{
|
||||
"queries":{
|
||||
"$each":[],
|
||||
"$slice":index+1
|
||||
}
|
||||
}
|
||||
}
|
||||
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
||||
)
|
||||
elif conversation_id is not None and conversation_id != "None":
|
||||
conversations_collection.update_one(
|
||||
@@ -150,6 +143,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -169,11 +163,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
|
||||
"role": "user",
|
||||
"content": "Summarise following conversation in no more than 3 words, "
|
||||
"respond ONLY with the summary, use the same language as the "
|
||||
"system \n\nUser: "
|
||||
+ question
|
||||
+ "\n\n"
|
||||
+ "AI: "
|
||||
+ response,
|
||||
"system \n\nUser: " + question + "\n\n" + "AI: " + response,
|
||||
},
|
||||
]
|
||||
|
||||
@@ -188,6 +178,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -208,12 +199,13 @@ def get_prompt(prompt_id):
|
||||
|
||||
|
||||
def complete_stream(
|
||||
question, retriever, conversation_id, user_api_key, isNoneDoc=False,index=None
|
||||
question, retriever, conversation_id, user_api_key, isNoneDoc=False, index=None
|
||||
):
|
||||
|
||||
try:
|
||||
try:
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
tool_calls = []
|
||||
answer = retriever.gen()
|
||||
sources = retriever.search()
|
||||
for source in sources:
|
||||
@@ -222,6 +214,7 @@ def complete_stream(
|
||||
if len(sources) > 0:
|
||||
data = json.dumps({"type": "source", "source": sources})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
for line in answer:
|
||||
if "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
@@ -229,6 +222,10 @@ def complete_stream(
|
||||
yield f"data: {data}\n\n"
|
||||
elif "source" in line:
|
||||
source_log_docs.append(line["source"])
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
@@ -239,7 +236,13 @@ def complete_stream(
|
||||
)
|
||||
if user_api_key is None:
|
||||
conversation_id = save_conversation(
|
||||
conversation_id, question, response_full, source_log_docs, llm,index
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
index,
|
||||
)
|
||||
# send data.type = "end" to indicate that the stream has ended as json
|
||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||
@@ -303,7 +306,7 @@ class Stream(Resource):
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
"index":fields.Integer(
|
||||
"index": fields.Integer(
|
||||
required=False, description="The position where query is to be updated"
|
||||
),
|
||||
},
|
||||
@@ -315,22 +318,24 @@ class Stream(Resource):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
if "index" in data:
|
||||
required_fields = ["question","conversation_id"]
|
||||
required_fields = ["question", "conversation_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
|
||||
history = limit_chat_history(
|
||||
json.loads(data.get("history", [])), gpt_model=gpt_model
|
||||
)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
|
||||
index=data.get("index",None)
|
||||
|
||||
index = data.get("index", None)
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
@@ -367,7 +372,7 @@ class Stream(Resource):
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
|
||||
|
||||
return Response(
|
||||
complete_stream(
|
||||
question=question,
|
||||
@@ -395,7 +400,7 @@ class Stream(Resource):
|
||||
)
|
||||
status_code = 400
|
||||
return Response(
|
||||
error_stream_generate('Unknown error occurred'),
|
||||
error_stream_generate("Unknown error occurred"),
|
||||
status=status_code,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
@@ -442,14 +447,16 @@ class Answer(Resource):
|
||||
@api.doc(description="Provide an answer based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
required_fields = ["question"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
|
||||
history = limit_chat_history(
|
||||
json.loads(data.get("history", [])), gpt_model=gpt_model
|
||||
)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
chunks = int(data.get("chunks", 2))
|
||||
@@ -490,13 +497,16 @@ class Answer(Resource):
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
|
||||
source_log_docs = []
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
tool_calls = []
|
||||
for line in retriever.gen():
|
||||
if "source" in line:
|
||||
source_log_docs.append(line["source"])
|
||||
elif "answer" in line:
|
||||
response_full += line["answer"]
|
||||
elif "tool_calls" in line:
|
||||
tool_calls.append(line["tool_calls"])
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in source_log_docs:
|
||||
@@ -509,7 +519,12 @@ class Answer(Resource):
|
||||
result = {"answer": response_full, "sources": source_log_docs}
|
||||
result["conversation_id"] = str(
|
||||
save_conversation(
|
||||
conversation_id, question, response_full, source_log_docs, llm
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
)
|
||||
)
|
||||
retriever_params = retriever.get_params()
|
||||
|
||||
@@ -35,6 +35,12 @@ class ClassicRAG(BaseRetriever):
|
||||
)
|
||||
)
|
||||
self.user_api_key = user_api_key
|
||||
self.agent = Agent(
|
||||
llm_name=settings.LLM_NAME,
|
||||
gpt_model=self.gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.user_api_key,
|
||||
)
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0:
|
||||
@@ -78,20 +84,24 @@ class ClassicRAG(BaseRetriever):
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": i["response"]}
|
||||
)
|
||||
if "tool_calls" in i:
|
||||
for tool_call in i["tool_calls"]:
|
||||
messages_combine.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}",
|
||||
}
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
# llm = LLMCreator.create_llm(
|
||||
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
|
||||
# )
|
||||
# completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
agent = Agent(
|
||||
llm_name=settings.LLM_NAME,
|
||||
gpt_model=self.gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.user_api_key,
|
||||
)
|
||||
completion = agent.gen(messages_combine)
|
||||
|
||||
completion = self.agent.gen(messages_combine)
|
||||
for line in completion:
|
||||
yield {"answer": str(line)}
|
||||
yield {"tool_calls": self.agent.tool_calls.copy()}
|
||||
|
||||
def search(self):
|
||||
return self._get_data()
|
||||
|
||||
@@ -16,6 +16,7 @@ class Agent:
|
||||
# Static tool configuration (to be replaced later)
|
||||
self.tools = []
|
||||
self.tool_config = {}
|
||||
self.tool_calls = []
|
||||
|
||||
def _get_user_tools(self, user="local"):
|
||||
mongo = MongoDB.get_client()
|
||||
@@ -123,6 +124,15 @@ class Agent:
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
call_id = getattr(call, "id", None)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": tool_data["name"],
|
||||
"action_name": action_name,
|
||||
"arguments": str(call_args),
|
||||
"result": str(result),
|
||||
}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
|
||||
return result, call_id
|
||||
|
||||
def _simple_tool_agent(self, messages):
|
||||
@@ -154,6 +164,7 @@ class Agent:
|
||||
return
|
||||
|
||||
def gen(self, messages):
|
||||
self.tool_calls = []
|
||||
if self.llm.supports_tools():
|
||||
resp = self._simple_tool_agent(messages)
|
||||
for line in resp:
|
||||
|
||||
Reference in New Issue
Block a user