fix: address PR review feedback

- Fix SSRF: validate API endpoint host against allowlist before use
- Limit /models response body to 2MB to prevent memory exhaustion (DoS)
- Use MakeAuthenticatedRequest for consistent headers across API calls
- Trim trailing slash on API endpoint to prevent double-slash URLs
- Use ListModelsWithGitHubToken to simplify token exchange + listing
- Deduplicate model IDs to prevent incorrect registry reference counting
- Remove dead capabilities enrichment code block
- Remove unused ModelExtra field with misleading json:"-" tag
- Extract magic numbers to named constants (defaultCopilotContextLength)
- Remove redundant hyphenID == id check (already filtered by Contains)
- Use defer cancel() for context timeout in service.go
This commit is contained in:
yx-bot7
2026-03-04 15:43:51 +08:00
parent d8e3d4e2b6
commit 7d6660d181
4 changed files with 49 additions and 37 deletions

View File

@@ -8,6 +8,8 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -224,14 +226,13 @@ func (c *CopilotAuth) MakeAuthenticatedRequest(ctx context.Context, method, url
// CopilotModelEntry represents a single model entry returned by the Copilot /models API.
type CopilotModelEntry struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Name string `json:"name,omitempty"`
Version string `json:"version,omitempty"`
Capabilities map[string]any `json:"capabilities,omitempty"`
ModelExtra map[string]any `json:"-"` // catch-all for unknown fields
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Name string `json:"name,omitempty"`
Version string `json:"version,omitempty"`
Capabilities map[string]any `json:"capabilities,omitempty"`
}
// CopilotModelsResponse represents the response from the Copilot /models endpoint.
@@ -240,6 +241,17 @@ type CopilotModelsResponse struct {
Object string `json:"object"`
}
// maxModelsResponseSize is the maximum allowed response size from the /models endpoint (2 MB).
const maxModelsResponseSize = 2 * 1024 * 1024
// allowedCopilotAPIHosts is the set of hosts that are considered safe for Copilot API requests.
var allowedCopilotAPIHosts = map[string]bool{
"api.githubcopilot.com": true,
"api.individual.githubcopilot.com": true,
"api.business.githubcopilot.com": true,
"copilot-proxy.githubusercontent.com": true,
}
// ListModels fetches the list of available models from the Copilot API.
// It requires a valid Copilot API token (not the GitHub access token).
func (c *CopilotAuth) ListModels(ctx context.Context, apiToken *CopilotAPIToken) ([]CopilotModelEntry, error) {
@@ -247,23 +259,22 @@ func (c *CopilotAuth) ListModels(ctx context.Context, apiToken *CopilotAPIToken)
return nil, fmt.Errorf("copilot: api token is required for listing models")
}
// Build models URL, validating the endpoint host to prevent SSRF.
modelsURL := copilotAPIEndpoint + "/models"
if apiToken.Endpoints.API != "" {
modelsURL = apiToken.Endpoints.API + "/models"
if ep := strings.TrimRight(apiToken.Endpoints.API, "/"); ep != "" {
parsed, err := url.Parse(ep)
if err == nil && parsed.Scheme == "https" && allowedCopilotAPIHosts[parsed.Host] {
modelsURL = ep + "/models"
} else {
log.Warnf("copilot: ignoring untrusted API endpoint %q, using default", ep)
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil)
req, err := c.MakeAuthenticatedRequest(ctx, http.MethodGet, modelsURL, nil, apiToken)
if err != nil {
return nil, fmt.Errorf("copilot: failed to create models request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+apiToken.Token)
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", copilotUserAgent)
req.Header.Set("Editor-Version", copilotEditorVersion)
req.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
req.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("copilot: models request failed: %w", err)
@@ -274,7 +285,9 @@ func (c *CopilotAuth) ListModels(ctx context.Context, apiToken *CopilotAPIToken)
}
}()
bodyBytes, err := io.ReadAll(resp.Body)
// Limit response body to prevent memory exhaustion.
limitedReader := io.LimitReader(resp.Body, maxModelsResponseSize)
bodyBytes, err := io.ReadAll(limitedReader)
if err != nil {
return nil, fmt.Errorf("copilot: failed to read models response: %w", err)
}