Merge pull request #1630 from siiddhantt/feat/show-tool-execution

feat: tool calls tracking
This commit is contained in:
Alex
2025-02-14 10:27:15 +00:00
committed by GitHub
18 changed files with 502 additions and 99 deletions

View File

@@ -88,9 +88,6 @@ def get_data_from_api_key(api_key):
if data is None:
raise Exception("Invalid API Key, please generate new key", 401)
if "retriever" not in data:
data["retriever"] = None
if "source" in data and isinstance(data["source"], DBRef):
source_doc = db.dereference(data["source"])
data["source"] = str(source_doc["_id"])
@@ -117,7 +114,9 @@ def is_azure_configured():
)
def save_conversation(conversation_id, question, response, source_log_docs, llm,index=None):
def save_conversation(
conversation_id, question, response, source_log_docs, tool_calls, llm, index=None
):
if conversation_id is not None and index is not None:
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
@@ -126,20 +125,14 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
f"queries.{index}.prompt": question,
f"queries.{index}.response": response,
f"queries.{index}.sources": source_log_docs,
f"queries.{index}.tool_calls": tool_calls,
}
}
},
)
##remove following queries from the array
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
{
"$push":{
"queries":{
"$each":[],
"$slice":index+1
}
}
}
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
)
elif conversation_id is not None and conversation_id != "None":
conversations_collection.update_one(
@@ -150,6 +143,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
"prompt": question,
"response": response,
"sources": source_log_docs,
"tool_calls": tool_calls,
}
}
},
@@ -169,11 +163,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system \n\nUser: "
+ question
+ "\n\n"
+ "AI: "
+ response,
"system \n\nUser: " + question + "\n\n" + "AI: " + response,
},
]
@@ -188,6 +178,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
"prompt": question,
"response": response,
"sources": source_log_docs,
"tool_calls": tool_calls,
}
],
}
@@ -208,12 +199,13 @@ def get_prompt(prompt_id):
def complete_stream(
question, retriever, conversation_id, user_api_key, isNoneDoc=False,index=None
question, retriever, conversation_id, user_api_key, isNoneDoc=False, index=None
):
try:
try:
response_full = ""
source_log_docs = []
tool_calls = []
answer = retriever.gen()
sources = retriever.search()
for source in sources:
@@ -222,6 +214,7 @@ def complete_stream(
if len(sources) > 0:
data = json.dumps({"type": "source", "source": sources})
yield f"data: {data}\n\n"
for line in answer:
if "answer" in line:
response_full += str(line["answer"])
@@ -229,6 +222,10 @@ def complete_stream(
yield f"data: {data}\n\n"
elif "source" in line:
source_log_docs.append(line["source"])
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
yield f"data: {data}\n\n"
if isNoneDoc:
for doc in source_log_docs:
@@ -239,7 +236,13 @@ def complete_stream(
)
if user_api_key is None:
conversation_id = save_conversation(
conversation_id, question, response_full, source_log_docs, llm,index
conversation_id,
question,
response_full,
source_log_docs,
tool_calls,
llm,
index,
)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
@@ -303,7 +306,7 @@ class Stream(Resource):
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"index":fields.Integer(
"index": fields.Integer(
required=False, description="The position where query is to be updated"
),
},
@@ -315,22 +318,24 @@ class Stream(Resource):
data = request.get_json()
required_fields = ["question"]
if "index" in data:
required_fields = ["question","conversation_id"]
required_fields = ["question", "conversation_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
history = limit_chat_history(
json.loads(data.get("history", [])), gpt_model=gpt_model
)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
index=data.get("index",None)
index = data.get("index", None)
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
@@ -367,7 +372,7 @@ class Stream(Resource):
gpt_model=gpt_model,
user_api_key=user_api_key,
)
return Response(
complete_stream(
question=question,
@@ -395,7 +400,7 @@ class Stream(Resource):
)
status_code = 400
return Response(
error_stream_generate('Unknown error occurred'),
error_stream_generate("Unknown error occurred"),
status=status_code,
mimetype="text/event-stream",
)
@@ -442,14 +447,16 @@ class Answer(Resource):
@api.doc(description="Provide an answer based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
history = limit_chat_history(
json.loads(data.get("history", [])), gpt_model=gpt_model
)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
chunks = int(data.get("chunks", 2))
@@ -490,13 +497,16 @@ class Answer(Resource):
user_api_key=user_api_key,
)
source_log_docs = []
response_full = ""
source_log_docs = []
tool_calls = []
for line in retriever.gen():
if "source" in line:
source_log_docs.append(line["source"])
elif "answer" in line:
response_full += line["answer"]
elif "tool_calls" in line:
tool_calls.append(line["tool_calls"])
if data.get("isNoneDoc"):
for doc in source_log_docs:
@@ -509,7 +519,12 @@ class Answer(Resource):
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = str(
save_conversation(
conversation_id, question, response_full, source_log_docs, llm
conversation_id,
question,
response_full,
source_log_docs,
tool_calls,
llm,
)
)
retriever_params = retriever.get_params()

View File

@@ -1,26 +1,37 @@
import os
from pathlib import Path
from typing import Optional
import os
from pydantic_settings import BaseSettings
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
class Settings(BaseSettings):
LLM_NAME: str = "docsgpt"
MODEL_NAME: Optional[str] = None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
MODEL_NAME: Optional[str] = (
None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
)
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
DEFAULT_MAX_HISTORY: int = 150
MODEL_TOKEN_LIMITS: dict = {"gpt-4o-mini": 128000, "gpt-3.5-turbo": 4096, "claude-2": 1e5}
MODEL_TOKEN_LIMITS: dict = {
"gpt-4o-mini": 128000,
"gpt-3.5-turbo": 4096,
"claude-2": 1e5,
"gemini-2.0-flash-exp": 1e6,
}
UPLOAD_FOLDER: str = "inputs"
PARSE_PDF_AS_IMAGE: bool = False
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
VECTOR_STORE: str = (
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
)
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
# LLM Cache
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
@@ -28,12 +39,18 @@ class Settings(BaseSettings):
API_URL: str = "http://localhost:7091" # backend url for celery worker
API_KEY: Optional[str] = None # LLM api key
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings
OPENAI_BASE_URL: Optional[str] = None # openai base url for open ai compatable models
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = (
None # azure deployment name for embeddings
)
OPENAI_BASE_URL: Optional[str] = (
None # openai base url for open ai compatable models
)
# elasticsearch
ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch
@@ -68,12 +85,14 @@ class Settings(BaseSettings):
# Milvus vectorstore config
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
MILVUS_TOKEN: Optional[str] = ""
# LanceDB vectorstore config
LANCEDB_PATH: str = "/tmp/lancedb" # Path where LanceDB stores its local data
LANCEDB_TABLE_NAME: Optional[str] = "docsgpts" # Name of the table to use for storing vectors
LANCEDB_TABLE_NAME: Optional[str] = (
"docsgpts" # Name of the table to use for storing vectors
)
BRAVE_SEARCH_API_KEY: Optional[str] = None
FLASK_DEBUG_MODE: bool = False

View File

@@ -1,3 +1,5 @@
import json
from application.core.settings import settings
from application.llm.base import BaseLLM
@@ -15,6 +17,63 @@ class OpenAILLM(BaseLLM):
self.api_key = api_key
self.user_api_key = user_api_key
def _clean_messages_openai(self, messages):
cleaned_messages = []
for message in messages:
role = message.get("role")
content = message.get("content")
if role == "model":
role = "assistant"
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
elif isinstance(content, list):
for item in content:
if "text" in item:
cleaned_messages.append(
{"role": role, "content": item["text"]}
)
elif "function_call" in item:
tool_call = {
"id": item["function_call"]["call_id"],
"type": "function",
"function": {
"name": item["function_call"]["name"],
"arguments": json.dumps(
item["function_call"]["args"]
),
},
}
cleaned_messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": [tool_call],
}
)
elif "function_response" in item:
cleaned_messages.append(
{
"role": "tool",
"tool_call_id": item["function_response"][
"call_id"
],
"content": json.dumps(
item["function_response"]["response"]["result"]
),
}
)
else:
raise ValueError(
f"Unexpected content dictionary format: {item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
return cleaned_messages
def _raw_gen(
self,
baseself,
@@ -25,9 +84,15 @@ class OpenAILLM(BaseLLM):
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs,
):
messages = self._clean_messages_openai(messages)
print(messages)
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
model=model,
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
return response.choices[0]
else:
@@ -46,6 +111,7 @@ class OpenAILLM(BaseLLM):
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs,
):
messages = self._clean_messages_openai(messages)
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)

