diff --git a/src/agents/bash-tools.exec-approval-request.ts b/src/agents/bash-tools.exec-approval-request.ts index b68aa37d398..2b08495a400 100644 --- a/src/agents/bash-tools.exec-approval-request.ts +++ b/src/agents/bash-tools.exec-approval-request.ts @@ -42,3 +42,27 @@ export async function requestExecApprovalDecision( : undefined; return typeof decisionValue === "string" ? decisionValue : null; } + +export async function requestExecApprovalDecisionForHost(params: { + approvalId: string; + command: string; + workdir: string; + host: "gateway" | "node"; + security: ExecSecurity; + ask: ExecAsk; + agentId?: string; + resolvedPath?: string; + sessionKey?: string; +}): Promise { + return await requestExecApprovalDecision({ + id: params.approvalId, + command: params.command, + cwd: params.workdir, + host: params.host, + security: params.security, + ask: params.ask, + agentId: params.agentId, + resolvedPath: params.resolvedPath, + sessionKey: params.sessionKey, + }); +} diff --git a/src/agents/bash-tools.exec-host-gateway.ts b/src/agents/bash-tools.exec-host-gateway.ts index f742ee3862a..7b44f6fb3c6 100644 --- a/src/agents/bash-tools.exec-host-gateway.ts +++ b/src/agents/bash-tools.exec-host-gateway.ts @@ -16,7 +16,7 @@ import { } from "../infra/exec-approvals.js"; import type { SafeBinProfile } from "../infra/exec-safe-bin-policy.js"; import { markBackgrounded, tail } from "./bash-process-registry.js"; -import { requestExecApprovalDecision } from "./bash-tools.exec-approval-request.js"; +import { requestExecApprovalDecisionForHost } from "./bash-tools.exec-approval-request.js"; import { DEFAULT_APPROVAL_TIMEOUT_MS, DEFAULT_NOTIFY_TAIL_CHARS, @@ -81,6 +81,19 @@ export async function processGatewayAllowlist( const analysisOk = allowlistEval.analysisOk; const allowlistSatisfied = hostSecurity === "allowlist" && analysisOk ? allowlistEval.allowlistSatisfied : false; + const recordMatchedAllowlistUse = (resolvedPath?: string) => { + if (allowlistMatches.length === 0) { + return; + } + const seen = new Set(); + for (const match of allowlistMatches) { + if (seen.has(match.pattern)) { + continue; + } + seen.add(match.pattern); + recordAllowlistUse(approvals.file, params.agentId, match, params.command, resolvedPath); + } + }; const hasHeredocSegment = allowlistEval.segments.some((segment) => segment.argv.some((token) => token.startsWith("<<")), ); @@ -113,10 +126,10 @@ export async function processGatewayAllowlist( void (async () => { let decision: string | null = null; try { - decision = await requestExecApprovalDecision({ - id: approvalId, + decision = await requestExecApprovalDecisionForHost({ + approvalId, command: params.command, - cwd: params.workdir, + workdir: params.workdir, host: "gateway", security: hostSecurity, ask: hostAsk, @@ -186,22 +199,7 @@ export async function processGatewayAllowlist( return; } - if (allowlistMatches.length > 0) { - const seen = new Set(); - for (const match of allowlistMatches) { - if (seen.has(match.pattern)) { - continue; - } - seen.add(match.pattern); - recordAllowlistUse( - approvals.file, - params.agentId, - match, - params.command, - resolvedPath ?? undefined, - ); - } - } + recordMatchedAllowlistUse(resolvedPath ?? undefined); let run: Awaited> | null = null; try { @@ -321,22 +319,7 @@ export async function processGatewayAllowlist( } } - if (allowlistMatches.length > 0) { - const seen = new Set(); - for (const match of allowlistMatches) { - if (seen.has(match.pattern)) { - continue; - } - seen.add(match.pattern); - recordAllowlistUse( - approvals.file, - params.agentId, - match, - params.command, - allowlistEval.segments[0]?.resolution?.resolvedPath, - ); - } - } + recordMatchedAllowlistUse(allowlistEval.segments[0]?.resolution?.resolvedPath); return { execCommandOverride }; } diff --git a/src/agents/bash-tools.exec-host-node.ts b/src/agents/bash-tools.exec-host-node.ts index 3cca1bc121a..642c898107c 100644 --- a/src/agents/bash-tools.exec-host-node.ts +++ b/src/agents/bash-tools.exec-host-node.ts @@ -12,7 +12,7 @@ import { resolveExecApprovalsFromFile, } from "../infra/exec-approvals.js"; import { buildNodeShellCommand } from "../infra/node-shell.js"; -import { requestExecApprovalDecision } from "./bash-tools.exec-approval-request.js"; +import { requestExecApprovalDecisionForHost } from "./bash-tools.exec-approval-request.js"; import { DEFAULT_APPROVAL_TIMEOUT_MS, createApprovalSlug, @@ -178,10 +178,10 @@ export async function executeNodeHostCommand( void (async () => { let decision: string | null = null; try { - decision = await requestExecApprovalDecision({ - id: approvalId, + decision = await requestExecApprovalDecisionForHost({ + approvalId, command: params.command, - cwd: params.workdir, + workdir: params.workdir, host: "node", security: hostSecurity, ask: hostAsk, diff --git a/src/agents/pi-embedded-subscribe.ts b/src/agents/pi-embedded-subscribe.ts index b3326c39e3a..7d2195b98ce 100644 --- a/src/agents/pi-embedded-subscribe.ts +++ b/src/agents/pi-embedded-subscribe.ts @@ -317,35 +317,10 @@ export function subscribeEmbeddedPiSession(params: SubscribeEmbeddedPiSessionPar } return `\`\`\`txt\n${trimmed}\n\`\`\``; }; - const emitToolSummary = (toolName?: string, meta?: string) => { + const emitToolResultMessage = (toolName: string | undefined, message: string) => { if (!params.onToolResult) { return; } - const agg = formatToolAggregate(toolName, meta ? [meta] : undefined, { - markdown: useMarkdown, - }); - const { text: cleanedText, mediaUrls } = parseReplyDirectives(agg); - const filteredMediaUrls = filterToolResultMediaUrls(toolName, mediaUrls ?? []); - if (!cleanedText && filteredMediaUrls.length === 0) { - return; - } - try { - void params.onToolResult({ - text: cleanedText, - mediaUrls: filteredMediaUrls.length ? filteredMediaUrls : undefined, - }); - } catch { - // ignore tool result delivery failures - } - }; - const emitToolOutput = (toolName?: string, meta?: string, output?: string) => { - if (!params.onToolResult || !output) { - return; - } - const agg = formatToolAggregate(toolName, meta ? [meta] : undefined, { - markdown: useMarkdown, - }); - const message = `${agg}\n${formatToolOutputBlock(output)}`; const { text: cleanedText, mediaUrls } = parseReplyDirectives(message); const filteredMediaUrls = filterToolResultMediaUrls(toolName, mediaUrls ?? []); if (!cleanedText && filteredMediaUrls.length === 0) { @@ -360,6 +335,22 @@ export function subscribeEmbeddedPiSession(params: SubscribeEmbeddedPiSessionPar // ignore tool result delivery failures } }; + const emitToolSummary = (toolName?: string, meta?: string) => { + const agg = formatToolAggregate(toolName, meta ? [meta] : undefined, { + markdown: useMarkdown, + }); + emitToolResultMessage(toolName, agg); + }; + const emitToolOutput = (toolName?: string, meta?: string, output?: string) => { + if (!output) { + return; + } + const agg = formatToolAggregate(toolName, meta ? [meta] : undefined, { + markdown: useMarkdown, + }); + const message = `${agg}\n${formatToolOutputBlock(output)}`; + emitToolResultMessage(toolName, message); + }; const stripBlockTags = ( text: string, diff --git a/src/agents/skills-status.ts b/src/agents/skills-status.ts index 64f38ed9fd1..02ff68e2efc 100644 --- a/src/agents/skills-status.ts +++ b/src/agents/skills-status.ts @@ -1,6 +1,6 @@ import path from "node:path"; import type { OpenClawConfig } from "../config/config.js"; -import { evaluateEntryMetadataRequirementsForCurrentPlatform } from "../shared/entry-status.js"; +import { evaluateEntryRequirementsForCurrentPlatform } from "../shared/entry-status.js"; import type { RequirementConfigCheck, Requirements } from "../shared/requirements.js"; import { CONFIG_DIR } from "../utils.js"; import { @@ -191,17 +191,15 @@ function buildSkillStatus( ? bundledNames.has(entry.skill.name) : entry.skill.source === "openclaw-bundled"; - const requirementStatus = evaluateEntryMetadataRequirementsForCurrentPlatform({ - always, - metadata: entry.metadata, - frontmatter: entry.frontmatter, - hasLocalBin: hasBinary, - remote: eligibility?.remote, - isEnvSatisfied, - isConfigSatisfied, - }); const { emoji, homepage, required, missing, requirementsSatisfied, configChecks } = - requirementStatus; + evaluateEntryRequirementsForCurrentPlatform({ + always, + entry, + hasLocalBin: hasBinary, + remote: eligibility?.remote, + isEnvSatisfied, + isConfigSatisfied, + }); const eligible = !disabled && !blockedByAllowlist && requirementsSatisfied; return { diff --git a/src/agents/system-prompt.ts b/src/agents/system-prompt.ts index 74530db6897..9027bba92d7 100644 --- a/src/agents/system-prompt.ts +++ b/src/agents/system-prompt.ts @@ -5,6 +5,7 @@ import type { MemoryCitationsMode } from "../config/types.memory.js"; import { listDeliverableMessageChannels } from "../utils/message-channel.js"; import type { ResolvedTimeFormat } from "./date-time.js"; import type { EmbeddedContextFile } from "./pi-embedded-helpers.js"; +import type { EmbeddedSandboxInfo } from "./pi-embedded-runner/types.js"; import { sanitizeForPromptLiteral } from "./sanitize-for-prompt.js"; /** @@ -229,20 +230,7 @@ export function buildAgentSystemPrompt(params: { repoRoot?: string; }; messageToolHints?: string[]; - sandboxInfo?: { - enabled: boolean; - workspaceDir?: string; - containerWorkspaceDir?: string; - workspaceAccess?: "none" | "ro" | "rw"; - agentWorkspaceMount?: string; - browserBridgeUrl?: string; - browserNoVncUrl?: string; - hostBrowserAllowed?: boolean; - elevated?: { - allowed: boolean; - defaultLevel: "on" | "off" | "ask" | "full"; - }; - }; + sandboxInfo?: EmbeddedSandboxInfo; /** Reaction guidance for the agent (for Telegram minimal/extensive modes). */ reactionGuidance?: { level: "minimal" | "extensive"; diff --git a/src/browser/pw-role-snapshot.ts b/src/browser/pw-role-snapshot.ts index adf80794994..7a0b0ae70fe 100644 --- a/src/browser/pw-role-snapshot.ts +++ b/src/browser/pw-role-snapshot.ts @@ -266,6 +266,46 @@ function processLine( return enhanced; } +type InteractiveSnapshotLine = NonNullable>; + +function buildInteractiveSnapshotLines(params: { + lines: string[]; + options: RoleSnapshotOptions; + resolveRef: (parsed: InteractiveSnapshotLine) => { ref: string; nth?: number } | null; + recordRef: (parsed: InteractiveSnapshotLine, ref: string, nth?: number) => void; + includeSuffix: (suffix: string) => boolean; +}): string[] { + const out: string[] = []; + for (const line of params.lines) { + const parsed = matchInteractiveSnapshotLine(line, params.options); + if (!parsed) { + continue; + } + if (!INTERACTIVE_ROLES.has(parsed.role)) { + continue; + } + const resolved = params.resolveRef(parsed); + if (!resolved?.ref) { + continue; + } + params.recordRef(parsed, resolved.ref, resolved.nth); + + let enhanced = `- ${parsed.roleRaw}`; + if (parsed.name) { + enhanced += ` "${parsed.name}"`; + } + enhanced += ` [ref=${resolved.ref}]`; + if ((resolved.nth ?? 0) > 0) { + enhanced += ` [nth=${resolved.nth}]`; + } + if (params.includeSuffix(parsed.suffix)) { + enhanced += parsed.suffix; + } + out.push(enhanced); + } + return out; +} + export function parseRoleRef(raw: string): string | null { const trimmed = raw.trim(); if (!trimmed) { @@ -294,39 +334,24 @@ export function buildRoleSnapshotFromAriaSnapshot( }; if (options.interactive) { - const result: string[] = []; - for (const line of lines) { - const parsed = matchInteractiveSnapshotLine(line, options); - if (!parsed) { - continue; - } - const { roleRaw, role, name, suffix } = parsed; - if (!INTERACTIVE_ROLES.has(role)) { - continue; - } - - const ref = nextRef(); - const nth = tracker.getNextIndex(role, name); - tracker.trackRef(role, name, ref); - refs[ref] = { - role, - name, - nth, - }; - - let enhanced = `- ${roleRaw}`; - if (name) { - enhanced += ` "${name}"`; - } - enhanced += ` [ref=${ref}]`; - if (nth > 0) { - enhanced += ` [nth=${nth}]`; - } - if (suffix.includes("[")) { - enhanced += suffix; - } - result.push(enhanced); - } + const result = buildInteractiveSnapshotLines({ + lines, + options, + resolveRef: ({ role, name }) => { + const ref = nextRef(); + const nth = tracker.getNextIndex(role, name); + tracker.trackRef(role, name, ref); + return { ref, nth }; + }, + recordRef: ({ role, name }, ref, nth) => { + refs[ref] = { + role, + name, + nth, + }; + }, + includeSuffix: (suffix) => suffix.includes("["), + }); removeNthFromNonDuplicates(refs, tracker); @@ -370,23 +395,18 @@ export function buildRoleSnapshotFromAiSnapshot( const refs: RoleRefMap = {}; if (options.interactive) { - const out: string[] = []; - for (const line of lines) { - const parsed = matchInteractiveSnapshotLine(line, options); - if (!parsed) { - continue; - } - const { roleRaw, role, name, suffix } = parsed; - if (!INTERACTIVE_ROLES.has(role)) { - continue; - } - const ref = parseAiSnapshotRef(suffix); - if (!ref) { - continue; - } - refs[ref] = { role, ...(name ? { name } : {}) }; - out.push(`- ${roleRaw}${name ? ` "${name}"` : ""}${suffix}`); - } + const out = buildInteractiveSnapshotLines({ + lines, + options, + resolveRef: ({ suffix }) => { + const ref = parseAiSnapshotRef(suffix); + return ref ? { ref } : null; + }, + recordRef: ({ role, name }, ref) => { + refs[ref] = { role, ...(name ? { name } : {}) }; + }, + includeSuffix: () => true, + }); return { snapshot: out.join("\n") || "(no interactive elements)", refs, diff --git a/src/gateway/server-node-events.ts b/src/gateway/server-node-events.ts index 0878134f416..ce1d699797f 100644 --- a/src/gateway/server-node-events.ts +++ b/src/gateway/server-node-events.ts @@ -200,6 +200,21 @@ function parseSessionKeyFromPayloadJSON(payloadJSON: string): string | null { return sessionKey.length > 0 ? sessionKey : null; } +function parsePayloadObject(payloadJSON?: string | null): Record | null { + if (!payloadJSON) { + return null; + } + let payload: unknown; + try { + payload = JSON.parse(payloadJSON) as unknown; + } catch { + return null; + } + return typeof payload === "object" && payload !== null + ? (payload as Record) + : null; +} + async function sendReceiptAck(params: { cfg: ReturnType; deps: NodeEventContext["deps"]; @@ -232,17 +247,10 @@ async function sendReceiptAck(params: { export const handleNodeEvent = async (ctx: NodeEventContext, nodeId: string, evt: NodeEvent) => { switch (evt.event) { case "voice.transcript": { - if (!evt.payloadJSON) { + const obj = parsePayloadObject(evt.payloadJSON); + if (!obj) { return; } - let payload: unknown; - try { - payload = JSON.parse(evt.payloadJSON) as unknown; - } catch { - return; - } - const obj = - typeof payload === "object" && payload !== null ? (payload as Record) : {}; const text = typeof obj.text === "string" ? obj.text.trim() : ""; if (!text) { return; @@ -455,17 +463,10 @@ export const handleNodeEvent = async (ctx: NodeEventContext, nodeId: string, evt case "exec.started": case "exec.finished": case "exec.denied": { - if (!evt.payloadJSON) { + const obj = parsePayloadObject(evt.payloadJSON); + if (!obj) { return; } - let payload: unknown; - try { - payload = JSON.parse(evt.payloadJSON) as unknown; - } catch { - return; - } - const obj = - typeof payload === "object" && payload !== null ? (payload as Record) : {}; const sessionKey = typeof obj.sessionKey === "string" ? obj.sessionKey.trim() : `node-${nodeId}`; if (!sessionKey) { @@ -519,17 +520,10 @@ export const handleNodeEvent = async (ctx: NodeEventContext, nodeId: string, evt return; } case "push.apns.register": { - if (!evt.payloadJSON) { + const obj = parsePayloadObject(evt.payloadJSON); + if (!obj) { return; } - let payload: unknown; - try { - payload = JSON.parse(evt.payloadJSON) as unknown; - } catch { - return; - } - const obj = - typeof payload === "object" && payload !== null ? (payload as Record) : {}; const token = typeof obj.token === "string" ? obj.token : ""; const topic = typeof obj.topic === "string" ? obj.topic : ""; const environment = obj.environment; diff --git a/src/gateway/server.agent.gateway-server-agent.mocks.ts b/src/gateway/server.agent.gateway-server-agent.mocks.ts index 3dd42d4ab40..b930ccbc67f 100644 --- a/src/gateway/server.agent.gateway-server-agent.mocks.ts +++ b/src/gateway/server.agent.gateway-server-agent.mocks.ts @@ -1,23 +1,9 @@ import { vi } from "vitest"; -import type { PluginRegistry } from "../plugins/registry.js"; +import { createEmptyPluginRegistry, type PluginRegistry } from "../plugins/registry.js"; import { setActivePluginRegistry } from "../plugins/runtime.js"; export const registryState: { registry: PluginRegistry } = { - registry: { - plugins: [], - tools: [], - hooks: [], - typedHooks: [], - channels: [], - providers: [], - gatewayHandlers: {}, - httpHandlers: [], - httpRoutes: [], - cliRegistrars: [], - services: [], - commands: [], - diagnostics: [], - } as PluginRegistry, + registry: createEmptyPluginRegistry(), }; export function setRegistry(registry: PluginRegistry) { diff --git a/src/hooks/hooks-status.ts b/src/hooks/hooks-status.ts index a65cee2487c..d609b919039 100644 --- a/src/hooks/hooks-status.ts +++ b/src/hooks/hooks-status.ts @@ -1,6 +1,6 @@ import path from "node:path"; import type { OpenClawConfig } from "../config/config.js"; -import { evaluateEntryMetadataRequirementsForCurrentPlatform } from "../shared/entry-status.js"; +import { evaluateEntryRequirementsForCurrentPlatform } from "../shared/entry-status.js"; import type { RequirementConfigCheck, Requirements } from "../shared/requirements.js"; import { CONFIG_DIR } from "../utils.js"; import { hasBinary, isConfigPathTruthy, resolveHookConfig } from "./config.js"; @@ -91,17 +91,15 @@ function buildHookStatus( Boolean(process.env[envName] || hookConfig?.env?.[envName]); const isConfigSatisfied = (pathStr: string) => isConfigPathTruthy(config, pathStr); - const requirementStatus = evaluateEntryMetadataRequirementsForCurrentPlatform({ - always, - metadata: entry.metadata, - frontmatter: entry.frontmatter, - hasLocalBin: hasBinary, - remote: eligibility?.remote, - isEnvSatisfied, - isConfigSatisfied, - }); const { emoji, homepage, required, missing, requirementsSatisfied, configChecks } = - requirementStatus; + evaluateEntryRequirementsForCurrentPlatform({ + always, + entry, + hasLocalBin: hasBinary, + remote: eligibility?.remote, + isEnvSatisfied, + isConfigSatisfied, + }); const eligible = !disabled && requirementsSatisfied; diff --git a/src/infra/exec-wrapper-resolution.ts b/src/infra/exec-wrapper-resolution.ts index c5ae325e973..2fae35f315e 100644 --- a/src/infra/exec-wrapper-resolution.ts +++ b/src/infra/exec-wrapper-resolution.ts @@ -346,30 +346,7 @@ export function unwrapDispatchWrappersForResolution( } function extractPosixShellInlineCommand(argv: string[]): string | null { - for (let i = 1; i < argv.length; i += 1) { - const token = argv[i]?.trim(); - if (!token) { - continue; - } - const lower = token.toLowerCase(); - if (lower === "--") { - break; - } - if (POSIX_INLINE_COMMAND_FLAGS.has(lower)) { - const cmd = argv[i + 1]?.trim(); - return cmd ? cmd : null; - } - if (/^-[^-]*c[^-]*$/i.test(token)) { - const commandIndex = lower.indexOf("c"); - const inline = token.slice(commandIndex + 1).trim(); - if (inline) { - return inline; - } - const cmd = argv[i + 1]?.trim(); - return cmd ? cmd : null; - } - } - return null; + return extractInlineCommandByFlags(argv, POSIX_INLINE_COMMAND_FLAGS, { allowCombinedC: true }); } function extractCmdInlineCommand(argv: string[]): string | null { @@ -389,6 +366,14 @@ function extractCmdInlineCommand(argv: string[]): string | null { } function extractPowerShellInlineCommand(argv: string[]): string | null { + return extractInlineCommandByFlags(argv, POWERSHELL_INLINE_COMMAND_FLAGS); +} + +function extractInlineCommandByFlags( + argv: string[], + flags: ReadonlySet, + options: { allowCombinedC?: boolean } = {}, +): string | null { for (let i = 1; i < argv.length; i += 1) { const token = argv[i]?.trim(); if (!token) { @@ -398,7 +383,16 @@ function extractPowerShellInlineCommand(argv: string[]): string | null { if (lower === "--") { break; } - if (POWERSHELL_INLINE_COMMAND_FLAGS.has(lower)) { + if (flags.has(lower)) { + const cmd = argv[i + 1]?.trim(); + return cmd ? cmd : null; + } + if (options.allowCombinedC && /^-[^-]*c[^-]*$/i.test(token)) { + const commandIndex = lower.indexOf("c"); + const inline = token.slice(commandIndex + 1).trim(); + if (inline) { + return inline; + } const cmd = argv[i + 1]?.trim(); return cmd ? cmd : null; } diff --git a/src/memory/embeddings-mistral.ts b/src/memory/embeddings-mistral.ts index 33b1afe5282..7d9f2bb3dfe 100644 --- a/src/memory/embeddings-mistral.ts +++ b/src/memory/embeddings-mistral.ts @@ -1,6 +1,8 @@ import type { SsrFPolicy } from "../infra/net/ssrf.js"; -import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js"; -import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js"; +import { + createRemoteEmbeddingProvider, + resolveRemoteEmbeddingClient, +} from "./embeddings-remote-provider.js"; import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; export type MistralEmbeddingClient = { @@ -28,31 +30,13 @@ export async function createMistralEmbeddingProvider( options: EmbeddingProviderOptions, ): Promise<{ provider: EmbeddingProvider; client: MistralEmbeddingClient }> { const client = await resolveMistralEmbeddingClient(options); - const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`; - - const embed = async (input: string[]): Promise => { - if (input.length === 0) { - return []; - } - return await fetchRemoteEmbeddingVectors({ - url, - headers: client.headers, - ssrfPolicy: client.ssrfPolicy, - body: { model: client.model, input }, - errorPrefix: "mistral embeddings failed", - }); - }; return { - provider: { + provider: createRemoteEmbeddingProvider({ id: "mistral", - model: client.model, - embedQuery: async (text) => { - const [vec] = await embed([text]); - return vec ?? []; - }, - embedBatch: embed, - }, + client, + errorPrefix: "mistral embeddings failed", + }), client, }; } @@ -60,11 +44,10 @@ export async function createMistralEmbeddingProvider( export async function resolveMistralEmbeddingClient( options: EmbeddingProviderOptions, ): Promise { - const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({ + return await resolveRemoteEmbeddingClient({ provider: "mistral", options, defaultBaseUrl: DEFAULT_MISTRAL_BASE_URL, + normalizeModel: normalizeMistralModel, }); - const model = normalizeMistralModel(options.model); - return { baseUrl, headers, ssrfPolicy, model }; } diff --git a/src/memory/embeddings-openai.ts b/src/memory/embeddings-openai.ts index 02b92e68f60..af8184f4452 100644 --- a/src/memory/embeddings-openai.ts +++ b/src/memory/embeddings-openai.ts @@ -1,6 +1,8 @@ import type { SsrFPolicy } from "../infra/net/ssrf.js"; -import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js"; -import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js"; +import { + createRemoteEmbeddingProvider, + resolveRemoteEmbeddingClient, +} from "./embeddings-remote-provider.js"; import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; export type OpenAiEmbeddingClient = { @@ -33,32 +35,14 @@ export async function createOpenAiEmbeddingProvider( options: EmbeddingProviderOptions, ): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> { const client = await resolveOpenAiEmbeddingClient(options); - const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`; - - const embed = async (input: string[]): Promise => { - if (input.length === 0) { - return []; - } - return await fetchRemoteEmbeddingVectors({ - url, - headers: client.headers, - ssrfPolicy: client.ssrfPolicy, - body: { model: client.model, input }, - errorPrefix: "openai embeddings failed", - }); - }; return { - provider: { + provider: createRemoteEmbeddingProvider({ id: "openai", - model: client.model, + client, + errorPrefix: "openai embeddings failed", maxInputTokens: OPENAI_MAX_INPUT_TOKENS[client.model], - embedQuery: async (text) => { - const [vec] = await embed([text]); - return vec ?? []; - }, - embedBatch: embed, - }, + }), client, }; } @@ -66,11 +50,10 @@ export async function createOpenAiEmbeddingProvider( export async function resolveOpenAiEmbeddingClient( options: EmbeddingProviderOptions, ): Promise { - const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({ + return await resolveRemoteEmbeddingClient({ provider: "openai", options, defaultBaseUrl: DEFAULT_OPENAI_BASE_URL, + normalizeModel: normalizeOpenAiModel, }); - const model = normalizeOpenAiModel(options.model); - return { baseUrl, headers, ssrfPolicy, model }; } diff --git a/src/memory/embeddings-remote-client.ts b/src/memory/embeddings-remote-client.ts index 13b45f4777e..3a150c388aa 100644 --- a/src/memory/embeddings-remote-client.ts +++ b/src/memory/embeddings-remote-client.ts @@ -3,7 +3,7 @@ import type { SsrFPolicy } from "../infra/net/ssrf.js"; import type { EmbeddingProviderOptions } from "./embeddings.js"; import { buildRemoteBaseUrlPolicy } from "./remote-http.js"; -type RemoteEmbeddingProviderId = "openai" | "voyage" | "mistral"; +export type RemoteEmbeddingProviderId = "openai" | "voyage" | "mistral"; export async function resolveRemoteEmbeddingBearerClient(params: { provider: RemoteEmbeddingProviderId; diff --git a/src/memory/embeddings-remote-provider.ts b/src/memory/embeddings-remote-provider.ts new file mode 100644 index 00000000000..0d191af57e9 --- /dev/null +++ b/src/memory/embeddings-remote-provider.ts @@ -0,0 +1,63 @@ +import type { SsrFPolicy } from "../infra/net/ssrf.js"; +import { + resolveRemoteEmbeddingBearerClient, + type RemoteEmbeddingProviderId, +} from "./embeddings-remote-client.js"; +import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js"; +import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; + +export type RemoteEmbeddingClient = { + baseUrl: string; + headers: Record; + ssrfPolicy?: SsrFPolicy; + model: string; +}; + +export function createRemoteEmbeddingProvider(params: { + id: string; + client: RemoteEmbeddingClient; + errorPrefix: string; + maxInputTokens?: number; +}): EmbeddingProvider { + const { client } = params; + const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`; + + const embed = async (input: string[]): Promise => { + if (input.length === 0) { + return []; + } + return await fetchRemoteEmbeddingVectors({ + url, + headers: client.headers, + ssrfPolicy: client.ssrfPolicy, + body: { model: client.model, input }, + errorPrefix: params.errorPrefix, + }); + }; + + return { + id: params.id, + model: client.model, + ...(typeof params.maxInputTokens === "number" ? { maxInputTokens: params.maxInputTokens } : {}), + embedQuery: async (text) => { + const [vec] = await embed([text]); + return vec ?? []; + }, + embedBatch: embed, + }; +} + +export async function resolveRemoteEmbeddingClient(params: { + provider: RemoteEmbeddingProviderId; + options: EmbeddingProviderOptions; + defaultBaseUrl: string; + normalizeModel: (model: string) => string; +}): Promise { + const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({ + provider: params.provider, + options: params.options, + defaultBaseUrl: params.defaultBaseUrl, + }); + const model = params.normalizeModel(params.options.model); + return { baseUrl, headers, ssrfPolicy, model }; +} diff --git a/src/shared/entry-status.ts b/src/shared/entry-status.ts index 009e90a06fc..0141e0815a9 100644 --- a/src/shared/entry-status.ts +++ b/src/shared/entry-status.ts @@ -64,3 +64,30 @@ export function evaluateEntryMetadataRequirementsForCurrentPlatform( localPlatform: process.platform, }); } + +export function evaluateEntryRequirementsForCurrentPlatform(params: { + always: boolean; + entry: { + metadata?: (RequirementsMetadata & { emoji?: string; homepage?: string }) | null; + frontmatter?: { + emoji?: string; + homepage?: string; + website?: string; + url?: string; + } | null; + }; + hasLocalBin: (bin: string) => boolean; + remote?: RequirementRemote; + isEnvSatisfied: (envName: string) => boolean; + isConfigSatisfied: (pathStr: string) => boolean; +}): ReturnType { + return evaluateEntryMetadataRequirementsForCurrentPlatform({ + always: params.always, + metadata: params.entry.metadata, + frontmatter: params.entry.frontmatter, + hasLocalBin: params.hasLocalBin, + remote: params.remote, + isEnvSatisfied: params.isEnvSatisfied, + isConfigSatisfied: params.isConfigSatisfied, + }); +}