refactor: share html entity tool call decoding

This commit is contained in:
Peter Steinberger
2026-04-08 15:33:53 +01:00
parent b358db1775
commit 17bd5f1dd2
4 changed files with 212 additions and 196 deletions

View File

@@ -3,6 +3,7 @@ import { streamSimple } from "@mariozechner/pi-ai";
import type { ProviderWrapStreamFnContext } from "@openclaw/plugin-sdk/plugin-entry";
import {
composeProviderStreamWrappers,
createHtmlEntityToolCallArgumentDecodingWrapper,
createToolStreamWrapper,
} from "@openclaw/plugin-sdk/provider-stream-shared";
@@ -200,105 +201,8 @@ export function createXaiFastModeWrapper(
};
}
function decodeHtmlEntities(value: string): string {
return value
.replaceAll(""", '"')
.replaceAll(""", '"')
.replaceAll("'", "'")
.replaceAll("'", "'")
.replaceAll("&lt;", "<")
.replaceAll("&#60;", "<")
.replaceAll("&gt;", ">")
.replaceAll("&#62;", ">")
.replaceAll("&amp;", "&")
.replaceAll("&#38;", "&");
}
function decodeHtmlEntitiesInObject(value: unknown): unknown {
if (typeof value === "string") {
return decodeHtmlEntities(value);
}
if (!value || typeof value !== "object") {
return value;
}
if (Array.isArray(value)) {
return value.map((entry) => decodeHtmlEntitiesInObject(entry));
}
const record = value as Record<string, unknown>;
for (const [key, entry] of Object.entries(record)) {
record[key] = decodeHtmlEntitiesInObject(entry);
}
return record;
}
function decodeXaiToolCallArgumentsInMessage(message: unknown): void {
if (!message || typeof message !== "object") {
return;
}
const content = (message as { content?: unknown }).content;
if (!Array.isArray(content)) {
return;
}
for (const block of content) {
if (!block || typeof block !== "object") {
continue;
}
const typedBlock = block as { type?: unknown; arguments?: unknown };
if (typedBlock.type !== "toolCall" || !typedBlock.arguments) {
continue;
}
if (typeof typedBlock.arguments === "object") {
typedBlock.arguments = decodeHtmlEntitiesInObject(typedBlock.arguments);
}
}
}
function wrapStreamDecodeXaiToolCallArguments(
stream: ReturnType<typeof streamSimple>,
): ReturnType<typeof streamSimple> {
const originalResult = stream.result.bind(stream);
stream.result = async () => {
const message = await originalResult();
decodeXaiToolCallArgumentsInMessage(message);
return message;
};
const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream);
(stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] =
function () {
const iterator = originalAsyncIterator();
return {
async next() {
const result = await iterator.next();
if (!result.done && result.value && typeof result.value === "object") {
const event = result.value as { partial?: unknown; message?: unknown };
decodeXaiToolCallArgumentsInMessage(event.partial);
decodeXaiToolCallArgumentsInMessage(event.message);
}
return result;
},
async return(value?: unknown) {
return iterator.return?.(value) ?? { done: true as const, value: undefined };
},
async throw(error?: unknown) {
return iterator.throw?.(error) ?? { done: true as const, value: undefined };
},
};
};
return stream;
}
export function createXaiToolCallArgumentDecodingWrapper(baseStreamFn: StreamFn): StreamFn {
return (model, context, options) => {
const maybeStream = baseStreamFn(model, context, options);
if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) {
return Promise.resolve(maybeStream).then((stream) =>
wrapStreamDecodeXaiToolCallArguments(stream),
);
}
return wrapStreamDecodeXaiToolCallArguments(maybeStream);
};
}
export const createXaiToolCallArgumentDecodingWrapper =
createHtmlEntityToolCallArgumentDecodingWrapper;
export function wrapXaiProviderStream(ctx: ProviderWrapStreamFnContext): StreamFn | undefined {
const extraParams = ctx.extraParams;

View File

@@ -1,5 +1,9 @@
import type { StreamFn } from "@mariozechner/pi-agent-core";
import { streamSimple } from "@mariozechner/pi-ai";
import {
createHtmlEntityToolCallArgumentDecodingWrapper,
decodeHtmlEntitiesInObject,
} from "../../../plugin-sdk/provider-stream-shared.js";
import { normalizeProviderId } from "../../model-selection.js";
import { log } from "../logger.js";
@@ -375,102 +379,8 @@ export function shouldRepairMalformedAnthropicToolCallArguments(provider?: strin
return normalizeProviderId(provider ?? "") === "kimi";
}
const HTML_ENTITY_RE = /&(?:amp|lt|gt|quot|apos|#39|#x[0-9a-f]+|#\d+);/i;
function decodeHtmlEntities(value: string): string {
return value
.replace(/&amp;/gi, "&")
.replace(/&quot;/gi, '"')
.replace(/&#39;/gi, "'")
.replace(/&apos;/gi, "'")
.replace(/&lt;/gi, "<")
.replace(/&gt;/gi, ">")
.replace(/&#x([0-9a-f]+);/gi, (_, hex) => String.fromCodePoint(Number.parseInt(hex, 16)))
.replace(/&#(\d+);/gi, (_, dec) => String.fromCodePoint(Number.parseInt(dec, 10)));
}
export function decodeHtmlEntitiesInObject(obj: unknown): unknown {
if (typeof obj === "string") {
return HTML_ENTITY_RE.test(obj) ? decodeHtmlEntities(obj) : obj;
}
if (Array.isArray(obj)) {
return obj.map(decodeHtmlEntitiesInObject);
}
if (obj && typeof obj === "object") {
const result: Record<string, unknown> = {};
for (const [key, val] of Object.entries(obj as Record<string, unknown>)) {
result[key] = decodeHtmlEntitiesInObject(val);
}
return result;
}
return obj;
}
function decodeXaiToolCallArgumentsInMessage(message: unknown): void {
if (!message || typeof message !== "object") {
return;
}
const content = (message as { content?: unknown }).content;
if (!Array.isArray(content)) {
return;
}
for (const block of content) {
if (!block || typeof block !== "object") {
continue;
}
const typedBlock = block as { type?: unknown; arguments?: unknown };
if (typedBlock.type !== "toolCall" || !typedBlock.arguments) {
continue;
}
if (typeof typedBlock.arguments === "object") {
typedBlock.arguments = decodeHtmlEntitiesInObject(typedBlock.arguments);
}
}
}
function wrapStreamDecodeXaiToolCallArguments(
stream: ReturnType<typeof streamSimple>,
): ReturnType<typeof streamSimple> {
const originalResult = stream.result.bind(stream);
stream.result = async () => {
const message = await originalResult();
decodeXaiToolCallArgumentsInMessage(message);
return message;
};
const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream);
(stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] =
function () {
const iterator = originalAsyncIterator();
return {
async next() {
const result = await iterator.next();
if (!result.done && result.value && typeof result.value === "object") {
const event = result.value as { partial?: unknown; message?: unknown };
decodeXaiToolCallArgumentsInMessage(event.partial);
decodeXaiToolCallArgumentsInMessage(event.message);
}
return result;
},
async return(value?: unknown) {
return iterator.return?.(value) ?? { done: true as const, value: undefined };
},
async throw(error?: unknown) {
return iterator.throw?.(error) ?? { done: true as const, value: undefined };
},
};
};
return stream;
}
export function wrapStreamFnDecodeXaiToolCallArguments(baseFn: StreamFn): StreamFn {
return (model, context, options) => {
const maybeStream = baseFn(model, context, options);
if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) {
return Promise.resolve(maybeStream).then((stream) =>
wrapStreamDecodeXaiToolCallArguments(stream),
);
}
return wrapStreamDecodeXaiToolCallArguments(maybeStream);
};
return createHtmlEntityToolCallArgumentDecodingWrapper(baseFn);
}
export { decodeHtmlEntitiesInObject };

View File

@@ -0,0 +1,98 @@
import type { StreamFn } from "@mariozechner/pi-agent-core";
import { describe, expect, it } from "vitest";
import {
createHtmlEntityToolCallArgumentDecodingWrapper,
decodeHtmlEntitiesInObject,
} from "./provider-stream-shared.js";
type FakeWrappedStream = {
result: () => Promise<unknown>;
[Symbol.asyncIterator]: () => AsyncIterator<unknown>;
};
function createFakeStream(params: {
events: unknown[];
resultMessage: unknown;
}): FakeWrappedStream {
return {
async result() {
return params.resultMessage;
},
[Symbol.asyncIterator]() {
return (async function* () {
for (const event of params.events) {
yield event;
}
})();
},
};
}
describe("decodeHtmlEntitiesInObject", () => {
it("recursively decodes string values", () => {
expect(
decodeHtmlEntitiesInObject({
command: "cd ~/dev &amp;&amp; echo &quot;ok&quot;",
args: ["&lt;input&gt;", "&#x27;quoted&#x27;"],
}),
).toEqual({
command: 'cd ~/dev && echo "ok"',
args: ["<input>", "'quoted'"],
});
});
});
describe("createHtmlEntityToolCallArgumentDecodingWrapper", () => {
it("decodes tool call arguments in final and streaming messages", async () => {
const resultMessage = {
content: [
{
type: "toolCall",
arguments: { command: "echo &quot;result&quot; &amp;&amp; true" },
},
],
};
const streamEvent = {
partial: {
content: [
{
type: "toolCall",
arguments: { path: "&lt;stream&gt;", nested: { quote: "&#39;x&#39;" } },
},
],
},
};
const baseStreamFn: StreamFn = () =>
createFakeStream({ events: [streamEvent], resultMessage }) as never;
const stream = createHtmlEntityToolCallArgumentDecodingWrapper(baseStreamFn)(
{} as never,
{} as never,
{},
) as FakeWrappedStream;
await expect(stream.result()).resolves.toEqual({
content: [
{
type: "toolCall",
arguments: { command: 'echo "result" && true' },
},
],
});
const iterator = stream[Symbol.asyncIterator]();
await expect(iterator.next()).resolves.toEqual({
done: false,
value: {
partial: {
content: [
{
type: "toolCall",
arguments: { path: "<stream>", nested: { quote: "'x'" } },
},
],
},
},
});
});
});

View File

@@ -1,4 +1,5 @@
import type { StreamFn } from "@mariozechner/pi-agent-core";
import { streamSimple } from "@mariozechner/pi-ai";
export type ProviderStreamWrapperFactory =
| ((streamFn: StreamFn | undefined) => StreamFn | undefined)
@@ -16,6 +17,109 @@ export function composeProviderStreamWrappers(
);
}
const HTML_ENTITY_RE = /&(?:amp|lt|gt|quot|apos|#39|#x[0-9a-f]+|#\d+);/i;
function decodeHtmlEntities(value: string): string {
return value
.replace(/&amp;/gi, "&")
.replace(/&quot;/gi, '"')
.replace(/&#39;/gi, "'")
.replace(/&apos;/gi, "'")
.replace(/&lt;/gi, "<")
.replace(/&gt;/gi, ">")
.replace(/&#x([0-9a-f]+);/gi, (_, hex) => String.fromCodePoint(Number.parseInt(hex, 16)))
.replace(/&#(\d+);/gi, (_, dec) => String.fromCodePoint(Number.parseInt(dec, 10)));
}
export function decodeHtmlEntitiesInObject(value: unknown): unknown {
if (typeof value === "string") {
return HTML_ENTITY_RE.test(value) ? decodeHtmlEntities(value) : value;
}
if (Array.isArray(value)) {
return value.map(decodeHtmlEntitiesInObject);
}
if (value && typeof value === "object") {
const result: Record<string, unknown> = {};
for (const [key, entry] of Object.entries(value as Record<string, unknown>)) {
result[key] = decodeHtmlEntitiesInObject(entry);
}
return result;
}
return value;
}
function decodeToolCallArgumentsHtmlEntitiesInMessage(message: unknown): void {
if (!message || typeof message !== "object") {
return;
}
const content = (message as { content?: unknown }).content;
if (!Array.isArray(content)) {
return;
}
for (const block of content) {
if (!block || typeof block !== "object") {
continue;
}
const typedBlock = block as { type?: unknown; arguments?: unknown };
if (typedBlock.type !== "toolCall" || !typedBlock.arguments) {
continue;
}
if (typeof typedBlock.arguments === "object") {
typedBlock.arguments = decodeHtmlEntitiesInObject(typedBlock.arguments);
}
}
}
function wrapStreamDecodeToolCallArgumentHtmlEntities(
stream: ReturnType<typeof streamSimple>,
): ReturnType<typeof streamSimple> {
const originalResult = stream.result.bind(stream);
stream.result = async () => {
const message = await originalResult();
decodeToolCallArgumentsHtmlEntitiesInMessage(message);
return message;
};
const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream);
(stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] =
function () {
const iterator = originalAsyncIterator();
return {
async next() {
const result = await iterator.next();
if (!result.done && result.value && typeof result.value === "object") {
const event = result.value as { partial?: unknown; message?: unknown };
decodeToolCallArgumentsHtmlEntitiesInMessage(event.partial);
decodeToolCallArgumentsHtmlEntitiesInMessage(event.message);
}
return result;
},
async return(value?: unknown) {
return iterator.return?.(value) ?? { done: true as const, value: undefined };
},
async throw(error?: unknown) {
return iterator.throw?.(error) ?? { done: true as const, value: undefined };
},
};
};
return stream;
}
export function createHtmlEntityToolCallArgumentDecodingWrapper(
baseStreamFn: StreamFn | undefined,
): StreamFn {
const underlying = baseStreamFn ?? streamSimple;
return (model, context, options) => {
const maybeStream = underlying(model, context, options);
if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) {
return Promise.resolve(maybeStream).then((stream) =>
wrapStreamDecodeToolCallArgumentHtmlEntities(stream),
);
}
return wrapStreamDecodeToolCallArgumentHtmlEntities(maybeStream);
};
}
export {
applyAnthropicPayloadPolicyToParams,
resolveAnthropicPayloadPolicy,