fix: align websocket stream fallback types

This commit is contained in:
Peter Steinberger
2026-03-23 08:59:16 +00:00
parent fb602c9b02
commit 4ea014d581
2 changed files with 18 additions and 11 deletions

View File

@@ -14,7 +14,12 @@
* Skipped in CI — no API key available and we avoid billable external calls.
*/
import type { AssistantMessage, Context } from "@mariozechner/pi-ai";
import type {
AssistantMessage,
AssistantMessageEvent,
AssistantMessageEventStream,
Context,
} from "@mariozechner/pi-ai";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
const API_KEY = process.env.OPENAI_API_KEY;
@@ -23,6 +28,7 @@ const testFn = LIVE ? it : it.skip;
type OpenAIWsStreamModule = typeof import("./openai-ws-stream.js");
type StreamFactory = OpenAIWsStreamModule["createOpenAIWebSocketStreamFn"];
type StreamReturn = ReturnType<ReturnType<StreamFactory>>;
let openAIWsStreamModule: OpenAIWsStreamModule;
const model = {
@@ -78,17 +84,16 @@ function makeToolResultMessage(
} as unknown as StreamFnParams[1]["messages"][number];
}
async function collectEvents(
stream: ReturnType<ReturnType<typeof createOpenAIWebSocketStreamFn>>,
): Promise<Array<{ type: string; message?: AssistantMessage }>> {
const events: Array<{ type: string; message?: AssistantMessage }> = [];
for await (const event of stream as AsyncIterable<{ type: string; message?: AssistantMessage }>) {
async function collectEvents(stream: StreamReturn): Promise<AssistantMessageEvent[]> {
const events: AssistantMessageEvent[] = [];
const resolvedStream: AssistantMessageEventStream = await stream;
for await (const event of resolvedStream) {
events.push(event);
}
return events;
}
function expectDone(events: Array<{ type: string; message?: AssistantMessage }>): AssistantMessage {
function expectDone(events: AssistantMessageEvent[]): AssistantMessage {
const done = events.find((event) => event.type === "done")?.message;
expect(done).toBeDefined();
return done!;

View File

@@ -26,6 +26,7 @@ import type { StreamFn } from "@mariozechner/pi-agent-core";
import type {
AssistantMessage,
AssistantMessageEvent,
AssistantMessageEventStream,
Context,
Message,
StopReason,
@@ -69,10 +70,11 @@ interface WsSession {
/** Module-level registry: sessionId → WsSession */
const wsRegistry = new Map<string, WsSession>();
type AssistantMessageEventStreamLike = AsyncIterable<AssistantMessageEvent> & {
type AssistantMessageEventStreamLike = {
push(event: AssistantMessageEvent): void;
end(result?: AssistantMessage): void;
result(): Promise<AssistantMessage>;
[Symbol.asyncIterator](): AsyncIterator<AssistantMessageEvent>;
};
class LocalAssistantMessageEventStream implements AssistantMessageEventStreamLike {
@@ -114,7 +116,7 @@ class LocalAssistantMessageEventStream implements AssistantMessageEventStreamLik
}
while (this.waiting.length > 0) {
const waiter = this.waiting.shift();
waiter?.({ value: undefined as AssistantMessageEvent, done: true });
waiter?.({ value: undefined as unknown as AssistantMessageEvent, done: true });
}
}
@@ -142,10 +144,10 @@ class LocalAssistantMessageEventStream implements AssistantMessageEventStreamLik
}
}
function createEventStream(): AssistantMessageEventStreamLike {
function createEventStream(): AssistantMessageEventStream {
return typeof createAssistantMessageEventStream === "function"
? createAssistantMessageEventStream()
: new LocalAssistantMessageEventStream();
: (new LocalAssistantMessageEventStream() as unknown as AssistantMessageEventStream);
}
// ─────────────────────────────────────────────────────────────────────────────