feat: add tool calls tracking and show in frontend

This commit is contained in:
Siddhant Rai
2025-02-12 21:47:47 +05:30
parent 0de4241b56
commit e209699b19
13 changed files with 302 additions and 51 deletions

View File

@@ -88,9 +88,6 @@ def get_data_from_api_key(api_key):
if data is None:
raise Exception("Invalid API Key, please generate new key", 401)
if "retriever" not in data:
data["retriever"] = None
if "source" in data and isinstance(data["source"], DBRef):
source_doc = db.dereference(data["source"])
data["source"] = str(source_doc["_id"])
@@ -117,7 +114,9 @@ def is_azure_configured():
)
def save_conversation(conversation_id, question, response, source_log_docs, llm,index=None):
def save_conversation(
conversation_id, question, response, source_log_docs, tool_calls, llm, index=None
):
if conversation_id is not None and index is not None:
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
@@ -126,20 +125,14 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
f"queries.{index}.prompt": question,
f"queries.{index}.response": response,
f"queries.{index}.sources": source_log_docs,
f"queries.{index}.tool_calls": tool_calls,
}
}
},
)
##remove following queries from the array
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
{
"$push":{
"queries":{
"$each":[],
"$slice":index+1
}
}
}
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
)
elif conversation_id is not None and conversation_id != "None":
conversations_collection.update_one(
@@ -150,6 +143,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
"prompt": question,
"response": response,
"sources": source_log_docs,
"tool_calls": tool_calls,
}
}
},
@@ -169,11 +163,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system \n\nUser: "
+ question
+ "\n\n"
+ "AI: "
+ response,
"system \n\nUser: " + question + "\n\n" + "AI: " + response,
},
]
@@ -188,6 +178,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
"prompt": question,
"response": response,
"sources": source_log_docs,
"tool_calls": tool_calls,
}
],
}
@@ -208,12 +199,13 @@ def get_prompt(prompt_id):
def complete_stream(
question, retriever, conversation_id, user_api_key, isNoneDoc=False,index=None
question, retriever, conversation_id, user_api_key, isNoneDoc=False, index=None
):
try:
try:
response_full = ""
source_log_docs = []
tool_calls = []
answer = retriever.gen()
sources = retriever.search()
for source in sources:
@@ -222,6 +214,7 @@ def complete_stream(
if len(sources) > 0:
data = json.dumps({"type": "source", "source": sources})
yield f"data: {data}\n\n"
for line in answer:
if "answer" in line:
response_full += str(line["answer"])
@@ -229,6 +222,10 @@ def complete_stream(
yield f"data: {data}\n\n"
elif "source" in line:
source_log_docs.append(line["source"])
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
yield f"data: {data}\n\n"
if isNoneDoc:
for doc in source_log_docs:
@@ -239,7 +236,13 @@ def complete_stream(
)
if user_api_key is None:
conversation_id = save_conversation(
conversation_id, question, response_full, source_log_docs, llm,index
conversation_id,
question,
response_full,
source_log_docs,
tool_calls,
llm,
index,
)
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
@@ -303,7 +306,7 @@ class Stream(Resource):
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"index":fields.Integer(
"index": fields.Integer(
required=False, description="The position where query is to be updated"
),
},
@@ -315,22 +318,24 @@ class Stream(Resource):
data = request.get_json()
required_fields = ["question"]
if "index" in data:
required_fields = ["question","conversation_id"]
required_fields = ["question", "conversation_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
history = limit_chat_history(
json.loads(data.get("history", [])), gpt_model=gpt_model
)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
index=data.get("index",None)
index = data.get("index", None)
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
@@ -367,7 +372,7 @@ class Stream(Resource):
gpt_model=gpt_model,
user_api_key=user_api_key,
)
return Response(
complete_stream(
question=question,
@@ -395,7 +400,7 @@ class Stream(Resource):
)
status_code = 400
return Response(
error_stream_generate('Unknown error occurred'),
error_stream_generate("Unknown error occurred"),
status=status_code,
mimetype="text/event-stream",
)
@@ -442,14 +447,16 @@ class Answer(Resource):
@api.doc(description="Provide an answer based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
history = limit_chat_history(
json.loads(data.get("history", [])), gpt_model=gpt_model
)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
chunks = int(data.get("chunks", 2))
@@ -490,13 +497,16 @@ class Answer(Resource):
user_api_key=user_api_key,
)
source_log_docs = []
response_full = ""
source_log_docs = []
tool_calls = []
for line in retriever.gen():
if "source" in line:
source_log_docs.append(line["source"])
elif "answer" in line:
response_full += line["answer"]
elif "tool_calls" in line:
tool_calls.append(line["tool_calls"])
if data.get("isNoneDoc"):
for doc in source_log_docs:
@@ -509,7 +519,12 @@ class Answer(Resource):
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = str(
save_conversation(
conversation_id, question, response_full, source_log_docs, llm
conversation_id,
question,
response_full,
source_log_docs,
tool_calls,
llm,
)
)
retriever_params = retriever.get_params()

View File

@@ -35,6 +35,12 @@ class ClassicRAG(BaseRetriever):
)
)
self.user_api_key = user_api_key
self.agent = Agent(
llm_name=settings.LLM_NAME,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
)
def _get_data(self):
if self.chunks == 0:
@@ -78,20 +84,24 @@ class ClassicRAG(BaseRetriever):
messages_combine.append(
{"role": "assistant", "content": i["response"]}
)
if "tool_calls" in i:
for tool_call in i["tool_calls"]:
messages_combine.append(
{
"role": "assistant",
"content": f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}",
}
)
messages_combine.append({"role": "user", "content": self.question})
# llm = LLMCreator.create_llm(
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
# )
# completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
agent = Agent(
llm_name=settings.LLM_NAME,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
)
completion = agent.gen(messages_combine)
completion = self.agent.gen(messages_combine)
for line in completion:
yield {"answer": str(line)}
yield {"tool_calls": self.agent.tool_calls.copy()}
def search(self):
return self._get_data()

