From 8a3612e56c0152171d50512dacc125dcdd093135 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Thu, 13 Feb 2025 05:02:10 +0530 Subject: [PATCH] fix: improve tool call handling and UI adjustments --- application/tools/agent.py | 12 +++- application/utils.py | 55 +++++++++++-------- frontend/src/components/CopyButton.tsx | 2 +- .../src/conversation/ConversationBubble.tsx | 26 ++++++--- 4 files changed, 61 insertions(+), 34 deletions(-) diff --git a/application/tools/agent.py b/application/tools/agent.py index 3f44f42a..148160db 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -144,7 +144,11 @@ class Agent: if isinstance(resp, str): yield resp return - if hasattr(resp, "message") and hasattr(resp.message, "content"): + if ( + hasattr(resp, "message") + and hasattr(resp.message, "content") + and resp.message.content is not None + ): yield resp.message.content return @@ -152,7 +156,11 @@ class Agent: if isinstance(resp, str): yield resp - elif hasattr(resp, "message") and hasattr(resp.message, "content"): + elif ( + hasattr(resp, "message") + and hasattr(resp.message, "content") + and resp.message.content is not None + ): yield resp.message.content else: completion = self.llm.gen_stream( diff --git a/application/utils.py b/application/utils.py index 54d2086f..1a075cb6 100644 --- a/application/utils.py +++ b/application/utils.py @@ -1,8 +1,9 @@ -import tiktoken import hashlib -from flask import jsonify, make_response import re +import tiktoken +from flask import jsonify, make_response + _encoding = None @@ -22,6 +23,7 @@ def num_tokens_from_string(string: str) -> int: else: return 0 + def num_tokens_from_object_or_list(thing): if isinstance(thing, list): return sum([num_tokens_from_object_or_list(x) for x in thing]) @@ -32,6 +34,7 @@ def num_tokens_from_object_or_list(thing): else: return 0 + def count_tokens_docs(docs): docs_content = "" for doc in docs: @@ -59,6 +62,7 @@ def check_required_fields(data, required_fields): def get_hash(data): return hashlib.md5(data.encode()).hexdigest() + def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"): """ Limits chat history based on token count. @@ -67,38 +71,41 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"): from application.core.settings import settings max_token_limit = ( - max_token_limit - if max_token_limit and - max_token_limit < settings.MODEL_TOKEN_LIMITS.get( - gpt_model, settings.DEFAULT_MAX_HISTORY - ) - else settings.MODEL_TOKEN_LIMITS.get( - gpt_model, settings.DEFAULT_MAX_HISTORY - ) - ) - + max_token_limit + if max_token_limit + and max_token_limit + < settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) + else settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) + ) if not history: return [] - - tokens_current_history = 0 + trimmed_history = [] - + tokens_current_history = 0 + for message in reversed(history): + tokens_batch = 0 if "prompt" in message and "response" in message: - tokens_batch = num_tokens_from_string(message["prompt"]) + num_tokens_from_string( - message["response"] - ) - if tokens_current_history + tokens_batch < max_token_limit: - tokens_current_history += tokens_batch - trimmed_history.insert(0, message) - else: - break + tokens_batch += num_tokens_from_string(message["prompt"]) + tokens_batch += num_tokens_from_string(message["response"]) + + if "tool_calls" in message: + for tool_call in message["tool_calls"]: + tool_call_string = f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}" + tokens_batch += num_tokens_from_string(tool_call_string) + + if tokens_current_history + tokens_batch < max_token_limit: + tokens_current_history += tokens_batch + trimmed_history.insert(0, message) + else: + break return trimmed_history + def validate_function_name(function_name): """Validates if a function name matches the allowed pattern.""" if not re.match(r"^[a-zA-Z0-9_-]+$", function_name): return False - return True \ No newline at end of file + return True diff --git a/frontend/src/components/CopyButton.tsx b/frontend/src/components/CopyButton.tsx index e13f9133..f0559f52 100644 --- a/frontend/src/components/CopyButton.tsx +++ b/frontend/src/components/CopyButton.tsx @@ -40,7 +40,7 @@ export default function CoppyButton({ /> ) : ( { handleCopyClick(text); }} diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index 7987ae8f..4d26a3e5 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -630,25 +630,37 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) { key={`tool-call-${index}`} title={`${toolCall.tool_name} - ${toolCall.action_name}`} className="w-full rounded-[20px] bg-gray-1000 dark:bg-gun-metal hover:bg-[#F1F1F1] dark:hover:bg-[#2C2E3C]" - titleClassName="px-4 py-2 text-sm font-semibold" + titleClassName="px-6 py-2 text-sm font-semibold" children={
-

- Arguments +

+ + Arguments + {' '} +

- + {toolCall.arguments}

-

- Response +

+ + Response + {' '} +

- + {toolCall.result}