refactor: update env variable names

This commit is contained in:
Siddhant Rai
2025-06-06 15:29:53 +05:30
parent 143f4aa886
commit e9530d5ec5
9 changed files with 121 additions and 86 deletions

View File

@@ -1,8 +1,6 @@
from typing import Dict, Generator
from application.agents.base import BaseAgent
from application.logging import LogContext
from application.retriever.base import BaseRetriever
import logging
@@ -10,55 +8,90 @@ logger = logging.getLogger(__name__)
class ClassicAgent(BaseAgent):
"""A simplified classic agent with clear execution flow.
Usage:
1. Processes a query through retrieval
2. Sets up available tools
3. Generates responses using LLM
4. Handles tool interactions if needed
5. Returns standardized outputs
Easy to extend by overriding specific steps.
"""
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Main execution flow for the agent."""
# Step 1: Retrieve relevant data
retrieved_data = self._retriever_search(retriever, query, log_context)
if self.user_api_key:
tools_dict = self._get_tools(self.user_api_key)
else:
tools_dict = self._get_user_tools(self.user)
# Step 2: Prepare tools
tools_dict = (
self._get_user_tools(self.user)
if not self.user_api_key
else self._get_tools(self.user_api_key)
)
self._prepare_tools(tools_dict)
# Step 3: Build and process messages
messages = self._build_messages(self.prompt, query, retrieved_data)
llm_response = self._llm_gen(messages, log_context)
resp = self._llm_gen(messages, log_context)
# Step 4: Handle the response
yield from self._handle_response(
llm_response, tools_dict, messages, log_context
)
attachments = self.attachments
if isinstance(resp, str):
yield {"answer": resp}
return
if (
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
yield {"answer": resp.message.content}
return
resp = self._llm_handler(resp, tools_dict, messages, log_context, attachments)
if isinstance(resp, str):
yield {"answer": resp}
elif (
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
yield {"answer": resp.message.content}
else:
for line in resp:
if isinstance(line, str):
yield {"answer": line}
# Step 5: Return metadata
yield {"sources": retrieved_data}
yield {"tool_calls": self._get_truncated_tool_calls()}
# Log tool calls for debugging
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)
yield {"sources": retrieved_data}
# clean tool_call_data only send first 50 characters of tool_call['result']
for tool_call in self.tool_calls:
if len(str(tool_call["result"])) > 50:
tool_call["result"] = str(tool_call["result"])[:50] + "..."
yield {"tool_calls": self.tool_calls.copy()}
def _handle_response(self, response, tools_dict, messages, log_context):
"""Handle different types of LLM responses consistently."""
# Handle simple string responses
if isinstance(response, str):
yield {"answer": response}
return
# Handle content from message objects
if hasattr(response, "message") and getattr(response.message, "content", None):
yield {"answer": response.message.content}
return
# Handle complex responses that may require tool use
processed_response = self._llm_handler(
response, tools_dict, messages, log_context, self.attachments
)
# Yield the final processed response
if isinstance(processed_response, str):
yield {"answer": processed_response}
elif hasattr(processed_response, "message") and getattr(
processed_response.message, "content", None
):
yield {"answer": processed_response.message.content}
else:
for line in processed_response:
if isinstance(line, str):
yield {"answer": line}
def _get_truncated_tool_calls(self):
"""Return tool calls with truncated results for cleaner output."""
return [
{
**tool_call,
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
else tool_call["result"]
),
}
for tool_call in self.tool_calls
]

View File

@@ -37,17 +37,17 @@ api.add_namespace(answer_ns)
gpt_model = ""
# to have some kind of default behaviour
if settings.LLM_NAME == "openai":
if settings.LLM_PROVIDER == "openai":
gpt_model = "gpt-4o-mini"
elif settings.LLM_NAME == "anthropic":
elif settings.LLM_PROVIDER == "anthropic":
gpt_model = "claude-2"
elif settings.LLM_NAME == "groq":
elif settings.LLM_PROVIDER == "groq":
gpt_model = "llama3-8b-8192"
elif settings.LLM_NAME == "novita":
elif settings.LLM_PROVIDER == "novita":
gpt_model = "deepseek/deepseek-r1"
if settings.MODEL_NAME: # in case there is particular model name configured
gpt_model = settings.MODEL_NAME
if settings.LLM_NAME: # in case there is particular model name configured
gpt_model = settings.LLM_NAME
# load the prompts
current_dir = os.path.dirname(
@@ -322,7 +322,7 @@ def complete_stream(
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_NAME,
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
@@ -453,9 +453,7 @@ class Stream(Resource):
agent_type = settings.AGENT_NAME
decoded_token = getattr(request, "decoded_token", None)
user_sub = decoded_token.get("sub") if decoded_token else None
agent_key, is_shared_usage, shared_token = get_agent_key(
agent_id, user_sub
)
agent_key, is_shared_usage, shared_token = get_agent_key(agent_id, user_sub)
if agent_key:
data.update({"api_key": agent_key})
@@ -506,7 +504,7 @@ class Stream(Resource):
agent = AgentCreator.create_agent(
agent_type,
endpoint="stream",
llm_name=settings.LLM_NAME,
llm_name=settings.LLM_PROVIDER,
gpt_model=gpt_model,
api_key=settings.API_KEY,
user_api_key=user_api_key,
@@ -659,7 +657,7 @@ class Answer(Resource):
agent = AgentCreator.create_agent(
agent_type,
endpoint="api/answer",
llm_name=settings.LLM_NAME,
llm_name=settings.LLM_PROVIDER,
gpt_model=gpt_model,
api_key=settings.API_KEY,
user_api_key=user_api_key,
@@ -728,7 +726,7 @@ class Answer(Resource):
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_NAME,
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,

View File

@@ -11,18 +11,18 @@ current_dir = os.path.dirname(
class Settings(BaseSettings):
AUTH_TYPE: Optional[str] = None
LLM_NAME: str = "docsgpt"
MODEL_NAME: Optional[str] = (
None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
LLM_PROVIDER: str = "docsgpt"
LLM_NAME: Optional[str] = (
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
)
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
MONGO_DB_NAME: str = "docsgpt"
MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
DEFAULT_MAX_HISTORY: int = 150
MODEL_TOKEN_LIMITS: dict = {
LLM_TOKEN_LIMITS: dict = {
"gpt-4o-mini": 128000,
"gpt-3.5-turbo": 4096,
"claude-2": 1e5,
@@ -101,7 +101,6 @@ class Settings(BaseSettings):
FLASK_DEBUG_MODE: bool = False
STORAGE_TYPE: str = "local" # local or s3
JWT_SECRET_KEY: str = ""

View File

@@ -2,6 +2,7 @@ from application.llm.base import BaseLLM
from application.core.settings import settings
import threading
class LlamaSingleton:
_instances = {}
_lock = threading.Lock() # Add a lock for thread synchronization
@@ -29,7 +30,7 @@ class LlamaCpp(BaseLLM):
self,
api_key=None,
user_api_key=None,
llm_name=settings.MODEL_PATH,
llm_name=settings.LLM_PATH,
*args,
**kwargs,
):
@@ -42,14 +43,18 @@ class LlamaCpp(BaseLLM):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False)
result = LlamaSingleton.query_model(
self.llama, prompt, max_tokens=150, echo=False
)
return result["choices"][0]["text"].split("### Answer \n")[-1]
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False, stream=stream)
result = LlamaSingleton.query_model(
self.llama, prompt, max_tokens=150, echo=False, stream=stream
)
for item in result:
for choice in item["choices"]:
yield choice["text"]

View File

@@ -29,10 +29,10 @@ class BraveRetSearch(BaseRetriever):
self.token_limit = (
token_limit
if token_limit
< settings.MODEL_TOKEN_LIMITS.get(
< settings.LLM_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
else settings.MODEL_TOKEN_LIMITS.get(
else settings.LLM_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
)
@@ -59,7 +59,7 @@ class BraveRetSearch(BaseRetriever):
docs.append({"text": snippet, "title": title, "link": link})
except IndexError:
pass
if settings.LLM_NAME == "llama.cpp":
if settings.LLM_PROVIDER == "llama.cpp":
docs = [docs[0]]
return docs
@@ -84,7 +84,7 @@ class BraveRetSearch(BaseRetriever):
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(
settings.LLM_NAME,
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
decoded_token=self.decoded_token,

View File

@@ -16,7 +16,7 @@ class ClassicRAG(BaseRetriever):
token_limit=150,
gpt_model="docsgpt",
user_api_key=None,
llm_name=settings.LLM_NAME,
llm_name=settings.LLM_PROVIDER,
api_key=settings.API_KEY,
decoded_token=None,
):
@@ -28,10 +28,10 @@ class ClassicRAG(BaseRetriever):
self.token_limit = (
token_limit
if token_limit
< settings.MODEL_TOKEN_LIMITS.get(
< settings.LLM_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
else settings.MODEL_TOKEN_LIMITS.get(
else settings.LLM_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
)

View File

@@ -28,10 +28,10 @@ class DuckDuckSearch(BaseRetriever):
self.token_limit = (
token_limit
if token_limit
< settings.MODEL_TOKEN_LIMITS.get(
< settings.LLM_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
else settings.MODEL_TOKEN_LIMITS.get(
else settings.LLM_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
)
@@ -58,7 +58,7 @@ class DuckDuckSearch(BaseRetriever):
)
except IndexError:
pass
if settings.LLM_NAME == "llama.cpp":
if settings.LLM_PROVIDER == "llama.cpp":
docs = [docs[0]]
return docs
@@ -83,7 +83,7 @@ class DuckDuckSearch(BaseRetriever):
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(
settings.LLM_NAME,
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
decoded_token=self.decoded_token,

View File

@@ -74,8 +74,8 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
max_token_limit
if max_token_limit
and max_token_limit
< settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
else settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
< settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
else settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
)
if not history:

View File

@@ -143,8 +143,8 @@ def run_agent_logic(agent_config, input_data):
agent = AgentCreator.create_agent(
agent_type,
endpoint="webhook",
llm_name=settings.LLM_NAME,
gpt_model=settings.MODEL_NAME,
llm_name=settings.LLM_PROVIDER,
gpt_model=settings.LLM_NAME,
api_key=settings.API_KEY,
user_api_key=user_api_key,
prompt=prompt,
@@ -159,7 +159,7 @@ def run_agent_logic(agent_config, input_data):
prompt=prompt,
chunks=chunks,
token_limit=settings.DEFAULT_MAX_HISTORY,
gpt_model=settings.MODEL_NAME,
gpt_model=settings.LLM_NAME,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
@@ -458,7 +458,9 @@ def attachment_worker(self, file_info, user):
relative_path,
lambda local_path, **kwargs: SimpleDirectoryReader(
input_files=[local_path], exclude_hidden=True, errors="ignore"
).load_data()[0].text
)
.load_data()[0]
.text,
)
token_count = num_tokens_from_string(content)
@@ -487,9 +489,7 @@ def attachment_worker(self, file_info, user):
f"Stored attachment with ID: {attachment_id}", extra={"user": user}
)
self.update_state(
state="PROGRESS", meta={"current": 100, "status": "Complete"}
)
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"})
return {
"filename": filename,