From 2d033d2aa8420af3f60e9931e9be1f87e1b5c318 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Tue, 3 Mar 2026 02:42:29 +0000 Subject: [PATCH] refactor(agents): split tool-result char estimator --- .../tool-result-char-estimator.ts | 169 ++++++++++++++++++ .../tool-result-context-guard.ts | 166 ++--------------- 2 files changed, 184 insertions(+), 151 deletions(-) create mode 100644 src/agents/pi-embedded-runner/tool-result-char-estimator.ts diff --git a/src/agents/pi-embedded-runner/tool-result-char-estimator.ts b/src/agents/pi-embedded-runner/tool-result-char-estimator.ts new file mode 100644 index 00000000000..16bdc5e43eb --- /dev/null +++ b/src/agents/pi-embedded-runner/tool-result-char-estimator.ts @@ -0,0 +1,169 @@ +import type { AgentMessage } from "@mariozechner/pi-agent-core"; + +export const CHARS_PER_TOKEN_ESTIMATE = 4; +export const TOOL_RESULT_CHARS_PER_TOKEN_ESTIMATE = 2; +const IMAGE_CHAR_ESTIMATE = 8_000; + +export type MessageCharEstimateCache = WeakMap; + +function isTextBlock(block: unknown): block is { type: "text"; text: string } { + return !!block && typeof block === "object" && (block as { type?: unknown }).type === "text"; +} + +function isImageBlock(block: unknown): boolean { + return !!block && typeof block === "object" && (block as { type?: unknown }).type === "image"; +} + +function estimateUnknownChars(value: unknown): number { + if (typeof value === "string") { + return value.length; + } + if (value === undefined) { + return 0; + } + try { + const serialized = JSON.stringify(value); + return typeof serialized === "string" ? serialized.length : 0; + } catch { + return 256; + } +} + +export function isToolResultMessage(msg: AgentMessage): boolean { + const role = (msg as { role?: unknown }).role; + const type = (msg as { type?: unknown }).type; + return role === "toolResult" || role === "tool" || type === "toolResult"; +} + +function getToolResultContent(msg: AgentMessage): unknown[] { + if (!isToolResultMessage(msg)) { + return []; + } + const content = (msg as { content?: unknown }).content; + if (typeof content === "string") { + return [{ type: "text", text: content }]; + } + return Array.isArray(content) ? content : []; +} + +export function getToolResultText(msg: AgentMessage): string { + const content = getToolResultContent(msg); + const chunks: string[] = []; + for (const block of content) { + if (isTextBlock(block)) { + chunks.push(block.text); + } + } + return chunks.join("\n"); +} + +function estimateMessageChars(msg: AgentMessage): number { + if (!msg || typeof msg !== "object") { + return 0; + } + + if (msg.role === "user") { + const content = msg.content; + if (typeof content === "string") { + return content.length; + } + let chars = 0; + if (Array.isArray(content)) { + for (const block of content) { + if (isTextBlock(block)) { + chars += block.text.length; + } else if (isImageBlock(block)) { + chars += IMAGE_CHAR_ESTIMATE; + } else { + chars += estimateUnknownChars(block); + } + } + } + return chars; + } + + if (msg.role === "assistant") { + let chars = 0; + const content = (msg as { content?: unknown }).content; + if (Array.isArray(content)) { + for (const block of content) { + if (!block || typeof block !== "object") { + continue; + } + const typed = block as { + type?: unknown; + text?: unknown; + thinking?: unknown; + arguments?: unknown; + }; + if (typed.type === "text" && typeof typed.text === "string") { + chars += typed.text.length; + } else if (typed.type === "thinking" && typeof typed.thinking === "string") { + chars += typed.thinking.length; + } else if (typed.type === "toolCall") { + try { + chars += JSON.stringify(typed.arguments ?? {}).length; + } catch { + chars += 128; + } + } else { + chars += estimateUnknownChars(block); + } + } + } + return chars; + } + + if (isToolResultMessage(msg)) { + let chars = 0; + const content = getToolResultContent(msg); + for (const block of content) { + if (isTextBlock(block)) { + chars += block.text.length; + } else if (isImageBlock(block)) { + chars += IMAGE_CHAR_ESTIMATE; + } else { + chars += estimateUnknownChars(block); + } + } + const details = (msg as { details?: unknown }).details; + chars += estimateUnknownChars(details); + const weightedChars = Math.ceil( + chars * (CHARS_PER_TOKEN_ESTIMATE / TOOL_RESULT_CHARS_PER_TOKEN_ESTIMATE), + ); + return Math.max(chars, weightedChars); + } + + return 256; +} + +export function createMessageCharEstimateCache(): MessageCharEstimateCache { + return new WeakMap(); +} + +export function estimateMessageCharsCached( + msg: AgentMessage, + cache: MessageCharEstimateCache, +): number { + const hit = cache.get(msg); + if (hit !== undefined) { + return hit; + } + const estimated = estimateMessageChars(msg); + cache.set(msg, estimated); + return estimated; +} + +export function estimateContextChars( + messages: AgentMessage[], + cache: MessageCharEstimateCache, +): number { + return messages.reduce((sum, msg) => sum + estimateMessageCharsCached(msg, cache), 0); +} + +export function invalidateMessageCharsCacheEntry( + cache: MessageCharEstimateCache, + msg: AgentMessage, +): void { + cache.delete(msg); +} diff --git a/src/agents/pi-embedded-runner/tool-result-context-guard.ts b/src/agents/pi-embedded-runner/tool-result-context-guard.ts index b1c02f0f87b..4a3d3482421 100644 --- a/src/agents/pi-embedded-runner/tool-result-context-guard.ts +++ b/src/agents/pi-embedded-runner/tool-result-context-guard.ts @@ -1,11 +1,19 @@ import type { AgentMessage } from "@mariozechner/pi-agent-core"; +import { + CHARS_PER_TOKEN_ESTIMATE, + TOOL_RESULT_CHARS_PER_TOKEN_ESTIMATE, + type MessageCharEstimateCache, + createMessageCharEstimateCache, + estimateContextChars, + estimateMessageCharsCached, + getToolResultText, + invalidateMessageCharsCacheEntry, + isToolResultMessage, +} from "./tool-result-char-estimator.js"; -const CHARS_PER_TOKEN_ESTIMATE = 4; // Keep a conservative input budget to absorb tokenizer variance and provider framing overhead. const CONTEXT_INPUT_HEADROOM_RATIO = 0.75; const SINGLE_TOOL_RESULT_CONTEXT_SHARE = 0.5; -const TOOL_RESULT_CHARS_PER_TOKEN_ESTIMATE = 2; -const IMAGE_CHAR_ESTIMATE = 8_000; export const CONTEXT_LIMIT_TRUNCATION_NOTICE = "[truncated: output exceeded context limit]"; const CONTEXT_LIMIT_TRUNCATION_SUFFIX = `\n${CONTEXT_LIMIT_TRUNCATION_NOTICE}`; @@ -23,152 +31,6 @@ type GuardableAgent = object; type GuardableAgentRecord = { transformContext?: GuardableTransformContext; }; -type MessageCharEstimateCache = WeakMap; - -function isTextBlock(block: unknown): block is { type: "text"; text: string } { - return !!block && typeof block === "object" && (block as { type?: unknown }).type === "text"; -} - -function isImageBlock(block: unknown): boolean { - return !!block && typeof block === "object" && (block as { type?: unknown }).type === "image"; -} - -function estimateUnknownChars(value: unknown): number { - if (typeof value === "string") { - return value.length; - } - if (value === undefined) { - return 0; - } - try { - const serialized = JSON.stringify(value); - return typeof serialized === "string" ? serialized.length : 0; - } catch { - return 256; - } -} - -function isToolResultMessage(msg: AgentMessage): boolean { - const role = (msg as { role?: unknown }).role; - const type = (msg as { type?: unknown }).type; - return role === "toolResult" || role === "tool" || type === "toolResult"; -} - -function getToolResultContent(msg: AgentMessage): unknown[] { - if (!isToolResultMessage(msg)) { - return []; - } - const content = (msg as { content?: unknown }).content; - if (typeof content === "string") { - return [{ type: "text", text: content }]; - } - return Array.isArray(content) ? content : []; -} - -function getToolResultText(msg: AgentMessage): string { - const content = getToolResultContent(msg); - const chunks: string[] = []; - for (const block of content) { - if (isTextBlock(block)) { - chunks.push(block.text); - } - } - return chunks.join("\n"); -} - -function estimateMessageChars(msg: AgentMessage): number { - if (!msg || typeof msg !== "object") { - return 0; - } - - if (msg.role === "user") { - const content = msg.content; - if (typeof content === "string") { - return content.length; - } - let chars = 0; - if (Array.isArray(content)) { - for (const block of content) { - if (isTextBlock(block)) { - chars += block.text.length; - } else if (isImageBlock(block)) { - chars += IMAGE_CHAR_ESTIMATE; - } else { - chars += estimateUnknownChars(block); - } - } - } - return chars; - } - - if (msg.role === "assistant") { - let chars = 0; - const content = (msg as { content?: unknown }).content; - if (Array.isArray(content)) { - for (const block of content) { - if (!block || typeof block !== "object") { - continue; - } - const typed = block as { - type?: unknown; - text?: unknown; - thinking?: unknown; - arguments?: unknown; - }; - if (typed.type === "text" && typeof typed.text === "string") { - chars += typed.text.length; - } else if (typed.type === "thinking" && typeof typed.thinking === "string") { - chars += typed.thinking.length; - } else if (typed.type === "toolCall") { - try { - chars += JSON.stringify(typed.arguments ?? {}).length; - } catch { - chars += 128; - } - } else { - chars += estimateUnknownChars(block); - } - } - } - return chars; - } - - if (isToolResultMessage(msg)) { - let chars = 0; - const content = getToolResultContent(msg); - for (const block of content) { - if (isTextBlock(block)) { - chars += block.text.length; - } else if (isImageBlock(block)) { - chars += IMAGE_CHAR_ESTIMATE; - } else { - chars += estimateUnknownChars(block); - } - } - const details = (msg as { details?: unknown }).details; - chars += estimateUnknownChars(details); - const weightedChars = Math.ceil( - chars * (CHARS_PER_TOKEN_ESTIMATE / TOOL_RESULT_CHARS_PER_TOKEN_ESTIMATE), - ); - return Math.max(chars, weightedChars); - } - - return 256; -} - -function estimateMessageCharsCached(msg: AgentMessage, cache: MessageCharEstimateCache): number { - const hit = cache.get(msg); - if (hit !== undefined) { - return hit; - } - const estimated = estimateMessageChars(msg); - cache.set(msg, estimated); - return estimated; -} - -function estimateContextChars(messages: AgentMessage[], cache: MessageCharEstimateCache): number { - return messages.reduce((sum, msg) => sum + estimateMessageCharsCached(msg, cache), 0); -} function truncateTextToBudget(text: string, maxChars: number): string { if (text.length <= maxChars) { @@ -284,7 +146,9 @@ function applyMessageMutationInPlace( } } Object.assign(targetRecord, sourceRecord); - cache?.delete(target); + if (cache) { + invalidateMessageCharsCacheEntry(cache, target); + } } function enforceToolResultContextBudgetInPlace(params: { @@ -293,7 +157,7 @@ function enforceToolResultContextBudgetInPlace(params: { maxSingleToolResultChars: number; }): void { const { messages, contextBudgetChars, maxSingleToolResultChars } = params; - const estimateCache: MessageCharEstimateCache = new WeakMap(); + const estimateCache = createMessageCharEstimateCache(); // Ensure each tool result has an upper bound before considering total context usage. for (const message of messages) {