fix: structure improvements of messages

This commit is contained in:
Alex
2026-04-01 14:58:44 +01:00
parent 398f3acc8d
commit 8b9e595d85
8 changed files with 117 additions and 119 deletions

View File

@@ -138,17 +138,12 @@ class BaseAgent(ABC):
actions_by_id = {a["call_id"]: a for a in tool_actions}
# Build a single assistant message containing all tool calls so
# the message history matches the format LLM providers expect
# (one assistant message with N tool_calls, followed by N tool results).
tc_objects: List[Dict[str, Any]] = []
for pending in pending_tool_calls:
call_id = pending["call_id"]
action = actions_by_id.get(call_id)
if not action:
action = {
"call_id": call_id,
"decision": "denied",
"comment": "No response provided",
}
# Build the assistant tool-call message in standard format
args = pending["arguments"]
args_str = (
json.dumps(args) if isinstance(args, dict) else (args or "{}")
@@ -163,11 +158,25 @@ class BaseAgent(ABC):
}
if pending.get("thought_signature"):
tc_obj["thought_signature"] = pending["thought_signature"]
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [tc_obj],
})
tc_objects.append(tc_obj)
messages.append({
"role": "assistant",
"content": None,
"tool_calls": tc_objects,
})
# Now process each pending call and append tool result messages
for pending in pending_tool_calls:
call_id = pending["call_id"]
args = pending["arguments"]
action = actions_by_id.get(call_id)
if not action:
action = {
"call_id": call_id,
"decision": "denied",
"comment": "No response provided",
}
if action.get("decision") == "approved":
# Execute the tool server-side

View File

@@ -126,33 +126,18 @@ class AnswerResource(Resource, BaseAnswerResource):
stream_result = self.process_response_stream(stream)
if len(stream_result) == 7:
(
conversation_id,
response,
sources,
tool_calls,
thought,
error,
extra_info,
) = stream_result
else:
conversation_id, response, sources, tool_calls, thought, error = (
stream_result
)
extra_info = None
if error:
return make_response({"error": error}, 400)
if stream_result["error"]:
return make_response({"error": stream_result["error"]}, 400)
result = {
"conversation_id": conversation_id,
"answer": response,
"sources": sources,
"tool_calls": tool_calls,
"thought": thought,
"conversation_id": stream_result["conversation_id"],
"answer": stream_result["answer"],
"sources": stream_result["sources"],
"tool_calls": stream_result["tool_calls"],
"thought": stream_result["thought"],
}
extra_info = stream_result.get("extra")
if extra_info:
result.update(extra_info)
except Exception as e:

View File

