mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-14 10:11:19 +00:00
Merge branch 'main' into main
This commit is contained in:
@@ -20,9 +20,10 @@ with open(
|
||||
"r",
|
||||
) as f:
|
||||
final_prompt_template = f.read()
|
||||
|
||||
|
||||
MAX_ITERATIONS_REASONING = 10
|
||||
|
||||
|
||||
class ReActAgent(BaseAgent):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -38,49 +39,69 @@ class ReActAgent(BaseAgent):
|
||||
collected_content = []
|
||||
if isinstance(resp, str):
|
||||
collected_content.append(resp)
|
||||
elif ( # OpenAI non-streaming or Anthropic non-streaming (older SDK style)
|
||||
elif ( # OpenAI non-streaming or Anthropic non-streaming (older SDK style)
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
collected_content.append(resp.message.content)
|
||||
elif ( # OpenAI non-streaming (Pydantic model), Anthropic new SDK non-streaming
|
||||
hasattr(resp, "choices") and resp.choices and
|
||||
hasattr(resp.choices[0], "message") and
|
||||
hasattr(resp.choices[0].message, "content") and
|
||||
resp.choices[0].message.content is not None
|
||||
elif ( # OpenAI non-streaming (Pydantic model), Anthropic new SDK non-streaming
|
||||
hasattr(resp, "choices")
|
||||
and resp.choices
|
||||
and hasattr(resp.choices[0], "message")
|
||||
and hasattr(resp.choices[0].message, "content")
|
||||
and resp.choices[0].message.content is not None
|
||||
):
|
||||
collected_content.append(resp.choices[0].message.content) # OpenAI
|
||||
elif ( # Anthropic new SDK non-streaming content block
|
||||
hasattr(resp, "content") and isinstance(resp.content, list) and resp.content and
|
||||
hasattr(resp.content[0], "text")
|
||||
collected_content.append(resp.choices[0].message.content) # OpenAI
|
||||
elif ( # Anthropic new SDK non-streaming content block
|
||||
hasattr(resp, "content")
|
||||
and isinstance(resp.content, list)
|
||||
and resp.content
|
||||
and hasattr(resp.content[0], "text")
|
||||
):
|
||||
collected_content.append(resp.content[0].text) # Anthropic
|
||||
collected_content.append(resp.content[0].text) # Anthropic
|
||||
else:
|
||||
# Assume resp is a stream if not a recognized object
|
||||
chunk = None
|
||||
try:
|
||||
for chunk in resp: # This will fail if resp is not iterable (e.g. a non-streaming response object)
|
||||
for (
|
||||
chunk
|
||||
) in (
|
||||
resp
|
||||
): # This will fail if resp is not iterable (e.g. a non-streaming response object)
|
||||
content_piece = ""
|
||||
# OpenAI-like stream
|
||||
if hasattr(chunk, 'choices') and len(chunk.choices) > 0 and \
|
||||
hasattr(chunk.choices[0], 'delta') and \
|
||||
hasattr(chunk.choices[0].delta, 'content') and \
|
||||
chunk.choices[0].delta.content is not None:
|
||||
if (
|
||||
hasattr(chunk, "choices")
|
||||
and len(chunk.choices) > 0
|
||||
and hasattr(chunk.choices[0], "delta")
|
||||
and hasattr(chunk.choices[0].delta, "content")
|
||||
and chunk.choices[0].delta.content is not None
|
||||
):
|
||||
content_piece = chunk.choices[0].delta.content
|
||||
# Anthropic-like stream (ContentBlockDelta)
|
||||
elif hasattr(chunk, 'type') and chunk.type == 'content_block_delta' and \
|
||||
hasattr(chunk, 'delta') and hasattr(chunk.delta, 'text'):
|
||||
elif (
|
||||
hasattr(chunk, "type")
|
||||
and chunk.type == "content_block_delta"
|
||||
and hasattr(chunk, "delta")
|
||||
and hasattr(chunk.delta, "text")
|
||||
):
|
||||
content_piece = chunk.delta.text
|
||||
elif isinstance(chunk, str): # Simplest case: stream of strings
|
||||
elif isinstance(chunk, str): # Simplest case: stream of strings
|
||||
content_piece = chunk
|
||||
|
||||
if content_piece:
|
||||
collected_content.append(content_piece)
|
||||
except TypeError: # If resp is not iterable (e.g. a final response object that wasn't caught above)
|
||||
logger.debug(f"Response type {type(resp)} could not be iterated as a stream. It might be a non-streaming object not handled by specific checks.")
|
||||
except (
|
||||
TypeError
|
||||
): # If resp is not iterable (e.g. a final response object that wasn't caught above)
|
||||
logger.debug(
|
||||
f"Response type {type(resp)} could not be iterated as a stream. It might be a non-streaming object not handled by specific checks."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing potential stream chunk: {e}, chunk was: {getattr(chunk, '__dict__', chunk)}")
|
||||
|
||||
logger.error(
|
||||
f"Error processing potential stream chunk: {e}, chunk was: {getattr(chunk, '__dict__', chunk) if chunk is not None else 'N/A'}"
|
||||
)
|
||||
|
||||
return "".join(collected_content)
|
||||
|
||||
@@ -112,8 +133,9 @@ class ReActAgent(BaseAgent):
|
||||
yield {"thought": line_chunk}
|
||||
self.plan = "".join(current_plan_parts)
|
||||
if self.plan:
|
||||
self.observations.append(f"Plan: {self.plan} Iteration: {iterating_reasoning}")
|
||||
|
||||
self.observations.append(
|
||||
f"Plan: {self.plan} Iteration: {iterating_reasoning}"
|
||||
)
|
||||
|
||||
max_obs_len = 20000
|
||||
obs_str = "\n".join(self.observations)
|
||||
@@ -125,34 +147,55 @@ class ReActAgent(BaseAgent):
|
||||
+ f"\n\nObservations:\n{obs_str}"
|
||||
+ f"\n\nIf there is enough data to complete user query '{query}', Respond with 'SATISFIED' only. Otherwise, continue. Dont Menstion 'SATISFIED' in your response if you are not ready. "
|
||||
)
|
||||
|
||||
|
||||
messages = self._build_messages(execution_prompt_str, query, retrieved_data)
|
||||
|
||||
resp_from_llm_gen = self._llm_gen(messages, log_context)
|
||||
|
||||
initial_llm_thought_content = self._extract_content_from_llm_response(resp_from_llm_gen)
|
||||
initial_llm_thought_content = self._extract_content_from_llm_response(
|
||||
resp_from_llm_gen
|
||||
)
|
||||
if initial_llm_thought_content:
|
||||
self.observations.append(f"Initial thought/response: {initial_llm_thought_content}")
|
||||
self.observations.append(
|
||||
f"Initial thought/response: {initial_llm_thought_content}"
|
||||
)
|
||||
else:
|
||||
logger.info("ReActAgent: Initial LLM response (before handler) had no textual content (might be only tool calls).")
|
||||
resp_after_handler = self._llm_handler(resp_from_llm_gen, tools_dict, messages, log_context)
|
||||
|
||||
for tool_call_info in self.tool_calls: # Iterate over self.tool_calls populated by _llm_handler
|
||||
logger.info(
|
||||
"ReActAgent: Initial LLM response (before handler) had no textual content (might be only tool calls)."
|
||||
)
|
||||
resp_after_handler = self._llm_handler(
|
||||
resp_from_llm_gen, tools_dict, messages, log_context
|
||||
)
|
||||
|
||||
for (
|
||||
tool_call_info
|
||||
) in (
|
||||
self.tool_calls
|
||||
): # Iterate over self.tool_calls populated by _llm_handler
|
||||
observation_string = (
|
||||
f"Executed Action: Tool '{tool_call_info.get('tool_name', 'N/A')}' "
|
||||
f"with arguments '{tool_call_info.get('arguments', '{}')}'. Result: '{str(tool_call_info.get('result', ''))[:200]}...'"
|
||||
)
|
||||
self.observations.append(observation_string)
|
||||
|
||||
content_after_handler = self._extract_content_from_llm_response(resp_after_handler)
|
||||
content_after_handler = self._extract_content_from_llm_response(
|
||||
resp_after_handler
|
||||
)
|
||||
if content_after_handler:
|
||||
self.observations.append(f"Response after tool execution: {content_after_handler}")
|
||||
self.observations.append(
|
||||
f"Response after tool execution: {content_after_handler}"
|
||||
)
|
||||
else:
|
||||
logger.info("ReActAgent: LLM response after handler had no textual content.")
|
||||
logger.info(
|
||||
"ReActAgent: LLM response after handler had no textual content."
|
||||
)
|
||||
|
||||
if log_context:
|
||||
log_context.stacks.append(
|
||||
{"component": "agent_tool_calls", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
{
|
||||
"component": "agent_tool_calls",
|
||||
"data": {"tool_calls": self.tool_calls.copy()},
|
||||
}
|
||||
)
|
||||
|
||||
yield {"sources": retrieved_data}
|
||||
@@ -165,13 +208,17 @@ class ReActAgent(BaseAgent):
|
||||
display_tool_calls.append(cleaned_tc)
|
||||
if display_tool_calls:
|
||||
yield {"tool_calls": display_tool_calls}
|
||||
|
||||
|
||||
if "SATISFIED" in content_after_handler:
|
||||
logger.info("ReActAgent: LLM satisfied with the plan and data. Stopping reasoning.")
|
||||
logger.info(
|
||||
"ReActAgent: LLM satisfied with the plan and data. Stopping reasoning."
|
||||
)
|
||||
break
|
||||
|
||||
# 3. Create Final Answer based on all observations
|
||||
final_answer_stream = self._create_final_answer(query, self.observations, log_context)
|
||||
final_answer_stream = self._create_final_answer(
|
||||
query, self.observations, log_context
|
||||
)
|
||||
for answer_chunk in final_answer_stream:
|
||||
yield {"answer": answer_chunk}
|
||||
logger.info("ReActAgent: Finished generating final answer.")
|
||||
@@ -184,12 +231,16 @@ class ReActAgent(BaseAgent):
|
||||
summaries = docs_data if docs_data else "No documents retrieved."
|
||||
plan_prompt_filled = plan_prompt_filled.replace("{summaries}", summaries)
|
||||
plan_prompt_filled = plan_prompt_filled.replace("{prompt}", self.prompt or "")
|
||||
plan_prompt_filled = plan_prompt_filled.replace("{observations}", "\n".join(self.observations))
|
||||
plan_prompt_filled = plan_prompt_filled.replace(
|
||||
"{observations}", "\n".join(self.observations)
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": plan_prompt_filled}]
|
||||
|
||||
plan_stream_from_llm = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=getattr(self, 'tools', None) # Use self.tools
|
||||
model=self.gpt_model,
|
||||
messages=messages,
|
||||
tools=getattr(self, "tools", None), # Use self.tools
|
||||
)
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm)
|
||||
@@ -206,8 +257,12 @@ class ReActAgent(BaseAgent):
|
||||
observation_string = "\n".join(observations)
|
||||
max_obs_len = 10000
|
||||
if len(observation_string) > max_obs_len:
|
||||
observation_string = observation_string[:max_obs_len] + "\n...[observations truncated]"
|
||||
logger.warning("ReActAgent: Truncated observations for final answer prompt due to length.")
|
||||
observation_string = (
|
||||
observation_string[:max_obs_len] + "\n...[observations truncated]"
|
||||
)
|
||||
logger.warning(
|
||||
"ReActAgent: Truncated observations for final answer prompt due to length."
|
||||
)
|
||||
|
||||
final_answer_prompt_filled = final_prompt_template.format(
|
||||
query=query, observations=observation_string
|
||||
@@ -226,4 +281,4 @@ class ReActAgent(BaseAgent):
|
||||
for chunk in final_answer_stream_from_llm:
|
||||
content_piece = self._extract_content_from_llm_response(chunk)
|
||||
if content_piece:
|
||||
yield content_piece
|
||||
yield content_piece
|
||||
|
||||
@@ -822,6 +822,70 @@ class PinnedAgents(Resource):
|
||||
return make_response(jsonify(list_pinned_agents), 200)
|
||||
|
||||
|
||||
@agents_ns.route("/template_agents")
|
||||
class GetTemplateAgents(Resource):
|
||||
@api.doc(description="Get template/premade agents")
|
||||
def get(self):
|
||||
try:
|
||||
template_agents = agents_collection.find({"user": "system"})
|
||||
template_agents = [
|
||||
{
|
||||
"id": str(agent["_id"]),
|
||||
"name": agent["name"],
|
||||
"description": agent["description"],
|
||||
"image": agent.get("image", ""),
|
||||
}
|
||||
for agent in template_agents
|
||||
]
|
||||
return make_response(jsonify(template_agents), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Template agents fetch error: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@agents_ns.route("/adopt_agent")
|
||||
class AdoptAgent(Resource):
|
||||
@api.doc(params={"id": "Agent ID"}, description="Adopt an agent by ID")
|
||||
def post(self):
|
||||
if not (decoded_token := request.decoded_token):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
if not (agent_id := request.args.get("id")):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID required"}), 400
|
||||
)
|
||||
|
||||
try:
|
||||
agent = agents_collection.find_one(
|
||||
{"_id": ObjectId(agent_id), "user": "system"}
|
||||
)
|
||||
if not agent:
|
||||
return make_response(jsonify({"status": "Not found"}), 404)
|
||||
|
||||
new_agent = agent.copy()
|
||||
new_agent.pop("_id", None)
|
||||
new_agent["user"] = decoded_token["sub"]
|
||||
new_agent["status"] = "published"
|
||||
new_agent["lastUsedAt"] = datetime.datetime.now(datetime.timezone.utc)
|
||||
new_agent["key"] = str(uuid.uuid4())
|
||||
insert_result = agents_collection.insert_one(new_agent)
|
||||
|
||||
response_agent = new_agent.copy()
|
||||
response_agent.pop("_id", None)
|
||||
response_agent["id"] = str(insert_result.inserted_id)
|
||||
response_agent["tool_details"] = resolve_tool_details(
|
||||
response_agent.get("tools", [])
|
||||
)
|
||||
if isinstance(response_agent.get("source"), DBRef):
|
||||
response_agent["source"] = str(response_agent["source"].id)
|
||||
return make_response(
|
||||
jsonify({"success": True, "agent": response_agent}), 200
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Agent adopt error: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@agents_ns.route("/pin_agent")
|
||||
class PinAgent(Resource):
|
||||
@api.doc(params={"id": "ID of the agent"}, description="Pin or unpin an agent")
|
||||
|
||||
@@ -57,6 +57,8 @@ class Settings(BaseSettings):
|
||||
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||
MICROSOFT_REDIRECT_URI: Optional[str] = "http://localhost:7091/api/connectors/callback" # Your project's redirect URI that you registered in Azure Portal.
|
||||
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
|
||||
# GitHub source
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
|
||||
# LLM Cache
|
||||
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
|
||||
|
||||
@@ -1,44 +1,135 @@
|
||||
import base64
|
||||
import requests
|
||||
from typing import List
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from langchain_core.documents import Document
|
||||
from application.parser.schema.base import Document
|
||||
import mimetypes
|
||||
from application.core.settings import settings
|
||||
|
||||
class GitHubLoader(BaseRemote):
|
||||
def __init__(self):
|
||||
self.access_token = None
|
||||
self.access_token = settings.GITHUB_ACCESS_TOKEN
|
||||
self.headers = {
|
||||
"Authorization": f"token {self.access_token}"
|
||||
} if self.access_token else {}
|
||||
"Authorization": f"token {self.access_token}",
|
||||
"Accept": "application/vnd.github.v3+json"
|
||||
} if self.access_token else {
|
||||
"Accept": "application/vnd.github.v3+json"
|
||||
}
|
||||
return
|
||||
|
||||
def fetch_file_content(self, repo_url: str, file_path: str) -> str:
|
||||
def is_text_file(self, file_path: str) -> bool:
|
||||
"""Determine if a file is a text file based on extension."""
|
||||
# Common text file extensions
|
||||
text_extensions = {
|
||||
'.txt', '.md', '.markdown', '.rst', '.json', '.xml', '.yaml', '.yml',
|
||||
'.py', '.js', '.ts', '.jsx', '.tsx', '.java', '.c', '.cpp', '.h', '.hpp',
|
||||
'.cs', '.go', '.rs', '.rb', '.php', '.swift', '.kt', '.scala',
|
||||
'.html', '.css', '.scss', '.sass', '.less',
|
||||
'.sh', '.bash', '.zsh', '.fish',
|
||||
'.sql', '.r', '.m', '.mat',
|
||||
'.ini', '.cfg', '.conf', '.config', '.env',
|
||||
'.gitignore', '.dockerignore', '.editorconfig',
|
||||
'.log', '.csv', '.tsv'
|
||||
}
|
||||
|
||||
# Get file extension
|
||||
file_lower = file_path.lower()
|
||||
for ext in text_extensions:
|
||||
if file_lower.endswith(ext):
|
||||
return True
|
||||
|
||||
# Also check MIME type
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
if mime_type and (mime_type.startswith("text") or mime_type in ["application/json", "application/xml"]):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def fetch_file_content(self, repo_url: str, file_path: str) -> Optional[str]:
|
||||
"""Fetch file content. Returns None if file should be skipped (binary files or empty files)."""
|
||||
url = f"https://api.github.com/repos/{repo_url}/contents/{file_path}"
|
||||
response = requests.get(url, headers=self.headers)
|
||||
response = self._make_request(url)
|
||||
|
||||
if response.status_code == 200:
|
||||
content = response.json()
|
||||
mime_type, _ = mimetypes.guess_type(file_path) # Guess the MIME type based on the file extension
|
||||
content = response.json()
|
||||
|
||||
if content.get("encoding") == "base64":
|
||||
if mime_type and mime_type.startswith("text"): # Handle only text files
|
||||
try:
|
||||
decoded_content = base64.b64decode(content["content"]).decode("utf-8")
|
||||
return f"Filename: {file_path}\n\n{decoded_content}"
|
||||
except Exception as e:
|
||||
raise e
|
||||
else:
|
||||
return f"Filename: {file_path} is a binary file and was skipped."
|
||||
if content.get("encoding") == "base64":
|
||||
if self.is_text_file(file_path): # Handle only text files
|
||||
try:
|
||||
decoded_content = base64.b64decode(content["content"]).decode("utf-8").strip()
|
||||
# Skip empty files
|
||||
if not decoded_content:
|
||||
return None
|
||||
return decoded_content
|
||||
except Exception:
|
||||
# If decoding fails, it's probably a binary file
|
||||
return None
|
||||
else:
|
||||
return f"Filename: {file_path}\n\n{content['content']}"
|
||||
# Skip binary files by returning None
|
||||
return None
|
||||
else:
|
||||
response.raise_for_status()
|
||||
file_content = content['content'].strip()
|
||||
# Skip empty files
|
||||
if not file_content:
|
||||
return None
|
||||
return file_content
|
||||
|
||||
def _make_request(self, url: str, max_retries: int = 3) -> requests.Response:
|
||||
"""Make a request with retry logic for rate limiting"""
|
||||
for attempt in range(max_retries):
|
||||
response = requests.get(url, headers=self.headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response
|
||||
elif response.status_code == 403:
|
||||
# Check if it's a rate limit issue
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = error_data.get("message", "")
|
||||
|
||||
# Check rate limit headers
|
||||
remaining = response.headers.get("X-RateLimit-Remaining", "unknown")
|
||||
reset_time = response.headers.get("X-RateLimit-Reset", "unknown")
|
||||
|
||||
print(f"GitHub API 403 Error: {error_msg}")
|
||||
print(f"Rate limit remaining: {remaining}, Reset time: {reset_time}")
|
||||
|
||||
if "rate limit" in error_msg.lower():
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2 ** attempt # Exponential backoff
|
||||
print(f"Rate limit hit, waiting {wait_time} seconds before retry...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
# Provide helpful error message
|
||||
if remaining == "0":
|
||||
raise Exception(f"GitHub API rate limit exceeded. Please set GITHUB_ACCESS_TOKEN environment variable. Reset time: {reset_time}")
|
||||
else:
|
||||
raise Exception(f"GitHub API error: {error_msg}. This may require authentication - set GITHUB_ACCESS_TOKEN environment variable.")
|
||||
except Exception as e:
|
||||
if isinstance(e, Exception) and "GitHub API" in str(e):
|
||||
raise
|
||||
# If we can't parse the response, raise the original error
|
||||
response.raise_for_status()
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
return response
|
||||
|
||||
def fetch_repo_files(self, repo_url: str, path: str = "") -> List[str]:
|
||||
url = f"https://api.github.com/repos/{repo_url}/contents/{path}"
|
||||
response = requests.get(url, headers={**self.headers, "Accept": "application/vnd.github.v3.raw"})
|
||||
response = self._make_request(url)
|
||||
|
||||
contents = response.json()
|
||||
|
||||
# Handle error responses from GitHub API
|
||||
if isinstance(contents, dict) and "message" in contents:
|
||||
raise Exception(f"GitHub API error: {contents.get('message')}")
|
||||
|
||||
# Ensure contents is a list
|
||||
if not isinstance(contents, list):
|
||||
raise TypeError(f"Expected list from GitHub API, got {type(contents).__name__}: {contents}")
|
||||
|
||||
files = []
|
||||
for item in contents:
|
||||
if item["type"] == "file":
|
||||
@@ -53,6 +144,15 @@ class GitHubLoader(BaseRemote):
|
||||
documents = []
|
||||
for file_path in files:
|
||||
content = self.fetch_file_content(repo_name, file_path)
|
||||
documents.append(Document(page_content=content, metadata={"title": file_path,
|
||||
"source": f"https://github.com/{repo_name}/blob/main/{file_path}"}))
|
||||
# Skip binary files (content is None)
|
||||
if content is None:
|
||||
continue
|
||||
documents.append(Document(
|
||||
text=content,
|
||||
doc_id=file_path,
|
||||
extra_info={
|
||||
"title": file_path,
|
||||
"source": f"https://github.com/{repo_name}/blob/main/{file_path}"
|
||||
}
|
||||
))
|
||||
return documents
|
||||
|
||||
0
application/seed/__init__.py
Normal file
0
application/seed/__init__.py
Normal file
26
application/seed/commands.py
Normal file
26
application/seed/commands.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import click
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.seed.seeder import DatabaseSeeder
|
||||
|
||||
|
||||
@click.group()
|
||||
def seed():
|
||||
"""Database seeding commands"""
|
||||
pass
|
||||
|
||||
|
||||
@seed.command()
|
||||
@click.option("--force", is_flag=True, help="Force reseeding even if data exists")
|
||||
def init(force):
|
||||
"""Initialize database with seed data"""
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
seeder = DatabaseSeeder(db)
|
||||
seeder.seed_initial_data(force=force)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
seed()
|
||||
36
application/seed/config/agents_template.yaml
Normal file
36
application/seed/config/agents_template.yaml
Normal file
@@ -0,0 +1,36 @@
|
||||
# Configuration for Premade Agents
|
||||
# This file contains template agents that will be seeded into the database
|
||||
|
||||
agents:
|
||||
# Basic Agent Template
|
||||
- name: "Agent Name" # Required: Unique name for the agent
|
||||
description: "What this agent does" # Required: Brief description of the agent's purpose
|
||||
image: "URL_TO_IMAGE" # Optional: URL to agent's avatar/image
|
||||
agent_type: "classic" # Required: Type of agent (e.g., classic, react, etc.)
|
||||
prompt_id: "default" # Optional: Reference to prompt template
|
||||
prompt: # Optional: Define new prompt
|
||||
name: "New Prompt"
|
||||
content: "You are new agent with cool new prompt."
|
||||
chunks: "0" # Optional: Chunking strategy for documents
|
||||
retriever: "" # Optional: Retriever type for document search
|
||||
|
||||
# Source Configuration (where the agent gets its knowledge)
|
||||
source: # Optional: Select a source to link with agent
|
||||
name: "Source Display Name" # Human-readable name for the source
|
||||
url: "https://example.com/data-source" # URL or path to knowledge source
|
||||
loader: "url" # Type of loader (url, pdf, txt, etc.)
|
||||
|
||||
# Tools Configuration (what capabilities the agent has)
|
||||
tools: # Optional: Remove if agent doesn't need tools
|
||||
- name: "tool_name" # Must match a supported tool name
|
||||
display_name: "Tool Display Name" # Optional: Human-readable name for the tool
|
||||
config:
|
||||
# Tool-specific configuration
|
||||
# Example for DuckDuckGo:
|
||||
# token: "${DDG_API_KEY}" # ${} denotes environment variable
|
||||
|
||||
# Add more tools as needed
|
||||
# - name: "another_tool"
|
||||
# config:
|
||||
# param1: "value1"
|
||||
# param2: "${ENV_VAR}"
|
||||
94
application/seed/config/premade_agents.yaml
Normal file
94
application/seed/config/premade_agents.yaml
Normal file
@@ -0,0 +1,94 @@
|
||||
# Configuration for Premade Agents
|
||||
|
||||
agents:
|
||||
- name: "Assistant"
|
||||
description: "Your general-purpose AI assistant. Ready to help with a wide range of tasks."
|
||||
image: "https://d3dg1063dc54p9.cloudfront.net/imgs/agents/agent-logo.svg"
|
||||
agent_type: "classic"
|
||||
prompt_id: "default"
|
||||
chunks: "0"
|
||||
retriever: ""
|
||||
|
||||
# Tools Configuration
|
||||
tools:
|
||||
- name: "tool_name"
|
||||
display_name: "read_webpage"
|
||||
config:
|
||||
|
||||
- name: "Researcher"
|
||||
description: "A specialized research agent that performs deep dives into subjects."
|
||||
image: "https://d3dg1063dc54p9.cloudfront.net/imgs/agents/agent-researcher.svg"
|
||||
agent_type: "react"
|
||||
prompt:
|
||||
name: "Researcher-Agent"
|
||||
content: |
|
||||
You are a specialized AI research assistant, DocsGPT. Your primary function is to conduct in-depth research on a given subject or question. You are methodical, thorough, and analytical. You should perform multiple iterations of thinking to gather and synthesize information before providing a final, comprehensive answer.
|
||||
|
||||
You have access to the 'Read Webpage' tool. Use this tool to explore sources, gather data, and deepen your understanding. Be proactive in using the tool to fill in knowledge gaps and validate information.
|
||||
|
||||
Users can Upload documents for your context as attachments or sources via UI using the Conversation input box.
|
||||
If appropriate, your answers can include code examples, formatted as follows:
|
||||
```(language)
|
||||
(code)
|
||||
```
|
||||
Users are also able to see charts and diagrams if you use them with valid mermaid syntax in your responses. Try to respond with mermaid charts if visualization helps with users queries. You effectively utilize chat history, ensuring relevant and tailored responses. Try to use additional provided context if it's available, otherwise use your knowledge and tool capabilities.
|
||||
----------------
|
||||
Possible additional context from uploaded sources:
|
||||
{summaries}
|
||||
|
||||
chunks: "0"
|
||||
retriever: ""
|
||||
|
||||
# Tools Configuration
|
||||
tools:
|
||||
- name: "tool_name"
|
||||
display_name: "read_webpage"
|
||||
config:
|
||||
|
||||
- name: "Search Widget"
|
||||
description: "A powerful search widget agent. Ask it anything about DocsGPT"
|
||||
image: "https://d3dg1063dc54p9.cloudfront.net/imgs/agents/agent-search.svg"
|
||||
agent_type: "classic"
|
||||
prompt:
|
||||
name: "Search-Agent"
|
||||
content: |
|
||||
You are a website search assistant, DocsGPT. Your sole purpose is to help users find information within the provided context of the DocsGPT documentation. Act as a specialized search engine.
|
||||
|
||||
Your answers must be based *only* on the provided context. Do not use any external knowledge. If the answer is not in the context, inform the user that you could not find the information within the documentation.
|
||||
|
||||
Keep your responses concise and directly related to the user's query, pointing them to the most relevant information.
|
||||
----------------
|
||||
Possible additional context from uploaded sources:
|
||||
{summaries}
|
||||
|
||||
chunks: "8"
|
||||
retriever: ""
|
||||
|
||||
source:
|
||||
name: "DocsGPT-Docs"
|
||||
url: "https://d3dg1063dc54p9.cloudfront.net/agent-source/docsgpt-documentation.md" # URL to DocsGPT documentation
|
||||
loader: "url"
|
||||
|
||||
- name: "Support Widget"
|
||||
description: "A friendly support widget agent to help you with any questions."
|
||||
image: "https://d3dg1063dc54p9.cloudfront.net/imgs/agents/agent-support.svg"
|
||||
agent_type: "classic"
|
||||
prompt:
|
||||
name: "Support-Agent"
|
||||
content: |
|
||||
You are a helpful AI support widget agent, DocsGPT. Your goal is to assist users by answering their questions about our website, product and its features. Provide friendly, clear, and direct support.
|
||||
|
||||
Your knowledge is strictly limited to the provided context from the DocsGPT documentation. You must not answer questions outside of this scope. If a user asks something you cannot answer from the context, politely state that you can only help with questions about this website.
|
||||
|
||||
Effectively utilize chat history to understand the user's issue fully. Guide users to the information they need in a helpful and conversational manner.
|
||||
----------------
|
||||
Possible additional context from uploaded sources:
|
||||
{summaries}
|
||||
|
||||
chunks: "8"
|
||||
retriever: ""
|
||||
|
||||
source:
|
||||
name: "DocsGPT-Docs"
|
||||
url: "https://d3dg1063dc54p9.cloudfront.net/agent-source/docsgpt-documentation.md" # URL to DocsGPT documentation
|
||||
loader: "url"
|
||||
277
application/seed/seeder.py
Normal file
277
application/seed/seeder.py
Normal file
@@ -0,0 +1,277 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import yaml
|
||||
from bson import ObjectId
|
||||
from bson.dbref import DBRef
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pymongo import MongoClient
|
||||
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api.user.tasks import ingest_remote
|
||||
|
||||
load_dotenv()
|
||||
tool_config = {}
|
||||
tool_manager = ToolManager(config=tool_config)
|
||||
|
||||
|
||||
class DatabaseSeeder:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
self.tools_collection = self.db["user_tools"]
|
||||
self.sources_collection = self.db["sources"]
|
||||
self.agents_collection = self.db["agents"]
|
||||
self.prompts_collection = self.db["prompts"]
|
||||
self.system_user_id = "system"
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def seed_initial_data(self, config_path: str = None, force=False):
|
||||
"""Main entry point for seeding all initial data"""
|
||||
if not force and self._is_already_seeded():
|
||||
self.logger.info("Database already seeded. Use force=True to reseed.")
|
||||
return
|
||||
config_path = config_path or os.path.join(
|
||||
os.path.dirname(__file__), "config", "premade_agents.yaml"
|
||||
)
|
||||
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
self._seed_from_config(config)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load seeding config: {str(e)}")
|
||||
raise
|
||||
|
||||
def _seed_from_config(self, config: Dict):
|
||||
"""Seed all data from configuration"""
|
||||
self.logger.info("🌱 Starting seeding...")
|
||||
|
||||
if not config.get("agents"):
|
||||
self.logger.warning("No agents found in config")
|
||||
return
|
||||
used_tool_ids = set()
|
||||
|
||||
for agent_config in config["agents"]:
|
||||
try:
|
||||
self.logger.info(f"Processing agent: {agent_config['name']}")
|
||||
|
||||
# 1. Handle Source
|
||||
|
||||
source_result = self._handle_source(agent_config)
|
||||
if source_result is False:
|
||||
self.logger.error(
|
||||
f"Skipping agent {agent_config['name']} due to source ingestion failure"
|
||||
)
|
||||
continue
|
||||
source_id = source_result
|
||||
# 2. Handle Tools
|
||||
|
||||
tool_ids = self._handle_tools(agent_config)
|
||||
if len(tool_ids) == 0:
|
||||
self.logger.warning(
|
||||
f"No valid tools for agent {agent_config['name']}"
|
||||
)
|
||||
used_tool_ids.update(tool_ids)
|
||||
|
||||
# 3. Handle Prompt
|
||||
|
||||
prompt_id = self._handle_prompt(agent_config)
|
||||
|
||||
# 4. Create Agent
|
||||
|
||||
agent_data = {
|
||||
"user": self.system_user_id,
|
||||
"name": agent_config["name"],
|
||||
"description": agent_config["description"],
|
||||
"image": agent_config.get("image", ""),
|
||||
"source": (
|
||||
DBRef("sources", ObjectId(source_id)) if source_id else ""
|
||||
),
|
||||
"tools": [str(tid) for tid in tool_ids],
|
||||
"agent_type": agent_config["agent_type"],
|
||||
"prompt_id": prompt_id or agent_config.get("prompt_id", "default"),
|
||||
"chunks": agent_config.get("chunks", "0"),
|
||||
"retriever": agent_config.get("retriever", ""),
|
||||
"status": "template",
|
||||
"createdAt": datetime.now(timezone.utc),
|
||||
"updatedAt": datetime.now(timezone.utc),
|
||||
}
|
||||
|
||||
existing = self.agents_collection.find_one(
|
||||
{"user": self.system_user_id, "name": agent_config["name"]}
|
||||
)
|
||||
if existing:
|
||||
self.logger.info(f"Updating existing agent: {agent_config['name']}")
|
||||
self.agents_collection.update_one(
|
||||
{"_id": existing["_id"]}, {"$set": agent_data}
|
||||
)
|
||||
agent_id = existing["_id"]
|
||||
else:
|
||||
self.logger.info(f"Creating new agent: {agent_config['name']}")
|
||||
result = self.agents_collection.insert_one(agent_data)
|
||||
agent_id = result.inserted_id
|
||||
self.logger.info(
|
||||
f"Successfully processed agent: {agent_config['name']} (ID: {agent_id})"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Error processing agent {agent_config['name']}: {str(e)}"
|
||||
)
|
||||
continue
|
||||
self.logger.info("✅ Database seeding completed")
|
||||
|
||||
def _handle_source(self, agent_config: Dict) -> Union[ObjectId, None, bool]:
|
||||
"""Handle source ingestion and return source ID"""
|
||||
if not agent_config.get("source"):
|
||||
self.logger.info(
|
||||
"No source provided for agent - will create agent without source"
|
||||
)
|
||||
return None
|
||||
source_config = agent_config["source"]
|
||||
self.logger.info(f"Ingesting source: {source_config['url']}")
|
||||
|
||||
try:
|
||||
existing = self.sources_collection.find_one(
|
||||
{"user": self.system_user_id, "remote_data": source_config["url"]}
|
||||
)
|
||||
if existing:
|
||||
self.logger.info(f"Source already exists: {existing['_id']}")
|
||||
return existing["_id"]
|
||||
# Ingest new source using worker
|
||||
|
||||
task = ingest_remote.delay(
|
||||
source_data=source_config["url"],
|
||||
job_name=source_config["name"],
|
||||
user=self.system_user_id,
|
||||
loader=source_config.get("loader", "url"),
|
||||
)
|
||||
|
||||
result = task.get(timeout=300)
|
||||
|
||||
if not task.successful():
|
||||
raise Exception(f"Source ingestion failed: {result}")
|
||||
source_id = None
|
||||
if isinstance(result, dict) and "id" in result:
|
||||
source_id = result["id"]
|
||||
else:
|
||||
raise Exception(f"Source ingestion result missing 'id': {result}")
|
||||
self.logger.info(f"Source ingested successfully: {source_id}")
|
||||
return source_id
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to ingest source: {str(e)}")
|
||||
return False
|
||||
|
||||
def _handle_tools(self, agent_config: Dict) -> List[ObjectId]:
|
||||
"""Handle tool creation and return list of tool IDs"""
|
||||
tool_ids = []
|
||||
if not agent_config.get("tools"):
|
||||
return tool_ids
|
||||
for tool_config in agent_config["tools"]:
|
||||
try:
|
||||
tool_name = tool_config["name"]
|
||||
processed_config = self._process_config(tool_config.get("config", {}))
|
||||
self.logger.info(f"Processing tool: {tool_name}")
|
||||
|
||||
existing = self.tools_collection.find_one(
|
||||
{
|
||||
"user": self.system_user_id,
|
||||
"name": tool_name,
|
||||
"config": processed_config,
|
||||
}
|
||||
)
|
||||
if existing:
|
||||
self.logger.info(f"Tool already exists: {existing['_id']}")
|
||||
tool_ids.append(existing["_id"])
|
||||
continue
|
||||
tool_data = {
|
||||
"user": self.system_user_id,
|
||||
"name": tool_name,
|
||||
"displayName": tool_config.get("display_name", tool_name),
|
||||
"description": tool_config.get("description", ""),
|
||||
"actions": tool_manager.tools[tool_name].get_actions_metadata(),
|
||||
"config": processed_config,
|
||||
"status": True,
|
||||
}
|
||||
|
||||
result = self.tools_collection.insert_one(tool_data)
|
||||
tool_ids.append(result.inserted_id)
|
||||
self.logger.info(f"Created new tool: {result.inserted_id}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to process tool {tool_name}: {str(e)}")
|
||||
continue
|
||||
return tool_ids
|
||||
|
||||
def _handle_prompt(self, agent_config: Dict) -> Optional[str]:
|
||||
"""Handle prompt creation and return prompt ID"""
|
||||
if not agent_config.get("prompt"):
|
||||
return None
|
||||
|
||||
prompt_config = agent_config["prompt"]
|
||||
prompt_name = prompt_config.get("name", f"{agent_config['name']} Prompt")
|
||||
prompt_content = prompt_config.get("content", "")
|
||||
|
||||
if not prompt_content:
|
||||
self.logger.warning(
|
||||
f"No prompt content provided for agent {agent_config['name']}"
|
||||
)
|
||||
return None
|
||||
|
||||
self.logger.info(f"Processing prompt: {prompt_name}")
|
||||
|
||||
try:
|
||||
existing = self.prompts_collection.find_one(
|
||||
{
|
||||
"user": self.system_user_id,
|
||||
"name": prompt_name,
|
||||
"content": prompt_content,
|
||||
}
|
||||
)
|
||||
if existing:
|
||||
self.logger.info(f"Prompt already exists: {existing['_id']}")
|
||||
return str(existing["_id"])
|
||||
|
||||
prompt_data = {
|
||||
"name": prompt_name,
|
||||
"content": prompt_content,
|
||||
"user": self.system_user_id,
|
||||
}
|
||||
|
||||
result = self.prompts_collection.insert_one(prompt_data)
|
||||
prompt_id = str(result.inserted_id)
|
||||
self.logger.info(f"Created new prompt: {prompt_id}")
|
||||
return prompt_id
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to process prompt {prompt_name}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _process_config(self, config: Dict) -> Dict:
|
||||
"""Process config values to replace environment variables"""
|
||||
processed = {}
|
||||
for key, value in config.items():
|
||||
if (
|
||||
isinstance(value, str)
|
||||
and value.startswith("${")
|
||||
and value.endswith("}")
|
||||
):
|
||||
env_var = value[2:-1]
|
||||
processed[key] = os.getenv(env_var, "")
|
||||
else:
|
||||
processed[key] = value
|
||||
return processed
|
||||
|
||||
def _is_already_seeded(self) -> bool:
|
||||
"""Check if premade agents already exist"""
|
||||
return self.agents_collection.count_documents({"user": self.system_user_id}) > 0
|
||||
|
||||
@classmethod
|
||||
def initialize_from_env(cls, worker=None):
|
||||
"""Factory method to create seeder from environment"""
|
||||
mongo_uri = os.getenv("MONGO_URI", "mongodb://localhost:27017")
|
||||
db_name = os.getenv("MONGO_DB_NAME", "docsgpt")
|
||||
client = MongoClient(mongo_uri)
|
||||
db = client[db_name]
|
||||
return cls(db)
|
||||
@@ -168,6 +168,10 @@ def validate_function_name(function_name):
|
||||
|
||||
|
||||
def generate_image_url(image_path):
|
||||
if isinstance(image_path, str) and (
|
||||
image_path.startswith("http://") or image_path.startswith("https://")
|
||||
):
|
||||
return image_path
|
||||
strategy = getattr(settings, "URL_STRATEGY", "backend")
|
||||
if strategy == "s3":
|
||||
bucket_name = getattr(settings, "S3_BUCKET_NAME", "docsgpt-test-bucket")
|
||||
|
||||
@@ -39,6 +39,7 @@ sources_collection = db["sources"]
|
||||
|
||||
# Constants
|
||||
|
||||
|
||||
MIN_TOKENS = 150
|
||||
MAX_TOKENS = 1250
|
||||
RECURSION_DEPTH = 2
|
||||
@@ -740,7 +741,13 @@ def remote_worker(
|
||||
if os.path.exists(full_path):
|
||||
shutil.rmtree(full_path)
|
||||
logging.info("remote_worker task completed successfully")
|
||||
return {"urls": source_data, "name_job": name_job, "user": user, "limited": False}
|
||||
return {
|
||||
"id": str(id),
|
||||
"urls": source_data,
|
||||
"name_job": name_job,
|
||||
"user": user,
|
||||
"limited": False,
|
||||
}
|
||||
|
||||
|
||||
def sync(
|
||||
|
||||
Reference in New Issue
Block a user