mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-22 13:25:08 +00:00
Compare commits
28 Commits
fix-badly-
...
feat-notif
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
827a0bb382 | ||
|
|
b04cb44ab5 | ||
|
|
42384a0e92 | ||
|
|
0bce35ad29 | ||
|
|
9de8bb4499 | ||
|
|
cdbd3f061d | ||
|
|
2ac46fd858 | ||
|
|
daa4320da2 | ||
|
|
e70a7a5115 | ||
|
|
150d9f4e37 | ||
|
|
746bcbc5f9 | ||
|
|
aa91117fbf | ||
|
|
abbd56cb66 | ||
|
|
85d8375e6c | ||
|
|
7e98d21b61 | ||
|
|
249f9f9fe0 | ||
|
|
6c4346eb84 | ||
|
|
cb3ca8a36b | ||
|
|
4c8230fb6c | ||
|
|
649557798d | ||
|
|
afe8354ca5 | ||
|
|
5483eb0e27 | ||
|
|
bd2985db47 | ||
|
|
b99147ba83 | ||
|
|
c3023f8b71 | ||
|
|
c168a530f5 | ||
|
|
2d539f3199 | ||
|
|
ed9444cf3d |
@@ -8,7 +8,7 @@ RUN apt-get update && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Verify Python installation and setup symlink
|
||||
RUN if [ -f /usr/bin/python3.12 ]; then \
|
||||
@@ -73,7 +73,7 @@ COPY --from=builder /models /app/models
|
||||
COPY . /app/application
|
||||
|
||||
# Change the ownership of the /app directory to the appuser
|
||||
|
||||
|
||||
RUN mkdir -p /app/application/inputs/local
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
@@ -82,11 +82,6 @@ ENV FLASK_APP=app.py \
|
||||
FLASK_DEBUG=true \
|
||||
PATH="/venv/bin:$PATH"
|
||||
|
||||
ENV MALLOC_ARENA_MAX=2 \
|
||||
OMP_NUM_THREADS=4 \
|
||||
MKL_NUM_THREADS=4 \
|
||||
OPENBLAS_NUM_THREADS=4
|
||||
|
||||
# Expose the port the app runs on
|
||||
EXPOSE 7091
|
||||
|
||||
|
||||
@@ -114,8 +114,6 @@ class BaseAgent(ABC):
|
||||
self.compressed_summary = compressed_summary
|
||||
self.current_token_count = 0
|
||||
self.context_limit_reached = False
|
||||
self.conversation_id: Optional[str] = None
|
||||
self.initial_user_id: Optional[str] = None
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import ctypes
|
||||
import gc
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from celery import Celery
|
||||
@@ -101,34 +98,6 @@ def _unbind_task_log_context(task_id, **_):
|
||||
)
|
||||
|
||||
|
||||
def _trim_native_heap() -> None:
|
||||
"""Return freed glibc heap pages to the OS (Linux only; no-op elsewhere)."""
|
||||
# docling/torch parsing makes large transient allocations; glibc keeps the
|
||||
# freed pages in per-thread malloc arenas rather than returning them, so a
|
||||
# long-lived worker child's RSS only ever climbs. malloc_trim hands them
|
||||
# back. The symbol is glibc-only — absent in macOS libc.
|
||||
if not sys.platform.startswith("linux"):
|
||||
return
|
||||
try:
|
||||
ctypes.CDLL("libc.so.6").malloc_trim(0)
|
||||
except (OSError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
@task_postrun.connect
|
||||
def _reclaim_memory_after_task(*args, **kwargs):
|
||||
"""Drop per-task allocations so the prefork child's RSS doesn't ratchet."""
|
||||
gc.collect()
|
||||
torch = sys.modules.get("torch")
|
||||
if torch is not None:
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
_trim_native_heap()
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def _run_version_check(*args, **kwargs):
|
||||
"""Kick off the anonymous version check on worker startup.
|
||||
|
||||
@@ -31,10 +31,3 @@ worker_prefetch_multiplier = settings.CELERY_WORKER_PREFETCH_MULTIPLIER
|
||||
broker_transport_options = {"visibility_timeout": settings.CELERY_VISIBILITY_TIMEOUT}
|
||||
result_expires = 86400 * 7
|
||||
task_track_started = True
|
||||
|
||||
# Recycle the prefork worker child to bound native-heap growth from
|
||||
# docling/torch parsing. Left unset (Celery's unlimited default) when 0.
|
||||
if settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD > 0:
|
||||
worker_max_memory_per_child = settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD
|
||||
if settings.CELERY_WORKER_MAX_TASKS_PER_CHILD > 0:
|
||||
worker_max_tasks_per_child = settings.CELERY_WORKER_MAX_TASKS_PER_CHILD
|
||||
|
||||
@@ -36,11 +36,6 @@ class Settings(BaseSettings):
|
||||
# and Dify defaults; long ingests can override via env.
|
||||
CELERY_WORKER_PREFETCH_MULTIPLIER: int = 1
|
||||
CELERY_VISIBILITY_TIMEOUT: int = 3600
|
||||
# Recycle the prefork worker child once its resident size crosses this many
|
||||
# kilobytes — backstops native-heap growth from docling/torch parsing. 0 disables.
|
||||
CELERY_WORKER_MAX_MEMORY_PER_CHILD: int = 4194304
|
||||
# Recycle the child after this many tasks; 0 disables (memory cap is the primary knob).
|
||||
CELERY_WORKER_MAX_TASKS_PER_CHILD: int = 0
|
||||
# Only consulted when VECTOR_STORE=mongodb or when running scripts/db/backfill.py; user data lives in Postgres.
|
||||
MONGO_URI: Optional[str] = None
|
||||
# User-data Postgres DB.
|
||||
|
||||
@@ -63,8 +63,7 @@ class ToolCallAttemptsRepository:
|
||||
message_id: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Insert OR upgrade a row to ``executed`` — or ``confirmed`` when
|
||||
there is no ``message_id``, as in ``mark_executed``.
|
||||
"""Insert OR upgrade a row to ``executed``.
|
||||
|
||||
Used as a fallback when ``record_proposed`` failed (DB outage)
|
||||
and the tool ran anyway — preserves the journal so the
|
||||
@@ -73,7 +72,6 @@ class ToolCallAttemptsRepository:
|
||||
result_payload: dict = {"result": result}
|
||||
if artifact_id:
|
||||
result_payload["artifact_id"] = artifact_id
|
||||
status = "executed" if message_id is not None else "confirmed"
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
@@ -84,9 +82,9 @@ class ToolCallAttemptsRepository:
|
||||
(:call_id, CAST(:tool_id AS uuid), :tool_name,
|
||||
:action_name, CAST(:arguments AS jsonb),
|
||||
CAST(:result AS jsonb), CAST(:message_id AS uuid),
|
||||
:status)
|
||||
'executed')
|
||||
ON CONFLICT (call_id) DO UPDATE
|
||||
SET status = :status,
|
||||
SET status = 'executed',
|
||||
result = EXCLUDED.result,
|
||||
message_id = COALESCE(EXCLUDED.message_id, tool_call_attempts.message_id)
|
||||
"""
|
||||
@@ -99,7 +97,6 @@ class ToolCallAttemptsRepository:
|
||||
"arguments": json.dumps(arguments if arguments is not None else {}, cls=PGNativeJSONEncoder),
|
||||
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
|
||||
"message_id": message_id,
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -111,9 +108,7 @@ class ToolCallAttemptsRepository:
|
||||
message_id: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Flip ``proposed`` → ``executed``, or straight to ``confirmed``
|
||||
when there is no ``message_id`` (a ``save_conversation=False``
|
||||
request reserves no message, so no finalize will confirm it).
|
||||
"""Flip ``proposed`` → ``executed`` with the tool result.
|
||||
|
||||
``artifact_id`` (when present) is stored alongside ``result`` in
|
||||
the JSONB as audit data — the reconciler reads it for diagnostic
|
||||
@@ -122,14 +117,12 @@ class ToolCallAttemptsRepository:
|
||||
result_payload: dict = {"result": result}
|
||||
if artifact_id:
|
||||
result_payload["artifact_id"] = artifact_id
|
||||
status = "executed" if message_id is not None else "confirmed"
|
||||
sql = (
|
||||
"UPDATE tool_call_attempts SET "
|
||||
"status = :status, result = CAST(:result AS jsonb)"
|
||||
"status = 'executed', result = CAST(:result AS jsonb)"
|
||||
)
|
||||
params: dict[str, Any] = {
|
||||
"call_id": call_id,
|
||||
"status": status,
|
||||
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
|
||||
}
|
||||
if message_id is not None:
|
||||
|
||||
@@ -4,24 +4,19 @@ Fixed 5-second generation (100 tokens × 50 ms/token). No auth. Emits SSE
|
||||
chunks in OpenAI's chat.completions streaming format, or a single response
|
||||
when stream=false. Run on 127.0.0.1:8090 — point DocsGPT at it via
|
||||
OPENAI_BASE_URL=http://127.0.0.1:8090/v1.
|
||||
|
||||
Flags:
|
||||
--tool-calls First response returns a tool call instead of text.
|
||||
Subsequent responses (after a tool_result) return text.
|
||||
Useful for triggering the tool-execution loop.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from flask import Flask, Response, request, jsonify
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
TOKEN_COUNT = 100
|
||||
TOKEN_DELAY_S = 0.05 # 100 * 0.05 = 5.0 s
|
||||
TOOL_CALL_MODE = False
|
||||
|
||||
logger = logging.getLogger("mock_llm")
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s mock: %(message)s")
|
||||
@@ -44,7 +39,7 @@ FILLER_TOKENS = [
|
||||
".",
|
||||
]
|
||||
|
||||
app = Flask(__name__)
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def _token_stream_id() -> str:
|
||||
@@ -68,57 +63,11 @@ def _sse_chunk(completion_id: str, model: str, delta: dict, finish_reason=None)
|
||||
return f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
|
||||
def _gen_tool_call_stream(model: str, req_id: str):
|
||||
"""Emit two tool_calls (search) in streaming format.
|
||||
|
||||
Two calls ensure the handler executes the first (which can return a
|
||||
huge result), then hits _check_context_limit before the second.
|
||||
"""
|
||||
completion_id = _token_stream_id()
|
||||
call_id_1 = f"call_{uuid.uuid4().hex[:12]}"
|
||||
call_id_2 = f"call_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
yield _sse_chunk(completion_id, model, {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": call_id_1,
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": ""},
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"id": call_id_2,
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": ""},
|
||||
},
|
||||
],
|
||||
})
|
||||
args_json = json.dumps({"query": "Python programming basics"})
|
||||
for ch in args_json:
|
||||
time.sleep(TOKEN_DELAY_S)
|
||||
yield _sse_chunk(completion_id, model, {
|
||||
"tool_calls": [
|
||||
{"index": 0, "function": {"arguments": ch}},
|
||||
{"index": 1, "function": {"arguments": ch}},
|
||||
],
|
||||
})
|
||||
yield _sse_chunk(completion_id, model, {}, finish_reason="tool_calls")
|
||||
yield "data: [DONE]\n\n"
|
||||
logger.info("[%s] tool_call stream done (ids=%s, %s)", req_id, call_id_1, call_id_2)
|
||||
|
||||
|
||||
def _has_tool_result(messages: list) -> bool:
|
||||
return any(m.get("role") == "tool" for m in messages)
|
||||
|
||||
|
||||
def _gen_text_stream(model: str, req_id: str):
|
||||
async def _stream_response(model: str, req_id: str):
|
||||
completion_id = _token_stream_id()
|
||||
yield _sse_chunk(completion_id, model, {"role": "assistant", "content": ""})
|
||||
for tok in FILLER_TOKENS[:TOKEN_COUNT]:
|
||||
time.sleep(TOKEN_DELAY_S)
|
||||
for i, tok in enumerate(FILLER_TOKENS[:TOKEN_COUNT]):
|
||||
await asyncio.sleep(TOKEN_DELAY_S)
|
||||
yield _sse_chunk(completion_id, model, {"content": tok})
|
||||
yield _sse_chunk(completion_id, model, {}, finish_reason="stop")
|
||||
yield "data: [DONE]\n\n"
|
||||
@@ -126,84 +75,63 @@ def _gen_text_stream(model: str, req_id: str):
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
def chat_completions():
|
||||
body = request.get_json(force=True)
|
||||
async def chat_completions(request: Request):
|
||||
body = await request.json()
|
||||
model = body.get("model", "mock")
|
||||
stream = bool(body.get("stream", False))
|
||||
messages = body.get("messages", [])
|
||||
tools = body.get("tools")
|
||||
req_id = uuid.uuid4().hex[:8]
|
||||
logger.info(
|
||||
"[%s] /chat/completions stream=%s model=%s tools=%s msgs=%d",
|
||||
req_id, stream, model, bool(tools), len(messages),
|
||||
)
|
||||
|
||||
use_tool_call = (
|
||||
TOOL_CALL_MODE
|
||||
and tools
|
||||
and not _has_tool_result(messages)
|
||||
)
|
||||
logger.info("[%s] /chat/completions stream=%s model=%s max_tokens=%s", req_id, stream, model, body.get("max_tokens"))
|
||||
|
||||
if stream:
|
||||
gen = (
|
||||
_gen_tool_call_stream(model, req_id) if use_tool_call
|
||||
else _gen_text_stream(model, req_id)
|
||||
)
|
||||
return Response(
|
||||
gen,
|
||||
mimetype="text/event-stream",
|
||||
return StreamingResponse(
|
||||
_stream_response(model, req_id),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
time.sleep(TOKEN_COUNT * TOKEN_DELAY_S)
|
||||
await asyncio.sleep(TOKEN_COUNT * TOKEN_DELAY_S)
|
||||
logger.info("[%s] non-stream done", req_id)
|
||||
text = "".join(FILLER_TOKENS[:TOKEN_COUNT])
|
||||
completion_id = _token_stream_id()
|
||||
return jsonify({
|
||||
"id": completion_id,
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": text},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": TOKEN_COUNT,
|
||||
"total_tokens": 10 + TOKEN_COUNT,
|
||||
},
|
||||
})
|
||||
return JSONResponse(
|
||||
{
|
||||
"id": completion_id,
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": text},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": TOKEN_COUNT,
|
||||
"total_tokens": 10 + TOKEN_COUNT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
def list_models():
|
||||
return jsonify({
|
||||
async def list_models():
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [{"id": "mock", "object": "model", "owned_by": "mock"}],
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return jsonify({"status": "ok"})
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--tool-calls", action="store_true",
|
||||
help="First response returns a tool_call; subsequent responses return text.",
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=8090)
|
||||
args = parser.parse_args()
|
||||
TOOL_CALL_MODE = args.tool_calls
|
||||
if TOOL_CALL_MODE:
|
||||
logger.info("Tool-call mode enabled")
|
||||
app.run(host="127.0.0.1", port=args.port, debug=False, threaded=True)
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="127.0.0.1", port=8090, log_level="info")
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Tests for the journaled execute path on ToolExecutor.
|
||||
|
||||
Each tool call inserts a ``tool_call_attempts`` row and flips it
|
||||
``proposed → executed`` (or ``→ failed``). With a ``message_id`` it
|
||||
stays ``executed`` for the finalize path to confirm; without one
|
||||
(``save_conversation=False``) it goes straight to ``confirmed``.
|
||||
Each tool call inserts a row into ``tool_call_attempts`` then flips
|
||||
through ``proposed → executed`` (or ``proposed → failed``). The flip
|
||||
to ``confirmed`` is owned by the message-finalize path and is only
|
||||
asserted indirectly here (rows stay in ``executed`` so the reconciler
|
||||
can pick them up).
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
@@ -74,24 +75,11 @@ def _make_call(name="test_action_t1", call_id="c1"):
|
||||
return call
|
||||
|
||||
|
||||
_TOOLS_DICT = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExecuteJournaling:
|
||||
def test_no_message_id_proposed_then_confirmed(
|
||||
def test_happy_path_proposed_then_executed(
|
||||
self, pg_conn, mock_tool_manager, monkeypatch
|
||||
):
|
||||
"""No reserved message (``save_conversation=False``) → row lands ``confirmed``, not ``executed``."""
|
||||
executor = ToolExecutor(user="u")
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
@@ -101,12 +89,23 @@ class TestExecuteJournaling:
|
||||
)
|
||||
_patch_db(monkeypatch, pg_conn)
|
||||
|
||||
events, result = _drain(executor.execute(_TOOLS_DICT, _make_call(), "MockLLM"))
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
events, result = _drain(executor.execute(tools_dict, _make_call(), "MockLLM"))
|
||||
assert result[0] == "Tool result"
|
||||
|
||||
row = _select_attempt(pg_conn, "c1")
|
||||
assert row is not None
|
||||
assert row["status"] == "confirmed"
|
||||
assert row["status"] == "executed"
|
||||
assert row["tool_name"] == "test_tool"
|
||||
assert row["action_name"] == "test_action"
|
||||
assert row["arguments"] == {"q": "v"}
|
||||
@@ -118,7 +117,10 @@ class TestExecuteJournaling:
|
||||
def test_executor_message_id_is_persisted_on_executed_row(
|
||||
self, pg_conn, mock_tool_manager, monkeypatch
|
||||
):
|
||||
"""The executor's message_id is carried onto the journal row, which stays ``executed``."""
|
||||
"""When the route stamps a placeholder message_id on the executor,
|
||||
the journal row carries it forward so ``confirm_executed_tool_calls``
|
||||
can later flip it to ``confirmed``.
|
||||
"""
|
||||
from application.storage.db.repositories.conversations import (
|
||||
ConversationsRepository,
|
||||
)
|
||||
@@ -145,7 +147,18 @@ class TestExecuteJournaling:
|
||||
)
|
||||
_patch_db(monkeypatch, pg_conn)
|
||||
|
||||
_drain(executor.execute(_TOOLS_DICT, _make_call(call_id="cm1"), "MockLLM"))
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
_drain(executor.execute(tools_dict, _make_call(call_id="cm1"), "MockLLM"))
|
||||
|
||||
row = _select_attempt(pg_conn, "cm1")
|
||||
assert row is not None
|
||||
@@ -167,7 +180,18 @@ class TestExecuteJournaling:
|
||||
RuntimeError("boom")
|
||||
)
|
||||
|
||||
gen = executor.execute(_TOOLS_DICT, _make_call(call_id="c2"), "MockLLM")
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
gen = executor.execute(tools_dict, _make_call(call_id="c2"), "MockLLM")
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
_drain(gen)
|
||||
|
||||
@@ -176,10 +200,42 @@ class TestExecuteJournaling:
|
||||
assert row["status"] == "failed"
|
||||
assert row["error"] == "boom"
|
||||
|
||||
def test_executed_row_lingers_for_reconciler_when_no_confirm(
|
||||
self, pg_conn, mock_tool_manager, monkeypatch
|
||||
):
|
||||
"""No finalize_message call → row sits in ``executed``."""
|
||||
executor = ToolExecutor(user="u")
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls, **kw: Mock(
|
||||
parse_args=Mock(return_value=("t1", "test_action", {}))
|
||||
),
|
||||
)
|
||||
_patch_db(monkeypatch, pg_conn)
|
||||
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
_drain(executor.execute(tools_dict, _make_call(call_id="c3"), "MockLLM"))
|
||||
|
||||
row = _select_attempt(pg_conn, "c3")
|
||||
assert row["status"] == "executed"
|
||||
# Partial index `tool_call_attempts_pending_ts_idx` selects rows
|
||||
# in ('proposed','executed') — the reconciler reads those.
|
||||
assert row["status"] in ("proposed", "executed")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRepository:
|
||||
def test_proposed_then_confirmed_when_no_message(self, pg_conn):
|
||||
def test_proposed_then_executed_round_trip(self, pg_conn):
|
||||
from application.storage.db.repositories.tool_call_attempts import (
|
||||
ToolCallAttemptsRepository,
|
||||
)
|
||||
@@ -193,50 +249,7 @@ class TestRepository:
|
||||
|
||||
assert repo.mark_executed("c-x", {"out": "ok"}) is True
|
||||
row = _select_attempt(pg_conn, "c-x")
|
||||
assert row["status"] == "confirmed"
|
||||
assert row["message_id"] is None
|
||||
assert row["result"] == {"result": {"out": "ok"}}
|
||||
|
||||
def test_mark_executed_with_message_stays_executed(self, pg_conn):
|
||||
from application.storage.db.repositories.conversations import (
|
||||
ConversationsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.tool_call_attempts import (
|
||||
ToolCallAttemptsRepository,
|
||||
)
|
||||
|
||||
# FK constraint: message_id must reference a real row.
|
||||
conv_repo = ConversationsRepository(pg_conn)
|
||||
conv = conv_repo.create("u-repo", "repo-msg-test")
|
||||
msg = conv_repo.reserve_message(
|
||||
str(conv["id"]),
|
||||
prompt="q?",
|
||||
placeholder_response="...",
|
||||
request_id="req-repo-1",
|
||||
status="pending",
|
||||
)
|
||||
message_uuid = str(msg["id"])
|
||||
|
||||
repo = ToolCallAttemptsRepository(pg_conn)
|
||||
repo.record_proposed("c-m", "tool", "act", {})
|
||||
assert (
|
||||
repo.mark_executed("c-m", {"out": "ok"}, message_id=message_uuid) is True
|
||||
)
|
||||
row = _select_attempt(pg_conn, "c-m")
|
||||
assert row["status"] == "executed"
|
||||
assert str(row["message_id"]) == message_uuid
|
||||
|
||||
def test_upsert_executed_without_message_confirms(self, pg_conn):
|
||||
"""``upsert_executed`` (DB-outage fallback) with no ``message_id`` lands ``confirmed``."""
|
||||
from application.storage.db.repositories.tool_call_attempts import (
|
||||
ToolCallAttemptsRepository,
|
||||
)
|
||||
|
||||
repo = ToolCallAttemptsRepository(pg_conn)
|
||||
repo.upsert_executed("c-up", "tool", "act", {"a": 1}, {"out": "ok"})
|
||||
row = _select_attempt(pg_conn, "c-up")
|
||||
assert row["status"] == "confirmed"
|
||||
assert row["message_id"] is None
|
||||
assert row["result"] == {"result": {"out": "ok"}}
|
||||
|
||||
def test_mark_failed_sets_error(self, pg_conn):
|
||||
|
||||
Reference in New Issue
Block a user