diff --git a/application/api/v1/translator.py b/application/api/v1/translator.py index 3d7717ff..9212bea3 100644 --- a/application/api/v1/translator.py +++ b/application/api/v1/translator.py @@ -7,9 +7,22 @@ This module handles: """ import json +import re import time from typing import Any, Dict, List, Optional +# Pattern matching internal tool-id suffixes (e.g. _ct0, _ct12) +_TOOL_SUFFIX_RE = re.compile(r"_ct\d+$") + + +def _strip_tool_suffix(name: str) -> str: + """Remove internal tool-id suffix from a tool name for client responses. + + Internally tools are named ``action_ct0`` so the LLM can route calls. + Standard OpenAI clients expect the original registered name back. + """ + return _TOOL_SUFFIX_RE.sub("", name) + # --------------------------------------------------------------------------- # Request translation @@ -119,6 +132,8 @@ def translate_request( if is_continuation(messages): tool_actions = extract_tool_results(messages) conversation_id = extract_conversation_id(messages) + if not conversation_id: + conversation_id = data.get("conversation_id") result = { "conversation_id": conversation_id, "tool_actions": tool_actions, @@ -199,7 +214,9 @@ def translate_response( "id": tc.get("call_id", ""), "type": "function", "function": { - "name": tc.get("name", tc.get("action_name", "")), + "name": _strip_tool_suffix( + tc.get("action_name", tc.get("name", "")) + ), "arguments": ( json.dumps(tc["arguments"]) if isinstance(tc.get("arguments"), dict) @@ -341,7 +358,9 @@ def translate_stream_event( "id": tc_data.get("call_id", ""), "type": "function", "function": { - "name": tc_data.get("action_name", ""), + "name": _strip_tool_suffix( + tc_data.get("action_name", "") + ), "arguments": args_str, }, }], diff --git a/tests/agents/test_tool_executor.py b/tests/agents/test_tool_executor.py index 96be815c..bb3038fe 100644 --- a/tests/agents/test_tool_executor.py +++ b/tests/agents/test_tool_executor.py @@ -128,6 +128,77 @@ class TestToolExecutorPrepare: assert "value" not in result["properties"]["query"] +@pytest.mark.unit +class TestCheckPause: + + def _make_call(self, name="action_toolid", call_id="c1", arguments="{}"): + call = Mock() + call.name = name + call.id = call_id + call.arguments = arguments + call.thought_signature = None + return call + + def test_client_side_tool_returns_suffixed_name(self, monkeypatch): + """check_pause returns the LLM-facing suffixed name for internal routing.""" + executor = ToolExecutor() + + monkeypatch.setattr( + "application.agents.tool_executor.ToolActionParser", + lambda _cls: Mock( + parse_args=Mock(return_value=("ct0", "write_file", {"path": "test.md"})) + ), + ) + + tools_dict = { + "ct0": { + "name": "write_file", + "client_side": True, + "actions": [ + {"name": "write_file", "description": "Write a file", "active": True, "parameters": {}}, + ], + } + } + + call = self._make_call(name="write_file_ct0") + result = executor.check_pause(tools_dict, call, "MockLLM") + + assert result is not None + # name keeps the suffix for LLM message reconstruction during continuation + assert result["name"] == "write_file_ct0" + # action_name is the clean parsed name + assert result["action_name"] == "write_file" + assert result["tool_id"] == "ct0" + + def test_approval_required_returns_suffixed_name(self, monkeypatch): + """check_pause for approval-required tools also returns suffixed name.""" + executor = ToolExecutor() + + monkeypatch.setattr( + "application.agents.tool_executor.ToolActionParser", + lambda _cls: Mock( + parse_args=Mock(return_value=("t1", "delete_all", {})) + ), + ) + + tools_dict = { + "t1": { + "name": "dangerous_tool", + "actions": [ + {"name": "delete_all", "description": "Deletes everything", "active": True, + "require_approval": True, "parameters": {}}, + ], + } + } + + call = self._make_call(name="delete_all_t1") + result = executor.check_pause(tools_dict, call, "MockLLM") + + assert result is not None + assert result["name"] == "delete_all_t1" + assert result["action_name"] == "delete_all" + + @pytest.mark.unit class TestToolExecutorExecute: diff --git a/tests/integration/test_v1_tool_calls.py b/tests/integration/test_v1_tool_calls.py new file mode 100644 index 00000000..af0f650a --- /dev/null +++ b/tests/integration/test_v1_tool_calls.py @@ -0,0 +1,539 @@ +#!/usr/bin/env python3 +r""" +Integration tests for the /v1/ chat completions API — client tool-call flow. + +Tests the full lifecycle: +1. Send request with client tools → LLM triggers a tool call +2. Verify response returns clean tool names (no internal _ct\d+ suffix) +3. Send continuation with tool results + top-level conversation_id +4. Verify the continuation completes successfully + +Usage: + python tests/integration/test_v1_tool_calls.py + python tests/integration/test_v1_tool_calls.py --base-url http://localhost:7091 +""" + +import json as json_module +import re +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import requests + +_THIS_DIR = Path(__file__).parent +_TESTS_DIR = _THIS_DIR.parent +_ROOT_DIR = _TESTS_DIR.parent +if str(_ROOT_DIR) not in sys.path: + sys.path.insert(0, str(_ROOT_DIR)) + +from tests.integration.base import DocsGPTTestBase, create_client_from_args + +# Internal suffix pattern that should NOT appear in client responses +_CT_SUFFIX_RE = re.compile(r"_ct\d+$") + + +class V1ToolCallTests(DocsGPTTestBase): + """Integration tests for /v1/ client tool-call flows.""" + + # ------------------------------------------------------------------------- + # Helpers + # ------------------------------------------------------------------------- + + def get_or_create_agent_key(self) -> Optional[str]: + """Get or create a test agent and return its API key.""" + if hasattr(self, "_agent_key") and self._agent_key: + return self._agent_key + + payload = { + "name": f"V1 ToolCall Test {int(time.time())}", + "description": "Integration test agent for tool-call flow", + "prompt_id": "default", + "chunks": 2, + "retriever": "classic", + "agent_type": "classic", + "status": "published", + "source": "default", + } + + try: + response = self.post("/api/create_agent", json=payload, timeout=10) + if response.status_code in [200, 201]: + result = response.json() + api_key = result.get("key") + self._agent_id = result.get("id") + if api_key: + self._agent_key = api_key + self.print_info(f"Created test agent with key: {api_key[:8]}...") + return api_key + except Exception as e: + self.print_error(f"Failed to create agent: {e}") + + return None + + def _v1_headers(self, api_key: str) -> dict: + return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + + # A simple client tool definition in OpenAI format + _CLIENT_TOOLS = [ + { + "type": "function", + "function": { + "name": "create", + "description": "Create a new todo item", + "parameters": { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "The title of the new todo item", + } + }, + "required": ["title"], + }, + }, + } + ] + + def _send_streaming_request( + self, + api_key: str, + messages: List[Dict], + tools: Optional[List[Dict]] = None, + conversation_id: Optional[str] = None, + ) -> Tuple[List[Dict], str, Optional[Dict]]: + """Send a streaming request and collect all events. + + Returns: + (all_chunks, full_content, tool_call_info) + tool_call_info is a dict with 'name', 'arguments', 'call_id' + if the response paused for a client tool call, else None. + """ + body: Dict[str, Any] = { + "messages": messages, + "stream": True, + } + if tools: + body["tools"] = tools + if conversation_id: + body["conversation_id"] = conversation_id + + response = requests.post( + f"{self.base_url}/v1/chat/completions", + json=body, + headers=self._v1_headers(api_key), + stream=True, + timeout=120, + ) + + if response.status_code != 200: + raise RuntimeError( + f"Expected 200, got {response.status_code}: {response.text[:300]}" + ) + + chunks: List[Dict] = [] + content_pieces: List[str] = [] + tool_call_info: Optional[Dict] = None + conversation_id_from_response: Optional[str] = None + + for line in response.iter_lines(): + if not line: + continue + line_str = line.decode("utf-8") + if not line_str.startswith("data: "): + continue + + data_str = line_str[6:] + if data_str.strip() == "[DONE]": + break + + try: + chunk = json_module.loads(data_str) + chunks.append(chunk) + + # Standard chunks + if "choices" in chunk: + delta = chunk["choices"][0].get("delta", {}) + if "content" in delta: + content_pieces.append(delta["content"]) + + # Tool call delta + if "tool_calls" in delta: + tc = delta["tool_calls"][0] + tool_call_info = { + "call_id": tc.get("id", ""), + "name": tc["function"]["name"], + "arguments": tc["function"].get("arguments", "{}"), + } + + # Extension chunks + if "docsgpt" in chunk: + ext = chunk["docsgpt"] + if ext.get("type") == "id": + conversation_id_from_response = ext.get("conversation_id") + + except json_module.JSONDecodeError: + pass + + full_content = "".join(content_pieces) + + # Attach conversation_id to tool_call_info for convenience + if tool_call_info and conversation_id_from_response: + tool_call_info["conversation_id"] = conversation_id_from_response + + return chunks, full_content, tool_call_info + + def _send_non_streaming_request( + self, + api_key: str, + messages: List[Dict], + tools: Optional[List[Dict]] = None, + conversation_id: Optional[str] = None, + ) -> Dict: + """Send a non-streaming request and return parsed JSON.""" + body: Dict[str, Any] = { + "messages": messages, + "stream": False, + } + if tools: + body["tools"] = tools + if conversation_id: + body["conversation_id"] = conversation_id + + response = requests.post( + f"{self.base_url}/v1/chat/completions", + json=body, + headers=self._v1_headers(api_key), + timeout=120, + ) + + if response.status_code != 200: + raise RuntimeError( + f"Expected 200, got {response.status_code}: {response.text[:300]}" + ) + + return response.json() + + # ------------------------------------------------------------------------- + # Tests + # ------------------------------------------------------------------------- + + def test_streaming_tool_call_clean_name(self) -> bool: + """Streaming: tool names returned to client must not have _ct suffixes.""" + test_name = "v1 streaming tool call - clean name" + self.print_header(f"Testing {test_name}") + + api_key = self.get_or_create_agent_key() + if not api_key: + if not self.require_auth(test_name): + return True + self.record_result(test_name, True, "Skipped (no agent)") + return True + + try: + messages = [ + {"role": "user", "content": "Use the create tool to add a todo item titled 'Test integration'. Call the tool now."}, + ] + chunks, content, tool_call_info = self._send_streaming_request( + api_key, messages, tools=self._CLIENT_TOOLS + ) + + if not tool_call_info: + # LLM didn't trigger the tool — could happen, not a failure of our code + self.print_warning("LLM did not trigger a tool call (may need prompt tuning)") + self.print_info(f"Got text response instead: {content[:100]}") + self.record_result(test_name, True, "Skipped (LLM didn't call tool)") + return True + + tool_name = tool_call_info["name"] + self.print_info(f"Tool call name: {tool_name}") + + has_suffix = bool(_CT_SUFFIX_RE.search(tool_name)) + if has_suffix: + self.print_error(f"Tool name has internal suffix: {tool_name}") + self.record_result(test_name, False, f"Suffix leak: {tool_name}") + return False + + self.print_success(f"Tool name is clean: {tool_name}") + self.record_result(test_name, True, f"Clean name: {tool_name}") + return True + + except Exception as e: + self.print_error(f"Error: {e}") + self.record_result(test_name, False, str(e)) + return False + + def test_non_streaming_tool_call_clean_name(self) -> bool: + """Non-streaming: tool names returned to client must not have _ct suffixes.""" + test_name = "v1 non-streaming tool call - clean name" + self.print_header(f"Testing {test_name}") + + api_key = self.get_or_create_agent_key() + if not api_key: + if not self.require_auth(test_name): + return True + self.record_result(test_name, True, "Skipped (no agent)") + return True + + try: + messages = [ + {"role": "user", "content": "Use the create tool to add a todo item titled 'Test non-stream'. Call the tool now."}, + ] + data = self._send_non_streaming_request( + api_key, messages, tools=self._CLIENT_TOOLS + ) + + message = data["choices"][0]["message"] + tool_calls = message.get("tool_calls") + + if not tool_calls: + content = message.get("content", "") + self.print_warning("LLM did not trigger a tool call") + self.print_info(f"Got text response: {content[:100]}") + self.record_result(test_name, True, "Skipped (LLM didn't call tool)") + return True + + tool_name = tool_calls[0]["function"]["name"] + self.print_info(f"Tool call name: {tool_name}") + + has_suffix = bool(_CT_SUFFIX_RE.search(tool_name)) + if has_suffix: + self.print_error(f"Tool name has internal suffix: {tool_name}") + self.record_result(test_name, False, f"Suffix leak: {tool_name}") + return False + + self.print_success(f"Tool name is clean: {tool_name}") + self.record_result(test_name, True, f"Clean name: {tool_name}") + return True + + except Exception as e: + self.print_error(f"Error: {e}") + self.record_result(test_name, False, str(e)) + return False + + def test_streaming_tool_continuation_with_top_level_conversation_id(self) -> bool: + """Full tool-call round-trip: trigger → get conversation_id → continue with top-level id.""" + test_name = "v1 streaming tool continuation - top-level conversation_id" + self.print_header(f"Testing {test_name}") + + api_key = self.get_or_create_agent_key() + if not api_key: + if not self.require_auth(test_name): + return True + self.record_result(test_name, True, "Skipped (no agent)") + return True + + try: + # Step 1: trigger a tool call + messages = [ + {"role": "user", "content": "Use the create tool to add a todo item titled 'Round trip test'. Call the tool now."}, + ] + chunks, content, tool_call_info = self._send_streaming_request( + api_key, messages, tools=self._CLIENT_TOOLS + ) + + if not tool_call_info: + self.print_warning("LLM did not trigger a tool call") + self.record_result(test_name, True, "Skipped (LLM didn't call tool)") + return True + + conversation_id = tool_call_info.get("conversation_id") + if not conversation_id: + self.print_error("No conversation_id returned in stream") + self.record_result(test_name, False, "Missing conversation_id") + return False + + self.print_info(f"Got conversation_id: {conversation_id[:12]}...") + self.print_info(f"Tool call: {tool_call_info['name']}({tool_call_info['arguments']})") + + # Step 2: send continuation with tool result + top-level conversation_id + # (standard OpenAI format — no docsgpt field in assistant message) + continuation_messages = [ + *messages, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": tool_call_info["call_id"], + "type": "function", + "function": { + "name": tool_call_info["name"], + "arguments": tool_call_info["arguments"], + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": tool_call_info["call_id"], + "content": json_module.dumps({"id": 99, "title": "Round trip test", "status": "created"}), + }, + ] + + chunks2, content2, tool_call_info2 = self._send_streaming_request( + api_key, + continuation_messages, + tools=self._CLIENT_TOOLS, + conversation_id=conversation_id, + ) + + checks = [ + (len(chunks2) > 0, f"continuation returned {len(chunks2)} chunks"), + (bool(content2) or tool_call_info2 is not None, "got content or another tool call"), + ] + + all_passed = True + for check, label in checks: + if check: + self.print_success(f" {label}") + else: + self.print_error(f" {label}") + all_passed = False + + if content2: + self.print_info(f"Continuation response: {content2[:150]}") + + self.record_result( + test_name, + all_passed, + "Full round-trip works" if all_passed else "Continuation failed", + ) + return all_passed + + except Exception as e: + self.print_error(f"Error: {e}") + self.record_result(test_name, False, str(e)) + return False + + def test_non_streaming_tool_continuation_with_top_level_conversation_id(self) -> bool: + """Non-streaming full round-trip with top-level conversation_id.""" + test_name = "v1 non-streaming tool continuation - top-level conversation_id" + self.print_header(f"Testing {test_name}") + + api_key = self.get_or_create_agent_key() + if not api_key: + if not self.require_auth(test_name): + return True + self.record_result(test_name, True, "Skipped (no agent)") + return True + + try: + # Step 1: trigger a tool call + messages = [ + {"role": "user", "content": "Use the create tool to add a todo item titled 'Non-stream round trip'. Call the tool now."}, + ] + data = self._send_non_streaming_request( + api_key, messages, tools=self._CLIENT_TOOLS + ) + + message = data["choices"][0]["message"] + tool_calls = message.get("tool_calls") + + if not tool_calls: + self.print_warning("LLM did not trigger a tool call") + self.record_result(test_name, True, "Skipped (LLM didn't call tool)") + return True + + conversation_id = data.get("docsgpt", {}).get("conversation_id") + if not conversation_id: + self.print_error("No conversation_id in response") + self.record_result(test_name, False, "Missing conversation_id") + return False + + tc = tool_calls[0] + self.print_info(f"Got tool call: {tc['function']['name']}") + self.print_info(f"conversation_id: {conversation_id[:12]}...") + + # Step 2: send continuation (standard format, top-level conversation_id) + continuation_messages = [ + *messages, + { + "role": "assistant", + "content": None, + "tool_calls": [tc], + }, + { + "role": "tool", + "tool_call_id": tc["id"], + "content": json_module.dumps({"id": 100, "title": "Non-stream round trip", "status": "created"}), + }, + ] + + data2 = self._send_non_streaming_request( + api_key, + continuation_messages, + tools=self._CLIENT_TOOLS, + conversation_id=conversation_id, + ) + + message2 = data2["choices"][0]["message"] + has_response = bool(message2.get("content")) or bool(message2.get("tool_calls")) + + if has_response: + self.print_success("Continuation returned a response") + content2 = message2.get("content", "") + if content2: + self.print_info(f"Response: {content2[:150]}") + else: + self.print_error("Continuation returned empty response") + + self.record_result( + test_name, + has_response, + "Round-trip works" if has_response else "Empty continuation response", + ) + return has_response + + except Exception as e: + self.print_error(f"Error: {e}") + self.record_result(test_name, False, str(e)) + return False + + # ------------------------------------------------------------------------- + # Cleanup & Run All + # ------------------------------------------------------------------------- + + def cleanup(self): + if hasattr(self, "_agent_id") and self._agent_id and self.is_authenticated: + try: + self.post(f"/api/delete_agent?id={self._agent_id}", json={}) + self.print_info(f"Cleaned up test agent {self._agent_id[:8]}...") + except Exception: + pass + + def run_all(self) -> bool: + self.print_header("V1 Tool-Call Flow Integration Tests") + self.print_info(f"Base URL: {self.base_url}") + self.print_info(f"Authentication: {'Yes' if self.is_authenticated else 'No'}") + + try: + # Streaming tests + self.test_streaming_tool_call_clean_name() + time.sleep(1) + + self.test_non_streaming_tool_call_clean_name() + time.sleep(1) + + # Full round-trip tests + self.test_streaming_tool_continuation_with_top_level_conversation_id() + time.sleep(1) + + self.test_non_streaming_tool_continuation_with_top_level_conversation_id() + time.sleep(1) + + finally: + self.cleanup() + + return self.print_summary() + + +def main(): + client = create_client_from_args(V1ToolCallTests, "DocsGPT V1 Tool-Call Integration Tests") + success = client.run_all() + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() diff --git a/tests/test_v1_translator.py b/tests/test_v1_translator.py index b8a5e4c3..1f33e31c 100644 --- a/tests/test_v1_translator.py +++ b/tests/test_v1_translator.py @@ -9,6 +9,7 @@ import json import pytest from application.api.v1.translator import ( + _strip_tool_suffix, convert_history, extract_tool_results, is_continuation, @@ -187,6 +188,39 @@ class TestTranslateRequest: assert len(result["tool_actions"]) == 1 assert result["tool_actions"][0]["call_id"] == "c1" + def test_continuation_with_top_level_conversation_id(self): + """Standard clients send conversation_id at request level, not in messages.""" + data = { + "conversation_id": "conv-top-level", + "messages": [ + {"role": "user", "content": "Do stuff"}, + { + "role": "assistant", + "tool_calls": [{"id": "c1", "type": "function", "function": {"name": "act", "arguments": "{}"}}], + }, + {"role": "tool", "tool_call_id": "c1", "content": "done"}, + ], + } + result = translate_request(data, "key") + assert result["conversation_id"] == "conv-top-level" + + def test_continuation_in_message_conversation_id_takes_precedence(self): + """When both in-message and top-level conversation_id exist, in-message wins.""" + data = { + "conversation_id": "conv-top-level", + "messages": [ + {"role": "user", "content": "Do stuff"}, + { + "role": "assistant", + "tool_calls": [{"id": "c1", "type": "function", "function": {"name": "act", "arguments": "{}"}}], + "docsgpt": {"conversation_id": "conv-in-message"}, + }, + {"role": "tool", "tool_call_id": "c1", "content": "done"}, + ], + } + result = translate_request(data, "key") + assert result["conversation_id"] == "conv-in-message" + def test_client_tools_passed_through(self): data = { "messages": [{"role": "user", "content": "Hi"}], @@ -263,6 +297,51 @@ class TestTranslateResponse: ) assert resp["docsgpt"]["tool_calls"] == tool_calls + def test_pending_tool_calls_strips_ct_suffix(self): + """Internal _ct\\d+ suffixes must be stripped from tool names in responses.""" + pending = [ + { + "call_id": "c1", + "name": "get_weather_ct0", + "action_name": "get_weather_ct0", + "arguments": {"city": "SF"}, + } + ] + resp = translate_response( + conversation_id="c1", + answer="", + sources=[], + tool_calls=[], + thought="", + model_name="agent", + pending_tool_calls=pending, + ) + tc = resp["choices"][0]["message"]["tool_calls"][0] + assert tc["function"]["name"] == "get_weather" + + def test_pending_tool_calls_non_ct_suffix_preserved(self): + """Non-client tool suffixes (e.g. _t1) should not be stripped.""" + pending = [ + { + "call_id": "c1", + "name": "search_t1", + "action_name": "search_t1", + "arguments": {"q": "test"}, + } + ] + resp = translate_response( + conversation_id="c1", + answer="", + sources=[], + tool_calls=[], + thought="", + model_name="agent", + pending_tool_calls=pending, + ) + tc = resp["choices"][0]["message"]["tool_calls"][0] + # _t1 is NOT a client-tool suffix (_ct\d+), so it stays + assert tc["function"]["name"] == "search_t1" + def test_pending_tool_calls(self): pending = [ { @@ -366,6 +445,24 @@ class TestTranslateStreamEvent: assert tc["id"] == "c1" assert tc["function"]["name"] == "get_weather" + def test_tool_call_client_execution_strips_ct_suffix(self): + """Internal _ct suffixes must be stripped from streaming tool call names.""" + chunks = translate_stream_event( + { + "type": "tool_call", + "data": { + "call_id": "c1", + "action_name": "create_ct0", + "arguments": {"title": "test"}, + "status": "requires_client_execution", + }, + }, + "chatcmpl-1", "agent", + ) + parsed = json.loads(chunks[0].replace("data: ", "").strip()) + tc = parsed["choices"][0]["delta"]["tool_calls"][0] + assert tc["function"]["name"] == "create" + def test_tool_call_completed(self): chunks = translate_stream_event( { @@ -457,3 +554,30 @@ class TestTranslateStreamEvent: assert "choices" not in parsed # docsgpt key is present assert "docsgpt" in parsed + + +# --------------------------------------------------------------------------- +# _strip_tool_suffix +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestStripToolSuffix: + + def test_strips_ct0(self): + assert _strip_tool_suffix("create_ct0") == "create" + + def test_strips_ct_multi_digit(self): + assert _strip_tool_suffix("write_file_ct12") == "write_file" + + def test_preserves_non_ct_suffix(self): + assert _strip_tool_suffix("search_t1") == "search_t1" + + def test_preserves_plain_name(self): + assert _strip_tool_suffix("get_weather") == "get_weather" + + def test_preserves_empty(self): + assert _strip_tool_suffix("") == "" + + def test_ct_in_middle_not_stripped(self): + assert _strip_tool_suffix("ct0_action") == "ct0_action"