feat: add tool calls tracking and show in frontend

This commit is contained in:
Siddhant Rai
2025-02-12 21:47:47 +05:30
parent 0de4241b56
commit e209699b19
13 changed files with 302 additions and 51 deletions

View File

@@ -88,9 +88,6 @@ def get_data_from_api_key(api_key):
if data is None:
raise Exception("Invalid API Key, please generate new key", 401)
if "retriever" not in data:
data["retriever"] = None
if "source" in data and isinstance(data["source"], DBRef):
source_doc = db.dereference(data["source"])
data["source"] = str(source_doc["_id"])
@@ -117,7 +114,9 @@ def is_azure_configured():
)
def save_conversation(conversation_id, question, response, source_log_docs, llm,index=None):
def save_conversation(
conversation_id, question, response, source_log_docs, tool_calls, llm, index=None
):
if conversation_id is not None and index is not None:
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
@@ -126,20 +125,14 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
f"queries.{index}.prompt": question,
f"queries.{index}.response": response,
f"queries.{index}.sources": source_log_docs,
f"queries.{index}.tool_calls": tool_calls,
}
}
},
)
##remove following queries from the array
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
{
"$push":{
"queries":{
"$each":[],
"$slice":index+1
}
}
}
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
)
elif conversation_id is not None and conversation_id != "None":
conversations_collection.update_one(
@@ -150,6 +143,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
"prompt": question,
"response": response,
"sources": source_log_docs,
"tool_calls": tool_calls,
}
}
},
@@ -169,11 +163,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system \n\nUser: "
+ question
+ "\n\n"
+ "AI: "
+ response,
"system \n\nUser: " + question + "\n\n" + "AI: " + response,
},
]
@@ -188,6 +178,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
"prompt": question,
"response": response,
"sources": source_log_docs,
"tool_calls": tool_calls,
}
],
}
@@ -208,12 +199,13 @@ def get_prompt(prompt_id):
def complete_stream(
question, retriever, conversation_id, user_api_key, isNoneDoc=False,index=None
question, retriever, conversation_id, user_api_key, isNoneDoc=False, index=None
):
try:
try:
response_full = ""
source_log_docs = []
tool_calls = []
answer = retriever.gen()
sources = retriever.search()
for source in sources:
@@ -222,6 +214,7 @@ def complete_stream(
if len(sources) > 0:
data = json.dumps({"type": "source", "source": sources})
yield f"data: {data}\n\n"
for line in answer:
if "answer" in line:
response_full += str(line["answer"])
@@ -229,6 +222,10 @@ def complete_stream(
yield f"data: {data}\n\n"
elif "source" in line:
source_log_docs.append(line["source"])
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
yield f"data: {data}\n\n"
if isNoneDoc:
for doc in source_log_docs:
@@ -239,7 +236,13 @@ def complete_stream(
)
if user_api_key is None:
conversation_id = save_conversation(
conversation_id, question, response_full, source_log_docs, llm,index
conversation_id,
question,
response_full,
source_log_docs,
tool_calls,
llm,
index,
)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
@@ -303,7 +306,7 @@ class Stream(Resource):
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"index":fields.Integer(
"index": fields.Integer(
required=False, description="The position where query is to be updated"
),
},
@@ -315,22 +318,24 @@ class Stream(Resource):
data = request.get_json()
required_fields = ["question"]
if "index" in data:
required_fields = ["question","conversation_id"]
required_fields = ["question", "conversation_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
history = limit_chat_history(
json.loads(data.get("history", [])), gpt_model=gpt_model
)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
index=data.get("index",None)
index = data.get("index", None)
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
@@ -367,7 +372,7 @@ class Stream(Resource):
gpt_model=gpt_model,
user_api_key=user_api_key,
)
return Response(
complete_stream(
question=question,
@@ -395,7 +400,7 @@ class Stream(Resource):
)
status_code = 400
return Response(
error_stream_generate('Unknown error occurred'),
error_stream_generate("Unknown error occurred"),
status=status_code,
mimetype="text/event-stream",
)
@@ -442,14 +447,16 @@ class Answer(Resource):
@api.doc(description="Provide an answer based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
history = limit_chat_history(
json.loads(data.get("history", [])), gpt_model=gpt_model
)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
chunks = int(data.get("chunks", 2))
@@ -490,13 +497,16 @@ class Answer(Resource):
user_api_key=user_api_key,
)
source_log_docs = []
response_full = ""
source_log_docs = []
tool_calls = []
for line in retriever.gen():
if "source" in line:
source_log_docs.append(line["source"])
elif "answer" in line:
response_full += line["answer"]
elif "tool_calls" in line:
tool_calls.append(line["tool_calls"])
if data.get("isNoneDoc"):
for doc in source_log_docs:
@@ -509,7 +519,12 @@ class Answer(Resource):
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = str(
save_conversation(
conversation_id, question, response_full, source_log_docs, llm
conversation_id,
question,
response_full,
source_log_docs,
tool_calls,
llm,
)
)
retriever_params = retriever.get_params()