diff --git a/.gitignore b/.gitignore index 2b9c215a..02493d24 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ cliproxy # Configuration config.yaml .env - +.mcp.json # Generated content bin/* logs/* diff --git a/internal/runtime/executor/github_copilot_executor.go b/internal/runtime/executor/github_copilot_executor.go index 3681faf8..af83ad0c 100644 --- a/internal/runtime/executor/github_copilot_executor.go +++ b/internal/runtime/executor/github_copilot_executor.go @@ -110,7 +110,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth. defer reporter.trackFailure(ctx, &err) from := opts.SourceFormat - useResponses := useGitHubCopilotResponsesEndpoint(from) + useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model) to := sdktranslator.FromString("openai") if useResponses { to = sdktranslator.FromString("openai-response") @@ -123,6 +123,12 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth. body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body = e.normalizeModel(req.Model, body) body = flattenAssistantContent(body) + if useResponses { + body = normalizeGitHubCopilotResponsesInput(body) + body = normalizeGitHubCopilotResponsesTools(body) + } else { + body = normalizeGitHubCopilotChatTools(body) + } requestedModel := payloadRequestedModel(opts, req.Model) body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel) body, _ = sjson.SetBytes(body, "stream", false) @@ -199,7 +205,12 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth. } var param any - converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m) + converted := "" + if useResponses && from.String() == "claude" { + converted = translateGitHubCopilotResponsesNonStreamToClaude(data) + } else { + converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m) + } resp = cliproxyexecutor.Response{Payload: []byte(converted)} reporter.ensurePublished(ctx) return resp, nil @@ -216,7 +227,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox defer reporter.trackFailure(ctx, &err) from := opts.SourceFormat - useResponses := useGitHubCopilotResponsesEndpoint(from) + useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model) to := sdktranslator.FromString("openai") if useResponses { to = sdktranslator.FromString("openai-response") @@ -229,6 +240,12 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body = e.normalizeModel(req.Model, body) body = flattenAssistantContent(body) + if useResponses { + body = normalizeGitHubCopilotResponsesInput(body) + body = normalizeGitHubCopilotResponsesTools(body) + } else { + body = normalizeGitHubCopilotChatTools(body) + } requestedModel := payloadRequestedModel(opts, req.Model) body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel) body, _ = sjson.SetBytes(body, "stream", true) @@ -329,7 +346,12 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox } } - chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) + var chunks []string + if useResponses && from.String() == "claude" { + chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), ¶m) + } else { + chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) + } for i := range chunks { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} } @@ -483,8 +505,12 @@ func (e *GitHubCopilotExecutor) normalizeModel(model string, body []byte) []byte return body } -func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format) bool { - return sourceFormat.String() == "openai-response" +func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model string) bool { + if sourceFormat.String() == "openai-response" { + return true + } + baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName) + return strings.Contains(baseModel, "codex") } // flattenAssistantContent converts assistant message content from array format @@ -519,6 +545,411 @@ func flattenAssistantContent(body []byte) []byte { return result } +func normalizeGitHubCopilotChatTools(body []byte) []byte { + tools := gjson.GetBytes(body, "tools") + if tools.Exists() { + filtered := "[]" + if tools.IsArray() { + for _, tool := range tools.Array() { + if tool.Get("type").String() != "function" { + continue + } + filtered, _ = sjson.SetRaw(filtered, "-1", tool.Raw) + } + } + body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered)) + } + + toolChoice := gjson.GetBytes(body, "tool_choice") + if !toolChoice.Exists() { + return body + } + if toolChoice.Type == gjson.String { + switch toolChoice.String() { + case "auto", "none", "required": + return body + } + } + body, _ = sjson.SetBytes(body, "tool_choice", "auto") + return body +} + +func normalizeGitHubCopilotResponsesInput(body []byte) []byte { + input := gjson.GetBytes(body, "input") + if input.Exists() { + if input.Type == gjson.String { + return body + } + inputString := input.Raw + if input.Type != gjson.JSON { + inputString = input.String() + } + body, _ = sjson.SetBytes(body, "input", inputString) + return body + } + + var parts []string + if system := gjson.GetBytes(body, "system"); system.Exists() { + if text := strings.TrimSpace(collectTextFromNode(system)); text != "" { + parts = append(parts, text) + } + } + if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() { + for _, msg := range messages.Array() { + if text := strings.TrimSpace(collectTextFromNode(msg.Get("content"))); text != "" { + parts = append(parts, text) + } + } + } + body, _ = sjson.SetBytes(body, "input", strings.Join(parts, "\n")) + return body +} + +func normalizeGitHubCopilotResponsesTools(body []byte) []byte { + tools := gjson.GetBytes(body, "tools") + if tools.Exists() { + filtered := "[]" + if tools.IsArray() { + for _, tool := range tools.Array() { + toolType := tool.Get("type").String() + // Accept OpenAI format (type="function") and Claude format + // (no type field, but has top-level name + input_schema). + if toolType != "" && toolType != "function" { + continue + } + name := tool.Get("name").String() + if name == "" { + name = tool.Get("function.name").String() + } + if name == "" { + continue + } + normalized := `{"type":"function","name":""}` + normalized, _ = sjson.Set(normalized, "name", name) + if desc := tool.Get("description").String(); desc != "" { + normalized, _ = sjson.Set(normalized, "description", desc) + } else if desc = tool.Get("function.description").String(); desc != "" { + normalized, _ = sjson.Set(normalized, "description", desc) + } + if params := tool.Get("parameters"); params.Exists() { + normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw) + } else if params = tool.Get("function.parameters"); params.Exists() { + normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw) + } else if params = tool.Get("input_schema"); params.Exists() { + normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw) + } + filtered, _ = sjson.SetRaw(filtered, "-1", normalized) + } + } + body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered)) + } + + toolChoice := gjson.GetBytes(body, "tool_choice") + if !toolChoice.Exists() { + return body + } + if toolChoice.Type == gjson.String { + switch toolChoice.String() { + case "auto", "none", "required": + return body + default: + body, _ = sjson.SetBytes(body, "tool_choice", "auto") + return body + } + } + if toolChoice.Type == gjson.JSON { + choiceType := toolChoice.Get("type").String() + if choiceType == "function" { + name := toolChoice.Get("name").String() + if name == "" { + name = toolChoice.Get("function.name").String() + } + if name != "" { + normalized := `{"type":"function","name":""}` + normalized, _ = sjson.Set(normalized, "name", name) + body, _ = sjson.SetRawBytes(body, "tool_choice", []byte(normalized)) + return body + } + } + } + body, _ = sjson.SetBytes(body, "tool_choice", "auto") + return body +} + +func collectTextFromNode(node gjson.Result) string { + if !node.Exists() { + return "" + } + if node.Type == gjson.String { + return node.String() + } + if node.IsArray() { + var parts []string + for _, item := range node.Array() { + if item.Type == gjson.String { + if text := item.String(); text != "" { + parts = append(parts, text) + } + continue + } + if text := item.Get("text").String(); text != "" { + parts = append(parts, text) + continue + } + if nested := collectTextFromNode(item.Get("content")); nested != "" { + parts = append(parts, nested) + } + } + return strings.Join(parts, "\n") + } + if node.Type == gjson.JSON { + if text := node.Get("text").String(); text != "" { + return text + } + if nested := collectTextFromNode(node.Get("content")); nested != "" { + return nested + } + return node.Raw + } + return node.String() +} + +type githubCopilotResponsesStreamToolState struct { + Index int + ID string + Name string +} + +type githubCopilotResponsesStreamState struct { + MessageStarted bool + MessageStopSent bool + TextBlockStarted bool + TextBlockIndex int + NextContentIndex int + HasToolUse bool + OutputIndexToTool map[int]*githubCopilotResponsesStreamToolState + ItemIDToTool map[string]*githubCopilotResponsesStreamToolState +} + +func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string { + root := gjson.ParseBytes(data) + out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}` + out, _ = sjson.Set(out, "id", root.Get("id").String()) + out, _ = sjson.Set(out, "model", root.Get("model").String()) + + hasToolUse := false + if output := root.Get("output"); output.Exists() && output.IsArray() { + for _, item := range output.Array() { + switch item.Get("type").String() { + case "message": + if content := item.Get("content"); content.Exists() && content.IsArray() { + for _, part := range content.Array() { + if part.Get("type").String() != "output_text" { + continue + } + text := part.Get("text").String() + if text == "" { + continue + } + block := `{"type":"text","text":""}` + block, _ = sjson.Set(block, "text", text) + out, _ = sjson.SetRaw(out, "content.-1", block) + } + } + case "function_call": + hasToolUse = true + toolUse := `{"type":"tool_use","id":"","name":"","input":{}}` + toolID := item.Get("call_id").String() + if toolID == "" { + toolID = item.Get("id").String() + } + toolUse, _ = sjson.Set(toolUse, "id", toolID) + toolUse, _ = sjson.Set(toolUse, "name", item.Get("name").String()) + if args := item.Get("arguments").String(); args != "" && gjson.Valid(args) { + argObj := gjson.Parse(args) + if argObj.IsObject() { + toolUse, _ = sjson.SetRaw(toolUse, "input", argObj.Raw) + } + } + out, _ = sjson.SetRaw(out, "content.-1", toolUse) + } + } + } + + inputTokens := root.Get("usage.input_tokens").Int() + outputTokens := root.Get("usage.output_tokens").Int() + out, _ = sjson.Set(out, "usage.input_tokens", inputTokens) + out, _ = sjson.Set(out, "usage.output_tokens", outputTokens) + if hasToolUse { + out, _ = sjson.Set(out, "stop_reason", "tool_use") + } else { + out, _ = sjson.Set(out, "stop_reason", "end_turn") + } + return out +} + +func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []string { + if *param == nil { + *param = &githubCopilotResponsesStreamState{ + TextBlockIndex: -1, + OutputIndexToTool: make(map[int]*githubCopilotResponsesStreamToolState), + ItemIDToTool: make(map[string]*githubCopilotResponsesStreamToolState), + } + } + state := (*param).(*githubCopilotResponsesStreamState) + + if !bytes.HasPrefix(line, dataTag) { + return nil + } + payload := bytes.TrimSpace(line[5:]) + if bytes.Equal(payload, []byte("[DONE]")) { + return nil + } + if !gjson.ValidBytes(payload) { + return nil + } + + event := gjson.GetBytes(payload, "type").String() + results := make([]string, 0, 4) + ensureMessageStart := func() { + if state.MessageStarted { + return + } + messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}` + messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String()) + messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String()) + results = append(results, "event: message_start\ndata: "+messageStart+"\n\n") + state.MessageStarted = true + } + startTextBlockIfNeeded := func() { + if state.TextBlockStarted { + return + } + if state.TextBlockIndex < 0 { + state.TextBlockIndex = state.NextContentIndex + state.NextContentIndex++ + } + contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` + contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex) + results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n") + state.TextBlockStarted = true + } + stopTextBlockIfNeeded := func() { + if !state.TextBlockStarted { + return + } + contentBlockStop := `{"type":"content_block_stop","index":0}` + contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex) + results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n") + state.TextBlockStarted = false + state.TextBlockIndex = -1 + } + resolveTool := func(itemID string, outputIndex int) *githubCopilotResponsesStreamToolState { + if itemID != "" { + if tool, ok := state.ItemIDToTool[itemID]; ok { + return tool + } + } + if tool, ok := state.OutputIndexToTool[outputIndex]; ok { + if itemID != "" { + state.ItemIDToTool[itemID] = tool + } + return tool + } + return nil + } + + switch event { + case "response.created": + ensureMessageStart() + case "response.output_text.delta": + ensureMessageStart() + startTextBlockIfNeeded() + delta := gjson.GetBytes(payload, "delta").String() + if delta != "" { + contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}` + contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex) + contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta) + results = append(results, "event: content_block_delta\ndata: "+contentDelta+"\n\n") + } + case "response.output_item.added": + if gjson.GetBytes(payload, "item.type").String() != "function_call" { + break + } + ensureMessageStart() + stopTextBlockIfNeeded() + state.HasToolUse = true + tool := &githubCopilotResponsesStreamToolState{ + Index: state.NextContentIndex, + ID: gjson.GetBytes(payload, "item.call_id").String(), + Name: gjson.GetBytes(payload, "item.name").String(), + } + if tool.ID == "" { + tool.ID = gjson.GetBytes(payload, "item.id").String() + } + state.NextContentIndex++ + outputIndex := int(gjson.GetBytes(payload, "output_index").Int()) + state.OutputIndexToTool[outputIndex] = tool + if itemID := gjson.GetBytes(payload, "item.id").String(); itemID != "" { + state.ItemIDToTool[itemID] = tool + } + contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}` + contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index) + contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID) + contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name) + results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n") + case "response.output_item.delta": + item := gjson.GetBytes(payload, "item") + if item.Get("type").String() != "function_call" { + break + } + tool := resolveTool(item.Get("id").String(), int(gjson.GetBytes(payload, "output_index").Int())) + if tool == nil { + break + } + partial := gjson.GetBytes(payload, "delta").String() + if partial == "" { + partial = item.Get("arguments").String() + } + if partial == "" { + break + } + inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}` + inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index) + inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial) + results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n") + case "response.output_item.done": + if gjson.GetBytes(payload, "item.type").String() != "function_call" { + break + } + tool := resolveTool(gjson.GetBytes(payload, "item.id").String(), int(gjson.GetBytes(payload, "output_index").Int())) + if tool == nil { + break + } + contentBlockStop := `{"type":"content_block_stop","index":0}` + contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index) + results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n") + case "response.completed": + ensureMessageStart() + stopTextBlockIfNeeded() + if !state.MessageStopSent { + stopReason := "end_turn" + if state.HasToolUse { + stopReason = "tool_use" + } + messageDelta := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` + messageDelta, _ = sjson.Set(messageDelta, "delta.stop_reason", stopReason) + messageDelta, _ = sjson.Set(messageDelta, "usage.input_tokens", gjson.GetBytes(payload, "response.usage.input_tokens").Int()) + messageDelta, _ = sjson.Set(messageDelta, "usage.output_tokens", gjson.GetBytes(payload, "response.usage.output_tokens").Int()) + results = append(results, "event: message_delta\ndata: "+messageDelta+"\n\n") + results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n") + state.MessageStopSent = true + } + } + + return results +} + // isHTTPSuccess checks if the status code indicates success (2xx). func isHTTPSuccess(statusCode int) bool { return statusCode >= 200 && statusCode < 300 diff --git a/internal/runtime/executor/github_copilot_executor_test.go b/internal/runtime/executor/github_copilot_executor_test.go index ef077fd6..2895c8a7 100644 --- a/internal/runtime/executor/github_copilot_executor_test.go +++ b/internal/runtime/executor/github_copilot_executor_test.go @@ -1,8 +1,10 @@ package executor import ( + "strings" "testing" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" "github.com/tidwall/gjson" ) @@ -52,3 +54,189 @@ func TestGitHubCopilotNormalizeModel_StripsSuffix(t *testing.T) { }) } } + +func TestUseGitHubCopilotResponsesEndpoint_OpenAIResponseSource(t *testing.T) { + t.Parallel() + if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai-response"), "claude-3-5-sonnet") { + t.Fatal("expected openai-response source to use /responses") + } +} + +func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) { + t.Parallel() + if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5-codex") { + t.Fatal("expected codex model to use /responses") + } +} + +func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) { + t.Parallel() + if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") { + t.Fatal("expected default openai source with non-codex model to use /chat/completions") + } +} + +func TestNormalizeGitHubCopilotChatTools_KeepFunctionOnly(t *testing.T) { + t.Parallel() + body := []byte(`{"tools":[{"type":"function","function":{"name":"ok"}},{"type":"code_interpreter"}],"tool_choice":"auto"}`) + got := normalizeGitHubCopilotChatTools(body) + tools := gjson.GetBytes(got, "tools").Array() + if len(tools) != 1 { + t.Fatalf("tools len = %d, want 1", len(tools)) + } + if tools[0].Get("type").String() != "function" { + t.Fatalf("tool type = %q, want function", tools[0].Get("type").String()) + } +} + +func TestNormalizeGitHubCopilotChatTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) { + t.Parallel() + body := []byte(`{"tools":[],"tool_choice":{"type":"function","function":{"name":"x"}}}`) + got := normalizeGitHubCopilotChatTools(body) + if gjson.GetBytes(got, "tool_choice").String() != "auto" { + t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw) + } +} + +func TestNormalizeGitHubCopilotResponsesInput_MissingInputExtractedFromSystemAndMessages(t *testing.T) { + t.Parallel() + body := []byte(`{"system":"sys text","messages":[{"role":"user","content":"user text"},{"role":"assistant","content":[{"type":"text","text":"assistant text"}]}]}`) + got := normalizeGitHubCopilotResponsesInput(body) + in := gjson.GetBytes(got, "input") + if in.Type != gjson.String { + t.Fatalf("input type = %v, want string", in.Type) + } + if !strings.Contains(in.String(), "sys text") || !strings.Contains(in.String(), "user text") || !strings.Contains(in.String(), "assistant text") { + t.Fatalf("input = %q, want merged text", in.String()) + } +} + +func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testing.T) { + t.Parallel() + body := []byte(`{"input":{"foo":"bar"}}`) + got := normalizeGitHubCopilotResponsesInput(body) + in := gjson.GetBytes(got, "input") + if in.Type != gjson.String { + t.Fatalf("input type = %v, want string", in.Type) + } + if !strings.Contains(in.String(), "foo") { + t.Fatalf("input = %q, want stringified object", in.String()) + } +} + +func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) { + t.Parallel() + body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`) + got := normalizeGitHubCopilotResponsesTools(body) + tools := gjson.GetBytes(got, "tools").Array() + if len(tools) != 1 { + t.Fatalf("tools len = %d, want 1", len(tools)) + } + if tools[0].Get("name").String() != "sum" { + t.Fatalf("tools[0].name = %q, want sum", tools[0].Get("name").String()) + } + if !tools[0].Get("parameters").Exists() { + t.Fatal("expected parameters to be preserved") + } +} + +func TestNormalizeGitHubCopilotResponsesTools_ClaudeFormatTools(t *testing.T) { + t.Parallel() + body := []byte(`{"tools":[{"name":"Bash","description":"Run commands","input_schema":{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}},{"name":"Read","description":"Read files","input_schema":{"type":"object","properties":{"path":{"type":"string"}}}}]}`) + got := normalizeGitHubCopilotResponsesTools(body) + tools := gjson.GetBytes(got, "tools").Array() + if len(tools) != 2 { + t.Fatalf("tools len = %d, want 2", len(tools)) + } + if tools[0].Get("type").String() != "function" { + t.Fatalf("tools[0].type = %q, want function", tools[0].Get("type").String()) + } + if tools[0].Get("name").String() != "Bash" { + t.Fatalf("tools[0].name = %q, want Bash", tools[0].Get("name").String()) + } + if tools[0].Get("description").String() != "Run commands" { + t.Fatalf("tools[0].description = %q, want 'Run commands'", tools[0].Get("description").String()) + } + if !tools[0].Get("parameters").Exists() { + t.Fatal("expected parameters to be set from input_schema") + } + if tools[0].Get("parameters.properties.command").Exists() != true { + t.Fatal("expected parameters.properties.command to exist") + } + if tools[1].Get("name").String() != "Read" { + t.Fatalf("tools[1].name = %q, want Read", tools[1].Get("name").String()) + } +} + +func TestNormalizeGitHubCopilotResponsesTools_FlattenToolChoiceFunctionObject(t *testing.T) { + t.Parallel() + body := []byte(`{"tool_choice":{"type":"function","function":{"name":"sum"}}}`) + got := normalizeGitHubCopilotResponsesTools(body) + if gjson.GetBytes(got, "tool_choice.type").String() != "function" { + t.Fatalf("tool_choice.type = %q, want function", gjson.GetBytes(got, "tool_choice.type").String()) + } + if gjson.GetBytes(got, "tool_choice.name").String() != "sum" { + t.Fatalf("tool_choice.name = %q, want sum", gjson.GetBytes(got, "tool_choice.name").String()) + } +} + +func TestNormalizeGitHubCopilotResponsesTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) { + t.Parallel() + body := []byte(`{"tool_choice":{"type":"function"}}`) + got := normalizeGitHubCopilotResponsesTools(body) + if gjson.GetBytes(got, "tool_choice").String() != "auto" { + t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw) + } +} + +func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing.T) { + t.Parallel() + resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`) + out := translateGitHubCopilotResponsesNonStreamToClaude(resp) + if gjson.Get(out, "type").String() != "message" { + t.Fatalf("type = %q, want message", gjson.Get(out, "type").String()) + } + if gjson.Get(out, "content.0.type").String() != "text" { + t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String()) + } + if gjson.Get(out, "content.0.text").String() != "hello" { + t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String()) + } +} + +func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *testing.T) { + t.Parallel() + resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`) + out := translateGitHubCopilotResponsesNonStreamToClaude(resp) + if gjson.Get(out, "content.0.type").String() != "tool_use" { + t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String()) + } + if gjson.Get(out, "content.0.name").String() != "sum" { + t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String()) + } + if gjson.Get(out, "stop_reason").String() != "tool_use" { + t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String()) + } +} + +func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.T) { + t.Parallel() + var param any + + created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), ¶m) + if len(created) == 0 || !strings.Contains(created[0], "message_start") { + t.Fatalf("created events = %#v, want message_start", created) + } + + delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), ¶m) + joinedDelta := strings.Join(delta, "") + if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") { + t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta) + } + + completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), ¶m) + joinedCompleted := strings.Join(completed, "") + if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") { + t.Fatalf("completed events = %#v, want message_delta + message_stop", completed) + } +}