mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-20 22:51:45 +00:00
Compare commits
558 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0c48ef58e0 | ||
|
|
15b0d8d039 | ||
|
|
5dcca69e8c | ||
|
|
f5dc6483d5 | ||
|
|
d949921143 | ||
|
|
7b03f04670 | ||
|
|
1267fddf61 | ||
|
|
85c7d43bea | ||
|
|
44c74d6ea2 | ||
|
|
ba454dbfbf | ||
|
|
d1508ca030 | ||
|
|
d4a6a5ae15 | ||
|
|
7c24d54ca8 | ||
|
|
a4c1e32ff6 | ||
|
|
f56cf42461 | ||
|
|
3dea1da249 | ||
|
|
8fac29631d | ||
|
|
8fecd625d2 | ||
|
|
10b55b5ddd | ||
|
|
41ae2c81e7 | ||
|
|
278a89824c | ||
|
|
c4459c4346 | ||
|
|
61e0447f92 | ||
|
|
1dc3018fd6 | ||
|
|
26fd3eff03 | ||
|
|
5bfaf8086b | ||
|
|
6c0a1efd71 | ||
|
|
f5ed5c7453 | ||
|
|
65158cce46 | ||
|
|
1c6c3675d1 | ||
|
|
a583463d60 | ||
|
|
8ed290c1c4 | ||
|
|
727221df2e | ||
|
|
1d8e68ad15 | ||
|
|
0ab1f5412f | ||
|
|
9ded75d335 | ||
|
|
f135fdf7fc | ||
|
|
828df80088 | ||
|
|
c585caa0ce | ||
|
|
5bb69fa4ab | ||
|
|
344043b9f1 | ||
|
|
26c298ced1 | ||
|
|
5ab9afac83 | ||
|
|
65ce86338b | ||
|
|
2a97037d7b | ||
|
|
d801393841 | ||
|
|
b2c0cdfc88 | ||
|
|
f32c8c9620 | ||
|
|
0f45d89255 | ||
|
|
96056d0137 | ||
|
|
f780c289e8 | ||
|
|
ac36119a02 | ||
|
|
39dc4557c1 | ||
|
|
30e94b6792 | ||
|
|
938af75954 | ||
|
|
38f0ae5970 | ||
|
|
cf249586a9 | ||
|
|
1dba2d0f81 | ||
|
|
730809d8ea | ||
|
|
e8d1b79cb3 | ||
|
|
5e81b65f2f | ||
|
|
7e8e2226a6 | ||
|
|
f0c20e852f | ||
|
|
7cdf8e9872 | ||
|
|
c42480a574 | ||
|
|
55c146a0e7 | ||
|
|
e2e3c7dde0 | ||
|
|
9e0ab4d116 | ||
|
|
8783caf313 | ||
|
|
f6f4640c5e | ||
|
|
613fe6768d | ||
|
|
ad8e3964ff | ||
|
|
e9dc576409 | ||
|
|
941334da79 | ||
|
|
d54f816363 | ||
|
|
69b950db4c | ||
|
|
f43d25def1 | ||
|
|
a279192881 | ||
|
|
6a43d7285c | ||
|
|
578c312660 | ||
|
|
6bb9bf3132 | ||
|
|
343a2fc2f7 | ||
|
|
12b967118b | ||
|
|
70efd4e016 | ||
|
|
f5aa68ecda | ||
|
|
9a5f142c33 | ||
|
|
d390b95b76 | ||
|
|
d1f6224b70 | ||
|
|
fcc59d606d | ||
|
|
91e7591955 | ||
|
|
4607356333 | ||
|
|
9a9ed99072 | ||
|
|
5ae38584b8 | ||
|
|
c8b7e2b8d6 | ||
|
|
cad45ffa33 | ||
|
|
6a27bceec0 | ||
|
|
163d68318f | ||
|
|
0ea768011b | ||
|
|
8b9dbe10f0 | ||
|
|
341b4beea1 | ||
|
|
bea13f9724 | ||
|
|
9f5bdfaa31 | ||
|
|
9eabdd09db | ||
|
|
c3f8dc362e | ||
|
|
b85120873b | ||
|
|
6f58518c69 | ||
|
|
000fcb15fa | ||
|
|
ea43361492 | ||
|
|
c1818f197b | ||
|
|
b0653cec7b | ||
|
|
22a1a24cf5 | ||
|
|
7223fee2de | ||
|
|
ada8e2905e | ||
|
|
4ba10531da | ||
|
|
3774b56e9f | ||
|
|
c2d4137fb9 | ||
|
|
2ee938acaf | ||
|
|
8d5e470e1f | ||
|
|
65e9e892a4 | ||
|
|
3882494878 | ||
|
|
088c1d07f4 | ||
|
|
8430b28cfa | ||
|
|
f3ab8f4bc5 | ||
|
|
0e4f189c2e | ||
|
|
98509f615c | ||
|
|
e7a66ae504 | ||
|
|
754b126944 | ||
|
|
ae37ccffbf | ||
|
|
42c062bb5b | ||
|
|
87bf0b73d5 | ||
|
|
f389667ec3 | ||
|
|
29dba0399b | ||
|
|
a824e7cd0b | ||
|
|
140faef7dc | ||
|
|
adb580b344 | ||
|
|
06405f2129 | ||
|
|
b849bf79d6 | ||
|
|
59af2c57b1 | ||
|
|
d1fd2c4ad4 | ||
|
|
b6c6379bfa | ||
|
|
8f0e66b72e | ||
|
|
f63cf6ff7a | ||
|
|
d2419ed49d | ||
|
|
516d22c695 | ||
|
|
73cda6e836 | ||
|
|
0805989ee5 | ||
|
|
9b5ce8c64f | ||
|
|
058793c73a | ||
|
|
75da02af55 | ||
|
|
ab9ebea592 | ||
|
|
7ee37ee4b9 | ||
|
|
837afffb31 | ||
|
|
03a1bac898 | ||
|
|
3171d524f0 | ||
|
|
3e78a8d500 | ||
|
|
fcba912cc4 | ||
|
|
7170eeea5f | ||
|
|
e3eb048c7a | ||
|
|
a59e92435b | ||
|
|
108895fc04 | ||
|
|
abc293c642 | ||
|
|
da3a498a28 | ||
|
|
bb44671845 | ||
|
|
09e480036a | ||
|
|
249f969110 | ||
|
|
4f8acec2d8 | ||
|
|
34339f61ee | ||
|
|
4045378cb4 | ||
|
|
2df35449fe | ||
|
|
c744179645 | ||
|
|
9720b03a6b | ||
|
|
f2c0f3d325 | ||
|
|
4f99bc54f1 | ||
|
|
913f4a9c5f | ||
|
|
25d1c18a3f | ||
|
|
d09dd4d0b2 | ||
|
|
474fb042da | ||
|
|
8435c3d7be | ||
|
|
e783d0a62e | ||
|
|
b05f575e9b | ||
|
|
f5e9f01811 | ||
|
|
ff7dbb5867 | ||
|
|
e34b2b4f1d | ||
|
|
15c2f274ea | ||
|
|
37249339ac | ||
|
|
c422d16beb | ||
|
|
66cd50f603 | ||
|
|
caa529c282 | ||
|
|
51a4379bf4 | ||
|
|
acf98ed10e | ||
|
|
d1c07a091e | ||
|
|
c1a8adf1ab | ||
|
|
08e078fc25 | ||
|
|
105a21548f | ||
|
|
1734aa1664 | ||
|
|
ca11b236a7 | ||
|
|
6fdff8227d | ||
|
|
330e12d3c2 | ||
|
|
bd09c0bf09 | ||
|
|
b468ca79c3 | ||
|
|
d2c7e4e96a | ||
|
|
1c7003ff68 | ||
|
|
1b44364e78 | ||
|
|
ec77f4a4f5 | ||
|
|
f611dd6e96 | ||
|
|
07b7c1a1e0 | ||
|
|
51fd58d74f | ||
|
|
faae9c2f7c | ||
|
|
bc3a6e4646 | ||
|
|
b09b03e35e | ||
|
|
16231947e7 | ||
|
|
39b9a38fbc | ||
|
|
bd855abec9 | ||
|
|
7c3c2e9f64 | ||
|
|
c10f8ae2e2 | ||
|
|
a0bf33eca6 | ||
|
|
88dd9c715d | ||
|
|
a3e21df814 | ||
|
|
d3b94c9241 | ||
|
|
c1d7599829 | ||
|
|
d11936f292 | ||
|
|
17363edf25 | ||
|
|
279cbbbb8a | ||
|
|
486cd4c343 | ||
|
|
25feceb783 | ||
|
|
d26752250d | ||
|
|
b15453c369 | ||
|
|
04ba8c8bc3 | ||
|
|
6570692291 | ||
|
|
f73d55ddaa | ||
|
|
13aa5b3375 | ||
|
|
0fcc02fbea | ||
|
|
c03883ccf0 | ||
|
|
134a9eac9d | ||
|
|
6d8de0ade4 | ||
|
|
1587ff5e74 | ||
|
|
f033d3a6df | ||
|
|
145e0e0b5d | ||
|
|
f8d1bc06ea | ||
|
|
d5930f4e44 | ||
|
|
9b7d7021af | ||
|
|
e41c22ef44 | ||
|
|
5fc2bd393e | ||
|
|
55271403fb | ||
|
|
36fba66619 | ||
|
|
66eb12294a | ||
|
|
73b22ec29b | ||
|
|
c31ae2f3b5 | ||
|
|
76b53d6b5b | ||
|
|
a34dfed378 | ||
|
|
b9b127a7ea | ||
|
|
2741e7b7b3 | ||
|
|
1767a56d4f | ||
|
|
779e6c2d2f | ||
|
|
73c831747b | ||
|
|
b8b89f34f4 | ||
|
|
1fa094dac6 | ||
|
|
f55754621f | ||
|
|
ac26e7db43 | ||
|
|
10b824fcac | ||
|
|
e5d3541b5a | ||
|
|
79755e76ea | ||
|
|
35f158d526 | ||
|
|
6962e09dd9 | ||
|
|
4c4cbd44da | ||
|
|
26eca8b6ba | ||
|
|
62b17f40a1 | ||
|
|
511b8a992e | ||
|
|
7dccc7ba2f | ||
|
|
70c90687fd | ||
|
|
8144ffd5c8 | ||
|
|
0ab977c236 | ||
|
|
224f0de353 | ||
|
|
6b45d311ec | ||
|
|
d54de441d3 | ||
|
|
7386a70724 | ||
|
|
1821bf7051 | ||
|
|
d42b5d4e78 | ||
|
|
1b7447b682 | ||
|
|
40dee4453a | ||
|
|
8902e1cccb | ||
|
|
de5fe71478 | ||
|
|
dcfbec2990 | ||
|
|
c95620f90e | ||
|
|
754f3bcbc3 | ||
|
|
36973d4a6f | ||
|
|
9613f0b3f9 | ||
|
|
274f29e26b | ||
|
|
c8e79c3787 | ||
|
|
8afef43887 | ||
|
|
c1083cbfc6 | ||
|
|
c89d19b300 | ||
|
|
1e6bc81cfd | ||
|
|
1a149475e0 | ||
|
|
e5166841db | ||
|
|
19c52bcb60 | ||
|
|
bb9b2d1758 | ||
|
|
7fa527193c | ||
|
|
ed0eb51b4d | ||
|
|
0e4f669c8b | ||
|
|
76c064c729 | ||
|
|
d2f652f436 | ||
|
|
6a452a54d5 | ||
|
|
9e5693e74f | ||
|
|
528b1a2307 | ||
|
|
0cc978ec1d | ||
|
|
d312422ab4 | ||
|
|
fee736933b | ||
|
|
09c92aa0b5 | ||
|
|
8c67b3ae64 | ||
|
|
000e4ceb4e | ||
|
|
5c99846ecf | ||
|
|
cc32f5ff61 | ||
|
|
fbff68b9e0 | ||
|
|
7e1a543b79 | ||
|
|
d475aaba96 | ||
|
|
1dc4ecb1b8 | ||
|
|
1315f710f5 | ||
|
|
96f55570f7 | ||
|
|
0906aeca87 | ||
|
|
7333619f15 | ||
|
|
97c0487add | ||
|
|
74b862d8b8 | ||
|
|
2db8df8e38 | ||
|
|
a576088d5f | ||
|
|
66ff916838 | ||
|
|
7b0453074e | ||
|
|
a000eb523d | ||
|
|
18a4fedc7f | ||
|
|
5d6cdccda0 | ||
|
|
1b7f4ac3e1 | ||
|
|
afc1a5b814 | ||
|
|
7ed38db54f | ||
|
|
28c10f4e69 | ||
|
|
6e12441a3b | ||
|
|
65c439c18d | ||
|
|
0ed2d16596 | ||
|
|
db335ac616 | ||
|
|
f3c59165d7 | ||
|
|
e6690cb447 | ||
|
|
35907416b8 | ||
|
|
e8bb350467 | ||
|
|
5331d51f27 | ||
|
|
755ca75879 | ||
|
|
2398ebad55 | ||
|
|
c1bf298216 | ||
|
|
e005208d76 | ||
|
|
d1df70d02f | ||
|
|
f81acd0760 | ||
|
|
636da4c932 | ||
|
|
cccb77b552 | ||
|
|
2bd646ad70 | ||
|
|
52c1fa025e | ||
|
|
680105f84d | ||
|
|
f7069e9548 | ||
|
|
7275e99b41 | ||
|
|
c28b65f849 | ||
|
|
793840cdb4 | ||
|
|
8f421de532 | ||
|
|
be2dd60ee7 | ||
|
|
ea3e0b713e | ||
|
|
8179d5a8a4 | ||
|
|
6fa7abe434 | ||
|
|
5135c22cd6 | ||
|
|
1e27990561 | ||
|
|
e1e9fc43c1 | ||
|
|
b2921518ac | ||
|
|
dd64adbeeb | ||
|
|
616d41c06a | ||
|
|
e0e337aeb9 | ||
|
|
d52839fced | ||
|
|
4022e69651 | ||
|
|
56073ded69 | ||
|
|
9738a53f49 | ||
|
|
be3f8dbf7e | ||
|
|
9c6c3612a8 | ||
|
|
19e1a4447a | ||
|
|
36efcc6e28 | ||
|
|
a337ecf35c | ||
|
|
7c2ad4cda2 | ||
|
|
fb95813fbf | ||
|
|
db63f9b5d6 | ||
|
|
25f6c4a250 | ||
|
|
b24ae74216 | ||
|
|
59ad8f40dc | ||
|
|
ff03dc6a2c | ||
|
|
dc7187ca5b | ||
|
|
b1dcff778c | ||
|
|
cef2aeeb08 | ||
|
|
bcd1e8cc34 | ||
|
|
198b3f4a40 | ||
|
|
9fee7f488e | ||
|
|
1b46d39b8b | ||
|
|
c1241a98e2 | ||
|
|
8d8f5970ee | ||
|
|
f90120f846 | ||
|
|
0b94d36c4a | ||
|
|
152c310bb7 | ||
|
|
f6bbca35ab | ||
|
|
c8cee6a209 | ||
|
|
b5701f416b | ||
|
|
4b1a404fcb | ||
|
|
b93cce5412 | ||
|
|
c6cb24039d | ||
|
|
5382408489 | ||
|
|
67669196ed | ||
|
|
5c817a9b42 | ||
|
|
e08f68ed7c | ||
|
|
f09ed25fd3 | ||
|
|
58fd9bf964 | ||
|
|
7b3dfc67bc | ||
|
|
cdd24052d3 | ||
|
|
5da0decef6 | ||
|
|
733fd8edab | ||
|
|
af27f2b8bc | ||
|
|
2e1925d762 | ||
|
|
77254bd074 | ||
|
|
5b6342e6ac | ||
|
|
e166e56249 | ||
|
|
3960c93d51 | ||
|
|
339a81b650 | ||
|
|
560c020477 | ||
|
|
aec65e3be3 | ||
|
|
f44f0702f8 | ||
|
|
b76b79068f | ||
|
|
34c8ccb961 | ||
|
|
d08e164af3 | ||
|
|
8178efaeda | ||
|
|
86d5db472a | ||
|
|
020d36f6e8 | ||
|
|
1db23979e8 | ||
|
|
c3d5dbe96f | ||
|
|
5484489406 | ||
|
|
0ac52da460 | ||
|
|
817cebb321 | ||
|
|
683f3709d6 | ||
|
|
dbd42a42b2 | ||
|
|
ec24baf757 | ||
|
|
dea3e74d35 | ||
|
|
a6c3042e34 | ||
|
|
861537c9bd | ||
|
|
8c92cb0883 | ||
|
|
89d7be9525 | ||
|
|
2b79d7f22f | ||
|
|
2bb686f594 | ||
|
|
163fe287ce | ||
|
|
70988d387b | ||
|
|
52058a1659 | ||
|
|
df5595a0c9 | ||
|
|
ddaa9d2436 | ||
|
|
7b7b258c38 | ||
|
|
a00f774f5a | ||
|
|
9daf1ba8b5 | ||
|
|
76f2359637 | ||
|
|
dcb1c9be8a | ||
|
|
a24f4ace78 | ||
|
|
c631df8c3b | ||
|
|
54c3eb1b1e | ||
|
|
bb28cd26ad | ||
|
|
046865461e | ||
|
|
cf74ed2f0c | ||
|
|
c3762328a5 | ||
|
|
e333fbea3d | ||
|
|
efbe36d1d4 | ||
|
|
8553cfa40e | ||
|
|
30d5c95b26 | ||
|
|
d1e3195e6f | ||
|
|
05a35662ae | ||
|
|
ce53d3a287 | ||
|
|
5f58248016 | ||
|
|
4cc99e7449 | ||
|
|
71773fe032 | ||
|
|
a1e0fa0f39 | ||
|
|
fc2f0b6983 | ||
|
|
5c9997cdac | ||
|
|
6f81046730 | ||
|
|
0687472d01 | ||
|
|
7739738fb3 | ||
|
|
99d1ce247b | ||
|
|
f5941a411c | ||
|
|
ba672bbd07 | ||
|
|
d9c6627a53 | ||
|
|
2e9907c3ac | ||
|
|
90afb9cb73 | ||
|
|
d0cc0cd9a5 | ||
|
|
338321e553 | ||
|
|
182b31963a | ||
|
|
4f48e5254a | ||
|
|
15dd5db1d7 | ||
|
|
424711b718 | ||
|
|
91a2b1f0b4 | ||
|
|
2b134fc378 | ||
|
|
b9153719b0 | ||
|
|
631e5c8331 | ||
|
|
e9c60a0a67 | ||
|
|
98a1bb5a7f | ||
|
|
ca90487a8c | ||
|
|
1042489f85 | ||
|
|
38277c1ea6 | ||
|
|
ee0c24628f | ||
|
|
07d6689d87 | ||
|
|
3a18f6fcca | ||
|
|
099e734a02 | ||
|
|
a52da26b5d | ||
|
|
522a68a4ea | ||
|
|
a02eda54d0 | ||
|
|
97ef633c57 | ||
|
|
dae8463ba1 | ||
|
|
7c1299922e | ||
|
|
ddcf1f279d | ||
|
|
7e6bb8fdc5 | ||
|
|
9cee8ef87b | ||
|
|
93fb841bcb | ||
|
|
0c05131aeb | ||
|
|
5ebc58fab4 | ||
|
|
2b609dd891 | ||
|
|
a8cbc68c3e | ||
|
|
11a795a01c | ||
|
|
89c428216e | ||
|
|
2695a99623 | ||
|
|
242aecd924 | ||
|
|
ce8cc1ba33 | ||
|
|
ad5253bd2b | ||
|
|
97fdd2e088 | ||
|
|
9397f7049f | ||
|
|
a14d19b92c | ||
|
|
8ae0c05ea6 | ||
|
|
8822f20d17 | ||
|
|
553d6f50ea | ||
|
|
f0e5a5a367 | ||
|
|
f6dfea9357 | ||
|
|
cc8dc7f62c | ||
|
|
a3846ea513 | ||
|
|
8d44be858e | ||
|
|
0e6bb076e9 | ||
|
|
ac135fc7cb | ||
|
|
4e1d09809d | ||
|
|
9e855f8100 | ||
|
|
25680a8259 | ||
|
|
13c93e8cfd | ||
|
|
88aa1b9fd1 | ||
|
|
ac95e92829 | ||
|
|
8526c2da25 | ||
|
|
68a6cabf8b | ||
|
|
ac0e387da1 | ||
|
|
7fe1d102cb | ||
|
|
c51851689b | ||
|
|
419bf784ab | ||
|
|
7d6660d181 | ||
|
|
d8e3d4e2b6 | ||
|
|
dd44413ba5 | ||
|
|
10fa0f2062 | ||
|
|
30338ecec4 | ||
|
|
9a37defed3 | ||
|
|
c83a057996 | ||
|
|
b7588428c5 | ||
|
|
14cb2b95c6 | ||
|
|
fdeef48498 |
81
.github/workflows/agents-md-guard.yml
vendored
Normal file
81
.github/workflows/agents-md-guard.yml
vendored
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
name: agents-md-guard
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request_target:
|
||||||
|
types:
|
||||||
|
- opened
|
||||||
|
- synchronize
|
||||||
|
- reopened
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
issues: write
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
close-when-agents-md-changed:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Detect AGENTS.md changes and close PR
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const prNumber = context.payload.pull_request.number;
|
||||||
|
const { owner, repo } = context.repo;
|
||||||
|
|
||||||
|
const files = await github.paginate(github.rest.pulls.listFiles, {
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
pull_number: prNumber,
|
||||||
|
per_page: 100,
|
||||||
|
});
|
||||||
|
|
||||||
|
const touchesAgentsMd = (path) =>
|
||||||
|
typeof path === "string" &&
|
||||||
|
(path === "AGENTS.md" || path.endsWith("/AGENTS.md"));
|
||||||
|
|
||||||
|
const touched = files.filter(
|
||||||
|
(f) => touchesAgentsMd(f.filename) || touchesAgentsMd(f.previous_filename),
|
||||||
|
);
|
||||||
|
|
||||||
|
if (touched.length === 0) {
|
||||||
|
core.info("No AGENTS.md changes detected.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const changedList = touched
|
||||||
|
.map((f) =>
|
||||||
|
f.previous_filename && f.previous_filename !== f.filename
|
||||||
|
? `- ${f.previous_filename} -> ${f.filename}`
|
||||||
|
: `- ${f.filename}`,
|
||||||
|
)
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
const body = [
|
||||||
|
"This repository does not allow modifying `AGENTS.md` in pull requests.",
|
||||||
|
"",
|
||||||
|
"Detected changes:",
|
||||||
|
changedList,
|
||||||
|
"",
|
||||||
|
"Please revert these changes and open a new PR without touching `AGENTS.md`.",
|
||||||
|
].join("\n");
|
||||||
|
|
||||||
|
try {
|
||||||
|
await github.rest.issues.createComment({
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
issue_number: prNumber,
|
||||||
|
body,
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
core.warning(`Failed to comment on PR #${prNumber}: ${error.message}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
await github.rest.pulls.update({
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
pull_number: prNumber,
|
||||||
|
state: "closed",
|
||||||
|
});
|
||||||
|
|
||||||
|
core.setFailed("PR modifies AGENTS.md");
|
||||||
73
.github/workflows/auto-retarget-main-pr-to-dev.yml
vendored
Normal file
73
.github/workflows/auto-retarget-main-pr-to-dev.yml
vendored
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
name: auto-retarget-main-pr-to-dev
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request_target:
|
||||||
|
types:
|
||||||
|
- opened
|
||||||
|
- reopened
|
||||||
|
- edited
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
issues: write
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
retarget:
|
||||||
|
if: github.actor != 'github-actions[bot]'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Retarget PR base to dev
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const pr = context.payload.pull_request;
|
||||||
|
const prNumber = pr.number;
|
||||||
|
const { owner, repo } = context.repo;
|
||||||
|
|
||||||
|
const baseRef = pr.base?.ref;
|
||||||
|
const headRef = pr.head?.ref;
|
||||||
|
const desiredBase = "dev";
|
||||||
|
|
||||||
|
if (baseRef !== "main") {
|
||||||
|
core.info(`PR #${prNumber} base is ${baseRef}; nothing to do.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (headRef === desiredBase) {
|
||||||
|
core.info(`PR #${prNumber} is ${desiredBase} -> main; skipping retarget.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
core.info(`Retargeting PR #${prNumber} base from ${baseRef} to ${desiredBase}.`);
|
||||||
|
|
||||||
|
try {
|
||||||
|
await github.rest.pulls.update({
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
pull_number: prNumber,
|
||||||
|
base: desiredBase,
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
core.setFailed(`Failed to retarget PR #${prNumber} to ${desiredBase}: ${error.message}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const body = [
|
||||||
|
`This pull request targeted \`${baseRef}\`.`,
|
||||||
|
"",
|
||||||
|
`The base branch has been automatically changed to \`${desiredBase}\`.`,
|
||||||
|
].join("\n");
|
||||||
|
|
||||||
|
try {
|
||||||
|
await github.rest.issues.createComment({
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
issue_number: prNumber,
|
||||||
|
body,
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
core.warning(`Failed to comment on PR #${prNumber}: ${error.message}`);
|
||||||
|
}
|
||||||
14
.github/workflows/docker-image.yml
vendored
14
.github/workflows/docker-image.yml
vendored
@@ -16,6 +16,10 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: |
|
||||||
|
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||||
|
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to DockerHub
|
- name: Login to DockerHub
|
||||||
@@ -25,7 +29,7 @@ jobs:
|
|||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Generate Build Metadata
|
- name: Generate Build Metadata
|
||||||
run: |
|
run: |
|
||||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||||
- name: Build and push (amd64)
|
- name: Build and push (amd64)
|
||||||
@@ -47,6 +51,10 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: |
|
||||||
|
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||||
|
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to DockerHub
|
- name: Login to DockerHub
|
||||||
@@ -56,7 +64,7 @@ jobs:
|
|||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Generate Build Metadata
|
- name: Generate Build Metadata
|
||||||
run: |
|
run: |
|
||||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||||
- name: Build and push (arm64)
|
- name: Build and push (arm64)
|
||||||
@@ -90,7 +98,7 @@ jobs:
|
|||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Generate Build Metadata
|
- name: Generate Build Metadata
|
||||||
run: |
|
run: |
|
||||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||||
- name: Create and push multi-arch manifests
|
- name: Create and push multi-arch manifests
|
||||||
|
|||||||
4
.github/workflows/pr-test-build.yml
vendored
4
.github/workflows/pr-test-build.yml
vendored
@@ -12,6 +12,10 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: |
|
||||||
|
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||||
|
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
|
|||||||
9
.github/workflows/release.yaml
vendored
9
.github/workflows/release.yaml
vendored
@@ -16,6 +16,10 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: |
|
||||||
|
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||||
|
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||||
- run: git fetch --force --tags
|
- run: git fetch --force --tags
|
||||||
- uses: actions/setup-go@v4
|
- uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
@@ -23,15 +27,14 @@ jobs:
|
|||||||
cache: true
|
cache: true
|
||||||
- name: Generate Build Metadata
|
- name: Generate Build Metadata
|
||||||
run: |
|
run: |
|
||||||
VERSION=$(git describe --tags --always --dirty)
|
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||||
echo "VERSION=${VERSION}" >> $GITHUB_ENV
|
|
||||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||||
- uses: goreleaser/goreleaser-action@v4
|
- uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
distribution: goreleaser
|
distribution: goreleaser
|
||||||
version: latest
|
version: latest
|
||||||
args: release --clean
|
args: release --clean --skip=validate
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
VERSION: ${{ env.VERSION }}
|
VERSION: ${{ env.VERSION }}
|
||||||
|
|||||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -1,6 +1,7 @@
|
|||||||
# Binaries
|
# Binaries
|
||||||
cli-proxy-api
|
cli-proxy-api
|
||||||
cliproxy
|
cliproxy
|
||||||
|
/server
|
||||||
*.exe
|
*.exe
|
||||||
|
|
||||||
|
|
||||||
@@ -36,15 +37,16 @@ GEMINI.md
|
|||||||
|
|
||||||
# Tooling metadata
|
# Tooling metadata
|
||||||
.vscode/*
|
.vscode/*
|
||||||
|
.worktrees/
|
||||||
.codex/*
|
.codex/*
|
||||||
.claude/*
|
.claude/*
|
||||||
.gemini/*
|
.gemini/*
|
||||||
.serena/*
|
.serena/*
|
||||||
.agent/*
|
.agent/*
|
||||||
.agents/*
|
.agents/*
|
||||||
.agents/*
|
|
||||||
.opencode/*
|
.opencode/*
|
||||||
.idea/*
|
.idea/*
|
||||||
|
.beads/*
|
||||||
.bmad/*
|
.bmad/*
|
||||||
_bmad/*
|
_bmad/*
|
||||||
_bmad-output/*
|
_bmad-output/*
|
||||||
@@ -53,4 +55,10 @@ _bmad-output/*
|
|||||||
# macOS
|
# macOS
|
||||||
.DS_Store
|
.DS_Store
|
||||||
._*
|
._*
|
||||||
|
|
||||||
|
# Opencode
|
||||||
|
.beads/
|
||||||
|
.opencode/
|
||||||
|
.cli-proxy-api/
|
||||||
|
.venv/
|
||||||
*.bak
|
*.bak
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
version: 2
|
||||||
|
|
||||||
builds:
|
builds:
|
||||||
- id: "cli-proxy-api-plus"
|
- id: "cli-proxy-api-plus"
|
||||||
env:
|
env:
|
||||||
@@ -6,6 +8,7 @@ builds:
|
|||||||
- linux
|
- linux
|
||||||
- windows
|
- windows
|
||||||
- darwin
|
- darwin
|
||||||
|
- freebsd
|
||||||
goarch:
|
goarch:
|
||||||
- amd64
|
- amd64
|
||||||
- arm64
|
- arm64
|
||||||
|
|||||||
58
AGENTS.md
Normal file
58
AGENTS.md
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
# AGENTS.md
|
||||||
|
|
||||||
|
Go 1.26+ proxy server providing OpenAI/Gemini/Claude/Codex compatible APIs with OAuth and round-robin load balancing.
|
||||||
|
|
||||||
|
## Repository
|
||||||
|
- GitHub: https://github.com/router-for-me/CLIProxyAPI
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
```bash
|
||||||
|
gofmt -w . # Format (required after Go changes)
|
||||||
|
go build -o cli-proxy-api ./cmd/server # Build
|
||||||
|
go run ./cmd/server # Run dev server
|
||||||
|
go test ./... # Run all tests
|
||||||
|
go test -v -run TestName ./path/to/pkg # Run single test
|
||||||
|
go build -o test-output ./cmd/server && rm test-output # Verify compile (REQUIRED after changes)
|
||||||
|
```
|
||||||
|
- Common flags: `--config <path>`, `--tui`, `--standalone`, `--local-model`, `--no-browser`, `--oauth-callback-port <port>`
|
||||||
|
|
||||||
|
## Config
|
||||||
|
- Default config: `config.yaml` (template: `config.example.yaml`)
|
||||||
|
- `.env` is auto-loaded from the working directory
|
||||||
|
- Auth material defaults under `auths/`
|
||||||
|
- Storage backends: file-based default; optional Postgres/git/object store (`PGSTORE_*`, `GITSTORE_*`, `OBJECTSTORE_*`)
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
- `cmd/server/` — Server entrypoint
|
||||||
|
- `internal/api/` — Gin HTTP API (routes, middleware, modules)
|
||||||
|
- `internal/api/modules/amp/` — Amp integration (Amp-style routes + reverse proxy)
|
||||||
|
- `internal/thinking/` — Main thinking/reasoning pipeline. `ApplyThinking()` (apply.go) parses suffixes (`suffix.go`, suffix overrides body), normalizes config to canonical `ThinkingConfig` (`types.go`), normalizes and validates centrally (`validate.go`/`convert.go`), then applies provider-specific output via `ProviderApplier`. Do not break this "canonical representation → per-provider translation" architecture.
|
||||||
|
- `internal/runtime/executor/` — Per-provider runtime executors (incl. Codex WebSocket)
|
||||||
|
- `internal/translator/` — Provider protocol translators (and shared `common`)
|
||||||
|
- `internal/registry/` — Model registry + remote updater (`StartModelsUpdater`); `--local-model` disables remote updates
|
||||||
|
- `internal/store/` — Storage implementations and secret resolution
|
||||||
|
- `internal/managementasset/` — Config snapshots and management assets
|
||||||
|
- `internal/cache/` — Request signature caching
|
||||||
|
- `internal/watcher/` — Config hot-reload and watchers
|
||||||
|
- `internal/wsrelay/` — WebSocket relay sessions
|
||||||
|
- `internal/usage/` — Usage and token accounting
|
||||||
|
- `internal/tui/` — Bubbletea terminal UI (`--tui`, `--standalone`)
|
||||||
|
- `sdk/cliproxy/` — Embeddable SDK entry (service/builder/watchers/pipeline)
|
||||||
|
- `test/` — Cross-module integration tests
|
||||||
|
|
||||||
|
## Code Conventions
|
||||||
|
- Keep changes small and simple (KISS)
|
||||||
|
- Comments in English only
|
||||||
|
- If editing code that already contains non-English comments, translate them to English (don’t add new non-English comments)
|
||||||
|
- For user-visible strings, keep the existing language used in that file/area
|
||||||
|
- New Markdown docs should be in English unless the file is explicitly language-specific (e.g. `README_CN.md`)
|
||||||
|
- As a rule, do not make standalone changes to `internal/translator/`. You may modify it only as part of broader changes elsewhere.
|
||||||
|
- If a task requires changing only `internal/translator/`, run `gh repo view --json viewerPermission -q .viewerPermission` to confirm you have `WRITE`, `MAINTAIN`, or `ADMIN`. If you do, you may proceed; otherwise, file a GitHub issue including the goal, rationale, and the intended implementation code, then stop further work.
|
||||||
|
- `internal/runtime/executor/` should contain executors and their unit tests only. Place any helper/supporting files under `internal/runtime/executor/helps/`.
|
||||||
|
- Follow `gofmt`; keep imports goimports-style; wrap errors with context where helpful
|
||||||
|
- Do not use `log.Fatal`/`log.Fatalf` (terminates the process); prefer returning errors and logging via logrus
|
||||||
|
- Shadowed variables: use method suffix (`errStart := server.Start()`)
|
||||||
|
- Wrap defer errors: `defer func() { if err := f.Close(); err != nil { log.Errorf(...) } }()`
|
||||||
|
- Use logrus structured logging; avoid leaking secrets/tokens in logs
|
||||||
|
- Avoid panics in HTTP handlers; prefer logged errors and meaningful HTTP status codes
|
||||||
|
- Timeouts are allowed only during credential acquisition; after an upstream connection is established, do not set timeouts for any subsequent network behavior. Intentional exceptions that must remain allowed are the Codex websocket liveness deadlines in `internal/runtime/executor/codex_websockets_executor.go`, the wsrelay session deadlines in `internal/wsrelay/session.go`, the management APICall timeout in `internal/api/handlers/management/api_tools.go`, and the `cmd/fetch_antigravity_models` utility timeouts
|
||||||
117
README.md
117
README.md
@@ -8,123 +8,6 @@ All third-party provider support is maintained by community contributors; CLIPro
|
|||||||
|
|
||||||
The Plus release stays in lockstep with the mainline features.
|
The Plus release stays in lockstep with the mainline features.
|
||||||
|
|
||||||
## Differences from the Mainline
|
|
||||||
|
|
||||||
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
|
||||||
|
|
||||||
## New Features (Plus Enhanced)
|
|
||||||
|
|
||||||
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $10/month. It provides access to their flagship GLM-4.7 & (GLM-5 Only Available for Pro Users)model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
|
|
||||||
|
|
||||||
## Kiro Authentication
|
|
||||||
|
|
||||||
### CLI Login
|
|
||||||
|
|
||||||
> **Note:** Google/GitHub login is not available for third-party applications due to AWS Cognito restrictions.
|
|
||||||
|
|
||||||
**AWS Builder ID** (recommended):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Device code flow
|
|
||||||
./CLIProxyAPI --kiro-aws-login
|
|
||||||
|
|
||||||
# Authorization code flow
|
|
||||||
./CLIProxyAPI --kiro-aws-authcode
|
|
||||||
```
|
|
||||||
|
|
||||||
**Import token from Kiro IDE:**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./CLIProxyAPI --kiro-import
|
|
||||||
```
|
|
||||||
|
|
||||||
To get a token from Kiro IDE:
|
|
||||||
|
|
||||||
1. Open Kiro IDE and login with Google (or GitHub)
|
|
||||||
2. Find the token file: `~/.kiro/kiro-auth-token.json`
|
|
||||||
3. Run: `./CLIProxyAPI --kiro-import`
|
|
||||||
|
|
||||||
**AWS IAM Identity Center (IDC):**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
|
||||||
|
|
||||||
# Specify region
|
|
||||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
|
||||||
```
|
|
||||||
|
|
||||||
**Additional flags:**
|
|
||||||
|
|
||||||
| Flag | Description |
|
|
||||||
|------|-------------|
|
|
||||||
| `--no-browser` | Don't open browser automatically, print URL instead |
|
|
||||||
| `--no-incognito` | Use existing browser session (Kiro defaults to incognito). Useful for corporate SSO that requires an authenticated browser session |
|
|
||||||
| `--kiro-idc-start-url` | IDC Start URL (required with `--kiro-idc-login`) |
|
|
||||||
| `--kiro-idc-region` | IDC region (default: `us-east-1`) |
|
|
||||||
| `--kiro-idc-flow` | IDC flow type: `authcode` (default) or `device` |
|
|
||||||
|
|
||||||
### Web-based OAuth Login
|
|
||||||
|
|
||||||
Access the Kiro OAuth web interface at:
|
|
||||||
|
|
||||||
```
|
|
||||||
http://your-server:8080/v0/oauth/kiro
|
|
||||||
```
|
|
||||||
|
|
||||||
This provides a browser-based OAuth flow for Kiro (AWS CodeWhisperer) authentication with:
|
|
||||||
- AWS Builder ID login
|
|
||||||
- AWS Identity Center (IDC) login
|
|
||||||
- Token import from Kiro IDE
|
|
||||||
|
|
||||||
## Quick Deployment with Docker
|
|
||||||
|
|
||||||
### One-Command Deployment
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Create deployment directory
|
|
||||||
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
|
||||||
|
|
||||||
# Create docker-compose.yml
|
|
||||||
cat > docker-compose.yml << 'EOF'
|
|
||||||
services:
|
|
||||||
cli-proxy-api:
|
|
||||||
image: eceasy/cli-proxy-api-plus:latest
|
|
||||||
container_name: cli-proxy-api-plus
|
|
||||||
ports:
|
|
||||||
- "8317:8317"
|
|
||||||
volumes:
|
|
||||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
|
||||||
- ./auths:/root/.cli-proxy-api
|
|
||||||
- ./logs:/CLIProxyAPI/logs
|
|
||||||
restart: unless-stopped
|
|
||||||
EOF
|
|
||||||
|
|
||||||
# Download example config
|
|
||||||
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
|
|
||||||
|
|
||||||
# Pull and start
|
|
||||||
docker compose pull && docker compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
### Configuration
|
|
||||||
|
|
||||||
Edit `config.yaml` before starting:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# Basic configuration example
|
|
||||||
server:
|
|
||||||
port: 8317
|
|
||||||
|
|
||||||
# Add your provider configurations here
|
|
||||||
```
|
|
||||||
|
|
||||||
### Update to Latest Version
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd ~/cli-proxy
|
|
||||||
docker compose pull && docker compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.
|
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.
|
||||||
|
|||||||
121
README_CN.md
121
README_CN.md
@@ -6,125 +6,6 @@
|
|||||||
|
|
||||||
所有的第三方供应商支持都由第三方社区维护者提供,CLIProxyAPI 不提供技术支持。如需取得支持,请与对应的社区维护者联系。
|
所有的第三方供应商支持都由第三方社区维护者提供,CLIProxyAPI 不提供技术支持。如需取得支持,请与对应的社区维护者联系。
|
||||||
|
|
||||||
该 Plus 版本的主线功能与主线功能强制同步。
|
|
||||||
|
|
||||||
## 与主线版本版本差异
|
|
||||||
|
|
||||||
[](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
|
|
||||||
|
|
||||||
## 新增功能 (Plus 增强版)
|
|
||||||
|
|
||||||
GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7(受限于算力,目前仅限Pro用户开放),为开发者提供顶尖的编码体验。
|
|
||||||
|
|
||||||
智谱AI为本产品提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
|
|
||||||
|
|
||||||
### 命令行登录
|
|
||||||
|
|
||||||
> **注意:** 由于 AWS Cognito 限制,Google/GitHub 登录不可用于第三方应用。
|
|
||||||
|
|
||||||
**AWS Builder ID**(推荐):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 设备码流程
|
|
||||||
./CLIProxyAPI --kiro-aws-login
|
|
||||||
|
|
||||||
# 授权码流程
|
|
||||||
./CLIProxyAPI --kiro-aws-authcode
|
|
||||||
```
|
|
||||||
|
|
||||||
**从 Kiro IDE 导入令牌:**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./CLIProxyAPI --kiro-import
|
|
||||||
```
|
|
||||||
|
|
||||||
获取令牌步骤:
|
|
||||||
|
|
||||||
1. 打开 Kiro IDE,使用 Google(或 GitHub)登录
|
|
||||||
2. 找到令牌文件:`~/.kiro/kiro-auth-token.json`
|
|
||||||
3. 运行:`./CLIProxyAPI --kiro-import`
|
|
||||||
|
|
||||||
**AWS IAM Identity Center (IDC):**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
|
||||||
|
|
||||||
# 指定区域
|
|
||||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
|
||||||
```
|
|
||||||
|
|
||||||
**附加参数:**
|
|
||||||
|
|
||||||
| 参数 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| `--no-browser` | 不自动打开浏览器,打印 URL |
|
|
||||||
| `--no-incognito` | 使用已有浏览器会话(Kiro 默认使用无痕模式),适用于需要已登录浏览器会话的企业 SSO 场景 |
|
|
||||||
| `--kiro-idc-start-url` | IDC Start URL(`--kiro-idc-login` 必需) |
|
|
||||||
| `--kiro-idc-region` | IDC 区域(默认:`us-east-1`) |
|
|
||||||
| `--kiro-idc-flow` | IDC 流程类型:`authcode`(默认)或 `device` |
|
|
||||||
|
|
||||||
### 网页端 OAuth 登录
|
|
||||||
|
|
||||||
访问 Kiro OAuth 网页认证界面:
|
|
||||||
|
|
||||||
```
|
|
||||||
http://your-server:8080/v0/oauth/kiro
|
|
||||||
```
|
|
||||||
|
|
||||||
提供基于浏览器的 Kiro (AWS CodeWhisperer) OAuth 认证流程,支持:
|
|
||||||
- AWS Builder ID 登录
|
|
||||||
- AWS Identity Center (IDC) 登录
|
|
||||||
- 从 Kiro IDE 导入令牌
|
|
||||||
|
|
||||||
## Docker 快速部署
|
|
||||||
|
|
||||||
### 一键部署
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 创建部署目录
|
|
||||||
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
|
||||||
|
|
||||||
# 创建 docker-compose.yml
|
|
||||||
cat > docker-compose.yml << 'EOF'
|
|
||||||
services:
|
|
||||||
cli-proxy-api:
|
|
||||||
image: eceasy/cli-proxy-api-plus:latest
|
|
||||||
container_name: cli-proxy-api-plus
|
|
||||||
ports:
|
|
||||||
- "8317:8317"
|
|
||||||
volumes:
|
|
||||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
|
||||||
- ./auths:/root/.cli-proxy-api
|
|
||||||
- ./logs:/CLIProxyAPI/logs
|
|
||||||
restart: unless-stopped
|
|
||||||
EOF
|
|
||||||
|
|
||||||
# 下载示例配置
|
|
||||||
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
|
|
||||||
|
|
||||||
# 拉取并启动
|
|
||||||
docker compose pull && docker compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
### 配置说明
|
|
||||||
|
|
||||||
启动前请编辑 `config.yaml`:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# 基本配置示例
|
|
||||||
server:
|
|
||||||
port: 8317
|
|
||||||
|
|
||||||
# 在此添加你的供应商配置
|
|
||||||
```
|
|
||||||
|
|
||||||
### 更新到最新版本
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd ~/cli-proxy
|
|
||||||
docker compose pull && docker compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
## 贡献
|
## 贡献
|
||||||
|
|
||||||
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
|
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
|
||||||
@@ -133,4 +14,4 @@ docker compose pull && docker compose up -d
|
|||||||
|
|
||||||
## 许可证
|
## 许可证
|
||||||
|
|
||||||
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
||||||
|
|||||||
BIN
assets/bmoplus.png
Normal file
BIN
assets/bmoplus.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 28 KiB |
BIN
assets/lingtrue.png
Normal file
BIN
assets/lingtrue.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 129 KiB |
276
cmd/fetch_antigravity_models/main.go
Normal file
276
cmd/fetch_antigravity_models/main.go
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
// Command fetch_antigravity_models connects to the Antigravity API using the
|
||||||
|
// stored auth credentials and saves the dynamically fetched model list to a
|
||||||
|
// JSON file for inspection or offline use.
|
||||||
|
//
|
||||||
|
// Usage:
|
||||||
|
//
|
||||||
|
// go run ./cmd/fetch_antigravity_models [flags]
|
||||||
|
//
|
||||||
|
// Flags:
|
||||||
|
//
|
||||||
|
// --auths-dir <path> Directory containing auth JSON files (default: "auths")
|
||||||
|
// --output <path> Output JSON file path (default: "antigravity_models.json")
|
||||||
|
// --pretty Pretty-print the output JSON (default: true)
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
sdkauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
antigravityBaseURLDaily = "https://daily-cloudcode-pa.googleapis.com"
|
||||||
|
antigravitySandboxBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||||
|
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
|
||||||
|
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
logging.SetupBaseLogger()
|
||||||
|
log.SetLevel(log.InfoLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelOutput wraps the fetched model list with fetch metadata.
|
||||||
|
type modelOutput struct {
|
||||||
|
Models []modelEntry `json:"models"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelEntry contains only the fields we want to keep for static model definitions.
|
||||||
|
type modelEntry struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
DisplayName string `json:"display_name"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
ContextLength int `json:"context_length,omitempty"`
|
||||||
|
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var authsDir string
|
||||||
|
var outputPath string
|
||||||
|
var pretty bool
|
||||||
|
|
||||||
|
flag.StringVar(&authsDir, "auths-dir", "auths", "Directory containing auth JSON files")
|
||||||
|
flag.StringVar(&outputPath, "output", "antigravity_models.json", "Output JSON file path")
|
||||||
|
flag.BoolVar(&pretty, "pretty", true, "Pretty-print the output JSON")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// Resolve relative paths against the working directory.
|
||||||
|
wd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: cannot get working directory: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if !filepath.IsAbs(authsDir) {
|
||||||
|
authsDir = filepath.Join(wd, authsDir)
|
||||||
|
}
|
||||||
|
if !filepath.IsAbs(outputPath) {
|
||||||
|
outputPath = filepath.Join(wd, outputPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Scanning auth files in: %s\n", authsDir)
|
||||||
|
|
||||||
|
// Load all auth records from the directory.
|
||||||
|
fileStore := sdkauth.NewFileTokenStore()
|
||||||
|
fileStore.SetBaseDir(authsDir)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
auths, err := fileStore.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: failed to list auth files: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if len(auths) == 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: no auth files found in %s\n", authsDir)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the first enabled antigravity auth.
|
||||||
|
var chosen *coreauth.Auth
|
||||||
|
for _, a := range auths {
|
||||||
|
if a == nil || a.Disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.EqualFold(strings.TrimSpace(a.Provider), "antigravity") {
|
||||||
|
chosen = a
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if chosen == nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: no enabled antigravity auth found in %s\n", authsDir)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Using auth: id=%s label=%s\n", chosen.ID, chosen.Label)
|
||||||
|
|
||||||
|
// Fetch models from the upstream Antigravity API.
|
||||||
|
fmt.Println("Fetching Antigravity model list from upstream...")
|
||||||
|
|
||||||
|
fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
models := fetchModels(fetchCtx, chosen)
|
||||||
|
if len(models) == 0 {
|
||||||
|
fmt.Fprintln(os.Stderr, "warning: no models returned (API may be unavailable or token expired)")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Fetched %d models.\n", len(models))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the output payload.
|
||||||
|
out := modelOutput{
|
||||||
|
Models: models,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal to JSON.
|
||||||
|
var raw []byte
|
||||||
|
if pretty {
|
||||||
|
raw, err = json.MarshalIndent(out, "", " ")
|
||||||
|
} else {
|
||||||
|
raw, err = json.Marshal(out)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: failed to marshal JSON: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = os.WriteFile(outputPath, raw, 0o644); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: failed to write output file %s: %v\n", outputPath, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Model list saved to: %s\n", outputPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchModels(ctx context.Context, auth *coreauth.Auth) []modelEntry {
|
||||||
|
accessToken := metaStringValue(auth.Metadata, "access_token")
|
||||||
|
if accessToken == "" {
|
||||||
|
fmt.Fprintln(os.Stderr, "error: no access token found in auth")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURLs := []string{antigravityBaseURLProd, antigravityBaseURLDaily, antigravitySandboxBaseURLDaily}
|
||||||
|
|
||||||
|
for _, baseURL := range baseURLs {
|
||||||
|
modelsURL := baseURL + antigravityModelsPath
|
||||||
|
|
||||||
|
var payload []byte
|
||||||
|
if auth != nil && auth.Metadata != nil {
|
||||||
|
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
|
||||||
|
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(payload) == 0 {
|
||||||
|
payload = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, strings.NewReader(string(payload)))
|
||||||
|
if errReq != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
httpReq.Close = true
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
httpReq.Header.Set("User-Agent", misc.AntigravityUserAgent())
|
||||||
|
|
||||||
|
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
if transport, _, errProxy := proxyutil.BuildHTTPTransport(auth.ProxyURL); errProxy == nil && transport != nil {
|
||||||
|
httpClient.Transport = transport
|
||||||
|
}
|
||||||
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
|
if errDo != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
httpResp.Body.Close()
|
||||||
|
if errRead != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
result := gjson.GetBytes(bodyBytes, "models")
|
||||||
|
if !result.Exists() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var models []modelEntry
|
||||||
|
|
||||||
|
for originalName, modelData := range result.Map() {
|
||||||
|
modelID := strings.TrimSpace(originalName)
|
||||||
|
if modelID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Skip internal/experimental models
|
||||||
|
switch modelID {
|
||||||
|
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
displayName := modelData.Get("displayName").String()
|
||||||
|
if displayName == "" {
|
||||||
|
displayName = modelID
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := modelEntry{
|
||||||
|
ID: modelID,
|
||||||
|
Object: "model",
|
||||||
|
OwnedBy: "antigravity",
|
||||||
|
Type: "antigravity",
|
||||||
|
DisplayName: displayName,
|
||||||
|
Name: modelID,
|
||||||
|
Description: displayName,
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
|
||||||
|
entry.ContextLength = int(maxTok)
|
||||||
|
}
|
||||||
|
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
|
||||||
|
entry.MaxCompletionTokens = int(maxOut)
|
||||||
|
}
|
||||||
|
|
||||||
|
models = append(models, entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func metaStringValue(m map[string]interface{}, key string) string {
|
||||||
|
if m == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
v, ok := m[key]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
switch val := v.(type) {
|
||||||
|
case string:
|
||||||
|
return val
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
20
cmd/mcpdebug/main.go
Normal file
20
cmd/mcpdebug/main.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Encode MCP result with empty execId
|
||||||
|
resultBytes := cursorproto.EncodeExecMcpResult(1, "", `{"test": "data"}`, false)
|
||||||
|
fmt.Printf("Result protobuf hex: %s\n", hex.EncodeToString(resultBytes))
|
||||||
|
fmt.Printf("Result length: %d bytes\n", len(resultBytes))
|
||||||
|
|
||||||
|
// Write to file for analysis
|
||||||
|
os.WriteFile("mcp_result.bin", resultBytes)
|
||||||
|
fmt.Println("Wrote mcp_result.bin")
|
||||||
|
}
|
||||||
32
cmd/protocheck/main.go
Normal file
32
cmd/protocheck/main.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
ecm := cursorproto.NewMsg("ExecClientMessage")
|
||||||
|
|
||||||
|
// Try different field names
|
||||||
|
names := []string{
|
||||||
|
"mcp_result", "mcpResult", "McpResult", "MCP_RESULT",
|
||||||
|
"shell_result", "shellResult",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range names {
|
||||||
|
fd := ecm.Descriptor().Fields().ByName(name)
|
||||||
|
if fd != nil {
|
||||||
|
fmt.Printf("Found field %q: number=%d, kind=%s\n", name, fd.Number(), fd.Kind())
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Field %q NOT FOUND\n", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List all fields
|
||||||
|
fmt.Println("\nAll fields in ExecClientMessage:")
|
||||||
|
for i := 0; i < ecm.Descriptor().Fields().Len(); i++ {
|
||||||
|
f := ecm.Descriptor().Fields().Get(i)
|
||||||
|
fmt.Printf(" %d: %q (number=%d)\n", i, f.Name(), f.Number())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
||||||
@@ -74,14 +75,16 @@ func main() {
|
|||||||
var codexLogin bool
|
var codexLogin bool
|
||||||
var codexDeviceLogin bool
|
var codexDeviceLogin bool
|
||||||
var claudeLogin bool
|
var claudeLogin bool
|
||||||
var qwenLogin bool
|
|
||||||
var kiloLogin bool
|
var kiloLogin bool
|
||||||
var iflowLogin bool
|
var iflowLogin bool
|
||||||
var iflowCookie bool
|
var iflowCookie bool
|
||||||
|
var gitlabLogin bool
|
||||||
|
var gitlabTokenLogin bool
|
||||||
var noBrowser bool
|
var noBrowser bool
|
||||||
var oauthCallbackPort int
|
var oauthCallbackPort int
|
||||||
var antigravityLogin bool
|
var antigravityLogin bool
|
||||||
var kimiLogin bool
|
var kimiLogin bool
|
||||||
|
var cursorLogin bool
|
||||||
var kiroLogin bool
|
var kiroLogin bool
|
||||||
var kiroGoogleLogin bool
|
var kiroGoogleLogin bool
|
||||||
var kiroAWSLogin bool
|
var kiroAWSLogin bool
|
||||||
@@ -92,30 +95,35 @@ func main() {
|
|||||||
var kiroIDCRegion string
|
var kiroIDCRegion string
|
||||||
var kiroIDCFlow string
|
var kiroIDCFlow string
|
||||||
var githubCopilotLogin bool
|
var githubCopilotLogin bool
|
||||||
|
var codeBuddyLogin bool
|
||||||
var projectID string
|
var projectID string
|
||||||
var vertexImport string
|
var vertexImport string
|
||||||
|
var vertexImportPrefix string
|
||||||
var configPath string
|
var configPath string
|
||||||
var password string
|
var password string
|
||||||
var tuiMode bool
|
var tuiMode bool
|
||||||
var standalone bool
|
var standalone bool
|
||||||
var noIncognito bool
|
var noIncognito bool
|
||||||
var useIncognito bool
|
var useIncognito bool
|
||||||
|
var localModel bool
|
||||||
|
|
||||||
// Define command-line flags for different operation modes.
|
// Define command-line flags for different operation modes.
|
||||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||||
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
||||||
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
|
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
|
||||||
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
||||||
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
|
||||||
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
||||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||||
|
flag.BoolVar(&gitlabLogin, "gitlab-login", false, "Login to GitLab Duo using OAuth")
|
||||||
|
flag.BoolVar(&gitlabTokenLogin, "gitlab-token-login", false, "Login to GitLab Duo using a personal access token")
|
||||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||||
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
||||||
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
||||||
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
||||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||||
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
|
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
|
||||||
|
flag.BoolVar(&cursorLogin, "cursor-login", false, "Login to Cursor using OAuth")
|
||||||
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
||||||
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
||||||
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
||||||
@@ -126,12 +134,15 @@ func main() {
|
|||||||
flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)")
|
flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)")
|
||||||
flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device")
|
flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device")
|
||||||
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
||||||
|
flag.BoolVar(&codeBuddyLogin, "codebuddy-login", false, "Login to CodeBuddy using browser OAuth flow")
|
||||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||||
|
flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)")
|
||||||
flag.StringVar(&password, "password", "", "")
|
flag.StringVar(&password, "password", "", "")
|
||||||
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
||||||
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
||||||
|
flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching")
|
||||||
|
|
||||||
flag.CommandLine.Usage = func() {
|
flag.CommandLine.Usage = func() {
|
||||||
out := flag.CommandLine.Output()
|
out := flag.CommandLine.Output()
|
||||||
@@ -177,6 +188,7 @@ func main() {
|
|||||||
gitStoreRemoteURL string
|
gitStoreRemoteURL string
|
||||||
gitStoreUser string
|
gitStoreUser string
|
||||||
gitStorePassword string
|
gitStorePassword string
|
||||||
|
gitStoreBranch string
|
||||||
gitStoreLocalPath string
|
gitStoreLocalPath string
|
||||||
gitStoreInst *store.GitTokenStore
|
gitStoreInst *store.GitTokenStore
|
||||||
gitStoreRoot string
|
gitStoreRoot string
|
||||||
@@ -246,6 +258,9 @@ func main() {
|
|||||||
if value, ok := lookupEnv("GITSTORE_LOCAL_PATH", "gitstore_local_path"); ok {
|
if value, ok := lookupEnv("GITSTORE_LOCAL_PATH", "gitstore_local_path"); ok {
|
||||||
gitStoreLocalPath = value
|
gitStoreLocalPath = value
|
||||||
}
|
}
|
||||||
|
if value, ok := lookupEnv("GITSTORE_GIT_BRANCH", "gitstore_git_branch"); ok {
|
||||||
|
gitStoreBranch = value
|
||||||
|
}
|
||||||
if value, ok := lookupEnv("OBJECTSTORE_ENDPOINT", "objectstore_endpoint"); ok {
|
if value, ok := lookupEnv("OBJECTSTORE_ENDPOINT", "objectstore_endpoint"); ok {
|
||||||
useObjectStore = true
|
useObjectStore = true
|
||||||
objectStoreEndpoint = value
|
objectStoreEndpoint = value
|
||||||
@@ -380,7 +395,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
gitStoreRoot = filepath.Join(gitStoreLocalPath, "gitstore")
|
gitStoreRoot = filepath.Join(gitStoreLocalPath, "gitstore")
|
||||||
authDir := filepath.Join(gitStoreRoot, "auths")
|
authDir := filepath.Join(gitStoreRoot, "auths")
|
||||||
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword)
|
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword, gitStoreBranch)
|
||||||
gitStoreInst.SetBaseDir(authDir)
|
gitStoreInst.SetBaseDir(authDir)
|
||||||
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
|
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
|
||||||
log.Errorf("failed to prepare git token store: %v", errRepo)
|
log.Errorf("failed to prepare git token store: %v", errRepo)
|
||||||
@@ -499,7 +514,7 @@ func main() {
|
|||||||
|
|
||||||
if vertexImport != "" {
|
if vertexImport != "" {
|
||||||
// Handle Vertex service account import
|
// Handle Vertex service account import
|
||||||
cmd.DoVertexImport(cfg, vertexImport)
|
cmd.DoVertexImport(cfg, vertexImport, vertexImportPrefix)
|
||||||
} else if login {
|
} else if login {
|
||||||
// Handle Google/Gemini login
|
// Handle Google/Gemini login
|
||||||
cmd.DoLogin(cfg, projectID, options)
|
cmd.DoLogin(cfg, projectID, options)
|
||||||
@@ -509,6 +524,9 @@ func main() {
|
|||||||
} else if githubCopilotLogin {
|
} else if githubCopilotLogin {
|
||||||
// Handle GitHub Copilot login
|
// Handle GitHub Copilot login
|
||||||
cmd.DoGitHubCopilotLogin(cfg, options)
|
cmd.DoGitHubCopilotLogin(cfg, options)
|
||||||
|
} else if codeBuddyLogin {
|
||||||
|
// Handle CodeBuddy login
|
||||||
|
cmd.DoCodeBuddyLogin(cfg, options)
|
||||||
} else if codexLogin {
|
} else if codexLogin {
|
||||||
// Handle Codex login
|
// Handle Codex login
|
||||||
cmd.DoCodexLogin(cfg, options)
|
cmd.DoCodexLogin(cfg, options)
|
||||||
@@ -518,16 +536,20 @@ func main() {
|
|||||||
} else if claudeLogin {
|
} else if claudeLogin {
|
||||||
// Handle Claude login
|
// Handle Claude login
|
||||||
cmd.DoClaudeLogin(cfg, options)
|
cmd.DoClaudeLogin(cfg, options)
|
||||||
} else if qwenLogin {
|
|
||||||
cmd.DoQwenLogin(cfg, options)
|
|
||||||
} else if kiloLogin {
|
} else if kiloLogin {
|
||||||
cmd.DoKiloLogin(cfg, options)
|
cmd.DoKiloLogin(cfg, options)
|
||||||
} else if iflowLogin {
|
} else if iflowLogin {
|
||||||
cmd.DoIFlowLogin(cfg, options)
|
cmd.DoIFlowLogin(cfg, options)
|
||||||
} else if iflowCookie {
|
} else if iflowCookie {
|
||||||
cmd.DoIFlowCookieAuth(cfg, options)
|
cmd.DoIFlowCookieAuth(cfg, options)
|
||||||
|
} else if gitlabLogin {
|
||||||
|
cmd.DoGitLabLogin(cfg, options)
|
||||||
|
} else if gitlabTokenLogin {
|
||||||
|
cmd.DoGitLabTokenLogin(cfg, options)
|
||||||
} else if kimiLogin {
|
} else if kimiLogin {
|
||||||
cmd.DoKimiLogin(cfg, options)
|
cmd.DoKimiLogin(cfg, options)
|
||||||
|
} else if cursorLogin {
|
||||||
|
cmd.DoCursorLogin(cfg, options)
|
||||||
} else if kiroLogin {
|
} else if kiroLogin {
|
||||||
// For Kiro auth, default to incognito mode for multi-account support
|
// For Kiro auth, default to incognito mode for multi-account support
|
||||||
// Users can explicitly override with --no-incognito
|
// Users can explicitly override with --no-incognito
|
||||||
@@ -569,10 +591,17 @@ func main() {
|
|||||||
cmd.WaitForCloudDeploy()
|
cmd.WaitForCloudDeploy()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if localModel && (!tuiMode || standalone) {
|
||||||
|
log.Info("Local model mode: using embedded model catalog, remote model updates disabled")
|
||||||
|
}
|
||||||
if tuiMode {
|
if tuiMode {
|
||||||
if standalone {
|
if standalone {
|
||||||
// Standalone mode: start an embedded local server and connect TUI client to it.
|
// Standalone mode: start an embedded local server and connect TUI client to it.
|
||||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||||
|
misc.StartAntigravityVersionUpdater(context.Background())
|
||||||
|
if !localModel {
|
||||||
|
registry.StartModelsUpdater(context.Background())
|
||||||
|
}
|
||||||
hook := tui.NewLogHook(2000)
|
hook := tui.NewLogHook(2000)
|
||||||
hook.SetFormatter(&logging.LogFormatter{})
|
hook.SetFormatter(&logging.LogFormatter{})
|
||||||
log.AddHook(hook)
|
log.AddHook(hook)
|
||||||
@@ -643,15 +672,19 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Start the main proxy service
|
// Start the main proxy service
|
||||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||||
|
misc.StartAntigravityVersionUpdater(context.Background())
|
||||||
|
if !localModel {
|
||||||
|
registry.StartModelsUpdater(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
if cfg.AuthDir != "" {
|
if cfg.AuthDir != "" {
|
||||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||||
defer kiro.StopGlobalRefreshManager()
|
defer kiro.StopGlobalRefreshManager()
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.StartService(cfg, configFilePath, password)
|
cmd.StartService(cfg, configFilePath, password)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,6 +25,10 @@ remote-management:
|
|||||||
# Disable the bundled management control panel asset download and HTTP route when true.
|
# Disable the bundled management control panel asset download and HTTP route when true.
|
||||||
disable-control-panel: false
|
disable-control-panel: false
|
||||||
|
|
||||||
|
# Disable automatic periodic background updates of the management panel from GitHub (default: false).
|
||||||
|
# When enabled, the panel is only downloaded on first access if missing, and never auto-updated afterward.
|
||||||
|
# disable-auto-update-panel: false
|
||||||
|
|
||||||
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
|
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
|
||||||
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
|
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
|
||||||
|
|
||||||
@@ -68,7 +72,8 @@ error-logs-max-files: 10
|
|||||||
usage-statistics-enabled: false
|
usage-statistics-enabled: false
|
||||||
|
|
||||||
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
||||||
proxy-url: ''
|
# Per-entry proxy-url also supports "direct" or "none" to bypass both the global proxy-url and environment proxies explicitly.
|
||||||
|
proxy-url: ""
|
||||||
|
|
||||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||||
force-model-prefix: false
|
force-model-prefix: false
|
||||||
@@ -87,26 +92,54 @@ max-retry-credentials: 0
|
|||||||
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
|
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
|
||||||
max-retry-interval: 30
|
max-retry-interval: 30
|
||||||
|
|
||||||
|
# When true, disable auth/model cooldown scheduling globally (prevents blackout windows after failure states).
|
||||||
|
disable-cooling: false
|
||||||
|
|
||||||
|
# Core auth auto-refresh worker pool size (OAuth/file-based auth token refresh).
|
||||||
|
# When > 0, overrides the default worker count (16).
|
||||||
|
# auth-auto-refresh-workers: 16
|
||||||
|
|
||||||
# Quota exceeded behavior
|
# Quota exceeded behavior
|
||||||
quota-exceeded:
|
quota-exceeded:
|
||||||
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||||
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
||||||
|
antigravity-credits: true # Whether to retry Antigravity quota_exhausted 429s once with enabledCreditTypes=["GOOGLE_ONE_AI"]
|
||||||
|
|
||||||
# Routing strategy for selecting credentials when multiple match.
|
# Routing strategy for selecting credentials when multiple match.
|
||||||
routing:
|
routing:
|
||||||
strategy: 'round-robin' # round-robin (default), fill-first
|
strategy: "round-robin" # round-robin (default), fill-first
|
||||||
|
# Enable universal session-sticky routing for all clients.
|
||||||
|
# Session IDs are extracted from: X-Session-ID header, Idempotency-Key,
|
||||||
|
# metadata.user_id, conversation_id, or first few messages hash.
|
||||||
|
# Automatic failover is always enabled when bound auth becomes unavailable.
|
||||||
|
session-affinity: false # default: false
|
||||||
|
# How long session-to-auth bindings are retained. Default: 1h
|
||||||
|
session-affinity-ttl: "1h"
|
||||||
|
|
||||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||||
ws-auth: false
|
ws-auth: false
|
||||||
|
|
||||||
|
# When true, enable Gemini CLI internal endpoints (/v1internal:*).
|
||||||
|
# Default is false for safety.
|
||||||
|
enable-gemini-cli-endpoint: false
|
||||||
|
|
||||||
# When > 0, emit blank lines every N seconds for non-streaming responses to prevent idle timeouts.
|
# When > 0, emit blank lines every N seconds for non-streaming responses to prevent idle timeouts.
|
||||||
nonstream-keepalive-interval: 0
|
nonstream-keepalive-interval: 0
|
||||||
|
|
||||||
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
|
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
|
||||||
# streaming:
|
# streaming:
|
||||||
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
|
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
|
||||||
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
|
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
|
||||||
|
|
||||||
|
# Signature cache validation for thinking blocks (Antigravity/Claude).
|
||||||
|
# When true (default), cached signatures are preferred and validated.
|
||||||
|
# When false, client signatures are used directly after normalization (bypass mode for testing).
|
||||||
|
# antigravity-signature-cache-enabled: true
|
||||||
|
|
||||||
|
# Bypass mode signature validation strictness (only applies when signature cache is disabled).
|
||||||
|
# When true, validates full Claude protobuf tree (Field 2 -> Field 1 structure).
|
||||||
|
# When false (default), only checks R/E prefix + base64 + first byte 0x12.
|
||||||
|
# antigravity-signature-bypass-strict: false
|
||||||
|
|
||||||
# Gemini API keys
|
# Gemini API keys
|
||||||
# gemini-api-key:
|
# gemini-api-key:
|
||||||
# - api-key: "AIzaSy...01"
|
# - api-key: "AIzaSy...01"
|
||||||
@@ -115,6 +148,7 @@ nonstream-keepalive-interval: 0
|
|||||||
# headers:
|
# headers:
|
||||||
# X-Custom-Header: "custom-value"
|
# X-Custom-Header: "custom-value"
|
||||||
# proxy-url: "socks5://proxy.example.com:1080"
|
# proxy-url: "socks5://proxy.example.com:1080"
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# models:
|
# models:
|
||||||
# - name: "gemini-2.5-flash" # upstream model name
|
# - name: "gemini-2.5-flash" # upstream model name
|
||||||
# alias: "gemini-flash" # client alias mapped to the upstream model
|
# alias: "gemini-flash" # client alias mapped to the upstream model
|
||||||
@@ -133,6 +167,7 @@ nonstream-keepalive-interval: 0
|
|||||||
# headers:
|
# headers:
|
||||||
# X-Custom-Header: "custom-value"
|
# X-Custom-Header: "custom-value"
|
||||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# models:
|
# models:
|
||||||
# - name: "gpt-5-codex" # upstream model name
|
# - name: "gpt-5-codex" # upstream model name
|
||||||
# alias: "codex-latest" # client alias mapped to the upstream model
|
# alias: "codex-latest" # client alias mapped to the upstream model
|
||||||
@@ -151,6 +186,7 @@ nonstream-keepalive-interval: 0
|
|||||||
# headers:
|
# headers:
|
||||||
# X-Custom-Header: "custom-value"
|
# X-Custom-Header: "custom-value"
|
||||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# models:
|
# models:
|
||||||
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
||||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||||
@@ -169,14 +205,31 @@ nonstream-keepalive-interval: 0
|
|||||||
# - "API"
|
# - "API"
|
||||||
# - "proxy"
|
# - "proxy"
|
||||||
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
||||||
|
# experimental-cch-signing: false # optional: default is false; when true, sign the final /v1/messages body using the current Claude Code cch algorithm
|
||||||
|
# # keep this disabled unless you explicitly need the behavior, so upstream seed changes fall back to legacy proxy behavior
|
||||||
|
|
||||||
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
||||||
# These are used as fallbacks when the client does not send its own headers.
|
# In legacy mode, user-agent/package-version/runtime-version/timeout are used as fallbacks
|
||||||
|
# when the client omits them, while OS/arch remain runtime-derived. When
|
||||||
|
# stabilize-device-profile is enabled, OS/arch stay pinned to the baseline values below,
|
||||||
|
# while user-agent/package-version/runtime-version seed a software fingerprint that can
|
||||||
|
# still upgrade to newer official Claude client versions.
|
||||||
# claude-header-defaults:
|
# claude-header-defaults:
|
||||||
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
|
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
|
||||||
# package-version: "0.74.0"
|
# package-version: "0.74.0"
|
||||||
# runtime-version: "v24.3.0"
|
# runtime-version: "v24.3.0"
|
||||||
|
# os: "MacOS"
|
||||||
|
# arch: "arm64"
|
||||||
# timeout: "600"
|
# timeout: "600"
|
||||||
|
# stabilize-device-profile: false # optional, default false; set true to enable per-auth/API-key fingerprint pinning
|
||||||
|
|
||||||
|
# Default headers for Codex OAuth model requests.
|
||||||
|
# These are used only for file-backed/OAuth Codex requests when the client
|
||||||
|
# does not send the header. `user-agent` applies to HTTP and websocket requests;
|
||||||
|
# `beta-features` only applies to websocket requests. They do not apply to codex-api-key entries.
|
||||||
|
# codex-header-defaults:
|
||||||
|
# user-agent: "codex_cli_rs/0.114.0 (Mac OS 14.2.0; x86_64) vscode/1.111.0"
|
||||||
|
# beta-features: "multi_agent"
|
||||||
|
|
||||||
# Kiro (AWS CodeWhisperer) configuration
|
# Kiro (AWS CodeWhisperer) configuration
|
||||||
# Note: Kiro API currently only operates in us-east-1 region
|
# Note: Kiro API currently only operates in us-east-1 region
|
||||||
@@ -215,17 +268,32 @@ nonstream-keepalive-interval: 0
|
|||||||
# api-key-entries:
|
# api-key-entries:
|
||||||
# - api-key: "sk-or-v1-...b780"
|
# - api-key: "sk-or-v1-...b780"
|
||||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# - api-key: "sk-or-v1-...b781" # without proxy-url
|
# - api-key: "sk-or-v1-...b781" # without proxy-url
|
||||||
# models: # The models supported by the provider.
|
# models: # The models supported by the provider.
|
||||||
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
||||||
# alias: "kimi-k2" # The alias used in the API.
|
# alias: "kimi-k2" # The alias used in the API.
|
||||||
|
# thinking: # optional: omit to default to levels ["low","medium","high"]
|
||||||
|
# levels: ["low", "medium", "high"]
|
||||||
|
# # You may repeat the same alias to build an internal model pool.
|
||||||
|
# # The client still sees only one alias in the model list.
|
||||||
|
# # Requests to that alias will round-robin across the upstream names below,
|
||||||
|
# # and if the chosen upstream fails before producing output, the request will
|
||||||
|
# # continue with the next upstream model in the same alias pool.
|
||||||
|
# - name: "deepseek-v3.1"
|
||||||
|
# alias: "claude-opus-4.66"
|
||||||
|
# - name: "glm-5"
|
||||||
|
# alias: "claude-opus-4.66"
|
||||||
|
# - name: "kimi-k2.5"
|
||||||
|
# alias: "claude-opus-4.66"
|
||||||
|
|
||||||
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
# Vertex API keys (Vertex-compatible endpoints, base-url is optional)
|
||||||
# vertex-api-key:
|
# vertex-api-key:
|
||||||
# - api-key: "vk-123..." # x-goog-api-key header
|
# - api-key: "vk-123..." # x-goog-api-key header
|
||||||
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
||||||
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
|
# base-url: "https://example.com/api" # optional, e.g. https://zenmux.ai/api; falls back to Google Vertex when omitted
|
||||||
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# headers:
|
# headers:
|
||||||
# X-Custom-Header: "custom-value"
|
# X-Custom-Header: "custom-value"
|
||||||
# models: # optional: map aliases to upstream model names
|
# models: # optional: map aliases to upstream model names
|
||||||
@@ -273,8 +341,12 @@ nonstream-keepalive-interval: 0
|
|||||||
|
|
||||||
# Global OAuth model name aliases (per channel)
|
# Global OAuth model name aliases (per channel)
|
||||||
# These aliases rename model IDs for both model listing and request routing.
|
# These aliases rename model IDs for both model listing and request routing.
|
||||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
|
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, iflow, kiro, github-copilot, kimi.
|
||||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||||
|
# NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping
|
||||||
|
# client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps
|
||||||
|
# you select the protocol surface, but inference backend selection can still follow the resolved
|
||||||
|
# model/alias. For strict backend pinning, use unique aliases/prefixes or avoid overlapping names.
|
||||||
# You can repeat the same name with different aliases to expose multiple client model names.
|
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||||
# oauth-model-alias:
|
# oauth-model-alias:
|
||||||
# antigravity:
|
# antigravity:
|
||||||
@@ -308,12 +380,6 @@ nonstream-keepalive-interval: 0
|
|||||||
# codex:
|
# codex:
|
||||||
# - name: "gpt-5"
|
# - name: "gpt-5"
|
||||||
# alias: "g5"
|
# alias: "g5"
|
||||||
# qwen:
|
|
||||||
# - name: "qwen3-coder-plus"
|
|
||||||
# alias: "qwen-plus"
|
|
||||||
# iflow:
|
|
||||||
# - name: "glm-4.7"
|
|
||||||
# alias: "glm-god"
|
|
||||||
# kimi:
|
# kimi:
|
||||||
# - name: "kimi-k2.5"
|
# - name: "kimi-k2.5"
|
||||||
# alias: "k2.5"
|
# alias: "k2.5"
|
||||||
@@ -342,10 +408,6 @@ nonstream-keepalive-interval: 0
|
|||||||
# - "claude-3-5-haiku-20241022"
|
# - "claude-3-5-haiku-20241022"
|
||||||
# codex:
|
# codex:
|
||||||
# - "gpt-5-codex-mini"
|
# - "gpt-5-codex-mini"
|
||||||
# qwen:
|
|
||||||
# - "vision-model"
|
|
||||||
# iflow:
|
|
||||||
# - "tstars2.0"
|
|
||||||
# kimi:
|
# kimi:
|
||||||
# - "kimi-k2-thinking"
|
# - "kimi-k2-thinking"
|
||||||
# kiro:
|
# kiro:
|
||||||
|
|||||||
@@ -109,10 +109,19 @@ wait_for_service() {
|
|||||||
sleep 2
|
sleep 2
|
||||||
}
|
}
|
||||||
|
|
||||||
if [[ "${1:-}" == "--with-usage" ]]; then
|
case "${1:-}" in
|
||||||
WITH_USAGE=true
|
"")
|
||||||
export_stats_api_secret
|
;;
|
||||||
fi
|
"--with-usage")
|
||||||
|
WITH_USAGE=true
|
||||||
|
export_stats_api_secret
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Error: unknown option '${1}'. Did you mean '--with-usage'?"
|
||||||
|
echo "Usage: ./docker-build.sh [--with-usage]"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
# --- Step 1: Choose Environment ---
|
# --- Step 1: Choose Environment ---
|
||||||
echo "Please select an option:"
|
echo "Please select an option:"
|
||||||
|
|||||||
115
docs/gitlab-duo.md
Normal file
115
docs/gitlab-duo.md
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
# GitLab Duo guide
|
||||||
|
|
||||||
|
CLIProxyAPI can now use GitLab Duo as a first-class provider instead of treating it as a plain text wrapper.
|
||||||
|
|
||||||
|
It supports:
|
||||||
|
|
||||||
|
- OAuth login
|
||||||
|
- personal access token login
|
||||||
|
- automatic refresh of GitLab `direct_access` metadata
|
||||||
|
- dynamic model discovery from GitLab metadata
|
||||||
|
- native GitLab AI gateway routing for Anthropic and OpenAI/Codex managed models
|
||||||
|
- Claude-compatible and OpenAI-compatible downstream APIs
|
||||||
|
|
||||||
|
## What this means
|
||||||
|
|
||||||
|
If GitLab Duo returns an Anthropic-managed model, CLIProxyAPI routes requests through the GitLab AI gateway Anthropic proxy and uses the existing Claude executor path.
|
||||||
|
|
||||||
|
If GitLab Duo returns an OpenAI-managed model, CLIProxyAPI routes requests through the GitLab AI gateway OpenAI proxy and uses the existing Codex/OpenAI executor path.
|
||||||
|
|
||||||
|
That gives GitLab Duo much closer runtime behavior to the built-in `codex` provider:
|
||||||
|
|
||||||
|
- Claude-compatible clients can use GitLab Duo models through `/v1/messages`
|
||||||
|
- OpenAI-compatible clients can use GitLab Duo models through `/v1/chat/completions`
|
||||||
|
- OpenAI Responses clients can use GitLab Duo models through `/v1/responses`
|
||||||
|
|
||||||
|
The model list is not hardcoded. CLIProxyAPI reads the current model metadata from GitLab `direct_access` and registers:
|
||||||
|
|
||||||
|
- a stable alias: `gitlab-duo`
|
||||||
|
- any discovered managed model names, such as `claude-sonnet-4-5` or `gpt-5-codex`
|
||||||
|
|
||||||
|
## Login
|
||||||
|
|
||||||
|
OAuth login:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./CLIProxyAPI -gitlab-login
|
||||||
|
```
|
||||||
|
|
||||||
|
PAT login:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./CLIProxyAPI -gitlab-token-login
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also provide inputs through environment variables:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export GITLAB_BASE_URL=https://gitlab.com
|
||||||
|
export GITLAB_OAUTH_CLIENT_ID=your-client-id
|
||||||
|
export GITLAB_OAUTH_CLIENT_SECRET=your-client-secret
|
||||||
|
export GITLAB_PERSONAL_ACCESS_TOKEN=glpat-...
|
||||||
|
```
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
|
||||||
|
- OAuth requires a GitLab OAuth application.
|
||||||
|
- PAT login requires a personal access token that can call the GitLab APIs used by Duo. In practice, `api` scope is the safe baseline.
|
||||||
|
- Self-managed GitLab instances are supported through `GITLAB_BASE_URL`.
|
||||||
|
|
||||||
|
## Using the models
|
||||||
|
|
||||||
|
After login, start CLIProxyAPI normally and point your client at the local proxy.
|
||||||
|
|
||||||
|
You can select:
|
||||||
|
|
||||||
|
- `gitlab-duo` to use the current Duo-managed model for that account
|
||||||
|
- the discovered provider model name if you want to pin it explicitly
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://127.0.0.1:8080/v1/models
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://127.0.0.1:8080/v1/chat/completions \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"model": "gitlab-duo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Write a Go HTTP middleware for request IDs."}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
If the GitLab account is currently mapped to an Anthropic model, Claude-compatible clients can use the same account through the Claude handler path. If the account is currently mapped to an OpenAI/Codex model, OpenAI-compatible clients can use `/v1/chat/completions` or `/v1/responses`.
|
||||||
|
|
||||||
|
## How model freshness works
|
||||||
|
|
||||||
|
CLIProxyAPI does not ship a fixed GitLab Duo model catalog.
|
||||||
|
|
||||||
|
Instead, it refreshes GitLab `direct_access` metadata and uses the returned `model_details` and any discovered model list entries to keep the local registry aligned with the current GitLab-managed model assignment.
|
||||||
|
|
||||||
|
This matches GitLab's current public contract better than hardcoding model names.
|
||||||
|
|
||||||
|
## Current scope
|
||||||
|
|
||||||
|
The GitLab Duo provider now has:
|
||||||
|
|
||||||
|
- OAuth and PAT auth flows
|
||||||
|
- runtime refresh of Duo gateway credentials
|
||||||
|
- native Anthropic gateway routing
|
||||||
|
- native OpenAI/Codex gateway routing
|
||||||
|
- handler-level smoke tests for Claude-compatible and OpenAI-compatible paths
|
||||||
|
|
||||||
|
Still out of scope today:
|
||||||
|
|
||||||
|
- websocket or session-specific parity beyond the current HTTP APIs
|
||||||
|
- GitLab-specific IDE features that are not exposed through the public gateway contract
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- GitLab Code Suggestions API: https://docs.gitlab.com/api/code_suggestions/
|
||||||
|
- GitLab Agent Assistant and managed credentials: https://docs.gitlab.com/user/duo_agent_platform/agent_assistant/
|
||||||
|
- GitLab Duo model selection: https://docs.gitlab.com/user/gitlab_duo/model_selection/
|
||||||
115
docs/gitlab-duo_CN.md
Normal file
115
docs/gitlab-duo_CN.md
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
# GitLab Duo 使用说明
|
||||||
|
|
||||||
|
CLIProxyAPI 现在可以把 GitLab Duo 当作一等 Provider 来使用,而不是仅仅把它当成简单的文本补全封装。
|
||||||
|
|
||||||
|
当前支持:
|
||||||
|
|
||||||
|
- OAuth 登录
|
||||||
|
- personal access token 登录
|
||||||
|
- 自动刷新 GitLab `direct_access` 元数据
|
||||||
|
- 根据 GitLab 返回的元数据动态发现模型
|
||||||
|
- 针对 Anthropic 和 OpenAI/Codex 托管模型的 GitLab AI gateway 原生路由
|
||||||
|
- Claude 兼容与 OpenAI 兼容下游 API
|
||||||
|
|
||||||
|
## 这意味着什么
|
||||||
|
|
||||||
|
如果 GitLab Duo 返回的是 Anthropic 托管模型,CLIProxyAPI 会通过 GitLab AI gateway 的 Anthropic 代理转发,并复用现有的 Claude executor 路径。
|
||||||
|
|
||||||
|
如果 GitLab Duo 返回的是 OpenAI 托管模型,CLIProxyAPI 会通过 GitLab AI gateway 的 OpenAI 代理转发,并复用现有的 Codex/OpenAI executor 路径。
|
||||||
|
|
||||||
|
这让 GitLab Duo 的运行时行为更接近内置的 `codex` Provider:
|
||||||
|
|
||||||
|
- Claude 兼容客户端可以通过 `/v1/messages` 使用 GitLab Duo 模型
|
||||||
|
- OpenAI 兼容客户端可以通过 `/v1/chat/completions` 使用 GitLab Duo 模型
|
||||||
|
- OpenAI Responses 客户端可以通过 `/v1/responses` 使用 GitLab Duo 模型
|
||||||
|
|
||||||
|
模型列表不是硬编码的。CLIProxyAPI 会从 GitLab `direct_access` 中读取当前模型元数据,并注册:
|
||||||
|
|
||||||
|
- 一个稳定别名:`gitlab-duo`
|
||||||
|
- GitLab 当前发现到的托管模型名,例如 `claude-sonnet-4-5` 或 `gpt-5-codex`
|
||||||
|
|
||||||
|
## 登录
|
||||||
|
|
||||||
|
OAuth 登录:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./CLIProxyAPI -gitlab-login
|
||||||
|
```
|
||||||
|
|
||||||
|
PAT 登录:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./CLIProxyAPI -gitlab-token-login
|
||||||
|
```
|
||||||
|
|
||||||
|
也可以通过环境变量提供输入:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export GITLAB_BASE_URL=https://gitlab.com
|
||||||
|
export GITLAB_OAUTH_CLIENT_ID=your-client-id
|
||||||
|
export GITLAB_OAUTH_CLIENT_SECRET=your-client-secret
|
||||||
|
export GITLAB_PERSONAL_ACCESS_TOKEN=glpat-...
|
||||||
|
```
|
||||||
|
|
||||||
|
说明:
|
||||||
|
|
||||||
|
- OAuth 方式需要一个 GitLab OAuth application。
|
||||||
|
- PAT 登录需要一个能够调用 GitLab Duo 相关 API 的 personal access token。实践上,`api` scope 是最稳妥的基线。
|
||||||
|
- 自建 GitLab 实例可以通过 `GITLAB_BASE_URL` 接入。
|
||||||
|
|
||||||
|
## 如何使用模型
|
||||||
|
|
||||||
|
登录完成后,正常启动 CLIProxyAPI,并让客户端连接到本地代理。
|
||||||
|
|
||||||
|
你可以选择:
|
||||||
|
|
||||||
|
- `gitlab-duo`,始终使用该账号当前的 Duo 托管模型
|
||||||
|
- GitLab 当前发现到的 provider 模型名,如果你想显式固定模型
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://127.0.0.1:8080/v1/models
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://127.0.0.1:8080/v1/chat/completions \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"model": "gitlab-duo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Write a Go HTTP middleware for request IDs."}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
如果该 GitLab 账号当前绑定的是 Anthropic 模型,Claude 兼容客户端可以通过 Claude handler 路径直接使用它。如果当前绑定的是 OpenAI/Codex 模型,OpenAI 兼容客户端可以通过 `/v1/chat/completions` 或 `/v1/responses` 使用它。
|
||||||
|
|
||||||
|
## 模型如何保持最新
|
||||||
|
|
||||||
|
CLIProxyAPI 不内置固定的 GitLab Duo 模型清单。
|
||||||
|
|
||||||
|
它会刷新 GitLab `direct_access` 元数据,并使用返回的 `model_details` 以及可能存在的模型列表字段,让本地 registry 尽量与 GitLab 当前分配的托管模型保持一致。
|
||||||
|
|
||||||
|
这比硬编码模型名更符合 GitLab 当前公开 API 的实际契约。
|
||||||
|
|
||||||
|
## 当前覆盖范围
|
||||||
|
|
||||||
|
GitLab Duo Provider 目前已经具备:
|
||||||
|
|
||||||
|
- OAuth 和 PAT 登录流程
|
||||||
|
- Duo gateway 凭据的运行时刷新
|
||||||
|
- Anthropic gateway 原生路由
|
||||||
|
- OpenAI/Codex gateway 原生路由
|
||||||
|
- Claude 兼容和 OpenAI 兼容路径的 handler 级 smoke 测试
|
||||||
|
|
||||||
|
当前仍未覆盖:
|
||||||
|
|
||||||
|
- websocket 或 session 级别的完全对齐
|
||||||
|
- GitLab 公开 gateway 契约之外的 IDE 专有能力
|
||||||
|
|
||||||
|
## 参考资料
|
||||||
|
|
||||||
|
- GitLab Code Suggestions API: https://docs.gitlab.com/api/code_suggestions/
|
||||||
|
- GitLab Agent Assistant 与 managed credentials: https://docs.gitlab.com/user/duo_agent_platform/agent_assistant/
|
||||||
|
- GitLab Duo 模型选择: https://docs.gitlab.com/user/gitlab_duo/model_selection/
|
||||||
@@ -52,11 +52,11 @@ func init() {
|
|||||||
sdktr.Register(fOpenAI, fMyProv,
|
sdktr.Register(fOpenAI, fMyProv,
|
||||||
func(model string, raw []byte, stream bool) []byte { return raw },
|
func(model string, raw []byte, stream bool) []byte { return raw },
|
||||||
sdktr.ResponseTransform{
|
sdktr.ResponseTransform{
|
||||||
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []string {
|
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) [][]byte {
|
||||||
return []string{string(raw)}
|
return [][]byte{raw}
|
||||||
},
|
},
|
||||||
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) string {
|
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []byte {
|
||||||
return string(raw)
|
return raw
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
3
go.mod
3
go.mod
@@ -83,6 +83,7 @@ require (
|
|||||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||||
github.com/muesli/termenv v0.16.0 // indirect
|
github.com/muesli/termenv v0.16.0 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
|
github.com/pierrec/xxHash v0.1.5
|
||||||
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
||||||
github.com/rivo/uniseg v0.4.7 // indirect
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
github.com/rs/xid v1.5.0 // indirect
|
github.com/rs/xid v1.5.0 // indirect
|
||||||
@@ -91,8 +92,8 @@ require (
|
|||||||
github.com/tidwall/pretty v1.2.0 // indirect
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
|
||||||
github.com/x448/float16 v0.8.4 // indirect
|
github.com/x448/float16 v0.8.4 // indirect
|
||||||
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/sys v0.38.0 // indirect
|
golang.org/x/sys v0.38.0 // indirect
|
||||||
golang.org/x/text v0.31.0 // indirect
|
golang.org/x/text v0.31.0 // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -154,6 +154,8 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
|
|||||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
|
github.com/pierrec/xxHash v0.1.5 h1:n/jBpwTHiER4xYvK3/CdPVnLDPchj8eTJFFLUb4QHBo=
|
||||||
|
github.com/pierrec/xxHash v0.1.5/go.mod h1:w2waW5Zoa/Wc4Yqe0wgrIYAGKqRMf7czn2HNKXmuL+I=
|
||||||
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
||||||
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
||||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -14,13 +13,13 @@ import (
|
|||||||
|
|
||||||
"github.com/fxamacker/cbor/v2"
|
"github.com/fxamacker/cbor/v2"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
log "github.com/sirupsen/logrus"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
"golang.org/x/oauth2/google"
|
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/oauth2/google"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultAPICallTimeout = 60 * time.Second
|
const defaultAPICallTimeout = 60 * time.Second
|
||||||
@@ -702,6 +701,11 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
|
|||||||
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
|
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
|
||||||
proxyCandidates = append(proxyCandidates, proxyStr)
|
proxyCandidates = append(proxyCandidates, proxyStr)
|
||||||
}
|
}
|
||||||
|
if h != nil && h.cfg != nil {
|
||||||
|
if proxyStr := strings.TrimSpace(proxyURLFromAPIKeyConfig(h.cfg, auth)); proxyStr != "" {
|
||||||
|
proxyCandidates = append(proxyCandidates, proxyStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if h != nil && h.cfg != nil {
|
if h != nil && h.cfg != nil {
|
||||||
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
|
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
|
||||||
@@ -724,50 +728,132 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
|
|||||||
return clone
|
return clone
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildProxyTransport(proxyStr string) *http.Transport {
|
type apiKeyConfigEntry interface {
|
||||||
proxyStr = strings.TrimSpace(proxyStr)
|
GetAPIKey() string
|
||||||
if proxyStr == "" {
|
GetBaseURL() string
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveAPIKeyConfig[T apiKeyConfigEntry](entries []T, auth *coreauth.Auth) *T {
|
||||||
|
if auth == nil || len(entries) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
attrKey, attrBase := "", ""
|
||||||
proxyURL, errParse := url.Parse(proxyStr)
|
if auth.Attributes != nil {
|
||||||
if errParse != nil {
|
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||||
log.WithError(errParse).Debug("parse proxy URL failed")
|
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
if proxyURL.Scheme == "" || proxyURL.Host == "" {
|
for i := range entries {
|
||||||
log.Debug("proxy URL missing scheme/host")
|
entry := &entries[i]
|
||||||
return nil
|
cfgKey := strings.TrimSpace((*entry).GetAPIKey())
|
||||||
}
|
cfgBase := strings.TrimSpace((*entry).GetBaseURL())
|
||||||
|
if attrKey != "" && attrBase != "" {
|
||||||
if proxyURL.Scheme == "socks5" {
|
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
||||||
var proxyAuth *proxy.Auth
|
return entry
|
||||||
if proxyURL.User != nil {
|
}
|
||||||
username := proxyURL.User.Username()
|
continue
|
||||||
password, _ := proxyURL.User.Password()
|
|
||||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
|
||||||
}
|
}
|
||||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||||
if errSOCKS5 != nil {
|
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||||
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
|
return entry
|
||||||
return nil
|
}
|
||||||
}
|
}
|
||||||
return &http.Transport{
|
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||||
Proxy: nil,
|
return entry
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.Dial(network, addr)
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if attrKey != "" {
|
||||||
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
for i := range entries {
|
||||||
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
entry := &entries[i]
|
||||||
|
if strings.EqualFold(strings.TrimSpace((*entry).GetAPIKey()), attrKey) {
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func proxyURLFromAPIKeyConfig(cfg *config.Config, auth *coreauth.Auth) string {
|
||||||
|
if cfg == nil || auth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
authKind, authAccount := auth.AccountInfo()
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(authKind), "api_key") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
attrs := auth.Attributes
|
||||||
|
compatName := ""
|
||||||
|
providerKey := ""
|
||||||
|
if len(attrs) > 0 {
|
||||||
|
compatName = strings.TrimSpace(attrs["compat_name"])
|
||||||
|
providerKey = strings.TrimSpace(attrs["provider_key"])
|
||||||
|
}
|
||||||
|
if compatName != "" || strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
|
||||||
|
return resolveOpenAICompatAPIKeyProxyURL(cfg, auth, strings.TrimSpace(authAccount), providerKey, compatName)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch strings.ToLower(strings.TrimSpace(auth.Provider)) {
|
||||||
|
case "gemini":
|
||||||
|
if entry := resolveAPIKeyConfig(cfg.GeminiKey, auth); entry != nil {
|
||||||
|
return strings.TrimSpace(entry.ProxyURL)
|
||||||
|
}
|
||||||
|
case "claude":
|
||||||
|
if entry := resolveAPIKeyConfig(cfg.ClaudeKey, auth); entry != nil {
|
||||||
|
return strings.TrimSpace(entry.ProxyURL)
|
||||||
|
}
|
||||||
|
case "codex":
|
||||||
|
if entry := resolveAPIKeyConfig(cfg.CodexKey, auth); entry != nil {
|
||||||
|
return strings.TrimSpace(entry.ProxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveOpenAICompatAPIKeyProxyURL(cfg *config.Config, auth *coreauth.Auth, apiKey, providerKey, compatName string) string {
|
||||||
|
if cfg == nil || auth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
apiKey = strings.TrimSpace(apiKey)
|
||||||
|
if apiKey == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
candidates := make([]string, 0, 3)
|
||||||
|
if v := strings.TrimSpace(compatName); v != "" {
|
||||||
|
candidates = append(candidates, v)
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(providerKey); v != "" {
|
||||||
|
candidates = append(candidates, v)
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(auth.Provider); v != "" {
|
||||||
|
candidates = append(candidates, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range cfg.OpenAICompatibility {
|
||||||
|
compat := &cfg.OpenAICompatibility[i]
|
||||||
|
for _, candidate := range candidates {
|
||||||
|
if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) {
|
||||||
|
for j := range compat.APIKeyEntries {
|
||||||
|
entry := &compat.APIKeyEntries[j]
|
||||||
|
if strings.EqualFold(strings.TrimSpace(entry.APIKey), apiKey) {
|
||||||
|
return strings.TrimSpace(entry.ProxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildProxyTransport(proxyStr string) *http.Transport {
|
||||||
|
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
|
||||||
|
if errBuild != nil {
|
||||||
|
log.WithError(errBuild).Debug("build proxy transport failed")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return transport
|
||||||
|
}
|
||||||
|
|
||||||
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
|
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
|
||||||
func headerContainsValue(headers map[string]string, targetKey, targetValue string) bool {
|
func headerContainsValue(headers map[string]string, targetKey, targetValue string) bool {
|
||||||
if len(headers) == 0 {
|
if len(headers) == 0 {
|
||||||
|
|||||||
@@ -2,172 +2,211 @@ package management
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type memoryAuthStore struct {
|
func TestAPICallTransportDirectBypassesGlobalProxy(t *testing.T) {
|
||||||
mu sync.Mutex
|
t.Parallel()
|
||||||
items map[string]*coreauth.Auth
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) {
|
h := &Handler{
|
||||||
_ = ctx
|
cfg: &config.Config{
|
||||||
s.mu.Lock()
|
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
|
||||||
defer s.mu.Unlock()
|
|
||||||
out := make([]*coreauth.Auth, 0, len(s.items))
|
|
||||||
for _, a := range s.items {
|
|
||||||
out = append(out, a.Clone())
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
|
||||||
_ = ctx
|
|
||||||
if auth == nil {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
s.mu.Lock()
|
|
||||||
if s.items == nil {
|
|
||||||
s.items = make(map[string]*coreauth.Auth)
|
|
||||||
}
|
|
||||||
s.items[auth.ID] = auth.Clone()
|
|
||||||
s.mu.Unlock()
|
|
||||||
return auth.ID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
|
|
||||||
_ = ctx
|
|
||||||
s.mu.Lock()
|
|
||||||
delete(s.items, id)
|
|
||||||
s.mu.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
|
|
||||||
var callCount int
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
callCount++
|
|
||||||
if r.Method != http.MethodPost {
|
|
||||||
t.Fatalf("expected POST, got %s", r.Method)
|
|
||||||
}
|
|
||||||
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
|
|
||||||
t.Fatalf("unexpected content-type: %s", ct)
|
|
||||||
}
|
|
||||||
bodyBytes, _ := io.ReadAll(r.Body)
|
|
||||||
_ = r.Body.Close()
|
|
||||||
values, err := url.ParseQuery(string(bodyBytes))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parse form: %v", err)
|
|
||||||
}
|
|
||||||
if values.Get("grant_type") != "refresh_token" {
|
|
||||||
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
|
|
||||||
}
|
|
||||||
if values.Get("refresh_token") != "rt" {
|
|
||||||
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
|
|
||||||
}
|
|
||||||
if values.Get("client_id") != antigravityOAuthClientID {
|
|
||||||
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
|
|
||||||
}
|
|
||||||
if values.Get("client_secret") != antigravityOAuthClientSecret {
|
|
||||||
t.Fatalf("unexpected client_secret")
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
||||||
"access_token": "new-token",
|
|
||||||
"refresh_token": "rt2",
|
|
||||||
"expires_in": int64(3600),
|
|
||||||
"token_type": "Bearer",
|
|
||||||
})
|
|
||||||
}))
|
|
||||||
t.Cleanup(srv.Close)
|
|
||||||
|
|
||||||
originalURL := antigravityOAuthTokenURL
|
|
||||||
antigravityOAuthTokenURL = srv.URL
|
|
||||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
|
||||||
|
|
||||||
store := &memoryAuthStore{}
|
|
||||||
manager := coreauth.NewManager(store, nil, nil)
|
|
||||||
|
|
||||||
auth := &coreauth.Auth{
|
|
||||||
ID: "antigravity-test.json",
|
|
||||||
FileName: "antigravity-test.json",
|
|
||||||
Provider: "antigravity",
|
|
||||||
Metadata: map[string]any{
|
|
||||||
"type": "antigravity",
|
|
||||||
"access_token": "old-token",
|
|
||||||
"refresh_token": "rt",
|
|
||||||
"expires_in": int64(3600),
|
|
||||||
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
|
|
||||||
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
|
||||||
t.Fatalf("register auth: %v", err)
|
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "direct"})
|
||||||
|
httpTransport, ok := transport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type = %T, want *http.Transport", transport)
|
||||||
|
}
|
||||||
|
if httpTransport.Proxy != nil {
|
||||||
|
t.Fatal("expected direct transport to disable proxy function")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "bad-value"})
|
||||||
|
httpTransport, ok := transport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type = %T, want *http.Transport", transport)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||||
|
if errRequest != nil {
|
||||||
|
t.Fatalf("http.NewRequest returned error: %v", errRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, errProxy := httpTransport.Proxy(req)
|
||||||
|
if errProxy != nil {
|
||||||
|
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
|
||||||
|
}
|
||||||
|
if proxyURL == nil || proxyURL.String() != "http://global-proxy.example.com:8080" {
|
||||||
|
t.Fatalf("proxy URL = %v, want http://global-proxy.example.com:8080", proxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPICallTransportAPIKeyAuthFallsBackToConfigProxyURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
|
||||||
|
GeminiKey: []config.GeminiKey{{
|
||||||
|
APIKey: "gemini-key",
|
||||||
|
ProxyURL: "http://gemini-proxy.example.com:8080",
|
||||||
|
}},
|
||||||
|
ClaudeKey: []config.ClaudeKey{{
|
||||||
|
APIKey: "claude-key",
|
||||||
|
ProxyURL: "http://claude-proxy.example.com:8080",
|
||||||
|
}},
|
||||||
|
CodexKey: []config.CodexKey{{
|
||||||
|
APIKey: "codex-key",
|
||||||
|
ProxyURL: "http://codex-proxy.example.com:8080",
|
||||||
|
}},
|
||||||
|
OpenAICompatibility: []config.OpenAICompatibility{{
|
||||||
|
Name: "bohe",
|
||||||
|
BaseURL: "https://bohe.example.com",
|
||||||
|
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{
|
||||||
|
APIKey: "compat-key",
|
||||||
|
ProxyURL: "http://compat-proxy.example.com:8080",
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
auth *coreauth.Auth
|
||||||
|
wantProxy string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "gemini",
|
||||||
|
auth: &coreauth.Auth{
|
||||||
|
Provider: "gemini",
|
||||||
|
Attributes: map[string]string{"api_key": "gemini-key"},
|
||||||
|
},
|
||||||
|
wantProxy: "http://gemini-proxy.example.com:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "claude",
|
||||||
|
auth: &coreauth.Auth{
|
||||||
|
Provider: "claude",
|
||||||
|
Attributes: map[string]string{"api_key": "claude-key"},
|
||||||
|
},
|
||||||
|
wantProxy: "http://claude-proxy.example.com:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "codex",
|
||||||
|
auth: &coreauth.Auth{
|
||||||
|
Provider: "codex",
|
||||||
|
Attributes: map[string]string{"api_key": "codex-key"},
|
||||||
|
},
|
||||||
|
wantProxy: "http://codex-proxy.example.com:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "openai-compatibility",
|
||||||
|
auth: &coreauth.Auth{
|
||||||
|
Provider: "bohe",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "compat-key",
|
||||||
|
"compat_name": "bohe",
|
||||||
|
"provider_key": "bohe",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantProxy: "http://compat-proxy.example.com:8080",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
transport := h.apiCallTransport(tc.auth)
|
||||||
|
httpTransport, ok := transport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type = %T, want *http.Transport", transport)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||||
|
if errRequest != nil {
|
||||||
|
t.Fatalf("http.NewRequest returned error: %v", errRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, errProxy := httpTransport.Proxy(req)
|
||||||
|
if errProxy != nil {
|
||||||
|
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
|
||||||
|
}
|
||||||
|
if proxyURL == nil || proxyURL.String() != tc.wantProxy {
|
||||||
|
t.Fatalf("proxy URL = %v, want %s", proxyURL, tc.wantProxy)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
geminiAuth := &coreauth.Auth{
|
||||||
|
ID: "gemini:apikey:123",
|
||||||
|
Provider: "gemini",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "shared-key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
compatAuth := &coreauth.Auth{
|
||||||
|
ID: "openai-compatibility:bohe:456",
|
||||||
|
Provider: "bohe",
|
||||||
|
Label: "bohe",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "shared-key",
|
||||||
|
"compat_name": "bohe",
|
||||||
|
"provider_key": "bohe",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, errRegister := manager.Register(context.Background(), geminiAuth); errRegister != nil {
|
||||||
|
t.Fatalf("register gemini auth: %v", errRegister)
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), compatAuth); errRegister != nil {
|
||||||
|
t.Fatalf("register compat auth: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
geminiIndex := geminiAuth.EnsureIndex()
|
||||||
|
compatIndex := compatAuth.EnsureIndex()
|
||||||
|
if geminiIndex == compatIndex {
|
||||||
|
t.Fatalf("shared api key produced duplicate auth_index %q", geminiIndex)
|
||||||
}
|
}
|
||||||
|
|
||||||
h := &Handler{authManager: manager}
|
h := &Handler{authManager: manager}
|
||||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
|
||||||
if err != nil {
|
gotGemini := h.authByIndex(geminiIndex)
|
||||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
if gotGemini == nil {
|
||||||
|
t.Fatal("expected gemini auth by index")
|
||||||
}
|
}
|
||||||
if token != "new-token" {
|
if gotGemini.ID != geminiAuth.ID {
|
||||||
t.Fatalf("expected refreshed token, got %q", token)
|
t.Fatalf("authByIndex(gemini) returned %q, want %q", gotGemini.ID, geminiAuth.ID)
|
||||||
}
|
|
||||||
if callCount != 1 {
|
|
||||||
t.Fatalf("expected 1 refresh call, got %d", callCount)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, ok := manager.GetByID(auth.ID)
|
gotCompat := h.authByIndex(compatIndex)
|
||||||
if !ok || updated == nil {
|
if gotCompat == nil {
|
||||||
t.Fatalf("expected auth in manager after update")
|
t.Fatal("expected compat auth by index")
|
||||||
}
|
}
|
||||||
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
|
if gotCompat.ID != compatAuth.ID {
|
||||||
t.Fatalf("expected manager metadata updated, got %q", got)
|
t.Fatalf("authByIndex(compat) returned %q, want %q", gotCompat.ID, compatAuth.ID)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
|
|
||||||
var callCount int
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
callCount++
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
}))
|
|
||||||
t.Cleanup(srv.Close)
|
|
||||||
|
|
||||||
originalURL := antigravityOAuthTokenURL
|
|
||||||
antigravityOAuthTokenURL = srv.URL
|
|
||||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
|
||||||
|
|
||||||
auth := &coreauth.Auth{
|
|
||||||
ID: "antigravity-valid.json",
|
|
||||||
FileName: "antigravity-valid.json",
|
|
||||||
Provider: "antigravity",
|
|
||||||
Metadata: map[string]any{
|
|
||||||
"type": "antigravity",
|
|
||||||
"access_token": "ok-token",
|
|
||||||
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
h := &Handler{}
|
|
||||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
|
||||||
}
|
|
||||||
if token != "ok-token" {
|
|
||||||
t.Fatalf("expected existing token, got %q", token)
|
|
||||||
}
|
|
||||||
if callCount != 0 {
|
|
||||||
t.Fatalf("expected no refresh calls, got %d", callCount)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
197
internal/api/handlers/management/auth_files_batch_test.go
Normal file
197
internal/api/handlers/management/auth_files_batch_test.go
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUploadAuthFile_BatchMultipart(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||||
|
|
||||||
|
files := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
}{
|
||||||
|
{name: "alpha.json", content: `{"type":"codex","email":"alpha@example.com"}`},
|
||||||
|
{name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`},
|
||||||
|
}
|
||||||
|
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
for _, file := range files {
|
||||||
|
part, err := writer.CreateFormFile("file", file.name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create multipart file: %v", err)
|
||||||
|
}
|
||||||
|
if _, err = part.Write([]byte(file.content)); err != nil {
|
||||||
|
t.Fatalf("failed to write multipart content: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := writer.Close(); err != nil {
|
||||||
|
t.Fatalf("failed to close multipart writer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body)
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
ctx.Request = req
|
||||||
|
|
||||||
|
h.UploadAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected upload status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
if got, ok := payload["uploaded"].(float64); !ok || int(got) != len(files) {
|
||||||
|
t.Fatalf("expected uploaded=%d, got %#v", len(files), payload["uploaded"])
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range files {
|
||||||
|
fullPath := filepath.Join(authDir, file.name)
|
||||||
|
data, err := os.ReadFile(fullPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected uploaded file %s to exist: %v", file.name, err)
|
||||||
|
}
|
||||||
|
if string(data) != file.content {
|
||||||
|
t.Fatalf("expected file %s content %q, got %q", file.name, file.content, string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auths := manager.List()
|
||||||
|
if len(auths) != len(files) {
|
||||||
|
t.Fatalf("expected %d auth entries, got %d", len(files), len(auths))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadAuthFile_BatchMultipart_InvalidJSONDoesNotOverwriteExistingFile(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||||
|
|
||||||
|
existingName := "alpha.json"
|
||||||
|
existingContent := `{"type":"codex","email":"alpha@example.com"}`
|
||||||
|
if err := os.WriteFile(filepath.Join(authDir, existingName), []byte(existingContent), 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to seed existing auth file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
files := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
}{
|
||||||
|
{name: existingName, content: `{"type":"codex"`},
|
||||||
|
{name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`},
|
||||||
|
}
|
||||||
|
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
for _, file := range files {
|
||||||
|
part, err := writer.CreateFormFile("file", file.name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create multipart file: %v", err)
|
||||||
|
}
|
||||||
|
if _, err = part.Write([]byte(file.content)); err != nil {
|
||||||
|
t.Fatalf("failed to write multipart content: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := writer.Close(); err != nil {
|
||||||
|
t.Fatalf("failed to close multipart writer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body)
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
ctx.Request = req
|
||||||
|
|
||||||
|
h.UploadAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusMultiStatus {
|
||||||
|
t.Fatalf("expected upload status %d, got %d with body %s", http.StatusMultiStatus, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(filepath.Join(authDir, existingName))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected existing auth file to remain readable: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != existingContent {
|
||||||
|
t.Fatalf("expected existing auth file to remain %q, got %q", existingContent, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
betaData, err := os.ReadFile(filepath.Join(authDir, "beta.json"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected valid auth file to be created: %v", err)
|
||||||
|
}
|
||||||
|
if string(betaData) != files[1].content {
|
||||||
|
t.Fatalf("expected beta auth file content %q, got %q", files[1].content, string(betaData))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteAuthFile_BatchQuery(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
files := []string{"alpha.json", "beta.json"}
|
||||||
|
for _, name := range files {
|
||||||
|
if err := os.WriteFile(filepath.Join(authDir, name), []byte(`{"type":"codex"}`), 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to write auth file %s: %v", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||||
|
h.tokenStore = &memoryAuthStore{}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := httptest.NewRequest(
|
||||||
|
http.MethodDelete,
|
||||||
|
"/v0/management/auth-files?name="+url.QueryEscape(files[0])+"&name="+url.QueryEscape(files[1]),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
ctx.Request = req
|
||||||
|
|
||||||
|
h.DeleteAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
if got, ok := payload["deleted"].(float64); !ok || int(got) != len(files) {
|
||||||
|
t.Fatalf("expected deleted=%d, got %#v", len(files), payload["deleted"])
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range files {
|
||||||
|
if _, err := os.Stat(filepath.Join(authDir, name)); !os.IsNotExist(err) {
|
||||||
|
t.Fatalf("expected auth file %s to be removed, stat err: %v", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
62
internal/api/handlers/management/auth_files_download_test.go
Normal file
62
internal/api/handlers/management/auth_files_download_test.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDownloadAuthFile_ReturnsFile(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
fileName := "download-user.json"
|
||||||
|
expected := []byte(`{"type":"codex"}`)
|
||||||
|
if err := os.WriteFile(filepath.Join(authDir, fileName), expected, 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to write auth file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(fileName), nil)
|
||||||
|
h.DownloadAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected download status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := rec.Body.Bytes(); string(got) != string(expected) {
|
||||||
|
t.Fatalf("unexpected download content: %q", string(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadAuthFile_RejectsPathSeparators(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, nil)
|
||||||
|
|
||||||
|
for _, name := range []string{
|
||||||
|
"../external/secret.json",
|
||||||
|
`..\\external\\secret.json`,
|
||||||
|
"nested/secret.json",
|
||||||
|
`nested\\secret.json`,
|
||||||
|
} {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(name), nil)
|
||||||
|
h.DownloadAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected %d for name %q, got %d with body %s", http.StatusBadRequest, name, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDownloadAuthFile_PreventsWindowsSlashTraversal(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
authDir := filepath.Join(tempDir, "auth")
|
||||||
|
externalDir := filepath.Join(tempDir, "external")
|
||||||
|
if err := os.MkdirAll(authDir, 0o700); err != nil {
|
||||||
|
t.Fatalf("failed to create auth dir: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(externalDir, 0o700); err != nil {
|
||||||
|
t.Fatalf("failed to create external dir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
secretName := "secret.json"
|
||||||
|
secretPath := filepath.Join(externalDir, secretName)
|
||||||
|
if err := os.WriteFile(secretPath, []byte(`{"secret":true}`), 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to write external file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx.Request = httptest.NewRequest(
|
||||||
|
http.MethodGet,
|
||||||
|
"/v0/management/auth-files/download?name="+url.QueryEscape("../external/"+secretName),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
h.DownloadAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status %d, got %d with body %s", http.StatusBadRequest, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
164
internal/api/handlers/management/auth_files_gitlab_test.go
Normal file
164
internal/api/handlers/management/auth_files_gitlab_test.go
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestGitLabPATToken_SavesAuthRecord(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if got := r.Header.Get("Authorization"); got != "Bearer glpat-test-token" {
|
||||||
|
t.Fatalf("authorization header = %q, want Bearer glpat-test-token", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/v4/user":
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"id": 42,
|
||||||
|
"username": "gitlab-user",
|
||||||
|
"name": "GitLab User",
|
||||||
|
"email": "gitlab@example.com",
|
||||||
|
})
|
||||||
|
case "/api/v4/personal_access_tokens/self":
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"id": 7,
|
||||||
|
"name": "management-center",
|
||||||
|
"scopes": []string{"api", "read_user"},
|
||||||
|
"user_id": 42,
|
||||||
|
})
|
||||||
|
case "/api/v4/code_suggestions/direct_access":
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"base_url": "https://cloud.gitlab.example.com",
|
||||||
|
"token": "gateway-token",
|
||||||
|
"expires_at": 1893456000,
|
||||||
|
"headers": map[string]string{
|
||||||
|
"X-Gitlab-Realm": "saas",
|
||||||
|
},
|
||||||
|
"model_details": map[string]any{
|
||||||
|
"model_provider": "anthropic",
|
||||||
|
"model_name": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer upstream.Close()
|
||||||
|
|
||||||
|
store := &memoryAuthStore{}
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, coreauth.NewManager(nil, nil, nil))
|
||||||
|
h.tokenStore = store
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/gitlab-auth-url", strings.NewReader(`{"base_url":"`+upstream.URL+`","personal_access_token":"glpat-test-token"}`))
|
||||||
|
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.RequestGitLabPATToken(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if got := resp["status"]; got != "ok" {
|
||||||
|
t.Fatalf("status = %#v, want ok", got)
|
||||||
|
}
|
||||||
|
if got := resp["model_provider"]; got != "anthropic" {
|
||||||
|
t.Fatalf("model_provider = %#v, want anthropic", got)
|
||||||
|
}
|
||||||
|
if got := resp["model_name"]; got != "claude-sonnet-4-5" {
|
||||||
|
t.Fatalf("model_name = %#v, want claude-sonnet-4-5", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
store.mu.Lock()
|
||||||
|
defer store.mu.Unlock()
|
||||||
|
if len(store.items) != 1 {
|
||||||
|
t.Fatalf("expected 1 saved auth record, got %d", len(store.items))
|
||||||
|
}
|
||||||
|
var saved *coreauth.Auth
|
||||||
|
for _, item := range store.items {
|
||||||
|
saved = item
|
||||||
|
}
|
||||||
|
if saved == nil {
|
||||||
|
t.Fatal("expected saved auth record")
|
||||||
|
}
|
||||||
|
if saved.Provider != "gitlab" {
|
||||||
|
t.Fatalf("provider = %q, want gitlab", saved.Provider)
|
||||||
|
}
|
||||||
|
if got := saved.Metadata["auth_kind"]; got != "personal_access_token" {
|
||||||
|
t.Fatalf("auth_kind = %#v, want personal_access_token", got)
|
||||||
|
}
|
||||||
|
if got := saved.Metadata["model_provider"]; got != "anthropic" {
|
||||||
|
t.Fatalf("saved model_provider = %#v, want anthropic", got)
|
||||||
|
}
|
||||||
|
if got := saved.Metadata["duo_gateway_token"]; got != "gateway-token" {
|
||||||
|
t.Fatalf("saved duo_gateway_token = %#v, want gateway-token", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostOAuthCallback_GitLabWritesPendingCallbackFile(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
state := "gitlab-state-123"
|
||||||
|
RegisterOAuthSession(state, "gitlab")
|
||||||
|
t.Cleanup(func() { CompleteOAuthSession(state) })
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, coreauth.NewManager(nil, nil, nil))
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/oauth-callback", strings.NewReader(`{"provider":"gitlab","redirect_url":"http://localhost:17171/auth/callback?code=test-code&state=`+state+`"}`))
|
||||||
|
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.PostOAuthCallback(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := filepath.Join(authDir, ".oauth-gitlab-"+state+".oauth")
|
||||||
|
data, err := os.ReadFile(filePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read callback file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]string
|
||||||
|
if err := json.Unmarshal(data, &payload); err != nil {
|
||||||
|
t.Fatalf("decode callback payload: %v", err)
|
||||||
|
}
|
||||||
|
if got := payload["code"]; got != "test-code" {
|
||||||
|
t.Fatalf("callback code = %q, want test-code", got)
|
||||||
|
}
|
||||||
|
if got := payload["state"]; got != state {
|
||||||
|
t.Fatalf("callback state = %q, want %q", got, state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOAuthProvider_GitLab(t *testing.T) {
|
||||||
|
provider, err := NormalizeOAuthProvider("gitlab")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NormalizeOAuthProvider returned error: %v", err)
|
||||||
|
}
|
||||||
|
if provider != "gitlab" {
|
||||||
|
t.Fatalf("provider = %q, want gitlab", provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
164
internal/api/handlers/management/auth_files_patch_fields_test.go
Normal file
164
internal/api/handlers/management/auth_files_patch_fields_test.go
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPatchAuthFileFields_MergeHeadersAndDeleteEmptyValues(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
store := &memoryAuthStore{}
|
||||||
|
manager := coreauth.NewManager(store, nil, nil)
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: "test.json",
|
||||||
|
FileName: "test.json",
|
||||||
|
Provider: "claude",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"path": "/tmp/test.json",
|
||||||
|
"header:X-Old": "old",
|
||||||
|
"header:X-Remove": "gone",
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "claude",
|
||||||
|
"headers": map[string]any{
|
||||||
|
"X-Old": "old",
|
||||||
|
"X-Remove": "gone",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
|
||||||
|
t.Fatalf("failed to register auth record: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
|
||||||
|
|
||||||
|
body := `{"name":"test.json","prefix":"p1","proxy_url":"http://proxy.local","headers":{"X-Old":"new","X-New":"v","X-Remove":" ","X-Nope":""}}`
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
ctx.Request = req
|
||||||
|
h.PatchAuthFileFields(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, ok := manager.GetByID("test.json")
|
||||||
|
if !ok || updated == nil {
|
||||||
|
t.Fatalf("expected auth record to exist after patch")
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated.Prefix != "p1" {
|
||||||
|
t.Fatalf("prefix = %q, want %q", updated.Prefix, "p1")
|
||||||
|
}
|
||||||
|
if updated.ProxyURL != "http://proxy.local" {
|
||||||
|
t.Fatalf("proxy_url = %q, want %q", updated.ProxyURL, "http://proxy.local")
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated.Metadata == nil {
|
||||||
|
t.Fatalf("expected metadata to be non-nil")
|
||||||
|
}
|
||||||
|
if got, _ := updated.Metadata["prefix"].(string); got != "p1" {
|
||||||
|
t.Fatalf("metadata.prefix = %q, want %q", got, "p1")
|
||||||
|
}
|
||||||
|
if got, _ := updated.Metadata["proxy_url"].(string); got != "http://proxy.local" {
|
||||||
|
t.Fatalf("metadata.proxy_url = %q, want %q", got, "http://proxy.local")
|
||||||
|
}
|
||||||
|
|
||||||
|
headersMeta, ok := updated.Metadata["headers"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
raw, _ := json.Marshal(updated.Metadata["headers"])
|
||||||
|
t.Fatalf("metadata.headers = %T (%s), want map[string]any", updated.Metadata["headers"], string(raw))
|
||||||
|
}
|
||||||
|
if got := headersMeta["X-Old"]; got != "new" {
|
||||||
|
t.Fatalf("metadata.headers.X-Old = %#v, want %q", got, "new")
|
||||||
|
}
|
||||||
|
if got := headersMeta["X-New"]; got != "v" {
|
||||||
|
t.Fatalf("metadata.headers.X-New = %#v, want %q", got, "v")
|
||||||
|
}
|
||||||
|
if _, ok := headersMeta["X-Remove"]; ok {
|
||||||
|
t.Fatalf("expected metadata.headers.X-Remove to be deleted")
|
||||||
|
}
|
||||||
|
if _, ok := headersMeta["X-Nope"]; ok {
|
||||||
|
t.Fatalf("expected metadata.headers.X-Nope to be absent")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := updated.Attributes["header:X-Old"]; got != "new" {
|
||||||
|
t.Fatalf("attrs header:X-Old = %q, want %q", got, "new")
|
||||||
|
}
|
||||||
|
if got := updated.Attributes["header:X-New"]; got != "v" {
|
||||||
|
t.Fatalf("attrs header:X-New = %q, want %q", got, "v")
|
||||||
|
}
|
||||||
|
if _, ok := updated.Attributes["header:X-Remove"]; ok {
|
||||||
|
t.Fatalf("expected attrs header:X-Remove to be deleted")
|
||||||
|
}
|
||||||
|
if _, ok := updated.Attributes["header:X-Nope"]; ok {
|
||||||
|
t.Fatalf("expected attrs header:X-Nope to be absent")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatchAuthFileFields_HeadersEmptyMapIsNoop(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
store := &memoryAuthStore{}
|
||||||
|
manager := coreauth.NewManager(store, nil, nil)
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: "noop.json",
|
||||||
|
FileName: "noop.json",
|
||||||
|
Provider: "claude",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"path": "/tmp/noop.json",
|
||||||
|
"header:X-Kee": "1",
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "claude",
|
||||||
|
"headers": map[string]any{
|
||||||
|
"X-Kee": "1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
|
||||||
|
t.Fatalf("failed to register auth record: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
|
||||||
|
|
||||||
|
body := `{"name":"noop.json","note":"hello","headers":{}}`
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
ctx.Request = req
|
||||||
|
h.PatchAuthFileFields(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, ok := manager.GetByID("noop.json")
|
||||||
|
if !ok || updated == nil {
|
||||||
|
t.Fatalf("expected auth record to exist after patch")
|
||||||
|
}
|
||||||
|
if got := updated.Attributes["header:X-Kee"]; got != "1" {
|
||||||
|
t.Fatalf("attrs header:X-Kee = %q, want %q", got, "1")
|
||||||
|
}
|
||||||
|
headersMeta, ok := updated.Metadata["headers"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected metadata.headers to remain a map, got %T", updated.Metadata["headers"])
|
||||||
|
}
|
||||||
|
if got := headersMeta["X-Kee"]; got != "1" {
|
||||||
|
t.Fatalf("metadata.headers.X-Kee = %#v, want %q", got, "1")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -214,19 +214,46 @@ func (h *Handler) PatchGeminiKey(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) DeleteGeminiKey(c *gin.Context) {
|
func (h *Handler) DeleteGeminiKey(c *gin.Context) {
|
||||||
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||||
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
|
if baseRaw, okBase := c.GetQuery("base-url"); okBase {
|
||||||
for _, v := range h.cfg.GeminiKey {
|
base := strings.TrimSpace(baseRaw)
|
||||||
if v.APIKey != val {
|
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
|
||||||
|
for _, v := range h.cfg.GeminiKey {
|
||||||
|
if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base {
|
||||||
|
continue
|
||||||
|
}
|
||||||
out = append(out, v)
|
out = append(out, v)
|
||||||
}
|
}
|
||||||
|
if len(out) != len(h.cfg.GeminiKey) {
|
||||||
|
h.cfg.GeminiKey = out
|
||||||
|
h.cfg.SanitizeGeminiKeys()
|
||||||
|
h.persist(c)
|
||||||
|
} else {
|
||||||
|
c.JSON(404, gin.H{"error": "item not found"})
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if len(out) != len(h.cfg.GeminiKey) {
|
|
||||||
h.cfg.GeminiKey = out
|
matchIndex := -1
|
||||||
h.cfg.SanitizeGeminiKeys()
|
matchCount := 0
|
||||||
h.persist(c)
|
for i := range h.cfg.GeminiKey {
|
||||||
} else {
|
if strings.TrimSpace(h.cfg.GeminiKey[i].APIKey) == val {
|
||||||
|
matchCount++
|
||||||
|
if matchIndex == -1 {
|
||||||
|
matchIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if matchCount == 0 {
|
||||||
c.JSON(404, gin.H{"error": "item not found"})
|
c.JSON(404, gin.H{"error": "item not found"})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
if matchCount > 1 {
|
||||||
|
c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:matchIndex], h.cfg.GeminiKey[matchIndex+1:]...)
|
||||||
|
h.cfg.SanitizeGeminiKeys()
|
||||||
|
h.persist(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if idxStr := c.Query("index"); idxStr != "" {
|
if idxStr := c.Query("index"); idxStr != "" {
|
||||||
@@ -335,14 +362,39 @@ func (h *Handler) PatchClaudeKey(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) DeleteClaudeKey(c *gin.Context) {
|
func (h *Handler) DeleteClaudeKey(c *gin.Context) {
|
||||||
if val := c.Query("api-key"); val != "" {
|
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||||
out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey))
|
if baseRaw, okBase := c.GetQuery("base-url"); okBase {
|
||||||
for _, v := range h.cfg.ClaudeKey {
|
base := strings.TrimSpace(baseRaw)
|
||||||
if v.APIKey != val {
|
out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey))
|
||||||
|
for _, v := range h.cfg.ClaudeKey {
|
||||||
|
if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base {
|
||||||
|
continue
|
||||||
|
}
|
||||||
out = append(out, v)
|
out = append(out, v)
|
||||||
}
|
}
|
||||||
|
h.cfg.ClaudeKey = out
|
||||||
|
h.cfg.SanitizeClaudeKeys()
|
||||||
|
h.persist(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
matchIndex := -1
|
||||||
|
matchCount := 0
|
||||||
|
for i := range h.cfg.ClaudeKey {
|
||||||
|
if strings.TrimSpace(h.cfg.ClaudeKey[i].APIKey) == val {
|
||||||
|
matchCount++
|
||||||
|
if matchIndex == -1 {
|
||||||
|
matchIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if matchCount > 1 {
|
||||||
|
c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if matchIndex != -1 {
|
||||||
|
h.cfg.ClaudeKey = append(h.cfg.ClaudeKey[:matchIndex], h.cfg.ClaudeKey[matchIndex+1:]...)
|
||||||
}
|
}
|
||||||
h.cfg.ClaudeKey = out
|
|
||||||
h.cfg.SanitizeClaudeKeys()
|
h.cfg.SanitizeClaudeKeys()
|
||||||
h.persist(c)
|
h.persist(c)
|
||||||
return
|
return
|
||||||
@@ -509,8 +561,12 @@ func (h *Handler) PutVertexCompatKeys(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
for i := range arr {
|
for i := range arr {
|
||||||
normalizeVertexCompatKey(&arr[i])
|
normalizeVertexCompatKey(&arr[i])
|
||||||
|
if arr[i].APIKey == "" {
|
||||||
|
c.JSON(400, gin.H{"error": fmt.Sprintf("vertex-api-key[%d].api-key is required", i)})
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
h.cfg.VertexCompatAPIKey = arr
|
h.cfg.VertexCompatAPIKey = append([]config.VertexCompatKey(nil), arr...)
|
||||||
h.cfg.SanitizeVertexCompatKeys()
|
h.cfg.SanitizeVertexCompatKeys()
|
||||||
h.persist(c)
|
h.persist(c)
|
||||||
}
|
}
|
||||||
@@ -597,13 +653,38 @@ func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) DeleteVertexCompatKey(c *gin.Context) {
|
func (h *Handler) DeleteVertexCompatKey(c *gin.Context) {
|
||||||
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||||
out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey))
|
if baseRaw, okBase := c.GetQuery("base-url"); okBase {
|
||||||
for _, v := range h.cfg.VertexCompatAPIKey {
|
base := strings.TrimSpace(baseRaw)
|
||||||
if v.APIKey != val {
|
out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey))
|
||||||
|
for _, v := range h.cfg.VertexCompatAPIKey {
|
||||||
|
if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base {
|
||||||
|
continue
|
||||||
|
}
|
||||||
out = append(out, v)
|
out = append(out, v)
|
||||||
}
|
}
|
||||||
|
h.cfg.VertexCompatAPIKey = out
|
||||||
|
h.cfg.SanitizeVertexCompatKeys()
|
||||||
|
h.persist(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
matchIndex := -1
|
||||||
|
matchCount := 0
|
||||||
|
for i := range h.cfg.VertexCompatAPIKey {
|
||||||
|
if strings.TrimSpace(h.cfg.VertexCompatAPIKey[i].APIKey) == val {
|
||||||
|
matchCount++
|
||||||
|
if matchIndex == -1 {
|
||||||
|
matchIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if matchCount > 1 {
|
||||||
|
c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if matchIndex != -1 {
|
||||||
|
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:matchIndex], h.cfg.VertexCompatAPIKey[matchIndex+1:]...)
|
||||||
}
|
}
|
||||||
h.cfg.VertexCompatAPIKey = out
|
|
||||||
h.cfg.SanitizeVertexCompatKeys()
|
h.cfg.SanitizeVertexCompatKeys()
|
||||||
h.persist(c)
|
h.persist(c)
|
||||||
return
|
return
|
||||||
@@ -915,14 +996,39 @@ func (h *Handler) PatchCodexKey(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) DeleteCodexKey(c *gin.Context) {
|
func (h *Handler) DeleteCodexKey(c *gin.Context) {
|
||||||
if val := c.Query("api-key"); val != "" {
|
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||||
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
|
if baseRaw, okBase := c.GetQuery("base-url"); okBase {
|
||||||
for _, v := range h.cfg.CodexKey {
|
base := strings.TrimSpace(baseRaw)
|
||||||
if v.APIKey != val {
|
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
|
||||||
|
for _, v := range h.cfg.CodexKey {
|
||||||
|
if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base {
|
||||||
|
continue
|
||||||
|
}
|
||||||
out = append(out, v)
|
out = append(out, v)
|
||||||
}
|
}
|
||||||
|
h.cfg.CodexKey = out
|
||||||
|
h.cfg.SanitizeCodexKeys()
|
||||||
|
h.persist(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
matchIndex := -1
|
||||||
|
matchCount := 0
|
||||||
|
for i := range h.cfg.CodexKey {
|
||||||
|
if strings.TrimSpace(h.cfg.CodexKey[i].APIKey) == val {
|
||||||
|
matchCount++
|
||||||
|
if matchIndex == -1 {
|
||||||
|
matchIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if matchCount > 1 {
|
||||||
|
c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if matchIndex != -1 {
|
||||||
|
h.cfg.CodexKey = append(h.cfg.CodexKey[:matchIndex], h.cfg.CodexKey[matchIndex+1:]...)
|
||||||
}
|
}
|
||||||
h.cfg.CodexKey = out
|
|
||||||
h.cfg.SanitizeCodexKeys()
|
h.cfg.SanitizeCodexKeys()
|
||||||
h.persist(c)
|
h.persist(c)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -0,0 +1,172 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func writeTestConfigFile(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
if errWrite := os.WriteFile(path, []byte("{}\n"), 0o600); errWrite != nil {
|
||||||
|
t.Fatalf("failed to write test config: %v", errWrite)
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteGeminiKey_RequiresBaseURLWhenAPIKeyDuplicated(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
GeminiKey: []config.GeminiKey{
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://a.example.com"},
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://b.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
configFilePath: writeTestConfigFile(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/gemini-api-key?api-key=shared-key", nil)
|
||||||
|
|
||||||
|
h.DeleteGeminiKey(c)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.cfg.GeminiKey); got != 2 {
|
||||||
|
t.Fatalf("gemini keys len = %d, want 2", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteGeminiKey_DeletesOnlyMatchingBaseURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
GeminiKey: []config.GeminiKey{
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://a.example.com"},
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://b.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
configFilePath: writeTestConfigFile(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/gemini-api-key?api-key=shared-key&base-url=https://a.example.com", nil)
|
||||||
|
|
||||||
|
h.DeleteGeminiKey(c)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.cfg.GeminiKey); got != 1 {
|
||||||
|
t.Fatalf("gemini keys len = %d, want 1", got)
|
||||||
|
}
|
||||||
|
if got := h.cfg.GeminiKey[0].BaseURL; got != "https://b.example.com" {
|
||||||
|
t.Fatalf("remaining base-url = %q, want %q", got, "https://b.example.com")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteClaudeKey_DeletesEmptyBaseURLWhenExplicitlyProvided(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
ClaudeKey: []config.ClaudeKey{
|
||||||
|
{APIKey: "shared-key", BaseURL: ""},
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://claude.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
configFilePath: writeTestConfigFile(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/claude-api-key?api-key=shared-key&base-url=", nil)
|
||||||
|
|
||||||
|
h.DeleteClaudeKey(c)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.cfg.ClaudeKey); got != 1 {
|
||||||
|
t.Fatalf("claude keys len = %d, want 1", got)
|
||||||
|
}
|
||||||
|
if got := h.cfg.ClaudeKey[0].BaseURL; got != "https://claude.example.com" {
|
||||||
|
t.Fatalf("remaining base-url = %q, want %q", got, "https://claude.example.com")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteVertexCompatKey_DeletesOnlyMatchingBaseURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://a.example.com"},
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://b.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
configFilePath: writeTestConfigFile(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/vertex-api-key?api-key=shared-key&base-url=https://b.example.com", nil)
|
||||||
|
|
||||||
|
h.DeleteVertexCompatKey(c)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.cfg.VertexCompatAPIKey); got != 1 {
|
||||||
|
t.Fatalf("vertex keys len = %d, want 1", got)
|
||||||
|
}
|
||||||
|
if got := h.cfg.VertexCompatAPIKey[0].BaseURL; got != "https://a.example.com" {
|
||||||
|
t.Fatalf("remaining base-url = %q, want %q", got, "https://a.example.com")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteCodexKey_RequiresBaseURLWhenAPIKeyDuplicated(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
CodexKey: []config.CodexKey{
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://a.example.com"},
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://b.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
configFilePath: writeTestConfigFile(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/codex-api-key?api-key=shared-key", nil)
|
||||||
|
|
||||||
|
h.DeleteCodexKey(c)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.cfg.CodexKey); got != 2 {
|
||||||
|
t.Fatalf("codex keys len = %d, want 2", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -228,14 +228,12 @@ func NormalizeOAuthProvider(provider string) (string, error) {
|
|||||||
return "anthropic", nil
|
return "anthropic", nil
|
||||||
case "codex", "openai":
|
case "codex", "openai":
|
||||||
return "codex", nil
|
return "codex", nil
|
||||||
|
case "gitlab":
|
||||||
|
return "gitlab", nil
|
||||||
case "gemini", "google":
|
case "gemini", "google":
|
||||||
return "gemini", nil
|
return "gemini", nil
|
||||||
case "iflow", "i-flow":
|
|
||||||
return "iflow", nil
|
|
||||||
case "antigravity", "anti-gravity":
|
case "antigravity", "anti-gravity":
|
||||||
return "antigravity", nil
|
return "antigravity", nil
|
||||||
case "qwen":
|
|
||||||
return "qwen", nil
|
|
||||||
case "kiro":
|
case "kiro":
|
||||||
return "kiro", nil
|
return "kiro", nil
|
||||||
case "github":
|
case "github":
|
||||||
|
|||||||
49
internal/api/handlers/management/test_store_test.go
Normal file
49
internal/api/handlers/management/test_store_test.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type memoryAuthStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
items map[string]*coreauth.Auth
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) List(_ context.Context) ([]*coreauth.Auth, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
out := make([]*coreauth.Auth, 0, len(s.items))
|
||||||
|
for _, item := range s.items {
|
||||||
|
out = append(out, item)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) Save(_ context.Context, auth *coreauth.Auth) (string, error) {
|
||||||
|
if auth == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.items == nil {
|
||||||
|
s.items = make(map[string]*coreauth.Auth)
|
||||||
|
}
|
||||||
|
s.items[auth.ID] = auth
|
||||||
|
return auth.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) Delete(_ context.Context, id string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
delete(s.items, id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) SetBaseDir(string) {}
|
||||||
@@ -15,6 +15,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
|
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
|
||||||
|
const responseBodyOverrideContextKey = "RESPONSE_BODY_OVERRIDE"
|
||||||
|
const websocketTimelineOverrideContextKey = "WEBSOCKET_TIMELINE_OVERRIDE"
|
||||||
|
|
||||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||||
type RequestInfo struct {
|
type RequestInfo struct {
|
||||||
@@ -304,6 +306,10 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
|||||||
if len(apiResponse) > 0 {
|
if len(apiResponse) > 0 {
|
||||||
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
||||||
}
|
}
|
||||||
|
apiWebsocketTimeline := w.extractAPIWebsocketTimeline(c)
|
||||||
|
if len(apiWebsocketTimeline) > 0 {
|
||||||
|
_ = w.streamWriter.WriteAPIWebsocketTimeline(apiWebsocketTimeline)
|
||||||
|
}
|
||||||
if err := w.streamWriter.Close(); err != nil {
|
if err := w.streamWriter.Close(); err != nil {
|
||||||
w.streamWriter = nil
|
w.streamWriter = nil
|
||||||
return err
|
return err
|
||||||
@@ -312,7 +318,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.extractResponseBody(c), w.extractWebsocketTimeline(c), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIWebsocketTimeline(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
||||||
@@ -352,6 +358,18 @@ func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWriterWrapper) extractAPIWebsocketTimeline(c *gin.Context) []byte {
|
||||||
|
apiTimeline, isExist := c.Get("API_WEBSOCKET_TIMELINE")
|
||||||
|
if !isExist {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data, ok := apiTimeline.([]byte)
|
||||||
|
if !ok || len(data) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bytes.Clone(data)
|
||||||
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
|
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
|
||||||
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
|
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
|
||||||
if !isExist {
|
if !isExist {
|
||||||
@@ -364,19 +382,8 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
||||||
if c != nil {
|
if body := extractBodyOverride(c, requestBodyOverrideContextKey); len(body) > 0 {
|
||||||
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
|
return body
|
||||||
switch value := bodyOverride.(type) {
|
|
||||||
case []byte:
|
|
||||||
if len(value) > 0 {
|
|
||||||
return bytes.Clone(value)
|
|
||||||
}
|
|
||||||
case string:
|
|
||||||
if strings.TrimSpace(value) != "" {
|
|
||||||
return []byte(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||||
return w.requestInfo.Body
|
return w.requestInfo.Body
|
||||||
@@ -384,13 +391,48 @@ func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
func (w *ResponseWriterWrapper) extractResponseBody(c *gin.Context) []byte {
|
||||||
|
if body := extractBodyOverride(c, responseBodyOverrideContextKey); len(body) > 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
if w.body == nil || w.body.Len() == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bytes.Clone(w.body.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWriterWrapper) extractWebsocketTimeline(c *gin.Context) []byte {
|
||||||
|
return extractBodyOverride(c, websocketTimelineOverrideContextKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractBodyOverride(c *gin.Context, key string) []byte {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
bodyOverride, isExist := c.Get(key)
|
||||||
|
if !isExist {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch value := bodyOverride.(type) {
|
||||||
|
case []byte:
|
||||||
|
if len(value) > 0 {
|
||||||
|
return bytes.Clone(value)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
if strings.TrimSpace(value) != "" {
|
||||||
|
return []byte(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body, websocketTimeline, apiRequestBody, apiResponseBody, apiWebsocketTimeline []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||||
if w.requestInfo == nil {
|
if w.requestInfo == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if loggerWithOptions, ok := w.logger.(interface {
|
if loggerWithOptions, ok := w.logger.(interface {
|
||||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
||||||
}); ok {
|
}); ok {
|
||||||
return loggerWithOptions.LogRequestWithOptions(
|
return loggerWithOptions.LogRequestWithOptions(
|
||||||
w.requestInfo.URL,
|
w.requestInfo.URL,
|
||||||
@@ -400,8 +442,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
|
|||||||
statusCode,
|
statusCode,
|
||||||
headers,
|
headers,
|
||||||
body,
|
body,
|
||||||
|
websocketTimeline,
|
||||||
apiRequestBody,
|
apiRequestBody,
|
||||||
apiResponseBody,
|
apiResponseBody,
|
||||||
|
apiWebsocketTimeline,
|
||||||
apiResponseErrors,
|
apiResponseErrors,
|
||||||
forceLog,
|
forceLog,
|
||||||
w.requestInfo.RequestID,
|
w.requestInfo.RequestID,
|
||||||
@@ -418,8 +462,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
|
|||||||
statusCode,
|
statusCode,
|
||||||
headers,
|
headers,
|
||||||
body,
|
body,
|
||||||
|
websocketTimeline,
|
||||||
apiRequestBody,
|
apiRequestBody,
|
||||||
apiResponseBody,
|
apiResponseBody,
|
||||||
|
apiWebsocketTimeline,
|
||||||
apiResponseErrors,
|
apiResponseErrors,
|
||||||
w.requestInfo.RequestID,
|
w.requestInfo.RequestID,
|
||||||
w.requestInfo.Timestamp,
|
w.requestInfo.Timestamp,
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
|
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
|
||||||
@@ -33,7 +37,7 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
|||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
wrapper := &ResponseWriterWrapper{}
|
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
|
||||||
c.Set(requestBodyOverrideContextKey, "override-as-string")
|
c.Set(requestBodyOverrideContextKey, "override-as-string")
|
||||||
|
|
||||||
body := wrapper.extractRequestBody(c)
|
body := wrapper.extractRequestBody(c)
|
||||||
@@ -41,3 +45,158 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
|||||||
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
|
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtractResponseBodyPrefersOverride(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
|
||||||
|
wrapper.body.WriteString("original-response")
|
||||||
|
|
||||||
|
body := wrapper.extractResponseBody(c)
|
||||||
|
if string(body) != "original-response" {
|
||||||
|
t.Fatalf("response body = %q, want %q", string(body), "original-response")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set(responseBodyOverrideContextKey, []byte("override-response"))
|
||||||
|
body = wrapper.extractResponseBody(c)
|
||||||
|
if string(body) != "override-response" {
|
||||||
|
t.Fatalf("response body = %q, want %q", string(body), "override-response")
|
||||||
|
}
|
||||||
|
|
||||||
|
body[0] = 'X'
|
||||||
|
if got := wrapper.extractResponseBody(c); string(got) != "override-response" {
|
||||||
|
t.Fatalf("response override should be cloned, got %q", string(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractResponseBodySupportsStringOverride(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
wrapper := &ResponseWriterWrapper{}
|
||||||
|
c.Set(responseBodyOverrideContextKey, "override-response-as-string")
|
||||||
|
|
||||||
|
body := wrapper.extractResponseBody(c)
|
||||||
|
if string(body) != "override-response-as-string" {
|
||||||
|
t.Fatalf("response body = %q, want %q", string(body), "override-response-as-string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractBodyOverrideClonesBytes(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
override := []byte("body-override")
|
||||||
|
c.Set(requestBodyOverrideContextKey, override)
|
||||||
|
|
||||||
|
body := extractBodyOverride(c, requestBodyOverrideContextKey)
|
||||||
|
if !bytes.Equal(body, override) {
|
||||||
|
t.Fatalf("body override = %q, want %q", string(body), string(override))
|
||||||
|
}
|
||||||
|
|
||||||
|
body[0] = 'X'
|
||||||
|
if !bytes.Equal(override, []byte("body-override")) {
|
||||||
|
t.Fatalf("override mutated: %q", string(override))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractWebsocketTimelineUsesOverride(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
wrapper := &ResponseWriterWrapper{}
|
||||||
|
if got := wrapper.extractWebsocketTimeline(c); got != nil {
|
||||||
|
t.Fatalf("expected nil websocket timeline, got %q", string(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set(websocketTimelineOverrideContextKey, []byte("timeline"))
|
||||||
|
body := wrapper.extractWebsocketTimeline(c)
|
||||||
|
if string(body) != "timeline" {
|
||||||
|
t.Fatalf("websocket timeline = %q, want %q", string(body), "timeline")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFinalizeStreamingWritesAPIWebsocketTimeline(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
streamWriter := &testStreamingLogWriter{}
|
||||||
|
wrapper := &ResponseWriterWrapper{
|
||||||
|
ResponseWriter: c.Writer,
|
||||||
|
logger: &testRequestLogger{enabled: true},
|
||||||
|
requestInfo: &RequestInfo{
|
||||||
|
URL: "/v1/responses",
|
||||||
|
Method: "POST",
|
||||||
|
Headers: map[string][]string{"Content-Type": {"application/json"}},
|
||||||
|
RequestID: "req-1",
|
||||||
|
Timestamp: time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC),
|
||||||
|
},
|
||||||
|
isStreaming: true,
|
||||||
|
streamWriter: streamWriter,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set("API_WEBSOCKET_TIMELINE", []byte("Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}"))
|
||||||
|
|
||||||
|
if err := wrapper.Finalize(c); err != nil {
|
||||||
|
t.Fatalf("Finalize error: %v", err)
|
||||||
|
}
|
||||||
|
if string(streamWriter.apiWebsocketTimeline) != "Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}" {
|
||||||
|
t.Fatalf("stream writer websocket timeline = %q", string(streamWriter.apiWebsocketTimeline))
|
||||||
|
}
|
||||||
|
if !streamWriter.closed {
|
||||||
|
t.Fatal("expected stream writer to be closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testRequestLogger struct {
|
||||||
|
enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *testRequestLogger) LogRequest(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, string, time.Time, time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *testRequestLogger) LogStreamingRequest(string, string, map[string][]string, []byte, string) (logging.StreamingLogWriter, error) {
|
||||||
|
return &testStreamingLogWriter{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *testRequestLogger) IsEnabled() bool {
|
||||||
|
return l.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
type testStreamingLogWriter struct {
|
||||||
|
apiWebsocketTimeline []byte
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) WriteChunkAsync([]byte) {}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) WriteStatus(int, map[string][]string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) WriteAPIRequest([]byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) WriteAPIResponse([]byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
|
||||||
|
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) SetFirstChunkTimestamp(time.Time) {}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) Close() error {
|
||||||
|
w.closed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -123,6 +123,10 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sanitize request body: remove thinking blocks with invalid signatures
|
||||||
|
// to prevent upstream API 400 errors
|
||||||
|
bodyBytes = SanitizeAmpRequestBody(bodyBytes)
|
||||||
|
|
||||||
// Restore the body for the handler to read
|
// Restore the body for the handler to read
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
|
||||||
@@ -249,6 +253,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
||||||
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
||||||
rewriter := NewResponseRewriter(c.Writer, modelName)
|
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||||
|
rewriter.suppressThinking = true
|
||||||
c.Writer = rewriter
|
c.Writer = rewriter
|
||||||
// Filter Anthropic-Beta header only for local handling paths
|
// Filter Anthropic-Beta header only for local handling paths
|
||||||
filterAntropicBetaHeader(c)
|
filterAntropicBetaHeader(c)
|
||||||
@@ -259,10 +264,17 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
} else if len(providers) > 0 {
|
} else if len(providers) > 0 {
|
||||||
// Log: Using local provider (free)
|
// Log: Using local provider (free)
|
||||||
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
||||||
|
// Wrap with ResponseRewriter for local providers too, because upstream
|
||||||
|
// proxies (e.g. NewAPI) may return a different model name and lack
|
||||||
|
// Amp-required fields like thinking.signature.
|
||||||
|
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||||
|
rewriter.suppressThinking = providerName != "claude"
|
||||||
|
c.Writer = rewriter
|
||||||
// Filter Anthropic-Beta header only for local handling paths
|
// Filter Anthropic-Beta header only for local handling paths
|
||||||
filterAntropicBetaHeader(c)
|
filterAntropicBetaHeader(c)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
handler(c)
|
handler(c)
|
||||||
|
rewriter.Flush()
|
||||||
} else {
|
} else {
|
||||||
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
|
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
|||||||
@@ -129,11 +129,11 @@ func TestModifyResponse_GzipScenarios(t *testing.T) {
|
|||||||
wantCE: "",
|
wantCE: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "skips_non_2xx_status",
|
name: "decompresses_non_2xx_status_when_gzip_detected",
|
||||||
header: http.Header{},
|
header: http.Header{},
|
||||||
body: good,
|
body: good,
|
||||||
status: 404,
|
status: 404,
|
||||||
wantBody: good,
|
wantBody: goodJSON,
|
||||||
wantCE: "",
|
wantCE: "",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package amp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -12,15 +14,17 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
|
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
|
||||||
// It's used to rewrite model names in responses when model mapping is used
|
// It is used to rewrite model names in responses when model mapping is used
|
||||||
|
// and to keep Amp-compatible response shapes.
|
||||||
type ResponseRewriter struct {
|
type ResponseRewriter struct {
|
||||||
gin.ResponseWriter
|
gin.ResponseWriter
|
||||||
body *bytes.Buffer
|
body *bytes.Buffer
|
||||||
originalModel string
|
originalModel string
|
||||||
isStreaming bool
|
isStreaming bool
|
||||||
|
suppressThinking bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResponseRewriter creates a new response rewriter for model name substitution
|
// NewResponseRewriter creates a new response rewriter for model name substitution.
|
||||||
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
|
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
|
||||||
return &ResponseRewriter{
|
return &ResponseRewriter{
|
||||||
ResponseWriter: w,
|
ResponseWriter: w,
|
||||||
@@ -33,15 +37,15 @@ const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
|
|||||||
|
|
||||||
func looksLikeSSEChunk(data []byte) bool {
|
func looksLikeSSEChunk(data []byte) bool {
|
||||||
// Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
|
// Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
|
||||||
// Heuristics are intentionally simple and cheap.
|
// We conservatively detect SSE by checking for "data:" / "event:" at the start of any line.
|
||||||
return bytes.Contains(data, []byte("data:")) ||
|
for _, line := range bytes.Split(data, []byte("\n")) {
|
||||||
bytes.Contains(data, []byte("event:")) ||
|
trimmed := bytes.TrimSpace(line)
|
||||||
bytes.Contains(data, []byte("message_start")) ||
|
if bytes.HasPrefix(trimmed, []byte("data:")) ||
|
||||||
bytes.Contains(data, []byte("message_delta")) ||
|
bytes.HasPrefix(trimmed, []byte("event:")) {
|
||||||
bytes.Contains(data, []byte("content_block_start")) ||
|
return true
|
||||||
bytes.Contains(data, []byte("content_block_delta")) ||
|
}
|
||||||
bytes.Contains(data, []byte("content_block_stop")) ||
|
}
|
||||||
bytes.Contains(data, []byte("\n\n"))
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *ResponseRewriter) enableStreaming(reason string) error {
|
func (rw *ResponseRewriter) enableStreaming(reason string) error {
|
||||||
@@ -95,7 +99,8 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if rw.isStreaming {
|
if rw.isStreaming {
|
||||||
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
rewritten := rw.rewriteStreamChunk(data)
|
||||||
|
n, err := rw.ResponseWriter.Write(rewritten)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
@@ -106,7 +111,6 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
|||||||
return rw.body.Write(data)
|
return rw.body.Write(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush writes the buffered response with model names rewritten
|
|
||||||
func (rw *ResponseRewriter) Flush() {
|
func (rw *ResponseRewriter) Flush() {
|
||||||
if rw.isStreaming {
|
if rw.isStreaming {
|
||||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||||
@@ -115,40 +119,79 @@ func (rw *ResponseRewriter) Flush() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if rw.body.Len() > 0 {
|
if rw.body.Len() > 0 {
|
||||||
if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil {
|
rewritten := rw.rewriteModelInResponse(rw.body.Bytes())
|
||||||
|
// Update Content-Length to match the rewritten body size, since
|
||||||
|
// signature injection and model name changes alter the payload length.
|
||||||
|
rw.ResponseWriter.Header().Set("Content-Length", fmt.Sprintf("%d", len(rewritten)))
|
||||||
|
if _, err := rw.ResponseWriter.Write(rewritten); err != nil {
|
||||||
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
|
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// modelFieldPaths lists all JSON paths where model name may appear
|
|
||||||
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
||||||
|
|
||||||
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
|
||||||
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
|
// in API responses so that the Amp TUI does not crash on P.signature.length.
|
||||||
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
func ensureAmpSignature(data []byte) []byte {
|
||||||
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
|
for index, block := range gjson.GetBytes(data, "content").Array() {
|
||||||
// The Amp client struggles when both thinking and tool_use blocks are present
|
blockType := block.Get("type").String()
|
||||||
|
if blockType != "tool_use" && blockType != "thinking" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
signaturePath := fmt.Sprintf("content.%d.signature", index)
|
||||||
|
if gjson.GetBytes(data, signaturePath).Exists() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
data, err = sjson.SetBytes(data, signaturePath, "")
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Amp ResponseRewriter: failed to add empty signature to %s block: %v", blockType, err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
contentBlockType := gjson.GetBytes(data, "content_block.type").String()
|
||||||
|
if (contentBlockType == "tool_use" || contentBlockType == "thinking") && !gjson.GetBytes(data, "content_block.signature").Exists() {
|
||||||
|
var err error
|
||||||
|
data, err = sjson.SetBytes(data, "content_block.signature", "")
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Amp ResponseRewriter: failed to add empty signature to streaming %s block: %v", contentBlockType, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
|
||||||
|
if !rw.suppressThinking {
|
||||||
|
return data
|
||||||
|
}
|
||||||
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
|
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
|
||||||
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
|
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
|
||||||
if filtered.Exists() {
|
if filtered.Exists() {
|
||||||
originalCount := gjson.GetBytes(data, "content.#").Int()
|
originalCount := gjson.GetBytes(data, "content.#").Int()
|
||||||
filteredCount := filtered.Get("#").Int()
|
filteredCount := filtered.Get("#").Int()
|
||||||
|
|
||||||
if originalCount > filteredCount {
|
if originalCount > filteredCount {
|
||||||
var err error
|
var err error
|
||||||
data, err = sjson.SetBytes(data, "content", filtered.Value())
|
data, err = sjson.SetBytes(data, "content", filtered.Value())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
|
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
|
||||||
} else {
|
|
||||||
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
|
|
||||||
// Log the result for verification
|
|
||||||
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||||
|
data = ensureAmpSignature(data)
|
||||||
|
data = rw.suppressAmpThinking(data)
|
||||||
|
if len(data) == 0 {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
if rw.originalModel == "" {
|
if rw.originalModel == "" {
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
@@ -160,24 +203,164 @@ func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
// rewriteStreamChunk rewrites model names in SSE stream chunks
|
|
||||||
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
||||||
if rw.originalModel == "" {
|
lines := bytes.Split(chunk, []byte("\n"))
|
||||||
return chunk
|
var out [][]byte
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
for i < len(lines) {
|
||||||
|
line := lines[i]
|
||||||
|
trimmed := bytes.TrimSpace(line)
|
||||||
|
|
||||||
|
// Case 1: "event:" line - look ahead for its "data:" line
|
||||||
|
if bytes.HasPrefix(trimmed, []byte("event: ")) {
|
||||||
|
// Scan forward past blank lines to find the data: line
|
||||||
|
dataIdx := -1
|
||||||
|
for j := i + 1; j < len(lines); j++ {
|
||||||
|
t := bytes.TrimSpace(lines[j])
|
||||||
|
if len(t) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if bytes.HasPrefix(t, []byte("data: ")) {
|
||||||
|
dataIdx = j
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if dataIdx >= 0 {
|
||||||
|
// Found event+data pair - process through rewriter
|
||||||
|
jsonData := bytes.TrimPrefix(bytes.TrimSpace(lines[dataIdx]), []byte("data: "))
|
||||||
|
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||||
|
rewritten := rw.rewriteStreamEvent(jsonData)
|
||||||
|
if rewritten == nil {
|
||||||
|
i = dataIdx + 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Emit event line
|
||||||
|
out = append(out, line)
|
||||||
|
// Emit blank lines between event and data
|
||||||
|
for k := i + 1; k < dataIdx; k++ {
|
||||||
|
out = append(out, lines[k])
|
||||||
|
}
|
||||||
|
// Emit rewritten data
|
||||||
|
out = append(out, append([]byte("data: "), rewritten...))
|
||||||
|
i = dataIdx + 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No data line found (orphan event from cross-chunk split)
|
||||||
|
// Pass it through as-is - the data will arrive in the next chunk
|
||||||
|
out = append(out, line)
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 2: standalone "data:" line (no preceding event: in this chunk)
|
||||||
|
if bytes.HasPrefix(trimmed, []byte("data: ")) {
|
||||||
|
jsonData := bytes.TrimPrefix(trimmed, []byte("data: "))
|
||||||
|
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||||
|
rewritten := rw.rewriteStreamEvent(jsonData)
|
||||||
|
if rewritten != nil {
|
||||||
|
out = append(out, append([]byte("data: "), rewritten...))
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 3: everything else
|
||||||
|
out = append(out, line)
|
||||||
|
i++
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSE format: "data: {json}\n\n"
|
return bytes.Join(out, []byte("\n"))
|
||||||
lines := bytes.Split(chunk, []byte("\n"))
|
}
|
||||||
for i, line := range lines {
|
|
||||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
// rewriteStreamEvent processes a single JSON event in the SSE stream.
|
||||||
jsonData := bytes.TrimPrefix(line, []byte("data: "))
|
// It rewrites model names and ensures signature fields exist.
|
||||||
if len(jsonData) > 0 && jsonData[0] == '{' {
|
// NOTE: streaming mode does NOT suppress thinking blocks - they are
|
||||||
// Rewrite JSON in the data line
|
// passed through with signature injection to avoid breaking SSE index
|
||||||
rewritten := rw.rewriteModelInResponse(jsonData)
|
// alignment and TUI rendering.
|
||||||
lines[i] = append([]byte("data: "), rewritten...)
|
func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
|
||||||
|
// Inject empty signature where needed
|
||||||
|
data = ensureAmpSignature(data)
|
||||||
|
|
||||||
|
// Rewrite model name
|
||||||
|
if rw.originalModel != "" {
|
||||||
|
for _, path := range modelFieldPaths {
|
||||||
|
if gjson.GetBytes(data, path).Exists() {
|
||||||
|
data, _ = sjson.SetBytes(data, path, rw.originalModel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return bytes.Join(lines, []byte("\n"))
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures
|
||||||
|
// and strips the proxy-injected "signature" field from tool_use blocks in the messages
|
||||||
|
// array before forwarding to the upstream API.
|
||||||
|
// This prevents 400 errors from the API which requires valid signatures on thinking
|
||||||
|
// blocks and does not accept a signature field on tool_use blocks.
|
||||||
|
func SanitizeAmpRequestBody(body []byte) []byte {
|
||||||
|
messages := gjson.GetBytes(body, "messages")
|
||||||
|
if !messages.Exists() || !messages.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
modified := false
|
||||||
|
for msgIdx, msg := range messages.Array() {
|
||||||
|
if msg.Get("role").String() != "assistant" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content := msg.Get("content")
|
||||||
|
if !content.Exists() || !content.IsArray() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var keepBlocks []interface{}
|
||||||
|
contentModified := false
|
||||||
|
|
||||||
|
for _, block := range content.Array() {
|
||||||
|
blockType := block.Get("type").String()
|
||||||
|
if blockType == "thinking" {
|
||||||
|
sig := block.Get("signature")
|
||||||
|
if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" {
|
||||||
|
contentModified = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use raw JSON to prevent float64 rounding of large integers in tool_use inputs
|
||||||
|
blockRaw := []byte(block.Raw)
|
||||||
|
if blockType == "tool_use" && block.Get("signature").Exists() {
|
||||||
|
blockRaw, _ = sjson.DeleteBytes(blockRaw, "signature")
|
||||||
|
contentModified = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// sjson.SetBytes supports raw JSON strings if wrapped in gjson.Raw
|
||||||
|
keepBlocks = append(keepBlocks, json.RawMessage(blockRaw))
|
||||||
|
}
|
||||||
|
|
||||||
|
if contentModified {
|
||||||
|
contentPath := fmt.Sprintf("messages.%d.content", msgIdx)
|
||||||
|
var err error
|
||||||
|
if len(keepBlocks) == 0 {
|
||||||
|
body, err = sjson.SetBytes(body, contentPath, []interface{}{})
|
||||||
|
} else {
|
||||||
|
body, err = sjson.SetBytes(body, contentPath, keepBlocks)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Amp RequestSanitizer: failed to sanitize message %d: %v", msgIdx, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if modified {
|
||||||
|
log.Debugf("Amp RequestSanitizer: sanitized request body")
|
||||||
|
}
|
||||||
|
return body
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package amp
|
package amp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -100,6 +101,80 @@ func TestRewriteStreamChunk_MessageModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRewriteStreamChunk_PreservesThinkingWithSignatureInjection(t *testing.T) {
|
||||||
|
rw := &ResponseRewriter{}
|
||||||
|
|
||||||
|
chunk := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"abc\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"bash\",\"input\":{}}}\n\n")
|
||||||
|
result := rw.rewriteStreamChunk(chunk)
|
||||||
|
|
||||||
|
// Streaming mode preserves thinking blocks (does NOT suppress them)
|
||||||
|
// to avoid breaking SSE index alignment and TUI rendering
|
||||||
|
if !contains(result, []byte(`"content_block":{"type":"thinking"`)) {
|
||||||
|
t.Fatalf("expected thinking content_block_start to be preserved, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"delta":{"type":"thinking_delta"`)) {
|
||||||
|
t.Fatalf("expected thinking_delta to be preserved, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"type":"content_block_stop","index":0`)) {
|
||||||
|
t.Fatalf("expected content_block_stop for thinking block to be preserved, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"content_block":{"type":"tool_use"`)) {
|
||||||
|
t.Fatalf("expected tool_use content_block frame to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
// Signature should be injected into both thinking and tool_use blocks
|
||||||
|
if count := strings.Count(string(result), `"signature":""`); count != 2 {
|
||||||
|
t.Fatalf("expected 2 signature injections, but got %d in %s", count, string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testing.T) {
|
||||||
|
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-whitespace","signature":" "},{"type":"thinking","thinking":"drop-number","signature":123},{"type":"thinking","thinking":"keep-valid","signature":"valid-signature"},{"type":"text","text":"keep-text"}]}]}`)
|
||||||
|
result := SanitizeAmpRequestBody(input)
|
||||||
|
|
||||||
|
if contains(result, []byte("drop-whitespace")) {
|
||||||
|
t.Fatalf("expected whitespace-only signature block to be removed, got %s", string(result))
|
||||||
|
}
|
||||||
|
if contains(result, []byte("drop-number")) {
|
||||||
|
t.Fatalf("expected non-string signature block to be removed, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte("keep-valid")) {
|
||||||
|
t.Fatalf("expected valid thinking block to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte("keep-text")) {
|
||||||
|
t.Fatalf("expected non-thinking content to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeAmpRequestBody_StripsSignatureFromToolUseBlocks(t *testing.T) {
|
||||||
|
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"thought","signature":"valid-sig"},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`)
|
||||||
|
result := SanitizeAmpRequestBody(input)
|
||||||
|
|
||||||
|
if contains(result, []byte(`"signature":""`)) {
|
||||||
|
t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"valid-sig"`)) {
|
||||||
|
t.Fatalf("expected thinking signature to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"tool_use"`)) {
|
||||||
|
t.Fatalf("expected tool_use block to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testing.T) {
|
||||||
|
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-me","signature":""},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`)
|
||||||
|
result := SanitizeAmpRequestBody(input)
|
||||||
|
|
||||||
|
if contains(result, []byte("drop-me")) {
|
||||||
|
t.Fatalf("expected invalid thinking block to be removed, got %s", string(result))
|
||||||
|
}
|
||||||
|
if contains(result, []byte(`"signature"`)) {
|
||||||
|
t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"tool_use"`)) {
|
||||||
|
t.Fatalf("expected tool_use block to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func contains(data, substr []byte) bool {
|
func contains(data, substr []byte) bool {
|
||||||
for i := 0; i <= len(data)-len(substr); i++ {
|
for i := 0; i <= len(data)-len(substr); i++ {
|
||||||
if string(data[i:i+len(substr)]) == string(substr) {
|
if string(data[i:i+len(substr)]) == string(substr) {
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
||||||
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
|
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||||
@@ -262,6 +263,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
}
|
}
|
||||||
managementasset.SetCurrentConfig(cfg)
|
managementasset.SetCurrentConfig(cfg)
|
||||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||||
|
applySignatureCacheConfig(nil, cfg)
|
||||||
// Initialize management handler
|
// Initialize management handler
|
||||||
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
|
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
|
||||||
if optionState.localPassword != "" {
|
if optionState.localPassword != "" {
|
||||||
@@ -323,6 +325,10 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
// setupRoutes configures the API routes for the server.
|
// setupRoutes configures the API routes for the server.
|
||||||
// It defines the endpoints and associates them with their respective handlers.
|
// It defines the endpoints and associates them with their respective handlers.
|
||||||
func (s *Server) setupRoutes() {
|
func (s *Server) setupRoutes() {
|
||||||
|
s.engine.GET("/healthz", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||||
|
})
|
||||||
|
|
||||||
s.engine.GET("/management.html", s.serveManagementControlPanel)
|
s.engine.GET("/management.html", s.serveManagementControlPanel)
|
||||||
openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers)
|
openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers)
|
||||||
geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers)
|
geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers)
|
||||||
@@ -403,6 +409,20 @@ func (s *Server) setupRoutes() {
|
|||||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
s.engine.GET("/gitlab/callback", func(c *gin.Context) {
|
||||||
|
code := c.Query("code")
|
||||||
|
state := c.Query("state")
|
||||||
|
errStr := c.Query("error")
|
||||||
|
if errStr == "" {
|
||||||
|
errStr = c.Query("error_description")
|
||||||
|
}
|
||||||
|
if state != "" {
|
||||||
|
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gitlab", state, code, errStr)
|
||||||
|
}
|
||||||
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||||
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||||
|
})
|
||||||
|
|
||||||
s.engine.GET("/google/callback", func(c *gin.Context) {
|
s.engine.GET("/google/callback", func(c *gin.Context) {
|
||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
state := c.Query("state")
|
state := c.Query("state")
|
||||||
@@ -417,20 +437,6 @@ func (s *Server) setupRoutes() {
|
|||||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||||
})
|
})
|
||||||
|
|
||||||
s.engine.GET("/iflow/callback", func(c *gin.Context) {
|
|
||||||
code := c.Query("code")
|
|
||||||
state := c.Query("state")
|
|
||||||
errStr := c.Query("error")
|
|
||||||
if errStr == "" {
|
|
||||||
errStr = c.Query("error_description")
|
|
||||||
}
|
|
||||||
if state != "" {
|
|
||||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr)
|
|
||||||
}
|
|
||||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
|
||||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
|
||||||
})
|
|
||||||
|
|
||||||
s.engine.GET("/antigravity/callback", func(c *gin.Context) {
|
s.engine.GET("/antigravity/callback", func(c *gin.Context) {
|
||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
state := c.Query("state")
|
state := c.Query("state")
|
||||||
@@ -555,6 +561,8 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.PUT("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
mgmt.PUT("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
||||||
mgmt.PATCH("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
mgmt.PATCH("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
||||||
|
|
||||||
|
mgmt.GET("/copilot-quota", s.mgmt.GetCopilotQuota)
|
||||||
|
|
||||||
mgmt.GET("/api-keys", s.mgmt.GetAPIKeys)
|
mgmt.GET("/api-keys", s.mgmt.GetAPIKeys)
|
||||||
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
|
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
|
||||||
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)
|
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)
|
||||||
@@ -658,19 +666,21 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
|
|
||||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||||
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
|
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
|
||||||
|
mgmt.GET("/gitlab-auth-url", s.mgmt.RequestGitLabToken)
|
||||||
|
mgmt.POST("/gitlab-auth-url", s.mgmt.RequestGitLabPATToken)
|
||||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||||
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
||||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
mgmt.GET("/kilo-auth-url", s.mgmt.RequestKiloToken)
|
||||||
mgmt.GET("/kilo-auth-url", s.mgmt.RequestKiloToken)
|
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
|
||||||
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
|
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
||||||
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
mgmt.GET("/cursor-auth-url", s.mgmt.RequestCursorToken)
|
||||||
mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
|
mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
|
||||||
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
||||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) managementAvailabilityMiddleware() gin.HandlerFunc {
|
func (s *Server) managementAvailabilityMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
@@ -943,6 +953,8 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
applySignatureCacheConfig(oldCfg, cfg)
|
||||||
|
|
||||||
if s.handlers != nil && s.handlers.AuthManager != nil {
|
if s.handlers != nil && s.handlers.AuthManager != nil {
|
||||||
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
|
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
|
||||||
}
|
}
|
||||||
@@ -1081,3 +1093,37 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
|
|||||||
c.AbortWithStatusJSON(statusCode, gin.H{"error": err.Message})
|
c.AbortWithStatusJSON(statusCode, gin.H{"error": err.Message})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func configuredSignatureCacheEnabled(cfg *config.Config) bool {
|
||||||
|
if cfg != nil && cfg.AntigravitySignatureCacheEnabled != nil {
|
||||||
|
return *cfg.AntigravitySignatureCacheEnabled
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func applySignatureCacheConfig(oldCfg, cfg *config.Config) {
|
||||||
|
newVal := configuredSignatureCacheEnabled(cfg)
|
||||||
|
newStrict := configuredSignatureBypassStrict(cfg)
|
||||||
|
if oldCfg == nil {
|
||||||
|
cache.SetSignatureCacheEnabled(newVal)
|
||||||
|
cache.SetSignatureBypassStrictMode(newStrict)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
oldVal := configuredSignatureCacheEnabled(oldCfg)
|
||||||
|
if oldVal != newVal {
|
||||||
|
cache.SetSignatureCacheEnabled(newVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
oldStrict := configuredSignatureBypassStrict(oldCfg)
|
||||||
|
if oldStrict != newStrict {
|
||||||
|
cache.SetSignatureBypassStrictMode(newStrict)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func configuredSignatureBypassStrict(cfg *config.Config) bool {
|
||||||
|
if cfg != nil && cfg.AntigravitySignatureBypassStrict != nil {
|
||||||
|
return *cfg.AntigravitySignatureBypassStrict
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
@@ -46,6 +47,28 @@ func newTestServer(t *testing.T) *Server {
|
|||||||
return NewServer(cfg, authManager, accessManager, configPath)
|
return NewServer(cfg, authManager, accessManager, configPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHealthz(t *testing.T) {
|
||||||
|
server := newTestServer(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
server.engine.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status code: got %d want %d; body=%s", rr.Code, http.StatusOK, rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to parse response JSON: %v; body=%s", err, rr.Body.String())
|
||||||
|
}
|
||||||
|
if resp.Status != "ok" {
|
||||||
|
t.Fatalf("unexpected response status: got %q want %q", resp.Status, "ok")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAmpProviderModelRoutes(t *testing.T) {
|
func TestAmpProviderModelRoutes(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -172,6 +195,8 @@ func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
true,
|
true,
|
||||||
"issue-1711",
|
"issue-1711",
|
||||||
time.Now(),
|
time.Now(),
|
||||||
|
|||||||
@@ -59,10 +59,30 @@ type ClaudeAuth struct {
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - *ClaudeAuth: A new Claude authentication service instance
|
// - *ClaudeAuth: A new Claude authentication service instance
|
||||||
func NewClaudeAuth(cfg *config.Config) *ClaudeAuth {
|
func NewClaudeAuth(cfg *config.Config) *ClaudeAuth {
|
||||||
|
return NewClaudeAuthWithProxyURL(cfg, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClaudeAuthWithProxyURL creates a new Anthropic authentication service with a proxy override.
|
||||||
|
// proxyURL takes precedence over cfg.ProxyURL when non-empty.
|
||||||
|
func NewClaudeAuthWithProxyURL(cfg *config.Config, proxyURL string) *ClaudeAuth {
|
||||||
|
effectiveProxyURL := strings.TrimSpace(proxyURL)
|
||||||
|
var sdkCfg *config.SDKConfig
|
||||||
|
if cfg != nil {
|
||||||
|
sdkCfgCopy := cfg.SDKConfig
|
||||||
|
if effectiveProxyURL == "" {
|
||||||
|
effectiveProxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||||
|
}
|
||||||
|
sdkCfgCopy.ProxyURL = effectiveProxyURL
|
||||||
|
sdkCfg = &sdkCfgCopy
|
||||||
|
} else if effectiveProxyURL != "" {
|
||||||
|
sdkCfgCopy := config.SDKConfig{ProxyURL: effectiveProxyURL}
|
||||||
|
sdkCfg = &sdkCfgCopy
|
||||||
|
}
|
||||||
|
|
||||||
// Use custom HTTP client with Firefox TLS fingerprint to bypass
|
// Use custom HTTP client with Firefox TLS fingerprint to bypass
|
||||||
// Cloudflare's bot detection on Anthropic domains
|
// Cloudflare's bot detection on Anthropic domains
|
||||||
return &ClaudeAuth{
|
return &ClaudeAuth{
|
||||||
httpClient: NewAnthropicHttpClient(&cfg.SDKConfig),
|
httpClient: NewAnthropicHttpClient(sdkCfg),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,7 +108,7 @@ func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string
|
|||||||
"client_id": {ClientID},
|
"client_id": {ClientID},
|
||||||
"response_type": {"code"},
|
"response_type": {"code"},
|
||||||
"redirect_uri": {RedirectURI},
|
"redirect_uri": {RedirectURI},
|
||||||
"scope": {"org:create_api_key user:profile user:inference"},
|
"scope": {"user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload"},
|
||||||
"code_challenge": {pkceCodes.CodeChallenge},
|
"code_challenge": {pkceCodes.CodeChallenge},
|
||||||
"code_challenge_method": {"S256"},
|
"code_challenge_method": {"S256"},
|
||||||
"state": {state},
|
"state": {state},
|
||||||
|
|||||||
33
internal/auth/claude/anthropic_auth_proxy_test.go
Normal file
33
internal/auth/claude/anthropic_auth_proxy_test.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewClaudeAuthWithProxyURL_OverrideDirectTakesPrecedence(t *testing.T) {
|
||||||
|
cfg := &config.Config{SDKConfig: config.SDKConfig{ProxyURL: "socks5://proxy.example.com:1080"}}
|
||||||
|
auth := NewClaudeAuthWithProxyURL(cfg, "direct")
|
||||||
|
|
||||||
|
transport, ok := auth.httpClient.Transport.(*utlsRoundTripper)
|
||||||
|
if !ok || transport == nil {
|
||||||
|
t.Fatalf("expected utlsRoundTripper, got %T", auth.httpClient.Transport)
|
||||||
|
}
|
||||||
|
if transport.dialer != proxy.Direct {
|
||||||
|
t.Fatalf("expected proxy.Direct, got %T", transport.dialer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewClaudeAuthWithProxyURL_OverrideProxyAppliedWithoutConfig(t *testing.T) {
|
||||||
|
auth := NewClaudeAuthWithProxyURL(nil, "socks5://proxy.example.com:1080")
|
||||||
|
|
||||||
|
transport, ok := auth.httpClient.Transport.(*utlsRoundTripper)
|
||||||
|
if !ok || transport == nil {
|
||||||
|
t.Fatalf("expected utlsRoundTripper, got %T", auth.httpClient.Transport)
|
||||||
|
}
|
||||||
|
if transport.dialer == proxy.Direct {
|
||||||
|
t.Fatalf("expected proxy dialer, got %T", transport.dialer)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,12 +4,12 @@ package claude
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
tls "github.com/refraction-networking/utls"
|
tls "github.com/refraction-networking/utls"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/proxy"
|
"golang.org/x/net/proxy"
|
||||||
@@ -31,17 +31,12 @@ type utlsRoundTripper struct {
|
|||||||
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
|
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
|
||||||
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
|
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
|
||||||
var dialer proxy.Dialer = proxy.Direct
|
var dialer proxy.Dialer = proxy.Direct
|
||||||
if cfg != nil && cfg.ProxyURL != "" {
|
if cfg != nil {
|
||||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
proxyDialer, mode, errBuild := proxyutil.BuildDialer(cfg.ProxyURL)
|
||||||
if err != nil {
|
if errBuild != nil {
|
||||||
log.Errorf("failed to parse proxy URL %q: %v", cfg.ProxyURL, err)
|
log.Errorf("failed to configure proxy dialer for %q: %v", cfg.ProxyURL, errBuild)
|
||||||
} else {
|
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
|
||||||
pDialer, err := proxy.FromURL(proxyURL, proxy.Direct)
|
dialer = proxyDialer
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to create proxy dialer for %q: %v", cfg.ProxyURL, err)
|
|
||||||
} else {
|
|
||||||
dialer = pDialer
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
335
internal/auth/codebuddy/codebuddy_auth.go
Normal file
335
internal/auth/codebuddy/codebuddy_auth.go
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
package codebuddy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
BaseURL = "https://copilot.tencent.com"
|
||||||
|
DefaultDomain = "www.codebuddy.cn"
|
||||||
|
UserAgent = "CLI/2.63.2 CodeBuddy/2.63.2"
|
||||||
|
|
||||||
|
codeBuddyStatePath = "/v2/plugin/auth/state"
|
||||||
|
codeBuddyTokenPath = "/v2/plugin/auth/token"
|
||||||
|
codeBuddyRefreshPath = "/v2/plugin/auth/token/refresh"
|
||||||
|
pollInterval = 5 * time.Second
|
||||||
|
maxPollDuration = 5 * time.Minute
|
||||||
|
codeLoginPending = 11217
|
||||||
|
codeSuccess = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
type CodeBuddyAuth struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
cfg *config.Config
|
||||||
|
baseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCodeBuddyAuth(cfg *config.Config) *CodeBuddyAuth {
|
||||||
|
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
if cfg != nil {
|
||||||
|
httpClient = util.SetProxy(&cfg.SDKConfig, httpClient)
|
||||||
|
}
|
||||||
|
return &CodeBuddyAuth{httpClient: httpClient, cfg: cfg, baseURL: BaseURL}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthState holds the state and auth URL returned by the auth state API.
|
||||||
|
type AuthState struct {
|
||||||
|
State string
|
||||||
|
AuthURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchAuthState calls POST /v2/plugin/auth/state?platform=CLI to get the state and login URL.
|
||||||
|
func (a *CodeBuddyAuth) FetchAuthState(ctx context.Context) (*AuthState, error) {
|
||||||
|
stateURL := fmt.Sprintf("%s%s?platform=CLI", a.baseURL, codeBuddyStatePath)
|
||||||
|
body := []byte("{}")
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, stateURL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to create auth state request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||||
|
req.Header.Set("X-Domain", "copilot.tencent.com")
|
||||||
|
req.Header.Set("X-No-Authorization", "true")
|
||||||
|
req.Header.Set("X-No-User-Id", "true")
|
||||||
|
req.Header.Set("X-No-Enterprise-Id", "true")
|
||||||
|
req.Header.Set("X-No-Department-Info", "true")
|
||||||
|
req.Header.Set("X-Product", "SaaS")
|
||||||
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
|
req.Header.Set("X-Request-ID", requestID)
|
||||||
|
|
||||||
|
resp, err := a.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: auth state request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("codebuddy auth state: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to read auth state response: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("codebuddy: auth state request returned status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
Data *struct {
|
||||||
|
State string `json:"state"`
|
||||||
|
AuthURL string `json:"authUrl"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
if err = json.Unmarshal(bodyBytes, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to parse auth state response: %w", err)
|
||||||
|
}
|
||||||
|
if result.Code != codeSuccess {
|
||||||
|
return nil, fmt.Errorf("codebuddy: auth state request failed with code %d: %s", result.Code, result.Msg)
|
||||||
|
}
|
||||||
|
if result.Data == nil || result.Data.State == "" || result.Data.AuthURL == "" {
|
||||||
|
return nil, fmt.Errorf("codebuddy: auth state response missing state or authUrl")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &AuthState{
|
||||||
|
State: result.Data.State,
|
||||||
|
AuthURL: result.Data.AuthURL,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type pollResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
RequestID string `json:"requestId"`
|
||||||
|
Data *struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
ExpiresIn int64 `json:"expiresIn"`
|
||||||
|
TokenType string `json:"tokenType"`
|
||||||
|
Domain string `json:"domain"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// doPollRequest performs a single polling request, safely reading and closing the response body
|
||||||
|
func (a *CodeBuddyAuth) doPollRequest(ctx context.Context, pollURL string) ([]byte, int, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pollURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("%w: %v", ErrTokenFetchFailed, err)
|
||||||
|
}
|
||||||
|
a.applyPollHeaders(req)
|
||||||
|
|
||||||
|
resp, err := a.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("codebuddy poll: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, resp.StatusCode, fmt.Errorf("codebuddy poll: failed to read response body: %w", err)
|
||||||
|
}
|
||||||
|
return body, resp.StatusCode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PollForToken polls until the user completes browser authorization and returns auth data.
|
||||||
|
func (a *CodeBuddyAuth) PollForToken(ctx context.Context, state string) (*CodeBuddyTokenStorage, error) {
|
||||||
|
deadline := time.Now().Add(maxPollDuration)
|
||||||
|
pollURL := fmt.Sprintf("%s%s?state=%s", a.baseURL, codeBuddyTokenPath, url.QueryEscape(state))
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(pollInterval):
|
||||||
|
}
|
||||||
|
|
||||||
|
body, statusCode, err := a.doPollRequest(ctx, pollURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("codebuddy poll: request error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if statusCode != http.StatusOK {
|
||||||
|
log.Debugf("codebuddy poll: unexpected status %d", statusCode)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var result pollResponse
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch result.Code {
|
||||||
|
case codeSuccess:
|
||||||
|
if result.Data == nil {
|
||||||
|
return nil, fmt.Errorf("%w: empty data in response", ErrTokenFetchFailed)
|
||||||
|
}
|
||||||
|
userID, _ := a.DecodeUserID(result.Data.AccessToken)
|
||||||
|
return &CodeBuddyTokenStorage{
|
||||||
|
AccessToken: result.Data.AccessToken,
|
||||||
|
RefreshToken: result.Data.RefreshToken,
|
||||||
|
ExpiresIn: result.Data.ExpiresIn,
|
||||||
|
TokenType: result.Data.TokenType,
|
||||||
|
Domain: result.Data.Domain,
|
||||||
|
UserID: userID,
|
||||||
|
Type: "codebuddy",
|
||||||
|
}, nil
|
||||||
|
case codeLoginPending:
|
||||||
|
// continue polling
|
||||||
|
default:
|
||||||
|
// TODO: when the CodeBuddy API error code for user denial is known,
|
||||||
|
// return ErrAccessDenied here instead of ErrTokenFetchFailed.
|
||||||
|
return nil, fmt.Errorf("%w: server returned code %d: %s", ErrTokenFetchFailed, result.Code, result.Msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, ErrPollingTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeUserID decodes the sub field from a JWT access token as the user ID.
|
||||||
|
func (a *CodeBuddyAuth) DecodeUserID(accessToken string) (string, error) {
|
||||||
|
parts := strings.Split(accessToken, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return "", ErrJWTDecodeFailed
|
||||||
|
}
|
||||||
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("%w: %v", ErrJWTDecodeFailed, err)
|
||||||
|
}
|
||||||
|
var claims struct {
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||||
|
return "", fmt.Errorf("%w: %v", ErrJWTDecodeFailed, err)
|
||||||
|
}
|
||||||
|
if claims.Sub == "" {
|
||||||
|
return "", fmt.Errorf("%w: sub claim is empty", ErrJWTDecodeFailed)
|
||||||
|
}
|
||||||
|
return claims.Sub, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshToken exchanges a refresh token for a new access token.
|
||||||
|
// It calls POST /v2/plugin/auth/token/refresh with the required headers.
|
||||||
|
func (a *CodeBuddyAuth) RefreshToken(ctx context.Context, accessToken, refreshToken, userID, domain string) (*CodeBuddyTokenStorage, error) {
|
||||||
|
if domain == "" {
|
||||||
|
domain = DefaultDomain
|
||||||
|
}
|
||||||
|
refreshURL := fmt.Sprintf("%s%s", a.baseURL, codeBuddyRefreshPath)
|
||||||
|
body := []byte("{}")
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to create refresh request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
requestID := strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||||
|
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||||
|
req.Header.Set("X-Domain", domain)
|
||||||
|
req.Header.Set("X-Refresh-Token", refreshToken)
|
||||||
|
req.Header.Set("X-Auth-Refresh-Source", "plugin")
|
||||||
|
req.Header.Set("X-Request-ID", requestID)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("X-User-Id", userID)
|
||||||
|
req.Header.Set("X-Product", "SaaS")
|
||||||
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
|
|
||||||
|
resp, err := a.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: refresh request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("codebuddy refresh: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to read refresh response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||||
|
return nil, fmt.Errorf("codebuddy: refresh token rejected (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("codebuddy: refresh failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
Data *struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
ExpiresIn int64 `json:"expiresIn"`
|
||||||
|
RefreshExpiresIn int64 `json:"refreshExpiresIn"`
|
||||||
|
TokenType string `json:"tokenType"`
|
||||||
|
Domain string `json:"domain"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
if err = json.Unmarshal(bodyBytes, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to parse refresh response: %w", err)
|
||||||
|
}
|
||||||
|
if result.Code != codeSuccess {
|
||||||
|
return nil, fmt.Errorf("codebuddy: refresh failed with code %d: %s", result.Code, result.Msg)
|
||||||
|
}
|
||||||
|
if result.Data == nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: empty data in refresh response")
|
||||||
|
}
|
||||||
|
|
||||||
|
newUserID, _ := a.DecodeUserID(result.Data.AccessToken)
|
||||||
|
if newUserID == "" {
|
||||||
|
newUserID = userID
|
||||||
|
}
|
||||||
|
tokenDomain := result.Data.Domain
|
||||||
|
if tokenDomain == "" {
|
||||||
|
tokenDomain = domain
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CodeBuddyTokenStorage{
|
||||||
|
AccessToken: result.Data.AccessToken,
|
||||||
|
RefreshToken: result.Data.RefreshToken,
|
||||||
|
ExpiresIn: result.Data.ExpiresIn,
|
||||||
|
RefreshExpiresIn: result.Data.RefreshExpiresIn,
|
||||||
|
TokenType: result.Data.TokenType,
|
||||||
|
Domain: tokenDomain,
|
||||||
|
UserID: newUserID,
|
||||||
|
Type: "codebuddy",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *CodeBuddyAuth) applyPollHeaders(req *http.Request) {
|
||||||
|
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||||
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
|
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||||
|
req.Header.Set("X-No-Authorization", "true")
|
||||||
|
req.Header.Set("X-No-User-Id", "true")
|
||||||
|
req.Header.Set("X-No-Enterprise-Id", "true")
|
||||||
|
req.Header.Set("X-No-Department-Info", "true")
|
||||||
|
req.Header.Set("X-Product", "SaaS")
|
||||||
|
}
|
||||||
285
internal/auth/codebuddy/codebuddy_auth_http_test.go
Normal file
285
internal/auth/codebuddy/codebuddy_auth_http_test.go
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
package codebuddy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newTestAuth creates a CodeBuddyAuth pointing at the given test server.
|
||||||
|
func newTestAuth(serverURL string) *CodeBuddyAuth {
|
||||||
|
return &CodeBuddyAuth{
|
||||||
|
httpClient: http.DefaultClient,
|
||||||
|
baseURL: serverURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeJWT builds a minimal JWT with the given sub claim for testing.
|
||||||
|
func fakeJWT(sub string) string {
|
||||||
|
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`))
|
||||||
|
payload, _ := json.Marshal(map[string]any{"sub": sub, "iat": 1234567890})
|
||||||
|
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
|
||||||
|
return header + "." + encodedPayload + ".sig"
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- FetchAuthState tests ---
|
||||||
|
|
||||||
|
func TestFetchAuthState_Success(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Errorf("expected POST, got %s", r.Method)
|
||||||
|
}
|
||||||
|
if got := r.URL.Path; got != codeBuddyStatePath {
|
||||||
|
t.Errorf("expected path %s, got %s", codeBuddyStatePath, got)
|
||||||
|
}
|
||||||
|
if got := r.URL.Query().Get("platform"); got != "CLI" {
|
||||||
|
t.Errorf("expected platform=CLI, got %s", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("User-Agent"); got != UserAgent {
|
||||||
|
t.Errorf("expected User-Agent %s, got %s", UserAgent, got)
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"msg": "ok",
|
||||||
|
"data": map[string]any{
|
||||||
|
"state": "test-state-abc",
|
||||||
|
"authUrl": "https://example.com/login?state=test-state-abc",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
result, err := auth.FetchAuthState(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result.State != "test-state-abc" {
|
||||||
|
t.Errorf("expected state 'test-state-abc', got '%s'", result.State)
|
||||||
|
}
|
||||||
|
if result.AuthURL != "https://example.com/login?state=test-state-abc" {
|
||||||
|
t.Errorf("unexpected authURL: %s", result.AuthURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchAuthState_NonOKStatus(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_, _ = w.Write([]byte("internal error"))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.FetchAuthState(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-200 status")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchAuthState_APIErrorCode(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 10001,
|
||||||
|
"msg": "rate limited",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.FetchAuthState(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-zero code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchAuthState_MissingData(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"msg": "ok",
|
||||||
|
"data": map[string]any{
|
||||||
|
"state": "",
|
||||||
|
"authUrl": "",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.FetchAuthState(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for empty state/authUrl")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- RefreshToken tests ---
|
||||||
|
|
||||||
|
func TestRefreshToken_Success(t *testing.T) {
|
||||||
|
newAccessToken := fakeJWT("refreshed-user-456")
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Errorf("expected POST, got %s", r.Method)
|
||||||
|
}
|
||||||
|
if got := r.URL.Path; got != codeBuddyRefreshPath {
|
||||||
|
t.Errorf("expected path %s, got %s", codeBuddyRefreshPath, got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("X-Refresh-Token"); got != "old-refresh-token" {
|
||||||
|
t.Errorf("expected X-Refresh-Token 'old-refresh-token', got '%s'", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("Authorization"); got != "Bearer old-access-token" {
|
||||||
|
t.Errorf("expected Authorization 'Bearer old-access-token', got '%s'", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("X-User-Id"); got != "user-123" {
|
||||||
|
t.Errorf("expected X-User-Id 'user-123', got '%s'", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("X-Domain"); got != "custom.domain.com" {
|
||||||
|
t.Errorf("expected X-Domain 'custom.domain.com', got '%s'", got)
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"msg": "ok",
|
||||||
|
"data": map[string]any{
|
||||||
|
"accessToken": newAccessToken,
|
||||||
|
"refreshToken": "new-refresh-token",
|
||||||
|
"expiresIn": 3600,
|
||||||
|
"refreshExpiresIn": 86400,
|
||||||
|
"tokenType": "bearer",
|
||||||
|
"domain": "custom.domain.com",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
storage, err := auth.RefreshToken(context.Background(), "old-access-token", "old-refresh-token", "user-123", "custom.domain.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if storage.AccessToken != newAccessToken {
|
||||||
|
t.Errorf("expected new access token, got '%s'", storage.AccessToken)
|
||||||
|
}
|
||||||
|
if storage.RefreshToken != "new-refresh-token" {
|
||||||
|
t.Errorf("expected 'new-refresh-token', got '%s'", storage.RefreshToken)
|
||||||
|
}
|
||||||
|
if storage.UserID != "refreshed-user-456" {
|
||||||
|
t.Errorf("expected userID 'refreshed-user-456', got '%s'", storage.UserID)
|
||||||
|
}
|
||||||
|
if storage.ExpiresIn != 3600 {
|
||||||
|
t.Errorf("expected expiresIn 3600, got %d", storage.ExpiresIn)
|
||||||
|
}
|
||||||
|
if storage.RefreshExpiresIn != 86400 {
|
||||||
|
t.Errorf("expected refreshExpiresIn 86400, got %d", storage.RefreshExpiresIn)
|
||||||
|
}
|
||||||
|
if storage.Domain != "custom.domain.com" {
|
||||||
|
t.Errorf("expected domain 'custom.domain.com', got '%s'", storage.Domain)
|
||||||
|
}
|
||||||
|
if storage.Type != "codebuddy" {
|
||||||
|
t.Errorf("expected type 'codebuddy', got '%s'", storage.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_DefaultDomain(t *testing.T) {
|
||||||
|
var receivedDomain string
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedDomain = r.Header.Get("X-Domain")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"msg": "ok",
|
||||||
|
"data": map[string]any{
|
||||||
|
"accessToken": fakeJWT("user-1"),
|
||||||
|
"refreshToken": "rt",
|
||||||
|
"expiresIn": 3600,
|
||||||
|
"tokenType": "bearer",
|
||||||
|
"domain": DefaultDomain,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if receivedDomain != DefaultDomain {
|
||||||
|
t.Errorf("expected default domain '%s', got '%s'", DefaultDomain, receivedDomain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_Unauthorized(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 401 response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_Forbidden(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 403 response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_APIErrorCode(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 40001,
|
||||||
|
"msg": "invalid refresh token",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-zero API code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_FallbackUserIDAndDomain(t *testing.T) {
|
||||||
|
// When the new access token cannot be decoded for userID, it should fall back to the provided one.
|
||||||
|
// When the response domain is empty, it should fall back to the request domain.
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"msg": "ok",
|
||||||
|
"data": map[string]any{
|
||||||
|
"accessToken": "not-a-valid-jwt",
|
||||||
|
"refreshToken": "new-rt",
|
||||||
|
"expiresIn": 7200,
|
||||||
|
"tokenType": "bearer",
|
||||||
|
"domain": "",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
storage, err := auth.RefreshToken(context.Background(), "at", "rt", "original-uid", "original.domain.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if storage.UserID != "original-uid" {
|
||||||
|
t.Errorf("expected fallback userID 'original-uid', got '%s'", storage.UserID)
|
||||||
|
}
|
||||||
|
if storage.Domain != "original.domain.com" {
|
||||||
|
t.Errorf("expected fallback domain 'original.domain.com', got '%s'", storage.Domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
21
internal/auth/codebuddy/codebuddy_auth_test.go
Normal file
21
internal/auth/codebuddy/codebuddy_auth_test.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package codebuddy_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDecodeUserID_ValidJWT(t *testing.T) {
|
||||||
|
// JWT payload: {"sub":"test-user-id-123","iat":1234567890}
|
||||||
|
// base64url encode: eyJzdWIiOiJ0ZXN0LXVzZXItaWQtMTIzIiwiaWF0IjoxMjM0NTY3ODkwfQ
|
||||||
|
token := "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXItaWQtMTIzIiwiaWF0IjoxMjM0NTY3ODkwfQ.sig"
|
||||||
|
auth := codebuddy.NewCodeBuddyAuth(nil)
|
||||||
|
userID, err := auth.DecodeUserID(token)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if userID != "test-user-id-123" {
|
||||||
|
t.Errorf("expected 'test-user-id-123', got '%s'", userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
25
internal/auth/codebuddy/errors.go
Normal file
25
internal/auth/codebuddy/errors.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package codebuddy
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrPollingTimeout = errors.New("codebuddy: polling timeout, user did not authorize in time")
|
||||||
|
ErrAccessDenied = errors.New("codebuddy: access denied by user")
|
||||||
|
ErrTokenFetchFailed = errors.New("codebuddy: failed to fetch token from server")
|
||||||
|
ErrJWTDecodeFailed = errors.New("codebuddy: failed to decode JWT token")
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetUserFriendlyMessage(err error) string {
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, ErrPollingTimeout):
|
||||||
|
return "Authentication timed out. Please try again."
|
||||||
|
case errors.Is(err, ErrAccessDenied):
|
||||||
|
return "Access denied. Please try again and approve the login request."
|
||||||
|
case errors.Is(err, ErrJWTDecodeFailed):
|
||||||
|
return "Failed to decode token. Please try logging in again."
|
||||||
|
case errors.Is(err, ErrTokenFetchFailed):
|
||||||
|
return "Failed to fetch token from server. Please try again."
|
||||||
|
default:
|
||||||
|
return "Authentication failed: " + err.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
65
internal/auth/codebuddy/token.go
Normal file
65
internal/auth/codebuddy/token.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
// Package codebuddy provides authentication and token management functionality
|
||||||
|
// for CodeBuddy AI services. It handles OAuth2 token storage, serialization,
|
||||||
|
// and retrieval for maintaining authenticated sessions with the CodeBuddy API.
|
||||||
|
package codebuddy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CodeBuddyTokenStorage stores OAuth token information for CodeBuddy API authentication.
|
||||||
|
// It maintains compatibility with the existing auth system while adding CodeBuddy-specific fields
|
||||||
|
// for managing access tokens and user account information.
|
||||||
|
type CodeBuddyTokenStorage struct {
|
||||||
|
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
// RefreshToken is the OAuth2 refresh token used to obtain new access tokens.
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
// ExpiresIn is the number of seconds until the access token expires.
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
// RefreshExpiresIn is the number of seconds until the refresh token expires.
|
||||||
|
RefreshExpiresIn int64 `json:"refresh_expires_in,omitempty"`
|
||||||
|
// TokenType is the type of token, typically "bearer".
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
// Domain is the CodeBuddy service domain/region.
|
||||||
|
Domain string `json:"domain"`
|
||||||
|
// UserID is the user ID associated with this token.
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
// Type indicates the authentication provider type, always "codebuddy" for this storage.
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveTokenToFile serializes the CodeBuddy token storage to a JSON file.
|
||||||
|
// This method creates the necessary directory structure and writes the token
|
||||||
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - authFilePath: The full path where the token file should be saved
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: An error if the operation fails, nil otherwise
|
||||||
|
func (s *CodeBuddyTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||||
|
misc.LogSavingCredentials(authFilePath)
|
||||||
|
s.Type = "codebuddy"
|
||||||
|
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||||
|
return fmt.Errorf("failed to create directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.OpenFile(authFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create token file: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = f.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(s); err != nil {
|
||||||
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -37,8 +37,23 @@ type CodexAuth struct {
|
|||||||
// NewCodexAuth creates a new CodexAuth service instance.
|
// NewCodexAuth creates a new CodexAuth service instance.
|
||||||
// It initializes an HTTP client with proxy settings from the provided configuration.
|
// It initializes an HTTP client with proxy settings from the provided configuration.
|
||||||
func NewCodexAuth(cfg *config.Config) *CodexAuth {
|
func NewCodexAuth(cfg *config.Config) *CodexAuth {
|
||||||
|
return NewCodexAuthWithProxyURL(cfg, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCodexAuthWithProxyURL creates a new CodexAuth service instance.
|
||||||
|
// proxyURL takes precedence over cfg.ProxyURL when non-empty.
|
||||||
|
func NewCodexAuthWithProxyURL(cfg *config.Config, proxyURL string) *CodexAuth {
|
||||||
|
effectiveProxyURL := strings.TrimSpace(proxyURL)
|
||||||
|
var sdkCfg config.SDKConfig
|
||||||
|
if cfg != nil {
|
||||||
|
sdkCfg = cfg.SDKConfig
|
||||||
|
if effectiveProxyURL == "" {
|
||||||
|
effectiveProxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sdkCfg.ProxyURL = effectiveProxyURL
|
||||||
return &CodexAuth{
|
return &CodexAuth{
|
||||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
|
httpClient: util.SetProxy(&sdkCfg, &http.Client{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
@@ -42,3 +44,37 @@ func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
|
|||||||
t.Fatalf("expected 1 refresh attempt, got %d", got)
|
t.Fatalf("expected 1 refresh attempt, got %d", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewCodexAuthWithProxyURL_OverrideDirectDisablesProxy(t *testing.T) {
|
||||||
|
cfg := &config.Config{SDKConfig: config.SDKConfig{ProxyURL: "http://proxy.example.com:8080"}}
|
||||||
|
auth := NewCodexAuthWithProxyURL(cfg, "direct")
|
||||||
|
|
||||||
|
transport, ok := auth.httpClient.Transport.(*http.Transport)
|
||||||
|
if !ok || transport == nil {
|
||||||
|
t.Fatalf("expected http.Transport, got %T", auth.httpClient.Transport)
|
||||||
|
}
|
||||||
|
if transport.Proxy != nil {
|
||||||
|
t.Fatal("expected direct transport to disable proxy function")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCodexAuthWithProxyURL_OverrideProxyTakesPrecedence(t *testing.T) {
|
||||||
|
cfg := &config.Config{SDKConfig: config.SDKConfig{ProxyURL: "http://global.example.com:8080"}}
|
||||||
|
auth := NewCodexAuthWithProxyURL(cfg, "http://override.example.com:8081")
|
||||||
|
|
||||||
|
transport, ok := auth.httpClient.Transport.(*http.Transport)
|
||||||
|
if !ok || transport == nil {
|
||||||
|
t.Fatalf("expected http.Transport, got %T", auth.httpClient.Transport)
|
||||||
|
}
|
||||||
|
req, errReq := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||||
|
if errReq != nil {
|
||||||
|
t.Fatalf("new request: %v", errReq)
|
||||||
|
}
|
||||||
|
proxyURL, errProxy := transport.Proxy(req)
|
||||||
|
if errProxy != nil {
|
||||||
|
t.Fatalf("proxy func: %v", errProxy)
|
||||||
|
}
|
||||||
|
if proxyURL == nil || proxyURL.String() != "http://override.example.com:8081" {
|
||||||
|
t.Fatalf("proxy URL = %v, want http://override.example.com:8081", proxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
@@ -22,11 +24,11 @@ const (
|
|||||||
copilotAPIEndpoint = "https://api.githubcopilot.com"
|
copilotAPIEndpoint = "https://api.githubcopilot.com"
|
||||||
|
|
||||||
// Common HTTP header values for Copilot API requests.
|
// Common HTTP header values for Copilot API requests.
|
||||||
copilotUserAgent = "GithubCopilot/1.0"
|
copilotUserAgent = "GithubCopilot/1.0"
|
||||||
copilotEditorVersion = "vscode/1.100.0"
|
copilotEditorVersion = "vscode/1.100.0"
|
||||||
copilotPluginVersion = "copilot/1.300.0"
|
copilotPluginVersion = "copilot/1.300.0"
|
||||||
copilotIntegrationID = "vscode-chat"
|
copilotIntegrationID = "vscode-chat"
|
||||||
copilotOpenAIIntent = "conversation-panel"
|
copilotOpenAIIntent = "conversation-panel"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CopilotAPIToken represents the Copilot API token response.
|
// CopilotAPIToken represents the Copilot API token response.
|
||||||
@@ -222,6 +224,165 @@ func (c *CopilotAuth) MakeAuthenticatedRequest(ctx context.Context, method, url
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CopilotModelEntry represents a single model entry returned by the Copilot /models API.
|
||||||
|
type CopilotModelEntry struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Version string `json:"version,omitempty"`
|
||||||
|
Capabilities map[string]any `json:"capabilities,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopilotModelLimits holds the token limits returned by the Copilot /models API
|
||||||
|
// under capabilities.limits. These limits vary by account type (individual vs
|
||||||
|
// business) and are the authoritative source for enforcing prompt size.
|
||||||
|
type CopilotModelLimits struct {
|
||||||
|
// MaxContextWindowTokens is the total context window (prompt + output).
|
||||||
|
MaxContextWindowTokens int
|
||||||
|
// MaxPromptTokens is the hard limit on input/prompt tokens.
|
||||||
|
// Exceeding this triggers a 400 error from the Copilot API.
|
||||||
|
MaxPromptTokens int
|
||||||
|
// MaxOutputTokens is the maximum number of output/completion tokens.
|
||||||
|
MaxOutputTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limits extracts the token limits from the model's capabilities map.
|
||||||
|
// Returns nil if no limits are available or the structure is unexpected.
|
||||||
|
//
|
||||||
|
// Expected Copilot API shape:
|
||||||
|
//
|
||||||
|
// "capabilities": {
|
||||||
|
// "limits": {
|
||||||
|
// "max_context_window_tokens": 200000,
|
||||||
|
// "max_prompt_tokens": 168000,
|
||||||
|
// "max_output_tokens": 32000
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
func (e *CopilotModelEntry) Limits() *CopilotModelLimits {
|
||||||
|
if e.Capabilities == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
limitsRaw, ok := e.Capabilities["limits"]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
limitsMap, ok := limitsRaw.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &CopilotModelLimits{
|
||||||
|
MaxContextWindowTokens: anyToInt(limitsMap["max_context_window_tokens"]),
|
||||||
|
MaxPromptTokens: anyToInt(limitsMap["max_prompt_tokens"]),
|
||||||
|
MaxOutputTokens: anyToInt(limitsMap["max_output_tokens"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only return if at least one field is populated.
|
||||||
|
if result.MaxContextWindowTokens == 0 && result.MaxPromptTokens == 0 && result.MaxOutputTokens == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// anyToInt converts a JSON-decoded numeric value to int.
|
||||||
|
// Go's encoding/json decodes numbers into float64 when the target is any/interface{}.
|
||||||
|
func anyToInt(v any) int {
|
||||||
|
switch n := v.(type) {
|
||||||
|
case float64:
|
||||||
|
return int(n)
|
||||||
|
case float32:
|
||||||
|
return int(n)
|
||||||
|
case int:
|
||||||
|
return n
|
||||||
|
case int64:
|
||||||
|
return int(n)
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopilotModelsResponse represents the response from the Copilot /models endpoint.
|
||||||
|
type CopilotModelsResponse struct {
|
||||||
|
Data []CopilotModelEntry `json:"data"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// maxModelsResponseSize is the maximum allowed response size from the /models endpoint (2 MB).
|
||||||
|
const maxModelsResponseSize = 2 * 1024 * 1024
|
||||||
|
|
||||||
|
// allowedCopilotAPIHosts is the set of hosts that are considered safe for Copilot API requests.
|
||||||
|
var allowedCopilotAPIHosts = map[string]bool{
|
||||||
|
"api.githubcopilot.com": true,
|
||||||
|
"api.individual.githubcopilot.com": true,
|
||||||
|
"api.business.githubcopilot.com": true,
|
||||||
|
"copilot-proxy.githubusercontent.com": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListModels fetches the list of available models from the Copilot API.
|
||||||
|
// It requires a valid Copilot API token (not the GitHub access token).
|
||||||
|
func (c *CopilotAuth) ListModels(ctx context.Context, apiToken *CopilotAPIToken) ([]CopilotModelEntry, error) {
|
||||||
|
if apiToken == nil || apiToken.Token == "" {
|
||||||
|
return nil, fmt.Errorf("copilot: api token is required for listing models")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build models URL, validating the endpoint host to prevent SSRF.
|
||||||
|
modelsURL := copilotAPIEndpoint + "/models"
|
||||||
|
if ep := strings.TrimRight(apiToken.Endpoints.API, "/"); ep != "" {
|
||||||
|
parsed, err := url.Parse(ep)
|
||||||
|
if err == nil && parsed.Scheme == "https" && allowedCopilotAPIHosts[parsed.Host] {
|
||||||
|
modelsURL = ep + "/models"
|
||||||
|
} else {
|
||||||
|
log.Warnf("copilot: ignoring untrusted API endpoint %q, using default", ep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := c.MakeAuthenticatedRequest(ctx, http.MethodGet, modelsURL, nil, apiToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("copilot: failed to create models request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("copilot: models request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("copilot list models: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Limit response body to prevent memory exhaustion.
|
||||||
|
limitedReader := io.LimitReader(resp.Body, maxModelsResponseSize)
|
||||||
|
bodyBytes, err := io.ReadAll(limitedReader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("copilot: failed to read models response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isHTTPSuccess(resp.StatusCode) {
|
||||||
|
return nil, fmt.Errorf("copilot: list models failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var modelsResp CopilotModelsResponse
|
||||||
|
if err = json.Unmarshal(bodyBytes, &modelsResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("copilot: failed to parse models response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelsResp.Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListModelsWithGitHubToken is a convenience method that exchanges a GitHub access token
|
||||||
|
// for a Copilot API token and then fetches the available models.
|
||||||
|
func (c *CopilotAuth) ListModelsWithGitHubToken(ctx context.Context, githubAccessToken string) ([]CopilotModelEntry, error) {
|
||||||
|
apiToken, err := c.GetCopilotAPIToken(ctx, githubAccessToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("copilot: failed to get API token for model listing: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.ListModels(ctx, apiToken)
|
||||||
|
}
|
||||||
|
|
||||||
// buildChatCompletionURL builds the URL for chat completions API.
|
// buildChatCompletionURL builds the URL for chat completions API.
|
||||||
func buildChatCompletionURL() string {
|
func buildChatCompletionURL() string {
|
||||||
return copilotAPIEndpoint + "/chat/completions"
|
return copilotAPIEndpoint + "/chat/completions"
|
||||||
|
|||||||
33
internal/auth/cursor/filename.go
Normal file
33
internal/auth/cursor/filename.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package cursor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CredentialFileName returns the filename used to persist Cursor credentials.
|
||||||
|
// Priority: explicit label > auto-generated from JWT sub hash.
|
||||||
|
// If both label and subHash are empty, falls back to "cursor.json".
|
||||||
|
func CredentialFileName(label, subHash string) string {
|
||||||
|
label = strings.TrimSpace(label)
|
||||||
|
subHash = strings.TrimSpace(subHash)
|
||||||
|
if label != "" {
|
||||||
|
return fmt.Sprintf("cursor.%s.json", label)
|
||||||
|
}
|
||||||
|
if subHash != "" {
|
||||||
|
return fmt.Sprintf("cursor.%s.json", subHash)
|
||||||
|
}
|
||||||
|
return "cursor.json"
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisplayLabel returns a human-readable label for the Cursor account.
|
||||||
|
func DisplayLabel(label, subHash string) string {
|
||||||
|
label = strings.TrimSpace(label)
|
||||||
|
if label != "" {
|
||||||
|
return "Cursor " + label
|
||||||
|
}
|
||||||
|
if subHash != "" {
|
||||||
|
return "Cursor " + subHash
|
||||||
|
}
|
||||||
|
return "Cursor User"
|
||||||
|
}
|
||||||
249
internal/auth/cursor/oauth.go
Normal file
249
internal/auth/cursor/oauth.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
// Package cursor implements Cursor OAuth PKCE authentication and token refresh.
|
||||||
|
package cursor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CursorLoginURL = "https://cursor.com/loginDeepControl"
|
||||||
|
CursorPollURL = "https://api2.cursor.sh/auth/poll"
|
||||||
|
CursorRefreshURL = "https://api2.cursor.sh/auth/exchange_user_api_key"
|
||||||
|
|
||||||
|
pollMaxAttempts = 150
|
||||||
|
pollBaseDelay = 1 * time.Second
|
||||||
|
pollMaxDelay = 10 * time.Second
|
||||||
|
pollBackoffMultiply = 1.2
|
||||||
|
maxConsecutiveErrors = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthParams holds the PKCE parameters for Cursor login.
|
||||||
|
type AuthParams struct {
|
||||||
|
Verifier string
|
||||||
|
Challenge string
|
||||||
|
UUID string
|
||||||
|
LoginURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenPair holds the access and refresh tokens from Cursor.
|
||||||
|
type TokenPair struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GeneratePKCE creates a PKCE verifier and challenge pair.
|
||||||
|
func GeneratePKCE() (verifier, challenge string, err error) {
|
||||||
|
verifierBytes := make([]byte, 96)
|
||||||
|
if _, err = rand.Read(verifierBytes); err != nil {
|
||||||
|
return "", "", fmt.Errorf("cursor: failed to generate PKCE verifier: %w", err)
|
||||||
|
}
|
||||||
|
verifier = base64.RawURLEncoding.EncodeToString(verifierBytes)
|
||||||
|
|
||||||
|
h := sha256.Sum256([]byte(verifier))
|
||||||
|
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||||
|
return verifier, challenge, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateAuthParams creates the full set of auth params for Cursor login.
|
||||||
|
func GenerateAuthParams() (*AuthParams, error) {
|
||||||
|
verifier, challenge, err := GeneratePKCE()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
uuidBytes := make([]byte, 16)
|
||||||
|
if _, err = rand.Read(uuidBytes); err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: failed to generate UUID: %w", err)
|
||||||
|
}
|
||||||
|
uuid := fmt.Sprintf("%x-%x-%x-%x-%x",
|
||||||
|
uuidBytes[0:4], uuidBytes[4:6], uuidBytes[6:8], uuidBytes[8:10], uuidBytes[10:16])
|
||||||
|
|
||||||
|
loginURL := fmt.Sprintf("%s?challenge=%s&uuid=%s&mode=login&redirectTarget=cli",
|
||||||
|
CursorLoginURL, challenge, uuid)
|
||||||
|
|
||||||
|
return &AuthParams{
|
||||||
|
Verifier: verifier,
|
||||||
|
Challenge: challenge,
|
||||||
|
UUID: uuid,
|
||||||
|
LoginURL: loginURL,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PollForAuth polls the Cursor auth endpoint until the user completes login.
|
||||||
|
func PollForAuth(ctx context.Context, uuid, verifier string) (*TokenPair, error) {
|
||||||
|
delay := pollBaseDelay
|
||||||
|
consecutiveErrors := 0
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 10 * time.Second}
|
||||||
|
|
||||||
|
for attempt := 0; attempt < pollMaxAttempts; attempt++ {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(delay):
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s?uuid=%s&verifier=%s", CursorPollURL, uuid, verifier)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: failed to create poll request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
consecutiveErrors++
|
||||||
|
if consecutiveErrors >= maxConsecutiveErrors {
|
||||||
|
return nil, fmt.Errorf("cursor: too many consecutive poll errors (last: %v)", err)
|
||||||
|
}
|
||||||
|
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusNotFound {
|
||||||
|
// Still waiting for user to authorize
|
||||||
|
consecutiveErrors = 0
|
||||||
|
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||||
|
var tokens TokenPair
|
||||||
|
if err := json.Unmarshal(body, &tokens); err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: failed to parse auth response: %w", err)
|
||||||
|
}
|
||||||
|
return &tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("cursor: poll failed with status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("cursor: authentication polling timeout (waited ~%.0f seconds)",
|
||||||
|
float64(pollMaxAttempts)*pollMaxDelay.Seconds()/2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshToken refreshes a Cursor access token using the refresh token.
|
||||||
|
func RefreshToken(ctx context.Context, refreshToken string) (*TokenPair, error) {
|
||||||
|
client := &http.Client{Timeout: 10 * time.Second}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, CursorRefreshURL,
|
||||||
|
strings.NewReader("{}"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: failed to create refresh request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+refreshToken)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: token refresh request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("cursor: token refresh failed (status %d): %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokens TokenPair
|
||||||
|
if err := json.Unmarshal(body, &tokens); err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: failed to parse refresh response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep original refresh token if not returned
|
||||||
|
if tokens.RefreshToken == "" {
|
||||||
|
tokens.RefreshToken = refreshToken
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseJWTSub extracts the "sub" claim from a Cursor JWT access token.
|
||||||
|
// Cursor JWTs contain "sub" like "auth0|user_XXXX" which uniquely identifies
|
||||||
|
// the account. Returns empty string if parsing fails.
|
||||||
|
func ParseJWTSub(token string) string {
|
||||||
|
decoded := decodeJWTPayload(token)
|
||||||
|
if decoded == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var claims struct {
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return claims.Sub
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubToShortHash converts a JWT sub claim to a short hex hash for use in filenames.
|
||||||
|
// e.g. "auth0|user_2x..." → "a3f8b2c1"
|
||||||
|
func SubToShortHash(sub string) string {
|
||||||
|
if sub == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
h := sha256.Sum256([]byte(sub))
|
||||||
|
return fmt.Sprintf("%x", h[:4]) // 8 hex chars
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeJWTPayload decodes the payload (middle) part of a JWT.
|
||||||
|
func decodeJWTPayload(token string) []byte {
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payload := parts[1]
|
||||||
|
switch len(payload) % 4 {
|
||||||
|
case 2:
|
||||||
|
payload += "=="
|
||||||
|
case 3:
|
||||||
|
payload += "="
|
||||||
|
}
|
||||||
|
payload = strings.ReplaceAll(payload, "-", "+")
|
||||||
|
payload = strings.ReplaceAll(payload, "_", "/")
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return decoded
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTokenExpiry extracts the JWT expiry from an access token with a 5-minute safety margin.
|
||||||
|
// Falls back to 1 hour from now if the token can't be parsed.
|
||||||
|
func GetTokenExpiry(token string) time.Time {
|
||||||
|
decoded := decodeJWTPayload(token)
|
||||||
|
if decoded == nil {
|
||||||
|
return time.Now().Add(1 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims struct {
|
||||||
|
Exp float64 `json:"exp"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(decoded, &claims); err != nil || claims.Exp == 0 {
|
||||||
|
return time.Now().Add(1 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
sec, frac := math.Modf(claims.Exp)
|
||||||
|
expiry := time.Unix(int64(sec), int64(frac*1e9))
|
||||||
|
// Subtract 5-minute safety margin
|
||||||
|
return expiry.Add(-5 * time.Minute)
|
||||||
|
}
|
||||||
|
|
||||||
|
func minDuration(a, b time.Duration) time.Duration {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
84
internal/auth/cursor/proto/connect.go
Normal file
84
internal/auth/cursor/proto/connect.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package proto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ConnectEndStreamFlag marks the end-of-stream frame (trailers).
|
||||||
|
ConnectEndStreamFlag byte = 0x02
|
||||||
|
// ConnectCompressionFlag indicates the payload is compressed (not supported).
|
||||||
|
ConnectCompressionFlag byte = 0x01
|
||||||
|
// ConnectFrameHeaderSize is the fixed 5-byte frame header.
|
||||||
|
ConnectFrameHeaderSize = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
// FrameConnectMessage wraps a protobuf payload in a Connect frame.
|
||||||
|
// Frame format: [1 byte flags][4 bytes payload length (big-endian)][payload]
|
||||||
|
func FrameConnectMessage(data []byte, flags byte) []byte {
|
||||||
|
frame := make([]byte, ConnectFrameHeaderSize+len(data))
|
||||||
|
frame[0] = flags
|
||||||
|
binary.BigEndian.PutUint32(frame[1:5], uint32(len(data)))
|
||||||
|
copy(frame[5:], data)
|
||||||
|
return frame
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConnectFrame extracts one frame from a buffer.
|
||||||
|
// Returns (flags, payload, bytesConsumed, ok).
|
||||||
|
// ok is false when the buffer is too short for a complete frame.
|
||||||
|
func ParseConnectFrame(buf []byte) (flags byte, payload []byte, consumed int, ok bool) {
|
||||||
|
if len(buf) < ConnectFrameHeaderSize {
|
||||||
|
return 0, nil, 0, false
|
||||||
|
}
|
||||||
|
flags = buf[0]
|
||||||
|
length := binary.BigEndian.Uint32(buf[1:5])
|
||||||
|
total := ConnectFrameHeaderSize + int(length)
|
||||||
|
if len(buf) < total {
|
||||||
|
return 0, nil, 0, false
|
||||||
|
}
|
||||||
|
return flags, buf[5:total], total, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectError is a structured error from the Connect protocol end-of-stream trailer.
|
||||||
|
// The Code field contains the server-defined error code (e.g. gRPC standard codes
|
||||||
|
// like "resource_exhausted", "unauthenticated", "permission_denied", "unavailable").
|
||||||
|
type ConnectError struct {
|
||||||
|
Code string // server-defined error code
|
||||||
|
Message string // human-readable error description
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnectError) Error() string {
|
||||||
|
return fmt.Sprintf("Connect error %s: %s", e.Code, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConnectEndStream parses a Connect end-of-stream frame payload (JSON).
|
||||||
|
// Returns nil if there is no error in the trailer.
|
||||||
|
// On error, returns a *ConnectError with the server's error code and message.
|
||||||
|
func ParseConnectEndStream(data []byte) error {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var trailer struct {
|
||||||
|
Error *struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
} `json:"error"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &trailer); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse Connect end stream: %w", err)
|
||||||
|
}
|
||||||
|
if trailer.Error != nil {
|
||||||
|
code := trailer.Error.Code
|
||||||
|
if code == "" {
|
||||||
|
code = "unknown"
|
||||||
|
}
|
||||||
|
msg := trailer.Error.Message
|
||||||
|
if msg == "" {
|
||||||
|
msg = "Unknown error"
|
||||||
|
}
|
||||||
|
return &ConnectError{Code: code, Message: msg}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
563
internal/auth/cursor/proto/decode.go
Normal file
563
internal/auth/cursor/proto/decode.go
Normal file
@@ -0,0 +1,563 @@
|
|||||||
|
package proto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/protobuf/encoding/protowire"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServerMessageType identifies the kind of decoded server message.
|
||||||
|
type ServerMessageType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ServerMsgUnknown ServerMessageType = iota
|
||||||
|
ServerMsgTextDelta // Text content delta
|
||||||
|
ServerMsgThinkingDelta // Thinking/reasoning delta
|
||||||
|
ServerMsgThinkingCompleted // Thinking completed
|
||||||
|
ServerMsgKvGetBlob // Server wants a blob
|
||||||
|
ServerMsgKvSetBlob // Server wants to store a blob
|
||||||
|
ServerMsgExecRequestCtx // Server requests context (tools, etc.)
|
||||||
|
ServerMsgExecMcpArgs // Server wants MCP tool execution
|
||||||
|
ServerMsgExecShellArgs // Rejected: shell command
|
||||||
|
ServerMsgExecReadArgs // Rejected: file read
|
||||||
|
ServerMsgExecWriteArgs // Rejected: file write
|
||||||
|
ServerMsgExecDeleteArgs // Rejected: file delete
|
||||||
|
ServerMsgExecLsArgs // Rejected: directory listing
|
||||||
|
ServerMsgExecGrepArgs // Rejected: grep search
|
||||||
|
ServerMsgExecFetchArgs // Rejected: HTTP fetch
|
||||||
|
ServerMsgExecDiagnostics // Respond with empty diagnostics
|
||||||
|
ServerMsgExecShellStream // Rejected: shell stream
|
||||||
|
ServerMsgExecBgShellSpawn // Rejected: background shell
|
||||||
|
ServerMsgExecWriteShellStdin // Rejected: write shell stdin
|
||||||
|
ServerMsgExecOther // Other exec types (respond with empty)
|
||||||
|
ServerMsgTurnEnded // Turn has ended (no more output)
|
||||||
|
ServerMsgHeartbeat // Server heartbeat
|
||||||
|
ServerMsgTokenDelta // Token usage delta
|
||||||
|
ServerMsgCheckpoint // Conversation checkpoint update
|
||||||
|
)
|
||||||
|
|
||||||
|
// DecodedServerMessage holds parsed data from an AgentServerMessage.
|
||||||
|
type DecodedServerMessage struct {
|
||||||
|
Type ServerMessageType
|
||||||
|
|
||||||
|
// For text/thinking deltas
|
||||||
|
Text string
|
||||||
|
|
||||||
|
// For KV messages
|
||||||
|
KvId uint32
|
||||||
|
BlobId []byte // hex-encoded blob ID
|
||||||
|
BlobData []byte // for setBlobArgs
|
||||||
|
|
||||||
|
// For exec messages
|
||||||
|
ExecMsgId uint32
|
||||||
|
ExecId string
|
||||||
|
|
||||||
|
// For MCP args
|
||||||
|
McpToolName string
|
||||||
|
McpToolCallId string
|
||||||
|
McpArgs map[string][]byte // arg name -> protobuf-encoded value
|
||||||
|
|
||||||
|
// For rejection context
|
||||||
|
Path string
|
||||||
|
Command string
|
||||||
|
WorkingDirectory string
|
||||||
|
Url string
|
||||||
|
|
||||||
|
// For other exec - the raw field number for building a response
|
||||||
|
ExecFieldNumber int
|
||||||
|
|
||||||
|
// For TokenDeltaUpdate
|
||||||
|
TokenDelta int64
|
||||||
|
|
||||||
|
// For conversation checkpoint update (raw bytes, not decoded)
|
||||||
|
CheckpointData []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeAgentServerMessage parses an AgentServerMessage and returns
|
||||||
|
// a structured representation of the first meaningful message found.
|
||||||
|
func DecodeAgentServerMessage(data []byte) (*DecodedServerMessage, error) {
|
||||||
|
msg := &DecodedServerMessage{Type: ServerMsgUnknown}
|
||||||
|
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return msg, fmt.Errorf("invalid tag")
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
switch typ {
|
||||||
|
case protowire.BytesType:
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return msg, fmt.Errorf("invalid bytes field %d", num)
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
// Debug: log top-level ASM fields
|
||||||
|
log.Debugf("DecodeAgentServerMessage: found ASM field %d, len=%d", num, len(val))
|
||||||
|
|
||||||
|
switch num {
|
||||||
|
case ASM_InteractionUpdate:
|
||||||
|
log.Debugf("DecodeAgentServerMessage: calling decodeInteractionUpdate")
|
||||||
|
decodeInteractionUpdate(val, msg)
|
||||||
|
case ASM_ExecServerMessage:
|
||||||
|
log.Debugf("DecodeAgentServerMessage: calling decodeExecServerMessage")
|
||||||
|
decodeExecServerMessage(val, msg)
|
||||||
|
case ASM_KvServerMessage:
|
||||||
|
decodeKvServerMessage(val, msg)
|
||||||
|
case ASM_ConversationCheckpoint:
|
||||||
|
msg.Type = ServerMsgCheckpoint
|
||||||
|
msg.CheckpointData = append([]byte(nil), val...) // copy raw bytes
|
||||||
|
log.Debugf("DecodeAgentServerMessage: captured checkpoint %d bytes", len(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
case protowire.VarintType:
|
||||||
|
_, n := protowire.ConsumeVarint(data)
|
||||||
|
if n < 0 {
|
||||||
|
return msg, fmt.Errorf("invalid varint field %d", num)
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Skip unknown wire types
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return msg, fmt.Errorf("invalid field %d", num)
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeInteractionUpdate(data []byte, msg *DecodedServerMessage) {
|
||||||
|
log.Debugf("decodeInteractionUpdate: input len=%d, hex=%x", len(data), data)
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
log.Debugf("decodeInteractionUpdate: invalid tag, remaining=%x", data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
log.Debugf("decodeInteractionUpdate: field=%d wire=%d remaining=%d bytes", num, typ, len(data))
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
log.Debugf("decodeInteractionUpdate: invalid bytes field %d", num)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
log.Debugf("decodeInteractionUpdate: field %d content len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
|
||||||
|
|
||||||
|
switch num {
|
||||||
|
case IU_TextDelta:
|
||||||
|
msg.Type = ServerMsgTextDelta
|
||||||
|
msg.Text = decodeStringField(val, TDU_Text)
|
||||||
|
log.Debugf("decodeInteractionUpdate: TextDelta text=%q", msg.Text)
|
||||||
|
case IU_ThinkingDelta:
|
||||||
|
msg.Type = ServerMsgThinkingDelta
|
||||||
|
msg.Text = decodeStringField(val, TKD_Text)
|
||||||
|
log.Debugf("decodeInteractionUpdate: ThinkingDelta text=%q", msg.Text)
|
||||||
|
case IU_ThinkingCompleted:
|
||||||
|
msg.Type = ServerMsgThinkingCompleted
|
||||||
|
log.Debugf("decodeInteractionUpdate: ThinkingCompleted")
|
||||||
|
case 2:
|
||||||
|
// tool_call_started - ignore but log
|
||||||
|
log.Debugf("decodeInteractionUpdate: ToolCallStarted (ignored)")
|
||||||
|
case 3:
|
||||||
|
// tool_call_completed - ignore but log
|
||||||
|
log.Debugf("decodeInteractionUpdate: ToolCallCompleted (ignored)")
|
||||||
|
case 8:
|
||||||
|
// token_delta - extract token count
|
||||||
|
msg.Type = ServerMsgTokenDelta
|
||||||
|
msg.TokenDelta = decodeVarintField(val, 1)
|
||||||
|
log.Debugf("decodeInteractionUpdate: TokenDeltaUpdate tokens=%d", msg.TokenDelta)
|
||||||
|
case 13:
|
||||||
|
// heartbeat from server
|
||||||
|
msg.Type = ServerMsgHeartbeat
|
||||||
|
case 14:
|
||||||
|
// turn_ended - critical: model finished generating
|
||||||
|
msg.Type = ServerMsgTurnEnded
|
||||||
|
log.Debugf("decodeInteractionUpdate: TurnEndedUpdate - stream should end")
|
||||||
|
case 16:
|
||||||
|
// step_started - ignore
|
||||||
|
log.Debugf("decodeInteractionUpdate: StepStartedUpdate (ignored)")
|
||||||
|
case 17:
|
||||||
|
// step_completed - ignore
|
||||||
|
log.Debugf("decodeInteractionUpdate: StepCompletedUpdate (ignored)")
|
||||||
|
default:
|
||||||
|
log.Debugf("decodeInteractionUpdate: unknown field %d", num)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeKvServerMessage(data []byte, msg *DecodedServerMessage) {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
switch typ {
|
||||||
|
case protowire.VarintType:
|
||||||
|
val, n := protowire.ConsumeVarint(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == KSM_Id {
|
||||||
|
msg.KvId = uint32(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
case protowire.BytesType:
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
switch num {
|
||||||
|
case KSM_GetBlobArgs:
|
||||||
|
msg.Type = ServerMsgKvGetBlob
|
||||||
|
msg.BlobId = decodeBytesField(val, GBA_BlobId)
|
||||||
|
case KSM_SetBlobArgs:
|
||||||
|
msg.Type = ServerMsgKvSetBlob
|
||||||
|
decodeSetBlobArgs(val, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeSetBlobArgs(data []byte, msg *DecodedServerMessage) {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
switch num {
|
||||||
|
case SBA_BlobId:
|
||||||
|
msg.BlobId = val
|
||||||
|
case SBA_BlobData:
|
||||||
|
msg.BlobData = val
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeExecServerMessage(data []byte, msg *DecodedServerMessage) {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
switch typ {
|
||||||
|
case protowire.VarintType:
|
||||||
|
val, n := protowire.ConsumeVarint(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == ESM_Id {
|
||||||
|
msg.ExecMsgId = uint32(val)
|
||||||
|
log.Debugf("decodeExecServerMessage: ESM_Id = %d", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
case protowire.BytesType:
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
// Debug: log all fields found in ExecServerMessage
|
||||||
|
log.Debugf("decodeExecServerMessage: found field %d, len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
|
||||||
|
|
||||||
|
switch num {
|
||||||
|
case ESM_ExecId:
|
||||||
|
msg.ExecId = string(val)
|
||||||
|
log.Debugf("decodeExecServerMessage: ESM_ExecId = %q", msg.ExecId)
|
||||||
|
case ESM_RequestContextArgs:
|
||||||
|
msg.Type = ServerMsgExecRequestCtx
|
||||||
|
case ESM_McpArgs:
|
||||||
|
msg.Type = ServerMsgExecMcpArgs
|
||||||
|
decodeMcpArgs(val, msg)
|
||||||
|
case ESM_ShellArgs:
|
||||||
|
msg.Type = ServerMsgExecShellArgs
|
||||||
|
decodeShellArgs(val, msg)
|
||||||
|
case ESM_ShellStreamArgs:
|
||||||
|
msg.Type = ServerMsgExecShellStream
|
||||||
|
decodeShellArgs(val, msg)
|
||||||
|
case ESM_ReadArgs:
|
||||||
|
msg.Type = ServerMsgExecReadArgs
|
||||||
|
msg.Path = decodeStringField(val, RA_Path)
|
||||||
|
case ESM_WriteArgs:
|
||||||
|
msg.Type = ServerMsgExecWriteArgs
|
||||||
|
msg.Path = decodeStringField(val, WA_Path)
|
||||||
|
case ESM_DeleteArgs:
|
||||||
|
msg.Type = ServerMsgExecDeleteArgs
|
||||||
|
msg.Path = decodeStringField(val, DA_Path)
|
||||||
|
case ESM_LsArgs:
|
||||||
|
msg.Type = ServerMsgExecLsArgs
|
||||||
|
msg.Path = decodeStringField(val, LA_Path)
|
||||||
|
case ESM_GrepArgs:
|
||||||
|
msg.Type = ServerMsgExecGrepArgs
|
||||||
|
case ESM_FetchArgs:
|
||||||
|
msg.Type = ServerMsgExecFetchArgs
|
||||||
|
msg.Url = decodeStringField(val, FA_Url)
|
||||||
|
case ESM_DiagnosticsArgs:
|
||||||
|
msg.Type = ServerMsgExecDiagnostics
|
||||||
|
case ESM_BackgroundShellSpawn:
|
||||||
|
msg.Type = ServerMsgExecBgShellSpawn
|
||||||
|
decodeShellArgs(val, msg) // same structure
|
||||||
|
case ESM_WriteShellStdinArgs:
|
||||||
|
msg.Type = ServerMsgExecWriteShellStdin
|
||||||
|
default:
|
||||||
|
// Unknown exec types - only set if we haven't identified the type yet
|
||||||
|
// (other fields like span_context (19) come after the exec type field)
|
||||||
|
if msg.Type == ServerMsgUnknown {
|
||||||
|
msg.Type = ServerMsgExecOther
|
||||||
|
msg.ExecFieldNumber = int(num)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeMcpArgs(data []byte, msg *DecodedServerMessage) {
|
||||||
|
msg.McpArgs = make(map[string][]byte)
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
switch num {
|
||||||
|
case MCA_Name:
|
||||||
|
msg.McpToolName = string(val)
|
||||||
|
case MCA_Args:
|
||||||
|
// Map entries are encoded as submessages with key=1, value=2
|
||||||
|
decodeMapEntry(val, msg.McpArgs)
|
||||||
|
case MCA_ToolCallId:
|
||||||
|
msg.McpToolCallId = string(val)
|
||||||
|
case MCA_ToolName:
|
||||||
|
// ToolName takes precedence if present
|
||||||
|
if msg.McpToolName == "" || string(val) != "" {
|
||||||
|
msg.McpToolName = string(val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeMapEntry(data []byte, m map[string][]byte) {
|
||||||
|
var key string
|
||||||
|
var value []byte
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == 1 {
|
||||||
|
key = string(val)
|
||||||
|
} else if num == 2 {
|
||||||
|
value = append([]byte(nil), val...)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if key != "" {
|
||||||
|
m[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeShellArgs(data []byte, msg *DecodedServerMessage) {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
switch num {
|
||||||
|
case SHA_Command:
|
||||||
|
msg.Command = string(val)
|
||||||
|
case SHA_WorkingDirectory:
|
||||||
|
msg.WorkingDirectory = string(val)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Helper decoders ---
|
||||||
|
|
||||||
|
// decodeStringField extracts a string from the first matching field in a submessage.
|
||||||
|
func decodeStringField(data []byte, targetField protowire.Number) string {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == targetField {
|
||||||
|
return string(val)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeBytesField extracts bytes from the first matching field in a submessage.
|
||||||
|
func decodeBytesField(data []byte, targetField protowire.Number) []byte {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == targetField {
|
||||||
|
return append([]byte(nil), val...)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeVarintField extracts an int64 from the first matching varint field in a submessage.
|
||||||
|
func decodeVarintField(data []byte, targetField protowire.Number) int64 {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if typ == protowire.VarintType {
|
||||||
|
val, n := protowire.ConsumeVarint(data)
|
||||||
|
if n < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == targetField {
|
||||||
|
return int64(val)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlobIdHex returns the hex string of a blob ID for use as a map key.
|
||||||
|
func BlobIdHex(blobId []byte) string {
|
||||||
|
return hex.EncodeToString(blobId)
|
||||||
|
}
|
||||||
1244
internal/auth/cursor/proto/descriptor.go
Normal file
1244
internal/auth/cursor/proto/descriptor.go
Normal file
File diff suppressed because it is too large
Load Diff
664
internal/auth/cursor/proto/encode.go
Normal file
664
internal/auth/cursor/proto/encode.go
Normal file
@@ -0,0 +1,664 @@
|
|||||||
|
// Package proto provides protobuf encoding for Cursor's gRPC API,
|
||||||
|
// using dynamicpb with the embedded FileDescriptorProto from agent.proto.
|
||||||
|
// This mirrors the cursor-auth TS plugin's use of @bufbuild/protobuf create()+toBinary().
|
||||||
|
package proto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/protobuf/encoding/protowire"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
"google.golang.org/protobuf/reflect/protoreflect"
|
||||||
|
"google.golang.org/protobuf/types/dynamicpb"
|
||||||
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Public types ---
|
||||||
|
|
||||||
|
// RunRequestParams holds all data needed to build an AgentRunRequest.
|
||||||
|
type RunRequestParams struct {
|
||||||
|
ModelId string
|
||||||
|
SystemPrompt string
|
||||||
|
UserText string
|
||||||
|
MessageId string
|
||||||
|
ConversationId string
|
||||||
|
Images []ImageData
|
||||||
|
Turns []TurnData
|
||||||
|
McpTools []McpToolDef
|
||||||
|
BlobStore map[string][]byte // hex(sha256) -> data, populated during encoding
|
||||||
|
RawCheckpoint []byte // if non-nil, use as conversation_state directly (from server checkpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageData struct {
|
||||||
|
MimeType string
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type TurnData struct {
|
||||||
|
UserText string
|
||||||
|
AssistantText string
|
||||||
|
}
|
||||||
|
|
||||||
|
type McpToolDef struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
InputSchema json.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Helper: create a dynamic message and set fields ---
|
||||||
|
|
||||||
|
func newMsg(name string) *dynamicpb.Message {
|
||||||
|
return dynamicpb.NewMessage(Msg(name))
|
||||||
|
}
|
||||||
|
|
||||||
|
func field(msg *dynamicpb.Message, name string) protoreflect.FieldDescriptor {
|
||||||
|
return msg.Descriptor().Fields().ByName(protoreflect.Name(name))
|
||||||
|
}
|
||||||
|
|
||||||
|
func setStr(msg *dynamicpb.Message, name, val string) {
|
||||||
|
if val != "" {
|
||||||
|
msg.Set(field(msg, name), protoreflect.ValueOfString(val))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setBytes(msg *dynamicpb.Message, name string, val []byte) {
|
||||||
|
if len(val) > 0 {
|
||||||
|
msg.Set(field(msg, name), protoreflect.ValueOfBytes(val))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setUint32(msg *dynamicpb.Message, name string, val uint32) {
|
||||||
|
msg.Set(field(msg, name), protoreflect.ValueOfUint32(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
func setBool(msg *dynamicpb.Message, name string, val bool) {
|
||||||
|
msg.Set(field(msg, name), protoreflect.ValueOfBool(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
func setMsg(msg *dynamicpb.Message, name string, sub *dynamicpb.Message) {
|
||||||
|
msg.Set(field(msg, name), protoreflect.ValueOfMessage(sub.ProtoReflect()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func marshal(msg *dynamicpb.Message) []byte {
|
||||||
|
b, err := proto.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
panic("cursor proto marshal: " + err.Error())
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Encode functions mirroring cursor-fetch.ts ---
|
||||||
|
|
||||||
|
// EncodeHeartbeat returns an encoded AgentClientMessage with clientHeartbeat.
|
||||||
|
// Mirrors: create(AgentClientMessageSchema, { message: { case: 'clientHeartbeat', value: create(ClientHeartbeatSchema, {}) } })
|
||||||
|
func EncodeHeartbeat() []byte {
|
||||||
|
hb := newMsg("ClientHeartbeat")
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "client_heartbeat", hb)
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeRunRequest builds a full AgentClientMessage wrapping an AgentRunRequest.
|
||||||
|
// Mirrors buildCursorRequest() in cursor-fetch.ts.
|
||||||
|
// If p.RawCheckpoint is set, it is used directly as the conversation_state bytes
|
||||||
|
// (from a previous conversation_checkpoint_update), skipping manual turn construction.
|
||||||
|
func EncodeRunRequest(p *RunRequestParams) []byte {
|
||||||
|
if p.RawCheckpoint != nil {
|
||||||
|
return encodeRunRequestWithCheckpoint(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.BlobStore == nil {
|
||||||
|
p.BlobStore = make(map[string][]byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Conversation turns ---
|
||||||
|
// Each turn is serialized as bytes (ConversationTurnStructure → bytes)
|
||||||
|
var turnBytes [][]byte
|
||||||
|
for _, turn := range p.Turns {
|
||||||
|
// UserMessage for this turn
|
||||||
|
um := newMsg("UserMessage")
|
||||||
|
setStr(um, "text", turn.UserText)
|
||||||
|
setStr(um, "message_id", generateId())
|
||||||
|
umBytes := marshal(um)
|
||||||
|
|
||||||
|
// Steps (assistant response)
|
||||||
|
var stepBytes [][]byte
|
||||||
|
if turn.AssistantText != "" {
|
||||||
|
am := newMsg("AssistantMessage")
|
||||||
|
setStr(am, "text", turn.AssistantText)
|
||||||
|
step := newMsg("ConversationStep")
|
||||||
|
setMsg(step, "assistant_message", am)
|
||||||
|
stepBytes = append(stepBytes, marshal(step))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AgentConversationTurnStructure (fields are bytes, not submessages)
|
||||||
|
agentTurn := newMsg("AgentConversationTurnStructure")
|
||||||
|
setBytes(agentTurn, "user_message", umBytes)
|
||||||
|
for _, sb := range stepBytes {
|
||||||
|
stepsField := field(agentTurn, "steps")
|
||||||
|
list := agentTurn.Mutable(stepsField).List()
|
||||||
|
list.Append(protoreflect.ValueOfBytes(sb))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationTurnStructure (oneof turn → agentConversationTurn)
|
||||||
|
cts := newMsg("ConversationTurnStructure")
|
||||||
|
setMsg(cts, "agent_conversation_turn", agentTurn)
|
||||||
|
turnBytes = append(turnBytes, marshal(cts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- System prompt blob ---
|
||||||
|
systemJSON, _ := json.Marshal(map[string]string{"role": "system", "content": p.SystemPrompt})
|
||||||
|
blobId := sha256Sum(systemJSON)
|
||||||
|
p.BlobStore[hex.EncodeToString(blobId)] = systemJSON
|
||||||
|
|
||||||
|
// --- ConversationStateStructure ---
|
||||||
|
css := newMsg("ConversationStateStructure")
|
||||||
|
// rootPromptMessagesJson: repeated bytes
|
||||||
|
rootField := field(css, "root_prompt_messages_json")
|
||||||
|
rootList := css.Mutable(rootField).List()
|
||||||
|
rootList.Append(protoreflect.ValueOfBytes(blobId))
|
||||||
|
// turns: repeated bytes (field 8) + turns_old (field 2) for compatibility
|
||||||
|
turnsField := field(css, "turns")
|
||||||
|
turnsList := css.Mutable(turnsField).List()
|
||||||
|
for _, tb := range turnBytes {
|
||||||
|
turnsList.Append(protoreflect.ValueOfBytes(tb))
|
||||||
|
}
|
||||||
|
turnsOldField := field(css, "turns_old")
|
||||||
|
if turnsOldField != nil {
|
||||||
|
turnsOldList := css.Mutable(turnsOldField).List()
|
||||||
|
for _, tb := range turnBytes {
|
||||||
|
turnsOldList.Append(protoreflect.ValueOfBytes(tb))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- UserMessage (current) ---
|
||||||
|
userMessage := newMsg("UserMessage")
|
||||||
|
setStr(userMessage, "text", p.UserText)
|
||||||
|
setStr(userMessage, "message_id", p.MessageId)
|
||||||
|
|
||||||
|
// Images via SelectedContext
|
||||||
|
if len(p.Images) > 0 {
|
||||||
|
sc := newMsg("SelectedContext")
|
||||||
|
imgsField := field(sc, "selected_images")
|
||||||
|
imgsList := sc.Mutable(imgsField).List()
|
||||||
|
for _, img := range p.Images {
|
||||||
|
si := newMsg("SelectedImage")
|
||||||
|
setStr(si, "uuid", generateId())
|
||||||
|
setStr(si, "mime_type", img.MimeType)
|
||||||
|
setBytes(si, "data", img.Data)
|
||||||
|
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
|
||||||
|
}
|
||||||
|
setMsg(userMessage, "selected_context", sc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- UserMessageAction ---
|
||||||
|
uma := newMsg("UserMessageAction")
|
||||||
|
setMsg(uma, "user_message", userMessage)
|
||||||
|
|
||||||
|
// --- ConversationAction ---
|
||||||
|
ca := newMsg("ConversationAction")
|
||||||
|
setMsg(ca, "user_message_action", uma)
|
||||||
|
|
||||||
|
// --- ModelDetails ---
|
||||||
|
md := newMsg("ModelDetails")
|
||||||
|
setStr(md, "model_id", p.ModelId)
|
||||||
|
setStr(md, "display_model_id", p.ModelId)
|
||||||
|
setStr(md, "display_name", p.ModelId)
|
||||||
|
|
||||||
|
// --- AgentRunRequest ---
|
||||||
|
arr := newMsg("AgentRunRequest")
|
||||||
|
setMsg(arr, "conversation_state", css)
|
||||||
|
setMsg(arr, "action", ca)
|
||||||
|
setMsg(arr, "model_details", md)
|
||||||
|
setStr(arr, "conversation_id", p.ConversationId)
|
||||||
|
|
||||||
|
// McpTools
|
||||||
|
if len(p.McpTools) > 0 {
|
||||||
|
mcpTools := newMsg("McpTools")
|
||||||
|
toolsField := field(mcpTools, "mcp_tools")
|
||||||
|
toolsList := mcpTools.Mutable(toolsField).List()
|
||||||
|
for _, tool := range p.McpTools {
|
||||||
|
td := newMsg("McpToolDefinition")
|
||||||
|
setStr(td, "name", tool.Name)
|
||||||
|
setStr(td, "description", tool.Description)
|
||||||
|
if len(tool.InputSchema) > 0 {
|
||||||
|
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||||
|
}
|
||||||
|
setStr(td, "provider_identifier", "proxy")
|
||||||
|
setStr(td, "tool_name", tool.Name)
|
||||||
|
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||||
|
}
|
||||||
|
setMsg(arr, "mcp_tools", mcpTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- AgentClientMessage ---
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "run_request", arr)
|
||||||
|
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeRunRequestWithCheckpoint builds an AgentClientMessage using a raw checkpoint
|
||||||
|
// as conversation_state. The checkpoint bytes are embedded directly without deserialization.
|
||||||
|
func encodeRunRequestWithCheckpoint(p *RunRequestParams) []byte {
|
||||||
|
// Build UserMessage
|
||||||
|
userMessage := newMsg("UserMessage")
|
||||||
|
setStr(userMessage, "text", p.UserText)
|
||||||
|
setStr(userMessage, "message_id", p.MessageId)
|
||||||
|
if len(p.Images) > 0 {
|
||||||
|
sc := newMsg("SelectedContext")
|
||||||
|
imgsField := field(sc, "selected_images")
|
||||||
|
imgsList := sc.Mutable(imgsField).List()
|
||||||
|
for _, img := range p.Images {
|
||||||
|
si := newMsg("SelectedImage")
|
||||||
|
setStr(si, "uuid", generateId())
|
||||||
|
setStr(si, "mime_type", img.MimeType)
|
||||||
|
setBytes(si, "data", img.Data)
|
||||||
|
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
|
||||||
|
}
|
||||||
|
setMsg(userMessage, "selected_context", sc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build ConversationAction with UserMessageAction
|
||||||
|
uma := newMsg("UserMessageAction")
|
||||||
|
setMsg(uma, "user_message", userMessage)
|
||||||
|
ca := newMsg("ConversationAction")
|
||||||
|
setMsg(ca, "user_message_action", uma)
|
||||||
|
caBytes := marshal(ca)
|
||||||
|
|
||||||
|
// Build ModelDetails
|
||||||
|
md := newMsg("ModelDetails")
|
||||||
|
setStr(md, "model_id", p.ModelId)
|
||||||
|
setStr(md, "display_model_id", p.ModelId)
|
||||||
|
setStr(md, "display_name", p.ModelId)
|
||||||
|
mdBytes := marshal(md)
|
||||||
|
|
||||||
|
// Build McpTools
|
||||||
|
var mcpToolsBytes []byte
|
||||||
|
if len(p.McpTools) > 0 {
|
||||||
|
mcpTools := newMsg("McpTools")
|
||||||
|
toolsField := field(mcpTools, "mcp_tools")
|
||||||
|
toolsList := mcpTools.Mutable(toolsField).List()
|
||||||
|
for _, tool := range p.McpTools {
|
||||||
|
td := newMsg("McpToolDefinition")
|
||||||
|
setStr(td, "name", tool.Name)
|
||||||
|
setStr(td, "description", tool.Description)
|
||||||
|
if len(tool.InputSchema) > 0 {
|
||||||
|
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||||
|
}
|
||||||
|
setStr(td, "provider_identifier", "proxy")
|
||||||
|
setStr(td, "tool_name", tool.Name)
|
||||||
|
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||||
|
}
|
||||||
|
mcpToolsBytes = marshal(mcpTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manually assemble AgentRunRequest using protowire to embed raw checkpoint
|
||||||
|
var arrBuf []byte
|
||||||
|
// field 1: conversation_state = raw checkpoint bytes (length-delimited)
|
||||||
|
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationState, protowire.BytesType)
|
||||||
|
arrBuf = protowire.AppendBytes(arrBuf, p.RawCheckpoint)
|
||||||
|
// field 2: action = ConversationAction
|
||||||
|
arrBuf = protowire.AppendTag(arrBuf, ARR_Action, protowire.BytesType)
|
||||||
|
arrBuf = protowire.AppendBytes(arrBuf, caBytes)
|
||||||
|
// field 3: model_details = ModelDetails
|
||||||
|
arrBuf = protowire.AppendTag(arrBuf, ARR_ModelDetails, protowire.BytesType)
|
||||||
|
arrBuf = protowire.AppendBytes(arrBuf, mdBytes)
|
||||||
|
// field 4: mcp_tools = McpTools
|
||||||
|
if len(mcpToolsBytes) > 0 {
|
||||||
|
arrBuf = protowire.AppendTag(arrBuf, ARR_McpTools, protowire.BytesType)
|
||||||
|
arrBuf = protowire.AppendBytes(arrBuf, mcpToolsBytes)
|
||||||
|
}
|
||||||
|
// field 5: conversation_id = string
|
||||||
|
if p.ConversationId != "" {
|
||||||
|
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationId, protowire.BytesType)
|
||||||
|
arrBuf = protowire.AppendString(arrBuf, p.ConversationId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap in AgentClientMessage field 1 (run_request)
|
||||||
|
var acmBuf []byte
|
||||||
|
acmBuf = protowire.AppendTag(acmBuf, ACM_RunRequest, protowire.BytesType)
|
||||||
|
acmBuf = protowire.AppendBytes(acmBuf, arrBuf)
|
||||||
|
|
||||||
|
log.Debugf("cursor encode: built RunRequest with checkpoint (%d bytes), total=%d bytes", len(p.RawCheckpoint), len(acmBuf))
|
||||||
|
return acmBuf
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResumeRequestParams holds data for a ResumeAction request.
|
||||||
|
type ResumeRequestParams struct {
|
||||||
|
ModelId string
|
||||||
|
ConversationId string
|
||||||
|
McpTools []McpToolDef
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeResumeRequest builds an AgentClientMessage with ResumeAction.
|
||||||
|
// Used to resume a conversation by conversation_id without re-sending full history.
|
||||||
|
func EncodeResumeRequest(p *ResumeRequestParams) []byte {
|
||||||
|
// RequestContext with tools
|
||||||
|
rc := newMsg("RequestContext")
|
||||||
|
if len(p.McpTools) > 0 {
|
||||||
|
toolsField := field(rc, "tools")
|
||||||
|
toolsList := rc.Mutable(toolsField).List()
|
||||||
|
for _, tool := range p.McpTools {
|
||||||
|
td := newMsg("McpToolDefinition")
|
||||||
|
setStr(td, "name", tool.Name)
|
||||||
|
setStr(td, "description", tool.Description)
|
||||||
|
if len(tool.InputSchema) > 0 {
|
||||||
|
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||||
|
}
|
||||||
|
setStr(td, "provider_identifier", "proxy")
|
||||||
|
setStr(td, "tool_name", tool.Name)
|
||||||
|
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResumeAction
|
||||||
|
ra := newMsg("ResumeAction")
|
||||||
|
setMsg(ra, "request_context", rc)
|
||||||
|
|
||||||
|
// ConversationAction with resume_action
|
||||||
|
ca := newMsg("ConversationAction")
|
||||||
|
setMsg(ca, "resume_action", ra)
|
||||||
|
|
||||||
|
// ModelDetails
|
||||||
|
md := newMsg("ModelDetails")
|
||||||
|
setStr(md, "model_id", p.ModelId)
|
||||||
|
setStr(md, "display_model_id", p.ModelId)
|
||||||
|
setStr(md, "display_name", p.ModelId)
|
||||||
|
|
||||||
|
// AgentRunRequest — no conversation_state needed for resume
|
||||||
|
arr := newMsg("AgentRunRequest")
|
||||||
|
setMsg(arr, "action", ca)
|
||||||
|
setMsg(arr, "model_details", md)
|
||||||
|
setStr(arr, "conversation_id", p.ConversationId)
|
||||||
|
|
||||||
|
// McpTools at top level
|
||||||
|
if len(p.McpTools) > 0 {
|
||||||
|
mcpTools := newMsg("McpTools")
|
||||||
|
toolsField := field(mcpTools, "mcp_tools")
|
||||||
|
toolsList := mcpTools.Mutable(toolsField).List()
|
||||||
|
for _, tool := range p.McpTools {
|
||||||
|
td := newMsg("McpToolDefinition")
|
||||||
|
setStr(td, "name", tool.Name)
|
||||||
|
setStr(td, "description", tool.Description)
|
||||||
|
if len(tool.InputSchema) > 0 {
|
||||||
|
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||||
|
}
|
||||||
|
setStr(td, "provider_identifier", "proxy")
|
||||||
|
setStr(td, "tool_name", tool.Name)
|
||||||
|
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||||
|
}
|
||||||
|
setMsg(arr, "mcp_tools", mcpTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "run_request", arr)
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- KV response encoders ---
|
||||||
|
// Mirrors handleKvMessage() in cursor-fetch.ts
|
||||||
|
|
||||||
|
// EncodeKvGetBlobResult responds to a getBlobArgs request.
|
||||||
|
func EncodeKvGetBlobResult(kvId uint32, blobData []byte) []byte {
|
||||||
|
result := newMsg("GetBlobResult")
|
||||||
|
if blobData != nil {
|
||||||
|
setBytes(result, "blob_data", blobData)
|
||||||
|
}
|
||||||
|
|
||||||
|
kvc := newMsg("KvClientMessage")
|
||||||
|
setUint32(kvc, "id", kvId)
|
||||||
|
setMsg(kvc, "get_blob_result", result)
|
||||||
|
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "kv_client_message", kvc)
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeKvSetBlobResult responds to a setBlobArgs request.
|
||||||
|
func EncodeKvSetBlobResult(kvId uint32) []byte {
|
||||||
|
result := newMsg("SetBlobResult")
|
||||||
|
|
||||||
|
kvc := newMsg("KvClientMessage")
|
||||||
|
setUint32(kvc, "id", kvId)
|
||||||
|
setMsg(kvc, "set_blob_result", result)
|
||||||
|
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "kv_client_message", kvc)
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Exec response encoders ---
|
||||||
|
// Mirrors handleExecMessage() and sendExec() in cursor-fetch.ts
|
||||||
|
|
||||||
|
// EncodeExecRequestContextResult responds to requestContextArgs with tool definitions.
|
||||||
|
func EncodeExecRequestContextResult(execMsgId uint32, execId string, tools []McpToolDef) []byte {
|
||||||
|
// RequestContext with tools
|
||||||
|
rc := newMsg("RequestContext")
|
||||||
|
if len(tools) > 0 {
|
||||||
|
toolsField := field(rc, "tools")
|
||||||
|
toolsList := rc.Mutable(toolsField).List()
|
||||||
|
for _, tool := range tools {
|
||||||
|
td := newMsg("McpToolDefinition")
|
||||||
|
setStr(td, "name", tool.Name)
|
||||||
|
setStr(td, "description", tool.Description)
|
||||||
|
if len(tool.InputSchema) > 0 {
|
||||||
|
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||||
|
}
|
||||||
|
setStr(td, "provider_identifier", "proxy")
|
||||||
|
setStr(td, "tool_name", tool.Name)
|
||||||
|
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestContextSuccess
|
||||||
|
rcs := newMsg("RequestContextSuccess")
|
||||||
|
setMsg(rcs, "request_context", rc)
|
||||||
|
|
||||||
|
// RequestContextResult (oneof success)
|
||||||
|
rcr := newMsg("RequestContextResult")
|
||||||
|
setMsg(rcr, "success", rcs)
|
||||||
|
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "request_context_result", rcr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeExecMcpResult responds with MCP tool result.
|
||||||
|
func EncodeExecMcpResult(execMsgId uint32, execId string, content string, isError bool) []byte {
|
||||||
|
textContent := newMsg("McpTextContent")
|
||||||
|
setStr(textContent, "text", content)
|
||||||
|
|
||||||
|
contentItem := newMsg("McpToolResultContentItem")
|
||||||
|
setMsg(contentItem, "text", textContent)
|
||||||
|
|
||||||
|
success := newMsg("McpSuccess")
|
||||||
|
contentField := field(success, "content")
|
||||||
|
contentList := success.Mutable(contentField).List()
|
||||||
|
contentList.Append(protoreflect.ValueOfMessage(contentItem.ProtoReflect()))
|
||||||
|
setBool(success, "is_error", isError)
|
||||||
|
|
||||||
|
result := newMsg("McpResult")
|
||||||
|
setMsg(result, "success", success)
|
||||||
|
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeExecMcpError responds with MCP error.
|
||||||
|
func EncodeExecMcpError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||||
|
mcpErr := newMsg("McpError")
|
||||||
|
setStr(mcpErr, "error", errMsg)
|
||||||
|
|
||||||
|
result := newMsg("McpResult")
|
||||||
|
setMsg(result, "error", mcpErr)
|
||||||
|
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Rejection encoders (mirror handleExecMessage rejections) ---
|
||||||
|
|
||||||
|
func EncodeExecReadRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||||
|
rej := newMsg("ReadRejected")
|
||||||
|
setStr(rej, "path", path)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("ReadResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "read_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecShellRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
|
||||||
|
rej := newMsg("ShellRejected")
|
||||||
|
setStr(rej, "command", command)
|
||||||
|
setStr(rej, "working_directory", workDir)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("ShellResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "shell_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecWriteRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||||
|
rej := newMsg("WriteRejected")
|
||||||
|
setStr(rej, "path", path)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("WriteResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "write_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecDeleteRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||||
|
rej := newMsg("DeleteRejected")
|
||||||
|
setStr(rej, "path", path)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("DeleteResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "delete_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecLsRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||||
|
rej := newMsg("LsRejected")
|
||||||
|
setStr(rej, "path", path)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("LsResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "ls_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecGrepError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||||
|
grepErr := newMsg("GrepError")
|
||||||
|
setStr(grepErr, "error", errMsg)
|
||||||
|
result := newMsg("GrepResult")
|
||||||
|
setMsg(result, "error", grepErr)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "grep_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecFetchError(execMsgId uint32, execId string, url, errMsg string) []byte {
|
||||||
|
fetchErr := newMsg("FetchError")
|
||||||
|
setStr(fetchErr, "url", url)
|
||||||
|
setStr(fetchErr, "error", errMsg)
|
||||||
|
result := newMsg("FetchResult")
|
||||||
|
setMsg(result, "error", fetchErr)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "fetch_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecDiagnosticsResult(execMsgId uint32, execId string) []byte {
|
||||||
|
result := newMsg("DiagnosticsResult")
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "diagnostics_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecBackgroundShellSpawnRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
|
||||||
|
rej := newMsg("ShellRejected")
|
||||||
|
setStr(rej, "command", command)
|
||||||
|
setStr(rej, "working_directory", workDir)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("BackgroundShellSpawnResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "background_shell_spawn_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecWriteShellStdinError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||||
|
wsErr := newMsg("WriteShellStdinError")
|
||||||
|
setStr(wsErr, "error", errMsg)
|
||||||
|
result := newMsg("WriteShellStdinResult")
|
||||||
|
setMsg(result, "error", wsErr)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "write_shell_stdin_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeExecClientMsg wraps an exec result in AgentClientMessage.
|
||||||
|
// Mirrors sendExec() in cursor-fetch.ts.
|
||||||
|
func encodeExecClientMsg(id uint32, execId string, resultFieldName string, resultMsg *dynamicpb.Message) []byte {
|
||||||
|
ecm := newMsg("ExecClientMessage")
|
||||||
|
setUint32(ecm, "id", id)
|
||||||
|
// Force set exec_id even if empty - Cursor requires this field to be set
|
||||||
|
ecm.Set(field(ecm, "exec_id"), protoreflect.ValueOfString(execId))
|
||||||
|
|
||||||
|
// Debug: check if field exists
|
||||||
|
fd := field(ecm, resultFieldName)
|
||||||
|
if fd == nil {
|
||||||
|
panic(fmt.Sprintf("field %q NOT FOUND in ExecClientMessage! Available fields: %v", resultFieldName, listFields(ecm)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug: log the actual field being set
|
||||||
|
log.Debugf("encodeExecClientMsg: setting field %q (number=%d, kind=%s)", fd.Name(), fd.Number(), fd.Kind())
|
||||||
|
|
||||||
|
ecm.Set(fd, protoreflect.ValueOfMessage(resultMsg.ProtoReflect()))
|
||||||
|
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "exec_client_message", ecm)
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
func listFields(msg *dynamicpb.Message) []string {
|
||||||
|
var names []string
|
||||||
|
for i := 0; i < msg.Descriptor().Fields().Len(); i++ {
|
||||||
|
names = append(names, string(msg.Descriptor().Fields().Get(i).Name()))
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Utilities ---
|
||||||
|
|
||||||
|
// jsonToProtobufValueBytes converts a JSON schema (json.RawMessage) to protobuf Value binary.
|
||||||
|
// This mirrors the TS pattern: toBinary(ValueSchema, fromJson(ValueSchema, jsonSchema))
|
||||||
|
func jsonToProtobufValueBytes(jsonData json.RawMessage) []byte {
|
||||||
|
if len(jsonData) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var v interface{}
|
||||||
|
if err := json.Unmarshal(jsonData, &v); err != nil {
|
||||||
|
return jsonData // fallback to raw JSON if parsing fails
|
||||||
|
}
|
||||||
|
pbVal, err := structpb.NewValue(v)
|
||||||
|
if err != nil {
|
||||||
|
return jsonData // fallback
|
||||||
|
}
|
||||||
|
b, err := proto.Marshal(pbVal)
|
||||||
|
if err != nil {
|
||||||
|
return jsonData // fallback
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProtobufValueBytesToJSON converts protobuf Value binary back to JSON.
|
||||||
|
// This mirrors the TS pattern: toJson(ValueSchema, fromBinary(ValueSchema, value))
|
||||||
|
func ProtobufValueBytesToJSON(data []byte) (interface{}, error) {
|
||||||
|
val := &structpb.Value{}
|
||||||
|
if err := proto.Unmarshal(data, val); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return val.AsInterface(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sha256Sum(data []byte) []byte {
|
||||||
|
h := sha256.Sum256(data)
|
||||||
|
return h[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
var idCounter uint64
|
||||||
|
|
||||||
|
func generateId() string {
|
||||||
|
idCounter++
|
||||||
|
h := sha256.Sum256([]byte{byte(idCounter), byte(idCounter >> 8), byte(idCounter >> 16)})
|
||||||
|
return hex.EncodeToString(h[:16])
|
||||||
|
}
|
||||||
332
internal/auth/cursor/proto/fieldnumbers.go
Normal file
332
internal/auth/cursor/proto/fieldnumbers.go
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
// Package proto provides hand-rolled protobuf encode/decode for Cursor's gRPC API.
|
||||||
|
// Field numbers are extracted from the TypeScript generated proto/agent_pb.ts in alma-plugins/cursor-auth.
|
||||||
|
package proto
|
||||||
|
|
||||||
|
// AgentClientMessage (msg 118) oneof "message"
|
||||||
|
const (
|
||||||
|
ACM_RunRequest = 1 // AgentRunRequest
|
||||||
|
ACM_ExecClientMessage = 2 // ExecClientMessage
|
||||||
|
ACM_KvClientMessage = 3 // KvClientMessage
|
||||||
|
ACM_ConversationAction = 4 // ConversationAction
|
||||||
|
ACM_ExecClientControlMsg = 5 // ExecClientControlMessage
|
||||||
|
ACM_InteractionResponse = 6 // InteractionResponse
|
||||||
|
ACM_ClientHeartbeat = 7 // ClientHeartbeat
|
||||||
|
)
|
||||||
|
|
||||||
|
// AgentServerMessage (msg 119) oneof "message"
|
||||||
|
const (
|
||||||
|
ASM_InteractionUpdate = 1 // InteractionUpdate
|
||||||
|
ASM_ExecServerMessage = 2 // ExecServerMessage
|
||||||
|
ASM_ConversationCheckpoint = 3 // ConversationStateStructure
|
||||||
|
ASM_KvServerMessage = 4 // KvServerMessage
|
||||||
|
ASM_ExecServerControlMessage = 5 // ExecServerControlMessage
|
||||||
|
ASM_InteractionQuery = 7 // InteractionQuery
|
||||||
|
)
|
||||||
|
|
||||||
|
// AgentRunRequest (msg 91)
|
||||||
|
const (
|
||||||
|
ARR_ConversationState = 1 // ConversationStateStructure
|
||||||
|
ARR_Action = 2 // ConversationAction
|
||||||
|
ARR_ModelDetails = 3 // ModelDetails
|
||||||
|
ARR_McpTools = 4 // McpTools
|
||||||
|
ARR_ConversationId = 5 // string (optional)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConversationStateStructure (msg 83)
|
||||||
|
const (
|
||||||
|
CSS_RootPromptMessagesJson = 1 // repeated bytes
|
||||||
|
CSS_TurnsOld = 2 // repeated bytes (deprecated)
|
||||||
|
CSS_Todos = 3 // repeated bytes
|
||||||
|
CSS_PendingToolCalls = 4 // repeated string
|
||||||
|
CSS_Turns = 8 // repeated bytes (CURRENT field for turns)
|
||||||
|
CSS_PreviousWorkspaceUris = 9 // repeated string
|
||||||
|
CSS_SelfSummaryCount = 17 // uint32
|
||||||
|
CSS_ReadPaths = 18 // repeated string
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConversationAction (msg 54) oneof "action"
|
||||||
|
const (
|
||||||
|
CA_UserMessageAction = 1 // UserMessageAction
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserMessageAction (msg 55)
|
||||||
|
const (
|
||||||
|
UMA_UserMessage = 1 // UserMessage
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserMessage (msg 63)
|
||||||
|
const (
|
||||||
|
UM_Text = 1 // string
|
||||||
|
UM_MessageId = 2 // string
|
||||||
|
UM_SelectedContext = 3 // SelectedContext (optional)
|
||||||
|
)
|
||||||
|
|
||||||
|
// SelectedContext
|
||||||
|
const (
|
||||||
|
SC_SelectedImages = 1 // repeated SelectedImage
|
||||||
|
)
|
||||||
|
|
||||||
|
// SelectedImage
|
||||||
|
const (
|
||||||
|
SI_BlobId = 1 // bytes (oneof dataOrBlobId)
|
||||||
|
SI_Uuid = 2 // string
|
||||||
|
SI_Path = 3 // string
|
||||||
|
SI_MimeType = 7 // string
|
||||||
|
SI_Data = 8 // bytes (oneof dataOrBlobId)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModelDetails (msg 88)
|
||||||
|
const (
|
||||||
|
MD_ModelId = 1 // string
|
||||||
|
MD_ThinkingDetails = 2 // ThinkingDetails (optional)
|
||||||
|
MD_DisplayModelId = 3 // string
|
||||||
|
MD_DisplayName = 4 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpTools (msg 307)
|
||||||
|
const (
|
||||||
|
MT_McpTools = 1 // repeated McpToolDefinition
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpToolDefinition (msg 306)
|
||||||
|
const (
|
||||||
|
MTD_Name = 1 // string
|
||||||
|
MTD_Description = 2 // string
|
||||||
|
MTD_InputSchema = 3 // bytes
|
||||||
|
MTD_ProviderIdentifier = 4 // string
|
||||||
|
MTD_ToolName = 5 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConversationTurnStructure (msg 70) oneof "turn"
|
||||||
|
const (
|
||||||
|
CTS_AgentConversationTurn = 1 // AgentConversationTurnStructure
|
||||||
|
)
|
||||||
|
|
||||||
|
// AgentConversationTurnStructure (msg 72)
|
||||||
|
const (
|
||||||
|
ACTS_UserMessage = 1 // bytes (serialized UserMessage)
|
||||||
|
ACTS_Steps = 2 // repeated bytes (serialized ConversationStep)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConversationStep (msg 53) oneof "message"
|
||||||
|
const (
|
||||||
|
CS_AssistantMessage = 1 // AssistantMessage
|
||||||
|
)
|
||||||
|
|
||||||
|
// AssistantMessage
|
||||||
|
const (
|
||||||
|
AM_Text = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Server-side message fields ---
|
||||||
|
|
||||||
|
// InteractionUpdate oneof "message"
|
||||||
|
const (
|
||||||
|
IU_TextDelta = 1 // TextDeltaUpdate
|
||||||
|
IU_ThinkingDelta = 4 // ThinkingDeltaUpdate
|
||||||
|
IU_ThinkingCompleted = 5 // ThinkingCompletedUpdate
|
||||||
|
)
|
||||||
|
|
||||||
|
// TextDeltaUpdate (msg 92)
|
||||||
|
const (
|
||||||
|
TDU_Text = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// ThinkingDeltaUpdate (msg 97)
|
||||||
|
const (
|
||||||
|
TKD_Text = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// KvServerMessage (msg 271)
|
||||||
|
const (
|
||||||
|
KSM_Id = 1 // uint32
|
||||||
|
KSM_GetBlobArgs = 2 // GetBlobArgs
|
||||||
|
KSM_SetBlobArgs = 3 // SetBlobArgs
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetBlobArgs (msg 267)
|
||||||
|
const (
|
||||||
|
GBA_BlobId = 1 // bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetBlobArgs (msg 269)
|
||||||
|
const (
|
||||||
|
SBA_BlobId = 1 // bytes
|
||||||
|
SBA_BlobData = 2 // bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
// KvClientMessage (msg 272)
|
||||||
|
const (
|
||||||
|
KCM_Id = 1 // uint32
|
||||||
|
KCM_GetBlobResult = 2 // GetBlobResult
|
||||||
|
KCM_SetBlobResult = 3 // SetBlobResult
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetBlobResult (msg 268)
|
||||||
|
const (
|
||||||
|
GBR_BlobData = 1 // bytes (optional)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExecServerMessage
|
||||||
|
const (
|
||||||
|
ESM_Id = 1 // uint32
|
||||||
|
ESM_ExecId = 15 // string
|
||||||
|
// oneof message:
|
||||||
|
ESM_ShellArgs = 2 // ShellArgs
|
||||||
|
ESM_WriteArgs = 3 // WriteArgs
|
||||||
|
ESM_DeleteArgs = 4 // DeleteArgs
|
||||||
|
ESM_GrepArgs = 5 // GrepArgs
|
||||||
|
ESM_ReadArgs = 7 // ReadArgs (NOTE: 6 is skipped)
|
||||||
|
ESM_LsArgs = 8 // LsArgs
|
||||||
|
ESM_DiagnosticsArgs = 9 // DiagnosticsArgs
|
||||||
|
ESM_RequestContextArgs = 10 // RequestContextArgs
|
||||||
|
ESM_McpArgs = 11 // McpArgs
|
||||||
|
ESM_ShellStreamArgs = 14 // ShellArgs (stream variant)
|
||||||
|
ESM_BackgroundShellSpawn = 16 // BackgroundShellSpawnArgs
|
||||||
|
ESM_FetchArgs = 20 // FetchArgs
|
||||||
|
ESM_WriteShellStdinArgs = 23 // WriteShellStdinArgs
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExecClientMessage
|
||||||
|
const (
|
||||||
|
ECM_Id = 1 // uint32
|
||||||
|
ECM_ExecId = 15 // string
|
||||||
|
// oneof message (mirrors server fields):
|
||||||
|
ECM_ShellResult = 2
|
||||||
|
ECM_WriteResult = 3
|
||||||
|
ECM_DeleteResult = 4
|
||||||
|
ECM_GrepResult = 5
|
||||||
|
ECM_ReadResult = 7
|
||||||
|
ECM_LsResult = 8
|
||||||
|
ECM_DiagnosticsResult = 9
|
||||||
|
ECM_RequestContextResult = 10
|
||||||
|
ECM_McpResult = 11
|
||||||
|
ECM_ShellStream = 14
|
||||||
|
ECM_BackgroundShellSpawnRes = 16
|
||||||
|
ECM_FetchResult = 20
|
||||||
|
ECM_WriteShellStdinResult = 23
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpArgs
|
||||||
|
const (
|
||||||
|
MCA_Name = 1 // string
|
||||||
|
MCA_Args = 2 // map<string, bytes>
|
||||||
|
MCA_ToolCallId = 3 // string
|
||||||
|
MCA_ProviderIdentifier = 4 // string
|
||||||
|
MCA_ToolName = 5 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestContextResult oneof "result"
|
||||||
|
const (
|
||||||
|
RCR_Success = 1 // RequestContextSuccess
|
||||||
|
RCR_Error = 2 // RequestContextError
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestContextSuccess (msg 337)
|
||||||
|
const (
|
||||||
|
RCS_RequestContext = 1 // RequestContext
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestContext
|
||||||
|
const (
|
||||||
|
RC_Rules = 2 // repeated CursorRule
|
||||||
|
RC_Tools = 7 // repeated McpToolDefinition
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpResult oneof "result"
|
||||||
|
const (
|
||||||
|
MCR_Success = 1 // McpSuccess
|
||||||
|
MCR_Error = 2 // McpError
|
||||||
|
MCR_Rejected = 3 // McpRejected
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpSuccess (msg 290)
|
||||||
|
const (
|
||||||
|
MCS_Content = 1 // repeated McpToolResultContentItem
|
||||||
|
MCS_IsError = 2 // bool
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpToolResultContentItem oneof "content"
|
||||||
|
const (
|
||||||
|
MTRCI_Text = 1 // McpTextContent
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpTextContent (msg 287)
|
||||||
|
const (
|
||||||
|
MTC_Text = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpError (msg 291)
|
||||||
|
const (
|
||||||
|
MCE_Error = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Rejection messages ---
|
||||||
|
|
||||||
|
// ReadRejected: path=1, reason=2
|
||||||
|
// ShellRejected: command=1, workingDirectory=2, reason=3, isReadonly=4
|
||||||
|
// WriteRejected: path=1, reason=2
|
||||||
|
// DeleteRejected: path=1, reason=2
|
||||||
|
// LsRejected: path=1, reason=2
|
||||||
|
// GrepError: error=1
|
||||||
|
// FetchError: url=1, error=2
|
||||||
|
// WriteShellStdinError: error=1
|
||||||
|
|
||||||
|
// ReadResult oneof: success=1, error=2, rejected=3
|
||||||
|
// ShellResult oneof: success=1 (+ various), rejected=?
|
||||||
|
// The TS code uses specific result field numbers from the oneof:
|
||||||
|
const (
|
||||||
|
RR_Rejected = 3 // ReadResult.rejected
|
||||||
|
SR_Rejected = 5 // ShellResult.rejected (from TS: ShellResult has success/various/rejected)
|
||||||
|
WR_Rejected = 5 // WriteResult.rejected
|
||||||
|
DR_Rejected = 3 // DeleteResult.rejected
|
||||||
|
LR_Rejected = 3 // LsResult.rejected
|
||||||
|
GR_Error = 2 // GrepResult.error
|
||||||
|
FR_Error = 2 // FetchResult.error
|
||||||
|
BSSR_Rejected = 2 // BackgroundShellSpawnResult.rejected (error field)
|
||||||
|
WSSR_Error = 2 // WriteShellStdinResult.error
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Rejection struct fields ---
|
||||||
|
const (
|
||||||
|
REJ_Path = 1
|
||||||
|
REJ_Reason = 2
|
||||||
|
SREJ_Command = 1
|
||||||
|
SREJ_WorkingDir = 2
|
||||||
|
SREJ_Reason = 3
|
||||||
|
SREJ_IsReadonly = 4
|
||||||
|
GERR_Error = 1
|
||||||
|
FERR_Url = 1
|
||||||
|
FERR_Error = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReadArgs
|
||||||
|
const (
|
||||||
|
RA_Path = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// WriteArgs
|
||||||
|
const (
|
||||||
|
WA_Path = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeleteArgs
|
||||||
|
const (
|
||||||
|
DA_Path = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// LsArgs
|
||||||
|
const (
|
||||||
|
LA_Path = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// ShellArgs
|
||||||
|
const (
|
||||||
|
SHA_Command = 1 // string
|
||||||
|
SHA_WorkingDirectory = 2 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// FetchArgs
|
||||||
|
const (
|
||||||
|
FA_Url = 1 // string
|
||||||
|
)
|
||||||
313
internal/auth/cursor/proto/h2stream.go
Normal file
313
internal/auth/cursor/proto/h2stream.go
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
package proto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
"golang.org/x/net/http2/hpack"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultInitialWindowSize = 65535 // HTTP/2 default
|
||||||
|
maxFramePayload = 16384 // HTTP/2 default max frame size
|
||||||
|
)
|
||||||
|
|
||||||
|
// H2Stream provides bidirectional HTTP/2 streaming for the Connect protocol.
|
||||||
|
// Go's net/http does not support full-duplex HTTP/2, so we use the low-level framer.
|
||||||
|
type H2Stream struct {
|
||||||
|
framer *http2.Framer
|
||||||
|
conn net.Conn
|
||||||
|
streamID uint32
|
||||||
|
mu sync.Mutex
|
||||||
|
id string // unique identifier for debugging
|
||||||
|
frameNum int64 // sequential frame counter for debugging
|
||||||
|
|
||||||
|
dataCh chan []byte
|
||||||
|
doneCh chan struct{}
|
||||||
|
err error
|
||||||
|
|
||||||
|
// Send-side flow control
|
||||||
|
sendWindow int32 // available bytes we can send on this stream
|
||||||
|
connWindow int32 // available bytes on the connection level
|
||||||
|
windowCond *sync.Cond // signaled when window is updated
|
||||||
|
windowMu sync.Mutex // protects sendWindow, connWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
// ID returns the unique identifier for this stream (for logging).
|
||||||
|
func (s *H2Stream) ID() string { return s.id }
|
||||||
|
|
||||||
|
// FrameNum returns the current frame number for debugging.
|
||||||
|
func (s *H2Stream) FrameNum() int64 {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.frameNum
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialH2Stream establishes a TLS+HTTP/2 connection and opens a new stream.
|
||||||
|
func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
|
||||||
|
tlsConn, err := tls.Dial("tcp", host+":443", &tls.Config{
|
||||||
|
NextProtos: []string{"h2"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("h2: TLS dial failed: %w", err)
|
||||||
|
}
|
||||||
|
if tlsConn.ConnectionState().NegotiatedProtocol != "h2" {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: server did not negotiate h2")
|
||||||
|
}
|
||||||
|
|
||||||
|
framer := http2.NewFramer(tlsConn, tlsConn)
|
||||||
|
|
||||||
|
// Client connection preface
|
||||||
|
if _, err := tlsConn.Write([]byte(http2.ClientPreface)); err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: preface write failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send initial SETTINGS (tell server how much WE can receive)
|
||||||
|
if err := framer.WriteSettings(
|
||||||
|
http2.Setting{ID: http2.SettingInitialWindowSize, Val: 4 * 1024 * 1024},
|
||||||
|
http2.Setting{ID: http2.SettingMaxConcurrentStreams, Val: 100},
|
||||||
|
); err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: settings write failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection-level window update (for receiving)
|
||||||
|
if err := framer.WriteWindowUpdate(0, 3*1024*1024); err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: window update failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and handle initial server frames (SETTINGS, WINDOW_UPDATE)
|
||||||
|
// Track server's initial window size (how much WE can send)
|
||||||
|
serverInitialWindowSize := int32(defaultInitialWindowSize)
|
||||||
|
connWindowSize := int32(defaultInitialWindowSize) // connection-level send window
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
f, err := framer.ReadFrame()
|
||||||
|
if err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: initial frame read failed: %w", err)
|
||||||
|
}
|
||||||
|
switch sf := f.(type) {
|
||||||
|
case *http2.SettingsFrame:
|
||||||
|
if !sf.IsAck() {
|
||||||
|
sf.ForeachSetting(func(s http2.Setting) error {
|
||||||
|
if s.ID == http2.SettingInitialWindowSize {
|
||||||
|
serverInitialWindowSize = int32(s.Val)
|
||||||
|
log.Debugf("h2: server initial window size: %d", s.Val)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
framer.WriteSettingsAck()
|
||||||
|
} else {
|
||||||
|
goto handshakeDone
|
||||||
|
}
|
||||||
|
case *http2.WindowUpdateFrame:
|
||||||
|
if sf.StreamID == 0 {
|
||||||
|
connWindowSize += int32(sf.Increment)
|
||||||
|
log.Debugf("h2: initial conn window update: +%d, total=%d", sf.Increment, connWindowSize)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// unexpected but continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
handshakeDone:
|
||||||
|
|
||||||
|
// Build HEADERS
|
||||||
|
streamID := uint32(1)
|
||||||
|
var hdrBuf []byte
|
||||||
|
enc := hpack.NewEncoder(&sliceWriter{buf: &hdrBuf})
|
||||||
|
enc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"})
|
||||||
|
enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
|
||||||
|
enc.WriteField(hpack.HeaderField{Name: ":authority", Value: host})
|
||||||
|
if p, ok := headers[":path"]; ok {
|
||||||
|
enc.WriteField(hpack.HeaderField{Name: ":path", Value: p})
|
||||||
|
}
|
||||||
|
for k, v := range headers {
|
||||||
|
if len(k) > 0 && k[0] == ':' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := framer.WriteHeaders(http2.HeadersFrameParam{
|
||||||
|
StreamID: streamID,
|
||||||
|
BlockFragment: hdrBuf,
|
||||||
|
EndStream: false,
|
||||||
|
EndHeaders: true,
|
||||||
|
}); err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: headers write failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &H2Stream{
|
||||||
|
framer: framer,
|
||||||
|
conn: tlsConn,
|
||||||
|
streamID: streamID,
|
||||||
|
dataCh: make(chan []byte, 256),
|
||||||
|
doneCh: make(chan struct{}),
|
||||||
|
id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")),
|
||||||
|
frameNum: 0,
|
||||||
|
sendWindow: serverInitialWindowSize,
|
||||||
|
connWindow: connWindowSize,
|
||||||
|
}
|
||||||
|
s.windowCond = sync.NewCond(&s.windowMu)
|
||||||
|
go s.readLoop()
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write sends a DATA frame on the stream, respecting flow control.
|
||||||
|
func (s *H2Stream) Write(data []byte) error {
|
||||||
|
for len(data) > 0 {
|
||||||
|
chunk := data
|
||||||
|
if len(chunk) > maxFramePayload {
|
||||||
|
chunk = data[:maxFramePayload]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for flow control window
|
||||||
|
s.windowMu.Lock()
|
||||||
|
for s.sendWindow <= 0 || s.connWindow <= 0 {
|
||||||
|
s.windowCond.Wait()
|
||||||
|
}
|
||||||
|
// Limit chunk to available window
|
||||||
|
allowed := int(s.sendWindow)
|
||||||
|
if int(s.connWindow) < allowed {
|
||||||
|
allowed = int(s.connWindow)
|
||||||
|
}
|
||||||
|
if len(chunk) > allowed {
|
||||||
|
chunk = chunk[:allowed]
|
||||||
|
}
|
||||||
|
s.sendWindow -= int32(len(chunk))
|
||||||
|
s.connWindow -= int32(len(chunk))
|
||||||
|
s.windowMu.Unlock()
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
err := s.framer.WriteData(s.streamID, false, chunk)
|
||||||
|
s.mu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data = data[len(chunk):]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Data returns the channel of received data chunks.
|
||||||
|
func (s *H2Stream) Data() <-chan []byte { return s.dataCh }
|
||||||
|
|
||||||
|
// Done returns a channel closed when the stream ends.
|
||||||
|
func (s *H2Stream) Done() <-chan struct{} { return s.doneCh }
|
||||||
|
|
||||||
|
// Err returns the error (if any) that caused the stream to close.
|
||||||
|
// Returns nil for a clean shutdown (EOF / StreamEnded).
|
||||||
|
func (s *H2Stream) Err() error { return s.err }
|
||||||
|
|
||||||
|
// Close tears down the connection.
|
||||||
|
func (s *H2Stream) Close() {
|
||||||
|
s.conn.Close()
|
||||||
|
// Unblock any writers waiting on flow control
|
||||||
|
s.windowCond.Broadcast()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *H2Stream) readLoop() {
|
||||||
|
defer close(s.doneCh)
|
||||||
|
defer close(s.dataCh)
|
||||||
|
|
||||||
|
for {
|
||||||
|
f, err := s.framer.ReadFrame()
|
||||||
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
|
s.err = err
|
||||||
|
log.Debugf("h2stream[%s]: readLoop error: %v", s.id, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment frame counter
|
||||||
|
s.mu.Lock()
|
||||||
|
s.frameNum++
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
switch frame := f.(type) {
|
||||||
|
case *http2.DataFrame:
|
||||||
|
if frame.StreamID == s.streamID && len(frame.Data()) > 0 {
|
||||||
|
cp := make([]byte, len(frame.Data()))
|
||||||
|
copy(cp, frame.Data())
|
||||||
|
s.dataCh <- cp
|
||||||
|
|
||||||
|
// Flow control: send WINDOW_UPDATE for received data
|
||||||
|
s.mu.Lock()
|
||||||
|
s.framer.WriteWindowUpdate(0, uint32(len(cp)))
|
||||||
|
s.framer.WriteWindowUpdate(s.streamID, uint32(len(cp)))
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
if frame.StreamEnded() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case *http2.HeadersFrame:
|
||||||
|
if frame.StreamEnded() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case *http2.RSTStreamFrame:
|
||||||
|
s.err = fmt.Errorf("h2: RST_STREAM code=%d", frame.ErrCode)
|
||||||
|
log.Debugf("h2stream[%s]: received RST_STREAM code=%d", s.id, frame.ErrCode)
|
||||||
|
return
|
||||||
|
|
||||||
|
case *http2.GoAwayFrame:
|
||||||
|
s.err = fmt.Errorf("h2: GOAWAY code=%d", frame.ErrCode)
|
||||||
|
return
|
||||||
|
|
||||||
|
case *http2.PingFrame:
|
||||||
|
if !frame.IsAck() {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.framer.WritePing(true, frame.Data)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
case *http2.SettingsFrame:
|
||||||
|
if !frame.IsAck() {
|
||||||
|
// Check for window size changes
|
||||||
|
frame.ForeachSetting(func(setting http2.Setting) error {
|
||||||
|
if setting.ID == http2.SettingInitialWindowSize {
|
||||||
|
s.windowMu.Lock()
|
||||||
|
delta := int32(setting.Val) - s.sendWindow
|
||||||
|
s.sendWindow += delta
|
||||||
|
s.windowMu.Unlock()
|
||||||
|
s.windowCond.Broadcast()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
s.mu.Lock()
|
||||||
|
s.framer.WriteSettingsAck()
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
case *http2.WindowUpdateFrame:
|
||||||
|
// Update send-side flow control window
|
||||||
|
s.windowMu.Lock()
|
||||||
|
if frame.StreamID == 0 {
|
||||||
|
s.connWindow += int32(frame.Increment)
|
||||||
|
} else if frame.StreamID == s.streamID {
|
||||||
|
s.sendWindow += int32(frame.Increment)
|
||||||
|
}
|
||||||
|
s.windowMu.Unlock()
|
||||||
|
s.windowCond.Broadcast()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type sliceWriter struct{ buf *[]byte }
|
||||||
|
|
||||||
|
func (w *sliceWriter) Write(p []byte) (int, error) {
|
||||||
|
*w.buf = append(*w.buf, p...)
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
@@ -10,9 +10,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
@@ -20,9 +18,9 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
@@ -80,36 +78,16 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
|||||||
}
|
}
|
||||||
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||||
|
|
||||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
|
||||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
if errBuild != nil {
|
||||||
if err == nil {
|
log.Errorf("%v", errBuild)
|
||||||
var transport *http.Transport
|
} else if transport != nil {
|
||||||
if proxyURL.Scheme == "socks5" {
|
proxyClient := &http.Client{Transport: transport}
|
||||||
// Handle SOCKS5 proxy.
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
|
||||||
username := proxyURL.User.Username()
|
|
||||||
password, _ := proxyURL.User.Password()
|
|
||||||
auth := &proxy.Auth{User: username, Password: password}
|
|
||||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
|
|
||||||
if errSOCKS5 != nil {
|
|
||||||
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
|
||||||
return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
|
|
||||||
}
|
|
||||||
transport = &http.Transport{
|
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.Dial(network, addr)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
|
||||||
// Handle HTTP/HTTPS proxy.
|
|
||||||
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
|
||||||
}
|
|
||||||
|
|
||||||
if transport != nil {
|
|
||||||
proxyClient := &http.Client{Transport: transport}
|
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
// Configure the OAuth2 client.
|
// Configure the OAuth2 client.
|
||||||
conf := &oauth2.Config{
|
conf := &oauth2.Config{
|
||||||
ClientID: ClientID,
|
ClientID: ClientID,
|
||||||
@@ -327,6 +305,9 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
|||||||
defer manualPromptTimer.Stop()
|
defer manualPromptTimer.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var manualInputCh <-chan string
|
||||||
|
var manualInputErrCh <-chan error
|
||||||
|
|
||||||
waitForCallback:
|
waitForCallback:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -348,13 +329,14 @@ waitForCallback:
|
|||||||
return nil, err
|
return nil, err
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
|
manualInputCh, manualInputErrCh = misc.AsyncPrompt(opts.Prompt, "Paste the Gemini callback URL (or press Enter to keep waiting): ")
|
||||||
if err != nil {
|
continue
|
||||||
return nil, err
|
case input := <-manualInputCh:
|
||||||
}
|
manualInputCh = nil
|
||||||
parsed, err := misc.ParseOAuthCallback(input)
|
manualInputErrCh = nil
|
||||||
if err != nil {
|
parsed, errParse := misc.ParseOAuthCallback(input)
|
||||||
return nil, err
|
if errParse != nil {
|
||||||
|
return nil, errParse
|
||||||
}
|
}
|
||||||
if parsed == nil {
|
if parsed == nil {
|
||||||
continue
|
continue
|
||||||
@@ -367,6 +349,8 @@ waitForCallback:
|
|||||||
}
|
}
|
||||||
authCode = parsed.Code
|
authCode = parsed.Code
|
||||||
break waitForCallback
|
break waitForCallback
|
||||||
|
case errManual := <-manualInputErrCh:
|
||||||
|
return nil, errManual
|
||||||
case <-timeoutTimer.C:
|
case <-timeoutTimer.C:
|
||||||
return nil, fmt.Errorf("oauth flow timed out")
|
return nil, fmt.Errorf("oauth flow timed out")
|
||||||
}
|
}
|
||||||
|
|||||||
492
internal/auth/gitlab/gitlab.go
Normal file
492
internal/auth/gitlab/gitlab.go
Normal file
@@ -0,0 +1,492 @@
|
|||||||
|
package gitlab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultBaseURL = "https://gitlab.com"
|
||||||
|
DefaultCallbackPort = 17171
|
||||||
|
defaultOAuthScope = "api read_user"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PKCECodes struct {
|
||||||
|
CodeVerifier string
|
||||||
|
CodeChallenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OAuthResult struct {
|
||||||
|
Code string
|
||||||
|
State string
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OAuthServer struct {
|
||||||
|
server *http.Server
|
||||||
|
port int
|
||||||
|
resultChan chan *OAuthResult
|
||||||
|
errorChan chan error
|
||||||
|
mu sync.Mutex
|
||||||
|
running bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
PublicEmail string `json:"public_email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PersonalAccessTokenSelf struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Scopes []string `json:"scopes"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelDetails struct {
|
||||||
|
ModelProvider string `json:"model_provider"`
|
||||||
|
ModelName string `json:"model_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DirectAccessResponse struct {
|
||||||
|
BaseURL string `json:"base_url"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
ExpiresAt int64 `json:"expires_at"`
|
||||||
|
Headers map[string]string `json:"headers"`
|
||||||
|
ModelDetails *ModelDetails `json:"model_details,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DiscoveredModel struct {
|
||||||
|
ModelProvider string
|
||||||
|
ModelName string
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthClient struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAuthClient(cfg *config.Config) *AuthClient {
|
||||||
|
client := &http.Client{}
|
||||||
|
if cfg != nil {
|
||||||
|
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||||
|
}
|
||||||
|
return &AuthClient{httpClient: client}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NormalizeBaseURL(raw string) string {
|
||||||
|
value := strings.TrimSpace(raw)
|
||||||
|
if value == "" {
|
||||||
|
return DefaultBaseURL
|
||||||
|
}
|
||||||
|
if !strings.Contains(value, "://") {
|
||||||
|
value = "https://" + value
|
||||||
|
}
|
||||||
|
value = strings.TrimRight(value, "/")
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func TokenExpiry(now time.Time, token *TokenResponse) time.Time {
|
||||||
|
if token == nil {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
if token.CreatedAt > 0 && token.ExpiresIn > 0 {
|
||||||
|
return time.Unix(token.CreatedAt+int64(token.ExpiresIn), 0).UTC()
|
||||||
|
}
|
||||||
|
if token.ExpiresIn > 0 {
|
||||||
|
return now.UTC().Add(time.Duration(token.ExpiresIn) * time.Second)
|
||||||
|
}
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GeneratePKCECodes() (*PKCECodes, error) {
|
||||||
|
verifierBytes := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(verifierBytes); err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab pkce generation failed: %w", err)
|
||||||
|
}
|
||||||
|
verifier := base64.RawURLEncoding.EncodeToString(verifierBytes)
|
||||||
|
sum := sha256.Sum256([]byte(verifier))
|
||||||
|
challenge := base64.RawURLEncoding.EncodeToString(sum[:])
|
||||||
|
return &PKCECodes{
|
||||||
|
CodeVerifier: verifier,
|
||||||
|
CodeChallenge: challenge,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOAuthServer(port int) *OAuthServer {
|
||||||
|
return &OAuthServer{
|
||||||
|
port: port,
|
||||||
|
resultChan: make(chan *OAuthResult, 1),
|
||||||
|
errorChan: make(chan error, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) Start() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.running {
|
||||||
|
return fmt.Errorf("gitlab oauth server already running")
|
||||||
|
}
|
||||||
|
if !s.isPortAvailable() {
|
||||||
|
return fmt.Errorf("port %d is already in use", s.port)
|
||||||
|
}
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/auth/callback", s.handleCallback)
|
||||||
|
|
||||||
|
s.server = &http.Server{
|
||||||
|
Addr: fmt.Sprintf(":%d", s.port),
|
||||||
|
Handler: mux,
|
||||||
|
ReadTimeout: 10 * time.Second,
|
||||||
|
WriteTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
s.running = true
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
s.errorChan <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) Stop(ctx context.Context) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if !s.running || s.server == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
s.running = false
|
||||||
|
s.server = nil
|
||||||
|
}()
|
||||||
|
return s.server.Shutdown(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
|
||||||
|
select {
|
||||||
|
case result := <-s.resultChan:
|
||||||
|
return result, nil
|
||||||
|
case err := <-s.errorChan:
|
||||||
|
return nil, err
|
||||||
|
case <-time.After(timeout):
|
||||||
|
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
query := r.URL.Query()
|
||||||
|
if errParam := strings.TrimSpace(query.Get("error")); errParam != "" {
|
||||||
|
s.sendResult(&OAuthResult{Error: errParam})
|
||||||
|
http.Error(w, errParam, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code := strings.TrimSpace(query.Get("code"))
|
||||||
|
state := strings.TrimSpace(query.Get("state"))
|
||||||
|
if code == "" || state == "" {
|
||||||
|
s.sendResult(&OAuthResult{Error: "missing_code_or_state"})
|
||||||
|
http.Error(w, "missing code or state", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.sendResult(&OAuthResult{Code: code, State: state})
|
||||||
|
_, _ = w.Write([]byte("GitLab authentication received. You can close this tab."))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) sendResult(result *OAuthResult) {
|
||||||
|
select {
|
||||||
|
case s.resultChan <- result:
|
||||||
|
default:
|
||||||
|
log.Debug("gitlab oauth result channel full, dropping callback result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) isPortAvailable() bool {
|
||||||
|
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port))
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_ = listener.Close()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func RedirectURL(port int) string {
|
||||||
|
return fmt.Sprintf("http://localhost:%d/auth/callback", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) GenerateAuthURL(baseURL, clientID, redirectURI, state string, pkce *PKCECodes) (string, error) {
|
||||||
|
if pkce == nil {
|
||||||
|
return "", fmt.Errorf("gitlab auth URL generation failed: PKCE codes are required")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(clientID) == "" {
|
||||||
|
return "", fmt.Errorf("gitlab auth URL generation failed: client ID is required")
|
||||||
|
}
|
||||||
|
baseURL = NormalizeBaseURL(baseURL)
|
||||||
|
params := url.Values{
|
||||||
|
"client_id": {strings.TrimSpace(clientID)},
|
||||||
|
"response_type": {"code"},
|
||||||
|
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||||
|
"scope": {defaultOAuthScope},
|
||||||
|
"state": {strings.TrimSpace(state)},
|
||||||
|
"code_challenge": {pkce.CodeChallenge},
|
||||||
|
"code_challenge_method": {"S256"},
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/oauth/authorize?%s", baseURL, params.Encode()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) ExchangeCodeForTokens(ctx context.Context, baseURL, clientID, clientSecret, redirectURI, code, codeVerifier string) (*TokenResponse, error) {
|
||||||
|
form := url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"client_id": {strings.TrimSpace(clientID)},
|
||||||
|
"code": {strings.TrimSpace(code)},
|
||||||
|
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||||
|
"code_verifier": {strings.TrimSpace(codeVerifier)},
|
||||||
|
}
|
||||||
|
if secret := strings.TrimSpace(clientSecret); secret != "" {
|
||||||
|
form.Set("client_secret", secret)
|
||||||
|
}
|
||||||
|
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) RefreshTokens(ctx context.Context, baseURL, clientID, clientSecret, refreshToken string) (*TokenResponse, error) {
|
||||||
|
form := url.Values{
|
||||||
|
"grant_type": {"refresh_token"},
|
||||||
|
"refresh_token": {strings.TrimSpace(refreshToken)},
|
||||||
|
}
|
||||||
|
if clientID = strings.TrimSpace(clientID); clientID != "" {
|
||||||
|
form.Set("client_id", clientID)
|
||||||
|
}
|
||||||
|
if secret := strings.TrimSpace(clientSecret); secret != "" {
|
||||||
|
form.Set("client_secret", secret)
|
||||||
|
}
|
||||||
|
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) postToken(ctx context.Context, tokenURL string, form url.Values) (*TokenResponse, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab token request failed: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab token request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab token response read failed: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("gitlab token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
var token TokenResponse
|
||||||
|
if err := json.Unmarshal(body, &token); err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab token response decode failed: %w", err)
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) GetCurrentUser(ctx context.Context, baseURL, token string) (*User, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/user", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab user request failed: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab user request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab user response read failed: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("gitlab user request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var user User
|
||||||
|
if err := json.Unmarshal(body, &user); err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab user response decode failed: %w", err)
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) GetPersonalAccessTokenSelf(ctx context.Context, baseURL, token string) (*PersonalAccessTokenSelf, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/personal_access_tokens/self", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab PAT self response read failed: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("gitlab PAT self request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var pat PersonalAccessTokenSelf
|
||||||
|
if err := json.Unmarshal(body, &pat); err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab PAT self response decode failed: %w", err)
|
||||||
|
}
|
||||||
|
return &pat, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) FetchDirectAccess(ctx context.Context, baseURL, token string) (*DirectAccessResponse, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, NormalizeBaseURL(baseURL)+"/api/v4/code_suggestions/direct_access", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab direct access response read failed: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("gitlab direct access request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var direct DirectAccessResponse
|
||||||
|
if err := json.Unmarshal(body, &direct); err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab direct access response decode failed: %w", err)
|
||||||
|
}
|
||||||
|
if direct.Headers == nil {
|
||||||
|
direct.Headers = make(map[string]string)
|
||||||
|
}
|
||||||
|
return &direct, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtractDiscoveredModels(metadata map[string]any) []DiscoveredModel {
|
||||||
|
if len(metadata) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
models := make([]DiscoveredModel, 0, 4)
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
appendModel := func(provider, name string) {
|
||||||
|
provider = strings.TrimSpace(provider)
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if name == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key := strings.ToLower(name)
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
models = append(models, DiscoveredModel{
|
||||||
|
ModelProvider: provider,
|
||||||
|
ModelName: name,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if raw, ok := metadata["model_details"]; ok {
|
||||||
|
appendDiscoveredModels(raw, appendModel)
|
||||||
|
}
|
||||||
|
appendModel(stringValue(metadata["model_provider"]), stringValue(metadata["model_name"]))
|
||||||
|
|
||||||
|
for _, key := range []string{"models", "supported_models", "discovered_models"} {
|
||||||
|
if raw, ok := metadata[key]; ok {
|
||||||
|
appendDiscoveredModels(raw, appendModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendDiscoveredModels(raw any, appendModel func(provider, name string)) {
|
||||||
|
switch typed := raw.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
appendModel(stringValue(typed["model_provider"]), stringValue(typed["model_name"]))
|
||||||
|
appendModel(stringValue(typed["provider"]), stringValue(typed["name"]))
|
||||||
|
if nested, ok := typed["models"]; ok {
|
||||||
|
appendDiscoveredModels(nested, appendModel)
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, item := range typed {
|
||||||
|
appendDiscoveredModels(item, appendModel)
|
||||||
|
}
|
||||||
|
case []string:
|
||||||
|
for _, item := range typed {
|
||||||
|
appendModel("", item)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
appendModel("", typed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringValue(raw any) string {
|
||||||
|
switch typed := raw.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(typed)
|
||||||
|
case fmt.Stringer:
|
||||||
|
return strings.TrimSpace(typed.String())
|
||||||
|
case json.Number:
|
||||||
|
return typed.String()
|
||||||
|
case int:
|
||||||
|
return strconv.Itoa(typed)
|
||||||
|
case int64:
|
||||||
|
return strconv.FormatInt(typed, 10)
|
||||||
|
case float64:
|
||||||
|
return strconv.FormatInt(int64(typed), 10)
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
138
internal/auth/gitlab/gitlab_test.go
Normal file
138
internal/auth/gitlab/gitlab_test.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
package gitlab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthClientGenerateAuthURLIncludesPKCE(t *testing.T) {
|
||||||
|
client := NewAuthClient(nil)
|
||||||
|
pkce, err := GeneratePKCECodes()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GeneratePKCECodes() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawURL, err := client.GenerateAuthURL("https://gitlab.example.com", "client-id", RedirectURL(17171), "state-123", pkce)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateAuthURL() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := url.Parse(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Parse(authURL) error = %v", err)
|
||||||
|
}
|
||||||
|
if got := parsed.Path; got != "/oauth/authorize" {
|
||||||
|
t.Fatalf("expected /oauth/authorize path, got %q", got)
|
||||||
|
}
|
||||||
|
query := parsed.Query()
|
||||||
|
if got := query.Get("client_id"); got != "client-id" {
|
||||||
|
t.Fatalf("expected client_id, got %q", got)
|
||||||
|
}
|
||||||
|
if got := query.Get("scope"); got != defaultOAuthScope {
|
||||||
|
t.Fatalf("expected scope %q, got %q", defaultOAuthScope, got)
|
||||||
|
}
|
||||||
|
if got := query.Get("code_challenge_method"); got != "S256" {
|
||||||
|
t.Fatalf("expected PKCE method S256, got %q", got)
|
||||||
|
}
|
||||||
|
if got := query.Get("code_challenge"); got == "" {
|
||||||
|
t.Fatal("expected non-empty code_challenge")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthClientExchangeCodeForTokens(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/oauth/token" {
|
||||||
|
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
t.Fatalf("ParseForm() error = %v", err)
|
||||||
|
}
|
||||||
|
if got := r.Form.Get("grant_type"); got != "authorization_code" {
|
||||||
|
t.Fatalf("expected authorization_code grant, got %q", got)
|
||||||
|
}
|
||||||
|
if got := r.Form.Get("code_verifier"); got != "verifier-123" {
|
||||||
|
t.Fatalf("expected code_verifier, got %q", got)
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"access_token": "oauth-access",
|
||||||
|
"refresh_token": "oauth-refresh",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"scope": "api read_user",
|
||||||
|
"created_at": 1710000000,
|
||||||
|
"expires_in": 3600,
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := NewAuthClient(nil)
|
||||||
|
token, err := client.ExchangeCodeForTokens(context.Background(), srv.URL, "client-id", "client-secret", RedirectURL(17171), "auth-code", "verifier-123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExchangeCodeForTokens() error = %v", err)
|
||||||
|
}
|
||||||
|
if token.AccessToken != "oauth-access" {
|
||||||
|
t.Fatalf("expected access token, got %q", token.AccessToken)
|
||||||
|
}
|
||||||
|
if token.RefreshToken != "oauth-refresh" {
|
||||||
|
t.Fatalf("expected refresh token, got %q", token.RefreshToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractDiscoveredModels(t *testing.T) {
|
||||||
|
models := ExtractDiscoveredModels(map[string]any{
|
||||||
|
"model_details": map[string]any{
|
||||||
|
"model_provider": "anthropic",
|
||||||
|
"model_name": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
"supported_models": []any{
|
||||||
|
map[string]any{"model_provider": "openai", "model_name": "gpt-4.1"},
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(models) != 2 {
|
||||||
|
t.Fatalf("expected 2 unique models, got %d", len(models))
|
||||||
|
}
|
||||||
|
if models[0].ModelName != "claude-sonnet-4-5" {
|
||||||
|
t.Fatalf("unexpected first model %q", models[0].ModelName)
|
||||||
|
}
|
||||||
|
if models[1].ModelName != "gpt-4.1" {
|
||||||
|
t.Fatalf("unexpected second model %q", models[1].ModelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchDirectAccessDecodesModelDetails(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/api/v4/code_suggestions/direct_access" {
|
||||||
|
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("Authorization"); !strings.Contains(got, "token-123") {
|
||||||
|
t.Fatalf("expected bearer token, got %q", got)
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"base_url": "https://cloud.gitlab.example.com",
|
||||||
|
"token": "gateway-token",
|
||||||
|
"expires_at": 1710003600,
|
||||||
|
"headers": map[string]string{
|
||||||
|
"X-Gitlab-Realm": "saas",
|
||||||
|
},
|
||||||
|
"model_details": map[string]any{
|
||||||
|
"model_provider": "anthropic",
|
||||||
|
"model_name": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := NewAuthClient(nil)
|
||||||
|
direct, err := client.FetchDirectAccess(context.Background(), srv.URL, "token-123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FetchDirectAccess() error = %v", err)
|
||||||
|
}
|
||||||
|
if direct.ModelDetails == nil || direct.ModelDetails.ModelName != "claude-sonnet-4-5" {
|
||||||
|
t.Fatalf("expected model details, got %+v", direct.ModelDetails)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -102,10 +102,24 @@ func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
|
|||||||
|
|
||||||
// NewDeviceFlowClientWithDeviceID creates a new device flow client with the specified device ID.
|
// NewDeviceFlowClientWithDeviceID creates a new device flow client with the specified device ID.
|
||||||
func NewDeviceFlowClientWithDeviceID(cfg *config.Config, deviceID string) *DeviceFlowClient {
|
func NewDeviceFlowClientWithDeviceID(cfg *config.Config, deviceID string) *DeviceFlowClient {
|
||||||
|
return NewDeviceFlowClientWithDeviceIDAndProxyURL(cfg, deviceID, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDeviceFlowClientWithDeviceIDAndProxyURL creates a new device flow client with a proxy override.
|
||||||
|
// proxyURL takes precedence over cfg.ProxyURL when non-empty.
|
||||||
|
func NewDeviceFlowClientWithDeviceIDAndProxyURL(cfg *config.Config, deviceID string, proxyURL string) *DeviceFlowClient {
|
||||||
client := &http.Client{Timeout: 30 * time.Second}
|
client := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
effectiveProxyURL := strings.TrimSpace(proxyURL)
|
||||||
|
var sdkCfg config.SDKConfig
|
||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
sdkCfg = cfg.SDKConfig
|
||||||
|
if effectiveProxyURL == "" {
|
||||||
|
effectiveProxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
sdkCfg.ProxyURL = effectiveProxyURL
|
||||||
|
client = util.SetProxy(&sdkCfg, client)
|
||||||
|
|
||||||
resolvedDeviceID := strings.TrimSpace(deviceID)
|
resolvedDeviceID := strings.TrimSpace(deviceID)
|
||||||
if resolvedDeviceID == "" {
|
if resolvedDeviceID == "" {
|
||||||
resolvedDeviceID = getOrCreateDeviceID()
|
resolvedDeviceID = getOrCreateDeviceID()
|
||||||
|
|||||||
42
internal/auth/kimi/kimi_proxy_test.go
Normal file
42
internal/auth/kimi/kimi_proxy_test.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package kimi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewDeviceFlowClientWithDeviceIDAndProxyURL_OverrideDirectDisablesProxy(t *testing.T) {
|
||||||
|
cfg := &config.Config{SDKConfig: config.SDKConfig{ProxyURL: "http://proxy.example.com:8080"}}
|
||||||
|
client := NewDeviceFlowClientWithDeviceIDAndProxyURL(cfg, "device-1", "direct")
|
||||||
|
|
||||||
|
transport, ok := client.httpClient.Transport.(*http.Transport)
|
||||||
|
if !ok || transport == nil {
|
||||||
|
t.Fatalf("expected http.Transport, got %T", client.httpClient.Transport)
|
||||||
|
}
|
||||||
|
if transport.Proxy != nil {
|
||||||
|
t.Fatal("expected direct transport to disable proxy function")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewDeviceFlowClientWithDeviceIDAndProxyURL_OverrideProxyTakesPrecedence(t *testing.T) {
|
||||||
|
cfg := &config.Config{SDKConfig: config.SDKConfig{ProxyURL: "http://global.example.com:8080"}}
|
||||||
|
client := NewDeviceFlowClientWithDeviceIDAndProxyURL(cfg, "device-1", "http://override.example.com:8081")
|
||||||
|
|
||||||
|
transport, ok := client.httpClient.Transport.(*http.Transport)
|
||||||
|
if !ok || transport == nil {
|
||||||
|
t.Fatalf("expected http.Transport, got %T", client.httpClient.Transport)
|
||||||
|
}
|
||||||
|
req, errReq := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||||
|
if errReq != nil {
|
||||||
|
t.Fatalf("new request: %v", errReq)
|
||||||
|
}
|
||||||
|
proxyURL, errProxy := transport.Proxy(req)
|
||||||
|
if errProxy != nil {
|
||||||
|
t.Fatalf("proxy func: %v", errProxy)
|
||||||
|
}
|
||||||
|
if proxyURL == nil || proxyURL.String() != "http://override.example.com:8081" {
|
||||||
|
t.Fatalf("proxy URL = %v, want http://override.example.com:8081", proxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -748,4 +748,3 @@ func TestExtractRegionFromMetadata(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CooldownReason429 = "rate_limit_exceeded"
|
CooldownReason429 = "rate_limit_exceeded"
|
||||||
CooldownReasonSuspended = "account_suspended"
|
CooldownReasonSuspended = "account_suspended"
|
||||||
CooldownReasonQuotaExhausted = "quota_exhausted"
|
CooldownReasonQuotaExhausted = "quota_exhausted"
|
||||||
|
|
||||||
DefaultShortCooldown = 1 * time.Minute
|
DefaultShortCooldown = 1 * time.Minute
|
||||||
|
|||||||
@@ -26,9 +26,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
jitterRand *rand.Rand
|
jitterRand *rand.Rand
|
||||||
jitterRandOnce sync.Once
|
jitterRandOnce sync.Once
|
||||||
jitterMu sync.Mutex
|
jitterMu sync.Mutex
|
||||||
lastRequestTime time.Time
|
lastRequestTime time.Time
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -24,10 +24,10 @@ type TokenScorer struct {
|
|||||||
metrics map[string]*TokenMetrics
|
metrics map[string]*TokenMetrics
|
||||||
|
|
||||||
// Scoring weights
|
// Scoring weights
|
||||||
successRateWeight float64
|
successRateWeight float64
|
||||||
quotaWeight float64
|
quotaWeight float64
|
||||||
latencyWeight float64
|
latencyWeight float64
|
||||||
lastUsedWeight float64
|
lastUsedWeight float64
|
||||||
failPenaltyMultiplier float64
|
failPenaltyMultiplier float64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ func (h *ProtocolHandler) Start(ctx context.Context) (int, error) {
|
|||||||
var listener net.Listener
|
var listener net.Listener
|
||||||
var err error
|
var err error
|
||||||
portRange := []int{DefaultHandlerPort, DefaultHandlerPort + 1, DefaultHandlerPort + 2, DefaultHandlerPort + 3, DefaultHandlerPort + 4}
|
portRange := []int{DefaultHandlerPort, DefaultHandlerPort + 1, DefaultHandlerPort + 2, DefaultHandlerPort + 3, DefaultHandlerPort + 4}
|
||||||
|
|
||||||
for _, port := range portRange {
|
for _, port := range portRange {
|
||||||
listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -105,7 +105,7 @@ func (h *ProtocolHandler) Start(ctx context.Context) (int, error) {
|
|||||||
}
|
}
|
||||||
log.Debugf("kiro protocol handler: port %d busy, trying next", port)
|
log.Debugf("kiro protocol handler: port %d busy, trying next", port)
|
||||||
}
|
}
|
||||||
|
|
||||||
if listener == nil {
|
if listener == nil {
|
||||||
return 0, fmt.Errorf("failed to start callback server: all ports %d-%d are busy", DefaultHandlerPort, DefaultHandlerPort+4)
|
return 0, fmt.Errorf("failed to start callback server: all ports %d-%d are busy", DefaultHandlerPort, DefaultHandlerPort+4)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,359 +0,0 @@
|
|||||||
package qwen
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// QwenOAuthDeviceCodeEndpoint is the URL for initiating the OAuth 2.0 device authorization flow.
|
|
||||||
QwenOAuthDeviceCodeEndpoint = "https://chat.qwen.ai/api/v1/oauth2/device/code"
|
|
||||||
// QwenOAuthTokenEndpoint is the URL for exchanging device codes or refresh tokens for access tokens.
|
|
||||||
QwenOAuthTokenEndpoint = "https://chat.qwen.ai/api/v1/oauth2/token"
|
|
||||||
// QwenOAuthClientID is the client identifier for the Qwen OAuth 2.0 application.
|
|
||||||
QwenOAuthClientID = "f0304373b74a44d2b584a3fb70ca9e56"
|
|
||||||
// QwenOAuthScope defines the permissions requested by the application.
|
|
||||||
QwenOAuthScope = "openid profile email model.completion"
|
|
||||||
// QwenOAuthGrantType specifies the grant type for the device code flow.
|
|
||||||
QwenOAuthGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
|
||||||
)
|
|
||||||
|
|
||||||
// QwenTokenData represents the OAuth credentials, including access and refresh tokens.
|
|
||||||
type QwenTokenData struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
// RefreshToken is used to obtain a new access token when the current one expires.
|
|
||||||
RefreshToken string `json:"refresh_token,omitempty"`
|
|
||||||
// TokenType indicates the type of token, typically "Bearer".
|
|
||||||
TokenType string `json:"token_type"`
|
|
||||||
// ResourceURL specifies the base URL of the resource server.
|
|
||||||
ResourceURL string `json:"resource_url,omitempty"`
|
|
||||||
// Expire indicates the expiration date and time of the access token.
|
|
||||||
Expire string `json:"expiry_date,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeviceFlow represents the response from the device authorization endpoint.
|
|
||||||
type DeviceFlow struct {
|
|
||||||
// DeviceCode is the code that the client uses to poll for an access token.
|
|
||||||
DeviceCode string `json:"device_code"`
|
|
||||||
// UserCode is the code that the user enters at the verification URI.
|
|
||||||
UserCode string `json:"user_code"`
|
|
||||||
// VerificationURI is the URL where the user can enter the user code to authorize the device.
|
|
||||||
VerificationURI string `json:"verification_uri"`
|
|
||||||
// VerificationURIComplete is a URI that includes the user_code, which can be used to automatically
|
|
||||||
// fill in the code on the verification page.
|
|
||||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
|
||||||
// ExpiresIn is the time in seconds until the device_code and user_code expire.
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
// Interval is the minimum time in seconds that the client should wait between polling requests.
|
|
||||||
Interval int `json:"interval"`
|
|
||||||
// CodeVerifier is the cryptographically random string used in the PKCE flow.
|
|
||||||
CodeVerifier string `json:"code_verifier"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// QwenTokenResponse represents the successful token response from the token endpoint.
|
|
||||||
type QwenTokenResponse struct {
|
|
||||||
// AccessToken is the token used to access protected resources.
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
// RefreshToken is used to obtain a new access token.
|
|
||||||
RefreshToken string `json:"refresh_token,omitempty"`
|
|
||||||
// TokenType indicates the type of token, typically "Bearer".
|
|
||||||
TokenType string `json:"token_type"`
|
|
||||||
// ResourceURL specifies the base URL of the resource server.
|
|
||||||
ResourceURL string `json:"resource_url,omitempty"`
|
|
||||||
// ExpiresIn is the time in seconds until the access token expires.
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// QwenAuth manages authentication and token handling for the Qwen API.
|
|
||||||
type QwenAuth struct {
|
|
||||||
httpClient *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewQwenAuth creates a new QwenAuth instance with a proxy-configured HTTP client.
|
|
||||||
func NewQwenAuth(cfg *config.Config) *QwenAuth {
|
|
||||||
return &QwenAuth{
|
|
||||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateCodeVerifier generates a cryptographically random string for the PKCE code verifier.
|
|
||||||
func (qa *QwenAuth) generateCodeVerifier() (string, error) {
|
|
||||||
bytes := make([]byte, 32)
|
|
||||||
if _, err := rand.Read(bytes); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return base64.RawURLEncoding.EncodeToString(bytes), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateCodeChallenge creates a SHA-256 hash of the code verifier, used as the PKCE code challenge.
|
|
||||||
func (qa *QwenAuth) generateCodeChallenge(codeVerifier string) string {
|
|
||||||
hash := sha256.Sum256([]byte(codeVerifier))
|
|
||||||
return base64.RawURLEncoding.EncodeToString(hash[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// generatePKCEPair creates a new code verifier and its corresponding code challenge for PKCE.
|
|
||||||
func (qa *QwenAuth) generatePKCEPair() (string, string, error) {
|
|
||||||
codeVerifier, err := qa.generateCodeVerifier()
|
|
||||||
if err != nil {
|
|
||||||
return "", "", err
|
|
||||||
}
|
|
||||||
codeChallenge := qa.generateCodeChallenge(codeVerifier)
|
|
||||||
return codeVerifier, codeChallenge, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RefreshTokens exchanges a refresh token for a new access token.
|
|
||||||
func (qa *QwenAuth) RefreshTokens(ctx context.Context, refreshToken string) (*QwenTokenData, error) {
|
|
||||||
data := url.Values{}
|
|
||||||
data.Set("grant_type", "refresh_token")
|
|
||||||
data.Set("refresh_token", refreshToken)
|
|
||||||
data.Set("client_id", QwenOAuthClientID)
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthTokenEndpoint, strings.NewReader(data.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
|
|
||||||
resp, err := qa.httpClient.Do(req)
|
|
||||||
|
|
||||||
// resp, err := qa.httpClient.PostForm(QwenOAuthTokenEndpoint, data)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("token refresh request failed: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
var errorData map[string]interface{}
|
|
||||||
if err = json.Unmarshal(body, &errorData); err == nil {
|
|
||||||
return nil, fmt.Errorf("token refresh failed: %v - %v", errorData["error"], errorData["error_description"])
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
var tokenData QwenTokenResponse
|
|
||||||
if err = json.Unmarshal(body, &tokenData); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &QwenTokenData{
|
|
||||||
AccessToken: tokenData.AccessToken,
|
|
||||||
TokenType: tokenData.TokenType,
|
|
||||||
RefreshToken: tokenData.RefreshToken,
|
|
||||||
ResourceURL: tokenData.ResourceURL,
|
|
||||||
Expire: time.Now().Add(time.Duration(tokenData.ExpiresIn) * time.Second).Format(time.RFC3339),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitiateDeviceFlow starts the OAuth 2.0 device authorization flow and returns the device flow details.
|
|
||||||
func (qa *QwenAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceFlow, error) {
|
|
||||||
// Generate PKCE code verifier and challenge
|
|
||||||
codeVerifier, codeChallenge, err := qa.generatePKCEPair()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to generate PKCE pair: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data := url.Values{}
|
|
||||||
data.Set("client_id", QwenOAuthClientID)
|
|
||||||
data.Set("scope", QwenOAuthScope)
|
|
||||||
data.Set("code_challenge", codeChallenge)
|
|
||||||
data.Set("code_challenge_method", "S256")
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthDeviceCodeEndpoint, strings.NewReader(data.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
|
|
||||||
resp, err := qa.httpClient.Do(req)
|
|
||||||
|
|
||||||
// resp, err := qa.httpClient.PostForm(QwenOAuthDeviceCodeEndpoint, data)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("device authorization request failed: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("device authorization failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
var result DeviceFlow
|
|
||||||
if err = json.Unmarshal(body, &result); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse device flow response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the response indicates success
|
|
||||||
if result.DeviceCode == "" {
|
|
||||||
return nil, fmt.Errorf("device authorization failed: device_code not found in response")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the code_verifier to the result so it can be used later for polling
|
|
||||||
result.CodeVerifier = codeVerifier
|
|
||||||
|
|
||||||
return &result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PollForToken polls the token endpoint with the device code to obtain an access token.
|
|
||||||
func (qa *QwenAuth) PollForToken(deviceCode, codeVerifier string) (*QwenTokenData, error) {
|
|
||||||
pollInterval := 5 * time.Second
|
|
||||||
maxAttempts := 60 // 5 minutes max
|
|
||||||
|
|
||||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
|
||||||
data := url.Values{}
|
|
||||||
data.Set("grant_type", QwenOAuthGrantType)
|
|
||||||
data.Set("client_id", QwenOAuthClientID)
|
|
||||||
data.Set("device_code", deviceCode)
|
|
||||||
data.Set("code_verifier", codeVerifier)
|
|
||||||
|
|
||||||
resp, err := http.PostForm(QwenOAuthTokenEndpoint, data)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err)
|
|
||||||
time.Sleep(pollInterval)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err)
|
|
||||||
time.Sleep(pollInterval)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
// Parse the response as JSON to check for OAuth RFC 8628 standard errors
|
|
||||||
var errorData map[string]interface{}
|
|
||||||
if err = json.Unmarshal(body, &errorData); err == nil {
|
|
||||||
// According to OAuth RFC 8628, handle standard polling responses
|
|
||||||
if resp.StatusCode == http.StatusBadRequest {
|
|
||||||
errorType, _ := errorData["error"].(string)
|
|
||||||
switch errorType {
|
|
||||||
case "authorization_pending":
|
|
||||||
// User has not yet approved the authorization request. Continue polling.
|
|
||||||
fmt.Printf("Polling attempt %d/%d...\n\n", attempt+1, maxAttempts)
|
|
||||||
time.Sleep(pollInterval)
|
|
||||||
continue
|
|
||||||
case "slow_down":
|
|
||||||
// Client is polling too frequently. Increase poll interval.
|
|
||||||
pollInterval = time.Duration(float64(pollInterval) * 1.5)
|
|
||||||
if pollInterval > 10*time.Second {
|
|
||||||
pollInterval = 10 * time.Second
|
|
||||||
}
|
|
||||||
fmt.Printf("Server requested to slow down, increasing poll interval to %v\n\n", pollInterval)
|
|
||||||
time.Sleep(pollInterval)
|
|
||||||
continue
|
|
||||||
case "expired_token":
|
|
||||||
return nil, fmt.Errorf("device code expired. Please restart the authentication process")
|
|
||||||
case "access_denied":
|
|
||||||
return nil, fmt.Errorf("authorization denied by user. Please restart the authentication process")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// For other errors, return with proper error information
|
|
||||||
errorType, _ := errorData["error"].(string)
|
|
||||||
errorDesc, _ := errorData["error_description"].(string)
|
|
||||||
return nil, fmt.Errorf("device token poll failed: %s - %s", errorType, errorDesc)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If JSON parsing fails, fall back to text response
|
|
||||||
return nil, fmt.Errorf("device token poll failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body))
|
|
||||||
}
|
|
||||||
// log.Debugf("%s", string(body))
|
|
||||||
// Success - parse token data
|
|
||||||
var response QwenTokenResponse
|
|
||||||
if err = json.Unmarshal(body, &response); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to QwenTokenData format and save
|
|
||||||
tokenData := &QwenTokenData{
|
|
||||||
AccessToken: response.AccessToken,
|
|
||||||
RefreshToken: response.RefreshToken,
|
|
||||||
TokenType: response.TokenType,
|
|
||||||
ResourceURL: response.ResourceURL,
|
|
||||||
Expire: time.Now().Add(time.Duration(response.ExpiresIn) * time.Second).Format(time.RFC3339),
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("authentication timeout. Please restart the authentication process")
|
|
||||||
}
|
|
||||||
|
|
||||||
// RefreshTokensWithRetry attempts to refresh tokens with a specified number of retries upon failure.
|
|
||||||
func (o *QwenAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken string, maxRetries int) (*QwenTokenData, error) {
|
|
||||||
var lastErr error
|
|
||||||
|
|
||||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
|
||||||
if attempt > 0 {
|
|
||||||
// Wait before retry
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case <-time.After(time.Duration(attempt) * time.Second):
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenData, err := o.RefreshTokens(ctx, refreshToken)
|
|
||||||
if err == nil {
|
|
||||||
return tokenData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
lastErr = err
|
|
||||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateTokenStorage creates a QwenTokenStorage object from a QwenTokenData object.
|
|
||||||
func (o *QwenAuth) CreateTokenStorage(tokenData *QwenTokenData) *QwenTokenStorage {
|
|
||||||
storage := &QwenTokenStorage{
|
|
||||||
AccessToken: tokenData.AccessToken,
|
|
||||||
RefreshToken: tokenData.RefreshToken,
|
|
||||||
LastRefresh: time.Now().Format(time.RFC3339),
|
|
||||||
ResourceURL: tokenData.ResourceURL,
|
|
||||||
Expire: tokenData.Expire,
|
|
||||||
}
|
|
||||||
|
|
||||||
return storage
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateTokenStorage updates an existing token storage with new token data
|
|
||||||
func (o *QwenAuth) UpdateTokenStorage(storage *QwenTokenStorage, tokenData *QwenTokenData) {
|
|
||||||
storage.AccessToken = tokenData.AccessToken
|
|
||||||
storage.RefreshToken = tokenData.RefreshToken
|
|
||||||
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
|
||||||
storage.ResourceURL = tokenData.ResourceURL
|
|
||||||
storage.Expire = tokenData.Expire
|
|
||||||
}
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
// Package qwen provides authentication and token management functionality
|
|
||||||
// for Alibaba's Qwen AI services. It handles OAuth2 token storage, serialization,
|
|
||||||
// and retrieval for maintaining authenticated sessions with the Qwen API.
|
|
||||||
package qwen
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
|
||||||
)
|
|
||||||
|
|
||||||
// QwenTokenStorage stores OAuth2 token information for Alibaba Qwen API authentication.
|
|
||||||
// It maintains compatibility with the existing auth system while adding Qwen-specific fields
|
|
||||||
// for managing access tokens, refresh tokens, and user account information.
|
|
||||||
type QwenTokenStorage struct {
|
|
||||||
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
// RefreshToken is used to obtain new access tokens when the current one expires.
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
// LastRefresh is the timestamp of the last token refresh operation.
|
|
||||||
LastRefresh string `json:"last_refresh"`
|
|
||||||
// ResourceURL is the base URL for API requests.
|
|
||||||
ResourceURL string `json:"resource_url"`
|
|
||||||
// Email is the Qwen account email address associated with this token.
|
|
||||||
Email string `json:"email"`
|
|
||||||
// Type indicates the authentication provider type, always "qwen" for this storage.
|
|
||||||
Type string `json:"type"`
|
|
||||||
// Expire is the timestamp when the current access token expires.
|
|
||||||
Expire string `json:"expired"`
|
|
||||||
|
|
||||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
|
||||||
// It is not exported to JSON directly to allow flattening during serialization.
|
|
||||||
Metadata map[string]any `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
|
||||||
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
|
|
||||||
ts.Metadata = meta
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
|
|
||||||
// This method creates the necessary directory structure and writes the token
|
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
|
||||||
// It merges any injected metadata into the top-level JSON object.
|
|
||||||
//
|
|
||||||
// Parameters:
|
|
||||||
// - authFilePath: The full path where the token file should be saved
|
|
||||||
//
|
|
||||||
// Returns:
|
|
||||||
// - error: An error if the operation fails, nil otherwise
|
|
||||||
func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|
||||||
misc.LogSavingCredentials(authFilePath)
|
|
||||||
ts.Type = "qwen"
|
|
||||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
|
||||||
return fmt.Errorf("failed to create directory: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
f, err := os.Create(authFilePath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create token file: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = f.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Merge metadata using helper
|
|
||||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
|
||||||
if errMerge != nil {
|
|
||||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -30,6 +30,10 @@ type VertexCredentialStorage struct {
|
|||||||
|
|
||||||
// Type is the provider identifier stored alongside credentials. Always "vertex".
|
// Type is the provider identifier stored alongside credentials. Always "vertex".
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Prefix optionally namespaces models for this credential (e.g., "teamA").
|
||||||
|
// This results in model names like "teamA/gemini-2.0-flash".
|
||||||
|
Prefix string `json:"prefix,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile writes the credential payload to the given file path in JSON format.
|
// SaveTokenToFile writes the credential payload to the given file path in JSON format.
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ func CloseBrowser() error {
|
|||||||
if lastBrowserProcess == nil || lastBrowserProcess.Process == nil {
|
if lastBrowserProcess == nil || lastBrowserProcess.Process == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := lastBrowserProcess.Process.Kill()
|
err := lastBrowserProcess.Process.Kill()
|
||||||
lastBrowserProcess = nil
|
lastBrowserProcess = nil
|
||||||
return err
|
return err
|
||||||
|
|||||||
45
internal/cache/signature_cache.go
vendored
45
internal/cache/signature_cache.go
vendored
@@ -5,7 +5,10 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SignatureEntry holds a cached thinking signature with timestamp
|
// SignatureEntry holds a cached thinking signature with timestamp
|
||||||
@@ -193,3 +196,45 @@ func GetModelGroup(modelName string) string {
|
|||||||
}
|
}
|
||||||
return modelName
|
return modelName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var signatureCacheEnabled atomic.Bool
|
||||||
|
var signatureBypassStrictMode atomic.Bool
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
signatureCacheEnabled.Store(true)
|
||||||
|
signatureBypassStrictMode.Store(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSignatureCacheEnabled switches Antigravity signature handling between cache mode and bypass mode.
|
||||||
|
func SetSignatureCacheEnabled(enabled bool) {
|
||||||
|
previous := signatureCacheEnabled.Swap(enabled)
|
||||||
|
if previous == enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !enabled {
|
||||||
|
log.Info("antigravity signature cache DISABLED - bypass mode active, cached signatures will not be used for request translation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignatureCacheEnabled returns whether signature cache validation is enabled.
|
||||||
|
func SignatureCacheEnabled() bool {
|
||||||
|
return signatureCacheEnabled.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSignatureBypassStrictMode controls whether bypass mode uses strict protobuf-tree validation.
|
||||||
|
func SetSignatureBypassStrictMode(strict bool) {
|
||||||
|
previous := signatureBypassStrictMode.Swap(strict)
|
||||||
|
if previous == strict {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strict {
|
||||||
|
log.Debug("antigravity bypass signature validation: strict mode (protobuf tree)")
|
||||||
|
} else {
|
||||||
|
log.Debug("antigravity bypass signature validation: basic mode (R/E + 0x12)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignatureBypassStrictMode returns whether bypass mode uses strict protobuf-tree validation.
|
||||||
|
func SignatureBypassStrictMode() bool {
|
||||||
|
return signatureBypassStrictMode.Load()
|
||||||
|
}
|
||||||
|
|||||||
91
internal/cache/signature_cache_test.go
vendored
91
internal/cache/signature_cache_test.go
vendored
@@ -1,8 +1,12 @@
|
|||||||
package cache
|
package cache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const testModelName = "claude-sonnet-4-5"
|
const testModelName = "claude-sonnet-4-5"
|
||||||
@@ -208,3 +212,90 @@ func TestCacheSignature_ExpirationLogic(t *testing.T) {
|
|||||||
// but the logic is verified by the implementation
|
// but the logic is verified by the implementation
|
||||||
_ = time.Now() // Acknowledge we're not testing time passage
|
_ = time.Now() // Acknowledge we're not testing time passage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSignatureModeSetters_LogAtInfoLevel(t *testing.T) {
|
||||||
|
logger := log.StandardLogger()
|
||||||
|
previousOutput := logger.Out
|
||||||
|
previousLevel := logger.Level
|
||||||
|
previousCache := SignatureCacheEnabled()
|
||||||
|
previousStrict := SignatureBypassStrictMode()
|
||||||
|
SetSignatureCacheEnabled(true)
|
||||||
|
SetSignatureBypassStrictMode(false)
|
||||||
|
buffer := &bytes.Buffer{}
|
||||||
|
log.SetOutput(buffer)
|
||||||
|
log.SetLevel(log.InfoLevel)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
log.SetOutput(previousOutput)
|
||||||
|
log.SetLevel(previousLevel)
|
||||||
|
SetSignatureCacheEnabled(previousCache)
|
||||||
|
SetSignatureBypassStrictMode(previousStrict)
|
||||||
|
})
|
||||||
|
|
||||||
|
SetSignatureCacheEnabled(false)
|
||||||
|
SetSignatureBypassStrictMode(true)
|
||||||
|
SetSignatureBypassStrictMode(false)
|
||||||
|
|
||||||
|
output := buffer.String()
|
||||||
|
if !strings.Contains(output, "antigravity signature cache DISABLED") {
|
||||||
|
t.Fatalf("expected info output for disabling signature cache, got: %q", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "strict mode (protobuf tree)") {
|
||||||
|
t.Fatalf("expected strict bypass mode log to stay below info level, got: %q", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "basic mode (R/E + 0x12)") {
|
||||||
|
t.Fatalf("expected basic bypass mode log to stay below info level, got: %q", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignatureModeSetters_DoNotRepeatSameStateLogs(t *testing.T) {
|
||||||
|
logger := log.StandardLogger()
|
||||||
|
previousOutput := logger.Out
|
||||||
|
previousLevel := logger.Level
|
||||||
|
previousCache := SignatureCacheEnabled()
|
||||||
|
previousStrict := SignatureBypassStrictMode()
|
||||||
|
SetSignatureCacheEnabled(false)
|
||||||
|
SetSignatureBypassStrictMode(true)
|
||||||
|
buffer := &bytes.Buffer{}
|
||||||
|
log.SetOutput(buffer)
|
||||||
|
log.SetLevel(log.InfoLevel)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
log.SetOutput(previousOutput)
|
||||||
|
log.SetLevel(previousLevel)
|
||||||
|
SetSignatureCacheEnabled(previousCache)
|
||||||
|
SetSignatureBypassStrictMode(previousStrict)
|
||||||
|
})
|
||||||
|
|
||||||
|
SetSignatureCacheEnabled(false)
|
||||||
|
SetSignatureBypassStrictMode(true)
|
||||||
|
|
||||||
|
if buffer.Len() != 0 {
|
||||||
|
t.Fatalf("expected repeated setter calls with unchanged state to stay silent, got: %q", buffer.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignatureBypassStrictMode_LogsAtDebugLevel(t *testing.T) {
|
||||||
|
logger := log.StandardLogger()
|
||||||
|
previousOutput := logger.Out
|
||||||
|
previousLevel := logger.Level
|
||||||
|
previousStrict := SignatureBypassStrictMode()
|
||||||
|
SetSignatureBypassStrictMode(false)
|
||||||
|
buffer := &bytes.Buffer{}
|
||||||
|
log.SetOutput(buffer)
|
||||||
|
log.SetLevel(log.DebugLevel)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
log.SetOutput(previousOutput)
|
||||||
|
log.SetLevel(previousLevel)
|
||||||
|
SetSignatureBypassStrictMode(previousStrict)
|
||||||
|
})
|
||||||
|
|
||||||
|
SetSignatureBypassStrictMode(true)
|
||||||
|
SetSignatureBypassStrictMode(false)
|
||||||
|
|
||||||
|
output := buffer.String()
|
||||||
|
if !strings.Contains(output, "strict mode (protobuf tree)") {
|
||||||
|
t.Fatalf("expected debug output for strict bypass mode, got: %q", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "basic mode (R/E + 0x12)") {
|
||||||
|
t.Fatalf("expected debug output for basic bypass mode, got: %q", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// newAuthManager creates a new authentication manager instance with all supported
|
// newAuthManager creates a new authentication manager instance with all supported
|
||||||
// authenticators and a file-based token store. It initializes authenticators for
|
// authenticators and a file-based token store.
|
||||||
// Gemini, Codex, Claude, Qwen, IFlow, Antigravity, and GitHub Copilot providers.
|
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - *sdkAuth.Manager: A configured authentication manager instance
|
// - *sdkAuth.Manager: A configured authentication manager instance
|
||||||
@@ -16,13 +15,14 @@ func newAuthManager() *sdkAuth.Manager {
|
|||||||
sdkAuth.NewGeminiAuthenticator(),
|
sdkAuth.NewGeminiAuthenticator(),
|
||||||
sdkAuth.NewCodexAuthenticator(),
|
sdkAuth.NewCodexAuthenticator(),
|
||||||
sdkAuth.NewClaudeAuthenticator(),
|
sdkAuth.NewClaudeAuthenticator(),
|
||||||
sdkAuth.NewQwenAuthenticator(),
|
|
||||||
sdkAuth.NewIFlowAuthenticator(),
|
|
||||||
sdkAuth.NewAntigravityAuthenticator(),
|
sdkAuth.NewAntigravityAuthenticator(),
|
||||||
sdkAuth.NewKimiAuthenticator(),
|
sdkAuth.NewKimiAuthenticator(),
|
||||||
sdkAuth.NewKiroAuthenticator(),
|
sdkAuth.NewKiroAuthenticator(),
|
||||||
sdkAuth.NewGitHubCopilotAuthenticator(),
|
sdkAuth.NewGitHubCopilotAuthenticator(),
|
||||||
sdkAuth.NewKiloAuthenticator(),
|
sdkAuth.NewKiloAuthenticator(),
|
||||||
|
sdkAuth.NewGitLabAuthenticator(),
|
||||||
|
sdkAuth.NewCodeBuddyAuthenticator(),
|
||||||
|
sdkAuth.NewCursorAuthenticator(),
|
||||||
)
|
)
|
||||||
return manager
|
return manager
|
||||||
}
|
}
|
||||||
|
|||||||
43
internal/cmd/codebuddy_login.go
Normal file
43
internal/cmd/codebuddy_login.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoCodeBuddyLogin triggers the browser OAuth polling flow for CodeBuddy and saves tokens.
|
||||||
|
// It initiates the OAuth authentication, displays the user code for the user to enter
|
||||||
|
// at the CodeBuddy verification URL, and waits for authorization before saving the tokens.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration containing proxy and auth directory settings
|
||||||
|
// - options: Login options including browser behavior settings
|
||||||
|
func DoCodeBuddyLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
Metadata: map[string]string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
record, savedPath, err := manager.Login(context.Background(), "codebuddy", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("CodeBuddy authentication failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
if record != nil && record.Label != "" {
|
||||||
|
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||||
|
}
|
||||||
|
fmt.Println("CodeBuddy authentication successful!")
|
||||||
|
}
|
||||||
37
internal/cmd/cursor_login.go
Normal file
37
internal/cmd/cursor_login.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoCursorLogin triggers the OAuth PKCE flow for Cursor and saves tokens.
|
||||||
|
func DoCursorLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
Metadata: map[string]string{},
|
||||||
|
Prompt: options.Prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
record, savedPath, err := manager.Login(context.Background(), "cursor", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Cursor authentication failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
log.Infof("Authentication saved to %s", savedPath)
|
||||||
|
}
|
||||||
|
if record != nil && record.Label != "" {
|
||||||
|
log.Infof("Authenticated as %s", record.Label)
|
||||||
|
}
|
||||||
|
log.Info("Cursor authentication successful!")
|
||||||
|
}
|
||||||
69
internal/cmd/gitlab_login.go
Normal file
69
internal/cmd/gitlab_login.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DoGitLabLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
promptFn := options.Prompt
|
||||||
|
if promptFn == nil {
|
||||||
|
promptFn = defaultProjectPrompt()
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
CallbackPort: options.CallbackPort,
|
||||||
|
Metadata: map[string]string{
|
||||||
|
"login_mode": "oauth",
|
||||||
|
},
|
||||||
|
Prompt: promptFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("GitLab Duo authentication failed: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
fmt.Println("GitLab Duo authentication successful!")
|
||||||
|
}
|
||||||
|
|
||||||
|
func DoGitLabTokenLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
promptFn := options.Prompt
|
||||||
|
if promptFn == nil {
|
||||||
|
promptFn = defaultProjectPrompt()
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
Metadata: map[string]string{
|
||||||
|
"login_mode": "pat",
|
||||||
|
},
|
||||||
|
Prompt: promptFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("GitLab Duo PAT authentication failed: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
fmt.Println("GitLab Duo PAT authentication successful!")
|
||||||
|
}
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
||||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DoQwenLogin handles the Qwen device flow using the shared authentication manager.
|
|
||||||
// It initiates the device-based authentication process for Qwen services and saves
|
|
||||||
// the authentication tokens to the configured auth directory.
|
|
||||||
//
|
|
||||||
// Parameters:
|
|
||||||
// - cfg: The application configuration
|
|
||||||
// - options: Login options including browser behavior and prompts
|
|
||||||
func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
|
|
||||||
if options == nil {
|
|
||||||
options = &LoginOptions{}
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := newAuthManager()
|
|
||||||
|
|
||||||
promptFn := options.Prompt
|
|
||||||
if promptFn == nil {
|
|
||||||
promptFn = func(prompt string) (string, error) {
|
|
||||||
fmt.Println()
|
|
||||||
fmt.Println(prompt)
|
|
||||||
var value string
|
|
||||||
_, err := fmt.Scanln(&value)
|
|
||||||
return value, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
authOpts := &sdkAuth.LoginOptions{
|
|
||||||
NoBrowser: options.NoBrowser,
|
|
||||||
CallbackPort: options.CallbackPort,
|
|
||||||
Metadata: map[string]string{},
|
|
||||||
Prompt: promptFn,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
|
|
||||||
if err != nil {
|
|
||||||
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
|
|
||||||
log.Error(emailErr.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fmt.Printf("Qwen authentication failed: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if savedPath != "" {
|
|
||||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("Qwen authentication successful!")
|
|
||||||
}
|
|
||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
// DoVertexImport imports a Google Cloud service account key JSON and persists
|
// DoVertexImport imports a Google Cloud service account key JSON and persists
|
||||||
// it as a "vertex" provider credential. The file content is embedded in the auth
|
// it as a "vertex" provider credential. The file content is embedded in the auth
|
||||||
// file to allow portable deployment across stores.
|
// file to allow portable deployment across stores.
|
||||||
func DoVertexImport(cfg *config.Config, keyPath string) {
|
func DoVertexImport(cfg *config.Config, keyPath string, prefix string) {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
cfg = &config.Config{}
|
cfg = &config.Config{}
|
||||||
}
|
}
|
||||||
@@ -62,13 +62,28 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
|
|||||||
// Default location if not provided by user. Can be edited in the saved file later.
|
// Default location if not provided by user. Can be edited in the saved file later.
|
||||||
location := "us-central1"
|
location := "us-central1"
|
||||||
|
|
||||||
fileName := fmt.Sprintf("vertex-%s.json", sanitizeFilePart(projectID))
|
// Normalize and validate prefix: must be a single segment (no "/" allowed).
|
||||||
|
prefix = strings.TrimSpace(prefix)
|
||||||
|
prefix = strings.Trim(prefix, "/")
|
||||||
|
if prefix != "" && strings.Contains(prefix, "/") {
|
||||||
|
log.Errorf("vertex-import: prefix must be a single segment (no '/' allowed): %q", prefix)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Include prefix in filename so importing the same project with different
|
||||||
|
// prefixes creates separate credential files instead of overwriting.
|
||||||
|
baseName := sanitizeFilePart(projectID)
|
||||||
|
if prefix != "" {
|
||||||
|
baseName = sanitizeFilePart(prefix) + "-" + baseName
|
||||||
|
}
|
||||||
|
fileName := fmt.Sprintf("vertex-%s.json", baseName)
|
||||||
// Build auth record
|
// Build auth record
|
||||||
storage := &vertex.VertexCredentialStorage{
|
storage := &vertex.VertexCredentialStorage{
|
||||||
ServiceAccount: sa,
|
ServiceAccount: sa,
|
||||||
ProjectID: projectID,
|
ProjectID: projectID,
|
||||||
Email: email,
|
Email: email,
|
||||||
Location: location,
|
Location: location,
|
||||||
|
Prefix: prefix,
|
||||||
}
|
}
|
||||||
metadata := map[string]any{
|
metadata := map[string]any{
|
||||||
"service_account": sa,
|
"service_account": sa,
|
||||||
@@ -76,6 +91,7 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
|
|||||||
"email": email,
|
"email": email,
|
||||||
"location": location,
|
"location": location,
|
||||||
"type": "vertex",
|
"type": "vertex",
|
||||||
|
"prefix": prefix,
|
||||||
"label": labelForVertex(projectID, email),
|
"label": labelForVertex(projectID, email),
|
||||||
}
|
}
|
||||||
record := &coreauth.Auth{
|
record := &coreauth.Auth{
|
||||||
|
|||||||
55
internal/config/claude_header_defaults_test.go
Normal file
55
internal/config/claude_header_defaults_test.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadConfigOptional_ClaudeHeaderDefaults(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
configPath := filepath.Join(dir, "config.yaml")
|
||||||
|
configYAML := []byte(`
|
||||||
|
claude-header-defaults:
|
||||||
|
user-agent: " claude-cli/2.1.70 (external, cli) "
|
||||||
|
package-version: " 0.80.0 "
|
||||||
|
runtime-version: " v24.5.0 "
|
||||||
|
os: " MacOS "
|
||||||
|
arch: " arm64 "
|
||||||
|
timeout: " 900 "
|
||||||
|
stabilize-device-profile: false
|
||||||
|
`)
|
||||||
|
if err := os.WriteFile(configPath, configYAML, 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to write config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := LoadConfigOptional(configPath, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadConfigOptional() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.UserAgent; got != "claude-cli/2.1.70 (external, cli)" {
|
||||||
|
t.Fatalf("UserAgent = %q, want %q", got, "claude-cli/2.1.70 (external, cli)")
|
||||||
|
}
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.PackageVersion; got != "0.80.0" {
|
||||||
|
t.Fatalf("PackageVersion = %q, want %q", got, "0.80.0")
|
||||||
|
}
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.RuntimeVersion; got != "v24.5.0" {
|
||||||
|
t.Fatalf("RuntimeVersion = %q, want %q", got, "v24.5.0")
|
||||||
|
}
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.OS; got != "MacOS" {
|
||||||
|
t.Fatalf("OS = %q, want %q", got, "MacOS")
|
||||||
|
}
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.Arch; got != "arm64" {
|
||||||
|
t.Fatalf("Arch = %q, want %q", got, "arm64")
|
||||||
|
}
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.Timeout; got != "900" {
|
||||||
|
t.Fatalf("Timeout = %q, want %q", got, "900")
|
||||||
|
}
|
||||||
|
if cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
|
||||||
|
t.Fatal("StabilizeDeviceProfile = nil, want non-nil")
|
||||||
|
}
|
||||||
|
if got := *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile; got {
|
||||||
|
t.Fatalf("StabilizeDeviceProfile = %v, want false", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
32
internal/config/codex_websocket_header_defaults_test.go
Normal file
32
internal/config/codex_websocket_header_defaults_test.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadConfigOptional_CodexHeaderDefaults(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
configPath := filepath.Join(dir, "config.yaml")
|
||||||
|
configYAML := []byte(`
|
||||||
|
codex-header-defaults:
|
||||||
|
user-agent: " my-codex-client/1.0 "
|
||||||
|
beta-features: " feature-a,feature-b "
|
||||||
|
`)
|
||||||
|
if err := os.WriteFile(configPath, configYAML, 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to write config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := LoadConfigOptional(configPath, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadConfigOptional() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := cfg.CodexHeaderDefaults.UserAgent; got != "my-codex-client/1.0" {
|
||||||
|
t.Fatalf("UserAgent = %q, want %q", got, "my-codex-client/1.0")
|
||||||
|
}
|
||||||
|
if got := cfg.CodexHeaderDefaults.BetaFeatures; got != "feature-a,feature-b" {
|
||||||
|
t.Fatalf("BetaFeatures = %q, want %q", got, "feature-a,feature-b")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
@@ -67,6 +68,10 @@ type Config struct {
|
|||||||
// DisableCooling disables quota cooldown scheduling when true.
|
// DisableCooling disables quota cooldown scheduling when true.
|
||||||
DisableCooling bool `yaml:"disable-cooling" json:"disable-cooling"`
|
DisableCooling bool `yaml:"disable-cooling" json:"disable-cooling"`
|
||||||
|
|
||||||
|
// AuthAutoRefreshWorkers overrides the size of the core auth auto-refresh worker pool.
|
||||||
|
// When <= 0, the default worker count is used.
|
||||||
|
AuthAutoRefreshWorkers int `yaml:"auth-auto-refresh-workers" json:"auth-auto-refresh-workers"`
|
||||||
|
|
||||||
// RequestRetry defines the retry times when the request failed.
|
// RequestRetry defines the retry times when the request failed.
|
||||||
RequestRetry int `yaml:"request-retry" json:"request-retry"`
|
RequestRetry int `yaml:"request-retry" json:"request-retry"`
|
||||||
// MaxRetryCredentials defines the maximum number of credentials to try for a failed request.
|
// MaxRetryCredentials defines the maximum number of credentials to try for a failed request.
|
||||||
@@ -84,6 +89,13 @@ type Config struct {
|
|||||||
// WebsocketAuth enables or disables authentication for the WebSocket API.
|
// WebsocketAuth enables or disables authentication for the WebSocket API.
|
||||||
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
|
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
|
||||||
|
|
||||||
|
// AntigravitySignatureCacheEnabled controls whether signature cache validation is enabled for thinking blocks.
|
||||||
|
// When true (default), cached signatures are preferred and validated.
|
||||||
|
// When false, client signatures are used directly after normalization (bypass mode).
|
||||||
|
AntigravitySignatureCacheEnabled *bool `yaml:"antigravity-signature-cache-enabled,omitempty" json:"antigravity-signature-cache-enabled,omitempty"`
|
||||||
|
|
||||||
|
AntigravitySignatureBypassStrict *bool `yaml:"antigravity-signature-bypass-strict,omitempty" json:"antigravity-signature-bypass-strict,omitempty"`
|
||||||
|
|
||||||
// GeminiKey defines Gemini API key configurations with optional routing overrides.
|
// GeminiKey defines Gemini API key configurations with optional routing overrides.
|
||||||
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
|
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
|
||||||
|
|
||||||
@@ -101,6 +113,10 @@ type Config struct {
|
|||||||
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
|
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
|
||||||
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
|
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
|
||||||
|
|
||||||
|
// CodexHeaderDefaults configures fallback headers for Codex OAuth model requests.
|
||||||
|
// These are used only when the client does not send its own headers.
|
||||||
|
CodexHeaderDefaults CodexHeaderDefaults `yaml:"codex-header-defaults" json:"codex-header-defaults"`
|
||||||
|
|
||||||
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
|
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
|
||||||
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
|
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
|
||||||
|
|
||||||
@@ -119,12 +135,12 @@ type Config struct {
|
|||||||
AmpCode AmpCode `yaml:"ampcode" json:"ampcode"`
|
AmpCode AmpCode `yaml:"ampcode" json:"ampcode"`
|
||||||
|
|
||||||
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
|
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
|
||||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
|
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, iflow, kiro, github-copilot, kimi.
|
||||||
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
|
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
|
||||||
|
|
||||||
// OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels.
|
// OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels.
|
||||||
// These aliases affect both model listing and model routing for supported channels:
|
// These aliases affect both model listing and model routing for supported channels:
|
||||||
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
|
// gemini-cli, vertex, aistudio, antigravity, claude, codex, iflow, kiro, github-copilot, kimi.
|
||||||
//
|
//
|
||||||
// NOTE: This does not apply to existing per-credential model alias features under:
|
// NOTE: This does not apply to existing per-credential model alias features under:
|
||||||
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
|
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
|
||||||
@@ -141,13 +157,27 @@ type Config struct {
|
|||||||
legacyMigrationPending bool `yaml:"-" json:"-"`
|
legacyMigrationPending bool `yaml:"-" json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
|
// ClaudeHeaderDefaults configures default header values injected into Claude API requests.
|
||||||
// when the client does not send them. Update these when Claude Code releases a new version.
|
// In legacy mode, UserAgent/PackageVersion/RuntimeVersion/Timeout act as fallbacks when
|
||||||
|
// the client omits them, while OS/Arch remain runtime-derived. When stabilized device
|
||||||
|
// profiles are enabled, OS/Arch become the pinned platform baseline, while
|
||||||
|
// UserAgent/PackageVersion/RuntimeVersion seed the upgradeable software fingerprint.
|
||||||
type ClaudeHeaderDefaults struct {
|
type ClaudeHeaderDefaults struct {
|
||||||
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||||
PackageVersion string `yaml:"package-version" json:"package-version"`
|
PackageVersion string `yaml:"package-version" json:"package-version"`
|
||||||
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
||||||
Timeout string `yaml:"timeout" json:"timeout"`
|
OS string `yaml:"os" json:"os"`
|
||||||
|
Arch string `yaml:"arch" json:"arch"`
|
||||||
|
Timeout string `yaml:"timeout" json:"timeout"`
|
||||||
|
StabilizeDeviceProfile *bool `yaml:"stabilize-device-profile,omitempty" json:"stabilize-device-profile,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CodexHeaderDefaults configures fallback header values injected into Codex
|
||||||
|
// model requests for OAuth/file-backed auth when the client omits them.
|
||||||
|
// UserAgent applies to HTTP and websocket requests; BetaFeatures only applies to websockets.
|
||||||
|
type CodexHeaderDefaults struct {
|
||||||
|
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||||
|
BetaFeatures string `yaml:"beta-features" json:"beta-features"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSConfig holds HTTPS server settings.
|
// TLSConfig holds HTTPS server settings.
|
||||||
@@ -176,6 +206,9 @@ type RemoteManagement struct {
|
|||||||
SecretKey string `yaml:"secret-key"`
|
SecretKey string `yaml:"secret-key"`
|
||||||
// DisableControlPanel skips serving and syncing the bundled management UI when true.
|
// DisableControlPanel skips serving and syncing the bundled management UI when true.
|
||||||
DisableControlPanel bool `yaml:"disable-control-panel"`
|
DisableControlPanel bool `yaml:"disable-control-panel"`
|
||||||
|
// DisableAutoUpdatePanel disables automatic periodic background updates of the management panel asset from GitHub.
|
||||||
|
// When false (the default), the background updater remains enabled; when true, the panel is only downloaded on first access if missing.
|
||||||
|
DisableAutoUpdatePanel bool `yaml:"disable-auto-update-panel"`
|
||||||
// PanelGitHubRepository overrides the GitHub repository used to fetch the management panel asset.
|
// PanelGitHubRepository overrides the GitHub repository used to fetch the management panel asset.
|
||||||
// Accepts either a repository URL (https://github.com/org/repo) or an API releases endpoint.
|
// Accepts either a repository URL (https://github.com/org/repo) or an API releases endpoint.
|
||||||
PanelGitHubRepository string `yaml:"panel-github-repository"`
|
PanelGitHubRepository string `yaml:"panel-github-repository"`
|
||||||
@@ -189,6 +222,10 @@ type QuotaExceeded struct {
|
|||||||
|
|
||||||
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
|
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
|
||||||
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
|
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
|
||||||
|
|
||||||
|
// AntigravityCredits indicates whether to retry Antigravity quota_exhausted 429s once
|
||||||
|
// on the same credential with enabledCreditTypes=["GOOGLE_ONE_AI"].
|
||||||
|
AntigravityCredits bool `yaml:"antigravity-credits" json:"antigravity-credits"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoutingConfig configures how credentials are selected for requests.
|
// RoutingConfig configures how credentials are selected for requests.
|
||||||
@@ -196,6 +233,22 @@ type RoutingConfig struct {
|
|||||||
// Strategy selects the credential selection strategy.
|
// Strategy selects the credential selection strategy.
|
||||||
// Supported values: "round-robin" (default), "fill-first".
|
// Supported values: "round-robin" (default), "fill-first".
|
||||||
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
|
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
|
||||||
|
|
||||||
|
// ClaudeCodeSessionAffinity enables session-sticky routing for Claude Code clients.
|
||||||
|
// When enabled, requests with the same session ID (extracted from metadata.user_id)
|
||||||
|
// are routed to the same auth credential when available.
|
||||||
|
// Deprecated: Use SessionAffinity instead for universal session support.
|
||||||
|
ClaudeCodeSessionAffinity bool `yaml:"claude-code-session-affinity,omitempty" json:"claude-code-session-affinity,omitempty"`
|
||||||
|
|
||||||
|
// SessionAffinity enables universal session-sticky routing for all clients.
|
||||||
|
// Session IDs are extracted from multiple sources:
|
||||||
|
// X-Session-ID header, Idempotency-Key, metadata.user_id, conversation_id, or message hash.
|
||||||
|
// Automatic failover is always enabled when bound auth becomes unavailable.
|
||||||
|
SessionAffinity bool `yaml:"session-affinity,omitempty" json:"session-affinity,omitempty"`
|
||||||
|
|
||||||
|
// SessionAffinityTTL specifies how long session-to-auth bindings are retained.
|
||||||
|
// Default: 1h. Accepts duration strings like "30m", "1h", "2h30m".
|
||||||
|
SessionAffinityTTL string `yaml:"session-affinity-ttl,omitempty" json:"session-affinity-ttl,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuthModelAlias defines a model ID alias for a specific channel.
|
// OAuthModelAlias defines a model ID alias for a specific channel.
|
||||||
@@ -235,8 +288,8 @@ type AmpCode struct {
|
|||||||
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
||||||
|
|
||||||
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
|
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
|
||||||
// When a client authenticates with a key that matches an entry, that upstream key is used.
|
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
|
||||||
// If no match is found, falls back to UpstreamAPIKey (default behavior).
|
// is used for the upstream Amp request.
|
||||||
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
|
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
|
||||||
|
|
||||||
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
|
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
|
||||||
@@ -358,6 +411,11 @@ type ClaudeKey struct {
|
|||||||
|
|
||||||
// Cloak configures request cloaking for non-Claude-Code clients.
|
// Cloak configures request cloaking for non-Claude-Code clients.
|
||||||
Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"`
|
Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"`
|
||||||
|
|
||||||
|
// ExperimentalCCHSigning enables opt-in final-body cch signing for cloaked
|
||||||
|
// Claude /v1/messages requests. It is disabled by default so upstream seed
|
||||||
|
// changes do not alter the proxy's legacy behavior.
|
||||||
|
ExperimentalCCHSigning bool `yaml:"experimental-cch-signing,omitempty" json:"experimental-cch-signing,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k ClaudeKey) GetAPIKey() string { return k.APIKey }
|
func (k ClaudeKey) GetAPIKey() string { return k.APIKey }
|
||||||
@@ -556,6 +614,10 @@ type OpenAICompatibilityModel struct {
|
|||||||
|
|
||||||
// Alias is the model name alias that clients will use to reference this model.
|
// Alias is the model name alias that clients will use to reference this model.
|
||||||
Alias string `yaml:"alias" json:"alias"`
|
Alias string `yaml:"alias" json:"alias"`
|
||||||
|
|
||||||
|
// Thinking configures the thinking/reasoning capability for this model.
|
||||||
|
// If nil, the model defaults to level-based reasoning with levels ["low", "medium", "high"].
|
||||||
|
Thinking *registry.ThinkingSupport `yaml:"thinking,omitempty" json:"thinking,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m OpenAICompatibilityModel) GetName() string { return m.Name }
|
func (m OpenAICompatibilityModel) GetName() string { return m.Name }
|
||||||
@@ -673,12 +735,18 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
// Sanitize Gemini API key configuration and migrate legacy entries.
|
// Sanitize Gemini API key configuration and migrate legacy entries.
|
||||||
cfg.SanitizeGeminiKeys()
|
cfg.SanitizeGeminiKeys()
|
||||||
|
|
||||||
// Sanitize Vertex-compatible API keys: drop entries without base-url
|
// Sanitize Vertex-compatible API keys.
|
||||||
cfg.SanitizeVertexCompatKeys()
|
cfg.SanitizeVertexCompatKeys()
|
||||||
|
|
||||||
// Sanitize Codex keys: drop entries without base-url
|
// Sanitize Codex keys: drop entries without base-url
|
||||||
cfg.SanitizeCodexKeys()
|
cfg.SanitizeCodexKeys()
|
||||||
|
|
||||||
|
// Sanitize Codex header defaults.
|
||||||
|
cfg.SanitizeCodexHeaderDefaults()
|
||||||
|
|
||||||
|
// Sanitize Claude header defaults.
|
||||||
|
cfg.SanitizeClaudeHeaderDefaults()
|
||||||
|
|
||||||
// Sanitize Claude key headers
|
// Sanitize Claude key headers
|
||||||
cfg.SanitizeClaudeKeys()
|
cfg.SanitizeClaudeKeys()
|
||||||
|
|
||||||
@@ -771,6 +839,30 @@ func payloadRawString(value any) ([]byte, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SanitizeCodexHeaderDefaults trims surrounding whitespace from the
|
||||||
|
// configured Codex header fallback values.
|
||||||
|
func (cfg *Config) SanitizeCodexHeaderDefaults() {
|
||||||
|
if cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg.CodexHeaderDefaults.UserAgent = strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent)
|
||||||
|
cfg.CodexHeaderDefaults.BetaFeatures = strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeClaudeHeaderDefaults trims surrounding whitespace from the
|
||||||
|
// configured Claude fingerprint baseline values.
|
||||||
|
func (cfg *Config) SanitizeClaudeHeaderDefaults() {
|
||||||
|
if cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg.ClaudeHeaderDefaults.UserAgent = strings.TrimSpace(cfg.ClaudeHeaderDefaults.UserAgent)
|
||||||
|
cfg.ClaudeHeaderDefaults.PackageVersion = strings.TrimSpace(cfg.ClaudeHeaderDefaults.PackageVersion)
|
||||||
|
cfg.ClaudeHeaderDefaults.RuntimeVersion = strings.TrimSpace(cfg.ClaudeHeaderDefaults.RuntimeVersion)
|
||||||
|
cfg.ClaudeHeaderDefaults.OS = strings.TrimSpace(cfg.ClaudeHeaderDefaults.OS)
|
||||||
|
cfg.ClaudeHeaderDefaults.Arch = strings.TrimSpace(cfg.ClaudeHeaderDefaults.Arch)
|
||||||
|
cfg.ClaudeHeaderDefaults.Timeout = strings.TrimSpace(cfg.ClaudeHeaderDefaults.Timeout)
|
||||||
|
}
|
||||||
|
|
||||||
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
|
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
|
||||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||||
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
||||||
@@ -916,6 +1008,7 @@ func (cfg *Config) SanitizeKiroKeys() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SanitizeGeminiKeys deduplicates and normalizes Gemini credentials.
|
// SanitizeGeminiKeys deduplicates and normalizes Gemini credentials.
|
||||||
|
// It uses API key + base URL as the uniqueness key.
|
||||||
func (cfg *Config) SanitizeGeminiKeys() {
|
func (cfg *Config) SanitizeGeminiKeys() {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return
|
return
|
||||||
@@ -934,10 +1027,11 @@ func (cfg *Config) SanitizeGeminiKeys() {
|
|||||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||||
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||||
if _, exists := seen[entry.APIKey]; exists {
|
uniqueKey := entry.APIKey + "|" + entry.BaseURL
|
||||||
|
if _, exists := seen[uniqueKey]; exists {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
seen[entry.APIKey] = struct{}{}
|
seen[uniqueKey] = struct{}{}
|
||||||
out = append(out, entry)
|
out = append(out, entry)
|
||||||
}
|
}
|
||||||
cfg.GeminiKey = out
|
cfg.GeminiKey = out
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
// defaultKiroAliases returns default oauth-model-alias entries for Kiro.
|
// defaultKiroAliases returns default oauth-model-alias entries for Kiro.
|
||||||
// These aliases expose standard Claude IDs for Kiro-prefixed upstream models.
|
// These aliases expose standard Claude IDs for Kiro-prefixed upstream models.
|
||||||
func defaultKiroAliases() []OAuthModelAlias {
|
func defaultKiroAliases() []OAuthModelAlias {
|
||||||
@@ -35,3 +37,25 @@ func defaultGitHubCopilotAliases() []OAuthModelAlias {
|
|||||||
{Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true},
|
{Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GitHubCopilotAliasesFromModels generates oauth-model-alias entries from a dynamic
|
||||||
|
// list of model IDs fetched from the Copilot API. It auto-creates aliases for
|
||||||
|
// models whose ID contains a dot (e.g. "claude-opus-4.6" → "claude-opus-4-6"),
|
||||||
|
// which is the pattern used by Claude models on Copilot.
|
||||||
|
func GitHubCopilotAliasesFromModels(modelIDs []string) []OAuthModelAlias {
|
||||||
|
var aliases []OAuthModelAlias
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
for _, id := range modelIDs {
|
||||||
|
if !strings.Contains(id, ".") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
hyphenID := strings.ReplaceAll(id, ".", "-")
|
||||||
|
key := id + "→" + hyphenID
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
aliases = append(aliases, OAuthModelAlias{Name: id, Alias: hyphenID, Fork: true})
|
||||||
|
}
|
||||||
|
return aliases
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,6 +9,10 @@ type SDKConfig struct {
|
|||||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||||
|
|
||||||
|
// EnableGeminiCLIEndpoint controls whether Gemini CLI internal endpoints (/v1internal:*) are enabled.
|
||||||
|
// Default is false for safety; when false, /v1internal:* requests are rejected.
|
||||||
|
EnableGeminiCLIEndpoint bool `yaml:"enable-gemini-cli-endpoint" json:"enable-gemini-cli-endpoint"`
|
||||||
|
|
||||||
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
|
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
|
||||||
// to target prefixed credentials. When false, unprefixed model requests may use prefixed
|
// to target prefixed credentials. When false, unprefixed model requests may use prefixed
|
||||||
// credentials as well.
|
// credentials as well.
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ type VertexCompatKey struct {
|
|||||||
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
// BaseURL is the base URL for the Vertex-compatible API endpoint.
|
// BaseURL optionally overrides the Vertex-compatible API endpoint.
|
||||||
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
|
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
|
||||||
// Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..."
|
// When empty, requests fall back to the default Vertex API base URL.
|
||||||
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
|
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
|
||||||
|
|
||||||
// ProxyURL optionally overrides the global proxy for this API key.
|
// ProxyURL optionally overrides the global proxy for this API key.
|
||||||
@@ -71,10 +71,6 @@ func (cfg *Config) SanitizeVertexCompatKeys() {
|
|||||||
}
|
}
|
||||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||||
if entry.BaseURL == "" {
|
|
||||||
// BaseURL is required for Vertex API key entries
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||||
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
package logging
|
package logging
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/flate"
|
"compress/flate"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
@@ -41,15 +42,17 @@ type RequestLogger interface {
|
|||||||
// - statusCode: The response status code
|
// - statusCode: The response status code
|
||||||
// - responseHeaders: The response headers
|
// - responseHeaders: The response headers
|
||||||
// - response: The raw response data
|
// - response: The raw response data
|
||||||
|
// - websocketTimeline: Optional downstream websocket event timeline
|
||||||
// - apiRequest: The API request data
|
// - apiRequest: The API request data
|
||||||
// - apiResponse: The API response data
|
// - apiResponse: The API response data
|
||||||
|
// - apiWebsocketTimeline: Optional upstream websocket event timeline
|
||||||
// - requestID: Optional request ID for log file naming
|
// - requestID: Optional request ID for log file naming
|
||||||
// - requestTimestamp: When the request was received
|
// - requestTimestamp: When the request was received
|
||||||
// - apiResponseTimestamp: When the API response was received
|
// - apiResponseTimestamp: When the API response was received
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - error: An error if logging fails, nil otherwise
|
// - error: An error if logging fails, nil otherwise
|
||||||
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
|
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
|
||||||
|
|
||||||
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
|
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
|
||||||
//
|
//
|
||||||
@@ -111,6 +114,16 @@ type StreamingLogWriter interface {
|
|||||||
// - error: An error if writing fails, nil otherwise
|
// - error: An error if writing fails, nil otherwise
|
||||||
WriteAPIResponse(apiResponse []byte) error
|
WriteAPIResponse(apiResponse []byte) error
|
||||||
|
|
||||||
|
// WriteAPIWebsocketTimeline writes the upstream websocket timeline to the log.
|
||||||
|
// This should be called when upstream communication happened over websocket.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - apiWebsocketTimeline: The upstream websocket event timeline
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: An error if writing fails, nil otherwise
|
||||||
|
WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error
|
||||||
|
|
||||||
// SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received.
|
// SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
@@ -203,17 +216,17 @@ func (l *FileRequestLogger) SetErrorLogsMaxFiles(maxFiles int) {
|
|||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - error: An error if logging fails, nil otherwise
|
// - error: An error if logging fails, nil otherwise
|
||||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
|
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
||||||
// The force flag allows writing error logs even when regular request logging is disabled.
|
// The force flag allows writing error logs even when regular request logging is disabled.
|
||||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
|
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||||
if !l.enabled && !force {
|
if !l.enabled && !force {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -260,8 +273,10 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
|
|||||||
requestHeaders,
|
requestHeaders,
|
||||||
body,
|
body,
|
||||||
requestBodyPath,
|
requestBodyPath,
|
||||||
|
websocketTimeline,
|
||||||
apiRequest,
|
apiRequest,
|
||||||
apiResponse,
|
apiResponse,
|
||||||
|
apiWebsocketTimeline,
|
||||||
apiResponseErrors,
|
apiResponseErrors,
|
||||||
statusCode,
|
statusCode,
|
||||||
responseHeaders,
|
responseHeaders,
|
||||||
@@ -518,8 +533,10 @@ func (l *FileRequestLogger) writeNonStreamingLog(
|
|||||||
requestHeaders map[string][]string,
|
requestHeaders map[string][]string,
|
||||||
requestBody []byte,
|
requestBody []byte,
|
||||||
requestBodyPath string,
|
requestBodyPath string,
|
||||||
|
websocketTimeline []byte,
|
||||||
apiRequest []byte,
|
apiRequest []byte,
|
||||||
apiResponse []byte,
|
apiResponse []byte,
|
||||||
|
apiWebsocketTimeline []byte,
|
||||||
apiResponseErrors []*interfaces.ErrorMessage,
|
apiResponseErrors []*interfaces.ErrorMessage,
|
||||||
statusCode int,
|
statusCode int,
|
||||||
responseHeaders map[string][]string,
|
responseHeaders map[string][]string,
|
||||||
@@ -531,7 +548,16 @@ func (l *FileRequestLogger) writeNonStreamingLog(
|
|||||||
if requestTimestamp.IsZero() {
|
if requestTimestamp.IsZero() {
|
||||||
requestTimestamp = time.Now()
|
requestTimestamp = time.Now()
|
||||||
}
|
}
|
||||||
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp); errWrite != nil {
|
isWebsocketTranscript := hasSectionPayload(websocketTimeline)
|
||||||
|
downstreamTransport := inferDownstreamTransport(requestHeaders, websocketTimeline)
|
||||||
|
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
|
||||||
|
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp, downstreamTransport, upstreamTransport, !isWebsocketTranscript); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
if errWrite := writeAPISection(w, "=== WEBSOCKET TIMELINE ===\n", "=== WEBSOCKET TIMELINE", websocketTimeline, time.Time{}); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
if errWrite := writeAPISection(w, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", apiWebsocketTimeline, time.Time{}); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil {
|
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil {
|
||||||
@@ -543,6 +569,12 @@ func (l *FileRequestLogger) writeNonStreamingLog(
|
|||||||
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil {
|
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
if isWebsocketTranscript {
|
||||||
|
// Intentionally omit the generic downstream HTTP response section for websocket
|
||||||
|
// transcripts. The durable session exchange is captured in WEBSOCKET TIMELINE,
|
||||||
|
// and appending a one-off upgrade response snapshot would dilute that transcript.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
|
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -553,6 +585,9 @@ func writeRequestInfoWithBody(
|
|||||||
body []byte,
|
body []byte,
|
||||||
bodyPath string,
|
bodyPath string,
|
||||||
timestamp time.Time,
|
timestamp time.Time,
|
||||||
|
downstreamTransport string,
|
||||||
|
upstreamTransport string,
|
||||||
|
includeBody bool,
|
||||||
) error {
|
) error {
|
||||||
if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil {
|
if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
@@ -566,10 +601,20 @@ func writeRequestInfoWithBody(
|
|||||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil {
|
if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(downstreamTransport) != "" {
|
||||||
|
if _, errWrite := io.WriteString(w, fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport)); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(upstreamTransport) != "" {
|
||||||
|
if _, errWrite := io.WriteString(w, fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport)); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
}
|
||||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
|
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -584,36 +629,121 @@ func writeRequestInfoWithBody(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !includeBody {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil {
|
if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bodyTrailingNewlines := 1
|
||||||
if bodyPath != "" {
|
if bodyPath != "" {
|
||||||
bodyFile, errOpen := os.Open(bodyPath)
|
bodyFile, errOpen := os.Open(bodyPath)
|
||||||
if errOpen != nil {
|
if errOpen != nil {
|
||||||
return errOpen
|
return errOpen
|
||||||
}
|
}
|
||||||
if _, errCopy := io.Copy(w, bodyFile); errCopy != nil {
|
tracker := &trailingNewlineTrackingWriter{writer: w}
|
||||||
|
written, errCopy := io.Copy(tracker, bodyFile)
|
||||||
|
if errCopy != nil {
|
||||||
_ = bodyFile.Close()
|
_ = bodyFile.Close()
|
||||||
return errCopy
|
return errCopy
|
||||||
}
|
}
|
||||||
|
if written > 0 {
|
||||||
|
bodyTrailingNewlines = tracker.trailingNewlines
|
||||||
|
}
|
||||||
if errClose := bodyFile.Close(); errClose != nil {
|
if errClose := bodyFile.Close(); errClose != nil {
|
||||||
log.WithError(errClose).Warn("failed to close request body temp file")
|
log.WithError(errClose).Warn("failed to close request body temp file")
|
||||||
}
|
}
|
||||||
} else if _, errWrite := w.Write(body); errWrite != nil {
|
} else if _, errWrite := w.Write(body); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
|
} else if len(body) > 0 {
|
||||||
|
bodyTrailingNewlines = countTrailingNewlinesBytes(body)
|
||||||
}
|
}
|
||||||
|
if errWrite := writeSectionSpacing(w, bodyTrailingNewlines); errWrite != nil {
|
||||||
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
|
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func countTrailingNewlinesBytes(payload []byte) int {
|
||||||
|
count := 0
|
||||||
|
for i := len(payload) - 1; i >= 0; i-- {
|
||||||
|
if payload[i] != '\n' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeSectionSpacing(w io.Writer, trailingNewlines int) error {
|
||||||
|
missingNewlines := 3 - trailingNewlines
|
||||||
|
if missingNewlines <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, errWrite := io.WriteString(w, strings.Repeat("\n", missingNewlines))
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
|
||||||
|
type trailingNewlineTrackingWriter struct {
|
||||||
|
writer io.Writer
|
||||||
|
trailingNewlines int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *trailingNewlineTrackingWriter) Write(payload []byte) (int, error) {
|
||||||
|
written, errWrite := t.writer.Write(payload)
|
||||||
|
if written > 0 {
|
||||||
|
writtenPayload := payload[:written]
|
||||||
|
trailingNewlines := countTrailingNewlinesBytes(writtenPayload)
|
||||||
|
if trailingNewlines == len(writtenPayload) {
|
||||||
|
t.trailingNewlines += trailingNewlines
|
||||||
|
} else {
|
||||||
|
t.trailingNewlines = trailingNewlines
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return written, errWrite
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasSectionPayload(payload []byte) bool {
|
||||||
|
return len(bytes.TrimSpace(payload)) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func inferDownstreamTransport(headers map[string][]string, websocketTimeline []byte) string {
|
||||||
|
if hasSectionPayload(websocketTimeline) {
|
||||||
|
return "websocket"
|
||||||
|
}
|
||||||
|
for key, values := range headers {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(key), "Upgrade") {
|
||||||
|
for _, value := range values {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(value), "websocket") {
|
||||||
|
return "websocket"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "http"
|
||||||
|
}
|
||||||
|
|
||||||
|
func inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline []byte, _ []*interfaces.ErrorMessage) string {
|
||||||
|
hasHTTP := hasSectionPayload(apiRequest) || hasSectionPayload(apiResponse)
|
||||||
|
hasWS := hasSectionPayload(apiWebsocketTimeline)
|
||||||
|
switch {
|
||||||
|
case hasHTTP && hasWS:
|
||||||
|
return "websocket+http"
|
||||||
|
case hasWS:
|
||||||
|
return "websocket"
|
||||||
|
case hasHTTP:
|
||||||
|
return "http"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error {
|
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error {
|
||||||
if len(payload) == 0 {
|
if len(payload) == 0 {
|
||||||
return nil
|
return nil
|
||||||
@@ -623,11 +753,6 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
|
|||||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if !bytes.HasSuffix(payload, []byte("\n")) {
|
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
|
||||||
return errWrite
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
|
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
@@ -640,12 +765,9 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
|
|||||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
|
||||||
return errWrite
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
if errWrite := writeSectionSpacing(w, countTrailingNewlinesBytes(payload)); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -662,12 +784,17 @@ func writeAPIErrorResponses(w io.Writer, apiResponseErrors []*interfaces.ErrorMe
|
|||||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil {
|
if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
trailingNewlines := 1
|
||||||
if apiResponseErrors[i].Error != nil {
|
if apiResponseErrors[i].Error != nil {
|
||||||
if _, errWrite := io.WriteString(w, apiResponseErrors[i].Error.Error()); errWrite != nil {
|
errText := apiResponseErrors[i].Error.Error()
|
||||||
|
if _, errWrite := io.WriteString(w, errText); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
if errText != "" {
|
||||||
|
trailingNewlines = countTrailingNewlinesBytes([]byte(errText))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
|
if errWrite := writeSectionSpacing(w, trailingNewlines); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -694,12 +821,18 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
var bufferedReader *bufio.Reader
|
||||||
return errWrite
|
if responseReader != nil {
|
||||||
|
bufferedReader = bufio.NewReader(responseReader)
|
||||||
|
}
|
||||||
|
if !responseBodyStartsWithLeadingNewline(bufferedReader) {
|
||||||
|
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if responseReader != nil {
|
if bufferedReader != nil {
|
||||||
if _, errCopy := io.Copy(w, responseReader); errCopy != nil {
|
if _, errCopy := io.Copy(w, bufferedReader); errCopy != nil {
|
||||||
return errCopy
|
return errCopy
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -717,6 +850,19 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func responseBodyStartsWithLeadingNewline(reader *bufio.Reader) bool {
|
||||||
|
if reader == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if peeked, _ := reader.Peek(2); len(peeked) >= 2 && peeked[0] == '\r' && peeked[1] == '\n' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if peeked, _ := reader.Peek(1); len(peeked) >= 1 && peeked[0] == '\n' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// formatLogContent creates the complete log content for non-streaming requests.
|
// formatLogContent creates the complete log content for non-streaming requests.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
@@ -724,6 +870,7 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
|||||||
// - method: The HTTP method
|
// - method: The HTTP method
|
||||||
// - headers: The request headers
|
// - headers: The request headers
|
||||||
// - body: The request body
|
// - body: The request body
|
||||||
|
// - websocketTimeline: The downstream websocket event timeline
|
||||||
// - apiRequest: The API request data
|
// - apiRequest: The API request data
|
||||||
// - apiResponse: The API response data
|
// - apiResponse: The API response data
|
||||||
// - response: The raw response data
|
// - response: The raw response data
|
||||||
@@ -732,11 +879,42 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
|||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - string: The formatted log content
|
// - string: The formatted log content
|
||||||
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, apiRequest, apiResponse, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string {
|
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string {
|
||||||
var content strings.Builder
|
var content strings.Builder
|
||||||
|
isWebsocketTranscript := hasSectionPayload(websocketTimeline)
|
||||||
|
downstreamTransport := inferDownstreamTransport(headers, websocketTimeline)
|
||||||
|
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
|
||||||
|
|
||||||
// Request info
|
// Request info
|
||||||
content.WriteString(l.formatRequestInfo(url, method, headers, body))
|
content.WriteString(l.formatRequestInfo(url, method, headers, body, downstreamTransport, upstreamTransport, !isWebsocketTranscript))
|
||||||
|
|
||||||
|
if len(websocketTimeline) > 0 {
|
||||||
|
if bytes.HasPrefix(websocketTimeline, []byte("=== WEBSOCKET TIMELINE")) {
|
||||||
|
content.Write(websocketTimeline)
|
||||||
|
if !bytes.HasSuffix(websocketTimeline, []byte("\n")) {
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
content.WriteString("=== WEBSOCKET TIMELINE ===\n")
|
||||||
|
content.Write(websocketTimeline)
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(apiWebsocketTimeline) > 0 {
|
||||||
|
if bytes.HasPrefix(apiWebsocketTimeline, []byte("=== API WEBSOCKET TIMELINE")) {
|
||||||
|
content.Write(apiWebsocketTimeline)
|
||||||
|
if !bytes.HasSuffix(apiWebsocketTimeline, []byte("\n")) {
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
content.WriteString("=== API WEBSOCKET TIMELINE ===\n")
|
||||||
|
content.Write(apiWebsocketTimeline)
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
if len(apiRequest) > 0 {
|
if len(apiRequest) > 0 {
|
||||||
if bytes.HasPrefix(apiRequest, []byte("=== API REQUEST")) {
|
if bytes.HasPrefix(apiRequest, []byte("=== API REQUEST")) {
|
||||||
@@ -773,6 +951,12 @@ func (l *FileRequestLogger) formatLogContent(url, method string, headers map[str
|
|||||||
content.WriteString("\n")
|
content.WriteString("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isWebsocketTranscript {
|
||||||
|
// Mirror writeNonStreamingLog: websocket transcripts end with the dedicated
|
||||||
|
// timeline sections instead of a generic downstream HTTP response block.
|
||||||
|
return content.String()
|
||||||
|
}
|
||||||
|
|
||||||
// Response section
|
// Response section
|
||||||
content.WriteString("=== RESPONSE ===\n")
|
content.WriteString("=== RESPONSE ===\n")
|
||||||
content.WriteString(fmt.Sprintf("Status: %d\n", status))
|
content.WriteString(fmt.Sprintf("Status: %d\n", status))
|
||||||
@@ -933,13 +1117,19 @@ func (l *FileRequestLogger) decompressZstd(data []byte) ([]byte, error) {
|
|||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - string: The formatted request information
|
// - string: The formatted request information
|
||||||
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte) string {
|
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte, downstreamTransport string, upstreamTransport string, includeBody bool) string {
|
||||||
var content strings.Builder
|
var content strings.Builder
|
||||||
|
|
||||||
content.WriteString("=== REQUEST INFO ===\n")
|
content.WriteString("=== REQUEST INFO ===\n")
|
||||||
content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version))
|
content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version))
|
||||||
content.WriteString(fmt.Sprintf("URL: %s\n", url))
|
content.WriteString(fmt.Sprintf("URL: %s\n", url))
|
||||||
content.WriteString(fmt.Sprintf("Method: %s\n", method))
|
content.WriteString(fmt.Sprintf("Method: %s\n", method))
|
||||||
|
if strings.TrimSpace(downstreamTransport) != "" {
|
||||||
|
content.WriteString(fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport))
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(upstreamTransport) != "" {
|
||||||
|
content.WriteString(fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport))
|
||||||
|
}
|
||||||
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||||
content.WriteString("\n")
|
content.WriteString("\n")
|
||||||
|
|
||||||
@@ -952,6 +1142,10 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
|
|||||||
}
|
}
|
||||||
content.WriteString("\n")
|
content.WriteString("\n")
|
||||||
|
|
||||||
|
if !includeBody {
|
||||||
|
return content.String()
|
||||||
|
}
|
||||||
|
|
||||||
content.WriteString("=== REQUEST BODY ===\n")
|
content.WriteString("=== REQUEST BODY ===\n")
|
||||||
content.Write(body)
|
content.Write(body)
|
||||||
content.WriteString("\n\n")
|
content.WriteString("\n\n")
|
||||||
@@ -1011,6 +1205,9 @@ type FileStreamingLogWriter struct {
|
|||||||
// apiResponse stores the upstream API response data.
|
// apiResponse stores the upstream API response data.
|
||||||
apiResponse []byte
|
apiResponse []byte
|
||||||
|
|
||||||
|
// apiWebsocketTimeline stores the upstream websocket event timeline.
|
||||||
|
apiWebsocketTimeline []byte
|
||||||
|
|
||||||
// apiResponseTimestamp captures when the API response was received.
|
// apiResponseTimestamp captures when the API response was received.
|
||||||
apiResponseTimestamp time.Time
|
apiResponseTimestamp time.Time
|
||||||
}
|
}
|
||||||
@@ -1092,6 +1289,21 @@ func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteAPIWebsocketTimeline buffers the upstream websocket timeline for later writing.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - apiWebsocketTimeline: The upstream websocket event timeline
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: Always returns nil (buffering cannot fail)
|
||||||
|
func (w *FileStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
|
||||||
|
if len(apiWebsocketTimeline) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
|
func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
|
||||||
if !timestamp.IsZero() {
|
if !timestamp.IsZero() {
|
||||||
w.apiResponseTimestamp = timestamp
|
w.apiResponseTimestamp = timestamp
|
||||||
@@ -1100,7 +1312,7 @@ func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
|
|||||||
|
|
||||||
// Close finalizes the log file and cleans up resources.
|
// Close finalizes the log file and cleans up resources.
|
||||||
// It writes all buffered data to the file in the correct order:
|
// It writes all buffered data to the file in the correct order:
|
||||||
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
|
// API WEBSOCKET TIMELINE -> API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - error: An error if closing fails, nil otherwise
|
// - error: An error if closing fails, nil otherwise
|
||||||
@@ -1182,7 +1394,10 @@ func (w *FileStreamingLogWriter) asyncWriter() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
|
func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
|
||||||
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil {
|
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp, "http", inferUpstreamTransport(w.apiRequest, w.apiResponse, w.apiWebsocketTimeline, nil), true); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
if errWrite := writeAPISection(logFile, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", w.apiWebsocketTimeline, time.Time{}); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil {
|
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil {
|
||||||
@@ -1265,6 +1480,17 @@ func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteAPIWebsocketTimeline is a no-op implementation that does nothing and always returns nil.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - apiWebsocketTimeline: The upstream websocket event timeline (ignored)
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: Always returns nil
|
||||||
|
func (w *NoOpStreamingLogWriter) WriteAPIWebsocketTimeline(_ []byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {}
|
func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {}
|
||||||
|
|
||||||
// Close is a no-op implementation that does nothing and always returns nil.
|
// Close is a no-op implementation that does nothing and always returns nil.
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ const (
|
|||||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||||
managementSyncMinInterval = 30 * time.Second
|
managementSyncMinInterval = 30 * time.Second
|
||||||
updateCheckInterval = 3 * time.Hour
|
updateCheckInterval = 3 * time.Hour
|
||||||
|
maxAssetDownloadSize = 50 << 20 // 10 MB safety limit for management asset downloads
|
||||||
)
|
)
|
||||||
|
|
||||||
// ManagementFileName exposes the control panel asset filename.
|
// ManagementFileName exposes the control panel asset filename.
|
||||||
@@ -88,6 +89,10 @@ func runAutoUpdater(ctx context.Context) {
|
|||||||
log.Debug("management asset auto-updater skipped: control panel disabled")
|
log.Debug("management asset auto-updater skipped: control panel disabled")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if cfg.RemoteManagement.DisableAutoUpdatePanel {
|
||||||
|
log.Debug("management asset auto-updater skipped: disable-auto-update-panel is enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
configPath, _ := schedulerConfigPath.Load().(string)
|
configPath, _ := schedulerConfigPath.Load().(string)
|
||||||
staticDir := StaticDir(configPath)
|
staticDir := StaticDir(configPath)
|
||||||
@@ -259,7 +264,8 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
|||||||
}
|
}
|
||||||
|
|
||||||
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
|
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
|
||||||
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
|
log.Errorf("management asset digest mismatch: expected %s got %s — aborting update for safety", remoteHash, downloadedHash)
|
||||||
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = atomicWriteFile(localPath, data); err != nil {
|
if err = atomicWriteFile(localPath, data); err != nil {
|
||||||
@@ -282,6 +288,9 @@ func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, loca
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Warnf("management asset downloaded from fallback URL without digest verification (hash=%s) — "+
|
||||||
|
"enable verified GitHub updates by keeping disable-auto-update-panel set to false", downloadedHash)
|
||||||
|
|
||||||
if err = atomicWriteFile(localPath, data); err != nil {
|
if err = atomicWriteFile(localPath, data); err != nil {
|
||||||
log.WithError(err).Warn("failed to persist fallback management control panel page")
|
log.WithError(err).Warn("failed to persist fallback management control panel page")
|
||||||
return false
|
return false
|
||||||
@@ -392,10 +401,13 @@ func downloadAsset(ctx context.Context, client *http.Client, downloadURL string)
|
|||||||
return nil, "", fmt.Errorf("unexpected download status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
return nil, "", fmt.Errorf("unexpected download status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := io.ReadAll(resp.Body)
|
data, err := io.ReadAll(io.LimitReader(resp.Body, maxAssetDownloadSize+1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("read download body: %w", err)
|
return nil, "", fmt.Errorf("read download body: %w", err)
|
||||||
}
|
}
|
||||||
|
if int64(len(data)) > maxAssetDownloadSize {
|
||||||
|
return nil, "", fmt.Errorf("download exceeds maximum allowed size of %d bytes", maxAssetDownloadSize)
|
||||||
|
}
|
||||||
|
|
||||||
sum := sha256.Sum256(data)
|
sum := sha256.Sum256(data)
|
||||||
return data, hex.EncodeToString(sum[:]), nil
|
return data, hex.EncodeToString(sum[:]), nil
|
||||||
|
|||||||
151
internal/misc/antigravity_version.go
Normal file
151
internal/misc/antigravity_version.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
// Package misc provides miscellaneous utility functions for the CLI Proxy API server.
|
||||||
|
package misc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
antigravityReleasesURL = "https://antigravity-auto-updater-974169037036.us-central1.run.app/releases"
|
||||||
|
antigravityFallbackVersion = "1.21.9"
|
||||||
|
antigravityVersionCacheTTL = 6 * time.Hour
|
||||||
|
antigravityFetchTimeout = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type antigravityRelease struct {
|
||||||
|
Version string `json:"version"`
|
||||||
|
ExecutionID string `json:"execution_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
cachedAntigravityVersion = antigravityFallbackVersion
|
||||||
|
antigravityVersionMu sync.RWMutex
|
||||||
|
antigravityVersionExpiry time.Time
|
||||||
|
antigravityUpdaterOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
// StartAntigravityVersionUpdater starts a background goroutine that periodically refreshes the cached antigravity version.
|
||||||
|
// This is intentionally decoupled from request execution to avoid blocking executors on version lookups.
|
||||||
|
func StartAntigravityVersionUpdater(ctx context.Context) {
|
||||||
|
antigravityUpdaterOnce.Do(func() {
|
||||||
|
go runAntigravityVersionUpdater(ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func runAntigravityVersionUpdater(ctx context.Context) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(antigravityVersionCacheTTL / 2)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
log.Infof("periodic antigravity version refresh started (interval=%s)", antigravityVersionCacheTTL/2)
|
||||||
|
|
||||||
|
refreshAntigravityVersion(ctx)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
refreshAntigravityVersion(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func refreshAntigravityVersion(ctx context.Context) {
|
||||||
|
version, errFetch := fetchAntigravityLatestVersion(ctx)
|
||||||
|
|
||||||
|
antigravityVersionMu.Lock()
|
||||||
|
defer antigravityVersionMu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
if errFetch == nil {
|
||||||
|
cachedAntigravityVersion = version
|
||||||
|
antigravityVersionExpiry = now.Add(antigravityVersionCacheTTL)
|
||||||
|
log.WithField("version", version).Info("fetched latest antigravity version")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if cachedAntigravityVersion == "" || now.After(antigravityVersionExpiry) {
|
||||||
|
cachedAntigravityVersion = antigravityFallbackVersion
|
||||||
|
antigravityVersionExpiry = now.Add(antigravityVersionCacheTTL)
|
||||||
|
log.WithError(errFetch).Warn("failed to refresh antigravity version, using fallback version")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithError(errFetch).Debug("failed to refresh antigravity version, keeping cached value")
|
||||||
|
}
|
||||||
|
|
||||||
|
// AntigravityLatestVersion returns the cached antigravity version refreshed by StartAntigravityVersionUpdater.
|
||||||
|
// It falls back to antigravityFallbackVersion if the cache is empty or stale.
|
||||||
|
func AntigravityLatestVersion() string {
|
||||||
|
antigravityVersionMu.RLock()
|
||||||
|
if cachedAntigravityVersion != "" && time.Now().Before(antigravityVersionExpiry) {
|
||||||
|
v := cachedAntigravityVersion
|
||||||
|
antigravityVersionMu.RUnlock()
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
antigravityVersionMu.RUnlock()
|
||||||
|
|
||||||
|
return antigravityFallbackVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
// AntigravityUserAgent returns the User-Agent string for antigravity requests
|
||||||
|
// using the latest version fetched from the releases API.
|
||||||
|
func AntigravityUserAgent() string {
|
||||||
|
return fmt.Sprintf("antigravity/%s darwin/arm64", AntigravityLatestVersion())
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchAntigravityLatestVersion(ctx context.Context) (string, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: antigravityFetchTimeout}
|
||||||
|
|
||||||
|
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodGet, antigravityReleasesURL, nil)
|
||||||
|
if errReq != nil {
|
||||||
|
return "", fmt.Errorf("build antigravity releases request: %w", errReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, errDo := client.Do(httpReq)
|
||||||
|
if errDo != nil {
|
||||||
|
return "", fmt.Errorf("fetch antigravity releases: %w", errDo)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.WithError(errClose).Warn("antigravity releases response body close error")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("antigravity releases API returned status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var releases []antigravityRelease
|
||||||
|
if errDecode := json.NewDecoder(resp.Body).Decode(&releases); errDecode != nil {
|
||||||
|
return "", fmt.Errorf("decode antigravity releases response: %w", errDecode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(releases) == 0 {
|
||||||
|
return "", errors.New("antigravity releases API returned empty list")
|
||||||
|
}
|
||||||
|
|
||||||
|
version := releases[0].Version
|
||||||
|
if version == "" {
|
||||||
|
return "", errors.New("antigravity releases API returned empty version")
|
||||||
|
}
|
||||||
|
|
||||||
|
return version, nil
|
||||||
|
}
|
||||||
@@ -30,6 +30,23 @@ type OAuthCallback struct {
|
|||||||
ErrorDescription string
|
ErrorDescription string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AsyncPrompt runs a prompt function in a goroutine and returns channels for
|
||||||
|
// the result. The returned channels are buffered (size 1) so the goroutine can
|
||||||
|
// complete even if the caller abandons the channels.
|
||||||
|
func AsyncPrompt(promptFn func(string) (string, error), message string) (<-chan string, <-chan error) {
|
||||||
|
inputCh := make(chan string, 1)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
input, err := promptFn(message)
|
||||||
|
if err != nil {
|
||||||
|
errCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
inputCh <- input
|
||||||
|
}()
|
||||||
|
return inputCh, errCh
|
||||||
|
}
|
||||||
|
|
||||||
// ParseOAuthCallback extracts OAuth parameters from a callback URL.
|
// ParseOAuthCallback extracts OAuth parameters from a callback URL.
|
||||||
// It returns nil when the input is empty.
|
// It returns nil when the input is empty.
|
||||||
func ParseOAuthCallback(input string) (*OAuthCallback, error) {
|
func ParseOAuthCallback(input string) (*OAuthCallback, error) {
|
||||||
|
|||||||
@@ -1,12 +1,222 @@
|
|||||||
// Package registry provides model definitions and lookup helpers for various AI providers.
|
// Package registry provides model definitions and lookup helpers for various AI providers.
|
||||||
// Static model metadata is stored in model_definitions_static_data.go.
|
// Static model metadata is loaded from the embedded models.json file and can be refreshed from network.
|
||||||
package registry
|
package registry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// staticModelsJSON mirrors the top-level structure of models.json.
|
||||||
|
type staticModelsJSON struct {
|
||||||
|
Claude []*ModelInfo `json:"claude"`
|
||||||
|
Gemini []*ModelInfo `json:"gemini"`
|
||||||
|
Vertex []*ModelInfo `json:"vertex"`
|
||||||
|
GeminiCLI []*ModelInfo `json:"gemini-cli"`
|
||||||
|
AIStudio []*ModelInfo `json:"aistudio"`
|
||||||
|
CodexFree []*ModelInfo `json:"codex-free"`
|
||||||
|
CodexTeam []*ModelInfo `json:"codex-team"`
|
||||||
|
CodexPlus []*ModelInfo `json:"codex-plus"`
|
||||||
|
CodexPro []*ModelInfo `json:"codex-pro"`
|
||||||
|
Kimi []*ModelInfo `json:"kimi"`
|
||||||
|
Antigravity []*ModelInfo `json:"antigravity"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClaudeModels returns the standard Claude model definitions.
|
||||||
|
func GetClaudeModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Claude)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiModels returns the standard Gemini model definitions.
|
||||||
|
func GetGeminiModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Gemini)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiVertexModels returns Gemini model definitions for Vertex AI.
|
||||||
|
func GetGeminiVertexModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Vertex)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiCLIModels returns Gemini model definitions for the Gemini CLI.
|
||||||
|
func GetGeminiCLIModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().GeminiCLI)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAIStudioModels returns model definitions for AI Studio.
|
||||||
|
func GetAIStudioModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().AIStudio)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexFreeModels returns model definitions for the Codex free plan tier.
|
||||||
|
func GetCodexFreeModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexFree)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexTeamModels returns model definitions for the Codex team plan tier.
|
||||||
|
func GetCodexTeamModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexTeam)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexPlusModels returns model definitions for the Codex plus plan tier.
|
||||||
|
func GetCodexPlusModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexPlus)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexProModels returns model definitions for the Codex pro plan tier.
|
||||||
|
func GetCodexProModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexPro)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions.
|
||||||
|
func GetKimiModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Kimi)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAntigravityModels returns the standard Antigravity model definitions.
|
||||||
|
func GetAntigravityModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Antigravity)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodeBuddyModels returns the available models for CodeBuddy (Tencent).
|
||||||
|
// These models are served through the copilot.tencent.com API.
|
||||||
|
func GetCodeBuddyModels() []*ModelInfo {
|
||||||
|
now := int64(1748044800) // 2025-05-24
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "auto",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "Auto",
|
||||||
|
Description: "Automatic model selection via CodeBuddy",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "glm-5v-turbo",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "GLM-5v Turbo",
|
||||||
|
Description: "GLM-5v Turbo via CodeBuddy",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "glm-5.1",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "GLM-5.1",
|
||||||
|
Description: "GLM-5.1 via CodeBuddy",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "glm-5.0-turbo",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "GLM-5.0 Turbo",
|
||||||
|
Description: "GLM-5.0 Turbo via CodeBuddy",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "glm-5.0",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "GLM-5.0",
|
||||||
|
Description: "GLM-5.0 via CodeBuddy",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "glm-4.7",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "GLM-4.7",
|
||||||
|
Description: "GLM-4.7 via CodeBuddy",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "minimax-m2.7",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "MiniMax M2.7",
|
||||||
|
Description: "MiniMax M2.7 via CodeBuddy",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kimi-k2.5",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "Kimi K2.5",
|
||||||
|
Description: "Kimi K2.5 via CodeBuddy",
|
||||||
|
ContextLength: 256000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kimi-k2-thinking",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "Kimi K2 Thinking",
|
||||||
|
Description: "Kimi K2 Thinking via CodeBuddy",
|
||||||
|
ContextLength: 256000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
Thinking: &ThinkingSupport{ZeroAllowed: true},
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "deepseek-v3-2-volc",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "DeepSeek V3.2 (Volc)",
|
||||||
|
Description: "DeepSeek V3.2 via CodeBuddy",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned.
|
||||||
|
func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*ModelInfo, len(models))
|
||||||
|
for i, m := range models {
|
||||||
|
out[i] = cloneModelInfo(m)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
|
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
|
||||||
// It returns nil when the channel is unknown.
|
// It returns nil when the channel is unknown.
|
||||||
//
|
//
|
||||||
@@ -17,13 +227,9 @@ import (
|
|||||||
// - gemini-cli
|
// - gemini-cli
|
||||||
// - aistudio
|
// - aistudio
|
||||||
// - codex
|
// - codex
|
||||||
// - qwen
|
|
||||||
// - iflow
|
|
||||||
// - kimi
|
// - kimi
|
||||||
// - kiro
|
|
||||||
// - kilo
|
// - kilo
|
||||||
// - github-copilot
|
// - github-copilot
|
||||||
// - kiro
|
|
||||||
// - amazonq
|
// - amazonq
|
||||||
// - antigravity (returns static overrides only)
|
// - antigravity (returns static overrides only)
|
||||||
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||||
@@ -40,11 +246,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
|||||||
case "aistudio":
|
case "aistudio":
|
||||||
return GetAIStudioModels()
|
return GetAIStudioModels()
|
||||||
case "codex":
|
case "codex":
|
||||||
return GetOpenAIModels()
|
return GetCodexProModels()
|
||||||
case "qwen":
|
|
||||||
return GetQwenModels()
|
|
||||||
case "iflow":
|
|
||||||
return GetIFlowModels()
|
|
||||||
case "kimi":
|
case "kimi":
|
||||||
return GetKimiModels()
|
return GetKimiModels()
|
||||||
case "github-copilot":
|
case "github-copilot":
|
||||||
@@ -56,33 +258,28 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
|||||||
case "amazonq":
|
case "amazonq":
|
||||||
return GetAmazonQModels()
|
return GetAmazonQModels()
|
||||||
case "antigravity":
|
case "antigravity":
|
||||||
cfg := GetAntigravityModelConfig()
|
return GetAntigravityModels()
|
||||||
if len(cfg) == 0 {
|
case "codebuddy":
|
||||||
return nil
|
return GetCodeBuddyModels()
|
||||||
}
|
case "cursor":
|
||||||
models := make([]*ModelInfo, 0, len(cfg))
|
return GetCursorModels()
|
||||||
for modelID, entry := range cfg {
|
|
||||||
if modelID == "" || entry == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
models = append(models, &ModelInfo{
|
|
||||||
ID: modelID,
|
|
||||||
Object: "model",
|
|
||||||
OwnedBy: "antigravity",
|
|
||||||
Type: "antigravity",
|
|
||||||
Thinking: entry.Thinking,
|
|
||||||
MaxCompletionTokens: entry.MaxCompletionTokens,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
sort.Slice(models, func(i, j int) bool {
|
|
||||||
return strings.ToLower(models[i].ID) < strings.ToLower(models[j].ID)
|
|
||||||
})
|
|
||||||
return models
|
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCursorModels returns the fallback Cursor model definitions.
|
||||||
|
func GetCursorModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
{ID: "composer-2", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Composer 2", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||||
|
{ID: "claude-4-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 4 Sonnet", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||||
|
{ID: "claude-3.5-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 3.5 Sonnet", ContextLength: 200000, MaxCompletionTokens: 8192},
|
||||||
|
{ID: "gpt-4o", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "GPT-4o", ContextLength: 128000, MaxCompletionTokens: 16384},
|
||||||
|
{ID: "cursor-small", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Cursor Small", ContextLength: 200000, MaxCompletionTokens: 64000},
|
||||||
|
{ID: "gemini-2.5-pro", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Gemini 2.5 Pro", ContextLength: 1000000, MaxCompletionTokens: 65536, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// LookupStaticModelInfo searches all static model definitions for a model by ID.
|
// LookupStaticModelInfo searches all static model definitions for a model by ID.
|
||||||
// Returns nil if no matching model is found.
|
// Returns nil if no matching model is found.
|
||||||
func LookupStaticModelInfo(modelID string) *ModelInfo {
|
func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||||
@@ -90,45 +287,46 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data := getModels()
|
||||||
allModels := [][]*ModelInfo{
|
allModels := [][]*ModelInfo{
|
||||||
GetClaudeModels(),
|
data.Claude,
|
||||||
GetGeminiModels(),
|
data.Gemini,
|
||||||
GetGeminiVertexModels(),
|
data.Vertex,
|
||||||
GetGeminiCLIModels(),
|
data.GeminiCLI,
|
||||||
GetAIStudioModels(),
|
data.AIStudio,
|
||||||
GetOpenAIModels(),
|
data.CodexPro,
|
||||||
GetQwenModels(),
|
data.Kimi,
|
||||||
GetIFlowModels(),
|
data.Antigravity,
|
||||||
GetKimiModels(),
|
|
||||||
GetGitHubCopilotModels(),
|
GetGitHubCopilotModels(),
|
||||||
GetKiroModels(),
|
GetKiroModels(),
|
||||||
GetKiloModels(),
|
GetKiloModels(),
|
||||||
GetAmazonQModels(),
|
GetAmazonQModels(),
|
||||||
|
GetCodeBuddyModels(),
|
||||||
|
GetCursorModels(),
|
||||||
}
|
}
|
||||||
for _, models := range allModels {
|
for _, models := range allModels {
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
if m != nil && m.ID == modelID {
|
if m != nil && m.ID == modelID {
|
||||||
return m
|
return cloneModelInfo(m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check Antigravity static config
|
|
||||||
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil {
|
|
||||||
return &ModelInfo{
|
|
||||||
ID: modelID,
|
|
||||||
Thinking: cfg.Thinking,
|
|
||||||
MaxCompletionTokens: cfg.MaxCompletionTokens,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// defaultCopilotClaudeContextLength is the conservative prompt token limit for
|
||||||
|
// Claude models accessed via the GitHub Copilot API. Individual accounts are
|
||||||
|
// capped at 128K; business accounts at 168K. When the dynamic /models API fetch
|
||||||
|
// succeeds, the real per-account limit overrides this value. This constant is
|
||||||
|
// only used as a safe fallback.
|
||||||
|
const defaultCopilotClaudeContextLength = 128000
|
||||||
|
|
||||||
// GetGitHubCopilotModels returns the available models for GitHub Copilot.
|
// GetGitHubCopilotModels returns the available models for GitHub Copilot.
|
||||||
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
|
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
|
||||||
func GetGitHubCopilotModels() []*ModelInfo {
|
func GetGitHubCopilotModels() []*ModelInfo {
|
||||||
now := int64(1732752000) // 2024-11-27
|
now := int64(1732752000) // 2024-11-27
|
||||||
|
copilotClaudeEndpoints := []string{"/chat/completions", "/messages"}
|
||||||
gpt4oEntries := []struct {
|
gpt4oEntries := []struct {
|
||||||
ID string
|
ID string
|
||||||
DisplayName string
|
DisplayName string
|
||||||
@@ -152,6 +350,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Description: "OpenAI GPT-4.1 via GitHub Copilot",
|
Description: "OpenAI GPT-4.1 via GitHub Copilot",
|
||||||
ContextLength: 128000,
|
ContextLength: 128000,
|
||||||
MaxCompletionTokens: 16384,
|
MaxCompletionTokens: 16384,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,6 +365,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Description: entry.Description,
|
Description: entry.Description,
|
||||||
ContextLength: 128000,
|
ContextLength: 128000,
|
||||||
MaxCompletionTokens: 16384,
|
MaxCompletionTokens: 16384,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -300,6 +500,32 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
SupportedEndpoints: []string{"/responses"},
|
SupportedEndpoints: []string{"/responses"},
|
||||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.4",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "GPT-5.4",
|
||||||
|
Description: "OpenAI GPT-5.4 via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/responses"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.4-mini",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "GPT-5.4 mini",
|
||||||
|
Description: "OpenAI GPT-5.4 mini via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/responses"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-haiku-4.5",
|
ID: "claude-haiku-4.5",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -308,9 +534,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Haiku 4.5",
|
DisplayName: "Claude Haiku 4.5",
|
||||||
Description: "Anthropic Claude Haiku 4.5 via GitHub Copilot",
|
Description: "Anthropic Claude Haiku 4.5 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: copilotClaudeEndpoints,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-opus-4.1",
|
ID: "claude-opus-4.1",
|
||||||
@@ -320,9 +546,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Opus 4.1",
|
DisplayName: "Claude Opus 4.1",
|
||||||
Description: "Anthropic Claude Opus 4.1 via GitHub Copilot",
|
Description: "Anthropic Claude Opus 4.1 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 32000,
|
MaxCompletionTokens: 32000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: copilotClaudeEndpoints,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-opus-4.5",
|
ID: "claude-opus-4.5",
|
||||||
@@ -332,9 +558,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Opus 4.5",
|
DisplayName: "Claude Opus 4.5",
|
||||||
Description: "Anthropic Claude Opus 4.5 via GitHub Copilot",
|
Description: "Anthropic Claude Opus 4.5 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: copilotClaudeEndpoints,
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-opus-4.6",
|
ID: "claude-opus-4.6",
|
||||||
@@ -344,9 +571,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Opus 4.6",
|
DisplayName: "Claude Opus 4.6",
|
||||||
Description: "Anthropic Claude Opus 4.6 via GitHub Copilot",
|
Description: "Anthropic Claude Opus 4.6 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: copilotClaudeEndpoints,
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-sonnet-4",
|
ID: "claude-sonnet-4",
|
||||||
@@ -356,9 +584,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Sonnet 4",
|
DisplayName: "Claude Sonnet 4",
|
||||||
Description: "Anthropic Claude Sonnet 4 via GitHub Copilot",
|
Description: "Anthropic Claude Sonnet 4 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: copilotClaudeEndpoints,
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-sonnet-4.5",
|
ID: "claude-sonnet-4.5",
|
||||||
@@ -368,9 +597,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Sonnet 4.5",
|
DisplayName: "Claude Sonnet 4.5",
|
||||||
Description: "Anthropic Claude Sonnet 4.5 via GitHub Copilot",
|
Description: "Anthropic Claude Sonnet 4.5 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: copilotClaudeEndpoints,
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-sonnet-4.6",
|
ID: "claude-sonnet-4.6",
|
||||||
@@ -380,9 +610,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Sonnet 4.6",
|
DisplayName: "Claude Sonnet 4.6",
|
||||||
Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot",
|
Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: copilotClaudeEndpoints,
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-2.5-pro",
|
ID: "gemini-2.5-pro",
|
||||||
@@ -394,6 +625,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Description: "Google Gemini 2.5 Pro via GitHub Copilot",
|
Description: "Google Gemini 2.5 Pro via GitHub Copilot",
|
||||||
ContextLength: 1048576,
|
ContextLength: 1048576,
|
||||||
MaxCompletionTokens: 65536,
|
MaxCompletionTokens: 65536,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3-pro-preview",
|
ID: "gemini-3-pro-preview",
|
||||||
@@ -405,6 +637,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Description: "Google Gemini 3 Pro Preview via GitHub Copilot",
|
Description: "Google Gemini 3 Pro Preview via GitHub Copilot",
|
||||||
ContextLength: 1048576,
|
ContextLength: 1048576,
|
||||||
MaxCompletionTokens: 65536,
|
MaxCompletionTokens: 65536,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3.1-pro-preview",
|
ID: "gemini-3.1-pro-preview",
|
||||||
@@ -414,8 +647,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Gemini 3.1 Pro (Preview)",
|
DisplayName: "Gemini 3.1 Pro (Preview)",
|
||||||
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
|
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
|
||||||
ContextLength: 1048576,
|
ContextLength: 173000,
|
||||||
MaxCompletionTokens: 65536,
|
MaxCompletionTokens: 65536,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3-flash-preview",
|
ID: "gemini-3-flash-preview",
|
||||||
@@ -425,8 +659,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Gemini 3 Flash (Preview)",
|
DisplayName: "Gemini 3 Flash (Preview)",
|
||||||
Description: "Google Gemini 3 Flash Preview via GitHub Copilot",
|
Description: "Google Gemini 3 Flash Preview via GitHub Copilot",
|
||||||
ContextLength: 1048576,
|
ContextLength: 173000,
|
||||||
MaxCompletionTokens: 65536,
|
MaxCompletionTokens: 65536,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "grok-code-fast-1",
|
ID: "grok-code-fast-1",
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
70
internal/registry/model_definitions_test.go
Normal file
70
internal/registry/model_definitions_test.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package registry
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestGitHubCopilotGeminiModelsAreChatOnly(t *testing.T) {
|
||||||
|
models := GetGitHubCopilotModels()
|
||||||
|
required := map[string]bool{
|
||||||
|
"gemini-2.5-pro": false,
|
||||||
|
"gemini-3-pro-preview": false,
|
||||||
|
"gemini-3.1-pro-preview": false,
|
||||||
|
"gemini-3-flash-preview": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range models {
|
||||||
|
if _, ok := required[model.ID]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
required[model.ID] = true
|
||||||
|
if len(model.SupportedEndpoints) != 1 || model.SupportedEndpoints[0] != "/chat/completions" {
|
||||||
|
t.Fatalf("model %q supported endpoints = %v, want [/chat/completions]", model.ID, model.SupportedEndpoints)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for modelID, found := range required {
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected GitHub Copilot model %q in definitions", modelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitHubCopilotClaudeModelsSupportMessages(t *testing.T) {
|
||||||
|
models := GetGitHubCopilotModels()
|
||||||
|
required := map[string]bool{
|
||||||
|
"claude-haiku-4.5": false,
|
||||||
|
"claude-opus-4.1": false,
|
||||||
|
"claude-opus-4.5": false,
|
||||||
|
"claude-opus-4.6": false,
|
||||||
|
"claude-sonnet-4": false,
|
||||||
|
"claude-sonnet-4.5": false,
|
||||||
|
"claude-sonnet-4.6": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range models {
|
||||||
|
if _, ok := required[model.ID]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
required[model.ID] = true
|
||||||
|
if !containsString(model.SupportedEndpoints, "/chat/completions") {
|
||||||
|
t.Fatalf("model %q supported endpoints = %v, missing /chat/completions", model.ID, model.SupportedEndpoints)
|
||||||
|
}
|
||||||
|
if !containsString(model.SupportedEndpoints, "/messages") {
|
||||||
|
t.Fatalf("model %q supported endpoints = %v, missing /messages", model.ID, model.SupportedEndpoints)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for modelID, found := range required {
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected GitHub Copilot model %q in definitions", modelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsString(items []string, want string) bool {
|
||||||
|
for _, item := range items {
|
||||||
|
if item == want {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -64,20 +64,25 @@ type ModelInfo struct {
|
|||||||
UserDefined bool `json:"-"`
|
UserDefined bool `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type availableModelsCacheEntry struct {
|
||||||
|
models []map[string]any
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
// ThinkingSupport describes a model family's supported internal reasoning budget range.
|
// ThinkingSupport describes a model family's supported internal reasoning budget range.
|
||||||
// Values are interpreted in provider-native token units.
|
// Values are interpreted in provider-native token units.
|
||||||
type ThinkingSupport struct {
|
type ThinkingSupport struct {
|
||||||
// Min is the minimum allowed thinking budget (inclusive).
|
// Min is the minimum allowed thinking budget (inclusive).
|
||||||
Min int `json:"min,omitempty"`
|
Min int `json:"min,omitempty" yaml:"min,omitempty"`
|
||||||
// Max is the maximum allowed thinking budget (inclusive).
|
// Max is the maximum allowed thinking budget (inclusive).
|
||||||
Max int `json:"max,omitempty"`
|
Max int `json:"max,omitempty" yaml:"max,omitempty"`
|
||||||
// ZeroAllowed indicates whether 0 is a valid value (to disable thinking).
|
// ZeroAllowed indicates whether 0 is a valid value (to disable thinking).
|
||||||
ZeroAllowed bool `json:"zero_allowed,omitempty"`
|
ZeroAllowed bool `json:"zero_allowed,omitempty" yaml:"zero-allowed,omitempty"`
|
||||||
// DynamicAllowed indicates whether -1 is a valid value (dynamic thinking budget).
|
// DynamicAllowed indicates whether -1 is a valid value (dynamic thinking budget).
|
||||||
DynamicAllowed bool `json:"dynamic_allowed,omitempty"`
|
DynamicAllowed bool `json:"dynamic_allowed,omitempty" yaml:"dynamic-allowed,omitempty"`
|
||||||
// Levels defines discrete reasoning effort levels (e.g., "low", "medium", "high").
|
// Levels defines discrete reasoning effort levels (e.g., "low", "medium", "high").
|
||||||
// When set, the model uses level-based reasoning instead of token budgets.
|
// When set, the model uses level-based reasoning instead of token budgets.
|
||||||
Levels []string `json:"levels,omitempty"`
|
Levels []string `json:"levels,omitempty" yaml:"levels,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelRegistration tracks a model's availability
|
// ModelRegistration tracks a model's availability
|
||||||
@@ -118,6 +123,8 @@ type ModelRegistry struct {
|
|||||||
clientProviders map[string]string
|
clientProviders map[string]string
|
||||||
// mutex ensures thread-safe access to the registry
|
// mutex ensures thread-safe access to the registry
|
||||||
mutex *sync.RWMutex
|
mutex *sync.RWMutex
|
||||||
|
// availableModelsCache stores per-handler snapshots for GetAvailableModels.
|
||||||
|
availableModelsCache map[string]availableModelsCacheEntry
|
||||||
// hook is an optional callback sink for model registration changes
|
// hook is an optional callback sink for model registration changes
|
||||||
hook ModelRegistryHook
|
hook ModelRegistryHook
|
||||||
}
|
}
|
||||||
@@ -130,15 +137,28 @@ var registryOnce sync.Once
|
|||||||
func GetGlobalRegistry() *ModelRegistry {
|
func GetGlobalRegistry() *ModelRegistry {
|
||||||
registryOnce.Do(func() {
|
registryOnce.Do(func() {
|
||||||
globalRegistry = &ModelRegistry{
|
globalRegistry = &ModelRegistry{
|
||||||
models: make(map[string]*ModelRegistration),
|
models: make(map[string]*ModelRegistration),
|
||||||
clientModels: make(map[string][]string),
|
clientModels: make(map[string][]string),
|
||||||
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
||||||
clientProviders: make(map[string]string),
|
clientProviders: make(map[string]string),
|
||||||
mutex: &sync.RWMutex{},
|
availableModelsCache: make(map[string]availableModelsCacheEntry),
|
||||||
|
mutex: &sync.RWMutex{},
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return globalRegistry
|
return globalRegistry
|
||||||
}
|
}
|
||||||
|
func (r *ModelRegistry) ensureAvailableModelsCacheLocked() {
|
||||||
|
if r.availableModelsCache == nil {
|
||||||
|
r.availableModelsCache = make(map[string]availableModelsCacheEntry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ModelRegistry) invalidateAvailableModelsCacheLocked() {
|
||||||
|
if len(r.availableModelsCache) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clear(r.availableModelsCache)
|
||||||
|
}
|
||||||
|
|
||||||
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
|
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
|
||||||
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
||||||
@@ -153,9 +173,9 @@ func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
|
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
|
||||||
return info
|
return cloneModelInfo(info)
|
||||||
}
|
}
|
||||||
return LookupStaticModelInfo(modelID)
|
return cloneModelInfo(LookupStaticModelInfo(modelID))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetHook sets an optional hook for observing model registration changes.
|
// SetHook sets an optional hook for observing model registration changes.
|
||||||
@@ -169,6 +189,7 @@ func (r *ModelRegistry) SetHook(hook ModelRegistryHook) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const defaultModelRegistryHookTimeout = 5 * time.Second
|
const defaultModelRegistryHookTimeout = 5 * time.Second
|
||||||
|
const modelQuotaExceededWindow = 5 * time.Minute
|
||||||
|
|
||||||
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
|
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
|
||||||
hook := r.hook
|
hook := r.hook
|
||||||
@@ -213,6 +234,7 @@ func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
|
|||||||
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
|
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
provider := strings.ToLower(clientProvider)
|
provider := strings.ToLower(clientProvider)
|
||||||
uniqueModelIDs := make([]string, 0, len(models))
|
uniqueModelIDs := make([]string, 0, len(models))
|
||||||
@@ -238,6 +260,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
delete(r.clientModels, clientID)
|
delete(r.clientModels, clientID)
|
||||||
delete(r.clientModelInfos, clientID)
|
delete(r.clientModelInfos, clientID)
|
||||||
delete(r.clientProviders, clientID)
|
delete(r.clientProviders, clientID)
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
misc.LogCredentialSeparator()
|
misc.LogCredentialSeparator()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -265,6 +288,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
} else {
|
} else {
|
||||||
delete(r.clientProviders, clientID)
|
delete(r.clientProviders, clientID)
|
||||||
}
|
}
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
r.triggerModelsRegistered(provider, clientID, models)
|
r.triggerModelsRegistered(provider, clientID, models)
|
||||||
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
|
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
|
||||||
misc.LogCredentialSeparator()
|
misc.LogCredentialSeparator()
|
||||||
@@ -367,6 +391,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
reg.InfoByProvider[provider] = cloneModelInfo(model)
|
reg.InfoByProvider[provider] = cloneModelInfo(model)
|
||||||
}
|
}
|
||||||
reg.LastUpdated = now
|
reg.LastUpdated = now
|
||||||
|
// Re-registering an existing client/model binding starts a fresh registry
|
||||||
|
// snapshot for that binding. Cooldown and suspension are transient
|
||||||
|
// scheduling state and must not survive this reconciliation step.
|
||||||
if reg.QuotaExceededClients != nil {
|
if reg.QuotaExceededClients != nil {
|
||||||
delete(reg.QuotaExceededClients, clientID)
|
delete(reg.QuotaExceededClients, clientID)
|
||||||
}
|
}
|
||||||
@@ -408,6 +435,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
delete(r.clientProviders, clientID)
|
delete(r.clientProviders, clientID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
r.triggerModelsRegistered(provider, clientID, models)
|
r.triggerModelsRegistered(provider, clientID, models)
|
||||||
if len(added) == 0 && len(removed) == 0 && !providerChanged {
|
if len(added) == 0 && len(removed) == 0 && !providerChanged {
|
||||||
// Only metadata (e.g., display name) changed; skip separator when no log output.
|
// Only metadata (e.g., display name) changed; skip separator when no log output.
|
||||||
@@ -511,6 +539,13 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
|
|||||||
if len(model.SupportedOutputModalities) > 0 {
|
if len(model.SupportedOutputModalities) > 0 {
|
||||||
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
|
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
|
||||||
}
|
}
|
||||||
|
if model.Thinking != nil {
|
||||||
|
copyThinking := *model.Thinking
|
||||||
|
if len(model.Thinking.Levels) > 0 {
|
||||||
|
copyThinking.Levels = append([]string(nil), model.Thinking.Levels...)
|
||||||
|
}
|
||||||
|
copyModel.Thinking = ©Thinking
|
||||||
|
}
|
||||||
return ©Model
|
return ©Model
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -540,6 +575,7 @@ func (r *ModelRegistry) UnregisterClient(clientID string) {
|
|||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
r.unregisterClientInternal(clientID)
|
r.unregisterClientInternal(clientID)
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
}
|
}
|
||||||
|
|
||||||
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
|
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
|
||||||
@@ -606,9 +642,12 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
|||||||
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
if registration, exists := r.models[modelID]; exists {
|
if registration, exists := r.models[modelID]; exists {
|
||||||
registration.QuotaExceededClients[clientID] = new(time.Now())
|
now := time.Now()
|
||||||
|
registration.QuotaExceededClients[clientID] = &now
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
|
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -620,9 +659,11 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
|||||||
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
|
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
if registration, exists := r.models[modelID]; exists {
|
if registration, exists := r.models[modelID]; exists {
|
||||||
delete(registration.QuotaExceededClients, clientID)
|
delete(registration.QuotaExceededClients, clientID)
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
// log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
|
// log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -638,6 +679,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
|
|||||||
}
|
}
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
registration, exists := r.models[modelID]
|
registration, exists := r.models[modelID]
|
||||||
if !exists || registration == nil {
|
if !exists || registration == nil {
|
||||||
@@ -651,6 +693,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
|
|||||||
}
|
}
|
||||||
registration.SuspendedClients[clientID] = reason
|
registration.SuspendedClients[clientID] = reason
|
||||||
registration.LastUpdated = time.Now()
|
registration.LastUpdated = time.Now()
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
if reason != "" {
|
if reason != "" {
|
||||||
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
|
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
|
||||||
} else {
|
} else {
|
||||||
@@ -668,6 +711,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
|
|||||||
}
|
}
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
registration, exists := r.models[modelID]
|
registration, exists := r.models[modelID]
|
||||||
if !exists || registration == nil || registration.SuspendedClients == nil {
|
if !exists || registration == nil || registration.SuspendedClients == nil {
|
||||||
@@ -678,6 +722,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
|
|||||||
}
|
}
|
||||||
delete(registration.SuspendedClients, clientID)
|
delete(registration.SuspendedClients, clientID)
|
||||||
registration.LastUpdated = time.Now()
|
registration.LastUpdated = time.Now()
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
log.Debugf("Resumed client %s for model %s", clientID, modelID)
|
log.Debugf("Resumed client %s for model %s", clientID, modelID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -713,22 +758,51 @@ func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool {
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - []map[string]any: List of available models in the requested format
|
// - []map[string]any: List of available models in the requested format
|
||||||
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
|
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
|
||||||
r.mutex.RLock()
|
now := time.Now()
|
||||||
defer r.mutex.RUnlock()
|
|
||||||
|
|
||||||
models := make([]map[string]any, 0)
|
r.mutex.RLock()
|
||||||
quotaExpiredDuration := 5 * time.Minute
|
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
|
||||||
|
models := cloneModelMaps(cache.models)
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
|
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
|
||||||
|
return cloneModelMaps(cache.models)
|
||||||
|
}
|
||||||
|
|
||||||
|
models, expiresAt := r.buildAvailableModelsLocked(handlerType, now)
|
||||||
|
r.availableModelsCache[handlerType] = availableModelsCacheEntry{
|
||||||
|
models: cloneModelMaps(models),
|
||||||
|
expiresAt: expiresAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
|
||||||
|
models := make([]map[string]any, 0, len(r.models))
|
||||||
|
var expiresAt time.Time
|
||||||
|
|
||||||
for _, registration := range r.models {
|
for _, registration := range r.models {
|
||||||
// Check if model has any non-quota-exceeded clients
|
|
||||||
availableClients := registration.Count
|
availableClients := registration.Count
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
// Count clients that have exceeded quota but haven't recovered yet
|
|
||||||
expiredClients := 0
|
expiredClients := 0
|
||||||
for _, quotaTime := range registration.QuotaExceededClients {
|
for _, quotaTime := range registration.QuotaExceededClients {
|
||||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
if quotaTime == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
recoveryAt := quotaTime.Add(modelQuotaExceededWindow)
|
||||||
|
if now.Before(recoveryAt) {
|
||||||
expiredClients++
|
expiredClients++
|
||||||
|
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
|
||||||
|
expiresAt = recoveryAt
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -749,7 +823,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
|||||||
effectiveClients = 0
|
effectiveClients = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Include models that have available clients, or those solely cooling down.
|
|
||||||
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
||||||
model := r.convertModelToMap(registration.Info, handlerType)
|
model := r.convertModelToMap(registration.Info, handlerType)
|
||||||
if model != nil {
|
if model != nil {
|
||||||
@@ -758,7 +831,44 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return models
|
return models, expiresAt
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneModelMaps(models []map[string]any) []map[string]any {
|
||||||
|
cloned := make([]map[string]any, 0, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if model == nil {
|
||||||
|
cloned = append(cloned, nil)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
copyModel := make(map[string]any, len(model))
|
||||||
|
for key, value := range model {
|
||||||
|
copyModel[key] = cloneModelMapValue(value)
|
||||||
|
}
|
||||||
|
cloned = append(cloned, copyModel)
|
||||||
|
}
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneModelMapValue(value any) any {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
copyMap := make(map[string]any, len(typed))
|
||||||
|
for key, entry := range typed {
|
||||||
|
copyMap[key] = cloneModelMapValue(entry)
|
||||||
|
}
|
||||||
|
return copyMap
|
||||||
|
case []any:
|
||||||
|
copySlice := make([]any, len(typed))
|
||||||
|
for i, entry := range typed {
|
||||||
|
copySlice[i] = cloneModelMapValue(entry)
|
||||||
|
}
|
||||||
|
return copySlice
|
||||||
|
case []string:
|
||||||
|
return append([]string(nil), typed...)
|
||||||
|
default:
|
||||||
|
return value
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAvailableModelsByProvider returns models available for the given provider identifier.
|
// GetAvailableModelsByProvider returns models available for the given provider identifier.
|
||||||
@@ -822,7 +932,6 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
quotaExpiredDuration := 5 * time.Minute
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
result := make([]*ModelInfo, 0, len(providerModels))
|
result := make([]*ModelInfo, 0, len(providerModels))
|
||||||
|
|
||||||
@@ -844,7 +953,7 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
|
|||||||
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
|
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
|
||||||
expiredClients++
|
expiredClients++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -874,11 +983,11 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
|
|||||||
|
|
||||||
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
||||||
if entry.info != nil {
|
if entry.info != nil {
|
||||||
result = append(result, entry.info)
|
result = append(result, cloneModelInfo(entry.info))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if ok && registration != nil && registration.Info != nil {
|
if ok && registration != nil && registration.Info != nil {
|
||||||
result = append(result, registration.Info)
|
result = append(result, cloneModelInfo(registration.Info))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -898,12 +1007,11 @@ func (r *ModelRegistry) GetModelCount(modelID string) int {
|
|||||||
|
|
||||||
if registration, exists := r.models[modelID]; exists {
|
if registration, exists := r.models[modelID]; exists {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
quotaExpiredDuration := 5 * time.Minute
|
|
||||||
|
|
||||||
// Count clients that have exceeded quota but haven't recovered yet
|
// Count clients that have exceeded quota but haven't recovered yet
|
||||||
expiredClients := 0
|
expiredClients := 0
|
||||||
for _, quotaTime := range registration.QuotaExceededClients {
|
for _, quotaTime := range registration.QuotaExceededClients {
|
||||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
|
||||||
expiredClients++
|
expiredClients++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -987,13 +1095,13 @@ func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
|
|||||||
if reg.Providers != nil {
|
if reg.Providers != nil {
|
||||||
if count, ok := reg.Providers[provider]; ok && count > 0 {
|
if count, ok := reg.Providers[provider]; ok && count > 0 {
|
||||||
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
|
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
|
||||||
return info
|
return cloneModelInfo(info)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Fallback to global info (last registered)
|
// Fallback to global info (last registered)
|
||||||
return reg.Info
|
return cloneModelInfo(reg.Info)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1033,7 +1141,7 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
|||||||
result["max_completion_tokens"] = model.MaxCompletionTokens
|
result["max_completion_tokens"] = model.MaxCompletionTokens
|
||||||
}
|
}
|
||||||
if len(model.SupportedParameters) > 0 {
|
if len(model.SupportedParameters) > 0 {
|
||||||
result["supported_parameters"] = model.SupportedParameters
|
result["supported_parameters"] = append([]string(nil), model.SupportedParameters...)
|
||||||
}
|
}
|
||||||
if len(model.SupportedEndpoints) > 0 {
|
if len(model.SupportedEndpoints) > 0 {
|
||||||
result["supported_endpoints"] = model.SupportedEndpoints
|
result["supported_endpoints"] = model.SupportedEndpoints
|
||||||
@@ -1069,6 +1177,16 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
|||||||
"dynamic_allowed": model.Thinking.DynamicAllowed,
|
"dynamic_allowed": model.Thinking.DynamicAllowed,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Include context limits so Claude Code can manage conversation
|
||||||
|
// context correctly, especially for Copilot-proxied models whose
|
||||||
|
// real prompt limit (128K-168K) is much lower than the 1M window
|
||||||
|
// that Claude Code may assume for Opus 4.6 with 1M context enabled.
|
||||||
|
if model.ContextLength > 0 {
|
||||||
|
result["context_length"] = model.ContextLength
|
||||||
|
}
|
||||||
|
if model.MaxCompletionTokens > 0 {
|
||||||
|
result["max_completion_tokens"] = model.MaxCompletionTokens
|
||||||
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
case "gemini":
|
case "gemini":
|
||||||
@@ -1094,13 +1212,13 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
|||||||
result["outputTokenLimit"] = model.OutputTokenLimit
|
result["outputTokenLimit"] = model.OutputTokenLimit
|
||||||
}
|
}
|
||||||
if len(model.SupportedGenerationMethods) > 0 {
|
if len(model.SupportedGenerationMethods) > 0 {
|
||||||
result["supportedGenerationMethods"] = model.SupportedGenerationMethods
|
result["supportedGenerationMethods"] = append([]string(nil), model.SupportedGenerationMethods...)
|
||||||
}
|
}
|
||||||
if len(model.SupportedInputModalities) > 0 {
|
if len(model.SupportedInputModalities) > 0 {
|
||||||
result["supportedInputModalities"] = model.SupportedInputModalities
|
result["supportedInputModalities"] = append([]string(nil), model.SupportedInputModalities...)
|
||||||
}
|
}
|
||||||
if len(model.SupportedOutputModalities) > 0 {
|
if len(model.SupportedOutputModalities) > 0 {
|
||||||
result["supportedOutputModalities"] = model.SupportedOutputModalities
|
result["supportedOutputModalities"] = append([]string(nil), model.SupportedOutputModalities...)
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -1129,16 +1247,20 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
|
|||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
quotaExpiredDuration := 5 * time.Minute
|
invalidated := false
|
||||||
|
|
||||||
for modelID, registration := range r.models {
|
for modelID, registration := range r.models {
|
||||||
for clientID, quotaTime := range registration.QuotaExceededClients {
|
for clientID, quotaTime := range registration.QuotaExceededClients {
|
||||||
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
|
if quotaTime != nil && now.Sub(*quotaTime) >= modelQuotaExceededWindow {
|
||||||
delete(registration.QuotaExceededClients, clientID)
|
delete(registration.QuotaExceededClients, clientID)
|
||||||
|
invalidated = true
|
||||||
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
|
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if invalidated {
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFirstAvailableModel returns the first available model for the given handler type.
|
// GetFirstAvailableModel returns the first available model for the given handler type.
|
||||||
@@ -1152,8 +1274,6 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
|
|||||||
// - string: The model ID of the first available model, or empty string if none available
|
// - string: The model ID of the first available model, or empty string if none available
|
||||||
// - error: An error if no models are available
|
// - error: An error if no models are available
|
||||||
func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) {
|
func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) {
|
||||||
r.mutex.RLock()
|
|
||||||
defer r.mutex.RUnlock()
|
|
||||||
|
|
||||||
// Get all available models for this handler type
|
// Get all available models for this handler type
|
||||||
models := r.GetAvailableModels(handlerType)
|
models := r.GetAvailableModels(handlerType)
|
||||||
@@ -1213,13 +1333,13 @@ func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo {
|
|||||||
// Prefer client's own model info to preserve original type/owned_by
|
// Prefer client's own model info to preserve original type/owned_by
|
||||||
if clientInfos != nil {
|
if clientInfos != nil {
|
||||||
if info, ok := clientInfos[modelID]; ok && info != nil {
|
if info, ok := clientInfos[modelID]; ok && info != nil {
|
||||||
result = append(result, info)
|
result = append(result, cloneModelInfo(info))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Fallback to global registry (for backwards compatibility)
|
// Fallback to global registry (for backwards compatibility)
|
||||||
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
|
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
|
||||||
result = append(result, reg.Info)
|
result = append(result, cloneModelInfo(reg.Info))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|||||||
54
internal/registry/model_registry_cache_test.go
Normal file
54
internal/registry/model_registry_cache_test.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package registry
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestGetAvailableModelsReturnsClonedSnapshots(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
|
||||||
|
|
||||||
|
first := r.GetAvailableModels("openai")
|
||||||
|
if len(first) != 1 {
|
||||||
|
t.Fatalf("expected 1 model, got %d", len(first))
|
||||||
|
}
|
||||||
|
first[0]["id"] = "mutated"
|
||||||
|
first[0]["display_name"] = "Mutated"
|
||||||
|
|
||||||
|
second := r.GetAvailableModels("openai")
|
||||||
|
if got := second[0]["id"]; got != "m1" {
|
||||||
|
t.Fatalf("expected cached snapshot to stay isolated, got id %v", got)
|
||||||
|
}
|
||||||
|
if got := second[0]["display_name"]; got != "Model One" {
|
||||||
|
t.Fatalf("expected cached snapshot to stay isolated, got display_name %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAvailableModelsInvalidatesCacheOnRegistryChanges(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
|
||||||
|
|
||||||
|
models := r.GetAvailableModels("openai")
|
||||||
|
if len(models) != 1 {
|
||||||
|
t.Fatalf("expected 1 model, got %d", len(models))
|
||||||
|
}
|
||||||
|
if got := models[0]["display_name"]; got != "Model One" {
|
||||||
|
t.Fatalf("expected initial display_name Model One, got %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One Updated"}})
|
||||||
|
models = r.GetAvailableModels("openai")
|
||||||
|
if got := models[0]["display_name"]; got != "Model One Updated" {
|
||||||
|
t.Fatalf("expected updated display_name after cache invalidation, got %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.SuspendClientModel("client-1", "m1", "manual")
|
||||||
|
models = r.GetAvailableModels("openai")
|
||||||
|
if len(models) != 0 {
|
||||||
|
t.Fatalf("expected no available models after suspension, got %d", len(models))
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ResumeClientModel("client-1", "m1")
|
||||||
|
models = r.GetAvailableModels("openai")
|
||||||
|
if len(models) != 1 {
|
||||||
|
t.Fatalf("expected model to reappear after resume, got %d", len(models))
|
||||||
|
}
|
||||||
|
}
|
||||||
149
internal/registry/model_registry_safety_test.go
Normal file
149
internal/registry/model_registry_safety_test.go
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
package registry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetModelInfoReturnsClone(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||||
|
ID: "m1",
|
||||||
|
DisplayName: "Model One",
|
||||||
|
Thinking: &ThinkingSupport{Min: 1, Max: 2, Levels: []string{"low", "high"}},
|
||||||
|
}})
|
||||||
|
|
||||||
|
first := r.GetModelInfo("m1", "gemini")
|
||||||
|
if first == nil {
|
||||||
|
t.Fatal("expected model info")
|
||||||
|
}
|
||||||
|
first.DisplayName = "mutated"
|
||||||
|
first.Thinking.Levels[0] = "mutated"
|
||||||
|
|
||||||
|
second := r.GetModelInfo("m1", "gemini")
|
||||||
|
if second.DisplayName != "Model One" {
|
||||||
|
t.Fatalf("expected cloned display name, got %q", second.DisplayName)
|
||||||
|
}
|
||||||
|
if second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] != "low" {
|
||||||
|
t.Fatalf("expected cloned thinking levels, got %+v", second.Thinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetModelsForClientReturnsClones(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||||
|
ID: "m1",
|
||||||
|
DisplayName: "Model One",
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
|
||||||
|
}})
|
||||||
|
|
||||||
|
first := r.GetModelsForClient("client-1")
|
||||||
|
if len(first) != 1 || first[0] == nil {
|
||||||
|
t.Fatalf("expected one model, got %+v", first)
|
||||||
|
}
|
||||||
|
first[0].DisplayName = "mutated"
|
||||||
|
first[0].Thinking.Levels[0] = "mutated"
|
||||||
|
|
||||||
|
second := r.GetModelsForClient("client-1")
|
||||||
|
if len(second) != 1 || second[0] == nil {
|
||||||
|
t.Fatalf("expected one model on second fetch, got %+v", second)
|
||||||
|
}
|
||||||
|
if second[0].DisplayName != "Model One" {
|
||||||
|
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
|
||||||
|
}
|
||||||
|
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
|
||||||
|
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAvailableModelsByProviderReturnsClones(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||||
|
ID: "m1",
|
||||||
|
DisplayName: "Model One",
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
|
||||||
|
}})
|
||||||
|
|
||||||
|
first := r.GetAvailableModelsByProvider("gemini")
|
||||||
|
if len(first) != 1 || first[0] == nil {
|
||||||
|
t.Fatalf("expected one model, got %+v", first)
|
||||||
|
}
|
||||||
|
first[0].DisplayName = "mutated"
|
||||||
|
first[0].Thinking.Levels[0] = "mutated"
|
||||||
|
|
||||||
|
second := r.GetAvailableModelsByProvider("gemini")
|
||||||
|
if len(second) != 1 || second[0] == nil {
|
||||||
|
t.Fatalf("expected one model on second fetch, got %+v", second)
|
||||||
|
}
|
||||||
|
if second[0].DisplayName != "Model One" {
|
||||||
|
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
|
||||||
|
}
|
||||||
|
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
|
||||||
|
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanupExpiredQuotasInvalidatesAvailableModelsCache(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "openai", []*ModelInfo{{ID: "m1", Created: 1}})
|
||||||
|
r.SetModelQuotaExceeded("client-1", "m1")
|
||||||
|
if models := r.GetAvailableModels("openai"); len(models) != 1 {
|
||||||
|
t.Fatalf("expected cooldown model to remain listed before cleanup, got %d", len(models))
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mutex.Lock()
|
||||||
|
quotaTime := time.Now().Add(-6 * time.Minute)
|
||||||
|
r.models["m1"].QuotaExceededClients["client-1"] = "aTime
|
||||||
|
r.mutex.Unlock()
|
||||||
|
|
||||||
|
r.CleanupExpiredQuotas()
|
||||||
|
|
||||||
|
if count := r.GetModelCount("m1"); count != 1 {
|
||||||
|
t.Fatalf("expected model count 1 after cleanup, got %d", count)
|
||||||
|
}
|
||||||
|
models := r.GetAvailableModels("openai")
|
||||||
|
if len(models) != 1 {
|
||||||
|
t.Fatalf("expected model to stay available after cleanup, got %d", len(models))
|
||||||
|
}
|
||||||
|
if got := models[0]["id"]; got != "m1" {
|
||||||
|
t.Fatalf("expected model id m1, got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAvailableModelsReturnsClonedSupportedParameters(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "openai", []*ModelInfo{{
|
||||||
|
ID: "m1",
|
||||||
|
DisplayName: "Model One",
|
||||||
|
SupportedParameters: []string{"temperature", "top_p"},
|
||||||
|
}})
|
||||||
|
|
||||||
|
first := r.GetAvailableModels("openai")
|
||||||
|
if len(first) != 1 {
|
||||||
|
t.Fatalf("expected one model, got %d", len(first))
|
||||||
|
}
|
||||||
|
params, ok := first[0]["supported_parameters"].([]string)
|
||||||
|
if !ok || len(params) != 2 {
|
||||||
|
t.Fatalf("expected supported_parameters slice, got %#v", first[0]["supported_parameters"])
|
||||||
|
}
|
||||||
|
params[0] = "mutated"
|
||||||
|
|
||||||
|
second := r.GetAvailableModels("openai")
|
||||||
|
params, ok = second[0]["supported_parameters"].([]string)
|
||||||
|
if !ok || len(params) != 2 || params[0] != "temperature" {
|
||||||
|
t.Fatalf("expected cloned supported_parameters, got %#v", second[0]["supported_parameters"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLookupModelInfoReturnsCloneForStaticDefinitions(t *testing.T) {
|
||||||
|
first := LookupModelInfo("glm-4.6")
|
||||||
|
if first == nil || first.Thinking == nil || len(first.Thinking.Levels) == 0 {
|
||||||
|
t.Fatalf("expected static model with thinking levels, got %+v", first)
|
||||||
|
}
|
||||||
|
first.Thinking.Levels[0] = "mutated"
|
||||||
|
|
||||||
|
second := LookupModelInfo("glm-4.6")
|
||||||
|
if second == nil || second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] == "mutated" {
|
||||||
|
t.Fatalf("expected static lookup clone, got %+v", second)
|
||||||
|
}
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user