fix(memoryFlush): correct context token accounting for flush gating (#5343)

Merged via squash.

Prepared head SHA: afaa7bae3b
Co-authored-by: jarvis-medmatic <252428873+jarvis-medmatic@users.noreply.github.com>
Co-authored-by: jalehman <550978+jalehman@users.noreply.github.com>
Reviewed-by: @jalehman
This commit is contained in:
Jarvis
2026-03-01 01:54:57 +01:00
committed by GitHub
parent 812a996b2f
commit fcb6859784
8 changed files with 478 additions and 44 deletions

View File

@@ -0,0 +1 @@
- Memory flush: fix usage-threshold gating and transcript fallback paths so flushes run reliably when expected (#5343) (thanks @jarvis-medmatic)

View File

@@ -149,6 +149,7 @@ export function derivePromptTokens(usage?: {
export function deriveSessionTotalTokens(params: {
usage?: {
input?: number;
output?: number;
total?: number;
cacheRead?: number;
cacheWrite?: number;
@@ -159,11 +160,14 @@ export function deriveSessionTotalTokens(params: {
const promptOverride = params.promptTokens;
const hasPromptOverride =
typeof promptOverride === "number" && Number.isFinite(promptOverride) && promptOverride > 0;
const usage = params.usage;
if (!usage && !hasPromptOverride) {
return undefined;
}
const input = usage?.input ?? 0;
// NOTE: SessionEntry.totalTokens is used as a prompt/context snapshot.
// It intentionally excludes completion/output tokens.
const promptTokens = hasPromptOverride
? promptOverride
: derivePromptTokens({
@@ -171,15 +175,12 @@ export function deriveSessionTotalTokens(params: {
cacheRead: usage?.cacheRead,
cacheWrite: usage?.cacheWrite,
});
let total = promptTokens ?? usage?.total ?? input;
if (!(total > 0)) {
if (!(typeof promptTokens === "number") || !Number.isFinite(promptTokens) || promptTokens <= 0) {
return undefined;
}
// NOTE: Do NOT clamp total to contextTokens here. The stored totalTokens
// should reflect the actual token count (or best estimate). Clamping causes
// /status to display contextTokens/contextTokens (100%) when the accumulated
// input exceeds the context window, hiding the real usage. The display layer
// (formatTokens in status.ts) already caps the percentage at 999%.
return total;
// Keep this value unclamped; display layers are responsible for capping
// percentages for terminal output.
return promptTokens;
}

View File

@@ -1,10 +1,25 @@
import crypto from "node:crypto";
import fs from "node:fs";
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import { estimateMessagesTokens } from "../../agents/compaction.js";
import { runWithModelFallback } from "../../agents/model-fallback.js";
import { isCliProvider } from "../../agents/model-selection.js";
import { runEmbeddedPiAgent } from "../../agents/pi-embedded.js";
import { resolveSandboxConfigForAgent, resolveSandboxRuntimeStatus } from "../../agents/sandbox.js";
import {
derivePromptTokens,
hasNonzeroUsage,
normalizeUsage,
type UsageLike,
} from "../../agents/usage.js";
import type { OpenClawConfig } from "../../config/config.js";
import { type SessionEntry, updateSessionStoreEntry } from "../../config/sessions.js";
import {
resolveAgentIdFromSessionKey,
resolveSessionFilePath,
resolveSessionFilePathOptions,
type SessionEntry,
updateSessionStoreEntry,
} from "../../config/sessions.js";
import { logVerbose } from "../../globals.js";
import { registerAgentRunContext } from "../../infra/agent-events.js";
import type { TemplateContext } from "../templating.js";
@@ -24,9 +39,152 @@ import {
import type { FollowupRun } from "./queue.js";
import { incrementCompactionCount } from "./session-updates.js";
export function estimatePromptTokensForMemoryFlush(prompt?: string): number | undefined {
const trimmed = prompt?.trim();
if (!trimmed) {
return undefined;
}
const message: AgentMessage = { role: "user", content: trimmed, timestamp: Date.now() };
const tokens = estimateMessagesTokens([message]);
if (!Number.isFinite(tokens) || tokens <= 0) {
return undefined;
}
return Math.ceil(tokens);
}
export function resolveEffectivePromptTokens(
basePromptTokens?: number,
lastOutputTokens?: number,
promptTokenEstimate?: number,
): number {
const base = Math.max(0, basePromptTokens ?? 0);
const output = Math.max(0, lastOutputTokens ?? 0);
const estimate = Math.max(0, promptTokenEstimate ?? 0);
// Flush gating projects the next input context by adding the previous
// completion and the current user prompt estimate.
return base + output + estimate;
}
export type SessionTranscriptUsageSnapshot = {
promptTokens?: number;
outputTokens?: number;
};
// Keep a generous near-threshold window so large assistant outputs still trigger
// transcript reads in time to flip memory-flush gating when needed.
const TRANSCRIPT_OUTPUT_READ_BUFFER_TOKENS = 8192;
const TRANSCRIPT_TAIL_CHUNK_BYTES = 64 * 1024;
function parseUsageFromTranscriptLine(line: string): ReturnType<typeof normalizeUsage> | undefined {
const trimmed = line.trim();
if (!trimmed) {
return undefined;
}
try {
const parsed = JSON.parse(trimmed) as {
message?: { usage?: UsageLike };
usage?: UsageLike;
};
const usageRaw = parsed.message?.usage ?? parsed.usage;
const usage = normalizeUsage(usageRaw);
if (usage && hasNonzeroUsage(usage)) {
return usage;
}
} catch {
// ignore bad lines
}
return undefined;
}
async function readLastNonzeroUsageFromSessionLog(logPath: string) {
const handle = await fs.promises.open(logPath, "r");
try {
const stat = await handle.stat();
let position = stat.size;
let leadingPartial = "";
while (position > 0) {
const chunkSize = Math.min(TRANSCRIPT_TAIL_CHUNK_BYTES, position);
const start = position - chunkSize;
const buffer = Buffer.allocUnsafe(chunkSize);
const { bytesRead } = await handle.read(buffer, 0, chunkSize, start);
if (bytesRead <= 0) {
break;
}
const chunk = buffer.toString("utf-8", 0, bytesRead);
const combined = `${chunk}${leadingPartial}`;
const lines = combined.split(/\n+/);
leadingPartial = lines.shift() ?? "";
for (let i = lines.length - 1; i >= 0; i -= 1) {
const usage = parseUsageFromTranscriptLine(lines[i] ?? "");
if (usage) {
return usage;
}
}
position = start;
}
return parseUsageFromTranscriptLine(leadingPartial);
} finally {
await handle.close();
}
}
export async function readPromptTokensFromSessionLog(
sessionId?: string,
sessionEntry?: SessionEntry,
sessionKey?: string,
opts?: { storePath?: string },
): Promise<SessionTranscriptUsageSnapshot | undefined> {
if (!sessionId) {
return undefined;
}
try {
const transcriptPath = (
sessionEntry as (SessionEntry & { transcriptPath?: string }) | undefined
)?.transcriptPath?.trim();
const sessionFile = sessionEntry?.sessionFile?.trim() || transcriptPath;
const agentId = resolveAgentIdFromSessionKey(sessionKey);
const pathOpts = resolveSessionFilePathOptions({
agentId,
storePath: opts?.storePath,
});
// Normalize sessionFile through resolveSessionFilePath so relative entries
// are resolved against the sessions dir/store layout, not process.cwd().
const logPath = resolveSessionFilePath(
sessionId,
sessionFile ? { sessionFile } : sessionEntry,
pathOpts,
);
const lastUsage = await readLastNonzeroUsageFromSessionLog(logPath);
if (!lastUsage) {
return undefined;
}
const promptTokens = derivePromptTokens(lastUsage);
const outputRaw = lastUsage.output;
const outputTokens =
typeof outputRaw === "number" && Number.isFinite(outputRaw) && outputRaw > 0
? outputRaw
: undefined;
if (!(typeof promptTokens === "number") && !(typeof outputTokens === "number")) {
return undefined;
}
return {
promptTokens,
outputTokens,
};
} catch {
return undefined;
}
}
export async function runMemoryFlushIfNeeded(params: {
cfg: OpenClawConfig;
followupRun: FollowupRun;
promptForEstimate?: string;
sessionCtx: TemplateContext;
opts?: GetReplyOptions;
defaultModel: string;
@@ -58,28 +216,156 @@ export async function runMemoryFlushIfNeeded(params: {
return sandboxCfg.workspaceAccess === "rw";
})();
const isCli = isCliProvider(params.followupRun.run.provider, params.cfg);
const canAttemptFlush = memoryFlushWritable && !params.isHeartbeat && !isCli;
let entry =
params.sessionEntry ??
(params.sessionKey ? params.sessionStore?.[params.sessionKey] : undefined);
const contextWindowTokens = resolveMemoryFlushContextWindowTokens({
modelId: params.followupRun.run.model ?? params.defaultModel,
agentCfgContextTokens: params.agentCfgContextTokens,
});
const promptTokenEstimate = estimatePromptTokensForMemoryFlush(
params.promptForEstimate ?? params.followupRun.prompt,
);
const persistedPromptTokensRaw = entry?.totalTokens;
const persistedPromptTokens =
typeof persistedPromptTokensRaw === "number" &&
Number.isFinite(persistedPromptTokensRaw) &&
persistedPromptTokensRaw > 0
? persistedPromptTokensRaw
: undefined;
const hasFreshPersistedPromptTokens =
typeof persistedPromptTokens === "number" && entry?.totalTokensFresh === true;
const flushThreshold =
contextWindowTokens -
memoryFlushSettings.reserveTokensFloor -
memoryFlushSettings.softThresholdTokens;
// When totals are stale/unknown, derive prompt + last output from transcript so memory
// flush can still be evaluated against projected next-input size.
//
// When totals are fresh, only read the transcript when we're close enough to the
// threshold that missing the last output tokens could flip the decision.
const shouldReadTranscriptForOutput =
canAttemptFlush &&
entry &&
hasFreshPersistedPromptTokens &&
typeof promptTokenEstimate === "number" &&
Number.isFinite(promptTokenEstimate) &&
flushThreshold > 0 &&
(persistedPromptTokens ?? 0) + promptTokenEstimate >=
flushThreshold - TRANSCRIPT_OUTPUT_READ_BUFFER_TOKENS;
const shouldReadTranscript =
canAttemptFlush && entry && (!hasFreshPersistedPromptTokens || shouldReadTranscriptForOutput);
const transcriptUsageSnapshot = shouldReadTranscript
? await readPromptTokensFromSessionLog(
params.followupRun.run.sessionId,
entry,
params.sessionKey ?? params.followupRun.run.sessionKey,
{ storePath: params.storePath },
)
: undefined;
const transcriptPromptTokens = transcriptUsageSnapshot?.promptTokens;
const transcriptOutputTokens = transcriptUsageSnapshot?.outputTokens;
const hasReliableTranscriptPromptTokens =
typeof transcriptPromptTokens === "number" &&
Number.isFinite(transcriptPromptTokens) &&
transcriptPromptTokens > 0;
const shouldPersistTranscriptPromptTokens =
hasReliableTranscriptPromptTokens &&
(!hasFreshPersistedPromptTokens ||
(transcriptPromptTokens ?? 0) > (persistedPromptTokens ?? 0));
if (entry && shouldPersistTranscriptPromptTokens) {
const nextEntry = {
...entry,
totalTokens: transcriptPromptTokens,
totalTokensFresh: true,
};
entry = nextEntry;
if (params.sessionKey && params.sessionStore) {
params.sessionStore[params.sessionKey] = nextEntry;
}
if (params.storePath && params.sessionKey) {
try {
const updatedEntry = await updateSessionStoreEntry({
storePath: params.storePath,
sessionKey: params.sessionKey,
update: async () => ({ totalTokens: transcriptPromptTokens, totalTokensFresh: true }),
});
if (updatedEntry) {
entry = updatedEntry;
if (params.sessionStore) {
params.sessionStore[params.sessionKey] = updatedEntry;
}
}
} catch (err) {
logVerbose(`failed to persist derived prompt totalTokens: ${String(err)}`);
}
}
}
const promptTokensSnapshot = Math.max(
hasFreshPersistedPromptTokens ? (persistedPromptTokens ?? 0) : 0,
hasReliableTranscriptPromptTokens ? (transcriptPromptTokens ?? 0) : 0,
);
const hasFreshPromptTokensSnapshot =
promptTokensSnapshot > 0 &&
(hasFreshPersistedPromptTokens || hasReliableTranscriptPromptTokens);
const projectedTokenCount = hasFreshPromptTokensSnapshot
? resolveEffectivePromptTokens(
promptTokensSnapshot,
transcriptOutputTokens,
promptTokenEstimate,
)
: undefined;
const tokenCountForFlush =
typeof projectedTokenCount === "number" &&
Number.isFinite(projectedTokenCount) &&
projectedTokenCount > 0
? projectedTokenCount
: undefined;
// Diagnostic logging to understand why memory flush may not trigger.
logVerbose(
`memoryFlush check: sessionKey=${params.sessionKey} ` +
`tokenCount=${tokenCountForFlush ?? "undefined"} ` +
`contextWindow=${contextWindowTokens} threshold=${flushThreshold} ` +
`isHeartbeat=${params.isHeartbeat} isCli=${isCli} memoryFlushWritable=${memoryFlushWritable} ` +
`compactionCount=${entry?.compactionCount ?? 0} memoryFlushCompactionCount=${entry?.memoryFlushCompactionCount ?? "undefined"} ` +
`persistedPromptTokens=${persistedPromptTokens ?? "undefined"} persistedFresh=${entry?.totalTokensFresh === true} ` +
`promptTokensEst=${promptTokenEstimate ?? "undefined"} transcriptPromptTokens=${transcriptPromptTokens ?? "undefined"} transcriptOutputTokens=${transcriptOutputTokens ?? "undefined"} ` +
`projectedTokenCount=${projectedTokenCount ?? "undefined"}`,
);
const shouldFlushMemory =
memoryFlushSettings &&
memoryFlushWritable &&
!params.isHeartbeat &&
!isCliProvider(params.followupRun.run.provider, params.cfg) &&
!isCli &&
shouldRunMemoryFlush({
entry:
params.sessionEntry ??
(params.sessionKey ? params.sessionStore?.[params.sessionKey] : undefined),
contextWindowTokens: resolveMemoryFlushContextWindowTokens({
modelId: params.followupRun.run.model ?? params.defaultModel,
agentCfgContextTokens: params.agentCfgContextTokens,
}),
entry,
tokenCount: tokenCountForFlush,
contextWindowTokens,
reserveTokensFloor: memoryFlushSettings.reserveTokensFloor,
softThresholdTokens: memoryFlushSettings.softThresholdTokens,
});
if (!shouldFlushMemory) {
return params.sessionEntry;
return entry ?? params.sessionEntry;
}
let activeSessionEntry = params.sessionEntry;
logVerbose(
`memoryFlush triggered: sessionKey=${params.sessionKey} tokenCount=${tokenCountForFlush ?? "undefined"} threshold=${flushThreshold}`,
);
let activeSessionEntry = entry ?? params.sessionEntry;
const activeSessionStore = params.sessionStore;
const flushRunId = crypto.randomUUID();
if (params.sessionKey) {

View File

@@ -1492,6 +1492,66 @@ describe("runReplyAgent memory flush", () => {
});
});
it("uses configured prompts for memory flush runs", async () => {
await withTempStore(async (storePath) => {
const sessionKey = "main";
const sessionEntry = {
sessionId: "session",
updatedAt: Date.now(),
totalTokens: 80_000,
compactionCount: 1,
};
await seedSessionStore({ storePath, sessionKey, entry: sessionEntry });
const calls: Array<EmbeddedRunParams> = [];
state.runEmbeddedPiAgentMock.mockImplementation(async (params: EmbeddedRunParams) => {
calls.push(params);
if (params.prompt?.includes("Write notes.")) {
return { payloads: [], meta: {} };
}
return {
payloads: [{ text: "ok" }],
meta: { agentMeta: { usage: { input: 1, output: 1 } } },
};
});
const baseRun = createBaseRun({
storePath,
sessionEntry,
config: {
agents: {
defaults: {
compaction: {
memoryFlush: {
prompt: "Write notes.",
systemPrompt: "Flush memory now.",
},
},
},
},
},
runOverrides: { extraSystemPrompt: "extra system" },
});
await runReplyAgentWithBase({
baseRun,
storePath,
sessionKey,
sessionEntry,
commandBody: "hello",
});
const flushCall = calls[0];
expect(flushCall?.prompt).toContain("Write notes.");
expect(flushCall?.prompt).toContain("NO_REPLY");
expect(flushCall?.extraSystemPrompt).toContain("extra system");
expect(flushCall?.extraSystemPrompt).toContain("Flush memory now.");
expect(flushCall?.extraSystemPrompt).toContain("NO_REPLY");
expect(calls[1]?.prompt).toBe("hello");
});
});
it("runs a memory flush turn and updates session metadata", async () => {
await withTempStore(async (storePath) => {
const sessionKey = "main";
@@ -1541,6 +1601,66 @@ describe("runReplyAgent memory flush", () => {
});
});
it("runs memory flush when transcript fallback uses a relative sessionFile path", async () => {
await withTempStore(async (storePath) => {
const sessionKey = "main";
const sessionFile = "session-relative.jsonl";
const transcriptPath = path.join(path.dirname(storePath), sessionFile);
await fs.mkdir(path.dirname(transcriptPath), { recursive: true });
await fs.writeFile(
transcriptPath,
JSON.stringify({ usage: { input: 90_000, output: 8_000 } }),
"utf-8",
);
const sessionEntry = {
sessionId: "session",
updatedAt: Date.now(),
sessionFile,
totalTokens: 10,
totalTokensFresh: false,
compactionCount: 1,
};
await seedSessionStore({ storePath, sessionKey, entry: sessionEntry });
const calls: Array<{ prompt?: string }> = [];
state.runEmbeddedPiAgentMock.mockImplementation(async (params: EmbeddedRunParams) => {
calls.push({ prompt: params.prompt });
if (params.prompt?.includes("Pre-compaction memory flush.")) {
return { payloads: [], meta: {} };
}
return {
payloads: [{ text: "ok" }],
meta: { agentMeta: { usage: { input: 1, output: 1 } } },
};
});
const baseRun = createBaseRun({
storePath,
sessionEntry,
runOverrides: { sessionFile },
});
await runReplyAgentWithBase({
baseRun,
storePath,
sessionKey,
sessionEntry,
commandBody: "hello",
});
expect(calls).toHaveLength(2);
expect(calls[0]?.prompt).toContain("Pre-compaction memory flush.");
expect(calls[0]?.prompt).toContain("Current time:");
expect(calls[0]?.prompt).toMatch(/memory\/\d{4}-\d{2}-\d{2}\.md/);
expect(calls[1]?.prompt).toBe("hello");
const stored = JSON.parse(await fs.readFile(storePath, "utf-8"));
expect(stored[sessionKey].memoryFlushAt).toBeTypeOf("number");
});
});
it("skips memory flush when disabled in config", async () => {
await withTempStore(async (storePath) => {
const sessionKey = "main";

View File

@@ -251,6 +251,7 @@ export async function runReplyAgent(params: {
activeSessionEntry = await runMemoryFlushIfNeeded({
cfg,
followupRun,
promptForEstimate: followupRun.prompt,
sessionCtx,
opts,
defaultModel,

View File

@@ -115,11 +115,27 @@ export function shouldRunMemoryFlush(params: {
SessionEntry,
"totalTokens" | "totalTokensFresh" | "compactionCount" | "memoryFlushCompactionCount"
>;
/**
* Optional token count override for flush gating. When provided, this value is
* treated as a fresh context snapshot and used instead of the cached
* SessionEntry.totalTokens (which may be stale/unknown).
*/
tokenCount?: number;
contextWindowTokens: number;
reserveTokensFloor: number;
softThresholdTokens: number;
}): boolean {
const totalTokens = resolveFreshSessionTotalTokens(params.entry);
if (!params.entry) {
return false;
}
const override = params.tokenCount;
const overrideTokens =
typeof override === "number" && Number.isFinite(override) && override > 0
? Math.floor(override)
: undefined;
const totalTokens = overrideTokens ?? resolveFreshSessionTotalTokens(params.entry);
if (!totalTokens || totalTokens <= 0) {
return false;
}
@@ -134,8 +150,8 @@ export function shouldRunMemoryFlush(params: {
return false;
}
const compactionCount = params.entry?.compactionCount ?? 0;
const lastFlushAt = params.entry?.memoryFlushCompactionCount;
const compactionCount = params.entry.compactionCount ?? 0;
const lastFlushAt = params.entry.memoryFlushCompactionCount;
if (typeof lastFlushAt === "number" && lastFlushAt === compactionCount) {
return false;
}

View File

@@ -79,16 +79,20 @@ export async function updateSessionStoreAfterAgentRun(params: {
if (hasNonzeroUsage(usage)) {
const input = usage.input ?? 0;
const output = usage.output ?? 0;
const totalTokens =
deriveSessionTotalTokens({
usage,
contextTokens,
promptTokens,
}) ?? input;
const totalTokens = deriveSessionTotalTokens({
usage,
contextTokens,
promptTokens,
});
next.inputTokens = input;
next.outputTokens = output;
next.totalTokens = totalTokens;
next.totalTokensFresh = true;
if (typeof totalTokens === "number" && Number.isFinite(totalTokens) && totalTokens > 0) {
next.totalTokens = totalTokens;
next.totalTokensFresh = true;
} else {
next.totalTokens = undefined;
next.totalTokensFresh = false;
}
next.cacheRead = usage.cacheRead ?? 0;
next.cacheWrite = usage.cacheWrite ?? 0;
}

View File

@@ -514,27 +514,32 @@ export async function runCronIsolatedAgentTurn(params: {
if (hasNonzeroUsage(usage)) {
const input = usage.input ?? 0;
const output = usage.output ?? 0;
const totalTokens =
deriveSessionTotalTokens({
usage,
contextTokens,
promptTokens,
}) ?? input;
const totalTokens = deriveSessionTotalTokens({
usage,
contextTokens,
promptTokens,
});
cronSession.sessionEntry.inputTokens = input;
cronSession.sessionEntry.outputTokens = output;
cronSession.sessionEntry.totalTokens = totalTokens;
cronSession.sessionEntry.totalTokensFresh = true;
const telemetryUsage: NonNullable<CronRunTelemetry["usage"]> = {
input_tokens: input,
output_tokens: output,
};
if (typeof totalTokens === "number" && Number.isFinite(totalTokens) && totalTokens > 0) {
cronSession.sessionEntry.totalTokens = totalTokens;
cronSession.sessionEntry.totalTokensFresh = true;
telemetryUsage.total_tokens = totalTokens;
} else {
cronSession.sessionEntry.totalTokens = undefined;
cronSession.sessionEntry.totalTokensFresh = false;
}
cronSession.sessionEntry.cacheRead = usage.cacheRead ?? 0;
cronSession.sessionEntry.cacheWrite = usage.cacheWrite ?? 0;
telemetry = {
model: modelUsed,
provider: providerUsed,
usage: {
input_tokens: input,
output_tokens: output,
total_tokens: totalTokens,
},
usage: telemetryUsage,
};
} else {
telemetry = {