diff --git a/application/worker.py b/application/worker.py index 3f957527..5c5dc367 100755 --- a/application/worker.py +++ b/application/worker.py @@ -146,6 +146,14 @@ def upload_index(full_path, file_data): def run_agent_logic(agent_config, input_data): try: + from application.core.model_utils import ( + get_api_key_for_provider, + get_default_model_id, + get_provider_from_model_id, + validate_model_id, + ) + from application.utils import calculate_doc_token_budget + source = agent_config.get("source") retriever = agent_config.get("retriever", "classic") if isinstance(source, DBRef): @@ -160,31 +168,62 @@ def run_agent_logic(agent_config, input_data): user_api_key = agent_config["key"] agent_type = agent_config.get("agent_type", "classic") decoded_token = {"sub": agent_config.get("user")} + json_schema = agent_config.get("json_schema") prompt = get_prompt(prompt_id, db["prompts"]) - agent = AgentCreator.create_agent( - agent_type, - endpoint="webhook", - llm_name=settings.LLM_PROVIDER, - model_id=settings.LLM_NAME, - api_key=settings.API_KEY, - user_api_key=user_api_key, - prompt=prompt, - chat_history=[], - decoded_token=decoded_token, - attachments=[], + + # Determine model_id: check agent's default_model_id, fallback to system default + agent_default_model = agent_config.get("default_model_id", "") + if agent_default_model and validate_model_id(agent_default_model): + model_id = agent_default_model + else: + model_id = get_default_model_id() + + # Get provider and API key for the selected model + provider = get_provider_from_model_id(model_id) if model_id else settings.LLM_PROVIDER + system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER) + + # Calculate proper doc_token_limit based on model's context window + history_token_limit = 2000 # Default for webhooks + doc_token_limit = calculate_doc_token_budget( + model_id=model_id, history_token_limit=history_token_limit ) + retriever = RetrieverCreator.create_retriever( retriever, source=source, chat_history=[], prompt=prompt, chunks=chunks, - token_limit=settings.DEFAULT_MAX_HISTORY, - model_id=settings.LLM_NAME, + doc_token_limit=doc_token_limit, + model_id=model_id, user_api_key=user_api_key, decoded_token=decoded_token, ) - answer = agent.gen(query=input_data, retriever=retriever) + + # Pre-fetch documents using the retriever + retrieved_docs = [] + try: + docs = retriever.search(input_data) + if docs: + retrieved_docs = docs + except Exception as e: + logging.warning(f"Failed to retrieve documents: {e}") + + agent = AgentCreator.create_agent( + agent_type, + endpoint="webhook", + llm_name=provider or settings.LLM_PROVIDER, + model_id=model_id, + api_key=system_api_key, + user_api_key=user_api_key, + prompt=prompt, + chat_history=[], + retrieved_docs=retrieved_docs, + decoded_token=decoded_token, + attachments=[], + json_schema=json_schema, + ) + answer = agent.gen(query=input_data) response_full = "" thought = "" source_log_docs = []