Compare commits

..

1 Commits

Author SHA1 Message Date
Alex
e0a8cc178b feat: BYOM 2026-04-27 21:50:45 +01:00
35 changed files with 109 additions and 1910 deletions

View File

@@ -47,13 +47,11 @@
</ul>
## Roadmap
- [x] Agent Workflow Builder with conditional nodes ( February 2026 )
- [x] SharePoint & Confluence connectors ( March April 2026 )
- [x] Research mode ( March 2026 )
- [x] Postgres migration for user data ( April 2026 )
- [x] OpenTelemetry observability ( April 2026 )
- [x] Bring Your Own Model (BYOM) ( April 2026 )
- [ ] Agent scheduling (RedBeat-backed) ( Q2 2026 )
- [x] Add OAuth 2.0 authentication for MCP ( September 2025 )
- [x] Deep Agents ( October 2025 )
- [x] Prompt Templating ( October 2025 )
- [x] Full api tooling ( Dec 2025 )
- [ ] Agent scheduling ( Jan 2026 )
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!

View File

@@ -274,14 +274,7 @@ class ToolExecutor:
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
logger.error(
"tool_call_parse_failed",
extra={
"llm_class_name": llm_class_name,
"llm_tool_name": llm_name,
"call_id": call_id,
},
)
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
@@ -296,15 +289,7 @@ class ToolExecutor:
if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(
"tool_id_not_found",
extra={
"tool_id": tool_id,
"llm_tool_name": llm_name,
"call_id": call_id,
"available_tool_count": len(tools_dict),
},
)
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
@@ -371,15 +356,7 @@ class ToolExecutor:
f"Failed to load tool '{tool_data.get('name')}' (tool_id key={tool_id}): "
"missing 'id' on tool row."
)
logger.error(
"tool_load_failed",
extra={
"tool_name": tool_data.get("name"),
"tool_id": tool_id,
"action_name": action_name,
"call_id": call_id,
},
)
logger.error(error_message)
tool_call_data["result"] = error_message
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
@@ -474,12 +451,10 @@ class ToolExecutor:
row_id = tool_data.get("id")
if not row_id:
logger.error(
"tool_missing_row_id",
extra={
"tool_name": tool_data.get("name"),
"tool_id": tool_id,
"action_name": action_name,
},
"Tool data missing 'id' for tool name=%s (enumerate-key tool_id=%s); "
"skipping load to avoid binding a non-UUID downstream.",
tool_data.get("name"),
tool_id,
)
return None
tool_config["tool_id"] = str(row_id)

View File

@@ -12,12 +12,6 @@ logger = logging.getLogger(__name__)
class TokenCounter:
"""Centralized token counting for conversations and messages."""
# Per-image token estimate. Provider tokenizers vary widely
# (Gemini ~258, GPT-4o 85-1500, Claude ~1500) and the actual cost
# depends on resolution/detail we can't see here. Errs slightly high
# so the threshold check stays conservative.
_IMAGE_PART_TOKEN_ESTIMATE = 1500
@staticmethod
def count_message_tokens(messages: List[Dict]) -> int:
"""
@@ -35,36 +29,12 @@ class TokenCounter:
if isinstance(content, str):
total_tokens += num_tokens_from_string(content)
elif isinstance(content, list):
# Handle structured content (tool calls, image parts, etc.)
# Handle structured content (tool calls, etc.)
for item in content:
if isinstance(item, dict):
total_tokens += TokenCounter._count_content_part(item)
total_tokens += num_tokens_from_string(str(item))
return total_tokens
@staticmethod
def _count_content_part(item: Dict) -> int:
# Image/file attachments are billed by the provider per image,
# not proportional to the inline bytes/base64 string.
# ``str(item)`` on a 1MB image inflates the count by ~10000x,
# which trips spurious compression and overflows downstream
# input limits.
item_type = item.get("type")
if "files" in item:
files = item.get("files")
count = len(files) if isinstance(files, list) and files else 1
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE * count
if "image_url" in item or item_type in {
"image",
"image_url",
"input_image",
"file",
}:
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE
return num_tokens_from_string(str(item))
@staticmethod
def count_query_tokens(
queries: List[Dict[str, Any]], include_tool_calls: bool = True

View File

@@ -9,7 +9,6 @@ from jose import jwt
from application.auth import handle_auth
from application.core import log_context
from application.core.logging_config import setup_logging
setup_logging()
@@ -113,38 +112,6 @@ def generate_token():
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
_LOG_CTX_TOKEN_ATTR = "_log_ctx_token"
@app.before_request
def _bind_log_context():
"""Bind activity_id + endpoint for the duration of this request.
Runs before ``authenticate_request``; ``user_id`` is overlaid in a
follow-up handler once the JWT has been decoded.
"""
if request.method == "OPTIONS":
return None
activity_id = str(uuid.uuid4())
request.activity_id = activity_id
token = log_context.bind(
activity_id=activity_id,
endpoint=request.endpoint,
)
setattr(request, _LOG_CTX_TOKEN_ATTR, token)
return None
@app.teardown_request
def _reset_log_context(_exc):
# SSE streams keep yielding after teardown fires, but a2wsgi runs each
# request inside ``copy_context().run(...)``, so this reset doesn't
# leak into the stream's view of the context.
token = getattr(request, _LOG_CTX_TOKEN_ATTR, None)
if token is not None:
log_context.reset(token)
@app.before_request
def enforce_stt_request_size_limits():
if request.method == "OPTIONS":
@@ -181,21 +148,6 @@ def authenticate_request():
request.decoded_token = decoded_token
@app.before_request
def _bind_user_id_to_log_context():
# Registered after ``authenticate_request`` (Flask runs before_request
# handlers in registration order), so ``request.decoded_token`` is
# populated by the time we read it. ``teardown_request`` unwinds the
# whole request-level bind, so no separate reset token is needed here.
if request.method == "OPTIONS":
return None
decoded_token = getattr(request, "decoded_token", None)
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
if user_id:
log_context.bind(user_id=user_id)
return None
@app.after_request
def after_request(response: Response) -> Response:
"""Add CORS headers for the pure Flask development entrypoint."""

View File

@@ -1,4 +1,3 @@
import hashlib
import json
import logging
import time
@@ -11,14 +10,6 @@ from application.utils import get_hash
logger = logging.getLogger(__name__)
def _cache_default(value):
# Image attachments arrive inline as bytes (see GoogleLLM.prepare_messages_with_attachments);
# hash so the cache key stays bounded in size and stable across identical content.
if isinstance(value, (bytes, bytearray, memoryview)):
return f"<bytes:sha256:{hashlib.sha256(bytes(value)).hexdigest()}>"
return repr(value)
_redis_instance = None
_redis_creation_failed = False
_instance_lock = Lock()
@@ -45,7 +36,7 @@ def get_redis_instance():
def gen_cache_key(messages, model="docgpt", tools=None):
if not all(isinstance(msg, dict) for msg in messages):
raise ValueError("All messages must be dictionaries.")
messages_str = json.dumps(messages, default=_cache_default)
messages_str = json.dumps(messages)
tools_str = json.dumps(str(tools)) if tools else ""
combined = f"{model}_{messages_str}_{tools_str}"
cache_key = get_hash(combined)

View File

@@ -1,17 +1,8 @@
import inspect
import logging
import threading
from celery import Celery
from application.core import log_context
from application.core.settings import settings
from celery.signals import (
setup_logging,
task_postrun,
task_prerun,
worker_process_init,
worker_ready,
)
from celery.signals import setup_logging, worker_process_init, worker_ready
def make_celery(app_name=__name__):
@@ -50,54 +41,6 @@ def _dispose_db_engine_on_fork(*args, **kwargs):
dispose_engine()
# Most tasks in this repo accept ``user`` where the log context wants
# ``user_id``; map task parameter names to context keys explicitly.
_TASK_PARAM_TO_CTX_KEY: dict[str, str] = {
"user": "user_id",
"user_id": "user_id",
"agent_id": "agent_id",
"conversation_id": "conversation_id",
}
_task_log_tokens: dict[str, object] = {}
@task_prerun.connect
def _bind_task_log_context(task_id, task, args, kwargs, **_):
# Resolve task args by parameter name — nearly every task in this repo
# is called positionally, so ``kwargs.get('user')`` would bind nothing.
ctx = {"activity_id": task_id}
try:
sig = inspect.signature(task.run)
bound = sig.bind_partial(*args, **kwargs).arguments
except (TypeError, ValueError):
bound = dict(kwargs)
for param_name, value in bound.items():
ctx_key = _TASK_PARAM_TO_CTX_KEY.get(param_name)
if ctx_key and value:
ctx[ctx_key] = value
_task_log_tokens[task_id] = log_context.bind(**ctx)
@task_postrun.connect
def _unbind_task_log_context(task_id, **_):
# ``task_postrun`` fires on both success and failure. Required for
# Celery: unlike the Flask path, tasks aren't isolated in their own
# ``copy_context().run(...)``, so a missing reset would leak the
# bind onto the next task on the same worker.
token = _task_log_tokens.pop(task_id, None)
if token is None:
return
try:
log_context.reset(token)
except ValueError:
# task_prerun and task_postrun ran on different threads (non-default
# Celery pool); the token isn't valid in this context. Drop it.
logging.getLogger(__name__).debug(
"log_context reset skipped for task %s", task_id
)
@worker_ready.connect
def _run_version_check(*args, **kwargs):
"""Kick off the anonymous version check on worker startup.

View File

