Add GitLab Duo OAuth and PAT support

This commit is contained in:
LuxVTZ
2026-03-10 17:52:35 +04:00
parent 046865461e
commit bb28cd26ad
12 changed files with 2062 additions and 1 deletions

View File

@@ -79,6 +79,8 @@ func main() {
var kiloLogin bool
var iflowLogin bool
var iflowCookie bool
var gitlabLogin bool
var gitlabTokenLogin bool
var noBrowser bool
var oauthCallbackPort int
var antigravityLogin bool
@@ -111,6 +113,8 @@ func main() {
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
flag.BoolVar(&gitlabLogin, "gitlab-login", false, "Login to GitLab Duo using OAuth")
flag.BoolVar(&gitlabTokenLogin, "gitlab-token-login", false, "Login to GitLab Duo using a personal access token")
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
@@ -527,6 +531,10 @@ func main() {
cmd.DoIFlowLogin(cfg, options)
} else if iflowCookie {
cmd.DoIFlowCookieAuth(cfg, options)
} else if gitlabLogin {
cmd.DoGitLabLogin(cfg, options)
} else if gitlabTokenLogin {
cmd.DoGitLabTokenLogin(cfg, options)
} else if kimiLogin {
cmd.DoKimiLogin(cfg, options)
} else if kiroLogin {

View File

@@ -0,0 +1,492 @@
package gitlab
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
DefaultBaseURL = "https://gitlab.com"
DefaultCallbackPort = 17171
defaultOAuthScope = "api read_user"
)
type PKCECodes struct {
CodeVerifier string
CodeChallenge string
}
type OAuthResult struct {
Code string
State string
Error string
}
type OAuthServer struct {
server *http.Server
port int
resultChan chan *OAuthResult
errorChan chan error
mu sync.Mutex
running bool
}
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
CreatedAt int64 `json:"created_at"`
ExpiresIn int `json:"expires_in"`
}
type User struct {
ID int64 `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Email string `json:"email"`
PublicEmail string `json:"public_email"`
}
type PersonalAccessTokenSelf struct {
ID int64 `json:"id"`
Name string `json:"name"`
Scopes []string `json:"scopes"`
UserID int64 `json:"user_id"`
}
type ModelDetails struct {
ModelProvider string `json:"model_provider"`
ModelName string `json:"model_name"`
}
type DirectAccessResponse struct {
BaseURL string `json:"base_url"`
Token string `json:"token"`
ExpiresAt int64 `json:"expires_at"`
Headers map[string]string `json:"headers"`
ModelDetails *ModelDetails `json:"model_details,omitempty"`
}
type DiscoveredModel struct {
ModelProvider string
ModelName string
}
type AuthClient struct {
httpClient *http.Client
}
func NewAuthClient(cfg *config.Config) *AuthClient {
client := &http.Client{}
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
return &AuthClient{httpClient: client}
}
func NormalizeBaseURL(raw string) string {
value := strings.TrimSpace(raw)
if value == "" {
return DefaultBaseURL
}
if !strings.Contains(value, "://") {
value = "https://" + value
}
value = strings.TrimRight(value, "/")
return value
}
func TokenExpiry(now time.Time, token *TokenResponse) time.Time {
if token == nil {
return time.Time{}
}
if token.CreatedAt > 0 && token.ExpiresIn > 0 {
return time.Unix(token.CreatedAt+int64(token.ExpiresIn), 0).UTC()
}
if token.ExpiresIn > 0 {
return now.UTC().Add(time.Duration(token.ExpiresIn) * time.Second)
}
return time.Time{}
}
func GeneratePKCECodes() (*PKCECodes, error) {
verifierBytes := make([]byte, 32)
if _, err := rand.Read(verifierBytes); err != nil {
return nil, fmt.Errorf("gitlab pkce generation failed: %w", err)
}
verifier := base64.RawURLEncoding.EncodeToString(verifierBytes)
sum := sha256.Sum256([]byte(verifier))
challenge := base64.RawURLEncoding.EncodeToString(sum[:])
return &PKCECodes{
CodeVerifier: verifier,
CodeChallenge: challenge,
}, nil
}
func NewOAuthServer(port int) *OAuthServer {
return &OAuthServer{
port: port,
resultChan: make(chan *OAuthResult, 1),
errorChan: make(chan error, 1),
}
}
func (s *OAuthServer) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.running {
return fmt.Errorf("gitlab oauth server already running")
}
if !s.isPortAvailable() {
return fmt.Errorf("port %d is already in use", s.port)
}
mux := http.NewServeMux()
mux.HandleFunc("/auth/callback", s.handleCallback)
s.server = &http.Server{
Addr: fmt.Sprintf(":%d", s.port),
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
s.running = true
go func() {
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
s.errorChan <- err
}
}()
time.Sleep(100 * time.Millisecond)
return nil
}
func (s *OAuthServer) Stop(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.running || s.server == nil {
return nil
}
defer func() {
s.running = false
s.server = nil
}()
return s.server.Shutdown(ctx)
}
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
select {
case result := <-s.resultChan:
return result, nil
case err := <-s.errorChan:
return nil, err
case <-time.After(timeout):
return nil, fmt.Errorf("timeout waiting for OAuth callback")
}
}
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
query := r.URL.Query()
if errParam := strings.TrimSpace(query.Get("error")); errParam != "" {
s.sendResult(&OAuthResult{Error: errParam})
http.Error(w, errParam, http.StatusBadRequest)
return
}
code := strings.TrimSpace(query.Get("code"))
state := strings.TrimSpace(query.Get("state"))
if code == "" || state == "" {
s.sendResult(&OAuthResult{Error: "missing_code_or_state"})
http.Error(w, "missing code or state", http.StatusBadRequest)
return
}
s.sendResult(&OAuthResult{Code: code, State: state})
_, _ = w.Write([]byte("GitLab authentication received. You can close this tab."))
}
func (s *OAuthServer) sendResult(result *OAuthResult) {
select {
case s.resultChan <- result:
default:
log.Debug("gitlab oauth result channel full, dropping callback result")
}
}
func (s *OAuthServer) isPortAvailable() bool {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port))
if err != nil {
return false
}
_ = listener.Close()
return true
}
func RedirectURL(port int) string {
return fmt.Sprintf("http://localhost:%d/auth/callback", port)
}
func (c *AuthClient) GenerateAuthURL(baseURL, clientID, redirectURI, state string, pkce *PKCECodes) (string, error) {
if pkce == nil {
return "", fmt.Errorf("gitlab auth URL generation failed: PKCE codes are required")
}
if strings.TrimSpace(clientID) == "" {
return "", fmt.Errorf("gitlab auth URL generation failed: client ID is required")
}
baseURL = NormalizeBaseURL(baseURL)
params := url.Values{
"client_id": {strings.TrimSpace(clientID)},
"response_type": {"code"},
"redirect_uri": {strings.TrimSpace(redirectURI)},
"scope": {defaultOAuthScope},
"state": {strings.TrimSpace(state)},
"code_challenge": {pkce.CodeChallenge},
"code_challenge_method": {"S256"},
}
return fmt.Sprintf("%s/oauth/authorize?%s", baseURL, params.Encode()), nil
}
func (c *AuthClient) ExchangeCodeForTokens(ctx context.Context, baseURL, clientID, clientSecret, redirectURI, code, codeVerifier string) (*TokenResponse, error) {
form := url.Values{
"grant_type": {"authorization_code"},
"client_id": {strings.TrimSpace(clientID)},
"code": {strings.TrimSpace(code)},
"redirect_uri": {strings.TrimSpace(redirectURI)},
"code_verifier": {strings.TrimSpace(codeVerifier)},
}
if secret := strings.TrimSpace(clientSecret); secret != "" {
form.Set("client_secret", secret)
}
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
}
func (c *AuthClient) RefreshTokens(ctx context.Context, baseURL, clientID, clientSecret, refreshToken string) (*TokenResponse, error) {
form := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {strings.TrimSpace(refreshToken)},
}
if clientID = strings.TrimSpace(clientID); clientID != "" {
form.Set("client_id", clientID)
}
if secret := strings.TrimSpace(clientSecret); secret != "" {
form.Set("client_secret", secret)
}
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
}
func (c *AuthClient) postToken(ctx context.Context, tokenURL string, form url.Values) (*TokenResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("gitlab token request failed: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("gitlab token request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("gitlab token response read failed: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("gitlab token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var token TokenResponse
if err := json.Unmarshal(body, &token); err != nil {
return nil, fmt.Errorf("gitlab token response decode failed: %w", err)
}
return &token, nil
}
func (c *AuthClient) GetCurrentUser(ctx context.Context, baseURL, token string) (*User, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/user", nil)
if err != nil {
return nil, fmt.Errorf("gitlab user request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("gitlab user request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("gitlab user response read failed: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("gitlab user request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var user User
if err := json.Unmarshal(body, &user); err != nil {
return nil, fmt.Errorf("gitlab user response decode failed: %w", err)
}
return &user, nil
}
func (c *AuthClient) GetPersonalAccessTokenSelf(ctx context.Context, baseURL, token string) (*PersonalAccessTokenSelf, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/personal_access_tokens/self", nil)
if err != nil {
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("gitlab PAT self response read failed: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("gitlab PAT self request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var pat PersonalAccessTokenSelf
if err := json.Unmarshal(body, &pat); err != nil {
return nil, fmt.Errorf("gitlab PAT self response decode failed: %w", err)
}
return &pat, nil
}
func (c *AuthClient) FetchDirectAccess(ctx context.Context, baseURL, token string) (*DirectAccessResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, NormalizeBaseURL(baseURL)+"/api/v4/code_suggestions/direct_access", nil)
if err != nil {
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("gitlab direct access response read failed: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("gitlab direct access request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var direct DirectAccessResponse
if err := json.Unmarshal(body, &direct); err != nil {
return nil, fmt.Errorf("gitlab direct access response decode failed: %w", err)
}
if direct.Headers == nil {
direct.Headers = make(map[string]string)
}
return &direct, nil
}
func ExtractDiscoveredModels(metadata map[string]any) []DiscoveredModel {
if len(metadata) == 0 {
return nil
}
models := make([]DiscoveredModel, 0, 4)
seen := make(map[string]struct{})
appendModel := func(provider, name string) {
provider = strings.TrimSpace(provider)
name = strings.TrimSpace(name)
if name == "" {
return
}
key := strings.ToLower(provider + "\x00" + name)
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
models = append(models, DiscoveredModel{
ModelProvider: provider,
ModelName: name,
})
}
if raw, ok := metadata["model_details"]; ok {
appendDiscoveredModels(raw, appendModel)
}
appendModel(stringValue(metadata["model_provider"]), stringValue(metadata["model_name"]))
for _, key := range []string{"models", "supported_models", "discovered_models"} {
if raw, ok := metadata[key]; ok {
appendDiscoveredModels(raw, appendModel)
}
}
return models
}
func appendDiscoveredModels(raw any, appendModel func(provider, name string)) {
switch typed := raw.(type) {
case map[string]any:
appendModel(stringValue(typed["model_provider"]), stringValue(typed["model_name"]))
appendModel(stringValue(typed["provider"]), stringValue(typed["name"]))
if nested, ok := typed["models"]; ok {
appendDiscoveredModels(nested, appendModel)
}
case []any:
for _, item := range typed {
appendDiscoveredModels(item, appendModel)
}
case []string:
for _, item := range typed {
appendModel("", item)
}
case string:
appendModel("", typed)
}
}
func stringValue(raw any) string {
switch typed := raw.(type) {
case string:
return strings.TrimSpace(typed)
case fmt.Stringer:
return strings.TrimSpace(typed.String())
case json.Number:
return typed.String()
case int:
return strconv.Itoa(typed)
case int64:
return strconv.FormatInt(typed, 10)
case float64:
return strconv.FormatInt(int64(typed), 10)
default:
return ""
}
}

View File

@@ -0,0 +1,72 @@
package gitlab
import (
"context"
"net/http"
"net/http/httptest"
"testing"
)
func TestNormalizeBaseURL(t *testing.T) {
tests := []struct {
name string
in string
want string
}{
{name: "default", in: "", want: DefaultBaseURL},
{name: "plain host", in: "gitlab.example.com", want: "https://gitlab.example.com"},
{name: "trim trailing slash", in: "https://gitlab.example.com/", want: "https://gitlab.example.com"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := NormalizeBaseURL(tc.in); got != tc.want {
t.Fatalf("NormalizeBaseURL(%q) = %q, want %q", tc.in, got, tc.want)
}
})
}
}
func TestFetchDirectAccess_ParsesModelDetails(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", r.Method)
}
if got := r.Header.Get("Authorization"); got != "Bearer pat-123" {
t.Fatalf("expected Authorization header, got %q", got)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"base_url":"https://gateway.gitlab.example.com/v1",
"token":"duo-gateway-token",
"expires_at":2000000000,
"headers":{
"X-Gitlab-Realm":"saas",
"X-Gitlab-Host-Name":"gitlab.example.com"
},
"model_details":{
"model_provider":"anthropic",
"model_name":"claude-sonnet-4-5"
}
}`))
}))
defer server.Close()
client := &AuthClient{httpClient: server.Client()}
direct, err := client.FetchDirectAccess(context.Background(), server.URL, "pat-123")
if err != nil {
t.Fatalf("FetchDirectAccess returned error: %v", err)
}
if direct.BaseURL != "https://gateway.gitlab.example.com/v1" {
t.Fatalf("unexpected base_url %q", direct.BaseURL)
}
if direct.Token != "duo-gateway-token" {
t.Fatalf("unexpected token %q", direct.Token)
}
if direct.ModelDetails == nil || direct.ModelDetails.ModelName != "claude-sonnet-4-5" {
t.Fatalf("unexpected model details: %+v", direct.ModelDetails)
}
if direct.Headers["X-Gitlab-Realm"] != "saas" {
t.Fatalf("expected X-Gitlab-Realm header, got %+v", direct.Headers)
}
}

