From 6a3d4f9fadac1ff7a5045e63c2458341c832bd71 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Mon, 23 Mar 2026 04:36:06 -0700 Subject: [PATCH] test: isolate pi model and reset-model thread fixtures --- src/agents/pi-embedded-runner/compact.ts | 3 +- src/agents/pi-embedded-runner/model.test.ts | 7 ++ src/agents/pi-embedded-runner/model.ts | 65 ++++++++--- .../compaction-safeguard.test.ts | 10 +- .../pi-extensions/compaction-safeguard.ts | 12 +- .../reply/session-reset-model.test.ts | 109 ++++++++++++++++++ src/auto-reply/reply/session-reset-model.ts | 5 +- src/auto-reply/reply/session.test.ts | 97 ---------------- 8 files changed, 189 insertions(+), 119 deletions(-) create mode 100644 src/auto-reply/reply/session-reset-model.test.ts diff --git a/src/agents/pi-embedded-runner/compact.ts b/src/agents/pi-embedded-runner/compact.ts index df6466cf17c..7641afbffd7 100644 --- a/src/agents/pi-embedded-runner/compact.ts +++ b/src/agents/pi-embedded-runner/compact.ts @@ -390,7 +390,7 @@ async function runPostCompactionSideEffects(params: { type CompactionHookRunner = { hasHooks?: (hookName?: string) => boolean; runBeforeCompaction?: ( - metrics: { messageCount: number; tokenCount?: number }, + metrics: { messageCount: number; tokenCount?: number; sessionFile?: string }, context: { sessionId: string; agentId: string; @@ -1337,6 +1337,7 @@ export async function compactEmbeddedPiSession( await hookRunner.runBeforeCompaction( { messageCount: -1, + sessionFile: params.sessionFile, }, hookCtx, ); diff --git a/src/agents/pi-embedded-runner/model.test.ts b/src/agents/pi-embedded-runner/model.test.ts index 69bfc3192d7..12c08f6b0f9 100644 --- a/src/agents/pi-embedded-runner/model.test.ts +++ b/src/agents/pi-embedded-runner/model.test.ts @@ -1,4 +1,5 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; +import { discoverModels } from "../pi-model-discovery.js"; import { createProviderRuntimeTestMock } from "./model.provider-runtime.test-support.js"; vi.mock("../pi-model-discovery.js", () => ({ @@ -62,7 +63,10 @@ function resolveModelForTest( agentDir?: string, cfg?: OpenClawConfig, ) { + const resolvedAgentDir = agentDir ?? "/tmp/agent"; return resolveModel(provider, modelId, agentDir, cfg, { + authStorage: { mocked: true } as never, + modelRegistry: discoverModels({ mocked: true } as never, resolvedAgentDir), runtimeHooks: createRuntimeHooks(), }); } @@ -74,7 +78,10 @@ function resolveModelAsyncForTest( cfg?: OpenClawConfig, options?: { retryTransientProviderRuntimeMiss?: boolean }, ) { + const resolvedAgentDir = agentDir ?? "/tmp/agent"; return resolveModelAsync(provider, modelId, agentDir, cfg, { + authStorage: { mocked: true } as never, + modelRegistry: discoverModels({ mocked: true } as never, resolvedAgentDir), ...options, runtimeHooks: createRuntimeHooks(), }); diff --git a/src/agents/pi-embedded-runner/model.ts b/src/agents/pi-embedded-runner/model.ts index c30c531bad1..ab2c0dac460 100644 --- a/src/agents/pi-embedded-runner/model.ts +++ b/src/agents/pi-embedded-runner/model.ts @@ -76,6 +76,13 @@ function normalizeResolvedModel(params: { agentDir?: string; runtimeHooks?: ProviderRuntimeHooks; }): Model { + const normalizedInputModel = + Array.isArray(params.model.input) && params.model.input.length > 0 + ? params.model + : ({ + ...params.model, + input: ["text"], + } as Model); const runtimeHooks = params.runtimeHooks ?? DEFAULT_PROVIDER_RUNTIME_HOOKS; const pluginNormalized = runtimeHooks.normalizeProviderResolvedModelWithPlugin({ provider: params.provider, @@ -84,14 +91,36 @@ function normalizeResolvedModel(params: { config: params.cfg, agentDir: params.agentDir, provider: params.provider, - modelId: params.model.id, - model: params.model, + modelId: normalizedInputModel.id, + model: normalizedInputModel, }, }) as Model | undefined; if (pluginNormalized) { return normalizeModelCompat(pluginNormalized); } - return normalizeResolvedProviderModel(params); + return normalizeResolvedProviderModel({ + provider: params.provider, + model: normalizedInputModel, + }); +} + +function findInlineModelMatch(params: { + providers: Record; + provider: string; + modelId: string; +}) { + const inlineModels = buildInlineProviderModels(params.providers); + const exact = inlineModels.find( + (entry) => entry.provider === params.provider && entry.id === params.modelId, + ); + if (exact) { + return exact; + } + const normalizedProvider = normalizeProviderId(params.provider); + return inlineModels.find( + (entry) => + normalizeProviderId(entry.provider) === normalizedProvider && entry.id === params.modelId, + ); } export { buildModelAliasLines }; @@ -212,11 +241,11 @@ function resolveExplicitModelWithRegistry(params: { return { kind: "suppressed" }; } const providerConfig = resolveConfiguredProviderConfig(cfg, provider); - const inlineModels = buildInlineProviderModels(cfg?.models?.providers ?? {}); - const normalizedProvider = normalizeProviderId(provider); - const inlineMatch = inlineModels.find( - (entry) => normalizeProviderId(entry.provider) === normalizedProvider && entry.id === modelId, - ); + const inlineMatch = findInlineModelMatch({ + providers: cfg?.models?.providers ?? {}, + provider, + modelId, + }); if (inlineMatch?.api) { return { kind: "resolved", @@ -249,9 +278,11 @@ function resolveExplicitModelWithRegistry(params: { } const providers = cfg?.models?.providers ?? {}; - const fallbackInlineMatch = buildInlineProviderModels(providers).find( - (entry) => normalizeProviderId(entry.provider) === normalizedProvider && entry.id === modelId, - ); + const fallbackInlineMatch = findInlineModelMatch({ + providers, + provider, + modelId, + }); if (fallbackInlineMatch?.api) { return { kind: "resolved", @@ -385,6 +416,8 @@ export function resolveModel( agentDir?: string, cfg?: OpenClawConfig, options?: { + authStorage?: AuthStorage; + modelRegistry?: ModelRegistry; runtimeHooks?: ProviderRuntimeHooks; }, ): { @@ -394,8 +427,8 @@ export function resolveModel( modelRegistry: ModelRegistry; } { const resolvedAgentDir = agentDir ?? resolveOpenClawAgentDir(); - const authStorage = discoverAuthStorage(resolvedAgentDir); - const modelRegistry = discoverModels(authStorage, resolvedAgentDir); + const authStorage = options?.authStorage ?? discoverAuthStorage(resolvedAgentDir); + const modelRegistry = options?.modelRegistry ?? discoverModels(authStorage, resolvedAgentDir); const model = resolveModelWithRegistry({ provider, modelId, @@ -421,6 +454,8 @@ export async function resolveModelAsync( agentDir?: string, cfg?: OpenClawConfig, options?: { + authStorage?: AuthStorage; + modelRegistry?: ModelRegistry; retryTransientProviderRuntimeMiss?: boolean; runtimeHooks?: ProviderRuntimeHooks; }, @@ -431,8 +466,8 @@ export async function resolveModelAsync( modelRegistry: ModelRegistry; }> { const resolvedAgentDir = agentDir ?? resolveOpenClawAgentDir(); - const authStorage = discoverAuthStorage(resolvedAgentDir); - const modelRegistry = discoverModels(authStorage, resolvedAgentDir); + const authStorage = options?.authStorage ?? discoverAuthStorage(resolvedAgentDir); + const modelRegistry = options?.modelRegistry ?? discoverModels(authStorage, resolvedAgentDir); const explicitModel = resolveExplicitModelWithRegistry({ provider, modelId, diff --git a/src/agents/pi-extensions/compaction-safeguard.test.ts b/src/agents/pi-extensions/compaction-safeguard.test.ts index d7c8886b32e..e7032183ce9 100644 --- a/src/agents/pi-extensions/compaction-safeguard.test.ts +++ b/src/agents/pi-extensions/compaction-safeguard.test.ts @@ -4,7 +4,7 @@ import path from "node:path"; import type { AgentMessage } from "@mariozechner/pi-agent-core"; import type { Api, Model } from "@mariozechner/pi-ai"; import type { ExtensionAPI, ExtensionContext } from "@mariozechner/pi-coding-agent"; -import { describe, expect, it, vi } from "vitest"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import type { OpenClawConfig } from "../../config/config.js"; import * as compactionModule from "../compaction.js"; import { buildEmbeddedExtensionFactories } from "../pi-embedded-runner/extensions.js"; @@ -51,6 +51,14 @@ const { SUMMARY_TRUNCATED_MARKER, } = __testing; +beforeEach(() => { + __testing.setSummarizeInStagesForTest(mockSummarizeInStages); +}); + +afterEach(() => { + __testing.setSummarizeInStagesForTest(); +}); + function stubSessionManager(): ExtensionContext["sessionManager"] { const stub: ExtensionContext["sessionManager"] = { getCwd: () => "/stub", diff --git a/src/agents/pi-extensions/compaction-safeguard.ts b/src/agents/pi-extensions/compaction-safeguard.ts index 5b7d4f082f0..a9b25a72e6a 100644 --- a/src/agents/pi-extensions/compaction-safeguard.ts +++ b/src/agents/pi-extensions/compaction-safeguard.ts @@ -66,6 +66,9 @@ const STRICT_EXACT_IDENTIFIERS_INSTRUCTION = "For ## Exact identifiers, preserve literal values exactly as seen (IDs, URLs, file paths, ports, hashes, dates, times)."; const POLICY_OFF_EXACT_IDENTIFIERS_INSTRUCTION = "For ## Exact identifiers, include identifiers only when needed for continuity; do not enforce literal-preservation rules."; +const compactionSafeguardDeps = { + summarizeInStages, +}; type ToolFailure = { toolCallId: string; @@ -906,7 +909,7 @@ export default function compactionSafeguardExtension(api: ExtensionAPI): void { Math.floor(contextWindowTokens * droppedChunkRatio) - SUMMARIZATION_OVERHEAD_TOKENS, ); - droppedSummary = await summarizeInStages({ + droppedSummary = await compactionSafeguardDeps.summarizeInStages({ messages: pruned.droppedMessagesList, model, apiKey, @@ -977,7 +980,7 @@ export default function compactionSafeguardExtension(api: ExtensionAPI): void { try { historySummary = messagesToSummarize.length > 0 - ? await summarizeInStages({ + ? await compactionSafeguardDeps.summarizeInStages({ messages: messagesToSummarize, model, apiKey, @@ -993,7 +996,7 @@ export default function compactionSafeguardExtension(api: ExtensionAPI): void { summaryWithoutPreservedTurns = historySummary; if (preparation.isSplitTurn && turnPrefixMessages.length > 0) { - const prefixSummary = await summarizeInStages({ + const prefixSummary = await compactionSafeguardDeps.summarizeInStages({ messages: turnPrefixMessages, model, apiKey, @@ -1111,6 +1114,9 @@ export default function compactionSafeguardExtension(api: ExtensionAPI): void { } export const __testing = { + setSummarizeInStagesForTest(next?: typeof summarizeInStages) { + compactionSafeguardDeps.summarizeInStages = next ?? summarizeInStages; + }, collectToolFailures, formatToolFailuresSection, splitPreservedRecentTurns, diff --git a/src/auto-reply/reply/session-reset-model.test.ts b/src/auto-reply/reply/session-reset-model.test.ts new file mode 100644 index 00000000000..7140c4ebe01 --- /dev/null +++ b/src/auto-reply/reply/session-reset-model.test.ts @@ -0,0 +1,109 @@ +import { describe, expect, it } from "vitest"; +import type { ModelCatalogEntry } from "../../agents/model-catalog.js"; +import { buildModelAliasIndex } from "../../agents/model-selection.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import type { SessionEntry } from "../../config/sessions.js"; +import { applyResetModelOverride } from "./session-reset-model.js"; + +const modelCatalog: ModelCatalogEntry[] = [ + { provider: "minimax", id: "m2.7", name: "M2.7" }, + { provider: "openai", id: "gpt-4o-mini", name: "GPT-4o mini" }, +]; + +describe("applyResetModelOverride", () => { + it("selects a model hint and strips it from the body", async () => { + const cfg = {} as OpenClawConfig; + const aliasIndex = buildModelAliasIndex({ cfg, defaultProvider: "openai" }); + const sessionEntry: SessionEntry = { + sessionId: "s1", + updatedAt: Date.now(), + }; + const sessionStore: Record = { "agent:main:dm:1": sessionEntry }; + const sessionCtx = { BodyStripped: "minimax summarize" }; + const ctx = { ChatType: "direct" }; + + await applyResetModelOverride({ + cfg, + resetTriggered: true, + bodyStripped: "minimax summarize", + sessionCtx, + ctx, + sessionEntry, + sessionStore, + sessionKey: "agent:main:dm:1", + defaultProvider: "openai", + defaultModel: "gpt-4o-mini", + aliasIndex, + modelCatalog, + }); + + expect(sessionEntry.providerOverride).toBe("minimax"); + expect(sessionEntry.modelOverride).toBe("m2.7"); + expect(sessionCtx.BodyStripped).toBe("summarize"); + }); + + it("clears auth profile overrides when reset applies a model", async () => { + const cfg = {} as OpenClawConfig; + const aliasIndex = buildModelAliasIndex({ cfg, defaultProvider: "openai" }); + const sessionEntry: SessionEntry = { + sessionId: "s1", + updatedAt: Date.now(), + authProfileOverride: "anthropic:default", + authProfileOverrideSource: "user", + authProfileOverrideCompactionCount: 2, + }; + const sessionStore: Record = { "agent:main:dm:1": sessionEntry }; + const sessionCtx = { BodyStripped: "minimax summarize" }; + const ctx = { ChatType: "direct" }; + + await applyResetModelOverride({ + cfg, + resetTriggered: true, + bodyStripped: "minimax summarize", + sessionCtx, + ctx, + sessionEntry, + sessionStore, + sessionKey: "agent:main:dm:1", + defaultProvider: "openai", + defaultModel: "gpt-4o-mini", + aliasIndex, + modelCatalog, + }); + + expect(sessionEntry.authProfileOverride).toBeUndefined(); + expect(sessionEntry.authProfileOverrideSource).toBeUndefined(); + expect(sessionEntry.authProfileOverrideCompactionCount).toBeUndefined(); + }); + + it("skips when resetTriggered is false", async () => { + const cfg = {} as OpenClawConfig; + const aliasIndex = buildModelAliasIndex({ cfg, defaultProvider: "openai" }); + const sessionEntry: SessionEntry = { + sessionId: "s1", + updatedAt: Date.now(), + }; + const sessionStore: Record = { "agent:main:dm:1": sessionEntry }; + const sessionCtx = { BodyStripped: "minimax summarize" }; + const ctx = { ChatType: "direct" }; + + await applyResetModelOverride({ + cfg, + resetTriggered: false, + bodyStripped: "minimax summarize", + sessionCtx, + ctx, + sessionEntry, + sessionStore, + sessionKey: "agent:main:dm:1", + defaultProvider: "openai", + defaultModel: "gpt-4o-mini", + aliasIndex, + modelCatalog, + }); + + expect(sessionEntry.providerOverride).toBeUndefined(); + expect(sessionEntry.modelOverride).toBeUndefined(); + expect(sessionCtx.BodyStripped).toBe("minimax summarize"); + }); +}); diff --git a/src/auto-reply/reply/session-reset-model.ts b/src/auto-reply/reply/session-reset-model.ts index 101720e2dd2..f3e1d873e2c 100644 --- a/src/auto-reply/reply/session-reset-model.ts +++ b/src/auto-reply/reply/session-reset-model.ts @@ -1,4 +1,4 @@ -import { loadModelCatalog } from "../../agents/model-catalog.js"; +import { loadModelCatalog, type ModelCatalogEntry } from "../../agents/model-catalog.js"; import { buildAllowedModelSet, modelKey, @@ -99,6 +99,7 @@ export async function applyResetModelOverride(params: { defaultProvider: string; defaultModel: string; aliasIndex: ModelAliasIndex; + modelCatalog?: ModelCatalogEntry[]; }): Promise { if (!params.resetTriggered) { return {}; @@ -113,7 +114,7 @@ export async function applyResetModelOverride(params: { return {}; } - const catalog = await loadModelCatalog({ config: params.cfg }); + const catalog = params.modelCatalog ?? (await loadModelCatalog({ config: params.cfg })); const allowed = buildAllowedModelSet({ cfg: params.cfg, catalog, diff --git a/src/auto-reply/reply/session.test.ts b/src/auto-reply/reply/session.test.ts index 534936c142c..cd8f4184ade 100644 --- a/src/auto-reply/reply/session.test.ts +++ b/src/auto-reply/reply/session.test.ts @@ -3,7 +3,6 @@ import os from "node:os"; import path from "node:path"; import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; import * as bootstrapCache from "../../agents/bootstrap-cache.js"; -import { buildModelAliasIndex } from "../../agents/model-selection.js"; import type { OpenClawConfig } from "../../config/config.js"; import type { SessionEntry } from "../../config/sessions.js"; import { formatZonedTimestamp } from "../../infra/format-time/format-datetime.ts"; @@ -12,7 +11,6 @@ import { registerSessionBindingAdapter, } from "../../infra/outbound/session-binding-service.js"; import { enqueueSystemEvent, resetSystemEventsForTest } from "../../infra/system-events.js"; -import { applyResetModelOverride } from "./session-reset-model.js"; import { drainFormattedSystemEvents } from "./session-updates.js"; import { persistSessionUsageUpdate } from "./session-usage.js"; import { initSessionState } from "./session.js"; @@ -1261,101 +1259,6 @@ describe("initSessionState reset triggers in Slack channels", () => { }); }); -describe("applyResetModelOverride", () => { - it("selects a model hint and strips it from the body", async () => { - const cfg = {} as OpenClawConfig; - const aliasIndex = buildModelAliasIndex({ cfg, defaultProvider: "openai" }); - const sessionEntry: SessionEntry = { - sessionId: "s1", - updatedAt: Date.now(), - }; - const sessionStore: Record = { "agent:main:dm:1": sessionEntry }; - const sessionCtx = { BodyStripped: "minimax summarize" }; - const ctx = { ChatType: "direct" }; - - await applyResetModelOverride({ - cfg, - resetTriggered: true, - bodyStripped: "minimax summarize", - sessionCtx, - ctx, - sessionEntry, - sessionStore, - sessionKey: "agent:main:dm:1", - defaultProvider: "openai", - defaultModel: "gpt-4o-mini", - aliasIndex, - }); - - expect(sessionEntry.providerOverride).toBe("minimax"); - expect(sessionEntry.modelOverride).toBe("m2.7"); - expect(sessionCtx.BodyStripped).toBe("summarize"); - }); - - it("clears auth profile overrides when reset applies a model", async () => { - const cfg = {} as OpenClawConfig; - const aliasIndex = buildModelAliasIndex({ cfg, defaultProvider: "openai" }); - const sessionEntry: SessionEntry = { - sessionId: "s1", - updatedAt: Date.now(), - authProfileOverride: "anthropic:default", - authProfileOverrideSource: "user", - authProfileOverrideCompactionCount: 2, - }; - const sessionStore: Record = { "agent:main:dm:1": sessionEntry }; - const sessionCtx = { BodyStripped: "minimax summarize" }; - const ctx = { ChatType: "direct" }; - - await applyResetModelOverride({ - cfg, - resetTriggered: true, - bodyStripped: "minimax summarize", - sessionCtx, - ctx, - sessionEntry, - sessionStore, - sessionKey: "agent:main:dm:1", - defaultProvider: "openai", - defaultModel: "gpt-4o-mini", - aliasIndex, - }); - - expect(sessionEntry.authProfileOverride).toBeUndefined(); - expect(sessionEntry.authProfileOverrideSource).toBeUndefined(); - expect(sessionEntry.authProfileOverrideCompactionCount).toBeUndefined(); - }); - - it("skips when resetTriggered is false", async () => { - const cfg = {} as OpenClawConfig; - const aliasIndex = buildModelAliasIndex({ cfg, defaultProvider: "openai" }); - const sessionEntry: SessionEntry = { - sessionId: "s1", - updatedAt: Date.now(), - }; - const sessionStore: Record = { "agent:main:dm:1": sessionEntry }; - const sessionCtx = { BodyStripped: "minimax summarize" }; - const ctx = { ChatType: "direct" }; - - await applyResetModelOverride({ - cfg, - resetTriggered: false, - bodyStripped: "minimax summarize", - sessionCtx, - ctx, - sessionEntry, - sessionStore, - sessionKey: "agent:main:dm:1", - defaultProvider: "openai", - defaultModel: "gpt-4o-mini", - aliasIndex, - }); - - expect(sessionEntry.providerOverride).toBeUndefined(); - expect(sessionEntry.modelOverride).toBeUndefined(); - expect(sessionCtx.BodyStripped).toBe("minimax summarize"); - }); -}); - describe("initSessionState preserves behavior overrides across /new and /reset", () => { async function seedSessionStoreWithOverrides(params: { storePath: string;