@@ -1,57 +0,0 @@
"""Per-activity logging context backed by ``contextvars``.
The ``_ContextFilter`` installed by ``logging_config.setup_logging`` stamps
every ``LogRecord`` emitted inside a ``bind`` block with the bound keys, so
they land as first-class attributes on the OTLP log export rather than being
buried inside formatted message bodies.
A single ``ContextVar`` holds a dict so nested binds reset atomically (LIFO)
via the token returned by ``bind``.
"""
from __future__ import annotations
from contextvars import ContextVar, Token
from typing import Mapping
_CTX_KEYS: frozenset[str] = frozenset(
{
"activity_id",
"parent_activity_id",
"user_id",
"agent_id",
"conversation_id",
"endpoint",
"model",
}
)
_ctx: ContextVar[Mapping[str, str]] = ContextVar("log_ctx", default={})
def bind(**kwargs: object) -> Token:
"""Overlay the given keys onto the current context.
Returns a ``Token`` so the caller can ``reset`` in a ``finally`` block.
Keys outside :data:`_CTX_KEYS` are silently dropped (so a typo can't
stamp a stray field name onto every record), as are ``None`` values
(a missing attribute is more useful than the literal string ``"None"``).
"""
overlay = {
k: str(v)
for k, v in kwargs.items()
if k in _CTX_KEYS and v is not None
}
new = {**_ctx.get(), **overlay}
return _ctx.set(new)
def reset(token: Token) -> None:
"""Restore the context to the snapshot captured by the matching ``bind``."""
_ctx.reset(token)
def snapshot() -> Mapping[str, str]:
"""Return the current context dict. Treat as read-only; use :func:`bind`."""
return _ctx.get()

View File