View File

@@ -23,6 +23,7 @@ func newAuthManager() *sdkAuth.Manager {
sdkAuth.NewKiroAuthenticator(),
sdkAuth.NewGitHubCopilotAuthenticator(),
sdkAuth.NewKiloAuthenticator(),
sdkAuth.NewGitLabAuthenticator(),
)
return manager
}

View File

@@ -0,0 +1,69 @@
package cmd
import (
"context"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
)
func DoGitLabLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{
"login_mode": "oauth",
},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
if err != nil {
fmt.Printf("GitLab Duo authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("GitLab Duo authentication successful!")
}
func DoGitLabTokenLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
Metadata: map[string]string{
"login_mode": "pat",
},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
if err != nil {
fmt.Printf("GitLab Duo PAT authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("GitLab Duo PAT authentication successful!")
}

View File

@@ -0,0 +1,746 @@
package executor
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
const (
gitLabProviderKey = "gitlab"
gitLabAuthMethodOAuth = "oauth"
gitLabAuthMethodPAT = "pat"
gitLabChatEndpoint = "/api/v4/chat/completions"
gitLabCodeSuggestionsEndpoint = "/api/v4/code_suggestions/completions"
)
type GitLabExecutor struct {
cfg *config.Config
}
type gitLabPrompt struct {
Instruction string
FileName string
ContentAboveCursor string
ChatContext []map[string]any
CodeSuggestionContext []map[string]any
}
func NewGitLabExecutor(cfg *config.Config) *GitLabExecutor {
return &GitLabExecutor{cfg: cfg}
}
func (e *GitLabExecutor) Identifier() string { return gitLabProviderKey }
func (e *GitLabExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
translated, err := e.translateToOpenAI(req, opts)
if err != nil {
return resp, err
}
prompt := buildGitLabPrompt(translated)
if strings.TrimSpace(prompt.Instruction) == "" && strings.TrimSpace(prompt.ContentAboveCursor) == "" {
err = statusErr{code: http.StatusBadRequest, msg: "gitlab duo executor: request has no usable text content"}
return resp, err
}
text, err := e.invoke(ctx, auth, prompt)
if err != nil {
return resp, err
}
responseModel := gitLabResolvedModel(auth, req.Model)
openAIResponse := buildGitLabOpenAIResponse(responseModel, text, translated)
reporter.publish(ctx, parseOpenAIUsage(openAIResponse))
reporter.ensurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(
ctx,
sdktranslator.FromString("openai"),
opts.SourceFormat,
req.Model,
opts.OriginalRequest,
translated,
openAIResponse,
&param,
)
return cliproxyexecutor.Response{Payload: []byte(out), Headers: make(http.Header)}, nil
}
func (e *GitLabExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
translated, err := e.translateToOpenAI(req, opts)
if err != nil {
return nil, err
}
prompt := buildGitLabPrompt(translated)
if strings.TrimSpace(prompt.Instruction) == "" && strings.TrimSpace(prompt.ContentAboveCursor) == "" {
return nil, statusErr{code: http.StatusBadRequest, msg: "gitlab duo executor: request has no usable text content"}
}
text, err := e.invoke(ctx, auth, prompt)
if err != nil {
return nil, err
}
responseModel := gitLabResolvedModel(auth, req.Model)
openAIResponse := buildGitLabOpenAIResponse(responseModel, text, translated)
reporter.publish(ctx, parseOpenAIUsage(openAIResponse))
reporter.ensurePublished(ctx)
out := make(chan cliproxyexecutor.StreamChunk, 8)
go func() {
defer close(out)
var param any
lines := buildGitLabOpenAIStream(responseModel, text)
for _, line := range lines {
chunks := sdktranslator.TranslateStream(
ctx,
sdktranslator.FromString("openai"),
opts.SourceFormat,
req.Model,
opts.OriginalRequest,
translated,
[]byte(line),
&param,
)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: make(http.Header), Chunks: out}, nil
}
func (e *GitLabExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
if auth == nil {
return nil, fmt.Errorf("gitlab duo executor: auth is nil")
}
baseURL := gitLabBaseURL(auth)
token := gitLabPrimaryToken(auth)
if baseURL == "" || token == "" {
return nil, fmt.Errorf("gitlab duo executor: missing base URL or token")
}
client := gitlab.NewAuthClient(e.cfg)
method := strings.ToLower(strings.TrimSpace(gitLabMetadataString(auth.Metadata, "auth_method", "auth_kind")))
if method == "" {
method = gitLabAuthMethodOAuth
}
if method == gitLabAuthMethodOAuth {
if refreshed, refreshErr := e.refreshOAuthToken(ctx, client, auth, baseURL); refreshErr == nil && refreshed != nil {
token = refreshed.AccessToken
applyGitLabTokenMetadata(auth.Metadata, refreshed)
}
}
direct, err := client.FetchDirectAccess(ctx, baseURL, token)
if err != nil && method == gitLabAuthMethodOAuth {
if refreshed, refreshErr := e.refreshOAuthToken(ctx, client, auth, baseURL); refreshErr == nil && refreshed != nil {
token = refreshed.AccessToken
applyGitLabTokenMetadata(auth.Metadata, refreshed)
direct, err = client.FetchDirectAccess(ctx, baseURL, token)
}
}
if err != nil {
return nil, err
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
auth.Metadata["type"] = gitLabProviderKey
auth.Metadata["auth_method"] = method
auth.Metadata["auth_kind"] = gitLabAuthKind(method)
auth.Metadata["base_url"] = gitlab.NormalizeBaseURL(baseURL)
auth.Metadata["last_refresh"] = time.Now().UTC().Format(time.RFC3339)
mergeGitLabDirectAccessMetadata(auth.Metadata, direct)
return auth, nil
}
func (e *GitLabExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
translated := sdktranslator.TranslateRequest(opts.SourceFormat, sdktranslator.FromString("openai"), baseModel, req.Payload, false)
enc, err := tokenizerForModel(baseModel)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("gitlab duo executor: tokenizer init failed: %w", err)
}
count, err := countOpenAIChatTokens(enc, translated)
if err != nil {
return cliproxyexecutor.Response{}, err
}
return cliproxyexecutor.Response{Payload: buildOpenAIUsageJSON(count), Headers: make(http.Header)}, nil
}
func (e *GitLabExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("gitlab duo executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if token := gitLabPrimaryToken(auth); token != "" {
httpReq.Header.Set("Authorization", "Bearer "+token)
}
return newProxyAwareHTTPClient(ctx, e.cfg, auth, 0).Do(httpReq)
}
func (e *GitLabExecutor) translateToOpenAI(req cliproxyexecutor.Request, opts cliproxyexecutor.Options) ([]byte, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
return sdktranslator.TranslateRequest(opts.SourceFormat, sdktranslator.FromString("openai"), baseModel, req.Payload, opts.Stream), nil
}
func (e *GitLabExecutor) invoke(ctx context.Context, auth *cliproxyauth.Auth, prompt gitLabPrompt) (string, error) {
if text, err := e.requestChat(ctx, auth, prompt); err == nil {
return text, nil
} else if !shouldFallbackToCodeSuggestions(err) {
return "", err
}
return e.requestCodeSuggestions(ctx, auth, prompt)
}
func (e *GitLabExecutor) requestChat(ctx context.Context, auth *cliproxyauth.Auth, prompt gitLabPrompt) (string, error) {
body := map[string]any{
"content": prompt.Instruction,
"with_clean_history": true,
}
if len(prompt.ChatContext) > 0 {
body["additional_context"] = prompt.ChatContext
}
return e.doJSONTextRequest(ctx, auth, gitLabChatEndpoint, body)
}
func (e *GitLabExecutor) requestCodeSuggestions(ctx context.Context, auth *cliproxyauth.Auth, prompt gitLabPrompt) (string, error) {
contentAbove := strings.TrimSpace(prompt.ContentAboveCursor)
if contentAbove == "" {
contentAbove = prompt.Instruction
}
body := map[string]any{
"current_file": map[string]any{
"file_name": prompt.FileName,
"content_above_cursor": contentAbove,
"content_below_cursor": "",
},
"intent": "generation",
"generation_type": "small_file",
"user_instruction": prompt.Instruction,
"stream": false,
}
if len(prompt.CodeSuggestionContext) > 0 {
body["context"] = prompt.CodeSuggestionContext
}
return e.doJSONTextRequest(ctx, auth, gitLabCodeSuggestionsEndpoint, body)
}
func (e *GitLabExecutor) doJSONTextRequest(ctx context.Context, auth *cliproxyauth.Auth, endpoint string, payload map[string]any) (string, error) {
token := gitLabPrimaryToken(auth)
baseURL := gitLabBaseURL(auth)
if token == "" || baseURL == "" {
return "", statusErr{code: http.StatusUnauthorized, msg: "gitlab duo executor: missing credentials"}
}
body, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("gitlab duo executor: marshal request failed: %w", err)
}
url := strings.TrimRight(baseURL, "/") + endpoint
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", "CLIProxyAPI/GitLab-Duo")
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: req.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
resp, err := httpClient.Do(req)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return "", err
}
defer func() { _ = resp.Body.Close() }()
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
respBody, err := io.ReadAll(resp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return "", err
}
appendAPIResponseChunk(ctx, e.cfg, respBody)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", statusErr{code: resp.StatusCode, msg: strings.TrimSpace(string(respBody))}
}
text, err := parseGitLabTextResponse(endpoint, respBody)
if err != nil {
return "", err
}
return strings.TrimSpace(text), nil
}
func (e *GitLabExecutor) refreshOAuthToken(ctx context.Context, client *gitlab.AuthClient, auth *cliproxyauth.Auth, baseURL string) (*gitlab.TokenResponse, error) {
if auth == nil {
return nil, fmt.Errorf("gitlab duo executor: auth is nil")
}
refreshToken := gitLabMetadataString(auth.Metadata, "refresh_token")
if refreshToken == "" {
return nil, fmt.Errorf("gitlab duo executor: refresh token missing")
}
if !gitLabOAuthTokenNeedsRefresh(auth.Metadata) && gitLabPrimaryToken(auth) != "" {
return nil, nil
}
return client.RefreshTokens(
ctx,
baseURL,
gitLabMetadataString(auth.Metadata, "oauth_client_id"),
gitLabMetadataString(auth.Metadata, "oauth_client_secret"),
refreshToken,
)
}
func buildGitLabPrompt(payload []byte) gitLabPrompt {
root := gjson.ParseBytes(payload)
prompt := gitLabPrompt{
FileName: "prompt.txt",
}
msgs := root.Get("messages")
if msgs.Exists() && msgs.IsArray() {
systemIndex := 0
contextIndex := 0
transcript := make([]string, 0, len(msgs.Array()))
var lastUser string
msgs.ForEach(func(_, msg gjson.Result) bool {
role := strings.TrimSpace(msg.Get("role").String())
if role == "" {
role = "user"
}
content := openAIContentText(msg.Get("content"))
if content == "" {
return true
}
switch role {
case "system":
systemIndex++
prompt.ChatContext = append(prompt.ChatContext, map[string]any{
"category": "snippet",
"id": fmt.Sprintf("system-%d", systemIndex),
"content": content,
})
case "user":
lastUser = content
contextIndex++
prompt.CodeSuggestionContext = append(prompt.CodeSuggestionContext, map[string]any{
"type": "snippet",
"name": fmt.Sprintf("user-%d", contextIndex),
"content": content,
})
transcript = append(transcript, "User:\n"+content)
default:
contextIndex++
prompt.ChatContext = append(prompt.ChatContext, map[string]any{
"category": "snippet",
"id": fmt.Sprintf("%s-%d", role, contextIndex),
"content": content,
})
prompt.CodeSuggestionContext = append(prompt.CodeSuggestionContext, map[string]any{
"type": "snippet",
"name": fmt.Sprintf("%s-%d", role, contextIndex),
"content": content,
})
transcript = append(transcript, strings.Title(role)+":\n"+content)
}
return true
})
prompt.Instruction = strings.TrimSpace(lastUser)
prompt.ContentAboveCursor = truncateGitLabPrompt(strings.Join(transcript, "\n\n"), 12000)
}
if prompt.Instruction == "" {
for _, key := range []string{"prompt", "input", "instructions"} {
if value := strings.TrimSpace(root.Get(key).String()); value != "" {
prompt.Instruction = value
break
}
}
}
if prompt.ContentAboveCursor == "" {
prompt.ContentAboveCursor = prompt.Instruction
}
prompt.Instruction = truncateGitLabPrompt(prompt.Instruction, 4000)
prompt.ContentAboveCursor = truncateGitLabPrompt(prompt.ContentAboveCursor, 12000)
return prompt
}
func openAIContentText(content gjson.Result) string {
segments := make([]string, 0, 8)
collectOpenAIContent(content, &segments)
return strings.TrimSpace(strings.Join(segments, "\n"))
}
func truncateGitLabPrompt(value string, limit int) string {
value = strings.TrimSpace(value)
if limit <= 0 || len(value) <= limit {
return value
}
return strings.TrimSpace(value[:limit])
}
func parseGitLabTextResponse(endpoint string, body []byte) (string, error) {
if endpoint == gitLabChatEndpoint {
var text string
if err := json.Unmarshal(body, &text); err == nil {
return text, nil
}
if value := strings.TrimSpace(gjson.GetBytes(body, "response").String()); value != "" {
return value, nil
}
}
if value := strings.TrimSpace(gjson.GetBytes(body, "choices.0.text").String()); value != "" {
return value, nil
}
if value := strings.TrimSpace(gjson.GetBytes(body, "response").String()); value != "" {
return value, nil
}
var plain string
if err := json.Unmarshal(body, &plain); err == nil && strings.TrimSpace(plain) != "" {
return plain, nil
}
return "", fmt.Errorf("gitlab duo executor: upstream returned no text payload")
}
func shouldFallbackToCodeSuggestions(err error) bool {
if err == nil {
return false
}
status, ok := err.(interface{ StatusCode() int })
if !ok {
return false
}
switch status.StatusCode() {
case http.StatusForbidden, http.StatusNotFound, http.StatusMethodNotAllowed, http.StatusNotImplemented:
return true
default:
return false
}
}
func buildGitLabOpenAIResponse(model, text string, translatedReq []byte) []byte {
promptTokens, completionTokens := gitLabUsage(model, translatedReq, text)
payload := map[string]any{
"id": fmt.Sprintf("gitlab-%d", time.Now().UnixNano()),
"object": "chat.completion",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]any{{
"index": 0,
"message": map[string]any{
"role": "assistant",
"content": text,
},
"finish_reason": "stop",
}},
"usage": map[string]any{
"prompt_tokens": promptTokens,
"completion_tokens": completionTokens,
"total_tokens": promptTokens + completionTokens,
},
}
raw, _ := json.Marshal(payload)
return raw
}
func buildGitLabOpenAIStream(model, text string) []string {
now := time.Now().Unix()
id := fmt.Sprintf("gitlab-%d", time.Now().UnixNano())
chunks := []map[string]any{
{
"id": id,
"object": "chat.completion.chunk",
"created": now,
"model": model,
"choices": []map[string]any{{
"index": 0,
"delta": map[string]any{"role": "assistant"},
}},
},
{
"id": id,
"object": "chat.completion.chunk",
"created": now,
"model": model,
"choices": []map[string]any{{
"index": 0,
"delta": map[string]any{"content": text},
}},
},
{
"id": id,
"object": "chat.completion.chunk",
"created": now,
"model": model,
"choices": []map[string]any{{
"index": 0,
"delta": map[string]any{},
"finish_reason": "stop",
}},
},
}
lines := make([]string, 0, len(chunks)+1)
for _, chunk := range chunks {
raw, _ := json.Marshal(chunk)
lines = append(lines, "data: "+string(raw))
}
lines = append(lines, "data: [DONE]")
return lines
}
func gitLabUsage(model string, translatedReq []byte, text string) (int64, int64) {
enc, err := tokenizerForModel(model)
if err != nil {
return 0, 0
}
promptTokens, err := countOpenAIChatTokens(enc, translatedReq)
if err != nil {
promptTokens = 0
}
completionCount, err := enc.Count(strings.TrimSpace(text))
if err != nil {
return promptTokens, 0
}
return promptTokens, int64(completionCount)
}
func gitLabPrimaryToken(auth *cliproxyauth.Auth) string {
if auth == nil || auth.Metadata == nil {
return ""
}
if token := gitLabMetadataString(auth.Metadata, "access_token"); token != "" {
return token
}
return gitLabMetadataString(auth.Metadata, "personal_access_token")
}
func gitLabBaseURL(auth *cliproxyauth.Auth) string {
if auth == nil || auth.Metadata == nil {
return ""
}
return gitlab.NormalizeBaseURL(gitLabMetadataString(auth.Metadata, "base_url"))
}
func gitLabResolvedModel(auth *cliproxyauth.Auth, requested string) string {
requested = strings.TrimSpace(thinking.ParseSuffix(requested).ModelName)
if requested != "" && !strings.EqualFold(requested, "gitlab-duo") {
return requested
}
if auth != nil && auth.Metadata != nil {
for _, model := range gitlab.ExtractDiscoveredModels(auth.Metadata) {
if name := strings.TrimSpace(model.ModelName); name != "" {
return name
}
}
}
if requested != "" {
return requested
}
return "gitlab-duo"
}
func gitLabMetadataString(metadata map[string]any, keys ...string) string {
for _, key := range keys {
if metadata == nil {
return ""
}
if value, ok := metadata[key].(string); ok {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
}
return ""
}
func gitLabOAuthTokenNeedsRefresh(metadata map[string]any) bool {
expiry := gitLabMetadataString(metadata, "oauth_expires_at")
if expiry == "" {
return true
}
ts, err := time.Parse(time.RFC3339, expiry)
if err != nil {
return true
}
return time.Until(ts) <= 5*time.Minute
}
func applyGitLabTokenMetadata(metadata map[string]any, tokenResp *gitlab.TokenResponse) {
if metadata == nil || tokenResp == nil {
return
}
if accessToken := strings.TrimSpace(tokenResp.AccessToken); accessToken != "" {
metadata["access_token"] = accessToken
}
if refreshToken := strings.TrimSpace(tokenResp.RefreshToken); refreshToken != "" {
metadata["refresh_token"] = refreshToken
}
if tokenType := strings.TrimSpace(tokenResp.TokenType); tokenType != "" {
metadata["token_type"] = tokenType
}
if scope := strings.TrimSpace(tokenResp.Scope); scope != "" {
metadata["scope"] = scope
}
if expiry := gitlab.TokenExpiry(time.Now(), tokenResp); !expiry.IsZero() {
metadata["oauth_expires_at"] = expiry.Format(time.RFC3339)
}
}
func mergeGitLabDirectAccessMetadata(metadata map[string]any, direct *gitlab.DirectAccessResponse) {
if metadata == nil || direct == nil {
return
}
if base := strings.TrimSpace(direct.BaseURL); base != "" {
metadata["duo_gateway_base_url"] = base
}
if token := strings.TrimSpace(direct.Token); token != "" {
metadata["duo_gateway_token"] = token
}
if direct.ExpiresAt > 0 {
expiry := time.Unix(direct.ExpiresAt, 0).UTC()
metadata["duo_gateway_expires_at"] = expiry.Format(time.RFC3339)
if ttl := expiry.Sub(time.Now().UTC()); ttl > 0 {
interval := int(ttl.Seconds()) / 2
switch {
case interval < 60:
interval = 60
case interval > 240:
interval = 240
}
metadata["refresh_interval_seconds"] = interval
}
}
if len(direct.Headers) > 0 {
headers := make(map[string]string, len(direct.Headers))
for key, value := range direct.Headers {
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key == "" || value == "" {
continue
}
headers[key] = value
}
if len(headers) > 0 {
metadata["duo_gateway_headers"] = headers
}
}
if direct.ModelDetails != nil {
modelDetails := map[string]any{}
if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" {
modelDetails["model_provider"] = provider
metadata["model_provider"] = provider
}
if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" {
modelDetails["model_name"] = model
metadata["model_name"] = model
}
if len(modelDetails) > 0 {
metadata["model_details"] = modelDetails
}
}
}
func gitLabAuthKind(method string) string {
switch strings.ToLower(strings.TrimSpace(method)) {
case gitLabAuthMethodPAT:
return "personal_access_token"
default:
return "oauth"
}
}
func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo {
models := make([]*registry.ModelInfo, 0, 4)
seen := make(map[string]struct{}, 4)
addModel := func(id, displayName, provider string) {
id = strings.TrimSpace(id)
if id == "" {
return
}
key := strings.ToLower(id)
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
models = append(models, &registry.ModelInfo{
ID: id,
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "gitlab",
Type: "gitlab",
DisplayName: displayName,
Description: provider,
UserDefined: true,
})
}
addModel("gitlab-duo", "GitLab Duo", "gitlab")
if auth == nil {
return models
}
for _, model := range gitlab.ExtractDiscoveredModels(auth.Metadata) {
name := strings.TrimSpace(model.ModelName)
if name == "" {
continue
}
displayName := "GitLab Duo"
if provider := strings.TrimSpace(model.ModelProvider); provider != "" {
displayName = fmt.Sprintf("GitLab Duo (%s)", provider)
}
addModel(name, displayName, strings.TrimSpace(model.ModelProvider))
}
return models
}