View File

@@ -16,6 +16,7 @@ class Agent:
# Static tool configuration (to be replaced later)
self.tools = []
self.tool_config = {}
self.tool_calls = []
def _get_user_tools(self, user="local"):
mongo = MongoDB.get_client()
@@ -123,6 +124,15 @@ class Agent:
print(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
call_id = getattr(call, "id", None)
tool_call_data = {
"tool_name": tool_data["name"],
"action_name": action_name,
"arguments": str(call_args),
"result": str(result),
}
self.tool_calls.append(tool_call_data)
return result, call_id
def _simple_tool_agent(self, messages):
@@ -154,6 +164,7 @@ class Agent:
return
def gen(self, messages):
self.tool_calls = []
if self.llm.supports_tools():
resp = self._simple_tool_agent(messages)
for line in resp:

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-chevron-down"><path d="m6 9 6 6 6-6"/></svg>

After

Width:  |  Height:  |  Size: 246 B

View File

@@ -0,0 +1,59 @@
import React, { useRef, useState } from 'react';
import ChevronDown from '../assets/chevron-down.svg';
type AccordionProps = {
title: string;
children: React.ReactNode;
className?: string;
titleClassName?: string;
contentClassName?: string;
open?: boolean;
};
export default function Accordion({
title,
children,
className = '',
titleClassName = '',
contentClassName = '',
open: initialOpen = false,
}: AccordionProps) {
const contentRef = useRef<HTMLDivElement>(null);
const [isOpen, setIsOpen] = useState(initialOpen);
const accordionContentStyle = {
height: isOpen ? 'auto' : '0px',
transition: 'height 0.3s ease-in-out, opacity 0.3s ease-in-out',
overflow: 'hidden',
} as React.CSSProperties;
const toggleAccordion = () => {
setIsOpen(!isOpen);
};
return (
<div className={`shadow-sm overflow-hidden ${className}`}>
<button
className={`flex items-center justify-between w-full focus:outline-none ${titleClassName}`}
onClick={toggleAccordion}
>
<p className="break-words">{title}</p>
<img
src={ChevronDown}
className={`h-5 w-5 transform transition-transform duration-200 dark:invert ${
isOpen ? 'rotate-180' : ''
}`}
aria-hidden="true"
/>
</button>
<div
ref={contentRef}
style={accordionContentStyle}
className={`px-4 ${contentClassName} ${isOpen ? 'pb-3' : 'pb-0'}`}
>
{children}
</div>
</div>
);
}

View File

@@ -225,6 +225,7 @@ export default function Conversation() {
message={query.response}
type={'ANSWER'}
sources={query.sources}
toolCalls={query.tool_calls}
feedback={query.feedback}
handleFeedback={(feedback: FEEDBACK) =>
handleFeedback(query, feedback, index)

View File

@@ -1,6 +1,7 @@
import 'katex/dist/katex.min.css';
import { forwardRef, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import ReactMarkdown from 'react-markdown';
import { useSelector } from 'react-redux';
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter';
@@ -8,27 +9,29 @@ import { vscDarkPlus } from 'react-syntax-highlighter/dist/cjs/styles/prism';
import rehypeKatex from 'rehype-katex';
import remarkGfm from 'remark-gfm';
import remarkMath from 'remark-math';
import { useTranslation } from 'react-i18next';
import DocsGPT3 from '../assets/cute_docsgpt3.svg';
import ChevronDown from '../assets/chevron-down.svg';
import Dislike from '../assets/dislike.svg?react';
import Document from '../assets/document.svg';
import Edit from '../assets/edit.svg';
import Like from '../assets/like.svg?react';
import Link from '../assets/link.svg';
import Sources from '../assets/sources.svg';
import Edit from '../assets/edit.svg';
import UserIcon from '../assets/user.png';
import Accordion from '../components/Accordion';
import Avatar from '../components/Avatar';
import CopyButton from '../components/CopyButton';
import Sidebar from '../components/Sidebar';
import SpeakButton from '../components/TextToSpeechButton';
import { useOutsideAlerter } from '../hooks';
import {
selectChunks,
selectSelectedDocs,
} from '../preferences/preferenceSlice';
import classes from './ConversationBubble.module.css';
import { FEEDBACK, MESSAGE_TYPE } from './conversationModels';
import { useOutsideAlerter } from '../hooks';
import { ToolCallsType } from './types';
const DisableSourceFE = import.meta.env.VITE_DISABLE_SOURCE_FE || false;
@@ -41,6 +44,7 @@ const ConversationBubble = forwardRef<
feedback?: FEEDBACK;
handleFeedback?: (feedback: FEEDBACK) => void;
sources?: { title: string; text: string; source: string }[];
toolCalls?: ToolCallsType[];
retryBtn?: React.ReactElement;
questionNumber?: number;
handleUpdatedQuestionSubmission?: (
@@ -57,6 +61,7 @@ const ConversationBubble = forwardRef<
feedback,
handleFeedback,
sources,
toolCalls,
retryBtn,
questionNumber,
handleUpdatedQuestionSubmission,
@@ -307,6 +312,9 @@ const ConversationBubble = forwardRef<
</div>
)
)}
{toolCalls && toolCalls.length > 0 && (
<ToolCalls toolCalls={toolCalls} />
)}
<div className="flex flex-col flex-wrap items-start self-start lg:flex-nowrap">
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
@@ -586,3 +594,72 @@ function AllSources(sources: AllSourcesProps) {
}
export default ConversationBubble;
function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
const [isToolCallsOpen, setIsToolCallsOpen] = useState(false);
return (
<div className="mb-4 w-full flex flex-col flex-wrap items-start self-start lg:flex-nowrap">
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
className="h-[26px] w-[30px] text-xl"
avatar={
<img
src={Sources}
alt={'ToolCalls'}
className="h-full w-full object-fill"
/>
}
/>
<button
className="flex flex-row items-center gap-2"
onClick={() => setIsToolCallsOpen(!isToolCallsOpen)}
>
<p className="text-base font-semibold">Tool Calls</p>
<img
src={ChevronDown}
alt="ChevronDown"
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${isToolCallsOpen ? 'rotate-180' : ''}`}
/>
</button>
</div>
{isToolCallsOpen && (
<div className="fade-in ml-3 mr-5 max-w-[90vw] md:max-w-[70vw] lg:max-w-[50vw]">
<div className="grid grid-cols-1 gap-2">
{toolCalls.map((toolCall, index) => (
<Accordion
key={`tool-call-${index}`}
title={`${toolCall.tool_name} - ${toolCall.action_name}`}
className="w-full rounded-[20px] bg-gray-1000 dark:bg-gun-metal hover:bg-[#F1F1F1] dark:hover:bg-[#2C2E3C]"
titleClassName="px-4 py-2 text-sm font-semibold"
children={
<div className="flex flex-col gap-1">
<div className="flex flex-col border border-silver dark:border-silver/20 rounded-2xl">
<p className="p-2 text-sm font-semibold bg-black/10 dark:bg-[#191919] rounded-t-2xl break-words">
Arguments
</p>
<p className="p-2 font-mono text-sm dark:tex dark:bg-[#222327] rounded-b-2xl break-words">
<span className="text-black dark:text-gray-400">
{toolCall.arguments}
</span>
</p>
</div>
<div className="flex flex-col border border-silver dark:border-silver/20 rounded-2xl">
<p className="p-2 text-sm font-semibold bg-black/10 dark:bg-[#191919] rounded-t-2xl break-words">
Response
</p>
<p className="p-2 font-mono text-sm dark:tex dark:bg-[#222327] rounded-b-2xl break-words">
<span className="text-black dark:text-gray-400">
{toolCall.result}
</span>
</p>
</div>
</div>
}
/>
))}
</div>
</div>
)}
</div>
);
}

View File

@@ -121,6 +121,7 @@ export const SharedConversation = () => {
message={query.response}
type={'ANSWER'}
sources={query.sources ?? []}
toolCalls={query.tool_calls}
></ConversationBubble>
);
} else if (query.error) {

View File

@@ -1,6 +1,7 @@
import conversationService from '../api/services/conversationService';
import { Doc } from '../models/misc';
import { Answer, FEEDBACK, RetrievalPayload } from './conversationModels';
import { ToolCallsType } from './types';
export function handleFetchAnswer(
question: string,
@@ -16,6 +17,7 @@ export function handleFetchAnswer(
result: any;
answer: any;
sources: any;
toolCalls: ToolCallsType[];
conversationId: any;
query: string;
}
@@ -23,13 +25,18 @@ export function handleFetchAnswer(
result: any;
answer: any;
sources: any;
toolCalls: ToolCallsType[];
query: string;
conversationId: any;
title: any;
}
> {
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
return {
prompt: item.prompt,
response: item.response,
tool_calls: item.tool_calls,
};
});
const payload: RetrievalPayload = {
question: question,
@@ -60,6 +67,7 @@ export function handleFetchAnswer(
query: question,
result,
sources: data.sources,
toolCalls: data.tool_calls,
conversationId: data.conversation_id,
};
});
@@ -78,7 +86,11 @@ export function handleFetchAnswerSteaming(
indx?: number,
): Promise<Answer> {
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
return {
prompt: item.prompt,
response: item.response,
tool_calls: item.tool_calls,
};
});
const payload: RetrievalPayload = {
question: question,
@@ -155,7 +167,11 @@ export function handleSearch(
token_limit: number,
) {
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
return {
prompt: item.prompt,
response: item.response,
tool_calls: item.tool_calls,
};
});
const payload: RetrievalPayload = {
question: question,
@@ -183,7 +199,11 @@ export function handleSearchViaApiKey(
history: Array<any> = [],
) {
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
return {
prompt: item.prompt,
response: item.response,
tool_calls: item.tool_calls,
};
});
return conversationService
.search({
@@ -230,7 +250,11 @@ export function handleFetchSharedAnswerStreaming( //for shared conversations
onEvent: (event: MessageEvent) => void,
): Promise<Answer> {
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
return {
prompt: item.prompt,
response: item.response,
tool_calls: item.tool_calls,
};
});
return new Promise<Answer>((resolve, reject) => {
@@ -330,6 +354,7 @@ export function handleFetchSharedAnswer(
query: question,
result,
sources: data.sources,
toolCalls: data.tool_calls,
};
});
}

View File

@@ -1,3 +1,5 @@
import { ToolCallsType } from './types';
export type MESSAGE_TYPE = 'QUESTION' | 'ANSWER' | 'ERROR';
export type Status = 'idle' | 'loading' | 'failed';
export type FEEDBACK = 'LIKE' | 'DISLIKE' | null;
@@ -17,9 +19,10 @@ export interface Answer {
answer: string;
query: string;
result: string;
sources: { title: string; text: string; source: string }[];
conversationId: string | null;
title: string | null;
sources: { title: string; text: string; source: string }[];
tool_calls: ToolCallsType[];
}
export interface Query {
@@ -27,10 +30,12 @@ export interface Query {
response?: string;
feedback?: FEEDBACK;
error?: string;
sources?: { title: string; text: string; source: string }[];
conversationId?: string | null;
title?: string | null;
sources?: { title: string; text: string; source: string }[];
tool_calls?: ToolCallsType[];
}
export interface RetrievalPayload {
question: string;
active_docs?: string;

View File

@@ -82,6 +82,13 @@ export const fetchAnswer = createAsyncThunk<
query: { sources: data.source ?? [] },
}),
);
} else if (data.type === 'tool_calls') {
dispatch(
updateToolCalls({
index: indx ?? state.conversation.queries.length - 1,
query: { tool_calls: data.tool_calls },
}),
);
} else if (data.type === 'error') {
// set status to 'failed'
dispatch(conversationSlice.actions.setStatus('failed'));
@@ -130,7 +137,11 @@ export const fetchAnswer = createAsyncThunk<
dispatch(
updateQuery({
index: indx ?? state.conversation.queries.length - 1,
query: { response: answer.answer, sources: sourcesPrepped },
query: {
response: answer.answer,
sources: sourcesPrepped,
tool_calls: answer.toolCalls,
},
}),
);
dispatch(
@@ -156,6 +167,7 @@ export const fetchAnswer = createAsyncThunk<
query: question,
result: '',
sources: [],
tool_calls: [],
};
});
@@ -212,6 +224,15 @@ export const conversationSlice = createSlice({
state.queries[index].sources!.push(query.sources![0]);
}
},
updateToolCalls(
state,
action: PayloadAction<{ index: number; query: Partial<Query> }>,
) {
const { index, query } = action.payload;
if (!state.queries[index].tool_calls) {
state.queries[index].tool_calls = query?.tool_calls;
}
},
updateQuery(
state,
action: PayloadAction<{ index: number; query: Partial<Query> }>,
@@ -263,6 +284,7 @@ export const {
updateStreamingQuery,
updateConversationId,
updateStreamingSource,
updateToolCalls,
setConversation,
} = conversationSlice.actions;
export default conversationSlice.reducer;

View File

@@ -51,6 +51,13 @@ export const fetchSharedAnswer = createAsyncThunk<Answer, { question: string }>(
query: { sources: data.source ?? [] },
}),
);
} else if (data.type === 'tool_calls') {
dispatch(
updateToolCalls({
index: state.sharedConversation.queries.length - 1,
query: { tool_calls: data.tool_calls },
}),
);
} else if (data.type === 'error') {
// set status to 'failed'
dispatch(sharedConversationSlice.actions.setStatus('failed'));
@@ -107,6 +114,7 @@ export const fetchSharedAnswer = createAsyncThunk<Answer, { question: string }>(
query: question,
result: '',
sources: [],
tool_calls: [],
};
},
);
@@ -161,6 +169,15 @@ export const sharedConversationSlice = createSlice({
};
}
},
updateToolCalls(
state,
action: PayloadAction<{ index: number; query: Partial<Query> }>,
) {
const { index, query } = action.payload;
if (!state.queries[index].tool_calls) {
state.queries[index].tool_calls = query?.tool_calls;
}
},
updateQuery(
state,
action: PayloadAction<{ index: number; query: Partial<Query> }>,
@@ -232,6 +249,7 @@ export const {
setClientApiKey,
updateQuery,
updateStreamingQuery,
updateToolCalls,
addQuery,
saveToLocalStorage,
updateStreamingSource,

View File

@@ -0,0 +1,6 @@
export type ToolCallsType = {
tool_name: string;
action_name: string;
arguments: string;
result: string;
};