Compare commits

...

3 Commits

Author SHA1 Message Date
Alex
c0510b974d fix: marking executed tool calls on webhooks 2026-05-17 21:51:37 +01:00
Alex
97a362b703 feat: fix glibc memory overflow (#2478) 2026-05-17 11:55:39 +01:00
Pavel
29477b40b3 define conversation_id and initial_user_id on BaseAgent (#2474)
These attributes were only set by StreamProcessor after agent creation,
causing an AttributeError in _perform_mid_execution_compression when
the context limit was hit through other code paths (e.g. worker).
Declaring them as None in init lets the handler fall through to
in-memory compression gracefully.
2026-05-15 15:33:34 +01:00
8 changed files with 245 additions and 129 deletions

View File

@@ -8,7 +8,7 @@ RUN apt-get update && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && \
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
rm -rf /var/lib/apt/lists/*
rm -rf /var/lib/apt/lists/*
# Verify Python installation and setup symlink
RUN if [ -f /usr/bin/python3.12 ]; then \
@@ -73,7 +73,7 @@ COPY --from=builder /models /app/models
COPY . /app/application
# Change the ownership of the /app directory to the appuser
RUN mkdir -p /app/application/inputs/local
RUN chown -R appuser:appuser /app
@@ -82,6 +82,11 @@ ENV FLASK_APP=app.py \
FLASK_DEBUG=true \
PATH="/venv/bin:$PATH"
ENV MALLOC_ARENA_MAX=2 \
OMP_NUM_THREADS=4 \
MKL_NUM_THREADS=4 \
OPENBLAS_NUM_THREADS=4
# Expose the port the app runs on
EXPOSE 7091

View File

@@ -114,6 +114,8 @@ class BaseAgent(ABC):
self.compressed_summary = compressed_summary
self.current_token_count = 0
self.context_limit_reached = False
self.conversation_id: Optional[str] = None
self.initial_user_id: Optional[str] = None
@log_activity()
def gen(

View File

@@ -1,5 +1,8 @@
import ctypes
import gc
import inspect
import logging
import sys
import threading
from celery import Celery
@@ -98,6 +101,34 @@ def _unbind_task_log_context(task_id, **_):
)
def _trim_native_heap() -> None:
"""Return freed glibc heap pages to the OS (Linux only; no-op elsewhere)."""
# docling/torch parsing makes large transient allocations; glibc keeps the
# freed pages in per-thread malloc arenas rather than returning them, so a
# long-lived worker child's RSS only ever climbs. malloc_trim hands them
# back. The symbol is glibc-only — absent in macOS libc.
if not sys.platform.startswith("linux"):
return
try:
ctypes.CDLL("libc.so.6").malloc_trim(0)
except (OSError, AttributeError):
pass
@task_postrun.connect
def _reclaim_memory_after_task(*args, **kwargs):
"""Drop per-task allocations so the prefork child's RSS doesn't ratchet."""
gc.collect()
torch = sys.modules.get("torch")
if torch is not None:
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception:
pass
_trim_native_heap()
@worker_ready.connect
def _run_version_check(*args, **kwargs):
"""Kick off the anonymous version check on worker startup.

View File

@@ -31,3 +31,10 @@ worker_prefetch_multiplier = settings.CELERY_WORKER_PREFETCH_MULTIPLIER
broker_transport_options = {"visibility_timeout": settings.CELERY_VISIBILITY_TIMEOUT}
result_expires = 86400 * 7
task_track_started = True
# Recycle the prefork worker child to bound native-heap growth from
# docling/torch parsing. Left unset (Celery's unlimited default) when 0.
if settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD > 0:
worker_max_memory_per_child = settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD
if settings.CELERY_WORKER_MAX_TASKS_PER_CHILD > 0:
worker_max_tasks_per_child = settings.CELERY_WORKER_MAX_TASKS_PER_CHILD

View File

@@ -36,6 +36,11 @@ class Settings(BaseSettings):
# and Dify defaults; long ingests can override via env.
CELERY_WORKER_PREFETCH_MULTIPLIER: int = 1
CELERY_VISIBILITY_TIMEOUT: int = 3600
# Recycle the prefork worker child once its resident size crosses this many
# kilobytes — backstops native-heap growth from docling/torch parsing. 0 disables.
CELERY_WORKER_MAX_MEMORY_PER_CHILD: int = 4194304
# Recycle the child after this many tasks; 0 disables (memory cap is the primary knob).
CELERY_WORKER_MAX_TASKS_PER_CHILD: int = 0
# Only consulted when VECTOR_STORE=mongodb or when running scripts/db/backfill.py; user data lives in Postgres.
MONGO_URI: Optional[str] = None
# User-data Postgres DB.

View File

@@ -63,7 +63,8 @@ class ToolCallAttemptsRepository:
message_id: Optional[str] = None,
artifact_id: Optional[str] = None,
) -> None:
"""Insert OR upgrade a row to ``executed``.
"""Insert OR upgrade a row to ``executed`` — or ``confirmed`` when
there is no ``message_id``, as in ``mark_executed``.
Used as a fallback when ``record_proposed`` failed (DB outage)
and the tool ran anyway — preserves the journal so the
@@ -72,6 +73,7 @@ class ToolCallAttemptsRepository:
result_payload: dict = {"result": result}
if artifact_id:
result_payload["artifact_id"] = artifact_id
status = "executed" if message_id is not None else "confirmed"
self._conn.execute(
text(
"""
@@ -82,9 +84,9 @@ class ToolCallAttemptsRepository:
(:call_id, CAST(:tool_id AS uuid), :tool_name,
:action_name, CAST(:arguments AS jsonb),
CAST(:result AS jsonb), CAST(:message_id AS uuid),
'executed')
:status)
ON CONFLICT (call_id) DO UPDATE
SET status = 'executed',
SET status = :status,
result = EXCLUDED.result,
message_id = COALESCE(EXCLUDED.message_id, tool_call_attempts.message_id)
"""
@@ -97,6 +99,7 @@ class ToolCallAttemptsRepository:
"arguments": json.dumps(arguments if arguments is not None else {}, cls=PGNativeJSONEncoder),
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
"message_id": message_id,
"status": status,
},
)
@@ -108,7 +111,9 @@ class ToolCallAttemptsRepository:
message_id: Optional[str] = None,
artifact_id: Optional[str] = None,
) -> bool:
"""Flip ``proposed`` → ``executed`` with the tool result.
"""Flip ``proposed`` → ``executed``, or straight to ``confirmed``
when there is no ``message_id`` (a ``save_conversation=False``
request reserves no message, so no finalize will confirm it).
``artifact_id`` (when present) is stored alongside ``result`` in
the JSONB as audit data — the reconciler reads it for diagnostic
@@ -117,12 +122,14 @@ class ToolCallAttemptsRepository:
result_payload: dict = {"result": result}
if artifact_id:
result_payload["artifact_id"] = artifact_id
status = "executed" if message_id is not None else "confirmed"
sql = (
"UPDATE tool_call_attempts SET "
"status = 'executed', result = CAST(:result AS jsonb)"
"status = :status, result = CAST(:result AS jsonb)"
)
params: dict[str, Any] = {
"call_id": call_id,
"status": status,
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
}
if message_id is not None:

View File

@@ -4,19 +4,24 @@ Fixed 5-second generation (100 tokens × 50 ms/token). No auth. Emits SSE
chunks in OpenAI's chat.completions streaming format, or a single response
when stream=false. Run on 127.0.0.1:8090 — point DocsGPT at it via
OPENAI_BASE_URL=http://127.0.0.1:8090/v1.
Flags:
--tool-calls First response returns a tool call instead of text.
Subsequent responses (after a tool_result) return text.
Useful for triggering the tool-execution loop.
"""
import asyncio
import argparse
import json
import logging
import time
import uuid
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from flask import Flask, Response, request, jsonify
TOKEN_COUNT = 100
TOKEN_DELAY_S = 0.05 # 100 * 0.05 = 5.0 s
TOOL_CALL_MODE = False
logger = logging.getLogger("mock_llm")
logging.basicConfig(level=logging.INFO, format="%(asctime)s mock: %(message)s")
@@ -39,7 +44,7 @@ FILLER_TOKENS = [
".",
]
app = FastAPI()
app = Flask(__name__)
def _token_stream_id() -> str:
@@ -63,11 +68,57 @@ def _sse_chunk(completion_id: str, model: str, delta: dict, finish_reason=None)
return f"data: {json.dumps(payload)}\n\n"
async def _stream_response(model: str, req_id: str):
def _gen_tool_call_stream(model: str, req_id: str):
"""Emit two tool_calls (search) in streaming format.
Two calls ensure the handler executes the first (which can return a
huge result), then hits _check_context_limit before the second.
"""
completion_id = _token_stream_id()
call_id_1 = f"call_{uuid.uuid4().hex[:12]}"
call_id_2 = f"call_{uuid.uuid4().hex[:12]}"
yield _sse_chunk(completion_id, model, {
"role": "assistant",
"content": None,
"tool_calls": [
{
"index": 0,
"id": call_id_1,
"type": "function",
"function": {"name": "search", "arguments": ""},
},
{
"index": 1,
"id": call_id_2,
"type": "function",
"function": {"name": "search", "arguments": ""},
},
],
})
args_json = json.dumps({"query": "Python programming basics"})
for ch in args_json:
time.sleep(TOKEN_DELAY_S)
yield _sse_chunk(completion_id, model, {
"tool_calls": [
{"index": 0, "function": {"arguments": ch}},
{"index": 1, "function": {"arguments": ch}},
],
})
yield _sse_chunk(completion_id, model, {}, finish_reason="tool_calls")
yield "data: [DONE]\n\n"
logger.info("[%s] tool_call stream done (ids=%s, %s)", req_id, call_id_1, call_id_2)
def _has_tool_result(messages: list) -> bool:
return any(m.get("role") == "tool" for m in messages)
def _gen_text_stream(model: str, req_id: str):
completion_id = _token_stream_id()
yield _sse_chunk(completion_id, model, {"role": "assistant", "content": ""})
for i, tok in enumerate(FILLER_TOKENS[:TOKEN_COUNT]):
await asyncio.sleep(TOKEN_DELAY_S)
for tok in FILLER_TOKENS[:TOKEN_COUNT]:
time.sleep(TOKEN_DELAY_S)
yield _sse_chunk(completion_id, model, {"content": tok})
yield _sse_chunk(completion_id, model, {}, finish_reason="stop")
yield "data: [DONE]\n\n"
@@ -75,63 +126,84 @@ async def _stream_response(model: str, req_id: str):
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
body = await request.json()
def chat_completions():
body = request.get_json(force=True)
model = body.get("model", "mock")
stream = bool(body.get("stream", False))
messages = body.get("messages", [])
tools = body.get("tools")
req_id = uuid.uuid4().hex[:8]
logger.info("[%s] /chat/completions stream=%s model=%s max_tokens=%s", req_id, stream, model, body.get("max_tokens"))
logger.info(
"[%s] /chat/completions stream=%s model=%s tools=%s msgs=%d",
req_id, stream, model, bool(tools), len(messages),
)
use_tool_call = (
TOOL_CALL_MODE
and tools
and not _has_tool_result(messages)
)
if stream:
return StreamingResponse(
_stream_response(model, req_id),
media_type="text/event-stream",
gen = (
_gen_tool_call_stream(model, req_id) if use_tool_call
else _gen_text_stream(model, req_id)
)
return Response(
gen,
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
},
)
await asyncio.sleep(TOKEN_COUNT * TOKEN_DELAY_S)
time.sleep(TOKEN_COUNT * TOKEN_DELAY_S)
logger.info("[%s] non-stream done", req_id)
text = "".join(FILLER_TOKENS[:TOKEN_COUNT])
completion_id = _token_stream_id()
return JSONResponse(
{
"id": completion_id,
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": TOKEN_COUNT,
"total_tokens": 10 + TOKEN_COUNT,
},
}
)
return jsonify({
"id": completion_id,
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": TOKEN_COUNT,
"total_tokens": 10 + TOKEN_COUNT,
},
})
@app.get("/v1/models")
async def list_models():
return {
def list_models():
return jsonify({
"object": "list",
"data": [{"id": "mock", "object": "model", "owned_by": "mock"}],
}
})
@app.get("/health")
async def health():
return {"status": "ok"}
def health():
return jsonify({"status": "ok"})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8090, log_level="info")
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--tool-calls", action="store_true",
help="First response returns a tool_call; subsequent responses return text.",
)
parser.add_argument("--port", type=int, default=8090)
args = parser.parse_args()
TOOL_CALL_MODE = args.tool_calls
if TOOL_CALL_MODE:
logger.info("Tool-call mode enabled")
app.run(host="127.0.0.1", port=args.port, debug=False, threaded=True)

View File

@@ -1,10 +1,9 @@
"""Tests for the journaled execute path on ToolExecutor.
Each tool call inserts a row into ``tool_call_attempts`` then flips
through ``proposed → executed`` (or ``proposed → failed``). The flip
to ``confirmed`` is owned by the message-finalize path and is only
asserted indirectly here (rows stay in ``executed`` so the reconciler
can pick them up).
Each tool call inserts a ``tool_call_attempts`` row and flips it
``proposed → executed`` (or ``→ failed``). With a ``message_id`` it
stays ``executed`` for the finalize path to confirm; without one
(``save_conversation=False``) it goes straight to ``confirmed``.
"""
from contextlib import contextmanager
@@ -75,11 +74,24 @@ def _make_call(name="test_action_t1", call_id="c1"):
return call
_TOOLS_DICT = {
"t1": {
"id": "00000000-0000-0000-0000-000000000001",
"name": "test_tool",
"config": {"key": "val"},
"actions": [
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
],
}
}
@pytest.mark.unit
class TestExecuteJournaling:
def test_happy_path_proposed_then_executed(
def test_no_message_id_proposed_then_confirmed(
self, pg_conn, mock_tool_manager, monkeypatch
):
"""No reserved message (``save_conversation=False``) → row lands ``confirmed``, not ``executed``."""
executor = ToolExecutor(user="u")
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
@@ -89,23 +101,12 @@ class TestExecuteJournaling:
)
_patch_db(monkeypatch, pg_conn)
tools_dict = {
"t1": {
"id": "00000000-0000-0000-0000-000000000001",
"name": "test_tool",
"config": {"key": "val"},
"actions": [
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
],
}
}
events, result = _drain(executor.execute(tools_dict, _make_call(), "MockLLM"))
events, result = _drain(executor.execute(_TOOLS_DICT, _make_call(), "MockLLM"))
assert result[0] == "Tool result"
row = _select_attempt(pg_conn, "c1")
assert row is not None
assert row["status"] == "executed"
assert row["status"] == "confirmed"
assert row["tool_name"] == "test_tool"
assert row["action_name"] == "test_action"
assert row["arguments"] == {"q": "v"}
@@ -117,10 +118,7 @@ class TestExecuteJournaling:
def test_executor_message_id_is_persisted_on_executed_row(
self, pg_conn, mock_tool_manager, monkeypatch
):
"""When the route stamps a placeholder message_id on the executor,
the journal row carries it forward so ``confirm_executed_tool_calls``
can later flip it to ``confirmed``.
"""
"""The executor's message_id is carried onto the journal row, which stays ``executed``."""
from application.storage.db.repositories.conversations import (
ConversationsRepository,
)
@@ -147,18 +145,7 @@ class TestExecuteJournaling:
)
_patch_db(monkeypatch, pg_conn)
tools_dict = {
"t1": {
"id": "00000000-0000-0000-0000-000000000001",
"name": "test_tool",
"config": {"key": "val"},
"actions": [
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
],
}
}
_drain(executor.execute(tools_dict, _make_call(call_id="cm1"), "MockLLM"))
_drain(executor.execute(_TOOLS_DICT, _make_call(call_id="cm1"), "MockLLM"))
row = _select_attempt(pg_conn, "cm1")
assert row is not None
@@ -180,18 +167,7 @@ class TestExecuteJournaling:
RuntimeError("boom")
)
tools_dict = {
"t1": {
"id": "00000000-0000-0000-0000-000000000001",
"name": "test_tool",
"config": {"key": "val"},
"actions": [
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
],
}
}
gen = executor.execute(tools_dict, _make_call(call_id="c2"), "MockLLM")
gen = executor.execute(_TOOLS_DICT, _make_call(call_id="c2"), "MockLLM")
with pytest.raises(RuntimeError, match="boom"):
_drain(gen)
@@ -200,42 +176,10 @@ class TestExecuteJournaling:
assert row["status"] == "failed"
assert row["error"] == "boom"
def test_executed_row_lingers_for_reconciler_when_no_confirm(
self, pg_conn, mock_tool_manager, monkeypatch
):
"""No finalize_message call → row sits in ``executed``."""
executor = ToolExecutor(user="u")
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls, **kw: Mock(
parse_args=Mock(return_value=("t1", "test_action", {}))
),
)
_patch_db(monkeypatch, pg_conn)
tools_dict = {
"t1": {
"id": "00000000-0000-0000-0000-000000000001",
"name": "test_tool",
"config": {"key": "val"},
"actions": [
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
],
}
}
_drain(executor.execute(tools_dict, _make_call(call_id="c3"), "MockLLM"))
row = _select_attempt(pg_conn, "c3")
assert row["status"] == "executed"
# Partial index `tool_call_attempts_pending_ts_idx` selects rows
# in ('proposed','executed') — the reconciler reads those.
assert row["status"] in ("proposed", "executed")
@pytest.mark.unit
class TestRepository:
def test_proposed_then_executed_round_trip(self, pg_conn):
def test_proposed_then_confirmed_when_no_message(self, pg_conn):
from application.storage.db.repositories.tool_call_attempts import (
ToolCallAttemptsRepository,
)
@@ -249,7 +193,50 @@ class TestRepository:
assert repo.mark_executed("c-x", {"out": "ok"}) is True
row = _select_attempt(pg_conn, "c-x")
assert row["status"] == "confirmed"
assert row["message_id"] is None
assert row["result"] == {"result": {"out": "ok"}}
def test_mark_executed_with_message_stays_executed(self, pg_conn):
from application.storage.db.repositories.conversations import (
ConversationsRepository,
)
from application.storage.db.repositories.tool_call_attempts import (
ToolCallAttemptsRepository,
)
# FK constraint: message_id must reference a real row.
conv_repo = ConversationsRepository(pg_conn)
conv = conv_repo.create("u-repo", "repo-msg-test")
msg = conv_repo.reserve_message(
str(conv["id"]),
prompt="q?",
placeholder_response="...",
request_id="req-repo-1",
status="pending",
)
message_uuid = str(msg["id"])
repo = ToolCallAttemptsRepository(pg_conn)
repo.record_proposed("c-m", "tool", "act", {})
assert (
repo.mark_executed("c-m", {"out": "ok"}, message_id=message_uuid) is True
)
row = _select_attempt(pg_conn, "c-m")
assert row["status"] == "executed"
assert str(row["message_id"]) == message_uuid
def test_upsert_executed_without_message_confirms(self, pg_conn):
"""``upsert_executed`` (DB-outage fallback) with no ``message_id`` lands ``confirmed``."""
from application.storage.db.repositories.tool_call_attempts import (
ToolCallAttemptsRepository,
)
repo = ToolCallAttemptsRepository(pg_conn)
repo.upsert_executed("c-up", "tool", "act", {"a": 1}, {"out": "ok"})
row = _select_attempt(pg_conn, "c-up")
assert row["status"] == "confirmed"
assert row["message_id"] is None
assert row["result"] == {"result": {"out": "ok"}}
def test_mark_failed_sets_error(self, pg_conn):