mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-23 23:33:49 +00:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0155a01bb1 | ||
|
|
cfeee5d511 | ||
|
|
f27672f6cf | ||
|
|
28420c14e4 | ||
|
|
10e0ea1309 | ||
|
|
0bd221ff41 | ||
|
|
5fda6f8ef3 | ||
|
|
9b956f6338 | ||
|
|
09923f654c | ||
|
|
ae7b972649 | ||
|
|
47885e3710 | ||
|
|
4b9a260b37 | ||
|
|
aea337cfe2 | ||
|
|
811f8f8b4f | ||
|
|
27734a23b1 | ||
|
|
1b8e538a77 | ||
|
|
41c2385aca |
@@ -28,4 +28,6 @@ bin/*
|
|||||||
.claude/*
|
.claude/*
|
||||||
.vscode/*
|
.vscode/*
|
||||||
.serena/*
|
.serena/*
|
||||||
.bmad/*
|
.agent/*
|
||||||
|
.bmad/*
|
||||||
|
_bmad/*
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -32,7 +32,9 @@ GEMINI.md
|
|||||||
.vscode/*
|
.vscode/*
|
||||||
.claude/*
|
.claude/*
|
||||||
.serena/*
|
.serena/*
|
||||||
|
.agent/*
|
||||||
.bmad/*
|
.bmad/*
|
||||||
|
_bmad/*
|
||||||
.mcp/cache/
|
.mcp/cache/
|
||||||
|
|
||||||
# macOS
|
# macOS
|
||||||
|
|||||||
@@ -242,6 +242,21 @@ func GetGeminiVertexModels() []*ModelInfo {
|
|||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3-flash-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1765929600,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3-flash-preview",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Gemini 3 Flash Preview",
|
||||||
|
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3-pro-image-preview",
|
ID: "gemini-3-pro-image-preview",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -317,7 +332,22 @@ func GetGeminiCLIModels() []*ModelInfo {
|
|||||||
Name: "models/gemini-3-pro-preview",
|
Name: "models/gemini-3-pro-preview",
|
||||||
Version: "3.0",
|
Version: "3.0",
|
||||||
DisplayName: "Gemini 3 Pro Preview",
|
DisplayName: "Gemini 3 Pro Preview",
|
||||||
Description: "Gemini 3 Pro Preview",
|
Description: "Our most intelligent model with SOTA reasoning and multimodal understanding, and powerful agentic and vibe coding capabilities",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3-flash-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1765929600,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3-flash-preview",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Gemini 3 Flash Preview",
|
||||||
|
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||||
InputTokenLimit: 1048576,
|
InputTokenLimit: 1048576,
|
||||||
OutputTokenLimit: 65536,
|
OutputTokenLimit: 65536,
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
@@ -389,6 +419,21 @@ func GetAIStudioModels() []*ModelInfo {
|
|||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3-flash-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1765929600,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3-flash-preview",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Gemini 3 Flash Preview",
|
||||||
|
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-pro-latest",
|
ID: "gemini-pro-latest",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
|
|||||||
@@ -73,6 +73,10 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au
|
|||||||
|
|
||||||
// Execute performs a non-streaming request to the Antigravity API.
|
// Execute performs a non-streaming request to the Antigravity API.
|
||||||
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
|
if strings.Contains(req.Model, "claude") {
|
||||||
|
return e.executeClaudeNonStream(ctx, auth, req, opts)
|
||||||
|
}
|
||||||
|
|
||||||
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
||||||
if errToken != nil {
|
if errToken != nil {
|
||||||
return resp, errToken
|
return resp, errToken
|
||||||
@@ -164,6 +168,336 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// executeClaudeNonStream performs a claude non-streaming request to the Antigravity API.
|
||||||
|
func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
|
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
||||||
|
if errToken != nil {
|
||||||
|
return resp, errToken
|
||||||
|
}
|
||||||
|
if updatedAuth != nil {
|
||||||
|
auth = updatedAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("antigravity")
|
||||||
|
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
|
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||||
|
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||||
|
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||||
|
|
||||||
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
|
||||||
|
var lastStatus int
|
||||||
|
var lastBody []byte
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
for idx, baseURL := range baseURLs {
|
||||||
|
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL)
|
||||||
|
if errReq != nil {
|
||||||
|
err = errReq
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
|
if errDo != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
|
lastStatus = 0
|
||||||
|
lastBody = nil
|
||||||
|
lastErr = errDo
|
||||||
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = errDo
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
if errRead != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
|
lastStatus = 0
|
||||||
|
lastBody = nil
|
||||||
|
lastErr = errRead
|
||||||
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = errRead
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||||
|
lastStatus = httpResp.StatusCode
|
||||||
|
lastBody = append([]byte(nil), bodyBytes...)
|
||||||
|
lastErr = nil
|
||||||
|
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
|
go func(resp *http.Response) {
|
||||||
|
defer close(out)
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Buffer(nil, streamScannerBuffer)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
|
|
||||||
|
// Filter usage metadata for all models
|
||||||
|
// Only retain usage statistics in the terminal chunk
|
||||||
|
line = FilterSSEUsageMetadata(line)
|
||||||
|
|
||||||
|
payload := jsonPayload(line)
|
||||||
|
if payload == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
||||||
|
reporter.publish(ctx, detail)
|
||||||
|
}
|
||||||
|
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: payload}
|
||||||
|
}
|
||||||
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
|
reporter.publishFailure(ctx)
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
|
} else {
|
||||||
|
reporter.ensurePublished(ctx)
|
||||||
|
}
|
||||||
|
}(httpResp)
|
||||||
|
|
||||||
|
var buffer bytes.Buffer
|
||||||
|
for chunk := range out {
|
||||||
|
if chunk.Err != nil {
|
||||||
|
return resp, chunk.Err
|
||||||
|
}
|
||||||
|
if len(chunk.Payload) > 0 {
|
||||||
|
_, _ = buffer.Write(chunk.Payload)
|
||||||
|
_, _ = buffer.Write([]byte("\n"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())}
|
||||||
|
|
||||||
|
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||||
|
var param any
|
||||||
|
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, resp.Payload, ¶m)
|
||||||
|
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||||
|
reporter.ensurePublished(ctx)
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case lastStatus != 0:
|
||||||
|
err = statusErr{code: lastStatus, msg: string(lastBody)}
|
||||||
|
case lastErr != nil:
|
||||||
|
err = lastErr
|
||||||
|
default:
|
||||||
|
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||||
|
}
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
|
||||||
|
responseTemplate := ""
|
||||||
|
var traceID string
|
||||||
|
var finishReason string
|
||||||
|
var modelVersion string
|
||||||
|
var responseID string
|
||||||
|
var role string
|
||||||
|
var usageRaw string
|
||||||
|
parts := make([]map[string]interface{}, 0)
|
||||||
|
var pendingKind string
|
||||||
|
var pendingText strings.Builder
|
||||||
|
var pendingThoughtSig string
|
||||||
|
|
||||||
|
flushPending := func() {
|
||||||
|
if pendingKind == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
text := pendingText.String()
|
||||||
|
switch pendingKind {
|
||||||
|
case "text":
|
||||||
|
if strings.TrimSpace(text) == "" {
|
||||||
|
pendingKind = ""
|
||||||
|
pendingText.Reset()
|
||||||
|
pendingThoughtSig = ""
|
||||||
|
return
|
||||||
|
}
|
||||||
|
parts = append(parts, map[string]interface{}{"text": text})
|
||||||
|
case "thought":
|
||||||
|
if strings.TrimSpace(text) == "" && pendingThoughtSig == "" {
|
||||||
|
pendingKind = ""
|
||||||
|
pendingText.Reset()
|
||||||
|
pendingThoughtSig = ""
|
||||||
|
return
|
||||||
|
}
|
||||||
|
part := map[string]interface{}{"thought": true}
|
||||||
|
part["text"] = text
|
||||||
|
if pendingThoughtSig != "" {
|
||||||
|
part["thoughtSignature"] = pendingThoughtSig
|
||||||
|
}
|
||||||
|
parts = append(parts, part)
|
||||||
|
}
|
||||||
|
pendingKind = ""
|
||||||
|
pendingText.Reset()
|
||||||
|
pendingThoughtSig = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizePart := func(partResult gjson.Result) map[string]interface{} {
|
||||||
|
var m map[string]interface{}
|
||||||
|
_ = json.Unmarshal([]byte(partResult.Raw), &m)
|
||||||
|
if m == nil {
|
||||||
|
m = map[string]interface{}{}
|
||||||
|
}
|
||||||
|
sig := partResult.Get("thoughtSignature").String()
|
||||||
|
if sig == "" {
|
||||||
|
sig = partResult.Get("thought_signature").String()
|
||||||
|
}
|
||||||
|
if sig != "" {
|
||||||
|
m["thoughtSignature"] = sig
|
||||||
|
delete(m, "thought_signature")
|
||||||
|
}
|
||||||
|
if inlineData, ok := m["inline_data"]; ok {
|
||||||
|
m["inlineData"] = inlineData
|
||||||
|
delete(m, "inline_data")
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, line := range bytes.Split(stream, []byte("\n")) {
|
||||||
|
trimmed := bytes.TrimSpace(line)
|
||||||
|
if len(trimmed) == 0 || !gjson.ValidBytes(trimmed) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
root := gjson.ParseBytes(trimmed)
|
||||||
|
responseNode := root.Get("response")
|
||||||
|
if !responseNode.Exists() {
|
||||||
|
if root.Get("candidates").Exists() {
|
||||||
|
responseNode = root
|
||||||
|
} else {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
responseTemplate = responseNode.Raw
|
||||||
|
|
||||||
|
if traceResult := root.Get("traceId"); traceResult.Exists() && traceResult.String() != "" {
|
||||||
|
traceID = traceResult.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if roleResult := responseNode.Get("candidates.0.content.role"); roleResult.Exists() {
|
||||||
|
role = roleResult.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if finishResult := responseNode.Get("candidates.0.finishReason"); finishResult.Exists() && finishResult.String() != "" {
|
||||||
|
finishReason = finishResult.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelResult := responseNode.Get("modelVersion"); modelResult.Exists() && modelResult.String() != "" {
|
||||||
|
modelVersion = modelResult.String()
|
||||||
|
}
|
||||||
|
if responseIDResult := responseNode.Get("responseId"); responseIDResult.Exists() && responseIDResult.String() != "" {
|
||||||
|
responseID = responseIDResult.String()
|
||||||
|
}
|
||||||
|
if usageResult := responseNode.Get("usageMetadata"); usageResult.Exists() {
|
||||||
|
usageRaw = usageResult.Raw
|
||||||
|
} else if usageResult := root.Get("usageMetadata"); usageResult.Exists() {
|
||||||
|
usageRaw = usageResult.Raw
|
||||||
|
}
|
||||||
|
|
||||||
|
if partsResult := responseNode.Get("candidates.0.content.parts"); partsResult.IsArray() {
|
||||||
|
for _, part := range partsResult.Array() {
|
||||||
|
hasFunctionCall := part.Get("functionCall").Exists()
|
||||||
|
hasInlineData := part.Get("inlineData").Exists() || part.Get("inline_data").Exists()
|
||||||
|
sig := part.Get("thoughtSignature").String()
|
||||||
|
if sig == "" {
|
||||||
|
sig = part.Get("thought_signature").String()
|
||||||
|
}
|
||||||
|
text := part.Get("text").String()
|
||||||
|
thought := part.Get("thought").Bool()
|
||||||
|
|
||||||
|
if hasFunctionCall || hasInlineData {
|
||||||
|
flushPending()
|
||||||
|
parts = append(parts, normalizePart(part))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if thought || part.Get("text").Exists() {
|
||||||
|
kind := "text"
|
||||||
|
if thought {
|
||||||
|
kind = "thought"
|
||||||
|
}
|
||||||
|
if pendingKind != "" && pendingKind != kind {
|
||||||
|
flushPending()
|
||||||
|
}
|
||||||
|
pendingKind = kind
|
||||||
|
pendingText.WriteString(text)
|
||||||
|
if kind == "thought" && sig != "" {
|
||||||
|
pendingThoughtSig = sig
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
flushPending()
|
||||||
|
parts = append(parts, normalizePart(part))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flushPending()
|
||||||
|
|
||||||
|
if responseTemplate == "" {
|
||||||
|
responseTemplate = `{"candidates":[{"content":{"role":"model","parts":[]}}]}`
|
||||||
|
}
|
||||||
|
|
||||||
|
partsJSON, _ := json.Marshal(parts)
|
||||||
|
responseTemplate, _ = sjson.SetRaw(responseTemplate, "candidates.0.content.parts", string(partsJSON))
|
||||||
|
if role != "" {
|
||||||
|
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.content.role", role)
|
||||||
|
}
|
||||||
|
if finishReason != "" {
|
||||||
|
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.finishReason", finishReason)
|
||||||
|
}
|
||||||
|
if modelVersion != "" {
|
||||||
|
responseTemplate, _ = sjson.Set(responseTemplate, "modelVersion", modelVersion)
|
||||||
|
}
|
||||||
|
if responseID != "" {
|
||||||
|
responseTemplate, _ = sjson.Set(responseTemplate, "responseId", responseID)
|
||||||
|
}
|
||||||
|
if usageRaw != "" {
|
||||||
|
responseTemplate, _ = sjson.SetRaw(responseTemplate, "usageMetadata", usageRaw)
|
||||||
|
} else if !gjson.Get(responseTemplate, "usageMetadata").Exists() {
|
||||||
|
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.promptTokenCount", 0)
|
||||||
|
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.candidatesTokenCount", 0)
|
||||||
|
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.totalTokenCount", 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := `{"response":{},"traceId":""}`
|
||||||
|
output, _ = sjson.SetRaw(output, "response", responseTemplate)
|
||||||
|
if traceID != "" {
|
||||||
|
output, _ = sjson.Set(output, "traceId", traceID)
|
||||||
|
}
|
||||||
|
return []byte(output)
|
||||||
|
}
|
||||||
|
|
||||||
// ExecuteStream performs a streaming request to the Antigravity API.
|
// ExecuteStream performs a streaming request to the Antigravity API.
|
||||||
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
ctx = context.WithValue(ctx, "alt", "")
|
ctx = context.WithValue(ctx, "alt", "")
|
||||||
@@ -549,27 +883,9 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
|||||||
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||||
}
|
}
|
||||||
|
|
||||||
strJSON = util.DeleteKey(strJSON, "$schema")
|
// Use the centralized schema cleaner to handle unsupported keywords,
|
||||||
strJSON = util.DeleteKey(strJSON, "maxItems")
|
// const->enum conversion, and flattening of types/anyOf.
|
||||||
strJSON = util.DeleteKey(strJSON, "minItems")
|
strJSON = util.CleanJSONSchemaForGemini(strJSON)
|
||||||
strJSON = util.DeleteKey(strJSON, "minLength")
|
|
||||||
strJSON = util.DeleteKey(strJSON, "maxLength")
|
|
||||||
strJSON = util.DeleteKey(strJSON, "exclusiveMinimum")
|
|
||||||
strJSON = util.DeleteKey(strJSON, "exclusiveMaximum")
|
|
||||||
strJSON = util.DeleteKey(strJSON, "$ref")
|
|
||||||
strJSON = util.DeleteKey(strJSON, "$defs")
|
|
||||||
|
|
||||||
paths = make([]string, 0)
|
|
||||||
util.Walk(gjson.Parse(strJSON), "", "anyOf", &paths)
|
|
||||||
for _, p := range paths {
|
|
||||||
anyOf := gjson.Get(strJSON, p)
|
|
||||||
if anyOf.IsArray() {
|
|
||||||
anyOfItems := anyOf.Array()
|
|
||||||
if len(anyOfItems) > 0 {
|
|
||||||
strJSON, _ = sjson.SetRaw(strJSON, p[:len(p)-len(".anyOf")], anyOfItems[0].Raw)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
payload = []byte(strJSON)
|
payload = []byte(strJSON)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ package claude
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -350,24 +349,25 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
response := map[string]interface{}{
|
responseJSON := `{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||||
"id": root.Get("response.responseId").String(),
|
responseJSON, _ = sjson.Set(responseJSON, "id", root.Get("response.responseId").String())
|
||||||
"type": "message",
|
responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String())
|
||||||
"role": "assistant",
|
responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens)
|
||||||
"model": root.Get("response.modelVersion").String(),
|
responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens)
|
||||||
"content": []interface{}{},
|
|
||||||
"stop_reason": nil,
|
contentArrayInitialized := false
|
||||||
"stop_sequence": nil,
|
ensureContentArray := func() {
|
||||||
"usage": map[string]interface{}{
|
if contentArrayInitialized {
|
||||||
"input_tokens": promptTokens,
|
return
|
||||||
"output_tokens": outputTokens,
|
}
|
||||||
},
|
responseJSON, _ = sjson.SetRaw(responseJSON, "content", "[]")
|
||||||
|
contentArrayInitialized = true
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := root.Get("response.candidates.0.content.parts")
|
parts := root.Get("response.candidates.0.content.parts")
|
||||||
var contentBlocks []interface{}
|
|
||||||
textBuilder := strings.Builder{}
|
textBuilder := strings.Builder{}
|
||||||
thinkingBuilder := strings.Builder{}
|
thinkingBuilder := strings.Builder{}
|
||||||
|
thinkingSignature := ""
|
||||||
toolIDCounter := 0
|
toolIDCounter := 0
|
||||||
hasToolCall := false
|
hasToolCall := false
|
||||||
|
|
||||||
@@ -375,28 +375,43 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
|||||||
if textBuilder.Len() == 0 {
|
if textBuilder.Len() == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
ensureContentArray()
|
||||||
"type": "text",
|
block := `{"type":"text","text":""}`
|
||||||
"text": textBuilder.String(),
|
block, _ = sjson.Set(block, "text", textBuilder.String())
|
||||||
})
|
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
||||||
textBuilder.Reset()
|
textBuilder.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
flushThinking := func() {
|
flushThinking := func() {
|
||||||
if thinkingBuilder.Len() == 0 {
|
if thinkingBuilder.Len() == 0 && thinkingSignature == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
ensureContentArray()
|
||||||
"type": "thinking",
|
block := `{"type":"thinking","thinking":""}`
|
||||||
"thinking": thinkingBuilder.String(),
|
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
||||||
})
|
if thinkingSignature != "" {
|
||||||
|
block, _ = sjson.Set(block, "signature", thinkingSignature)
|
||||||
|
}
|
||||||
|
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
||||||
thinkingBuilder.Reset()
|
thinkingBuilder.Reset()
|
||||||
|
thinkingSignature = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if parts.IsArray() {
|
if parts.IsArray() {
|
||||||
for _, part := range parts.Array() {
|
for _, part := range parts.Array() {
|
||||||
|
isThought := part.Get("thought").Bool()
|
||||||
|
if isThought {
|
||||||
|
sig := part.Get("thoughtSignature")
|
||||||
|
if !sig.Exists() {
|
||||||
|
sig = part.Get("thought_signature")
|
||||||
|
}
|
||||||
|
if sig.Exists() && sig.String() != "" {
|
||||||
|
thinkingSignature = sig.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if text := part.Get("text"); text.Exists() && text.String() != "" {
|
if text := part.Get("text"); text.Exists() && text.String() != "" {
|
||||||
if part.Get("thought").Bool() {
|
if isThought {
|
||||||
flushText()
|
flushText()
|
||||||
thinkingBuilder.WriteString(text.String())
|
thinkingBuilder.WriteString(text.String())
|
||||||
continue
|
continue
|
||||||
@@ -413,21 +428,16 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
|||||||
|
|
||||||
name := functionCall.Get("name").String()
|
name := functionCall.Get("name").String()
|
||||||
toolIDCounter++
|
toolIDCounter++
|
||||||
toolBlock := map[string]interface{}{
|
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||||
"type": "tool_use",
|
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||||
"id": fmt.Sprintf("tool_%d", toolIDCounter),
|
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
||||||
"name": name,
|
|
||||||
"input": map[string]interface{}{},
|
if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) {
|
||||||
|
toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
if args := functionCall.Get("args"); args.Exists() {
|
ensureContentArray()
|
||||||
var parsed interface{}
|
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", toolBlock)
|
||||||
if err := json.Unmarshal([]byte(args.Raw), &parsed); err == nil {
|
|
||||||
toolBlock["input"] = parsed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
contentBlocks = append(contentBlocks, toolBlock)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -436,8 +446,6 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
|||||||
flushThinking()
|
flushThinking()
|
||||||
flushText()
|
flushText()
|
||||||
|
|
||||||
response["content"] = contentBlocks
|
|
||||||
|
|
||||||
stopReason := "end_turn"
|
stopReason := "end_turn"
|
||||||
if hasToolCall {
|
if hasToolCall {
|
||||||
stopReason = "tool_use"
|
stopReason = "tool_use"
|
||||||
@@ -453,19 +461,15 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
response["stop_reason"] = stopReason
|
responseJSON, _ = sjson.Set(responseJSON, "stop_reason", stopReason)
|
||||||
|
|
||||||
if usage := response["usage"].(map[string]interface{}); usage["input_tokens"] == int64(0) && usage["output_tokens"] == int64(0) {
|
if promptTokens == 0 && outputTokens == 0 {
|
||||||
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() {
|
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() {
|
||||||
delete(response, "usage")
|
responseJSON, _ = sjson.Delete(responseJSON, "usage")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
encoded, err := json.Marshal(response)
|
return responseJSON
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return string(encoded)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ type Params struct {
|
|||||||
HasFirstResponse bool
|
HasFirstResponse bool
|
||||||
ResponseType int
|
ResponseType int
|
||||||
ResponseIndex int
|
ResponseIndex int
|
||||||
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
||||||
}
|
}
|
||||||
|
|
||||||
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
|
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
|
||||||
@@ -179,6 +179,18 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
|||||||
usedTool = true
|
usedTool = true
|
||||||
fcName := functionCallResult.Get("name").String()
|
fcName := functionCallResult.Get("name").String()
|
||||||
|
|
||||||
|
// FIX: Handle streaming split/delta where name might be empty in subsequent chunks.
|
||||||
|
// If we are already in tool use mode and name is empty, treat as continuation (delta).
|
||||||
|
if (*param).(*Params).ResponseType == 3 && fcName == "" {
|
||||||
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
|
||||||
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
}
|
||||||
|
// Continue to next part without closing/opening logic
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Handle state transitions when switching to function calls
|
// Handle state transitions when switching to function calls
|
||||||
// Close any existing function call block first
|
// Close any existing function call block first
|
||||||
if (*param).(*Params).ResponseType == 3 {
|
if (*param).(*Params).ResponseType == 3 {
|
||||||
|
|||||||
496
internal/util/gemini_schema.go
Normal file
496
internal/util/gemini_schema.go
Normal file
@@ -0,0 +1,496 @@
|
|||||||
|
// Package util provides utility functions for the CLI Proxy API server.
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
|
||||||
|
|
||||||
|
// CleanJSONSchemaForGemini transforms a JSON schema to be compatible with Gemini/Antigravity API.
|
||||||
|
// It handles unsupported keywords, type flattening, and schema simplification while preserving
|
||||||
|
// semantic information as description hints.
|
||||||
|
func CleanJSONSchemaForGemini(jsonStr string) string {
|
||||||
|
// Phase 1: Convert and add hints
|
||||||
|
jsonStr = convertRefsToHints(jsonStr)
|
||||||
|
jsonStr = convertConstToEnum(jsonStr)
|
||||||
|
jsonStr = addEnumHints(jsonStr)
|
||||||
|
jsonStr = addAdditionalPropertiesHints(jsonStr)
|
||||||
|
jsonStr = moveConstraintsToDescription(jsonStr)
|
||||||
|
|
||||||
|
// Phase 2: Flatten complex structures
|
||||||
|
jsonStr = mergeAllOf(jsonStr)
|
||||||
|
jsonStr = flattenAnyOfOneOf(jsonStr)
|
||||||
|
jsonStr = flattenTypeArrays(jsonStr)
|
||||||
|
|
||||||
|
// Phase 3: Cleanup
|
||||||
|
jsonStr = removeUnsupportedKeywords(jsonStr)
|
||||||
|
jsonStr = cleanupRequiredFields(jsonStr)
|
||||||
|
|
||||||
|
return jsonStr
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertRefsToHints converts $ref to description hints (Lazy Hint strategy).
|
||||||
|
func convertRefsToHints(jsonStr string) string {
|
||||||
|
paths := findPaths(jsonStr, "$ref")
|
||||||
|
sortByDepth(paths)
|
||||||
|
|
||||||
|
for _, p := range paths {
|
||||||
|
refVal := gjson.Get(jsonStr, p).String()
|
||||||
|
defName := refVal
|
||||||
|
if idx := strings.LastIndex(refVal, "/"); idx >= 0 {
|
||||||
|
defName = refVal[idx+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
parentPath := trimSuffix(p, ".$ref")
|
||||||
|
hint := fmt.Sprintf("See: %s", defName)
|
||||||
|
if existing := gjson.Get(jsonStr, descriptionPath(parentPath)).String(); existing != "" {
|
||||||
|
hint = fmt.Sprintf("%s (%s)", existing, hint)
|
||||||
|
}
|
||||||
|
|
||||||
|
replacement := `{"type":"object","description":""}`
|
||||||
|
replacement, _ = sjson.Set(replacement, "description", hint)
|
||||||
|
jsonStr = setRawAt(jsonStr, parentPath, replacement)
|
||||||
|
}
|
||||||
|
return jsonStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertConstToEnum(jsonStr string) string {
|
||||||
|
for _, p := range findPaths(jsonStr, "const") {
|
||||||
|
val := gjson.Get(jsonStr, p)
|
||||||
|
if !val.Exists() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
enumPath := trimSuffix(p, ".const") + ".enum"
|
||||||
|
if !gjson.Get(jsonStr, enumPath).Exists() {
|
||||||
|
jsonStr, _ = sjson.Set(jsonStr, enumPath, []interface{}{val.Value()})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func addEnumHints(jsonStr string) string {
|
||||||
|
for _, p := range findPaths(jsonStr, "enum") {
|
||||||
|
arr := gjson.Get(jsonStr, p)
|
||||||
|
if !arr.IsArray() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
items := arr.Array()
|
||||||
|
if len(items) <= 1 || len(items) > 10 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var vals []string
|
||||||
|
for _, item := range items {
|
||||||
|
vals = append(vals, item.String())
|
||||||
|
}
|
||||||
|
jsonStr = appendHint(jsonStr, trimSuffix(p, ".enum"), "Allowed: "+strings.Join(vals, ", "))
|
||||||
|
}
|
||||||
|
return jsonStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func addAdditionalPropertiesHints(jsonStr string) string {
|
||||||
|
for _, p := range findPaths(jsonStr, "additionalProperties") {
|
||||||
|
if gjson.Get(jsonStr, p).Type == gjson.False {
|
||||||
|
jsonStr = appendHint(jsonStr, trimSuffix(p, ".additionalProperties"), "No extra properties allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonStr
|
||||||
|
}
|
||||||
|
|
||||||
|
var unsupportedConstraints = []string{
|
||||||
|
"minLength", "maxLength", "exclusiveMinimum", "exclusiveMaximum",
|
||||||
|
"pattern", "minItems", "maxItems",
|
||||||
|
}
|
||||||
|
|
||||||
|
func moveConstraintsToDescription(jsonStr string) string {
|
||||||
|
for _, key := range unsupportedConstraints {
|
||||||
|
for _, p := range findPaths(jsonStr, key) {
|
||||||
|
val := gjson.Get(jsonStr, p)
|
||||||
|
if !val.Exists() || val.IsObject() || val.IsArray() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parentPath := trimSuffix(p, "."+key)
|
||||||
|
if isPropertyDefinition(parentPath) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
jsonStr = appendHint(jsonStr, parentPath, fmt.Sprintf("%s: %s", key, val.String()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeAllOf(jsonStr string) string {
|
||||||
|
paths := findPaths(jsonStr, "allOf")
|
||||||
|
sortByDepth(paths)
|
||||||
|
|
||||||
|
for _, p := range paths {
|
||||||
|
allOf := gjson.Get(jsonStr, p)
|
||||||
|
if !allOf.IsArray() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parentPath := trimSuffix(p, ".allOf")
|
||||||
|
|
||||||
|
for _, item := range allOf.Array() {
|
||||||
|
if props := item.Get("properties"); props.IsObject() {
|
||||||
|
props.ForEach(func(key, value gjson.Result) bool {
|
||||||
|
destPath := joinPath(parentPath, "properties."+escapeGJSONPathKey(key.String()))
|
||||||
|
jsonStr, _ = sjson.SetRaw(jsonStr, destPath, value.Raw)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if req := item.Get("required"); req.IsArray() {
|
||||||
|
reqPath := joinPath(parentPath, "required")
|
||||||
|
current := getStrings(jsonStr, reqPath)
|
||||||
|
for _, r := range req.Array() {
|
||||||
|
if s := r.String(); !contains(current, s) {
|
||||||
|
current = append(current, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
jsonStr, _ = sjson.Set(jsonStr, reqPath, current)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||||
|
}
|
||||||
|
return jsonStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func flattenAnyOfOneOf(jsonStr string) string {
|
||||||
|
for _, key := range []string{"anyOf", "oneOf"} {
|
||||||
|
paths := findPaths(jsonStr, key)
|
||||||
|
sortByDepth(paths)
|
||||||
|
|
||||||
|
for _, p := range paths {
|
||||||
|
arr := gjson.Get(jsonStr, p)
|
||||||
|
if !arr.IsArray() || len(arr.Array()) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
parentPath := trimSuffix(p, "."+key)
|
||||||
|
parentDesc := gjson.Get(jsonStr, descriptionPath(parentPath)).String()
|
||||||
|
|
||||||
|
items := arr.Array()
|
||||||
|
bestIdx, allTypes := selectBest(items)
|
||||||
|
selected := items[bestIdx].Raw
|
||||||
|
|
||||||
|
if parentDesc != "" {
|
||||||
|
selected = mergeDescriptionRaw(selected, parentDesc)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(allTypes) > 1 {
|
||||||
|
hint := "Accepts: " + strings.Join(allTypes, " | ")
|
||||||
|
selected = appendHintRaw(selected, hint)
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonStr = setRawAt(jsonStr, parentPath, selected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectBest(items []gjson.Result) (bestIdx int, types []string) {
|
||||||
|
bestScore := -1
|
||||||
|
for i, item := range items {
|
||||||
|
t := item.Get("type").String()
|
||||||
|
score := 0
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case t == "object" || item.Get("properties").Exists():
|
||||||
|
score, t = 3, orDefault(t, "object")
|
||||||
|
case t == "array" || item.Get("items").Exists():
|
||||||
|
score, t = 2, orDefault(t, "array")
|
||||||
|
case t != "" && t != "null":
|
||||||
|
score = 1
|
||||||
|
default:
|
||||||
|
t = orDefault(t, "null")
|
||||||
|
}
|
||||||
|
|
||||||
|
if t != "" {
|
||||||
|
types = append(types, t)
|
||||||
|
}
|
||||||
|
if score > bestScore {
|
||||||
|
bestScore, bestIdx = score, i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func flattenTypeArrays(jsonStr string) string {
|
||||||
|
paths := findPaths(jsonStr, "type")
|
||||||
|
sortByDepth(paths)
|
||||||
|
|
||||||
|
nullableFields := make(map[string][]string)
|
||||||
|
|
||||||
|
for _, p := range paths {
|
||||||
|
res := gjson.Get(jsonStr, p)
|
||||||
|
if !res.IsArray() || len(res.Array()) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
hasNull := false
|
||||||
|
var nonNullTypes []string
|
||||||
|
for _, item := range res.Array() {
|
||||||
|
s := item.String()
|
||||||
|
if s == "null" {
|
||||||
|
hasNull = true
|
||||||
|
} else if s != "" {
|
||||||
|
nonNullTypes = append(nonNullTypes, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
firstType := "string"
|
||||||
|
if len(nonNullTypes) > 0 {
|
||||||
|
firstType = nonNullTypes[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonStr, _ = sjson.Set(jsonStr, p, firstType)
|
||||||
|
|
||||||
|
parentPath := trimSuffix(p, ".type")
|
||||||
|
if len(nonNullTypes) > 1 {
|
||||||
|
hint := "Accepts: " + strings.Join(nonNullTypes, " | ")
|
||||||
|
jsonStr = appendHint(jsonStr, parentPath, hint)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasNull {
|
||||||
|
parts := splitGJSONPath(p)
|
||||||
|
if len(parts) >= 3 && parts[len(parts)-3] == "properties" {
|
||||||
|
fieldNameEscaped := parts[len(parts)-2]
|
||||||
|
fieldName := unescapeGJSONPathKey(fieldNameEscaped)
|
||||||
|
objectPath := strings.Join(parts[:len(parts)-3], ".")
|
||||||
|
nullableFields[objectPath] = append(nullableFields[objectPath], fieldName)
|
||||||
|
|
||||||
|
propPath := joinPath(objectPath, "properties."+fieldNameEscaped)
|
||||||
|
jsonStr = appendHint(jsonStr, propPath, "(nullable)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for objectPath, fields := range nullableFields {
|
||||||
|
reqPath := joinPath(objectPath, "required")
|
||||||
|
req := gjson.Get(jsonStr, reqPath)
|
||||||
|
if !req.IsArray() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var filtered []string
|
||||||
|
for _, r := range req.Array() {
|
||||||
|
if !contains(fields, r.String()) {
|
||||||
|
filtered = append(filtered, r.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(filtered) == 0 {
|
||||||
|
jsonStr, _ = sjson.Delete(jsonStr, reqPath)
|
||||||
|
} else {
|
||||||
|
jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeUnsupportedKeywords(jsonStr string) string {
|
||||||
|
keywords := append(unsupportedConstraints,
|
||||||
|
"$schema", "$defs", "definitions", "const", "$ref", "additionalProperties",
|
||||||
|
)
|
||||||
|
for _, key := range keywords {
|
||||||
|
for _, p := range findPaths(jsonStr, key) {
|
||||||
|
if isPropertyDefinition(trimSuffix(p, "."+key)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupRequiredFields(jsonStr string) string {
|
||||||
|
for _, p := range findPaths(jsonStr, "required") {
|
||||||
|
parentPath := trimSuffix(p, ".required")
|
||||||
|
propsPath := joinPath(parentPath, "properties")
|
||||||
|
|
||||||
|
req := gjson.Get(jsonStr, p)
|
||||||
|
props := gjson.Get(jsonStr, propsPath)
|
||||||
|
if !req.IsArray() || !props.IsObject() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var valid []string
|
||||||
|
for _, r := range req.Array() {
|
||||||
|
key := r.String()
|
||||||
|
if props.Get(escapeGJSONPathKey(key)).Exists() {
|
||||||
|
valid = append(valid, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(valid) != len(req.Array()) {
|
||||||
|
if len(valid) == 0 {
|
||||||
|
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||||
|
} else {
|
||||||
|
jsonStr, _ = sjson.Set(jsonStr, p, valid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonStr
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Helpers ---
|
||||||
|
|
||||||
|
func findPaths(jsonStr, field string) []string {
|
||||||
|
var paths []string
|
||||||
|
Walk(gjson.Parse(jsonStr), "", field, &paths)
|
||||||
|
return paths
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortByDepth(paths []string) {
|
||||||
|
sort.Slice(paths, func(i, j int) bool { return len(paths[i]) > len(paths[j]) })
|
||||||
|
}
|
||||||
|
|
||||||
|
func trimSuffix(path, suffix string) string {
|
||||||
|
if path == strings.TrimPrefix(suffix, ".") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSuffix(path, suffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinPath(base, suffix string) string {
|
||||||
|
if base == "" {
|
||||||
|
return suffix
|
||||||
|
}
|
||||||
|
return base + "." + suffix
|
||||||
|
}
|
||||||
|
|
||||||
|
func setRawAt(jsonStr, path, value string) string {
|
||||||
|
if path == "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
result, _ := sjson.SetRaw(jsonStr, path, value)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func isPropertyDefinition(path string) bool {
|
||||||
|
return path == "properties" || strings.HasSuffix(path, ".properties")
|
||||||
|
}
|
||||||
|
|
||||||
|
func descriptionPath(parentPath string) string {
|
||||||
|
if parentPath == "" || parentPath == "@this" {
|
||||||
|
return "description"
|
||||||
|
}
|
||||||
|
return parentPath + ".description"
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendHint(jsonStr, parentPath, hint string) string {
|
||||||
|
descPath := parentPath + ".description"
|
||||||
|
if parentPath == "" || parentPath == "@this" {
|
||||||
|
descPath = "description"
|
||||||
|
}
|
||||||
|
existing := gjson.Get(jsonStr, descPath).String()
|
||||||
|
if existing != "" {
|
||||||
|
hint = fmt.Sprintf("%s (%s)", existing, hint)
|
||||||
|
}
|
||||||
|
jsonStr, _ = sjson.Set(jsonStr, descPath, hint)
|
||||||
|
return jsonStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendHintRaw(jsonRaw, hint string) string {
|
||||||
|
existing := gjson.Get(jsonRaw, "description").String()
|
||||||
|
if existing != "" {
|
||||||
|
hint = fmt.Sprintf("%s (%s)", existing, hint)
|
||||||
|
}
|
||||||
|
jsonRaw, _ = sjson.Set(jsonRaw, "description", hint)
|
||||||
|
return jsonRaw
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStrings(jsonStr, path string) []string {
|
||||||
|
var result []string
|
||||||
|
if arr := gjson.Get(jsonStr, path); arr.IsArray() {
|
||||||
|
for _, r := range arr.Array() {
|
||||||
|
result = append(result, r.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(slice []string, item string) bool {
|
||||||
|
for _, s := range slice {
|
||||||
|
if s == item {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func orDefault(val, def string) string {
|
||||||
|
if val == "" {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func escapeGJSONPathKey(key string) string {
|
||||||
|
return gjsonPathKeyReplacer.Replace(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func unescapeGJSONPathKey(key string) string {
|
||||||
|
if !strings.Contains(key, "\\") {
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(key))
|
||||||
|
for i := 0; i < len(key); i++ {
|
||||||
|
if key[i] == '\\' && i+1 < len(key) {
|
||||||
|
i++
|
||||||
|
b.WriteByte(key[i])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b.WriteByte(key[i])
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitGJSONPath(path string) []string {
|
||||||
|
if path == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := make([]string, 0, strings.Count(path, ".")+1)
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(path))
|
||||||
|
|
||||||
|
for i := 0; i < len(path); i++ {
|
||||||
|
c := path[i]
|
||||||
|
if c == '\\' && i+1 < len(path) {
|
||||||
|
b.WriteByte('\\')
|
||||||
|
i++
|
||||||
|
b.WriteByte(path[i])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '.' {
|
||||||
|
parts = append(parts, b.String())
|
||||||
|
b.Reset()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b.WriteByte(c)
|
||||||
|
}
|
||||||
|
parts = append(parts, b.String())
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeDescriptionRaw(schemaRaw, parentDesc string) string {
|
||||||
|
childDesc := gjson.Get(schemaRaw, "description").String()
|
||||||
|
switch {
|
||||||
|
case childDesc == "":
|
||||||
|
schemaRaw, _ = sjson.Set(schemaRaw, "description", parentDesc)
|
||||||
|
return schemaRaw
|
||||||
|
case childDesc == parentDesc:
|
||||||
|
return schemaRaw
|
||||||
|
default:
|
||||||
|
combined := fmt.Sprintf("%s (%s)", parentDesc, childDesc)
|
||||||
|
schemaRaw, _ = sjson.Set(schemaRaw, "description", combined)
|
||||||
|
return schemaRaw
|
||||||
|
}
|
||||||
|
}
|
||||||
613
internal/util/gemini_schema_test.go
Normal file
613
internal/util/gemini_schema_test.go
Normal file
@@ -0,0 +1,613 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_ConstToEnum(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"kind": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "InsightVizNode"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
expected := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"kind": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["InsightVizNode"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
compareJSON(t, expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": ["string", "null"]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["name", "other"]
|
||||||
|
}`
|
||||||
|
|
||||||
|
expected := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "(nullable)"
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["other"]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
compareJSON(t, expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_ConstraintsToDescription(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"tags": {
|
||||||
|
"type": "array",
|
||||||
|
"description": "List of tags",
|
||||||
|
"minItems": 1
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "User name",
|
||||||
|
"minLength": 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
|
||||||
|
// minItems should be REMOVED and moved to description
|
||||||
|
if strings.Contains(result, `"minItems"`) {
|
||||||
|
t.Errorf("minItems keyword should be removed")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "minItems: 1") {
|
||||||
|
t.Errorf("minItems hint missing in description")
|
||||||
|
}
|
||||||
|
|
||||||
|
// minLength should be moved to description
|
||||||
|
if !strings.Contains(result, "minLength: 3") {
|
||||||
|
t.Errorf("minLength hint missing in description")
|
||||||
|
}
|
||||||
|
if strings.Contains(result, `"minLength":`) || strings.Contains(result, `"minLength" :`) {
|
||||||
|
t.Errorf("minLength keyword should be removed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_AnyOfFlattening_SmartSelection(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"anyOf": [
|
||||||
|
{ "type": "null" },
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"kind": { "type": "string" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
expected := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Accepts: null | object",
|
||||||
|
"properties": {
|
||||||
|
"kind": { "type": "string" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
compareJSON(t, expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_OneOfFlattening(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"config": {
|
||||||
|
"oneOf": [
|
||||||
|
{ "type": "string" },
|
||||||
|
{ "type": "integer" }
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
expected := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"config": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Accepts: string | integer"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
compareJSON(t, expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_AllOfMerging(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"allOf": [
|
||||||
|
{
|
||||||
|
"properties": {
|
||||||
|
"a": { "type": "string" }
|
||||||
|
},
|
||||||
|
"required": ["a"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"properties": {
|
||||||
|
"b": { "type": "integer" }
|
||||||
|
},
|
||||||
|
"required": ["b"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
expected := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": { "type": "string" },
|
||||||
|
"b": { "type": "integer" }
|
||||||
|
},
|
||||||
|
"required": ["a", "b"]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
compareJSON(t, expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_RefHandling(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"definitions": {
|
||||||
|
"User": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": { "type": "string" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"customer": { "$ref": "#/definitions/User" }
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
expected := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"customer": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "See: User"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
compareJSON(t, expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_RefHandling_DescriptionEscaping(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"definitions": {
|
||||||
|
"User": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": { "type": "string" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"customer": {
|
||||||
|
"description": "He said \"hi\"\\nsecond line",
|
||||||
|
"$ref": "#/definitions/User"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
expected := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"customer": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "He said \"hi\"\\nsecond line (See: User)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
compareJSON(t, expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_CyclicRefDefaults(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"definitions": {
|
||||||
|
"Node": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"child": { "$ref": "#/definitions/Node" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"$ref": "#/definitions/Node"
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
|
||||||
|
var resMap map[string]interface{}
|
||||||
|
json.Unmarshal([]byte(result), &resMap)
|
||||||
|
|
||||||
|
if resMap["type"] != "object" {
|
||||||
|
t.Errorf("Expected type: object, got: %v", resMap["type"])
|
||||||
|
}
|
||||||
|
|
||||||
|
desc, ok := resMap["description"].(string)
|
||||||
|
if !ok || !strings.Contains(desc, "Node") {
|
||||||
|
t.Errorf("Expected description hint containing 'Node', got: %v", resMap["description"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_RequiredCleanup(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {"type": "string"},
|
||||||
|
"b": {"type": "string"}
|
||||||
|
},
|
||||||
|
"required": ["a", "b", "c"]
|
||||||
|
}`
|
||||||
|
|
||||||
|
expected := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {"type": "string"},
|
||||||
|
"b": {"type": "string"}
|
||||||
|
},
|
||||||
|
"required": ["a", "b"]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
compareJSON(t, expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_AllOfMerging_DotKeys(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"allOf": [
|
||||||
|
{
|
||||||
|
"properties": {
|
||||||
|
"my.param": { "type": "string" }
|
||||||
|
},
|
||||||
|
"required": ["my.param"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"properties": {
|
||||||
|
"b": { "type": "integer" }
|
||||||
|
},
|
||||||
|
"required": ["b"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
expected := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"my.param": { "type": "string" },
|
||||||
|
"b": { "type": "integer" }
|
||||||
|
},
|
||||||
|
"required": ["my.param", "b"]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
compareJSON(t, expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_PropertyNameCollision(t *testing.T) {
|
||||||
|
// A tool has an argument named "pattern" - should NOT be treated as a constraint
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"pattern": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The regex pattern"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["pattern"]
|
||||||
|
}`
|
||||||
|
|
||||||
|
expected := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"pattern": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The regex pattern"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["pattern"]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
compareJSON(t, expected, result)
|
||||||
|
|
||||||
|
var resMap map[string]interface{}
|
||||||
|
json.Unmarshal([]byte(result), &resMap)
|
||||||
|
props, _ := resMap["properties"].(map[string]interface{})
|
||||||
|
if _, ok := props["description"]; ok {
|
||||||
|
t.Errorf("Invalid 'description' property injected into properties map")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_DotKeys(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"my.param": {
|
||||||
|
"type": "string",
|
||||||
|
"$ref": "#/definitions/MyType"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"definitions": {
|
||||||
|
"MyType": { "type": "string" }
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
|
||||||
|
var resMap map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(result), &resMap); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
props, ok := resMap["properties"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("properties missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
if val, ok := props["my.param"]; !ok {
|
||||||
|
t.Fatalf("Key 'my.param' is missing. Result: %s", result)
|
||||||
|
} else {
|
||||||
|
valMap, _ := val.(map[string]interface{})
|
||||||
|
if _, hasRef := valMap["$ref"]; hasRef {
|
||||||
|
t.Errorf("Key 'my.param' still contains $ref")
|
||||||
|
}
|
||||||
|
if _, ok := props["my"]; ok {
|
||||||
|
t.Errorf("Artifact key 'my' created by sjson splitting")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_AnyOfAlternativeHints(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"value": {
|
||||||
|
"anyOf": [
|
||||||
|
{ "type": "string" },
|
||||||
|
{ "type": "integer" },
|
||||||
|
{ "type": "null" }
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
|
||||||
|
if !strings.Contains(result, "Accepts:") {
|
||||||
|
t.Errorf("Expected alternative types hint, got: %s", result)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "string") || !strings.Contains(result, "integer") {
|
||||||
|
t.Errorf("Expected all alternative types in hint, got: %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_NullableHint(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "User name"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["name"]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
|
||||||
|
if !strings.Contains(result, "(nullable)") {
|
||||||
|
t.Errorf("Expected nullable hint, got: %s", result)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "User name") {
|
||||||
|
t.Errorf("Expected original description to be preserved, got: %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable_DotKey(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"my.param": {
|
||||||
|
"type": ["string", "null"]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["my.param", "other"]
|
||||||
|
}`
|
||||||
|
|
||||||
|
expected := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"my.param": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "(nullable)"
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["other"]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
compareJSON(t, expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_EnumHint(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"status": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["active", "inactive", "pending"],
|
||||||
|
"description": "Current status"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
|
||||||
|
if !strings.Contains(result, "Allowed:") {
|
||||||
|
t.Errorf("Expected enum values hint, got: %s", result)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "active") || !strings.Contains(result, "inactive") {
|
||||||
|
t.Errorf("Expected enum values in hint, got: %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_AdditionalPropertiesHint(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": { "type": "string" }
|
||||||
|
},
|
||||||
|
"additionalProperties": false
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
|
||||||
|
if !strings.Contains(result, "No extra properties allowed") {
|
||||||
|
t.Errorf("Expected additionalProperties hint, got: %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_AnyOfFlattening_PreservesDescription(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"config": {
|
||||||
|
"description": "Parent desc",
|
||||||
|
"anyOf": [
|
||||||
|
{ "type": "string", "description": "Child desc" },
|
||||||
|
{ "type": "integer" }
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
expected := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"config": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Parent desc (Child desc) (Accepts: string | integer)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
compareJSON(t, expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_SingleEnumNoHint(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"kind": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["fixed"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
|
||||||
|
if strings.Contains(result, "Allowed:") {
|
||||||
|
t.Errorf("Single value enum should not add Allowed hint, got: %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanJSONSchemaForGemini_MultipleNonNullTypes(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"value": {
|
||||||
|
"type": ["string", "integer", "boolean"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
result := CleanJSONSchemaForGemini(input)
|
||||||
|
|
||||||
|
if !strings.Contains(result, "Accepts:") {
|
||||||
|
t.Errorf("Expected multiple types hint, got: %s", result)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "string") || !strings.Contains(result, "integer") || !strings.Contains(result, "boolean") {
|
||||||
|
t.Errorf("Expected all types in hint, got: %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
|
||||||
|
var expMap, actMap map[string]interface{}
|
||||||
|
errExp := json.Unmarshal([]byte(expectedJSON), &expMap)
|
||||||
|
errAct := json.Unmarshal([]byte(actualJSON), &actMap)
|
||||||
|
|
||||||
|
if errExp != nil || errAct != nil {
|
||||||
|
t.Fatalf("JSON Unmarshal error. Exp: %v, Act: %v", errExp, errAct)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(expMap, actMap) {
|
||||||
|
expBytes, _ := json.MarshalIndent(expMap, "", " ")
|
||||||
|
actBytes, _ := json.MarshalIndent(actMap, "", " ")
|
||||||
|
t.Errorf("JSON mismatch:\nExpected:\n%s\n\nActual:\n%s", string(expBytes), string(actBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ package util
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
@@ -28,10 +29,17 @@ func Walk(value gjson.Result, path, field string, paths *[]string) {
|
|||||||
// For JSON objects and arrays, iterate through each child
|
// For JSON objects and arrays, iterate through each child
|
||||||
value.ForEach(func(key, val gjson.Result) bool {
|
value.ForEach(func(key, val gjson.Result) bool {
|
||||||
var childPath string
|
var childPath string
|
||||||
|
// Escape special characters for gjson/sjson path syntax
|
||||||
|
// . -> \.
|
||||||
|
// * -> \*
|
||||||
|
// ? -> \?
|
||||||
|
var keyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
|
||||||
|
safeKey := keyReplacer.Replace(key.String())
|
||||||
|
|
||||||
if path == "" {
|
if path == "" {
|
||||||
childPath = key.String()
|
childPath = safeKey
|
||||||
} else {
|
} else {
|
||||||
childPath = path + "." + key.String()
|
childPath = path + "." + safeKey
|
||||||
}
|
}
|
||||||
if key.String() == field {
|
if key.String() == field {
|
||||||
*paths = append(*paths, childPath)
|
*paths = append(*paths, childPath)
|
||||||
|
|||||||
270
internal/watcher/clients.go
Normal file
270
internal/watcher/clients.go
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
// clients.go implements watcher client lifecycle logic and persistence helpers.
|
||||||
|
// It reloads clients, handles incremental auth file changes, and persists updates when supported.
|
||||||
|
package watcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string, forceAuthRefresh bool) {
|
||||||
|
log.Debugf("starting full client load process")
|
||||||
|
|
||||||
|
w.clientsMutex.RLock()
|
||||||
|
cfg := w.config
|
||||||
|
w.clientsMutex.RUnlock()
|
||||||
|
|
||||||
|
if cfg == nil {
|
||||||
|
log.Error("config is nil, cannot reload clients")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(affectedOAuthProviders) > 0 {
|
||||||
|
w.clientsMutex.Lock()
|
||||||
|
if w.currentAuths != nil {
|
||||||
|
filtered := make(map[string]*coreauth.Auth, len(w.currentAuths))
|
||||||
|
for id, auth := range w.currentAuths {
|
||||||
|
if auth == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||||
|
if _, match := matchProvider(provider, affectedOAuthProviders); match {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered[id] = auth
|
||||||
|
}
|
||||||
|
w.currentAuths = filtered
|
||||||
|
log.Debugf("applying oauth-excluded-models to providers %v", affectedOAuthProviders)
|
||||||
|
} else {
|
||||||
|
w.currentAuths = nil
|
||||||
|
}
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg)
|
||||||
|
totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
||||||
|
log.Debugf("loaded %d API key clients", totalAPIKeyClients)
|
||||||
|
|
||||||
|
var authFileCount int
|
||||||
|
if rescanAuth {
|
||||||
|
authFileCount = w.loadFileClients(cfg)
|
||||||
|
log.Debugf("loaded %d file-based clients", authFileCount)
|
||||||
|
} else {
|
||||||
|
w.clientsMutex.RLock()
|
||||||
|
authFileCount = len(w.lastAuthHashes)
|
||||||
|
w.clientsMutex.RUnlock()
|
||||||
|
log.Debugf("skipping auth directory rescan; retaining %d existing auth files", authFileCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rescanAuth {
|
||||||
|
w.clientsMutex.Lock()
|
||||||
|
|
||||||
|
w.lastAuthHashes = make(map[string]string)
|
||||||
|
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
||||||
|
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
|
||||||
|
} else if resolvedAuthDir != "" {
|
||||||
|
_ = filepath.Walk(resolvedAuthDir, func(path string, info fs.FileInfo, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
|
||||||
|
if data, errReadFile := os.ReadFile(path); errReadFile == nil && len(data) > 0 {
|
||||||
|
sum := sha256.Sum256(data)
|
||||||
|
normalizedPath := w.normalizeAuthPath(path)
|
||||||
|
w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
||||||
|
|
||||||
|
if w.reloadCallback != nil {
|
||||||
|
log.Debugf("triggering server update callback before auth refresh")
|
||||||
|
w.reloadCallback(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.refreshAuthState(forceAuthRefresh)
|
||||||
|
|
||||||
|
log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)",
|
||||||
|
totalNewClients,
|
||||||
|
authFileCount,
|
||||||
|
geminiAPIKeyCount,
|
||||||
|
vertexCompatAPIKeyCount,
|
||||||
|
claudeAPIKeyCount,
|
||||||
|
codexAPIKeyCount,
|
||||||
|
openAICompatCount,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) addOrUpdateClient(path string) {
|
||||||
|
data, errRead := os.ReadFile(path)
|
||||||
|
if errRead != nil {
|
||||||
|
log.Errorf("failed to read auth file %s: %v", filepath.Base(path), errRead)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(data) == 0 {
|
||||||
|
log.Debugf("ignoring empty auth file: %s", filepath.Base(path))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sum := sha256.Sum256(data)
|
||||||
|
curHash := hex.EncodeToString(sum[:])
|
||||||
|
normalized := w.normalizeAuthPath(path)
|
||||||
|
|
||||||
|
w.clientsMutex.Lock()
|
||||||
|
|
||||||
|
cfg := w.config
|
||||||
|
if cfg == nil {
|
||||||
|
log.Error("config is nil, cannot add or update client")
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash {
|
||||||
|
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path))
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.lastAuthHashes[normalized] = curHash
|
||||||
|
|
||||||
|
w.clientsMutex.Unlock() // Unlock before the callback
|
||||||
|
|
||||||
|
w.refreshAuthState(false)
|
||||||
|
|
||||||
|
if w.reloadCallback != nil {
|
||||||
|
log.Debugf("triggering server update callback after add/update")
|
||||||
|
w.reloadCallback(cfg)
|
||||||
|
}
|
||||||
|
w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) removeClient(path string) {
|
||||||
|
normalized := w.normalizeAuthPath(path)
|
||||||
|
w.clientsMutex.Lock()
|
||||||
|
|
||||||
|
cfg := w.config
|
||||||
|
delete(w.lastAuthHashes, normalized)
|
||||||
|
|
||||||
|
w.clientsMutex.Unlock() // Release the lock before the callback
|
||||||
|
|
||||||
|
w.refreshAuthState(false)
|
||||||
|
|
||||||
|
if w.reloadCallback != nil {
|
||||||
|
log.Debugf("triggering server update callback after removal")
|
||||||
|
w.reloadCallback(cfg)
|
||||||
|
}
|
||||||
|
w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) loadFileClients(cfg *config.Config) int {
|
||||||
|
authFileCount := 0
|
||||||
|
successfulAuthCount := 0
|
||||||
|
|
||||||
|
authDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir)
|
||||||
|
if errResolveAuthDir != nil {
|
||||||
|
log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if authDir == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
errWalk := filepath.Walk(authDir, func(path string, info fs.FileInfo, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("error accessing path %s: %v", path, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
|
||||||
|
authFileCount++
|
||||||
|
log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path))
|
||||||
|
if data, errCreate := os.ReadFile(path); errCreate == nil && len(data) > 0 {
|
||||||
|
successfulAuthCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if errWalk != nil {
|
||||||
|
log.Errorf("error walking auth directory: %v", errWalk)
|
||||||
|
}
|
||||||
|
log.Debugf("auth directory scan complete - found %d .json files, %d readable", authFileCount, successfulAuthCount)
|
||||||
|
return authFileCount
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) {
|
||||||
|
geminiAPIKeyCount := 0
|
||||||
|
vertexCompatAPIKeyCount := 0
|
||||||
|
claudeAPIKeyCount := 0
|
||||||
|
codexAPIKeyCount := 0
|
||||||
|
openAICompatCount := 0
|
||||||
|
|
||||||
|
if len(cfg.GeminiKey) > 0 {
|
||||||
|
geminiAPIKeyCount += len(cfg.GeminiKey)
|
||||||
|
}
|
||||||
|
if len(cfg.VertexCompatAPIKey) > 0 {
|
||||||
|
vertexCompatAPIKeyCount += len(cfg.VertexCompatAPIKey)
|
||||||
|
}
|
||||||
|
if len(cfg.ClaudeKey) > 0 {
|
||||||
|
claudeAPIKeyCount += len(cfg.ClaudeKey)
|
||||||
|
}
|
||||||
|
if len(cfg.CodexKey) > 0 {
|
||||||
|
codexAPIKeyCount += len(cfg.CodexKey)
|
||||||
|
}
|
||||||
|
if len(cfg.OpenAICompatibility) > 0 {
|
||||||
|
for _, compatConfig := range cfg.OpenAICompatibility {
|
||||||
|
openAICompatCount += len(compatConfig.APIKeyEntries)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) persistConfigAsync() {
|
||||||
|
if w == nil || w.storePersister == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := w.storePersister.PersistConfig(ctx); err != nil {
|
||||||
|
log.Errorf("failed to persist config change: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) persistAuthAsync(message string, paths ...string) {
|
||||||
|
if w == nil || w.storePersister == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filtered := make([]string, 0, len(paths))
|
||||||
|
for _, p := range paths {
|
||||||
|
if trimmed := strings.TrimSpace(p); trimmed != "" {
|
||||||
|
filtered = append(filtered, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(filtered) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := w.storePersister.PersistAuthFiles(ctx, message, filtered...); err != nil {
|
||||||
|
log.Errorf("failed to persist auth changes: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
134
internal/watcher/config_reload.go
Normal file
134
internal/watcher/config_reload.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
// config_reload.go implements debounced configuration hot reload.
|
||||||
|
// It detects material changes and reloads clients when the config changes.
|
||||||
|
package watcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (w *Watcher) stopConfigReloadTimer() {
|
||||||
|
w.configReloadMu.Lock()
|
||||||
|
if w.configReloadTimer != nil {
|
||||||
|
w.configReloadTimer.Stop()
|
||||||
|
w.configReloadTimer = nil
|
||||||
|
}
|
||||||
|
w.configReloadMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) scheduleConfigReload() {
|
||||||
|
w.configReloadMu.Lock()
|
||||||
|
defer w.configReloadMu.Unlock()
|
||||||
|
if w.configReloadTimer != nil {
|
||||||
|
w.configReloadTimer.Stop()
|
||||||
|
}
|
||||||
|
w.configReloadTimer = time.AfterFunc(configReloadDebounce, func() {
|
||||||
|
w.configReloadMu.Lock()
|
||||||
|
w.configReloadTimer = nil
|
||||||
|
w.configReloadMu.Unlock()
|
||||||
|
w.reloadConfigIfChanged()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) reloadConfigIfChanged() {
|
||||||
|
data, err := os.ReadFile(w.configPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to read config file for hash check: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(data) == 0 {
|
||||||
|
log.Debugf("ignoring empty config file write event")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sum := sha256.Sum256(data)
|
||||||
|
newHash := hex.EncodeToString(sum[:])
|
||||||
|
|
||||||
|
w.clientsMutex.RLock()
|
||||||
|
currentHash := w.lastConfigHash
|
||||||
|
w.clientsMutex.RUnlock()
|
||||||
|
|
||||||
|
if currentHash != "" && currentHash == newHash {
|
||||||
|
log.Debugf("config file content unchanged (hash match), skipping reload")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Infof("config file changed, reloading: %s", w.configPath)
|
||||||
|
if w.reloadConfig() {
|
||||||
|
finalHash := newHash
|
||||||
|
if updatedData, errRead := os.ReadFile(w.configPath); errRead == nil && len(updatedData) > 0 {
|
||||||
|
sumUpdated := sha256.Sum256(updatedData)
|
||||||
|
finalHash = hex.EncodeToString(sumUpdated[:])
|
||||||
|
} else if errRead != nil {
|
||||||
|
log.WithError(errRead).Debug("failed to compute updated config hash after reload")
|
||||||
|
}
|
||||||
|
w.clientsMutex.Lock()
|
||||||
|
w.lastConfigHash = finalHash
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
w.persistConfigAsync()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) reloadConfig() bool {
|
||||||
|
log.Debug("=========================== CONFIG RELOAD ============================")
|
||||||
|
log.Debugf("starting config reload from: %s", w.configPath)
|
||||||
|
|
||||||
|
newConfig, errLoadConfig := config.LoadConfig(w.configPath)
|
||||||
|
if errLoadConfig != nil {
|
||||||
|
log.Errorf("failed to reload config: %v", errLoadConfig)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.mirroredAuthDir != "" {
|
||||||
|
newConfig.AuthDir = w.mirroredAuthDir
|
||||||
|
} else {
|
||||||
|
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(newConfig.AuthDir); errResolveAuthDir != nil {
|
||||||
|
log.Errorf("failed to resolve auth directory from config: %v", errResolveAuthDir)
|
||||||
|
} else {
|
||||||
|
newConfig.AuthDir = resolvedAuthDir
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w.clientsMutex.Lock()
|
||||||
|
var oldConfig *config.Config
|
||||||
|
_ = yaml.Unmarshal(w.oldConfigYaml, &oldConfig)
|
||||||
|
w.oldConfigYaml, _ = yaml.Marshal(newConfig)
|
||||||
|
w.config = newConfig
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
|
||||||
|
var affectedOAuthProviders []string
|
||||||
|
if oldConfig != nil {
|
||||||
|
_, affectedOAuthProviders = diff.DiffOAuthExcludedModelChanges(oldConfig.OAuthExcludedModels, newConfig.OAuthExcludedModels)
|
||||||
|
}
|
||||||
|
|
||||||
|
util.SetLogLevel(newConfig)
|
||||||
|
if oldConfig != nil && oldConfig.Debug != newConfig.Debug {
|
||||||
|
log.Debugf("log level updated - debug mode changed from %t to %t", oldConfig.Debug, newConfig.Debug)
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldConfig != nil {
|
||||||
|
details := diff.BuildConfigChangeDetails(oldConfig, newConfig)
|
||||||
|
if len(details) > 0 {
|
||||||
|
log.Debugf("config changes detected:")
|
||||||
|
for _, d := range details {
|
||||||
|
log.Debugf(" %s", d)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Debugf("no material config field changes detected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
|
||||||
|
forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix
|
||||||
|
|
||||||
|
log.Infof("config successfully reloaded, triggering client reload")
|
||||||
|
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)
|
||||||
|
return true
|
||||||
|
}
|
||||||
273
internal/watcher/dispatcher.go
Normal file
273
internal/watcher/dispatcher.go
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
// dispatcher.go implements auth update dispatching and queue management.
|
||||||
|
// It batches, deduplicates, and delivers auth updates to registered consumers.
|
||||||
|
package watcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (w *Watcher) setAuthUpdateQueue(queue chan<- AuthUpdate) {
|
||||||
|
w.clientsMutex.Lock()
|
||||||
|
defer w.clientsMutex.Unlock()
|
||||||
|
w.authQueue = queue
|
||||||
|
if w.dispatchCond == nil {
|
||||||
|
w.dispatchCond = sync.NewCond(&w.dispatchMu)
|
||||||
|
}
|
||||||
|
if w.dispatchCancel != nil {
|
||||||
|
w.dispatchCancel()
|
||||||
|
if w.dispatchCond != nil {
|
||||||
|
w.dispatchMu.Lock()
|
||||||
|
w.dispatchCond.Broadcast()
|
||||||
|
w.dispatchMu.Unlock()
|
||||||
|
}
|
||||||
|
w.dispatchCancel = nil
|
||||||
|
}
|
||||||
|
if queue != nil {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
w.dispatchCancel = cancel
|
||||||
|
go w.dispatchLoop(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) dispatchRuntimeAuthUpdate(update AuthUpdate) bool {
|
||||||
|
if w == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
w.clientsMutex.Lock()
|
||||||
|
if w.runtimeAuths == nil {
|
||||||
|
w.runtimeAuths = make(map[string]*coreauth.Auth)
|
||||||
|
}
|
||||||
|
switch update.Action {
|
||||||
|
case AuthUpdateActionAdd, AuthUpdateActionModify:
|
||||||
|
if update.Auth != nil && update.Auth.ID != "" {
|
||||||
|
clone := update.Auth.Clone()
|
||||||
|
w.runtimeAuths[clone.ID] = clone
|
||||||
|
if w.currentAuths == nil {
|
||||||
|
w.currentAuths = make(map[string]*coreauth.Auth)
|
||||||
|
}
|
||||||
|
w.currentAuths[clone.ID] = clone.Clone()
|
||||||
|
}
|
||||||
|
case AuthUpdateActionDelete:
|
||||||
|
id := update.ID
|
||||||
|
if id == "" && update.Auth != nil {
|
||||||
|
id = update.Auth.ID
|
||||||
|
}
|
||||||
|
if id != "" {
|
||||||
|
delete(w.runtimeAuths, id)
|
||||||
|
if w.currentAuths != nil {
|
||||||
|
delete(w.currentAuths, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
if w.getAuthQueue() == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
w.dispatchAuthUpdates([]AuthUpdate{update})
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) refreshAuthState(force bool) {
|
||||||
|
auths := w.SnapshotCoreAuths()
|
||||||
|
w.clientsMutex.Lock()
|
||||||
|
if len(w.runtimeAuths) > 0 {
|
||||||
|
for _, a := range w.runtimeAuths {
|
||||||
|
if a != nil {
|
||||||
|
auths = append(auths, a.Clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
updates := w.prepareAuthUpdatesLocked(auths, force)
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
w.dispatchAuthUpdates(updates)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) prepareAuthUpdatesLocked(auths []*coreauth.Auth, force bool) []AuthUpdate {
|
||||||
|
newState := make(map[string]*coreauth.Auth, len(auths))
|
||||||
|
for _, auth := range auths {
|
||||||
|
if auth == nil || auth.ID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newState[auth.ID] = auth.Clone()
|
||||||
|
}
|
||||||
|
if w.currentAuths == nil {
|
||||||
|
w.currentAuths = newState
|
||||||
|
if w.authQueue == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
updates := make([]AuthUpdate, 0, len(newState))
|
||||||
|
for id, auth := range newState {
|
||||||
|
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()})
|
||||||
|
}
|
||||||
|
return updates
|
||||||
|
}
|
||||||
|
if w.authQueue == nil {
|
||||||
|
w.currentAuths = newState
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
updates := make([]AuthUpdate, 0, len(newState)+len(w.currentAuths))
|
||||||
|
for id, auth := range newState {
|
||||||
|
if existing, ok := w.currentAuths[id]; !ok {
|
||||||
|
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()})
|
||||||
|
} else if force || !authEqual(existing, auth) {
|
||||||
|
updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: auth.Clone()})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for id := range w.currentAuths {
|
||||||
|
if _, ok := newState[id]; !ok {
|
||||||
|
updates = append(updates, AuthUpdate{Action: AuthUpdateActionDelete, ID: id})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.currentAuths = newState
|
||||||
|
return updates
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) dispatchAuthUpdates(updates []AuthUpdate) {
|
||||||
|
if len(updates) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
queue := w.getAuthQueue()
|
||||||
|
if queue == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
baseTS := time.Now().UnixNano()
|
||||||
|
w.dispatchMu.Lock()
|
||||||
|
if w.pendingUpdates == nil {
|
||||||
|
w.pendingUpdates = make(map[string]AuthUpdate)
|
||||||
|
}
|
||||||
|
for idx, update := range updates {
|
||||||
|
key := w.authUpdateKey(update, baseTS+int64(idx))
|
||||||
|
if _, exists := w.pendingUpdates[key]; !exists {
|
||||||
|
w.pendingOrder = append(w.pendingOrder, key)
|
||||||
|
}
|
||||||
|
w.pendingUpdates[key] = update
|
||||||
|
}
|
||||||
|
if w.dispatchCond != nil {
|
||||||
|
w.dispatchCond.Signal()
|
||||||
|
}
|
||||||
|
w.dispatchMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) authUpdateKey(update AuthUpdate, ts int64) string {
|
||||||
|
if update.ID != "" {
|
||||||
|
return update.ID
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s:%d", update.Action, ts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) dispatchLoop(ctx context.Context) {
|
||||||
|
for {
|
||||||
|
batch, ok := w.nextPendingBatch(ctx)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
queue := w.getAuthQueue()
|
||||||
|
if queue == nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, update := range batch {
|
||||||
|
select {
|
||||||
|
case queue <- update:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) nextPendingBatch(ctx context.Context) ([]AuthUpdate, bool) {
|
||||||
|
w.dispatchMu.Lock()
|
||||||
|
defer w.dispatchMu.Unlock()
|
||||||
|
for len(w.pendingOrder) == 0 {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
w.dispatchCond.Wait()
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
batch := make([]AuthUpdate, 0, len(w.pendingOrder))
|
||||||
|
for _, key := range w.pendingOrder {
|
||||||
|
batch = append(batch, w.pendingUpdates[key])
|
||||||
|
delete(w.pendingUpdates, key)
|
||||||
|
}
|
||||||
|
w.pendingOrder = w.pendingOrder[:0]
|
||||||
|
return batch, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) getAuthQueue() chan<- AuthUpdate {
|
||||||
|
w.clientsMutex.RLock()
|
||||||
|
defer w.clientsMutex.RUnlock()
|
||||||
|
return w.authQueue
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) stopDispatch() {
|
||||||
|
if w.dispatchCancel != nil {
|
||||||
|
w.dispatchCancel()
|
||||||
|
w.dispatchCancel = nil
|
||||||
|
}
|
||||||
|
w.dispatchMu.Lock()
|
||||||
|
w.pendingOrder = nil
|
||||||
|
w.pendingUpdates = nil
|
||||||
|
if w.dispatchCond != nil {
|
||||||
|
w.dispatchCond.Broadcast()
|
||||||
|
}
|
||||||
|
w.dispatchMu.Unlock()
|
||||||
|
w.clientsMutex.Lock()
|
||||||
|
w.authQueue = nil
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func authEqual(a, b *coreauth.Auth) bool {
|
||||||
|
return reflect.DeepEqual(normalizeAuth(a), normalizeAuth(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeAuth(a *coreauth.Auth) *coreauth.Auth {
|
||||||
|
if a == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
clone := a.Clone()
|
||||||
|
clone.CreatedAt = time.Time{}
|
||||||
|
clone.UpdatedAt = time.Time{}
|
||||||
|
clone.LastRefreshedAt = time.Time{}
|
||||||
|
clone.NextRefreshAfter = time.Time{}
|
||||||
|
clone.Runtime = nil
|
||||||
|
clone.Quota.NextRecoverAt = time.Time{}
|
||||||
|
return clone
|
||||||
|
}
|
||||||
|
|
||||||
|
func snapshotCoreAuths(cfg *config.Config, authDir string) []*coreauth.Auth {
|
||||||
|
ctx := &synthesizer.SynthesisContext{
|
||||||
|
Config: cfg,
|
||||||
|
AuthDir: authDir,
|
||||||
|
Now: time.Now(),
|
||||||
|
IDGenerator: synthesizer.NewStableIDGenerator(),
|
||||||
|
}
|
||||||
|
|
||||||
|
var out []*coreauth.Auth
|
||||||
|
|
||||||
|
configSynth := synthesizer.NewConfigSynthesizer()
|
||||||
|
if auths, err := configSynth.Synthesize(ctx); err == nil {
|
||||||
|
out = append(out, auths...)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileSynth := synthesizer.NewFileSynthesizer()
|
||||||
|
if auths, err := fileSynth.Synthesize(ctx); err == nil {
|
||||||
|
out = append(out, auths...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
260
internal/watcher/events.go
Normal file
260
internal/watcher/events.go
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
// events.go implements fsnotify event handling for config and auth file changes.
|
||||||
|
// It normalizes paths, debounces noisy events, and triggers reload/update logic.
|
||||||
|
package watcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
|
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func matchProvider(provider string, targets []string) (string, bool) {
|
||||||
|
p := strings.ToLower(strings.TrimSpace(provider))
|
||||||
|
for _, t := range targets {
|
||||||
|
if strings.EqualFold(p, strings.TrimSpace(t)) {
|
||||||
|
return p, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) start(ctx context.Context) error {
|
||||||
|
if errAddConfig := w.watcher.Add(w.configPath); errAddConfig != nil {
|
||||||
|
log.Errorf("failed to watch config file %s: %v", w.configPath, errAddConfig)
|
||||||
|
return errAddConfig
|
||||||
|
}
|
||||||
|
log.Debugf("watching config file: %s", w.configPath)
|
||||||
|
|
||||||
|
if errAddAuthDir := w.watcher.Add(w.authDir); errAddAuthDir != nil {
|
||||||
|
log.Errorf("failed to watch auth directory %s: %v", w.authDir, errAddAuthDir)
|
||||||
|
return errAddAuthDir
|
||||||
|
}
|
||||||
|
log.Debugf("watching auth directory: %s", w.authDir)
|
||||||
|
|
||||||
|
w.watchKiroIDETokenFile()
|
||||||
|
|
||||||
|
go w.processEvents(ctx)
|
||||||
|
|
||||||
|
w.reloadClients(true, nil, false)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) watchKiroIDETokenFile() {
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get home directory for Kiro IDE token watch: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
kiroTokenDir := filepath.Join(homeDir, ".aws", "sso", "cache")
|
||||||
|
|
||||||
|
if _, statErr := os.Stat(kiroTokenDir); os.IsNotExist(statErr) {
|
||||||
|
log.Debugf("Kiro IDE token directory does not exist: %s", kiroTokenDir)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if errAdd := w.watcher.Add(kiroTokenDir); errAdd != nil {
|
||||||
|
log.Debugf("failed to watch Kiro IDE token directory %s: %v", kiroTokenDir, errAdd)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debugf("watching Kiro IDE token directory: %s", kiroTokenDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) processEvents(ctx context.Context) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case event, ok := <-w.watcher.Events:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.handleEvent(event)
|
||||||
|
case errWatch, ok := <-w.watcher.Errors:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("file watcher error: %v", errWatch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||||
|
// Filter only relevant events: config file or auth-dir JSON files.
|
||||||
|
configOps := fsnotify.Write | fsnotify.Create | fsnotify.Rename
|
||||||
|
normalizedName := w.normalizeAuthPath(event.Name)
|
||||||
|
normalizedConfigPath := w.normalizeAuthPath(w.configPath)
|
||||||
|
normalizedAuthDir := w.normalizeAuthPath(w.authDir)
|
||||||
|
isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0
|
||||||
|
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
|
||||||
|
isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0
|
||||||
|
isKiroIDEToken := w.isKiroIDETokenFile(event.Name) && event.Op&authOps != 0
|
||||||
|
if !isConfigEvent && !isAuthJSON && !isKiroIDEToken {
|
||||||
|
// Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if isKiroIDEToken {
|
||||||
|
w.handleKiroIDETokenChange(event)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name)
|
||||||
|
|
||||||
|
// Handle config file changes
|
||||||
|
if isConfigEvent {
|
||||||
|
log.Debugf("config file change details - operation: %s, timestamp: %s", event.Op.String(), now.Format("2006-01-02 15:04:05.000"))
|
||||||
|
w.scheduleConfigReload()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle auth directory changes incrementally (.json only)
|
||||||
|
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
|
||||||
|
if w.shouldDebounceRemove(normalizedName, now) {
|
||||||
|
log.Debugf("debouncing remove event for %s", filepath.Base(event.Name))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Atomic replace on some platforms may surface as Rename (or Remove) before the new file is ready.
|
||||||
|
// Wait briefly; if the path exists again, treat as an update instead of removal.
|
||||||
|
time.Sleep(replaceCheckDelay)
|
||||||
|
if _, statErr := os.Stat(event.Name); statErr == nil {
|
||||||
|
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
|
||||||
|
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
|
||||||
|
w.addOrUpdateClient(event.Name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !w.isKnownAuthFile(event.Name) {
|
||||||
|
log.Debugf("ignoring remove for unknown auth file: %s", filepath.Base(event.Name))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
|
||||||
|
w.removeClient(event.Name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if event.Op&(fsnotify.Create|fsnotify.Write) != 0 {
|
||||||
|
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
|
||||||
|
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
|
||||||
|
w.addOrUpdateClient(event.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) isKiroIDETokenFile(path string) bool {
|
||||||
|
normalized := filepath.ToSlash(path)
|
||||||
|
return strings.HasSuffix(normalized, "kiro-auth-token.json") && strings.Contains(normalized, ".aws/sso/cache")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) handleKiroIDETokenChange(event fsnotify.Event) {
|
||||||
|
log.Debugf("Kiro IDE token file event detected: %s %s", event.Op.String(), event.Name)
|
||||||
|
|
||||||
|
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
|
||||||
|
time.Sleep(replaceCheckDelay)
|
||||||
|
if _, statErr := os.Stat(event.Name); statErr != nil {
|
||||||
|
log.Debugf("Kiro IDE token file removed: %s", event.Name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenData, err := kiroauth.LoadKiroIDEToken()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to load Kiro IDE token after change: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Kiro IDE token file updated, access token refreshed (provider: %s)", tokenData.Provider)
|
||||||
|
|
||||||
|
w.refreshAuthState(true)
|
||||||
|
|
||||||
|
w.clientsMutex.RLock()
|
||||||
|
cfg := w.config
|
||||||
|
w.clientsMutex.RUnlock()
|
||||||
|
|
||||||
|
if w.reloadCallback != nil && cfg != nil {
|
||||||
|
log.Debugf("triggering server update callback after Kiro IDE token change")
|
||||||
|
w.reloadCallback(cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) authFileUnchanged(path string) (bool, error) {
|
||||||
|
data, errRead := os.ReadFile(path)
|
||||||
|
if errRead != nil {
|
||||||
|
return false, errRead
|
||||||
|
}
|
||||||
|
if len(data) == 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
sum := sha256.Sum256(data)
|
||||||
|
curHash := hex.EncodeToString(sum[:])
|
||||||
|
|
||||||
|
normalized := w.normalizeAuthPath(path)
|
||||||
|
w.clientsMutex.RLock()
|
||||||
|
prevHash, ok := w.lastAuthHashes[normalized]
|
||||||
|
w.clientsMutex.RUnlock()
|
||||||
|
if ok && prevHash == curHash {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) isKnownAuthFile(path string) bool {
|
||||||
|
normalized := w.normalizeAuthPath(path)
|
||||||
|
w.clientsMutex.RLock()
|
||||||
|
defer w.clientsMutex.RUnlock()
|
||||||
|
_, ok := w.lastAuthHashes[normalized]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) normalizeAuthPath(path string) string {
|
||||||
|
trimmed := strings.TrimSpace(path)
|
||||||
|
if trimmed == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
cleaned := filepath.Clean(trimmed)
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
cleaned = strings.TrimPrefix(cleaned, `\\?\`)
|
||||||
|
cleaned = strings.ToLower(cleaned)
|
||||||
|
}
|
||||||
|
return cleaned
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) shouldDebounceRemove(normalizedPath string, now time.Time) bool {
|
||||||
|
if normalizedPath == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
w.clientsMutex.Lock()
|
||||||
|
if w.lastRemoveTimes == nil {
|
||||||
|
w.lastRemoveTimes = make(map[string]time.Time)
|
||||||
|
}
|
||||||
|
if last, ok := w.lastRemoveTimes[normalizedPath]; ok {
|
||||||
|
if now.Sub(last) < authRemoveDebounceWindow {
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.lastRemoveTimes[normalizedPath] = now
|
||||||
|
if len(w.lastRemoveTimes) > 128 {
|
||||||
|
cutoff := now.Add(-2 * authRemoveDebounceWindow)
|
||||||
|
for p, t := range w.lastRemoveTimes {
|
||||||
|
if t.Before(cutoff) {
|
||||||
|
delete(w.lastRemoveTimes, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -1,46 +1,22 @@
|
|||||||
// Package watcher provides file system monitoring functionality for the CLI Proxy API.
|
// Package watcher watches config/auth files and triggers hot reloads.
|
||||||
// It watches configuration files and authentication directories for changes,
|
// It supports cross-platform fsnotify event handling.
|
||||||
// automatically reloading clients and configuration when files are modified.
|
|
||||||
// The package handles cross-platform file system events and supports hot-reloading.
|
|
||||||
package watcher
|
package watcher
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
|
||||||
"io/fs"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"reflect"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
|
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
||||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func matchProvider(provider string, targets []string) (string, bool) {
|
|
||||||
p := strings.ToLower(strings.TrimSpace(provider))
|
|
||||||
for _, t := range targets {
|
|
||||||
if strings.EqualFold(p, strings.TrimSpace(t)) {
|
|
||||||
return p, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return p, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// storePersister captures persistence-capable token store methods used by the watcher.
|
// storePersister captures persistence-capable token store methods used by the watcher.
|
||||||
type storePersister interface {
|
type storePersister interface {
|
||||||
PersistConfig(ctx context.Context) error
|
PersistConfig(ctx context.Context) error
|
||||||
@@ -132,54 +108,7 @@ func NewWatcher(configPath, authDir string, reloadCallback func(*config.Config))
|
|||||||
|
|
||||||
// Start begins watching the configuration file and authentication directory
|
// Start begins watching the configuration file and authentication directory
|
||||||
func (w *Watcher) Start(ctx context.Context) error {
|
func (w *Watcher) Start(ctx context.Context) error {
|
||||||
// Watch the config file
|
return w.start(ctx)
|
||||||
if errAddConfig := w.watcher.Add(w.configPath); errAddConfig != nil {
|
|
||||||
log.Errorf("failed to watch config file %s: %v", w.configPath, errAddConfig)
|
|
||||||
return errAddConfig
|
|
||||||
}
|
|
||||||
log.Debugf("watching config file: %s", w.configPath)
|
|
||||||
|
|
||||||
// Watch the auth directory
|
|
||||||
if errAddAuthDir := w.watcher.Add(w.authDir); errAddAuthDir != nil {
|
|
||||||
log.Errorf("failed to watch auth directory %s: %v", w.authDir, errAddAuthDir)
|
|
||||||
return errAddAuthDir
|
|
||||||
}
|
|
||||||
log.Debugf("watching auth directory: %s", w.authDir)
|
|
||||||
|
|
||||||
// Watch Kiro IDE token file directory for automatic token updates
|
|
||||||
w.watchKiroIDETokenFile()
|
|
||||||
|
|
||||||
// Start the event processing goroutine
|
|
||||||
go w.processEvents(ctx)
|
|
||||||
|
|
||||||
// Perform an initial full reload based on current config and auth dir
|
|
||||||
w.reloadClients(true, nil, false)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// watchKiroIDETokenFile adds the Kiro IDE token file directory to the watcher.
|
|
||||||
// This enables automatic detection of token updates from Kiro IDE.
|
|
||||||
func (w *Watcher) watchKiroIDETokenFile() {
|
|
||||||
homeDir, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to get home directory for Kiro IDE token watch: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Kiro IDE stores tokens in ~/.aws/sso/cache/
|
|
||||||
kiroTokenDir := filepath.Join(homeDir, ".aws", "sso", "cache")
|
|
||||||
|
|
||||||
// Check if directory exists
|
|
||||||
if _, statErr := os.Stat(kiroTokenDir); os.IsNotExist(statErr) {
|
|
||||||
log.Debugf("Kiro IDE token directory does not exist: %s", kiroTokenDir)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if errAdd := w.watcher.Add(kiroTokenDir); errAdd != nil {
|
|
||||||
log.Debugf("failed to watch Kiro IDE token directory %s: %v", kiroTokenDir, errAdd)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Debugf("watching Kiro IDE token directory: %s", kiroTokenDir)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop stops the file watcher
|
// Stop stops the file watcher
|
||||||
@@ -189,15 +118,6 @@ func (w *Watcher) Stop() error {
|
|||||||
return w.watcher.Close()
|
return w.watcher.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Watcher) stopConfigReloadTimer() {
|
|
||||||
w.configReloadMu.Lock()
|
|
||||||
if w.configReloadTimer != nil {
|
|
||||||
w.configReloadTimer.Stop()
|
|
||||||
w.configReloadTimer = nil
|
|
||||||
}
|
|
||||||
w.configReloadMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetConfig updates the current configuration
|
// SetConfig updates the current configuration
|
||||||
func (w *Watcher) SetConfig(cfg *config.Config) {
|
func (w *Watcher) SetConfig(cfg *config.Config) {
|
||||||
w.clientsMutex.Lock()
|
w.clientsMutex.Lock()
|
||||||
@@ -208,873 +128,20 @@ func (w *Watcher) SetConfig(cfg *config.Config) {
|
|||||||
|
|
||||||
// SetAuthUpdateQueue sets the queue used to emit auth updates.
|
// SetAuthUpdateQueue sets the queue used to emit auth updates.
|
||||||
func (w *Watcher) SetAuthUpdateQueue(queue chan<- AuthUpdate) {
|
func (w *Watcher) SetAuthUpdateQueue(queue chan<- AuthUpdate) {
|
||||||
w.clientsMutex.Lock()
|
w.setAuthUpdateQueue(queue)
|
||||||
defer w.clientsMutex.Unlock()
|
|
||||||
w.authQueue = queue
|
|
||||||
if w.dispatchCond == nil {
|
|
||||||
w.dispatchCond = sync.NewCond(&w.dispatchMu)
|
|
||||||
}
|
|
||||||
if w.dispatchCancel != nil {
|
|
||||||
w.dispatchCancel()
|
|
||||||
if w.dispatchCond != nil {
|
|
||||||
w.dispatchMu.Lock()
|
|
||||||
w.dispatchCond.Broadcast()
|
|
||||||
w.dispatchMu.Unlock()
|
|
||||||
}
|
|
||||||
w.dispatchCancel = nil
|
|
||||||
}
|
|
||||||
if queue != nil {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
w.dispatchCancel = cancel
|
|
||||||
go w.dispatchLoop(ctx)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DispatchRuntimeAuthUpdate allows external runtime providers (e.g., websocket-driven auths)
|
// DispatchRuntimeAuthUpdate allows external runtime providers (e.g., websocket-driven auths)
|
||||||
// to push auth updates through the same queue used by file/config watchers.
|
// to push auth updates through the same queue used by file/config watchers.
|
||||||
// Returns true if the update was enqueued; false if no queue is configured.
|
// Returns true if the update was enqueued; false if no queue is configured.
|
||||||
func (w *Watcher) DispatchRuntimeAuthUpdate(update AuthUpdate) bool {
|
func (w *Watcher) DispatchRuntimeAuthUpdate(update AuthUpdate) bool {
|
||||||
if w == nil {
|
return w.dispatchRuntimeAuthUpdate(update)
|
||||||
return false
|
|
||||||
}
|
|
||||||
w.clientsMutex.Lock()
|
|
||||||
if w.runtimeAuths == nil {
|
|
||||||
w.runtimeAuths = make(map[string]*coreauth.Auth)
|
|
||||||
}
|
|
||||||
switch update.Action {
|
|
||||||
case AuthUpdateActionAdd, AuthUpdateActionModify:
|
|
||||||
if update.Auth != nil && update.Auth.ID != "" {
|
|
||||||
clone := update.Auth.Clone()
|
|
||||||
w.runtimeAuths[clone.ID] = clone
|
|
||||||
if w.currentAuths == nil {
|
|
||||||
w.currentAuths = make(map[string]*coreauth.Auth)
|
|
||||||
}
|
|
||||||
w.currentAuths[clone.ID] = clone.Clone()
|
|
||||||
}
|
|
||||||
case AuthUpdateActionDelete:
|
|
||||||
id := update.ID
|
|
||||||
if id == "" && update.Auth != nil {
|
|
||||||
id = update.Auth.ID
|
|
||||||
}
|
|
||||||
if id != "" {
|
|
||||||
delete(w.runtimeAuths, id)
|
|
||||||
if w.currentAuths != nil {
|
|
||||||
delete(w.currentAuths, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w.clientsMutex.Unlock()
|
|
||||||
if w.getAuthQueue() == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
w.dispatchAuthUpdates([]AuthUpdate{update})
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Watcher) refreshAuthState(force bool) {
|
|
||||||
auths := w.SnapshotCoreAuths()
|
|
||||||
w.clientsMutex.Lock()
|
|
||||||
if len(w.runtimeAuths) > 0 {
|
|
||||||
for _, a := range w.runtimeAuths {
|
|
||||||
if a != nil {
|
|
||||||
auths = append(auths, a.Clone())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
updates := w.prepareAuthUpdatesLocked(auths, force)
|
|
||||||
w.clientsMutex.Unlock()
|
|
||||||
w.dispatchAuthUpdates(updates)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Watcher) prepareAuthUpdatesLocked(auths []*coreauth.Auth, force bool) []AuthUpdate {
|
|
||||||
newState := make(map[string]*coreauth.Auth, len(auths))
|
|
||||||
for _, auth := range auths {
|
|
||||||
if auth == nil || auth.ID == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newState[auth.ID] = auth.Clone()
|
|
||||||
}
|
|
||||||
if w.currentAuths == nil {
|
|
||||||
w.currentAuths = newState
|
|
||||||
if w.authQueue == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
updates := make([]AuthUpdate, 0, len(newState))
|
|
||||||
for id, auth := range newState {
|
|
||||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()})
|
|
||||||
}
|
|
||||||
return updates
|
|
||||||
}
|
|
||||||
if w.authQueue == nil {
|
|
||||||
w.currentAuths = newState
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
updates := make([]AuthUpdate, 0, len(newState)+len(w.currentAuths))
|
|
||||||
for id, auth := range newState {
|
|
||||||
if existing, ok := w.currentAuths[id]; !ok {
|
|
||||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()})
|
|
||||||
} else if force || !authEqual(existing, auth) {
|
|
||||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: auth.Clone()})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for id := range w.currentAuths {
|
|
||||||
if _, ok := newState[id]; !ok {
|
|
||||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionDelete, ID: id})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w.currentAuths = newState
|
|
||||||
return updates
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Watcher) dispatchAuthUpdates(updates []AuthUpdate) {
|
|
||||||
if len(updates) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
queue := w.getAuthQueue()
|
|
||||||
if queue == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
baseTS := time.Now().UnixNano()
|
|
||||||
w.dispatchMu.Lock()
|
|
||||||
if w.pendingUpdates == nil {
|
|
||||||
w.pendingUpdates = make(map[string]AuthUpdate)
|
|
||||||
}
|
|
||||||
for idx, update := range updates {
|
|
||||||
key := w.authUpdateKey(update, baseTS+int64(idx))
|
|
||||||
if _, exists := w.pendingUpdates[key]; !exists {
|
|
||||||
w.pendingOrder = append(w.pendingOrder, key)
|
|
||||||
}
|
|
||||||
w.pendingUpdates[key] = update
|
|
||||||
}
|
|
||||||
if w.dispatchCond != nil {
|
|
||||||
w.dispatchCond.Signal()
|
|
||||||
}
|
|
||||||
w.dispatchMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Watcher) authUpdateKey(update AuthUpdate, ts int64) string {
|
|
||||||
if update.ID != "" {
|
|
||||||
return update.ID
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%s:%d", update.Action, ts)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Watcher) dispatchLoop(ctx context.Context) {
|
|
||||||
for {
|
|
||||||
batch, ok := w.nextPendingBatch(ctx)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
queue := w.getAuthQueue()
|
|
||||||
if queue == nil {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, update := range batch {
|
|
||||||
select {
|
|
||||||
case queue <- update:
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Watcher) nextPendingBatch(ctx context.Context) ([]AuthUpdate, bool) {
|
|
||||||
w.dispatchMu.Lock()
|
|
||||||
defer w.dispatchMu.Unlock()
|
|
||||||
for len(w.pendingOrder) == 0 {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
w.dispatchCond.Wait()
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
batch := make([]AuthUpdate, 0, len(w.pendingOrder))
|
|
||||||
for _, key := range w.pendingOrder {
|
|
||||||
batch = append(batch, w.pendingUpdates[key])
|
|
||||||
delete(w.pendingUpdates, key)
|
|
||||||
}
|
|
||||||
w.pendingOrder = w.pendingOrder[:0]
|
|
||||||
return batch, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Watcher) getAuthQueue() chan<- AuthUpdate {
|
|
||||||
w.clientsMutex.RLock()
|
|
||||||
defer w.clientsMutex.RUnlock()
|
|
||||||
return w.authQueue
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Watcher) stopDispatch() {
|
|
||||||
if w.dispatchCancel != nil {
|
|
||||||
w.dispatchCancel()
|
|
||||||
w.dispatchCancel = nil
|
|
||||||
}
|
|
||||||
w.dispatchMu.Lock()
|
|
||||||
w.pendingOrder = nil
|
|
||||||
w.pendingUpdates = nil
|
|
||||||
if w.dispatchCond != nil {
|
|
||||||
w.dispatchCond.Broadcast()
|
|
||||||
}
|
|
||||||
w.dispatchMu.Unlock()
|
|
||||||
w.clientsMutex.Lock()
|
|
||||||
w.authQueue = nil
|
|
||||||
w.clientsMutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Watcher) persistConfigAsync() {
|
|
||||||
if w == nil || w.storePersister == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := w.storePersister.PersistConfig(ctx); err != nil {
|
|
||||||
log.Errorf("failed to persist config change: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Watcher) persistAuthAsync(message string, paths ...string) {
|
|
||||||
if w == nil || w.storePersister == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
filtered := make([]string, 0, len(paths))
|
|
||||||
for _, p := range paths {
|
|
||||||
if trimmed := strings.TrimSpace(p); trimmed != "" {
|
|
||||||
filtered = append(filtered, trimmed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(filtered) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := w.storePersister.PersistAuthFiles(ctx, message, filtered...); err != nil {
|
|
||||||
log.Errorf("failed to persist auth changes: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func authEqual(a, b *coreauth.Auth) bool {
|
|
||||||
return reflect.DeepEqual(normalizeAuth(a), normalizeAuth(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeAuth(a *coreauth.Auth) *coreauth.Auth {
|
|
||||||
if a == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
clone := a.Clone()
|
|
||||||
clone.CreatedAt = time.Time{}
|
|
||||||
clone.UpdatedAt = time.Time{}
|
|
||||||
clone.LastRefreshedAt = time.Time{}
|
|
||||||
clone.NextRefreshAfter = time.Time{}
|
|
||||||
clone.Runtime = nil
|
|
||||||
clone.Quota.NextRecoverAt = time.Time{}
|
|
||||||
return clone
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetClients sets the file-based clients.
|
|
||||||
// SetClients removed
|
|
||||||
// SetAPIKeyClients removed
|
|
||||||
|
|
||||||
// processEvents handles file system events
|
|
||||||
func (w *Watcher) processEvents(ctx context.Context) {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case event, ok := <-w.watcher.Events:
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.handleEvent(event)
|
|
||||||
case errWatch, ok := <-w.watcher.Errors:
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Errorf("file watcher error: %v", errWatch)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Watcher) authFileUnchanged(path string) (bool, error) {
|
|
||||||
data, errRead := os.ReadFile(path)
|
|
||||||
if errRead != nil {
|
|
||||||
return false, errRead
|
|
||||||
}
|
|
||||||
if len(data) == 0 {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
sum := sha256.Sum256(data)
|
|
||||||
curHash := hex.EncodeToString(sum[:])
|
|
||||||
|
|
||||||
normalized := w.normalizeAuthPath(path)
|
|
||||||
w.clientsMutex.RLock()
|
|
||||||
prevHash, ok := w.lastAuthHashes[normalized]
|
|
||||||
w.clientsMutex.RUnlock()
|
|
||||||
if ok && prevHash == curHash {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Watcher) isKnownAuthFile(path string) bool {
|
|
||||||
normalized := w.normalizeAuthPath(path)
|
|
||||||
w.clientsMutex.RLock()
|
|
||||||
defer w.clientsMutex.RUnlock()
|
|
||||||
_, ok := w.lastAuthHashes[normalized]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Watcher) normalizeAuthPath(path string) string {
|
|
||||||
trimmed := strings.TrimSpace(path)
|
|
||||||
if trimmed == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
cleaned := filepath.Clean(trimmed)
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
cleaned = strings.TrimPrefix(cleaned, `\\?\`)
|
|
||||||
cleaned = strings.ToLower(cleaned)
|
|
||||||
}
|
|
||||||
return cleaned
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Watcher) shouldDebounceRemove(normalizedPath string, now time.Time) bool {
|
|
||||||
if normalizedPath == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
w.clientsMutex.Lock()
|
|
||||||
if w.lastRemoveTimes == nil {
|
|
||||||
w.lastRemoveTimes = make(map[string]time.Time)
|
|
||||||
}
|
|
||||||
if last, ok := w.lastRemoveTimes[normalizedPath]; ok {
|
|
||||||
if now.Sub(last) < authRemoveDebounceWindow {
|
|
||||||
w.clientsMutex.Unlock()
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w.lastRemoveTimes[normalizedPath] = now
|
|
||||||
if len(w.lastRemoveTimes) > 128 {
|
|
||||||
cutoff := now.Add(-2 * authRemoveDebounceWindow)
|
|
||||||
for p, t := range w.lastRemoveTimes {
|
|
||||||
if t.Before(cutoff) {
|
|
||||||
delete(w.lastRemoveTimes, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w.clientsMutex.Unlock()
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleEvent processes individual file system events
|
|
||||||
func (w *Watcher) handleEvent(event fsnotify.Event) {
|
|
||||||
// Filter only relevant events: config file or auth-dir JSON files.
|
|
||||||
configOps := fsnotify.Write | fsnotify.Create | fsnotify.Rename
|
|
||||||
normalizedName := w.normalizeAuthPath(event.Name)
|
|
||||||
normalizedConfigPath := w.normalizeAuthPath(w.configPath)
|
|
||||||
normalizedAuthDir := w.normalizeAuthPath(w.authDir)
|
|
||||||
isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0
|
|
||||||
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
|
|
||||||
isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0
|
|
||||||
|
|
||||||
// Check for Kiro IDE token file changes
|
|
||||||
isKiroIDEToken := w.isKiroIDETokenFile(event.Name) && event.Op&authOps != 0
|
|
||||||
|
|
||||||
if !isConfigEvent && !isAuthJSON && !isKiroIDEToken {
|
|
||||||
// Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle Kiro IDE token file changes
|
|
||||||
if isKiroIDEToken {
|
|
||||||
w.handleKiroIDETokenChange(event)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name)
|
|
||||||
|
|
||||||
// Handle config file changes
|
|
||||||
if isConfigEvent {
|
|
||||||
log.Debugf("config file change details - operation: %s, timestamp: %s", event.Op.String(), now.Format("2006-01-02 15:04:05.000"))
|
|
||||||
w.scheduleConfigReload()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle auth directory changes incrementally (.json only)
|
|
||||||
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
|
|
||||||
if w.shouldDebounceRemove(normalizedName, now) {
|
|
||||||
log.Debugf("debouncing remove event for %s", filepath.Base(event.Name))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Atomic replace on some platforms may surface as Rename (or Remove) before the new file is ready.
|
|
||||||
// Wait briefly; if the path exists again, treat as an update instead of removal.
|
|
||||||
time.Sleep(replaceCheckDelay)
|
|
||||||
if _, statErr := os.Stat(event.Name); statErr == nil {
|
|
||||||
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
|
|
||||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
|
|
||||||
w.addOrUpdateClient(event.Name)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !w.isKnownAuthFile(event.Name) {
|
|
||||||
log.Debugf("ignoring remove for unknown auth file: %s", filepath.Base(event.Name))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
|
|
||||||
w.removeClient(event.Name)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if event.Op&(fsnotify.Create|fsnotify.Write) != 0 {
|
|
||||||
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
|
|
||||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
|
|
||||||
w.addOrUpdateClient(event.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Watcher) scheduleConfigReload() {
|
|
||||||
w.configReloadMu.Lock()
|
|
||||||
defer w.configReloadMu.Unlock()
|
|
||||||
if w.configReloadTimer != nil {
|
|
||||||
w.configReloadTimer.Stop()
|
|
||||||
}
|
|
||||||
w.configReloadTimer = time.AfterFunc(configReloadDebounce, func() {
|
|
||||||
w.configReloadMu.Lock()
|
|
||||||
w.configReloadTimer = nil
|
|
||||||
w.configReloadMu.Unlock()
|
|
||||||
w.reloadConfigIfChanged()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// isKiroIDETokenFile checks if the given path is the Kiro IDE token file.
|
|
||||||
func (w *Watcher) isKiroIDETokenFile(path string) bool {
|
|
||||||
// Check if it's the kiro-auth-token.json file in ~/.aws/sso/cache/
|
|
||||||
// Use filepath.ToSlash to ensure consistent separators across platforms (Windows uses backslashes)
|
|
||||||
normalized := filepath.ToSlash(path)
|
|
||||||
return strings.HasSuffix(normalized, "kiro-auth-token.json") && strings.Contains(normalized, ".aws/sso/cache")
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleKiroIDETokenChange processes changes to the Kiro IDE token file.
|
|
||||||
// When the token file is updated by Kiro IDE, this triggers a reload of Kiro auth.
|
|
||||||
func (w *Watcher) handleKiroIDETokenChange(event fsnotify.Event) {
|
|
||||||
log.Debugf("Kiro IDE token file event detected: %s %s", event.Op.String(), event.Name)
|
|
||||||
|
|
||||||
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
|
|
||||||
// Token file removed - wait briefly for potential atomic replace
|
|
||||||
time.Sleep(replaceCheckDelay)
|
|
||||||
if _, statErr := os.Stat(event.Name); statErr != nil {
|
|
||||||
log.Debugf("Kiro IDE token file removed: %s", event.Name)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to load the updated token
|
|
||||||
tokenData, err := kiroauth.LoadKiroIDEToken()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to load Kiro IDE token after change: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("Kiro IDE token file updated, access token refreshed (provider: %s)", tokenData.Provider)
|
|
||||||
|
|
||||||
// Trigger auth state refresh to pick up the new token
|
|
||||||
w.refreshAuthState(true)
|
|
||||||
|
|
||||||
// Notify callback if set
|
|
||||||
w.clientsMutex.RLock()
|
|
||||||
cfg := w.config
|
|
||||||
w.clientsMutex.RUnlock()
|
|
||||||
|
|
||||||
if w.reloadCallback != nil && cfg != nil {
|
|
||||||
log.Debugf("triggering server update callback after Kiro IDE token change")
|
|
||||||
w.reloadCallback(cfg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *Watcher) reloadConfigIfChanged() {
|
|
||||||
data, err := os.ReadFile(w.configPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to read config file for hash check: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(data) == 0 {
|
|
||||||
log.Debugf("ignoring empty config file write event")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
sum := sha256.Sum256(data)
|
|
||||||
newHash := hex.EncodeToString(sum[:])
|
|
||||||
|
|
||||||
w.clientsMutex.RLock()
|
|
||||||
currentHash := w.lastConfigHash
|
|
||||||
w.clientsMutex.RUnlock()
|
|
||||||
|
|
||||||
if currentHash != "" && currentHash == newHash {
|
|
||||||
log.Debugf("config file content unchanged (hash match), skipping reload")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Infof("config file changed, reloading: %s", w.configPath)
|
|
||||||
if w.reloadConfig() {
|
|
||||||
finalHash := newHash
|
|
||||||
if updatedData, errRead := os.ReadFile(w.configPath); errRead == nil && len(updatedData) > 0 {
|
|
||||||
sumUpdated := sha256.Sum256(updatedData)
|
|
||||||
finalHash = hex.EncodeToString(sumUpdated[:])
|
|
||||||
} else if errRead != nil {
|
|
||||||
log.WithError(errRead).Debug("failed to compute updated config hash after reload")
|
|
||||||
}
|
|
||||||
w.clientsMutex.Lock()
|
|
||||||
w.lastConfigHash = finalHash
|
|
||||||
w.clientsMutex.Unlock()
|
|
||||||
w.persistConfigAsync()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// reloadConfig reloads the configuration and triggers a full reload
|
|
||||||
func (w *Watcher) reloadConfig() bool {
|
|
||||||
log.Debug("=========================== CONFIG RELOAD ============================")
|
|
||||||
log.Debugf("starting config reload from: %s", w.configPath)
|
|
||||||
|
|
||||||
newConfig, errLoadConfig := config.LoadConfig(w.configPath)
|
|
||||||
if errLoadConfig != nil {
|
|
||||||
log.Errorf("failed to reload config: %v", errLoadConfig)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if w.mirroredAuthDir != "" {
|
|
||||||
newConfig.AuthDir = w.mirroredAuthDir
|
|
||||||
} else {
|
|
||||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(newConfig.AuthDir); errResolveAuthDir != nil {
|
|
||||||
log.Errorf("failed to resolve auth directory from config: %v", errResolveAuthDir)
|
|
||||||
} else {
|
|
||||||
newConfig.AuthDir = resolvedAuthDir
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
w.clientsMutex.Lock()
|
|
||||||
var oldConfig *config.Config
|
|
||||||
_ = yaml.Unmarshal(w.oldConfigYaml, &oldConfig)
|
|
||||||
w.oldConfigYaml, _ = yaml.Marshal(newConfig)
|
|
||||||
w.config = newConfig
|
|
||||||
w.clientsMutex.Unlock()
|
|
||||||
|
|
||||||
var affectedOAuthProviders []string
|
|
||||||
if oldConfig != nil {
|
|
||||||
_, affectedOAuthProviders = diff.DiffOAuthExcludedModelChanges(oldConfig.OAuthExcludedModels, newConfig.OAuthExcludedModels)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Always apply the current log level based on the latest config.
|
|
||||||
// This ensures logrus reflects the desired level even if change detection misses.
|
|
||||||
util.SetLogLevel(newConfig)
|
|
||||||
// Additional debug for visibility when the flag actually changes.
|
|
||||||
if oldConfig != nil && oldConfig.Debug != newConfig.Debug {
|
|
||||||
log.Debugf("log level updated - debug mode changed from %t to %t", oldConfig.Debug, newConfig.Debug)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log configuration changes in debug mode, only when there are material diffs
|
|
||||||
if oldConfig != nil {
|
|
||||||
details := diff.BuildConfigChangeDetails(oldConfig, newConfig)
|
|
||||||
if len(details) > 0 {
|
|
||||||
log.Debugf("config changes detected:")
|
|
||||||
for _, d := range details {
|
|
||||||
log.Debugf(" %s", d)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Debugf("no material config field changes detected")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
|
|
||||||
forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix
|
|
||||||
|
|
||||||
log.Infof("config successfully reloaded, triggering client reload")
|
|
||||||
// Reload clients with new config
|
|
||||||
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// reloadClients performs a full scan and reload of all clients.
|
|
||||||
func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string, forceAuthRefresh bool) {
|
|
||||||
log.Debugf("starting full client load process")
|
|
||||||
|
|
||||||
w.clientsMutex.RLock()
|
|
||||||
cfg := w.config
|
|
||||||
w.clientsMutex.RUnlock()
|
|
||||||
|
|
||||||
if cfg == nil {
|
|
||||||
log.Error("config is nil, cannot reload clients")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(affectedOAuthProviders) > 0 {
|
|
||||||
w.clientsMutex.Lock()
|
|
||||||
if w.currentAuths != nil {
|
|
||||||
filtered := make(map[string]*coreauth.Auth, len(w.currentAuths))
|
|
||||||
for id, auth := range w.currentAuths {
|
|
||||||
if auth == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
|
||||||
if _, match := matchProvider(provider, affectedOAuthProviders); match {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
filtered[id] = auth
|
|
||||||
}
|
|
||||||
w.currentAuths = filtered
|
|
||||||
log.Debugf("applying oauth-excluded-models to providers %v", affectedOAuthProviders)
|
|
||||||
} else {
|
|
||||||
w.currentAuths = nil
|
|
||||||
}
|
|
||||||
w.clientsMutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unregister all old API key clients before creating new ones
|
|
||||||
// no legacy clients to unregister
|
|
||||||
|
|
||||||
// Create new API key clients based on the new config
|
|
||||||
geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg)
|
|
||||||
totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
|
||||||
log.Debugf("loaded %d API key clients", totalAPIKeyClients)
|
|
||||||
|
|
||||||
var authFileCount int
|
|
||||||
if rescanAuth {
|
|
||||||
// Load file-based clients when explicitly requested (startup or authDir change)
|
|
||||||
authFileCount = w.loadFileClients(cfg)
|
|
||||||
log.Debugf("loaded %d file-based clients", authFileCount)
|
|
||||||
} else {
|
|
||||||
// Preserve existing auth hashes and only report current known count to avoid redundant scans.
|
|
||||||
w.clientsMutex.RLock()
|
|
||||||
authFileCount = len(w.lastAuthHashes)
|
|
||||||
w.clientsMutex.RUnlock()
|
|
||||||
log.Debugf("skipping auth directory rescan; retaining %d existing auth files", authFileCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
// no legacy file-based clients to unregister
|
|
||||||
|
|
||||||
// Update client maps
|
|
||||||
if rescanAuth {
|
|
||||||
w.clientsMutex.Lock()
|
|
||||||
|
|
||||||
// Rebuild auth file hash cache for current clients
|
|
||||||
w.lastAuthHashes = make(map[string]string)
|
|
||||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
|
||||||
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
|
|
||||||
} else if resolvedAuthDir != "" {
|
|
||||||
_ = filepath.Walk(resolvedAuthDir, func(path string, info fs.FileInfo, err error) error {
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
|
|
||||||
if data, errReadFile := os.ReadFile(path); errReadFile == nil && len(data) > 0 {
|
|
||||||
sum := sha256.Sum256(data)
|
|
||||||
normalizedPath := w.normalizeAuthPath(path)
|
|
||||||
w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
w.clientsMutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
|
||||||
|
|
||||||
// Ensure consumers observe the new configuration before auth updates dispatch.
|
|
||||||
if w.reloadCallback != nil {
|
|
||||||
log.Debugf("triggering server update callback before auth refresh")
|
|
||||||
w.reloadCallback(cfg)
|
|
||||||
}
|
|
||||||
|
|
||||||
w.refreshAuthState(forceAuthRefresh)
|
|
||||||
|
|
||||||
log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)",
|
|
||||||
totalNewClients,
|
|
||||||
authFileCount,
|
|
||||||
geminiAPIKeyCount,
|
|
||||||
vertexCompatAPIKeyCount,
|
|
||||||
claudeAPIKeyCount,
|
|
||||||
codexAPIKeyCount,
|
|
||||||
openAICompatCount,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// createClientFromFile creates a single client instance from a given token file path.
|
|
||||||
// createClientFromFile removed (legacy)
|
|
||||||
|
|
||||||
// addOrUpdateClient handles the addition or update of a single client.
|
|
||||||
func (w *Watcher) addOrUpdateClient(path string) {
|
|
||||||
data, errRead := os.ReadFile(path)
|
|
||||||
if errRead != nil {
|
|
||||||
log.Errorf("failed to read auth file %s: %v", filepath.Base(path), errRead)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(data) == 0 {
|
|
||||||
log.Debugf("ignoring empty auth file: %s", filepath.Base(path))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sum := sha256.Sum256(data)
|
|
||||||
curHash := hex.EncodeToString(sum[:])
|
|
||||||
normalized := w.normalizeAuthPath(path)
|
|
||||||
|
|
||||||
w.clientsMutex.Lock()
|
|
||||||
|
|
||||||
cfg := w.config
|
|
||||||
if cfg == nil {
|
|
||||||
log.Error("config is nil, cannot add or update client")
|
|
||||||
w.clientsMutex.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash {
|
|
||||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path))
|
|
||||||
w.clientsMutex.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update hash cache
|
|
||||||
w.lastAuthHashes[normalized] = curHash
|
|
||||||
|
|
||||||
w.clientsMutex.Unlock() // Unlock before the callback
|
|
||||||
|
|
||||||
w.refreshAuthState(false)
|
|
||||||
|
|
||||||
if w.reloadCallback != nil {
|
|
||||||
log.Debugf("triggering server update callback after add/update")
|
|
||||||
w.reloadCallback(cfg)
|
|
||||||
}
|
|
||||||
w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path)
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeClient handles the removal of a single client.
|
|
||||||
func (w *Watcher) removeClient(path string) {
|
|
||||||
normalized := w.normalizeAuthPath(path)
|
|
||||||
w.clientsMutex.Lock()
|
|
||||||
|
|
||||||
cfg := w.config
|
|
||||||
delete(w.lastAuthHashes, normalized)
|
|
||||||
|
|
||||||
w.clientsMutex.Unlock() // Release the lock before the callback
|
|
||||||
|
|
||||||
w.refreshAuthState(false)
|
|
||||||
|
|
||||||
if w.reloadCallback != nil {
|
|
||||||
log.Debugf("triggering server update callback after removal")
|
|
||||||
w.reloadCallback(cfg)
|
|
||||||
}
|
|
||||||
w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SnapshotCombinedClients returns a snapshot of current combined clients.
|
|
||||||
// SnapshotCombinedClients removed
|
|
||||||
|
|
||||||
// SnapshotCoreAuths converts current clients snapshot into core auth entries.
|
// SnapshotCoreAuths converts current clients snapshot into core auth entries.
|
||||||
func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||||
w.clientsMutex.RLock()
|
w.clientsMutex.RLock()
|
||||||
cfg := w.config
|
cfg := w.config
|
||||||
w.clientsMutex.RUnlock()
|
w.clientsMutex.RUnlock()
|
||||||
|
return snapshotCoreAuths(cfg, w.authDir)
|
||||||
ctx := &synthesizer.SynthesisContext{
|
|
||||||
Config: cfg,
|
|
||||||
AuthDir: w.authDir,
|
|
||||||
Now: time.Now(),
|
|
||||||
IDGenerator: synthesizer.NewStableIDGenerator(),
|
|
||||||
}
|
|
||||||
|
|
||||||
var out []*coreauth.Auth
|
|
||||||
|
|
||||||
// Use ConfigSynthesizer for API key auth entries
|
|
||||||
configSynth := synthesizer.NewConfigSynthesizer()
|
|
||||||
if auths, err := configSynth.Synthesize(ctx); err == nil {
|
|
||||||
out = append(out, auths...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use FileSynthesizer for file-based OAuth auth entries
|
|
||||||
fileSynth := synthesizer.NewFileSynthesizer()
|
|
||||||
if auths, err := fileSynth.Synthesize(ctx); err == nil {
|
|
||||||
out = append(out, auths...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildCombinedClientMap merges file-based clients with API key clients from the cache.
|
|
||||||
// buildCombinedClientMap removed
|
|
||||||
|
|
||||||
// unregisterClientWithReason attempts to call client-specific unregister hooks with context.
|
|
||||||
// unregisterClientWithReason removed
|
|
||||||
|
|
||||||
// loadFileClients scans the auth directory and creates clients from .json files.
|
|
||||||
func (w *Watcher) loadFileClients(cfg *config.Config) int {
|
|
||||||
authFileCount := 0
|
|
||||||
successfulAuthCount := 0
|
|
||||||
|
|
||||||
authDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir)
|
|
||||||
if errResolveAuthDir != nil {
|
|
||||||
log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir)
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
if authDir == "" {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
errWalk := filepath.Walk(authDir, func(path string, info fs.FileInfo, err error) error {
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("error accessing path %s: %v", path, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
|
|
||||||
authFileCount++
|
|
||||||
log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path))
|
|
||||||
// Count readable JSON files as successful auth entries
|
|
||||||
if data, errCreate := os.ReadFile(path); errCreate == nil && len(data) > 0 {
|
|
||||||
successfulAuthCount++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if errWalk != nil {
|
|
||||||
log.Errorf("error walking auth directory: %v", errWalk)
|
|
||||||
}
|
|
||||||
log.Debugf("auth directory scan complete - found %d .json files, %d readable", authFileCount, successfulAuthCount)
|
|
||||||
return authFileCount
|
|
||||||
}
|
|
||||||
|
|
||||||
func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) {
|
|
||||||
geminiAPIKeyCount := 0
|
|
||||||
vertexCompatAPIKeyCount := 0
|
|
||||||
claudeAPIKeyCount := 0
|
|
||||||
codexAPIKeyCount := 0
|
|
||||||
openAICompatCount := 0
|
|
||||||
|
|
||||||
if len(cfg.GeminiKey) > 0 {
|
|
||||||
// Stateless executor handles Gemini API keys; avoid constructing legacy clients.
|
|
||||||
geminiAPIKeyCount += len(cfg.GeminiKey)
|
|
||||||
}
|
|
||||||
if len(cfg.VertexCompatAPIKey) > 0 {
|
|
||||||
vertexCompatAPIKeyCount += len(cfg.VertexCompatAPIKey)
|
|
||||||
}
|
|
||||||
if len(cfg.ClaudeKey) > 0 {
|
|
||||||
claudeAPIKeyCount += len(cfg.ClaudeKey)
|
|
||||||
}
|
|
||||||
if len(cfg.CodexKey) > 0 {
|
|
||||||
codexAPIKeyCount += len(cfg.CodexKey)
|
|
||||||
}
|
|
||||||
if len(cfg.OpenAICompatibility) > 0 {
|
|
||||||
// Do not construct legacy clients for OpenAI-compat providers; these are handled by the stateless executor.
|
|
||||||
for _, compatConfig := range cfg.OpenAICompatibility {
|
|
||||||
openAICompatCount += len(compatConfig.APIKeyEntries)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -16,6 +17,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
@@ -489,6 +491,28 @@ func TestAuthFileUnchangedUsesHash(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthFileUnchangedEmptyAndMissing(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
emptyFile := filepath.Join(tmpDir, "empty.json")
|
||||||
|
if err := os.WriteFile(emptyFile, []byte(""), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write empty auth file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := &Watcher{lastAuthHashes: make(map[string]string)}
|
||||||
|
unchanged, err := w.authFileUnchanged(emptyFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error for empty file: %v", err)
|
||||||
|
}
|
||||||
|
if unchanged {
|
||||||
|
t.Fatal("expected empty file to be treated as changed")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = w.authFileUnchanged(filepath.Join(tmpDir, "missing.json"))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for missing auth file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReloadClientsCachesAuthHashes(t *testing.T) {
|
func TestReloadClientsCachesAuthHashes(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
authFile := filepath.Join(tmpDir, "one.json")
|
authFile := filepath.Join(tmpDir, "one.json")
|
||||||
@@ -528,6 +552,23 @@ func TestReloadClientsLogsConfigDiffs(t *testing.T) {
|
|||||||
w.reloadClients(false, nil, false)
|
w.reloadClients(false, nil, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReloadClientsHandlesNilConfig(t *testing.T) {
|
||||||
|
w := &Watcher{}
|
||||||
|
w.reloadClients(true, nil, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReloadClientsFiltersProvidersWithNilCurrentAuths(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
w := &Watcher{
|
||||||
|
authDir: tmp,
|
||||||
|
config: &config.Config{AuthDir: tmp},
|
||||||
|
}
|
||||||
|
w.reloadClients(false, []string{"match"}, false)
|
||||||
|
if w.currentAuths != nil && len(w.currentAuths) != 0 {
|
||||||
|
t.Fatalf("expected currentAuths to be nil or empty, got %d", len(w.currentAuths))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSetAuthUpdateQueueNilResetsDispatch(t *testing.T) {
|
func TestSetAuthUpdateQueueNilResetsDispatch(t *testing.T) {
|
||||||
w := &Watcher{}
|
w := &Watcher{}
|
||||||
queue := make(chan AuthUpdate, 1)
|
queue := make(chan AuthUpdate, 1)
|
||||||
@@ -541,6 +582,45 @@ func TestSetAuthUpdateQueueNilResetsDispatch(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPersistAsyncEarlyReturns(t *testing.T) {
|
||||||
|
var nilWatcher *Watcher
|
||||||
|
nilWatcher.persistConfigAsync()
|
||||||
|
nilWatcher.persistAuthAsync("msg", "a")
|
||||||
|
|
||||||
|
w := &Watcher{}
|
||||||
|
w.persistConfigAsync()
|
||||||
|
w.persistAuthAsync("msg", " ", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
type errorPersister struct {
|
||||||
|
configCalls int32
|
||||||
|
authCalls int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *errorPersister) PersistConfig(context.Context) error {
|
||||||
|
atomic.AddInt32(&p.configCalls, 1)
|
||||||
|
return fmt.Errorf("persist config error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *errorPersister) PersistAuthFiles(context.Context, string, ...string) error {
|
||||||
|
atomic.AddInt32(&p.authCalls, 1)
|
||||||
|
return fmt.Errorf("persist auth error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPersistAsyncErrorPaths(t *testing.T) {
|
||||||
|
p := &errorPersister{}
|
||||||
|
w := &Watcher{storePersister: p}
|
||||||
|
w.persistConfigAsync()
|
||||||
|
w.persistAuthAsync("msg", "a")
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
if atomic.LoadInt32(&p.configCalls) != 1 {
|
||||||
|
t.Fatalf("expected PersistConfig to be called once, got %d", p.configCalls)
|
||||||
|
}
|
||||||
|
if atomic.LoadInt32(&p.authCalls) != 1 {
|
||||||
|
t.Fatalf("expected PersistAuthFiles to be called once, got %d", p.authCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStopConfigReloadTimerSafeWhenNil(t *testing.T) {
|
func TestStopConfigReloadTimerSafeWhenNil(t *testing.T) {
|
||||||
w := &Watcher{}
|
w := &Watcher{}
|
||||||
w.stopConfigReloadTimer()
|
w.stopConfigReloadTimer()
|
||||||
@@ -608,6 +688,803 @@ func TestDispatchAuthUpdatesFlushesQueue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDispatchLoopExitsOnContextDoneWhileSending(t *testing.T) {
|
||||||
|
queue := make(chan AuthUpdate) // unbuffered to block sends
|
||||||
|
w := &Watcher{
|
||||||
|
authQueue: queue,
|
||||||
|
pendingUpdates: map[string]AuthUpdate{
|
||||||
|
"k": {Action: AuthUpdateActionAdd, ID: "k"},
|
||||||
|
},
|
||||||
|
pendingOrder: []string{"k"},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
w.dispatchLoop(ctx)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("expected dispatchLoop to exit after ctx canceled while blocked on send")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessEventsHandlesEventErrorAndChannelClose(t *testing.T) {
|
||||||
|
w := &Watcher{
|
||||||
|
watcher: &fsnotify.Watcher{
|
||||||
|
Events: make(chan fsnotify.Event, 2),
|
||||||
|
Errors: make(chan error, 2),
|
||||||
|
},
|
||||||
|
configPath: "config.yaml",
|
||||||
|
authDir: "auth",
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
w.processEvents(ctx)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
w.watcher.Events <- fsnotify.Event{Name: "unrelated.txt", Op: fsnotify.Write}
|
||||||
|
w.watcher.Errors <- fmt.Errorf("watcher error")
|
||||||
|
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
close(w.watcher.Events)
|
||||||
|
close(w.watcher.Errors)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
t.Fatal("processEvents did not exit after channels closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessEventsReturnsWhenErrorsChannelClosed(t *testing.T) {
|
||||||
|
w := &Watcher{
|
||||||
|
watcher: &fsnotify.Watcher{
|
||||||
|
Events: nil,
|
||||||
|
Errors: make(chan error),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
close(w.watcher.Errors)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
w.processEvents(ctx)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
t.Fatal("processEvents did not exit after errors channel closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleEventIgnoresUnrelatedFiles(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
authDir := filepath.Join(tmpDir, "auth")
|
||||||
|
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to create auth dir: %v", err)
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write config file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var reloads int32
|
||||||
|
w := &Watcher{
|
||||||
|
authDir: authDir,
|
||||||
|
configPath: configPath,
|
||||||
|
lastAuthHashes: make(map[string]string),
|
||||||
|
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
|
||||||
|
}
|
||||||
|
w.SetConfig(&config.Config{AuthDir: authDir})
|
||||||
|
|
||||||
|
w.handleEvent(fsnotify.Event{Name: filepath.Join(tmpDir, "note.txt"), Op: fsnotify.Write})
|
||||||
|
if atomic.LoadInt32(&reloads) != 0 {
|
||||||
|
t.Fatalf("expected no reloads for unrelated file, got %d", reloads)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleEventConfigChangeSchedulesReload(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
authDir := filepath.Join(tmpDir, "auth")
|
||||||
|
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to create auth dir: %v", err)
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write config file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var reloads int32
|
||||||
|
w := &Watcher{
|
||||||
|
authDir: authDir,
|
||||||
|
configPath: configPath,
|
||||||
|
lastAuthHashes: make(map[string]string),
|
||||||
|
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
|
||||||
|
}
|
||||||
|
w.SetConfig(&config.Config{AuthDir: authDir})
|
||||||
|
|
||||||
|
w.handleEvent(fsnotify.Event{Name: configPath, Op: fsnotify.Write})
|
||||||
|
|
||||||
|
time.Sleep(400 * time.Millisecond)
|
||||||
|
if atomic.LoadInt32(&reloads) != 1 {
|
||||||
|
t.Fatalf("expected config change to trigger reload once, got %d", reloads)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleEventAuthWriteTriggersUpdate(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
authDir := filepath.Join(tmpDir, "auth")
|
||||||
|
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to create auth dir: %v", err)
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write config file: %v", err)
|
||||||
|
}
|
||||||
|
authFile := filepath.Join(authDir, "a.json")
|
||||||
|
if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write auth file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var reloads int32
|
||||||
|
w := &Watcher{
|
||||||
|
authDir: authDir,
|
||||||
|
configPath: configPath,
|
||||||
|
lastAuthHashes: make(map[string]string),
|
||||||
|
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
|
||||||
|
}
|
||||||
|
w.SetConfig(&config.Config{AuthDir: authDir})
|
||||||
|
|
||||||
|
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Write})
|
||||||
|
if atomic.LoadInt32(&reloads) != 1 {
|
||||||
|
t.Fatalf("expected auth write to trigger reload callback, got %d", reloads)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleEventRemoveDebounceSkips(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
authDir := filepath.Join(tmpDir, "auth")
|
||||||
|
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to create auth dir: %v", err)
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write config file: %v", err)
|
||||||
|
}
|
||||||
|
authFile := filepath.Join(authDir, "remove.json")
|
||||||
|
|
||||||
|
var reloads int32
|
||||||
|
w := &Watcher{
|
||||||
|
authDir: authDir,
|
||||||
|
configPath: configPath,
|
||||||
|
lastAuthHashes: make(map[string]string),
|
||||||
|
lastRemoveTimes: map[string]time.Time{
|
||||||
|
filepath.Clean(authFile): time.Now(),
|
||||||
|
},
|
||||||
|
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
|
||||||
|
}
|
||||||
|
w.SetConfig(&config.Config{AuthDir: authDir})
|
||||||
|
|
||||||
|
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
|
||||||
|
if atomic.LoadInt32(&reloads) != 0 {
|
||||||
|
t.Fatalf("expected remove to be debounced, got %d", reloads)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleEventAtomicReplaceUnchangedSkips(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
authDir := filepath.Join(tmpDir, "auth")
|
||||||
|
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to create auth dir: %v", err)
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write config file: %v", err)
|
||||||
|
}
|
||||||
|
authFile := filepath.Join(authDir, "same.json")
|
||||||
|
content := []byte(`{"type":"demo"}`)
|
||||||
|
if err := os.WriteFile(authFile, content, 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write auth file: %v", err)
|
||||||
|
}
|
||||||
|
sum := sha256.Sum256(content)
|
||||||
|
|
||||||
|
var reloads int32
|
||||||
|
w := &Watcher{
|
||||||
|
authDir: authDir,
|
||||||
|
configPath: configPath,
|
||||||
|
lastAuthHashes: make(map[string]string),
|
||||||
|
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
|
||||||
|
}
|
||||||
|
w.SetConfig(&config.Config{AuthDir: authDir})
|
||||||
|
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(sum[:])
|
||||||
|
|
||||||
|
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Rename})
|
||||||
|
if atomic.LoadInt32(&reloads) != 0 {
|
||||||
|
t.Fatalf("expected unchanged atomic replace to be skipped, got %d", reloads)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleEventAtomicReplaceChangedTriggersUpdate(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
authDir := filepath.Join(tmpDir, "auth")
|
||||||
|
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to create auth dir: %v", err)
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write config file: %v", err)
|
||||||
|
}
|
||||||
|
authFile := filepath.Join(authDir, "change.json")
|
||||||
|
oldContent := []byte(`{"type":"demo","v":1}`)
|
||||||
|
newContent := []byte(`{"type":"demo","v":2}`)
|
||||||
|
if err := os.WriteFile(authFile, newContent, 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write auth file: %v", err)
|
||||||
|
}
|
||||||
|
oldSum := sha256.Sum256(oldContent)
|
||||||
|
|
||||||
|
var reloads int32
|
||||||
|
w := &Watcher{
|
||||||
|
authDir: authDir,
|
||||||
|
configPath: configPath,
|
||||||
|
lastAuthHashes: make(map[string]string),
|
||||||
|
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
|
||||||
|
}
|
||||||
|
w.SetConfig(&config.Config{AuthDir: authDir})
|
||||||
|
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(oldSum[:])
|
||||||
|
|
||||||
|
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Rename})
|
||||||
|
if atomic.LoadInt32(&reloads) != 1 {
|
||||||
|
t.Fatalf("expected changed atomic replace to trigger update, got %d", reloads)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleEventRemoveUnknownFileIgnored(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
authDir := filepath.Join(tmpDir, "auth")
|
||||||
|
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to create auth dir: %v", err)
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write config file: %v", err)
|
||||||
|
}
|
||||||
|
authFile := filepath.Join(authDir, "unknown.json")
|
||||||
|
|
||||||
|
var reloads int32
|
||||||
|
w := &Watcher{
|
||||||
|
authDir: authDir,
|
||||||
|
configPath: configPath,
|
||||||
|
lastAuthHashes: make(map[string]string),
|
||||||
|
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
|
||||||
|
}
|
||||||
|
w.SetConfig(&config.Config{AuthDir: authDir})
|
||||||
|
|
||||||
|
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
|
||||||
|
if atomic.LoadInt32(&reloads) != 0 {
|
||||||
|
t.Fatalf("expected unknown remove to be ignored, got %d", reloads)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleEventRemoveKnownFileDeletes(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
authDir := filepath.Join(tmpDir, "auth")
|
||||||
|
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to create auth dir: %v", err)
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write config file: %v", err)
|
||||||
|
}
|
||||||
|
authFile := filepath.Join(authDir, "known.json")
|
||||||
|
|
||||||
|
var reloads int32
|
||||||
|
w := &Watcher{
|
||||||
|
authDir: authDir,
|
||||||
|
configPath: configPath,
|
||||||
|
lastAuthHashes: make(map[string]string),
|
||||||
|
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
|
||||||
|
}
|
||||||
|
w.SetConfig(&config.Config{AuthDir: authDir})
|
||||||
|
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = "hash"
|
||||||
|
|
||||||
|
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
|
||||||
|
if atomic.LoadInt32(&reloads) != 1 {
|
||||||
|
t.Fatalf("expected known remove to trigger reload, got %d", reloads)
|
||||||
|
}
|
||||||
|
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
|
||||||
|
t.Fatal("expected known auth hash to be deleted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeAuthPathAndDebounceCleanup(t *testing.T) {
|
||||||
|
w := &Watcher{}
|
||||||
|
if got := w.normalizeAuthPath(" "); got != "" {
|
||||||
|
t.Fatalf("expected empty normalize result, got %q", got)
|
||||||
|
}
|
||||||
|
if got := w.normalizeAuthPath(" a/../b "); got != filepath.Clean("a/../b") {
|
||||||
|
t.Fatalf("unexpected normalize result: %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.clientsMutex.Lock()
|
||||||
|
w.lastRemoveTimes = make(map[string]time.Time, 140)
|
||||||
|
old := time.Now().Add(-3 * authRemoveDebounceWindow)
|
||||||
|
for i := 0; i < 129; i++ {
|
||||||
|
w.lastRemoveTimes[fmt.Sprintf("old-%d", i)] = old
|
||||||
|
}
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
|
||||||
|
w.shouldDebounceRemove("new-path", time.Now())
|
||||||
|
|
||||||
|
w.clientsMutex.Lock()
|
||||||
|
gotLen := len(w.lastRemoveTimes)
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
if gotLen >= 129 {
|
||||||
|
t.Fatalf("expected debounce cleanup to shrink map, got %d", gotLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshAuthStateDispatchesRuntimeAuths(t *testing.T) {
|
||||||
|
queue := make(chan AuthUpdate, 8)
|
||||||
|
w := &Watcher{
|
||||||
|
authDir: t.TempDir(),
|
||||||
|
lastAuthHashes: make(map[string]string),
|
||||||
|
}
|
||||||
|
w.SetConfig(&config.Config{AuthDir: w.authDir})
|
||||||
|
w.SetAuthUpdateQueue(queue)
|
||||||
|
defer w.stopDispatch()
|
||||||
|
|
||||||
|
w.clientsMutex.Lock()
|
||||||
|
w.runtimeAuths = map[string]*coreauth.Auth{
|
||||||
|
"nil": nil,
|
||||||
|
"r1": {ID: "r1", Provider: "runtime"},
|
||||||
|
}
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
|
||||||
|
w.refreshAuthState(false)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case u := <-queue:
|
||||||
|
if u.Action != AuthUpdateActionAdd || u.ID != "r1" {
|
||||||
|
t.Fatalf("unexpected auth update: %+v", u)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for runtime auth update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddOrUpdateClientEdgeCases(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
authDir := tmpDir
|
||||||
|
authFile := filepath.Join(tmpDir, "edge.json")
|
||||||
|
if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write auth file: %v", err)
|
||||||
|
}
|
||||||
|
emptyFile := filepath.Join(tmpDir, "empty.json")
|
||||||
|
if err := os.WriteFile(emptyFile, []byte(""), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write empty auth file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var reloads int32
|
||||||
|
w := &Watcher{
|
||||||
|
authDir: authDir,
|
||||||
|
lastAuthHashes: make(map[string]string),
|
||||||
|
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
|
||||||
|
}
|
||||||
|
|
||||||
|
w.addOrUpdateClient(filepath.Join(tmpDir, "missing.json"))
|
||||||
|
w.addOrUpdateClient(emptyFile)
|
||||||
|
if atomic.LoadInt32(&reloads) != 0 {
|
||||||
|
t.Fatalf("expected no reloads for missing/empty file, got %d", reloads)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.addOrUpdateClient(authFile) // config nil -> should not panic or update
|
||||||
|
if len(w.lastAuthHashes) != 0 {
|
||||||
|
t.Fatalf("expected no hash entries without config, got %d", len(w.lastAuthHashes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadFileClientsWalkError(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
noAccessDir := filepath.Join(tmpDir, "0noaccess")
|
||||||
|
if err := os.MkdirAll(noAccessDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to create noaccess dir: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.Chmod(noAccessDir, 0); err != nil {
|
||||||
|
t.Skipf("chmod not supported: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = os.Chmod(noAccessDir, 0o755) }()
|
||||||
|
|
||||||
|
cfg := &config.Config{AuthDir: tmpDir}
|
||||||
|
w := &Watcher{}
|
||||||
|
w.SetConfig(cfg)
|
||||||
|
|
||||||
|
count := w.loadFileClients(cfg)
|
||||||
|
if count != 0 {
|
||||||
|
t.Fatalf("expected count 0 due to walk error, got %d", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReloadConfigIfChangedHandlesMissingAndEmpty(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
authDir := filepath.Join(tmpDir, "auth")
|
||||||
|
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to create auth dir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := &Watcher{
|
||||||
|
configPath: filepath.Join(tmpDir, "missing.yaml"),
|
||||||
|
authDir: authDir,
|
||||||
|
}
|
||||||
|
w.reloadConfigIfChanged() // missing file -> log + return
|
||||||
|
|
||||||
|
emptyPath := filepath.Join(tmpDir, "empty.yaml")
|
||||||
|
if err := os.WriteFile(emptyPath, []byte(""), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write empty config: %v", err)
|
||||||
|
}
|
||||||
|
w.configPath = emptyPath
|
||||||
|
w.reloadConfigIfChanged() // empty file -> early return
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReloadConfigUsesMirroredAuthDir(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
authDir := filepath.Join(tmpDir, "auth")
|
||||||
|
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to create auth dir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
if err := os.WriteFile(configPath, []byte("auth_dir: "+filepath.Join(tmpDir, "other")+"\n"), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := &Watcher{
|
||||||
|
configPath: configPath,
|
||||||
|
authDir: authDir,
|
||||||
|
mirroredAuthDir: authDir,
|
||||||
|
lastAuthHashes: make(map[string]string),
|
||||||
|
}
|
||||||
|
w.SetConfig(&config.Config{AuthDir: authDir})
|
||||||
|
|
||||||
|
if ok := w.reloadConfig(); !ok {
|
||||||
|
t.Fatal("expected reloadConfig to succeed")
|
||||||
|
}
|
||||||
|
|
||||||
|
w.clientsMutex.RLock()
|
||||||
|
defer w.clientsMutex.RUnlock()
|
||||||
|
if w.config == nil || w.config.AuthDir != authDir {
|
||||||
|
t.Fatalf("expected AuthDir to be overridden by mirroredAuthDir %s, got %+v", authDir, w.config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReloadConfigFiltersAffectedOAuthProviders(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
authDir := filepath.Join(tmpDir, "auth")
|
||||||
|
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to create auth dir: %v", err)
|
||||||
|
}
|
||||||
|
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
|
||||||
|
// Ensure SnapshotCoreAuths yields a provider that is NOT affected, so we can assert it survives.
|
||||||
|
if err := os.WriteFile(filepath.Join(authDir, "provider-b.json"), []byte(`{"type":"provider-b","email":"b@example.com"}`), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write auth file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
oldCfg := &config.Config{
|
||||||
|
AuthDir: authDir,
|
||||||
|
OAuthExcludedModels: map[string][]string{
|
||||||
|
"provider-a": {"m1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
newCfg := &config.Config{
|
||||||
|
AuthDir: authDir,
|
||||||
|
OAuthExcludedModels: map[string][]string{
|
||||||
|
"provider-a": {"m2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, err := yaml.Marshal(newCfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal config: %v", err)
|
||||||
|
}
|
||||||
|
if err = os.WriteFile(configPath, data, 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := &Watcher{
|
||||||
|
configPath: configPath,
|
||||||
|
authDir: authDir,
|
||||||
|
lastAuthHashes: make(map[string]string),
|
||||||
|
currentAuths: map[string]*coreauth.Auth{
|
||||||
|
"a": {ID: "a", Provider: "provider-a"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.SetConfig(oldCfg)
|
||||||
|
|
||||||
|
if ok := w.reloadConfig(); !ok {
|
||||||
|
t.Fatal("expected reloadConfig to succeed")
|
||||||
|
}
|
||||||
|
|
||||||
|
w.clientsMutex.RLock()
|
||||||
|
defer w.clientsMutex.RUnlock()
|
||||||
|
for _, auth := range w.currentAuths {
|
||||||
|
if auth != nil && auth.Provider == "provider-a" {
|
||||||
|
t.Fatal("expected affected provider auth to be filtered")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
foundB := false
|
||||||
|
for _, auth := range w.currentAuths {
|
||||||
|
if auth != nil && auth.Provider == "provider-b" {
|
||||||
|
foundB = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundB {
|
||||||
|
t.Fatal("expected unaffected provider auth to remain")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartFailsWhenAuthDirMissing(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
if err := os.WriteFile(configPath, []byte("auth_dir: "+filepath.Join(tmpDir, "missing-auth")+"\n"), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write config file: %v", err)
|
||||||
|
}
|
||||||
|
authDir := filepath.Join(tmpDir, "missing-auth")
|
||||||
|
|
||||||
|
w, err := NewWatcher(configPath, authDir, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create watcher: %v", err)
|
||||||
|
}
|
||||||
|
defer w.Stop()
|
||||||
|
w.SetConfig(&config.Config{AuthDir: authDir})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := w.Start(ctx); err == nil {
|
||||||
|
t.Fatal("expected Start to fail for missing auth dir")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDispatchRuntimeAuthUpdateReturnsFalseWithoutQueue(t *testing.T) {
|
||||||
|
w := &Watcher{}
|
||||||
|
if ok := w.DispatchRuntimeAuthUpdate(AuthUpdate{Action: AuthUpdateActionAdd, Auth: &coreauth.Auth{ID: "a"}}); ok {
|
||||||
|
t.Fatal("expected DispatchRuntimeAuthUpdate to return false when no queue configured")
|
||||||
|
}
|
||||||
|
if ok := w.DispatchRuntimeAuthUpdate(AuthUpdate{Action: AuthUpdateActionDelete, Auth: &coreauth.Auth{ID: "a"}}); ok {
|
||||||
|
t.Fatal("expected DispatchRuntimeAuthUpdate delete to return false when no queue configured")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeAuthNil(t *testing.T) {
|
||||||
|
if normalizeAuth(nil) != nil {
|
||||||
|
t.Fatal("expected normalizeAuth(nil) to return nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// stubStore implements coreauth.Store plus watcher-specific persistence helpers.
|
||||||
|
type stubStore struct {
|
||||||
|
authDir string
|
||||||
|
cfgPersisted int32
|
||||||
|
authPersisted int32
|
||||||
|
lastAuthMessage string
|
||||||
|
lastAuthPaths []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubStore) List(context.Context) ([]*coreauth.Auth, error) { return nil, nil }
|
||||||
|
func (s *stubStore) Save(context.Context, *coreauth.Auth) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
func (s *stubStore) Delete(context.Context, string) error { return nil }
|
||||||
|
func (s *stubStore) PersistConfig(context.Context) error {
|
||||||
|
atomic.AddInt32(&s.cfgPersisted, 1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *stubStore) PersistAuthFiles(_ context.Context, message string, paths ...string) error {
|
||||||
|
atomic.AddInt32(&s.authPersisted, 1)
|
||||||
|
s.lastAuthMessage = message
|
||||||
|
s.lastAuthPaths = paths
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *stubStore) AuthDir() string { return s.authDir }
|
||||||
|
|
||||||
|
func TestNewWatcherDetectsPersisterAndAuthDir(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
store := &stubStore{authDir: tmp}
|
||||||
|
orig := sdkAuth.GetTokenStore()
|
||||||
|
sdkAuth.RegisterTokenStore(store)
|
||||||
|
defer sdkAuth.RegisterTokenStore(orig)
|
||||||
|
|
||||||
|
w, err := NewWatcher("config.yaml", "auth", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewWatcher failed: %v", err)
|
||||||
|
}
|
||||||
|
if w.storePersister == nil {
|
||||||
|
t.Fatal("expected storePersister to be set from token store")
|
||||||
|
}
|
||||||
|
if w.mirroredAuthDir != tmp {
|
||||||
|
t.Fatalf("expected mirroredAuthDir %s, got %s", tmp, w.mirroredAuthDir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPersistConfigAndAuthAsyncInvokePersister(t *testing.T) {
|
||||||
|
w := &Watcher{
|
||||||
|
storePersister: &stubStore{},
|
||||||
|
}
|
||||||
|
|
||||||
|
w.persistConfigAsync()
|
||||||
|
w.persistAuthAsync("msg", " a ", "", "b ")
|
||||||
|
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
store := w.storePersister.(*stubStore)
|
||||||
|
if atomic.LoadInt32(&store.cfgPersisted) != 1 {
|
||||||
|
t.Fatalf("expected PersistConfig to be called once, got %d", store.cfgPersisted)
|
||||||
|
}
|
||||||
|
if atomic.LoadInt32(&store.authPersisted) != 1 {
|
||||||
|
t.Fatalf("expected PersistAuthFiles to be called once, got %d", store.authPersisted)
|
||||||
|
}
|
||||||
|
if store.lastAuthMessage != "msg" {
|
||||||
|
t.Fatalf("unexpected auth message: %s", store.lastAuthMessage)
|
||||||
|
}
|
||||||
|
if len(store.lastAuthPaths) != 2 || store.lastAuthPaths[0] != "a" || store.lastAuthPaths[1] != "b" {
|
||||||
|
t.Fatalf("unexpected filtered paths: %#v", store.lastAuthPaths)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScheduleConfigReloadDebounces(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
authDir := tmp
|
||||||
|
cfgPath := tmp + "/config.yaml"
|
||||||
|
if err := os.WriteFile(cfgPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var reloads int32
|
||||||
|
w := &Watcher{
|
||||||
|
configPath: cfgPath,
|
||||||
|
authDir: authDir,
|
||||||
|
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
|
||||||
|
}
|
||||||
|
w.SetConfig(&config.Config{AuthDir: authDir})
|
||||||
|
|
||||||
|
w.scheduleConfigReload()
|
||||||
|
w.scheduleConfigReload()
|
||||||
|
|
||||||
|
time.Sleep(400 * time.Millisecond)
|
||||||
|
|
||||||
|
if atomic.LoadInt32(&reloads) != 1 {
|
||||||
|
t.Fatalf("expected single debounced reload, got %d", reloads)
|
||||||
|
}
|
||||||
|
if w.lastConfigHash == "" {
|
||||||
|
t.Fatal("expected lastConfigHash to be set after reload")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareAuthUpdatesLockedForceAndDelete(t *testing.T) {
|
||||||
|
w := &Watcher{
|
||||||
|
currentAuths: map[string]*coreauth.Auth{
|
||||||
|
"a": {ID: "a", Provider: "p1"},
|
||||||
|
},
|
||||||
|
authQueue: make(chan AuthUpdate, 4),
|
||||||
|
}
|
||||||
|
|
||||||
|
updates := w.prepareAuthUpdatesLocked([]*coreauth.Auth{{ID: "a", Provider: "p2"}}, false)
|
||||||
|
if len(updates) != 1 || updates[0].Action != AuthUpdateActionModify || updates[0].ID != "a" {
|
||||||
|
t.Fatalf("unexpected modify updates: %+v", updates)
|
||||||
|
}
|
||||||
|
|
||||||
|
updates = w.prepareAuthUpdatesLocked([]*coreauth.Auth{{ID: "a", Provider: "p2"}}, true)
|
||||||
|
if len(updates) != 1 || updates[0].Action != AuthUpdateActionModify {
|
||||||
|
t.Fatalf("expected force modify, got %+v", updates)
|
||||||
|
}
|
||||||
|
|
||||||
|
updates = w.prepareAuthUpdatesLocked([]*coreauth.Auth{}, false)
|
||||||
|
if len(updates) != 1 || updates[0].Action != AuthUpdateActionDelete || updates[0].ID != "a" {
|
||||||
|
t.Fatalf("expected delete for missing auth, got %+v", updates)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthEqualIgnoresTemporalFields(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
a := &coreauth.Auth{ID: "x", CreatedAt: now}
|
||||||
|
b := &coreauth.Auth{ID: "x", CreatedAt: now.Add(5 * time.Second)}
|
||||||
|
if !authEqual(a, b) {
|
||||||
|
t.Fatal("expected authEqual to ignore temporal differences")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDispatchLoopExitsWhenQueueNilAndContextCanceled(t *testing.T) {
|
||||||
|
w := &Watcher{
|
||||||
|
dispatchCond: nil,
|
||||||
|
pendingUpdates: map[string]AuthUpdate{"k": {ID: "k"}},
|
||||||
|
pendingOrder: []string{"k"},
|
||||||
|
}
|
||||||
|
w.dispatchCond = sync.NewCond(&w.dispatchMu)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
w.dispatchLoop(ctx)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
w.dispatchMu.Lock()
|
||||||
|
w.dispatchCond.Broadcast()
|
||||||
|
w.dispatchMu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
t.Fatal("dispatchLoop did not exit after context cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReloadClientsFiltersOAuthProvidersWithoutRescan(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
w := &Watcher{
|
||||||
|
authDir: tmp,
|
||||||
|
config: &config.Config{AuthDir: tmp},
|
||||||
|
currentAuths: map[string]*coreauth.Auth{
|
||||||
|
"a": {ID: "a", Provider: "Match"},
|
||||||
|
"b": {ID: "b", Provider: "other"},
|
||||||
|
},
|
||||||
|
lastAuthHashes: map[string]string{"cached": "hash"},
|
||||||
|
}
|
||||||
|
|
||||||
|
w.reloadClients(false, []string{"match"}, false)
|
||||||
|
|
||||||
|
w.clientsMutex.RLock()
|
||||||
|
defer w.clientsMutex.RUnlock()
|
||||||
|
if _, ok := w.currentAuths["a"]; ok {
|
||||||
|
t.Fatal("expected filtered provider to be removed")
|
||||||
|
}
|
||||||
|
if len(w.lastAuthHashes) != 1 {
|
||||||
|
t.Fatalf("expected existing hash cache to be retained, got %d", len(w.lastAuthHashes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScheduleProcessEventsStopsOnContextDone(t *testing.T) {
|
||||||
|
w := &Watcher{
|
||||||
|
watcher: &fsnotify.Watcher{
|
||||||
|
Events: make(chan fsnotify.Event, 1),
|
||||||
|
Errors: make(chan error, 1),
|
||||||
|
},
|
||||||
|
configPath: "config.yaml",
|
||||||
|
authDir: "auth",
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
w.processEvents(ctx)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
t.Fatal("processEvents did not exit on context cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func hexString(data []byte) string {
|
func hexString(data []byte) string {
|
||||||
return strings.ToLower(fmt.Sprintf("%x", data))
|
return strings.ToLower(fmt.Sprintf("%x", data))
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user