From 866bd91c659ca558a40356fdeede3c3dfe8e816f Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Mon, 2 Mar 2026 21:18:22 +0000 Subject: [PATCH] refactor: harden msteams lifecycle and attachment flows --- .../msteams/src/attachments/download.ts | 30 +++---- extensions/msteams/src/attachments/graph.ts | 37 +++++---- .../msteams/src/attachments/shared.test.ts | 81 +++++++++++++++++++ extensions/msteams/src/attachments/shared.ts | 49 +++++++++++ extensions/msteams/src/messenger.ts | 76 ++++++++++------- .../src/monitor-handler.file-consent.test.ts | 14 +--- extensions/msteams/src/monitor-handler.ts | 16 +++- .../msteams/src/monitor.lifecycle.test.ts | 18 +++++ extensions/msteams/src/monitor.ts | 24 ++---- extensions/msteams/src/reply-dispatcher.ts | 37 ++++----- .../msteams/src/revoked-context.test.ts | 39 +++++++++ extensions/msteams/src/revoked-context.ts | 17 ++++ src/plugin-sdk/channel-lifecycle.test.ts | 66 +++++++++++++++ src/plugin-sdk/channel-lifecycle.ts | 66 +++++++++++++++ src/plugin-sdk/index.ts | 1 + 15 files changed, 459 insertions(+), 112 deletions(-) create mode 100644 extensions/msteams/src/revoked-context.test.ts create mode 100644 extensions/msteams/src/revoked-context.ts create mode 100644 src/plugin-sdk/channel-lifecycle.test.ts create mode 100644 src/plugin-sdk/channel-lifecycle.ts diff --git a/extensions/msteams/src/attachments/download.ts b/extensions/msteams/src/attachments/download.ts index f40de08ece6..5a982df1b9f 100644 --- a/extensions/msteams/src/attachments/download.ts +++ b/extensions/msteams/src/attachments/download.ts @@ -6,12 +6,12 @@ import { isDownloadableAttachment, isRecord, isUrlAllowed, + type MSTeamsAttachmentFetchPolicy, normalizeContentType, resolveMediaSsrfPolicy, + resolveAttachmentFetchPolicy, resolveRequestUrl, - resolveAuthAllowedHosts, - resolveAllowedHosts, - safeFetch, + safeFetchWithPolicy, } from "./shared.js"; import type { MSTeamsAccessTokenProvider, @@ -95,12 +95,11 @@ async function fetchWithAuthFallback(params: { tokenProvider?: MSTeamsAccessTokenProvider; fetchFn?: typeof fetch; requestInit?: RequestInit; - allowHosts: string[]; - authAllowHosts: string[]; + policy: MSTeamsAttachmentFetchPolicy; }): Promise { - const firstAttempt = await safeFetch({ + const firstAttempt = await safeFetchWithPolicy({ url: params.url, - allowHosts: params.allowHosts, + policy: params.policy, fetchFn: params.fetchFn, requestInit: params.requestInit, }); @@ -113,7 +112,7 @@ async function fetchWithAuthFallback(params: { if (firstAttempt.status !== 401 && firstAttempt.status !== 403) { return firstAttempt; } - if (!isUrlAllowed(params.url, params.authAllowHosts)) { + if (!isUrlAllowed(params.url, params.policy.authAllowHosts)) { return firstAttempt; } @@ -124,10 +123,9 @@ async function fetchWithAuthFallback(params: { const token = await params.tokenProvider.getAccessToken(scope); const authHeaders = new Headers(params.requestInit?.headers); authHeaders.set("Authorization", `Bearer ${token}`); - const authAttempt = await safeFetch({ + const authAttempt = await safeFetchWithPolicy({ url: params.url, - allowHosts: params.allowHosts, - authorizationAllowHosts: params.authAllowHosts, + policy: params.policy, fetchFn, requestInit: { ...params.requestInit, @@ -171,8 +169,11 @@ export async function downloadMSTeamsAttachments(params: { if (list.length === 0) { return []; } - const allowHosts = resolveAllowedHosts(params.allowHosts); - const authAllowHosts = resolveAuthAllowedHosts(params.authAllowHosts); + const policy = resolveAttachmentFetchPolicy({ + allowHosts: params.allowHosts, + authAllowHosts: params.authAllowHosts, + }); + const allowHosts = policy.allowHosts; const ssrfPolicy = resolveMediaSsrfPolicy(allowHosts); // Download ANY downloadable attachment (not just images) @@ -249,8 +250,7 @@ export async function downloadMSTeamsAttachments(params: { tokenProvider: params.tokenProvider, fetchFn: params.fetchFn, requestInit: init, - allowHosts, - authAllowHosts, + policy, }), }); out.push(media); diff --git a/extensions/msteams/src/attachments/graph.ts b/extensions/msteams/src/attachments/graph.ts index f921a2cfa04..a50356e3ced 100644 --- a/extensions/msteams/src/attachments/graph.ts +++ b/extensions/msteams/src/attachments/graph.ts @@ -3,16 +3,17 @@ import { getMSTeamsRuntime } from "../runtime.js"; import { downloadMSTeamsAttachments } from "./download.js"; import { downloadAndStoreMSTeamsRemoteMedia } from "./remote-media.js"; import { + applyAuthorizationHeaderForUrl, GRAPH_ROOT, inferPlaceholder, isRecord, isUrlAllowed, + type MSTeamsAttachmentFetchPolicy, normalizeContentType, resolveMediaSsrfPolicy, + resolveAttachmentFetchPolicy, resolveRequestUrl, - resolveAuthAllowedHosts, - resolveAllowedHosts, - safeFetch, + safeFetchWithPolicy, } from "./shared.js"; import type { MSTeamsAccessTokenProvider, @@ -243,9 +244,11 @@ export async function downloadMSTeamsGraphMedia(params: { if (!params.messageUrl || !params.tokenProvider) { return { media: [] }; } - const allowHosts = resolveAllowedHosts(params.allowHosts); - const authAllowHosts = resolveAuthAllowedHosts(params.authAllowHosts); - const ssrfPolicy = resolveMediaSsrfPolicy(allowHosts); + const policy: MSTeamsAttachmentFetchPolicy = resolveAttachmentFetchPolicy({ + allowHosts: params.allowHosts, + authAllowHosts: params.authAllowHosts, + }); + const ssrfPolicy = resolveMediaSsrfPolicy(policy.allowHosts); const messageUrl = params.messageUrl; let accessToken: string; try { @@ -291,7 +294,7 @@ export async function downloadMSTeamsGraphMedia(params: { try { // SharePoint URLs need to be accessed via Graph shares API const shareUrl = att.contentUrl!; - if (!isUrlAllowed(shareUrl, allowHosts)) { + if (!isUrlAllowed(shareUrl, policy.allowHosts)) { continue; } const encodedUrl = Buffer.from(shareUrl).toString("base64url"); @@ -307,15 +310,15 @@ export async function downloadMSTeamsGraphMedia(params: { fetchImpl: async (input, init) => { const requestUrl = resolveRequestUrl(input); const headers = new Headers(init?.headers); - if (isUrlAllowed(requestUrl, authAllowHosts)) { - headers.set("Authorization", `Bearer ${accessToken}`); - } else { - headers.delete("Authorization"); - } - return await safeFetch({ + applyAuthorizationHeaderForUrl({ + headers, url: requestUrl, - allowHosts, - authorizationAllowHosts: authAllowHosts, + authAllowHosts: policy.authAllowHosts, + bearerToken: accessToken, + }); + return await safeFetchWithPolicy({ + url: requestUrl, + policy, fetchFn, requestInit: { ...init, @@ -373,8 +376,8 @@ export async function downloadMSTeamsGraphMedia(params: { attachments: filteredAttachments, maxBytes: params.maxBytes, tokenProvider: params.tokenProvider, - allowHosts, - authAllowHosts, + allowHosts: policy.allowHosts, + authAllowHosts: policy.authAllowHosts, fetchFn: params.fetchFn, preserveFilenames: params.preserveFilenames, }); diff --git a/extensions/msteams/src/attachments/shared.test.ts b/extensions/msteams/src/attachments/shared.test.ts index 4aa1c0e8bab..186a70f71aa 100644 --- a/extensions/msteams/src/attachments/shared.test.ts +++ b/extensions/msteams/src/attachments/shared.test.ts @@ -1,12 +1,15 @@ import { describe, expect, it, vi } from "vitest"; import { + applyAuthorizationHeaderForUrl, isPrivateOrReservedIP, isUrlAllowed, resolveAndValidateIP, + resolveAttachmentFetchPolicy, resolveAllowedHosts, resolveAuthAllowedHosts, resolveMediaSsrfPolicy, safeFetch, + safeFetchWithPolicy, } from "./shared.js"; const publicResolve = async () => ({ address: "13.107.136.10" }); @@ -34,6 +37,18 @@ describe("msteams attachment allowlists", () => { expect(resolveAuthAllowedHosts(["*", "graph.microsoft.com"])).toEqual(["*"]); }); + it("resolves a normalized attachment fetch policy", () => { + expect( + resolveAttachmentFetchPolicy({ + allowHosts: ["sharepoint.com"], + authAllowHosts: ["graph.microsoft.com"], + }), + ).toEqual({ + allowHosts: ["sharepoint.com"], + authAllowHosts: ["graph.microsoft.com"], + }); + }); + it("requires https and host suffix match", () => { const allowHosts = resolveAllowedHosts(["sharepoint.com"]); expect(isUrlAllowed("https://contoso.sharepoint.com/file.png", allowHosts)).toBe(true); @@ -294,4 +309,70 @@ describe("safeFetch", () => { }), ).rejects.toThrow("blocked by allowlist"); }); + + it("strips authorization across redirects outside auth allowlist", async () => { + const seenAuth: string[] = []; + const fetchMock = vi.fn(async (url: string, init?: RequestInit) => { + const auth = new Headers(init?.headers).get("authorization") ?? ""; + seenAuth.push(`${url}|${auth}`); + if (url === "https://teams.sharepoint.com/file.pdf") { + return new Response(null, { + status: 302, + headers: { location: "https://cdn.sharepoint.com/storage/file.pdf" }, + }); + } + return new Response("ok", { status: 200 }); + }); + + const headers = new Headers({ Authorization: "Bearer secret" }); + const res = await safeFetch({ + url: "https://teams.sharepoint.com/file.pdf", + allowHosts: ["sharepoint.com"], + authorizationAllowHosts: ["graph.microsoft.com"], + fetchFn: fetchMock as unknown as typeof fetch, + requestInit: { headers }, + resolveFn: publicResolve, + }); + expect(res.status).toBe(200); + expect(seenAuth[0]).toContain("Bearer secret"); + expect(seenAuth[1]).toMatch(/\|$/); + }); +}); + +describe("attachment fetch auth helpers", () => { + it("sets and clears authorization header by auth allowlist", () => { + const headers = new Headers(); + applyAuthorizationHeaderForUrl({ + headers, + url: "https://graph.microsoft.com/v1.0/me", + authAllowHosts: ["graph.microsoft.com"], + bearerToken: "token-1", + }); + expect(headers.get("authorization")).toBe("Bearer token-1"); + + applyAuthorizationHeaderForUrl({ + headers, + url: "https://evil.example.com/collect", + authAllowHosts: ["graph.microsoft.com"], + bearerToken: "token-1", + }); + expect(headers.get("authorization")).toBeNull(); + }); + + it("safeFetchWithPolicy forwards policy allowlists", async () => { + const fetchMock = vi.fn(async (_url: string, _init?: RequestInit) => { + return new Response("ok", { status: 200 }); + }); + const res = await safeFetchWithPolicy({ + url: "https://teams.sharepoint.com/file.pdf", + policy: resolveAttachmentFetchPolicy({ + allowHosts: ["sharepoint.com"], + authAllowHosts: ["graph.microsoft.com"], + }), + fetchFn: fetchMock as unknown as typeof fetch, + resolveFn: publicResolve, + }); + expect(res.status).toBe(200); + expect(fetchMock).toHaveBeenCalledOnce(); + }); }); diff --git a/extensions/msteams/src/attachments/shared.ts b/extensions/msteams/src/attachments/shared.ts index 88ff64970b6..7897b52803e 100644 --- a/extensions/msteams/src/attachments/shared.ts +++ b/extensions/msteams/src/attachments/shared.ts @@ -266,10 +266,42 @@ export function resolveAuthAllowedHosts(input?: string[]): string[] { return normalizeHostnameSuffixAllowlist(input, DEFAULT_MEDIA_AUTH_HOST_ALLOWLIST); } +export type MSTeamsAttachmentFetchPolicy = { + allowHosts: string[]; + authAllowHosts: string[]; +}; + +export function resolveAttachmentFetchPolicy(params?: { + allowHosts?: string[]; + authAllowHosts?: string[]; +}): MSTeamsAttachmentFetchPolicy { + return { + allowHosts: resolveAllowedHosts(params?.allowHosts), + authAllowHosts: resolveAuthAllowedHosts(params?.authAllowHosts), + }; +} + export function isUrlAllowed(url: string, allowlist: string[]): boolean { return isHttpsUrlAllowedByHostnameSuffixAllowlist(url, allowlist); } +export function applyAuthorizationHeaderForUrl(params: { + headers: Headers; + url: string; + authAllowHosts: string[]; + bearerToken?: string; +}): void { + if (!params.bearerToken) { + params.headers.delete("Authorization"); + return; + } + if (isUrlAllowed(params.url, params.authAllowHosts)) { + params.headers.set("Authorization", `Bearer ${params.bearerToken}`); + return; + } + params.headers.delete("Authorization"); +} + export function resolveMediaSsrfPolicy(allowHosts: string[]): SsrFPolicy | undefined { return buildHostnameAllowlistPolicyFromSuffixAllowlist(allowHosts); } @@ -408,3 +440,20 @@ export async function safeFetch(params: { throw new Error(`Too many redirects (>${MAX_SAFE_REDIRECTS})`); } + +export async function safeFetchWithPolicy(params: { + url: string; + policy: MSTeamsAttachmentFetchPolicy; + fetchFn?: typeof fetch; + requestInit?: RequestInit; + resolveFn?: (hostname: string) => Promise<{ address: string }>; +}): Promise { + return await safeFetch({ + url: params.url, + allowHosts: params.policy.allowHosts, + authorizationAllowHosts: params.policy.authAllowHosts, + fetchFn: params.fetchFn, + requestInit: params.requestInit, + resolveFn: params.resolveFn, + }); +} diff --git a/extensions/msteams/src/messenger.ts b/extensions/msteams/src/messenger.ts index e421b8bf3eb..4a913192944 100644 --- a/extensions/msteams/src/messenger.ts +++ b/extensions/msteams/src/messenger.ts @@ -10,7 +10,7 @@ import { } from "openclaw/plugin-sdk"; import type { MSTeamsAccessTokenProvider } from "./attachments/types.js"; import type { StoredConversationReference } from "./conversation-store.js"; -import { classifyMSTeamsSendError, isRevokedProxyError } from "./errors.js"; +import { classifyMSTeamsSendError } from "./errors.js"; import { prepareFileConsentActivity, requiresFileConsent } from "./file-consent-helpers.js"; import { buildTeamsFileInfoCard } from "./graph-chat.js"; import { @@ -20,6 +20,7 @@ import { } from "./graph-upload.js"; import { extractFilename, extractMessageId, getMimeType, isLocalPath } from "./media-helpers.js"; import { parseMentions } from "./mentions.js"; +import { withRevokedProxyFallback } from "./revoked-context.js"; import { getMSTeamsRuntime } from "./runtime.js"; /** @@ -441,34 +442,42 @@ export async function sendMSTeamsMessages(params: { } }; - const sendMessagesInContext = async ( + const sendMessageInContext = async ( ctx: SendContext, - batch: MSTeamsRenderedMessage[] = messages, - offset = 0, + message: MSTeamsRenderedMessage, + messageIndex: number, + ): Promise => { + const response = await sendWithRetry( + async () => + await ctx.sendActivity( + await buildActivity( + message, + params.conversationRef, + params.tokenProvider, + params.sharePointSiteId, + params.mediaMaxBytes, + ), + ), + { messageIndex, messageCount: messages.length }, + ); + return extractMessageId(response) ?? "unknown"; + }; + + const sendMessageBatchInContext = async ( + ctx: SendContext, + batch: MSTeamsRenderedMessage[], + startIndex: number, ): Promise => { const messageIds: string[] = []; for (const [idx, message] of batch.entries()) { - const response = await sendWithRetry( - async () => - await ctx.sendActivity( - await buildActivity( - message, - params.conversationRef, - params.tokenProvider, - params.sharePointSiteId, - params.mediaMaxBytes, - ), - ), - { messageIndex: offset + idx, messageCount: messages.length }, - ); - messageIds.push(extractMessageId(response) ?? "unknown"); + messageIds.push(await sendMessageInContext(ctx, message, startIndex + idx)); } return messageIds; }; const sendProactively = async ( - batch: MSTeamsRenderedMessage[] = messages, - offset = 0, + batch: MSTeamsRenderedMessage[], + startIndex: number, ): Promise => { const baseRef = buildConversationReference(params.conversationRef); const proactiveRef: MSTeamsConversationReference = { @@ -478,7 +487,7 @@ export async function sendMSTeamsMessages(params: { const messageIds: string[] = []; await params.adapter.continueConversation(params.appId, proactiveRef, async (ctx) => { - messageIds.push(...(await sendMessagesInContext(ctx, batch, offset))); + messageIds.push(...(await sendMessageBatchInContext(ctx, batch, startIndex))); }); return messageIds; }; @@ -490,16 +499,21 @@ export async function sendMSTeamsMessages(params: { } const messageIds: string[] = []; for (const [idx, message] of messages.entries()) { - try { - messageIds.push(...(await sendMessagesInContext(ctx, [message], idx))); - } catch (err) { - if (!isRevokedProxyError(err)) { - throw err; - } - const remaining = messages.slice(idx); - if (remaining.length > 0) { - messageIds.push(...(await sendProactively(remaining, idx))); - } + const result = await withRevokedProxyFallback({ + run: async () => ({ + ids: [await sendMessageInContext(ctx, message, idx)], + fellBack: false, + }), + onRevoked: async () => { + const remaining = messages.slice(idx); + return { + ids: remaining.length > 0 ? await sendProactively(remaining, idx) : [], + fellBack: true, + }; + }, + }); + messageIds.push(...result.ids); + if (result.fellBack) { return messageIds; } } diff --git a/extensions/msteams/src/monitor-handler.file-consent.test.ts b/extensions/msteams/src/monitor-handler.file-consent.test.ts index 1fc6714a451..386ffc34853 100644 --- a/extensions/msteams/src/monitor-handler.file-consent.test.ts +++ b/extensions/msteams/src/monitor-handler.file-consent.test.ts @@ -155,10 +155,7 @@ describe("msteams file consent invoke authz", () => { }), ); - // Wait for async upload to complete - await vi.waitFor(() => { - expect(fileConsentMockState.uploadToConsentUrl).toHaveBeenCalledTimes(1); - }); + expect(fileConsentMockState.uploadToConsentUrl).toHaveBeenCalledTimes(1); expect(fileConsentMockState.uploadToConsentUrl).toHaveBeenCalledWith( expect.objectContaining({ @@ -192,12 +189,9 @@ describe("msteams file consent invoke authz", () => { }), ); - // Wait for async handler to complete - await vi.waitFor(() => { - expect(sendActivity).toHaveBeenCalledWith( - "The file upload request has expired. Please try sending the file again.", - ); - }); + expect(sendActivity).toHaveBeenCalledWith( + "The file upload request has expired. Please try sending the file again.", + ); expect(fileConsentMockState.uploadToConsentUrl).not.toHaveBeenCalled(); expect(getPendingUpload(uploadId)).toBeDefined(); diff --git a/extensions/msteams/src/monitor-handler.ts b/extensions/msteams/src/monitor-handler.ts index 27d3e06929f..ac1b469e8be 100644 --- a/extensions/msteams/src/monitor-handler.ts +++ b/extensions/msteams/src/monitor-handler.ts @@ -7,6 +7,7 @@ import { createMSTeamsMessageHandler } from "./monitor-handler/message-handler.j import type { MSTeamsMonitorLogger } from "./monitor-types.js"; import { getPendingUpload, removePendingUpload } from "./pending-uploads.js"; import type { MSTeamsPollStore } from "./polls.js"; +import { withRevokedProxyFallback } from "./revoked-context.js"; import type { MSTeamsTurnContext } from "./sdk-types.js"; export type MSTeamsAccessTokenProvider = { @@ -146,10 +147,19 @@ export function registerMSTeamsHandlers( // Send invoke response IMMEDIATELY to prevent Teams timeout await ctx.sendActivity({ type: "invokeResponse", value: { status: 200 } }); - // Handle file upload asynchronously (don't await) - handleFileConsentInvoke(ctx, deps.log).catch((err) => { + try { + await withRevokedProxyFallback({ + run: async () => await handleFileConsentInvoke(ctx, deps.log), + onRevoked: async () => true, + onRevokedLog: () => { + deps.log.debug?.( + "turn context revoked during file consent invoke; skipping delayed response", + ); + }, + }); + } catch (err) { deps.log.debug?.("file consent handler error", { error: String(err) }); - }); + } return; } return originalRun.call(handler, context); diff --git a/extensions/msteams/src/monitor.lifecycle.test.ts b/extensions/msteams/src/monitor.lifecycle.test.ts index abf69b23d0e..132718ce307 100644 --- a/extensions/msteams/src/monitor.lifecycle.test.ts +++ b/extensions/msteams/src/monitor.lifecycle.test.ts @@ -6,6 +6,9 @@ import type { MSTeamsPollStore } from "./polls.js"; type FakeServer = EventEmitter & { close: (callback?: (err?: Error | null) => void) => void; + setTimeout: (msecs: number) => FakeServer; + requestTimeout: number; + headersTimeout: number; }; const expressControl = vi.hoisted(() => ({ @@ -14,6 +17,18 @@ const expressControl = vi.hoisted(() => ({ vi.mock("openclaw/plugin-sdk", () => ({ DEFAULT_WEBHOOK_MAX_BODY_BYTES: 1024 * 1024, + keepHttpServerTaskAlive: vi.fn( + async (params: { abortSignal?: AbortSignal; onAbort?: () => Promise | void }) => { + await new Promise((resolve) => { + if (params.abortSignal?.aborted) { + resolve(); + return; + } + params.abortSignal?.addEventListener("abort", () => resolve(), { once: true }); + }); + await params.onAbort?.(); + }, + ), mergeAllowlist: (params: { existing?: string[]; additions?: string[] }) => Array.from(new Set([...(params.existing ?? []), ...(params.additions ?? [])])), summarizeMapping: vi.fn(), @@ -31,6 +46,9 @@ vi.mock("express", () => { post: vi.fn(), listen: vi.fn((_port: number) => { const server = new EventEmitter() as FakeServer; + server.setTimeout = vi.fn((_msecs: number) => server); + server.requestTimeout = 0; + server.headersTimeout = 0; server.close = (callback?: (err?: Error | null) => void) => { queueMicrotask(() => { server.emit("close"); diff --git a/extensions/msteams/src/monitor.ts b/extensions/msteams/src/monitor.ts index 8ae4f7e3173..f2adba52139 100644 --- a/extensions/msteams/src/monitor.ts +++ b/extensions/msteams/src/monitor.ts @@ -2,6 +2,7 @@ import type { Server } from "node:http"; import type { Request, Response } from "express"; import { DEFAULT_WEBHOOK_MAX_BODY_BYTES, + keepHttpServerTaskAlive, mergeAllowlist, summarizeMapping, type OpenClawConfig, @@ -333,25 +334,12 @@ export async function monitorMSTeamsProvider( }); }; - // Handle abort signal - const onAbort = () => { - void shutdown(); - }; - if (opts.abortSignal) { - if (opts.abortSignal.aborted) { - onAbort(); - } else { - opts.abortSignal.addEventListener("abort", onAbort, { once: true }); - } - } - - // Keep this task alive until shutdown/close so gateway runtime does not treat startup as exit. - await new Promise((resolve) => { - httpServer.once("close", () => { - resolve(); - }); + // Keep this task alive until close so gateway runtime does not treat startup as exit. + await keepHttpServerTaskAlive({ + server: httpServer, + abortSignal: opts.abortSignal, + onAbort: shutdown, }); - opts.abortSignal?.removeEventListener("abort", onAbort); return { app: expressApp, shutdown }; } diff --git a/extensions/msteams/src/reply-dispatcher.ts b/extensions/msteams/src/reply-dispatcher.ts index 7f9dd098f43..3ddf7b18c5e 100644 --- a/extensions/msteams/src/reply-dispatcher.ts +++ b/extensions/msteams/src/reply-dispatcher.ts @@ -13,7 +13,6 @@ import { classifyMSTeamsSendError, formatMSTeamsSendErrorHint, formatUnknownError, - isRevokedProxyError, } from "./errors.js"; import { buildConversationReference, @@ -22,6 +21,7 @@ import { sendMSTeamsMessages, } from "./messenger.js"; import type { MSTeamsMonitorLogger } from "./monitor-types.js"; +import { withRevokedProxyFallback } from "./revoked-context.js"; import { getMSTeamsRuntime } from "./runtime.js"; import type { MSTeamsTurnContext } from "./sdk-types.js"; @@ -53,23 +53,24 @@ export function createMSTeamsReplyDispatcher(params: { * the stored conversation reference so the user still sees the "…" bubble. */ const sendTypingIndicator = async () => { - try { - await params.context.sendActivity({ type: "typing" }); - } catch (err) { - if (!isRevokedProxyError(err)) { - throw err; - } - // Turn context revoked — fall back to proactive typing. - params.log.debug?.("turn context revoked, sending typing via proactive messaging"); - const baseRef = buildConversationReference(params.conversationRef); - await params.adapter.continueConversation( - params.appId, - { ...baseRef, activityId: undefined }, - async (ctx) => { - await ctx.sendActivity({ type: "typing" }); - }, - ); - } + await withRevokedProxyFallback({ + run: async () => { + await params.context.sendActivity({ type: "typing" }); + }, + onRevoked: async () => { + const baseRef = buildConversationReference(params.conversationRef); + await params.adapter.continueConversation( + params.appId, + { ...baseRef, activityId: undefined }, + async (ctx) => { + await ctx.sendActivity({ type: "typing" }); + }, + ); + }, + onRevokedLog: () => { + params.log.debug?.("turn context revoked, sending typing via proactive messaging"); + }, + }); }; const typingCallbacks = createTypingCallbacks({ diff --git a/extensions/msteams/src/revoked-context.test.ts b/extensions/msteams/src/revoked-context.test.ts new file mode 100644 index 00000000000..20c339d9434 --- /dev/null +++ b/extensions/msteams/src/revoked-context.test.ts @@ -0,0 +1,39 @@ +import { describe, expect, it, vi } from "vitest"; +import { withRevokedProxyFallback } from "./revoked-context.js"; + +describe("msteams revoked context helper", () => { + it("returns primary result when no error occurs", async () => { + await expect( + withRevokedProxyFallback({ + run: async () => "ok", + onRevoked: async () => "fallback", + }), + ).resolves.toBe("ok"); + }); + + it("uses fallback when proxy-revoked TypeError is thrown", async () => { + const onRevokedLog = vi.fn(); + await expect( + withRevokedProxyFallback({ + run: async () => { + throw new TypeError("Cannot perform 'get' on a proxy that has been revoked"); + }, + onRevoked: async () => "fallback", + onRevokedLog, + }), + ).resolves.toBe("fallback"); + expect(onRevokedLog).toHaveBeenCalledOnce(); + }); + + it("rethrows non-revoked errors", async () => { + const err = Object.assign(new Error("boom"), { statusCode: 500 }); + await expect( + withRevokedProxyFallback({ + run: async () => { + throw err; + }, + onRevoked: async () => "fallback", + }), + ).rejects.toBe(err); + }); +}); diff --git a/extensions/msteams/src/revoked-context.ts b/extensions/msteams/src/revoked-context.ts new file mode 100644 index 00000000000..a8ac1859434 --- /dev/null +++ b/extensions/msteams/src/revoked-context.ts @@ -0,0 +1,17 @@ +import { isRevokedProxyError } from "./errors.js"; + +export async function withRevokedProxyFallback(params: { + run: () => Promise; + onRevoked: () => Promise; + onRevokedLog?: () => void; +}): Promise { + try { + return await params.run(); + } catch (err) { + if (!isRevokedProxyError(err)) { + throw err; + } + params.onRevokedLog?.(); + return await params.onRevoked(); + } +} diff --git a/src/plugin-sdk/channel-lifecycle.test.ts b/src/plugin-sdk/channel-lifecycle.test.ts new file mode 100644 index 00000000000..020510c914a --- /dev/null +++ b/src/plugin-sdk/channel-lifecycle.test.ts @@ -0,0 +1,66 @@ +import { EventEmitter } from "node:events"; +import { describe, expect, it, vi } from "vitest"; +import { keepHttpServerTaskAlive, waitUntilAbort } from "./channel-lifecycle.js"; + +type FakeServer = EventEmitter & { + close: (callback?: () => void) => void; +}; + +function createFakeServer(): FakeServer { + const server = new EventEmitter() as FakeServer; + server.close = (callback) => { + queueMicrotask(() => { + server.emit("close"); + callback?.(); + }); + }; + return server; +} + +describe("plugin-sdk channel lifecycle helpers", () => { + it("resolves waitUntilAbort when signal aborts", async () => { + const abort = new AbortController(); + const task = waitUntilAbort(abort.signal); + + const early = await Promise.race([ + task.then(() => "resolved"), + new Promise<"pending">((resolve) => setTimeout(() => resolve("pending"), 25)), + ]); + expect(early).toBe("pending"); + + abort.abort(); + await expect(task).resolves.toBeUndefined(); + }); + + it("keeps server task pending until close, then resolves", async () => { + const server = createFakeServer(); + const task = keepHttpServerTaskAlive({ server }); + + const early = await Promise.race([ + task.then(() => "resolved"), + new Promise<"pending">((resolve) => setTimeout(() => resolve("pending"), 25)), + ]); + expect(early).toBe("pending"); + + server.close(); + await expect(task).resolves.toBeUndefined(); + }); + + it("triggers abort hook once and resolves after close", async () => { + const server = createFakeServer(); + const abort = new AbortController(); + const onAbort = vi.fn(async () => { + server.close(); + }); + + const task = keepHttpServerTaskAlive({ + server, + abortSignal: abort.signal, + onAbort, + }); + + abort.abort(); + await expect(task).resolves.toBeUndefined(); + expect(onAbort).toHaveBeenCalledOnce(); + }); +}); diff --git a/src/plugin-sdk/channel-lifecycle.ts b/src/plugin-sdk/channel-lifecycle.ts new file mode 100644 index 00000000000..4687e167352 --- /dev/null +++ b/src/plugin-sdk/channel-lifecycle.ts @@ -0,0 +1,66 @@ +type CloseAwareServer = { + once: (event: "close", listener: () => void) => unknown; +}; + +/** + * Return a promise that resolves when the signal is aborted. + * + * If no signal is provided, the promise stays pending forever. + */ +export function waitUntilAbort(signal?: AbortSignal): Promise { + return new Promise((resolve) => { + if (!signal) { + return; + } + if (signal.aborted) { + resolve(); + return; + } + signal.addEventListener("abort", () => resolve(), { once: true }); + }); +} + +/** + * Keep a channel/provider task pending until the HTTP server closes. + * + * When an abort signal is provided, `onAbort` is invoked once and should + * trigger server shutdown. The returned promise resolves only after `close`. + */ +export async function keepHttpServerTaskAlive(params: { + server: CloseAwareServer; + abortSignal?: AbortSignal; + onAbort?: () => void | Promise; +}): Promise { + const { server, abortSignal, onAbort } = params; + let abortTask: Promise = Promise.resolve(); + let abortTriggered = false; + + const triggerAbort = () => { + if (abortTriggered) { + return; + } + abortTriggered = true; + abortTask = Promise.resolve(onAbort?.()).then(() => undefined); + }; + + const onAbortSignal = () => { + triggerAbort(); + }; + + if (abortSignal) { + if (abortSignal.aborted) { + triggerAbort(); + } else { + abortSignal.addEventListener("abort", onAbortSignal, { once: true }); + } + } + + await new Promise((resolve) => { + server.once("close", () => resolve()); + }); + + if (abortSignal) { + abortSignal.removeEventListener("abort", onAbortSignal); + } + await abortTask; +} diff --git a/src/plugin-sdk/index.ts b/src/plugin-sdk/index.ts index 4d656634602..f31d2c1ff64 100644 --- a/src/plugin-sdk/index.ts +++ b/src/plugin-sdk/index.ts @@ -149,6 +149,7 @@ export { WEBHOOK_IN_FLIGHT_DEFAULTS, } from "./webhook-request-guards.js"; export type { WebhookBodyReadProfile, WebhookInFlightLimiter } from "./webhook-request-guards.js"; +export { keepHttpServerTaskAlive, waitUntilAbort } from "./channel-lifecycle.js"; export type { AgentMediaPayload } from "./agent-media-payload.js"; export { buildAgentMediaPayload } from "./agent-media-payload.js"; export {