mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-06 16:25:04 +00:00
feat: more logs on stream finish
This commit is contained in:
@@ -296,6 +296,35 @@ class BaseLLM(ABC):
|
||||
},
|
||||
)
|
||||
|
||||
def _emit_stream_finished_log(
|
||||
self,
|
||||
model,
|
||||
*,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
latency_ms,
|
||||
cached_tokens=None,
|
||||
error=None,
|
||||
):
|
||||
# Paired with ``llm_stream_start`` so cost dashboards can sum tokens
|
||||
# by user/agent/provider. Token counts are client-side estimates
|
||||
# from ``stream_token_usage``; vendor-reported counts (incl.
|
||||
# ``cached_tokens`` for prompt caching) require per-provider
|
||||
# extraction in each ``_raw_gen_stream`` and aren't wired yet.
|
||||
extra = {
|
||||
"model": model,
|
||||
"provider": self.provider_name,
|
||||
"prompt_tokens": int(prompt_tokens),
|
||||
"completion_tokens": int(completion_tokens),
|
||||
"latency_ms": int(latency_ms),
|
||||
"status": "error" if error is not None else "ok",
|
||||
}
|
||||
if cached_tokens is not None:
|
||||
extra["cached_tokens"] = int(cached_tokens)
|
||||
if error is not None:
|
||||
extra["error_class"] = type(error).__name__
|
||||
logging.info("llm_stream_finished", extra=extra)
|
||||
|
||||
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
|
||||
# Attachments arrive as ``_usage_attachments`` from ``Agent._llm_gen``;
|
||||
# the ``stream_token_usage`` decorator pops that key, but the log
|
||||
|
||||
@@ -24,6 +24,15 @@ class LogContext:
|
||||
self.api_key = api_key
|
||||
self.query = query
|
||||
self.stacks = []
|
||||
# Per-activity response aggregates populated by ``_consume_and_log``
|
||||
# while it forwards stream items, then flushed onto the
|
||||
# ``activity_finished`` event so every Flask request gets the
|
||||
# same summary that ``run_agent_logic`` used to log only for the
|
||||
# Celery webhook path.
|
||||
self.answer_length = 0
|
||||
self.thought_length = 0
|
||||
self.source_count = 0
|
||||
self.tool_call_count = 0
|
||||
|
||||
|
||||
def build_stack_data(
|
||||
@@ -131,10 +140,8 @@ def log_activity() -> Callable:
|
||||
raise
|
||||
finally:
|
||||
_emit_activity_finished(
|
||||
activity_id=activity_id,
|
||||
context=context,
|
||||
parent_activity_id=parent_activity_id,
|
||||
user=user,
|
||||
endpoint=endpoint,
|
||||
started_at=started_at,
|
||||
error=error,
|
||||
)
|
||||
@@ -147,32 +154,60 @@ def log_activity() -> Callable:
|
||||
|
||||
def _emit_activity_finished(
|
||||
*,
|
||||
activity_id: str,
|
||||
context: "LogContext",
|
||||
parent_activity_id: str | None,
|
||||
user: str,
|
||||
endpoint: str,
|
||||
started_at: float,
|
||||
error: BaseException | None,
|
||||
) -> None:
|
||||
"""Emit the paired ``activity_finished`` event with duration and outcome."""
|
||||
"""Emit the paired ``activity_finished`` event with duration, outcome,
|
||||
and per-activity response aggregates accumulated in ``_consume_and_log``.
|
||||
"""
|
||||
duration_ms = int((time.monotonic() - started_at) * 1000)
|
||||
logging.info(
|
||||
"activity_finished",
|
||||
extra={
|
||||
"activity_id": activity_id,
|
||||
"activity_id": context.activity_id,
|
||||
"parent_activity_id": parent_activity_id,
|
||||
"user_id": user,
|
||||
"endpoint": endpoint,
|
||||
"user_id": context.user,
|
||||
"endpoint": context.endpoint,
|
||||
"duration_ms": duration_ms,
|
||||
"status": "error" if error is not None else "ok",
|
||||
"error_class": type(error).__name__ if error is not None else None,
|
||||
"answer_length": context.answer_length,
|
||||
"thought_length": context.thought_length,
|
||||
"source_count": context.source_count,
|
||||
"tool_call_count": context.tool_call_count,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _accumulate_response_summary(item: Any, context: "LogContext") -> None:
|
||||
"""Mirror the per-line aggregation that ``run_agent_logic`` did for the
|
||||
Celery webhook path, but at the generator-consumption layer so every
|
||||
``Agent.gen`` activity (Flask streaming, sub-agents, workflow agents)
|
||||
gets the same summary.
|
||||
"""
|
||||
if not isinstance(item, dict):
|
||||
return
|
||||
if "answer" in item:
|
||||
context.answer_length += len(str(item["answer"]))
|
||||
return
|
||||
if "thought" in item:
|
||||
context.thought_length += len(str(item["thought"]))
|
||||
return
|
||||
sources = item.get("sources") if "sources" in item else None
|
||||
if isinstance(sources, list):
|
||||
context.source_count += len(sources)
|
||||
return
|
||||
tool_calls = item.get("tool_calls") if "tool_calls" in item else None
|
||||
if isinstance(tool_calls, list):
|
||||
context.tool_call_count += len(tool_calls)
|
||||
|
||||
|
||||
def _consume_and_log(generator: Generator, context: "LogContext"):
|
||||
try:
|
||||
for item in generator:
|
||||
_accumulate_response_summary(item, context)
|
||||
yield item
|
||||
except Exception as e:
|
||||
logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import sys
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
@@ -145,19 +146,44 @@ def stream_token_usage(func):
|
||||
**kwargs,
|
||||
)
|
||||
batch = []
|
||||
result = func(self, model, messages, stream, tools, **kwargs)
|
||||
for r in result:
|
||||
batch.append(r)
|
||||
yield r
|
||||
for line in batch:
|
||||
call_usage["generated_tokens"] += _count_tokens(line)
|
||||
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
||||
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
||||
update_token_usage(
|
||||
self.decoded_token,
|
||||
self.user_api_key,
|
||||
call_usage,
|
||||
getattr(self, "agent_id", None),
|
||||
)
|
||||
started_at = time.monotonic()
|
||||
error: BaseException | None = None
|
||||
try:
|
||||
result = func(self, model, messages, stream, tools, **kwargs)
|
||||
for r in result:
|
||||
batch.append(r)
|
||||
yield r
|
||||
except Exception as exc:
|
||||
# ``GeneratorExit`` (consumer disconnected) and KeyboardInterrupt
|
||||
# flow through as ``status="ok"`` — same convention as
|
||||
# ``application.logging._consume_and_log``.
|
||||
error = exc
|
||||
raise
|
||||
finally:
|
||||
for line in batch:
|
||||
call_usage["generated_tokens"] += _count_tokens(line)
|
||||
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
||||
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
||||
# Persist usage rows only on success: a partial mid-stream
|
||||
# failure shouldn't bill the user for a response they never got.
|
||||
if error is None:
|
||||
update_token_usage(
|
||||
self.decoded_token,
|
||||
self.user_api_key,
|
||||
call_usage,
|
||||
getattr(self, "agent_id", None),
|
||||
)
|
||||
emit = getattr(self, "_emit_stream_finished_log", None)
|
||||
if callable(emit):
|
||||
try:
|
||||
emit(
|
||||
model,
|
||||
prompt_tokens=call_usage["prompt_tokens"],
|
||||
completion_tokens=call_usage["generated_tokens"],
|
||||
latency_ms=int((time.monotonic() - started_at) * 1000),
|
||||
error=error,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to emit llm_stream_finished")
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -432,15 +432,10 @@ def run_agent_logic(agent_config, input_data):
|
||||
"tool_calls": tool_calls,
|
||||
"thought": thought,
|
||||
}
|
||||
logging.info(
|
||||
"agent_response",
|
||||
extra={
|
||||
"answer_length": len(response_full),
|
||||
"source_count": len(source_log_docs),
|
||||
"tool_call_count": len(tool_calls),
|
||||
"thought_length": len(thought),
|
||||
},
|
||||
)
|
||||
# Per-activity summary fields (answer_length, thought_length,
|
||||
# source_count, tool_call_count) now ride on the inner
|
||||
# ``activity_finished`` event emitted by ``log_activity`` around
|
||||
# ``Agent.gen`` above; no separate ``agent_response`` log needed.
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error in run_agent_logic: {e}", exc_info=True)
|
||||
|
||||
@@ -151,6 +151,90 @@ class TestGenMethods:
|
||||
# BaseLLM default — concrete providers always override.
|
||||
assert evt.provider == "unknown"
|
||||
|
||||
@patch("application.llm.base.stream_cache", lambda f: f)
|
||||
def test_gen_stream_emits_llm_stream_finished_on_success(self, caplog):
|
||||
# Real ``stream_token_usage`` so the emit-from-finally path runs.
|
||||
# ``update_token_usage`` short-circuits under pytest, so no DB
|
||||
# mocking is needed.
|
||||
import logging as _logging
|
||||
|
||||
class FakeProvider(StubLLM):
|
||||
provider_name = "fake-provider"
|
||||
|
||||
llm = FakeProvider(raw_gen_stream_items=["alpha", "beta"])
|
||||
llm.user_api_key = None
|
||||
with caplog.at_level(_logging.INFO, logger="root"):
|
||||
list(
|
||||
llm.gen_stream(
|
||||
model="m1",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
)
|
||||
|
||||
finished = [r for r in caplog.records if r.message == "llm_stream_finished"]
|
||||
assert len(finished) == 1
|
||||
evt = finished[0]
|
||||
assert evt.model == "m1"
|
||||
assert evt.provider == "fake-provider"
|
||||
assert evt.status == "ok"
|
||||
assert isinstance(evt.prompt_tokens, int) and evt.prompt_tokens >= 0
|
||||
assert isinstance(evt.completion_tokens, int) and evt.completion_tokens > 0
|
||||
assert isinstance(evt.latency_ms, int) and evt.latency_ms >= 0
|
||||
# ``cached_tokens`` is intentionally absent until per-provider
|
||||
# vendor-usage extraction lands.
|
||||
assert not hasattr(evt, "cached_tokens")
|
||||
assert not hasattr(evt, "error_class")
|
||||
|
||||
@patch("application.llm.base.stream_cache", lambda f: f)
|
||||
def test_gen_stream_emits_llm_stream_finished_on_error(self, caplog):
|
||||
import logging as _logging
|
||||
|
||||
class FakeProvider(BaseLLM):
|
||||
provider_name = "fake-provider"
|
||||
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kw):
|
||||
return "x"
|
||||
|
||||
def _raw_gen_stream(self, baseself, model, messages, stream=True, tools=None, **kw):
|
||||
yield "partial"
|
||||
raise RuntimeError("mid_stream_boom")
|
||||
|
||||
llm = FakeProvider()
|
||||
llm.user_api_key = None
|
||||
with caplog.at_level(_logging.INFO, logger="root"), pytest.raises(RuntimeError):
|
||||
list(llm.gen_stream(model="m1", messages=[]))
|
||||
|
||||
finished = [r for r in caplog.records if r.message == "llm_stream_finished"]
|
||||
assert len(finished) == 1
|
||||
evt = finished[0]
|
||||
assert evt.status == "error"
|
||||
assert evt.error_class == "RuntimeError"
|
||||
# Partial completion tokens still recorded (the chunk yielded
|
||||
# before the failure is in the batch).
|
||||
assert evt.completion_tokens > 0
|
||||
|
||||
@patch("application.llm.base.stream_cache", lambda f: f)
|
||||
def test_gen_stream_finished_event_paired_with_stream_start(self, caplog):
|
||||
# The two events form a pair the cost dashboards join on; verify
|
||||
# they always come in order and from the same provider/model.
|
||||
import logging as _logging
|
||||
|
||||
class FakeProvider(StubLLM):
|
||||
provider_name = "fake-provider"
|
||||
|
||||
llm = FakeProvider(raw_gen_stream_items=["x"])
|
||||
llm.user_api_key = None
|
||||
with caplog.at_level(_logging.INFO, logger="root"):
|
||||
list(llm.gen_stream(model="m1", messages=[]))
|
||||
|
||||
records = [
|
||||
r for r in caplog.records
|
||||
if r.message in ("llm_stream_start", "llm_stream_finished")
|
||||
]
|
||||
assert [r.message for r in records] == ["llm_stream_start", "llm_stream_finished"]
|
||||
assert records[0].model == records[1].model == "m1"
|
||||
assert records[0].provider == records[1].provider == "fake-provider"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestProviderNameRegistry:
|
||||
|
||||
@@ -257,4 +257,115 @@ class TestLogActivity:
|
||||
assert finished.status == "error"
|
||||
assert finished.error_class == "ValueError"
|
||||
|
||||
def test_log_activity_emits_response_summary_aggregates(self, caplog):
|
||||
# Replaces the ``agent_response`` event that ``run_agent_logic``
|
||||
# used to emit only on the Celery webhook path: every Flask
|
||||
# activity now gets the same aggregates on ``activity_finished``.
|
||||
import logging as _logging
|
||||
|
||||
from application.logging import log_activity
|
||||
|
||||
class FakeAgent:
|
||||
endpoint = "stream"
|
||||
user = "user1"
|
||||
user_api_key = ""
|
||||
query = "q"
|
||||
|
||||
@log_activity()
|
||||
def streaming(agent, log_context=None):
|
||||
yield {"answer": "Hello "}
|
||||
yield {"answer": "world"}
|
||||
yield {"thought": "thinking..."}
|
||||
yield {"sources": [{"id": "a"}, {"id": "b"}, {"id": "c"}]}
|
||||
yield {"tool_calls": [{"name": "search"}, {"name": "fetch"}]}
|
||||
yield "ignored-non-dict"
|
||||
yield {"unrecognised": "noop"}
|
||||
|
||||
with patch("application.logging._log_activity_to_db"), \
|
||||
caplog.at_level(_logging.INFO, logger="root"):
|
||||
list(streaming(FakeAgent()))
|
||||
|
||||
finished = next(r for r in caplog.records if r.message == "activity_finished")
|
||||
assert finished.answer_length == len("Hello world")
|
||||
assert finished.thought_length == len("thinking...")
|
||||
assert finished.source_count == 3
|
||||
assert finished.tool_call_count == 2
|
||||
|
||||
def test_log_activity_aggregates_initialise_to_zero(self, caplog):
|
||||
# No yields → summary fields still present and zero (so Axiom
|
||||
# schemas don't get a missing-field hole on empty activities).
|
||||
import logging as _logging
|
||||
|
||||
from application.logging import log_activity
|
||||
|
||||
class FakeAgent:
|
||||
endpoint = "stream"
|
||||
user = "user1"
|
||||
user_api_key = ""
|
||||
query = ""
|
||||
|
||||
@log_activity()
|
||||
def empty(agent, log_context=None):
|
||||
return
|
||||
yield # pragma: no cover — generator marker
|
||||
|
||||
with patch("application.logging._log_activity_to_db"), \
|
||||
caplog.at_level(_logging.INFO, logger="root"):
|
||||
list(empty(FakeAgent()))
|
||||
|
||||
finished = next(r for r in caplog.records if r.message == "activity_finished")
|
||||
assert finished.answer_length == 0
|
||||
assert finished.thought_length == 0
|
||||
assert finished.source_count == 0
|
||||
assert finished.tool_call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAccumulateResponseSummary:
|
||||
"""Direct coverage of the dispatch table — easier to enumerate edge
|
||||
cases here than in end-to-end ``log_activity`` tests."""
|
||||
|
||||
def _ctx(self):
|
||||
from application.logging import LogContext
|
||||
|
||||
return LogContext(
|
||||
endpoint="e", activity_id="a", user="u", api_key="k", query="q"
|
||||
)
|
||||
|
||||
def test_answer_appends_length(self):
|
||||
from application.logging import _accumulate_response_summary
|
||||
|
||||
ctx = self._ctx()
|
||||
_accumulate_response_summary({"answer": "abcd"}, ctx)
|
||||
_accumulate_response_summary({"answer": "ef"}, ctx)
|
||||
assert ctx.answer_length == 6
|
||||
assert ctx.thought_length == 0
|
||||
|
||||
def test_non_dict_items_are_ignored(self):
|
||||
from application.logging import _accumulate_response_summary
|
||||
|
||||
ctx = self._ctx()
|
||||
for item in ("string", 123, None, ["list"], object()):
|
||||
_accumulate_response_summary(item, ctx)
|
||||
assert ctx.answer_length == 0
|
||||
assert ctx.source_count == 0
|
||||
|
||||
def test_sources_must_be_list(self):
|
||||
# A malformed payload (sources=str) shouldn't crash the
|
||||
# accumulator — drop it silently rather than half-count it.
|
||||
from application.logging import _accumulate_response_summary
|
||||
|
||||
ctx = self._ctx()
|
||||
_accumulate_response_summary({"sources": "not-a-list"}, ctx)
|
||||
assert ctx.source_count == 0
|
||||
|
||||
def test_tool_calls_counted(self):
|
||||
from application.logging import _accumulate_response_summary
|
||||
|
||||
ctx = self._ctx()
|
||||
_accumulate_response_summary(
|
||||
{"tool_calls": [{"name": "a"}, {"name": "b"}]}, ctx
|
||||
)
|
||||
assert ctx.tool_call_count == 2
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user