mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Compare commits
2 Commits
fix/model-
...
dc2faf7a7e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc2faf7a7e | ||
|
|
67e0d222d1 |
@@ -103,11 +103,10 @@ class StreamProcessor:
|
|||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
"""Initialize all required components for processing"""
|
"""Initialize all required components for processing"""
|
||||||
self._validate_and_set_model()
|
|
||||||
self._configure_agent()
|
self._configure_agent()
|
||||||
|
self._validate_and_set_model()
|
||||||
self._configure_source()
|
self._configure_source()
|
||||||
self._configure_retriever()
|
self._configure_retriever()
|
||||||
self._configure_agent()
|
|
||||||
self._load_conversation_history()
|
self._load_conversation_history()
|
||||||
self._process_attachments()
|
self._process_attachments()
|
||||||
|
|
||||||
@@ -230,7 +229,12 @@ class StreamProcessor:
|
|||||||
)
|
)
|
||||||
self.model_id = requested_model
|
self.model_id = requested_model
|
||||||
else:
|
else:
|
||||||
self.model_id = get_default_model_id()
|
# Check if agent has a default model configured
|
||||||
|
agent_default_model = self.agent_config.get("default_model_id", "")
|
||||||
|
if agent_default_model and validate_model_id(agent_default_model):
|
||||||
|
self.model_id = agent_default_model
|
||||||
|
else:
|
||||||
|
self.model_id = get_default_model_id()
|
||||||
|
|
||||||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||||
"""Get API key for agent with access control"""
|
"""Get API key for agent with access control"""
|
||||||
@@ -303,6 +307,10 @@ class StreamProcessor:
|
|||||||
data["sources"] = sources_list
|
data["sources"] = sources_list
|
||||||
else:
|
else:
|
||||||
data["sources"] = []
|
data["sources"] = []
|
||||||
|
|
||||||
|
# Preserve model configuration from agent
|
||||||
|
data["default_model_id"] = data.get("default_model_id", "")
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _configure_source(self):
|
def _configure_source(self):
|
||||||
@@ -355,6 +363,7 @@ class StreamProcessor:
|
|||||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||||
"user_api_key": api_key,
|
"user_api_key": api_key,
|
||||||
"json_schema": data_key.get("json_schema"),
|
"json_schema": data_key.get("json_schema"),
|
||||||
|
"default_model_id": data_key.get("default_model_id", ""),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
self.initial_user_id = data_key.get("user")
|
self.initial_user_id = data_key.get("user")
|
||||||
@@ -379,6 +388,7 @@ class StreamProcessor:
|
|||||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||||
"user_api_key": self.agent_key,
|
"user_api_key": self.agent_key,
|
||||||
"json_schema": data_key.get("json_schema"),
|
"json_schema": data_key.get("json_schema"),
|
||||||
|
"default_model_id": data_key.get("default_model_id", ""),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
self.decoded_token = (
|
self.decoded_token = (
|
||||||
@@ -405,6 +415,7 @@ class StreamProcessor:
|
|||||||
"agent_type": settings.AGENT_NAME,
|
"agent_type": settings.AGENT_NAME,
|
||||||
"user_api_key": None,
|
"user_api_key": None,
|
||||||
"json_schema": None,
|
"json_schema": None,
|
||||||
|
"default_model_id": "",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ OPENAI_MODELS = [
|
|||||||
supports_tools=True,
|
supports_tools=True,
|
||||||
supports_structured_output=True,
|
supports_structured_output=True,
|
||||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||||
context_window=400000,
|
context_window=200000,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
AvailableModel(
|
AvailableModel(
|
||||||
@@ -49,7 +49,7 @@ OPENAI_MODELS = [
|
|||||||
supports_tools=True,
|
supports_tools=True,
|
||||||
supports_structured_output=True,
|
supports_structured_output=True,
|
||||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||||
context_window=400000,
|
context_window=200000,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@@ -133,7 +133,7 @@ GOOGLE_MODELS = [
|
|||||||
supports_tools=True,
|
supports_tools=True,
|
||||||
supports_structured_output=True,
|
supports_structured_output=True,
|
||||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||||
context_window=20000, # Set low for testing compression
|
context_window=2000000,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -146,6 +146,14 @@ def upload_index(full_path, file_data):
|
|||||||
|
|
||||||
def run_agent_logic(agent_config, input_data):
|
def run_agent_logic(agent_config, input_data):
|
||||||
try:
|
try:
|
||||||
|
from application.core.model_utils import (
|
||||||
|
get_api_key_for_provider,
|
||||||
|
get_default_model_id,
|
||||||
|
get_provider_from_model_id,
|
||||||
|
validate_model_id,
|
||||||
|
)
|
||||||
|
from application.utils import calculate_doc_token_budget
|
||||||
|
|
||||||
source = agent_config.get("source")
|
source = agent_config.get("source")
|
||||||
retriever = agent_config.get("retriever", "classic")
|
retriever = agent_config.get("retriever", "classic")
|
||||||
if isinstance(source, DBRef):
|
if isinstance(source, DBRef):
|
||||||
@@ -160,31 +168,62 @@ def run_agent_logic(agent_config, input_data):
|
|||||||
user_api_key = agent_config["key"]
|
user_api_key = agent_config["key"]
|
||||||
agent_type = agent_config.get("agent_type", "classic")
|
agent_type = agent_config.get("agent_type", "classic")
|
||||||
decoded_token = {"sub": agent_config.get("user")}
|
decoded_token = {"sub": agent_config.get("user")}
|
||||||
|
json_schema = agent_config.get("json_schema")
|
||||||
prompt = get_prompt(prompt_id, db["prompts"])
|
prompt = get_prompt(prompt_id, db["prompts"])
|
||||||
agent = AgentCreator.create_agent(
|
|
||||||
agent_type,
|
# Determine model_id: check agent's default_model_id, fallback to system default
|
||||||
endpoint="webhook",
|
agent_default_model = agent_config.get("default_model_id", "")
|
||||||
llm_name=settings.LLM_PROVIDER,
|
if agent_default_model and validate_model_id(agent_default_model):
|
||||||
model_id=settings.LLM_NAME,
|
model_id = agent_default_model
|
||||||
api_key=settings.API_KEY,
|
else:
|
||||||
user_api_key=user_api_key,
|
model_id = get_default_model_id()
|
||||||
prompt=prompt,
|
|
||||||
chat_history=[],
|
# Get provider and API key for the selected model
|
||||||
decoded_token=decoded_token,
|
provider = get_provider_from_model_id(model_id) if model_id else settings.LLM_PROVIDER
|
||||||
attachments=[],
|
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||||
|
|
||||||
|
# Calculate proper doc_token_limit based on model's context window
|
||||||
|
history_token_limit = 2000 # Default for webhooks
|
||||||
|
doc_token_limit = calculate_doc_token_budget(
|
||||||
|
model_id=model_id, history_token_limit=history_token_limit
|
||||||
)
|
)
|
||||||
|
|
||||||
retriever = RetrieverCreator.create_retriever(
|
retriever = RetrieverCreator.create_retriever(
|
||||||
retriever,
|
retriever,
|
||||||
source=source,
|
source=source,
|
||||||
chat_history=[],
|
chat_history=[],
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
chunks=chunks,
|
chunks=chunks,
|
||||||
token_limit=settings.DEFAULT_MAX_HISTORY,
|
doc_token_limit=doc_token_limit,
|
||||||
model_id=settings.LLM_NAME,
|
model_id=model_id,
|
||||||
user_api_key=user_api_key,
|
user_api_key=user_api_key,
|
||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
)
|
)
|
||||||
answer = agent.gen(query=input_data, retriever=retriever)
|
|
||||||
|
# Pre-fetch documents using the retriever
|
||||||
|
retrieved_docs = []
|
||||||
|
try:
|
||||||
|
docs = retriever.search(input_data)
|
||||||
|
if docs:
|
||||||
|
retrieved_docs = docs
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to retrieve documents: {e}")
|
||||||
|
|
||||||
|
agent = AgentCreator.create_agent(
|
||||||
|
agent_type,
|
||||||
|
endpoint="webhook",
|
||||||
|
llm_name=provider or settings.LLM_PROVIDER,
|
||||||
|
model_id=model_id,
|
||||||
|
api_key=system_api_key,
|
||||||
|
user_api_key=user_api_key,
|
||||||
|
prompt=prompt,
|
||||||
|
chat_history=[],
|
||||||
|
retrieved_docs=retrieved_docs,
|
||||||
|
decoded_token=decoded_token,
|
||||||
|
attachments=[],
|
||||||
|
json_schema=json_schema,
|
||||||
|
)
|
||||||
|
answer = agent.gen(query=input_data)
|
||||||
response_full = ""
|
response_full = ""
|
||||||
thought = ""
|
thought = ""
|
||||||
source_log_docs = []
|
source_log_docs = []
|
||||||
|
|||||||
Reference in New Issue
Block a user