From 92c3c707e12ec60e203cf22c91c562629fa81041 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 2 Jun 2025 12:17:29 +0530 Subject: [PATCH 01/10] refactor: reorganize LLM handler structure and improve tool call parsing --- application/agents/base.py | 16 +- application/agents/llm_handler.py | 351 ------------------ .../agents/tools/tool_action_parser.py | 33 +- application/llm/handlers/__init__.py | 0 application/llm/handlers/base.py | 317 ++++++++++++++++ application/llm/handlers/google.py | 78 ++++ application/llm/handlers/handler_creator.py | 18 + application/llm/handlers/openai.py | 57 +++ 8 files changed, 494 insertions(+), 376 deletions(-) delete mode 100644 application/agents/llm_handler.py create mode 100644 application/llm/handlers/__init__.py create mode 100644 application/llm/handlers/base.py create mode 100644 application/llm/handlers/google.py create mode 100644 application/llm/handlers/handler_creator.py create mode 100644 application/llm/handlers/openai.py diff --git a/application/agents/base.py b/application/agents/base.py index d44244cf..a4bbd001 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -2,16 +2,18 @@ import uuid from abc import ABC, abstractmethod from typing import Dict, Generator, List, Optional -from application.agents.llm_handler import get_llm_handler +from bson.objectid import ObjectId + from application.agents.tools.tool_action_parser import ToolActionParser from application.agents.tools.tool_manager import ToolManager from application.core.mongo_db import MongoDB +from application.core.settings import settings + +from application.llm.handlers.handler_creator import LLMHandlerCreator from application.llm.llm_creator import LLMCreator from application.logging import build_stack_data, log_activity, LogContext from application.retriever.base import BaseRetriever -from application.core.settings import settings -from bson.objectid import ObjectId class BaseAgent(ABC): @@ -45,7 +47,9 @@ class BaseAgent(ABC): user_api_key=user_api_key, decoded_token=decoded_token, ) - self.llm_handler = get_llm_handler(llm_name) + self.llm_handler = LLMHandlerCreator.create_handler( + llm_name if llm_name else "default" + ) self.attachments = attachments or [] @log_activity() @@ -268,8 +272,8 @@ class BaseAgent(ABC): log_context: Optional[LogContext] = None, attachments: Optional[List[Dict]] = None, ): - resp = self.llm_handler.handle_response( - self, resp, tools_dict, messages, attachments + resp = self.llm_handler.process_message_flow( + self, resp, tools_dict, messages, attachments, True ) if log_context: data = build_stack_data(self.llm_handler, exclude_attributes=["tool_calls"]) diff --git a/application/agents/llm_handler.py b/application/agents/llm_handler.py deleted file mode 100644 index 1b995f71..00000000 --- a/application/agents/llm_handler.py +++ /dev/null @@ -1,351 +0,0 @@ -import json -import logging -from abc import ABC, abstractmethod - -from application.logging import build_stack_data - -logger = logging.getLogger(__name__) - - -class LLMHandler(ABC): - def __init__(self): - self.llm_calls = [] - self.tool_calls = [] - - @abstractmethod - def handle_response(self, agent, resp, tools_dict, messages, attachments=None, **kwargs): - pass - - def prepare_messages_with_attachments(self, agent, messages, attachments=None): - """ - Prepare messages with attachment content if available. - - Args: - agent: The current agent instance. - messages (list): List of message dictionaries. - attachments (list): List of attachment dictionaries with content. - - Returns: - list: Messages with attachment context added to the system prompt. - """ - if not attachments: - return messages - - logger.info(f"Preparing messages with {len(attachments)} attachments") - - supported_types = agent.llm.get_supported_attachment_types() - - supported_attachments = [] - unsupported_attachments = [] - - for attachment in attachments: - mime_type = attachment.get('mime_type') - if mime_type in supported_types: - supported_attachments.append(attachment) - else: - unsupported_attachments.append(attachment) - - # Process supported attachments with the LLM's custom method - prepared_messages = messages - if supported_attachments: - logger.info(f"Processing {len(supported_attachments)} supported attachments with {agent.llm.__class__.__name__}'s method") - prepared_messages = agent.llm.prepare_messages_with_attachments(messages, supported_attachments) - - # Process unsupported attachments with the default method - if unsupported_attachments: - logger.info(f"Processing {len(unsupported_attachments)} unsupported attachments with default method") - prepared_messages = self._append_attachment_content_to_system(prepared_messages, unsupported_attachments) - - return prepared_messages - - def _append_attachment_content_to_system(self, messages, attachments): - """ - Default method to append attachment content to the system prompt. - - Args: - messages (list): List of message dictionaries. - attachments (list): List of attachment dictionaries with content. - - Returns: - list: Messages with attachment context added to the system prompt. - """ - prepared_messages = messages.copy() - - attachment_texts = [] - for attachment in attachments: - logger.info(f"Adding attachment {attachment.get('id')} to context") - if 'content' in attachment: - attachment_texts.append(f"Attached file content:\n\n{attachment['content']}") - - if attachment_texts: - combined_attachment_text = "\n\n".join(attachment_texts) - - system_found = False - for i in range(len(prepared_messages)): - if prepared_messages[i].get("role") == "system": - prepared_messages[i]["content"] += f"\n\n{combined_attachment_text}" - system_found = True - break - - if not system_found: - prepared_messages.insert(0, {"role": "system", "content": combined_attachment_text}) - - return prepared_messages - -class OpenAILLMHandler(LLMHandler): - def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True): - - messages = self.prepare_messages_with_attachments(agent, messages, attachments) - logger.info(f"Messages with attachments: {messages}") - if not stream: - while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls": - message = json.loads(resp.model_dump_json())["message"] - keys_to_remove = {"audio", "function_call", "refusal"} - filtered_data = { - k: v for k, v in message.items() if k not in keys_to_remove - } - messages.append(filtered_data) - - tool_calls = resp.message.tool_calls - for call in tool_calls: - try: - self.tool_calls.append(call) - tool_response, call_id = agent._execute_tool_action( - tools_dict, call - ) - function_call_dict = { - "function_call": { - "name": call.function.name, - "args": call.function.arguments, - "call_id": call_id, - } - } - function_response_dict = { - "function_response": { - "name": call.function.name, - "response": {"result": tool_response}, - "call_id": call_id, - } - } - - messages.append( - {"role": "assistant", "content": [function_call_dict]} - ) - messages.append( - {"role": "tool", "content": [function_response_dict]} - ) - - messages = self.prepare_messages_with_attachments(agent, messages, attachments) - except Exception as e: - logging.error(f"Error executing tool: {str(e)}", exc_info=True) - messages.append( - { - "role": "tool", - "content": f"Error executing tool: {str(e)}", - "tool_call_id": call_id, - } - ) - resp = agent.llm.gen_stream( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - self.llm_calls.append(build_stack_data(agent.llm)) - return resp - - else: - text_buffer = "" - while True: - tool_calls = {} - for chunk in resp: - if isinstance(chunk, str) and len(chunk) > 0: - yield chunk - continue - elif hasattr(chunk, "delta"): - chunk_delta = chunk.delta - - if ( - hasattr(chunk_delta, "tool_calls") - and chunk_delta.tool_calls is not None - ): - for tool_call in chunk_delta.tool_calls: - index = tool_call.index - if index not in tool_calls: - tool_calls[index] = { - "id": "", - "function": {"name": "", "arguments": ""}, - } - - current = tool_calls[index] - if tool_call.id: - current["id"] = tool_call.id - if tool_call.function.name: - current["function"][ - "name" - ] = tool_call.function.name - if tool_call.function.arguments: - current["function"][ - "arguments" - ] += tool_call.function.arguments - tool_calls[index] = current - - if ( - hasattr(chunk, "finish_reason") - and chunk.finish_reason == "tool_calls" - ): - for index in sorted(tool_calls.keys()): - call = tool_calls[index] - try: - self.tool_calls.append(call) - tool_response, call_id = agent._execute_tool_action( - tools_dict, call - ) - if isinstance(call["function"]["arguments"], str): - call["function"]["arguments"] = json.loads(call["function"]["arguments"]) - - function_call_dict = { - "function_call": { - "name": call["function"]["name"], - "args": call["function"]["arguments"], - "call_id": call["id"], - } - } - function_response_dict = { - "function_response": { - "name": call["function"]["name"], - "response": {"result": tool_response}, - "call_id": call["id"], - } - } - - messages.append( - { - "role": "assistant", - "content": [function_call_dict], - } - ) - messages.append( - { - "role": "tool", - "content": [function_response_dict], - } - ) - - except Exception as e: - logging.error(f"Error executing tool: {str(e)}", exc_info=True) - messages.append( - { - "role": "assistant", - "content": f"Error executing tool: {str(e)}", - } - ) - tool_calls = {} - if hasattr(chunk_delta, "content") and chunk_delta.content: - # Add to buffer or yield immediately based on your preference - text_buffer += chunk_delta.content - yield text_buffer - text_buffer = "" - - if ( - hasattr(chunk, "finish_reason") - and chunk.finish_reason == "stop" - ): - return resp - elif isinstance(chunk, str) and len(chunk) == 0: - continue - - logger.info(f"Regenerating with messages: {messages}") - resp = agent.llm.gen_stream( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - self.llm_calls.append(build_stack_data(agent.llm)) - - -class GoogleLLMHandler(LLMHandler): - def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True): - from google.genai import types - - messages = self.prepare_messages_with_attachments(agent, messages, attachments) - - while True: - if not stream: - response = agent.llm.gen( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - self.llm_calls.append(build_stack_data(agent.llm)) - if response.candidates and response.candidates[0].content.parts: - tool_call_found = False - for part in response.candidates[0].content.parts: - if part.function_call: - tool_call_found = True - self.tool_calls.append(part.function_call) - tool_response, call_id = agent._execute_tool_action( - tools_dict, part.function_call - ) - function_response_part = types.Part.from_function_response( - name=part.function_call.name, - response={"result": tool_response}, - ) - - messages.append( - {"role": "model", "content": [part.to_json_dict()]} - ) - messages.append( - { - "role": "tool", - "content": [function_response_part.to_json_dict()], - } - ) - - if ( - not tool_call_found - and response.candidates[0].content.parts - and response.candidates[0].content.parts[0].text - ): - return response.candidates[0].content.parts[0].text - elif not tool_call_found: - return response.candidates[0].content.parts - - else: - return response - - else: - response = agent.llm.gen_stream( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - self.llm_calls.append(build_stack_data(agent.llm)) - - tool_call_found = False - for result in response: - if hasattr(result, "function_call"): - tool_call_found = True - self.tool_calls.append(result.function_call) - tool_response, call_id = agent._execute_tool_action( - tools_dict, result.function_call - ) - function_response_part = types.Part.from_function_response( - name=result.function_call.name, - response={"result": tool_response}, - ) - - messages.append( - {"role": "model", "content": [result.to_json_dict()]} - ) - messages.append( - { - "role": "tool", - "content": [function_response_part.to_json_dict()], - } - ) - else: - tool_call_found = False - yield result - - if not tool_call_found: - return response - - -def get_llm_handler(llm_type): - handlers = { - "openai": OpenAILLMHandler(), - "google": GoogleLLMHandler(), - } - return handlers.get(llm_type, OpenAILLMHandler()) diff --git a/application/agents/tools/tool_action_parser.py b/application/agents/tools/tool_action_parser.py index c7da5a4c..0589ac88 100644 --- a/application/agents/tools/tool_action_parser.py +++ b/application/agents/tools/tool_action_parser.py @@ -17,26 +17,21 @@ class ToolActionParser: return parser(call) def _parse_openai_llm(self, call): - if isinstance(call, dict): - try: - call_args = json.loads(call["function"]["arguments"]) - tool_id = call["function"]["name"].split("_")[-1] - action_name = call["function"]["name"].rsplit("_", 1)[0] - except (KeyError, TypeError) as e: - logger.error(f"Error parsing OpenAI LLM call: {e}") - return None, None, None - else: - try: - call_args = json.loads(call.function.arguments) - tool_id = call.function.name.split("_")[-1] - action_name = call.function.name.rsplit("_", 1)[0] - except (AttributeError, TypeError) as e: - logger.error(f"Error parsing OpenAI LLM call: {e}") - return None, None, None + try: + call_args = json.loads(call.arguments) + tool_id = call.name.split("_")[-1] + action_name = call.name.rsplit("_", 1)[0] + except (AttributeError, TypeError) as e: + logger.error(f"Error parsing OpenAI LLM call: {e}") + return None, None, None return tool_id, action_name, call_args def _parse_google_llm(self, call): - call_args = call.args - tool_id = call.name.split("_")[-1] - action_name = call.name.rsplit("_", 1)[0] + try: + call_args = call.arguments + tool_id = call.name.split("_")[-1] + action_name = call.name.rsplit("_", 1)[0] + except (AttributeError, TypeError) as e: + logger.error(f"Error parsing Google LLM call: {e}") + return None, None, None return tool_id, action_name, call_args diff --git a/application/llm/handlers/__init__.py b/application/llm/handlers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/application/llm/handlers/base.py b/application/llm/handlers/base.py new file mode 100644 index 00000000..ede7cec3 --- /dev/null +++ b/application/llm/handlers/base.py @@ -0,0 +1,317 @@ +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, Generator, List, Optional, Union + +from application.logging import build_stack_data + +logger = logging.getLogger(__name__) + + +@dataclass +class ToolCall: + """Represents a tool/function call from the LLM.""" + + id: str + name: str + arguments: Union[str, Dict] + index: Optional[int] = None + + @classmethod + def from_dict(cls, data: Dict) -> "ToolCall": + """Create ToolCall from dictionary.""" + return cls( + id=data.get("id", ""), + name=data.get("name", ""), + arguments=data.get("arguments", {}), + index=data.get("index"), + ) + + +@dataclass +class LLMResponse: + """Represents a response from the LLM.""" + + content: str + tool_calls: List[ToolCall] + finish_reason: str + raw_response: Any + + @property + def requires_tool_call(self) -> bool: + """Check if the response requires tool calls.""" + return bool(self.tool_calls) and self.finish_reason == "tool_calls" + + +class LLMHandler(ABC): + """Abstract base class for LLM handlers.""" + + def __init__(self): + self.llm_calls = [] + self.tool_calls = [] + + @abstractmethod + def parse_response(self, response: Any) -> LLMResponse: + """Parse raw LLM response into standardized format.""" + pass + + @abstractmethod + def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict: + """Create a tool result message for the conversation history.""" + pass + + @abstractmethod + def _iterate_stream(self, response: Any) -> Generator: + """Iterate through streaming response chunks.""" + pass + + def process_message_flow( + self, + agent, + initial_response, + tools_dict: Dict, + messages: List[Dict], + attachments: Optional[List] = None, + stream: bool = False, + ) -> Union[str, Generator]: + """ + Main orchestration method for processing LLM message flow. + + Args: + agent: The agent instance + initial_response: Initial LLM response + tools_dict: Dictionary of available tools + messages: Conversation history + attachments: Optional attachments + stream: Whether to use streaming + + Returns: + Final response or generator for streaming + """ + messages = self.prepare_messages(agent, messages, attachments) + + if stream: + return self.handle_streaming(agent, initial_response, tools_dict, messages) + else: + return self.handle_non_streaming( + agent, initial_response, tools_dict, messages + ) + + def prepare_messages( + self, agent, messages: List[Dict], attachments: Optional[List] = None + ) -> List[Dict]: + """ + Prepare messages with attachments and provider-specific formatting. + + Args: + agent: The agent instance + messages: Original messages + attachments: List of attachments + + Returns: + Prepared messages list + """ + if not attachments: + return messages + logger.info(f"Preparing messages with {len(attachments)} attachments") + supported_types = agent.llm.get_supported_attachment_types() + + supported_attachments = [ + a for a in attachments if a.get("mime_type") in supported_types + ] + unsupported_attachments = [ + a for a in attachments if a.get("mime_type") not in supported_types + ] + + # Process supported attachments with the LLM's custom method + + if supported_attachments: + logger.info( + f"Processing {len(supported_attachments)} supported attachments" + ) + messages = agent.llm.prepare_messages_with_attachments( + messages, supported_attachments + ) + # Process unsupported attachments with default method + + if unsupported_attachments: + logger.info( + f"Processing {len(unsupported_attachments)} unsupported attachments" + ) + messages = self._append_unsupported_attachments( + messages, unsupported_attachments + ) + return messages + + def _append_unsupported_attachments( + self, messages: List[Dict], attachments: List[Dict] + ) -> List[Dict]: + """ + Default method to append unsupported attachment content to system prompt. + + Args: + messages: Current messages + attachments: List of unsupported attachments + + Returns: + Updated messages list + """ + prepared_messages = messages.copy() + attachment_texts = [] + + for attachment in attachments: + logger.info(f"Adding attachment {attachment.get('id')} to context") + if "content" in attachment: + attachment_texts.append( + f"Attached file content:\n\n{attachment['content']}" + ) + if attachment_texts: + combined_text = "\n\n".join(attachment_texts) + + system_msg = next( + (msg for msg in prepared_messages if msg.get("role") == "system"), + {"role": "system", "content": ""}, + ) + + if system_msg not in prepared_messages: + prepared_messages.insert(0, system_msg) + system_msg["content"] += f"\n\n{combined_text}" + return prepared_messages + + def handle_tool_calls( + self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict] + ) -> List[Dict]: + """ + Execute tool calls and update conversation history. + + Args: + agent: The agent instance + tool_calls: List of tool calls to execute + tools_dict: Available tools dictionary + messages: Current conversation history + + Returns: + Updated messages list + """ + updated_messages = messages.copy() + + for call in tool_calls: + try: + self.tool_calls.append(call) + tool_response, call_id = agent._execute_tool_action(tools_dict, call) + + updated_messages.append( + { + "role": "assistant", + "content": [ + { + "function_call": { + "name": call.name, + "args": call.arguments, + "call_id": call_id, + } + } + ], + } + ) + + updated_messages.append(self.create_tool_message(call, tool_response)) + + except Exception as e: + logger.error(f"Error executing tool: {str(e)}", exc_info=True) + updated_messages.append( + { + "role": "tool", + "content": f"Error executing tool: {str(e)}", + "tool_call_id": call.id, + } + ) + + return updated_messages + + def handle_non_streaming( + self, agent, response: Any, tools_dict: Dict, messages: List[Dict] + ) -> Union[str, Dict]: + """ + Handle non-streaming response flow. + + Args: + agent: The agent instance + response: Current LLM response + tools_dict: Available tools dictionary + messages: Conversation history + + Returns: + Final response after processing all tool calls + """ + parsed = self.parse_response(response) + self.llm_calls.append(build_stack_data(agent.llm)) + + while parsed.requires_tool_call: + messages = self.handle_tool_calls( + agent, parsed.tool_calls, tools_dict, messages + ) + + response = agent.llm.gen( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + parsed = self.parse_response(response) + self.llm_calls.append(build_stack_data(agent.llm)) + + return parsed.content + + def handle_streaming( + self, agent, response: Any, tools_dict: Dict, messages: List[Dict] + ) -> Generator: + """ + Handle streaming response flow. + + Args: + agent: The agent instance + response: Current LLM response + tools_dict: Available tools dictionary + messages: Conversation history + + Yields: + Streaming response chunks + """ + buffer = "" + tool_calls = {} + + for chunk in self._iterate_stream(response): + if isinstance(chunk, str): + yield chunk + continue + parsed = self.parse_response(chunk) + + if parsed.tool_calls: + for call in parsed.tool_calls: + if call.index not in tool_calls: + tool_calls[call.index] = call + else: + existing = tool_calls[call.index] + if call.id: + existing.id = call.id + if call.name: + existing.name = call.name + if call.arguments: + existing.arguments += call.arguments + if parsed.finish_reason == "tool_calls": + messages = self.handle_tool_calls( + agent, list(tool_calls.values()), tools_dict, messages + ) + tool_calls = {} + + response = agent.llm.gen_stream( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + self.llm_calls.append(build_stack_data(agent.llm)) + + yield from self.handle_streaming(agent, response, tools_dict, messages) + return + if parsed.content: + buffer += parsed.content + yield buffer + buffer = "" + if parsed.finish_reason == "stop": + return diff --git a/application/llm/handlers/google.py b/application/llm/handlers/google.py new file mode 100644 index 00000000..b43f2a16 --- /dev/null +++ b/application/llm/handlers/google.py @@ -0,0 +1,78 @@ +import uuid +from typing import Any, Dict, Generator + +from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall + + +class GoogleLLMHandler(LLMHandler): + """Handler for Google's GenAI API.""" + + def parse_response(self, response: Any) -> LLMResponse: + """Parse Google response into standardized format.""" + + if isinstance(response, str): + return LLMResponse( + content=response, + tool_calls=[], + finish_reason="stop", + raw_response=response, + ) + + if hasattr(response, "candidates"): + parts = response.candidates[0].content.parts if response.candidates else [] + tool_calls = [ + ToolCall( + id=str(uuid.uuid4()), + name=part.function_call.name, + arguments=part.function_call.args, + ) + for part in parts + if hasattr(part, "function_call") and part.function_call is not None + ] + + content = " ".join( + part.text + for part in parts + if hasattr(part, "text") and part.text is not None + ) + return LLMResponse( + content=content, + tool_calls=tool_calls, + finish_reason="tool_calls" if tool_calls else "stop", + raw_response=response, + ) + + else: + tool_calls = [] + if hasattr(response, "function_call"): + tool_calls.append( + ToolCall( + id=str(uuid.uuid4()), + name=response.function_call.name, + arguments=response.function_call.args, + ) + ) + return LLMResponse( + content=response.text if hasattr(response, "text") else "", + tool_calls=tool_calls, + finish_reason="tool_calls" if tool_calls else "stop", + raw_response=response, + ) + + def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict: + """Create Google-style tool message.""" + from google.genai import types + + return { + "role": "tool", + "content": [ + types.Part.from_function_response( + name=tool_call.name, response={"result": result} + ).to_json_dict() + ], + } + + def _iterate_stream(self, response: Any) -> Generator: + """Iterate through Google streaming response.""" + for chunk in response: + yield chunk diff --git a/application/llm/handlers/handler_creator.py b/application/llm/handlers/handler_creator.py new file mode 100644 index 00000000..e39c000d --- /dev/null +++ b/application/llm/handlers/handler_creator.py @@ -0,0 +1,18 @@ +from application.llm.handlers.base import LLMHandler +from application.llm.handlers.google import GoogleLLMHandler +from application.llm.handlers.openai import OpenAILLMHandler + + +class LLMHandlerCreator: + handlers = { + "openai": OpenAILLMHandler, + "google": GoogleLLMHandler, + "default": OpenAILLMHandler, + } + + @classmethod + def create_handler(cls, llm_type: str, *args, **kwargs) -> LLMHandler: + handler_class = cls.handlers.get(llm_type.lower()) + if not handler_class: + raise ValueError(f"No LLM handler class found for type {llm_type}") + return handler_class(*args, **kwargs) diff --git a/application/llm/handlers/openai.py b/application/llm/handlers/openai.py new file mode 100644 index 00000000..99ddde4c --- /dev/null +++ b/application/llm/handlers/openai.py @@ -0,0 +1,57 @@ +from typing import Any, Dict, Generator + +from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall + + +class OpenAILLMHandler(LLMHandler): + """Handler for OpenAI API.""" + + def parse_response(self, response: Any) -> LLMResponse: + """Parse OpenAI response into standardized format.""" + if isinstance(response, str): + return LLMResponse( + content=response, + tool_calls=[], + finish_reason="stop", + raw_response=response, + ) + + message = getattr(response, "message", None) or getattr(response, "delta", None) + + tool_calls = [] + if hasattr(message, "tool_calls"): + tool_calls = [ + ToolCall( + id=getattr(tc, "id", ""), + name=getattr(tc.function, "name", ""), + arguments=getattr(tc.function, "arguments", ""), + index=getattr(tc, "index", None), + ) + for tc in message.tool_calls or [] + ] + return LLMResponse( + content=getattr(message, "content", ""), + tool_calls=tool_calls, + finish_reason=getattr(response, "finish_reason", ""), + raw_response=response, + ) + + def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict: + """Create OpenAI-style tool message.""" + return { + "role": "tool", + "content": [ + { + "function_response": { + "name": tool_call.name, + "response": {"result": result}, + "call_id": tool_call.id, + } + } + ], + } + + def _iterate_stream(self, response: Any) -> Generator: + """Iterate through OpenAI streaming response.""" + for chunk in response: + yield chunk From 143f4aa886fdc3f9edd11b62d362b6cd71ec2e16 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Fri, 6 Jun 2025 14:41:44 +0530 Subject: [PATCH 02/10] refactor: streamline conversation handling and update agent pinning logic --- frontend/src/Navigation.tsx | 110 +++++++++++++++++------------------- 1 file changed, 53 insertions(+), 57 deletions(-) diff --git a/frontend/src/Navigation.tsx b/frontend/src/Navigation.tsx index f545d38c..13aed4f2 100644 --- a/frontend/src/Navigation.tsx +++ b/frontend/src/Navigation.tsx @@ -48,6 +48,7 @@ import { setConversations, setModalStateDeleteConv, setSelectedAgent, + setSharedAgents, } from './preferences/preferenceSlice'; import Upload from './upload/Upload'; @@ -169,70 +170,65 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { const handleTogglePin = (agent: Agent) => { userService.togglePinAgent(agent.id ?? '', token).then((response) => { if (response.ok) { - const updatedAgents = agents?.map((a) => - a.id === agent.id ? { ...a, pinned: !a.pinned } : a, - ); - dispatch(setAgents(updatedAgents)); + const updatePinnedStatus = (a: Agent) => + a.id === agent.id ? { ...a, pinned: !a.pinned } : a; + dispatch(setAgents(agents?.map(updatePinnedStatus))); + dispatch(setSharedAgents(sharedAgents?.map(updatePinnedStatus))); } }); }; - const handleConversationClick = (index: string) => { - dispatch(setSelectedAgent(null)); - conversationService - .getConversation(index, token) - .then((response) => { - if (!response.ok) { - navigate('/'); - dispatch(setSelectedAgent(null)); - return null; - } - return response.json(); - }) - .then((data) => { - if (!data) return; - dispatch(setConversation(data.queries)); - dispatch( - updateConversationId({ - query: { conversationId: index }, - }), + const handleConversationClick = async (index: string) => { + try { + dispatch(setSelectedAgent(null)); + + const response = await conversationService.getConversation(index, token); + if (!response.ok) { + navigate('/'); + return; + } + + const data = await response.json(); + if (!data) return; + + dispatch(setConversation(data.queries)); + dispatch(updateConversationId({ query: { conversationId: index } })); + + if (!data.agent_id) { + navigate('/'); + return; + } + + let agent: Agent; + if (data.is_shared_usage) { + const sharedResponse = await userService.getSharedAgent( + data.shared_token, + token, ); - if (data.agent_id) { - if (data.is_shared_usage) { - userService - .getSharedAgent(data.shared_token, token) - .then((response) => { - if (!response.ok) { - navigate('/'); - dispatch(setSelectedAgent(null)); - return; - } - response.json().then((agent: Agent) => { - navigate(`/agents/shared/${agent.shared_token}`); - }); - }); - } else { - userService.getAgent(data.agent_id, token).then((response) => { - if (!response.ok) { - navigate('/'); - dispatch(setSelectedAgent(null)); - return; - } - response.json().then((agent: Agent) => { - if (agent.shared_token) - navigate(`/agents/shared/${agent.shared_token}`); - else { - dispatch(setSelectedAgent(agent)); - navigate('/'); - } - }); - }); - } - } else { + if (!sharedResponse.ok) { navigate('/'); - dispatch(setSelectedAgent(null)); + return; } - }); + agent = await sharedResponse.json(); + navigate(`/agents/shared/${agent.shared_token}`); + } else { + const agentResponse = await userService.getAgent(data.agent_id, token); + if (!agentResponse.ok) { + navigate('/'); + return; + } + agent = await agentResponse.json(); + if (agent.shared_token) { + navigate(`/agents/shared/${agent.shared_token}`); + } else { + await Promise.resolve(dispatch(setSelectedAgent(agent))); + navigate('/'); + } + } + } catch (error) { + console.error('Error handling conversation click:', error); + navigate('/'); + } }; const resetConversation = () => { From e9530d5ec5e5fa9fdd5a6ce2a25e886ec9608f61 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Fri, 6 Jun 2025 15:29:53 +0530 Subject: [PATCH 03/10] refactor: update env variable names --- application/agents/classic_agent.py | 113 +++++++++++++++-------- application/api/answer/routes.py | 24 +++-- application/core/settings.py | 13 ++- application/llm/llama_cpp.py | 13 ++- application/retriever/brave_search.py | 8 +- application/retriever/classic_rag.py | 6 +- application/retriever/duckduck_search.py | 8 +- application/utils.py | 4 +- application/worker.py | 18 ++-- 9 files changed, 121 insertions(+), 86 deletions(-) diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py index b371123b..d0576511 100644 --- a/application/agents/classic_agent.py +++ b/application/agents/classic_agent.py @@ -1,8 +1,6 @@ from typing import Dict, Generator - from application.agents.base import BaseAgent from application.logging import LogContext - from application.retriever.base import BaseRetriever import logging @@ -10,55 +8,90 @@ logger = logging.getLogger(__name__) class ClassicAgent(BaseAgent): + """A simplified classic agent with clear execution flow. + + Usage: + 1. Processes a query through retrieval + 2. Sets up available tools + 3. Generates responses using LLM + 4. Handles tool interactions if needed + 5. Returns standardized outputs + + Easy to extend by overriding specific steps. + """ + def _gen_inner( self, query: str, retriever: BaseRetriever, log_context: LogContext ) -> Generator[Dict, None, None]: + """Main execution flow for the agent.""" + # Step 1: Retrieve relevant data retrieved_data = self._retriever_search(retriever, query, log_context) - if self.user_api_key: - tools_dict = self._get_tools(self.user_api_key) - else: - tools_dict = self._get_user_tools(self.user) + + # Step 2: Prepare tools + tools_dict = ( + self._get_user_tools(self.user) + if not self.user_api_key + else self._get_tools(self.user_api_key) + ) self._prepare_tools(tools_dict) + # Step 3: Build and process messages messages = self._build_messages(self.prompt, query, retrieved_data) + llm_response = self._llm_gen(messages, log_context) - resp = self._llm_gen(messages, log_context) + # Step 4: Handle the response + yield from self._handle_response( + llm_response, tools_dict, messages, log_context + ) - attachments = self.attachments - - if isinstance(resp, str): - yield {"answer": resp} - return - if ( - hasattr(resp, "message") - and hasattr(resp.message, "content") - and resp.message.content is not None - ): - yield {"answer": resp.message.content} - return - - resp = self._llm_handler(resp, tools_dict, messages, log_context, attachments) - - if isinstance(resp, str): - yield {"answer": resp} - elif ( - hasattr(resp, "message") - and hasattr(resp.message, "content") - and resp.message.content is not None - ): - yield {"answer": resp.message.content} - else: - for line in resp: - if isinstance(line, str): - yield {"answer": line} + # Step 5: Return metadata + yield {"sources": retrieved_data} + yield {"tool_calls": self._get_truncated_tool_calls()} + # Log tool calls for debugging log_context.stacks.append( {"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}} ) - yield {"sources": retrieved_data} - # clean tool_call_data only send first 50 characters of tool_call['result'] - for tool_call in self.tool_calls: - if len(str(tool_call["result"])) > 50: - tool_call["result"] = str(tool_call["result"])[:50] + "..." - yield {"tool_calls": self.tool_calls.copy()} + def _handle_response(self, response, tools_dict, messages, log_context): + """Handle different types of LLM responses consistently.""" + # Handle simple string responses + if isinstance(response, str): + yield {"answer": response} + return + + # Handle content from message objects + if hasattr(response, "message") and getattr(response.message, "content", None): + yield {"answer": response.message.content} + return + + # Handle complex responses that may require tool use + processed_response = self._llm_handler( + response, tools_dict, messages, log_context, self.attachments + ) + + # Yield the final processed response + if isinstance(processed_response, str): + yield {"answer": processed_response} + elif hasattr(processed_response, "message") and getattr( + processed_response.message, "content", None + ): + yield {"answer": processed_response.message.content} + else: + for line in processed_response: + if isinstance(line, str): + yield {"answer": line} + + def _get_truncated_tool_calls(self): + """Return tool calls with truncated results for cleaner output.""" + return [ + { + **tool_call, + "result": ( + f"{str(tool_call['result'])[:50]}..." + if len(str(tool_call["result"])) > 50 + else tool_call["result"] + ), + } + for tool_call in self.tool_calls + ] diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 2aa473d4..44ba035b 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -37,17 +37,17 @@ api.add_namespace(answer_ns) gpt_model = "" # to have some kind of default behaviour -if settings.LLM_NAME == "openai": +if settings.LLM_PROVIDER == "openai": gpt_model = "gpt-4o-mini" -elif settings.LLM_NAME == "anthropic": +elif settings.LLM_PROVIDER == "anthropic": gpt_model = "claude-2" -elif settings.LLM_NAME == "groq": +elif settings.LLM_PROVIDER == "groq": gpt_model = "llama3-8b-8192" -elif settings.LLM_NAME == "novita": +elif settings.LLM_PROVIDER == "novita": gpt_model = "deepseek/deepseek-r1" -if settings.MODEL_NAME: # in case there is particular model name configured - gpt_model = settings.MODEL_NAME +if settings.LLM_NAME: # in case there is particular model name configured + gpt_model = settings.LLM_NAME # load the prompts current_dir = os.path.dirname( @@ -322,7 +322,7 @@ def complete_stream( doc["source"] = "None" llm = LLMCreator.create_llm( - settings.LLM_NAME, + settings.LLM_PROVIDER, api_key=settings.API_KEY, user_api_key=user_api_key, decoded_token=decoded_token, @@ -453,9 +453,7 @@ class Stream(Resource): agent_type = settings.AGENT_NAME decoded_token = getattr(request, "decoded_token", None) user_sub = decoded_token.get("sub") if decoded_token else None - agent_key, is_shared_usage, shared_token = get_agent_key( - agent_id, user_sub - ) + agent_key, is_shared_usage, shared_token = get_agent_key(agent_id, user_sub) if agent_key: data.update({"api_key": agent_key}) @@ -506,7 +504,7 @@ class Stream(Resource): agent = AgentCreator.create_agent( agent_type, endpoint="stream", - llm_name=settings.LLM_NAME, + llm_name=settings.LLM_PROVIDER, gpt_model=gpt_model, api_key=settings.API_KEY, user_api_key=user_api_key, @@ -659,7 +657,7 @@ class Answer(Resource): agent = AgentCreator.create_agent( agent_type, endpoint="api/answer", - llm_name=settings.LLM_NAME, + llm_name=settings.LLM_PROVIDER, gpt_model=gpt_model, api_key=settings.API_KEY, user_api_key=user_api_key, @@ -728,7 +726,7 @@ class Answer(Resource): doc["source"] = "None" llm = LLMCreator.create_llm( - settings.LLM_NAME, + settings.LLM_PROVIDER, api_key=settings.API_KEY, user_api_key=user_api_key, decoded_token=decoded_token, diff --git a/application/core/settings.py b/application/core/settings.py index 3be34242..05a09510 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -11,18 +11,18 @@ current_dir = os.path.dirname( class Settings(BaseSettings): AUTH_TYPE: Optional[str] = None - LLM_NAME: str = "docsgpt" - MODEL_NAME: Optional[str] = ( - None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo + LLM_PROVIDER: str = "docsgpt" + LLM_NAME: Optional[str] = ( + None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo ) EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2" CELERY_BROKER_URL: str = "redis://localhost:6379/0" CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" MONGO_URI: str = "mongodb://localhost:27017/docsgpt" MONGO_DB_NAME: str = "docsgpt" - MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf") + LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf") DEFAULT_MAX_HISTORY: int = 150 - MODEL_TOKEN_LIMITS: dict = { + LLM_TOKEN_LIMITS: dict = { "gpt-4o-mini": 128000, "gpt-3.5-turbo": 4096, "claude-2": 1e5, @@ -99,8 +99,7 @@ class Settings(BaseSettings): BRAVE_SEARCH_API_KEY: Optional[str] = None FLASK_DEBUG_MODE: bool = False - STORAGE_TYPE: str = "local" # local or s3 - + STORAGE_TYPE: str = "local" # local or s3 JWT_SECRET_KEY: str = "" diff --git a/application/llm/llama_cpp.py b/application/llm/llama_cpp.py index 804c3c56..f0418b49 100644 --- a/application/llm/llama_cpp.py +++ b/application/llm/llama_cpp.py @@ -2,6 +2,7 @@ from application.llm.base import BaseLLM from application.core.settings import settings import threading + class LlamaSingleton: _instances = {} _lock = threading.Lock() # Add a lock for thread synchronization @@ -29,7 +30,7 @@ class LlamaCpp(BaseLLM): self, api_key=None, user_api_key=None, - llm_name=settings.MODEL_PATH, + llm_name=settings.LLM_PATH, *args, **kwargs, ): @@ -42,14 +43,18 @@ class LlamaCpp(BaseLLM): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" - result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False) + result = LlamaSingleton.query_model( + self.llama, prompt, max_tokens=150, echo=False + ) return result["choices"][0]["text"].split("### Answer \n")[-1] def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" - result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False, stream=stream) + result = LlamaSingleton.query_model( + self.llama, prompt, max_tokens=150, echo=False, stream=stream + ) for item in result: for choice in item["choices"]: - yield choice["text"] \ No newline at end of file + yield choice["text"] diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py index 33bdb894..123000e4 100644 --- a/application/retriever/brave_search.py +++ b/application/retriever/brave_search.py @@ -29,10 +29,10 @@ class BraveRetSearch(BaseRetriever): self.token_limit = ( token_limit if token_limit - < settings.MODEL_TOKEN_LIMITS.get( + < settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) - else settings.MODEL_TOKEN_LIMITS.get( + else settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) ) @@ -59,7 +59,7 @@ class BraveRetSearch(BaseRetriever): docs.append({"text": snippet, "title": title, "link": link}) except IndexError: pass - if settings.LLM_NAME == "llama.cpp": + if settings.LLM_PROVIDER == "llama.cpp": docs = [docs[0]] return docs @@ -84,7 +84,7 @@ class BraveRetSearch(BaseRetriever): messages_combine.append({"role": "user", "content": self.question}) llm = LLMCreator.create_llm( - settings.LLM_NAME, + settings.LLM_PROVIDER, api_key=settings.API_KEY, user_api_key=self.user_api_key, decoded_token=self.decoded_token, diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index b8ac69e4..2b4e6df7 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -16,7 +16,7 @@ class ClassicRAG(BaseRetriever): token_limit=150, gpt_model="docsgpt", user_api_key=None, - llm_name=settings.LLM_NAME, + llm_name=settings.LLM_PROVIDER, api_key=settings.API_KEY, decoded_token=None, ): @@ -28,10 +28,10 @@ class ClassicRAG(BaseRetriever): self.token_limit = ( token_limit if token_limit - < settings.MODEL_TOKEN_LIMITS.get( + < settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) - else settings.MODEL_TOKEN_LIMITS.get( + else settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) ) diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py index 29bbba18..5abe5edd 100644 --- a/application/retriever/duckduck_search.py +++ b/application/retriever/duckduck_search.py @@ -28,10 +28,10 @@ class DuckDuckSearch(BaseRetriever): self.token_limit = ( token_limit if token_limit - < settings.MODEL_TOKEN_LIMITS.get( + < settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) - else settings.MODEL_TOKEN_LIMITS.get( + else settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) ) @@ -58,7 +58,7 @@ class DuckDuckSearch(BaseRetriever): ) except IndexError: pass - if settings.LLM_NAME == "llama.cpp": + if settings.LLM_PROVIDER == "llama.cpp": docs = [docs[0]] return docs @@ -83,7 +83,7 @@ class DuckDuckSearch(BaseRetriever): messages_combine.append({"role": "user", "content": self.question}) llm = LLMCreator.create_llm( - settings.LLM_NAME, + settings.LLM_PROVIDER, api_key=settings.API_KEY, user_api_key=self.user_api_key, decoded_token=self.decoded_token, diff --git a/application/utils.py b/application/utils.py index 6d47d31a..7a9cfd2b 100644 --- a/application/utils.py +++ b/application/utils.py @@ -74,8 +74,8 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"): max_token_limit if max_token_limit and max_token_limit - < settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) - else settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) + < settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) + else settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) ) if not history: diff --git a/application/worker.py b/application/worker.py index 9829fde9..13c0ca30 100755 --- a/application/worker.py +++ b/application/worker.py @@ -143,8 +143,8 @@ def run_agent_logic(agent_config, input_data): agent = AgentCreator.create_agent( agent_type, endpoint="webhook", - llm_name=settings.LLM_NAME, - gpt_model=settings.MODEL_NAME, + llm_name=settings.LLM_PROVIDER, + gpt_model=settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key, prompt=prompt, @@ -159,7 +159,7 @@ def run_agent_logic(agent_config, input_data): prompt=prompt, chunks=chunks, token_limit=settings.DEFAULT_MAX_HISTORY, - gpt_model=settings.MODEL_NAME, + gpt_model=settings.LLM_NAME, user_api_key=user_api_key, decoded_token=decoded_token, ) @@ -449,7 +449,7 @@ def attachment_worker(self, file_info, user): try: self.update_state(state="PROGRESS", meta={"current": 10}) storage = StorageCreator.get_storage() - + self.update_state( state="PROGRESS", meta={"current": 30, "status": "Processing content"} ) @@ -458,9 +458,11 @@ def attachment_worker(self, file_info, user): relative_path, lambda local_path, **kwargs: SimpleDirectoryReader( input_files=[local_path], exclude_hidden=True, errors="ignore" - ).load_data()[0].text + ) + .load_data()[0] + .text, ) - + token_count = num_tokens_from_string(content) self.update_state( @@ -487,9 +489,7 @@ def attachment_worker(self, file_info, user): f"Stored attachment with ID: {attachment_id}", extra={"user": user} ) - self.update_state( - state="PROGRESS", meta={"current": 100, "status": "Complete"} - ) + self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"}) return { "filename": filename, From 5f5c31cd5b11f6d1f739711409d25ee9c031e981 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Fri, 6 Jun 2025 16:55:57 +0530 Subject: [PATCH 04/10] refactor: enhance LLM fallback handling and streamline method execution --- application/agents/base.py | 15 ++++- application/llm/base.py | 131 +++++++++++++++++++++++++++---------- 2 files changed, 110 insertions(+), 36 deletions(-) diff --git a/application/agents/base.py b/application/agents/base.py index a4bbd001..f48418b3 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -256,12 +256,21 @@ class BaseAgent(ABC): return retrieved_data def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None): - resp = self.llm.gen_stream( - model=self.gpt_model, messages=messages, tools=self.tools - ) + gen_kwargs = {"model": self.gpt_model, "messages": messages} + + if ( + hasattr(self.llm, "_supports_tools") + and self.llm._supports_tools + and self.tools + ): + gen_kwargs["tools"] = self.tools + + resp = self.llm.gen_stream(**gen_kwargs) + if log_context: data = build_stack_data(self.llm, exclude_attributes=["client"]) log_context.stacks.append({"component": "llm", "data": data}) + return resp def _llm_handler( diff --git a/application/llm/base.py b/application/llm/base.py index 0607159d..b145816d 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -1,53 +1,118 @@ +import logging from abc import ABC, abstractmethod from application.cache import gen_cache, stream_cache + +from application.core.settings import settings from application.usage import gen_token_usage, stream_token_usage +logger = logging.getLogger(__name__) + class BaseLLM(ABC): - def __init__(self, decoded_token=None): + def __init__( + self, + decoded_token=None, + ): self.decoded_token = decoded_token self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0} + self.fallback_provider = settings.FALLBACK_LLM_PROVIDER + self.fallback_model_name = settings.FALLBACK_LLM_NAME + self.fallback_llm_api_key = settings.FALLBACK_LLM_API_KEY + self._fallback_llm = None - def _apply_decorator(self, method, decorators, *args, **kwargs): - for decorator in decorators: - method = decorator(method) - return method(self, *args, **kwargs) + @property + def fallback_llm(self): + """Lazy-loaded fallback LLM instance.""" + if ( + self._fallback_llm is None + and self.fallback_provider + and self.fallback_model_name + ): + try: + from llm.llm_creator import LLMCreator + + self._fallback_llm = LLMCreator( + self.fallback_provider, + self.fallback_llm_api_key, + None, + self.decoded_token, + ) + except Exception as e: + logger.error( + f"Failed to initialize fallback LLM: {str(e)}", exc_info=True + ) + return self._fallback_llm + + def _execute_with_fallback( + self, method_name: str, decorators: list, *args, **kwargs + ): + """ + Unified method execution with fallback support. + + Args: + method_name: Name of the raw method ('_raw_gen' or '_raw_gen_stream') + decorators: List of decorators to apply + *args: Positional arguments + **kwargs: Keyword arguments + """ + + def decorated_method(): + method = getattr(self, method_name) + for decorator in decorators: + method = decorator(method) + return method(self, *args, **kwargs) + + try: + return decorated_method() + except Exception as e: + if not self.fallback_llm: + logger.error(f"Primary LLM failed and no fallback available: {str(e)}") + raise + logger.warning( + f"Falling back to {self.fallback_provider}/{self.fallback_model_name}. Error: {str(e)}" + ) + # Retry with fallback (without decorators for accurate token tracking) + + fallback_method = getattr( + self.fallback_llm, method_name.replace("_raw_", "") + ) + return fallback_method(*args, **kwargs) + + def gen(self, model, messages, stream=False, tools=None, *args, **kwargs): + decorators = [gen_token_usage, gen_cache] + return self._execute_with_fallback( + "_raw_gen", + decorators, + model=model, + messages=messages, + stream=stream, + tools=tools, + *args, + **kwargs, + ) + + def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs): + decorators = [stream_cache, stream_token_usage] + return self._execute_with_fallback( + "_raw_gen_stream", + decorators, + model=model, + messages=messages, + stream=stream, + tools=tools, + *args, + **kwargs, + ) @abstractmethod def _raw_gen(self, model, messages, stream, tools, *args, **kwargs): pass - def gen(self, model, messages, stream=False, tools=None, *args, **kwargs): - decorators = [gen_token_usage, gen_cache] - return self._apply_decorator( - self._raw_gen, - decorators=decorators, - model=model, - messages=messages, - stream=stream, - tools=tools, - *args, - **kwargs - ) - @abstractmethod def _raw_gen_stream(self, model, messages, stream, *args, **kwargs): pass - def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs): - decorators = [stream_cache, stream_token_usage] - return self._apply_decorator( - self._raw_gen_stream, - decorators=decorators, - model=model, - messages=messages, - stream=stream, - tools=tools, - *args, - **kwargs - ) - def supports_tools(self): return hasattr(self, "_supports_tools") and callable( getattr(self, "_supports_tools") @@ -55,11 +120,11 @@ class BaseLLM(ABC): def _supports_tools(self): raise NotImplementedError("Subclass must implement _supports_tools method") - + def get_supported_attachment_types(self): """ Return a list of MIME types supported by this LLM for file uploads. - + Returns: list: List of supported MIME types """ From 35f4b1323722fdb3efae6910155b5da5e65d2e3f Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Fri, 6 Jun 2025 17:05:15 +0530 Subject: [PATCH 05/10] refactor: add fallback LLM configuration options to settings --- application/core/settings.py | 3 +++ application/llm/base.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/application/core/settings.py b/application/core/settings.py index 05a09510..2ff371af 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -35,6 +35,9 @@ class Settings(BaseSettings): ) RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search AGENT_NAME: str = "classic" + FALLBACK_LLM_PROVIDER: Optional[str] = None # provider for fallback llm + FALLBACK_LLM_NAME: Optional[str] = None # model name for fallback llm + FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm # LLM Cache CACHE_REDIS_URL: str = "redis://localhost:6379/2" diff --git a/application/llm/base.py b/application/llm/base.py index b145816d..cbce4ffd 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -72,7 +72,6 @@ class BaseLLM(ABC): logger.warning( f"Falling back to {self.fallback_provider}/{self.fallback_model_name}. Error: {str(e)}" ) - # Retry with fallback (without decorators for accurate token tracking) fallback_method = getattr( self.fallback_llm, method_name.replace("_raw_", "") From e5b1a716590aa457b8d2bbc31271b2d05bd82883 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Fri, 6 Jun 2025 17:23:27 +0530 Subject: [PATCH 06/10] refactor: update fallback LLM initialization to use factory method --- application/llm/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/application/llm/base.py b/application/llm/base.py index cbce4ffd..bef3e11f 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -30,9 +30,9 @@ class BaseLLM(ABC): and self.fallback_model_name ): try: - from llm.llm_creator import LLMCreator + from application.llm.llm_creator import LLMCreator - self._fallback_llm = LLMCreator( + self._fallback_llm = LLMCreator.create_llm( self.fallback_provider, self.fallback_llm_api_key, None, From 3351f71813d3057251bce3756d8a83adfee4959e Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 11 Jun 2025 12:40:32 +0530 Subject: [PATCH 07/10] refactor: tool calls sent when pending and after completion --- application/agents/base.py | 54 +++++++++++++++---- application/agents/classic_agent.py | 44 --------------- application/api/answer/routes.py | 5 +- application/llm/handlers/base.py | 28 ++++++++-- .../src/conversation/conversationSlice.ts | 32 +++++++---- frontend/src/conversation/types/index.ts | 3 +- 6 files changed, 94 insertions(+), 72 deletions(-) diff --git a/application/agents/base.py b/application/agents/base.py index f48418b3..adebc125 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -136,6 +136,15 @@ class BaseAgent(ABC): parser = ToolActionParser(self.llm.__class__.__name__) tool_id, action_name, call_args = parser.parse_args(call) + call_id = getattr(call, "id", None) or str(uuid.uuid4()) + tool_call_data = { + "tool_name": tools_dict[tool_id]["name"], + "call_id": call_id, + "action_name": f"{action_name}_{tool_id}", + "arguments": call_args, + } + yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}} + tool_data = tools_dict[tool_id] action_data = ( tool_data["config"]["actions"][action_name] @@ -188,19 +197,26 @@ class BaseAgent(ABC): else: print(f"Executing tool: {action_name} with args: {call_args}") result = tool.execute_action(action_name, **parameters) - call_id = getattr(call, "id", None) + tool_call_data["result"] = result - tool_call_data = { - "tool_name": tool_data["name"], - "call_id": call_id if call_id is not None else "None", - "action_name": f"{action_name}_{tool_id}", - "arguments": call_args, - "result": result, - } + yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}} self.tool_calls.append(tool_call_data) return result, call_id + def _get_truncated_tool_calls(self): + return [ + { + **tool_call, + "result": ( + f"{str(tool_call['result'])[:50]}..." + if len(str(tool_call["result"])) > 50 + else tool_call["result"] + ), + } + for tool_call in self.tool_calls + ] + def _build_messages( self, system_prompt: str, @@ -264,13 +280,11 @@ class BaseAgent(ABC): and self.tools ): gen_kwargs["tools"] = self.tools - resp = self.llm.gen_stream(**gen_kwargs) if log_context: data = build_stack_data(self.llm, exclude_attributes=["client"]) log_context.stacks.append({"component": "llm", "data": data}) - return resp def _llm_handler( @@ -288,3 +302,23 @@ class BaseAgent(ABC): data = build_stack_data(self.llm_handler, exclude_attributes=["tool_calls"]) log_context.stacks.append({"component": "llm_handler", "data": data}) return resp + + def _handle_response(self, response, tools_dict, messages, log_context): + if isinstance(response, str): + yield {"answer": response} + return + if hasattr(response, "message") and getattr(response.message, "content", None): + yield {"answer": response.message.content} + return + + processed_response_gen = self._llm_handler( + response, tools_dict, messages, log_context, self.attachments + ) + + for event in processed_response_gen: + if isinstance(event, str): + yield {"answer": event} + elif hasattr(event, "message") and getattr(event.message, "content", None): + yield {"answer": event.message.content} + elif isinstance(event, dict) and "type" in event: + yield event diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py index d0576511..6fe73de0 100644 --- a/application/agents/classic_agent.py +++ b/application/agents/classic_agent.py @@ -23,7 +23,6 @@ class ClassicAgent(BaseAgent): def _gen_inner( self, query: str, retriever: BaseRetriever, log_context: LogContext ) -> Generator[Dict, None, None]: - """Main execution flow for the agent.""" # Step 1: Retrieve relevant data retrieved_data = self._retriever_search(retriever, query, log_context) @@ -52,46 +51,3 @@ class ClassicAgent(BaseAgent): log_context.stacks.append( {"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}} ) - - def _handle_response(self, response, tools_dict, messages, log_context): - """Handle different types of LLM responses consistently.""" - # Handle simple string responses - if isinstance(response, str): - yield {"answer": response} - return - - # Handle content from message objects - if hasattr(response, "message") and getattr(response.message, "content", None): - yield {"answer": response.message.content} - return - - # Handle complex responses that may require tool use - processed_response = self._llm_handler( - response, tools_dict, messages, log_context, self.attachments - ) - - # Yield the final processed response - if isinstance(processed_response, str): - yield {"answer": processed_response} - elif hasattr(processed_response, "message") and getattr( - processed_response.message, "content", None - ): - yield {"answer": processed_response.message.content} - else: - for line in processed_response: - if isinstance(line, str): - yield {"answer": line} - - def _get_truncated_tool_calls(self): - """Return tool calls with truncated results for cleaner output.""" - return [ - { - **tool_call, - "result": ( - f"{str(tool_call['result'])[:50]}..." - if len(str(tool_call["result"])) > 50 - else tool_call["result"] - ), - } - for tool_call in self.tool_calls - ] diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 44ba035b..83c3db6f 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -310,12 +310,13 @@ def complete_stream( yield f"data: {data}\n\n" elif "tool_calls" in line: tool_calls = line["tool_calls"] - data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls}) - yield f"data: {data}\n\n" elif "thought" in line: thought += line["thought"] data = json.dumps({"type": "thought", "thought": line["thought"]}) yield f"data: {data}\n\n" + elif "type" in line: + data = json.dumps(line) + yield f"data: {data}\n\n" if isNoneDoc: for doc in source_log_docs: diff --git a/application/llm/handlers/base.py b/application/llm/handlers/base.py index ede7cec3..43205472 100644 --- a/application/llm/handlers/base.py +++ b/application/llm/handlers/base.py @@ -180,7 +180,7 @@ class LLMHandler(ABC): def handle_tool_calls( self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict] - ) -> List[Dict]: + ) -> Generator: """ Execute tool calls and update conversation history. @@ -198,7 +198,13 @@ class LLMHandler(ABC): for call in tool_calls: try: self.tool_calls.append(call) - tool_response, call_id = agent._execute_tool_action(tools_dict, call) + tool_executor_gen = agent._execute_tool_action(tools_dict, call) + while True: + try: + yield next(tool_executor_gen) + except StopIteration as e: + tool_response, call_id = e.value + break updated_messages.append( { @@ -231,7 +237,7 @@ class LLMHandler(ABC): def handle_non_streaming( self, agent, response: Any, tools_dict: Dict, messages: List[Dict] - ) -> Union[str, Dict]: + ) -> Generator: """ Handle non-streaming response flow. @@ -248,9 +254,15 @@ class LLMHandler(ABC): self.llm_calls.append(build_stack_data(agent.llm)) while parsed.requires_tool_call: - messages = self.handle_tool_calls( + tool_handler_gen = self.handle_tool_calls( agent, parsed.tool_calls, tools_dict, messages ) + while True: + try: + yield next(tool_handler_gen) + except StopIteration as e: + messages = e.value + break response = agent.llm.gen( model=agent.gpt_model, messages=messages, tools=agent.tools @@ -297,9 +309,15 @@ class LLMHandler(ABC): if call.arguments: existing.arguments += call.arguments if parsed.finish_reason == "tool_calls": - messages = self.handle_tool_calls( + tool_handler_gen = self.handle_tool_calls( agent, list(tool_calls.values()), tools_dict, messages ) + while True: + try: + yield next(tool_handler_gen) + except StopIteration as e: + messages = e.value + break tool_calls = {} response = agent.llm.gen_stream( diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index 961260ea..03532792 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -14,6 +14,7 @@ import { ConversationState, Attachment, } from './conversationModels'; +import { ToolCallsType } from './types'; const initialState: ConversationState = { queries: [], @@ -110,11 +111,11 @@ export const fetchAnswer = createAsyncThunk< query: { sources: data.source ?? [] }, }), ); - } else if (data.type === 'tool_calls') { + } else if (data.type === 'tool_call') { dispatch( - updateToolCalls({ + updateToolCall({ index: targetIndex, - query: { tool_calls: data.tool_calls }, + tool_call: data.data as ToolCallsType, }), ); } else if (data.type === 'error') { @@ -280,12 +281,23 @@ export const conversationSlice = createSlice({ state.queries[index].sources!.push(query.sources![0]); } }, - updateToolCalls( - state, - action: PayloadAction<{ index: number; query: Partial }>, - ) { - const { index, query } = action.payload; - state.queries[index].tool_calls = query?.tool_calls ?? []; + updateToolCall(state, action) { + const { index, tool_call } = action.payload; + + if (!state.queries[index].tool_calls) { + state.queries[index].tool_calls = []; + } + + const existingIndex = state.queries[index].tool_calls.findIndex( + (call) => call.call_id === tool_call.call_id, + ); + + if (existingIndex !== -1) { + Object.assign( + state.queries[index].tool_calls[existingIndex], + tool_call, + ); + } else state.queries[index].tool_calls.push(tool_call); }, updateQuery( state, @@ -378,7 +390,7 @@ export const { updateConversationId, updateThought, updateStreamingSource, - updateToolCalls, + updateToolCall, setConversation, setAttachments, addAttachment, diff --git a/frontend/src/conversation/types/index.ts b/frontend/src/conversation/types/index.ts index 9b5f2365..4ccb04a1 100644 --- a/frontend/src/conversation/types/index.ts +++ b/frontend/src/conversation/types/index.ts @@ -3,5 +3,6 @@ export type ToolCallsType = { action_name: string; call_id: string; arguments: Record; - result: Record; + result?: Record; + status?: 'pending' | 'completed'; }; From aaecf52c9997b33a4f3002bfcb9a1b3d68a11c8c Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 11 Jun 2025 12:30:34 +0100 Subject: [PATCH 08/10] refactor: update docs LLM_NAME and MODEL_NAME to LLM_PROVIDER and LLM_NAME --- deployment/docker-compose-dev.yaml | 1 + deployment/docker-compose.yaml | 4 ++- deployment/k8s/docsgpt-secrets.yaml | 2 +- docs/pages/Deploying/Docker-Deploying.mdx | 10 +++--- docs/pages/Deploying/DocsGPT-Settings.mdx | 24 +++++++------- .../pages/Guides/How-to-use-different-LLM.mdx | 4 +-- docs/pages/Models/cloud-providers.mdx | 14 ++++----- docs/pages/Models/embeddings.md | 2 +- docs/pages/Models/local-inference.mdx | 24 +++++++------- setup.sh | 31 +++++++++---------- 10 files changed, 59 insertions(+), 57 deletions(-) diff --git a/deployment/docker-compose-dev.yaml b/deployment/docker-compose-dev.yaml index 8a3e75c4..a1658bd2 100644 --- a/deployment/docker-compose-dev.yaml +++ b/deployment/docker-compose-dev.yaml @@ -1,3 +1,4 @@ +name: docsgpt-oss services: redis: diff --git a/deployment/docker-compose.yaml b/deployment/docker-compose.yaml index c4b81a08..da9249ea 100644 --- a/deployment/docker-compose.yaml +++ b/deployment/docker-compose.yaml @@ -1,3 +1,4 @@ +name: docsgpt-oss services: frontend: build: ../frontend @@ -17,13 +18,13 @@ services: environment: - API_KEY=$API_KEY - EMBEDDINGS_KEY=$API_KEY + - LLM_PROVIDER=$LLM_PROVIDER - LLM_NAME=$LLM_NAME - CELERY_BROKER_URL=redis://redis:6379/0 - CELERY_RESULT_BACKEND=redis://redis:6379/1 - MONGO_URI=mongodb://mongo:27017/docsgpt - CACHE_REDIS_URL=redis://redis:6379/2 - OPENAI_BASE_URL=$OPENAI_BASE_URL - - MODEL_NAME=$MODEL_NAME ports: - "7091:7091" volumes: @@ -41,6 +42,7 @@ services: environment: - API_KEY=$API_KEY - EMBEDDINGS_KEY=$API_KEY + - LLM_PROVIDER=$LLM_PROVIDER - LLM_NAME=$LLM_NAME - CELERY_BROKER_URL=redis://redis:6379/0 - CELERY_RESULT_BACKEND=redis://redis:6379/1 diff --git a/deployment/k8s/docsgpt-secrets.yaml b/deployment/k8s/docsgpt-secrets.yaml index 45a00973..2d494e8e 100644 --- a/deployment/k8s/docsgpt-secrets.yaml +++ b/deployment/k8s/docsgpt-secrets.yaml @@ -4,7 +4,7 @@ metadata: name: docsgpt-secrets type: Opaque data: - LLM_NAME: ZG9jc2dwdA== + LLM_PROVIDER: ZG9jc2dwdA== INTERNAL_KEY: aW50ZXJuYWw= CELERY_BROKER_URL: cmVkaXM6Ly9yZWRpcy1zZXJ2aWNlOjYzNzkvMA== CELERY_RESULT_BACKEND: cmVkaXM6Ly9yZWRpcy1zZXJ2aWNlOjYzNzkvMA== diff --git a/docs/pages/Deploying/Docker-Deploying.mdx b/docs/pages/Deploying/Docker-Deploying.mdx index 7bb70729..382db5fb 100644 --- a/docs/pages/Deploying/Docker-Deploying.mdx +++ b/docs/pages/Deploying/Docker-Deploying.mdx @@ -37,7 +37,7 @@ The fastest way to try out DocsGPT is by using the public API endpoint. This req Open the `.env` file and add the following lines: ``` - LLM_NAME=docsgpt + LLM_PROVIDER=docsgpt VITE_API_STREAMING=true ``` @@ -93,16 +93,16 @@ There are two Ollama optional files: 3. **Pull the Ollama Model:** - **Crucially, after launching with Ollama, you need to pull the desired model into the Ollama container.** Find the `MODEL_NAME` you configured in your `.env` file (e.g., `llama3.2:1b`). Then execute the following command to pull the model *inside* the running Ollama container: + **Crucially, after launching with Ollama, you need to pull the desired model into the Ollama container.** Find the `LLM_NAME` you configured in your `.env` file (e.g., `llama3.2:1b`). Then execute the following command to pull the model *inside* the running Ollama container: ```bash - docker compose -f deployment/docker-compose.yaml -f deployment/optional/docker-compose.optional.ollama-cpu.yaml exec -it ollama ollama pull + docker compose -f deployment/docker-compose.yaml -f deployment/optional/docker-compose.optional.ollama-cpu.yaml exec -it ollama ollama pull ``` or (for GPU): ```bash - docker compose -f deployment/docker-compose.yaml -f deployment/optional/docker-compose.optional.ollama-gpu.yaml exec -it ollama ollama pull + docker compose -f deployment/docker-compose.yaml -f deployment/optional/docker-compose.optional.ollama-gpu.yaml exec -it ollama ollama pull ``` - Replace `` with the actual model name from your `.env` file. + Replace `` with the actual model name from your `.env` file. 4. **Access DocsGPT in your browser:** diff --git a/docs/pages/Deploying/DocsGPT-Settings.mdx b/docs/pages/Deploying/DocsGPT-Settings.mdx index 239b35d7..92537934 100644 --- a/docs/pages/Deploying/DocsGPT-Settings.mdx +++ b/docs/pages/Deploying/DocsGPT-Settings.mdx @@ -20,9 +20,9 @@ The easiest and recommended way to configure basic settings is by using a `.env` **Example `.env` file structure:** ``` -LLM_NAME=openai +LLM_PROVIDER=openai API_KEY=YOUR_OPENAI_API_KEY -MODEL_NAME=gpt-4o +LLM_NAME=gpt-4o ``` ### 2. Configuration via `settings.py` file (Advanced) @@ -37,7 +37,7 @@ While modifying `settings.py` offers more flexibility, it's generally recommende Here are some of the most fundamental settings you'll likely want to configure: -- **`LLM_NAME`**: This setting determines which Large Language Model (LLM) provider DocsGPT will use. It tells DocsGPT which API to interact with. +- **`LLM_PROVIDER`**: This setting determines which Large Language Model (LLM) provider DocsGPT will use. It tells DocsGPT which API to interact with. - **Common values:** - `docsgpt`: Use the DocsGPT Public API Endpoint (simple and free, as offered in `setup.sh` option 1). @@ -49,11 +49,11 @@ Here are some of the most fundamental settings you'll likely want to configure: - `azure_openai`: Use Azure OpenAI Service. - `openai` (when using local inference engines like Ollama, Llama.cpp, TGI, etc.): This signals DocsGPT to use an OpenAI-compatible API format, even if the actual LLM is running locally. -- **`MODEL_NAME`**: Specifies the specific model to use from the chosen LLM provider. The available models depend on the `LLM_NAME` you've selected. +- **`LLM_NAME`**: Specifies the specific model to use from the chosen LLM provider. The available models depend on the `LLM_PROVIDER` you've selected. - **Examples:** - - For `LLM_NAME=openai`: `gpt-4o` - - For `LLM_NAME=google`: `gemini-2.0-flash` + - For `LLM_PROVIDER=openai`: `gpt-4o` + - For `LLM_PROVIDER=google`: `gemini-2.0-flash` - For local models (e.g., Ollama): `llama3.2:1b` (or any model name available in your setup). - **`EMBEDDINGS_NAME`**: This setting defines which embedding model DocsGPT will use to generate vector embeddings for your documents. Embeddings are numerical representations of text that allow DocsGPT to understand the semantic meaning of your documents for efficient search and retrieval. @@ -63,7 +63,7 @@ Here are some of the most fundamental settings you'll likely want to configure: - **`API_KEY`**: Required for most cloud-based LLM providers. This is your authentication key to access the LLM provider's API. You'll need to obtain this key from your chosen provider's platform. -- **`OPENAI_BASE_URL`**: Specifically used when `LLM_NAME` is set to `openai` but you are connecting to a local inference engine (like Ollama, Llama.cpp, etc.) that exposes an OpenAI-compatible API. This setting tells DocsGPT where to find your local LLM server. +- **`OPENAI_BASE_URL`**: Specifically used when `LLM_PROVIDER` is set to `openai` but you are connecting to a local inference engine (like Ollama, Llama.cpp, etc.) that exposes an OpenAI-compatible API. This setting tells DocsGPT where to find your local LLM server. ## Configuration Examples @@ -74,9 +74,9 @@ Let's look at some concrete examples of how to configure these settings in your To use OpenAI's `gpt-4o` model, you would configure your `.env` file like this: ``` -LLM_NAME=openai +LLM_PROVIDER=openai API_KEY=YOUR_OPENAI_API_KEY # Replace with your actual OpenAI API key -MODEL_NAME=gpt-4o +LLM_NAME=gpt-4o ``` Make sure to replace `YOUR_OPENAI_API_KEY` with your actual OpenAI API key. @@ -86,14 +86,14 @@ Make sure to replace `YOUR_OPENAI_API_KEY` with your actual OpenAI API key. To use a local Ollama server with the `llama3.2:1b` model, you would configure your `.env` file like this: ``` -LLM_NAME=openai # Using OpenAI compatible API format for local models +LLM_PROVIDER=openai # Using OpenAI compatible API format for local models API_KEY=None # API Key is not needed for local Ollama -MODEL_NAME=llama3.2:1b +LLM_NAME=llama3.2:1b OPENAI_BASE_URL=http://host.docker.internal:11434/v1 # Default Ollama API URL within Docker EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2 # You can also run embeddings locally if needed ``` -In this case, even though you are using Ollama locally, `LLM_NAME` is set to `openai` because Ollama (and many other local inference engines) are designed to be API-compatible with OpenAI. `OPENAI_BASE_URL` points DocsGPT to the local Ollama server. +In this case, even though you are using Ollama locally, `LLM_PROVIDER` is set to `openai` because Ollama (and many other local inference engines) are designed to be API-compatible with OpenAI. `OPENAI_BASE_URL` points DocsGPT to the local Ollama server. ## Authentication Settings diff --git a/docs/pages/Guides/How-to-use-different-LLM.mdx b/docs/pages/Guides/How-to-use-different-LLM.mdx index 3bc8477d..217bc690 100644 --- a/docs/pages/Guides/How-to-use-different-LLM.mdx +++ b/docs/pages/Guides/How-to-use-different-LLM.mdx @@ -32,9 +32,9 @@ Choose the LLM of your choice. ### For Open source llm change: ### Step 1 -For open source version please edit `LLM_NAME`, `MODEL_NAME` and others in the .env file. Refer to [⚙️ App Configuration](/Deploying/DocsGPT-Settings) for more information. +For open source version please edit `LLM_PROVIDER`, `LLM_NAME` and others in the .env file. Refer to [⚙️ App Configuration](/Deploying/DocsGPT-Settings) for more information. ### Step 2 -Visit [☁️ Cloud Providers](/Models/cloud-providers) for the updated list of online models. Make sure you have the right API_KEY and correct LLM_NAME. +Visit [☁️ Cloud Providers](/Models/cloud-providers) for the updated list of online models. Make sure you have the right API_KEY and correct LLM_PROVIDER. For self-hosted please visit [🖥️ Local Inference](/Models/local-inference). diff --git a/docs/pages/Models/cloud-providers.mdx b/docs/pages/Models/cloud-providers.mdx index 86f2d132..36e737fa 100644 --- a/docs/pages/Models/cloud-providers.mdx +++ b/docs/pages/Models/cloud-providers.mdx @@ -13,15 +13,15 @@ The primary method for configuring your LLM provider in DocsGPT is through the ` To connect to a cloud LLM provider, you will typically need to configure the following basic settings in your `.env` file: -* **`LLM_NAME`**: This setting is essential and identifies the specific cloud provider you wish to use (e.g., `openai`, `google`, `anthropic`). -* **`MODEL_NAME`**: Specifies the exact model you want to utilize from your chosen provider (e.g., `gpt-4o`, `gemini-2.0-flash`, `claude-3-5-sonnet-latest`). Refer to your provider's documentation for a list of available models. +* **`LLM_PROVIDER`**: This setting is essential and identifies the specific cloud provider you wish to use (e.g., `openai`, `google`, `anthropic`). +* **`LLM_NAME`**: Specifies the exact model you want to utilize from your chosen provider (e.g., `gpt-4o`, `gemini-2.0-flash`, `claude-3-5-sonnet-latest`). Refer to your provider's documentation for a list of available models. * **`API_KEY`**: Almost all cloud LLM providers require an API key for authentication. Obtain your API key from your chosen provider's platform and securely store it in your `.env` file. ## Explicitly Supported Cloud Providers -DocsGPT offers direct, streamlined support for the following cloud LLM providers, making configuration straightforward. The table below outlines the `LLM_NAME` and example `MODEL_NAME` values to use for each provider in your `.env` file. +DocsGPT offers direct, streamlined support for the following cloud LLM providers, making configuration straightforward. The table below outlines the `LLM_PROVIDER` and example `LLM_NAME` values to use for each provider in your `.env` file. -| Provider | `LLM_NAME` | Example `MODEL_NAME` | +| Provider | `LLM_PROVIDER` | Example `LLM_NAME` | | :--------------------------- | :------------- | :-------------------------- | | DocsGPT Public API | `docsgpt` | `None` | | OpenAI | `openai` | `gpt-4o` | @@ -35,16 +35,16 @@ DocsGPT offers direct, streamlined support for the following cloud LLM providers DocsGPT's flexible architecture allows you to connect to any cloud provider that offers an API compatible with the OpenAI API standard. This opens up a vast ecosystem of LLM services. -To connect to an OpenAI-compatible cloud provider, you will still use `LLM_NAME=openai` in your `.env` file. However, you will also need to specify the API endpoint of your chosen provider using the `OPENAI_BASE_URL` setting. You will also likely need to provide an `API_KEY` and `MODEL_NAME` as required by that provider. +To connect to an OpenAI-compatible cloud provider, you will still use `LLM_PROVIDER=openai` in your `.env` file. However, you will also need to specify the API endpoint of your chosen provider using the `OPENAI_BASE_URL` setting. You will also likely need to provide an `API_KEY` and `LLM_NAME` as required by that provider. **Example for DeepSeek (OpenAI-Compatible API):** To connect to DeepSeek, which offers an OpenAI-compatible API, your `.env` file could be configured as follows: ``` -LLM_NAME=openai +LLM_PROVIDER=openai API_KEY=YOUR_API_KEY # Your DeepSeek API key -MODEL_NAME=deepseek-chat # Or your desired DeepSeek model name +LLM_NAME=deepseek-chat # Or your desired DeepSeek model name OPENAI_BASE_URL=https://api.deepseek.com/v1 # DeepSeek's OpenAI API URL ``` diff --git a/docs/pages/Models/embeddings.md b/docs/pages/Models/embeddings.md index 6dfb89b6..68015db7 100644 --- a/docs/pages/Models/embeddings.md +++ b/docs/pages/Models/embeddings.md @@ -60,7 +60,7 @@ To use OpenAI's `text-embedding-ada-002` embedding model, you need to set `EMBED **Example `.env` configuration for OpenAI Embeddings:** ``` -LLM_NAME=openai +LLM_PROVIDER=openai API_KEY=YOUR_OPENAI_API_KEY # Your OpenAI API Key EMBEDDINGS_NAME=openai_text-embedding-ada-002 ``` diff --git a/docs/pages/Models/local-inference.mdx b/docs/pages/Models/local-inference.mdx index 4aa6bca2..0bba907b 100644 --- a/docs/pages/Models/local-inference.mdx +++ b/docs/pages/Models/local-inference.mdx @@ -15,8 +15,8 @@ Setting up a local inference engine with DocsGPT is configured through environme To connect to a local inference engine, you will generally need to configure these settings in your `.env` file: -* **`LLM_NAME`**: Crucially set this to `openai`. This tells DocsGPT to use the OpenAI-compatible API format for communication, even though the LLM is local. -* **`MODEL_NAME`**: Specify the model name as recognized by your local inference engine. This might be a model identifier or left as `None` if the engine doesn't require explicit model naming in the API request. +* **`LLM_PROVIDER`**: Crucially set this to `openai`. This tells DocsGPT to use the OpenAI-compatible API format for communication, even though the LLM is local. +* **`LLM_NAME`**: Specify the model name as recognized by your local inference engine. This might be a model identifier or left as `None` if the engine doesn't require explicit model naming in the API request. * **`OPENAI_BASE_URL`**: This is essential. Set this to the base URL of your local inference engine's API endpoint. This tells DocsGPT where to find your local LLM server. * **`API_KEY`**: Generally, for local inference engines, you can set `API_KEY=None` as authentication is usually not required in local setups. @@ -24,16 +24,16 @@ To connect to a local inference engine, you will generally need to configure the DocsGPT is readily configurable to work with the following local inference engines, all communicating via the OpenAI API format. Here are example `OPENAI_BASE_URL` values for each, based on default setups: -| Inference Engine | `LLM_NAME` | `OPENAI_BASE_URL` | -| :---------------------------- | :--------- | :------------------------- | -| LLaMa.cpp | `openai` | `http://localhost:8000/v1` | -| Ollama | `openai` | `http://localhost:11434/v1` | -| Text Generation Inference (TGI)| `openai` | `http://localhost:8080/v1` | -| SGLang | `openai` | `http://localhost:30000/v1` | -| vLLM | `openai` | `http://localhost:8000/v1` | -| Aphrodite | `openai` | `http://localhost:2242/v1` | -| FriendliAI | `openai` | `http://localhost:8997/v1` | -| LMDeploy | `openai` | `http://localhost:23333/v1` | +| Inference Engine | `LLM_PROVIDER` | `OPENAI_BASE_URL` | +| :---------------------------- | :------------- | :------------------------- | +| LLaMa.cpp | `openai` | `http://localhost:8000/v1` | +| Ollama | `openai` | `http://localhost:11434/v1` | +| Text Generation Inference (TGI)| `openai` | `http://localhost:8080/v1` | +| SGLang | `openai` | `http://localhost:30000/v1` | +| vLLM | `openai` | `http://localhost:8000/v1` | +| Aphrodite | `openai` | `http://localhost:2242/v1` | +| FriendliAI | `openai` | `http://localhost:8997/v1` | +| LMDeploy | `openai` | `http://localhost:23333/v1` | **Important Note on `localhost` vs `host.docker.internal`:** diff --git a/setup.sh b/setup.sh index 5cf013fc..b072d546 100755 --- a/setup.sh +++ b/setup.sh @@ -169,7 +169,7 @@ prompt_ollama_options() { # 1) Use DocsGPT Public API Endpoint (simple and free) use_docs_public_api_endpoint() { echo -e "\n${NC}Setting up DocsGPT Public API Endpoint...${NC}" - echo "LLM_NAME=docsgpt" > .env + echo "LLM_PROVIDER=docsgpt" > .env echo "VITE_API_STREAMING=true" >> .env echo -e "${GREEN}.env file configured for DocsGPT Public API.${NC}" @@ -237,13 +237,12 @@ serve_local_ollama() { echo -e "\n${NC}Configuring for Ollama ($(echo "$docker_compose_file_suffix" | tr '[:lower:]' '[:upper:]'))...${NC}" # Using tr for uppercase - more compatible echo "API_KEY=xxxx" > .env # Placeholder API Key - echo "LLM_NAME=openai" >> .env - echo "MODEL_NAME=$model_name" >> .env + echo "LLM_PROVIDER=openai" >> .env + echo "LLM_NAME=$model_name" >> .env echo "VITE_API_STREAMING=true" >> .env echo "OPENAI_BASE_URL=http://ollama:11434/v1" >> .env echo "EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2" >> .env echo -e "${GREEN}.env file configured for Ollama ($(echo "$docker_compose_file_suffix" | tr '[:lower:]' '[:upper:]')${NC}${GREEN}).${NC}" - echo -e "${YELLOW}Note: MODEL_NAME is set to '${BOLD}$model_name${NC}${YELLOW}'. You can change it later in the .env file.${NC}" check_and_start_docker @@ -350,8 +349,8 @@ connect_local_inference_engine() { echo -e "\n${NC}Configuring for Local Inference Engine: ${BOLD}${engine_name}...${NC}" echo "API_KEY=None" > .env - echo "LLM_NAME=openai" >> .env - echo "MODEL_NAME=$model_name" >> .env + echo "LLM_PROVIDER=openai" >> .env + echo "LLM_NAME=$model_name" >> .env echo "VITE_API_STREAMING=true" >> .env echo "OPENAI_BASE_URL=$openai_base_url" >> .env echo "EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2" >> .env @@ -381,7 +380,7 @@ connect_local_inference_engine() { # 4) Connect Cloud API Provider connect_cloud_api_provider() { - local provider_choice api_key llm_name + local provider_choice api_key llm_provider local setup_result # Variable to store the return status get_api_key() { @@ -395,43 +394,43 @@ connect_cloud_api_provider() { case "$provider_choice" in 1) # OpenAI provider_name="OpenAI" - llm_name="openai" + llm_provider="openai" model_name="gpt-4o" get_api_key break ;; 2) # Google provider_name="Google (Vertex AI, Gemini)" - llm_name="google" + llm_provider="google" model_name="gemini-2.0-flash" get_api_key break ;; 3) # Anthropic provider_name="Anthropic (Claude)" - llm_name="anthropic" + llm_provider="anthropic" model_name="claude-3-5-sonnet-latest" get_api_key break ;; 4) # Groq provider_name="Groq" - llm_name="groq" + llm_provider="groq" model_name="llama-3.1-8b-instant" get_api_key break ;; 5) # HuggingFace Inference API provider_name="HuggingFace Inference API" - llm_name="huggingface" + llm_provider="huggingface" model_name="meta-llama/Llama-3.1-8B-Instruct" get_api_key break ;; 6) # Azure OpenAI provider_name="Azure OpenAI" - llm_name="azure_openai" + llm_provider="azure_openai" model_name="gpt-4o" get_api_key break ;; 7) # Novita provider_name="Novita" - llm_name="novita" + llm_provider="novita" model_name="deepseek/deepseek-r1" get_api_key break ;; @@ -442,8 +441,8 @@ connect_cloud_api_provider() { echo -e "\n${NC}Configuring for Cloud API Provider: ${BOLD}${provider_name}...${NC}" echo "API_KEY=$api_key" > .env - echo "LLM_NAME=$llm_name" >> .env - echo "MODEL_NAME=$model_name" >> .env + echo "LLM_PROVIDER=$llm_provider" >> .env + echo "LLM_NAME=$model_name" >> .env echo "VITE_API_STREAMING=true" >> .env echo -e "${GREEN}.env file configured for ${BOLD}${provider_name}${NC}${GREEN}.${NC}" From 9b839655a73157a7dd96b1d3debc98322da0880b Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 11 Jun 2025 19:28:15 +0530 Subject: [PATCH 09/10] refactor: improve tool call result handling and display in conversation components --- application/agents/base.py | 5 +++- .../src/conversation/ConversationBubble.tsx | 26 ++++++++++++------- .../src/conversation/ConversationMessages.tsx | 2 +- .../src/conversation/conversationSlice.ts | 9 ++++--- 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/application/agents/base.py b/application/agents/base.py index adebc125..c9cc579d 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -197,7 +197,9 @@ class BaseAgent(ABC): else: print(f"Executing tool: {action_name} with args: {call_args}") result = tool.execute_action(action_name, **parameters) - tool_call_data["result"] = result + tool_call_data["result"] = ( + f"{str(result)[:50]}..." if len(str(result)) > 50 else result + ) yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}} self.tool_calls.append(tool_call_data) @@ -213,6 +215,7 @@ class BaseAgent(ABC): if len(str(tool_call["result"])) > 50 else tool_call["result"] ), + "status": "completed", } for tool_call in self.tool_calls ] diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index 920005e3..c1c1f553 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -26,7 +26,9 @@ import UserIcon from '../assets/user.svg'; import Accordion from '../components/Accordion'; import Avatar from '../components/Avatar'; import CopyButton from '../components/CopyButton'; +import MermaidRenderer from '../components/MermaidRenderer'; import Sidebar from '../components/Sidebar'; +import Spinner from '../components/Spinner'; import SpeakButton from '../components/TextToSpeechButton'; import { useDarkTheme, useOutsideAlerter } from '../hooks'; import { @@ -36,7 +38,6 @@ import { import classes from './ConversationBubble.module.css'; import { FEEDBACK, MESSAGE_TYPE } from './conversationModels'; import { ToolCallsType } from './types'; -import MermaidRenderer from '../components/MermaidRenderer'; const DisableSourceFE = import.meta.env.VITE_DISABLE_SOURCE_FE || false; @@ -741,7 +742,7 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) { {isToolCallsOpen && ( -
+
{toolCalls.map((toolCall, index) => (

-

- - {JSON.stringify(toolCall.result, null, 2)} + {toolCall.status === 'pending' && ( + + -

+ )} + {toolCall.status === 'completed' && ( +

+ + {JSON.stringify(toolCall.result, null, 2)} + +

+ )}
diff --git a/frontend/src/conversation/ConversationMessages.tsx b/frontend/src/conversation/ConversationMessages.tsx index 5c2150a6..ac40bf7b 100644 --- a/frontend/src/conversation/ConversationMessages.tsx +++ b/frontend/src/conversation/ConversationMessages.tsx @@ -131,7 +131,7 @@ export default function ConversationMessages({ ? LAST_BUBBLE_MARGIN : DEFAULT_BUBBLE_MARGIN; - if (query.thought || query.response) { + if (query.thought || query.response || query.tool_calls) { const isCurrentlyStreaming = status === 'loading' && index === queries.length - 1; return ( diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index 03532792..2e15a2ea 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -293,10 +293,11 @@ export const conversationSlice = createSlice({ ); if (existingIndex !== -1) { - Object.assign( - state.queries[index].tool_calls[existingIndex], - tool_call, - ); + const existingCall = state.queries[index].tool_calls[existingIndex]; + state.queries[index].tool_calls[existingIndex] = { + ...existingCall, + ...tool_call, + }; } else state.queries[index].tool_calls.push(tool_call); }, updateQuery( From b414f79bc573f39152c06094d3ae2108a2066237 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 11 Jun 2025 19:37:32 +0530 Subject: [PATCH 10/10] fix: adjust width of tool calls display in ConversationBubble component --- frontend/src/conversation/ConversationBubble.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index c70ed6f7..1bf8dafe 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -769,7 +769,7 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
{isToolCallsOpen && ( -
+
{toolCalls.map((toolCall, index) => (