feat: more logs on stream finish

This commit is contained in:
Alex
2026-04-28 02:27:02 +01:00
parent 552bfe016a
commit f0c39dec23
6 changed files with 313 additions and 33 deletions

View File

@@ -296,6 +296,35 @@ class BaseLLM(ABC):
},
)
def _emit_stream_finished_log(
self,
model,
*,
prompt_tokens,
completion_tokens,
latency_ms,
cached_tokens=None,
error=None,
):
# Paired with ``llm_stream_start`` so cost dashboards can sum tokens
# by user/agent/provider. Token counts are client-side estimates
# from ``stream_token_usage``; vendor-reported counts (incl.
# ``cached_tokens`` for prompt caching) require per-provider
# extraction in each ``_raw_gen_stream`` and aren't wired yet.
extra = {
"model": model,
"provider": self.provider_name,
"prompt_tokens": int(prompt_tokens),
"completion_tokens": int(completion_tokens),
"latency_ms": int(latency_ms),
"status": "error" if error is not None else "ok",
}
if cached_tokens is not None:
extra["cached_tokens"] = int(cached_tokens)
if error is not None:
extra["error_class"] = type(error).__name__
logging.info("llm_stream_finished", extra=extra)
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
# Attachments arrive as ``_usage_attachments`` from ``Agent._llm_gen``;
# the ``stream_token_usage`` decorator pops that key, but the log

View File

@@ -24,6 +24,15 @@ class LogContext:
self.api_key = api_key
self.query = query
self.stacks = []
# Per-activity response aggregates populated by ``_consume_and_log``
# while it forwards stream items, then flushed onto the
# ``activity_finished`` event so every Flask request gets the
# same summary that ``run_agent_logic`` used to log only for the
# Celery webhook path.
self.answer_length = 0
self.thought_length = 0
self.source_count = 0
self.tool_call_count = 0
def build_stack_data(
@@ -131,10 +140,8 @@ def log_activity() -> Callable:
raise
finally:
_emit_activity_finished(
activity_id=activity_id,
context=context,
parent_activity_id=parent_activity_id,
user=user,
endpoint=endpoint,
started_at=started_at,
error=error,
)
@@ -147,32 +154,60 @@ def log_activity() -> Callable:
def _emit_activity_finished(
*,
activity_id: str,
context: "LogContext",
parent_activity_id: str | None,
user: str,
endpoint: str,
started_at: float,
error: BaseException | None,
) -> None:
"""Emit the paired ``activity_finished`` event with duration and outcome."""
"""Emit the paired ``activity_finished`` event with duration, outcome,
and per-activity response aggregates accumulated in ``_consume_and_log``.
"""
duration_ms = int((time.monotonic() - started_at) * 1000)
logging.info(
"activity_finished",
extra={
"activity_id": activity_id,
"activity_id": context.activity_id,
"parent_activity_id": parent_activity_id,
"user_id": user,
"endpoint": endpoint,
"user_id": context.user,
"endpoint": context.endpoint,
"duration_ms": duration_ms,
"status": "error" if error is not None else "ok",
"error_class": type(error).__name__ if error is not None else None,
"answer_length": context.answer_length,
"thought_length": context.thought_length,
"source_count": context.source_count,
"tool_call_count": context.tool_call_count,
},
)
def _accumulate_response_summary(item: Any, context: "LogContext") -> None:
"""Mirror the per-line aggregation that ``run_agent_logic`` did for the
Celery webhook path, but at the generator-consumption layer so every
``Agent.gen`` activity (Flask streaming, sub-agents, workflow agents)
gets the same summary.
"""
if not isinstance(item, dict):
return
if "answer" in item:
context.answer_length += len(str(item["answer"]))
return
if "thought" in item:
context.thought_length += len(str(item["thought"]))
return
sources = item.get("sources") if "sources" in item else None
if isinstance(sources, list):
context.source_count += len(sources)
return
tool_calls = item.get("tool_calls") if "tool_calls" in item else None
if isinstance(tool_calls, list):
context.tool_call_count += len(tool_calls)
def _consume_and_log(generator: Generator, context: "LogContext"):
try:
for item in generator:
_accumulate_response_summary(item, context)
yield item
except Exception as e:
logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}")

View File