@@ -2,36 +2,6 @@ import logging
import os
from logging.config import dictConfig
from application.core.log_context import snapshot as _ctx_snapshot
# Loggers with ``propagate=False`` don't share root's handlers, so the
# context filter has to be installed on their handlers directly.
_NON_PROPAGATING_LOGGERS: tuple[str, ...] = (
"uvicorn",
"uvicorn.access",
"uvicorn.error",
"celery.app.trace",
"celery.worker.strategy",
"gunicorn.error",
"gunicorn.access",
)
class _ContextFilter(logging.Filter):
"""Stamp the current ``log_context`` snapshot onto every ``LogRecord``.
Must be installed on **handlers**, not loggers: Python skips logger-level
filters when a child logger's record propagates up. The ``hasattr`` guard
keeps an explicit ``logger.info(..., extra={...})`` from being overwritten.
"""
def filter(self, record: logging.LogRecord) -> bool:
for key, value in _ctx_snapshot().items():
if not hasattr(record, key):
setattr(record, key, value)
return True
def _otlp_logs_enabled() -> bool:
"""Return True when the user has opted in to OTLP log export.
@@ -90,23 +60,3 @@ def setup_logging() -> None:
for handler in preserved_handlers:
if handler not in root.handlers:
root.addHandler(handler)
_install_context_filter()
def _install_context_filter() -> None:
"""Attach :class:`_ContextFilter` to root's handlers + every handler on
the known non-propagating loggers. Skipping handlers that already carry
one keeps repeat ``setup_logging`` calls from stacking filters.
"""
def _has_ctx_filter(handler: logging.Handler) -> bool:
return any(isinstance(f, _ContextFilter) for f in handler.filters)
for handler in logging.getLogger().handlers:
if not _has_ctx_filter(handler):
handler.addFilter(_ContextFilter())
for name in _NON_PROPAGATING_LOGGERS:
for handler in logging.getLogger(name).handlers:
if not _has_ctx_filter(handler):
handler.addFilter(_ContextFilter())

View File

@@ -11,7 +11,6 @@ logger = logging.getLogger(__name__)
class AnthropicLLM(BaseLLM):
provider_name = "anthropic"
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):

View File

@@ -1,6 +1,5 @@
import logging
from abc import ABC, abstractmethod
from typing import ClassVar
from application.cache import gen_cache, stream_cache
@@ -11,10 +10,6 @@ logger = logging.getLogger(__name__)
class BaseLLM(ABC):
# Stamped onto the ``llm_stream_start`` event so dashboards can group
# calls by vendor. Subclasses override.
provider_name: ClassVar[str] = "unknown"
def __init__(
self,
decoded_token=None,
@@ -123,26 +118,6 @@ class BaseLLM(ABC):
return args_dict
return {k: v for k, v in args_dict.items() if v is not None}
@staticmethod
def _is_non_retriable_client_error(exc: BaseException) -> bool:
"""4xx errors mean the request itself is malformed — retrying with
a different model fails identically and doubles the work. Only
transient/5xx/connection errors should trigger fallback."""
try:
from google.genai.errors import ClientError as _GenaiClientError
if isinstance(exc, _GenaiClientError):
return True
except ImportError:
pass
for attr in ("status_code", "code", "http_status"):
v = getattr(exc, attr, None)
if isinstance(v, int) and 400 <= v < 500:
return True
resp = getattr(exc, "response", None)
v = getattr(resp, "status_code", None)
return isinstance(v, int) and 400 <= v < 500
def _execute_with_fallback(
self, method_name: str, decorators: list, *args, **kwargs
):
@@ -166,18 +141,12 @@ class BaseLLM(ABC):
if is_stream:
return self._stream_with_fallback(
decorated_method, method_name, decorators, *args, **kwargs
decorated_method, method_name, *args, **kwargs
)
try:
return decorated_method()
except Exception as e:
if self._is_non_retriable_client_error(e):
logger.error(
f"Primary LLM failed with non-retriable client error; "
f"skipping fallback: {str(e)}"
)
raise
if not self.fallback_llm:
logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
raise
@@ -187,27 +156,14 @@ class BaseLLM(ABC):
f"{fallback.model_id}. Error: {str(e)}"
)
# Apply decorators to fallback's raw method directly — calling
# fallback.gen() would re-enter the orchestrator and recurse via
# fallback.fallback_llm.
fallback_method = getattr(fallback, method_name)
for decorator in decorators:
fallback_method = decorator(fallback_method)
fallback_method = getattr(
fallback, method_name.replace("_raw_", "")
)
fallback_kwargs = {**kwargs, "model": fallback.model_id}
try:
return fallback_method(fallback, *args, **fallback_kwargs)
except Exception as e2:
if self._is_non_retriable_client_error(e2):
logger.error(
f"Fallback LLM failed with non-retriable client "
f"error; giving up: {str(e2)}"
)
else:
logger.error(f"Fallback LLM also failed; giving up: {str(e2)}")
raise
return fallback_method(*args, **fallback_kwargs)
def _stream_with_fallback(
self, decorated_method, method_name, decorators, *args, **kwargs
self, decorated_method, method_name, *args, **kwargs
):
"""
Wrapper generator that catches mid-stream errors and falls back.
@@ -220,12 +176,6 @@ class BaseLLM(ABC):
try:
yield from decorated_method()
except Exception as e:
if self._is_non_retriable_client_error(e):
logger.error(
f"Primary LLM failed mid-stream with non-retriable client "
f"error; skipping fallback: {str(e)}"
)
raise
if not self.fallback_llm:
logger.error(
f"Primary LLM failed and no fallback configured: {str(e)}"
@@ -236,37 +186,11 @@ class BaseLLM(ABC):
f"Primary LLM failed mid-stream. Falling back to "
f"{fallback.model_id}. Error: {str(e)}"
)
# Apply decorators to fallback's raw stream method directly —
# calling fallback.gen_stream() would re-enter the orchestrator
# and recurse via fallback.fallback_llm. Emit the stream-start
# event manually so dashboards still see the fallback's
# provider/model when the response actually comes from it.
fallback._emit_stream_start_log(
fallback.model_id,
kwargs.get("messages"),
kwargs.get("tools"),
bool(
kwargs.get("_usage_attachments")
or kwargs.get("attachments")
),
fallback_method = getattr(
fallback, method_name.replace("_raw_", "")
)
fallback_method = getattr(fallback, method_name)
for decorator in decorators:
fallback_method = decorator(fallback_method)
fallback_kwargs = {**kwargs, "model": fallback.model_id}
try:
yield from fallback_method(fallback, *args, **fallback_kwargs)
except Exception as e2:
if self._is_non_retriable_client_error(e2):
logger.error(
f"Fallback LLM failed mid-stream with non-retriable "
f"client error; giving up: {str(e2)}"
)
else:
logger.error(
f"Fallback LLM also failed mid-stream; giving up: {str(e2)}"
)
raise
yield from fallback_method(*args, **fallback_kwargs)
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
decorators = [gen_token_usage, gen_cache]
@@ -281,58 +205,7 @@ class BaseLLM(ABC):
**kwargs,
)
def _emit_stream_start_log(self, model, messages, tools, has_attachments):
# Stamped with ``self.provider_name`` so dashboards can group calls
# by vendor; the fallback path emits its own copy on the fallback
# instance so the actual responding provider is recorded.
logging.info(
"llm_stream_start",
extra={
"model": model,
"provider": self.provider_name,
"message_count": len(messages) if messages is not None else 0,
"has_attachments": bool(has_attachments),
"has_tools": bool(tools),
},
)
def _emit_stream_finished_log(
self,
model,
*,
prompt_tokens,
completion_tokens,
latency_ms,
cached_tokens=None,
error=None,
):
# Paired with ``llm_stream_start`` so cost dashboards can sum tokens
# by user/agent/provider. Token counts are client-side estimates
# from ``stream_token_usage``; vendor-reported counts (incl.
# ``cached_tokens`` for prompt caching) require per-provider
# extraction in each ``_raw_gen_stream`` and aren't wired yet.
extra = {
"model": model,
"provider": self.provider_name,
"prompt_tokens": int(prompt_tokens),
"completion_tokens": int(completion_tokens),
"latency_ms": int(latency_ms),
"status": "error" if error is not None else "ok",
}
if cached_tokens is not None:
extra["cached_tokens"] = int(cached_tokens)
if error is not None:
extra["error_class"] = type(error).__name__
logging.info("llm_stream_finished", extra=extra)
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
# Attachments arrive as ``_usage_attachments`` from ``Agent._llm_gen``;
# the ``stream_token_usage`` decorator pops that key, but the log
# fires before the decorator runs so it's still in ``kwargs`` here.
has_attachments = bool(
kwargs.get("_usage_attachments") or kwargs.get("attachments")
)
self._emit_stream_start_log(model, messages, tools, has_attachments)
decorators = [stream_cache, stream_token_usage]
return self._execute_with_fallback(
"_raw_gen_stream",

View File

@@ -6,8 +6,6 @@ DOCSGPT_BASE_URL = "https://oai.arc53.com"
DOCSGPT_MODEL = "docsgpt"
class DocsGPTAPILLM(OpenAILLM):
provider_name = "docsgpt"
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=DOCSGPT_API_KEY,

View File

@@ -10,8 +10,6 @@ from application.storage.storage_creator import StorageCreator
class GoogleLLM(BaseLLM):
provider_name = "google"
def __init__(
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
):
@@ -81,39 +79,24 @@ class GoogleLLM(BaseLLM):
for attachment in attachments:
mime_type = attachment.get("mime_type")
if mime_type not in self.get_supported_attachment_types():
continue
try:
# Images go inline as bytes per Google's guidance for
# requests under 20MB; the Files API can return before
# the upload reaches ACTIVE state and yield an empty URI.
if mime_type.startswith("image/"):
file_bytes = self._read_attachment_bytes(attachment)
files.append(
{"file_bytes": file_bytes, "mime_type": mime_type}
)
else:
if mime_type in self.get_supported_attachment_types():
try:
file_uri = self._upload_file_to_google(attachment)
if not file_uri:
raise ValueError(
f"Google Files API returned empty URI for "
f"{attachment.get('path', 'unknown')}"
)
logging.info(
f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}"
)
files.append({"file_uri": file_uri, "mime_type": mime_type})
except Exception as e:
logging.error(
f"GoogleLLM: Error processing attachment: {e}", exc_info=True
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
}
except Exception as e:
logging.error(
f"GoogleLLM: Error uploading file: {e}", exc_info=True
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
}
)
if files:
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
prepared_messages[user_message_index]["content"].append({"files": files})
@@ -129,9 +112,7 @@ class GoogleLLM(BaseLLM):
Returns:
str: Google AI file URI for the uploaded file.
"""
# Truthy check, not membership: a poisoned cache row of "" or
# None must be treated as a miss and trigger a fresh upload.
if attachment.get("google_file_uri"):
if "google_file_uri" in attachment:
return attachment["google_file_uri"]
file_path = attachment.get("path")
if not file_path:
@@ -145,10 +126,6 @@ class GoogleLLM(BaseLLM):
file=local_path
).uri,
)
if not file_uri:
raise ValueError(
f"Google Files API upload returned empty URI for {file_path}"
)
# Cache the Google file URI on the attachment row so we don't
# re-upload on the next LLM call. Accept either a PG UUID
@@ -182,26 +159,6 @@ class GoogleLLM(BaseLLM):
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
raise
def _read_attachment_bytes(self, attachment):
"""
Read attachment bytes from storage for inline transmission.
Args:
attachment (dict): Attachment dictionary with path and metadata.
Returns:
bytes: Raw file bytes.
"""
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
if not self.storage.file_exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
return self.storage.process_file(
file_path,
lambda local_path, **kwargs: open(local_path, "rb").read(),
)
def _clean_messages_google(self, messages):
"""
Convert OpenAI format messages to Google AI format and collect system prompts.
@@ -341,24 +298,12 @@ class GoogleLLM(BaseLLM):
)
elif "files" in item:
for file_data in item["files"]:
if "file_bytes" in file_data:
parts.append(
types.Part.from_bytes(
data=file_data["file_bytes"],
mime_type=file_data["mime_type"],
)
)
elif file_data.get("file_uri"):
parts.append(
types.Part.from_uri(
file_uri=file_data["file_uri"],
mime_type=file_data["mime_type"],
)
)
else:
logging.warning(
"GoogleLLM: dropping file part with empty URI and no bytes"
parts.append(
types.Part.from_uri(
file_uri=file_data["file_uri"],
mime_type=file_data["mime_type"],
)
)
else:
raise ValueError(
f"Unexpected content dictionary format:{item}"
@@ -596,6 +541,22 @@ class GoogleLLM(BaseLLM):
config.response_mime_type = "application/json"
# Check if we have both tools and file attachments
has_attachments = False
for message in messages:
for part in message.parts:
if hasattr(part, "file_data") and part.file_data is not None:
has_attachments = True
break
if has_attachments:
break
messages_summary = self._summarize_messages_for_log(messages)
logging.info(
"GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s",
model,
messages_summary,
has_attachments,
)
response = client.models.generate_content_stream(
model=model,
contents=messages,

View File

@@ -5,8 +5,6 @@ GROQ_BASE_URL = "https://api.groq.com/openai/v1"
class GroqLLM(OpenAILLM):
provider_name = "groq"
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.GROQ_API_KEY or settings.API_KEY,

View File

@@ -280,26 +280,7 @@ class LLMHandler(ABC):
# Keep serialized function calls/responses so the compressor sees actions
parts_text.append(str(item))
elif "files" in item:
# Image attachments arrive with raw bytes / base64
# inline (see GoogleLLM.prepare_messages_with_attachments).
# ``str(item)`` would dump the whole byte/base64
# blob into the compression prompt and bust the
# compression LLM's input limit.
files = item.get("files") or []
descriptors = []
if isinstance(files, list):
for f in files:
if isinstance(f, dict):
descriptors.append(
f.get("mime_type") or "file"
)
elif isinstance(f, str):
descriptors.append(f)
if not descriptors:
descriptors = ["file"]
parts_text.append(
f"[attachment: {', '.join(descriptors)}]"
)
parts_text.append(str(item))
return "\n".join(parts_text)
return ""

View File

@@ -26,8 +26,6 @@ class LlamaSingleton:
class LlamaCpp(BaseLLM):
provider_name = "llama_cpp"
def __init__(
self,
api_key=None,

View File

@@ -5,8 +5,6 @@ NOVITA_BASE_URL = "https://api.novita.ai/openai"
class NovitaLLM(OpenAILLM):
provider_name = "novita"
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.NOVITA_API_KEY or settings.API_KEY,

View File

@@ -5,8 +5,6 @@ OPEN_ROUTER_BASE_URL = "https://openrouter.ai/api/v1"
class OpenRouterLLM(OpenAILLM):
provider_name = "openrouter"
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.OPEN_ROUTER_API_KEY or settings.API_KEY,

View File

@@ -61,7 +61,6 @@ def _truncate_base64_for_logging(messages):
class OpenAILLM(BaseLLM):
provider_name = "openai"
def __init__(
self,

View File

@@ -3,7 +3,6 @@ from application.core.settings import settings
class PremAILLM(BaseLLM):
provider_name = "premai"
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from premai import Prem

View File

@@ -59,7 +59,6 @@ class LineIterator:
class SagemakerAPILLM(BaseLLM):
provider_name = "sagemaker"
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
import boto3

View File

@@ -1,13 +1,11 @@
import datetime
import functools
import inspect
import time
import logging
import uuid
from typing import Any, Callable, Dict, Generator, List
from application.core import log_context
from application.storage.db.repositories.stack_logs import StackLogsRepository
from application.storage.db.session import db_session
@@ -24,15 +22,6 @@ class LogContext:
self.api_key = api_key
self.query = query
self.stacks = []
# Per-activity response aggregates populated by ``_consume_and_log``
# while it forwards stream items, then flushed onto the
# ``activity_finished`` event so every Flask request gets the
# same summary that ``run_agent_logic`` used to log only for the
# Celery webhook path.
self.answer_length = 0
self.thought_length = 0
self.source_count = 0
self.tool_call_count = 0
def build_stack_data(
@@ -89,125 +78,25 @@ def log_activity() -> Callable:
user = data.get("user", "local")
api_key = data.get("user_api_key", "")
query = kwargs.get("query", getattr(args[0], "query", ""))
agent_id = getattr(args[0], "agent_id", None) or kwargs.get("agent_id")
conversation_id = (
kwargs.get("conversation_id")
or getattr(args[0], "conversation_id", None)
)
model = getattr(args[0], "gpt_model", None) or getattr(args[0], "model", None)
# Capture the surrounding activity_id before overlaying ours,
# so nested activities record the parent → child link.
parent_activity_id = log_context.snapshot().get("activity_id")
context = LogContext(endpoint, activity_id, user, api_key, query)
kwargs["log_context"] = context
ctx_token = log_context.bind(
activity_id=activity_id,
parent_activity_id=parent_activity_id,
user_id=user,
agent_id=agent_id,
conversation_id=conversation_id,
endpoint=endpoint,
model=model,
)
started_at = time.monotonic()
logging.info(
"activity_started",
extra={
"activity_id": activity_id,
"parent_activity_id": parent_activity_id,
"user_id": user,
"agent_id": agent_id,
"conversation_id": conversation_id,
"endpoint": endpoint,
"model": model,
},
f"Starting activity: {endpoint} - {activity_id} - User: {user}"
)
error: BaseException | None = None
try:
generator = func(*args, **kwargs)
yield from _consume_and_log(generator, context)
except Exception as exc:
# Only ``Exception`` counts as an activity error; ``GeneratorExit``
# (consumer disconnected mid-stream) and ``KeyboardInterrupt``
# flow through the finally as ``status="ok"``, matching
# ``_consume_and_log``.
error = exc
raise
finally:
_emit_activity_finished(
context=context,
parent_activity_id=parent_activity_id,
started_at=started_at,
error=error,
)
log_context.reset(ctx_token)
generator = func(*args, **kwargs)
yield from _consume_and_log(generator, context)
return wrapper
return decorator
def _emit_activity_finished(
*,
context: "LogContext",
parent_activity_id: str | None,
started_at: float,
error: BaseException | None,
) -> None:
"""Emit the paired ``activity_finished`` event with duration, outcome,
and per-activity response aggregates accumulated in ``_consume_and_log``.
"""
duration_ms = int((time.monotonic() - started_at) * 1000)
logging.info(
"activity_finished",
extra={
"activity_id": context.activity_id,
"parent_activity_id": parent_activity_id,
"user_id": context.user,
"endpoint": context.endpoint,
"duration_ms": duration_ms,
"status": "error" if error is not None else "ok",
"error_class": type(error).__name__ if error is not None else None,
"answer_length": context.answer_length,
"thought_length": context.thought_length,
"source_count": context.source_count,
"tool_call_count": context.tool_call_count,
},
)
def _accumulate_response_summary(item: Any, context: "LogContext") -> None:
"""Mirror the per-line aggregation that ``run_agent_logic`` did for the
Celery webhook path, but at the generator-consumption layer so every
``Agent.gen`` activity (Flask streaming, sub-agents, workflow agents)
gets the same summary.
"""
if not isinstance(item, dict):
return
if "answer" in item:
context.answer_length += len(str(item["answer"]))
return
if "thought" in item:
context.thought_length += len(str(item["thought"]))
return
sources = item.get("sources") if "sources" in item else None
if isinstance(sources, list):
context.source_count += len(sources)
return
tool_calls = item.get("tool_calls") if "tool_calls" in item else None
if isinstance(tool_calls, list):
context.tool_call_count += len(tool_calls)
def _consume_and_log(generator: Generator, context: "LogContext"):
try:
for item in generator:
_accumulate_response_summary(item, context)
yield item
except Exception as e:
logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}")

View File

@@ -17,21 +17,6 @@ _UPDATABLE_SCALARS = {
_UPDATABLE_JSONB = {"metadata"}
def _attachment_to_dict(row: Any) -> dict:
"""row_to_dict + ``upload_path``→``path`` alias.
Pre-Postgres, the Mongo attachment shape used ``path``. The PG column
is ``upload_path``; LLM provider code (google_ai/openai/anthropic and
handlers/base) still reads ``attachment.get("path")``. Mirroring the
``id``/``_id`` dual-emit in row_to_dict so consumers don't need to
know which storage backend produced the dict.
"""
out = row_to_dict(row)
if "upload_path" in out and out.get("path") is None:
out["path"] = out["upload_path"]
return out
class AttachmentsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
@@ -81,7 +66,7 @@ class AttachmentsRepository:
"legacy_mongo_id": legacy_mongo_id,
},
)
return _attachment_to_dict(result.fetchone())
return row_to_dict(result.fetchone())
def get(self, attachment_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
@@ -91,7 +76,7 @@ class AttachmentsRepository:
{"id": attachment_id, "user_id": user_id},
)
row = result.fetchone()
return _attachment_to_dict(row) if row is not None else None
return row_to_dict(row) if row is not None else None
def get_any(self, attachment_id: str, user_id: str) -> Optional[dict]:
"""Resolve an attachment by either PG UUID or legacy Mongo ObjectId string."""
@@ -170,14 +155,14 @@ class AttachmentsRepository:
params["user_id"] = user_id
result = self._conn.execute(text(sql), params)
row = result.fetchone()
return _attachment_to_dict(row) if row is not None else None
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM attachments WHERE user_id = :user_id ORDER BY created_at DESC"),
{"user_id": user_id},
)
return [_attachment_to_dict(r) for r in result.fetchall()]
return [row_to_dict(r) for r in result.fetchall()]
def update(self, attachment_id: str, user_id: str, fields: dict) -> bool:
"""Partial update. Used by the LLM providers to cache their

View File

@@ -1,6 +1,5 @@
import sys
import logging
import time
from datetime import datetime
from application.storage.db.repositories.token_usage import TokenUsageRepository
@@ -21,15 +20,6 @@ def _serialize_for_token_count(value):
if value is None:
return ""
# Raw binary payloads (image/file attachments arrive as ``bytes`` from
# ``GoogleLLM.prepare_messages_with_attachments``) — without this
# branch they fall through to ``str(value)`` below, which produces a
# multi-megabyte ``"b'\\x89PNG...'"`` repr-string and inflates
# ``prompt_tokens`` by orders of magnitude. Same intent as the
# data-URL skip above.
if isinstance(value, (bytes, bytearray, memoryview)):
return ""
if isinstance(value, list):
return [_serialize_for_token_count(item) for item in value]
@@ -155,44 +145,19 @@ def stream_token_usage(func):
**kwargs,
)
batch = []
started_at = time.monotonic()
error: BaseException | None = None
try:
result = func(self, model, messages, stream, tools, **kwargs)
for r in result:
batch.append(r)
yield r
except Exception as exc:
# ``GeneratorExit`` (consumer disconnected) and KeyboardInterrupt
# flow through as ``status="ok"`` — same convention as
# ``application.logging._consume_and_log``.
error = exc
raise
finally:
for line in batch:
call_usage["generated_tokens"] += _count_tokens(line)
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
# Persist usage rows only on success: a partial mid-stream
# failure shouldn't bill the user for a response they never got.
if error is None:
update_token_usage(
self.decoded_token,
self.user_api_key,
call_usage,
getattr(self, "agent_id", None),
)
emit = getattr(self, "_emit_stream_finished_log", None)
if callable(emit):
try:
emit(
model,
prompt_tokens=call_usage["prompt_tokens"],
completion_tokens=call_usage["generated_tokens"],
latency_ms=int((time.monotonic() - started_at) * 1000),
error=error,
)
except Exception:
logger.exception("Failed to emit llm_stream_finished")
result = func(self, model, messages, stream, tools, **kwargs)
for r in result:
batch.append(r)
yield r
for line in batch:
call_usage["generated_tokens"] += _count_tokens(line)
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
update_token_usage(
self.decoded_token,
self.user_api_key,
call_usage,
getattr(self, "agent_id", None),
)
return wrapper

View File

@@ -432,10 +432,7 @@ def run_agent_logic(agent_config, input_data):
"tool_calls": tool_calls,
"thought": thought,
}
# Per-activity summary fields (answer_length, thought_length,
# source_count, tool_call_count) now ride on the inner
# ``activity_finished`` event emitted by ``log_activity`` around
# ``Agent.gen`` above; no separate ``agent_response`` log needed.
logging.info(f"Agent response: {result}")
return result
except Exception as e:
logging.error(f"Error in run_agent_logic: {e}", exc_info=True)

View File

@@ -1,161 +0,0 @@
"""Unit tests for ``log_context`` + ``_ContextFilter``."""
from __future__ import annotations
import logging
import pytest
from application.core import log_context
from application.core.logging_config import _ContextFilter
@pytest.fixture(autouse=True)
def _clean_log_ctx():
# The contextvar is module-scoped; snapshot at entry, restore at exit
# to keep tests from leaking state into each other.
token = log_context.bind()
try:
yield
finally:
log_context.reset(token)
@pytest.mark.unit
class TestBindAndSnapshot:
def test_bind_returns_token_and_snapshot_reflects_overlay(self):
token = log_context.bind(activity_id="a1", user_id="u1")
assert log_context.snapshot() == {"activity_id": "a1", "user_id": "u1"}
log_context.reset(token)
assert log_context.snapshot() == {}
def test_bind_drops_unknown_keys(self):
token = log_context.bind(activity_id="a1", not_a_real_key="boom")
try:
assert log_context.snapshot() == {"activity_id": "a1"}
finally:
log_context.reset(token)
def test_bind_drops_none_values(self):
token = log_context.bind(activity_id="a1", agent_id=None)
try:
assert "agent_id" not in log_context.snapshot()
finally:
log_context.reset(token)
def test_bind_coerces_values_to_str(self):
token = log_context.bind(activity_id=42)
try:
assert log_context.snapshot()["activity_id"] == "42"
finally:
log_context.reset(token)
def test_nested_bind_overlays_and_resets_lifo(self):
outer = log_context.bind(activity_id="outer", user_id="u1")
inner = log_context.bind(activity_id="inner", agent_id="agent-1")
# Inner overrides activity_id, keeps user_id from outer, adds agent_id.
assert log_context.snapshot() == {
"activity_id": "inner",
"user_id": "u1",
"agent_id": "agent-1",
}
log_context.reset(inner)
assert log_context.snapshot() == {"activity_id": "outer", "user_id": "u1"}
log_context.reset(outer)
assert log_context.snapshot() == {}
def test_parent_activity_id_pattern(self):
outer = log_context.bind(activity_id="parent-1")
parent = log_context.snapshot().get("activity_id")
inner = log_context.bind(activity_id="child-1", parent_activity_id=parent)
try:
snap = log_context.snapshot()
assert snap["activity_id"] == "child-1"
assert snap["parent_activity_id"] == "parent-1"
finally:
log_context.reset(inner)
log_context.reset(outer)
@pytest.mark.unit
class TestContextFilter:
def _make_record(self, **extra) -> logging.LogRecord:
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname=__file__,
lineno=1,
msg="hello",
args=(),
exc_info=None,
)
for k, v in extra.items():
setattr(record, k, v)
return record
def test_stamps_record_with_context(self):
token = log_context.bind(activity_id="a1", user_id="u1")
try:
record = self._make_record()
assert _ContextFilter().filter(record) is True
assert record.activity_id == "a1"
assert record.user_id == "u1"
finally:
log_context.reset(token)
def test_explicit_extra_wins_over_context(self):
# extra={} sets attributes on the record before the filter runs;
# the filter must not overwrite them.
token = log_context.bind(activity_id="from-ctx")
try:
record = self._make_record(activity_id="from-extra")
_ContextFilter().filter(record)
assert record.activity_id == "from-extra"
finally:
log_context.reset(token)
def test_no_op_when_context_empty(self):
record = self._make_record()
assert _ContextFilter().filter(record) is True
assert not hasattr(record, "activity_id")
@pytest.mark.unit
class TestFilterWiringEndToEnd:
"""Regression guard: the filter must be installed on handlers, not on
loggers — Python skips logger-level filters during propagation.
"""
def test_propagated_record_gets_stamped(self):
from application.core.logging_config import _install_context_filter
captured: list[logging.LogRecord] = []
class _Capture(logging.Handler):
def emit(self, record):
captured.append(record)
root = logging.getLogger()
saved_handlers = list(root.handlers)
saved_level = root.level
try:
root.handlers = [_Capture()]
root.setLevel(logging.DEBUG)
_install_context_filter()
child = logging.getLogger("test_log_context.propagation")
child.setLevel(logging.DEBUG)
token = log_context.bind(activity_id="propagated-id")
try:
child.info("from a child logger")
finally:
log_context.reset(token)
assert captured, "Capture handler should have received the record"
assert captured[0].activity_id == "propagated-id"
finally:
root.handlers = saved_handlers
root.setLevel(saved_level)

View File

@@ -360,38 +360,7 @@ class TestExtractTextFromContent:
handler = ConcreteHandler()
content = [{"files": ["/tmp/a.txt"]}]
result = handler._extract_text_from_content(content)
assert result == "[attachment: /tmp/a.txt]"
def test_list_with_inline_image_bytes(self):
# Google attaches images as inline bytes; stringifying them into
# the compression prompt would bust the compression LLM's input
# limit. The placeholder must describe the attachment without
# embedding the bytes.
handler = ConcreteHandler()
content = [
{
"files": [
{"file_bytes": b"\x89PNG" + b"\x00" * 1000, "mime_type": "image/png"}
]
}
]
result = handler._extract_text_from_content(content)
assert result == "[attachment: image/png]"
assert "PNG" not in result
assert "\\x" not in result
def test_list_with_multiple_files(self):
handler = ConcreteHandler()
content = [
{
"files": [
{"file_bytes": b"a", "mime_type": "image/png"},
{"file_uri": "https://x", "mime_type": "image/jpeg"},
]
}
]
result = handler._extract_text_from_content(content)
assert result == "[attachment: image/png, image/jpeg]"
assert "files" in result
def test_list_with_none_text(self):
handler = ConcreteHandler()

View File

@@ -49,9 +49,6 @@ class FailingLLM(BaseLLM):
class FallbackLLM(BaseLLM):
# _execute_with_fallback applies decorators to the fallback's raw method
# directly and never calls .gen() / .gen_stream() on it, so
# tracking lives on the raw methods.
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.gen_called = False
@@ -65,6 +62,14 @@ class FallbackLLM(BaseLLM):
self.gen_stream_called = True
yield "fallback_chunk"
def gen(self, *args, **kwargs):
self.gen_called = True
return "fallback_gen_result"
def gen_stream(self, *args, **kwargs):
self.gen_stream_called = True
yield "fallback_stream_chunk"
# ---------------------------------------------------------------------------
# gen / gen_stream decorator application
@@ -90,180 +95,6 @@ class TestGenMethods:
)
assert result == ["a", "b"]
@patch("application.llm.base.stream_cache", lambda f: f)
@patch("application.llm.base.stream_token_usage", lambda f: f)
def test_gen_stream_emits_llm_stream_start_event(self, caplog):
import logging as _logging
class FakeProvider(StubLLM):
provider_name = "fake-provider"
llm = FakeProvider(raw_gen_stream_items=["x"])
with caplog.at_level(_logging.INFO, logger="root"):
list(
llm.gen_stream(
model="m1",
messages=[{"role": "user", "content": "hi"}, {"role": "assistant", "content": "hey"}],
tools=[{"name": "t"}],
_usage_attachments=[{"path": "/tmp/a.png"}],
)
)
starts = [r for r in caplog.records if r.message == "llm_stream_start"]
assert len(starts) == 1
evt = starts[0]
assert evt.model == "m1"
assert evt.provider == "fake-provider"
assert evt.message_count == 2
# ``_usage_attachments`` is what ``Agent._llm_gen`` actually passes;
# the alias check below covers the bare ``attachments=`` form.
assert evt.has_attachments is True
assert evt.has_tools is True
@patch("application.llm.base.stream_cache", lambda f: f)
@patch("application.llm.base.stream_token_usage", lambda f: f)
def test_gen_stream_recognises_attachments_kwarg_alias(self, caplog):
import logging as _logging
llm = StubLLM(raw_gen_stream_items=["x"])
with caplog.at_level(_logging.INFO, logger="root"):
list(
llm.gen_stream(
model="m1", messages=[], attachments=[{"path": "/tmp/a"}]
)
)
evt = next(r for r in caplog.records if r.message == "llm_stream_start")
assert evt.has_attachments is True
@patch("application.llm.base.stream_cache", lambda f: f)
@patch("application.llm.base.stream_token_usage", lambda f: f)
def test_gen_stream_emits_event_without_attachments_or_tools(self, caplog):
import logging as _logging
llm = StubLLM(raw_gen_stream_items=["x"])
with caplog.at_level(_logging.INFO, logger="root"):
list(llm.gen_stream(model="m1", messages=[]))
evt = next(r for r in caplog.records if r.message == "llm_stream_start")
assert evt.message_count == 0
assert evt.has_attachments is False
assert evt.has_tools is False
# BaseLLM default — concrete providers always override.
assert evt.provider == "unknown"
@patch("application.llm.base.stream_cache", lambda f: f)
def test_gen_stream_emits_llm_stream_finished_on_success(self, caplog):
# Real ``stream_token_usage`` so the emit-from-finally path runs.
# ``update_token_usage`` short-circuits under pytest, so no DB
# mocking is needed.
import logging as _logging
class FakeProvider(StubLLM):
provider_name = "fake-provider"
llm = FakeProvider(raw_gen_stream_items=["alpha", "beta"])
llm.user_api_key = None
with caplog.at_level(_logging.INFO, logger="root"):
list(
llm.gen_stream(
model="m1",
messages=[{"role": "user", "content": "hi"}],
)
)
finished = [r for r in caplog.records if r.message == "llm_stream_finished"]
assert len(finished) == 1
evt = finished[0]
assert evt.model == "m1"
assert evt.provider == "fake-provider"
assert evt.status == "ok"
assert isinstance(evt.prompt_tokens, int) and evt.prompt_tokens >= 0
assert isinstance(evt.completion_tokens, int) and evt.completion_tokens > 0
assert isinstance(evt.latency_ms, int) and evt.latency_ms >= 0
# ``cached_tokens`` is intentionally absent until per-provider
# vendor-usage extraction lands.
assert not hasattr(evt, "cached_tokens")
assert not hasattr(evt, "error_class")
@patch("application.llm.base.stream_cache", lambda f: f)
def test_gen_stream_emits_llm_stream_finished_on_error(self, caplog):
import logging as _logging
class FakeProvider(BaseLLM):
provider_name = "fake-provider"
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kw):
return "x"
def _raw_gen_stream(self, baseself, model, messages, stream=True, tools=None, **kw):
yield "partial"
raise RuntimeError("mid_stream_boom")
llm = FakeProvider()
llm.user_api_key = None
with caplog.at_level(_logging.INFO, logger="root"), pytest.raises(RuntimeError):
list(llm.gen_stream(model="m1", messages=[]))
finished = [r for r in caplog.records if r.message == "llm_stream_finished"]
assert len(finished) == 1
evt = finished[0]
assert evt.status == "error"
assert evt.error_class == "RuntimeError"
# Partial completion tokens still recorded (the chunk yielded
# before the failure is in the batch).
assert evt.completion_tokens > 0
@patch("application.llm.base.stream_cache", lambda f: f)
def test_gen_stream_finished_event_paired_with_stream_start(self, caplog):
# The two events form a pair the cost dashboards join on; verify
# they always come in order and from the same provider/model.
import logging as _logging
class FakeProvider(StubLLM):
provider_name = "fake-provider"
llm = FakeProvider(raw_gen_stream_items=["x"])
llm.user_api_key = None
with caplog.at_level(_logging.INFO, logger="root"):
list(llm.gen_stream(model="m1", messages=[]))
records = [
r for r in caplog.records
if r.message in ("llm_stream_start", "llm_stream_finished")
]
assert [r.message for r in records] == ["llm_stream_start", "llm_stream_finished"]
assert records[0].model == records[1].model == "m1"
assert records[0].provider == records[1].provider == "fake-provider"
@pytest.mark.unit
class TestProviderNameRegistry:
"""A new provider without ``provider_name`` would silently report
``provider="unknown"`` in telemetry. Pin the expected values here."""
def test_provider_names_match_expectations(self):
from application.llm.anthropic import AnthropicLLM
from application.llm.docsgpt_provider import DocsGPTAPILLM
from application.llm.google_ai import GoogleLLM
from application.llm.groq import GroqLLM
from application.llm.llama_cpp import LlamaCpp
from application.llm.novita import NovitaLLM
from application.llm.open_router import OpenRouterLLM
from application.llm.openai import OpenAILLM
from application.llm.premai import PremAILLM
from application.llm.sagemaker import SagemakerAPILLM
assert OpenAILLM.provider_name == "openai"
assert GoogleLLM.provider_name == "google"
assert AnthropicLLM.provider_name == "anthropic"
assert GroqLLM.provider_name == "groq"
assert NovitaLLM.provider_name == "novita"
assert OpenRouterLLM.provider_name == "openrouter"
assert DocsGPTAPILLM.provider_name == "docsgpt"
assert PremAILLM.provider_name == "premai"
assert LlamaCpp.provider_name == "llama_cpp"
assert SagemakerAPILLM.provider_name == "sagemaker"
@patch("application.llm.base.gen_cache", lambda f: f)
@patch("application.llm.base.gen_token_usage", lambda f: f)
def test_gen_passes_tools(self):
@@ -309,7 +140,7 @@ class TestExecuteWithFallbackNonStreaming:
llm._fallback_llm = fallback
result = llm.gen(model="m", messages=[])
assert result == "fallback_result"
assert result == "fallback_gen_result"
assert fallback.gen_called
@@ -336,97 +167,10 @@ class TestStreamWithFallback:
llm._fallback_llm = fallback
result = list(llm.gen_stream(model="m", messages=[]))
assert "fallback_chunk" in result
assert "fallback_stream_chunk" in result
assert fallback.gen_stream_called
# ---------------------------------------------------------------------------
# Non-retriable client error guard
# ---------------------------------------------------------------------------
class _StatusError(Exception):
"""Mimics openai/anthropic-shaped client errors with a status_code."""
def __init__(self, status_code, message="bad request"):
super().__init__(message)
self.status_code = status_code
class _ClientErrorLLM(BaseLLM):
def __init__(self, status_code, **kwargs):
super().__init__(**kwargs)
self._status = status_code
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kw):
raise _StatusError(self._status)
def _raw_gen_stream(self, baseself, model, messages, stream=True, tools=None, **kw):
raise _StatusError(self._status)
@pytest.mark.unit
class TestNonRetriableClientError:
def test_helper_detects_4xx_status_code(self):
assert BaseLLM._is_non_retriable_client_error(_StatusError(400))
assert BaseLLM._is_non_retriable_client_error(_StatusError(404))
assert BaseLLM._is_non_retriable_client_error(_StatusError(429))
def test_helper_passes_5xx_through(self):
# 5xx and connection errors should still trigger fallback.
assert not BaseLLM._is_non_retriable_client_error(_StatusError(500))
assert not BaseLLM._is_non_retriable_client_error(_StatusError(503))
assert not BaseLLM._is_non_retriable_client_error(RuntimeError("oops"))
def test_helper_detects_genai_client_error(self):
try:
from google.genai.errors import ClientError
except ImportError:
pytest.skip("google-genai not installed")
# ClientError(code, response_json, response=None)
exc = ClientError(400, {"error": {"message": "bad", "code": 400}}, None)
assert BaseLLM._is_non_retriable_client_error(exc)
def test_helper_detects_response_status_code(self):
exc = RuntimeError("wrapped")
exc.response = type("R", (), {"status_code": 401})()
assert BaseLLM._is_non_retriable_client_error(exc)
@patch("application.llm.base.gen_cache", lambda f: f)
@patch("application.llm.base.gen_token_usage", lambda f: f)
def test_4xx_skips_fallback(self):
fallback = FallbackLLM(model_id="fallback-model")
llm = _ClientErrorLLM(status_code=400)
llm._fallback_llm = fallback
with pytest.raises(_StatusError):
llm.gen(model="m", messages=[])
assert not fallback.gen_called
@patch("application.llm.base.stream_cache", lambda f: f)
@patch("application.llm.base.stream_token_usage", lambda f: f)
def test_4xx_skips_stream_fallback(self):
fallback = FallbackLLM(model_id="fallback-model")
llm = _ClientErrorLLM(status_code=400)
llm._fallback_llm = fallback
with pytest.raises(_StatusError):
list(llm.gen_stream(model="m", messages=[]))
assert not fallback.gen_stream_called
@patch("application.llm.base.gen_cache", lambda f: f)
@patch("application.llm.base.gen_token_usage", lambda f: f)
def test_5xx_still_falls_back(self):
fallback = FallbackLLM(model_id="fallback-model")
llm = _ClientErrorLLM(status_code=503)
llm._fallback_llm = fallback
result = llm.gen(model="m", messages=[])
assert result == "fallback_result"
assert fallback.gen_called
# ---------------------------------------------------------------------------
# fallback_llm property: backup model resolution
# ---------------------------------------------------------------------------

View File

@@ -31,20 +31,12 @@ class FakeLLM(BaseLLM):
self.gen_stream_called = False
self.last_model_received = None # tracks the model kwarg passed to gen/gen_stream
# Track at the raw-method level. _execute_with_fallback applies
# decorators to the fallback's raw method directly and
# never calls .gen() / .gen_stream() on it, so a public-method
# override would not register fallback hops.
def _raw_gen(self, baseself, model, messages, stream, tools=None, **kwargs):
self.gen_called = True
self.last_model_received = model
if self.fail_at is not None:
raise RuntimeError("primary model unavailable")
return self.responses[0]
def _raw_gen_stream(self, baseself, model, messages, stream, tools=None, **kwargs):
self.gen_stream_called = True
self.last_model_received = model
yielded = 0
for chunk in self.stream_chunks:
if self.fail_at is not None and yielded >= self.fail_at:
@@ -52,6 +44,18 @@ class FakeLLM(BaseLLM):
yield chunk
yielded += 1
# Wrap gen/gen_stream so we can track whether the fallback instance was used
# and which model kwarg it received
def gen(self, *args, **kwargs):
self.gen_called = True
self.last_model_received = kwargs.get("model")
return super().gen(*args, **kwargs)
def gen_stream(self, *args, **kwargs):
self.gen_stream_called = True
self.last_model_received = kwargs.get("model")
return super().gen_stream(*args, **kwargs)
# Helpers
@@ -290,140 +294,6 @@ class TestStreamingFallback:
with pytest.raises(RuntimeError, match="mid-stream failure"):
list(primary.gen_stream(**CALL_ARGS))
def test_fallback_emits_stream_start_with_fallback_provider(
self, patch_model_utils, caplog
):
# The fallback raw-stream path bypasses ``gen_stream``, so it must
# emit its own ``llm_stream_start`` event tagged with the fallback
# vendor — otherwise dashboards record only the failed primary
# even when the response came from the backup.
import logging as _logging
class FallbackProvider(FakeLLM):
provider_name = "fallback-vendor"
backup = FallbackProvider(
stream_chunks=["b1"], model_id="backup-model-id"
)
patch_model_utils(
get_provider=lambda m, **_kwargs: "openai",
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
class PrimaryProvider(FakeLLM):
provider_name = "primary-vendor"
primary = PrimaryProvider(
stream_chunks=["x"],
fail_at=0,
backup_models=["backup-model-id"],
)
with caplog.at_level(_logging.INFO, logger="root"):
list(
primary.gen_stream(
model="primary-model",
messages=[{"role": "user", "content": "hi"}],
)
)
starts = [r for r in caplog.records if r.message == "llm_stream_start"]
assert len(starts) == 2
assert starts[0].provider == "primary-vendor"
assert starts[0].model == "primary-model"
assert starts[1].provider == "fallback-vendor"
assert starts[1].model == "backup-model-id"
# Tests — fallback never re-enters the orchestrator (Option B regression)
@pytest.mark.integration
class TestFallbackNoRecursion:
"""When the primary fails, _execute_with_fallback applies decorators to
the fallback's raw method directly. The fallback's own ``fallback_llm``
property must never be accessed — otherwise a fallback failure would
re-enter the orchestrator and walk the global FALLBACK_LLM_* chain
unboundedly."""
def test_backup_fallback_llm_property_never_accessed_on_gen_failure(
self, monkeypatch, patch_model_utils
):
backup = FakeLLM(fail_at=0) # backup also fails
accessed_on = []
original_property = BaseLLM.fallback_llm
def tracked_fallback_llm(self_llm):
accessed_on.append(self_llm)
return original_property.fget(self_llm)
monkeypatch.setattr(
BaseLLM, "fallback_llm", property(tracked_fallback_llm)
)
patch_model_utils(
get_provider=lambda m, **_kwargs: "openai",
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
primary = FakeLLM(fail_at=0, backup_models=["backup-model"])
with pytest.raises(RuntimeError, match="primary model unavailable"):
primary.gen(**CALL_ARGS)
assert primary in accessed_on # primary lazy-loaded its fallback
assert backup not in accessed_on # backup's chain was never walked
def test_backup_fallback_llm_property_never_accessed_on_stream_failure(
self, monkeypatch, patch_model_utils
):
backup = FakeLLM(stream_chunks=["x"], fail_at=0)
accessed_on = []
original_property = BaseLLM.fallback_llm
def tracked_fallback_llm(self_llm):
accessed_on.append(self_llm)
return original_property.fget(self_llm)
monkeypatch.setattr(
BaseLLM, "fallback_llm", property(tracked_fallback_llm)
)
patch_model_utils(
get_provider=lambda m, **_kwargs: "openai",
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
primary = FakeLLM(
stream_chunks=["y"], fail_at=0, backup_models=["backup-model"]
)
with pytest.raises(RuntimeError, match="mid-stream failure"):
list(primary.gen_stream(**CALL_ARGS))
assert primary in accessed_on
assert backup not in accessed_on
def test_fallback_failure_propagates_without_chain(self, patch_model_utils):
"""When both primary and fallback fail, the fallback's exception
propagates cleanly — no third hop, no extra retries."""
backup = FakeLLM(fail_at=0)
patch_model_utils(
get_provider=lambda m, **_kwargs: "openai",
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
primary = FakeLLM(fail_at=0, backup_models=["backup-model"])
with pytest.raises(RuntimeError, match="primary model unavailable"):
primary.gen(**CALL_ARGS)
assert backup.gen_called # confirms fallback raw method WAS invoked
# Tests — backup model priority over global fallback

View File

@@ -28,11 +28,10 @@ from application.llm.google_ai import GoogleLLM
class _FakePart:
def __init__(self, text=None, function_call=None, file_data=None, inline_data=None, thought=False, **kwargs):
def __init__(self, text=None, function_call=None, file_data=None, thought=False, **kwargs):
self.text = text
self.function_call = function_call or kwargs.get("functionCall")
self.file_data = file_data
self.inline_data = inline_data
self.thought = thought
self.thoughtSignature = kwargs.get("thoughtSignature")
@@ -54,12 +53,6 @@ class _FakePart:
file_data=types.SimpleNamespace(file_uri=file_uri, mime_type=mime_type)
)
@staticmethod
def from_bytes(data, mime_type):
return _FakePart(
inline_data=types.SimpleNamespace(data=data, mime_type=mime_type)
)
class _FakeContent:
def __init__(self, role, parts):
@@ -230,43 +223,6 @@ class TestCleanMessagesGoogle:
for p in cleaned[0].parts
)
def test_files_with_inline_bytes(self, llm):
msgs = [
{
"role": "user",
"content": [
{
"files": [
{"file_bytes": b"\x89PNG", "mime_type": "image/png"}
]
},
],
}
]
cleaned, _ = llm._clean_messages_google(msgs)
assert len(cleaned) == 1
inline_parts = [
p for p in cleaned[0].parts
if getattr(p, "inline_data", None) is not None
]
assert len(inline_parts) == 1
assert inline_parts[0].inline_data.data == b"\x89PNG"
assert inline_parts[0].inline_data.mime_type == "image/png"
def test_files_with_empty_uri_dropped(self, llm):
msgs = [
{
"role": "user",
"content": [
{"files": [{"file_uri": "", "mime_type": "image/png"}]},
],
}
]
cleaned, _ = llm._clean_messages_google(msgs)
# Empty URI part is dropped; no other parts means the whole
# content is empty and the message itself is not appended.
assert cleaned == []
def test_unexpected_list_item_raises(self, llm):
msgs = [{"role": "user", "content": [{"unknown_key": "val"}]}]
with pytest.raises(ValueError, match="Unexpected content dictionary"):
@@ -754,9 +710,7 @@ class TestPrepareMessagesWithAttachments:
def test_upload_error_adds_text_fallback(self, llm, monkeypatch):
monkeypatch.setattr(
llm,
"_read_attachment_bytes",
lambda a: (_ for _ in ()).throw(Exception("fail")),
llm, "_upload_file_to_google", lambda a: (_ for _ in ()).throw(Exception("fail"))
)
msgs = [{"role": "user", "content": "hi"}]
attachments = [
@@ -770,57 +724,8 @@ class TestPrepareMessagesWithAttachments:
]
assert len(text_parts) == 1
def test_pdf_upload_error_adds_text_fallback(self, llm, monkeypatch):
monkeypatch.setattr(
llm,
"_upload_file_to_google",
lambda a: (_ for _ in ()).throw(Exception("fail")),
)
msgs = [{"role": "user", "content": "hi"}]
attachments = [
{"mime_type": "application/pdf", "path": "/tmp/doc.pdf", "content": "x"},
]
result = llm.prepare_messages_with_attachments(msgs, attachments)
user_msg = next(m for m in result if m["role"] == "user")
text_parts = [
p for p in user_msg["content"]
if isinstance(p, dict) and p.get("type") == "text" and "could not" in p.get("text", "").lower()
]
assert len(text_parts) == 1
def test_pdf_empty_uri_adds_text_fallback(self, llm, monkeypatch):
monkeypatch.setattr(llm, "_upload_file_to_google", lambda a: "")
msgs = [{"role": "user", "content": "hi"}]
attachments = [
{"mime_type": "application/pdf", "path": "/tmp/doc.pdf", "content": "x"},
]
result = llm.prepare_messages_with_attachments(msgs, attachments)
user_msg = next(m for m in result if m["role"] == "user")
files_entries = [
p for p in user_msg["content"] if isinstance(p, dict) and "files" in p
]
assert files_entries == []
text_parts = [
p for p in user_msg["content"]
if isinstance(p, dict) and p.get("type") == "text" and "could not" in p.get("text", "").lower()
]
assert len(text_parts) == 1
def test_image_uses_inline_bytes(self, llm, monkeypatch):
monkeypatch.setattr(llm, "_read_attachment_bytes", lambda a: b"\x89PNG-bytes")
msgs = [{"role": "user", "content": "hi"}]
attachments = [{"mime_type": "image/png", "path": "/img.png"}]
result = llm.prepare_messages_with_attachments(msgs, attachments)
user_msg = next(m for m in result if m["role"] == "user")
files_entry = next(
p for p in user_msg["content"] if isinstance(p, dict) and "files" in p
)
assert files_entry["files"] == [
{"file_bytes": b"\x89PNG-bytes", "mime_type": "image/png"}
]
def test_no_user_message_creates_one(self, llm, monkeypatch):
monkeypatch.setattr(llm, "_read_attachment_bytes", lambda a: b"png")
monkeypatch.setattr(llm, "_upload_file_to_google", lambda a: "gs://uri")
msgs = [{"role": "system", "content": "sys"}]
attachments = [{"mime_type": "image/png", "path": "/img.png"}]
result = llm.prepare_messages_with_attachments(msgs, attachments)
@@ -841,26 +746,6 @@ class TestUploadFileToGoogle:
result = llm._upload_file_to_google(attachment)
assert result == "gs://cached"
def test_empty_cached_uri_triggers_reupload(self, llm, monkeypatch):
# Poisoned-cache repro: an empty-string google_file_uri must be
# treated as a miss and re-upload, not returned as-is.
monkeypatch.setattr(
"application.llm.google_ai.settings",
types.SimpleNamespace(GOOGLE_API_KEY="k", API_KEY="k"),
)
result = llm._upload_file_to_google(
{"google_file_uri": "", "path": "/tmp/file.pdf"}
)
assert result == "gs://fake-uri"
def test_empty_upload_uri_raises(self, llm):
llm.storage = types.SimpleNamespace(
file_exists=lambda p: True,
process_file=lambda path, fn, **kw: "",
)
with pytest.raises(ValueError, match="empty URI"):
llm._upload_file_to_google({"path": "/tmp/file.pdf"})
def test_raises_for_no_path(self, llm):
with pytest.raises(ValueError, match="No file path"):
llm._upload_file_to_google({})

View File

@@ -228,7 +228,6 @@ def test_prepare_messages_with_attachments_appends_files(monkeypatch):
process_file=lambda path, processor_func, **kwargs: "gs://file_uri"
)
monkeypatch.setattr(llm, "_upload_file_to_google", lambda att: "gs://file_uri")
monkeypatch.setattr(llm, "_read_attachment_bytes", lambda att: b"png-bytes")
messages = [{"role": "user", "content": "Hi"}]
attachments = [
@@ -241,9 +240,4 @@ def test_prepare_messages_with_attachments_appends_files(monkeypatch):
assert isinstance(user_msg["content"], list)
files_entry = next((p for p in user_msg["content"] if isinstance(p, dict) and "files" in p), None)
assert files_entry is not None
files = files_entry["files"]
assert len(files) == 2
image_part = next(f for f in files if f["mime_type"] == "image/png")
pdf_part = next(f for f in files if f["mime_type"] == "application/pdf")
assert image_part == {"file_bytes": b"png-bytes", "mime_type": "image/png"}
assert pdf_part == {"file_uri": "gs://file_uri", "mime_type": "application/pdf"}
assert isinstance(files_entry["files"], list) and len(files_entry["files"]) == 2

View File

@@ -31,23 +31,6 @@ class TestCreate:
doc = repo.create("u", "f", "/p")
assert doc["_id"] == doc["id"]
def test_create_aliases_upload_path_as_path(self, pg_conn):
# LLM provider code (google_ai/openai/anthropic and handlers/base)
# reads attachment.get("path") — preserved from the legacy Mongo
# shape. Repo emits both keys so consumers don't need to know
# which storage backend produced the dict.
repo = _repo(pg_conn)
doc = repo.create("u", "f", "/uploads/x.png")
assert doc["path"] == "/uploads/x.png"
assert doc["upload_path"] == "/uploads/x.png"
def test_get_aliases_upload_path_as_path(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("u", "f", "/uploads/y.pdf")
fetched = repo.get(created["id"], "u")
assert fetched is not None
assert fetched["path"] == "/uploads/y.pdf"
def test_create_with_legacy_mongo_id(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create(

View File

@@ -438,94 +438,3 @@ def test_stream_cache_key_generation_failure_yields(mock_make_redis):
messages = ["not_a_dict"]
result = list(mock_function(None, "model", messages, stream=True, tools=None))
assert result == ["fallback_chunk"]
# =====================================================================
# gen_cache_key with inline bytes (Google attachments)
# =====================================================================
@pytest.mark.unit
def test_gen_cache_key_handles_inline_bytes():
"""Image attachments arrive in messages as raw bytes (see
GoogleLLM.prepare_messages_with_attachments). gen_cache_key must not
crash on json.dumps of bytes."""
msgs = [
{
"role": "user",
"content": [{"file_bytes": b"\x00\x01\x02", "mime_type": "image/png"}],
}
]
key = gen_cache_key(msgs, model="x")
assert isinstance(key, str)
assert len(key) == 32
@pytest.mark.unit
def test_gen_cache_key_stable_for_same_bytes():
"""Two requests with identical image bytes must produce the same key
— otherwise we'd never get cache hits on image-bearing prompts."""
a = [
{
"role": "user",
"content": [{"file_bytes": b"abc", "mime_type": "image/png"}],
}
]
b = [
{
"role": "user",
"content": [{"file_bytes": b"abc", "mime_type": "image/png"}],
}
]
assert gen_cache_key(a, "m") == gen_cache_key(b, "m")
@pytest.mark.unit
def test_gen_cache_key_differs_for_different_bytes():
"""Different image bytes must produce different keys — otherwise two
different images would collide in cache."""
a = [
{
"role": "user",
"content": [{"file_bytes": b"abc", "mime_type": "image/png"}],
}
]
b = [
{
"role": "user",
"content": [{"file_bytes": b"xyz", "mime_type": "image/png"}],
}
]
assert gen_cache_key(a, "m") != gen_cache_key(b, "m")
@pytest.mark.unit
def test_gen_cache_key_handles_bytearray_and_memoryview():
"""The default helper covers all bytes-like types so refactors that
swap bytes for bytearray/memoryview don't silently re-introduce the
TypeError."""
msgs_ba = [
{
"role": "user",
"content": [
{"file_bytes": bytearray(b"abc"), "mime_type": "image/png"}
],
}
]
msgs_mv = [
{
"role": "user",
"content": [
{"file_bytes": memoryview(b"abc"), "mime_type": "image/png"}
],
}
]
msgs_b = [
{
"role": "user",
"content": [{"file_bytes": b"abc", "mime_type": "image/png"}],
}
]
# All three should hash the same content to the same key.
assert gen_cache_key(msgs_ba, "m") == gen_cache_key(msgs_b, "m")
assert gen_cache_key(msgs_mv, "m") == gen_cache_key(msgs_b, "m")

View File

@@ -427,109 +427,6 @@ class TestCompressionService:
assert token_count_with_tools > token_count_without_tools
def test_count_message_tokens_skips_inline_image_bytes(self):
# Google attaches images as raw bytes inline. Stringifying the
# part would tokenize the byte repr (~2M tokens for a 1MB image),
# trigger spurious compression, and overflow downstream input
# limits. The estimate must stay bounded.
big_bytes = b"\x00" * 1_000_000
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "describe this"},
{
"files": [
{"file_bytes": big_bytes, "mime_type": "image/png"}
]
},
],
}
]
tokens = TokenCounter.count_message_tokens(messages)
assert tokens < 5000
def test_count_message_tokens_per_image_estimate_scales(self):
msgs_one = [
{
"role": "user",
"content": [
{
"files": [
{"file_bytes": b"x", "mime_type": "image/png"}
]
}
],
}
]
msgs_two = [
{
"role": "user",
"content": [
{
"files": [
{"file_bytes": b"x", "mime_type": "image/png"},
{"file_uri": "https://x", "mime_type": "image/jpeg"},
]
}
],
}
]
assert TokenCounter.count_message_tokens(
msgs_two
) > TokenCounter.count_message_tokens(msgs_one)
def test_count_message_tokens_skips_openai_image_url(self):
# OpenAI puts images inline as base64 data URLs. ``str(item)`` on
# the dict would tokenize the entire base64 payload.
big_b64 = "A" * 1_000_000
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{big_b64}"
},
},
],
}
]
tokens = TokenCounter.count_message_tokens(messages)
assert tokens < 5000
def test_count_message_tokens_skips_anthropic_image(self):
# Anthropic puts images inline as ``source.data`` base64.
big_b64 = "A" * 1_000_000
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": big_b64,
},
},
],
}
]
tokens = TokenCounter.count_message_tokens(messages)
assert tokens < 5000
def test_count_message_tokens_still_counts_text_parts(self):
# The image bypass must not regress regular text-part counting.
messages = [
{
"role": "user",
"content": [{"type": "text", "text": "hello world"}],
}
]
assert TokenCounter.count_message_tokens(messages) > 0
def test_format_conversation_for_compression(
self, prompt_builder, sample_conversation
):

