diff --git a/application/agents/base.py b/application/agents/base.py index d44244cf..c9cc579d 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -2,16 +2,18 @@ import uuid from abc import ABC, abstractmethod from typing import Dict, Generator, List, Optional -from application.agents.llm_handler import get_llm_handler +from bson.objectid import ObjectId + from application.agents.tools.tool_action_parser import ToolActionParser from application.agents.tools.tool_manager import ToolManager from application.core.mongo_db import MongoDB +from application.core.settings import settings + +from application.llm.handlers.handler_creator import LLMHandlerCreator from application.llm.llm_creator import LLMCreator from application.logging import build_stack_data, log_activity, LogContext from application.retriever.base import BaseRetriever -from application.core.settings import settings -from bson.objectid import ObjectId class BaseAgent(ABC): @@ -45,7 +47,9 @@ class BaseAgent(ABC): user_api_key=user_api_key, decoded_token=decoded_token, ) - self.llm_handler = get_llm_handler(llm_name) + self.llm_handler = LLMHandlerCreator.create_handler( + llm_name if llm_name else "default" + ) self.attachments = attachments or [] @log_activity() @@ -132,6 +136,15 @@ class BaseAgent(ABC): parser = ToolActionParser(self.llm.__class__.__name__) tool_id, action_name, call_args = parser.parse_args(call) + call_id = getattr(call, "id", None) or str(uuid.uuid4()) + tool_call_data = { + "tool_name": tools_dict[tool_id]["name"], + "call_id": call_id, + "action_name": f"{action_name}_{tool_id}", + "arguments": call_args, + } + yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}} + tool_data = tools_dict[tool_id] action_data = ( tool_data["config"]["actions"][action_name] @@ -184,19 +197,29 @@ class BaseAgent(ABC): else: print(f"Executing tool: {action_name} with args: {call_args}") result = tool.execute_action(action_name, **parameters) - call_id = getattr(call, "id", None) + tool_call_data["result"] = ( + f"{str(result)[:50]}..." if len(str(result)) > 50 else result + ) - tool_call_data = { - "tool_name": tool_data["name"], - "call_id": call_id if call_id is not None else "None", - "action_name": f"{action_name}_{tool_id}", - "arguments": call_args, - "result": result, - } + yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}} self.tool_calls.append(tool_call_data) return result, call_id + def _get_truncated_tool_calls(self): + return [ + { + **tool_call, + "result": ( + f"{str(tool_call['result'])[:50]}..." + if len(str(tool_call["result"])) > 50 + else tool_call["result"] + ), + "status": "completed", + } + for tool_call in self.tool_calls + ] + def _build_messages( self, system_prompt: str, @@ -252,9 +275,16 @@ class BaseAgent(ABC): return retrieved_data def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None): - resp = self.llm.gen_stream( - model=self.gpt_model, messages=messages, tools=self.tools - ) + gen_kwargs = {"model": self.gpt_model, "messages": messages} + + if ( + hasattr(self.llm, "_supports_tools") + and self.llm._supports_tools + and self.tools + ): + gen_kwargs["tools"] = self.tools + resp = self.llm.gen_stream(**gen_kwargs) + if log_context: data = build_stack_data(self.llm, exclude_attributes=["client"]) log_context.stacks.append({"component": "llm", "data": data}) @@ -268,10 +298,30 @@ class BaseAgent(ABC): log_context: Optional[LogContext] = None, attachments: Optional[List[Dict]] = None, ): - resp = self.llm_handler.handle_response( - self, resp, tools_dict, messages, attachments + resp = self.llm_handler.process_message_flow( + self, resp, tools_dict, messages, attachments, True ) if log_context: data = build_stack_data(self.llm_handler, exclude_attributes=["tool_calls"]) log_context.stacks.append({"component": "llm_handler", "data": data}) return resp + + def _handle_response(self, response, tools_dict, messages, log_context): + if isinstance(response, str): + yield {"answer": response} + return + if hasattr(response, "message") and getattr(response.message, "content", None): + yield {"answer": response.message.content} + return + + processed_response_gen = self._llm_handler( + response, tools_dict, messages, log_context, self.attachments + ) + + for event in processed_response_gen: + if isinstance(event, str): + yield {"answer": event} + elif hasattr(event, "message") and getattr(event.message, "content", None): + yield {"answer": event.message.content} + elif isinstance(event, dict) and "type" in event: + yield event diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py index b371123b..6fe73de0 100644 --- a/application/agents/classic_agent.py +++ b/application/agents/classic_agent.py @@ -1,8 +1,6 @@ from typing import Dict, Generator - from application.agents.base import BaseAgent from application.logging import LogContext - from application.retriever.base import BaseRetriever import logging @@ -10,55 +8,46 @@ logger = logging.getLogger(__name__) class ClassicAgent(BaseAgent): + """A simplified classic agent with clear execution flow. + + Usage: + 1. Processes a query through retrieval + 2. Sets up available tools + 3. Generates responses using LLM + 4. Handles tool interactions if needed + 5. Returns standardized outputs + + Easy to extend by overriding specific steps. + """ + def _gen_inner( self, query: str, retriever: BaseRetriever, log_context: LogContext ) -> Generator[Dict, None, None]: + # Step 1: Retrieve relevant data retrieved_data = self._retriever_search(retriever, query, log_context) - if self.user_api_key: - tools_dict = self._get_tools(self.user_api_key) - else: - tools_dict = self._get_user_tools(self.user) + + # Step 2: Prepare tools + tools_dict = ( + self._get_user_tools(self.user) + if not self.user_api_key + else self._get_tools(self.user_api_key) + ) self._prepare_tools(tools_dict) + # Step 3: Build and process messages messages = self._build_messages(self.prompt, query, retrieved_data) + llm_response = self._llm_gen(messages, log_context) - resp = self._llm_gen(messages, log_context) + # Step 4: Handle the response + yield from self._handle_response( + llm_response, tools_dict, messages, log_context + ) - attachments = self.attachments - - if isinstance(resp, str): - yield {"answer": resp} - return - if ( - hasattr(resp, "message") - and hasattr(resp.message, "content") - and resp.message.content is not None - ): - yield {"answer": resp.message.content} - return - - resp = self._llm_handler(resp, tools_dict, messages, log_context, attachments) - - if isinstance(resp, str): - yield {"answer": resp} - elif ( - hasattr(resp, "message") - and hasattr(resp.message, "content") - and resp.message.content is not None - ): - yield {"answer": resp.message.content} - else: - for line in resp: - if isinstance(line, str): - yield {"answer": line} + # Step 5: Return metadata + yield {"sources": retrieved_data} + yield {"tool_calls": self._get_truncated_tool_calls()} + # Log tool calls for debugging log_context.stacks.append( {"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}} ) - - yield {"sources": retrieved_data} - # clean tool_call_data only send first 50 characters of tool_call['result'] - for tool_call in self.tool_calls: - if len(str(tool_call["result"])) > 50: - tool_call["result"] = str(tool_call["result"])[:50] + "..." - yield {"tool_calls": self.tool_calls.copy()} diff --git a/application/agents/llm_handler.py b/application/agents/llm_handler.py deleted file mode 100644 index 1b995f71..00000000 --- a/application/agents/llm_handler.py +++ /dev/null @@ -1,351 +0,0 @@ -import json -import logging -from abc import ABC, abstractmethod - -from application.logging import build_stack_data - -logger = logging.getLogger(__name__) - - -class LLMHandler(ABC): - def __init__(self): - self.llm_calls = [] - self.tool_calls = [] - - @abstractmethod - def handle_response(self, agent, resp, tools_dict, messages, attachments=None, **kwargs): - pass - - def prepare_messages_with_attachments(self, agent, messages, attachments=None): - """ - Prepare messages with attachment content if available. - - Args: - agent: The current agent instance. - messages (list): List of message dictionaries. - attachments (list): List of attachment dictionaries with content. - - Returns: - list: Messages with attachment context added to the system prompt. - """ - if not attachments: - return messages - - logger.info(f"Preparing messages with {len(attachments)} attachments") - - supported_types = agent.llm.get_supported_attachment_types() - - supported_attachments = [] - unsupported_attachments = [] - - for attachment in attachments: - mime_type = attachment.get('mime_type') - if mime_type in supported_types: - supported_attachments.append(attachment) - else: - unsupported_attachments.append(attachment) - - # Process supported attachments with the LLM's custom method - prepared_messages = messages - if supported_attachments: - logger.info(f"Processing {len(supported_attachments)} supported attachments with {agent.llm.__class__.__name__}'s method") - prepared_messages = agent.llm.prepare_messages_with_attachments(messages, supported_attachments) - - # Process unsupported attachments with the default method - if unsupported_attachments: - logger.info(f"Processing {len(unsupported_attachments)} unsupported attachments with default method") - prepared_messages = self._append_attachment_content_to_system(prepared_messages, unsupported_attachments) - - return prepared_messages - - def _append_attachment_content_to_system(self, messages, attachments): - """ - Default method to append attachment content to the system prompt. - - Args: - messages (list): List of message dictionaries. - attachments (list): List of attachment dictionaries with content. - - Returns: - list: Messages with attachment context added to the system prompt. - """ - prepared_messages = messages.copy() - - attachment_texts = [] - for attachment in attachments: - logger.info(f"Adding attachment {attachment.get('id')} to context") - if 'content' in attachment: - attachment_texts.append(f"Attached file content:\n\n{attachment['content']}") - - if attachment_texts: - combined_attachment_text = "\n\n".join(attachment_texts) - - system_found = False - for i in range(len(prepared_messages)): - if prepared_messages[i].get("role") == "system": - prepared_messages[i]["content"] += f"\n\n{combined_attachment_text}" - system_found = True - break - - if not system_found: - prepared_messages.insert(0, {"role": "system", "content": combined_attachment_text}) - - return prepared_messages - -class OpenAILLMHandler(LLMHandler): - def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True): - - messages = self.prepare_messages_with_attachments(agent, messages, attachments) - logger.info(f"Messages with attachments: {messages}") - if not stream: - while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls": - message = json.loads(resp.model_dump_json())["message"] - keys_to_remove = {"audio", "function_call", "refusal"} - filtered_data = { - k: v for k, v in message.items() if k not in keys_to_remove - } - messages.append(filtered_data) - - tool_calls = resp.message.tool_calls - for call in tool_calls: - try: - self.tool_calls.append(call) - tool_response, call_id = agent._execute_tool_action( - tools_dict, call - ) - function_call_dict = { - "function_call": { - "name": call.function.name, - "args": call.function.arguments, - "call_id": call_id, - } - } - function_response_dict = { - "function_response": { - "name": call.function.name, - "response": {"result": tool_response}, - "call_id": call_id, - } - } - - messages.append( - {"role": "assistant", "content": [function_call_dict]} - ) - messages.append( - {"role": "tool", "content": [function_response_dict]} - ) - - messages = self.prepare_messages_with_attachments(agent, messages, attachments) - except Exception as e: - logging.error(f"Error executing tool: {str(e)}", exc_info=True) - messages.append( - { - "role": "tool", - "content": f"Error executing tool: {str(e)}", - "tool_call_id": call_id, - } - ) - resp = agent.llm.gen_stream( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - self.llm_calls.append(build_stack_data(agent.llm)) - return resp - - else: - text_buffer = "" - while True: - tool_calls = {} - for chunk in resp: - if isinstance(chunk, str) and len(chunk) > 0: - yield chunk - continue - elif hasattr(chunk, "delta"): - chunk_delta = chunk.delta - - if ( - hasattr(chunk_delta, "tool_calls") - and chunk_delta.tool_calls is not None - ): - for tool_call in chunk_delta.tool_calls: - index = tool_call.index - if index not in tool_calls: - tool_calls[index] = { - "id": "", - "function": {"name": "", "arguments": ""}, - } - - current = tool_calls[index] - if tool_call.id: - current["id"] = tool_call.id - if tool_call.function.name: - current["function"][ - "name" - ] = tool_call.function.name - if tool_call.function.arguments: - current["function"][ - "arguments" - ] += tool_call.function.arguments - tool_calls[index] = current - - if ( - hasattr(chunk, "finish_reason") - and chunk.finish_reason == "tool_calls" - ): - for index in sorted(tool_calls.keys()): - call = tool_calls[index] - try: - self.tool_calls.append(call) - tool_response, call_id = agent._execute_tool_action( - tools_dict, call - ) - if isinstance(call["function"]["arguments"], str): - call["function"]["arguments"] = json.loads(call["function"]["arguments"]) - - function_call_dict = { - "function_call": { - "name": call["function"]["name"], - "args": call["function"]["arguments"], - "call_id": call["id"], - } - } - function_response_dict = { - "function_response": { - "name": call["function"]["name"], - "response": {"result": tool_response}, - "call_id": call["id"], - } - } - - messages.append( - { - "role": "assistant", - "content": [function_call_dict], - } - ) - messages.append( - { - "role": "tool", - "content": [function_response_dict], - } - ) - - except Exception as e: - logging.error(f"Error executing tool: {str(e)}", exc_info=True) - messages.append( - { - "role": "assistant", - "content": f"Error executing tool: {str(e)}", - } - ) - tool_calls = {} - if hasattr(chunk_delta, "content") and chunk_delta.content: - # Add to buffer or yield immediately based on your preference - text_buffer += chunk_delta.content - yield text_buffer - text_buffer = "" - - if ( - hasattr(chunk, "finish_reason") - and chunk.finish_reason == "stop" - ): - return resp - elif isinstance(chunk, str) and len(chunk) == 0: - continue - - logger.info(f"Regenerating with messages: {messages}") - resp = agent.llm.gen_stream( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - self.llm_calls.append(build_stack_data(agent.llm)) - - -class GoogleLLMHandler(LLMHandler): - def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True): - from google.genai import types - - messages = self.prepare_messages_with_attachments(agent, messages, attachments) - - while True: - if not stream: - response = agent.llm.gen( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - self.llm_calls.append(build_stack_data(agent.llm)) - if response.candidates and response.candidates[0].content.parts: - tool_call_found = False - for part in response.candidates[0].content.parts: - if part.function_call: - tool_call_found = True - self.tool_calls.append(part.function_call) - tool_response, call_id = agent._execute_tool_action( - tools_dict, part.function_call - ) - function_response_part = types.Part.from_function_response( - name=part.function_call.name, - response={"result": tool_response}, - ) - - messages.append( - {"role": "model", "content": [part.to_json_dict()]} - ) - messages.append( - { - "role": "tool", - "content": [function_response_part.to_json_dict()], - } - ) - - if ( - not tool_call_found - and response.candidates[0].content.parts - and response.candidates[0].content.parts[0].text - ): - return response.candidates[0].content.parts[0].text - elif not tool_call_found: - return response.candidates[0].content.parts - - else: - return response - - else: - response = agent.llm.gen_stream( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - self.llm_calls.append(build_stack_data(agent.llm)) - - tool_call_found = False - for result in response: - if hasattr(result, "function_call"): - tool_call_found = True - self.tool_calls.append(result.function_call) - tool_response, call_id = agent._execute_tool_action( - tools_dict, result.function_call - ) - function_response_part = types.Part.from_function_response( - name=result.function_call.name, - response={"result": tool_response}, - ) - - messages.append( - {"role": "model", "content": [result.to_json_dict()]} - ) - messages.append( - { - "role": "tool", - "content": [function_response_part.to_json_dict()], - } - ) - else: - tool_call_found = False - yield result - - if not tool_call_found: - return response - - -def get_llm_handler(llm_type): - handlers = { - "openai": OpenAILLMHandler(), - "google": GoogleLLMHandler(), - } - return handlers.get(llm_type, OpenAILLMHandler()) diff --git a/application/agents/tools/tool_action_parser.py b/application/agents/tools/tool_action_parser.py index c7da5a4c..0589ac88 100644 --- a/application/agents/tools/tool_action_parser.py +++ b/application/agents/tools/tool_action_parser.py @@ -17,26 +17,21 @@ class ToolActionParser: return parser(call) def _parse_openai_llm(self, call): - if isinstance(call, dict): - try: - call_args = json.loads(call["function"]["arguments"]) - tool_id = call["function"]["name"].split("_")[-1] - action_name = call["function"]["name"].rsplit("_", 1)[0] - except (KeyError, TypeError) as e: - logger.error(f"Error parsing OpenAI LLM call: {e}") - return None, None, None - else: - try: - call_args = json.loads(call.function.arguments) - tool_id = call.function.name.split("_")[-1] - action_name = call.function.name.rsplit("_", 1)[0] - except (AttributeError, TypeError) as e: - logger.error(f"Error parsing OpenAI LLM call: {e}") - return None, None, None + try: + call_args = json.loads(call.arguments) + tool_id = call.name.split("_")[-1] + action_name = call.name.rsplit("_", 1)[0] + except (AttributeError, TypeError) as e: + logger.error(f"Error parsing OpenAI LLM call: {e}") + return None, None, None return tool_id, action_name, call_args def _parse_google_llm(self, call): - call_args = call.args - tool_id = call.name.split("_")[-1] - action_name = call.name.rsplit("_", 1)[0] + try: + call_args = call.arguments + tool_id = call.name.split("_")[-1] + action_name = call.name.rsplit("_", 1)[0] + except (AttributeError, TypeError) as e: + logger.error(f"Error parsing Google LLM call: {e}") + return None, None, None return tool_id, action_name, call_args diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 3c2ba866..469ea98c 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -37,17 +37,17 @@ api.add_namespace(answer_ns) gpt_model = "" # to have some kind of default behaviour -if settings.LLM_NAME == "openai": +if settings.LLM_PROVIDER == "openai": gpt_model = "gpt-4o-mini" -elif settings.LLM_NAME == "anthropic": +elif settings.LLM_PROVIDER == "anthropic": gpt_model = "claude-2" -elif settings.LLM_NAME == "groq": +elif settings.LLM_PROVIDER == "groq": gpt_model = "llama3-8b-8192" -elif settings.LLM_NAME == "novita": +elif settings.LLM_PROVIDER == "novita": gpt_model = "deepseek/deepseek-r1" -if settings.MODEL_NAME: # in case there is particular model name configured - gpt_model = settings.MODEL_NAME +if settings.LLM_NAME: # in case there is particular model name configured + gpt_model = settings.LLM_NAME # load the prompts current_dir = os.path.dirname( @@ -307,19 +307,20 @@ def complete_stream( yield f"data: {data}\n\n" elif "tool_calls" in line: tool_calls = line["tool_calls"] - data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls}) - yield f"data: {data}\n\n" elif "thought" in line: thought += line["thought"] data = json.dumps({"type": "thought", "thought": line["thought"]}) yield f"data: {data}\n\n" + elif "type" in line: + data = json.dumps(line) + yield f"data: {data}\n\n" if isNoneDoc: for doc in source_log_docs: doc["source"] = "None" llm = LLMCreator.create_llm( - settings.LLM_NAME, + settings.LLM_PROVIDER, api_key=settings.API_KEY, user_api_key=user_api_key, decoded_token=decoded_token, @@ -451,9 +452,7 @@ class Stream(Resource): agent_type = settings.AGENT_NAME decoded_token = getattr(request, "decoded_token", None) user_sub = decoded_token.get("sub") if decoded_token else None - agent_key, is_shared_usage, shared_token = get_agent_key( - agent_id, user_sub - ) + agent_key, is_shared_usage, shared_token = get_agent_key(agent_id, user_sub) if agent_key: data.update({"api_key": agent_key}) @@ -504,7 +503,7 @@ class Stream(Resource): agent = AgentCreator.create_agent( agent_type, endpoint="stream", - llm_name=settings.LLM_NAME, + llm_name=settings.LLM_PROVIDER, gpt_model=gpt_model, api_key=settings.API_KEY, user_api_key=user_api_key, @@ -658,7 +657,7 @@ class Answer(Resource): agent = AgentCreator.create_agent( agent_type, endpoint="api/answer", - llm_name=settings.LLM_NAME, + llm_name=settings.LLM_PROVIDER, gpt_model=gpt_model, api_key=settings.API_KEY, user_api_key=user_api_key, @@ -727,7 +726,7 @@ class Answer(Resource): doc["source"] = "None" llm = LLMCreator.create_llm( - settings.LLM_NAME, + settings.LLM_PROVIDER, api_key=settings.API_KEY, user_api_key=user_api_key, decoded_token=decoded_token, diff --git a/application/core/settings.py b/application/core/settings.py index 3be34242..2ff371af 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -11,18 +11,18 @@ current_dir = os.path.dirname( class Settings(BaseSettings): AUTH_TYPE: Optional[str] = None - LLM_NAME: str = "docsgpt" - MODEL_NAME: Optional[str] = ( - None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo + LLM_PROVIDER: str = "docsgpt" + LLM_NAME: Optional[str] = ( + None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo ) EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2" CELERY_BROKER_URL: str = "redis://localhost:6379/0" CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" MONGO_URI: str = "mongodb://localhost:27017/docsgpt" MONGO_DB_NAME: str = "docsgpt" - MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf") + LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf") DEFAULT_MAX_HISTORY: int = 150 - MODEL_TOKEN_LIMITS: dict = { + LLM_TOKEN_LIMITS: dict = { "gpt-4o-mini": 128000, "gpt-3.5-turbo": 4096, "claude-2": 1e5, @@ -35,6 +35,9 @@ class Settings(BaseSettings): ) RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search AGENT_NAME: str = "classic" + FALLBACK_LLM_PROVIDER: Optional[str] = None # provider for fallback llm + FALLBACK_LLM_NAME: Optional[str] = None # model name for fallback llm + FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm # LLM Cache CACHE_REDIS_URL: str = "redis://localhost:6379/2" @@ -99,8 +102,7 @@ class Settings(BaseSettings): BRAVE_SEARCH_API_KEY: Optional[str] = None FLASK_DEBUG_MODE: bool = False - STORAGE_TYPE: str = "local" # local or s3 - + STORAGE_TYPE: str = "local" # local or s3 JWT_SECRET_KEY: str = "" diff --git a/application/llm/base.py b/application/llm/base.py index 0607159d..bef3e11f 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -1,53 +1,117 @@ +import logging from abc import ABC, abstractmethod from application.cache import gen_cache, stream_cache + +from application.core.settings import settings from application.usage import gen_token_usage, stream_token_usage +logger = logging.getLogger(__name__) + class BaseLLM(ABC): - def __init__(self, decoded_token=None): + def __init__( + self, + decoded_token=None, + ): self.decoded_token = decoded_token self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0} + self.fallback_provider = settings.FALLBACK_LLM_PROVIDER + self.fallback_model_name = settings.FALLBACK_LLM_NAME + self.fallback_llm_api_key = settings.FALLBACK_LLM_API_KEY + self._fallback_llm = None - def _apply_decorator(self, method, decorators, *args, **kwargs): - for decorator in decorators: - method = decorator(method) - return method(self, *args, **kwargs) + @property + def fallback_llm(self): + """Lazy-loaded fallback LLM instance.""" + if ( + self._fallback_llm is None + and self.fallback_provider + and self.fallback_model_name + ): + try: + from application.llm.llm_creator import LLMCreator + + self._fallback_llm = LLMCreator.create_llm( + self.fallback_provider, + self.fallback_llm_api_key, + None, + self.decoded_token, + ) + except Exception as e: + logger.error( + f"Failed to initialize fallback LLM: {str(e)}", exc_info=True + ) + return self._fallback_llm + + def _execute_with_fallback( + self, method_name: str, decorators: list, *args, **kwargs + ): + """ + Unified method execution with fallback support. + + Args: + method_name: Name of the raw method ('_raw_gen' or '_raw_gen_stream') + decorators: List of decorators to apply + *args: Positional arguments + **kwargs: Keyword arguments + """ + + def decorated_method(): + method = getattr(self, method_name) + for decorator in decorators: + method = decorator(method) + return method(self, *args, **kwargs) + + try: + return decorated_method() + except Exception as e: + if not self.fallback_llm: + logger.error(f"Primary LLM failed and no fallback available: {str(e)}") + raise + logger.warning( + f"Falling back to {self.fallback_provider}/{self.fallback_model_name}. Error: {str(e)}" + ) + + fallback_method = getattr( + self.fallback_llm, method_name.replace("_raw_", "") + ) + return fallback_method(*args, **kwargs) + + def gen(self, model, messages, stream=False, tools=None, *args, **kwargs): + decorators = [gen_token_usage, gen_cache] + return self._execute_with_fallback( + "_raw_gen", + decorators, + model=model, + messages=messages, + stream=stream, + tools=tools, + *args, + **kwargs, + ) + + def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs): + decorators = [stream_cache, stream_token_usage] + return self._execute_with_fallback( + "_raw_gen_stream", + decorators, + model=model, + messages=messages, + stream=stream, + tools=tools, + *args, + **kwargs, + ) @abstractmethod def _raw_gen(self, model, messages, stream, tools, *args, **kwargs): pass - def gen(self, model, messages, stream=False, tools=None, *args, **kwargs): - decorators = [gen_token_usage, gen_cache] - return self._apply_decorator( - self._raw_gen, - decorators=decorators, - model=model, - messages=messages, - stream=stream, - tools=tools, - *args, - **kwargs - ) - @abstractmethod def _raw_gen_stream(self, model, messages, stream, *args, **kwargs): pass - def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs): - decorators = [stream_cache, stream_token_usage] - return self._apply_decorator( - self._raw_gen_stream, - decorators=decorators, - model=model, - messages=messages, - stream=stream, - tools=tools, - *args, - **kwargs - ) - def supports_tools(self): return hasattr(self, "_supports_tools") and callable( getattr(self, "_supports_tools") @@ -55,11 +119,11 @@ class BaseLLM(ABC): def _supports_tools(self): raise NotImplementedError("Subclass must implement _supports_tools method") - + def get_supported_attachment_types(self): """ Return a list of MIME types supported by this LLM for file uploads. - + Returns: list: List of supported MIME types """ diff --git a/application/llm/handlers/__init__.py b/application/llm/handlers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/application/llm/handlers/base.py b/application/llm/handlers/base.py new file mode 100644 index 00000000..43205472 --- /dev/null +++ b/application/llm/handlers/base.py @@ -0,0 +1,335 @@ +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, Generator, List, Optional, Union + +from application.logging import build_stack_data + +logger = logging.getLogger(__name__) + + +@dataclass +class ToolCall: + """Represents a tool/function call from the LLM.""" + + id: str + name: str + arguments: Union[str, Dict] + index: Optional[int] = None + + @classmethod + def from_dict(cls, data: Dict) -> "ToolCall": + """Create ToolCall from dictionary.""" + return cls( + id=data.get("id", ""), + name=data.get("name", ""), + arguments=data.get("arguments", {}), + index=data.get("index"), + ) + + +@dataclass +class LLMResponse: + """Represents a response from the LLM.""" + + content: str + tool_calls: List[ToolCall] + finish_reason: str + raw_response: Any + + @property + def requires_tool_call(self) -> bool: + """Check if the response requires tool calls.""" + return bool(self.tool_calls) and self.finish_reason == "tool_calls" + + +class LLMHandler(ABC): + """Abstract base class for LLM handlers.""" + + def __init__(self): + self.llm_calls = [] + self.tool_calls = [] + + @abstractmethod + def parse_response(self, response: Any) -> LLMResponse: + """Parse raw LLM response into standardized format.""" + pass + + @abstractmethod + def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict: + """Create a tool result message for the conversation history.""" + pass + + @abstractmethod + def _iterate_stream(self, response: Any) -> Generator: + """Iterate through streaming response chunks.""" + pass + + def process_message_flow( + self, + agent, + initial_response, + tools_dict: Dict, + messages: List[Dict], + attachments: Optional[List] = None, + stream: bool = False, + ) -> Union[str, Generator]: + """ + Main orchestration method for processing LLM message flow. + + Args: + agent: The agent instance + initial_response: Initial LLM response + tools_dict: Dictionary of available tools + messages: Conversation history + attachments: Optional attachments + stream: Whether to use streaming + + Returns: + Final response or generator for streaming + """ + messages = self.prepare_messages(agent, messages, attachments) + + if stream: + return self.handle_streaming(agent, initial_response, tools_dict, messages) + else: + return self.handle_non_streaming( + agent, initial_response, tools_dict, messages + ) + + def prepare_messages( + self, agent, messages: List[Dict], attachments: Optional[List] = None + ) -> List[Dict]: + """ + Prepare messages with attachments and provider-specific formatting. + + Args: + agent: The agent instance + messages: Original messages + attachments: List of attachments + + Returns: + Prepared messages list + """ + if not attachments: + return messages + logger.info(f"Preparing messages with {len(attachments)} attachments") + supported_types = agent.llm.get_supported_attachment_types() + + supported_attachments = [ + a for a in attachments if a.get("mime_type") in supported_types + ] + unsupported_attachments = [ + a for a in attachments if a.get("mime_type") not in supported_types + ] + + # Process supported attachments with the LLM's custom method + + if supported_attachments: + logger.info( + f"Processing {len(supported_attachments)} supported attachments" + ) + messages = agent.llm.prepare_messages_with_attachments( + messages, supported_attachments + ) + # Process unsupported attachments with default method + + if unsupported_attachments: + logger.info( + f"Processing {len(unsupported_attachments)} unsupported attachments" + ) + messages = self._append_unsupported_attachments( + messages, unsupported_attachments + ) + return messages + + def _append_unsupported_attachments( + self, messages: List[Dict], attachments: List[Dict] + ) -> List[Dict]: + """ + Default method to append unsupported attachment content to system prompt. + + Args: + messages: Current messages + attachments: List of unsupported attachments + + Returns: + Updated messages list + """ + prepared_messages = messages.copy() + attachment_texts = [] + + for attachment in attachments: + logger.info(f"Adding attachment {attachment.get('id')} to context") + if "content" in attachment: + attachment_texts.append( + f"Attached file content:\n\n{attachment['content']}" + ) + if attachment_texts: + combined_text = "\n\n".join(attachment_texts) + + system_msg = next( + (msg for msg in prepared_messages if msg.get("role") == "system"), + {"role": "system", "content": ""}, + ) + + if system_msg not in prepared_messages: + prepared_messages.insert(0, system_msg) + system_msg["content"] += f"\n\n{combined_text}" + return prepared_messages + + def handle_tool_calls( + self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict] + ) -> Generator: + """ + Execute tool calls and update conversation history. + + Args: + agent: The agent instance + tool_calls: List of tool calls to execute + tools_dict: Available tools dictionary + messages: Current conversation history + + Returns: + Updated messages list + """ + updated_messages = messages.copy() + + for call in tool_calls: + try: + self.tool_calls.append(call) + tool_executor_gen = agent._execute_tool_action(tools_dict, call) + while True: + try: + yield next(tool_executor_gen) + except StopIteration as e: + tool_response, call_id = e.value + break + + updated_messages.append( + { + "role": "assistant", + "content": [ + { + "function_call": { + "name": call.name, + "args": call.arguments, + "call_id": call_id, + } + } + ], + } + ) + + updated_messages.append(self.create_tool_message(call, tool_response)) + + except Exception as e: + logger.error(f"Error executing tool: {str(e)}", exc_info=True) + updated_messages.append( + { + "role": "tool", + "content": f"Error executing tool: {str(e)}", + "tool_call_id": call.id, + } + ) + + return updated_messages + + def handle_non_streaming( + self, agent, response: Any, tools_dict: Dict, messages: List[Dict] + ) -> Generator: + """ + Handle non-streaming response flow. + + Args: + agent: The agent instance + response: Current LLM response + tools_dict: Available tools dictionary + messages: Conversation history + + Returns: + Final response after processing all tool calls + """ + parsed = self.parse_response(response) + self.llm_calls.append(build_stack_data(agent.llm)) + + while parsed.requires_tool_call: + tool_handler_gen = self.handle_tool_calls( + agent, parsed.tool_calls, tools_dict, messages + ) + while True: + try: + yield next(tool_handler_gen) + except StopIteration as e: + messages = e.value + break + + response = agent.llm.gen( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + parsed = self.parse_response(response) + self.llm_calls.append(build_stack_data(agent.llm)) + + return parsed.content + + def handle_streaming( + self, agent, response: Any, tools_dict: Dict, messages: List[Dict] + ) -> Generator: + """ + Handle streaming response flow. + + Args: + agent: The agent instance + response: Current LLM response + tools_dict: Available tools dictionary + messages: Conversation history + + Yields: + Streaming response chunks + """ + buffer = "" + tool_calls = {} + + for chunk in self._iterate_stream(response): + if isinstance(chunk, str): + yield chunk + continue + parsed = self.parse_response(chunk) + + if parsed.tool_calls: + for call in parsed.tool_calls: + if call.index not in tool_calls: + tool_calls[call.index] = call + else: + existing = tool_calls[call.index] + if call.id: + existing.id = call.id + if call.name: + existing.name = call.name + if call.arguments: + existing.arguments += call.arguments + if parsed.finish_reason == "tool_calls": + tool_handler_gen = self.handle_tool_calls( + agent, list(tool_calls.values()), tools_dict, messages + ) + while True: + try: + yield next(tool_handler_gen) + except StopIteration as e: + messages = e.value + break + tool_calls = {} + + response = agent.llm.gen_stream( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + self.llm_calls.append(build_stack_data(agent.llm)) + + yield from self.handle_streaming(agent, response, tools_dict, messages) + return + if parsed.content: + buffer += parsed.content + yield buffer + buffer = "" + if parsed.finish_reason == "stop": + return diff --git a/application/llm/handlers/google.py b/application/llm/handlers/google.py new file mode 100644 index 00000000..b43f2a16 --- /dev/null +++ b/application/llm/handlers/google.py @@ -0,0 +1,78 @@ +import uuid +from typing import Any, Dict, Generator + +from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall + + +class GoogleLLMHandler(LLMHandler): + """Handler for Google's GenAI API.""" + + def parse_response(self, response: Any) -> LLMResponse: + """Parse Google response into standardized format.""" + + if isinstance(response, str): + return LLMResponse( + content=response, + tool_calls=[], + finish_reason="stop", + raw_response=response, + ) + + if hasattr(response, "candidates"): + parts = response.candidates[0].content.parts if response.candidates else [] + tool_calls = [ + ToolCall( + id=str(uuid.uuid4()), + name=part.function_call.name, + arguments=part.function_call.args, + ) + for part in parts + if hasattr(part, "function_call") and part.function_call is not None + ] + + content = " ".join( + part.text + for part in parts + if hasattr(part, "text") and part.text is not None + ) + return LLMResponse( + content=content, + tool_calls=tool_calls, + finish_reason="tool_calls" if tool_calls else "stop", + raw_response=response, + ) + + else: + tool_calls = [] + if hasattr(response, "function_call"): + tool_calls.append( + ToolCall( + id=str(uuid.uuid4()), + name=response.function_call.name, + arguments=response.function_call.args, + ) + ) + return LLMResponse( + content=response.text if hasattr(response, "text") else "", + tool_calls=tool_calls, + finish_reason="tool_calls" if tool_calls else "stop", + raw_response=response, + ) + + def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict: + """Create Google-style tool message.""" + from google.genai import types + + return { + "role": "tool", + "content": [ + types.Part.from_function_response( + name=tool_call.name, response={"result": result} + ).to_json_dict() + ], + } + + def _iterate_stream(self, response: Any) -> Generator: + """Iterate through Google streaming response.""" + for chunk in response: + yield chunk diff --git a/application/llm/handlers/handler_creator.py b/application/llm/handlers/handler_creator.py new file mode 100644 index 00000000..e39c000d --- /dev/null +++ b/application/llm/handlers/handler_creator.py @@ -0,0 +1,18 @@ +from application.llm.handlers.base import LLMHandler +from application.llm.handlers.google import GoogleLLMHandler +from application.llm.handlers.openai import OpenAILLMHandler + + +class LLMHandlerCreator: + handlers = { + "openai": OpenAILLMHandler, + "google": GoogleLLMHandler, + "default": OpenAILLMHandler, + } + + @classmethod + def create_handler(cls, llm_type: str, *args, **kwargs) -> LLMHandler: + handler_class = cls.handlers.get(llm_type.lower()) + if not handler_class: + raise ValueError(f"No LLM handler class found for type {llm_type}") + return handler_class(*args, **kwargs) diff --git a/application/llm/handlers/openai.py b/application/llm/handlers/openai.py new file mode 100644 index 00000000..99ddde4c --- /dev/null +++ b/application/llm/handlers/openai.py @@ -0,0 +1,57 @@ +from typing import Any, Dict, Generator + +from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall + + +class OpenAILLMHandler(LLMHandler): + """Handler for OpenAI API.""" + + def parse_response(self, response: Any) -> LLMResponse: + """Parse OpenAI response into standardized format.""" + if isinstance(response, str): + return LLMResponse( + content=response, + tool_calls=[], + finish_reason="stop", + raw_response=response, + ) + + message = getattr(response, "message", None) or getattr(response, "delta", None) + + tool_calls = [] + if hasattr(message, "tool_calls"): + tool_calls = [ + ToolCall( + id=getattr(tc, "id", ""), + name=getattr(tc.function, "name", ""), + arguments=getattr(tc.function, "arguments", ""), + index=getattr(tc, "index", None), + ) + for tc in message.tool_calls or [] + ] + return LLMResponse( + content=getattr(message, "content", ""), + tool_calls=tool_calls, + finish_reason=getattr(response, "finish_reason", ""), + raw_response=response, + ) + + def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict: + """Create OpenAI-style tool message.""" + return { + "role": "tool", + "content": [ + { + "function_response": { + "name": tool_call.name, + "response": {"result": result}, + "call_id": tool_call.id, + } + } + ], + } + + def _iterate_stream(self, response: Any) -> Generator: + """Iterate through OpenAI streaming response.""" + for chunk in response: + yield chunk diff --git a/application/llm/llama_cpp.py b/application/llm/llama_cpp.py index 804c3c56..f0418b49 100644 --- a/application/llm/llama_cpp.py +++ b/application/llm/llama_cpp.py @@ -2,6 +2,7 @@ from application.llm.base import BaseLLM from application.core.settings import settings import threading + class LlamaSingleton: _instances = {} _lock = threading.Lock() # Add a lock for thread synchronization @@ -29,7 +30,7 @@ class LlamaCpp(BaseLLM): self, api_key=None, user_api_key=None, - llm_name=settings.MODEL_PATH, + llm_name=settings.LLM_PATH, *args, **kwargs, ): @@ -42,14 +43,18 @@ class LlamaCpp(BaseLLM): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" - result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False) + result = LlamaSingleton.query_model( + self.llama, prompt, max_tokens=150, echo=False + ) return result["choices"][0]["text"].split("### Answer \n")[-1] def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" - result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False, stream=stream) + result = LlamaSingleton.query_model( + self.llama, prompt, max_tokens=150, echo=False, stream=stream + ) for item in result: for choice in item["choices"]: - yield choice["text"] \ No newline at end of file + yield choice["text"] diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py index 33bdb894..123000e4 100644 --- a/application/retriever/brave_search.py +++ b/application/retriever/brave_search.py @@ -29,10 +29,10 @@ class BraveRetSearch(BaseRetriever): self.token_limit = ( token_limit if token_limit - < settings.MODEL_TOKEN_LIMITS.get( + < settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) - else settings.MODEL_TOKEN_LIMITS.get( + else settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) ) @@ -59,7 +59,7 @@ class BraveRetSearch(BaseRetriever): docs.append({"text": snippet, "title": title, "link": link}) except IndexError: pass - if settings.LLM_NAME == "llama.cpp": + if settings.LLM_PROVIDER == "llama.cpp": docs = [docs[0]] return docs @@ -84,7 +84,7 @@ class BraveRetSearch(BaseRetriever): messages_combine.append({"role": "user", "content": self.question}) llm = LLMCreator.create_llm( - settings.LLM_NAME, + settings.LLM_PROVIDER, api_key=settings.API_KEY, user_api_key=self.user_api_key, decoded_token=self.decoded_token, diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 804025e3..9416b4f7 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -16,7 +16,7 @@ class ClassicRAG(BaseRetriever): token_limit=150, gpt_model="docsgpt", user_api_key=None, - llm_name=settings.LLM_NAME, + llm_name=settings.LLM_PROVIDER, api_key=settings.API_KEY, decoded_token=None, ): @@ -28,10 +28,10 @@ class ClassicRAG(BaseRetriever): self.token_limit = ( token_limit if token_limit - < settings.MODEL_TOKEN_LIMITS.get( + < settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) - else settings.MODEL_TOKEN_LIMITS.get( + else settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) ) diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py index 29bbba18..5abe5edd 100644 --- a/application/retriever/duckduck_search.py +++ b/application/retriever/duckduck_search.py @@ -28,10 +28,10 @@ class DuckDuckSearch(BaseRetriever): self.token_limit = ( token_limit if token_limit - < settings.MODEL_TOKEN_LIMITS.get( + < settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) - else settings.MODEL_TOKEN_LIMITS.get( + else settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) ) @@ -58,7 +58,7 @@ class DuckDuckSearch(BaseRetriever): ) except IndexError: pass - if settings.LLM_NAME == "llama.cpp": + if settings.LLM_PROVIDER == "llama.cpp": docs = [docs[0]] return docs @@ -83,7 +83,7 @@ class DuckDuckSearch(BaseRetriever): messages_combine.append({"role": "user", "content": self.question}) llm = LLMCreator.create_llm( - settings.LLM_NAME, + settings.LLM_PROVIDER, api_key=settings.API_KEY, user_api_key=self.user_api_key, decoded_token=self.decoded_token, diff --git a/application/utils.py b/application/utils.py index 548c8828..e749c788 100644 --- a/application/utils.py +++ b/application/utils.py @@ -102,8 +102,8 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"): max_token_limit if max_token_limit and max_token_limit - < settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) - else settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) + < settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) + else settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) ) if not history: diff --git a/application/worker.py b/application/worker.py index 235c969e..c6178931 100755 --- a/application/worker.py +++ b/application/worker.py @@ -143,8 +143,8 @@ def run_agent_logic(agent_config, input_data): agent = AgentCreator.create_agent( agent_type, endpoint="webhook", - llm_name=settings.LLM_NAME, - gpt_model=settings.MODEL_NAME, + llm_name=settings.LLM_PROVIDER, + gpt_model=settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key, prompt=prompt, @@ -159,7 +159,7 @@ def run_agent_logic(agent_config, input_data): prompt=prompt, chunks=chunks, token_limit=settings.DEFAULT_MAX_HISTORY, - gpt_model=settings.MODEL_NAME, + gpt_model=settings.LLM_NAME, user_api_key=user_api_key, decoded_token=decoded_token, ) @@ -452,7 +452,7 @@ def attachment_worker(self, file_info, user): try: self.update_state(state="PROGRESS", meta={"current": 10}) storage = StorageCreator.get_storage() - + self.update_state( state="PROGRESS", meta={"current": 30, "status": "Processing content"} ) @@ -461,9 +461,11 @@ def attachment_worker(self, file_info, user): relative_path, lambda local_path, **kwargs: SimpleDirectoryReader( input_files=[local_path], exclude_hidden=True, errors="ignore" - ).load_data()[0].text + ) + .load_data()[0] + .text, ) - + token_count = num_tokens_from_string(content) self.update_state( @@ -491,9 +493,7 @@ def attachment_worker(self, file_info, user): f"Stored attachment with ID: {attachment_id}", extra={"user": user} ) - self.update_state( - state="PROGRESS", meta={"current": 100, "status": "Complete"} - ) + self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"}) return { "filename": filename, diff --git a/deployment/docker-compose-dev.yaml b/deployment/docker-compose-dev.yaml index 8a3e75c4..a1658bd2 100644 --- a/deployment/docker-compose-dev.yaml +++ b/deployment/docker-compose-dev.yaml @@ -1,3 +1,4 @@ +name: docsgpt-oss services: redis: diff --git a/deployment/docker-compose.yaml b/deployment/docker-compose.yaml index c4b81a08..da9249ea 100644 --- a/deployment/docker-compose.yaml +++ b/deployment/docker-compose.yaml @@ -1,3 +1,4 @@ +name: docsgpt-oss services: frontend: build: ../frontend @@ -17,13 +18,13 @@ services: environment: - API_KEY=$API_KEY - EMBEDDINGS_KEY=$API_KEY + - LLM_PROVIDER=$LLM_PROVIDER - LLM_NAME=$LLM_NAME - CELERY_BROKER_URL=redis://redis:6379/0 - CELERY_RESULT_BACKEND=redis://redis:6379/1 - MONGO_URI=mongodb://mongo:27017/docsgpt - CACHE_REDIS_URL=redis://redis:6379/2 - OPENAI_BASE_URL=$OPENAI_BASE_URL - - MODEL_NAME=$MODEL_NAME ports: - "7091:7091" volumes: @@ -41,6 +42,7 @@ services: environment: - API_KEY=$API_KEY - EMBEDDINGS_KEY=$API_KEY + - LLM_PROVIDER=$LLM_PROVIDER - LLM_NAME=$LLM_NAME - CELERY_BROKER_URL=redis://redis:6379/0 - CELERY_RESULT_BACKEND=redis://redis:6379/1 diff --git a/deployment/k8s/docsgpt-secrets.yaml b/deployment/k8s/docsgpt-secrets.yaml index 45a00973..2d494e8e 100644 --- a/deployment/k8s/docsgpt-secrets.yaml +++ b/deployment/k8s/docsgpt-secrets.yaml @@ -4,7 +4,7 @@ metadata: name: docsgpt-secrets type: Opaque data: - LLM_NAME: ZG9jc2dwdA== + LLM_PROVIDER: ZG9jc2dwdA== INTERNAL_KEY: aW50ZXJuYWw= CELERY_BROKER_URL: cmVkaXM6Ly9yZWRpcy1zZXJ2aWNlOjYzNzkvMA== CELERY_RESULT_BACKEND: cmVkaXM6Ly9yZWRpcy1zZXJ2aWNlOjYzNzkvMA== diff --git a/docs/pages/Deploying/Docker-Deploying.mdx b/docs/pages/Deploying/Docker-Deploying.mdx index 7bb70729..382db5fb 100644 --- a/docs/pages/Deploying/Docker-Deploying.mdx +++ b/docs/pages/Deploying/Docker-Deploying.mdx @@ -37,7 +37,7 @@ The fastest way to try out DocsGPT is by using the public API endpoint. This req Open the `.env` file and add the following lines: ``` - LLM_NAME=docsgpt + LLM_PROVIDER=docsgpt VITE_API_STREAMING=true ``` @@ -93,16 +93,16 @@ There are two Ollama optional files: 3. **Pull the Ollama Model:** - **Crucially, after launching with Ollama, you need to pull the desired model into the Ollama container.** Find the `MODEL_NAME` you configured in your `.env` file (e.g., `llama3.2:1b`). Then execute the following command to pull the model *inside* the running Ollama container: + **Crucially, after launching with Ollama, you need to pull the desired model into the Ollama container.** Find the `LLM_NAME` you configured in your `.env` file (e.g., `llama3.2:1b`). Then execute the following command to pull the model *inside* the running Ollama container: ```bash - docker compose -f deployment/docker-compose.yaml -f deployment/optional/docker-compose.optional.ollama-cpu.yaml exec -it ollama ollama pull + docker compose -f deployment/docker-compose.yaml -f deployment/optional/docker-compose.optional.ollama-cpu.yaml exec -it ollama ollama pull ``` or (for GPU): ```bash - docker compose -f deployment/docker-compose.yaml -f deployment/optional/docker-compose.optional.ollama-gpu.yaml exec -it ollama ollama pull + docker compose -f deployment/docker-compose.yaml -f deployment/optional/docker-compose.optional.ollama-gpu.yaml exec -it ollama ollama pull ``` - Replace `` with the actual model name from your `.env` file. + Replace `` with the actual model name from your `.env` file. 4. **Access DocsGPT in your browser:** diff --git a/docs/pages/Deploying/DocsGPT-Settings.mdx b/docs/pages/Deploying/DocsGPT-Settings.mdx index 239b35d7..92537934 100644 --- a/docs/pages/Deploying/DocsGPT-Settings.mdx +++ b/docs/pages/Deploying/DocsGPT-Settings.mdx @@ -20,9 +20,9 @@ The easiest and recommended way to configure basic settings is by using a `.env` **Example `.env` file structure:** ``` -LLM_NAME=openai +LLM_PROVIDER=openai API_KEY=YOUR_OPENAI_API_KEY -MODEL_NAME=gpt-4o +LLM_NAME=gpt-4o ``` ### 2. Configuration via `settings.py` file (Advanced) @@ -37,7 +37,7 @@ While modifying `settings.py` offers more flexibility, it's generally recommende Here are some of the most fundamental settings you'll likely want to configure: -- **`LLM_NAME`**: This setting determines which Large Language Model (LLM) provider DocsGPT will use. It tells DocsGPT which API to interact with. +- **`LLM_PROVIDER`**: This setting determines which Large Language Model (LLM) provider DocsGPT will use. It tells DocsGPT which API to interact with. - **Common values:** - `docsgpt`: Use the DocsGPT Public API Endpoint (simple and free, as offered in `setup.sh` option 1). @@ -49,11 +49,11 @@ Here are some of the most fundamental settings you'll likely want to configure: - `azure_openai`: Use Azure OpenAI Service. - `openai` (when using local inference engines like Ollama, Llama.cpp, TGI, etc.): This signals DocsGPT to use an OpenAI-compatible API format, even if the actual LLM is running locally. -- **`MODEL_NAME`**: Specifies the specific model to use from the chosen LLM provider. The available models depend on the `LLM_NAME` you've selected. +- **`LLM_NAME`**: Specifies the specific model to use from the chosen LLM provider. The available models depend on the `LLM_PROVIDER` you've selected. - **Examples:** - - For `LLM_NAME=openai`: `gpt-4o` - - For `LLM_NAME=google`: `gemini-2.0-flash` + - For `LLM_PROVIDER=openai`: `gpt-4o` + - For `LLM_PROVIDER=google`: `gemini-2.0-flash` - For local models (e.g., Ollama): `llama3.2:1b` (or any model name available in your setup). - **`EMBEDDINGS_NAME`**: This setting defines which embedding model DocsGPT will use to generate vector embeddings for your documents. Embeddings are numerical representations of text that allow DocsGPT to understand the semantic meaning of your documents for efficient search and retrieval. @@ -63,7 +63,7 @@ Here are some of the most fundamental settings you'll likely want to configure: - **`API_KEY`**: Required for most cloud-based LLM providers. This is your authentication key to access the LLM provider's API. You'll need to obtain this key from your chosen provider's platform. -- **`OPENAI_BASE_URL`**: Specifically used when `LLM_NAME` is set to `openai` but you are connecting to a local inference engine (like Ollama, Llama.cpp, etc.) that exposes an OpenAI-compatible API. This setting tells DocsGPT where to find your local LLM server. +- **`OPENAI_BASE_URL`**: Specifically used when `LLM_PROVIDER` is set to `openai` but you are connecting to a local inference engine (like Ollama, Llama.cpp, etc.) that exposes an OpenAI-compatible API. This setting tells DocsGPT where to find your local LLM server. ## Configuration Examples @@ -74,9 +74,9 @@ Let's look at some concrete examples of how to configure these settings in your To use OpenAI's `gpt-4o` model, you would configure your `.env` file like this: ``` -LLM_NAME=openai +LLM_PROVIDER=openai API_KEY=YOUR_OPENAI_API_KEY # Replace with your actual OpenAI API key -MODEL_NAME=gpt-4o +LLM_NAME=gpt-4o ``` Make sure to replace `YOUR_OPENAI_API_KEY` with your actual OpenAI API key. @@ -86,14 +86,14 @@ Make sure to replace `YOUR_OPENAI_API_KEY` with your actual OpenAI API key. To use a local Ollama server with the `llama3.2:1b` model, you would configure your `.env` file like this: ``` -LLM_NAME=openai # Using OpenAI compatible API format for local models +LLM_PROVIDER=openai # Using OpenAI compatible API format for local models API_KEY=None # API Key is not needed for local Ollama -MODEL_NAME=llama3.2:1b +LLM_NAME=llama3.2:1b OPENAI_BASE_URL=http://host.docker.internal:11434/v1 # Default Ollama API URL within Docker EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2 # You can also run embeddings locally if needed ``` -In this case, even though you are using Ollama locally, `LLM_NAME` is set to `openai` because Ollama (and many other local inference engines) are designed to be API-compatible with OpenAI. `OPENAI_BASE_URL` points DocsGPT to the local Ollama server. +In this case, even though you are using Ollama locally, `LLM_PROVIDER` is set to `openai` because Ollama (and many other local inference engines) are designed to be API-compatible with OpenAI. `OPENAI_BASE_URL` points DocsGPT to the local Ollama server. ## Authentication Settings diff --git a/docs/pages/Guides/How-to-use-different-LLM.mdx b/docs/pages/Guides/How-to-use-different-LLM.mdx index 3bc8477d..217bc690 100644 --- a/docs/pages/Guides/How-to-use-different-LLM.mdx +++ b/docs/pages/Guides/How-to-use-different-LLM.mdx @@ -32,9 +32,9 @@ Choose the LLM of your choice. ### For Open source llm change: ### Step 1 -For open source version please edit `LLM_NAME`, `MODEL_NAME` and others in the .env file. Refer to [⚙️ App Configuration](/Deploying/DocsGPT-Settings) for more information. +For open source version please edit `LLM_PROVIDER`, `LLM_NAME` and others in the .env file. Refer to [⚙️ App Configuration](/Deploying/DocsGPT-Settings) for more information. ### Step 2 -Visit [☁️ Cloud Providers](/Models/cloud-providers) for the updated list of online models. Make sure you have the right API_KEY and correct LLM_NAME. +Visit [☁️ Cloud Providers](/Models/cloud-providers) for the updated list of online models. Make sure you have the right API_KEY and correct LLM_PROVIDER. For self-hosted please visit [🖥️ Local Inference](/Models/local-inference). diff --git a/docs/pages/Models/cloud-providers.mdx b/docs/pages/Models/cloud-providers.mdx index 86f2d132..36e737fa 100644 --- a/docs/pages/Models/cloud-providers.mdx +++ b/docs/pages/Models/cloud-providers.mdx @@ -13,15 +13,15 @@ The primary method for configuring your LLM provider in DocsGPT is through the ` To connect to a cloud LLM provider, you will typically need to configure the following basic settings in your `.env` file: -* **`LLM_NAME`**: This setting is essential and identifies the specific cloud provider you wish to use (e.g., `openai`, `google`, `anthropic`). -* **`MODEL_NAME`**: Specifies the exact model you want to utilize from your chosen provider (e.g., `gpt-4o`, `gemini-2.0-flash`, `claude-3-5-sonnet-latest`). Refer to your provider's documentation for a list of available models. +* **`LLM_PROVIDER`**: This setting is essential and identifies the specific cloud provider you wish to use (e.g., `openai`, `google`, `anthropic`). +* **`LLM_NAME`**: Specifies the exact model you want to utilize from your chosen provider (e.g., `gpt-4o`, `gemini-2.0-flash`, `claude-3-5-sonnet-latest`). Refer to your provider's documentation for a list of available models. * **`API_KEY`**: Almost all cloud LLM providers require an API key for authentication. Obtain your API key from your chosen provider's platform and securely store it in your `.env` file. ## Explicitly Supported Cloud Providers -DocsGPT offers direct, streamlined support for the following cloud LLM providers, making configuration straightforward. The table below outlines the `LLM_NAME` and example `MODEL_NAME` values to use for each provider in your `.env` file. +DocsGPT offers direct, streamlined support for the following cloud LLM providers, making configuration straightforward. The table below outlines the `LLM_PROVIDER` and example `LLM_NAME` values to use for each provider in your `.env` file. -| Provider | `LLM_NAME` | Example `MODEL_NAME` | +| Provider | `LLM_PROVIDER` | Example `LLM_NAME` | | :--------------------------- | :------------- | :-------------------------- | | DocsGPT Public API | `docsgpt` | `None` | | OpenAI | `openai` | `gpt-4o` | @@ -35,16 +35,16 @@ DocsGPT offers direct, streamlined support for the following cloud LLM providers DocsGPT's flexible architecture allows you to connect to any cloud provider that offers an API compatible with the OpenAI API standard. This opens up a vast ecosystem of LLM services. -To connect to an OpenAI-compatible cloud provider, you will still use `LLM_NAME=openai` in your `.env` file. However, you will also need to specify the API endpoint of your chosen provider using the `OPENAI_BASE_URL` setting. You will also likely need to provide an `API_KEY` and `MODEL_NAME` as required by that provider. +To connect to an OpenAI-compatible cloud provider, you will still use `LLM_PROVIDER=openai` in your `.env` file. However, you will also need to specify the API endpoint of your chosen provider using the `OPENAI_BASE_URL` setting. You will also likely need to provide an `API_KEY` and `LLM_NAME` as required by that provider. **Example for DeepSeek (OpenAI-Compatible API):** To connect to DeepSeek, which offers an OpenAI-compatible API, your `.env` file could be configured as follows: ``` -LLM_NAME=openai +LLM_PROVIDER=openai API_KEY=YOUR_API_KEY # Your DeepSeek API key -MODEL_NAME=deepseek-chat # Or your desired DeepSeek model name +LLM_NAME=deepseek-chat # Or your desired DeepSeek model name OPENAI_BASE_URL=https://api.deepseek.com/v1 # DeepSeek's OpenAI API URL ``` diff --git a/docs/pages/Models/embeddings.md b/docs/pages/Models/embeddings.md index 6dfb89b6..68015db7 100644 --- a/docs/pages/Models/embeddings.md +++ b/docs/pages/Models/embeddings.md @@ -60,7 +60,7 @@ To use OpenAI's `text-embedding-ada-002` embedding model, you need to set `EMBED **Example `.env` configuration for OpenAI Embeddings:** ``` -LLM_NAME=openai +LLM_PROVIDER=openai API_KEY=YOUR_OPENAI_API_KEY # Your OpenAI API Key EMBEDDINGS_NAME=openai_text-embedding-ada-002 ``` diff --git a/docs/pages/Models/local-inference.mdx b/docs/pages/Models/local-inference.mdx index 4aa6bca2..0bba907b 100644 --- a/docs/pages/Models/local-inference.mdx +++ b/docs/pages/Models/local-inference.mdx @@ -15,8 +15,8 @@ Setting up a local inference engine with DocsGPT is configured through environme To connect to a local inference engine, you will generally need to configure these settings in your `.env` file: -* **`LLM_NAME`**: Crucially set this to `openai`. This tells DocsGPT to use the OpenAI-compatible API format for communication, even though the LLM is local. -* **`MODEL_NAME`**: Specify the model name as recognized by your local inference engine. This might be a model identifier or left as `None` if the engine doesn't require explicit model naming in the API request. +* **`LLM_PROVIDER`**: Crucially set this to `openai`. This tells DocsGPT to use the OpenAI-compatible API format for communication, even though the LLM is local. +* **`LLM_NAME`**: Specify the model name as recognized by your local inference engine. This might be a model identifier or left as `None` if the engine doesn't require explicit model naming in the API request. * **`OPENAI_BASE_URL`**: This is essential. Set this to the base URL of your local inference engine's API endpoint. This tells DocsGPT where to find your local LLM server. * **`API_KEY`**: Generally, for local inference engines, you can set `API_KEY=None` as authentication is usually not required in local setups. @@ -24,16 +24,16 @@ To connect to a local inference engine, you will generally need to configure the DocsGPT is readily configurable to work with the following local inference engines, all communicating via the OpenAI API format. Here are example `OPENAI_BASE_URL` values for each, based on default setups: -| Inference Engine | `LLM_NAME` | `OPENAI_BASE_URL` | -| :---------------------------- | :--------- | :------------------------- | -| LLaMa.cpp | `openai` | `http://localhost:8000/v1` | -| Ollama | `openai` | `http://localhost:11434/v1` | -| Text Generation Inference (TGI)| `openai` | `http://localhost:8080/v1` | -| SGLang | `openai` | `http://localhost:30000/v1` | -| vLLM | `openai` | `http://localhost:8000/v1` | -| Aphrodite | `openai` | `http://localhost:2242/v1` | -| FriendliAI | `openai` | `http://localhost:8997/v1` | -| LMDeploy | `openai` | `http://localhost:23333/v1` | +| Inference Engine | `LLM_PROVIDER` | `OPENAI_BASE_URL` | +| :---------------------------- | :------------- | :------------------------- | +| LLaMa.cpp | `openai` | `http://localhost:8000/v1` | +| Ollama | `openai` | `http://localhost:11434/v1` | +| Text Generation Inference (TGI)| `openai` | `http://localhost:8080/v1` | +| SGLang | `openai` | `http://localhost:30000/v1` | +| vLLM | `openai` | `http://localhost:8000/v1` | +| Aphrodite | `openai` | `http://localhost:2242/v1` | +| FriendliAI | `openai` | `http://localhost:8997/v1` | +| LMDeploy | `openai` | `http://localhost:23333/v1` | **Important Note on `localhost` vs `host.docker.internal`:** diff --git a/frontend/src/Navigation.tsx b/frontend/src/Navigation.tsx index 57c5de45..300253b8 100644 --- a/frontend/src/Navigation.tsx +++ b/frontend/src/Navigation.tsx @@ -48,6 +48,7 @@ import { setConversations, setModalStateDeleteConv, setSelectedAgent, + setSharedAgents, } from './preferences/preferenceSlice'; import Upload from './upload/Upload'; @@ -169,73 +170,65 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { const handleTogglePin = (agent: Agent) => { userService.togglePinAgent(agent.id ?? '', token).then((response) => { if (response.ok) { - const updatedAgents = agents?.map((a) => - a.id === agent.id ? { ...a, pinned: !a.pinned } : a, - ); - dispatch(setAgents(updatedAgents)); + const updatePinnedStatus = (a: Agent) => + a.id === agent.id ? { ...a, pinned: !a.pinned } : a; + dispatch(setAgents(agents?.map(updatePinnedStatus))); + dispatch(setSharedAgents(sharedAgents?.map(updatePinnedStatus))); } }); }; - const handleConversationClick = (index: string) => { - dispatch(setSelectedAgent(null)); - conversationService - .getConversation(index, token) - .then((response) => { - if (!response.ok) { - navigate('/'); - dispatch(setSelectedAgent(null)); - return null; - } - return response.json(); - }) - .then((data) => { - if (!data) return; - dispatch(setConversation(data.queries)); - dispatch( - updateConversationId({ - query: { conversationId: index }, - }), + const handleConversationClick = async (index: string) => { + try { + dispatch(setSelectedAgent(null)); + + const response = await conversationService.getConversation(index, token); + if (!response.ok) { + navigate('/'); + return; + } + + const data = await response.json(); + if (!data) return; + + dispatch(setConversation(data.queries)); + dispatch(updateConversationId({ query: { conversationId: index } })); + + if (!data.agent_id) { + navigate('/'); + return; + } + + let agent: Agent; + if (data.is_shared_usage) { + const sharedResponse = await userService.getSharedAgent( + data.shared_token, + token, ); - if (isMobile || isTablet) { - setNavOpen(false); - } - if (data.agent_id) { - if (data.is_shared_usage) { - userService - .getSharedAgent(data.shared_token, token) - .then((response) => { - if (!response.ok) { - navigate('/'); - dispatch(setSelectedAgent(null)); - return; - } - response.json().then((agent: Agent) => { - navigate(`/agents/shared/${agent.shared_token}`); - }); - }); - } else { - userService.getAgent(data.agent_id, token).then((response) => { - if (!response.ok) { - navigate('/'); - dispatch(setSelectedAgent(null)); - return; - } - response.json().then((agent: Agent) => { - if (agent.shared_token) - navigate(`/agents/shared/${agent.shared_token}`); - else { - dispatch(setSelectedAgent(agent)); - navigate('/'); - } - }); - }); - } - } else { + if (!sharedResponse.ok) { navigate('/'); - dispatch(setSelectedAgent(null)); + return; } - }); + agent = await sharedResponse.json(); + navigate(`/agents/shared/${agent.shared_token}`); + } else { + const agentResponse = await userService.getAgent(data.agent_id, token); + if (!agentResponse.ok) { + navigate('/'); + return; + } + agent = await agentResponse.json(); + if (agent.shared_token) { + navigate(`/agents/shared/${agent.shared_token}`); + } else { + await Promise.resolve(dispatch(setSelectedAgent(agent))); + navigate('/'); + } + } + } catch (error) { + console.error('Error handling conversation click:', error); + navigate('/'); + } }; const resetConversation = () => { diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index 102a8363..1bf8dafe 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -26,7 +26,9 @@ import UserIcon from '../assets/user.svg'; import Accordion from '../components/Accordion'; import Avatar from '../components/Avatar'; import CopyButton from '../components/CopyButton'; +import MermaidRenderer from '../components/MermaidRenderer'; import Sidebar from '../components/Sidebar'; +import Spinner from '../components/Spinner'; import SpeakButton from '../components/TextToSpeechButton'; import { useDarkTheme, useOutsideAlerter } from '../hooks'; import { @@ -36,7 +38,6 @@ import { import classes from './ConversationBubble.module.css'; import { FEEDBACK, MESSAGE_TYPE } from './conversationModels'; import { ToolCallsType } from './types'; -import MermaidRenderer from '../components/MermaidRenderer'; const DisableSourceFE = import.meta.env.VITE_DISABLE_SOURCE_FE || false; @@ -768,7 +769,7 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) { {isToolCallsOpen && ( -
+
{toolCalls.map((toolCall, index) => (

-

- - {JSON.stringify(toolCall.result, null, 2)} + {toolCall.status === 'pending' && ( + + -

+ )} + {toolCall.status === 'completed' && ( +

+ + {JSON.stringify(toolCall.result, null, 2)} + +

+ )}
diff --git a/frontend/src/conversation/ConversationMessages.tsx b/frontend/src/conversation/ConversationMessages.tsx index 8b31dc9c..4bc2bb08 100644 --- a/frontend/src/conversation/ConversationMessages.tsx +++ b/frontend/src/conversation/ConversationMessages.tsx @@ -131,7 +131,7 @@ export default function ConversationMessages({ ? LAST_BUBBLE_MARGIN : DEFAULT_BUBBLE_MARGIN; - if (query.thought || query.response) { + if (query.thought || query.response || query.tool_calls) { const isCurrentlyStreaming = status === 'loading' && index === queries.length - 1; return ( diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index 036962c4..f1d6e615 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -4,14 +4,21 @@ import { getConversations } from '../preferences/preferenceApi'; import { setConversations } from '../preferences/preferenceSlice'; import store from '../store'; import { - selectCompletedAttachments, clearAttachments, + selectCompletedAttachments, } from '../upload/uploadSlice'; import { handleFetchAnswer, handleFetchAnswerSteaming, } from './conversationHandlers'; -import { Answer, Query, Status, ConversationState } from './conversationModels'; +import { + Answer, + Attachment, + ConversationState, + Query, + Status, +} from './conversationModels'; +import { ToolCallsType } from './types'; const initialState: ConversationState = { queries: [], @@ -112,11 +119,11 @@ export const fetchAnswer = createAsyncThunk< query: { sources: data.source ?? [] }, }), ); - } else if (data.type === 'tool_calls') { + } else if (data.type === 'tool_call') { dispatch( - updateToolCalls({ + updateToolCall({ index: targetIndex, - query: { tool_calls: data.tool_calls }, + tool_call: data.data as ToolCallsType, }), ); } else if (data.type === 'error') { @@ -282,12 +289,24 @@ export const conversationSlice = createSlice({ state.queries[index].sources!.push(query.sources![0]); } }, - updateToolCalls( - state, - action: PayloadAction<{ index: number; query: Partial }>, - ) { - const { index, query } = action.payload; - state.queries[index].tool_calls = query?.tool_calls ?? []; + updateToolCall(state, action) { + const { index, tool_call } = action.payload; + + if (!state.queries[index].tool_calls) { + state.queries[index].tool_calls = []; + } + + const existingIndex = state.queries[index].tool_calls.findIndex( + (call) => call.call_id === tool_call.call_id, + ); + + if (existingIndex !== -1) { + const existingCall = state.queries[index].tool_calls[existingIndex]; + state.queries[index].tool_calls[existingIndex] = { + ...existingCall, + ...tool_call, + }; + } else state.queries[index].tool_calls.push(tool_call); }, updateQuery( state, @@ -347,7 +366,7 @@ export const { updateConversationId, updateThought, updateStreamingSource, - updateToolCalls, + updateToolCall, setConversation, setStatus, raiseError, diff --git a/frontend/src/conversation/types/index.ts b/frontend/src/conversation/types/index.ts index 9b5f2365..4ccb04a1 100644 --- a/frontend/src/conversation/types/index.ts +++ b/frontend/src/conversation/types/index.ts @@ -3,5 +3,6 @@ export type ToolCallsType = { action_name: string; call_id: string; arguments: Record; - result: Record; + result?: Record; + status?: 'pending' | 'completed'; }; diff --git a/setup.sh b/setup.sh index 5cf013fc..b072d546 100755 --- a/setup.sh +++ b/setup.sh @@ -169,7 +169,7 @@ prompt_ollama_options() { # 1) Use DocsGPT Public API Endpoint (simple and free) use_docs_public_api_endpoint() { echo -e "\n${NC}Setting up DocsGPT Public API Endpoint...${NC}" - echo "LLM_NAME=docsgpt" > .env + echo "LLM_PROVIDER=docsgpt" > .env echo "VITE_API_STREAMING=true" >> .env echo -e "${GREEN}.env file configured for DocsGPT Public API.${NC}" @@ -237,13 +237,12 @@ serve_local_ollama() { echo -e "\n${NC}Configuring for Ollama ($(echo "$docker_compose_file_suffix" | tr '[:lower:]' '[:upper:]'))...${NC}" # Using tr for uppercase - more compatible echo "API_KEY=xxxx" > .env # Placeholder API Key - echo "LLM_NAME=openai" >> .env - echo "MODEL_NAME=$model_name" >> .env + echo "LLM_PROVIDER=openai" >> .env + echo "LLM_NAME=$model_name" >> .env echo "VITE_API_STREAMING=true" >> .env echo "OPENAI_BASE_URL=http://ollama:11434/v1" >> .env echo "EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2" >> .env echo -e "${GREEN}.env file configured for Ollama ($(echo "$docker_compose_file_suffix" | tr '[:lower:]' '[:upper:]')${NC}${GREEN}).${NC}" - echo -e "${YELLOW}Note: MODEL_NAME is set to '${BOLD}$model_name${NC}${YELLOW}'. You can change it later in the .env file.${NC}" check_and_start_docker @@ -350,8 +349,8 @@ connect_local_inference_engine() { echo -e "\n${NC}Configuring for Local Inference Engine: ${BOLD}${engine_name}...${NC}" echo "API_KEY=None" > .env - echo "LLM_NAME=openai" >> .env - echo "MODEL_NAME=$model_name" >> .env + echo "LLM_PROVIDER=openai" >> .env + echo "LLM_NAME=$model_name" >> .env echo "VITE_API_STREAMING=true" >> .env echo "OPENAI_BASE_URL=$openai_base_url" >> .env echo "EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2" >> .env @@ -381,7 +380,7 @@ connect_local_inference_engine() { # 4) Connect Cloud API Provider connect_cloud_api_provider() { - local provider_choice api_key llm_name + local provider_choice api_key llm_provider local setup_result # Variable to store the return status get_api_key() { @@ -395,43 +394,43 @@ connect_cloud_api_provider() { case "$provider_choice" in 1) # OpenAI provider_name="OpenAI" - llm_name="openai" + llm_provider="openai" model_name="gpt-4o" get_api_key break ;; 2) # Google provider_name="Google (Vertex AI, Gemini)" - llm_name="google" + llm_provider="google" model_name="gemini-2.0-flash" get_api_key break ;; 3) # Anthropic provider_name="Anthropic (Claude)" - llm_name="anthropic" + llm_provider="anthropic" model_name="claude-3-5-sonnet-latest" get_api_key break ;; 4) # Groq provider_name="Groq" - llm_name="groq" + llm_provider="groq" model_name="llama-3.1-8b-instant" get_api_key break ;; 5) # HuggingFace Inference API provider_name="HuggingFace Inference API" - llm_name="huggingface" + llm_provider="huggingface" model_name="meta-llama/Llama-3.1-8B-Instruct" get_api_key break ;; 6) # Azure OpenAI provider_name="Azure OpenAI" - llm_name="azure_openai" + llm_provider="azure_openai" model_name="gpt-4o" get_api_key break ;; 7) # Novita provider_name="Novita" - llm_name="novita" + llm_provider="novita" model_name="deepseek/deepseek-r1" get_api_key break ;; @@ -442,8 +441,8 @@ connect_cloud_api_provider() { echo -e "\n${NC}Configuring for Cloud API Provider: ${BOLD}${provider_name}...${NC}" echo "API_KEY=$api_key" > .env - echo "LLM_NAME=$llm_name" >> .env - echo "MODEL_NAME=$model_name" >> .env + echo "LLM_PROVIDER=$llm_provider" >> .env + echo "LLM_NAME=$model_name" >> .env echo "VITE_API_STREAMING=true" >> .env echo -e "${GREEN}.env file configured for ${BOLD}${provider_name}${NC}${GREEN}.${NC}"