feat: add tool calls tracking and show in frontend

This commit is contained in:
Siddhant Rai
2025-02-12 21:47:47 +05:30
parent 0de4241b56
commit e209699b19
13 changed files with 302 additions and 51 deletions

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-chevron-down"><path d="m6 9 6 6 6-6"/></svg>

After

Width:  |  Height:  |  Size: 246 B

View File

@@ -0,0 +1,59 @@
import React, { useRef, useState } from 'react';
import ChevronDown from '../assets/chevron-down.svg';
type AccordionProps = {
title: string;
children: React.ReactNode;
className?: string;
titleClassName?: string;
contentClassName?: string;
open?: boolean;
};
export default function Accordion({
title,
children,
className = '',
titleClassName = '',
contentClassName = '',
open: initialOpen = false,
}: AccordionProps) {
const contentRef = useRef<HTMLDivElement>(null);
const [isOpen, setIsOpen] = useState(initialOpen);
const accordionContentStyle = {
height: isOpen ? 'auto' : '0px',
transition: 'height 0.3s ease-in-out, opacity 0.3s ease-in-out',
overflow: 'hidden',
} as React.CSSProperties;
const toggleAccordion = () => {
setIsOpen(!isOpen);
};
return (
<div className={`shadow-sm overflow-hidden ${className}`}>
<button
className={`flex items-center justify-between w-full focus:outline-none ${titleClassName}`}
onClick={toggleAccordion}
>
<p className="break-words">{title}</p>
<img
src={ChevronDown}
className={`h-5 w-5 transform transition-transform duration-200 dark:invert ${
isOpen ? 'rotate-180' : ''
}`}
aria-hidden="true"
/>
</button>
<div
ref={contentRef}
style={accordionContentStyle}
className={`px-4 ${contentClassName} ${isOpen ? 'pb-3' : 'pb-0'}`}
>
{children}
</div>
</div>
);
}

View File