@@ -1,5 +1,6 @@
import sys
import logging
import time
from datetime import datetime
from application.storage.db.repositories.token_usage import TokenUsageRepository
@@ -145,19 +146,44 @@ def stream_token_usage(func):
**kwargs,
)
batch = []
result = func(self, model, messages, stream, tools, **kwargs)
for r in result:
batch.append(r)
yield r
for line in batch:
call_usage["generated_tokens"] += _count_tokens(line)
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
update_token_usage(
self.decoded_token,
self.user_api_key,
call_usage,
getattr(self, "agent_id", None),
)
started_at = time.monotonic()
error: BaseException | None = None
try:
result = func(self, model, messages, stream, tools, **kwargs)
for r in result:
batch.append(r)
yield r
except Exception as exc:
# ``GeneratorExit`` (consumer disconnected) and KeyboardInterrupt
# flow through as ``status="ok"`` — same convention as
# ``application.logging._consume_and_log``.
error = exc
raise
finally:
for line in batch:
call_usage["generated_tokens"] += _count_tokens(line)
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
# Persist usage rows only on success: a partial mid-stream
# failure shouldn't bill the user for a response they never got.
if error is None:
update_token_usage(
self.decoded_token,
self.user_api_key,
call_usage,
getattr(self, "agent_id", None),
)
emit = getattr(self, "_emit_stream_finished_log", None)
if callable(emit):
try:
emit(
model,
prompt_tokens=call_usage["prompt_tokens"],
completion_tokens=call_usage["generated_tokens"],
latency_ms=int((time.monotonic() - started_at) * 1000),
error=error,
)
except Exception:
logger.exception("Failed to emit llm_stream_finished")
return wrapper

View File

@@ -432,15 +432,10 @@ def run_agent_logic(agent_config, input_data):
"tool_calls": tool_calls,
"thought": thought,
}
logging.info(
"agent_response",
extra={
"answer_length": len(response_full),
"source_count": len(source_log_docs),
"tool_call_count": len(tool_calls),
"thought_length": len(thought),
},
)
# Per-activity summary fields (answer_length, thought_length,
# source_count, tool_call_count) now ride on the inner
# ``activity_finished`` event emitted by ``log_activity`` around
# ``Agent.gen`` above; no separate ``agent_response`` log needed.
return result
except Exception as e:
logging.error(f"Error in run_agent_logic: {e}", exc_info=True)

View File

@@ -151,6 +151,90 @@ class TestGenMethods:
# BaseLLM default — concrete providers always override.
assert evt.provider == "unknown"
@patch("application.llm.base.stream_cache", lambda f: f)
def test_gen_stream_emits_llm_stream_finished_on_success(self, caplog):
# Real ``stream_token_usage`` so the emit-from-finally path runs.
# ``update_token_usage`` short-circuits under pytest, so no DB
# mocking is needed.
import logging as _logging
class FakeProvider(StubLLM):
provider_name = "fake-provider"
llm = FakeProvider(raw_gen_stream_items=["alpha", "beta"])
llm.user_api_key = None
with caplog.at_level(_logging.INFO, logger="root"):
list(
llm.gen_stream(
model="m1",
messages=[{"role": "user", "content": "hi"}],
)
)
finished = [r for r in caplog.records if r.message == "llm_stream_finished"]
assert len(finished) == 1
evt = finished[0]
assert evt.model == "m1"
assert evt.provider == "fake-provider"
assert evt.status == "ok"
assert isinstance(evt.prompt_tokens, int) and evt.prompt_tokens >= 0
assert isinstance(evt.completion_tokens, int) and evt.completion_tokens > 0
assert isinstance(evt.latency_ms, int) and evt.latency_ms >= 0
# ``cached_tokens`` is intentionally absent until per-provider
# vendor-usage extraction lands.
assert not hasattr(evt, "cached_tokens")
assert not hasattr(evt, "error_class")
@patch("application.llm.base.stream_cache", lambda f: f)
def test_gen_stream_emits_llm_stream_finished_on_error(self, caplog):
import logging as _logging
class FakeProvider(BaseLLM):
provider_name = "fake-provider"
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kw):
return "x"
def _raw_gen_stream(self, baseself, model, messages, stream=True, tools=None, **kw):
yield "partial"
raise RuntimeError("mid_stream_boom")
llm = FakeProvider()
llm.user_api_key = None
with caplog.at_level(_logging.INFO, logger="root"), pytest.raises(RuntimeError):
list(llm.gen_stream(model="m1", messages=[]))
finished = [r for r in caplog.records if r.message == "llm_stream_finished"]
assert len(finished) == 1
evt = finished[0]
assert evt.status == "error"
assert evt.error_class == "RuntimeError"
# Partial completion tokens still recorded (the chunk yielded
# before the failure is in the batch).
assert evt.completion_tokens > 0
@patch("application.llm.base.stream_cache", lambda f: f)
def test_gen_stream_finished_event_paired_with_stream_start(self, caplog):
# The two events form a pair the cost dashboards join on; verify
# they always come in order and from the same provider/model.
import logging as _logging
class FakeProvider(StubLLM):
provider_name = "fake-provider"
llm = FakeProvider(raw_gen_stream_items=["x"])
llm.user_api_key = None
with caplog.at_level(_logging.INFO, logger="root"):
list(llm.gen_stream(model="m1", messages=[]))
records = [
r for r in caplog.records
if r.message in ("llm_stream_start", "llm_stream_finished")
]
assert [r.message for r in records] == ["llm_stream_start", "llm_stream_finished"]
assert records[0].model == records[1].model == "m1"
assert records[0].provider == records[1].provider == "fake-provider"
@pytest.mark.unit
class TestProviderNameRegistry:

