fix: tool mapping

This commit is contained in:
Alex
2026-04-03 13:32:55 +01:00
parent be7da983e7
commit 1af09f114d
4 changed files with 755 additions and 2 deletions

View File

@@ -7,9 +7,22 @@ This module handles:
"""
import json
import re
import time
from typing import Any, Dict, List, Optional
# Pattern matching internal tool-id suffixes (e.g. _ct0, _ct12)
_TOOL_SUFFIX_RE = re.compile(r"_ct\d+$")
def _strip_tool_suffix(name: str) -> str:
"""Remove internal tool-id suffix from a tool name for client responses.
Internally tools are named ``action_ct0`` so the LLM can route calls.
Standard OpenAI clients expect the original registered name back.
"""
return _TOOL_SUFFIX_RE.sub("", name)
# ---------------------------------------------------------------------------
# Request translation
@@ -119,6 +132,8 @@ def translate_request(
if is_continuation(messages):
tool_actions = extract_tool_results(messages)
conversation_id = extract_conversation_id(messages)
if not conversation_id:
conversation_id = data.get("conversation_id")
result = {
"conversation_id": conversation_id,
"tool_actions": tool_actions,
@@ -199,7 +214,9 @@ def translate_response(
"id": tc.get("call_id", ""),
"type": "function",
"function": {
"name": tc.get("name", tc.get("action_name", "")),
"name": _strip_tool_suffix(
tc.get("action_name", tc.get("name", ""))
),
"arguments": (
json.dumps(tc["arguments"])
if isinstance(tc.get("arguments"), dict)
@@ -341,7 +358,9 @@ def translate_stream_event(
"id": tc_data.get("call_id", ""),
"type": "function",
"function": {
"name": tc_data.get("action_name", ""),
"name": _strip_tool_suffix(
tc_data.get("action_name", "")
),
"arguments": args_str,
},
}],

View File

@@ -128,6 +128,77 @@ class TestToolExecutorPrepare:
assert "value" not in result["properties"]["query"]
@pytest.mark.unit
class TestCheckPause:
def _make_call(self, name="action_toolid", call_id="c1", arguments="{}"):
call = Mock()
call.name = name
call.id = call_id
call.arguments = arguments
call.thought_signature = None
return call
def test_client_side_tool_returns_suffixed_name(self, monkeypatch):
"""check_pause returns the LLM-facing suffixed name for internal routing."""
executor = ToolExecutor()
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls: Mock(
parse_args=Mock(return_value=("ct0", "write_file", {"path": "test.md"}))
),
)
tools_dict = {
"ct0": {
"name": "write_file",
"client_side": True,
"actions": [
{"name": "write_file", "description": "Write a file", "active": True, "parameters": {}},
],
}
}
call = self._make_call(name="write_file_ct0")
result = executor.check_pause(tools_dict, call, "MockLLM")
assert result is not None
# name keeps the suffix for LLM message reconstruction during continuation
assert result["name"] == "write_file_ct0"
# action_name is the clean parsed name
assert result["action_name"] == "write_file"
assert result["tool_id"] == "ct0"
def test_approval_required_returns_suffixed_name(self, monkeypatch):
"""check_pause for approval-required tools also returns suffixed name."""
executor = ToolExecutor()
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls: Mock(
parse_args=Mock(return_value=("t1", "delete_all", {}))
),
)
tools_dict = {
"t1": {
"name": "dangerous_tool",
"actions": [
{"name": "delete_all", "description": "Deletes everything", "active": True,
"require_approval": True, "parameters": {}},
],
}
}
call = self._make_call(name="delete_all_t1")
result = executor.check_pause(tools_dict, call, "MockLLM")
assert result is not None
assert result["name"] == "delete_all_t1"
assert result["action_name"] == "delete_all"
@pytest.mark.unit
class TestToolExecutorExecute:

View File

@@ -0,0 +1,539 @@
#!/usr/bin/env python3
r"""
Integration tests for the /v1/ chat completions API — client tool-call flow.
Tests the full lifecycle:
1. Send request with client tools → LLM triggers a tool call
2. Verify response returns clean tool names (no internal _ct\d+ suffix)
3. Send continuation with tool results + top-level conversation_id
4. Verify the continuation completes successfully
Usage:
python tests/integration/test_v1_tool_calls.py
python tests/integration/test_v1_tool_calls.py --base-url http://localhost:7091
"""
import json as json_module
import re
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import requests
_THIS_DIR = Path(__file__).parent
_TESTS_DIR = _THIS_DIR.parent
_ROOT_DIR = _TESTS_DIR.parent
if str(_ROOT_DIR) not in sys.path:
sys.path.insert(0, str(_ROOT_DIR))
from tests.integration.base import DocsGPTTestBase, create_client_from_args
# Internal suffix pattern that should NOT appear in client responses
_CT_SUFFIX_RE = re.compile(r"_ct\d+$")
class V1ToolCallTests(DocsGPTTestBase):
"""Integration tests for /v1/ client tool-call flows."""
# -------------------------------------------------------------------------
# Helpers
# -------------------------------------------------------------------------
def get_or_create_agent_key(self) -> Optional[str]:
"""Get or create a test agent and return its API key."""
if hasattr(self, "_agent_key") and self._agent_key:
return self._agent_key
payload = {
"name": f"V1 ToolCall Test {int(time.time())}",
"description": "Integration test agent for tool-call flow",
"prompt_id": "default",
"chunks": 2,
"retriever": "classic",
"agent_type": "classic",
"status": "published",
"source": "default",
}
try:
response = self.post("/api/create_agent", json=payload, timeout=10)
if response.status_code in [200, 201]:
result = response.json()
api_key = result.get("key")
self._agent_id = result.get("id")
if api_key:
self._agent_key = api_key
self.print_info(f"Created test agent with key: {api_key[:8]}...")
return api_key
except Exception as e:
self.print_error(f"Failed to create agent: {e}")
return None
def _v1_headers(self, api_key: str) -> dict:
return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
# A simple client tool definition in OpenAI format
_CLIENT_TOOLS = [
{
"type": "function",
"function": {
"name": "create",
"description": "Create a new todo item",
"parameters": {
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "The title of the new todo item",
}
},
"required": ["title"],
},
},
}
]
def _send_streaming_request(
self,
api_key: str,
messages: List[Dict],
tools: Optional[List[Dict]] = None,
conversation_id: Optional[str] = None,
) -> Tuple[List[Dict], str, Optional[Dict]]:
"""Send a streaming request and collect all events.
Returns:
(all_chunks, full_content, tool_call_info)
tool_call_info is a dict with 'name', 'arguments', 'call_id'
if the response paused for a client tool call, else None.
"""
body: Dict[str, Any] = {
"messages": messages,
"stream": True,
}
if tools:
body["tools"] = tools
if conversation_id:
body["conversation_id"] = conversation_id
response = requests.post(
f"{self.base_url}/v1/chat/completions",
json=body,
headers=self._v1_headers(api_key),
stream=True,
timeout=120,
)
if response.status_code != 200:
raise RuntimeError(
f"Expected 200, got {response.status_code}: {response.text[:300]}"
)
chunks: List[Dict] = []
content_pieces: List[str] = []
tool_call_info: Optional[Dict] = None
conversation_id_from_response: Optional[str] = None
for line in response.iter_lines():
if not line:
continue
line_str = line.decode("utf-8")
if not line_str.startswith("data: "):
continue
data_str = line_str[6:]
if data_str.strip() == "[DONE]":
break
try:
chunk = json_module.loads(data_str)
chunks.append(chunk)
# Standard chunks
if "choices" in chunk:
delta = chunk["choices"][0].get("delta", {})
if "content" in delta:
content_pieces.append(delta["content"])
# Tool call delta
if "tool_calls" in delta:
tc = delta["tool_calls"][0]
tool_call_info = {
"call_id": tc.get("id", ""),
"name": tc["function"]["name"],
"arguments": tc["function"].get("arguments", "{}"),
}
# Extension chunks
if "docsgpt" in chunk:
ext = chunk["docsgpt"]
if ext.get("type") == "id":
conversation_id_from_response = ext.get("conversation_id")
except json_module.JSONDecodeError:
pass
full_content = "".join(content_pieces)
# Attach conversation_id to tool_call_info for convenience
if tool_call_info and conversation_id_from_response:
tool_call_info["conversation_id"] = conversation_id_from_response
return chunks, full_content, tool_call_info
def _send_non_streaming_request(
self,
api_key: str,
messages: List[Dict],
tools: Optional[List[Dict]] = None,
conversation_id: Optional[str] = None,
) -> Dict:
"""Send a non-streaming request and return parsed JSON."""
body: Dict[str, Any] = {
"messages": messages,
"stream": False,
}
if tools:
body["tools"] = tools
if conversation_id:
body["conversation_id"] = conversation_id
response = requests.post(
f"{self.base_url}/v1/chat/completions",
json=body,
headers=self._v1_headers(api_key),
timeout=120,
)
if response.status_code != 200:
raise RuntimeError(
f"Expected 200, got {response.status_code}: {response.text[:300]}"
)
return response.json()
# -------------------------------------------------------------------------
# Tests
# -------------------------------------------------------------------------
def test_streaming_tool_call_clean_name(self) -> bool:
"""Streaming: tool names returned to client must not have _ct suffixes."""
test_name = "v1 streaming tool call - clean name"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
messages = [
{"role": "user", "content": "Use the create tool to add a todo item titled 'Test integration'. Call the tool now."},
]
chunks, content, tool_call_info = self._send_streaming_request(
api_key, messages, tools=self._CLIENT_TOOLS
)
if not tool_call_info:
# LLM didn't trigger the tool — could happen, not a failure of our code
self.print_warning("LLM did not trigger a tool call (may need prompt tuning)")
self.print_info(f"Got text response instead: {content[:100]}")
self.record_result(test_name, True, "Skipped (LLM didn't call tool)")
return True
tool_name = tool_call_info["name"]
self.print_info(f"Tool call name: {tool_name}")
has_suffix = bool(_CT_SUFFIX_RE.search(tool_name))
if has_suffix:
self.print_error(f"Tool name has internal suffix: {tool_name}")
self.record_result(test_name, False, f"Suffix leak: {tool_name}")
return False
self.print_success(f"Tool name is clean: {tool_name}")
self.record_result(test_name, True, f"Clean name: {tool_name}")
return True
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
def test_non_streaming_tool_call_clean_name(self) -> bool:
"""Non-streaming: tool names returned to client must not have _ct suffixes."""
test_name = "v1 non-streaming tool call - clean name"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
messages = [
{"role": "user", "content": "Use the create tool to add a todo item titled 'Test non-stream'. Call the tool now."},
]
data = self._send_non_streaming_request(
api_key, messages, tools=self._CLIENT_TOOLS
)
message = data["choices"][0]["message"]
tool_calls = message.get("tool_calls")
if not tool_calls:
content = message.get("content", "")
self.print_warning("LLM did not trigger a tool call")
self.print_info(f"Got text response: {content[:100]}")
self.record_result(test_name, True, "Skipped (LLM didn't call tool)")
return True
tool_name = tool_calls[0]["function"]["name"]
self.print_info(f"Tool call name: {tool_name}")
has_suffix = bool(_CT_SUFFIX_RE.search(tool_name))
if has_suffix:
self.print_error(f"Tool name has internal suffix: {tool_name}")
self.record_result(test_name, False, f"Suffix leak: {tool_name}")
return False
self.print_success(f"Tool name is clean: {tool_name}")
self.record_result(test_name, True, f"Clean name: {tool_name}")
return True
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
def test_streaming_tool_continuation_with_top_level_conversation_id(self) -> bool:
"""Full tool-call round-trip: trigger → get conversation_id → continue with top-level id."""
test_name = "v1 streaming tool continuation - top-level conversation_id"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
# Step 1: trigger a tool call
messages = [
{"role": "user", "content": "Use the create tool to add a todo item titled 'Round trip test'. Call the tool now."},
]
chunks, content, tool_call_info = self._send_streaming_request(
api_key, messages, tools=self._CLIENT_TOOLS
)
if not tool_call_info:
self.print_warning("LLM did not trigger a tool call")
self.record_result(test_name, True, "Skipped (LLM didn't call tool)")
return True
conversation_id = tool_call_info.get("conversation_id")
if not conversation_id:
self.print_error("No conversation_id returned in stream")
self.record_result(test_name, False, "Missing conversation_id")
return False
self.print_info(f"Got conversation_id: {conversation_id[:12]}...")
self.print_info(f"Tool call: {tool_call_info['name']}({tool_call_info['arguments']})")
# Step 2: send continuation with tool result + top-level conversation_id
# (standard OpenAI format — no docsgpt field in assistant message)
continuation_messages = [
*messages,
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": tool_call_info["call_id"],
"type": "function",
"function": {
"name": tool_call_info["name"],
"arguments": tool_call_info["arguments"],
},
}
],
},
{
"role": "tool",
"tool_call_id": tool_call_info["call_id"],
"content": json_module.dumps({"id": 99, "title": "Round trip test", "status": "created"}),
},
]
chunks2, content2, tool_call_info2 = self._send_streaming_request(
api_key,
continuation_messages,
tools=self._CLIENT_TOOLS,
conversation_id=conversation_id,
)
checks = [
(len(chunks2) > 0, f"continuation returned {len(chunks2)} chunks"),
(bool(content2) or tool_call_info2 is not None, "got content or another tool call"),
]
all_passed = True
for check, label in checks:
if check:
self.print_success(f" {label}")
else:
self.print_error(f" {label}")
all_passed = False
if content2:
self.print_info(f"Continuation response: {content2[:150]}")
self.record_result(
test_name,
all_passed,
"Full round-trip works" if all_passed else "Continuation failed",
)
return all_passed
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
def test_non_streaming_tool_continuation_with_top_level_conversation_id(self) -> bool:
"""Non-streaming full round-trip with top-level conversation_id."""
test_name = "v1 non-streaming tool continuation - top-level conversation_id"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
# Step 1: trigger a tool call
messages = [
{"role": "user", "content": "Use the create tool to add a todo item titled 'Non-stream round trip'. Call the tool now."},
]
data = self._send_non_streaming_request(
api_key, messages, tools=self._CLIENT_TOOLS
)
message = data["choices"][0]["message"]
tool_calls = message.get("tool_calls")
if not tool_calls:
self.print_warning("LLM did not trigger a tool call")
self.record_result(test_name, True, "Skipped (LLM didn't call tool)")
return True
conversation_id = data.get("docsgpt", {}).get("conversation_id")
if not conversation_id:
self.print_error("No conversation_id in response")
self.record_result(test_name, False, "Missing conversation_id")
return False
tc = tool_calls[0]
self.print_info(f"Got tool call: {tc['function']['name']}")
self.print_info(f"conversation_id: {conversation_id[:12]}...")
# Step 2: send continuation (standard format, top-level conversation_id)
continuation_messages = [
*messages,
{
"role": "assistant",
"content": None,
"tool_calls": [tc],
},
{
"role": "tool",
"tool_call_id": tc["id"],
"content": json_module.dumps({"id": 100, "title": "Non-stream round trip", "status": "created"}),
},
]
data2 = self._send_non_streaming_request(
api_key,
continuation_messages,
tools=self._CLIENT_TOOLS,
conversation_id=conversation_id,
)
message2 = data2["choices"][0]["message"]
has_response = bool(message2.get("content")) or bool(message2.get("tool_calls"))
if has_response:
self.print_success("Continuation returned a response")
content2 = message2.get("content", "")
if content2:
self.print_info(f"Response: {content2[:150]}")
else:
self.print_error("Continuation returned empty response")
self.record_result(
test_name,
has_response,
"Round-trip works" if has_response else "Empty continuation response",
)
return has_response
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Cleanup & Run All
# -------------------------------------------------------------------------
def cleanup(self):
if hasattr(self, "_agent_id") and self._agent_id and self.is_authenticated:
try:
self.post(f"/api/delete_agent?id={self._agent_id}", json={})
self.print_info(f"Cleaned up test agent {self._agent_id[:8]}...")
except Exception:
pass
def run_all(self) -> bool:
self.print_header("V1 Tool-Call Flow Integration Tests")
self.print_info(f"Base URL: {self.base_url}")
self.print_info(f"Authentication: {'Yes' if self.is_authenticated else 'No'}")
try:
# Streaming tests
self.test_streaming_tool_call_clean_name()
time.sleep(1)
self.test_non_streaming_tool_call_clean_name()
time.sleep(1)
# Full round-trip tests
self.test_streaming_tool_continuation_with_top_level_conversation_id()
time.sleep(1)
self.test_non_streaming_tool_continuation_with_top_level_conversation_id()
time.sleep(1)
finally:
self.cleanup()
return self.print_summary()
def main():
client = create_client_from_args(V1ToolCallTests, "DocsGPT V1 Tool-Call Integration Tests")
success = client.run_all()
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()

View File

@@ -9,6 +9,7 @@ import json
import pytest
from application.api.v1.translator import (
_strip_tool_suffix,
convert_history,
extract_tool_results,
is_continuation,
@@ -187,6 +188,39 @@ class TestTranslateRequest:
assert len(result["tool_actions"]) == 1
assert result["tool_actions"][0]["call_id"] == "c1"
def test_continuation_with_top_level_conversation_id(self):
"""Standard clients send conversation_id at request level, not in messages."""
data = {
"conversation_id": "conv-top-level",
"messages": [
{"role": "user", "content": "Do stuff"},
{
"role": "assistant",
"tool_calls": [{"id": "c1", "type": "function", "function": {"name": "act", "arguments": "{}"}}],
},
{"role": "tool", "tool_call_id": "c1", "content": "done"},
],
}
result = translate_request(data, "key")
assert result["conversation_id"] == "conv-top-level"
def test_continuation_in_message_conversation_id_takes_precedence(self):
"""When both in-message and top-level conversation_id exist, in-message wins."""
data = {
"conversation_id": "conv-top-level",
"messages": [
{"role": "user", "content": "Do stuff"},
{
"role": "assistant",
"tool_calls": [{"id": "c1", "type": "function", "function": {"name": "act", "arguments": "{}"}}],
"docsgpt": {"conversation_id": "conv-in-message"},
},
{"role": "tool", "tool_call_id": "c1", "content": "done"},
],
}
result = translate_request(data, "key")
assert result["conversation_id"] == "conv-in-message"
def test_client_tools_passed_through(self):
data = {
"messages": [{"role": "user", "content": "Hi"}],
@@ -263,6 +297,51 @@ class TestTranslateResponse:
)
assert resp["docsgpt"]["tool_calls"] == tool_calls
def test_pending_tool_calls_strips_ct_suffix(self):
"""Internal _ct\\d+ suffixes must be stripped from tool names in responses."""
pending = [
{
"call_id": "c1",
"name": "get_weather_ct0",
"action_name": "get_weather_ct0",
"arguments": {"city": "SF"},
}
]
resp = translate_response(
conversation_id="c1",
answer="",
sources=[],
tool_calls=[],
thought="",
model_name="agent",
pending_tool_calls=pending,
)
tc = resp["choices"][0]["message"]["tool_calls"][0]
assert tc["function"]["name"] == "get_weather"
def test_pending_tool_calls_non_ct_suffix_preserved(self):
"""Non-client tool suffixes (e.g. _t1) should not be stripped."""
pending = [
{
"call_id": "c1",
"name": "search_t1",
"action_name": "search_t1",
"arguments": {"q": "test"},
}
]
resp = translate_response(
conversation_id="c1",
answer="",
sources=[],
tool_calls=[],
thought="",
model_name="agent",
pending_tool_calls=pending,
)
tc = resp["choices"][0]["message"]["tool_calls"][0]
# _t1 is NOT a client-tool suffix (_ct\d+), so it stays
assert tc["function"]["name"] == "search_t1"
def test_pending_tool_calls(self):
pending = [
{
@@ -366,6 +445,24 @@ class TestTranslateStreamEvent:
assert tc["id"] == "c1"
assert tc["function"]["name"] == "get_weather"
def test_tool_call_client_execution_strips_ct_suffix(self):
"""Internal _ct suffixes must be stripped from streaming tool call names."""
chunks = translate_stream_event(
{
"type": "tool_call",
"data": {
"call_id": "c1",
"action_name": "create_ct0",
"arguments": {"title": "test"},
"status": "requires_client_execution",
},
},
"chatcmpl-1", "agent",
)
parsed = json.loads(chunks[0].replace("data: ", "").strip())
tc = parsed["choices"][0]["delta"]["tool_calls"][0]
assert tc["function"]["name"] == "create"
def test_tool_call_completed(self):
chunks = translate_stream_event(
{
@@ -457,3 +554,30 @@ class TestTranslateStreamEvent:
assert "choices" not in parsed
# docsgpt key is present
assert "docsgpt" in parsed
# ---------------------------------------------------------------------------
# _strip_tool_suffix
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestStripToolSuffix:
def test_strips_ct0(self):
assert _strip_tool_suffix("create_ct0") == "create"
def test_strips_ct_multi_digit(self):
assert _strip_tool_suffix("write_file_ct12") == "write_file"
def test_preserves_non_ct_suffix(self):
assert _strip_tool_suffix("search_t1") == "search_t1"
def test_preserves_plain_name(self):
assert _strip_tool_suffix("get_weather") == "get_weather"
def test_preserves_empty(self):
assert _strip_tool_suffix("") == ""
def test_ct_in_middle_not_stripped(self):
assert _strip_tool_suffix("ct0_action") == "ct0_action"