mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-20 22:51:45 +00:00
Compare commits
617 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0c48ef58e0 | ||
|
|
15b0d8d039 | ||
|
|
5dcca69e8c | ||
|
|
f5dc6483d5 | ||
|
|
d949921143 | ||
|
|
7b03f04670 | ||
|
|
1267fddf61 | ||
|
|
85c7d43bea | ||
|
|
44c74d6ea2 | ||
|
|
ba454dbfbf | ||
|
|
d1508ca030 | ||
|
|
d4a6a5ae15 | ||
|
|
7c24d54ca8 | ||
|
|
a4c1e32ff6 | ||
|
|
f56cf42461 | ||
|
|
3dea1da249 | ||
|
|
8fac29631d | ||
|
|
8fecd625d2 | ||
|
|
10b55b5ddd | ||
|
|
41ae2c81e7 | ||
|
|
278a89824c | ||
|
|
c4459c4346 | ||
|
|
61e0447f92 | ||
|
|
1dc3018fd6 | ||
|
|
26fd3eff03 | ||
|
|
5bfaf8086b | ||
|
|
6c0a1efd71 | ||
|
|
f5ed5c7453 | ||
|
|
65158cce46 | ||
|
|
1c6c3675d1 | ||
|
|
a583463d60 | ||
|
|
8ed290c1c4 | ||
|
|
727221df2e | ||
|
|
1d8e68ad15 | ||
|
|
0ab1f5412f | ||
|
|
9ded75d335 | ||
|
|
f135fdf7fc | ||
|
|
828df80088 | ||
|
|
c585caa0ce | ||
|
|
5bb69fa4ab | ||
|
|
344043b9f1 | ||
|
|
26c298ced1 | ||
|
|
5ab9afac83 | ||
|
|
65ce86338b | ||
|
|
2a97037d7b | ||
|
|
d801393841 | ||
|
|
b2c0cdfc88 | ||
|
|
f32c8c9620 | ||
|
|
0f45d89255 | ||
|
|
96056d0137 | ||
|
|
f780c289e8 | ||
|
|
ac36119a02 | ||
|
|
39dc4557c1 | ||
|
|
30e94b6792 | ||
|
|
938af75954 | ||
|
|
38f0ae5970 | ||
|
|
cf249586a9 | ||
|
|
1dba2d0f81 | ||
|
|
730809d8ea | ||
|
|
e8d1b79cb3 | ||
|
|
5e81b65f2f | ||
|
|
7e8e2226a6 | ||
|
|
f0c20e852f | ||
|
|
7cdf8e9872 | ||
|
|
c42480a574 | ||
|
|
55c146a0e7 | ||
|
|
e2e3c7dde0 | ||
|
|
9e0ab4d116 | ||
|
|
8783caf313 | ||
|
|
f6f4640c5e | ||
|
|
613fe6768d | ||
|
|
ad8e3964ff | ||
|
|
e9dc576409 | ||
|
|
941334da79 | ||
|
|
d54f816363 | ||
|
|
69b950db4c | ||
|
|
f43d25def1 | ||
|
|
a279192881 | ||
|
|
6a43d7285c | ||
|
|
578c312660 | ||
|
|
6bb9bf3132 | ||
|
|
343a2fc2f7 | ||
|
|
12b967118b | ||
|
|
70efd4e016 | ||
|
|
f5aa68ecda | ||
|
|
9a5f142c33 | ||
|
|
d390b95b76 | ||
|
|
d1f6224b70 | ||
|
|
fcc59d606d | ||
|
|
91e7591955 | ||
|
|
4607356333 | ||
|
|
9a9ed99072 | ||
|
|
5ae38584b8 | ||
|
|
c8b7e2b8d6 | ||
|
|
cad45ffa33 | ||
|
|
6a27bceec0 | ||
|
|
163d68318f | ||
|
|
0ea768011b | ||
|
|
8b9dbe10f0 | ||
|
|
341b4beea1 | ||
|
|
bea13f9724 | ||
|
|
9f5bdfaa31 | ||
|
|
9eabdd09db | ||
|
|
c3f8dc362e | ||
|
|
b85120873b | ||
|
|
6f58518c69 | ||
|
|
000fcb15fa | ||
|
|
ea43361492 | ||
|
|
c1818f197b | ||
|
|
b0653cec7b | ||
|
|
22a1a24cf5 | ||
|
|
7223fee2de | ||
|
|
ada8e2905e | ||
|
|
4ba10531da | ||
|
|
3774b56e9f | ||
|
|
c2d4137fb9 | ||
|
|
2ee938acaf | ||
|
|
8d5e470e1f | ||
|
|
65e9e892a4 | ||
|
|
3882494878 | ||
|
|
088c1d07f4 | ||
|
|
8430b28cfa | ||
|
|
f3ab8f4bc5 | ||
|
|
0e4f189c2e | ||
|
|
98509f615c | ||
|
|
e7a66ae504 | ||
|
|
754b126944 | ||
|
|
ae37ccffbf | ||
|
|
42c062bb5b | ||
|
|
87bf0b73d5 | ||
|
|
f389667ec3 | ||
|
|
29dba0399b | ||
|
|
a824e7cd0b | ||
|
|
140faef7dc | ||
|
|
adb580b344 | ||
|
|
06405f2129 | ||
|
|
b849bf79d6 | ||
|
|
59af2c57b1 | ||
|
|
d1fd2c4ad4 | ||
|
|
b6c6379bfa | ||
|
|
8f0e66b72e | ||
|
|
f63cf6ff7a | ||
|
|
d2419ed49d | ||
|
|
516d22c695 | ||
|
|
73cda6e836 | ||
|
|
0805989ee5 | ||
|
|
9b5ce8c64f | ||
|
|
058793c73a | ||
|
|
75da02af55 | ||
|
|
ab9ebea592 | ||
|
|
7ee37ee4b9 | ||
|
|
837afffb31 | ||
|
|
03a1bac898 | ||
|
|
3171d524f0 | ||
|
|
3e78a8d500 | ||
|
|
fcba912cc4 | ||
|
|
7170eeea5f | ||
|
|
e3eb048c7a | ||
|
|
a59e92435b | ||
|
|
108895fc04 | ||
|
|
abc293c642 | ||
|
|
da3a498a28 | ||
|
|
bb44671845 | ||
|
|
09e480036a | ||
|
|
249f969110 | ||
|
|
4f8acec2d8 | ||
|
|
34339f61ee | ||
|
|
4045378cb4 | ||
|
|
2df35449fe | ||
|
|
c744179645 | ||
|
|
9720b03a6b | ||
|
|
f2c0f3d325 | ||
|
|
4f99bc54f1 | ||
|
|
913f4a9c5f | ||
|
|
25d1c18a3f | ||
|
|
d09dd4d0b2 | ||
|
|
474fb042da | ||
|
|
8435c3d7be | ||
|
|
e783d0a62e | ||
|
|
b05f575e9b | ||
|
|
f5e9f01811 | ||
|
|
ff7dbb5867 | ||
|
|
e34b2b4f1d | ||
|
|
15c2f274ea | ||
|
|
37249339ac | ||
|
|
c422d16beb | ||
|
|
66cd50f603 | ||
|
|
caa529c282 | ||
|
|
51a4379bf4 | ||
|
|
acf98ed10e | ||
|
|
d1c07a091e | ||
|
|
c1a8adf1ab | ||
|
|
08e078fc25 | ||
|
|
105a21548f | ||
|
|
1734aa1664 | ||
|
|
ca11b236a7 | ||
|
|
6fdff8227d | ||
|
|
330e12d3c2 | ||
|
|
bd09c0bf09 | ||
|
|
b468ca79c3 | ||
|
|
d2c7e4e96a | ||
|
|
1c7003ff68 | ||
|
|
1b44364e78 | ||
|
|
ec77f4a4f5 | ||
|
|
f611dd6e96 | ||
|
|
07b7c1a1e0 | ||
|
|
51fd58d74f | ||
|
|
faae9c2f7c | ||
|
|
bc3a6e4646 | ||
|
|
b09b03e35e | ||
|
|
16231947e7 | ||
|
|
39b9a38fbc | ||
|
|
bd855abec9 | ||
|
|
7c3c2e9f64 | ||
|
|
c10f8ae2e2 | ||
|
|
a0bf33eca6 | ||
|
|
88dd9c715d | ||
|
|
a3e21df814 | ||
|
|
d3b94c9241 | ||
|
|
c1d7599829 | ||
|
|
d11936f292 | ||
|
|
17363edf25 | ||
|
|
279cbbbb8a | ||
|
|
486cd4c343 | ||
|
|
25feceb783 | ||
|
|
d26752250d | ||
|
|
b15453c369 | ||
|
|
04ba8c8bc3 | ||
|
|
6570692291 | ||
|
|
f73d55ddaa | ||
|
|
13aa5b3375 | ||
|
|
0fcc02fbea | ||
|
|
c03883ccf0 | ||
|
|
134a9eac9d | ||
|
|
6d8de0ade4 | ||
|
|
1587ff5e74 | ||
|
|
f033d3a6df | ||
|
|
145e0e0b5d | ||
|
|
f8d1bc06ea | ||
|
|
d5930f4e44 | ||
|
|
9b7d7021af | ||
|
|
e41c22ef44 | ||
|
|
5fc2bd393e | ||
|
|
55271403fb | ||
|
|
36fba66619 | ||
|
|
66eb12294a | ||
|
|
73b22ec29b | ||
|
|
c31ae2f3b5 | ||
|
|
76b53d6b5b | ||
|
|
a34dfed378 | ||
|
|
b9b127a7ea | ||
|
|
2741e7b7b3 | ||
|
|
1767a56d4f | ||
|
|
779e6c2d2f | ||
|
|
73c831747b | ||
|
|
b8b89f34f4 | ||
|
|
1fa094dac6 | ||
|
|
f55754621f | ||
|
|
ac26e7db43 | ||
|
|
10b824fcac | ||
|
|
e5d3541b5a | ||
|
|
79755e76ea | ||
|
|
35f158d526 | ||
|
|
6962e09dd9 | ||
|
|
4c4cbd44da | ||
|
|
26eca8b6ba | ||
|
|
62b17f40a1 | ||
|
|
511b8a992e | ||
|
|
7dccc7ba2f | ||
|
|
70c90687fd | ||
|
|
8144ffd5c8 | ||
|
|
0ab977c236 | ||
|
|
224f0de353 | ||
|
|
6b45d311ec | ||
|
|
d54de441d3 | ||
|
|
7386a70724 | ||
|
|
1821bf7051 | ||
|
|
d42b5d4e78 | ||
|
|
1b7447b682 | ||
|
|
40dee4453a | ||
|
|
8902e1cccb | ||
|
|
de5fe71478 | ||
|
|
dcfbec2990 | ||
|
|
c95620f90e | ||
|
|
754f3bcbc3 | ||
|
|
36973d4a6f | ||
|
|
9613f0b3f9 | ||
|
|
274f29e26b | ||
|
|
c8e79c3787 | ||
|
|
8afef43887 | ||
|
|
c1083cbfc6 | ||
|
|
c89d19b300 | ||
|
|
1e6bc81cfd | ||
|
|
1a149475e0 | ||
|
|
e5166841db | ||
|
|
19c52bcb60 | ||
|
|
bb9b2d1758 | ||
|
|
7fa527193c | ||
|
|
ed0eb51b4d | ||
|
|
0e4f669c8b | ||
|
|
76c064c729 | ||
|
|
d2f652f436 | ||
|
|
6a452a54d5 | ||
|
|
9e5693e74f | ||
|
|
528b1a2307 | ||
|
|
0cc978ec1d | ||
|
|
d312422ab4 | ||
|
|
fee736933b | ||
|
|
09c92aa0b5 | ||
|
|
8c67b3ae64 | ||
|
|
000e4ceb4e | ||
|
|
5c99846ecf | ||
|
|
cc32f5ff61 | ||
|
|
fbff68b9e0 | ||
|
|
7e1a543b79 | ||
|
|
d475aaba96 | ||
|
|
1dc4ecb1b8 | ||
|
|
1315f710f5 | ||
|
|
96f55570f7 | ||
|
|
0906aeca87 | ||
|
|
7333619f15 | ||
|
|
97c0487add | ||
|
|
74b862d8b8 | ||
|
|
2db8df8e38 | ||
|
|
a576088d5f | ||
|
|
66ff916838 | ||
|
|
7b0453074e | ||
|
|
a000eb523d | ||
|
|
18a4fedc7f | ||
|
|
5d6cdccda0 | ||
|
|
1b7f4ac3e1 | ||
|
|
afc1a5b814 | ||
|
|
7ed38db54f | ||
|
|
28c10f4e69 | ||
|
|
6e12441a3b | ||
|
|
65c439c18d | ||
|
|
0ed2d16596 | ||
|
|
db335ac616 | ||
|
|
f3c59165d7 | ||
|
|
e6690cb447 | ||
|
|
35907416b8 | ||
|
|
e8bb350467 | ||
|
|
5331d51f27 | ||
|
|
755ca75879 | ||
|
|
2398ebad55 | ||
|
|
c1bf298216 | ||
|
|
e005208d76 | ||
|
|
d1df70d02f | ||
|
|
f81acd0760 | ||
|
|
636da4c932 | ||
|
|
cccb77b552 | ||
|
|
2bd646ad70 | ||
|
|
52c1fa025e | ||
|
|
680105f84d | ||
|
|
f7069e9548 | ||
|
|
7275e99b41 | ||
|
|
c28b65f849 | ||
|
|
793840cdb4 | ||
|
|
8f421de532 | ||
|
|
be2dd60ee7 | ||
|
|
ea3e0b713e | ||
|
|
8179d5a8a4 | ||
|
|
6fa7abe434 | ||
|
|
5135c22cd6 | ||
|
|
1e27990561 | ||
|
|
e1e9fc43c1 | ||
|
|
b2921518ac | ||
|
|
dd64adbeeb | ||
|
|
616d41c06a | ||
|
|
e0e337aeb9 | ||
|
|
d52839fced | ||
|
|
4022e69651 | ||
|
|
56073ded69 | ||
|
|
9738a53f49 | ||
|
|
be3f8dbf7e | ||
|
|
9c6c3612a8 | ||
|
|
19e1a4447a | ||
|
|
36efcc6e28 | ||
|
|
a337ecf35c | ||
|
|
7c2ad4cda2 | ||
|
|
fb95813fbf | ||
|
|
db63f9b5d6 | ||
|
|
25f6c4a250 | ||
|
|
b24ae74216 | ||
|
|
59ad8f40dc | ||
|
|
ff03dc6a2c | ||
|
|
dc7187ca5b | ||
|
|
b1dcff778c | ||
|
|
cef2aeeb08 | ||
|
|
bcd1e8cc34 | ||
|
|
198b3f4a40 | ||
|
|
9fee7f488e | ||
|
|
1b46d39b8b | ||
|
|
c1241a98e2 | ||
|
|
8d8f5970ee | ||
|
|
f90120f846 | ||
|
|
0b94d36c4a | ||
|
|
152c310bb7 | ||
|
|
f6bbca35ab | ||
|
|
c8cee6a209 | ||
|
|
b5701f416b | ||
|
|
4b1a404fcb | ||
|
|
b93cce5412 | ||
|
|
c6cb24039d | ||
|
|
5382408489 | ||
|
|
67669196ed | ||
|
|
5c817a9b42 | ||
|
|
e08f68ed7c | ||
|
|
f09ed25fd3 | ||
|
|
58fd9bf964 | ||
|
|
7b3dfc67bc | ||
|
|
cdd24052d3 | ||
|
|
5da0decef6 | ||
|
|
733fd8edab | ||
|
|
af27f2b8bc | ||
|
|
2e1925d762 | ||
|
|
77254bd074 | ||
|
|
5b6342e6ac | ||
|
|
e166e56249 | ||
|
|
3960c93d51 | ||
|
|
339a81b650 | ||
|
|
560c020477 | ||
|
|
aec65e3be3 | ||
|
|
f44f0702f8 | ||
|
|
b76b79068f | ||
|
|
34c8ccb961 | ||
|
|
d08e164af3 | ||
|
|
8178efaeda | ||
|
|
86d5db472a | ||
|
|
020d36f6e8 | ||
|
|
1db23979e8 | ||
|
|
c3d5dbe96f | ||
|
|
5484489406 | ||
|
|
0ac52da460 | ||
|
|
817cebb321 | ||
|
|
683f3709d6 | ||
|
|
dbd42a42b2 | ||
|
|
ec24baf757 | ||
|
|
dea3e74d35 | ||
|
|
a6c3042e34 | ||
|
|
861537c9bd | ||
|
|
8c92cb0883 | ||
|
|
89d7be9525 | ||
|
|
2b79d7f22f | ||
|
|
2bb686f594 | ||
|
|
163fe287ce | ||
|
|
70988d387b | ||
|
|
52058a1659 | ||
|
|
df5595a0c9 | ||
|
|
ddaa9d2436 | ||
|
|
7b7b258c38 | ||
|
|
a00f774f5a | ||
|
|
9daf1ba8b5 | ||
|
|
76f2359637 | ||
|
|
dcb1c9be8a | ||
|
|
a24f4ace78 | ||
|
|
c631df8c3b | ||
|
|
54c3eb1b1e | ||
|
|
bb28cd26ad | ||
|
|
046865461e | ||
|
|
cf74ed2f0c | ||
|
|
c3762328a5 | ||
|
|
e333fbea3d | ||
|
|
efbe36d1d4 | ||
|
|
8553cfa40e | ||
|
|
30d5c95b26 | ||
|
|
d1e3195e6f | ||
|
|
05a35662ae | ||
|
|
ce53d3a287 | ||
|
|
5f58248016 | ||
|
|
4cc99e7449 | ||
|
|
71773fe032 | ||
|
|
a1e0fa0f39 | ||
|
|
fc2f0b6983 | ||
|
|
5c9997cdac | ||
|
|
6f81046730 | ||
|
|
0687472d01 | ||
|
|
7739738fb3 | ||
|
|
99d1ce247b | ||
|
|
f5941a411c | ||
|
|
ba672bbd07 | ||
|
|
d9c6627a53 | ||
|
|
2e9907c3ac | ||
|
|
90afb9cb73 | ||
|
|
d0cc0cd9a5 | ||
|
|
338321e553 | ||
|
|
182b31963a | ||
|
|
4f48e5254a | ||
|
|
15dd5db1d7 | ||
|
|
424711b718 | ||
|
|
91a2b1f0b4 | ||
|
|
2b134fc378 | ||
|
|
b9153719b0 | ||
|
|
631e5c8331 | ||
|
|
e9c60a0a67 | ||
|
|
98a1bb5a7f | ||
|
|
ca90487a8c | ||
|
|
1042489f85 | ||
|
|
38277c1ea6 | ||
|
|
ee0c24628f | ||
|
|
07d6689d87 | ||
|
|
3a18f6fcca | ||
|
|
099e734a02 | ||
|
|
a52da26b5d | ||
|
|
522a68a4ea | ||
|
|
a02eda54d0 | ||
|
|
97ef633c57 | ||
|
|
dae8463ba1 | ||
|
|
7c1299922e | ||
|
|
ddcf1f279d | ||
|
|
7e6bb8fdc5 | ||
|
|
9cee8ef87b | ||
|
|
93fb841bcb | ||
|
|
0c05131aeb | ||
|
|
5ebc58fab4 | ||
|
|
2b609dd891 | ||
|
|
a8cbc68c3e | ||
|
|
11a795a01c | ||
|
|
89c428216e | ||
|
|
2695a99623 | ||
|
|
242aecd924 | ||
|
|
ce8cc1ba33 | ||
|
|
ad5253bd2b | ||
|
|
97fdd2e088 | ||
|
|
9397f7049f | ||
|
|
a14d19b92c | ||
|
|
8ae0c05ea6 | ||
|
|
8822f20d17 | ||
|
|
553d6f50ea | ||
|
|
f0e5a5a367 | ||
|
|
f6dfea9357 | ||
|
|
cc8dc7f62c | ||
|
|
a3846ea513 | ||
|
|
8d44be858e | ||
|
|
0e6bb076e9 | ||
|
|
ac135fc7cb | ||
|
|
4e1d09809d | ||
|
|
9e855f8100 | ||
|
|
25680a8259 | ||
|
|
13c93e8cfd | ||
|
|
88aa1b9fd1 | ||
|
|
352cb98ff0 | ||
|
|
ac95e92829 | ||
|
|
8526c2da25 | ||
|
|
68a6cabf8b | ||
|
|
ac0e387da1 | ||
|
|
7fe1d102cb | ||
|
|
5850492a93 | ||
|
|
fdbd4041ca | ||
|
|
ebef1fae2a | ||
|
|
c51851689b | ||
|
|
419bf784ab | ||
|
|
4bbeb92e9a | ||
|
|
b436dad8bc | ||
|
|
6ae15d6c44 | ||
|
|
0468bde0d6 | ||
|
|
1d7329e797 | ||
|
|
48ffc4dee7 | ||
|
|
7ebd8f0c44 | ||
|
|
b680c146c1 | ||
|
|
7d6660d181 | ||
|
|
d8e3d4e2b6 | ||
|
|
d26ad8224d | ||
|
|
5c84d69d42 | ||
|
|
527e4b7f26 | ||
|
|
b48485b42b | ||
|
|
79009bb3d4 | ||
|
|
26fc611f86 | ||
|
|
b43743d4f1 | ||
|
|
179e5434b1 | ||
|
|
9f95b31158 | ||
|
|
5da07eae4c | ||
|
|
835ae178d4 | ||
|
|
c80ab8bf0d | ||
|
|
ce87714ef1 | ||
|
|
0452b869e8 | ||
|
|
d2e5857b82 | ||
|
|
f9b005f21f | ||
|
|
532107b4fa | ||
|
|
c44793789b | ||
|
|
4e99525279 | ||
|
|
7547d1d0b3 | ||
|
|
68934942d0 | ||
|
|
09fec34e1c | ||
|
|
9229708b6c | ||
|
|
914db94e79 | ||
|
|
660bd7eff5 | ||
|
|
b907d21851 | ||
|
|
dd44413ba5 | ||
|
|
10fa0f2062 | ||
|
|
d6cc976d1f | ||
|
|
8aa2cce8c5 | ||
|
|
bf9b2c49df | ||
|
|
77b42c6165 | ||
|
|
30338ecec4 | ||
|
|
9a37defed3 | ||
|
|
c83a057996 | ||
|
|
68dd2bfe82 | ||
|
|
2baf35b3ef | ||
|
|
846e75b893 | ||
|
|
fc0257d6d9 | ||
|
|
f3c164d345 | ||
|
|
4040b1e766 | ||
|
|
b7588428c5 | ||
|
|
8f97a5f77c | ||
|
|
2a4d3e60f3 | ||
|
|
8b5af2ab84 | ||
|
|
d887716ebd | ||
|
|
5dc1848466 | ||
|
|
9491517b26 | ||
|
|
9370b5bd04 | ||
|
|
abb51a0d93 | ||
|
|
c8d809131b | ||
|
|
dd71c73a9f | ||
|
|
2615f489d6 | ||
|
|
14cb2b95c6 | ||
|
|
fdeef48498 |
81
.github/workflows/agents-md-guard.yml
vendored
Normal file
81
.github/workflows/agents-md-guard.yml
vendored
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
name: agents-md-guard
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request_target:
|
||||||
|
types:
|
||||||
|
- opened
|
||||||
|
- synchronize
|
||||||
|
- reopened
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
issues: write
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
close-when-agents-md-changed:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Detect AGENTS.md changes and close PR
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const prNumber = context.payload.pull_request.number;
|
||||||
|
const { owner, repo } = context.repo;
|
||||||
|
|
||||||
|
const files = await github.paginate(github.rest.pulls.listFiles, {
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
pull_number: prNumber,
|
||||||
|
per_page: 100,
|
||||||
|
});
|
||||||
|
|
||||||
|
const touchesAgentsMd = (path) =>
|
||||||
|
typeof path === "string" &&
|
||||||
|
(path === "AGENTS.md" || path.endsWith("/AGENTS.md"));
|
||||||
|
|
||||||
|
const touched = files.filter(
|
||||||
|
(f) => touchesAgentsMd(f.filename) || touchesAgentsMd(f.previous_filename),
|
||||||
|
);
|
||||||
|
|
||||||
|
if (touched.length === 0) {
|
||||||
|
core.info("No AGENTS.md changes detected.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const changedList = touched
|
||||||
|
.map((f) =>
|
||||||
|
f.previous_filename && f.previous_filename !== f.filename
|
||||||
|
? `- ${f.previous_filename} -> ${f.filename}`
|
||||||
|
: `- ${f.filename}`,
|
||||||
|
)
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
const body = [
|
||||||
|
"This repository does not allow modifying `AGENTS.md` in pull requests.",
|
||||||
|
"",
|
||||||
|
"Detected changes:",
|
||||||
|
changedList,
|
||||||
|
"",
|
||||||
|
"Please revert these changes and open a new PR without touching `AGENTS.md`.",
|
||||||
|
].join("\n");
|
||||||
|
|
||||||
|
try {
|
||||||
|
await github.rest.issues.createComment({
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
issue_number: prNumber,
|
||||||
|
body,
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
core.warning(`Failed to comment on PR #${prNumber}: ${error.message}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
await github.rest.pulls.update({
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
pull_number: prNumber,
|
||||||
|
state: "closed",
|
||||||
|
});
|
||||||
|
|
||||||
|
core.setFailed("PR modifies AGENTS.md");
|
||||||
73
.github/workflows/auto-retarget-main-pr-to-dev.yml
vendored
Normal file
73
.github/workflows/auto-retarget-main-pr-to-dev.yml
vendored
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
name: auto-retarget-main-pr-to-dev
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request_target:
|
||||||
|
types:
|
||||||
|
- opened
|
||||||
|
- reopened
|
||||||
|
- edited
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
issues: write
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
retarget:
|
||||||
|
if: github.actor != 'github-actions[bot]'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Retarget PR base to dev
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const pr = context.payload.pull_request;
|
||||||
|
const prNumber = pr.number;
|
||||||
|
const { owner, repo } = context.repo;
|
||||||
|
|
||||||
|
const baseRef = pr.base?.ref;
|
||||||
|
const headRef = pr.head?.ref;
|
||||||
|
const desiredBase = "dev";
|
||||||
|
|
||||||
|
if (baseRef !== "main") {
|
||||||
|
core.info(`PR #${prNumber} base is ${baseRef}; nothing to do.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (headRef === desiredBase) {
|
||||||
|
core.info(`PR #${prNumber} is ${desiredBase} -> main; skipping retarget.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
core.info(`Retargeting PR #${prNumber} base from ${baseRef} to ${desiredBase}.`);
|
||||||
|
|
||||||
|
try {
|
||||||
|
await github.rest.pulls.update({
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
pull_number: prNumber,
|
||||||
|
base: desiredBase,
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
core.setFailed(`Failed to retarget PR #${prNumber} to ${desiredBase}: ${error.message}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const body = [
|
||||||
|
`This pull request targeted \`${baseRef}\`.`,
|
||||||
|
"",
|
||||||
|
`The base branch has been automatically changed to \`${desiredBase}\`.`,
|
||||||
|
].join("\n");
|
||||||
|
|
||||||
|
try {
|
||||||
|
await github.rest.issues.createComment({
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
issue_number: prNumber,
|
||||||
|
body,
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
core.warning(`Failed to comment on PR #${prNumber}: ${error.message}`);
|
||||||
|
}
|
||||||
14
.github/workflows/docker-image.yml
vendored
14
.github/workflows/docker-image.yml
vendored
@@ -16,6 +16,10 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: |
|
||||||
|
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||||
|
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to DockerHub
|
- name: Login to DockerHub
|
||||||
@@ -25,7 +29,7 @@ jobs:
|
|||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Generate Build Metadata
|
- name: Generate Build Metadata
|
||||||
run: |
|
run: |
|
||||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||||
- name: Build and push (amd64)
|
- name: Build and push (amd64)
|
||||||
@@ -47,6 +51,10 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: |
|
||||||
|
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||||
|
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to DockerHub
|
- name: Login to DockerHub
|
||||||
@@ -56,7 +64,7 @@ jobs:
|
|||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Generate Build Metadata
|
- name: Generate Build Metadata
|
||||||
run: |
|
run: |
|
||||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||||
- name: Build and push (arm64)
|
- name: Build and push (arm64)
|
||||||
@@ -90,7 +98,7 @@ jobs:
|
|||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Generate Build Metadata
|
- name: Generate Build Metadata
|
||||||
run: |
|
run: |
|
||||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||||
- name: Create and push multi-arch manifests
|
- name: Create and push multi-arch manifests
|
||||||
|
|||||||
4
.github/workflows/pr-test-build.yml
vendored
4
.github/workflows/pr-test-build.yml
vendored
@@ -12,6 +12,10 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: |
|
||||||
|
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||||
|
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
|
|||||||
9
.github/workflows/release.yaml
vendored
9
.github/workflows/release.yaml
vendored
@@ -16,6 +16,10 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: |
|
||||||
|
git fetch --depth 1 https://github.com/router-for-me/models.git main
|
||||||
|
git show FETCH_HEAD:models.json > internal/registry/models/models.json
|
||||||
- run: git fetch --force --tags
|
- run: git fetch --force --tags
|
||||||
- uses: actions/setup-go@v4
|
- uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
@@ -23,15 +27,14 @@ jobs:
|
|||||||
cache: true
|
cache: true
|
||||||
- name: Generate Build Metadata
|
- name: Generate Build Metadata
|
||||||
run: |
|
run: |
|
||||||
VERSION=$(git describe --tags --always --dirty)
|
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
|
||||||
echo "VERSION=${VERSION}" >> $GITHUB_ENV
|
|
||||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||||
- uses: goreleaser/goreleaser-action@v4
|
- uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
distribution: goreleaser
|
distribution: goreleaser
|
||||||
version: latest
|
version: latest
|
||||||
args: release --clean
|
args: release --clean --skip=validate
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
VERSION: ${{ env.VERSION }}
|
VERSION: ${{ env.VERSION }}
|
||||||
|
|||||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -1,6 +1,7 @@
|
|||||||
# Binaries
|
# Binaries
|
||||||
cli-proxy-api
|
cli-proxy-api
|
||||||
cliproxy
|
cliproxy
|
||||||
|
/server
|
||||||
*.exe
|
*.exe
|
||||||
|
|
||||||
|
|
||||||
@@ -36,15 +37,16 @@ GEMINI.md
|
|||||||
|
|
||||||
# Tooling metadata
|
# Tooling metadata
|
||||||
.vscode/*
|
.vscode/*
|
||||||
|
.worktrees/
|
||||||
.codex/*
|
.codex/*
|
||||||
.claude/*
|
.claude/*
|
||||||
.gemini/*
|
.gemini/*
|
||||||
.serena/*
|
.serena/*
|
||||||
.agent/*
|
.agent/*
|
||||||
.agents/*
|
.agents/*
|
||||||
.agents/*
|
|
||||||
.opencode/*
|
.opencode/*
|
||||||
.idea/*
|
.idea/*
|
||||||
|
.beads/*
|
||||||
.bmad/*
|
.bmad/*
|
||||||
_bmad/*
|
_bmad/*
|
||||||
_bmad-output/*
|
_bmad-output/*
|
||||||
@@ -53,4 +55,10 @@ _bmad-output/*
|
|||||||
# macOS
|
# macOS
|
||||||
.DS_Store
|
.DS_Store
|
||||||
._*
|
._*
|
||||||
|
|
||||||
|
# Opencode
|
||||||
|
.beads/
|
||||||
|
.opencode/
|
||||||
|
.cli-proxy-api/
|
||||||
|
.venv/
|
||||||
*.bak
|
*.bak
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
version: 2
|
||||||
|
|
||||||
builds:
|
builds:
|
||||||
- id: "cli-proxy-api-plus"
|
- id: "cli-proxy-api-plus"
|
||||||
env:
|
env:
|
||||||
@@ -6,6 +8,7 @@ builds:
|
|||||||
- linux
|
- linux
|
||||||
- windows
|
- windows
|
||||||
- darwin
|
- darwin
|
||||||
|
- freebsd
|
||||||
goarch:
|
goarch:
|
||||||
- amd64
|
- amd64
|
||||||
- arm64
|
- arm64
|
||||||
|
|||||||
58
AGENTS.md
Normal file
58
AGENTS.md
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
# AGENTS.md
|
||||||
|
|
||||||
|
Go 1.26+ proxy server providing OpenAI/Gemini/Claude/Codex compatible APIs with OAuth and round-robin load balancing.
|
||||||
|
|
||||||
|
## Repository
|
||||||
|
- GitHub: https://github.com/router-for-me/CLIProxyAPI
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
```bash
|
||||||
|
gofmt -w . # Format (required after Go changes)
|
||||||
|
go build -o cli-proxy-api ./cmd/server # Build
|
||||||
|
go run ./cmd/server # Run dev server
|
||||||
|
go test ./... # Run all tests
|
||||||
|
go test -v -run TestName ./path/to/pkg # Run single test
|
||||||
|
go build -o test-output ./cmd/server && rm test-output # Verify compile (REQUIRED after changes)
|
||||||
|
```
|
||||||
|
- Common flags: `--config <path>`, `--tui`, `--standalone`, `--local-model`, `--no-browser`, `--oauth-callback-port <port>`
|
||||||
|
|
||||||
|
## Config
|
||||||
|
- Default config: `config.yaml` (template: `config.example.yaml`)
|
||||||
|
- `.env` is auto-loaded from the working directory
|
||||||
|
- Auth material defaults under `auths/`
|
||||||
|
- Storage backends: file-based default; optional Postgres/git/object store (`PGSTORE_*`, `GITSTORE_*`, `OBJECTSTORE_*`)
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
- `cmd/server/` — Server entrypoint
|
||||||
|
- `internal/api/` — Gin HTTP API (routes, middleware, modules)
|
||||||
|
- `internal/api/modules/amp/` — Amp integration (Amp-style routes + reverse proxy)
|
||||||
|
- `internal/thinking/` — Main thinking/reasoning pipeline. `ApplyThinking()` (apply.go) parses suffixes (`suffix.go`, suffix overrides body), normalizes config to canonical `ThinkingConfig` (`types.go`), normalizes and validates centrally (`validate.go`/`convert.go`), then applies provider-specific output via `ProviderApplier`. Do not break this "canonical representation → per-provider translation" architecture.
|
||||||
|
- `internal/runtime/executor/` — Per-provider runtime executors (incl. Codex WebSocket)
|
||||||
|
- `internal/translator/` — Provider protocol translators (and shared `common`)
|
||||||
|
- `internal/registry/` — Model registry + remote updater (`StartModelsUpdater`); `--local-model` disables remote updates
|
||||||
|
- `internal/store/` — Storage implementations and secret resolution
|
||||||
|
- `internal/managementasset/` — Config snapshots and management assets
|
||||||
|
- `internal/cache/` — Request signature caching
|
||||||
|
- `internal/watcher/` — Config hot-reload and watchers
|
||||||
|
- `internal/wsrelay/` — WebSocket relay sessions
|
||||||
|
- `internal/usage/` — Usage and token accounting
|
||||||
|
- `internal/tui/` — Bubbletea terminal UI (`--tui`, `--standalone`)
|
||||||
|
- `sdk/cliproxy/` — Embeddable SDK entry (service/builder/watchers/pipeline)
|
||||||
|
- `test/` — Cross-module integration tests
|
||||||
|
|
||||||
|
## Code Conventions
|
||||||
|
- Keep changes small and simple (KISS)
|
||||||
|
- Comments in English only
|
||||||
|
- If editing code that already contains non-English comments, translate them to English (don’t add new non-English comments)
|
||||||
|
- For user-visible strings, keep the existing language used in that file/area
|
||||||
|
- New Markdown docs should be in English unless the file is explicitly language-specific (e.g. `README_CN.md`)
|
||||||
|
- As a rule, do not make standalone changes to `internal/translator/`. You may modify it only as part of broader changes elsewhere.
|
||||||
|
- If a task requires changing only `internal/translator/`, run `gh repo view --json viewerPermission -q .viewerPermission` to confirm you have `WRITE`, `MAINTAIN`, or `ADMIN`. If you do, you may proceed; otherwise, file a GitHub issue including the goal, rationale, and the intended implementation code, then stop further work.
|
||||||
|
- `internal/runtime/executor/` should contain executors and their unit tests only. Place any helper/supporting files under `internal/runtime/executor/helps/`.
|
||||||
|
- Follow `gofmt`; keep imports goimports-style; wrap errors with context where helpful
|
||||||
|
- Do not use `log.Fatal`/`log.Fatalf` (terminates the process); prefer returning errors and logging via logrus
|
||||||
|
- Shadowed variables: use method suffix (`errStart := server.Start()`)
|
||||||
|
- Wrap defer errors: `defer func() { if err := f.Close(); err != nil { log.Errorf(...) } }()`
|
||||||
|
- Use logrus structured logging; avoid leaking secrets/tokens in logs
|
||||||
|
- Avoid panics in HTTP handlers; prefer logged errors and meaningful HTTP status codes
|
||||||
|
- Timeouts are allowed only during credential acquisition; after an upstream connection is established, do not set timeouts for any subsequent network behavior. Intentional exceptions that must remain allowed are the Codex websocket liveness deadlines in `internal/runtime/executor/codex_websockets_executor.go`, the wsrelay session deadlines in `internal/wsrelay/session.go`, the management APICall timeout in `internal/api/handlers/management/api_tools.go`, and the `cmd/fetch_antigravity_models` utility timeouts
|
||||||
126
README.md
126
README.md
@@ -8,132 +8,6 @@ All third-party provider support is maintained by community contributors; CLIPro
|
|||||||
|
|
||||||
The Plus release stays in lockstep with the mainline features.
|
The Plus release stays in lockstep with the mainline features.
|
||||||
|
|
||||||
## Differences from the Mainline
|
|
||||||
|
|
||||||
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
|
|
||||||
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)
|
|
||||||
|
|
||||||
## New Features (Plus Enhanced)
|
|
||||||
|
|
||||||
- **OAuth Web Authentication**: Browser-based OAuth login for Kiro with beautiful web UI
|
|
||||||
- **Rate Limiter**: Built-in request rate limiting to prevent API abuse
|
|
||||||
- **Background Token Refresh**: Automatic token refresh 10 minutes before expiration
|
|
||||||
- **Metrics & Monitoring**: Request metrics collection for monitoring and debugging
|
|
||||||
- **Device Fingerprint**: Device fingerprint generation for enhanced security
|
|
||||||
- **Cooldown Management**: Smart cooldown mechanism for API rate limits
|
|
||||||
- **Usage Checker**: Real-time usage monitoring and quota management
|
|
||||||
- **Model Converter**: Unified model name conversion across providers
|
|
||||||
- **UTF-8 Stream Processing**: Improved streaming response handling
|
|
||||||
|
|
||||||
## Kiro Authentication
|
|
||||||
|
|
||||||
### CLI Login
|
|
||||||
|
|
||||||
> **Note:** Google/GitHub login is not available for third-party applications due to AWS Cognito restrictions.
|
|
||||||
|
|
||||||
**AWS Builder ID** (recommended):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Device code flow
|
|
||||||
./CLIProxyAPI --kiro-aws-login
|
|
||||||
|
|
||||||
# Authorization code flow
|
|
||||||
./CLIProxyAPI --kiro-aws-authcode
|
|
||||||
```
|
|
||||||
|
|
||||||
**Import token from Kiro IDE:**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./CLIProxyAPI --kiro-import
|
|
||||||
```
|
|
||||||
|
|
||||||
To get a token from Kiro IDE:
|
|
||||||
|
|
||||||
1. Open Kiro IDE and login with Google (or GitHub)
|
|
||||||
2. Find the token file: `~/.kiro/kiro-auth-token.json`
|
|
||||||
3. Run: `./CLIProxyAPI --kiro-import`
|
|
||||||
|
|
||||||
**AWS IAM Identity Center (IDC):**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
|
||||||
|
|
||||||
# Specify region
|
|
||||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
|
||||||
```
|
|
||||||
|
|
||||||
**Additional flags:**
|
|
||||||
|
|
||||||
| Flag | Description |
|
|
||||||
|------|-------------|
|
|
||||||
| `--no-browser` | Don't open browser automatically, print URL instead |
|
|
||||||
| `--no-incognito` | Use existing browser session (Kiro defaults to incognito). Useful for corporate SSO that requires an authenticated browser session |
|
|
||||||
| `--kiro-idc-start-url` | IDC Start URL (required with `--kiro-idc-login`) |
|
|
||||||
| `--kiro-idc-region` | IDC region (default: `us-east-1`) |
|
|
||||||
| `--kiro-idc-flow` | IDC flow type: `authcode` (default) or `device` |
|
|
||||||
|
|
||||||
### Web-based OAuth Login
|
|
||||||
|
|
||||||
Access the Kiro OAuth web interface at:
|
|
||||||
|
|
||||||
```
|
|
||||||
http://your-server:8080/v0/oauth/kiro
|
|
||||||
```
|
|
||||||
|
|
||||||
This provides a browser-based OAuth flow for Kiro (AWS CodeWhisperer) authentication with:
|
|
||||||
- AWS Builder ID login
|
|
||||||
- AWS Identity Center (IDC) login
|
|
||||||
- Token import from Kiro IDE
|
|
||||||
|
|
||||||
## Quick Deployment with Docker
|
|
||||||
|
|
||||||
### One-Command Deployment
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Create deployment directory
|
|
||||||
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
|
||||||
|
|
||||||
# Create docker-compose.yml
|
|
||||||
cat > docker-compose.yml << 'EOF'
|
|
||||||
services:
|
|
||||||
cli-proxy-api:
|
|
||||||
image: eceasy/cli-proxy-api-plus:latest
|
|
||||||
container_name: cli-proxy-api-plus
|
|
||||||
ports:
|
|
||||||
- "8317:8317"
|
|
||||||
volumes:
|
|
||||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
|
||||||
- ./auths:/root/.cli-proxy-api
|
|
||||||
- ./logs:/CLIProxyAPI/logs
|
|
||||||
restart: unless-stopped
|
|
||||||
EOF
|
|
||||||
|
|
||||||
# Download example config
|
|
||||||
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
|
|
||||||
|
|
||||||
# Pull and start
|
|
||||||
docker compose pull && docker compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
### Configuration
|
|
||||||
|
|
||||||
Edit `config.yaml` before starting:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# Basic configuration example
|
|
||||||
server:
|
|
||||||
port: 8317
|
|
||||||
|
|
||||||
# Add your provider configurations here
|
|
||||||
```
|
|
||||||
|
|
||||||
### Update to Latest Version
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd ~/cli-proxy
|
|
||||||
docker compose pull && docker compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.
|
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.
|
||||||
|
|||||||
128
README_CN.md
128
README_CN.md
@@ -6,134 +6,6 @@
|
|||||||
|
|
||||||
所有的第三方供应商支持都由第三方社区维护者提供,CLIProxyAPI 不提供技术支持。如需取得支持,请与对应的社区维护者联系。
|
所有的第三方供应商支持都由第三方社区维护者提供,CLIProxyAPI 不提供技术支持。如需取得支持,请与对应的社区维护者联系。
|
||||||
|
|
||||||
该 Plus 版本的主线功能与主线功能强制同步。
|
|
||||||
|
|
||||||
## 与主线版本版本差异
|
|
||||||
|
|
||||||
- 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
|
|
||||||
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供
|
|
||||||
|
|
||||||
## 新增功能 (Plus 增强版)
|
|
||||||
|
|
||||||
- **OAuth Web 认证**: 基于浏览器的 Kiro OAuth 登录,提供美观的 Web UI
|
|
||||||
- **请求限流器**: 内置请求限流,防止 API 滥用
|
|
||||||
- **后台令牌刷新**: 过期前 10 分钟自动刷新令牌
|
|
||||||
- **监控指标**: 请求指标收集,用于监控和调试
|
|
||||||
- **设备指纹**: 设备指纹生成,增强安全性
|
|
||||||
- **冷却管理**: 智能冷却机制,应对 API 速率限制
|
|
||||||
- **用量检查器**: 实时用量监控和配额管理
|
|
||||||
- **模型转换器**: 跨供应商的统一模型名称转换
|
|
||||||
- **UTF-8 流处理**: 改进的流式响应处理
|
|
||||||
|
|
||||||
## Kiro 认证
|
|
||||||
|
|
||||||
### 命令行登录
|
|
||||||
|
|
||||||
> **注意:** 由于 AWS Cognito 限制,Google/GitHub 登录不可用于第三方应用。
|
|
||||||
|
|
||||||
**AWS Builder ID**(推荐):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 设备码流程
|
|
||||||
./CLIProxyAPI --kiro-aws-login
|
|
||||||
|
|
||||||
# 授权码流程
|
|
||||||
./CLIProxyAPI --kiro-aws-authcode
|
|
||||||
```
|
|
||||||
|
|
||||||
**从 Kiro IDE 导入令牌:**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./CLIProxyAPI --kiro-import
|
|
||||||
```
|
|
||||||
|
|
||||||
获取令牌步骤:
|
|
||||||
|
|
||||||
1. 打开 Kiro IDE,使用 Google(或 GitHub)登录
|
|
||||||
2. 找到令牌文件:`~/.kiro/kiro-auth-token.json`
|
|
||||||
3. 运行:`./CLIProxyAPI --kiro-import`
|
|
||||||
|
|
||||||
**AWS IAM Identity Center (IDC):**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
|
||||||
|
|
||||||
# 指定区域
|
|
||||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
|
||||||
```
|
|
||||||
|
|
||||||
**附加参数:**
|
|
||||||
|
|
||||||
| 参数 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| `--no-browser` | 不自动打开浏览器,打印 URL |
|
|
||||||
| `--no-incognito` | 使用已有浏览器会话(Kiro 默认使用无痕模式),适用于需要已登录浏览器会话的企业 SSO 场景 |
|
|
||||||
| `--kiro-idc-start-url` | IDC Start URL(`--kiro-idc-login` 必需) |
|
|
||||||
| `--kiro-idc-region` | IDC 区域(默认:`us-east-1`) |
|
|
||||||
| `--kiro-idc-flow` | IDC 流程类型:`authcode`(默认)或 `device` |
|
|
||||||
|
|
||||||
### 网页端 OAuth 登录
|
|
||||||
|
|
||||||
访问 Kiro OAuth 网页认证界面:
|
|
||||||
|
|
||||||
```
|
|
||||||
http://your-server:8080/v0/oauth/kiro
|
|
||||||
```
|
|
||||||
|
|
||||||
提供基于浏览器的 Kiro (AWS CodeWhisperer) OAuth 认证流程,支持:
|
|
||||||
- AWS Builder ID 登录
|
|
||||||
- AWS Identity Center (IDC) 登录
|
|
||||||
- 从 Kiro IDE 导入令牌
|
|
||||||
|
|
||||||
## Docker 快速部署
|
|
||||||
|
|
||||||
### 一键部署
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 创建部署目录
|
|
||||||
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
|
||||||
|
|
||||||
# 创建 docker-compose.yml
|
|
||||||
cat > docker-compose.yml << 'EOF'
|
|
||||||
services:
|
|
||||||
cli-proxy-api:
|
|
||||||
image: eceasy/cli-proxy-api-plus:latest
|
|
||||||
container_name: cli-proxy-api-plus
|
|
||||||
ports:
|
|
||||||
- "8317:8317"
|
|
||||||
volumes:
|
|
||||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
|
||||||
- ./auths:/root/.cli-proxy-api
|
|
||||||
- ./logs:/CLIProxyAPI/logs
|
|
||||||
restart: unless-stopped
|
|
||||||
EOF
|
|
||||||
|
|
||||||
# 下载示例配置
|
|
||||||
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
|
|
||||||
|
|
||||||
# 拉取并启动
|
|
||||||
docker compose pull && docker compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
### 配置说明
|
|
||||||
|
|
||||||
启动前请编辑 `config.yaml`:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# 基本配置示例
|
|
||||||
server:
|
|
||||||
port: 8317
|
|
||||||
|
|
||||||
# 在此添加你的供应商配置
|
|
||||||
```
|
|
||||||
|
|
||||||
### 更新到最新版本
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd ~/cli-proxy
|
|
||||||
docker compose pull && docker compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
## 贡献
|
## 贡献
|
||||||
|
|
||||||
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
|
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
|
||||||
|
|||||||
BIN
assets/bmoplus.png
Normal file
BIN
assets/bmoplus.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 28 KiB |
BIN
assets/lingtrue.png
Normal file
BIN
assets/lingtrue.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 129 KiB |
276
cmd/fetch_antigravity_models/main.go
Normal file
276
cmd/fetch_antigravity_models/main.go
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
// Command fetch_antigravity_models connects to the Antigravity API using the
|
||||||
|
// stored auth credentials and saves the dynamically fetched model list to a
|
||||||
|
// JSON file for inspection or offline use.
|
||||||
|
//
|
||||||
|
// Usage:
|
||||||
|
//
|
||||||
|
// go run ./cmd/fetch_antigravity_models [flags]
|
||||||
|
//
|
||||||
|
// Flags:
|
||||||
|
//
|
||||||
|
// --auths-dir <path> Directory containing auth JSON files (default: "auths")
|
||||||
|
// --output <path> Output JSON file path (default: "antigravity_models.json")
|
||||||
|
// --pretty Pretty-print the output JSON (default: true)
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
sdkauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
antigravityBaseURLDaily = "https://daily-cloudcode-pa.googleapis.com"
|
||||||
|
antigravitySandboxBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||||
|
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
|
||||||
|
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
logging.SetupBaseLogger()
|
||||||
|
log.SetLevel(log.InfoLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelOutput wraps the fetched model list with fetch metadata.
|
||||||
|
type modelOutput struct {
|
||||||
|
Models []modelEntry `json:"models"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelEntry contains only the fields we want to keep for static model definitions.
|
||||||
|
type modelEntry struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
DisplayName string `json:"display_name"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
ContextLength int `json:"context_length,omitempty"`
|
||||||
|
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var authsDir string
|
||||||
|
var outputPath string
|
||||||
|
var pretty bool
|
||||||
|
|
||||||
|
flag.StringVar(&authsDir, "auths-dir", "auths", "Directory containing auth JSON files")
|
||||||
|
flag.StringVar(&outputPath, "output", "antigravity_models.json", "Output JSON file path")
|
||||||
|
flag.BoolVar(&pretty, "pretty", true, "Pretty-print the output JSON")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// Resolve relative paths against the working directory.
|
||||||
|
wd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: cannot get working directory: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if !filepath.IsAbs(authsDir) {
|
||||||
|
authsDir = filepath.Join(wd, authsDir)
|
||||||
|
}
|
||||||
|
if !filepath.IsAbs(outputPath) {
|
||||||
|
outputPath = filepath.Join(wd, outputPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Scanning auth files in: %s\n", authsDir)
|
||||||
|
|
||||||
|
// Load all auth records from the directory.
|
||||||
|
fileStore := sdkauth.NewFileTokenStore()
|
||||||
|
fileStore.SetBaseDir(authsDir)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
auths, err := fileStore.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: failed to list auth files: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if len(auths) == 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: no auth files found in %s\n", authsDir)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the first enabled antigravity auth.
|
||||||
|
var chosen *coreauth.Auth
|
||||||
|
for _, a := range auths {
|
||||||
|
if a == nil || a.Disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.EqualFold(strings.TrimSpace(a.Provider), "antigravity") {
|
||||||
|
chosen = a
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if chosen == nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: no enabled antigravity auth found in %s\n", authsDir)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Using auth: id=%s label=%s\n", chosen.ID, chosen.Label)
|
||||||
|
|
||||||
|
// Fetch models from the upstream Antigravity API.
|
||||||
|
fmt.Println("Fetching Antigravity model list from upstream...")
|
||||||
|
|
||||||
|
fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
models := fetchModels(fetchCtx, chosen)
|
||||||
|
if len(models) == 0 {
|
||||||
|
fmt.Fprintln(os.Stderr, "warning: no models returned (API may be unavailable or token expired)")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Fetched %d models.\n", len(models))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the output payload.
|
||||||
|
out := modelOutput{
|
||||||
|
Models: models,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal to JSON.
|
||||||
|
var raw []byte
|
||||||
|
if pretty {
|
||||||
|
raw, err = json.MarshalIndent(out, "", " ")
|
||||||
|
} else {
|
||||||
|
raw, err = json.Marshal(out)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: failed to marshal JSON: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = os.WriteFile(outputPath, raw, 0o644); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: failed to write output file %s: %v\n", outputPath, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Model list saved to: %s\n", outputPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchModels(ctx context.Context, auth *coreauth.Auth) []modelEntry {
|
||||||
|
accessToken := metaStringValue(auth.Metadata, "access_token")
|
||||||
|
if accessToken == "" {
|
||||||
|
fmt.Fprintln(os.Stderr, "error: no access token found in auth")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURLs := []string{antigravityBaseURLProd, antigravityBaseURLDaily, antigravitySandboxBaseURLDaily}
|
||||||
|
|
||||||
|
for _, baseURL := range baseURLs {
|
||||||
|
modelsURL := baseURL + antigravityModelsPath
|
||||||
|
|
||||||
|
var payload []byte
|
||||||
|
if auth != nil && auth.Metadata != nil {
|
||||||
|
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
|
||||||
|
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(payload) == 0 {
|
||||||
|
payload = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, strings.NewReader(string(payload)))
|
||||||
|
if errReq != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
httpReq.Close = true
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
httpReq.Header.Set("User-Agent", misc.AntigravityUserAgent())
|
||||||
|
|
||||||
|
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
if transport, _, errProxy := proxyutil.BuildHTTPTransport(auth.ProxyURL); errProxy == nil && transport != nil {
|
||||||
|
httpClient.Transport = transport
|
||||||
|
}
|
||||||
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
|
if errDo != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
httpResp.Body.Close()
|
||||||
|
if errRead != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
result := gjson.GetBytes(bodyBytes, "models")
|
||||||
|
if !result.Exists() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var models []modelEntry
|
||||||
|
|
||||||
|
for originalName, modelData := range result.Map() {
|
||||||
|
modelID := strings.TrimSpace(originalName)
|
||||||
|
if modelID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Skip internal/experimental models
|
||||||
|
switch modelID {
|
||||||
|
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
displayName := modelData.Get("displayName").String()
|
||||||
|
if displayName == "" {
|
||||||
|
displayName = modelID
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := modelEntry{
|
||||||
|
ID: modelID,
|
||||||
|
Object: "model",
|
||||||
|
OwnedBy: "antigravity",
|
||||||
|
Type: "antigravity",
|
||||||
|
DisplayName: displayName,
|
||||||
|
Name: modelID,
|
||||||
|
Description: displayName,
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
|
||||||
|
entry.ContextLength = int(maxTok)
|
||||||
|
}
|
||||||
|
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
|
||||||
|
entry.MaxCompletionTokens = int(maxOut)
|
||||||
|
}
|
||||||
|
|
||||||
|
models = append(models, entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func metaStringValue(m map[string]interface{}, key string) string {
|
||||||
|
if m == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
v, ok := m[key]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
switch val := v.(type) {
|
||||||
|
case string:
|
||||||
|
return val
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
20
cmd/mcpdebug/main.go
Normal file
20
cmd/mcpdebug/main.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Encode MCP result with empty execId
|
||||||
|
resultBytes := cursorproto.EncodeExecMcpResult(1, "", `{"test": "data"}`, false)
|
||||||
|
fmt.Printf("Result protobuf hex: %s\n", hex.EncodeToString(resultBytes))
|
||||||
|
fmt.Printf("Result length: %d bytes\n", len(resultBytes))
|
||||||
|
|
||||||
|
// Write to file for analysis
|
||||||
|
os.WriteFile("mcp_result.bin", resultBytes)
|
||||||
|
fmt.Println("Wrote mcp_result.bin")
|
||||||
|
}
|
||||||
32
cmd/protocheck/main.go
Normal file
32
cmd/protocheck/main.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
ecm := cursorproto.NewMsg("ExecClientMessage")
|
||||||
|
|
||||||
|
// Try different field names
|
||||||
|
names := []string{
|
||||||
|
"mcp_result", "mcpResult", "McpResult", "MCP_RESULT",
|
||||||
|
"shell_result", "shellResult",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range names {
|
||||||
|
fd := ecm.Descriptor().Fields().ByName(name)
|
||||||
|
if fd != nil {
|
||||||
|
fmt.Printf("Found field %q: number=%d, kind=%s\n", name, fd.Number(), fd.Kind())
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Field %q NOT FOUND\n", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List all fields
|
||||||
|
fmt.Println("\nAll fields in ExecClientMessage:")
|
||||||
|
for i := 0; i < ecm.Descriptor().Fields().Len(); i++ {
|
||||||
|
f := ecm.Descriptor().Fields().Get(i)
|
||||||
|
fmt.Printf(" %d: %q (number=%d)\n", i, f.Name(), f.Number())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
||||||
@@ -74,14 +75,16 @@ func main() {
|
|||||||
var codexLogin bool
|
var codexLogin bool
|
||||||
var codexDeviceLogin bool
|
var codexDeviceLogin bool
|
||||||
var claudeLogin bool
|
var claudeLogin bool
|
||||||
var qwenLogin bool
|
|
||||||
var kiloLogin bool
|
var kiloLogin bool
|
||||||
var iflowLogin bool
|
var iflowLogin bool
|
||||||
var iflowCookie bool
|
var iflowCookie bool
|
||||||
|
var gitlabLogin bool
|
||||||
|
var gitlabTokenLogin bool
|
||||||
var noBrowser bool
|
var noBrowser bool
|
||||||
var oauthCallbackPort int
|
var oauthCallbackPort int
|
||||||
var antigravityLogin bool
|
var antigravityLogin bool
|
||||||
var kimiLogin bool
|
var kimiLogin bool
|
||||||
|
var cursorLogin bool
|
||||||
var kiroLogin bool
|
var kiroLogin bool
|
||||||
var kiroGoogleLogin bool
|
var kiroGoogleLogin bool
|
||||||
var kiroAWSLogin bool
|
var kiroAWSLogin bool
|
||||||
@@ -92,30 +95,35 @@ func main() {
|
|||||||
var kiroIDCRegion string
|
var kiroIDCRegion string
|
||||||
var kiroIDCFlow string
|
var kiroIDCFlow string
|
||||||
var githubCopilotLogin bool
|
var githubCopilotLogin bool
|
||||||
|
var codeBuddyLogin bool
|
||||||
var projectID string
|
var projectID string
|
||||||
var vertexImport string
|
var vertexImport string
|
||||||
|
var vertexImportPrefix string
|
||||||
var configPath string
|
var configPath string
|
||||||
var password string
|
var password string
|
||||||
var tuiMode bool
|
var tuiMode bool
|
||||||
var standalone bool
|
var standalone bool
|
||||||
var noIncognito bool
|
var noIncognito bool
|
||||||
var useIncognito bool
|
var useIncognito bool
|
||||||
|
var localModel bool
|
||||||
|
|
||||||
// Define command-line flags for different operation modes.
|
// Define command-line flags for different operation modes.
|
||||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||||
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
||||||
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
|
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
|
||||||
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
||||||
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
|
||||||
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
||||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||||
|
flag.BoolVar(&gitlabLogin, "gitlab-login", false, "Login to GitLab Duo using OAuth")
|
||||||
|
flag.BoolVar(&gitlabTokenLogin, "gitlab-token-login", false, "Login to GitLab Duo using a personal access token")
|
||||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||||
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
||||||
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
||||||
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
||||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||||
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
|
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
|
||||||
|
flag.BoolVar(&cursorLogin, "cursor-login", false, "Login to Cursor using OAuth")
|
||||||
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
||||||
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
||||||
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
||||||
@@ -126,12 +134,15 @@ func main() {
|
|||||||
flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)")
|
flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)")
|
||||||
flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device")
|
flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device")
|
||||||
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
||||||
|
flag.BoolVar(&codeBuddyLogin, "codebuddy-login", false, "Login to CodeBuddy using browser OAuth flow")
|
||||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||||
|
flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)")
|
||||||
flag.StringVar(&password, "password", "", "")
|
flag.StringVar(&password, "password", "", "")
|
||||||
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
||||||
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
||||||
|
flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching")
|
||||||
|
|
||||||
flag.CommandLine.Usage = func() {
|
flag.CommandLine.Usage = func() {
|
||||||
out := flag.CommandLine.Output()
|
out := flag.CommandLine.Output()
|
||||||
@@ -177,6 +188,7 @@ func main() {
|
|||||||
gitStoreRemoteURL string
|
gitStoreRemoteURL string
|
||||||
gitStoreUser string
|
gitStoreUser string
|
||||||
gitStorePassword string
|
gitStorePassword string
|
||||||
|
gitStoreBranch string
|
||||||
gitStoreLocalPath string
|
gitStoreLocalPath string
|
||||||
gitStoreInst *store.GitTokenStore
|
gitStoreInst *store.GitTokenStore
|
||||||
gitStoreRoot string
|
gitStoreRoot string
|
||||||
@@ -246,6 +258,9 @@ func main() {
|
|||||||
if value, ok := lookupEnv("GITSTORE_LOCAL_PATH", "gitstore_local_path"); ok {
|
if value, ok := lookupEnv("GITSTORE_LOCAL_PATH", "gitstore_local_path"); ok {
|
||||||
gitStoreLocalPath = value
|
gitStoreLocalPath = value
|
||||||
}
|
}
|
||||||
|
if value, ok := lookupEnv("GITSTORE_GIT_BRANCH", "gitstore_git_branch"); ok {
|
||||||
|
gitStoreBranch = value
|
||||||
|
}
|
||||||
if value, ok := lookupEnv("OBJECTSTORE_ENDPOINT", "objectstore_endpoint"); ok {
|
if value, ok := lookupEnv("OBJECTSTORE_ENDPOINT", "objectstore_endpoint"); ok {
|
||||||
useObjectStore = true
|
useObjectStore = true
|
||||||
objectStoreEndpoint = value
|
objectStoreEndpoint = value
|
||||||
@@ -380,7 +395,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
gitStoreRoot = filepath.Join(gitStoreLocalPath, "gitstore")
|
gitStoreRoot = filepath.Join(gitStoreLocalPath, "gitstore")
|
||||||
authDir := filepath.Join(gitStoreRoot, "auths")
|
authDir := filepath.Join(gitStoreRoot, "auths")
|
||||||
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword)
|
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword, gitStoreBranch)
|
||||||
gitStoreInst.SetBaseDir(authDir)
|
gitStoreInst.SetBaseDir(authDir)
|
||||||
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
|
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
|
||||||
log.Errorf("failed to prepare git token store: %v", errRepo)
|
log.Errorf("failed to prepare git token store: %v", errRepo)
|
||||||
@@ -499,7 +514,7 @@ func main() {
|
|||||||
|
|
||||||
if vertexImport != "" {
|
if vertexImport != "" {
|
||||||
// Handle Vertex service account import
|
// Handle Vertex service account import
|
||||||
cmd.DoVertexImport(cfg, vertexImport)
|
cmd.DoVertexImport(cfg, vertexImport, vertexImportPrefix)
|
||||||
} else if login {
|
} else if login {
|
||||||
// Handle Google/Gemini login
|
// Handle Google/Gemini login
|
||||||
cmd.DoLogin(cfg, projectID, options)
|
cmd.DoLogin(cfg, projectID, options)
|
||||||
@@ -509,6 +524,9 @@ func main() {
|
|||||||
} else if githubCopilotLogin {
|
} else if githubCopilotLogin {
|
||||||
// Handle GitHub Copilot login
|
// Handle GitHub Copilot login
|
||||||
cmd.DoGitHubCopilotLogin(cfg, options)
|
cmd.DoGitHubCopilotLogin(cfg, options)
|
||||||
|
} else if codeBuddyLogin {
|
||||||
|
// Handle CodeBuddy login
|
||||||
|
cmd.DoCodeBuddyLogin(cfg, options)
|
||||||
} else if codexLogin {
|
} else if codexLogin {
|
||||||
// Handle Codex login
|
// Handle Codex login
|
||||||
cmd.DoCodexLogin(cfg, options)
|
cmd.DoCodexLogin(cfg, options)
|
||||||
@@ -518,16 +536,20 @@ func main() {
|
|||||||
} else if claudeLogin {
|
} else if claudeLogin {
|
||||||
// Handle Claude login
|
// Handle Claude login
|
||||||
cmd.DoClaudeLogin(cfg, options)
|
cmd.DoClaudeLogin(cfg, options)
|
||||||
} else if qwenLogin {
|
|
||||||
cmd.DoQwenLogin(cfg, options)
|
|
||||||
} else if kiloLogin {
|
} else if kiloLogin {
|
||||||
cmd.DoKiloLogin(cfg, options)
|
cmd.DoKiloLogin(cfg, options)
|
||||||
} else if iflowLogin {
|
} else if iflowLogin {
|
||||||
cmd.DoIFlowLogin(cfg, options)
|
cmd.DoIFlowLogin(cfg, options)
|
||||||
} else if iflowCookie {
|
} else if iflowCookie {
|
||||||
cmd.DoIFlowCookieAuth(cfg, options)
|
cmd.DoIFlowCookieAuth(cfg, options)
|
||||||
|
} else if gitlabLogin {
|
||||||
|
cmd.DoGitLabLogin(cfg, options)
|
||||||
|
} else if gitlabTokenLogin {
|
||||||
|
cmd.DoGitLabTokenLogin(cfg, options)
|
||||||
} else if kimiLogin {
|
} else if kimiLogin {
|
||||||
cmd.DoKimiLogin(cfg, options)
|
cmd.DoKimiLogin(cfg, options)
|
||||||
|
} else if cursorLogin {
|
||||||
|
cmd.DoCursorLogin(cfg, options)
|
||||||
} else if kiroLogin {
|
} else if kiroLogin {
|
||||||
// For Kiro auth, default to incognito mode for multi-account support
|
// For Kiro auth, default to incognito mode for multi-account support
|
||||||
// Users can explicitly override with --no-incognito
|
// Users can explicitly override with --no-incognito
|
||||||
@@ -569,10 +591,17 @@ func main() {
|
|||||||
cmd.WaitForCloudDeploy()
|
cmd.WaitForCloudDeploy()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if localModel && (!tuiMode || standalone) {
|
||||||
|
log.Info("Local model mode: using embedded model catalog, remote model updates disabled")
|
||||||
|
}
|
||||||
if tuiMode {
|
if tuiMode {
|
||||||
if standalone {
|
if standalone {
|
||||||
// Standalone mode: start an embedded local server and connect TUI client to it.
|
// Standalone mode: start an embedded local server and connect TUI client to it.
|
||||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||||
|
misc.StartAntigravityVersionUpdater(context.Background())
|
||||||
|
if !localModel {
|
||||||
|
registry.StartModelsUpdater(context.Background())
|
||||||
|
}
|
||||||
hook := tui.NewLogHook(2000)
|
hook := tui.NewLogHook(2000)
|
||||||
hook.SetFormatter(&logging.LogFormatter{})
|
hook.SetFormatter(&logging.LogFormatter{})
|
||||||
log.AddHook(hook)
|
log.AddHook(hook)
|
||||||
@@ -643,15 +672,19 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Start the main proxy service
|
// Start the main proxy service
|
||||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||||
|
misc.StartAntigravityVersionUpdater(context.Background())
|
||||||
|
if !localModel {
|
||||||
|
registry.StartModelsUpdater(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
if cfg.AuthDir != "" {
|
if cfg.AuthDir != "" {
|
||||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||||
defer kiro.StopGlobalRefreshManager()
|
defer kiro.StopGlobalRefreshManager()
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.StartService(cfg, configFilePath, password)
|
cmd.StartService(cfg, configFilePath, password)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,6 +25,10 @@ remote-management:
|
|||||||
# Disable the bundled management control panel asset download and HTTP route when true.
|
# Disable the bundled management control panel asset download and HTTP route when true.
|
||||||
disable-control-panel: false
|
disable-control-panel: false
|
||||||
|
|
||||||
|
# Disable automatic periodic background updates of the management panel from GitHub (default: false).
|
||||||
|
# When enabled, the panel is only downloaded on first access if missing, and never auto-updated afterward.
|
||||||
|
# disable-auto-update-panel: false
|
||||||
|
|
||||||
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
|
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
|
||||||
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
|
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
|
||||||
|
|
||||||
@@ -68,7 +72,8 @@ error-logs-max-files: 10
|
|||||||
usage-statistics-enabled: false
|
usage-statistics-enabled: false
|
||||||
|
|
||||||
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
||||||
proxy-url: ''
|
# Per-entry proxy-url also supports "direct" or "none" to bypass both the global proxy-url and environment proxies explicitly.
|
||||||
|
proxy-url: ""
|
||||||
|
|
||||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||||
force-model-prefix: false
|
force-model-prefix: false
|
||||||
@@ -87,26 +92,54 @@ max-retry-credentials: 0
|
|||||||
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
|
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
|
||||||
max-retry-interval: 30
|
max-retry-interval: 30
|
||||||
|
|
||||||
|
# When true, disable auth/model cooldown scheduling globally (prevents blackout windows after failure states).
|
||||||
|
disable-cooling: false
|
||||||
|
|
||||||
|
# Core auth auto-refresh worker pool size (OAuth/file-based auth token refresh).
|
||||||
|
# When > 0, overrides the default worker count (16).
|
||||||
|
# auth-auto-refresh-workers: 16
|
||||||
|
|
||||||
# Quota exceeded behavior
|
# Quota exceeded behavior
|
||||||
quota-exceeded:
|
quota-exceeded:
|
||||||
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||||
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
||||||
|
antigravity-credits: true # Whether to retry Antigravity quota_exhausted 429s once with enabledCreditTypes=["GOOGLE_ONE_AI"]
|
||||||
|
|
||||||
# Routing strategy for selecting credentials when multiple match.
|
# Routing strategy for selecting credentials when multiple match.
|
||||||
routing:
|
routing:
|
||||||
strategy: 'round-robin' # round-robin (default), fill-first
|
strategy: "round-robin" # round-robin (default), fill-first
|
||||||
|
# Enable universal session-sticky routing for all clients.
|
||||||
|
# Session IDs are extracted from: X-Session-ID header, Idempotency-Key,
|
||||||
|
# metadata.user_id, conversation_id, or first few messages hash.
|
||||||
|
# Automatic failover is always enabled when bound auth becomes unavailable.
|
||||||
|
session-affinity: false # default: false
|
||||||
|
# How long session-to-auth bindings are retained. Default: 1h
|
||||||
|
session-affinity-ttl: "1h"
|
||||||
|
|
||||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||||
ws-auth: false
|
ws-auth: false
|
||||||
|
|
||||||
|
# When true, enable Gemini CLI internal endpoints (/v1internal:*).
|
||||||
|
# Default is false for safety.
|
||||||
|
enable-gemini-cli-endpoint: false
|
||||||
|
|
||||||
# When > 0, emit blank lines every N seconds for non-streaming responses to prevent idle timeouts.
|
# When > 0, emit blank lines every N seconds for non-streaming responses to prevent idle timeouts.
|
||||||
nonstream-keepalive-interval: 0
|
nonstream-keepalive-interval: 0
|
||||||
|
|
||||||
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
|
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
|
||||||
# streaming:
|
# streaming:
|
||||||
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
|
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
|
||||||
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
|
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
|
||||||
|
|
||||||
|
# Signature cache validation for thinking blocks (Antigravity/Claude).
|
||||||
|
# When true (default), cached signatures are preferred and validated.
|
||||||
|
# When false, client signatures are used directly after normalization (bypass mode for testing).
|
||||||
|
# antigravity-signature-cache-enabled: true
|
||||||
|
|
||||||
|
# Bypass mode signature validation strictness (only applies when signature cache is disabled).
|
||||||
|
# When true, validates full Claude protobuf tree (Field 2 -> Field 1 structure).
|
||||||
|
# When false (default), only checks R/E prefix + base64 + first byte 0x12.
|
||||||
|
# antigravity-signature-bypass-strict: false
|
||||||
|
|
||||||
# Gemini API keys
|
# Gemini API keys
|
||||||
# gemini-api-key:
|
# gemini-api-key:
|
||||||
# - api-key: "AIzaSy...01"
|
# - api-key: "AIzaSy...01"
|
||||||
@@ -115,6 +148,7 @@ nonstream-keepalive-interval: 0
|
|||||||
# headers:
|
# headers:
|
||||||
# X-Custom-Header: "custom-value"
|
# X-Custom-Header: "custom-value"
|
||||||
# proxy-url: "socks5://proxy.example.com:1080"
|
# proxy-url: "socks5://proxy.example.com:1080"
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# models:
|
# models:
|
||||||
# - name: "gemini-2.5-flash" # upstream model name
|
# - name: "gemini-2.5-flash" # upstream model name
|
||||||
# alias: "gemini-flash" # client alias mapped to the upstream model
|
# alias: "gemini-flash" # client alias mapped to the upstream model
|
||||||
@@ -133,6 +167,7 @@ nonstream-keepalive-interval: 0
|
|||||||
# headers:
|
# headers:
|
||||||
# X-Custom-Header: "custom-value"
|
# X-Custom-Header: "custom-value"
|
||||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# models:
|
# models:
|
||||||
# - name: "gpt-5-codex" # upstream model name
|
# - name: "gpt-5-codex" # upstream model name
|
||||||
# alias: "codex-latest" # client alias mapped to the upstream model
|
# alias: "codex-latest" # client alias mapped to the upstream model
|
||||||
@@ -151,6 +186,7 @@ nonstream-keepalive-interval: 0
|
|||||||
# headers:
|
# headers:
|
||||||
# X-Custom-Header: "custom-value"
|
# X-Custom-Header: "custom-value"
|
||||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# models:
|
# models:
|
||||||
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
||||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||||
@@ -169,14 +205,31 @@ nonstream-keepalive-interval: 0
|
|||||||
# - "API"
|
# - "API"
|
||||||
# - "proxy"
|
# - "proxy"
|
||||||
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
||||||
|
# experimental-cch-signing: false # optional: default is false; when true, sign the final /v1/messages body using the current Claude Code cch algorithm
|
||||||
|
# # keep this disabled unless you explicitly need the behavior, so upstream seed changes fall back to legacy proxy behavior
|
||||||
|
|
||||||
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
||||||
# These are used as fallbacks when the client does not send its own headers.
|
# In legacy mode, user-agent/package-version/runtime-version/timeout are used as fallbacks
|
||||||
|
# when the client omits them, while OS/arch remain runtime-derived. When
|
||||||
|
# stabilize-device-profile is enabled, OS/arch stay pinned to the baseline values below,
|
||||||
|
# while user-agent/package-version/runtime-version seed a software fingerprint that can
|
||||||
|
# still upgrade to newer official Claude client versions.
|
||||||
# claude-header-defaults:
|
# claude-header-defaults:
|
||||||
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
|
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
|
||||||
# package-version: "0.74.0"
|
# package-version: "0.74.0"
|
||||||
# runtime-version: "v24.3.0"
|
# runtime-version: "v24.3.0"
|
||||||
|
# os: "MacOS"
|
||||||
|
# arch: "arm64"
|
||||||
# timeout: "600"
|
# timeout: "600"
|
||||||
|
# stabilize-device-profile: false # optional, default false; set true to enable per-auth/API-key fingerprint pinning
|
||||||
|
|
||||||
|
# Default headers for Codex OAuth model requests.
|
||||||
|
# These are used only for file-backed/OAuth Codex requests when the client
|
||||||
|
# does not send the header. `user-agent` applies to HTTP and websocket requests;
|
||||||
|
# `beta-features` only applies to websocket requests. They do not apply to codex-api-key entries.
|
||||||
|
# codex-header-defaults:
|
||||||
|
# user-agent: "codex_cli_rs/0.114.0 (Mac OS 14.2.0; x86_64) vscode/1.111.0"
|
||||||
|
# beta-features: "multi_agent"
|
||||||
|
|
||||||
# Kiro (AWS CodeWhisperer) configuration
|
# Kiro (AWS CodeWhisperer) configuration
|
||||||
# Note: Kiro API currently only operates in us-east-1 region
|
# Note: Kiro API currently only operates in us-east-1 region
|
||||||
@@ -215,17 +268,32 @@ nonstream-keepalive-interval: 0
|
|||||||
# api-key-entries:
|
# api-key-entries:
|
||||||
# - api-key: "sk-or-v1-...b780"
|
# - api-key: "sk-or-v1-...b780"
|
||||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# - api-key: "sk-or-v1-...b781" # without proxy-url
|
# - api-key: "sk-or-v1-...b781" # without proxy-url
|
||||||
# models: # The models supported by the provider.
|
# models: # The models supported by the provider.
|
||||||
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
||||||
# alias: "kimi-k2" # The alias used in the API.
|
# alias: "kimi-k2" # The alias used in the API.
|
||||||
|
# thinking: # optional: omit to default to levels ["low","medium","high"]
|
||||||
|
# levels: ["low", "medium", "high"]
|
||||||
|
# # You may repeat the same alias to build an internal model pool.
|
||||||
|
# # The client still sees only one alias in the model list.
|
||||||
|
# # Requests to that alias will round-robin across the upstream names below,
|
||||||
|
# # and if the chosen upstream fails before producing output, the request will
|
||||||
|
# # continue with the next upstream model in the same alias pool.
|
||||||
|
# - name: "deepseek-v3.1"
|
||||||
|
# alias: "claude-opus-4.66"
|
||||||
|
# - name: "glm-5"
|
||||||
|
# alias: "claude-opus-4.66"
|
||||||
|
# - name: "kimi-k2.5"
|
||||||
|
# alias: "claude-opus-4.66"
|
||||||
|
|
||||||
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
# Vertex API keys (Vertex-compatible endpoints, base-url is optional)
|
||||||
# vertex-api-key:
|
# vertex-api-key:
|
||||||
# - api-key: "vk-123..." # x-goog-api-key header
|
# - api-key: "vk-123..." # x-goog-api-key header
|
||||||
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
||||||
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
|
# base-url: "https://example.com/api" # optional, e.g. https://zenmux.ai/api; falls back to Google Vertex when omitted
|
||||||
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# headers:
|
# headers:
|
||||||
# X-Custom-Header: "custom-value"
|
# X-Custom-Header: "custom-value"
|
||||||
# models: # optional: map aliases to upstream model names
|
# models: # optional: map aliases to upstream model names
|
||||||
@@ -233,6 +301,9 @@ nonstream-keepalive-interval: 0
|
|||||||
# alias: "vertex-flash" # client-visible alias
|
# alias: "vertex-flash" # client-visible alias
|
||||||
# - name: "gemini-2.5-pro"
|
# - name: "gemini-2.5-pro"
|
||||||
# alias: "vertex-pro"
|
# alias: "vertex-pro"
|
||||||
|
# excluded-models: # optional: models to exclude from listing
|
||||||
|
# - "imagen-3.0-generate-002"
|
||||||
|
# - "imagen-*"
|
||||||
|
|
||||||
# Amp Integration
|
# Amp Integration
|
||||||
# ampcode:
|
# ampcode:
|
||||||
@@ -270,8 +341,12 @@ nonstream-keepalive-interval: 0
|
|||||||
|
|
||||||
# Global OAuth model name aliases (per channel)
|
# Global OAuth model name aliases (per channel)
|
||||||
# These aliases rename model IDs for both model listing and request routing.
|
# These aliases rename model IDs for both model listing and request routing.
|
||||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
|
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, iflow, kiro, github-copilot, kimi.
|
||||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||||
|
# NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping
|
||||||
|
# client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps
|
||||||
|
# you select the protocol surface, but inference backend selection can still follow the resolved
|
||||||
|
# model/alias. For strict backend pinning, use unique aliases/prefixes or avoid overlapping names.
|
||||||
# You can repeat the same name with different aliases to expose multiple client model names.
|
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||||
# oauth-model-alias:
|
# oauth-model-alias:
|
||||||
# antigravity:
|
# antigravity:
|
||||||
@@ -305,12 +380,6 @@ nonstream-keepalive-interval: 0
|
|||||||
# codex:
|
# codex:
|
||||||
# - name: "gpt-5"
|
# - name: "gpt-5"
|
||||||
# alias: "g5"
|
# alias: "g5"
|
||||||
# qwen:
|
|
||||||
# - name: "qwen3-coder-plus"
|
|
||||||
# alias: "qwen-plus"
|
|
||||||
# iflow:
|
|
||||||
# - name: "glm-4.7"
|
|
||||||
# alias: "glm-god"
|
|
||||||
# kimi:
|
# kimi:
|
||||||
# - name: "kimi-k2.5"
|
# - name: "kimi-k2.5"
|
||||||
# alias: "k2.5"
|
# alias: "k2.5"
|
||||||
@@ -339,10 +408,6 @@ nonstream-keepalive-interval: 0
|
|||||||
# - "claude-3-5-haiku-20241022"
|
# - "claude-3-5-haiku-20241022"
|
||||||
# codex:
|
# codex:
|
||||||
# - "gpt-5-codex-mini"
|
# - "gpt-5-codex-mini"
|
||||||
# qwen:
|
|
||||||
# - "vision-model"
|
|
||||||
# iflow:
|
|
||||||
# - "tstars2.0"
|
|
||||||
# kimi:
|
# kimi:
|
||||||
# - "kimi-k2-thinking"
|
# - "kimi-k2-thinking"
|
||||||
# kiro:
|
# kiro:
|
||||||
|
|||||||
@@ -109,10 +109,19 @@ wait_for_service() {
|
|||||||
sleep 2
|
sleep 2
|
||||||
}
|
}
|
||||||
|
|
||||||
if [[ "${1:-}" == "--with-usage" ]]; then
|
case "${1:-}" in
|
||||||
WITH_USAGE=true
|
"")
|
||||||
export_stats_api_secret
|
;;
|
||||||
fi
|
"--with-usage")
|
||||||
|
WITH_USAGE=true
|
||||||
|
export_stats_api_secret
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Error: unknown option '${1}'. Did you mean '--with-usage'?"
|
||||||
|
echo "Usage: ./docker-build.sh [--with-usage]"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
# --- Step 1: Choose Environment ---
|
# --- Step 1: Choose Environment ---
|
||||||
echo "Please select an option:"
|
echo "Please select an option:"
|
||||||
|
|||||||
115
docs/gitlab-duo.md
Normal file
115
docs/gitlab-duo.md
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
# GitLab Duo guide
|
||||||
|
|
||||||
|
CLIProxyAPI can now use GitLab Duo as a first-class provider instead of treating it as a plain text wrapper.
|
||||||
|
|
||||||
|
It supports:
|
||||||
|
|
||||||
|
- OAuth login
|
||||||
|
- personal access token login
|
||||||
|
- automatic refresh of GitLab `direct_access` metadata
|
||||||
|
- dynamic model discovery from GitLab metadata
|
||||||
|
- native GitLab AI gateway routing for Anthropic and OpenAI/Codex managed models
|
||||||
|
- Claude-compatible and OpenAI-compatible downstream APIs
|
||||||
|
|
||||||
|
## What this means
|
||||||
|
|
||||||
|
If GitLab Duo returns an Anthropic-managed model, CLIProxyAPI routes requests through the GitLab AI gateway Anthropic proxy and uses the existing Claude executor path.
|
||||||
|
|
||||||
|
If GitLab Duo returns an OpenAI-managed model, CLIProxyAPI routes requests through the GitLab AI gateway OpenAI proxy and uses the existing Codex/OpenAI executor path.
|
||||||
|
|
||||||
|
That gives GitLab Duo much closer runtime behavior to the built-in `codex` provider:
|
||||||
|
|
||||||
|
- Claude-compatible clients can use GitLab Duo models through `/v1/messages`
|
||||||
|
- OpenAI-compatible clients can use GitLab Duo models through `/v1/chat/completions`
|
||||||
|
- OpenAI Responses clients can use GitLab Duo models through `/v1/responses`
|
||||||
|
|
||||||
|
The model list is not hardcoded. CLIProxyAPI reads the current model metadata from GitLab `direct_access` and registers:
|
||||||
|
|
||||||
|
- a stable alias: `gitlab-duo`
|
||||||
|
- any discovered managed model names, such as `claude-sonnet-4-5` or `gpt-5-codex`
|
||||||
|
|
||||||
|
## Login
|
||||||
|
|
||||||
|
OAuth login:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./CLIProxyAPI -gitlab-login
|
||||||
|
```
|
||||||
|
|
||||||
|
PAT login:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./CLIProxyAPI -gitlab-token-login
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also provide inputs through environment variables:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export GITLAB_BASE_URL=https://gitlab.com
|
||||||
|
export GITLAB_OAUTH_CLIENT_ID=your-client-id
|
||||||
|
export GITLAB_OAUTH_CLIENT_SECRET=your-client-secret
|
||||||
|
export GITLAB_PERSONAL_ACCESS_TOKEN=glpat-...
|
||||||
|
```
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
|
||||||
|
- OAuth requires a GitLab OAuth application.
|
||||||
|
- PAT login requires a personal access token that can call the GitLab APIs used by Duo. In practice, `api` scope is the safe baseline.
|
||||||
|
- Self-managed GitLab instances are supported through `GITLAB_BASE_URL`.
|
||||||
|
|
||||||
|
## Using the models
|
||||||
|
|
||||||
|
After login, start CLIProxyAPI normally and point your client at the local proxy.
|
||||||
|
|
||||||
|
You can select:
|
||||||
|
|
||||||
|
- `gitlab-duo` to use the current Duo-managed model for that account
|
||||||
|
- the discovered provider model name if you want to pin it explicitly
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://127.0.0.1:8080/v1/models
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://127.0.0.1:8080/v1/chat/completions \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"model": "gitlab-duo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Write a Go HTTP middleware for request IDs."}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
If the GitLab account is currently mapped to an Anthropic model, Claude-compatible clients can use the same account through the Claude handler path. If the account is currently mapped to an OpenAI/Codex model, OpenAI-compatible clients can use `/v1/chat/completions` or `/v1/responses`.
|
||||||
|
|
||||||
|
## How model freshness works
|
||||||
|
|
||||||
|
CLIProxyAPI does not ship a fixed GitLab Duo model catalog.
|
||||||
|
|
||||||
|
Instead, it refreshes GitLab `direct_access` metadata and uses the returned `model_details` and any discovered model list entries to keep the local registry aligned with the current GitLab-managed model assignment.
|
||||||
|
|
||||||
|
This matches GitLab's current public contract better than hardcoding model names.
|
||||||
|
|
||||||
|
## Current scope
|
||||||
|
|
||||||
|
The GitLab Duo provider now has:
|
||||||
|
|
||||||
|
- OAuth and PAT auth flows
|
||||||
|
- runtime refresh of Duo gateway credentials
|
||||||
|
- native Anthropic gateway routing
|
||||||
|
- native OpenAI/Codex gateway routing
|
||||||
|
- handler-level smoke tests for Claude-compatible and OpenAI-compatible paths
|
||||||
|
|
||||||
|
Still out of scope today:
|
||||||
|
|
||||||
|
- websocket or session-specific parity beyond the current HTTP APIs
|
||||||
|
- GitLab-specific IDE features that are not exposed through the public gateway contract
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- GitLab Code Suggestions API: https://docs.gitlab.com/api/code_suggestions/
|
||||||
|
- GitLab Agent Assistant and managed credentials: https://docs.gitlab.com/user/duo_agent_platform/agent_assistant/
|
||||||
|
- GitLab Duo model selection: https://docs.gitlab.com/user/gitlab_duo/model_selection/
|
||||||
115
docs/gitlab-duo_CN.md
Normal file
115
docs/gitlab-duo_CN.md
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
# GitLab Duo 使用说明
|
||||||
|
|
||||||
|
CLIProxyAPI 现在可以把 GitLab Duo 当作一等 Provider 来使用,而不是仅仅把它当成简单的文本补全封装。
|
||||||
|
|
||||||
|
当前支持:
|
||||||
|
|
||||||
|
- OAuth 登录
|
||||||
|
- personal access token 登录
|
||||||
|
- 自动刷新 GitLab `direct_access` 元数据
|
||||||
|
- 根据 GitLab 返回的元数据动态发现模型
|
||||||
|
- 针对 Anthropic 和 OpenAI/Codex 托管模型的 GitLab AI gateway 原生路由
|
||||||
|
- Claude 兼容与 OpenAI 兼容下游 API
|
||||||
|
|
||||||
|
## 这意味着什么
|
||||||
|
|
||||||
|
如果 GitLab Duo 返回的是 Anthropic 托管模型,CLIProxyAPI 会通过 GitLab AI gateway 的 Anthropic 代理转发,并复用现有的 Claude executor 路径。
|
||||||
|
|
||||||
|
如果 GitLab Duo 返回的是 OpenAI 托管模型,CLIProxyAPI 会通过 GitLab AI gateway 的 OpenAI 代理转发,并复用现有的 Codex/OpenAI executor 路径。
|
||||||
|
|
||||||
|
这让 GitLab Duo 的运行时行为更接近内置的 `codex` Provider:
|
||||||
|
|
||||||
|
- Claude 兼容客户端可以通过 `/v1/messages` 使用 GitLab Duo 模型
|
||||||
|
- OpenAI 兼容客户端可以通过 `/v1/chat/completions` 使用 GitLab Duo 模型
|
||||||
|
- OpenAI Responses 客户端可以通过 `/v1/responses` 使用 GitLab Duo 模型
|
||||||
|
|
||||||
|
模型列表不是硬编码的。CLIProxyAPI 会从 GitLab `direct_access` 中读取当前模型元数据,并注册:
|
||||||
|
|
||||||
|
- 一个稳定别名:`gitlab-duo`
|
||||||
|
- GitLab 当前发现到的托管模型名,例如 `claude-sonnet-4-5` 或 `gpt-5-codex`
|
||||||
|
|
||||||
|
## 登录
|
||||||
|
|
||||||
|
OAuth 登录:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./CLIProxyAPI -gitlab-login
|
||||||
|
```
|
||||||
|
|
||||||
|
PAT 登录:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./CLIProxyAPI -gitlab-token-login
|
||||||
|
```
|
||||||
|
|
||||||
|
也可以通过环境变量提供输入:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export GITLAB_BASE_URL=https://gitlab.com
|
||||||
|
export GITLAB_OAUTH_CLIENT_ID=your-client-id
|
||||||
|
export GITLAB_OAUTH_CLIENT_SECRET=your-client-secret
|
||||||
|
export GITLAB_PERSONAL_ACCESS_TOKEN=glpat-...
|
||||||
|
```
|
||||||
|
|
||||||
|
说明:
|
||||||
|
|
||||||
|
- OAuth 方式需要一个 GitLab OAuth application。
|
||||||
|
- PAT 登录需要一个能够调用 GitLab Duo 相关 API 的 personal access token。实践上,`api` scope 是最稳妥的基线。
|
||||||
|
- 自建 GitLab 实例可以通过 `GITLAB_BASE_URL` 接入。
|
||||||
|
|
||||||
|
## 如何使用模型
|
||||||
|
|
||||||
|
登录完成后,正常启动 CLIProxyAPI,并让客户端连接到本地代理。
|
||||||
|
|
||||||
|
你可以选择:
|
||||||
|
|
||||||
|
- `gitlab-duo`,始终使用该账号当前的 Duo 托管模型
|
||||||
|
- GitLab 当前发现到的 provider 模型名,如果你想显式固定模型
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://127.0.0.1:8080/v1/models
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://127.0.0.1:8080/v1/chat/completions \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"model": "gitlab-duo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Write a Go HTTP middleware for request IDs."}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
如果该 GitLab 账号当前绑定的是 Anthropic 模型,Claude 兼容客户端可以通过 Claude handler 路径直接使用它。如果当前绑定的是 OpenAI/Codex 模型,OpenAI 兼容客户端可以通过 `/v1/chat/completions` 或 `/v1/responses` 使用它。
|
||||||
|
|
||||||
|
## 模型如何保持最新
|
||||||
|
|
||||||
|
CLIProxyAPI 不内置固定的 GitLab Duo 模型清单。
|
||||||
|
|
||||||
|
它会刷新 GitLab `direct_access` 元数据,并使用返回的 `model_details` 以及可能存在的模型列表字段,让本地 registry 尽量与 GitLab 当前分配的托管模型保持一致。
|
||||||
|
|
||||||
|
这比硬编码模型名更符合 GitLab 当前公开 API 的实际契约。
|
||||||
|
|
||||||
|
## 当前覆盖范围
|
||||||
|
|
||||||
|
GitLab Duo Provider 目前已经具备:
|
||||||
|
|
||||||
|
- OAuth 和 PAT 登录流程
|
||||||
|
- Duo gateway 凭据的运行时刷新
|
||||||
|
- Anthropic gateway 原生路由
|
||||||
|
- OpenAI/Codex gateway 原生路由
|
||||||
|
- Claude 兼容和 OpenAI 兼容路径的 handler 级 smoke 测试
|
||||||
|
|
||||||
|
当前仍未覆盖:
|
||||||
|
|
||||||
|
- websocket 或 session 级别的完全对齐
|
||||||
|
- GitLab 公开 gateway 契约之外的 IDE 专有能力
|
||||||
|
|
||||||
|
## 参考资料
|
||||||
|
|
||||||
|
- GitLab Code Suggestions API: https://docs.gitlab.com/api/code_suggestions/
|
||||||
|
- GitLab Agent Assistant 与 managed credentials: https://docs.gitlab.com/user/duo_agent_platform/agent_assistant/
|
||||||
|
- GitLab Duo 模型选择: https://docs.gitlab.com/user/gitlab_duo/model_selection/
|
||||||
@@ -52,11 +52,11 @@ func init() {
|
|||||||
sdktr.Register(fOpenAI, fMyProv,
|
sdktr.Register(fOpenAI, fMyProv,
|
||||||
func(model string, raw []byte, stream bool) []byte { return raw },
|
func(model string, raw []byte, stream bool) []byte { return raw },
|
||||||
sdktr.ResponseTransform{
|
sdktr.ResponseTransform{
|
||||||
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []string {
|
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) [][]byte {
|
||||||
return []string{string(raw)}
|
return [][]byte{raw}
|
||||||
},
|
},
|
||||||
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) string {
|
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []byte {
|
||||||
return string(raw)
|
return raw
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
3
go.mod
3
go.mod
@@ -83,6 +83,7 @@ require (
|
|||||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||||
github.com/muesli/termenv v0.16.0 // indirect
|
github.com/muesli/termenv v0.16.0 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
|
github.com/pierrec/xxHash v0.1.5
|
||||||
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
||||||
github.com/rivo/uniseg v0.4.7 // indirect
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
github.com/rs/xid v1.5.0 // indirect
|
github.com/rs/xid v1.5.0 // indirect
|
||||||
@@ -91,8 +92,8 @@ require (
|
|||||||
github.com/tidwall/pretty v1.2.0 // indirect
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
|
||||||
github.com/x448/float16 v0.8.4 // indirect
|
github.com/x448/float16 v0.8.4 // indirect
|
||||||
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/sys v0.38.0 // indirect
|
golang.org/x/sys v0.38.0 // indirect
|
||||||
golang.org/x/text v0.31.0 // indirect
|
golang.org/x/text v0.31.0 // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -154,6 +154,8 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
|
|||||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
|
github.com/pierrec/xxHash v0.1.5 h1:n/jBpwTHiER4xYvK3/CdPVnLDPchj8eTJFFLUb4QHBo=
|
||||||
|
github.com/pierrec/xxHash v0.1.5/go.mod h1:w2waW5Zoa/Wc4Yqe0wgrIYAGKqRMf7czn2HNKXmuL+I=
|
||||||
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
||||||
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
||||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -14,13 +13,13 @@ import (
|
|||||||
|
|
||||||
"github.com/fxamacker/cbor/v2"
|
"github.com/fxamacker/cbor/v2"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
log "github.com/sirupsen/logrus"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
"golang.org/x/oauth2/google"
|
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/oauth2/google"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultAPICallTimeout = 60 * time.Second
|
const defaultAPICallTimeout = 60 * time.Second
|
||||||
@@ -702,6 +701,11 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
|
|||||||
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
|
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
|
||||||
proxyCandidates = append(proxyCandidates, proxyStr)
|
proxyCandidates = append(proxyCandidates, proxyStr)
|
||||||
}
|
}
|
||||||
|
if h != nil && h.cfg != nil {
|
||||||
|
if proxyStr := strings.TrimSpace(proxyURLFromAPIKeyConfig(h.cfg, auth)); proxyStr != "" {
|
||||||
|
proxyCandidates = append(proxyCandidates, proxyStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if h != nil && h.cfg != nil {
|
if h != nil && h.cfg != nil {
|
||||||
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
|
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
|
||||||
@@ -724,50 +728,132 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
|
|||||||
return clone
|
return clone
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildProxyTransport(proxyStr string) *http.Transport {
|
type apiKeyConfigEntry interface {
|
||||||
proxyStr = strings.TrimSpace(proxyStr)
|
GetAPIKey() string
|
||||||
if proxyStr == "" {
|
GetBaseURL() string
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveAPIKeyConfig[T apiKeyConfigEntry](entries []T, auth *coreauth.Auth) *T {
|
||||||
|
if auth == nil || len(entries) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
attrKey, attrBase := "", ""
|
||||||
proxyURL, errParse := url.Parse(proxyStr)
|
if auth.Attributes != nil {
|
||||||
if errParse != nil {
|
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||||
log.WithError(errParse).Debug("parse proxy URL failed")
|
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
if proxyURL.Scheme == "" || proxyURL.Host == "" {
|
for i := range entries {
|
||||||
log.Debug("proxy URL missing scheme/host")
|
entry := &entries[i]
|
||||||
return nil
|
cfgKey := strings.TrimSpace((*entry).GetAPIKey())
|
||||||
}
|
cfgBase := strings.TrimSpace((*entry).GetBaseURL())
|
||||||
|
if attrKey != "" && attrBase != "" {
|
||||||
if proxyURL.Scheme == "socks5" {
|
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
||||||
var proxyAuth *proxy.Auth
|
return entry
|
||||||
if proxyURL.User != nil {
|
}
|
||||||
username := proxyURL.User.Username()
|
continue
|
||||||
password, _ := proxyURL.User.Password()
|
|
||||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
|
||||||
}
|
}
|
||||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||||
if errSOCKS5 != nil {
|
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||||
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
|
return entry
|
||||||
return nil
|
}
|
||||||
}
|
}
|
||||||
return &http.Transport{
|
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||||
Proxy: nil,
|
return entry
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.Dial(network, addr)
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if attrKey != "" {
|
||||||
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
for i := range entries {
|
||||||
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
entry := &entries[i]
|
||||||
|
if strings.EqualFold(strings.TrimSpace((*entry).GetAPIKey()), attrKey) {
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func proxyURLFromAPIKeyConfig(cfg *config.Config, auth *coreauth.Auth) string {
|
||||||
|
if cfg == nil || auth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
authKind, authAccount := auth.AccountInfo()
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(authKind), "api_key") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
attrs := auth.Attributes
|
||||||
|
compatName := ""
|
||||||
|
providerKey := ""
|
||||||
|
if len(attrs) > 0 {
|
||||||
|
compatName = strings.TrimSpace(attrs["compat_name"])
|
||||||
|
providerKey = strings.TrimSpace(attrs["provider_key"])
|
||||||
|
}
|
||||||
|
if compatName != "" || strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
|
||||||
|
return resolveOpenAICompatAPIKeyProxyURL(cfg, auth, strings.TrimSpace(authAccount), providerKey, compatName)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch strings.ToLower(strings.TrimSpace(auth.Provider)) {
|
||||||
|
case "gemini":
|
||||||
|
if entry := resolveAPIKeyConfig(cfg.GeminiKey, auth); entry != nil {
|
||||||
|
return strings.TrimSpace(entry.ProxyURL)
|
||||||
|
}
|
||||||
|
case "claude":
|
||||||
|
if entry := resolveAPIKeyConfig(cfg.ClaudeKey, auth); entry != nil {
|
||||||
|
return strings.TrimSpace(entry.ProxyURL)
|
||||||
|
}
|
||||||
|
case "codex":
|
||||||
|
if entry := resolveAPIKeyConfig(cfg.CodexKey, auth); entry != nil {
|
||||||
|
return strings.TrimSpace(entry.ProxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveOpenAICompatAPIKeyProxyURL(cfg *config.Config, auth *coreauth.Auth, apiKey, providerKey, compatName string) string {
|
||||||
|
if cfg == nil || auth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
apiKey = strings.TrimSpace(apiKey)
|
||||||
|
if apiKey == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
candidates := make([]string, 0, 3)
|
||||||
|
if v := strings.TrimSpace(compatName); v != "" {
|
||||||
|
candidates = append(candidates, v)
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(providerKey); v != "" {
|
||||||
|
candidates = append(candidates, v)
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(auth.Provider); v != "" {
|
||||||
|
candidates = append(candidates, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range cfg.OpenAICompatibility {
|
||||||
|
compat := &cfg.OpenAICompatibility[i]
|
||||||
|
for _, candidate := range candidates {
|
||||||
|
if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) {
|
||||||
|
for j := range compat.APIKeyEntries {
|
||||||
|
entry := &compat.APIKeyEntries[j]
|
||||||
|
if strings.EqualFold(strings.TrimSpace(entry.APIKey), apiKey) {
|
||||||
|
return strings.TrimSpace(entry.ProxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildProxyTransport(proxyStr string) *http.Transport {
|
||||||
|
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
|
||||||
|
if errBuild != nil {
|
||||||
|
log.WithError(errBuild).Debug("build proxy transport failed")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return transport
|
||||||
|
}
|
||||||
|
|
||||||
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
|
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
|
||||||
func headerContainsValue(headers map[string]string, targetKey, targetValue string) bool {
|
func headerContainsValue(headers map[string]string, targetKey, targetValue string) bool {
|
||||||
if len(headers) == 0 {
|
if len(headers) == 0 {
|
||||||
|
|||||||
@@ -2,172 +2,211 @@ package management
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type memoryAuthStore struct {
|
func TestAPICallTransportDirectBypassesGlobalProxy(t *testing.T) {
|
||||||
mu sync.Mutex
|
t.Parallel()
|
||||||
items map[string]*coreauth.Auth
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) {
|
h := &Handler{
|
||||||
_ = ctx
|
cfg: &config.Config{
|
||||||
s.mu.Lock()
|
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
|
||||||
defer s.mu.Unlock()
|
|
||||||
out := make([]*coreauth.Auth, 0, len(s.items))
|
|
||||||
for _, a := range s.items {
|
|
||||||
out = append(out, a.Clone())
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
|
||||||
_ = ctx
|
|
||||||
if auth == nil {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
s.mu.Lock()
|
|
||||||
if s.items == nil {
|
|
||||||
s.items = make(map[string]*coreauth.Auth)
|
|
||||||
}
|
|
||||||
s.items[auth.ID] = auth.Clone()
|
|
||||||
s.mu.Unlock()
|
|
||||||
return auth.ID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
|
|
||||||
_ = ctx
|
|
||||||
s.mu.Lock()
|
|
||||||
delete(s.items, id)
|
|
||||||
s.mu.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
|
|
||||||
var callCount int
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
callCount++
|
|
||||||
if r.Method != http.MethodPost {
|
|
||||||
t.Fatalf("expected POST, got %s", r.Method)
|
|
||||||
}
|
|
||||||
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
|
|
||||||
t.Fatalf("unexpected content-type: %s", ct)
|
|
||||||
}
|
|
||||||
bodyBytes, _ := io.ReadAll(r.Body)
|
|
||||||
_ = r.Body.Close()
|
|
||||||
values, err := url.ParseQuery(string(bodyBytes))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parse form: %v", err)
|
|
||||||
}
|
|
||||||
if values.Get("grant_type") != "refresh_token" {
|
|
||||||
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
|
|
||||||
}
|
|
||||||
if values.Get("refresh_token") != "rt" {
|
|
||||||
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
|
|
||||||
}
|
|
||||||
if values.Get("client_id") != antigravityOAuthClientID {
|
|
||||||
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
|
|
||||||
}
|
|
||||||
if values.Get("client_secret") != antigravityOAuthClientSecret {
|
|
||||||
t.Fatalf("unexpected client_secret")
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
||||||
"access_token": "new-token",
|
|
||||||
"refresh_token": "rt2",
|
|
||||||
"expires_in": int64(3600),
|
|
||||||
"token_type": "Bearer",
|
|
||||||
})
|
|
||||||
}))
|
|
||||||
t.Cleanup(srv.Close)
|
|
||||||
|
|
||||||
originalURL := antigravityOAuthTokenURL
|
|
||||||
antigravityOAuthTokenURL = srv.URL
|
|
||||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
|
||||||
|
|
||||||
store := &memoryAuthStore{}
|
|
||||||
manager := coreauth.NewManager(store, nil, nil)
|
|
||||||
|
|
||||||
auth := &coreauth.Auth{
|
|
||||||
ID: "antigravity-test.json",
|
|
||||||
FileName: "antigravity-test.json",
|
|
||||||
Provider: "antigravity",
|
|
||||||
Metadata: map[string]any{
|
|
||||||
"type": "antigravity",
|
|
||||||
"access_token": "old-token",
|
|
||||||
"refresh_token": "rt",
|
|
||||||
"expires_in": int64(3600),
|
|
||||||
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
|
|
||||||
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
|
||||||
t.Fatalf("register auth: %v", err)
|
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "direct"})
|
||||||
|
httpTransport, ok := transport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type = %T, want *http.Transport", transport)
|
||||||
|
}
|
||||||
|
if httpTransport.Proxy != nil {
|
||||||
|
t.Fatal("expected direct transport to disable proxy function")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "bad-value"})
|
||||||
|
httpTransport, ok := transport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type = %T, want *http.Transport", transport)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||||
|
if errRequest != nil {
|
||||||
|
t.Fatalf("http.NewRequest returned error: %v", errRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, errProxy := httpTransport.Proxy(req)
|
||||||
|
if errProxy != nil {
|
||||||
|
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
|
||||||
|
}
|
||||||
|
if proxyURL == nil || proxyURL.String() != "http://global-proxy.example.com:8080" {
|
||||||
|
t.Fatalf("proxy URL = %v, want http://global-proxy.example.com:8080", proxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPICallTransportAPIKeyAuthFallsBackToConfigProxyURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
|
||||||
|
GeminiKey: []config.GeminiKey{{
|
||||||
|
APIKey: "gemini-key",
|
||||||
|
ProxyURL: "http://gemini-proxy.example.com:8080",
|
||||||
|
}},
|
||||||
|
ClaudeKey: []config.ClaudeKey{{
|
||||||
|
APIKey: "claude-key",
|
||||||
|
ProxyURL: "http://claude-proxy.example.com:8080",
|
||||||
|
}},
|
||||||
|
CodexKey: []config.CodexKey{{
|
||||||
|
APIKey: "codex-key",
|
||||||
|
ProxyURL: "http://codex-proxy.example.com:8080",
|
||||||
|
}},
|
||||||
|
OpenAICompatibility: []config.OpenAICompatibility{{
|
||||||
|
Name: "bohe",
|
||||||
|
BaseURL: "https://bohe.example.com",
|
||||||
|
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{
|
||||||
|
APIKey: "compat-key",
|
||||||
|
ProxyURL: "http://compat-proxy.example.com:8080",
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
auth *coreauth.Auth
|
||||||
|
wantProxy string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "gemini",
|
||||||
|
auth: &coreauth.Auth{
|
||||||
|
Provider: "gemini",
|
||||||
|
Attributes: map[string]string{"api_key": "gemini-key"},
|
||||||
|
},
|
||||||
|
wantProxy: "http://gemini-proxy.example.com:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "claude",
|
||||||
|
auth: &coreauth.Auth{
|
||||||
|
Provider: "claude",
|
||||||
|
Attributes: map[string]string{"api_key": "claude-key"},
|
||||||
|
},
|
||||||
|
wantProxy: "http://claude-proxy.example.com:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "codex",
|
||||||
|
auth: &coreauth.Auth{
|
||||||
|
Provider: "codex",
|
||||||
|
Attributes: map[string]string{"api_key": "codex-key"},
|
||||||
|
},
|
||||||
|
wantProxy: "http://codex-proxy.example.com:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "openai-compatibility",
|
||||||
|
auth: &coreauth.Auth{
|
||||||
|
Provider: "bohe",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "compat-key",
|
||||||
|
"compat_name": "bohe",
|
||||||
|
"provider_key": "bohe",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantProxy: "http://compat-proxy.example.com:8080",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
transport := h.apiCallTransport(tc.auth)
|
||||||
|
httpTransport, ok := transport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type = %T, want *http.Transport", transport)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||||
|
if errRequest != nil {
|
||||||
|
t.Fatalf("http.NewRequest returned error: %v", errRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, errProxy := httpTransport.Proxy(req)
|
||||||
|
if errProxy != nil {
|
||||||
|
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
|
||||||
|
}
|
||||||
|
if proxyURL == nil || proxyURL.String() != tc.wantProxy {
|
||||||
|
t.Fatalf("proxy URL = %v, want %s", proxyURL, tc.wantProxy)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
geminiAuth := &coreauth.Auth{
|
||||||
|
ID: "gemini:apikey:123",
|
||||||
|
Provider: "gemini",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "shared-key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
compatAuth := &coreauth.Auth{
|
||||||
|
ID: "openai-compatibility:bohe:456",
|
||||||
|
Provider: "bohe",
|
||||||
|
Label: "bohe",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "shared-key",
|
||||||
|
"compat_name": "bohe",
|
||||||
|
"provider_key": "bohe",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, errRegister := manager.Register(context.Background(), geminiAuth); errRegister != nil {
|
||||||
|
t.Fatalf("register gemini auth: %v", errRegister)
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), compatAuth); errRegister != nil {
|
||||||
|
t.Fatalf("register compat auth: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
geminiIndex := geminiAuth.EnsureIndex()
|
||||||
|
compatIndex := compatAuth.EnsureIndex()
|
||||||
|
if geminiIndex == compatIndex {
|
||||||
|
t.Fatalf("shared api key produced duplicate auth_index %q", geminiIndex)
|
||||||
}
|
}
|
||||||
|
|
||||||
h := &Handler{authManager: manager}
|
h := &Handler{authManager: manager}
|
||||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
|
||||||
if err != nil {
|
gotGemini := h.authByIndex(geminiIndex)
|
||||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
if gotGemini == nil {
|
||||||
|
t.Fatal("expected gemini auth by index")
|
||||||
}
|
}
|
||||||
if token != "new-token" {
|
if gotGemini.ID != geminiAuth.ID {
|
||||||
t.Fatalf("expected refreshed token, got %q", token)
|
t.Fatalf("authByIndex(gemini) returned %q, want %q", gotGemini.ID, geminiAuth.ID)
|
||||||
}
|
|
||||||
if callCount != 1 {
|
|
||||||
t.Fatalf("expected 1 refresh call, got %d", callCount)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, ok := manager.GetByID(auth.ID)
|
gotCompat := h.authByIndex(compatIndex)
|
||||||
if !ok || updated == nil {
|
if gotCompat == nil {
|
||||||
t.Fatalf("expected auth in manager after update")
|
t.Fatal("expected compat auth by index")
|
||||||
}
|
}
|
||||||
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
|
if gotCompat.ID != compatAuth.ID {
|
||||||
t.Fatalf("expected manager metadata updated, got %q", got)
|
t.Fatalf("authByIndex(compat) returned %q, want %q", gotCompat.ID, compatAuth.ID)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
|
|
||||||
var callCount int
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
callCount++
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
}))
|
|
||||||
t.Cleanup(srv.Close)
|
|
||||||
|
|
||||||
originalURL := antigravityOAuthTokenURL
|
|
||||||
antigravityOAuthTokenURL = srv.URL
|
|
||||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
|
||||||
|
|
||||||
auth := &coreauth.Auth{
|
|
||||||
ID: "antigravity-valid.json",
|
|
||||||
FileName: "antigravity-valid.json",
|
|
||||||
Provider: "antigravity",
|
|
||||||
Metadata: map[string]any{
|
|
||||||
"type": "antigravity",
|
|
||||||
"access_token": "ok-token",
|
|
||||||
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
h := &Handler{}
|
|
||||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
|
||||||
}
|
|
||||||
if token != "ok-token" {
|
|
||||||
t.Fatalf("expected existing token, got %q", token)
|
|
||||||
}
|
|
||||||
if callCount != 0 {
|
|
||||||
t.Fatalf("expected no refresh calls, got %d", callCount)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
197
internal/api/handlers/management/auth_files_batch_test.go
Normal file
197
internal/api/handlers/management/auth_files_batch_test.go
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUploadAuthFile_BatchMultipart(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||||
|
|
||||||
|
files := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
}{
|
||||||
|
{name: "alpha.json", content: `{"type":"codex","email":"alpha@example.com"}`},
|
||||||
|
{name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`},
|
||||||
|
}
|
||||||
|
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
for _, file := range files {
|
||||||
|
part, err := writer.CreateFormFile("file", file.name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create multipart file: %v", err)
|
||||||
|
}
|
||||||
|
if _, err = part.Write([]byte(file.content)); err != nil {
|
||||||
|
t.Fatalf("failed to write multipart content: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := writer.Close(); err != nil {
|
||||||
|
t.Fatalf("failed to close multipart writer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body)
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
ctx.Request = req
|
||||||
|
|
||||||
|
h.UploadAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected upload status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
if got, ok := payload["uploaded"].(float64); !ok || int(got) != len(files) {
|
||||||
|
t.Fatalf("expected uploaded=%d, got %#v", len(files), payload["uploaded"])
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range files {
|
||||||
|
fullPath := filepath.Join(authDir, file.name)
|
||||||
|
data, err := os.ReadFile(fullPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected uploaded file %s to exist: %v", file.name, err)
|
||||||
|
}
|
||||||
|
if string(data) != file.content {
|
||||||
|
t.Fatalf("expected file %s content %q, got %q", file.name, file.content, string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auths := manager.List()
|
||||||
|
if len(auths) != len(files) {
|
||||||
|
t.Fatalf("expected %d auth entries, got %d", len(files), len(auths))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadAuthFile_BatchMultipart_InvalidJSONDoesNotOverwriteExistingFile(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||||
|
|
||||||
|
existingName := "alpha.json"
|
||||||
|
existingContent := `{"type":"codex","email":"alpha@example.com"}`
|
||||||
|
if err := os.WriteFile(filepath.Join(authDir, existingName), []byte(existingContent), 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to seed existing auth file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
files := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
}{
|
||||||
|
{name: existingName, content: `{"type":"codex"`},
|
||||||
|
{name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`},
|
||||||
|
}
|
||||||
|
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
for _, file := range files {
|
||||||
|
part, err := writer.CreateFormFile("file", file.name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create multipart file: %v", err)
|
||||||
|
}
|
||||||
|
if _, err = part.Write([]byte(file.content)); err != nil {
|
||||||
|
t.Fatalf("failed to write multipart content: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := writer.Close(); err != nil {
|
||||||
|
t.Fatalf("failed to close multipart writer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body)
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
ctx.Request = req
|
||||||
|
|
||||||
|
h.UploadAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusMultiStatus {
|
||||||
|
t.Fatalf("expected upload status %d, got %d with body %s", http.StatusMultiStatus, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(filepath.Join(authDir, existingName))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected existing auth file to remain readable: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != existingContent {
|
||||||
|
t.Fatalf("expected existing auth file to remain %q, got %q", existingContent, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
betaData, err := os.ReadFile(filepath.Join(authDir, "beta.json"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected valid auth file to be created: %v", err)
|
||||||
|
}
|
||||||
|
if string(betaData) != files[1].content {
|
||||||
|
t.Fatalf("expected beta auth file content %q, got %q", files[1].content, string(betaData))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteAuthFile_BatchQuery(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
files := []string{"alpha.json", "beta.json"}
|
||||||
|
for _, name := range files {
|
||||||
|
if err := os.WriteFile(filepath.Join(authDir, name), []byte(`{"type":"codex"}`), 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to write auth file %s: %v", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||||
|
h.tokenStore = &memoryAuthStore{}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := httptest.NewRequest(
|
||||||
|
http.MethodDelete,
|
||||||
|
"/v0/management/auth-files?name="+url.QueryEscape(files[0])+"&name="+url.QueryEscape(files[1]),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
ctx.Request = req
|
||||||
|
|
||||||
|
h.DeleteAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
if got, ok := payload["deleted"].(float64); !ok || int(got) != len(files) {
|
||||||
|
t.Fatalf("expected deleted=%d, got %#v", len(files), payload["deleted"])
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range files {
|
||||||
|
if _, err := os.Stat(filepath.Join(authDir, name)); !os.IsNotExist(err) {
|
||||||
|
t.Fatalf("expected auth file %s to be removed, stat err: %v", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
129
internal/api/handlers/management/auth_files_delete_test.go
Normal file
129
internal/api/handlers/management/auth_files_delete_test.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeleteAuthFile_UsesAuthPathFromManager(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
authDir := filepath.Join(tempDir, "auth")
|
||||||
|
externalDir := filepath.Join(tempDir, "external")
|
||||||
|
if errMkdirAuth := os.MkdirAll(authDir, 0o700); errMkdirAuth != nil {
|
||||||
|
t.Fatalf("failed to create auth dir: %v", errMkdirAuth)
|
||||||
|
}
|
||||||
|
if errMkdirExternal := os.MkdirAll(externalDir, 0o700); errMkdirExternal != nil {
|
||||||
|
t.Fatalf("failed to create external dir: %v", errMkdirExternal)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileName := "codex-user@example.com-plus.json"
|
||||||
|
shadowPath := filepath.Join(authDir, fileName)
|
||||||
|
realPath := filepath.Join(externalDir, fileName)
|
||||||
|
if errWriteShadow := os.WriteFile(shadowPath, []byte(`{"type":"codex","email":"shadow@example.com"}`), 0o600); errWriteShadow != nil {
|
||||||
|
t.Fatalf("failed to write shadow file: %v", errWriteShadow)
|
||||||
|
}
|
||||||
|
if errWriteReal := os.WriteFile(realPath, []byte(`{"type":"codex","email":"real@example.com"}`), 0o600); errWriteReal != nil {
|
||||||
|
t.Fatalf("failed to write real file: %v", errWriteReal)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: "legacy/" + fileName,
|
||||||
|
FileName: fileName,
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusError,
|
||||||
|
Unavailable: true,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"path": realPath,
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "codex",
|
||||||
|
"email": "real@example.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
|
||||||
|
t.Fatalf("failed to register auth record: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||||
|
h.tokenStore = &memoryAuthStore{}
|
||||||
|
|
||||||
|
deleteRec := httptest.NewRecorder()
|
||||||
|
deleteCtx, _ := gin.CreateTestContext(deleteRec)
|
||||||
|
deleteReq := httptest.NewRequest(http.MethodDelete, "/v0/management/auth-files?name="+url.QueryEscape(fileName), nil)
|
||||||
|
deleteCtx.Request = deleteReq
|
||||||
|
h.DeleteAuthFile(deleteCtx)
|
||||||
|
|
||||||
|
if deleteRec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, deleteRec.Code, deleteRec.Body.String())
|
||||||
|
}
|
||||||
|
if _, errStatReal := os.Stat(realPath); !os.IsNotExist(errStatReal) {
|
||||||
|
t.Fatalf("expected managed auth file to be removed, stat err: %v", errStatReal)
|
||||||
|
}
|
||||||
|
if _, errStatShadow := os.Stat(shadowPath); errStatShadow != nil {
|
||||||
|
t.Fatalf("expected shadow auth file to remain, stat err: %v", errStatShadow)
|
||||||
|
}
|
||||||
|
|
||||||
|
listRec := httptest.NewRecorder()
|
||||||
|
listCtx, _ := gin.CreateTestContext(listRec)
|
||||||
|
listReq := httptest.NewRequest(http.MethodGet, "/v0/management/auth-files", nil)
|
||||||
|
listCtx.Request = listReq
|
||||||
|
h.ListAuthFiles(listCtx)
|
||||||
|
|
||||||
|
if listRec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected list status %d, got %d with body %s", http.StatusOK, listRec.Code, listRec.Body.String())
|
||||||
|
}
|
||||||
|
var listPayload map[string]any
|
||||||
|
if errUnmarshal := json.Unmarshal(listRec.Body.Bytes(), &listPayload); errUnmarshal != nil {
|
||||||
|
t.Fatalf("failed to decode list payload: %v", errUnmarshal)
|
||||||
|
}
|
||||||
|
filesRaw, ok := listPayload["files"].([]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected files array, payload: %#v", listPayload)
|
||||||
|
}
|
||||||
|
if len(filesRaw) != 0 {
|
||||||
|
t.Fatalf("expected removed auth to be hidden from list, got %d entries", len(filesRaw))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteAuthFile_FallbackToAuthDirPath(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
fileName := "fallback-user.json"
|
||||||
|
filePath := filepath.Join(authDir, fileName)
|
||||||
|
if errWrite := os.WriteFile(filePath, []byte(`{"type":"codex"}`), 0o600); errWrite != nil {
|
||||||
|
t.Fatalf("failed to write auth file: %v", errWrite)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||||
|
h.tokenStore = &memoryAuthStore{}
|
||||||
|
|
||||||
|
deleteRec := httptest.NewRecorder()
|
||||||
|
deleteCtx, _ := gin.CreateTestContext(deleteRec)
|
||||||
|
deleteReq := httptest.NewRequest(http.MethodDelete, "/v0/management/auth-files?name="+url.QueryEscape(fileName), nil)
|
||||||
|
deleteCtx.Request = deleteReq
|
||||||
|
h.DeleteAuthFile(deleteCtx)
|
||||||
|
|
||||||
|
if deleteRec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, deleteRec.Code, deleteRec.Body.String())
|
||||||
|
}
|
||||||
|
if _, errStat := os.Stat(filePath); !os.IsNotExist(errStat) {
|
||||||
|
t.Fatalf("expected auth file to be removed from auth dir, stat err: %v", errStat)
|
||||||
|
}
|
||||||
|
}
|
||||||
62
internal/api/handlers/management/auth_files_download_test.go
Normal file
62
internal/api/handlers/management/auth_files_download_test.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDownloadAuthFile_ReturnsFile(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
fileName := "download-user.json"
|
||||||
|
expected := []byte(`{"type":"codex"}`)
|
||||||
|
if err := os.WriteFile(filepath.Join(authDir, fileName), expected, 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to write auth file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(fileName), nil)
|
||||||
|
h.DownloadAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected download status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := rec.Body.Bytes(); string(got) != string(expected) {
|
||||||
|
t.Fatalf("unexpected download content: %q", string(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadAuthFile_RejectsPathSeparators(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, nil)
|
||||||
|
|
||||||
|
for _, name := range []string{
|
||||||
|
"../external/secret.json",
|
||||||
|
`..\\external\\secret.json`,
|
||||||
|
"nested/secret.json",
|
||||||
|
`nested\\secret.json`,
|
||||||
|
} {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(name), nil)
|
||||||
|
h.DownloadAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected %d for name %q, got %d with body %s", http.StatusBadRequest, name, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDownloadAuthFile_PreventsWindowsSlashTraversal(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
authDir := filepath.Join(tempDir, "auth")
|
||||||
|
externalDir := filepath.Join(tempDir, "external")
|
||||||
|
if err := os.MkdirAll(authDir, 0o700); err != nil {
|
||||||
|
t.Fatalf("failed to create auth dir: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(externalDir, 0o700); err != nil {
|
||||||
|
t.Fatalf("failed to create external dir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
secretName := "secret.json"
|
||||||
|
secretPath := filepath.Join(externalDir, secretName)
|
||||||
|
if err := os.WriteFile(secretPath, []byte(`{"secret":true}`), 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to write external file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx.Request = httptest.NewRequest(
|
||||||
|
http.MethodGet,
|
||||||
|
"/v0/management/auth-files/download?name="+url.QueryEscape("../external/"+secretName),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
h.DownloadAuthFile(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status %d, got %d with body %s", http.StatusBadRequest, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
164
internal/api/handlers/management/auth_files_gitlab_test.go
Normal file
164
internal/api/handlers/management/auth_files_gitlab_test.go
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestGitLabPATToken_SavesAuthRecord(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if got := r.Header.Get("Authorization"); got != "Bearer glpat-test-token" {
|
||||||
|
t.Fatalf("authorization header = %q, want Bearer glpat-test-token", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/v4/user":
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"id": 42,
|
||||||
|
"username": "gitlab-user",
|
||||||
|
"name": "GitLab User",
|
||||||
|
"email": "gitlab@example.com",
|
||||||
|
})
|
||||||
|
case "/api/v4/personal_access_tokens/self":
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"id": 7,
|
||||||
|
"name": "management-center",
|
||||||
|
"scopes": []string{"api", "read_user"},
|
||||||
|
"user_id": 42,
|
||||||
|
})
|
||||||
|
case "/api/v4/code_suggestions/direct_access":
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"base_url": "https://cloud.gitlab.example.com",
|
||||||
|
"token": "gateway-token",
|
||||||
|
"expires_at": 1893456000,
|
||||||
|
"headers": map[string]string{
|
||||||
|
"X-Gitlab-Realm": "saas",
|
||||||
|
},
|
||||||
|
"model_details": map[string]any{
|
||||||
|
"model_provider": "anthropic",
|
||||||
|
"model_name": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer upstream.Close()
|
||||||
|
|
||||||
|
store := &memoryAuthStore{}
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, coreauth.NewManager(nil, nil, nil))
|
||||||
|
h.tokenStore = store
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/gitlab-auth-url", strings.NewReader(`{"base_url":"`+upstream.URL+`","personal_access_token":"glpat-test-token"}`))
|
||||||
|
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.RequestGitLabPATToken(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
if got := resp["status"]; got != "ok" {
|
||||||
|
t.Fatalf("status = %#v, want ok", got)
|
||||||
|
}
|
||||||
|
if got := resp["model_provider"]; got != "anthropic" {
|
||||||
|
t.Fatalf("model_provider = %#v, want anthropic", got)
|
||||||
|
}
|
||||||
|
if got := resp["model_name"]; got != "claude-sonnet-4-5" {
|
||||||
|
t.Fatalf("model_name = %#v, want claude-sonnet-4-5", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
store.mu.Lock()
|
||||||
|
defer store.mu.Unlock()
|
||||||
|
if len(store.items) != 1 {
|
||||||
|
t.Fatalf("expected 1 saved auth record, got %d", len(store.items))
|
||||||
|
}
|
||||||
|
var saved *coreauth.Auth
|
||||||
|
for _, item := range store.items {
|
||||||
|
saved = item
|
||||||
|
}
|
||||||
|
if saved == nil {
|
||||||
|
t.Fatal("expected saved auth record")
|
||||||
|
}
|
||||||
|
if saved.Provider != "gitlab" {
|
||||||
|
t.Fatalf("provider = %q, want gitlab", saved.Provider)
|
||||||
|
}
|
||||||
|
if got := saved.Metadata["auth_kind"]; got != "personal_access_token" {
|
||||||
|
t.Fatalf("auth_kind = %#v, want personal_access_token", got)
|
||||||
|
}
|
||||||
|
if got := saved.Metadata["model_provider"]; got != "anthropic" {
|
||||||
|
t.Fatalf("saved model_provider = %#v, want anthropic", got)
|
||||||
|
}
|
||||||
|
if got := saved.Metadata["duo_gateway_token"]; got != "gateway-token" {
|
||||||
|
t.Fatalf("saved duo_gateway_token = %#v, want gateway-token", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostOAuthCallback_GitLabWritesPendingCallbackFile(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
authDir := t.TempDir()
|
||||||
|
state := "gitlab-state-123"
|
||||||
|
RegisterOAuthSession(state, "gitlab")
|
||||||
|
t.Cleanup(func() { CompleteOAuthSession(state) })
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, coreauth.NewManager(nil, nil, nil))
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/oauth-callback", strings.NewReader(`{"provider":"gitlab","redirect_url":"http://localhost:17171/auth/callback?code=test-code&state=`+state+`"}`))
|
||||||
|
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.PostOAuthCallback(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := filepath.Join(authDir, ".oauth-gitlab-"+state+".oauth")
|
||||||
|
data, err := os.ReadFile(filePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read callback file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]string
|
||||||
|
if err := json.Unmarshal(data, &payload); err != nil {
|
||||||
|
t.Fatalf("decode callback payload: %v", err)
|
||||||
|
}
|
||||||
|
if got := payload["code"]; got != "test-code" {
|
||||||
|
t.Fatalf("callback code = %q, want test-code", got)
|
||||||
|
}
|
||||||
|
if got := payload["state"]; got != state {
|
||||||
|
t.Fatalf("callback state = %q, want %q", got, state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOAuthProvider_GitLab(t *testing.T) {
|
||||||
|
provider, err := NormalizeOAuthProvider("gitlab")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NormalizeOAuthProvider returned error: %v", err)
|
||||||
|
}
|
||||||
|
if provider != "gitlab" {
|
||||||
|
t.Fatalf("provider = %q, want gitlab", provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
164
internal/api/handlers/management/auth_files_patch_fields_test.go
Normal file
164
internal/api/handlers/management/auth_files_patch_fields_test.go
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPatchAuthFileFields_MergeHeadersAndDeleteEmptyValues(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
store := &memoryAuthStore{}
|
||||||
|
manager := coreauth.NewManager(store, nil, nil)
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: "test.json",
|
||||||
|
FileName: "test.json",
|
||||||
|
Provider: "claude",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"path": "/tmp/test.json",
|
||||||
|
"header:X-Old": "old",
|
||||||
|
"header:X-Remove": "gone",
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "claude",
|
||||||
|
"headers": map[string]any{
|
||||||
|
"X-Old": "old",
|
||||||
|
"X-Remove": "gone",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
|
||||||
|
t.Fatalf("failed to register auth record: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
|
||||||
|
|
||||||
|
body := `{"name":"test.json","prefix":"p1","proxy_url":"http://proxy.local","headers":{"X-Old":"new","X-New":"v","X-Remove":" ","X-Nope":""}}`
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
ctx.Request = req
|
||||||
|
h.PatchAuthFileFields(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, ok := manager.GetByID("test.json")
|
||||||
|
if !ok || updated == nil {
|
||||||
|
t.Fatalf("expected auth record to exist after patch")
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated.Prefix != "p1" {
|
||||||
|
t.Fatalf("prefix = %q, want %q", updated.Prefix, "p1")
|
||||||
|
}
|
||||||
|
if updated.ProxyURL != "http://proxy.local" {
|
||||||
|
t.Fatalf("proxy_url = %q, want %q", updated.ProxyURL, "http://proxy.local")
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated.Metadata == nil {
|
||||||
|
t.Fatalf("expected metadata to be non-nil")
|
||||||
|
}
|
||||||
|
if got, _ := updated.Metadata["prefix"].(string); got != "p1" {
|
||||||
|
t.Fatalf("metadata.prefix = %q, want %q", got, "p1")
|
||||||
|
}
|
||||||
|
if got, _ := updated.Metadata["proxy_url"].(string); got != "http://proxy.local" {
|
||||||
|
t.Fatalf("metadata.proxy_url = %q, want %q", got, "http://proxy.local")
|
||||||
|
}
|
||||||
|
|
||||||
|
headersMeta, ok := updated.Metadata["headers"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
raw, _ := json.Marshal(updated.Metadata["headers"])
|
||||||
|
t.Fatalf("metadata.headers = %T (%s), want map[string]any", updated.Metadata["headers"], string(raw))
|
||||||
|
}
|
||||||
|
if got := headersMeta["X-Old"]; got != "new" {
|
||||||
|
t.Fatalf("metadata.headers.X-Old = %#v, want %q", got, "new")
|
||||||
|
}
|
||||||
|
if got := headersMeta["X-New"]; got != "v" {
|
||||||
|
t.Fatalf("metadata.headers.X-New = %#v, want %q", got, "v")
|
||||||
|
}
|
||||||
|
if _, ok := headersMeta["X-Remove"]; ok {
|
||||||
|
t.Fatalf("expected metadata.headers.X-Remove to be deleted")
|
||||||
|
}
|
||||||
|
if _, ok := headersMeta["X-Nope"]; ok {
|
||||||
|
t.Fatalf("expected metadata.headers.X-Nope to be absent")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := updated.Attributes["header:X-Old"]; got != "new" {
|
||||||
|
t.Fatalf("attrs header:X-Old = %q, want %q", got, "new")
|
||||||
|
}
|
||||||
|
if got := updated.Attributes["header:X-New"]; got != "v" {
|
||||||
|
t.Fatalf("attrs header:X-New = %q, want %q", got, "v")
|
||||||
|
}
|
||||||
|
if _, ok := updated.Attributes["header:X-Remove"]; ok {
|
||||||
|
t.Fatalf("expected attrs header:X-Remove to be deleted")
|
||||||
|
}
|
||||||
|
if _, ok := updated.Attributes["header:X-Nope"]; ok {
|
||||||
|
t.Fatalf("expected attrs header:X-Nope to be absent")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatchAuthFileFields_HeadersEmptyMapIsNoop(t *testing.T) {
|
||||||
|
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
store := &memoryAuthStore{}
|
||||||
|
manager := coreauth.NewManager(store, nil, nil)
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: "noop.json",
|
||||||
|
FileName: "noop.json",
|
||||||
|
Provider: "claude",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"path": "/tmp/noop.json",
|
||||||
|
"header:X-Kee": "1",
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "claude",
|
||||||
|
"headers": map[string]any{
|
||||||
|
"X-Kee": "1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
|
||||||
|
t.Fatalf("failed to register auth record: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
|
||||||
|
|
||||||
|
body := `{"name":"noop.json","note":"hello","headers":{}}`
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
ctx.Request = req
|
||||||
|
h.PatchAuthFileFields(ctx)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, ok := manager.GetByID("noop.json")
|
||||||
|
if !ok || updated == nil {
|
||||||
|
t.Fatalf("expected auth record to exist after patch")
|
||||||
|
}
|
||||||
|
if got := updated.Attributes["header:X-Kee"]; got != "1" {
|
||||||
|
t.Fatalf("attrs header:X-Kee = %q, want %q", got, "1")
|
||||||
|
}
|
||||||
|
headersMeta, ok := updated.Metadata["headers"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected metadata.headers to remain a map, got %T", updated.Metadata["headers"])
|
||||||
|
}
|
||||||
|
if got := headersMeta["X-Kee"]; got != "1" {
|
||||||
|
t.Fatalf("metadata.headers.X-Kee = %#v, want %q", got, "1")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -214,19 +214,46 @@ func (h *Handler) PatchGeminiKey(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) DeleteGeminiKey(c *gin.Context) {
|
func (h *Handler) DeleteGeminiKey(c *gin.Context) {
|
||||||
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||||
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
|
if baseRaw, okBase := c.GetQuery("base-url"); okBase {
|
||||||
for _, v := range h.cfg.GeminiKey {
|
base := strings.TrimSpace(baseRaw)
|
||||||
if v.APIKey != val {
|
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
|
||||||
|
for _, v := range h.cfg.GeminiKey {
|
||||||
|
if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base {
|
||||||
|
continue
|
||||||
|
}
|
||||||
out = append(out, v)
|
out = append(out, v)
|
||||||
}
|
}
|
||||||
|
if len(out) != len(h.cfg.GeminiKey) {
|
||||||
|
h.cfg.GeminiKey = out
|
||||||
|
h.cfg.SanitizeGeminiKeys()
|
||||||
|
h.persist(c)
|
||||||
|
} else {
|
||||||
|
c.JSON(404, gin.H{"error": "item not found"})
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if len(out) != len(h.cfg.GeminiKey) {
|
|
||||||
h.cfg.GeminiKey = out
|
matchIndex := -1
|
||||||
h.cfg.SanitizeGeminiKeys()
|
matchCount := 0
|
||||||
h.persist(c)
|
for i := range h.cfg.GeminiKey {
|
||||||
} else {
|
if strings.TrimSpace(h.cfg.GeminiKey[i].APIKey) == val {
|
||||||
|
matchCount++
|
||||||
|
if matchIndex == -1 {
|
||||||
|
matchIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if matchCount == 0 {
|
||||||
c.JSON(404, gin.H{"error": "item not found"})
|
c.JSON(404, gin.H{"error": "item not found"})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
if matchCount > 1 {
|
||||||
|
c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:matchIndex], h.cfg.GeminiKey[matchIndex+1:]...)
|
||||||
|
h.cfg.SanitizeGeminiKeys()
|
||||||
|
h.persist(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if idxStr := c.Query("index"); idxStr != "" {
|
if idxStr := c.Query("index"); idxStr != "" {
|
||||||
@@ -335,14 +362,39 @@ func (h *Handler) PatchClaudeKey(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) DeleteClaudeKey(c *gin.Context) {
|
func (h *Handler) DeleteClaudeKey(c *gin.Context) {
|
||||||
if val := c.Query("api-key"); val != "" {
|
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||||
out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey))
|
if baseRaw, okBase := c.GetQuery("base-url"); okBase {
|
||||||
for _, v := range h.cfg.ClaudeKey {
|
base := strings.TrimSpace(baseRaw)
|
||||||
if v.APIKey != val {
|
out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey))
|
||||||
|
for _, v := range h.cfg.ClaudeKey {
|
||||||
|
if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base {
|
||||||
|
continue
|
||||||
|
}
|
||||||
out = append(out, v)
|
out = append(out, v)
|
||||||
}
|
}
|
||||||
|
h.cfg.ClaudeKey = out
|
||||||
|
h.cfg.SanitizeClaudeKeys()
|
||||||
|
h.persist(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
matchIndex := -1
|
||||||
|
matchCount := 0
|
||||||
|
for i := range h.cfg.ClaudeKey {
|
||||||
|
if strings.TrimSpace(h.cfg.ClaudeKey[i].APIKey) == val {
|
||||||
|
matchCount++
|
||||||
|
if matchIndex == -1 {
|
||||||
|
matchIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if matchCount > 1 {
|
||||||
|
c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if matchIndex != -1 {
|
||||||
|
h.cfg.ClaudeKey = append(h.cfg.ClaudeKey[:matchIndex], h.cfg.ClaudeKey[matchIndex+1:]...)
|
||||||
}
|
}
|
||||||
h.cfg.ClaudeKey = out
|
|
||||||
h.cfg.SanitizeClaudeKeys()
|
h.cfg.SanitizeClaudeKeys()
|
||||||
h.persist(c)
|
h.persist(c)
|
||||||
return
|
return
|
||||||
@@ -509,19 +561,24 @@ func (h *Handler) PutVertexCompatKeys(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
for i := range arr {
|
for i := range arr {
|
||||||
normalizeVertexCompatKey(&arr[i])
|
normalizeVertexCompatKey(&arr[i])
|
||||||
|
if arr[i].APIKey == "" {
|
||||||
|
c.JSON(400, gin.H{"error": fmt.Sprintf("vertex-api-key[%d].api-key is required", i)})
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
h.cfg.VertexCompatAPIKey = arr
|
h.cfg.VertexCompatAPIKey = append([]config.VertexCompatKey(nil), arr...)
|
||||||
h.cfg.SanitizeVertexCompatKeys()
|
h.cfg.SanitizeVertexCompatKeys()
|
||||||
h.persist(c)
|
h.persist(c)
|
||||||
}
|
}
|
||||||
func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
|
func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
|
||||||
type vertexCompatPatch struct {
|
type vertexCompatPatch struct {
|
||||||
APIKey *string `json:"api-key"`
|
APIKey *string `json:"api-key"`
|
||||||
Prefix *string `json:"prefix"`
|
Prefix *string `json:"prefix"`
|
||||||
BaseURL *string `json:"base-url"`
|
BaseURL *string `json:"base-url"`
|
||||||
ProxyURL *string `json:"proxy-url"`
|
ProxyURL *string `json:"proxy-url"`
|
||||||
Headers *map[string]string `json:"headers"`
|
Headers *map[string]string `json:"headers"`
|
||||||
Models *[]config.VertexCompatModel `json:"models"`
|
Models *[]config.VertexCompatModel `json:"models"`
|
||||||
|
ExcludedModels *[]string `json:"excluded-models"`
|
||||||
}
|
}
|
||||||
var body struct {
|
var body struct {
|
||||||
Index *int `json:"index"`
|
Index *int `json:"index"`
|
||||||
@@ -585,6 +642,9 @@ func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
|
|||||||
if body.Value.Models != nil {
|
if body.Value.Models != nil {
|
||||||
entry.Models = append([]config.VertexCompatModel(nil), (*body.Value.Models)...)
|
entry.Models = append([]config.VertexCompatModel(nil), (*body.Value.Models)...)
|
||||||
}
|
}
|
||||||
|
if body.Value.ExcludedModels != nil {
|
||||||
|
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
|
||||||
|
}
|
||||||
normalizeVertexCompatKey(&entry)
|
normalizeVertexCompatKey(&entry)
|
||||||
h.cfg.VertexCompatAPIKey[targetIndex] = entry
|
h.cfg.VertexCompatAPIKey[targetIndex] = entry
|
||||||
h.cfg.SanitizeVertexCompatKeys()
|
h.cfg.SanitizeVertexCompatKeys()
|
||||||
@@ -593,13 +653,38 @@ func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) DeleteVertexCompatKey(c *gin.Context) {
|
func (h *Handler) DeleteVertexCompatKey(c *gin.Context) {
|
||||||
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||||
out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey))
|
if baseRaw, okBase := c.GetQuery("base-url"); okBase {
|
||||||
for _, v := range h.cfg.VertexCompatAPIKey {
|
base := strings.TrimSpace(baseRaw)
|
||||||
if v.APIKey != val {
|
out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey))
|
||||||
|
for _, v := range h.cfg.VertexCompatAPIKey {
|
||||||
|
if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base {
|
||||||
|
continue
|
||||||
|
}
|
||||||
out = append(out, v)
|
out = append(out, v)
|
||||||
}
|
}
|
||||||
|
h.cfg.VertexCompatAPIKey = out
|
||||||
|
h.cfg.SanitizeVertexCompatKeys()
|
||||||
|
h.persist(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
matchIndex := -1
|
||||||
|
matchCount := 0
|
||||||
|
for i := range h.cfg.VertexCompatAPIKey {
|
||||||
|
if strings.TrimSpace(h.cfg.VertexCompatAPIKey[i].APIKey) == val {
|
||||||
|
matchCount++
|
||||||
|
if matchIndex == -1 {
|
||||||
|
matchIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if matchCount > 1 {
|
||||||
|
c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if matchIndex != -1 {
|
||||||
|
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:matchIndex], h.cfg.VertexCompatAPIKey[matchIndex+1:]...)
|
||||||
}
|
}
|
||||||
h.cfg.VertexCompatAPIKey = out
|
|
||||||
h.cfg.SanitizeVertexCompatKeys()
|
h.cfg.SanitizeVertexCompatKeys()
|
||||||
h.persist(c)
|
h.persist(c)
|
||||||
return
|
return
|
||||||
@@ -911,14 +996,39 @@ func (h *Handler) PatchCodexKey(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) DeleteCodexKey(c *gin.Context) {
|
func (h *Handler) DeleteCodexKey(c *gin.Context) {
|
||||||
if val := c.Query("api-key"); val != "" {
|
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||||
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
|
if baseRaw, okBase := c.GetQuery("base-url"); okBase {
|
||||||
for _, v := range h.cfg.CodexKey {
|
base := strings.TrimSpace(baseRaw)
|
||||||
if v.APIKey != val {
|
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
|
||||||
|
for _, v := range h.cfg.CodexKey {
|
||||||
|
if strings.TrimSpace(v.APIKey) == val && strings.TrimSpace(v.BaseURL) == base {
|
||||||
|
continue
|
||||||
|
}
|
||||||
out = append(out, v)
|
out = append(out, v)
|
||||||
}
|
}
|
||||||
|
h.cfg.CodexKey = out
|
||||||
|
h.cfg.SanitizeCodexKeys()
|
||||||
|
h.persist(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
matchIndex := -1
|
||||||
|
matchCount := 0
|
||||||
|
for i := range h.cfg.CodexKey {
|
||||||
|
if strings.TrimSpace(h.cfg.CodexKey[i].APIKey) == val {
|
||||||
|
matchCount++
|
||||||
|
if matchIndex == -1 {
|
||||||
|
matchIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if matchCount > 1 {
|
||||||
|
c.JSON(400, gin.H{"error": "multiple items match api-key; base-url is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if matchIndex != -1 {
|
||||||
|
h.cfg.CodexKey = append(h.cfg.CodexKey[:matchIndex], h.cfg.CodexKey[matchIndex+1:]...)
|
||||||
}
|
}
|
||||||
h.cfg.CodexKey = out
|
|
||||||
h.cfg.SanitizeCodexKeys()
|
h.cfg.SanitizeCodexKeys()
|
||||||
h.persist(c)
|
h.persist(c)
|
||||||
return
|
return
|
||||||
@@ -1029,6 +1139,7 @@ func normalizeVertexCompatKey(entry *config.VertexCompatKey) {
|
|||||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||||
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
||||||
|
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
|
||||||
if len(entry.Models) == 0 {
|
if len(entry.Models) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,172 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func writeTestConfigFile(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
if errWrite := os.WriteFile(path, []byte("{}\n"), 0o600); errWrite != nil {
|
||||||
|
t.Fatalf("failed to write test config: %v", errWrite)
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteGeminiKey_RequiresBaseURLWhenAPIKeyDuplicated(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
GeminiKey: []config.GeminiKey{
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://a.example.com"},
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://b.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
configFilePath: writeTestConfigFile(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/gemini-api-key?api-key=shared-key", nil)
|
||||||
|
|
||||||
|
h.DeleteGeminiKey(c)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.cfg.GeminiKey); got != 2 {
|
||||||
|
t.Fatalf("gemini keys len = %d, want 2", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteGeminiKey_DeletesOnlyMatchingBaseURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
GeminiKey: []config.GeminiKey{
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://a.example.com"},
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://b.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
configFilePath: writeTestConfigFile(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/gemini-api-key?api-key=shared-key&base-url=https://a.example.com", nil)
|
||||||
|
|
||||||
|
h.DeleteGeminiKey(c)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.cfg.GeminiKey); got != 1 {
|
||||||
|
t.Fatalf("gemini keys len = %d, want 1", got)
|
||||||
|
}
|
||||||
|
if got := h.cfg.GeminiKey[0].BaseURL; got != "https://b.example.com" {
|
||||||
|
t.Fatalf("remaining base-url = %q, want %q", got, "https://b.example.com")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteClaudeKey_DeletesEmptyBaseURLWhenExplicitlyProvided(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
ClaudeKey: []config.ClaudeKey{
|
||||||
|
{APIKey: "shared-key", BaseURL: ""},
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://claude.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
configFilePath: writeTestConfigFile(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/claude-api-key?api-key=shared-key&base-url=", nil)
|
||||||
|
|
||||||
|
h.DeleteClaudeKey(c)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.cfg.ClaudeKey); got != 1 {
|
||||||
|
t.Fatalf("claude keys len = %d, want 1", got)
|
||||||
|
}
|
||||||
|
if got := h.cfg.ClaudeKey[0].BaseURL; got != "https://claude.example.com" {
|
||||||
|
t.Fatalf("remaining base-url = %q, want %q", got, "https://claude.example.com")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteVertexCompatKey_DeletesOnlyMatchingBaseURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://a.example.com"},
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://b.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
configFilePath: writeTestConfigFile(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/vertex-api-key?api-key=shared-key&base-url=https://b.example.com", nil)
|
||||||
|
|
||||||
|
h.DeleteVertexCompatKey(c)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.cfg.VertexCompatAPIKey); got != 1 {
|
||||||
|
t.Fatalf("vertex keys len = %d, want 1", got)
|
||||||
|
}
|
||||||
|
if got := h.cfg.VertexCompatAPIKey[0].BaseURL; got != "https://a.example.com" {
|
||||||
|
t.Fatalf("remaining base-url = %q, want %q", got, "https://a.example.com")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteCodexKey_RequiresBaseURLWhenAPIKeyDuplicated(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := &Handler{
|
||||||
|
cfg: &config.Config{
|
||||||
|
CodexKey: []config.CodexKey{
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://a.example.com"},
|
||||||
|
{APIKey: "shared-key", BaseURL: "https://b.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
configFilePath: writeTestConfigFile(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodDelete, "/v0/management/codex-api-key?api-key=shared-key", nil)
|
||||||
|
|
||||||
|
h.DeleteCodexKey(c)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status = %d, want %d; body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.cfg.CodexKey); got != 2 {
|
||||||
|
t.Fatalf("codex keys len = %d, want 2", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -228,14 +228,12 @@ func NormalizeOAuthProvider(provider string) (string, error) {
|
|||||||
return "anthropic", nil
|
return "anthropic", nil
|
||||||
case "codex", "openai":
|
case "codex", "openai":
|
||||||
return "codex", nil
|
return "codex", nil
|
||||||
|
case "gitlab":
|
||||||
|
return "gitlab", nil
|
||||||
case "gemini", "google":
|
case "gemini", "google":
|
||||||
return "gemini", nil
|
return "gemini", nil
|
||||||
case "iflow", "i-flow":
|
|
||||||
return "iflow", nil
|
|
||||||
case "antigravity", "anti-gravity":
|
case "antigravity", "anti-gravity":
|
||||||
return "antigravity", nil
|
return "antigravity", nil
|
||||||
case "qwen":
|
|
||||||
return "qwen", nil
|
|
||||||
case "kiro":
|
case "kiro":
|
||||||
return "kiro", nil
|
return "kiro", nil
|
||||||
case "github":
|
case "github":
|
||||||
|
|||||||
49
internal/api/handlers/management/test_store_test.go
Normal file
49
internal/api/handlers/management/test_store_test.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type memoryAuthStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
items map[string]*coreauth.Auth
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) List(_ context.Context) ([]*coreauth.Auth, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
out := make([]*coreauth.Auth, 0, len(s.items))
|
||||||
|
for _, item := range s.items {
|
||||||
|
out = append(out, item)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) Save(_ context.Context, auth *coreauth.Auth) (string, error) {
|
||||||
|
if auth == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.items == nil {
|
||||||
|
s.items = make(map[string]*coreauth.Auth)
|
||||||
|
}
|
||||||
|
s.items[auth.ID] = auth
|
||||||
|
return auth.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) Delete(_ context.Context, id string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
delete(s.items, id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) SetBaseDir(string) {}
|
||||||
@@ -15,6 +15,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
|
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
|
||||||
|
const responseBodyOverrideContextKey = "RESPONSE_BODY_OVERRIDE"
|
||||||
|
const websocketTimelineOverrideContextKey = "WEBSOCKET_TIMELINE_OVERRIDE"
|
||||||
|
|
||||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||||
type RequestInfo struct {
|
type RequestInfo struct {
|
||||||
@@ -304,6 +306,10 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
|||||||
if len(apiResponse) > 0 {
|
if len(apiResponse) > 0 {
|
||||||
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
||||||
}
|
}
|
||||||
|
apiWebsocketTimeline := w.extractAPIWebsocketTimeline(c)
|
||||||
|
if len(apiWebsocketTimeline) > 0 {
|
||||||
|
_ = w.streamWriter.WriteAPIWebsocketTimeline(apiWebsocketTimeline)
|
||||||
|
}
|
||||||
if err := w.streamWriter.Close(); err != nil {
|
if err := w.streamWriter.Close(); err != nil {
|
||||||
w.streamWriter = nil
|
w.streamWriter = nil
|
||||||
return err
|
return err
|
||||||
@@ -312,7 +318,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.extractResponseBody(c), w.extractWebsocketTimeline(c), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIWebsocketTimeline(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
||||||
@@ -352,6 +358,18 @@ func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWriterWrapper) extractAPIWebsocketTimeline(c *gin.Context) []byte {
|
||||||
|
apiTimeline, isExist := c.Get("API_WEBSOCKET_TIMELINE")
|
||||||
|
if !isExist {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data, ok := apiTimeline.([]byte)
|
||||||
|
if !ok || len(data) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bytes.Clone(data)
|
||||||
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
|
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
|
||||||
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
|
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
|
||||||
if !isExist {
|
if !isExist {
|
||||||
@@ -364,19 +382,8 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
||||||
if c != nil {
|
if body := extractBodyOverride(c, requestBodyOverrideContextKey); len(body) > 0 {
|
||||||
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
|
return body
|
||||||
switch value := bodyOverride.(type) {
|
|
||||||
case []byte:
|
|
||||||
if len(value) > 0 {
|
|
||||||
return bytes.Clone(value)
|
|
||||||
}
|
|
||||||
case string:
|
|
||||||
if strings.TrimSpace(value) != "" {
|
|
||||||
return []byte(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||||
return w.requestInfo.Body
|
return w.requestInfo.Body
|
||||||
@@ -384,13 +391,48 @@ func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
func (w *ResponseWriterWrapper) extractResponseBody(c *gin.Context) []byte {
|
||||||
|
if body := extractBodyOverride(c, responseBodyOverrideContextKey); len(body) > 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
if w.body == nil || w.body.Len() == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bytes.Clone(w.body.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWriterWrapper) extractWebsocketTimeline(c *gin.Context) []byte {
|
||||||
|
return extractBodyOverride(c, websocketTimelineOverrideContextKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractBodyOverride(c *gin.Context, key string) []byte {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
bodyOverride, isExist := c.Get(key)
|
||||||
|
if !isExist {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch value := bodyOverride.(type) {
|
||||||
|
case []byte:
|
||||||
|
if len(value) > 0 {
|
||||||
|
return bytes.Clone(value)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
if strings.TrimSpace(value) != "" {
|
||||||
|
return []byte(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body, websocketTimeline, apiRequestBody, apiResponseBody, apiWebsocketTimeline []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||||
if w.requestInfo == nil {
|
if w.requestInfo == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if loggerWithOptions, ok := w.logger.(interface {
|
if loggerWithOptions, ok := w.logger.(interface {
|
||||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
||||||
}); ok {
|
}); ok {
|
||||||
return loggerWithOptions.LogRequestWithOptions(
|
return loggerWithOptions.LogRequestWithOptions(
|
||||||
w.requestInfo.URL,
|
w.requestInfo.URL,
|
||||||
@@ -400,8 +442,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
|
|||||||
statusCode,
|
statusCode,
|
||||||
headers,
|
headers,
|
||||||
body,
|
body,
|
||||||
|
websocketTimeline,
|
||||||
apiRequestBody,
|
apiRequestBody,
|
||||||
apiResponseBody,
|
apiResponseBody,
|
||||||
|
apiWebsocketTimeline,
|
||||||
apiResponseErrors,
|
apiResponseErrors,
|
||||||
forceLog,
|
forceLog,
|
||||||
w.requestInfo.RequestID,
|
w.requestInfo.RequestID,
|
||||||
@@ -418,8 +462,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
|
|||||||
statusCode,
|
statusCode,
|
||||||
headers,
|
headers,
|
||||||
body,
|
body,
|
||||||
|
websocketTimeline,
|
||||||
apiRequestBody,
|
apiRequestBody,
|
||||||
apiResponseBody,
|
apiResponseBody,
|
||||||
|
apiWebsocketTimeline,
|
||||||
apiResponseErrors,
|
apiResponseErrors,
|
||||||
w.requestInfo.RequestID,
|
w.requestInfo.RequestID,
|
||||||
w.requestInfo.Timestamp,
|
w.requestInfo.Timestamp,
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
|
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
|
||||||
@@ -33,7 +37,7 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
|||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
wrapper := &ResponseWriterWrapper{}
|
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
|
||||||
c.Set(requestBodyOverrideContextKey, "override-as-string")
|
c.Set(requestBodyOverrideContextKey, "override-as-string")
|
||||||
|
|
||||||
body := wrapper.extractRequestBody(c)
|
body := wrapper.extractRequestBody(c)
|
||||||
@@ -41,3 +45,158 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
|||||||
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
|
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtractResponseBodyPrefersOverride(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
|
||||||
|
wrapper.body.WriteString("original-response")
|
||||||
|
|
||||||
|
body := wrapper.extractResponseBody(c)
|
||||||
|
if string(body) != "original-response" {
|
||||||
|
t.Fatalf("response body = %q, want %q", string(body), "original-response")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set(responseBodyOverrideContextKey, []byte("override-response"))
|
||||||
|
body = wrapper.extractResponseBody(c)
|
||||||
|
if string(body) != "override-response" {
|
||||||
|
t.Fatalf("response body = %q, want %q", string(body), "override-response")
|
||||||
|
}
|
||||||
|
|
||||||
|
body[0] = 'X'
|
||||||
|
if got := wrapper.extractResponseBody(c); string(got) != "override-response" {
|
||||||
|
t.Fatalf("response override should be cloned, got %q", string(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractResponseBodySupportsStringOverride(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
wrapper := &ResponseWriterWrapper{}
|
||||||
|
c.Set(responseBodyOverrideContextKey, "override-response-as-string")
|
||||||
|
|
||||||
|
body := wrapper.extractResponseBody(c)
|
||||||
|
if string(body) != "override-response-as-string" {
|
||||||
|
t.Fatalf("response body = %q, want %q", string(body), "override-response-as-string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractBodyOverrideClonesBytes(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
override := []byte("body-override")
|
||||||
|
c.Set(requestBodyOverrideContextKey, override)
|
||||||
|
|
||||||
|
body := extractBodyOverride(c, requestBodyOverrideContextKey)
|
||||||
|
if !bytes.Equal(body, override) {
|
||||||
|
t.Fatalf("body override = %q, want %q", string(body), string(override))
|
||||||
|
}
|
||||||
|
|
||||||
|
body[0] = 'X'
|
||||||
|
if !bytes.Equal(override, []byte("body-override")) {
|
||||||
|
t.Fatalf("override mutated: %q", string(override))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractWebsocketTimelineUsesOverride(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
wrapper := &ResponseWriterWrapper{}
|
||||||
|
if got := wrapper.extractWebsocketTimeline(c); got != nil {
|
||||||
|
t.Fatalf("expected nil websocket timeline, got %q", string(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set(websocketTimelineOverrideContextKey, []byte("timeline"))
|
||||||
|
body := wrapper.extractWebsocketTimeline(c)
|
||||||
|
if string(body) != "timeline" {
|
||||||
|
t.Fatalf("websocket timeline = %q, want %q", string(body), "timeline")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFinalizeStreamingWritesAPIWebsocketTimeline(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
streamWriter := &testStreamingLogWriter{}
|
||||||
|
wrapper := &ResponseWriterWrapper{
|
||||||
|
ResponseWriter: c.Writer,
|
||||||
|
logger: &testRequestLogger{enabled: true},
|
||||||
|
requestInfo: &RequestInfo{
|
||||||
|
URL: "/v1/responses",
|
||||||
|
Method: "POST",
|
||||||
|
Headers: map[string][]string{"Content-Type": {"application/json"}},
|
||||||
|
RequestID: "req-1",
|
||||||
|
Timestamp: time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC),
|
||||||
|
},
|
||||||
|
isStreaming: true,
|
||||||
|
streamWriter: streamWriter,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set("API_WEBSOCKET_TIMELINE", []byte("Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}"))
|
||||||
|
|
||||||
|
if err := wrapper.Finalize(c); err != nil {
|
||||||
|
t.Fatalf("Finalize error: %v", err)
|
||||||
|
}
|
||||||
|
if string(streamWriter.apiWebsocketTimeline) != "Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}" {
|
||||||
|
t.Fatalf("stream writer websocket timeline = %q", string(streamWriter.apiWebsocketTimeline))
|
||||||
|
}
|
||||||
|
if !streamWriter.closed {
|
||||||
|
t.Fatal("expected stream writer to be closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testRequestLogger struct {
|
||||||
|
enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *testRequestLogger) LogRequest(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, string, time.Time, time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *testRequestLogger) LogStreamingRequest(string, string, map[string][]string, []byte, string) (logging.StreamingLogWriter, error) {
|
||||||
|
return &testStreamingLogWriter{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *testRequestLogger) IsEnabled() bool {
|
||||||
|
return l.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
type testStreamingLogWriter struct {
|
||||||
|
apiWebsocketTimeline []byte
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) WriteChunkAsync([]byte) {}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) WriteStatus(int, map[string][]string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) WriteAPIRequest([]byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) WriteAPIResponse([]byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
|
||||||
|
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) SetFirstChunkTimestamp(time.Time) {}
|
||||||
|
|
||||||
|
func (w *testStreamingLogWriter) Close() error {
|
||||||
|
w.closed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -123,6 +123,10 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sanitize request body: remove thinking blocks with invalid signatures
|
||||||
|
// to prevent upstream API 400 errors
|
||||||
|
bodyBytes = SanitizeAmpRequestBody(bodyBytes)
|
||||||
|
|
||||||
// Restore the body for the handler to read
|
// Restore the body for the handler to read
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
|
||||||
@@ -249,6 +253,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
||||||
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
||||||
rewriter := NewResponseRewriter(c.Writer, modelName)
|
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||||
|
rewriter.suppressThinking = true
|
||||||
c.Writer = rewriter
|
c.Writer = rewriter
|
||||||
// Filter Anthropic-Beta header only for local handling paths
|
// Filter Anthropic-Beta header only for local handling paths
|
||||||
filterAntropicBetaHeader(c)
|
filterAntropicBetaHeader(c)
|
||||||
@@ -259,10 +264,17 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
} else if len(providers) > 0 {
|
} else if len(providers) > 0 {
|
||||||
// Log: Using local provider (free)
|
// Log: Using local provider (free)
|
||||||
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
||||||
|
// Wrap with ResponseRewriter for local providers too, because upstream
|
||||||
|
// proxies (e.g. NewAPI) may return a different model name and lack
|
||||||
|
// Amp-required fields like thinking.signature.
|
||||||
|
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||||
|
rewriter.suppressThinking = providerName != "claude"
|
||||||
|
c.Writer = rewriter
|
||||||
// Filter Anthropic-Beta header only for local handling paths
|
// Filter Anthropic-Beta header only for local handling paths
|
||||||
filterAntropicBetaHeader(c)
|
filterAntropicBetaHeader(c)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
handler(c)
|
handler(c)
|
||||||
|
rewriter.Flush()
|
||||||
} else {
|
} else {
|
||||||
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
|
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -77,6 +78,9 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
|||||||
req.Header.Del("X-Api-Key")
|
req.Header.Del("X-Api-Key")
|
||||||
req.Header.Del("X-Goog-Api-Key")
|
req.Header.Del("X-Goog-Api-Key")
|
||||||
|
|
||||||
|
// Remove proxy, client identity, and browser fingerprint headers
|
||||||
|
misc.ScrubProxyAndFingerprintHeaders(req)
|
||||||
|
|
||||||
// Remove query-based credentials if they match the authenticated client API key.
|
// Remove query-based credentials if they match the authenticated client API key.
|
||||||
// This prevents leaking client auth material to the Amp upstream while avoiding
|
// This prevents leaking client auth material to the Amp upstream while avoiding
|
||||||
// breaking unrelated upstream query parameters.
|
// breaking unrelated upstream query parameters.
|
||||||
|
|||||||
@@ -129,11 +129,11 @@ func TestModifyResponse_GzipScenarios(t *testing.T) {
|
|||||||
wantCE: "",
|
wantCE: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "skips_non_2xx_status",
|
name: "decompresses_non_2xx_status_when_gzip_detected",
|
||||||
header: http.Header{},
|
header: http.Header{},
|
||||||
body: good,
|
body: good,
|
||||||
status: 404,
|
status: 404,
|
||||||
wantBody: good,
|
wantBody: goodJSON,
|
||||||
wantCE: "",
|
wantCE: "",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package amp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -12,15 +14,17 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
|
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
|
||||||
// It's used to rewrite model names in responses when model mapping is used
|
// It is used to rewrite model names in responses when model mapping is used
|
||||||
|
// and to keep Amp-compatible response shapes.
|
||||||
type ResponseRewriter struct {
|
type ResponseRewriter struct {
|
||||||
gin.ResponseWriter
|
gin.ResponseWriter
|
||||||
body *bytes.Buffer
|
body *bytes.Buffer
|
||||||
originalModel string
|
originalModel string
|
||||||
isStreaming bool
|
isStreaming bool
|
||||||
|
suppressThinking bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResponseRewriter creates a new response rewriter for model name substitution
|
// NewResponseRewriter creates a new response rewriter for model name substitution.
|
||||||
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
|
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
|
||||||
return &ResponseRewriter{
|
return &ResponseRewriter{
|
||||||
ResponseWriter: w,
|
ResponseWriter: w,
|
||||||
@@ -33,15 +37,15 @@ const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
|
|||||||
|
|
||||||
func looksLikeSSEChunk(data []byte) bool {
|
func looksLikeSSEChunk(data []byte) bool {
|
||||||
// Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
|
// Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
|
||||||
// Heuristics are intentionally simple and cheap.
|
// We conservatively detect SSE by checking for "data:" / "event:" at the start of any line.
|
||||||
return bytes.Contains(data, []byte("data:")) ||
|
for _, line := range bytes.Split(data, []byte("\n")) {
|
||||||
bytes.Contains(data, []byte("event:")) ||
|
trimmed := bytes.TrimSpace(line)
|
||||||
bytes.Contains(data, []byte("message_start")) ||
|
if bytes.HasPrefix(trimmed, []byte("data:")) ||
|
||||||
bytes.Contains(data, []byte("message_delta")) ||
|
bytes.HasPrefix(trimmed, []byte("event:")) {
|
||||||
bytes.Contains(data, []byte("content_block_start")) ||
|
return true
|
||||||
bytes.Contains(data, []byte("content_block_delta")) ||
|
}
|
||||||
bytes.Contains(data, []byte("content_block_stop")) ||
|
}
|
||||||
bytes.Contains(data, []byte("\n\n"))
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *ResponseRewriter) enableStreaming(reason string) error {
|
func (rw *ResponseRewriter) enableStreaming(reason string) error {
|
||||||
@@ -95,7 +99,8 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if rw.isStreaming {
|
if rw.isStreaming {
|
||||||
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
rewritten := rw.rewriteStreamChunk(data)
|
||||||
|
n, err := rw.ResponseWriter.Write(rewritten)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
@@ -106,7 +111,6 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
|||||||
return rw.body.Write(data)
|
return rw.body.Write(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush writes the buffered response with model names rewritten
|
|
||||||
func (rw *ResponseRewriter) Flush() {
|
func (rw *ResponseRewriter) Flush() {
|
||||||
if rw.isStreaming {
|
if rw.isStreaming {
|
||||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||||
@@ -115,40 +119,79 @@ func (rw *ResponseRewriter) Flush() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if rw.body.Len() > 0 {
|
if rw.body.Len() > 0 {
|
||||||
if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil {
|
rewritten := rw.rewriteModelInResponse(rw.body.Bytes())
|
||||||
|
// Update Content-Length to match the rewritten body size, since
|
||||||
|
// signature injection and model name changes alter the payload length.
|
||||||
|
rw.ResponseWriter.Header().Set("Content-Length", fmt.Sprintf("%d", len(rewritten)))
|
||||||
|
if _, err := rw.ResponseWriter.Write(rewritten); err != nil {
|
||||||
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
|
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// modelFieldPaths lists all JSON paths where model name may appear
|
|
||||||
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
||||||
|
|
||||||
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
|
||||||
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
|
// in API responses so that the Amp TUI does not crash on P.signature.length.
|
||||||
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
func ensureAmpSignature(data []byte) []byte {
|
||||||
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
|
for index, block := range gjson.GetBytes(data, "content").Array() {
|
||||||
// The Amp client struggles when both thinking and tool_use blocks are present
|
blockType := block.Get("type").String()
|
||||||
|
if blockType != "tool_use" && blockType != "thinking" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
signaturePath := fmt.Sprintf("content.%d.signature", index)
|
||||||
|
if gjson.GetBytes(data, signaturePath).Exists() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
data, err = sjson.SetBytes(data, signaturePath, "")
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Amp ResponseRewriter: failed to add empty signature to %s block: %v", blockType, err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
contentBlockType := gjson.GetBytes(data, "content_block.type").String()
|
||||||
|
if (contentBlockType == "tool_use" || contentBlockType == "thinking") && !gjson.GetBytes(data, "content_block.signature").Exists() {
|
||||||
|
var err error
|
||||||
|
data, err = sjson.SetBytes(data, "content_block.signature", "")
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Amp ResponseRewriter: failed to add empty signature to streaming %s block: %v", contentBlockType, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
|
||||||
|
if !rw.suppressThinking {
|
||||||
|
return data
|
||||||
|
}
|
||||||
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
|
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
|
||||||
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
|
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
|
||||||
if filtered.Exists() {
|
if filtered.Exists() {
|
||||||
originalCount := gjson.GetBytes(data, "content.#").Int()
|
originalCount := gjson.GetBytes(data, "content.#").Int()
|
||||||
filteredCount := filtered.Get("#").Int()
|
filteredCount := filtered.Get("#").Int()
|
||||||
|
|
||||||
if originalCount > filteredCount {
|
if originalCount > filteredCount {
|
||||||
var err error
|
var err error
|
||||||
data, err = sjson.SetBytes(data, "content", filtered.Value())
|
data, err = sjson.SetBytes(data, "content", filtered.Value())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
|
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
|
||||||
} else {
|
|
||||||
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
|
|
||||||
// Log the result for verification
|
|
||||||
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||||
|
data = ensureAmpSignature(data)
|
||||||
|
data = rw.suppressAmpThinking(data)
|
||||||
|
if len(data) == 0 {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
if rw.originalModel == "" {
|
if rw.originalModel == "" {
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
@@ -160,24 +203,164 @@ func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
// rewriteStreamChunk rewrites model names in SSE stream chunks
|
|
||||||
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
||||||
if rw.originalModel == "" {
|
lines := bytes.Split(chunk, []byte("\n"))
|
||||||
return chunk
|
var out [][]byte
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
for i < len(lines) {
|
||||||
|
line := lines[i]
|
||||||
|
trimmed := bytes.TrimSpace(line)
|
||||||
|
|
||||||
|
// Case 1: "event:" line - look ahead for its "data:" line
|
||||||
|
if bytes.HasPrefix(trimmed, []byte("event: ")) {
|
||||||
|
// Scan forward past blank lines to find the data: line
|
||||||
|
dataIdx := -1
|
||||||
|
for j := i + 1; j < len(lines); j++ {
|
||||||
|
t := bytes.TrimSpace(lines[j])
|
||||||
|
if len(t) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if bytes.HasPrefix(t, []byte("data: ")) {
|
||||||
|
dataIdx = j
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if dataIdx >= 0 {
|
||||||
|
// Found event+data pair - process through rewriter
|
||||||
|
jsonData := bytes.TrimPrefix(bytes.TrimSpace(lines[dataIdx]), []byte("data: "))
|
||||||
|
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||||
|
rewritten := rw.rewriteStreamEvent(jsonData)
|
||||||
|
if rewritten == nil {
|
||||||
|
i = dataIdx + 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Emit event line
|
||||||
|
out = append(out, line)
|
||||||
|
// Emit blank lines between event and data
|
||||||
|
for k := i + 1; k < dataIdx; k++ {
|
||||||
|
out = append(out, lines[k])
|
||||||
|
}
|
||||||
|
// Emit rewritten data
|
||||||
|
out = append(out, append([]byte("data: "), rewritten...))
|
||||||
|
i = dataIdx + 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No data line found (orphan event from cross-chunk split)
|
||||||
|
// Pass it through as-is - the data will arrive in the next chunk
|
||||||
|
out = append(out, line)
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 2: standalone "data:" line (no preceding event: in this chunk)
|
||||||
|
if bytes.HasPrefix(trimmed, []byte("data: ")) {
|
||||||
|
jsonData := bytes.TrimPrefix(trimmed, []byte("data: "))
|
||||||
|
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||||
|
rewritten := rw.rewriteStreamEvent(jsonData)
|
||||||
|
if rewritten != nil {
|
||||||
|
out = append(out, append([]byte("data: "), rewritten...))
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 3: everything else
|
||||||
|
out = append(out, line)
|
||||||
|
i++
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSE format: "data: {json}\n\n"
|
return bytes.Join(out, []byte("\n"))
|
||||||
lines := bytes.Split(chunk, []byte("\n"))
|
}
|
||||||
for i, line := range lines {
|
|
||||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
// rewriteStreamEvent processes a single JSON event in the SSE stream.
|
||||||
jsonData := bytes.TrimPrefix(line, []byte("data: "))
|
// It rewrites model names and ensures signature fields exist.
|
||||||
if len(jsonData) > 0 && jsonData[0] == '{' {
|
// NOTE: streaming mode does NOT suppress thinking blocks - they are
|
||||||
// Rewrite JSON in the data line
|
// passed through with signature injection to avoid breaking SSE index
|
||||||
rewritten := rw.rewriteModelInResponse(jsonData)
|
// alignment and TUI rendering.
|
||||||
lines[i] = append([]byte("data: "), rewritten...)
|
func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
|
||||||
|
// Inject empty signature where needed
|
||||||
|
data = ensureAmpSignature(data)
|
||||||
|
|
||||||
|
// Rewrite model name
|
||||||
|
if rw.originalModel != "" {
|
||||||
|
for _, path := range modelFieldPaths {
|
||||||
|
if gjson.GetBytes(data, path).Exists() {
|
||||||
|
data, _ = sjson.SetBytes(data, path, rw.originalModel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return bytes.Join(lines, []byte("\n"))
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures
|
||||||
|
// and strips the proxy-injected "signature" field from tool_use blocks in the messages
|
||||||
|
// array before forwarding to the upstream API.
|
||||||
|
// This prevents 400 errors from the API which requires valid signatures on thinking
|
||||||
|
// blocks and does not accept a signature field on tool_use blocks.
|
||||||
|
func SanitizeAmpRequestBody(body []byte) []byte {
|
||||||
|
messages := gjson.GetBytes(body, "messages")
|
||||||
|
if !messages.Exists() || !messages.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
modified := false
|
||||||
|
for msgIdx, msg := range messages.Array() {
|
||||||
|
if msg.Get("role").String() != "assistant" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content := msg.Get("content")
|
||||||
|
if !content.Exists() || !content.IsArray() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var keepBlocks []interface{}
|
||||||
|
contentModified := false
|
||||||
|
|
||||||
|
for _, block := range content.Array() {
|
||||||
|
blockType := block.Get("type").String()
|
||||||
|
if blockType == "thinking" {
|
||||||
|
sig := block.Get("signature")
|
||||||
|
if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" {
|
||||||
|
contentModified = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use raw JSON to prevent float64 rounding of large integers in tool_use inputs
|
||||||
|
blockRaw := []byte(block.Raw)
|
||||||
|
if blockType == "tool_use" && block.Get("signature").Exists() {
|
||||||
|
blockRaw, _ = sjson.DeleteBytes(blockRaw, "signature")
|
||||||
|
contentModified = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// sjson.SetBytes supports raw JSON strings if wrapped in gjson.Raw
|
||||||
|
keepBlocks = append(keepBlocks, json.RawMessage(blockRaw))
|
||||||
|
}
|
||||||
|
|
||||||
|
if contentModified {
|
||||||
|
contentPath := fmt.Sprintf("messages.%d.content", msgIdx)
|
||||||
|
var err error
|
||||||
|
if len(keepBlocks) == 0 {
|
||||||
|
body, err = sjson.SetBytes(body, contentPath, []interface{}{})
|
||||||
|
} else {
|
||||||
|
body, err = sjson.SetBytes(body, contentPath, keepBlocks)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Amp RequestSanitizer: failed to sanitize message %d: %v", msgIdx, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if modified {
|
||||||
|
log.Debugf("Amp RequestSanitizer: sanitized request body")
|
||||||
|
}
|
||||||
|
return body
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package amp
|
package amp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -100,6 +101,80 @@ func TestRewriteStreamChunk_MessageModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRewriteStreamChunk_PreservesThinkingWithSignatureInjection(t *testing.T) {
|
||||||
|
rw := &ResponseRewriter{}
|
||||||
|
|
||||||
|
chunk := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"abc\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"bash\",\"input\":{}}}\n\n")
|
||||||
|
result := rw.rewriteStreamChunk(chunk)
|
||||||
|
|
||||||
|
// Streaming mode preserves thinking blocks (does NOT suppress them)
|
||||||
|
// to avoid breaking SSE index alignment and TUI rendering
|
||||||
|
if !contains(result, []byte(`"content_block":{"type":"thinking"`)) {
|
||||||
|
t.Fatalf("expected thinking content_block_start to be preserved, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"delta":{"type":"thinking_delta"`)) {
|
||||||
|
t.Fatalf("expected thinking_delta to be preserved, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"type":"content_block_stop","index":0`)) {
|
||||||
|
t.Fatalf("expected content_block_stop for thinking block to be preserved, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"content_block":{"type":"tool_use"`)) {
|
||||||
|
t.Fatalf("expected tool_use content_block frame to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
// Signature should be injected into both thinking and tool_use blocks
|
||||||
|
if count := strings.Count(string(result), `"signature":""`); count != 2 {
|
||||||
|
t.Fatalf("expected 2 signature injections, but got %d in %s", count, string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testing.T) {
|
||||||
|
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-whitespace","signature":" "},{"type":"thinking","thinking":"drop-number","signature":123},{"type":"thinking","thinking":"keep-valid","signature":"valid-signature"},{"type":"text","text":"keep-text"}]}]}`)
|
||||||
|
result := SanitizeAmpRequestBody(input)
|
||||||
|
|
||||||
|
if contains(result, []byte("drop-whitespace")) {
|
||||||
|
t.Fatalf("expected whitespace-only signature block to be removed, got %s", string(result))
|
||||||
|
}
|
||||||
|
if contains(result, []byte("drop-number")) {
|
||||||
|
t.Fatalf("expected non-string signature block to be removed, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte("keep-valid")) {
|
||||||
|
t.Fatalf("expected valid thinking block to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte("keep-text")) {
|
||||||
|
t.Fatalf("expected non-thinking content to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeAmpRequestBody_StripsSignatureFromToolUseBlocks(t *testing.T) {
|
||||||
|
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"thought","signature":"valid-sig"},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`)
|
||||||
|
result := SanitizeAmpRequestBody(input)
|
||||||
|
|
||||||
|
if contains(result, []byte(`"signature":""`)) {
|
||||||
|
t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"valid-sig"`)) {
|
||||||
|
t.Fatalf("expected thinking signature to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"tool_use"`)) {
|
||||||
|
t.Fatalf("expected tool_use block to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testing.T) {
|
||||||
|
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-me","signature":""},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`)
|
||||||
|
result := SanitizeAmpRequestBody(input)
|
||||||
|
|
||||||
|
if contains(result, []byte("drop-me")) {
|
||||||
|
t.Fatalf("expected invalid thinking block to be removed, got %s", string(result))
|
||||||
|
}
|
||||||
|
if contains(result, []byte(`"signature"`)) {
|
||||||
|
t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"tool_use"`)) {
|
||||||
|
t.Fatalf("expected tool_use block to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func contains(data, substr []byte) bool {
|
func contains(data, substr []byte) bool {
|
||||||
for i := 0; i <= len(data)-len(substr); i++ {
|
for i := 0; i <= len(data)-len(substr); i++ {
|
||||||
if string(data[i:i+len(substr)]) == string(substr) {
|
if string(data[i:i+len(substr)]) == string(substr) {
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
||||||
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
|
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||||
@@ -262,6 +263,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
}
|
}
|
||||||
managementasset.SetCurrentConfig(cfg)
|
managementasset.SetCurrentConfig(cfg)
|
||||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||||
|
applySignatureCacheConfig(nil, cfg)
|
||||||
// Initialize management handler
|
// Initialize management handler
|
||||||
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
|
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
|
||||||
if optionState.localPassword != "" {
|
if optionState.localPassword != "" {
|
||||||
@@ -323,6 +325,10 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
// setupRoutes configures the API routes for the server.
|
// setupRoutes configures the API routes for the server.
|
||||||
// It defines the endpoints and associates them with their respective handlers.
|
// It defines the endpoints and associates them with their respective handlers.
|
||||||
func (s *Server) setupRoutes() {
|
func (s *Server) setupRoutes() {
|
||||||
|
s.engine.GET("/healthz", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||||
|
})
|
||||||
|
|
||||||
s.engine.GET("/management.html", s.serveManagementControlPanel)
|
s.engine.GET("/management.html", s.serveManagementControlPanel)
|
||||||
openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers)
|
openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers)
|
||||||
geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers)
|
geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers)
|
||||||
@@ -403,6 +409,20 @@ func (s *Server) setupRoutes() {
|
|||||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
s.engine.GET("/gitlab/callback", func(c *gin.Context) {
|
||||||
|
code := c.Query("code")
|
||||||
|
state := c.Query("state")
|
||||||
|
errStr := c.Query("error")
|
||||||
|
if errStr == "" {
|
||||||
|
errStr = c.Query("error_description")
|
||||||
|
}
|
||||||
|
if state != "" {
|
||||||
|
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gitlab", state, code, errStr)
|
||||||
|
}
|
||||||
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||||
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||||
|
})
|
||||||
|
|
||||||
s.engine.GET("/google/callback", func(c *gin.Context) {
|
s.engine.GET("/google/callback", func(c *gin.Context) {
|
||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
state := c.Query("state")
|
state := c.Query("state")
|
||||||
@@ -417,20 +437,6 @@ func (s *Server) setupRoutes() {
|
|||||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||||
})
|
})
|
||||||
|
|
||||||
s.engine.GET("/iflow/callback", func(c *gin.Context) {
|
|
||||||
code := c.Query("code")
|
|
||||||
state := c.Query("state")
|
|
||||||
errStr := c.Query("error")
|
|
||||||
if errStr == "" {
|
|
||||||
errStr = c.Query("error_description")
|
|
||||||
}
|
|
||||||
if state != "" {
|
|
||||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr)
|
|
||||||
}
|
|
||||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
|
||||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
|
||||||
})
|
|
||||||
|
|
||||||
s.engine.GET("/antigravity/callback", func(c *gin.Context) {
|
s.engine.GET("/antigravity/callback", func(c *gin.Context) {
|
||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
state := c.Query("state")
|
state := c.Query("state")
|
||||||
@@ -555,6 +561,8 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.PUT("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
mgmt.PUT("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
||||||
mgmt.PATCH("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
mgmt.PATCH("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
||||||
|
|
||||||
|
mgmt.GET("/copilot-quota", s.mgmt.GetCopilotQuota)
|
||||||
|
|
||||||
mgmt.GET("/api-keys", s.mgmt.GetAPIKeys)
|
mgmt.GET("/api-keys", s.mgmt.GetAPIKeys)
|
||||||
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
|
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
|
||||||
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)
|
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)
|
||||||
@@ -658,19 +666,21 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
|
|
||||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||||
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
|
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
|
||||||
|
mgmt.GET("/gitlab-auth-url", s.mgmt.RequestGitLabToken)
|
||||||
|
mgmt.POST("/gitlab-auth-url", s.mgmt.RequestGitLabPATToken)
|
||||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||||
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
||||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
mgmt.GET("/kilo-auth-url", s.mgmt.RequestKiloToken)
|
||||||
mgmt.GET("/kilo-auth-url", s.mgmt.RequestKiloToken)
|
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
|
||||||
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
|
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
||||||
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
mgmt.GET("/cursor-auth-url", s.mgmt.RequestCursorToken)
|
||||||
mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
|
mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
|
||||||
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
||||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) managementAvailabilityMiddleware() gin.HandlerFunc {
|
func (s *Server) managementAvailabilityMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
@@ -943,6 +953,8 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
applySignatureCacheConfig(oldCfg, cfg)
|
||||||
|
|
||||||
if s.handlers != nil && s.handlers.AuthManager != nil {
|
if s.handlers != nil && s.handlers.AuthManager != nil {
|
||||||
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
|
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
|
||||||
}
|
}
|
||||||
@@ -1081,3 +1093,37 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
|
|||||||
c.AbortWithStatusJSON(statusCode, gin.H{"error": err.Message})
|
c.AbortWithStatusJSON(statusCode, gin.H{"error": err.Message})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func configuredSignatureCacheEnabled(cfg *config.Config) bool {
|
||||||
|
if cfg != nil && cfg.AntigravitySignatureCacheEnabled != nil {
|
||||||
|
return *cfg.AntigravitySignatureCacheEnabled
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func applySignatureCacheConfig(oldCfg, cfg *config.Config) {
|
||||||
|
newVal := configuredSignatureCacheEnabled(cfg)
|
||||||
|
newStrict := configuredSignatureBypassStrict(cfg)
|
||||||
|
if oldCfg == nil {
|
||||||
|
cache.SetSignatureCacheEnabled(newVal)
|
||||||
|
cache.SetSignatureBypassStrictMode(newStrict)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
oldVal := configuredSignatureCacheEnabled(oldCfg)
|
||||||
|
if oldVal != newVal {
|
||||||
|
cache.SetSignatureCacheEnabled(newVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
oldStrict := configuredSignatureBypassStrict(oldCfg)
|
||||||
|
if oldStrict != newStrict {
|
||||||
|
cache.SetSignatureBypassStrictMode(newStrict)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func configuredSignatureBypassStrict(cfg *config.Config) bool {
|
||||||
|
if cfg != nil && cfg.AntigravitySignatureBypassStrict != nil {
|
||||||
|
return *cfg.AntigravitySignatureBypassStrict
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
@@ -46,6 +47,28 @@ func newTestServer(t *testing.T) *Server {
|
|||||||
return NewServer(cfg, authManager, accessManager, configPath)
|
return NewServer(cfg, authManager, accessManager, configPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHealthz(t *testing.T) {
|
||||||
|
server := newTestServer(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
server.engine.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status code: got %d want %d; body=%s", rr.Code, http.StatusOK, rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to parse response JSON: %v; body=%s", err, rr.Body.String())
|
||||||
|
}
|
||||||
|
if resp.Status != "ok" {
|
||||||
|
t.Fatalf("unexpected response status: got %q want %q", resp.Status, "ok")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAmpProviderModelRoutes(t *testing.T) {
|
func TestAmpProviderModelRoutes(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -172,6 +195,8 @@ func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
true,
|
true,
|
||||||
"issue-1711",
|
"issue-1711",
|
||||||
time.Now(),
|
time.Now(),
|
||||||
|
|||||||
@@ -59,10 +59,30 @@ type ClaudeAuth struct {
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - *ClaudeAuth: A new Claude authentication service instance
|
// - *ClaudeAuth: A new Claude authentication service instance
|
||||||
func NewClaudeAuth(cfg *config.Config) *ClaudeAuth {
|
func NewClaudeAuth(cfg *config.Config) *ClaudeAuth {
|
||||||
|
return NewClaudeAuthWithProxyURL(cfg, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClaudeAuthWithProxyURL creates a new Anthropic authentication service with a proxy override.
|
||||||
|
// proxyURL takes precedence over cfg.ProxyURL when non-empty.
|
||||||
|
func NewClaudeAuthWithProxyURL(cfg *config.Config, proxyURL string) *ClaudeAuth {
|
||||||
|
effectiveProxyURL := strings.TrimSpace(proxyURL)
|
||||||
|
var sdkCfg *config.SDKConfig
|
||||||
|
if cfg != nil {
|
||||||
|
sdkCfgCopy := cfg.SDKConfig
|
||||||
|
if effectiveProxyURL == "" {
|
||||||
|
effectiveProxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||||
|
}
|
||||||
|
sdkCfgCopy.ProxyURL = effectiveProxyURL
|
||||||
|
sdkCfg = &sdkCfgCopy
|
||||||
|
} else if effectiveProxyURL != "" {
|
||||||
|
sdkCfgCopy := config.SDKConfig{ProxyURL: effectiveProxyURL}
|
||||||
|
sdkCfg = &sdkCfgCopy
|
||||||
|
}
|
||||||
|
|
||||||
// Use custom HTTP client with Firefox TLS fingerprint to bypass
|
// Use custom HTTP client with Firefox TLS fingerprint to bypass
|
||||||
// Cloudflare's bot detection on Anthropic domains
|
// Cloudflare's bot detection on Anthropic domains
|
||||||
return &ClaudeAuth{
|
return &ClaudeAuth{
|
||||||
httpClient: NewAnthropicHttpClient(&cfg.SDKConfig),
|
httpClient: NewAnthropicHttpClient(sdkCfg),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,7 +108,7 @@ func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string
|
|||||||
"client_id": {ClientID},
|
"client_id": {ClientID},
|
||||||
"response_type": {"code"},
|
"response_type": {"code"},
|
||||||
"redirect_uri": {RedirectURI},
|
"redirect_uri": {RedirectURI},
|
||||||
"scope": {"org:create_api_key user:profile user:inference"},
|
"scope": {"user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload"},
|
||||||
"code_challenge": {pkceCodes.CodeChallenge},
|
"code_challenge": {pkceCodes.CodeChallenge},
|
||||||
"code_challenge_method": {"S256"},
|
"code_challenge_method": {"S256"},
|
||||||
"state": {state},
|
"state": {state},
|
||||||
|
|||||||
33
internal/auth/claude/anthropic_auth_proxy_test.go
Normal file
33
internal/auth/claude/anthropic_auth_proxy_test.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewClaudeAuthWithProxyURL_OverrideDirectTakesPrecedence(t *testing.T) {
|
||||||
|
cfg := &config.Config{SDKConfig: config.SDKConfig{ProxyURL: "socks5://proxy.example.com:1080"}}
|
||||||
|
auth := NewClaudeAuthWithProxyURL(cfg, "direct")
|
||||||
|
|
||||||
|
transport, ok := auth.httpClient.Transport.(*utlsRoundTripper)
|
||||||
|
if !ok || transport == nil {
|
||||||
|
t.Fatalf("expected utlsRoundTripper, got %T", auth.httpClient.Transport)
|
||||||
|
}
|
||||||
|
if transport.dialer != proxy.Direct {
|
||||||
|
t.Fatalf("expected proxy.Direct, got %T", transport.dialer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewClaudeAuthWithProxyURL_OverrideProxyAppliedWithoutConfig(t *testing.T) {
|
||||||
|
auth := NewClaudeAuthWithProxyURL(nil, "socks5://proxy.example.com:1080")
|
||||||
|
|
||||||
|
transport, ok := auth.httpClient.Transport.(*utlsRoundTripper)
|
||||||
|
if !ok || transport == nil {
|
||||||
|
t.Fatalf("expected utlsRoundTripper, got %T", auth.httpClient.Transport)
|
||||||
|
}
|
||||||
|
if transport.dialer == proxy.Direct {
|
||||||
|
t.Fatalf("expected proxy dialer, got %T", transport.dialer)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,12 +4,12 @@ package claude
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
tls "github.com/refraction-networking/utls"
|
tls "github.com/refraction-networking/utls"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/proxy"
|
"golang.org/x/net/proxy"
|
||||||
@@ -31,17 +31,12 @@ type utlsRoundTripper struct {
|
|||||||
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
|
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
|
||||||
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
|
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
|
||||||
var dialer proxy.Dialer = proxy.Direct
|
var dialer proxy.Dialer = proxy.Direct
|
||||||
if cfg != nil && cfg.ProxyURL != "" {
|
if cfg != nil {
|
||||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
proxyDialer, mode, errBuild := proxyutil.BuildDialer(cfg.ProxyURL)
|
||||||
if err != nil {
|
if errBuild != nil {
|
||||||
log.Errorf("failed to parse proxy URL %q: %v", cfg.ProxyURL, err)
|
log.Errorf("failed to configure proxy dialer for %q: %v", cfg.ProxyURL, errBuild)
|
||||||
} else {
|
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
|
||||||
pDialer, err := proxy.FromURL(proxyURL, proxy.Direct)
|
dialer = proxyDialer
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to create proxy dialer for %q: %v", cfg.ProxyURL, err)
|
|
||||||
} else {
|
|
||||||
dialer = pDialer
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
335
internal/auth/codebuddy/codebuddy_auth.go
Normal file
335
internal/auth/codebuddy/codebuddy_auth.go
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
package codebuddy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
BaseURL = "https://copilot.tencent.com"
|
||||||
|
DefaultDomain = "www.codebuddy.cn"
|
||||||
|
UserAgent = "CLI/2.63.2 CodeBuddy/2.63.2"
|
||||||
|
|
||||||
|
codeBuddyStatePath = "/v2/plugin/auth/state"
|
||||||
|
codeBuddyTokenPath = "/v2/plugin/auth/token"
|
||||||
|
codeBuddyRefreshPath = "/v2/plugin/auth/token/refresh"
|
||||||
|
pollInterval = 5 * time.Second
|
||||||
|
maxPollDuration = 5 * time.Minute
|
||||||
|
codeLoginPending = 11217
|
||||||
|
codeSuccess = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
type CodeBuddyAuth struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
cfg *config.Config
|
||||||
|
baseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCodeBuddyAuth(cfg *config.Config) *CodeBuddyAuth {
|
||||||
|
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
if cfg != nil {
|
||||||
|
httpClient = util.SetProxy(&cfg.SDKConfig, httpClient)
|
||||||
|
}
|
||||||
|
return &CodeBuddyAuth{httpClient: httpClient, cfg: cfg, baseURL: BaseURL}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthState holds the state and auth URL returned by the auth state API.
|
||||||
|
type AuthState struct {
|
||||||
|
State string
|
||||||
|
AuthURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchAuthState calls POST /v2/plugin/auth/state?platform=CLI to get the state and login URL.
|
||||||
|
func (a *CodeBuddyAuth) FetchAuthState(ctx context.Context) (*AuthState, error) {
|
||||||
|
stateURL := fmt.Sprintf("%s%s?platform=CLI", a.baseURL, codeBuddyStatePath)
|
||||||
|
body := []byte("{}")
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, stateURL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to create auth state request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||||
|
req.Header.Set("X-Domain", "copilot.tencent.com")
|
||||||
|
req.Header.Set("X-No-Authorization", "true")
|
||||||
|
req.Header.Set("X-No-User-Id", "true")
|
||||||
|
req.Header.Set("X-No-Enterprise-Id", "true")
|
||||||
|
req.Header.Set("X-No-Department-Info", "true")
|
||||||
|
req.Header.Set("X-Product", "SaaS")
|
||||||
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
|
req.Header.Set("X-Request-ID", requestID)
|
||||||
|
|
||||||
|
resp, err := a.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: auth state request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("codebuddy auth state: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to read auth state response: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("codebuddy: auth state request returned status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
Data *struct {
|
||||||
|
State string `json:"state"`
|
||||||
|
AuthURL string `json:"authUrl"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
if err = json.Unmarshal(bodyBytes, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to parse auth state response: %w", err)
|
||||||
|
}
|
||||||
|
if result.Code != codeSuccess {
|
||||||
|
return nil, fmt.Errorf("codebuddy: auth state request failed with code %d: %s", result.Code, result.Msg)
|
||||||
|
}
|
||||||
|
if result.Data == nil || result.Data.State == "" || result.Data.AuthURL == "" {
|
||||||
|
return nil, fmt.Errorf("codebuddy: auth state response missing state or authUrl")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &AuthState{
|
||||||
|
State: result.Data.State,
|
||||||
|
AuthURL: result.Data.AuthURL,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type pollResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
RequestID string `json:"requestId"`
|
||||||
|
Data *struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
ExpiresIn int64 `json:"expiresIn"`
|
||||||
|
TokenType string `json:"tokenType"`
|
||||||
|
Domain string `json:"domain"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// doPollRequest performs a single polling request, safely reading and closing the response body
|
||||||
|
func (a *CodeBuddyAuth) doPollRequest(ctx context.Context, pollURL string) ([]byte, int, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pollURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("%w: %v", ErrTokenFetchFailed, err)
|
||||||
|
}
|
||||||
|
a.applyPollHeaders(req)
|
||||||
|
|
||||||
|
resp, err := a.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("codebuddy poll: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, resp.StatusCode, fmt.Errorf("codebuddy poll: failed to read response body: %w", err)
|
||||||
|
}
|
||||||
|
return body, resp.StatusCode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PollForToken polls until the user completes browser authorization and returns auth data.
|
||||||
|
func (a *CodeBuddyAuth) PollForToken(ctx context.Context, state string) (*CodeBuddyTokenStorage, error) {
|
||||||
|
deadline := time.Now().Add(maxPollDuration)
|
||||||
|
pollURL := fmt.Sprintf("%s%s?state=%s", a.baseURL, codeBuddyTokenPath, url.QueryEscape(state))
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(pollInterval):
|
||||||
|
}
|
||||||
|
|
||||||
|
body, statusCode, err := a.doPollRequest(ctx, pollURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("codebuddy poll: request error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if statusCode != http.StatusOK {
|
||||||
|
log.Debugf("codebuddy poll: unexpected status %d", statusCode)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var result pollResponse
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch result.Code {
|
||||||
|
case codeSuccess:
|
||||||
|
if result.Data == nil {
|
||||||
|
return nil, fmt.Errorf("%w: empty data in response", ErrTokenFetchFailed)
|
||||||
|
}
|
||||||
|
userID, _ := a.DecodeUserID(result.Data.AccessToken)
|
||||||
|
return &CodeBuddyTokenStorage{
|
||||||
|
AccessToken: result.Data.AccessToken,
|
||||||
|
RefreshToken: result.Data.RefreshToken,
|
||||||
|
ExpiresIn: result.Data.ExpiresIn,
|
||||||
|
TokenType: result.Data.TokenType,
|
||||||
|
Domain: result.Data.Domain,
|
||||||
|
UserID: userID,
|
||||||
|
Type: "codebuddy",
|
||||||
|
}, nil
|
||||||
|
case codeLoginPending:
|
||||||
|
// continue polling
|
||||||
|
default:
|
||||||
|
// TODO: when the CodeBuddy API error code for user denial is known,
|
||||||
|
// return ErrAccessDenied here instead of ErrTokenFetchFailed.
|
||||||
|
return nil, fmt.Errorf("%w: server returned code %d: %s", ErrTokenFetchFailed, result.Code, result.Msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, ErrPollingTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeUserID decodes the sub field from a JWT access token as the user ID.
|
||||||
|
func (a *CodeBuddyAuth) DecodeUserID(accessToken string) (string, error) {
|
||||||
|
parts := strings.Split(accessToken, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return "", ErrJWTDecodeFailed
|
||||||
|
}
|
||||||
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("%w: %v", ErrJWTDecodeFailed, err)
|
||||||
|
}
|
||||||
|
var claims struct {
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||||
|
return "", fmt.Errorf("%w: %v", ErrJWTDecodeFailed, err)
|
||||||
|
}
|
||||||
|
if claims.Sub == "" {
|
||||||
|
return "", fmt.Errorf("%w: sub claim is empty", ErrJWTDecodeFailed)
|
||||||
|
}
|
||||||
|
return claims.Sub, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshToken exchanges a refresh token for a new access token.
|
||||||
|
// It calls POST /v2/plugin/auth/token/refresh with the required headers.
|
||||||
|
func (a *CodeBuddyAuth) RefreshToken(ctx context.Context, accessToken, refreshToken, userID, domain string) (*CodeBuddyTokenStorage, error) {
|
||||||
|
if domain == "" {
|
||||||
|
domain = DefaultDomain
|
||||||
|
}
|
||||||
|
refreshURL := fmt.Sprintf("%s%s", a.baseURL, codeBuddyRefreshPath)
|
||||||
|
body := []byte("{}")
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to create refresh request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
requestID := strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||||
|
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||||
|
req.Header.Set("X-Domain", domain)
|
||||||
|
req.Header.Set("X-Refresh-Token", refreshToken)
|
||||||
|
req.Header.Set("X-Auth-Refresh-Source", "plugin")
|
||||||
|
req.Header.Set("X-Request-ID", requestID)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("X-User-Id", userID)
|
||||||
|
req.Header.Set("X-Product", "SaaS")
|
||||||
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
|
|
||||||
|
resp, err := a.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: refresh request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("codebuddy refresh: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to read refresh response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||||
|
return nil, fmt.Errorf("codebuddy: refresh token rejected (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("codebuddy: refresh failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
Data *struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
ExpiresIn int64 `json:"expiresIn"`
|
||||||
|
RefreshExpiresIn int64 `json:"refreshExpiresIn"`
|
||||||
|
TokenType string `json:"tokenType"`
|
||||||
|
Domain string `json:"domain"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
if err = json.Unmarshal(bodyBytes, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: failed to parse refresh response: %w", err)
|
||||||
|
}
|
||||||
|
if result.Code != codeSuccess {
|
||||||
|
return nil, fmt.Errorf("codebuddy: refresh failed with code %d: %s", result.Code, result.Msg)
|
||||||
|
}
|
||||||
|
if result.Data == nil {
|
||||||
|
return nil, fmt.Errorf("codebuddy: empty data in refresh response")
|
||||||
|
}
|
||||||
|
|
||||||
|
newUserID, _ := a.DecodeUserID(result.Data.AccessToken)
|
||||||
|
if newUserID == "" {
|
||||||
|
newUserID = userID
|
||||||
|
}
|
||||||
|
tokenDomain := result.Data.Domain
|
||||||
|
if tokenDomain == "" {
|
||||||
|
tokenDomain = domain
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CodeBuddyTokenStorage{
|
||||||
|
AccessToken: result.Data.AccessToken,
|
||||||
|
RefreshToken: result.Data.RefreshToken,
|
||||||
|
ExpiresIn: result.Data.ExpiresIn,
|
||||||
|
RefreshExpiresIn: result.Data.RefreshExpiresIn,
|
||||||
|
TokenType: result.Data.TokenType,
|
||||||
|
Domain: tokenDomain,
|
||||||
|
UserID: newUserID,
|
||||||
|
Type: "codebuddy",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *CodeBuddyAuth) applyPollHeaders(req *http.Request) {
|
||||||
|
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||||
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
|
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||||
|
req.Header.Set("X-No-Authorization", "true")
|
||||||
|
req.Header.Set("X-No-User-Id", "true")
|
||||||
|
req.Header.Set("X-No-Enterprise-Id", "true")
|
||||||
|
req.Header.Set("X-No-Department-Info", "true")
|
||||||
|
req.Header.Set("X-Product", "SaaS")
|
||||||
|
}
|
||||||
285
internal/auth/codebuddy/codebuddy_auth_http_test.go
Normal file
285
internal/auth/codebuddy/codebuddy_auth_http_test.go
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
package codebuddy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newTestAuth creates a CodeBuddyAuth pointing at the given test server.
|
||||||
|
func newTestAuth(serverURL string) *CodeBuddyAuth {
|
||||||
|
return &CodeBuddyAuth{
|
||||||
|
httpClient: http.DefaultClient,
|
||||||
|
baseURL: serverURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeJWT builds a minimal JWT with the given sub claim for testing.
|
||||||
|
func fakeJWT(sub string) string {
|
||||||
|
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`))
|
||||||
|
payload, _ := json.Marshal(map[string]any{"sub": sub, "iat": 1234567890})
|
||||||
|
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
|
||||||
|
return header + "." + encodedPayload + ".sig"
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- FetchAuthState tests ---
|
||||||
|
|
||||||
|
func TestFetchAuthState_Success(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Errorf("expected POST, got %s", r.Method)
|
||||||
|
}
|
||||||
|
if got := r.URL.Path; got != codeBuddyStatePath {
|
||||||
|
t.Errorf("expected path %s, got %s", codeBuddyStatePath, got)
|
||||||
|
}
|
||||||
|
if got := r.URL.Query().Get("platform"); got != "CLI" {
|
||||||
|
t.Errorf("expected platform=CLI, got %s", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("User-Agent"); got != UserAgent {
|
||||||
|
t.Errorf("expected User-Agent %s, got %s", UserAgent, got)
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"msg": "ok",
|
||||||
|
"data": map[string]any{
|
||||||
|
"state": "test-state-abc",
|
||||||
|
"authUrl": "https://example.com/login?state=test-state-abc",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
result, err := auth.FetchAuthState(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result.State != "test-state-abc" {
|
||||||
|
t.Errorf("expected state 'test-state-abc', got '%s'", result.State)
|
||||||
|
}
|
||||||
|
if result.AuthURL != "https://example.com/login?state=test-state-abc" {
|
||||||
|
t.Errorf("unexpected authURL: %s", result.AuthURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchAuthState_NonOKStatus(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_, _ = w.Write([]byte("internal error"))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.FetchAuthState(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-200 status")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchAuthState_APIErrorCode(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 10001,
|
||||||
|
"msg": "rate limited",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.FetchAuthState(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-zero code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchAuthState_MissingData(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"msg": "ok",
|
||||||
|
"data": map[string]any{
|
||||||
|
"state": "",
|
||||||
|
"authUrl": "",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.FetchAuthState(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for empty state/authUrl")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- RefreshToken tests ---
|
||||||
|
|
||||||
|
func TestRefreshToken_Success(t *testing.T) {
|
||||||
|
newAccessToken := fakeJWT("refreshed-user-456")
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Errorf("expected POST, got %s", r.Method)
|
||||||
|
}
|
||||||
|
if got := r.URL.Path; got != codeBuddyRefreshPath {
|
||||||
|
t.Errorf("expected path %s, got %s", codeBuddyRefreshPath, got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("X-Refresh-Token"); got != "old-refresh-token" {
|
||||||
|
t.Errorf("expected X-Refresh-Token 'old-refresh-token', got '%s'", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("Authorization"); got != "Bearer old-access-token" {
|
||||||
|
t.Errorf("expected Authorization 'Bearer old-access-token', got '%s'", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("X-User-Id"); got != "user-123" {
|
||||||
|
t.Errorf("expected X-User-Id 'user-123', got '%s'", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("X-Domain"); got != "custom.domain.com" {
|
||||||
|
t.Errorf("expected X-Domain 'custom.domain.com', got '%s'", got)
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"msg": "ok",
|
||||||
|
"data": map[string]any{
|
||||||
|
"accessToken": newAccessToken,
|
||||||
|
"refreshToken": "new-refresh-token",
|
||||||
|
"expiresIn": 3600,
|
||||||
|
"refreshExpiresIn": 86400,
|
||||||
|
"tokenType": "bearer",
|
||||||
|
"domain": "custom.domain.com",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
storage, err := auth.RefreshToken(context.Background(), "old-access-token", "old-refresh-token", "user-123", "custom.domain.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if storage.AccessToken != newAccessToken {
|
||||||
|
t.Errorf("expected new access token, got '%s'", storage.AccessToken)
|
||||||
|
}
|
||||||
|
if storage.RefreshToken != "new-refresh-token" {
|
||||||
|
t.Errorf("expected 'new-refresh-token', got '%s'", storage.RefreshToken)
|
||||||
|
}
|
||||||
|
if storage.UserID != "refreshed-user-456" {
|
||||||
|
t.Errorf("expected userID 'refreshed-user-456', got '%s'", storage.UserID)
|
||||||
|
}
|
||||||
|
if storage.ExpiresIn != 3600 {
|
||||||
|
t.Errorf("expected expiresIn 3600, got %d", storage.ExpiresIn)
|
||||||
|
}
|
||||||
|
if storage.RefreshExpiresIn != 86400 {
|
||||||
|
t.Errorf("expected refreshExpiresIn 86400, got %d", storage.RefreshExpiresIn)
|
||||||
|
}
|
||||||
|
if storage.Domain != "custom.domain.com" {
|
||||||
|
t.Errorf("expected domain 'custom.domain.com', got '%s'", storage.Domain)
|
||||||
|
}
|
||||||
|
if storage.Type != "codebuddy" {
|
||||||
|
t.Errorf("expected type 'codebuddy', got '%s'", storage.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_DefaultDomain(t *testing.T) {
|
||||||
|
var receivedDomain string
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedDomain = r.Header.Get("X-Domain")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"msg": "ok",
|
||||||
|
"data": map[string]any{
|
||||||
|
"accessToken": fakeJWT("user-1"),
|
||||||
|
"refreshToken": "rt",
|
||||||
|
"expiresIn": 3600,
|
||||||
|
"tokenType": "bearer",
|
||||||
|
"domain": DefaultDomain,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if receivedDomain != DefaultDomain {
|
||||||
|
t.Errorf("expected default domain '%s', got '%s'", DefaultDomain, receivedDomain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_Unauthorized(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 401 response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_Forbidden(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 403 response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_APIErrorCode(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 40001,
|
||||||
|
"msg": "invalid refresh token",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-zero API code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_FallbackUserIDAndDomain(t *testing.T) {
|
||||||
|
// When the new access token cannot be decoded for userID, it should fall back to the provided one.
|
||||||
|
// When the response domain is empty, it should fall back to the request domain.
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"msg": "ok",
|
||||||
|
"data": map[string]any{
|
||||||
|
"accessToken": "not-a-valid-jwt",
|
||||||
|
"refreshToken": "new-rt",
|
||||||
|
"expiresIn": 7200,
|
||||||
|
"tokenType": "bearer",
|
||||||
|
"domain": "",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
auth := newTestAuth(srv.URL)
|
||||||
|
storage, err := auth.RefreshToken(context.Background(), "at", "rt", "original-uid", "original.domain.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if storage.UserID != "original-uid" {
|
||||||
|
t.Errorf("expected fallback userID 'original-uid', got '%s'", storage.UserID)
|
||||||
|
}
|
||||||
|
if storage.Domain != "original.domain.com" {
|
||||||
|
t.Errorf("expected fallback domain 'original.domain.com', got '%s'", storage.Domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
21
internal/auth/codebuddy/codebuddy_auth_test.go
Normal file
21
internal/auth/codebuddy/codebuddy_auth_test.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package codebuddy_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDecodeUserID_ValidJWT(t *testing.T) {
|
||||||
|
// JWT payload: {"sub":"test-user-id-123","iat":1234567890}
|
||||||
|
// base64url encode: eyJzdWIiOiJ0ZXN0LXVzZXItaWQtMTIzIiwiaWF0IjoxMjM0NTY3ODkwfQ
|
||||||
|
token := "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXItaWQtMTIzIiwiaWF0IjoxMjM0NTY3ODkwfQ.sig"
|
||||||
|
auth := codebuddy.NewCodeBuddyAuth(nil)
|
||||||
|
userID, err := auth.DecodeUserID(token)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if userID != "test-user-id-123" {
|
||||||
|
t.Errorf("expected 'test-user-id-123', got '%s'", userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
25
internal/auth/codebuddy/errors.go
Normal file
25
internal/auth/codebuddy/errors.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package codebuddy
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrPollingTimeout = errors.New("codebuddy: polling timeout, user did not authorize in time")
|
||||||
|
ErrAccessDenied = errors.New("codebuddy: access denied by user")
|
||||||
|
ErrTokenFetchFailed = errors.New("codebuddy: failed to fetch token from server")
|
||||||
|
ErrJWTDecodeFailed = errors.New("codebuddy: failed to decode JWT token")
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetUserFriendlyMessage(err error) string {
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, ErrPollingTimeout):
|
||||||
|
return "Authentication timed out. Please try again."
|
||||||
|
case errors.Is(err, ErrAccessDenied):
|
||||||
|
return "Access denied. Please try again and approve the login request."
|
||||||
|
case errors.Is(err, ErrJWTDecodeFailed):
|
||||||
|
return "Failed to decode token. Please try logging in again."
|
||||||
|
case errors.Is(err, ErrTokenFetchFailed):
|
||||||
|
return "Failed to fetch token from server. Please try again."
|
||||||
|
default:
|
||||||
|
return "Authentication failed: " + err.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
65
internal/auth/codebuddy/token.go
Normal file
65
internal/auth/codebuddy/token.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
// Package codebuddy provides authentication and token management functionality
|
||||||
|
// for CodeBuddy AI services. It handles OAuth2 token storage, serialization,
|
||||||
|
// and retrieval for maintaining authenticated sessions with the CodeBuddy API.
|
||||||
|
package codebuddy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CodeBuddyTokenStorage stores OAuth token information for CodeBuddy API authentication.
|
||||||
|
// It maintains compatibility with the existing auth system while adding CodeBuddy-specific fields
|
||||||
|
// for managing access tokens and user account information.
|
||||||
|
type CodeBuddyTokenStorage struct {
|
||||||
|
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
// RefreshToken is the OAuth2 refresh token used to obtain new access tokens.
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
// ExpiresIn is the number of seconds until the access token expires.
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
// RefreshExpiresIn is the number of seconds until the refresh token expires.
|
||||||
|
RefreshExpiresIn int64 `json:"refresh_expires_in,omitempty"`
|
||||||
|
// TokenType is the type of token, typically "bearer".
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
// Domain is the CodeBuddy service domain/region.
|
||||||
|
Domain string `json:"domain"`
|
||||||
|
// UserID is the user ID associated with this token.
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
// Type indicates the authentication provider type, always "codebuddy" for this storage.
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveTokenToFile serializes the CodeBuddy token storage to a JSON file.
|
||||||
|
// This method creates the necessary directory structure and writes the token
|
||||||
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - authFilePath: The full path where the token file should be saved
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: An error if the operation fails, nil otherwise
|
||||||
|
func (s *CodeBuddyTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||||
|
misc.LogSavingCredentials(authFilePath)
|
||||||
|
s.Type = "codebuddy"
|
||||||
|
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||||
|
return fmt.Errorf("failed to create directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.OpenFile(authFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create token file: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = f.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(s); err != nil {
|
||||||
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -37,8 +37,23 @@ type CodexAuth struct {
|
|||||||
// NewCodexAuth creates a new CodexAuth service instance.
|
// NewCodexAuth creates a new CodexAuth service instance.
|
||||||
// It initializes an HTTP client with proxy settings from the provided configuration.
|
// It initializes an HTTP client with proxy settings from the provided configuration.
|
||||||
func NewCodexAuth(cfg *config.Config) *CodexAuth {
|
func NewCodexAuth(cfg *config.Config) *CodexAuth {
|
||||||
|
return NewCodexAuthWithProxyURL(cfg, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCodexAuthWithProxyURL creates a new CodexAuth service instance.
|
||||||
|
// proxyURL takes precedence over cfg.ProxyURL when non-empty.
|
||||||
|
func NewCodexAuthWithProxyURL(cfg *config.Config, proxyURL string) *CodexAuth {
|
||||||
|
effectiveProxyURL := strings.TrimSpace(proxyURL)
|
||||||
|
var sdkCfg config.SDKConfig
|
||||||
|
if cfg != nil {
|
||||||
|
sdkCfg = cfg.SDKConfig
|
||||||
|
if effectiveProxyURL == "" {
|
||||||
|
effectiveProxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sdkCfg.ProxyURL = effectiveProxyURL
|
||||||
return &CodexAuth{
|
return &CodexAuth{
|
||||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
|
httpClient: util.SetProxy(&sdkCfg, &http.Client{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
@@ -42,3 +44,37 @@ func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
|
|||||||
t.Fatalf("expected 1 refresh attempt, got %d", got)
|
t.Fatalf("expected 1 refresh attempt, got %d", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewCodexAuthWithProxyURL_OverrideDirectDisablesProxy(t *testing.T) {
|
||||||
|
cfg := &config.Config{SDKConfig: config.SDKConfig{ProxyURL: "http://proxy.example.com:8080"}}
|
||||||
|
auth := NewCodexAuthWithProxyURL(cfg, "direct")
|
||||||
|
|
||||||
|
transport, ok := auth.httpClient.Transport.(*http.Transport)
|
||||||
|
if !ok || transport == nil {
|
||||||
|
t.Fatalf("expected http.Transport, got %T", auth.httpClient.Transport)
|
||||||
|
}
|
||||||
|
if transport.Proxy != nil {
|
||||||
|
t.Fatal("expected direct transport to disable proxy function")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCodexAuthWithProxyURL_OverrideProxyTakesPrecedence(t *testing.T) {
|
||||||
|
cfg := &config.Config{SDKConfig: config.SDKConfig{ProxyURL: "http://global.example.com:8080"}}
|
||||||
|
auth := NewCodexAuthWithProxyURL(cfg, "http://override.example.com:8081")
|
||||||
|
|
||||||
|
transport, ok := auth.httpClient.Transport.(*http.Transport)
|
||||||
|
if !ok || transport == nil {
|
||||||
|
t.Fatalf("expected http.Transport, got %T", auth.httpClient.Transport)
|
||||||
|
}
|
||||||
|
req, errReq := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||||
|
if errReq != nil {
|
||||||
|
t.Fatalf("new request: %v", errReq)
|
||||||
|
}
|
||||||
|
proxyURL, errProxy := transport.Proxy(req)
|
||||||
|
if errProxy != nil {
|
||||||
|
t.Fatalf("proxy func: %v", errProxy)
|
||||||
|
}
|
||||||
|
if proxyURL == nil || proxyURL.String() != "http://override.example.com:8081" {
|
||||||
|
t.Fatalf("proxy URL = %v, want http://override.example.com:8081", proxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
@@ -22,11 +24,11 @@ const (
|
|||||||
copilotAPIEndpoint = "https://api.githubcopilot.com"
|
copilotAPIEndpoint = "https://api.githubcopilot.com"
|
||||||
|
|
||||||
// Common HTTP header values for Copilot API requests.
|
// Common HTTP header values for Copilot API requests.
|
||||||
copilotUserAgent = "GithubCopilot/1.0"
|
copilotUserAgent = "GithubCopilot/1.0"
|
||||||
copilotEditorVersion = "vscode/1.100.0"
|
copilotEditorVersion = "vscode/1.100.0"
|
||||||
copilotPluginVersion = "copilot/1.300.0"
|
copilotPluginVersion = "copilot/1.300.0"
|
||||||
copilotIntegrationID = "vscode-chat"
|
copilotIntegrationID = "vscode-chat"
|
||||||
copilotOpenAIIntent = "conversation-panel"
|
copilotOpenAIIntent = "conversation-panel"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CopilotAPIToken represents the Copilot API token response.
|
// CopilotAPIToken represents the Copilot API token response.
|
||||||
@@ -222,6 +224,165 @@ func (c *CopilotAuth) MakeAuthenticatedRequest(ctx context.Context, method, url
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CopilotModelEntry represents a single model entry returned by the Copilot /models API.
|
||||||
|
type CopilotModelEntry struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Version string `json:"version,omitempty"`
|
||||||
|
Capabilities map[string]any `json:"capabilities,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopilotModelLimits holds the token limits returned by the Copilot /models API
|
||||||
|
// under capabilities.limits. These limits vary by account type (individual vs
|
||||||
|
// business) and are the authoritative source for enforcing prompt size.
|
||||||
|
type CopilotModelLimits struct {
|
||||||
|
// MaxContextWindowTokens is the total context window (prompt + output).
|
||||||
|
MaxContextWindowTokens int
|
||||||
|
// MaxPromptTokens is the hard limit on input/prompt tokens.
|
||||||
|
// Exceeding this triggers a 400 error from the Copilot API.
|
||||||
|
MaxPromptTokens int
|
||||||
|
// MaxOutputTokens is the maximum number of output/completion tokens.
|
||||||
|
MaxOutputTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limits extracts the token limits from the model's capabilities map.
|
||||||
|
// Returns nil if no limits are available or the structure is unexpected.
|
||||||
|
//
|
||||||
|
// Expected Copilot API shape:
|
||||||
|
//
|
||||||
|
// "capabilities": {
|
||||||
|
// "limits": {
|
||||||
|
// "max_context_window_tokens": 200000,
|
||||||
|
// "max_prompt_tokens": 168000,
|
||||||
|
// "max_output_tokens": 32000
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
func (e *CopilotModelEntry) Limits() *CopilotModelLimits {
|
||||||
|
if e.Capabilities == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
limitsRaw, ok := e.Capabilities["limits"]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
limitsMap, ok := limitsRaw.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &CopilotModelLimits{
|
||||||
|
MaxContextWindowTokens: anyToInt(limitsMap["max_context_window_tokens"]),
|
||||||
|
MaxPromptTokens: anyToInt(limitsMap["max_prompt_tokens"]),
|
||||||
|
MaxOutputTokens: anyToInt(limitsMap["max_output_tokens"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only return if at least one field is populated.
|
||||||
|
if result.MaxContextWindowTokens == 0 && result.MaxPromptTokens == 0 && result.MaxOutputTokens == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// anyToInt converts a JSON-decoded numeric value to int.
|
||||||
|
// Go's encoding/json decodes numbers into float64 when the target is any/interface{}.
|
||||||
|
func anyToInt(v any) int {
|
||||||
|
switch n := v.(type) {
|
||||||
|
case float64:
|
||||||
|
return int(n)
|
||||||
|
case float32:
|
||||||
|
return int(n)
|
||||||
|
case int:
|
||||||
|
return n
|
||||||
|
case int64:
|
||||||
|
return int(n)
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopilotModelsResponse represents the response from the Copilot /models endpoint.
|
||||||
|
type CopilotModelsResponse struct {
|
||||||
|
Data []CopilotModelEntry `json:"data"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// maxModelsResponseSize is the maximum allowed response size from the /models endpoint (2 MB).
|
||||||
|
const maxModelsResponseSize = 2 * 1024 * 1024
|
||||||
|
|
||||||
|
// allowedCopilotAPIHosts is the set of hosts that are considered safe for Copilot API requests.
|
||||||
|
var allowedCopilotAPIHosts = map[string]bool{
|
||||||
|
"api.githubcopilot.com": true,
|
||||||
|
"api.individual.githubcopilot.com": true,
|
||||||
|
"api.business.githubcopilot.com": true,
|
||||||
|
"copilot-proxy.githubusercontent.com": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListModels fetches the list of available models from the Copilot API.
|
||||||
|
// It requires a valid Copilot API token (not the GitHub access token).
|
||||||
|
func (c *CopilotAuth) ListModels(ctx context.Context, apiToken *CopilotAPIToken) ([]CopilotModelEntry, error) {
|
||||||
|
if apiToken == nil || apiToken.Token == "" {
|
||||||
|
return nil, fmt.Errorf("copilot: api token is required for listing models")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build models URL, validating the endpoint host to prevent SSRF.
|
||||||
|
modelsURL := copilotAPIEndpoint + "/models"
|
||||||
|
if ep := strings.TrimRight(apiToken.Endpoints.API, "/"); ep != "" {
|
||||||
|
parsed, err := url.Parse(ep)
|
||||||
|
if err == nil && parsed.Scheme == "https" && allowedCopilotAPIHosts[parsed.Host] {
|
||||||
|
modelsURL = ep + "/models"
|
||||||
|
} else {
|
||||||
|
log.Warnf("copilot: ignoring untrusted API endpoint %q, using default", ep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := c.MakeAuthenticatedRequest(ctx, http.MethodGet, modelsURL, nil, apiToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("copilot: failed to create models request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("copilot: models request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("copilot list models: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Limit response body to prevent memory exhaustion.
|
||||||
|
limitedReader := io.LimitReader(resp.Body, maxModelsResponseSize)
|
||||||
|
bodyBytes, err := io.ReadAll(limitedReader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("copilot: failed to read models response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isHTTPSuccess(resp.StatusCode) {
|
||||||
|
return nil, fmt.Errorf("copilot: list models failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var modelsResp CopilotModelsResponse
|
||||||
|
if err = json.Unmarshal(bodyBytes, &modelsResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("copilot: failed to parse models response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelsResp.Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListModelsWithGitHubToken is a convenience method that exchanges a GitHub access token
|
||||||
|
// for a Copilot API token and then fetches the available models.
|
||||||
|
func (c *CopilotAuth) ListModelsWithGitHubToken(ctx context.Context, githubAccessToken string) ([]CopilotModelEntry, error) {
|
||||||
|
apiToken, err := c.GetCopilotAPIToken(ctx, githubAccessToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("copilot: failed to get API token for model listing: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.ListModels(ctx, apiToken)
|
||||||
|
}
|
||||||
|
|
||||||
// buildChatCompletionURL builds the URL for chat completions API.
|
// buildChatCompletionURL builds the URL for chat completions API.
|
||||||
func buildChatCompletionURL() string {
|
func buildChatCompletionURL() string {
|
||||||
return copilotAPIEndpoint + "/chat/completions"
|
return copilotAPIEndpoint + "/chat/completions"
|
||||||
|
|||||||
33
internal/auth/cursor/filename.go
Normal file
33
internal/auth/cursor/filename.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package cursor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CredentialFileName returns the filename used to persist Cursor credentials.
|
||||||
|
// Priority: explicit label > auto-generated from JWT sub hash.
|
||||||
|
// If both label and subHash are empty, falls back to "cursor.json".
|
||||||
|
func CredentialFileName(label, subHash string) string {
|
||||||
|
label = strings.TrimSpace(label)
|
||||||
|
subHash = strings.TrimSpace(subHash)
|
||||||
|
if label != "" {
|
||||||
|
return fmt.Sprintf("cursor.%s.json", label)
|
||||||
|
}
|
||||||
|
if subHash != "" {
|
||||||
|
return fmt.Sprintf("cursor.%s.json", subHash)
|
||||||
|
}
|
||||||
|
return "cursor.json"
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisplayLabel returns a human-readable label for the Cursor account.
|
||||||
|
func DisplayLabel(label, subHash string) string {
|
||||||
|
label = strings.TrimSpace(label)
|
||||||
|
if label != "" {
|
||||||
|
return "Cursor " + label
|
||||||
|
}
|
||||||
|
if subHash != "" {
|
||||||
|
return "Cursor " + subHash
|
||||||
|
}
|
||||||
|
return "Cursor User"
|
||||||
|
}
|
||||||
249
internal/auth/cursor/oauth.go
Normal file
249
internal/auth/cursor/oauth.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
// Package cursor implements Cursor OAuth PKCE authentication and token refresh.
|
||||||
|
package cursor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CursorLoginURL = "https://cursor.com/loginDeepControl"
|
||||||
|
CursorPollURL = "https://api2.cursor.sh/auth/poll"
|
||||||
|
CursorRefreshURL = "https://api2.cursor.sh/auth/exchange_user_api_key"
|
||||||
|
|
||||||
|
pollMaxAttempts = 150
|
||||||
|
pollBaseDelay = 1 * time.Second
|
||||||
|
pollMaxDelay = 10 * time.Second
|
||||||
|
pollBackoffMultiply = 1.2
|
||||||
|
maxConsecutiveErrors = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthParams holds the PKCE parameters for Cursor login.
|
||||||
|
type AuthParams struct {
|
||||||
|
Verifier string
|
||||||
|
Challenge string
|
||||||
|
UUID string
|
||||||
|
LoginURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenPair holds the access and refresh tokens from Cursor.
|
||||||
|
type TokenPair struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GeneratePKCE creates a PKCE verifier and challenge pair.
|
||||||
|
func GeneratePKCE() (verifier, challenge string, err error) {
|
||||||
|
verifierBytes := make([]byte, 96)
|
||||||
|
if _, err = rand.Read(verifierBytes); err != nil {
|
||||||
|
return "", "", fmt.Errorf("cursor: failed to generate PKCE verifier: %w", err)
|
||||||
|
}
|
||||||
|
verifier = base64.RawURLEncoding.EncodeToString(verifierBytes)
|
||||||
|
|
||||||
|
h := sha256.Sum256([]byte(verifier))
|
||||||
|
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||||
|
return verifier, challenge, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateAuthParams creates the full set of auth params for Cursor login.
|
||||||
|
func GenerateAuthParams() (*AuthParams, error) {
|
||||||
|
verifier, challenge, err := GeneratePKCE()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
uuidBytes := make([]byte, 16)
|
||||||
|
if _, err = rand.Read(uuidBytes); err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: failed to generate UUID: %w", err)
|
||||||
|
}
|
||||||
|
uuid := fmt.Sprintf("%x-%x-%x-%x-%x",
|
||||||
|
uuidBytes[0:4], uuidBytes[4:6], uuidBytes[6:8], uuidBytes[8:10], uuidBytes[10:16])
|
||||||
|
|
||||||
|
loginURL := fmt.Sprintf("%s?challenge=%s&uuid=%s&mode=login&redirectTarget=cli",
|
||||||
|
CursorLoginURL, challenge, uuid)
|
||||||
|
|
||||||
|
return &AuthParams{
|
||||||
|
Verifier: verifier,
|
||||||
|
Challenge: challenge,
|
||||||
|
UUID: uuid,
|
||||||
|
LoginURL: loginURL,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PollForAuth polls the Cursor auth endpoint until the user completes login.
|
||||||
|
func PollForAuth(ctx context.Context, uuid, verifier string) (*TokenPair, error) {
|
||||||
|
delay := pollBaseDelay
|
||||||
|
consecutiveErrors := 0
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 10 * time.Second}
|
||||||
|
|
||||||
|
for attempt := 0; attempt < pollMaxAttempts; attempt++ {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(delay):
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s?uuid=%s&verifier=%s", CursorPollURL, uuid, verifier)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: failed to create poll request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
consecutiveErrors++
|
||||||
|
if consecutiveErrors >= maxConsecutiveErrors {
|
||||||
|
return nil, fmt.Errorf("cursor: too many consecutive poll errors (last: %v)", err)
|
||||||
|
}
|
||||||
|
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusNotFound {
|
||||||
|
// Still waiting for user to authorize
|
||||||
|
consecutiveErrors = 0
|
||||||
|
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||||
|
var tokens TokenPair
|
||||||
|
if err := json.Unmarshal(body, &tokens); err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: failed to parse auth response: %w", err)
|
||||||
|
}
|
||||||
|
return &tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("cursor: poll failed with status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("cursor: authentication polling timeout (waited ~%.0f seconds)",
|
||||||
|
float64(pollMaxAttempts)*pollMaxDelay.Seconds()/2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshToken refreshes a Cursor access token using the refresh token.
|
||||||
|
func RefreshToken(ctx context.Context, refreshToken string) (*TokenPair, error) {
|
||||||
|
client := &http.Client{Timeout: 10 * time.Second}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, CursorRefreshURL,
|
||||||
|
strings.NewReader("{}"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: failed to create refresh request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+refreshToken)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: token refresh request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("cursor: token refresh failed (status %d): %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokens TokenPair
|
||||||
|
if err := json.Unmarshal(body, &tokens); err != nil {
|
||||||
|
return nil, fmt.Errorf("cursor: failed to parse refresh response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep original refresh token if not returned
|
||||||
|
if tokens.RefreshToken == "" {
|
||||||
|
tokens.RefreshToken = refreshToken
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseJWTSub extracts the "sub" claim from a Cursor JWT access token.
|
||||||
|
// Cursor JWTs contain "sub" like "auth0|user_XXXX" which uniquely identifies
|
||||||
|
// the account. Returns empty string if parsing fails.
|
||||||
|
func ParseJWTSub(token string) string {
|
||||||
|
decoded := decodeJWTPayload(token)
|
||||||
|
if decoded == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var claims struct {
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return claims.Sub
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubToShortHash converts a JWT sub claim to a short hex hash for use in filenames.
|
||||||
|
// e.g. "auth0|user_2x..." → "a3f8b2c1"
|
||||||
|
func SubToShortHash(sub string) string {
|
||||||
|
if sub == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
h := sha256.Sum256([]byte(sub))
|
||||||
|
return fmt.Sprintf("%x", h[:4]) // 8 hex chars
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeJWTPayload decodes the payload (middle) part of a JWT.
|
||||||
|
func decodeJWTPayload(token string) []byte {
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payload := parts[1]
|
||||||
|
switch len(payload) % 4 {
|
||||||
|
case 2:
|
||||||
|
payload += "=="
|
||||||
|
case 3:
|
||||||
|
payload += "="
|
||||||
|
}
|
||||||
|
payload = strings.ReplaceAll(payload, "-", "+")
|
||||||
|
payload = strings.ReplaceAll(payload, "_", "/")
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return decoded
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTokenExpiry extracts the JWT expiry from an access token with a 5-minute safety margin.
|
||||||
|
// Falls back to 1 hour from now if the token can't be parsed.
|
||||||
|
func GetTokenExpiry(token string) time.Time {
|
||||||
|
decoded := decodeJWTPayload(token)
|
||||||
|
if decoded == nil {
|
||||||
|
return time.Now().Add(1 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims struct {
|
||||||
|
Exp float64 `json:"exp"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(decoded, &claims); err != nil || claims.Exp == 0 {
|
||||||
|
return time.Now().Add(1 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
sec, frac := math.Modf(claims.Exp)
|
||||||
|
expiry := time.Unix(int64(sec), int64(frac*1e9))
|
||||||
|
// Subtract 5-minute safety margin
|
||||||
|
return expiry.Add(-5 * time.Minute)
|
||||||
|
}
|
||||||
|
|
||||||
|
func minDuration(a, b time.Duration) time.Duration {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
84
internal/auth/cursor/proto/connect.go
Normal file
84
internal/auth/cursor/proto/connect.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package proto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ConnectEndStreamFlag marks the end-of-stream frame (trailers).
|
||||||
|
ConnectEndStreamFlag byte = 0x02
|
||||||
|
// ConnectCompressionFlag indicates the payload is compressed (not supported).
|
||||||
|
ConnectCompressionFlag byte = 0x01
|
||||||
|
// ConnectFrameHeaderSize is the fixed 5-byte frame header.
|
||||||
|
ConnectFrameHeaderSize = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
// FrameConnectMessage wraps a protobuf payload in a Connect frame.
|
||||||
|
// Frame format: [1 byte flags][4 bytes payload length (big-endian)][payload]
|
||||||
|
func FrameConnectMessage(data []byte, flags byte) []byte {
|
||||||
|
frame := make([]byte, ConnectFrameHeaderSize+len(data))
|
||||||
|
frame[0] = flags
|
||||||
|
binary.BigEndian.PutUint32(frame[1:5], uint32(len(data)))
|
||||||
|
copy(frame[5:], data)
|
||||||
|
return frame
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConnectFrame extracts one frame from a buffer.
|
||||||
|
// Returns (flags, payload, bytesConsumed, ok).
|
||||||
|
// ok is false when the buffer is too short for a complete frame.
|
||||||
|
func ParseConnectFrame(buf []byte) (flags byte, payload []byte, consumed int, ok bool) {
|
||||||
|
if len(buf) < ConnectFrameHeaderSize {
|
||||||
|
return 0, nil, 0, false
|
||||||
|
}
|
||||||
|
flags = buf[0]
|
||||||
|
length := binary.BigEndian.Uint32(buf[1:5])
|
||||||
|
total := ConnectFrameHeaderSize + int(length)
|
||||||
|
if len(buf) < total {
|
||||||
|
return 0, nil, 0, false
|
||||||
|
}
|
||||||
|
return flags, buf[5:total], total, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectError is a structured error from the Connect protocol end-of-stream trailer.
|
||||||
|
// The Code field contains the server-defined error code (e.g. gRPC standard codes
|
||||||
|
// like "resource_exhausted", "unauthenticated", "permission_denied", "unavailable").
|
||||||
|
type ConnectError struct {
|
||||||
|
Code string // server-defined error code
|
||||||
|
Message string // human-readable error description
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnectError) Error() string {
|
||||||
|
return fmt.Sprintf("Connect error %s: %s", e.Code, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConnectEndStream parses a Connect end-of-stream frame payload (JSON).
|
||||||
|
// Returns nil if there is no error in the trailer.
|
||||||
|
// On error, returns a *ConnectError with the server's error code and message.
|
||||||
|
func ParseConnectEndStream(data []byte) error {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var trailer struct {
|
||||||
|
Error *struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
} `json:"error"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &trailer); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse Connect end stream: %w", err)
|
||||||
|
}
|
||||||
|
if trailer.Error != nil {
|
||||||
|
code := trailer.Error.Code
|
||||||
|
if code == "" {
|
||||||
|
code = "unknown"
|
||||||
|
}
|
||||||
|
msg := trailer.Error.Message
|
||||||
|
if msg == "" {
|
||||||
|
msg = "Unknown error"
|
||||||
|
}
|
||||||
|
return &ConnectError{Code: code, Message: msg}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
563
internal/auth/cursor/proto/decode.go
Normal file
563
internal/auth/cursor/proto/decode.go
Normal file
@@ -0,0 +1,563 @@
|
|||||||
|
package proto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/protobuf/encoding/protowire"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServerMessageType identifies the kind of decoded server message.
|
||||||
|
type ServerMessageType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ServerMsgUnknown ServerMessageType = iota
|
||||||
|
ServerMsgTextDelta // Text content delta
|
||||||
|
ServerMsgThinkingDelta // Thinking/reasoning delta
|
||||||
|
ServerMsgThinkingCompleted // Thinking completed
|
||||||
|
ServerMsgKvGetBlob // Server wants a blob
|
||||||
|
ServerMsgKvSetBlob // Server wants to store a blob
|
||||||
|
ServerMsgExecRequestCtx // Server requests context (tools, etc.)
|
||||||
|
ServerMsgExecMcpArgs // Server wants MCP tool execution
|
||||||
|
ServerMsgExecShellArgs // Rejected: shell command
|
||||||
|
ServerMsgExecReadArgs // Rejected: file read
|
||||||
|
ServerMsgExecWriteArgs // Rejected: file write
|
||||||
|
ServerMsgExecDeleteArgs // Rejected: file delete
|
||||||
|
ServerMsgExecLsArgs // Rejected: directory listing
|
||||||
|
ServerMsgExecGrepArgs // Rejected: grep search
|
||||||
|
ServerMsgExecFetchArgs // Rejected: HTTP fetch
|
||||||
|
ServerMsgExecDiagnostics // Respond with empty diagnostics
|
||||||
|
ServerMsgExecShellStream // Rejected: shell stream
|
||||||
|
ServerMsgExecBgShellSpawn // Rejected: background shell
|
||||||
|
ServerMsgExecWriteShellStdin // Rejected: write shell stdin
|
||||||
|
ServerMsgExecOther // Other exec types (respond with empty)
|
||||||
|
ServerMsgTurnEnded // Turn has ended (no more output)
|
||||||
|
ServerMsgHeartbeat // Server heartbeat
|
||||||
|
ServerMsgTokenDelta // Token usage delta
|
||||||
|
ServerMsgCheckpoint // Conversation checkpoint update
|
||||||
|
)
|
||||||
|
|
||||||
|
// DecodedServerMessage holds parsed data from an AgentServerMessage.
|
||||||
|
type DecodedServerMessage struct {
|
||||||
|
Type ServerMessageType
|
||||||
|
|
||||||
|
// For text/thinking deltas
|
||||||
|
Text string
|
||||||
|
|
||||||
|
// For KV messages
|
||||||
|
KvId uint32
|
||||||
|
BlobId []byte // hex-encoded blob ID
|
||||||
|
BlobData []byte // for setBlobArgs
|
||||||
|
|
||||||
|
// For exec messages
|
||||||
|
ExecMsgId uint32
|
||||||
|
ExecId string
|
||||||
|
|
||||||
|
// For MCP args
|
||||||
|
McpToolName string
|
||||||
|
McpToolCallId string
|
||||||
|
McpArgs map[string][]byte // arg name -> protobuf-encoded value
|
||||||
|
|
||||||
|
// For rejection context
|
||||||
|
Path string
|
||||||
|
Command string
|
||||||
|
WorkingDirectory string
|
||||||
|
Url string
|
||||||
|
|
||||||
|
// For other exec - the raw field number for building a response
|
||||||
|
ExecFieldNumber int
|
||||||
|
|
||||||
|
// For TokenDeltaUpdate
|
||||||
|
TokenDelta int64
|
||||||
|
|
||||||
|
// For conversation checkpoint update (raw bytes, not decoded)
|
||||||
|
CheckpointData []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeAgentServerMessage parses an AgentServerMessage and returns
|
||||||
|
// a structured representation of the first meaningful message found.
|
||||||
|
func DecodeAgentServerMessage(data []byte) (*DecodedServerMessage, error) {
|
||||||
|
msg := &DecodedServerMessage{Type: ServerMsgUnknown}
|
||||||
|
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return msg, fmt.Errorf("invalid tag")
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
switch typ {
|
||||||
|
case protowire.BytesType:
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return msg, fmt.Errorf("invalid bytes field %d", num)
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
// Debug: log top-level ASM fields
|
||||||
|
log.Debugf("DecodeAgentServerMessage: found ASM field %d, len=%d", num, len(val))
|
||||||
|
|
||||||
|
switch num {
|
||||||
|
case ASM_InteractionUpdate:
|
||||||
|
log.Debugf("DecodeAgentServerMessage: calling decodeInteractionUpdate")
|
||||||
|
decodeInteractionUpdate(val, msg)
|
||||||
|
case ASM_ExecServerMessage:
|
||||||
|
log.Debugf("DecodeAgentServerMessage: calling decodeExecServerMessage")
|
||||||
|
decodeExecServerMessage(val, msg)
|
||||||
|
case ASM_KvServerMessage:
|
||||||
|
decodeKvServerMessage(val, msg)
|
||||||
|
case ASM_ConversationCheckpoint:
|
||||||
|
msg.Type = ServerMsgCheckpoint
|
||||||
|
msg.CheckpointData = append([]byte(nil), val...) // copy raw bytes
|
||||||
|
log.Debugf("DecodeAgentServerMessage: captured checkpoint %d bytes", len(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
case protowire.VarintType:
|
||||||
|
_, n := protowire.ConsumeVarint(data)
|
||||||
|
if n < 0 {
|
||||||
|
return msg, fmt.Errorf("invalid varint field %d", num)
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Skip unknown wire types
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return msg, fmt.Errorf("invalid field %d", num)
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeInteractionUpdate(data []byte, msg *DecodedServerMessage) {
|
||||||
|
log.Debugf("decodeInteractionUpdate: input len=%d, hex=%x", len(data), data)
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
log.Debugf("decodeInteractionUpdate: invalid tag, remaining=%x", data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
log.Debugf("decodeInteractionUpdate: field=%d wire=%d remaining=%d bytes", num, typ, len(data))
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
log.Debugf("decodeInteractionUpdate: invalid bytes field %d", num)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
log.Debugf("decodeInteractionUpdate: field %d content len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
|
||||||
|
|
||||||
|
switch num {
|
||||||
|
case IU_TextDelta:
|
||||||
|
msg.Type = ServerMsgTextDelta
|
||||||
|
msg.Text = decodeStringField(val, TDU_Text)
|
||||||
|
log.Debugf("decodeInteractionUpdate: TextDelta text=%q", msg.Text)
|
||||||
|
case IU_ThinkingDelta:
|
||||||
|
msg.Type = ServerMsgThinkingDelta
|
||||||
|
msg.Text = decodeStringField(val, TKD_Text)
|
||||||
|
log.Debugf("decodeInteractionUpdate: ThinkingDelta text=%q", msg.Text)
|
||||||
|
case IU_ThinkingCompleted:
|
||||||
|
msg.Type = ServerMsgThinkingCompleted
|
||||||
|
log.Debugf("decodeInteractionUpdate: ThinkingCompleted")
|
||||||
|
case 2:
|
||||||
|
// tool_call_started - ignore but log
|
||||||
|
log.Debugf("decodeInteractionUpdate: ToolCallStarted (ignored)")
|
||||||
|
case 3:
|
||||||
|
// tool_call_completed - ignore but log
|
||||||
|
log.Debugf("decodeInteractionUpdate: ToolCallCompleted (ignored)")
|
||||||
|
case 8:
|
||||||
|
// token_delta - extract token count
|
||||||
|
msg.Type = ServerMsgTokenDelta
|
||||||
|
msg.TokenDelta = decodeVarintField(val, 1)
|
||||||
|
log.Debugf("decodeInteractionUpdate: TokenDeltaUpdate tokens=%d", msg.TokenDelta)
|
||||||
|
case 13:
|
||||||
|
// heartbeat from server
|
||||||
|
msg.Type = ServerMsgHeartbeat
|
||||||
|
case 14:
|
||||||
|
// turn_ended - critical: model finished generating
|
||||||
|
msg.Type = ServerMsgTurnEnded
|
||||||
|
log.Debugf("decodeInteractionUpdate: TurnEndedUpdate - stream should end")
|
||||||
|
case 16:
|
||||||
|
// step_started - ignore
|
||||||
|
log.Debugf("decodeInteractionUpdate: StepStartedUpdate (ignored)")
|
||||||
|
case 17:
|
||||||
|
// step_completed - ignore
|
||||||
|
log.Debugf("decodeInteractionUpdate: StepCompletedUpdate (ignored)")
|
||||||
|
default:
|
||||||
|
log.Debugf("decodeInteractionUpdate: unknown field %d", num)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeKvServerMessage(data []byte, msg *DecodedServerMessage) {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
switch typ {
|
||||||
|
case protowire.VarintType:
|
||||||
|
val, n := protowire.ConsumeVarint(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == KSM_Id {
|
||||||
|
msg.KvId = uint32(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
case protowire.BytesType:
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
switch num {
|
||||||
|
case KSM_GetBlobArgs:
|
||||||
|
msg.Type = ServerMsgKvGetBlob
|
||||||
|
msg.BlobId = decodeBytesField(val, GBA_BlobId)
|
||||||
|
case KSM_SetBlobArgs:
|
||||||
|
msg.Type = ServerMsgKvSetBlob
|
||||||
|
decodeSetBlobArgs(val, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeSetBlobArgs(data []byte, msg *DecodedServerMessage) {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
switch num {
|
||||||
|
case SBA_BlobId:
|
||||||
|
msg.BlobId = val
|
||||||
|
case SBA_BlobData:
|
||||||
|
msg.BlobData = val
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeExecServerMessage(data []byte, msg *DecodedServerMessage) {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
switch typ {
|
||||||
|
case protowire.VarintType:
|
||||||
|
val, n := protowire.ConsumeVarint(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == ESM_Id {
|
||||||
|
msg.ExecMsgId = uint32(val)
|
||||||
|
log.Debugf("decodeExecServerMessage: ESM_Id = %d", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
case protowire.BytesType:
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
// Debug: log all fields found in ExecServerMessage
|
||||||
|
log.Debugf("decodeExecServerMessage: found field %d, len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
|
||||||
|
|
||||||
|
switch num {
|
||||||
|
case ESM_ExecId:
|
||||||
|
msg.ExecId = string(val)
|
||||||
|
log.Debugf("decodeExecServerMessage: ESM_ExecId = %q", msg.ExecId)
|
||||||
|
case ESM_RequestContextArgs:
|
||||||
|
msg.Type = ServerMsgExecRequestCtx
|
||||||
|
case ESM_McpArgs:
|
||||||
|
msg.Type = ServerMsgExecMcpArgs
|
||||||
|
decodeMcpArgs(val, msg)
|
||||||
|
case ESM_ShellArgs:
|
||||||
|
msg.Type = ServerMsgExecShellArgs
|
||||||
|
decodeShellArgs(val, msg)
|
||||||
|
case ESM_ShellStreamArgs:
|
||||||
|
msg.Type = ServerMsgExecShellStream
|
||||||
|
decodeShellArgs(val, msg)
|
||||||
|
case ESM_ReadArgs:
|
||||||
|
msg.Type = ServerMsgExecReadArgs
|
||||||
|
msg.Path = decodeStringField(val, RA_Path)
|
||||||
|
case ESM_WriteArgs:
|
||||||
|
msg.Type = ServerMsgExecWriteArgs
|
||||||
|
msg.Path = decodeStringField(val, WA_Path)
|
||||||
|
case ESM_DeleteArgs:
|
||||||
|
msg.Type = ServerMsgExecDeleteArgs
|
||||||
|
msg.Path = decodeStringField(val, DA_Path)
|
||||||
|
case ESM_LsArgs:
|
||||||
|
msg.Type = ServerMsgExecLsArgs
|
||||||
|
msg.Path = decodeStringField(val, LA_Path)
|
||||||
|
case ESM_GrepArgs:
|
||||||
|
msg.Type = ServerMsgExecGrepArgs
|
||||||
|
case ESM_FetchArgs:
|
||||||
|
msg.Type = ServerMsgExecFetchArgs
|
||||||
|
msg.Url = decodeStringField(val, FA_Url)
|
||||||
|
case ESM_DiagnosticsArgs:
|
||||||
|
msg.Type = ServerMsgExecDiagnostics
|
||||||
|
case ESM_BackgroundShellSpawn:
|
||||||
|
msg.Type = ServerMsgExecBgShellSpawn
|
||||||
|
decodeShellArgs(val, msg) // same structure
|
||||||
|
case ESM_WriteShellStdinArgs:
|
||||||
|
msg.Type = ServerMsgExecWriteShellStdin
|
||||||
|
default:
|
||||||
|
// Unknown exec types - only set if we haven't identified the type yet
|
||||||
|
// (other fields like span_context (19) come after the exec type field)
|
||||||
|
if msg.Type == ServerMsgUnknown {
|
||||||
|
msg.Type = ServerMsgExecOther
|
||||||
|
msg.ExecFieldNumber = int(num)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeMcpArgs(data []byte, msg *DecodedServerMessage) {
|
||||||
|
msg.McpArgs = make(map[string][]byte)
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
switch num {
|
||||||
|
case MCA_Name:
|
||||||
|
msg.McpToolName = string(val)
|
||||||
|
case MCA_Args:
|
||||||
|
// Map entries are encoded as submessages with key=1, value=2
|
||||||
|
decodeMapEntry(val, msg.McpArgs)
|
||||||
|
case MCA_ToolCallId:
|
||||||
|
msg.McpToolCallId = string(val)
|
||||||
|
case MCA_ToolName:
|
||||||
|
// ToolName takes precedence if present
|
||||||
|
if msg.McpToolName == "" || string(val) != "" {
|
||||||
|
msg.McpToolName = string(val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeMapEntry(data []byte, m map[string][]byte) {
|
||||||
|
var key string
|
||||||
|
var value []byte
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == 1 {
|
||||||
|
key = string(val)
|
||||||
|
} else if num == 2 {
|
||||||
|
value = append([]byte(nil), val...)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if key != "" {
|
||||||
|
m[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeShellArgs(data []byte, msg *DecodedServerMessage) {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
switch num {
|
||||||
|
case SHA_Command:
|
||||||
|
msg.Command = string(val)
|
||||||
|
case SHA_WorkingDirectory:
|
||||||
|
msg.WorkingDirectory = string(val)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Helper decoders ---
|
||||||
|
|
||||||
|
// decodeStringField extracts a string from the first matching field in a submessage.
|
||||||
|
func decodeStringField(data []byte, targetField protowire.Number) string {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == targetField {
|
||||||
|
return string(val)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeBytesField extracts bytes from the first matching field in a submessage.
|
||||||
|
func decodeBytesField(data []byte, targetField protowire.Number) []byte {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
|
||||||
|
if typ == protowire.BytesType {
|
||||||
|
val, n := protowire.ConsumeBytes(data)
|
||||||
|
if n < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == targetField {
|
||||||
|
return append([]byte(nil), val...)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeVarintField extracts an int64 from the first matching varint field in a submessage.
|
||||||
|
func decodeVarintField(data []byte, targetField protowire.Number) int64 {
|
||||||
|
for len(data) > 0 {
|
||||||
|
num, typ, n := protowire.ConsumeTag(data)
|
||||||
|
if n < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if typ == protowire.VarintType {
|
||||||
|
val, n := protowire.ConsumeVarint(data)
|
||||||
|
if n < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
if num == targetField {
|
||||||
|
return int64(val)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n := protowire.ConsumeFieldValue(num, typ, data)
|
||||||
|
if n < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
data = data[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlobIdHex returns the hex string of a blob ID for use as a map key.
|
||||||
|
func BlobIdHex(blobId []byte) string {
|
||||||
|
return hex.EncodeToString(blobId)
|
||||||
|
}
|
||||||
1244
internal/auth/cursor/proto/descriptor.go
Normal file
1244
internal/auth/cursor/proto/descriptor.go
Normal file
File diff suppressed because it is too large
Load Diff
664
internal/auth/cursor/proto/encode.go
Normal file
664
internal/auth/cursor/proto/encode.go
Normal file
@@ -0,0 +1,664 @@
|
|||||||
|
// Package proto provides protobuf encoding for Cursor's gRPC API,
|
||||||
|
// using dynamicpb with the embedded FileDescriptorProto from agent.proto.
|
||||||
|
// This mirrors the cursor-auth TS plugin's use of @bufbuild/protobuf create()+toBinary().
|
||||||
|
package proto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/protobuf/encoding/protowire"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
"google.golang.org/protobuf/reflect/protoreflect"
|
||||||
|
"google.golang.org/protobuf/types/dynamicpb"
|
||||||
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Public types ---
|
||||||
|
|
||||||
|
// RunRequestParams holds all data needed to build an AgentRunRequest.
|
||||||
|
type RunRequestParams struct {
|
||||||
|
ModelId string
|
||||||
|
SystemPrompt string
|
||||||
|
UserText string
|
||||||
|
MessageId string
|
||||||
|
ConversationId string
|
||||||
|
Images []ImageData
|
||||||
|
Turns []TurnData
|
||||||
|
McpTools []McpToolDef
|
||||||
|
BlobStore map[string][]byte // hex(sha256) -> data, populated during encoding
|
||||||
|
RawCheckpoint []byte // if non-nil, use as conversation_state directly (from server checkpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageData struct {
|
||||||
|
MimeType string
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type TurnData struct {
|
||||||
|
UserText string
|
||||||
|
AssistantText string
|
||||||
|
}
|
||||||
|
|
||||||
|
type McpToolDef struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
InputSchema json.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Helper: create a dynamic message and set fields ---
|
||||||
|
|
||||||
|
func newMsg(name string) *dynamicpb.Message {
|
||||||
|
return dynamicpb.NewMessage(Msg(name))
|
||||||
|
}
|
||||||
|
|
||||||
|
func field(msg *dynamicpb.Message, name string) protoreflect.FieldDescriptor {
|
||||||
|
return msg.Descriptor().Fields().ByName(protoreflect.Name(name))
|
||||||
|
}
|
||||||
|
|
||||||
|
func setStr(msg *dynamicpb.Message, name, val string) {
|
||||||
|
if val != "" {
|
||||||
|
msg.Set(field(msg, name), protoreflect.ValueOfString(val))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setBytes(msg *dynamicpb.Message, name string, val []byte) {
|
||||||
|
if len(val) > 0 {
|
||||||
|
msg.Set(field(msg, name), protoreflect.ValueOfBytes(val))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setUint32(msg *dynamicpb.Message, name string, val uint32) {
|
||||||
|
msg.Set(field(msg, name), protoreflect.ValueOfUint32(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
func setBool(msg *dynamicpb.Message, name string, val bool) {
|
||||||
|
msg.Set(field(msg, name), protoreflect.ValueOfBool(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
func setMsg(msg *dynamicpb.Message, name string, sub *dynamicpb.Message) {
|
||||||
|
msg.Set(field(msg, name), protoreflect.ValueOfMessage(sub.ProtoReflect()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func marshal(msg *dynamicpb.Message) []byte {
|
||||||
|
b, err := proto.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
panic("cursor proto marshal: " + err.Error())
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Encode functions mirroring cursor-fetch.ts ---
|
||||||
|
|
||||||
|
// EncodeHeartbeat returns an encoded AgentClientMessage with clientHeartbeat.
|
||||||
|
// Mirrors: create(AgentClientMessageSchema, { message: { case: 'clientHeartbeat', value: create(ClientHeartbeatSchema, {}) } })
|
||||||
|
func EncodeHeartbeat() []byte {
|
||||||
|
hb := newMsg("ClientHeartbeat")
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "client_heartbeat", hb)
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeRunRequest builds a full AgentClientMessage wrapping an AgentRunRequest.
|
||||||
|
// Mirrors buildCursorRequest() in cursor-fetch.ts.
|
||||||
|
// If p.RawCheckpoint is set, it is used directly as the conversation_state bytes
|
||||||
|
// (from a previous conversation_checkpoint_update), skipping manual turn construction.
|
||||||
|
func EncodeRunRequest(p *RunRequestParams) []byte {
|
||||||
|
if p.RawCheckpoint != nil {
|
||||||
|
return encodeRunRequestWithCheckpoint(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.BlobStore == nil {
|
||||||
|
p.BlobStore = make(map[string][]byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Conversation turns ---
|
||||||
|
// Each turn is serialized as bytes (ConversationTurnStructure → bytes)
|
||||||
|
var turnBytes [][]byte
|
||||||
|
for _, turn := range p.Turns {
|
||||||
|
// UserMessage for this turn
|
||||||
|
um := newMsg("UserMessage")
|
||||||
|
setStr(um, "text", turn.UserText)
|
||||||
|
setStr(um, "message_id", generateId())
|
||||||
|
umBytes := marshal(um)
|
||||||
|
|
||||||
|
// Steps (assistant response)
|
||||||
|
var stepBytes [][]byte
|
||||||
|
if turn.AssistantText != "" {
|
||||||
|
am := newMsg("AssistantMessage")
|
||||||
|
setStr(am, "text", turn.AssistantText)
|
||||||
|
step := newMsg("ConversationStep")
|
||||||
|
setMsg(step, "assistant_message", am)
|
||||||
|
stepBytes = append(stepBytes, marshal(step))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AgentConversationTurnStructure (fields are bytes, not submessages)
|
||||||
|
agentTurn := newMsg("AgentConversationTurnStructure")
|
||||||
|
setBytes(agentTurn, "user_message", umBytes)
|
||||||
|
for _, sb := range stepBytes {
|
||||||
|
stepsField := field(agentTurn, "steps")
|
||||||
|
list := agentTurn.Mutable(stepsField).List()
|
||||||
|
list.Append(protoreflect.ValueOfBytes(sb))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationTurnStructure (oneof turn → agentConversationTurn)
|
||||||
|
cts := newMsg("ConversationTurnStructure")
|
||||||
|
setMsg(cts, "agent_conversation_turn", agentTurn)
|
||||||
|
turnBytes = append(turnBytes, marshal(cts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- System prompt blob ---
|
||||||
|
systemJSON, _ := json.Marshal(map[string]string{"role": "system", "content": p.SystemPrompt})
|
||||||
|
blobId := sha256Sum(systemJSON)
|
||||||
|
p.BlobStore[hex.EncodeToString(blobId)] = systemJSON
|
||||||
|
|
||||||
|
// --- ConversationStateStructure ---
|
||||||
|
css := newMsg("ConversationStateStructure")
|
||||||
|
// rootPromptMessagesJson: repeated bytes
|
||||||
|
rootField := field(css, "root_prompt_messages_json")
|
||||||
|
rootList := css.Mutable(rootField).List()
|
||||||
|
rootList.Append(protoreflect.ValueOfBytes(blobId))
|
||||||
|
// turns: repeated bytes (field 8) + turns_old (field 2) for compatibility
|
||||||
|
turnsField := field(css, "turns")
|
||||||
|
turnsList := css.Mutable(turnsField).List()
|
||||||
|
for _, tb := range turnBytes {
|
||||||
|
turnsList.Append(protoreflect.ValueOfBytes(tb))
|
||||||
|
}
|
||||||
|
turnsOldField := field(css, "turns_old")
|
||||||
|
if turnsOldField != nil {
|
||||||
|
turnsOldList := css.Mutable(turnsOldField).List()
|
||||||
|
for _, tb := range turnBytes {
|
||||||
|
turnsOldList.Append(protoreflect.ValueOfBytes(tb))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- UserMessage (current) ---
|
||||||
|
userMessage := newMsg("UserMessage")
|
||||||
|
setStr(userMessage, "text", p.UserText)
|
||||||
|
setStr(userMessage, "message_id", p.MessageId)
|
||||||
|
|
||||||
|
// Images via SelectedContext
|
||||||
|
if len(p.Images) > 0 {
|
||||||
|
sc := newMsg("SelectedContext")
|
||||||
|
imgsField := field(sc, "selected_images")
|
||||||
|
imgsList := sc.Mutable(imgsField).List()
|
||||||
|
for _, img := range p.Images {
|
||||||
|
si := newMsg("SelectedImage")
|
||||||
|
setStr(si, "uuid", generateId())
|
||||||
|
setStr(si, "mime_type", img.MimeType)
|
||||||
|
setBytes(si, "data", img.Data)
|
||||||
|
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
|
||||||
|
}
|
||||||
|
setMsg(userMessage, "selected_context", sc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- UserMessageAction ---
|
||||||
|
uma := newMsg("UserMessageAction")
|
||||||
|
setMsg(uma, "user_message", userMessage)
|
||||||
|
|
||||||
|
// --- ConversationAction ---
|
||||||
|
ca := newMsg("ConversationAction")
|
||||||
|
setMsg(ca, "user_message_action", uma)
|
||||||
|
|
||||||
|
// --- ModelDetails ---
|
||||||
|
md := newMsg("ModelDetails")
|
||||||
|
setStr(md, "model_id", p.ModelId)
|
||||||
|
setStr(md, "display_model_id", p.ModelId)
|
||||||
|
setStr(md, "display_name", p.ModelId)
|
||||||
|
|
||||||
|
// --- AgentRunRequest ---
|
||||||
|
arr := newMsg("AgentRunRequest")
|
||||||
|
setMsg(arr, "conversation_state", css)
|
||||||
|
setMsg(arr, "action", ca)
|
||||||
|
setMsg(arr, "model_details", md)
|
||||||
|
setStr(arr, "conversation_id", p.ConversationId)
|
||||||
|
|
||||||
|
// McpTools
|
||||||
|
if len(p.McpTools) > 0 {
|
||||||
|
mcpTools := newMsg("McpTools")
|
||||||
|
toolsField := field(mcpTools, "mcp_tools")
|
||||||
|
toolsList := mcpTools.Mutable(toolsField).List()
|
||||||
|
for _, tool := range p.McpTools {
|
||||||
|
td := newMsg("McpToolDefinition")
|
||||||
|
setStr(td, "name", tool.Name)
|
||||||
|
setStr(td, "description", tool.Description)
|
||||||
|
if len(tool.InputSchema) > 0 {
|
||||||
|
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||||
|
}
|
||||||
|
setStr(td, "provider_identifier", "proxy")
|
||||||
|
setStr(td, "tool_name", tool.Name)
|
||||||
|
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||||
|
}
|
||||||
|
setMsg(arr, "mcp_tools", mcpTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- AgentClientMessage ---
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "run_request", arr)
|
||||||
|
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeRunRequestWithCheckpoint builds an AgentClientMessage using a raw checkpoint
|
||||||
|
// as conversation_state. The checkpoint bytes are embedded directly without deserialization.
|
||||||
|
func encodeRunRequestWithCheckpoint(p *RunRequestParams) []byte {
|
||||||
|
// Build UserMessage
|
||||||
|
userMessage := newMsg("UserMessage")
|
||||||
|
setStr(userMessage, "text", p.UserText)
|
||||||
|
setStr(userMessage, "message_id", p.MessageId)
|
||||||
|
if len(p.Images) > 0 {
|
||||||
|
sc := newMsg("SelectedContext")
|
||||||
|
imgsField := field(sc, "selected_images")
|
||||||
|
imgsList := sc.Mutable(imgsField).List()
|
||||||
|
for _, img := range p.Images {
|
||||||
|
si := newMsg("SelectedImage")
|
||||||
|
setStr(si, "uuid", generateId())
|
||||||
|
setStr(si, "mime_type", img.MimeType)
|
||||||
|
setBytes(si, "data", img.Data)
|
||||||
|
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
|
||||||
|
}
|
||||||
|
setMsg(userMessage, "selected_context", sc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build ConversationAction with UserMessageAction
|
||||||
|
uma := newMsg("UserMessageAction")
|
||||||
|
setMsg(uma, "user_message", userMessage)
|
||||||
|
ca := newMsg("ConversationAction")
|
||||||
|
setMsg(ca, "user_message_action", uma)
|
||||||
|
caBytes := marshal(ca)
|
||||||
|
|
||||||
|
// Build ModelDetails
|
||||||
|
md := newMsg("ModelDetails")
|
||||||
|
setStr(md, "model_id", p.ModelId)
|
||||||
|
setStr(md, "display_model_id", p.ModelId)
|
||||||
|
setStr(md, "display_name", p.ModelId)
|
||||||
|
mdBytes := marshal(md)
|
||||||
|
|
||||||
|
// Build McpTools
|
||||||
|
var mcpToolsBytes []byte
|
||||||
|
if len(p.McpTools) > 0 {
|
||||||
|
mcpTools := newMsg("McpTools")
|
||||||
|
toolsField := field(mcpTools, "mcp_tools")
|
||||||
|
toolsList := mcpTools.Mutable(toolsField).List()
|
||||||
|
for _, tool := range p.McpTools {
|
||||||
|
td := newMsg("McpToolDefinition")
|
||||||
|
setStr(td, "name", tool.Name)
|
||||||
|
setStr(td, "description", tool.Description)
|
||||||
|
if len(tool.InputSchema) > 0 {
|
||||||
|
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||||
|
}
|
||||||
|
setStr(td, "provider_identifier", "proxy")
|
||||||
|
setStr(td, "tool_name", tool.Name)
|
||||||
|
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||||
|
}
|
||||||
|
mcpToolsBytes = marshal(mcpTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manually assemble AgentRunRequest using protowire to embed raw checkpoint
|
||||||
|
var arrBuf []byte
|
||||||
|
// field 1: conversation_state = raw checkpoint bytes (length-delimited)
|
||||||
|
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationState, protowire.BytesType)
|
||||||
|
arrBuf = protowire.AppendBytes(arrBuf, p.RawCheckpoint)
|
||||||
|
// field 2: action = ConversationAction
|
||||||
|
arrBuf = protowire.AppendTag(arrBuf, ARR_Action, protowire.BytesType)
|
||||||
|
arrBuf = protowire.AppendBytes(arrBuf, caBytes)
|
||||||
|
// field 3: model_details = ModelDetails
|
||||||
|
arrBuf = protowire.AppendTag(arrBuf, ARR_ModelDetails, protowire.BytesType)
|
||||||
|
arrBuf = protowire.AppendBytes(arrBuf, mdBytes)
|
||||||
|
// field 4: mcp_tools = McpTools
|
||||||
|
if len(mcpToolsBytes) > 0 {
|
||||||
|
arrBuf = protowire.AppendTag(arrBuf, ARR_McpTools, protowire.BytesType)
|
||||||
|
arrBuf = protowire.AppendBytes(arrBuf, mcpToolsBytes)
|
||||||
|
}
|
||||||
|
// field 5: conversation_id = string
|
||||||
|
if p.ConversationId != "" {
|
||||||
|
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationId, protowire.BytesType)
|
||||||
|
arrBuf = protowire.AppendString(arrBuf, p.ConversationId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap in AgentClientMessage field 1 (run_request)
|
||||||
|
var acmBuf []byte
|
||||||
|
acmBuf = protowire.AppendTag(acmBuf, ACM_RunRequest, protowire.BytesType)
|
||||||
|
acmBuf = protowire.AppendBytes(acmBuf, arrBuf)
|
||||||
|
|
||||||
|
log.Debugf("cursor encode: built RunRequest with checkpoint (%d bytes), total=%d bytes", len(p.RawCheckpoint), len(acmBuf))
|
||||||
|
return acmBuf
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResumeRequestParams holds data for a ResumeAction request.
|
||||||
|
type ResumeRequestParams struct {
|
||||||
|
ModelId string
|
||||||
|
ConversationId string
|
||||||
|
McpTools []McpToolDef
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeResumeRequest builds an AgentClientMessage with ResumeAction.
|
||||||
|
// Used to resume a conversation by conversation_id without re-sending full history.
|
||||||
|
func EncodeResumeRequest(p *ResumeRequestParams) []byte {
|
||||||
|
// RequestContext with tools
|
||||||
|
rc := newMsg("RequestContext")
|
||||||
|
if len(p.McpTools) > 0 {
|
||||||
|
toolsField := field(rc, "tools")
|
||||||
|
toolsList := rc.Mutable(toolsField).List()
|
||||||
|
for _, tool := range p.McpTools {
|
||||||
|
td := newMsg("McpToolDefinition")
|
||||||
|
setStr(td, "name", tool.Name)
|
||||||
|
setStr(td, "description", tool.Description)
|
||||||
|
if len(tool.InputSchema) > 0 {
|
||||||
|
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||||
|
}
|
||||||
|
setStr(td, "provider_identifier", "proxy")
|
||||||
|
setStr(td, "tool_name", tool.Name)
|
||||||
|
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResumeAction
|
||||||
|
ra := newMsg("ResumeAction")
|
||||||
|
setMsg(ra, "request_context", rc)
|
||||||
|
|
||||||
|
// ConversationAction with resume_action
|
||||||
|
ca := newMsg("ConversationAction")
|
||||||
|
setMsg(ca, "resume_action", ra)
|
||||||
|
|
||||||
|
// ModelDetails
|
||||||
|
md := newMsg("ModelDetails")
|
||||||
|
setStr(md, "model_id", p.ModelId)
|
||||||
|
setStr(md, "display_model_id", p.ModelId)
|
||||||
|
setStr(md, "display_name", p.ModelId)
|
||||||
|
|
||||||
|
// AgentRunRequest — no conversation_state needed for resume
|
||||||
|
arr := newMsg("AgentRunRequest")
|
||||||
|
setMsg(arr, "action", ca)
|
||||||
|
setMsg(arr, "model_details", md)
|
||||||
|
setStr(arr, "conversation_id", p.ConversationId)
|
||||||
|
|
||||||
|
// McpTools at top level
|
||||||
|
if len(p.McpTools) > 0 {
|
||||||
|
mcpTools := newMsg("McpTools")
|
||||||
|
toolsField := field(mcpTools, "mcp_tools")
|
||||||
|
toolsList := mcpTools.Mutable(toolsField).List()
|
||||||
|
for _, tool := range p.McpTools {
|
||||||
|
td := newMsg("McpToolDefinition")
|
||||||
|
setStr(td, "name", tool.Name)
|
||||||
|
setStr(td, "description", tool.Description)
|
||||||
|
if len(tool.InputSchema) > 0 {
|
||||||
|
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||||
|
}
|
||||||
|
setStr(td, "provider_identifier", "proxy")
|
||||||
|
setStr(td, "tool_name", tool.Name)
|
||||||
|
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||||
|
}
|
||||||
|
setMsg(arr, "mcp_tools", mcpTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "run_request", arr)
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- KV response encoders ---
|
||||||
|
// Mirrors handleKvMessage() in cursor-fetch.ts
|
||||||
|
|
||||||
|
// EncodeKvGetBlobResult responds to a getBlobArgs request.
|
||||||
|
func EncodeKvGetBlobResult(kvId uint32, blobData []byte) []byte {
|
||||||
|
result := newMsg("GetBlobResult")
|
||||||
|
if blobData != nil {
|
||||||
|
setBytes(result, "blob_data", blobData)
|
||||||
|
}
|
||||||
|
|
||||||
|
kvc := newMsg("KvClientMessage")
|
||||||
|
setUint32(kvc, "id", kvId)
|
||||||
|
setMsg(kvc, "get_blob_result", result)
|
||||||
|
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "kv_client_message", kvc)
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeKvSetBlobResult responds to a setBlobArgs request.
|
||||||
|
func EncodeKvSetBlobResult(kvId uint32) []byte {
|
||||||
|
result := newMsg("SetBlobResult")
|
||||||
|
|
||||||
|
kvc := newMsg("KvClientMessage")
|
||||||
|
setUint32(kvc, "id", kvId)
|
||||||
|
setMsg(kvc, "set_blob_result", result)
|
||||||
|
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "kv_client_message", kvc)
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Exec response encoders ---
|
||||||
|
// Mirrors handleExecMessage() and sendExec() in cursor-fetch.ts
|
||||||
|
|
||||||
|
// EncodeExecRequestContextResult responds to requestContextArgs with tool definitions.
|
||||||
|
func EncodeExecRequestContextResult(execMsgId uint32, execId string, tools []McpToolDef) []byte {
|
||||||
|
// RequestContext with tools
|
||||||
|
rc := newMsg("RequestContext")
|
||||||
|
if len(tools) > 0 {
|
||||||
|
toolsField := field(rc, "tools")
|
||||||
|
toolsList := rc.Mutable(toolsField).List()
|
||||||
|
for _, tool := range tools {
|
||||||
|
td := newMsg("McpToolDefinition")
|
||||||
|
setStr(td, "name", tool.Name)
|
||||||
|
setStr(td, "description", tool.Description)
|
||||||
|
if len(tool.InputSchema) > 0 {
|
||||||
|
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
|
||||||
|
}
|
||||||
|
setStr(td, "provider_identifier", "proxy")
|
||||||
|
setStr(td, "tool_name", tool.Name)
|
||||||
|
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestContextSuccess
|
||||||
|
rcs := newMsg("RequestContextSuccess")
|
||||||
|
setMsg(rcs, "request_context", rc)
|
||||||
|
|
||||||
|
// RequestContextResult (oneof success)
|
||||||
|
rcr := newMsg("RequestContextResult")
|
||||||
|
setMsg(rcr, "success", rcs)
|
||||||
|
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "request_context_result", rcr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeExecMcpResult responds with MCP tool result.
|
||||||
|
func EncodeExecMcpResult(execMsgId uint32, execId string, content string, isError bool) []byte {
|
||||||
|
textContent := newMsg("McpTextContent")
|
||||||
|
setStr(textContent, "text", content)
|
||||||
|
|
||||||
|
contentItem := newMsg("McpToolResultContentItem")
|
||||||
|
setMsg(contentItem, "text", textContent)
|
||||||
|
|
||||||
|
success := newMsg("McpSuccess")
|
||||||
|
contentField := field(success, "content")
|
||||||
|
contentList := success.Mutable(contentField).List()
|
||||||
|
contentList.Append(protoreflect.ValueOfMessage(contentItem.ProtoReflect()))
|
||||||
|
setBool(success, "is_error", isError)
|
||||||
|
|
||||||
|
result := newMsg("McpResult")
|
||||||
|
setMsg(result, "success", success)
|
||||||
|
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeExecMcpError responds with MCP error.
|
||||||
|
func EncodeExecMcpError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||||
|
mcpErr := newMsg("McpError")
|
||||||
|
setStr(mcpErr, "error", errMsg)
|
||||||
|
|
||||||
|
result := newMsg("McpResult")
|
||||||
|
setMsg(result, "error", mcpErr)
|
||||||
|
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Rejection encoders (mirror handleExecMessage rejections) ---
|
||||||
|
|
||||||
|
func EncodeExecReadRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||||
|
rej := newMsg("ReadRejected")
|
||||||
|
setStr(rej, "path", path)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("ReadResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "read_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecShellRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
|
||||||
|
rej := newMsg("ShellRejected")
|
||||||
|
setStr(rej, "command", command)
|
||||||
|
setStr(rej, "working_directory", workDir)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("ShellResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "shell_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecWriteRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||||
|
rej := newMsg("WriteRejected")
|
||||||
|
setStr(rej, "path", path)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("WriteResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "write_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecDeleteRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||||
|
rej := newMsg("DeleteRejected")
|
||||||
|
setStr(rej, "path", path)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("DeleteResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "delete_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecLsRejected(execMsgId uint32, execId string, path, reason string) []byte {
|
||||||
|
rej := newMsg("LsRejected")
|
||||||
|
setStr(rej, "path", path)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("LsResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "ls_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecGrepError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||||
|
grepErr := newMsg("GrepError")
|
||||||
|
setStr(grepErr, "error", errMsg)
|
||||||
|
result := newMsg("GrepResult")
|
||||||
|
setMsg(result, "error", grepErr)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "grep_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecFetchError(execMsgId uint32, execId string, url, errMsg string) []byte {
|
||||||
|
fetchErr := newMsg("FetchError")
|
||||||
|
setStr(fetchErr, "url", url)
|
||||||
|
setStr(fetchErr, "error", errMsg)
|
||||||
|
result := newMsg("FetchResult")
|
||||||
|
setMsg(result, "error", fetchErr)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "fetch_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecDiagnosticsResult(execMsgId uint32, execId string) []byte {
|
||||||
|
result := newMsg("DiagnosticsResult")
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "diagnostics_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecBackgroundShellSpawnRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
|
||||||
|
rej := newMsg("ShellRejected")
|
||||||
|
setStr(rej, "command", command)
|
||||||
|
setStr(rej, "working_directory", workDir)
|
||||||
|
setStr(rej, "reason", reason)
|
||||||
|
result := newMsg("BackgroundShellSpawnResult")
|
||||||
|
setMsg(result, "rejected", rej)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "background_shell_spawn_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeExecWriteShellStdinError(execMsgId uint32, execId string, errMsg string) []byte {
|
||||||
|
wsErr := newMsg("WriteShellStdinError")
|
||||||
|
setStr(wsErr, "error", errMsg)
|
||||||
|
result := newMsg("WriteShellStdinResult")
|
||||||
|
setMsg(result, "error", wsErr)
|
||||||
|
return encodeExecClientMsg(execMsgId, execId, "write_shell_stdin_result", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeExecClientMsg wraps an exec result in AgentClientMessage.
|
||||||
|
// Mirrors sendExec() in cursor-fetch.ts.
|
||||||
|
func encodeExecClientMsg(id uint32, execId string, resultFieldName string, resultMsg *dynamicpb.Message) []byte {
|
||||||
|
ecm := newMsg("ExecClientMessage")
|
||||||
|
setUint32(ecm, "id", id)
|
||||||
|
// Force set exec_id even if empty - Cursor requires this field to be set
|
||||||
|
ecm.Set(field(ecm, "exec_id"), protoreflect.ValueOfString(execId))
|
||||||
|
|
||||||
|
// Debug: check if field exists
|
||||||
|
fd := field(ecm, resultFieldName)
|
||||||
|
if fd == nil {
|
||||||
|
panic(fmt.Sprintf("field %q NOT FOUND in ExecClientMessage! Available fields: %v", resultFieldName, listFields(ecm)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug: log the actual field being set
|
||||||
|
log.Debugf("encodeExecClientMsg: setting field %q (number=%d, kind=%s)", fd.Name(), fd.Number(), fd.Kind())
|
||||||
|
|
||||||
|
ecm.Set(fd, protoreflect.ValueOfMessage(resultMsg.ProtoReflect()))
|
||||||
|
|
||||||
|
acm := newMsg("AgentClientMessage")
|
||||||
|
setMsg(acm, "exec_client_message", ecm)
|
||||||
|
return marshal(acm)
|
||||||
|
}
|
||||||
|
|
||||||
|
func listFields(msg *dynamicpb.Message) []string {
|
||||||
|
var names []string
|
||||||
|
for i := 0; i < msg.Descriptor().Fields().Len(); i++ {
|
||||||
|
names = append(names, string(msg.Descriptor().Fields().Get(i).Name()))
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Utilities ---
|
||||||
|
|
||||||
|
// jsonToProtobufValueBytes converts a JSON schema (json.RawMessage) to protobuf Value binary.
|
||||||
|
// This mirrors the TS pattern: toBinary(ValueSchema, fromJson(ValueSchema, jsonSchema))
|
||||||
|
func jsonToProtobufValueBytes(jsonData json.RawMessage) []byte {
|
||||||
|
if len(jsonData) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var v interface{}
|
||||||
|
if err := json.Unmarshal(jsonData, &v); err != nil {
|
||||||
|
return jsonData // fallback to raw JSON if parsing fails
|
||||||
|
}
|
||||||
|
pbVal, err := structpb.NewValue(v)
|
||||||
|
if err != nil {
|
||||||
|
return jsonData // fallback
|
||||||
|
}
|
||||||
|
b, err := proto.Marshal(pbVal)
|
||||||
|
if err != nil {
|
||||||
|
return jsonData // fallback
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProtobufValueBytesToJSON converts protobuf Value binary back to JSON.
|
||||||
|
// This mirrors the TS pattern: toJson(ValueSchema, fromBinary(ValueSchema, value))
|
||||||
|
func ProtobufValueBytesToJSON(data []byte) (interface{}, error) {
|
||||||
|
val := &structpb.Value{}
|
||||||
|
if err := proto.Unmarshal(data, val); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return val.AsInterface(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sha256Sum(data []byte) []byte {
|
||||||
|
h := sha256.Sum256(data)
|
||||||
|
return h[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
var idCounter uint64
|
||||||
|
|
||||||
|
func generateId() string {
|
||||||
|
idCounter++
|
||||||
|
h := sha256.Sum256([]byte{byte(idCounter), byte(idCounter >> 8), byte(idCounter >> 16)})
|
||||||
|
return hex.EncodeToString(h[:16])
|
||||||
|
}
|
||||||
332
internal/auth/cursor/proto/fieldnumbers.go
Normal file
332
internal/auth/cursor/proto/fieldnumbers.go
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
// Package proto provides hand-rolled protobuf encode/decode for Cursor's gRPC API.
|
||||||
|
// Field numbers are extracted from the TypeScript generated proto/agent_pb.ts in alma-plugins/cursor-auth.
|
||||||
|
package proto
|
||||||
|
|
||||||
|
// AgentClientMessage (msg 118) oneof "message"
|
||||||
|
const (
|
||||||
|
ACM_RunRequest = 1 // AgentRunRequest
|
||||||
|
ACM_ExecClientMessage = 2 // ExecClientMessage
|
||||||
|
ACM_KvClientMessage = 3 // KvClientMessage
|
||||||
|
ACM_ConversationAction = 4 // ConversationAction
|
||||||
|
ACM_ExecClientControlMsg = 5 // ExecClientControlMessage
|
||||||
|
ACM_InteractionResponse = 6 // InteractionResponse
|
||||||
|
ACM_ClientHeartbeat = 7 // ClientHeartbeat
|
||||||
|
)
|
||||||
|
|
||||||
|
// AgentServerMessage (msg 119) oneof "message"
|
||||||
|
const (
|
||||||
|
ASM_InteractionUpdate = 1 // InteractionUpdate
|
||||||
|
ASM_ExecServerMessage = 2 // ExecServerMessage
|
||||||
|
ASM_ConversationCheckpoint = 3 // ConversationStateStructure
|
||||||
|
ASM_KvServerMessage = 4 // KvServerMessage
|
||||||
|
ASM_ExecServerControlMessage = 5 // ExecServerControlMessage
|
||||||
|
ASM_InteractionQuery = 7 // InteractionQuery
|
||||||
|
)
|
||||||
|
|
||||||
|
// AgentRunRequest (msg 91)
|
||||||
|
const (
|
||||||
|
ARR_ConversationState = 1 // ConversationStateStructure
|
||||||
|
ARR_Action = 2 // ConversationAction
|
||||||
|
ARR_ModelDetails = 3 // ModelDetails
|
||||||
|
ARR_McpTools = 4 // McpTools
|
||||||
|
ARR_ConversationId = 5 // string (optional)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConversationStateStructure (msg 83)
|
||||||
|
const (
|
||||||
|
CSS_RootPromptMessagesJson = 1 // repeated bytes
|
||||||
|
CSS_TurnsOld = 2 // repeated bytes (deprecated)
|
||||||
|
CSS_Todos = 3 // repeated bytes
|
||||||
|
CSS_PendingToolCalls = 4 // repeated string
|
||||||
|
CSS_Turns = 8 // repeated bytes (CURRENT field for turns)
|
||||||
|
CSS_PreviousWorkspaceUris = 9 // repeated string
|
||||||
|
CSS_SelfSummaryCount = 17 // uint32
|
||||||
|
CSS_ReadPaths = 18 // repeated string
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConversationAction (msg 54) oneof "action"
|
||||||
|
const (
|
||||||
|
CA_UserMessageAction = 1 // UserMessageAction
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserMessageAction (msg 55)
|
||||||
|
const (
|
||||||
|
UMA_UserMessage = 1 // UserMessage
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserMessage (msg 63)
|
||||||
|
const (
|
||||||
|
UM_Text = 1 // string
|
||||||
|
UM_MessageId = 2 // string
|
||||||
|
UM_SelectedContext = 3 // SelectedContext (optional)
|
||||||
|
)
|
||||||
|
|
||||||
|
// SelectedContext
|
||||||
|
const (
|
||||||
|
SC_SelectedImages = 1 // repeated SelectedImage
|
||||||
|
)
|
||||||
|
|
||||||
|
// SelectedImage
|
||||||
|
const (
|
||||||
|
SI_BlobId = 1 // bytes (oneof dataOrBlobId)
|
||||||
|
SI_Uuid = 2 // string
|
||||||
|
SI_Path = 3 // string
|
||||||
|
SI_MimeType = 7 // string
|
||||||
|
SI_Data = 8 // bytes (oneof dataOrBlobId)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModelDetails (msg 88)
|
||||||
|
const (
|
||||||
|
MD_ModelId = 1 // string
|
||||||
|
MD_ThinkingDetails = 2 // ThinkingDetails (optional)
|
||||||
|
MD_DisplayModelId = 3 // string
|
||||||
|
MD_DisplayName = 4 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpTools (msg 307)
|
||||||
|
const (
|
||||||
|
MT_McpTools = 1 // repeated McpToolDefinition
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpToolDefinition (msg 306)
|
||||||
|
const (
|
||||||
|
MTD_Name = 1 // string
|
||||||
|
MTD_Description = 2 // string
|
||||||
|
MTD_InputSchema = 3 // bytes
|
||||||
|
MTD_ProviderIdentifier = 4 // string
|
||||||
|
MTD_ToolName = 5 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConversationTurnStructure (msg 70) oneof "turn"
|
||||||
|
const (
|
||||||
|
CTS_AgentConversationTurn = 1 // AgentConversationTurnStructure
|
||||||
|
)
|
||||||
|
|
||||||
|
// AgentConversationTurnStructure (msg 72)
|
||||||
|
const (
|
||||||
|
ACTS_UserMessage = 1 // bytes (serialized UserMessage)
|
||||||
|
ACTS_Steps = 2 // repeated bytes (serialized ConversationStep)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConversationStep (msg 53) oneof "message"
|
||||||
|
const (
|
||||||
|
CS_AssistantMessage = 1 // AssistantMessage
|
||||||
|
)
|
||||||
|
|
||||||
|
// AssistantMessage
|
||||||
|
const (
|
||||||
|
AM_Text = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Server-side message fields ---
|
||||||
|
|
||||||
|
// InteractionUpdate oneof "message"
|
||||||
|
const (
|
||||||
|
IU_TextDelta = 1 // TextDeltaUpdate
|
||||||
|
IU_ThinkingDelta = 4 // ThinkingDeltaUpdate
|
||||||
|
IU_ThinkingCompleted = 5 // ThinkingCompletedUpdate
|
||||||
|
)
|
||||||
|
|
||||||
|
// TextDeltaUpdate (msg 92)
|
||||||
|
const (
|
||||||
|
TDU_Text = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// ThinkingDeltaUpdate (msg 97)
|
||||||
|
const (
|
||||||
|
TKD_Text = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// KvServerMessage (msg 271)
|
||||||
|
const (
|
||||||
|
KSM_Id = 1 // uint32
|
||||||
|
KSM_GetBlobArgs = 2 // GetBlobArgs
|
||||||
|
KSM_SetBlobArgs = 3 // SetBlobArgs
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetBlobArgs (msg 267)
|
||||||
|
const (
|
||||||
|
GBA_BlobId = 1 // bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetBlobArgs (msg 269)
|
||||||
|
const (
|
||||||
|
SBA_BlobId = 1 // bytes
|
||||||
|
SBA_BlobData = 2 // bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
// KvClientMessage (msg 272)
|
||||||
|
const (
|
||||||
|
KCM_Id = 1 // uint32
|
||||||
|
KCM_GetBlobResult = 2 // GetBlobResult
|
||||||
|
KCM_SetBlobResult = 3 // SetBlobResult
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetBlobResult (msg 268)
|
||||||
|
const (
|
||||||
|
GBR_BlobData = 1 // bytes (optional)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExecServerMessage
|
||||||
|
const (
|
||||||
|
ESM_Id = 1 // uint32
|
||||||
|
ESM_ExecId = 15 // string
|
||||||
|
// oneof message:
|
||||||
|
ESM_ShellArgs = 2 // ShellArgs
|
||||||
|
ESM_WriteArgs = 3 // WriteArgs
|
||||||
|
ESM_DeleteArgs = 4 // DeleteArgs
|
||||||
|
ESM_GrepArgs = 5 // GrepArgs
|
||||||
|
ESM_ReadArgs = 7 // ReadArgs (NOTE: 6 is skipped)
|
||||||
|
ESM_LsArgs = 8 // LsArgs
|
||||||
|
ESM_DiagnosticsArgs = 9 // DiagnosticsArgs
|
||||||
|
ESM_RequestContextArgs = 10 // RequestContextArgs
|
||||||
|
ESM_McpArgs = 11 // McpArgs
|
||||||
|
ESM_ShellStreamArgs = 14 // ShellArgs (stream variant)
|
||||||
|
ESM_BackgroundShellSpawn = 16 // BackgroundShellSpawnArgs
|
||||||
|
ESM_FetchArgs = 20 // FetchArgs
|
||||||
|
ESM_WriteShellStdinArgs = 23 // WriteShellStdinArgs
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExecClientMessage
|
||||||
|
const (
|
||||||
|
ECM_Id = 1 // uint32
|
||||||
|
ECM_ExecId = 15 // string
|
||||||
|
// oneof message (mirrors server fields):
|
||||||
|
ECM_ShellResult = 2
|
||||||
|
ECM_WriteResult = 3
|
||||||
|
ECM_DeleteResult = 4
|
||||||
|
ECM_GrepResult = 5
|
||||||
|
ECM_ReadResult = 7
|
||||||
|
ECM_LsResult = 8
|
||||||
|
ECM_DiagnosticsResult = 9
|
||||||
|
ECM_RequestContextResult = 10
|
||||||
|
ECM_McpResult = 11
|
||||||
|
ECM_ShellStream = 14
|
||||||
|
ECM_BackgroundShellSpawnRes = 16
|
||||||
|
ECM_FetchResult = 20
|
||||||
|
ECM_WriteShellStdinResult = 23
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpArgs
|
||||||
|
const (
|
||||||
|
MCA_Name = 1 // string
|
||||||
|
MCA_Args = 2 // map<string, bytes>
|
||||||
|
MCA_ToolCallId = 3 // string
|
||||||
|
MCA_ProviderIdentifier = 4 // string
|
||||||
|
MCA_ToolName = 5 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestContextResult oneof "result"
|
||||||
|
const (
|
||||||
|
RCR_Success = 1 // RequestContextSuccess
|
||||||
|
RCR_Error = 2 // RequestContextError
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestContextSuccess (msg 337)
|
||||||
|
const (
|
||||||
|
RCS_RequestContext = 1 // RequestContext
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestContext
|
||||||
|
const (
|
||||||
|
RC_Rules = 2 // repeated CursorRule
|
||||||
|
RC_Tools = 7 // repeated McpToolDefinition
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpResult oneof "result"
|
||||||
|
const (
|
||||||
|
MCR_Success = 1 // McpSuccess
|
||||||
|
MCR_Error = 2 // McpError
|
||||||
|
MCR_Rejected = 3 // McpRejected
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpSuccess (msg 290)
|
||||||
|
const (
|
||||||
|
MCS_Content = 1 // repeated McpToolResultContentItem
|
||||||
|
MCS_IsError = 2 // bool
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpToolResultContentItem oneof "content"
|
||||||
|
const (
|
||||||
|
MTRCI_Text = 1 // McpTextContent
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpTextContent (msg 287)
|
||||||
|
const (
|
||||||
|
MTC_Text = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpError (msg 291)
|
||||||
|
const (
|
||||||
|
MCE_Error = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Rejection messages ---
|
||||||
|
|
||||||
|
// ReadRejected: path=1, reason=2
|
||||||
|
// ShellRejected: command=1, workingDirectory=2, reason=3, isReadonly=4
|
||||||
|
// WriteRejected: path=1, reason=2
|
||||||
|
// DeleteRejected: path=1, reason=2
|
||||||
|
// LsRejected: path=1, reason=2
|
||||||
|
// GrepError: error=1
|
||||||
|
// FetchError: url=1, error=2
|
||||||
|
// WriteShellStdinError: error=1
|
||||||
|
|
||||||
|
// ReadResult oneof: success=1, error=2, rejected=3
|
||||||
|
// ShellResult oneof: success=1 (+ various), rejected=?
|
||||||
|
// The TS code uses specific result field numbers from the oneof:
|
||||||
|
const (
|
||||||
|
RR_Rejected = 3 // ReadResult.rejected
|
||||||
|
SR_Rejected = 5 // ShellResult.rejected (from TS: ShellResult has success/various/rejected)
|
||||||
|
WR_Rejected = 5 // WriteResult.rejected
|
||||||
|
DR_Rejected = 3 // DeleteResult.rejected
|
||||||
|
LR_Rejected = 3 // LsResult.rejected
|
||||||
|
GR_Error = 2 // GrepResult.error
|
||||||
|
FR_Error = 2 // FetchResult.error
|
||||||
|
BSSR_Rejected = 2 // BackgroundShellSpawnResult.rejected (error field)
|
||||||
|
WSSR_Error = 2 // WriteShellStdinResult.error
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Rejection struct fields ---
|
||||||
|
const (
|
||||||
|
REJ_Path = 1
|
||||||
|
REJ_Reason = 2
|
||||||
|
SREJ_Command = 1
|
||||||
|
SREJ_WorkingDir = 2
|
||||||
|
SREJ_Reason = 3
|
||||||
|
SREJ_IsReadonly = 4
|
||||||
|
GERR_Error = 1
|
||||||
|
FERR_Url = 1
|
||||||
|
FERR_Error = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReadArgs
|
||||||
|
const (
|
||||||
|
RA_Path = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// WriteArgs
|
||||||
|
const (
|
||||||
|
WA_Path = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeleteArgs
|
||||||
|
const (
|
||||||
|
DA_Path = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// LsArgs
|
||||||
|
const (
|
||||||
|
LA_Path = 1 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// ShellArgs
|
||||||
|
const (
|
||||||
|
SHA_Command = 1 // string
|
||||||
|
SHA_WorkingDirectory = 2 // string
|
||||||
|
)
|
||||||
|
|
||||||
|
// FetchArgs
|
||||||
|
const (
|
||||||
|
FA_Url = 1 // string
|
||||||
|
)
|
||||||
313
internal/auth/cursor/proto/h2stream.go
Normal file
313
internal/auth/cursor/proto/h2stream.go
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
package proto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
"golang.org/x/net/http2/hpack"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultInitialWindowSize = 65535 // HTTP/2 default
|
||||||
|
maxFramePayload = 16384 // HTTP/2 default max frame size
|
||||||
|
)
|
||||||
|
|
||||||
|
// H2Stream provides bidirectional HTTP/2 streaming for the Connect protocol.
|
||||||
|
// Go's net/http does not support full-duplex HTTP/2, so we use the low-level framer.
|
||||||
|
type H2Stream struct {
|
||||||
|
framer *http2.Framer
|
||||||
|
conn net.Conn
|
||||||
|
streamID uint32
|
||||||
|
mu sync.Mutex
|
||||||
|
id string // unique identifier for debugging
|
||||||
|
frameNum int64 // sequential frame counter for debugging
|
||||||
|
|
||||||
|
dataCh chan []byte
|
||||||
|
doneCh chan struct{}
|
||||||
|
err error
|
||||||
|
|
||||||
|
// Send-side flow control
|
||||||
|
sendWindow int32 // available bytes we can send on this stream
|
||||||
|
connWindow int32 // available bytes on the connection level
|
||||||
|
windowCond *sync.Cond // signaled when window is updated
|
||||||
|
windowMu sync.Mutex // protects sendWindow, connWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
// ID returns the unique identifier for this stream (for logging).
|
||||||
|
func (s *H2Stream) ID() string { return s.id }
|
||||||
|
|
||||||
|
// FrameNum returns the current frame number for debugging.
|
||||||
|
func (s *H2Stream) FrameNum() int64 {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.frameNum
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialH2Stream establishes a TLS+HTTP/2 connection and opens a new stream.
|
||||||
|
func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
|
||||||
|
tlsConn, err := tls.Dial("tcp", host+":443", &tls.Config{
|
||||||
|
NextProtos: []string{"h2"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("h2: TLS dial failed: %w", err)
|
||||||
|
}
|
||||||
|
if tlsConn.ConnectionState().NegotiatedProtocol != "h2" {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: server did not negotiate h2")
|
||||||
|
}
|
||||||
|
|
||||||
|
framer := http2.NewFramer(tlsConn, tlsConn)
|
||||||
|
|
||||||
|
// Client connection preface
|
||||||
|
if _, err := tlsConn.Write([]byte(http2.ClientPreface)); err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: preface write failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send initial SETTINGS (tell server how much WE can receive)
|
||||||
|
if err := framer.WriteSettings(
|
||||||
|
http2.Setting{ID: http2.SettingInitialWindowSize, Val: 4 * 1024 * 1024},
|
||||||
|
http2.Setting{ID: http2.SettingMaxConcurrentStreams, Val: 100},
|
||||||
|
); err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: settings write failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection-level window update (for receiving)
|
||||||
|
if err := framer.WriteWindowUpdate(0, 3*1024*1024); err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: window update failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and handle initial server frames (SETTINGS, WINDOW_UPDATE)
|
||||||
|
// Track server's initial window size (how much WE can send)
|
||||||
|
serverInitialWindowSize := int32(defaultInitialWindowSize)
|
||||||
|
connWindowSize := int32(defaultInitialWindowSize) // connection-level send window
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
f, err := framer.ReadFrame()
|
||||||
|
if err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: initial frame read failed: %w", err)
|
||||||
|
}
|
||||||
|
switch sf := f.(type) {
|
||||||
|
case *http2.SettingsFrame:
|
||||||
|
if !sf.IsAck() {
|
||||||
|
sf.ForeachSetting(func(s http2.Setting) error {
|
||||||
|
if s.ID == http2.SettingInitialWindowSize {
|
||||||
|
serverInitialWindowSize = int32(s.Val)
|
||||||
|
log.Debugf("h2: server initial window size: %d", s.Val)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
framer.WriteSettingsAck()
|
||||||
|
} else {
|
||||||
|
goto handshakeDone
|
||||||
|
}
|
||||||
|
case *http2.WindowUpdateFrame:
|
||||||
|
if sf.StreamID == 0 {
|
||||||
|
connWindowSize += int32(sf.Increment)
|
||||||
|
log.Debugf("h2: initial conn window update: +%d, total=%d", sf.Increment, connWindowSize)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// unexpected but continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
handshakeDone:
|
||||||
|
|
||||||
|
// Build HEADERS
|
||||||
|
streamID := uint32(1)
|
||||||
|
var hdrBuf []byte
|
||||||
|
enc := hpack.NewEncoder(&sliceWriter{buf: &hdrBuf})
|
||||||
|
enc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"})
|
||||||
|
enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
|
||||||
|
enc.WriteField(hpack.HeaderField{Name: ":authority", Value: host})
|
||||||
|
if p, ok := headers[":path"]; ok {
|
||||||
|
enc.WriteField(hpack.HeaderField{Name: ":path", Value: p})
|
||||||
|
}
|
||||||
|
for k, v := range headers {
|
||||||
|
if len(k) > 0 && k[0] == ':' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := framer.WriteHeaders(http2.HeadersFrameParam{
|
||||||
|
StreamID: streamID,
|
||||||
|
BlockFragment: hdrBuf,
|
||||||
|
EndStream: false,
|
||||||
|
EndHeaders: true,
|
||||||
|
}); err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, fmt.Errorf("h2: headers write failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &H2Stream{
|
||||||
|
framer: framer,
|
||||||
|
conn: tlsConn,
|
||||||
|
streamID: streamID,
|
||||||
|
dataCh: make(chan []byte, 256),
|
||||||
|
doneCh: make(chan struct{}),
|
||||||
|
id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")),
|
||||||
|
frameNum: 0,
|
||||||
|
sendWindow: serverInitialWindowSize,
|
||||||
|
connWindow: connWindowSize,
|
||||||
|
}
|
||||||
|
s.windowCond = sync.NewCond(&s.windowMu)
|
||||||
|
go s.readLoop()
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write sends a DATA frame on the stream, respecting flow control.
|
||||||
|
func (s *H2Stream) Write(data []byte) error {
|
||||||
|
for len(data) > 0 {
|
||||||
|
chunk := data
|
||||||
|
if len(chunk) > maxFramePayload {
|
||||||
|
chunk = data[:maxFramePayload]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for flow control window
|
||||||
|
s.windowMu.Lock()
|
||||||
|
for s.sendWindow <= 0 || s.connWindow <= 0 {
|
||||||
|
s.windowCond.Wait()
|
||||||
|
}
|
||||||
|
// Limit chunk to available window
|
||||||
|
allowed := int(s.sendWindow)
|
||||||
|
if int(s.connWindow) < allowed {
|
||||||
|
allowed = int(s.connWindow)
|
||||||
|
}
|
||||||
|
if len(chunk) > allowed {
|
||||||
|
chunk = chunk[:allowed]
|
||||||
|
}
|
||||||
|
s.sendWindow -= int32(len(chunk))
|
||||||
|
s.connWindow -= int32(len(chunk))
|
||||||
|
s.windowMu.Unlock()
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
err := s.framer.WriteData(s.streamID, false, chunk)
|
||||||
|
s.mu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data = data[len(chunk):]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Data returns the channel of received data chunks.
|
||||||
|
func (s *H2Stream) Data() <-chan []byte { return s.dataCh }
|
||||||
|
|
||||||
|
// Done returns a channel closed when the stream ends.
|
||||||
|
func (s *H2Stream) Done() <-chan struct{} { return s.doneCh }
|
||||||
|
|
||||||
|
// Err returns the error (if any) that caused the stream to close.
|
||||||
|
// Returns nil for a clean shutdown (EOF / StreamEnded).
|
||||||
|
func (s *H2Stream) Err() error { return s.err }
|
||||||
|
|
||||||
|
// Close tears down the connection.
|
||||||
|
func (s *H2Stream) Close() {
|
||||||
|
s.conn.Close()
|
||||||
|
// Unblock any writers waiting on flow control
|
||||||
|
s.windowCond.Broadcast()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *H2Stream) readLoop() {
|
||||||
|
defer close(s.doneCh)
|
||||||
|
defer close(s.dataCh)
|
||||||
|
|
||||||
|
for {
|
||||||
|
f, err := s.framer.ReadFrame()
|
||||||
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
|
s.err = err
|
||||||
|
log.Debugf("h2stream[%s]: readLoop error: %v", s.id, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment frame counter
|
||||||
|
s.mu.Lock()
|
||||||
|
s.frameNum++
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
switch frame := f.(type) {
|
||||||
|
case *http2.DataFrame:
|
||||||
|
if frame.StreamID == s.streamID && len(frame.Data()) > 0 {
|
||||||
|
cp := make([]byte, len(frame.Data()))
|
||||||
|
copy(cp, frame.Data())
|
||||||
|
s.dataCh <- cp
|
||||||
|
|
||||||
|
// Flow control: send WINDOW_UPDATE for received data
|
||||||
|
s.mu.Lock()
|
||||||
|
s.framer.WriteWindowUpdate(0, uint32(len(cp)))
|
||||||
|
s.framer.WriteWindowUpdate(s.streamID, uint32(len(cp)))
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
if frame.StreamEnded() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case *http2.HeadersFrame:
|
||||||
|
if frame.StreamEnded() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case *http2.RSTStreamFrame:
|
||||||
|
s.err = fmt.Errorf("h2: RST_STREAM code=%d", frame.ErrCode)
|
||||||
|
log.Debugf("h2stream[%s]: received RST_STREAM code=%d", s.id, frame.ErrCode)
|
||||||
|
return
|
||||||
|
|
||||||
|
case *http2.GoAwayFrame:
|
||||||
|
s.err = fmt.Errorf("h2: GOAWAY code=%d", frame.ErrCode)
|
||||||
|
return
|
||||||
|
|
||||||
|
case *http2.PingFrame:
|
||||||
|
if !frame.IsAck() {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.framer.WritePing(true, frame.Data)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
case *http2.SettingsFrame:
|
||||||
|
if !frame.IsAck() {
|
||||||
|
// Check for window size changes
|
||||||
|
frame.ForeachSetting(func(setting http2.Setting) error {
|
||||||
|
if setting.ID == http2.SettingInitialWindowSize {
|
||||||
|
s.windowMu.Lock()
|
||||||
|
delta := int32(setting.Val) - s.sendWindow
|
||||||
|
s.sendWindow += delta
|
||||||
|
s.windowMu.Unlock()
|
||||||
|
s.windowCond.Broadcast()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
s.mu.Lock()
|
||||||
|
s.framer.WriteSettingsAck()
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
case *http2.WindowUpdateFrame:
|
||||||
|
// Update send-side flow control window
|
||||||
|
s.windowMu.Lock()
|
||||||
|
if frame.StreamID == 0 {
|
||||||
|
s.connWindow += int32(frame.Increment)
|
||||||
|
} else if frame.StreamID == s.streamID {
|
||||||
|
s.sendWindow += int32(frame.Increment)
|
||||||
|
}
|
||||||
|
s.windowMu.Unlock()
|
||||||
|
s.windowCond.Broadcast()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type sliceWriter struct{ buf *[]byte }
|
||||||
|
|
||||||
|
func (w *sliceWriter) Write(p []byte) (int, error) {
|
||||||
|
*w.buf = append(*w.buf, p...)
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
@@ -10,9 +10,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
@@ -20,9 +18,9 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
@@ -80,36 +78,16 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
|||||||
}
|
}
|
||||||
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||||
|
|
||||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
|
||||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
if errBuild != nil {
|
||||||
if err == nil {
|
log.Errorf("%v", errBuild)
|
||||||
var transport *http.Transport
|
} else if transport != nil {
|
||||||
if proxyURL.Scheme == "socks5" {
|
proxyClient := &http.Client{Transport: transport}
|
||||||
// Handle SOCKS5 proxy.
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
|
||||||
username := proxyURL.User.Username()
|
|
||||||
password, _ := proxyURL.User.Password()
|
|
||||||
auth := &proxy.Auth{User: username, Password: password}
|
|
||||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
|
|
||||||
if errSOCKS5 != nil {
|
|
||||||
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
|
||||||
return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
|
|
||||||
}
|
|
||||||
transport = &http.Transport{
|
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.Dial(network, addr)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
|
||||||
// Handle HTTP/HTTPS proxy.
|
|
||||||
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
|
||||||
}
|
|
||||||
|
|
||||||
if transport != nil {
|
|
||||||
proxyClient := &http.Client{Transport: transport}
|
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
// Configure the OAuth2 client.
|
// Configure the OAuth2 client.
|
||||||
conf := &oauth2.Config{
|
conf := &oauth2.Config{
|
||||||
ClientID: ClientID,
|
ClientID: ClientID,
|
||||||
@@ -327,6 +305,9 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
|||||||
defer manualPromptTimer.Stop()
|
defer manualPromptTimer.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var manualInputCh <-chan string
|
||||||
|
var manualInputErrCh <-chan error
|
||||||
|
|
||||||
waitForCallback:
|
waitForCallback:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -348,13 +329,14 @@ waitForCallback:
|
|||||||
return nil, err
|
return nil, err
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
|
manualInputCh, manualInputErrCh = misc.AsyncPrompt(opts.Prompt, "Paste the Gemini callback URL (or press Enter to keep waiting): ")
|
||||||
if err != nil {
|
continue
|
||||||
return nil, err
|
case input := <-manualInputCh:
|
||||||
}
|
manualInputCh = nil
|
||||||
parsed, err := misc.ParseOAuthCallback(input)
|
manualInputErrCh = nil
|
||||||
if err != nil {
|
parsed, errParse := misc.ParseOAuthCallback(input)
|
||||||
return nil, err
|
if errParse != nil {
|
||||||
|
return nil, errParse
|
||||||
}
|
}
|
||||||
if parsed == nil {
|
if parsed == nil {
|
||||||
continue
|
continue
|
||||||
@@ -367,6 +349,8 @@ waitForCallback:
|
|||||||
}
|
}
|
||||||
authCode = parsed.Code
|
authCode = parsed.Code
|
||||||
break waitForCallback
|
break waitForCallback
|
||||||
|
case errManual := <-manualInputErrCh:
|
||||||
|
return nil, errManual
|
||||||
case <-timeoutTimer.C:
|
case <-timeoutTimer.C:
|
||||||
return nil, fmt.Errorf("oauth flow timed out")
|
return nil, fmt.Errorf("oauth flow timed out")
|
||||||
}
|
}
|
||||||
|
|||||||
492
internal/auth/gitlab/gitlab.go
Normal file
492
internal/auth/gitlab/gitlab.go
Normal file
@@ -0,0 +1,492 @@
|
|||||||
|
package gitlab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultBaseURL = "https://gitlab.com"
|
||||||
|
DefaultCallbackPort = 17171
|
||||||
|
defaultOAuthScope = "api read_user"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PKCECodes struct {
|
||||||
|
CodeVerifier string
|
||||||
|
CodeChallenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OAuthResult struct {
|
||||||
|
Code string
|
||||||
|
State string
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OAuthServer struct {
|
||||||
|
server *http.Server
|
||||||
|
port int
|
||||||
|
resultChan chan *OAuthResult
|
||||||
|
errorChan chan error
|
||||||
|
mu sync.Mutex
|
||||||
|
running bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
PublicEmail string `json:"public_email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PersonalAccessTokenSelf struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Scopes []string `json:"scopes"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelDetails struct {
|
||||||
|
ModelProvider string `json:"model_provider"`
|
||||||
|
ModelName string `json:"model_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DirectAccessResponse struct {
|
||||||
|
BaseURL string `json:"base_url"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
ExpiresAt int64 `json:"expires_at"`
|
||||||
|
Headers map[string]string `json:"headers"`
|
||||||
|
ModelDetails *ModelDetails `json:"model_details,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DiscoveredModel struct {
|
||||||
|
ModelProvider string
|
||||||
|
ModelName string
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthClient struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAuthClient(cfg *config.Config) *AuthClient {
|
||||||
|
client := &http.Client{}
|
||||||
|
if cfg != nil {
|
||||||
|
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||||
|
}
|
||||||
|
return &AuthClient{httpClient: client}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NormalizeBaseURL(raw string) string {
|
||||||
|
value := strings.TrimSpace(raw)
|
||||||
|
if value == "" {
|
||||||
|
return DefaultBaseURL
|
||||||
|
}
|
||||||
|
if !strings.Contains(value, "://") {
|
||||||
|
value = "https://" + value
|
||||||
|
}
|
||||||
|
value = strings.TrimRight(value, "/")
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func TokenExpiry(now time.Time, token *TokenResponse) time.Time {
|
||||||
|
if token == nil {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
if token.CreatedAt > 0 && token.ExpiresIn > 0 {
|
||||||
|
return time.Unix(token.CreatedAt+int64(token.ExpiresIn), 0).UTC()
|
||||||
|
}
|
||||||
|
if token.ExpiresIn > 0 {
|
||||||
|
return now.UTC().Add(time.Duration(token.ExpiresIn) * time.Second)
|
||||||
|
}
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GeneratePKCECodes() (*PKCECodes, error) {
|
||||||
|
verifierBytes := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(verifierBytes); err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab pkce generation failed: %w", err)
|
||||||
|
}
|
||||||
|
verifier := base64.RawURLEncoding.EncodeToString(verifierBytes)
|
||||||
|
sum := sha256.Sum256([]byte(verifier))
|
||||||
|
challenge := base64.RawURLEncoding.EncodeToString(sum[:])
|
||||||
|
return &PKCECodes{
|
||||||
|
CodeVerifier: verifier,
|
||||||
|
CodeChallenge: challenge,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOAuthServer(port int) *OAuthServer {
|
||||||
|
return &OAuthServer{
|
||||||
|
port: port,
|
||||||
|
resultChan: make(chan *OAuthResult, 1),
|
||||||
|
errorChan: make(chan error, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) Start() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.running {
|
||||||
|
return fmt.Errorf("gitlab oauth server already running")
|
||||||
|
}
|
||||||
|
if !s.isPortAvailable() {
|
||||||
|
return fmt.Errorf("port %d is already in use", s.port)
|
||||||
|
}
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/auth/callback", s.handleCallback)
|
||||||
|
|
||||||
|
s.server = &http.Server{
|
||||||
|
Addr: fmt.Sprintf(":%d", s.port),
|
||||||
|
Handler: mux,
|
||||||
|
ReadTimeout: 10 * time.Second,
|
||||||
|
WriteTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
s.running = true
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
s.errorChan <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) Stop(ctx context.Context) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if !s.running || s.server == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
s.running = false
|
||||||
|
s.server = nil
|
||||||
|
}()
|
||||||
|
return s.server.Shutdown(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
|
||||||
|
select {
|
||||||
|
case result := <-s.resultChan:
|
||||||
|
return result, nil
|
||||||
|
case err := <-s.errorChan:
|
||||||
|
return nil, err
|
||||||
|
case <-time.After(timeout):
|
||||||
|
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
query := r.URL.Query()
|
||||||
|
if errParam := strings.TrimSpace(query.Get("error")); errParam != "" {
|
||||||
|
s.sendResult(&OAuthResult{Error: errParam})
|
||||||
|
http.Error(w, errParam, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code := strings.TrimSpace(query.Get("code"))
|
||||||
|
state := strings.TrimSpace(query.Get("state"))
|
||||||
|
if code == "" || state == "" {
|
||||||
|
s.sendResult(&OAuthResult{Error: "missing_code_or_state"})
|
||||||
|
http.Error(w, "missing code or state", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.sendResult(&OAuthResult{Code: code, State: state})
|
||||||
|
_, _ = w.Write([]byte("GitLab authentication received. You can close this tab."))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) sendResult(result *OAuthResult) {
|
||||||
|
select {
|
||||||
|
case s.resultChan <- result:
|
||||||
|
default:
|
||||||
|
log.Debug("gitlab oauth result channel full, dropping callback result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthServer) isPortAvailable() bool {
|
||||||
|
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port))
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_ = listener.Close()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func RedirectURL(port int) string {
|
||||||
|
return fmt.Sprintf("http://localhost:%d/auth/callback", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) GenerateAuthURL(baseURL, clientID, redirectURI, state string, pkce *PKCECodes) (string, error) {
|
||||||
|
if pkce == nil {
|
||||||
|
return "", fmt.Errorf("gitlab auth URL generation failed: PKCE codes are required")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(clientID) == "" {
|
||||||
|
return "", fmt.Errorf("gitlab auth URL generation failed: client ID is required")
|
||||||
|
}
|
||||||
|
baseURL = NormalizeBaseURL(baseURL)
|
||||||
|
params := url.Values{
|
||||||
|
"client_id": {strings.TrimSpace(clientID)},
|
||||||
|
"response_type": {"code"},
|
||||||
|
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||||
|
"scope": {defaultOAuthScope},
|
||||||
|
"state": {strings.TrimSpace(state)},
|
||||||
|
"code_challenge": {pkce.CodeChallenge},
|
||||||
|
"code_challenge_method": {"S256"},
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/oauth/authorize?%s", baseURL, params.Encode()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) ExchangeCodeForTokens(ctx context.Context, baseURL, clientID, clientSecret, redirectURI, code, codeVerifier string) (*TokenResponse, error) {
|
||||||
|
form := url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"client_id": {strings.TrimSpace(clientID)},
|
||||||
|
"code": {strings.TrimSpace(code)},
|
||||||
|
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||||
|
"code_verifier": {strings.TrimSpace(codeVerifier)},
|
||||||
|
}
|
||||||
|
if secret := strings.TrimSpace(clientSecret); secret != "" {
|
||||||
|
form.Set("client_secret", secret)
|
||||||
|
}
|
||||||
|
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) RefreshTokens(ctx context.Context, baseURL, clientID, clientSecret, refreshToken string) (*TokenResponse, error) {
|
||||||
|
form := url.Values{
|
||||||
|
"grant_type": {"refresh_token"},
|
||||||
|
"refresh_token": {strings.TrimSpace(refreshToken)},
|
||||||
|
}
|
||||||
|
if clientID = strings.TrimSpace(clientID); clientID != "" {
|
||||||
|
form.Set("client_id", clientID)
|
||||||
|
}
|
||||||
|
if secret := strings.TrimSpace(clientSecret); secret != "" {
|
||||||
|
form.Set("client_secret", secret)
|
||||||
|
}
|
||||||
|
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) postToken(ctx context.Context, tokenURL string, form url.Values) (*TokenResponse, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab token request failed: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab token request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab token response read failed: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("gitlab token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
var token TokenResponse
|
||||||
|
if err := json.Unmarshal(body, &token); err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab token response decode failed: %w", err)
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) GetCurrentUser(ctx context.Context, baseURL, token string) (*User, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/user", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab user request failed: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab user request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab user response read failed: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("gitlab user request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var user User
|
||||||
|
if err := json.Unmarshal(body, &user); err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab user response decode failed: %w", err)
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) GetPersonalAccessTokenSelf(ctx context.Context, baseURL, token string) (*PersonalAccessTokenSelf, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/personal_access_tokens/self", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab PAT self response read failed: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("gitlab PAT self request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var pat PersonalAccessTokenSelf
|
||||||
|
if err := json.Unmarshal(body, &pat); err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab PAT self response decode failed: %w", err)
|
||||||
|
}
|
||||||
|
return &pat, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthClient) FetchDirectAccess(ctx context.Context, baseURL, token string) (*DirectAccessResponse, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, NormalizeBaseURL(baseURL)+"/api/v4/code_suggestions/direct_access", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab direct access response read failed: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("gitlab direct access request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var direct DirectAccessResponse
|
||||||
|
if err := json.Unmarshal(body, &direct); err != nil {
|
||||||
|
return nil, fmt.Errorf("gitlab direct access response decode failed: %w", err)
|
||||||
|
}
|
||||||
|
if direct.Headers == nil {
|
||||||
|
direct.Headers = make(map[string]string)
|
||||||
|
}
|
||||||
|
return &direct, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtractDiscoveredModels(metadata map[string]any) []DiscoveredModel {
|
||||||
|
if len(metadata) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
models := make([]DiscoveredModel, 0, 4)
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
appendModel := func(provider, name string) {
|
||||||
|
provider = strings.TrimSpace(provider)
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if name == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key := strings.ToLower(name)
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
models = append(models, DiscoveredModel{
|
||||||
|
ModelProvider: provider,
|
||||||
|
ModelName: name,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if raw, ok := metadata["model_details"]; ok {
|
||||||
|
appendDiscoveredModels(raw, appendModel)
|
||||||
|
}
|
||||||
|
appendModel(stringValue(metadata["model_provider"]), stringValue(metadata["model_name"]))
|
||||||
|
|
||||||
|
for _, key := range []string{"models", "supported_models", "discovered_models"} {
|
||||||
|
if raw, ok := metadata[key]; ok {
|
||||||
|
appendDiscoveredModels(raw, appendModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendDiscoveredModels(raw any, appendModel func(provider, name string)) {
|
||||||
|
switch typed := raw.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
appendModel(stringValue(typed["model_provider"]), stringValue(typed["model_name"]))
|
||||||
|
appendModel(stringValue(typed["provider"]), stringValue(typed["name"]))
|
||||||
|
if nested, ok := typed["models"]; ok {
|
||||||
|
appendDiscoveredModels(nested, appendModel)
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, item := range typed {
|
||||||
|
appendDiscoveredModels(item, appendModel)
|
||||||
|
}
|
||||||
|
case []string:
|
||||||
|
for _, item := range typed {
|
||||||
|
appendModel("", item)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
appendModel("", typed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringValue(raw any) string {
|
||||||
|
switch typed := raw.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(typed)
|
||||||
|
case fmt.Stringer:
|
||||||
|
return strings.TrimSpace(typed.String())
|
||||||
|
case json.Number:
|
||||||
|
return typed.String()
|
||||||
|
case int:
|
||||||
|
return strconv.Itoa(typed)
|
||||||
|
case int64:
|
||||||
|
return strconv.FormatInt(typed, 10)
|
||||||
|
case float64:
|
||||||
|
return strconv.FormatInt(int64(typed), 10)
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
138
internal/auth/gitlab/gitlab_test.go
Normal file
138
internal/auth/gitlab/gitlab_test.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
package gitlab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthClientGenerateAuthURLIncludesPKCE(t *testing.T) {
|
||||||
|
client := NewAuthClient(nil)
|
||||||
|
pkce, err := GeneratePKCECodes()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GeneratePKCECodes() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawURL, err := client.GenerateAuthURL("https://gitlab.example.com", "client-id", RedirectURL(17171), "state-123", pkce)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateAuthURL() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := url.Parse(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Parse(authURL) error = %v", err)
|
||||||
|
}
|
||||||
|
if got := parsed.Path; got != "/oauth/authorize" {
|
||||||
|
t.Fatalf("expected /oauth/authorize path, got %q", got)
|
||||||
|
}
|
||||||
|
query := parsed.Query()
|
||||||
|
if got := query.Get("client_id"); got != "client-id" {
|
||||||
|
t.Fatalf("expected client_id, got %q", got)
|
||||||
|
}
|
||||||
|
if got := query.Get("scope"); got != defaultOAuthScope {
|
||||||
|
t.Fatalf("expected scope %q, got %q", defaultOAuthScope, got)
|
||||||
|
}
|
||||||
|
if got := query.Get("code_challenge_method"); got != "S256" {
|
||||||
|
t.Fatalf("expected PKCE method S256, got %q", got)
|
||||||
|
}
|
||||||
|
if got := query.Get("code_challenge"); got == "" {
|
||||||
|
t.Fatal("expected non-empty code_challenge")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthClientExchangeCodeForTokens(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/oauth/token" {
|
||||||
|
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
t.Fatalf("ParseForm() error = %v", err)
|
||||||
|
}
|
||||||
|
if got := r.Form.Get("grant_type"); got != "authorization_code" {
|
||||||
|
t.Fatalf("expected authorization_code grant, got %q", got)
|
||||||
|
}
|
||||||
|
if got := r.Form.Get("code_verifier"); got != "verifier-123" {
|
||||||
|
t.Fatalf("expected code_verifier, got %q", got)
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"access_token": "oauth-access",
|
||||||
|
"refresh_token": "oauth-refresh",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"scope": "api read_user",
|
||||||
|
"created_at": 1710000000,
|
||||||
|
"expires_in": 3600,
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := NewAuthClient(nil)
|
||||||
|
token, err := client.ExchangeCodeForTokens(context.Background(), srv.URL, "client-id", "client-secret", RedirectURL(17171), "auth-code", "verifier-123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExchangeCodeForTokens() error = %v", err)
|
||||||
|
}
|
||||||
|
if token.AccessToken != "oauth-access" {
|
||||||
|
t.Fatalf("expected access token, got %q", token.AccessToken)
|
||||||
|
}
|
||||||
|
if token.RefreshToken != "oauth-refresh" {
|
||||||
|
t.Fatalf("expected refresh token, got %q", token.RefreshToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractDiscoveredModels(t *testing.T) {
|
||||||
|
models := ExtractDiscoveredModels(map[string]any{
|
||||||
|
"model_details": map[string]any{
|
||||||
|
"model_provider": "anthropic",
|
||||||
|
"model_name": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
"supported_models": []any{
|
||||||
|
map[string]any{"model_provider": "openai", "model_name": "gpt-4.1"},
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(models) != 2 {
|
||||||
|
t.Fatalf("expected 2 unique models, got %d", len(models))
|
||||||
|
}
|
||||||
|
if models[0].ModelName != "claude-sonnet-4-5" {
|
||||||
|
t.Fatalf("unexpected first model %q", models[0].ModelName)
|
||||||
|
}
|
||||||
|
if models[1].ModelName != "gpt-4.1" {
|
||||||
|
t.Fatalf("unexpected second model %q", models[1].ModelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchDirectAccessDecodesModelDetails(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/api/v4/code_suggestions/direct_access" {
|
||||||
|
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("Authorization"); !strings.Contains(got, "token-123") {
|
||||||
|
t.Fatalf("expected bearer token, got %q", got)
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"base_url": "https://cloud.gitlab.example.com",
|
||||||
|
"token": "gateway-token",
|
||||||
|
"expires_at": 1710003600,
|
||||||
|
"headers": map[string]string{
|
||||||
|
"X-Gitlab-Realm": "saas",
|
||||||
|
},
|
||||||
|
"model_details": map[string]any{
|
||||||
|
"model_provider": "anthropic",
|
||||||
|
"model_name": "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := NewAuthClient(nil)
|
||||||
|
direct, err := client.FetchDirectAccess(context.Background(), srv.URL, "token-123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FetchDirectAccess() error = %v", err)
|
||||||
|
}
|
||||||
|
if direct.ModelDetails == nil || direct.ModelDetails.ModelName != "claude-sonnet-4-5" {
|
||||||
|
t.Fatalf("expected model details, got %+v", direct.ModelDetails)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -102,10 +102,24 @@ func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
|
|||||||
|
|
||||||
// NewDeviceFlowClientWithDeviceID creates a new device flow client with the specified device ID.
|
// NewDeviceFlowClientWithDeviceID creates a new device flow client with the specified device ID.
|
||||||
func NewDeviceFlowClientWithDeviceID(cfg *config.Config, deviceID string) *DeviceFlowClient {
|
func NewDeviceFlowClientWithDeviceID(cfg *config.Config, deviceID string) *DeviceFlowClient {
|
||||||
|
return NewDeviceFlowClientWithDeviceIDAndProxyURL(cfg, deviceID, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDeviceFlowClientWithDeviceIDAndProxyURL creates a new device flow client with a proxy override.
|
||||||
|
// proxyURL takes precedence over cfg.ProxyURL when non-empty.
|
||||||
|
func NewDeviceFlowClientWithDeviceIDAndProxyURL(cfg *config.Config, deviceID string, proxyURL string) *DeviceFlowClient {
|
||||||
client := &http.Client{Timeout: 30 * time.Second}
|
client := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
effectiveProxyURL := strings.TrimSpace(proxyURL)
|
||||||
|
var sdkCfg config.SDKConfig
|
||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
sdkCfg = cfg.SDKConfig
|
||||||
|
if effectiveProxyURL == "" {
|
||||||
|
effectiveProxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
sdkCfg.ProxyURL = effectiveProxyURL
|
||||||
|
client = util.SetProxy(&sdkCfg, client)
|
||||||
|
|
||||||
resolvedDeviceID := strings.TrimSpace(deviceID)
|
resolvedDeviceID := strings.TrimSpace(deviceID)
|
||||||
if resolvedDeviceID == "" {
|
if resolvedDeviceID == "" {
|
||||||
resolvedDeviceID = getOrCreateDeviceID()
|
resolvedDeviceID = getOrCreateDeviceID()
|
||||||
|
|||||||
42
internal/auth/kimi/kimi_proxy_test.go
Normal file
42
internal/auth/kimi/kimi_proxy_test.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package kimi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewDeviceFlowClientWithDeviceIDAndProxyURL_OverrideDirectDisablesProxy(t *testing.T) {
|
||||||
|
cfg := &config.Config{SDKConfig: config.SDKConfig{ProxyURL: "http://proxy.example.com:8080"}}
|
||||||
|
client := NewDeviceFlowClientWithDeviceIDAndProxyURL(cfg, "device-1", "direct")
|
||||||
|
|
||||||
|
transport, ok := client.httpClient.Transport.(*http.Transport)
|
||||||
|
if !ok || transport == nil {
|
||||||
|
t.Fatalf("expected http.Transport, got %T", client.httpClient.Transport)
|
||||||
|
}
|
||||||
|
if transport.Proxy != nil {
|
||||||
|
t.Fatal("expected direct transport to disable proxy function")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewDeviceFlowClientWithDeviceIDAndProxyURL_OverrideProxyTakesPrecedence(t *testing.T) {
|
||||||
|
cfg := &config.Config{SDKConfig: config.SDKConfig{ProxyURL: "http://global.example.com:8080"}}
|
||||||
|
client := NewDeviceFlowClientWithDeviceIDAndProxyURL(cfg, "device-1", "http://override.example.com:8081")
|
||||||
|
|
||||||
|
transport, ok := client.httpClient.Transport.(*http.Transport)
|
||||||
|
if !ok || transport == nil {
|
||||||
|
t.Fatalf("expected http.Transport, got %T", client.httpClient.Transport)
|
||||||
|
}
|
||||||
|
req, errReq := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||||
|
if errReq != nil {
|
||||||
|
t.Fatalf("new request: %v", errReq)
|
||||||
|
}
|
||||||
|
proxyURL, errProxy := transport.Proxy(req)
|
||||||
|
if errProxy != nil {
|
||||||
|
t.Fatalf("proxy func: %v", errProxy)
|
||||||
|
}
|
||||||
|
if proxyURL == nil || proxyURL.String() != "http://override.example.com:8081" {
|
||||||
|
t.Fatalf("proxy URL = %v, want http://override.example.com:8081", proxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -748,4 +748,3 @@ func TestExtractRegionFromMetadata(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CooldownReason429 = "rate_limit_exceeded"
|
CooldownReason429 = "rate_limit_exceeded"
|
||||||
CooldownReasonSuspended = "account_suspended"
|
CooldownReasonSuspended = "account_suspended"
|
||||||
CooldownReasonQuotaExhausted = "quota_exhausted"
|
CooldownReasonQuotaExhausted = "quota_exhausted"
|
||||||
|
|
||||||
DefaultShortCooldown = 1 * time.Minute
|
DefaultShortCooldown = 1 * time.Minute
|
||||||
|
|||||||
@@ -26,9 +26,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
jitterRand *rand.Rand
|
jitterRand *rand.Rand
|
||||||
jitterRandOnce sync.Once
|
jitterRandOnce sync.Once
|
||||||
jitterMu sync.Mutex
|
jitterMu sync.Mutex
|
||||||
lastRequestTime time.Time
|
lastRequestTime time.Time
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -24,10 +24,10 @@ type TokenScorer struct {
|
|||||||
metrics map[string]*TokenMetrics
|
metrics map[string]*TokenMetrics
|
||||||
|
|
||||||
// Scoring weights
|
// Scoring weights
|
||||||
successRateWeight float64
|
successRateWeight float64
|
||||||
quotaWeight float64
|
quotaWeight float64
|
||||||
latencyWeight float64
|
latencyWeight float64
|
||||||
lastUsedWeight float64
|
lastUsedWeight float64
|
||||||
failPenaltyMultiplier float64
|
failPenaltyMultiplier float64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,359 +0,0 @@
|
|||||||
package qwen
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// QwenOAuthDeviceCodeEndpoint is the URL for initiating the OAuth 2.0 device authorization flow.
|
|
||||||
QwenOAuthDeviceCodeEndpoint = "https://chat.qwen.ai/api/v1/oauth2/device/code"
|
|
||||||
// QwenOAuthTokenEndpoint is the URL for exchanging device codes or refresh tokens for access tokens.
|
|
||||||
QwenOAuthTokenEndpoint = "https://chat.qwen.ai/api/v1/oauth2/token"
|
|
||||||
// QwenOAuthClientID is the client identifier for the Qwen OAuth 2.0 application.
|
|
||||||
QwenOAuthClientID = "f0304373b74a44d2b584a3fb70ca9e56"
|
|
||||||
// QwenOAuthScope defines the permissions requested by the application.
|
|
||||||
QwenOAuthScope = "openid profile email model.completion"
|
|
||||||
// QwenOAuthGrantType specifies the grant type for the device code flow.
|
|
||||||
QwenOAuthGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
|
||||||
)
|
|
||||||
|
|
||||||
// QwenTokenData represents the OAuth credentials, including access and refresh tokens.
|
|
||||||
type QwenTokenData struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
// RefreshToken is used to obtain a new access token when the current one expires.
|
|
||||||
RefreshToken string `json:"refresh_token,omitempty"`
|
|
||||||
// TokenType indicates the type of token, typically "Bearer".
|
|
||||||
TokenType string `json:"token_type"`
|
|
||||||
// ResourceURL specifies the base URL of the resource server.
|
|
||||||
ResourceURL string `json:"resource_url,omitempty"`
|
|
||||||
// Expire indicates the expiration date and time of the access token.
|
|
||||||
Expire string `json:"expiry_date,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeviceFlow represents the response from the device authorization endpoint.
|
|
||||||
type DeviceFlow struct {
|
|
||||||
// DeviceCode is the code that the client uses to poll for an access token.
|
|
||||||
DeviceCode string `json:"device_code"`
|
|
||||||
// UserCode is the code that the user enters at the verification URI.
|
|
||||||
UserCode string `json:"user_code"`
|
|
||||||
// VerificationURI is the URL where the user can enter the user code to authorize the device.
|
|
||||||
VerificationURI string `json:"verification_uri"`
|
|
||||||
// VerificationURIComplete is a URI that includes the user_code, which can be used to automatically
|
|
||||||
// fill in the code on the verification page.
|
|
||||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
|
||||||
// ExpiresIn is the time in seconds until the device_code and user_code expire.
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
// Interval is the minimum time in seconds that the client should wait between polling requests.
|
|
||||||
Interval int `json:"interval"`
|
|
||||||
// CodeVerifier is the cryptographically random string used in the PKCE flow.
|
|
||||||
CodeVerifier string `json:"code_verifier"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// QwenTokenResponse represents the successful token response from the token endpoint.
|
|
||||||
type QwenTokenResponse struct {
|
|
||||||
// AccessToken is the token used to access protected resources.
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
// RefreshToken is used to obtain a new access token.
|
|
||||||
RefreshToken string `json:"refresh_token,omitempty"`
|
|
||||||
// TokenType indicates the type of token, typically "Bearer".
|
|
||||||
TokenType string `json:"token_type"`
|
|
||||||
// ResourceURL specifies the base URL of the resource server.
|
|
||||||
ResourceURL string `json:"resource_url,omitempty"`
|
|
||||||
// ExpiresIn is the time in seconds until the access token expires.
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// QwenAuth manages authentication and token handling for the Qwen API.
|
|
||||||
type QwenAuth struct {
|
|
||||||
httpClient *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewQwenAuth creates a new QwenAuth instance with a proxy-configured HTTP client.
|
|
||||||
func NewQwenAuth(cfg *config.Config) *QwenAuth {
|
|
||||||
return &QwenAuth{
|
|
||||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateCodeVerifier generates a cryptographically random string for the PKCE code verifier.
|
|
||||||
func (qa *QwenAuth) generateCodeVerifier() (string, error) {
|
|
||||||
bytes := make([]byte, 32)
|
|
||||||
if _, err := rand.Read(bytes); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return base64.RawURLEncoding.EncodeToString(bytes), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateCodeChallenge creates a SHA-256 hash of the code verifier, used as the PKCE code challenge.
|
|
||||||
func (qa *QwenAuth) generateCodeChallenge(codeVerifier string) string {
|
|
||||||
hash := sha256.Sum256([]byte(codeVerifier))
|
|
||||||
return base64.RawURLEncoding.EncodeToString(hash[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// generatePKCEPair creates a new code verifier and its corresponding code challenge for PKCE.
|
|
||||||
func (qa *QwenAuth) generatePKCEPair() (string, string, error) {
|
|
||||||
codeVerifier, err := qa.generateCodeVerifier()
|
|
||||||
if err != nil {
|
|
||||||
return "", "", err
|
|
||||||
}
|
|
||||||
codeChallenge := qa.generateCodeChallenge(codeVerifier)
|
|
||||||
return codeVerifier, codeChallenge, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RefreshTokens exchanges a refresh token for a new access token.
|
|
||||||
func (qa *QwenAuth) RefreshTokens(ctx context.Context, refreshToken string) (*QwenTokenData, error) {
|
|
||||||
data := url.Values{}
|
|
||||||
data.Set("grant_type", "refresh_token")
|
|
||||||
data.Set("refresh_token", refreshToken)
|
|
||||||
data.Set("client_id", QwenOAuthClientID)
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthTokenEndpoint, strings.NewReader(data.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
|
|
||||||
resp, err := qa.httpClient.Do(req)
|
|
||||||
|
|
||||||
// resp, err := qa.httpClient.PostForm(QwenOAuthTokenEndpoint, data)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("token refresh request failed: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
var errorData map[string]interface{}
|
|
||||||
if err = json.Unmarshal(body, &errorData); err == nil {
|
|
||||||
return nil, fmt.Errorf("token refresh failed: %v - %v", errorData["error"], errorData["error_description"])
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
var tokenData QwenTokenResponse
|
|
||||||
if err = json.Unmarshal(body, &tokenData); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &QwenTokenData{
|
|
||||||
AccessToken: tokenData.AccessToken,
|
|
||||||
TokenType: tokenData.TokenType,
|
|
||||||
RefreshToken: tokenData.RefreshToken,
|
|
||||||
ResourceURL: tokenData.ResourceURL,
|
|
||||||
Expire: time.Now().Add(time.Duration(tokenData.ExpiresIn) * time.Second).Format(time.RFC3339),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitiateDeviceFlow starts the OAuth 2.0 device authorization flow and returns the device flow details.
|
|
||||||
func (qa *QwenAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceFlow, error) {
|
|
||||||
// Generate PKCE code verifier and challenge
|
|
||||||
codeVerifier, codeChallenge, err := qa.generatePKCEPair()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to generate PKCE pair: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data := url.Values{}
|
|
||||||
data.Set("client_id", QwenOAuthClientID)
|
|
||||||
data.Set("scope", QwenOAuthScope)
|
|
||||||
data.Set("code_challenge", codeChallenge)
|
|
||||||
data.Set("code_challenge_method", "S256")
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthDeviceCodeEndpoint, strings.NewReader(data.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
|
|
||||||
resp, err := qa.httpClient.Do(req)
|
|
||||||
|
|
||||||
// resp, err := qa.httpClient.PostForm(QwenOAuthDeviceCodeEndpoint, data)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("device authorization request failed: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("device authorization failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
var result DeviceFlow
|
|
||||||
if err = json.Unmarshal(body, &result); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse device flow response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the response indicates success
|
|
||||||
if result.DeviceCode == "" {
|
|
||||||
return nil, fmt.Errorf("device authorization failed: device_code not found in response")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the code_verifier to the result so it can be used later for polling
|
|
||||||
result.CodeVerifier = codeVerifier
|
|
||||||
|
|
||||||
return &result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PollForToken polls the token endpoint with the device code to obtain an access token.
|
|
||||||
func (qa *QwenAuth) PollForToken(deviceCode, codeVerifier string) (*QwenTokenData, error) {
|
|
||||||
pollInterval := 5 * time.Second
|
|
||||||
maxAttempts := 60 // 5 minutes max
|
|
||||||
|
|
||||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
|
||||||
data := url.Values{}
|
|
||||||
data.Set("grant_type", QwenOAuthGrantType)
|
|
||||||
data.Set("client_id", QwenOAuthClientID)
|
|
||||||
data.Set("device_code", deviceCode)
|
|
||||||
data.Set("code_verifier", codeVerifier)
|
|
||||||
|
|
||||||
resp, err := http.PostForm(QwenOAuthTokenEndpoint, data)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err)
|
|
||||||
time.Sleep(pollInterval)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err)
|
|
||||||
time.Sleep(pollInterval)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
// Parse the response as JSON to check for OAuth RFC 8628 standard errors
|
|
||||||
var errorData map[string]interface{}
|
|
||||||
if err = json.Unmarshal(body, &errorData); err == nil {
|
|
||||||
// According to OAuth RFC 8628, handle standard polling responses
|
|
||||||
if resp.StatusCode == http.StatusBadRequest {
|
|
||||||
errorType, _ := errorData["error"].(string)
|
|
||||||
switch errorType {
|
|
||||||
case "authorization_pending":
|
|
||||||
// User has not yet approved the authorization request. Continue polling.
|
|
||||||
fmt.Printf("Polling attempt %d/%d...\n\n", attempt+1, maxAttempts)
|
|
||||||
time.Sleep(pollInterval)
|
|
||||||
continue
|
|
||||||
case "slow_down":
|
|
||||||
// Client is polling too frequently. Increase poll interval.
|
|
||||||
pollInterval = time.Duration(float64(pollInterval) * 1.5)
|
|
||||||
if pollInterval > 10*time.Second {
|
|
||||||
pollInterval = 10 * time.Second
|
|
||||||
}
|
|
||||||
fmt.Printf("Server requested to slow down, increasing poll interval to %v\n\n", pollInterval)
|
|
||||||
time.Sleep(pollInterval)
|
|
||||||
continue
|
|
||||||
case "expired_token":
|
|
||||||
return nil, fmt.Errorf("device code expired. Please restart the authentication process")
|
|
||||||
case "access_denied":
|
|
||||||
return nil, fmt.Errorf("authorization denied by user. Please restart the authentication process")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// For other errors, return with proper error information
|
|
||||||
errorType, _ := errorData["error"].(string)
|
|
||||||
errorDesc, _ := errorData["error_description"].(string)
|
|
||||||
return nil, fmt.Errorf("device token poll failed: %s - %s", errorType, errorDesc)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If JSON parsing fails, fall back to text response
|
|
||||||
return nil, fmt.Errorf("device token poll failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body))
|
|
||||||
}
|
|
||||||
// log.Debugf("%s", string(body))
|
|
||||||
// Success - parse token data
|
|
||||||
var response QwenTokenResponse
|
|
||||||
if err = json.Unmarshal(body, &response); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to QwenTokenData format and save
|
|
||||||
tokenData := &QwenTokenData{
|
|
||||||
AccessToken: response.AccessToken,
|
|
||||||
RefreshToken: response.RefreshToken,
|
|
||||||
TokenType: response.TokenType,
|
|
||||||
ResourceURL: response.ResourceURL,
|
|
||||||
Expire: time.Now().Add(time.Duration(response.ExpiresIn) * time.Second).Format(time.RFC3339),
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("authentication timeout. Please restart the authentication process")
|
|
||||||
}
|
|
||||||
|
|
||||||
// RefreshTokensWithRetry attempts to refresh tokens with a specified number of retries upon failure.
|
|
||||||
func (o *QwenAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken string, maxRetries int) (*QwenTokenData, error) {
|
|
||||||
var lastErr error
|
|
||||||
|
|
||||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
|
||||||
if attempt > 0 {
|
|
||||||
// Wait before retry
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case <-time.After(time.Duration(attempt) * time.Second):
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenData, err := o.RefreshTokens(ctx, refreshToken)
|
|
||||||
if err == nil {
|
|
||||||
return tokenData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
lastErr = err
|
|
||||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateTokenStorage creates a QwenTokenStorage object from a QwenTokenData object.
|
|
||||||
func (o *QwenAuth) CreateTokenStorage(tokenData *QwenTokenData) *QwenTokenStorage {
|
|
||||||
storage := &QwenTokenStorage{
|
|
||||||
AccessToken: tokenData.AccessToken,
|
|
||||||
RefreshToken: tokenData.RefreshToken,
|
|
||||||
LastRefresh: time.Now().Format(time.RFC3339),
|
|
||||||
ResourceURL: tokenData.ResourceURL,
|
|
||||||
Expire: tokenData.Expire,
|
|
||||||
}
|
|
||||||
|
|
||||||
return storage
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateTokenStorage updates an existing token storage with new token data
|
|
||||||
func (o *QwenAuth) UpdateTokenStorage(storage *QwenTokenStorage, tokenData *QwenTokenData) {
|
|
||||||
storage.AccessToken = tokenData.AccessToken
|
|
||||||
storage.RefreshToken = tokenData.RefreshToken
|
|
||||||
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
|
||||||
storage.ResourceURL = tokenData.ResourceURL
|
|
||||||
storage.Expire = tokenData.Expire
|
|
||||||
}
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
// Package qwen provides authentication and token management functionality
|
|
||||||
// for Alibaba's Qwen AI services. It handles OAuth2 token storage, serialization,
|
|
||||||
// and retrieval for maintaining authenticated sessions with the Qwen API.
|
|
||||||
package qwen
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
|
||||||
)
|
|
||||||
|
|
||||||
// QwenTokenStorage stores OAuth2 token information for Alibaba Qwen API authentication.
|
|
||||||
// It maintains compatibility with the existing auth system while adding Qwen-specific fields
|
|
||||||
// for managing access tokens, refresh tokens, and user account information.
|
|
||||||
type QwenTokenStorage struct {
|
|
||||||
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
// RefreshToken is used to obtain new access tokens when the current one expires.
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
// LastRefresh is the timestamp of the last token refresh operation.
|
|
||||||
LastRefresh string `json:"last_refresh"`
|
|
||||||
// ResourceURL is the base URL for API requests.
|
|
||||||
ResourceURL string `json:"resource_url"`
|
|
||||||
// Email is the Qwen account email address associated with this token.
|
|
||||||
Email string `json:"email"`
|
|
||||||
// Type indicates the authentication provider type, always "qwen" for this storage.
|
|
||||||
Type string `json:"type"`
|
|
||||||
// Expire is the timestamp when the current access token expires.
|
|
||||||
Expire string `json:"expired"`
|
|
||||||
|
|
||||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
|
||||||
// It is not exported to JSON directly to allow flattening during serialization.
|
|
||||||
Metadata map[string]any `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
|
||||||
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
|
|
||||||
ts.Metadata = meta
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
|
|
||||||
// This method creates the necessary directory structure and writes the token
|
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
|
||||||
// It merges any injected metadata into the top-level JSON object.
|
|
||||||
//
|
|
||||||
// Parameters:
|
|
||||||
// - authFilePath: The full path where the token file should be saved
|
|
||||||
//
|
|
||||||
// Returns:
|
|
||||||
// - error: An error if the operation fails, nil otherwise
|
|
||||||
func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|
||||||
misc.LogSavingCredentials(authFilePath)
|
|
||||||
ts.Type = "qwen"
|
|
||||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
|
||||||
return fmt.Errorf("failed to create directory: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
f, err := os.Create(authFilePath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create token file: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = f.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Merge metadata using helper
|
|
||||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
|
||||||
if errMerge != nil {
|
|
||||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -30,6 +30,10 @@ type VertexCredentialStorage struct {
|
|||||||
|
|
||||||
// Type is the provider identifier stored alongside credentials. Always "vertex".
|
// Type is the provider identifier stored alongside credentials. Always "vertex".
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Prefix optionally namespaces models for this credential (e.g., "teamA").
|
||||||
|
// This results in model names like "teamA/gemini-2.0-flash".
|
||||||
|
Prefix string `json:"prefix,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile writes the credential payload to the given file path in JSON format.
|
// SaveTokenToFile writes the credential payload to the given file path in JSON format.
|
||||||
|
|||||||
45
internal/cache/signature_cache.go
vendored
45
internal/cache/signature_cache.go
vendored
@@ -5,7 +5,10 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SignatureEntry holds a cached thinking signature with timestamp
|
// SignatureEntry holds a cached thinking signature with timestamp
|
||||||
@@ -193,3 +196,45 @@ func GetModelGroup(modelName string) string {
|
|||||||
}
|
}
|
||||||
return modelName
|
return modelName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var signatureCacheEnabled atomic.Bool
|
||||||
|
var signatureBypassStrictMode atomic.Bool
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
signatureCacheEnabled.Store(true)
|
||||||
|
signatureBypassStrictMode.Store(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSignatureCacheEnabled switches Antigravity signature handling between cache mode and bypass mode.
|
||||||
|
func SetSignatureCacheEnabled(enabled bool) {
|
||||||
|
previous := signatureCacheEnabled.Swap(enabled)
|
||||||
|
if previous == enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !enabled {
|
||||||
|
log.Info("antigravity signature cache DISABLED - bypass mode active, cached signatures will not be used for request translation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignatureCacheEnabled returns whether signature cache validation is enabled.
|
||||||
|
func SignatureCacheEnabled() bool {
|
||||||
|
return signatureCacheEnabled.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSignatureBypassStrictMode controls whether bypass mode uses strict protobuf-tree validation.
|
||||||
|
func SetSignatureBypassStrictMode(strict bool) {
|
||||||
|
previous := signatureBypassStrictMode.Swap(strict)
|
||||||
|
if previous == strict {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strict {
|
||||||
|
log.Debug("antigravity bypass signature validation: strict mode (protobuf tree)")
|
||||||
|
} else {
|
||||||
|
log.Debug("antigravity bypass signature validation: basic mode (R/E + 0x12)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignatureBypassStrictMode returns whether bypass mode uses strict protobuf-tree validation.
|
||||||
|
func SignatureBypassStrictMode() bool {
|
||||||
|
return signatureBypassStrictMode.Load()
|
||||||
|
}
|
||||||
|
|||||||
91
internal/cache/signature_cache_test.go
vendored
91
internal/cache/signature_cache_test.go
vendored
@@ -1,8 +1,12 @@
|
|||||||
package cache
|
package cache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const testModelName = "claude-sonnet-4-5"
|
const testModelName = "claude-sonnet-4-5"
|
||||||
@@ -208,3 +212,90 @@ func TestCacheSignature_ExpirationLogic(t *testing.T) {
|
|||||||
// but the logic is verified by the implementation
|
// but the logic is verified by the implementation
|
||||||
_ = time.Now() // Acknowledge we're not testing time passage
|
_ = time.Now() // Acknowledge we're not testing time passage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSignatureModeSetters_LogAtInfoLevel(t *testing.T) {
|
||||||
|
logger := log.StandardLogger()
|
||||||
|
previousOutput := logger.Out
|
||||||
|
previousLevel := logger.Level
|
||||||
|
previousCache := SignatureCacheEnabled()
|
||||||
|
previousStrict := SignatureBypassStrictMode()
|
||||||
|
SetSignatureCacheEnabled(true)
|
||||||
|
SetSignatureBypassStrictMode(false)
|
||||||
|
buffer := &bytes.Buffer{}
|
||||||
|
log.SetOutput(buffer)
|
||||||
|
log.SetLevel(log.InfoLevel)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
log.SetOutput(previousOutput)
|
||||||
|
log.SetLevel(previousLevel)
|
||||||
|
SetSignatureCacheEnabled(previousCache)
|
||||||
|
SetSignatureBypassStrictMode(previousStrict)
|
||||||
|
})
|
||||||
|
|
||||||
|
SetSignatureCacheEnabled(false)
|
||||||
|
SetSignatureBypassStrictMode(true)
|
||||||
|
SetSignatureBypassStrictMode(false)
|
||||||
|
|
||||||
|
output := buffer.String()
|
||||||
|
if !strings.Contains(output, "antigravity signature cache DISABLED") {
|
||||||
|
t.Fatalf("expected info output for disabling signature cache, got: %q", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "strict mode (protobuf tree)") {
|
||||||
|
t.Fatalf("expected strict bypass mode log to stay below info level, got: %q", output)
|
||||||
|
}
|
||||||
|
if strings.Contains(output, "basic mode (R/E + 0x12)") {
|
||||||
|
t.Fatalf("expected basic bypass mode log to stay below info level, got: %q", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignatureModeSetters_DoNotRepeatSameStateLogs(t *testing.T) {
|
||||||
|
logger := log.StandardLogger()
|
||||||
|
previousOutput := logger.Out
|
||||||
|
previousLevel := logger.Level
|
||||||
|
previousCache := SignatureCacheEnabled()
|
||||||
|
previousStrict := SignatureBypassStrictMode()
|
||||||
|
SetSignatureCacheEnabled(false)
|
||||||
|
SetSignatureBypassStrictMode(true)
|
||||||
|
buffer := &bytes.Buffer{}
|
||||||
|
log.SetOutput(buffer)
|
||||||
|
log.SetLevel(log.InfoLevel)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
log.SetOutput(previousOutput)
|
||||||
|
log.SetLevel(previousLevel)
|
||||||
|
SetSignatureCacheEnabled(previousCache)
|
||||||
|
SetSignatureBypassStrictMode(previousStrict)
|
||||||
|
})
|
||||||
|
|
||||||
|
SetSignatureCacheEnabled(false)
|
||||||
|
SetSignatureBypassStrictMode(true)
|
||||||
|
|
||||||
|
if buffer.Len() != 0 {
|
||||||
|
t.Fatalf("expected repeated setter calls with unchanged state to stay silent, got: %q", buffer.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignatureBypassStrictMode_LogsAtDebugLevel(t *testing.T) {
|
||||||
|
logger := log.StandardLogger()
|
||||||
|
previousOutput := logger.Out
|
||||||
|
previousLevel := logger.Level
|
||||||
|
previousStrict := SignatureBypassStrictMode()
|
||||||
|
SetSignatureBypassStrictMode(false)
|
||||||
|
buffer := &bytes.Buffer{}
|
||||||
|
log.SetOutput(buffer)
|
||||||
|
log.SetLevel(log.DebugLevel)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
log.SetOutput(previousOutput)
|
||||||
|
log.SetLevel(previousLevel)
|
||||||
|
SetSignatureBypassStrictMode(previousStrict)
|
||||||
|
})
|
||||||
|
|
||||||
|
SetSignatureBypassStrictMode(true)
|
||||||
|
SetSignatureBypassStrictMode(false)
|
||||||
|
|
||||||
|
output := buffer.String()
|
||||||
|
if !strings.Contains(output, "strict mode (protobuf tree)") {
|
||||||
|
t.Fatalf("expected debug output for strict bypass mode, got: %q", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "basic mode (R/E + 0x12)") {
|
||||||
|
t.Fatalf("expected debug output for basic bypass mode, got: %q", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// newAuthManager creates a new authentication manager instance with all supported
|
// newAuthManager creates a new authentication manager instance with all supported
|
||||||
// authenticators and a file-based token store. It initializes authenticators for
|
// authenticators and a file-based token store.
|
||||||
// Gemini, Codex, Claude, Qwen, IFlow, Antigravity, and GitHub Copilot providers.
|
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - *sdkAuth.Manager: A configured authentication manager instance
|
// - *sdkAuth.Manager: A configured authentication manager instance
|
||||||
@@ -16,13 +15,14 @@ func newAuthManager() *sdkAuth.Manager {
|
|||||||
sdkAuth.NewGeminiAuthenticator(),
|
sdkAuth.NewGeminiAuthenticator(),
|
||||||
sdkAuth.NewCodexAuthenticator(),
|
sdkAuth.NewCodexAuthenticator(),
|
||||||
sdkAuth.NewClaudeAuthenticator(),
|
sdkAuth.NewClaudeAuthenticator(),
|
||||||
sdkAuth.NewQwenAuthenticator(),
|
|
||||||
sdkAuth.NewIFlowAuthenticator(),
|
|
||||||
sdkAuth.NewAntigravityAuthenticator(),
|
sdkAuth.NewAntigravityAuthenticator(),
|
||||||
sdkAuth.NewKimiAuthenticator(),
|
sdkAuth.NewKimiAuthenticator(),
|
||||||
sdkAuth.NewKiroAuthenticator(),
|
sdkAuth.NewKiroAuthenticator(),
|
||||||
sdkAuth.NewGitHubCopilotAuthenticator(),
|
sdkAuth.NewGitHubCopilotAuthenticator(),
|
||||||
sdkAuth.NewKiloAuthenticator(),
|
sdkAuth.NewKiloAuthenticator(),
|
||||||
|
sdkAuth.NewGitLabAuthenticator(),
|
||||||
|
sdkAuth.NewCodeBuddyAuthenticator(),
|
||||||
|
sdkAuth.NewCursorAuthenticator(),
|
||||||
)
|
)
|
||||||
return manager
|
return manager
|
||||||
}
|
}
|
||||||
|
|||||||
43
internal/cmd/codebuddy_login.go
Normal file
43
internal/cmd/codebuddy_login.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoCodeBuddyLogin triggers the browser OAuth polling flow for CodeBuddy and saves tokens.
|
||||||
|
// It initiates the OAuth authentication, displays the user code for the user to enter
|
||||||
|
// at the CodeBuddy verification URL, and waits for authorization before saving the tokens.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration containing proxy and auth directory settings
|
||||||
|
// - options: Login options including browser behavior settings
|
||||||
|
func DoCodeBuddyLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
Metadata: map[string]string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
record, savedPath, err := manager.Login(context.Background(), "codebuddy", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("CodeBuddy authentication failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
if record != nil && record.Label != "" {
|
||||||
|
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||||
|
}
|
||||||
|
fmt.Println("CodeBuddy authentication successful!")
|
||||||
|
}
|
||||||
37
internal/cmd/cursor_login.go
Normal file
37
internal/cmd/cursor_login.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoCursorLogin triggers the OAuth PKCE flow for Cursor and saves tokens.
|
||||||
|
func DoCursorLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
Metadata: map[string]string{},
|
||||||
|
Prompt: options.Prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
record, savedPath, err := manager.Login(context.Background(), "cursor", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Cursor authentication failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
log.Infof("Authentication saved to %s", savedPath)
|
||||||
|
}
|
||||||
|
if record != nil && record.Label != "" {
|
||||||
|
log.Infof("Authenticated as %s", record.Label)
|
||||||
|
}
|
||||||
|
log.Info("Cursor authentication successful!")
|
||||||
|
}
|
||||||
69
internal/cmd/gitlab_login.go
Normal file
69
internal/cmd/gitlab_login.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DoGitLabLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
promptFn := options.Prompt
|
||||||
|
if promptFn == nil {
|
||||||
|
promptFn = defaultProjectPrompt()
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
CallbackPort: options.CallbackPort,
|
||||||
|
Metadata: map[string]string{
|
||||||
|
"login_mode": "oauth",
|
||||||
|
},
|
||||||
|
Prompt: promptFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("GitLab Duo authentication failed: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
fmt.Println("GitLab Duo authentication successful!")
|
||||||
|
}
|
||||||
|
|
||||||
|
func DoGitLabTokenLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
promptFn := options.Prompt
|
||||||
|
if promptFn == nil {
|
||||||
|
promptFn = defaultProjectPrompt()
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
Metadata: map[string]string{
|
||||||
|
"login_mode": "pat",
|
||||||
|
},
|
||||||
|
Prompt: promptFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("GitLab Duo PAT authentication failed: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
fmt.Println("GitLab Duo PAT authentication successful!")
|
||||||
|
}
|
||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -27,11 +28,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
|
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||||
geminiCLIVersion = "v1internal"
|
geminiCLIVersion = "v1internal"
|
||||||
geminiCLIUserAgent = "google-api-nodejs-client/9.15.1"
|
|
||||||
geminiCLIApiClient = "gl-node/22.17.0"
|
|
||||||
geminiCLIClientMetadata = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type projectSelectionRequiredError struct{}
|
type projectSelectionRequiredError struct{}
|
||||||
@@ -409,9 +407,7 @@ func callGeminiCLI(ctx context.Context, httpClient *http.Client, endpoint string
|
|||||||
return fmt.Errorf("create request: %w", errRequest)
|
return fmt.Errorf("create request: %w", errRequest)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||||
req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient)
|
|
||||||
req.Header.Set("Client-Metadata", geminiCLIClientMetadata)
|
|
||||||
|
|
||||||
resp, errDo := httpClient.Do(req)
|
resp, errDo := httpClient.Do(req)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
@@ -630,7 +626,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
|||||||
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||||
resp, errDo := httpClient.Do(req)
|
resp, errDo := httpClient.Do(req)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
||||||
@@ -651,7 +647,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
|||||||
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
|
||||||
resp, errDo = httpClient.Do(req)
|
resp, errDo = httpClient.Do(req)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
||||||
|
|||||||
@@ -1,60 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
||||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DoQwenLogin handles the Qwen device flow using the shared authentication manager.
|
|
||||||
// It initiates the device-based authentication process for Qwen services and saves
|
|
||||||
// the authentication tokens to the configured auth directory.
|
|
||||||
//
|
|
||||||
// Parameters:
|
|
||||||
// - cfg: The application configuration
|
|
||||||
// - options: Login options including browser behavior and prompts
|
|
||||||
func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
|
|
||||||
if options == nil {
|
|
||||||
options = &LoginOptions{}
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := newAuthManager()
|
|
||||||
|
|
||||||
promptFn := options.Prompt
|
|
||||||
if promptFn == nil {
|
|
||||||
promptFn = func(prompt string) (string, error) {
|
|
||||||
fmt.Println()
|
|
||||||
fmt.Println(prompt)
|
|
||||||
var value string
|
|
||||||
_, err := fmt.Scanln(&value)
|
|
||||||
return value, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
authOpts := &sdkAuth.LoginOptions{
|
|
||||||
NoBrowser: options.NoBrowser,
|
|
||||||
CallbackPort: options.CallbackPort,
|
|
||||||
Metadata: map[string]string{},
|
|
||||||
Prompt: promptFn,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
|
|
||||||
if err != nil {
|
|
||||||
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
|
|
||||||
log.Error(emailErr.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fmt.Printf("Qwen authentication failed: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if savedPath != "" {
|
|
||||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("Qwen authentication successful!")
|
|
||||||
}
|
|
||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
// DoVertexImport imports a Google Cloud service account key JSON and persists
|
// DoVertexImport imports a Google Cloud service account key JSON and persists
|
||||||
// it as a "vertex" provider credential. The file content is embedded in the auth
|
// it as a "vertex" provider credential. The file content is embedded in the auth
|
||||||
// file to allow portable deployment across stores.
|
// file to allow portable deployment across stores.
|
||||||
func DoVertexImport(cfg *config.Config, keyPath string) {
|
func DoVertexImport(cfg *config.Config, keyPath string, prefix string) {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
cfg = &config.Config{}
|
cfg = &config.Config{}
|
||||||
}
|
}
|
||||||
@@ -62,13 +62,28 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
|
|||||||
// Default location if not provided by user. Can be edited in the saved file later.
|
// Default location if not provided by user. Can be edited in the saved file later.
|
||||||
location := "us-central1"
|
location := "us-central1"
|
||||||
|
|
||||||
fileName := fmt.Sprintf("vertex-%s.json", sanitizeFilePart(projectID))
|
// Normalize and validate prefix: must be a single segment (no "/" allowed).
|
||||||
|
prefix = strings.TrimSpace(prefix)
|
||||||
|
prefix = strings.Trim(prefix, "/")
|
||||||
|
if prefix != "" && strings.Contains(prefix, "/") {
|
||||||
|
log.Errorf("vertex-import: prefix must be a single segment (no '/' allowed): %q", prefix)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Include prefix in filename so importing the same project with different
|
||||||
|
// prefixes creates separate credential files instead of overwriting.
|
||||||
|
baseName := sanitizeFilePart(projectID)
|
||||||
|
if prefix != "" {
|
||||||
|
baseName = sanitizeFilePart(prefix) + "-" + baseName
|
||||||
|
}
|
||||||
|
fileName := fmt.Sprintf("vertex-%s.json", baseName)
|
||||||
// Build auth record
|
// Build auth record
|
||||||
storage := &vertex.VertexCredentialStorage{
|
storage := &vertex.VertexCredentialStorage{
|
||||||
ServiceAccount: sa,
|
ServiceAccount: sa,
|
||||||
ProjectID: projectID,
|
ProjectID: projectID,
|
||||||
Email: email,
|
Email: email,
|
||||||
Location: location,
|
Location: location,
|
||||||
|
Prefix: prefix,
|
||||||
}
|
}
|
||||||
metadata := map[string]any{
|
metadata := map[string]any{
|
||||||
"service_account": sa,
|
"service_account": sa,
|
||||||
@@ -76,6 +91,7 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
|
|||||||
"email": email,
|
"email": email,
|
||||||
"location": location,
|
"location": location,
|
||||||
"type": "vertex",
|
"type": "vertex",
|
||||||
|
"prefix": prefix,
|
||||||
"label": labelForVertex(projectID, email),
|
"label": labelForVertex(projectID, email),
|
||||||
}
|
}
|
||||||
record := &coreauth.Auth{
|
record := &coreauth.Auth{
|
||||||
|
|||||||
55
internal/config/claude_header_defaults_test.go
Normal file
55
internal/config/claude_header_defaults_test.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadConfigOptional_ClaudeHeaderDefaults(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
configPath := filepath.Join(dir, "config.yaml")
|
||||||
|
configYAML := []byte(`
|
||||||
|
claude-header-defaults:
|
||||||
|
user-agent: " claude-cli/2.1.70 (external, cli) "
|
||||||
|
package-version: " 0.80.0 "
|
||||||
|
runtime-version: " v24.5.0 "
|
||||||
|
os: " MacOS "
|
||||||
|
arch: " arm64 "
|
||||||
|
timeout: " 900 "
|
||||||
|
stabilize-device-profile: false
|
||||||
|
`)
|
||||||
|
if err := os.WriteFile(configPath, configYAML, 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to write config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := LoadConfigOptional(configPath, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadConfigOptional() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.UserAgent; got != "claude-cli/2.1.70 (external, cli)" {
|
||||||
|
t.Fatalf("UserAgent = %q, want %q", got, "claude-cli/2.1.70 (external, cli)")
|
||||||
|
}
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.PackageVersion; got != "0.80.0" {
|
||||||
|
t.Fatalf("PackageVersion = %q, want %q", got, "0.80.0")
|
||||||
|
}
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.RuntimeVersion; got != "v24.5.0" {
|
||||||
|
t.Fatalf("RuntimeVersion = %q, want %q", got, "v24.5.0")
|
||||||
|
}
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.OS; got != "MacOS" {
|
||||||
|
t.Fatalf("OS = %q, want %q", got, "MacOS")
|
||||||
|
}
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.Arch; got != "arm64" {
|
||||||
|
t.Fatalf("Arch = %q, want %q", got, "arm64")
|
||||||
|
}
|
||||||
|
if got := cfg.ClaudeHeaderDefaults.Timeout; got != "900" {
|
||||||
|
t.Fatalf("Timeout = %q, want %q", got, "900")
|
||||||
|
}
|
||||||
|
if cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
|
||||||
|
t.Fatal("StabilizeDeviceProfile = nil, want non-nil")
|
||||||
|
}
|
||||||
|
if got := *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile; got {
|
||||||
|
t.Fatalf("StabilizeDeviceProfile = %v, want false", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
32
internal/config/codex_websocket_header_defaults_test.go
Normal file
32
internal/config/codex_websocket_header_defaults_test.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadConfigOptional_CodexHeaderDefaults(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
configPath := filepath.Join(dir, "config.yaml")
|
||||||
|
configYAML := []byte(`
|
||||||
|
codex-header-defaults:
|
||||||
|
user-agent: " my-codex-client/1.0 "
|
||||||
|
beta-features: " feature-a,feature-b "
|
||||||
|
`)
|
||||||
|
if err := os.WriteFile(configPath, configYAML, 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to write config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := LoadConfigOptional(configPath, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadConfigOptional() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := cfg.CodexHeaderDefaults.UserAgent; got != "my-codex-client/1.0" {
|
||||||
|
t.Fatalf("UserAgent = %q, want %q", got, "my-codex-client/1.0")
|
||||||
|
}
|
||||||
|
if got := cfg.CodexHeaderDefaults.BetaFeatures; got != "feature-a,feature-b" {
|
||||||
|
t.Fatalf("BetaFeatures = %q, want %q", got, "feature-a,feature-b")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
@@ -67,6 +68,10 @@ type Config struct {
|
|||||||
// DisableCooling disables quota cooldown scheduling when true.
|
// DisableCooling disables quota cooldown scheduling when true.
|
||||||
DisableCooling bool `yaml:"disable-cooling" json:"disable-cooling"`
|
DisableCooling bool `yaml:"disable-cooling" json:"disable-cooling"`
|
||||||
|
|
||||||
|
// AuthAutoRefreshWorkers overrides the size of the core auth auto-refresh worker pool.
|
||||||
|
// When <= 0, the default worker count is used.
|
||||||
|
AuthAutoRefreshWorkers int `yaml:"auth-auto-refresh-workers" json:"auth-auto-refresh-workers"`
|
||||||
|
|
||||||
// RequestRetry defines the retry times when the request failed.
|
// RequestRetry defines the retry times when the request failed.
|
||||||
RequestRetry int `yaml:"request-retry" json:"request-retry"`
|
RequestRetry int `yaml:"request-retry" json:"request-retry"`
|
||||||
// MaxRetryCredentials defines the maximum number of credentials to try for a failed request.
|
// MaxRetryCredentials defines the maximum number of credentials to try for a failed request.
|
||||||
@@ -84,6 +89,13 @@ type Config struct {
|
|||||||
// WebsocketAuth enables or disables authentication for the WebSocket API.
|
// WebsocketAuth enables or disables authentication for the WebSocket API.
|
||||||
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
|
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
|
||||||
|
|
||||||
|
// AntigravitySignatureCacheEnabled controls whether signature cache validation is enabled for thinking blocks.
|
||||||
|
// When true (default), cached signatures are preferred and validated.
|
||||||
|
// When false, client signatures are used directly after normalization (bypass mode).
|
||||||
|
AntigravitySignatureCacheEnabled *bool `yaml:"antigravity-signature-cache-enabled,omitempty" json:"antigravity-signature-cache-enabled,omitempty"`
|
||||||
|
|
||||||
|
AntigravitySignatureBypassStrict *bool `yaml:"antigravity-signature-bypass-strict,omitempty" json:"antigravity-signature-bypass-strict,omitempty"`
|
||||||
|
|
||||||
// GeminiKey defines Gemini API key configurations with optional routing overrides.
|
// GeminiKey defines Gemini API key configurations with optional routing overrides.
|
||||||
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
|
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
|
||||||
|
|
||||||
@@ -101,6 +113,10 @@ type Config struct {
|
|||||||
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
|
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
|
||||||
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
|
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
|
||||||
|
|
||||||
|
// CodexHeaderDefaults configures fallback headers for Codex OAuth model requests.
|
||||||
|
// These are used only when the client does not send its own headers.
|
||||||
|
CodexHeaderDefaults CodexHeaderDefaults `yaml:"codex-header-defaults" json:"codex-header-defaults"`
|
||||||
|
|
||||||
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
|
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
|
||||||
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
|
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
|
||||||
|
|
||||||
@@ -119,12 +135,12 @@ type Config struct {
|
|||||||
AmpCode AmpCode `yaml:"ampcode" json:"ampcode"`
|
AmpCode AmpCode `yaml:"ampcode" json:"ampcode"`
|
||||||
|
|
||||||
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
|
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
|
||||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
|
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, iflow, kiro, github-copilot, kimi.
|
||||||
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
|
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
|
||||||
|
|
||||||
// OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels.
|
// OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels.
|
||||||
// These aliases affect both model listing and model routing for supported channels:
|
// These aliases affect both model listing and model routing for supported channels:
|
||||||
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
|
// gemini-cli, vertex, aistudio, antigravity, claude, codex, iflow, kiro, github-copilot, kimi.
|
||||||
//
|
//
|
||||||
// NOTE: This does not apply to existing per-credential model alias features under:
|
// NOTE: This does not apply to existing per-credential model alias features under:
|
||||||
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
|
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
|
||||||
@@ -141,13 +157,27 @@ type Config struct {
|
|||||||
legacyMigrationPending bool `yaml:"-" json:"-"`
|
legacyMigrationPending bool `yaml:"-" json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
|
// ClaudeHeaderDefaults configures default header values injected into Claude API requests.
|
||||||
// when the client does not send them. Update these when Claude Code releases a new version.
|
// In legacy mode, UserAgent/PackageVersion/RuntimeVersion/Timeout act as fallbacks when
|
||||||
|
// the client omits them, while OS/Arch remain runtime-derived. When stabilized device
|
||||||
|
// profiles are enabled, OS/Arch become the pinned platform baseline, while
|
||||||
|
// UserAgent/PackageVersion/RuntimeVersion seed the upgradeable software fingerprint.
|
||||||
type ClaudeHeaderDefaults struct {
|
type ClaudeHeaderDefaults struct {
|
||||||
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||||
PackageVersion string `yaml:"package-version" json:"package-version"`
|
PackageVersion string `yaml:"package-version" json:"package-version"`
|
||||||
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
||||||
Timeout string `yaml:"timeout" json:"timeout"`
|
OS string `yaml:"os" json:"os"`
|
||||||
|
Arch string `yaml:"arch" json:"arch"`
|
||||||
|
Timeout string `yaml:"timeout" json:"timeout"`
|
||||||
|
StabilizeDeviceProfile *bool `yaml:"stabilize-device-profile,omitempty" json:"stabilize-device-profile,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CodexHeaderDefaults configures fallback header values injected into Codex
|
||||||
|
// model requests for OAuth/file-backed auth when the client omits them.
|
||||||
|
// UserAgent applies to HTTP and websocket requests; BetaFeatures only applies to websockets.
|
||||||
|
type CodexHeaderDefaults struct {
|
||||||
|
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||||
|
BetaFeatures string `yaml:"beta-features" json:"beta-features"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSConfig holds HTTPS server settings.
|
// TLSConfig holds HTTPS server settings.
|
||||||
@@ -176,6 +206,9 @@ type RemoteManagement struct {
|
|||||||
SecretKey string `yaml:"secret-key"`
|
SecretKey string `yaml:"secret-key"`
|
||||||
// DisableControlPanel skips serving and syncing the bundled management UI when true.
|
// DisableControlPanel skips serving and syncing the bundled management UI when true.
|
||||||
DisableControlPanel bool `yaml:"disable-control-panel"`
|
DisableControlPanel bool `yaml:"disable-control-panel"`
|
||||||
|
// DisableAutoUpdatePanel disables automatic periodic background updates of the management panel asset from GitHub.
|
||||||
|
// When false (the default), the background updater remains enabled; when true, the panel is only downloaded on first access if missing.
|
||||||
|
DisableAutoUpdatePanel bool `yaml:"disable-auto-update-panel"`
|
||||||
// PanelGitHubRepository overrides the GitHub repository used to fetch the management panel asset.
|
// PanelGitHubRepository overrides the GitHub repository used to fetch the management panel asset.
|
||||||
// Accepts either a repository URL (https://github.com/org/repo) or an API releases endpoint.
|
// Accepts either a repository URL (https://github.com/org/repo) or an API releases endpoint.
|
||||||
PanelGitHubRepository string `yaml:"panel-github-repository"`
|
PanelGitHubRepository string `yaml:"panel-github-repository"`
|
||||||
@@ -189,6 +222,10 @@ type QuotaExceeded struct {
|
|||||||
|
|
||||||
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
|
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
|
||||||
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
|
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
|
||||||
|
|
||||||
|
// AntigravityCredits indicates whether to retry Antigravity quota_exhausted 429s once
|
||||||
|
// on the same credential with enabledCreditTypes=["GOOGLE_ONE_AI"].
|
||||||
|
AntigravityCredits bool `yaml:"antigravity-credits" json:"antigravity-credits"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoutingConfig configures how credentials are selected for requests.
|
// RoutingConfig configures how credentials are selected for requests.
|
||||||
@@ -196,6 +233,22 @@ type RoutingConfig struct {
|
|||||||
// Strategy selects the credential selection strategy.
|
// Strategy selects the credential selection strategy.
|
||||||
// Supported values: "round-robin" (default), "fill-first".
|
// Supported values: "round-robin" (default), "fill-first".
|
||||||
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
|
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
|
||||||
|
|
||||||
|
// ClaudeCodeSessionAffinity enables session-sticky routing for Claude Code clients.
|
||||||
|
// When enabled, requests with the same session ID (extracted from metadata.user_id)
|
||||||
|
// are routed to the same auth credential when available.
|
||||||
|
// Deprecated: Use SessionAffinity instead for universal session support.
|
||||||
|
ClaudeCodeSessionAffinity bool `yaml:"claude-code-session-affinity,omitempty" json:"claude-code-session-affinity,omitempty"`
|
||||||
|
|
||||||
|
// SessionAffinity enables universal session-sticky routing for all clients.
|
||||||
|
// Session IDs are extracted from multiple sources:
|
||||||
|
// X-Session-ID header, Idempotency-Key, metadata.user_id, conversation_id, or message hash.
|
||||||
|
// Automatic failover is always enabled when bound auth becomes unavailable.
|
||||||
|
SessionAffinity bool `yaml:"session-affinity,omitempty" json:"session-affinity,omitempty"`
|
||||||
|
|
||||||
|
// SessionAffinityTTL specifies how long session-to-auth bindings are retained.
|
||||||
|
// Default: 1h. Accepts duration strings like "30m", "1h", "2h30m".
|
||||||
|
SessionAffinityTTL string `yaml:"session-affinity-ttl,omitempty" json:"session-affinity-ttl,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuthModelAlias defines a model ID alias for a specific channel.
|
// OAuthModelAlias defines a model ID alias for a specific channel.
|
||||||
@@ -235,8 +288,8 @@ type AmpCode struct {
|
|||||||
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
||||||
|
|
||||||
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
|
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
|
||||||
// When a client authenticates with a key that matches an entry, that upstream key is used.
|
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
|
||||||
// If no match is found, falls back to UpstreamAPIKey (default behavior).
|
// is used for the upstream Amp request.
|
||||||
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
|
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
|
||||||
|
|
||||||
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
|
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
|
||||||
@@ -358,6 +411,11 @@ type ClaudeKey struct {
|
|||||||
|
|
||||||
// Cloak configures request cloaking for non-Claude-Code clients.
|
// Cloak configures request cloaking for non-Claude-Code clients.
|
||||||
Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"`
|
Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"`
|
||||||
|
|
||||||
|
// ExperimentalCCHSigning enables opt-in final-body cch signing for cloaked
|
||||||
|
// Claude /v1/messages requests. It is disabled by default so upstream seed
|
||||||
|
// changes do not alter the proxy's legacy behavior.
|
||||||
|
ExperimentalCCHSigning bool `yaml:"experimental-cch-signing,omitempty" json:"experimental-cch-signing,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k ClaudeKey) GetAPIKey() string { return k.APIKey }
|
func (k ClaudeKey) GetAPIKey() string { return k.APIKey }
|
||||||
@@ -556,6 +614,10 @@ type OpenAICompatibilityModel struct {
|
|||||||
|
|
||||||
// Alias is the model name alias that clients will use to reference this model.
|
// Alias is the model name alias that clients will use to reference this model.
|
||||||
Alias string `yaml:"alias" json:"alias"`
|
Alias string `yaml:"alias" json:"alias"`
|
||||||
|
|
||||||
|
// Thinking configures the thinking/reasoning capability for this model.
|
||||||
|
// If nil, the model defaults to level-based reasoning with levels ["low", "medium", "high"].
|
||||||
|
Thinking *registry.ThinkingSupport `yaml:"thinking,omitempty" json:"thinking,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m OpenAICompatibilityModel) GetName() string { return m.Name }
|
func (m OpenAICompatibilityModel) GetName() string { return m.Name }
|
||||||
@@ -579,16 +641,6 @@ func LoadConfig(configFile string) (*Config, error) {
|
|||||||
// If optional is true and the file is missing, it returns an empty Config.
|
// If optional is true and the file is missing, it returns an empty Config.
|
||||||
// If optional is true and the file is empty or invalid, it returns an empty Config.
|
// If optional is true and the file is empty or invalid, it returns an empty Config.
|
||||||
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||||
// NOTE: Startup oauth-model-alias migration is intentionally disabled.
|
|
||||||
// Reason: avoid mutating config.yaml during server startup.
|
|
||||||
// Re-enable the block below if automatic startup migration is needed again.
|
|
||||||
// if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
|
|
||||||
// // Log warning but don't fail - config loading should still work
|
|
||||||
// fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
|
|
||||||
// } else if migrated {
|
|
||||||
// fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Read the entire configuration file into memory.
|
// Read the entire configuration file into memory.
|
||||||
data, err := os.ReadFile(configFile)
|
data, err := os.ReadFile(configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -683,12 +735,18 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
// Sanitize Gemini API key configuration and migrate legacy entries.
|
// Sanitize Gemini API key configuration and migrate legacy entries.
|
||||||
cfg.SanitizeGeminiKeys()
|
cfg.SanitizeGeminiKeys()
|
||||||
|
|
||||||
// Sanitize Vertex-compatible API keys: drop entries without base-url
|
// Sanitize Vertex-compatible API keys.
|
||||||
cfg.SanitizeVertexCompatKeys()
|
cfg.SanitizeVertexCompatKeys()
|
||||||
|
|
||||||
// Sanitize Codex keys: drop entries without base-url
|
// Sanitize Codex keys: drop entries without base-url
|
||||||
cfg.SanitizeCodexKeys()
|
cfg.SanitizeCodexKeys()
|
||||||
|
|
||||||
|
// Sanitize Codex header defaults.
|
||||||
|
cfg.SanitizeCodexHeaderDefaults()
|
||||||
|
|
||||||
|
// Sanitize Claude header defaults.
|
||||||
|
cfg.SanitizeClaudeHeaderDefaults()
|
||||||
|
|
||||||
// Sanitize Claude key headers
|
// Sanitize Claude key headers
|
||||||
cfg.SanitizeClaudeKeys()
|
cfg.SanitizeClaudeKeys()
|
||||||
|
|
||||||
@@ -781,6 +839,30 @@ func payloadRawString(value any) ([]byte, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SanitizeCodexHeaderDefaults trims surrounding whitespace from the
|
||||||
|
// configured Codex header fallback values.
|
||||||
|
func (cfg *Config) SanitizeCodexHeaderDefaults() {
|
||||||
|
if cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg.CodexHeaderDefaults.UserAgent = strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent)
|
||||||
|
cfg.CodexHeaderDefaults.BetaFeatures = strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeClaudeHeaderDefaults trims surrounding whitespace from the
|
||||||
|
// configured Claude fingerprint baseline values.
|
||||||
|
func (cfg *Config) SanitizeClaudeHeaderDefaults() {
|
||||||
|
if cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg.ClaudeHeaderDefaults.UserAgent = strings.TrimSpace(cfg.ClaudeHeaderDefaults.UserAgent)
|
||||||
|
cfg.ClaudeHeaderDefaults.PackageVersion = strings.TrimSpace(cfg.ClaudeHeaderDefaults.PackageVersion)
|
||||||
|
cfg.ClaudeHeaderDefaults.RuntimeVersion = strings.TrimSpace(cfg.ClaudeHeaderDefaults.RuntimeVersion)
|
||||||
|
cfg.ClaudeHeaderDefaults.OS = strings.TrimSpace(cfg.ClaudeHeaderDefaults.OS)
|
||||||
|
cfg.ClaudeHeaderDefaults.Arch = strings.TrimSpace(cfg.ClaudeHeaderDefaults.Arch)
|
||||||
|
cfg.ClaudeHeaderDefaults.Timeout = strings.TrimSpace(cfg.ClaudeHeaderDefaults.Timeout)
|
||||||
|
}
|
||||||
|
|
||||||
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
|
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
|
||||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||||
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
||||||
@@ -926,6 +1008,7 @@ func (cfg *Config) SanitizeKiroKeys() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SanitizeGeminiKeys deduplicates and normalizes Gemini credentials.
|
// SanitizeGeminiKeys deduplicates and normalizes Gemini credentials.
|
||||||
|
// It uses API key + base URL as the uniqueness key.
|
||||||
func (cfg *Config) SanitizeGeminiKeys() {
|
func (cfg *Config) SanitizeGeminiKeys() {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return
|
return
|
||||||
@@ -944,10 +1027,11 @@ func (cfg *Config) SanitizeGeminiKeys() {
|
|||||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||||
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||||
if _, exists := seen[entry.APIKey]; exists {
|
uniqueKey := entry.APIKey + "|" + entry.BaseURL
|
||||||
|
if _, exists := seen[uniqueKey]; exists {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
seen[entry.APIKey] = struct{}{}
|
seen[uniqueKey] = struct{}{}
|
||||||
out = append(out, entry)
|
out = append(out, entry)
|
||||||
}
|
}
|
||||||
cfg.GeminiKey = out
|
cfg.GeminiKey = out
|
||||||
@@ -1676,9 +1760,6 @@ func pruneMappingToGeneratedKeys(dstRoot, srcRoot *yaml.Node, key string) {
|
|||||||
srcIdx := findMapKeyIndex(srcRoot, key)
|
srcIdx := findMapKeyIndex(srcRoot, key)
|
||||||
if srcIdx < 0 {
|
if srcIdx < 0 {
|
||||||
// Keep an explicit empty mapping for oauth-model-alias when it was previously present.
|
// Keep an explicit empty mapping for oauth-model-alias when it was previously present.
|
||||||
//
|
|
||||||
// Rationale: LoadConfig runs MigrateOAuthModelAlias before unmarshalling. If the
|
|
||||||
// oauth-model-alias key is missing, migration will add the default antigravity aliases.
|
|
||||||
// When users delete the last channel from oauth-model-alias via the management API,
|
// When users delete the last channel from oauth-model-alias via the management API,
|
||||||
// we want that deletion to persist across hot reloads and restarts.
|
// we want that deletion to persist across hot reloads and restarts.
|
||||||
if key == "oauth-model-alias" {
|
if key == "oauth-model-alias" {
|
||||||
|
|||||||
61
internal/config/oauth_model_alias_defaults.go
Normal file
61
internal/config/oauth_model_alias_defaults.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// defaultKiroAliases returns default oauth-model-alias entries for Kiro.
|
||||||
|
// These aliases expose standard Claude IDs for Kiro-prefixed upstream models.
|
||||||
|
func defaultKiroAliases() []OAuthModelAlias {
|
||||||
|
return []OAuthModelAlias{
|
||||||
|
// Sonnet 4.6
|
||||||
|
{Name: "kiro-claude-sonnet-4-6", Alias: "claude-sonnet-4-6", Fork: true},
|
||||||
|
// Sonnet 4.5
|
||||||
|
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5-20250929", Fork: true},
|
||||||
|
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||||
|
// Sonnet 4
|
||||||
|
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4-20250514", Fork: true},
|
||||||
|
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4", Fork: true},
|
||||||
|
// Opus 4.6
|
||||||
|
{Name: "kiro-claude-opus-4-6", Alias: "claude-opus-4-6", Fork: true},
|
||||||
|
// Opus 4.5
|
||||||
|
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5-20251101", Fork: true},
|
||||||
|
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5", Fork: true},
|
||||||
|
// Haiku 4.5
|
||||||
|
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5-20251001", Fork: true},
|
||||||
|
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5", Fork: true},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultGitHubCopilotAliases returns default oauth-model-alias entries for
|
||||||
|
// GitHub Copilot Claude models. It exposes hyphen-style IDs used by clients.
|
||||||
|
func defaultGitHubCopilotAliases() []OAuthModelAlias {
|
||||||
|
return []OAuthModelAlias{
|
||||||
|
{Name: "claude-haiku-4.5", Alias: "claude-haiku-4-5", Fork: true},
|
||||||
|
{Name: "claude-opus-4.1", Alias: "claude-opus-4-1", Fork: true},
|
||||||
|
{Name: "claude-opus-4.5", Alias: "claude-opus-4-5", Fork: true},
|
||||||
|
{Name: "claude-opus-4.6", Alias: "claude-opus-4-6", Fork: true},
|
||||||
|
{Name: "claude-sonnet-4.5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||||
|
{Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitHubCopilotAliasesFromModels generates oauth-model-alias entries from a dynamic
|
||||||
|
// list of model IDs fetched from the Copilot API. It auto-creates aliases for
|
||||||
|
// models whose ID contains a dot (e.g. "claude-opus-4.6" → "claude-opus-4-6"),
|
||||||
|
// which is the pattern used by Claude models on Copilot.
|
||||||
|
func GitHubCopilotAliasesFromModels(modelIDs []string) []OAuthModelAlias {
|
||||||
|
var aliases []OAuthModelAlias
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
for _, id := range modelIDs {
|
||||||
|
if !strings.Contains(id, ".") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
hyphenID := strings.ReplaceAll(id, ".", "-")
|
||||||
|
key := id + "→" + hyphenID
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
aliases = append(aliases, OAuthModelAlias{Name: id, Alias: hyphenID, Fork: true})
|
||||||
|
}
|
||||||
|
return aliases
|
||||||
|
}
|
||||||
@@ -1,316 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
// antigravityModelConversionTable maps old built-in aliases to actual model names
|
|
||||||
// for the antigravity channel during migration.
|
|
||||||
var antigravityModelConversionTable = map[string]string{
|
|
||||||
"gemini-2.5-computer-use-preview-10-2025": "rev19-uic3-1p",
|
|
||||||
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
|
||||||
"gemini-3-pro-preview": "gemini-3-pro-high",
|
|
||||||
"gemini-3-flash-preview": "gemini-3-flash",
|
|
||||||
"gemini-claude-sonnet-4-5": "claude-sonnet-4-5",
|
|
||||||
"gemini-claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
|
||||||
"gemini-claude-opus-4-5-thinking": "claude-opus-4-5-thinking",
|
|
||||||
"gemini-claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultKiroAliases returns the default oauth-model-alias configuration
|
|
||||||
// for the kiro channel. Maps kiro-prefixed model names to standard Claude model
|
|
||||||
// names so that clients like Claude Code can use standard names directly.
|
|
||||||
func defaultKiroAliases() []OAuthModelAlias {
|
|
||||||
return []OAuthModelAlias{
|
|
||||||
// Sonnet 4.6
|
|
||||||
{Name: "kiro-claude-sonnet-4-6", Alias: "claude-sonnet-4-6", Fork: true},
|
|
||||||
// Sonnet 4.5
|
|
||||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5-20250929", Fork: true},
|
|
||||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5", Fork: true},
|
|
||||||
// Sonnet 4
|
|
||||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4-20250514", Fork: true},
|
|
||||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4", Fork: true},
|
|
||||||
// Opus 4.6
|
|
||||||
{Name: "kiro-claude-opus-4-6", Alias: "claude-opus-4-6", Fork: true},
|
|
||||||
// Opus 4.5
|
|
||||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5-20251101", Fork: true},
|
|
||||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5", Fork: true},
|
|
||||||
// Haiku 4.5
|
|
||||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5-20251001", Fork: true},
|
|
||||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5", Fork: true},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultGitHubCopilotAliases returns default oauth-model-alias entries that
|
|
||||||
// expose Claude hyphen-style IDs for GitHub Copilot Claude models.
|
|
||||||
// This keeps compatibility with clients (e.g. Claude Code) that use
|
|
||||||
// Anthropic-style model IDs like "claude-opus-4-6".
|
|
||||||
func defaultGitHubCopilotAliases() []OAuthModelAlias {
|
|
||||||
return []OAuthModelAlias{
|
|
||||||
{Name: "claude-haiku-4.5", Alias: "claude-haiku-4-5", Fork: true},
|
|
||||||
{Name: "claude-opus-4.1", Alias: "claude-opus-4-1", Fork: true},
|
|
||||||
{Name: "claude-opus-4.5", Alias: "claude-opus-4-5", Fork: true},
|
|
||||||
{Name: "claude-opus-4.6", Alias: "claude-opus-4-6", Fork: true},
|
|
||||||
{Name: "claude-sonnet-4.5", Alias: "claude-sonnet-4-5", Fork: true},
|
|
||||||
{Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultAntigravityAliases returns the default oauth-model-alias configuration
|
|
||||||
// for the antigravity channel when neither field exists.
|
|
||||||
func defaultAntigravityAliases() []OAuthModelAlias {
|
|
||||||
return []OAuthModelAlias{
|
|
||||||
{Name: "rev19-uic3-1p", Alias: "gemini-2.5-computer-use-preview-10-2025"},
|
|
||||||
{Name: "gemini-3-pro-image", Alias: "gemini-3-pro-image-preview"},
|
|
||||||
{Name: "gemini-3-pro-high", Alias: "gemini-3-pro-preview"},
|
|
||||||
{Name: "gemini-3-flash", Alias: "gemini-3-flash-preview"},
|
|
||||||
{Name: "claude-sonnet-4-5", Alias: "gemini-claude-sonnet-4-5"},
|
|
||||||
{Name: "claude-sonnet-4-5-thinking", Alias: "gemini-claude-sonnet-4-5-thinking"},
|
|
||||||
{Name: "claude-opus-4-5-thinking", Alias: "gemini-claude-opus-4-5-thinking"},
|
|
||||||
{Name: "claude-opus-4-6-thinking", Alias: "gemini-claude-opus-4-6-thinking"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MigrateOAuthModelAlias checks for and performs migration from oauth-model-mappings
|
|
||||||
// to oauth-model-alias at startup. Returns true if migration was performed.
|
|
||||||
//
|
|
||||||
// Migration flow:
|
|
||||||
// 1. Check if oauth-model-alias exists -> skip migration
|
|
||||||
// 2. Check if oauth-model-mappings exists -> convert and migrate
|
|
||||||
// - For antigravity channel, convert old built-in aliases to actual model names
|
|
||||||
//
|
|
||||||
// 3. Neither exists -> add default antigravity config
|
|
||||||
func MigrateOAuthModelAlias(configFile string) (bool, error) {
|
|
||||||
data, err := os.ReadFile(configFile)
|
|
||||||
if err != nil {
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
if len(data) == 0 {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse YAML into node tree to preserve structure
|
|
||||||
var root yaml.Node
|
|
||||||
if err := yaml.Unmarshal(data, &root); err != nil {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
if root.Kind != yaml.DocumentNode || len(root.Content) == 0 {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
rootMap := root.Content[0]
|
|
||||||
if rootMap == nil || rootMap.Kind != yaml.MappingNode {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if oauth-model-alias already exists
|
|
||||||
if findMapKeyIndex(rootMap, "oauth-model-alias") >= 0 {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if oauth-model-mappings exists
|
|
||||||
oldIdx := findMapKeyIndex(rootMap, "oauth-model-mappings")
|
|
||||||
if oldIdx >= 0 {
|
|
||||||
// Migrate from old field
|
|
||||||
return migrateFromOldField(configFile, &root, rootMap, oldIdx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Neither field exists - add default antigravity config
|
|
||||||
return addDefaultAntigravityConfig(configFile, &root, rootMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
// migrateFromOldField converts oauth-model-mappings to oauth-model-alias
|
|
||||||
func migrateFromOldField(configFile string, root *yaml.Node, rootMap *yaml.Node, oldIdx int) (bool, error) {
|
|
||||||
if oldIdx+1 >= len(rootMap.Content) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
oldValue := rootMap.Content[oldIdx+1]
|
|
||||||
if oldValue == nil || oldValue.Kind != yaml.MappingNode {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the old aliases
|
|
||||||
oldAliases := parseOldAliasNode(oldValue)
|
|
||||||
if len(oldAliases) == 0 {
|
|
||||||
// Remove the old field and write
|
|
||||||
removeMapKeyByIndex(rootMap, oldIdx)
|
|
||||||
return writeYAMLNode(configFile, root)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert model names for antigravity channel
|
|
||||||
newAliases := make(map[string][]OAuthModelAlias, len(oldAliases))
|
|
||||||
for channel, entries := range oldAliases {
|
|
||||||
converted := make([]OAuthModelAlias, 0, len(entries))
|
|
||||||
for _, entry := range entries {
|
|
||||||
newEntry := OAuthModelAlias{
|
|
||||||
Name: entry.Name,
|
|
||||||
Alias: entry.Alias,
|
|
||||||
Fork: entry.Fork,
|
|
||||||
}
|
|
||||||
// Convert model names for antigravity channel
|
|
||||||
if strings.EqualFold(channel, "antigravity") {
|
|
||||||
if actual, ok := antigravityModelConversionTable[entry.Name]; ok {
|
|
||||||
newEntry.Name = actual
|
|
||||||
}
|
|
||||||
}
|
|
||||||
converted = append(converted, newEntry)
|
|
||||||
}
|
|
||||||
newAliases[channel] = converted
|
|
||||||
}
|
|
||||||
|
|
||||||
// For antigravity channel, supplement missing default aliases
|
|
||||||
if antigravityEntries, exists := newAliases["antigravity"]; exists {
|
|
||||||
// Build a set of already configured model names (upstream names)
|
|
||||||
configuredModels := make(map[string]bool, len(antigravityEntries))
|
|
||||||
for _, entry := range antigravityEntries {
|
|
||||||
configuredModels[entry.Name] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add missing default aliases
|
|
||||||
for _, defaultAlias := range defaultAntigravityAliases() {
|
|
||||||
if !configuredModels[defaultAlias.Name] {
|
|
||||||
antigravityEntries = append(antigravityEntries, defaultAlias)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
newAliases["antigravity"] = antigravityEntries
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build new node
|
|
||||||
newNode := buildOAuthModelAliasNode(newAliases)
|
|
||||||
|
|
||||||
// Replace old key with new key and value
|
|
||||||
rootMap.Content[oldIdx].Value = "oauth-model-alias"
|
|
||||||
rootMap.Content[oldIdx+1] = newNode
|
|
||||||
|
|
||||||
return writeYAMLNode(configFile, root)
|
|
||||||
}
|
|
||||||
|
|
||||||
// addDefaultAntigravityConfig adds the default antigravity configuration
|
|
||||||
func addDefaultAntigravityConfig(configFile string, root *yaml.Node, rootMap *yaml.Node) (bool, error) {
|
|
||||||
defaults := map[string][]OAuthModelAlias{
|
|
||||||
"antigravity": defaultAntigravityAliases(),
|
|
||||||
}
|
|
||||||
newNode := buildOAuthModelAliasNode(defaults)
|
|
||||||
|
|
||||||
// Add new key-value pair
|
|
||||||
keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "oauth-model-alias"}
|
|
||||||
rootMap.Content = append(rootMap.Content, keyNode, newNode)
|
|
||||||
|
|
||||||
return writeYAMLNode(configFile, root)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseOldAliasNode parses the old oauth-model-mappings node structure
|
|
||||||
func parseOldAliasNode(node *yaml.Node) map[string][]OAuthModelAlias {
|
|
||||||
if node == nil || node.Kind != yaml.MappingNode {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
result := make(map[string][]OAuthModelAlias)
|
|
||||||
for i := 0; i+1 < len(node.Content); i += 2 {
|
|
||||||
channelNode := node.Content[i]
|
|
||||||
entriesNode := node.Content[i+1]
|
|
||||||
if channelNode == nil || entriesNode == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
channel := strings.ToLower(strings.TrimSpace(channelNode.Value))
|
|
||||||
if channel == "" || entriesNode.Kind != yaml.SequenceNode {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
entries := make([]OAuthModelAlias, 0, len(entriesNode.Content))
|
|
||||||
for _, entryNode := range entriesNode.Content {
|
|
||||||
if entryNode == nil || entryNode.Kind != yaml.MappingNode {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
entry := parseAliasEntry(entryNode)
|
|
||||||
if entry.Name != "" && entry.Alias != "" {
|
|
||||||
entries = append(entries, entry)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(entries) > 0 {
|
|
||||||
result[channel] = entries
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseAliasEntry parses a single alias entry node
|
|
||||||
func parseAliasEntry(node *yaml.Node) OAuthModelAlias {
|
|
||||||
var entry OAuthModelAlias
|
|
||||||
for i := 0; i+1 < len(node.Content); i += 2 {
|
|
||||||
keyNode := node.Content[i]
|
|
||||||
valNode := node.Content[i+1]
|
|
||||||
if keyNode == nil || valNode == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
switch strings.ToLower(strings.TrimSpace(keyNode.Value)) {
|
|
||||||
case "name":
|
|
||||||
entry.Name = strings.TrimSpace(valNode.Value)
|
|
||||||
case "alias":
|
|
||||||
entry.Alias = strings.TrimSpace(valNode.Value)
|
|
||||||
case "fork":
|
|
||||||
entry.Fork = strings.ToLower(strings.TrimSpace(valNode.Value)) == "true"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return entry
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildOAuthModelAliasNode creates a YAML node for oauth-model-alias
|
|
||||||
func buildOAuthModelAliasNode(aliases map[string][]OAuthModelAlias) *yaml.Node {
|
|
||||||
node := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
|
||||||
for channel, entries := range aliases {
|
|
||||||
channelNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: channel}
|
|
||||||
entriesNode := &yaml.Node{Kind: yaml.SequenceNode, Tag: "!!seq"}
|
|
||||||
for _, entry := range entries {
|
|
||||||
entryNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
|
||||||
entryNode.Content = append(entryNode.Content,
|
|
||||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "name"},
|
|
||||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Name},
|
|
||||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "alias"},
|
|
||||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Alias},
|
|
||||||
)
|
|
||||||
if entry.Fork {
|
|
||||||
entryNode.Content = append(entryNode.Content,
|
|
||||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "fork"},
|
|
||||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!bool", Value: "true"},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
entriesNode.Content = append(entriesNode.Content, entryNode)
|
|
||||||
}
|
|
||||||
node.Content = append(node.Content, channelNode, entriesNode)
|
|
||||||
}
|
|
||||||
return node
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeMapKeyByIndex removes a key-value pair from a mapping node by index
|
|
||||||
func removeMapKeyByIndex(mapNode *yaml.Node, keyIdx int) {
|
|
||||||
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if keyIdx < 0 || keyIdx+1 >= len(mapNode.Content) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
mapNode.Content = append(mapNode.Content[:keyIdx], mapNode.Content[keyIdx+2:]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// writeYAMLNode writes the YAML node tree back to file
|
|
||||||
func writeYAMLNode(configFile string, root *yaml.Node) (bool, error) {
|
|
||||||
f, err := os.Create(configFile)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
enc := yaml.NewEncoder(f)
|
|
||||||
enc.SetIndent(2)
|
|
||||||
if err := enc.Encode(root); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
if err := enc.Close(); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
@@ -1,245 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMigrateOAuthModelAlias_SkipsIfNewFieldExists(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
configFile := filepath.Join(dir, "config.yaml")
|
|
||||||
|
|
||||||
content := `oauth-model-alias:
|
|
||||||
gemini-cli:
|
|
||||||
- name: "gemini-2.5-pro"
|
|
||||||
alias: "g2.5p"
|
|
||||||
`
|
|
||||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if migrated {
|
|
||||||
t.Fatal("expected no migration when oauth-model-alias already exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify file unchanged
|
|
||||||
data, _ := os.ReadFile(configFile)
|
|
||||||
if !strings.Contains(string(data), "oauth-model-alias:") {
|
|
||||||
t.Fatal("file should still contain oauth-model-alias")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMigrateOAuthModelAlias_MigratesOldField(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
configFile := filepath.Join(dir, "config.yaml")
|
|
||||||
|
|
||||||
content := `oauth-model-mappings:
|
|
||||||
gemini-cli:
|
|
||||||
- name: "gemini-2.5-pro"
|
|
||||||
alias: "g2.5p"
|
|
||||||
fork: true
|
|
||||||
`
|
|
||||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if !migrated {
|
|
||||||
t.Fatal("expected migration to occur")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify new field exists and old field removed
|
|
||||||
data, _ := os.ReadFile(configFile)
|
|
||||||
if strings.Contains(string(data), "oauth-model-mappings:") {
|
|
||||||
t.Fatal("old field should be removed")
|
|
||||||
}
|
|
||||||
if !strings.Contains(string(data), "oauth-model-alias:") {
|
|
||||||
t.Fatal("new field should exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse and verify structure
|
|
||||||
var root yaml.Node
|
|
||||||
if err := yaml.Unmarshal(data, &root); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMigrateOAuthModelAlias_ConvertsAntigravityModels(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
configFile := filepath.Join(dir, "config.yaml")
|
|
||||||
|
|
||||||
// Use old model names that should be converted
|
|
||||||
content := `oauth-model-mappings:
|
|
||||||
antigravity:
|
|
||||||
- name: "gemini-2.5-computer-use-preview-10-2025"
|
|
||||||
alias: "computer-use"
|
|
||||||
- name: "gemini-3-pro-preview"
|
|
||||||
alias: "g3p"
|
|
||||||
`
|
|
||||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if !migrated {
|
|
||||||
t.Fatal("expected migration to occur")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify model names were converted
|
|
||||||
data, _ := os.ReadFile(configFile)
|
|
||||||
content = string(data)
|
|
||||||
if !strings.Contains(content, "rev19-uic3-1p") {
|
|
||||||
t.Fatal("expected gemini-2.5-computer-use-preview-10-2025 to be converted to rev19-uic3-1p")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "gemini-3-pro-high") {
|
|
||||||
t.Fatal("expected gemini-3-pro-preview to be converted to gemini-3-pro-high")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify missing default aliases were supplemented
|
|
||||||
if !strings.Contains(content, "gemini-3-pro-image") {
|
|
||||||
t.Fatal("expected missing default alias gemini-3-pro-image to be added")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "gemini-3-flash") {
|
|
||||||
t.Fatal("expected missing default alias gemini-3-flash to be added")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "claude-sonnet-4-5") {
|
|
||||||
t.Fatal("expected missing default alias claude-sonnet-4-5 to be added")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "claude-sonnet-4-5-thinking") {
|
|
||||||
t.Fatal("expected missing default alias claude-sonnet-4-5-thinking to be added")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "claude-opus-4-5-thinking") {
|
|
||||||
t.Fatal("expected missing default alias claude-opus-4-5-thinking to be added")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "claude-opus-4-6-thinking") {
|
|
||||||
t.Fatal("expected missing default alias claude-opus-4-6-thinking to be added")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMigrateOAuthModelAlias_AddsDefaultIfNeitherExists(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
configFile := filepath.Join(dir, "config.yaml")
|
|
||||||
|
|
||||||
content := `debug: true
|
|
||||||
port: 8080
|
|
||||||
`
|
|
||||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if !migrated {
|
|
||||||
t.Fatal("expected migration to add default config")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify default antigravity config was added
|
|
||||||
data, _ := os.ReadFile(configFile)
|
|
||||||
content = string(data)
|
|
||||||
if !strings.Contains(content, "oauth-model-alias:") {
|
|
||||||
t.Fatal("expected oauth-model-alias to be added")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "antigravity:") {
|
|
||||||
t.Fatal("expected antigravity channel to be added")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "rev19-uic3-1p") {
|
|
||||||
t.Fatal("expected default antigravity aliases to include rev19-uic3-1p")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMigrateOAuthModelAlias_PreservesOtherConfig(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
configFile := filepath.Join(dir, "config.yaml")
|
|
||||||
|
|
||||||
content := `debug: true
|
|
||||||
port: 8080
|
|
||||||
oauth-model-mappings:
|
|
||||||
gemini-cli:
|
|
||||||
- name: "test"
|
|
||||||
alias: "t"
|
|
||||||
api-keys:
|
|
||||||
- "key1"
|
|
||||||
- "key2"
|
|
||||||
`
|
|
||||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if !migrated {
|
|
||||||
t.Fatal("expected migration to occur")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify other config preserved
|
|
||||||
data, _ := os.ReadFile(configFile)
|
|
||||||
content = string(data)
|
|
||||||
if !strings.Contains(content, "debug: true") {
|
|
||||||
t.Fatal("expected debug field to be preserved")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "port: 8080") {
|
|
||||||
t.Fatal("expected port field to be preserved")
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "api-keys:") {
|
|
||||||
t.Fatal("expected api-keys field to be preserved")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMigrateOAuthModelAlias_NonexistentFile(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
migrated, err := MigrateOAuthModelAlias("/nonexistent/path/config.yaml")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error for nonexistent file: %v", err)
|
|
||||||
}
|
|
||||||
if migrated {
|
|
||||||
t.Fatal("expected no migration for nonexistent file")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMigrateOAuthModelAlias_EmptyFile(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
configFile := filepath.Join(dir, "config.yaml")
|
|
||||||
|
|
||||||
if err := os.WriteFile(configFile, []byte(""), 0644); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if migrated {
|
|
||||||
t.Fatal("expected no migration for empty file")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -9,6 +9,10 @@ type SDKConfig struct {
|
|||||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||||
|
|
||||||
|
// EnableGeminiCLIEndpoint controls whether Gemini CLI internal endpoints (/v1internal:*) are enabled.
|
||||||
|
// Default is false for safety; when false, /v1internal:* requests are rejected.
|
||||||
|
EnableGeminiCLIEndpoint bool `yaml:"enable-gemini-cli-endpoint" json:"enable-gemini-cli-endpoint"`
|
||||||
|
|
||||||
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
|
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
|
||||||
// to target prefixed credentials. When false, unprefixed model requests may use prefixed
|
// to target prefixed credentials. When false, unprefixed model requests may use prefixed
|
||||||
// credentials as well.
|
// credentials as well.
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ type VertexCompatKey struct {
|
|||||||
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
// BaseURL is the base URL for the Vertex-compatible API endpoint.
|
// BaseURL optionally overrides the Vertex-compatible API endpoint.
|
||||||
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
|
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
|
||||||
// Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..."
|
// When empty, requests fall back to the default Vertex API base URL.
|
||||||
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
|
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
|
||||||
|
|
||||||
// ProxyURL optionally overrides the global proxy for this API key.
|
// ProxyURL optionally overrides the global proxy for this API key.
|
||||||
@@ -34,6 +34,9 @@ type VertexCompatKey struct {
|
|||||||
|
|
||||||
// Models defines the model configurations including aliases for routing.
|
// Models defines the model configurations including aliases for routing.
|
||||||
Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"`
|
Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"`
|
||||||
|
|
||||||
|
// ExcludedModels lists model IDs that should be excluded for this provider.
|
||||||
|
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k VertexCompatKey) GetAPIKey() string { return k.APIKey }
|
func (k VertexCompatKey) GetAPIKey() string { return k.APIKey }
|
||||||
@@ -68,12 +71,9 @@ func (cfg *Config) SanitizeVertexCompatKeys() {
|
|||||||
}
|
}
|
||||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||||
if entry.BaseURL == "" {
|
|
||||||
// BaseURL is required for Vertex API key entries
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||||
|
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||||
|
|
||||||
// Sanitize models: remove entries without valid alias
|
// Sanitize models: remove entries without valid alias
|
||||||
sanitizedModels := make([]VertexCompatModel, 0, len(entry.Models))
|
sanitizedModels := make([]VertexCompatModel, 0, len(entry.Models))
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
package logging
|
package logging
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/flate"
|
"compress/flate"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
@@ -41,15 +42,17 @@ type RequestLogger interface {
|
|||||||
// - statusCode: The response status code
|
// - statusCode: The response status code
|
||||||
// - responseHeaders: The response headers
|
// - responseHeaders: The response headers
|
||||||
// - response: The raw response data
|
// - response: The raw response data
|
||||||
|
// - websocketTimeline: Optional downstream websocket event timeline
|
||||||
// - apiRequest: The API request data
|
// - apiRequest: The API request data
|
||||||
// - apiResponse: The API response data
|
// - apiResponse: The API response data
|
||||||
|
// - apiWebsocketTimeline: Optional upstream websocket event timeline
|
||||||
// - requestID: Optional request ID for log file naming
|
// - requestID: Optional request ID for log file naming
|
||||||
// - requestTimestamp: When the request was received
|
// - requestTimestamp: When the request was received
|
||||||
// - apiResponseTimestamp: When the API response was received
|
// - apiResponseTimestamp: When the API response was received
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - error: An error if logging fails, nil otherwise
|
// - error: An error if logging fails, nil otherwise
|
||||||
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
|
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
|
||||||
|
|
||||||
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
|
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
|
||||||
//
|
//
|
||||||
@@ -111,6 +114,16 @@ type StreamingLogWriter interface {
|
|||||||
// - error: An error if writing fails, nil otherwise
|
// - error: An error if writing fails, nil otherwise
|
||||||
WriteAPIResponse(apiResponse []byte) error
|
WriteAPIResponse(apiResponse []byte) error
|
||||||
|
|
||||||
|
// WriteAPIWebsocketTimeline writes the upstream websocket timeline to the log.
|
||||||
|
// This should be called when upstream communication happened over websocket.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - apiWebsocketTimeline: The upstream websocket event timeline
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: An error if writing fails, nil otherwise
|
||||||
|
WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error
|
||||||
|
|
||||||
// SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received.
|
// SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
@@ -203,17 +216,17 @@ func (l *FileRequestLogger) SetErrorLogsMaxFiles(maxFiles int) {
|
|||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - error: An error if logging fails, nil otherwise
|
// - error: An error if logging fails, nil otherwise
|
||||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
|
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
||||||
// The force flag allows writing error logs even when regular request logging is disabled.
|
// The force flag allows writing error logs even when regular request logging is disabled.
|
||||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
|
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||||
if !l.enabled && !force {
|
if !l.enabled && !force {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -260,8 +273,10 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
|
|||||||
requestHeaders,
|
requestHeaders,
|
||||||
body,
|
body,
|
||||||
requestBodyPath,
|
requestBodyPath,
|
||||||
|
websocketTimeline,
|
||||||
apiRequest,
|
apiRequest,
|
||||||
apiResponse,
|
apiResponse,
|
||||||
|
apiWebsocketTimeline,
|
||||||
apiResponseErrors,
|
apiResponseErrors,
|
||||||
statusCode,
|
statusCode,
|
||||||
responseHeaders,
|
responseHeaders,
|
||||||
@@ -518,8 +533,10 @@ func (l *FileRequestLogger) writeNonStreamingLog(
|
|||||||
requestHeaders map[string][]string,
|
requestHeaders map[string][]string,
|
||||||
requestBody []byte,
|
requestBody []byte,
|
||||||
requestBodyPath string,
|
requestBodyPath string,
|
||||||
|
websocketTimeline []byte,
|
||||||
apiRequest []byte,
|
apiRequest []byte,
|
||||||
apiResponse []byte,
|
apiResponse []byte,
|
||||||
|
apiWebsocketTimeline []byte,
|
||||||
apiResponseErrors []*interfaces.ErrorMessage,
|
apiResponseErrors []*interfaces.ErrorMessage,
|
||||||
statusCode int,
|
statusCode int,
|
||||||
responseHeaders map[string][]string,
|
responseHeaders map[string][]string,
|
||||||
@@ -531,7 +548,16 @@ func (l *FileRequestLogger) writeNonStreamingLog(
|
|||||||
if requestTimestamp.IsZero() {
|
if requestTimestamp.IsZero() {
|
||||||
requestTimestamp = time.Now()
|
requestTimestamp = time.Now()
|
||||||
}
|
}
|
||||||
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp); errWrite != nil {
|
isWebsocketTranscript := hasSectionPayload(websocketTimeline)
|
||||||
|
downstreamTransport := inferDownstreamTransport(requestHeaders, websocketTimeline)
|
||||||
|
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
|
||||||
|
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp, downstreamTransport, upstreamTransport, !isWebsocketTranscript); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
if errWrite := writeAPISection(w, "=== WEBSOCKET TIMELINE ===\n", "=== WEBSOCKET TIMELINE", websocketTimeline, time.Time{}); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
if errWrite := writeAPISection(w, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", apiWebsocketTimeline, time.Time{}); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil {
|
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil {
|
||||||
@@ -543,6 +569,12 @@ func (l *FileRequestLogger) writeNonStreamingLog(
|
|||||||
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil {
|
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
if isWebsocketTranscript {
|
||||||
|
// Intentionally omit the generic downstream HTTP response section for websocket
|
||||||
|
// transcripts. The durable session exchange is captured in WEBSOCKET TIMELINE,
|
||||||
|
// and appending a one-off upgrade response snapshot would dilute that transcript.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
|
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -553,6 +585,9 @@ func writeRequestInfoWithBody(
|
|||||||
body []byte,
|
body []byte,
|
||||||
bodyPath string,
|
bodyPath string,
|
||||||
timestamp time.Time,
|
timestamp time.Time,
|
||||||
|
downstreamTransport string,
|
||||||
|
upstreamTransport string,
|
||||||
|
includeBody bool,
|
||||||
) error {
|
) error {
|
||||||
if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil {
|
if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
@@ -566,10 +601,20 @@ func writeRequestInfoWithBody(
|
|||||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil {
|
if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(downstreamTransport) != "" {
|
||||||
|
if _, errWrite := io.WriteString(w, fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport)); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(upstreamTransport) != "" {
|
||||||
|
if _, errWrite := io.WriteString(w, fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport)); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
}
|
||||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
|
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -584,36 +629,121 @@ func writeRequestInfoWithBody(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
if errWrite := writeSectionSpacing(w, 1); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !includeBody {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil {
|
if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bodyTrailingNewlines := 1
|
||||||
if bodyPath != "" {
|
if bodyPath != "" {
|
||||||
bodyFile, errOpen := os.Open(bodyPath)
|
bodyFile, errOpen := os.Open(bodyPath)
|
||||||
if errOpen != nil {
|
if errOpen != nil {
|
||||||
return errOpen
|
return errOpen
|
||||||
}
|
}
|
||||||
if _, errCopy := io.Copy(w, bodyFile); errCopy != nil {
|
tracker := &trailingNewlineTrackingWriter{writer: w}
|
||||||
|
written, errCopy := io.Copy(tracker, bodyFile)
|
||||||
|
if errCopy != nil {
|
||||||
_ = bodyFile.Close()
|
_ = bodyFile.Close()
|
||||||
return errCopy
|
return errCopy
|
||||||
}
|
}
|
||||||
|
if written > 0 {
|
||||||
|
bodyTrailingNewlines = tracker.trailingNewlines
|
||||||
|
}
|
||||||
if errClose := bodyFile.Close(); errClose != nil {
|
if errClose := bodyFile.Close(); errClose != nil {
|
||||||
log.WithError(errClose).Warn("failed to close request body temp file")
|
log.WithError(errClose).Warn("failed to close request body temp file")
|
||||||
}
|
}
|
||||||
} else if _, errWrite := w.Write(body); errWrite != nil {
|
} else if _, errWrite := w.Write(body); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
|
} else if len(body) > 0 {
|
||||||
|
bodyTrailingNewlines = countTrailingNewlinesBytes(body)
|
||||||
}
|
}
|
||||||
|
if errWrite := writeSectionSpacing(w, bodyTrailingNewlines); errWrite != nil {
|
||||||
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
|
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func countTrailingNewlinesBytes(payload []byte) int {
|
||||||
|
count := 0
|
||||||
|
for i := len(payload) - 1; i >= 0; i-- {
|
||||||
|
if payload[i] != '\n' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeSectionSpacing(w io.Writer, trailingNewlines int) error {
|
||||||
|
missingNewlines := 3 - trailingNewlines
|
||||||
|
if missingNewlines <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, errWrite := io.WriteString(w, strings.Repeat("\n", missingNewlines))
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
|
||||||
|
type trailingNewlineTrackingWriter struct {
|
||||||
|
writer io.Writer
|
||||||
|
trailingNewlines int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *trailingNewlineTrackingWriter) Write(payload []byte) (int, error) {
|
||||||
|
written, errWrite := t.writer.Write(payload)
|
||||||
|
if written > 0 {
|
||||||
|
writtenPayload := payload[:written]
|
||||||
|
trailingNewlines := countTrailingNewlinesBytes(writtenPayload)
|
||||||
|
if trailingNewlines == len(writtenPayload) {
|
||||||
|
t.trailingNewlines += trailingNewlines
|
||||||
|
} else {
|
||||||
|
t.trailingNewlines = trailingNewlines
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return written, errWrite
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasSectionPayload(payload []byte) bool {
|
||||||
|
return len(bytes.TrimSpace(payload)) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func inferDownstreamTransport(headers map[string][]string, websocketTimeline []byte) string {
|
||||||
|
if hasSectionPayload(websocketTimeline) {
|
||||||
|
return "websocket"
|
||||||
|
}
|
||||||
|
for key, values := range headers {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(key), "Upgrade") {
|
||||||
|
for _, value := range values {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(value), "websocket") {
|
||||||
|
return "websocket"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "http"
|
||||||
|
}
|
||||||
|
|
||||||
|
func inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline []byte, _ []*interfaces.ErrorMessage) string {
|
||||||
|
hasHTTP := hasSectionPayload(apiRequest) || hasSectionPayload(apiResponse)
|
||||||
|
hasWS := hasSectionPayload(apiWebsocketTimeline)
|
||||||
|
switch {
|
||||||
|
case hasHTTP && hasWS:
|
||||||
|
return "websocket+http"
|
||||||
|
case hasWS:
|
||||||
|
return "websocket"
|
||||||
|
case hasHTTP:
|
||||||
|
return "http"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error {
|
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error {
|
||||||
if len(payload) == 0 {
|
if len(payload) == 0 {
|
||||||
return nil
|
return nil
|
||||||
@@ -623,11 +753,6 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
|
|||||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if !bytes.HasSuffix(payload, []byte("\n")) {
|
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
|
||||||
return errWrite
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
|
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
@@ -640,12 +765,9 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
|
|||||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
|
||||||
return errWrite
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
if errWrite := writeSectionSpacing(w, countTrailingNewlinesBytes(payload)); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -662,12 +784,17 @@ func writeAPIErrorResponses(w io.Writer, apiResponseErrors []*interfaces.ErrorMe
|
|||||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil {
|
if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
trailingNewlines := 1
|
||||||
if apiResponseErrors[i].Error != nil {
|
if apiResponseErrors[i].Error != nil {
|
||||||
if _, errWrite := io.WriteString(w, apiResponseErrors[i].Error.Error()); errWrite != nil {
|
errText := apiResponseErrors[i].Error.Error()
|
||||||
|
if _, errWrite := io.WriteString(w, errText); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
if errText != "" {
|
||||||
|
trailingNewlines = countTrailingNewlinesBytes([]byte(errText))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
|
if errWrite := writeSectionSpacing(w, trailingNewlines); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -694,12 +821,18 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
var bufferedReader *bufio.Reader
|
||||||
return errWrite
|
if responseReader != nil {
|
||||||
|
bufferedReader = bufio.NewReader(responseReader)
|
||||||
|
}
|
||||||
|
if !responseBodyStartsWithLeadingNewline(bufferedReader) {
|
||||||
|
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if responseReader != nil {
|
if bufferedReader != nil {
|
||||||
if _, errCopy := io.Copy(w, responseReader); errCopy != nil {
|
if _, errCopy := io.Copy(w, bufferedReader); errCopy != nil {
|
||||||
return errCopy
|
return errCopy
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -717,6 +850,19 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func responseBodyStartsWithLeadingNewline(reader *bufio.Reader) bool {
|
||||||
|
if reader == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if peeked, _ := reader.Peek(2); len(peeked) >= 2 && peeked[0] == '\r' && peeked[1] == '\n' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if peeked, _ := reader.Peek(1); len(peeked) >= 1 && peeked[0] == '\n' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// formatLogContent creates the complete log content for non-streaming requests.
|
// formatLogContent creates the complete log content for non-streaming requests.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
@@ -724,6 +870,7 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
|||||||
// - method: The HTTP method
|
// - method: The HTTP method
|
||||||
// - headers: The request headers
|
// - headers: The request headers
|
||||||
// - body: The request body
|
// - body: The request body
|
||||||
|
// - websocketTimeline: The downstream websocket event timeline
|
||||||
// - apiRequest: The API request data
|
// - apiRequest: The API request data
|
||||||
// - apiResponse: The API response data
|
// - apiResponse: The API response data
|
||||||
// - response: The raw response data
|
// - response: The raw response data
|
||||||
@@ -732,11 +879,42 @@ func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, respo
|
|||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - string: The formatted log content
|
// - string: The formatted log content
|
||||||
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, apiRequest, apiResponse, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string {
|
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, websocketTimeline, apiRequest, apiResponse, apiWebsocketTimeline, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string {
|
||||||
var content strings.Builder
|
var content strings.Builder
|
||||||
|
isWebsocketTranscript := hasSectionPayload(websocketTimeline)
|
||||||
|
downstreamTransport := inferDownstreamTransport(headers, websocketTimeline)
|
||||||
|
upstreamTransport := inferUpstreamTransport(apiRequest, apiResponse, apiWebsocketTimeline, apiResponseErrors)
|
||||||
|
|
||||||
// Request info
|
// Request info
|
||||||
content.WriteString(l.formatRequestInfo(url, method, headers, body))
|
content.WriteString(l.formatRequestInfo(url, method, headers, body, downstreamTransport, upstreamTransport, !isWebsocketTranscript))
|
||||||
|
|
||||||
|
if len(websocketTimeline) > 0 {
|
||||||
|
if bytes.HasPrefix(websocketTimeline, []byte("=== WEBSOCKET TIMELINE")) {
|
||||||
|
content.Write(websocketTimeline)
|
||||||
|
if !bytes.HasSuffix(websocketTimeline, []byte("\n")) {
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
content.WriteString("=== WEBSOCKET TIMELINE ===\n")
|
||||||
|
content.Write(websocketTimeline)
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(apiWebsocketTimeline) > 0 {
|
||||||
|
if bytes.HasPrefix(apiWebsocketTimeline, []byte("=== API WEBSOCKET TIMELINE")) {
|
||||||
|
content.Write(apiWebsocketTimeline)
|
||||||
|
if !bytes.HasSuffix(apiWebsocketTimeline, []byte("\n")) {
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
content.WriteString("=== API WEBSOCKET TIMELINE ===\n")
|
||||||
|
content.Write(apiWebsocketTimeline)
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
if len(apiRequest) > 0 {
|
if len(apiRequest) > 0 {
|
||||||
if bytes.HasPrefix(apiRequest, []byte("=== API REQUEST")) {
|
if bytes.HasPrefix(apiRequest, []byte("=== API REQUEST")) {
|
||||||
@@ -773,6 +951,12 @@ func (l *FileRequestLogger) formatLogContent(url, method string, headers map[str
|
|||||||
content.WriteString("\n")
|
content.WriteString("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isWebsocketTranscript {
|
||||||
|
// Mirror writeNonStreamingLog: websocket transcripts end with the dedicated
|
||||||
|
// timeline sections instead of a generic downstream HTTP response block.
|
||||||
|
return content.String()
|
||||||
|
}
|
||||||
|
|
||||||
// Response section
|
// Response section
|
||||||
content.WriteString("=== RESPONSE ===\n")
|
content.WriteString("=== RESPONSE ===\n")
|
||||||
content.WriteString(fmt.Sprintf("Status: %d\n", status))
|
content.WriteString(fmt.Sprintf("Status: %d\n", status))
|
||||||
@@ -933,13 +1117,19 @@ func (l *FileRequestLogger) decompressZstd(data []byte) ([]byte, error) {
|
|||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - string: The formatted request information
|
// - string: The formatted request information
|
||||||
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte) string {
|
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte, downstreamTransport string, upstreamTransport string, includeBody bool) string {
|
||||||
var content strings.Builder
|
var content strings.Builder
|
||||||
|
|
||||||
content.WriteString("=== REQUEST INFO ===\n")
|
content.WriteString("=== REQUEST INFO ===\n")
|
||||||
content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version))
|
content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version))
|
||||||
content.WriteString(fmt.Sprintf("URL: %s\n", url))
|
content.WriteString(fmt.Sprintf("URL: %s\n", url))
|
||||||
content.WriteString(fmt.Sprintf("Method: %s\n", method))
|
content.WriteString(fmt.Sprintf("Method: %s\n", method))
|
||||||
|
if strings.TrimSpace(downstreamTransport) != "" {
|
||||||
|
content.WriteString(fmt.Sprintf("Downstream Transport: %s\n", downstreamTransport))
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(upstreamTransport) != "" {
|
||||||
|
content.WriteString(fmt.Sprintf("Upstream Transport: %s\n", upstreamTransport))
|
||||||
|
}
|
||||||
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||||
content.WriteString("\n")
|
content.WriteString("\n")
|
||||||
|
|
||||||
@@ -952,6 +1142,10 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
|
|||||||
}
|
}
|
||||||
content.WriteString("\n")
|
content.WriteString("\n")
|
||||||
|
|
||||||
|
if !includeBody {
|
||||||
|
return content.String()
|
||||||
|
}
|
||||||
|
|
||||||
content.WriteString("=== REQUEST BODY ===\n")
|
content.WriteString("=== REQUEST BODY ===\n")
|
||||||
content.Write(body)
|
content.Write(body)
|
||||||
content.WriteString("\n\n")
|
content.WriteString("\n\n")
|
||||||
@@ -1011,6 +1205,9 @@ type FileStreamingLogWriter struct {
|
|||||||
// apiResponse stores the upstream API response data.
|
// apiResponse stores the upstream API response data.
|
||||||
apiResponse []byte
|
apiResponse []byte
|
||||||
|
|
||||||
|
// apiWebsocketTimeline stores the upstream websocket event timeline.
|
||||||
|
apiWebsocketTimeline []byte
|
||||||
|
|
||||||
// apiResponseTimestamp captures when the API response was received.
|
// apiResponseTimestamp captures when the API response was received.
|
||||||
apiResponseTimestamp time.Time
|
apiResponseTimestamp time.Time
|
||||||
}
|
}
|
||||||
@@ -1092,6 +1289,21 @@ func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteAPIWebsocketTimeline buffers the upstream websocket timeline for later writing.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - apiWebsocketTimeline: The upstream websocket event timeline
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: Always returns nil (buffering cannot fail)
|
||||||
|
func (w *FileStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
|
||||||
|
if len(apiWebsocketTimeline) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
|
func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
|
||||||
if !timestamp.IsZero() {
|
if !timestamp.IsZero() {
|
||||||
w.apiResponseTimestamp = timestamp
|
w.apiResponseTimestamp = timestamp
|
||||||
@@ -1100,7 +1312,7 @@ func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
|
|||||||
|
|
||||||
// Close finalizes the log file and cleans up resources.
|
// Close finalizes the log file and cleans up resources.
|
||||||
// It writes all buffered data to the file in the correct order:
|
// It writes all buffered data to the file in the correct order:
|
||||||
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
|
// API WEBSOCKET TIMELINE -> API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - error: An error if closing fails, nil otherwise
|
// - error: An error if closing fails, nil otherwise
|
||||||
@@ -1182,7 +1394,10 @@ func (w *FileStreamingLogWriter) asyncWriter() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
|
func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
|
||||||
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil {
|
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp, "http", inferUpstreamTransport(w.apiRequest, w.apiResponse, w.apiWebsocketTimeline, nil), true); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
if errWrite := writeAPISection(logFile, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", w.apiWebsocketTimeline, time.Time{}); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil {
|
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil {
|
||||||
@@ -1265,6 +1480,17 @@ func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteAPIWebsocketTimeline is a no-op implementation that does nothing and always returns nil.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - apiWebsocketTimeline: The upstream websocket event timeline (ignored)
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: Always returns nil
|
||||||
|
func (w *NoOpStreamingLogWriter) WriteAPIWebsocketTimeline(_ []byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {}
|
func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {}
|
||||||
|
|
||||||
// Close is a no-op implementation that does nothing and always returns nil.
|
// Close is a no-op implementation that does nothing and always returns nil.
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ const (
|
|||||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||||
managementSyncMinInterval = 30 * time.Second
|
managementSyncMinInterval = 30 * time.Second
|
||||||
updateCheckInterval = 3 * time.Hour
|
updateCheckInterval = 3 * time.Hour
|
||||||
|
maxAssetDownloadSize = 50 << 20 // 10 MB safety limit for management asset downloads
|
||||||
)
|
)
|
||||||
|
|
||||||
// ManagementFileName exposes the control panel asset filename.
|
// ManagementFileName exposes the control panel asset filename.
|
||||||
@@ -88,6 +89,10 @@ func runAutoUpdater(ctx context.Context) {
|
|||||||
log.Debug("management asset auto-updater skipped: control panel disabled")
|
log.Debug("management asset auto-updater skipped: control panel disabled")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if cfg.RemoteManagement.DisableAutoUpdatePanel {
|
||||||
|
log.Debug("management asset auto-updater skipped: disable-auto-update-panel is enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
configPath, _ := schedulerConfigPath.Load().(string)
|
configPath, _ := schedulerConfigPath.Load().(string)
|
||||||
staticDir := StaticDir(configPath)
|
staticDir := StaticDir(configPath)
|
||||||
@@ -259,7 +264,8 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
|||||||
}
|
}
|
||||||
|
|
||||||
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
|
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
|
||||||
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
|
log.Errorf("management asset digest mismatch: expected %s got %s — aborting update for safety", remoteHash, downloadedHash)
|
||||||
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = atomicWriteFile(localPath, data); err != nil {
|
if err = atomicWriteFile(localPath, data); err != nil {
|
||||||
@@ -282,6 +288,9 @@ func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, loca
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Warnf("management asset downloaded from fallback URL without digest verification (hash=%s) — "+
|
||||||
|
"enable verified GitHub updates by keeping disable-auto-update-panel set to false", downloadedHash)
|
||||||
|
|
||||||
if err = atomicWriteFile(localPath, data); err != nil {
|
if err = atomicWriteFile(localPath, data); err != nil {
|
||||||
log.WithError(err).Warn("failed to persist fallback management control panel page")
|
log.WithError(err).Warn("failed to persist fallback management control panel page")
|
||||||
return false
|
return false
|
||||||
@@ -392,10 +401,13 @@ func downloadAsset(ctx context.Context, client *http.Client, downloadURL string)
|
|||||||
return nil, "", fmt.Errorf("unexpected download status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
return nil, "", fmt.Errorf("unexpected download status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := io.ReadAll(resp.Body)
|
data, err := io.ReadAll(io.LimitReader(resp.Body, maxAssetDownloadSize+1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("read download body: %w", err)
|
return nil, "", fmt.Errorf("read download body: %w", err)
|
||||||
}
|
}
|
||||||
|
if int64(len(data)) > maxAssetDownloadSize {
|
||||||
|
return nil, "", fmt.Errorf("download exceeds maximum allowed size of %d bytes", maxAssetDownloadSize)
|
||||||
|
}
|
||||||
|
|
||||||
sum := sha256.Sum256(data)
|
sum := sha256.Sum256(data)
|
||||||
return data, hex.EncodeToString(sum[:]), nil
|
return data, hex.EncodeToString(sum[:]), nil
|
||||||
|
|||||||
151
internal/misc/antigravity_version.go
Normal file
151
internal/misc/antigravity_version.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
// Package misc provides miscellaneous utility functions for the CLI Proxy API server.
|
||||||
|
package misc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
antigravityReleasesURL = "https://antigravity-auto-updater-974169037036.us-central1.run.app/releases"
|
||||||
|
antigravityFallbackVersion = "1.21.9"
|
||||||
|
antigravityVersionCacheTTL = 6 * time.Hour
|
||||||
|
antigravityFetchTimeout = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type antigravityRelease struct {
|
||||||
|
Version string `json:"version"`
|
||||||
|
ExecutionID string `json:"execution_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
cachedAntigravityVersion = antigravityFallbackVersion
|
||||||
|
antigravityVersionMu sync.RWMutex
|
||||||
|
antigravityVersionExpiry time.Time
|
||||||
|
antigravityUpdaterOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
// StartAntigravityVersionUpdater starts a background goroutine that periodically refreshes the cached antigravity version.
|
||||||
|
// This is intentionally decoupled from request execution to avoid blocking executors on version lookups.
|
||||||
|
func StartAntigravityVersionUpdater(ctx context.Context) {
|
||||||
|
antigravityUpdaterOnce.Do(func() {
|
||||||
|
go runAntigravityVersionUpdater(ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func runAntigravityVersionUpdater(ctx context.Context) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(antigravityVersionCacheTTL / 2)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
log.Infof("periodic antigravity version refresh started (interval=%s)", antigravityVersionCacheTTL/2)
|
||||||
|
|
||||||
|
refreshAntigravityVersion(ctx)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
refreshAntigravityVersion(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func refreshAntigravityVersion(ctx context.Context) {
|
||||||
|
version, errFetch := fetchAntigravityLatestVersion(ctx)
|
||||||
|
|
||||||
|
antigravityVersionMu.Lock()
|
||||||
|
defer antigravityVersionMu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
if errFetch == nil {
|
||||||
|
cachedAntigravityVersion = version
|
||||||
|
antigravityVersionExpiry = now.Add(antigravityVersionCacheTTL)
|
||||||
|
log.WithField("version", version).Info("fetched latest antigravity version")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if cachedAntigravityVersion == "" || now.After(antigravityVersionExpiry) {
|
||||||
|
cachedAntigravityVersion = antigravityFallbackVersion
|
||||||
|
antigravityVersionExpiry = now.Add(antigravityVersionCacheTTL)
|
||||||
|
log.WithError(errFetch).Warn("failed to refresh antigravity version, using fallback version")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithError(errFetch).Debug("failed to refresh antigravity version, keeping cached value")
|
||||||
|
}
|
||||||
|
|
||||||
|
// AntigravityLatestVersion returns the cached antigravity version refreshed by StartAntigravityVersionUpdater.
|
||||||
|
// It falls back to antigravityFallbackVersion if the cache is empty or stale.
|
||||||
|
func AntigravityLatestVersion() string {
|
||||||
|
antigravityVersionMu.RLock()
|
||||||
|
if cachedAntigravityVersion != "" && time.Now().Before(antigravityVersionExpiry) {
|
||||||
|
v := cachedAntigravityVersion
|
||||||
|
antigravityVersionMu.RUnlock()
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
antigravityVersionMu.RUnlock()
|
||||||
|
|
||||||
|
return antigravityFallbackVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
// AntigravityUserAgent returns the User-Agent string for antigravity requests
|
||||||
|
// using the latest version fetched from the releases API.
|
||||||
|
func AntigravityUserAgent() string {
|
||||||
|
return fmt.Sprintf("antigravity/%s darwin/arm64", AntigravityLatestVersion())
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchAntigravityLatestVersion(ctx context.Context) (string, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: antigravityFetchTimeout}
|
||||||
|
|
||||||
|
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodGet, antigravityReleasesURL, nil)
|
||||||
|
if errReq != nil {
|
||||||
|
return "", fmt.Errorf("build antigravity releases request: %w", errReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, errDo := client.Do(httpReq)
|
||||||
|
if errDo != nil {
|
||||||
|
return "", fmt.Errorf("fetch antigravity releases: %w", errDo)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.WithError(errClose).Warn("antigravity releases response body close error")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("antigravity releases API returned status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var releases []antigravityRelease
|
||||||
|
if errDecode := json.NewDecoder(resp.Body).Decode(&releases); errDecode != nil {
|
||||||
|
return "", fmt.Errorf("decode antigravity releases response: %w", errDecode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(releases) == 0 {
|
||||||
|
return "", errors.New("antigravity releases API returned empty list")
|
||||||
|
}
|
||||||
|
|
||||||
|
version := releases[0].Version
|
||||||
|
if version == "" {
|
||||||
|
return "", errors.New("antigravity releases API returned empty version")
|
||||||
|
}
|
||||||
|
|
||||||
|
return version, nil
|
||||||
|
}
|
||||||
@@ -4,10 +4,98 @@
|
|||||||
package misc
|
package misc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// GeminiCLIVersion is the version string reported in the User-Agent for upstream requests.
|
||||||
|
GeminiCLIVersion = "0.31.0"
|
||||||
|
|
||||||
|
// GeminiCLIApiClientHeader is the value for the X-Goog-Api-Client header sent to the Gemini CLI upstream.
|
||||||
|
GeminiCLIApiClientHeader = "google-genai-sdk/1.41.0 gl-node/v22.19.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
// geminiCLIOS maps Go runtime OS names to the Node.js-style platform strings used by Gemini CLI.
|
||||||
|
func geminiCLIOS() string {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "windows":
|
||||||
|
return "win32"
|
||||||
|
default:
|
||||||
|
return runtime.GOOS
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// geminiCLIArch maps Go runtime architecture names to the Node.js-style arch strings used by Gemini CLI.
|
||||||
|
func geminiCLIArch() string {
|
||||||
|
switch runtime.GOARCH {
|
||||||
|
case "amd64":
|
||||||
|
return "x64"
|
||||||
|
case "386":
|
||||||
|
return "x86"
|
||||||
|
default:
|
||||||
|
return runtime.GOARCH
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GeminiCLIUserAgent returns a User-Agent string that matches the Gemini CLI format.
|
||||||
|
// The model parameter is included in the UA; pass "" or "unknown" when the model is not applicable.
|
||||||
|
func GeminiCLIUserAgent(model string) string {
|
||||||
|
if model == "" {
|
||||||
|
model = "unknown"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScrubProxyAndFingerprintHeaders removes all headers that could reveal
|
||||||
|
// proxy infrastructure, client identity, or browser fingerprints from an
|
||||||
|
// outgoing request. This ensures requests to upstream services look like they
|
||||||
|
// originate directly from a native client rather than a third-party client
|
||||||
|
// behind a reverse proxy.
|
||||||
|
func ScrubProxyAndFingerprintHeaders(req *http.Request) {
|
||||||
|
if req == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Proxy tracing headers ---
|
||||||
|
req.Header.Del("X-Forwarded-For")
|
||||||
|
req.Header.Del("X-Forwarded-Host")
|
||||||
|
req.Header.Del("X-Forwarded-Proto")
|
||||||
|
req.Header.Del("X-Forwarded-Port")
|
||||||
|
req.Header.Del("X-Real-IP")
|
||||||
|
req.Header.Del("Forwarded")
|
||||||
|
req.Header.Del("Via")
|
||||||
|
|
||||||
|
// --- Client identity headers ---
|
||||||
|
req.Header.Del("X-Title")
|
||||||
|
req.Header.Del("X-Stainless-Lang")
|
||||||
|
req.Header.Del("X-Stainless-Package-Version")
|
||||||
|
req.Header.Del("X-Stainless-Os")
|
||||||
|
req.Header.Del("X-Stainless-Arch")
|
||||||
|
req.Header.Del("X-Stainless-Runtime")
|
||||||
|
req.Header.Del("X-Stainless-Runtime-Version")
|
||||||
|
req.Header.Del("Http-Referer")
|
||||||
|
req.Header.Del("Referer")
|
||||||
|
|
||||||
|
// --- Browser / Chromium fingerprint headers ---
|
||||||
|
// These are sent by Electron-based clients (e.g. CherryStudio) using the
|
||||||
|
// Fetch API, but NOT by Node.js https module (which Antigravity uses).
|
||||||
|
req.Header.Del("Sec-Ch-Ua")
|
||||||
|
req.Header.Del("Sec-Ch-Ua-Mobile")
|
||||||
|
req.Header.Del("Sec-Ch-Ua-Platform")
|
||||||
|
req.Header.Del("Sec-Fetch-Mode")
|
||||||
|
req.Header.Del("Sec-Fetch-Site")
|
||||||
|
req.Header.Del("Sec-Fetch-Dest")
|
||||||
|
req.Header.Del("Priority")
|
||||||
|
|
||||||
|
// --- Encoding negotiation ---
|
||||||
|
// Antigravity (Node.js) sends "gzip, deflate, br" by default;
|
||||||
|
// Electron-based clients may add "zstd" which is a fingerprint mismatch.
|
||||||
|
req.Header.Del("Accept-Encoding")
|
||||||
|
}
|
||||||
|
|
||||||
// EnsureHeader ensures that a header exists in the target header map by checking
|
// EnsureHeader ensures that a header exists in the target header map by checking
|
||||||
// multiple sources in order of priority: source headers, existing target headers,
|
// multiple sources in order of priority: source headers, existing target headers,
|
||||||
// and finally the default value. It only sets the header if it's not already present
|
// and finally the default value. It only sets the header if it's not already present
|
||||||
|
|||||||
@@ -30,6 +30,23 @@ type OAuthCallback struct {
|
|||||||
ErrorDescription string
|
ErrorDescription string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AsyncPrompt runs a prompt function in a goroutine and returns channels for
|
||||||
|
// the result. The returned channels are buffered (size 1) so the goroutine can
|
||||||
|
// complete even if the caller abandons the channels.
|
||||||
|
func AsyncPrompt(promptFn func(string) (string, error), message string) (<-chan string, <-chan error) {
|
||||||
|
inputCh := make(chan string, 1)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
input, err := promptFn(message)
|
||||||
|
if err != nil {
|
||||||
|
errCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
inputCh <- input
|
||||||
|
}()
|
||||||
|
return inputCh, errCh
|
||||||
|
}
|
||||||
|
|
||||||
// ParseOAuthCallback extracts OAuth parameters from a callback URL.
|
// ParseOAuthCallback extracts OAuth parameters from a callback URL.
|
||||||
// It returns nil when the input is empty.
|
// It returns nil when the input is empty.
|
||||||
func ParseOAuthCallback(input string) (*OAuthCallback, error) {
|
func ParseOAuthCallback(input string) (*OAuthCallback, error) {
|
||||||
|
|||||||
@@ -1,12 +1,222 @@
|
|||||||
// Package registry provides model definitions and lookup helpers for various AI providers.
|
// Package registry provides model definitions and lookup helpers for various AI providers.
|
||||||
// Static model metadata is stored in model_definitions_static_data.go.
|
// Static model metadata is loaded from the embedded models.json file and can be refreshed from network.
|
||||||
package registry
|
package registry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// staticModelsJSON mirrors the top-level structure of models.json.
|
||||||
|
type staticModelsJSON struct {
|
||||||
|
Claude []*ModelInfo `json:"claude"`
|
||||||
|
Gemini []*ModelInfo `json:"gemini"`
|
||||||
|
Vertex []*ModelInfo `json:"vertex"`
|
||||||
|
GeminiCLI []*ModelInfo `json:"gemini-cli"`
|
||||||
|
AIStudio []*ModelInfo `json:"aistudio"`
|
||||||
|
CodexFree []*ModelInfo `json:"codex-free"`
|
||||||
|
CodexTeam []*ModelInfo `json:"codex-team"`
|
||||||
|
CodexPlus []*ModelInfo `json:"codex-plus"`
|
||||||
|
CodexPro []*ModelInfo `json:"codex-pro"`
|
||||||
|
Kimi []*ModelInfo `json:"kimi"`
|
||||||
|
Antigravity []*ModelInfo `json:"antigravity"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClaudeModels returns the standard Claude model definitions.
|
||||||
|
func GetClaudeModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Claude)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiModels returns the standard Gemini model definitions.
|
||||||
|
func GetGeminiModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Gemini)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiVertexModels returns Gemini model definitions for Vertex AI.
|
||||||
|
func GetGeminiVertexModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Vertex)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiCLIModels returns Gemini model definitions for the Gemini CLI.
|
||||||
|
func GetGeminiCLIModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().GeminiCLI)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAIStudioModels returns model definitions for AI Studio.
|
||||||
|
func GetAIStudioModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().AIStudio)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexFreeModels returns model definitions for the Codex free plan tier.
|
||||||
|
func GetCodexFreeModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexFree)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexTeamModels returns model definitions for the Codex team plan tier.
|
||||||
|
func GetCodexTeamModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexTeam)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexPlusModels returns model definitions for the Codex plus plan tier.
|
||||||
|
func GetCodexPlusModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexPlus)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexProModels returns model definitions for the Codex pro plan tier.
|
||||||
|
func GetCodexProModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexPro)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions.
|
||||||
|
func GetKimiModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Kimi)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAntigravityModels returns the standard Antigravity model definitions.
|
||||||
|
func GetAntigravityModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Antigravity)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodeBuddyModels returns the available models for CodeBuddy (Tencent).
|
||||||
|
// These models are served through the copilot.tencent.com API.
|
||||||
|
func GetCodeBuddyModels() []*ModelInfo {
|
||||||
|
now := int64(1748044800) // 2025-05-24
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "auto",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "Auto",
|
||||||
|
Description: "Automatic model selection via CodeBuddy",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "glm-5v-turbo",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "GLM-5v Turbo",
|
||||||
|
Description: "GLM-5v Turbo via CodeBuddy",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "glm-5.1",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "GLM-5.1",
|
||||||
|
Description: "GLM-5.1 via CodeBuddy",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "glm-5.0-turbo",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "GLM-5.0 Turbo",
|
||||||
|
Description: "GLM-5.0 Turbo via CodeBuddy",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "glm-5.0",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "GLM-5.0",
|
||||||
|
Description: "GLM-5.0 via CodeBuddy",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "glm-4.7",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "GLM-4.7",
|
||||||
|
Description: "GLM-4.7 via CodeBuddy",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "minimax-m2.7",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "MiniMax M2.7",
|
||||||
|
Description: "MiniMax M2.7 via CodeBuddy",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kimi-k2.5",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "Kimi K2.5",
|
||||||
|
Description: "Kimi K2.5 via CodeBuddy",
|
||||||
|
ContextLength: 256000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kimi-k2-thinking",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "Kimi K2 Thinking",
|
||||||
|
Description: "Kimi K2 Thinking via CodeBuddy",
|
||||||
|
ContextLength: 256000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
Thinking: &ThinkingSupport{ZeroAllowed: true},
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "deepseek-v3-2-volc",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "tencent",
|
||||||
|
Type: "codebuddy",
|
||||||
|
DisplayName: "DeepSeek V3.2 (Volc)",
|
||||||
|
Description: "DeepSeek V3.2 via CodeBuddy",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned.
|
||||||
|
func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*ModelInfo, len(models))
|
||||||
|
for i, m := range models {
|
||||||
|
out[i] = cloneModelInfo(m)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
|
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
|
||||||
// It returns nil when the channel is unknown.
|
// It returns nil when the channel is unknown.
|
||||||
//
|
//
|
||||||
@@ -17,13 +227,9 @@ import (
|
|||||||
// - gemini-cli
|
// - gemini-cli
|
||||||
// - aistudio
|
// - aistudio
|
||||||
// - codex
|
// - codex
|
||||||
// - qwen
|
|
||||||
// - iflow
|
|
||||||
// - kimi
|
// - kimi
|
||||||
// - kiro
|
|
||||||
// - kilo
|
// - kilo
|
||||||
// - github-copilot
|
// - github-copilot
|
||||||
// - kiro
|
|
||||||
// - amazonq
|
// - amazonq
|
||||||
// - antigravity (returns static overrides only)
|
// - antigravity (returns static overrides only)
|
||||||
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||||
@@ -40,11 +246,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
|||||||
case "aistudio":
|
case "aistudio":
|
||||||
return GetAIStudioModels()
|
return GetAIStudioModels()
|
||||||
case "codex":
|
case "codex":
|
||||||
return GetOpenAIModels()
|
return GetCodexProModels()
|
||||||
case "qwen":
|
|
||||||
return GetQwenModels()
|
|
||||||
case "iflow":
|
|
||||||
return GetIFlowModels()
|
|
||||||
case "kimi":
|
case "kimi":
|
||||||
return GetKimiModels()
|
return GetKimiModels()
|
||||||
case "github-copilot":
|
case "github-copilot":
|
||||||
@@ -56,33 +258,28 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
|||||||
case "amazonq":
|
case "amazonq":
|
||||||
return GetAmazonQModels()
|
return GetAmazonQModels()
|
||||||
case "antigravity":
|
case "antigravity":
|
||||||
cfg := GetAntigravityModelConfig()
|
return GetAntigravityModels()
|
||||||
if len(cfg) == 0 {
|
case "codebuddy":
|
||||||
return nil
|
return GetCodeBuddyModels()
|
||||||
}
|
case "cursor":
|
||||||
models := make([]*ModelInfo, 0, len(cfg))
|
return GetCursorModels()
|
||||||
for modelID, entry := range cfg {
|
|
||||||
if modelID == "" || entry == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
models = append(models, &ModelInfo{
|
|
||||||
ID: modelID,
|
|
||||||
Object: "model",
|
|
||||||
OwnedBy: "antigravity",
|
|
||||||
Type: "antigravity",
|
|
||||||
Thinking: entry.Thinking,
|
|
||||||
MaxCompletionTokens: entry.MaxCompletionTokens,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
sort.Slice(models, func(i, j int) bool {
|
|
||||||
return strings.ToLower(models[i].ID) < strings.ToLower(models[j].ID)
|
|
||||||
})
|
|
||||||
return models
|
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCursorModels returns the fallback Cursor model definitions.
|
||||||
|
func GetCursorModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
{ID: "composer-2", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Composer 2", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||||
|
{ID: "claude-4-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 4 Sonnet", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||||
|
{ID: "claude-3.5-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 3.5 Sonnet", ContextLength: 200000, MaxCompletionTokens: 8192},
|
||||||
|
{ID: "gpt-4o", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "GPT-4o", ContextLength: 128000, MaxCompletionTokens: 16384},
|
||||||
|
{ID: "cursor-small", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Cursor Small", ContextLength: 200000, MaxCompletionTokens: 64000},
|
||||||
|
{ID: "gemini-2.5-pro", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Gemini 2.5 Pro", ContextLength: 1000000, MaxCompletionTokens: 65536, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// LookupStaticModelInfo searches all static model definitions for a model by ID.
|
// LookupStaticModelInfo searches all static model definitions for a model by ID.
|
||||||
// Returns nil if no matching model is found.
|
// Returns nil if no matching model is found.
|
||||||
func LookupStaticModelInfo(modelID string) *ModelInfo {
|
func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||||
@@ -90,45 +287,46 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data := getModels()
|
||||||
allModels := [][]*ModelInfo{
|
allModels := [][]*ModelInfo{
|
||||||
GetClaudeModels(),
|
data.Claude,
|
||||||
GetGeminiModels(),
|
data.Gemini,
|
||||||
GetGeminiVertexModels(),
|
data.Vertex,
|
||||||
GetGeminiCLIModels(),
|
data.GeminiCLI,
|
||||||
GetAIStudioModels(),
|
data.AIStudio,
|
||||||
GetOpenAIModels(),
|
data.CodexPro,
|
||||||
GetQwenModels(),
|
data.Kimi,
|
||||||
GetIFlowModels(),
|
data.Antigravity,
|
||||||
GetKimiModels(),
|
|
||||||
GetGitHubCopilotModels(),
|
GetGitHubCopilotModels(),
|
||||||
GetKiroModels(),
|
GetKiroModels(),
|
||||||
GetKiloModels(),
|
GetKiloModels(),
|
||||||
GetAmazonQModels(),
|
GetAmazonQModels(),
|
||||||
|
GetCodeBuddyModels(),
|
||||||
|
GetCursorModels(),
|
||||||
}
|
}
|
||||||
for _, models := range allModels {
|
for _, models := range allModels {
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
if m != nil && m.ID == modelID {
|
if m != nil && m.ID == modelID {
|
||||||
return m
|
return cloneModelInfo(m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check Antigravity static config
|
|
||||||
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil {
|
|
||||||
return &ModelInfo{
|
|
||||||
ID: modelID,
|
|
||||||
Thinking: cfg.Thinking,
|
|
||||||
MaxCompletionTokens: cfg.MaxCompletionTokens,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// defaultCopilotClaudeContextLength is the conservative prompt token limit for
|
||||||
|
// Claude models accessed via the GitHub Copilot API. Individual accounts are
|
||||||
|
// capped at 128K; business accounts at 168K. When the dynamic /models API fetch
|
||||||
|
// succeeds, the real per-account limit overrides this value. This constant is
|
||||||
|
// only used as a safe fallback.
|
||||||
|
const defaultCopilotClaudeContextLength = 128000
|
||||||
|
|
||||||
// GetGitHubCopilotModels returns the available models for GitHub Copilot.
|
// GetGitHubCopilotModels returns the available models for GitHub Copilot.
|
||||||
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
|
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
|
||||||
func GetGitHubCopilotModels() []*ModelInfo {
|
func GetGitHubCopilotModels() []*ModelInfo {
|
||||||
now := int64(1732752000) // 2024-11-27
|
now := int64(1732752000) // 2024-11-27
|
||||||
|
copilotClaudeEndpoints := []string{"/chat/completions", "/messages"}
|
||||||
gpt4oEntries := []struct {
|
gpt4oEntries := []struct {
|
||||||
ID string
|
ID string
|
||||||
DisplayName string
|
DisplayName string
|
||||||
@@ -152,6 +350,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Description: "OpenAI GPT-4.1 via GitHub Copilot",
|
Description: "OpenAI GPT-4.1 via GitHub Copilot",
|
||||||
ContextLength: 128000,
|
ContextLength: 128000,
|
||||||
MaxCompletionTokens: 16384,
|
MaxCompletionTokens: 16384,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,6 +365,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Description: entry.Description,
|
Description: entry.Description,
|
||||||
ContextLength: 128000,
|
ContextLength: 128000,
|
||||||
MaxCompletionTokens: 16384,
|
MaxCompletionTokens: 16384,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -300,6 +500,32 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
SupportedEndpoints: []string{"/responses"},
|
SupportedEndpoints: []string{"/responses"},
|
||||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.4",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "GPT-5.4",
|
||||||
|
Description: "OpenAI GPT-5.4 via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/responses"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.4-mini",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "GPT-5.4 mini",
|
||||||
|
Description: "OpenAI GPT-5.4 mini via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
SupportedEndpoints: []string{"/responses"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-haiku-4.5",
|
ID: "claude-haiku-4.5",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -308,9 +534,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Haiku 4.5",
|
DisplayName: "Claude Haiku 4.5",
|
||||||
Description: "Anthropic Claude Haiku 4.5 via GitHub Copilot",
|
Description: "Anthropic Claude Haiku 4.5 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: copilotClaudeEndpoints,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-opus-4.1",
|
ID: "claude-opus-4.1",
|
||||||
@@ -320,9 +546,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Opus 4.1",
|
DisplayName: "Claude Opus 4.1",
|
||||||
Description: "Anthropic Claude Opus 4.1 via GitHub Copilot",
|
Description: "Anthropic Claude Opus 4.1 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 32000,
|
MaxCompletionTokens: 32000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: copilotClaudeEndpoints,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-opus-4.5",
|
ID: "claude-opus-4.5",
|
||||||
@@ -332,9 +558,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Opus 4.5",
|
DisplayName: "Claude Opus 4.5",
|
||||||
Description: "Anthropic Claude Opus 4.5 via GitHub Copilot",
|
Description: "Anthropic Claude Opus 4.5 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: copilotClaudeEndpoints,
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-opus-4.6",
|
ID: "claude-opus-4.6",
|
||||||
@@ -344,9 +571,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Opus 4.6",
|
DisplayName: "Claude Opus 4.6",
|
||||||
Description: "Anthropic Claude Opus 4.6 via GitHub Copilot",
|
Description: "Anthropic Claude Opus 4.6 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: copilotClaudeEndpoints,
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-sonnet-4",
|
ID: "claude-sonnet-4",
|
||||||
@@ -356,9 +584,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Sonnet 4",
|
DisplayName: "Claude Sonnet 4",
|
||||||
Description: "Anthropic Claude Sonnet 4 via GitHub Copilot",
|
Description: "Anthropic Claude Sonnet 4 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: copilotClaudeEndpoints,
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-sonnet-4.5",
|
ID: "claude-sonnet-4.5",
|
||||||
@@ -368,9 +597,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Sonnet 4.5",
|
DisplayName: "Claude Sonnet 4.5",
|
||||||
Description: "Anthropic Claude Sonnet 4.5 via GitHub Copilot",
|
Description: "Anthropic Claude Sonnet 4.5 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: copilotClaudeEndpoints,
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-sonnet-4.6",
|
ID: "claude-sonnet-4.6",
|
||||||
@@ -380,9 +610,10 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Claude Sonnet 4.6",
|
DisplayName: "Claude Sonnet 4.6",
|
||||||
Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot",
|
Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot",
|
||||||
ContextLength: 200000,
|
ContextLength: defaultCopilotClaudeContextLength,
|
||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: copilotClaudeEndpoints,
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-2.5-pro",
|
ID: "gemini-2.5-pro",
|
||||||
@@ -394,6 +625,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Description: "Google Gemini 2.5 Pro via GitHub Copilot",
|
Description: "Google Gemini 2.5 Pro via GitHub Copilot",
|
||||||
ContextLength: 1048576,
|
ContextLength: 1048576,
|
||||||
MaxCompletionTokens: 65536,
|
MaxCompletionTokens: 65536,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3-pro-preview",
|
ID: "gemini-3-pro-preview",
|
||||||
@@ -405,6 +637,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Description: "Google Gemini 3 Pro Preview via GitHub Copilot",
|
Description: "Google Gemini 3 Pro Preview via GitHub Copilot",
|
||||||
ContextLength: 1048576,
|
ContextLength: 1048576,
|
||||||
MaxCompletionTokens: 65536,
|
MaxCompletionTokens: 65536,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3.1-pro-preview",
|
ID: "gemini-3.1-pro-preview",
|
||||||
@@ -414,8 +647,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Gemini 3.1 Pro (Preview)",
|
DisplayName: "Gemini 3.1 Pro (Preview)",
|
||||||
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
|
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
|
||||||
ContextLength: 1048576,
|
ContextLength: 173000,
|
||||||
MaxCompletionTokens: 65536,
|
MaxCompletionTokens: 65536,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3-flash-preview",
|
ID: "gemini-3-flash-preview",
|
||||||
@@ -425,8 +659,9 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
DisplayName: "Gemini 3 Flash (Preview)",
|
DisplayName: "Gemini 3 Flash (Preview)",
|
||||||
Description: "Google Gemini 3 Flash Preview via GitHub Copilot",
|
Description: "Google Gemini 3 Flash Preview via GitHub Copilot",
|
||||||
ContextLength: 1048576,
|
ContextLength: 173000,
|
||||||
MaxCompletionTokens: 65536,
|
MaxCompletionTokens: 65536,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "grok-code-fast-1",
|
ID: "grok-code-fast-1",
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user