@@ -225,6 +225,7 @@ export default function Conversation() {
message={query.response}
type={'ANSWER'}
sources={query.sources}
toolCalls={query.tool_calls}
feedback={query.feedback}
handleFeedback={(feedback: FEEDBACK) =>
handleFeedback(query, feedback, index)

View File

@@ -1,6 +1,7 @@
import 'katex/dist/katex.min.css';
import { forwardRef, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import ReactMarkdown from 'react-markdown';
import { useSelector } from 'react-redux';
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter';
@@ -8,27 +9,29 @@ import { vscDarkPlus } from 'react-syntax-highlighter/dist/cjs/styles/prism';
import rehypeKatex from 'rehype-katex';
import remarkGfm from 'remark-gfm';
import remarkMath from 'remark-math';
import { useTranslation } from 'react-i18next';
import DocsGPT3 from '../assets/cute_docsgpt3.svg';
import ChevronDown from '../assets/chevron-down.svg';
import Dislike from '../assets/dislike.svg?react';
import Document from '../assets/document.svg';
import Edit from '../assets/edit.svg';
import Like from '../assets/like.svg?react';
import Link from '../assets/link.svg';
import Sources from '../assets/sources.svg';
import Edit from '../assets/edit.svg';
import UserIcon from '../assets/user.png';
import Accordion from '../components/Accordion';
import Avatar from '../components/Avatar';
import CopyButton from '../components/CopyButton';
import Sidebar from '../components/Sidebar';
import SpeakButton from '../components/TextToSpeechButton';
import { useOutsideAlerter } from '../hooks';
import {
selectChunks,
selectSelectedDocs,
} from '../preferences/preferenceSlice';
import classes from './ConversationBubble.module.css';
import { FEEDBACK, MESSAGE_TYPE } from './conversationModels';
import { useOutsideAlerter } from '../hooks';
import { ToolCallsType } from './types';
const DisableSourceFE = import.meta.env.VITE_DISABLE_SOURCE_FE || false;
@@ -41,6 +44,7 @@ const ConversationBubble = forwardRef<
feedback?: FEEDBACK;
handleFeedback?: (feedback: FEEDBACK) => void;
sources?: { title: string; text: string; source: string }[];
toolCalls?: ToolCallsType[];
retryBtn?: React.ReactElement;
questionNumber?: number;
handleUpdatedQuestionSubmission?: (
@@ -57,6 +61,7 @@ const ConversationBubble = forwardRef<
feedback,
handleFeedback,
sources,
toolCalls,
retryBtn,
questionNumber,
handleUpdatedQuestionSubmission,
@@ -307,6 +312,9 @@ const ConversationBubble = forwardRef<
</div>
)
)}
{toolCalls && toolCalls.length > 0 && (
<ToolCalls toolCalls={toolCalls} />
)}
<div className="flex flex-col flex-wrap items-start self-start lg:flex-nowrap">
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
@@ -586,3 +594,72 @@ function AllSources(sources: AllSourcesProps) {
}
export default ConversationBubble;
function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
const [isToolCallsOpen, setIsToolCallsOpen] = useState(false);
return (
<div className="mb-4 w-full flex flex-col flex-wrap items-start self-start lg:flex-nowrap">
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
className="h-[26px] w-[30px] text-xl"
avatar={
<img
src={Sources}
alt={'ToolCalls'}
className="h-full w-full object-fill"
/>
}
/>
<button
className="flex flex-row items-center gap-2"
onClick={() => setIsToolCallsOpen(!isToolCallsOpen)}
>
<p className="text-base font-semibold">Tool Calls</p>
<img
src={ChevronDown}
alt="ChevronDown"
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${isToolCallsOpen ? 'rotate-180' : ''}`}
/>
</button>
</div>
{isToolCallsOpen && (
<div className="fade-in ml-3 mr-5 max-w-[90vw] md:max-w-[70vw] lg:max-w-[50vw]">
<div className="grid grid-cols-1 gap-2">
{toolCalls.map((toolCall, index) => (
<Accordion
key={`tool-call-${index}`}
title={`${toolCall.tool_name} - ${toolCall.action_name}`}
className="w-full rounded-[20px] bg-gray-1000 dark:bg-gun-metal hover:bg-[#F1F1F1] dark:hover:bg-[#2C2E3C]"
titleClassName="px-4 py-2 text-sm font-semibold"
children={
<div className="flex flex-col gap-1">
<div className="flex flex-col border border-silver dark:border-silver/20 rounded-2xl">
<p className="p-2 text-sm font-semibold bg-black/10 dark:bg-[#191919] rounded-t-2xl break-words">
Arguments
</p>
<p className="p-2 font-mono text-sm dark:tex dark:bg-[#222327] rounded-b-2xl break-words">
<span className="text-black dark:text-gray-400">
{toolCall.arguments}
</span>
</p>
</div>
<div className="flex flex-col border border-silver dark:border-silver/20 rounded-2xl">
<p className="p-2 text-sm font-semibold bg-black/10 dark:bg-[#191919] rounded-t-2xl break-words">
Response
</p>
<p className="p-2 font-mono text-sm dark:tex dark:bg-[#222327] rounded-b-2xl break-words">
<span className="text-black dark:text-gray-400">
{toolCall.result}
</span>
</p>
</div>
</div>
}
/>
))}
</div>
</div>
)}
</div>
);
}

View File

@@ -121,6 +121,7 @@ export const SharedConversation = () => {
message={query.response}
type={'ANSWER'}
sources={query.sources ?? []}
toolCalls={query.tool_calls}
></ConversationBubble>
);
} else if (query.error) {

View File

@@ -1,6 +1,7 @@
import conversationService from '../api/services/conversationService';
import { Doc } from '../models/misc';
import { Answer, FEEDBACK, RetrievalPayload } from './conversationModels';
import { ToolCallsType } from './types';
export function handleFetchAnswer(
question: string,
@@ -16,6 +17,7 @@ export function handleFetchAnswer(
result: any;
answer: any;
sources: any;
toolCalls: ToolCallsType[];
conversationId: any;
query: string;
}
@@ -23,13 +25,18 @@ export function handleFetchAnswer(
result: any;
answer: any;
sources: any;
toolCalls: ToolCallsType[];
query: string;
conversationId: any;
title: any;
}
> {
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
return {
prompt: item.prompt,
response: item.response,
tool_calls: item.tool_calls,
};
});
const payload: RetrievalPayload = {
question: question,
@@ -60,6 +67,7 @@ export function handleFetchAnswer(
query: question,
result,
sources: data.sources,
toolCalls: data.tool_calls,
conversationId: data.conversation_id,
};
});
@@ -78,7 +86,11 @@ export function handleFetchAnswerSteaming(
indx?: number,
): Promise<Answer> {
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
return {
prompt: item.prompt,
response: item.response,
tool_calls: item.tool_calls,
};
});
const payload: RetrievalPayload = {
question: question,
@@ -155,7 +167,11 @@ export function handleSearch(
token_limit: number,
) {
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
return {
prompt: item.prompt,
response: item.response,
tool_calls: item.tool_calls,
};
});
const payload: RetrievalPayload = {
question: question,
@@ -183,7 +199,11 @@ export function handleSearchViaApiKey(
history: Array<any> = [],
) {
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
return {
prompt: item.prompt,
response: item.response,
tool_calls: item.tool_calls,
};
});
return conversationService
.search({
@@ -230,7 +250,11 @@ export function handleFetchSharedAnswerStreaming( //for shared conversations
onEvent: (event: MessageEvent) => void,
): Promise<Answer> {
history = history.map((item) => {
return { prompt: item.prompt, response: item.response };
return {
prompt: item.prompt,
response: item.response,
tool_calls: item.tool_calls,
};
});
return new Promise<Answer>((resolve, reject) => {
@@ -330,6 +354,7 @@ export function handleFetchSharedAnswer(
query: question,
result,
sources: data.sources,
toolCalls: data.tool_calls,
};
});
}

View File

@@ -1,3 +1,5 @@
import { ToolCallsType } from './types';
export type MESSAGE_TYPE = 'QUESTION' | 'ANSWER' | 'ERROR';
export type Status = 'idle' | 'loading' | 'failed';
export type FEEDBACK = 'LIKE' | 'DISLIKE' | null;
@@ -17,9 +19,10 @@ export interface Answer {
answer: string;
query: string;
result: string;
sources: { title: string; text: string; source: string }[];
conversationId: string | null;
title: string | null;
sources: { title: string; text: string; source: string }[];
tool_calls: ToolCallsType[];
}
export interface Query {
@@ -27,10 +30,12 @@ export interface Query {
response?: string;
feedback?: FEEDBACK;
error?: string;
sources?: { title: string; text: string; source: string }[];
conversationId?: string | null;
title?: string | null;
sources?: { title: string; text: string; source: string }[];
tool_calls?: ToolCallsType[];
}
export interface RetrievalPayload {
question: string;
active_docs?: string;

View File

@@ -82,6 +82,13 @@ export const fetchAnswer = createAsyncThunk<
query: { sources: data.source ?? [] },
}),
);
} else if (data.type === 'tool_calls') {
dispatch(
updateToolCalls({
index: indx ?? state.conversation.queries.length - 1,
query: { tool_calls: data.tool_calls },
}),
);
} else if (data.type === 'error') {
// set status to 'failed'
dispatch(conversationSlice.actions.setStatus('failed'));
@@ -130,7 +137,11 @@ export const fetchAnswer = createAsyncThunk<
dispatch(
updateQuery({
index: indx ?? state.conversation.queries.length - 1,
query: { response: answer.answer, sources: sourcesPrepped },
query: {
response: answer.answer,
sources: sourcesPrepped,
tool_calls: answer.toolCalls,
},
}),
);
dispatch(
@@ -156,6 +167,7 @@ export const fetchAnswer = createAsyncThunk<
query: question,
result: '',
sources: [],
tool_calls: [],
};
});
@@ -212,6 +224,15 @@ export const conversationSlice = createSlice({
state.queries[index].sources!.push(query.sources![0]);
}
},
updateToolCalls(
state,
action: PayloadAction<{ index: number; query: Partial<Query> }>,
) {
const { index, query } = action.payload;
if (!state.queries[index].tool_calls) {
state.queries[index].tool_calls = query?.tool_calls;
}
},
updateQuery(
state,
action: PayloadAction<{ index: number; query: Partial<Query> }>,
@@ -263,6 +284,7 @@ export const {
updateStreamingQuery,
updateConversationId,
updateStreamingSource,
updateToolCalls,
setConversation,
} = conversationSlice.actions;
export default conversationSlice.reducer;

View File

@@ -51,6 +51,13 @@ export const fetchSharedAnswer = createAsyncThunk<Answer, { question: string }>(
query: { sources: data.source ?? [] },
}),
);
} else if (data.type === 'tool_calls') {
dispatch(
updateToolCalls({
index: state.sharedConversation.queries.length - 1,
query: { tool_calls: data.tool_calls },
}),
);
} else if (data.type === 'error') {
// set status to 'failed'
dispatch(sharedConversationSlice.actions.setStatus('failed'));
@@ -107,6 +114,7 @@ export const fetchSharedAnswer = createAsyncThunk<Answer, { question: string }>(
query: question,
result: '',
sources: [],
tool_calls: [],
};
},
);
@@ -161,6 +169,15 @@ export const sharedConversationSlice = createSlice({
};
}
},
updateToolCalls(
state,
action: PayloadAction<{ index: number; query: Partial<Query> }>,
) {
const { index, query } = action.payload;
if (!state.queries[index].tool_calls) {
state.queries[index].tool_calls = query?.tool_calls;
}
},
updateQuery(
state,
action: PayloadAction<{ index: number; query: Partial<Query> }>,
@@ -232,6 +249,7 @@ export const {
setClientApiKey,
updateQuery,
updateStreamingQuery,
updateToolCalls,
addQuery,
saveToLocalStorage,
updateStreamingSource,

View File

@@ -0,0 +1,6 @@
export type ToolCallsType = {
tool_name: string;
action_name: string;
arguments: string;
result: string;
};