refactor: harden msteams lifecycle and attachment flows

This commit is contained in:
Peter Steinberger
2026-03-02 21:18:22 +00:00
parent d98a61a977
commit 866bd91c65
15 changed files with 459 additions and 112 deletions

View File

@@ -6,12 +6,12 @@ import {
isDownloadableAttachment,
isRecord,
isUrlAllowed,
type MSTeamsAttachmentFetchPolicy,
normalizeContentType,
resolveMediaSsrfPolicy,
resolveAttachmentFetchPolicy,
resolveRequestUrl,
resolveAuthAllowedHosts,
resolveAllowedHosts,
safeFetch,
safeFetchWithPolicy,
} from "./shared.js";
import type {
MSTeamsAccessTokenProvider,
@@ -95,12 +95,11 @@ async function fetchWithAuthFallback(params: {
tokenProvider?: MSTeamsAccessTokenProvider;
fetchFn?: typeof fetch;
requestInit?: RequestInit;
allowHosts: string[];
authAllowHosts: string[];
policy: MSTeamsAttachmentFetchPolicy;
}): Promise<Response> {
const firstAttempt = await safeFetch({
const firstAttempt = await safeFetchWithPolicy({
url: params.url,
allowHosts: params.allowHosts,
policy: params.policy,
fetchFn: params.fetchFn,
requestInit: params.requestInit,
});
@@ -113,7 +112,7 @@ async function fetchWithAuthFallback(params: {
if (firstAttempt.status !== 401 && firstAttempt.status !== 403) {
return firstAttempt;
}
if (!isUrlAllowed(params.url, params.authAllowHosts)) {
if (!isUrlAllowed(params.url, params.policy.authAllowHosts)) {
return firstAttempt;
}
@@ -124,10 +123,9 @@ async function fetchWithAuthFallback(params: {
const token = await params.tokenProvider.getAccessToken(scope);
const authHeaders = new Headers(params.requestInit?.headers);
authHeaders.set("Authorization", `Bearer ${token}`);
const authAttempt = await safeFetch({
const authAttempt = await safeFetchWithPolicy({
url: params.url,
allowHosts: params.allowHosts,
authorizationAllowHosts: params.authAllowHosts,
policy: params.policy,
fetchFn,
requestInit: {
...params.requestInit,
@@ -171,8 +169,11 @@ export async function downloadMSTeamsAttachments(params: {
if (list.length === 0) {
return [];
}
const allowHosts = resolveAllowedHosts(params.allowHosts);
const authAllowHosts = resolveAuthAllowedHosts(params.authAllowHosts);
const policy = resolveAttachmentFetchPolicy({
allowHosts: params.allowHosts,
authAllowHosts: params.authAllowHosts,
});
const allowHosts = policy.allowHosts;
const ssrfPolicy = resolveMediaSsrfPolicy(allowHosts);
// Download ANY downloadable attachment (not just images)
@@ -249,8 +250,7 @@ export async function downloadMSTeamsAttachments(params: {
tokenProvider: params.tokenProvider,
fetchFn: params.fetchFn,
requestInit: init,
allowHosts,
authAllowHosts,
policy,
}),
});
out.push(media);

View File

@@ -3,16 +3,17 @@ import { getMSTeamsRuntime } from "../runtime.js";
import { downloadMSTeamsAttachments } from "./download.js";
import { downloadAndStoreMSTeamsRemoteMedia } from "./remote-media.js";
import {
applyAuthorizationHeaderForUrl,
GRAPH_ROOT,
inferPlaceholder,
isRecord,
isUrlAllowed,
type MSTeamsAttachmentFetchPolicy,
normalizeContentType,
resolveMediaSsrfPolicy,
resolveAttachmentFetchPolicy,
resolveRequestUrl,
resolveAuthAllowedHosts,
resolveAllowedHosts,
safeFetch,
safeFetchWithPolicy,
} from "./shared.js";
import type {
MSTeamsAccessTokenProvider,
@@ -243,9 +244,11 @@ export async function downloadMSTeamsGraphMedia(params: {
if (!params.messageUrl || !params.tokenProvider) {
return { media: [] };
}
const allowHosts = resolveAllowedHosts(params.allowHosts);
const authAllowHosts = resolveAuthAllowedHosts(params.authAllowHosts);
const ssrfPolicy = resolveMediaSsrfPolicy(allowHosts);
const policy: MSTeamsAttachmentFetchPolicy = resolveAttachmentFetchPolicy({
allowHosts: params.allowHosts,
authAllowHosts: params.authAllowHosts,
});
const ssrfPolicy = resolveMediaSsrfPolicy(policy.allowHosts);
const messageUrl = params.messageUrl;
let accessToken: string;
try {
@@ -291,7 +294,7 @@ export async function downloadMSTeamsGraphMedia(params: {
try {
// SharePoint URLs need to be accessed via Graph shares API
const shareUrl = att.contentUrl!;
if (!isUrlAllowed(shareUrl, allowHosts)) {
if (!isUrlAllowed(shareUrl, policy.allowHosts)) {
continue;
}
const encodedUrl = Buffer.from(shareUrl).toString("base64url");
@@ -307,15 +310,15 @@ export async function downloadMSTeamsGraphMedia(params: {
fetchImpl: async (input, init) => {
const requestUrl = resolveRequestUrl(input);
const headers = new Headers(init?.headers);
if (isUrlAllowed(requestUrl, authAllowHosts)) {
headers.set("Authorization", `Bearer ${accessToken}`);
} else {
headers.delete("Authorization");
}
return await safeFetch({
applyAuthorizationHeaderForUrl({
headers,
url: requestUrl,
allowHosts,
authorizationAllowHosts: authAllowHosts,
authAllowHosts: policy.authAllowHosts,
bearerToken: accessToken,
});
return await safeFetchWithPolicy({
url: requestUrl,
policy,
fetchFn,
requestInit: {
...init,
@@ -373,8 +376,8 @@ export async function downloadMSTeamsGraphMedia(params: {
attachments: filteredAttachments,
maxBytes: params.maxBytes,
tokenProvider: params.tokenProvider,
allowHosts,
authAllowHosts,
allowHosts: policy.allowHosts,
authAllowHosts: policy.authAllowHosts,
fetchFn: params.fetchFn,
preserveFilenames: params.preserveFilenames,
});

View File

@@ -1,12 +1,15 @@
import { describe, expect, it, vi } from "vitest";
import {
applyAuthorizationHeaderForUrl,
isPrivateOrReservedIP,
isUrlAllowed,
resolveAndValidateIP,
resolveAttachmentFetchPolicy,
resolveAllowedHosts,
resolveAuthAllowedHosts,
resolveMediaSsrfPolicy,
safeFetch,
safeFetchWithPolicy,
} from "./shared.js";
const publicResolve = async () => ({ address: "13.107.136.10" });
@@ -34,6 +37,18 @@ describe("msteams attachment allowlists", () => {
expect(resolveAuthAllowedHosts(["*", "graph.microsoft.com"])).toEqual(["*"]);
});
it("resolves a normalized attachment fetch policy", () => {
expect(
resolveAttachmentFetchPolicy({
allowHosts: ["sharepoint.com"],
authAllowHosts: ["graph.microsoft.com"],
}),
).toEqual({
allowHosts: ["sharepoint.com"],
authAllowHosts: ["graph.microsoft.com"],
});
});
it("requires https and host suffix match", () => {
const allowHosts = resolveAllowedHosts(["sharepoint.com"]);
expect(isUrlAllowed("https://contoso.sharepoint.com/file.png", allowHosts)).toBe(true);
@@ -294,4 +309,70 @@ describe("safeFetch", () => {
}),
).rejects.toThrow("blocked by allowlist");
});
it("strips authorization across redirects outside auth allowlist", async () => {
const seenAuth: string[] = [];
const fetchMock = vi.fn(async (url: string, init?: RequestInit) => {
const auth = new Headers(init?.headers).get("authorization") ?? "";
seenAuth.push(`${url}|${auth}`);
if (url === "https://teams.sharepoint.com/file.pdf") {
return new Response(null, {
status: 302,
headers: { location: "https://cdn.sharepoint.com/storage/file.pdf" },
});
}
return new Response("ok", { status: 200 });
});
const headers = new Headers({ Authorization: "Bearer secret" });
const res = await safeFetch({
url: "https://teams.sharepoint.com/file.pdf",
allowHosts: ["sharepoint.com"],
authorizationAllowHosts: ["graph.microsoft.com"],
fetchFn: fetchMock as unknown as typeof fetch,
requestInit: { headers },
resolveFn: publicResolve,
});
expect(res.status).toBe(200);
expect(seenAuth[0]).toContain("Bearer secret");
expect(seenAuth[1]).toMatch(/\|$/);
});
});
describe("attachment fetch auth helpers", () => {
it("sets and clears authorization header by auth allowlist", () => {
const headers = new Headers();
applyAuthorizationHeaderForUrl({
headers,
url: "https://graph.microsoft.com/v1.0/me",
authAllowHosts: ["graph.microsoft.com"],
bearerToken: "token-1",
});
expect(headers.get("authorization")).toBe("Bearer token-1");
applyAuthorizationHeaderForUrl({
headers,
url: "https://evil.example.com/collect",
authAllowHosts: ["graph.microsoft.com"],
bearerToken: "token-1",
});
expect(headers.get("authorization")).toBeNull();
});
it("safeFetchWithPolicy forwards policy allowlists", async () => {
const fetchMock = vi.fn(async (_url: string, _init?: RequestInit) => {
return new Response("ok", { status: 200 });
});
const res = await safeFetchWithPolicy({
url: "https://teams.sharepoint.com/file.pdf",
policy: resolveAttachmentFetchPolicy({
allowHosts: ["sharepoint.com"],
authAllowHosts: ["graph.microsoft.com"],
}),
fetchFn: fetchMock as unknown as typeof fetch,
resolveFn: publicResolve,
});
expect(res.status).toBe(200);
expect(fetchMock).toHaveBeenCalledOnce();
});
});

View File

@@ -266,10 +266,42 @@ export function resolveAuthAllowedHosts(input?: string[]): string[] {
return normalizeHostnameSuffixAllowlist(input, DEFAULT_MEDIA_AUTH_HOST_ALLOWLIST);
}
export type MSTeamsAttachmentFetchPolicy = {
allowHosts: string[];
authAllowHosts: string[];
};
export function resolveAttachmentFetchPolicy(params?: {
allowHosts?: string[];
authAllowHosts?: string[];
}): MSTeamsAttachmentFetchPolicy {
return {
allowHosts: resolveAllowedHosts(params?.allowHosts),
authAllowHosts: resolveAuthAllowedHosts(params?.authAllowHosts),
};
}
export function isUrlAllowed(url: string, allowlist: string[]): boolean {
return isHttpsUrlAllowedByHostnameSuffixAllowlist(url, allowlist);
}
export function applyAuthorizationHeaderForUrl(params: {
headers: Headers;
url: string;
authAllowHosts: string[];
bearerToken?: string;
}): void {
if (!params.bearerToken) {
params.headers.delete("Authorization");
return;
}
if (isUrlAllowed(params.url, params.authAllowHosts)) {
params.headers.set("Authorization", `Bearer ${params.bearerToken}`);
return;
}
params.headers.delete("Authorization");
}
export function resolveMediaSsrfPolicy(allowHosts: string[]): SsrFPolicy | undefined {
return buildHostnameAllowlistPolicyFromSuffixAllowlist(allowHosts);
}
@@ -408,3 +440,20 @@ export async function safeFetch(params: {
throw new Error(`Too many redirects (>${MAX_SAFE_REDIRECTS})`);
}
export async function safeFetchWithPolicy(params: {
url: string;
policy: MSTeamsAttachmentFetchPolicy;
fetchFn?: typeof fetch;
requestInit?: RequestInit;
resolveFn?: (hostname: string) => Promise<{ address: string }>;
}): Promise<Response> {
return await safeFetch({
url: params.url,
allowHosts: params.policy.allowHosts,
authorizationAllowHosts: params.policy.authAllowHosts,
fetchFn: params.fetchFn,
requestInit: params.requestInit,
resolveFn: params.resolveFn,
});
}

View File

@@ -10,7 +10,7 @@ import {
} from "openclaw/plugin-sdk";
import type { MSTeamsAccessTokenProvider } from "./attachments/types.js";
import type { StoredConversationReference } from "./conversation-store.js";
import { classifyMSTeamsSendError, isRevokedProxyError } from "./errors.js";
import { classifyMSTeamsSendError } from "./errors.js";
import { prepareFileConsentActivity, requiresFileConsent } from "./file-consent-helpers.js";
import { buildTeamsFileInfoCard } from "./graph-chat.js";
import {
@@ -20,6 +20,7 @@ import {
} from "./graph-upload.js";
import { extractFilename, extractMessageId, getMimeType, isLocalPath } from "./media-helpers.js";
import { parseMentions } from "./mentions.js";
import { withRevokedProxyFallback } from "./revoked-context.js";
import { getMSTeamsRuntime } from "./runtime.js";
/**
@@ -441,34 +442,42 @@ export async function sendMSTeamsMessages(params: {
}
};
const sendMessagesInContext = async (
const sendMessageInContext = async (
ctx: SendContext,
batch: MSTeamsRenderedMessage[] = messages,
offset = 0,
message: MSTeamsRenderedMessage,
messageIndex: number,
): Promise<string> => {
const response = await sendWithRetry(
async () =>
await ctx.sendActivity(
await buildActivity(
message,
params.conversationRef,
params.tokenProvider,
params.sharePointSiteId,
params.mediaMaxBytes,
),
),
{ messageIndex, messageCount: messages.length },
);
return extractMessageId(response) ?? "unknown";
};
const sendMessageBatchInContext = async (
ctx: SendContext,
batch: MSTeamsRenderedMessage[],
startIndex: number,
): Promise<string[]> => {
const messageIds: string[] = [];
for (const [idx, message] of batch.entries()) {
const response = await sendWithRetry(
async () =>
await ctx.sendActivity(
await buildActivity(
message,
params.conversationRef,
params.tokenProvider,
params.sharePointSiteId,
params.mediaMaxBytes,
),
),
{ messageIndex: offset + idx, messageCount: messages.length },
);
messageIds.push(extractMessageId(response) ?? "unknown");
messageIds.push(await sendMessageInContext(ctx, message, startIndex + idx));
}
return messageIds;
};
const sendProactively = async (
batch: MSTeamsRenderedMessage[] = messages,
offset = 0,
batch: MSTeamsRenderedMessage[],
startIndex: number,
): Promise<string[]> => {
const baseRef = buildConversationReference(params.conversationRef);
const proactiveRef: MSTeamsConversationReference = {
@@ -478,7 +487,7 @@ export async function sendMSTeamsMessages(params: {
const messageIds: string[] = [];
await params.adapter.continueConversation(params.appId, proactiveRef, async (ctx) => {
messageIds.push(...(await sendMessagesInContext(ctx, batch, offset)));
messageIds.push(...(await sendMessageBatchInContext(ctx, batch, startIndex)));
});
return messageIds;
};
@@ -490,16 +499,21 @@ export async function sendMSTeamsMessages(params: {
}
const messageIds: string[] = [];
for (const [idx, message] of messages.entries()) {
try {
messageIds.push(...(await sendMessagesInContext(ctx, [message], idx)));
} catch (err) {
if (!isRevokedProxyError(err)) {
throw err;
}
const remaining = messages.slice(idx);
if (remaining.length > 0) {
messageIds.push(...(await sendProactively(remaining, idx)));
}
const result = await withRevokedProxyFallback({
run: async () => ({
ids: [await sendMessageInContext(ctx, message, idx)],
fellBack: false,
}),
onRevoked: async () => {
const remaining = messages.slice(idx);
return {
ids: remaining.length > 0 ? await sendProactively(remaining, idx) : [],
fellBack: true,
};
},
});
messageIds.push(...result.ids);
if (result.fellBack) {
return messageIds;
}
}

View File

@@ -155,10 +155,7 @@ describe("msteams file consent invoke authz", () => {
}),
);
// Wait for async upload to complete
await vi.waitFor(() => {
expect(fileConsentMockState.uploadToConsentUrl).toHaveBeenCalledTimes(1);
});
expect(fileConsentMockState.uploadToConsentUrl).toHaveBeenCalledTimes(1);
expect(fileConsentMockState.uploadToConsentUrl).toHaveBeenCalledWith(
expect.objectContaining({
@@ -192,12 +189,9 @@ describe("msteams file consent invoke authz", () => {
}),
);
// Wait for async handler to complete
await vi.waitFor(() => {
expect(sendActivity).toHaveBeenCalledWith(
"The file upload request has expired. Please try sending the file again.",
);
});
expect(sendActivity).toHaveBeenCalledWith(
"The file upload request has expired. Please try sending the file again.",
);
expect(fileConsentMockState.uploadToConsentUrl).not.toHaveBeenCalled();
expect(getPendingUpload(uploadId)).toBeDefined();

View File

@@ -7,6 +7,7 @@ import { createMSTeamsMessageHandler } from "./monitor-handler/message-handler.j
import type { MSTeamsMonitorLogger } from "./monitor-types.js";
import { getPendingUpload, removePendingUpload } from "./pending-uploads.js";
import type { MSTeamsPollStore } from "./polls.js";
import { withRevokedProxyFallback } from "./revoked-context.js";
import type { MSTeamsTurnContext } from "./sdk-types.js";
export type MSTeamsAccessTokenProvider = {
@@ -146,10 +147,19 @@ export function registerMSTeamsHandlers<T extends MSTeamsActivityHandler>(
// Send invoke response IMMEDIATELY to prevent Teams timeout
await ctx.sendActivity({ type: "invokeResponse", value: { status: 200 } });
// Handle file upload asynchronously (don't await)
handleFileConsentInvoke(ctx, deps.log).catch((err) => {
try {
await withRevokedProxyFallback({
run: async () => await handleFileConsentInvoke(ctx, deps.log),
onRevoked: async () => true,
onRevokedLog: () => {
deps.log.debug?.(
"turn context revoked during file consent invoke; skipping delayed response",
);
},
});
} catch (err) {
deps.log.debug?.("file consent handler error", { error: String(err) });
});
}
return;
}
return originalRun.call(handler, context);

View File

@@ -6,6 +6,9 @@ import type { MSTeamsPollStore } from "./polls.js";
type FakeServer = EventEmitter & {
close: (callback?: (err?: Error | null) => void) => void;
setTimeout: (msecs: number) => FakeServer;
requestTimeout: number;
headersTimeout: number;
};
const expressControl = vi.hoisted(() => ({
@@ -14,6 +17,18 @@ const expressControl = vi.hoisted(() => ({
vi.mock("openclaw/plugin-sdk", () => ({
DEFAULT_WEBHOOK_MAX_BODY_BYTES: 1024 * 1024,
keepHttpServerTaskAlive: vi.fn(
async (params: { abortSignal?: AbortSignal; onAbort?: () => Promise<void> | void }) => {
await new Promise<void>((resolve) => {
if (params.abortSignal?.aborted) {
resolve();
return;
}
params.abortSignal?.addEventListener("abort", () => resolve(), { once: true });
});
await params.onAbort?.();
},
),
mergeAllowlist: (params: { existing?: string[]; additions?: string[] }) =>
Array.from(new Set([...(params.existing ?? []), ...(params.additions ?? [])])),
summarizeMapping: vi.fn(),
@@ -31,6 +46,9 @@ vi.mock("express", () => {
post: vi.fn(),
listen: vi.fn((_port: number) => {
const server = new EventEmitter() as FakeServer;
server.setTimeout = vi.fn((_msecs: number) => server);
server.requestTimeout = 0;
server.headersTimeout = 0;
server.close = (callback?: (err?: Error | null) => void) => {
queueMicrotask(() => {
server.emit("close");

View File

@@ -2,6 +2,7 @@ import type { Server } from "node:http";
import type { Request, Response } from "express";
import {
DEFAULT_WEBHOOK_MAX_BODY_BYTES,
keepHttpServerTaskAlive,
mergeAllowlist,
summarizeMapping,
type OpenClawConfig,
@@ -333,25 +334,12 @@ export async function monitorMSTeamsProvider(
});
};
// Handle abort signal
const onAbort = () => {
void shutdown();
};
if (opts.abortSignal) {
if (opts.abortSignal.aborted) {
onAbort();
} else {
opts.abortSignal.addEventListener("abort", onAbort, { once: true });
}
}
// Keep this task alive until shutdown/close so gateway runtime does not treat startup as exit.
await new Promise<void>((resolve) => {
httpServer.once("close", () => {
resolve();
});
// Keep this task alive until close so gateway runtime does not treat startup as exit.
await keepHttpServerTaskAlive({
server: httpServer,
abortSignal: opts.abortSignal,
onAbort: shutdown,
});
opts.abortSignal?.removeEventListener("abort", onAbort);
return { app: expressApp, shutdown };
}

View File

@@ -13,7 +13,6 @@ import {
classifyMSTeamsSendError,
formatMSTeamsSendErrorHint,
formatUnknownError,
isRevokedProxyError,
} from "./errors.js";
import {
buildConversationReference,
@@ -22,6 +21,7 @@ import {
sendMSTeamsMessages,
} from "./messenger.js";
import type { MSTeamsMonitorLogger } from "./monitor-types.js";
import { withRevokedProxyFallback } from "./revoked-context.js";
import { getMSTeamsRuntime } from "./runtime.js";
import type { MSTeamsTurnContext } from "./sdk-types.js";
@@ -53,23 +53,24 @@ export function createMSTeamsReplyDispatcher(params: {
* the stored conversation reference so the user still sees the "…" bubble.
*/
const sendTypingIndicator = async () => {
try {
await params.context.sendActivity({ type: "typing" });
} catch (err) {
if (!isRevokedProxyError(err)) {
throw err;
}
// Turn context revoked — fall back to proactive typing.
params.log.debug?.("turn context revoked, sending typing via proactive messaging");
const baseRef = buildConversationReference(params.conversationRef);
await params.adapter.continueConversation(
params.appId,
{ ...baseRef, activityId: undefined },
async (ctx) => {
await ctx.sendActivity({ type: "typing" });
},
);
}
await withRevokedProxyFallback({
run: async () => {
await params.context.sendActivity({ type: "typing" });
},
onRevoked: async () => {
const baseRef = buildConversationReference(params.conversationRef);
await params.adapter.continueConversation(
params.appId,
{ ...baseRef, activityId: undefined },
async (ctx) => {
await ctx.sendActivity({ type: "typing" });
},
);
},
onRevokedLog: () => {
params.log.debug?.("turn context revoked, sending typing via proactive messaging");
},
});
};
const typingCallbacks = createTypingCallbacks({

View File

@@ -0,0 +1,39 @@
import { describe, expect, it, vi } from "vitest";
import { withRevokedProxyFallback } from "./revoked-context.js";
describe("msteams revoked context helper", () => {
it("returns primary result when no error occurs", async () => {
await expect(
withRevokedProxyFallback({
run: async () => "ok",
onRevoked: async () => "fallback",
}),
).resolves.toBe("ok");
});
it("uses fallback when proxy-revoked TypeError is thrown", async () => {
const onRevokedLog = vi.fn();
await expect(
withRevokedProxyFallback({
run: async () => {
throw new TypeError("Cannot perform 'get' on a proxy that has been revoked");
},
onRevoked: async () => "fallback",
onRevokedLog,
}),
).resolves.toBe("fallback");
expect(onRevokedLog).toHaveBeenCalledOnce();
});
it("rethrows non-revoked errors", async () => {
const err = Object.assign(new Error("boom"), { statusCode: 500 });
await expect(
withRevokedProxyFallback({
run: async () => {
throw err;
},
onRevoked: async () => "fallback",
}),
).rejects.toBe(err);
});
});

View File

@@ -0,0 +1,17 @@
import { isRevokedProxyError } from "./errors.js";
export async function withRevokedProxyFallback<T>(params: {
run: () => Promise<T>;
onRevoked: () => Promise<T>;
onRevokedLog?: () => void;
}): Promise<T> {
try {
return await params.run();
} catch (err) {
if (!isRevokedProxyError(err)) {
throw err;
}
params.onRevokedLog?.();
return await params.onRevoked();
}
}

View File

@@ -0,0 +1,66 @@
import { EventEmitter } from "node:events";
import { describe, expect, it, vi } from "vitest";
import { keepHttpServerTaskAlive, waitUntilAbort } from "./channel-lifecycle.js";
type FakeServer = EventEmitter & {
close: (callback?: () => void) => void;
};
function createFakeServer(): FakeServer {
const server = new EventEmitter() as FakeServer;
server.close = (callback) => {
queueMicrotask(() => {
server.emit("close");
callback?.();
});
};
return server;
}
describe("plugin-sdk channel lifecycle helpers", () => {
it("resolves waitUntilAbort when signal aborts", async () => {
const abort = new AbortController();
const task = waitUntilAbort(abort.signal);
const early = await Promise.race([
task.then(() => "resolved"),
new Promise<"pending">((resolve) => setTimeout(() => resolve("pending"), 25)),
]);
expect(early).toBe("pending");
abort.abort();
await expect(task).resolves.toBeUndefined();
});
it("keeps server task pending until close, then resolves", async () => {
const server = createFakeServer();
const task = keepHttpServerTaskAlive({ server });
const early = await Promise.race([
task.then(() => "resolved"),
new Promise<"pending">((resolve) => setTimeout(() => resolve("pending"), 25)),
]);
expect(early).toBe("pending");
server.close();
await expect(task).resolves.toBeUndefined();
});
it("triggers abort hook once and resolves after close", async () => {
const server = createFakeServer();
const abort = new AbortController();
const onAbort = vi.fn(async () => {
server.close();
});
const task = keepHttpServerTaskAlive({
server,
abortSignal: abort.signal,
onAbort,
});
abort.abort();
await expect(task).resolves.toBeUndefined();
expect(onAbort).toHaveBeenCalledOnce();
});
});

View File

@@ -0,0 +1,66 @@
type CloseAwareServer = {
once: (event: "close", listener: () => void) => unknown;
};
/**
* Return a promise that resolves when the signal is aborted.
*
* If no signal is provided, the promise stays pending forever.
*/
export function waitUntilAbort(signal?: AbortSignal): Promise<void> {
return new Promise<void>((resolve) => {
if (!signal) {
return;
}
if (signal.aborted) {
resolve();
return;
}
signal.addEventListener("abort", () => resolve(), { once: true });
});
}
/**
* Keep a channel/provider task pending until the HTTP server closes.
*
* When an abort signal is provided, `onAbort` is invoked once and should
* trigger server shutdown. The returned promise resolves only after `close`.
*/
export async function keepHttpServerTaskAlive(params: {
server: CloseAwareServer;
abortSignal?: AbortSignal;
onAbort?: () => void | Promise<void>;
}): Promise<void> {
const { server, abortSignal, onAbort } = params;
let abortTask: Promise<void> = Promise.resolve();
let abortTriggered = false;
const triggerAbort = () => {
if (abortTriggered) {
return;
}
abortTriggered = true;
abortTask = Promise.resolve(onAbort?.()).then(() => undefined);
};
const onAbortSignal = () => {
triggerAbort();
};
if (abortSignal) {
if (abortSignal.aborted) {
triggerAbort();
} else {
abortSignal.addEventListener("abort", onAbortSignal, { once: true });
}
}
await new Promise<void>((resolve) => {
server.once("close", () => resolve());
});
if (abortSignal) {
abortSignal.removeEventListener("abort", onAbortSignal);
}
await abortTask;
}

View File

@@ -149,6 +149,7 @@ export {
WEBHOOK_IN_FLIGHT_DEFAULTS,
} from "./webhook-request-guards.js";
export type { WebhookBodyReadProfile, WebhookInFlightLimiter } from "./webhook-request-guards.js";
export { keepHttpServerTaskAlive, waitUntilAbort } from "./channel-lifecycle.js";
export type { AgentMediaPayload } from "./agent-media-payload.js";
export { buildAgentMediaPayload } from "./agent-media-payload.js";
export {