diff --git a/extensions/lmstudio/src/stream.test.ts b/extensions/lmstudio/src/stream.test.ts index c5976c17b76..496f3dbe81a 100644 --- a/extensions/lmstudio/src/stream.test.ts +++ b/extensions/lmstudio/src/stream.test.ts @@ -36,6 +36,57 @@ afterAll(() => { type StreamEvent = { type: string } & Record; +function requireRecord(value: unknown, label: string): Record { + if (!value || typeof value !== "object" || Array.isArray(value)) { + throw new Error(`expected ${label} to be a record`); + } + return value as Record; +} + +function expectRecordFields(record: Record, fields: Record) { + for (const [key, value] of Object.entries(fields)) { + expect(record[key]).toEqual(value); + } +} + +function expectSingleDoneEvent(events: StreamEvent[]) { + expect(events).toHaveLength(1); + expect(events[0]?.type).toBe("done"); +} + +function requireMockCallArg(mock: { mock: { calls: unknown[][] } }, label: string) { + const call = mock.mock.calls[0]; + if (!call) { + throw new Error(`expected ${label} call`); + } + return call; +} + +function expectEnsureLoadedFields(fields: Record) { + const [params] = requireMockCallArg(ensureLmstudioModelLoadedMock, "ensureLmstudioModelLoaded"); + const record = requireRecord(params, "ensureLmstudioModelLoaded params"); + for (const [key, value] of Object.entries(fields)) { + if (key === "ssrfPolicy") { + expectRecordFields( + requireRecord(record.ssrfPolicy, "ssrfPolicy"), + value as Record, + ); + } else { + expect(record[key]).toEqual(value); + } + } +} + +function expectBaseStreamModelFields(baseStream: StreamFn, fields: Record) { + const call = requireMockCallArg( + baseStream as unknown as { mock: { calls: unknown[][] } }, + "base stream", + ); + expectRecordFields(requireRecord(call[0], "base stream model"), fields); + expect(call[1]).toBeDefined(); + expect(call[2]).toBeUndefined(); +} + async function collectEvents(stream: ReturnType): Promise { const resolved = stream instanceof Promise ? await stream : stream; const events: StreamEvent[] = []; @@ -135,17 +186,15 @@ describe("lmstudio stream wrapper", () => { ); const events = await collectEvents(stream); - expect(events).toEqual([expect.objectContaining({ type: "done" })]); + expectSingleDoneEvent(events); expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledTimes(1); - expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledWith( - expect.objectContaining({ - baseUrl: "http://lmstudio.internal:1234/v1", - modelKey: "qwen3-8b-instruct", - requestedContextLength: 131072, - apiKey: "lmstudio-token", - ssrfPolicy: { allowedHostnames: ["lmstudio.internal"] }, - }), - ); + expectEnsureLoadedFields({ + baseUrl: "http://lmstudio.internal:1234/v1", + modelKey: "qwen3-8b-instruct", + requestedContextLength: 131072, + apiKey: "lmstudio-token", + ssrfPolicy: { allowedHostnames: ["lmstudio.internal"] }, + }); }); it("prefers model contextTokens over contextWindow for preload requests", async () => { @@ -160,17 +209,15 @@ describe("lmstudio stream wrapper", () => { ); const events = await collectEvents(stream); - expect(events).toEqual([expect.objectContaining({ type: "done" })]); + expectSingleDoneEvent(events); expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledTimes(1); - expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledWith( - expect.objectContaining({ - baseUrl: "http://lmstudio.internal:1234/v1", - modelKey: "qwen3-8b-instruct", - requestedContextLength: 64000, - apiKey: "lmstudio-token", - ssrfPolicy: { allowedHostnames: ["lmstudio.internal"] }, - }), - ); + expectEnsureLoadedFields({ + baseUrl: "http://lmstudio.internal:1234/v1", + modelKey: "qwen3-8b-instruct", + requestedContextLength: 64000, + apiKey: "lmstudio-token", + ssrfPolicy: { allowedHostnames: ["lmstudio.internal"] }, + }); }); it("continues inference when preload fails", async () => { @@ -202,7 +249,7 @@ describe("lmstudio stream wrapper", () => { undefined as never, ); const events = await collectEvents(stream); - expect(events).toEqual([expect.objectContaining({ type: "done" })]); + expectSingleDoneEvent(events); expect(baseStream).toHaveBeenCalledTimes(1); }); @@ -237,16 +284,16 @@ describe("lmstudio stream wrapper", () => { ), ); - expect(events).toEqual([expect.objectContaining({ type: "done" })]); + expectSingleDoneEvent(events); expect(ensureLmstudioModelLoadedMock).not.toHaveBeenCalled(); expect(baseStream).toHaveBeenCalledTimes(1); - expect(baseStream).toHaveBeenCalledWith( - expect.objectContaining({ - compat: expect.objectContaining({ supportsUsageInStreaming: true }), - }), - expect.anything(), - undefined, + const [model] = requireMockCallArg( + baseStream as unknown as { mock: { calls: unknown[][] } }, + "base stream", ); + expectRecordFields(requireRecord(requireRecord(model, "base stream model").compat, "compat"), { + supportsUsageInStreaming: true, + }); }); it("dedupes concurrent preload requests for the same model and context", async () => { @@ -308,8 +355,8 @@ describe("lmstudio stream wrapper", () => { resolvePreload(); const [firstEvents, secondEvents] = await Promise.all([firstPromise, secondPromise]); - expect(firstEvents).toEqual([expect.objectContaining({ type: "done" })]); - expect(secondEvents).toEqual([expect.objectContaining({ type: "done" })]); + expectSingleDoneEvent(firstEvents); + expectSingleDoneEvent(secondEvents); expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledTimes(1); }); @@ -343,7 +390,7 @@ describe("lmstudio stream wrapper", () => { undefined as never, ), ); - expect(firstEvents).toEqual([expect.objectContaining({ type: "done" })]); + expectSingleDoneEvent(firstEvents); expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledTimes(1); const secondEvents = await collectEvents( @@ -357,7 +404,7 @@ describe("lmstudio stream wrapper", () => { undefined as never, ), ); - expect(secondEvents).toEqual([expect.objectContaining({ type: "done" })]); + expectSingleDoneEvent(secondEvents); // The second call must NOT retry preload because cooldown is active, but // the underlying stream must still run so the user gets a response. expect(ensureLmstudioModelLoadedMock).toHaveBeenCalledTimes(1); @@ -450,19 +497,17 @@ describe("lmstudio stream wrapper", () => { ); const events = await collectEvents(stream); - expect(events).toEqual([expect.objectContaining({ type: "done" })]); + expectSingleDoneEvent(events); expect(baseStream).toHaveBeenCalledTimes(1); - expect(baseStream).toHaveBeenCalledWith( - expect.objectContaining({ - provider: "lmstudio", - compat: expect.objectContaining({ - supportsDeveloperRole: false, - supportsUsageInStreaming: true, - }), - }), - expect.anything(), - undefined, + expectBaseStreamModelFields(baseStream, { provider: "lmstudio" }); + const [model] = requireMockCallArg( + baseStream as unknown as { mock: { calls: unknown[][] } }, + "base stream", ); + expectRecordFields(requireRecord(requireRecord(model, "base stream model").compat, "compat"), { + supportsDeveloperRole: false, + supportsUsageInStreaming: true, + }); }); it("promotes standalone bracketed local-model tool text to a structured tool call", async () => { @@ -511,12 +556,13 @@ describe("lmstudio stream wrapper", () => { }; expect(done.reason).toBe("toolUse"); expect(done.message?.stopReason).toBe("toolUse"); - expect(done.message?.content?.[0]).toMatchObject({ + const toolCall = requireRecord(done.message?.content?.[0], "tool call content"); + expectRecordFields(toolCall, { type: "toolCall", name: "mempalace_mempalace_search", arguments: { query: "codename", wing: "personal", room: "identities" }, }); - expect(String(done.message?.content?.[0]?.id)).toMatch(/^call_[a-f0-9]{24}$/); + expect(String(toolCall.id)).toMatch(/^call_[a-f0-9]{24}$/); }); it("promotes standalone Harmony local-model tool text to a structured tool call", async () => { @@ -555,7 +601,7 @@ describe("lmstudio stream wrapper", () => { reason?: string; }; expect(done.reason).toBe("toolUse"); - expect(done.message?.content?.[0]).toMatchObject({ + expectRecordFields(requireRecord(done.message?.content?.[0], "tool call content"), { type: "toolCall", name: "read", arguments: { path: "/path/to/file", line_start: 1, line_end: 400 }, @@ -597,8 +643,14 @@ describe("lmstudio stream wrapper", () => { "text_end", "done", ]); - expect(events.find((event) => event.type === "text_delta")).toMatchObject({ - delta: rawToolText, - }); + expectRecordFields( + requireRecord( + events.find((event) => event.type === "text_delta"), + "text delta", + ), + { + delta: rawToolText, + }, + ); }); });