@@ -540,8 +540,13 @@ class BaseAnswerResource:
yield f"data: {data}\n\n"
return
def process_response_stream(self, stream):
"""Process the stream response for non-streaming endpoint"""
def process_response_stream(self, stream) -> Dict[str, Any]:
"""Process the stream response for non-streaming endpoint.
Returns:
Dict with keys: conversation_id, answer, sources, tool_calls,
thought, error, and optional extra.
"""
conversation_id = ""
response_full = ""
source_log_docs = []
@@ -577,7 +582,14 @@ class BaseAnswerResource:
thought = event["thought"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
return None, None, None, None, event["error"], None
return {
"conversation_id": None,
"answer": None,
"sources": None,
"tool_calls": None,
"thought": None,
"error": event["error"],
}
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
@@ -585,30 +597,30 @@ class BaseAnswerResource:
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return None, None, None, None, "Stream ended unexpectedly", None
return {
"conversation_id": None,
"answer": None,
"sources": None,
"tool_calls": None,
"thought": None,
"error": "Stream ended unexpectedly",
}
result: Dict[str, Any] = {
"conversation_id": conversation_id,
"answer": response_full,
"sources": source_log_docs,
"tool_calls": tool_calls,
"thought": thought,
"error": None,
}
if pending_tool_calls is not None:
return (
conversation_id,
response_full,
source_log_docs,
tool_calls,
thought,
None,
{"pending_tool_calls": pending_tool_calls},
)
result = (
conversation_id,
response_full,
source_log_docs,
tool_calls,
thought,
None,
)
result["extra"] = {"pending_tool_calls": pending_tool_calls}
if is_structured:
result = result + ({"structured": True, "schema": schema_info},)
result["extra"] = {"structured": True, "schema": schema_info}
return result
def error_stream_generate(self, err_response):

View File

@@ -36,16 +36,21 @@ def _extract_bearer_token() -> Optional[str]:
return None
def _get_model_name(api_key: str) -> str:
"""Look up agent name for display as model name."""
def _lookup_agent(api_key: str) -> Optional[Dict]:
"""Look up the agent document for this API key."""
try:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
agent = db["agents"].find_one({"key": api_key})
if agent:
return agent.get("name", api_key)
return db["agents"].find_one({"key": api_key})
except Exception:
pass
logger.warning("Failed to look up agent for API key", exc_info=True)
return None
def _get_model_name(agent: Optional[Dict], api_key: str) -> str:
"""Return agent name for display as model name."""
if agent:
return agent.get("name", api_key)
return api_key
@@ -72,7 +77,8 @@ def chat_completions():
)
is_stream = data.get("stream", False)
model_name = _get_model_name(api_key)
agent_doc = _lookup_agent(api_key)
model_name = _get_model_name(agent_doc, api_key)
try:
internal_data = translate_request(data, api_key)
@@ -83,8 +89,10 @@ def chat_completions():
400,
)
# Use the api_key as decoded token for agent auth
decoded_token = {"sub": "api_key_user"}
# Link decoded_token to the agent's owner so continuation state,
# logs, and tool execution use the correct user identity.
agent_user = agent_doc.get("user") if agent_doc else None
decoded_token = {"sub": agent_user or "api_key_user"}
try:
processor = StreamProcessor(internal_data, decoded_token)
@@ -232,26 +240,21 @@ def _non_stream_response(
result = helper.process_response_stream(stream)
if len(result) == 7:
conversation_id, answer, sources, tool_calls, thought, error, extra = result
else:
conversation_id, answer, sources, tool_calls, thought, error = result
extra = None
if error:
if result["error"]:
return make_response(
jsonify({"error": {"message": error, "type": "server_error"}}),
jsonify({"error": {"message": result["error"], "type": "server_error"}}),
500,
)
extra = result.get("extra")
pending = extra.get("pending_tool_calls") if isinstance(extra, dict) else None
response = translate_response(
conversation_id=conversation_id,
answer=answer or "",
sources=sources,
tool_calls=tool_calls,
thought=thought or "",
conversation_id=result["conversation_id"],
answer=result["answer"] or "",
sources=result["sources"],
tool_calls=result["tool_calls"],
thought=result["thought"] or "",
model_name=model_name,
pending_tool_calls=pending,
)

View File

@@ -411,15 +411,23 @@ function translateV1ChunkToInternalEvents(
if (delta.tool_calls) {
for (const tc of delta.tool_calls) {
let parsedArgs: Record<string, any> = {};
if (tc.function?.arguments) {
try {
parsedArgs = JSON.parse(tc.function.arguments);
} catch {
// Arguments may arrive as fragments during streaming;
// keep the raw string so downstream can accumulate it.
parsedArgs = { _raw: tc.function.arguments };
}
}
events.push({
type: 'tool_call',
data: {
call_id: tc.id,
action_name: tc.function?.name || '',
tool_name: tc.function?.name || '',
arguments: tc.function?.arguments
? JSON.parse(tc.function.arguments)
: {},
arguments: parsedArgs,
status: 'requires_client_execution',
},
});

View File

@@ -73,7 +73,7 @@ class TestAnswerResourcePost:
),
), patch(
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
return_value=(conv_id, "Hello", [], [], "", None),
return_value={"conversation_id": conv_id, "answer": "Hello", "sources": [], "tool_calls": [], "thought": "", "error": None},
):
resp = answer_client.post(
"/api/answer",
@@ -129,7 +129,7 @@ class TestAnswerResourcePost:
return_value=iter([]),
), patch(
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
return_value=(None, None, None, None, None, "Stream error"),
return_value={"conversation_id": None, "answer": None, "sources": None, "tool_calls": None, "thought": None, "error": "Stream error"},
):
resp = answer_client.post(
"/api/answer",
@@ -173,15 +173,7 @@ class TestAnswerResourcePost:
return_value=iter([]),
), patch(
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
return_value=(
conv_id,
'{"key": "val"}',
[],
[],
"",
None,
{"structured": True, "schema": {"type": "object"}},
),
return_value={"conversation_id": conv_id, "answer": '{"key": "val"}', "sources": [], "tool_calls": [], "thought": "", "error": None, "extra": {"structured": True, "schema": {"type": "object"}}},
):
resp = answer_client.post(
"/api/answer",
@@ -208,14 +200,7 @@ class TestAnswerResourcePost:
return_value=iter([]),
), patch(
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
return_value=(
conv_id,
"answer text",
[{"title": "src"}],
[{"tool": "t"}],
"thinking...",
None,
),
return_value={"conversation_id": conv_id, "answer": "answer text", "sources": [{"title": "src"}], "tool_calls": [{"tool": "t"}], "thought": "thinking...", "error": None},
):
resp = answer_client.post(
"/api/answer",

View File

@@ -481,10 +481,10 @@ class TestProcessResponseStream:
result = resource.process_response_stream(iter(stream))
assert result[0] == conv_id
assert result[1] == "Hello world"
assert result[2] == [{"title": "doc1"}]
assert result[5] is None
assert result["conversation_id"] == conv_id
assert result["answer"] == "Hello world"
assert result["sources"] == [{"title": "doc1"}]
assert result["error"] is None
def test_handles_stream_error(self, mock_mongo_db, flask_app):
import json
@@ -500,10 +500,8 @@ class TestProcessResponseStream:
result = resource.process_response_stream(iter(stream))
assert len(result) == 6
assert result[0] is None
assert result[4] == "Test error"
assert result[5] is None
assert result["conversation_id"] is None
assert result["error"] == "Test error"
def test_handles_malformed_stream_data(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource

View File

@@ -295,10 +295,8 @@ class TestProcessResponseStreamExtended:
f'data: {json.dumps({"type": "end"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result[1] == "{}"
# Structured output adds extra tuple element
assert len(result) == 7
assert result[6]["structured"] is True
assert result["answer"] == "{}"
assert result.get("extra", {}).get("structured") is True
def test_handles_tool_calls_event(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
@@ -312,7 +310,7 @@ class TestProcessResponseStreamExtended:
f'data: {json.dumps({"type": "end"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result[3] == [{"name": "t1"}]
assert result["tool_calls"] == [{"name": "t1"}]
def test_incomplete_stream(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
@@ -323,7 +321,7 @@ class TestProcessResponseStreamExtended:
f'data: {json.dumps({"type": "answer", "answer": "partial"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result[4] == "Stream ended unexpectedly"
assert result["error"] == "Stream ended unexpectedly"
def test_handles_thought_event(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
@@ -335,7 +333,7 @@ class TestProcessResponseStreamExtended:
f'data: {json.dumps({"type": "end"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result[4] == "thinking..."
assert result["thought"] == "thinking..."
@pytest.mark.unit