View File

@@ -153,219 +153,4 @@ class TestLogActivity:
):
list(failing_gen(FakeAgent()))
def test_log_activity_emits_lifecycle_events(self, caplog):
import logging as _logging
from application.logging import log_activity
class FakeAgent:
endpoint = "test"
user = "user1"
user_api_key = "k"
query = "q"
agent_id = "agent-7"
conversation_id = "conv-3"
@log_activity()
def gen(agent, log_context=None):
yield "x"
with patch("application.logging._log_activity_to_db"), \
caplog.at_level(_logging.INFO, logger="root"):
list(gen(FakeAgent()))
messages = [r.message for r in caplog.records]
assert "activity_started" in messages
assert "activity_finished" in messages
started = next(r for r in caplog.records if r.message == "activity_started")
finished = next(r for r in caplog.records if r.message == "activity_finished")
assert started.endpoint == "test"
assert started.user_id == "user1"
assert started.agent_id == "agent-7"
assert started.conversation_id == "conv-3"
assert started.parent_activity_id is None # top-level activity
assert finished.activity_id == started.activity_id
assert finished.status == "ok"
assert isinstance(finished.duration_ms, int)
assert finished.duration_ms >= 0
assert finished.error_class is None
def test_log_activity_records_parent_activity_id_when_nested(self, caplog):
# Sub-agents / workflow_agents wrap an outer @log_activity gen;
# the inner activity_started event must link to the outer's id.
import logging as _logging
from application.logging import log_activity
class FakeAgent:
endpoint = "outer"
user = "user1"
user_api_key = ""
query = ""
class InnerAgent:
endpoint = "inner"
user = "user1"
user_api_key = ""
query = ""
@log_activity()
def inner_gen(agent, log_context=None):
yield "i"
@log_activity()
def outer_gen(agent, log_context=None):
yield from inner_gen(InnerAgent())
with patch("application.logging._log_activity_to_db"), \
caplog.at_level(_logging.INFO, logger="root"):
list(outer_gen(FakeAgent()))
starts = [r for r in caplog.records if r.message == "activity_started"]
assert len(starts) == 2
outer_start, inner_start = starts
assert outer_start.endpoint == "outer"
assert outer_start.parent_activity_id is None
assert inner_start.endpoint == "inner"
assert inner_start.parent_activity_id == outer_start.activity_id
def test_log_activity_records_error_status_on_failure(self, caplog):
import logging as _logging
from application.logging import log_activity
class FakeAgent:
endpoint = "boom"
user = "user1"
user_api_key = ""
query = ""
@log_activity()
def failing(agent, log_context=None):
yield "before"
raise ValueError("bad thing")
with patch("application.logging._log_activity_to_db"), \
caplog.at_level(_logging.INFO, logger="root"), \
pytest.raises(ValueError):
list(failing(FakeAgent()))
finished = next(r for r in caplog.records if r.message == "activity_finished")
assert finished.status == "error"
assert finished.error_class == "ValueError"
def test_log_activity_emits_response_summary_aggregates(self, caplog):
# Replaces the ``agent_response`` event that ``run_agent_logic``
# used to emit only on the Celery webhook path: every Flask
# activity now gets the same aggregates on ``activity_finished``.
import logging as _logging
from application.logging import log_activity
class FakeAgent:
endpoint = "stream"
user = "user1"
user_api_key = ""
query = "q"
@log_activity()
def streaming(agent, log_context=None):
yield {"answer": "Hello "}
yield {"answer": "world"}
yield {"thought": "thinking..."}
yield {"sources": [{"id": "a"}, {"id": "b"}, {"id": "c"}]}
yield {"tool_calls": [{"name": "search"}, {"name": "fetch"}]}
yield "ignored-non-dict"
yield {"unrecognised": "noop"}
with patch("application.logging._log_activity_to_db"), \
caplog.at_level(_logging.INFO, logger="root"):
list(streaming(FakeAgent()))
finished = next(r for r in caplog.records if r.message == "activity_finished")
assert finished.answer_length == len("Hello world")
assert finished.thought_length == len("thinking...")
assert finished.source_count == 3
assert finished.tool_call_count == 2
def test_log_activity_aggregates_initialise_to_zero(self, caplog):
# No yields → summary fields still present and zero (so Axiom
# schemas don't get a missing-field hole on empty activities).
import logging as _logging
from application.logging import log_activity
class FakeAgent:
endpoint = "stream"
user = "user1"
user_api_key = ""
query = ""
@log_activity()
def empty(agent, log_context=None):
return
yield # pragma: no cover — generator marker
with patch("application.logging._log_activity_to_db"), \
caplog.at_level(_logging.INFO, logger="root"):
list(empty(FakeAgent()))
finished = next(r for r in caplog.records if r.message == "activity_finished")
assert finished.answer_length == 0
assert finished.thought_length == 0
assert finished.source_count == 0
assert finished.tool_call_count == 0
@pytest.mark.unit
class TestAccumulateResponseSummary:
"""Direct coverage of the dispatch table — easier to enumerate edge
cases here than in end-to-end ``log_activity`` tests."""
def _ctx(self):
from application.logging import LogContext
return LogContext(
endpoint="e", activity_id="a", user="u", api_key="k", query="q"
)
def test_answer_appends_length(self):
from application.logging import _accumulate_response_summary
ctx = self._ctx()
_accumulate_response_summary({"answer": "abcd"}, ctx)
_accumulate_response_summary({"answer": "ef"}, ctx)
assert ctx.answer_length == 6
assert ctx.thought_length == 0
def test_non_dict_items_are_ignored(self):
from application.logging import _accumulate_response_summary
ctx = self._ctx()
for item in ("string", 123, None, ["list"], object()):
_accumulate_response_summary(item, ctx)
assert ctx.answer_length == 0
assert ctx.source_count == 0
def test_sources_must_be_list(self):
# A malformed payload (sources=str) shouldn't crash the
# accumulator — drop it silently rather than half-count it.
from application.logging import _accumulate_response_summary
ctx = self._ctx()
_accumulate_response_summary({"sources": "not-a-list"}, ctx)
assert ctx.source_count == 0
def test_tool_calls_counted(self):
from application.logging import _accumulate_response_summary
ctx = self._ctx()
_accumulate_response_summary(
{"tool_calls": [{"name": "a"}, {"name": "b"}]}, ctx
)
assert ctx.tool_call_count == 2