View File

@@ -257,4 +257,115 @@ class TestLogActivity:
assert finished.status == "error"
assert finished.error_class == "ValueError"
def test_log_activity_emits_response_summary_aggregates(self, caplog):
# Replaces the ``agent_response`` event that ``run_agent_logic``
# used to emit only on the Celery webhook path: every Flask
# activity now gets the same aggregates on ``activity_finished``.
import logging as _logging
from application.logging import log_activity
class FakeAgent:
endpoint = "stream"
user = "user1"
user_api_key = ""
query = "q"
@log_activity()
def streaming(agent, log_context=None):
yield {"answer": "Hello "}
yield {"answer": "world"}
yield {"thought": "thinking..."}
yield {"sources": [{"id": "a"}, {"id": "b"}, {"id": "c"}]}
yield {"tool_calls": [{"name": "search"}, {"name": "fetch"}]}
yield "ignored-non-dict"
yield {"unrecognised": "noop"}
with patch("application.logging._log_activity_to_db"), \
caplog.at_level(_logging.INFO, logger="root"):
list(streaming(FakeAgent()))
finished = next(r for r in caplog.records if r.message == "activity_finished")
assert finished.answer_length == len("Hello world")
assert finished.thought_length == len("thinking...")
assert finished.source_count == 3
assert finished.tool_call_count == 2
def test_log_activity_aggregates_initialise_to_zero(self, caplog):
# No yields → summary fields still present and zero (so Axiom
# schemas don't get a missing-field hole on empty activities).
import logging as _logging
from application.logging import log_activity
class FakeAgent:
endpoint = "stream"
user = "user1"
user_api_key = ""
query = ""
@log_activity()
def empty(agent, log_context=None):
return
yield # pragma: no cover — generator marker
with patch("application.logging._log_activity_to_db"), \
caplog.at_level(_logging.INFO, logger="root"):
list(empty(FakeAgent()))
finished = next(r for r in caplog.records if r.message == "activity_finished")
assert finished.answer_length == 0
assert finished.thought_length == 0
assert finished.source_count == 0
assert finished.tool_call_count == 0
@pytest.mark.unit
class TestAccumulateResponseSummary:
"""Direct coverage of the dispatch table — easier to enumerate edge
cases here than in end-to-end ``log_activity`` tests."""
def _ctx(self):
from application.logging import LogContext
return LogContext(
endpoint="e", activity_id="a", user="u", api_key="k", query="q"
)
def test_answer_appends_length(self):
from application.logging import _accumulate_response_summary
ctx = self._ctx()
_accumulate_response_summary({"answer": "abcd"}, ctx)
_accumulate_response_summary({"answer": "ef"}, ctx)
assert ctx.answer_length == 6
assert ctx.thought_length == 0
def test_non_dict_items_are_ignored(self):
from application.logging import _accumulate_response_summary
ctx = self._ctx()
for item in ("string", 123, None, ["list"], object()):
_accumulate_response_summary(item, ctx)
assert ctx.answer_length == 0
assert ctx.source_count == 0
def test_sources_must_be_list(self):
# A malformed payload (sources=str) shouldn't crash the
# accumulator — drop it silently rather than half-count it.
from application.logging import _accumulate_response_summary
ctx = self._ctx()
_accumulate_response_summary({"sources": "not-a-list"}, ctx)
assert ctx.source_count == 0
def test_tool_calls_counted(self):
from application.logging import _accumulate_response_summary
ctx = self._ctx()
_accumulate_response_summary(
{"tool_calls": [{"name": "a"}, {"name": "b"}]}, ctx
)
assert ctx.tool_call_count == 2