mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 06:30:03 +00:00
fix: structure improvements of messages
This commit is contained in:
@@ -138,17 +138,12 @@ class BaseAgent(ABC):
|
||||
|
||||
actions_by_id = {a["call_id"]: a for a in tool_actions}
|
||||
|
||||
# Build a single assistant message containing all tool calls so
|
||||
# the message history matches the format LLM providers expect
|
||||
# (one assistant message with N tool_calls, followed by N tool results).
|
||||
tc_objects: List[Dict[str, Any]] = []
|
||||
for pending in pending_tool_calls:
|
||||
call_id = pending["call_id"]
|
||||
action = actions_by_id.get(call_id)
|
||||
if not action:
|
||||
action = {
|
||||
"call_id": call_id,
|
||||
"decision": "denied",
|
||||
"comment": "No response provided",
|
||||
}
|
||||
|
||||
# Build the assistant tool-call message in standard format
|
||||
args = pending["arguments"]
|
||||
args_str = (
|
||||
json.dumps(args) if isinstance(args, dict) else (args or "{}")
|
||||
@@ -163,11 +158,25 @@ class BaseAgent(ABC):
|
||||
}
|
||||
if pending.get("thought_signature"):
|
||||
tc_obj["thought_signature"] = pending["thought_signature"]
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [tc_obj],
|
||||
})
|
||||
tc_objects.append(tc_obj)
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": tc_objects,
|
||||
})
|
||||
|
||||
# Now process each pending call and append tool result messages
|
||||
for pending in pending_tool_calls:
|
||||
call_id = pending["call_id"]
|
||||
args = pending["arguments"]
|
||||
action = actions_by_id.get(call_id)
|
||||
if not action:
|
||||
action = {
|
||||
"call_id": call_id,
|
||||
"decision": "denied",
|
||||
"comment": "No response provided",
|
||||
}
|
||||
|
||||
if action.get("decision") == "approved":
|
||||
# Execute the tool server-side
|
||||
|
||||
@@ -126,33 +126,18 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
|
||||
stream_result = self.process_response_stream(stream)
|
||||
|
||||
if len(stream_result) == 7:
|
||||
(
|
||||
conversation_id,
|
||||
response,
|
||||
sources,
|
||||
tool_calls,
|
||||
thought,
|
||||
error,
|
||||
extra_info,
|
||||
) = stream_result
|
||||
else:
|
||||
conversation_id, response, sources, tool_calls, thought, error = (
|
||||
stream_result
|
||||
)
|
||||
extra_info = None
|
||||
|
||||
if error:
|
||||
return make_response({"error": error}, 400)
|
||||
if stream_result["error"]:
|
||||
return make_response({"error": stream_result["error"]}, 400)
|
||||
|
||||
result = {
|
||||
"conversation_id": conversation_id,
|
||||
"answer": response,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"thought": thought,
|
||||
"conversation_id": stream_result["conversation_id"],
|
||||
"answer": stream_result["answer"],
|
||||
"sources": stream_result["sources"],
|
||||
"tool_calls": stream_result["tool_calls"],
|
||||
"thought": stream_result["thought"],
|
||||
}
|
||||
|
||||
extra_info = stream_result.get("extra")
|
||||
if extra_info:
|
||||
result.update(extra_info)
|
||||
except Exception as e:
|
||||
|
||||
@@ -540,8 +540,13 @@ class BaseAnswerResource:
|
||||
yield f"data: {data}\n\n"
|
||||
return
|
||||
|
||||
def process_response_stream(self, stream):
|
||||
"""Process the stream response for non-streaming endpoint"""
|
||||
def process_response_stream(self, stream) -> Dict[str, Any]:
|
||||
"""Process the stream response for non-streaming endpoint.
|
||||
|
||||
Returns:
|
||||
Dict with keys: conversation_id, answer, sources, tool_calls,
|
||||
thought, error, and optional extra.
|
||||
"""
|
||||
conversation_id = ""
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
@@ -577,7 +582,14 @@ class BaseAnswerResource:
|
||||
thought = event["thought"]
|
||||
elif event["type"] == "error":
|
||||
logger.error(f"Error from stream: {event['error']}")
|
||||
return None, None, None, None, event["error"], None
|
||||
return {
|
||||
"conversation_id": None,
|
||||
"answer": None,
|
||||
"sources": None,
|
||||
"tool_calls": None,
|
||||
"thought": None,
|
||||
"error": event["error"],
|
||||
}
|
||||
elif event["type"] == "end":
|
||||
stream_ended = True
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
@@ -585,30 +597,30 @@ class BaseAnswerResource:
|
||||
continue
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return None, None, None, None, "Stream ended unexpectedly", None
|
||||
return {
|
||||
"conversation_id": None,
|
||||
"answer": None,
|
||||
"sources": None,
|
||||
"tool_calls": None,
|
||||
"thought": None,
|
||||
"error": "Stream ended unexpectedly",
|
||||
}
|
||||
|
||||
result: Dict[str, Any] = {
|
||||
"conversation_id": conversation_id,
|
||||
"answer": response_full,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
"thought": thought,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
if pending_tool_calls is not None:
|
||||
return (
|
||||
conversation_id,
|
||||
response_full,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
thought,
|
||||
None,
|
||||
{"pending_tool_calls": pending_tool_calls},
|
||||
)
|
||||
|
||||
result = (
|
||||
conversation_id,
|
||||
response_full,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
thought,
|
||||
None,
|
||||
)
|
||||
result["extra"] = {"pending_tool_calls": pending_tool_calls}
|
||||
|
||||
if is_structured:
|
||||
result = result + ({"structured": True, "schema": schema_info},)
|
||||
result["extra"] = {"structured": True, "schema": schema_info}
|
||||
|
||||
return result
|
||||
|
||||
def error_stream_generate(self, err_response):
|
||||
|
||||
@@ -36,16 +36,21 @@ def _extract_bearer_token() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _get_model_name(api_key: str) -> str:
|
||||
"""Look up agent name for display as model name."""
|
||||
def _lookup_agent(api_key: str) -> Optional[Dict]:
|
||||
"""Look up the agent document for this API key."""
|
||||
try:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
agent = db["agents"].find_one({"key": api_key})
|
||||
if agent:
|
||||
return agent.get("name", api_key)
|
||||
return db["agents"].find_one({"key": api_key})
|
||||
except Exception:
|
||||
pass
|
||||
logger.warning("Failed to look up agent for API key", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def _get_model_name(agent: Optional[Dict], api_key: str) -> str:
|
||||
"""Return agent name for display as model name."""
|
||||
if agent:
|
||||
return agent.get("name", api_key)
|
||||
return api_key
|
||||
|
||||
|
||||
@@ -72,7 +77,8 @@ def chat_completions():
|
||||
)
|
||||
|
||||
is_stream = data.get("stream", False)
|
||||
model_name = _get_model_name(api_key)
|
||||
agent_doc = _lookup_agent(api_key)
|
||||
model_name = _get_model_name(agent_doc, api_key)
|
||||
|
||||
try:
|
||||
internal_data = translate_request(data, api_key)
|
||||
@@ -83,8 +89,10 @@ def chat_completions():
|
||||
400,
|
||||
)
|
||||
|
||||
# Use the api_key as decoded token for agent auth
|
||||
decoded_token = {"sub": "api_key_user"}
|
||||
# Link decoded_token to the agent's owner so continuation state,
|
||||
# logs, and tool execution use the correct user identity.
|
||||
agent_user = agent_doc.get("user") if agent_doc else None
|
||||
decoded_token = {"sub": agent_user or "api_key_user"}
|
||||
|
||||
try:
|
||||
processor = StreamProcessor(internal_data, decoded_token)
|
||||
@@ -232,26 +240,21 @@ def _non_stream_response(
|
||||
|
||||
result = helper.process_response_stream(stream)
|
||||
|
||||
if len(result) == 7:
|
||||
conversation_id, answer, sources, tool_calls, thought, error, extra = result
|
||||
else:
|
||||
conversation_id, answer, sources, tool_calls, thought, error = result
|
||||
extra = None
|
||||
|
||||
if error:
|
||||
if result["error"]:
|
||||
return make_response(
|
||||
jsonify({"error": {"message": error, "type": "server_error"}}),
|
||||
jsonify({"error": {"message": result["error"], "type": "server_error"}}),
|
||||
500,
|
||||
)
|
||||
|
||||
extra = result.get("extra")
|
||||
pending = extra.get("pending_tool_calls") if isinstance(extra, dict) else None
|
||||
|
||||
response = translate_response(
|
||||
conversation_id=conversation_id,
|
||||
answer=answer or "",
|
||||
sources=sources,
|
||||
tool_calls=tool_calls,
|
||||
thought=thought or "",
|
||||
conversation_id=result["conversation_id"],
|
||||
answer=result["answer"] or "",
|
||||
sources=result["sources"],
|
||||
tool_calls=result["tool_calls"],
|
||||
thought=result["thought"] or "",
|
||||
model_name=model_name,
|
||||
pending_tool_calls=pending,
|
||||
)
|
||||
|
||||
@@ -411,15 +411,23 @@ function translateV1ChunkToInternalEvents(
|
||||
|
||||
if (delta.tool_calls) {
|
||||
for (const tc of delta.tool_calls) {
|
||||
let parsedArgs: Record<string, any> = {};
|
||||
if (tc.function?.arguments) {
|
||||
try {
|
||||
parsedArgs = JSON.parse(tc.function.arguments);
|
||||
} catch {
|
||||
// Arguments may arrive as fragments during streaming;
|
||||
// keep the raw string so downstream can accumulate it.
|
||||
parsedArgs = { _raw: tc.function.arguments };
|
||||
}
|
||||
}
|
||||
events.push({
|
||||
type: 'tool_call',
|
||||
data: {
|
||||
call_id: tc.id,
|
||||
action_name: tc.function?.name || '',
|
||||
tool_name: tc.function?.name || '',
|
||||
arguments: tc.function?.arguments
|
||||
? JSON.parse(tc.function.arguments)
|
||||
: {},
|
||||
arguments: parsedArgs,
|
||||
status: 'requires_client_execution',
|
||||
},
|
||||
});
|
||||
|
||||
@@ -73,7 +73,7 @@ class TestAnswerResourcePost:
|
||||
),
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
|
||||
return_value=(conv_id, "Hello", [], [], "", None),
|
||||
return_value={"conversation_id": conv_id, "answer": "Hello", "sources": [], "tool_calls": [], "thought": "", "error": None},
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
@@ -129,7 +129,7 @@ class TestAnswerResourcePost:
|
||||
return_value=iter([]),
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
|
||||
return_value=(None, None, None, None, None, "Stream error"),
|
||||
return_value={"conversation_id": None, "answer": None, "sources": None, "tool_calls": None, "thought": None, "error": "Stream error"},
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
@@ -173,15 +173,7 @@ class TestAnswerResourcePost:
|
||||
return_value=iter([]),
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
|
||||
return_value=(
|
||||
conv_id,
|
||||
'{"key": "val"}',
|
||||
[],
|
||||
[],
|
||||
"",
|
||||
None,
|
||||
{"structured": True, "schema": {"type": "object"}},
|
||||
),
|
||||
return_value={"conversation_id": conv_id, "answer": '{"key": "val"}', "sources": [], "tool_calls": [], "thought": "", "error": None, "extra": {"structured": True, "schema": {"type": "object"}}},
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
@@ -208,14 +200,7 @@ class TestAnswerResourcePost:
|
||||
return_value=iter([]),
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
|
||||
return_value=(
|
||||
conv_id,
|
||||
"answer text",
|
||||
[{"title": "src"}],
|
||||
[{"tool": "t"}],
|
||||
"thinking...",
|
||||
None,
|
||||
),
|
||||
return_value={"conversation_id": conv_id, "answer": "answer text", "sources": [{"title": "src"}], "tool_calls": [{"tool": "t"}], "thought": "thinking...", "error": None},
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
|
||||
@@ -481,10 +481,10 @@ class TestProcessResponseStream:
|
||||
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
|
||||
assert result[0] == conv_id
|
||||
assert result[1] == "Hello world"
|
||||
assert result[2] == [{"title": "doc1"}]
|
||||
assert result[5] is None
|
||||
assert result["conversation_id"] == conv_id
|
||||
assert result["answer"] == "Hello world"
|
||||
assert result["sources"] == [{"title": "doc1"}]
|
||||
assert result["error"] is None
|
||||
|
||||
def test_handles_stream_error(self, mock_mongo_db, flask_app):
|
||||
import json
|
||||
@@ -500,10 +500,8 @@ class TestProcessResponseStream:
|
||||
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
|
||||
assert len(result) == 6
|
||||
assert result[0] is None
|
||||
assert result[4] == "Test error"
|
||||
assert result[5] is None
|
||||
assert result["conversation_id"] is None
|
||||
assert result["error"] == "Test error"
|
||||
|
||||
def test_handles_malformed_stream_data(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
@@ -295,10 +295,8 @@ class TestProcessResponseStreamExtended:
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result[1] == "{}"
|
||||
# Structured output adds extra tuple element
|
||||
assert len(result) == 7
|
||||
assert result[6]["structured"] is True
|
||||
assert result["answer"] == "{}"
|
||||
assert result.get("extra", {}).get("structured") is True
|
||||
|
||||
def test_handles_tool_calls_event(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
@@ -312,7 +310,7 @@ class TestProcessResponseStreamExtended:
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result[3] == [{"name": "t1"}]
|
||||
assert result["tool_calls"] == [{"name": "t1"}]
|
||||
|
||||
def test_incomplete_stream(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
@@ -323,7 +321,7 @@ class TestProcessResponseStreamExtended:
|
||||
f'data: {json.dumps({"type": "answer", "answer": "partial"})}\n\n',
|
||||
]
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result[4] == "Stream ended unexpectedly"
|
||||
assert result["error"] == "Stream ended unexpectedly"
|
||||
|
||||
def test_handles_thought_event(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
@@ -335,7 +333,7 @@ class TestProcessResponseStreamExtended:
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result[4] == "thinking..."
|
||||
assert result["thought"] == "thinking..."
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
||||
Reference in New Issue
Block a user