mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-03-03 20:33:45 +00:00
proxy for api-tool
-Only for api_tool for now, if this solution works well then implementation for other tools is required - Need to check api keys creation with the current proxies - Show connection string example at creation - locale needs updates for other languages
This commit is contained in:
@@ -11,4 +11,7 @@ class AgentCreator:
|
||||
agent_class = cls.agents.get(type.lower())
|
||||
if not agent_class:
|
||||
raise ValueError(f"No agent class found for type {type}")
|
||||
config = kwargs.pop('config', None)
|
||||
if isinstance(config, dict) and 'proxy_id' in config and 'proxy_id' not in kwargs:
|
||||
kwargs['proxy_id'] = config['proxy_id']
|
||||
return agent_class(*args, **kwargs)
|
||||
|
||||
@@ -17,6 +17,7 @@ class BaseAgent:
|
||||
api_key,
|
||||
user_api_key=None,
|
||||
decoded_token=None,
|
||||
proxy_id=None,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.llm = LLMCreator.create_llm(
|
||||
@@ -30,6 +31,7 @@ class BaseAgent:
|
||||
self.tools = []
|
||||
self.tool_config = {}
|
||||
self.tool_calls = []
|
||||
self.proxy_id = proxy_id
|
||||
|
||||
def gen(self, *args, **kwargs) -> Generator[Dict, None, None]:
|
||||
raise NotImplementedError('Method "gen" must be implemented in the child class')
|
||||
@@ -41,6 +43,11 @@ class BaseAgent:
|
||||
user_tools = user_tools_collection.find({"user": user, "status": True})
|
||||
user_tools = list(user_tools)
|
||||
tools_by_id = {str(tool["_id"]): tool for tool in user_tools}
|
||||
if hasattr(self, 'proxy_id') and self.proxy_id:
|
||||
for tool_id, tool in tools_by_id.items():
|
||||
if 'config' not in tool:
|
||||
tool['config'] = {}
|
||||
tool['config']['proxy_id'] = self.proxy_id
|
||||
return tools_by_id
|
||||
|
||||
def _build_tool_parameters(self, action):
|
||||
@@ -126,6 +133,7 @@ class BaseAgent:
|
||||
"method": tool_data["config"]["actions"][action_name]["method"],
|
||||
"headers": headers,
|
||||
"query_params": query_params,
|
||||
"proxy_id": self.proxy_id,
|
||||
}
|
||||
if tool_data["name"] == "api_tool"
|
||||
else tool_data["config"]
|
||||
|
||||
@@ -18,9 +18,10 @@ class ClassicAgent(BaseAgent):
|
||||
prompt="",
|
||||
chat_history=None,
|
||||
decoded_token=None,
|
||||
proxy_id=None,
|
||||
):
|
||||
super().__init__(
|
||||
endpoint, llm_name, gpt_model, api_key, user_api_key, decoded_token
|
||||
endpoint, llm_name, gpt_model, api_key, user_api_key, decoded_token, proxy_id
|
||||
)
|
||||
self.user = decoded_token.get("sub")
|
||||
self.prompt = prompt
|
||||
|
||||
@@ -23,15 +23,43 @@ class APITool(Tool):
|
||||
)
|
||||
|
||||
def _make_api_call(self, url, method, headers, query_params, body):
|
||||
sanitized_headers = {}
|
||||
for key, value in headers.items():
|
||||
if isinstance(value, str):
|
||||
sanitized_value = value.encode('latin-1', errors='ignore').decode('latin-1')
|
||||
sanitized_headers[key] = sanitized_value
|
||||
else:
|
||||
sanitized_headers[key] = value
|
||||
|
||||
if query_params:
|
||||
url = f"{url}?{requests.compat.urlencode(query_params)}"
|
||||
if isinstance(body, dict):
|
||||
body = json.dumps(body)
|
||||
response = None
|
||||
try:
|
||||
print(f"Making API call: {method} {url} with body: {body}")
|
||||
if body == "{}":
|
||||
body = None
|
||||
response = requests.request(method, url, headers=headers, data=body)
|
||||
|
||||
proxy_id = self.config.get("proxy_id", None)
|
||||
request_kwargs = {
|
||||
'method': method,
|
||||
'url': url,
|
||||
'headers': sanitized_headers,
|
||||
'data': body
|
||||
}
|
||||
try:
|
||||
if proxy_id:
|
||||
from application.agents.tools.proxy_handler import apply_proxy_to_request
|
||||
response = apply_proxy_to_request(
|
||||
requests.request,
|
||||
proxy_id=proxy_id,
|
||||
**request_kwargs
|
||||
)
|
||||
else:
|
||||
response = requests.request(**request_kwargs)
|
||||
except ImportError:
|
||||
response = requests.request(**request_kwargs)
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get(
|
||||
"Content-Type", "application/json"
|
||||
|
||||
63
application/agents/tools/proxy_handler.py
Normal file
63
application/agents/tools/proxy_handler.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import logging
|
||||
import requests
|
||||
from typing import Dict, Optional
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Get MongoDB connection
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo["docsgpt"]
|
||||
proxies_collection = db["proxies"]
|
||||
|
||||
def get_proxy_config(proxy_id: str) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
Retrieve proxy configuration from the database.
|
||||
|
||||
Args:
|
||||
proxy_id: The ID of the proxy configuration
|
||||
|
||||
Returns:
|
||||
A dictionary with proxy configuration or None if not found
|
||||
"""
|
||||
if not proxy_id or proxy_id == "none":
|
||||
return None
|
||||
|
||||
try:
|
||||
if ObjectId.is_valid(proxy_id):
|
||||
proxy_config = proxies_collection.find_one({"_id": ObjectId(proxy_id)})
|
||||
if proxy_config and "connection" in proxy_config:
|
||||
connection_str = proxy_config["connection"].strip()
|
||||
if connection_str:
|
||||
# Format proxy for requests library
|
||||
return {
|
||||
"http": connection_str,
|
||||
"https": connection_str
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving proxy configuration: {e}")
|
||||
return None
|
||||
|
||||
def apply_proxy_to_request(request_func, proxy_id=None, **kwargs):
|
||||
"""
|
||||
Apply proxy configuration to a requests function if available.
|
||||
This is a minimal wrapper that doesn't change the function signature.
|
||||
|
||||
Args:
|
||||
request_func: The requests function to call (e.g., requests.get, requests.post)
|
||||
proxy_id: Optional proxy ID to use
|
||||
**kwargs: Arguments to pass to the request function
|
||||
|
||||
Returns:
|
||||
The response from the request
|
||||
"""
|
||||
if proxy_id:
|
||||
proxy_config = get_proxy_config(proxy_id)
|
||||
if proxy_config:
|
||||
kwargs['proxies'] = proxy_config
|
||||
logger.info(f"Using proxy for request")
|
||||
|
||||
return request_func(**kwargs)
|
||||
@@ -335,6 +335,9 @@ class Stream(Resource):
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"proxy_id": fields.String(
|
||||
required=False, description="Proxy ID to use for API calls"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
@@ -376,6 +379,7 @@ class Stream(Resource):
|
||||
)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
proxy_id = data.get("proxy_id", None)
|
||||
|
||||
index = data.get("index", None)
|
||||
chunks = int(data.get("chunks", 2))
|
||||
@@ -386,6 +390,7 @@ class Stream(Resource):
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
prompt_id = data_key.get("prompt_id", "default")
|
||||
proxy_id = data_key.get("proxy_id", None)
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
retriever_name = data_key.get("retriever", retriever_name)
|
||||
user_api_key = data["api_key"]
|
||||
@@ -422,6 +427,7 @@ class Stream(Resource):
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
proxy_id=proxy_id,
|
||||
chat_history=history,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
@@ -496,6 +502,9 @@ class Answer(Resource):
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"proxy_id": fields.String(
|
||||
required=False, description="Proxy ID to use for API calls"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
@@ -527,6 +536,7 @@ class Answer(Resource):
|
||||
)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
proxy_id = data.get("proxy_id", None)
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
@@ -535,6 +545,7 @@ class Answer(Resource):
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
prompt_id = data_key.get("prompt_id", "default")
|
||||
proxy_id = data_key.get("proxy_id", None)
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
retriever_name = data_key.get("retriever", retriever_name)
|
||||
user_api_key = data["api_key"]
|
||||
@@ -569,6 +580,7 @@ class Answer(Resource):
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
proxy_id=proxy_id,
|
||||
chat_history=history,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
@@ -27,6 +27,7 @@ db = mongo["docsgpt"]
|
||||
conversations_collection = db["conversations"]
|
||||
sources_collection = db["sources"]
|
||||
prompts_collection = db["prompts"]
|
||||
proxies_collection = db["proxies"]
|
||||
feedback_collection = db["feedback"]
|
||||
api_key_collection = db["api_keys"]
|
||||
token_usage_collection = db["token_usage"]
|
||||
@@ -919,6 +920,183 @@ class UpdatePrompt(Resource):
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
@user_ns.route("/api/get_proxies")
|
||||
class GetProxies(Resource):
|
||||
@api.doc(description="Get all proxies for the user")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
proxies = proxies_collection.find({"user": user})
|
||||
list_proxies = [
|
||||
{"id": "none", "name": "None", "type": "public"},
|
||||
]
|
||||
|
||||
for proxy in proxies:
|
||||
list_proxies.append(
|
||||
{
|
||||
"id": str(proxy["_id"]),
|
||||
"name": proxy["name"],
|
||||
"type": "private",
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving proxies: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
return make_response(jsonify(list_proxies), 200)
|
||||
|
||||
|
||||
@user_ns.route("/api/get_single_proxy")
|
||||
class GetSingleProxy(Resource):
|
||||
@api.doc(params={"id": "ID of the proxy"}, description="Get a single proxy by ID")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
proxy_id = request.args.get("id")
|
||||
if not proxy_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
|
||||
try:
|
||||
if proxy_id == "none":
|
||||
return make_response(jsonify({"connection": ""}), 200)
|
||||
|
||||
proxy = proxies_collection.find_one(
|
||||
{"_id": ObjectId(proxy_id), "user": user}
|
||||
)
|
||||
if not proxy:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving proxy: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
return make_response(jsonify({"connection": proxy["connection"]}), 200)
|
||||
|
||||
|
||||
@user_ns.route("/api/create_proxy")
|
||||
class CreateProxy(Resource):
|
||||
create_proxy_model = api.model(
|
||||
"CreateProxyModel",
|
||||
{
|
||||
"connection": fields.String(
|
||||
required=True, description="Connection string of the proxy"
|
||||
),
|
||||
"name": fields.String(required=True, description="Name of the proxy"),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(create_proxy_model)
|
||||
@api.doc(description="Create a new proxy")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["connection", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
resp = proxies_collection.insert_one(
|
||||
{
|
||||
"name": data["name"],
|
||||
"connection": data["connection"],
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
new_id = str(resp.inserted_id)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating proxy: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
return make_response(jsonify({"id": new_id}), 200)
|
||||
|
||||
|
||||
@user_ns.route("/api/delete_proxy")
|
||||
class DeleteProxy(Resource):
|
||||
delete_proxy_model = api.model(
|
||||
"DeleteProxyModel",
|
||||
{"id": fields.String(required=True, description="Proxy ID to delete")},
|
||||
)
|
||||
|
||||
@api.expect(delete_proxy_model)
|
||||
@api.doc(description="Delete a proxy by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
# Don't allow deleting the 'none' proxy
|
||||
if data["id"] == "none":
|
||||
return make_response(jsonify({"success": False, "message": "Cannot delete default proxy"}), 400)
|
||||
|
||||
result = proxies_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
|
||||
if result.deleted_count == 0:
|
||||
return make_response(jsonify({"success": False, "message": "Proxy not found"}), 404)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting proxy: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@user_ns.route("/api/update_proxy")
|
||||
class UpdateProxy(Resource):
|
||||
update_proxy_model = api.model(
|
||||
"UpdateProxyModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Proxy ID to update"),
|
||||
"name": fields.String(required=True, description="New name of the proxy"),
|
||||
"connection": fields.String(
|
||||
required=True, description="New connection string of the proxy"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(update_proxy_model)
|
||||
@api.doc(description="Update an existing proxy")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "name", "connection"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
# Don't allow updating the 'none' proxy
|
||||
if data["id"] == "none":
|
||||
return make_response(jsonify({"success": False, "message": "Cannot update default proxy"}), 400)
|
||||
|
||||
result = proxies_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"name": data["name"], "connection": data["connection"]}},
|
||||
)
|
||||
if result.modified_count == 0:
|
||||
return make_response(jsonify({"success": False, "message": "Proxy not found"}), 404)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating proxy: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
@user_ns.route("/api/get_api_keys")
|
||||
class GetApiKeys(Resource):
|
||||
|
||||
@@ -14,6 +14,7 @@ esutils==1.0.1
|
||||
Flask==3.1.0
|
||||
faiss-cpu==1.9.0.post1
|
||||
flask-restx==1.3.0
|
||||
gevent==24.11.1
|
||||
google-genai==1.3.0
|
||||
google-generativeai==0.8.3
|
||||
gTTS==2.5.4
|
||||
@@ -35,7 +36,7 @@ langchain-community==0.3.19
|
||||
langchain-core==0.3.45
|
||||
langchain-openai==0.3.8
|
||||
langchain-text-splitters==0.3.6
|
||||
langsmith==0.3.19
|
||||
langsmith==0.3.15
|
||||
lazy-object-proxy==1.10.0
|
||||
lxml==5.3.1
|
||||
markupsafe==3.0.2
|
||||
@@ -65,7 +66,7 @@ py==1.11.0
|
||||
pydantic==2.10.6
|
||||
pydantic-core==2.27.2
|
||||
pydantic-settings==2.7.1
|
||||
pymongo==4.11.3
|
||||
pymongo==4.10.1
|
||||
pypdf==5.2.0
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv==1.0.1
|
||||
|
||||
Reference in New Issue
Block a user