View File

@@ -1,3 +1,5 @@
import uuid
from application.core.settings import settings
from application.retriever.base import BaseRetriever
from application.tools.agent import Agent
@@ -35,6 +37,12 @@ class ClassicRAG(BaseRetriever):
)
)
self.user_api_key = user_api_key
self.agent = Agent(
llm_name=settings.LLM_NAME,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
)
def _get_data(self):
if self.chunks == 0:
@@ -78,21 +86,42 @@ class ClassicRAG(BaseRetriever):
messages_combine.append(
{"role": "assistant", "content": i["response"]}
)
if "tool_calls" in i:
for tool_call in i["tool_calls"]:
call_id = tool_call.get("call_id")
if call_id is None or call_id == "None":
call_id = str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages_combine.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages_combine.append(
{"role": "tool", "content": [function_response_dict]}
)
messages_combine.append({"role": "user", "content": self.question})
# llm = LLMCreator.create_llm(
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
# )
# completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
agent = Agent(
llm_name=settings.LLM_NAME,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
)
completion = agent.gen(messages_combine)
completion = self.agent.gen(messages_combine)
for line in completion:
yield {"answer": str(line)}
yield {"tool_calls": self.agent.tool_calls.copy()}
def search(self):
return self._get_data()

View File

@@ -16,6 +16,7 @@ class Agent:
# Static tool configuration (to be replaced later)
self.tools = []
self.tool_config = {}
self.tool_calls = []
def _get_user_tools(self, user="local"):
mongo = MongoDB.get_client()
@@ -123,6 +124,16 @@ class Agent:
print(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
call_id = getattr(call, "id", None)
tool_call_data = {
"tool_name": tool_data["name"],
"call_id": call_id if call_id is not None else "None",
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
"result": result,
}
self.tool_calls.append(tool_call_data)
return result, call_id
def _simple_tool_agent(self, messages):
@@ -134,7 +145,11 @@ class Agent:
if isinstance(resp, str):
yield resp
return
if hasattr(resp, "message") and hasattr(resp.message, "content"):
if (
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
yield resp.message.content
return
@@ -142,7 +157,11 @@ class Agent:
if isinstance(resp, str):
yield resp
elif hasattr(resp, "message") and hasattr(resp.message, "content"):
elif (
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
yield resp.message.content
else:
completion = self.llm.gen_stream(
@@ -154,6 +173,7 @@ class Agent:
return
def gen(self, messages):
self.tool_calls = []
if self.llm.supports_tools():
resp = self._simple_tool_agent(messages)
for line in resp:

View File

@@ -24,13 +24,28 @@ class OpenAILLMHandler(LLMHandler):
tool_response, call_id = agent._execute_tool_action(
tools_dict, call
)
messages.append(
{
"role": "tool",
"content": str(tool_response),
"tool_call_id": call_id,
function_call_dict = {
"function_call": {
"name": call.function.name,
"args": call.function.arguments,
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": call.function.name,
"response": {"result": tool_response},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
)
except Exception as e:
messages.append(
{

View File

@@ -1,8 +1,9 @@
import tiktoken
import hashlib
from flask import jsonify, make_response
import re
import tiktoken
from flask import jsonify, make_response
_encoding = None
@@ -22,6 +23,7 @@ def num_tokens_from_string(string: str) -> int:
else:
return 0
def num_tokens_from_object_or_list(thing):
if isinstance(thing, list):
return sum([num_tokens_from_object_or_list(x) for x in thing])
@@ -32,6 +34,7 @@ def num_tokens_from_object_or_list(thing):
else:
return 0
def count_tokens_docs(docs):
docs_content = ""
for doc in docs:
@@ -59,6 +62,7 @@ def check_required_fields(data, required_fields):
def get_hash(data):
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
"""
Limits chat history based on token count.
@@ -67,38 +71,41 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
from application.core.settings import settings
max_token_limit = (
max_token_limit
if max_token_limit and
max_token_limit < settings.MODEL_TOKEN_LIMITS.get(
gpt_model, settings.DEFAULT_MAX_HISTORY
)
else settings.MODEL_TOKEN_LIMITS.get(
gpt_model, settings.DEFAULT_MAX_HISTORY
)
)
max_token_limit
if max_token_limit
and max_token_limit
< settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
else settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
)
if not history:
return []
tokens_current_history = 0
trimmed_history = []
tokens_current_history = 0
for message in reversed(history):
tokens_batch = 0
if "prompt" in message and "response" in message:
tokens_batch = num_tokens_from_string(message["prompt"]) + num_tokens_from_string(
message["response"]
)
if tokens_current_history + tokens_batch < max_token_limit:
tokens_current_history += tokens_batch
trimmed_history.insert(0, message)
else:
break
tokens_batch += num_tokens_from_string(message["prompt"])
tokens_batch += num_tokens_from_string(message["response"])
if "tool_calls" in message:
for tool_call in message["tool_calls"]:
tool_call_string = f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}"
tokens_batch += num_tokens_from_string(tool_call_string)
if tokens_current_history + tokens_batch < max_token_limit:
tokens_current_history += tokens_batch
trimmed_history.insert(0, message)
else:
break
return trimmed_history
def validate_function_name(function_name):
"""Validates if a function name matches the allowed pattern."""
if not re.match(r"^[a-zA-Z0-9_-]+$", function_name):
return False
return True
return True

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-chevron-down"><path d="m6 9 6 6 6-6"/></svg>

After

Width:  |  Height:  |  Size: 246 B

View File

@@ -0,0 +1,59 @@
import React, { useRef, useState } from 'react';
import ChevronDown from '../assets/chevron-down.svg';
type AccordionProps = {
title: string;
children: React.ReactNode;
className?: string;
titleClassName?: string;
contentClassName?: string;
open?: boolean;
};
export default function Accordion({
title,
children,
className = '',
titleClassName = '',
contentClassName = '',
open: initialOpen = false,
}: AccordionProps) {
const contentRef = useRef<HTMLDivElement>(null);
const [isOpen, setIsOpen] = useState(initialOpen);
const accordionContentStyle = {
height: isOpen ? 'auto' : '0px',
transition: 'height 0.3s ease-in-out, opacity 0.3s ease-in-out',
overflow: 'hidden',
} as React.CSSProperties;
const toggleAccordion = () => {
setIsOpen(!isOpen);
};
return (
<div className={`shadow-sm overflow-hidden ${className}`}>
<button
className={`flex items-center justify-between w-full focus:outline-none ${titleClassName}`}
onClick={toggleAccordion}
>
<p className="break-words">{title}</p>
<img
src={ChevronDown}
className={`h-5 w-5 transform transition-transform duration-200 dark:invert ${
isOpen ? 'rotate-180' : ''
}`}
aria-hidden="true"
/>
</button>
<div
ref={contentRef}
style={accordionContentStyle}
className={`px-4 ${contentClassName} ${isOpen ? 'pb-3' : 'pb-0'}`}
>
{children}
</div>
</div>
);
}

View File

@@ -40,7 +40,7 @@ export default function CoppyButton({
/>
) : (
<Copy
className="cursor-pointer fill-none"
className="w-4 cursor-pointer fill-none"
onClick={() => {
handleCopyClick(text);
}}

View File

@@ -225,6 +225,7 @@ export default function Conversation() {
message={query.response}
type={'ANSWER'}
sources={query.sources}
toolCalls={query.tool_calls}
feedback={query.feedback}
handleFeedback={(feedback: FEEDBACK) =>
handleFeedback(query, feedback, index)

View File

@@ -1,6 +1,7 @@
import 'katex/dist/katex.min.css';
import { forwardRef, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import ReactMarkdown from 'react-markdown';
import { useSelector } from 'react-redux';
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter';
@@ -8,27 +9,29 @@ import { vscDarkPlus } from 'react-syntax-highlighter/dist/cjs/styles/prism';
import rehypeKatex from 'rehype-katex';
import remarkGfm from 'remark-gfm';
import remarkMath from 'remark-math';
import { useTranslation } from 'react-i18next';
import DocsGPT3 from '../assets/cute_docsgpt3.svg';
import ChevronDown from '../assets/chevron-down.svg';
import Dislike from '../assets/dislike.svg?react';
import Document from '../assets/document.svg';
import Edit from '../assets/edit.svg';
import Like from '../assets/like.svg?react';
import Link from '../assets/link.svg';
import Sources from '../assets/sources.svg';
import Edit from '../assets/edit.svg';
import UserIcon from '../assets/user.png';
import Accordion from '../components/Accordion';
import Avatar from '../components/Avatar';
import CopyButton from '../components/CopyButton';
import Sidebar from '../components/Sidebar';
import SpeakButton from '../components/TextToSpeechButton';
import { useOutsideAlerter } from '../hooks';
import {
selectChunks,
selectSelectedDocs,
} from '../preferences/preferenceSlice';
import classes from './ConversationBubble.module.css';
import { FEEDBACK, MESSAGE_TYPE } from './conversationModels';
import { useOutsideAlerter } from '../hooks';
import { ToolCallsType } from './types';
const DisableSourceFE = import.meta.env.VITE_DISABLE_SOURCE_FE || false;
@@ -41,6 +44,7 @@ const ConversationBubble = forwardRef<
feedback?: FEEDBACK;
handleFeedback?: (feedback: FEEDBACK) => void;
sources?: { title: string; text: string; source: string }[];
toolCalls?: ToolCallsType[];
retryBtn?: React.ReactElement;
questionNumber?: number;
handleUpdatedQuestionSubmission?: (
@@ -57,6 +61,7 @@ const ConversationBubble = forwardRef<
feedback,
handleFeedback,
sources,
toolCalls,
retryBtn,
questionNumber,
handleUpdatedQuestionSubmission,
@@ -307,6 +312,9 @@ const ConversationBubble = forwardRef<
</div>
)
)}
{toolCalls && toolCalls.length > 0 && (
<ToolCalls toolCalls={toolCalls} />
)}
<div className="flex flex-col flex-wrap items-start self-start lg:flex-nowrap">
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
@@ -586,3 +594,88 @@ function AllSources(sources: AllSourcesProps) {
}
export default ConversationBubble;
function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
const [isToolCallsOpen, setIsToolCallsOpen] = useState(false);
return (
<div className="mb-4 w-full flex flex-col flex-wrap items-start self-start lg:flex-nowrap">
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
className="h-[26px] w-[30px] text-xl"
avatar={
<img
src={Sources}
alt={'ToolCalls'}
className="h-full w-full object-fill"
/>
}
/>
<button
className="flex flex-row items-center gap-2"
onClick={() => setIsToolCallsOpen(!isToolCallsOpen)}
>
<p className="text-base font-semibold">Tool Calls</p>
<img
src={ChevronDown}
alt="ChevronDown"
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${isToolCallsOpen ? 'rotate-180' : ''}`}
/>
</button>
</div>
{isToolCallsOpen && (
<div className="fade-in ml-3 mr-5 max-w-[90vw] md:max-w-[70vw] lg:max-w-[50vw]">
<div className="grid grid-cols-1 gap-2">
{toolCalls.map((toolCall, index) => (
<Accordion
key={`tool-call-${index}`}
title={`${toolCall.tool_name} - ${toolCall.action_name.substring(0, toolCall.action_name.lastIndexOf('_'))}`}
className="w-full rounded-[20px] bg-gray-1000 dark:bg-gun-metal hover:bg-[#F1F1F1] dark:hover:bg-[#2C2E3C]"
titleClassName="px-6 py-2 text-sm font-semibold"
children={
<div className="flex flex-col gap-1">
<div className="flex flex-col border border-silver dark:border-silver/20 rounded-2xl">
<p className="flex flex-row items-center justify-between px-2 py-1 text-sm font-semibold bg-black/10 dark:bg-[#191919] rounded-t-2xl break-words">
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
Arguments
</span>{' '}
<CopyButton
text={JSON.stringify(toolCall.arguments, null, 2)}
/>
</p>
<p className="p-2 font-mono text-sm dark:tex dark:bg-[#222327] rounded-b-2xl break-words">
<span
className="text-black dark:text-gray-400 leading-[23px]"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{JSON.stringify(toolCall.arguments, null, 2)}
</span>
</p>
</div>
<div className="flex flex-col border border-silver dark:border-silver/20 rounded-2xl">
<p className="flex flex-row items-center justify-between px-2 py-1 text-sm font-semibold bg-black/10 dark:bg-[#191919] rounded-t-2xl break-words">
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
Response
</span>{' '}
<CopyButton
text={JSON.stringify(toolCall.result, null, 2)}
/>
</p>
<p className="p-2 font-mono text-sm dark:tex dark:bg-[#222327] rounded-b-2xl break-words">
<span
className="text-black dark:text-gray-400 leading-[23px]"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{JSON.stringify(toolCall.result, null, 2)}
</span>
</p>
</div>
</div>
}
/>
))}
</div>
</div>
)}
</div>
);
}

View File

@@ -121,6 +121,7 @@ export const SharedConversation = () => {
message={query.response}
type={'ANSWER'}
sources={query.sources ?? []}
toolCalls={query.tool_calls}
></ConversationBubble>
);
} else if (query.error) {

View File

@@ -1,6 +1,7 @@
import conversationService from '../api/services/conversationService';
import { Doc } from '../models/misc';
import { Answer, FEEDBACK, RetrievalPayload } from './conversationModels';
import { ToolCallsType } from './types';
export function handleFetchAnswer(
question: string,
@@ -16,6 +17,7 @@ export function handleFetchAnswer(
result: any;
answer: any;
sources: any;
toolCalls: ToolCallsType[];
conversationId: any;
query: string;
}
@@ -23,13 +25,18 @@ export function handleFetchAnswer(
result: any;
answer: any;
sources: any;
toolCalls: ToolCallsType[];
query: string;
conversationId: any;
title: any;
}
> {
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
return {
prompt: item.prompt,
response: item.response,
tool_calls: item.tool_calls,
};
});
const payload: RetrievalPayload = {
question: question,
@@ -60,6 +67,7 @@ export function handleFetchAnswer(
query: question,
result,
sources: data.sources,
toolCalls: data.tool_calls,
conversationId: data.conversation_id,
};
});
@@ -78,7 +86,11 @@ export function handleFetchAnswerSteaming(
indx?: number,
): Promise<Answer> {
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
return {
prompt: item.prompt,
response: item.response,
tool_calls: item.tool_calls,
};
});
const payload: RetrievalPayload = {
question: question,
@@ -155,7 +167,11 @@ export function handleSearch(
token_limit: number,
) {
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
return {
prompt: item.prompt,
response: item.response,
tool_calls: item.tool_calls,
};
});
const payload: RetrievalPayload = {
question: question,
@@ -183,7 +199,11 @@ export function handleSearchViaApiKey(
history: Array<any> = [],
) {
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
return {
prompt: item.prompt,
response: item.response,
tool_calls: item.tool_calls,
};
});
return conversationService
.search({
@@ -230,7 +250,11 @@ export function handleFetchSharedAnswerStreaming( //for shared conversations
onEvent: (event: MessageEvent) => void,
): Promise<Answer> {
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
return {
prompt: item.prompt,
response: item.response,
tool_calls: item.tool_calls,
};
});
return new Promise<Answer>((resolve, reject) => {
@@ -330,6 +354,7 @@ export function handleFetchSharedAnswer(
query: question,
result,
sources: data.sources,
toolCalls: data.tool_calls,
};
});
}

View File

@@ -1,3 +1,5 @@
import { ToolCallsType } from './types';
export type MESSAGE_TYPE = 'QUESTION' | 'ANSWER' | 'ERROR';
export type Status = 'idle' | 'loading' | 'failed';
export type FEEDBACK = 'LIKE' | 'DISLIKE' | null;
@@ -17,9 +19,10 @@ export interface Answer {
answer: string;
query: string;
result: string;
sources: { title: string; text: string; source: string }[];
conversationId: string | null;
title: string | null;
sources: { title: string; text: string; source: string }[];
tool_calls: ToolCallsType[];
}
export interface Query {
@@ -27,10 +30,12 @@ export interface Query {
response?: string;
feedback?: FEEDBACK;
error?: string;
sources?: { title: string; text: string; source: string }[];
conversationId?: string | null;
title?: string | null;
sources?: { title: string; text: string; source: string }[];
tool_calls?: ToolCallsType[];
}
export interface RetrievalPayload {
question: string;
active_docs?: string;

View File

@@ -82,6 +82,13 @@ export const fetchAnswer = createAsyncThunk<
query: { sources: data.source ?? [] },
}),
);
} else if (data.type === 'tool_calls') {
dispatch(
updateToolCalls({
index: indx ?? state.conversation.queries.length - 1,
query: { tool_calls: data.tool_calls },
}),
);
} else if (data.type === 'error') {
// set status to 'failed'
dispatch(conversationSlice.actions.setStatus('failed'));
@@ -130,7 +137,11 @@ export const fetchAnswer = createAsyncThunk<
dispatch(
updateQuery({
index: indx ?? state.conversation.queries.length - 1,
query: { response: answer.answer, sources: sourcesPrepped },
query: {
response: answer.answer,
sources: sourcesPrepped,
tool_calls: answer.toolCalls,
},
}),
);
dispatch(
@@ -156,6 +167,7 @@ export const fetchAnswer = createAsyncThunk<
query: question,
result: '',
sources: [],
tool_calls: [],
};
});
@@ -212,6 +224,15 @@ export const conversationSlice = createSlice({
state.queries[index].sources!.push(query.sources![0]);
}
},
updateToolCalls(
state,
action: PayloadAction<{ index: number; query: Partial<Query> }>,
) {
const { index, query } = action.payload;
if (!state.queries[index].tool_calls) {
state.queries[index].tool_calls = query?.tool_calls;
}
},
updateQuery(
state,
action: PayloadAction<{ index: number; query: Partial<Query> }>,
@@ -263,6 +284,7 @@ export const {
updateStreamingQuery,
updateConversationId,
updateStreamingSource,
updateToolCalls,
setConversation,
} = conversationSlice.actions;
export default conversationSlice.reducer;

View File

@@ -51,6 +51,13 @@ export const fetchSharedAnswer = createAsyncThunk<Answer, { question: string }>(
query: { sources: data.source ?? [] },
}),
);
} else if (data.type === 'tool_calls') {
dispatch(
updateToolCalls({
index: state.sharedConversation.queries.length - 1,
query: { tool_calls: data.tool_calls },
}),
);
} else if (data.type === 'error') {
// set status to 'failed'
dispatch(sharedConversationSlice.actions.setStatus('failed'));
@@ -107,6 +114,7 @@ export const fetchSharedAnswer = createAsyncThunk<Answer, { question: string }>(
query: question,
result: '',
sources: [],
tool_calls: [],
};
},
);
@@ -161,6 +169,15 @@ export const sharedConversationSlice = createSlice({
};
}
},
updateToolCalls(
state,
action: PayloadAction<{ index: number; query: Partial<Query> }>,
) {
const { index, query } = action.payload;
if (!state.queries[index].tool_calls) {
state.queries[index].tool_calls = query?.tool_calls;
}
},
updateQuery(
state,
action: PayloadAction<{ index: number; query: Partial<Query> }>,
@@ -232,6 +249,7 @@ export const {
setClientApiKey,
updateQuery,
updateStreamingQuery,
updateToolCalls,
addQuery,
saveToLocalStorage,
updateStreamingSource,

View File

@@ -0,0 +1,7 @@
export type ToolCallsType = {
tool_name: string;
action_name: string;
call_id: string;
arguments: Record<string, any>;
result: Record<string, any>;
};