mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 06:30:03 +00:00
Compare commits
1 Commits
0.17.1
...
feat-bring
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e0a8cc178b |
12
README.md
12
README.md
@@ -47,13 +47,11 @@
|
||||
</ul>
|
||||
|
||||
## Roadmap
|
||||
- [x] Agent Workflow Builder with conditional nodes ( February 2026 )
|
||||
- [x] SharePoint & Confluence connectors ( March – April 2026 )
|
||||
- [x] Research mode ( March 2026 )
|
||||
- [x] Postgres migration for user data ( April 2026 )
|
||||
- [x] OpenTelemetry observability ( April 2026 )
|
||||
- [x] Bring Your Own Model (BYOM) ( April 2026 )
|
||||
- [ ] Agent scheduling (RedBeat-backed) ( Q2 2026 )
|
||||
- [x] Add OAuth 2.0 authentication for MCP ( September 2025 )
|
||||
- [x] Deep Agents ( October 2025 )
|
||||
- [x] Prompt Templating ( October 2025 )
|
||||
- [x] Full api tooling ( Dec 2025 )
|
||||
- [ ] Agent scheduling ( Jan 2026 )
|
||||
|
||||
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
|
||||
|
||||
|
||||
@@ -274,14 +274,7 @@ class ToolExecutor:
|
||||
|
||||
if tool_id is None or action_name is None:
|
||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
|
||||
logger.error(
|
||||
"tool_call_parse_failed",
|
||||
extra={
|
||||
"llm_class_name": llm_class_name,
|
||||
"llm_tool_name": llm_name,
|
||||
"call_id": call_id,
|
||||
},
|
||||
)
|
||||
logger.error(error_message)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
@@ -296,15 +289,7 @@ class ToolExecutor:
|
||||
|
||||
if tool_id not in tools_dict:
|
||||
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
||||
logger.error(
|
||||
"tool_id_not_found",
|
||||
extra={
|
||||
"tool_id": tool_id,
|
||||
"llm_tool_name": llm_name,
|
||||
"call_id": call_id,
|
||||
"available_tool_count": len(tools_dict),
|
||||
},
|
||||
)
|
||||
logger.error(error_message)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
@@ -371,15 +356,7 @@ class ToolExecutor:
|
||||
f"Failed to load tool '{tool_data.get('name')}' (tool_id key={tool_id}): "
|
||||
"missing 'id' on tool row."
|
||||
)
|
||||
logger.error(
|
||||
"tool_load_failed",
|
||||
extra={
|
||||
"tool_name": tool_data.get("name"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"call_id": call_id,
|
||||
},
|
||||
)
|
||||
logger.error(error_message)
|
||||
tool_call_data["result"] = error_message
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
@@ -474,12 +451,10 @@ class ToolExecutor:
|
||||
row_id = tool_data.get("id")
|
||||
if not row_id:
|
||||
logger.error(
|
||||
"tool_missing_row_id",
|
||||
extra={
|
||||
"tool_name": tool_data.get("name"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
},
|
||||
"Tool data missing 'id' for tool name=%s (enumerate-key tool_id=%s); "
|
||||
"skipping load to avoid binding a non-UUID downstream.",
|
||||
tool_data.get("name"),
|
||||
tool_id,
|
||||
)
|
||||
return None
|
||||
tool_config["tool_id"] = str(row_id)
|
||||
|
||||
@@ -12,12 +12,6 @@ logger = logging.getLogger(__name__)
|
||||
class TokenCounter:
|
||||
"""Centralized token counting for conversations and messages."""
|
||||
|
||||
# Per-image token estimate. Provider tokenizers vary widely
|
||||
# (Gemini ~258, GPT-4o 85-1500, Claude ~1500) and the actual cost
|
||||
# depends on resolution/detail we can't see here. Errs slightly high
|
||||
# so the threshold check stays conservative.
|
||||
_IMAGE_PART_TOKEN_ESTIMATE = 1500
|
||||
|
||||
@staticmethod
|
||||
def count_message_tokens(messages: List[Dict]) -> int:
|
||||
"""
|
||||
@@ -35,36 +29,12 @@ class TokenCounter:
|
||||
if isinstance(content, str):
|
||||
total_tokens += num_tokens_from_string(content)
|
||||
elif isinstance(content, list):
|
||||
# Handle structured content (tool calls, image parts, etc.)
|
||||
# Handle structured content (tool calls, etc.)
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
total_tokens += TokenCounter._count_content_part(item)
|
||||
total_tokens += num_tokens_from_string(str(item))
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def _count_content_part(item: Dict) -> int:
|
||||
# Image/file attachments are billed by the provider per image,
|
||||
# not proportional to the inline bytes/base64 string.
|
||||
# ``str(item)`` on a 1MB image inflates the count by ~10000x,
|
||||
# which trips spurious compression and overflows downstream
|
||||
# input limits.
|
||||
item_type = item.get("type")
|
||||
|
||||
if "files" in item:
|
||||
files = item.get("files")
|
||||
count = len(files) if isinstance(files, list) and files else 1
|
||||
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE * count
|
||||
|
||||
if "image_url" in item or item_type in {
|
||||
"image",
|
||||
"image_url",
|
||||
"input_image",
|
||||
"file",
|
||||
}:
|
||||
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE
|
||||
|
||||
return num_tokens_from_string(str(item))
|
||||
|
||||
@staticmethod
|
||||
def count_query_tokens(
|
||||
queries: List[Dict[str, Any]], include_tool_calls: bool = True
|
||||
|
||||
@@ -9,7 +9,6 @@ from jose import jwt
|
||||
|
||||
from application.auth import handle_auth
|
||||
|
||||
from application.core import log_context
|
||||
from application.core.logging_config import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -113,38 +112,6 @@ def generate_token():
|
||||
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
|
||||
|
||||
|
||||
_LOG_CTX_TOKEN_ATTR = "_log_ctx_token"
|
||||
|
||||
|
||||
@app.before_request
|
||||
def _bind_log_context():
|
||||
"""Bind activity_id + endpoint for the duration of this request.
|
||||
|
||||
Runs before ``authenticate_request``; ``user_id`` is overlaid in a
|
||||
follow-up handler once the JWT has been decoded.
|
||||
"""
|
||||
if request.method == "OPTIONS":
|
||||
return None
|
||||
activity_id = str(uuid.uuid4())
|
||||
request.activity_id = activity_id
|
||||
token = log_context.bind(
|
||||
activity_id=activity_id,
|
||||
endpoint=request.endpoint,
|
||||
)
|
||||
setattr(request, _LOG_CTX_TOKEN_ATTR, token)
|
||||
return None
|
||||
|
||||
|
||||
@app.teardown_request
|
||||
def _reset_log_context(_exc):
|
||||
# SSE streams keep yielding after teardown fires, but a2wsgi runs each
|
||||
# request inside ``copy_context().run(...)``, so this reset doesn't
|
||||
# leak into the stream's view of the context.
|
||||
token = getattr(request, _LOG_CTX_TOKEN_ATTR, None)
|
||||
if token is not None:
|
||||
log_context.reset(token)
|
||||
|
||||
|
||||
@app.before_request
|
||||
def enforce_stt_request_size_limits():
|
||||
if request.method == "OPTIONS":
|
||||
@@ -181,21 +148,6 @@ def authenticate_request():
|
||||
request.decoded_token = decoded_token
|
||||
|
||||
|
||||
@app.before_request
|
||||
def _bind_user_id_to_log_context():
|
||||
# Registered after ``authenticate_request`` (Flask runs before_request
|
||||
# handlers in registration order), so ``request.decoded_token`` is
|
||||
# populated by the time we read it. ``teardown_request`` unwinds the
|
||||
# whole request-level bind, so no separate reset token is needed here.
|
||||
if request.method == "OPTIONS":
|
||||
return None
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
|
||||
if user_id:
|
||||
log_context.bind(user_id=user_id)
|
||||
return None
|
||||
|
||||
|
||||
@app.after_request
|
||||
def after_request(response: Response) -> Response:
|
||||
"""Add CORS headers for the pure Flask development entrypoint."""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
@@ -11,14 +10,6 @@ from application.utils import get_hash
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cache_default(value):
|
||||
# Image attachments arrive inline as bytes (see GoogleLLM.prepare_messages_with_attachments);
|
||||
# hash so the cache key stays bounded in size and stable across identical content.
|
||||
if isinstance(value, (bytes, bytearray, memoryview)):
|
||||
return f"<bytes:sha256:{hashlib.sha256(bytes(value)).hexdigest()}>"
|
||||
return repr(value)
|
||||
|
||||
_redis_instance = None
|
||||
_redis_creation_failed = False
|
||||
_instance_lock = Lock()
|
||||
@@ -45,7 +36,7 @@ def get_redis_instance():
|
||||
def gen_cache_key(messages, model="docgpt", tools=None):
|
||||
if not all(isinstance(msg, dict) for msg in messages):
|
||||
raise ValueError("All messages must be dictionaries.")
|
||||
messages_str = json.dumps(messages, default=_cache_default)
|
||||
messages_str = json.dumps(messages)
|
||||
tools_str = json.dumps(str(tools)) if tools else ""
|
||||
combined = f"{model}_{messages_str}_{tools_str}"
|
||||
cache_key = get_hash(combined)
|
||||
|
||||
@@ -1,17 +1,8 @@
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from celery import Celery
|
||||
from application.core import log_context
|
||||
from application.core.settings import settings
|
||||
from celery.signals import (
|
||||
setup_logging,
|
||||
task_postrun,
|
||||
task_prerun,
|
||||
worker_process_init,
|
||||
worker_ready,
|
||||
)
|
||||
from celery.signals import setup_logging, worker_process_init, worker_ready
|
||||
|
||||
|
||||
def make_celery(app_name=__name__):
|
||||
@@ -50,54 +41,6 @@ def _dispose_db_engine_on_fork(*args, **kwargs):
|
||||
dispose_engine()
|
||||
|
||||
|
||||
# Most tasks in this repo accept ``user`` where the log context wants
|
||||
# ``user_id``; map task parameter names to context keys explicitly.
|
||||
_TASK_PARAM_TO_CTX_KEY: dict[str, str] = {
|
||||
"user": "user_id",
|
||||
"user_id": "user_id",
|
||||
"agent_id": "agent_id",
|
||||
"conversation_id": "conversation_id",
|
||||
}
|
||||
|
||||
_task_log_tokens: dict[str, object] = {}
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def _bind_task_log_context(task_id, task, args, kwargs, **_):
|
||||
# Resolve task args by parameter name — nearly every task in this repo
|
||||
# is called positionally, so ``kwargs.get('user')`` would bind nothing.
|
||||
ctx = {"activity_id": task_id}
|
||||
try:
|
||||
sig = inspect.signature(task.run)
|
||||
bound = sig.bind_partial(*args, **kwargs).arguments
|
||||
except (TypeError, ValueError):
|
||||
bound = dict(kwargs)
|
||||
for param_name, value in bound.items():
|
||||
ctx_key = _TASK_PARAM_TO_CTX_KEY.get(param_name)
|
||||
if ctx_key and value:
|
||||
ctx[ctx_key] = value
|
||||
_task_log_tokens[task_id] = log_context.bind(**ctx)
|
||||
|
||||
|
||||
@task_postrun.connect
|
||||
def _unbind_task_log_context(task_id, **_):
|
||||
# ``task_postrun`` fires on both success and failure. Required for
|
||||
# Celery: unlike the Flask path, tasks aren't isolated in their own
|
||||
# ``copy_context().run(...)``, so a missing reset would leak the
|
||||
# bind onto the next task on the same worker.
|
||||
token = _task_log_tokens.pop(task_id, None)
|
||||
if token is None:
|
||||
return
|
||||
try:
|
||||
log_context.reset(token)
|
||||
except ValueError:
|
||||
# task_prerun and task_postrun ran on different threads (non-default
|
||||
# Celery pool); the token isn't valid in this context. Drop it.
|
||||
logging.getLogger(__name__).debug(
|
||||
"log_context reset skipped for task %s", task_id
|
||||
)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def _run_version_check(*args, **kwargs):
|
||||
"""Kick off the anonymous version check on worker startup.
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Per-activity logging context backed by ``contextvars``.
|
||||
|
||||
The ``_ContextFilter`` installed by ``logging_config.setup_logging`` stamps
|
||||
every ``LogRecord`` emitted inside a ``bind`` block with the bound keys, so
|
||||
they land as first-class attributes on the OTLP log export rather than being
|
||||
buried inside formatted message bodies.
|
||||
|
||||
A single ``ContextVar`` holds a dict so nested binds reset atomically (LIFO)
|
||||
via the token returned by ``bind``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar, Token
|
||||
from typing import Mapping
|
||||
|
||||
|
||||
_CTX_KEYS: frozenset[str] = frozenset(
|
||||
{
|
||||
"activity_id",
|
||||
"parent_activity_id",
|
||||
"user_id",
|
||||
"agent_id",
|
||||
"conversation_id",
|
||||
"endpoint",
|
||||
"model",
|
||||
}
|
||||
)
|
||||
|
||||
_ctx: ContextVar[Mapping[str, str]] = ContextVar("log_ctx", default={})
|
||||
|
||||
|
||||
def bind(**kwargs: object) -> Token:
|
||||
"""Overlay the given keys onto the current context.
|
||||
|
||||
Returns a ``Token`` so the caller can ``reset`` in a ``finally`` block.
|
||||
Keys outside :data:`_CTX_KEYS` are silently dropped (so a typo can't
|
||||
stamp a stray field name onto every record), as are ``None`` values
|
||||
(a missing attribute is more useful than the literal string ``"None"``).
|
||||
"""
|
||||
overlay = {
|
||||
k: str(v)
|
||||
for k, v in kwargs.items()
|
||||
if k in _CTX_KEYS and v is not None
|
||||
}
|
||||
new = {**_ctx.get(), **overlay}
|
||||
return _ctx.set(new)
|
||||
|
||||
|
||||
def reset(token: Token) -> None:
|
||||
"""Restore the context to the snapshot captured by the matching ``bind``."""
|
||||
_ctx.reset(token)
|
||||
|
||||
|
||||
def snapshot() -> Mapping[str, str]:
|
||||
"""Return the current context dict. Treat as read-only; use :func:`bind`."""
|
||||
return _ctx.get()
|
||||
@@ -2,36 +2,6 @@ import logging
|
||||
import os
|
||||
from logging.config import dictConfig
|
||||
|
||||
from application.core.log_context import snapshot as _ctx_snapshot
|
||||
|
||||
|
||||
# Loggers with ``propagate=False`` don't share root's handlers, so the
|
||||
# context filter has to be installed on their handlers directly.
|
||||
_NON_PROPAGATING_LOGGERS: tuple[str, ...] = (
|
||||
"uvicorn",
|
||||
"uvicorn.access",
|
||||
"uvicorn.error",
|
||||
"celery.app.trace",
|
||||
"celery.worker.strategy",
|
||||
"gunicorn.error",
|
||||
"gunicorn.access",
|
||||
)
|
||||
|
||||
|
||||
class _ContextFilter(logging.Filter):
|
||||
"""Stamp the current ``log_context`` snapshot onto every ``LogRecord``.
|
||||
|
||||
Must be installed on **handlers**, not loggers: Python skips logger-level
|
||||
filters when a child logger's record propagates up. The ``hasattr`` guard
|
||||
keeps an explicit ``logger.info(..., extra={...})`` from being overwritten.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
for key, value in _ctx_snapshot().items():
|
||||
if not hasattr(record, key):
|
||||
setattr(record, key, value)
|
||||
return True
|
||||
|
||||
|
||||
def _otlp_logs_enabled() -> bool:
|
||||
"""Return True when the user has opted in to OTLP log export.
|
||||
@@ -90,23 +60,3 @@ def setup_logging() -> None:
|
||||
for handler in preserved_handlers:
|
||||
if handler not in root.handlers:
|
||||
root.addHandler(handler)
|
||||
|
||||
_install_context_filter()
|
||||
|
||||
|
||||
def _install_context_filter() -> None:
|
||||
"""Attach :class:`_ContextFilter` to root's handlers + every handler on
|
||||
the known non-propagating loggers. Skipping handlers that already carry
|
||||
one keeps repeat ``setup_logging`` calls from stacking filters.
|
||||
"""
|
||||
|
||||
def _has_ctx_filter(handler: logging.Handler) -> bool:
|
||||
return any(isinstance(f, _ContextFilter) for f in handler.filters)
|
||||
|
||||
for handler in logging.getLogger().handlers:
|
||||
if not _has_ctx_filter(handler):
|
||||
handler.addFilter(_ContextFilter())
|
||||
for name in _NON_PROPAGATING_LOGGERS:
|
||||
for handler in logging.getLogger(name).handlers:
|
||||
if not _has_ctx_filter(handler):
|
||||
handler.addFilter(_ContextFilter())
|
||||
|
||||
@@ -11,7 +11,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicLLM(BaseLLM):
|
||||
provider_name = "anthropic"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import ClassVar
|
||||
|
||||
from application.cache import gen_cache, stream_cache
|
||||
|
||||
@@ -11,10 +10,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
# Stamped onto the ``llm_stream_start`` event so dashboards can group
|
||||
# calls by vendor. Subclasses override.
|
||||
provider_name: ClassVar[str] = "unknown"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoded_token=None,
|
||||
@@ -123,26 +118,6 @@ class BaseLLM(ABC):
|
||||
return args_dict
|
||||
return {k: v for k, v in args_dict.items() if v is not None}
|
||||
|
||||
@staticmethod
|
||||
def _is_non_retriable_client_error(exc: BaseException) -> bool:
|
||||
"""4xx errors mean the request itself is malformed — retrying with
|
||||
a different model fails identically and doubles the work. Only
|
||||
transient/5xx/connection errors should trigger fallback."""
|
||||
try:
|
||||
from google.genai.errors import ClientError as _GenaiClientError
|
||||
|
||||
if isinstance(exc, _GenaiClientError):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
for attr in ("status_code", "code", "http_status"):
|
||||
v = getattr(exc, attr, None)
|
||||
if isinstance(v, int) and 400 <= v < 500:
|
||||
return True
|
||||
resp = getattr(exc, "response", None)
|
||||
v = getattr(resp, "status_code", None)
|
||||
return isinstance(v, int) and 400 <= v < 500
|
||||
|
||||
def _execute_with_fallback(
|
||||
self, method_name: str, decorators: list, *args, **kwargs
|
||||
):
|
||||
@@ -166,18 +141,12 @@ class BaseLLM(ABC):
|
||||
|
||||
if is_stream:
|
||||
return self._stream_with_fallback(
|
||||
decorated_method, method_name, decorators, *args, **kwargs
|
||||
decorated_method, method_name, *args, **kwargs
|
||||
)
|
||||
|
||||
try:
|
||||
return decorated_method()
|
||||
except Exception as e:
|
||||
if self._is_non_retriable_client_error(e):
|
||||
logger.error(
|
||||
f"Primary LLM failed with non-retriable client error; "
|
||||
f"skipping fallback: {str(e)}"
|
||||
)
|
||||
raise
|
||||
if not self.fallback_llm:
|
||||
logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
|
||||
raise
|
||||
@@ -187,27 +156,14 @@ class BaseLLM(ABC):
|
||||
f"{fallback.model_id}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
# Apply decorators to fallback's raw method directly — calling
|
||||
# fallback.gen() would re-enter the orchestrator and recurse via
|
||||
# fallback.fallback_llm.
|
||||
fallback_method = getattr(fallback, method_name)
|
||||
for decorator in decorators:
|
||||
fallback_method = decorator(fallback_method)
|
||||
fallback_method = getattr(
|
||||
fallback, method_name.replace("_raw_", "")
|
||||
)
|
||||
fallback_kwargs = {**kwargs, "model": fallback.model_id}
|
||||
try:
|
||||
return fallback_method(fallback, *args, **fallback_kwargs)
|
||||
except Exception as e2:
|
||||
if self._is_non_retriable_client_error(e2):
|
||||
logger.error(
|
||||
f"Fallback LLM failed with non-retriable client "
|
||||
f"error; giving up: {str(e2)}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Fallback LLM also failed; giving up: {str(e2)}")
|
||||
raise
|
||||
return fallback_method(*args, **fallback_kwargs)
|
||||
|
||||
def _stream_with_fallback(
|
||||
self, decorated_method, method_name, decorators, *args, **kwargs
|
||||
self, decorated_method, method_name, *args, **kwargs
|
||||
):
|
||||
"""
|
||||
Wrapper generator that catches mid-stream errors and falls back.
|
||||
@@ -220,12 +176,6 @@ class BaseLLM(ABC):
|
||||
try:
|
||||
yield from decorated_method()
|
||||
except Exception as e:
|
||||
if self._is_non_retriable_client_error(e):
|
||||
logger.error(
|
||||
f"Primary LLM failed mid-stream with non-retriable client "
|
||||
f"error; skipping fallback: {str(e)}"
|
||||
)
|
||||
raise
|
||||
if not self.fallback_llm:
|
||||
logger.error(
|
||||
f"Primary LLM failed and no fallback configured: {str(e)}"
|
||||
@@ -236,37 +186,11 @@ class BaseLLM(ABC):
|
||||
f"Primary LLM failed mid-stream. Falling back to "
|
||||
f"{fallback.model_id}. Error: {str(e)}"
|
||||
)
|
||||
# Apply decorators to fallback's raw stream method directly —
|
||||
# calling fallback.gen_stream() would re-enter the orchestrator
|
||||
# and recurse via fallback.fallback_llm. Emit the stream-start
|
||||
# event manually so dashboards still see the fallback's
|
||||
# provider/model when the response actually comes from it.
|
||||
fallback._emit_stream_start_log(
|
||||
fallback.model_id,
|
||||
kwargs.get("messages"),
|
||||
kwargs.get("tools"),
|
||||
bool(
|
||||
kwargs.get("_usage_attachments")
|
||||
or kwargs.get("attachments")
|
||||
),
|
||||
fallback_method = getattr(
|
||||
fallback, method_name.replace("_raw_", "")
|
||||
)
|
||||
fallback_method = getattr(fallback, method_name)
|
||||
for decorator in decorators:
|
||||
fallback_method = decorator(fallback_method)
|
||||
fallback_kwargs = {**kwargs, "model": fallback.model_id}
|
||||
try:
|
||||
yield from fallback_method(fallback, *args, **fallback_kwargs)
|
||||
except Exception as e2:
|
||||
if self._is_non_retriable_client_error(e2):
|
||||
logger.error(
|
||||
f"Fallback LLM failed mid-stream with non-retriable "
|
||||
f"client error; giving up: {str(e2)}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Fallback LLM also failed mid-stream; giving up: {str(e2)}"
|
||||
)
|
||||
raise
|
||||
yield from fallback_method(*args, **fallback_kwargs)
|
||||
|
||||
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
|
||||
decorators = [gen_token_usage, gen_cache]
|
||||
@@ -281,58 +205,7 @@ class BaseLLM(ABC):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _emit_stream_start_log(self, model, messages, tools, has_attachments):
|
||||
# Stamped with ``self.provider_name`` so dashboards can group calls
|
||||
# by vendor; the fallback path emits its own copy on the fallback
|
||||
# instance so the actual responding provider is recorded.
|
||||
logging.info(
|
||||
"llm_stream_start",
|
||||
extra={
|
||||
"model": model,
|
||||
"provider": self.provider_name,
|
||||
"message_count": len(messages) if messages is not None else 0,
|
||||
"has_attachments": bool(has_attachments),
|
||||
"has_tools": bool(tools),
|
||||
},
|
||||
)
|
||||
|
||||
def _emit_stream_finished_log(
|
||||
self,
|
||||
model,
|
||||
*,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
latency_ms,
|
||||
cached_tokens=None,
|
||||
error=None,
|
||||
):
|
||||
# Paired with ``llm_stream_start`` so cost dashboards can sum tokens
|
||||
# by user/agent/provider. Token counts are client-side estimates
|
||||
# from ``stream_token_usage``; vendor-reported counts (incl.
|
||||
# ``cached_tokens`` for prompt caching) require per-provider
|
||||
# extraction in each ``_raw_gen_stream`` and aren't wired yet.
|
||||
extra = {
|
||||
"model": model,
|
||||
"provider": self.provider_name,
|
||||
"prompt_tokens": int(prompt_tokens),
|
||||
"completion_tokens": int(completion_tokens),
|
||||
"latency_ms": int(latency_ms),
|
||||
"status": "error" if error is not None else "ok",
|
||||
}
|
||||
if cached_tokens is not None:
|
||||
extra["cached_tokens"] = int(cached_tokens)
|
||||
if error is not None:
|
||||
extra["error_class"] = type(error).__name__
|
||||
logging.info("llm_stream_finished", extra=extra)
|
||||
|
||||
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
|
||||
# Attachments arrive as ``_usage_attachments`` from ``Agent._llm_gen``;
|
||||
# the ``stream_token_usage`` decorator pops that key, but the log
|
||||
# fires before the decorator runs so it's still in ``kwargs`` here.
|
||||
has_attachments = bool(
|
||||
kwargs.get("_usage_attachments") or kwargs.get("attachments")
|
||||
)
|
||||
self._emit_stream_start_log(model, messages, tools, has_attachments)
|
||||
decorators = [stream_cache, stream_token_usage]
|
||||
return self._execute_with_fallback(
|
||||
"_raw_gen_stream",
|
||||
|
||||
@@ -6,8 +6,6 @@ DOCSGPT_BASE_URL = "https://oai.arc53.com"
|
||||
DOCSGPT_MODEL = "docsgpt"
|
||||
|
||||
class DocsGPTAPILLM(OpenAILLM):
|
||||
provider_name = "docsgpt"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=DOCSGPT_API_KEY,
|
||||
|
||||
@@ -10,8 +10,6 @@ from application.storage.storage_creator import StorageCreator
|
||||
|
||||
|
||||
class GoogleLLM(BaseLLM):
|
||||
provider_name = "google"
|
||||
|
||||
def __init__(
|
||||
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
|
||||
):
|
||||
@@ -81,39 +79,24 @@ class GoogleLLM(BaseLLM):
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get("mime_type")
|
||||
|
||||
if mime_type not in self.get_supported_attachment_types():
|
||||
continue
|
||||
try:
|
||||
# Images go inline as bytes per Google's guidance for
|
||||
# requests under 20MB; the Files API can return before
|
||||
# the upload reaches ACTIVE state and yield an empty URI.
|
||||
if mime_type.startswith("image/"):
|
||||
file_bytes = self._read_attachment_bytes(attachment)
|
||||
files.append(
|
||||
{"file_bytes": file_bytes, "mime_type": mime_type}
|
||||
)
|
||||
else:
|
||||
if mime_type in self.get_supported_attachment_types():
|
||||
try:
|
||||
file_uri = self._upload_file_to_google(attachment)
|
||||
if not file_uri:
|
||||
raise ValueError(
|
||||
f"Google Files API returned empty URI for "
|
||||
f"{attachment.get('path', 'unknown')}"
|
||||
)
|
||||
logging.info(
|
||||
f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}"
|
||||
)
|
||||
files.append({"file_uri": file_uri, "mime_type": mime_type})
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"GoogleLLM: Error processing attachment: {e}", exc_info=True
|
||||
)
|
||||
if "content" in attachment:
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"GoogleLLM: Error uploading file: {e}", exc_info=True
|
||||
)
|
||||
if "content" in attachment:
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
|
||||
}
|
||||
)
|
||||
if files:
|
||||
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
|
||||
prepared_messages[user_message_index]["content"].append({"files": files})
|
||||
@@ -129,9 +112,7 @@ class GoogleLLM(BaseLLM):
|
||||
Returns:
|
||||
str: Google AI file URI for the uploaded file.
|
||||
"""
|
||||
# Truthy check, not membership: a poisoned cache row of "" or
|
||||
# None must be treated as a miss and trigger a fresh upload.
|
||||
if attachment.get("google_file_uri"):
|
||||
if "google_file_uri" in attachment:
|
||||
return attachment["google_file_uri"]
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
@@ -145,10 +126,6 @@ class GoogleLLM(BaseLLM):
|
||||
file=local_path
|
||||
).uri,
|
||||
)
|
||||
if not file_uri:
|
||||
raise ValueError(
|
||||
f"Google Files API upload returned empty URI for {file_path}"
|
||||
)
|
||||
|
||||
# Cache the Google file URI on the attachment row so we don't
|
||||
# re-upload on the next LLM call. Accept either a PG UUID
|
||||
@@ -182,26 +159,6 @@ class GoogleLLM(BaseLLM):
|
||||
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _read_attachment_bytes(self, attachment):
|
||||
"""
|
||||
Read attachment bytes from storage for inline transmission.
|
||||
|
||||
Args:
|
||||
attachment (dict): Attachment dictionary with path and metadata.
|
||||
|
||||
Returns:
|
||||
bytes: Raw file bytes.
|
||||
"""
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
raise ValueError("No file path provided in attachment")
|
||||
if not self.storage.file_exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
return self.storage.process_file(
|
||||
file_path,
|
||||
lambda local_path, **kwargs: open(local_path, "rb").read(),
|
||||
)
|
||||
|
||||
def _clean_messages_google(self, messages):
|
||||
"""
|
||||
Convert OpenAI format messages to Google AI format and collect system prompts.
|
||||
@@ -341,24 +298,12 @@ class GoogleLLM(BaseLLM):
|
||||
)
|
||||
elif "files" in item:
|
||||
for file_data in item["files"]:
|
||||
if "file_bytes" in file_data:
|
||||
parts.append(
|
||||
types.Part.from_bytes(
|
||||
data=file_data["file_bytes"],
|
||||
mime_type=file_data["mime_type"],
|
||||
)
|
||||
)
|
||||
elif file_data.get("file_uri"):
|
||||
parts.append(
|
||||
types.Part.from_uri(
|
||||
file_uri=file_data["file_uri"],
|
||||
mime_type=file_data["mime_type"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
"GoogleLLM: dropping file part with empty URI and no bytes"
|
||||
parts.append(
|
||||
types.Part.from_uri(
|
||||
file_uri=file_data["file_uri"],
|
||||
mime_type=file_data["mime_type"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content dictionary format:{item}"
|
||||
@@ -596,6 +541,22 @@ class GoogleLLM(BaseLLM):
|
||||
config.response_mime_type = "application/json"
|
||||
# Check if we have both tools and file attachments
|
||||
|
||||
has_attachments = False
|
||||
for message in messages:
|
||||
for part in message.parts:
|
||||
if hasattr(part, "file_data") and part.file_data is not None:
|
||||
has_attachments = True
|
||||
break
|
||||
if has_attachments:
|
||||
break
|
||||
messages_summary = self._summarize_messages_for_log(messages)
|
||||
logging.info(
|
||||
"GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s",
|
||||
model,
|
||||
messages_summary,
|
||||
has_attachments,
|
||||
)
|
||||
|
||||
response = client.models.generate_content_stream(
|
||||
model=model,
|
||||
contents=messages,
|
||||
|
||||
@@ -5,8 +5,6 @@ GROQ_BASE_URL = "https://api.groq.com/openai/v1"
|
||||
|
||||
|
||||
class GroqLLM(OpenAILLM):
|
||||
provider_name = "groq"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=api_key or settings.GROQ_API_KEY or settings.API_KEY,
|
||||
|
||||
@@ -280,26 +280,7 @@ class LLMHandler(ABC):
|
||||
# Keep serialized function calls/responses so the compressor sees actions
|
||||
parts_text.append(str(item))
|
||||
elif "files" in item:
|
||||
# Image attachments arrive with raw bytes / base64
|
||||
# inline (see GoogleLLM.prepare_messages_with_attachments).
|
||||
# ``str(item)`` would dump the whole byte/base64
|
||||
# blob into the compression prompt and bust the
|
||||
# compression LLM's input limit.
|
||||
files = item.get("files") or []
|
||||
descriptors = []
|
||||
if isinstance(files, list):
|
||||
for f in files:
|
||||
if isinstance(f, dict):
|
||||
descriptors.append(
|
||||
f.get("mime_type") or "file"
|
||||
)
|
||||
elif isinstance(f, str):
|
||||
descriptors.append(f)
|
||||
if not descriptors:
|
||||
descriptors = ["file"]
|
||||
parts_text.append(
|
||||
f"[attachment: {', '.join(descriptors)}]"
|
||||
)
|
||||
parts_text.append(str(item))
|
||||
return "\n".join(parts_text)
|
||||
return ""
|
||||
|
||||
|
||||
@@ -26,8 +26,6 @@ class LlamaSingleton:
|
||||
|
||||
|
||||
class LlamaCpp(BaseLLM):
|
||||
provider_name = "llama_cpp"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key=None,
|
||||
|
||||
@@ -5,8 +5,6 @@ NOVITA_BASE_URL = "https://api.novita.ai/openai"
|
||||
|
||||
|
||||
class NovitaLLM(OpenAILLM):
|
||||
provider_name = "novita"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=api_key or settings.NOVITA_API_KEY or settings.API_KEY,
|
||||
|
||||
@@ -5,8 +5,6 @@ OPEN_ROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
class OpenRouterLLM(OpenAILLM):
|
||||
provider_name = "openrouter"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=api_key or settings.OPEN_ROUTER_API_KEY or settings.API_KEY,
|
||||
|
||||
@@ -61,7 +61,6 @@ def _truncate_base64_for_logging(messages):
|
||||
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
provider_name = "openai"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -3,7 +3,6 @@ from application.core.settings import settings
|
||||
|
||||
|
||||
class PremAILLM(BaseLLM):
|
||||
provider_name = "premai"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
from premai import Prem
|
||||
|
||||
@@ -59,7 +59,6 @@ class LineIterator:
|
||||
|
||||
|
||||
class SagemakerAPILLM(BaseLLM):
|
||||
provider_name = "sagemaker"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
import boto3
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import datetime
|
||||
import functools
|
||||
import inspect
|
||||
import time
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, Generator, List
|
||||
|
||||
from application.core import log_context
|
||||
from application.storage.db.repositories.stack_logs import StackLogsRepository
|
||||
from application.storage.db.session import db_session
|
||||
|
||||
@@ -24,15 +22,6 @@ class LogContext:
|
||||
self.api_key = api_key
|
||||
self.query = query
|
||||
self.stacks = []
|
||||
# Per-activity response aggregates populated by ``_consume_and_log``
|
||||
# while it forwards stream items, then flushed onto the
|
||||
# ``activity_finished`` event so every Flask request gets the
|
||||
# same summary that ``run_agent_logic`` used to log only for the
|
||||
# Celery webhook path.
|
||||
self.answer_length = 0
|
||||
self.thought_length = 0
|
||||
self.source_count = 0
|
||||
self.tool_call_count = 0
|
||||
|
||||
|
||||
def build_stack_data(
|
||||
@@ -89,125 +78,25 @@ def log_activity() -> Callable:
|
||||
user = data.get("user", "local")
|
||||
api_key = data.get("user_api_key", "")
|
||||
query = kwargs.get("query", getattr(args[0], "query", ""))
|
||||
agent_id = getattr(args[0], "agent_id", None) or kwargs.get("agent_id")
|
||||
conversation_id = (
|
||||
kwargs.get("conversation_id")
|
||||
or getattr(args[0], "conversation_id", None)
|
||||
)
|
||||
model = getattr(args[0], "gpt_model", None) or getattr(args[0], "model", None)
|
||||
|
||||
# Capture the surrounding activity_id before overlaying ours,
|
||||
# so nested activities record the parent → child link.
|
||||
parent_activity_id = log_context.snapshot().get("activity_id")
|
||||
|
||||
context = LogContext(endpoint, activity_id, user, api_key, query)
|
||||
kwargs["log_context"] = context
|
||||
|
||||
ctx_token = log_context.bind(
|
||||
activity_id=activity_id,
|
||||
parent_activity_id=parent_activity_id,
|
||||
user_id=user,
|
||||
agent_id=agent_id,
|
||||
conversation_id=conversation_id,
|
||||
endpoint=endpoint,
|
||||
model=model,
|
||||
)
|
||||
|
||||
started_at = time.monotonic()
|
||||
logging.info(
|
||||
"activity_started",
|
||||
extra={
|
||||
"activity_id": activity_id,
|
||||
"parent_activity_id": parent_activity_id,
|
||||
"user_id": user,
|
||||
"agent_id": agent_id,
|
||||
"conversation_id": conversation_id,
|
||||
"endpoint": endpoint,
|
||||
"model": model,
|
||||
},
|
||||
f"Starting activity: {endpoint} - {activity_id} - User: {user}"
|
||||
)
|
||||
|
||||
error: BaseException | None = None
|
||||
try:
|
||||
generator = func(*args, **kwargs)
|
||||
yield from _consume_and_log(generator, context)
|
||||
except Exception as exc:
|
||||
# Only ``Exception`` counts as an activity error; ``GeneratorExit``
|
||||
# (consumer disconnected mid-stream) and ``KeyboardInterrupt``
|
||||
# flow through the finally as ``status="ok"``, matching
|
||||
# ``_consume_and_log``.
|
||||
error = exc
|
||||
raise
|
||||
finally:
|
||||
_emit_activity_finished(
|
||||
context=context,
|
||||
parent_activity_id=parent_activity_id,
|
||||
started_at=started_at,
|
||||
error=error,
|
||||
)
|
||||
log_context.reset(ctx_token)
|
||||
generator = func(*args, **kwargs)
|
||||
yield from _consume_and_log(generator, context)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _emit_activity_finished(
|
||||
*,
|
||||
context: "LogContext",
|
||||
parent_activity_id: str | None,
|
||||
started_at: float,
|
||||
error: BaseException | None,
|
||||
) -> None:
|
||||
"""Emit the paired ``activity_finished`` event with duration, outcome,
|
||||
and per-activity response aggregates accumulated in ``_consume_and_log``.
|
||||
"""
|
||||
duration_ms = int((time.monotonic() - started_at) * 1000)
|
||||
logging.info(
|
||||
"activity_finished",
|
||||
extra={
|
||||
"activity_id": context.activity_id,
|
||||
"parent_activity_id": parent_activity_id,
|
||||
"user_id": context.user,
|
||||
"endpoint": context.endpoint,
|
||||
"duration_ms": duration_ms,
|
||||
"status": "error" if error is not None else "ok",
|
||||
"error_class": type(error).__name__ if error is not None else None,
|
||||
"answer_length": context.answer_length,
|
||||
"thought_length": context.thought_length,
|
||||
"source_count": context.source_count,
|
||||
"tool_call_count": context.tool_call_count,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _accumulate_response_summary(item: Any, context: "LogContext") -> None:
|
||||
"""Mirror the per-line aggregation that ``run_agent_logic`` did for the
|
||||
Celery webhook path, but at the generator-consumption layer so every
|
||||
``Agent.gen`` activity (Flask streaming, sub-agents, workflow agents)
|
||||
gets the same summary.
|
||||
"""
|
||||
if not isinstance(item, dict):
|
||||
return
|
||||
if "answer" in item:
|
||||
context.answer_length += len(str(item["answer"]))
|
||||
return
|
||||
if "thought" in item:
|
||||
context.thought_length += len(str(item["thought"]))
|
||||
return
|
||||
sources = item.get("sources") if "sources" in item else None
|
||||
if isinstance(sources, list):
|
||||
context.source_count += len(sources)
|
||||
return
|
||||
tool_calls = item.get("tool_calls") if "tool_calls" in item else None
|
||||
if isinstance(tool_calls, list):
|
||||
context.tool_call_count += len(tool_calls)
|
||||
|
||||
|
||||
def _consume_and_log(generator: Generator, context: "LogContext"):
|
||||
try:
|
||||
for item in generator:
|
||||
_accumulate_response_summary(item, context)
|
||||
yield item
|
||||
except Exception as e:
|
||||
logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}")
|
||||
|
||||
@@ -17,21 +17,6 @@ _UPDATABLE_SCALARS = {
|
||||
_UPDATABLE_JSONB = {"metadata"}
|
||||
|
||||
|
||||
def _attachment_to_dict(row: Any) -> dict:
|
||||
"""row_to_dict + ``upload_path``→``path`` alias.
|
||||
|
||||
Pre-Postgres, the Mongo attachment shape used ``path``. The PG column
|
||||
is ``upload_path``; LLM provider code (google_ai/openai/anthropic and
|
||||
handlers/base) still reads ``attachment.get("path")``. Mirroring the
|
||||
``id``/``_id`` dual-emit in row_to_dict so consumers don't need to
|
||||
know which storage backend produced the dict.
|
||||
"""
|
||||
out = row_to_dict(row)
|
||||
if "upload_path" in out and out.get("path") is None:
|
||||
out["path"] = out["upload_path"]
|
||||
return out
|
||||
|
||||
|
||||
class AttachmentsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
@@ -81,7 +66,7 @@ class AttachmentsRepository:
|
||||
"legacy_mongo_id": legacy_mongo_id,
|
||||
},
|
||||
)
|
||||
return _attachment_to_dict(result.fetchone())
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get(self, attachment_id: str, user_id: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
@@ -91,7 +76,7 @@ class AttachmentsRepository:
|
||||
{"id": attachment_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return _attachment_to_dict(row) if row is not None else None
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_any(self, attachment_id: str, user_id: str) -> Optional[dict]:
|
||||
"""Resolve an attachment by either PG UUID or legacy Mongo ObjectId string."""
|
||||
@@ -170,14 +155,14 @@ class AttachmentsRepository:
|
||||
params["user_id"] = user_id
|
||||
result = self._conn.execute(text(sql), params)
|
||||
row = result.fetchone()
|
||||
return _attachment_to_dict(row) if row is not None else None
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM attachments WHERE user_id = :user_id ORDER BY created_at DESC"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return [_attachment_to_dict(r) for r in result.fetchall()]
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update(self, attachment_id: str, user_id: str, fields: dict) -> bool:
|
||||
"""Partial update. Used by the LLM providers to cache their
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import sys
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
@@ -21,15 +20,6 @@ def _serialize_for_token_count(value):
|
||||
if value is None:
|
||||
return ""
|
||||
|
||||
# Raw binary payloads (image/file attachments arrive as ``bytes`` from
|
||||
# ``GoogleLLM.prepare_messages_with_attachments``) — without this
|
||||
# branch they fall through to ``str(value)`` below, which produces a
|
||||
# multi-megabyte ``"b'\\x89PNG...'"`` repr-string and inflates
|
||||
# ``prompt_tokens`` by orders of magnitude. Same intent as the
|
||||
# data-URL skip above.
|
||||
if isinstance(value, (bytes, bytearray, memoryview)):
|
||||
return ""
|
||||
|
||||
if isinstance(value, list):
|
||||
return [_serialize_for_token_count(item) for item in value]
|
||||
|
||||
@@ -155,44 +145,19 @@ def stream_token_usage(func):
|
||||
**kwargs,
|
||||
)
|
||||
batch = []
|
||||
started_at = time.monotonic()
|
||||
error: BaseException | None = None
|
||||
try:
|
||||
result = func(self, model, messages, stream, tools, **kwargs)
|
||||
for r in result:
|
||||
batch.append(r)
|
||||
yield r
|
||||
except Exception as exc:
|
||||
# ``GeneratorExit`` (consumer disconnected) and KeyboardInterrupt
|
||||
# flow through as ``status="ok"`` — same convention as
|
||||
# ``application.logging._consume_and_log``.
|
||||
error = exc
|
||||
raise
|
||||
finally:
|
||||
for line in batch:
|
||||
call_usage["generated_tokens"] += _count_tokens(line)
|
||||
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
||||
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
||||
# Persist usage rows only on success: a partial mid-stream
|
||||
# failure shouldn't bill the user for a response they never got.
|
||||
if error is None:
|
||||
update_token_usage(
|
||||
self.decoded_token,
|
||||
self.user_api_key,
|
||||
call_usage,
|
||||
getattr(self, "agent_id", None),
|
||||
)
|
||||
emit = getattr(self, "_emit_stream_finished_log", None)
|
||||
if callable(emit):
|
||||
try:
|
||||
emit(
|
||||
model,
|
||||
prompt_tokens=call_usage["prompt_tokens"],
|
||||
completion_tokens=call_usage["generated_tokens"],
|
||||
latency_ms=int((time.monotonic() - started_at) * 1000),
|
||||
error=error,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to emit llm_stream_finished")
|
||||
result = func(self, model, messages, stream, tools, **kwargs)
|
||||
for r in result:
|
||||
batch.append(r)
|
||||
yield r
|
||||
for line in batch:
|
||||
call_usage["generated_tokens"] += _count_tokens(line)
|
||||
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
||||
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
||||
update_token_usage(
|
||||
self.decoded_token,
|
||||
self.user_api_key,
|
||||
call_usage,
|
||||
getattr(self, "agent_id", None),
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -432,10 +432,7 @@ def run_agent_logic(agent_config, input_data):
|
||||
"tool_calls": tool_calls,
|
||||
"thought": thought,
|
||||
}
|
||||
# Per-activity summary fields (answer_length, thought_length,
|
||||
# source_count, tool_call_count) now ride on the inner
|
||||
# ``activity_finished`` event emitted by ``log_activity`` around
|
||||
# ``Agent.gen`` above; no separate ``agent_response`` log needed.
|
||||
logging.info(f"Agent response: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error in run_agent_logic: {e}", exc_info=True)
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
"""Unit tests for ``log_context`` + ``_ContextFilter``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from application.core import log_context
|
||||
from application.core.logging_config import _ContextFilter
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_log_ctx():
|
||||
# The contextvar is module-scoped; snapshot at entry, restore at exit
|
||||
# to keep tests from leaking state into each other.
|
||||
token = log_context.bind()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
log_context.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBindAndSnapshot:
|
||||
|
||||
def test_bind_returns_token_and_snapshot_reflects_overlay(self):
|
||||
token = log_context.bind(activity_id="a1", user_id="u1")
|
||||
assert log_context.snapshot() == {"activity_id": "a1", "user_id": "u1"}
|
||||
log_context.reset(token)
|
||||
assert log_context.snapshot() == {}
|
||||
|
||||
def test_bind_drops_unknown_keys(self):
|
||||
token = log_context.bind(activity_id="a1", not_a_real_key="boom")
|
||||
try:
|
||||
assert log_context.snapshot() == {"activity_id": "a1"}
|
||||
finally:
|
||||
log_context.reset(token)
|
||||
|
||||
def test_bind_drops_none_values(self):
|
||||
token = log_context.bind(activity_id="a1", agent_id=None)
|
||||
try:
|
||||
assert "agent_id" not in log_context.snapshot()
|
||||
finally:
|
||||
log_context.reset(token)
|
||||
|
||||
def test_bind_coerces_values_to_str(self):
|
||||
token = log_context.bind(activity_id=42)
|
||||
try:
|
||||
assert log_context.snapshot()["activity_id"] == "42"
|
||||
finally:
|
||||
log_context.reset(token)
|
||||
|
||||
def test_nested_bind_overlays_and_resets_lifo(self):
|
||||
outer = log_context.bind(activity_id="outer", user_id="u1")
|
||||
inner = log_context.bind(activity_id="inner", agent_id="agent-1")
|
||||
# Inner overrides activity_id, keeps user_id from outer, adds agent_id.
|
||||
assert log_context.snapshot() == {
|
||||
"activity_id": "inner",
|
||||
"user_id": "u1",
|
||||
"agent_id": "agent-1",
|
||||
}
|
||||
log_context.reset(inner)
|
||||
assert log_context.snapshot() == {"activity_id": "outer", "user_id": "u1"}
|
||||
log_context.reset(outer)
|
||||
assert log_context.snapshot() == {}
|
||||
|
||||
def test_parent_activity_id_pattern(self):
|
||||
outer = log_context.bind(activity_id="parent-1")
|
||||
parent = log_context.snapshot().get("activity_id")
|
||||
inner = log_context.bind(activity_id="child-1", parent_activity_id=parent)
|
||||
try:
|
||||
snap = log_context.snapshot()
|
||||
assert snap["activity_id"] == "child-1"
|
||||
assert snap["parent_activity_id"] == "parent-1"
|
||||
finally:
|
||||
log_context.reset(inner)
|
||||
log_context.reset(outer)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestContextFilter:
|
||||
|
||||
def _make_record(self, **extra) -> logging.LogRecord:
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname=__file__,
|
||||
lineno=1,
|
||||
msg="hello",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
for k, v in extra.items():
|
||||
setattr(record, k, v)
|
||||
return record
|
||||
|
||||
def test_stamps_record_with_context(self):
|
||||
token = log_context.bind(activity_id="a1", user_id="u1")
|
||||
try:
|
||||
record = self._make_record()
|
||||
assert _ContextFilter().filter(record) is True
|
||||
assert record.activity_id == "a1"
|
||||
assert record.user_id == "u1"
|
||||
finally:
|
||||
log_context.reset(token)
|
||||
|
||||
def test_explicit_extra_wins_over_context(self):
|
||||
# extra={} sets attributes on the record before the filter runs;
|
||||
# the filter must not overwrite them.
|
||||
token = log_context.bind(activity_id="from-ctx")
|
||||
try:
|
||||
record = self._make_record(activity_id="from-extra")
|
||||
_ContextFilter().filter(record)
|
||||
assert record.activity_id == "from-extra"
|
||||
finally:
|
||||
log_context.reset(token)
|
||||
|
||||
def test_no_op_when_context_empty(self):
|
||||
record = self._make_record()
|
||||
assert _ContextFilter().filter(record) is True
|
||||
assert not hasattr(record, "activity_id")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFilterWiringEndToEnd:
|
||||
"""Regression guard: the filter must be installed on handlers, not on
|
||||
loggers — Python skips logger-level filters during propagation.
|
||||
"""
|
||||
|
||||
def test_propagated_record_gets_stamped(self):
|
||||
from application.core.logging_config import _install_context_filter
|
||||
|
||||
captured: list[logging.LogRecord] = []
|
||||
|
||||
class _Capture(logging.Handler):
|
||||
def emit(self, record):
|
||||
captured.append(record)
|
||||
|
||||
root = logging.getLogger()
|
||||
saved_handlers = list(root.handlers)
|
||||
saved_level = root.level
|
||||
try:
|
||||
root.handlers = [_Capture()]
|
||||
root.setLevel(logging.DEBUG)
|
||||
_install_context_filter()
|
||||
|
||||
child = logging.getLogger("test_log_context.propagation")
|
||||
child.setLevel(logging.DEBUG)
|
||||
|
||||
token = log_context.bind(activity_id="propagated-id")
|
||||
try:
|
||||
child.info("from a child logger")
|
||||
finally:
|
||||
log_context.reset(token)
|
||||
|
||||
assert captured, "Capture handler should have received the record"
|
||||
assert captured[0].activity_id == "propagated-id"
|
||||
finally:
|
||||
root.handlers = saved_handlers
|
||||
root.setLevel(saved_level)
|
||||
@@ -360,38 +360,7 @@ class TestExtractTextFromContent:
|
||||
handler = ConcreteHandler()
|
||||
content = [{"files": ["/tmp/a.txt"]}]
|
||||
result = handler._extract_text_from_content(content)
|
||||
assert result == "[attachment: /tmp/a.txt]"
|
||||
|
||||
def test_list_with_inline_image_bytes(self):
|
||||
# Google attaches images as inline bytes; stringifying them into
|
||||
# the compression prompt would bust the compression LLM's input
|
||||
# limit. The placeholder must describe the attachment without
|
||||
# embedding the bytes.
|
||||
handler = ConcreteHandler()
|
||||
content = [
|
||||
{
|
||||
"files": [
|
||||
{"file_bytes": b"\x89PNG" + b"\x00" * 1000, "mime_type": "image/png"}
|
||||
]
|
||||
}
|
||||
]
|
||||
result = handler._extract_text_from_content(content)
|
||||
assert result == "[attachment: image/png]"
|
||||
assert "PNG" not in result
|
||||
assert "\\x" not in result
|
||||
|
||||
def test_list_with_multiple_files(self):
|
||||
handler = ConcreteHandler()
|
||||
content = [
|
||||
{
|
||||
"files": [
|
||||
{"file_bytes": b"a", "mime_type": "image/png"},
|
||||
{"file_uri": "https://x", "mime_type": "image/jpeg"},
|
||||
]
|
||||
}
|
||||
]
|
||||
result = handler._extract_text_from_content(content)
|
||||
assert result == "[attachment: image/png, image/jpeg]"
|
||||
assert "files" in result
|
||||
|
||||
def test_list_with_none_text(self):
|
||||
handler = ConcreteHandler()
|
||||
|
||||
@@ -49,9 +49,6 @@ class FailingLLM(BaseLLM):
|
||||
|
||||
|
||||
class FallbackLLM(BaseLLM):
|
||||
# _execute_with_fallback applies decorators to the fallback's raw method
|
||||
# directly and never calls .gen() / .gen_stream() on it, so
|
||||
# tracking lives on the raw methods.
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.gen_called = False
|
||||
@@ -65,6 +62,14 @@ class FallbackLLM(BaseLLM):
|
||||
self.gen_stream_called = True
|
||||
yield "fallback_chunk"
|
||||
|
||||
def gen(self, *args, **kwargs):
|
||||
self.gen_called = True
|
||||
return "fallback_gen_result"
|
||||
|
||||
def gen_stream(self, *args, **kwargs):
|
||||
self.gen_stream_called = True
|
||||
yield "fallback_stream_chunk"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# gen / gen_stream decorator application
|
||||
@@ -90,180 +95,6 @@ class TestGenMethods:
|
||||
)
|
||||
assert result == ["a", "b"]
|
||||
|
||||
@patch("application.llm.base.stream_cache", lambda f: f)
|
||||
@patch("application.llm.base.stream_token_usage", lambda f: f)
|
||||
def test_gen_stream_emits_llm_stream_start_event(self, caplog):
|
||||
import logging as _logging
|
||||
|
||||
class FakeProvider(StubLLM):
|
||||
provider_name = "fake-provider"
|
||||
|
||||
llm = FakeProvider(raw_gen_stream_items=["x"])
|
||||
with caplog.at_level(_logging.INFO, logger="root"):
|
||||
list(
|
||||
llm.gen_stream(
|
||||
model="m1",
|
||||
messages=[{"role": "user", "content": "hi"}, {"role": "assistant", "content": "hey"}],
|
||||
tools=[{"name": "t"}],
|
||||
_usage_attachments=[{"path": "/tmp/a.png"}],
|
||||
)
|
||||
)
|
||||
|
||||
starts = [r for r in caplog.records if r.message == "llm_stream_start"]
|
||||
assert len(starts) == 1
|
||||
evt = starts[0]
|
||||
assert evt.model == "m1"
|
||||
assert evt.provider == "fake-provider"
|
||||
assert evt.message_count == 2
|
||||
# ``_usage_attachments`` is what ``Agent._llm_gen`` actually passes;
|
||||
# the alias check below covers the bare ``attachments=`` form.
|
||||
assert evt.has_attachments is True
|
||||
assert evt.has_tools is True
|
||||
|
||||
@patch("application.llm.base.stream_cache", lambda f: f)
|
||||
@patch("application.llm.base.stream_token_usage", lambda f: f)
|
||||
def test_gen_stream_recognises_attachments_kwarg_alias(self, caplog):
|
||||
import logging as _logging
|
||||
|
||||
llm = StubLLM(raw_gen_stream_items=["x"])
|
||||
with caplog.at_level(_logging.INFO, logger="root"):
|
||||
list(
|
||||
llm.gen_stream(
|
||||
model="m1", messages=[], attachments=[{"path": "/tmp/a"}]
|
||||
)
|
||||
)
|
||||
evt = next(r for r in caplog.records if r.message == "llm_stream_start")
|
||||
assert evt.has_attachments is True
|
||||
|
||||
@patch("application.llm.base.stream_cache", lambda f: f)
|
||||
@patch("application.llm.base.stream_token_usage", lambda f: f)
|
||||
def test_gen_stream_emits_event_without_attachments_or_tools(self, caplog):
|
||||
import logging as _logging
|
||||
|
||||
llm = StubLLM(raw_gen_stream_items=["x"])
|
||||
with caplog.at_level(_logging.INFO, logger="root"):
|
||||
list(llm.gen_stream(model="m1", messages=[]))
|
||||
|
||||
evt = next(r for r in caplog.records if r.message == "llm_stream_start")
|
||||
assert evt.message_count == 0
|
||||
assert evt.has_attachments is False
|
||||
assert evt.has_tools is False
|
||||
# BaseLLM default — concrete providers always override.
|
||||
assert evt.provider == "unknown"
|
||||
|
||||
@patch("application.llm.base.stream_cache", lambda f: f)
|
||||
def test_gen_stream_emits_llm_stream_finished_on_success(self, caplog):
|
||||
# Real ``stream_token_usage`` so the emit-from-finally path runs.
|
||||
# ``update_token_usage`` short-circuits under pytest, so no DB
|
||||
# mocking is needed.
|
||||
import logging as _logging
|
||||
|
||||
class FakeProvider(StubLLM):
|
||||
provider_name = "fake-provider"
|
||||
|
||||
llm = FakeProvider(raw_gen_stream_items=["alpha", "beta"])
|
||||
llm.user_api_key = None
|
||||
with caplog.at_level(_logging.INFO, logger="root"):
|
||||
list(
|
||||
llm.gen_stream(
|
||||
model="m1",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
)
|
||||
|
||||
finished = [r for r in caplog.records if r.message == "llm_stream_finished"]
|
||||
assert len(finished) == 1
|
||||
evt = finished[0]
|
||||
assert evt.model == "m1"
|
||||
assert evt.provider == "fake-provider"
|
||||
assert evt.status == "ok"
|
||||
assert isinstance(evt.prompt_tokens, int) and evt.prompt_tokens >= 0
|
||||
assert isinstance(evt.completion_tokens, int) and evt.completion_tokens > 0
|
||||
assert isinstance(evt.latency_ms, int) and evt.latency_ms >= 0
|
||||
# ``cached_tokens`` is intentionally absent until per-provider
|
||||
# vendor-usage extraction lands.
|
||||
assert not hasattr(evt, "cached_tokens")
|
||||
assert not hasattr(evt, "error_class")
|
||||
|
||||
@patch("application.llm.base.stream_cache", lambda f: f)
|
||||
def test_gen_stream_emits_llm_stream_finished_on_error(self, caplog):
|
||||
import logging as _logging
|
||||
|
||||
class FakeProvider(BaseLLM):
|
||||
provider_name = "fake-provider"
|
||||
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kw):
|
||||
return "x"
|
||||
|
||||
def _raw_gen_stream(self, baseself, model, messages, stream=True, tools=None, **kw):
|
||||
yield "partial"
|
||||
raise RuntimeError("mid_stream_boom")
|
||||
|
||||
llm = FakeProvider()
|
||||
llm.user_api_key = None
|
||||
with caplog.at_level(_logging.INFO, logger="root"), pytest.raises(RuntimeError):
|
||||
list(llm.gen_stream(model="m1", messages=[]))
|
||||
|
||||
finished = [r for r in caplog.records if r.message == "llm_stream_finished"]
|
||||
assert len(finished) == 1
|
||||
evt = finished[0]
|
||||
assert evt.status == "error"
|
||||
assert evt.error_class == "RuntimeError"
|
||||
# Partial completion tokens still recorded (the chunk yielded
|
||||
# before the failure is in the batch).
|
||||
assert evt.completion_tokens > 0
|
||||
|
||||
@patch("application.llm.base.stream_cache", lambda f: f)
|
||||
def test_gen_stream_finished_event_paired_with_stream_start(self, caplog):
|
||||
# The two events form a pair the cost dashboards join on; verify
|
||||
# they always come in order and from the same provider/model.
|
||||
import logging as _logging
|
||||
|
||||
class FakeProvider(StubLLM):
|
||||
provider_name = "fake-provider"
|
||||
|
||||
llm = FakeProvider(raw_gen_stream_items=["x"])
|
||||
llm.user_api_key = None
|
||||
with caplog.at_level(_logging.INFO, logger="root"):
|
||||
list(llm.gen_stream(model="m1", messages=[]))
|
||||
|
||||
records = [
|
||||
r for r in caplog.records
|
||||
if r.message in ("llm_stream_start", "llm_stream_finished")
|
||||
]
|
||||
assert [r.message for r in records] == ["llm_stream_start", "llm_stream_finished"]
|
||||
assert records[0].model == records[1].model == "m1"
|
||||
assert records[0].provider == records[1].provider == "fake-provider"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestProviderNameRegistry:
|
||||
"""A new provider without ``provider_name`` would silently report
|
||||
``provider="unknown"`` in telemetry. Pin the expected values here."""
|
||||
|
||||
def test_provider_names_match_expectations(self):
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
from application.llm.docsgpt_provider import DocsGPTAPILLM
|
||||
from application.llm.google_ai import GoogleLLM
|
||||
from application.llm.groq import GroqLLM
|
||||
from application.llm.llama_cpp import LlamaCpp
|
||||
from application.llm.novita import NovitaLLM
|
||||
from application.llm.open_router import OpenRouterLLM
|
||||
from application.llm.openai import OpenAILLM
|
||||
from application.llm.premai import PremAILLM
|
||||
from application.llm.sagemaker import SagemakerAPILLM
|
||||
|
||||
assert OpenAILLM.provider_name == "openai"
|
||||
assert GoogleLLM.provider_name == "google"
|
||||
assert AnthropicLLM.provider_name == "anthropic"
|
||||
assert GroqLLM.provider_name == "groq"
|
||||
assert NovitaLLM.provider_name == "novita"
|
||||
assert OpenRouterLLM.provider_name == "openrouter"
|
||||
assert DocsGPTAPILLM.provider_name == "docsgpt"
|
||||
assert PremAILLM.provider_name == "premai"
|
||||
assert LlamaCpp.provider_name == "llama_cpp"
|
||||
assert SagemakerAPILLM.provider_name == "sagemaker"
|
||||
|
||||
@patch("application.llm.base.gen_cache", lambda f: f)
|
||||
@patch("application.llm.base.gen_token_usage", lambda f: f)
|
||||
def test_gen_passes_tools(self):
|
||||
@@ -309,7 +140,7 @@ class TestExecuteWithFallbackNonStreaming:
|
||||
llm._fallback_llm = fallback
|
||||
|
||||
result = llm.gen(model="m", messages=[])
|
||||
assert result == "fallback_result"
|
||||
assert result == "fallback_gen_result"
|
||||
assert fallback.gen_called
|
||||
|
||||
|
||||
@@ -336,97 +167,10 @@ class TestStreamWithFallback:
|
||||
llm._fallback_llm = fallback
|
||||
|
||||
result = list(llm.gen_stream(model="m", messages=[]))
|
||||
assert "fallback_chunk" in result
|
||||
assert "fallback_stream_chunk" in result
|
||||
assert fallback.gen_stream_called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-retriable client error guard
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StatusError(Exception):
|
||||
"""Mimics openai/anthropic-shaped client errors with a status_code."""
|
||||
|
||||
def __init__(self, status_code, message="bad request"):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class _ClientErrorLLM(BaseLLM):
|
||||
def __init__(self, status_code, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._status = status_code
|
||||
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kw):
|
||||
raise _StatusError(self._status)
|
||||
|
||||
def _raw_gen_stream(self, baseself, model, messages, stream=True, tools=None, **kw):
|
||||
raise _StatusError(self._status)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNonRetriableClientError:
|
||||
|
||||
def test_helper_detects_4xx_status_code(self):
|
||||
assert BaseLLM._is_non_retriable_client_error(_StatusError(400))
|
||||
assert BaseLLM._is_non_retriable_client_error(_StatusError(404))
|
||||
assert BaseLLM._is_non_retriable_client_error(_StatusError(429))
|
||||
|
||||
def test_helper_passes_5xx_through(self):
|
||||
# 5xx and connection errors should still trigger fallback.
|
||||
assert not BaseLLM._is_non_retriable_client_error(_StatusError(500))
|
||||
assert not BaseLLM._is_non_retriable_client_error(_StatusError(503))
|
||||
assert not BaseLLM._is_non_retriable_client_error(RuntimeError("oops"))
|
||||
|
||||
def test_helper_detects_genai_client_error(self):
|
||||
try:
|
||||
from google.genai.errors import ClientError
|
||||
except ImportError:
|
||||
pytest.skip("google-genai not installed")
|
||||
# ClientError(code, response_json, response=None)
|
||||
exc = ClientError(400, {"error": {"message": "bad", "code": 400}}, None)
|
||||
assert BaseLLM._is_non_retriable_client_error(exc)
|
||||
|
||||
def test_helper_detects_response_status_code(self):
|
||||
exc = RuntimeError("wrapped")
|
||||
exc.response = type("R", (), {"status_code": 401})()
|
||||
assert BaseLLM._is_non_retriable_client_error(exc)
|
||||
|
||||
@patch("application.llm.base.gen_cache", lambda f: f)
|
||||
@patch("application.llm.base.gen_token_usage", lambda f: f)
|
||||
def test_4xx_skips_fallback(self):
|
||||
fallback = FallbackLLM(model_id="fallback-model")
|
||||
llm = _ClientErrorLLM(status_code=400)
|
||||
llm._fallback_llm = fallback
|
||||
|
||||
with pytest.raises(_StatusError):
|
||||
llm.gen(model="m", messages=[])
|
||||
assert not fallback.gen_called
|
||||
|
||||
@patch("application.llm.base.stream_cache", lambda f: f)
|
||||
@patch("application.llm.base.stream_token_usage", lambda f: f)
|
||||
def test_4xx_skips_stream_fallback(self):
|
||||
fallback = FallbackLLM(model_id="fallback-model")
|
||||
llm = _ClientErrorLLM(status_code=400)
|
||||
llm._fallback_llm = fallback
|
||||
|
||||
with pytest.raises(_StatusError):
|
||||
list(llm.gen_stream(model="m", messages=[]))
|
||||
assert not fallback.gen_stream_called
|
||||
|
||||
@patch("application.llm.base.gen_cache", lambda f: f)
|
||||
@patch("application.llm.base.gen_token_usage", lambda f: f)
|
||||
def test_5xx_still_falls_back(self):
|
||||
fallback = FallbackLLM(model_id="fallback-model")
|
||||
llm = _ClientErrorLLM(status_code=503)
|
||||
llm._fallback_llm = fallback
|
||||
|
||||
result = llm.gen(model="m", messages=[])
|
||||
assert result == "fallback_result"
|
||||
assert fallback.gen_called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fallback_llm property: backup model resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -31,20 +31,12 @@ class FakeLLM(BaseLLM):
|
||||
self.gen_stream_called = False
|
||||
self.last_model_received = None # tracks the model kwarg passed to gen/gen_stream
|
||||
|
||||
# Track at the raw-method level. _execute_with_fallback applies
|
||||
# decorators to the fallback's raw method directly and
|
||||
# never calls .gen() / .gen_stream() on it, so a public-method
|
||||
# override would not register fallback hops.
|
||||
def _raw_gen(self, baseself, model, messages, stream, tools=None, **kwargs):
|
||||
self.gen_called = True
|
||||
self.last_model_received = model
|
||||
if self.fail_at is not None:
|
||||
raise RuntimeError("primary model unavailable")
|
||||
return self.responses[0]
|
||||
|
||||
def _raw_gen_stream(self, baseself, model, messages, stream, tools=None, **kwargs):
|
||||
self.gen_stream_called = True
|
||||
self.last_model_received = model
|
||||
yielded = 0
|
||||
for chunk in self.stream_chunks:
|
||||
if self.fail_at is not None and yielded >= self.fail_at:
|
||||
@@ -52,6 +44,18 @@ class FakeLLM(BaseLLM):
|
||||
yield chunk
|
||||
yielded += 1
|
||||
|
||||
# Wrap gen/gen_stream so we can track whether the fallback instance was used
|
||||
# and which model kwarg it received
|
||||
def gen(self, *args, **kwargs):
|
||||
self.gen_called = True
|
||||
self.last_model_received = kwargs.get("model")
|
||||
return super().gen(*args, **kwargs)
|
||||
|
||||
def gen_stream(self, *args, **kwargs):
|
||||
self.gen_stream_called = True
|
||||
self.last_model_received = kwargs.get("model")
|
||||
return super().gen_stream(*args, **kwargs)
|
||||
|
||||
|
||||
# Helpers
|
||||
|
||||
@@ -290,140 +294,6 @@ class TestStreamingFallback:
|
||||
with pytest.raises(RuntimeError, match="mid-stream failure"):
|
||||
list(primary.gen_stream(**CALL_ARGS))
|
||||
|
||||
def test_fallback_emits_stream_start_with_fallback_provider(
|
||||
self, patch_model_utils, caplog
|
||||
):
|
||||
# The fallback raw-stream path bypasses ``gen_stream``, so it must
|
||||
# emit its own ``llm_stream_start`` event tagged with the fallback
|
||||
# vendor — otherwise dashboards record only the failed primary
|
||||
# even when the response came from the backup.
|
||||
import logging as _logging
|
||||
|
||||
class FallbackProvider(FakeLLM):
|
||||
provider_name = "fallback-vendor"
|
||||
|
||||
backup = FallbackProvider(
|
||||
stream_chunks=["b1"], model_id="backup-model-id"
|
||||
)
|
||||
patch_model_utils(
|
||||
get_provider=lambda m, **_kwargs: "openai",
|
||||
get_api_key=lambda p: "k",
|
||||
create_llm=lambda type, **kw: backup,
|
||||
)
|
||||
|
||||
class PrimaryProvider(FakeLLM):
|
||||
provider_name = "primary-vendor"
|
||||
|
||||
primary = PrimaryProvider(
|
||||
stream_chunks=["x"],
|
||||
fail_at=0,
|
||||
backup_models=["backup-model-id"],
|
||||
)
|
||||
|
||||
with caplog.at_level(_logging.INFO, logger="root"):
|
||||
list(
|
||||
primary.gen_stream(
|
||||
model="primary-model",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
)
|
||||
|
||||
starts = [r for r in caplog.records if r.message == "llm_stream_start"]
|
||||
assert len(starts) == 2
|
||||
assert starts[0].provider == "primary-vendor"
|
||||
assert starts[0].model == "primary-model"
|
||||
assert starts[1].provider == "fallback-vendor"
|
||||
assert starts[1].model == "backup-model-id"
|
||||
|
||||
|
||||
# Tests — fallback never re-enters the orchestrator (Option B regression)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestFallbackNoRecursion:
|
||||
"""When the primary fails, _execute_with_fallback applies decorators to
|
||||
the fallback's raw method directly. The fallback's own ``fallback_llm``
|
||||
property must never be accessed — otherwise a fallback failure would
|
||||
re-enter the orchestrator and walk the global FALLBACK_LLM_* chain
|
||||
unboundedly."""
|
||||
|
||||
def test_backup_fallback_llm_property_never_accessed_on_gen_failure(
|
||||
self, monkeypatch, patch_model_utils
|
||||
):
|
||||
backup = FakeLLM(fail_at=0) # backup also fails
|
||||
|
||||
accessed_on = []
|
||||
original_property = BaseLLM.fallback_llm
|
||||
|
||||
def tracked_fallback_llm(self_llm):
|
||||
accessed_on.append(self_llm)
|
||||
return original_property.fget(self_llm)
|
||||
|
||||
monkeypatch.setattr(
|
||||
BaseLLM, "fallback_llm", property(tracked_fallback_llm)
|
||||
)
|
||||
|
||||
patch_model_utils(
|
||||
get_provider=lambda m, **_kwargs: "openai",
|
||||
get_api_key=lambda p: "k",
|
||||
create_llm=lambda type, **kw: backup,
|
||||
)
|
||||
|
||||
primary = FakeLLM(fail_at=0, backup_models=["backup-model"])
|
||||
with pytest.raises(RuntimeError, match="primary model unavailable"):
|
||||
primary.gen(**CALL_ARGS)
|
||||
|
||||
assert primary in accessed_on # primary lazy-loaded its fallback
|
||||
assert backup not in accessed_on # backup's chain was never walked
|
||||
|
||||
def test_backup_fallback_llm_property_never_accessed_on_stream_failure(
|
||||
self, monkeypatch, patch_model_utils
|
||||
):
|
||||
backup = FakeLLM(stream_chunks=["x"], fail_at=0)
|
||||
|
||||
accessed_on = []
|
||||
original_property = BaseLLM.fallback_llm
|
||||
|
||||
def tracked_fallback_llm(self_llm):
|
||||
accessed_on.append(self_llm)
|
||||
return original_property.fget(self_llm)
|
||||
|
||||
monkeypatch.setattr(
|
||||
BaseLLM, "fallback_llm", property(tracked_fallback_llm)
|
||||
)
|
||||
|
||||
patch_model_utils(
|
||||
get_provider=lambda m, **_kwargs: "openai",
|
||||
get_api_key=lambda p: "k",
|
||||
create_llm=lambda type, **kw: backup,
|
||||
)
|
||||
|
||||
primary = FakeLLM(
|
||||
stream_chunks=["y"], fail_at=0, backup_models=["backup-model"]
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="mid-stream failure"):
|
||||
list(primary.gen_stream(**CALL_ARGS))
|
||||
|
||||
assert primary in accessed_on
|
||||
assert backup not in accessed_on
|
||||
|
||||
def test_fallback_failure_propagates_without_chain(self, patch_model_utils):
|
||||
"""When both primary and fallback fail, the fallback's exception
|
||||
propagates cleanly — no third hop, no extra retries."""
|
||||
backup = FakeLLM(fail_at=0)
|
||||
|
||||
patch_model_utils(
|
||||
get_provider=lambda m, **_kwargs: "openai",
|
||||
get_api_key=lambda p: "k",
|
||||
create_llm=lambda type, **kw: backup,
|
||||
)
|
||||
|
||||
primary = FakeLLM(fail_at=0, backup_models=["backup-model"])
|
||||
with pytest.raises(RuntimeError, match="primary model unavailable"):
|
||||
primary.gen(**CALL_ARGS)
|
||||
|
||||
assert backup.gen_called # confirms fallback raw method WAS invoked
|
||||
|
||||
|
||||
# Tests — backup model priority over global fallback
|
||||
|
||||
|
||||
@@ -28,11 +28,10 @@ from application.llm.google_ai import GoogleLLM
|
||||
|
||||
|
||||
class _FakePart:
|
||||
def __init__(self, text=None, function_call=None, file_data=None, inline_data=None, thought=False, **kwargs):
|
||||
def __init__(self, text=None, function_call=None, file_data=None, thought=False, **kwargs):
|
||||
self.text = text
|
||||
self.function_call = function_call or kwargs.get("functionCall")
|
||||
self.file_data = file_data
|
||||
self.inline_data = inline_data
|
||||
self.thought = thought
|
||||
self.thoughtSignature = kwargs.get("thoughtSignature")
|
||||
|
||||
@@ -54,12 +53,6 @@ class _FakePart:
|
||||
file_data=types.SimpleNamespace(file_uri=file_uri, mime_type=mime_type)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data, mime_type):
|
||||
return _FakePart(
|
||||
inline_data=types.SimpleNamespace(data=data, mime_type=mime_type)
|
||||
)
|
||||
|
||||
|
||||
class _FakeContent:
|
||||
def __init__(self, role, parts):
|
||||
@@ -230,43 +223,6 @@ class TestCleanMessagesGoogle:
|
||||
for p in cleaned[0].parts
|
||||
)
|
||||
|
||||
def test_files_with_inline_bytes(self, llm):
|
||||
msgs = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"files": [
|
||||
{"file_bytes": b"\x89PNG", "mime_type": "image/png"}
|
||||
]
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
cleaned, _ = llm._clean_messages_google(msgs)
|
||||
assert len(cleaned) == 1
|
||||
inline_parts = [
|
||||
p for p in cleaned[0].parts
|
||||
if getattr(p, "inline_data", None) is not None
|
||||
]
|
||||
assert len(inline_parts) == 1
|
||||
assert inline_parts[0].inline_data.data == b"\x89PNG"
|
||||
assert inline_parts[0].inline_data.mime_type == "image/png"
|
||||
|
||||
def test_files_with_empty_uri_dropped(self, llm):
|
||||
msgs = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"files": [{"file_uri": "", "mime_type": "image/png"}]},
|
||||
],
|
||||
}
|
||||
]
|
||||
cleaned, _ = llm._clean_messages_google(msgs)
|
||||
# Empty URI part is dropped; no other parts means the whole
|
||||
# content is empty and the message itself is not appended.
|
||||
assert cleaned == []
|
||||
|
||||
def test_unexpected_list_item_raises(self, llm):
|
||||
msgs = [{"role": "user", "content": [{"unknown_key": "val"}]}]
|
||||
with pytest.raises(ValueError, match="Unexpected content dictionary"):
|
||||
@@ -754,9 +710,7 @@ class TestPrepareMessagesWithAttachments:
|
||||
|
||||
def test_upload_error_adds_text_fallback(self, llm, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
llm,
|
||||
"_read_attachment_bytes",
|
||||
lambda a: (_ for _ in ()).throw(Exception("fail")),
|
||||
llm, "_upload_file_to_google", lambda a: (_ for _ in ()).throw(Exception("fail"))
|
||||
)
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
attachments = [
|
||||
@@ -770,57 +724,8 @@ class TestPrepareMessagesWithAttachments:
|
||||
]
|
||||
assert len(text_parts) == 1
|
||||
|
||||
def test_pdf_upload_error_adds_text_fallback(self, llm, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
llm,
|
||||
"_upload_file_to_google",
|
||||
lambda a: (_ for _ in ()).throw(Exception("fail")),
|
||||
)
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
attachments = [
|
||||
{"mime_type": "application/pdf", "path": "/tmp/doc.pdf", "content": "x"},
|
||||
]
|
||||
result = llm.prepare_messages_with_attachments(msgs, attachments)
|
||||
user_msg = next(m for m in result if m["role"] == "user")
|
||||
text_parts = [
|
||||
p for p in user_msg["content"]
|
||||
if isinstance(p, dict) and p.get("type") == "text" and "could not" in p.get("text", "").lower()
|
||||
]
|
||||
assert len(text_parts) == 1
|
||||
|
||||
def test_pdf_empty_uri_adds_text_fallback(self, llm, monkeypatch):
|
||||
monkeypatch.setattr(llm, "_upload_file_to_google", lambda a: "")
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
attachments = [
|
||||
{"mime_type": "application/pdf", "path": "/tmp/doc.pdf", "content": "x"},
|
||||
]
|
||||
result = llm.prepare_messages_with_attachments(msgs, attachments)
|
||||
user_msg = next(m for m in result if m["role"] == "user")
|
||||
files_entries = [
|
||||
p for p in user_msg["content"] if isinstance(p, dict) and "files" in p
|
||||
]
|
||||
assert files_entries == []
|
||||
text_parts = [
|
||||
p for p in user_msg["content"]
|
||||
if isinstance(p, dict) and p.get("type") == "text" and "could not" in p.get("text", "").lower()
|
||||
]
|
||||
assert len(text_parts) == 1
|
||||
|
||||
def test_image_uses_inline_bytes(self, llm, monkeypatch):
|
||||
monkeypatch.setattr(llm, "_read_attachment_bytes", lambda a: b"\x89PNG-bytes")
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
attachments = [{"mime_type": "image/png", "path": "/img.png"}]
|
||||
result = llm.prepare_messages_with_attachments(msgs, attachments)
|
||||
user_msg = next(m for m in result if m["role"] == "user")
|
||||
files_entry = next(
|
||||
p for p in user_msg["content"] if isinstance(p, dict) and "files" in p
|
||||
)
|
||||
assert files_entry["files"] == [
|
||||
{"file_bytes": b"\x89PNG-bytes", "mime_type": "image/png"}
|
||||
]
|
||||
|
||||
def test_no_user_message_creates_one(self, llm, monkeypatch):
|
||||
monkeypatch.setattr(llm, "_read_attachment_bytes", lambda a: b"png")
|
||||
monkeypatch.setattr(llm, "_upload_file_to_google", lambda a: "gs://uri")
|
||||
msgs = [{"role": "system", "content": "sys"}]
|
||||
attachments = [{"mime_type": "image/png", "path": "/img.png"}]
|
||||
result = llm.prepare_messages_with_attachments(msgs, attachments)
|
||||
@@ -841,26 +746,6 @@ class TestUploadFileToGoogle:
|
||||
result = llm._upload_file_to_google(attachment)
|
||||
assert result == "gs://cached"
|
||||
|
||||
def test_empty_cached_uri_triggers_reupload(self, llm, monkeypatch):
|
||||
# Poisoned-cache repro: an empty-string google_file_uri must be
|
||||
# treated as a miss and re-upload, not returned as-is.
|
||||
monkeypatch.setattr(
|
||||
"application.llm.google_ai.settings",
|
||||
types.SimpleNamespace(GOOGLE_API_KEY="k", API_KEY="k"),
|
||||
)
|
||||
result = llm._upload_file_to_google(
|
||||
{"google_file_uri": "", "path": "/tmp/file.pdf"}
|
||||
)
|
||||
assert result == "gs://fake-uri"
|
||||
|
||||
def test_empty_upload_uri_raises(self, llm):
|
||||
llm.storage = types.SimpleNamespace(
|
||||
file_exists=lambda p: True,
|
||||
process_file=lambda path, fn, **kw: "",
|
||||
)
|
||||
with pytest.raises(ValueError, match="empty URI"):
|
||||
llm._upload_file_to_google({"path": "/tmp/file.pdf"})
|
||||
|
||||
def test_raises_for_no_path(self, llm):
|
||||
with pytest.raises(ValueError, match="No file path"):
|
||||
llm._upload_file_to_google({})
|
||||
|
||||
@@ -228,7 +228,6 @@ def test_prepare_messages_with_attachments_appends_files(monkeypatch):
|
||||
process_file=lambda path, processor_func, **kwargs: "gs://file_uri"
|
||||
)
|
||||
monkeypatch.setattr(llm, "_upload_file_to_google", lambda att: "gs://file_uri")
|
||||
monkeypatch.setattr(llm, "_read_attachment_bytes", lambda att: b"png-bytes")
|
||||
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
attachments = [
|
||||
@@ -241,9 +240,4 @@ def test_prepare_messages_with_attachments_appends_files(monkeypatch):
|
||||
assert isinstance(user_msg["content"], list)
|
||||
files_entry = next((p for p in user_msg["content"] if isinstance(p, dict) and "files" in p), None)
|
||||
assert files_entry is not None
|
||||
files = files_entry["files"]
|
||||
assert len(files) == 2
|
||||
image_part = next(f for f in files if f["mime_type"] == "image/png")
|
||||
pdf_part = next(f for f in files if f["mime_type"] == "application/pdf")
|
||||
assert image_part == {"file_bytes": b"png-bytes", "mime_type": "image/png"}
|
||||
assert pdf_part == {"file_uri": "gs://file_uri", "mime_type": "application/pdf"}
|
||||
assert isinstance(files_entry["files"], list) and len(files_entry["files"]) == 2
|
||||
|
||||
@@ -31,23 +31,6 @@ class TestCreate:
|
||||
doc = repo.create("u", "f", "/p")
|
||||
assert doc["_id"] == doc["id"]
|
||||
|
||||
def test_create_aliases_upload_path_as_path(self, pg_conn):
|
||||
# LLM provider code (google_ai/openai/anthropic and handlers/base)
|
||||
# reads attachment.get("path") — preserved from the legacy Mongo
|
||||
# shape. Repo emits both keys so consumers don't need to know
|
||||
# which storage backend produced the dict.
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("u", "f", "/uploads/x.png")
|
||||
assert doc["path"] == "/uploads/x.png"
|
||||
assert doc["upload_path"] == "/uploads/x.png"
|
||||
|
||||
def test_get_aliases_upload_path_as_path(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("u", "f", "/uploads/y.pdf")
|
||||
fetched = repo.get(created["id"], "u")
|
||||
assert fetched is not None
|
||||
assert fetched["path"] == "/uploads/y.pdf"
|
||||
|
||||
def test_create_with_legacy_mongo_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create(
|
||||
|
||||
@@ -438,94 +438,3 @@ def test_stream_cache_key_generation_failure_yields(mock_make_redis):
|
||||
messages = ["not_a_dict"]
|
||||
result = list(mock_function(None, "model", messages, stream=True, tools=None))
|
||||
assert result == ["fallback_chunk"]
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# gen_cache_key with inline bytes (Google attachments)
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gen_cache_key_handles_inline_bytes():
|
||||
"""Image attachments arrive in messages as raw bytes (see
|
||||
GoogleLLM.prepare_messages_with_attachments). gen_cache_key must not
|
||||
crash on json.dumps of bytes."""
|
||||
msgs = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"file_bytes": b"\x00\x01\x02", "mime_type": "image/png"}],
|
||||
}
|
||||
]
|
||||
key = gen_cache_key(msgs, model="x")
|
||||
assert isinstance(key, str)
|
||||
assert len(key) == 32
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gen_cache_key_stable_for_same_bytes():
|
||||
"""Two requests with identical image bytes must produce the same key
|
||||
— otherwise we'd never get cache hits on image-bearing prompts."""
|
||||
a = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"file_bytes": b"abc", "mime_type": "image/png"}],
|
||||
}
|
||||
]
|
||||
b = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"file_bytes": b"abc", "mime_type": "image/png"}],
|
||||
}
|
||||
]
|
||||
assert gen_cache_key(a, "m") == gen_cache_key(b, "m")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gen_cache_key_differs_for_different_bytes():
|
||||
"""Different image bytes must produce different keys — otherwise two
|
||||
different images would collide in cache."""
|
||||
a = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"file_bytes": b"abc", "mime_type": "image/png"}],
|
||||
}
|
||||
]
|
||||
b = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"file_bytes": b"xyz", "mime_type": "image/png"}],
|
||||
}
|
||||
]
|
||||
assert gen_cache_key(a, "m") != gen_cache_key(b, "m")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gen_cache_key_handles_bytearray_and_memoryview():
|
||||
"""The default helper covers all bytes-like types so refactors that
|
||||
swap bytes for bytearray/memoryview don't silently re-introduce the
|
||||
TypeError."""
|
||||
msgs_ba = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"file_bytes": bytearray(b"abc"), "mime_type": "image/png"}
|
||||
],
|
||||
}
|
||||
]
|
||||
msgs_mv = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"file_bytes": memoryview(b"abc"), "mime_type": "image/png"}
|
||||
],
|
||||
}
|
||||
]
|
||||
msgs_b = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"file_bytes": b"abc", "mime_type": "image/png"}],
|
||||
}
|
||||
]
|
||||
# All three should hash the same content to the same key.
|
||||
assert gen_cache_key(msgs_ba, "m") == gen_cache_key(msgs_b, "m")
|
||||
assert gen_cache_key(msgs_mv, "m") == gen_cache_key(msgs_b, "m")
|
||||
|
||||
@@ -427,109 +427,6 @@ class TestCompressionService:
|
||||
|
||||
assert token_count_with_tools > token_count_without_tools
|
||||
|
||||
def test_count_message_tokens_skips_inline_image_bytes(self):
|
||||
# Google attaches images as raw bytes inline. Stringifying the
|
||||
# part would tokenize the byte repr (~2M tokens for a 1MB image),
|
||||
# trigger spurious compression, and overflow downstream input
|
||||
# limits. The estimate must stay bounded.
|
||||
big_bytes = b"\x00" * 1_000_000
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "describe this"},
|
||||
{
|
||||
"files": [
|
||||
{"file_bytes": big_bytes, "mime_type": "image/png"}
|
||||
]
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
tokens = TokenCounter.count_message_tokens(messages)
|
||||
assert tokens < 5000
|
||||
|
||||
def test_count_message_tokens_per_image_estimate_scales(self):
|
||||
msgs_one = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"files": [
|
||||
{"file_bytes": b"x", "mime_type": "image/png"}
|
||||
]
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
msgs_two = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"files": [
|
||||
{"file_bytes": b"x", "mime_type": "image/png"},
|
||||
{"file_uri": "https://x", "mime_type": "image/jpeg"},
|
||||
]
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
assert TokenCounter.count_message_tokens(
|
||||
msgs_two
|
||||
) > TokenCounter.count_message_tokens(msgs_one)
|
||||
|
||||
def test_count_message_tokens_skips_openai_image_url(self):
|
||||
# OpenAI puts images inline as base64 data URLs. ``str(item)`` on
|
||||
# the dict would tokenize the entire base64 payload.
|
||||
big_b64 = "A" * 1_000_000
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{big_b64}"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
tokens = TokenCounter.count_message_tokens(messages)
|
||||
assert tokens < 5000
|
||||
|
||||
def test_count_message_tokens_skips_anthropic_image(self):
|
||||
# Anthropic puts images inline as ``source.data`` base64.
|
||||
big_b64 = "A" * 1_000_000
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": big_b64,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
tokens = TokenCounter.count_message_tokens(messages)
|
||||
assert tokens < 5000
|
||||
|
||||
def test_count_message_tokens_still_counts_text_parts(self):
|
||||
# The image bypass must not regress regular text-part counting.
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "hello world"}],
|
||||
}
|
||||
]
|
||||
assert TokenCounter.count_message_tokens(messages) > 0
|
||||
|
||||
def test_format_conversation_for_compression(
|
||||
self, prompt_builder, sample_conversation
|
||||
):
|
||||
|
||||
@@ -153,219 +153,4 @@ class TestLogActivity:
|
||||
):
|
||||
list(failing_gen(FakeAgent()))
|
||||
|
||||
def test_log_activity_emits_lifecycle_events(self, caplog):
|
||||
import logging as _logging
|
||||
|
||||
from application.logging import log_activity
|
||||
|
||||
class FakeAgent:
|
||||
endpoint = "test"
|
||||
user = "user1"
|
||||
user_api_key = "k"
|
||||
query = "q"
|
||||
agent_id = "agent-7"
|
||||
conversation_id = "conv-3"
|
||||
|
||||
@log_activity()
|
||||
def gen(agent, log_context=None):
|
||||
yield "x"
|
||||
|
||||
with patch("application.logging._log_activity_to_db"), \
|
||||
caplog.at_level(_logging.INFO, logger="root"):
|
||||
list(gen(FakeAgent()))
|
||||
|
||||
messages = [r.message for r in caplog.records]
|
||||
assert "activity_started" in messages
|
||||
assert "activity_finished" in messages
|
||||
|
||||
started = next(r for r in caplog.records if r.message == "activity_started")
|
||||
finished = next(r for r in caplog.records if r.message == "activity_finished")
|
||||
|
||||
assert started.endpoint == "test"
|
||||
assert started.user_id == "user1"
|
||||
assert started.agent_id == "agent-7"
|
||||
assert started.conversation_id == "conv-3"
|
||||
assert started.parent_activity_id is None # top-level activity
|
||||
|
||||
assert finished.activity_id == started.activity_id
|
||||
assert finished.status == "ok"
|
||||
assert isinstance(finished.duration_ms, int)
|
||||
assert finished.duration_ms >= 0
|
||||
assert finished.error_class is None
|
||||
|
||||
def test_log_activity_records_parent_activity_id_when_nested(self, caplog):
|
||||
# Sub-agents / workflow_agents wrap an outer @log_activity gen;
|
||||
# the inner activity_started event must link to the outer's id.
|
||||
import logging as _logging
|
||||
|
||||
from application.logging import log_activity
|
||||
|
||||
class FakeAgent:
|
||||
endpoint = "outer"
|
||||
user = "user1"
|
||||
user_api_key = ""
|
||||
query = ""
|
||||
|
||||
class InnerAgent:
|
||||
endpoint = "inner"
|
||||
user = "user1"
|
||||
user_api_key = ""
|
||||
query = ""
|
||||
|
||||
@log_activity()
|
||||
def inner_gen(agent, log_context=None):
|
||||
yield "i"
|
||||
|
||||
@log_activity()
|
||||
def outer_gen(agent, log_context=None):
|
||||
yield from inner_gen(InnerAgent())
|
||||
|
||||
with patch("application.logging._log_activity_to_db"), \
|
||||
caplog.at_level(_logging.INFO, logger="root"):
|
||||
list(outer_gen(FakeAgent()))
|
||||
|
||||
starts = [r for r in caplog.records if r.message == "activity_started"]
|
||||
assert len(starts) == 2
|
||||
outer_start, inner_start = starts
|
||||
assert outer_start.endpoint == "outer"
|
||||
assert outer_start.parent_activity_id is None
|
||||
assert inner_start.endpoint == "inner"
|
||||
assert inner_start.parent_activity_id == outer_start.activity_id
|
||||
|
||||
def test_log_activity_records_error_status_on_failure(self, caplog):
|
||||
import logging as _logging
|
||||
|
||||
from application.logging import log_activity
|
||||
|
||||
class FakeAgent:
|
||||
endpoint = "boom"
|
||||
user = "user1"
|
||||
user_api_key = ""
|
||||
query = ""
|
||||
|
||||
@log_activity()
|
||||
def failing(agent, log_context=None):
|
||||
yield "before"
|
||||
raise ValueError("bad thing")
|
||||
|
||||
with patch("application.logging._log_activity_to_db"), \
|
||||
caplog.at_level(_logging.INFO, logger="root"), \
|
||||
pytest.raises(ValueError):
|
||||
list(failing(FakeAgent()))
|
||||
|
||||
finished = next(r for r in caplog.records if r.message == "activity_finished")
|
||||
assert finished.status == "error"
|
||||
assert finished.error_class == "ValueError"
|
||||
|
||||
def test_log_activity_emits_response_summary_aggregates(self, caplog):
|
||||
# Replaces the ``agent_response`` event that ``run_agent_logic``
|
||||
# used to emit only on the Celery webhook path: every Flask
|
||||
# activity now gets the same aggregates on ``activity_finished``.
|
||||
import logging as _logging
|
||||
|
||||
from application.logging import log_activity
|
||||
|
||||
class FakeAgent:
|
||||
endpoint = "stream"
|
||||
user = "user1"
|
||||
user_api_key = ""
|
||||
query = "q"
|
||||
|
||||
@log_activity()
|
||||
def streaming(agent, log_context=None):
|
||||
yield {"answer": "Hello "}
|
||||
yield {"answer": "world"}
|
||||
yield {"thought": "thinking..."}
|
||||
yield {"sources": [{"id": "a"}, {"id": "b"}, {"id": "c"}]}
|
||||
yield {"tool_calls": [{"name": "search"}, {"name": "fetch"}]}
|
||||
yield "ignored-non-dict"
|
||||
yield {"unrecognised": "noop"}
|
||||
|
||||
with patch("application.logging._log_activity_to_db"), \
|
||||
caplog.at_level(_logging.INFO, logger="root"):
|
||||
list(streaming(FakeAgent()))
|
||||
|
||||
finished = next(r for r in caplog.records if r.message == "activity_finished")
|
||||
assert finished.answer_length == len("Hello world")
|
||||
assert finished.thought_length == len("thinking...")
|
||||
assert finished.source_count == 3
|
||||
assert finished.tool_call_count == 2
|
||||
|
||||
def test_log_activity_aggregates_initialise_to_zero(self, caplog):
|
||||
# No yields → summary fields still present and zero (so Axiom
|
||||
# schemas don't get a missing-field hole on empty activities).
|
||||
import logging as _logging
|
||||
|
||||
from application.logging import log_activity
|
||||
|
||||
class FakeAgent:
|
||||
endpoint = "stream"
|
||||
user = "user1"
|
||||
user_api_key = ""
|
||||
query = ""
|
||||
|
||||
@log_activity()
|
||||
def empty(agent, log_context=None):
|
||||
return
|
||||
yield # pragma: no cover — generator marker
|
||||
|
||||
with patch("application.logging._log_activity_to_db"), \
|
||||
caplog.at_level(_logging.INFO, logger="root"):
|
||||
list(empty(FakeAgent()))
|
||||
|
||||
finished = next(r for r in caplog.records if r.message == "activity_finished")
|
||||
assert finished.answer_length == 0
|
||||
assert finished.thought_length == 0
|
||||
assert finished.source_count == 0
|
||||
assert finished.tool_call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAccumulateResponseSummary:
|
||||
"""Direct coverage of the dispatch table — easier to enumerate edge
|
||||
cases here than in end-to-end ``log_activity`` tests."""
|
||||
|
||||
def _ctx(self):
|
||||
from application.logging import LogContext
|
||||
|
||||
return LogContext(
|
||||
endpoint="e", activity_id="a", user="u", api_key="k", query="q"
|
||||
)
|
||||
|
||||
def test_answer_appends_length(self):
|
||||
from application.logging import _accumulate_response_summary
|
||||
|
||||
ctx = self._ctx()
|
||||
_accumulate_response_summary({"answer": "abcd"}, ctx)
|
||||
_accumulate_response_summary({"answer": "ef"}, ctx)
|
||||
assert ctx.answer_length == 6
|
||||
assert ctx.thought_length == 0
|
||||
|
||||
def test_non_dict_items_are_ignored(self):
|
||||
from application.logging import _accumulate_response_summary
|
||||
|
||||
ctx = self._ctx()
|
||||
for item in ("string", 123, None, ["list"], object()):
|
||||
_accumulate_response_summary(item, ctx)
|
||||
assert ctx.answer_length == 0
|
||||
assert ctx.source_count == 0
|
||||
|
||||
def test_sources_must_be_list(self):
|
||||
# A malformed payload (sources=str) shouldn't crash the
|
||||
# accumulator — drop it silently rather than half-count it.
|
||||
from application.logging import _accumulate_response_summary
|
||||
|
||||
ctx = self._ctx()
|
||||
_accumulate_response_summary({"sources": "not-a-list"}, ctx)
|
||||
assert ctx.source_count == 0
|
||||
|
||||
def test_tool_calls_counted(self):
|
||||
from application.logging import _accumulate_response_summary
|
||||
|
||||
ctx = self._ctx()
|
||||
_accumulate_response_summary(
|
||||
{"tool_calls": [{"name": "a"}, {"name": "b"}]}, ctx
|
||||
)
|
||||
assert ctx.tool_call_count == 2
|
||||
|
||||
|
||||
|
||||
@@ -361,16 +361,6 @@ class TestSerializeForTokenCount:
|
||||
def test_none_returns_empty(self):
|
||||
assert _serialize_for_token_count(None) == ""
|
||||
|
||||
def test_bytes_returns_empty(self):
|
||||
# Regression: image/file attachments arrive as ``bytes`` from the
|
||||
# provider-specific message preparation. Without an explicit
|
||||
# branch they fell through to ``str(value)`` and inflated
|
||||
# ``prompt_tokens`` by millions per call.
|
||||
png_header = b"\x89PNG\r\n\x1a\n" + b"\x00" * 4096
|
||||
assert _serialize_for_token_count(png_header) == ""
|
||||
assert _serialize_for_token_count(bytearray(png_header)) == ""
|
||||
assert _serialize_for_token_count(memoryview(png_header)) == ""
|
||||
|
||||
def test_list_recursion(self):
|
||||
result = _serialize_for_token_count(["hello", "world"])
|
||||
assert result == ["hello", "world"]
|
||||
@@ -448,11 +438,6 @@ class TestCountTokens:
|
||||
data_url = "data:image/png;base64,iVBORw0KGgoAAAA..."
|
||||
assert _count_tokens(data_url) == 0
|
||||
|
||||
def test_bytes_returns_zero(self):
|
||||
# Regression: a multi-megabyte ``bytes`` payload (image attachment)
|
||||
# used to be repr-stringified and counted as millions of tokens.
|
||||
assert _count_tokens(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100000) == 0
|
||||
|
||||
def test_dict_counts(self):
|
||||
assert _count_tokens({"key": "some text here"}) > 0
|
||||
|
||||
@@ -518,26 +503,6 @@ class TestCountPromptTokens:
|
||||
)
|
||||
assert tokens_with > tokens_without
|
||||
|
||||
def test_bytes_in_message_content_does_not_inflate_count(self):
|
||||
# Production regression: a single image attachment landed as bytes
|
||||
# inside ``content`` and the prior repr-fallback pushed
|
||||
# ``prompt_tokens`` past 2,000,000 on Axiom. Verify the bytes
|
||||
# branch keeps the count bounded by the surrounding text.
|
||||
text_only = [{"content": "Summarize this image."}]
|
||||
with_bytes = [
|
||||
{
|
||||
"content": [
|
||||
{"type": "text", "text": "Summarize this image."},
|
||||
{"type": "image", "data": b"\x89PNG\r\n" + b"\x00" * 200_000},
|
||||
]
|
||||
}
|
||||
]
|
||||
baseline = _count_prompt_tokens(text_only, tools=None)
|
||||
with_attachment = _count_prompt_tokens(with_bytes, tools=None)
|
||||
# 200KB of zero bytes used to register as ~200K tokens; cap the
|
||||
# acceptable inflation at a small constant for tool-format overhead.
|
||||
assert with_attachment - baseline < 50
|
||||
|
||||
def test_message_with_tool_calls_field(self):
|
||||
messages = [
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user