From e9530d5ec5e5fa9fdd5a6ce2a25e886ec9608f61 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Fri, 6 Jun 2025 15:29:53 +0530 Subject: [PATCH] refactor: update env variable names --- application/agents/classic_agent.py | 113 +++++++++++++++-------- application/api/answer/routes.py | 24 +++-- application/core/settings.py | 13 ++- application/llm/llama_cpp.py | 13 ++- application/retriever/brave_search.py | 8 +- application/retriever/classic_rag.py | 6 +- application/retriever/duckduck_search.py | 8 +- application/utils.py | 4 +- application/worker.py | 18 ++-- 9 files changed, 121 insertions(+), 86 deletions(-) diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py index b371123b..d0576511 100644 --- a/application/agents/classic_agent.py +++ b/application/agents/classic_agent.py @@ -1,8 +1,6 @@ from typing import Dict, Generator - from application.agents.base import BaseAgent from application.logging import LogContext - from application.retriever.base import BaseRetriever import logging @@ -10,55 +8,90 @@ logger = logging.getLogger(__name__) class ClassicAgent(BaseAgent): + """A simplified classic agent with clear execution flow. + + Usage: + 1. Processes a query through retrieval + 2. Sets up available tools + 3. Generates responses using LLM + 4. Handles tool interactions if needed + 5. Returns standardized outputs + + Easy to extend by overriding specific steps. + """ + def _gen_inner( self, query: str, retriever: BaseRetriever, log_context: LogContext ) -> Generator[Dict, None, None]: + """Main execution flow for the agent.""" + # Step 1: Retrieve relevant data retrieved_data = self._retriever_search(retriever, query, log_context) - if self.user_api_key: - tools_dict = self._get_tools(self.user_api_key) - else: - tools_dict = self._get_user_tools(self.user) + + # Step 2: Prepare tools + tools_dict = ( + self._get_user_tools(self.user) + if not self.user_api_key + else self._get_tools(self.user_api_key) + ) self._prepare_tools(tools_dict) + # Step 3: Build and process messages messages = self._build_messages(self.prompt, query, retrieved_data) + llm_response = self._llm_gen(messages, log_context) - resp = self._llm_gen(messages, log_context) + # Step 4: Handle the response + yield from self._handle_response( + llm_response, tools_dict, messages, log_context + ) - attachments = self.attachments - - if isinstance(resp, str): - yield {"answer": resp} - return - if ( - hasattr(resp, "message") - and hasattr(resp.message, "content") - and resp.message.content is not None - ): - yield {"answer": resp.message.content} - return - - resp = self._llm_handler(resp, tools_dict, messages, log_context, attachments) - - if isinstance(resp, str): - yield {"answer": resp} - elif ( - hasattr(resp, "message") - and hasattr(resp.message, "content") - and resp.message.content is not None - ): - yield {"answer": resp.message.content} - else: - for line in resp: - if isinstance(line, str): - yield {"answer": line} + # Step 5: Return metadata + yield {"sources": retrieved_data} + yield {"tool_calls": self._get_truncated_tool_calls()} + # Log tool calls for debugging log_context.stacks.append( {"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}} ) - yield {"sources": retrieved_data} - # clean tool_call_data only send first 50 characters of tool_call['result'] - for tool_call in self.tool_calls: - if len(str(tool_call["result"])) > 50: - tool_call["result"] = str(tool_call["result"])[:50] + "..." - yield {"tool_calls": self.tool_calls.copy()} + def _handle_response(self, response, tools_dict, messages, log_context): + """Handle different types of LLM responses consistently.""" + # Handle simple string responses + if isinstance(response, str): + yield {"answer": response} + return + + # Handle content from message objects + if hasattr(response, "message") and getattr(response.message, "content", None): + yield {"answer": response.message.content} + return + + # Handle complex responses that may require tool use + processed_response = self._llm_handler( + response, tools_dict, messages, log_context, self.attachments + ) + + # Yield the final processed response + if isinstance(processed_response, str): + yield {"answer": processed_response} + elif hasattr(processed_response, "message") and getattr( + processed_response.message, "content", None + ): + yield {"answer": processed_response.message.content} + else: + for line in processed_response: + if isinstance(line, str): + yield {"answer": line} + + def _get_truncated_tool_calls(self): + """Return tool calls with truncated results for cleaner output.""" + return [ + { + **tool_call, + "result": ( + f"{str(tool_call['result'])[:50]}..." + if len(str(tool_call["result"])) > 50 + else tool_call["result"] + ), + } + for tool_call in self.tool_calls + ] diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 2aa473d4..44ba035b 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -37,17 +37,17 @@ api.add_namespace(answer_ns) gpt_model = "" # to have some kind of default behaviour -if settings.LLM_NAME == "openai": +if settings.LLM_PROVIDER == "openai": gpt_model = "gpt-4o-mini" -elif settings.LLM_NAME == "anthropic": +elif settings.LLM_PROVIDER == "anthropic": gpt_model = "claude-2" -elif settings.LLM_NAME == "groq": +elif settings.LLM_PROVIDER == "groq": gpt_model = "llama3-8b-8192" -elif settings.LLM_NAME == "novita": +elif settings.LLM_PROVIDER == "novita": gpt_model = "deepseek/deepseek-r1" -if settings.MODEL_NAME: # in case there is particular model name configured - gpt_model = settings.MODEL_NAME +if settings.LLM_NAME: # in case there is particular model name configured + gpt_model = settings.LLM_NAME # load the prompts current_dir = os.path.dirname( @@ -322,7 +322,7 @@ def complete_stream( doc["source"] = "None" llm = LLMCreator.create_llm( - settings.LLM_NAME, + settings.LLM_PROVIDER, api_key=settings.API_KEY, user_api_key=user_api_key, decoded_token=decoded_token, @@ -453,9 +453,7 @@ class Stream(Resource): agent_type = settings.AGENT_NAME decoded_token = getattr(request, "decoded_token", None) user_sub = decoded_token.get("sub") if decoded_token else None - agent_key, is_shared_usage, shared_token = get_agent_key( - agent_id, user_sub - ) + agent_key, is_shared_usage, shared_token = get_agent_key(agent_id, user_sub) if agent_key: data.update({"api_key": agent_key}) @@ -506,7 +504,7 @@ class Stream(Resource): agent = AgentCreator.create_agent( agent_type, endpoint="stream", - llm_name=settings.LLM_NAME, + llm_name=settings.LLM_PROVIDER, gpt_model=gpt_model, api_key=settings.API_KEY, user_api_key=user_api_key, @@ -659,7 +657,7 @@ class Answer(Resource): agent = AgentCreator.create_agent( agent_type, endpoint="api/answer", - llm_name=settings.LLM_NAME, + llm_name=settings.LLM_PROVIDER, gpt_model=gpt_model, api_key=settings.API_KEY, user_api_key=user_api_key, @@ -728,7 +726,7 @@ class Answer(Resource): doc["source"] = "None" llm = LLMCreator.create_llm( - settings.LLM_NAME, + settings.LLM_PROVIDER, api_key=settings.API_KEY, user_api_key=user_api_key, decoded_token=decoded_token, diff --git a/application/core/settings.py b/application/core/settings.py index 3be34242..05a09510 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -11,18 +11,18 @@ current_dir = os.path.dirname( class Settings(BaseSettings): AUTH_TYPE: Optional[str] = None - LLM_NAME: str = "docsgpt" - MODEL_NAME: Optional[str] = ( - None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo + LLM_PROVIDER: str = "docsgpt" + LLM_NAME: Optional[str] = ( + None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo ) EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2" CELERY_BROKER_URL: str = "redis://localhost:6379/0" CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" MONGO_URI: str = "mongodb://localhost:27017/docsgpt" MONGO_DB_NAME: str = "docsgpt" - MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf") + LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf") DEFAULT_MAX_HISTORY: int = 150 - MODEL_TOKEN_LIMITS: dict = { + LLM_TOKEN_LIMITS: dict = { "gpt-4o-mini": 128000, "gpt-3.5-turbo": 4096, "claude-2": 1e5, @@ -99,8 +99,7 @@ class Settings(BaseSettings): BRAVE_SEARCH_API_KEY: Optional[str] = None FLASK_DEBUG_MODE: bool = False - STORAGE_TYPE: str = "local" # local or s3 - + STORAGE_TYPE: str = "local" # local or s3 JWT_SECRET_KEY: str = "" diff --git a/application/llm/llama_cpp.py b/application/llm/llama_cpp.py index 804c3c56..f0418b49 100644 --- a/application/llm/llama_cpp.py +++ b/application/llm/llama_cpp.py @@ -2,6 +2,7 @@ from application.llm.base import BaseLLM from application.core.settings import settings import threading + class LlamaSingleton: _instances = {} _lock = threading.Lock() # Add a lock for thread synchronization @@ -29,7 +30,7 @@ class LlamaCpp(BaseLLM): self, api_key=None, user_api_key=None, - llm_name=settings.MODEL_PATH, + llm_name=settings.LLM_PATH, *args, **kwargs, ): @@ -42,14 +43,18 @@ class LlamaCpp(BaseLLM): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" - result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False) + result = LlamaSingleton.query_model( + self.llama, prompt, max_tokens=150, echo=False + ) return result["choices"][0]["text"].split("### Answer \n")[-1] def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" - result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False, stream=stream) + result = LlamaSingleton.query_model( + self.llama, prompt, max_tokens=150, echo=False, stream=stream + ) for item in result: for choice in item["choices"]: - yield choice["text"] \ No newline at end of file + yield choice["text"] diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py index 33bdb894..123000e4 100644 --- a/application/retriever/brave_search.py +++ b/application/retriever/brave_search.py @@ -29,10 +29,10 @@ class BraveRetSearch(BaseRetriever): self.token_limit = ( token_limit if token_limit - < settings.MODEL_TOKEN_LIMITS.get( + < settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) - else settings.MODEL_TOKEN_LIMITS.get( + else settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) ) @@ -59,7 +59,7 @@ class BraveRetSearch(BaseRetriever): docs.append({"text": snippet, "title": title, "link": link}) except IndexError: pass - if settings.LLM_NAME == "llama.cpp": + if settings.LLM_PROVIDER == "llama.cpp": docs = [docs[0]] return docs @@ -84,7 +84,7 @@ class BraveRetSearch(BaseRetriever): messages_combine.append({"role": "user", "content": self.question}) llm = LLMCreator.create_llm( - settings.LLM_NAME, + settings.LLM_PROVIDER, api_key=settings.API_KEY, user_api_key=self.user_api_key, decoded_token=self.decoded_token, diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index b8ac69e4..2b4e6df7 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -16,7 +16,7 @@ class ClassicRAG(BaseRetriever): token_limit=150, gpt_model="docsgpt", user_api_key=None, - llm_name=settings.LLM_NAME, + llm_name=settings.LLM_PROVIDER, api_key=settings.API_KEY, decoded_token=None, ): @@ -28,10 +28,10 @@ class ClassicRAG(BaseRetriever): self.token_limit = ( token_limit if token_limit - < settings.MODEL_TOKEN_LIMITS.get( + < settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) - else settings.MODEL_TOKEN_LIMITS.get( + else settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) ) diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py index 29bbba18..5abe5edd 100644 --- a/application/retriever/duckduck_search.py +++ b/application/retriever/duckduck_search.py @@ -28,10 +28,10 @@ class DuckDuckSearch(BaseRetriever): self.token_limit = ( token_limit if token_limit - < settings.MODEL_TOKEN_LIMITS.get( + < settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) - else settings.MODEL_TOKEN_LIMITS.get( + else settings.LLM_TOKEN_LIMITS.get( self.gpt_model, settings.DEFAULT_MAX_HISTORY ) ) @@ -58,7 +58,7 @@ class DuckDuckSearch(BaseRetriever): ) except IndexError: pass - if settings.LLM_NAME == "llama.cpp": + if settings.LLM_PROVIDER == "llama.cpp": docs = [docs[0]] return docs @@ -83,7 +83,7 @@ class DuckDuckSearch(BaseRetriever): messages_combine.append({"role": "user", "content": self.question}) llm = LLMCreator.create_llm( - settings.LLM_NAME, + settings.LLM_PROVIDER, api_key=settings.API_KEY, user_api_key=self.user_api_key, decoded_token=self.decoded_token, diff --git a/application/utils.py b/application/utils.py index 6d47d31a..7a9cfd2b 100644 --- a/application/utils.py +++ b/application/utils.py @@ -74,8 +74,8 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"): max_token_limit if max_token_limit and max_token_limit - < settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) - else settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) + < settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) + else settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY) ) if not history: diff --git a/application/worker.py b/application/worker.py index 9829fde9..13c0ca30 100755 --- a/application/worker.py +++ b/application/worker.py @@ -143,8 +143,8 @@ def run_agent_logic(agent_config, input_data): agent = AgentCreator.create_agent( agent_type, endpoint="webhook", - llm_name=settings.LLM_NAME, - gpt_model=settings.MODEL_NAME, + llm_name=settings.LLM_PROVIDER, + gpt_model=settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key, prompt=prompt, @@ -159,7 +159,7 @@ def run_agent_logic(agent_config, input_data): prompt=prompt, chunks=chunks, token_limit=settings.DEFAULT_MAX_HISTORY, - gpt_model=settings.MODEL_NAME, + gpt_model=settings.LLM_NAME, user_api_key=user_api_key, decoded_token=decoded_token, ) @@ -449,7 +449,7 @@ def attachment_worker(self, file_info, user): try: self.update_state(state="PROGRESS", meta={"current": 10}) storage = StorageCreator.get_storage() - + self.update_state( state="PROGRESS", meta={"current": 30, "status": "Processing content"} ) @@ -458,9 +458,11 @@ def attachment_worker(self, file_info, user): relative_path, lambda local_path, **kwargs: SimpleDirectoryReader( input_files=[local_path], exclude_hidden=True, errors="ignore" - ).load_data()[0].text + ) + .load_data()[0] + .text, ) - + token_count = num_tokens_from_string(content) self.update_state( @@ -487,9 +489,7 @@ def attachment_worker(self, file_info, user): f"Stored attachment with ID: {attachment_id}", extra={"user": user} ) - self.update_state( - state="PROGRESS", meta={"current": 100, "status": "Complete"} - ) + self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"}) return { "filename": filename,