diff --git a/application/agents/base.py b/application/agents/base.py index 4fc53795..4576e98d 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -138,17 +138,12 @@ class BaseAgent(ABC): actions_by_id = {a["call_id"]: a for a in tool_actions} + # Build a single assistant message containing all tool calls so + # the message history matches the format LLM providers expect + # (one assistant message with N tool_calls, followed by N tool results). + tc_objects: List[Dict[str, Any]] = [] for pending in pending_tool_calls: call_id = pending["call_id"] - action = actions_by_id.get(call_id) - if not action: - action = { - "call_id": call_id, - "decision": "denied", - "comment": "No response provided", - } - - # Build the assistant tool-call message in standard format args = pending["arguments"] args_str = ( json.dumps(args) if isinstance(args, dict) else (args or "{}") @@ -163,11 +158,25 @@ class BaseAgent(ABC): } if pending.get("thought_signature"): tc_obj["thought_signature"] = pending["thought_signature"] - messages.append({ - "role": "assistant", - "content": None, - "tool_calls": [tc_obj], - }) + tc_objects.append(tc_obj) + + messages.append({ + "role": "assistant", + "content": None, + "tool_calls": tc_objects, + }) + + # Now process each pending call and append tool result messages + for pending in pending_tool_calls: + call_id = pending["call_id"] + args = pending["arguments"] + action = actions_by_id.get(call_id) + if not action: + action = { + "call_id": call_id, + "decision": "denied", + "comment": "No response provided", + } if action.get("decision") == "approved": # Execute the tool server-side diff --git a/application/api/answer/routes/answer.py b/application/api/answer/routes/answer.py index b6fe288a..5fa7199f 100644 --- a/application/api/answer/routes/answer.py +++ b/application/api/answer/routes/answer.py @@ -126,33 +126,18 @@ class AnswerResource(Resource, BaseAnswerResource): stream_result = self.process_response_stream(stream) - if len(stream_result) == 7: - ( - conversation_id, - response, - sources, - tool_calls, - thought, - error, - extra_info, - ) = stream_result - else: - conversation_id, response, sources, tool_calls, thought, error = ( - stream_result - ) - extra_info = None - - if error: - return make_response({"error": error}, 400) + if stream_result["error"]: + return make_response({"error": stream_result["error"]}, 400) result = { - "conversation_id": conversation_id, - "answer": response, - "sources": sources, - "tool_calls": tool_calls, - "thought": thought, + "conversation_id": stream_result["conversation_id"], + "answer": stream_result["answer"], + "sources": stream_result["sources"], + "tool_calls": stream_result["tool_calls"], + "thought": stream_result["thought"], } + extra_info = stream_result.get("extra") if extra_info: result.update(extra_info) except Exception as e: diff --git a/application/api/answer/routes/base.py b/application/api/answer/routes/base.py index 74932d1d..4a152b2a 100644 --- a/application/api/answer/routes/base.py +++ b/application/api/answer/routes/base.py @@ -540,8 +540,13 @@ class BaseAnswerResource: yield f"data: {data}\n\n" return - def process_response_stream(self, stream): - """Process the stream response for non-streaming endpoint""" + def process_response_stream(self, stream) -> Dict[str, Any]: + """Process the stream response for non-streaming endpoint. + + Returns: + Dict with keys: conversation_id, answer, sources, tool_calls, + thought, error, and optional extra. + """ conversation_id = "" response_full = "" source_log_docs = [] @@ -577,7 +582,14 @@ class BaseAnswerResource: thought = event["thought"] elif event["type"] == "error": logger.error(f"Error from stream: {event['error']}") - return None, None, None, None, event["error"], None + return { + "conversation_id": None, + "answer": None, + "sources": None, + "tool_calls": None, + "thought": None, + "error": event["error"], + } elif event["type"] == "end": stream_ended = True except (json.JSONDecodeError, KeyError) as e: @@ -585,30 +597,30 @@ class BaseAnswerResource: continue if not stream_ended: logger.error("Stream ended unexpectedly without an 'end' event.") - return None, None, None, None, "Stream ended unexpectedly", None + return { + "conversation_id": None, + "answer": None, + "sources": None, + "tool_calls": None, + "thought": None, + "error": "Stream ended unexpectedly", + } + + result: Dict[str, Any] = { + "conversation_id": conversation_id, + "answer": response_full, + "sources": source_log_docs, + "tool_calls": tool_calls, + "thought": thought, + "error": None, + } if pending_tool_calls is not None: - return ( - conversation_id, - response_full, - source_log_docs, - tool_calls, - thought, - None, - {"pending_tool_calls": pending_tool_calls}, - ) - - result = ( - conversation_id, - response_full, - source_log_docs, - tool_calls, - thought, - None, - ) + result["extra"] = {"pending_tool_calls": pending_tool_calls} if is_structured: - result = result + ({"structured": True, "schema": schema_info},) + result["extra"] = {"structured": True, "schema": schema_info} + return result def error_stream_generate(self, err_response): diff --git a/application/api/v1/routes.py b/application/api/v1/routes.py index df7930d6..d773d962 100644 --- a/application/api/v1/routes.py +++ b/application/api/v1/routes.py @@ -36,16 +36,21 @@ def _extract_bearer_token() -> Optional[str]: return None -def _get_model_name(api_key: str) -> str: - """Look up agent name for display as model name.""" +def _lookup_agent(api_key: str) -> Optional[Dict]: + """Look up the agent document for this API key.""" try: mongo = MongoDB.get_client() db = mongo[settings.MONGO_DB_NAME] - agent = db["agents"].find_one({"key": api_key}) - if agent: - return agent.get("name", api_key) + return db["agents"].find_one({"key": api_key}) except Exception: - pass + logger.warning("Failed to look up agent for API key", exc_info=True) + return None + + +def _get_model_name(agent: Optional[Dict], api_key: str) -> str: + """Return agent name for display as model name.""" + if agent: + return agent.get("name", api_key) return api_key @@ -72,7 +77,8 @@ def chat_completions(): ) is_stream = data.get("stream", False) - model_name = _get_model_name(api_key) + agent_doc = _lookup_agent(api_key) + model_name = _get_model_name(agent_doc, api_key) try: internal_data = translate_request(data, api_key) @@ -83,8 +89,10 @@ def chat_completions(): 400, ) - # Use the api_key as decoded token for agent auth - decoded_token = {"sub": "api_key_user"} + # Link decoded_token to the agent's owner so continuation state, + # logs, and tool execution use the correct user identity. + agent_user = agent_doc.get("user") if agent_doc else None + decoded_token = {"sub": agent_user or "api_key_user"} try: processor = StreamProcessor(internal_data, decoded_token) @@ -232,26 +240,21 @@ def _non_stream_response( result = helper.process_response_stream(stream) - if len(result) == 7: - conversation_id, answer, sources, tool_calls, thought, error, extra = result - else: - conversation_id, answer, sources, tool_calls, thought, error = result - extra = None - - if error: + if result["error"]: return make_response( - jsonify({"error": {"message": error, "type": "server_error"}}), + jsonify({"error": {"message": result["error"], "type": "server_error"}}), 500, ) + extra = result.get("extra") pending = extra.get("pending_tool_calls") if isinstance(extra, dict) else None response = translate_response( - conversation_id=conversation_id, - answer=answer or "", - sources=sources, - tool_calls=tool_calls, - thought=thought or "", + conversation_id=result["conversation_id"], + answer=result["answer"] or "", + sources=result["sources"], + tool_calls=result["tool_calls"], + thought=result["thought"] or "", model_name=model_name, pending_tool_calls=pending, ) diff --git a/frontend/src/conversation/conversationHandlers.ts b/frontend/src/conversation/conversationHandlers.ts index c7fcdf44..759635b6 100644 --- a/frontend/src/conversation/conversationHandlers.ts +++ b/frontend/src/conversation/conversationHandlers.ts @@ -411,15 +411,23 @@ function translateV1ChunkToInternalEvents( if (delta.tool_calls) { for (const tc of delta.tool_calls) { + let parsedArgs: Record = {}; + if (tc.function?.arguments) { + try { + parsedArgs = JSON.parse(tc.function.arguments); + } catch { + // Arguments may arrive as fragments during streaming; + // keep the raw string so downstream can accumulate it. + parsedArgs = { _raw: tc.function.arguments }; + } + } events.push({ type: 'tool_call', data: { call_id: tc.id, action_name: tc.function?.name || '', tool_name: tc.function?.name || '', - arguments: tc.function?.arguments - ? JSON.parse(tc.function.arguments) - : {}, + arguments: parsedArgs, status: 'requires_client_execution', }, }); diff --git a/tests/api/answer/routes/test_answer.py b/tests/api/answer/routes/test_answer.py index 53f2525d..4e4b7a8b 100644 --- a/tests/api/answer/routes/test_answer.py +++ b/tests/api/answer/routes/test_answer.py @@ -73,7 +73,7 @@ class TestAnswerResourcePost: ), ), patch( "application.api.answer.routes.answer.AnswerResource.process_response_stream", - return_value=(conv_id, "Hello", [], [], "", None), + return_value={"conversation_id": conv_id, "answer": "Hello", "sources": [], "tool_calls": [], "thought": "", "error": None}, ): resp = answer_client.post( "/api/answer", @@ -129,7 +129,7 @@ class TestAnswerResourcePost: return_value=iter([]), ), patch( "application.api.answer.routes.answer.AnswerResource.process_response_stream", - return_value=(None, None, None, None, None, "Stream error"), + return_value={"conversation_id": None, "answer": None, "sources": None, "tool_calls": None, "thought": None, "error": "Stream error"}, ): resp = answer_client.post( "/api/answer", @@ -173,15 +173,7 @@ class TestAnswerResourcePost: return_value=iter([]), ), patch( "application.api.answer.routes.answer.AnswerResource.process_response_stream", - return_value=( - conv_id, - '{"key": "val"}', - [], - [], - "", - None, - {"structured": True, "schema": {"type": "object"}}, - ), + return_value={"conversation_id": conv_id, "answer": '{"key": "val"}', "sources": [], "tool_calls": [], "thought": "", "error": None, "extra": {"structured": True, "schema": {"type": "object"}}}, ): resp = answer_client.post( "/api/answer", @@ -208,14 +200,7 @@ class TestAnswerResourcePost: return_value=iter([]), ), patch( "application.api.answer.routes.answer.AnswerResource.process_response_stream", - return_value=( - conv_id, - "answer text", - [{"title": "src"}], - [{"tool": "t"}], - "thinking...", - None, - ), + return_value={"conversation_id": conv_id, "answer": "answer text", "sources": [{"title": "src"}], "tool_calls": [{"tool": "t"}], "thought": "thinking...", "error": None}, ): resp = answer_client.post( "/api/answer", diff --git a/tests/api/answer/routes/test_base.py b/tests/api/answer/routes/test_base.py index eee2ab31..95b7e423 100644 --- a/tests/api/answer/routes/test_base.py +++ b/tests/api/answer/routes/test_base.py @@ -481,10 +481,10 @@ class TestProcessResponseStream: result = resource.process_response_stream(iter(stream)) - assert result[0] == conv_id - assert result[1] == "Hello world" - assert result[2] == [{"title": "doc1"}] - assert result[5] is None + assert result["conversation_id"] == conv_id + assert result["answer"] == "Hello world" + assert result["sources"] == [{"title": "doc1"}] + assert result["error"] is None def test_handles_stream_error(self, mock_mongo_db, flask_app): import json @@ -500,10 +500,8 @@ class TestProcessResponseStream: result = resource.process_response_stream(iter(stream)) - assert len(result) == 6 - assert result[0] is None - assert result[4] == "Test error" - assert result[5] is None + assert result["conversation_id"] is None + assert result["error"] == "Test error" def test_handles_malformed_stream_data(self, mock_mongo_db, flask_app): from application.api.answer.routes.base import BaseAnswerResource diff --git a/tests/api/answer/test_base_routes.py b/tests/api/answer/test_base_routes.py index dc6a19a5..d47076ad 100644 --- a/tests/api/answer/test_base_routes.py +++ b/tests/api/answer/test_base_routes.py @@ -295,10 +295,8 @@ class TestProcessResponseStreamExtended: f'data: {json.dumps({"type": "end"})}\n\n', ] result = resource.process_response_stream(iter(stream)) - assert result[1] == "{}" - # Structured output adds extra tuple element - assert len(result) == 7 - assert result[6]["structured"] is True + assert result["answer"] == "{}" + assert result.get("extra", {}).get("structured") is True def test_handles_tool_calls_event(self, mock_mongo_db, flask_app): from application.api.answer.routes.base import BaseAnswerResource @@ -312,7 +310,7 @@ class TestProcessResponseStreamExtended: f'data: {json.dumps({"type": "end"})}\n\n', ] result = resource.process_response_stream(iter(stream)) - assert result[3] == [{"name": "t1"}] + assert result["tool_calls"] == [{"name": "t1"}] def test_incomplete_stream(self, mock_mongo_db, flask_app): from application.api.answer.routes.base import BaseAnswerResource @@ -323,7 +321,7 @@ class TestProcessResponseStreamExtended: f'data: {json.dumps({"type": "answer", "answer": "partial"})}\n\n', ] result = resource.process_response_stream(iter(stream)) - assert result[4] == "Stream ended unexpectedly" + assert result["error"] == "Stream ended unexpectedly" def test_handles_thought_event(self, mock_mongo_db, flask_app): from application.api.answer.routes.base import BaseAnswerResource @@ -335,7 +333,7 @@ class TestProcessResponseStreamExtended: f'data: {json.dumps({"type": "end"})}\n\n', ] result = resource.process_response_stream(iter(stream)) - assert result[4] == "thinking..." + assert result["thought"] == "thinking..." @pytest.mark.unit