View File

@@ -361,16 +361,6 @@ class TestSerializeForTokenCount:
def test_none_returns_empty(self):
assert _serialize_for_token_count(None) == ""
def test_bytes_returns_empty(self):
# Regression: image/file attachments arrive as ``bytes`` from the
# provider-specific message preparation. Without an explicit
# branch they fell through to ``str(value)`` and inflated
# ``prompt_tokens`` by millions per call.
png_header = b"\x89PNG\r\n\x1a\n" + b"\x00" * 4096
assert _serialize_for_token_count(png_header) == ""
assert _serialize_for_token_count(bytearray(png_header)) == ""
assert _serialize_for_token_count(memoryview(png_header)) == ""
def test_list_recursion(self):
result = _serialize_for_token_count(["hello", "world"])
assert result == ["hello", "world"]
@@ -448,11 +438,6 @@ class TestCountTokens:
data_url = "data:image/png;base64,iVBORw0KGgoAAAA..."
assert _count_tokens(data_url) == 0
def test_bytes_returns_zero(self):
# Regression: a multi-megabyte ``bytes`` payload (image attachment)
# used to be repr-stringified and counted as millions of tokens.
assert _count_tokens(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100000) == 0
def test_dict_counts(self):
assert _count_tokens({"key": "some text here"}) > 0
@@ -518,26 +503,6 @@ class TestCountPromptTokens:
)
assert tokens_with > tokens_without
def test_bytes_in_message_content_does_not_inflate_count(self):
# Production regression: a single image attachment landed as bytes
# inside ``content`` and the prior repr-fallback pushed
# ``prompt_tokens`` past 2,000,000 on Axiom. Verify the bytes
# branch keeps the count bounded by the surrounding text.
text_only = [{"content": "Summarize this image."}]
with_bytes = [
{
"content": [
{"type": "text", "text": "Summarize this image."},
{"type": "image", "data": b"\x89PNG\r\n" + b"\x00" * 200_000},
]
}
]
baseline = _count_prompt_tokens(text_only, tools=None)
with_attachment = _count_prompt_tokens(with_bytes, tools=None)
# 200KB of zero bytes used to register as ~200K tokens; cap the
# acceptable inflation at a small constant for tool-format overhead.
assert with_attachment - baseline < 50
def test_message_with_tool_calls_field(self):
messages = [
{