View File

@@ -0,0 +1,124 @@
package executor
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
func TestGitLabExecutorRefresh_WithPATStoresGatewayMetadata(t *testing.T) {
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/v4/code_suggestions/direct_access" {
t.Fatalf("unexpected path %s", r.URL.Path)
}
if got := r.Header.Get("Authorization"); got != "Bearer pat-123" {
t.Fatalf("unexpected Authorization header %q", got)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"base_url":"` + server.URL + `",
"token":"gateway-token",
"expires_at":2000000000,
"headers":{"X-Gitlab-Realm":"saas"},
"model_details":{"model_provider":"mistral","model_name":"codestral-2501"}
}`))
}))
defer server.Close()
exec := NewGitLabExecutor(nil)
auth := &cliproxyauth.Auth{
ID: "gitlab-pat.json",
Provider: "gitlab",
Metadata: map[string]any{
"type": "gitlab",
"auth_method": "pat",
"base_url": server.URL,
"personal_access_token": "pat-123",
},
}
updated, err := exec.Refresh(context.Background(), auth)
if err != nil {
t.Fatalf("Refresh returned error: %v", err)
}
if got := metadataString(updated.Metadata, "duo_gateway_token"); got != "gateway-token" {
t.Fatalf("unexpected gateway token %q", got)
}
if got := gitLabModelName(updated); got != "codestral-2501" {
t.Fatalf("unexpected model name %q", got)
}
headers := gitLabHeaders(updated)
if headers["X-Gitlab-Realm"] != "saas" {
t.Fatalf("unexpected gateway headers %+v", headers)
}
}
func TestGitLabExecutorExecute_UsesGatewayHeadersAndResolvedModel(t *testing.T) {
var receivedAuth string
var receivedRealm string
var receivedModel string
gateway := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
receivedAuth = r.Header.Get("Authorization")
receivedRealm = r.Header.Get("X-Gitlab-Realm")
receivedModel = findJSONField(string(body), `"model":"`, `"`)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"ok","object":"chat.completion","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}]}`))
}))
defer gateway.Close()
exec := NewGitLabExecutor(nil)
auth := &cliproxyauth.Auth{
ID: "gitlab-oauth.json",
Provider: "gitlab",
Metadata: map[string]any{
"type": "gitlab",
"auth_method": "oauth",
"duo_gateway_base_url": gateway.URL,
"duo_gateway_token": "gateway-token",
"duo_gateway_headers": map[string]any{"X-Gitlab-Realm": "saas"},
"model_details": map[string]any{"model_name": "codestral-2501", "model_provider": "mistral"},
},
}
resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gitlab-duo",
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"hello"}]}`),
}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("Execute returned error: %v", err)
}
if len(resp.Payload) == 0 {
t.Fatal("expected non-empty payload")
}
if receivedAuth != "Bearer gateway-token" {
t.Fatalf("unexpected Authorization header %q", receivedAuth)
}
if receivedRealm != "saas" {
t.Fatalf("unexpected X-Gitlab-Realm header %q", receivedRealm)
}
if receivedModel != "codestral-2501" {
t.Fatalf("unexpected resolved model %q", receivedModel)
}
}
func findJSONField(body, prefix, suffix string) string {
start := strings.Index(body, prefix)
if start < 0 {
return ""
}
start += len(prefix)
end := strings.Index(body[start:], suffix)
if end < 0 {
return ""
}
return body[start : start+end]
}

462
sdk/auth/gitlab.go Normal file
View File

@@ -0,0 +1,462 @@
package auth
import (
"context"
"fmt"
"strings"
"time"
gitlabauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
const (
gitLabLoginModeMetadataKey = "login_mode"
gitLabLoginModeOAuth = "oauth"
gitLabLoginModePAT = "pat"
gitLabBaseURLMetadataKey = "base_url"
gitLabOAuthClientIDMetadataKey = "oauth_client_id"
gitLabOAuthClientSecretMetadataKey = "oauth_client_secret"
gitLabPersonalAccessTokenMetadataKey = "personal_access_token"
)
var gitLabRefreshLead = 5 * time.Minute
type GitLabAuthenticator struct {
CallbackPort int
}
func NewGitLabAuthenticator() *GitLabAuthenticator {
return &GitLabAuthenticator{CallbackPort: gitlabauth.DefaultCallbackPort}
}
func (a *GitLabAuthenticator) Provider() string {
return "gitlab"
}
func (a *GitLabAuthenticator) RefreshLead() *time.Duration {
return &gitLabRefreshLead
}
func (a *GitLabAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("cliproxy auth: configuration is required")
}
if ctx == nil {
ctx = context.Background()
}
if opts == nil {
opts = &LoginOptions{}
}
switch strings.ToLower(strings.TrimSpace(opts.Metadata[gitLabLoginModeMetadataKey])) {
case "", gitLabLoginModeOAuth:
return a.loginOAuth(ctx, cfg, opts)
case gitLabLoginModePAT:
return a.loginPAT(ctx, cfg, opts)
default:
return nil, fmt.Errorf("gitlab auth: unsupported login mode %q", opts.Metadata[gitLabLoginModeMetadataKey])
}
}
func (a *GitLabAuthenticator) loginOAuth(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
client := gitlabauth.NewAuthClient(cfg)
baseURL := a.resolveString(opts, gitLabBaseURLMetadataKey, gitlabauth.DefaultBaseURL)
clientID, err := a.requireInput(opts, gitLabOAuthClientIDMetadataKey, "Enter GitLab OAuth application client ID: ")
if err != nil {
return nil, err
}
clientSecret, err := a.optionalInput(opts, gitLabOAuthClientSecretMetadataKey, "Enter GitLab OAuth application client secret (press Enter for public PKCE app): ")
if err != nil {
return nil, err
}
callbackPort := a.CallbackPort
if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
redirectURI := gitlabauth.RedirectURL(callbackPort)
pkceCodes, err := gitlabauth.GeneratePKCECodes()
if err != nil {
return nil, err
}
state, err := misc.GenerateRandomState()
if err != nil {
return nil, fmt.Errorf("gitlab state generation failed: %w", err)
}
oauthServer := gitlabauth.NewOAuthServer(callbackPort)
if err := oauthServer.Start(); err != nil {
return nil, err
}
defer func() {
stopCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if stopErr := oauthServer.Stop(stopCtx); stopErr != nil {
log.Warnf("gitlab oauth server stop error: %v", stopErr)
}
}()
authURL, err := client.GenerateAuthURL(baseURL, clientID, redirectURI, state, pkceCodes)
if err != nil {
return nil, err
}
if !opts.NoBrowser {
fmt.Println("Opening browser for GitLab Duo authentication")
if !browser.IsAvailable() {
log.Warn("No browser available; please open the URL manually")
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} else if err = browser.OpenURL(authURL); err != nil {
log.Warnf("Failed to open browser automatically: %v", err)
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
} else {
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
fmt.Println("Waiting for GitLab OAuth callback...")
callbackCh := make(chan *gitlabauth.OAuthResult, 1)
callbackErrCh := make(chan error, 1)
go func() {
result, waitErr := oauthServer.WaitForCallback(5 * time.Minute)
if waitErr != nil {
callbackErrCh <- waitErr
return
}
callbackCh <- result
}()
var result *gitlabauth.OAuthResult
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
return nil, err
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
input, promptErr := opts.Prompt("Paste the GitLab callback URL (or press Enter to keep waiting): ")
if promptErr != nil {
return nil, promptErr
}
parsed, parseErr := misc.ParseOAuthCallback(input)
if parseErr != nil {
return nil, parseErr
}
if parsed == nil {
continue
}
result = &gitlabauth.OAuthResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
}
}
if result.Error != "" {
return nil, fmt.Errorf("gitlab oauth returned error: %s", result.Error)
}
if result.State != state {
return nil, fmt.Errorf("gitlab auth: state mismatch")
}
tokenResp, err := client.ExchangeCodeForTokens(ctx, baseURL, clientID, clientSecret, redirectURI, result.Code, pkceCodes.CodeVerifier)
if err != nil {
return nil, err
}
accessToken := strings.TrimSpace(tokenResp.AccessToken)
if accessToken == "" {
return nil, fmt.Errorf("gitlab auth: missing access token")
}
user, err := client.GetCurrentUser(ctx, baseURL, accessToken)
if err != nil {
return nil, err
}
direct, err := client.FetchDirectAccess(ctx, baseURL, accessToken)
if err != nil {
return nil, err
}
identifier := gitLabAccountIdentifier(user)
fileName := fmt.Sprintf("gitlab-%s.json", sanitizeGitLabFileName(identifier))
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
metadata["auth_kind"] = "oauth"
metadata[gitLabOAuthClientIDMetadataKey] = clientID
if strings.TrimSpace(clientSecret) != "" {
metadata[gitLabOAuthClientSecretMetadataKey] = clientSecret
}
metadata["username"] = strings.TrimSpace(user.Username)
if email := strings.TrimSpace(primaryGitLabEmail(user)); email != "" {
metadata["email"] = email
}
metadata["name"] = strings.TrimSpace(user.Name)
fmt.Println("GitLab Duo authentication successful")
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Label: identifier,
Metadata: metadata,
}, nil
}
func (a *GitLabAuthenticator) loginPAT(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
client := gitlabauth.NewAuthClient(cfg)
baseURL := a.resolveString(opts, gitLabBaseURLMetadataKey, gitlabauth.DefaultBaseURL)
token, err := a.requireInput(opts, gitLabPersonalAccessTokenMetadataKey, "Enter GitLab personal access token: ")
if err != nil {
return nil, err
}
user, err := client.GetCurrentUser(ctx, baseURL, token)
if err != nil {
return nil, err
}
_, err = client.GetPersonalAccessTokenSelf(ctx, baseURL, token)
if err != nil {
return nil, err
}
direct, err := client.FetchDirectAccess(ctx, baseURL, token)
if err != nil {
return nil, err
}
identifier := gitLabAccountIdentifier(user)
fileName := fmt.Sprintf("gitlab-%s-pat.json", sanitizeGitLabFileName(identifier))
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModePAT, nil, direct)
metadata["auth_kind"] = "personal_access_token"
metadata[gitLabPersonalAccessTokenMetadataKey] = strings.TrimSpace(token)
metadata["token_preview"] = maskGitLabToken(token)
metadata["username"] = strings.TrimSpace(user.Username)
if email := strings.TrimSpace(primaryGitLabEmail(user)); email != "" {
metadata["email"] = email
}
metadata["name"] = strings.TrimSpace(user.Name)
fmt.Println("GitLab Duo PAT authentication successful")
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Label: identifier + " (PAT)",
Metadata: metadata,
}, nil
}
func buildGitLabAuthMetadata(baseURL, mode string, tokenResp *gitlabauth.TokenResponse, direct *gitlabauth.DirectAccessResponse) map[string]any {
metadata := map[string]any{
"type": "gitlab",
"auth_method": strings.TrimSpace(mode),
gitLabBaseURLMetadataKey: gitlabauth.NormalizeBaseURL(baseURL),
"last_refresh": time.Now().UTC().Format(time.RFC3339),
"refresh_interval_seconds": 240,
}
if tokenResp != nil {
metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
if refreshToken := strings.TrimSpace(tokenResp.RefreshToken); refreshToken != "" {
metadata["refresh_token"] = refreshToken
}
if tokenType := strings.TrimSpace(tokenResp.TokenType); tokenType != "" {
metadata["token_type"] = tokenType
}
if scope := strings.TrimSpace(tokenResp.Scope); scope != "" {
metadata["scope"] = scope
}
if expiry := gitlabauth.TokenExpiry(time.Now(), tokenResp); !expiry.IsZero() {
metadata["oauth_expires_at"] = expiry.Format(time.RFC3339)
}
}
mergeGitLabDirectAccessMetadata(metadata, direct)
return metadata
}
func mergeGitLabDirectAccessMetadata(metadata map[string]any, direct *gitlabauth.DirectAccessResponse) {
if metadata == nil || direct == nil {
return
}
if base := strings.TrimSpace(direct.BaseURL); base != "" {
metadata["duo_gateway_base_url"] = base
}
if token := strings.TrimSpace(direct.Token); token != "" {
metadata["duo_gateway_token"] = token
}
if direct.ExpiresAt > 0 {
expiry := time.Unix(direct.ExpiresAt, 0).UTC()
metadata["duo_gateway_expires_at"] = expiry.Format(time.RFC3339)
now := time.Now().UTC()
if ttl := expiry.Sub(now); ttl > 0 {
interval := int(ttl.Seconds()) / 2
switch {
case interval < 60:
interval = 60
case interval > 240:
interval = 240
}
metadata["refresh_interval_seconds"] = interval
}
}
if len(direct.Headers) > 0 {
headers := make(map[string]string, len(direct.Headers))
for key, value := range direct.Headers {
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key == "" || value == "" {
continue
}
headers[key] = value
}
if len(headers) > 0 {
metadata["duo_gateway_headers"] = headers
}
}
if direct.ModelDetails != nil {
modelDetails := map[string]any{}
if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" {
modelDetails["model_provider"] = provider
metadata["model_provider"] = provider
}
if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" {
modelDetails["model_name"] = model
metadata["model_name"] = model
}
if len(modelDetails) > 0 {
metadata["model_details"] = modelDetails
}
}
}
func (a *GitLabAuthenticator) resolveString(opts *LoginOptions, key, fallback string) string {
if opts != nil && opts.Metadata != nil {
if value := strings.TrimSpace(opts.Metadata[key]); value != "" {
return value
}
}
if strings.TrimSpace(fallback) != "" {
return fallback
}
return ""
}
func (a *GitLabAuthenticator) requireInput(opts *LoginOptions, key, prompt string) (string, error) {
if value := a.resolveString(opts, key, ""); value != "" {
return value, nil
}
if opts != nil && opts.Prompt != nil {
value, err := opts.Prompt(prompt)
if err != nil {
return "", err
}
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed, nil
}
}
return "", fmt.Errorf("gitlab auth: missing required %s", key)
}
func (a *GitLabAuthenticator) optionalInput(opts *LoginOptions, key, prompt string) (string, error) {
if value := a.resolveString(opts, key, ""); value != "" {
return value, nil
}
if opts != nil && opts.Prompt != nil {
value, err := opts.Prompt(prompt)
if err != nil {
return "", err
}
return strings.TrimSpace(value), nil
}
return "", nil
}
func primaryGitLabEmail(user *gitlabauth.User) string {
if user == nil {
return ""
}
if value := strings.TrimSpace(user.Email); value != "" {
return value
}
return strings.TrimSpace(user.PublicEmail)
}
func gitLabAccountIdentifier(user *gitlabauth.User) string {
if user == nil {
return "user"
}
for _, value := range []string{user.Username, primaryGitLabEmail(user), user.Name} {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return "user"
}
func sanitizeGitLabFileName(value string) string {
value = strings.TrimSpace(strings.ToLower(value))
if value == "" {
return "user"
}
var builder strings.Builder
lastDash := false
for _, r := range value {
switch {
case r >= 'a' && r <= 'z':
builder.WriteRune(r)
lastDash = false
case r >= '0' && r <= '9':
builder.WriteRune(r)
lastDash = false
case r == '-' || r == '_' || r == '.':
builder.WriteRune(r)
lastDash = false
default:
if !lastDash {
builder.WriteRune('-')
lastDash = true
}
}
}
result := strings.Trim(builder.String(), "-")
if result == "" {
return "user"
}
return result
}
func maskGitLabToken(token string) string {
trimmed := strings.TrimSpace(token)
if trimmed == "" {
return ""
}
if len(trimmed) <= 8 {
return trimmed
}
return trimmed[:4] + "..." + trimmed[len(trimmed)-4:]
}

View File

@@ -17,6 +17,7 @@ func init() {
registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() })
registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() })
registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
registerRefreshLead("gitlab", func() Authenticator { return NewGitLabAuthenticator() })
}
func registerRefreshLead(provider string, factory func() Authenticator) {

View File

@@ -390,6 +390,27 @@ func (a *Auth) AccountInfo() (string, string) {
// Check metadata for email first (OAuth-style auth)
if a.Metadata != nil {
if method, ok := a.Metadata["auth_method"].(string); ok {
switch strings.ToLower(strings.TrimSpace(method)) {
case "oauth":
for _, key := range []string{"email", "username", "name"} {
if value, okValue := a.Metadata[key].(string); okValue {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return "oauth", trimmed
}
}
}
case "pat", "personal_access_token":
for _, key := range []string{"username", "email", "name", "token_preview"} {
if value, okValue := a.Metadata[key].(string); okValue {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return "personal_access_token", trimmed
}
}
}
return "personal_access_token", ""
}
}
if v, ok := a.Metadata["email"].(string); ok {
email := strings.TrimSpace(v)
if email != "" {

View File

@@ -119,6 +119,7 @@ func newDefaultAuthManager() *sdkAuth.Manager {
sdkAuth.NewCodexAuthenticator(),
sdkAuth.NewClaudeAuthenticator(),
sdkAuth.NewQwenAuthenticator(),
sdkAuth.NewGitLabAuthenticator(),
)
}
@@ -444,6 +445,8 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
s.coreManager.RegisterExecutor(executor.NewKiloExecutor(s.cfg))
case "github-copilot":
s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg))
case "gitlab":
s.coreManager.RegisterExecutor(executor.NewGitLabExecutor(s.cfg))
default:
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
if providerKey == "" {
@@ -891,7 +894,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
models = applyExcludedModels(models, excluded)
case "kimi":
models = registry.GetKimiModels()
models = applyExcludedModels(models, excluded)
models = applyExcludedModels(models, excluded)
case "github-copilot":
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
@@ -903,6 +906,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
case "kilo":
models = executor.FetchKiloModels(context.Background(), a, s.cfg)
models = applyExcludedModels(models, excluded)
case "gitlab":
models = executor.GitLabModelsFromAuth(a)
models = applyExcludedModels(models, excluded)
default:
// Handle OpenAI-compatibility providers by name using config
if s.cfg != nil {

View File

@@ -0,0 +1,59 @@
package cliproxy
import (
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestRegisterModelsForAuth_GitLabUsesDiscoveredModelAndAlias(t *testing.T) {
service := &Service{cfg: &config.Config{}}
auth := &coreauth.Auth{
ID: "gitlab-auth",
Provider: "gitlab",
Status: coreauth.StatusActive,
Metadata: map[string]any{
"model_details": map[string]any{
"model_provider": "mistral",
"model_name": "codestral-2501",
},
},
}
reg := registry.GetGlobalRegistry()
reg.UnregisterClient(auth.ID)
t.Cleanup(func() {
reg.UnregisterClient(auth.ID)
})
service.registerModelsForAuth(auth)
models := reg.GetModelsForClient(auth.ID)
if len(models) == 0 {
t.Fatal("expected GitLab models to be registered")
}
seenActual := false
seenAlias := false
for _, model := range models {
if model == nil {
continue
}
switch strings.TrimSpace(model.ID) {
case "codestral-2501":
seenActual = true
case "gitlab-duo":
seenAlias = true
}
}
if !seenActual {
t.Fatal("expected discovered GitLab model to be registered")
}
if !seenAlias {
t.Fatal("expected stable GitLab Duo alias to be registered")
}
}