fix: clamp xAI live gateway thinking

This commit is contained in:
Peter Steinberger
2026-05-06 05:30:22 +01:00
parent afc2c2e207
commit e9987ffc3a
4 changed files with 266 additions and 23 deletions

View File

@@ -0,0 +1,5 @@
import type { ProviderThinkingProfile } from "openclaw/plugin-sdk/plugin-entry";
export function resolveThinkingProfile(): ProviderThinkingProfile {
return { levels: [{ id: "off" }], defaultLevel: "off" };
}

View File

@@ -3,7 +3,12 @@ import fs from "node:fs/promises";
import { createServer } from "node:net";
import os from "node:os";
import path from "node:path";
import type { Api, Model } from "@mariozechner/pi-ai";
import {
clampThinkingLevel,
type Api,
type Model,
type ModelThinkingLevel,
} from "@mariozechner/pi-ai";
import { afterEach, describe, expect, it } from "vitest";
import { resolveAgentWorkspaceDir, resolveDefaultAgentDir } from "../agents/agent-scope.js";
import { ensureAuthProfileStore, saveAuthProfileStore } from "../agents/auth-profiles/store.js";
@@ -36,6 +41,7 @@ import { clearRuntimeConfigSnapshot, getRuntimeConfig } from "../config/io.js";
import type { ModelsConfig, ModelProviderConfig, OpenClawConfig } from "../config/types.js";
import { isTruthyEnvValue } from "../infra/env.js";
import { normalizeGoogleModelId } from "../plugin-sdk/google-model-id.js";
import { resolveProviderThinkingProfile } from "../plugins/provider-runtime.js";
import { DEFAULT_AGENT_ID } from "../routing/session-key.js";
import { stripAssistantInternalScaffolding } from "../shared/text/assistant-visible-text.js";
import { GATEWAY_CLIENT_MODES, GATEWAY_CLIENT_NAMES } from "../utils/message-channel.js";
@@ -587,6 +593,95 @@ describe("resolveGatewayLiveMaxModels", () => {
});
});
function createGatewayLiveTestModel(provider: string, id: string): Model<Api> {
return {
provider,
id,
name: id,
api: "openai-responses",
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 1_000,
maxTokens: 100,
reasoning: false,
} as Model<Api>;
}
describe("resolveExplicitLiveModelCandidates", () => {
it("uses targeted registry lookup for explicit provider/model filters", () => {
const model = createGatewayLiveTestModel("xai", "grok-4.3");
const matcher = createLiveTargetMatcher({
providerFilter: new Set(["xai"]),
modelFilter: new Set(["xai/grok-4.3"]),
env: {},
});
const candidates = resolveExplicitLiveModelCandidates({
modelRegistry: {
find(provider, modelId) {
expect(provider).toBe("xai");
expect(modelId).toBe("grok-4.3");
return model;
},
getAll() {
throw new Error("explicit model lookup should not enumerate registry");
},
},
modelFilter: new Set(["xai/grok-4.3"]),
providerFilter: new Set(["xai"]),
targetMatcher: matcher,
});
expect(candidates).toEqual([model]);
});
it("falls back to enumeration for ambiguous model-only filters", () => {
const matcher = createLiveTargetMatcher({
providerFilter: null,
modelFilter: new Set(["grok-4.3"]),
env: {},
});
expect(
resolveExplicitLiveModelCandidates({
modelRegistry: {
find() {
throw new Error("ambiguous model-only lookup should not use direct find");
},
getAll() {
return [];
},
},
modelFilter: new Set(["grok-4.3"]),
providerFilter: null,
targetMatcher: matcher,
}),
).toBeNull();
});
});
describe("resolveGatewayLiveModelThinkingLevel", () => {
it("clamps requested thinking to levels supported by model metadata", () => {
expect(
resolveGatewayLiveModelThinkingLevel({
cfg: {},
model: {
...createGatewayLiveTestModel("xai", "grok-4.3"),
reasoning: true,
thinkingLevelMap: {
off: null,
minimal: null,
low: null,
medium: null,
high: null,
xhigh: null,
},
},
requestedLevel: "low",
}),
).toBe("off");
});
});
function isGoogleModelNotFoundText(text: string): boolean {
const trimmed = text.trim();
if (!trimmed) {
@@ -1281,6 +1376,101 @@ type GatewayModelSuiteParams = {
providerOverrides?: Record<string, ModelProviderConfig>;
};
type LiveModelRegistry = {
find(provider: string, modelId: string): Model<Api> | null | undefined;
getAll(): Array<Model<Api>>;
};
function parseExplicitLiveModelRef(
raw: string,
providerFilter: Set<string> | null,
): { provider: string; modelId: string } | null {
const trimmed = raw.trim();
if (!trimmed) {
return null;
}
const slash = trimmed.indexOf("/");
if (slash !== -1) {
const provider = normalizeProviderId(trimmed.slice(0, slash));
const modelId = trimmed.slice(slash + 1).trim();
return provider && modelId ? { provider, modelId } : null;
}
if (!providerFilter || providerFilter.size !== 1) {
return null;
}
const [provider] = [...providerFilter];
return provider ? { provider: normalizeProviderId(provider), modelId: trimmed } : null;
}
function resolveExplicitLiveModelCandidates(params: {
modelRegistry: LiveModelRegistry;
modelFilter: Set<string> | null;
providerFilter: Set<string> | null;
targetMatcher: ReturnType<typeof createLiveTargetMatcher>;
}): Array<Model<Api>> | null {
if (!params.modelFilter || params.modelFilter.size === 0) {
return null;
}
const candidates: Array<Model<Api>> = [];
const seen = new Set<string>();
for (const raw of params.modelFilter) {
const ref = parseExplicitLiveModelRef(raw, params.providerFilter);
if (!ref) {
return null;
}
const model = params.modelRegistry.find(ref.provider, ref.modelId);
if (!model) {
return null;
}
if (
!params.targetMatcher.matchesProvider(model.provider) ||
!params.targetMatcher.matchesModel(model.provider, model.id)
) {
return null;
}
const key = `${normalizeProviderId(model.provider)}/${model.id.toLowerCase()}`;
if (!seen.has(key)) {
seen.add(key);
candidates.push(model);
}
}
return candidates;
}
function resolveGatewayLiveModelThinkingLevel(params: {
cfg: OpenClawConfig;
model: Model<Api>;
requestedLevel: string;
}): string {
const { model, requestedLevel } = params;
const normalized = requestedLevel.trim() as ModelThinkingLevel;
if (!["off", "minimal", "low", "medium", "high", "xhigh"].includes(normalized)) {
return requestedLevel;
}
const profile = resolveProviderThinkingProfile({
provider: model.provider,
config: params.cfg,
context: {
provider: model.provider,
modelId: model.id,
reasoning: model.reasoning,
},
});
if (profile) {
const levelIds = profile.levels.map((level) => level.id);
if (levelIds.includes(normalized)) {
return normalized;
}
if (profile.defaultLevel) {
return profile.defaultLevel;
}
if (levelIds.length === 1) {
return levelIds[0] ?? requestedLevel;
}
}
return clampThinkingLevel(model, normalized);
}
function buildLiveGatewayConfig(params: {
cfg: OpenClawConfig;
candidates: Array<Model<Api>>;
@@ -1549,6 +1739,14 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) {
for (const [index, model] of params.candidates.entries()) {
const modelKey = `${model.provider}/${model.id}`;
const progressLabel = `[${params.label}] ${index + 1}/${total} ${modelKey}`;
const thinkingLevel = resolveGatewayLiveModelThinkingLevel({
cfg: params.cfg,
model,
requestedLevel: params.thinkingLevel,
});
if (thinkingLevel !== params.thinkingLevel) {
logProgress(`${progressLabel}: thinking ${params.thinkingLevel} -> ${thinkingLevel}`);
}
// Use a separate session per model: live providers can finalize late after
// skip/retry paths, and a reset on a reused key does not isolate those
// delayed transcript writes from the next model probe.
@@ -1589,7 +1787,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) {
modelKey,
message:
"Explain in 2-3 sentences how the JavaScript event loop handles microtasks vs macrotasks. Must mention both words: microtask and macrotask.",
thinkingLevel: params.thinkingLevel,
thinkingLevel,
context: `${progressLabel}: prompt`,
});
if (!text) {
@@ -1601,7 +1799,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) {
modelKey,
message:
"Explain in 2-3 sentences how the JavaScript event loop handles microtasks vs macrotasks. Must mention both words: microtask and macrotask.",
thinkingLevel: params.thinkingLevel,
thinkingLevel,
context: `${progressLabel}: prompt-retry`,
});
}
@@ -1650,7 +1848,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) {
modelKey,
message:
"Answer in exactly two short sentences. Include the exact lowercase words microtask and macrotask. No bullets.",
thinkingLevel: params.thinkingLevel,
thinkingLevel,
context: `${progressLabel}: prompt-keyword-retry`,
});
if (retryText) {
@@ -1697,7 +1895,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) {
: "OpenClaw live tool probe (local, safe): " +
`use the tool named \`read\` (or \`Read\`) with JSON arguments {"path":"${toolProbePath}"}. ` +
"Then reply with the two nonce values you read (include both).",
thinkingLevel: params.thinkingLevel,
thinkingLevel,
context: `${progressLabel}: tool-read`,
});
if (
@@ -1768,7 +1966,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) {
`mkdir -p "${tempDir}" && printf '%s' '${nonceC}' > "${toolWritePath}". ` +
`Then use the tool named \`read\` (or \`Read\`) with JSON arguments {"path":"${toolWritePath}"}. ` +
"Finally reply including the nonce text you read back.",
thinkingLevel: params.thinkingLevel,
thinkingLevel,
context: `${progressLabel}: tool-exec`,
});
if (
@@ -1836,7 +2034,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) {
content: imageBase64,
},
],
thinkingLevel: params.thinkingLevel,
thinkingLevel,
context: `${progressLabel}: image`,
});
if (
@@ -1883,7 +2081,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) {
idempotencyKey: `idem-${runId2}-1`,
modelKey,
message: `Call the tool named \`read\` (or \`Read\`) on "${toolProbePath}". Do not write any other text.`,
thinkingLevel: params.thinkingLevel,
thinkingLevel,
context: `${progressLabel}: tool-only-regression-first`,
});
assertNoReasoningTags({
@@ -1899,7 +2097,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) {
idempotencyKey: `idem-${runId2}-2`,
modelKey,
message: `Now answer: what are the values of nonceA and nonceB in "${toolProbePath}"? Reply with exactly: ${nonceA} ${nonceB}.`,
thinkingLevel: params.thinkingLevel,
thinkingLevel,
context: `${progressLabel}: tool-only-regression-second`,
});
assertNoReasoningTags({
@@ -1919,7 +2117,7 @@ async function runGatewayModelSuite(params: GatewayModelSuiteParams) {
sessionKey,
modelKey,
label: progressLabel,
thinkingLevel: params.thinkingLevel,
thinkingLevel,
});
}
return "done";
@@ -2171,7 +2369,6 @@ describeLive("gateway live (dev agent, profile keys)", () => {
const agentDir = resolveDefaultAgentDir(cfg);
const authStorage = discoverAuthStorage(agentDir);
const modelRegistry = discoverModels(authStorage, agentDir);
const all = modelRegistry.getAll();
const rawModels = process.env.OPENCLAW_LIVE_GATEWAY_MODELS?.trim();
const useModern = !rawModels || rawModels === "modern" || rawModels === "all";
@@ -2184,18 +2381,29 @@ describeLive("gateway live (dev agent, profile keys)", () => {
config: cfg,
env: process.env,
});
const wanted = filter
? all.filter((m) => targetMatcher.matchesModel(m.provider, m.id))
: all.filter(
(m) =>
!shouldExcludeProviderFromDefaultHighSignalLiveSweep({
provider: m.provider,
useExplicitModels: useExplicit,
providerFilter: PROVIDERS,
config: cfg,
env: process.env,
}) && isHighSignalLiveModelRef({ provider: m.provider, id: m.id }),
);
let wanted = useExplicit
? resolveExplicitLiveModelCandidates({
modelRegistry,
modelFilter: filter,
providerFilter: PROVIDERS,
targetMatcher,
})
: null;
if (!wanted) {
const all = modelRegistry.getAll();
wanted = filter
? all.filter((m) => targetMatcher.matchesModel(m.provider, m.id))
: all.filter(
(m) =>
!shouldExcludeProviderFromDefaultHighSignalLiveSweep({
provider: m.provider,
useExplicitModels: useExplicit,
providerFilter: PROVIDERS,
config: cfg,
env: process.env,
}) && isHighSignalLiveModelRef({ provider: m.provider, id: m.id }),
);
}
const candidates: Array<Model<Api>> = [];
const skipped: Array<{ model: string; error: string }> = [];

View File

@@ -70,6 +70,7 @@ let resolveProviderStreamFn: typeof import("./provider-runtime.js").resolveProvi
let resolveProviderCacheTtlEligibility: typeof import("./provider-runtime.js").resolveProviderCacheTtlEligibility;
let resolveProviderBinaryThinking: typeof import("./provider-runtime.js").resolveProviderBinaryThinking;
let createProviderEmbeddingProvider: typeof import("./provider-runtime.js").createProviderEmbeddingProvider;
let resolveProviderThinkingProfile: typeof import("./provider-runtime.js").resolveProviderThinkingProfile;
let resolveProviderDefaultThinkingLevel: typeof import("./provider-runtime.js").resolveProviderDefaultThinkingLevel;
let resolveProviderModernModelRef: typeof import("./provider-runtime.js").resolveProviderModernModelRef;
let resolveProviderReasoningOutputModeWithPlugin: typeof import("./provider-runtime.js").resolveProviderReasoningOutputModeWithPlugin;
@@ -295,6 +296,7 @@ describe("provider-runtime", () => {
resolveProviderCacheTtlEligibility,
resolveProviderBinaryThinking,
createProviderEmbeddingProvider,
resolveProviderThinkingProfile,
resolveProviderDefaultThinkingLevel,
resolveProviderModernModelRef,
resolveProviderReasoningOutputModeWithPlugin,
@@ -1154,6 +1156,30 @@ describe("provider-runtime", () => {
expect(resolvePluginProvidersMock).not.toHaveBeenCalled();
});
it("resolves thinking profiles from bundled policy surface before runtime plugins", () => {
const resolveThinkingProfile = vi.fn(() => ({
levels: [{ id: "off" as const }],
defaultLevel: "off" as const,
}));
resolveBundledProviderPolicySurfaceMock.mockReturnValue({
resolveThinkingProfile,
});
expect(
resolveProviderThinkingProfile({
provider: "xai",
context: {
provider: "xai",
modelId: "grok-4.3",
reasoning: true,
},
}),
).toEqual({ levels: [{ id: "off" }], defaultLevel: "off" });
expect(resolveThinkingProfile).toHaveBeenCalledTimes(1);
expect(resolvePluginProvidersMock).not.toHaveBeenCalled();
});
it("resolves provider config defaults through owner plugins", () => {
resolvePluginProvidersMock.mockReturnValue([
{

View File

@@ -763,6 +763,10 @@ export function resolveProviderThinkingProfile(params: {
env?: NodeJS.ProcessEnv;
context: ProviderDefaultThinkingPolicyContext;
}): ProviderThinkingProfile | null | undefined {
const bundledSurface = resolveBundledProviderPolicySurface(params.provider);
if (bundledSurface?.resolveThinkingProfile) {
return bundledSurface.resolveThinkingProfile(params.context) ?? undefined;
}
return resolveProviderRuntimePlugin(params)?.resolveThinkingProfile?.(params.context);
}