mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Merge pull request #1630 from siiddhantt/feat/show-tool-execution
feat: tool calls tracking
This commit is contained in:
@@ -88,9 +88,6 @@ def get_data_from_api_key(api_key):
|
||||
if data is None:
|
||||
raise Exception("Invalid API Key, please generate new key", 401)
|
||||
|
||||
if "retriever" not in data:
|
||||
data["retriever"] = None
|
||||
|
||||
if "source" in data and isinstance(data["source"], DBRef):
|
||||
source_doc = db.dereference(data["source"])
|
||||
data["source"] = str(source_doc["_id"])
|
||||
@@ -117,7 +114,9 @@ def is_azure_configured():
|
||||
)
|
||||
|
||||
|
||||
def save_conversation(conversation_id, question, response, source_log_docs, llm,index=None):
|
||||
def save_conversation(
|
||||
conversation_id, question, response, source_log_docs, tool_calls, llm, index=None
|
||||
):
|
||||
if conversation_id is not None and index is not None:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
||||
@@ -126,20 +125,14 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
|
||||
f"queries.{index}.prompt": question,
|
||||
f"queries.{index}.response": response,
|
||||
f"queries.{index}.sources": source_log_docs,
|
||||
f"queries.{index}.tool_calls": tool_calls,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
##remove following queries from the array
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
||||
{
|
||||
"$push":{
|
||||
"queries":{
|
||||
"$each":[],
|
||||
"$slice":index+1
|
||||
}
|
||||
}
|
||||
}
|
||||
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
||||
)
|
||||
elif conversation_id is not None and conversation_id != "None":
|
||||
conversations_collection.update_one(
|
||||
@@ -150,6 +143,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -169,11 +163,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
|
||||
"role": "user",
|
||||
"content": "Summarise following conversation in no more than 3 words, "
|
||||
"respond ONLY with the summary, use the same language as the "
|
||||
"system \n\nUser: "
|
||||
+ question
|
||||
+ "\n\n"
|
||||
+ "AI: "
|
||||
+ response,
|
||||
"system \n\nUser: " + question + "\n\n" + "AI: " + response,
|
||||
},
|
||||
]
|
||||
|
||||
@@ -188,6 +178,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm,
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -208,12 +199,13 @@ def get_prompt(prompt_id):
|
||||
|
||||
|
||||
def complete_stream(
|
||||
question, retriever, conversation_id, user_api_key, isNoneDoc=False,index=None
|
||||
question, retriever, conversation_id, user_api_key, isNoneDoc=False, index=None
|
||||
):
|
||||
|
||||
try:
|
||||
try:
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
tool_calls = []
|
||||
answer = retriever.gen()
|
||||
sources = retriever.search()
|
||||
for source in sources:
|
||||
@@ -222,6 +214,7 @@ def complete_stream(
|
||||
if len(sources) > 0:
|
||||
data = json.dumps({"type": "source", "source": sources})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
for line in answer:
|
||||
if "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
@@ -229,6 +222,10 @@ def complete_stream(
|
||||
yield f"data: {data}\n\n"
|
||||
elif "source" in line:
|
||||
source_log_docs.append(line["source"])
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
@@ -239,7 +236,13 @@ def complete_stream(
|
||||
)
|
||||
if user_api_key is None:
|
||||
conversation_id = save_conversation(
|
||||
conversation_id, question, response_full, source_log_docs, llm,index
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
index,
|
||||
)
|
||||
# send data.type = "end" to indicate that the stream has ended as json
|
||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||
@@ -303,7 +306,7 @@ class Stream(Resource):
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
"index":fields.Integer(
|
||||
"index": fields.Integer(
|
||||
required=False, description="The position where query is to be updated"
|
||||
),
|
||||
},
|
||||
@@ -315,22 +318,24 @@ class Stream(Resource):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
if "index" in data:
|
||||
required_fields = ["question","conversation_id"]
|
||||
required_fields = ["question", "conversation_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
|
||||
history = limit_chat_history(
|
||||
json.loads(data.get("history", [])), gpt_model=gpt_model
|
||||
)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
|
||||
index=data.get("index",None)
|
||||
|
||||
index = data.get("index", None)
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
@@ -367,7 +372,7 @@ class Stream(Resource):
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
|
||||
|
||||
return Response(
|
||||
complete_stream(
|
||||
question=question,
|
||||
@@ -395,7 +400,7 @@ class Stream(Resource):
|
||||
)
|
||||
status_code = 400
|
||||
return Response(
|
||||
error_stream_generate('Unknown error occurred'),
|
||||
error_stream_generate("Unknown error occurred"),
|
||||
status=status_code,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
@@ -442,14 +447,16 @@ class Answer(Resource):
|
||||
@api.doc(description="Provide an answer based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
required_fields = ["question"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)
|
||||
history = limit_chat_history(
|
||||
json.loads(data.get("history", [])), gpt_model=gpt_model
|
||||
)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
chunks = int(data.get("chunks", 2))
|
||||
@@ -490,13 +497,16 @@ class Answer(Resource):
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
|
||||
source_log_docs = []
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
tool_calls = []
|
||||
for line in retriever.gen():
|
||||
if "source" in line:
|
||||
source_log_docs.append(line["source"])
|
||||
elif "answer" in line:
|
||||
response_full += line["answer"]
|
||||
elif "tool_calls" in line:
|
||||
tool_calls.append(line["tool_calls"])
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in source_log_docs:
|
||||
@@ -509,7 +519,12 @@ class Answer(Resource):
|
||||
result = {"answer": response_full, "sources": source_log_docs}
|
||||
result["conversation_id"] = str(
|
||||
save_conversation(
|
||||
conversation_id, question, response_full, source_log_docs, llm
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
)
|
||||
)
|
||||
retriever_params = retriever.get_params()
|
||||
|
||||
@@ -1,26 +1,37 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
LLM_NAME: str = "docsgpt"
|
||||
MODEL_NAME: Optional[str] = None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
|
||||
MODEL_NAME: Optional[str] = (
|
||||
None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
|
||||
)
|
||||
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||
MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||
DEFAULT_MAX_HISTORY: int = 150
|
||||
MODEL_TOKEN_LIMITS: dict = {"gpt-4o-mini": 128000, "gpt-3.5-turbo": 4096, "claude-2": 1e5}
|
||||
MODEL_TOKEN_LIMITS: dict = {
|
||||
"gpt-4o-mini": 128000,
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"claude-2": 1e5,
|
||||
"gemini-2.0-flash-exp": 1e6,
|
||||
}
|
||||
UPLOAD_FOLDER: str = "inputs"
|
||||
PARSE_PDF_AS_IMAGE: bool = False
|
||||
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
|
||||
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
|
||||
VECTOR_STORE: str = (
|
||||
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
|
||||
)
|
||||
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
|
||||
|
||||
# LLM Cache
|
||||
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
|
||||
@@ -28,12 +39,18 @@ class Settings(BaseSettings):
|
||||
API_URL: str = "http://localhost:7091" # backend url for celery worker
|
||||
|
||||
API_KEY: Optional[str] = None # LLM api key
|
||||
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
EMBEDDINGS_KEY: Optional[str] = (
|
||||
None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
)
|
||||
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
|
||||
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
|
||||
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings
|
||||
OPENAI_BASE_URL: Optional[str] = None # openai base url for open ai compatable models
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = (
|
||||
None # azure deployment name for embeddings
|
||||
)
|
||||
OPENAI_BASE_URL: Optional[str] = (
|
||||
None # openai base url for open ai compatable models
|
||||
)
|
||||
|
||||
# elasticsearch
|
||||
ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch
|
||||
@@ -68,12 +85,14 @@ class Settings(BaseSettings):
|
||||
|
||||
# Milvus vectorstore config
|
||||
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
|
||||
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
|
||||
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
|
||||
MILVUS_TOKEN: Optional[str] = ""
|
||||
|
||||
# LanceDB vectorstore config
|
||||
LANCEDB_PATH: str = "/tmp/lancedb" # Path where LanceDB stores its local data
|
||||
LANCEDB_TABLE_NAME: Optional[str] = "docsgpts" # Name of the table to use for storing vectors
|
||||
LANCEDB_TABLE_NAME: Optional[str] = (
|
||||
"docsgpts" # Name of the table to use for storing vectors
|
||||
)
|
||||
BRAVE_SEARCH_API_KEY: Optional[str] = None
|
||||
|
||||
FLASK_DEBUG_MODE: bool = False
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
@@ -15,6 +17,63 @@ class OpenAILLM(BaseLLM):
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _clean_messages_openai(self, messages):
|
||||
cleaned_messages = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
if role == "model":
|
||||
role = "assistant"
|
||||
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
cleaned_messages.append({"role": role, "content": content})
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if "text" in item:
|
||||
cleaned_messages.append(
|
||||
{"role": role, "content": item["text"]}
|
||||
)
|
||||
elif "function_call" in item:
|
||||
tool_call = {
|
||||
"id": item["function_call"]["call_id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": item["function_call"]["name"],
|
||||
"arguments": json.dumps(
|
||||
item["function_call"]["args"]
|
||||
),
|
||||
},
|
||||
}
|
||||
cleaned_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [tool_call],
|
||||
}
|
||||
)
|
||||
elif "function_response" in item:
|
||||
cleaned_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": item["function_response"][
|
||||
"call_id"
|
||||
],
|
||||
"content": json.dumps(
|
||||
item["function_response"]["response"]["result"]
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content dictionary format: {item}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
return cleaned_messages
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
baseself,
|
||||
@@ -25,9 +84,15 @@ class OpenAILLM(BaseLLM):
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
**kwargs,
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
print(messages)
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, tools=tools, **kwargs
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
return response.choices[0]
|
||||
else:
|
||||
@@ -46,6 +111,7 @@ class OpenAILLM(BaseLLM):
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
**kwargs,
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import uuid
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.tools.agent import Agent
|
||||
@@ -35,6 +37,12 @@ class ClassicRAG(BaseRetriever):
|
||||
)
|
||||
)
|
||||
self.user_api_key = user_api_key
|
||||
self.agent = Agent(
|
||||
llm_name=settings.LLM_NAME,
|
||||
gpt_model=self.gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.user_api_key,
|
||||
)
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0:
|
||||
@@ -78,21 +86,42 @@ class ClassicRAG(BaseRetriever):
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": i["response"]}
|
||||
)
|
||||
if "tool_calls" in i:
|
||||
for tool_call in i["tool_calls"]:
|
||||
call_id = tool_call.get("call_id")
|
||||
if call_id is None or call_id == "None":
|
||||
call_id = str(uuid.uuid4())
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages_combine.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
# llm = LLMCreator.create_llm(
|
||||
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
|
||||
# )
|
||||
# completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
agent = Agent(
|
||||
llm_name=settings.LLM_NAME,
|
||||
gpt_model=self.gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.user_api_key,
|
||||
)
|
||||
completion = agent.gen(messages_combine)
|
||||
completion = self.agent.gen(messages_combine)
|
||||
|
||||
for line in completion:
|
||||
yield {"answer": str(line)}
|
||||
|
||||
yield {"tool_calls": self.agent.tool_calls.copy()}
|
||||
|
||||
def search(self):
|
||||
return self._get_data()
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ class Agent:
|
||||
# Static tool configuration (to be replaced later)
|
||||
self.tools = []
|
||||
self.tool_config = {}
|
||||
self.tool_calls = []
|
||||
|
||||
def _get_user_tools(self, user="local"):
|
||||
mongo = MongoDB.get_client()
|
||||
@@ -123,6 +124,16 @@ class Agent:
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
call_id = getattr(call, "id", None)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": tool_data["name"],
|
||||
"call_id": call_id if call_id is not None else "None",
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"arguments": call_args,
|
||||
"result": result,
|
||||
}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
|
||||
return result, call_id
|
||||
|
||||
def _simple_tool_agent(self, messages):
|
||||
@@ -134,7 +145,11 @@ class Agent:
|
||||
if isinstance(resp, str):
|
||||
yield resp
|
||||
return
|
||||
if hasattr(resp, "message") and hasattr(resp.message, "content"):
|
||||
if (
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
yield resp.message.content
|
||||
return
|
||||
|
||||
@@ -142,7 +157,11 @@ class Agent:
|
||||
|
||||
if isinstance(resp, str):
|
||||
yield resp
|
||||
elif hasattr(resp, "message") and hasattr(resp.message, "content"):
|
||||
elif (
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
yield resp.message.content
|
||||
else:
|
||||
completion = self.llm.gen_stream(
|
||||
@@ -154,6 +173,7 @@ class Agent:
|
||||
return
|
||||
|
||||
def gen(self, messages):
|
||||
self.tool_calls = []
|
||||
if self.llm.supports_tools():
|
||||
resp = self._simple_tool_agent(messages)
|
||||
for line in resp:
|
||||
|
||||
@@ -24,13 +24,28 @@ class OpenAILLMHandler(LLMHandler):
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, call
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": str(tool_response),
|
||||
"tool_call_id": call_id,
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": call.function.name,
|
||||
"args": call.function.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": call.function.name,
|
||||
"response": {"result": tool_response},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
messages.append(
|
||||
{
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import tiktoken
|
||||
import hashlib
|
||||
from flask import jsonify, make_response
|
||||
import re
|
||||
|
||||
import tiktoken
|
||||
from flask import jsonify, make_response
|
||||
|
||||
|
||||
_encoding = None
|
||||
|
||||
@@ -22,6 +23,7 @@ def num_tokens_from_string(string: str) -> int:
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def num_tokens_from_object_or_list(thing):
|
||||
if isinstance(thing, list):
|
||||
return sum([num_tokens_from_object_or_list(x) for x in thing])
|
||||
@@ -32,6 +34,7 @@ def num_tokens_from_object_or_list(thing):
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def count_tokens_docs(docs):
|
||||
docs_content = ""
|
||||
for doc in docs:
|
||||
@@ -59,6 +62,7 @@ def check_required_fields(data, required_fields):
|
||||
def get_hash(data):
|
||||
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
|
||||
|
||||
|
||||
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
||||
"""
|
||||
Limits chat history based on token count.
|
||||
@@ -67,38 +71,41 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
||||
from application.core.settings import settings
|
||||
|
||||
max_token_limit = (
|
||||
max_token_limit
|
||||
if max_token_limit and
|
||||
max_token_limit < settings.MODEL_TOKEN_LIMITS.get(
|
||||
gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
else settings.MODEL_TOKEN_LIMITS.get(
|
||||
gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
)
|
||||
|
||||
max_token_limit
|
||||
if max_token_limit
|
||||
and max_token_limit
|
||||
< settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
|
||||
else settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
|
||||
)
|
||||
|
||||
if not history:
|
||||
return []
|
||||
|
||||
tokens_current_history = 0
|
||||
|
||||
trimmed_history = []
|
||||
|
||||
tokens_current_history = 0
|
||||
|
||||
for message in reversed(history):
|
||||
tokens_batch = 0
|
||||
if "prompt" in message and "response" in message:
|
||||
tokens_batch = num_tokens_from_string(message["prompt"]) + num_tokens_from_string(
|
||||
message["response"]
|
||||
)
|
||||
if tokens_current_history + tokens_batch < max_token_limit:
|
||||
tokens_current_history += tokens_batch
|
||||
trimmed_history.insert(0, message)
|
||||
else:
|
||||
break
|
||||
tokens_batch += num_tokens_from_string(message["prompt"])
|
||||
tokens_batch += num_tokens_from_string(message["response"])
|
||||
|
||||
if "tool_calls" in message:
|
||||
for tool_call in message["tool_calls"]:
|
||||
tool_call_string = f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}"
|
||||
tokens_batch += num_tokens_from_string(tool_call_string)
|
||||
|
||||
if tokens_current_history + tokens_batch < max_token_limit:
|
||||
tokens_current_history += tokens_batch
|
||||
trimmed_history.insert(0, message)
|
||||
else:
|
||||
break
|
||||
|
||||
return trimmed_history
|
||||
|
||||
|
||||
def validate_function_name(function_name):
|
||||
"""Validates if a function name matches the allowed pattern."""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", function_name):
|
||||
return False
|
||||
return True
|
||||
return True
|
||||
|
||||
1
frontend/src/assets/chevron-down.svg
Normal file
1
frontend/src/assets/chevron-down.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-chevron-down"><path d="m6 9 6 6 6-6"/></svg>
|
||||
|
After Width: | Height: | Size: 246 B |
59
frontend/src/components/Accordion.tsx
Normal file
59
frontend/src/components/Accordion.tsx
Normal file
@@ -0,0 +1,59 @@
|
||||
import React, { useRef, useState } from 'react';
|
||||
|
||||
import ChevronDown from '../assets/chevron-down.svg';
|
||||
|
||||
type AccordionProps = {
|
||||
title: string;
|
||||
children: React.ReactNode;
|
||||
className?: string;
|
||||
titleClassName?: string;
|
||||
contentClassName?: string;
|
||||
open?: boolean;
|
||||
};
|
||||
|
||||
export default function Accordion({
|
||||
title,
|
||||
children,
|
||||
className = '',
|
||||
titleClassName = '',
|
||||
contentClassName = '',
|
||||
open: initialOpen = false,
|
||||
}: AccordionProps) {
|
||||
const contentRef = useRef<HTMLDivElement>(null);
|
||||
const [isOpen, setIsOpen] = useState(initialOpen);
|
||||
|
||||
const accordionContentStyle = {
|
||||
height: isOpen ? 'auto' : '0px',
|
||||
transition: 'height 0.3s ease-in-out, opacity 0.3s ease-in-out',
|
||||
overflow: 'hidden',
|
||||
} as React.CSSProperties;
|
||||
|
||||
const toggleAccordion = () => {
|
||||
setIsOpen(!isOpen);
|
||||
};
|
||||
return (
|
||||
<div className={`shadow-sm overflow-hidden ${className}`}>
|
||||
<button
|
||||
className={`flex items-center justify-between w-full focus:outline-none ${titleClassName}`}
|
||||
onClick={toggleAccordion}
|
||||
>
|
||||
<p className="break-words">{title}</p>
|
||||
<img
|
||||
src={ChevronDown}
|
||||
className={`h-5 w-5 transform transition-transform duration-200 dark:invert ${
|
||||
isOpen ? 'rotate-180' : ''
|
||||
}`}
|
||||
aria-hidden="true"
|
||||
/>
|
||||
</button>
|
||||
|
||||
<div
|
||||
ref={contentRef}
|
||||
style={accordionContentStyle}
|
||||
className={`px-4 ${contentClassName} ${isOpen ? 'pb-3' : 'pb-0'}`}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -40,7 +40,7 @@ export default function CoppyButton({
|
||||
/>
|
||||
) : (
|
||||
<Copy
|
||||
className="cursor-pointer fill-none"
|
||||
className="w-4 cursor-pointer fill-none"
|
||||
onClick={() => {
|
||||
handleCopyClick(text);
|
||||
}}
|
||||
|
||||
@@ -225,6 +225,7 @@ export default function Conversation() {
|
||||
message={query.response}
|
||||
type={'ANSWER'}
|
||||
sources={query.sources}
|
||||
toolCalls={query.tool_calls}
|
||||
feedback={query.feedback}
|
||||
handleFeedback={(feedback: FEEDBACK) =>
|
||||
handleFeedback(query, feedback, index)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import 'katex/dist/katex.min.css';
|
||||
|
||||
import { forwardRef, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import ReactMarkdown from 'react-markdown';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter';
|
||||
@@ -8,27 +9,29 @@ import { vscDarkPlus } from 'react-syntax-highlighter/dist/cjs/styles/prism';
|
||||
import rehypeKatex from 'rehype-katex';
|
||||
import remarkGfm from 'remark-gfm';
|
||||
import remarkMath from 'remark-math';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import DocsGPT3 from '../assets/cute_docsgpt3.svg';
|
||||
import ChevronDown from '../assets/chevron-down.svg';
|
||||
import Dislike from '../assets/dislike.svg?react';
|
||||
import Document from '../assets/document.svg';
|
||||
import Edit from '../assets/edit.svg';
|
||||
import Like from '../assets/like.svg?react';
|
||||
import Link from '../assets/link.svg';
|
||||
import Sources from '../assets/sources.svg';
|
||||
import Edit from '../assets/edit.svg';
|
||||
import UserIcon from '../assets/user.png';
|
||||
import Accordion from '../components/Accordion';
|
||||
import Avatar from '../components/Avatar';
|
||||
import CopyButton from '../components/CopyButton';
|
||||
import Sidebar from '../components/Sidebar';
|
||||
import SpeakButton from '../components/TextToSpeechButton';
|
||||
import { useOutsideAlerter } from '../hooks';
|
||||
import {
|
||||
selectChunks,
|
||||
selectSelectedDocs,
|
||||
} from '../preferences/preferenceSlice';
|
||||
import classes from './ConversationBubble.module.css';
|
||||
import { FEEDBACK, MESSAGE_TYPE } from './conversationModels';
|
||||
import { useOutsideAlerter } from '../hooks';
|
||||
import { ToolCallsType } from './types';
|
||||
|
||||
const DisableSourceFE = import.meta.env.VITE_DISABLE_SOURCE_FE || false;
|
||||
|
||||
@@ -41,6 +44,7 @@ const ConversationBubble = forwardRef<
|
||||
feedback?: FEEDBACK;
|
||||
handleFeedback?: (feedback: FEEDBACK) => void;
|
||||
sources?: { title: string; text: string; source: string }[];
|
||||
toolCalls?: ToolCallsType[];
|
||||
retryBtn?: React.ReactElement;
|
||||
questionNumber?: number;
|
||||
handleUpdatedQuestionSubmission?: (
|
||||
@@ -57,6 +61,7 @@ const ConversationBubble = forwardRef<
|
||||
feedback,
|
||||
handleFeedback,
|
||||
sources,
|
||||
toolCalls,
|
||||
retryBtn,
|
||||
questionNumber,
|
||||
handleUpdatedQuestionSubmission,
|
||||
@@ -307,6 +312,9 @@ const ConversationBubble = forwardRef<
|
||||
</div>
|
||||
)
|
||||
)}
|
||||
{toolCalls && toolCalls.length > 0 && (
|
||||
<ToolCalls toolCalls={toolCalls} />
|
||||
)}
|
||||
<div className="flex flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
<div className="my-2 flex flex-row items-center justify-center gap-3">
|
||||
<Avatar
|
||||
@@ -586,3 +594,88 @@ function AllSources(sources: AllSourcesProps) {
|
||||
}
|
||||
|
||||
export default ConversationBubble;
|
||||
|
||||
function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
|
||||
const [isToolCallsOpen, setIsToolCallsOpen] = useState(false);
|
||||
return (
|
||||
<div className="mb-4 w-full flex flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
<div className="my-2 flex flex-row items-center justify-center gap-3">
|
||||
<Avatar
|
||||
className="h-[26px] w-[30px] text-xl"
|
||||
avatar={
|
||||
<img
|
||||
src={Sources}
|
||||
alt={'ToolCalls'}
|
||||
className="h-full w-full object-fill"
|
||||
/>
|
||||
}
|
||||
/>
|
||||
<button
|
||||
className="flex flex-row items-center gap-2"
|
||||
onClick={() => setIsToolCallsOpen(!isToolCallsOpen)}
|
||||
>
|
||||
<p className="text-base font-semibold">Tool Calls</p>
|
||||
<img
|
||||
src={ChevronDown}
|
||||
alt="ChevronDown"
|
||||
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${isToolCallsOpen ? 'rotate-180' : ''}`}
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
{isToolCallsOpen && (
|
||||
<div className="fade-in ml-3 mr-5 max-w-[90vw] md:max-w-[70vw] lg:max-w-[50vw]">
|
||||
<div className="grid grid-cols-1 gap-2">
|
||||
{toolCalls.map((toolCall, index) => (
|
||||
<Accordion
|
||||
key={`tool-call-${index}`}
|
||||
title={`${toolCall.tool_name} - ${toolCall.action_name.substring(0, toolCall.action_name.lastIndexOf('_'))}`}
|
||||
className="w-full rounded-[20px] bg-gray-1000 dark:bg-gun-metal hover:bg-[#F1F1F1] dark:hover:bg-[#2C2E3C]"
|
||||
titleClassName="px-6 py-2 text-sm font-semibold"
|
||||
children={
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="flex flex-col border border-silver dark:border-silver/20 rounded-2xl">
|
||||
<p className="flex flex-row items-center justify-between px-2 py-1 text-sm font-semibold bg-black/10 dark:bg-[#191919] rounded-t-2xl break-words">
|
||||
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
|
||||
Arguments
|
||||
</span>{' '}
|
||||
<CopyButton
|
||||
text={JSON.stringify(toolCall.arguments, null, 2)}
|
||||
/>
|
||||
</p>
|
||||
<p className="p-2 font-mono text-sm dark:tex dark:bg-[#222327] rounded-b-2xl break-words">
|
||||
<span
|
||||
className="text-black dark:text-gray-400 leading-[23px]"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{JSON.stringify(toolCall.arguments, null, 2)}
|
||||
</span>
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex flex-col border border-silver dark:border-silver/20 rounded-2xl">
|
||||
<p className="flex flex-row items-center justify-between px-2 py-1 text-sm font-semibold bg-black/10 dark:bg-[#191919] rounded-t-2xl break-words">
|
||||
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
|
||||
Response
|
||||
</span>{' '}
|
||||
<CopyButton
|
||||
text={JSON.stringify(toolCall.result, null, 2)}
|
||||
/>
|
||||
</p>
|
||||
<p className="p-2 font-mono text-sm dark:tex dark:bg-[#222327] rounded-b-2xl break-words">
|
||||
<span
|
||||
className="text-black dark:text-gray-400 leading-[23px]"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{JSON.stringify(toolCall.result, null, 2)}
|
||||
</span>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -121,6 +121,7 @@ export const SharedConversation = () => {
|
||||
message={query.response}
|
||||
type={'ANSWER'}
|
||||
sources={query.sources ?? []}
|
||||
toolCalls={query.tool_calls}
|
||||
></ConversationBubble>
|
||||
);
|
||||
} else if (query.error) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import conversationService from '../api/services/conversationService';
|
||||
import { Doc } from '../models/misc';
|
||||
import { Answer, FEEDBACK, RetrievalPayload } from './conversationModels';
|
||||
import { ToolCallsType } from './types';
|
||||
|
||||
export function handleFetchAnswer(
|
||||
question: string,
|
||||
@@ -16,6 +17,7 @@ export function handleFetchAnswer(
|
||||
result: any;
|
||||
answer: any;
|
||||
sources: any;
|
||||
toolCalls: ToolCallsType[];
|
||||
conversationId: any;
|
||||
query: string;
|
||||
}
|
||||
@@ -23,13 +25,18 @@ export function handleFetchAnswer(
|
||||
result: any;
|
||||
answer: any;
|
||||
sources: any;
|
||||
toolCalls: ToolCallsType[];
|
||||
query: string;
|
||||
conversationId: any;
|
||||
title: any;
|
||||
}
|
||||
> {
|
||||
history = history.map((item) => {
|
||||
return { prompt: item.prompt, response: item.response };
|
||||
return {
|
||||
prompt: item.prompt,
|
||||
response: item.response,
|
||||
tool_calls: item.tool_calls,
|
||||
};
|
||||
});
|
||||
const payload: RetrievalPayload = {
|
||||
question: question,
|
||||
@@ -60,6 +67,7 @@ export function handleFetchAnswer(
|
||||
query: question,
|
||||
result,
|
||||
sources: data.sources,
|
||||
toolCalls: data.tool_calls,
|
||||
conversationId: data.conversation_id,
|
||||
};
|
||||
});
|
||||
@@ -78,7 +86,11 @@ export function handleFetchAnswerSteaming(
|
||||
indx?: number,
|
||||
): Promise<Answer> {
|
||||
history = history.map((item) => {
|
||||
return { prompt: item.prompt, response: item.response };
|
||||
return {
|
||||
prompt: item.prompt,
|
||||
response: item.response,
|
||||
tool_calls: item.tool_calls,
|
||||
};
|
||||
});
|
||||
const payload: RetrievalPayload = {
|
||||
question: question,
|
||||
@@ -155,7 +167,11 @@ export function handleSearch(
|
||||
token_limit: number,
|
||||
) {
|
||||
history = history.map((item) => {
|
||||
return { prompt: item.prompt, response: item.response };
|
||||
return {
|
||||
prompt: item.prompt,
|
||||
response: item.response,
|
||||
tool_calls: item.tool_calls,
|
||||
};
|
||||
});
|
||||
const payload: RetrievalPayload = {
|
||||
question: question,
|
||||
@@ -183,7 +199,11 @@ export function handleSearchViaApiKey(
|
||||
history: Array<any> = [],
|
||||
) {
|
||||
history = history.map((item) => {
|
||||
return { prompt: item.prompt, response: item.response };
|
||||
return {
|
||||
prompt: item.prompt,
|
||||
response: item.response,
|
||||
tool_calls: item.tool_calls,
|
||||
};
|
||||
});
|
||||
return conversationService
|
||||
.search({
|
||||
@@ -230,7 +250,11 @@ export function handleFetchSharedAnswerStreaming( //for shared conversations
|
||||
onEvent: (event: MessageEvent) => void,
|
||||
): Promise<Answer> {
|
||||
history = history.map((item) => {
|
||||
return { prompt: item.prompt, response: item.response };
|
||||
return {
|
||||
prompt: item.prompt,
|
||||
response: item.response,
|
||||
tool_calls: item.tool_calls,
|
||||
};
|
||||
});
|
||||
|
||||
return new Promise<Answer>((resolve, reject) => {
|
||||
@@ -330,6 +354,7 @@ export function handleFetchSharedAnswer(
|
||||
query: question,
|
||||
result,
|
||||
sources: data.sources,
|
||||
toolCalls: data.tool_calls,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import { ToolCallsType } from './types';
|
||||
|
||||
export type MESSAGE_TYPE = 'QUESTION' | 'ANSWER' | 'ERROR';
|
||||
export type Status = 'idle' | 'loading' | 'failed';
|
||||
export type FEEDBACK = 'LIKE' | 'DISLIKE' | null;
|
||||
@@ -17,9 +19,10 @@ export interface Answer {
|
||||
answer: string;
|
||||
query: string;
|
||||
result: string;
|
||||
sources: { title: string; text: string; source: string }[];
|
||||
conversationId: string | null;
|
||||
title: string | null;
|
||||
sources: { title: string; text: string; source: string }[];
|
||||
tool_calls: ToolCallsType[];
|
||||
}
|
||||
|
||||
export interface Query {
|
||||
@@ -27,10 +30,12 @@ export interface Query {
|
||||
response?: string;
|
||||
feedback?: FEEDBACK;
|
||||
error?: string;
|
||||
sources?: { title: string; text: string; source: string }[];
|
||||
conversationId?: string | null;
|
||||
title?: string | null;
|
||||
sources?: { title: string; text: string; source: string }[];
|
||||
tool_calls?: ToolCallsType[];
|
||||
}
|
||||
|
||||
export interface RetrievalPayload {
|
||||
question: string;
|
||||
active_docs?: string;
|
||||
|
||||
@@ -82,6 +82,13 @@ export const fetchAnswer = createAsyncThunk<
|
||||
query: { sources: data.source ?? [] },
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'tool_calls') {
|
||||
dispatch(
|
||||
updateToolCalls({
|
||||
index: indx ?? state.conversation.queries.length - 1,
|
||||
query: { tool_calls: data.tool_calls },
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'error') {
|
||||
// set status to 'failed'
|
||||
dispatch(conversationSlice.actions.setStatus('failed'));
|
||||
@@ -130,7 +137,11 @@ export const fetchAnswer = createAsyncThunk<
|
||||
dispatch(
|
||||
updateQuery({
|
||||
index: indx ?? state.conversation.queries.length - 1,
|
||||
query: { response: answer.answer, sources: sourcesPrepped },
|
||||
query: {
|
||||
response: answer.answer,
|
||||
sources: sourcesPrepped,
|
||||
tool_calls: answer.toolCalls,
|
||||
},
|
||||
}),
|
||||
);
|
||||
dispatch(
|
||||
@@ -156,6 +167,7 @@ export const fetchAnswer = createAsyncThunk<
|
||||
query: question,
|
||||
result: '',
|
||||
sources: [],
|
||||
tool_calls: [],
|
||||
};
|
||||
});
|
||||
|
||||
@@ -212,6 +224,15 @@ export const conversationSlice = createSlice({
|
||||
state.queries[index].sources!.push(query.sources![0]);
|
||||
}
|
||||
},
|
||||
updateToolCalls(
|
||||
state,
|
||||
action: PayloadAction<{ index: number; query: Partial<Query> }>,
|
||||
) {
|
||||
const { index, query } = action.payload;
|
||||
if (!state.queries[index].tool_calls) {
|
||||
state.queries[index].tool_calls = query?.tool_calls;
|
||||
}
|
||||
},
|
||||
updateQuery(
|
||||
state,
|
||||
action: PayloadAction<{ index: number; query: Partial<Query> }>,
|
||||
@@ -263,6 +284,7 @@ export const {
|
||||
updateStreamingQuery,
|
||||
updateConversationId,
|
||||
updateStreamingSource,
|
||||
updateToolCalls,
|
||||
setConversation,
|
||||
} = conversationSlice.actions;
|
||||
export default conversationSlice.reducer;
|
||||
|
||||
@@ -51,6 +51,13 @@ export const fetchSharedAnswer = createAsyncThunk<Answer, { question: string }>(
|
||||
query: { sources: data.source ?? [] },
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'tool_calls') {
|
||||
dispatch(
|
||||
updateToolCalls({
|
||||
index: state.sharedConversation.queries.length - 1,
|
||||
query: { tool_calls: data.tool_calls },
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'error') {
|
||||
// set status to 'failed'
|
||||
dispatch(sharedConversationSlice.actions.setStatus('failed'));
|
||||
@@ -107,6 +114,7 @@ export const fetchSharedAnswer = createAsyncThunk<Answer, { question: string }>(
|
||||
query: question,
|
||||
result: '',
|
||||
sources: [],
|
||||
tool_calls: [],
|
||||
};
|
||||
},
|
||||
);
|
||||
@@ -161,6 +169,15 @@ export const sharedConversationSlice = createSlice({
|
||||
};
|
||||
}
|
||||
},
|
||||
updateToolCalls(
|
||||
state,
|
||||
action: PayloadAction<{ index: number; query: Partial<Query> }>,
|
||||
) {
|
||||
const { index, query } = action.payload;
|
||||
if (!state.queries[index].tool_calls) {
|
||||
state.queries[index].tool_calls = query?.tool_calls;
|
||||
}
|
||||
},
|
||||
updateQuery(
|
||||
state,
|
||||
action: PayloadAction<{ index: number; query: Partial<Query> }>,
|
||||
@@ -232,6 +249,7 @@ export const {
|
||||
setClientApiKey,
|
||||
updateQuery,
|
||||
updateStreamingQuery,
|
||||
updateToolCalls,
|
||||
addQuery,
|
||||
saveToLocalStorage,
|
||||
updateStreamingSource,
|
||||
|
||||
7
frontend/src/conversation/types/index.ts
Normal file
7
frontend/src/conversation/types/index.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
export type ToolCallsType = {
|
||||
tool_name: string;
|
||||
action_name: string;
|
||||
call_id: string;
|
||||
arguments: Record<string, any>;
|
||||
result: Record<string, any>;
|
||||
};
|
||||
Reference in New Issue
Block a user