diff --git a/src/agents/pi-embedded-helpers/errors.ts b/src/agents/pi-embedded-helpers/errors.ts index 2979cb67a51..ad079ab9115 100644 --- a/src/agents/pi-embedded-helpers/errors.ts +++ b/src/agents/pi-embedded-helpers/errors.ts @@ -14,6 +14,18 @@ export function formatBillingErrorMessage(provider?: string): string { export const BILLING_ERROR_USER_MESSAGE = formatBillingErrorMessage(); const RATE_LIMIT_ERROR_USER_MESSAGE = "⚠️ API rate limit reached. Please try again later."; +const OVERLOADED_ERROR_USER_MESSAGE = + "The AI service is temporarily overloaded. Please try again in a moment."; + +function formatRateLimitOrOverloadedErrorCopy(raw: string): string | undefined { + if (isRateLimitErrorMessage(raw)) { + return RATE_LIMIT_ERROR_USER_MESSAGE; + } + if (isOverloadedErrorMessage(raw)) { + return OVERLOADED_ERROR_USER_MESSAGE; + } + return undefined; +} export function isContextOverflowError(errorMessage?: string): boolean { if (!errorMessage) { @@ -463,12 +475,9 @@ export function formatAssistantErrorText( return `LLM request rejected: ${invalidRequest[1]}`; } - if (isRateLimitErrorMessage(raw)) { - return RATE_LIMIT_ERROR_USER_MESSAGE; - } - - if (isOverloadedErrorMessage(raw)) { - return "The AI service is temporarily overloaded. Please try again in a moment."; + const transientCopy = formatRateLimitOrOverloadedErrorCopy(raw); + if (transientCopy) { + return transientCopy; } if (isBillingErrorMessage(raw)) { @@ -523,11 +532,9 @@ export function sanitizeUserFacingText(text: string, opts?: { errorContext?: boo } if (ERROR_PREFIX_RE.test(trimmed)) { - if (isRateLimitErrorMessage(trimmed)) { - return RATE_LIMIT_ERROR_USER_MESSAGE; - } - if (isOverloadedErrorMessage(trimmed)) { - return "The AI service is temporarily overloaded. Please try again in a moment."; + const prefixedCopy = formatRateLimitOrOverloadedErrorCopy(trimmed); + if (prefixedCopy) { + return prefixedCopy; } if (isTimeoutErrorMessage(trimmed)) { return "LLM request timed out."; diff --git a/src/agents/pi-embedded-subscribe.handlers.messages.ts b/src/agents/pi-embedded-subscribe.handlers.messages.ts index eb18187193a..f4e738c6663 100644 --- a/src/agents/pi-embedded-subscribe.handlers.messages.ts +++ b/src/agents/pi-embedded-subscribe.handlers.messages.ts @@ -57,7 +57,7 @@ export function handleMessageUpdate( return; } - ctx.state.lastAssistant = msg; + ctx.noteLastAssistant(msg); const assistantEvent = evt.assistantMessageEvent; const assistantRecord = @@ -200,7 +200,7 @@ export function handleMessageEnd( } const assistantMessage = msg; - ctx.state.lastAssistant = assistantMessage; + ctx.noteLastAssistant(assistantMessage); ctx.recordAssistantUsage((assistantMessage as { usage?: unknown }).usage); promoteThinkingTagsToBlocks(assistantMessage); diff --git a/src/agents/pi-embedded-subscribe.handlers.types.ts b/src/agents/pi-embedded-subscribe.handlers.types.ts index 2a626ad86a0..fcb6a3e75e5 100644 --- a/src/agents/pi-embedded-subscribe.handlers.types.ts +++ b/src/agents/pi-embedded-subscribe.handlers.types.ts @@ -72,6 +72,7 @@ export type EmbeddedPiSubscribeContext = { blockChunking?: BlockReplyChunking; blockChunker: EmbeddedBlockChunker | null; hookRunner?: HookRunner; + noteLastAssistant: (msg: AgentMessage) => void; shouldEmitToolResult: () => boolean; shouldEmitToolOutput: () => boolean; diff --git a/src/agents/pi-embedded-subscribe.ts b/src/agents/pi-embedded-subscribe.ts index 102d0811ab1..146066d21c7 100644 --- a/src/agents/pi-embedded-subscribe.ts +++ b/src/agents/pi-embedded-subscribe.ts @@ -1,3 +1,4 @@ +import type { AgentMessage } from "@mariozechner/pi-agent-core"; import type { InlineCodeState } from "../markdown/code-spans.js"; import type { EmbeddedPiSubscribeContext, @@ -569,6 +570,12 @@ export function subscribeEmbeddedPiSession(params: SubscribeEmbeddedPiSessionPar resetAssistantMessageState(0); }; + const noteLastAssistant = (msg: AgentMessage) => { + if (msg?.role === "assistant") { + state.lastAssistant = msg; + } + }; + const ctx: EmbeddedPiSubscribeContext = { params, state, @@ -576,6 +583,7 @@ export function subscribeEmbeddedPiSession(params: SubscribeEmbeddedPiSessionPar blockChunking, blockChunker, hookRunner: params.hookRunner, + noteLastAssistant, shouldEmitToolResult, shouldEmitToolOutput, emitToolSummary,